工作需要,了解了一下weka的java api,主要是随机森林这一块,刚开始学习,记录下。
了解不多,直接上demo,里面有一些注释说明:
package weka;
import java.io.File;
import weka.classifiers.Classifier;
import weka.classifiers.trees.RandomForest;
import weka.core.Instances;
import weka.core.SerializationHelper;
import weka.core.converters.ArffLoader;
public class demo {
public static void main(String[] args) throws Exception {
Classifier m_classifier = new RandomForest();
File inputFile = new File("F:/java/weka/trainData.arff");//训练语料文件
ArffLoader atf = new ArffLoader();
atf.setFile(inputFile);
Instances instancesTrain = atf.getDataSet(); // 读入训练文件
inputFile = new File("F:/java/weka/testData.arff");//测试语料文件
atf.setFile(inputFile);
Instances instancesTest = atf.getDataSet(); // 读入测试文件
instancesTest.setClassIndex(0); //设置分类属性所在行号(第一行为0号),instancesTest.numAttributes()可以取得属性总数
double sum = instancesTest.numInstances(),//测试语料实例数
right = 0.0f;
instancesTrain.setClassIndex(0);
m_classifier.buildClassifier(instancesTrain); //训练
System.out.println(m_classifier);
// 保存模型
SerializationHelper.write("LibSVM.model", m_classifier);//参数一为模型保存文件,classifier4为要保存的模型
for(int i = 0;i<sum;i++)//测试分类结果 1
{
if(m_classifier.classifyInstance(instancesTest.instance(i))==instancesTest.instance(i).classValue())//如果预测值和答案值相等(测试语料中的分类列提供的须为正确答案,结果才有意义)
{
right++;//正确值加1
}
}
// 获取上面保存的模型
Classifier classifier8 = (Classifier) weka.core.SerializationHelper.read("LibSVM.model");
double right2 = 0.0f;
for(int i = 0;i<sum;i++)//测试分类结果 2 (通过)
{
if(classifier8.classifyInstance(instancesTest.instance(i))==instancesTest.instance(i).classValue())//如果预测值和答案值相等(测试语料中的分类列提供的须为正确答案,结果才有意义)
{
right2++;//正确值加1
}
}
System.out.println(right);
System.out.println(right2);
System.out.println(sum);
System.out.println("RandomForest classification precision:"+(right/sum));
}
}
其中包含了随机森林的使用,包括训练、模型保存及算法结果。实际使用时拆分为两部分,在后台进行训练,客户端使用训练后的模型计算结果。当然还有实时训练的情况,这个后面再去了解。
随机森林可能还有很多配置参数需要调整,后续慢慢去学习。
扫码关注微信公众号--IT老五
微信扫一扫关注公众号,获取更多实用app,订阅地址不定时更新