深度学习(28)——YOLO系列(7)
咱就是说,需要源码请造访:Jane的GitHub
:在这里
上午没写完的,下午继续,是一个小尾巴。其实上午把训练的关键部分和数据的关键部分都写完了,现在就是写一下推理部分
在推理过程为了提高效率,速度更快:
detect 全过程
1.1 attempt_load(weights)
- weights是加载的yolov7之前训练好的权重
- 刚开始load以后还有BN,没有合并的
- 关键在下面的fuse()
1.2 model.fuse()
# 很隐蔽,刚开始我没想到接口是在这里的
def fuse(self): # fuse model Conv2d() + BatchNorm2d() layers
print('Fusing layers... ')
for m in self.model.modules():
if isinstance(m, RepConv):
#print(f" fuse_repvgg_block")
m.fuse_repvgg_block()
elif isinstance(m, RepConv_OREPA):
#print(f" switch_to_deploy")
m.switch_to_deploy()
elif type(m) is Conv and hasattr(m, 'bn'):
m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv
delattr(m, 'bn') # remove batchnorm
m.forward = m.fuseforward # update forward
elif isinstance(m, (IDetect, IAuxDetect)):
m.fuse()
m.forward = m.fuseforward
self.info()
return self
当遇到conv后面一定是有BN的,所以
1.3 fuse_conv_and_bn(conv,bn)
- 先定义一个新的conv【和原来传入的是一样的inputsize,outputsize和kernel】
- 先得到w_conv:
w_conv = conv.weight.clone().view(conv.out_channels, -1)
- 得到w_bn:
w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
【bn.weight 就是以下公式中的gamma,sigma平方是方差bn.running_var
】 - 得到w_fuse:
fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.shape))
- 得到b_conv,因为在学习过程中bias我们都设置为0,所以:
b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias
- 得到b_bn :
b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
【bn.bias是上面公式中的β,μ为均值bn.running_mean】 - 计算b_fuse
fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)
def fuse_conv_and_bn(conv, bn):
# Fuse convolution and batchnorm layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/
fusedconv = nn.Conv2d(conv.in_channels,
conv.out_channels,
kernel_size=conv.kernel_size,
stride=conv.stride,
padding=conv.padding,
groups=conv.groups,
bias=True).requires_grad_(False).to(conv.weight.device)
# prepare filters bn.weight 对应论文中的gamma bn.bias对应论文中的beta bn.running_mean则是对于当前batch size的数据所统计出来的平均值 bn.running_var是对于当前batch size的数据所统计出来的方差
w_conv = conv.weight.clone().view(conv.out_channels, -1)
w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.shape))
# prepare spatial bias
b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias
b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)
return fusedconv
1.4 Repvgg_block
把Repvgg中的卷积和BN合在一起
- 原来的block↓
- 融合rbr_dense后:
- 融合rbr_1*1后:
1.5 将1* 1卷积padding成3* 3
padding后
所有的都改变以后:model长这样——>
OK,这次真没啦,886~~~~