From a layer normalized tensor
Y
∈
R
H
^
×
W
^
×
C
^
\mathbf{Y} \in \mathbb{R}^{\hat{H} \times \hat{W} \times \hat{C}}
Y∈RH^×W^×C^, our MDTA first generates query
(
Q
)
(\mathbf{Q})
(Q), key
(
K
)
(\mathbf{K})
(K) and value
(
V
)
(\mathbf{V})
(V) projections, enriched with local context. It is achieved by applying
1
×
1
1 \times 1
1×1 convolutions to aggregate pixel-wise cross-channel context followed by
3
×
3
3 \times 3
3×3 depth-wise convolutions to encode channel-wise spatial context, yielding
Q
=
W
d
Q
W
p
Q
Y
,
K
=
W
d
K
W
p
K
Y
\mathbf{Q}=W_d^Q W_p^Q \mathbf{Y}, \mathbf{K}=W_d^K W_p^K \mathbf{Y}
Q=WdQWpQY,K=WdKWpKY and
V
=
W
d
V
W
p
V
Y
\mathbf{V}=W_d^V W_p^V \mathbf{Y}
V=WdVWpVY. Where
W
p
(
⋅
)
W_p^{(\cdot)}
Wp(⋅) is the
1
×
1
1 \times 1
1×1 point-wise convolution and
W
d
(
⋅
)
W_d^{(\cdot)}
Wd(⋅) is the
3
×
3
3 \times 3
3×3 depth-wise convolution. We use bias-free convolutional layers in the network. Next, we reshape query and key projections such that their dot-product interaction generates a transposed-attention map
A
\mathbf{A}
A of size
R
C
^
×
C
^
\mathbb{R}^{\hat{C} \times \hat{C}}
RC^×C^, instead of the huge regular attention map of size
R
H
^
W
^
×
H
^
W
^
\mathbb{R}^{\hat{H} \hat{W} \times \hat{H} \hat{W}}
RH^W^×H^W^. Overall, the MDTA process is defined as:
X
^
=
W
p
Attention
(
Q
^
,
K
^
,
V
^
)
+
X
Attention
(
Q
^
,
K
^
,
V
^
)
=
V
^
⋅
Softmax
(
K
^
⋅
Q
^
/
α
)
\hat{\mathbf{X}}=W_p \operatorname{Attention}(\hat{\mathbf{Q}}, \hat{\mathbf{K}}, \hat{\mathbf{V}})+\mathbf{X}\\ \operatorname{Attention}(\hat{\mathbf{Q}}, \hat{\mathbf{K}}, \hat{\mathbf{V}})=\hat{\mathbf{V}} \cdot \operatorname{Softmax}(\hat{\mathbf{K}} \cdot \hat{\mathbf{Q}} / \alpha)
X^=WpAttention(Q^,K^,V^)+XAttention(Q^,K^,V^)=V^⋅Softmax(K^⋅Q^/α)
where
X
\mathbf{X}
X and
X
^
\hat{\mathbf{X}}
X^ are the input and output feature maps;
Q
^
∈
R
H
^
W
^
×
C
^
;
K
^
∈
R
C
^
×
H
^
W
^
;
\hat{\mathbf{Q}} \in \mathbb{R}^{\hat{H} \hat{W} \times \hat{C}} ; \hat{\mathbf{K}} \in \mathbb{R}^{\hat{C} \times \hat{H} \hat{W}} ;
Q^∈RH^W^×C^;K^∈RC^×H^W^; and
V
^
∈
R
H
^
W
^
×
C
^
\hat{\mathbf{V}} \in \mathbb{R}^{\hat{H} \hat{W} \times \hat{C}}
V^∈RH^W^×C^ matrices are obtained after reshaping tensors from the original size
R
H
^
×
W
^
×
C
^
\mathbb{R}^{\hat{H} \times \hat{W} \times \hat{C}}
RH^×W^×C^. Here,
α
\alpha
α is a learnable scaling parameter to control the magnitude of the dot product of
K
^
\hat{\mathbf{K}}
K^ and
Q
^
\hat{\mathbf{Q}}
Q^ before applying the softmax function. Similar to the conventional multi-head SA , we divide the number of channels into ‘heads’ and learn separate attention maps in parallel.
## Multi-DConv Head Transposed Self-Attention (MDTA)
class Attention(nn.Module):
def __init__(self, dim, num_heads, bias):
super(Attention, self).__init__()
self.num_heads = num_heads
self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))
self.qkv = nn.Conv2d(dim, dim*3, kernel_size=1, bias=bias)
self.qkv_dwconv = nn.Conv2d(
dim*3, dim*3, kernel_size=3, stride=1, padding=1, groups=dim*3, bias=bias)
self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
def forward(self, x):
b, c, h, w = x.shape
qkv = self.qkv_dwconv(self.qkv(x))
q, k, v = qkv.chunk(3, dim=1)
q = rearrange(q, 'b (head c) h w -> b head c (h w)',
head=self.num_heads)
k = rearrange(k, 'b (head c) h w -> b head c (h w)',
head=self.num_heads)
v = rearrange(v, 'b (head c) h w -> b head c (h w)',
head=self.num_heads)
q = torch.nn.functional.normalize(q, dim=-1)
k = torch.nn.functional.normalize(k, dim=-1)
attn = (q @ k.transpose(-2, -1)) * self.temperature
attn = attn.softmax(dim=-1)
out = (attn @ v)
out = rearrange(out, 'b head c (h w) -> b (head c) h w',
head=self.num_heads, h=h, w=w)
out = self.project_out(out)
return out
这段代码并没有实现图中的Norm模块,该模块的实现可以参考Layer Normalization(层规范化)。我们看一下Transformer Block是如何包装的:
class TransformerBlock(nn.Module):
def __init__(self, dim, num_heads, ffn_expansion_factor, bias, LayerNorm_type):
super(TransformerBlock, self).__init__()
self.norm1 = LayerNorm(dim, LayerNorm_type)
self.attn = Attention(dim, num_heads, bias)
self.norm2 = LayerNorm(dim, LayerNorm_type)
self.ffn = FeedForward(dim, ffn_expansion_factor, bias)
def forward(self, x):
x = x + self.attn(self.norm1(x))#MDTA
x = x + self.ffn(self.norm2(x))
return x
可以看到实现的时候是先Norm,然后通过Attention,最后再残差连接,这整个流程才是上图所示