BERT训练环节(代码实现)

news2025/1/21 6:22:43

1.代码实现

#导包
import torch
from torch import nn
import dltools
#加载数据需要用到的声明变量
batch_size, max_len = 1, 64
#获取训练数据迭代器、词汇表
train_iter, vocab = dltools.load_data_wiki(batch_size, max_len)
#其余都是二维数组
#tokens, segments, valid_lens(一维), pred_position, mlm_weights, mlm, nsp(一维)对应每条数据i中包含的数据
for i in train_iter:  #遍历迭代器
    break   #只遍历一条数据
[tensor([[    3,    25,     0,  4993,     0,    24,     4,    26,    13,     2,
            158,    20,     5,    73,  1399,     2,     9,   813,     9,   987,
             45,    26,    52,    46,    53,   158,     2,     5,  3140,  5880,
              9,   543,     6,  6974,     2,     2,   315,     6,     8,     5,
           8698,     8, 17229,     9,   308,     2,     4,     1,     1,     1,
              1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
              1,     1,     1,     1]]),
 tensor([[0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]),
 tensor([47.]),
 tensor([[ 9, 15, 26, 32, 34, 35, 45,  0,  0,  0]]),
 tensor([[1., 1., 1., 1., 1., 1., 1., 0., 0., 0.]]),
 tensor([[ 484, 1288,   20,    6, 2808,    9,   18,    0,    0,    0]]),
 tensor([0])]
#创建BERT网络模型
net = dltools.BERTModel(len(vocab), num_hiddens=128, norm_shape=[128], 
                        ffn_num_input=128, ffn_num_hiddens=256, num_heads=2, 
                        num_layers=2, dropout=0.2, key_size=128, query_size=128, 
                        value_size=128, hid_in_features=128, mlm_in_features=128, 
                        nsp_in_features=128)
#调用设备上的GPU
devices = dltools.try_all_gpus()
#损失函数对象
loss = nn.CrossEntropyLoss()   #多分类问题,使用交叉熵
#@save    #表示用于指示某些代码应该被保存或导出,以便于管理和重用
def _get_batch_loss_bert(net, loss, vocab_size, tokens_X, segments_X, valid_lens_x, pred_positions_X, mlm_weights_X, mlm_Y, nsp_y):
    #前向传播
    #获取遮蔽词元的预测结果、下一个句子的预测结果
    _, mlm_Y_hat, nsp_Y_hat = net(tokens_X, segments_X, valid_lens_x.reshape(-1), pred_positions_X)
    #计算遮蔽语言模型的损失
    mlm_l = loss(mlm_Y_hat.reshape(-1, vocab_size), mlm_Y.reshape(-1)) * mlm_weights_X.reshape(-1,1)
    mlm_l = mlm_l.sum() / (mlm_weights_X.sum() + 1e-8)   #MLM损失函数的归一化版本   #加一个很小的数1e-8,防止分母为0,抵消上一行代码乘以的数值
    #计算下一个句子预测任务的损失
    nsp_l = loss(nsp_Y_hat, nsp_y)
    l = mlm_l + nsp_l
    return mlm_l, nsp_l, l  
def train_bert(train_iter, net, loss, vocab_size, devices, num_steps):  #文本词元样本量太多,全跑完花费的时间太多,若num_steps=1在BERT中表示,跑了1个batch_size
    net = nn.DataParallel(net, device_ids=devices).to(devices[0])  #调用设备的GPU
    trainer = torch.optim.Adam(net.parameters(), lr=0.01)   #梯度下降的优化算法Adam
    step, timer = 0, dltools.Timer()  #设置计时器
    #调用画图工具
    animator = dltools.Animator(xlabel='step', ylabel='loss', xlim=[1, num_steps], legend=['mlm', 'nsp'])
    #遮蔽语言模型损失的和, 下一句预测任务损失的和, 句子对的数量, 计数
    metric = dltools.Accumulator(4)  #Accumulator类被设计用来收集和累加各种指标(metric)
    num_steps_reached = False  #设置一个判断标志, 训练步数是否达到预设的步数
    while step < num_steps and not num_steps_reached:
        for tokens_X, segments_X, valid_lens_x, pred_positions_X, mlm_weights_X, mlm_Y, nsp_y in train_iter:
            #将遍历的数据发送到设备上
            tokens_X = tokens_X.to(devices[0])
            segments_X = segments_X.to(devices[0])
            valid_lens_x = valid_lens_x.to(devices[0])
            pred_positions_X = pred_positions_X.to(devices[0])
            mlm_weights_X = mlm_weights_X.to(devices[0])
            mlm_Y, nsp_y = mlm_Y.to(devices[0]), nsp_y.to(devices[0])
            
            #梯度清零
            trainer.zero_grad()
            timer.start()  #开始计时
            mlm_l, nsp_l, l = _get_batch_loss_bert(net, loss, vocab_size, tokens_X, segments_X, valid_lens_x, pred_positions_X, mlm_weights_X, mlm_Y, nsp_y)
            l.backward()  #反向传播
            trainer.step()  #梯度更新
            metric.add(mlm_l, nsp_l, tokens_X.shape[0], l)  #累积的参数指标
            timer.stop() #计时停止
            animator.add(step + 1, (metric[0] / metric[3], metric[1] / metric[3]))  #画图的
            step += 1  #训练完一个batch_size,就+1
            if step == num_steps:  #若步数与预设的训练步数相等
                num_steps_reached = True   #判断标志改为True
                break  #退出while循环
    
    print(f'MLM loss {metric[0] / metric[3]:.3f}, 'f'NSP loss {metric[1] / metric[3]:.3f}')
    print(f'{metric[2]/ timer.sum():.1f} sentence pairs/sec on 'f'{str(devices)}')
   
train_bert(train_iter, net, loss, len(vocab), devices, 500)

 

def get_bert_encoding(net, tokens_a, tokens_b=None):
    tokens, segments = dltools.get_tokens_and_segments(tokens_a, tokens_b)
    token_ids = torch.tensor(vocab[tokens], device=devices[0]).unsqueeze(0)  #unsqueeze(0)增加一个维度
    segments = torch.tensor(segments, device=devices[0]).unsqueeze(0)  
    valid_len = torch.tensor(len(tokens), device=devices[0]).unsqueeze(0)
    endoced_X, _, _ = net(token_ids, segments, valid_len)
    return endoced_X
tokens_a = ['a', 'crane', 'is', 'flying']
encoded_text = get_bert_encoding(net, tokens_a)
# 词元:'<cls>','a','crane','is','flying','<sep>'
encoded_text_cls = encoded_text[:, 0, :]
encoded_text_crane = encoded_text[:, 2, :]
encoded_text.shape, encoded_text_cls.shape, encoded_text_crane[0][:3]
(torch.Size([1, 6, 128]),
 torch.Size([1, 128]),
 tensor([-0.5872, -0.0510, -0.7376], device='cuda:0', grad_fn=<SliceBackward0>))
encoded_text_crane

 

tensor([[-5.8725e-01, -5.0994e-02, -7.3764e-01, -4.3832e-02,  9.2467e-02,
          1.2745e+00,  2.7062e-01,  6.0271e-01, -5.5055e-02,  7.5122e-02,
          4.4872e-01,  7.5821e-01, -6.1558e-02, -1.2549e+00,  2.4479e-01,
          1.3132e+00, -1.0382e+00, -4.7851e-03, -6.3590e-01, -1.3180e+00,
          5.2245e-02,  5.0982e-01,  7.4168e-02, -2.2352e+00,  7.4425e-02,
          5.0371e-01,  7.2120e-02, -4.6384e-01, -1.6588e+00,  6.3987e-01,
         -6.4567e-01,  1.7187e+00, -6.9696e-01,  5.6788e-01,  3.2628e-01,
         -1.0486e+00, -7.2610e-01,  5.7909e-02, -1.6380e-01, -1.2834e+00,
          1.6431e+00, -1.5972e+00, -4.5678e-03,  8.8022e-02,  5.5931e-02,
         -7.2332e-02, -4.9313e-01, -4.2971e+00,  6.9757e-01,  7.0690e-02,
         -1.8613e+00,  2.0366e-01,  8.9868e-01, -3.4565e-01,  9.6776e-02,
          1.3699e-02,  7.1410e-01,  5.4820e-01,  9.7358e-01, -8.1038e-01,
          2.6216e-01, -5.7850e-01, -1.1969e-01, -2.5277e-01, -2.0046e-01,
         -1.6718e-01,  5.5540e-01, -1.8172e-01, -2.5639e-02, -6.0961e-01,
         -1.1521e-03, -9.2973e-02,  9.5226e-01, -2.4453e-01,  9.7340e-01,
         -1.7908e+00, -2.9840e-02,  2.3087e+00,  2.4889e-01, -7.2734e-01,
          2.1827e+00, -1.1172e+00, -7.0915e-02,  2.5138e+00, -1.0356e+00,
         -3.7332e-02, -5.6668e-01,  5.2251e-01, -5.0058e-01,  1.7354e+00,
          4.0760e-01, -1.2982e-01, -7.0230e-01,  3.1563e+00,  1.8754e-01,
          2.0220e-01,  1.4500e-01,  2.3296e+00,  4.5522e-02,  1.1762e-01,
          1.0662e+00, -4.0858e+00,  1.6024e-01,  1.7885e+00, -2.7034e-01,
         -1.6869e-01, -8.7018e-02, -4.2451e-01,  1.1446e-01, -1.5761e+00,
          7.6947e-02,  2.4336e+00,  4.5346e-02, -6.5078e-02,  1.4203e+00,
          3.7165e-01, -7.9571e-01, -1.3515e+00,  4.1511e-02,  1.3561e-01,
         -3.3006e+00,  1.4821e-01,  1.3024e-01,  1.9966e-01, -8.5910e-01,
          1.4505e+00,  7.6774e-02,  9.3771e-01]], device='cuda:0',
       grad_fn=<SliceBackward0>)
tokens_a, tokens_b = ['a', 'crane', 'driver', 'came'], ['he', 'just', 'left']
encoded_pair = get_bert_encoding(net, tokens_a, tokens_b)
# 词元:'<cls>','a','crane','driver','came','<sep>','he','just', 'left','<sep>'
encoded_pair_cls = encoded_pair[:, 0, :]
encoded_pair_crane = encoded_pair[:, 2, :]
encoded_pair.shape, encoded_pair_cls.shape, encoded_pair_crane[0][:3]

 

(torch.Size([1, 10, 128]),
 torch.Size([1, 128]),
 tensor([-0.4637, -0.0569, -0.6119], device='cuda:0', grad_fn=<SliceBackward0>))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2164925.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

一带一路区块链赛项样题解析(中)

一带一路区块链赛项样题解析 (模块二) 标题任务一 按要求完成智能合约开发 1、学籍信息合约(Roll)接口编码(6分) (1)编写学籍信息合约中的RollInfo 实体接口,完成RollInfo实体通用数据的初始化,实现可追溯的学籍信息上链功能;(2分) // SPDX-License-Identifie…

FPGA IP 和 开源 HDL 一般去哪找?

在FPGA开发的世界中&#xff0c;IP核和HDL模块是构建复杂数字系统的基石。它们如同乐高积木&#xff0c;让开发者能够快速搭建和重用经过验证的电路功能。但你是否曾感到迷茫&#xff0c;不知道从哪里寻找这些宝贵的资源&#xff1f;本文将为你揭开寻找FPGA IP核和HDL模块资源的…

探索MemGPT:AI界的新宠儿

文章目录 探索MemGPT&#xff1a;AI界的新宠儿1. 背景介绍2. MemGPT是什么&#xff1f;3. 如何安装MemGPT&#xff1f;4. 简单的库函数使用方法5. 场景应用场景一&#xff1a;创建持久聊天机器人场景二&#xff1a;文档分析场景三&#xff1a;多会话聊天互动 6. 常见Bug及解决方…

【2.使用VBA自动填充Excel工作表】

目录 前言什么是VBA如何使用Excel中的VBA简单基础入门控制台输出信息定义过程&#xff08;功能&#xff09;定义变量常用的数据类型Set循环For To 我的需求开发过程效果演示文件情况测试填充源文件测试填充目标文件 全部完整的代码sheet1中的代码&#xff0c;对应A公司工作表Us…

社区来稿丨一个真正意义上的实时多模态智能体框架,TEN Framework 为构建下一代 AI Agent 而生

本文由 RTE 开发者社区成员通过社区网站投稿提供&#xff0c;如果你也有与实时互动&#xff08;Real-Time Engagement&#xff0c;RTE&#xff09;相关的项目分享&#xff0c;欢迎访问网站 rtecommunity.dev 发布&#xff0c;优秀项目将会在公众号发布分享。 自从 OpenAI 展示了…

大数据毕业设计选题推荐-手机销售数据分析系统-Hive-Hadoop-Spark

✨作者主页&#xff1a;IT毕设梦工厂✨ 个人简介&#xff1a;曾从事计算机专业培训教学&#xff0c;擅长Java、Python、PHP、.NET、Node.js、GO、微信小程序、安卓Android等项目实战。接项目定制开发、代码讲解、答辩教学、文档编写、降重等。 ☑文末获取源码☑ 精彩专栏推荐⬇…

PINN机器学习登上Science正刊!热门buff叠满!11个创新思路get到就能发

今天我们来聊聊物理信息机器学习PIML。PINN大家都熟悉吧&#xff0c;毕竟研究热度就没下去过&#xff0c;这个热点其实就是PIML的一种典型代表。 PIML是一种融合了物理学与机器学习的创新技术&#xff0c;通过引入物理学的先验知识&#xff0c;来改进和优化机器学习模型的性能…

换脸黑科技FaceFusion 3.0(Windows Mac整合包)震撼来袭!

换脸黑科技FaceFusion 3.0&#xff08;Windows & Mac整合包&#xff09;震撼来袭&#xff01; 各位魔法师们&#xff0c;准备好迎接 FaceFusion 3.0 的强势登场了吗&#xff1f;这款 AI 换脸神器经历了全面升级&#xff0c;功能更强大&#xff0c;效果更惊艳&#xff0c;操…

C++(引用、窄化、输入)

1. 引用 reference&#xff08;重点&#xff09; 1.1 基础使用 引用就是某个变量或常量的别名&#xff0c;对引用进行操作与操作原变量或常量完全相同。 #include <iostream>using namespace std;int main() {int a 1;int& b a; // b是a的引用b;cout << a &…

基于单片机的汽车防酒驾控制系统设计

本设计基于STC12C5A60S2单片机的汽车防酒驾系统&#xff0c;主要包括主控制器、酒精检测模块、显示模块、声光报警模块和语音播报模块等共同组成&#xff0c;从而实现了对车内酒精浓度进行采集&#xff0c;预防酒驾的发生。利用酒精检测传感器对车辆内人员呼出的气体进行酒精浓…

C盘满了怎么清理_C盘满了深度清理详细操作步骤(多种方法)

最近有很多网友问我&#xff0c;我电脑C盘满了怎么清理&#xff1f;说自己不敢乱清理&#xff0c;怕清了系统文件无法正常开机&#xff0c;今天小编就教大家C盘满了清理的详细操作步骤&#xff0c;按教程来不怕系统进不了系统了。 C盘满了清理流程&#xff1a; 清理系统产生的…

vue-pdf 实现pdf预览、高亮、分页、定位功能

vue-pdf 实现pdf预览、高亮、分页、定位功能&#xff08;基于vue2.0&#xff01;&#xff01;&#xff01;&#xff09; 前言一、实现步骤1.引入库2.示例代码3.触发高亮事件4.分页高亮5.跳转指定页面并高亮&#xff08;不分页&#xff09; 参考笔记&#xff08;重要&#xff09…

C# 面对对象基础 枚举,Enum.TryParse的使用

代码&#xff1a; using System; using System.Collections.Generic; using System.Dynamic; using System.Linq; using System.Text; using System.Threading; using System.Threading.Tasks;namespace Student_c_ {enum Week : int{Mon,Tus,Wed,Thu,Fri,Sat,Sun,}public cla…

微服务之服务保护

Sentinel引入Java项目中 一&#xff1a;安装Sentinel 官网地址&#xff1a;https://github.com/alibaba/Sentinel/releases 二&#xff1a;安装好后在sentinel-dashboard.jar所在目录运行终端 三&#xff1a;运行命令&#xff0c;端口自己指定 java -Dserver.port8090 -Dcs…

iPhone16新机到手,,这些操作都要设置好

iPhone16新机首批机子已经发货&#xff0c;陆陆续续都几到了买家们手中了&#xff0c;iPhone 16到手后&#xff0c;虽然没有严格意义上的“必须”设置&#xff0c;但有一些推荐设置可以帮助您更好地使用和保护设备&#xff0c;同时提升安全性和使用体验&#xff0c;让你的新iPh…

栈的深度解析:链式队列的实现

引言 队列是一种广泛应用于计算机科学的数据结构&#xff0c;具有先进先出&#xff08;FIFO&#xff09;的特性。在许多实际应用中&#xff0c;例如任务调度、缓冲区管理等&#xff0c;队列扮演着重要角色。本文将详细介绍队列的基本概念&#xff0c;并通过链表实现一个简单的…

初识Jenkins持续集成系统

随着软件开发复杂度的不断提高&#xff0c;团队成员之间如何更好地协同工作以确保软件开发的质量&#xff0c;已经慢慢成为开发过程中不可回避的问题。Jenkins 自动化部署可以解决集成、测试、部署等重复性的工作&#xff0c;工具集成的效率明显高于人工操作;并且持续集成可以更…

网络原理3-应用层(HTTP/HTTPS)

目录 DNSHTTP/HTTPSHTTP协议报文HTTP的方法请求报头、响应报头(header)状态码构造HTTP请求HTTPS 应用层是我们日常开发中最常用的一层&#xff0c;因为其他层&#xff1a;传输层、网络层、数据链路层、物理层这些都是操作系统和硬件、驱动已经实现好的&#xff0c;我们只能使用…

【Python】的语言基础学习方法 快速掌握! 源码可分享!

python语言基础 第一章 你好python 1.1安装python https://www.python.org/downloads/release/python-3104/ 自定义安装&#xff0c;全选 配置python的安装路径 验证&#xff1a;cmd输入python 1.2python解释器 解释器主要做了两件事&#xff1a; 翻译代码提交给计算机去运…

Linux 下安装mysql

1.检查之前是否安装过mysql rpm -qa | grep mysql 如果之前安装过&#xff0c;删除之前的安装包 rpm -e 安装包 如果没有&#xff0c;进行后续安装 2. 下载 MySQL :: Download MySQL Community Server (Archived Versions)https://downloads.mysql.com/archives/community/ 3…