目录
1.BAM介绍
2.BAM引入到yolov5
2.1 加入common.py中:
2.2 加入yolo.py中:
2.3 yolov5s_BAM.yaml
1.BAM介绍
论文:https://arxiv.org/pdf/1807.06514.pdf
摘要:提出了一种简单有效的注意力模块,称为瓶颈注意力模块(BAM),可以与任何前馈卷积神经网络集成。我们的模块沿着两条独立的路径,通道和空间,推断出一张注意力图。我们将我们的模块放置在模型的每个瓶颈处,在那里会发生特征图的下采样。我们的模块用许多参数在瓶颈处构建了分层注意力,并且它可以以端到端的方式与任何前馈模型联合训练。我们通过在CIFAR-100、ImageNet-1K、VOC 2007和MS COCO基准上进行大量实验来验证我们的BAM。我们的实验表明,各种模型在分类和检测性能上都有持续的改进,证明了BAM的广泛适用性。
作者将BAM放在了Resnet网络中每个stage之间。有趣的是,通过可视化我们可以看到多层BAMs形成了一个分层的注意力机制,这有点像人类的感知机制。BAM在每个stage之间消除了像背景语义特征这样的低层次特征,然后逐渐聚焦于高级的语义–明确的目标。
作者提出了新的Attention模型——瓶颈注意模块,通过分离的两个路径channel和spatial得到attention map,减少计算开销和参数开销。
实验
BAM可以在大规模数据集中的各种模型上有很好的泛化能力,同时参数和计算的开销可以忽略不计,这表明提出的模块BAM可以有效地提高网络容量。另一个值得注意的是,改进的性能来自于只在网络中放置三个模块。
BAM提高了所有具有两个骨干网络的强大基线的准确性.BAM的准确率提高是以可忽略不计的参数开销实现的,这表明提高不是由于天真的容量增加,而是由于我们有效的特征细化。
2.BAM引入到yolov5
2.1 加入common.py
中:
###################### BAM attention #### START by AI&CV ###############################
import torch
from torch import nn
import torch.nn.functional as F
class ChannelGate(nn.Module):
def __init__(self, channel, reduction=16):
super().__init__()
self.avgpool = nn.AdaptiveAvgPool2d(1)
self.mlp = nn.Sequential(
nn.Linear(channel, channel // reduction),
nn.ReLU(inplace=True),
nn.Linear(channel // reduction, channel)
)
self.bn = nn.BatchNorm1d(channel)
def forward(self, x):
b, c, h, w = x.shape
y = self.avgpool(x).view(b, c)
y = self.mlp(y)
y = self.bn(y).view(b, c, 1, 1)
return y.expand_as(x)
class SpatialGate(nn.Module):
def __init__(self, channel, reduction=16, kernel_size=3, dilation_val=4):
super().__init__()
self.conv1 = nn.Conv2d(channel, channel // reduction, kernel_size=1)
self.conv2 = nn.Sequential(
nn.Conv2d(channel // reduction, channel // reduction, kernel_size, padding=dilation_val,
dilation=dilation_val),
nn.BatchNorm2d(channel // reduction),
nn.ReLU(inplace=True),
nn.Conv2d(channel // reduction, channel // reduction, kernel_size, padding=dilation_val,
dilation=dilation_val),
nn.BatchNorm2d(channel // reduction),
nn.ReLU(inplace=True)
)
self.conv3 = nn.Conv2d(channel // reduction, 1, kernel_size=1)
self.bn = nn.BatchNorm2d(1)
def forward(self, x):
b, c, h, w = x.shape
y = self.conv1(x)
y = self.conv2(y)
y = self.conv3(y)
y = self.bn(y)
return y.expand_as(x)
class BAM(nn.Module):
def __init__(self, channel):
super(BAM, self).__init__()
self.channel_attn = ChannelGate(channel)
self.spatial_attn = SpatialGate(channel)
def forward(self, x):
attn = F.sigmoid(self.channel_attn(x) + self.spatial_attn(x))
return x + x * attn
###################### BAM attention #### END by AI&CV ###############################
2.2 加入yolo.py
中:
def parse_model(d, ch): # model_dict, input_channels(3)
添加以下内容
elif m is BAM:
c1, c2 = ch[f], args[0]
if c2 != no:
c2 = make_divisible(c2 * gw, 8)
args = [c1, *args[1:]]
2.3 yolov5s_BAM.yaml
仅供参考,加入网络位置不同在不同数据集表现不一致是正常现场
# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
# Parameters
nc: 1 # number of classes
depth_multiple: 0.33 # model depth multiple
width_multiple: 0.50 # layer channel multiple
anchors:
- [10,13, 16,30, 33,23] # P3/8
- [30,61, 62,45, 59,119] # P4/16
- [116,90, 156,198, 373,326] # P5/32
# YOLOv5 v6.0 backbone
backbone:
# [from, number, module, args]
[[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2
[-1, 1, Conv, [128, 3, 2]], # 1-P2/4
[-1, 3, C3, [128]],
[-1, 1, Conv, [256, 3, 2]], # 3-P3/8
[-1, 6, C3, [256]],
[-1, 1, Conv, [512, 3, 2]], # 5-P4/16
[-1, 9, C3, [512]],
[-1, 1, Conv, [1024, 3, 2]], # 7-P5/32
[-1, 3, C3, [1024]],
[-1, 1, SPPF, [1024, 5]], # 9
]
# YOLOv5 v6.0 head
head:
[[-1, 1, Conv, [512, 1, 1]],
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 6], 1, Concat, [1]], # cat backbone P4
[-1, 3, C3, [512, False]], # 13
[-1, 1, Conv, [256, 1, 1]],
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 4], 1, Concat, [1]], # cat backbone P3
[-1, 3, C3, [256, False]], # 17 (P3/8-small)
[-1, 1, Conv, [256, 3, 2]],
[[-1, 14], 1, Concat, [1]], # cat head P4
[-1, 3, C3, [512, False]], # 20 (P4/16-medium)
[-1, 1, Conv, [512, 3, 2]],
[[-1, 10], 1, Concat, [1]], # cat head P5
[-1, 3, C3, [1024, False]], # 23 (P5/32-large)
[-1, 1, BAM, [1024]], # 24
[[17, 20, 24], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5)
]
3.YOLOv5/YOLOv7魔术师专栏介绍
💡💡💡YOLOv5/YOLOv7魔术师,独家首发创新(原创),持续更新,最终完结篇数≥100+,适用于Yolov5、Yolov7、Yolov8等各个Yolo系列,专栏文章提供每一步步骤和源码,轻松带你上手魔改网络
💡💡💡重点:通过本专栏的阅读,后续你也可以自己魔改网络,在网络不同位置(Backbone、head、detect、loss等)进行魔改,实现创新!!!
专栏介绍:
✨✨✨原创魔改网络、复现前沿论文,组合优化创新
🚀🚀🚀小目标、遮挡物、难样本性能提升
🍉🍉🍉持续更新中,定期更新不同数据集涨点情况
本专栏提供每一步改进步骤和源码,开箱即用,在你的数据集下轻松涨点
通过注意力机制、小目标检测、Backbone&Head优化、 IOU&Loss优化、优化器改进、卷积变体改进、轻量级网络结合yolo等方面进行展开点,
专栏链接如下:
Yolov5/Yolov7魔术师_AI小怪兽的博客-CSDN博客