transformer实现词性标注

news2024/10/6 14:33:02

1、self-attention

1.1、self-attention结构图

上图是 Self-Attention 的结构,在计算的时候需要用到矩阵 Q(查询), K(键值), V(值)。在实际中,Self-Attention 接收的是输入(单词的表示向量 x组成的矩阵 X) 或者上一个 Encoder block 的输出。而 QK正是通过 Self-Attention 的输入进行线性变换得到的。

1.2 Q,K,V的计算

Self-Attention 的输入用矩阵 X进行表示,则可以使用线性变阵矩阵 WQWKWV 计算得到 QKV。计算如下图所示,注意 X, Q, K, V每一行都表示一个单词

 3.3 Self-Attention 的输出

得到矩阵 QKV之后就可以计算出 Self-Attention 的输出了,计算的公式如下: 

公式中计算矩阵 Q和 K 每一行向量的内积,为了防止内积过大,因此除以 dk 的平方根。乘以 K 的转置后,得到的矩阵行列数都为 n,n 为句子单词数,这个矩阵可以表示单词之间的 attention 强度。下图为 乘以 的转置,1234 表示的是句子中的单词。

得到 QK^{T} 之后,使用 Softmax 计算每一个单词对于其他单词的 attention 系数,公式中的 Softmax 是对矩阵的每一行进行 Softmax,即每一行的和都变为 1。

对矩阵每一行进行softmax
​​​​​

 

得到 Softmax 矩阵之后可以和 V相乘,得到最终的输出 Z

self-attention输出

 上图中 Softmax 矩阵的第 1 行表示单词 1 与其他所有单词的 attention 系数,最终单词 1 的输出 Z1 等于所有单词 i 的值 Vi 根据 attention 系数的比例加在一起得到,如下图所示:

Zi的计算方法

class Attention(nn.Module):
    def __init__(self, input_n:int,hidden_n:int):
        super().__init__()
        self.hidden_n = hidden_n
        self.input_n=input_n

        self.W_q = torch.nn.Linear(input_n, hidden_n)
        self.W_k = torch.nn.Linear(input_n, hidden_n)
        self.W_v = torch.nn.Linear(input_n, hidden_n)

        

    def forward(self, Q, K, V, mask=None):
        Q = self.W_q(Q)
        K = self.W_k(K)
        V = self.W_v(V)
        
        attention_scores = torch.matmul(Q, K.transpose(-2, -1))
        attention_weights = softmax(attention_scores)
        output = torch.matmul(attention_weights, V)
        return output
        

2、multi-head attention

       

从上图可以看到 Multi-Head Attention 包含多个 Self-Attention 层,首先将输入 X分别传递到 h 个不同的 Self-Attention 中,计算得到 h 个输出矩阵 Z。下图是 h=8 时候的情况,此时会得到 8 个输出矩阵 Z

多个self-attention

 得到 8 个输出矩阵 Z1 到 Z8 之后,Multi-Head Attention 将它们拼接在一起 (Concat),然后传入一个 Linear层,得到 Multi-Head Attention 最终的输出 Z

Multi-Head Attention的输出

 可以看到 Multi-Head Attention 输出的矩阵 Z与其输入的矩阵 X 的维度是一样的。

class MultiHeadAttention(nn.Module):
    def __init__(self,hidden_n:int, h:int = 2):
        """
        hidden_n: hidden dimension
        h: number of heads
        """
        super().__init__()
        
        embed_size=hidden_n
        heads=h

        self.embed_size = embed_size
        self.heads = heads
        # 每个head的处理的特征个数
        self.head_dim = embed_size // heads
 
        # 如果不能整除就报错
        assert (self.head_dim * self.heads == self.embed_size), 'embed_size should be divided by heads'
 
        # 三个全连接分别计算qkv
        self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
 
        # 输出层
        self.fc_out = nn.Linear(self.head_dim * self.heads, embed_size)


    def forward(self, Q, K, V, mask=None):

        query,values,keys=Q,K,V

        N = query.shape[0]  # batch
        # 获取每个句子有多少个单词
        value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]
 
        # 维度调整 [b,seq_len,embed_size] ==> [b,seq_len,heads,head_dim]
        values = values.reshape(N, value_len, self.heads, self.head_dim)
        keys = keys.reshape(N, key_len, self.heads, self.head_dim)
        queries = query.reshape(N, query_len, self.heads, self.head_dim)
 
        # 对原始输入数据计算q、k、v
        values = self.values(values)
        keys = self.keys(keys)
        queries = self.queries(queries)
 
        # 爱因斯坦简记法,用于张量矩阵运算,q和k的转置矩阵相乘
        # queries.shape = [N, query_len, self.heads, self.head_dim]
        # keys.shape = [N, keys_len, self.heads, self.head_dim]
        # energy.shape = [N, heads, query_len, keys_len]
        energy = torch.einsum('nqhd, nkhd -> nhqk', [queries, keys])
 
        # 是否使用mask遮挡t时刻以后的所有q、k
        if mask is not None:
            # 将mask中所有为0的位置的元素,在energy中对应位置都置为 -1*10^10
            energy = energy.masked_fill(mask==0, torch.tensor(-1e10))
 
        # 根据公式计算attention, 在最后一个维度上计算softmax
        attention = torch.softmax(energy/(self.embed_size**(1/2)), dim=3)
        
        # 爱因斯坦简记法矩阵元素,其中query_len == keys_len == value_len
        # attention.shape = [N, heads, query_len, keys_len]
        # values.shape = [N, value_len, heads, head_dim]
        # out.shape = [N, query_len, heads, head_dim]
        out = torch.einsum('nhql, nlhd -> nqhd', [attention, values])
        
        # 维度调整 [N, query_len, heads, head_dim] ==> [N, query_len, heads*head_dim]
        out = out.reshape(N, query_len, self.heads*self.head_dim)
 
        # 全连接,shape不变
        output = self.fc_out(out)


        return output

3、transformer block

3.1 encoder blockg构架图

 上图红色部分是 Transformer 的 Encoder block 结构,可以看到是由 Multi-Head Attention, Add & Norm, Feed Forward, Add & Norm 组成的。刚刚已经了解了 Multi-Head Attention 的计算过程,现在了解一下 Add & Norm 和 Feed Forward 部分。

3.2 Add & Norm

Add & Norm 层由 Add 和 Norm 两部分组成,其计算公式如下:

 其中 X表示 Multi-Head Attention 或者 Feed Forward 的输入,MultiHeadAttention(X) 和 FeedForward(X) 表示输出 (输出与输入 X 维度是一样的,所以可以相加)。

Add指 X+MultiHeadAttention(X),是一种残差连接,通常用于解决多层网络训练的问题,可以让网络只关注当前差异的部分,在 ResNet 中经常用到。

残差连接

 Norm指 Layer Normalization,通常用于 RNN 结构,Layer Normalization 会将每一层神经元的输入都转成均值方差都一样的,这样可以加快收敛。

3.3 Feed Forward

Feed Forward 层比较简单,是一个两层的全连接层,第一层的激活函数为 Relu,第二层不使用激活函数,对应的公式如下。

Feed Forward

 X是输入,Feed Forward 最终得到的输出矩阵的维度与 X 一致。

class TransformerBlock(nn.Module):
    def __init__(self, hidden_n:int, h:int = 2):
        """
        hidden_n: hidden dimension
        h: number of heads
        """
        super().__init__()
        embed_size=hidden_n
        heads=h
        # 实例化自注意力模块
        self.attention =MultiHeadAttention (embed_size, heads)
 
        # muti_head之后的layernorm
        self.norm1 = nn.LayerNorm(embed_size)
        # FFN之后的layernorm
        self.norm2 = nn.LayerNorm(embed_size)
 
        forward_expansion=1
        dropout=0.2

        # 构建FFN前馈型神经网络
        self.feed_forward = nn.Sequential(
            # 第一个全连接层上升特征个数
            nn.Linear(embed_size, embed_size * forward_expansion),
            # relu激活
            nn.ReLU(),
            # 第二个全连接下降特征个数
            nn.Linear(embed_size * forward_expansion, embed_size)
        )
 
        # dropout层随机杀死神经元
        self.dropout = nn.Dropout(dropout)


    def forward(self, value, key, query, mask=None):
        attention = self.attention(value, key, query, mask)
        # 输入和输出做残差连接
        x = query + attention
        # layernorm标准化
        x = self.norm1(x)
        # dropout
        x = self.dropout(x)

         # FFN
        ffn = self.feed_forward(x)
        # 残差连接输入和输出
        forward = ffn + x
        # layernorm + dropout
        out = self.dropout(self.norm2(forward))
 
        return out

transformer

import torch.nn as nn
class Transformer(nn.Module):
    def __init__(self,vocab_size, emb_n: int, hidden_n: int, n:int =3, h:int =2):
        """
        emb_n: number of token embeddings
        hidden_n: hidden dimension
        n: number of layers
        h: number of heads per layer
        """

        embedding_dim=emb_n
        
        super().__init__()
        self.embedding_dim = embedding_dim
        self.embeddings = nn.Embedding(vocab_size,embedding_dim)
        
        self.layers=nn.ModuleList(
            [TransformerBlock(hidden_n,h) for _ in range(n)    
            ]

        )
        


    def forward(self,x):
        N,seq_len=x.shape

        out=self.embeddings(x)
        for layer in self.layers:
            out=layer(out,out,out)

        return out

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/952294.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

什么是跨域(cross-origin)请求,如何解决跨域问题?

聚沙成塔每天进步一点点 ⭐ 专栏简介⭐ 跨域请求和跨域问题⭐ 解决跨域问题的方法1. CORS(跨域资源共享)2. JSONP(JSON with Padding)3. 代理服务器4. WebSocket5. 使用服务器中继 ⭐ 写在最后 ⭐ 专栏简介 前端入门之旅&#xff…

[C++]构造与毁灭:深入探讨C++中四种构造函数与析构函数

个人主页:北海 🎐CSDN新晋作者 🎉欢迎 👍点赞✍评论⭐收藏✨收录专栏:C/C🤝希望作者的文章能对你有所帮助,有不足的地方请在评论区留言指正,大家一起学习交流!&#x1f9…

ModaHub魔搭社区:星环科技致力于打造更优越的向量数据库

在数字化时代,数据成为了最重要的资源之一。随着人工智能、大数据等技术的不断发展,向量数据库成为了处理这类数据的关键工具。星环科技作为一家专注于数据存储和管理技术的公司,其重要目标就是将向量数据库打造得更为优越。 在星环科技,有一个专注于向量数据库的团队。这个…

当面试遇到难题:解决棘手问题的三大策略

🌷🍁 博主猫头虎 带您 Go to New World.✨🍁 🦄 博客首页——猫头虎的博客🎐 🐳《面试题大全专栏》 文章图文并茂🦕生动形象🦖简单易学!欢迎大家来踩踩~🌺 &a…

【C++】关于fixed和setprecision的学习和介绍

前言 在学习swap函数的时候&#xff0c;偶然了解到了fixed和setprecision&#xff0c;这两条控制语句&#xff0c;在了解了之后&#xff0c;觉得很有用&#xff0c;于是写一篇文章来介绍fixed和setprecision这两条控制语句 fixed控制输出形式 使用fixed语句需要包含<ioma…

Python2021年06月Python二级 -- 编程题解析

题目一 没有重复数字的两位数统计 编写一段程序&#xff0c;实现下面的功能: (1)检查所有的两位数; (2)程序自动分析两位数上的个位与十位上的数字是否相同&#xff0c;相同则剔除&#xff0c; 不同则保留(例:12符合本要求&#xff0c;个位是2&#xff0c;十位是1&#xff0c;两…

第7节——渲染列表+Key作用

一、列表渲染 我们再react中如果渲染列表&#xff0c;一般使用map方法进行渲染 import React from "react";export default class LearnJSX2 extends React.Component {state {infos: [{name: "张三",age: 18,},{name: "李四",age: 20,},{nam…

C#,数值计算——Midexp的计算方法与源程序

1 文本格式 using System; namespace Legalsoft.Truffer { public class Midexp : Midpnt { public new double func(double x) { return funk.funk(-Math.Log(x)) / x; } public Midexp(UniVarRealValueFun funcc, double aa, d…

跳槽面试:如何转换工作场所而不失去优势

&#x1f337;&#x1f341; 博主猫头虎 带您 Go to New World.✨&#x1f341; &#x1f984; 博客首页——猫头虎的博客&#x1f390; &#x1f433;《面试题大全专栏》 文章图文并茂&#x1f995;生动形象&#x1f996;简单易学&#xff01;欢迎大家来踩踩~&#x1f33a; &a…

16 个前端安全知识

16 个前端安全知识 去年 security course 上的是 React&#xff0c;然后学了一些 一些 React 项目中可能存在的安全隐患&#xff0c;今年看了一下列表&#xff0c;正好看到了前端也有更新&#xff0c;所以就把这个补上了。 一个非常好学习各种安全隐患的机构是 https://owasp…

《Python入门到精通》webbrowser模块详解,Python webbrowser标准库,Python浏览器控制工具

「作者主页」&#xff1a;士别三日wyx 「作者简介」&#xff1a;CSDN top100、阿里云博客专家、华为云享专家、网络安全领域优质创作者 「推荐专栏」&#xff1a;小白零基础《Python入门到精通》 webbrowser模块详解 1、常用操作2、函数大全webbrowser.open() 打开浏览器webbro…

Django请求的生命周期

Django请求的生命周期是指: 当用户在浏览器上输入URL到用户看到网页的这个时间段内&#xff0c;Django后台所发生的事情。 直白的来说就是当请求来的时候和请求走的阶段中&#xff0c;Django的执行轨迹。 一个完整的Django生命周期: 用户从客户端发出一条请求以后&#xff…

Kubernetes技术--Kubernetes架构组件以及核心概念

1.Kubernetes集群架构组件 搭建一个Kubernetes环境集群,其架构如下所示: 内容详解: Master:控制节点,指派任务、决策 Node:工作节点,实际干活的。 Master组件内容:

Python小知识 - 如何使用Python的Flask框架快速开发Web应用

如何使用Python的Flask框架快速开发Web应用 现在越来越多的人把Python作为自己的第一语言来学习&#xff0c;Python的简洁易学的语法以及丰富的第三方库让人们越来越喜欢上了这门语言。本文将介绍如何使用Python的Flask框架快速开发Web应用。 Flask是一个使用Python编写的轻量级…

迈向无限可能, ATEN宏正领跑设备切换行业革命!

随着互联网在各个领域的广泛应用,线上办公这一不受时间和地点制约、不受发展空间限制的办公模式开始广受追捧,预示着经济的发展正朝着新潮与活跃的方向不断跃进。当然,在互联网时代的背景下,多线程、多设备的线上办公模式也催生了许多问题:多设备间无法进行高速传输、切换;为保…

Mybatis 日志(JDK Log)

上一篇我们介绍了Mybatis中的参数&#xff0c;本篇我们使用JDK Log打印一下Mybatis运行时的日志&#xff0c;看一下Mybatis执行的过程。 这里我选取上一篇的示例进行JDK Log的集成&#xff0c;这里如果您想对上一篇进行详细了解&#xff0c;可以参考&#xff1a; Mybatis参数…

【python爬虫】4.爬虫实操(菜品爬取)

文章目录 前言项目&#xff1a;解密吴氏私厨分析过程代码实现&#xff08;一&#xff09;获取与解析提取最小父级标签一组菜名、URL、食材写循环&#xff0c;存列表 代码实现&#xff08;二&#xff09;复习总结 前言 上一关&#xff0c;我们学习了用BeautifulSoup库解析数据和…

说说构建流批一体准实时数仓

分析&回答 基于 Hive 的离线数仓往往是企业大数据生产系统中不可缺少的一环。Hive 数仓有很高的成熟度和稳定性&#xff0c;但由于它是离线的&#xff0c;延时很大。在一些对延时要求比较高的场景&#xff0c;需要另外搭建基于 Flink 的实时数仓&#xff0c;将链路延时降低…

国标视频云服务EasyGBS国标视频平台迁移服务器后无法启动的问题解决方法

国标视频云服务EasyGBS支持设备/平台通过国标GB28181协议注册接入&#xff0c;并能实现视频的实时监控直播、录像、检索与回看、语音对讲、云存储、告警、平台级联等功能。平台部署简单、可拓展性强&#xff0c;支持将接入的视频流进行全终端、全平台分发&#xff0c;分发的视频…

即插即生产与基于技能的设计

智能制造领域的主要研究工作就是为制造领域所有事物和行为构建数字化模型。最终实现制造工厂中设备&#xff0c;软件&#xff0c;物流所有事物的互联互通。而且实现这种互联互通是便捷&#xff0c;灵活的。通俗地将就是“即插即生产”。不过&#xff0c;要实现这一目标并非易事…