DistilBERT模型训练实战

news2025/1/12 9:49:39

LLM似乎正在接管世界,但许多人仍然不真正理解他们是如何运作的。 我从事机器学习工作已有几年,并且对自然语言处理和最近的进展非常着迷。

尽管我阅读了大部分随附的论文,但训练这些模型对我来说仍然是一个谜,这就是为什么我决定继续自己训练一个模型,以真正了解它是如何工作的。 我将其与训练问答模型结合起来,但这里仅详细介绍 DistilBERT 模型。

为了让你的生活更轻松,我决定对其工作原理进行简短回顾。 请查看这篇文章中的 distilbert.ipynb文件来查找相关代码。

在线工具推荐: Three.js AI纹理开发包 - YOLO合成数据生成器 - GLTF/GLB在线编辑 - 3D模型格式在线转换 - 可编程3D场景编辑器 

1、为什么选择 DistilBERT

要回答的第一个问题是为什么我选择 DistilBERT 而不是 BERT、ALBERT 和该模型的所有其他变体。 不幸的是,我没有无限的云计算访问权限,只有内存有限的本地 GPU,因此我必须针对模型大小和训练时间而不是性能进行优化。

也就是说,与 BERT 相比,官方的 DistilBERT 性能仅下降了3%,这似乎是一个合理的权衡。 BERT 基础有1.1亿个参数,训练时间为12天,而 DistilBERT 有6600 万个参数,训练时间只有3.5天左右。 原始论文中指出模型减小了 40%,保留了97% 的语言理解能力,速度提高 60%。

我查看了这篇文章中对 BERT、RoBERTA、DistilBERT 和 XLNet 的简短总结和比较,文章在评论中提供了一个很棒的表格,比较了所有模型。

2、数据

我使用 HuggingFace  的 OpenWebText 数据集来训练模型。 它是 OpenAI 的 WebText 数据集的开源版本。 它包含从 Reddit 采样的 8013769 个段落。

HuggingFace 为许多数据集和模型提供了一个令人惊叹的(!!!)界面,我在整个项目中都使用了它。 只需使用以下命令即可下载整个数据集。

from datasets import load_dataset

ds = load_dataset("openwebtext")

然后我继续将数据集以 10 000 个为单位存储在本地,因为这需要一些时间,而且我不想每次都等待。

3、分词(tokenization)

接下来,我们需要为模型训练一个分词器(因为我们无法将自然语言输入到模型中)。 我们可以使用 HuggingFace 的 BertWordPieceTokenizer。 我们只需传递文件的路径,它就会自动完成所有操作。 此外,我们还需要添加特殊标记 PAD(填充)、UNK(未知)、CLS(分类)、SEP(分隔符)和 MSK(掩码)标记。 有关这些标记的解释,请参阅基本 BERT 模型教程。

from tokenizers import BertWordPieceTokenizer

paths = [str(x) for x in Path('data/original').glob('**/*.txt')]

tokenizer = BertWordPieceTokenizer(
        clean_text=True,
        handle_chinese_chars=False,
        strip_accents=False,
        lowercase=True
)
tokenizer.train(files=paths[:10], vocab_size=30_000, min_frequency=2,
                    limit_alphabet=1000, wordpieces_prefix='##',
                    special_tokens=['[PAD]', '[UNK]', '[CLS]', '[SEP]', '[MASK]'])

当我们测试它时,我们得到以下标记并再次解码它们,结果表明标记生成器在每个输入的开头添加了一个 CLS 标记,并在句子后面添加了分隔符标记。 此外,我们还看到标记化输入包含输入 id(每个单词的 id)和注意掩码(告诉模型哪些标记很重要,即如果我们将序列填充到给定长度,它们将为 0)。

tokens = tokenizer('Hello, how are you?')
print(tokens)
# {'input_ids': [2, 21694, 16, 2287, 2009, 1991, 35, 3], 
# 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1]}

tokenizer.decode(tokens['input_ids']) 
# '[CLS] hello, how are you? [SEP]'

4、数据集和数据加载器

我们可以继续使用自定义的 Dataset 类和 PyTorch 中的 DataLoader 准备要加载到模型中的数据。 数据集类可以在这里找到。 我们基本上加载文件并使用我们的分词器对输入进行编码。

我在数据集中做的另一件事是逐个加载文件。 考虑到内存限制,我必须以这种方式实现它。 它有一些缺点,即你不能以这种方式洗牌数据,因为这会把一切搞乱。 不过,这应该不是什么太大的问题,因为数据集已经根据数据集描述进行了改组。

在训练过程中,模型尝试预测被屏蔽的标记,我们需要对其进行屏蔽。 因此,我屏蔽了(分配 MSK 令牌)15% 的输入,效果非常好。 其中一些基于 DistilBERT 的 HuggingFace 实现,可以在这里找到。

dataset = Dataset(paths = [str(x) for x in Path('data/original').glob('**/*.txt')][50:70], tokenizer=tokenizer)
loader = torch.utils.data.DataLoader(dataset, batch_size=8)

test_dataset = Dataset(paths = [str(x) for x in Path('data/original').glob('**/*.txt')][10:12], tokenizer=tokenizer)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=4)

5、模型

接下来我们必须定义我们的模型,是的,你猜对了,我们在这里也使用 HuggingFace。 它提供了一个令人惊叹的界面,使训练变得非常容易。

from transformers import DistilBertForMaskedLM, DistilBertConfig

config = DistilBertConfig(
    vocab_size=30000,
    max_position_embeddings=514
)
model = DistilBertForMaskedLM(config)

我们使用学习率为 1e-4 的 AdamW 作为优化器并训练 10 个 epoch(这已经花费了很多时间)。 在下面,你可以找到我的训练过程,这是非常基础的代码。

epochs = 10

for epoch in range(epochs):
    loop = tqdm(loader, leave=True)
    
    # set model to training mode
    model.train()
    losses = []
    
    # iterate over dataset
    for batch in loop:
        optim.zero_grad()
        
        # copy input to device
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        
        # predict
        outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
        
        # update weights
        loss = outputs.loss
        loss.backward()
        
        optim.step()
        
        # output current loss
        loop.set_description(f'Epoch {epoch}')
        loop.set_postfix(loss=loss.item())
        losses.append(loss.item())
        
    print("Mean Training Loss", np.mean(losses))
    losses = []
    loop = tqdm(test_loader, leave=True)
    
    # set model to evaluation mode
    model.eval()
    
    # iterate over dataset
    for batch in loop:
        # copy input to device
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        
        # predict
        outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
        
        # update weights
        loss = outputs.loss
        
        # output current loss
        loop.set_description(f'Epoch {epoch}')
        loop.set_postfix(loss=loss.item())
        losses.append(loss.item())
    print("Mean Test Loss", np.mean(losses))

6、测试

之后,我们可以运行一些健全性测试来查看模型对某些屏蔽标记的预测。 我们可以再次使用 HuggingFace 创建一个管道,它将为我们处理预测。 我们使用 fill.tokenizer.mask_token 将 MSK 令牌添加到输入中。

from transformers import pipeline

fill = pipeline("fill-mask", model='distilbert', config=config, tokenizer='distilbert_tokenizer')
fill(f'It seems important to tackle the climate {fill.tokenizer.mask_token}.')

此外,我们得到了以下带有置信水平的预测,这些预测似乎都是这句话中合理的下一个标记。

  • change: 0.19
  • crisis: 0.12
  • issues: 0.05
  • issue: 0.04

7、结束语

总而言之,考虑到基础设施的限制,结果相当不错。 显然,我们没有达到与原始模型相当的性能,但如果确实想在应用程序中使用它,你可以使用预训练模型(请参考这里)。


原文链接:DistilBERT模型训练实战 - BimAnt

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1274678.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

CSS中的非布局样式+CSS布局 前端开发入门笔记(十一)

CSS中的非布局样式 在CSS中,非布局样式是指那些不会直接影响页面布局的样式。这些样式主要关注的是元素的颜色、字体、背景、边框、阴影等视觉效果。以下是一些常见的非布局CSS样式: 文本样式:包括字体(font-family)…

以太网PHY,MAC接口

本文主要介绍以太网的 MAC 和 PHY,以及之间的 MII(Media Independent Interface ,媒体独立接口)和 MII 的各种衍生版本——GMII、SGMII、RMII、RGMII等。 简介 从硬件的角度看,以太网接口电路主要由MAC(M…

GAN:WGAN-GP-带有梯度惩罚的WGAN

论文:https://arxiv.org/pdf/1704.00028.pdf 代码:GitHub - igul222/improved_wgan_training: Code for reproducing experiments in "Improved Training of Wasserstein GANs" 发表:2017 WGAN三部曲的终章-WGAN-GP 摘要 WGAN在…

YOLOv5全网独家首发改进:SENetv2,Squeeze-Excitation模块融合Dense Layer,效果秒杀SENet

💡💡💡本文自研创新改进:SENet v2,针对SENet主要优化点,提出新颖的多分支Dense Layer,并与Squeeze-Excitation网络模块高效融合,融合增强了网络捕获通道模式和全局知识的能力 推荐指数:五星 收录 YOLOv5原创自研 https://blog.csdn.net/m0_63774211/catego…

GPT市场将取代插件商店 openAI已经关闭plugins申请,全部集成到GPTs(Actions)来连接现实世界,可以与物理世界互动了。

Actions使用了plugins的许多核心思想,也增加了新的特性。 ChatGPT的"Actions"与"Plugins"是OpenAI在GPT模型中引入的两种不同的功能扩展机制。这两种机制的目的是增强模型的功能,使其能够处理更多样化的任务和请求。下面是对两者的比…

应用于智慧工厂的AI边缘计算盒子+AI算法软硬一体化方案

智慧工厂解决方案,传统工厂/生产管理,普遍存在运营粗放、效率低、应变能力差、安全隐患突出、资源不平衡等“行业症状”; 以英码产品为核心的智能化场景解决方案,可以从本质上根治这些“症状”,如企业可利用智能预测系…

RocketMQ Copilot 一款面向 Apache RocketMQ 的智能辅助运维系统

一、RocketMQ简介 ocketMQ是阿里巴巴研发的一款分布式消息中间件,后开源给Apache基金会,成为apache的顶级开源项目。它具有高性能、高可靠、高实时和分布式的特点。RocketMQ主要应用于解决应用耦合,消息分发,流量削锋等问题。 R…

七年 4 个阶段:滴滴可观测架构演进与实践

一分钟精华速览 在当前阶段,可观测性的建设并没有统一的执行路径。每家公司会根据自身的业务需求、运营模式和规模,形成一套独特的实践方案。为了应对业务规模的扩大和需求的变化,可观测团队必须持续优化和升级其架构,并始终保证…

2023年中国金融科技研究报告

第一章 行业概况 1.1 定义 金融科技(FinTech, Financial Technology)代表了金融和技术的交汇。这一领域虽然处于发展的初期阶段,但已经展现出深远的影响力。金融科技的业务模式多样,涵盖了从传统金融服务的数字化转型到新兴技术…

亚马逊云与生成式 AI 的融合——生成式AI的应用领域

文章目录 前言亚马逊云科技增强客户体验聊天机器人和虚拟助手亚马逊云科技 鸿翼:提供精准检索和问答,显著提升全球化售后服务体验AI 赋能的联络中心智能导购&个性化推荐智慧数字人 提升员工生成力和创造力对话式搜索亚马逊云科技 西门子&#xff1…

mongoDB非关系型数据库学习记录

一、简介 1.1Mongodb是什么 MongoDB是一个基于分布式文件存储的数据库,官方地址https://www.mongodb.com/ 1.2数据库是什么 数据库(DataBase)是按照数据结构来组织、存储和管理数据的应用程序 1.3数据库的作用 数据库的主要作用就是管理数据,对数据进行增©、删(d)、…

Node-red

Node-Red 什么是Node-redNode-red的特点 Node-red的Windows安装安装Node.js安装包下载安装包安装安装检查 安装Node-red安装Note-red运行Note-red 什么是Node-red Node-RED 是一种编程工具,用于以新颖有趣的方式将硬件设备、API 和在线服务连接在一起。 Node-RED 是…

【【Micro Blaze按键中断实验】】

Micro Blaze按键中断实验 中断是一种当满足要求的突发事件发生时通知处理器进行处理的信号。中断可以由硬件处理单元和外 部设备产生,也可以由软件本身产生。对硬件来说,中断信号是一个由某个处理单元产生的异步信号,用 来引起处理器的注意。…

如何利用 Snapchat 制定数字营销战略

近年来,Snapchat已成为数字营销领域的重要参与者。Snapchat 已经发展成为一种复杂的营销工具,被各种公司用来与年轻受众互动,此前它最初被认为是一个专门供青少年发布快速、转瞬即逝内容的平台。Snapchat 上的日活跃用户数量超过 2.8 亿&…

CMD命令切换至D盘

1.使用快捷键winr进入如下所示界面: 2.在框内输入CMD 后点击确定,即可进入如下界面; 3.输入d: 后按下enter即可转换成功; 补充一些CMD命令: 1. appwiz.cpl:程序和功能 2. calc:启动计算器 3.…

分布式仿真SNN的思考(二)

经过漫长的思考,我依然无法为昨天的第二个问题找到合适的解决方法。然后今天依然对整体的放着进行思考,找出规律再去写代码。考虑SNN网络: 那么他的邻接表gabal_adj: 0 1 2 1 3 2 1 3 3 4 5 4 6 5 2 6 5 3 假设有两…

SSL证书为什么要收费?

SSL证书之所以需要收费,主要涉及以下几个方面的原因: 验证过程成本 SSL证书颁发机构(CA,Certificate Authority)必须执行验证过程,以确保证书请求者的身份和域名所有权。这些验证程序需要时间和资源&…

PyQt6 QCommandLinkButton命令链接按钮控件

​锋哥原创的PyQt6视频教程: 2024版 PyQt6 Python桌面开发 视频教程(无废话版) 玩命更新中~_哔哩哔哩_bilibili2024版 PyQt6 Python桌面开发 视频教程(无废话版) 玩命更新中~共计32条视频,包括:2024版 PyQt6 Python桌面开发 视频教程(无废话…

IO延迟引起的虚拟机故障排查

vmware 虚拟机连上之后总感觉非常卡,查看CPU 内存资源使用率是正常的。 message 日志有cpu卡住的报错 NMI watchdog: BUG: soft lockup - CPU#8 stuck for 23s! [container-31451:45878]下面分析是什么导致的服务器cpu卡住。 1、打开prometheus,观察服务…

IP地理定位技术的服务内容详解

IP地理定位技术是一种通过IP地址确定设备或用户地理位置的技术,广泛应用于广告定向、网络安全、位置服务等领域。本文将深入探讨IP地理定位技术的服务内容,解析其在不同场景中提供的多种服务。 1. 准确的地理位置信息提供: IP地理定位技术的…