【深度学习】pytorch pth模型转为onnx模型后出现冗余节点“identity”,onnx模型的冗余节点“identity”

news2024/10/5 18:26:52

情況描述

onnx模型的冗余节点“identity”如下图。
在这里插入图片描述

解决方式

首先,确保您已经安装了onnx-simplifier库:

pip install onnx-simplifier

然后,您可以按照以下方式使用onnx-simplifier库:

import onnx
from onnxsim import simplify

# 加载导出的 ONNX 模型
onnx_model = onnx.load("your_model.onnx")

# 简化模型
simplified_model, check = simplify(onnx_model)

# 保存简化后的模型
onnx.save_model(simplified_model, "simplified_model.onnx")

通过这个过程,onnx-simplifier库将会检测和移除不必要的"identity"节点,从而减少模型中的冗余。

请注意,使用onnx-simplifier库可能会改变模型的计算图,因此在使用简化后的模型之前,务必进行测试和验证以确保其功能没有受到影响。

问题原因

在将 PyTorch 模型转换为 ONNX 格式时,有时会出现冗余的"identity"节点的问题。这是因为 PyTorch 和 ONNX 在计算图构建和表示方式上存在一些差异。

在 PyTorch 中,计算图是动态构建的,其中包含了很多临时变量和操作。但在 ONNX 中,计算图是静态定义的,每个操作都显式地表示为一个节点。这种差异可能导致在将 PyTorch 模型转换为 ONNX 格式时引入一些不必要的中间"identity"节点。

一个常见的原因是,PyTorch 中的某些操作或模型结构在 ONNX 中没有直接的等价表示。为了保持模型结构的一致性,转换过程中可能会引入额外的"identity"节点,用于保留原始模型中的特定计算图结构或操作。

另外,有时候这些"identity"节点并不会对模型的性能或功能产生任何影响,它们只是在图形表示上引入了一些冗余。这些冗余节点在模型尺寸较小的情况下可能并不明显,但对于大型模型来说可能会显著增加模型文件的大小。

通过使用onnx-simplifier库,您可以对导出的 ONNX 模型进行后处理,去除这些不必要的"identity"节点,从而减少模型的冗余。

需要注意的是,由于 PyTorch 和 ONNX 之间的差异,无法完全避免所有的冗余节点。但大部分情况下这些冗余节点并不会对模型的性能或功能产生实质性的影响。

我的模型代码

import torch
from torch import nn
import torch.nn.functional as F
from torch.nn import init


class hswish(nn.Module):
    def forward(self, x):
        out = x * F.relu6(x + 3, inplace=True) / 6
        return out


class hsigmoid(nn.Module):
    def forward(self, x):
        out = F.relu6(x + 3, inplace=True) / 6
        return out


# 注意力机制
class SeModule(nn.Module):
    def __init__(self, in_channel, reduction=4):
        super(SeModule, self).__init__()
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.fc1 = nn.Conv2d(in_channel, in_channel // reduction, kernel_size=1, stride=1, padding=0, bias=False)
        self.bn = nn.BatchNorm2d(in_channel // reduction)
        self.relu = nn.ReLU(inplace=True)
        self.fc2 = nn.Conv2d(in_channel // reduction, in_channel, kernel_size=1, stride=1, padding=0, bias=False)
        self.hs = hsigmoid()

    def forward(self, x):
        out = self.avgpool(x)
        out = self.fc1(out)
        out = self.bn(out)
        out = self.relu(out)
        out = self.fc2(out)
        out = self.hs(out)
        return x * out


# 线性瓶颈和反向残差结构
class Block(nn.Module):
    def __init__(self, kernel_size, in_channel, expand_size, out_channel, nolinear, semodule, stride):
        super(Block, self).__init__()
        self.stride = stride
        self.se = semodule
        # 1*1展开卷积
        self.conv1 = nn.Conv2d(in_channel, expand_size, kernel_size=1, stride=1, padding=0, bias=False)
        self.bn1 = nn.BatchNorm2d(expand_size)
        self.nolinear1 = nolinear
        # 3*3(或5*5)深度可分离卷积
        self.conv2 = nn.Conv2d(expand_size, expand_size, kernel_size=kernel_size, stride=stride,
                               padding=kernel_size // 2, groups=expand_size, bias=False)
        self.bn2 = nn.BatchNorm2d(expand_size)
        self.nolinear2 = nolinear
        # 1*1投影卷积
        self.conv3 = nn.Conv2d(expand_size, out_channel, kernel_size=1, stride=1, padding=0, bias=False)
        self.bn3 = nn.BatchNorm2d(out_channel)

        self.shortcut = nn.Sequential()
        if stride == 1 and in_channel != out_channel:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=1, padding=0, bias=False),
                nn.BatchNorm2d(out_channel),
            )

    def forward(self, x):
        out = self.nolinear1(self.bn1(self.conv1(x)))
        out = self.nolinear2(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        # 注意力模块
        if self.se != None:
            out = self.se(out)
        # 残差链接
        out = out + self.shortcut(x) if self.stride == 1 else out
        return out


class MobileNetV3_Small_050(nn.Module):
    def __init__(self):
        super(MobileNetV3_Small_050, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(16)
        self.hs1 = nn.ReLU(inplace=True)
        self.bneck = nn.Sequential(
            Block(3, 16, 8, 16, nn.ReLU(inplace=True), SeModule(16), 2),
            Block(3, 16, 40, 16, nn.ReLU(inplace=True), None, 2),
            Block(3, 16, 56, 16, nn.ReLU(inplace=True), None, 1),
            Block(5, 16, 64, 24, hswish(), SeModule(24), 2),
            Block(5, 24, 144, 24, hswish(), SeModule(24), 1),
            Block(5, 24, 144, 24, hswish(), SeModule(24), 1),
            Block(5, 24, 72, 24, hswish(), SeModule(24), 1),
            Block(5, 24, 72, 24, hswish(), SeModule(24), 1),
            Block(5, 24, 144, 48, hswish(), SeModule(48), 2),
            Block(5, 48, 288, 48, hswish(), SeModule(48), 1),
            Block(5, 48, 288, 48, hswish(), SeModule(48), 1),
        )
        self.conv2 = nn.Conv2d(48, 288, kernel_size=1, stride=1, padding=0, bias=False)
        self.bn2 = nn.BatchNorm2d(288)
        self.hs2 = hswish()
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(288, 6)
        self.init_params()

    def init_params(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                init.kaiming_normal_(m.weight, mode='fan_out')
                if m.bias is not None:
                    init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                init.constant_(m.weight, 1)
                init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                init.normal_(m.weight, std=0.001)
                if m.bias is not None:
                    init.constant_(m.bias, 0)

    def forward(self, x):
        out = self.hs1(self.bn1(self.conv1(x)))
        out = self.bneck(out)
        out = self.hs2(self.bn2(self.conv2(out)))
        out = self.avgpool(out)
        out = out.view(-1, 288)
        out = self.fc(out)
        return out


class MobileNetV3_Small(nn.Module):
    def __init__(self):
        super(MobileNetV3_Small, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(16)
        self.hs1 = hswish()
        self.bneck = nn.Sequential(
            Block(3, 16, 16, 16, nn.ReLU(inplace=True), SeModule(16), 2),
            Block(3, 16, 72, 24, nn.ReLU(inplace=True), None, 2),
            Block(3, 24, 88, 24, nn.ReLU(inplace=True), None, 1),
            Block(5, 24, 96, 40, hswish(), SeModule(40), 2),
            Block(5, 40, 240, 40, hswish(), SeModule(40), 1),
            Block(5, 40, 240, 40, hswish(), SeModule(40), 1),
            Block(5, 40, 120, 48, hswish(), SeModule(48), 1),
            Block(5, 48, 144, 48, hswish(), SeModule(48), 1),
            Block(5, 48, 288, 96, hswish(), SeModule(96), 2),
            Block(5, 96, 576, 96, hswish(), SeModule(96), 1),
            Block(5, 96, 576, 96, hswish(), SeModule(96), 1),
        )

        self.conv2 = nn.Conv2d(96, 576, kernel_size=1, stride=1, padding=0, bias=False)
        self.bn2 = nn.BatchNorm2d(576)
        self.hs2 = hswish()

        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(576, 6)
        self.init_params()

    def init_params(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                init.kaiming_normal_(m.weight, mode='fan_out')
                if m.bias is not None:
                    init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                init.constant_(m.weight, 1)
                init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                init.normal_(m.weight, std=0.001)
                if m.bias is not None:
                    init.constant_(m.bias, 0)

    def forward(self, x):
        out = self.hs1(self.bn1(self.conv1(x)))
        out = self.bneck(out)
        out = self.hs2(self.bn2(self.conv2(out)))
        out = self.avgpool(out)
        out = out.view(-1, 576)
        out = self.fc(out)
        return out


if __name__ == '__main__':
    # from torchsummary import summary
    # net = MobileNetV3_Small_050().train()
    # summary(net, (3, 64, 64))
    #
    # from torchstat import stat
    # net = MobileNetV3_Small_050().train()
    # stat(net, input_size=(3, 64, 64))  # 输出模型的FLOPs和参数数量

    # 转为onnx
    import torch.onnx

    dummy_input = torch.randn(1, 3, 64, 64)
    net = MobileNetV3_Small_050().eval()
    torch.onnx.export(net, dummy_input, "mobilenetv3_small_050.onnx", input_names=["input"], output_names=["output"],
                      opset_version=11, )

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/638529.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

手机短视频设置背景文字工具

代码地址 github: https://github.com/iotzzh/tools-web gitee: https://gitee.com/iotzzh/tools-web.git 以前喜欢发抖音,内容是一些古诗文,然后找不到合适模板,于是自己就写了一个小工具,功能如下: 时间展示、支持选…

爬虫基础学习记录

爬虫介绍 互联网爬虫 如果我们把互联网比作一张大的蜘蛛网,那一台计算机上的数据便是蜘蛛网上的一个猎物,而爬虫程序就是一只小蜘蛛,沿着蜘蛛网抓取自己想要的数据 解释1:通过一个程序,根据Url(http://www.taobao.c…

❤ vue主要使用的版本和对应体系

❤ 项目主要使用的版本和对应体系 vue地址: https://cn.vuejs.org/ Vue所有版本地址 https://github.com/vuejs/core/blob/main/changelogs/CHANGELOG-3.2.md NPM镜像地址 npm install -g cnpm --registryhttps://registry.npm.taobao.org nvm 地址: htt…

element(兼容2.72以下的版本)实现树形数据+复选框的效果

用最新的element是可以实现树形数据的展示,但是没有复选框效果,用2.72以前的版本的话,是根本没有展开树形数据的效果,也没有复选框效果, 需求:在2.72以下的老版本上做一个树形展示的效果,并且还…

初识Notes Domino 14 Drop1

大家好,才是真的好。 周末花了点时间,安装了一下Notes Domino 14 Drop1版本。考虑到大多数人的习惯,没采用Docker或K8s方式来部署,也没采用一键配置功能,依旧通过传统方式一步一步进行安装和配置,这样大家…

【Spring Boot 初识丨五】beans 详解

上一篇讲了 Spring Boot 的主程序类 本篇来讲一讲 beans 详解 Spring Boot 初识: 【Spring Boot 初识丨一】入门实战 【Spring Boot 初识丨二】maven 【Spring Boot 初识丨三】starter 【Spring Boot 初识丨四】主应用类 beans 一、 定义二、 命名三、 生命周期3.1 …

Linux防火墙学习笔记7

安装apahce: yum install -y httpd echo 123 >> /var/www/html/index.html systemctl start httpd curl http://localhost 然后给iptables插入一条防火墙策略: iptables -t filter -I INPUT -p tcp --dport 80 -j ACCEPT注意:这里使…

【Spring学习之更简单的读取和存储Bean对象】教会你使用五大类注解和方法注解去存储 Bean 对象

前言: 💞💞今天我们依然是学习Spring,这里我们会更加了解Spring的知识,知道Spring是怎么更加简单的读取和存储Bean对象的。也会让大家对Spring更加了解。 💟💟前路漫漫,希望大家坚持…

高能预警!融云WICC发布《社交泛娱乐出海作战地图》

最近圈子里风很大的《社交泛娱乐出海作战地图》, 必须说,真的有亿点点厉害!这简直是一张集社交泛娱乐市场、品类知识和出海实战指南于一体的教材级地图,实感入手不亏。关注【融云全球互联网通信云】了解更多 首先,容我先秀一把实…

Definition of regularity in PDE theory

Regularity is one of the vague yet very useful terms to talk about a vast variety of results in a uniform way. Other examples of such words include “dynamics” in dynamical systems (I have never seen a real definition of this term but everyone uses it, an…

学习Vue 之 创建一个 Vue 应用

文章目录 Vue.js概述了解 Vue创建一个 Vue 应用参考 Vue.js 概述 计划学习前端,已有一些HTML,js,CSS的基础知识,下一步学习Vue.js。 以下是一些适合新手的Vue.js教程,你可以根据自己的实际情况和需求选择适合自己的…

独家揭秘:Kotlin编译器前端—解析阶段

独家揭秘:Kotlin编译器前端:解析阶段 Kotlin编译器对我来说就像一个黑盒子,虽然有关于Kotlin PSI在IDE插件中有使用的文档,但除了源代码中留下的注释之外,几乎没有其他信息可用。接下来的文章中我们来探索Kotlin编译器…

6. WebGPU 将图像导入纹理

我们在上一篇文章中介绍了有关使用纹理的一些基础知识。在本文中,我们将介绍从图像导入纹理。 在上一篇文章中,通过调用 device.createTexture 创建了一个纹理,然后通过调用 device.queue.writeTexture 将数据放入纹理中。 device.queue 上还…

Axure教程—穿梭框(中继器+动态面板)

本文将教大家如何用AXURE中动态面板和中继器制作穿梭框效果 一、效果 预览地址:https://8k99mh.axshare.com 下载地址:https://download.csdn.net/download/weixin_43516258/87897661?spm1001.2014.3001.5503 二、功能 在待选区域选项中可以选择一个选…

CURL获取与使用

背景:在日常工作中,经常会遇到需要获取CURL构造请求来进行问题定位,那如何获取及使用CURL则成为一个测试人员必备的技能; CURL是什么 CURL是一个命令行工具,开发人员使用它来与服务器进行数据交互。 如何获取完整 C…

Python开源自动化工具Playwright安装及介绍

目录 前言 1、Playwright介绍 2、Playwright安装 3、实操演示 4、小结 总结: 前言 微软开源了一个非常强大的自动化项目叫 playwright-python 它支持主流的浏览器,包含:Chrome、Firefox、Safari、Microsoft Edge 等,同时支…

简单使用Hystrix

使用Hystrix之前&#xff0c;需要先对SpringCloud有所了解&#xff0c;然后才会使用的顺畅&#xff0c;它是我们SpringCould的一种保护机制&#xff0c;非常好用。 下面直接开始 先导入Hystrix所需要的依赖 <!-- 引入openfiegn--> <dependency> <groupId>org…

Java学习笔记(视频:韩顺平老师)3.0

如果你喜欢这篇文章的话&#xff0c;请给作者点赞哟&#xff0c;你的支持是我不断前进的动力。 因为作者能力水平有限&#xff0c;欢迎各位大佬指导。 目录 如果你喜欢这篇文章的话&#xff0c;请给作者点赞哟&#xff0c;你的支持是我不断前进的动力。 算数运算符 号使用…

体验 TDengine 3.0 高性能的第一步,请学会控制建表策略

正如我们之前所言&#xff0c;在 3.0 当中&#xff0c;我们在产品底层做了很大的变化调整&#xff0c;除了架构更加科学高效以外&#xff0c;用户体验也是我们重点优化的方向。以之前一篇文章为例&#xff1a;对于 Update 功能&#xff0c;用户不再需要任何配置 &#xff0c;默…

社交泛娱乐出海如何抓住AIGC?我在融云WICC上看到了答案

大模型掀起的AIGC时代&#xff0c;所有企业的所有业务与产品都值得利用大模型技术重做一遍&#xff0c;接下来也将有越来越多依托AIGC技术的创新应用涌现。关注【融云全球互联网通信云】了解更多 在社交泛娱乐赛道&#xff0c;AI大模型技术也呈现出了加速落地的态势。日前&…