[机器学习]AdaBoost(数学原理 + 例子解释 + 代码实战)

news2024/12/21 22:45:18

AdaBoost

AdaBoost(Adaptive Boosting)是一种Boosting算法,它通过迭代地训练弱分类器并将它们组合成一个强分类器来提高分类性能。

AdaBoost算法的特点是它能够自适应地调整样本的权重,使那些被错误分类的样本在后续的训练中得到更多的关注。

加法模型

AdaBoost算法的最终分类器是一个加法模型,即多个弱分类器的线性组合。数学表达式如下:

f ( x ) = ∑ m = 1 M α m G m ( x ) f(x) = \sum_{m=1}^{M} \alpha_m G_m(x) f(x)=m=1MαmGm(x)

其中, G m ( x ) G_m(x) Gm(x)是第m个弱分类器, α m \alpha_m αm 是第m个弱分类器的权重, M M M 是弱分类器的总数。

训练过程

  1. 初始化样本权重:在第一轮迭代中,所有样本的权重都相等,即每个样本的权重为 1 N \frac{1}{N} N1,其中N是样本总数。

  2. 训练弱分类器:在每一轮迭代中,使用当前的样本权重来训练一个弱分类器 G m ( x ) G_m(x) Gm(x)

  3. 计算分类误差率:计算弱分类器 G m ( x ) G_m(x) Gm(x)在训练集上的分类误差率 ϵ m \epsilon_m ϵm,即被错误分类的样本数占总样本数的比例。

    • 分类误差率范围确定: 0 < = ϵ m < = 0.5 0<=\epsilon_m<=0.5 0<=ϵm<=0.5
    • 分类误差率计算公式为: ϵ m = ∑ i = 1 N w i ( m ) ⋅ I ( y i ≠ G m ( x i ) ) = ∑ y i ≠ G m ( x i ) w i ( m ) \epsilon_m ={\sum_{i=1}^{N} w_i^{(m)} \cdot \mathbb{I}(y_i \neq G_m(x_i))} = {\sum_{y_i \neq G_m(x_i)} w_i^{(m)}} ϵm=i=1Nwi(m)I(yi=Gm(xi))=yi=Gm(xi)wi(m)
      • I ( y i ≠ G m ( x i ) ) \mathbb{I}(y_i \neq G_m(x_i)) I(yi=Gm(xi))是一个指示函数(也称为指示变量),当样本i的真实标签 y i y_i yi 与弱分类器对样本i的预测 G m ( x i ) G_m(x_i) Gm(xi) 不相等,即样本被错误分类时,该函数的值为1;如果相等,即样本被正确分类时,该函数的值为0。
      • ∑ y i ≠ G m ( x i ) w i ( m ) \sum_{y_i \neq G_m(x_i)} w_i^{(m)} yi=Gm(xi)wi(m)对所有被第m个弱分类器错误分类的样本的权重进行累加。
  4. 计算弱分类器权重:根据分类误差率 ϵ m \epsilon_m ϵm 计算弱分类器的权重 α m \alpha_m αm
    - 分类误差率越大,权重越小;反之,分类误差率越小,权值越大。
    - 权重的计算公式为:
    α m = 1 2 ln ⁡ ( 1 − ϵ m ϵ m ) \alpha_m = \frac{1}{2} \ln \left( \frac{1 - \epsilon_m}{\epsilon_m} \right) αm=21ln(ϵm1ϵm)

    • 1 − ϵ m 1−ϵ_m 1ϵm 是第m个弱分类器在训练集上的正确率,即被正确分类的样本数占总样本数的比例。
    • 对数函数 ln ⁡ ( 1 − ϵ m ϵ m ) \ln \left( \frac{1 - \epsilon_m}{\epsilon_m} \right) ln(ϵm1ϵm) 用于计算正确率与误差率的比值的自然对数。这个比值反映了弱分类器的性能,正确率越高,误差率越低,比值越大。
  5. 更新样本权重:根据弱分类器的预测结果更新样本权重。对于被正确分类的样本,权重降低;对于被错误分类的样本,权重提高。将样本权重的更新视为损失函数
    在这里插入图片描述

更新公式为:
w i ( m + 1 ) = w i ( m ) ⋅ exp ⁡ ( − α m ⋅ y i ⋅ G m ( x i ) ) Z m w_{i}^{(m+1)} =\frac {w_{i}^{(m)} \cdot \exp(-\alpha_m \cdot y_i \cdot G_m(x_i)) }{Z_m} wi(m+1)=Zmwi(m)exp(αmyiGm(xi))
- 其中, w i ( m ) w_{i}^{(m)} wi(m) 是第m轮中第i个样本的权重, y i y_i yi 是第i个样本的真实标签, G m ( x i ) G_m(x_i) Gm(xi) 是第m个弱分类器对第i个样本的预测结果
- Z m Z_m Zm是归一化因子,目的是把分子映射到0-1范围内。 Z m = ∑ i = 1 N w i ( m ) ⋅ exp ⁡ ( − α m ⋅ y i ⋅ G m ( x i ) ) Z_m = \sum_{i = 1}^Nw_{i}^{(m)} \cdot \exp(-\alpha_m \cdot y_i \cdot G_m(x_i)) Zm=i=1Nwi(m)exp(αmyiGm(xi))
- 对于被正确分类的样本, y i ⋅ G m ( x i ) y_i \cdot G_m(x_i) yiGm(xi)同号,指数函数的值为 e − α m e^{-\alpha_m} eαm 小于1,样本权重降低。
- 对于被错误分类的样本, y i ⋅ G m ( x i ) y_i \cdot G_m(x_i) yiGm(xi)异号,指数函数的值为 e α m e^{\alpha_m} eαm 大于1,样本权重提高。
- 上述公式也可以写成这样:
在这里插入图片描述

  1. 迭代:重复步骤2到6,直到达到指定的迭代次数M或总分类器的精度达到设定的阈值。

  2. 最终预测:在所有弱分类器训练完成后,AdaBoost算法通过加权多数表决来确定最终的分类结果。对于一个新样本x,最终的预测结果是所有弱分类器预测结果的加权和:
    f ( x ) = ∑ m = 1 M α m G m ( x ) f(x) = \sum_{m=1}^{M} \alpha_m G_m(x) f(x)=m=1MαmGm(x)
    对于分类问题,最终的预测类别是使 f ( x ) f(x) f(x)最大化的类别。

例子

例子来源

在这里插入图片描述
在这里插入图片描述

在这里插入图片描述

在这里插入图片描述
在这里插入图片描述

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

代码实现

import numpy as np
from sklearn.datasets import make_classification
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split

# make_classification生成包含1000个样本和20个特征的模拟二分类数据集
X, y = make_classification(n_samples=1000, n_features=20, n_informative=2, n_redundant=0, random_state=42)
y = np.where(y == 0, -1, 1)  # 将标签转换为-1和1

# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# 1. 初始化样本权重
sample_weights = np.ones_like(y_train) / len(y_train)

# 设置参数
n_estimators = 50  # 弱分类器的数量
learning_rate = 1.0  # 学习率

# 初始化弱分类器列表
weak_classifiers = []

for m in range(n_estimators):
    # 2. 训练弱分类器
    from sklearn.tree import DecisionTreeClassifier
    clf = DecisionTreeClassifier(max_depth=1)
    clf.fit(X_train, y_train, sample_weight=sample_weights)
    
    y_pred = clf.predict(X_train)
    
    # 3. 计算分类误差率
    incorrect = np.sum(sample_weights * (y_train != y_pred))
    error_rate = incorrect / np.sum(sample_weights)
    
    # 如果误差率大于0.5,则拒绝这个分类器
    if error_rate > 0.5:
        continue
    
    # 4. 计算弱分类器权重
    alpha = np.log((1.0 - error_rate) / error_rate) / 2.0
    
    # 更新弱分类器列表
    weak_classifiers.append((clf, alpha))
    
    # 5. 更新样本权重
    sample_weights *= np.exp(-alpha * y_train * y_pred)
    sample_weights /= np.sum(sample_weights)  # 归一化权重


def predict(X, classifiers):
    votes = np.zeros((X.shape[0],))
    for clf, alpha in classifiers:
        votes += alpha * clf.predict(X)
    return np.sign(votes)

# 7.预测
train_pred = predict(X_train, weak_classifiers)
test_pred = predict(X_test, weak_classifiers)

# 计算准确率
train_accuracy = accuracy_score(y_train, train_pred)
test_accuracy = accuracy_score(y_test, test_pred)

print(f"Train Accuracy: {train_accuracy:.4f}")
print(f"Test Accuracy: {test_accuracy:.4f}")

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2263465.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

详细解读TISAX认证的意义

详细解读TISAX认证的意义&#xff0c;犹如揭开信息安全领域的一颗璀璨明珠&#xff0c;它不仅代表了企业在信息安全管理方面的卓越成就&#xff0c;更是通往全球汽车供应链信任桥梁的关键一环。TISAX&#xff0c;即“Trusted Information Security Assessment Exchange”&#…

黑马Redis数据结构学习笔记

Redis数据结构 动态字符串 Intset Dict ZipList QuickList SkipList 类似倍增 RedisObject 五种数据类型 String List Set ZSet Hash

sqlilabs靶场二十一关二十五关攻略

第二十一关 第一步 可以发现cookie是经过64位加密的 我们试试在这里注入 选择给他编码 发现可以成功注入 爆出表名 爆出字段 爆出数据 第二十二关 跟二十一关一模一样 闭合换成" 第二十三关 第二十三关重新回到get请求&#xff0c;会发现输入单引号报错&#xff0c…

Win10将WindowsTerminal设置默认终端并添加到右键(无法使用微软商店)

由于公司内网限制&#xff0c;无法通过微软商店安装 Windows Terminal&#xff0c;本指南提供手动安装和配置新版 Windows Terminal 的步骤&#xff0c;并添加右键菜单快捷方式。 1. 下载新版终端安装包: 访问 Windows Terminal 的 GitHub 发布页面&#xff1a;https://githu…

从地铁客流讲开来:十二城日常地铁客运量特征

随着城市化进程的加速和人口的不断增长&#xff0c;公共交通系统在现代都市生活中扮演着日益重要的角色。地铁作为高效、环保的城市交通方式&#xff0c;已经成为居民日常出行不可或缺的一部分。本文聚焦于2024年10月28日至12月1日期间&#xff0c;对包括北上广深这四个超一线城…

firefox浏览器如何安装驱动?

firefox的浏览器驱动:https://github.com/mozilla/geckodriver/releases 将geckodriver.exe所在文件路径追加到系统环境变量path的后面

2.2.3 进程通信举例

文章目录 PV操作实现互斥PV操作实现同步高级通信 PV操作实现互斥 PV操作实现互斥。先明确临界资源是什么&#xff0c;然后确定信号量的初值。实现互斥时&#xff0c;一般是执行P操作&#xff0c;进入临界区&#xff0c;出临界区执行V操作。 以统计车流量为例。临界资源是记录统…

UE5 移植Editor或Developer模块到Runtime

要将源码中的非运行时模块移植到Runtime下使用,个人理解就是一个解决编译报错的过程,先将目标模块复制到项目的source目录内,然后修改模块文件夹名称,修改模块.build.cs与文件夹名称保持一致 修改build.cs内的类名 ,每个模块都要修改 // Copyright Epic Games, Inc. All …

Qt WORD/PDF(三)使用 QAxObject 对 Word 替换(QML)

关于QT Widget 其它文章请点击这里: QT Widget 国际站点 GitHub: https://github.com/chenchuhan 国内站点 Gitee : https://gitee.com/chuck_chee 姊妹篇: Qt WORD/PDF&#xff08;一&#xff09;使用 QtPdfium库实现 PDF 操作 Qt WORD/PDF&#xff08;二…

RAG基础知识及综述学习

RAG基础知识及综述学习 前言1.RAG 模块1.1 检索器&#xff08;Retriever&#xff09;1.2 检索融合&#xff08;Retrieval Fusion&#xff09;1.3 生成器&#xff08;Generator&#xff09; 2.构建检索器&#xff08;Retriever&#xff09;2.1 分块语料库2.2 编码文本块2.3 构建…

移动网络(2,3,4,5G)设备TCP通讯调试方法

背景&#xff1a; 当设备是移动网络设备连接云平台的时候&#xff0c;如果服务器没有收到网络数据&#xff0c;移动物联设备发送不知道有没有有丢失数据的时候&#xff0c;需要一个抓取设备出来的数据和服务器下发的数据的方法。 1.服务器系统是很成熟的&#xff0c;一般是linu…

深入剖析MyBatis的架构原理

架构设计 简要画出 MyBatis 的架构图 >> ​​ Mybatis 的功能架构分为哪三层&#xff1f; API 接口层 提供给外部使用的接口 API&#xff0c;开发人员通过这些本地 API 来操纵数据库。接口层一接收到调用请求就会调用数据处理层来完成具体的数据处理。MyBatis 和数据库的…

android opencv导入进行编译

1、直接新建module进行导入&#xff0c;选择opencv的sdk 导入module模式&#xff0c;选择下载好的sdk&#xff0c;修改module name为OpenCV490。 有报错直接解决报错&#xff0c;没报错直接运行成功。 2、解决错误&#xff0c;同步成功 一般报错是gradle版本问题较多。我的报…

智能座舱进阶-应用框架层-Jetpack主要组件

Jetpack的分类 1. DataBinding&#xff1a;以声明方式将可观察数据绑定到界面元素&#xff0c;通常和ViewModel配合使用。 2. Lifecycle&#xff1a;用于管理Activity和Fragment的生命周期&#xff0c;可帮助开发者生成更易于维护的轻量级代码。 3. LiveData: 在底层数据库更…

设计模式-访问者设计模式

介绍 访问者模式&#xff08;Visitor&#xff09;&#xff0c;表示一个作用于某对象结构中的各元素的操作&#xff0c;它使你可以在不改变个元素的类的前提下定义作用于这些元素的新操作。 问题&#xff1a;在一个机构里面有两种员工&#xff0c;1.Teacher 2.Engineer 员…

springmvc的拦截器,全局异常处理和文件上传

拦截器: 拦截不符合规则的&#xff0c;放行符合规则的。 等价于过滤器。 拦截器只拦截controller层API接口。 如何定义拦截器。 定义一个类并实现拦截器接口 public class MyInterceptor implements HandlerInterceptor {public boolean preHandle(HttpServletRequest reque…

宿舍管理系统(源码+数据库+报告)

356基于SpringBoot的宿舍管理系统&#xff0c;系统包含两种角色&#xff1a;管理员、用户,系统分为前台和后台两大模块 二、项目技术 编程语言&#xff1a;Java 数据库&#xff1a;MySQL 项目管理工具&#xff1a;Maven 前端技术&#xff1a;Vue 后端技术&#xff1a;SpringBo…

基于 HC_SR04的超声波测距数码管显示(智能小车超声波避障部分)

超声波测距模块HC-SR04 1、产品特色 ①典型工作用电压&#xff1a;5V ②超小静态工作电流&#xff1a;小于 5mA ③感应角度(R3 电阻越大,增益越高,探测角度越大)&#xff1a; R3 电阻为 392,不大于 15 度 R3 电阻为 472, 不大于 30 度 ④探测距离(R3 电阻可调节增益,即调节探测…

(OCPP服务器)SteVe编译搭建全过程

注意&#xff1a;建议使用3.6.0&#xff0c;我升级到3.7.1&#xff0c;并没有多什么新功能&#xff0c;反而电表的实时数据只能看到累计电能了&#xff0c;我回退了就正常&#xff0c;数据库是兼容的&#xff0c;java版本换位java11&#xff0c;其他不变就好 背景&#xff1a;…

搭建Tomcat(四)---Servlet容器

目录 引入 Servlet容器 一、优化MyTomcat ①先将MyTomcat的main函数搬过来&#xff1a; ②将getClass()函数搬过来 ③创建容器 ④连接ServletConfigMapping和MyTomcat 连接&#xff1a; ⑤完整的ServletConfigMapping和MyTomcat方法&#xff1a; a.ServletConfigMappin…