深度学习基础知识 使用torchsummary、netron、tensorboardX查看模参数结构

news2024/12/23 17:31:03

深度学习基础知识 使用torchsummary、netron、tensorboardX查看模参数结构

  • 1、直接打印网络参数结构
  • 2、采用torchsummary检测、查看模型参数结构
  • 3、采用netron检测、查看模型参数结构
  • 3、使用tensorboardX

1、直接打印网络参数结构

import torch.nn as nn
from torchsummary import summary
import torch


class Alexnet(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(nn.Conv2d(3, 96, kernel_size=11, stride=4, padding=1), nn.ReLU(),
                                 nn.MaxPool2d(kernel_size=3, stride=2),
                                 nn.Conv2d(96, 256, kernel_size=5, padding=2), nn.ReLU(),
                                 nn.MaxPool2d(kernel_size=3, stride=2),
                                 nn.Conv2d(256, 384, kernel_size=3, padding=1), nn.ReLU(),
                                 nn.Conv2d(384, 384, kernel_size=3, padding=1), nn.ReLU(),
                                 nn.Conv2d(384, 256, kernel_size=3, padding=1), nn.ReLU(),
                                 nn.MaxPool2d(kernel_size=3, stride=2),
                                 nn.Flatten(), nn.Linear(256 * 5 * 5, 4096), nn.ReLU(),
                                 nn.Dropout(0.5),
                                 nn.Linear(4096, 4096), nn.ReLU(),
                                 nn.Dropout(0.5),
                                 nn.Linear(4096, 10))

    def forward(self, X):
        return self.net(X)

if __name__=="__main__":
    
    device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model=Alexnet().to(device)
    print(model)
    # summary(model,(3,224,224),16)

结果输出:

Alexnet(
  (net): Sequential(
    (0): Conv2d(3, 96, kernel_size=(11, 11), stride=(4, 4), padding=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(96, 256, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (4): ReLU()
    (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(256, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU()
    (8): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU()
    (10): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU()
    (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (13): Flatten(start_dim=1, end_dim=-1)
    (14): Linear(in_features=6400, out_features=4096, bias=True)
    (15): ReLU()
    (16): Dropout(p=0.5, inplace=False)
    (17): Linear(in_features=4096, out_features=4096, bias=True)
    (18): ReLU()
    (19): Dropout(p=0.5, inplace=False)
    (20): Linear(in_features=4096, out_features=10, bias=True)
  )
)

上述方案存在的问题是:当网络参数设置存在错误时,无法检测出来

2、采用torchsummary检测、查看模型参数结构

安装torchsummary

pip install torchsummary

通常采用torchsummary打印网络结构参数时,会出现以下问题
代码:

import torch.nn as nn
from torchsummary import summary


class Alexnet(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(nn.Conv2d(3, 96, kernel_size=11, stride=4, padding=1), nn.ReLU(),
                                 nn.MaxPool2d(kernel_size=3, stride=2),
                                 nn.Conv2d(96, 256, kernel_size=5, padding=2), nn.ReLU(),
                                 nn.MaxPool2d(kernel_size=3, stride=2),
                                 nn.Conv2d(256, 384, kernel_size=3, padding=1), nn.ReLU(),
                                 nn.Conv2d(384, 384, kernel_size=3, padding=1), nn.ReLU(),
                                 nn.Conv2d(384, 256, kernel_size=3, padding=1), nn.ReLU(),
                                 nn.MaxPool2d(kernel_size=3, stride=2),
                                 nn.Flatten(), nn.Linear(256 * 5 * 5, 4096), nn.ReLU(),
                                 nn.Dropout(0.5),
                                 nn.Linear(4096, 4096), nn.ReLU(),
                                 nn.Dropout(0.5),
                                 nn.Linear(4096, 10))

    def forward(self, X):
        return self.net(X)


net = Alexnet()
print(summary(net, (3, 224, 224), 8))

报错内容如下:

RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should be the same

报错原因分析:

在使用torchsummary可视化模型时候报错,报这个错误是因为类型不匹配,根据报错内容可以看出Input type为torch.FloatTensor(CPU数据类型),而weight type(即网络权重参数这些)为torch.cuda.FloatTensor(GPU数据类型)

解决方案:

将model传到GPU上便可。将代码如下修改便可正常运行

if __name__ == "__main__":
    from torchsummary import summary
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = UNet().to(device)	# modify
    print(model)
    summary(model, input_size=(3, 224, 224))

整体代码:

import torch.nn as nn
from torchsummary import summary
import torch


class Alexnet(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(nn.Conv2d(3, 96, kernel_size=11, stride=4, padding=1), nn.ReLU(),
                                 nn.MaxPool2d(kernel_size=3, stride=2),
                                 nn.Conv2d(96, 256, kernel_size=5, padding=2), nn.ReLU(),
                                 nn.MaxPool2d(kernel_size=3, stride=2),
                                 nn.Conv2d(256, 384, kernel_size=3, padding=1), nn.ReLU(),
                                 nn.Conv2d(384, 384, kernel_size=3, padding=1), nn.ReLU(),
                                 nn.Conv2d(384, 256, kernel_size=3, padding=1), nn.ReLU(),
                                 nn.MaxPool2d(kernel_size=3, stride=2),
                                 nn.Flatten(), nn.Linear(256 * 5 * 5, 4096), nn.ReLU(),
                                 nn.Dropout(0.5),
                                 nn.Linear(4096, 4096), nn.ReLU(),
                                 nn.Dropout(0.5),
                                 nn.Linear(4096, 10))

    def forward(self, X):
        return self.net(X)

if __name__=="__main__":
    
    device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model=Alexnet().to(device)
    # print(model)
    summary(model,(3,224,224),16)  # 16:表示传入的数据批次

打印结果:

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1           [16, 96, 54, 54]          34,944
              ReLU-2           [16, 96, 54, 54]               0
         MaxPool2d-3           [16, 96, 26, 26]               0
            Conv2d-4          [16, 256, 26, 26]         614,656
              ReLU-5          [16, 256, 26, 26]               0
         MaxPool2d-6          [16, 256, 12, 12]               0
            Conv2d-7          [16, 384, 12, 12]         885,120
              ReLU-8          [16, 384, 12, 12]               0
            Conv2d-9          [16, 384, 12, 12]       1,327,488
             ReLU-10          [16, 384, 12, 12]               0
           Conv2d-11          [16, 256, 12, 12]         884,992
             ReLU-12          [16, 256, 12, 12]               0
        MaxPool2d-13            [16, 256, 5, 5]               0
          Flatten-14                 [16, 6400]               0
           Linear-15                 [16, 4096]      26,218,496
             ReLU-16                 [16, 4096]               0
          Dropout-17                 [16, 4096]               0
           Linear-18                 [16, 4096]      16,781,312
             ReLU-19                 [16, 4096]               0
          Dropout-20                 [16, 4096]               0
           Linear-21                   [16, 10]          40,970
================================================================
Total params: 46,787,978
Trainable params: 46,787,978
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 9.19
Forward/backward pass size (MB): 163.58
Params size (MB): 178.48
Estimated Total Size (MB): 351.25
----------------------------------------------------------------

3、采用netron检测、查看模型参数结构

安装netron与onnx

pip install netron onnx

代码实现:

import torch.nn as nn
import netron
import torch
from onnx import shape_inference
import onnx


class Alexnet(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(nn.Conv2d(3, 96, kernel_size=11, stride=4, padding=1), nn.ReLU(),
                                 nn.MaxPool2d(kernel_size=3, stride=2),
                                 nn.Conv2d(96, 256, kernel_size=5, padding=2), nn.ReLU(),
                                 nn.MaxPool2d(kernel_size=3, stride=2),
                                 nn.Conv2d(256, 384, kernel_size=3, padding=1), nn.ReLU(),
                                 nn.Conv2d(384, 384, kernel_size=3, padding=1), nn.ReLU(),
                                 nn.Conv2d(384, 256, kernel_size=3, padding=1), nn.ReLU(),
                                 nn.MaxPool2d(kernel_size=3, stride=2),
                                 nn.Flatten(), nn.Linear(256 * 5 * 5, 4096), nn.ReLU(),
                                 nn.Dropout(0.5),
                                 nn.Linear(4096, 4096), nn.ReLU(),
                                 nn.Dropout(0.5),
                                 nn.Linear(4096, 10))

    def forward(self, X):
        return self.net(X)

if __name__=="__main__":
    
    device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model=Alexnet()
    temp_image=torch.rand((1,3,224,224))

    # 1、利用torch.onnx.export,先将模型导出为onnx格式的文件,保存到本地./model.onnx
    torch.onnx.export(model=model,args=temp_image,f='model.onnx',input_names=['image'],output_names=['feature_map'])
    
    # 2、加载进onxx模型,并推理,然后再保存覆盖原先模型
    onnx.save(onnx.shape_inference.infer_shapes(onnx.load("model.onnx")),"model.onnx")
    netron.start('model.onnx')

运行后,显示结构:
在这里插入图片描述
在这里插入图片描述

3、使用tensorboardX

在这里插入图片描述
代码实现:

import torch
import torch.nn as nn
from tensorboardX import SummaryWriter as SummaryWriter


class Alexnet(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(nn.Conv2d(3, 96, kernel_size=11, stride=4, padding=1), nn.ReLU(),
                                 nn.MaxPool2d(kernel_size=3, stride=2),
                                 nn.Conv2d(96, 256, kernel_size=5, padding=2), nn.ReLU(),
                                 nn.MaxPool2d(kernel_size=3, stride=2),
                                 nn.Conv2d(256, 384, kernel_size=3, padding=1), nn.ReLU(),
                                 nn.Conv2d(384, 384, kernel_size=3, padding=1), nn.ReLU(),
                                 nn.Conv2d(384, 256, kernel_size=3, padding=1), nn.ReLU(),
                                 nn.MaxPool2d(kernel_size=3, stride=2),
                                 nn.Flatten(), nn.Linear(256 * 5 * 5, 4096), nn.ReLU(),
                                 nn.Dropout(0.5),
                                 nn.Linear(4096, 4096), nn.ReLU(),
                                 nn.Dropout(0.5),
                                 nn.Linear(4096, 10))

    def forward(self, X):
        return self.net(X)


net = Alexnet()
img = torch.rand((1, 3, 224, 224))
with SummaryWriter(log_dir='logs') as w:
    w.add_graph(net, img)

运行后,会在本地生成一个log日志文件
在命令行运行以下指令:

tensorboard --logdir ./logs --port 6006

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1072525.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Linux入门攻坚——2、基础命令学习

Linux就是命令集的操作系统,精通Linux,就要先精通各种命令。 date : date [OPTION] ... [FORMAT] : 显示日期时间(系统时钟) FORMAT:格式符号 %D:月/日/年…

优思学院|零库存:丰田精益管理的成功之道(CLMP)

在如今竞争激烈的商业世界中,企业需要不断寻求新的方法来提高效率、降低成本,并确保产品的高质量。其中一种成功的策略是实施零库存管理,而日本汽车制造巨头丰田公司就以其独特的零库存策略而闻名全球。优思学院在本文中将会深入探讨什么是零…

Android Studio: unrecognized Attribute name MODULE

错误完整代码: ������ (1.8.0_291) �г����쳣������&#xff…

C++对象模型(3)-- 类对象所占用的空间

类对象所占用的空间可以用sizeof()函数获取,在C对象模型中,类对象所占用的空间规则如下: (1) 空类占1字节,以使该类对象在内存得以配置一个地址。 (2) 对象所占用的空间由3个因素决定:非静态成员变量、虚函数、字节填…

图形学 -- Rasterization栅格化

参考视频:Lecture 05 Rasterization 1 (Triangles)_哔哩哔哩_bilibili 视锥: 定义一个垂直角度,定义宽高比 投到[-1,1]^3之后要呈现到屏幕上 屏幕 屏幕(一个二位数组) 屏幕是个典型的光栅成像设备 定义屏幕空间 映…

【每日一练】勾股定理困难版

目录 题目官方给的解题思路源代码附最大公因数辗转相除法更相减损术 所有因数参考文献 题目 给定斜边z的值&#xff0c;求所有直角边x和y的组合数&#xff08;x、y和z都是正整数&#xff09;。 仅有一行输入&#xff0c;即斜边z的值&#xff08;z是正整数&#xff0c;且z<1…

23种经典设计模式:单例模式篇(C++)

前言&#xff1a; 博主将从此篇单例模式开始逐一分享23种经典设计模式&#xff0c;并结合C为大家展示实际应用。内容将持续更新&#xff0c;希望大家持续关注与支持。 什么是单例模式&#xff1f; 单例模式是设计模式的一种&#xff08;属于创建型模式 (Creational Pa…

将本地代码提交到git新仓库

建仓 首先需要新建一个仓库&#xff0c;注意一定要是空仓库&#xff0c;不要选任何初始化 在代码所在目录右击&#xff0c;进入Git Bash Here 初始化git仓库 git init将文件添加进库 git add .进行提交&#xff0c;-m 后面引号中的内容是本次提交内容&#xff0c;自行填写…

STM32F103 最小系统 PCB 设计与原理

这篇文章是来自我学习&#xff1a; ​​​​​​带着你从手册开始画板 STM最小系统板教程系列(一)_哔哩哔哩_bilibili​​​​​​ 这套教程的笔记&#xff0c;同时本文中也参考了其他教程以及我遇到的困惑与自答&#xff0c;最终汇总。 一、单片机最小系统 单片机最小系统是由…

Centos7中安装Jenkins教程

1.必须先配置jdk环境&#xff0c;安装jdk参考 Linux配置jdk 2.先卸载Jenkins # rpm卸载 rpm -e jenkins # 检查是否卸载成功 rpm -ql jenkins # 彻底删除残留文件 find / -iname jenkins | xargs -n 1000 rm -rf 3.安装Jenkins 在 /usr/ 目录下创建 jenkins文件夹 mkdir -p je…

Fastadmin 子级菜单展开合并,分类父级归纳

这里踩过一个坑&#xff0c;fastadmin默认的展开合并预定义处理的变量是pid。 所以建表时父级id需要是pid&#xff1b; 当然不是pid也没关系&#xff0c;这里以cat_id为例&#xff0c;多加一步处理一样能实现。 废话少说上代码&#xff1a; 首先在控制器&#xff0c; 引用…

使用HbuilderX运行uniapp中小程序项目

下载HbuilderX&#xff0c;下载链接&#xff1a; HBuilderX-高效极客技巧 导入相关项目。下载微信开发者工具。使用微信开发者工具打开&#xff1a;注意&#xff1a;如果是第一次使用&#xff0c;需要先配置小程序ide的相关路径&#xff0c;才能运行成功。如下图&#xff0c;需…

国产开源无头CMS,MyCms v4.7 快捷生成接口开发后台

MyCms 是一款基于 Laravel 开发的开源免费的开源多语言商城 CMS 企业建站系统。 MyCms 基于 Apache2.0 开源协议发布&#xff0c;免费且可商业使用&#xff0c;欢迎持续关注我们。技术交流 QQ 群&#xff1a;887522124 加群请备注来源&#xff1a;如gitee、github、官网等 v4…

什么是智能档案柜?如何使用智能档案柜?

智能档案柜是一种具有智能化功能的文件存储设备&#xff0c;它通过应用现代科技&#xff0c;集成了电子锁、自动化控制、智能管理系统技术&#xff0c;具有自动识别、高效存储、安全可靠等特点&#xff0c;提高档案管理的效率和安全性。适用于企业单位、图书馆等需要储存文件资…

(自学)黑客技术方法——网络安全篇

如果你想自学网络安全&#xff0c;首先你必须了解什么是网络安全&#xff01;&#xff0c;什么是黑客&#xff01;&#xff01; 1.无论网络、Web、移动、桌面、云等哪个领域&#xff0c;都有攻与防两面性&#xff0c;例如 Web 安全技术&#xff0c;既有 Web 渗透2.也有 Web 防…

JRebel在IDEA中实现热部署 (JRebel实用版)

JRebel简介&#xff1a; JRebel是与应用程序服务器集成的JVM Java代理&#xff0c;可使用现有的类加载器重新加载类。只有更改的类会重新编译并立即重新加载到正在运行的应用程序中&#xff0c;JRebel特别不依赖任何IDE或开发工具&#xff08;除编译器外&#xff09;。但是&…

Pyside6 QRadioButton

Pyside6 QRadioBox QRadioButton使用QRadioButton分组QRadioButton设置文本代码设置界面设置 QRadioButton禁用和启用代码设置界面设置 QRadioButton设置默认值代码设置界面设置 读取QRadioButton状态QRadioButton样式设计代码设置界面设置 完整程序界面程序主程序 QRadioButto…

语音芯片基础知识 什么是语音芯 他有什么作用 发展趋势是什么

目录 一、语音芯片的简介 常见的语音芯片有哪些&#xff1f; 语音芯片的种类有很多&#xff0c;大体区分下来也就4个类别而已&#xff1a; 选型的经验说明如下&#xff1a; 推荐使用flash型语音芯片 一、语音芯片的简介 语音芯片基础知识&#xff1a; 什么是语音芯片&…

计算机竞赛 题目:基于深度学习的手势识别实现

文章目录 1 前言2 项目背景3 任务描述4 环境搭配5 项目实现5.1 准备数据5.2 构建网络5.3 开始训练5.4 模型评估 6 识别效果7 最后 1 前言 &#x1f525; 优质竞赛项目系列&#xff0c;今天要分享的是 基于深度学习的手势识别实现 该项目较为新颖&#xff0c;适合作为竞赛课题…