引言
- 许久不认真看论文了,这不赶紧捡起来。这也是自己看的第一篇用到Transformer结构的CV论文。
- 之所以选择这篇文章来看,是考虑到之前做过手写字体生成的项目。这个工作可以用来合成一些手写体数据集,用来辅助手写体识别模型的训练。
- 本篇文章将从论文与代码一一对应解析的方式来撰写,这样便于找到论文重点地方以及用代码如何实现的,更快地学到其中要点。这个项目的代码写得很好看,有着清晰的说明和整洁的代码规范。跟着仓库README就可以快速跑起整个项目。
- 如果读者可以阅读英文的话,建议先去直接阅读英文论文,会更直接看到整个面貌。
- PDF | Code
SDT整体结构介绍
- 整体框架:
- 该工作提出从个体手写中解耦作家和字符级别的风格表示,以合成逼真的风格化在线手写字符。
- 从上述框架图,可以看出整体可分为三大部分:Style encoder、Content Encoder和Transformer Decoder。
- Style Encoder: 主要学习给定的Style的Writer和Glyph两种风格表示,用于指导合成风格化的文字。包含两部分:CNN Encoder和Transformer Encdoer。
- Content Encoder: 主要提取输入文字的特征,同样包含两部分:CNN Encoder和Transformer Encdoer。
- ❓疑问:为什么要将CNN Encoder + Transformer Encoder结合使用呢?
- 这个问题在论文中只说了Content Encoder使用两者的作用。CNN部分用来从content reference中学到compact feature map。Transformer encoder用来提取textual content表示。得益于Transformer强大的long-range 依赖的捕捉能力,Content Encdoer可以得到一个全局上下文的content feature。这里让我想到经典的CRNN结构,就是结合CNN + RNN两部分。
代码与论文对应
- 论文结构的最核心代码有两部分,一是搭建模型部分,二是数据集处理部分。
搭建模型部分
- 该部分代码位于仓库中models/model.py,我这里只摘其中最关键部分添加注释来解释,其余细节请小伙伴自行挖掘。
class SDT_Generator(nn.Module):
def __init__(self, d_model=512, nhead=8, num_encoder_layers=2, num_head_layers= 1,
wri_dec_layers=2, gly_dec_layers=2, dim_feedforward=2048, dropout=0.1,
activation="relu", normalize_before=True, return_intermediate_dec=True):
super(SDT_Generator, self).__init__()
self.Feat_Encoder = nn.Sequential(*([nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)] +list(models.resnet18(pretrained=True).children())[1:-2]))
encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward,
dropout, activation, normalize_before)
self.base_encoder = TransformerEncoder(encoder_layer, num_encoder_layers, None)
writer_norm = nn.LayerNorm(d_model) if normalize_before else None
glyph_norm = nn.LayerNorm(d_model) if normalize_before else None
self.writer_head = TransformerEncoder(encoder_layer, num_head_layers, writer_norm)
self.glyph_head = TransformerEncoder(encoder_layer, num_head_layers, glyph_norm)
self.content_encoder = Content_TR(d_model, num_encoder_layers)
decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward,
dropout, activation, normalize_before)
wri_decoder_norm = nn.LayerNorm(d_model) if normalize_before else None
self.wri_decoder = TransformerDecoder(decoder_layer, wri_dec_layers, wri_decoder_norm,
return_intermediate=return_intermediate_dec)
gly_decoder_norm = nn.LayerNorm(d_model) if normalize_before else None
self.gly_decoder = TransformerDecoder(decoder_layer, gly_dec_layers, gly_decoder_norm,
return_intermediate=return_intermediate_dec)
self.pro_mlp_writer = nn.Sequential(
nn.Linear(512, 4096), nn.GELU(), nn.Linear(4096, 256))
self.pro_mlp_character = nn.Sequential(
nn.Linear(512, 4096), nn.GELU(), nn.Linear(4096, 256))
self.SeqtoEmb = SeqtoEmb(hid_dim=d_model)
self.EmbtoSeq = EmbtoSeq(hid_dim=d_model)
self.add_position = PositionalEncoding(dropout=0.1, dim=d_model)
self._reset_parameters()
def forward(self, style_imgs, seq, char_img):
batch_size, num_imgs, in_planes, h, w = style_imgs.shape
style_imgs = style_imgs.view(-1, in_planes, h, w)
style_embe = self.Feat_Encoder(style_imgs)
anchor_num = num_imgs//2
style_embe = style_embe.view(batch_size*num_imgs, 512, -1).permute(2, 0, 1)
FEAT_ST_ENC = self.add_position(style_embe)
memory = self.base_encoder(FEAT_ST_ENC)
writer_memory = self.writer_head(memory)
glyph_memory = self.glyph_head(memory)
writer_memory = rearrange(writer_memory, 't (b p n) c -> t (p b) n c',
b=batch_size, p=2, n=anchor_num)
glyph_memory = rearrange(glyph_memory, 't (b p n) c -> t (p b) n c',
b=batch_size, p=2, n=anchor_num)
memory_fea = rearrange(writer_memory, 't b n c ->(t n) b c')
compact_fea = torch.mean(memory_fea, 0)
pro_emb = self.pro_mlp_writer(compact_fea)
query_emb = pro_emb[:batch_size, :]
pos_emb = pro_emb[batch_size:, :]
nce_emb = torch.stack((query_emb, pos_emb), 1)
nce_emb = nn.functional.normalize(nce_emb, p=2, dim=2)
patch_emb = glyph_memory[:, :batch_size]
anc, positive = self.random_double_sampling(patch_emb)
n_channels = anc.shape[-1]
anc = anc.reshape(batch_size, -1, n_channels)
anc_compact = torch.mean(anc, 1, keepdim=True)
anc_compact = self.pro_mlp_character(anc_compact)
positive = positive.reshape(batch_size, -1, n_channels)
positive_compact = torch.mean(positive, 1, keepdim=True)
positive_compact = self.pro_mlp_character(positive_compact)
nce_emb_patch = torch.cat((anc_compact, positive_compact), 1)
nce_emb_patch = nn.functional.normalize(nce_emb_patch, p=2, dim=2)
writer_style = memory_fea[:, :batch_size, :]
glyph_style = glyph_memory[:, :batch_size]
glyph_style = rearrange(glyph_style, 't b n c -> (t n) b c')
seq_emb = self.SeqtoEmb(seq).permute(1, 0, 2)
T, N, C = seq_emb.shape
char_emb = self.content_encoder(char_img)
char_emb = torch.mean(char_emb, 0)
char_emb = repeat(char_emb, 'n c -> t n c', t = 1)
tgt = torch.cat((char_emb, seq_emb), 0)
tgt_mask = generate_square_subsequent_mask(sz=(T+1)).to(tgt)
tgt = self.add_position(tgt)
wri_hs = self.wri_decoder(tgt, writer_style, tgt_mask=tgt_mask)
hs = self.gly_decoder(wri_hs[-1], glyph_style, tgt_mask=tgt_mask)
h = hs.transpose(1, 2)[-1]
pred_sequence = self.EmbtoSeq(h)
return pred_sequence, nce_emb, nce_emb_patch
数据集部分
- CASIA_CHINESE
data/CASIA_CHINESE
├── character_dict.pkl # 词典
├── Chinese_content.pkl # Content reference
├── test
├── test_style_samples
├── train
├── train_style_samples # 1300个pkl,每个pkl中是同一个人写的各个字,长度不一致
└── writer_dict.pkl
- 训练集中单个数据格式解析
{
'coords': torch.Tensor(coords),
'character_id': torch.Tensor([character_id]),
'writer_id': torch.Tensor([writer_id]),
'img_list': torch.Tensor(img_list),
'char_img': torch.Tensor(char_img),
'img_label': torch.Tensor([label_id]),
}
- 推理时:
总结
- 感觉对于Transformer的用法,比较粗暴。当然,Transformer本来就很粗暴
- 模型69M (
position_layer2_dim512_iter138k_test_acc0.9443.pth
) 比较容易接受,这和我之前以为的Transformer系列都很大,有些出入。这也算是纠正自己的盲目认知了 - 学到了
einops
库的用法,语义化操作,很有意思,值得学习。 - 第一次了解到NCE(Noise Contrastive Estimation)这个Loss,主要解决了class很多时,将其转换为二分类问题。