深度学习实战24-人工智能(Pytorch)搭建transformer模型,真正跑通transformer模型,深刻了解transformer的架构

news2025/1/18 20:29:38

大家好,我是微学AI,今天给大家讲述一下人工智能(Pytorch)搭建transformer模型,手动搭建transformer模型,我们知道transformer模型是相对复杂的模型,它是一种利用自注意力机制进行序列建模的深度学习模型。相较于 RNN 和 CNN,transformer 模型更高效、更容易并行化,广泛应用于神经机器翻译、文本生成、问答等任务。

一、transformer模型

transformer模型是一种用于进行序列到序列(seq2seq)学习的深度神经网络模型,它最初被应用于机器翻译任务,但后来被广泛应用于其他自然语言处理任务,如文本摘要、语言生成等。

Transformer模型的创新之处在于,在不使用LSTM或GRU等循环神经网络(RNN)的情况下,实现了序列数据的建模,这使得它具有了与RNN相比的许多优点,如更好的并行性、更高的训练速度和更长的序列依赖性。

二、transformer模型的结构

Transformer模型的主要组成部分是自注意力机制(self-attention mechanism)和前馈神经网络(feedforward neural network)。在使用自注意力机制时,模型会根据输入序列中每个位置的信息,生成一个与序列长度相同的向量表示。这个向量表示很好地捕捉了输入序列中每个位置和其他位置之间的关系,从而为模型提供了一个更好的理解输入信息的方式。

在Transformer中,输入序列由多个编码器堆叠而成,在每个编码器中,自注意力机制和前馈神经网络形成了一个块,多个块组成了完整的编码器。为了保持序列的信息,Transformer还使用了一个注意力机制(attention mechanism)来将输入序列中每个位置的信息传递到输出序列中。

6270bb3884e6498f9930ed48bb8f8436.png

 Transformer模型包括部分:

词嵌入层:将每个单词映射到一个向量表示,这个向量表示被称为嵌入向量(embedding vector),词嵌入层也可以使用预训练的嵌入向量。

位置编码:由于Transformer模型没有循环神经网络,因此需要一种方式来处理序列中单词的位置信息。位置编码是一组向量,它们被添加到嵌入向量中,以便模型能够对序列中单词的位置进行编码。

多头自注意力机制:是Transformer模型的核心部分,可以将输入序列中每个位置的信息传递到其他位置,并生成一个与输入序列长度相同的向量表示。

前馈神经网络:在自注意力机制之后,使用一层前馈神经网络来给各个位置的表示添加非线性变换。

残差连接:在自注意力机制和前馈神经网络之间添加了残差连接,来捕捉序列中长距离依赖性。

规范化层:规范化层分为两种:1.在层的维度上进行归一化处理,即对每个样本的所有神经元进行计算,以该样本在所有神经元输出的均值和方差作为归一化的参数。2.在每个mini-batch中进行归一化处理,即对一个mini-batch中所有样本在同一维度上进行归一化处理,然后使用该维度上mini-batch的均值和方差作为归一化的参数。

编码器层:由多个(通常为6-12个)完全相同的块组成,每个块包含一个自注意力机制、一个前馈神经网络和残差连接,用于对输入序列进行编码。

解码器层:在翻译任务中,还需要使用解码器从已经编码的源语言序列中生成目标语言序列。

三、transformer模型的搭建

import math
import torch
import torch.nn as nn
import torch.optim as optim

# 位置编码类
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=0.1)
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer("pe", pe)

    def forward(self, x):
        x = x + self.pe[: x.size(0), :]
        return self.dropout(x)

# transformer 模型搭建
class TransformerModel(nn.Module):
    def __init__(self, ntoken, ninp, nhead, nhid, nlayers):
        super(TransformerModel, self).__init__()

        #词嵌入层
        self.embedding = nn.Embedding(ntoken, ninp)

        # 位置编码
        self.pos_encoder = PositionalEncoding(ninp)

        #编码器层
        encoder_layers = nn.TransformerEncoderLayer(ninp, nhead, nhid)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layers, nlayers)
        self.decoder = nn.Linear(ninp, ntoken)
        self.init_weights()

    def init_weights(self):
        init_range = 0.1
        self.embedding.weight.data.uniform_(-init_range, init_range)
        self.decoder.bias.data.zero_()
        self.decoder.weight.data.uniform_(-init_range, init_range)

    def forward(self, x):
        x = self.embedding(x)
        x = self.pos_encoder(x)
        x = self.transformer_encoder(x)
        x = self.decoder(x)
        return x

def data_gen(batch_size=20, seq_len=10, limit=500):
    for _ in range(limit):
        data = torch.randint(1, 10, (batch_size, seq_len))
        targets = data * 2
        yield data, targets

if __name__ == "__main__":
    ntokens = 20
    emsize = 200
    nhead = 2
    nhid = 200
    nlayers = 2
    model = TransformerModel(ntokens, emsize, nhead, nhid, nlayers)

    criterion = nn.CrossEntropyLoss()
    lr = 0.001
    optimizer = optim.Adam(model.parameters(), lr=lr)

    num_epochs = 5
    for epoch in range(num_epochs):
        for i, (data, targets) in enumerate(data_gen()):
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output.view(-1, ntokens), targets.view(-1))
            loss.backward()
            optimizer.step()

            if i % 50 == 0:
                print(f"Epoch: {epoch}, Loss: {loss.item():.6f}")

    # Testing on some data
    test_data = torch.tensor([[3, 6, 9], [2, 4, 6]])
    print("Test input:", test_data)
    test_output = torch.argmax(model(test_data), dim=2)
    print("Test output:", test_output)

为了让大家更容易掌握,这里编码器和解码器过程直接利用nn.TransformerEncoderLayer,根据transformer编码器的结构中其实是包括:多头自注意力机制,前馈神经网络,残差连接、规范化层的。我们在项目在可以直接引用,进行调整。

为了让每个人跑通Transformer模型,我将输入序列中的每个整数乘以2。数据生成器data_gen函数生成用于训练的随机序列。这里设置了一个较小的词汇表大小为20,表示单词无需太多。

在训练完成后,给出一个简单的测试样例。测试数据包括两个序列[3, 6, 9]和[2, 4, 6]。模型的输出是输入整数加倍的序列。给大家举这个例子只是为了展示如何搭建一个Transformer模型,并在实际任务上完成简单的训练,让我们真正零距离地接触Transformer模型。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/435011.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【微信小程序】数据监听器,纯数据字段

一、数据监听器 1.1 什么是数据监听器 数据监听器用于 监听和响应任何属性和数据字段的变化,从而执行特定的操作 。它的作用类似于 vue 中 的 watch 侦听器。在小程序组件中, 在componets中新建一个test2文件夹在文件夹里新建component 在app.json …

C学习笔记3

1、将一个整数转换成二进制形式,就是其原码。(通俗的理解,原码就是一个整数本来的二进制形式。) 例如short a 6;a 的原码就是0000 0000 0000 0110 2、反码就区分正负数,因为正负数的反码是不一样的,正数…

2023年MathorCup数模D题赛题解题思路

MathorCup俗称妈杯,是除了美赛国赛外参赛人数首屈一指的比赛,而我们的妈杯今天也如期开赛。今年的妈杯难度,至少在我看来应该是2023年截至目前来讲最难的一场比赛。问题的设置、背景的选取等各个方面都吐露着我要难死你们的想法。难度是恒定的…

Euro-NCAP 2030愿景

每隔五年,Euro NCAP都会将利益相关者聚集在一起,研究当前的汽车技术现状,预测未来几年可能的挑战,并确定未来的机会所在。讨论的结果是该组织的未来发展方向和明确的未来愿景:Euro NCAP 2030年愿景。 2020年初,Euro NCAP开始制定一套新的战略目标,打算在下一年公布其《2…

STM-32:SPI通信协议/W25Q64简介—软件SPI读写W25Q64

目录 一、SPI简介1.1电路模式1.2通信原理1.3SPI时序基本单元1.3.1起始和终止1.3.2交换字节 二、W25Q642.1W25Q64简介2.2W25Q64硬件电路2.3W25Q64框图2.4Flash操作注意事项 三、软件SPI读写W25Q643.1接线图3.2程序代码 一、SPI简介 SPI是串行外设接口(Serial Periph…

Spring Boot异步任务、异步消息

目录 1.异步任务 1.1.概述 1.2.使用 2.异步消息 2.1.概述 2.2.使用 1.异步任务 1.1.概述 举一个例子,我现在有一个网上商城,客户在界面点击下单后,后台需要完成两步: 1.创建客户订单 2.发短信通知客户订单号 这里面第2…

selenium 连接已经打开的chrome浏览器 MAC

selenium 连接已经打开的chrome浏览器 MAC 一,前言 今天在爬取chatGPT的谷歌插件的prompts的时候,发现绕不过他的反爬机制,失败111,所以想用连接已打开的chatGPT页面进行控制 二,具体步骤 1,添加环境变…

Android入门

一、Android系统架构 Android大致可以分为4层架构:Linux内核层、系统运行库层、应用框架层和应用层 1.1Linux内核层 Android系统是基于Linux内核的,这一层为Android设备的各种硬件提供了如显示、音频、照相机、蓝牙、Wi-Fi等底层的驱动。 1.2系统运行层…

2023MathorCup 高校数学建模挑战赛D题思路解析

如下为MathorCup 高校数学建模挑战赛D题思路解析: D 题 航空安全风险分析和飞行技术评估问题 飞行安全是民航运输业赖以生存和发展的基础。随着我国民航业的快速发展,针对飞行安全问题的研究显得越来越重要。2022 年 3 月 21 日,“3.21”空难…

如何使用vim的插件Ctags查看Linux源码

一.ubuntu下安装Linux内核源码 (1).查看自己的内核版本 (2).查看源内的内核源码类表 (3).下载源码 (4).进入/usr/src (5).解压下载的文件到用户主 二.安装vim插件Ctags和使用 插件的介绍 Ctags工具是用来遍历源代码文件生成tags文件,这些tags文件能被编辑器或其它工…

2023年的深度学习入门指南(4) - 在你的电脑上运行大模型

2023年的深度学习入门指南(4) - 在你的电脑上运行大模型 上一篇我们介绍了大模型的基础,自注意力机制以及其实现Transformer模块。因为Transformer被PyTorch和TensorFlow等框架所支持,所以我们只要能够配置好框架的GPU或者其他加速硬件的支持&#xff0…

同为科技(TOWE)防雷科普篇1—雷电灾害认识与雷电预警信号解读

前 言 雷电是自然界最为壮观的大气现象之一。其强大的电流、炙热的高温、猛烈的冲击波以及强烈的电磁辐射等物理效应能够在瞬间产生巨大的破坏作用,常常导致人员伤亡,击毁建筑物、供配电系统、通信设备,造成计算机信息系统中断,引…

电风扇出口欧美CE/UL507认证办理

电风扇简称电扇,是一种利用电动机驱动扇叶旋转,来达到使空气加速流通的家用电器,主要用于清凉解暑和流通空气。风扇主要由扇头、叶片、网罩和控制装置等部件组成。电风扇的主要部件是:交流电动机。其工作原理是:通电线圈在磁场中受力而转动。…

Streamlit 中函数多次进入的问题

Streamlit 函数多次进入的问题 Streamlit 学习的背景重要案例心得踩坑或注意点 Streamlit 学习的背景 最近在学习ai相关的知识,同时需要做一些方便使用的web网页。 例如:调用chatGPT的api,做对话窗。 之前用过gradio, 但是发现在手机端上&am…

NestJS:理解ORM(Object Relational Mapping)

一、理解ORM(Object Relation Mapping) ORM是对象—关系映射(Object/Relation Mapping,ORM)是为了解决面向对象与关系数据库存在的互不匹配现象而产生的技术。业务实体在内存中表现为对象,在数据库中表现为关系数据。ORM通过使用描…

【网络技术】HULK攻击实验

一、重要声明 请勿攻击公网!请勿攻击公网!请勿攻击公网! 一切责任自负!一切责任自负!一切责任自负! 二、工具介绍 HULK(Http Unbearable Load King,字面意思是“不能承受的HTTP&a…

项目三:双人骰子

项目三:双人骰子 文章目录 项目三:双人骰子一、导入(5分钟)学习目的 二、新授(65分钟)1.预展示结果(5分钟)2.本节课所用的软硬件(5分钟)3.硬件介绍(1分钟)4.图形化块介绍(1分钟)5.单个模块的简单使用(1分钟)6.双人骰子编程逻辑分析(30分钟)7.…

Faster RCNN系列——RoI Pooling与全连接层

在RPN网络中,已经计算过一次类别真值与类别预测值、偏移量真值与偏移量预测值之间的损失值,但这里的类别预测值只包含两类:前景和背景。 RPN网络输出的Proposal作为生成的区域,继续在后续的RoI Pooling和全连接层中进行分类与回归…

企业电子招投标系统源码之电子招投标系统建设的重点和未来趋势

项目说明 随着公司的快速发展,企业人员和经营规模不断壮大,公司对内部招采管理的提升提出了更高的要求。在企业里建立一个公平、公开、公正的采购环境,最大限度控制采购成本至关重要。符合国家电子招投标法律法规及相关规范,以及…

性能优化3-分帧寻路+寻路任务统一管理

前言 当项目里的地图越来越大,一些性能上的问题开始逐渐出现,比如寻路。玩家在操控角色移动的时候,指引需要实时更新,同时一些npc也需要做移动,容易出现cpu占用率短时间过高,甚至掉帧的情况。 去年底的时候…