动手学深度学习
这里写自定义目录标题
- 注意力
- 加性注意力
- 缩放点积注意力
- 多头注意力
- 自注意力
- 自注意力+缩放点积注意力:案例
- Transformer
注意力
注意力汇聚的输出为值的加权和
查询的长度为q,键的长度为k,值的长度为v。
q
∈
1
×
q
,
k
∈
1
×
k
,
v
∈
R
1
×
v
{\bf{q}} \in {^{1 \times q}},{{\bf{k}}} \in {^{1 \times k}},{{\bf{v}}} \in {\mathbb{R}^{1 \times v}}
q∈1×q,k∈1×k,v∈R1×v
n个查询和m个键-值对
Q
∈
n
×
q
,
K
∈
m
×
k
,
V
∈
R
m
×
v
{\bf{Q}} \in {^{n \times q}},{\bf{K}} \in {^{m \times k}},{\bf{V}} \in {\mathbb{R}^{m \times v}}
Q∈n×q,K∈m×k,V∈Rm×v
a
(
Q
,
K
)
∈
R
n
×
m
{\bf{a}}\left( {{\bf{Q}},{\bf{K}}} \right) \in {\mathbb{R}^{n \times m}}
a(Q,K)∈Rn×m是注意力评分函数
α
(
Q
,
K
)
=
s
o
f
t
m
a
x
(
a
(
Q
,
K
)
)
=
exp
(
a
(
Q
,
K
)
)
∑
j
=
1
m
exp
(
a
(
Q
,
K
)
)
∈
R
n
×
m
{\boldsymbol{\alpha}} \left( {{\bf{Q}},{\bf{K}}} \right) = {\rm{softmax}}\left( {{\bf{a}}\left( {{\bf{Q}},{\bf{K}}} \right)} \right) = \frac{{\exp \left( {{\bf{a}}\left( {{\bf{Q}},{\bf{K}}} \right)} \right)}}{{\sum\limits_{j = 1}^m {\exp \left( {{\bf{a}}\left( {{\bf{Q}},{\bf{K}}} \right)} \right)} }} \in {\mathbb{R}^{n \times m}}
α(Q,K)=softmax(a(Q,K))=j=1∑mexp(a(Q,K))exp(a(Q,K))∈Rn×m是注意力权重
f
(
Q
,
K
,
V
)
=
α
(
Q
,
K
)
⊤
V
∈
R
n
×
v
f({\bf{Q}},{\bf{K}},{\bf{V}}) = {\boldsymbol{\alpha}} {\left( {{\bf{Q}},{\bf{K}}} \right)^ \top }{\bf{V}} \in {\mathbb{R}^{n \times v}}
f(Q,K,V)=α(Q,K)⊤V∈Rn×v是注意力汇聚函数
加性注意力
q
∈
R
1
×
q
,
k
∈
R
1
×
k
{\bf{q}} \in {\mathbb {R}^{1 \times q}},{\bf{k}} \in {\mathbb {R}^{1 \times k}}
q∈R1×q,k∈R1×k
W
q
∈
R
h
×
q
,
W
k
∈
R
h
×
k
,
w
v
∈
R
h
×
1
{{\bf{W}}_q} \in {{\mathbb R}^{h \times q}},{{\bf{W}}_k} \in {{\mathbb R}^{h \times k}},{{\bf{w}}_v} \in {{\mathbb R}^{h \times 1}}
Wq∈Rh×q,Wk∈Rh×k,wv∈Rh×1
a
(
q
,
k
)
=
w
v
⊤
t
a
n
h
(
W
q
q
⊤
+
W
k
k
⊤
)
∈
R
a({\bf{q}},{\bf{k}}) = {\bf{w}}_v^ \top {\rm{tanh}}({{\bf{W}}_q}{{\bf{q}}^ \top } + {{\bf{W}}_k}{{\bf{k}}^ \top }) \in \mathbb {R}
a(q,k)=wv⊤tanh(Wqq⊤+Wkk⊤)∈R是注意力评分函数
缩放点积注意力
q
∈
R
1
×
d
,
k
∈
R
1
×
d
,
v
∈
R
1
×
v
{\bf{q}} \in \mathbb{R}{^{1 \times d}},{\bf{k}} \in \mathbb{R}{^{1 \times d}},{\bf{v}} \in {{\mathbb R}^{1 \times v}}
q∈R1×d,k∈R1×d,v∈R1×v
a
(
q
,
k
)
=
1
d
q
k
⊤
∈
R
a\left( {{\bf{q}},{\bf{k}}} \right) = \frac{1}{{\sqrt d }}{\bf{q}}{{\bf{k}}^ \top } \in \mathbb{R}
a(q,k)=d1qk⊤∈R是注意力评分函数
f
(
q
,
k
,
v
)
=
α
(
q
,
k
)
⊤
v
=
s
o
f
t
m
a
x
(
1
d
q
k
⊤
)
v
∈
R
1
×
v
f({\bf{q}},{\bf{k}},{\bf{v}}) = \alpha {\left( {{\bf{q}},{\bf{k}}} \right)^ \top }{\bf{v}} = {\rm{softmax}}\left( {\frac{1}{{\sqrt d }}{\bf{q}}{{\bf{k}}^ \top }} \right){\bf{v}} \in {{\mathbb R}^{1 \times v}}
f(q,k,v)=α(q,k)⊤v=softmax(d1qk⊤)v∈R1×v是注意力汇聚函数
n个查询和m个键-值对
Q
∈
R
n
×
d
,
K
∈
R
m
×
d
,
V
∈
R
m
×
v
\mathbf Q\in\mathbb R^{n\times d}, \mathbf K\in\mathbb R^{m\times d}, \mathbf V\in\mathbb R^{m\times v}
Q∈Rn×d,K∈Rm×d,V∈Rm×v
a
(
Q
,
K
)
=
1
d
Q
K
⊤
∈
R
n
×
m
{\bf{a}}\left( {{\bf{Q}},{\bf{K}}} \right) = \frac{1}{{\sqrt d }}{\bf{Q}}{{\bf{K}}^ \top } \in {\mathbb{R}^{n \times m}}
a(Q,K)=d1QK⊤∈Rn×m是注意力评分函数
f
(
Q
,
K
,
V
)
=
α
(
Q
,
K
)
⊤
V
=
s
o
f
t
m
a
x
(
1
d
Q
K
⊤
)
V
∈
R
n
×
v
f({\bf{Q}},{\bf{K}},{\bf{V}}) = {\boldsymbol{\alpha}} {\left( {{\bf{Q}},{\bf{K}}} \right)^ \top }{\bf{V}} ={\rm{softmax}}\left( {\frac{1}{{\sqrt d }}{\bf{Q}}{{\bf{K}}^ \top }} \right){\bf{V}} \in {\mathbb{R}^{n \times v}}
f(Q,K,V)=α(Q,K)⊤V=softmax(d1QK⊤)V∈Rn×v是注意力汇聚函数
Attention Is All You Need
多头注意力
q
∈
R
1
×
d
q
,
k
∈
R
1
×
d
k
,
v
∈
R
1
×
d
v
{\bf{q}} \in {{\mathbb R}^{1 \times {d_q}}},{\bf{k}} \in {{\mathbb R}^{1 \times {d_k}}},{\bf{v}} \in {{\mathbb R}^{1 \times {d_v}}}
q∈R1×dq,k∈R1×dk,v∈R1×dv
W
i
(
q
)
∈
R
p
q
×
d
q
,
W
i
(
k
)
∈
R
p
k
×
d
k
,
W
i
(
v
)
∈
R
p
v
×
d
v
{\bf{W}}_i^{(q)} \in {{\mathbb R}^{{p_q} \times {d_q}}},{\bf{W}}_i^{(k)} \in {{\mathbb R}^{{p_k} \times {d_k}}},{\bf{W}}_i^{(v)} \in {{\mathbb R}^{{p_v} \times {d_v}}}
Wi(q)∈Rpq×dq,Wi(k)∈Rpk×dk,Wi(v)∈Rpv×dv
q
W
i
(
q
)
⊤
∈
R
1
×
p
q
,
k
W
i
(
k
)
⊤
∈
R
1
×
p
k
,
v
W
i
(
v
)
⊤
∈
R
1
×
p
v
{\bf{qW}}{_i^{(q)\top} } \in {{\mathbb R}^{1 \times {p_q}}},{\bf{kW}}{_i^{(k)\top} } \in {{\mathbb R}^{1 \times {p_k}}},{\bf{vW}}{_i^{(v)\top} } \in {{\mathbb R}^{1 \times {p_v}}}
qWi(q)⊤∈R1×pq,kWi(k)⊤∈R1×pk,vWi(v)⊤∈R1×pv
h
i
=
f
(
q
W
i
(
q
)
⊤
,
k
W
i
(
k
)
⊤
,
v
W
i
(
v
)
⊤
)
∈
R
1
×
p
v
{{\bf{h}}_i} = f\left( {{\bf{qW}}{{_i^{(q)\top}} },{\bf{kW}}{{_i^{(k)\top}} },{\bf{vW}}{{_i^{(v)\top}} }} \right) \in {{\mathbb R}^{1 \times {p_v}}}
hi=f(qWi(q)⊤,kWi(k)⊤,vWi(v)⊤)∈R1×pv是注意力头
多个注意力头连结然后线性变换
W
o
∈
R
p
o
×
h
p
v
{{\bf{W}}_o} \in {{\mathbb R}^{{p_o} \times h{p_v}}}
Wo∈Rpo×hpv
W
o
[
h
1
⊤
⋮
h
h
⊤
]
∈
R
p
o
{{\bf{W}}_o}\left[ {\begin{array}{c} {{{\bf{h}}_1^ \top}}\\ \vdots \\ {{{\bf{h}}_h^ \top}} \end{array}} \right] \in {{\mathbb R}^{{p_o}}}
Wo
h1⊤⋮hh⊤
∈Rpo
p
q
h
=
p
k
h
=
p
v
h
=
p
o
p_q h = p_k h = p_v h = p_o
pqh=pkh=pvh=po
多头注意力:多个注意力头连结然后线性变换
自注意力
x
i
∈
R
1
×
d
,
X
=
[
x
1
⋯
x
n
]
∈
R
n
×
d
{{\bf{x}}_i} \in {{\mathbb R}^{1 \times d}},{\bf{X}} = \left[ {\begin{array}{c} {{{\bf{x}}_1}}\\ \cdots \\ {{{\bf{x}}_n}} \end{array}} \right] \in {{\mathbb R}^{n \times d}}
xi∈R1×d,X=
x1⋯xn
∈Rn×d
Q
=
X
,
K
=
X
,
V
=
X
{\bf{Q}} = {\bf{X}},{\bf{K}} = {\bf{X}},{\bf{V}} = {\bf{X}}
Q=X,K=X,V=X
f
(
Q
,
K
,
V
)
=
α
(
Q
,
K
)
⊤
V
=
s
o
f
t
m
a
x
(
1
d
Q
K
⊤
)
V
∈
R
n
×
d
f({\bf{Q}},{\bf{K}},{\bf{V}}) = {\boldsymbol{\alpha}} {\left( {{\bf{Q}},{\bf{K}}} \right)^ \top }{\bf{V}} ={\rm{softmax}}\left( {\frac{1}{{\sqrt d }}{\bf{Q}}{{\bf{K}}^ \top }} \right){\bf{V}} \in {\mathbb{R}^{n \times d}}
f(Q,K,V)=α(Q,K)⊤V=softmax(d1QK⊤)V∈Rn×d
y
i
=
f
(
x
i
,
(
x
1
,
x
1
)
,
…
,
(
x
n
,
x
n
)
)
∈
R
d
{{\bf{y}}_i} = f\left( {{{\bf{x}}_i},\left( {{{\bf{x}}_1},{{\bf{x}}_1}} \right), \ldots ,\left( {{{\bf{x}}_n},{{\bf{x}}_n}} \right)} \right) \in {{\mathbb R}^d}
yi=f(xi,(x1,x1),…,(xn,xn))∈Rd
n个查询和m个键-值对
Q
=
t
a
n
h
(
W
q
X
)
∈
R
n
×
d
{\bf{Q}} = {\rm{tanh}}\left( {{{\bf{W}}_q}{\bf{X}}} \right) \in {{\mathbb R}^{n \times d}}
Q=tanh(WqX)∈Rn×d
K
=
t
a
n
h
(
W
k
X
)
∈
R
m
×
d
{\bf{K}} = {\rm{tanh}}\left( {{{\bf{W}}_k}{\bf{X}}} \right) \in {{\mathbb R}^{m \times d}}
K=tanh(WkX)∈Rm×d
V
=
t
a
n
h
(
W
v
X
)
∈
R
m
×
v
{\bf{V}} = {\rm{tanh}}\left( {{{\bf{W}}_v}{\bf{X}}} \right) \in {{\mathbb R}^{m \times v}}
V=tanh(WvX)∈Rm×v
自注意力+缩放点积注意力:案例
J. Xu, F. Zhong, and Y. Wang, “Learning multi-agent coordination for enhancing target coverage in directional sensor networks,” in Proc. Neural Information Processing Systems (NeurIPS), Vancouver, BC, Canada, Dec. 2020, pp. 1–16.
https://github.com/XuJing1022/HiT-MAC/blob/main/perception.py
类比多头注意力中
d
q
=
d
k
=
d
v
=
d
i
n
{d_q} = {d_k} = {d_v} = {d_{in}}
dq=dk=dv=din
p
q
=
p
k
=
p
v
=
d
a
t
t
{p_q} = {p_k} = {p_v} = {d_{att}}
pq=pk=pv=datt
x
i
∈
R
1
×
d
i
n
,
X
=
[
x
1
⋯
x
n
m
]
∈
R
n
m
×
d
i
n
{{\bf{x}}_i} \in {{\mathbb R}^{1 \times d_{in}}},{\bf{X}} = \left[ {\begin{array}{c} {{{\bf{x}}_1}}\\ \cdots \\ {{{\bf{x}}_{nm}}} \end{array}} \right] \in {{\mathbb R}^{nm \times d_{in}}}
xi∈R1×din,X=
x1⋯xnm
∈Rnm×din
W
∈
R
d
a
t
t
×
d
i
n
{\bf{W}} \in {{\mathbb R}^{d_{att}\times d_{in}}}
W∈Rdatt×din
Q
=
t
a
n
h
(
W
q
X
⊤
)
⊤
∈
R
n
m
×
d
a
t
t
{\bf{Q}} = {\rm{tanh}}\left( {{{\bf{W}}_q}{\bf{X}}^\top} \right)^\top \in {{\mathbb R}^{nm \times d_{att}}}
Q=tanh(WqX⊤)⊤∈Rnm×datt
K
=
t
a
n
h
(
W
k
X
⊤
)
⊤
∈
R
n
m
×
d
a
t
t
{\bf{K}} = {\rm{tanh}}\left( {{{\bf{W}}_k}{\bf{X}}^\top} \right)^\top \in {{\mathbb R}^{nm \times d_{att}}}
K=tanh(WkX⊤)⊤∈Rnm×datt
V
=
t
a
n
h
(
W
v
X
⊤
)
⊤
∈
R
n
m
×
d
a
t
t
{\bf{V}} = {\rm{tanh}}\left( {{{\bf{W}}_v}{\bf{X}}^\top} \right)^\top \in {{\mathbb R}^{nm \times d_{att}}}
V=tanh(WvX⊤)⊤∈Rnm×datt
f
(
Q
,
K
,
V
)
=
α
(
Q
,
K
)
⊤
V
=
s
o
f
t
m
a
x
(
1
d
Q
K
⊤
)
V
∈
R
n
m
×
d
a
t
t
f({\bf{Q}},{\bf{K}},{\bf{V}}) = {\boldsymbol{\alpha}} {\left( {{\bf{Q}},{\bf{K}}} \right)^ \top }{\bf{V}} ={\rm{softmax}}\left( {\frac{1}{{\sqrt d }}{\bf{Q}}{{\bf{K}}^ \top }} \right){\bf{V}} \in {{\mathbb R}^{nm \times d_{att}}}
f(Q,K,V)=α(Q,K)⊤V=softmax(d1QK⊤)V∈Rnm×datt
class AttentionLayer(torch.nn.Module):
def __init__(self, feature_dim, weight_dim, device):
super(AttentionLayer, self).__init__()
self.in_dim = feature_dim
self.device = device
self.Q = xavier_init(nn.Linear(self.in_dim, weight_dim))
self.K = xavier_init(nn.Linear(self.in_dim, weight_dim))
self.V = xavier_init(nn.Linear(self.in_dim, weight_dim))
self.feature_dim = weight_dim
def forward(self, x):
# param x: [num_agent, num_target, in_dim]
# return z: [num_agent, num_target, weight_dim]
# z = softmax(Q,K)*V
q = torch.tanh(self.Q(x)) # [batch_size, sequence_len, weight_dim]
k = torch.tanh(self.K(x)) # [batch_size, sequence_len, weight_dim]
v = torch.tanh(self.V(x)) # [batch_size, sequence_len, weight_dim]
z = torch.bmm(F.softmax(torch.bmm(q, k.permute(0, 2, 1)), dim=2), v) # [batch_size, sequence_len, weight_dim]
global_feature = z.sum(dim=1)
return z, global_feature