PyTorch深度学习实战 | 基于线性回归、决策树和SVM进行鸢尾花分类

news2025/1/9 0:37:16

鸢尾花数据集是机器学习领域非常经典的一个分类任务数据集。它的英文名称为Iris Data Set,使用sklearn库可以直接下载并导入该数据集。数据集总共包含150行数据,每一行数据由4个特征值及一个标签组成。标签为三种不同类别的鸢尾花,分别为:Iris Setosa,Iris Versicolour,Iris Virginica。

对于多分类任务,有较多机器学习的算法可以支持。本文将使用决策树、线性回归、SVM等多种算法来完成这一任务,并对不同方法进行比较。

01、使用Logistic实现鸢尾花分类

在前面介绍过Logistic用于二分类任务,对其进行扩展也用于多分类任务。下面将使用sklearn库完成一个基于Logistic的鸢尾花分类任务。如代码清单1所示,首先是导入sklearn.datasets包从而加载数据集,并将数据集按照测试集占比0.2随机分为训练集和测试集。

代码清单1 导入包以及加载数据集

rom sklearn.datasets import load_iris
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
import numpy as np
from sklearn.preprocessing import label_binarize
from sklearn.metrics import confusion_matrix, precision_score, accuracy_score,recall_score, f1_score, roc_auc_score, \
    roc_curve
import matplotlib.pyplot as plt
 
# 加载数据集
def loadDataSet():
    iris_dataset = load_iris()
    X = iris_dataset.data
    y = iris_dataset.target
    # 将数据划分为训练集和测试集
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
    return X_train, X_test, y_train, y_test

如代码清单2所示,编写函数训练Logistic模型。

代码清单2 训练Logistic模型

# 训练Logistic模性
def trainLS(x_train, y_train):
    # Logistic生成和训练
    clf = LogisticRegression()
    clf.fit(x_train, y_train)
    return clf

Logistic模型较为简单,不需要额外设置超参数即可开始训练。如代码清单3所示,初始化Logistic模型并将模型在训练集上训练,返回训练好的模型。

代码清单3 测试模型及打印各种评价指标

# 测试模型
def test(model, x_test, y_test):
    # 将标签转换为one-hot形式
    y_one_hot = label_binarize(y_test, np.arange(3))
    # 预测结果
    y_pre = model.predict(x_test)
    # 预测结果的概率
    y_pre_pro = model.predict_proba(x_test)
 
    # 混淆矩阵
    con_matrix = confusion_matrix(y_test, y_pre)
    print('confusion_matrix:\n', con_matrix)
    print('accuracy:{}'.format(accuracy_score(y_test, y_pre)))
    print('precision:{}'.format(precision_score(y_test, y_pre, average='micro')))
    print('recall:{}'.format(recall_score(y_test, y_pre, average='micro')))
    print('f1-score:{}'.format(f1_score(y_test, y_pre, average='micro')))
 
    # 绘制ROC曲线
    drawROC(y_one_hot, y_pre_pro)

在预测结果时,为了方便后面绘制ROC曲线,需要首先将测试集的标签转化为one-hot的形式,并得到模型在测试集上预测结果的概率值即y_pre_pro,从而传入drawROC函数完成ROC曲线的绘制。除此外,该函数实现了输出混淆矩阵以及计算准确率、精确率、查全率以及f1-score的功能。

代码清单4 绘制ROC曲线

def drawROC(y_one_hot, y_pre_pro):
    # AUC值
    auc = roc_auc_score(y_one_hot, y_pre_pro, average='micro')
    # 绘制ROC曲线
    fpr, tpr, thresholds = roc_curve(y_one_hot.ravel(), y_pre_pro.ravel())
    plt.plot(fpr, tpr, linewidth=2, label='AUC=%.3f' % auc)
    plt.plot([0, 1], [0, 1], 'k--')
    plt.axis([0, 1.1, 0, 1.1])
    plt.xlabel('False Postivie Rate')
    plt.ylabel('True Positive Rate')
    plt.legend()
    plt.show()

如代码清单4所示为绘制ROC曲线的代码实现。最后将加载数据集,训练模型,以及模型验证的整个流程连接起来从而实现main函数,如代码清单5所示。

代码清单5 main函数设置

if __name__ == '__main__':
    X_train, X_test, y_train, y_test = loadDataSet()
    model = trainLS(X_train, y_train)
    test(model, X_test, y_test)

将上述所有代码放在同一py脚本文件中,如图1所示可得最终的输出结果为

图1 命令行打印的测试结果

绘制得到的ROC曲线如图2所示。

图2 ROC曲线

Logistic是一个较为简单的模型,参数量较少,一般也用于较为简单的分类任务中,当任务更为复杂时,可以选取更为复杂的模型获得更好的效果,下面将使用不同的模型从而验证同一任务在不同模型下的表现。

02、使用决策树实现鸢尾花分类

由于只改动了模型,加载数据集、模型评价等其他部分的代码不需要改动,如代码清单6所示,增加新的函数用于训练决策树模型。

代码清单6 使用决策树模型进行训练

from sklearn import tree
# 训练决策树模性
def trainDT(x_train, y_train):
    # DT生成和训练
    clf = tree.DecisionTreeClassifier(criterion="entropy")
    clf.fit(x_train, y_train)
    return clf

同时修改main函数中调用的训练函数如代码清单7所示。

代码清单7 修改main函数内容

if __name__ == '__main__':
    X_train, X_test, y_train, y_test = loadDataSet()
    model = trainDT(X_train, y_train)
    test(model, X_test, y_test)

最后运行可得命令行输出如图3所示。

图3 决策树模型预测结果

以及ROC曲线如图4所示。

图4 决策树模型绘制ROC曲线

相比Logistic模型,决策树模型无论在哪一项指标上都得到了更高的评分,且决策树模型不会像Logistic模型一样受初始化的影响,多次运行程序均可获得相同的输出模型,而Logistic模型运行多次会发现评价指标会在某个范围内上下抖动。

03、使用SVM实现鸢尾花分类

到现在相信大家都已经非常熟悉如何继续修改代码从而实现SVM模型的预测,实现SVM模型的训练代码如代码清单8所示

代码清单8 使用SVM模型进行训练

# 训练SVM模性
from sklearn import svm
def trainSVM(x_train, y_train):
    # SVM生成和训练
    clf = svm.SVC(kernel='rbf', probability=True)
    clf.fit(x_train, y_train)
    return clf

同时修改main函数,如代码清单9所示。

代码清单9 修改main函数内容

if __name__ == '__main__':
    X_train, X_test, y_train, y_test = loadDataSet()
    model = trainSVM(X_train, y_train)
    test(model, X_test, y_test)

程序运行输出如图5所示。

图5 使用SVM模型预测结果

绘制得到的ROC曲线如图6所示。

图6 使用SVM模型绘制的ROC曲线

可以发现,随着模型进一步变得复杂,最终预测的各项指标进一步上升,在三个模型中SVM模型的高斯核最终结果在测试集中表现得最好且没有发生过拟合的现象,因此可以选用SVM模型来完成鸢尾花分类这一任务。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/417737.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【AI帮我写代码,上班摸鱼不是梦】手摸手图解CodeWhisperer的安装使用

IDEA插件 除了借助ChatGPT通过问答的方式生成代码,也可以通过IDEA插件在写代码是直接帮助我们生成代码。 目前,IDEA插件有CodeGeeX、CodeWhisperer、Copilot。其中,CodeGeeX和CodeWhisperer是完全免费的,Copilot是收费的&#x…

数据分析:麦当劳食品营养数据探索并可视化

系列文章目录 作者:i阿极 作者简介:Python领域新星作者、多项比赛获奖者:博主个人首页 😊😊😊如果觉得文章不错或能帮助到你学习,可以点赞👍收藏📁评论📒关注…

c++模板整理

目录 一.泛型编程​​​​​​​ 二.函数模板 2.1 函数模板概念 2.2函数模板格式 2.3 函数模板的原理 2.4 函数模板的实例化 2.5 模板参数的匹配原则 三.类模板 3.1 类模板的定义格式 3.2 类模板的实例化 3.3模板类 一.泛型编程​​​​​​​ 如何实现一个通用的交…

【前端之旅】快速上手Echarts

一名软件工程专业学生的前端之旅,记录自己对三件套(HTML、CSS、JavaScript)、Jquery、Ajax、Axios、Bootstrap、Node.js、Vue、小程序开发(Uniapp)以及各种UI组件库、前端框架的学习。 【前端之旅】Web基础与开发工具 【前端之旅】手把手教你安装VS Code并附上超实用插件…

「高并发业务必读」深入剖析 Java 并发包中的锁机制

故事 程序员小张: 刚毕业,参加工作1年左右,日常工作是CRUD 架构师老李: 多个大型项目经验,精通各种屠龙宝术; 小张和老李一起工作已有数月,双方在技术上也有了很多的交流,但是却总是…

GB28181 协议 SIP

2、注册信令 2.1基本注册 2.1.1 抓包过程 2.1.2 详细步骤 2.1.2.1、REGISTER REGISTER sip:34020000002000000001192.168.9.186:15060 SIP/2.0Via: SIP/2.0/TCP 192.168.9.186:42860;rport;branchz9hG4bK1557586049From: <sip:30514805331320000140192.168.9.186:5060>…

手写Spring框架-前奏-反射获取Annotation

目录 所谓反射 反射机制的作用 反射依赖reflect和Class 反射依赖的Class Class类的特点 获取Class对象的三种方式 获取类的构造方法并使用 获取类的成员变量并使用 获取类的成员方法并使用 问题引入 解析类的注解 解析成员变量的注解标签 解析方法上的注解 注解获…

Java类加载

类加载的时机 一个类型从被加载到虚拟机内存中开始&#xff0c;到卸载出内存为止&#xff0c;它的整个生命周期将会经历加载、验证、准备、解析、初始化、使用和卸载七个阶段。其中验证、准备、解析三个阶段统称为连接。 图中加载、验证、准备、初始化和卸载这五个阶段的顺序是…

CDGP数据治理专家认证含金量如何?值得考一个吗?

CDGP&#xff08;Certified Data Governance Professional&#xff09;数据治理专家认证的含金量非常高。该认证证明了持有人拥有数据治理方面的专业知识和技能&#xff0c;能够有效地管理和保护组织的数据资产。 CDGP认证考试内容涵盖数据治理的各个方面&#xff0c;包括数据…

看这家在线教育企业如何通过DHTMLX Scheduler,实现培训管理系统优化

“我们公司目前有一套培训管理系统&#xff0c;用于管理培训学员。目前学员越来越多&#xff0c;老旧的系统已经没法满足需求&#xff0c;导致我们经常需要手动记录学员出勤培训情况&#xff0c;除此之外&#xff0c;系统课程安排只展示时间&#xff0c;没法展示诸如主题&#…

macOS Big Sur 11.7.6 (20G1231) 正式版 ISO、PKG、DMG、IPSW 下载

本站下载的 macOS 软件包&#xff0c;既可以拖拽到 Applications&#xff08;应用程序&#xff09;下直接安装&#xff0c;也可以制作启动 U 盘安装&#xff0c;或者在虚拟机中启动安装。另外也支持在 Windows 和 Linux 中创建可引导介质。 2023 年 4 月 10 日&#xff08;北京…

【Vue全家桶】Pinia状态管理

【Vue全家桶】Pinia状态管理 文章目录【Vue全家桶】Pinia状态管理写在前面一、认识Pinia1.1 认识Pinia1.2 为什么使用Pinia&#xff1f;二、 Store2.1 定义Store2.2 Option对象2.3 setup函数2.4 使用定义的Store三、Pinia核心概念State3.1 定义State3.2 操作State3.3 使用选项式…

基于小生境粒子群优化算法的考虑光伏波动性的主动配电网有功无功协调优化(Matlab代码实现)

&#x1f468;‍&#x1f393;个人主页&#xff1a;研学社的博客&#x1f4a5;&#x1f4a5;&#x1f49e;&#x1f49e;欢迎来到本博客❤️❤️&#x1f4a5;&#x1f4a5;&#x1f3c6;博主优势&#xff1a;&#x1f31e;&#x1f31e;&#x1f31e;博客内容尽量做到思维缜密…

C语言基础——指针

文章目录一、指针1.指针的意义2.指针类型表示3.一些操作3.1打印1个变量地址3.2通过地址查看改地址的内容以及修改改地址的内容3.3操作某个空间 -- 4个字节,给他赋值为100&#xff0c;只知道该空间的地址0x8000 00004.指针变量的定义5.指针类型的大小6.指针变量的使用6.1 指针变…

python数据分析-matplotlib折线图知识总结01

python绘图库matplotlib的知识总结一.matplotlib是什么二.matplotlib的安装与导入三.matplotlib的常用函数四.matplotlib绘制折线图的使用方法1.设置图形大小2. 利用数据绘图3.调整x,y轴的刻度,旋转角度,显示描述信息,绘制网格,添加图例4.图形的样式5.绘制多条折线6.显示绘制的…

python知识记录:灵活使用numpy提高python数据分析效率!

NumPy是Python语言的一个第三方库&#xff0c;其支持大量高维度数组与矩阵运算。 作为python科学计算领域的三剑客之一&#xff0c;numpy在数据分析处理方面有着独特的魅力&#xff01; numpy模块的出现更多的是在数组处理的操作上面&#xff0c;并且支持和python常用的数据结…

Transformer在时序预测的应⽤第一弹——Autoformer

Transformer在时序预测的应⽤第一弹——Autoformer 原文地址&#xff1a;Autoformer: Decomposition Transformers with Auto-Correlation for Long-Term Series Forecasting&#xff08;NIPS 2021&#xff09; 做长时间序列的预测 Decomposition把时间序列做拆分&#xff0c…

目标检测——YOLOv7(十三)

简介&#xff1a; 继美团发布YOLOV6之后&#xff0c;YOLO系列原作者也发布了YOLOV7。主要从两点进行模型的优化&#xff1a;模型结构重参化和动态标签分配。 YOLOv7的特点是快&#xff01;相同体量下比YOLOv5精度更高&#xff0c;速度快120%&#xff0c;比YOLOX快180%。 Github…

RabbitMQ消息丢失的情况,以及如何通过代码解决

目录 RabbitMQ消息丢失问题&#xff1a; 代码部分&#xff1a; 完整代码&#xff1a; RabitMQConfig&#xff1a; CourseMQListener: 生产者跟交换机通信的消息丢失解决 &#xff1a; 交换机跟消息队列的消息丢失&#xff1a; 消息队列跟消费者的消息丢失&#xff1a; …

自动处理【支付宝交易支付投诉管理系统】配置指南

大家好&#xff0c;我是小悟 已经有小伙伴开始使用自动处理【支付宝交易支付投诉管理系统】&#xff0c;所以详细介绍一下如何配置。 阅读这篇文章之前&#xff0c;结合这篇【连夜干出来一个自动处理【支付宝交易支付投诉管理系统】&#xff0c;支持多商户】干货食用更佳。 连…