Pytorch QAT for UNet

news2025/1/11 18:40:54

对UNet进行Pytorch QAT量化感知训练研究了一周,终于跑通了,中间踩了不少坑,特此把正常操作记录一下,以备后续参考。

Pytorch提供了两种量化模式:Eager Mode 和FX Graph Mode.

Eager Mode需要手动指定需要融合(Fusion)的层,以及量化和反量化的位置,非常不好用,最开始我就是用的这种方式,踩了很多坑之后,虽然QAT训练完成了,但是在转换成int8模型的时候又报错,后来索性放弃该模式,直接使用FX Graph模式了。

FX Graph Mode虽然也没那么好用,但是它已经比Eager Mode方便多了,毕竟是一个自动化的量化框架。

下图给出了两种模式的比较:

好了,废话不多说,直接上代码说明网络的QAT过程吧。

1. 训练浮点模型—>QAT训练—>转换成int8模型

首先,需要包含我们使用到的相关量化库:

import torch
import copy
from torch.quantization import quantize_fx

接下来,创建一个Float32的新模型,并训练:

    # 根据自己的机器配置选择合适的device
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    
    #Create a new model and train from scratch
    model = unet(4, 4)   # unet是提前定义好的模型,输入和输出都为4通道数据
    model.to(device)  # 将模型拷贝到Device
    train_model(model, 10)   # train_model是提前定义的模型训练函数,本例中为了验证简单,先进行了10个epoch的训练
    torch.save(model.state_dict(), 'model_fp32.pth')  # 保存state_dict
    print('Train over.')

接下来,我们需要进行一些QAT设置:

print('Begin QAT...')
model_to_quantize = copy.deepcopy(model)
model_to_quantize.train()  # Set model mode to train

# Get default qconfig
qconfig_dict = {"": torch.quantization.get_default_qconfig('qnnpack')}
# Prepare model
model_prepared = quantize_fx.prepare_qat_fx(model_to_quantize, qconfig_dict)

model_prepared.to(device)  # 模型拷贝至用于训练的device
train_model(model_prepared, 5)   # 使用与Float32模型同样的训练函数,对prepared模型继续训练若干轮,这里为了方便,我只设置了5轮
torch.save(model_prepared.state_dict(), 'model_prepared.pth')  # 保存prepared模型

# Convert model to int8
print('Converting model to int8...')
model_quantized = quantize_fx.convert_fx(model_prepared)   # 将prepared模型转换成真正的int8定点模型
print('Convert done.')
torch.save(model_quantized.state_dict(), 'model_int8.pth')  # 保存定点模型的state_dict

以上代码中,对原始的Float32模型是从头训练的,其实我们也可以把训练好的浮点模型加载进来,再继续通过QAT训练之后进行量化。

2. 加载预训练好的浮点模型参数—>QAT训练—>转换成int8模型

首先,加载已经训练好的Float32模型:

# 实例化一个模型
model = unet(4, 4)
model.to(device)
# 加载提前训练好的模型参数
checkpoints = torch.load(model_path, map_location=device)
model.load_state_dict(checkpoints)

# 接下来直接进行QAT准备和训练
print('Begin QAT...')
model_to_quantize = copy.deepcopy(model)
model_to_quantize.train()  # Set model mode to train

# Get default qconfig
qconfig_dict = {"": torch.quantization.get_default_qconfig('qnnpack')}
# Prepare model
model_prepared = quantize_fx.prepare_qat_fx(model_to_quantize, qconfig_dict)

model_prepared.to(device)  # 模型拷贝至用于训练的device
train_model(model_prepared, 5)   # 使用与Float32模型同样的训练函数,对prepared模型继续训练若干轮,这里为了方便,我只设置了5轮
torch.save(model_prepared.state_dict(), 'model_prepared.pth')  # 保存prepared模型

# Convert model to int8
print('Converting model to int8...')
model_quantized = quantize_fx.convert_fx(model_prepared)   # 将prepared模型转换成真正的int8定点模型
print('Convert done.')
torch.save(model_quantized.state_dict(), 'model_int8.pth')  # 保存定点模型的state_dict

3. int8模型的使用

那么,对于训练好的int8模型,怎样调用来做推理呢?这个时候,直接拿原来的模型结构来加载就会失败,需要我们把原来的模型结构,按照QAT流程转换成int8形式之后,再进行加载,具体见代码:

# 加载int8模型的参数
state_dict_int8 = torch.load('model_int8.pth', map_location=device)

# 实例化原始模型
model = unet(4,4)
model_to_quantize = copy.deepcopy(model)

# 获取qconfig参数
qconfig_dict = {"": torch.quantization.get_default_qconfig('qnnpack')}
# 模型prepare
model_prepared = quantize_fx.prepare_qat_fx(model_to_quantize, qconfig_dict)
# 将prepared模型转换成int8结构
model_quantized = quantize_fx.convert_fx(model_prepared)
# 用转换出的int8模型结构加载int8模型参数
model_quantized.load_state_dict(state_dict_int8)
# 设置int8模型模式为eval
model_quantized.eval()

# Pre-process for input_data
# int8模型的调用,input_data是符合输入要求的4通道数据,output_data是模型输出的4通道数据,注意这里我省略了输入输出数据的前后处理,主要展示模型的QAT过程及定点化模型在pytorch中的调用方法。
output_data = model_quantized(input_data)
# Post-process for output_data

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/667743.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

迪赛智慧数——饼图(环形饼图):2022年618期间各品类销售额分布

效果图 2022年“618”全网成交额达6959亿元,较2021年增加了1174.20亿元,同比增长20.30%。“618”网购狂欢节首先是由京东发起的,京东618每年6月是京东的店庆月,2022年京东“618”成交额达3793亿元,较2021年增加了355亿…

性能测试工具:Jmeter介绍

JMeter是一个开源的Java应用程序,由Apache软件基金会开发和维护,可用于性能测试、压力测试、接口测试等。 1. 原理 JMeter的基本原理是模拟多用户并发访问应用程序,通过发送HTTP请求或其他协议请求,并测量响应时间、吞吐量、并发…

VSCode远程开发入门指南

我的开发环境 我的开发主机是一台Centos7的远端云服务器,在本地的Windows电脑上使用xshell进行ssh连接,在Windows使用vscode的Remote进行远程连接,进行编写与开发,主要是C领域的开发 为什么不推荐使用vim 一个趁手的编辑器在开…

【Linux】网络编程基础包含TCP详解

目录 网络结构模式C/S结构B/S结构 MAC地址IP地址端口网络模型OSI七层模型TCP/IP四层模型 通信过程数据包封装协议以太网协议ARP协议IP数据报格式UDP协议格式TCP协议格式封装分用TCP详解TCP和UDPTCP通信流程TCP三次握手TCP滑动窗口TCP四次挥手 网络结构模式 C/S结构 客户机-服务…

03、非受控组件与受控组件、高阶函数、prop-types、生命周期、hook

总结 一、非受控组件与受控组件 非受控组件 表单项不与state数据相向关联, 需要手动读取表单元素的值 借助于 ref获取真实DOM,在通过value获得输入值,使用原生 DOM 方式来获取表单元素值 非受控组件: 表单项不与 state 数据相向关联, 需要手动读取表…

宏景eHR SQL注入漏洞复现(CNVD-2023-08743)

0x01 产品简介 宏景eHR人力资源管理软件是一款人力资源管理与数字化应用相融合,满足动态化、协同化、流程化、战略化需求的软件。 0x02 漏洞概述 宏景eHR 存在SQL注入漏洞,未经过身份认证的远程攻击者可利用此漏洞执行任意SQL指令,从而窃取数…

android 如何分析应用的内存(七)——malloc hook

android 如何分析应用的内存(七) 接上文,介绍六大板块中的第二个————malloc hook 上一篇的自定义分配函数,常常只能解决当前库中的分配,而不能跟踪整个app中的分配。 为此,android的libc库&#xff…

正运动即将亮相2023年深圳激光展,助力个性化激光智能制造!

■展会名称: 第⼗六届深圳国际激光与智能装备、光子技术博览会(以下简称“深圳激光展”) ■展会日期 2023年6月27日-29日 ■展馆地点 深圳国际会展中心(宝安新馆) ■展位号 9D115 激光加工是一种基于光热效应的…

STM32的中断系统详解(嵌入式学习)

中断系统 1. 基本概念2. 中断的意义3. 中断处理过程处理过程过程详述 4. 中断体系结构5. NVIC概念主要功能 6. EXTI概念主要功能结构框图中断和事件的区别 7. 总结 1. 基本概念 中断是处理器中的一种机制,用于响应和处理突发事件或紧急事件。当发生中断时&#xff…

每日学术速递6.9

CV - 计算机视觉 | ML - 机器学习 | RL - 强化学习 | NLP 自然语言处理 Subjects: cs.CV 1.Segment Anything in High Quality 标题:以高质量分割任何内容 作者:Lei Ke, Mingqiao Ye, Martin Danelljan, Yifan Liu, Yu-Wing Tai, Chi-Keung Tang, …

Reids分布式锁详细介绍原理和实现

Reids 分布式锁 问题描述 1、单体单机部署的系统被演化成分布式集群系统后 2、由于分布式系统多线程、多进程并且分布在不同机器上,这将使原单机部署情况下的并发控制锁策略失效 3、单纯的Java API 并不能提供分布式锁的能力 4、为了解决这个问题就需要一种跨J…

abd shell后,getevent退出方法

abd shell后,getevent退出方法 输入 exit 然后回车退出

一种很新的交互式智能标注技术

随着人工智能应用的大规模落地,数据标注市场在高速增长的同时,也面临着标注成本的挑战。据IDC报告显示:数据标注在AI应用开发过程中所耗费的时间占到了25%,部分医学类应用一条数据的标注成本甚至高达20元。数据精度的高要求、强人…

RocketMQ 环境搭建

环境:linux(centos) 或 windos; jdk 1.8 场景:rocket入门学习 时间:2023-04-20 吐槽:可能是本人学习能力不足,想使用docker搭建rocketmq 一直失败,可能是我想使用的比较新…

正排倒排,并不是 MySQL 的排序的全部!

引言 一个悠闲的上午,小航送了我,一袋坚果,他看我吃的正香,慢慢问道:”温哥,mysql的排序,有什么要注意的吗,不就是正排倒排吗?” 我一听他问我的问题,顿感坚…

软件测试简历如何包装?

首先明确的包装简历不等于欺骗,只是把你的最好一面展示出来,给别人一个好的映像;(就相当于相亲,哈哈) 无论如何包装简历,注意简历上的东西一定要会、一定要会、一定要会(面试官一般…

Java框架-Spring

文章目录 1、你了解Spring IOC吗?2、SpringIOC的应用?3、SpringIOC的getBean方法的解析?4、面试题5、你了解Spring AOP吗?6、事务ACID特性7、事务传播 1、你了解Spring IOC吗? IoC(Inversion of control&a…

C++编程启蒙-2——你适合学习编程吗?

英语差,数学孬,照样可以学好编程。但,如果你逻辑思维差,动力能力弱,那么学习编程真的会难上加难。本课用来帮助读者实现对逻辑思维与动手能力的自我判断,并给出了实际测试方案。 英语差,数学孬&…

15个常见的AI绘画网站推荐

无论你是专业的艺术家还是对人工智能绘画感兴趣的普通人,AI绘画网站都可以为你提供新的创作灵感和艺术体验,给艺术界带来更多的创新和可能性。以下是15个常见的AI绘画网站的介绍。 即时 AI 灵感 「即时 AI 灵感」是通过文字描述等方式生成精致图像的AI…

QGIS实现shape、geojson数据的矢量切片教程

能够实现矢量切片的办法有很多,可以使用geoserver,可以使用qgis,当然也可以自己写代码实现。这篇文章我们来介绍一下如何使用qgis完成shape数据的矢量切片。 首先我们还是要准备一份矢量数据。矢量数据的格式是shape文件或者是geojson文件都…