【机器学习实战】使用SGD、随机森林对MNIST数据集实现多分类(jupyterbook)

news2025/1/18 18:59:48

1. 获取数据集并重新划分数据集

# 获取MNIST数据集
from sklearn.datasets import fetch_openml
mnist = fetch_openml('mnist_784', version=1, cache=True, as_frame=False)

# 查看测试器和标签
X, y = mnist['data'], mnist['target']
X_train, X_test, y_train, y_test = X[:60000], X[60000:], y[:60000], y[60000:]
# 对数据进行洗牌,防止输入许多相似实例导致的执行性能不佳
import numpy as np
shuffle_index = np.random.permutation(60000)
X_train, y_train = X_train[shuffle_index], y_train[shuffle_index]

# 重新创建目标向量(以是5和非5作为二分类标准)
y_train_5 = (y_train == '5')
y_test_5 = (y_test == '5')

在这里插入图片描述

2. 使用SGD随机梯度下降进行多分类

some_digit = X[36000]

from sklearn.linear_model import SGDClassifier

sgd_clf = SGDClassifier(random_state=42)
sgd_clf.fit(X_train, y_train)
sgd_clf.predict([some_digit])

在这里插入图片描述

3. 对二分类算法强制使用一对一、一对多策略进行多分类

3.1 SGD

# 1. 使用OvO(一对一)策略,基于SGD创建多分类器
from sklearn.multiclass import OneVsOneClassifier

ovo_clf = OneVsOneClassifier(SGDClassifier(random_state=42))
ovo_clf.fit(X_train, y_train)
ovo_clf.predict([some_digit])

在这里插入图片描述

3.2 随机森林

# 1. 训练随机森林(因为随机森林本身就可以进行多分类)
from sklearn.ensemble import RandomForestClassifier

forest_clf = RandomForestClassifier(random_state=42)
forest_clf.fit(X_train, y_train)
forest_clf.predict([some_digit])

在这里插入图片描述

4. 对模型进行评估(使用准确率)

4.1 数据未标准化

# 1. 使用交叉验证对SGD多分类器进行评估
from sklearn.model_selection import cross_val_score

cross_val_score(sgd_clf, X_train, y_train, cv=3, scoring="accuracy")

4.2 数据标准化后

# 2. 对训练集进行标准化,再进行评估看看
from sklearn.preprocessing import StandardScaler

scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train.astype(np.float64)) # 标准化
cross_val_score(sgd_clf, X_train_scaled, y_train, cv=3, scoring="accuracy")

在这里插入图片描述

5. 绘制混淆矩阵并进行分类错误分析

5.1 原始混淆矩阵

# 混淆矩阵
from sklearn.model_selection import cross_val_predict
from sklearn.metrics import confusion_matrix
y_train_pred = cross_val_predict(sgd_clf, X_train_scaled, y_train, cv=3)
conf_mx = confusion_matrix(y_train, y_train_pred)
# 绘制混淆矩阵的图像
import matplotlib.pyplot as plt

plt.matshow(conf_mx, cmap=plt.cm.gray)
plt.show()

在这里插入图片描述

  • 结论:
    • 大多数图片都在主对角线上,说明它们被正确分类。
    • 数字5稍微暗一点,可能数据集中5的图片比较少,也可能是分类器在5上的执行效果不如其他数字好。

5.2 将正确分类的剔除后只留下错误的

# 将混淆矩阵中的每个值 除以 相应类别中的图片数量
row_sums = conf_mx.sum(axis=1, keepdims=True) # 同行相加
norm_conf_mx = conf_mx / row_sums

# 用0填充对角线,保存错误,重新绘制混淆矩阵
np.fill_diagonal(norm_conf_mx, 0)
plt.matshow(norm_conf_mx, cmap=plt.cm.gray)
plt.show()

在这里插入图片描述

  • 结论:
    • 第8列、第9列比较亮,说明许多图片被错分为8和9;
    • 第8行、第9行也偏亮,说明8、9容易和其他数字混淆;
    • 行1很暗,说明大多数1都被正确分类;
    • 数字5被分成8的数量比8分成5的数量更多。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/69606.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Nuxt3使用echart,使用中国地图

目录 第一步安装echart 第二步配置plugins 第三步使用 例如使用饼状图 例如使用中国地图 第一步安装echart npm install echarts --save 第二步配置plugins 在plugins创建echarts.ts文件并写入下面内容 import * as echarts from echartsexport default defineNuxtPlugin((…

springboot事件监听机制二:基本工作原理

前言 这是继《springboot事件监听机制一:实战应用》第二篇,知其然,当然还要知其所以然,深入的源码里面探寻一下这一有套机制的工作原理。spring生态很茂盛,这里不会站太高去分析这个问题,大扯spring的一些原…

优秀的项目跟踪管理软件有哪些?

国内外优秀的项目跟踪管理软件有:1、软件项目跟踪管理PingCode;2、通用项目跟踪管理Worktile;3、小型团队项目跟踪管理Asana;4、基于桌面的项目跟踪软件Microsoft Project;5、适用所有类型项目的跟踪软件Clickup&#…

[ vulhub漏洞复现篇 ] GhostScript 沙箱绕过(任意命令执行)漏洞CVE-2018-16509

🍬 博主介绍 👨‍🎓 博主介绍:大家好,我是 _PowerShell ,很高兴认识大家~ ✨主攻领域:【渗透领域】【数据通信】 【通讯安全】 【web安全】【面试分析】 🎉点赞➕评论➕收藏 养成习…

Cellobiose-PEG-DBCO 纤维二糖-聚乙二醇-二苯基环辛炔,DBCO-PEG-纤维二糖

Cellobiose-PEG-DBCO 纤维二糖-聚乙二醇-二苯基环辛炔,DBCO-PEG-纤维二糖 中文名称:纤维二糖-二苯基环辛炔 英文名称:Cellobiose-DBCO 别称:二苯基环辛炔修饰纤维二糖,二苯基环辛炔-纤维二糖 PEG分子量可选&…

2023年湖北安全员ABC报名时间和考试时间是什么时候?甘建二

2023年湖北安全员ABC报名时间和考试时间是什么时候? 安全员ABC考试和报名时间,12月份安全员ABC考试时间是12月底,12月份湖北安全员ABC报名是现在开始报名了,目前报名入口已经开通需要开始报名了。 2023年湖北安全员ABC报名时间&am…

人工智能历史上的重要一步:ChatGPT影响到谷歌地位?

AI神器ChatGPT 火了。 能直接生成代码、会自动修复bug、在线问诊、模仿莎士比亚风格写作……各种话题都能hold住,它就是OpenAI刚刚推出的——ChatGPT。 有脑洞大开的网友甚至用它来设计游戏:先用ChatGPT生成游戏设定,再用Midjourney出图&…

如何让 useEffect 支持 async/await?

大家在使用 useEffect 的时候,假如回调函数中使用 async...await... 的时候,会报错如下。 看报错,我们知道 effect function 应该返回一个销毁函数(return返回的 cleanup 函数),如果 useEffect 第一个参数传…

[毕业设计]C++程序类内聚度的计算与存储

目录 前言 课题背景和意义 实现技术思路 实现效果图样例 前言 📅大四是整个大学期间最忙碌的时光,一边要忙着备考或实习为毕业后面临的就业升学做准备,一边要为毕业设计耗费大量精力。近几年各个学校要求的毕设项目越来越难,有不少课题是研究生级别难度的,对本科…

651页23万字智慧教育大数据信息化顶层设计及智慧应用建设方案

目录 一、 方案背景 1.1 以教育现代化支撑国家现代化 1.2 教育信息化是教育现代化重要内容和标志 1.3 大数据驱动教育信息化发展 1.4 政策指导大数据推动教育变革 1.5 教育大数据应用生态服务教育现代化 二、 建设需求 2.1 地区教育系统亟待进行信息共享、系统融合 2.2…

L2正则线性回归(岭回归)

岭回归 数据的特征比样本点还多,非满秩矩阵在求逆时会出现问题 岭回归即我们所说的L2正则线性回归,在一般的线性回归最小化均方误差的基础上增加了一个参数w的L2范数的罚项,从而最小化罚项残差平方和 简单说来,岭回归就是在普通…

FreeRTOS基础知识

目录 1.任务调度器简介 1.1抢占式调度举例 1.2时间片调度举例 2.任务状态 3.总结 1.任务调度器简介 调度器就是使用相关的调度算法来决定当前需要执行哪个任务。 FreeRTOS一共支持以下三种任务调度方式: FreeRTOS调度方式抢占式调度主要是针对优先级不同的任务…

vector深度剖析及模拟实现

vector模拟实现🏞️1. vector的扩容机制🌁2. vector迭代器失效问题📖2.1 insert导致的失效📖2.2 erase导致的失效🌿3. vector拷贝问题🏜️4. 模拟实现vector🏞️1. vector的扩容机制 #include&…

SQL快速入门、查询(SqlServer)[郝斌SqlServer完整版]

文章目录SQL学前导图一 、基本信息1 相关名词数据库相关基本概念:字段、属性、记录(元祖)、表、主键、外键2 基本语句3 约束:主键约束、外键约束、check约束、default约束、唯一约束二、查询1 计算列2 distinct(去重)3 between4 i…

生产跟踪是生产控制的基础,其主要功能有哪些?

生产跟踪是生产控制的基础,只有对生产的过程全面了解,才能掌握和控制生产的执行情况,所以生产跟踪模块在制造执行系统中一种起着举足轻重的作用。生产跟踪,不单单是对生产过程进行监控和记录数据,还需要将各个生产环节…

2023最新SSM计算机毕业设计选题大全(附源码+LW)之java校园新闻发布管理系统574ec

面对老师五花八门的设计要求,首先自己要明确好自己的题目方向,并且与老师多多沟通,用什么编程语言,使用到什么数据库,确定好了,在开始着手毕业设计。 1:选择课题的第一选择就是尽量选择指导老师…

ubuntu18.04上点云PCL 库使用初探

PCL 库使用资料 在 ubuntu18.04 上使用pcl记录 一、 安装 首先需要在 ubuntu 上安装c 库 sudo apt install libpcl-dev dpkg -S pcl 查看包文件安装的位置,包括头文件和库文件,进到库文件路径下看,目前安装的是 pcl 1.8.1 /usr/include/pc…

最全Java知识点总结归纳

一、流 Java所有的流类位于http://java.io包中,都分别继承字以下四种抽象流类型。 继承自InputStream/OutputStream的流都是用于向程序中输入/输出数据,且数据的单位都是字节(byte8bit)。 继承自Reader/Writer的流都是用于向程序中输入/输出数据&#x…

黄佳《零基础学机器学习》chap3笔记

黄佳 《零基础学机器学习》 chap3笔记 第3课 线性回归——预测网店的销售额 文章目录黄佳 《零基础学机器学习》 chap3笔记第3课 线性回归——预测网店的销售额3.1 问题定义:小冰的网店广告该如何投放3.2 数据的收集和预处理3.2.1 收集网店销售额数据3.2.2 数据读取…

功能测试(五)—— web项目抓包操作与测试报告

目录 目标 一、网络相关知识介绍 1.1 请求 1.2 响应 二、抓包工具的应用 2.1 过滤 2.2 删除数据 2.3 查看数据包内容 2.4 定位Bug 2.5 弱网测试 2.6 设置断点(请求之前) 2.7 设置断点(响应之后) 三、测试报告 目标 …