【机器学习实战】使用SGD-随机梯度下降、随机森林对MNIST数据进行二分类(Jupyterbook)

news2025/1/15 20:38:04

1. 数据集

  1. 由美国高中生和人口调查局员工手写的70000个数字的图片。
  2. 数据集获取
# 获取MNIST数据集
from sklearn.datasets import fetch_openml
mnist = fetch_openml('mnist_784', version=1, cache=True, as_frame=False)
mnist
  1. 查看X和Y
    在这里插入图片描述
  2. 找索引为36000的实例,并将其还原成数字(书中是还原成了5,但是我这么获取数据集还原的数字是9,不知道什么原因)
# 随机抽取一个实例,将其还原成数字
%matplotlib inline
import matplotlib
import matplotlib.pyplot as plt

some_digit = X[36000]
some_digit_image = some_digit.reshape(28, 28)
plt.imshow(some_digit_image, cmap=matplotlib.cm.binary, interpolation='nearest')
plt.axis('off')
plt.show()

在这里插入图片描述

2. 划分数据集

  • 因为有些机器学习算法对训练实例的顺序敏感,如果连续输入许多相似的实例,可能导致执行性能不佳,所以对训练集进行洗牌。
X_train, X_test, y_train, y_test = X[:60000], X[60000:], y[:60000], y[60000:]
# 对数据进行洗牌,防止输入许多相似实例导致的执行性能不佳
import numpy as np
shuffle_index = np.random.permutation(60000)
X_train, y_train = X_train[shuffle_index], y_train[shuffle_index]

在这里插入图片描述

3. 重新创建目标向量(以是5和非5作为二分类标准)

y_train_5 = (y_train == '5')
y_test_5 = (y_test == '5')

在这里插入图片描述

4. 训练SGD分类器

from sklearn.linear_model import SGDClassifier

sgd_clf = SGDClassifier(random_state=42)
sgd_clf.fit(X_train, y_train_5)

在这里插入图片描述

4.1 使用准确率进行评估

from sklearn.model_selection import StratifiedKFold
from sklearn.base import clone

skfolds = StratifiedKFold(n_splits=3, random_state=42, shuffle=True) # 三折

for train_index, test_index in skfolds.split(X_train, y_train_5):
    clone_clf = clone(sgd_clf)
    X_train_folds = X_train[train_index]
    y_train_folds = y_train_5[train_index]
    X_test_folds = X_train[test_index]
    y_test_folds = y_train_5[test_index]
    
    clone_clf.fit(X_train_folds, y_train_folds)
    y_pred = clone_clf.predict(X_test_folds)
    n_correct = sum(y_pred == y_test_folds)
    print('预测正确的数量比例:', n_correct / len(y_pred))

在这里插入图片描述

# 使用 cross_val_score()函数来评估模型
from sklearn.model_selection import cross_val_score

cross_val_score(sgd_clf, X_train, y_train_5, cv=3, scoring='accuracy')

在这里插入图片描述

4.2 使用混淆矩阵进行评估(精度、召回率和F1分数)

from sklearn.model_selection import cross_val_predict

y_train_pred = cross_val_predict(sgd_clf, X_train, y_train_5, cv=3)
from sklearn.metrics import confusion_matrix
confusion_matrix(y_train_5, y_train_pred)
# 精度和召回率
from sklearn.metrics import precision_score, recall_score

precision_score(y_train_5, y_train_pred)
# F1-score
from sklearn.metrics import f1_score
f1_score(y_train_5, y_train_pred)

在这里插入图片描述

4.3 绘制precision、recall相对于threshold的函数图

# 1. 获取训练集中所有实例的决策分数
y_scores = cross_val_predict(sgd_clf, X_train, y_train_5, cv=3, method='decision_function')
y_scores
# 2. 计算所有可能的阈值的precision、recall
from sklearn.metrics import precision_recall_curve

precisions, recalls, thresholds = precision_recall_curve(y_train_5, y_scores)
# 3. 绘制precision、recall相对于threshold的函数图
def plot_precision_recall_vs_threshold(precisions, recalls, thresholds):
    plt.plot(thresholds, precisions[: -1], 'b--', label='Precision')
    plt.plot(thresholds, recalls[:-1], 'g--', label='Recall')
    plt.xlabel('Threshold')
    plt.legend(loc='upper left') # 图例位置
    plt.ylim([0, 1]) # 设置刻度

plot_precision_recall_vs_threshold(precisions, recalls, thresholds)
plt.show()

在这里插入图片描述

4.4 精度-召回率(PR曲线)

# 4.精度-召回率曲线
plt.plot(recalls[:-1], precisions[: -1], 'r--')
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.show()

在这里插入图片描述

4.5 ROC曲线、AUC值

4.5.1 ROC曲线
# 1. 计算多种阈值的TPR和FPR
from sklearn.metrics import roc_curve

fpr, tpr, thresholds = roc_curve(y_train_5, y_scores)

# 2. 绘制FPR对TPR的曲线
def plot_roc_curve(fpr, tpr, label=None):
    plt.plot(fpr, tpr, lineWidth=2, label=label)
    plt.plot([0, 1], [0, 1], 'k--')  # 副对角线
    plt.axis([0, 1, 0, 1]) # 设置横轴、竖轴的刻度
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')

plot_roc_curve(fpr, tpr)
plt.show()

在这里插入图片描述

4.5.2 AUC值
# 完美分类器的ROC AUC为1,纯随机分类器的ROC AUC等于0.5
from sklearn.metrics import roc_auc_score
roc_auc_score(y_train_5, y_scores)

在这里插入图片描述

5. 使用随机森林训练、评估

from sklearn.ensemble import RandomForestClassifier

forest_clf = RandomForestClassifier(random_state=42)
y_probas_forest = cross_val_predict(forest_clf, X_train, y_train_5, cv=3, method='predict_proba')
y_probas_forest # 某个实例属于某个类别的概率
# 使用正类的概率作为分数值(第二列)
y_scores_forest = y_probas_forest[:, 1]
y_scores_forest
fpr_forest, tpr_forest, thresholds_forest = roc_curve(y_train_5, y_scores_forest)
# 绘制ROC曲线
plt.plot(fpr, tpr, "b:", label="SGD")
plot_roc_curve(fpr_forest, tpr_forest, "Random Forest")
plt.legend(loc="lower right") # 右下角
plt.show()

在这里插入图片描述
其他评估指标(AUC值、精度、召回率)如下:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/70236.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

AXI4-Lite总线读写BRAM

博主参考和学习的博客 AXI协议基础知识 。这篇博客比较详细地介绍了AXI总线,并且罗列了所有的通道和端口,写代码的时候可以方便地进行查表。AXI总线,AXI_BRAM读写仿真测试 。 这篇文章为代码的书写提供大致的思路,比如状态机和时…

GDB调试

文章目录1.什么是GDB2. 准备工作3.GDB命令-启动、退出、查看代码4.设置断点5.GDB命令-调试命令1.什么是GDB 2. 准备工作 通常,在为调试而编译时,我们会关掉编译器的优化选项"-o",并打开调试选选项“-g”,另外,“-wall”…

【第一章 Linux目录结构,网络连接模式,vi和vim,Linux关机重启命令,Linux用户管理】

第一章 Linux目录结构,网络连接模式,vi和vim,Linux关机&重启命令,Linux用户管理 1.Linux和Unix: ①Unix针对于大型,高性能主机或服务器; ②Linux适用于个人计算机。 2.网络连接的三种模式…

图解pytorch里面的torch.gather()

在 Dim1 的情况下应用 torch.gather() 上图显示了 torch gather() 函数在 dim1 的二维张量上的工作。 这里索引张量的行对应于输入张量的行(用灰色阴影突出显示)。现在对于索引张量中的每个索引值,从该行和输入张量的索引中选取相应的值。 让…

LEADTOOLS 22-23 .Net/NetCore/JS/JAVA/Win/Linux

破解版功能齐全:LEADTOOLS 是一系列综合工具包,旨在帮助程序员将光栅、文档、医学、多媒体和矢量图像集成到他们的桌面、服务器、平板电脑和移动应用程序中。LEADTOOLS 为开发人员提供最灵活、最强大的成像技术,为 OCR、条形码、表单识别、PD…

推荐大家一些CTF的网站和工具

一.网站 1.攻防世界 网址:攻防世界 这是一个有好多题目的网站 主要有Misc、Pwn、Web、Reverse、Crypto、Mobile几种题型 不会的问题还可以查题解 好用度 9星 2.BUUCTF 网址:BUUCTF在线评测 也有很多ctf的题目 逆向、网络等等...... 比攻防世界…

最近火爆了的对话ChatGPT

前言 相信最近小伙伴们已经被ChatGPT的惊艳效果刷屏了,之前笔者也介绍过一些对话方向的工作,感兴趣的小伙伴可以穿梭: 对话系统最新综述II https://zhuanlan.zhihu.com/p/446760658 在对话系统中建模意图、情感: https://zhuanlan.zhihu.com/…

Nacos是什么?

摘要:Nacos是 Dynamic Naming and Configuration Service的首字母简称,相较之下,它更易于构建云原生应用的动态服务发现、配置管理和服务管理平台。本文分享自华为云社区《Nacos入门指南 - Nacos是什么》,作者:华为云P…

.gitlab-ci.yml文件常用规则说明

我自己整理了一份yml文件,里面包含了分支触发,和tag触发,还有缓存等: stages:- install- build- deploycache:key: nodeModulespaths:- node_modules- distjob_install:stage: installtags:- cvtagsonly:refs:- devscript:- npm …

基于LLVM的Fortran编译器分析

简介 本文内容基于LLVM 13.0.0。 目前基于LLVM的Fortran编译器(或者驱动)有3种,分别是flang、f18和flang-new。 flang是pgfortran的开源版本,基于PGI/NVIDIA的商业Fortran 编译器,它并不从属于LLVM项目。NVIDIA团队…

LabVIEW编程LabVIEW开发 ADAM 4015热电阻输入模块例程与相关资料

LabVIEW编程LabVIEW开发 ADAM 4015热电阻输入模块例程与相关资料 ​研华公司的ADAM 4015是6通道热电阻输入模块,可以采集2线或3线热电阻输入信号,ADAM4015T课题采集热敏电阻的输入信号。模块在工业测量和监控的有着广泛的应用,它既可以支持A…

Web3中文|苹果想对以太坊征税

虽然Web3是非常新的技术,但是似乎已经遇到了非常多“劲敌”。 这些“敌人”正在阻碍web3应用程序和区块链游戏的发展,因为在web3里,应用程序和游戏将允许用户自主相互交易数字资产所有权。 所以,那些大公司,如任天堂…

最近全网爆火的黑科技,叫做chatGPT

AI神器ChatGPT 火了。 能直接生成代码、会自动修复bug、在线问诊、模仿莎士比亚风格写作……各种话题都能hold住,它就是OpenAI刚刚推出的——ChatGPT。 有脑洞大开的网友甚至用它来设计游戏:先用ChatGPT生成游戏设定,再用Midjourney出图&…

vue 数据手写分页,定时展示

我们在业务之中,其实会常常用到一些数据的分段展示 , 比如数据量过大导致echarts无法展示,我们就可以将数据进行算法分页 , 然后套用定时器实时更新分段数据; 例子展示 : 将下列数组截取成每页5条数据的分…

观察者模式(python)

一、模式定义 1.观察者模式(Observer Pattern):定义对象间的一种一对多依赖关系,使得每当一个对象状态发生改变时,其相关依赖对象皆得到通知并被自动更新。 2.观察者模式又叫做发布-订阅(Publish/Subscribe)模式、模…

SpringBoot微服务的发布与部署(3种方式)

基于 SpringBoot 的微服务开发完成之后,现在到了把它们发布并部署到相应的环境去运行的时候了。 SpringBoot 框架只提供了一套基于可执行 jar 包(executable jar)格式的标准发布形式,但并没有对部署做过多的界定,而且…

2022年Python面试题汇总【面试官爱问】

2022年Python面试题汇总【常问】1、请你讲讲python获取输入的方式,以及python如何打开文件2、Python数据处理的常用函数3、请你说说python传参传引用4、请你说说python和java的区别5、Python你常用的包有哪些?6、简单说明如何选择正确的Python版本。7、简…

Qt动态库

QT带界面的动态库 创建动态库 一、新建一个C的动态库的项目 选择C的动态库的项目,进行下一步 修改项目的名字和项目的保存的路径。 选着编译的方式,不需要改,进行下一步。 选着动态库,编译成动态库,进行下一步。 项目…

[附源码]JAVA毕业设计社区生活超市管理系统(系统+LW)

[附源码]JAVA毕业设计社区生活超市管理系统(系统LW) 项目运行 环境项配置: Jdk1.8 Tomcat8.5 Mysql HBuilderX(Webstorm也行) Eclispe(IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持)。 项目…

[附源码]Python计算机毕业设计SSM计算机学院课程设计管理系统(程序+LW)

[附源码]Python计算机毕业设计SSM计算机学院课程设计管理系统(程序LW) 项目运行 环境配置: Jdk1.8 Tomcat7.0 Mysql HBuilderX(Webstorm也行) Eclispe(IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持)。…