图片速览 BitNet: 1-bit LLM

news2024/12/25 0:42:49

输入数据

  • 模型使用absmax 量化方法进行b比特量化,将输入量化到 [ − Q b , Q b ] ( Q b = 2 b − 1 ) \left[-Q_{b},Q_{b}\right](Q_{b}=2^{b-1}) [Qb,Qb](Qb=2b1)
    x ~ = Q u a n t ( x ) = C l i p ( x × Q b γ , − Q b + ϵ , Q b − ϵ ) , Clip ⁡ ( x , a , b ) = max ⁡ ( a , min ⁡ ( b , x ) ) , γ = ∣ ∣ x ∣ ∣ ∞ , \widetilde{x}=\mathrm{Quant}(x)=\mathrm{Clip}(x\times\frac{Q_b}{\gamma},-Q_b+\epsilon,Q_b-\epsilon),\\ \operatorname{Clip}(x,a,b)=\max(a,\min(b,x)),\quad\gamma=||x||_\infty, x =Quant(x)=Clip(x×γQb,Qb+ϵ,Qbϵ),Clip(x,a,b)=max(a,min(b,x)),γ=∣∣x,

  • 其中 ε 是一个小的浮点数,可防止在执行截断时溢出。

// https://github.com/kyegomez/BitNet/blob/main/bitnet/bitbnet_b158.py
def absmean_quantize_weights(weights):
    """
    Quantizes the weights to -1, 0, or +1 using an absmean quantization function.

    Parameters:
    - weights (Tensor): The weights of a neural network layer.

    Returns:
    - Tensor: The quantized weights.
    """
    # Calculate the average absolute value (γ) of the weights
    gamma = torch.mean(torch.abs(weights))
    
    # Scale weights by γ and round to the nearest integer among {-1, 0, +1}
    quantized_weights = torch.clamp(torch.round(weights / gamma), min=-1, max=1)
    
    return quantized_weights

权重

  • 权重 W 的二值化可以公式化为:

α = 1 n m ∑ i j W i j W ~ = S i g n ( W − α ) , Sign ⁡ ( W i j ) = { + 1 , if W i j > 0 , − 1 , if W i j ≤ 0 , \\ \alpha=\frac1{nm}\sum_{ij}W_{ij} \\ \widetilde{W}=\mathrm{Sign}(W-\alpha),\\ \left.\operatorname{Sign}(W_{ij})=\left\{\begin{array}{ll}+1,&\quad\text{if}W_{ij}>0,\\-1,&\quad\text{if}W_{ij}\leq0,\end{array}\right.\right. α=nm1ijWijW =Sign(Wα),Sign(Wij)={+1,1,ifWij>0,ifWij0,

在这里插入图片描述

矩阵乘法

  • 使用上述量化方程,矩阵乘法可以写成:

y = W ~ x ~ y=\widetilde W\widetilde{x} y=W x

  • 为了保持量化后的方差,我们在激活量化之前引入了一个 LayerNorm函数。这样,输出 y 的方差就估计为 1

y = W ~ x ~ = W ~ Quant ( LN ( x ) ) × β γ Q b y=\widetilde{W}\widetilde{x}=\widetilde{W}\text{Quant}(\text{LN}(x))\times\frac{\beta\gamma}{Q_b} y=W x =W Quant(LN(x))×Qbβγ
L N ( x ) = x − E ( x ) V a r ( x ) + ϵ , β = 1 n m ∥ W ∥ 1 \mathrm{LN}(x)=\frac{x-E(x)}{\sqrt{\mathrm{Var}(x)+\epsilon}},\quad\beta=\frac1{nm}\|W\|_1 LN(x)=Var(x)+ϵ xE(x),β=nm1W1

在这里插入图片描述

// https://github.com/kyegomez/BitNet/blob/main/bitnet/bitlinear.py
import torch
from torch import Tensor, nn


class BitLinear(nn.Linear):
    """
    BitLinear is a custom linear layer that performs binarization of weights and quantization of activations
    in a group-wise manner.

    Args:
        in_features (int): Number of input features.
        out_features (int): Number of output features.
        bias (bool, optional): If set to False, the layer will not learn an additive bias. Default is True.
        num_groups (int, optional): Number of groups to divide the weights and activations into. Default is 1.
    """

    def __init__(
        self,
        in_features: int,
        out_features: int,
        bias: bool = True,
        num_groups: int = 1,
        b: int = 8,
    ):
        super().__init__(in_features, out_features, bias)
        self.in_features = in_features
        self.out_features = out_features
        self.b = b
        self.num_groups = num_groups
        self.eps = 1e-5
        self.norm = nn.LayerNorm(in_features)

    def ste(self, x):
        """
        Applies the sign function for binarization and uses Straight-Through Estimator (STE) during backward pass.

        Args:
            x (Tensor): Input tensor.

        Returns:
            Tensor: Binarized tensor.
        """
        binarized_x = torch.sign(x)
        binarized_x = (binarized_x - x).detach() + x
        return binarized_x

    def binarize_weights_groupwise(self):
        """
        Binarizes the weights of the layer in a group-wise manner using STE.

        Returns:
            Tensor: Binarized weights tensor.
        """
        group_size = self.weight.shape[0] // self.num_groups
        binarized_weights = torch.zeros_like(self.weight)

        for g in range(self.num_groups):
            start_idx = g * group_size
            end_idx = (g + 1) * group_size
            weight_group = self.weight[start_idx:end_idx]

            alpha_g = weight_group.mean()
            binarized_weights[start_idx:end_idx] = self.ste(weight_group - alpha_g)

        return binarized_weights

    def quantize_activations_groupwise(self, x):
        """
        Quantizes the activations of the layer in a group-wise manner.

        Args:
            x (Tensor): Input tensor.
            b (int, optional): Number of bits for quantization. Default is 8.

        Returns:
            Tensor: Quantized activations tensor.
        """
        Q_b = 2 ** (self.b - 1)

        group_size = x.shape[0] // self.num_groups
        quantized_x = torch.zeros_like(x)

        for g in range(self.num_groups):
            start_idx = g * group_size
            end_idx = (g + 1) * group_size
            activation_group = x[start_idx:end_idx]

            gamma_g = activation_group.abs().max()
            quantized_x[start_idx:end_idx] = torch.clamp(
                activation_group * Q_b / (gamma_g + self.eps),
                -Q_b + self.eps,
                Q_b - self.eps,
            )

        return quantized_x
    
    def dequantize_activations_groupwise(self, x):
        """
        Dequantizes the activations of the layer in a group-wise manner.

        Args:
            x (Tensor): Quantized input tensor.
            b (int, optional): Number of bits used during the quantization. Default is 8.

        Returns:
            Tensor: Dequantized activations tensor.
        """
        Q_b = 2 ** (self.b - 1)
        dequantized_x = torch.zeros_like(x)
        for g in range(self.num_groups):
            start_idx = g * x.shape[0] // self.num_groups
            end_idx = (g + 1) * x.shape[0] // self.num_groups
            quantized_group = x[start_idx:end_idx]
            gamma_g = quantized_group.abs().max()
            dequantized_x[start_idx:end_idx] = quantized_group * gamma_g / Q_b
        return dequantized_x

    def forward(self, x: Tensor) -> Tensor:
        """
        Forward pass of the BitLinear layer.

        Args:
            x (Tensor): Input tensor.

        Returns:
            Tensor: Output tensor.
        """
        # Normalize input
        x = self.norm(x)

        # Binarize weights and quantize activations
        binarized_weights = self.binarize_weights_groupwise()

        # Perform linear transformation
        output = torch.nn.functional.linear(x, binarized_weights, self.bias)

        # Quantize activations
        output = self.quantize_activations_groupwise(output)
        
        # Dequantize activations
        output = self.dequantize_activations_groupwise(output)

        # Return output
        return output



# Example usage
bitlinear = BitLinear(10, 5, num_groups=2, b=8)
input_tensor = torch.randn(5, 10)  # Example input tensor
output = bitlinear(input_tensor)
print(output)  # Example output tensor

CG

  • 【自然语言处理】【大模型】BitNet:用1-bit Transformer训练LLM

  • BitNet: Scaling 1-bit Transformers for Large Language Models

  • The Era of 1-bit LLMs: All Large Language Models are in 1.58 Bits

  • Implementation of “BitNet: Scaling 1-bit Transformers for Large Language Models” in pytorch

  • DB-LLM: Accurate Dual-Binarization for Efficient LLMs

  • 如何看待微软提出的BitNet b1.58?

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1499887.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

代码随想录算法训练营第day9|28. 找出字符串中第一个匹配项的下标、459.重复的子字符串

a.28. 找出字符串中第一个匹配项的下标 题目链接 给你两个字符串 haystack 和 needle ,请你在 haystack 字符串中找出 needle 字符串的第一个匹配项的下标(下标从 0 开始)。如果 needle 不是 haystack 的一部分,则返回 -1 。 示…

小火星露谷管理器 如何禁用管理器下载?

错误操作 当你在N网点击下载时,你可能会点击左边第一个按钮进行下载,如图: 然后你可能会看到这样的一个提示: 很多用户看着这个提示误以为小火星露谷管理器禁用了N网的下载。 正确操作 N网网页上的按钮MOD MANAGER DOWNLOAD翻…

[PTA] 分解质因子

输入一个正整数n(1≤n≤1e15),编程将其分解成若干个质因子(素数因子)积的形式。 输入格式: 任意给定一个正整数n(1≤n≤1e15)。 输出格式: 将输入的正整数分解成若干个质因子积的形式&#…

Linux 之五:权限管理(文件权限和用户管理)

1. 文件权限 在Linux系统中,文件权限是一个非常基础且重要的安全机制。它决定了用户和用户组对文件或目录的访问控制级别。 每个文件或目录都有一个包含9个字符的权限模式,这些字符分为三组,每组三个字符,分别对应文件所有者的权限…

面向对象中类与对象

思考系统1000个对象逻辑结构 理解系统1000个对象物理结构 对象this 引用 类的静态变量和静态函数 静态变量和静态函数属于类本身,而不是类的实例。它们可以在不创建类的实例的情况下直接通过类名访问。静态变量在内存中只有一份拷贝,被所有实例共享&…

基于FPGA加速的bird-oid object算法实现

导语 今天继续康奈尔大学FPGA 课程ECE 5760的典型案例分享——基于FPGA加速的bird-oid object算法实现。 (更多其他案例请参考网站: Final Projects ECE 5760) 1. 项目概述 项目网址 ECE 5760 Final Project 模型说明 Bird-oid object …

关于esp8266的一些经验汇总,新手必看

说实话,esp8266的nodemcu 已经使用了2年多了,各种问题遇到过,就尝试各种解决,而现在回头来看真的是稀里糊涂的在用,当然这个问题也同样涉及到esp32. 因为最近打算自己打一块esp8266的板,之前打的比较多的是…

数据结构之单链表详解(C语言手撕)

​ 🎉个人名片:🐼作者简介:一名乐于分享在学习道路上收获的大二在校生 🙈个人主页🎉:GOTXX 🐼个人WeChat:ILXOXVJE 🐼本文由GOTXX原创,首发CSDN…

(五)关系数据库标准语言SQL

注:课堂讲义使用的数据库 5.1利用SQL语言建立数据库 5.1.1 create Database 5.1.2 create schema...authorization... 创建数据库和创建模式的区别: 数据库是架构的集合,架构是表的集合。但在MySQL中,他们使用的方式是相同的。 …

如何修改SAP标准代码

文章目录 1 Introduction2 Method2.1 Click Change2.2 Switch off Assistent 3 Summary 1 Introduction In the sap sometimes we need change the standard code . I.E. How to comment code ? 2 Method 2.1 Click Change 2.2 Switch off Assistent This is the result wh…

GAMMA电源维修高压直流电源ES30P-5W ES系列

美国Gamma高压电源维修型号:D-ES30R-10N-5W/M,LXR30-1N,XRM5N-100W,ES50P-10W/DDPM,ES60P-10W/DDPM,RR20-20P/DDPM,ES30P-10W,ES60P-10W DDPM,RR60-18P/220V,…

iStoreOS系统内安装HomeAssistant服务

iStoreOS系统内安装HomeAssistant服务 1. HomeAssistant服务 HomeAssistant是一款基于Python的开源智能家居系统,简称HA。 HomeAssistant可以方便地连接各种外部设备,如智能设备、摄像头、邮件、短消息和云服务等,其成熟的可连接组件有近千…

rocketmq学习笔记(一)安装部署

初次使用rocketmq,记录一下全流程步骤。 1、下载安装包 首先在官网,下载安装包,可也根据官方文档进行部署,但有一些细节没说明,可能会有坑,本文会尽量详细的描述每个步骤,把我踩过的坑填补上。…

Python自动化测试:API接口自动化——requests、webSocket

接口自动化测试1 一、requests二、简单示例1.导入/引入库2.请求与响应示例1>简单访问百度主页-GET请求2>简单的登录请求-POST请求3>保存cookies至头信息headers4>其他接口请求时携带headers 三、webSocketwebSocket连接与数据收发示例 本文介绍了借助Python的reque…

Manacher 算法——Leetcode 5.最长回文子串

在了解之前,我们先要了解什么是回文串,什么是回文子串。 回文串和回文子串: 回文串是指一个字符串正序遍历和反向遍历结果相同的字符串。如 ABBA,正着读反着读结果是一样的。 有了回文串的概念,回文子串的概念也就显…

顺势交易中,用什么方法识别趋势的开始与结束?

在交易过程中,大家都知道顺势交易的重要性,但如何对趋势的开始和结束进行量化判断呢? 趋势交易需要一个正确的出发点和思想方向。也就是说,趋势交易需要关注什么呢?有哪些相关的技术手段可以利用呢? 首先&a…

springboot使用异步多线程

shigen坚持更新文章的博客写手,擅长Java、python、vue、shell等编程语言和各种应用程序、脚本的开发。记录成长,分享认知,留住感动。 个人IP:shigen 在shigen之前的很多文章中,提到了线程池: 高性能API设计…

一 windso10 笔记本刷linux cent os7.9系统

1:准备材料 16G以上U盘, 笔记本一台 镜像选了阿里云镜像:centos-7-isos-x86_64安装包下载_开源镜像站-阿里云 软件:链接:https://pan.baidu.com/s/13WDp2bBU1Pdx4gRDfmBetg 提取码:09s3 2:把镜像写入U盘,本人已经写入好了,选择镜像,点开始就是,确定等…

javascript正则深入

文章目录 一、前言二、高级`API`2.1、模式匹配的用法`(x)`2.2、非捕获括号的模式匹配`(?:x)`2.3、先行断言`x(?=y)`2.4、后行断言`(?<=y)x`2.5、正向否定查找`x(?!y)`2.6、反向否定查找`(?<!y)x`2.7、字符集合和反向字符集合的用法 `[xyz] / [^xyz]`2.8、词边界和非…

开关电源安规测试标准与测试要求

安规测试是对开关电源进行电气性能、安全性能等检测&#xff0c;确保开关电源符合规定并且安全可靠&#xff0c;为开关电源的质量把关。那么开关电源安规测试有哪些测试要求和标准呢&#xff1f; 开关电源安规测试要求 一、测试前 1. 首先&#xff0c;要检查测试环境&#xff0…