类BERT模型蒸馏实战

news2025/1/13 13:11:17

机器学习模型已经变得越来越大,以至于训练模型可能会给那些没有空闲集群的人带来痛苦。 此外,即使使用训练好的模型,当你的硬件与模型对其运行的期望不符时,推理的时间和内存成本也会飙升。 因此,为了缓解这个问题,我们并没有放弃类似 BERT 模型的深层知识,而是开发了一种称为蒸馏(distillation)的技术,将网络缩小到合理的大小,同时最大限度地减少性能损失。

如果你已经阅读了本系列的第一篇文章,那么这并不是什么新闻。 在其中,我们讨论了 DistilBERT  如何引入一种简单而有效的蒸馏技术,可以轻松应用于任何类似 BERT 的模型,但我们避开了任何具体的实现。 现在,我们将详细介绍如何将想法转化为 .py 文件。

在线工具推荐: Three.js AI纹理开发包 - YOLO合成数据生成器 - GLTF/GLB在线编辑 - 3D模型格式在线转换 - 可编程3D场景编辑器 

1、学生模型的初始化

由于我们想要从现有模型初始化一个新模型,因此需要访问旧模型(即教师)的权重。 我们假设预先存在的模型是在 PyTorch 上实现的 Hugging Face 模型, 因此,要获得权重,首先必须知道如何访问它们。 我们将使用 RoBERTa large 作为我们的教师模型。

1.1 Hugging Face的模型结构

我们可以尝试的第一件事是打印模型结构,这应该让我们深入了解它是如何制作的。 当然,我们总是可以深入研究 Hugging Face 文档 ,但这并不有趣。

from transformers import AutoModelForMaskedLM

roberta = AutoModelForMaskedLM.from_pretrained("roberta-large")

print(roberta)

运行此代码后,我们得到:

简单打印 RoBERTA 的第一印象

模型的结构开始出现,但我们可以让它变得更漂亮。 在 Hugging Face 模型中,我们可以使用 .children() 生成器访问模块的子组件。 因此,如果我们想要遍历整个模型,我们需要在其上调用 .children() ,并在每个产生的子级上继续调用 .children() ,等等......这描述了一个递归函数,代码如下:

from typing import Any
from transformers import AutoModelForMaskedLM

roberta = AutoModelForMaskedLM.from_pretrained("roberta-large")

def visualize_children(
    object : Any,
    level : int = 0,
) -> None:
    """
    Prints the children of (object) and their children too, if there are any.
    Uses the current depth (level) to print things in a ordonnate manner.
    """
    print(f"{'   ' * level}{level}- {type(object).__name__}")
    try:
        for child in object.children():
            visualize_children(child, level + 1)
    except:
        pass

visualize_children(roberta)

输出结果如下:

RoBERTa 的递归预览

通过展开这棵树,看起来 RoBERTa 模型的结构与其他类似 BERT 的模型一样,如下所示:

类 BERT 模型的架构

1.2 复制教师模型的权重

我们知道,要以 DistilBERT 的方式初始化类似 BERT 的模型,我们只需要复制除 Roberta 层最深层之外的所有内容,我们省略了其中的一半。

首先,我们需要创建学生模型,其架构与教师模型相同,但隐藏层数量只有一半。
为此,我们只需要使用教师模型的配置,它是一个类似字典的对象,描述了 Hugging Face 模型的架构。 当查看 roberta.config 属性时,我们可以看到以下内容:

RoBERTa 配置

我们在这里感兴趣的是 num-hidden-layers 属性。 让我们编写一个函数来复制此配置,通过将其除以 2 来更改该属性,并使用新配置创建一个新模型:

from transformers.models.roberta.modeling_roberta import RobertaPreTrainedModel, RobertaConfig

def distill_roberta(
    teacher_model : RobertaPreTrainedModel,
) -> RobertaPreTrainedModel:
    """
    Distilates a RoBERTa (teacher_model) like would DistilBERT for a BERT model.
    The student model has the same configuration, except for the number of hidden layers, which is // by 2.
    The student layers are initilized by copying one out of two layers of the teacher, starting with layer 0.
    The head of the teacher is also copied.
    """
    # Get teacher configuration as a dictionnary
    configuration = teacher_model.config.to_dict()
    # Half the number of hidden layer
    configuration['num_hidden_layers'] //= 2
    # Convert the dictionnary to the student configuration
    configuration = RobertaConfig.from_dict(configuration)
    # Create uninitialized student model
    student_model = type(teacher_model)(configuration)
    # Initialize the student's weights
    distill_roberta_weights(teacher=teacher_model, student=student_model)
    # Return the student model
    return student_model

当然,这个函数引入了一个缺失的部分 distill_roberta_weights ,该函数会将教师模型权重的一半置于学生层中,但我们仍然需要对其进行编码。 由于递归对于探索教师模型效果很好,因此我们可以使用相同的想法来探索和复制其中的部分内容。 我们将同时浏览教师模型和学生模型,同时将部分内容从一个模型复制到另一个模型。 唯一的技巧是要小心隐藏层部分并只复制一半。实现代码如下:

from transformers.models.roberta.modeling_roberta import RobertaEncoder, RobertaModel
from torch.nn import Module

def distill_roberta_weights(
    teacher : Module,
    student : Module,
) -> None:
    """
    Recursively copies the weights of the (teacher) to the (student).
    This function is meant to be first called on a RobertaFor... model, but is then called on every children of that model recursively.
    The only part that's not fully copied is the encoder, of which only half is copied.
    """
    # If the part is an entire RoBERTa model or a RobertaFor..., unpack and iterate
    if isinstance(teacher, RobertaModel) or type(teacher).__name__.startswith('RobertaFor'):
        for teacher_part, student_part in zip(teacher.children(), student.children()):
            distill_roberta_weights(teacher_part, student_part)
    # Else if the part is an encoder, copy one out of every layer
    elif isinstance(teacher, RobertaEncoder):
            teacher_encoding_layers = [layer for layer in next(teacher.children())]
            student_encoding_layers = [layer for layer in next(student.children())]
            for i in range(len(student_encoding_layers)):
                student_encoding_layers[i].load_state_dict(teacher_encoding_layers[2*i].state_dict())
    # Else the part is a head or something else, copy the state_dict
    else:
        student.load_state_dict(teacher.state_dict())

该函数通过递归和类型检查,确保学生模型与教师模型相同,对于 Roberta 层来说是安全的。 可以注意到,如果我们想在初始化教师模型时更改复制哪些层,则只有编码器部分中的 for 循环需要更改。

现在我们有了学生模型,我们需要训练它。 除了要使用的损失函数之外,这部分相对简单。

2、自定义损失函数

作为对 DistilBERT 训练过程的回顾,我们可以看下图:

DistilBERT 的蒸馏过程

我们将把注意力转向那个写着 LOSS 的红色大盒子。 但在揭示里面有什么之前,我们需要知道如何收集我们要喂它的东西。 从这张图中,我们可以看到我们需要三样东西:标签、学生模型和教师模型的嵌入。 标签,我们已经有了,否则,我们可能会遇到更大的问题。 现在让我们得到另外两个。

2.1 检索教师和学生的输入

在这里,我们将坚持我们的示例并使用带有分类头的 RoBERTa 来说明这部分。 我们需要的是一个函数,给定类似 BERT 模型的输入,即两个张量( input_ids 和  Attention_mask)以及模型本身,将返回该模型的输出 logits

由于我们使用的是 Hugging Face,所以这非常简单,我们唯一需要的知识就是看哪里。

from torch import Tensor

def get_logits(
    model : RobertaPreTrainedModel, 
    input_ids : Tensor,
    attention_mask : Tensor,
) -> Tensor:
    """
    Given a RoBERTa (model) for classification and the couple of (input_ids) and (attention_mask),
    returns the logits corresponding to the prediction.
    """
    return model.classifier(
        model.roberta(input_ids, attention_mask)[0]
    )

我们为学生模型和老师模型都执行这个操作,第一个有梯度,第二个没有梯度。

2.2 损失函数计算

如果损失函数有点不透明,我们建议你返回第一篇文章来阅读损失函数。 但是,如果没有时间这样做,下图应该会有所帮助:

DistilBERT 的损失

我们所说的 Converging consine loss(收敛余弦损失)是用于对齐两个输入向量的常规余弦损失。 有关更多信息,请参阅该系列的第一部分。 这是代码:

import torch
from torch.nn import CrossEntropyLoss, CosineEmbeddingLoss

def distillation_loss(
    teacher_logits : Tensor,
    student_logits : Tensor,
    labels : Tensor,
    temperature : float = 1.0,
) -> Tensor:
    """
    The distillation loss for distilating a BERT-like model.
    The loss takes the (teacher_logits), (student_logits) and (labels) for various losses.
    The (temperature) can be given, otherwise it's set to 1 by default.
    """
    # Temperature and sotfmax
    student_logits, teacher_logits = (student_logits / temperature).softmax(1), (teacher_logits / temperature).softmax(1)
    # Classification loss (problem-specific loss)
    loss = CrossEntropyLoss()(student_logits, labels)
    # CrossEntropy teacher-student loss
    loss = loss + CrossEntropyLoss()(student_logits, teacher_logits)
    # Cosine loss
    loss = loss + CosineEmbeddingLoss()(teacher_logits, student_logits, torch.ones(teacher_logits.size()[0]))
    # Average the loss and return it
    loss = loss / 3
    return loss

3、更优雅的实现

我希望你不会对 Python 是一种面向对象的编程语言感到震惊。 因此,由于所有这些函数都使用几乎相同的对象,因此不让它们成为类的一部分似乎很奇怪。 如果你想实现这一点,我建议使用 Distillator 类来整理代码,就像这个gist 。 我们不会嵌入这个,因为它很长。

当然,缺少一些东西,比如 GPU 支持、整个训练例程等。但是 DistilBERT 的所有关键思想都可以在那里找到。

4、蒸馏结果

那么以这种方式提炼出来的模型最终表现如何呢? 对于DistilBERT,可以阅读原论文。 对于 RoBERTa,Hugging Face 上已经存在类似 DistilBERT 的精简版本,就在这里。 在 GLUE 基准测试 上,我们可以比较这两个模型:

RoBERTa 与 DistilRoBERTa的对比

至于时间和内存成本,该模型的大小大约是 roberta-base 的三分之二,速度是 roberta-base 的两倍。

5、结束语

通过本系列文章,你应该拥有足够的知识来提炼遇到的任何类似 BERT 的模型。 但为什么要停在那里呢? 大自然充满了蒸馏方法,例如 TinyBERT  或 MobileBERT。 如果你认为其中一个更适合你的需求,那么应该阅读这些文章。 谁知道呢,你可能想尝试一种新的蒸馏方法,因为这是一个日益发展的领域。


原文链接:类 BERT 模型蒸馏实战 - BimAnt

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1218805.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

idea 环境搭建及运行java后端源码

1、 idea 历史版本下载及安装 建议下载和我一样的版本,2020.3 https://www.jetbrains.com/idea/download/other.html,idea分为专业版本(Ultimate)和社区版本(Community),前期可以下载专业版本…

新品|CASAIM-IS(2ND)自动化智能检测系统正式上市,打造更高效、更智能、更安全新体验!

全新第二代中科广电CASAIM-IS自动化智能检测系统正式上市,集合CASAIM最新的“智能控制、智能成像、智能检测”三智技术,为中小型精密复杂工件测量及检测提供一站式高效全自动化智能检测解决方案

设计模式(5)-使用设计模式实现简易版springIoc

自定义简易版springIoc 1 spring使用回顾 自定义spring框架前,先回顾一下spring框架的使用,从而分析spring的核心,并对核心功能进行模拟。 数据访问层。定义UserDao接口及其子实现类 public interface UserDao {public void add(); }public…

使用VC++设计程序,进行全局固定阈值分割、自适应阈值分割

图像分割 文章目录 图像分割实验内容一、全局固定阈值分割全局固定阈值分割的原理全局固定阈值分割的实验代码全局固定阈值分割的实验现象 二、自适应阈值分割自适应阈值分割的实验原理自适应阈值分割的实验代码自适应阈值分割的实验现象 实验内容 实验目的: &…

移交计划书、移交确认单

项目移交过程文件: 1、移交计划书 2、移交确认单 1、移交计划 2、移交确认单

Day48 力扣动态规划 : 647. 回文子串 |516.最长回文子序列 |动态规划总结篇

Day48 力扣动态规划 : 647. 回文子串 |516.最长回文子序列 |动态规划总结篇 647. 回文子串第一印象看完题解的思路dp递推公式初始化递归顺序 实现中的困难感悟代码 516.最长回文子序列第一印象我的尝试遇到的问题 看完题解的思路dp递推公式初始化 实现中…

基于springboot实现大学生体质测试管理系统项目【项目源码+论文说明】计算机毕业设计

基于springboot实现大学生体质测试管理系统演示 摘要 大学生体质测试管理系统提供给用户一个简单方便体质测试管理信息,通过留言区互动更方便。本系统采用了B/S体系的结构,使用了java技术以及MYSQL作为后台数据库进行开发。系统主要分为系统管理员、教师…

C/C++高频面经-秋招篇

自己在秋招找工作过程中遇到的一些C/C面试题,大中小厂都有,分享出来,希望能帮到有缘人。 C语言 snprintf()的使用 函数原型为int snprintf(char *str, size_t size, const char *format, …) 两点注意: (1) 如果格式化后的字符…

《Linux从练气到飞升》No.30 深入理解 POSIX 信号量与生产消费模型

🕺作者: 主页 我的专栏C语言从0到1探秘C数据结构从0到1探秘Linux菜鸟刷题集 😘欢迎关注:👍点赞🙌收藏✍️留言 🏇码字不易,你的👍点赞🙌收藏❤️关注对我真的…

vs2017 编译Qt 5.11.2 源码

SDK 10.0.22000.194 有 2种编译方式 ,第二种 看下面 推荐使用方式二,简单方便,唯一不好是慢 方式一: 1、问题描述: 使用VS编译程序时,运行库选择多线程(/MT),表示采用多线程静态…

安卓用户当心: CERT-IN 发布高危漏洞警告

已发现的漏洞一旦被利用,将构成严重风险,可能导致未经授权访问敏感信息。 印度计算机应急响应小组(CERT-IN)在最近发布的一份公告中,就影响印度安卓用户的新安卓漏洞发出了重要警告。 该警告对使用安卓 11、12、12L、…

modbus转profinet网关连接PLC与变频器控制摆辊应用在涂布机案例

通过兴达易控modbus转profinet网关的应用,PLC能够直接与变频器进行通讯,并实现对摆辊的精确控制。兴达易控modbus转profinet网关(XD-MDPN100)作为一个高性能的转换设备,能够稳定可靠地完成modbus和profinet之间的数据转…

2023最新最全【Python3.11.3】下载安装零基础教程【附安装包】

前言:链接在最底下 Python是一种可在多个平台上运行的计算机程序设计语言,它是一种高层次的脚本语言,结合了解释性、编译性、互动性和面向对象的特点。最初,它的设计目的是用于编写自动化脚本(shell)。但随着版本的更新和新功能的…

vue 城市选择器的使用 element-china-area-data

一、Element UI 中国省市区级联数据 本文参考:element-china-area-data - npm 1. 安装 npm install element-china-area-data -S2. 使用 import { provinceAndCityData, regionData, provinceAndCityDataPlus, regionDataPlus, CodeToText, TextToCode } from e…

学习模拟简明教程【Learning to simulate】

深度神经网络是一项令人惊叹的技术。 有了足够的标记数据,他们可以学习为图像和声音等高维输入生成非常准确的分类器。 近年来,机器学习社区已经能够成功解决诸如对象分类、图像中对象检测和图像分割等问题。 上述声明中的加黑字体警告是有足够的标记数…

git 构建报错

钉钉插件]当前任务未配置机器人,已跳过 org.codehaus.groovy.control.MultipleCompilationErrorsException: startup failed: WorkflowScript: 4: Tool type “maven” does not have an install of “maven-3.8.8” configured - did you mean “Maven-3.8.8”? …

Docker Desktop 配置阿里云镜像加速

阿里云搜索镜像,打开容器镜像服务,复制镜像加速器地址 Docker Desktop 右上角设置,选择 Docker Engine,在配置中添加阿里云的镜像地址,右下 Apply & restart 即可。 "registry-mirrors": ["https…

android适配鸿蒙系统开发

将一个Android应用迁移到鸿蒙系统需要进行细致的工作,因为两者之间存在一些根本性的差异,涉及到代码、架构、界面等多个方面的修改和适配。以下是迁移工作可能涉及的一些主要方面,希望对大家有所帮助。北京木奇移动技术有限公司,专…

《Linux从练气到飞升》No.29 生产者消费者模型

🕺作者: 主页 我的专栏C语言从0到1探秘C数据结构从0到1探秘Linux菜鸟刷题集 😘欢迎关注:👍点赞🙌收藏✍️留言 🏇码字不易,你的👍点赞🙌收藏❤️关注对我真的…

机械人必须要了解的丝杆螺母参数

丝杆螺母是机械中重要的零部件之一,主要用于将旋转运动转化为直线运动,或者将直线运动转化为旋转运动。只有正确了解丝杆螺母的参数,才能进行选型。 1、螺纹规格:丝杆螺母的螺纹规格是按照国家标准进行分类的,常见的有…