pytorch 神经网络模型 2D+3D 可视化,这个工具库够猛!

news2024/12/23 22:32:48

生信碱移

torch模块可视化

小编近期冲浪的时候发现一个torch模型架构可视化的神级python库VisualTorch,给各位铁子分享一下doge。

VisualTorch旨在帮助可视化基于Torch的神经网络架构,似乎是今年才上传到github上。它目前支持为PyTorch的Sequential和Custom模型生成分层风格、图形风格和LeNet风格的架构。工具的灵感源自visualkeras、pytorchviz和pytorch-summary。

图片

▲ 可视化示例

0.安装

使用以下代码安装该库

pip install visualtorch

环境依赖如下,实测的时候发现python版本还需要大于3.10:

pillow>=10.0.0
numpy>=1.18.1
aggdraw>=1.3.11
torch>=2.0.0

1.Layered可视化

2D可视化

图片

import matplotlib.pyplot as plt
import visualtorch
from torch import nn

# Example of a simple CNN model using nn.Sequential
model = nn.Sequential(
    nn.Conv2d(3, 16, kernel_size=3, padding=1),
    nn.ReLU(),
    nn.MaxPool2d(2, 2),
    nn.Conv2d(16, 32, kernel_size=3, padding=1),
    nn.ReLU(),
    nn.MaxPool2d(2, 2),
    nn.Conv2d(32, 64, kernel_size=3, padding=1),
    nn.ReLU(),
    nn.MaxPool2d(2, 2),
    nn.Flatten(),
    nn.Linear(64 * 28 * 28, 256),  # Adjusted the input size for the Linear layer
    nn.ReLU(),
    nn.Linear(256, 10),  # Assuming 10 output classes
)

input_shape = (1, 3, 224, 224)
img = visualtorch.layered_view(model, input_shape=input_shape, draw_volume=False)

plt.axis("off")
plt.tight_layout()
plt.imshow(img)
plt.show()

基础自定义模型的可视化

图片

import matplotlib.pyplot as plt
import torch
import torch.nn.functional as func
import visualtorch
from torch import nn


# Example of a simple CNN model
class SimpleCNN(nn.Module):
    """Simple CNN Model."""

    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(64 * 28 * 28, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Define the forward pass."""
        x = self.conv1(x)
        x = func.relu(x)
        x = func.max_pool2d(x, 2, 2)
        x = self.conv2(x)
        x = func.relu(x)
        x = func.max_pool2d(x, 2, 2)
        x = self.conv3(x)
        x = func.relu(x)
        x = func.max_pool2d(x, 2, 2)
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        x = func.relu(x)
        return self.fc2(x)


# Create an instance of the SimpleCNN
model = SimpleCNN()

input_shape = (1, 3, 224, 224)

img = visualtorch.layered_view(model, input_shape=input_shape, legend=True)

plt.axis("off")
plt.tight_layout()
plt.imshow(img)
plt.show()

基本Sequential模型的可视化

图片

import matplotlib.pyplot as plt
import visualtorch
from torch import nn

# Example of a simple CNN model using nn.Sequential
model = nn.Sequential(
    nn.Conv2d(3, 16, kernel_size=3, padding=1),
    nn.ReLU(),
    nn.MaxPool2d(2, 2),
    nn.Conv2d(16, 32, kernel_size=3, padding=1),
    nn.ReLU(),
    nn.MaxPool2d(2, 2),
    nn.Conv2d(32, 64, kernel_size=3, padding=1),
    nn.ReLU(),
    nn.MaxPool2d(2, 2),
    nn.Flatten(),
    nn.Linear(64 * 28 * 28, 256),  # Adjusted the input size for the Linear layer
    nn.ReLU(),
    nn.Linear(256, 10),  # Assuming 10 output classes
)

input_shape = (1, 3, 224, 224)

img = visualtorch.layered_view(model, input_shape=input_shape, legend=True)

plt.axis("off")
plt.tight_layout()
plt.imshow(img)
plt.show()

自定义模块的颜色

图片

from collections import defaultdict

import matplotlib.pyplot as plt
import visualtorch
from torch import nn

# Example of a simple CNN model using nn.Sequential
model = nn.Sequential(
    nn.Conv2d(3, 16, kernel_size=3, padding=1),
    nn.ReLU(),
    nn.MaxPool2d(2, 2),
    nn.Conv2d(16, 32, kernel_size=3, padding=1),
    nn.ReLU(),
    nn.MaxPool2d(2, 2),
    nn.Conv2d(32, 64, kernel_size=3, padding=1),
    nn.ReLU(),
    nn.MaxPool2d(2, 2),
    nn.Flatten(),
    nn.Linear(64 * 28 * 28, 256),  # Adjusted the input size for the Linear layer
    nn.ReLU(),
    nn.Linear(256, 10),  # Assuming 10 output classes
)

color_map: dict = defaultdict(dict)
color_map[nn.Conv2d]["fill"] = "LightSlateGray"  # Light Slate Gray
color_map[nn.ReLU]["fill"] = "#87CEFA"  # Light Sky Blue
color_map[nn.MaxPool2d]["fill"] = "LightSeaGreen"  # Light Sea Green
color_map[nn.Flatten]["fill"] = "#98FB98"  # Pale Green
color_map[nn.Linear]["fill"] = "LightSteelBlue"  # Light Steel Blue

input_shape = (1, 3, 224, 224)
img = visualtorch.layered_view(model, input_shape=input_shape, color_map=color_map)

plt.axis("off")
plt.tight_layout()
plt.imshow(img)
plt.show()

自定义模块的不透明度

图片

import matplotlib.pyplot as plt
import visualtorch
from torch import nn

# Example of a simple CNN model using nn.Sequential
model = nn.Sequential(
    nn.Conv2d(3, 16, kernel_size=3, padding=1),
    nn.ReLU(),
    nn.MaxPool2d(2, 2),
    nn.Conv2d(16, 32, kernel_size=3, padding=1),
    nn.ReLU(),
    nn.MaxPool2d(2, 2),
    nn.Conv2d(32, 64, kernel_size=3, padding=1),
    nn.ReLU(),
    nn.MaxPool2d(2, 2),
    nn.Flatten(),
    nn.Linear(64 * 28 * 28, 256),  # Adjusted the input size for the Linear layer
    nn.ReLU(),
    nn.Linear(256, 10),  # Assuming 10 output classes
)

input_shape = (1, 3, 224, 224)

img = visualtorch.layered_view(model, input_shape=input_shape, opacity=100)

plt.axis("off")
plt.tight_layout()
plt.imshow(img)
plt.show()

自定义模块的方向

图片

import matplotlib.pyplot as plt
import visualtorch
from torch import nn

# Example of a simple CNN model using nn.Sequential
model = nn.Sequential(
    nn.Conv2d(3, 16, kernel_size=3, padding=1),
    nn.ReLU(),
    nn.MaxPool2d(2, 2),
    nn.Conv2d(16, 32, kernel_size=3, padding=1),
    nn.ReLU(),
    nn.MaxPool2d(2, 2),
    nn.Conv2d(32, 64, kernel_size=3, padding=1),
    nn.ReLU(),
    nn.MaxPool2d(2, 2),
    nn.Flatten(),
    nn.Linear(64 * 28 * 28, 256),  # Adjusted the input size for the Linear layer
    nn.ReLU(),
    nn.Linear(256, 10),  # Assuming 10 output classes
)

input_shape = (1, 3, 224, 224)

img = visualtorch.layered_view(
    model,
    input_shape=input_shape,
    one_dim_orientation="x",
    spacing=40,
)

plt.axis("off")
plt.tight_layout()
plt.imshow(img)
plt.show()

自定义模块的阴影

图片

import matplotlib.pyplot as plt
import visualtorch
from torch import nn

# Example of a simple CNN model using nn.Sequential
model = nn.Sequential(
    nn.Conv2d(3, 16, kernel_size=3, padding=1),
    nn.ReLU(),
    nn.MaxPool2d(2, 2),
    nn.Conv2d(16, 32, kernel_size=3, padding=1),
    nn.ReLU(),
    nn.MaxPool2d(2, 2),
    nn.Conv2d(32, 64, kernel_size=3, padding=1),
    nn.ReLU(),
    nn.MaxPool2d(2, 2),
    nn.Flatten(),
    nn.Linear(64 * 28 * 28, 256),  # Adjusted the input size for the Linear layer
    nn.ReLU(),
    nn.Linear(256, 10),  # Assuming 10 output classes
)

input_shape = (1, 3, 224, 224)
img = visualtorch.layered_view(model, input_shape=input_shape, shade_step=50)

plt.axis("off")
plt.tight_layout()
plt.imshow(img)
plt.show()

自定义模块间空间距离

图片

import matplotlib.pyplot as plt
import visualtorch
from torch import nn

# Example of a simple CNN model using nn.Sequential
model = nn.Sequential(
    nn.Conv2d(3, 16, kernel_size=3, padding=1),
    nn.ReLU(),
    nn.MaxPool2d(2, 2),
    nn.Conv2d(16, 32, kernel_size=3, padding=1),
    nn.ReLU(),
    nn.MaxPool2d(2, 2),
    nn.Conv2d(32, 64, kernel_size=3, padding=1),
    nn.ReLU(),
    nn.MaxPool2d(2, 2),
    nn.Flatten(),
    nn.Linear(64 * 28 * 28, 256),  # Adjusted the input size for the Linear layer
    nn.ReLU(),
    nn.Linear(256, 10),  # Assuming 10 output classes
)

input_shape = (1, 3, 224, 224)

img = visualtorch.layered_view(model, input_shape=input_shape, spacing=50)

plt.axis("off")
plt.tight_layout()
plt.imshow(img)
plt.show()

忽略某些模块,即仅可视化某些层

图片

import matplotlib.pyplot as plt
import visualtorch
from torch import nn

# Example of a simple CNN model using nn.Sequential
model = nn.Sequential(
    nn.Conv2d(3, 16, kernel_size=3, padding=1),
    nn.ReLU(),
    nn.MaxPool2d(2, 2),
    nn.Conv2d(16, 32, kernel_size=3, padding=1),
    nn.ReLU(),
    nn.MaxPool2d(2, 2),
    nn.Conv2d(32, 64, kernel_size=3, padding=1),
    nn.ReLU(),
    nn.MaxPool2d(2, 2),
    nn.Flatten(),
    nn.Linear(64 * 28 * 28, 256),  # Adjusted the input size for the Linear layer
    nn.ReLU(),
    nn.Linear(256, 10),  # Assuming 10 output classes
)

ignored_layers = [nn.ReLU, nn.Flatten]

input_shape = (1, 3, 224, 224)
img = visualtorch.layered_view(
    model,
    input_shape=input_shape,
    type_ignore=ignored_layers,
)

plt.axis("off")
plt.tight_layout()
plt.imshow(img)
plt.show()

2.全连接层可视化

可视化基本的全连接层,当然像颜色、空间啥的也都可以调整:

图片

import matplotlib.pyplot as plt
import torch
import visualtorch
from torch import nn


class SimpleDense(nn.Module):
    """Simple Dense Model."""

    def __init__(self) -> None:
        super().__init__()
        self.h0 = nn.Linear(4, 8)
        self.h1 = nn.Linear(8, 8)
        self.h2 = nn.Linear(8, 4)
        self.out = nn.Linear(4, 2)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Define the forward pass."""
        x = self.h0(x)
        x = self.h1(x)
        x = self.h2(x)
        return self.out(x)


model = SimpleDense()

input_shape = (1, 4)

img = visualtorch.graph_view(model, input_shape)

plt.axis("off")
plt.tight_layout()
plt.imshow(img)
plt.show()

LeNet风格示例

图片

import matplotlib.pyplot as plt
import visualtorch
from torch import nn

# Example of a simple CNN model using nn.Sequential
model = nn.Sequential(
    nn.Conv2d(3, 8, kernel_size=3, padding=1),
    nn.MaxPool2d(2, 2),
    nn.Conv2d(8, 16, kernel_size=3, padding=1),
    nn.MaxPool2d(2, 2),
)

input_shape = (1, 3, 128, 128)

img = visualtorch.lenet_view(model, input_shape=input_shape)

plt.axis("off")
plt.tight_layout()
plt.imshow(img)
plt.show()

上面用到的几个API参数,这里就介绍了,可以自行查看文档:

  • https://visualtorch.readthedocs.io/en/latest/index.html

够猛,宝

赶紧收藏关注起来

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2164544.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

jQuery——jQuery的2把利器

1、jQuery 核心函数 ① 简称:jQuery 函数,即为 $ 或者 jQuery ② jQuery 库向外直接暴露的是 $ 或者 jQuery ③ 引入 jQuery 库后,直接使用 $ 即可 当函数用:$(xxx) 当对象用:$.xxx&#x…

华为官宣,不支持安卓应用的纯血鸿蒙终于来了

华为前不久与苹果新品发布会撞车的全球首款量产三折叠屏幕手机 Mate XT,本以为已是其下半年狠活儿担当。 但直到看完昨天下午的华为秋季全场景发布会才发现,好家伙,此前那都叫小打小闹,原来大招全搁在后头呢! 这场近两…

蒙语学习快速方法,速记蒙语单词怎么学习更高效!

要高效学习蒙古语和速记单词,首先要掌握基础知识,如字母表和发音规则。接着,专注于学习日常用语和基础词汇,并运用记忆技巧如联想、发音和构词法来帮助记忆。利用专门的学习软件,如“蒙语学习通”,可以提供…

进程间通信 (一)【管道通信(上)】

目录 1. 概况2. 管道通信的原理2.1 初步理解2.2 深入理解 1. 概况 是什么:两个及以上的进程实现数据层面的交互,称为进程间的通信。 因为进程独立性的存在,所以一个进程无法直接访问另一个进程的数据,即便是父子进程,子…

数字IC设计\FPGA 职位经典笔试面试整理--基础篇2

1. 卡诺图 逻辑函数表达式可以使用其最小项相加来表示,用所有的最小项可以转换为卡诺图进行逻辑项化简 卡诺图讲解资料1 卡诺图讲解资料2 卡诺图讲解资料3 最小项的定义 一个函数的某个乘积项包含了函数的全部变量,其中每个变量都以原变量或反变量的形…

从传统到智能:低代码平台在生产型企业中的应用实践

在全球数字化浪潮的推动下,生产型企业正面临前所未有的变革压力。为了在激烈的市场竞争中保持竞争力,企业迫切需要通过技术手段实现业务流程的优化和创新。然而,传统的软件开发方式往往耗时耗力,难以快速响应市场需求。低代码平台…

一些依赖库的交叉编译步骤

交叉编译链版本:12.3.0 一、curl-7.43.0库交叉编译 libcurl是一个跨平台的网络协议库,支持http, https, ftp, gopher, telnet, dict, file, 和ldap 协议。libcurl同样支持HTTPS证书授权,HTTP POST, HTTP PUT, FTP 上传, HTTP基本表单上传&a…

Django学习实战篇六(适合略有基础的新手小白学习)(从0开发项目)

前言: 上一章中,我们完成了页面样式的配置,让之前简陋的页面变得漂亮了些。 整理一下目前已经完成的系统,从界面上看,已经完成了以下页面: 首页分类列表页标签列表页口博文详情页 这离我们的需求还有些距离&#xff0…

哪款手机软件适合记事?记事本软件推荐

在这个信息爆炸的时代,手机已经成为我们生活中不可或缺的一部分。它不仅携带方便,而且功能强大,几乎可以完成我们日常所需的所有任务。随着生活节奏的加快,人们越来越需要一个可靠的工具来帮助自己记录重要信息和工作事项。这时候…

德勤校招网申笔试综合能力测试SHL题库与面试真题攻略

德勤的综合能力测试(General Ability)是其校园招聘在线测评的关键环节,旨在评估应聘者的多项认知能力。以下是对这部分内容的全面整合: 综合能力测试(General Ability) 测试时长为46分钟,包含…

ORA-12560:TNS:协议适配器错误

今天准备在数据库服务器创建一个用户,使用管理员账号进行登录 sqlplus / as sysdba时,突然报了个ORA-12560:TNS:协议适配器错误,吓的我一激灵,不应该啊,之前一直都是正常的,也是在网…

大漠yolo-数据集标注

参考 【按键精灵】大漠插件yolo环境配置_哔哩哔哩_bilibili 1. 2. 3.启动

MySQL高阶1873-计算特殊奖金

目录 题目 准备数据 分析数据 总结 题目 编写解决方案,计算每个雇员的奖金。如果一个雇员的 id 是 奇数 并且他的名字不是以 M 开头,那么他的奖金是他工资的 100% ,否则奖金为 0 。 返回的结果按照 employee_id 排序。 准备数据 Crea…

记录踩坑 uniapp 引入百度地图(微信小程序,H5,APP)

前言 因为公司要求一定要用百度地图,网上引入百度地图的方法说的就三种(插件,异步,webview组件),因为我用的是VUE3 第一种方法引入插件(插件名vue-baidu-map)一直报错vue2没试过反正vue3引进去就是报错第二种方法用异步引入 如果只开发app和h5可以用,微信小程序反正不显示,但…

android studio 批量修改包名 app package name

1、批量修改包名:project view模式 我们可以看到,只可以修改myapplication的部分包名,前面的com.demo这个修改了,可以进行如下设置来达到修改demo的目的。 2、设置下,通过不同的目录来达到批量修改的目的:…

2024最新甄选7款超好用的文档加密软件 | 好用的企业文档加密软件大盘点!赶快码住!

在数字化时代,文档如同古代的锦书密函,承载着企业的智慧与机密。 正如古诗所云:"锦书难托云中雁,密语常藏月下窗。" 2024年,我们不仅要传承古人的智慧,更要借助现代科技的力量,守护…

张朝阳的物理课第三卷:量子力学的硬核探索与启发

💂 个人网站:【 摸鱼游戏】【神级代码资源网站】【海拥导航】🤟 找工作,来万码优才:👉 #小程序://万码优才/HDQZJEQiCJb9cFi💅 想寻找共同学习交流,摸鱼划水的小伙伴,请点击【全栈技…

使用Prometheus进行系统监控,包括Mysql、Redis,并使用Grafana图形化表示

Prometheus是一个开源的的监控工具,而且还免费。这一次我们用Prometheus来对之前安装的所有服务,包括Mysql、Redis、系统状况等进行监控,并结合Grafana进行图形化展示 Prometheus下载和安装 下载地址(以下所有插件的官方下载地址…

二叉搜索树(来学包会) C++经验+1

目录 什么是二叉搜索树 解二叉搜索树 二叉搜索树的操作 二叉搜索树的插入(三步走) 二叉搜索树的搜索 二叉搜索树的删除 1.删除的节点是叶子节点 2.删除的节点只有一边的子树 3.删除的节点左子树和右子树都有 详细完整代码 什么是二叉搜索树 二…

MT76X8、MT7621、MT7981和QCA9531的GPIO列表

一、 MT76X8 GPIO列表; 二、 MT7621 GPIO列表; 三、MTK7981 GPIO列表; 四、QCA9531 GPIO列表;