Transformer实现以及Pytorch源码解读(一)-数据输入篇

news2025/2/6 17:39:39

目标

以词性标注任务为例子,实现Transformer,并分析实现Pytorch的源码解读。

数据准备

所选的数据为nltk数据工具中的treebank数据集。treebank数据集的样子如以下两幅图所示:
该数据集中解释变量为若干句完整的句子:
在这里插入图片描述
被解释变量为该句子中每个词的词性:
在这里插入图片描述
具体每个词性简写的意思,大概如下文所示(参考博客):

标注词表:
名词:NN,NNS,NNP,NNPS
代词:PRP,PRP$
形容词:JJ,JJR,JJS
数词:CD
动词:VB,VBD,VBG,VBN,VBP,VBZ
副词:RB,RBR,RBS
1.CC      Coordinating conjunction 连接词
2.CD     Cardinal number  基数词
3.DT     Determiner 
 限定词(如this,that,these,those,such,不定限定词:no,some,any,each,every,enough,either,neither,all,both,half,several,many,much,(a)
 few,(a) little,other,another.
4.EX     Existential there 存在句
5.FW     Foreign word 外来词
6.IN     Preposition or subordinating conjunction 介词或从属连词
7.JJ     Adjective 形容词或序数词
8.JJR     Adjective, comparative 形容词比较级
9.JJS     Adjective, superlative 形容词最高级
10.LS     List item marker 列表标示
11.MD     Modal 情态助动词
12.NN     Noun, singular or mass 常用名词 单数形式
13.NNS     Noun, plural  常用名词 复数形式
14.NNP     Proper noun, singular  专有名词,单数形式
15.NNPS     Proper noun, plural  专有名词,复数形式
16.PDT     Predeterminer 前位限定词
17.POS     Possessive ending 所有格结束词
18.PRP     Personal pronoun 人称代词
19.PRP$     Possessive pronoun 所有格代名词
20.RB     Adverb 副词
21.RBR     Adverb, comparative 副词比较级
22.RBS     Adverb, superlative 副词最高级
23.RP     Particle 小品词
24.SYM     Symbol 符号
25.TO     to 作为介词或不定式格式
26.UH     Interjection 感叹词
27.VB     Verb, base form 动词基本形式
28.VBD     Verb, past tense 动词过去式
29.VBG     Verb, gerund or present participle 动名词和现在分词
30.VBN     Verb, past participle 过去分词
31.VBP     Verb, non-3rd person singular present 动词非第三人称单数
32.VBZ     Verb, 3rd person singular present 动词第三人称单数
33.WDT     Wh-determiner 限定词(如关系限定词:whose,which.疑问限定词:what,which,whose.)
34.WP      Wh-pronoun 代词(who whose which)
35.WP$     Possessive wh-pronoun 所有格代词
36.WRB     Wh-adverb   疑问代词(how where when)

处理过程


from nltk.corpus import treebank

#表示句子,以及句子中每个词的词性
sents, postags = zip(*(zip(*sent) for sent in treebank.tagged_sents()))

# 对涉及到的单词和词性进行唯一化的处理,并为每个词指定一个整数
vocab = Vocab.build(sents, reserved_tokens=["<pad>"])
tag_vocab = Vocab.build(postags)

#前3000的句子作为训练集,后3000的句子作为测试集.同时将每个单词用整数表示
train_data = [(vocab.convert_tokens_to_ids(sentence), tag_vocab.convert_tokens_to_ids(tags)) for sentence, tags in zip(sents[:3000], postags[:3000])]
test_data = [(vocab.convert_tokens_to_ids(sentence), tag_vocab.convert_tokens_to_ids(tags)) for sentence, tags in zip(sents[3000:], postags[3000:])]
pos_vocab=tag_vocab

vocab类如下:

class Vocab:
    def __init__(self, tokens=None):
        self.idx_to_token = list()
        self.token_to_idx = dict()

        if tokens is not None:
            if "<unk>" not in tokens:
                tokens = tokens + ["<unk>"]
            for token in tokens:
                self.idx_to_token.append(token)
                self.token_to_idx[token] = len(self.idx_to_token) - 1
            self.unk = self.token_to_idx['<unk>']

    @classmethod
    def build(cls, text, min_freq=1, reserved_tokens=None):
        token_freqs = defaultdict(int)
        for sentence in text:
            for token in sentence:
                token_freqs[token] += 1
        uniq_tokens = ["<unk>"] + (reserved_tokens if reserved_tokens else [])
        uniq_tokens += [token for token, freq in token_freqs.items() \
                        if freq >= min_freq and token != "<unk>"]
        return cls(uniq_tokens)

    def __len__(self):
        return len(self.idx_to_token)

    def __getitem__(self, token):
        return self.token_to_idx.get(token, self.unk)

    def convert_tokens_to_ids(self, tokens):
        return [self[token] for token in tokens]

    def convert_ids_to_tokens(self, indices):
        return [self.idx_to_token[index] for index in indices]

对vocab类的解释:统计解释变量中涉及到的单词的出现频率,同时为每个单词分配一个整数,作为该单词的整数表示。关于@classmethod的意义请看博客1。
随后根据处理后的数据构建迭代器:

# 给输入数据构造迭代器
train_dataset = TransformerDataset(train_data)
test_dataset = TransformerDataset(test_data)

#将迭代器中的数据,用加载器加载。最主要的是collate_fn函数,他表示对迭代器中的数据进行怎样的处理
train_data_loader = DataLoader(train_dataset, batch_size=batch_size, collate_fn=collate_fn, shuffle=True)
test_data_loader = DataLoader(test_dataset, batch_size=1, collate_fn=collate_fn, shuffle=False)

num_class = len(pos_vocab)

#加载模型
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

TransformerDataset类的定义如下:

class TransformerDataset(Dataset):
    def __init__(self, data):
        self.data = data
    def __len__(self):
        return len(self.data)
    def __getitem__(self, i):
        return self.data[i]

def collate_fn(examples):
    """
    wfj:该函数表示对于batch_size中的每一个元素做以下一下的操作,通常用来进行数据的标准化工作
    """
    # print("==========================")
    # print(examples)
    # print(len(examples))
    lengths = torch.tensor([len(ex[0]) for ex in examples])
    inputs = [torch.tensor(ex[0]) for ex in examples]
    targets = [torch.tensor(ex[1]) for ex in examples]
    # 对batch内的样本进行padding,使其具有相同长度
    inputs = pad_sequence(inputs, batch_first=True, padding_value=vocab["<pad>"])
    targets = pad_sequence(targets, batch_first=True, padding_value=vocab["<pad>"])
    #输出的几个参数的解释:解释变量;每个解释变量的长度;被解释变量;是否为填充位的标记。
    return inputs, lengths, targets, inputs != vocab["<pad>"]

关于collate_fn的解释,请看博客2.
最后是构建Transformer类:

class Transformer(nn.Module):
   def __init__(self, vocab_size, embedding_dim, hidden_dim, num_class,
                dim_feedforward=512, num_head=2, num_layers=2, dropout=0.1, max_len=512, activation: str = "relu"):
       super(Transformer, self).__init__()
       # 词嵌入层
       self.embedding_dim = embedding_dim
       self.embeddings = nn.Embedding(vocab_size, embedding_dim)
       self.position_embedding = PositionalEncoding(embedding_dim, dropout, max_len)
       # 编码层:使用Transformer
       encoder_layer = nn.TransformerEncoderLayer(hidden_dim, num_head, dim_feedforward, dropout, activation)
       self.transformer = nn.TransformerEncoder(encoder_layer, num_layers)
       # 输出层
       self.output = nn.Linear(hidden_dim, num_class)

   def forward(self, inputs, lengths):
       inputs = torch.transpose(inputs, 0, 1)
       hidden_states = self.embeddings(inputs)
       for inp,hid in zip(inputs,hidden_states):
           print("===================================")
           print(inp)
           print(hid)
       hidden_states = self.position_embedding(hidden_states)
       attention_mask = length_to_mask(lengths) == False
       hidden_states = self.transformer(hidden_states, src_key_padding_mask=attention_mask).transpose(0, 1)
       logits = self.output(hidden_states)
       log_probs = F.log_softmax(logits, dim=-1)
       return log_probs

由于Transformer类继承了nn.model,因此实例化该类并调用的时候forward中的代码会自动的执行。通过查阅源码我们发现,所有的forward是通过_call_impl这个方法实现的,而_call_impl被__call__方法调用,因此实例化后的类可以被直接的调用。
在这里插入图片描述
最后是模型的整体训练部分。

model.train()
for epoch in range(num_epoch):
  total_loss = 0
  for batch in tqdm(train_data_loader, desc=f"Training Epoch {epoch}"):
      inputs, lengths, targets, mask = [x.to(device) for x in batch]
      log_probs = model(inputs, lengths)
      loss = nll_loss(log_probs[mask], targets[mask])
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()
      total_loss += loss.item()
      break
  print(f"Loss: {total_loss:.2f}")
  break

接下来我们看Transformer类中的方法是怎样被具体实现的,比如词编码即embedding的过程,请看博客:
Transformer实现以及Pytorch源码解读(二)-embedding源码

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/102786.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Docker-DockerFile制定镜像

什么是DockerFile&#xff1f; DockerFile是一个用来编写Docker镜像的文本文件&#xff0c;文本内容包含了一条条构建镜像所需要的指令和说明。DockerFile就想要一个脚本文件一样。把我们想要执行的操作放到文本文件里&#xff0c;一键执行。这样我们就可以复用这个DockerFile…

读论文:Learning to Compare: Relation Network for Few-Shot Learning

Abstract 我们提出了一个概念上简单、灵活且通用的少镜头学习框架&#xff0c;其中分类器必须学习识别每个只给出少量示例的新类。我们的方法称为关系网络(RN)&#xff0c;从头到尾进行训练。在元学习过程中&#xff0c;它学习学习一个深度距离度量来比较插曲中的少量图像&…

RNA-seq 详细教程:时间点分析(14)

学习内容 了解如何使用 DESeq2 进行时间的分析LRT 使用 LRT 进行 Time course 分析尽管基因表达的静态测量很受欢迎&#xff0c;但生物过程的时程捕获对于反映其动态性质至关重要&#xff0c;特别是当模式复杂且不仅仅是上升或下降时。在处理此类数据时&#xff0c;似然比检验 …

doris入门后遇到的几个问题总结

文章目录1. Access denied for user anonymnull (using password: NO)2. timeout when waiting for send fragments RPC. Wait(sec): 5, host: xxx(ip)3. Failed to initialize JNI: Failed to find the library libjvm.so.4. 从mysql库导出的json文件大于100M时报错5. csv格式…

OA办公系统:颠覆企业办公模式,激活组织潜能打造新模式

企业的生命力在于生存&#xff0c;而想要在竞争日益激烈的市场环境下生存&#xff0c;就必须不断革新自己的内部条件&#xff0c;否则将会在发展的洪流中被社会所淘汰。如今社会的发展正在信息化世界中进行&#xff0c;企业搭建信息化平台是一条必经之道&#xff0c;而OA办公自…

太爽了!看酷开系统帮你沉浸式带娃!

现如今&#xff0c;OTT大屏涉及的线上内容与娱乐方式与日俱增&#xff0c;不仅常规的电视节目、网剧影视能够随心选择&#xff0c;还发展出以大屏为载体的短视频、健身、云游戏等丰富内容。在人们的居家生活走向常态化的当下&#xff0c;更长的开机使用时间自然对电视操作系统的…

codeforces:C. Another Array Problem【分类讨论 + 找规律】

目录题目截图题目分析ac code总结题目截图 题目分析 做cf题目别老想着套算法模版 找规律才是正道&#xff0c;这就是所谓的「思维」 n 2很简单 n > 4: # 肯定有一个最大值&#xff0c;不妨设它的位置在第三个或以后的x# 前两个值经过两次操作&#xff0c;都变为0# 第0…

Vue.js 目录结构

当我们初始化一个项目后目录结构是这样的&#xff1a; 目录解析 目录/文件说明build项目构建(webpack)相关代码config配置目录&#xff0c;包括端口号等。我们初学可以使用默认的。node_modulesnpm 加载的项目依赖模块src这里是我们要开发的目录&#xff0c;基本上要做的事情都…

# 关于“table“中更新传参回填form

关于"table"中更新传参回填form 一、id查询数据库回填form 使用阶段&#xff1a;Javaweb/ssm/Springboot出现场景&#xff1a;jsp页面&#xff08;el表达式&#xff09;、thymeleaf页面&#xff08;thymeleaf表达式&#xff0c;具体使用方法请前往百度&#xff09;…

Python成求职中最吃香的三大编程语言之一

程序员培训公司 CodinGame 发布的一份开发人员调查报告显示&#xff0c;在开发人员招聘中&#xff0c;拥有 JavaScript、Java 和 Python 三大编程语言技能的开发人员最受招聘经理欢迎。 该报告基于对全球近 15,000 名开发人员和人力资源专业人员的调查。报告显示&#xff0c;每…

【IO流】JAVA基础篇(一)

文章目录一、字节流和字符流的区别1、字节和字符换算关系2、字节、位、二进制之间的关系3、在64位的操作系统中&#xff0c;一个字等于多少字节&#xff1f;4、字节流和字符流区别二、InputStream1、FileInputStream2、FilterInputStream3、ObjectInputStream4、PipedInputStre…

玩客云刷ARMBIAN当服务器过程记录

玩客云的可玩性 1、可以刷成电视游戏盒子的双系统。也可以刷成单独的电视盒子和游戏盒子。不过因为内存有限放不了多少游戏。还是建议用外置SD卡存储游戏比较合适。 2、刷成Armbian linux系统&#xff08;可以实现docker、可道云、甜糖等多种功能&#xff09; 3、最后它还可…

jsp+ssm计算机毕业设计风景区管理系统【附源码】

项目运行 环境配置&#xff1a; Jdk1.8 Tomcat7.0 Mysql HBuilderX&#xff08;Webstorm也行&#xff09; Eclispe&#xff08;IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持&#xff09;。 项目技术&#xff1a; JSPSSM mybatis Maven等等组成&#xff0c;B/S模式 Mave…

生物安全防护实验室建设要点SICOLAB

生物安全实验室&#xff08;BiosafetyLaboratory&#xff09;&#xff0c;也称生物安全防护实验室&#xff08;BiosafetyContainmentforLaboratories&#xff09;&#xff0c;是通过防护屏障和管理措施&#xff0c;能够避免或控制被操作的有害生物因子危害&#xff0c;达到生物…

磺丁基醚环糊精盐内水相/桂利嗪/EGF/吲哚美辛-环糊精/黄芩苷β-环糊精包合物脂质体制备

小编今天分享了磺丁基醚环糊精盐内水相/桂利嗪/EGF/吲哚美辛-环糊精/黄芩苷β-环糊精包合物脂质体的研究内容&#xff0c;和小编一起来看&#xff01; 黄芩苷β-环糊精(β-CD)包合物脂质体: 采用薄膜-超声分散法制备黄芩苷-CD包合物脂质体,并测定脂质体的粒径分布,Zeta电位以及…

灿芯股份冲刺科创板上市:计划募资6亿元,中芯国际、小米为股东

12月19日&#xff0c;灿芯半导体&#xff08;上海&#xff09;股份有限公司&#xff08;下称“灿芯股份”&#xff09;在上海证券交易所递交招股书&#xff0c;准备在科创板上市。本次冲刺科创板上市&#xff0c;灿芯股份计划募资6亿元&#xff0c;海通证券为其保荐机构。 招股…

赫夫曼树 | 实战演练(二)

&#x1f388; 作者&#xff1a;Linux猿 &#x1f388; 简介&#xff1a;CSDN博客专家&#x1f3c6;&#xff0c;华为云享专家&#x1f3c6;&#xff0c;Linux、C/C、云计算、物联网、面试、刷题、算法尽管咨询我&#xff0c;关注我&#xff0c;有问题私聊&#xff01; &…

高效空气过滤器检漏

广州特耐苏净化设备有限公司详细介绍&#xff1a;高效空气过滤器安装后的检漏 高效空气过滤器安装后的检漏是确认安装质量&#xff0c;检测高效空气过滤器送风口的整个面、过滤器的周边、过滤器外框和安装框架之间的密封处。检漏时&#xff0c;从过滤器的上风侧引入测试气溶胶…

北京理工大学汇编语言复习重点(可打印)

文章目录前言第一章&#xff1a;基础性能指标计算储存器原理第二章&#xff1a;微处理器管理模式CPU工作模式实模式保护模式虚拟8086模式&#xff08;V86模式&#xff09;寄存器概述GDTR&#xff08;Global Descriptor Table Registr&#xff09;全局描述符表寄存器LDTRIDTRTR内…

神仙级python入门教程(非常详细),从零基础入门到精通,从看这篇开始!

前言 一.初聊Python【文末有惊喜福利】 1.为什么要学习Python&#xff1f; 在学习Python之前&#xff0c;你不要担心自己没基础或“脑子笨”&#xff0c;我始终认为&#xff0c;只要你想学并为之努力&#xff0c;就能学好&#xff0c;就能用Python去做很多事情。在这个喧嚣的…