ROCm上来自Transformers的双向编码器表示(BERT)

news2024/12/23 23:04:55

14.8. 来自Transformers的双向编码器表示(BERT) — 动手学深度学习 2.0.0 documentation (d2l.ai)

代码

import torch
from torch import nn
from d2l import torch as d2l

#@save
def get_tokens_and_segments(tokens_a, tokens_b=None):
    """获取输入序列的词元及其片段索引"""
    tokens = ['<cls>'] + tokens_a + ['<sep>']
    # 0和1分别标记片段A和B
    segments = [0] * (len(tokens_a) + 2)
    if tokens_b is not None:
        tokens += tokens_b + ['<sep>']
        segments += [1] * (len(tokens_b) + 1)
    return tokens, segments

#@save
class BERTEncoder(nn.Module):
    """BERT编码器"""
    def __init__(self, vocab_size, num_hiddens, norm_shape, ffn_num_input,
                 ffn_num_hiddens, num_heads, num_layers, dropout,
                 max_len=1000, key_size=768, query_size=768, value_size=768,
                 **kwargs):
        super(BERTEncoder, self).__init__(**kwargs)
        self.token_embedding = nn.Embedding(vocab_size, num_hiddens)
        self.segment_embedding = nn.Embedding(2, num_hiddens)
        self.blks = nn.Sequential()
        for i in range(num_layers):
            self.blks.add_module(f"{i}", d2l.EncoderBlock(
                key_size, query_size, value_size, num_hiddens, norm_shape,
                ffn_num_input, ffn_num_hiddens, num_heads, dropout, True))
        # 在BERT中,位置嵌入是可学习的,因此我们创建一个足够长的位置嵌入参数
        self.pos_embedding = nn.Parameter(torch.randn(1, max_len,
                                                      num_hiddens))

    def forward(self, tokens, segments, valid_lens):
        # 在以下代码段中,X的形状保持不变:(批量大小,最大序列长度,num_hiddens)
        X = self.token_embedding(tokens) + self.segment_embedding(segments)
        X = X + self.pos_embedding.data[:, :X.shape[1], :]
        for blk in self.blks:
            X = blk(X, valid_lens)
        return X

vocab_size, num_hiddens, ffn_num_hiddens, num_heads = 10000, 768, 1024, 4
norm_shape, ffn_num_input, num_layers, dropout = [768], 768, 2, 0.2
encoder = BERTEncoder(vocab_size, num_hiddens, norm_shape, ffn_num_input,
                      ffn_num_hiddens, num_heads, num_layers, dropout)

tokens = torch.randint(0, vocab_size, (2, 8))
segments = torch.tensor([[0, 0, 0, 0, 1, 1, 1, 1], [0, 0, 0, 1, 1, 1, 1, 1]])
encoded_X = encoder(tokens, segments, None)
encoded_X.shape

tokens = torch.randint(0, vocab_size, (2, 8))
segments = torch.tensor([[0, 0, 0, 0, 1, 1, 1, 1], [0, 0, 0, 1, 1, 1, 1, 1]])
encoded_X = encoder(tokens, segments, None)
encoded_X.shape

mlm = MaskLM(vocab_size, num_hiddens)
mlm_positions = torch.tensor([[1, 5, 2], [6, 1, 5]])
mlm_Y_hat = mlm(encoded_X, mlm_positions)
mlm_Y_hat.shape

mlm_Y = torch.tensor([[7, 8, 9], [10, 20, 30]])
loss = nn.CrossEntropyLoss(reduction='none')
mlm_l = loss(mlm_Y_hat.reshape((-1, vocab_size)), mlm_Y.reshape(-1))
mlm_l.shape

#@save
class NextSentencePred(nn.Module):
    """BERT的下一句预测任务"""
    def __init__(self, num_inputs, **kwargs):
        super(NextSentencePred, self).__init__(**kwargs)
        self.output = nn.Linear(num_inputs, 2)

    def forward(self, X):
        # X的形状:(batchsize,num_hiddens)
        return self.output(X)

encoded_X = torch.flatten(encoded_X, start_dim=1)
# NSP的输入形状:(batchsize,num_hiddens)
nsp = NextSentencePred(encoded_X.shape[-1])
nsp_Y_hat = nsp(encoded_X)
nsp_Y_hat.shape

nsp_y = torch.tensor([0, 1])
nsp_l = loss(nsp_Y_hat, nsp_y)
nsp_l.shape

#@save
class BERTModel(nn.Module):
    """BERT模型"""
    def __init__(self, vocab_size, num_hiddens, norm_shape, ffn_num_input,
                 ffn_num_hiddens, num_heads, num_layers, dropout,
                 max_len=1000, key_size=768, query_size=768, value_size=768,
                 hid_in_features=768, mlm_in_features=768,
                 nsp_in_features=768):
        super(BERTModel, self).__init__()
        self.encoder = BERTEncoder(vocab_size, num_hiddens, norm_shape,
                    ffn_num_input, ffn_num_hiddens, num_heads, num_layers,
                    dropout, max_len=max_len, key_size=key_size,
                    query_size=query_size, value_size=value_size)
        self.hidden = nn.Sequential(nn.Linear(hid_in_features, num_hiddens),
                                    nn.Tanh())
        self.mlm = MaskLM(vocab_size, num_hiddens, mlm_in_features)
        self.nsp = NextSentencePred(nsp_in_features)

    def forward(self, tokens, segments, valid_lens=None,
                pred_positions=None):
        encoded_X = self.encoder(tokens, segments, valid_lens)
        if pred_positions is not None:
            mlm_Y_hat = self.mlm(encoded_X, pred_positions)
        else:
            mlm_Y_hat = None
        # 用于下一句预测的多层感知机分类器的隐藏层,0是“<cls>”标记的索引
        nsp_Y_hat = self.nsp(self.hidden(encoded_X[:, 0, :]))
        return encoded_X, mlm_Y_hat, nsp_Y_hat

代码解析

这段代码是基于PyTorch框架实现的BERT(Bidirectional Encoder Representations from Transformers)模型。BERT是一种预训练语言表示模型,它可以用于各种自然语言处理(NLP)任务。下面是代码的中文解析:
1. get_tokens_and_segments(tokens_a, tokens_b=None) 函数用于获取输入句子的词元(tokens)及其对应的片段索引。如果有第二个句子 tokens_b,则会进行拼接,并用不同的索引来标识不同的句子。
2. BERTEncoder 类定义了BERT的编码器结构,它包含嵌入层(用于将词元转换为向量表示)、位置嵌入和多个Transformer编码块。
3. forward 方法定义了模型的前向传播逻辑。它将输入的词元和片段索引通过编码器进行编码,并返回编码后的向量表示。
4. 其中 tokens 是批量输入数据的词元索引,`segments` 是对应的片段索引,这里模拟了输入数据作为模型的示例。
5. 创建一个 BERTEncoder 实例,该实例就是BERT模型的编码器部分,类似于 Transformer 模型中的编码器层。
6. MaskLM 类未在代码中定义,通常用来实现BERT的掩码语言模型任务,它在一定比例的输入词元上应用掩码,并训练模型来预测这些被掩码的词元。
7. NextSentencePred 类定义了BERT的下一句预测(Next Sentence Prediction, NSP)任务,是一个简单的二分类器,用来预测给定的两个句子片段是否在原始文本中顺序相邻。
8. BERTModel 类将编码器、掩码语言模型(MaskLM),以及下一句预测(NSP)整合为完整的BERT模型。它通过前向传播来处理输入,同时能够根据需求进行掩码语言模型预测和下一句预测。
9. 模型实例化后,通过随机生成的 tokens 和 segments 调用其 forward 方法,得到编码后的向量 encoded_X,同时执行MLM和NSP任务,输出预测结果。
10. 最后计算MLM和NSP任务的损失,这些损失通常用于训练模型。`CrossEntropyLoss` 是在类别预测问题中经常使用的一个损失函数。
整体来看,这段代码展示了如何构建一个基于BERT结构的模型,其中涵盖了BERT的两个典型预训练任务:掩码语言模型和下一句预测。需要注意的是,这个代码片段作为一个解析,但实际中运行它需要额外的上下文(例如 MaskLM 类的实现)和适当的数据准备和预处理步骤。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1695841.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Cortex-M3的SysTick 定时器

目录 概述 1 SysTick 定时器 1.1 SysTick 定时器功能介绍 1.2 SysTick 定时器功能实现 1.3 SysTick在系统中的作用 2 SysTick应用的实例 2.1 建立异常服务例程 2.2 使能异常 2.3 闹钟功能 2.4 重定位向量表 2.5 消灭二次触发 3 SysTick在FreeRTOS中的应用 3.1 STM…

(完全解决)Python字典dict如何由键key索引转化为点.dot索引

文章目录 背景解决方案基础版升级版 背景 For example, instead of writing mydict[‘val’], I’d like to write mydict.val. 解决方案 基础版 I’ve always kept this around in a util file. You can use it as a mixin on your own classes too. class dotdict(dict)…

如何进行异地多地兼容组网设置?

跨地区工作、远程办公和异地合作已成为常态。由于网络限制和安全性要求&#xff0c;远程连接仍然是一个具有挑战性的问题。为了解决这一难题&#xff0c;各行各业都在寻找一种能在异地多地兼容的组网设置方案。本文将着重介绍基于【天联】的组网解决方案&#xff0c;探讨其操作…

SpringBoot——整合Thymeleaf模板

目录 模板引擎 新建一个SpringBoot项目 pom.xml application.properties Book BookController bookList.html ​编辑 项目总结 模板引擎 模板引擎是为了用户界面与业务数据分离而产生的&#xff0c;可以生成特定格式的页面在Java中&#xff0c;主要的模板引擎有JSP&…

如何评价刘强东说“业绩不好的人不是我兄弟”

在近日的一次京东管理层会议上&#xff0c;创始人刘强东以不容置疑的口吻表明了对公司文化的坚定态度&#xff1a;“凡是长期业绩不好&#xff0c;从来不拼搏的人&#xff0c;不是我的兄弟。”这句话不仅是对那些工作表现不佳的员工的直接警告&#xff0c;也透露出京东在追求业…

C++语法|多重继承详解(一)|理解虚基类和虚继承

系列汇总讲解&#xff0c;请移步&#xff1a; C语法&#xff5c;虚函数与多态详细讲解系列&#xff08;包含多重继承内容&#xff09; 虚基类是多重继承知识上的铺垫。 首先我们需要明确抽象类和虚基类的区别&#xff1a; 抽象类&#xff1a;有纯虚函数的类 虚基类是什么呢&a…

阿里云的域名购买和备案(一)

前言 本篇文章主要讲阿里云的域名购买和备案。 大家好&#xff0c;我是小荣&#xff0c;我又开始做自己的产品迷途dev了。这里详细记录一下域名购买的流程和备案流程。视频教学 购买流程 1.阿里云官网搜索域名注册 2.搜索你想注册的域名 3.将想要注册的域名加入域名清单 4.点…

[Linux]网络原理与配置

一.NAT模式网路配置 虚拟系统的IP地址处于随机网段&#xff0c;同时在母机上会额外有一个与虚拟IP地址网段相同的IP地址&#xff0c;可以实现母机与虚拟机的通信。虚拟系统的IP地址可以通过主机实际的IP地址作为代理IP&#xff0c;与外部系统进行通信。 优点&#xff1a;不造…

2024.05.25学习记录

1、面经复习&#xff1a; JS异步进阶、vue-react-diff、vue-router模式、requestldleCallback、React Fiber 2、代码随想录刷题、动态规划 3、组件库使用storybook

【C++】牛客——JZ38 字符串的排列

✨题目链接&#xff1a; JZ38 字符串的排列 ✨题目描述 输入一个长度为 n 字符串&#xff0c;打印出该字符串中字符的所有排列&#xff0c;你可以以任意顺序返回这个字符串数组。 例如输入字符串ABC,则输出由字符A,B,C所能排列出来的所有字符串ABC,ACB,BAC,BCA,CBA和CAB。 数…

结合时间复杂度浅谈二分法的好处(将持续更新,绝对值你一个收藏)

前言 笔者虽然刷的算法题不多,但是笔者也敢说,二分法真的是一种很优越的算法,使用上限极高的那种,正因如此,笔者才想浅谈一下二分法. 封面是我很喜欢的一个游戏角色,不知道有没有老gal玩家知道! 什么是二分法? 枚举查找即顺序查找&#xff0c;实现原理是逐个比较数组 a[0:…

【C++】详解二叉搜索树

目录 树概述 二叉搜索树概述 概念 特性 元素操作 插入 删除 模拟实现 框架 查找 插入 删除 树概述 树——在计算机中是一种很常见的数据结构。 树是一种很强大的数据结构&#xff0c;数据库&#xff0c;linux操作系统管理和windows操作系统管理所有文件的结构就是…

【基础详解】快速入门入门 SQLite数据可

简介 SQLite 是一个开源的嵌入式关系数据库&#xff0c;实现了自给自足的、无服务器的、配置无需的、事务性的 SQL 数据库引擎。它是一个零配置的数据库&#xff0c;这意味着与其他数据库系统不同&#xff0c;比如 MySQL、PostgreSQL 等&#xff0c;SQLite 不需要在系统中设置…

golang中的字节序 binary BigEndian 大端 , LittleEndian 小端 理解与write写入注意事项

在golang的binary包中有2个字节系的变量定义BigEndian和LittleEndian 这个东西是go里面很有特点的玩意&#xff0c;我们在java, php等语言中是基本看不到&#xff0c;因为大部分的语言默认使用的是BigEndian 大端模式&#xff0c; 而go语言里面是你自己可选的。 这个字节系大小…

Java的类和对象

Java的类和对象 前言一、面向过程和面向对象初步认识C语言Java 二、类和类的实例化基本语法示例注意事项 类的实例化 三、类的成员字段/属性/成员变量注意事项默认值规则字段就地初始化 方法static 关键字修饰属性代码内存解析 修饰方法注意事项静态方法和实例无关, 而是和类相…

@Async详解,为什么生产环境不推荐直接使用@Async?

一、Async 注解介绍&#xff1a; Async 注解用于声明一个方法是异步的。当在方法上加上这个注解时&#xff0c;Spring 将会在一个新的线程中执行该方法&#xff0c;而不会阻塞原始线程。这对于需要进行一些异步操作的场景非常有用&#xff0c;比如在后台执行一些耗时的任务而不…

ssms用户登陆失败,服务器处于单用户模式。目前只有一位管理员能够连接。解决方案

文章目录 问题解决方案单用户模式什么是单用户模式&#xff1f;为什么使用单用户模式&#xff1f;实现步骤 问题 连接smss的时候发现无法连接&#xff0c;显示 服务器处于单用户模式。目前只有一位管理员能够连接 解决方案 打开SQL Server配置管理器 右键属性 在启动参数的最…

Python 之 日志巡检脚本

脚本说明 使用Paramiko库进行SSH连接的自动化脚本&#xff0c;用于检查、配置和排除设备故障。说明如下&#xff1a; 导入所需的库&#xff1a;paramiko、json、logging和concurrent.futures。定义配置文件路径&#xff08;devices.json&#xff09;和日志文件路径&#xff0…

Unity射击游戏开发教程:(26)创建绕圈跑的效果

unity游戏 在本文中,我将介绍如何为敌人创建圆周运动。gif 中显示的确切行为是敌人沿着屏幕向下移动,直到到达某个点,一旦到达该点,它就会绕圈移动。

c语言:摆脱对指针的恐惧【4】

在上一期指针我们讲到了二级指针是的作用是存放一级指针的地址&#xff0c;还讲了指针数组是一个可以存放若干个指针变量的数组&#xff0c;这里我们再复习一下&#xff0c;下面指针数组是什么意思&#xff1f; int* arr1[10]; //整形指针的数组 char *arr2[4]; //一级字符指针…