【模型学习之路】手写+分析bert

news2024/11/4 15:34:21

手写+分析bert

目录

前言

架构

embeddings

Bertmodel

预训练任务

MLM

NSP

Bert

后话

netron可视化

code2flow可视化

fine tuning


前言

Attention is all you need!

读本文前,建议至少看懂【模型学习之路】手写+分析Transformer-CSDN博客。

毕竟Bert是transformer的变种之一。

架构

embeddings

Bert可以说就是transformer的Encoder,就像训练卷积网络时可以利用现成的网络然后fine tune就投入使用一样,Bert的动机就是训练一种预训练模型,之后根据不同的场景可以做不同的fine tune。

这里我们还是B代表批次(对于Bert,一个Batch可以输入一到两个句子,输入两个句子时,两个直接拼接就好了),m代表一个batch的单词数,n表示词向量的长度。

Bert的输入是三种输入之和(维度设定我们与本系列上一篇文章保持相同):

token_embeddings  和Transformer完全一样。

segment_embeddings  用来标记句子。第一个句子每个单词标0,第二个句子的每个单词标1。

pos_embeddings  用来标记位置,维度和Transformer中的一样,但是Bert的pos_embeddings是训练出来的(这意味它成为了神经网络里要训练的参数了)。

def get_token_and_segments(tokens_a, tokens_b=None):
    """
    bert的输入之一:token embeddings
    bert的输入之二:segment embeddings
    pos_embeddings在后面的模型里面
    """
    tokens = ['<cls>'] + tokens_a + ['<sep>']
    segments = [0] * (len(tokens_a) + 2)
    if tokens_b is not None:
        tokens += tokens_b + ['<sep>']
        segments += [1] * (len(tokens_b) + 1)
    return tokens, segments

Bertmodel

Bert的单个EncoderLayer和Transformer是一样的,我们直接把上一节的代码复制过来就好。

组装好。

class BertModel(nn.Module):
    def __init__(self, vocab, n, d_ff, h, n_layers,
                 max_len=1000, k=768, v=768):
        super(BertModel, self).__init__()
        self.token_embeddings = nn.Embedding(vocab, n)  # [B, m]->[B, m, vocab]->[B, m, n]
        self.segment_embeddings = nn.Embedding(2, n)  # [B, m]->[B, m, 2]->[B, m, n]
        self.pos_embeddings = nn.Parameter(torch.randn(1, max_len, n))  # [1, max_len, n]
        self.layers = nn.ModuleList([EncoderLayer(n, h, k, v, d_ff)
                                     for _ in range(n_layers)])

    def forward(self, tokens, segments, m):  # m是句子长度
        X = self.token_embeddings(tokens) + \
            self.segment_embeddings(segments)
        X += self.pos_embeddings[:, :X.shape[1], :]
        for layer in self.layers:
            X, attn = layer(X)
        return X

简单测试一下。

# 弄一点数据测试一下
    tokens = torch.randint(0, 100, (2, 10))  # [B, m]
    segments = torch.randint(0, 2, (2, 10))  # [B, m]
    m = 10
    bert = BertModel(100, 768, 3072, 12, 12)
    out = bert(tokens, segments, m)
    print(out.shape)  # [2, 10, 768]

预训练任务

Bert在训练时要做两种训练,这里先画个图表示架构,后面给出分析和代码。

MLM

Maked language model,是指在训练的时候随即从输入预料上mask掉一些单词,然后通过的上下文预测该单词,该任务非常像我们在中学时期经常做的完形填空。

在BERT的实验中,15%的WordPiece Token会被随机Mask掉。在训练模型时,一个句子会被多次喂到模型中用于参数学习,但是Google并没有在每次都mask掉这些单词,而是在确定要Mask掉的单词之后,80%的时候会直接替换为[Mask],10%的时候将其替换为其它任意单词,10%的时候会保留原始Token。(这里就不深入了)

class MLM(nn.Module):
    def __init__(self, vocab, n, mlm_hid):
        super(MLM, self).__init__()
        self.mlp = nn.Sequential(
            nn.Linear(n, mlm_hid),
            nn.ReLU(),
            nn.LayerNorm(mlm_hid),
            nn.Linear(mlm_hid, vocab))

    def forward(self, X, P):
        # X: [B, m, n]
        # P: [B, p]
        # 这里P指的是记录了要mask的元素的矩阵,若P(i,j)==k,表示X(i,k)被mask了
        p = P.shape[1]
        P = P.reshape(-1)
        batch_size = X.shape[0]
        batch_idx = torch.arange(batch_size)
        batch_idx = torch.repeat_interleave(batch_idx, p)
        X = X[batch_idx, P].reshape(batch_size, p, -1)  # [B, p, n]
        out = self.mlp(X)
        return out

这里的forward的逻辑有点麻烦,要读懂的话可以要手推一下。p是每一个Batch中mask的词的个数。(即在一个Batch中,m个词挑出了p个)。其实意会也行:就是训练时,在X[B, m, n]里每个batch有p个是<mask>,我们把它们挑出来,得到[B, p, n]。

NSP

Next Sentence Prediction的任务是判断句子B是否是句子A的下文。训练数据的生成方式是从平行语料中随机抽取的连续两句话,其中50%保留抽取的两句话,它们符合is_next关系,另外50%的第二句话是随机从预料中提取的,不符合is_next关系。分别记为1 | 0。

这个关系由每个句子的第一个token——<cls>捕捉。

class NSP(nn.Module):
    def __init__(self, n, nsp_hid):
        super(NSP, self).__init__()
        self.mlp = nn.Sequential(
            nn.Linear(n, nsp_hid),
            nn.Tanh(),
            nn.Linear(nsp_hid, 2))

    def forward(self, X):
        # X: [B, m, n]
        X = X[:, 0, :]  # [B, n]
        out = self.mlp(X)  # [B, 2]
        return out

Bert

下面拼装Bert。

class Bert(nn.Module):
    def __init__(self, vocab, n, d_ff, h, n_layers,
                 max_len=1000, k=768, v=768, mlm_feat=768, nsp_feat=768):
        super(Bert, self).__init__()
        self.encoder = BertModel(vocab, n, d_ff, h, n_layers, max_len, k, v)
        self.mlm = MLM(vocab, n, mlm_feat)
        self.nsp = NSP(n, nsp_feat)

    def forward(self, tokens, segments, m, P=None):
        X = self.encoder(tokens, segments, m)
        mlm_out = self.mlm(X, P) if P is not None else None
        nsp_out = self.nsp(X)
        return X, mlm_out, nsp_out

后话

netron可视化

利用netron可视化。

test_tokens = torch.randint(0, 100, (2, 10))  # [B, m]
test_segments = torch.randint(0, 2, (2, 10))  # [B, m]
test_P = torch.tensor([[1, 2, 4, 6, 8], [1, 3, 4, 5, 6]])
test_m = 10
test_bert = Bert(100, 768, 3072, 12, 12)
test_X, test_mlm_out, test_nsp_out = test_bert(test_tokens, test_segments, test_m, test_P)

modelData = "./demo.pth"
torch.onnx.export(test_bert, (test_tokens, test_segments), modelData)
netron.start(modelData)

截取部分看一下。

code2flow可视化

code2flow可以可视化代码函数和类的相互调用关系。

code2flow.code2flow([r'代码路径.py'], '输出路径.svg')

这里生成的png,其实svg清晰得多。

fine tuning

Bert的精髓在于,Bert只是一个编码器(Encoder),经过MLM和NSP两个任务的训练之后,可以自己在它的基础上训练一个Decoder来输出特定的值、得到特定的效果。这也是Bert的神奇和魅力所在!通过两个任务训练出一个编码器,然后可以通过不同的Decoder达到各种效果!

持续探索Bert......

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2231901.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

stm32移植LVGL(LVGL 8.2.0)

目录 1.下载LVGL源码 2.修改LVGL文件夹 (1)文件夹 examples 改动 (2)文件夹 demos 改动 3.最终LVGL文件夹内容 4.软件Keil配置、添加头文件 5.程序配置 6.其它配置参考链接 1.下载LVGL源码 LVGL源码地址&#xff1a;https://github.com/lvgl/lvgl 2.修改LVGL文件夹…

海南华志亿星电子商务有限公司电商服务的璀璨新星

在这个全民直播、短视频带货风起云涌的时代&#xff0c;抖音电商以其独特的魅力成为了众多商家竞相追逐的蓝海市场。而在这片波澜壮阔的商海中&#xff0c;海南华志亿星电子商务有限公司犹如一颗璀璨的新星&#xff0c;以其专业的服务、创新的策略&#xff0c;为无数商家照亮了…

动手学深度学习66 使用注意力机制的seq2seq

1. 使用注意力机制的seq2seq key value等价 是一个东西 第i个词rnn的输出 根据加权的不同&#xff0c;解码器前面用编码器前面的输出&#xff0c;到后面用后面的输出。 2. code 核心代码: context 怎么算 embedding没变&#xff0c;Decoder加了attention层 class Seq2SeqAt…

高校大数据实训平台介绍

高校大数据实验室架构 具体实训平台介绍 编程实训平台 1、大数据开发实训平台 大数据开发实训平台是面向实训课和课后训练的编程实训平台&#xff0c;平台底层基于Docker技术&#xff0c;采用容器云部署方案&#xff0c;预装大数据相关课程教学所需的实训环境…

【快速上手】pyspark 集群环境下的搭建(Yarn模式)

目录 前言&#xff1a; 一、安装步骤 安装前准备 1.第一步&#xff1a;安装python 2.第二步&#xff1a;在bigdata01上安装spark 3.第三步&#xff1a;同步bigdata01中的spark到bigdata02和03上 二、启动 三、可打开yarn界面查看任务 前言&#xff1a; 上一篇介绍的是…

【ARM Linux 系统稳定性分析入门及渐进 1.2 -- Crash 工具依赖内容】

文章目录 Prerequisites1. 内核对象文件2. 内存镜像3. 平台处理器类型4. Linux 内核版本 Prerequisites crash 工具需要依赖下面的内容&#xff1a; 1. 内核对象文件 vmlinux 文件&#xff1a;需要一个 vmlinux 内核对象文件&#xff0c;在本文中称为命名列表&#xff08;na…

【Canal 中间件】Canal 实现 MySQL 增量数据的异步缓存更新

文章目录 一、安装 MySQL1.1 启动 mysql 服务器1.2 开启 Binlog 写入功能1.2.1创建 binlog 配置文件1.2.2 修改配置文件权限1.2.3 挂载配置文件1.2.4 检测 binlog 配置是否成功 1.3 创建账户并授权 二、安装 RocketMQ2.1 创建容器共享网络2.2 启动 NameServer2.3 启动 Broker2.…

Spring Boot2.x教程:(十)从Field injection is not recommended谈谈依赖注入

从Field injection is not recommended谈谈依赖注入 1、问题引入2、依赖注入的三种方式2.1、字段注入&#xff08;Field Injection&#xff09;2.2、构造器注入&#xff08;Constructor Injection&#xff09;2.3、setter注入&#xff08;Setter Injection&#xff09; 3、为什…

解决 ClickHouse 高可用集群中 VRID 冲突问题:基于 chproxy 和 keepalived 的实践分析

Part1背景描述 近期&#xff0c;我们部署了两套 ClickHouse 生产集群&#xff0c;分别位于同城的两个数据中心。这两套集群的数据保持一致&#xff0c;以便在一个数据中心发生故障时&#xff0c;能够迅速切换应用至另一个数据中心的 ClickHouse 实例&#xff0c;确保服务连续性…

【Android】View的事件分发机制

文章目录 分发顺序ActivityViewGroupView 协作方法整体流程注意 Activity事件分发ViewGroup事件分发View点击事件总结 分发顺序 Activity->ViewGroup->View Activity 分发事件&#xff1a;Activity 通过 dispatchTouchEvent 方法分发事件&#xff0c;首先尝试将事件传递…

java项目之微服务在线教育系统设计与实现(springcloud)

风定落花生&#xff0c;歌声逐流水&#xff0c;大家好我是风歌&#xff0c;混迹在java圈的辛苦码农。今天要和大家聊的是一款基于springboot的闲一品交易平台。项目源码以及部署相关请联系风歌&#xff0c;文末附上联系信息 。 项目简介&#xff1a; 微服务在线教育系统设计与…

ChatGPT:真如吹的那般神乎其神吗?

ChatGPT的确是个神奇的东西。短短600多天&#xff0c;就已成全球访问量最大的网站之一。 ChatGPT已经出现在与这些大佬顶级大佬Google、Youtube、X.com、Baidu、Yahoo、amazon、Tiktok一起。 当然ChatGPT很优秀&#xff0c;这没有疑问&#xff0c;主要问题还是对度的把握上。…

【深度学习】实验 — 动手实现 GPT【二】:注意力机制、注意力掩码、多头注意力机制

【深度学习】实验 — 动手实现 GPT【二】&#xff1a;注意力机制、多头注意力机制 注意力机制简单示例&#xff1a;单个元素的情况简单示例&#xff1a;计算所有输入词元的注意力权重推广到所有输入序列词元&#xff1a; 注意力掩码代码实现多头注意力测试 注意力机制 简单示例…

简单的kafkaredis学习之kafka

简单的kafka&redis学习整理之kafka 1. kafka 1.1 什么是消息队列 在学习Kafka之前我们先来看一下什么是消息队列&#xff0c;消息队列(Message Queue)&#xff1a;可以简称为MQ 例如&#xff1a;Java中的Queue队列&#xff0c;也可以认为是一个消息队列 消息队列&#x…

基于人工智能的搜索和推荐系统

互联网上的搜索历史分析和用户活动是个性化推荐的基础&#xff0c;这些推荐已成为电子商务行业和在线业务的强大营销工具。随着人工智能的使用&#xff0c;在线搜索也在改进&#xff0c;因为它会根据用户的视觉偏好提出建议&#xff0c;而不是根据每个客户的需求和偏好量身定制…

ssm042在线云音乐系统的设计与实现+jsp(论文+源码)_kaic

摘 要 随着移动互联网时代的发展&#xff0c;网络的使用越来越普及&#xff0c;用户在获取和存储信息方面也会有激动人心的时刻。音乐也将慢慢融入人们的生活中。影响和改变我们的生活。随着当今各种流行音乐的流行&#xff0c;人们在日常生活中经常会用到的就是在线云音乐系统…

TVS 静电管 选型

参数选型举例: 静电管选型举例: 针对信号引脚一般只需ESD防护,关注其在IEC 61000−4−2波形下的测试结果:最大耐压值、钳位电压等,注意此时钳位电压的限值就不是Absolute maximum ratings值了,原因有2 1、Absolute maximum ratings值是指持续加压会损坏芯片 2、如果关…

监控调度台在交通运输行业的优势?

在当今快速发展的交通运输行业中&#xff0c;高效、安全的管理成为确保运营顺畅和乘客满意的关键。监控调度台作为这一领域的核心设备&#xff0c;正发挥着越来越重要的作用。它集成了视频监控、数据分析、实时通讯等多种功能&#xff0c;为交通运输行业带来了诸多优势。下面我…

华为ENSP--ISIS路由协议

项目背景 为了确保资源共享、办公自动化和节省人力成本&#xff0c;公司E申请两条专线将深圳总部和广州、北京两家分公司网络连接起来。公司原来运行OSFP路由协议&#xff0c;现打算迁移到IS-IS路由协议&#xff0c;张同学正在该公司实习&#xff0c;为了提高实际工作的准确性和…

设计模式07-结构型模式2(装饰模式/外观模式/代理模式/Java)

4.4 装饰模式 4.4.1 装饰模式的定义 1.动机&#xff1a;在不改变一个对象本身功能的基础上给对象增加额外的新行为 2.定义&#xff1a;动态地给一个对象增加一些额外的职责&#xff0c;就增加对象功能来说&#xff0c;装饰模式比生成子类实现更为灵活 4.4.2 装饰模式的结构…