96. BERT预训练代码

news2025/1/16 0:21:53

利用实现的BERT模型和从WikiText-2数据集生成的预训练样本,我们将在本节中在WikiText-2数据集上对BERT进行预训练。

import torch
from torch import nn
from d2l import torch as d2l

首先,我们加载WikiText-2数据集作为小批量的预训练样本,用于遮蔽语言模型和下一句预测。批量大小是512,BERT输入序列的最大长度是64。注意,在原始BERT模型中,最大长度是512。

batch_size, max_len = 512, 64
train_iter, vocab = d2l.load_data_wiki(batch_size, max_len)

原始BERT 有两个不同模型尺寸的版本。基本模型( BERT_BASE )使用12层(Transformer编码器块),768个隐藏单元(隐藏大小)和12个自注意头。大模型( BERT_LARGE )使用24层,1024个隐藏单元和16个自注意头。值得注意的是,前者有1.1亿个参数,后者有3.4亿个参数。

为了便于演示,我们定义了一个小的BERT,使用了2层、128个隐藏单元和2个自注意头

net = d2l.BERTModel(len(vocab), num_hiddens=128, norm_shape=[128],
                    ffn_num_input=128, ffn_num_hiddens=256, num_heads=2,
                    num_layers=2, dropout=0.2, key_size=128, query_size=128,
                    value_size=128, hid_in_features=128, mlm_in_features=128,
                    nsp_in_features=128)
devices = d2l.try_all_gpus()
loss = nn.CrossEntropyLoss()

在定义训练代码实现之前,我们定义了一个辅助函数_get_batch_loss_bert。给定训练样本,该函数计算遮蔽语言模型和下一句子预测任务的损失。请注意,BERT预训练的最终损失是遮蔽语言模型损失和下一句预测损失的和。

def _get_batch_loss_bert(net, loss, vocab_size, tokens_X,
                         segments_X, valid_lens_x,
                         pred_positions_X, mlm_weights_X,
                         mlm_Y, nsp_y):
    # 前向传播
    _, mlm_Y_hat, nsp_Y_hat = net(tokens_X, segments_X,
                                  valid_lens_x.reshape(-1),
                                  pred_positions_X)
    # 计算遮蔽语言模型损失
    mlm_l = loss(mlm_Y_hat.reshape(-1, vocab_size), mlm_Y.reshape(-1)) *\
    mlm_weights_X.reshape(-1, 1)
    mlm_l = mlm_l.sum() / (mlm_weights_X.sum() + 1e-8)
    # 计算下一句子预测任务的损失
    nsp_l = loss(nsp_Y_hat, nsp_y)
    l = mlm_l + nsp_l
    return mlm_l, nsp_l, l

通过调用上述两个辅助函数,下面的train_bert函数定义了在WikiText-2(train_iter)数据集上预训练BERT(net)的过程。训练BERT可能需要很长时间。以下函数的输入num_steps指定了训练的迭代步数,而不是像train_ch13函数那样指定训练的轮数.

def train_bert(train_iter, net, loss, vocab_size, devices, num_steps):
    net = nn.DataParallel(net, device_ids=devices).to(devices[0])
    trainer = torch.optim.Adam(net.parameters(), lr=0.01)
    step, timer = 0, d2l.Timer()
    animator = d2l.Animator(xlabel='step', ylabel='loss',
                            xlim=[1, num_steps], legend=['mlm', 'nsp'])
    # 遮蔽语言模型损失的和,下一句预测任务损失的和,句子对的数量,计数
    metric = d2l.Accumulator(4)
    num_steps_reached = False
    while step < num_steps and not num_steps_reached:
        for tokens_X, segments_X, valid_lens_x, pred_positions_X,\
            mlm_weights_X, mlm_Y, nsp_y in train_iter:
            tokens_X = tokens_X.to(devices[0])
            segments_X = segments_X.to(devices[0])
            valid_lens_x = valid_lens_x.to(devices[0])
            pred_positions_X = pred_positions_X.to(devices[0])
            mlm_weights_X = mlm_weights_X.to(devices[0])
            mlm_Y, nsp_y = mlm_Y.to(devices[0]), nsp_y.to(devices[0])
            trainer.zero_grad()
            timer.start()
            mlm_l, nsp_l, l = _get_batch_loss_bert(
                net, loss, vocab_size, tokens_X, segments_X, valid_lens_x,
                pred_positions_X, mlm_weights_X, mlm_Y, nsp_y)
            l.backward()
            trainer.step()
            metric.add(mlm_l, nsp_l, tokens_X.shape[0], 1)
            timer.stop()
            animator.add(step + 1,
                         (metric[0] / metric[3], metric[1] / metric[3]))
            step += 1
            if step == num_steps:
                num_steps_reached = True
                break

    print(f'MLM loss {metric[0] / metric[3]:.3f}, '
          f'NSP loss {metric[1] / metric[3]:.3f}')
    print(f'{metric[2] / timer.sum():.1f} sentence pairs/sec on '
          f'{str(devices)}')

在预训练过程中,我们可以绘制出遮蔽语言模型损失和下一句预测损失。

train_bert(train_iter, net, loss, len(vocab), devices, 50)

运行结果:

在这里插入图片描述

2. 用BERT表示文本

在预训练BERT之后,我们可以用它来表示单个文本、文本对或其中的任何词元。下面的函数返回tokens_atokens_b中所有词元的BERT(net)表示。

def get_bert_encoding(net, tokens_a, tokens_b=None):
    tokens, segments = d2l.get_tokens_and_segments(tokens_a, tokens_b)
    token_ids = torch.tensor(vocab[tokens], device=devices[0]).unsqueeze(0)
    segments = torch.tensor(segments, device=devices[0]).unsqueeze(0)
    valid_len = torch.tensor(len(tokens), device=devices[0]).unsqueeze(0)
    encoded_X, _, _ = net(token_ids, segments, valid_len)
    return encoded_X

考虑“a crane is flying”这句话。回想一下BERT的输入表示。插入特殊标记“< cls>”(用于分类)和“< sep>”(用于分隔)后,BERT输入序列的长度为6。因为零是“< cls>”词元,encoded_text[:, 0, :]是整个输入语句的BERT表示。为了评估一词多义词元“crane”,我们还打印出了该词元的BERT表示的前三个元素。

tokens_a = ['a', 'crane', 'is', 'flying']
encoded_text = get_bert_encoding(net, tokens_a)
# 词元:'<cls>','a','crane','is','flying','<sep>'
encoded_text_cls = encoded_text[:, 0, :]
encoded_text_crane = encoded_text[:, 2, :]
encoded_text.shape, encoded_text_cls.shape, encoded_text_crane[0][:3]

运行结果:

在这里插入图片描述

现在考虑一个句子“a crane driver came”和“he just left”。类似地,encoded_pair[:, 0, :]是来自预训练BERT的整个句子对的编码结果。注意,多义词元“crane”的前三个元素与上下文不同时的元素不同。这支持了BERT表示是上下文敏感的。

tokens_a, tokens_b = ['a', 'crane', 'driver', 'came'], ['he', 'just', 'left']
encoded_pair = get_bert_encoding(net, tokens_a, tokens_b)
# 词元:'<cls>','a','crane','driver','came','<sep>','he','just',
# 'left','<sep>'
encoded_pair_cls = encoded_pair[:, 0, :]
encoded_pair_crane = encoded_pair[:, 2, :]
encoded_pair.shape, encoded_pair_cls.shape, encoded_pair_crane[0][:3]

运行结果:

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/187694.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Logstash:如何使用 Logstash 解析并摄入 JSON 数据到 Elasticsearch

在我之前的文章 “Logstash&#xff1a;Data 转换&#xff0c;分析&#xff0c;提取&#xff0c;丰富及核心操作” 有涉及到这个话题。今天我想使用一个具体的例子来更深入地展示。 准备数据 我们先来把如下的数据拷贝下来&#xff0c;并保存到一个叫做 sample.json 的文件中。…

OS 学习笔记(5) 操作系统的体系结构

OS 学习笔记(5) 操作系统的体系结构 王道OS 1.4 操作系统的体系结构 文章目录OS 学习笔记(5) 操作系统的体系结构知识总览分层结构模块化操作系统的内核大内核 vs 微内核知识回顾与重要考点外核王道chap1 回顾英文表达、术语积累&#xff08;《操作系统概念》第九版、ostep 《O…

电子模块|心率血氧传感器模块MAX30102及其驱动代码

电子模块|心率血氧传感器模块MAX30102及其驱动代码实物照片模块简介工作原理原理图及引脚说明STM32软件驱动IIC通信代码数值转换代码main函数结果实物照片 模块简介 MAX30102是一个集成的脉搏血氧仪和心率监测仪生物传感器的模块。 它集成了一个红光LED和一个红外光LED、光电…

【经济学】MIT 微观经济学 Microeconomoics

MIT 微观经济学P1 Introduction and Supply & Demand约束优化和机会成本供给和需求P1 Introduction and Supply & Demand 约束优化和机会成本 微观经济学是研究如何个人和公司做决定在一个稀缺的世界。稀缺性是微观经济的驱动力。 微观经济学是一系列约束优化练习&a…

Hadoop安全之Kerberos

简介 安全无小事&#xff0c;我们常常要为了预防安全问题而付出大量的代价。虽然小区楼道里面的灭火器、消防栓常年没人用&#xff0c;但是我们还是要准备着。我们之所以愿意为了这些小概率事件而付出巨大的成本&#xff0c;是因为安全问题一旦发生&#xff0c;很多时候我们将…

自学数据分析——数据分析方法和模型

一、数据分析方法 数据分析的思维需要培养&#xff0c;先模仿别人&#xff0c;从模仿者到创造者。首先需要建立数据的敏感性&#xff0c;能快速了解数据在说什么&#xff0c;下面我们以抖音教育直播为例&#xff0c;首先来了解核心指标&#xff0c;以及各个指标所表示的含义。…

17.Stream流

目录 一.Stream流 1.1 什么是Stream流 1.2 Stream流思想 1.3 Stream流的三类方法 1.4 获取Stream流 1.4.1 集合获取Stream流的方式 1.4.2 数组获取Stream流的方式 1.5 中间方法 1.6 终结方法 1.7 收集Stream流 1.7.1 什么是收集Stream流 1.7.2 收集方法 一.Stream流…

Ant Design Vue 之a-tree-select

Ant Design Vue 是比较流行的vue框架之一&#xff0c;主要是展示a-tree-select 的简单用法&#xff0c;a-tree-select组件主要用于展示树结构的选择。 <template><a-spin :spinning"confirmLoading"><a-form :form"form"><a-form-ite…

CnOpenDataA股上市公司社会责任报告数据

一、数据简介 A股上市公司社会责任报告数据由和讯网自2013年开始独家策划的产品&#xff0c;也是国内首家上市公司社会责任专业测评产品。上市公司社会责任报告专业测评体系从股东责任、员工责任、供应商、客户和消费者权益责任、环境责任和社会责任五项考察&#xff0c;各项分…

Linux Workqueue

Linux Workqueue 1、前言 Workqueue 是内核里面很重要的一个机制&#xff0c;特别是内核驱动&#xff0c;一般的小型任务 (work) 都不会自己起一个线程来处理&#xff0c;而是扔到 Workqueue 中处理。Workqueue 的主要工作就是用进程上下文来处理内核中大量的小任务。 所以 …

基于php的旅游管理系统

摘要随着计算机技术&#xff0c;网络技术的迅猛发展&#xff0c;Internet 的不断普及&#xff0c;网络在各个领域里发挥了越来越重要的作用。特别是随着近年人民生活水平不断提高&#xff0c;在线旅游给人们的旅游业带来了更大的发展机遇。在经济快速发展的带动下&#xff0c;旅…

【Linux】tar命令打包 | 查看压缩文件 | 打包时忽略文件

tar命令打包 | 查看压缩文件 | 打包时忽略文件 等操作 1.起因 今天下午写阿狸bot的代码的时候&#xff0c;写错了aiofiles的保存操作 # 正确写法 async def write_file_aio(path:str, value):async with aiofiles.open(path, w, encodingutf-8) as f:await f.write(json.dump…

MyBatis持久层框架详细解读:核心配置文件

文章目录1. 前言2. 多环境配置3. 类型别名4. 对象工厂5. 总结1. 前言 前面我们在使用 MyBatis 开发时&#xff0c;编写核心配置文件替换 JDBC 中的连接信息&#xff0c;解决了 JDBC 硬编码的问题。其实&#xff0c;MyBatis 核心配置文件中还可以配置很多的内容。 MyBatis 的配…

mongodb分片

分片是MongoDB的扩展方式,通过分片能够增加更多的机器来用对不断增加的负载和数据,还不影响应用.1.分片简介分片是指将数据拆分,将其分散存在不同机器上的过程.有时也叫分区.将数据分散在不同的机器上,不需要功能强大的大型计算机就可以存储更多的数据,处理更大的负载.使用几乎…

屏幕录制下载推荐(可以无水印录制视频)

您有没有遇到过这种情况&#xff0c;在使用录屏工具录制电脑屏幕时&#xff0c;录制出来的视频是带有明显水印的。那有没有可以无水印录制的屏幕录制推荐呢&#xff1f;当然有。最近小编发现了一款可以无水印&#xff08;自定义图文水印&#xff09;录制的视频&#xff0c;快来…

Pycharm误触ignore的解决方法--有图

步骤1&#xff1a;进入pycharm编辑器之后&#xff0c;找到菜单栏中的file选项&#xff0c;点击之后会有一个下拉列表&#xff0c;直接选择settings&#xff0c;进入到设置的窗口。步骤2&#xff1a;在设置界面的左侧&#xff0c;找到Inspections选项&#xff0c;点击之后&#…

JavaScript 练手小技巧:拖拽事件、把图片拖拽入页面

HTML5 新增了拖拽事件 drag&#xff0c;利用它可以实现把外部文件拖拽入页面中&#xff0c;可以实现文件的读取&#xff0c;上传等等功能。 拖拽&#xff0c;又叫拖拉、拖动&#xff0c;英文为 drag。 拖拽事件是 HTML5 新增的事件操作。 拖拽指的是&#xff0c;用户在某个对…

【Rust】4. Rust 基础

4. Rust 基础 4.1 变量和可变性 4.1.1 常量 const xxx: type ...&#xff1a;常量使用 const 来定义&#xff0c;且必须注明值的类型常量在声明它的作用域之中&#xff0c;常量在整个程序生命周期中都有效 4.1.2 隐藏&#xff08;Shadowing&#xff09; 隐藏&#xff08;Sh…

基于卡尔曼滤波器的PID控制-1

采用M语言对算例进行仿真&#xff01;&#xff01;设置控制对象传递函数&#xff1a;取采样时间为1ms&#xff0c;采用Z变换将对象离散化&#xff0c;并描述为离散状态方程的形式&#xff1a;x(k 1) Ax(k) B(u(k)wk))y(k) Cx(k)带有测量噪声的被控对象输出为&#xff1a;yv(k)C…

Ubuntu18.04下安装OpenCV4.2.0与Opencv_contrib(图文详细报错总结)

Ubuntu18.04下安装OpenCV4.2.0与Opencv_contrib&#xff08;图文详细&#xff09;前期准备—环境依赖Cmake&#xff08;编译器&#xff09;依赖环境Python环境streamer环境图像处理依赖安装OpenCV编译OpenCV配置cmake编译参数make编译配置OpenCV动态库验证OpenCV环境# python环…