模块出处
[ECCV 24] [link] [code] Context-Guided Spatial Feature Reconstruction for Efficient Semantic Segmentation
模块名称
Rectangular self-Calibration Attention (RCA)
模块作用
空间注意力
模块结构
模块代码
import torch
import torch.nn as nn
import torch.nn.functional as F
class RCA(nn.Module):
def __init__(self, inp, kernel_size=1, ratio=1, band_kernel_size=11, dw_size=(1,1), padding=(0,0), stride=1, square_kernel_size=3, relu=True):
super(RCA, self).__init__()
self.dwconv_hw = nn.Conv2d(inp, inp, square_kernel_size, padding=square_kernel_size//2, groups=inp)
self.pool_h = nn.AdaptiveAvgPool2d((None, 1))
self.pool_w = nn.AdaptiveAvgPool2d((1, None))
gc = inp // ratio
self.excite = nn.Sequential(
nn.Conv2d(inp, gc, kernel_size=(1, band_kernel_size), padding=(0, band_kernel_size//2), groups=gc),
nn.BatchNorm2d(gc),
nn.ReLU(inplace=True),
nn.Conv2d(gc, inp, kernel_size=(band_kernel_size, 1), padding=(band_kernel_size//2, 0), groups=gc),
nn.Sigmoid()
)
def sge(self, x):
x_h = self.pool_h(x)
x_w = self.pool_w(x)
x_gather = x_h + x_w
ge = self.excite(x_gather)
return ge
def forward(self, x):
loc = self.dwconv_hw(x)
att = self.sge(x)
out = att*loc
return out
if __name__ == '__main__':
x = torch.randn([3, 256, 40, 40])
rca = RCA(inp=256)
out = rca(x)
print(out.shape) # 3, 256, 40, 40
原文表述
矩形自我校准注意力(RCA)明确地模拟矩形区域并校准注意力形状。