用PyTorch 从零开始构建 BitNet 1.58bit

news2025/1/20 13:29:09

我们手动实现BitNet的编写,并进行的一系列小实验证实,看看1.58bit 模型是否与全精度的大型语言模型相媲美!

什么是量化以及为什么需要它?

量化是用更少的比特数表示浮点数的过程。当两个数字使用不同的比特数进行量化时,浮点运算的计算成本几乎按照减少的比特数的比例降低(理论上)。这使我们能够提高速度并减少机器学习模型的内存消耗。但这通常会导致信息丢失,从而降低准确性,我们可以通过对量化模型进行更多的微调来一定程度上恢复这种损失。

现有的量化方法与 BitNet 1.58bit 对比

大多数量化算法都需要一个全精度的预训练模型。人们通常会应用后训练量化(PTQ)和量化感知训练(QAT)等技术,以使这些算法有效运行。

PTQ 是一种量化技术,模型在训练完成后进行量化。QAT 是对 PTQ 模型的进一步微调,即在考虑量化的情况下进一步训练模型。

而BitNet 采用了一种截然不同的方法,即从头开始训练模型时就进行量化!

BitNet 的量化算法

上图中,通过取绝对值的平均值的一半(假设 n=2)来计算权重裁剪阈值 γ。然后,权重矩阵 W 被相同的值除,导致新的权重矩阵在原始权重值 ≥ γ 时的值 ≥ 1,原始权重值 ≤ -γ 时的值 ≤ -1。对于 -γ 和 γ 之间的值,它们被映射到 -0.99999… 到 0.9999…

当执行 roundclip 时,

对于原始值 ≥ γ,新值为 1.0,原始值 ≤ -γ,新值为 -1.0,原始值在 -γ 和 γ 之间的新值为 0.0。

理论上,结果值可以用信息编码理论表示为 1.58 位。由于位数不能是分数,我们可以用 2 位来表示。

量化函数在Pytorch中的实现

阈值计算:

 def compute_adjustment_factor(self, input_tensor: torch.Tensor):
     absmean_weight = torch.mean(torch.abs(input_tensor))
     adjustment_factor = 1e-4 + absmean_weight * 2 # 1e-4 to avoid zero divison error
     return adjustment_factor

这里没有把绝对值减半,而是乘以了2。但是实验还是成功了!

RoundClip (1.58~= 2bit)

 def compute_2bit_quantized_tensor(self, input_tensor: torch.Tensor):
     twobit_matrix = torch.clip(input=torch.round(input_tensor), min=-1, max=1)
     return twobit_matrix
 
 def compute_1bit_quantized_tensor(self, input_tensor: torch.Tensor):
     return torch.sign(input_tensor)
 
 def compute_quantized_tensor(self, input_tensor: torch.Tensor):
     if self.quantization_mode == QuantizationMode.two_bit:
         return self.compute_2bit_quantized_tensor(input_tensor)
     else:
         return self.compute_1bit_quantized_tensor(input_tensor)

量化步骤

 weight_adjustment_factor = self.compute_adjustment_factor(self.weight)
 adjusted_weight = self.weight / weight_adjustment_factor
 quantized_weight = self.compute_quantized_tensor(adjusted_weight)

线性层操作

 F.linear(weight_adjustment_factor * x, quantized_weight, self.bias)

将调整因子与输入相乘,并将其除以量化权重

如果在将权重传递给线性层函数之前对其进行量化,则对量化矩阵的更新不会通过量化函数(因为大多数更新将在1e-4到1e-2之间,当通过量化步骤反向传播时将变为零)。因为原始的权重矩阵永远不会更新,模型永远不会学习!!

但有一个巧妙的工程技巧可以做到这一点,完整的前向传播是这样的

 def forward(self, x):
     weight_adjustment_factor = self.compute_adjustment_factor(self.weight)
     adjusted_weight = self.weight / weight_adjustment_factor
 
     if self.training:
         quantized_weight = (
             adjusted_weight
             + (
                 self.compute_quantized_tensor(adjusted_weight) - adjusted_weight
             ).detach()
         )
     else:
         quantized_weight = self.compute_quantized_tensor(adjusted_weight)
 
     return F.linear(weight_adjustment_factor * x, quantized_weight, self.bias)

量化权重块的值无论

self.training

是否设置为

True

都是相同的。但是当

self.training

设置为

True

时,计算得到的梯度会被优雅地复制到调整后的权重中。这允许在训练过程中更新调整后的权重,同时也更新原始的权重矩阵。

这是从谷歌 DeepMind 的 VQ VAE PyTorch 实现中借鉴的简单却实用的技巧

自定义Pytorch实现的实验结果

下面的实验选择了一个小型模型和一个相对于小型模型来假设足够大的数据集。此外,为了创建目标模型的量化变体,我简单地使用以下代码块,将

nn.Linear

模块替换为这个自定义实现:

 import copy
 
 def create_quantized_copy_of_model(
     input_model: nn.Module, quantization_mode: QuantizationMode
 ):
     model_copy = copy.deepcopy(input_model)
     hash_table = {n: m for n, m in model_copy.named_modules()}
 
     for key in list(hash_table.keys()):
         if isinstance(hash_table[key], nn.Linear):
             new_module = BitNetLinearLayer(
                 in_features=hash_table[key].in_features,
                 out_features=hash_table[key].out_features,
                 bias=hash_table[key].bias is not None,
                 quantization_mode=quantization_mode,
             )
             name_chain = key.split(".")
             parent_module_attr_name = ".".join(name_chain[:-1])
             parent_module = hash_table[parent_module_attr_name]
             setattr(parent_module, name_chain[-1], new_module)
     for n, m in model_copy.named_modules():
         assert not isinstance(m, nn.Linear)
     return model_copy

结果如下:

4层FFN的Mnist结果 :

128维6层VIT版本训练Fashion MNIST的结果

128维8层VIT在 CIFAR100上的结果

我们可以看到,除了第一个实验外,2位和1位版本的模型与全精度的常规版本的模型表现得一样好。在第一个实验中,量化模型可能发生了灾难性遗忘。

这些实验并未使用大型语言模型(LLMs)进行,但足以证明论文关于这样的系统能与全精度模型竞争的说法。

我们的实验与论文的唯一一个区别是,这个实现并没有将量化权重存储在2位矩阵中,计算仍以fp32执行的,要真正看到计算速度的提升,需要为此专门的计算内核,我们目前没有能力编写,所以实现仅验证了论文的潜在的论点。

以上实验的所有代码和模块代码都可以在github repo中找到

https://avoid.overfit.cn/post/131875e588ac4f4aa4f15d2dfa5b46db

作者:Chidhambararajan R

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1978742.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

一篇教会你PXE高效批量网络装机及kickstart无人值守安装

目录 搭建PXE的前提 搭建PEX的过程 如何构建PXE服务器 搭建本地yum源 搭建apache 创建软链接将本地yum源到apache页面下 搭建dhcp服务 dhcp配置文件如下 使用system-config-kickstart生成ks.cfg文件 ,.cfg配置文件如下 搭建TFTP服务 搭建完成后测试 搭建…

跟李沐学AI:NiN网络中的网络

NiN块 一个卷积层后跟着两个全连接层(实际为核窗口大小为1x1的卷积层)。卷积层步幅为1,无填充,输出形状与卷积层输出形状相同,起到全连接层的作用。 NiN架构 无全连接层,交替使用NiN块和步幅为2的最大池化…

【C++标准模版库】list的介绍及使用

list 一.list的介绍二.list的使用1.list 构造函数2.list 空间大小3.list 增删查改4.list 迭代器的使用1.正向迭代器2.反向迭代器 5.list 其他成员函数 三.vector与list关于sort性能的比较 一.list的介绍 C中的list标准模板库(STL)是C标准库中的一个重要组…

Linux文件管理和IO重定向知识总结

目录 一,文件管理 Linux的目录结构是一个树状结构: 文件的分类: 操作文件的常用命令: 文件元数据和节点和inode表结构: 特点: 创建文件: 查看文件inode号: cp和inode&#x…

揭秘Matplotlib等高线图:让数据‘高山流水‘间,笑点与深度并存!

1. 引言 在这个数据如山的时代,你是不是也曾在茫茫数海中迷失方向,渴望找到那片隐藏的“数据绿洲”?别怕,今天咱们就来聊聊Matplotlib这位绘图界的魔术师,特别是它那令人叹为观止的等高线图技能。想象一下&#xff0c…

领域模型(Domain Model)

前言 软件的核心是其为用户解决领域相关的问题的能力。所有其他特性,不管有多么重要,都要服务于这个基本目的。当领域很复杂时,这是一项艰巨的任务,要求高水平技术人员的共同努力。开发人员必须钻研领域以获取业务知识。他们必须…

拉刀基础知识——拉刀的种类

如前面所说:近期要围绕拉削和拉刀这个话题,分享一些相关的内容,从最基础的知识开始,为此还专门买了本旧书——《拉刀设计》入门学习。废话不多说,直接开始。 拉刀最早由冲头演变而来,用于加工方孔&#xf…

【Web】TFCCTF 2024 部分题解

目录 GREETINGS SURFING SAFE_CONTENT FLASK DESTROYER GREETINGS 打express的SSTI GitHub - TheWation/NodeJsSSTI: Express app with Pug templates demonstrating SSTI vulnerability and secure implementation for educational purposes. payload: /result?user…

历史标签如何时间迁移?

本文解析的论文是: Lin, C.; Du, P.; Samat, A.; Li, E.; Wang, X.; Xia, J. Automatic Updating of Land Cover Maps in Rapidly Urbanizing Regions by Relational Knowledge Transferring from GlobeLand30. Remote Sens. 2019, 11, 1397. https://doi.org/10.33…

一刷代码随想录(动态规划2)

62.不同路径 题意: 一个机器人位于一个 m x n 网格的左上角 (起始点在下图中标记为 “Start” )。 机器人每次只能向下或者向右移动一步。机器人试图达到网格的右下角(在下图中标记为 “Finish” )。 问总共有多少…

我的面包多

我的面包多主页:https://mbd.pub/o/author-bGubnGpq 欢迎咨询。

JavaSE面试篇章——一文干破Java集合

文章目录 Java集合——一文干破集合一、集合的理解和好处1.1 数组1.2 集合 二、集合的框架体系三、Collection接口和常用方法3.1 Collection接口实现类的特点3.2 Collection接口遍历元素方式1-使用Iterator(迭代器)3.2.1 基本介绍3.2.2 迭代器的执行原理3.2.3 Iterator接口的方…

数据库典型例题2-ER图转换关系模型

1.question solution: 2.做题步骤 一些解释&#xff1a; <1弱实体把强属性的主键写进去&#xff0c;指向强属性。eg:E6_A13指向E5_A13 <21:1&#xff0c;1:n&#xff0c;m:n&#xff1a;将完全参与的一方&#xff08;双线&#xff09;指向另一方&#xff0c;并将对方的…

AutoCAD ObjectArx二次开发(三) 创建MFC界面

主题&#xff1a;本章节主要介绍在ObjectARX项目中如何使用MFC界面进行交互操作&#xff0c;具体采用模态对话框的形式。 一、创建MFC的对话框 在项目中添加新项&#xff0c;选择MFC类&#xff0c;点击确定按钮&#xff0c;如下图所示。 然后会出现下图界面&#xff0c;填写类…

苹果应用程序清理卸载工具:App Cleaner Uninstaller Pro for Mac

App Cleaner & Uninstaller Pro 是一款专为 Mac OS X 操作系统设计的应用程序清理和卸载工具。这款软件的主要功能是帮助用户彻底删除不需要的应用程序、插件和残留文件&#xff0c;从而释放磁盘空间并提高系统性能。 特点和优势&#xff1a; 彻底卸载应用程序&#xff1a;…

【软件设计书】详细设计说明书和概要设计说明书(Word原件直接套用)

系统详细设计说明书案例&#xff08;直接套用&#xff09; 1.系统总体设计 2.性能设计 3.系统功能模块详细设计 4.数据库设计 5.接口设计 6.系统出错处理设计 7.系统处理规定 软件开发全文档下载&#xff08;下面链接或者本文末个人名片直接获取)&#xff1a;本文末个人名片直接…

【C语言】文件操作(下)

文章目录 前言1. 文件的读和写2. 文件的顺序读写2.1 顺序读写函数的介绍2.1.1 fgetc 和 fputc2.1.2 fgets 和 fputs 3. 文件缓冲区4. 总结 前言 在之前文件操作&#xff08;上&#xff09;和文件操作&#xff08;中&#xff09;的文章中&#xff0c;我从为什么要使用文件再到文…

RabbitMQ高级特性 - 生产者消息确认机制

文章目录 生产者消息确认机制概述confirm 代码实现return 代码实现 生产者消息确认机制 概述 为了保证信息 从生产者 发送到 队列&#xff0c;因此引入了生产者的消息确认机制. RabbitMQ 提供了两种解决方案&#xff1a; 通过事务机制实现.通过发送确认机制&#xff08;confi…

CPU利用率100%该怎么办

1 节拍率 Linux 作为一个多任务操作系统&#xff0c;将每个 CPU 的时间划分为很短的时间片&#xff0c;再通过调度器轮流分配给各个任务使用&#xff0c;因此造成多任务同时运行的错觉。 为了维护 CPU 时间&#xff0c;Linux 通过事先定义的节拍率&#xff08;内核中表示为 H…

AI大模型定级体系

前言&#xff1a;一直以来人们对通用人工智能&#xff08;AGI&#xff09;的定义始终缺乏一个具体的衡量标准&#xff0c;而现在OpenAI已创建了一套分级系统。 AI大模型定级 OpenAI对于其大模型的定级有一个独特的分级体系&#xff0c;旨在描述其人工智能系统的发展阶段以及距…