卷积神经网络参数量和计算量的计算方法

news2024/11/25 16:48:37

本文总结了标准卷积、分组卷积和全连接层参数量和计算量的计算方法,如有错误,麻烦大家指正

一、标准卷积

假设输入特征的shape为[C_{in}, H_{in}, W_{in}],卷积核的shape为[C_{in}, C_{out}, k, k],输出特征的shape为[C_{out}, H_{out}, W_{out}],则,

标准卷积运算的参数量为:

  • 不考虑bias:params=k\times k\times C_{in}\times C_{out}
  • 考虑bias:params=(k\times k\times C_{in}\times +1)\times C_{out}

标准卷积运算的计算量为:

  • 不考虑bias:FLOPs=[(k\times k\times C_{in})+(k\times k\times C_{in}-1)]\times H_{out}\times W_{out}\times C_{out}
  • 考虑bias:FLOPs=[(k\times k\times C_{in})+(k\times k\times C_{in}-1)+1]\times H_{out}\times W_{out}\times C_{out}

二、深度可分离卷积

假设输入特征的shape为[C_{in}, H_{in}, W_{in}],卷积核的shape为[gC_{in}, C_{out}, k, k](g为分组数量),输出特征的shape为[C_{out}, H_{out}, W_{out}],则,

分组卷积运算的参数量为:

  • 不考虑bias:params=k\times k\times \frac{C_{in}}{g}\times \frac{C_{out}}{g}\times g
  • 考虑bias:params=(k\times k\times \frac{C_{in}}{g}+1)\times \frac{C_{out}}{g}\times g

分组卷积运算的计算量为:

  • 不考虑bias:FLOPs=[(k\times k\times \frac{C_{in}}{g})+(k\times k\times \frac{C_{in}}{g}-1)]\times H_{out}\times W_{out}\times \frac{C_{out}}{g}\times g
  • 考虑bias:FLOPs=[(k\times k\times \frac{C_{in}}{g})+(k\times k\times \frac{C_{in}}{g}-1)+1]\times H_{out}\times W_{out}\times \frac{C_{out}}{g}\times g

举个栗子,假设输入特征的shape为[1,3,6,6],卷积的shape为[3,6,3,3],分组卷积的g为3,输出特征的shape为[1,6,2,2],则,标准卷积的计算量=[(3*3*3)+(3*3*3-1)+1]*2*2*6=1296,分组卷积的计算量=[(3*3*3)+(3*3*3-1)+1]*2*2*6/3*3=1296/3=432。下面使用thop验证一下,

import torch
import torch.nn as nn
from torchstat import stat
from thop import profile


x = torch.randn((1,3,6,6))
conv1 = nn.Conv2d(3,6,3,3)
conv2 = nn.Conv2d(3,6,3,3,groups=3)
flops1, params1 = profile(conv1, inputs=(x))
print(flops1)
print("自测:", (3*3*3+3*3*3-1+1)*6*2*2)
flops2, params2 = profile(conv2, inputs=(x))
print(flops2)
print("自测:", ((3*3*1+3*3*1-1+1)*6*2*2))


"""
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
1296.0
自测: 1296
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
432.0
自测: 432
"""

深入思考一下,经过标准卷积层后,输出特征图上的每一个点都由k\times k\times C_{in}计算得来。而如果经过的是分组卷积层,那么输出特征图上的每一个点则由k\times k\times \frac{C_{in}}{g}计算得来。因此,分组卷积的参数量为标准卷积参数量的\frac{1}{g}

然而,由于深度卷积的FLOPs与内存访问的比率太低,难以有效利用硬件,所以只是理论上比标准卷积的计算要快

三、全连接层

假设全连接层的shape为[C_{in}, C_{out}],则,

全连接层的参数量为:

  • 不考虑bias:params=C_{in}\times C_{out}
  • 考虑bias:params=(C_{in}+1)\times C_{out}

全连接层的计算量为:

  • 不考虑bias:FLOPs=[C_{in}+(C_{in}-1)]\times C_{out}
  • 考虑bias:FLOPs=[C_{in}+(C_{in}-1)+1]\times C_{out}

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/760490.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

C++特殊类设计及类型转换

目录 一、特殊类的设计 1.不能被拷贝的类 2.只能在堆区构建对象的类 3.只能在栈区构建对象的类 4.不能被继承的类 二、单例模式 1.饿汉模式 2.懒汉模式 3.线程安全 4.单例的释放 三、C类型转换 1.C语言的类型转换 2.static_cast 3.reinterpret_cast 4.const_cas…

Python补充笔记1-字符串

目录 1.字符串的驻留机制​编辑 2.字符串查找 2.1字符串查询操作方法 3.字符串大小写转换 3.1字符串的大小写转换方法 4.字符串内容对齐 4.1字符串内容对齐操作方法 5.字符串的劈分 5.1字符串劈分操作的方法​编辑 6.字符串判断 6.1判断字符串操作的方法​编辑 6.2字符串替换和…

虚拟化技术及实时虚拟化概述

版权声明&#xff1a;本文为本文为博主原创文章&#xff0c;未经本人同意&#xff0c;禁止转载。如有问题&#xff0c;欢迎指正。博客地址&#xff1a;https://www.cnblogs.com/wsg1100/ 实时虚拟化技术是一种针对实时应用场景的虚拟化技术&#xff0c;它要求在保证虚拟化优势…

STM32 ws2812b 最快点灯cubemx

文章目录 前言一、cubemx配置二、代码1.ws2812b.c/ws2812b.h2.主函数 前言 吐槽 想用stm32控制一下ws2812b的灯珠&#xff0c;结果发下没有一个好用的。 emmm&#xff01;&#xff01;&#xff01; 自己来吧&#xff01;&#xff01;&#xff01;&#xff01; 本篇基本不讲原理…

6、传输层TCP28

TCP协议&#xff1a;传输控制协议 1、协议实现 16位源端端口&16位对端端口&#xff1a;描述通信俩端进程32位序号&#xff1a;告诉接收端&#xff0c;这条数据在整体数据中的排序&#xff0c;接收端根据序号进行排序32位确认序号&#xff1a;向发送端进行回复确定&#xff…

pytest-html报告修改与汉化

目录 前言 生成报告 测试代码 原始报告 修改Environment 修改后的效果 修改Summary 修改后的效果 修改Results 优化Test 解决中文乱码 删除多余部分 修改后的效果 删除Links 修改后的效果 增加失败截图与用例描述 完整的conftest.py代码 汉化报告 修改plugin…

ClickHouse进阶

一、Explain查看执行计划 在 clickhouse 20.6 版本之前要查看 SQL 语句的执行计划需要设置日志级别为 trace 才能可以看到&#xff0c;并且只能真正执行 sql&#xff0c;在执行日志里面查看。 在 20.6 版本引入了原生的执行计划的语法。在 20.6.3 版本成为正式版本的功能。 …

常见的JS内置对象——字符串、数学、日期

二、字符串&#xff08;string&#xff09; 创建 一般使用第一种方式 2&#xff09;字符串的遍历 注意&#xff1a;没有foreach方法 3&#xff09;字符串的常见方法 substr()和substring()&#xff1a; substr()参数是从哪个位置开始&#xff0c;截多长 substring()参数是从…

完美匹配:一种简单的神经网络反事实推理学习表示方法

英文题目&#xff1a;Perfect Match: A Simple Method for Learning Representations For Counterfactual Inference With Neural Networks 翻译&#xff1a;完美匹配&#xff1a;一种简单的神经网络反事实推理学习表示方法 单位&#xff1a; 论文链接&#xff1a;https://a…

【状态估计】基于FOMIAUKF、分数阶模块、模型估计、多新息系数的电池SOC估计研究(Matlab代码实现)

&#x1f4a5;&#x1f4a5;&#x1f49e;&#x1f49e;欢迎来到本博客❤️❤️&#x1f4a5;&#x1f4a5; &#x1f3c6;博主优势&#xff1a;&#x1f31e;&#x1f31e;&#x1f31e;博客内容尽量做到思维缜密&#xff0c;逻辑清晰&#xff0c;为了方便读者。 ⛳️座右铭&a…

C++ 创建共享内存

共享内存用于实现进程间大量的数据传输&#xff0c;共享内存是在内存中单独开辟一段内存空间&#xff0c;这段内存空间有自己特有的数据结构&#xff0c;包括访问权限、大小和最近访问时间等。 1、shmget函数 #include <sys/ipc.h> #include <sys/shm.h> int shm…

c++——多态(补充)

优先查看&#xff1a;c——多态_Hiland.的博客-CSDN博客 目录 菱形虚拟继承子类的重写问题 菱形虚拟继承中的偏移量补充 逆向思维——汇编查看多态中被重写的虚函数 菱形虚拟继承子类的重写问题 继承环节时&#xff0c;菱形虚拟继承解决了菱形继承的数据冗余和二义性问题。…

C# Modbus通信从入门到精通(11)——Modbus RTU(调试软件Modbus Slave和Modbus Poll的使用)

前言 我们在开发Modbus程序的时候,会需要测试以下我们写的Modbus程序有没有问题,这时候就需要使用到Modbus Slave和Modbus Poll这两个软件,Modbus Slave是模拟Modbus从站,Modbus Poll是模拟Modbus从站主站的, 1、Modbus Slave 一般情况下我们开发的嗾使Modbus主站程序,…

性能测试(Jemeter)

1.性能指标 响应时间&#xff1a;一次请求的往返时间tps&#xff1a;每秒系统能够处理的事务数&#xff0c;比如订单中的下单操作&#xff0c;下单后续有很多操作&#xff0c;比如创建订单&#xff0c;扣除库存&#xff0c;清算库存等&#xff0c;这个完整操作就是一个完整的事…

【数据分享】1929-2022年全球站点的逐日最大持续风速数据(Shp\Excel\12000个站点)

气象数据是在各项研究中都经常使用的数据&#xff0c;气象指标包括气温、风速、降水、能见度等指标&#xff0c;说到气象数据&#xff0c;最详细的气象数据是具体到气象监测站点的数据&#xff01; 对于具体到监测站点的气象数据&#xff0c;之前我们分享过1929-2022年全球气象…

Qt添加第三方字体

最近开发项目时&#xff0c;据说不能用系统自带的微软雅黑字体&#xff0c;于是找一个开源的字体&#xff0c;思源黑体&#xff0c;这个是google和Adobe公司合力开发的可以免费使用。本篇记录一下Qt使用第三方字体的方式。字体从下载之家下载http://www.downza.cn/soft/266042.…

Pytest参数化——那些你不知道的使用技巧

目录 前言 装饰测试类 输出 说明 装饰测试函数 单个数据 输出 说明 一组数据 输出 说明 图解对应关系 组合数据 输出 说明 标记用例 输出 说明 嵌套字典 输出 增加可读性 使用ids参数 输出 说明 自定义id做标识 输出 说明 总结 总结&#xff1a; 前…

给你二叉树的根节点 root ,返回它节点值的中序遍历

题目&#xff1a;给你二叉树的根节点 root &#xff0c;返回它节点值的中序遍历。 要求&#xff1a;非递归实现。 1/ \2 3/ \ / \4 5 6 7中序遍历结果为&#xff1a; 4 2 5 1 6 3 7这里考察中序遍历思想&#xff0c;使用Stack的后进先出特性输出结果。 TreeNode树状结…

spring项目的创建和使用(详细教程 手把手)方法一

今天我们来讲使用maven方式创建一个sping项目。 1、创建一个普通的maven项目。 2、添加spring框架(引入依赖)支持。添加到pom.xml文件中。 添加的框架有 spring-context&#xff1a;spring 上下⽂&#xff0c;还有 spring-beans&#xff1a;管理对象的模块。 <dependenc…

python将.h5文件转换成csv

五、在jupyter中找到results文件夹&#xff0c;然后可以把跑的.h5结果转换为csv文件 pip install tables import h5py import numpy as np import pandas as pd filename Mnist_FEDL_0.003_0_10u_20b_5_avg.h5 f h5py.File(filename, r) # List all groups print("K…