一、CBAM概念
CBAM(Convolutional Block Attention Module)是一种用于卷积神经网络(CNN)的注意力机制,旨在提高网络的表现能力。它通过引入两个注意力模块来增强特征图的表达能力。
二、源码:
class ChannelAttention(nn.Module):
def __init__(self, in_planes, ratio=16):
super(ChannelAttention, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.max_pool = nn.AdaptiveMaxPool2d(1)
self.f1 = nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False)
self.relu = nn.ReLU()
self.f2 = nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
avg_out = self.f2(self.relu(self.f1(self.avg_pool(x))))
max_out = self.f2(self.relu(self.f1(self.max_pool(x))))
out = self.sigmoid(avg_out + max_out)
return out
class SpatialAttention(nn.Module):
def __init__(self, kernel_size=7):
super(SpatialAttention, self).__init__()
assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
padding = 3 if kernel_size == 7 else 1
self.conv = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
# 1*h*w
avg_out = torch.mean(x, dim=1, keepdim=True)
max_out, _ = torch.max(x, dim=1, keepdim=True)
x = torch.cat([avg_out, max_out], dim=1)
# 2*h*w
x = self.conv(x)
# 1*h*w
return self.sigmoid(x)
class CBAM(nn.Module):
def __init__(self, c1, c2, ratio=16, kernel_size=7): # ch_in, ch_out, number, shortcut, groups, expansion
super(CBAM, self).__init__()
self.channel_attention = ChannelAttention(c1, ratio)
self.spatial_attention = SpatialAttention(kernel_size)
def forward(self, x):
out = self.channel_attention(x) * x
# c*h*w
# c*h*w * 1*h*w
out = self.spatial_attention(out) * out
return out
三、改进步骤
第一步,在ultralytics/nn/modules/conv.py文件内添加注意力源码
第二步,在ultralytics/nn/modules/init.py文件内,按下图标识的地方添加注意力名
第一处:在from .conv import()处最后,添加注意力名称
第二处:在__all__={}处最后,添加注意力名称
第三步,在ultralytics/nn/tasks.py文件内,
首先,在from ultralytics.nn.modules import 处添加CBAM
其次,键盘点击CTRL+shift+F打开查找界面,搜索elif m in ,在该函数下方有一堆的elif m in XXX,在某一个elif下方添加如下代码
elif m in {CBAM}:
c1, c2 = ch[f], args[0]
if c2 != nc: # if not output
c2 = make_divisible(min(c2, max_channels) * width, 8)
args = [c1, c2, *args[1:]]
第五步,在ultralytics/cfg/models/v8文件下,复制yolov8.yaml,并改成自己的名字(如yolov8-CBAM.yaml)
第一种修改方法,在backbone中添加CABM,因为添加了一层CABM,所以在Detect处也要相应的做出修改,如下:
运行结果:
第二张修改方法,在backbone和head中分别添加CABM,Detect处也要根据添加的层数做出相应的做出修改,如下:
运行结果: