人工智能(Pytorch)搭建transformer模型,真正跑通transformer模型,深刻了解transformer的架构

news2024/9/20 16:57:04

大家好,我是微学AI,今天给大家讲述一下人工智能(Pytorch)搭建transformer模型,手动搭建transformer模型,我们知道transformer模型是相对复杂的模型,它是一种利用自注意力机制进行序列建模的深度学习模型。相较于 RNN 和 CNN,transformer 模型更高效、更容易并行化,广泛应用于神经机器翻译、文本生成、问答等任务。

一、transformer模型

transformer模型是一种用于进行序列到序列(seq2seq)学习的深度神经网络模型,它最初被应用于机器翻译任务,但后来被广泛应用于其他自然语言处理任务,如文本摘要、语言生成等。

Transformer模型的创新之处在于,在不使用LSTM或GRU等循环神经网络(RNN)的情况下,实现了序列数据的建模,这使得它具有了与RNN相比的许多优点,如更好的并行性、更高的训练速度和更长的序列依赖性。

二、transformer模型的结构

Transformer模型的主要组成部分是自注意力机制(self-attention mechanism)和前馈神经网络(feedforward neural network)。在使用自注意力机制时,模型会根据输入序列中每个位置的信息,生成一个与序列长度相同的向量表示。这个向量表示很好地捕捉了输入序列中每个位置和其他位置之间的关系,从而为模型提供了一个更好的理解输入信息的方式。

在Transformer中,输入序列由多个编码器堆叠而成,在每个编码器中,自注意力机制和前馈神经网络形成了一个块,多个块组成了完整的编码器。为了保持序列的信息,Transformer还使用了一个注意力机制(attention mechanism)来将输入序列中每个位置的信息传递到输出序列中。

6270bb3884e6498f9930ed48bb8f8436.png

 Transformer模型包括部分:

词嵌入层:将每个单词映射到一个向量表示,这个向量表示被称为嵌入向量(embedding vector),词嵌入层也可以使用预训练的嵌入向量。

位置编码:由于Transformer模型没有循环神经网络,因此需要一种方式来处理序列中单词的位置信息。位置编码是一组向量,它们被添加到嵌入向量中,以便模型能够对序列中单词的位置进行编码。

多头自注意力机制:是Transformer模型的核心部分,可以将输入序列中每个位置的信息传递到其他位置,并生成一个与输入序列长度相同的向量表示。

前馈神经网络:在自注意力机制之后,使用一层前馈神经网络来给各个位置的表示添加非线性变换。

残差连接:在自注意力机制和前馈神经网络之间添加了残差连接,来捕捉序列中长距离依赖性。

规范化层:规范化层分为两种:1.在层的维度上进行归一化处理,即对每个样本的所有神经元进行计算,以该样本在所有神经元输出的均值和方差作为归一化的参数。2.在每个mini-batch中进行归一化处理,即对一个mini-batch中所有样本在同一维度上进行归一化处理,然后使用该维度上mini-batch的均值和方差作为归一化的参数。

编码器层:由多个(通常为6-12个)完全相同的块组成,每个块包含一个自注意力机制、一个前馈神经网络和残差连接,用于对输入序列进行编码。

解码器层:在翻译任务中,还需要使用解码器从已经编码的源语言序列中生成目标语言序列。

三、transformer模型的搭建

import math
import torch
import torch.nn as nn
import torch.optim as optim

# 位置编码类
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=0.1)
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer("pe", pe)

    def forward(self, x):
        x = x + self.pe[: x.size(0), :]
        return self.dropout(x)

# transformer 模型搭建
class TransformerModel(nn.Module):
    def __init__(self, ntoken, ninp, nhead, nhid, nlayers):
        super(TransformerModel, self).__init__()

        #词嵌入层
        self.embedding = nn.Embedding(ntoken, ninp)

        # 位置编码
        self.pos_encoder = PositionalEncoding(ninp)

        #编码器层
        encoder_layers = nn.TransformerEncoderLayer(ninp, nhead, nhid)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layers, nlayers)
        self.decoder = nn.Linear(ninp, ntoken)
        self.init_weights()

    def init_weights(self):
        init_range = 0.1
        self.embedding.weight.data.uniform_(-init_range, init_range)
        self.decoder.bias.data.zero_()
        self.decoder.weight.data.uniform_(-init_range, init_range)

    def forward(self, x):
        x = self.embedding(x)
        x = self.pos_encoder(x)
        x = self.transformer_encoder(x)
        x = self.decoder(x)
        return x

def data_gen(batch_size=20, seq_len=10, limit=500):
    for _ in range(limit):
        data = torch.randint(1, 10, (batch_size, seq_len))
        targets = data * 2
        yield data, targets

if __name__ == "__main__":
    ntokens = 20
    emsize = 200
    nhead = 2
    nhid = 200
    nlayers = 2
    model = TransformerModel(ntokens, emsize, nhead, nhid, nlayers)

    criterion = nn.CrossEntropyLoss()
    lr = 0.001
    optimizer = optim.Adam(model.parameters(), lr=lr)

    num_epochs = 5
    for epoch in range(num_epochs):
        for i, (data, targets) in enumerate(data_gen()):
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output.view(-1, ntokens), targets.view(-1))
            loss.backward()
            optimizer.step()

            if i % 50 == 0:
                print(f"Epoch: {epoch}, Loss: {loss.item():.6f}")

    # Testing on some data
    test_data = torch.tensor([[3, 6, 9], [2, 4, 6]])
    print("Test input:", test_data)
    test_output = torch.argmax(model(test_data), dim=2)
    print("Test output:", test_output)

为了让大家更容易掌握,这里编码器和解码器过程直接利用nn.TransformerEncoderLayer,根据transformer编码器的结构中其实是包括:多头自注意力机制,前馈神经网络,残差连接、规范化层的。我们在项目在可以直接引用,进行调整。

为了让每个人跑通Transformer模型,我将输入序列中的每个整数乘以2。数据生成器data_gen函数生成用于训练的随机序列。这里设置了一个较小的词汇表大小为20,表示单词无需太多。

在训练完成后,给出一个简单的测试样例。测试数据包括两个序列[3, 6, 9]和[2, 4, 6]。模型的输出是输入整数加倍的序列。给大家举这个例子只是为了展示如何搭建一个Transformer模型,并在实际任务上完成简单的训练,让我们真正零距离地接触Transformer模型。

 

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/417218.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【数据结构Java】--图、BFS、DFS、拓扑结构

目录 一、图(Graph) 1.概念 2.有向图 3.出度、入度 4.无向图 5.简单图、多重图 6.无向完全图 7.有向完全图 8.有权图 9.连通图 10.连通分量(无向图) 11.强连通图(有向图) 12.强连通分量 13.邻接矩…

微服务架构-服务网关(Gateway)-权限认证(分布式session替代方案)

权限认证-分布式session替代方案 前面我们了解了Gateway组件的过滤器,这一节我们就探讨一下Gateway在分布式环境中的一个具体用例-用户鉴权。 1、传统单应用的用户鉴权 从我们开始学JavaEE的时候,就被洗脑式灌输了一种权限验证的标准做法,…

Adobe全新AI工具引关注,Adobe firefly助力创作更高效、更有创意

原标题:Adobe全新AI工具引关注,Adobe firefly(萤火虫)助力创作更高效、更有创意。 以ChatGPT为首的生成式AI、AIGC等工具的战局正如火如荼的进行中..... 除了微软、百度的聊天机器人和一些初创公司的AI画图工具令人惊艳&#xff…

Greenplum数据库执行器——PartitionSelector执行节点

为了能够对分区表有优异的处理能力,对于查询优化系统来说一个最基本的能力就是做分区裁剪partition pruning,将query中并不涉及的分区提前排除掉。如下执行计划所示,由于单表谓词在parititon key上,在优化期间即可确定哪些可以分区…

003:Mapbox GL设定不同的投影方式

第003个 点击查看专栏目录 本示例的目的是介绍演示如何在vue+mapbox中设定不同的投影方式 。默认情况下为Mercator投影,或者设置为null或者undefined时候,显示为Mercator投影。 直接复制下面的 vue+mapbox源代码,操作2分钟即可运行实现效果 文章目录 示例效果配置方式示例源…

【分享】维格表集成易聊实现线索自动化,减少流失率

公司•介绍 北京某职业教育公司专注行业发展、国际就业、留学、移民咨询。秉承专业性至上的原则,与行业内专家、高等学府以及产业集团合作,并邀请各领域专家组建了强大的专委会团队,为公司的业务开展提供专业性支持。 客户•遇到的问题 作为…

【Java面试八股文宝典之MySQL篇】备战2023 查缺补漏 你越早准备 越早成功!!!——Day23

大家好,我是陶然同学,软件工程大三即将实习。认识我的朋友们知道,我是科班出身,学的还行,但是对面试掌握不够,所以我将用这100多天更新Java面试题🙃🙃。 不敢苟同,相信大…

用Spring Doc代替Swagger

1 OpenApi OpenApi 是一个业界的 API 文档标准,是一个规范,这个规范目前有两大实现,分别是: SpringFoxSpringDoc 其中 SpringFox 其实也就是我们之前所说的 Swagger,SpringDoc 则是我们今天要说的内容。 OpenApi 就…

苹果智能戒指专利曝光,Find My技术加持不易丢

根据美国商标和专利局(USPTO)公示的清单,苹果近日获得了一项“智能戒指”相关的设计专利,编号为“US 11625098 B2”。 这款智能戒指专利主要服务于增强现实(AR)或者虚拟现实(VR)场…

leetcodeTmp

39. 组合总和 39. 组合总和 DFS排列&#xff1a;每个元素可选0次&#xff0c;1次以及多次 public List<List<Integer>> combinationSum(int[] candidates, int target) {//Arrays.sort(candidates);//注释了也能通过this.candidates candidates;ans.clear();co…

Omniverse Replicator 入门

OmniverseReplicator入门 Omniverse Replicator 作为 Omniverse Kit 扩展创建&#xff0c;并通过 Omniverse Code 方便地分发。 要使用复制器&#xff0c;您需要下载可在此处找到的 Omniverse 启动器。 有关 Omniverse 启动器的更多详细信息&#xff0c;请查看此视频。 使用…

kafaka学习

kafaka 消息队列&#xff1a;通常用来解决一个进程内&#xff0c;多线程环境下&#xff0c;资源竞争的问题&#xff1b;但是消息队列的锁的粒度太大了&#xff0c;需要进行拆分 消息队列中间组件 一个进程中&#xff0c;同时存在生产者、消费者、消息队列&#xff0c;在分布…

网络文件传输防止篡改-校验工具(md5sum)的使用

说明 MD5报文摘要算法&#xff08;Message-Digest Algorithm 5&#xff09;常常被用来验证网络文件传输的完整性&#xff0c;防止文件被人篡改。此算法对任意长度的信息逐位进行计算&#xff0c;产生一个二进制长度为128位&#xff08;十六进制长度就是32位&#xff09;的“指…

wordpres漏洞扫描器——wpscan

WordPress 使用PHP语言开发的博客平台 WordPress是使用PHP语言开发的博客平台&#xff0c;用户可以在支持PHP和MySQL数据库的服务器上架设属于自己的网站。也可以把 WordPress当作一个内容管理系统&#xff08;CMS&#xff09;来使用。 WordPress是一款个人博客系统&#xff0c…

手把手教你在linux中部署stable-diffusion-webui

stable-diffusion-webui是什么就不用多说了&#xff0c;以下是安装步骤&#xff0c;我以linux系统为例介绍&#xff0c;windows系统大同小异&#xff0c;安装期间没有用到梯子&#xff0c;安装目录/opt/stable-diffusion-webui/。 1.安装Anaconda stable-diffusion-webui要求p…

2023年小红书用户种草转化新路径

随着消费者对商品选择性提高&#xff0c;品牌转化链路随之被拉长&#xff0c;在投放操盘上竞争也愈发激烈&#xff0c;本期和大家聊聊如何在关键节点上引领用户决策&#xff0c;完成用户种草转化。种草链路拉长品牌发力点在何处&#xff1f; 基于平台用户的洞察分析&#xff0c…

ESXi安装CentOS

ESXi安装 参考&#xff1a;https://blog.csdn.net/tongxin_tongmeng/article/details/129466704 CentOS安装 镜像&#xff1a;http://mirrors.aliyun.com/centos/7/isos/x86_64-->CentOS-7-x86_64-DVD-2009.iso CentOS配置 FinalShell连接 ESXi简介 1.ESXi是由VMware公司…

leedcode刷题(6)

各位朋友们大家好&#xff0c;今天是我的leedcode刷题系列的第六篇。这篇文章将与队列方面的知识相关&#xff0c;因为这些知识用C语言实现较为复杂&#xff0c;所以我们就只使用Java来实现。 文章目录设计循环队列题目要求用例输入提示做题思路代码实现用栈实现队列题目要求用…

【回溯法】-----求一个集合的子集问题

leetcode78 subsetsleetcode 78 问题原文ExampleConstraints:解决思路回溯法代码实现leetcode 78 问题原文 Given an integer array nums of unique elements, return all possible subsets (the power set). The solution set must not contain duplicate subsets. Return t…

银行数仓分层架构

一、为什么要对数仓分层 实现好分层架构&#xff0c;有以下好处&#xff1a; 1清晰数据结构&#xff1a; 每一个数据分层都有对应的作用域&#xff0c;在使用数据的时候能更方便的定位和理解。 2数据血缘追踪&#xff1a; 提供给业务人员或下游系统的数据服务时都是目标数据&…