Transformer经典模型实战:零基础训练一个面向中文的T5模型(Text to Text Transfer Transformer)

news2025/1/21 0:59:43

scient

scient一个用python实现科学计算相关算法的包,包括自然语言、图像、神经网络、优化算法、机器学习、图计算等模块。

scient源码和编译安装包可以在Python package index获取。

The source code and binary installers for the latest released version are available at the [Python package index].

https://pypi.org/project/scient

可以用pip安装scient

You can install scient like this:

pip install scient

也可以用setup.py安装。

Or in the scient directory, execute:

python setup.py install

scient.neuralnet

神经网络相关算法模块,包括attention、transformer、bert、lstm、resnet、crf、dataset、fit等。

scient.neuralnet.transformer

实现了多个Transformer模型,包括Transformer、T5Transformer、ViTransformer、DecodeTransformer、Encoder、Decoder。

scient.neuralnet.transformer.T5Transformer(vocab_size: int, seq_len: int = 512, embed_size: int = 768,
										   n_head: int = 12, n_encode_layer: int = 12, n_decode_layer: int = 12, n_bucket: int = 32,
										   max_dist: int = 128, norm_first: bool = True, bias: bool = False, attn_scale: bool = False,
										   **kwargs)

Parameters

  • vocab_size : int
    字典规模.
  • seq_len : int, optional
    序列长度. The default is 512.
  • embed_size : int, optional
    embedding向量长度. The default is 768.
  • n_head : int, optional
    multi_head_attention的head数量. The default is 12.
  • n_encode_layer : int, optional
    编码层数. The default is 12.
  • n_decode_layer : int, optional
    解码层数. The default is 12.
  • n_bucket : int, optional
    multi_head_attention中相对位置编码的分桶数量. The default is 32.
  • max_dist : int, optional
    multi_head_attention中相对位置编码的最大距离. The default is 128.
  • norm_first : bool, optional
    在每一个编码/解码层中是否先进行Batch Normalization. The default is True.
  • bias : bool, optional
    模型中的参数是否bias. The default is False.
  • attn_scale : bool, optional
    multi_head_attention中是否需要对注意力矩阵进行scale. The default is False.
  • kwargs : 其它参数,kwargs中的参数将被传递到Encoder层和Decoder层。

Algorithms

T5采用了相对位置分桶(relative_position_bucket)的方式来处理位置编码。
在双向注意力的Encoder阶段,相对位置分桶的公式为:

在这里插入图片描述

在单向注意力的Decoder阶段,相对位置分桶的公式为:

在这里插入图片描述

式中的 n b n_b nb为相对位置编码的分桶数量n_bucket, m a x _ d i s t a n c e max\_distance max_distance为相对位置编码的最大距离max_dist。

T5模型结构

T5Transformer(
  (encoder_position): BucketPosition(
    (projection): Embedding(32, 12)
  )
  (decoder_position): BucketPosition(
    (projection): Embedding(32, 12)
  )
  (embedding): Embedding(32128, 768)
  (encoder): ModuleList(
    (0-11): 12 x Encoder(
      (multi_head_attn): MultiHead(
        (dropout): Dropout(p=0.1, inplace=False)
        (query): Linear(in_features=768, out_features=768, bias=False)
        (key): Linear(in_features=768, out_features=768, bias=False)
        (value): Linear(in_features=768, out_features=768, bias=False)
        (linear): Linear(in_features=768, out_features=768, bias=False)
      )
      (feedforward): Sequential(
        (0): Linear(in_features=768, out_features=3072, bias=False)
        (1): ReLU()
        (2): Dropout(p=0.1, inplace=False)
        (3): Linear(in_features=3072, out_features=768, bias=False)
      )
      (layernorm1): T5LayerNorm()
      (layernorm2): T5LayerNorm()
      (dropout1): Dropout(p=0.1, inplace=False)
      (dropout2): Dropout(p=0.1, inplace=False)
    )
  )
  (decoder): ModuleList(
    (0-11): 12 x Decoder(
      (mask_multi_head_attn): MultiHead(
        (dropout): Dropout(p=0.1, inplace=False)
        (query): Linear(in_features=768, out_features=768, bias=False)
        (key): Linear(in_features=768, out_features=768, bias=False)
        (value): Linear(in_features=768, out_features=768, bias=False)
        (linear): Linear(in_features=768, out_features=768, bias=False)
      )
      (multi_head_attn): MultiHead(
        (dropout): Dropout(p=0.1, inplace=False)
        (query): Linear(in_features=768, out_features=768, bias=False)
        (key): Linear(in_features=768, out_features=768, bias=False)
        (value): Linear(in_features=768, out_features=768, bias=False)
        (linear): Linear(in_features=768, out_features=768, bias=False)
      )
      (feedforward): Sequential(
        (0): Linear(in_features=768, out_features=3072, bias=False)
        (1): ReLU()
        (2): Dropout(p=0.1, inplace=False)
        (3): Linear(in_features=3072, out_features=768, bias=False)
      )
      (layernorm1): T5LayerNorm()
      (layernorm2): T5LayerNorm()
      (layernorm3): T5LayerNorm()
      (dropout1): Dropout(p=0.1, inplace=False)
      (dropout2): Dropout(p=0.1, inplace=False)
      (dropout3): Dropout(p=0.1, inplace=False)
    )
  )
  (encoder_layernorm): T5LayerNorm()
  (decoder_layernorm): T5LayerNorm()
  (linear): Linear(in_features=768, out_features=32128, bias=False)
)

Examples

下面的代码实例是训练一个“对句子进行重写,且不改变语义”的模型,比如“鹿跳过篱笆。”可重写成“一只鹿跳过篱笆。”。

import torch
from scient.neuralnet import transformer,fit
from scient.neuralnet import dataset
import sentencepiece
import pandas
from tqdm import tqdm

tqdm.pandas()

data_path='d:\\rewrite_train3.xlsx'
tokenizer_path='d:\\spiece.model'

#%%model
vocab_size=32128
seq_len_upper=32

tokenizer=sentencepiece.SentencePieceProcessor(tokenizer_path)
model=transformer.T5Transformer(vocab_size=vocab_size,dropout=0.1,ffn_size=3072)

#%% 数据
data=pandas.read_excel(data_path)

#tokenize
data['source_token']=data['input'].progress_apply(lambda x:tokenizer.encode(x))
data['target_token']=data['label'].progress_apply(lambda x:tokenizer.encode(x))

#清洗
data=data[(data['source_token'].apply(len)<seq_len_upper)&(data['target_token'].apply(len)<seq_len_upper)]

#截断
data['source_token']=data['source_token'].progress_apply(lambda x:x[:seq_len_upper]+[tokenizer.eos_id()])#增加<eos>标识
data['target_input_token']=data['target_token'].progress_apply(lambda x:[tokenizer.pad_id()]+x[:seq_len_upper])#增加<bos>标识,这里用pad_id作为<bos>
data['target_output_token']=data['target_token'].progress_apply(lambda x:x[:seq_len_upper]+[tokenizer.eos_id()])#增加<eos>标识

#mask
data['source_pad_mask']=data['source_token'].progress_apply(lambda x:[False]*len(x)+[True]*(seq_len_upper-len(x)))
data['target_pad_mask']=data['target_input_token'].progress_apply(lambda x:[False]*len(x)+[True]*(seq_len_upper-len(x)))

#补齐
data['source_token']=data['source_token'].progress_apply(lambda x:x+[tokenizer.pad_id()]*(seq_len_upper-len(x)))
data['target_input_token']=data['target_input_token'].progress_apply(lambda x:x+[tokenizer.pad_id()]*(seq_len_upper-len(x)))
data['target_output_token']=data['target_output_token'].progress_apply(lambda x:x+[tokenizer.pad_id()]*(seq_len_upper-len(x)))

batch_size=8
#dataLoad
data_train=data.sample(frac=0.7)
data_eval=data.drop(data_train.index).sample(frac=0.7)
data_val=data.drop(data_train.index).drop(data_eval.index)
train_loader = torch.utils.data.DataLoader(dataset=dataset.DataFrame(frame=data_train,tensor_vars=['source_token','target_input_token','source_pad_mask','target_pad_mask'],target_var='target_output_token'),batch_size=batch_size,shuffle=True)
eval_loader = torch.utils.data.DataLoader(dataset=dataset.DataFrame(frame=data_eval,tensor_vars=['source_token','target_input_token','source_pad_mask','target_pad_mask'],target_var='target_output_token'),batch_size=batch_size,shuffle=False)
val_loader = torch.utils.data.DataLoader(dataset=dataset.DataFrame(frame=data_val,tensor_vars=['source_token','target_input_token','source_pad_mask','target_pad_mask'],target_var='target_output_token'),batch_size=1,shuffle=False)
#%% 训练
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# device = torch.device("cpu")

#损失函数
loss_func_ = torch.nn.CrossEntropyLoss(ignore_index=0)
def loss_func(y_hat,y):
    return loss_func_(y_hat.reshape(-1, vocab_size),y.reshape(-1).to(torch.int64).to(device))  # 计算损失

#优化器
optimizer = torch.optim.AdamW(model.parameters(), lr=1.5e-4)

def perform_func(y_hat,y):#perform_func的输入是预测值y_hat和实际值y
    y_hat,y=torch.concat(y_hat).reshape(-1, vocab_size).numpy(),torch.concat(y).reshape(-1).numpy()#先将y_hat和y分别concat,由于y_hat和y是按loader分批计算和收集的,所以y_hat和y是batch_size大小的多个对象组成的list
    y_hat=y_hat.argmax(axis=1)
    y_hat=y_hat[y!=0]
    y=y[y!=0]
    return round((y_hat==y).sum()/len(y),4)#输出准确率,并保留4位小数

model=fit.set(model,optimizer=optimizer,loss_func=loss_func,perform_func=perform_func,device=device,n_iter=5)
model.fit(train_loader,eval_loader,mode=('inputs','target'))

附代码中用到的tokenizer模型spiece.model和训练数据rewrite_train3.xlsx的下载地址:
链接:https://pan.baidu.com/s/12vEZBYldXvPrJTiFUEKGUw?pwd=DTFM
提取码:DTFM

通过5轮训练,模型在训练集和测试集上的准确率均已达到99%以上。

train iter 0: avg_batch_loss=3.88477 perform=0.5023: 100%|██████████| 140/140 [06:43<00:00,  2.88s/it]    
eval iter 0: avg_batch_loss=0.56695 perform=0.8973: 100%|██████████| 42/42 [00:28<00:00,  1.47it/s]    
train iter 1: avg_batch_loss=0.27674 perform=0.9539: 100%|██████████| 140/140 [08:02<00:00,  3.45s/it]    
eval iter 1: avg_batch_loss=0.08557 perform=0.9808: 100%|██████████| 42/42 [00:46<00:00,  1.10s/it]    
train iter 2: avg_batch_loss=0.05592 perform=0.9897: 100%|██████████| 140/140 [09:33<00:00,  4.10s/it]    
eval iter 2: avg_batch_loss=0.01999 perform=0.9957: 100%|██████████| 42/42 [00:28<00:00,  1.45it/s]    
train iter 3: avg_batch_loss=0.02244 perform=0.9964: 100%|██████████| 140/140 [07:58<00:00,  3.42s/it]    
eval iter 3: avg_batch_loss=0.01343 perform=0.996: 100%|██████████| 42/42 [00:32<00:00,  1.31it/s]     
train iter 4: avg_batch_loss=0.01273 perform=0.9981: 100%|██████████| 140/140 [07:44<00:00,  3.32s/it]    
eval iter 4: avg_batch_loss=0.01047 perform=0.9977: 100%|██████████| 42/42 [00:29<00:00,  1.41it/s]    

采用训练好的模型对data_val数据集进行预测

#%%
# 验证
model.eval()
progressbar = tqdm(val_loader)#这里batch_size必须为1
preds=[]
with torch.no_grad():
    for index,((source,target_input,source_pad_mask,target_input_pad_mask),target_output) in enumerate(progressbar):
        # break
        memory=model.encode(source.to(torch.int64).to(device),source_pad_mask.to(device))
        pred=torch.tensor([[tokenizer.pad_id()]])#bos
        while True:
            pred_mask=torch.zeros_like(pred).to(torch.bool)
            decode = model.decode(pred.to(torch.int64).to(device),memory,target_pad_mask=pred_mask.to(device))
            output=model.linear(decode)
            _,ids = output.max(dim=-1)
            if ids[0,-1]==tokenizer.eos_id():#eos
                break
            if pred.size(1)>seq_len_upper-1:
                break
            pred=torch.cat([pred.to(device),ids[:,-1:]],dim=-1)
        preds+=pred.tolist()

data_val['target_output_pred']=preds
data_val['target_pred']=data_val['target_output_pred'].progress_apply(lambda x:tokenizer.decode(x))

预测结果

在这里插入图片描述

input是输入,label是期望模型输出的内容,target_pred是模型输出的内容,可以看到模型输出与期望之间基本一致。
值得注意的是第一条数据,模型将
机构认为,随着经济数据及上市公司财报的披露,预计市场主线将逐渐清晰,中大盘成长股有望成为下一阶段的资金偏好。
改写成
他们说,第二季度的收益报告将是给予投资者这种指导的关键。
虽然与期望输出的内容差距较大,但是模型输出的意思却是完全正确的,难道这就是模型涌现出的创造力?

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2066945.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

景联文科技提供语音采集服务:开启智能交互新纪元

随着人工智能技术的飞速发展&#xff0c;语音交互已成为连接人与智能设备的重要桥梁。无论是智能助手、智能家居还是自动驾驶汽车&#xff0c;语音识别技术都是其背后不可或缺的核心力量。 为了满足各行各业对高质量语音数据的需求&#xff0c;景联文科技凭借自身强大的数据采集…

XGen —— 导入Hou

动力学解算 选择description&#xff0c;转化为交互式Groom&#xff1b; 选择description&#xff0c;将引导线转化为曲线&#xff1b; 添加linearWire&#xff0c;并选择转化后的曲线生成解算线 选择上述生成的解算线&#xff0c;创建动力学&#xff1b; 导出解算的Xgen 导出a…

Edge SCDN:构建更快更安全的网络世界

什么是Edge SCDN&#xff1f; Edge SCDN&#xff0c;即边缘安全加速&#xff0c;是一种结合了传统CDN&#xff08;内容分发网络&#xff09;技术与网络安全防护功能的服务。传统的CDN通过在全球范围内分布服务器节点来加速网站内容的分发&#xff0c;提高访问速度和用户体验。…

备考计算机二级Python之Day4作业编程题

1、输入一个年份&#xff0c;输出是否为闰年。 #闰年条件&#xff1a;能被4整除但不能被100整除&#xff0c;或者能被400整除的年份都是闰年。 neval(input("请输入一个年份&#xff1a;")) if (n%40 and n%100!0) or (n%4000):print("该年份是闰年") els…

68 H3C SecPath F1000 (系统模块介绍-1)

68 H3C SecPath F1000 &#xff08;系统模块介绍&#xff09; 01-高可靠性 特性简介 高可靠性&#xff08;High Availability&#xff09;&#xff0c;简称为HA&#xff0c;能够在通信线路或设备产生故障时提供备用方案&#xff0c;当其中一个网络节点发生故障时&#xff0c…

生信是什么?生物信息学的基础概念与应用领域-生信圆桌

介绍 生信&#xff0c;全称为生物信息学&#xff08;Bioinformatics&#xff09;&#xff0c;是指将计算机科学、数学和统计学的方法应用于生物学数据的处理、分析和解释。随着基因组测序技术的发展和大规模生物数据的产生&#xff0c;生物信息学成为了生命科学研究中的一个核…

浅谈AI+工业视觉检测技术应用的优化

1 高质量替代人眼&#xff0c;助力智能制造 视觉是人类获取信息最主要的渠道&#xff0c;它使人们得以感知和理解周边的世界。通过视觉&#xff0c;人类可以感知外界物体的大小、明暗、颜色、动静&#xff0c;获得对机体生存具有重要意义的各种信息。人类的大脑皮层约有70%都在…

arthas源码刨析:arthas 命令粗谈(3)

文章目录 dashboardwatchretransform 前面介绍了 arthas 启动相关的代码并聊了聊怎么到一个 shellserver 的建立。 本篇我们来探讨一下几个使用频次非常高的命令是如何实现的。 dashboard 想看这个命令的主要原因是编程这些年来从来没有开发过 terminal 的这种比较花哨的界面&a…

最新出炉 -Web自动化测试之playwright:概述

概述 playwright是由微软开发的Web UI自动化测试工具&#xff0c; 支持Node.js、Python、C# 和 Java语言&#xff0c;本文将介绍playwright的特性以及它的简单使用。 playwright特性 playwright具有以下特点&#xff1a; 一、支持所有主流浏览器 支持所有主流浏览器&#x…

从开发到集成:视频美颜SDK与直播美颜API详解

在本文中&#xff0c;我们将详细探讨视频美颜SDK的开发过程及其与直播美颜API的集成方案&#xff0c;帮助开发者更好地理解和应用这些技术。 一、视频美颜SDK的开发概述 视频美颜SDK是一个用于实时视频处理的开发工具包&#xff0c;提供了包括磨皮、美白、瘦脸、眼睛放大等多…

各类软件历史版本的下载地址

postman,notpad等 https://www.filehorse.com/software-developer-tools/https://www.filehorse.com/software-developer-tools/

数业智能心大陆AI大模型,共情陪伴你的心理健康

大模型的出现&#xff0c;使得AI在语音识别、自然语言处理、计算机视觉等领域的性能得到了极大的提升&#xff0c;随着硬件设备的不断升级和优化&#xff0c;以及算法的不断改进&#xff0c;大模型的规模和性能也在不断提升&#xff0c;大模型的优势在于其强大的表示能力和泛化…

不愿回流上班,离职博主们不断寻找新的“栖息地”

文 | 螳螂观察 作者 | 如意 “替大家试过了&#xff0c;不上班真的很爽。” “985本硕&#xff0c;年薪40万裸辞了。” “不干了&#xff0c;谁家好人半夜12点还在司啊&#xff01;” 标题熟悉吧&#xff1f;对&#xff0c;这拨人你一定看到了&#xff0c;说人生是旷野&am…

45+用户占比近30%,网文产业如何赋能IP长链?

网文市场加速发展&#xff0c;巨头抢占中老年用户 作者&#xff5c;吕娆炜 排版&#xff5c;张思琪 干货抢先看 1. 我国网文产业市场规模突破3000亿元&#xff0c;在用户方面&#xff0c;截至2023年底&#xff0c;我国网文用户数量达5.37亿&#xff0c;同比增长9%&#xff0c…

【FreeRTOS】信号量

0 前言 学习视频&#xff1a; 【FreeRTOS入门与工程实践 --由浅入深带你学习FreeRTOS&#xff08;FreeRTOS教程 基于STM32&#xff0c;以实际项目为导向&#xff09;】 【精准空降到 00:42】 https://www.bilibili.com/video/BV1Jw411i7Fz/?p39&share_sourcecopy_web&…

源2.0-M32大模型发布4bit/8bit量化版! 运行显存仅需23GB,性能可媲美LLaMA3

近日&#xff0c;浪潮信息发布源2.0-M32大模型4bit和8bit量化版&#xff0c;性能比肩700亿参数的LLaMA3开源大模型。4bit量化版推理运行显存仅需23.27GB&#xff0c;处理每token所需算力约为1.9 GFLOPs&#xff0c;算力消耗仅为同等当量大模型LLaMA3-70B的1/80。而LLaMA3-70B运…

删除Eureka注册中心已经注册的服务

1.登录Eureka查看需要删除的服务。 2.使用postman或者apipost工具&#xff0c;请求方式DELETE, 接口地址输入&#xff1a;eureka的ip地址/eureka/apps/ Application / Status 例如: http://192.168.194.60:8761/eureka/apps/VUE-MANAGER-SERVICE/10.42.0.138:vue-manager…

酷家乐 同盾滑块分析

声明: 本文章中所有内容仅供学习交流使用&#xff0c;不用于其他任何目的&#xff0c;抓包内容、敏感网址、数据接口等均已做脱敏处理&#xff0c;严禁用于商业用途和非法用途&#xff0c;否则由此产生的一切后果均与作者无关&#xff01; 有相关问题请第一时间头像私信联系我…

【Hot100】LeetCode—114. 二叉树展开为链表

目录 1- 思路技巧——借助指针 2- 实现⭐114. 二叉树展开为链表——题解思路 3- ACM 实现 原题连接&#xff1a;114. 二叉树展开为链表 1- 思路 技巧——借助指针 思路&#xff1a;通过 ① 将左子树的右下结点的 .next ——> 拼接到当前节点的右子树上。 构造 cur 指针&a…

KPaaS还是ESB?怎样选择合适的集成方案?

在全球经济一体化和数字化转型的背景下&#xff0c;企业正面临着前所未有的挑战与机遇。随着业务的快速发展&#xff0c;企业内部的信息系统日益复杂&#xff0c;系统间的信息孤岛、系统割裂以及高昂的维护成本等问题逐渐凸显&#xff0c;严重制约了企业的创新能力和市场竞争力…