【大语言模型】ACL2024论文-07 BitDistiller: 释放亚4比特大型语言模型的潜力通过自蒸馏

news2024/11/27 4:16:45

【大语言模型】ACL2024论文-07 BitDistiller: 释放亚4比特大型语言模型的潜力通过自蒸馏


目录

文章目录

  • 【大语言模型】ACL2024论文-07 BitDistiller: 释放亚4比特大型语言模型的潜力通过自蒸馏
    • 目录
      • 摘要
      • 研究背景
      • 问题与挑战
      • 如何解决
      • 创新点
      • 算法模型
      • 实验效果
      • 代码
      • 推荐阅读指数:✭✭✭✭✩
    • 后记


BitDistiller: 释放亚4比特大型语言模型的潜力通过自蒸馏
在这里插入图片描述

摘要

本文介绍了BitDistiller,这是一个通过结合量化感知训练(QAT)和知识蒸馏(KD)来提升超低精度(亚4比特)大型语言模型(LLMs)性能的框架。BitDistiller首先采用定制的非对称量化和裁剪技术来尽可能保持量化权重的保真度,然后提出了一种新颖的基于置信度的Kullback-Leibler散度(CAKLD)目标,用于自蒸馏,以实现更快的收敛和更优的模型性能。实验评估表明,BitDistiller在3比特和2比特配置下,无论是在通用语言理解还是复杂推理基准测试中,都显著超越了现有方法。值得注意的是,BitDistiller更具成本效益,需要更少的数据和训练资源。

研究背景

随着大型语言模型(LLMs)规模的扩大,自然语言处理领域取得了令人印象深刻的进展。然而,这种模型规模的扩大在部署上带来了显著的挑战,尤其是在资源受限的设备上,因为它们需要大量的内存和计算能力。权重量化作为一种流行的策略,通过减少模型大小来提高LLMs的效率和可访问性,同时最小化性能损失。尽管4比特量化已被广泛采用,提供了显著的压缩比和保留LLM能力之间的平衡,但亚4比特量化会显著降低模型权重的保真度,尤其是在小型模型或需要复杂推理的任务中,导致模型性能恶化。
在这里插入图片描述

问题与挑战

在极端低比特QAT中实现高性能的两个基本挑战是:如何在量化过程中最大限度地保持权重保真度,以及如何在训练中有效学习低比特表示。

如何解决

BitDistiller通过以下方式解决上述挑战:

  1. 非对称量化和裁剪:BitDistiller采用了定制的非对称量化和裁剪策略,以保持全精度模型的能力,特别是在超低比特水平上。
  2. 自蒸馏:BitDistiller利用全精度模型作为教师,低比特模型作为学生,通过自蒸馏方法进行有效的低比特表示学习。
  3. CAKLD目标:BitDistiller创新性地提出了一种基于置信度的Kullback-Leibler散度(CAKLD)目标,优化知识传递效率,实现更快的收敛和增强的模型性能。

创新点

  • 非对称量化和裁剪:BitDistiller针对不同比特级别的量化采用了不同的量化策略,如NF格式和INT格式,以及非对称裁剪,以提高量化权重的表示保真度。
  • CAKLD目标:BitDistiller提出了一种新颖的CAKLD目标,它根据全精度模型对训练数据的置信度自动权衡模式寻求和模式覆盖行为。
  • 自蒸馏框架:BitDistiller将QAT与知识蒸馏相结合,使用全精度模型作为教师来指导低比特学生模型,这是一种简单而有效的自蒸馏方法。
    在这里插入图片描述

算法模型

BitDistiller的框架包括以下几个关键步骤:

  1. 非对称量化和裁剪:在QAT初始化阶段,BitDistiller对权重进行非对称裁剪,以减少量化误差。
  2. 自蒸馏:在训练过程中,全精度模型生成数据,低比特模型学习这些数据,通过CAKLD目标进行优化。
  3. CAKLD目标:CAKLD目标结合了反向KL散度和正向KL散度,根据全精度模型的置信度自动调整模式寻求和模式覆盖行为。
    在这里插入图片描述

实验效果

实验评估表明,BitDistiller在3比特和2比特配置下的性能显著优于现有的PTQ和QAT方法。以下是一些重要的数据和结论:

  • 语言建模任务:在WikiText-2的困惑度(PPL)和MMLU(5-shot)准确性方面,BitDistiller超越了竞争对手。
  • 推理任务:在HumanEval和GSM8K等推理基准测试中,BitDistiller在3比特和2比特量化中均展现出优越性能。
  • 成本效益:BitDistiller需要的训练数据和资源更少,更具成本效益。
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述

代码

https://github.com/DD-DuDa/BitDistiller.git
在这里插入图片描述

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function
from tqdm import tqdm
import gc
# import bitsandbytes as bnb
import torch.nn as nn
from functools import partial
# import bitsandbytes.functional as bnbF

class Round(Function):
    @staticmethod
    def forward(self, input):
        sign = torch.sign(input)
        output = sign * torch.floor(torch.abs(input) + 0.5)
        return output

    @staticmethod
    def backward(self, grad_output):
        grad_input = grad_output.clone()
        return grad_input

# core quantization method (simulated quantization)
def pseudo_quantize_tensor(w, n_bit=8,
                           zero_point=True, q_group_size=-1,
                           inplace=False,
                           get_scale_zp=False
                           ):
    org_w_shape = w.shape
    if q_group_size > 0:
        assert org_w_shape[-1] % q_group_size == 0
        w = w.reshape(-1, q_group_size)
    elif q_group_size == -1:
        w = w.reshape(-1, w.shape[-1])
    assert w.dim() == 2
    if zero_point:
        max_val = w.amax(dim=1, keepdim=True)
        min_val = w.amin(dim=1, keepdim=True)
        max_int = 2 ** n_bit - 1
        min_int = 0
        scales = (max_val - min_val).clamp(min=1e-5) / max_int
        zeros = (-torch.round(min_val / scales)).clamp_(min_int, max_int)
    else:  # we actually never used this
        assert min_val is None
        max_val = w.abs().amax(dim=1, keepdim=True)
        max_val = max_val.clamp(min=1e-5)
        max_int = 2 ** (n_bit - 1) - 1
        min_int = - 2 ** (n_bit - 1)
        scales = max_val / max_int
        zeros = 0

    assert torch.isnan(scales).sum() == 0
    assert torch.isnan(w).sum() == 0

    if inplace:
        ((w.div_(scales).round_().add_(zeros)).clamp_(
            min_int, max_int).sub_(zeros)).mul_(scales)
    else:
        w = (torch.clamp(torch.round(w / scales) +
                         zeros, min_int, max_int) - zeros) * scales
    assert torch.isnan(w).sum() == 0

    w = w.reshape(org_w_shape)

    if get_scale_zp:
        return w, scales.view(w.shape[0], -1), zeros.view(w.shape[0], -1)
    else:
        return w



@torch.no_grad()
def real_quantize_model_weight(
    model, w_bit, q_config,
    init_only=False
):
    from .qmodule import WQLinear
    from .pre_quant import get_blocks, get_named_linears, set_op_by_name
    assert q_config["zero_point"], "We only support zero_point quantization now."
    
    layers = get_blocks(model)
    for i in tqdm(range(len(layers)), desc="real weight quantization..." + ("(init only)" if init_only else "")):
        layer = layers[i]
        named_linears = get_named_linears(layer)
        # scale_activations(layer)

        for name, module in named_linears.items():
            if init_only:
                q_linear = WQLinear.from_linear(
                    module, w_bit, q_config['q_group_size'], True)
                q_linear.to(next(layer.parameters()).device)
                set_op_by_name(layer, name, q_linear)
            else:
                module.cuda()
                module.weight.data, scales, zeros = pseudo_quantize_tensor(module.weight.data, n_bit=w_bit, get_scale_zp=True, **q_config)
                # scales = scales.t().contiguous()
                # zeros = zeros.t().contiguous()
                q_linear = WQLinear.from_linear(
                    module, w_bit, q_config['q_group_size'], False, scales, zeros)
                module.cpu()
                q_linear.to(next(layer.parameters()).device)
                set_op_by_name(layer, name, q_linear)
                torch.cuda.empty_cache()
                gc.collect()
                
    torch.cuda.empty_cache()
    gc.collect()




def pseudo_quantize_n2f3_tensor(w, q_group_size=-1):
    quantizer = SteN2F3Quantizer(q_group_size=q_group_size)
    w = quantizer(w)
    return w


class SteInt3AsymQuantizer(nn.Module):
    def __init__(self, q_group_size=128):
        super().__init__()
        self.q_group_size = q_group_size
        self.bit = 3
    def forward(self, x):
        org_w_shape = x.shape

        if self.q_group_size > 0:
            assert org_w_shape[-1] % self.q_group_size == 0
            x = x.reshape(-1, self.q_group_size)
        elif self.q_group_size == -1:
            assert org_w_shape[-1] % self.q_group_size == 0
            x = x.reshape(-1, x.shape[-1])
        assert x.dim() == 2

        max_val = x.amax(dim=1, keepdim=True)
        min_val = x.amin(dim=1, keepdim=True)
        max_int = 2 ** self.bit - 1
        min_int = 0
        scales = (max_val - min_val).clamp(min=1e-5) / max_int
        zeros = (-torch.round(min_val / scales)).clamp_(min_int, max_int)

        assert torch.isnan(scales).sum() == 0
        assert torch.isnan(x).sum() == 0

        x = (torch.clamp(Round.apply(x / scales) +
                         zeros, min_int, max_int) - zeros) * scales
        assert torch.isnan(x).sum() == 0

        x = x.reshape(org_w_shape)

        return x

class SteInt2AsymQuantizer(nn.Module):
    def __init__(self, q_group_size=64):
        super().__init__()
        self.q_group_size = q_group_size
        self.bit = 2
    def forward(self, x):
        org_w_shape = x.shape

        if self.q_group_size > 0:
            assert org_w_shape[-1] % self.q_group_size == 0
            x = x.reshape(-1, self.q_group_size)
        assert x.dim() == 2

        max_val = x.amax(dim=1, keepdim=True)
        min_val = x.amin(dim=1, keepdim=True)
        max_int = 2 ** self.bit - 1
        min_int = 0
        scales = (max_val - min_val).clamp(min=1e-5) / max_int
        zeros = (-torch.round(min_val / scales)).clamp_(min_int, max_int)

        assert torch.isnan(scales).sum() == 0
        assert torch.isnan(x).sum() == 0

        x = (torch.clamp(Round.apply(x / scales) +
                         zeros, min_int, max_int) - zeros) * scales
        assert torch.isnan(x).sum() == 0

        x = x.reshape(org_w_shape)

        return x

class SteN2F3Quantizer(nn.Module):
    def __init__(self, q_group_size=128):
        super().__init__()
        self.q_group_size = q_group_size
    
    def forward(self, x):
        org_w_shape = x.shape

        # reshape to groupsize
        if self.q_group_size > 0:
            assert org_w_shape[-1] % self.q_group_size == 0
            qx = x.reshape(-1, self.q_group_size)
        elif self.q_group_size == -1:
            qx = x.reshape(-1, x.shape[-1])
        assert qx.dim() == 2

        # Get the Min Max
        max_val = qx.amax(dim=1, keepdim=True)
        min_val = qx.amin(dim=1, keepdim=True)

        
        scale_pos = torch.abs(max_val)
        scale_neg = torch.abs(min_val)

        dev = qx.device
        x_pos = torch.zeros_like(qx)
        x_neg = torch.zeros_like(qx)
        x_pos = torch.where(qx >= 0, qx, x_pos)
        x_neg = torch.where(qx < 0, qx, x_neg)
        q_pos = x_pos / scale_pos
        q_neg = x_neg / scale_neg

        q_pos, q_neg = self.round_pass(q_pos, q_neg, dev)

        qx = q_pos * scale_pos + q_neg * scale_neg

        qx = qx.reshape(org_w_shape)

        return qx
    
    def round_n2f3(self, q_pos, q_neg, dev):
        q_pos = torch.where(q_pos >= 0.8114928305149078,                                        torch.tensor(1.0).to(dev), q_pos)
        q_pos = torch.where((q_pos < 0.8114928305149078)    & (q_pos >= 0.5024898052215576),    torch.tensor(0.6229856610298157).to(dev), q_pos)
        q_pos = torch.where((q_pos < 0.5024898052215576)    & (q_pos >= 0.2826657369732857),    torch.tensor(0.3819939494132996).to(dev), q_pos)
        q_pos = torch.where((q_pos < 0.2826657369732857)    & (q_pos >= 0.0916687622666359),    torch.tensor(0.1833375245332718).to(dev), q_pos)
        q_pos = torch.where(q_pos < 0.0916687622666359,                                        torch.tensor(0).to(dev), q_pos)

        q_neg = torch.where(q_neg >= -0.1234657019376755,                                     torch.tensor(0).to(dev), q_neg)
        q_neg = torch.where((q_neg < -0.1234657019376755)   & (q_neg >= -0.39097706973552704),   torch.tensor(-0.2469314038753510).to(dev), q_neg)
        q_neg = torch.where((q_neg < -0.39097706973552704)   & (q_neg >= -0.7675113677978516),   torch.tensor(-0.5350227355957031).to(dev), q_neg)
        q_neg = torch.where(q_neg < -0.7675113677978516,                                        torch.tensor(-1.0).to(dev), q_neg)

        return q_pos, q_neg

    def round_pass(self, q_pos, q_neg, dev):
        y_grad_pos, y_grad_neg = q_pos, q_neg
        y_pos, y_neg = self.round_n2f3(q_pos, q_neg, dev)
        
        return (y_pos - y_grad_pos).detach() + y_grad_pos, (y_neg - y_grad_neg).detach() + y_grad_neg

推荐阅读指数:✭✭✭✭✩

推荐理由

  • 创新性:BitDistiller通过结合QAT和KD,在亚4比特量化领域提供了一种新的解决方案,具有显著的性能提升。
  • 实用性:BitDistiller不仅在理论上具有创新性,而且在实际应用中也显示出了成本效益,这对于资源受限的设备尤为重要。
  • 广泛适用性:BitDistiller在多种语言和推理任务中都展现出了优越的性能,表明其方法的广泛适用性。

后记

如果您对我的博客内容感兴趣,欢迎三连击(点赞、收藏、关注和评论),我将持续为您带来计算机人工智能前沿技术(尤其是AI相关的大语言模型,深度学习和计算机视觉相关方向)最新学术论文及工程实践方面的内容分享,助力您更快更准更系统地了解 AI前沿技术

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2231844.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Tomcat 和 Docker部署Java项目的区别

在 Java 项目部署中&#xff0c;Tomcat 和 Docker 是两种常见的选择。虽然它们都可以用来运行 Java 应用&#xff0c;但它们在定位、部署方式、依赖环境、资源隔离、扩展性和适用场景等方面有显著区别。 1. 功能定位 1.1 Tomcat Apache Tomcat 是一种轻量级的 Java 应用服务器…

基于SSM的学生选课系统+LW参考示例

系列文章目录 1.基于SSM的洗衣房管理系统原生微信小程序LW参考示例 2.基于SpringBoot的宠物摄影网站管理系统LW参考示例 3.基于SpringBootVue的企业人事管理系统LW参考示例 4.基于SSM的高校实验室管理系统LW参考示例 5.基于SpringBoot的二手数码回收系统原生微信小程序LW参考示…

Java I/O流详解

文章目录 I/O流概念I/O流的分类字节流&#xff08;Byte Streams&#xff09;字节字节流概述方法主要类和继承关系示例代码字节流读取文件 字符流字符流概述子类Reader1.FileReader&#xff1a;2.CharArrayReader&#xff1a;3.StringReader&#xff1a;4.InputStreamReader&…

基于Multisim数字频率计频率范围0-9999HZ电路(含仿真和报告)

【全套资料.zip】数字频率计仿真电路设计Multisim仿真设计数字电子技术 文章目录 功能一、Multisim仿真源文件二、原理文档报告资料下载【Multisim仿真报告讲解视频.zip】 功能 1.采用纯数字电路&#xff0c;非单片机。 2.频率计测量的频率范围0-9999HZ。 3.使用数码管进行频…

Python画笔案例-095 绘制鼠标画笔

1、绘制 鼠标画笔 通过 python 的turtle 库绘制 鼠标画笔,如下图: 2、实现代码 绘制 鼠标画笔,以下为实现代码: """鼠标画笔.py本程序可以用鼠标指针在屏幕上画画儿。 """ import turtlescreen = turtle.getscreen() screen.setup(

【温酒笔记】SPI

1. SPI基础 物理层 片选线 &#xff1a;选中拉低SCK: 时钟线MOSI:主出从入MISO:主入从出 协议层 CPOL:时钟极性&#xff1a;空闲电平高低 CPHA:时钟相位&#xff1a;第一个还是第二个边沿采样 2. 示例SPI-W25Q16 (见模组分类下文章)

mac电脑设置crontab定时任务,以及遇到的问题解决办法

crontab常用命令 crontab -u user&#xff1a;用来设定某个用户的crontab服务&#xff1b; crontab file&#xff1a;file是命令文件的名字,表示将file做为crontab的任务列表文件并载入crontab。如果在命令行中没有指定这个文件&#xff0c;crontab命令将接受标准输入&#xf…

MySQL中,如何定位慢查询?定位到的慢SQL如何分析?

目录 1. 慢查询发生的场景&#xff1f; 2. MySQL中&#xff0c;如何定位慢查询&#xff1f; 2.1 详细解释 3. 定位到的慢SQL如何分析&#xff1f; 3.1 详细说明 1. 慢查询发生的场景&#xff1f; 2. MySQL中&#xff0c;如何定位慢查询&#xff1f; 介绍一下当时产生问题…

大数据新视界 -- 大数据大厂之提升 Impala 查询效率:索引优化的秘籍大揭秘(上)(3/30)

&#x1f496;&#x1f496;&#x1f496;亲爱的朋友们&#xff0c;热烈欢迎你们来到 青云交的博客&#xff01;能与你们在此邂逅&#xff0c;我满心欢喜&#xff0c;深感无比荣幸。在这个瞬息万变的时代&#xff0c;我们每个人都在苦苦追寻一处能让心灵安然栖息的港湾。而 我的…

Nico,从零开始干掉Appium,移动端自动化测试框架实现

开头先让我碎碎念一波~去年差不多时间发布了一篇《 UiAutomator Nico&#xff0c;一个基于纯 adb 命令实现的安卓自动化测试框》&#xff08;https://testerhome.com/topics/37042&#xff09;&#xff0c; 由于种种原因 (详见此篇帖子) 当时选择了用纯 adb 命令来实现安卓自动…

小样本语义分割(HDMNet网络)

小样本语义分割&#xff08;HDMNet网络&#xff09; 摘要HDMNet 解决的问题本文贡献HDMNet 模型1. 特征提取2. 解耦下采样和匹配模块&#xff08;分层匹配结构&#xff09;2.1. 粗粒度到细粒度解码器2.2 . 自注意力模块2.3. 相关性模块 3. 损失函数 总结 摘要 小样本语义分割&…

layui 自定义验证单选框必填

对于输入框类型必填验证&#xff0c;只需要在 input 输入框加入 lay-verify "required" 即可。但对于单选按钮这种特殊的该怎么办呢&#xff1f;layui 为我们提供了自定义验证。 1. 在单选按钮上添加自定义验证的名称 2. 验证规则如下 // 单选框自定义验证form.ve…

植物神经紊乱别担心,这些运动让你重拾健康与平衡✨

在这个快节奏、高压力的时代&#xff0c;植物神经紊乱似乎已经成为现代人的“隐形杀手”。焦虑、失眠、心跳过速、呼吸不规律……这些症状不仅影响了我们的日常生活&#xff0c;更在无声中侵蚀着我们的身心健康。但别担心&#xff0c;通过科学合理的运动&#xff0c;我们可以有…

第1篇 引言

一、AIGC概念 1、AIGC定义 AIGC&#xff0c;即生成式人工智能&#xff08;Artificial Intelligence Generated Content&#xff09;&#xff0c;是指利用人工智能技术自动生成或辅助创作内容的过程和结果。 简单来说&#xff1a;过去&#xff0c;写文章、画张图、唱首歌、弄个…

2. 从服务器的主接口入手

Webserver 的主函数 main.cpp&#xff0c;完成了哪些功能&#xff1f; #include "config.h"int main(int argc, char *argv[]) {string user "";string passwd "";string databasename "";Config config;config.parse_arg(argc, a…

向量数据库 PieCloudVector 进阶系列丨打造音乐推荐系统

在上一篇内容中&#xff0c;我们介绍了 PieCloudVector 如何助力构建基于图片数据的商品推荐系统&#xff0c;详细描述从数据集的准备到数据向量化处理&#xff0c;再到向量数据的存储和相似性搜索的完整流程。本文将进一步探讨如何将 PieCloudVector 应用于音频数据&#xff0…

python之数据结构与算法(数据结构篇)-- 栈

一、栈的概念 这里我们不去了解教科书上面的“教条概念”&#xff0c;其实“栈”的概念和古代的时候的“客栈”是有异曲同工之妙的。 在这里我们把客栈看成“栈”&#xff0c;旅客看作“栈元素” 1.当旅客进来住店时&#xff0c;叫做“入栈”&#xff1b; 2.当旅客退房时&#…

Java调用chatgpt

目前openai的chatgpt在国内使用有一定难度&#xff0c;不过国内的大模型在大部分情况下已经不弱于chatgpt&#xff0c;而且还更便宜&#xff0c;又能解决国内最敏感的内容安全问题。本文后续以spring ai调用国内chatgpt厂商实现为例&#xff0c;讲解怎么构建一个java调用chatgp…

web前端多媒体标签设置(图片,视频,音频)以及图片热区(usemap)的设置

多媒体标签运用 在HTML中有以下常见多媒体标签&#xff1a; <img> &#xff08;图像标签&#xff09; - 作用&#xff1a;用于在网页中嵌入图像。 - 示例&#xff1a; <img src"image.jpg" alt"这是一张图片"> 。其中 src 属性指定图像的…

安卓开发之数据库的创建与删除

目录 前言&#xff1a;基础夯实&#xff1a;数据库的创建数据库的删除注意事项 效果展示&#xff1a;遇到问题&#xff1a;如何在虚拟机里面找到这个文件首先&#xff0c;找到虚拟机文件的位置其次&#xff0c;找到数据库文件的位置 核心代码&#xff1a; 前言&#xff1a; 安…