Day 63 : 集成学习之 AdaBoosting (1. 带权数据集)

news2025/1/11 23:35:24

63.1 AdaBoosting基本算法:先从初始训练集训练一个弱学习器,在根据弱学习器的表现对训练样本进行权重调整,经过若干轮之后,将得到一组分类器,将数据输入这组分类器后会得到一个综合且准确的的分类结果。“三个臭皮匠,顶个诸葛亮”,多个这样的弱分类器相互补充,最后会变成一个强分类器。

63.2 代码:

package dl;

import java.util.Arrays;

import weka.core.Instances;

 * Weighted instances.
public class WeightedInstances extends Instances {

     * Just the requirement of some classes, any number is ok.
    private static final long serialVersionUID = 110;

     * Weights.
    private double[] weights;

     * The first constructor.
     * @param paraFileReader
     *            The given reader to read data from file.
    public WeightedInstances(FileReader paraFileReader) throws Exception {
        setClassIndex(numAttributes() - 1);

        // Initialize weights
        weights = new double[numInstances()];
        double tempAverage = 1.0 / numInstances();
        for (int i = 0; i < weights.length; i++) {
            weights[i] = tempAverage;
        } // Of for i
        System.out.println("Instances weights are: " + Arrays.toString(weights));
    } // Of the first constructor

     * The second constructor.
     * @param paraInstances
     *            The given instance.
    public WeightedInstances(Instances paraInstances) {
        setClassIndex(numAttributes() - 1);

        // Initialize weights
        weights = new double[numInstances()];
        double tempAverage = 1.0 / numInstances();
        for (int i = 0; i < weights.length; i++) {
            weights[i] = tempAverage;
        } // Of for i
        System.out.println("Instances weights are: " + Arrays.toString(weights));
    } // Of the second constructor

     * Getter.
     * @param paraIndex
     *            The given index.
     * @return The weight of the given index.
    public double getWeight(int paraIndex) {
        return weights[paraIndex];
    } // Of getWeight

     * Adjust the weights.
     * @param paraCorrectArray
     *            Indicate which instances have been correctly classified.
     * @param paraAlpha
     *            The weight of the last classifier.
    public void adjustWeights(boolean[] paraCorrectArray, double paraAlpha) {
        // Step 1. Calculate alpha.
        double tempIncrease = Math.exp(paraAlpha);

        // Step 2. Adjust.
        double tempWeightsSum = 0; // For normalization.
        for (int i = 0; i < weights.length; i++) {
            if (paraCorrectArray[i]) {
                weights[i] /= tempIncrease;
            } else {
                weights[i] *= tempIncrease;
            } // Of if
            tempWeightsSum += weights[i];
        } // Of for i

        // Step 3. Normalize.
        for (int i = 0; i < weights.length; i++) {
            weights[i] /= tempWeightsSum;
        } // Of for i

        System.out.println("After adjusting, instances weights are: " + Arrays.toString(weights));
    } // Of adjustWeights

     * Test the method.
    public void adjustWeightsTest() {
        boolean[] tempCorrectArray = new boolean[numInstances()];
        for (int i = 0; i < tempCorrectArray.length / 2; i++) {
            tempCorrectArray[i] = true;
        } // Of for i

        double tempWeightedError = 0.3;

        adjustWeights(tempCorrectArray, tempWeightedError);

        System.out.println("After adjusting");

    } // Of adjustWeightsTest

     * For display.
    public String toString() {
        String resultString = "I am a weighted Instances object.\r\n" + "I have " + numInstances() + " instances and "
                + (numAttributes() - 1) + " conditional attributes.\r\n" + "My weights are: " + Arrays.toString(weights)
                + "\r\n" + "My data are: \r\n" + super.toString();

        return resultString;
    } // Of toString

     * For unit test.
     * @param args
     *            Not provided.
    public static void main(String args[]) {
        WeightedInstances tempWeightedInstances = null;
        String tempFilename = "C:\\Users\\86183\\IdeaProjects\\deepLearning\\src\\main\\java\\resources\\iris.arff";
        try {
            FileReader tempFileReader = new FileReader(tempFilename);
            tempWeightedInstances = new WeightedInstances(tempFileReader);
        } catch (Exception exception1) {
            System.out.println("Cannot read the file: " + tempFilename + "\r\n" + exception1);
        } // Of try


    } // Of main

} // Of class WeightedInstances

63.3 结果(部分)







