【机器学习】分类任务: 二分类与多分类

news2024/12/26 17:59:34

二分类与多分类:概念与区别

二分类多分类是分类任务的两种类型,区分的核心在于目标变量(label)的类别数:

  • 二分类:目标变量 y 只有两个类别,通常记为 y∈{0,1} 或 y∈{−1,1}。
    示例:垃圾邮件分类(垃圾邮件或非垃圾邮件)。

  • 多分类:目标变量 y 包含三个或更多类别,记为 y∈{1,2,…,K}。
    示例:手写数字识别(类别为 0 到 9 的数字)。


1. 二分类问题

特征与目标
  • 输入:特征向量 x \in \mathbb{R}^d
  • 输出:目标 y ∈ {0,1}。
  • 模型预测:预测值为类别 1 的概率 P(y=1|x) = \hat{y}
模型与算法
  1. 常用模型

    • 逻辑回归
    • 支持向量机(SVM)
    • 决策树
    • 随机森林
    • 神经网络(二分类输出层使用 Sigmoid 激活)
  2. 损失函数

    • 对数似然损失(Log-Likelihood Loss): \mathcal{L} = -\frac{1}{N} \sum_{i=1}^N \left[ y_i \log(\hat{y}_i) + (1 - y_i) \log(1 - \hat{y}_i) \right]
  3. 评估指标

    • 准确率(Accuracy)
    • 精确率(Precision)
    • 召回率(Recall)
    • F1 分数(F1-Score)
    • AUC-ROC 曲线
案例代码
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, roc_auc_score

# 生成二分类数据
# 参数说明:n_samples=100表示生成100个样本,n_features=4表示数据有4个特征,n_classes=2表示二分类问题,
# n_informative=2表示其中2个特征是有信息的,n_redundant=1表示1个特征是冗余的,n_repeated=0表示没有重复的特征,
# random_state=0表示随机种子,保证结果可重复
X, y = make_classification(n_samples=100, n_features=4, n_classes=2, n_informative=2, n_redundant=1, n_repeated=0,
                           random_state=0)

# 数据集划分
# 将数据集划分为训练集和测试集,test_size=0.2表示测试集占20%,random_state=42保证划分结果可重复
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# 使用逻辑回归模型进行训练
# 初始化逻辑回归模型
model = LogisticRegression()
# 使用训练集数据拟合模型
model.fit(X_train, y_train)
# 预测测试集的类别
y_pred = model.predict(X_test)
# 预测测试集的正类概率
y_prob = model.predict_proba(X_test)[:, 1]

# 评估模型性能
# 输出测试集的准确率
print("Accuracy:", accuracy_score(y_test, y_pred))
# 输出测试集的AUC-ROC分数
print("AUC-ROC:", roc_auc_score(y_test, y_prob))

输出结果

Accuracy: 0.9
AUC-ROC: 0.9090909090909091

2. 多分类问题

特征与目标
  • 输入:特征向量 x \in \mathbb{R}^d
  • 输出:目标 y \in \{1, 2, \dots, K\}
  • 模型预测:预测每个类别的概率 P(y=k|x),所有类别概率之和为 1。
模型与算法
  1. 常用模型

    • Softmax 回归(多类别逻辑回归)
    • 决策树与随机森林
    • 梯度提升树(如 XGBoost、LightGBM)
    • 神经网络(输出层使用 Softmax 激活)
  2. 损失函数

    • 交叉熵损失(Cross-Entropy Loss):\mathcal{L} = -\frac{1}{N} \sum_{i=1}^N \sum_{k=1}^K 1(y_i = k) \log(\hat{y}_{i,k}),k​ 是样本 i 被预测为类别 k 的概率。
  3. 评估指标

    • 准确率(Accuracy)
    • 混淆矩阵(Confusion Matrix)
    • 平均精确率、召回率与 F1 分数(Macro / Micro / Weighted)
案例代码
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score, classification_report

# 生成二分类数据
# 参数说明:n_samples=100表示生成100个样本,n_features=4表示数据有4个特征,n_classes=2表示二分类问题,
# n_informative=2表示其中2个特征是有信息的,n_redundant=1表示1个特征是冗余的,n_repeated=0表示没有重复的特征,
# random_state=0表示随机种子,保证结果可重复
X, y = make_classification(n_samples=100, n_features=4, n_classes=2, n_informative=2, n_redundant=1, n_repeated=0,
                           random_state=0)

# 数据集划分
# 将数据集划分为训练集和测试集,test_size=0.2表示测试集占20%,random_state=42保证划分结果可重复
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# 初始化随机森林分类器模型
model = RandomForestClassifier()

# 使用训练集数据拟合模型
model.fit(X_train, y_train)

# 使用拟合好的模型对测试集进行预测
y_pred = model.predict(X_test)

# 评估
# 输出模型的准确率
print("Accuracy:", accuracy_score(y_test, y_pred))
# 输出模型的分类报告,包含精确度、召回率、F1分数等指标
print("Classification Report:\n", classification_report(y_test, y_pred))

输出结果

Accuracy: 0.9
Classification Report:
               precision    recall  f1-score   support

           0       1.00      0.82      0.90        11
           1       0.82      1.00      0.90         9

    accuracy                           0.90        20
   macro avg       0.91      0.91      0.90        20
weighted avg       0.92      0.90      0.90        20

3. 二分类与多分类的区别

属性二分类多分类
目标变量y∈{0,1}y∈{1,2,…,K}
损失函数对数似然损失交叉熵损失
预测输出类别 0 或 1 的概率每个类别的概率分布
模型复杂度相对简单更复杂,需要考虑类别间关系
评估指标精确率、召回率、AUC 等混淆矩阵、宏平均 F1 等

4. 注意事项

  1. 模型选择

    • 对于二分类问题,许多模型(如逻辑回归、SVM)内置支持;
    • 多分类问题可通过**一对多(OvR)多对多(OvO)**策略,将多分类问题分解为多个二分类问题。
  2. 不平衡数据

    • 二分类和多分类中,不平衡数据都会导致评估指标偏差,需要关注 AUC 或调整权重。
  3. 概率解释

    • 二分类中概率直接表示为某一类别的置信度;
    • 多分类中概率分布表示样本属于每个类别的可能性。

总结而言,二分类和多分类的问题框架和方法类似,但多分类问题需要更复杂的模型和损失函数来捕捉类别间关系,是分类任务中的重要延伸!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2253252.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【自用】管材流转项目前端重部署流程 vue2 webpackage4 vuecli4

一、配置 1.下载项目,使用 IDEA 打开,并配置 Nodejs 它提示我,需要 Node.js,因为 nodejs 14 的 installer 已经官网已经找不到了,使用 fnm 又太麻烦, 所以直接采用在 IDEA 中下载的方式就好了。 2.清除缓…

java调用ai模型:使用国产通义千问完成基于知识库的问答

整体介绍: 基于RAG(Retrieval-Augmented Generation)技术,可以实现一个高效的Java智能问答客服机器人。核心思路是将预先准备的问答QA文档(例如Word格式文件)导入系统,通过数据清洗、向量化处理…

跨平台应用开发框架(4)----Qt(系统篇)

目录 1.Qt事件 1.事件来源 2.事件处理 3.按键事件 1.组合按键 4.鼠标事件 1.鼠标单击事件 2.鼠标释放事件 3.鼠标双击事件 4.鼠标移动事件 5.滚轮事件 5.定时器 1.QTimerEvent类 2.QTimer 类 3.获取系统日期及时间 6.事件分发器 7.事件过滤器 2.Qt文件 1.输入…

算法刷题Day8:BM30 二叉搜索树与双向链表

题目 牛客网题目传送门 思路 对二叉搜索树进行中序遍历,结果就是按序数组。因此想办法把前面遍历过的节点给记下来,记作pre。当遍历到某个节点node的时候,令前驱指向pre,然后让pre的后驱指向node。 代码 class TreeNode:def…

MySQL--视图

目录 1 认识视图 1.1 视图的定义 1.1 创建视图 1.2 查询 1.3 修改 1.4 删除 1.5 视图的优缺点 1.5.1 优点 1.5.2 缺点 1.6 视图的类型 1.7 视图与物化视图 2 视图检查选项 2.1 CASCADED 2.2 LOCAL 3 视图更新及作用 3.1 视图案列结合 3.1.1 屏蔽敏感数据 3.1…

基于Matlab高速动车组转臂定位橡胶节点刚度对车辆动力学影响仿真研究

本研究针对高速动车组转臂定位系统中橡胶节点的刚度对车辆动力学性能的影响进行仿真研究。随着高速铁路的发展,动车组的运行稳定性和舒适性成为设计和运营的核心问题,其中,转臂定位系统作为动车组悬挂系统的重要组成部分,其性能对…

并发专题(8)之JUC阻塞容器源码剖析

一、ArrayBlockingQueue源码剖析 ArrayBlockingQueue底层是采用数组实现的一个队列。因为底层是数据,一般被成为有界队列、其阻塞模式是基于ReentrantLock来实现的。 // 存数据操作 add(E),offer(E),put(E),offer(E,time,unit) // add(E):添加…

AI/ML 基础知识与常用术语全解析

目录 一.引言 二.AI/ML 基础知识 1.人工智能(Artificial Intelligence,AI) (1).定义 (2).发展历程 (3).应用领域 2.机器学习(Machine Learning,ML) (1).定义 (2).学习方式 ①.监督学习 ②.无监督…

【WRF-Urban】WPS中有关Urban的变量设置

【WRF-Urban】WPS中有关Urban的变量设置 地理数据源的配置WRF-Urban所需静态地理数据1、LANDUSE:包含城市地表分类的土地利用数据。2、URB_PARAM:城市参数数据集。3、FRC_URB2D:城市覆盖度数据集 WRF默认设置(美国)数据…

NVR录像机汇聚管理EasyNVR多个NVR同时管理基于B/S架构的技术特点与能力应用

EasyNVR视频融合平台基于云边端协同设计,能够轻松接入并管理海量的视频数据。该平台兼容性强、拓展灵活,提供了视频监控直播、录像存储、云存储服务、回放检索以及平台级联等一系列功能。B/S架构使得EasyNVR实现了视频监控的多元化兼容与高效管理。 其采…

c++预编译头文件

文章目录 c预编译头文件1.使用g编译预编译头文件2.使用visual studio进行预编译头文件2.1visual studio如何设置输出预处理文件(.i文件)2.2visual studio 如何设置预编译(初始创建空项目的情况下)2.3 visual studio打开输出编译时…

Zookeeper的通知机制是什么?

大家好,我是锋哥。今天分享关于【Zookeeper的通知机制是什么?】面试题。希望对大家有帮助; Zookeeper的通知机制是什么? 1000道 互联网大厂Java工程师 精选面试题-Java资源分享网 Zookeeper的通知机制主要通过Watcher实现,它是Zookeeper客…

基于Pyside6开发一个通用的在线升级工具

UI main.ui <?xml version"1.0" encoding"UTF-8"?> <ui version"4.0"><class>MainWindow</class><widget class"QMainWindow" name"MainWindow"><property name"geometry"&…

开源 - Ideal库 - Excel帮助类,ExcelHelper实现(四)

书接上回&#xff0c;前面章节已经实现Excel帮助类的第一步TableHeper的对象集合与DataTable相互转换功能&#xff0c;今天实现进入其第二步的核心功能ExcelHelper实现。 01、接口设计 下面我们根据第一章中讲解的核心设计思路&#xff0c;先进行接口设计&#xff0c;确定Exce…

嵌入式系统应用-LVGL的应用-平衡球游戏 part1

平衡球游戏 part1 1 平衡球游戏的界面设计2 界面设计2.1 背景设计2.2 球的设计2.3 移动球的坐标2.4 用鼠标移动这个球2.5 增加边框规则2.6 效果图2.7 游戏失败重启游戏 3 为小球增加增加动画效果3.1 增加移动效果代码3.2 具体效果图片 平衡球游戏 part2 第二部分文章在这里 1 …

《Python基础》之Pandas库

目录 一、简介 二、Pandas的核心数据结构 1、Series 2、DataFrame 三、数据读取与写入 1、数据读取 2、数据写入 四、数据清洗与处理 1、处理缺失值 2、处理重复值 3、数据转换 五、数据分析与可视化 1、统计描述 2、分组聚合 3、数据可视化 六、高级技巧 1、时…

网络安全-夜神模拟器如何通过虚拟机的Burp Suite代理应用程序接口

第一步、查看虚拟机的IP地址 我们可以通过ifconfig命令来查看虚拟机的IP地址,如下图所示。 第二步、在Burp Suite上设置代理 打开虚拟机上的Burp Suite,进入到代理模块中,进入到代理设置中心 打开系统代理设置中心之后,将我们虚拟机的地址添加到上面,作为新的代理。 第…

PyTorch 2.5.1: Bugs修复版发布

一&#xff0c;前言 在深度学习框架的不断迭代中&#xff0c;PyTorch 社区始终致力于提供更稳定、更高效的工具。最近&#xff0c;PyTorch 2.5.1 版本正式发布&#xff0c;这个版本主要针对 2.5.0 中发现的问题进行了修复&#xff0c;以提升用户体验。 二&#xff0c;PyTorch 2…

SpringAi整合大模型(进阶版)

进阶版是在基础的对话版之上进行新增功能。 如果还没弄出基础版的&#xff0c;请参考 https://blog.csdn.net/weixin_54925172/article/details/144143523?sharetypeblogdetail&sharerId144143523&sharereferPC&sharesourceweixin_54925172&spm1011.2480.30…

Python实现网站资源批量下载【可转成exe程序运行】

Python实现网站资源批量下载【可转成exe程序运行】 背景介绍解决方案转为exe可执行程序简单点说详细了解下 声明 背景介绍 发现 宣讲家网 的PPT很好&#xff0c;作为学习资料使用很有价值&#xff0c;所以想下载网站的PPT课件到本地&#xff0c;但是由于网站限制&#xff0c;一…