Python基于Pytorch Transformer实现对iris鸢尾花的分类预测,分别使用CPU和GPU训练

news2025/1/6 18:17:40

1、鸢尾花数据iris.csv

在这里插入图片描述

iris数据集是机器学习中一个经典的数据集,由英国统计学家Ronald Fisher在1936年收集整理而成。该数据集包含了3种不同品种的鸢尾花(Iris Setosa,Iris Versicolour,Iris Virginica)各50个样本,每个样本包含了花萼长度(sepal length)、花萼宽度(sepal width)、花瓣长度(petal length)、花瓣宽度(petal width)四个特征。

iris数据集的主要应用场景是分类问题,在机器学习领域中被广泛应用。通过使用iris数据集作为样本集,我们可以训练出一个分类器,将输入的新鲜鸢尾花归类到三种品种中的某一种。iris数据集的特征数据已经被广泛使用,也是许多特征选择算法和模型选择算法的基础数据集之一。

在这里插入图片描述
总共150条数据
在这里插入图片描述
数据分布均匀,每种分类50条数据。

2、Transformer模型 CPU版本

# -*- coding:utf-8 -*-
import torch                        # 导入 PyTorch 库
from torch import nn                # 导入 PyTorch 的神经网络模块
from sklearn import datasets        # 导入 scikit-learn 库中的 dataset 模块
from sklearn.model_selection import train_test_split  # 从 scikit-learn 的 model_selection 模块导入 split 方法用于分割训练集和测试集
from sklearn.preprocessing import StandardScaler      # 从 scikit-learn 的 preprocessing 模块导入方法,用于数据缩放

print("# 加载鸢尾花数据集")
# 加载鸢尾花数据集,这个数据集在机器学习中比较著名
iris = datasets.load_iris()
X = iris.data           # 对应输入变量或属性(features),含有4个属性:花萼长度、花萼宽度、花瓣长度 和 花瓣宽度
y = iris.target         # 对应目标变量(target),也就是类别标签,总共有3种分类

print("拆分训练集和测试")
# 把数据集按照80:20的比例来划分为训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)

print("数据缩放")
# 对训练集和测试集进行归一化处理,常用方法之一是StandardScaler
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)

print("数据转tensor类型")
# 将训练集和测试集转换为PyTorch的张量对象并设置数据类型
X_train = torch.tensor(X_train).float()
y_train = torch.tensor(y_train).long()
X_test = torch.tensor(X_test).float()
y_test = torch.tensor(y_test).long()

# 定义 Transformer 模型
class TransformerModel(nn.Module):
    def __init__(self, input_size, num_classes):
        super(TransformerModel, self).__init__()
        # 定义 Transformer 编码器,并指定输入维数和头数
        self.encoder_layer = nn.TransformerEncoderLayer(d_model=input_size, nhead=1)
        self.encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=1)
        # 定义全连接层,将 Transformer 编码器的输出映射到分类空间
        self.fc = nn.Linear(input_size, num_classes)

    def forward(self, x):
        # 在序列的第2个维度(也就是时间步或帧)上添加一维以适应 Transformer 的输入格式
        x = x.unsqueeze(1)
        # 将输入数据流经 Transformer 编码器进行特征提取
        x = self.encoder(x)
        # 通过压缩第2个维度将编码器的输出恢复到原来的形状
        x = x.squeeze(1)
        # 将编码器的输出传入全连接层,获得最终的输出结果
        x = self.fc(x)
        return x

print("创建模型")
# 初始化 Transformer 模型
model = TransformerModel(input_size=4, num_classes=3)

print("定义损失函数和优化器")
# 定义损失函数(交叉熵损失)和优化器(Adam)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

print("训练模型")
# 训练模型,对数据集进行多次迭代学习,更新模型的参数
num_epochs = 100
for epoch in range(num_epochs):
    # 前向传播计算输出结果
    outputs = model(X_train)
    loss = criterion(outputs, y_train)

    # 反向传播,更新梯度并优化模型参数
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # 打印每10个epoch的loss值
    if (epoch + 1) % 10 == 0:
        print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}')

print("测试模型")
# 测试模型的准确率
with torch.no_grad():
    # 对测试数据集进行预测,并与真实标签进行比较,获得预测
    outputs = model(X_test)
    _, predicted = torch.max(outputs.data, 1)
    accuracy = (predicted == y_test).sum().item() / y_test.size(0)
    print(f'Test Accuracy: {accuracy:.2f}')

控制台输出:

# 加载鸢尾花数据集
拆分训练集和测试
数据缩放
数据转tensor类型
创建模型
定义损失函数和优化器
训练模型
Epoch [10/100], Loss: 0.5863
Epoch [20/100], Loss: 0.3978
Epoch [30/100], Loss: 0.2954
Epoch [40/100], Loss: 0.1765
Epoch [50/100], Loss: 0.1548
Epoch [60/100], Loss: 0.1184
Epoch [70/100], Loss: 0.0847
Epoch [80/100], Loss: 0.2116
Epoch [90/100], Loss: 0.0941
Epoch [100/100], Loss: 0.1062
测试模型
Test Accuracy: 0.97

正确率97%

3、Transformer模型 GPU版本

# -*- coding:utf-8 -*-
import torch                        # 导入 PyTorch 库
from torch import nn                # 导入 PyTorch 的神经网络模块
from sklearn import datasets        # 导入 scikit-learn 库中的 dataset 模块
from sklearn.model_selection import train_test_split  # 从 scikit-learn 的 model_selection 模块导入 split 方法用于分割训练集和测试集
from sklearn.preprocessing import StandardScaler      # 从 scikit-learn 的 preprocessing 模块导入方法,用于数据缩放

print("# 检查GPU是否可用")
# Check if GPU is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print("# 加载鸢尾花数据集")
# 加载鸢尾花数据集,这个数据集在机器学习中比较著名
iris = datasets.load_iris()
X = iris.data           # 对应输入变量或属性(features),含有4个属性:花萼长度、花萼宽度、花瓣长度 和 花瓣宽度
y = iris.target         # 对应目标变量(target),也就是类别标签,总共有3种分类

print("拆分训练集和测试")
# 把数据集按照80:20的比例来划分为训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)

print("数据缩放")
# 对训练集和测试集进行归一化处理,常用方法之一是StandardScaler
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)

print("数据转tensor类型")
# 将训练集和测试集转换为PyTorch的张量对象并设置数据类型,加上to(device)可以运行在GPU上
X_train = torch.tensor(X_train).float().to(device)
y_train = torch.tensor(y_train).long().to(device)
X_test = torch.tensor(X_test).float().to(device)
y_test = torch.tensor(y_test).long().to(device)

# 定义 Transformer 模型
class TransformerModel(nn.Module):
    def __init__(self, input_size, num_classes):
        super(TransformerModel, self).__init__()
        # 定义 Transformer 编码器,并指定输入维数和头数
        self.encoder_layer = nn.TransformerEncoderLayer(d_model=input_size, nhead=1)
        self.encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=1)
        # 定义全连接层,将 Transformer 编码器的输出映射到分类空间
        self.fc = nn.Linear(input_size, num_classes)

    def forward(self, x):
        # 在序列的第2个维度(也就是时间步或帧)上添加一维以适应 Transformer 的输入格式
        x = x.unsqueeze(1)
        # 将输入数据流经 Transformer 编码器进行特征提取
        x = self.encoder(x)
        # 通过压缩第2个维度将编码器的输出恢复到原来的形状
        x = x.squeeze(1)
        # 将编码器的输出传入全连接层,获得最终的输出结果
        x = self.fc(x)
        return x

print("创建模型")
# 初始化 Transformer 模型
model = TransformerModel(input_size=4, num_classes=3).to(device)

print("定义损失函数和优化器")
# 定义损失函数(交叉熵损失)和优化器(Adam)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

print("训练模型")
# 训练模型,对数据集进行多次迭代学习,更新模型的参数
num_epochs = 100
for epoch in range(num_epochs):
    # 前向传播计算输出结果
    outputs = model(X_train)
    loss = criterion(outputs, y_train)

    # 反向传播,更新梯度并优化模型参数
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # 打印每10个epoch的loss值
    if (epoch + 1) % 10 == 0:
        print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}')

print("测试模型")
# 测试模型的准确率
with torch.no_grad():
    # 对测试数据集进行预测,并与真实标签进行比较,获得预测
    outputs = model(X_test)
    _, predicted = torch.max(outputs.data, 1)
    accuracy = (predicted == y_test).sum().item() / y_test.size(0)
    print(f'Test Accuracy: {accuracy:.2f}')

控制台输出:

# 检查GPU是否可用
# 加载鸢尾花数据集
拆分训练集和测试
数据缩放
数据转tensor类型
创建模型
定义损失函数和优化器
训练模型
Epoch [10/100], Loss: 0.6908
Epoch [20/100], Loss: 0.4861
Epoch [30/100], Loss: 0.3541
Epoch [40/100], Loss: 0.2136
Epoch [50/100], Loss: 0.2149
Epoch [60/100], Loss: 0.1263
Epoch [70/100], Loss: 0.1227
Epoch [80/100], Loss: 0.0685
Epoch [90/100], Loss: 0.1775
Epoch [100/100], Loss: 0.0889
测试模型
Test Accuracy: 0.97

正确率:97%

4、代码说明

在这段代码中,我们首先通过 torch.cuda.is_available() 检查GPU是否可用,如果GPU可用,则将计算转移到GPU,以便更快地训练模型。

然后使用 datasets.load_iris() 函数加载鸢尾花数据集。对于机器学习任务,我们通常会将数据集分成训练集和测试集,以便评估模型的性能。在本例中,使用 train_test_split() 方法将数据集分成训练集和测试集。

接下来,我们使用 StandardScaler 对数据进行缩放,以获得更好的模型性能。然后将数据集转换为PyTorch tensor格式,并使用 to() 将它们移动到GPU上(如果存在)。

然后定义了一个类名为 TransformerModel 的模型,并继承了 nn.Module。这个模型包括 TransformerEncoder 层、全局平均池化层和线性层。在这个模型中,输入是一组4维数值(表示鸢尾花的4种特征),输出需要有3个类别,因此最后一层的输出大小设置为3。

接下来,我们初始化模型并将其移动到GPU上,之后定义损失函数和优化器以进行模型的优化。在每个迭代步骤内进行前向传递、反向传递和梯度更新,同时打印出损失值以便调试和优化模型。经过若干次迭代后,我们使用测试集对模型进行测试,最后输出测试集的精度值。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/461372.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

BatchNormalization和LayerNormalization的理解、适用范围、PyTorch代码示例

文章目录 为什么要NormalizationBatchNormLayerNormtorch代码示例 学习神经网络归一化时,文章形形色色,但没找到适合小白通俗易懂且全面的。学习过后,特此记录。 为什么要Normalization 当输入数据量级极大或极小时,为保证输出数…

【算法基础】直接插入排序 + 希尔排序

👦个人主页:Weraphael ✍🏻作者简介:目前正在学习c和算法 ✈️专栏:【C/C】算法 🐋 希望大家多多支持,咱一起进步!😁 如果文章有啥瑕疵 希望大佬指点一二 如果文章对你有…

基于人类反馈的强化学习(RLHF)在LLM领域是如何运作的?

基于人类反馈的强化学习在LLM领域是如何运作的? 为什么需要强化学习RLHFPre-training modelReward ModelFine-tune with RL 参考 为什么需要强化学习 指标无法衡量。在过去的nlp任务中,词性标注、机器翻译、语义判别等任务是nlp任务的主力军&#xff0c…

Hytrix原理

这里写目录标题 Hytrix容错机制熔断资源隔离线程池隔离信号量隔离 服务降级请求缓存请求合并 Hystrix流程图 Hytrix容错机制 熔断 在一个统计时间窗口(HYST rixCommandProperties.metricsRollingStatisticalWindowInMilliseconds())内,处理…

转化率双倍暴涨——客户自助服务门户

近年来,社交媒体的兴起使客户负责品牌对话。随着电子商务和在线帮助需求的扩大,公司必须满足并超越新的期望,以保持客户满意度。 通过SaleSmartly(ss客服)自动化流程功能建立客户自助服务是一种双赢的决策&#xff0c…

Ajax XHR响应

文章目录 AJAX 服务器 响应服务器响应responseText 属性responseXML 属性 AJAX 服务器 响应 服务器响应 如需获得来自服务器的响应,请使用 XMLHttpRequest 对象的 responseText 或 responseXML 属性。 属性描述responseText获得字符串形式的响应数据。responseXML…

2015年下半年软件设计师下午试题

试题一 【说明】 某慕课教育平台欲添加在线作业批改系统,以实现高效的作业提交与批改,并进行统计。学生和讲师的基本信息已经初始化为数据库中的学生表和讲师表。系统的主要功能如下: 提交作业。验证学生标识后,学生将电子作业通…

Ajax 数据库

文章目录 AJAX Database 实例AJAX 数据库实例实例解释 showCustomer() 函数AJAX 服务器页面 AJAX Database 实例 AJAX 可用来与数据库进行动态通信。 AJAX 数据库实例 下面的例子将演示网页如何通过 AJAX 从数据库读取信息: 请在下面的下拉列表中选择一个客户&…

【wpf】转换器 Converter

今天积攒了一个转换器的用法,分享给各位。 我们经常会有这种需求: 某些控件有时需要显示,有时需要隐藏,比如: 那,我就想通过一个bool变量和是否显示绑定。 但是我们知道,是否显示,…

nssctf web

[BJDCTF 2020]easy_md5 测试发现输入的内容会通过get传递但是没有其他内容 观察一下响应头 我们看到了 select * from admin where passwordmd5($pass,true) 这里的md5(string,raw) 这里的string应该清楚就是字符串,是必选参数 raw是可选参数 默认不写为FA…

08 - Linux终端与进程

---- 整理自狄泰软件唐佐林老师课程 查看所有文章链接:(更新中)Linux系统编程训练营 - 目录 文章目录 1. 详解控制台与终端1.1 问题1.2 历史回顾1.2.1 控制台(Console)1.2.2 终端(Terminal) 1.3 历史发展进程1.4 终端与…

【Springboot系列】项目启动时怎么给mongo表加自动过期索引

1、前言 在之前操作mongo的过程中,都是自动创建,几乎没有手动创建过表。 这次开发中有张表数据量大,并且不是特别重要,数据表维护一个常见的问题是过期数据没有被及时清理,导致数据量过大,查询变得缓慢。…

【操作系统】CPU 缓存一致性、MESI 协议

【操作系统】CPU 缓存一致性、MESI 协议 参考资料: CPU缓存一致性协议(MESI) 【JUC】Java并发机制的底层实现原理 CPU 缓存一致性 文章目录 【操作系统】CPU 缓存一致性、MESI 协议CPU Cache 的数据写入写直达写回 缓存一致性问题总线嗅探MESI 协议总结 CPU Cache …

电感知识大全

目录 一、电感的种类 1、共模电感 2、差模电感 3、工字电感 功率电感 4、磁珠 5、变压器 6、R棒电感、棒形电感、差模电感 二、电感符号 三、电感特性 前面在学习电容的时候,为了让大家更形象,更通俗的去理解这个元器件,都是拿水缸去…

IO多路复用——select函数

1.select函数原型和fd_set结构体说明 1.1 select函数原型 ​ 使用 select 这种 IO 多路转接方式需要调用一个同名函数 select,这个函数是跨平台的,Linux、Mac、Windows 都是支持的。程序员通过调用这个函数可以委托内核帮助我们检测若干个文件描述符的…

mybatisPlus·入门·贰

文章目录 1 简单CRUD接口1.1 根据id查询({id传参)1.1.1 接口类直接继承IService1.1.2 controller直接调用方法 1.2 根据ids查询1.3 新增1.3.1 接口类直接继承IService1.3.2 controller直接调用方法 1.4 修改状态(Query传参)1.4.1 …

GPT-4 IDEA神仙插件亲测帮助亿万用户解决痛点!

最近,Intellij IDEA的插件商店推出了一款新的插件——Bito,据说使用了GPT-4和ChatGPT来帮助开发人员编写代码,并且下载量已经达到了65K以上。 这款插件可以将GPT-4和ChatGPT引入IDE来大大提高开发人员的效率。它使用了OpenAI的模型&#xff0…

ESP32设备驱动-BMM150数字地磁传感器驱动

BMM150数字地磁传感器驱动 文章目录 BMM150数字地磁传感器驱动1、BMM150介绍2、硬件准备3、软件准备4、驱动实现1、BMM150介绍 BMM150 是一款低功耗、低噪声的 3 轴数字地磁传感器,用于罗盘应用。 具有 1.56 x 1.56 mm 和 0.60 mm 高度的 12 引脚晶圆级芯片级封装 (WLCSP) 为…

JavaEE 2(4/24)

目录 1.线程 2.前台线程和后台线程 3.run和start的区别 4.线程的终止 5.线程等待 6.获取当前线程的引用 1.线程 创建线程需要继承Thread方法 调用start方法就会生成一个新的线程,调用run方法会在老的线程继续跑 main也是个线程,他是自动调用的.线程休息了先唤醒main和thre…

【文章学习系列之模型】Informer

本章内容 文章概况总体结构重点结构self-attention distilling operation(自注意蒸馏操作)generative style decoder(生成式解码器)ProbSparse self-attention mechanism(概率稀疏自注意机制) 实验结果主要…