机器学习weka,java api调用随机森林及保存模型
工作需要,了解了一下weka的java api,主要是随机森林这一块,刚开始学习,记录下。
了解不多,直接上demo,里面有一些注释说明:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 |
<span class="hljs-keyword">package</span> weka; <span class="hljs-keyword">import</span> java.io.File; <span class="hljs-keyword">import</span> weka.classifiers.Classifier; <span class="hljs-keyword">import</span> weka.classifiers.trees.RandomForest; <span class="hljs-keyword">import</span> weka.core.Instances; <span class="hljs-keyword">import</span> weka.core.SerializationHelper; <span class="hljs-keyword">import</span> weka.core.converters.ArffLoader; <span class="hljs-keyword">public</span> <span class="hljs-class"><span class="hljs-keyword">class</span> <span class="hljs-title">demo</span> </span>{ <span class="hljs-function"><span class="hljs-keyword">public</span> <span class="hljs-keyword">static</span> <span class="hljs-keyword">void</span> <span class="hljs-title">main</span><span class="hljs-params">(String[] args)</span> <span class="hljs-keyword">throws</span> Exception </span>{ Classifier m_classifier = <span class="hljs-keyword">new</span> RandomForest(); File inputFile = <span class="hljs-keyword">new</span> File(<span class="hljs-string">"F:/java/weka/trainData.arff"</span>);<span class="hljs-comment">//训练语料文件 </span> ArffLoader atf = <span class="hljs-keyword">new</span> ArffLoader(); atf.setFile(inputFile); Instances instancesTrain = atf.getDataSet(); <span class="hljs-comment">// 读入训练文件 </span> inputFile = <span class="hljs-keyword">new</span> File(<span class="hljs-string">"F:/java/weka/testData.arff"</span>);<span class="hljs-comment">//测试语料文件 </span> atf.setFile(inputFile); Instances instancesTest = atf.getDataSet(); <span class="hljs-comment">// 读入测试文件 </span> instancesTest.setClassIndex(<span class="hljs-number">0</span>); <span class="hljs-comment">//设置分类属性所在行号(第一行为0号),instancesTest.numAttributes()可以取得属性总数 </span> <span class="hljs-keyword">double</span> sum = instancesTest.numInstances(),<span class="hljs-comment">//测试语料实例数 </span> right = <span class="hljs-number">0.0f</span>; instancesTrain.setClassIndex(<span class="hljs-number">0</span>); m_classifier.buildClassifier(instancesTrain); <span class="hljs-comment">//训练</span> System.out.println(m_classifier); <span class="hljs-comment">// 保存模型</span> SerializationHelper.write(<span class="hljs-string">"LibSVM.model"</span>, m_classifier);<span class="hljs-comment">//参数一为模型保存文件,classifier4为要保存的模型</span> <span class="hljs-keyword">for</span>(<span class="hljs-keyword">int</span> i = <span class="hljs-number">0</span>;i<sum;i++)<span class="hljs-comment">//测试分类结果 1</span> { <span class="hljs-keyword">if</span>(m_classifier.classifyInstance(instancesTest.instance(i))==instancesTest.instance(i).classValue())<span class="hljs-comment">//如果预测值和答案值相等(测试语料中的分类列提供的须为正确答案,结果才有意义) </span> { right++;<span class="hljs-comment">//正确值加1 </span> } } <span class="hljs-comment">// 获取上面保存的模型</span> Classifier classifier8 = (Classifier) weka.core.SerializationHelper.read(<span class="hljs-string">"LibSVM.model"</span>); <span class="hljs-keyword">double</span> right2 = <span class="hljs-number">0.0f</span>; <span class="hljs-keyword">for</span>(<span class="hljs-keyword">int</span> i = <span class="hljs-number">0</span>;i<sum;i++)<span class="hljs-comment">//测试分类结果 2 (通过)</span> { <span class="hljs-keyword">if</span>(classifier8.classifyInstance(instancesTest.instance(i))==instancesTest.instance(i).classValue())<span class="hljs-comment">//如果预测值和答案值相等(测试语料中的分类列提供的须为正确答案,结果才有意义) </span> { right2++;<span class="hljs-comment">//正确值加1 </span> } } System.out.println(right); System.out.println(right2); System.out.println(sum); System.out.println(<span class="hljs-string">"RandomForest classification precision:"</span>+(right/sum)); } } |
其中包含了随机森林的使用,包括训练、模型保存及算法结果。实际使用时拆分为两部分,在后台进行训练,客户端使用训练后的模型计算结果。当然还有实时训练的情况,这个后面再去了解。
随机森林可能还有很多配置参数需要调整,后续慢慢去学习。