Day 65: 集成学习之 AdaBoosting (3. 集成器)

news2024/11/25 22:27:33

代码:

package dl;

import java.io.FileReader;
import weka.core.Instance;
import weka.core.Instances;

/**
 * The booster which ensembles base classifiers.
 */
public class Booster {

    /**
     * Classifiers.
     */
    SimpleClassifier[] classifiers;

    /**
     * Number of classifiers.
     */
    int numClassifiers;

    /**
     * Whether or not stop after the training error is 0.
     */
    boolean stopAfterConverge = false;

    /**
     * The weights of classifiers.
     */
    double[] classifierWeights;

    /**
     * The training data.
     */
    Instances trainingData;

    /**
     * The testing data.
     */
    Instances testingData;

    /**
     ******************
     * The first constructor. The testing set is the same as the training set.
     *
     * @param paraTrainingFilename
     *            The data filename.
     ******************
     */
    public Booster(String paraTrainingFilename) {
        // Step 1. Read training set.
        try {
            FileReader tempFileReader = new FileReader(paraTrainingFilename);
            trainingData = new Instances(tempFileReader);
            tempFileReader.close();
        } catch (Exception ee) {
            System.out.println("Cannot read the file: " + paraTrainingFilename + "\r\n" + ee);
            System.exit(0);
        } // Of try

        // Step 2. Set the last attribute as the class index.
        trainingData.setClassIndex(trainingData.numAttributes() - 1);

        // Step 3. The testing data is the same as the training data.
        testingData = trainingData;

        stopAfterConverge = true;

        System.out.println("****************Data**********\r\n" + trainingData);
    }// Of the first constructor

    /**
     ******************
     * Set the number of base classifier, and allocate space for them.
     *
     * @param paraNumBaseClassifiers
     *            The number of base classifier.
     ******************
     */
    public void setNumBaseClassifiers(int paraNumBaseClassifiers) {
        numClassifiers = paraNumBaseClassifiers;

        // Step 1. Allocate space (only reference) for classifiers
        classifiers = new SimpleClassifier[numClassifiers];

        // Step 2. Initialize classifier weights.
        classifierWeights = new double[numClassifiers];
    }// Of setNumBaseClassifiers

    /**
     ******************
     * Train the booster.
     *
     * @see algorithm.StumpClassifier#train()
     ******************
     */
    public void train() {
        // Step 1. Initialize.
        WeightedInstances tempWeightedInstances = null;
        double tempError;
        numClassifiers = 0;

        // Step 2. Build other classifiers.
        for (int i = 0; i < classifiers.length; i++) {
            // Step 2.1 Key code: Construct or adjust the weightedInstances
            if (i == 0) {
                tempWeightedInstances = new WeightedInstances(trainingData);
            } else {
                // Adjust the weights of the data.
                tempWeightedInstances.adjustWeights(classifiers[i - 1].computeCorrectnessArray(),
                        classifierWeights[i - 1]);
            } // Of if

            // Step 2.2 Train the next classifier.
            classifiers[i] = new StumpClassifier(tempWeightedInstances);
            classifiers[i].train();

            tempError = classifiers[i].computeWeightedError();

            // Key code: Set the classifier weight.
            classifierWeights[i] = 0.5 * Math.log(1 / tempError - 1);
            if (classifierWeights[i] < 1e-6) {
                classifierWeights[i] = 0;
            } // Of if

            System.out.println("Classifier #" + i + " , weighted error = " + tempError + ", weight = "
                    + classifierWeights[i] + "\r\n");

            numClassifiers++;

            // The accuracy is enough.
            if (stopAfterConverge) {
                double tempTrainingAccuracy = computeTrainingAccuray();
                System.out.println("The accuracy of the booster is: " + tempTrainingAccuracy + "\r\n");
                if (tempTrainingAccuracy > 0.999999) {
                    System.out.println("Stop at the round: " + i + " due to converge.\r\n");
                    break;
                } // Of if
            } // Of if
        } // Of for i
    }// Of train

    /**
     ******************
     * Classify an instance.
     *
     * @param paraInstance
     *            The given instance.
     * @return The predicted label.
     ******************
     */
    public int classify(Instance paraInstance) {
        double[] tempLabelsCountArray = new double[trainingData.classAttribute().numValues()];
        for (int i = 0; i < numClassifiers; i++) {
            int tempLabel = classifiers[i].classify(paraInstance);
            tempLabelsCountArray[tempLabel] += classifierWeights[i];
        } // Of for i

        int resultLabel = -1;
        double tempMax = -1;
        for (int i = 0; i < tempLabelsCountArray.length; i++) {
            if (tempMax < tempLabelsCountArray[i]) {
                tempMax = tempLabelsCountArray[i];
                resultLabel = i;
            } // Of if
        } // Of for

        return resultLabel;
    }// Of classify

    /**
     ******************
     * Test the booster on the training data.
     *
     * @return The classification accuracy.
     ******************
     */
    public double test() {
        System.out.println("Testing on " + testingData.numInstances() + " instances.\r\n");

        return test(testingData);
    }// Of test

    /**
     ******************
     * Test the booster.
     *
     * @param paraInstances
     *            The testing set.
     * @return The classification accuracy.
     ******************
     */
    public double test(Instances paraInstances) {
        double tempCorrect = 0;
        paraInstances.setClassIndex(paraInstances.numAttributes() - 1);

        for (int i = 0; i < paraInstances.numInstances(); i++) {
            Instance tempInstance = paraInstances.instance(i);
            if (classify(tempInstance) == (int) tempInstance.classValue()) {
                tempCorrect++;
            } // Of if
        } // Of for i

        double resultAccuracy = tempCorrect / paraInstances.numInstances();
        System.out.println("The accuracy is: " + resultAccuracy);

        return resultAccuracy;
    } // Of test

    /**
     ******************
     * Compute the training accuracy of the booster. It is not weighted.
     *
     * @return The training accuracy.
     ******************
     */
    public double computeTrainingAccuray() {
        double tempCorrect = 0;

        for (int i = 0; i < trainingData.numInstances(); i++) {
            if (classify(trainingData.instance(i)) == (int) trainingData.instance(i).classValue()) {
                tempCorrect++;
            } // Of if
        } // Of for i

        double tempAccuracy = tempCorrect / trainingData.numInstances();

        return tempAccuracy;
    }// Of computeTrainingAccuray

    /**
     ******************
     * For integration test.
     *
     * @param args
     *            Not provided.
     ******************
     */
    public static void main(String args[]) {
        System.out.println("Starting AdaBoosting...");
        Booster tempBooster = new Booster("C:\\Users\\86183\\IdeaProjects\\deepLearning\\src\\main\\java\\resources\\iris.arff");

        tempBooster.setNumBaseClassifiers(100);
        tempBooster.train();

        System.out.println("The training accuracy is: " + tempBooster.computeTrainingAccuray());
        tempBooster.test();
    }// Of main

}// Of class Booster

结果:

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/785066.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

解决报错:Can‘t connect to HTTPS URL because the SSL module is not available.

本人今天准备打开安装一个label-studio包&#xff0c;试了很多次&#xff0c;接连报如下错误&#xff0c;因此我就去找了一些解决方案&#xff0c;现在总结如下&#xff1a; 1、报错信息如下 2、解决方案如下&#xff1a; github上有对应的解决方案&#xff0c;链接&#xff…

教师ChatGPT的23种用法

火爆全网的ChatGPT&#xff0c;作为教师应该如何正确使用&#xff1f;本文梳理了教师ChatGPT的23种用法&#xff0c;一起来看看吧&#xff01; 1、回答问题 ChatGPT可用于实时回答问题&#xff0c;使其成为需要快速获取信息的学生的有用工具。 从这个意义上说&#xff0c;Cha…

安卓开发后台应用周期循环获取位置信息上报服务器

问题背景 最近有需求&#xff0c;在APP启动后&#xff0c;退到后台&#xff0c;还要能实现周期获取位置信息上报服务器&#xff0c;研究了一下实现方案。 问题分析 一、APP退到后台后网络请求实现 APP退到后台后&#xff0c;实现周期循环发送网络请求。目前尝试了两种方案是…

Sui Builder House巴黎站精彩集锦

Sui Builder House巴黎站于7月19日圆满结束&#xff0c;Mysten Labs联合创始人兼CTO的Sam Blackshear在活动上发表了主题演讲。两天的Builder House活动还邀请了Mysten Labs的其他杰出成员分享Sui的发展情况和近期进展&#xff0c;社区成员展示了自己项目并提供见解&#xff0c…

C++继承体系中,基类析构函数请加上virtual,设置为虚函数

为什么建议在存在继承体系时刻我们的类的析构函数加上virtual呢&#xff1f; 大家看段代码。 咋一看&#xff0c;没什么毛病这段代码&#xff0c;让我们画图理解下。 紫框中的前4个字节指向new开辟的空间。 我们知道&#xff0c;当基类A指针指向基类B时候会发生切片 当我们del…

小程序如何修改商品

​商家可能会遇到需要修改产品信息的情况。无论是价格调整、库存更新还是商品描述的修改&#xff0c;小程序提供了简便的方式来帮助你们完成这些操作。下面是一些简单的步骤和注意事项&#xff0c;帮助你们顺利地修改商品。 一、进入商品管理页面 在个人中心点击管理入口&…

工厂电力监控解决方案

1、概述 电力监控系统实现对变压器、柴油发电机、断路器以及其它重要设备进行监视、测量、记录、报警等功能&#xff0c;并与保护设备和远方控制中心及其他设备通信&#xff0c;实时掌握供电系统运行状况和可能存在的隐患&#xff0c;快速排除故障&#xff0c;提高工厂供电可靠…

2023年Q2京东环境电器市场数据分析(京东数据产品)

今年Q2&#xff0c;环境电器市场中不少类目表现亮眼&#xff0c;尤其是以净水器、空气净化器、除湿机等为代表的环境健康电器。此外&#xff0c;像冷风扇这类具有强季节性特征的电器也呈现出比较好的增长态势。 接下来&#xff0c;结合具体数据我们一起来分析Q2环境电器市场中…

承接箱体透明拼接屏项目时,需要注意哪些事项?

承接箱体透明拼接屏项目时&#xff0c;需要注意以下事项&#xff1a; 确定需求&#xff1a;在承接箱体透明拼接屏项目之前&#xff0c;需要明确客户的需求&#xff0c;包括屏幕的大小、分辨率、亮度、色彩等参数&#xff0c;以及使用的环境、观看距离和观看角度等。 材料选择&…

图文教程:如何在 3DS Max 中创建3D迷你卡通房屋

推荐&#xff1a; NSDT场景编辑器助你快速搭建可二次开发的3D应用场景 在本教程中&#xff0c;我们将学习如何创建一个有趣的、低多边形的迷你动画房子&#xff0c;你可以在自己的插图或视频游戏项目中使用它。您将学习的一些技能将包括创建基本的3D形状和基本的建模技术。让我…

最简单的固定表格列实现

ref: https://dev.to/nicolaserny/table-with-a-fixed-first-column-2c5b 假设我们现在有这样一个表格 <table><thead><tr><th>姓名</th><th>性别</th><th>民族</th><th>年龄</th><th>籍贯</th>…

好用的敏捷开发项目管理工具有哪些?这3款真的绝绝子!

随着数字化的转型和企业团队成员不断追求高效的工作效率&#xff0c;越来越多优质的敏捷开发项目管理工具&#xff0c;深受广大管理者的青睐。今天我将通过这篇文章为大家介绍3款非常好用的开发项目管理工具&#xff0c;建议收藏起来&#xff01; ​ 1.boardmix boardmix博思…

Jmeter 中 Beanshell 的使用

目录 前言&#xff1a; Beanshell 介绍 常用内置变量 log vars 和 props vars 常用方法&#xff1a; props 常用方法&#xff1a; prev 综合运用 前言&#xff1a; JMeter 是一个广泛使用的性能测试工具&#xff0c;它支持许多不同的测试技术和方法。其中&#xff0c…

浏览器协议TCP详解

浏览器协议TCP详解 浏览器进程负责存储、界面、下载等管理。在渲染进程中&#xff0c;运行着熟知的主线程、合成线程、JavaScript 解释器、排版引擎等。 浏览器进程处理用户在地址栏的输入&#xff0c;然后将 URL 发送给网络进程。网络进程发送 URL 请求&#xff0c;在接收到响…

【Yolov8自动标注数据集完整教程】

Yolov8自动标注数据集完整教程 1 前言2 先手动标注数据集&#xff0c;训练出初步的检测模型2.1 手动标注数据集2.2 Yolov8环境配置2.2.1 Yolov8下载2.2.2 Yolov8环境配置 2.3 Yolov8模型训练&#xff0c;得到初步的检测模型2.3.1 训练方式 3 使用初步的检测模型实现自动数据集标…

STM32 I2C OVR 错误

一、问题 STM32 I2C 用作从机时&#xff0c;开启如下中断并启用 callback 回调函数。 每一次复位后&#xff0c;从机都可以正常触发地址匹配中断ADDR&#xff0c;之后在该中断的回调函数中启用接收中断去收取数据时&#xff0c;却无法进入RXNE中断&#xff0c;而是触发了 OVR …

《数据分析-JiMuReport08》JiMuReport报表开发-报表列数量开发限制调整

JiMuReport报表开发列数量限制调整 1.开发列数限制 JiMuReport报表在开发的时候&#xff0c;需要100-200列的数据&#xff0c;但是在设计到一定数量的时候&#xff0c;水平下拉框就不能滑动了 2.报表参数调整 col: n 在application.yml文件的jmreport配置处&#xff0c;如果想…

【C++】特殊类的设计 | 类型转换

文章目录 1. 特殊类的设计单例模式饿汉模式具体代码 懒汉模式具体代码 懒汉模式和饿汉模式的优缺点 2. C的类型转换C语言的类型转换C的类型转换static_castreinterpret_castconst_castdynamic_cast 1. 特殊类的设计 单例模式 设计模式是 被反复使用 多数人知晓 经过分类的、代…

【Docker】Docker中容器之间通信方式

文章目录 1. Docker容器之间通信的主要方式1.1 通过容器ip访问1.2. 通过宿主机的ip:port访问1.3. 通过link建立连接&#xff08;官方不推荐使用&#xff09;1.4. 通过 User-defined networks&#xff08;推荐&#xff09; 2. 参考资料 1. Docker容器之间通信的主要方式 1.1 通…

OpenCV图像处理-视频分割静态背景-MOG/MOG2/GMG

视频分割背景 1.概念介绍2. 函数介绍MOG算法MOG2算法GMG算法 原视频获取链接 1.概念介绍 视频背景扣除原理&#xff1a;视频是一组连续的帧&#xff08;一幅幅图组成&#xff09;&#xff0c;帧与帧之间关系密切(GOP/group of picture)&#xff0c;在GOP中&#xff0c;背景几乎…