Transformer模型-数据预处理,训练,推理(预测)的简明介绍

news2025/3/1 14:03:41

Transformer模型-数据预处理,训练,推理(预测)的简明介绍

在继续探讨之前,假定已经对各个模块的功能有了充分的了解:

人工智能AI 虚拟现实VR 黑客帝国_Ankie(资深技术项目经理)的博客-CSDN博客icon-default.png?t=N7T8https://blog.csdn.net/ank1983/category_12546474.html

数据预处理

将德语翻译成英语,本文将使用torchtext.datasets中的Multi30k数据集。它包含训练集、验证集和测试集。所有用于加载分词器、生成词汇表、处理数据和生成批次的自定义函数都可以在附录中找到。

第一步是从spaCy加载每种语言的分词器,并使用load_vocab为两种语言创建词汇表。它调用build_vocabary,这是一个自定义函数,它使用torchtext.vocab中的build_vocab_from_iterator函数。词汇表中单词出现的最小频率为2,并且词汇表中的每个单词均为小写。build_vocabulary函数加载Multi30k数据集以生成词汇表。

生成词汇表后,可以设定一些全局变量,这些变量将以大写字母表示。以下变量分别代表“<bos>”、“<eos>”和“<pad>”的索引,这些索引对于源语言和目标语言的词汇表都是相同的。

BOS_IDX = vocab_trg['<bos>']
EOS_IDX = vocab_trg['<eos>']
PAD_IDX = vocab_trg['<pad>']

接下来可以加载数据集进行处理。

# raw data
train_data_raw, val_data_raw, test_data_raw = datasets.Multi30k(language_pair=("de", "en"))

每个集合都是一个数据迭代器,可以看作是一系列元组的列表。每个元组包含一个德语-英语对,例如(“Wie heißt du?”,“What is your name?”)。这些数据可以根据词汇表进行分词并转换为相应的索引。这些操作在自定义函数data_process中执行。

# processed data
train_data = data_process(train_data_raw)
val_data = data_process(val_data_raw)
test_data = data_process(test_data_raw)

现在,这些数据迭代器可以传递给torch.utils.data中的DataLoader,用于在训练期间生成批次。DataLoader需要数据迭代器、批次大小以及用于自定义批次的collate函数。它还允许打乱批次,并在最后一批不是完整批次时将其丢弃。需要提醒的是,批次大小是每个优化步骤中使用的序列数。

在下面的代码中,MAX_PADDING表示序列可以拥有的最大令牌数。torch.nn.functional中的pad函数会截断任何长于该值的序列,否则添加填充。这由generate_batch函数使用,该函数向序列中添加“<bos>”、“<eos>”和“<pad>”令牌,并生成用于训练的批次。在创建每个DataLoader时,数据迭代器被转换为映射风格的数据集,因为它们可以轻松地被打乱,并且可以根据需要提供其大小。

MAX_PADDING = 20
BATCH_SIZE = 128

train_iter = DataLoader(to_map_style_dataset(train_data), batch_size=BATCH_SIZE,
shuffle=True, drop_last=True, collate_fn=generate_batch)

valid_iter = DataLoader(to_map_style_dataset(val_data), batch_size=BATCH_SIZE,
shuffle=True, drop_last=True, collate_fn=generate_batch)

test_iter = DataLoader(to_map_style_dataset(test_data), batch_size=BATCH_SIZE,
shuffle=True, drop_last=True, collate_fn=generate_batch)

创建模型

下一步是创建用于训练数据的模型。可以使用make_model函数传递参数来创建模型,如果GPU可用,可以使用model.cuda()来确保模型将在GPU上进行训练。这些值是凭经验选择的。

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = make_model(device, vocab_src, vocab_trg,
n_layers=3, n_heads=8, d_model=256,
d_ffn=512, max_length=50)
model.cuda()

此外,还可以预览模型的总可训练参数,以评估其大小。

创建训练函数

为了训练模型,可以使用学习率为0.0005的Adam优化器,以及Cross Entropy Loss作为损失函数。Cross Entropy Loss接受模型输出的logits作为输入,通过softmax函数进行转换,取每个令牌的argmax,并将其与预期的目标输出进行比较。

LEARNING_RATE = 0.0005

optimizer = torch.optim.Adam(model.parameters(), lr = LEARNING_RATE)
criterion = nn.CrossEntropyLoss(ignore_index = PAD_IDX)

可以使用以下函数来训练模型,这些步骤在每个训练周期中执行。模型根据损失计算logits并更新参数。最后,该函数返回周期中批次的平均损失。请注意,logits和预期输出被重塑为单个序列,而不是单独的序列。对于logits,给定(3, 10, 27),表示由27个元素向量表示的十个令牌的三个序列,新形状将为(30, 27),即一个长序列。当执行argmax时,输出是一个30个元素的向量。预期输出,其形状为(3,10),也可以被重塑为30个元素的向量,然后这两个向量可以很容易地进行比较。

def train(model, iterator, optimizer, criterion, clip):
"""
Train the model on the given data.

Args:
model: Transformer model to be trained
iterator: data to be trained on
optimizer: optimizer for updating parameters
criterion: loss function for updating parameters
clip: value to help prevent exploding gradients

Returns:
loss for the epoch
"""

# set the model to training mode
model.train()

epoch_loss = 0

# loop through each batch in the iterator
for i, batch in enumerate(iterator):

# set the source and target batches
src,trg = batch

# zero the gradients
optimizer.zero_grad()

# logits for each output
logits = model(src, trg[:,:-1])

# expected output
expected_output = trg[:,1:]

# calculate the loss
loss = criterion(logits.contiguous().view(-1, logits.shape[-1]),
expected_output.contiguous().view(-1))

# backpropagation
loss.backward()

# clip the weights
torch.nn.utils.clip_grad_norm_(model.parameters(), clip)

# update the weights
optimizer.step()

# update the loss
epoch_loss += loss.item()

# return the average loss for the epoch
return epoch_loss / len(iterator)

下面的评估函数执行与训练函数相同的流程,但不会更新权重。这将在测试集和验证集上使用,以查看模型的泛化能力。

def evaluate(model, iterator, criterion):
"""
Evaluate the model on the given data.

Args:
model: Transformer model to be trained
iterator: data to be evaluated
criterion: loss function for assessing outputs

Returns:
loss for the data
"""

# set the model to evaluation mode
model.eval()

epoch_loss = 0

# evaluate without updating gradients
with torch.no_grad():

# loop through each batch in the iterator
for i, batch in enumerate(iterator):

# set the source and target batches
src, trg = batch


# logits for each output
logits = model(src, trg[:,:-1])

# expected output
expected_output = trg[:,1:]

# calculate the loss
loss = criterion(logits.contiguous().view(-1, logits.shape[-1]),
expected_output.contiguous().view(-1))

# update the loss
epoch_loss += loss.item()

# return the average loss for the epoch
return epoch_loss / len(iterator)

最后,可以创建一个函数来计算每个周期所需的时间。

def epoch_time(start_time, end_time):
elapsed_time = end_time - start_time
elapsed_mins = int(elapsed_time / 60)
elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
return elapsed_mins, elapsed_secs

训练模型

现在可以创建训练循环来训练模型,并评估其在验证集上的性能。

N_EPOCHS = 10
CLIP = 1

best_valid_loss = float('inf')

# loop through each epoch
for epoch in range(N_EPOCHS):

start_time = time.time()

# calculate the train loss and update the parameters
train_loss = train(model, train_iter, optimizer, criterion, CLIP)

# calculate the loss on the validation set
valid_loss = evaluate(model, valid_iter, criterion)

end_time = time.time()

# calculate how long the epoch took
epoch_mins, epoch_secs = epoch_time(start_time, end_time)

# save the model when it performs better than the previous run
if valid_loss < best_valid_loss:
best_valid_loss = valid_loss
torch.save(model.state_dict(), 'transformer-model.pt')

print(f'Epoch: {epoch+1:02} | Time: {epoch_mins}m {epoch_secs}s')
print(f'\tTrain Loss: {train_loss:.3f} | Train PPL: {math.exp(train_loss):7.3f}')
print(f'\t Val. Loss: {valid_loss:.3f} | Val. PPL: {math.exp(valid_loss):7.3f}')

输出结果:

Epoch: 01 | Time: 0m 21s
Train Loss: 4.534 | Train PPL: 93.169
Val. Loss: 3.474 | Val. PPL: 32.280

Epoch: 05 | Time: 0m 13s
Train Loss: 1.801 | Train PPL: 6.055
Val. Loss: 1.829 | Val. PPL: 6.229

Epoch: 10 | Time: 0m 13s
Train Loss: 1.093 | Train PPL: 2.984
Val. Loss: 1.677 | Val. PPL: 5.351

在评估结果之前,还可以使用评估函数在测试集上评估模型的准确性。

虽然损失已经显著下降,但并没有表明模型在德语到英语的翻译任务上有多成功。这可以通过两种方式评估。第一种是提供一个句子并在推理过程中预览其翻译。第二种是通过另一个指标(如BLEU)来计算其准确性,BLEU是翻译任务的标准指标。

推理(预测)

通过将句子传递给下面的函数,可以执行实时翻译。句子将被分词并传递给模型,一次生成一个令牌。一旦出现“<eos>”令牌,就会返回输出。

def translate_sentence(sentence, model, device, max_length = 50):
"""
Translate a German sentence to its English equivalent.

Args:
sentence: German sentence to be translated to English; list or str
model: Transformer model used for translation
device: device to perform translation on
max_length: maximum token length for translation

Returns:
src: return the tokenized input
trg_input: return the input to the decoder before the final output
trg_output: return the final translation, shifted right
attn_probs: return the attention scores for the decoder heads
masked_attn_probs: return the masked attention scores for the decoder heads
"""

model.eval()

# tokenize and index the provided string
if isinstance(sentence, str):
src = ['<bos>'] + [token.text.lower() for token in spacy_de(sentence)] + ['<eos>']
else:
src = ['<bos>'] + sentence + ['<eos>']

# convert to integers
src_indexes = [vocab_src[token] for token in src]

# convert list to tensor
src_tensor = torch.tensor(src_indexes).int().unsqueeze(0).to(device)

# set <bos> token for target generation
trg_indexes = [vocab_trg.get_stoi()['<bos>']]

# generate new tokens
for i in range(max_length):

# convert the list to a tensor
trg_tensor = torch.tensor(trg_indexes).int().unsqueeze(0).to(device)

# generate the next token
with torch.no_grad():

# generate the logits
logits = model.forward(src_tensor, trg_tensor)

# select the newly predicted token
pred_token = logits.argmax(2)[:,-1].item()

# if <eos> token or max length, stop generating
if pred_token == vocab_trg.get_stoi()['<eos>'] or i == (max_length-1):

# decoder input
trg_input = vocab_trg.lookup_tokens(trg_indexes)

# decoder output
trg_output = vocab_trg.lookup_tokens(logits.argmax(2).squeeze(0).tolist())

return src, trg_input, trg_output, model.decoder.attn_probs, model.decoder.masked_attn_probs

# else, continue generating
else:
# add the token
trg_indexes.append(pred_token)

测试:

# 'a woman with a large purse is walking by a gate'
src = ['eine', 'frau', 'mit', 'einer', 'großen', 'geldbörse', 'geht', 'an', 'einem', 'tor', 'vorbei', '.']

src, trg_input, trg_output, attn_probs, masked_attn_probs = translate_sentence(src, model, device)

print(f'source = {src}')
print(f'target input = {trg_input}')
print(f'target output = {trg_output}')

输出:

source = ['<bos>', 'eine', 'frau', 'mit', 'einer', 'großen', 'geldbörse', 'geht', 'an', 'einem', 'tor', 'vorbei', '.', '<eos>']
target input = ['<bos>', 'a', 'woman', 'with', 'a', 'large', 'purse', 'walking', 'past', 'a', 'gate', '.']
target output = ['a', 'woman', 'with', 'a', 'large', 'purse', 'walking', 'past', 'a', 'gate', '.', '<eos>']

原文链接:

https://medium.com/@hunter-j-phillips/putting-it-all-together-the-implemented-transformer-bfb11ac1ddfe

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1584888.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Vue2创建过程记录

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 文章目录 一、搭建node二、安装Vue CLI三、搭建新项目四、Elemet安装&#xff08;参照官网步骤[Element官网](https://element.eleme.cn/#/zh-CN/component/installation)&am…

2024年安卓轮播图代码+定时翻页(全网代码最少实现)

2024年安卓轮播图代码定时翻页 asda 这里是Fragment子类的继承如果使用 AppCompatActivity请修改一下很简单的如果又看不懂的话可以访问使用我的gpt&#xff1a;https://0.00000.work/ 免费3.5的 直接吧代码扔给他然后和他说帮忙解释一下每一行作用 Integer[] data{R.drawab…

Nexpose v6.6.245 for Linux Windows - 漏洞扫描

Nexpose v6.6.245 for Linux & Windows - 漏洞扫描 Rapid7 Vulnerability Management, Release Apr 03, 2024 请访问原文链接&#xff1a;Nexpose v6.6.245 for Linux & Windows - 漏洞扫描&#xff0c;查看最新版。原创作品&#xff0c;转载请保留出处。 作者主页&…

JVM字节码与类的加载——class文件结构

文章目录 1、概述1.1、class文件的跨平台性1.2、编译器分类1.3、透过字节码指令看代码细节 2、虚拟机的基石&#xff1a;class文件2.1、字节码指令2.2、解读字节码方式 3、class文件结构3.1、魔数&#xff1a;class文件的标识3.2、class文件版本号3.3、常量池&#xff1a;存放所…

Vue笔记 2

数据代理 数据代理&#xff1a;通过一个对象代理对另一个对象中属性的操作&#xff08;读/写&#xff09; let obj{x:100} let obj2{y:200} Object.defineProperty(obj2,x,{get(){return obj.x},set(value){obj.x value} })Vue中的数据代理 Vue中的数据代理&#xff1a; 通…

cesium 添加动态波纹效果 圆形扩散效果 波纹材质

一、扩展材质 /*** 水波纹扩散材质* param {*} options* param {String} options.color 颜色* param {Number} options.duration 持续时间 毫秒* param {Number} options.count 波浪数量* param {Number} options.gradient 渐变曲率*/function CircleWaveMaterialProperty(opt…

顶顶通呼叫中心中间件(mod_cti基于FreeSWITCH)-回铃音补偿

文章目录 前言联系我们解决问题操作步骤 前言 回铃音&#xff1a; 当别人打电话给你时&#xff0c;你的电话响铃了&#xff0c;而他听到的声音叫做回铃音。回铃音是被叫方向主叫方传送&#xff0c;也是彩铃功能的基础。我们平时打电话听到的“嘟 嘟 嘟 嘟”的声音&#xff0c;就…

Golang | Leetcode Golang题解之第20题有效的括号

题目&#xff1a; 题解&#xff1a; func isValid(s string) bool {n : len(s)if n % 2 1 {return false}pairs : map[byte]byte{): (,]: [,}: {,}stack : []byte{}for i : 0; i < n; i {if pairs[s[i]] > 0 {if len(stack) 0 || stack[len(stack)-1] ! pairs[s[i]] {…

白盒测试-条件覆盖

​ 条件覆盖是指运行代码进行测试时&#xff0c;程序中所有判断语句中的条件取值为真值为假的情况都被覆盖到&#xff0c;即每个判断语句的所有条件取真值和假值的情况都至少被经历过一次。 ​ 条件覆盖率的计算方法为&#xff1a;测试时覆盖到的条件语句真、假情况的总数 / 程…

期货学习笔记-MACD指标学习2

MACD底背离把握买入多单的技巧 底背离的概念及特征 底背离指的是MACD指标与价格低点之间的对比关系&#xff0c;这里需要明白的是MACD指标的涨跌动能和价格形态衰竭形态之间的关系&#xff0c;如果市场价格创新低而出现衰竭形态同时也有底背离形态的出现&#xff0c;此时下跌…

2024认证杯数学建模A题思路模型代码

目录 2024认证杯数学建模A题思路模型代码&#xff1a;4.11开赛后第一时间更更新&#xff0c;获取见文末名片 2023年认证杯数学建模 2024年认证杯思路代码获取见此 2024认证杯数学建模A题思路模型代码&#xff1a;4.11开赛后第一时间更更新&#xff0c;获取见文末名片 2023年认…

关于AI发展的3种声音:杨植麟 朱啸虎 王小川

1、杨植麟&#xff1a;技术信仰派 2、朱啸虎&#xff1a;市场信仰派 3、王小川&#xff1a;中间派 References 对话月之暗面杨植麟&#xff1a;向延绵而未知的雪山前进朱啸虎讲了一个中国现实主义AIGC故事王小川想提出中国AGI第三种可能性

【C 数据结构】线性表

文章目录 【 1. 线性表 】【 2. 顺序存储结构、链式存储结构 】【 3. 前驱、后继 】 【 1. 线性表 】 线性表&#xff0c;全名为线性存储结构&#xff0c;线性表结构存储的数据往往是可以依次排列的&#xff08;不考虑数值大小顺序&#xff09;。 例如&#xff0c;存储类似 {1…

Golang快速入门教程(一)

目录 一、环境搭建 1.windows安装 2.linux安装 3.开发工具 二、变量定义与输入输出 1.变量定义 2.全局变量与局部变量 3.定义多个变量 4.常量定义 5.命名规范 6.输出 7.输入 三、基本数据类型 1.整数型 2.浮点型 3.字符型 4.字符串类型 转义字符 多行字符…

RN使用蓝牙扫描

我项目需要用到蓝牙模块,蓝牙扫描到设备并且获取到电量显示到页面上,因此我做了如下demo,使用了react-native-ble-plx这个插件 点击进入官方文档官方文档 1.安卓环境配置(ios暂定,还没做ios,不过下面的方法是兼容的,自行配置ios权限) android/app/src/main/AndroidManifest.xml…

API管理平台:你用的到底是哪个?

Apifox是不开源的&#xff0c;在github的项目只是readme文件&#xff0c;私有化需要付费。当然saas版目前是免费使用的。 一、Swagger 为了让Swagger界面更加美观&#xff0c;有一些项目可以帮助你实现这一目标。以下是一些流行的项目&#xff0c;它们提供了增强的UI和额外的功…

代码签名证书:确保软件安全与可信性的关键工具

在数字化时代&#xff0c;软件已成为各行各业的核心驱动力。然而&#xff0c;随着网络威胁日益复杂&#xff0c;用户对于软件来源的可靠性、内容的完整性和发布者的可信度提出了更高要求。为满足这一需求&#xff0c;代码签名证书应运而生&#xff0c;它通过先进的加密技术&…

【SpringBoot】-- mapstruct进行类型转换时Converter实现类不能自动生成代码问题解决

问题描述 我的问题如下&#xff1a; 应该在红色区域生成对应的转换细节&#xff0c;但是这里只返回了一个空对象 问题解决 加入lombok-mapstruct-binding依赖,也要注意依赖引用顺序问题 <dependency><groupId>org.projectlombok</groupId><artifactId&…

【智能优化算法详解】粒子群算法PSO量子粒子群算法QPSO

1.粒子群算法PSO 博主言简意赅总结-算法思想&#xff1a;大方向下个体自学习探索群体交流共享 对比适应度找到最优点 背景 粒子群算法&#xff0c;也称粒子群优化算法或鸟群觅食算法&#xff08;Particle Swarm Optimization&#xff09;&#xff0c; 缩写为 PSO。粒子群…

小白部署springboot+vue网站到服务器踩坑总结【维护更新篇】

目录 前言如何更新前端nginx安装和启动言归正传&#xff0c;更新前端 如何更新后端杀死进程 前言 在上一篇文章里详细介绍了怎样部署一个前后端分离的网站&#xff0c;链接: 链接在此 但是部署完之后&#xff0c;在本地的开发更新后&#xff0c;重新上传到服务器上又遇到了一些…