Flops计算量,params参数量
在文件中
tools/analysis_tools/get_flops.py
利用以下命令实现
python tools/analysis_tools/get_flops.py configs/xxx/xxx-Net.py
后面可跟参数shape控制输入图片尺寸,例如
python tools/analysis_tools/get_flops.py configs/xxx/xxx-Net.py --shape 512 512
``
如下展示
```python
python tools/analysis_tools/get_flops.py configs/danet/danet_r50-d8_4xb4-40k_voc12aug-512x512.py
输出
==============================
Compute type: direct: randomly generate a picture
Input shape: (512, 512)
Flops: 0.211T
Params: 47.485M
==============================
坑点1
input_shape" and "inputs" cannot be both set.
在87行左右,由于配置文件在配置了data = model.data_preprocessor(data_batch)
,所有data中有数据,同时,input_shape通过默认参数得到,两个不能同时有值,所以将data注释掉。希望通过输入的参数计算。
接下来看一下这个get_model_complexity_info
函数都输出的是什么
{'flops': 211043745792,
'flops_str': '0.211T',
'activations': 168120320,
'activations_str': '0.168G',
'params': 47484961,
'params_str': '47.485M',
'out_table': '', 'out_arch': ''}
如何计算某一模块的计算量和参数量呢?
主要看以下代码
flop_handler = FlopAnalyzer(model, inputs)
activation_handler = ActivationAnalyzer(model, inputs)
flops = flop_handler.total()
activations = activation_handler.total()
params = parameter_count(model)['']
导入
from mmengine.analysis import (ActivationAnalyzer, FlopAnalyzer, parameter_count)
这里看一下FlopAnalyzer是如何使用
Examples:
>>> import torch.nn as nn
>>> import torch
>>> class TestModel(nn.Module):
... def __init__(self):
... super().__init__()
... self.fc = nn.Linear(in_features=1000, out_features=10)
... self.conv = nn.Conv2d(
... in_channels=3, out_channels=10, kernel_size=1
... )
... self.act = nn.ReLU()
... def forward(self, x):
... return self.fc(self.act(self.conv(x)).flatten(1))
>>> model = TestModel()
>>> inputs = (torch.randn((1,3,10,10)),)
>>> flops = FlopAnalyzer(model, inputs)
>>> flops.total()
13000
>>> flops.total("fc")
10000
>>> flops.by_operator()
Counter({"addmm" : 10000, "conv" : 3000})
>>> flops.by_module()
Counter({"" : 13000, "fc" : 10000, "conv" : 3000, "act" : 0})
>>> flops.by_module_and_operator()
{"" : Counter({"addmm" : 10000, "conv" : 3000}),
"fc" : Counter({"addmm" : 10000}),
"conv" : Counter({"conv" : 3000}),
"act" : Counter()
}