日撸java三百行day63-65

news2024/11/26 5:44:13

文章目录

  • 说明
  • 1. Day63-65 AdaBoosting算法
    • 1 AdaBoostin举例
      • 1.1数据样本
      • 1.2 举例过程
    • 2. 理论知识
    • 3. 总结
  • 2. 代码理解
    • 1. WeightedInstances类
    • 2. 选择基分类器并进行训练(树桩分类器)
    • 3. 计算误差率和误差系数(树桩分类器)
    • 4. 计算精确度
    • 5. 总结
  • 3. java知识
    • 1. 抽象类
    • 2. static关键字
    • 3. 静态代码块

说明

闵老师的文章链接: 日撸 Java 三百行(总述)_minfanphd的博客-CSDN博客
自己也把手敲的代码放在了github上维护:https://github.com/fulisha-ok/sampledata

1. Day63-65 AdaBoosting算法

1 AdaBoostin举例

AdaBoost算法是一种集成学习算法,是Boosting算法中的一种,通过组合多个弱分类器来构建一个强分类器。因为我也是第一次接触这个算法,直接去看一些算法以及公式,感觉还是有点吃力,所以我结合网上看的例子,自以及己也模拟一个例子先手动过一遍这个算法过程,然后再去学习他的理论知识。(其中的计算结果我是通过文章的代码计算所得的),如果遇到一些概念不懂,我们先假装自己懂,看完例子再去看理论,于我而言,蛮有用的。(若有问题,欢迎指正~)

1.1数据样本

我把原来iris.aff文件的150个缩减为12个数据,数据如下:一共有4种特征,3个类别,12个数据集。

@RELATION iris

@ATTRIBUTE sepallength	REAL
@ATTRIBUTE sepalwidth 	REAL
@ATTRIBUTE petallength 	REAL
@ATTRIBUTE petalwidth	REAL
@ATTRIBUTE class 	{Iris-setosa,Iris-versicolor,Iris-virginica}

@DATA
5.1,3.5,1.4,0.2,Iris-setosa
4.9,3.0,1.4,0.2,Iris-setosa
4.7,3.2,1.3,0.2,Iris-setosa
4.6,3.1,1.5,0.2,Iris-setosa
5.0,3.6,1.4,0.2,Iris-setosa
5.4,3.9,1.7,0.4,Iris-setosa
7.0,3.2,4.7,1.4,Iris-versicolor
6.4,3.2,4.5,1.5,Iris-versicolor
6.9,3.1,4.9,1.5,Iris-versicolor
5.5,2.3,4.0,1.3,Iris-versicolor
6.5,2.8,4.6,1.5,Iris-versicolor
5.7,2.8,4.5,1.3,Iris-versicolor
6.3,3.3,6.0,2.5,Iris-virginica
5.8,2.7,5.1,1.9,Iris-virginica
7.1,3.0,5.9,2.1,Iris-virginica
6.3,2.9,5.6,1.8,Iris-virginica
6.5,3.0,5.8,2.2,Iris-virginica
7.6,3.0,6.6,2.1,Iris-virginica

1.2 举例过程

先初始化这12个数据集的权重: 1 N \frac{1}{N} N1

索引01234567891011
权重0.0830.0830.0830.0830.0830.0830.0830.0830.0830.0830.0830.083

由于我们有三个类别,对于属性的选择我们都是随机选择属性来进行训练。我们的基学习器采用的是树桩分类器。(树桩分类器是决策树的一种特殊形式,它只包含一个根结点和两个叶子结点,相当于把数据按一个阈值一分为二,基类学习器还有我们之前学的KNN,贝叶斯算法等。)
前提:我们在每次学习的时候,基类学习器G(x)的值我们用+1来表示与我们预期相符合(正类别);-1来表示与我们预期不符合(负类别)

  • 第一个弱学习器(基类学习器)
    随机选择的属性为sepallength
索引01234567891011
属性5.14.94.74.67.06.46.95.56.35.87.16.3
权重w10.0830.0830.0830.0830.0830.0830.0830.0830.0830.0830.0830.083
正确结果 y 1 ( x ) y_{1}(x) y1(x)000011112222
预测结果 G 1 ( x ) G_{1}(x) G1(x)000011111111

我们的最佳分割点bestCut=5.3; leftLeafLabel = 0 ; rightLeafLabel = 1;
误差率 e 1 = 0.083 + 0.083 + 0.083 + 0.083 = 0.3333 e1=0.083 + 0.083 + 0.083 + 0.083 = 0.3333 e1=0.083+0.083+0.083+0.083=0.3333
误差系数: α 1 = 1 2 log ⁡ 1 − e 1 e 1 = 0.3465 \alpha _{1}=\frac{1}{2}\log \frac{1-e_{1}}{e_{1}}=0.3465 α1=21loge11e1=0.3465
训练数据的准确率:0.6667
弱学习器 G 1 ( x ) G_{1}(x) G1(x) G 1 ( x ) = { 0 , x < 5.3 1 , x > 5.3 G_{1}(x)= \begin{cases} 0, & \text {x < 5.3} \\ 1, & \text{x > 5.3} \end{cases} G1(x)={0,1,x < 5.3x > 5.3

  • 第二个弱学习器(基类学习器)
    随机选择的属性为petallength
索引01234567891011
属性1.41.41.31.54.74.54.94.06.05.15.95.6
权重w20.0620.0620.0620.0620.0620.0620.0620.0620.1250.1250.1250.125
正确结果 y 2 ( x ) y_{2}(x) y2(x)000011112222
预测结果 G 2 ( x ) G_{2}(x) G2(x)000022222222

我们的阈值取值为: 2.75; leftLeafLabel = 0 ; rightLeafLabel = 2;
误差权重之和 e 2 = 0.062 + 0.062 + 0.062 + 0.062 = 0.249 e2=0.062 + 0.062 + 0.062 + 0.062 = 0.249 e2=0.062+0.062+0.062+0.062=0.249
误差系数: α 2 = 1 2 log ⁡ 1 − e 2 e 2 = 0.5493 \alpha _{2}=\frac{1}{2}\log \frac{1-e_{2}}{e_{2}}=0.5493 α2=21loge21e2=0.5493
训练数据的准确率:0.6666
弱学习器 G 2 ( x ) G_{2}(x) G2(x) G 2 ( x ) = { 0 , x < 2.75 2 , x >2.75 G_{2}(x)= \begin{cases} 0, & \text {x < 2.75} \\ 2, & \text{x >2.75} \end{cases} G2(x)={0,2,x < 2.75x >2.75

  • 第三个弱学习器(基类学习器)
    随机选择的属性为petalwidth
索引01234567891011
属性0.20.20.20.21.41.51.51.32.51.92.11.8
权重w30.0410.0410.0410.0410.1250.1250.1250.1250.0830.0830.0830.083
正确结果 y 3 ( x ) y_{3}(x) y3(x)000011112222
预测结果 G 3 ( x ) G_{3}(x) G3(x)111111112222

我们的阈值取值为:1.65; leftLeafLabel = 1 ; rightLeafLabel = 2;
误差权重之和 e 3 = 0.041 + 0.041 + 0.041 + 0.041 = 0.166 e3=0.041 + 0.041 + 0.041 + 0.041 = 0.166 e3=0.041+0.041+0.041+0.041=0.166
误差系数: α 3 = 1 2 log ⁡ 1 − e 3 e 3 = 0.8047 \alpha _{3}=\frac{1}{2}\log \frac{1-e_{3}}{e_{3}}=0.8047 α3=21loge31e3=0.8047
训练数据的准确率:1.0
弱学习器 G 3 ( x ) G_{3}(x) G3(x) G 3 ( x ) = { 1 , x < 1.65 2 , x > 1.65 G_{3}(x)= \begin{cases} 1, & \text {x < 1.65} \\ 2, & \text{x > 1.65} \end{cases} G3(x)={1,2,x < 1.65x > 1.65

我针对第三次的学习来计算他们的准确率:
假设给出的实例:5.1,3.5,1.4,0.2,Iris-setosa
他在第一次基类学习器中预测类别为0: 则第一个类别累加误差系数 0.3465
他在第二次基类学习器中预测类别为0: 则第一个类别累加0.5493后为0.8958
他在第一次基类学习器中预测类别为1: 则第二个类别为:0.8047
故第一个实例的预测类别概率为:[0.8958797346140277, 0.8047189562170501, 0.0]可知预测的为类别0

其他11个实例依次预测的类别为
[0.8958797346140277, 0.8047189562170501, 0.0]
[0.8958797346140277, 0.8047189562170501, 0.0]
[0.8958797346140277, 0.8047189562170501, 0.0]
[0.0, 1.1512925464970227, 0.549306144334055]
[0.0, 1.1512925464970227, 0.549306144334055]
[0.0, 1.1512925464970227, 0.549306144334055]
[0.0, 1.1512925464970227, 0.549306144334055]
[0.0, 0.34657359027997264, 1.3540251005511053]
[0.0, 0.34657359027997264, 1.3540251005511053]
[0.0, 0.34657359027997264, 1.3540251005511053]
[0.0, 0.34657359027997264, 1.3540251005511053]

结合上面的例子,就可以来学习理论知识了:

2. 理论知识

算法步骤(我参考了网上的帖子,这里梳理出来是为了方便以后回顾)

  • 1. 假设训练集样本是这样的:
    T = ( x 1 , y 1 ) , ( x 2 , y 2 ) . . . . . ( x m , y m ) T={({x_{1}, y_{1}}), ({x_{2}, y_{2}})..... ({x_{m}, y_{m}})} T=(x1,y1),(x2,y2).....(xm,ym)

y i {y_{i}} yi 的取值如果在一个二分类问题中,可以取值+1和-1,在多类别问题中,要根据情况而定。如我们在上面的例子中,表示样本 x i x_{i} xi所属的类别

  • 2. 第m个弱学习器的各个数据的权重是
    D k = ( w m 1 , w m 2 , w m 3 . . . . w m i ) D_{k}=(w_{m1},w_{m2},w_{m3}....w_{mi}) Dk=(wm1,wm2,wm3....wmi)
    其中初始化的权重为如下:N为样本数量
    D 1 = ( w 11 , w 12 , w 13 . . . . ) , w 1 i = 1 N , i = 1 , 2 , 3 , 4... N D_{1}=(w_{11},w_{12},w_{13}....), w_{1i}=\frac{1}{N}, i=1,2,3,4...N D1=(w11,w12,w13....),w1i=N1,i=1,2,3,4...N
  • 3. 第m个学习器 G m ( x ) G_{m}(x) Gm(x)

G m ( x ) G_{m}(x) Gm(x)其取值只能有两个(正类别+1和负类别-1),而在多类别问题中,也依情况而定,我在上面例子中去的是具体的类别。

  • 4. G m ( x ) G_{m}(x) Gm(x)在训练数据集上的误差率(即失败样本数*样本权重之和):
    e m = P ( G m ( x i ) ≠ y i ) = ∑ i = 1 m w m i I ( G m ( x i ) ≠ y i ) e_{m}=P(G_{m}(x_{i})\neq y_{i})=\sum_{i=1}^{m}w_{mi}I(G_{m}(x_{i})\neq y_{i}) em=P(Gm(xi)=yi)=i=1mwmiI(Gm(xi)=yi)
    其中 I ( G m ( x i ) ≠ y i ) { I(G_m(x_i) ≠ y_i) } I(Gm(xi)=yi)的值为1,即预测的结果与实际结果不符合

  • 5.计算 G m ( x ) G_{m}(x) Gm(x)的权重系数:
    (由公式可以看出,当误差率越小的弱分类器,在最后的强分类器中贡献越大)
    α m = 1 2 log ⁡ 1 − e m e m \alpha _{m}=\frac{1}{2}\log \frac{1-e_{m}}{e_{m}} αm=21logem1em
    权重系数在最终分类器中会发挥作用,当值越大对最终的影响就较大;同时权重系数还会影响下一次权重更新。

  • 6.更新权重的公式:
    已知: D m = ( w m 1 , w m 2 , w m 3 . . . . w m i ) D_{m}=(w_{m1},w_{m2},w_{m3}....w_{mi}) Dm=(wm1,wm2,wm3....wmi)的权重,现在计算第m+1个样本集的权重

w m + 1 , i = w m i Z m e x p ( − α m y i G m ( x i ) ) w_{m+1,i}=\frac{w_{mi}}{Z_{m}}exp(-\alpha _{m}y_{i}G_{m}(x_{i})) wm+1,i=Zmwmiexp(αmyiGm(xi))
其中我们可以简化公式:
w m + 1 , i = { w m i Z m 1 e α m , G m ( x i ) = y i → ( G m ( x i ) ∗ y i = 1 ) w m i Z m e α m , G m ( x i ) ≠ y i → ( G m ( x i ) ∗ y i = − 1 ) w_{m+1,i} = \begin{cases} \frac{w_{mi}}{Z_{m}}\frac{1}{e^{\alpha_{m}}}, & G_{m}(x_{i})=y_{i} \rightarrow (G_{m}(x_{i})*y_{i}=1)\\ \frac{w_{mi}}{Z_{m}}e^{\alpha_{m}}, & G_{m}(x_{i})\neq y_{i} \rightarrow (G_{m}(x_{i})*y_{i}=-1)\\ \end{cases} wm+1,i={Zmwmieαm1,Zmwmieαm,Gm(xi)=yi(Gm(xi)yi=1)Gm(xi)=yi(Gm(xi)yi=1)
( G m ( x i ) ∗ y i = 1 ) (G_{m}(x_{i})*y_{i}=1) (Gm(xi)yi=1)的前提是 G m ( x i ) 和 y i G_{m}(x_{i})和y_{i} Gm(xi)yi取值在+1和-1,而实际上我觉得这个公式要描述的意思就可以理解为:如果预测结果一致就为1,预测结果不一致为-1。就如例子上面我们有3个类别,并没有按+1,-1去计算,所以要依情况而定,但核心思想不变!其中 Z k Z_{k} Zk是规范化因子
Z k = ∑ i = 1 m w k i e x p ( − α k y i G k ( x i ) ) Z_{k}=\sum_{i=1}^{m}w_{ki}exp(-\alpha_{k}y_{i}G_{k}(x_{i})) Zk=i=1mwkiexp(αkyiGk(xi))
从上面的是公式,我们可以知道在计算 w m + 1 , i w_{m+1,i} wm+1,i 的值时, α m \alpha _{m} αm的值肯定是大于0的 则 e α m e^{\alpha _{m}} eαm的值一定是大于1,当我们在第m+1次更新权重时,如果在第m次样本被正确分类( ( G m ( x i ) = y i (G_{m}(x_{i})=y_{i} (Gm(xi)=yi),则他在第m+1次时权重就会变小,而若被错误分类( G m ( x i ) ≠ y i G_{m}(x_{i})\neq y_{i} Gm(xi)=yi),则在第m+1次时权重就会变大 (因为在上一次已经被错误分类了,那么我在这一次的分类中我需要更重视错误分类的,所以就把权重要调大!)

3. 总结

结合上面的例子和公式,去看这个图就非常的生动,我们训练多个弱分类器(串行的),每一次弱分类器中数据的权重都是根据上一次的权重,误差概率和误差系数来更新,经过多次训练后我们最终形成一个最终的强分类器,我们预测类别是就通过计算多个弱分类器的一个加权和来预测那个类别概率最大(并行的)
在这里插入图片描述
AdaBoosting算法在每一次的弱分类器的训练中,实际上是一个二分类的问题。但是我们在上面的例子中,他实际上是一个多类别问题,那我们也可以通过这个算法来实现。我们的实现是对每个基本分类器都选择两个类别。在预测阶段,通过对这些基本分类器的预测结果进行投票,选择概率最大的。

2. 代码理解

我从Booster类去理解代码。下面是Booster的main方法,而最核心的东西就是train()方法,其中我们设置了训练的次数是20个。
在这里插入图片描述
下面是train的核心代码

public void train() {
		// Step 1. Initialize.
		WeightedInstances tempWeightedInstances = null;
		double tempError;
		numClassifiers = 0;
		SimpleTools.processTrackingOutput("Booster.train() Step 1\r\n");

		// Step 2. Build other classifiers.
		for (int i = 0; i < classifiers.length; i++) {
			Common.runSteps ++;
			// Step 2.1 Construct or adjust the weightedInstances
			if (i == 0) {
				tempWeightedInstances = new WeightedInstances(trainingData);
			} else {
				// Adjust the weights of the data.
				tempWeightedInstances.adjustWeights(classifiers[i - 1].computeCorrectnessArray(),
						classifierWeights[i - 1]);
			}
			SimpleTools.processTrackingOutput("Booster.train() Step 2.1\r\n");

			// Step 2.2 Train the next classifier.
			switch (baseClassifierType) {
			case STUMP_CLASSIFIER:
				classifiers[i] = new StumpClassifier(tempWeightedInstances);
				break;
			case BAYES_CLASSIFIER:
				classifiers[i] = new BayesClassifier(tempWeightedInstances);
				break;
			case Gaussian_CLASSIFIER:
				classifiers[i] = new GaussianClassifier(tempWeightedInstances);
				break;
			default:
				System.out.println(
						"Internal error. Unsupported base classifier type: " + baseClassifierType);
				System.exit(0);
			}
			classifiers[i].train();
			SimpleTools.processTrackingOutput("Booster.train() Step 2.2\r\n");

			// tempAccuracy = classifiers[i].computeTrainingAccuracy();
			//计算加权错误率
			tempError = classifiers[i].computeWeightedError();
			// Set the classifier weight. 弱分类器的权重
			classifierWeights[i] = 0.5 * Math.log(1 / tempError - 1);
			if (classifierWeights[i] < 1e-6) {
				classifierWeights[i] = 0;
			}
				// SimpleTools.variableTrackingOutput("Booster.train()");

			SimpleTools.variableTrackingOutput("Classifier #" + i + " , weighted error = "
					+ tempError + ", weight = " + classifierWeights[i] + "\r\n");

			numClassifiers++;

			// The accuracy is enough. 记录当前训练轮次(迭代)中集成分类器在训练数据上的准确率
			if (stopAfterConverge) {
				double tempTrainingAccuracy = computeTrainingAccuray();
				SimpleTools.variableTrackingOutput(
						"The accuracy of the booster is: " + tempTrainingAccuracy + "\r\n");
				if (tempTrainingAccuracy > 0.999999) {
					SimpleTools.processTrackingOutput(
							"Stop at the round: " + i + " due to converge.\r\n");
					break;
				}
			}
		}
	}

其中的for循环则是训练的次数,我就以一次训练过程来学习整个内容。通过Booster类中的train方法为入口,去了解这个方法中调用其他类的一些方法。

1. WeightedInstances类

在WeightedInstances类中最重要的方法就是adjustWeights,调整数据样本的权重。
下面这个是Booster方法中train()中调用WeightedInstances类的方法:
在这里插入图片描述

  • new WeightedInstances(trainingData)
    这里的方法是更新数据集样本中每个数据的权重,其中new WeightedInstances(trainingData)方法是在第一次训练时,初始化权重的值。初始化权重为 1 N \frac{1}{N} N1,N为样本的个数
	public WeightedInstances(Instances paraInstances) {
		super(paraInstances);
		setClassIndex(numAttributes() - 1);

		// Initialize weights
		weights = new double[numInstances()];
		double tempAverage = 1.0 / numInstances();
		for (int i = 0; i < weights.length; i++) {
			Common.runSteps ++;
			weights[i] = tempAverage;
		}
		SimpleTools.variableTrackingOutput("Instances weights are: " + Arrays.toString(weights));
	}
  • adjustWeights方法
    • 入参paraCorrectArray是上一次训练的预测结果(bool值:true or false)
    • 入参paraAlpha 上一次训练的误差系数 α m \alpha _{m} αm
      循环数据样本,更新每个数据的权重值(weights[i]的值):
      若在上一次训练中预测结果为true,则weights[i] = weights[i] * e α m e^{\alpha _{m}} eαm;
      若预测结果为false,则weights[i] = weights[i] * 1 e α m \frac{1}{e^{\alpha _{m}}} eαm1
      最后的循环中进行归一化操作。
public void adjustWeights(boolean[] paraCorrectArray, double paraAlpha) {
		// Step 2. Calculate alpha.
		double tempIncrease = Math.exp(paraAlpha);

		// Step 3. Adjust.
		double tempWeightsSum = 0; // For normalization.
		for (int i = 0; i < weights.length; i++) {
			Common.runSteps ++;
			if (paraCorrectArray[i]) {
				weights[i] /= tempIncrease;
			} else {
				weights[i] *= tempIncrease;
			} // Of if
			tempWeightsSum += weights[i];
		}

		// Step 4. Normalize.
		for (int i = 0; i < weights.length; i++) {
			Common.runSteps ++;
			weights[i] /= tempWeightsSum;
		}

		SimpleTools.variableTrackingOutput(
				"After adjusting, instances weights are: " + Arrays.toString(weights));
	}

2. 选择基分类器并进行训练(树桩分类器)

在这里插入图片描述

这里选择基分类器为树桩分类器(StumpClassifier类),如下是StumpClassifier类的train()方法。
具体的实现步骤是:

  • 因为特征值较多,这里采用随机的方法,随机选择。(java的Random方法)
  • tempValuesArray记录所选择特征值的取值并进行排序。(Arrays.sort方法)
  • 迭代所有的数据,找出最佳的分割点bestCut,以及左右子结点的类别取值(leftLeafLabel,rightLeafLabel)
    每一次训练,遍历所有的特征值,寻找可能取得的分割点进行计算,以获取最佳的分割点(选取分割点的方法有很多,在文章中的分割点是选取的前后两个特征值的平均值,但分割点的选取方式有很多,可以依情况而定) 选择最佳分割点是使划分的两个子集数据误差最小。
@Override
	public void train() {
		// Step 1. Randomly choose an attribute.
		selectedAttribute = Common.random.nextInt(numConditions);

		// Step 2. Find all attribute values and sort.
		double[] tempValuesArray = new double[numInstances];
		for (int i = 0; i < tempValuesArray.length; i++) {
			tempValuesArray[i] = weightedInstances.instance(i).value(selectedAttribute);
		}
		Arrays.sort(tempValuesArray);
		Common.runSteps += (long)(numInstances * Math.log(numInstances) / Math.log(2));

		// Step 3. Initialize, classify all instances as the same with the
		// original cut.
		int tempNumLabels = numClasses;
		double[] tempLabelCountArray = new double[tempNumLabels];
		int tempCurrentLabel;

		// Step 3.1 Scan all labels to obtain their counts.
		for (int i = 0; i < numInstances; i++) {
			Common.runSteps ++;
			// The label of the ith instance
			tempCurrentLabel = (int) weightedInstances.instance(i).classValue();
			tempLabelCountArray[tempCurrentLabel] += weightedInstances.getWeight(i);
		}

		// Step 3.2 Find the label with the maximal count.
		double tempMaxCorrect = 0;
		int tempBestLabel = -1;
		for (int i = 0; i < tempLabelCountArray.length; i++) {
			if (tempMaxCorrect < tempLabelCountArray[i]) {
				tempMaxCorrect = tempLabelCountArray[i];
				tempBestLabel = i;
			}
		}

		// Step 3.3 The cut is a little bit smaller than the minimal value.
		bestCut = tempValuesArray[0] - 0.1;
		leftLeafLabel = tempBestLabel;
		rightLeafLabel = tempBestLabel;

		// Step 4. Check candidate cuts one by one.
		// Step 4.1 To handle multi-class data, left and right.
		double tempCut;
		double[][] tempLabelCountMatrix = new double[2][tempNumLabels];

		for (int i = 0; i < tempValuesArray.length - 1; i++) {
			// Step 4.1 Some attribute values are identical, ignore them.
			if (tempValuesArray[i] == tempValuesArray[i + 1]) {
				continue;
			}
			tempCut = (tempValuesArray[i] + tempValuesArray[i + 1]) / 2;

			// Step 4.2 Scan all labels to obtain their counts wrt. the cut.
			// Initialize again since it is used many times.
			for (int j = 0; j < 2; j++) {
				for (int k = 0; k < tempNumLabels; k++) {
					Common.runSteps ++;
					tempLabelCountMatrix[j][k] = 0;
				}
			}

			for (int j = 0; j < numInstances; j++) {
				Common.runSteps ++;
				// The label of the jth instance
				tempCurrentLabel = (int) weightedInstances.instance(j).classValue();
				if (weightedInstances.instance(j).value(selectedAttribute) < tempCut) {
					tempLabelCountMatrix[0][tempCurrentLabel] += weightedInstances.getWeight(j);
				} else {
					tempLabelCountMatrix[1][tempCurrentLabel] += weightedInstances.getWeight(j);
				}
			}

			// Step 4.3 Left leaf. 记录左叶子结点的数据
			double tempLeftMaxCorrect = 0;
			int tempLeftBestLabel = 0;
			for (int j = 0; j < tempLabelCountMatrix[0].length; j++) {
				Common.runSteps ++;
				if (tempLeftMaxCorrect < tempLabelCountMatrix[0][j]) {
					tempLeftMaxCorrect = tempLabelCountMatrix[0][j];
					tempLeftBestLabel = j;
				}
			}

			// Step 4.4 Right leaf.
			double tempRightMaxCorrect = 0;
			int tempRightBestLabel = 0;
			for (int j = 0; j < tempLabelCountMatrix[1].length; j++) {
				Common.runSteps ++;
				if (tempRightMaxCorrect < tempLabelCountMatrix[1][j]) {
					tempRightMaxCorrect = tempLabelCountMatrix[1][j];
					tempRightBestLabel = j;
				}
			}

			// Step 4.5 Compare with the current best.
			if (tempMaxCorrect < tempLeftMaxCorrect + tempRightMaxCorrect) {
				Common.runSteps ++;
				tempMaxCorrect = tempLeftMaxCorrect + tempRightMaxCorrect;
				bestCut = tempCut;
				leftLeafLabel = tempLeftBestLabel;
				rightLeafLabel = tempRightBestLabel;
			}
		}

		SimpleTools.variableTrackingOutput("Attribute = " + selectedAttribute + ", cut = " + bestCut
				+ ", leftLeafLabel = " + leftLeafLabel + ", rightLeafLabel = " + rightLeafLabel);
	}

3. 计算误差率和误差系数(树桩分类器)

在这里插入图片描述
调用StumpClassifier类的computeWeightedError方法。(实际上这个computeWeightedError方法是公用的,使用的是StumpClassifier的父类SimpleClassifier提供的实现方法)

  • 在这个方法中,调用computeCorrectnessArray()获取训练样本在这一次训练中的预测结果和实际结果是否一样,例如下:
    在这里插入图片描述
  • 计算误差率:失败样本数*样本权重之和
    resultError += weightedInstances.getWeight(i);
  • 计算误差系数(弱分类器的权重) α m = 1 2 log ⁡ 1 − e m e m \alpha _{m}=\frac{1}{2}\log \frac{1-e_{m}}{e_{m}} αm=21logem1em
    classifierWeights[i] = 0.5 * Math.log(1 / tempError - 1);
public double computeWeightedError() {
		double resultError = 0;
		boolean[] tempCorrectnessArray = computeCorrectnessArray();
		for (int i = 0; i < tempCorrectnessArray.length; i++) {
			Common.runSteps ++;
			if (!tempCorrectnessArray[i]) {
				resultError += weightedInstances.getWeight(i);
			}
		}

		if (resultError < 1e-6) {
			resultError = 1e-6;
		}

		return resultError;
	}

4. 计算精确度

在这里插入图片描述
随着训练的次数越来越多,精确度会越来越好,若精确度足够大,就可以跳出训练。

  • classify方法
    入参为传入的数据实例。 如 5.1,3.5,1.4,0.2,Iris-setosa
    第一个for循环累加已经训练的基分类器中的权重系数,第二个for循环选择所有类别中概率最大的即为预测的类别
    计算所有训练样本的预测结果与实际结果对比,计算准确率。
    从这里也可以看出,当你训练次数越多(即numClassifiers越大),那预测正确的概率也会越大。
public double computeTrainingAccuray() {
		double tempCorrect = 0;

		for (int i = 0; i < trainingData.numInstances(); i++) {
			Common.runSteps ++;
			if (classify(trainingData.instance(i)) == (int) trainingData.instance(i).classValue()) {
				tempCorrect++;
			}
		}

		double tempAccuracy = tempCorrect / trainingData.numInstances();

		return tempAccuracy;
	}

	public int classify(Instance paraInstance) {
		double[] tempLabelsCountArray = new double[trainingData.classAttribute().numValues()];
		for (int i = 0; i < numClassifiers; i++) {
			Common.runSteps ++;
			int tempLabel = classifiers[i].classify(paraInstance);
			tempLabelsCountArray[tempLabel] += classifierWeights[i];
		}

		SimpleTools.variableTrackingOutput(Arrays.toString(tempLabelsCountArray));

		int resultLabel = -1;
		double tempMax = -1;
		for (int i = 0; i < tempLabelsCountArray.length; i++) {
			Common.runSteps ++;
			if (tempMax < tempLabelsCountArray[i]) {
				tempMax = tempLabelsCountArray[i];
				resultLabel = i;
			}
		}

		return resultLabel;
	}

5. 总结

因为这个AdaBosting算法的实现,涉及到不同类中方法的调用,所以需要知道这个算法的一个大致思想以及他的一些理论公式,再去看代码会更容易接受。同时,基分类器这里选择了树桩分类器,还有其他分类器可选。每个分类器的大致思想都是:

  • 初始化新训练样本的权重或根据上一次训练结果更新权重
  • 选择分类器,并根据分类器进行训练并计算在这一次训练中的错误率以及分类器的权重系数

3. java知识

在这个代码的实现中,用了一些java的基础知识

1. 抽象类

SimpleClassifier类。方法包含的成员变量和方法

在这里插入图片描述
抽象类不能被实例只能被继承。抽象类定义的是一组通用相关的属性和方法,并可以提供一些默认的实现,但SimpleClassifier这个类是无法实例化的(即不能被new)他主要的特点有:

  • 1.无法被实例化
  • 2.可以包含抽象方法:抽象方法是没有实现的,且要用abstract关键字声明,子类继承SimpleClassifier必须要实现这两个方法。如SimpleClassifier类的这两个方法。StumpClassifier类集成了SimpleClassifier类,且实现了这两个方法、
	/**
	 * Train the classifier.
	 */
	public abstract void train();

	/**
	 * Classify an instance.
	 * @param paraInstance The given instance.
	 * @return Predicted label.
	 */
	public abstract int classify(Instance paraInstance);	

在这里插入图片描述

  • 可以包含非抽象方法:这些方法包含具体的实现,子类可以继承这些方法,当然也可以重写。例如SimpleClassifier类中的computeCorrectnessArray(),computeTrainingAccuracy(),computeWeightedError()方法
  • 可以包含成员变量和构造方法。

2. static关键字

在今天的代码中,有Common类和SimpleTools类两个公共类,其中的成员变量和方法都是用static关键字修饰的,将在类的所有实例之间共享,即使没有创建类的实例对象,也可以直接通过类名来访问该成员变量。静态变量在内存中只有一份拷贝,被所有实例共享。通常用于表示类的共享数据或常量。
如代码中:
在这里插入图片描述
在这里插入图片描述

3. 静态代码块

在今天的代码中,Common类中有静态代码块,如下:
在这里插入图片描述
静态代码块在类加载的过程中执行,且只执行一次。他的作用主要是在类加载时执行一些初始化操作。(所以一些初始化的操作或变量可以放在这个里面执行)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/649897.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

谈谈Java高并发网站的设计思路

目前Java都在流行一个说词&#xff1a;高并发。 反正不管是不是&#xff0c;反正就是高并发。 谈高并发&#xff0c;我们需要知道几个名词&#xff1a; -响应时间(Response Time&#xff0c;RT)-吞吐量(Throughput)-每秒查询率QPS(Query Per Second)-每秒事务处理量TPS(Transa…

用单元测试读懂 vue3 中的 defineComponen

目录 前言&#xff1a; I. 测试用例 II. 一些基础类型定义 III. 官网文档中的 props V. 开发实践 VI. 全文总结 前言&#xff1a; Vue3是一种流行的JavaScript框架&#xff0c;它在组件开发方面提供了更多的表现力和灵活性&#xff0c;通过使用defineComponent高阶函数&…

腾讯安全联动头部企业携手验证“数字安全免疫力”模型框架

6月13日&#xff0c;腾讯安全联合IDC发布“数字安全免疫力”模型框架及《加强企业数字安全免疫力&#xff0c;助力数字时代下的韧性发展》白皮书&#xff0c;提出用免疫的思维应对新时期下安全建设与企业发展难以协同的挑战&#xff0c;围绕“数据”和“数字业务”建立三层由内…

视觉SLAM十四讲——ch8实践(视觉里程计2)

视觉SLAM十四讲----ch8的实践操作及避坑 0.实践前小知识介绍1. 实践操作前的准备工作2. 实践过程2.1 LK光流2.2 直接法 3. 遇到的问题及解决办法3.1 编译时遇到的问题 0.实践前小知识介绍 里程计的历史渊源是什么&#xff1f; 里程计是一种用来测量车辆或机器人行驶距离的装置…

Uniapp uni-app学习与快速上手

个人开源uni-app开源项目地址&#xff1a;准备中 在线展示项目地址&#xff1a;准备中 什么是uni-app uni&#xff0c;读 you ni&#xff0c;是统一的意思。 Dcloud即数字天堂(北京)网络技术有限公司是W3C成员及HTML5中国产业联盟发起单位&#xff0c;致力于推进HTML5发展构…

MP : Human Motion 人体运动的MLP方法

Back to MLP: A Simple Baseline for Human Motion Prediction conda install -c conda-forge easydict 简介 papercodehttps://arxiv.org/abs/2207.01567v2https://github.com/dulucas/siMLPe Back to MLP是一个仅使用MLP的新baseline,效果SOTA。本文解决了人类运动预测的问…

iconfont彩色图标

进入阿里巴巴矢量图标库iconfont-阿里巴巴矢量图标库&#xff0c;添加图标到项目&#xff0c;然后下载至本地 解压后的本地文件如下&#xff0c;核心的是 iconfont.eot 文件 2.打开电脑命令行&#xff0c;执行以下命令&#xff0c;全局安装 iconfont-tools 转换工具 npm insta…

创新指南 | 推动销售的17个可落地的集客式营销示例

无论您是开启集客式的营销有一段时间还是处于起步阶段&#xff0c;了解像您这样的企业是如何粉碎竞争对手的的集客式策略总是有帮助的。无论您的公司做什么&#xff0c;它所服务的行业&#xff0c;是B2B还是B2C &#xff0c;您都可以在这里找到许多可以使用的示例。 在本文中&…

rtklib短基线分析(香港基站)

1、下载香港基站 用filezilla下载&#xff0c;地址ftp://ftp.geodetic.gov.hk 下载hklt、hkkt&#xff0c;这两座站的观测值及星历&#xff0c;必须同一天。用crx2rnx工具解压。 2、打开rtklib的RTKPOST&#xff0c;输入文件&#xff0c;如下图所示。 选择static模式&#xf…

C语言进阶--指针(C语言灵魂)

目录 1.字符指针 2.指针数组 3.数组指针 4.数组参数与指针参数 4.1.一维数组传参 4.2.二维数组传参 4.3.一级指针传参 4.4.二级指针传参 5.函数指针 6.函数指针数组 7.指向函数指针数组的指针 8.回调函数 qsort函数 9.指针和数组笔试题 10.指针笔试题 前期要点回…

《机器学习公式推导与代码实现》chapter10-AdaBoost

《机器学习公式推导与代码实现》学习笔记&#xff0c;记录一下自己的学习过程&#xff0c;详细的内容请大家购买作者的书籍查阅。 AdaBoost 将多个单模型组合成一个综合模型的方式早已成为现代机器学习模型采用的主流方法-集成模型(ensemble learning)。AdaBoost是集成学习中…

python获取热搜数据并保存成Excel

python获取百度热搜数据 一、获取目标、准备工作二、开始编码三、总结 一、获取目标、准备工作 1、获取目标&#xff1a; 本次获取教程目标&#xff1a;某度热搜 2、准备工作 环境python3.xrequestspandas requests跟pandas为本次教程所需的库&#xff0c;requests用于模拟h…

牛客网基础语法51~60题

牛客网基础语法51~60题&#x1f618;&#x1f618;&#x1f618; &#x1f4ab;前言&#xff1a;今天是咱们第六期刷牛客网上的题目。 &#x1f4ab;目标&#xff1a;对每种的循环知识掌握熟练&#xff0c;用数学知识和循环结合运用熟练&#xff0c;对逻辑操作符运用熟练。 &am…

接口测试入门神器 —— Requests

起源 众所周知&#xff0c;自动化测试是软件测试爱好者毕生探索的课题。我认为&#xff0c;只要把 接口测试 做好&#xff0c;你的自动化测试就至少成功了一半。 应部分热情读者要求&#xff0c;今天跟大家一起了解 python 接口测试库- Requests 的基本用法并进行实践&#x…

【跑实验03】如何可视化GT边界框,如何选择边界框内部的边界框,如何可视化GT框和预测框,如何定义IoU阈值下的不同边界框?

文章目录 一、如何可视化GT边界框&#xff1f;二、GT框和预测框的可视化三、根据IoU阈值来选择 一、如何可视化GT边界框&#xff1f; from PIL import Image, ImageDrawdef draw_bboxes(image, bboxes, color"red", thickness2):draw ImageDraw.Draw(image)for bbo…

精雕细琢,Smartbi电子表格软件重构、新增、完善

Smartbi SpreadSheet电子表格软件自发布以来&#xff0c;我们一直关注着用户的诉求&#xff0c;也在不断地对产品进行改进和优化&#xff0c;确保产品能够持续满足用户需求。经过一段时间的努力&#xff0c;产品在各方面都有了明显的改进&#xff0c;接下来&#xff0c;让我们一…

全网最详细的postman接口测试教程,一篇文章满足你

1、前言   之前还没实际做过接口测试的时候呢&#xff0c;对接口测试这个概念比较渺茫&#xff0c;只能靠百度&#xff0c;查看各种接口实例&#xff0c;然后在工作中也没用上&#xff0c;现在呢是各种各样的接口都丢过来&#xff0c;总算是有了个实际的认识。因为只是接口的…

不写单元测试的我,被批了 ,怎么说?

我是凡哥&#xff0c;一年CRUD经验用十年的markdown程序员&#x1f468;&#x1f3fb;‍&#x1f4bb;常年被誉为职业八股文选手 最近在看单元测试的东西&#xff0c;想跟大家聊聊我的感受。单元测试这块说实在的&#xff0c;我并不太熟悉&#xff0c;我几乎不写单元测试&…

k8s 的 Deployment控制器

1. RS与RC与Deployment关联 RC&#xff08;Replication Controller&#xff09;主要作用就是用来确保容器应用的副本数始终保持在用户定义的副本数。即如果有容器异常退出&#xff0c;会自动创建新的pod来替代&#xff1b;而如果异常多出来的容器也会自动回收。K8S官方建议使用…

JDBC BasicDAO详解(通俗易懂)

目录 一、前言 二、BasicDAO的引入 1.为什么需要BasicDAO&#xff1f; 2.BasicDAO示意图 : 三、BasicDAO的分析 1.基本说明 : 2.简单设计 : 四、BasicDAO的实现 0.准备工作 : 1.工具类 : 2.JavaBean类 : 3.BasicDAO类 / StusDAO类 : 4.测试类 : 一、前言 第七节内容…