Pytorch导出FP16 ONNX模型

news2024/12/26 0:19:12

一般Pytorch导出ONNX时默认都是用的FP32,但有时需要导出FP16的ONNX模型,这样在部署时能够方便的将计算以及IO改成FP16,并且ONNX文件体积也会更小。想导出FP16的ONNX模型也比较简单,一般情况下只需要在导出FP32 ONNX的基础上调用下model.half()将模型相关权重转为FP16,然后输入的Tensor也改成FP16即可,具体操作可参考如下示例代码。这里需要注意下,当前Pytorch要导出FP16的ONNX必须将模型以及输入Tensor的device设置成GPU,否则会报很多算子不支持FP16计算的提示。

import torch
from torchvision.models import resnet50


def main():
    export_fp16 = True
    export_onnx_path = f"resnet50_fp{16 if export_fp16 else 32}.onnx"
    device = torch.device("cuda:0")

    model = resnet50()
    model.eval()
    model.to(device)
    if export_fp16:
        model.half()

    with torch.inference_mode():
        dtype = torch.float16 if export_fp16 else torch.float32
        x = torch.randn(size=(1, 3, 224, 224), dtype=dtype, device=device)
        torch.onnx.export(model=model,
                          args=(x,),
                          f=export_onnx_path,
                          input_names=["image"],
                          output_names=["output"],
                          dynamic_axes={"image": {2: "width", 3: "height"}},
                          opset_version=17)


if __name__ == '__main__':
    main()

通过Netron可视化工具可以看到导出的FP16 ONNX的输入/输出的tensor类型都是float16
在这里插入图片描述

并且通过对比可以看到,FP16的ONNX模型比FP32的文件更小(48.6MB vs 97.3MB)。
在这里插入图片描述
大多数情况可以按照上述操作进行正常转换,但也有一些比较头大的场景,因为你永远无法知道拿到的模型会有多奇葩,例如下面示例:
错误导出FP16 ONNX示例

import torch
import torch.nn as nn
import torch.nn.functional as F


class MyModel(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv = nn.Conv2d(3, 1, kernel_size=3, stride=2, padding=1)

    def forward(self, x):
        x = self.conv(x)

        kernel = torch.tensor([[0.1, 0.1, 0.1],
                               [0.1, 0.1, 0.1],
                               [0.1, 0.1, 0.1]], dtype=torch.float32, device=x.device).reshape([1, 1, 3, 3])
        x = F.conv2d(x, weight=kernel, bias=None, stride=1)

        return x


def main():
    export_fp16 = True
    export_onnx_path = f"my_model_fp{16 if export_fp16 else 32}.onnx"
    device = torch.device("cuda:0")

    model = MyModel()
    model.eval()
    model.to(device)
    if export_fp16:
        model.half()

    with torch.inference_mode():
        dtype = torch.float16 if export_fp16 else torch.float32
        x = torch.randn(size=(1, 3, 224, 224), dtype=dtype, device=device)
        model(x)
        torch.onnx.export(model=model,
                          args=(x,),
                          f=export_onnx_path,
                          input_names=["image"],
                          output_names=["output"],
                          dynamic_axes={"image": {2: "width", 3: "height"}},
                          opset_version=17)


if __name__ == '__main__':
    main()

执行以上代码后会报如下错误信息:

/src/ATen/native/cudnn/Conv_v8.cpp:80.)
  return F.conv2d(input, weight, bias, self.stride,
Traceback (most recent call last):
  File "/home/wz/my_projects/py_projects/export_fp16/example.py", line 47, in <module>
    main()
  File "/home/wz/my_projects/py_projects/export_fp16/example.py", line 36, in main
    model(x)
  File "/home/wz/miniconda3/envs/torch2.0.1/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/wz/my_projects/py_projects/export_fp16/example.py", line 17, in forward
    x = F.conv2d(x, weight=kernel, bias=None, stride=1)
  RuntimeError: Input type (torch.cuda.HalfTensor) and weight type (torch.cuda.FloatTensor) should be the same

简单来说就是在推理过程中遇到两种不同类型的数据要计算,torch.cuda.HalfTensor(FP16) 和torch.cuda.FloatTensor(FP32)。遇到这种情况一般常见有两种解法:

  • 一种是找到数据类型与我们预期不一致的地方,然后改成我们要想的dtype,例如上面示例是将kernel的dtype写死成了torch.float32,我们可以改成torch.float16或者写成x.dtype(这种会比较通用,会根据输入的Tensor类型自动切换)。这种方法有个弊端,如果代码里写死dtype的位置很多,改起来会比较头大。
  • 另一种是使用torch.autocast上下文管理器,该上下文管理器能够实现推理过程中自动进行混合精度计算,例如遇到能进行float16/bfloat16计算的场景会自动切换。具体使用方法可以查看官方文档。下面示例代码就是用torch.autocast上下文管理器来做自动转换。
import torch
import torch.nn as nn
import torch.nn.functional as F


class MyModel(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv = nn.Conv2d(3, 1, kernel_size=3, stride=2, padding=1)

    def forward(self, x):
        x = self.conv(x)

        kernel = torch.tensor([[0.1, 0.1, 0.1],
                               [0.1, 0.1, 0.1],
                               [0.1, 0.1, 0.1]], dtype=torch.float32, device=x.device).reshape([1, 1, 3, 3])
        x = F.conv2d(x, weight=kernel, bias=None, stride=1)

        return x


def main():
    export_fp16 = True
    export_onnx_path = f"my_model_fp{16 if export_fp16 else 32}.onnx"
    device = torch.device("cuda:0")

    model = MyModel()
    model.eval()
    model.to(device)
    if export_fp16:
        model.half()

    with torch.autocast(device_type="cuda", dtype=torch.float16):
        with torch.inference_mode():
            dtype = torch.float16 if export_fp16 else torch.float32
            x = torch.randn(size=(1, 3, 224, 224), dtype=dtype, device=device)
            model(x)
            torch.onnx.export(model=model,
                              args=(x,),
                              f=export_onnx_path,
                              input_names=["image"],
                              output_names=["output"],
                              dynamic_axes={"image": {2: "width", 3: "height"}},
                              opset_version=17)


if __name__ == '__main__':
    main()

使用上述代码能够正常导出ONNX模型,并且使用Netron可视化后可以看到导出的FP16 ONNX模型是符合预期的。
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1581404.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

LINUX系统触摸工业显示器芯片应用方案--Model4(简称M4芯片)

背景介绍&#xff1a; 触摸工业显示器传统的还是以WINDOWS为主&#xff0c;但近年来&#xff0c;安卓紧随其后&#xff0c;但一直市场应用情况不够理想&#xff0c;反而是LINUX系统的触摸工业显示器大受追捧呢&#xff1f; 触摸工业显示器传统是以Windows系统为主&#xff0c…

微信小程序用户登录授权指定(旧版本)

配置旧版本基础库2.12.3 实现效果 点击登录按钮即可直接登录&#xff0c;获取用户昵称和头像 点击获取头像昵称按钮则需要授权&#xff0c;才能成功登录 代码实现 my.xml <!-- 登录页面,调试基础库为2.20.2库 --> <view class"mylogin"><block w…

权威报道 | 百分点科技:《突发事件应急预案管理办法》解读

近日&#xff0c;百分点科技CTO刘译璟作为唯一企业界代表&#xff0c;接受应急领域权威期刊——《中国应急管理》杂志邀请&#xff0c;与中国安全生产科学研究院、中央党校、中国政法大学等单位的专家一起&#xff0c;就《突发事件应急预案管理办法》&#xff08;以下简称《办法…

三支冲突分析介绍

Pawlak最早通过观察一组智能体对一组问题的意见&#xff0c;提出了冲突分析模型。U表示对象集&#xff0c;V表示属性集&#xff0c;R表示对象集和属性集之间的二元关系&#xff0c;这样一个刻画冲突分析的信息系统通过三元组&#xff08;U&#xff0c;V&#xff0c;R&#xff0…

hive管理之ctl方式

hive管理之ctl方式 hivehive --service clictl命令行的命令 #清屏 Ctrl L #或者 &#xff01; clear #查看数据仓库中的表 show tabls; #查看数据仓库中的内置函数 show functions;#查看表的结构 desc表名 #查看hdfs上的文件 dfs -ls 目录 #执行操作系统的命令 &#xff01;命令…

如何用西门子PLC手工做一个电闸监控控制系统

//S-H4CK13Maptnh// 项目地址:https://github.com/MartinxMax/S7-200_Power_monitoring 欢迎三连支持…电闸监控与控制 注意 注意安全,220V!!! 应用场景 一些工业自动化企业中需要对电闸的监控与控制。 原理图 组装 上面一行检测跳闸情况 下面一行控制当前一路电源 控制…

小剧场短剧剧集收费短剧小程序APP

1. 内容展现 付费、免费、任务解锁&#xff1a;用户可以通过付费直接观看短剧&#xff0c;也可以通过完成平台任务&#xff08;如签到、分享等&#xff09;获得免费观看的机会。这种灵活的解锁方式既满足了用户的多种需求&#xff0c;也促进了平台的活跃度。主流展现形式&…

InsectMamba:基于状态空间模型的害虫分类

InsectMamba&#xff1a;基于状态空间模型的害虫分类 摘要IntroductionRelated WorkImage ClassificationInsect Pest Classification PreliminariesInsectMambaOverall Architecture InsectMamba: Insect Pest Classification with State Space Model 摘要 害虫分类是农业技术…

【黑马头条】-day07APP端文章搜索-ES-mongoDB

文章目录 今日内容1 搭建es环境1.1 拉取es镜像1.2 创建容器1.3 配置中文分词器ik1.4 测试 2 app文章搜索2.1 需求说明2.2 思路分析2.3 创建索引和映射2.3.1 PUT请求添加映射2.3.2 其他操作 2.4 初始化索引库数据2.4.1 导入es-init2.4.2 es-init配置2.4.3 导入数据2.4.4 查询已导…

Docker容器嵌入式开发:Docker Ubuntu18.04配置mysql数据库

在 Ubuntu 18.04 操作系统中安装 MySQL 数据库的过程。下面是安装过程的详细描述: 首先,使用以下命令安装 MySQL 服务器: sudo apt install mysql-server系统会提示是否继续安装,按下 Y 键确认。 安装过程中,系统会下载并安装 MySQL 相关的软件包,包括 libaio1、mysql…

ChromeDriver / Selenium-server

一、简介 ChromeDriver 是一个 WebDriver 的实现&#xff0c;专门用于自动化控制 Google Chrome 浏览器。以下是关于 ChromeDriver 的详细说明&#xff1a; 定义与作用&#xff1a; ChromeDriver 是一个独立的服务器程序&#xff0c;作为客户端库与 Google Chrome 浏览…

STM32H7通用定时器计数功能的使用

目录 概述 1 STM32定时器介绍 1.1 认识通用定时器 1.2 通用定时器的特征 1.3 递增计数模式 1.4 时钟选择 2 STM32Cube配置定时器时钟 2.1 配置定时器参数 2.2 配置定时器时钟 3 STM32H7定时器使用 3.1 认识定时器的数据结构 3.2 计数功能实现 4 测试案例 4.1 代码…

Python 基于 OpenCV 视觉图像处理实战 之 OpenCV 简单视频处理实战案例 之七 简单指定视频某片段快放效果

Python 基于 OpenCV 视觉图像处理实战 之 OpenCV 简单视频处理实战案例 之七 简单指定视频某片段快放效果 目录 Python 基于 OpenCV 视觉图像处理实战 之 OpenCV 简单视频处理实战案例 之七 简单指定视频某片段快放效果 一、简单介绍 二、简单指定视频某片段快放效果实现原理…

vue3使用jsQR解析二维码

1.了解jsQR jsQR是一个纯javascript脚本实现的二维码识别库&#xff0c;不仅可以在浏览器端使用&#xff0c;而且支持后端node.js环境。jsQR使用较为简单&#xff0c;有着不错的识别率。 2.效果图 3.二维码 4.下载jsqr包 npm i -d jsqr5.代码 <script setup> import …

STM32F407+FreeRTOS+LWIP UDP组播

开发环境介绍&#xff1a; MCU&#xff1a;STM32F407ZET6 网卡&#xff1a;LAN8720A LWIP版本&#xff1a;V1.1.0 FreeRTOS 版本&#xff1a;V10.2.1 LAN8720A硬件原理图&#xff1a; 硬件连接说明&#xff1a; MII_RX_CLK/RMII_REF_CLK ------>PA1 …

[lesson15]类与封装的概念

类与封装的概念 类的封装 类通常分为以下两个部分 类的实现细节类的使用方式 当使用类时&#xff0c;不需要关心其实现细节 当创建类时&#xff0c;才需要考虑其内部实现细节 封装的基本概念 根据经验&#xff1a;并不是类的每个属性都是对外公开的 如&#xff1a;女孩子不…

hive-3.1.2分布式搭建与hive的三种交互方式

hive-3.1.2分布式搭建&#xff1a; 一、上传解压配置环境变量 在官网或者镜像站下载驱动包 华为云镜像站地址&#xff1a; hive&#xff1a;Index of apache-local/hive/hive-3.1.2 mysql驱动包&#xff1a;Index of mysql-local/Downloads/Connector-J # 1、解压 tar -zx…

gpt科普1 GPT与搜索引擎的对比

GPT&#xff08;Generative Pre-trained Transformer&#xff09;是一种基于Transformer架构的自然语言处理模型。它通过大规模的无监督学习来预训练模型&#xff0c;在完成这个阶段后&#xff0c;可以用于各种NLP任务&#xff0c;如文本生成、机器翻译、文本分类等。 以下是关…

在Linux系统上实现TCP(socket)通信

一.什么TCP TCP&#xff08;传输控制协议&#xff09;是一种面向连接的、可靠的、基于字节流的传输层通信协议。 二.TCP通信流程 三. TCP 服务器端 1 创建socket int sockfd socket(AF_INET, SOCK_STREAM, 0); //SOCK_STREAM tcp通信2 绑定(bind) struct sockaddr_in myad…

.net 6 集成NLog

.net 6 webapi项目集成NLog 上代码step 1 添加nugetstep 2 添加支持step 3 添加配置文件 结束 上代码 step 1 添加nuget 添加nuget 包 Roc step 2 添加支持 修改program.cs var builder WebApplication.CreateBuilder(args); // 添加NLog日志支持 builder.AddRocNLog();ste…