一文详解大模型蒸馏工具TextBrewer

news2024/9/21 12:34:51

原文:https://zhuanlan.zhihu.com/p/648674584

本文分享自华为云社区《TextBrewer:融合并改进了NLP和CV中的多种知识蒸馏技术、提供便捷快速的知识蒸馏框架、提升模型的推理速度,减少内存占用》,作者:汀丶。

TextBrewer是一个基于PyTorch的、为实现NLP中的知识蒸馏任务而设计的工具包,
融合并改进了NLP和CV中的多种知识蒸馏技术,提供便捷快速的知识蒸馏框架,用于以较低的性能损失压缩神经网络模型的大小,提升模型的推理速度,减少内存占用。

1.简介

TextBrewer 为NLP中的知识蒸馏任务设计,融合了多种知识蒸馏技术,提供方便快捷的知识蒸馏框架。

主要特点:

  • 模型无关:适用于多种模型结构(主要面向Transfomer结构)
  • 方便灵活:可自由组合多种蒸馏方法;可方便增加自定义损失等模块
  • 非侵入式:无需对教师与学生模型本身结构进行修改
  • 支持典型的NLP任务:文本分类、阅读理解、序列标注等

TextBrewer目前支持的知识蒸馏技术有:

  • 软标签与硬标签混合训练
  • 动态损失权重调整与蒸馏温度调整
  • 多种蒸馏损失函数: hidden states MSE, attention-based loss, neuron selectivity transfer, …
  • 任意构建中间层特征匹配方案
  • 多教师知识蒸馏

TextBrewer的主要功能与模块分为3块:

  1. Distillers:进行蒸馏的核心部件,不同的distiller提供不同的蒸馏模式。目前包含GeneralDistiller, MultiTeacherDistiller, MultiTaskDistiller等
  2. Configurations and Presets:训练与蒸馏方法的配置,并提供预定义的蒸馏策略以及多种知识蒸馏损失函数
  3. Utilities:模型参数分析显示等辅助工具

用户需要准备:

  1. 已训练好的教师模型, 待蒸馏的学生模型
  2. 训练数据与必要的实验配置, 即可开始蒸馏

在多个典型NLP任务上,TextBrewer都能取得较好的压缩效果。相关实验见蒸馏效果。

2.TextBrewer结构

2.1 安装要求

  • Python >= 3.6
  • PyTorch >= 1.1.0
  • TensorboardX or Tensorboard
  • NumPy
  • tqdm
  • Transformers >= 2.0 (可选, Transformer相关示例需要用到)
  • Apex == 0.1.0 (可选,用于混合精度训练)
  • 从PyPI自动下载安装包安装:
pip install textbrewer
  • 从源码文件夹安装:
git clone https://github.com/airaria/TextBrewer.git
pip install ./textbrewer

2.2工作流程

  • Stage 1 : 蒸馏之前的准备工作:
  1. 训练教师模型
  2. 定义与初始化学生模型(随机初始化,或载入预训练权重)
  3. 构造蒸馏用数据集的dataloader,训练学生模型用的optimizer和learning rate scheduler
  • Stage 2 : 使用TextBrewer蒸馏:
  1. 构造训练配置(TrainingConfig)和蒸馏配置(DistillationConfig),初始化distiller
  2. 定义adaptor 和 callback ,分别用于适配模型输入输出和训练过程中的回调
  3. 调用distillertrain方法开始蒸馏

2.3 以蒸馏BERT-base到3层BERT为例展示TextBrewer用法

在开始蒸馏之前准备:

  • 训练好的教师模型teacher_model (BERT-base),待训练学生模型student_model (3-layer BERT)
  • 数据集dataloader,优化器optimizer,学习率调节器类或者构造函数scheduler_class 和构造用的参数字典 scheduler_args

使用TextBrewer蒸馏:

import textbrewer
from textbrewer import GeneralDistiller
from textbrewer import TrainingConfig, DistillationConfig

#展示模型参数量的统计
print("\nteacher_model's parametrers:")
result, _ = textbrewer.utils.display_parameters(teacher_model,max_level=3)
print (result)

print("student_model's parametrers:")
result, _ = textbrewer.utils.display_parameters(student_model,max_level=3)
print (result)

#定义adaptor用于解释模型的输出
def simple_adaptor(batch, model_outputs):
    # model输出的第二、三个元素分别是logits和hidden states
    return {'logits': model_outputs[1], 'hidden': model_outputs[2]}

#蒸馏与训练配置
# 匹配教师和学生的embedding层;同时匹配教师的第8层和学生的第2层
distill_config = DistillationConfig(
    intermediate_matches=[    
     {'layer_T':0, 'layer_S':0, 'feature':'hidden', 'loss': 'hidden_mse','weight' : 1},
     {'layer_T':8, 'layer_S':2, 'feature':'hidden', 'loss': 'hidden_mse','weight' : 1}])
train_config = TrainingConfig()

#初始化distiller
distiller = GeneralDistiller(
    train_config=train_config, distill_config = distill_config,
    model_T = teacher_model, model_S = student_model, 
    adaptor_T = simple_adaptor, adaptor_S = simple_adaptor)

#开始蒸馏
with distiller:
    distiller.train(optimizer, dataloader, num_epochs=1, scheduler_class=scheduler_class, scheduler_args = scheduler_args, callback=None)

2.4蒸馏任务示例

  • Transformers 4示例
    • examples/notebook_examples/sst2.ipynb (英文): SST-2文本分类任务上的BERT模型训练与蒸馏。
    • examples/notebook_examples/msra_ner.ipynb (中文): MSRA NER中文命名实体识别任务上的BERT模型训练与蒸馏。
    • examples/notebook_examples/sqaudv1.1.ipynb (英文): SQuAD 1.1英文阅读理解任务上的BERT模型训练与蒸馏。
  • examples/random_token_example: 一个可运行的简单示例,在文本分类任务上以随机文本为输入,演示TextBrewer用法。
  • examples/cmrc2018_example (中文): CMRC 2018上的中文阅读理解任务蒸馏,并使用DRCD数据集做数据增强。
  • examples/mnli_example (英文): MNLI任务上的英文句对分类任务蒸馏,并展示如何使用多教师蒸馏。
  • examples/conll2003_example (英文): CoNLL-2003英文实体识别任务上的序列标注任务蒸馏。
  • examples/msra_ner_example (中文): MSRA NER(中文命名实体识别)任务上,使用分布式数据并行训练的Chinese-ELECTRA-base模型蒸馏。

2.4.1蒸馏效果

我们在多个中英文文本分类、阅读理解、序列标注数据集上进行了蒸馏实验。实验的配置和效果如下。

  • 模型
  • 对于英文任务,教师模型为BERT-base-cased
  • 对于中文任务,教师模型为HFL发布的RoBERTa-wwm-ext 与 Electra-base

我们测试了不同的学生模型,为了与已有公开结果相比较,除了BiGRU都是和BERT一样的多层Transformer结构。模型的参数如下表所示。需要注意的是,参数量的统计包括了embedding层,但不包括最终适配各个任务的输出层。

  • 英文模型

  • 中文模型

  • T6的结构与DistilBERT[1], BERT6-PKD[2], BERT-of-Theseus[3] 相同。
  • T4-tiny的结构与 TinyBERT[4] 相同。
  • T3的结构与BERT3-PKD[2] 相同。

2.4.2 蒸馏配置
distill_config = DistillationConfig(temperature = 8, intermediate_matches = matches) #其他参数为默认值
不同的模型用的matches我们采用了以下配置:


各种matches的定义在examples/matches/matches.py中。均使用GeneralDistiller进行蒸馏。
2.4.3训练配置
蒸馏用的学习率 lr=1e-4(除非特殊说明)。训练30~60轮。
2.4.4英文实验结果
在英文实验中,我们使用了如下三个典型数据集。


我们在下面两表中列出了DistilBERT, BERT-PKD, BERT-of-Theseus, TinyBERT 等公开的蒸馏结果,并与我们的结果做对比。

Public results:


Our results:


说明:

  1. 公开模型的名称后括号内是其等价的模型结构
  2. 蒸馏到T4-tiny的实验中,SQuAD任务上使用了NewsQA作为增强数据;CoNLL-2003上使用了HotpotQA的篇章作为增强数据
  3. 蒸馏到T12-nano的实验中,CoNLL-2003上使用了HotpotQA的篇章作为增强数据

2.4.5中文实验结果
在中文实验中,我们使用了如下典型数据集。


实验结果如下表所示。


说明:

  1. 以RoBERTa-wwm-ext为教师模型蒸馏CMRC 2018和DRCD时,不采用学习率衰减
  2. CMRC 2018和DRCD两个任务上蒸馏时他们互作为增强数据
  3. Electra-base的教师模型训练设置参考自Chinese-ELECTRA
  4. Electra-small学生模型采用预训练权重初始化

3.核心概念
3.1Configurations

  • TrainingConfig 和 DistillationConfig:训练和蒸馏相关的配置。

3.2Distillers
Distiller负责执行实际的蒸馏过程。目前实现了以下的distillers:

  • BasicDistiller: 提供单模型单任务蒸馏方式。可用作测试或简单实验。
  • GeneralDistiller (常用): 提供单模型单任务蒸馏方式,并且支持中间层特征匹配,一般情况下推荐使用
  • MultiTeacherDistiller: 多教师蒸馏。将多个(同任务)教师模型蒸馏到一个学生模型上。暂不支持中间层特征匹配
  • MultiTaskDistiller:多任务蒸馏。将多个(不同任务)单任务教师模型蒸馏到一个多任务学生模型。
  • BasicTrainer:用于单个模型的有监督训练,而非蒸馏。可用于训练教师模型

3.3用户定义函数
蒸馏实验中,有两个组件需要由用户提供,分别是callback 和 adaptor :
3.3.1Callback
回调函数。在每个checkpoint,保存模型后会被distiller调用,并传入当前模型。可以借由回调函数在每个checkpoint评测模型效果。
3.3.2Adaptor
将模型的输入和输出转换为指定的格式,向distiller解释模型的输入和输出,以便distiller根据不同的策略进行不同的计算。在每个训练步,batch和模型的输出model_outputs会作为参数传递给adaptoradaptor负责重新组织这些数据,返回一个字典。
更多细节可参见完整文档中的说明。
4.FAQ
Q: 学生模型该如何初始化?
A: 知识蒸馏本质上是“老师教学生”的过程。在初始化学生模型时,可以采用随机初始化的形式(即完全不包含任何先验知识),也可以载入已训练好的模型权重。例如,从BERT-base模型蒸馏到3层BERT时,可以预先载入RBT3模型权重(中文任务)或BERT的前三层权重(英文任务),然后进一步进行蒸馏,避免了蒸馏过程的“冷启动”问题。我们建议用户在使用时尽量采用已预训练过的学生模型,以充分利用大规模数据预训练所带来的优势。
Q: 如何设置蒸馏的训练参数以达到一个较好的效果?
A: 知识蒸馏的比有标签数据上的训练需要更多的训练轮数与更大的学习率。比如,BERT-base上训练SQuAD一般以lr=3e-5训练3轮左右即可达到较好的效果;而蒸馏时需要以lr=1e-4训练30~50轮。当然具体到各个任务上肯定还有区别,我们的建议仅是基于我们的经验得出的,仅供参考
Q: 我的教师模型和学生模型的输入不同(比如词表不同导致input_ids不兼容),该如何进行蒸馏?
A: 需要分别为教师模型和学生模型提供不同的batch,参见完整文档中的 Feed Different batches to Student and Teacher, Feed Cached Values 章节。
Q: 我缓存了教师模型的输出,它们可以用于加速蒸馏吗?
A: 可以, 参见完整文档中的 Feed Different batches to Student and Teacher, Feed Cached Values 章节。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1977378.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

谷粒商城实战笔记-122~124-全文检索-ElasticSearch-分词

文章目录 一,122-全文检索-ElasticSearch-分词-分词&安装ik分词二,124-全文检索-ElasticSearch-分词-自定义扩展词库1,创建nginx容器1.1 创建nginx文件夹1.2 创建nginx容器获取nginx配置1.3 复制nginx容器配置文件1.4 删除临时的nginx容器…

《Milvus Cloud向量数据库指南》——什么是高可用:深入理解数据库系统中的高可用性架构

什么是高可用:深入理解数据库系统中的高可用性架构 在信息技术日新月异的今天,高可用性(High Availability,简称HA)已成为衡量一个系统,尤其是数据库系统稳定性和可靠性的重要标准。高可用性的核心目标在于确保系统能够持续不断地提供服务,最大限度地减少因维护活动、硬…

从零开始安装Jupyter Notebook和Jupyter Lab图文教程

前言 随着人工智能热浪(机器学习、深度学习、卷积神经网络、强化学习、AGC以及大语言模型LLM, 真的是一浪又一浪)的兴起,小伙伴们Python学习的热情达到了空前的高度。当我20年前接触Python的时候,做梦也没有想到Python会发展得怎么…

【初阶数据结构题目】10. 链表的回文结构

链表的回文结构 点击链接做题 思路1:创建新的数组,遍历原链表,遍历原链表,将链表节点中的值放入数组中,在数组中判断是否为回文结构。 例如: 排序前:1->2->2->1 设置数组来存储链表&a…

KubeSphere 最佳实战:探索 K8s GPU 资源的管理,在 KubeSphere 上部署 AI 大模型 Ollama

转载:KubeSphere 最佳实战:探索 K8s GPU 资源的管理,在 KubeSphere 上部署 AI 大模型 Ollama 随着人工智能、机器学习、AI 大模型技术的迅猛发展,我们对计算资源的需求也在不断攀升。特别是对于需要处理大规模数据和复杂算法的 AI…

数据恢复软件:电脑丢失文件,及时使用数据恢复软件恢复!

数据恢复软件什么时候会用到? 答:如果真的不小心删除文件,清空回收站,电脑重装系统等情况发生,我们要懂的及时停止使用电子设备,使用可靠的数据恢复软件,帮助我们恢复这些电子设备的数据&#…

【SQL Server 】故障排除:端口冲突排查、网络问题诊断及日志分析与监控6.1 端口冲突排查

目录 第6章:故障排除 端口冲突排查 示例:使用 PowerShell 排查端口冲突 网络问题诊断 示例:使用 Wireshark 捕获 SQL Server 网络流量 日志分析与监控 示例:使用 SQL Server Profiler 监控网络连接 安全注意事项 第6章&am…

Celery:Python异步任务处理的终极利器

文章目录 **Celery:Python异步任务处理的终极利器**第一部分:背景介绍异步任务处理的挑战为什么选择Celery?引入Celery 第二部分:Celery概述什么是Celery? 第三部分:安装Celery使用pip安装Celery 第四部分&…

腰部 KOL 发展潜力预测与企业定制 AI 智能名片 O2O 商城小程序的协同发展

摘要:随着社交媒体和内容创作平台的蓬勃发展,KOL(关键意见领袖)在品牌推广和营销领域的作用日益凸显。在头部 KOL 资源竞争激烈的当下,腰部 KOL 成为了新的运营重点。然而,挖掘有潜力的腰部 KOL 并非易事。…

【机器学习】重塑游戏世界:机器学习如何赋能游戏创新与体验升级

📝个人主页🌹:Eternity._ 🌹🌹期待您的关注 🌹🌹 ❀目录 🔍1. 引言:游戏世界的变革前夜📒2. 机器学习驱动的游戏创新🌞智能化游戏设计与开发&…

项目实战_图书管理系统(简易版)

你能学到什么 一个简单的项目——图书管理系统(浏览器:谷歌)基础版我们只做两个功能(因为其它的功能涉及的会比较多,索性就放在升级版里了,基础版先入个门) 登录: ⽤⼾输⼊账号,密码完成登录功…

华水2022年专升本计算机培养方案

华水2022年专升本计算机培养方案 文章目录 华水2022年专升本计算机培养方案计科第一学期第二学期第三学期第四学期 软工第一学期第二学期第三学期第四学期 计科 第一学期 通识必修课 大学外语线性代数离散数学 专业基础课 高级语言程序设计 专业选修课 Java 第二学期 通识…

我知道越来越多的专业摄影师在他们的修饰工作流程中使用 Portraiture,因为它可以让你在保持重要纹理的同时使皮肤非常光滑

Portraiture4.5新版功能亮点: 1. 高级皮肤修饰技术:4.5版本引入了更为先进的皮肤修饰算法,能够更自然地平滑皮肤,同时保留必要的皮肤纹理和细节,实现专业级别的人像修饰效果。 Portraiture4.5新版 2. 智能面部特征识…

计算机的错误计算(五十一)

摘要 探讨 的符号。 例1. 请确定 的符号[1]。 在计算过程中&#xff0c;若保留8位、16位、20位有效数字&#xff0c;则计算过程与结果分别如下: 若在Windows 10&#xff0c;Visual Studio 2010下计算&#xff1a; #include <math.h>double ysin(pow(2,(double)1…

Java11.0标准之重要特性及用法实例(二十一)

简介&#xff1a; CSDN博客专家&#xff0c;专注Android/Linux系统&#xff0c;分享多mic语音方案、音视频、编解码等技术&#xff0c;与大家一起成长&#xff01; 新书发布&#xff1a;《Android系统多媒体进阶实战》&#x1f680; 优质专栏&#xff1a; Audio工程师进阶系列…

(四)springboot2.7.6集成activit5.23.0之更换数据源

前面学习时&#xff0c;使用的内存数据库H2&#xff0c;实际使用时&#xff0c;一般会替换我们指定的数据库&#xff0c;这个时候要怎么配置呢&#xff1f; 1.查看activiti-spring-boot-starter-basic的spring.factories配置。 2.查看DataSourceProcessEngineAutoConfigurati…

诗意、甜美、可爱的水果:berry和cherry

我曾经在单词记忆的课上讲过&#xff0c;sweat(汗)和sweet(甜)的记忆之法&#xff0c;是甜这个单词sweet可以拟作甜丝丝来记忆&#xff0c;它是双写的-ee-结构&#xff0c;这样就能很好地与sweat相区别&#xff0c;同样desert(沙漠)和dessert(甜点)也是如此&#xff0c;和甜有关…

策略模式的一次应用

项目的需求是将一组图像按照相似度分类。 采用了模板匹配计算相似度的实现方式。 #include <opencv2/core.hpp> #include <openev2/core/utility.hpp> #include <opencv2/highqui.hpp> #include <openav2/imgproc.hpp> cv::Mat image matched; double …

基于微信小程序的游戏王交流平台设计与实现-计算机毕设 附源码 06533

基于微信小程序的游戏王交流平台设计与实现 摘要 本项目旨在设计并实现一款基于微信小程序的游戏王交流平台&#xff0c;旨在为广大游戏王爱好者提供一个交流互动的平台。通过该平台&#xff0c;用户可以分享游戏交流、分享卡片信息、参与线上比赛等活动&#xff0c;促进玩家之…

Python数据库连接全解析:5大方案实战对比

在本文中&#xff0c;我们将通过实际示例&#xff0c;深入探讨Python中5种主流的数据库连接方案。这些例子将帮助您更好地理解每种方法的特点和适用场景。 目录 不同方案说明1. DB-API&#xff1a;以sqlite3为例2. SQLAlchemy&#xff1a;ORM示例3. psycopg2&#xff1a;Postgr…