应用torchinfo计算网络的参数量

news2024/9/24 11:30:41

1 问题

定义好一个VGG11网络模型后,我们需要验证一下我们的模型是否按需求准确无误的写出,这时可以用torchinfo库中的summary来打印一下模型各层的参数状况。这时发现表中有一个param以及在经过两个卷积后参数量(param)没变,出于想知道每层的param是怎么计算出来,于是对此进行探究。

2 方法

1、网络中的参数量(param)是什么?

param代表每一层需要训练的参数个数,在全连接层是突触权重的个数,在卷积层是卷积核的参数的个数。

2、网络中的参数量(param)的计算。

卷积层计算公式:Conv2d_param=(卷积核尺寸*输入图像通道+1)*卷积核数目

池化层:池化层不需要参数。

全连接计算公式:Fc_param=(输入数据维度+1)*神经元个数

3、解释一下图表中vgg网络的结构和组成。vgg11的网络结构即表中的第一列:

287e9ca0c371c5cb5ff7c806b19112bb.pngconv3-64→maxpool→conv3-128→maxpool→conv3-256→conv3-256→maxpool→conv3-512→conv3-512→maxpool→conv3-512→conv3-512→maxpool→FC-4096→FC-4096→FC-1000→softmax。

4、代码展示

import torch
from torch import nn
from torchinfo import summary
class MyNet(nn.Module):
   #定义哪些层
   def __init__(self) :
       super().__init__()
       #(1)conv3-64
       self.conv1 = nn.Conv2d(
           in_channels=1, #输入图像通道数
           out_channels=64,#卷积产生的通道数(卷积核个数)
           kernel_size=3,#卷积核尺寸
           stride=1,
           padding=1       #不改变特征图大小
       )  
       self.max_pool_1 = nn.MaxPool2d(2)
       #(2)conv3-128
       self.conv2 = nn.Conv2d(
           in_channels=64,
           out_channels=128,
           kernel_size=3,
           stride=1,
           padding=1
       )
       self.max_pool_2 = nn.MaxPool2d(2)
       #(3)conv3-256
       self.conv3 = nn.Conv2d(
           in_channels=128,
           out_channels=256,
           kernel_size=3,
           stride=1,
           padding=1
       )
       self.conv4 = nn.Conv2d(
           in_channels=256,
           out_channels=256,
           kernel_size=3,
           stride=1,
           padding=1
       )
       self.max_pool_3 = nn.MaxPool2d(2)
       #(4)conv3-512
       self.conv5 = nn.Conv2d(
           in_channels=256,
           out_channels=512,
           kernel_size=3,
           stride=1,
           padding=1
       )
       self.conv6 = nn.Conv2d(
           in_channels=512,
           out_channels=512,
           kernel_size=3,
           stride=1,
           padding=1
       )
       self.max_pool_4 = nn.MaxPool2d(2)
       #(5)conv3-512
       self.conv7 = nn.Conv2d(
           in_channels=512,
           out_channels=512,
           kernel_size=3,
           stride=1,
           padding=1
       )
       self.conv8 = nn.Conv2d(
           in_channels=512,
           out_channels=512,
           kernel_size=3,
           stride=1,
           padding=1
       )
       self.max_pool_5 = nn.MaxPool2d(2)
       self.fc1 = nn.Linear(in_features=7*7*512,out_features=4096)
       self.fc2 = nn.Linear(in_features=4096,out_features=4096)
       self.fc3 = nn.Linear(in_features=4096,out_features=1000)
   #计算流向
   def forward(self,x):
       x = self.conv1(x)
       x = self.max_pool_1(x)
       x = self.conv2(x)
       x = self.max_pool_2(x)
       x = self.conv3(x)
       x = self.conv4(x)
       x = self.max_pool_3(x)
       x = self.conv5(x)
       x = self.conv6(x)
       x = self.max_pool_4(x)
       x = self.conv7(x)
       x = self.conv8(x)
       x = self.max_pool_5(x)
       x = torch.flatten(x,1)  #[B,C,H,W]从C开始flatten,B不用flatten,所以要加1
       x = self.fc1(x)
       x = self.fc2(x)
       out = self.fc3(x)
       return out
if __name__ == '__main__':
   x = torch.rand(128,1,224,224)
   net = MyNet()
   out = net(x)
   #print(out.shape)
   summary(net, (12,1,224,224))

2f89632206c7e595774539a2cbd6e8a5.png输出结果:

图片中红色方块计算过程:

1:相关代码及计算过程(卷积层)

self.conv7 = nn.Conv2d(
           in_channels=512,
           out_channels=512,
           kernel_size=3,
           stride=1,
           padding=1
       )

Conv2d_param= (3*3*512+1)*512=2,359,808(Conv2d-12代码同,故param同)

2:相关代码及计算过程

self.fc3 = nn.Linear(in_features=4096,out_features=1000)

Fc_fc_param=(4096+1)*1000=4,097,000

3 结语

以上为一般情况下参数量计算方法,当然还有很多细节与很多其他情况下的计算方法没有介绍,主要用来形容模型的大小程度,针对不同batch_size下param的不同,可以用于参考来选择更合适的batch_size。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/131897.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

从socket开始讲解网络模式(epoll)

从socket开始讲解网络模式 windows采用IOCP网络模型,而linux采用epoll网络模型(Linux得以实现高并发,并被作为服务器首选的重要原因),接下来讲下epoll模型对网络编程高并发的作用 简单的socket连接 socket连接交互的…

Python学习笔记-PyQt6之MVC项目结构初试

MVC结构是之model-view-controller三层架构的开发框架,用以将项目界面和逻辑进行解耦分析,便于维护。与WPF的MVVM相似。 项目开发做了一个秒表试手: 1.项目架构如下 controller:用于放置界面的控制逻辑model:用于放置…

回顾这十年,感悟

十年前,我35岁了,在体制内工作,到了很多人眼里的躺平的年龄。我眼里的世界,也就那么大,没有想过更进一步,有点中年油腻了,体质也差了。……终于有一天,醒悟了,不想过这样…

高并发系统设计 -- 秒杀系统

高并发秒杀 秒杀问题相信大家都知道的,虽然是一个烂大街的项目,但是秒杀问题背后的知识是很值得学习的,很多高并发系统设计都可以参照秒杀系统来进行实现。而且顺着这个问题,我会教给大家如何进行高并发的系统设计。 我们先来看…

Android集成三方浏览器之Crosswalk

上一篇讲解了腾讯 X5 内核的集成,这一篇是讲解 Crosswalk 的集成 Crosswalk 也是采用了Chromenium 内核,是一款开源的 web 引擎,开发者可以直接把 Crosswalk 嵌入到应用之中,当然也支持共享模式(系统中没有对应的 Cros…

费解的开关(BFS+哈希表+二进制枚举)

费解的开关(BFS哈希表二进制枚举)一、题目二、思路分析1、算法标签2、思路梳理方法1:BFS哈希表方法2:二进制枚举DFS一、题目 二、思路分析 1、算法标签 这道题考察的是BFS哈希表,DFS二进制枚举 2、思路梳理 方法1:…

Cohen–Sutherland 算法介绍(简单易懂)

目录 一、算法介绍 二、算法描述 三、算法总结 一、算法介绍 Cohen–Sutherland 算法用于直线段裁剪,通过判断直线与窗口之间的关系,来决定直线段部分的保留与舍弃。 二、算法描述 ① 首先,我们把屏幕分割成 9 个区域块,最中间区…

音乐相册如何制作?一步一步教会你

很多小伙伴会在旅行时,拍摄各种好看的照片,一趟旅途下来能留下好多照片呢,有些人会习惯将这些照片归类到一个相册里。其实我们也可以使用一些免费的软件将这些照片制作成有纪念意义的音乐相册,那大家知道免费制作音乐相册怎么做吗…

npm install 报警告npm WARN

npm install 报警告npm WARN optional SKIPPING OPTIONAL DEPENDENCY: fsevents1.2.0 (node_modules\fsevents npm notice created a lockfile as package-lock.json. You should commit this file. npm WARN fsevents1.2.0 had bundled packages that do not match the requi…

Crack:Inobitec DICOM Viewer Pro 2.9 多语言版本

Inobitec DICOM Viewer Pro 的使命是扩大医生可见和可能的范围。通过为医学提供高质量的创新 IT 解决方案,Ω578867473为改善全世界人民的健康做出了贡献。感受到自己工作的价值,意识到 21 世纪医学面临的挑战的重要性,以及解决这些挑战的乐趣…

WordPress使用二级域名存储图片等静态资源达到网站加速的详细配置

最近发现源站压力较大(水管太小)于是想着把WordPress博客的图片等静态资源分离到二级域名中,二级域名再使用一次云盾免费加速CDN,达到动静分离的效果,在这个过程中遇到一些坑,特此记录一下,方便…

NumpyPandas 数据处理与挖掘

笔记来源B站:https://www.bilibili.com/video/BV1xt411v7z9?p21 python学习笔记1 Numpy1.1 Numpy优势1.1.1 Numpy介绍1.1.2 ndarray介绍1.1.3 ndarray与Python原生list效率对比1.1.4 ndarray优势1.2 认识N维数组-ndarray属性1.2.1 ndarray的属性1.2.2 ndarray的形状…

11.1、基于Django4的可重用、用户注册和登录系统搭建

文章目录系统的功能思路分析搭建项目环境创建项目(虚拟环境)创建子应用修改语言、时区创建数据库表启动项目git提交项目代码到本地仓库git initi 初始化,创建本地git仓库pycharm安装 .ignore插件,来设置git的忽略文件提交代码修改…

SpringBoot+VUE前后端分离项目学习笔记 - 【09 SpringBoot集成MyBatis-Plus和SwaggerUI】

集成mybatis-plus依赖 官网 : https://baomidou.com/ pom.xml <!-- mybatis-plus --><dependency><groupId>com.baomidou</groupId><artifactId>mybatis-plus-boot-starter</artifactId><version>3.5.1</version></depe…

01月份图形化一级打卡试题

活动时间 从2023年 1月1日至1月21日&#xff0c;每天一道编程题。 本次打卡的规则如下&#xff1a; &#xff08;1&#xff09;小朋友每天利用10~15分钟做一道编程题&#xff0c;遇到问题就来群内讨论&#xff0c;我来给大家答疑。 &#xff08;2&#xff09;小朋友做完题目后&…

认证的未来:2023 年值得关注的四大趋势

在经济不确定性和地缘政治紧张的一年中&#xff0c;数字领域充满网络威胁也就不足为奇了。从广泛的假冒诈骗到日益增多的短信网络钓鱼&#xff0c;网络攻击的频率和严重程度在 2022 年有所增加&#xff0c;这突显了所有行业的组织身份验证漏洞。 因此&#xff0c;当我们翻开新…

amis组件学习的配置介绍(二)

table view 表格视图 这个看文档也很好理解&#xff0c;但是还是需要介绍一下。 trs&#xff1a; <Array>设置表格行属性。tds: <Array>设置单元格属性。 {"type": "table-view",// 设置表格行"trs": [{"background": &…

常见排序算法(上)

篮球哥温馨提示&#xff1a;编程的同时不要忘记锻炼哦&#xff01;稳定的排序算法&#xff0c;可以设计成不稳定的. 目录 1、 认识排序 2、常见排序的分类 3、直接插入排序 4、希尔排序(缩小增量排序) 5、选择排序 6、堆排序 1、 认识排序 在学校中&#xff0c;如果我们…

QML学习笔记【03】:动画

动画是在指定的时间内&#xff0c;一系列属性的持续变化 1 动画元素&#xff08;Animation Elements&#xff09; 有几种类型的动画&#xff0c;每一种都在特定情况下都有最佳的效果&#xff0c;下面列出了一些常用的动画&#xff1a; PropertyAnimation&#xff08;属性动画…

人工智能学习07--pytorch01

一、pytorch简介 1、与TensorFlow区别 2、常用网络层 二、pytorch需要&#xff1a; 1、anaconda 2、CUDA 只能在NVIDIA上运行 ↓我发现电脑果然没有这个显卡 https://zhidao.baidu.com/question/2084255692200398828.html 3、pycharm 新项目要配置python的编译器&#xff…