OCR经典神经网络(三)LayoutLM v2算法原理及其在发票数据集上的应用(NER及RE)
- LayoutLM系列模型是微软发布的、文档理解多模态基础模型领域最重要和有代表性的工作:
- LayoutLM v2:在一个单一的多模态框架中对文本(text)、布局(layout)和图像(image)之间的交互进行建模。
- LayoutXLM:LayoutXLM是 LayoutLMv2的多语言扩展版本。
- LayoutLM v3:借鉴了ViLT和BEIT,不需要经过预训练的视觉backbone,通过MLM、MIM和WPA进行预训练的多模态Transformer。在以视觉为中心的任务上(如文档图像分类和文档布局分析)和以文本为中心的任务上(表单理解、收据理解、文档问答)都表现很好。
- 今天,我们来了解下LayoutLM v2模型。
- 论文链接:https://arxiv.org/pdf/2012.14740
- 同样,百度开源的paddleocr中,在
关键信息抽取
中集成了此算法。 - paddleocr中集成的算法列表:https://github.com/PaddlePaddle/PaddleOCR/blob/main/docs/algorithm/overview.md
1 LayoutLM v2算法原理
- LayoutLM v2是一种多模态Transformer模型,该模型在预训练阶段整合了文档文本、版式及视觉信息,实现了在一个框架内端到端地学习跨模态交互。同时,将一种空间感知的自注意力机制融入到了Transformer架构中。
- 除了掩码视觉语言模型(MVLM)预训练策略外,LayoutLM v2还新增了文本-图像对齐(TIA)和文本-图像匹配(TIM) 作为预训练策略,以强化不同模态间的对齐。
- LayoutLMv2不仅在传统的富视觉文档理解(VrDU)任务上取得了显著的性能提升并达到当时新的最优水平,还在文档图像的视觉问题回答(VQA)任务上实现了新突破,这证明了多模态预训练在富视觉文档理解领域的巨大潜力。
1.1 模型结构
-
模型结构如下图所示,可以看到LayoutLM v2接收文本、视觉及版式信息作为输入,以建立深度的跨模态交互。另外,将spatial-aware的自注意力机制整合到了transformer中。
-
这里,我们主要看下Embedding层:
-
文本嵌入
- 文本嵌入包含三种嵌入:词嵌入代表词本身,一维位置嵌入表示词的位置索引,而段落嵌入用于区分不同的文本段落。
t i = T o k E m b ( w i ) + P o s E m b 1 D ( i ) + S e g E m b ( s i ) t_i= TokEmb(w_i)+PosEmb1D(i)+SegEmb(s_i) ti=TokEmb(wi)+PosEmb1D(i)+SegEmb(si)
- 使用WordPiece对OCR文本序列进行分词,并将每个分词(token)分配给特定的段落。接着,在序列的开始添加[CLS]标记,在每个文本段落的末尾添加[SEP]标记。为了使最终序列的长度恰好等于最大序列长度L,在序列末尾额外添加[PAD]填充符。
-
视觉嵌入
- 给定一个文档页面图像I,将其调整大小至224×224像素后输入到视觉主干网络中。之后,输出的特征图通过平均池化到固定尺寸,宽度为W,高度为H。接下来,它被展平为长度为W×H(例如:7×7)的视觉嵌入序列,此序列被称为VisTokEmb(I)。然后
对每个视觉token嵌入应用线性投影层,以统一其维度与文本嵌入的维度
。 - 由于基于CNN的视觉主干无法捕获位置信息,因此添加一维位置嵌入。
- 对于段落嵌入,将所有视觉令牌附属于视觉段[C]。
v i = P r o j ( V i s T o k E m b ( I ) i + P o s E m b 1 D ( i ) + S e g E m b ( [ C ] ) v_i= Proj(VisTokEmb(I)_i+PosEmb1D(i)+SegEmb([C]) vi=Proj(VisTokEmb(I)i+PosEmb1D(i)+SegEmb([C])
- 给定一个文档页面图像I,将其调整大小至224×224像素后输入到视觉主干网络中。之后,输出的特征图通过平均池化到固定尺寸,宽度为W,高度为H。接下来,它被展平为长度为W×H(例如:7×7)的视觉嵌入序列,此序列被称为VisTokEmb(I)。然后
-
布局嵌入(2D Position Embeddings)
- 将所有的坐标标准化并离散化为[0, 1000]范围内的整数,并使用两个嵌入层分别嵌入x轴特征和y轴特征。
- 给定第i个( 0 ≤ i < W × H + L 0 ≤ i < W×H + L 0≤i<W×H+L)文本/视觉token的标准化边界框 b o x i = ( x m i n , x m a x , y m i n , y m a x , w i d t h , h e i g h t ) box_i = (x_{min}, x_{max}, y_{min}, y_{max}, width, height) boxi=(xmin,xmax,ymin,ymax,width,height),布局嵌入层将这六个边界框特征连接起来构建一个token级的2D位置嵌入,即布局嵌入。
-
- 由于卷积神经网络(CNNs)执行局部变换,因此视觉token嵌入可以一一映射回图像区域,既没有重叠也没有遗漏。
在计算边界框时,视觉token可以被视为均匀划分的网格。
- 对于特殊token [CLS]、[SEP]和[PAD],会附加一个空边界框boxPAD = (0, 0, 0, 0, 0, 0)。这意味着这些特殊符号在空间布局上不占用实际区域,但通过这样的空边界框嵌入,模型能够将它们整合到序列中的相应位置上,同时保持空间信息的一致性。
1.2 预训练目标及数据
1.2.1 MVLM
- 采用了掩码视觉-语言建模(Masked Visual-Language Modeling, MVLM)方法,以便模型在跨模态线索的帮助下更好地学习语言方面。
- 随机掩蔽一些文本token,并要求模型恢复这些被掩蔽的token。
- 与此同时,布局信息保持不变,这意味着模型了解每个被掩蔽token在页面上的位置。
- 为了避免视觉线索泄露,在将原始页面图像输入到视觉编码器之前,会先对应掩蔽掉与被掩蔽文本token相对应的图像区域。
1.2.2 TIA
- Text-Image Alignment(TIA):随机遮盖图像,然后识别文本对应图像是否被遮盖了。
- 为了帮助模型学习图像与边界框坐标的空間位置对应关系,提出了细粒度的跨模态对齐任务——文本-图像对齐(Text-Image Alignment, TIA)。
- 在TIA任务中,随机选择一些文本行,并在其文档图像上的对应图像区域进行遮盖, 称此操作为“遮盖”,以避免与MVLM中的“掩码”操作混淆。
- 预训练期间,在编码器输出之上构建了一个分类层。该层根据文本令牌是否被遮盖(即,[Covered]或[Not Covered])预测每个文本令牌的标签,并计算二元交叉熵损失。
- 考虑到输入图像的分辨率有限,且某些文档元素(如图表中的符号和线条)可能看起来像被遮盖的文本区域,寻找单词大小的遮盖图像区域的任务可能会存在噪声。因此,遮盖操作是在行级别进行的。
- 当MVLM和TIA同时执行时,MVLM中被掩蔽的令牌的TIA损失不予考虑。这防止了模型学习从[MASK]到[Covered]这种无用但直观的对应关系。
1.2.3 TIM
- Text-Image Matching(TIM):使用[CLS]来判断给出的图片特征与文本特征是否属于同一个页面。
- 为了帮助模型学习文档图像与文本内容之间的对应关系,采用了较为粗粒度的跨模态对齐任务,即文本-图像匹配(Text-Image Matching, TIM)。
- 将[CLS]位置的输出表示送入一个分类器,以预测图像和文本是否来自同一文档页面。正常的配对输入被视为正样本。
- 为了构建负样本,图像要么被另一文档的页面图像替换,要么被移除。
- 为防止模型通过寻找任务特定特征来作弊,对负面样本中的图像也执行相同的掩码和遮盖操作。在负面样本中,TIA的目标标签全部设置为[Covered]。
1.2.4 预训练数据
-
为了预训练和评估LayoutLMv2模型,作者从富含视觉元素的文档理解领域中选择了广泛的数据集。
-
使用IIT-CDIP作为预训练数据集。
1.3 模型微调
- 在文档级别分类任务RVL-CDIP中,使用[CLS]输出以及池化的视觉令牌表示作为全局特征。
- 对于提取式问答任务DocVQA及其他四个实体提取任务,在LayoutLMv2输出的文本部分上
构建特定任务的头部层
。在DocVQA论文中,实验结果显示,在SQuAD数据集上微调过的BERT模型比原始BERT模型表现更优。受此启发,增加了一个额外的设置:首先在问题生成(Question Generation, QG)数据集上微调LayoutLMv2,随后再在DocVQA数据集上微调。这个QG数据集包含近百万对由训练于SQuAD数据集的生成模型产生的问题-答案对。
1.4 LayoutXLM模型结构
- LayoutXLM是 LayoutLMv2的多语言扩展版本。为了准确评估LayoutXLM,论文中还引入了一个多语言表单理解基准数据集,名为XFUND,该数据集包含了7种语言(中文、日语、西班牙语、法语、意大利语、德语、葡萄牙语)的表单理解样本,并为每种语言的手工标注了键值对。
- 论文链接:https://arxiv.org/pdf/2104.08836
- LayoutXLM预训练策略,同LayoutLMv2
- 该框架如下图所示:
- 模型接收来自三种不同模态的信息,即文本、布局和图像,分别使用文本嵌入、布局嵌入和视觉嵌入层进行编码。文本和图像嵌入被连接在一起,然后加上布局嵌入以获得输入嵌入。
- 输入嵌入通过带有空间感知自注意力机制的多模态Transformer进行编码。
- 最后,输出的上下文表示可以用于后续的任务特定层。
1.5 VI-LayoutXLM
-
百度在PP-StructureV2中,针对 LayoutXLM 进行改进,得到了VI-LayoutXLM。
-
论文链接:https://arxiv.org/pdf/2210.05391
-
模型部分改进如下:
- LayoutLMv2 以及 LayoutXLM 中引入视觉骨干网络,用于提取视觉特征,并与后续的 text embedding 进行联合,作为多模态的输入 embedding。但是该模块为基于 ResNet_x101_64x4d 的特征提取网络,特征抽取阶段耗时严重。
- 因此,
移除视觉特征提取模块
,同时仍然保留文本、位置以及布局等信息,最终发现针对 LayoutXLM 进行改进,下游 SER 任务精度无损,针对 LayoutLMv2 进行改进,下游 SER 任务精度仅降低2.1%,而模型大小减小了约340M。
2 VI-LayoutXLM在发票数据集上的应用
-
关键信息抽取 (Key Information Extraction, KIE)
指的是是从文本或者图像中,抽取出关键的信息。- 针对文档图像的关键信息抽取任务作为OCR的下游任务,存在非常多的实际应用场景,如表单识别、车票信息抽取、身份证信息抽取等。
- 文档图像中的KIE一般包含2个子任务,示意图如下图所示。
SER: 语义实体识别 (Semantic Entity Recognition)
,对每一个检测到的文本进行分类,如将其分为姓名,身份证。如下图中的黑色框和红色框。RE: 关系抽取 (Relation Extraction)
,对每一个检测到的文本进行分类,如将其分为问题 (key) 和答案 (value) 。然后对每一个问题找到对应的答案,相当于完成key-value的匹配过程。如下图中的红色框和黑色框分别代表问题和答案,黄色线代表问题和答案之间的对应关系。
-
除了
视觉特征无关的多模态预训练模型结构
,paddleocr中在KIE任务上,还有两个主要的优化策略:TB-YX:考虑阅读顺序的文本行排序逻辑
- 文本阅读顺序对于信息抽取与文本理解等任务至关重要,传统多模态模型中,没有考虑不同 OCR 工具可能产生的不正确阅读顺序,而模型输入中包含位置编码,阅读顺序会直接影响预测结果
- 在预处理中,对文本行按照从上到下,从左到右(YX)的顺序进行排序,为防止文本行位置轻微干扰带来的排序结果不稳定问题,在排序的过程中,引入位置偏移阈值 Th,对于 Y 方向距离小于 Th 的2个文本内容,使用 X 方向的位置从左到右进行排序。
UDML:联合互学习知识蒸馏策略
- UDML(Unified-Deep Mutual Learning)联合互学习是 PP-OCRv2 与 PP-OCRv3 中采用的对于文本识别非常有效的提升模型效果的策略。
- 在训练时,引入2个完全相同的模型进行互学习,计算2个模型之间的互蒸馏损失函数(DML loss),同时对 transformer 中间层的输出结果计算距离损失函数(L2 loss)。
- 使用该策略,最终 XFUND 数据集上,SER 任务 F1 指标提升0.6%,RE 任务 F1 指标提升5.01%。
-
KIE常用思路有如下两种:
-
一种是SER:
- 直接使用SER,获取关键信息的类别;常用于关键信息类别固定的场景。
- 以身份证场景为例, 关键信息一般包含
姓名
、性别
、民族
等,我们直接将对应的字段标注为特定的类别即可,如下图所示:
-
注意:
- 标注过程中,对于无关于KIE关键信息的文本内容,均需要将其标注为
other
类别,相当于背景信息。如在身份证场景中,如果我们不关注性别信息,那么可以将“性别”与“男”这2个字段的类别均标注为other
。 - 标注过程中,需要以文本行为单位进行标注,无需标注单个字符的位置信息。
数据量方面,一般来说,对于比较固定的场景,50张左右的训练图片即可达到可以接受的效果,可以使用PPOCRLabel完成KIE的标注过程。
- 标注过程中,对于无关于KIE关键信息的文本内容,均需要将其标注为
-
一种是SER+RE:
- 联合使用SER+RE,先利用SER找到key和value,然后再利用RE进行匹配;常用于关系类别不固定的场景。
- 以身份证场景为例, 关键信息一般包含
姓名
、性别
、民族
等关键信息。在SER阶段,我们需要识别所有的question (key) 与answer (value) 。每个字段的类别信息(label
字段)可以是question、answer或者other(与待抽取的关键信息无关的字段)
- 在RE阶段,需要标注每个字段的的id与连接信息,如下图所示:
- 标注过程中,如果value是多个字符,那么linking中可以新增一个key-value对,如
[[0, 1], [0, 2]]
- 数据量方面,一般来说,对于比较固定的场景,50张左右的训练图片即可达到可以接受的效果,可以使用PPOCRLabel完成KIE的标注过程。
- 标注过程中,如果value是多个字符,那么linking中可以新增一个key-value对,如
-
我们参考案例:https://aistudio.baidu.com/projectdetail/4823162(
项目里提供了发票数据集
),来对VI-LayoutXLM模型有更深的认识。
-
2.1 语义实体识别 (SER)
2.1.1 模型构建
-
我这里不用命令行执行,在
paddleocr\tests
目录下创建一个py文件执行训练过程 -
我们复制一份
paddleocr\configs\kie\vi_layoutxlm\ser_vi_layoutxlm_xfund_zh_udml.yml
文件到paddleocr\tests\configs进行修改(参考上面项目链接进行修改),发票数据集在上面项目中已提供,模型部分的配置文件如下:Architecture: model_type: &model_type "kie" name: DistillationModel algorithm: Distillation Models: Teacher: pretrained: freeze_params: false return_all_feats: true model_type: *model_type algorithm: &algorithm "LayoutXLM" Transform: Backbone: name: LayoutXLMForSer pretrained: True # 会利用paddle-nlp加载预训练模型 # one of base or vi mode: vi checkpoints: num_classes: &num_classes 5 # 采用BIO的标注,训练需要修改 Student: pretrained: freeze_params: false return_all_feats: true model_type: *model_type algorithm: *algorithm Transform: Backbone: name: LayoutXLMForSer pretrained: True # one of base or vi mode: vi checkpoints: num_classes: *num_classes
-
通过下面的py文件,我们就可以愉快的查看源码了。
def train_kie_token_ser_demo():
from tools.train import program, set_seed, main
# 配置文件的源地址地址: paddleocr\configs\kie\vi_layoutxlm\ser_vi_layoutxlm_xfund_zh_udml.yml
config, device, logger, vdl_writer = program.preprocess(is_train=True)
###############修改配置(也可在yml文件中修改)##################
# 评估频率
config["Global"]["eval_batch_step"] = [0, 200]
# log的打印频率
config["Global"]["print_batch_step"] = 50
# 训练的epochs
config["Global"]["epoch_num"] = 1
# 随机种子
seed = config["Global"]["seed"] if "seed" in config["Global"] else 1024
set_seed(seed)
###############模型训练##################
main(config, device, logger, vdl_writer, seed)
def train_kie_token_re_demo():
from tools.train import program, set_seed, main
# 配置文件的源地址地址: paddleocr\configs\kie\vi_layoutxlm\re_vi_layoutxlm_xfund_zh_udml.yml
config, device, logger, vdl_writer = program.preprocess(is_train=True)
###############修改配置(也可在yml文件中修改)##################
# 评估频率
config["Global"]["eval_batch_step"] = [0, 200]
# log的打印频率
config["Global"]["print_batch_step"] = 50
# 训练的epochs
config["Global"]["epoch_num"] = 1
# 随机种子
seed = config["Global"]["seed"] if "seed" in config["Global"] else 1024
set_seed(seed)
###############模型训练##################
main(config, device, logger, vdl_writer, seed)
if __name__ == '__main__':
train_kie_token_ser_demo()
# train_kie_token_re_demo()
LayoutXLMForTokenClassification
- 首先,利用LayoutXLMModel提取特征(文本、布局信息)
- 然后,利用文本部分的特征进行BIO多分类
# paddleocr.ppocr.modeling.backbones.vqa_layoutlm.py
class LayoutXLMForTokenClassification(LayoutXLMPretrainedModel):
def __init__(self, config: LayoutXLMConfig):
super(LayoutXLMForTokenClassification, self).__init__(config)
self.num_classes = config.num_labels
self.layoutxlm = LayoutXLMModel(config)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, self.num_classes)
......
def forward(
self,
input_ids=None,
bbox=None,
image=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
labels=None,
):
# 1、经过12层的Transformer Block Encoder
outputs = self.layoutxlm(
input_ids=input_ids,
bbox=bbox,
image=image,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
)
seq_length = input_ids.shape[1]
# sequence out and image out
# 2、进行BIO多分类
# sequence_output: (bs, 561, 768) -> (bs, 512, 768) -> (bs, 512, 5)
sequence_output = outputs[0][:, :seq_length]
sequence_output = self.dropout(sequence_output)
logits = self.classifier(sequence_output)
hidden_states = {
f"hidden_states_{idx}": outputs[2][f"{idx}_data"] for idx in range(self.layoutxlm.config.num_hidden_layers)
}
if self.training:
outputs = (logits, hidden_states)
else:
outputs = (logits,)
......
return outputs
LayoutXLMModel
这里我们主要看下LayoutXLMModel模型中,文本的embedding和视觉部分的embedding。
-
文本的embedding:
-
word_embeddings:对tokenizer后的input_ids进行word_embeddings,shape变化:(bs, 512) -> (bs, 512, 768)
-
position_embeddings(1D position embedding):对文本部分的position_ids进行embeding,shape变化:(bs, 512) -> (bs, 512, 768)。这里,文本和视觉的position_embeddings是共享的。
-
spatial_position_embeddings:这里shape变化为(bs, 512, 4) -> (bs, 512, 768),是将每一个bbox信息的(x_min, y_min, x_max, y_max, h, w)编码,然后concat得到,代码如下所示。注意:如果一个bbox内的文字,被切分为多个token,那么这些token的bbox信息是一致的。
# paddlenlp.transformers.layoutxlm.modeling.py def _cal_spatial_position_embeddings(self, bbox): try: # (bs, embdedding_dim) -> (bs, embdedding_dim, 128) left_position_embeddings = self.x_position_embeddings(bbox[:, :, 0]) # (bs, embdedding_dim) -> (bs, embdedding_dim, 128) upper_position_embeddings = self.y_position_embeddings(bbox[:, :, 1]) # (bs, embdedding_dim) -> (bs, embdedding_dim, 128) right_position_embeddings = self.x_position_embeddings(bbox[:, :, 2]) # (bs, embdedding_dim) -> (bs, embdedding_dim, 128) lower_position_embeddings = self.y_position_embeddings(bbox[:, :, 3]) except IndexError as e: raise IndexError("The :obj:`bbox`coordinate values should be within 0-1000 range.") from e # (bs, embdedding_dim) -> (bs, embdedding_dim, 128) h_position_embeddings = self.h_position_embeddings(bbox[:, :, 3] - bbox[:, :, 1]) # (bs, embdedding_dim) -> (bs, embdedding_dim, 128) w_position_embeddings = self.w_position_embeddings(bbox[:, :, 2] - bbox[:, :, 0]) # [x_min, y_min, x_max, y_max, h, w] concat -> (bs, embdedding_dim, 128*6) spatial_position_embeddings = paddle.concat( [ left_position_embeddings, upper_position_embeddings, right_position_embeddings, lower_position_embeddings, h_position_embeddings, w_position_embeddings, ], axis=-1, ) return spatial_position_embeddings
-
token_type_embeddings:这里的token_type_ids全为0,shape变化为(bs, 512) -> (bs, 512, 768)
-
-
视觉部分的embedding:
- position_embeddings(1D position embedding):
shape变化为(bs, 49) -> (bs, 49, 768)
。视觉部分的position ids为:[0, 1, 2, …, 48] -> (bs, 49)。这里虽然去除了视觉提取,但是position ids按照图像224×224经过降采样32倍后的feature map:7×7进行生成。这里,文本和视觉的position_embeddings是共享的; - spatial_position_embeddings:视觉部分布局信息,即bbox的生成的核心逻辑是:7×7网格中,每一个小的正方形的坐标(x_min, y_min, x_max, y_max)即为一个视觉token。
shape变化为(bs, 49, 4) -> (bs, 49, 768)
; - visual_segment_embedding
- position_embeddings(1D position embedding):
-
最终,将文本的embedding和视觉部分的embedding送入到12层的Transformer Encoder Block提取特征。
# paddlenlp.transformers.layoutxlm.modeling.py
@register_base_model
class LayoutXLMModel(LayoutXLMPretrainedModel):
def __init__(self, config: LayoutXLMConfig):
super(LayoutXLMModel, self).__init__(config)
self.config = config
self.use_visual_backbone = config.use_visual_backbone
self.has_visual_segment_embedding = config.has_visual_segment_embedding
self.embeddings = LayoutXLMEmbeddings(config)
if self.use_visual_backbone is True:
self.visual = VisualBackbone(config)
self.visual.stop_gradient = True
self.visual_proj = nn.Linear(config.image_feature_pool_shape[-1], config.hidden_size)
if self.has_visual_segment_embedding:
self.visual_segment_embedding = self.create_parameter(
shape=[
config.hidden_size,
],
dtype=paddle.float32,
)
self.visual_LayerNorm = nn.LayerNorm(config.hidden_size, epsilon=config.layer_norm_eps)
self.visual_dropout = nn.Dropout(config.hidden_dropout_prob)
self.encoder = LayoutXLMEncoder(config)
self.pooler = LayoutXLMPooler(config)
def _calc_visual_bbox(self, image_feature_pool_shape, bbox, visual_shape):
"""
视觉部分布局信息,即bbox的生成:
- image_feature_pool_shape:(7, 7, 256)
- 文字token的bbox信息:(bs, 512, 4)
- visual_shape:[bs, 49]
"""
# 首先,生成一个序列[0, 1000, 2000, 3000, 4000, 5000, 6000, 7000]
# 然后,离散化为[0, 1000],即[0, 142, 285, 428, 571, 714, 857, 1000]
visual_bbox_x = (
paddle.arange(
0,
1000 * (image_feature_pool_shape[1] + 1),
1000,
dtype=bbox.dtype,
)
// image_feature_pool_shape[1]
)
visual_bbox_y = (
paddle.arange(
0,
1000 * (image_feature_pool_shape[0] + 1),
1000,
dtype=bbox.dtype,
)
// image_feature_pool_shape[0]
)
expand_shape = image_feature_pool_shape[0:2] # (7, 7)
# 7×7网格中,每一个小的正方形的坐标(x_min, y_min, x_max, y_max)即为一个视觉token
# visual_bbox shape = (7×7, 4)
visual_bbox = paddle.stack(
[
visual_bbox_x[:-1].expand(expand_shape),
visual_bbox_y[:-1].expand(expand_shape[::-1]).transpose([1, 0]),
visual_bbox_x[1:].expand(expand_shape),
visual_bbox_y[1:].expand(expand_shape[::-1]).transpose([1, 0]),
],
axis=-1,
).reshape([expand_shape[0] * expand_shape[1], paddle.shape(bbox)[-1]])
# 扩展到bs个样本, (7×7, 4) -> (bs, 7×7, 4)
visual_bbox = visual_bbox.expand([visual_shape[0], visual_bbox.shape[0], visual_bbox.shape[1]])
return visual_bbox
def _calc_text_embeddings(self, input_ids, bbox, position_ids, token_type_ids):
"""
文本部分进行embeddings:
word_embeddings
+ position_embeddings(文本和视觉的position_embeddings是共享的)
+ spatial_position_embeddings
+ token_type_embeddings
"""
# (bs, 512) -> (bs, 512, 768)
words_embeddings = self.embeddings.word_embeddings(input_ids)
# (bs, 512) -> (bs, 512, 768)
position_embeddings = self.embeddings.position_embeddings(position_ids)
# (bs, 512, 4) -> (bs, 512, 768)
spatial_position_embeddings = self.embeddings._cal_spatial_position_embeddings(bbox)
# (bs, 512) -> (bs, 512, 768)
token_type_embeddings = self.embeddings.token_type_embeddings(token_type_ids)
# 4种embedding相加
embeddings = words_embeddings + position_embeddings + spatial_position_embeddings + token_type_embeddings
# LayerNorm + dropout
embeddings = self.embeddings.LayerNorm(embeddings)
embeddings = self.embeddings.dropout(embeddings)
return embeddings
def _calc_img_embeddings(self, image, bbox, position_ids):
"""
视觉部分进行embedding:
position_embeddings(文本和视觉的position_embeddings是共享的)
+ spatial_position_embeddings
+ visual_segment_embedding
"""
use_image_info = self.use_visual_backbone and image is not None
# (bs, 49) -> (bs, 49, 768)
position_embeddings = self.embeddings.position_embeddings(position_ids)
# (bs, 49, 4) -> (bs, 49, 768)
spatial_position_embeddings = self.embeddings._cal_spatial_position_embeddings(bbox)
if use_image_info is True:
visual_embeddings = self.visual_proj(self.visual(image.astype(paddle.float32)))
embeddings = visual_embeddings + position_embeddings + spatial_position_embeddings
else:
# embedding相加
embeddings = position_embeddings + spatial_position_embeddings
if self.has_visual_segment_embedding:
# self.visual_segment_embedding shape = (768)
embeddings += self.visual_segment_embedding
# visual_LayerNorm + visual_dropout
embeddings = self.visual_LayerNorm(embeddings)
embeddings = self.visual_dropout(embeddings)
return embeddings
def forward(
self,
input_ids=None,
bbox=None,
image=None,
token_type_ids=None,
position_ids=None,
attention_mask=None,
head_mask=None,
output_hidden_states=False,
output_attentions=False,
):
input_shape = paddle.shape(input_ids)
visual_shape = list(input_shape)
visual_shape[1] = self.config.image_feature_pool_shape[0] * self.config.image_feature_pool_shape[1]
# 视觉部分的bbox的生成
# 视觉token被视为均匀划分的网格
# 生成的bbox信息:feature_map(7×7)网格中,每一个小的正方形的坐标(x_min, y_min, x_max, y_max)即为一个视觉token
visual_bbox = self._calc_visual_bbox(self.config.image_feature_pool_shape, bbox, visual_shape)
# 1、2D position embedding(文本部分bbox+视觉部分bbox)
# (bs, 512, 4) + (bs, 49, 4) -> (bs, 561, 4)
final_bbox = paddle.concat([bbox, visual_bbox], axis=1)
if attention_mask is None:
attention_mask = paddle.ones(input_shape)
if self.use_visual_backbone is True:
# 使用视觉部分的backbone
visual_attention_mask = paddle.ones(visual_shape)
else:
# 移除视觉特征提取模块,mask全设置为0
visual_attention_mask = paddle.zeros(visual_shape)
attention_mask = attention_mask.astype(visual_attention_mask.dtype)
# concat后attention_mask:(bs, 512) + (bs, 49) -> (bs, 561)
final_attention_mask = paddle.concat([attention_mask, visual_attention_mask], axis=1)
if token_type_ids is None:
token_type_ids = paddle.zeros(input_shape, dtype=paddle.int64)
# 2、1D position embedding(文本部分+视觉部分) (bs, 512) + (bs, 49) -> (bs, 561)
if position_ids is None:
# 文本部分的position embedding
seq_length = input_shape[1]
position_ids = self.embeddings.position_ids[:, :seq_length]
position_ids = position_ids.expand(input_shape)
# 视觉部分的position embedding
# [0, 1, 2, ..., 48] -> (bs, 49)
visual_position_ids = paddle.arange(0, visual_shape[1]).expand([input_shape[0], visual_shape[1]])
final_position_ids = paddle.concat([position_ids, visual_position_ids], axis=1)
if bbox is None:
bbox = paddle.zeros(input_shape + [4])
# 3、 text embedding & visual (bs, 512, 768) + (bs, 49, 768) -> (bs, 561, 768)
# 文本部分进行embdedding (bs, 512) -> (bs, 512, 768)
text_layout_emb = self._calc_text_embeddings(
input_ids=input_ids,
bbox=bbox,
token_type_ids=token_type_ids,
position_ids=position_ids,
)
# 视觉部分进行embedding(注意此时没有image,仅有视觉的bbox以及position_ids)
visual_emb = self._calc_img_embeddings(
image=image,
bbox=visual_bbox,
position_ids=visual_position_ids,
)
final_emb = paddle.concat([text_layout_emb, visual_emb], axis=1)
# (bs, 561) -> (bs, 1, 1, 561)
extended_attention_mask = final_attention_mask.unsqueeze(1).unsqueeze(2)
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
if head_mask is not None:
if head_mask.dim() == 1:
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1)
elif head_mask.dim() == 2:
head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1)
else:
head_mask = [None] * self.config.num_hidden_layers
# 经过Transformer Encoder Block(12层)
encoder_outputs = self.encoder(
final_emb, # 文本&视觉部分的embedding , shape=(bs, 561, 768)
extended_attention_mask, # attention_mask , shape=(bs, 1, 1, 561)
bbox=final_bbox, # 2D position embedding【如果需要相对位置位置编码,加在attention_score上,这里为False】, shape=(bs, 561, 4)
position_ids=final_position_ids, # 1D position embedding【如果需要相对位置位置编码,加在attention_score上,这里为False】, shape=(bs, 561)
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
# sequence_output shape = (bs, 561, 768)
sequence_output = encoder_outputs[0]
# pooled_output shape = (bs, 768)
pooled_output = self.pooler(sequence_output)
return sequence_output, pooled_output, encoder_outputs[1]
2.1.2 损失计算
- 由于使用了
UDML:联合互学习知识蒸馏策略
,损失计算的配置如下:
Loss:
name: CombinedLoss # ppocr.losses.combined_loss.CombinedLoss
loss_config_list:
- DistillationVQASerTokenLayoutLMLoss: # GT loss ppocr.losses.distillation_loss.DistillationVQASerTokenLayoutLMLoss
weight: 1.0
model_name_list: ["Student", "Teacher"]
key: backbone_out
num_classes: *num_classes
- DistillationSERDMLLoss: # DML loss ppocr.losses.distillation_loss.DistillationSERDMLLoss
weight: 1.0
act: "softmax"
use_log: true
model_name_pairs:
- ["Student", "Teacher"]
key: backbone_out
- DistillationVQADistanceLoss: # S5 loss ppocr.losses.distillation_loss.DistillationVQADistanceLoss
weight: 0.5
mode: "l2"
model_name_pairs:
- ["Student", "Teacher"]
key: hidden_states_5
name: "loss_5"
- DistillationVQADistanceLoss: # S8 loss ppocr.losses.distillation_loss.DistillationVQADistanceLoss
weight: 0.5
mode: "l2"
model_name_pairs:
- ["Student", "Teacher"]
key: hidden_states_8
name: "loss_8"
- 如下所示,在DistillationModel中,Teacher和Student模型分别进行前向过程
# paddleocr.ppocr.modeling.architectures.distillation_model.py
class DistillationModel(nn.Layer):
def __init__(self, config):
"""
the module for OCR distillation.
args:
config (dict): the super parameters for module.
"""
super().__init__()
self.model_list = []
self.model_name_list = []
for key in config["Models"]:
model_config = config["Models"][key]
freeze_params = False
pretrained = None
if "freeze_params" in model_config:
freeze_params = model_config.pop("freeze_params")
if "pretrained" in model_config:
pretrained = model_config.pop("pretrained")
model = BaseModel(model_config)
if pretrained is not None:
load_pretrained_params(model, pretrained)
if freeze_params:
for param in model.parameters():
param.trainable = False
self.model_list.append(self.add_sublayer(key, model))
self.model_name_list.append(key)
def forward(self, x, data=None):
result_dict = dict()
# 执行所有模型的前向过程, 例如:Teacher和Student模型
for idx, model_name in enumerate(self.model_name_list):
result_dict[model_name] = self.model_list[idx](x, data)
return result_dict
- 在CombinedLoss中遍历配置的损失函数,分别计算损失,最后相加最为总损失
# paddleocr.ppocr.losses.combined_loss.py
class CombinedLoss(nn.Layer):
"""
CombinedLoss:
a combionation of loss function
"""
def __init__(self, loss_config_list=None):
super().__init__()
self.loss_func = []
self.loss_weight = []
assert isinstance(loss_config_list, list), "operator config should be a list"
......
def forward(self, input, batch, **kargs):
# input包含Teacher模型以及Student模型的输出结果
# batch是批次数据,里面包含label
loss_dict = {}
loss_all = 0.0
# 遍历配置的所有的损失函数,计算损失
for idx, loss_func in enumerate(self.loss_func):
loss = loss_func(input, batch, **kargs)
if isinstance(loss, paddle.Tensor):
loss = {"loss_{}_{}".format(str(loss), idx): loss}
weight = self.loss_weight[idx]
loss = {key: loss[key] * weight for key in loss}
if "loss" in loss:
loss_all += loss["loss"]
else:
loss_all += paddle.add_n(list(loss.values()))
loss_dict.update(loss)
loss_dict["loss"] = loss_all
return loss_dict
-
我们看下具体配置的损失函数:
-
DistillationVQASerTokenLayoutLMLoss的实质就是每个模型分别计算NER任务的CrossEntropyLoss,即GT loss:
class DistillationVQASerTokenLayoutLMLoss(VQASerTokenLayoutLMLoss): def __init__(self, num_classes, model_name_list=[], key=None, name="loss_ser"): super().__init__(num_classes=num_classes) self.model_name_list = model_name_list self.key = key self.name = name def forward(self, predicts, batch): loss_dict = dict() # 遍历Teacher模型、Student模型 for idx, model_name in enumerate(self.model_name_list): # 先从predicts取出相关模型的预测结果 out = predicts[model_name] # 然后,从out中取出key(即配置文件中配置的backbone_out)的值 if self.key is not None: out = out[self.key] # 调用父类,计算损失 loss = super().forward(out, batch) loss_dict["{}_{}".format(self.name, model_name)] = loss["loss"] return loss_dict # DistillationVQASerTokenLayoutLMLoss的父类 class VQASerTokenLayoutLMLoss(nn.Layer): def __init__(self, num_classes, key=None): super().__init__() self.loss_class = nn.CrossEntropyLoss() self.num_classes = num_classes self.ignore_index = self.loss_class.ignore_index self.key = key def forward(self, predicts, batch): if isinstance(predicts, dict) and self.key is not None: predicts = predicts[self.key] labels = batch[5] # (bs, 512) attention_mask = batch[2] # (bs, 512) if attention_mask is not None: active_loss = ( attention_mask.reshape( [ -1, ] ) == 1 ) # active_output_shape = (bs, 512, 5) -> (bs*512, 5) active_output = predicts.reshape([-1, self.num_classes])[active_loss] # active_label_shape = bs*512 active_label = labels.reshape( [ -1, ] )[active_loss] # 交叉熵损失函数 loss = self.loss_class(active_output, active_label) else: loss = self.loss_class( predicts.reshape([-1, self.num_classes]), labels.reshape( [ -1, ] ), ) return {"loss": loss}
-
DistillationSERDMLLoss实质是计算Techaer和Student模型之间的互蒸馏损失函数,即KL散度。
class DistillationSERDMLLoss(DMLLoss): """ """ def __init__( self, act="softmax", use_log=True, num_classes=7, model_name_pairs=[], key=None, name="loss_dml_ser", ): super().__init__(act=act, use_log=use_log) assert isinstance(model_name_pairs, list) self.key = key self.name = name self.num_classes = num_classes self.model_name_pairs = model_name_pairs def forward(self, predicts, batch): loss_dict = dict() # 遍历Teacher模型、Student模型 for idx, pair in enumerate(self.model_name_pairs): # 取出Teacher模型以及Student模型中的结果 out1 = predicts[pair[0]] out2 = predicts[pair[1]] if self.key is not None: # 取出backbone_out out1 = out1[self.key] out2 = out2[self.key] out1 = out1.reshape([-1, out1.shape[-1]]) out2 = out2.reshape([-1, out2.shape[-1]]) attention_mask = batch[2] if attention_mask is not None: active_output = ( attention_mask.reshape( [ -1, ] ) == 1 ) out1 = out1[active_output] out2 = out2[active_output] # 调用父类的方法 loss_dict["{}_{}".format(self.name, idx)] = super().forward(out1, out2) return loss_dict # DistillationSERDMLLoss的父类 class DMLLoss(nn.Layer): """ DMLLoss """ def __init__(self, act=None, use_log=False): super().__init__() if act is not None: assert act in ["softmax", "sigmoid"] if act == "softmax": self.act = nn.Softmax(axis=-1) elif act == "sigmoid": self.act = nn.Sigmoid() else: self.act = None self.use_log = use_log self.jskl_loss = KLJSLoss(mode="kl") def _kldiv(self, x, target): """ 计算两个概率分布之间的KL散度: KL散度的公式是 KL(P||Q) = ΣP(x) * log(P(x)/Q(x)),这里将其重写为ΣP(x)*(log(P(x)) - log(Q(x))) 即target * (paddle.log(target + eps) - x) """ eps = 1.0e-10 loss = target * (paddle.log(target + eps) - x) # batch mean loss loss = paddle.sum(loss) / loss.shape[0] return loss def forward(self, out1, out2): if self.act is not None: out1 = self.act(out1) + 1e-10 out2 = self.act(out2) + 1e-10 if self.use_log: # 计算KL散度 # for recognition distillation, log is needed for feature map log_out1 = paddle.log(out1) log_out2 = paddle.log(out2) loss = (self._kldiv(log_out1, out2) + self._kldiv(log_out2, out1)) / 2.0 else: # for detection distillation log is not needed loss = self.jskl_loss(out1, out2) return loss
-
DistillationVQADistanceLoss,本质是对 transformer 中间层的输出结果计算距离损失函数(L2 loss)
# DistillationVQADistanceLoss的父类 class DistanceLoss(nn.Layer): """ DistanceLoss: mode: loss mode """ def __init__(self, mode="l2", **kargs): super().__init__() assert mode in ["l1", "l2", "smooth_l1"] if mode == "l1": self.loss_func = nn.L1Loss(**kargs) elif mode == "l2": self.loss_func = nn.MSELoss(**kargs) elif mode == "smooth_l1": self.loss_func = nn.SmoothL1Loss(**kargs) def forward(self, x, y): return self.loss_func(x, y)
其他部分,诸如数据集的加载、构建优化器、创建评估函数、加载预训练模型、模型训练等,大家可以查看源码,不再赘述。
-
2.2 关系抽取(RE)
- 我们这里,看下模型的构建部分代码,其他代码,大家可以查看源码,不再赘述。
# paddlenlp.transformers.layoutxlm.modeling.py
class LayoutXLMForRelationExtraction(LayoutXLMPretrainedModel):
def __init__(self, config: LayoutXLMConfig):
super(LayoutXLMForRelationExtraction, self).__init__(config)
self.layoutxlm = LayoutXLMModel(config)
self.extractor = REDecoder(config.hidden_size, config.hidden_dropout_prob)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
......
def forward(
self,
input_ids,
bbox,
image=None,
attention_mask=None,
entities=None,
relations=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
labels=None,
):
# 1、经过12层的Transformer Block Encoder
outputs = self.layoutxlm(
input_ids=input_ids, # (bs, 512)
bbox=bbox, # (bs, 512, 4)
image=image, # None
attention_mask=attention_mask, # (bs, 512)
token_type_ids=token_type_ids, # (bs. 512)
position_ids=position_ids, # None
head_mask=head_mask, # None
)
seq_length = input_ids.shape[1]
# 最后一层输出
# sequence_output_shape = (bs, 512, 768)
sequence_output = outputs[0][:, :seq_length]
sequence_output = self.dropout(sequence_output)
# 2、计算loss和预测关系
loss, pred_relations = self.extractor(sequence_output, entities, relations)
hidden_states = [outputs[2][f"{idx}_data"] for idx in range(self.layoutxlm.config.num_hidden_layers)]
hidden_states = paddle.stack(hidden_states, axis=1)
# 3、返回结果
res = dict(loss=loss, pred_relations=pred_relations, hidden_states=hidden_states)
return res
-
主要代码在REDecoder中
- 首先,构建构建关系对的正负样本
- 然后,获取关系头(question)、关系尾(answer)对应的特征信息
- 获取关系头(即question)在input_ids中开始的索引对应token的hidden_states(shape=(100, 768))和关系头(question)经过Embedding后的结果(shape=(100, 768))进行concat
- 获取关系尾(即answer)在input_ids中开始的索引对应token的hidden_states(shape=(100, 768))和关系尾(answer)经过Embedding后的结果(shape=(100, 768))进行concat
- 利用提取到的head_repr、tail_repr特征信息进行关系分类
- 最后,利用预测结果,计算交叉熵损失等
- 下面,给出一个relations和entities示例,方便理解。
class REDecoder(nn.Layer):
def __init__(self, hidden_size=768, hidden_dropout_prob=0.1):
super(REDecoder, self).__init__()
self.entity_emb = nn.Embedding(3, hidden_size)
# 100代表:100个关系对
# (100, 1536) -> (100, 768) -> (100, 384)
projection = nn.Sequential(
nn.Linear(hidden_size * 2, hidden_size),
nn.ReLU(),
nn.Dropout(hidden_dropout_prob),
nn.Linear(hidden_size, hidden_size // 2),
nn.ReLU(),
nn.Dropout(hidden_dropout_prob),
)
self.ffnn_head = copy.deepcopy(projection)
self.ffnn_tail = copy.deepcopy(projection)
# (100, 384) -> (100, 2)
self.rel_classifier = BiaffineAttention(hidden_size // 2, 2)
self.loss_fct = CrossEntropyLoss()
def build_relation(self, relations, entities):
"""
relations_shape = (bs, 262145, 2)
entities_shape = (bs, 513, 3)
注:
relations第1个数组代表实际长度,例如:[10, 10],代表:关系对(QUESTION-ANSWER)只有10个,其他为填充
entities第1个数组代表实际长度,例如:[20, 20, 20],代表:实例(QUESTION或ANSWER)只有20个,其他为填充
"""
batch_size, max_seq_len = paddle.shape(entities)[:2]
# new_relations_shape = (bs, 513*513, 3), 初始化为-1
new_relations = paddle.full(
shape=[batch_size, max_seq_len * max_seq_len, 3], fill_value=-1, dtype=relations.dtype
)
for b in range(batch_size):
if entities[b, 0, 0] <= 2:
entitie_new = paddle.full(shape=[512, 3], fill_value=-1, dtype=entities.dtype)
entitie_new[0, :] = 2
entitie_new[1:3, 0] = 0 # start
entitie_new[1:3, 1] = 1 # end
entitie_new[1:3, 2] = 0 # label
entities[b] = entitie_new
# 实体label_shape为: [2, 1, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 1, 1, 2, 2, 2, 2, 1]
# all_possible_relations1为: [1 , 2 , 4 , 6 , 8 , 10, 12, 13, 14, 19] QUESTION
# all_possible_relations2为: [0 , 3 , 5 , 7 , 9 , 11, 15, 16, 17, 18] ANSWER
entitie_label = entities[b, 1 : entities[b, 0, 2] + 1, 2]
all_possible_relations1 = paddle.arange(0, entities[b, 0, 2], dtype=entities.dtype)
all_possible_relations1 = all_possible_relations1[entitie_label == 1]
all_possible_relations2 = paddle.arange(0, entities[b, 0, 2], dtype=entities.dtype)
all_possible_relations2 = all_possible_relations2[entitie_label == 2]
# 所有可能的关系:all_possible_relations_shape:(100, 2)
# [
# [1, 0], [1, 3], ... , [1, 18],
# [2, 0], [2, 3], ... , [2, 18],
# ......
# [19, 0], [19, 3], ... , [19, 18]
# ]
all_possible_relations = paddle.stack(
paddle.meshgrid(all_possible_relations1, all_possible_relations2), axis=2
).reshape([-1, 2])
if len(all_possible_relations) == 0:
all_possible_relations = paddle.full_like(all_possible_relations, fill_value=-1, dtype=entities.dtype)
all_possible_relations[0, 0] = 0
all_possible_relations[0, 1] = 1
# relation_head: [1 , 2 , 4 , 6 , 8 , 10, 12, 13, 14, 19]
# relation_tail: [0 , 3 , 5 , 7 , 9 , 11, 17, 15, 16, 18]
relation_head = relations[b, 1 : relations[b, 0, 0] + 1, 0]
relation_tail = relations[b, 1 : relations[b, 0, 1] + 1, 1]
# positive_relations_shape: (10, 2)
positive_relations = paddle.stack([relation_head, relation_tail], axis=1)
# (100, 2) -> (100, 10, 2)
all_possible_relations_repeat = all_possible_relations.unsqueeze(axis=1).tile(
[1, len(positive_relations), 1]
)
# (100, 2) -> (100, 10, 2)
positive_relations_repeat = positive_relations.unsqueeze(axis=0).tile([len(all_possible_relations), 1, 1])
# mask shape = (100, 10)
mask = paddle.all(all_possible_relations_repeat == positive_relations_repeat, axis=2)
# 获取关系对负样本
# negative_mask = paddle.any(mask, axis=1) is False
negative_mask = ~paddle.any(mask, axis=1)
negative_relations = all_possible_relations[negative_mask]
# 获取关系对正样本
# positive_mask = paddle.any(mask, axis=0) is True
positive_mask = paddle.any(mask, axis=0)
positive_relations = positive_relations[positive_mask]
if negative_mask.sum() > 0:
# positive_relations_shape = (10, 2)
# negative_relations_shape = (90, 2)
# reordered_relations_shape = (100, 2)
reordered_relations = paddle.concat([positive_relations, negative_relations])
else:
reordered_relations = positive_relations
relation_per_doc_label = paddle.zeros([len(reordered_relations), 1], dtype=reordered_relations.dtype)
relation_per_doc_label[: len(positive_relations)] = 1
# relation_per_doc shape: (100, 3)
"""
relation_per_doc =
[[1 , 0 , 1 ],# 正样本
[2 , 3 , 1 ],
[4 , 5 , 1 ],
......
[19, 18, 1 ],
[1 , 3 , 0 ],# 负样本
[1 , 5 , 0 ],
......
]
"""
relation_per_doc = paddle.concat([reordered_relations, relation_per_doc_label], axis=1)
assert len(relation_per_doc[:, 0]) != 0
# 第1个元素记录正负样本的长度信息,例如:[100, 100, 100]
new_relations[b, 0] = paddle.shape(relation_per_doc)[0].astype(new_relations.dtype)
# 将正负样本放到new_relations中
new_relations[b, 1 : len(relation_per_doc) + 1] = relation_per_doc
# new_relations.append(relation_per_doc)
return new_relations, entities
def get_predicted_relations(self, logits, relations, entities):
"""
logits: 预测得到的关系概率, 例如:shape = (100, 2)
relations: shape = (100, 3)
entities: shape = (513, 3)
"""
pred_relations = []
for i, pred_label in enumerate(logits.argmax(-1)):
if pred_label != 1:
continue
rel = paddle.full(shape=[7, 2], fill_value=-1, dtype=relations.dtype)
rel[0, 0] = relations[:, 0][i]
rel[1, 0] = entities[:, 0][relations[:, 0][i] + 1]
rel[1, 1] = entities[:, 1][relations[:, 0][i] + 1]
rel[2, 0] = entities[:, 2][relations[:, 0][i] + 1]
rel[3, 0] = relations[:, 1][i]
rel[4, 0] = entities[:, 0][relations[:, 1][i] + 1]
rel[4, 1] = entities[:, 1][relations[:, 1][i] + 1]
rel[5, 0] = entities[:, 2][relations[:, 1][i] + 1]
rel[6, 0] = 1
pred_relations.append(rel)
return pred_relations
def forward(self, hidden_states, entities, relations):
"""
hidden_states_shape:(bs, 512, 768)
entities_shape: (bs, 513, 3) , 其中:513 = 512 + 1,第一个元素记录长度信息
relations_shape: (bs, 262145, 2),其中:262145 = 512*512 + 1,第一个元素记录长度信息
"""
batch_size, max_length, _ = paddle.shape(entities)
# 1、构建关系的正负样本
# relations_shape: (bs, 263169, 3) , 其中: 263169 = 513 * 513
# entities_shape: (bs, 513, 3)
relations, entities = self.build_relation(relations, entities)
loss = 0
# 所有预测关系结果
all_pred_relations = paddle.full(
shape=[batch_size, max_length * max_length, 7, 2], fill_value=-1, dtype=entities.dtype
)
for b in range(batch_size):
# 2、获取关系头(question)、关系尾(answer)对应的特征信息
# 取出正负样本关系对, relation_shape = (100, 3)
relation = relations[b, 1 : relations[b, 0, 0] + 1]
# 获取关系头(question)、关系尾(answer)、以及关系标签(1表示question和answer是一对,即正样本, 0表示负样本)
head_entities = relation[:, 0]
tail_entities = relation[:, 1]
relation_labels = relation[:, 2]
# 每一个实体(question或answer)在input_ids中开始的索引
# 例如: [0 , 3 , 4 , 8 , 14 , 16 , 23 , 29 , 34 , 37 , 60 , 65 , 82 , 84 ,
# 87 , 90 , 91 , 96 , 102, 106]
entities_start_index = paddle.to_tensor(entities[b, 1 : entities[b, 0, 0] + 1, 0])
# 获取每个实体类型编号,1表示question,2表示answer
# 例如:[2, 1, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 1, 1, 2, 2, 2, 2, 1]
entities_labels = paddle.to_tensor(entities[b, 1 : entities[b, 0, 2] + 1, 2])
# 获取关系头(即question)在input_ids中开始的索引,为了后面获取对应token的hidden_states
head_index = entities_start_index[head_entities]
# 获取关系头(question)对应的实体类型编号
head_label = entities_labels[head_entities]
# 关系头(question)经过Embedding, head_label_repr_shape = (100, 768)
head_label_repr = self.entity_emb(head_label)
# 获取关系尾(即answer)在input_ids中开始的索引,为了后面获取对应token的hidden_states
tail_index = entities_start_index[tail_entities]
# 获取关系尾(answer)对应的实体类型编号
tail_label = entities_labels[tail_entities]
# 关系尾(answer)经过Embedding, tail_label_repr_shape = (100, 768)
tail_label_repr = self.entity_emb(tail_label)
# 获取关系头(question)开始token的hidden_states, tmp_hidden_states shape: (100, 768)
tmp_hidden_states = hidden_states[b][head_index]
if len(tmp_hidden_states.shape) == 1:
tmp_hidden_states = paddle.unsqueeze(tmp_hidden_states, axis=0)
# concat, head_repr_shape = (100, 1536)
head_repr = paddle.concat((tmp_hidden_states, head_label_repr), axis=-1)
# 获取关系尾(answer)开始token的hidden_states, tmp_hidden_states shape: (100, 768)
tmp_hidden_states = hidden_states[b][tail_index]
if len(tmp_hidden_states.shape) == 1:
tmp_hidden_states = paddle.unsqueeze(tmp_hidden_states, axis=0)
# concat, tail_repr_shape = (100, 1536)
tail_repr = paddle.concat((tmp_hidden_states, tail_label_repr), axis=-1)
# 3、利用提取到的head_repr、tail_repr进行关系分类
# heads_shape = (100, 1536) -> (100, 384)
# tails_shape = (100, 1536) -> (100, 384)
heads = self.ffnn_head(head_repr)
tails = self.ffnn_tail(tail_repr)
# 结合双线性层和线性层,实现对两个输入向量的复杂交互建模
# logits_shape = (100, 2)
logits = self.rel_classifier(heads, tails)
# 4、计算交叉熵损失
loss += self.loss_fct(logits, relation_labels)
pred_relations = self.get_predicted_relations(logits, relation, entities[b])
if len(pred_relations) > 0:
pred_relations = paddle.stack(pred_relations)
all_pred_relations[b, 0, :, :] = paddle.shape(pred_relations)[0].astype(all_pred_relations.dtype)
all_pred_relations[b, 1 : len(pred_relations) + 1, :, :] = pred_relations
return loss, all_pred_relations
- 关于模型的预测代码(使用OCR结果进行预测等),可以参考https://aistudio.baidu.com/projectdetail/4823162。