论文阅读: (CVPR2023 SDT )基于书写者风格和字符风格解耦的手写文字生成及源码对应

news2024/10/5 11:25:54

目录

        • 引言
        • SDT整体结构介绍
        • 代码与论文对应
          • 搭建模型部分
          • 数据集部分
        • 总结

引言

  • 许久不认真看论文了,这不赶紧捡起来。这也是自己看的第一篇用到Transformer结构的CV论文。
  • 之所以选择这篇文章来看,是考虑到之前做过手写字体生成的项目。这个工作可以用来合成一些手写体数据集,用来辅助手写体识别模型的训练。
  • 本篇文章将从论文与代码一一对应解析的方式来撰写,这样便于找到论文重点地方以及用代码如何实现的,更快地学到其中要点。这个项目的代码写得很好看,有着清晰的说明和整洁的代码规范。跟着仓库README就可以快速跑起整个项目。
  • 如果读者可以阅读英文的话,建议先去直接阅读英文论文,会更直接看到整个面貌。
  • PDF | Code

SDT整体结构介绍

  • 整体框架:
    SDT
  • 该工作提出从个体手写中解耦作家和字符级别的风格表示,以合成逼真的风格化在线手写字符。
  • 从上述框架图,可以看出整体可分为三大部分:Style encoderContent EncoderTransformer Decoder
    • Style Encoder: 主要学习给定的Style的Writer和Glyph两种风格表示,用于指导合成风格化的文字。包含两部分:CNN EncoderTransformer Encdoer
    • Content Encoder: 主要提取输入文字的特征,同样包含两部分:CNN EncoderTransformer Encdoer
  • ❓疑问:为什么要将CNN Encoder + Transformer Encoder结合使用呢?
    • 这个问题在论文中只说了Content Encoder使用两者的作用。CNN部分用来从content reference中学到compact feature map。Transformer encoder用来提取textual content表示。得益于Transformer强大的long-range 依赖的捕捉能力,Content Encdoer可以得到一个全局上下文的content feature。这里让我想到经典的CRNN结构,就是结合CNN + RNN两部分。
      在这里插入图片描述

代码与论文对应

  • 论文结构的最核心代码有两部分,一是搭建模型部分,二是数据集处理部分。
搭建模型部分
  • 该部分代码位于仓库中models/model.py,我这里只摘其中最关键部分添加注释来解释,其余细节请小伙伴自行挖掘。
class SDT_Generator(nn.Module):
    def __init__(self, d_model=512, nhead=8, num_encoder_layers=2, num_head_layers= 1,
                 wri_dec_layers=2, gly_dec_layers=2, dim_feedforward=2048, dropout=0.1,
                 activation="relu", normalize_before=True, return_intermediate_dec=True):
        super(SDT_Generator, self).__init__()
        
        ### style encoder with dual heads
        # Feat_Encoder:对应论文中的CNN Encoder,用来提取图像经过CNN之后的特征,backbone选的是ResNet18
        self.Feat_Encoder = nn.Sequential(*([nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)] +list(models.resnet18(pretrained=True).children())[1:-2]))
        
        # self.base_encoder:对应论文中Style Encoder的Transformer Encoderb部分
        encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward,
                                                dropout, activation, normalize_before)
        self.base_encoder = TransformerEncoder(encoder_layer, num_encoder_layers, None)
        
        writer_norm = nn.LayerNorm(d_model) if normalize_before else None
        glyph_norm = nn.LayerNorm(d_model) if normalize_before else None
 
        # writer_head和glyph_head分别对应论文中的Writer Head和Glyph Head
        # 从这里来看,这两个分支使用的是1层的Transformer Encoder结构
        self.writer_head = TransformerEncoder(encoder_layer, num_head_layers, writer_norm)
        self.glyph_head = TransformerEncoder(encoder_layer, num_head_layers, glyph_norm)

        ### content ecoder
        # content_encoder:对应论文中Content Encoder部分,
        # 从Content_TR源码来看,同样也是ResNet18作为CNN Encoder的backbone
        # Transformer Encoder部分用了3层的Transformer Encoder结构
        # 详情参见:https://github.com/dailenson/SDT/blob/1352b5cb779d47c5a8c87f6735e9dde94aa58f07/models/encoder.py#L8
        self.content_encoder = Content_TR(d_model, num_encoder_layers)

        ### decoder for receiving writer-wise and character-wise styles
        # 这里对应框图中Transformer Decoder中前后两个部分
        decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward,
                                                dropout, activation, normalize_before)
        wri_decoder_norm = nn.LayerNorm(d_model) if normalize_before else None
        self.wri_decoder = TransformerDecoder(decoder_layer, wri_dec_layers, wri_decoder_norm,
                                              return_intermediate=return_intermediate_dec)
        gly_decoder_norm = nn.LayerNorm(d_model) if normalize_before else None
        self.gly_decoder = TransformerDecoder(decoder_layer, gly_dec_layers, gly_decoder_norm,
                                          return_intermediate=return_intermediate_dec)
        
        ### two mlps that project style features into the space where nce_loss is applied
        self.pro_mlp_writer = nn.Sequential(
            nn.Linear(512, 4096), nn.GELU(), nn.Linear(4096, 256))
        self.pro_mlp_character = nn.Sequential(
            nn.Linear(512, 4096), nn.GELU(), nn.Linear(4096, 256))

        self.SeqtoEmb = SeqtoEmb(hid_dim=d_model)
        self.EmbtoSeq = EmbtoSeq(hid_dim=d_model)
  
        # 这里位置嵌入来源于论文Attention is all you need.
        self.add_position = PositionalEncoding(dropout=0.1, dim=d_model)        
        self._reset_parameters()

    # the shape of style_imgs is [B, 2*N, C, H, W] during training
    def forward(self, style_imgs, seq, char_img):
        batch_size, num_imgs, in_planes, h, w = style_imgs.shape

        # style_imgs: [B, 2*N, C:1, H, W] -> FEAT_ST_ENC: [4*N, B, C:512]
        style_imgs = style_imgs.view(-1, in_planes, h, w)  # [B*2N, C:1, H, W]
        
        # 经过CNN Encoder
        style_embe = self.Feat_Encoder(style_imgs)  # [B*2N, C:512, 2, 2]

        anchor_num = num_imgs//2
        style_embe = style_embe.view(batch_size*num_imgs, 512, -1).permute(2, 0, 1)  # [4, B*2N, C:512]
        FEAT_ST_ENC = self.add_position(style_embe)

        memory = self.base_encoder(FEAT_ST_ENC)  # [4, B*2N, C]
        writer_memory = self.writer_head(memory)
        glyph_memory = self.glyph_head(memory)

        writer_memory = rearrange(writer_memory, 't (b p n) c -> t (p b) n c',
                           b=batch_size, p=2, n=anchor_num)  # [4, 2*B, N, C]
        glyph_memory = rearrange(glyph_memory, 't (b p n) c -> t (p b) n c',
                           b=batch_size, p=2, n=anchor_num)  # [4, 2*B, N, C]

        # writer-nce
        memory_fea = rearrange(writer_memory, 't b n c ->(t n) b c')  # [4*N, 2*B, C]
        compact_fea = torch.mean(memory_fea, 0) # [2*B, C]
        
        # compact_fea:[2*B, C:512] ->  nce_emb: [B, 2, C:128]
        pro_emb = self.pro_mlp_writer(compact_fea)
        query_emb = pro_emb[:batch_size, :]
        pos_emb = pro_emb[batch_size:, :]
        nce_emb = torch.stack((query_emb, pos_emb), 1) # [B, 2, C]
        nce_emb = nn.functional.normalize(nce_emb, p=2, dim=2)

        # glyph-nce
        patch_emb = glyph_memory[:, :batch_size]  # [4, B, N, C]
        
        # sample the positive pair
        anc, positive = self.random_double_sampling(patch_emb)
        n_channels = anc.shape[-1]
        anc = anc.reshape(batch_size, -1, n_channels)
        anc_compact = torch.mean(anc, 1, keepdim=True) 
        anc_compact = self.pro_mlp_character(anc_compact) # [B, 1, C]
        positive = positive.reshape(batch_size, -1, n_channels)
        positive_compact = torch.mean(positive, 1, keepdim=True)
        positive_compact = self.pro_mlp_character(positive_compact) # [B, 1, C]

        nce_emb_patch = torch.cat((anc_compact, positive_compact), 1) # [B, 2, C]
        nce_emb_patch = nn.functional.normalize(nce_emb_patch, p=2, dim=2)

        # input the writer-wise & character-wise styles into the decoder
        writer_style = memory_fea[:, :batch_size, :]  # [4*N, B, C]
        glyph_style = glyph_memory[:, :batch_size]  # [4, B, N, C]
        glyph_style = rearrange(glyph_style, 't b n c -> (t n) b c') # [4*N, B, C]

        # QUERY: [char_emb, seq_emb]
        seq_emb = self.SeqtoEmb(seq).permute(1, 0, 2)
        T, N, C = seq_emb.shape

        # ========================Content Encoder部分=========================
        char_emb = self.content_encoder(char_img) # [4, N, 512]
        char_emb = torch.mean(char_emb, 0) #[N, 512]
        char_emb = repeat(char_emb, 'n c -> t n c', t = 1)
        tgt = torch.cat((char_emb, seq_emb), 0) # [1+T], put the content token as the first token
        tgt_mask = generate_square_subsequent_mask(sz=(T+1)).to(tgt)
        tgt = self.add_position(tgt)

		# 注意这里的执行顺序,Content Encoder输出 → Writer Decoder → Glyph Decoder → Embedding to Sequence
        # [wri_dec_layers, T, B, C]
        wri_hs = self.wri_decoder(tgt, writer_style, tgt_mask=tgt_mask)
        # [gly_dec_layers, T, B, C]
        hs = self.gly_decoder(wri_hs[-1], glyph_style, tgt_mask=tgt_mask)  

        h = hs.transpose(1, 2)[-1]  # B T C
        pred_sequence = self.EmbtoSeq(h)
        return pred_sequence, nce_emb, nce_emb_patch
数据集部分
  • CASIA_CHINESE
    data/CASIA_CHINESE
    ├── character_dict.pkl   # 词典
    ├── Chinese_content.pkl  # Content reference
    ├── test
    ├── test_style_samples
    ├── train
    ├── train_style_samples  # 1300个pkl,每个pkl中是同一个人写的各个字,长度不一致
    └── writer_dict.pkl
    
  • 训练集中单个数据格式解析
    {
        'coords': torch.Tensor(coords),                # 写这个字,每一划的点阵
        'character_id': torch.Tensor([character_id]),  # content字的索引
        'writer_id': torch.Tensor([writer_id]),        # 某个人的style
        'img_list': torch.Tensor(img_list),            # 随机选中style的img_list
        'char_img': torch.Tensor(char_img),            # content字的图像
        'img_label': torch.Tensor([label_id]),         # style中图像的label
    }
    
  • 推理时:
    • 输入:
      • 一种style15个字符的图像
      • 原始输入字符
    • 输出:属于该style的原始字符

总结

  1. 感觉对于Transformer的用法,比较粗暴。当然,Transformer本来就很粗暴
  2. 模型69M (position_layer2_dim512_iter138k_test_acc0.9443.pth) 比较容易接受,这和我之前以为的Transformer系列都很大,有些出入。这也算是纠正自己的盲目认知了
  3. 学到了einops库的用法,语义化操作,很有意思,值得学习。
  4. 第一次了解到NCE(Noise Contrastive Estimation)这个Loss,主要解决了class很多时,将其转换为二分类问题。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/699914.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

浅析基于物联网技术的校园能耗智慧监控平台的设计及应用

摘 要:为打造低碳绿色校园,营造良好的学习环境,针对目前校园建筑能耗大,特别是空调节能困难等问题,特采用物联网技术构建校园建筑能耗智慧监控平台。通过设计空调监控子系统,搭建空调监控模型实现了空调等智…

在 Jetpack Compose 中使用 Snackbar

Jetpack Compose 是 Android 的现代 UI 工具库,提供了丰富的组件和功能来构建漂亮、交互丰富的用户界面。在本文中,我们将学习如何在 Jetpack Compose 中使用 Snackbar 组件来显示临时消息或操作反馈。 什么是 Snackbar? Snackbar 是一种用于…

基于Layui实现管理页面

基于Layui实现的后台管理页面(仅前端) 注:这是博主在帮朋友实现的一个简单的系统前端框架(无后端),跟大家分享出来,可以直接将对应菜单跟html文件链接起来,页面使用标签页方式存在&…

面试了一个前阿里P7,Java八股文与架构核心知识简直背得炉火纯青

前几天,跟个老朋友吃饭,他最近想跳槽去大厂,觉得压力很大,问我能不能分享些所谓的经验套路。 每次有这类请求,都觉得有些有趣,不知道你发现没有大家身边真的有很多人不知道怎么面试,也不知道怎…

赛效:如何将PDF文件免费转换成Word文档

1:在网页上打开wdashi,默认进入PDF转Word页面,点击中间的上传文件图标。 2:将PDF文件添加上去之后,点击右下角的“开始转换”。 3:稍等片刻转换成功后,点击绿色的“立即下载”按钮,将…

win10修改IP地址报错:出现一个意外情况,不能完成所有你在......

问题描述 在修改网卡适配器的时候出现一下报错:出现一个意外情况,不能完成你在设置中所要求的更改 问题原因 该问题是由于我之前卸载VMware导致注册表出现问题。 解决方法 解决办法为:修复注册表(下载一个CCleaner下载试用版即可(https…

2. 查询至少连续三天下单的用户

文章目录 题目需求思路一实现一思路二实现二 题目需求 查询订单信息表(order_info)中 最少连续3天 下单的用户id,期望结果如下: user_id101 订单信息表:order_info order_id(订单id)user_id(用户id)create_date (下单日期)total_amount(订…

MySQL 数据表修复方法

MySQL表检查与修复 — check/repair指令 目录 MySQL表检查与修复 --- check/repair指令1. 指令详解2. 操作方法:命令提示符(cmd指令)操作方法SQLyog 操作方法(推荐) 本文主要讲check table和repair table指令; 1. 指令详解 在检…

如何把图片转文字?图片转文字方法分享!​

如何把图片转文字呢?在我们日常的工作或者生活当中,总会遇到需要将图片中的文字提取出来整理出文档,比如同事领导给你发的文件,或者在自己看到了喜欢书的段落句子,想要摘抄下来,这些都是可以用图片转文字来…

C++ Vector容器使用方法详解

Vector概述 C 标准库向量类是序列容器的类模板。 向量以线性排列方式存储给定类型的元素,并允许快速随机访问任何元素。 向量是需要力求保证访问性能时的首选序列容器。vector是种容器,类似数组一样,但它的size可以动态改变。vector的元素在内…

【GESP】2023年06月图形化二级 -- 时间规划

文章目录 时间规划【题目描述】【输入描述】【输出描述】【参考答案】其他测试用例 时间规划 【题目描述】 默认小猫角色和白色背景,小明在为自己规划学习时间。现在他想知道两个时刻之间有多少分钟。你能通过编程帮他做到吗? 【输入描述】 新建变量“…

餐饮市场分析(上)

阅读原文 研究某一类餐饮产品的市场概况,并在不同地区和品牌之间进行对比 一、数据需求 使用美团搜索商品返回的数据。 首先进入美团首页,切换到对应城市,并搜索感兴趣的关键词。接下来尝试翻页获取更多数据,点击下一页时发现页…

跨越时空限制,酷暑天气用VR看房是一种什么体验?

近年来,全球厄尔尼诺现象越来越频繁,夏季温度不断创下新高,持续大范围的高温天气让人们对出门“望而生畏”。很多购房者也不愿意在如此酷暑期间,四处奔波看房,酷暑天气让带看房效率大大降低,更有新闻报道&a…

Linux:LAMP-phpmyadmin

LAMP环境 (1条消息) Linux:LAMP搭建(全源码包安装)_鲍海超-GNUBHCkalitarro的博客-CSDN博客 phpmyadminphpMyAdminhttps://www.phpmyadmin.net/ 传进Linux tar xfz phpMyAdmin-5.2.1-all-languages.tar.gz 这个是解出来的包 mv phpMyAdmin-5.2.1-all-languages /…

【NOSQL数据库】Redis数据库的配置与优化一

目录 一、关系型数据库与非关系型数据库1.1关系型数据库1.2非关系型数据库1.3关系型数据库与非关系型数据库的区别1.3.1数据存储方式不同1.3.2扩展方式不同1.3.3对事务性的支持不同 1.4非关系型数据库产生的背景1.5总结 二、Redis简介2.1Redis的优点2.2使用场景2.3哪些数据适合…

大二网页设计实训-豆瓣首页(html+css)

免费开源一个前端网页,豆瓣首页,可以用来当实训等

探索神奇的甲方需求:提出异常要求的背后逻辑

在IT行业,每个人都可能遇到“神奇的甲方”和他们提出的匪夷所思甚至无厘头的需求。虽然这些要求可能让人摸不着头脑,但背后通常隐藏着某种逻辑和需求。让我们来探索一下这些“无理需求”背后的心理和可能的应对策略。 首先,为什么会出现这些…

Maven安装与配置详解

安装JDK JDK1.8所有版本官网下载链接: https://www.oracle.com/java/technologies/javase/javase8-archive-downloads.html 所有JDK下载地址: https://www.oracle.com/java/technologies/oracle-java-archive-downloads.html 可参照我的另一篇博客 安…

Unity | HDRP高清渲染管线学习笔记:HDRP Custom Pass

目录 一、Custom Pass Volume组件介绍 1.Mode(模式) 2.Injection Point(注入点) 3.Priority 4.Fade Radius 5.custom passes 二、查看Custom Pass的渲染阶段 Custom Pass允许你执行以下操作(官方文档&#xff0…

Linux--在当前路径下创建目录/文件夹指令:mkdir

语法: mkdir [选项] 文件名 功能: 在当前目录下创建一个名为 “文件名”的目录 常用选项: -p, --parents 可以是一个路径名称。此时若路径中的某些目录尚不存在,加上此选项后,系统将自动建立好那些不存在的目录,即一次可以建立…