pytorch 模型量化quantization

news2025/1/12 17:34:22

pytorch 模型量化quantization

  • 1.workflow
    • 1.1 PTQ
    • 1.2 QAT
  • 2. demo
    • 2.1 构建resnet101_quantization模型
    • 2.2 PTQ
    • 2.3 QAT
  • 参考文献

pytorch框架提供了三种量化方法,包括:

  • Dynamic Quantization
  • Post-Training Static Quantization(PTQ)
  • Quantization Aware Training(QAT)

此博客结合CIFAR100数据集分类任务,分别采用Post-Training Static QuantizationQuantization Aware Training对resnet101模型进行量化。

1.workflow

1.1 PTQ

在这里插入图片描述

图片来自Practical Quantization in PyTorch。

1.2 QAT

在这里插入图片描述

图片来自Practical Quantization in PyTorch。

从两张图片来看,PTQ和QAT的差别在于:PTQ量化前使用了calibration data,而QAT则有一个和训练阶段类似的训练过程。

2. demo

2.1 构建resnet101_quantization模型

import torch
import torch.nn as nn
from torch.ao.quantization import QuantStub, DeQuantStub

class ConvBNReLU(nn.Sequential):
    def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1):
        padding = (kernel_size - 1) // 2
        super(ConvBNReLU, self).__init__(
            nn.Conv2d(in_planes, out_planes, kernel_size, stride,
                      padding, groups=groups, bias=False),
            nn.BatchNorm2d(out_planes, momentum=0.1),
            nn.ReLU(inplace=False)
        )


class ConvBN(nn.Sequential):
    def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1):
        padding = (kernel_size - 1) // 2
        super(ConvBN, self).__init__(
            nn.Conv2d(in_planes, out_planes, kernel_size, stride,
                      padding, groups=groups, bias=False),
            nn.BatchNorm2d(out_planes, momentum=0.1),
        )


class BottleNeck(nn.Module):
    expansion = 4

    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()
        self.residual_function = nn.Sequential(
            ConvBNReLU(in_channels, out_channels, kernel_size=1),
            ConvBNReLU(out_channels, out_channels, kernel_size=3),
            ConvBN(out_channels, out_channels *
                   BottleNeck.expansion, kernel_size=1)
        )

        self.shortcut = nn.Sequential()

        if stride != 1 or in_channels != out_channels * BottleNeck.expansion:
            self.shortcut = ConvBN(in_channels, out_channels *
                                   BottleNeck.expansion, kernel_size=1)

        self.skip_add = nn.quantized.FloatFunctional()

    def forward(self, x):
        return nn.ReLU(inplace=True)(self.skip_add.add(self.residual_function(x), self.shortcut(x)))


class ResNet(nn.Module):

    def __init__(self, block, num_block, num_classes=100):
        super().__init__()

        self.in_channels = 64

        self.conv1 = nn.Sequential(
            ConvBNReLU(3, 64, kernel_size=3))

        self.conv2_x = self._make_layer(block, 64, num_block[0], 1)
        self.conv3_x = self._make_layer(block, 128, num_block[1], 2)
        self.conv4_x = self._make_layer(block, 256, num_block[2], 2)
        self.conv5_x = self._make_layer(block, 512, num_block[3], 2)
        self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))

        self.fc = nn.Linear(512 * block.expansion, num_classes)

        self.quant = QuantStub()
        self.dequant = DeQuantStub()

    def _make_layer(self, block, out_channels, num_blocks, stride):

        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_channels, out_channels, stride))
            self.in_channels = out_channels * block.expansion

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.quant(x)
        output = self.conv1(x)
        output = self.conv2_x(output)
        output = self.conv3_x(output)
        output = self.conv4_x(output)
        output = self.conv5_x(output)
        output = self.avg_pool(output)
        output = output.view(output.size(0), -1)
        output = self.fc(output)
        x = self.dequant(x)
        return output
    # Fuse Conv+BN and Conv+BN+Relu modules prior to quantization
    # This operation does not change the numerics

    def fuse_model(self, is_qat=False):
        fuse_modules = torch.ao.quantization.fuse_modules_qat if is_qat else torch.ao.quantization.fuse_modules
        for m in self.modules():
            if type(m) == ConvBNReLU:
                fuse_modules(m, ['0', '1', '2'], inplace=True)

            if type(m) == ConvBN:
                fuse_modules(m, ['0', '1'], inplace=True)
def resnet101():
    """ return a ResNet 101 object
    """
    return ResNet(BottleNeck, [3, 4, 23, 3])

代码改编自https://github.com/weiaicunzai/pytorch-cifar100。

如果要使用quantization,构建的模型和常规模型差别主要在以下内容:


class BottleNeck(nn.Module):
    expansion = 4

    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()
        ...
        ...
        self.skip_add = nn.quantized.FloatFunctional()

    def forward(self, x):
        return nn.ReLU(inplace=True)(self.skip_add.add(self.residual_function(x), self.shortcut(x)))

这是因为没有直接用于相加的算子。

如果没有这一操作,可能会报如下错误:

NotImplementedError: Could not run 'aten::add.out' with arguments from the 'QuantizedCPU' backend. This could be because the operator doesn't exist for this backend, or was omitted during the selective/custom build process (if using custom build). 

另外就是:


class ResNet(nn.Module):
    def __init__(self, block, num_block, num_classes=100):
        super().__init__()
        ...
        ...
        self.quant = QuantStub() #observer
        self.dequant = DeQuantStub()

    def _make_layer(self, block, out_channels, num_blocks, stride):
        ...
        ...

    def forward(self, x):
        x = self.quant(x)
        ...
        ...
        x = self.dequant(x)
        return output
    # Fuse Conv+BN and Conv+BN+Relu modules prior to quantization
    # This operation does not change the numerics

    def fuse_model(self, is_qat=False):
        ...
        ...

即添加observer,以及将Conv+BN 和Conv+BN+Relu 模块融合到一起。

2.2 PTQ

def print_size_of_model(model):
    torch.save(model.state_dict(), "temp.p")
    print('Size (MB):', os.path.getsize("temp.p")/1e6)
    os.remove('temp.p')

qmodel = resnet101()
# qmodel.load_state_dict(torch.load(args.weights))

qmodel = qmodel.to(device)
# # print(qmodel)
qmodel.eval()

print("Size of model befor quantization")
print_size_of_model(qmodel)

num_calibration_batches = 32
qmodel.eval()
# Fuse Conv, bn and relu
qmodel.fuse_model()

# Specify quantization configuration
# Start with simple min/max range estimation and per-tensor quantization of weights
qmodel.qconfig = torch.ao.quantization.default_qconfig
print(qmodel.qconfig)
torch.ao.quantization.prepare(qmodel, inplace=True)

# Calibrate first
print('Post Training Quantization Prepare: Inserting Observers')


# Calibrate with the training set
criterion = nn.CrossEntropyLoss()
evaluate(qmodel, criterion, cifar100_test_loader,
             neval_batches=10)
print('Post Training Quantization: Calibration done')

# Convert to quantized model
torch.ao.quantization.convert(qmodel, inplace=True)
print('Post Training Quantization: Convert done')
# print('\n Inverted Residual Block: After fusion and quantization, note fused modules: \n\n',
#       qmodel.features[1].conv)

print("Size of model after quantization")
print_size_of_model(qmodel)
Size of model befor quantization
Size (MB): 171.40158

Size of model after quantization
Size (MB): 42.970334

size大致缩小了四倍。

经过测试,在本地cpu上推断时间也缩小了3~4倍。

2.3 QAT

#%% QAT
float_model_file="resnet101.pt"
qat_model=resnet101()
qat_model.load_state_dict(torch.load(float_model_file))
qat_model.fuse_model(is_qat=True)

optimizer = torch.optim.SGD(qat_model.parameters(), lr = 0.0001)
# The old 'fbgemm' is still available but 'x86' is the recommended default.
qat_model.qconfig = torch.ao.quantization.get_default_qat_qconfig('x86')
torch.ao.quantization.prepare_qat(qat_model, inplace=True)
# print('Inverted Residual Block: After preparation for QAT, note fake-quantization modules \n',qat_model.features[1].conv)

num_train_batches,num_eval_batches = 2,2
eval_batch_size=32
criterion = nn.CrossEntropyLoss()
nepochs=10

# QAT takes time and one needs to train over a few epochs.
# Train and check accuracy after each epoch
for nepoch in range(nepochs):
    train_one_epoch(qat_model, criterion, optimizer, cifar100_test_loader, torch.device('cpu'),  num_train_batches)
    if nepoch > 3:
        # Freeze quantizer parameters
        qat_model.apply(torch.ao.quantization.disable_observer)
    if nepoch > 2:
        # Freeze batch norm mean and variance estimates
        qat_model.apply(torch.nn.intrinsic.qat.freeze_bn_stats)

    # Check the accuracy after each epoch
    quantized_model = torch.ao.quantization.convert(qat_model.eval(), inplace=False)
    quantized_model.eval()
    top1, top5 = evaluate(quantized_model,criterion, cifar100_test_loader, neval_batches=2)
    print('Epoch %d :Evaluation accuracy on %d images, %2.2f'%(nepoch, num_eval_batches * eval_batch_size, top1.avg))

完整代码后续将分享在github或csdn资源中。

参考文献

[1] Introduction to Quantization on PyTorch
[2] https://github.com/pytorch/pytorch/wiki/Introducing-Quantized-Tensor
[3] tensorflow 训练后量化
[4] pytorch dynamic and static quantization 适用的算子
[5] ★★★pytorch_static quantization tutorial
[6] PyTorch Static Quantization
[7] Practical Quantization in PyTorch
[8] ★★★https://github.com/weiaicunzai/pytorch-cifar100

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1281810.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

基于Logistic回归实现二分类

目录 Logistic回归公式推导: Sigmoid函数: Logistic回归如何实现分类: 优化的方法: 代码: 1.创建一个随机数据集,分类直线为y2x: 为什么用np.hstack()增加一列1? 为什么返回…

协同过滤算法:个性化推荐的艺术与科学

目录 引言: 一、协同过滤算法的基本原理 二、协同过滤算法的应用领域 三、协同过滤算法的优缺点 四、协同过滤算法的未来发展方向 五、结论 引言: 在当今数字化时代,信息过载成为了一个普遍的问题。为了帮助人们更好地发现符合个性化需…

Linux驱动开发学习笔记2《LED驱动开发试验》

目录 一、Linux下LED灯驱动原理 1.地址映射 二、硬件原理图分析 三、实验程序编写 1.LED 灯驱动程序编写 2.编写测试APP 四、运行测试 1.编译驱动程序和测试APP (1)编译驱动程序 (2)编译测试APP 2.运行测试 一、Linux下…

分享81个节日PPT,总有一款适合您

分享81个节日PPT,总有一款适合您 81个节日PPT下载链接:https://pan.baidu.com/s/1V0feg5pZ8C1Szycy40CrUw?pwd6666 提取码:6666 Python采集代码下载链接:采集代码.zip - 蓝奏云 学习知识费力气,收集整理更不易…

Android CardView基础使用

目录 一、CardView 1.1 导入material库 1.2 属性 二、使用(效果) 2.1 圆角卡片效果 2.2 阴影卡片效果 2.3 背景 2.3.1 设置卡片背景(app:cardBackgroundColor) 2.3.2 内嵌布局,给布局设置背景色 2.4 进阶版 2.4.1 带透明度 2.4.2 无透明度 一、CardView 顾名…

【编码魔法师系列_构建型1.3 】抽象工厂模式(Abstract Factory)

学会设计模式,你就可以像拥有魔法一样,在开发过程中解决一些复杂的问题。设计模式是由经验丰富的开发者们(GoF)凝聚出来的最佳实践,可以提高代码的可读性、可维护性和可重用性,从而让我们的开发效率更高。通…

vs 安装 qt qt扩展 改迅雷下载qt

Qt5.14.2安装教程和VS2019中的qt环境配置-CSDN博客 1 安装qt 社区版 免费 Download Qt OSS: Get Qt Online Installer 2 vs安装 qt vs tools 3 vs添加 qt添加 bin/cmake.exe 路径 3.1 扩展 -> qt versions 3.2 4 新版要源码安装 需要自己安装 安装独立安装的旧版 官网…

pygame时序模块time

文章目录 简介时钟对象平抛运动 pygame系列:初步💎加载图像💎图像变换💎直线绘制 简介 之前在更新图形的时候,为了调控死循环的响应时间,用到了time.sleep。而实际上,我们并不需要额外导入其他…

最强Node js 后端框架学习看这一篇文章就够

距离上次认真花时间写作,似乎已经过了许久许久,前端讲了一个新框架 ,叫 Nest.js 下方是课件,有过一定开发经验可跟随视频学习 B站 地址 : https://www.bilibili.com/video/BV1Lg4y197u1/?vd_sourcead427ffaf8a5c8344…

【计算机网络笔记】物理层——数据通信基础

系列文章目录 什么是计算机网络? 什么是网络协议? 计算机网络的结构 数据交换之电路交换 数据交换之报文交换和分组交换 分组交换 vs 电路交换 计算机网络性能(1)——速率、带宽、延迟 计算机网络性能(2)…

【1】基于多设计模式下的同步异步日志系统-项目介绍

1. 项目介绍 本项⽬主要实现⼀个日志系统, 其主要支持以下功能: • 支持多级别日志消息 • 支持同步日志和异步日志 • 支持可靠写⼊日志到控制台、文件以及滚动文件中 • 支持多线程程序并发写日志 • 支持扩展不同的日志落地⽬标地 2. 开发环境 • CentOS 7 • vs…

Node版本管理nvm工具安装及使用问题

安装和配置 下载地址 nvm官方下载window环境直接下nvm-setup.zip解压安装即可。 安装效验以及镜像配置 在cmd中,输入nvm -v 会反馈相应的安装版本,即表示安装成功。配置镜像源: nvm node_mirror https://npm.taobao.org/mirrors/node/ nvm npm_mir…

GitHub上1.5K标星的QA和软件测试学习路线图

​最近在GitHub上发现一个项目,项目描述了作为QA工程师,进行软件测试技能提升时的,建议的软件测试学习顺序图​。 虽然2021年起就不再更新了,但是居然有1.5K的​星。 整个项目有两个部分​: ​1.QA和软件测试学习顺序…

线程池大小设置多少,比较合适?

设置线程数的核心点 压测!压测!再压测!实际对性能要求比较高的场景,压测是最佳的方式! 并发编程适用于什么场景? CPU 密集型 对于 CPU 密集型任务,希望最大限度地提高 CPU 利用率&#xff0c…

使用String.valueOf()的坑

说明:记录一次使用String.valueOf()的坑,以下是一段有问题的代码: String count String.valueOf(listData.get(0).get(0).get("count");if (StringUtils.isBlank(count) || "0".equals(count)) {result.setResult(page)…

LLM推理部署(五):AirLLM使用4G显存即可在70B大模型上进行推理

众所周知,大模型的训练和推理需要大量的GPU资源,70B参数的大模型需要130G的GPU显存来存储,需要两个A100(显存为100G)。 ​ 在推理过程中,整个输入序列也需要加载到内存中进行复杂的“注意力”计算&am…

【WinForm.NET开发】演示:创建一个图片查看器 Windows 窗体应用

本文演示将创建一个 Windows 窗体应用程序,用于加载和显示图片。 Visual Studio 集成设计环境 (IDE) 提供了创建应用所需的工具。 1、先决条件 若要完成本教程,必须具有 Visual Studio。 请访问Visual Studio 下载页获取免费版本。 2、创建 Windows …

C语言扫雷游戏

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 前言一、扫雷游戏的分析和设计1.1扫雷游戏的功能说明1.2数据结构的分析1.3文件结构设计 二、扫雷游戏的代码实现总结 前言 详细介绍扫雷游戏的思路和实现过程。 一…

基于Java SSM框架实现美好生活九宫格日志网站系统项目【项目源码+论文说明】计算机毕业设计

基于java的SSM框架实现美好生活九宫格日志网站系统演示 摘要 21世纪的今天,随着社会的不断发展与进步,人们对于信息科学化的认识,已由低层次向高层次发展,由原来的感性认识向理性认识提高,管理工作的重要性已逐渐被人…

风险评估有什么用

风险评估就是量化测评某一事件或事物带来的影响或损失的可能程度。 为什么要做风险评估? 1.更准确地认识风险-系统地评估资产风险事件发生的概率大小和概率分布,及发生后损失的严重程度。帮助区分主要风险和次要风险。 2.保证规划的合理性和可行性-正确反映各风…