【原创】实现ChatGPT中Transformer模型之Encoder-Decoder

news2024/11/28 9:27:37

作者:黑夜路人

时间:2023年7月

Transformer Block (通用块)实现

看以上整个链路图,其实我们可以很清晰看到这心其实在Encoder环节里面主要是有几个大环节,每一层主要的核心作用如下:

  • Multi-headed self Attention(注意力机制层):通过不同的注意力函数并拼接结果,提高模型的表达能力,主要计算词与词的相关性和长距离次的相关性。
  • Normalization layer(归一化层):对每个隐层神经元进行归一化,使其特征值的均值为0,方差为1,解决梯度爆炸和消失问题。通过归一化,可以将数据压缩在一个合适范围内,避免出现超大或超小值,有利于模型训练,也能改善模型的泛化能力和加速模型收敛以及减少参数量的依赖。
  • Feed forward network(前馈神经网络):对注意力输出结果进行变换。
  • Another normalization layer(另一个归一化层):Weight Normalization用于对模型中层与层之间的权重矩阵进行归一化,其主要目的是解决梯度消失问题

注意力(Attention)实现

参考其他我们了解的信息可以看到里面核心的自注意力层等等,我们把每个层剥离看看核心这一层应该如何实现。

简单看看注意力的计算过程:

这张图所表示的大致运算过程是:

对于每个token,先产生三个向量query,key,value:

query向量类比于询问。某个token问:“其余的token都和我有多大程度的相关呀?”

key向量类比于索引。某个token说:“我把每个询问内容的回答都压缩了下装在我的key里”

value向量类比于回答。某个token说:“我把我自身涵盖的信息又抽取了一层装在我的value里”

注意力计算代码:

def attention(query: Tensor,
              key: Tensor,
              value: Tensor,
              mask: Optional[Tensor] = None,
              dropout: float = 0.1):
    """
    定义如何计算注意力得分
    参数:
        query: shape (batch_size, num_heads, seq_len, k_dim)
        key: shape(batch_size, num_heads, seq_len, k_dim)
        value: shape(batch_size, num_heads, seq_len, v_dim)
        mask: shape (batch_size, num_heads, seq_len, seq_len). Since our assumption, here the shape is
              (1, 1, seq_len, seq_len)
    返回:
        out: shape (batch_size, v_dim). 注意力头的输出。注意力分数:形状(seq_len,seq_ln)。
    """
    k_dim = query.size(-1)

    # shape (seq_len ,seq_len),row: token,col: token记号的注意力得分
    scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(k_dim)

    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e10)

    attention_score = F.softmax(scores, dim=-1)

    if dropout is not None:
        attention_score = dropout(attention_score)

    out = torch.matmul(attention_score, value)

    return out, attention_score  # shape: (seq_len, v_dim), (seq_len, seq_lem)

以图中的token a2为例:

它产生一个query,每个query都去和别的token的key做“某种方式”的计算,得到的结果我们称为attention score(即为图中的$$\alpha $$)。则一共得到四个attention score。(attention score又可以被称为attention weight)。

将这四个score分别乘上每个token的value,我们会得到四个抽取信息完毕的向量。

将这四个向量相加,就是最终a2过attention模型后所产生的结果b2。

整个这一层,我们通过代码来进行这个逻辑的简单实现:

class MultiHeadedAttention(nn.Module):
    def __init__(self,
                 num_heads: int,
                 d_model: int,
                 dropout: float = 0.1):
        super(MultiHeadedAttention, self).__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        # 假设v_dim总是等于k_dim
        self.k_dim = d_model // num_heads
        self.num_heads = num_heads
        self.proj_weights = clones(
            nn.Linear(d_model, d_model), 4)  # W^Q, W^K, W^V, W^O
        self.attention_score = None
        self.dropout = nn.Dropout(p=dropout)

    def forward(self,
                query: Tensor,
                key: Tensor,
                value: Tensor,
                mask: Optional[Tensor] = None):
        """
        参数:
            query: shape (batch_size, seq_len, d_model)
            key: shape (batch_size, seq_len, d_model)
            value: shape (batch_size, seq_len, d_model)
            mask: shape (batch_size, seq_len, seq_len). 由于我们假设所有数据都使用相同的掩码,因此这里的形状也等于(1,seq_len,seq-len)

        返回:
            out: shape (batch_size, seq_len, d_model). 多头注意力层的输出
        """
        if mask is not None:
            mask = mask.unsqueeze(1)
        batch_size = query.size(0)

        # 1) 应用W^Q、W^K、W^V生成新的查询、键、值
        query, key, value \
            = [proj_weight(x).view(batch_size, -1, self.num_heads, self.k_dim).transpose(1, 2)
                for proj_weight, x in zip(self.proj_weights, [query, key, value])]  # -1 equals to seq_len

        # 2) 计算注意力得分和out
        out, self.attention_score = attention(query, key, value, mask=mask,
                                              dropout=self.dropout)

        # 3) "Concat" 输出
        out = out.transpose(1, 2).contiguous() \
            .view(batch_size, -1, self.num_heads * self.k_dim)

        # 4) 应用W^O以获得最终输出
        out = self.proj_weights[-1](out)

        return out

Norm 归一化层实现

# 归一化层,标准化的计算公式
class NormLayer(nn.Module):
    def __init__(self, features, eps=1e-6):
        super(LayerNorm, self).__init__()
        self.a_2 = nn.Parameter(torch.ones(features))
        self.b_2 = nn.Parameter(torch.zeros(features))
        self.eps = eps

    def forward(self, x):
        mean = x.mean(-1, keepdim=True)
        std = x.std(-1, keepdim=True)
        return self.a_2 * (x - mean) / (std + self.eps) + self.b_2

前馈神经网络实现

class FeedForward(nn.Module):

    def __init__(self, d_model, d_ff=2048, dropout=0.1):
        super().__init__()

        # 设置 d_ff 缺省值为2048
        self.linear_1 = nn.Linear(d_model, d_ff)
        self.dropout = nn.Dropout(dropout)
        self.linear_2 = nn.Linear(d_ff, d_model)

    def forward(self, x):
        x = self.dropout(F.relu(self.linear_1(x)))
        x = self.linear_2(x)

Encoder(编码器曾)实现

Encoder 就是将前面介绍的整个链路部分,全部组装迭代起来,完成将源编码到中间编码的转换。


class EncoderLayer(nn.Module):

    def __init__(self, d_model, heads, dropout=0.1):
        super().__init__()
        self.norm_1 = Norm(d_model)
        self.norm_2 = Norm(d_model)
        self.attn = MultiHeadAttention(heads, d_model, dropout=dropout)
        self.ff = FeedForward(d_model, dropout=dropout)
        self.dropout_1 = nn.Dropout(dropout)
        self.dropout_2 = nn.Dropout(dropout)

    def forward(self, x, mask):
        x2 = self.norm_1(x)
        x = x + self.dropout_1(self.attn(x2, x2, x2, mask))
        x2 = self.norm_2(x)
        x = x + self.dropout_2(self.ff(x2))
        return x

class Encoder(nn.Module):

    def __init__(self, vocab_size, d_model, N, heads, dropout):
        super().__init__()
        self.N = N
        self.embed = Embedder(d_model, vocab_size)
        self.pe = PositionalEncoder(d_model, dropout=dropout)
        self.layers = get_clones(EncoderLayer(d_model, heads, dropout), N)
        self.norm = Norm(d_model)

    def forward(self, src, mask):
        x = self.embed(src)
        x = self.pe(x)
        for i in range(self.N):
            x = self.layers[i](x, mask)
        return self.norm(x)

Decoder(解码器层)实现

Decoder部分和 Encoder 的部分非常的相似,它主要是把 Encoder 生成的中间编码,转换为目标编码。

class DecoderLayer(nn.Module):

    def __init__(self, d_model, heads, dropout=0.1):
        super().__init__()
        self.norm_1 = Norm(d_model)
        self.norm_2 = Norm(d_model)
        self.norm_3 = Norm(d_model)

        self.dropout_1 = nn.Dropout(dropout)
        self.dropout_2 = nn.Dropout(dropout)
        self.dropout_3 = nn.Dropout(dropout)

        self.attn_1 = MultiHeadAttention(heads, d_model, dropout=dropout)
        self.attn_2 = MultiHeadAttention(heads, d_model, dropout=dropout)
        self.ff = FeedForward(d_model, dropout=dropout)

    def forward(self, x, e_outputs, src_mask, trg_mask):
        x2 = self.norm_1(x)
        x = x + self.dropout_1(self.attn_1(x2, x2, x2, trg_mask))
        x2 = self.norm_2(x)
        x = x + self.dropout_2(self.attn_2(x2, e_outputs, e_outputs,
                                           src_mask))
        x2 = self.norm_3(x)
        x = x + self.dropout_3(self.ff(x2))
        return x

class Decoder(nn.Module):

    def __init__(self, vocab_size, d_model, N, heads, dropout):
        super().__init__()
        self.N = N
        self.embed = Embedder(vocab_size, d_model)
        self.pe = PositionalEncoder(d_model, dropout=dropout)
        self.layers = get_clones(DecoderLayer(d_model, heads, dropout), N)
        self.norm = Norm(d_model)

    def forward(self, trg, e_outputs, src_mask, trg_mask):
        x = self.embed(trg)
        x = self.pe(x)
        for i in range(self.N):
            x = self.layers[i](x, e_outputs, src_mask, trg_mask)
        return self.norm(x)

Transformer 实现

把整个链路结合,包括Encoder和Decoder,最终就能够形成一个Transformer框架的基本MVP实现。

class Transformer(nn.Module):

    def __init__(self, src_vocab, trg_vocab, d_model, N, heads, dropout):
        super().__init__()
        self.encoder = Encoder(src_vocab, d_model, N, heads, dropout)
        self.decoder = Decoder(trg_vocab, d_model, N, heads, dropout)
        self.out = nn.Linear(d_model, trg_vocab)

    def forward(self, src, trg, src_mask, trg_mask):
        e_outputs = self.encoder(src, src_mask)
        d_output = self.decoder(trg, e_outputs, src_mask, trg_mask)
        output = self.out(d_output)
        return output

代码说明

如果想要学习阅读整个代码,访问 black-transformer 项目,github访问地址:

GitHub - heiyeluren/black-transformer: black-transformer 是一个轻量级模拟Transformer模型实现的概要代码,用于了解整个Transformer工作机制

取代你的不是AI,而是比你更了解AI和更会使用AI的人!

##End##

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/763900.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Flask 分页Demo

项目结构 app.py from flask import Flask, render_template, requestapp Flask(__name__)books [{title: Book 1, author: Author 1, year: 2020},{title: Book 2, author: Author 2, year: 2021},{title: Book 3, author: Author 3, year: 2022},{title: Book 4, author: …

《面试1v1》面试官让我讲一下Kafka的性能哪里好

🍅 作者简介:王哥,CSDN2022博客总榜Top100🏆、博客专家💪 🍅 技术交流:定期更新Java硬核干货,不定期送书活动 🍅 王哥多年工作总结:Java学习路线总结&#xf…

公钥加密之『迪菲–赫尔曼密钥交换』,颜色混合的把戏

前奏: 迪菲-赫尔曼密钥交换是一种安全协议。它可以让双方在完全没有对方任何预先信息的条件下通过不安全信道创建起一个密钥。这个密钥可以在后续的通讯中作为对称密钥来加密通讯内容。 颜色混合: 如图所示,蛇和老鼠要交换一个共享密钥&…

Springboot热部署相关功能

文章目录 前言一、Springboot如何在IDEA中开启热部署二、热部署的相关知识1.热部署的范围2.关闭热部署 前言 环境是Mac电脑下的IDEA 2023.1.X版本 如何在修改程序后自动进行加载修改后的程序而不是重启加载所有资源而更新程序,这就使用到了Springboot相关的热部署功…

120页商业银行企业级IT架构规划ppt

导读:原文《商业银行企业级IT架构规划ppt》(获取来源见文尾),本文精选其中精华及架构部分,逻辑清晰、内容完整,为快速形成售前方案提供参考。 完整版领取方式 完整版领取方式: 如需获取完整的电…

2023年7月17日,比较器,TreeMap底层,LinkedHashMap,Set接口

比较器 Comparator是外部比较器,用于比较来对象与对象之间的,两个对象进行比较,多用于集合排序 Comparable可以认为是一个内比较器,根据对象某一属性进行排序的。 1. 使用场景 ​ 内置比较器(Comparable)的…

Windows10下ChatGLM2-6B模型本地化安装部署教程图解

随着人工智能技术的不断发展,自然语言处理模型在研究和应用领域备受瞩目。ChatGLM2-6B模型作为其中的一员,以其强大的聊天和问答能力备受关注,并且最突出的优点是性能出色且轻量化。然而,通过云GPU部署安装模型可能需要支付相应的…

手把手带你创建微服务项目

1.先创建以下项目结构 2.在父项目中导入以下依赖 <dependencies><dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter</artifactId></dependency><!-- Web依赖 --><dependency>&l…

关于Java集合框架的总结

关于Java集合框架的总结 本篇文章先从整体介绍了java集合框架包含的接口和类&#xff0c;然后总结了集合框架中的一些基本知识和关键点&#xff0c;并结合实例进行简单分析。当我们把一个对象放入集合中后&#xff0c;系统会把所有集合元素都当成Object类的实例进行处理。从JDK…

分享四款导航页 个人主页html源码

一、开源免费&#xff0c;可以展示很多社交账号&#xff0c;也可以更换社交账号图标指向你的网站&#xff0c;上传后即可使用 https://wwwf.lanzout.com/ik7R912s031g 二、开源免费&#xff0c;不过部署稍微麻烦点 https://wwwf.lanzout.com/iCq2u12s02wb 三、适合做成导航页面…

Android性能优化篇[谷歌官方]

网上看到了个和Android性能优化相关的系列文章&#xff0c;觉的还不错&#xff0c;和大家分享下。 在Android领域&#xff0c;性能永远是一块大头。市场对这类人才的需求也是有增不减&#xff0c;而且薪资待遇也不错。如果大家想深入学习Android某个领域&#xff0c; 那性能这块…

190 → 169,50天瘦20斤随感

一头猪瘦二十斤没有人会在意&#xff0c;但一个人猛地瘦二十斤或许就会有意思~ 从五月底到7月中旬&#xff0c;大致50天瘦了21斤。本文大致从我自己的感想、方法、减肥前后的心态及身体变化等方面来给予你一些关键信息&#xff0c;希望对你有用吧。 当你发现自己真的在一斤一斤…

react CSS :last-child 最后一个下边框线如何去掉

需求&#xff1a;调用分类接口后&#xff0c;tab的最后一个border不要横线。 代码如下 逻辑是 i是否等于books数组的长度-1。 books.map((book, i) > { return( <View style{borderBottom:idx ! dictype.length - 1 && "1px solid #ECEFF7"} key…

CentOS7 mariadb10.x 安装

1、添加mariabd yum源 vi /etc/yum.repos.d/mariadb.repo [mariadb] name MariaDB baseurl https://mirrors.tuna.tsinghua.edu.cn/mariadb/yum/10.5/centos7-amd64/ gpgkey https://mirrors.tuna.tsinghua.edu.cn/mariadb/yum/RPM-GPG-KEY-MariaDB gpgcheck 12、建立yum…

EMC学习笔记(十六)射频PCB的EMC设计(三)

射频PCB的EMC设计&#xff08;三&#xff09; 1.布线1.1 阻抗控制2.2 转角1.3 微带线布线1.4 微带线耦合器1.5 微带线功分器1.6 微带线基本元件1.7 带状线布线1.8 射频信号走线两边包地铜皮 2.其他设计考虑 1.布线 1.1 阻抗控制 PCB信号走线的阻抗与板材的介电常数、PCB结构、…

【山河送书第二期】:《零基础学会Python编程(ChatGPT版》

【山河送书第二期】&#xff1a;《零基础学会Python编程&#xff08;ChatGPT版》 前言内容简介作者简介 前言 在过去的 5 年里&#xff0c;Python 已经 3 次获得 TIOBE 指数年度大奖&#xff0c;这得益于数据科学和人工智能领域的发展&#xff0c;使得 Python 变得异常流行&am…

电脑pdf怎么转换成word文档?你不知道的几种方法

在电脑上&#xff0c;我们经常需要查阅PDF文件&#xff0c;并且PDF文件在生活中应用广泛&#xff0c;可以保存许多重要内容。有时候我们需要将PDF文件转换为Word文档&#xff0c;以便对文件内容进行编辑。幸运的是&#xff0c;在电脑上将PDF转换为Word文档非常方便&#xff0c;…

nginx实现反向代理

Nginx Nginx (“engine x”) 是一个高性能的HTTP和反向代理服务器&#xff0c;特点是占有内存少&#xff0c;并发能力强。Nginx可以作为静态页面的web服务器&#xff0c;同时还支持CGI协议的动态语言&#xff0c;比如perl、php等。但是不支持java。Java程序只能通过与tomcat配合…

生命在于折腾——MacOS(Inter)渗透测试环境搭建

一、前景提要 之前使用的是2022款M2芯片的MacBook Air 13寸&#xff0c;不得不说&#xff0c;是真的续航好&#xff0c;轻薄&#xff0c;刚开始我了解到M芯片的底层是ARM架构&#xff0c;我觉得可以接受&#xff0c;虚拟机用的不多&#xff0c;但在后续的使用过程中&#xff0…

《MySQL》事务

文章目录 概念事务的操作属性&#xff08;aicd&#xff09; 概念 一组DML语句&#xff0c;这组语句要一次性执行完毕&#xff0c;是一个整体 为什么要有事务&#xff1f; 为应用层提供便捷服务 事务的操作 有一stu表 # 查看事务提交方式(默认是开启的) show variables like au…