【Python】基于Python的机器学习分类的模型选择:交叉验证和模型质量评估

news2024/11/23 13:25:47

目录

  • 1 简介
  • 2 思路分解与说明
  • 3 完整代码

1 简介

最近完成一个工作,就基于一些表格化的数据进行机器学习分类。
由于分类是研究中的关键步骤,所以首先要选择到底哪个模型适合我们的分类任务。
比较传统且经典的选择方法就是用交叉验证。
交叉验证是什么可以看以下这幅图,来自这篇论文,这里不过多赘述。
在这里插入图片描述
那么,具体实验中,有了数据之后,要怎么用自动化高效地方法完成交叉验证并选择合适的模型呢?这篇博文就是为了解决这个问题。

本文要实现的东西很简单,在于:
对多个模型进行k折交叉验证,并且对输出每一fold和每个模型的总体评价指标。

因此,本文的思路是:
1,读取数据;
2,拆分数据,并且对每一个fold进行质量指标计算;
3,汇总结果,对一个模型的总体结果的质量指标进行计算。

2 思路分解与说明

我们首先需要从本地文件中读取用于建模的样本数据,一般是用excel存取的,示例如下。
在这里插入图片描述

共有16列数据。其中前15列是特征(自变量),最后一列是目标(因变量)。
因为这个是分类任务,所以我的目标(因变量)只有两个值,即0和1。
首先,需要写个函数读取excel数据,如下。

def read_data(data_path, data_sheet):
    """
    读取excel表格中的数据
    输入excel文件路径,输出dataframe格式的数据
        data_path: excel文件(xlsx或xls)的路径
        data_sheet: excel文件的sheet名称
    """
    from pandas import read_excel
    file = data_path # 读取数据路径
    data = read_excel(open(file, "rb"), sheet_name=data_sheet) # 读取数据
    return data

我们用这个函数读取了样本,并且用这两句代码

data = read_data("data.xlsx", "data") # 读取数据
print(data) # 样本数据

进行读取和打印,结果如下。
在这里插入图片描述
接下来就是实现交叉验证及其模型质量参数的计算。
这里用比较常用的一些参数,比如kappa系数,总体精度(OA),精确度(precision),召回率(recall),F1-score这些,基本覆盖了分类任务中常用的所有评价指标。
好在关于这个分折(k-fold)和这些指标的计算都有比较方便的接口,接下来的任务就是整合以上需求,以分折交叉验证和这些指标的计算为导向写一个函数。函数如下。

def KFold_Classificaton(data, KNumber, classification_model, shuffle=False, random_state=None):
    """
    K折交叉验证
    输入数据和K折交叉验证所需参数,打印各次模型的精度指标
        data: dataframe格式数据
        KNumber: 折数
        classification_model: 实例化后的分类模型
        shuffle: 是否打乱样本,默认不打乱
        random_state: 随机种子
        target_names: 目标类名
        输出所有折验证的结果
    """
    from sklearn.model_selection import KFold
    from sklearn.metrics import classification_report
    from sklearn.metrics import accuracy_score, precision_score, f1_score, recall_score, cohen_kappa_score
    kf_model = KFold(n_splits=KNumber, shuffle=shuffle, random_state=random_state) # KFold规则,KNumber为折数
    results_list = [] # 存放所有结果dataframe的list
    # 每折的计算
    fold_counter = 1 # 折数计数器
    model_copy = classification_model
    for train_index, test_index in kf_model.split(data):
        results = DataFrame()
        print("=======================================================================")
        # 分割训练集和验证集的特征和目标
        x_train = data.iloc[train_index, 0:-1]
        y_train = data.iloc[train_index, -1]
        x_test = data.iloc[test_index, 0:-1]
        y_test = data.iloc[test_index, -1]
        model_copy.fit(x_train, y_train) # 训练模型
        y_predict = model_copy.predict(x_test) # 模型预测
        kappa = cohen_kappa_score(y_test, y_predict) # kappa系数
        model_report = classification_report(y_test, y_predict, digits=6) # 分类精度报告
        # 存放结果
        results["true"] = y_test
        results["pred"] = y_predict
        results["fold"] = fold_counter
        results_list.append(results)
        # 显示结果
        print(f"第 {fold_counter} 次精度验证:\n模型: {model_copy}\nkappa = {kappa}\n", model_report)
        fold_counter += 1
        print("=======================================================================\n")
    # 整合结果
    tot_results = pd.concat(results_list, axis=0) # 合并所有结果
    tot_results = tot_results.reset_index(drop=True) # 重置索引
    y_test = tot_results["true"]
    y_predict = tot_results["pred"]
    kappa = cohen_kappa_score(y_test, y_predict) # kappa系数
    model_report = classification_report(y_test, y_predict, digits=6) # 分类精度报告
    print(f"总体结果精度验证:\n模型: {model_copy}\nkappa = {kappa}\n", model_report)
    print("=====================================================================================\n")
    return tot_results

这个函数调用了sklearn里的kfold接口,它能自动按照填入的折数分割样本;然后遍历每一折并且划分特征和目标以进行建模和相关精度指标的计算,精度指标计算也用了sklearn.metrics里自带的一些接口,可根据需要自行调节;最后把一个模型的所有折的结果汇总起来,以进行总体的精度评估。
由于上述函数的参数需要输入一个实例化的模型,我就在主函数写了一些以供选取。

# 1.决策树分类
from sklearn.tree import DecisionTreeClassifier
dt = DecisionTreeClassifier()
# 2.SVM分类
from sklearn.svm import SVC
svm = SVC()
# 3.KNN分类
from sklearn.neighbors import KNeighborsClassifier
knn = KNeighborsClassifier()
# 4.随机森林分类
from sklearn.ensemble import RandomForestClassifier
rf = RandomForestClassifier(n_estimators=100)
# 5.Adaboost分类
from sklearn.ensemble import AdaBoostClassifier
adb = AdaBoostClassifier(n_estimators=100)
# 6.GBDT分类
from sklearn.ensemble import GradientBoostingClassifier
gbdt = GradientBoostingClassifier(n_estimators=100)
# 7.Bagging分类
from sklearn.ensemble import BaggingClassifier
bag = BaggingClassifier(n_estimators=100)
# 8.极端数分类
from sklearn.tree import ExtraTreeClassifier
et = ExtraTreeClassifier()
# 9.朴素贝叶斯分类
from sklearn.naive_bayes import GaussianNB
gnb = GaussianNB()

好的,接下来演示一下该函数的工作。比如我要用5折交叉验证测试随机森林(RF)和GBDT模型的效果,我就用以下语句。

KFold_Classificaton(data, 5, classification_model=rf, shuffle=True, random_state=1)
KFold_Classificaton(data, 5, classification_model=gbdt, shuffle=True, random_state=1)

部分结果输出如下。首先输出5次交叉验证的精度评估报告,然后汇总了5次结果并进行总体结果的评估。
在这里插入图片描述
在这里插入图片描述
好了,到目前为止我们的任务完成了。
完整代码见下方,注释写得很详细了。

3 完整代码

import pandas as pd
from pandas import DataFrame


def read_data(data_path, data_sheet):
    """
    读取excel表格中的数据
    输入excel文件路径,输出dataframe格式的数据
        data_path: excel文件(xlsx或xls)的路径
        data_sheet: excel文件的sheet名称
    """
    from pandas import read_excel
    file = data_path # 读取数据路径
    data = read_excel(open(file, "rb"), sheet_name=data_sheet) # 读取数据
    return data


def KFold_Classificaton(data, KNumber, classification_model, shuffle=False, random_state=None):
    """
    K折交叉验证
    输入数据和K折交叉验证所需参数,打印各次模型的精度指标
        data: dataframe格式数据
        KNumber: 折数
        classification_model: 实例化后的分类模型
        shuffle: 是否打乱样本,默认不打乱
        random_state: 随机种子
        target_names: 目标类名
        输出所有折验证的结果
    """
    from sklearn.model_selection import KFold
    from sklearn.metrics import classification_report
    from sklearn.metrics import accuracy_score, precision_score, f1_score, recall_score, cohen_kappa_score
    kf_model = KFold(n_splits=KNumber, shuffle=shuffle, random_state=random_state) # KFold规则,KNumber为折数
    results_list = [] # 存放所有结果dataframe的list
    # 每折的计算
    fold_counter = 1 # 折数计数器
    model_copy = classification_model
    for train_index, test_index in kf_model.split(data):
        results = DataFrame()
        print("=======================================================================")
        # 分割训练集和验证集的特征和目标
        x_train = data.iloc[train_index, 0:-1]
        y_train = data.iloc[train_index, -1]
        x_test = data.iloc[test_index, 0:-1]
        y_test = data.iloc[test_index, -1]
        model_copy.fit(x_train, y_train) # 训练模型
        y_predict = model_copy.predict(x_test) # 模型预测
        kappa = cohen_kappa_score(y_test, y_predict) # kappa系数
        model_report = classification_report(y_test, y_predict, digits=6) # 分类精度报告
        # 存放结果
        results["true"] = y_test
        results["pred"] = y_predict
        results["fold"] = fold_counter
        results_list.append(results)
        # 显示结果
        print(f"第 {fold_counter} 次精度验证:\n模型: {model_copy}\nkappa = {kappa}\n", model_report)
        fold_counter += 1
        print("=======================================================================\n")
    # 整合结果
    tot_results = pd.concat(results_list, axis=0) # 合并所有结果
    tot_results = tot_results.reset_index(drop=True) # 重置索引
    y_test = tot_results["true"]
    y_predict = tot_results["pred"]
    kappa = cohen_kappa_score(y_test, y_predict) # kappa系数
    model_report = classification_report(y_test, y_predict, digits=6) # 分类精度报告
    print(f"总体结果精度验证:\n模型: {model_copy}\nkappa = {kappa}\n", model_report)
    print("=====================================================================================\n")
    return tot_results




# 1.决策树分类
from sklearn.tree import DecisionTreeClassifier
dt = DecisionTreeClassifier()
# 2.SVM分类
from sklearn.svm import SVC
svm = SVC()
# 3.KNN分类
from sklearn.neighbors import KNeighborsClassifier
knn = KNeighborsClassifier()
# 4.随机森林分类
from sklearn.ensemble import RandomForestClassifier
rf = RandomForestClassifier(n_estimators=100)
# 5.Adaboost分类
from sklearn.ensemble import AdaBoostClassifier
adb = AdaBoostClassifier(n_estimators=100)
# 6.GBDT分类
from sklearn.ensemble import GradientBoostingClassifier
gbdt = GradientBoostingClassifier(n_estimators=100)
# 7.Bagging分类
from sklearn.ensemble import BaggingClassifier
bag = BaggingClassifier(n_estimators=100)
# 8.极端数分类
from sklearn.tree import ExtraTreeClassifier
et = ExtraTreeClassifier()
# 9.朴素贝叶斯分类
from sklearn.naive_bayes import GaussianNB
gnb = GaussianNB()



data = read_data("data.xlsx", "data") # 读取数据
print(data) # 样本数据

# 交叉验证
KFold_Classificaton(data, 5, classification_model=dt, shuffle=True, random_state=1)
KFold_Classificaton(data, 5, classification_model=svm, shuffle=True, random_state=1)
KFold_Classificaton(data, 5, classification_model=knn, shuffle=True, random_state=1)
KFold_Classificaton(data, 5, classification_model=rf, shuffle=True, random_state=1)
KFold_Classificaton(data, 5, classification_model=adb, shuffle=True, random_state=1)
KFold_Classificaton(data, 5, classification_model=gbdt, shuffle=True, random_state=1)
KFold_Classificaton(data, 5, classification_model=bag, shuffle=True, random_state=1)
KFold_Classificaton(data, 5, classification_model=et, shuffle=True, random_state=1)
KFold_Classificaton(data, 5, classification_model=gnb, shuffle=True, random_state=1)

如果对你有帮助,还望支持一下~点击此处施舍或扫下图的码。
-----------------------分割线(以下是乞讨内容)-----------------------
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/723624.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

openwrt使用记录

背景: 平时在vmware中做实验时候,经常需要在不同的机器上下载一些github上的项目进行调试,之前解决方案是在路由器层小米ac2100上装openwrt,试用一番发现太卡了。放弃,这次在vmware中安装作为小米ac2100的旁路由 规划…

推荐五款优秀,可替代商业软件的开源软件

​ 在日常的使用中,我们需要使用各种软件来提高我们的工作效率或者进行创意的表达。然而,商业软件价格昂贵,某些国产软件又充斥着广告。因此,开源软件成为了一个不错的选择,以下是我推荐的五款优秀的开源软件。 图片浏…

一文get,最容易碰上的接口自动化测试问题汇总

本篇文章分享几个接口自动化用例编写过程遇到的问题总结,希望能对初次探索接口自动化测试的小伙伴们解决问题上提供一小部分思路。 sql语句内容出现错误 空格:由于有些字段判断是变量,需要将sql拼接起来,但是在拼接字符串时没有…

手写操作系统--完善内核加载器之内存检测

这一篇我们来完善内核加载器的功能,我们知道内存是很重要的区域,我们需要对内存有个大致的描述,哪些可用,那些不可用,内存有多大。因此在内核加载器中我们需要对内存进行检测。内存检测的方法翻译文档如下:…

DAY39:贪心算法(八)无重叠区间+划分字母区间+合并区间

文章目录 435.无重叠区间思路完整版注意点 右区间排序 763.划分字母区间思路完整版如何确定区间分界线debug测试时间复杂度 总结 56.合并区间思路最开始的写法:直接在原数组上修改debug测试 修改版时间复杂度总结 435.无重叠区间 给定一个区间的集合 intervals &am…

“钓鱼”网站也有https?如何一招识破?

作为企业网站安全建设的基础设施, SSL证书可以对数据进行加密传输,保护数据在传输过程中不被监听、截取和篡改,因此部署了SSL证书的网站会比传统的http协议更加安全,也更受主流操作系统和浏览器的信任。 然而随着SSL证书的普及&a…

AI做PPT,五分钟搞定别人一天的量,最喜欢卷PPT了

用AI做PPT 主题生成大纲制作PPT 主题生成大纲 如何使用人工智能工具,如ChatGPT和mindshow,快速生成PPT。 gpt国内版 制作PPT,你可能只有一个主题,但没有明确的提纲或思路。 问gpt:计算机视觉的周工作汇报。我这周学…

MyBatis 与 Hibernate 有哪些不同?

ORM框架的选择与适用场景 MyBatis和Hibernate都是Java领域中流行的面向关系型数据库的ORM(对象关系映射)框架。它们的共同目标是简化开发人员操作数据库的工作,提供便捷的持久化操作。然而,两者在设计理念和适用场景上有所不同。…

Zabbix6.0 的部署

目录 一、概述 二、 zabbix 1.zabbix简介 2.zabbix监控原理 3. Zabbix 6.0 新特性 3.1Zabbix server高可用防止硬件故障或计划维护期的停机 3.2 Zabbix 6.0 LTS新增Kubernetes监控功能,可以在Kubernetes系统从多个维度采集指标 4. Zabbix 6.0 功能组件 4.1Z…

浏览器内核的介绍

文章目录 1、什么是浏览器内核2、常用浏览器内核3、浏览器内核分类3. 1、Trident3.2、Gecko3.3、Webkit3.4、Chromium3.5、Presto3.6、国内主流浏览器 4、五大主流浏览器(诞生顺序)4.1、IE(Internet Explorer)浏览器4.2、Opera浏览…

解决vue3中使用个别form表单校验失灵

当我点击校验时 其他都有触发校验 唯独radio没有触发,绑定都没有问题 看一下代码 const data reactive({form: {},rules: {serverStatus: [{ required: true, message: "服务状态不能为空", trigger: change }],tenantName: [{ required: true, messag…

hypef 五、请求及响应

文档地址 Hyperf https://hyperf.wiki/2.0/#/zh-cn/response 一、请求 1.1 安装 composer require hyperf/http-message 框架自带不用手动安装。 1.2 请求对象 在 onRequest 生命周期内可获得Hyperf\HttpServer\Request对象。 可以通过以来注入和路由对应参数获取。 de…

我的创作纪念日 --- 简单记录

理想并不是一种空虚的东西,也并不玄奇;它既非幻想,更非野心,而是一种追求善美的意识。 机缘 大家好,今天是我成为创作者的第256天,也是我在CSDN上发布的第81篇文章。在这里,我想和大家简单记…

C#(五十五)之线程死锁

死锁是指多个线程共享资源是&#xff0c;都占用同意部分资源&#xff0c;而且都在等待对方师范另一部分资源&#xff0c;从而导致程序停滞不前的情况 示例&#xff1a; /// <summary>/// 定义一个刀/// </summary>public static object knife new object();/// …

git install报错问题

报错如下&#xff1a; fatal: unable to connect to githurb.com: 运行如下命令即可&#xff1a; git config --global url.https://github.com/.insteadOf git://github.com/ git config --global url."https://".insteadOf git://接着再删除node_moudels包&…

Spring Boot 操作 Redis 的各种实现

一、Jedis,Redisson,Lettuce三者的区别 共同点&#xff1a;都提供了基于Redis操作的Java API&#xff0c;只是封装程度&#xff0c;具体实现稍有不同。 不同点&#xff1a; 1.1、Jedis 是Redis的Java实现的客户端。支持基本的数据类型如&#xff1a;String、Hash、List、Se…

卡尔曼滤波实例分析(二)

1 问题 假设一物体以一初速度 v 0 v_0 v0​位于一高度为 y 0 y_0 y0​处正处于匀速下降状态&#xff0c;此时该物体启动制动装置&#xff0c;以一个加速度为 a a a的作用力反向运动 &#xff08;1&#xff09;建模 速度&#xff1a; V V 0 − a ∗ t V V_0 - a*t VV0​−a∗…

API6中JS UI实现富文本组件居右显示

【关键字】 RichText、富文本组件、API6、JS UI、居右显示 【关键代码如下】 index.hml <div class"container"><text>文本行高</text><text>文本行高</text><text>文本行高</text><text>文本行高</text>&…

SciencePub学术 | 计算机与生物信息类重点SCIEEI征稿中

SciencePub学术 刊源推荐: 计算机与生物信息类重点SCIE&EI征稿中&#xff01;信息如下&#xff0c;录满为止&#xff1a; 一、期刊概况&#xff1a; 计算机与生物信息类重点SCIE&EI 【期刊简介】IF&#xff1a;7.5-8.0&#xff0c;JCR1区&#xff0c;中科院1区TOP&…

vscode设置自己用的注释格式

ctrlshiftP 打开设置 输入snippets&#xff0c;选择配置用户代码片段[Snippets: Configure User Snippets]输入JavaScript&#xff0c;选择JavaScript.json 把这段代码替换进去 "Print to jsNoteTitle": {"prefix": "jsNoteTitle","body&q…