Python使用pytorch深度学习框架构造Transformer神经网络模型预测红酒分类例子

news2024/12/23 15:47:46

1、红酒数据介绍

经典的红酒分类数据集是指UCI机器学习库中的Wine数据集。该数据集包含178个样本,每个样本有13个特征,可以用于分类任务。

具体每个字段的含义如下:
alcohol:酒精含量百分比
malic_acid:苹果酸含量(克/升)
ash:灰分含量(克/升)
alcalinity_of_ash:灰分碱度(以mEq/L为单位)
magnesium:镁含量(毫克/升)
total_phenols:总酚含量(以毫克/升为单位)
flavanoids:类黄酮含量(以毫克/升为单位)
nonflavanoid_phenols:非类黄酮酚含量(以毫克/升为单位)
proanthocyanins:原花青素含量(以毫克/升为单位)
color_intensity:颜色强度(以 absorbance 为单位,对应于 1cm 路径长度处的相对宽度)
hue:色调,即色彩的倾向性或相似性(在 1 至 10 之间的一个数字)
od280/od315_of_diluted_wines:稀释葡萄酒样品的光密度比值,用于测量葡萄酒中各种化合物的浓度
proline:脯氨酸含量(以毫克/升为单位),是一种天然氨基酸,与葡萄酒的品质和口感有关。

2、红酒数据集分析

2.1 加载红酒数据集

# 加载红酒数据集
wineBunch = load_wine()
type(wineBunch)

sklearn.utils.Bunch
sklearn.utils.Bunch是Scikit-learn库中的一个数据容器,类似于Python字典(dictionary),
它可以存储任意数量和类型的数据,并且可以使用点(.)操作符来访问数据。Bunch常用于存储机器学习模型的数据集,
例如描述特征矩阵的数据、相关联的目标向量、特征名称等等,以便于组织和传递这些数据到模型中进行训练或预测。

2.2 红酒数据集形状

len(wineBunch.data),len(wineBunch.target)

(178, 178)

2.3 红酒数据集打印前5行和后5行

featuresDf = pd.DataFrame(data=wineBunch.data, columns=wineBunch.feature_names)   # 特征数据
labelDf = pd.DataFrame(data=wineBunch.target, columns=["target"])               # 标签数据
wineDf = pd.concat([featuresDf, labelDf], axis=1)  # 横向拼接
wineDf.head(5).append(wineDf.tail(5))              # 打印首尾5行

在这里插入图片描述

2.4 红酒数据集列名

wineDf.columns

Index([‘alcohol’, ‘malic_acid’, ‘ash’, ‘alcalinity_of_ash’, ‘magnesium’,
‘total_phenols’, ‘flavanoids’, ‘nonflavanoid_phenols’,
‘proanthocyanins’, ‘color_intensity’, ‘hue’,
‘od280/od315_of_diluted_wines’, ‘proline’, ‘target’],
dtype=‘object’)

2.5 红酒数据集目标标签

print(wineDf.target.unique())
[0 1 2]

3、Transformer对红酒进行分类

3.1 Transformer介绍

Transformer是一种基于注意力机制的神经网络结构,主要用于自然语言处理领域中的序列到序列转换任务,比如机器翻译、文本摘要等。它在2017年被Google提出,并被成功应用于Google Translate中。

Transformer的主要特点在于使用了完全基于注意力机制的编码器-解码器结构,避免了传统循环神经网络(如LSTM)中存在的长序列依赖问题和梯度消失问题。此外,Transformer还使用了残差连接和层归一化等技术,增强了模型的训练效果和泛化能力。

在Transformer模型中,输入序列和输出序列都被表示为固定长度的向量,称为词向量,由多个嵌入层和多个编码器和解码器层组成。其中,编码器和解码器层包括多头注意力机制、前馈神经网络和残差连接等模块,以实现对序列的有效建模和转换。

3.2 引入依赖库

import random         # 导入 random 模块,用于随机数生成
import torch          # 导入 PyTorch 模块,用于深度学习任务
import numpy as np    # 导入 numpy 模块,用于数值计算
from torch import nn  # 从 PyTorch 中导入神经网络模块
from sklearn import datasets  # 从sklearn引入数据集
from sklearn.model_selection import train_test_split  # 导入 sklearn 库中的 train_test_split 函数,用于数据划分
from sklearn.preprocessing import StandardScaler     # 导入 sklearn 库中的 StandardScaler 类,用于数据标准化

3.3 设置随机种子

# 设置随机种子,让模型每次输出的结果都一样
seed_value = 42
random.seed(seed_value)                         # 设置 random 模块的随机种子
np.random.seed(seed_value)                      # 设置 numpy 模块的随机种子
torch.manual_seed(seed_value)                   # 设置 PyTorch 中 CPU 的随机种子
#tf.random.set_seed(seed_value)                 # 设置 Tensorflow 中随机种子
if torch.cuda.is_available():                   # 如果可以使用 CUDA,设置随机种子
    torch.cuda.manual_seed(seed_value)          # 设置 PyTorch 中 GPU 的随机种子
    torch.backends.cudnn.deterministic = True   # 使用确定性算法,使每次运行结果一样
    torch.backends.cudnn.benchmark = False      # 不使用自动寻找最优算法加速运算

3.4 检测GPU是否可用

# 检测GPU是否可用
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

3.5 加载数据集

# 加载红酒数据集
wine = datasets.load_wine()
X = wine.data
y = wine.target

3.6 拆分训练集和测试集

# 拆分成训练集和测试集,训练集80%和测试集20%
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)

3.7 缩放数据

# 缩放数据
scaler = StandardScaler() # 创建一个标准化转换器的实例
X_train = scaler.fit_transform(X_train) # 对训练集进行拟合(计算平均值和标准差)
X_test = scaler.transform(X_test) # 对测试集进行标准化转换,使用与训练集相同的平均值和标准差

3.8 转化成pytorch张量

# 将训练集转换为 PyTorch 张量,并转换为浮点数类型,如果 GPU 可用,则将张量移动到 GPU 上
X_train = torch.tensor(X_train).float().to(device)
# 将训练集转换为 PyTorch 张量,并转换为长整型,如果 GPU 可用,则将张量移动到 GPU 上
y_train = torch.tensor(y_train).long().to(device)
X_test = torch.tensor(X_test).float().to(device)
y_test = torch.tensor(y_test).long().to(device)

3.9 定义Transformer模型

定义 Transformer 模型

class TransformerModel(nn.Module):
    def __init__(self, input_size, num_classes):
        super(TransformerModel, self).__init__()
        # 构建Transformer编码层,参数包括输入维度、注意力头数
        # 其中d_model要和模型输入维度相同
        self.encoder_layer = nn.TransformerEncoderLayer(d_model=input_size,  # 输入维度
                                                        nhead=1)             # 注意力头数
        # 构建Transformer编码器,参数包括编码层和层数
        self.encoder = nn.TransformerEncoder(self.encoder_layer,             # 编码层
                                             num_layers=1)                   # 层数
        # 构建线性层,参数包括输入维度和输出维度(num_classes)
        self.fc = nn.Linear(input_size,                                      # 输入维度
                            num_classes)                                     # 输出维度

    def forward(self, x):
        #print("A:", x.shape)  # torch.Size([142, 13])
        x = x.unsqueeze(1)    # 增加一个维度,变成(batch_size, 1, input_size)的形状
        #print("B:", x.shape)  # torch.Size([142, 1, 13])
        x = self.encoder(x)   # 输入Transformer编码器进行编码
        #print("C:", x.shape)  # torch.Size([142, 1, 13])
        x = x.squeeze(1)      # 压缩第1维,变成(batch_size, input_size)的形状
        #print("D:", x.shape)  # torch.Size([142, 13])
        x = self.fc(x)        # 输入线性层进行分类预测
        #print("E:", x.shape)  # torch.Size([142, 3])
        return x
# 初始化Transformer模型,并移动到GPU
model = TransformerModel(input_size=13,             # 输入维度
                         num_classes=3).to(device)  # 输出维度

3.10 定义损失函数和优化器

定义损失函数和优化器

criterion = nn.CrossEntropyLoss() # 定义损失函数-交叉熵损失函数

定义优化器

optimizer = torch.optim.Adam(model.parameters(), # 模型参数
lr=0.01) # 学习率

3.11 训练模型

# 训练模型
num_epochs = 100     # 训练100轮
for epoch in range(num_epochs):
    # 正向传播:将训练数据放到模型中,得到模型的输出
    outputs = model(X_train)
    loss = criterion(outputs, y_train)  # 计算交叉熵损失

    # 反向传播和优化:清零梯度、反向传播计算梯度,并根据梯度更新模型参数
    optimizer.zero_grad()  # 清零梯度
    loss.backward()        # 反向传播计算梯度
    optimizer.step()       # 根据梯度更新模型参数

    # 每10轮打印一次损失值,查看模型训练的效果
    if (epoch + 1) % 10 == 0:
        print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}')

3.12 测试模型

# 测试模型,在没有梯度更新的情况下,对测试集进行推断
with torch.no_grad():
    outputs = model(X_test)   # 使用训练好的模型对测试集进行预测
    _, predicted = torch.max(outputs.data, 1)  # 对输出的结果取 argmax,得到预测概率最大的类别
    accuracy = (predicted == y_test).sum().item() / y_test.size(0)  # 计算模型在测试集上的准确率
    print(f'Test Accuracy: {accuracy:.2f}')   # 打印测试集准确率

3.13 控制输出

Epoch [10/100], Loss: 0.1346
Epoch [20/100], Loss: 0.0325
Epoch [30/100], Loss: 0.0116
Epoch [40/100], Loss: 0.0064
Epoch [50/100], Loss: 0.0040
Epoch [60/100], Loss: 0.0029
Epoch [70/100], Loss: 0.0026
Epoch [80/100], Loss: 0.0021
Epoch [90/100], Loss: 0.0019
Epoch [100/100], Loss: 0.0019
Test Accuracy: 1.00

Process finished with exit code 0

正确率:100%

3.14 保存模型

# 保存模型
PATH = "model.pt"
torch.save(model.state_dict(), PATH)

3.15 加载模型

加载模型

model = Net()
model.load_state_dict(torch.load(PATH))
model.eval()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/479700.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Python之硬汉巴特勒

一、前言 2023年4月27日,NBA季后赛热火4:1淘汰雄鹿,实现黑八。全NBA联盟最硬气的男人——巴特勒,再次向全世界证明了他是NBA最硬气的男人。上一场刚狂轰56分大比分逆转雄鹿,这一场又是带领球队打出了血性,超高难度绝平…

快速搭建简单图床 - 远程访问本地搭建的EasyImage图床【内网穿透】

文章目录 1.前言2. EasyImage网站搭建2.1. EasyImage下载和安装2.2. EasyImage网页测试2.3.cpolar的安装和注册 3.本地网页发布3.1.Cpolar云端设置3.2 Cpolar本地设置 4. 公网访问测试5. 结语 1.前言 一个好的图床,是网站或者文章图片能稳定显示的关键,…

驱动管理软件推荐

最近发现电脑右下角的任务栏中有一个叹号图标,如下: 点进去之后发现是Windows自家的安全中心的内核隔离出现了点问题,内核隔离功能打不开 点击“查看不兼容的驱动程序”,发现是一些驱动作祟 我的电脑中显示了好多不兼容的驱动程序…

跟着我学习 AI丨语音识别:将语音转为数字信号

语音识别是一种人工智能技术,其主要目的是将人类说话转化为计算机可以理解的信息。语音识别技术的应用非常广泛,包括智能家居、汽车导航、语音搜索、人机交互、语音翻译等。 语音识别的技术原理 语音识别的技术原理是将人类的语音信号转化为数字信号。这…

『python爬虫』06. 数据解析之re正则解析(保姆级图文)

目录 1. 什么是re解析2. 正则规则元字符量词匹配模式测试 3. 正则案例4. re模块的使用4.1 findall: 匹配字符串中所有的符合正则的内容4.2 finditer: 匹配字符串中所有的内容[返回的是迭代器]4.3 search, 找到一个结果就返回, 返回的结果是match对象4.4 match 从头开始匹配&…

Windows forfiles命令详解,Windows按时间搜索特定类型的文件。

「作者简介」:CSDN top100、阿里云博客专家、华为云享专家、网络安全领域优质创作者 「推荐专栏」:对网络安全感兴趣的小伙伴可以关注专栏《网络安全入门到精通》 forfiles 一、结果输出格式二、按时间搜索三、搜索指定类型文件四、批量删除文件 forfile…

Ubuntu远程SSH连接与远程桌面连接

目录 一、远程桌面连接 二、远程SSH连接 1、安装客户端 2、安装服务端 3、SSH客户端和服务端的区别 一、远程桌面连接 首先需要在Ubuntu里进行些设置,点击界面右上角的控制区,选择设置选项; 弹出界面进入网络中,点击设置图…

【致敬未来的攻城狮计划】— 连续打卡第十八天:FSP固件库开发GPT — PWM输出波形 — LED呼吸灯

系列文章目录 1.连续打卡第一天:提前对CPK_RA2E1是瑞萨RA系列开发板的初体验,了解一下 2.开发环境的选择和调试(从零开始,加油) 3.欲速则不达,今天是对RA2E1 基础知识的补充学习。 4.e2 studio 使用教程 5.…

Rust开发环境搭建到运行第一个程序HelloRust

一、Rust语言 1.1 Rust语言介绍 Rust 语言是一种高效、可靠的通用高级语言。其高效不仅限于开发效率,它的执行效率也是令人称赞的,是一种少有的兼顾开发效率和执行效率的语言。 Rust 语言由 Mozilla 开发,最早发布于2014年 9月。Rust 的编…

《ADC和DAC的基本架构》----学习记录(二)

2 模数转换器 2.1 ADC架构I:Flash转换器 2.1.1 比较器:1位ADC 转换开关是 1 位 DAC,而比较器是 1 位 ADC,如图所示。如果输入超过阈值,输出即会具有一个逻辑值,而输入低于阈值时输出又会有另一个值。此外…

寻找2020+跳蚱蜢(蓝桥杯JAVA解法)

目录 寻找2020:用户登录 题目描述 运行限制 跳蚱蜢:用户登录 题目描述 运行限制 寻找2020:用户登录 题目描述 本题为填空题,只需要算出结果后,在代码中使用输出语句将所填结果输出即可。 小蓝有一个数字矩阵&a…

使用ChatGPT生成了十种排序算法

前言 当前ChatGPT非常火爆,对于程序员来说,ChatGPT可以帮助编写很多有用的代码。比如:在算法的实现上,就可以替我们省很多事。所以,小试牛刀一下,看看ChatGPT生成了排序算法怎么样? 简介 排序…

网站搭建之配置tomcat

【 本次配置架构】 【安全配置】 1.删除后台登录 在tomcat安装目录下的/conf文件下编辑tomcat-users.xml,删除里面带有标签的内容块,默认这部分是被注释了的。注释了任然会显示后台登录,需要彻底删除。 进入末行模式,也就是使用vim进去后&…

Flask开发之环境搭建

目录 1、安装flask 2、创建Flask工程 ​编辑 3、初始化效果 4、运行效果 5、设置Debug模式 6、设置Host 7、设置Port 8、在app.config中添加配置 1、安装flask 如果电脑上从没有安装过flask,则在命令行界面输入以下命令: pip install flask 如果电…

【MFAC】基于偏格式动态线性化的无模型自适应控制(Matlab代码)

例题来源:侯忠生教授的《无模型自适应控制:理论与应用》(2013年科学出版社)。 👉对应书本 4.3 单输入单输出系统(SISO)偏格式动态线性化(PFDL)的无模型自适应控制(MFAC) 上一篇博客介绍了基于紧格式动态线性化的无模型…

C++每日一练:打家劫室(详解动态规划法)

文章目录 前言一、题目二、分析三、代码总结 前言 这题目出得很有意思哈,打劫也是很有技术含量滴!不会点算法打劫这么粗暴的工作都干不好。 提示:以下是本篇文章正文内容,下面案例可供参考 一、题目 题目名称: 打家…

实现Newton方法的最小化函数(pytorch)

首先,我们要明确需求 def newton(theta, f, tol 1e-8, fscale1.0, maxit 100, max_half 20) ● theta是优化参数的初始值的一个向量。 ● f是要最小化的目标函数。该函数应将PyTorch张量作为输入,并返回一个张量。 ● tol是收敛容忍度。 ● fscale 粗…

【Leetcode -328.奇偶链表 - 725.分隔链表】

Leetcode Leetcode -328.奇偶链表Leetcode - 725.分隔链表 Leetcode -328.奇偶链表 题目:给定单链表的头节点 head ,将所有索引为奇数的节点和索引为偶数的节点分别组合在一起,然后返回重新排序的列表。 第一个节点的索引被认为是 奇数 &am…

苏州百特电器有限公司网站设计

苏州百特电器有限公司网站设计 五一假期作业企业门户网站布局设计 基于 <div> 的企业门户网站设计 by 小喾苦 我这里仅仅是使用 html css 来实现这个网站的效果&#xff0c;并不是宣传这个网站(现在这个网站已经过时并且无法进入) 实现效果 https://xkk1.github.io/…

出差在外,远程访问企业局域网象过河ERP系统「内网穿透」

文章目录 概述1.查看象过河服务端端口2.内网穿透3. 异地公网连接4. 固定公网地址4.1 保留一个固定TCP地址4.2 配置固定TCP地址 5. 使用固定地址连接 转载自远程穿透文章&#xff1a;公网远程访问公司内网象过河ERP系统「内网穿透」 概述 ERP系统对于企业来说重要性不言而喻&am…