量化自定义PyTorch模型入门教程

news2024/10/6 20:25:26

在以前Pytorch只有一种量化的方法,叫做“eager mode qunatization”,在量化我们自定定义模型时经常会产生奇怪的错误,并且很难解决。但是最近,PyTorch发布了一种称为“fx-graph-mode-qunatization”的方方法。在本文中我们将研究这个fx-graph-mode-qunatization”看看它能不能让我们的量化操作更容易,更稳定。

本文将使用CIFAR 10和一个自定义AlexNet模型,我对这个模型进行了小的修改以提高效率,最后就是因为模型和数据集都很小,所以CPU也可以跑起来。

 import os
 import cv2
 import time
 import torch
 import numpy as np
 import torchvision
 from PIL import Image
 import torch.nn as nn
 import matplotlib.pyplot as plt
 from torchvision import transforms
 from torchvision import datasets, models, transforms
 
 device = "cpu"
 
 print(device)
 transform = transforms.Compose([
     transforms.Resize(224),
     transforms.ToTensor(),
     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
     ])
 
 batch_size = 8
 
 trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                         download=True, transform=transform)
 
 testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                        download=True, transform=transform)
 
 trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                           shuffle=True, num_workers=2)
 
 testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                          shuffle=False, num_workers=2)
 
 def print_model_size(mdl):
     torch.save(mdl.state_dict(), "tmp.pt")
     print("%.2f MB" %(os.path.getsize("tmp.pt")/1e6))
     os.remove('tmp.pt')

模型代码如下,使用AlexNet是因为他包含了我们日常用到的基本层:

 from torch.nn import init
 class mAlexNet(nn.Module):
     def __init__(self, num_classes=2):
         super().__init__()
         self.input_channel = 3
         self.num_output = num_classes
         
         self.layer1 = nn.Sequential(
             nn.Conv2d(in_channels=self.input_channel, out_channels= 16, kernel_size= 11, stride= 4),
             nn.ReLU(inplace=True),
             nn.MaxPool2d(kernel_size=3, stride=2)
         )
         init.xavier_uniform_(self.layer1[0].weight,gain= nn.init.calculate_gain('conv2d'))
 
         self.layer2 = nn.Sequential(
             nn.Conv2d(in_channels= 16, out_channels= 20, kernel_size= 5, stride= 1),
             nn.ReLU(inplace=True),
             nn.MaxPool2d(kernel_size=3, stride=2)
         )
         init.xavier_uniform_(self.layer2[0].weight,gain= nn.init.calculate_gain('conv2d'))
 
         self.layer3 = nn.Sequential(
             nn.Conv2d(in_channels= 20, out_channels= 30, kernel_size= 3, stride= 1),
             nn.ReLU(inplace=True),
             nn.MaxPool2d(kernel_size=3, stride=2)
         )
         init.xavier_uniform_(self.layer3[0].weight,gain= nn.init.calculate_gain('conv2d'))
        
 
         self.layer4 = nn.Sequential(
             nn.Linear(30*3*3, out_features=48),
             nn.ReLU(inplace=True)
         )
         init.kaiming_normal_(self.layer4[0].weight, mode='fan_in', nonlinearity='relu')
 
         self.layer5 = nn.Sequential(
             nn.Linear(in_features=48, out_features=self.num_output)
         )
         init.kaiming_normal_(self.layer5[0].weight, mode='fan_in', nonlinearity='relu')
 
 
     def forward(self, x):
         x = self.layer1(x)
         x = self.layer2(x)
         x = self.layer3(x)
         
         # Squeezes or flattens the image, but keeps the batch dimension
         x = x.reshape(x.size(0), -1)
         x = self.layer4(x)
         logits= self.layer5(x)
         return logits
 
 model = mAlexNet(num_classes= 10).to(device)

现在让我们用基本精度模型做一个快速的训练循环来获得基线:

 import torch.optim as optim 
 
 def train_model(model):
   criterion =  nn.CrossEntropyLoss()
   optimizer = optim.SGD(model.parameters(), lr=0.001, momentum = 0.9)
 
   for epoch in range(2):
     running_loss =0.0
     
     for i, data in enumerate(trainloader,0):
       
       inputs, labels = data
       inputs, labels = inputs.to(device), labels.to(device)
 
       optimizer.zero_grad()
       outputs = model(inputs)
       loss = criterion(outputs, labels)
       loss.backward()
       optimizer.step()
 
       # print statistics
       running_loss += loss.item()
       if i % 1000 == 999:
         print(f'[Ep: {epoch + 1}, Step: {i + 1:5d}] loss: {running_loss / 2000:.3f}')
         running_loss = 0.0
   
   return model
 
 model = train_model(model)
 PATH = './float_model.pth'
 torch.save(model.state_dict(), PATH)

可以看到损失是在降低的,我们这里只演示量化,所以就训练了2轮,对于准确率我们只做对比。

我将做所有三种可能的量化:

  1. 动态量化 Dynamic qunatization:使权重为整数(训练后)
  2. 静态量化 Static quantization:使权值和激活值为整数(训练后)
  3. 量化感知训练 Quantization aware training:以整数精度对模型进行训练

我们先从动态量化开始:

 import torch
 from torch.ao.quantization import (
   get_default_qconfig_mapping,
   get_default_qat_qconfig_mapping,
   QConfigMapping,
 )
 import torch.ao.quantization.quantize_fx as quantize_fx
 import copy
 
 # Load float model
 model_fp = mAlexNet(num_classes= 10).to(device)
 model_fp.load_state_dict(torch.load("./float_model.pth", map_location=device))
 
 # Copy model to qunatize
 model_to_quantize = copy.deepcopy(model_fp).to(device)
 model_to_quantize.eval()
 qconfig_mapping = QConfigMapping().set_global(torch.ao.quantization.default_dynamic_qconfig)
 
 # a tuple of one or more example inputs are needed to trace the model
 example_inputs = next(iter(trainloader))[0]
 
 # prepare
 model_prepared = quantize_fx.prepare_fx(model_to_quantize, qconfig_mapping, 
                   example_inputs)
 # no calibration needed when we only have dynamic/weight_only quantization
 # quantize
 model_quantized_dynamic = quantize_fx.convert_fx(model_prepared)

正如你所看到的,只需要通过模型传递一个示例输入来校准量化层,所以代码十分简单,看看我们的模型对比:

 print_model_size(model)
 print_model_size(model_quantized_dynamic)

可以看到的,减少了0.03 MB或者说模型变为了原来的75%,我们可以通过静态模式量化使其更小:

 model_to_quantize = copy.deepcopy(model_fp)
 qconfig_mapping = get_default_qconfig_mapping("qnnpack")
 model_to_quantize.eval()
 # prepare
 model_prepared = quantize_fx.prepare_fx(model_to_quantize, qconfig_mapping, example_inputs)
 # calibrate 
 with torch.no_grad():
     for i in range(20):
         batch = next(iter(trainloader))[0]
         output = model_prepared(batch.to(device))

静态量化与动态量化是非常相似的,我们只需要传递更多批次的数据来更好地校准模型。

让我们看看这些步骤是如何影响模型的:

可以看到其实程序为我们做了很多事情,所以我们才可以专注于功能而不是具体的实现,通过以上的准备,我们可以进行最后的量化了:

 # quantize
 model_quantized_static = quantize_fx.convert_fx(model_prepared)

量化后的model_quantized_static看起来像这样:

现在可以更清楚地看到,将Conv2d和Relu层融合并替换为相应的量子化对应层,并对其进行校准。可以将这些模型与最初的模型进行比较:

 print_model_size(model)
 print_model_size(model_quantized_dynamic)
 print_model_size(model_quantized_static)

量子化后的模型比原来的模型小3倍,这对于大模型来说非常重要

现在让我们看看如何在量化的情况下训练模型,量化感知的训练就需要在训练的时候加入量化的操作,代码如下:

 model_to_quantize = mAlexNet(num_classes= 10).to(device)
 qconfig_mapping = get_default_qat_qconfig_mapping("qnnpack")
 model_to_quantize.train()
 # prepare
 model_prepared = quantize_fx.prepare_qat_fx(model_to_quantize, qconfig_mapping, example_inputs)
 
 # training loop 
 model_trained_prepared = train_model(model_prepared)
 
 # quantize
 model_quantized_trained = quantize_fx.convert_fx(model_trained_prepared)

让我们比较一下到目前为止所有模型的大小。

 print("Regular floating point model: " )
 print_model_size( model_fp)
 print("Weights only qunatization: ")
 print_model_size( model_quantized_dynamic)
 print("Weights/Activations only qunatization: ")
 print_model_size(model_quantized_static)
 print("Qunatization aware trained: ")
 print_model_size(model_quantized_trained)

量化感知的训练对模型的大小没有任何影响,但它能提高准确率吗?

 def get_accuracy(model):
   correct = 0
   total = 0
   with torch.no_grad():
       for data in testloader:
           images, labels = data
           images, labels = images, labels
           outputs = model(images)
           _, predicted = torch.max(outputs.data, 1)
           total += labels.size(0)
           correct += (predicted == labels).sum().item()
 
       return 100 * correct / total
 
 fp_model_acc = get_accuracy(model)
 dy_model_acc = get_accuracy(model_quantized_dynamic)
 static_model_acc = get_accuracy(model_quantized_static)
 q_trained_model_acc = get_accuracy(model_quantized_trained)
 
 
 print("Acc on fp_model:" ,fp_model_acc)
 print("Acc weigths only quantization:", dy_model_acc)
 print("Acc weigths/activations quantization" ,static_model_acc)
 print("Acc on qunatization awere trained model:" ,q_trained_model_acc)

为了更方便的比较,我们可视化一下:

可以看到基础模型与量化模型具有相似的准确性,但模型尺寸大大减小,这在我们希望将其部署到服务器或低功耗设备上时至关重要。

最后一些资料:

https://pytorch.org/tutorials/prototype/fx_graph_mode_ptq_static.html#motivation-of-fx-graph-mode-quantization

https://pytorch.org/docs/stable/quantization.html

本文代码:

https://avoid.overfit.cn/post/a72a7478c344466581295418f1620f9b

作者:mor40

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/970483.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【JAVA】多态

作者主页:paper jie_的博客 本文作者:大家好,我是paper jie,感谢你阅读本文,欢迎一建三连哦。 本文录入于《JAVASE语法系列》专栏,本专栏是针对于大学生,编程小白精心打造的。笔者用重金(时间和…

【Sentinel】ProcessorSlotChain处理器插槽链与Node

文章目录 1、Sentinel的基本概念2、ProcessorSlotChain3、Node 1、Sentinel的基本概念 Sentinel实现限流、隔离、降级、熔断等功能,本质要做的就是两件事情: 统计数据:统计某个资源的访问数据(QPS、RT等信息)规则判断…

FPGA输出lvds信号点亮液晶屏

1 概述 该方案用于生成RGB信号,通过lvds接口驱动逻辑输出,点亮并驱动BP101WX-206液晶屏幕。 参考:下面为参考文章,内容非常详细。Xilinx LVDS Output——原语调用_vivado原语_ShareWow丶的博客http://t.csdn.cn/Zy37p 2 功能描述 …

从零开始学习 Java:简单易懂的入门指南之Collection集合及list集合(二十一)

Collection集合及list集合 1.Collection集合1.1数组和集合的区别1.2集合类体系结构1.3Collection 集合概述和使用1.4Collection集合的遍历1.4.1 迭代器遍历1.4.2 增强for1.4.3 lambda表达式 2.List集合2.1List集合的概述和特点2.2List集合的特有方法2.3List集合的五种遍历方式2…

JS动态计算自动滚动距离

先上效果 具体实现代码&#xff08;如果用到vue项目中的css要取消scoped否则不生效&#xff09; 在这里插入代码片<!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8" /><meta http-equiv"X-UA-Compatible"…

基于STM32的厨房环境监测系统

前言 本篇文章将之前所有的文章进行整合&#xff0c;是之前几篇文章的综合版。 MQ-2烟雾传感器模块功能实现&#xff08;STM32&#xff09; MQ-7一氧化碳传感器模块功能实现&#xff08;STM32&#xff09; dht11温湿度模块功能实现&#xff08;STM32&#xff09; 0…

返回序列中最大值第一次出现时对应的索引(位置)Series.idxmax()

【小白从小学Python、C、Java】 【计算机等级考试500强双证书】 【Python-数据分析】 返回序列中最大值第一次 出现时对应的索引(位置) Series.idxmax() [太阳]选择题 以下说法错误的是? import pandas as pd apd.Series(data[1,6,None,5,6], index[A,B,C,D,E]) print(【显示】…

Spring MVC:域对象共享数据

Spring MVC 前言域对象共享数据使用 ModelAndView 向 request 域对象中共享数据使用 Map 、Model 或 ModelMap 向 request 域对象中共享数据使用 SesionAttributes 注解向 session 域对象中共享数据使用 Servlet API 向 application 域对象中共享数据 附 前言 在上一章中&…

发收一体的2.4G射频合封芯片Y62G,内置九齐MCU

宇凡微2.4GHz发收一体合封芯片Y62G是一款高度集成的系统芯片&#xff0c;融合了2.4G芯片G350和微控制器&#xff08;MCU&#xff09;功能&#xff0c;为开发人员提供了更好的设计自由度和成本效益的解决方案。以下是Y62G芯片的主要特点和优势&#xff1a; 高度合封集成 Y62G芯…

细说GNSS模拟器的RTK功能(二)应用实例01 — 硬件和软件设置

在之前的文章中&#xff0c;我们介绍了什么是RTK&#xff0c;接下来我们将为大家展示RTK使用实例&#xff0c;可以通过两种不同的方法来模拟RTCM的使用&#xff0c;一种是基于RTCM插件&#xff0c;另一种是基于多实例来模拟两个同步的射频信号。 RTK插件方法可以帮助没有基础接…

iOS系统下轻松构建自动化数据收集流程

在当今信息爆炸的时代&#xff0c;我们经常需要从各种渠道获取大量的数据。然而&#xff0c;手动收集这些数据不仅耗费时间和精力&#xff0c;还容易出错。幸运的是&#xff0c;在现代科技发展中有两个强大工具可以帮助我们解决这一问题——Python编程语言和iOS设备上预装的Sho…

RPC框架

博主介绍&#xff1a;✌全网粉丝3W&#xff0c;全栈开发工程师&#xff0c;从事多年软件开发&#xff0c;在大厂呆过。持有软件中级、六级等证书。可提供微服务项目搭建与毕业项目实战&#xff0c;博主也曾写过优秀论文&#xff0c;查重率极低&#xff0c;在这方面有丰富的经验…

将虚拟机网络适配器改为仅主机模式,Vmware弹出“仅主机模式适配器驱动程序似乎未运行

这个问题的原因是&#xff0c;主机上缺乏VMware安装后没有VMnet1和VMnet8网卡。 所以

2023年超爆火的15款AI设计软件

随着人工智能技术的快速发展&#xff0c;数字插画之外的“泛设计”行业的从业者也开始在AI中逐渐受益。可能很多设计师还停留在“AI设计软件只能做一些动漫风格插画”的认知中&#xff0c;实际上受到行业需求提升的刺激&#xff0c;软件厂商已经开始积极研究并发布更多针对特定…

uni——input的提示语(placeholder)修改样式等

案例说明 操控input的提示语 案例代码 <template><view><view><input type"text" placeholder"请输入内容" :placeholder-class"isDialogHidden?redColor:" /><button click"hideDialog">按钮</…

正中优配:股票k线图入门?

随着股票市场的不断发展&#xff0c;对股民们来说&#xff0c;了解股票行情变得越来越重要。而股票k线图能够帮助股民们更好地调查和剖析股票行情。但关于一些没有相关经历的新手来说&#xff0c;股票k线图可能会带来一些困惑。那么&#xff0c;股票k线图入门应该从哪些方面着手…

蓝牙资讯|2023Q2全球TWS耳机出货量同比增长15%

TechInsights 报告指出&#xff0c;2023 年二季度全球 TWS 耳机出货量同比增长 15%&#xff0c;收益同比增长 5.1%。 苹果仍以 17% 出货量份额和 43% 的收益份额主导 TWS 市场&#xff0c;但来自印度和中国厂商的竞争&#xff0c;使苹果的份额有所下降。 在收益方面&#x…

浏览器是如何验证SSL证书的?

事实上&#xff0c;SSL证书作为目前网站数据安全的第一道防线&#xff0c;已被大部分企业所熟知。然而&#xff0c;这份认知主要是关于SSL证书可以实现网站HTTPS加密保护及身份的可信认证&#xff0c;防止传输数据的泄露或篡改方面&#xff0c;对于浏览器到底是如何验证SSL证书…

python安装wind10

一、下载: 官网:Python Releases for Windows | Python.org 二、安装 双击下载的安装程序文件。这将打开安装向导。安装界面图下方两个框的" Use admin privileges wheninstalling py. exe和” Add python. exe to PATH"都要勾选,一定要勾选!一定要勾选!一定要勾选…

昆明Sectigo dv通配符https证书

Sectigo是近些年发展比较快速的CA认证机构&#xff0c;Sectigo比较重视国际发展&#xff0c;先后成立了亚太审核中心等机构&#xff0c;是一家全球知名的数字证书颁发机构&#xff0c;Sectigo成立几十年来在全球范围内都受到信任。 1.Sectigo旗下的通配符https证书是市场占有率…