pytroch实战12:基于pytorch的网络结构可视化

news2025/1/16 19:08:41

基于pytorch的网络结构可视化

前言

​ 之前实现了一些常见的网络架构,但是有些网络架构并没有细说,并且网络传输过程中shape的变化也很少谈及过。

​ 除此之外,前面的实现很少涉及到可视化的内容,比如损失值的可视化、网络结构的可视化。

​ 所以本期博客就是补充一下这几点。

目录结构

文章目录

    • 基于pytorch的网络结构可视化
      • 1. 安装:
      • 2. summary使用方法:
      • 3. tensorboardX使用方法:
      • 4. 总结

1. 安装:

安装tensorboardX

​ 安装可视化库:

pip install tensorboardX

​ 当然建议指定镜像源:

pip install tensorboardX -i https://pypi.tuna.tsinghua.edu.cn/simple some-package

安装torchkeras

​ 如果你想要使用tensorflow库的summary,可以安装这个库:

pip install torchkeras -i https://pypi.tuna.tsinghua.edu.cn/simple some-package

2. summary使用方法:

​ 还记得LeNet5这个网络架构吗,我下面演示的代码都是基于这个代码来的,不清楚的可以看这篇文章。

​ 首先导入库:

from torchkeras import summary

​ 当然,其它的代码,比如创建模型、定义优化器等我们不说。只说如何使用这个方法。

​ summary使用很简单,直接按照下面的格式使用即可:

# 打印summary值
print(summary(model,input_shape=(1,28,28)))

注意:

  • 第一个参数值为模型对象
  • 第二个参数值input_shape为输入的shape大小,比如这里用的MNIST数据集,所以shape=[1,28,28]
  • 需要保证model和input_shape在同一设备中,比如上面model不能放入GPU中,或者要么将input_shape改为GPU中的变量

​ 打印的结果值:

--------------------------------------------------------------------------
Layer (type)                            Output Shape              Param #
==========================================================================
Conv2d-1                             [-1, 6, 28, 28]                  156
Sigmoid-2                            [-1, 6, 28, 28]                    0
MaxPool2d-3                          [-1, 6, 14, 14]                    0
Conv2d-4                            [-1, 16, 10, 10]                2,416
Sigmoid-5                           [-1, 16, 10, 10]                    0
MaxPool2d-6                           [-1, 16, 5, 5]                    0
Linear-7                                   [-1, 120]               48,120
Sigmoid-8                                  [-1, 120]                    0
Linear-9                                    [-1, 84]               10,164
Sigmoid-10                                  [-1, 84]                    0
Linear-11                                   [-1, 10]                  850
==========================================================================
Total params: 61,706
Trainable params: 61,706
Non-trainable params: 0
--------------------------------------------------------------------------
Input size (MB): 0.002991
Forward/backward pass size (MB): 0.111404
Params size (MB): 0.235390
Estimated Total Size (MB): 0.349785
--------------------------------------------------------------------------
--------------------------------------------------------------------------
Layer (type)                            Output Shape              Param #
==========================================================================
Conv2d-1                             [-1, 6, 28, 28]                  156
Sigmoid-2                            [-1, 6, 28, 28]                    0
MaxPool2d-3                          [-1, 6, 14, 14]                    0
Conv2d-4                            [-1, 16, 10, 10]                2,416
Sigmoid-5                           [-1, 16, 10, 10]                    0
MaxPool2d-6                           [-1, 16, 5, 5]                    0
Linear-7                                   [-1, 120]               48,120
Sigmoid-8                                  [-1, 120]                    0
Linear-9                                    [-1, 84]               10,164
Sigmoid-10                                  [-1, 84]                    0
Linear-11                                   [-1, 10]                  850
==========================================================================
Total params: 61,706
Trainable params: 61,706
Non-trainable params: 0
--------------------------------------------------------------------------
Input size (MB): 0.002991
Forward/backward pass size (MB): 0.111404
Params size (MB): 0.235390
Estimated Total Size (MB): 0.349785
--------------------------------------------------------------------------

3. tensorboardX使用方法:

​ 相比于summary,我个人觉得tensorboardX更好用,因为它可以将结果可视化,并且操作也十分简单。

代码中使用流程与举例

​ 首先,标准的使用流程为:

# 1. 导入库
from tensorboardX import SummaryWriter

# 2. 创建对象,
writer = SummaryWriter() 
	# 路径一般默认,不过也可以指定
	# 指定路径: writer = SummaryWriter('.\temp')
	
# 3. 可视化1: 可视化某个变量的值,一般为损失值
writer.add_scalar('变量名字',存储的值,序号)
	# 其中序号指的是损失值的序号,比如1、2、3这样的,目的是区分不同值
    # 日志记录: 序号 --- 变量名字 --- 存储值

# 4. 可视化2:可视化模型结构
writer.add_graph(model,input_to_model=batch_data)
	# 第一个参数为模型
	# 第二个参数为输入的shape,一般直接用batch_data即可
	# 注意两者必须在同一设备中,和summary类似

# 5. 关闭可视化
writer.close()

​ 那么,以LeNet5举个例子:

# 导入可视化
from tensorboardX import SummaryWriter

# 创建模型
class LeNet(nn.Module):
	...
	
# 下载数据集或者加载数据集
...
# 加载数据: 分批次,每批256个数据
...
# 创建模型
model = LeNet()
# 模型放入GPU中
...
# 定义损失函数、优化器
...

# *****初始化可视化对象*****
writer = SummaryWriter() # 路径一般默认,不过也可以指定

# 开始训练
x = 0  # 用于指定序号
for i in range(10):
	...
    for j,(batch_data,batch_label) in enumerate(train_loader):
        ...
        if (j + 1) % 200 == 0:
			...
            # 可视化1:一般添加loss值
            writer.add_scalar('200_step_loss',loss_temp / 200,x)
            x += 1
            ...

# 可视化2:模型结构
writer.add_graph(model,input_to_model=batch_data)
# 关闭可视化
writer.close()

​ 对于上面需要说明的一点:(图画的有点抽象,见谅)

在这里插入图片描述

可视化

​ 当上述代码云心完毕后,如果你没有更改默认路径,那么在所属文件夹会出现一个名为runs的文件夹:

在这里插入图片描述

​ 那么,你打开Windows的cmd,进入当前目录,运行下面的代码:

tensorboard --logdir=runs

在这里插入图片描述

​ 然后,将给出的网址复制到浏览器打开(此时不要关闭cmd窗口):

在这里插入图片描述

在这里插入图片描述

4. 总结

​ 网络结构的可视化操作还是比较简单的,而且效果也非常不错。如果你要做ppt或者其它的,建议可以试一试。

LeNet5案例完整代码(需要根据需求修改注释)

# author: baiCai
# 导包
import time
import torch
from torch import nn
from torch import optim
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
import torchvision.transforms as transforms
# 导入可视化
from tensorboardX import SummaryWriter
from torchkeras import summary

# 创建模型
class LeNet(nn.Module):
    def __init__(self):
        super(LeNet,self).__init__()
        # 定义模型
        self.features = nn.Sequential(
            nn.Conv2d(in_channels=1,out_channels=6,kernel_size=(5,5),stride=1,padding=2),
            nn.Sigmoid(),
            nn.MaxPool2d(kernel_size=2,stride=2),
            nn.Conv2d(in_channels=6,out_channels=16,kernel_size=(5,5),stride=1),
            nn.Sigmoid(),
            nn.MaxPool2d(kernel_size=2,stride=2),
        )
        self.classifier = nn.Sequential(
            nn.Linear(in_features=400, out_features=120),
            nn.Sigmoid(),
            nn.Linear(in_features=120, out_features=84),
            nn.Sigmoid(),
            nn.Linear(in_features=84, out_features=10)
        )

    def forward(self,x):
        # 定义前向算法
        x = self.features(x)
        # print(x.shape)
        x = torch.flatten(x,1)
        # print(x.shape)
        result = self.classifier(x)
        return result

# 下载数据集或者加载数据集
train_dataset = MNIST(root='../data',train=True,transform=transforms.ToTensor(),download=True)
test_dataset = MNIST(root='../data',train=False,transform=transforms.ToTensor())
# 加载数据: 分批次,每批256个数据
batch_size = 32
train_loader = DataLoader(dataset=train_dataset,batch_size=batch_size,shuffle=True)
test_loader = DataLoader(dataset=test_dataset,batch_size=batch_size,shuffle=False)
# start time
start_time = time.time()
# 创建模型
model = LeNet()
# 模型放入GPU中
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)
# 定义损失函数
loss_func = nn.CrossEntropyLoss()
loss_list = [] # 用来存储损失值
# 定义优化器
SGD = optim.Adam(params=model.parameters(),lr=0.001)
# 初始化可视化对象
writer = SummaryWriter() # 路径一般默认,不过也可以指定
# 打印summary值
# print(summary(model,input_shape=(1,28,28)))
# 训练指定次数
x = 0
for i in range(10):
    loss_temp = 0 # 定义一个损失值,用来打印查看
    # 其中j是迭代次数,data和label都是批量的,每批32个
    for j,(batch_data,batch_label) in enumerate(train_loader):
        # 启用GPU
        batch_data,batch_label = batch_data.cuda(),batch_label.cuda()
        # 清空梯度
        SGD.zero_grad()
        # 模型训练
        prediction = model(batch_data)
        # 计算损失
        loss = loss_func(prediction,batch_label)
        loss_temp += loss
        # BP算法
        loss.backward()
        # 更新梯度
        SGD.step()
        if (j + 1) % 200 == 0:
            print('第%d次训练,第%d批次,损失值: %.3f' % (i + 1, j + 1, loss_temp / 200))
            # 可视化1:一般添加loss值
            writer.add_scalar('200_step_loss',loss_temp / 200,x)
            x += 1
            loss_temp = 0
# end_time
end_time = time.time()
print('训练花了: %d s' % int((end_time-start_time)))
# 可视化2:模型结构
writer.add_graph(model,input_to_model=batch_data)
# 关闭可视化
writer.close()

# 使用GPU: 训练花了: 124 s
# 不适用GPU:训练花了: 160 s
# 测试
# correct = 0
# for batch_data,batch_label in test_loader:
#     batch_data, batch_label = batch_data.cuda(), batch_label.cuda()
#     prediction = model(batch_data)
#     predicted = torch.max(prediction.data, 1)[1]
#     correct += (predicted == batch_label).sum()
# print('准确率: %.2f %%' % (100 * correct / 10000)) # 因为总共10000个测试数据

#  准确率: 11.35 %

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/575294.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

MySQL的下载、安装、配置(图文详解)

MySQL的下载、安装、配置(图文详解) 一、MySQL的4大版本二、软件的下载三、MySQL8.0 版本的安装四、配置MySQL8.0五、配置MySQL8.0 环境变量六、MySQL5.7 版本的安装、配置七、安装失败问题 一、MySQL的4大版本 MySQL Community Server 社区版本&#xf…

专高六第一次项目答辩学到的知识点【未完成】

目录标题 1、animation和traslation定义动画的区别?2、微信小程序的支付流程?3、canvas和svg有什么区别?4、app自定义导航栏,如果说打包成小程序导航栏如何适配?4、express权限,接口权限?5、一键…

如何在Linux系统安装Nginx

博主介绍:✌全网粉丝4W,全栈开发工程师,从事多年软件开发,在大厂呆过。持有软件中级、六级等证书。可提供微服务项目搭建与毕业项目实战、定制、远程,博主也曾写过优秀论文,查重率极低,在这方面…

设计思维及在Thoughtworks的应用

图:史江鸿 第一次听到"设计思维"是在2016年,那时我刚加入Thoughtworks。我总能在各种场合听到这个词,似乎它在Thoughtworks具有不可撼动的地位。然而,作为QA角色,我并没有机会深入了解它。 我曾感到疑惑&…

2-python的变量类型

内容提要 主要介绍了python中的变量类型,之前不经常用的点有: 列表的下标可以是负数,无论正负,都是从左侧开始,从左到右依次递增。 还有截取操作[头:尾:步长),表示连接,*表示重复。 列表与元组…

springboot+vue地方废物回收机构管理(java项目源码+文档)

风定落花生,歌声逐流水,大家好我是风歌,混迹在java圈的辛苦码农。今天要和大家聊的是一款基于springboot的地方废物回收机构管理。项目源码以及部署相关请联系风歌,文末附上联系信息 。 💕💕作者&#xff1…

leetcode--删除链表的倒数第N个节点(java)

删除链表的倒数第N个节点 Leetcode 19 题解题思路代码演示链表专题 Leetcode 19 题 19 删除链表的倒数第N个节点 -可以测试 题目描述: 给你一个链表,删除链表的倒数第 n 个结点,并且返回链表的头结点 示例1: 输入:he…

javascript基础五:深拷贝浅拷贝的区别?如何实现一个深拷贝?

一、数据类型存储 JavaScript中存在两大数据类型: 基本类型引用类型 基本类型数据保存在在栈内存中 引用类型数据保存在堆内存中,引用数据类型的变量是一个指向堆内存中实际对象的引用,存在栈中 二、浅拷贝 浅拷贝,指的是创建新…

springcloud分布式架构网上商城(java项目源码+文档)

风定落花生,歌声逐流水,大家好我是风歌,混迹在java圈的辛苦码农。今天要和大家聊的是一款基于springboot的分布式架构网上商城。项目源码以及部署相关请联系风歌,文末附上联系信息 。 💕💕作者:…

LLM时代NLP研究何去何从?一个博士生的角度出发

深度学习自然语言处理 原创作者:Winni 前言 最近,大语言模型(LLMs)在许多任务上表现出接近人类水平的性能,这引发了行业兴趣和资金投入的激增,有关LLMs的论文最近也层出不穷。 看起来,NLP领域似…

博客系统(ssm版本)

在前面的文章中给大家介绍过博客系统的servlet版本,但是servlet的技术非常的老旧,我们在企业中用的都是springboot相关的框架,本章内容就是讲述如何一步一步的利用ssm的技术来实现博客系统。 目录 前期配置 创建数据库 配置文件 公共文件…

30 VueComponent 事件的绑定

前言 这是最近的碰到的那个 和响应式相关的问题 特定的操作之后响应式对象不“响应“了 引起的一系列的文章 主要记录的是 vue 的相关实现机制 呵呵 理解本文需要 vue 的使用基础, js 的使用基础 测试用例 用例如下, 我们这里核心关注 事件的处理流程 问题的调试 整个…

c# cad二次开发 通过选择txt文件将自动转换成多段线

c# cad二次开发 通过选择txt文件将自动转换成多段线,txt样式如下 using System; using System.Collections.Generic; using System.Text; using Autodesk.AutoCAD.ApplicationServices; using Autodesk.AutoCAD.EditorInput; using Autodesk.AutoCAD.Runtime; usi…

chatgpt赋能python:Python改变图片大小对SEO的影响

Python改变图片大小对SEO的影响 简介 Python作为一门高效的编程语言,广泛应用于各个行业,并在图像处理领域中也有很多应用。其中一个常见的应用就是改变图片的大小。在SEO(搜索引擎优化)中,图片大小的优化对网站的排…

chatgpt赋能python:Python批量输出:提高工作效率的必备技能

Python批量输出:提高工作效率的必备技能 在日常工作中,我们往往需要批量处理某些数据。Python作为一种流行的编程语言,可以帮助我们快速地完成这项任务。本文将介绍Python批量输出的基本知识和实用技巧,帮助读者提高工作效率。 …

chatgpt赋能python:Python改变当前目录的SEO指南

Python改变当前目录的SEO指南 介绍 对于SEO来说,网站的目录结构和文件命名是非常重要的。良好的目录结构可以帮助搜索引擎更好地理解您的网站内容,而有意义的文件命名可以提高页面的可读性并有助于排名。 但在开发过程中,我们经常需要在不…

铁粉数量上一百了

铁粉数量上一百了 常写博客,常进步。

【Python】类与对象

知识目录 一、写在前面✨二、类与对象简介三、Car类的实现四、Date类的实现五、总结撒花😊 一、写在前面✨ 大家好!我是初心,希望我们一路走来能坚守初心! 今天跟大家分享的文章是 Python中面向对象编程的类与对象。 &#xff0…

一道北大强基题背后的故事(一)——从走弯路到看答案

早点关注我,精彩不错过! 在前面的系列文章《我的数学学习回忆录——一个数学爱好者的反思(二)》中,我从宏观层面回忆了我的数学学习历程和反思。其实,我和数学之间还有很多很多意识流一样的交流和故事&…

训练DeeplabV3+来分割车道线

本例我们训练DeepLabV3语义分割模型来分割车道线。 DeepLabV3模型的原理有以下一些要点: 1,采用Encoder-Decoder架构。 2,Encoder使用类似Xception的结构作为backbone。 3,Encoder还使用ASPP(Atrous Spatial Pyramid Pooling)&…