QAT量化 demo

news2025/1/10 0:22:08

一、QAT量化基本流程

QAT过程可以分解为以下步骤:

  1. 定义模型:定义一个浮点模型,就像常规模型一样。
  2. 定义量化模型:定义一个与原始模型结构相同但增加了量化操作(如torch.quantization.QuantStub())和反量化操作(如torch.quantization.DeQuantStub())的量化模型。
  3. 准备数据:准备训练数据并将其量化为适当的位宽。
  4. 训练模型:在训练过程中,使用量化模型进行正向和反向传递,并在每个 epoch 或 batch 结束时使用反量化操作计算精度损失。
  5. 重新量化:在训练过程中,使用反量化操作重新量化模型参数,并使用新的量化参数继续训练。
  6. Fine-tuning:训练结束后,使用fine-tuning技术进一步提高模型的准确率。

在这里插入图片描述

二、QAT量化代码示例

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.quantization import QuantStub, DeQuantStub, quantize_dynamic, prepare_qat, convert

# 模型
class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        # 量化
        self.quant = QuantStub()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(128, 10)
        # 反量化
        self.dequant = DeQuantStub()

    def forward(self, x):
        # 量化
        x = self.quant(x)
        x = self.conv1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.relu(x)
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        # 反量化
        x = self.dequant(x)
        return x

# 数据
transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])])
train_data = datasets.CIFAR10(root='./data', train=True, download=True,transform=transform)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=1,shuffle=True, num_workers=0)

# 模型 优化器
model = MyModel()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

# Prepare the model
model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
model = prepare_qat(model)

# 训练
model.train()
for epoch in range(1):
    for i, (data, target) in enumerate(train_loader):
        optimizer.zero_grad()
        output = model(data)
        loss = nn.CrossEntropyLoss()(output, target)
        loss.backward()
        optimizer.step()
        if i % 100 == 0:
            print('Epoch: [%d/%d], Step: [%d/%d], Loss: %.4f' %
                  (epoch+1, 10, i+1, len(train_loader), loss.item()))

    # Re-quantize the model
    model = quantize_dynamic(model, {'': torch.quantization.default_dynamic_qconfig}, dtype=torch.qint8)

# 微调
model.eval()
for data, target in train_loader:
    model(data)
model = convert(model, inplace=True)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1589689.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Vue3+Ant Design表格排序

最近在公司做有关报表的项目时,遇到最多的问题-表格排序,刚开始看到UI设计图的时候,还有些纳闷这个排序如何做,其实实际上并没有想象中的那么难,如果说单纯的排序的话ant design这个组件里的表格有自带的排序和筛选功能…

知名专业定制线缆源头工厂推荐-精工电联:智能制造线缆的前沿技术探索

优质定制线缆源头厂家推荐-精工电联:智能制造线缆的前沿技术探索 知名专业定制线缆源头工厂推荐-精工电联:智能制造线缆的前沿技 在科技飞速发展的当今时代,智能制造已成为工业4.0的核心驱动力。精工电联,作为智能制造领先的高品质…

基于Python的LSTM网络实现单特征预测回归任务(pytorch版)

一、数据集 自建数据集--【load.xlsx】。包含2列: date列(时间列,记录2022年6月2日起始至2023年12月31日为止,日度数据)price列(价格列,记录日度数据对应的某品牌衣服的价格,浮点数…

【LAMMPS学习】八、基础知识(2.2)类型标签

8. 基础知识 此部分描述了如何使用 LAMMPS 为用户和开发人员执行各种任务。术语表页面还列出了 MD 术语,以及相应 LAMMPS 手册页的链接。 LAMMPS 源代码分发的 examples 目录中包含的示例输入脚本以及示例脚本页面上突出显示的示例输入脚本还展示了如何设置和运行各…

请核对您的姓名、证件号码、有效期和年限是否选择正确,请勿使用挂失过的身份证

问题 请核对您的姓名、证件号码、有效期和年限是否选择正确,请勿使用挂失过的身份证 详细问题 笔者在专利业务办理系统进行新用户注册,注册时间为晚上大概22:00以后。注册时已核对姓名、证件号码、有效期和年限,已确保正确&…

【C++】拆分详解 - 内存管理

文章目录 前言一、C/C内存分布二、C语言中动态内存管理方式:malloc/calloc/realloc/free三、C内存管理方式  3.1 new/delete操作内置类型  3.2 new和delete操作自定义类型  3.3 operator new与operator delete函数 四、new和delete的实现原理  4.1 内置类型…

【Java集合】面试题汇总

Java 集合Java 集合概览1. List, Set, Queue, Map 四者的区别?2. ArrayList 和 Array(数组)的区别?3. ArrayList 和 Vector 的区别?4. Vector 和 Stack 的区别?(了解即可)5. ArrayList 可以添加 null 值吗…

python+appium调@pytest.mark.parametrize返回missing 1 required positional argument:

出错描述: 1、在做pythonappium自动化测试时,使用装饰器pytest.mark.parametrize(“参数”,[值1,值2,值3]),测试脚本执行返回test_xx() missing 1 required positional argument:“…

Bitmap OOM

老机器Bitmap预读仍然OOM&#xff0c;无奈增加一段&#xff0c;终于不崩溃了。 if (Build.VERSION.SDK_INT < 21)size 2; 完整代码&#xff1a; Bitmap bitmap; try {//Log.e(Thread.currentThread().getStackTrace()[2] "", surl);URL url new URL(surl);…

通俗易懂HTTP和HTTPS区别

HTTP&#xff1a;超文本传输协议&#xff0c;它是使用一种明文的方式发送我们的内容&#xff0c;没有任何的加密&#xff0c;例如我们要在网页上输入账号密码&#xff0c;如果使用HTTP协议&#xff0c;账号密码就可能会被暴露&#xff0c;默认端口是80. HTTPS&#xff1a;是HT…

动态规划解决背包问题

目录 动态规划步骤&#xff1a; 1.01背包问题 2.完全背包问题 动态规划步骤&#xff1a; step1.分析问题&#xff0c;定义dp数组&#xff08;下标含义&#xff09; step2.初始化dp数组&#xff08;边界&#xff09; step3.写dp状态转换方程&#xff08;明确dp数组遍历顺序…

【Super数据结构】二叉搜索树与二叉树的非递归遍历(含前/中/后序)

&#x1f3e0;关于此专栏&#xff1a;Super数据结构专栏将使用C/C语言介绍顺序表、链表、栈、队列等数据结构&#xff0c;每篇博文会使用尽可能多的代码片段图片的方式。 &#x1f6aa;归属专栏&#xff1a;Super数据结构 &#x1f3af;每日努力一点点&#xff0c;技术累计看得…

【饿了么笔试题汇总】[全网首发]2024-04-12-饿了么春招笔试题-三语言题解(CPP/Python/Java)

&#x1f36d; 大家好这里是KK爱Coding &#xff0c;一枚热爱算法的程序员 ✨ 本系列打算持续跟新饿了么近期的春秋招笔试题汇总&#xff5e; &#x1f4bb; ACM银牌&#x1f948;| 多次AK大厂笔试 &#xff5c; 编程一对一辅导 &#x1f44f; 感谢大家的订阅➕ 和 喜欢&#x…

.cur 鼠标光标文件读取

备份icon掩码开发代码-CSDN博客 代码改写自 目前bug是高度不足&#xff0c;顶上的几十个像素图片打印需要加常数&#xff0c; i的改写是 i/3,参考上面链接的自简书的代码 #include <stdio.h> #include <windows.h> #pragma warning(disable : 4996) // visu…

[lesson20]初始化列表的使用

初始化列表的使用 类成员的初始化 C中提供了初始化列表对成员变量进行初始化 语法规则 注意事项 成员的初始化顺序与成员的声明顺序相同成员的初始化顺序与初始化列表中的位置无关初始化列表先于构造函数的函数体执行 类中的const成员 类中的const成员会被分配空间的类中…

【Linux】环境搭建

昙花一现&#xff0c;却等待了整个白昼 蝉鸣一夏&#xff0c;却蛰伏了几个四季 目录 购买云服务器 总结 使用 XShell 远程登陆到 Linux 利用Linux编写一个简单C程序 ⭐toush -- 创建文件 ⭐vi -- 文本编译器 ⭐ll -- 查看文件的显示结果分析 ⭐gcc -o ⭐cat -- 查看源代码 契子…

利用细粒度检索增强和自我检查提升对话式问题解答能力

&#x1f349; CSDN 叶庭云&#xff1a;https://yetingyun.blog.csdn.net/ 论文标题&#xff1a;Boosting Conversational Question Answering with Fine-Grained Retrieval-Augmentation and Self-Check 论文地址&#xff1a;https://arxiv.org/abs/2403.18243 检索增强生成…

【报错】AttributeError: ‘NoneType‘ object has no attribute ‘pyplot_show‘(已解决)

【报错】AttributeError: ‘NoneType’ object has no attribute ‘pyplot_show’ 问题描述&#xff1a;python可视化出现下面报错 我的原始代码&#xff1a; import matplotlib.pyplot as pltplt.figure() plt.plot(x, y, bo-) plt.axis(equal) plt.xlabel(X) plt.ylabe…

最短路径问题——(弗洛伊德算法与迪杰斯特拉算法)

最短路径问题——&#xff08;弗洛伊德算法与迪杰斯特拉算法&#xff09;【板子】 题目&#xff1a; 对于下面的图片所给出的关系,回答下面两个问题&#xff1a; 利用迪杰斯特拉算法求点A到每一个点之间的最小距离。利用弗洛伊德算法求每两个点之间的最短路径。 &#xff0…

[RK3399 Linux] 使用ubuntu 20.04.5制作rootfs

一、ubuntu base ubuntu base是用于为特定需求创建自定义映像的最小rootfs,是ubuntu可以运行的最小环境。 1.1 下载源码 下载ubuntu-base的方式有很多,可以从官方的地址:ttp://cdimage.ubuntu.com/ubuntu-base/releases。 也可以其它镜像地址下载,如清华源:https://mi…