目录
创建项目
导入依赖
加载数据
特征选择
学习算法
对新数据分类
评估与预测误差度量
混淆矩阵
通过模型的预测结果生成 ROC 曲线数据
选择分类算法
完整代码
结论
创建项目
首先创建spring boot项目,我这里用的JDK8,springboot2.7.6,在根目录下创建lib文件夹,resources目录下创建data文件夹,将weka.jar放入lib文件夹,将数据文件放到data里面,如图:
导入依赖
在pom中导入依赖
<dependency>
<groupId>cn.hutool</groupId>
<artifactId>hutool-all</artifactId>
<version>5.8.9</version>
</dependency>
<dependency>
<groupId>weka</groupId>
<artifactId>weka</artifactId>
<version>3.6.2</version>
<scope>system</scope>
<systemPath>${project.basedir}/lib/weka.jar</systemPath>
</dependency>
同时在build标签下添加如下内容,防止加载不到数据资源:
<resources>
<resource>
<directory>src/main/resources</directory>
<filtering>true</filtering>
<includes>
<include>**/*.*</include>
</includes>
</resource>
<resource>
<directory>src/main/resources/data</directory>
<filtering>true</filtering>
<includes>
<include>**/*.*</include>
</includes>
</resource>
</resources>
大概在这个位置:
注意这里也修改一下:
在main函数前:
加载数据
//加载数据
//DataSource对象。它可以接受各种文件格式,并将其转换成Instances。
ConverterUtils.DataSource ds = new ConverterUtils.DataSource(PATH);
Instances data = ds.getDataSet();
我们的任务是学习创建模型,以便预测新样本的animal属性。对于这些新样本,我们只知 道它们的一些其他属性,但不知道它们的animal属性。因此,从训练集中移除animal属性。为 此,只要使用Remove过滤器将animal属性过滤即可。
首先,设置一个参数的字符串表,指定必须移走第一个属性。其余属性用作数据集,用来训 练分类器。
最后,调用Filter.useFilter(Instances, Filter)静态方法,将过滤器应用于所选数 据集。
// 移除 animal 属性
Remove remove = new Remove();
remove.setOptions(new String[]{"-R", "1"});
remove.setInputFormat(data);
data = Filter.useFilter(data, remove);
特征选择
它的目 标是选择相关属性的一个子集,用在学习模型中。为什么特征选择如此重要呢?因为一个更小的 属性集可以简化模型,并且让它们更容易被用户解释,这样通常会减少所需的训练时间,也会降 低过拟合的风险。
做属性选择时,可以考虑类值,也可以不考虑。第一种情况下,可以使用属性选择算法评估 特征的不同子集,并且计算出一个分数,表明所选属性的品质。我们可以使用不同的搜索算法(比 如穷举搜索、最佳优先搜索)与不同的品质分数(比如信息增益、基尼指数等)。
Weka提供了一个AttributeSelection对象进行属性选择,它需要两个额外的参数:评价 器(evaluator),用于计算属性的有用程度;排行器(ranker),用于根据评价者给出的分数对属 性进行分类排序。
//将把信息增益用作评价器,通过它们的信息增益分数对特征进行分类排序。
InfoGainAttributeEval infoGainAttributeEval = new InfoGainAttributeEval();
Ranker ranker = new Ranker();
//对AttributeSelection对象进行初始化,设置评价器、排行器与数据。
AttributeSelection attributeSelection = new AttributeSelection();
attributeSelection.setEvaluator(infoGainAttributeEval);
attributeSelection.setSearch(ranker);
attributeSelection.SelectAttributes(data);
//将属性索引数组转换成字符串并打印
int[] selectedAttributes = attributeSelection.selectedAttributes();
System.out.println(Utils.arrayToString(selectedAttributes));
打印结果如下:12,3,7,2,0,1,8,9,13,4,11,5,15,10,6,14,16
最有价值的属性是12(fins)、3(eggs)、7(aquatic)、2(hair)等。基于这个结果,按次序 剔除无用特征,让学习算法生成更准确、更快的学习模型。 那么,“要保留多少个属性”最后是由什么决定的呢?对于确切的数目,没有什么现成的经 验可以借鉴,究竟保留多少属性取决于具体的数据与问题。属性选择的目的是选择那些可以让你 的模型变得更好的属性,所以应该把重点放在考察“属性是否有助于进一步改进模型”上。
学习算法
加载数据并选好最佳特征后,接下来学习一些分类模型。先从最基本的决策树开始。 Weka中,决策树在J48类中实现,它重新实现了著名的Quinlan's C4.5决策树学习器(Quinlan,
1993)。 首先,初始化一个新的J48决策树学习器。可以使用一个字符串表传递额外的参数,比如剪 枝(tree pruning),它用来控制模型的复杂度(请参考第1章)。例子中,我们将创建一棵未剪枝 树(un-pruned tree),为此传递一个U参数。
//构建决策树
String[] options = new String[1];
options[0] = "-U";
J48 j48 = new J48();
j48.setOptions(options);
//调用buildClassifier(Instances)方法,对学习过程进行初始化。
j48.buildClassifier(data);
System.out.println(j48);
打印结果如下:
J48 unpruned tree
------------------
feathers = false
| milk = false
| | backbone = false
| | | airborne = false
| | | | predator = false
| | | | | legs <= 2: invertebrate (2.0)
| | | | | legs > 2: insect (2.0)
| | | | predator = true: invertebrate (8.0)
| | | airborne = true: insect (6.0)
| | backbone = true
| | | fins = false
| | | | tail = false: amphibian (3.0)
| | | | tail = true: reptile (6.0/1.0)
| | | fins = true: fish (13.0)
| milk = true: mammal (41.0)
feathers = true: bird (20.0)
Number of Leaves : 9
Size of the tree : 17
从输出结果可以看到,未剪枝树总共有17个节点,其中9个是叶子(Leaves)。 另一种呈现树的方法是利用内建的TreeVisualizer树浏览器
//构建决策树浏览
TreeVisualizer tv = new TreeVisualizer(null, j48.graph(), new PlaceNode2());
JFrame frame = new JFrame("Decision Tree Visualization");
frame.setSize(800, 500);
frame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
frame.getContentPane().add(tv);
frame.setVisible(true);
tv.fitToScreen();
决策过程从顶部节点(也叫根节点)开始。节点标签指定要检查的属性值。示例中,先检查 feathers属性的值。如果feathers属性值存在,进入右手分支,到达标记为bird的叶子,表示 有20个样本支持这个输出结果。如果feathers属性值不存在,进入左手分支,到达下一个属性——milk,继续检查milk属性的值,然后进入与属性值相匹配的分支。不断重复这个过程, 直到到达叶节点。
对新数据分类
假如我们有一个动物的属性,但是不知道它的标签,此时可以使用已经学习的分类模型进行 预测。
首先,构造一个特征向量描述新样本,然后,调用模型的classify(Instance)方法获取类值,并返回标签索引:
//区分一个新的实例
double[] vals = new double[data.numAttributes()];
vals[0] = 1.0; // hair {false, true}
vals[1] = 0.0; // feathers {false, true}
vals[2] = 0.0; // eggs {false, true}
vals[3] = 1.0; // milk {false, true}
vals[4] = 0.0; // airborne {false, true}
vals[5] = 0.0; // aquatic {false, true}
vals[6] = 0.0; // predator {false, true}
vals[7] = 1.0; // toothed {false, true}
vals[8] = 1.0; // backbone {false, true}
vals[9] = 1.0; // breathes {false, true}
vals[10] = 1.0; // venomous {false, true}
vals[11] = 0.0; // fins {false, true}
vals[12] = 4.0; // legs INTEGER [0,9]
vals[13] = 1.0; // tail {false, true}
vals[14] = 1.0; // domestic {false, true}
vals[15] = 0.0; // catsize {false, true}
Instance newInst = new DenseInstance(1.0, vals);
newInst.setDataset(data);
double label = j48.classifyInstance(newInst);
System.out.println(data.classAttribute().value((int)label));
最后输出的结果是mammal类标签。
评估与预测误差度量
虽然创建了模型,但还不知道它是否值得我们信任。为了评估其性能,需要用到交叉验证技术。 Weka提供Evaluation类,帮助我们实现交叉验证。使用时,需要提供模型、数据、折数 (number of folds)以及一个初始的随机种子,评估结果存储在Evaluation对象中。
//评估
Classifier classifier = new J48();
Evaluation evaluation = new Evaluation(data);
evaluation.crossValidateModel(classifier, data, 10, new Random(1), new Object[]{});
System.out.println(evaluation.toSummaryString());
结果如下:
Correctly Classified Instances 93 92.0792 %
Incorrectly Classified Instances 8 7.9208 %
Kappa statistic 0.8955
Mean absolute error 0.0225
Root mean squared error 0.14
Relative absolute error 10.2478 %
Root relative squared error 42.4398 %
Coverage of cases (0.95 level) 96.0396 %
Mean rel. region size (0.95 level) 15.4173 %
Total Number of Instances 101
混淆矩阵
// 混淆矩阵
double[][] confusionMatrix = evaluation.confusionMatrix();
System.out.println(evaluation.toMatrixString());
结果如下:
=== Confusion Matrix ===
a b c d e f g <-- classified as
41 0 0 0 0 0 0 | a = mammal
0 20 0 0 0 0 0 | b = bird
0 0 3 1 0 1 0 | c = reptile
0 0 0 13 0 0 0 | d = fish
0 0 1 0 3 0 0 | e = amphibian
0 0 0 0 0 5 3 | f = insect
0 0 0 0 0 2 8 | g = invertebrate
第一行中,第一列的名称对应于分类模型指派的标签。然后,每一个附加行对应于一个实际 为真的类值。比如,第二行对应于那些实际带有mammal类标签的实例。在列中,读取所有被正 确分类为mammals的哺乳动物。第四行“爬行动物”中,可以看到有3个样本被正确分类为 reptile,一个被分类为fish,一个被分类为insect。由此可见,混淆矩阵可以让我们进一步 了解分类模型所犯错误的具体类型。
通过模型的预测结果生成 ROC 曲线数据
// 绘制 ROC 曲线
ThresholdCurve tc = new ThresholdCurve(); // 创建 ThresholdCurve 对象,用于生成 ROC 曲线数据
int classIndex = 0; // 选择要分析的目标类,通常为正类的索引
Instances result = tc.getCurve(evaluation.predictions(), classIndex); // 通过模型的预测结果生成 ROC 曲线数据
// 绘制曲线
ThresholdVisualizePanel vmc = new ThresholdVisualizePanel(); // 创建 ThresholdVisualizePanel 对象,用于可视化 ROC 曲线
vmc.setROCString("(ROC 曲线下面积 = " + tc.getROCArea(result) + ")"); // 设置 ROC 曲线的标签,显示 ROC 曲线下面积
vmc.setName(result.relationName()); // 设置曲线的名称,通常为数据集的名称
PlotData2D temp = new PlotData2D(result); // 将生成的 ROC 曲线数据转换为可绘制的数据格式
temp.setPlotName(result.relationName()); // 设置绘图的名称
temp.addInstanceNumberAttribute(); // 添加实例编号属性,便于后续绘图
// 指定哪些点需要连接
boolean[] cp = new boolean[result.numInstances()]; // 创建一个布尔数组,用于指定哪些点需要连接
for (int n = 1; n < cp.length; n++) {
cp[n] = true; // 将所有点标记为需要连接
}
temp.setConnectPoints(cp); // 设置连接点的配置
// 添加绘图
vmc.addPlot(temp); // 将生成的曲线数据添加到可视化面板中
// 显示曲线
JFrame frameRoc = new javax.swing.JFrame("ROC Curve"); // 创建一个 JFrame 窗口,标题为 "ROC Curve"
frameRoc.setSize(800, 500); // 设置窗口大小为 800x500
frameRoc.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE); // 设置窗口关闭操作,点击关闭按钮时退出程序
frameRoc.getContentPane().add(vmc); // 将可视化面板添加到窗口中
frameRoc.setVisible(true); // 显示窗口
选择分类算法
机器学习中,朴素贝叶斯是一种最简单、最有效果且最有效率的归纳算法。特征独立(现实 世界中很少有这种情况)时,从理论上来说,这是最好的。即使带有从属特征,它的性能也是非 常好的(Zhang,2004)。主要缺点是它不能学习特征之间如何进行相互作用,比如,尽管你喜欢 往茶里放些柠檬或牛奶,可你讨厌同时放入二者。
决策树的主要优点在于,模型是一棵树,它在样本的学习过程中很容易做解释与说明。决策 树既可以处理名义特征,也可以处理数值特征,并且不用担心数据是否是线性可分的。
完整代码
private static String PATH = ClassUtils.getDefaultClassLoader().getResource("zoo.arff").getPath();
public static void main(String[] args) throws Exception {
//加载数据
//DataSource对象。它可以接受各种文件格式,并将其转换成Instances。
ConverterUtils.DataSource ds = new ConverterUtils.DataSource(PATH);
Instances data = ds.getDataSet();
// 移除 animal 属性
Remove remove = new Remove();
remove.setOptions(new String[]{"-R", "1"});
remove.setInputFormat(data);
data = Filter.useFilter(data, remove);
//将把信息增益用作评价器,通过它们的信息增益分数对特征进行分类排序。
InfoGainAttributeEval infoGainAttributeEval = new InfoGainAttributeEval();
Ranker ranker = new Ranker();
//对AttributeSelection对象进行初始化,设置评价器、排行器与数据。
AttributeSelection attributeSelection = new AttributeSelection();
attributeSelection.setEvaluator(infoGainAttributeEval);
attributeSelection.setSearch(ranker);
attributeSelection.SelectAttributes(data);
//将属性索引数组转换成字符串并打印
int[] selectedAttributes = attributeSelection.selectedAttributes();
System.out.println(Utils.arrayToString(selectedAttributes));
//构建决策树
String[] options = new String[1];
options[0] = "-U";
J48 j48 = new J48();
j48.setOptions(options);
//调用buildClassifier(Instances)方法,对学习过程进行初始化。
j48.buildClassifier(data);
System.out.println(j48);
//区分一个新的实例
double[] vals = new double[data.numAttributes()];
vals[0] = 1.0; // hair {false, true}
vals[1] = 0.0; // feathers {false, true}
vals[2] = 0.0; // eggs {false, true}
vals[3] = 1.0; // milk {false, true}
vals[4] = 0.0; // airborne {false, true}
vals[5] = 0.0; // aquatic {false, true}
vals[6] = 0.0; // predator {false, true}
vals[7] = 1.0; // toothed {false, true}
vals[8] = 1.0; // backbone {false, true}
vals[9] = 1.0; // breathes {false, true}
vals[10] = 1.0; // venomous {false, true}
vals[11] = 0.0; // fins {false, true}
vals[12] = 4.0; // legs INTEGER [0,9]
vals[13] = 1.0; // tail {false, true}
vals[14] = 1.0; // domestic {false, true}
vals[15] = 0.0; // catsize {false, true}
Instance newInst = new DenseInstance(1.0, vals);
newInst.setDataset(data);
double label = j48.classifyInstance(newInst);
System.out.println(data.classAttribute().value((int) label));
//构建决策树浏览
TreeVisualizer tv = new TreeVisualizer(null, j48.graph(), new PlaceNode2());
JFrame frame = new JFrame("Decision Tree Visualization");
frame.setSize(800, 500);
frame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
frame.getContentPane().add(tv);
frame.setVisible(true);
tv.fitToScreen();
//评估
Classifier classifier = new J48();
Evaluation evaluation = new Evaluation(data);
evaluation.crossValidateModel(classifier, data, 10, new Random(1), new Object[]{});
System.out.println(evaluation.toSummaryString());
// 混淆矩阵
double[][] confusionMatrix = evaluation.confusionMatrix();
System.out.println(evaluation.toMatrixString());
// 绘制 ROC 曲线
ThresholdCurve tc = new ThresholdCurve(); // 创建 ThresholdCurve 对象,用于生成 ROC 曲线数据
int classIndex = 0; // 选择要分析的目标类,通常为正类的索引
Instances result = tc.getCurve(evaluation.predictions(), classIndex); // 通过模型的预测结果生成 ROC 曲线数据
// 绘制曲线
ThresholdVisualizePanel vmc = new ThresholdVisualizePanel(); // 创建 ThresholdVisualizePanel 对象,用于可视化 ROC 曲线
vmc.setROCString("(ROC 曲线下面积 = " + tc.getROCArea(result) + ")"); // 设置 ROC 曲线的标签,显示 ROC 曲线下面积
vmc.setName(result.relationName()); // 设置曲线的名称,通常为数据集的名称
PlotData2D temp = new PlotData2D(result); // 将生成的 ROC 曲线数据转换为可绘制的数据格式
temp.setPlotName(result.relationName()); // 设置绘图的名称
temp.addInstanceNumberAttribute(); // 添加实例编号属性,便于后续绘图
// 指定哪些点需要连接
boolean[] cp = new boolean[result.numInstances()]; // 创建一个布尔数组,用于指定哪些点需要连接
for (int n = 1; n < cp.length; n++) {
cp[n] = true; // 将所有点标记为需要连接
}
temp.setConnectPoints(cp); // 设置连接点的配置
// 添加绘图
vmc.addPlot(temp); // 将生成的曲线数据添加到可视化面板中
// 显示曲线
JFrame frameRoc = new javax.swing.JFrame("ROC Curve"); // 创建一个 JFrame 窗口,标题为 "ROC Curve"
frameRoc.setSize(800, 500); // 设置窗口大小为 800x500
frameRoc.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE); // 设置窗口关闭操作,点击关闭按钮时退出程序
frameRoc.getContentPane().add(vmc); // 将可视化面板添加到窗口中
frameRoc.setVisible(true); // 显示窗口
}
结论
总体表现:模型在分类任务上的表现较好,正确分类率高达 92.0792%。
分类效果:在哺乳动物、鸟类、鱼类这三类上,模型的分类效果非常出色,全部正确分类。在爬行动物、两栖动物、昆虫、无脊椎动物这四类上,虽然也表现不错,但存在一些错误分类。
改进方向:可以通过收集更多数据或调整特征选择来进一步提高模型在爬行动物、两栖动物、昆虫、无脊椎动物这四类上的分类效果。