深度学习根据代码可视化模型结构图的方法

news2024/9/21 15:38:55

方法1. Netron

Netron 是一个支持多种深度学习模型格式的可视化工具,可以将 PyTorch 模型转换为 ONNX 格式,然后使用 Netron 进行可视化。

安装 Netron

pip install netron

使用示例

import torch.onnx

# 定义模型
model = EMA(channels=64)
x = torch.randn(1, 64, 128, 128)

# 导出为 ONNX 格式
torch.onnx.export(model, x, "ema_model.onnx")

# 启动 Netron
import netron
netron.start('ema_model.onnx')

在浏览器中打开 http://localhost:8080/ 即可查看模型的结构图。

方法2. tensorboard

TensorBoard 是 TensorFlow 的可视化工具,也可以与 PyTorch 集成,适用于查看模型的计算图和训练过程。

安装 TensorBoard

pip install tensorboard

使用示例

import torch
from torch.utils.tensorboard import SummaryWriter

# 定义模型
class EMA(nn.Module):
    def __init__(self, channels, c2=None, factor=32):
        super(EMA, self).__init__()
        self.groups = factor
        assert channels // self.groups > 0
        self.softmax = nn.Softmax(-1)
        self.agp = nn.AdaptiveAvgPool2d((1, 1))
        self.pool_h = nn.AdaptiveAvgPool2d((None, 1))
        self.pool_w = nn.AdaptiveAvgPool2d((1, None))
        self.gn = nn.GroupNorm(channels // self.groups, channels // self.groups)
        self.conv1x1 = nn.Conv2d(channels // self.groups, channels // self.groups, kernel_size=1, stride=1, padding=0)
        self.conv3x3 = nn.Conv2d(channels // self.groups, channels // self.groups, kernel_size=3, stride=1, padding=1)

    def forward(self, x):
        b, c, h, w = x.size()
        group_x = x.reshape(b * self.groups, -1, h, w)  # b*g,c//g,h,w
        x_h = self.pool_h(group_x)
        x_w = self.pool_w(group_x).permute(0, 1, 3, 2)
        hw = self.conv1x1(torch.cat([x_h, x_w], dim=2))
        x_h, x_w = torch.split(hw, [h, w], dim=2)
        x1 = self.gn(group_x * x_h.sigmoid() * x_w.permute(0, 1, 3, 2).sigmoid())
        x2 = self.conv3x3(group_x)
        x11 = self.softmax(self.agp(x1).reshape(b * self.groups, -1, 1).permute(0, 2, 1))
        x12 = x2.reshape(b * self.groups, c // self.groups, -1)  # b*g, c//g, hw
        x21 = self.softmax(self.agp(x2).reshape(b * self.groups, -1, 1).permute(0, 2, 1))
        x22 = x1.reshape(b * self.groups, c // self.groups, -1)  # b*g, c//g, hw
        weights = (torch.matmul(x11, x12) + torch.matmul(x21, x22)).reshape(b * self.groups, 1, h, w)
        return (group_x * weights.sigmoid()).reshape(b, c, h, w)

# 创建一个SummaryWriter对象
writer = SummaryWriter('runs/ema_example')

# 创建模型和输入数据
model = EMA(channels=64)
x = torch.randn(1, 64, 128, 128)

# 添加计算图到TensorBoard
writer.add_graph(model, x)
writer.close()

然后使用以下命令启动 TensorBoard 并查看模型的计算图:

tensorboard --logdir=runs/ema_example

在浏览器中打开 http://localhost:6006/ 即可查看模型的结构图。

方法3. Torchviz

Torchviz 是一个可以将 PyTorch 模型的计算图转换为可视化结构图的工具。

需要先安装 graphviz

sudo apt-get update
sudo apt-get install graphviz

安装 Torchviz

pip install torchviz

使用示例

import torch
from torchviz import make_dot

# 定义模型
class EMA(nn.Module):
    def __init__(self, channels, c2=None, factor=32):
        super(EMA, self).__init__()
        self.groups = factor
        assert channels // self.groups > 0
        self.softmax = nn.Softmax(-1)
        self.agp = nn.AdaptiveAvgPool2d((1, 1))
        self.pool_h = nn.AdaptiveAvgPool2d((None, 1))
        self.pool_w = nn.AdaptiveAvgPool2d((1, None))
        self.gn = nn.GroupNorm(channels // self.groups, channels // self.groups)
        self.conv1x1 = nn.Conv2d(channels // self.groups, channels // self.groups, kernel_size=1, stride=1, padding=0)
        self.conv3x3 = nn.Conv2d(channels // self.groups, channels // self.groups, kernel_size=3, stride=1, padding=1)

    def forward(self, x):
        b, c, h, w = x.size()
        group_x = x.reshape(b * self.groups, -1, h, w)  # b*g,c//g,h,w
        x_h = self.pool_h(group_x)
        x_w = self.pool_w(group_x).permute(0, 1, 3, 2)
        hw = self.conv1x1(torch.cat([x_h, x_w], dim=2))
        x_h, x_w = torch.split(hw, [h, w], dim=2)
        x1 = self.gn(group_x * x_h.sigmoid() * x_w.permute(0, 1, 3, 2).sigmoid())
        x2 = self.conv3x3(group_x)
        x11 = self.softmax(self.agp(x1).reshape(b * self.groups, -1, 1).permute(0, 2, 1))
        x12 = x2.reshape(b * self.groups, c // self.groups, -1)  # b*g, c//g, hw
        x21 = self.softmax(self.agp(x2).reshape(b * self.groups, -1, 1).permute(0, 2, 1))
        x22 = x1.reshape(b * self.groups, c // self.groups, -1)  # b*g, c//g, hw
        weights = (torch.matmul(x11, x12) + torch.matmul(x21, x22)).reshape(b * self.groups, 1, h, w)
        return (group_x * weights.sigmoid()).reshape(b, c, h, w)

# 创建模型和输入数据
model = EMA(channels=64)
x = torch.randn(1, 64, 128, 128)

# 生成计算图
y = model(x)
dot = make_dot(y, params=dict(model.named_parameters()))

# 保存计算图为文件
dot.format = 'png'
dot.render('ema_model')

执行上面这个代码之后,然后会在当前目录下生成一个 png 文件和另一个文件

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1934744.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

上海理工大学24计算机考研考情分析!初复试分值比55:45,复试逆袭人数不算多!

上海理工大学(University of Shanghai for Science and Technology),位于上海市,是一所以工学为主,工学、理学、经济学、管理学、文学、法学、艺术学等多学科协调发展的应用研究型大学;是上海市属重点建设大…

秒懂C++之类与对象(中)

目录 一.流插入,流提取运算符 二.const成员函数 三.取地址重载 四.构造函数(列表初始化) 小测试: 五.全部代码 前排提醒:本文所参考代码仍是取用上篇文章关于日期类相关功能的实现~末尾有全部代码~ 一.流插入&a…

使用Redis的SETNX命令实现分布式锁

什么是分布式锁 分布式锁是一种用于在分布式系统中控制多个节点对共享资源进行访问的机制。在分布式系统中,由于多个节点可能同时访问和修改同一个资源,因此需要一种方法来确保在任意时刻只有一个节点能够对资源进行操作,以避免数据不一致或…

WPF/C#:实现导航功能

前言 在WPF中使用导航功能可以使用Frame控件,这是比较基础的一种方法。前几天分享了wpfui中NavigationView的基本用法,但是如果真正在项目中使用起来,基础的用法是无法满足的。今天通过wpfui中的mvvm例子来说明在wpfui中如何通过依赖注入与M…

STM32智能安防系统教程

目录 引言环境准备智能安防系统基础代码实现:实现智能安防系统 4.1 数据采集模块 4.2 数据处理与控制模块 4.3 通信与网络系统实现 4.4 用户界面与数据可视化应用场景:家庭与企业安防管理问题解决方案与优化收尾与总结 1. 引言 智能安防系统通过STM32…

解决npm install(‘proxy‘ config is set properly. See: ‘npm help config‘)失败问题

摘要 重装电脑系统后,使用npm install初始化项目依赖失败了,错误提示:‘proxy’ config is set properly…,具体的错误提示如下图所示: 解决方案 经过报错信息查询解决办法,最终找到了两个比较好的方案&a…

【Charles】-雷电模拟器-抓HTTPS包

写在前面 之前的文章我们写过如何通过Charles来抓取IOS手机上的HTTPS包以及遇到的坑。说一个场景,如果你的手机是IOS,但是团队提供的APP安装包是Android,这种情况下你还想抓包,怎么办? 不要慌,我们可以安装…

宝塔面板以www用户运行composer

方式一 执行命令时指定www用户 sudo -u www composer update方式二 在网站配置中的composer选项卡中选择配置运行

c# .net core中间件,生命周期

某些模块和处理程序具有存储在 Web.config 中的配置选项。但是在 ASP.NET Core 中,使用新配置模型取代了 Web.config。 HTTP 模块和处理程序如何工作 官网地址: 将 HTTP 处理程序和模块迁移到 ASP.NET Core 中间件 | Microsoft Learn 处理程序是&#xf…

大数据之路 读书笔记 Day7 实时技术 简介及流式技术架构

回顾: Day6 离线数据开发之数据开发平台Day5 数据同步遇到的问题与解决方案 1. 简介 阿里巴巴在流式数据处理方面采用了多种技术和框架,这些技术的特点包括: 高可伸缩性: 阿里巴巴使用Apache Flink进行大规模数据处理&#xff0c…

【已解决】Django连接MySQL启动报错Did you install mysqlclient?

在终端执行python manage.py makemigrations报错问题汇总 错误1:已安装mysqlclient,提示Did you install mysqlclient? 当你看到这样的错误信息,表明Django尝试加载MySQLdb模块但未找到,因为MySQLdb已被mysqlclient替代。 【解…

如何发一篇顶会论文? 涉及3D高斯,slam,自动驾驶,三维点云等等

SLAM&3DGS 1)SLAM/3DGS/三维点云/医疗图像/扩散模型/结构光/Transformer/CNN/Mamba/位姿估计 顶会论文指导 2)基于环境信息的定位,重建与场景理解 3)轻量级高保真Gaussian Splatting 4)基于大模型与GS的 6D pose e…

全时守护,无死角监测:重点海域渔港视频AI智能监管方案

一、方案背景 随着海洋经济的快速发展和海洋资源的日益紧缺,对重点海域渔港进行有效监控和管理显得尤为重要。视频监控作为一种高效、实时的管理手段,已成为渔港管理中不可或缺的一部分。当前,我国海域面积广阔,渔港众多&#xf…

Axure RP移动端医院在线挂号app问诊原型图模板

医疗在线挂号问诊Axure RP原型图医院APP原形模板,是一款原创的医疗类APP,设计尺寸采用iPhone13(375*812px),原型图上加入了仿真手机壳,使得预览效果更加逼真。 本套原型图主要功能有医疗常识科普、医院挂号…

分布式存储之 ceph 管理操作

一.资源池 Pool 管理 我们已经完成了 Ceph 集群的部署,但是我们如何向 Ceph 中存储数据呢?首先我们需要在 Ceph 中定义一个 Pool 资源池。Pool 是 Ceph 中存储 Object 对象抽象概念。我们可以将其理解为 Ceph 存储上划分的逻辑分区,Pool 由…

Springboot项目远程部署gitee仓库(docker+Jenkins+maven+git)

创建一个Springboot项目,勾选web将该项目创建git本地仓库,再创建远程仓库推送上去 创建TestController RestController RequestMapping("/test") public class TestController { GetMapping("/hello") public String sayHelloJe…

Springboot 启动时Bean的创建与注入-面试热点-springboot源码解读-xunznux

Springboot 启动时Bean的创建与注入,以及对应的源码解读 文章目录 Springboot 启动时Bean的创建与注入,以及对应的源码解读构建Web项目流程图:堆栈信息:堆栈信息简介堆栈信息源码详解1、main:10, DemoApplication (com.xun.demo)2…

OPC UA边缘计算耦合器BL205工业通信的最佳解决方案

OPC UA耦合器BL205是钡铼技术基于下一代工业互联网技术推出的分布式、可插拔、结构紧凑、可编程的IO系统,可直接接入SCADA、MES、MOM、ERP等IT系统,无缝链接OT与IT层,是工业互联网、工业4.0、智能制造、数字化转型解决方案中IO系统最佳方案。…

小阿轩yx-高性能内存对象缓存

小阿轩yx-高性能内存对象缓存 案例分析 案例概述 Memcached 是一款开源的高性能分布式内存对象缓存系统用于很多网站提高访问速度,尤其是需要频繁访问数据的大型网站是典型的 C/S 架构,需要构建 Memcached 服务器端与 Memcached API 客户端用 C 语言…

VisualRules-Web案例展示(一)

VisualRules单机版以其卓越的功能深受用户喜爱。现在,我们进一步推出了VisualRules-Web在线版本,让您无需安装任何软件,即可在任何浏览器中轻松体验VisualRules的强大功能。无论是数据分析、规则管理还是自动化决策,VisualRules-W…