2.3 TensorRT基于Entropy的校准

news2024/11/25 4:31:11

tensorRT的Entropy Calibration的伪代码,具体流程如下:

  • for循环:遍历所有可能的分割点,从128到2048
  • reference_distribution_P:将原始直方图bins按照当前分割点i进行切割,得到左侧的i个bin。
  • outliers_count:将原始直方图bins按照当前分割点i进行切割,得到右侧的2048-i个bin。
  • reference_distribution_P[ i-1 ] += outliers_count:将outliers_count加入到reference_distribution_P中,得到新的概率分布。
  • P /= sum§:将reference_distribution_P进行归一化。
  • candidate_distribution_Q:将当前的i个bin分成128个level,得到candidate_distribution_Q,表示我们将reference_distribution_P进行量化。
  • Q /= sum(Q):将candidate_distribution_Q进行归一化。
  • KL_divergence( reference_distribution_P, candidate_distribution_Q):计算当前量化方法下的KL散度,并将其保存在divergence中。
  • 循环结束后,divergence中记录了每个分割点i下的KL散度。我们选取KL散度最小的分割点i作为最优的分割点,并将其作为最终的量化参数。
    在这里插入图片描述

总的来说,Entropy Calibration的过程就是将概率分布量化成少量的level,并寻找最优的level,使得量化后的分布和原始分布的KL散度最小。

有一个问题需要讨论,既然是INT8量化(2^8=256),为什么我们量化的是128个bins而不是256个bins?
回答:因为量化中针对的数据是激活函数ReLU后的,即经过ReLU后的值均为正数,所以负数就不用考虑了,而原来INT8的取值范围是在[-128,127]之间,因此[-128,0]就不用考虑了,而原始的分布[0,127]就能够表达,因此for循环就是从[128,2048]

完整的示例代码如下:

import random
import numpy as np
import matplotlib.pyplot as plt
def generator_P(size):
    walk = []
    avg  = random.uniform(3.000, 600.999)
    std  = random.uniform(500.000, 1024.959)
    for _ in range(size):
        walk.append(random.gauss(avg, std))
    return walk
# smooth_distribution:对概率分布 P 和 Q 进行平滑处理,避免 KL 散度计算时出现分母为0的情形
def smooth_distribution(p, eps=0.0001):
    is_zeros = (p == 0).astype(np.float32)
    is_nonzeros = (p != 0).astype(np.float32)
    n_zeros = is_zeros.sum()
    n_nonzeros = p.size - n_zeros
    if not n_nonzeros:
        raise ValueError('The discrete probability distribution is malformed. All entries are 0.')
    eps1 = eps * float(n_zeros) / float(n_nonzeros)
    assert eps1 < 1.0, 'n_zeros=%d, n_nonzeros=%d, eps1=%f' % (n_zeros, n_nonzeros, eps1)
    hist = p.astype(np.float32)
    hist += eps * is_zeros + (-eps1) * is_nonzeros
    assert (hist <= 0).sum() == 0
    return hist

import copy
import scipy.stats as stats
def threshold_distribution(distribution, target_bin=128):
    # distribution = distribution[1:] # 将distribution数组的第一个元素去掉  [1:]???   
    distribution = distribution[:] 
    length = distribution.size  # distribution的长度
    # 计算概率分布从target_bin位置开始的累加和,即outliers_count
    outliers_count = sum(distribution[target_bin:])  
    # 初始化一个numpy数组,用来存放每个阈值下计算得到的所有KL散度
    kl_divergence = np.zeros(length - target_bin)   
    # for i in range(128,2048)
    for threshold in range(target_bin, length):
        # 将distribution数组中前threshold个元素拷贝到sliced_nd_hist数组中。
        sliced_nd_hist = copy.deepcopy(distribution[:threshold])
        # print(sliced_nd_hist.size)
        # generate reference distribution P
        p = sliced_nd_hist.copy()
        #  将后面outliers_count加到reference_distribution_P中,得到新的概率分布  
        p[threshold - 1] += outliers_count  
        # 将p进行归一化  量化前的p
        p = np.array(p) / np.sum(p)
        # 更新outliers_count的值,第一次循环的outliers_count为distribution[128:],第二次循环的outliers_count为distribution[129:],...
        outliers_count = outliers_count - distribution[threshold] 
        
        # is_nonzeros[k] indicates whether hist[k] is nonzero
        is_nonzeros = (p != 0).astype(np.int64)   # 判断每一位是否非零

        # 量化后的bins
        quantized_bins = np.zeros(target_bin, dtype=np.int64)
        
        # calculate how many bins should be merged to generate
        # quantized distribution q
        num_merged_bins = sliced_nd_hist.size // target_bin    # 计算stride
        
        # merge hist into num_quantized_bins bins
        for j in range(target_bin):
            start = j * num_merged_bins
            # stop最大为127
            stop  = start + num_merged_bins
            quantized_bins[j] = sliced_nd_hist[start:stop].sum()
        # [target_bin * num_merged_bins:]:这里要注意一下     
        quantized_bins[-1] += sliced_nd_hist[target_bin * num_merged_bins:].sum() # 将多余位累加到最后整除的位置上
        # expand quantized_bins into p.size bins
        q = np.zeros(sliced_nd_hist.size, dtype=np.float64) # 进行位扩展
        for j in range(target_bin):
            start = j * num_merged_bins
            
            # 这几行代码改为stop = start + num_merged_bins
            # if j == target_bin - 1:
            #     stop = -1
            # else:
            #     stop = start + num_merged_bins
            stop = start + num_merged_bins
            
            norm = is_nonzeros[start:stop].sum()
            if norm != 0:
                # 求q的平均值
                q[start:stop] = float(quantized_bins[j]) / float(norm)
        # 平滑处理,保证KLD计算出来不会无限大
        # print(p[-10:],q[-10:])
        # exit(1)
        q = np.array(q) /np.sum(q)
        p = smooth_distribution(p)
        q = smooth_distribution(q)

        # calculate kl_divergence between p and q
        kl_divergence[threshold - target_bin] = stats.entropy(p, q) # 计算KL散度
    min_kl_divergence = np.argmin(kl_divergence)    # 选择最小的KL散度
    threshold_value = min_kl_divergence + target_bin
    
    return threshold_value

if __name__ == '__main__':
    # 获取KL最小阈值
    size = 20480
    # generator_P(size):生成一个大小为size的随机数列,并使用高斯分布生成其中每个数的值。
    P = generator_P(size)
    P = np.array(P)
    # 只取大于零的数
    P = P[P>0]
    print("最大的激活值", max(np.absolute(P)))
    # 使用np.histogram(P, bins=2048)将P划分成2048组,并返回各组数据量和区间范围。
    hist, bins = np.histogram(P, bins=2048)
    # print(hist)
    threshold = threshold_distribution(hist, target_bin=128)
    print("threshold 所在组:", threshold)
    print("threshold 所在组的区间范围:", bins[threshold])
    # 分成split_zie组,density表示是否要normed
    plt.title("Relu activation value Histogram")
    plt.xlabel("Activation values")
    plt.ylabel("Normalized number of Counts")
    plt.hist(P, bins=2047)
    plt.vlines(bins[threshold], 0, 30, colors='r', linestyles='dashed')
    plt.show()

输出如下:

最大的激活值 3878.868170933664
threshold 所在组: 1777
threshold 所在组的区间范围: 3365.600012434412

在这里插入图片描述
上述示例代码主要是实现了一种量化方法——基于熵的量化(Entropy-based Quantization)。该方法的基本思路是通过统计神经网络中激活值的分布情况,并通过一些数学模型和技巧将其分解为多个区间,最终实现对神经网络模型中权重和激活值进行有损压缩的目的,从而减少模型的存储和计算开销。

参考链接:
https://blog.csdn.net/qq_40672115/article/details/129942542

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/614649.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

构建新型智能化智慧档案馆十防一体化解决技术方案

HONSOR现代化智慧档案馆智慧档案库房自动化温湿度环境安全监控系统方案【推介】 目前&#xff0c;我国档案正处于现代化科技飞升的起点&#xff0c;以物联网、大数据、智能化、云计算等为标志的网络充斥到社会生活、生产的各个角落。在我国&#xff0c;档案馆产生与发展经历了一…

为了搞懂ERP,我连问 ChatGPT 30个问题,通透!

我对ERP有很多的疑问&#xff0c;这次向ChatGPT请教&#xff0c;连问30个问题&#xff0c;瞬间觉得通透了很多&#xff0c;以下是我的问题和ChatGPT的回复&#xff0c;分为概念篇、架构篇和生态篇三部分&#xff0c;希望能带给你新的启示。 一、概念篇 1、ERP是什么&#xff1f…

图论与算法(5)图的广度优先遍历应用

1. 广度优先遍历 1.1 树的广度优先遍历 树的广度优先遍历&#xff08;Breadth-First Traversal&#xff09;&#xff0c;也称为层次遍历&#xff0c;是一种按层次顺序逐级访问树节点的遍历方式。在广度优先遍历中&#xff0c;先访问树的根节点&#xff0c;然后按照从上到下、…

高频面试八股文用法篇(二) hive中几种排序类型区别

目录 排序函数 1、order by 2、sort by 3、distribute by 4、cluster by 总结 排序类型 1、order by order by是与关系型数据库的用法是一样的。select * from emp order by empno desc; 针对全局数据进行排序&#xff0c;所以最终只会有1个reduce&#xff0c;因…

开源 Golang 微服务入门一: HTTP 框架 Hertz

前言 从本篇笔记开始将介绍 Go 框架三件套&#xff08;Web / RPC / ORM&#xff09;&#xff0c;框架的学习有助于后续课程的学习以及大项目的完成。本文主要介绍字节跳动的开源 Golang 微服务 HTTP 框架 Hertz。先了解一下三件套的相关基本知识&#xff0c;做一下铺垫&#x…

ArgoCD(二)--部署

3.2 ArgoCD部署 ArgoCD部署官网&#xff1a;https://argo-cd.readthedocs.io/en/stable/getting_started/ ArgoCD有两种部署方式&#xff1a;多租户部署和核心化部署&#xff1a; 多租户 常用于多个应用程序开发团队提供服务&#xff0c;并由平台团队维护的场景&#xff1b; …

BPMN模拟动画执行流程

目录 第一步&#xff1a;构建BPMN图 第二步&#xff1a;开启模拟 第三步&#xff1a;执行模拟 第四步&#xff1a;监听模拟 第一步&#xff1a;构建BPMN图 通过id标记&#xff0c;每一个流程 第二步&#xff1a;开启模拟 BPMN官方提供了各种各样的模块&#xff0c;比如执行…

修改linux ssh 22 端口

1、找到 sshd 的配置文件&#xff0c;增加一行 Port 32586 ,默认是 22 端口&#xff0c;记得&#xff0c;先不要把 22 端口的这一行删除&#xff0c;或者注释&#xff0c;因为我们要先验证一下&#xff0c;我们修改后的端口是否可以使用&#xff0c;都ok后&#xff0c;再把 22 …

精通Java数组的艺术:从初学者到高手的进阶之路(二)

⭐ 多维数组⭐ 数组存储表格数据⭐ Comparable 接口 ⭐ 多维数组 多维数组可以看成以数组为元素的数组。可以有二维、三维、甚至更多维数组&#xff0c;但是实际开发中用的非常少。最多到二维数组。 【eg】二维数组的声明 public class Test {public static void main(Strin…

一起看 I/O | Wear OS 更新一览

作者 / Android 开发者关系工程师 Kseniia Shumelchyk 随着 Wear OS 平台的不断发展&#xff0c;我们很高兴与您分享一些最新的功能和改进&#xff0c;以帮助您为用户打造富有吸引力的创新体验。 Peloton 和 Todoist 等合作伙伴一直以来都针对 Wear OS 打造卓越体验&#xff0c…

Android自定义一个车牌字母选择键盘

在一般和车相关的应用&#xff0c;难免会和车牌打交道&#xff0c;组成车牌的要素&#xff0c;国内无非就是省份简称地区代码英文或者数字组成&#xff0c;比如京A12345&#xff0c;在需要输入车牌的功能上&#xff0c;就需要有省份简称键盘和英文数字键盘了&#xff0c;在上篇…

深度学习(自编码器)

深度学习目录 自适应线性单元 (Widrow and Hoff, 1960)神经认知机 (Fukushima, 1980)GPU-加速 卷积网络 (Chellapilla et al., 2006)深度玻尔兹曼机 (Salakhutdinov and Hinton, 2009a)无监督卷积网络 (Jarrett et al., 2009b)GPU-加速 多层感知机 (Ciresan et al., 2010)分布…

人工影响天气期末复习笔记

&#xff08;一&#xff09;什么是人工影响天气 利用自然云微物理不稳定性&#xff0c;通过一定的技术方法改变云的微结构&#xff0c;从而改变云降水的发展过程&#xff0c;从而达到增加降水&#xff0c;防雹&#xff0c;消云雾等目的 &#xff08;二&#xff09;为什么要人工…

【历史上的今天】6 月 6 日:世界 IPv6 启动纪念日;《俄罗斯方块》发布;小红书诞生

整理 | 王启隆 透过「历史上的今天」&#xff0c;从过去看未来&#xff0c;从现在亦可以改变未来。 今天是 2023 年 6 月 6 日&#xff0c;在 2019 年的今天&#xff0c;工信部正式发放 5G 牌照。这一天&#xff0c;有四家企业被颁发了基础电信业务经营许可证&#xff0c;从此…

社区团购系统源码后台解析

近年来&#xff0c;随着购物方式的改变&#xff0c;社区团购可以说是越来越受关注了&#xff0c;大家应该对社区团购多多少少有一些认知&#xff0c;其实社区团购这样的商业模式拥有强大的赚钱的潜力&#xff0c;主要就是因为它的运营成本低&#xff0c;而且上手也不需要很复杂…

FPGA设计的指导性原则 (四)

在FPGA Express/FPGA Compiler II中,用鼠标右键单击编译后的芯片图标, 在弹出的命令对话框中选择“Edit Constraints”命令编辑综合约束文件(扩展 名为CTL),选择端口(Ports)选项卡,指定所需信号的全局时钟域为 “DONT USE”。图22所示为在FPGA Express综合约束编辑器中…

私有化部署低代码开发工具:jvs-rules 规则引擎决策流参数说明

JVS规则引擎决策调用 通过决策流水号查询入参变量 [请求参数]决策流 ​ GET/mgr/risk//test/parameter/flow/{no} 请求数据类型 application/x-www-form-urlencoded 响应数据类型 [ "*/*" ] 请求参数 参数名称 参数说明 请求类型 是否必须 数据类型 sch…

【Flutter混合开发】开发一个简单的快速启动框架

目录 前言启动插件Flutter代码Android代码IOS代码 启动模块使用android端ios端 前言 因为在移动端中启动Flutter页面会有短暂空白&#xff0c;虽然官方提供了引擎预热机制&#xff0c;但是需要提前将所有页面都进行预热&#xff0c;这样开发成本较高&#xff0c;在研究了闲鱼的…

通过点引导掩码表示的弱半监督实例分割

文章目录 The Devil is in the Points: Weakly Semi-Supervised Instance Segmentation via Point-Guided Mask Representation摘要本文方法Weakly Semi-Supervised Instance Segmentation using Point LabelsMask Refinement Network 实验结果消融实验 The Devil is in the Po…

【JavaEE】HTTP状态码-HTTP数据报的构造

HTTP状态码HTTP数据报的构造 文章目录 JavaEE & HTTP状态码 & HTTP数据报的构造1. HTTP状态码1.1 200 - OK1.2 404 - Not Found1.3 403 - Forbidden1.4 500 - Internal Server Error1.5 504 - Gateway Timeout1.6 302/301 重定向 2. 构造HTTP请求2.1 浏览器搜索栏输入u…