从入门AI到手写Transformer-13.编码器的结构
- 13.编码器的结构
- 代码
整理自视频 老袁不说话 。
13.编码器的结构
T r a n s f o r m e r E n c o d e r : 输入 [ b , n ] TransformerEncoder:输入[b,n] TransformerEncoder:输入[b,n]
- E m b e d d i n g : − > [ b , n , d ] Embedding:->[b,n,d] Embedding:−>[b,n,d]
- P o s i t i o n a l E n c o d e r : − > [ b , n , d ] PositionalEncoder:->[b,n,d] PositionalEncoder:−>[b,n,d]
- D r o p o u t : − > [ b , n , d ] Dropout:->[b,n,d] Dropout:−>[b,n,d]
-
E
n
c
o
d
e
r
B
l
o
c
k
:
[
b
,
n
,
d
]
−
>
[
b
,
n
,
d
]
EncoderBlock:[b,n,d]->[b,n,d]
EncoderBlock:[b,n,d]−>[b,n,d] 重复N次
- M u l t i h e a d A t t e n t i o n : 3 ∗ [ b , n , d ] − > [ b , n , d ] MultiheadAttention:3*[b,n,d]->[b,n,d] MultiheadAttention:3∗[b,n,d]−>[b,n,d]
- D r o p o u t : [ b , n , d ] − > [ b , n , d ] Dropout:[b,n,d]->[b,n,d] Dropout:[b,n,d]−>[b,n,d]
- A d d N o r m : 2 ∗ [ b , n , d ] ( D r o u p o u t 输出, M u l t i h e a d A t t e n t i o n 输入 ) − > [ b , n , d ] AddNorm:2*[b,n,d](Droupout输出,MultiheadAttention输入)->[b,n,d] AddNorm:2∗[b,n,d](Droupout输出,MultiheadAttention输入)−>[b,n,d]
- F F N : [ b , n , d ] − > [ b , n , d ] FFN:[b,n,d]->[b,n,d] FFN:[b,n,d]−>[b,n,d]
- D r o p o u t : [ b , n , d ] − > [ b , n , d ] Dropout:[b,n,d]->[b,n,d] Dropout:[b,n,d]−>[b,n,d]
-
A
d
d
N
o
r
m
:
2
∗
[
b
,
n
,
d
]
(
D
r
o
u
p
o
u
t
输出,
F
F
N
输入
)
−
>
[
b
,
n
,
d
]
AddNorm:2*[b,n,d](Droupout输出,FFN输入)->[b,n,d]
AddNorm:2∗[b,n,d](Droupout输出,FFN输入)−>[b,n,d]
编码器结构
多处执行Dropout
代码
import torch.nn as nn
class Embedding(nn.Module):
def __init__(self,*args,**kwargs)->None:
super().__init__(*args,**kwargs)
def forward(self):
print(self.__class__.__name__)
class PositionalEncoding(nn.Module):
def __init__(self,*args,**kwargs)->None:
super().__init__(*args,**kwargs)
def forward(self):
print(self.__class__.__name__)
class MultiheadAttention(nn.Module):
def __init__(self,*args,**kwargs)->None:
super().__init__(*args,**kwargs)
def forward(self):
print(self.__class__.__name__)
class Dropout(nn.Module):
def __init__(self,*args,**kwargs)->None:
super().__init__(*args,**kwargs)
def forward(self):
print(self.__class__.__name__)
class AddNorm(nn.Module):
def __init__(self,*args,**kwargs)->None:
super().__init__(*args,**kwargs)
def forward(self):
print(self.__class__.__name__)
class FFN(nn.Module):
def __init__(self,*args,**kwargs)->None:
super().__init__(*args,**kwargs)
def forward(self):
print(self.__class__.__name__)
class EncoderBlock(nn.Module):
def __init__(self,*args, **kwargs)->None:
super().__init__(*args,**kwargs)
self.mha = MultiheadAttention()
self.dropout1=Dropout()
self.addnorm1=AddNorm()
self.ffn=FFN()
self.dropout2=Dropout()
self.addnorm2 = AddNorm()
def forward(self):
self.mha()
self.dropout1()
self.addnorm1()
self.ffn()
self.dropout2()
self.addnorm2()
class TransformerEncoder(nn.Module):
def __init__(self,*args,**kwargs)->None:
super().__init__(*args,**kwargs)
self.embedding=Embedding() # 把序号转变为有语义信息的编码
self.posenc=PositionalEncoding()
self.dropout=Dropout()
self.encblocks=nn.Sequential()
for i in range(3):
self.encblocks.add_module(str(i),EncoderBlock())
def forward(self):
self.embedding()
self.posenc()
self.dropout()
for i,blk in enumerate(self.encblocks):
print(i)
blk()
te=TransformerEncoder()
te()
输出结果
Embedding
PositionalEncoding
Dropout
0
MultiheadAttention
Dropout
AddNorm
FFN
Dropout
AddNorm
1
MultiheadAttention
Dropout
AddNorm
FFN
Dropout
AddNorm
2
MultiheadAttention
Dropout
AddNorm
FFN
Dropout
AddNorm