PyTorch中使用Transformer对一维序列进行分类的源代码

news2025/1/22 15:50:44

 在PyTorch中使用Transformer对一维序列进行分类是一种常见做法,尤其是在处理时间序列数据、自然语言处理等领域。Transformer模型因其并行化处理能力和自注意力机制而成为许多任务的首选模型。

下面是一个使用PyTorch实现Transformer对一维序列进行分类的完整示例代码,包括数据预处理、模型定义、训练和评估等部分。

1. 准备工作

首先,我们需要导入所需的库,并定义一些基本的参数。

1import torch
2import torch.nn as nn
3import torch.optim as optim
4from torch.utils.data import Dataset, DataLoader
5from sklearn.model_selection import train_test_split
6from sklearn.preprocessing import StandardScaler
7import numpy as np
8
9# 定义超参数
10input_dim = 10  # 序列的维度
11seq_length = 50  # 序列长度
12hidden_dim = 128  # Transformer编码器隐藏层维度
13num_heads = 8  # 多头注意力机制中的头数
14num_layers = 6  # 编码器层数
15dropout = 0.1  # dropout概率
16num_classes = 3  # 分类类别数
17batch_size = 32
18num_epochs = 100
19learning_rate = 0.001
2. 数据预处理

假设我们有一组一维序列数据,我们将对其进行预处理,并将其划分为训练集和测试集。

1# 生成模拟数据
2def generate_data(n_samples, seq_length, input_dim, num_classes):
3    X = np.random.randn(n_samples, seq_length, input_dim)
4    y = np.random.randint(0, num_classes, size=(n_samples,))
5    return X, y
6
7# 生成数据
8X, y = generate_data(1000, seq_length, input_dim, num_classes)
9
10# 数据标准化
11scaler = StandardScaler()
12X = scaler.fit_transform(X.reshape(-1, input_dim)).reshape(X.shape)
13
14# 划分数据集
15X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
16
17# 定义数据集类
18class SequenceDataset(Dataset):
19    def __init__(self, X, y):
20        self.X = X
21        self.y = y
22
23    def __len__(self):
24        return len(self.y)
25
26    def __getitem__(self, idx):
27        return torch.tensor(self.X[idx], dtype=torch.float32), torch.tensor(self.y[idx], dtype=torch.long)
28
29train_dataset = SequenceDataset(X_train, y_train)
30test_dataset = SequenceDataset(X_test, y_test)
31
32train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
33test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
3. 定义Transformer模型

接下来定义一个基于Transformer的模型,该模型包含嵌入层、位置编码、多头自注意力机制和前馈神经网络等组件。

 
1class PositionalEncoding(nn.Module):
2    def __init__(self, d_model, max_len=5000):
3        super(PositionalEncoding, self).__init__()
4        pe = torch.zeros(max_len, d_model)
5        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
6        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
7        pe[:, 0::2] = torch.sin(position * div_term)
8        pe[:, 1::2] = torch.cos(position * div_term)
9        pe = pe.unsqueeze(0).transpose(0, 1)
10        self.register_buffer('pe', pe)
11
12    def forward(self, x):
13        return x + self.pe[:x.size(0), :]
14
15class TransformerClassifier(nn.Module):
16    def __init__(self, input_dim, hidden_dim, num_heads, num_layers, num_classes, dropout):
17        super(TransformerClassifier, self).__init__()
18        self.embedding = nn.Linear(input_dim, hidden_dim)
19        self.positional_encoding = PositionalEncoding(hidden_dim)
20        encoder_layer = nn.TransformerEncoderLayer(d_model=hidden_dim, nhead=num_heads, dropout=dropout)
21        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers)
22        self.classifier = nn.Linear(hidden_dim, num_classes)
23        self.dropout = nn.Dropout(dropout)
24
25    def forward(self, src):
26        embedded = self.embedding(src) * np.sqrt(hidden_dim)
27        encoded = self.positional_encoding(embedded)
28        output = self.transformer_encoder(encoded)
29        output = output.mean(dim=0)  # 平均池化
30        output = self.classifier(output)
31        return output
32
33model = TransformerClassifier(input_dim, hidden_dim, num_heads, num_layers, num_classes, dropout)
4. 训练模型

定义损失函数、优化器,并进行模型训练。

1criterion = nn.CrossEntropyLoss()
2optimizer = optim.Adam(model.parameters(), lr=learning_rate)
3
4def train(model, data_loader, criterion, optimizer, device):
5    model.train()
6    total_loss = 0.0
7    for batch_idx, (data, target) in enumerate(data_loader):
8        data, target = data.to(device), target.to(device)
9        optimizer.zero_grad()
10        output = model(data.permute(1, 0, 2))  # 调整数据维度为 (seq_len, batch_size, input_dim)
11        loss = criterion(output, target)
12        loss.backward()
13        optimizer.step()
14        total_loss += loss.item()
15    return total_loss / (batch_idx + 1)
16
17def evaluate(model, data_loader, criterion, device):
18    model.eval()
19    total_loss = 0.0
20    correct = 0
21    with torch.no_grad():
22        for data, target in data_loader:
23            data, target = data.to(device), target.to(device)
24            output = model(data.permute(1, 0, 2))
25            loss = criterion(output, target)
26            total_loss += loss.item()
27            pred = output.argmax(dim=1, keepdim=True)
28            correct += pred.eq(target.view_as(pred)).sum().item()
29    accuracy = correct / len(data_loader.dataset)
30    return total_loss / len(data_loader), accuracy
31
32device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
33model.to(device)
34
35for epoch in range(num_epochs):
36    train_loss = train(model, train_loader, criterion, optimizer, device)
37    test_loss, test_acc = evaluate(model, test_loader, criterion, device)
38    print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.4f}")
5. 总结

以上代码实现了一个使用Transformer对一维序列进行分类的完整流程,包括数据预处理、模型定义、训练和评估。该模型适用于处理时间序列数据或其他一维序列数据的分类任务。通过调整超参数和网络结构,可以进一步优化模型的性能。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2095203.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

ECC密码与RSA

一、ECC密码(椭圆曲线密码) 1.基本知识 定义: ECC 全称为椭圆曲线加密,EllipseCurve Cryptography,是一种基于椭圆曲线数学的公钥密码。与传统的基于大质数因子分解困难性的加密方法不同,ECC 依赖于解决椭圆…

@antv/x6 要求不显示水平滚动条,并且如果水平方向上显示不全的节点,则要求自动显示全部节点,垂直方向可以出现滚动条来滚动显示所有的节点。

1、要求一共有二个: 要求一:水平滚动条不显示。之前的文章中就已经发表过,可以用Scroller的className来处理。要求二:水平方向上显示全部节点,如果有显示不全的节点(即看不到的节点)要求能够显示…

asp.net实验:数据库写入不成功

🏆本文收录于《CSDN问答解惑-专业版》专栏,主要记录项目实战过程中的Bug之前因后果及提供真实有效的解决方案,希望能够助你一臂之力,帮你早日登顶实现财富自由🚀;同时,欢迎大家关注&&收…

国庆节微信头像怎么制作?制作国庆国旗节日头像的4个方法

国庆将至,不少朋友的微信头像都换成了渐变红旗头像,是不是觉得超酷呢?如果你也想拥有这样的头像,那就跟着这篇文章一起操作吧! 国庆节前夕,让我们先来了解一下如何制作渐变红旗头像。首先,我们需…

基于Python的Flask框架实战全流程从新建到部署【2】

本项目是基于win10系统运行以及操作的,部署在win7系统。 Flask 是一个轻量级的可定制框架,使用Python语言编写,较其他同类型框架更为灵活、轻便、安全且容易上手。 本文是flask框架实战项目,从新建、运行、测试、部署项目…

C/C++的内存分布、动态内存管理等的介绍

文章目录 前言一、C/C的内存分布二、C/C动态内存管理总结 前言 C/C的内存分布、动态内存管理等的介绍 一、C/C的内存分布 因为程序在运行过程中需要存储一些不同的数据,所以需要对内存空间进行分类 二、C/C动态内存管理 C语言动态内存管理是malloc / calloc / rea…

光降解水凝胶:三色光响应

大家好,今天来了解一种三色可见光波长选择性光降解水凝胶生物材料——《Tricolor visible wavelength-selective photodegradable hydrogel biomaterials》发表于《Nature Communications》,其交联剂Rubiq、Rubp和oNB对低能可见光(400 - 617n…

洛科威岩棉板重塑屋面应用,以多重优势成为“优选材料”

屋面作为建筑物的“外衣”,不仅承载着遮风挡雨的基本功能,更在保温隔热、防火安全、防潮隔音等方面发挥着举足轻重的作用。然而,面对极端气候、自然灾害以及日益严苛的环保标准,传统屋面材料逐渐暴露出其局限性,保温效…

JVM垃圾判定算法

垃圾收集技术是Java的一堵高墙。Java堆内存中存放着几乎所有的对象实例,垃圾收集器在对堆内存进行回收前,第一件事情就是要确定这些对象中哪些还存活,哪些已经死去(即不可能再被任何途径使用的对象)。也就是判定垃圾。…

STM32 使用8720 通过LWIP发送数据

一、硬件IOC 1、GPIO 2、NVIC 3、SYS 4、RCC 5、ETH 6、USART 7、LWIP 二、软件函数 1、Main /* USER CODE BEGIN Includes */ #include "ytcesys.h" /* USER CODE END Includes *//* USER CODE BEGIN 2 */ ethreset(); MX_LWIP_Init(); OPEN_USART1…

chunqiude

CVE-2022-28512 靶标介绍: Fantastic Blog (CMS)是一个绝对出色的博客/文章网络内容管理系统。它使您可以轻松地管理您的网站或博客,它为您提供了广泛的功能来定制您的博客以满足您的需求。它具有强大的功能,您无需接触任何代码即可启动并运…

【Java开发】Maven安装配置详细教程

原创文章,不得转载。 文章目录 产生背景用途安装配置本地仓库配置镜像 产生背景 在Java应用程序开发中,随着项目规模的不断扩大和复杂性增加,项目依赖的库、插件和配置文件也变得愈加复杂。传统的项目构建工具(如Ant)…

简单选择排序例题

从上面题目看出,如果排序方法可保证在排序前后排序码相同的相对位置不变,也就是四个选项里,21和21*之间不会交换 简单选择排序方法是:首先在所有记录中找到排序吗最小的记录,把它与第一个记录交换,然后在其…

EXO:模型最终验证的地方;infer_tensor;step;MLXDynamicShardInferenceEngine

目录 EXO:模型最终验证的地方 EXO:infer_tensor EXO:step MXNet的 mx.array 类型是什么 NDArray优化了什么 1. 异步计算和内存优化 2. 高效的数学和线性代数运算 3. 稀疏数据支持 4. 自动化求导 举例说明 EXO:模型最终验证的地方 EXO:infer_tensor 这段代码定…

【科技前沿】用深度强化学习优化电网,让电力调度更聪明!

Hey小伙伴们,今天我要跟大家分享一个超级酷炫的技术应用——深度强化学习在电网优化中的典型案例!如果你对机器学习感兴趣,或是正寻找如何用AI技术解决实际问题的方法,这篇分享绝对不容错过!👩‍&#x1f4…

Pyqt5高级技巧2:Tab顺序、伙伴快捷键、各类常用控件的事件、可移动的卡片式布局(含基础Demo)

目录 一、编辑Tab顺序 二、编辑伙伴 三、设置快捷键(仅MainWindow可用) 四、信号槽 【基本介绍】 【常用信号槽】控件对窗体(拖地) 【常用信号槽】控件对控件 【自定义信号槽】步骤 五、设计文件的转化 六、GUI的运行 1…

【研发日记】吃透新能源充电协议(一)——GB27930实例报文解析

文章目录 前言 背景介绍 充电协议框架 充电握手阶段 充电准备阶段 充电传输阶段 充电结束阶段 错误处理阶段 总结 参考资料 前言 近期在一个嵌入式开发项目中,用到了新能源充电协议,期间在翻阅各种资料文件时,一些地方还是容易理解…

包装和类练习 Stack的使用

目录 1.最小栈 2.有效的括号 3.栈的压入、弹出序列 4.逆波兰表达式求值 5.链栈与顺序栈相比&#xff0c;比较明显的优点是&#xff08; &#xff09; 1.最小栈 2.有效的括号 class Solution {public boolean isValid(String s) {Stack<Character> st new Stack<&g…

I/O方式

目录 一、程序查询方式 1.程序查询方式的特点 2.程序查询方式可分类 ①独占查询 ②定时查询 二、中断方式 1.中断I/O流程 2.例题 三、DMA方式 1.DMA控制器 2.特点 3. DMA的传送方式 ①停止CPU ②周期挪用 ③DMA和CPU交替访存 4.传送流程 ①预处理 ②数据传…

AIGC时代算法工程师的面试秘籍(第二十一式2024.8.19-9.1) |【三年面试五年模拟】

写在前面 【三年面试五年模拟】旨在整理&挖掘AI算法工程师在实习/校招/社招时所需的干货知识点与面试经验&#xff0c;力求让读者在获得心仪offer的同时&#xff0c;增强技术基本面。也欢迎大家提出宝贵的优化建议&#xff0c;一起交流学习&#x1f4aa; 欢迎大家关注Rocky…