Day_63-65 集成学习之 AdaBoosting

news2024/11/29 0:32:45

目录

Day_63-65

一. 基本概念介绍

        1. 集成学习

        2. 弱分类器与强分类器

 二. AdaBoosting算法

        1. AdaBoosting算法框架介绍

        2. AdaBoosting算法过程

三. 代码的实现过程

        1. WeightedInstances类

        2. 构造弱分类器的StumpClassifier类和抽象类SimpleClassifier

        3. 主类Booster的分析

四. 代码展示

五. 总结


Day_63-65

一. 基本概念介绍

        1. 集成学习

        集成学习大致可分为两大类:Bagging和Boosting,Bagging在这里不过多介绍,有兴趣的读者请参考另外的资料,这里主要讲Boosting算法。Boosting使用弱分类器,其个体学习器之间存在强依赖关系,是一种序列化方法。Boosting是一族算法,其主要目标为将弱学习器“提升”为强学习器,大部分Boosting算法都是根据前一个学习器的训练效果对样本分布进行调整,再根据新的样本分布训练下一个学习器,如此迭代M次,最后将一系列弱学习器组合成一个强学习器。而这些Boosting算法的不同点则主要体现在每轮样本分布的调整方式上。本文章先讨论Boosting的两大经典算法之一——AdaBoost。

        2. 弱分类器与强分类器

        这里有几个定义大家随便看看,

        ①一个分类器的分类准确率在60%-80%,即:比随机预测略好,但准确率却不太高,我们可以称之为“弱分类器”。反之,如果分类精度90%以上,则是强分类器。

        ②西瓜书上说:弱学习器常指泛化性能略优于随即猜测的学习器(例如在二分类问题上精度略高于50%的分类器)

        简而言之弱分类器是比较简单的分类器,举个简单的例子:若有一本书的页数大于500则判定为较难书籍,反之则判定为简单书籍。这是个非常简单的决策树,仅仅通过一个属性(页数)就来判断书籍的难易程度(显然不算太靠谱,但是准确率又大于50%),这便称为一个弱分类器。

 二. AdaBoosting算法

        1. AdaBoosting算法框架介绍

        AdaBoosting算法简单来说就是将多个弱分类器进行拼接得到一个强分类器。首先我们构建出多个弱分类器,接着每个弱分类器乘以对应的权重,然后把它们组合起来得到一个强分类器。最后输入一个测试数据,用这个强分类器进行判别。

        举个例子:判断一个房子要不要买?

        第一个弱分类器和它对应的权值(w1)

        第二个弱分类器它对应的权值(w2)

        第三个弱分类器它对应的权值(w3)

        最后我们输入一个数据(250万,三室,好地段),得到的结果(1表示买,-1表示不买)是w1×-1+w2×1+w3×1;这个数据若大于0,则买,否则不买。上述只是一个非常简单的例子,这里弱分类器还可以有很多个,并且属性可以重复,因为后期在计算权重(w)的时候,对于一个重复的属性不会有影响。

        现在我们的目标就是构建弱分类器,和调整对应的权重。

        2. AdaBoosting算法过程

        这部分会涉及到许多数学公式,请在小朋友的陪同下观看qaq。

        首先我们接下来的所有例子都来自下面的数据

属性1属性2属性3属性4标签
样本15.13.51.40.2Iris-setosa
样本24.93.01.40.2Iris-setosa
样本34.73.21.30.2Iris-setosa
样本44.63.11.50.2Iris-setosa
样本55.03.61.40.2Iris-setosa
样本65.43.91.70.4Iris-setosa
样本77.03.24.71.4Iris-versicolor
样本86.43.24.51.5Iris-versicolor
样本96.93.14.91.5Iris-versicolor
样本105.52.34.01.3Iris-versicolor
样本116.52.84.61.5Iris-versicolor
样本125.72.84.51.3Iris-versicolor
样本136.33.36.02.5Iris-virginica
样本145.82.75.11.9Iris-virginica
样本157.13.05.92.1Iris-virginica
样本166.32.95.61.8Iris-virginica
样本176.53.05.82.2Iris-virginica
样本187.63.06.62.1Iris-virginica

        2.1 初始化权重和弱分类器

        首先初始化数据的权重(总共18个数据);k表示迭代到第几轮了,i表示第几个数据样本,这里的m表示数据的个数。

        D(k)=(w_{k1},w_{k2},...w_{km})\ \ \ \ w_{1i}=1/m\ \ \ \ i=1,2...m

        接着我们构建一个弱分类器G_{k}(x)(这个过程可以暂时跳过,后面我再详谈)。

        2.2 计算误差率和弱分类器的权重系数

        由于弱分类器的特性,我们在考虑一个弱分类器G_{k}(x)对应的输入数据的的时候,只考虑单一属性。例如上面的数据我们随机选择一个属性作为这个弱分类器G_{k}(x)的输入(其他属性不考虑,集成学习就是多个弱分类器(G_{1}(x)G_{2}(x)...G_{k}(x))构建在一起的,所以多造几个弱分类器器就行了),若下图所示我们在构造分类器的时候只需要考虑一个属性1。

        分类问题的误差率很好理解和计算。由于多元分类是二元分类的推广,这里假设我们是二元分类问题(标签种类大于2的处理方法后面也会给出解释),这里我们先假定是二元分类问题输出为{-1,1}。那么对于第k个弱分类器G_{k}(x)在训练集上的加权误差率为,

e_{k}=P(G_{k}(x_{i})\neq y_{i})=\sum_{i=1}^{m}w_{ki}I(G_{k}(x_{i})\neq yi)

        意思就是上述的18个样本数据,通过这个分类器G_{k}(x),得到的分类结果和原本数据的分类结果不一致的话,那么就将这个数据样本的权重系数相加,这就是第k个弱分类器的误差参数e_{k}

        接着我们计算每个弱学习器的权重参数,对于二元分类问题,第k个弱分类器的G_{k}(x)权重系数有如下公式:至于每个弱分类器的权重系数为什么采用这种计算方式,有兴趣的读者请参考文章1,文章2。

a_{k}=\frac{1}{2}log\frac{1-e_{k}}{e_{k}}

        2.3 更新数据

        现在样本权重D(k)为:

D(k)=(w_{k1},w_{k2},...w_{km})

        则更新一轮的样本权重

w_{k+1,i}=\frac{w_{ki}}{Z_{K}}e^{(-a_{k}y_{i}G_{k}(x_{i}))}

        其中Z_{K}是规范化因子,

Z_{k}=\sum_{i=1}^{m}w_{ki}e^{(-a_{k}y_{i}G_{k}(x_{i}))}

        2.4 最终强分类器和迭代过程

        将上述过程进行K次得到K的弱分类器,并且将他们组合起来,最终我们得到了强分类器

f(x)=sign(\sum_{k=1}^{K}a_{k}G_{k}(x))

        再最后我们回顾一下这个算法的过程:首先我们对18个样本进行初始化权重D(k),接着将这18个样本输入我们构造的某一个弱分类器G_{k}(x),计算分类的误差e_{k}(分类的结果和原结果比对不一致的),接着我们更新每个弱分类器的权重参数a_{k},最后我们得到了一个强分类器f(x)

        2.5 弱分类器的构建

        详见三. 2.3 

        2.6 多元分类的处理方法

        详见三.2.3

三. 代码的实现过程

        这部分代码相当多,而且涉及到许多抽象类和接口,需要单独拎出来分析。所以按着代码的执行过程一步一步从整体框架的角度来分析执行过程。

        1. WeightedInstances类

        WeightedInstances首先继承的是Instances类,即有Instances类的调用方法如数据处理函数,另外WeightedInstances增加了一个weights数组用于存放每个样本数据的权重。WeightedInstances类可以简单理解成对样本数据权重的计算的类。

    /**
     * Weights.
     */
    private double[] weights;

        1.1两个构造函数:

        第一个构造函数表示传入数据的路径,读取数据,并且初始化设置weights数组的权重。

        第二个构造函数仅仅改变传入参数为Instances类,并且初始化设置weights数组的权重。

    /**
     ******************
     * The first constructor.
     *
     * @param paraFileReader
     *            The given reader to read data from file.
     ******************
     */
    public WeightedInstances(FileReader paraFileReader) throws Exception {
        super(paraFileReader);
        setClassIndex(numAttributes() - 1);

        // Initialize weights
        weights = new double[numInstances()];
        double tempAverage = 1.0 / numInstances();
        for (int i = 0; i < weights.length; i++) {
            weights[i] = tempAverage;
        } // Of for i
        System.out.println("Instances weights are: " + Arrays.toString(weights));
    } // Of the first constructor

    /**
     ******************
     * The second constructor.
     *
     * @param paraInstances
     *            The given instance.
     ******************
     */
    public WeightedInstances(Instances paraInstances) {
        super(paraInstances);
        setClassIndex(numAttributes() - 1);

        // Initialize weights
        weights = new double[numInstances()];
        double tempAverage = 1.0 / numInstances();
        for (int i = 0; i < weights.length; i++) {
            weights[i] = tempAverage;
        } // Of for i
        System.out.println("Instances weights are: " + Arrays.toString(weights));
    } // Of the second constructor

        1.2 得到样本的权重

        由于权重是private型变量,只能通过函数进行访问,所以这里设置了getWeight函数,传入某行i,输出第i行样本数据的权重。

    /**
     ******************
     * Getter.
     *
     * @param paraIndex
     *            The given index.
     * @return The weight of the given index.
     ******************
     */
    public double getWeight(int paraIndex) {
        return weights[paraIndex];
    } // Of getWeight

        1.3 WeightedInstances类的核心代码adjustWeights函数

        adjustWeights函数表示传入的参数为(上一个分类器的判定结果是否正确构成的数组,和上一个分类器的权重系数a_{k}),计算e^{(-a_{k}y_{i}G_{k}(x_{i}))},接着对每一个样本进行循环,如果判断正确,则计算公式为:weights[i] /= tempIncrease;否则weights[i] *= tempIncrease;。最后进行归一化操作:tempWeightsSum += weights[i];weights[i] /= tempWeightsSum;。

        这里的理论基础是:

        w_{k+1,i}=\frac{w_{ki}}{Z_{K}}e^{(-a_{k}y_{i}G_{k}(x_{i}))}        Z_{k}=\sum_{i=1}^{m}w_{ki}e^{(-a_{k}y_{i}G_{k}(x_{i}))}

        这样我们就更新了下一轮的样本数据的权重值

    /**
     ******************
     * Adjust the weights.
     *
     * @param paraCorrectArray
     *            Indicate which instances have been correctly classified.
     * @param paraAlpha
     *            The weight of the last classifier.
     ******************
     */
    public void adjustWeights(boolean[] paraCorrectArray, double paraAlpha) {
        // Step 1. Calculate alpha.
        double tempIncrease = Math.exp(paraAlpha);

        // Step 2. Adjust.
        double tempWeightsSum = 0; // For normalization.
        for (int i = 0; i < weights.length; i++) {
            if (paraCorrectArray[i]) {
                weights[i] /= tempIncrease;
            } else {
                weights[i] *= tempIncrease;
            } // Of if
            tempWeightsSum += weights[i];
        } // Of for i

        // Step 3. Normalize.
        for (int i = 0; i < weights.length; i++) {
            weights[i] /= tempWeightsSum;
        } // Of for i

        System.out.println("After adjusting, instances weights are: " + Arrays.toString(weights));
    } // Of adjustWeights

        总结:WeightedInstances类的作用主要是构建了一个权重数组weights用于记录每个样本的权重,并且它的核心代码是adjustWeights函数,它的作用是根据上一个分类器的判定结果上一个分类器的权重值调整下一轮样本数据的权重值

        2. 构造弱分类器的StumpClassifier类和抽象类SimpleClassifier

        StumpClassifier是继承SimpleClassifier类,而SimpleClassifier可以理解为一个接口。而StumpClassifier可以理解为构造弱分类器的类。

        2.1 基本参数解释

        bestCut是切割点(对于所有样本的某一个属性小于bestCut被分类为leftLeafLabel标签,大于bestCut被分类为rightLeafLabel),

    /**
     * The best cut for the current attribute on weightedInstances.
     */
    double bestCut;

    /**
     * The class label for attribute value less than bestCut.
     */
    int leftLeafLabel;

    /**
     * The class label for attribute value no less than bestCut.
     */
    int rightLeafLabel;

    /**
     ******************
     * The only constructor.
     *
     * @param paraWeightedInstances
     *            The given instances.
     ******************
     */

        2.2 构造函数

        调用的SimpleClassifier抽象类的函数(继承),

    /**
     ******************
     * The only constructor.
     *
     * @param paraWeightedInstances
     *            The given instances.
     ******************
     */
    public StumpClassifier(WeightedInstances paraWeightedInstances) {
        super(paraWeightedInstances);
    }// Of the only constructor

        下面是父类SimpleClassifier类的构造函数,主要是些赋值操作。

    /**
     ******************
     * The first constructor.
     *
     * @param paraWeightedInstances
     *            The given instances.
     ******************
     */
    public SimpleClassifier(WeightedInstances paraWeightedInstances) {
        weightedInstances = paraWeightedInstances;

        numConditions = weightedInstances.numAttributes() - 1;
        numInstances = weightedInstances.numInstances();
        numClasses = weightedInstances.classAttribute().numValues();
    }// Of the first constructor

        2.3 训练弱分类器(StumpClassifier类的核心代码,文章重点)

        我们现在所传入的数据是一堆加了权重的样本数据WeightedInstances类。

        首先我们选择样本的一个随机属性selectedAttribute,这是构造这个弱分类的标准

        // Step 1. Randomly choose an attribute.
        selectedAttribute = random.nextInt(numConditions);

        接着我们构造一个tempValuesArray数组存储所有样本的selectedAttribute属性的具体值,然后根据tempValuesArray数组的大小进行排序。

        // Step 2. Find all attribute values and sort.
        double[] tempValuesArray = new double[numInstances];
        for (int i = 0; i < tempValuesArray.length; i++) {
            tempValuesArray[i] = weightedInstances.instance(i).value(selectedAttribute);
        } // Of for i
        Arrays.sort(tempValuesArray);

        接着我们构造一个数组tempLabelCountArray,这个数组的大小为标签种类个数numClasses,然后扫描整个数据集,将这个数据集每个样本根据标签不同分成numClasses个类,并且记录他们的权值之和。

        // Step 3. Initialize, classify all instances as the same with the
        // original cut.
        int tempNumLabels = numClasses;
        double[] tempLabelCountArray = new double[tempNumLabels];
        int tempCurrentLabel;

        // Step 3.1 Scan all labels to obtain their counts.
        for (int i = 0; i < numInstances; i++) {
            // The label of the ith instance
            tempCurrentLabel = (int) weightedInstances.instance(i).classValue();
            tempLabelCountArray[tempCurrentLabel] += weightedInstances.getWeight(i);
        } // Of for i

        然后我们在这个实际样本里面寻找最佳标签,即遍历数组tempLabelCountArray,找到权值最大的标签种类tempBestLabel,和他们的权值tempMaxCorrect。这里的tempBestLabel和tempMaxCorrect都是实际数据得到的结果,一定要注意。

        // Step 3.2 Find the label with the maximal count.
        double tempMaxCorrect = 0;
        int tempBestLabel = -1;
        for (int i = 0; i < tempLabelCountArray.length; i++) {
            if (tempMaxCorrect < tempLabelCountArray[i]) {
                tempMaxCorrect = tempLabelCountArray[i];
                tempBestLabel = i;
            } // Of if
        } // Of for i

        然后我们寻找最佳的切割点,到底该怎么寻找呢?挨着挨着试每一个数据,我们首先设置一个切割点,这个切割点一定不能是数据集本身里面的数据,即这个切割点一定能把数据完全分成两份。

        首先我们先设置切割点为最小的值-0.1(这里-0.1就是为了使切割点不和原本的数据集重叠),初始化左右树判定结果都为tempBestLabel(真实数据里面最多的标签类)。

        // Step 3.3 The cut is a little bit smaller than the minimal value.
        bestCut = tempValuesArray[0] - 0.1;
        leftLeafLabel = tempBestLabel;
        rightLeafLabel = tempBestLabel;

        构建一个二维矩阵tempLabelCountMatrix,设置临时切割点tempCut。首先遍历所有的数据(为了找到最好的分割点),临时切割点tempCut= (tempValuesArray[i] + tempValuesArray[i + 1]) / 2并且这里的两个数据不能相等(若相等的话,就不能完全将数据集分成两个了)。这里需要讲一下tempLabelCountMatrix矩阵是用来干什么的,tempLabelCountMatrix矩阵用于记录每次设置tempCut切割点之后,对应遍历整个数据集然后将对应的权重相加到tempLabelCountMatrix。

        // Step 4. Check candidate cuts one by one.
        // Step 4.1 To handle multi-class data, left and right.
        double tempCut;
        double[][] tempLabelCountMatrix = new double[2][tempNumLabels];

        for (int i = 0; i < tempValuesArray.length - 1; i++) {
            // Step 4.1 Some attribute values are identical, ignore them.
            if (tempValuesArray[i] == tempValuesArray[i + 1]) {
                continue;
            } // Of if
            tempCut = (tempValuesArray[i] + tempValuesArray[i + 1]) / 2;

            // Step 4.2 Scan all labels to obtain their counts wrt. the cut.
            // Initialize again since it is used many times.
            for (int j = 0; j < 2; j++) {
                for (int k = 0; k < tempNumLabels; k++) {
                    tempLabelCountMatrix[j][k] = 0;
                } // Of for k
            } // Of for j

            for (int j = 0; j < numInstances; j++) {
                // The label of the jth instance
                tempCurrentLabel = (int) weightedInstances.instance(j).classValue();
                if (weightedInstances.instance(j).value(selectedAttribute) < tempCut) {
                    tempLabelCountMatrix[0][tempCurrentLabel] += weightedInstances.getWeight(j);
                } else {
                    tempLabelCountMatrix[1][tempCurrentLabel] += weightedInstances.getWeight(j);
                } // Of if
            } // Of for i

        每设置一个临时分割点tempcut,都要针对于这个切割点进行上述操作,接着从tempLabelCountMatrix矩阵里面选择出最有可能被分到的标签。

        这里的最有可能就是上述 ——二 2.6(多元分类的处理方法)的实现过程,即我们根据分割点统标签的权重大小,最大的那一个作为弱分类器的分类结果

        紧接着我们选择第一行权重最大的标签为tempLeftBestLabel,第一行最大的权重为tempLeftMaxCorrect;同理第二行权重最大的标签为tempRightBestLabel,第二行最大的权重为tempRightMaxCorrect。

            // Step 4.3 Left leaf.
            double tempLeftMaxCorrect = 0;
            int tempLeftBestLabel = 0;
            for (int j = 0; j < tempLabelCountMatrix[0].length; j++) {
                if (tempLeftMaxCorrect < tempLabelCountMatrix[0][j]) {
                    tempLeftMaxCorrect = tempLabelCountMatrix[0][j];
                    tempLeftBestLabel = j;
                } // Of if
            } // Of for i

            // Step 4.4 Right leaf.
            double tempRightMaxCorrect = 0;
            int tempRightBestLabel = 0;
            for (int j = 0; j < tempLabelCountMatrix[1].length; j++) {
                if (tempRightMaxCorrect < tempLabelCountMatrix[1][j]) {
                    tempRightMaxCorrect = tempLabelCountMatrix[1][j];
                    tempRightBestLabel = j;
                } // Of if
            } // Of for i

        这样我们就选出了一个可能的弱分类器,这个分类器怎么样呢?到底好不好呢?接着我们开始和之前的数据进行比较。

        如果说这里的分类结果的权重值tempLeftMaxCorrect + tempRightMaxCorrect>tempMaxCorrect(之前的权重值),那就说明这个分类器的效果更好(因为它占有更大的权重)更新一下tempMaxCorrect和最佳切割点bestCut。训练结束

            // Step 4.5 Compare with the current best.
            if (tempMaxCorrect < tempLeftMaxCorrect + tempRightMaxCorrect) {
                tempMaxCorrect = tempLeftMaxCorrect + tempRightMaxCorrect;
                bestCut = tempCut;
                leftLeafLabel = tempLeftBestLabel;
                rightLeafLabel = tempRightBestLabel;
            } // Of if

        最后总结一下,train这部分代码相当难理解,主要是要明白输入参数是什么,输出参数是什么以及我们构建的每一个变量的作用是什么,除此之外还需要对AdaBoosting算法理解到位才行。然后train主要做的一个工作就是根据样本的某一个属性,构建分类器,这个分类器怎么样构建的(挨着挨着试,试出来最优的分类器(权重最大));以及对于多元分类问题的处理方式(普通的AdaBoosting算法是二元分类问题),都是在这里解决,可以说这部分理解到位之后,后面就是豁然开朗。

        2.4 弱分类器的预测

        上面搞了这么大一堆东西,主要不就是为了得到一个弱分类器吗?他来了,

        传入某一个样本(paraInstance);接着bestCut作为最佳分类点,selectedAttribute为挑选的属性,leftLeafLabel为小于bestCut的分类结果,rightLeafLabel为大于bestCut的分类结果。

        3. 主类Booster的分析

        完成了上述的过程,我们终于到主类了

        3.1 基本初始化

        首先来看这个算法的起点——主函数

    /**
     ******************
     * For integration test.
     *
     * @param args
     *            Not provided.
     ******************
     */
    public static void main(String args[]) {
        System.out.println("Starting AdaBoosting...");
        Booster tempBooster = new Booster("D:/data/iris.arff");
        // Booster tempBooster = new Booster("src/data/smalliris.arff");

        tempBooster.setNumBaseClassifiers(100);
        tempBooster.train();

        System.out.println("The training accuracy is: " + tempBooster.computeTrainingAccuray());
        tempBooster.test();
    }// Of main

        这里有一个构造函数Booster,输入的是文件的路径,这个构造函数的主要作用是赋值某些参数,没什么特别的,看看就好

    /**
     ******************
     * The first constructor. The testing set is the same as the training set.
     *
     * @param paraTrainingFilename
     *            The data filename.
     ******************
     */
    public Booster(String paraTrainingFilename) {
        // Step 1. Read training set.
        try {
            FileReader tempFileReader = new FileReader(paraTrainingFilename);
            trainingData = new Instances(tempFileReader);
            tempFileReader.close();
        } catch (Exception ee) {
            System.out.println("Cannot read the file: " + paraTrainingFilename + "\r\n" + ee);
            System.exit(0);
        } // Of try

        // Step 2. Set the last attribute as the class index.
        trainingData.setClassIndex(trainingData.numAttributes() - 1);

        // Step 3. The testing data is the same as the training data.
        testingData = trainingData;

        stopAfterConverge = true;

//        System.out.println("****************Data**********\r\n" + trainingData);
    }// Of the first constructor

        接着我们来到主函数里面的tempBooster.setNumBaseClassifiers(100),这里主要是设置弱分类器的个数。

        详细的setNumBaseClassifiers函数如下所示,classifiers构造弱分类器的数组,每个数组的值空间里面都是一个弱分类器的数据结构;classifierWeights是每个弱分类器的权重。

    /**
     ******************
     * Set the number of base classifier, and allocate space for them.
     *
     * @param paraNumBaseClassifiers
     *            The number of base classifier.
     ******************
     */
    public void setNumBaseClassifiers(int paraNumBaseClassifiers) {
        numClassifiers = paraNumBaseClassifiers;

        // Step 1. Allocate space (only reference) for classifiers
        classifiers = new SimpleClassifier[numClassifiers];

        // Step 2. Initialize classifier weights.
        classifierWeights = new double[numClassifiers];
    }// Of setNumBaseClassifiers

        3.2 训练函数

        ①由于我们的分类器的个数是100(也是迭代的次数),第一重循环是迭代的次数(100次)。接着做了一个if判断语句,若迭代的次数是第一次时,我们直接初始化权重类WeightedInstances;若不是第一次时,根据上一个分类器的分类结果classifiers[i - 1].computeCorrectnessArray()和分类器的权重值classifierWeights[i - 1]进行调整(详见三. 1.3);总之我们得到了每一个样本现在的权重对象tempWeightedInstances。

    /**
     ******************
     * Train the booster.
     *
     * @see algorithm.StumpClassifier#train()
     ******************
     */
    public void train() {
        // Step 1. Initialize.
        WeightedInstances tempWeightedInstances = null;
        double tempError;
        numClassifiers = 0;

        // Step 2. Build other classifiers.
        for (int i = 0; i < classifiers.length; i++) {
            // Step 2.1 Key code: Construct or adjust the weightedInstances
            if (i == 0) {
                tempWeightedInstances = new WeightedInstances(trainingData);
            } else {
                // Adjust the weights of the data.
                tempWeightedInstances.adjustWeights(classifiers[i - 1].computeCorrectnessArray(),
                        classifierWeights[i - 1]);
            } // Of if

        ②最后根据这个tempWeightedInstances对象的权重和数据训练得到classifiers[i](弱分类器的数值(bestcut,leftlabel,rightlabel))。

            // Step 2.2 Train the next classifier.
            classifiers[i] = new StumpClassifier(tempWeightedInstances);
            classifiers[i].train();

        ③接着我们计算这个弱分类器器的误差

tempError = classifiers[i].computeWeightedError();

        以下是computeWeightedError()函数

        tempCorrectnessArray = computeCorrectnessArray();用于计算通过这个分类器,得到判定数组tempCorrectnessArray[ ](若通过这个弱分类器得到的是和数据一样的标签,则为true,否则为false)。

        接着我们遍历整个判定数组tempCorrectnessArray,如果判定数组的某个值为false,则将它的权重相加到resultError。即文章理论部分的e_{k}

    /**
     ******************
     * Compute the weighted error on the training set. It is at least 1e-6 to
     * avoid NaN.
     *
     * @return The weighted error.
     ******************
     */
    public double computeWeightedError() {
        double resultError = 0;
        boolean[] tempCorrectnessArray = computeCorrectnessArray();
        for (int i = 0; i < tempCorrectnessArray.length; i++) {
            if (!tempCorrectnessArray[i]) {
                resultError += weightedInstances.getWeight(i);
            } // Of if
        } // Of for i

        if (resultError < 1e-6) {
            resultError = 1e-6;
        } // Of if

        return resultError;
    }// Of computeWeightedError
} // Of class SimpleClassifier

         以下是computeCorrectnessArray函数

   /**
     ******************
     * Which instances in the training set are correctly classified.
     *
     * @return The correctness array.
     ******************
     */
    public boolean[] computeCorrectnessArray() {
        boolean[] resultCorrectnessArray = new boolean[weightedInstances.numInstances()];
        for (int i = 0; i < resultCorrectnessArray.length; i++) {
            Instance tempInstance = weightedInstances.instance(i);
            if ((int) (tempInstance.classValue()) == classify(tempInstance)) {
                resultCorrectnessArray[i] = true;
            } // Of if

            // System.out.print("\t" + classify(tempInstance));
        } // Of for i
        // System.out.println();
        return resultCorrectnessArray;
    }// Of computeCorrectnessArray

        ④由上述关系得到e_{k}=P(G_{k}(x_{i})\neq y_{i})=\sum_{i=1}^{m}w_{ki}I(G_{k}(x_{i})\neq yi),接着用e_{k}计算这个分类器的权重值a_{k}=\frac{1}{2}log\frac{1-e_{k}}{e_{k}},分类器数(迭代次数)自加1。

            // Key code: Set the classifier weight.
            classifierWeights[i] = 0.5 * Math.log(1 / tempError - 1);
            if (classifierWeights[i] < 1e-6) {
                classifierWeights[i] = 0;
            } // Of if

            System.out.println("Classifier #" + i + " , weighted error = " + tempError + ", weight = "
                    + classifierWeights[i] + "\r\n");

            numClassifiers++;

        ⑤最后每次迭代完成之后都需要计算一下这些所有已经构建的分类器的分类效果(训练集作为测试集),如果训练结果为1,那么不用继续迭代了,直接输出已经得到的分类器的值(这个时候可能没有达到100次,但是预测准确率已经达到1了,比如迭代到第k次,那么我们只需要取前k个弱分类器即可,后面的分类器直接抛弃)

            if (stopAfterConverge) {
                double tempTrainingAccuracy = computeTrainingAccuray();
                System.out.println("The accuracy of the booster is: " + tempTrainingAccuracy + "\r\n");
                if (tempTrainingAccuracy > 0.999999) {
                    System.out.println("Stop at the round: " + i + " due to converge.\r\n");
                    break;
                } // Of if
            } // Of if

        计算准确率的代码computeTrainingAccuray()函数

    /**
     ******************
     * Compute the training accuracy of the booster. It is not weighted.
     *
     * @return The training accuracy.
     ******************
     */
    public double computeTrainingAccuray() {
        double tempCorrect = 0;

        for (int i = 0; i < trainingData.numInstances(); i++) {
            if (classify(trainingData.instance(i)) == (int) trainingData.instance(i).classValue()) {
                tempCorrect++;
            } // Of if
        } // Of for i

        double tempAccuracy = tempCorrect / trainingData.numInstances();

        return tempAccuracy;
    }// Of computeTrainingAccuray

        3.3 测试函数

        System.out.println("The training accuracy is: " + tempBooster.computeTrainingAccuray());
        tempBooster.test();
    /**
     ******************
     * Test the booster.
     *
     * @param paraInstances
     *            The testing set.
     * @return The classification accuracy.
     ******************
     */
    public double test(Instances paraInstances) {
        double tempCorrect = 0;
        paraInstances.setClassIndex(paraInstances.numAttributes() - 1);

        for (int i = 0; i < paraInstances.numInstances(); i++) {
            Instance tempInstance = paraInstances.instance(i);
            if (classify(tempInstance) == (int) tempInstance.classValue()) {
                tempCorrect++;
            } // Of if
        } // Of for i

        double resultAccuracy = tempCorrect / paraInstances.numInstances();
        System.out.println("The accuracy is: " + resultAccuracy);

        return resultAccuracy;
    } // Of test

        这里是核心,主要是tempLabel = classifiers[i].classify(paraInstance);得到具体样本的分离标签tempLabel,接着tempLabelsCountArray[tempLabel] += classifierWeights[i];,统计每个标签的分类器权重值。最后选出最大的权重对应的标签即作为预测结果。

    /**
     ******************
     * Classify an instance.
     *
     * @param paraInstance
     *            The given instance.
     * @return The predicted label.
     ******************
     */
    public int classify(Instance paraInstance) {
        double[] tempLabelsCountArray = new double[trainingData.classAttribute().numValues()];
        for (int i = 0; i < numClassifiers; i++) {
            int tempLabel = classifiers[i].classify(paraInstance);
            tempLabelsCountArray[tempLabel] += classifierWeights[i];
        } // Of for i

        int resultLabel = -1;
        double tempMax = -1;
        for (int i = 0; i < tempLabelsCountArray.length; i++) {
            if (tempMax < tempLabelsCountArray[i]) {
                tempMax = tempLabelsCountArray[i];
                resultLabel = i;
            } // Of if
        } // Of for

        return resultLabel;
    }// Of classify

四. 代码展示

        详见文章(19条消息) 日撸 Java 三百行(61-70天,决策树与集成学习)_闵帆的博客-CSDN博客

五. 总结

        在做较大的算法学习的时候,我觉得第一重要的是理解到算法的过程,为什么要这么设置这个计算方法,需要达到的最优化目标是什么。这些都是数学基础,可用数学公式描绘出来,一步一步踏踏实实了解到每个公式的具体含义是什么。

        完成上述理论学习之后,为了加深自己的理解,可以举一个数据集的例子,比如怎么样对数据进行处理的,数学公式的展现是怎么样体现的。        

        最后到了代码的编写阶段,按着上面两步,编写代码其实就是复现的过程,写代码和理解理论本来就是相辅相成的过程,学习理论才可以写得出来代码,写出来代码又对理论进行了加深和巩固,这编写代码我们需要做的只不过是按着步骤一一做下去而已。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/741480.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Elastic 连续第三年被评为 2023 年 Gartner® Magic Quadrant™ 的 APM 和可观察性远见者

作者&#xff1a;Gagan Singh 我们很高兴地宣布&#xff0c;Elastic 连续第三年被评为 2023 年 Gartner 应用程序性能监控 (APM) 和可观测性魔力象限中的远见者。 Elastic 因其愿景的完整性和执行能力而受到认可 我们相信&#xff0c;Elastic 被认可为远见者&#xff0c;验证了…

自动化测试平台策略之:自动化测试与项目的结合之路

目录 前言&#xff1a; 一、自动化测试开展在整个项目中存在的一些问题 二、自动化测试与项目结合之路 三、自动化测试平台之项目系统建设 前言&#xff1a; 自动化测试平台是实施自动化测试的关键组成部分&#xff0c;它可以帮助测试团队提高测试效率、加速反馈周期&#xff0…

vue 后台返回列表H5点击按钮加载更多分页数据与van-tab记住选中状态

效果图&#xff08;点击更多订单加载&#xff0c;一次加载10条&#xff09;&#xff1a; <template><div id"order" class"wap-el page-container wap-com-page"><section><com-header></com-header></section><di…

6.1Java EE——Spring介绍

一、Spring概述 String框架的核心技术 Spring是由Rod Johnson组织和开发的一个分层的Java SE/EE一站式&#xff08;full-stack&#xff09;轻量级开源框架。它最为核心的理念是IoC&#xff08;控制反转&#xff09;和AOP&#xff08;面向切面编程&#xff09;&#xff0c;其中&…

声音合成与克隆——制作用于训练的声音数据集

前言 1.PaddleSpeech 是一个简单易用的all-in-one 的语音工具箱&#xff0c;支持语音处理的相关操作&#xff0c;如语音知别&#xff0c;语音合成&#xff0c;声纹识别&#xff0c;声音分类&#xff0c;语音翻译&#xff0c;语音唤醒等多个方向的应用开发。 这里只使用到语音…

C++之模板类重写基类构造函数(一百五十七)

简介&#xff1a; CSDN博客专家&#xff0c;专注Android/Linux系统&#xff0c;分享多mic语音方案、音视频、编解码等技术&#xff0c;与大家一起成长&#xff01; 优质专栏&#xff1a;Audio工程师进阶系列【原创干货持续更新中……】&#x1f680; 人生格言&#xff1a; 人生…

matlab[1,1]生成100个随机点

close all; clc; x linspace(0,1,200); y 0 rand(1,200); sz 25; c linspace(1,10,length(x)); scatter(x,y,sz,c,filled)

IO流学习07(Java)

序列化流&#xff08;对象操作输出流&#xff09;&#xff1a; 可以把java中的对象写到本地文件中。 public objectoutputstream(outputstream out) 把基本流包装成高级流。 public final void writeobject(object obj) 把对象序列化&#xff08;写出&#xff09;到文…

如何在Docker和Kubernetes中使用代理IP?

Docker和Kubernetes是目前非常流行的容器化技术&#xff0c;这些技术被广泛用于开发、部署和管理应用程序。在某些情况下&#xff0c;需要使用代理IP来访问特定的网络资源。本文将介绍如何在Docker和Kubernetes中使用代理IP&#xff0c;并提供详细的举例说明。 一、在Docker中使…

如何增强农业防灾减灾能力,加强灾情监测与风险预估

近日&#xff0c;农业农村部会同各部门联合下发通知&#xff0c;要求各地坚持问题导向&#xff0c;分区分类指导&#xff0c;细化实化措施&#xff0c;千方百计夺取秋粮和全年粮食丰收。文件中提到要通过加强灾害风险预报预警和灾情监测调度、分区分类做好灾情防范应对来应对气…

ASEMI整流桥GBU808参数和应用

编辑-Z 整流桥GBU808是一种常见的电子元件&#xff0c;用于将交流电转换为直流电。它由四个二极管组成&#xff0c;可以全波整流。GBU808具有高电流和高电压的特点&#xff0c;适用于各种电源和电路应用。 GBU808的主要特点之一是其高电流能力。它可以承受高达8安培的电流&…

嵌入式开发之串口通讯

串口通信(Serial Communication)&#xff0c; 是指外设和计算机间&#xff0c;通过数据信号线 、地线、控制线等&#xff0c;按位进行传输数据的一种通讯方式。这种通信方式使用的数据线少&#xff0c;在远距离通信中可以节约通信成本&#xff0c;但其传输速度比并行传输低&…

springboot会员制医疗预约服务管理信息系统

针对会员制医疗预约服务行业的管理现状&#xff0c;本会员制医疗预约服务管理信息系统主要实现以下几个目标&#xff1a; 1.系统界面简洁&#xff0c;操作简便。 2.拥有精准&#xff0c;高效的查询功能。 3.使管理人员能够及时的获得精确的报表。 4.对数据…

docker入门(Linux环境下安装Docker,Docker构建镜像)

docker入门(利用docker部署web应用) 一:什么是Docker 1.1 官方解释 Docker is the world’s leading software containerization platform。 Docker公司开发&#xff0c;开源&#xff0c;托管在github跨平台&#xff0c; 支持Windows、Macos、Linux。 1.2 抽象解释 docker…

【状态设计优化DP】ABC307 E

E - Distinct Adjacent (atcoder.jp) 题意&#xff1a; 思路&#xff1a; 组合问题&#xff0c;考虑DP或组合数 组合数不好考虑&#xff0c;我们去考虑DP 因为是个环&#xff0c;我们把环拆成一条链&#xff0c;然后加一个N1&#xff0c;颜色和起点1相同&#xff0c;在这条…

天台玻璃折叠门可实现室内外空间的无缝连接

天玻璃折叠门是指安装在天台上的可折叠开合的玻璃门&#xff0c;可用于将室外空间与室内空间进行隔离或连接。设计天台玻璃折叠门时需要注意以下几点&#xff1a; 1. 结构稳固性&#xff1a;选择坚固、稳定的材料和结构设计&#xff0c;确保门体在风力和其他外力作用下不易摇晃…

如何规范的设计数据库表

前言对于后端开发同学来说&#xff0c;访问数据库&#xff0c;是代码中必不可少的一个环节。系统中收集到用户的核心数据&#xff0c;为了安全性&#xff0c;我们一般会存储到数据库&#xff0c;比如&#xff1a;mysql&#xff0c;oracle等。后端开发的日常工作&#xff0c;需要…

制作搭建宠物商城小程序,打造便捷的宠物购物体验

随着宠物市场的不断发展&#xff0c;宠物商城小程序成为了满足宠物爱好者需求的重要工具。在现代社会&#xff0c;宠物已经成为人们生活中不可或缺的一部分。作为宠物爱好者&#xff0c;我们对于宠物食品、用品、医疗保健品等需求日益增长。而宠物商城小程序则为我们提供了一个…

python_day5_file

open()打开函数&#xff1a; f open(name,mode,encoding) name:要打开的目标文件名 mode:访问模式&#xff1a;只读r、写入w、追加a 等 encoding:编码格式&#xff0c;常为UTF-8 f open("D:\Test.txt", "r", encoding"UTF-8") print(type(f))r…

Dbeaver 显示字段备注信息

一、全局设置显示字段描述