机器学习模型搭建与评估

news2025/1/16 8:02:17

模型搭建和评估

    • 第三章 模型搭建和评估--建模
      • 模型搭建
        • 任务一:切割训练集和测试集
        • 任务二:模型创建
        • 任务三:输出模型预测结果
    • 第三章 模型搭建和评估-评估
      • 模型评估
        • 任务一:交叉验证
        • 任务二:混淆矩阵
        • 任务三:ROC曲线


Datawhale社区 动手学数据分析 第三章 模型搭建和评估 学习记录

  • 模型搭建
    • 数据划分
      • train_test_split()
    • 模型选择和拟合数据
      • 模型实例化、 fit()
    • 预测
      • predict()
      • predict_proba()
  • 模型评估
    • 交叉验证
      • cross_val_score()
      • cross_val_predict()
    • 混淆矩阵
      • cofusion_matrix()
      • classification_report() , precision_score() , recall_score() , f1_score()
      • precision_recall_curve()
    • ROC曲线
      • roc_curve()
      • roc_auc_score()

第三章 模型搭建和评估–建模

我们拥有的泰坦尼克号的数据集,那么我们这次的目的就是,完成泰坦尼克号存活预测这个任务。

import pandas as pd
import numpy as np
%matplotlib inline

载入我们提供清洗之后的数据(clear_data.csv),大家也将原始数据载入(train.csv),说说他们有什么不同

#写入代码
orgin_data = pd.read_csv('train.csv')
orgin_data.head(3)
PassengerIdSurvivedPclassNameSexAgeSibSpParchTicketFareCabinEmbarked
0103Braund, Mr. Owen Harrismale22.010A/5 211717.2500NaNS
1211Cumings, Mrs. John Bradley (Florence Briggs Th...female38.010PC 1759971.2833C85C
2313Heikkinen, Miss. Lainafemale26.000STON/O2. 31012827.9250NaNS
#写入代码
clean_data = pd.read_csv('clear_data.csv')
clean_data.head(3)
PassengerIdPclassAgeSibSpParchFareSex_femaleSex_maleEmbarked_CEmbarked_QEmbarked_S
00322.0107.250001001
11138.01071.283310100
22326.0007.925010001

#不同

  1. 删除了 Survived 特征(将标签独立出来)
  2. 将文本特征 Sex, Embarked 转换为了 One-hot 编码
  3. 删除了原数据中的 Name, Ticket, Cabin 特征

模型搭建

  • 处理完前面的数据我们就得到建模数据,下一步是选择合适模型
  • 在进行模型选择之前我们需要先知道数据集最终是进行监督学习还是无监督学习
  • 模型的选择一方面是通过我们的任务来决定的。
  • 除了根据我们任务来选择模型外,还可以根据数据样本量以及特征的稀疏性来决定
  • 刚开始我们总是先尝试使用一个基本的模型来作为其baseline,进而再训练其他模型做对比,最终选择泛化能力或性能比较好的模型

【思考】数据集哪些差异会导致模型在拟合数据是发生变化

#思考回答
样本数量、特征数量、特征重要性、样本分布差距、噪声等

任务一:切割训练集和测试集

这里使用留出法划分数据集

  • 将数据集分为自变量和因变量
  • 按比例切割训练集和测试集(一般测试集的比例有30%、25%、20%、15%和10%)
  • 使用分层抽样
  • 设置随机种子以便结果能复现

【思考】

  • 划分数据集的方法有哪些?
  • 为什么使用分层抽样,这样的好处有什么?

#思考

  1. 划分数据集的方法
  • 留出法
    • 随机划分
    • 分层抽样
  • K 折交叉验证
  • 自助法
  1. 分层抽样可以使划分后数据保存原始数据的分布

任务提示1

  • 切割数据集是为了后续能评估模型泛化能力
  • sklearn中切割数据集的方法为train_test_split
  • 查看函数文档可以在jupyter noteboo里面使用train_test_split?后回车即可看到
  • 分层和随机种子在参数里寻找

要从clear_data.csv和train.csv中提取train_test_split()所需的参数

#写入代码
X = clean_data
y = orgin_data.Survived
#写入代码
from sklearn.model_selection import train_test_split
#写入代码
?train_test_split
#写入代码
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y)
X_train.shape, X_test.shape, y_train.shape, y_test.shape
((712, 11), (179, 11), (712,), (179,))

【思考】

  • 什么情况下切割数据集的时候不用进行随机选取
#思考回答

1. 数据集非常大,随机划分即可能浪费大量时间,也可能降低准确率
2. 时序数据,采用分段切割,避免数据泄露
3. 数据类别不平衡,会采用重采样

任务二:模型创建

  • 创建基于线性模型的分类模型(逻辑回归)
  • 创建基于树的分类模型(决策树、随机森林)
  • 分别使用这些模型进行训练,分别的到训练集和测试集的得分
  • 查看模型的参数,并更改参数值,观察模型变化

提示

  • 逻辑回归不是回归模型而是分类模型,不要与LinearRegression混淆
  • 随机森林其实是决策树集成为了降低决策树过拟合的情况
  • 线性模型所在的模块为sklearn.linear_model
  • 树模型所在的模块为sklearn.ensemble
#写入代码
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
#写入代码
lr = LogisticRegression(max_iter=3000)
lr.fit(X_train, y_train)
lr.get_params()
{'C': 1.0,
 'class_weight': None,
 'dual': False,
 'fit_intercept': True,
 'intercept_scaling': 1,
 'l1_ratio': None,
 'max_iter': 3000,
 'multi_class': 'auto',
 'n_jobs': None,
 'penalty': 'l2',
 'random_state': None,
 'solver': 'lbfgs',
 'tol': 0.0001,
 'verbose': 0,
 'warm_start': False}
#写入代码 查看训练集和测试集的得分
print("Training set score: {:.2f} ,Testing set score: {:.2f}".format(lr.score(X_train, y_train), lr.score(X_test, y_test)))
Training set score: 0.81 ,Testing set score: 0.80
#写入代码
rfc = RandomForestClassifier()
rfc.fit(X_train, y_train)
rfc.get_params()
{'bootstrap': True,
 'ccp_alpha': 0.0,
 'class_weight': None,
 'criterion': 'gini',
 'max_depth': None,
 'max_features': 'auto',
 'max_leaf_nodes': None,
 'max_samples': None,
 'min_impurity_decrease': 0.0,
 'min_impurity_split': None,
 'min_samples_leaf': 1,
 'min_samples_split': 2,
 'min_weight_fraction_leaf': 0.0,
 'n_estimators': 100,
 'n_jobs': None,
 'oob_score': False,
 'random_state': None,
 'verbose': 0,
 'warm_start': False}
#写入代码 查看训练集和测试集的得分
print("Training set score: {:.2f} ,Testing set score: {:.2f}".format(rfc.score(X_train, y_train), rfc.score(X_test, y_test)))
Training set score: 1.00 ,Testing set score: 0.82

【思考】

  • 为什么线性模型可以进行分类任务,背后是怎么的数学关系
  • 对于多分类问题,线性模型是怎么进行分类的

#思考回答

  • 利用输入的特征计算一个加权和 s = w T x s=w^Tx s=wTx,再通过激活函数映射到0和1之间进行分类
  • 可以转换为多个二分类问题

任务三:输出模型预测结果

  • 输出模型预测分类标签
  • 输出不同分类标签的预测概率

提示3

  • 一般监督模型在sklearn里面有个predict能输出预测标签,predict_proba则可以输出标签概率
#写入代码
pred = lr.predict(X_test)
pred[:5]
array([0, 0, 0, 0, 1], dtype=int64)
#写入代码
pred_proba = lr.predict_proba(X_test)
pred_proba[:5]
array([[0.92961653, 0.07038347],
       [0.95523041, 0.04476959],
       [0.8395754 , 0.1604246 ],
       [0.96137521, 0.03862479],
       [0.34067921, 0.65932079]])

【思考】

  • 预测标签的概率对我们有什么帮助
#思考回答  
获取模型对预测结果的确信程度

第三章 模型搭建和评估-评估

根据之前的模型的建模,我们知道如何运用sklearn这个库来完成建模,以及我们知道了的数据集的划分等等操作。那么一个模型我们怎么知道它好不好用呢?以至于我们能不能放心的使用模型给我的结果呢?那么今天的学习的评估,就会很有帮助。

加载下面的库

import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from IPython.display import Image
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
%matplotlib inline
plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示负号
plt.rcParams['figure.figsize'] = (10, 6)  # 设置输出图片大小

任务:加载数据并分割测试集和训练集

#写入代码
data = pd.read_csv('clear_data.csv')
train = pd.read_csv('train.csv')
X = data
y = train['Survived']
from sklearn.model_selection import train_test_split
#写入代码
X_train, X_test, y_train, y_test = train_test_split(X, y, stratify=y, random_state=0)
#写入代码
lr = LogisticRegression(max_iter=1000)
lr.fit(X_train, y_train)
LogisticRegression(max_iter=1000)

模型评估

模型评估是为了知道模型的泛化能力。

  • K 折交叉验证(k-fold cross-validation)
  • 混淆矩阵(confusion_matrix)
    • 准确率(precision)
    • 召回率(recall)
    • F1 分数
  • ROC曲线

任务一:交叉验证

  • 用10折交叉验证来评估之前的逻辑回归模型
  • 计算交叉验证精度的平均值
#提示:交叉验证
Image('Snipaste_2020-01-05_16-37-56.png')

在这里插入图片描述

提示4

  • 交叉验证在sklearn中的模块为sklearn.model_selection
#载入模块
from sklearn.model_selection import cross_val_score
#K折交叉验证
lr = LogisticRegression(C=100, max_iter=1000)
scores = cross_val_score(lr, X_train, y_train, cv=10)
scores
array([0.8358209 , 0.7761194 , 0.82089552, 0.80597015, 0.85074627,
       0.86567164, 0.73134328, 0.85074627, 0.75757576, 0.6969697 ])
#获取K折平均值
print("Average cross-validation score: {:.2f}".format(scores.mean()))
Average cross-validation score: 0.80

思考4

  • k折越多的情况下会带来什么样的影响?
#思考回答  
增加了计算开销

任务二:混淆矩阵

  • 计算二分类问题的混淆矩阵
  • 计算精确率、召回率以及f-分数

【思考】什么是二分类问题的混淆矩阵,理解这个概念,知道它主要是运算到什么任务中的

#思考回答  
混淆矩阵用于评估分类器性能,其总体思路就是统计A类别实例被分成为B类别的次数  
混淆矩阵中的行表示实际类别,列表示预测类别  
#提示:混淆矩阵
Image('Snipaste_2020-01-05_16-38-26.png')


在这里插入图片描述

#提示:准确率 (Accuracy),精确度(Precision),Recall,f-分数计算方法
Image('Snipaste_2020-01-05_16-39-27.png')


在这里插入图片描述

提示5

  • 混淆矩阵的方法在sklearn中的sklearn.metrics模块
  • 混淆矩阵需要输入真实标签和预测标签
  • 精确率、召回率以及f-分数可使用classification_report模块
#写入代码
from sklearn.metrics import confusion_matrix
#写入代码
lr = LogisticRegression(C=100, max_iter=500)
lr.fit(X_train, y_train)
LogisticRegression(C=100, max_iter=500)
#写入代码
pred = lr.predict(X_train)
#写入代码
confusion_matrix(y_train, pred)
array([[350,  62],
       [ 71, 185]], dtype=int64)
from sklearn.metrics import classification_report
print(classification_report(y_train, pred))
              precision    recall  f1-score   support

           0       0.83      0.85      0.84       412
           1       0.75      0.72      0.74       256

    accuracy                           0.80       668
   macro avg       0.79      0.79      0.79       668
weighted avg       0.80      0.80      0.80       668

【思考】

  • 如果自己实现混淆矩阵的时候该注意什么问题
#思考回答  
混淆矩阵中每个值代表的含义和位置,即 TN,FP,FN,TP。  

绘制 PR 曲线

from sklearn.metrics import precision_recall_curve

precisions, recalls, thresholds = precision_recall_curve(y_test, lr.decision_function(X_test))
plt.plot(precisions, recalls)
plt.xlabel("precision")
plt.ylabel("recall")
plt.grid();

在这里插入图片描述

def plot_precision_recall_vs_threshold(precisions, recalls, thresholds):
    plt.plot(thresholds, precisions[:-1], "b--", label="Precision")
    plt.plot(thresholds, recalls[:-1], "g-", label="Recall")
    plt.legend(loc=0)
    plt.xlabel("thresholds")
    plt.grid()

plot_precision_recall_vs_threshold(precisions, recalls, thresholds);

在这里插入图片描述

任务三:ROC曲线

  • 绘制ROC曲线

【思考】什么是ROC曲线,ROC曲线的存在是为了解决什么问题?

#《机器学习实战》  
**受试者工作特征曲线(简称ROC)**
绘制是真正类率(召回率,灵敏度)和假正类率(FPR)关系。  
FPR是被错误分为正类的负类实例比率。它等于1减去真负类率(TNR),后者是被正确分类为负类的负类实例比率,也称为特异度。  


**ROC曲线和PR曲线的选取:**  
一个经验法则是,当正类非常少见或者你更关注假正类而不是假负类时,应该选择PR曲线,反之则是ROC曲线。

提示6

  • ROC曲线在sklearn中的模块为sklearn.metrics
  • ROC曲线下面所包围的面积越大越好
#写入代码
from sklearn.metrics import roc_curve
#写入代码
fpr, tpr, thresholds = roc_curve(y_test, lr.decision_function(X_test))
plt.plot(fpr, tpr, label="ROC Curve")
plt.xlabel("FPR")
plt.ylabel("TPR (recall)")
plt.grid();


在这里插入图片描述

思考6

  • 对于多分类问题如何绘制ROC曲线
#思考回答
绘制每个类别的ROC曲线

【思考】你能从这条ROC曲线的到什么信息?这些信息可以做什么?

#思考回答  
ROC曲线下面积,用来指导模型选择。

【注】
仅个人学习记录,详细内容见Datawhale社区开源课程 动手学数据分析

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/178471.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

python爬虫学习笔记-mysql数据库介绍下载安装

数据库概述 为什么要使用数据库? 那我们在没有学习数据库的时候,数据存放在json或者磁盘文件中不也挺好的嘛,为啥还要学习数据库? 文件中存储数据,无法基于文件直接对数据进行操作或者运算,必须借助python将…

IDEA搭建Finchley.SR2版本的SpringCloud父子基础项目-------Ribbon负载均衡

1.概念 Spring Cloud Ribbon是基于Netflix Ribbon实现的一套客户端负载均衡的工具。简单的说,Ribbon是Netflix发布的开源项目,主要功能是提供客户端的软件负载均衡算法,将Netflix的中间层服务连接在一起。Ribbon客户端组件提供一系列完善的配…

Python闭包与闭包陷阱

1 什么是闭包 在 Python 中,闭包是一种特殊的函数,它能够记住它所在的环境(也称作上下文)。这意味着闭包能够访问定义它的作用域中的变量。闭包通常用于封装数据和提供对外部访问的接口。 在 Python 中使用闭包有以下几点好处&a…

数据库和SQL概述

数据库和SQL概述 数据库的好处 实现数据的持久化使用完整的管理系统统一管理,易于查询 常用的一些名称缩写 DB:数据库(Database):存储数据的“仓库”。它保存了一系列有组织的数据DBMS:数据库管理系统(Database Management Sy…

离线用户召回定时更新

3.6 离线用户召回定时更新 学习目标 目标 知道离线内容召回的概念知道如何进行内容召回计算存储规则应用 应用spark完成离线用户基于内容的协同过滤推荐 3.6.1 定时更新代码 完整代码 import os import sys # 如果当前代码文件运行测试需要加入修改路径,否则后面…

游戏启动器:LaunchBox Premium with Big Box v13.1

LaunchBox知道您会喜欢的功能,具有风格的游戏启动器,我们最初将 Launchbox 构建为 DOSBox 的一个有吸引力的前端,但它现在拥有对现代游戏和复古游戏模拟的支持。我们让您的所有游戏看起来都很漂亮。 整理您的游戏收藏 我们不仅漂亮&#xff…

基于微信小程序奶茶店在线点餐下单系统

奶茶在线下单系统用户端是基于微信小程序端,管理员端是基于web端,基于java编程语言,mysql数据库,idea工具开发,用户微信端可以注册登陆小程序,查看奶茶详情,搜索下单奶茶,在线奶茶评…

CSS @property(CSS 自定义属性)

CSS property(CSS 自定义属性)参考描述propertyHoudiniproperty兼容性描述符规则syntax扩展initial-valueinherits示例描述符的注意事项使用 JavaScript 来创建自定义属性CSS 变量与自定义属性重复赋值过渡简单的背景过渡动画更复杂的背景过渡动画错误示…

【ARM体系结构】之数据类型约定与工作模式

1、RISC和CISC的区别 1.1 RISC : 精简指令集 使用精简指令集的架构:ARM架构 RISC-V架构 PowerPC架构 MIPS架构ARM架构 :目前使用最广泛的架构,ARM面向的低端消费类市场RISC-V架构 :第五代,精简指令集的架构&#xff…

这样定义通用人工智能

🍿*★,*:.☆欢迎您/$:*.★* 🍿 正文 人类解决问题的途径,大体可以分为两种。一种是事实推理,另一种是事实验证。 为什么只是两种分类,因为根据和环境的交互与否。 事实推理解释为当遇到事件发生的时候,思考的过程。可以使用概率模型,或者更复杂的模型(目前没…

Out of Vocabulary处理方法

Out of Vocabulary 我们在NLP任务中一般都会有一个词表,这个词表一般可以使用一些大牛论文中的词表或者一些大公司的词表,或者是从自己的数据集中提取的词。但是无论当后续的训练还是预测,总有可能会出现并不包含在词表中的词,这…

(教程)如何在BERT模型中添加自己的词汇(pytorch版)

来源:投稿 作者:皮皮雷 编辑:学姐 参考文章: NLP | How to add a domain-specific vocabulary (new tokens) to a subword tokenizer already trained like BERT WordPiece | by Pierre Guillou | Medium https://medium.com/pi…

ROS2机器人编程简述humble-第三章-BUMP AND GO IN C++ .3

简述本章项目,参考如下:ROS2机器人编程简述humble-第三章-PERCEPTION AND ACTUATION MODELS .1流程图绘制,参考如下:ROS2机器人编程简述humble-第三章-COMPUTATION GRAPH .2然后,在3.3和3.4分别用C和Python编程实现&am…

Bus Hound 工具抓取串口数据(PC端抓取USB转串口数据)

测试环境: PC端 USB转串口 链接终端板卡串口 目标:抓取通信过程中的通信数据 工具介绍:Bus Hound是是由美国perisoft公司研制的一款超级软件总线协议分析器,它是一种专用于PC机各种总线数据包监视和控制的开发工具软件&#xff0c…

通信原理简明教程 | 数字调制传输

文章目录1 二进制数字调制和解调1.1 二进制数字调制的基本原理1.2 二进制数字调制信号的特性1.3 解调方法2 二进制差分相移键控2.1 2PSK的倒π现象2.2 2DPSK调制和解调3 二进制调制系统的抗噪性能3.1 2ASK系统的抗噪声性能3.2 2FSK系统的抗噪声性能4 二进制数字调制系统性能比较…

服务器配置定时脚本 crontab + Python;centos6或centos 7或centos8 实现定时执行 python 脚本

一、crontab的安装 默认情况下,CentOS 7中已经安装有crontab,如果没有安装,可以通过yum进行安装。 yum install crontabs 二、crontab的定时语法说明 corntab中,一行代码就是一个定时任务,其语法结构可以通过这个图来理解。 字符含义如下: * 代表取值范围内的数字 /…

Linux内核驱动初探(三) 以太网卡

目录 0. 前言 1. menuconfig 2. 设备树 0. 前言 这次的网卡驱动就比较顺利,基本就是参考 4.19.x 内核以及 imx6qdl-sabrelite.dtsi、imx6qdl-sabreauto.dtsi 中的设备树,来设置以太网各项参数。 1. menuconfig 其实笔者接手的时候,网口这…

本质安全设备标准(IEC60079-11)的理解(三)

本质安全设备标准(IEC60079-11)的理解(三) 对于标准中“fault”的理解 第一,标准中对fault的定义是这样的: 3.7.2 fault any defect of any component, separation, insulation or connection between c…

C++空间命名

前言 提示:由于C是在C语言基础之上,增加了很多新的东西。 本文讲解命名空间的具体使用方法 文章目录 目录 前言 一、命名空间 二、命名空间定义 1.嵌套性 2.和并性 总结 提示:以下是本篇文章正文内容,下面案例可供参考 一…

【华为上机真题】区间交集

🎈 作者:Linux猿 🎈 简介:CSDN博客专家🏆,华为云享专家🏆,Linux、C/C、云计算、物联网、面试、刷题、算法尽管咨询我,关注我,有问题私聊! &…