文章目录
- 一、基本使用
- 二、常见指标
- 2.1Input size
- 2.2Forward/backward pass size
一、基本使用
torchsummary
库是一个好用的模型可视化工具,用于帮助开发者把握每个网络层级的细节,包括其中的连接和维度。使用方法:
from torchsummary import summary
库中仅有一个函数:
summary(model, input_size, batch_size=-1, device="cuda"):
model
:模型对象。input_size
:输入数据的格式,使用(C,H,W)格式。batch_size
:批数据的数量。device
:使用的设备。
以自定义的LeNet网络模型为例:
import torch
from torch import nn
from torchsummary import summary
class LeNet(nn.Module):
def __init__(self):
super(LeNet, self).__init__()
# 手写数字图片大小为32*32,故需填充2个像素
self.model = nn.Sequential(
nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, padding=2),
nn.Sigmoid(),
nn.AvgPool2d(kernel_size=2, stride=2),
nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5),
nn.AvgPool2d(kernel_size=2, stride=2),
nn.Flatten(),
nn.Linear(in_features=16 * 5 * 5, out_features=120),
nn.Linear(in_features=120, out_features=84),
nn.Linear(in_features=84, out_features=10),
)
def forward(self, x):
return self.model(x)
myLeNet = LeNet().to(device)
print(summary(myLeNet, input_size=(1, 28, 28), batch_size=64, device='cuda'))
二、常见指标
2.1Input size
Input size
表示输入数据的大小。在上述例子中,batch_size=64
,每张图片大小为(1,28,28)
,而Pytorch默认使用float32(双精度浮点数)占4字节,则每个batch所用内存大小为:
64
x
1
x
28
x
28
x
4
=
200
,
704
(
B
y
t
e
s
)
64x1x28x28x4=200,704(Bytes)
64x1x28x28x4=200,704(Bytes)
转化为以MB为单位:
200
,
704
/
102
4
2
(
B
y
t
e
s
)
=
0.19140625
(
B
y
t
e
s
)
200,704/1024^2(Bytes)=0.19140625(Bytes)
200,704/10242(Bytes)=0.19140625(Bytes)
约等于0.19MB。
2.2Forward/backward pass size
https://blog.csdn.net/weixin_43589323/article/details/137105988?ops_request_misc=&request_id=&biz_id=102&utm_term=torchsummary&utm_medium=distribute.pc_search_result.none-task-blog-2allsobaiduweb~default-5-137105988.142v100pc_search_result_base2&spm=1018.2226.3001.4187