BERT的代码实现

news2024/12/27 0:37:52

目录

1.BERT的理论

2.代码实现 

 2.1构建输入数据格式

 2.2定义BERT编码器的类

 2.3BERT的两个任务

2.3.1任务一:Masked Language Modeling MLM掩蔽语言模型任务 

2.3.2 任务二:next sentence prediction

3.整合代码 

 4.知识点个人理解


 

1.BERT的理论

BERT全称叫做Bidirectional Encoder Representations from Transformers, 论文地址: [1810.04805] BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding (arxiv.org)

BERT是谷歌AI研究院在2018年10月提出的一种预训练模型. BERT本质上就是Transformer模型的encoder部分, 并且对encoder做了一些改进.

  • 官方代码和预训练模型 Github: https://github.com/google-research/bert

下图中编码器部分即BERT的基本结构.

  

2.代码实现 

import torch
from torch import nn
import dltools

 2.1构建输入数据格式

def get_tokens_and_segments(tokens_a, tokens_b=None):
    #classification 分类
    #BERT是两句话作为一对句子一同传入的,也可以单独传一句话,若序列长度长,可以补padding
    #假设先传一句话tokens_a
    tokens = ['<cls>'] + tokens_a + ['<sep>']  #tokens_embedding层的处理
    segments = [0] * (len(tokens_a) + 2)  #判断词元属于哪一句话,加标记,0属于第一句话
    if tokens_b is not None:
        tokens += tokens_b + ['sep']
        segments += [1] * (len(tokens_b) + 1)
    return tokens, segments


#测试上面的函数
get_tokens_and_segments([1, 2, 3], [4, 5, 6])

(['<cls>', 1, 2, 3, '<sep>', 4, 5, 6, 'sep'], [0, 0, 0, 0, 0, 1, 1, 1, 1])

 2.2定义BERT编码器的类

class BERTEncoder(nn.Module):
    #由于前馈网络的ffn_num_outputs = num_hiddens,没有初始化传入
    #__init__()里面的参数,是创建类的时候传入的参数
    def __init__(self, vocab_size, num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens, num_heads, num_layers, dropout,
                max_len=1000, key_size=768, query_size=768, value_size=768, **kwargs):
        super().__init__(**kwargs)
        #token_embeddings层
        self.token_emdedding = nn.Embedding(vocab_size, num_hiddens)
        #segment_embedding层  (传入两个句子,所以第0维为2)
        self.segment_embedding = nn.Embedding(2, num_hiddens)
        #pos_embedding层  :位置嵌入层是可以学习的, 用nn.Parameter()定义可学习的参数
        self.pos_embedding = nn.Parameter(torch.randn(1, max_len, num_hiddens))
        
        #设置Encoder_block的数量
        self.blks = nn.Sequential()  #为使用的Encoder_block依次编号
        for i in range(num_layers):  #有几层网络循环几层
            self.blks.add_module(f'{i}', dltools.EncoderBlock(key_size, query_size, value_size, num_hiddens, norm_shape, 
                                                              ffn_num_input, ffn_num_hiddens, num_heads, dropout))
    
    #__init__()里面的参数,是创建类的时候传入的参数
    #foward里面的参数是创建完类对象之后,调用类方法时传入的参数
    def forward(self, tokens, segments, valid_lens):
        #X = token_embedding + segment_embedding + pos_embedding
        #传入的token_embedding,segment_embedding两者的shape相同,可以直接相加
        X = self.token_emdedding(tokens) + self.segment_embedding(segments)
        #pos_embedding与前两层的数据shape不相同,不能直接相加
        #切片让self.pos_embedding的第1维度的数据切片到token_embedding,segment_embedding相加之后的数
        X = X + self.pos_embedding.data[:, :X.shape[1], :]
        
        for blk in self.blks:
            X = blk(X, valid_lens)
        return X  
#测试上面代码


#创建BERTEncoder类对象
vocab_size, num_hiddens, ffn_num_hiddens, num_heads = 10000, 768, 1024, 4
norm_shape, ffn_num_input, num_layers, dropout = [768], 768, 2, 0.2
encoder = BERTEncoder(vocab_size, num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens, num_heads, num_layers, dropout)


tokens = torch.randint(0, vocab_size, (2, 8)) #生成随机正整数
segments = torch.tensor([[0, 0, 0, 0, 1, 1, 1, 1], [0, 0, 0, 0, 1, 1, 1, 1]])
#调用类方法
encoded_X = encoder(tokens, segments, None)


encoded_X.shape
torch.Size([2, 8, 768])

#  nn.Sequential()是PyTorch中的一个类,它允许用户将多个计算层按照顺序组合成一个模型。在深度学习中,模型可以是由各种不同类型的层组成的,例如卷积层、池化层、全连接层等。nn.Sequential()方法可以将这些层组合在一起,形成一个整体模型。 

 2.3BERT的两个任务

2.3.1任务一:Masked Language Modeling MLM掩蔽语言模型任务 

class MaskLM(nn.Module):
    def __init__(self, vocab_size, num_inputs=768, **kwargs):
        super().__init__(**kwargs)
        self.mlp = nn.Sequential(nn.Linear(num_inputs, num_hiddens),  #全连接层
                                nn.ReLU(),  
                                nn.LayerNorm(num_hiddens), 
                                nn.Linear(num_hiddens, vocab_size))  #输出层
    
    #X表示随机(15%概率)将一些词元换成mask
    #pred_positions表示已经处理好的80%概率将选中的词换成mask>, 10%概率换成随机词元,10%概率保持原有词元
    #pred_position是二维数据
    def forward(self, X, pred_positions):  
        num_pred_positions = pred_positions.shape[1]  #索引出80%、10%、10%三个概率选出的需要转换的词位置数量
        pred_positions = pred_positions.reshape(-1)  #变成一维数据
        batch_size = X.shape[0]  #获取批次
        batch_idx = torch.arange(0, batch_size) #获取批次的编号
        #将批次编号与元素数量对应起来
        #例如:batch_size = [0, 1]   -->   [0, 0, 0, 1, 1, 1]
        batch_idx = torch.repeat_interleave(batch_idx, num_pred_positions)  #将batch_idx中每个元素重复num_pred_positions次
        #把要预测位置的数据取出来
        masked_X = X[batch_idx, pred_positions]
        masked_X = masked_X.reshape(batch_size, num_pred_positions, -1)  #还原维度
        mlm_Y_hat = self.mlp(masked_X)
        return mlm_Y_hat
#测试代码


mlm = MaskLM(vocab_size, num_hiddens)
mlm_positions = torch.tensor([[1, 5, 2], [6, 1, 5]])
mlm_Y_hat = mlm(encoded_X, mlm_positions)


mlm_Y_hat.shape    #2:2个批次,   3:三个需要转换词元的位置     10000:计算的概率数量(在最后会用softmax函数计算分类结果),vocab_size有10000个,
torch.Size([2, 3, 10000])
mlm_Y = torch.tensor([[7, 8, 9], [10, 20, 30]])  #假设真实值
loss = nn.CrossEntropyLoss(reduction='none')
mlm_l = loss(mlm_Y_hat.reshape(-1, vocab_size), mlm_Y.reshape(-1))  # mlm_Y_hat的shape=(6, 10000)     mlm_Y的shape=(6)
mlm_l.shape

torch.Size([6])

2.3.2 任务二:next sentence prediction

class NextSentencePred(nn.Module):
    def __init__(self, num_inputs, **kwargs):
        super().__init__(**kwargs)
        self.output = nn.Linear(num_inputs, 2)  #预测输入的句子是否为下一个句子,预测目标值为“是/否”二分类问题
        
    def forward(self, X):
        #X的形状(batch_size, num_hiddens)
        return self.output(X)
#测试代码


encoded_X = torch.flatten(encoded_X, start_dim=1)  #将数据展平,相当于reshape
nsp = NextSentencePred(encoded_X.shape[-1])
nsp_Y_hat = nsp(encoded_X)

nsp_Y_hat.shape

torch.Size([2, 2])
#计算损失
nsp_y = torch.tensor([0, 1])   #假设真实值
nsp_1 = loss(nsp_Y_hat, nsp_y)
nsp_1.shape

torch.Size([2])

3.整合代码 

class BERTModel(nn.Module):
    def __init__(self, vocab_size, num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens, num_heads, num_layers, dropout,
                 max_len=1000, key_size=768, query_size=768, value_size=768,
                 hid_in_features=768, mlm_in_features=768, nsp_in_features=768, **kwargs):
        super().__init__(**kwargs)
        #初始化编码器对象
        self.encoder = BERTEncoder(vocab_size, num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens, num_heads, num_layers, dropout,
                                   max_len=max_len, key_size=key_size, query_size=query_size, value_size=value_size)
        #掩蔽语言模型任务
        self.mlm = MaskLM(vocab_size, num_hiddens, mlm_in_features)
        #中间隐藏层的线性转换+激活函数
        self.hidden = nn.Sequential(nn.Linear(hid_in_features, num_hiddens), nn.Tanh())
        #预测出下一句
        self.nsp = NextSentencePred(nsp_in_features)
        
    def forward(self, tokens, seqments, valid_lens=None, pred_position=None):
        encoded_X = self.encoder(tokens, seqments, valid_lens)
        if pred_position is not None:
            mlm_Y_hat = self.mlm(encoded_X, pred_position)
        else:
            pred_position = None
            
        #0表示<cls>标记的索引
        nsp_Y_hat = self.nsp(self.hidden(encoded_X[:, 0, :]))
        return encoded_X, mlm_Y_hat, nsp_Y_hat

 4.知识点个人理解

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2156197.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

代码随想录算法训练营第58天|卡码网 117. 软件构建、47. 参加科学大会

1. 卡码网 117. 软件构建 题目链接&#xff1a;https://kamacoder.com/problempage.php?pid1191 文章链接&#xff1a;https://www.programmercarl.com/kamacoder/0117.软件构建.html 思路&#xff1a;使用BFS BFS的实现思路&#xff1a; 拓扑排序的过程&#xff0c;其实就两步…

Java : 图书管理系统

图书管理系统的作用&#xff1a; 高效的图书管理 图书管理系统通过自动化管理&#xff0c;实现了图书的采编、编目、流通管理等操作的自动化处理&#xff0c;大大提高了图书管理的效率和准确性。 工作人员可以通过系统快速查找图书信息&#xff0c;实时掌握图书的借还情况&…

经典报童问题的2类扩展实例:带广告的报童问题和多产品报童问题

文章目录 1 引言2 经典报童问题3 带广告的报童问题3.1 论文解读3.2 样本均值近似方法 4 多产品报童问题4.1 论文解读4.2 算法模型4.3 简单实例求解4.4 复杂实例求解 5 总结6 相关阅读 1 引言 中秋已过&#xff0c;国庆未至&#xff0c;趁着这个空窗期&#xff0c;学点新知识&a…

解决DockerDesktop启动redis后采用PowerShell终端操作

如图&#xff1a; 在启动redis容器后&#xff0c;会计入以下界面 &#xff1a; 在进入执行界面后如图&#xff1a; 是否会觉得界面过于单调&#xff0c;于是想到使用PowerShell来操作。 步骤如下&#xff1a; 这样就能使用PowerShell愉快地敲命令了&#xff08;颜值是第一生…

AttributeError: ‘Sequential‘ object has no attribute ‘predict_classes‘如何解决

今天跟着书敲代码&#xff0c;报错&#xff1a; Sequential object has no attribute predict_classes&#xff0c;如图所示&#xff1a; 上网百度&#xff0c;发现predict_classes函数在新版本中已经删除了&#xff0c;需要使用 model.predict() 替代 model.predict_classes()…

【java面经速记】Mysql和ES数据同步

目录 Mysql业务数据库 ES查询数据库 数据同步方案 同步双写 异步双写&#xff08;MQ方式&#xff09; 基于Mysql的定时扫描同步 基于Binlog实时同步 使用canal监听binlog同步数据到es&#xff08;流行方案&#xff09; 拓展:mysql的主从复制原理 canal原理&#xff1a…

Via浏览器自动关闭CSDN弹窗

不知道大家有没有突发灵感迫切需要在手机上搜索一些技术性博客的时候。 不知道大家是不是搜索到的基本都是CSDN的文章。 不知道大家是否也被CSDN各种弹窗确认搞得心态爆炸。 不知道大家现在在手机上用的是什么浏览器&#xff0c;一直以来&#xff0c;我用的都是夸克&#xf…

时钟的配置

在使用51单片机时&#xff0c;系统使用的时钟源是一个外部晶体振荡器&#xff0c;频率为12M。由于51单片机每个指令周期都是12分频的&#xff0c;所以实际工作频率仅为1M。2440作为一种性能远高于51的Soc&#xff0c;主频肯定要远远高于51&#xff0c;因此2440有着比51单片机复…

【Android】DataBinding的运用

引言 之前对databinding有了基础的运用与介绍&#xff0c;但databinding的用处不单单在于Text的绑定&#xff0c;接下来就一起看看吧&#xff01; 意义&#xff1a;让布局文件承担了部分原本属于页面的工作&#xff0c;使页面与布局耦合度进一步降低。允许用户界面&#xff0…

Maven-一、分模块开发

Maven进阶 文章目录 Maven进阶前言创建新模块向新模块装入内容使用新模块把模块部署到本地仓库补充总结 前言 分模块开发可以把一个完整项目中的不同功能分为不同模块管理&#xff0c;然后模块间可以相互调用&#xff0c;该篇以一个SSM项目为目标展示如何使用maven分模块管理。…

操作系统之I/O设备管理

I/O系统的组成 I/O系统的结构 微机I/O系统 总线型I/O系统结构,CPU与内存之间可以直接进行信息交换&#xff0c;但是不能与设备直接进行信息交换&#xff0c;必须经过设备控制器。 主机I/O系统 I/O系统可能采用四级结构&#xff0c;包括主机、通道、控制器和设备。一个通道…

神经网络面试题目

1. 批规范化(Batch Normalization)的好处都有啥&#xff1f;、 A. 让每一层的输入的范围都大致固定 B. 它将权重的归一化平均值和标准差 C. 它是一种非常有效的反向传播(BP)方法 D. 这些均不是 正确答案是&#xff1a;A 解析&#xff1a; ‌‌‌‌  batch normalization 就…

TikTokDownloader 开源项目操作教程

TikTokDownloader TikTokDownloader 是一个开源的多功能视频下载工具&#xff0c;它专门用于从抖音和TikTok平台下载无水印的视频、图集和直播内容。这个工具支持批量下载账号作品、收藏内容&#xff0c;并可以采集详细数据。它提供了命令行和Web界面&#xff0c;具有多线程下…

图像处理基础知识点简记

简单记录一下图像处理的基础知识点 一、取样 1、释义 图像的取样就是图像在空间上的离散化处理,即使空间上连续变化的图像离散化, 决定了图像的空间分辨率。 2、过程 简单描述一下图象取样的基本过程,首先用一个网格把待处理的图像覆盖,然后把每一小格上模拟图像的各个…

re题(38)BUUCTF-[FlareOn6]Overlong

BUUCTF在线评测 (buuoj.cn) 运行一下.exe文件 查壳是32位的文件&#xff0c;放到ida反汇编 对unk_402008前28位进行一个操作&#xff0c;我们看到运行.exe文件的窗口正好是28个字符&#xff0c;而unk_402008中不止28个数据&#xff0c;所以猜测MessageBoxA&#xff08;&#x…

十一、 JDK17 新特性梳理

文章目录 为什么是JDK17语法层面新特性1、文本块2 、Switch 表达式增强3、instanceof的模式匹配4、var 局部变量推导 模块化及类封装1、记录类 record2 、隐藏类 Hidden Classes3 、密封类 Sealed Classes4、模块化 Module System1 、什么是模块化2、声明一个module3 、require…

从零开始讲DDR(4)——Xilinx方案

本文依据的是xilinx的PG150文档&#xff0c;主要介绍的是xilinx的ultrascale系列中内存资源的使用。 一、方案概述 Xilinx UltraScale™架构中的DDR3/DDR4 SDRAM ip核旨在支持高性能的内存接口解决方案。这些ip可以用于将DDR3和DDR4 SDRAM内存类型集成到设计中&#xff0c;提供…

干货 | 2024数智新时代制造业数字化创新实践白皮书(免费下载)

导读&#xff1a;本白皮书将对制造业发展历程、现状、趋势与核心难题做深入解读&#xff0c;并在此基础上提出了相应的制造行业解决方案,结合业内实践成功的客户案例来详析信息化转型的有效方法&#xff0c;以供生产制造行业的从业者参考交流。

计算机毕业设计 基于SpringBoot框架的网上蛋糕销售系统的设计与实现 Java实战项目 附源码+文档+视频讲解

博主介绍&#xff1a;✌从事软件开发10年之余&#xff0c;专注于Java技术领域、Python人工智能及数据挖掘、小程序项目开发和Android项目开发等。CSDN、掘金、华为云、InfoQ、阿里云等平台优质作者✌ &#x1f345;文末获取源码联系&#x1f345; &#x1f447;&#x1f3fb; 精…

解决SVN蓝色问号的问题

桌面或文件夹右键&#xff0c;选择TortoiseSVN->Settings打开设置对话框&#xff0c;选择Icon Overlays->Overlay Handlers->取消钩选Unversioned。确定&#xff0c;重启系统即可