Paddle进阶实战系列(一):保险文本视觉认知问答

news2024/11/30 12:34:37

保险文本视觉认知问答

1.项目介绍

1.1背景

随着人工智能技术的逐渐成熟,计算机视觉、语音、自然语言处理等技术在金融行业的应用从广度和深度上都在加速,这不仅降低了金融机构的运营和风险成本,而且有助于提升客户的满意度,比如利用NLP 技术实现智能问答解决方案,帮助用户即使没有复杂的金融背景知识也能快速找到自己需要的信息,而在寿险、产险、健康险等保险的理赔流程和客户服务环节中,存在大量扫描文档,例如医疗票据、费用清单、病例等。对这些扫描文档进行文字检测与识别,并且提取出结构化信息,可以用于极速理赔、个人健康管理等业务场景。

在保险领域,用户常见的问题占了60%~70%,这部分重复性工作费时费力,需要更有效率的处理方式。智能问答能够准确理解用户的意图,并直接给出精确的答案,极大节省了用户及工作人员的时间。

1.2.项目任务分析

本次任务需要将提供面向保险场景的扫描图片数据集,利用OCR技术自动识别影像资料后,再通过AI智能判断所识别文字的内在逻辑,回答关于图片的自然语言问题。问题的答案是可以从图片中提取的任何文本/标记。

输入: 保险场景的扫描文档(例如:医疗票据)+ 自然语言提问(例如:病人服用的药品清单有什么?)

输出: 对应自然语言提问的事实性答案.

  1. 分析: 根据不同文档图片数据集进行OCR识别,对问题和回答进行建模,保证一定准确率,技术涉及到OCR+NLP

  2. 难点: 照片拍摄角度不同,字体混合手写,一张图片可能由多张票据混合,背景噪声影响识别效果,考虑使用多种模型对比。

项目链接:见文末!

项目框架:使用paddle框架将OCR与NLP技术相结合

PaddleOCR流程

1.3 参考资料:

基于Paddle实现baseline项目参考:AIWN保险文本视觉认知问答

优秀方案参考:DocVQA冠军方案分享

PaddleOCR官方资料库 

PaddleNLP官方资料库 

2.数据集

2.1数据简介

  • 本次大赛提供的数据集使用的扫描文件类型包括票据、说明、报告等20 多种。混合了印刷、打字和手写的内容。
  • 训练集有5000余张左右原始扫描文件及对应的 4万余个自然语言问答标注。提供的数据均已做了标注及脱敏。

2.1.1 训练集

  • 训练集数据包括:

    • image:包含所有原始扫描文件图像
    • train.csv:问答训练库,包含序号(index)、问题 ID(quesiton_id)、图片名称(filename)、问题(question_text)、答案(answer_text),共 5 列
    • readme:数据说明文档

    字段说明:

  • 训练集用于模型训练,数据字段包括以下内容:

    1、index:序号

    2、question_id:问题的唯一id标识

    3、filename:问题对应的唯一图片名称

    4、question_text:问题描述

    5、answer_text:问题对应的唯一答案

2.1.2 测试集

  • 测试集数据规模为1000张左右原始扫描文件及对应的7000个自然语言问题,数据内容样例同训练集。

  • 测试集包含以下3个文件:

    • image:包含所有原始扫描文件图像
    • test1.csv:问答测试库,包含序号(index)、问题 ID(quesiton_id)、图片路径(filename)、问题(question_text),共 4 列
    • readme:数据说明文档
  • 测试集用于模型验证,需提交问题对应答案结果,数据字段包括以下内容:

    1、index:序号

    2、question_id:问题的唯一id标识

    3、filename:问题对应的唯一图片名称

    4、question_text:问题描述

2.2数据展示

  • 样例一: 

    • 提问: 西药费的金额是多少? 回答: 140.16
      提问: 140.16元购买了什么药品? 回答: {甲}缘沙坦胶囊{基}
  • 样例二:

    • 提问: 这是一份关于什么药品的说明? 回答: 十三味疏肝胶囊
      提问: 药品的有效期是多久? 回答: 1.5年

3.项目代码

使用PaddleOCR+PaddleNLP实现代码

参考项目原地址: https://github.com/datawhalechina/competition-baseline/tree/master/competition/AIWIN2021

3.1安装环境依赖包

# 安装paddleocr和paddlenlp
!pip install --user paddleocr==2.0.4 paddlenlp==2.0.0rc18
!pip list
!pip install pandas pillow matplotlib Ipython
#解压数据集
!tar -xf data/data83016/dataset.tar -C data
import pandas as pd
from PIL import Image
import codecs
import os

import matplotlib.pyplot as plt
# from IPython.display import set_matplotlib_formats
# %matplotlib inline
# set_matplotlib_formats('svg') # 输出为svg

df = pd.read_csv('data/train-utf8.csv')
df['filename'] = 'data/image/' + df['filename'] # 改为本地路径

3.2 OCR阶段

ocr阶段生成位置及内容:

注:Paddleocr目前支持中英文、英文、法语、德语、韩语、日语,可以通过修改lang参数进行切换参数依次为chenfrenchgermankoreanjapan

from paddleocr import PaddleOCR
import pandas as pd
from PIL import Image
import codecs
import os

import matplotlib.pyplot as plt

ocr = PaddleOCR(use_angle_cls=True, lang="ch", enable_mkldnn=True) # need to run only once to download and load model into memory

df = pd.read_csv('data/train-utf8.csv')
df['filename'] = 'data/image/' + df['filename'] # 改为本地路径

for path in df['filename'].unique():
    print(path)
    if os.path.exists('result/' + os.path.basename(path)[:-4] + '.txt'):
        continue
    
    result = ocr.ocr(path, cls=True)

    boxes = [line[0] for line in result]
    txts = [line[1][0] for line in result]
    scores = [line[1][1] for line in result]
    
    for box, txt in zip(boxes, txts):
        with codecs.open('result/' + os.path.basename(path)[:-4] + '.txt', 'a') as up:
            up.write('{0}\t{1}\n'.format(box, txt))

In [ ]

#创建结果路径
!mkdir result

In [ ]

#运行上述脚本,时间比较长,耐心等待即可
!python ocr.py

输出结果展示:

[[394.0, 51.0], [459.0, 51.0], [459.0, 75.0], [394.0, 75.0]] 橘红丸

[[34.0, 125.0], [343.0, 125.0], [343.0, 137.0], [34.0, 137.0]] 说明书来源:广东宏兴集团股份有限公司宏兴制药厂

[[77.0, 169.0], [150.0, 169.0], [150.0, 186.0], [77.0, 186.0]] 【药品名称】

[[178.0, 169.0], [249.0, 169.0], [249.0, 186.0], [178.0, 186.0]] 【通用名称】

[[300.0, 169.0], [343.0, 169.0], [343.0, 186.0], [300.0, 186.0]] 橘红丸

[[178.0, 210.0], [248.0, 210.0], [248.0, 223.0], [178.0, 223.0]] 【汉语拼音】

[[300.0, 210.0], [379.0, 210.0], [379.0, 223.0], [300.0, 223.0]] JuhongWan

结果分析示例:

In [ ]

df.head(10)
   index question_id                                           filename  \
0      1      Q00001  data/image/c850b0d7018d127989d1b20d0f7118d66f5...   
1      2      Q00002  data/image/c850b0d7018d127989d1b20d0f7118d66f5...   
2      3      Q00003  data/image/c850b0d7018d127989d1b20d0f7118d66f5...   
3      4      Q00004  data/image/c850b0d7018d127989d1b20d0f7118d66f5...   
4      5      Q00005  data/image/c850b0d7018d127989d1b20d0f7118d66f5...   
5      6      Q00006  data/image/c850b0d7018d127989d1b20d0f7118d66f5...   
6      7      Q00007  data/image/c850b0d7018d127989d1b20d0f7118d66f5...   
7      8      Q00008  data/image/c850b0d7018d127989d1b20d0f7118d66f5...   
8      9      Q00009  data/image/c850b0d7018d127989d1b20d0f7118d66f5...   
9     10      Q00010  data/image/AHEFGLB18921EAAA75R7_20210301111254...   

                                   question_text         answer_text  
0                                        这是什么药品?               茶碱缓释片  
1                                     本说明书来源于哪里?        黑龙江鼎恒升药业有限公司  
2                                     本品可通过什么屏障?                  胎盘  
3                                   说明书上方正中是什么字?               茶碱缓释片  
4                                       左上角是什么字?  说明书来源:黑龙江鼎恒升药业有限公司  
5                                   老年用药是下一项是什么?              药物相互作用  
6  Theophylline Sustainde-release Tablets是药品的什么?                 英文名  
7                                        茶碱是指什么?                主要成份  
8                                   198.18是指什么数?                 分子量  
9                        太平洋产险全国统一保险消费投诉电话是哪个号码?           95500-3-4  

In [ ]

Image.open(df['filename'].iloc[0])

<PIL.PngImagePlugin.PngImageFile image mode=RGBA size=852x2471 at 0x7FB02B3C36D0>

In [ ]

codecs.open('result/' + os.path.basename(df['filename'].iloc[0])[:-4] + '.txt').readlines()[:10]
['[[374.0, 51.0], [476.0, 51.0], [476.0, 74.0], [374.0, 74.0]]\t茶碱缓释片\n',
 '[[33.0, 124.0], [274.0, 124.0], [274.0, 137.0], [33.0, 137.0]]\t说明书来源:黑龙江鼎恒升药业有限公司\n',
 '[[231.0, 167.0], [304.0, 170.0], [303.0, 187.0], [231.0, 184.0]]\t【通用名称】\n',
 '[[133.0, 170.0], [202.0, 170.0], [202.0, 185.0], [133.0, 185.0]]\t【药品名称】\n',
 '[[354.0, 169.0], [424.0, 169.0], [424.0, 187.0], [354.0, 187.0]]\t茶碱缓释片\n',
 '[[232.0, 209.0], [292.0, 209.0], [292.0, 226.0], [232.0, 226.0]]\t【商品名】\n',
 '[[231.0, 245.0], [291.0, 248.0], [290.0, 266.0], [231.0, 264.0]]\t【英文名】\n',
 '[[355.0, 248.0], [600.0, 248.0], [600.0, 264.0], [355.0, 264.0]]\tTheophylline Sustained-release Tablets\n',
 '[[233.0, 286.0], [304.0, 286.0], [304.0, 303.0], [233.0, 303.0]]\t【汉语拼音】\n',
 '[[356.0, 288.0], [488.0, 288.0], [488.0, 301.0], [356.0, 301.0]]\tChajian HuanshiPian\n']

3.3 自然语言处理

3.3.1 模型1:规则匹配

In [ ]

codecs.open('result/' + os.path.basename(df['filename'].iloc[80])[:-4] + '.txt').readlines()[:15]
['[[150.0, 57.0], [470.0, 57.0], [470.0, 83.0], [150.0, 83.0]]\tPERSDNAL\n',
 '[[456.0, 57.0], [694.0, 57.0], [694.0, 83.0], [456.0, 83.0]]\tRESUME\n',
 '[[205.0, 276.0], [293.0, 276.0], [293.0, 304.0], [205.0, 304.0]]\t五百丁\n',
 '[[576.0, 276.0], [696.0, 276.0], [696.0, 307.0], [576.0, 307.0]]\t基本信息\n',
 '[[212.0, 328.0], [519.0, 328.0], [519.0, 341.0], [212.0, 341.0]]\t一句话介绍自己,告诉HR为什么选择你而不是别人\n',
 '[[575.0, 332.0], [633.0, 332.0], [633.0, 355.0], [575.0, 355.0]]\t画24岁\n',
 '[[576.0, 367.0], [629.0, 367.0], [629.0, 389.0], [576.0, 389.0]]\t国汉族\n',
 '[[577.0, 402.0], [657.0, 402.0], [657.0, 422.0], [577.0, 422.0]]\t广东广州\n',
 '[[35.0, 416.0], [154.0, 416.0], [154.0, 443.0], [35.0, 443.0]]\t(国)求职意向\n',
 '[[574.0, 436.0], [699.0, 430.0], [700.0, 453.0], [575.0, 459.0]]\tC13800138000\n',
 '[[36.0, 474.0], [132.0, 474.0], [132.0, 494.0], [36.0, 494.0]]\t幼儿园老师\n',
 '[[574.0, 471.0], [694.0, 466.0], [695.0, 486.0], [575.0, 490.0]]\t区bd@500d.me\n',
 '[[575.0, 533.0], [695.0, 533.0], [695.0, 562.0], [575.0, 562.0]]\t()技能特长\n',
 '[[34.0, 561.0], [155.0, 561.0], [155.0, 591.0], [34.0, 591.0]]\t()教育背景\n',
 '[[35.0, 619.0], [122.0, 619.0], [122.0, 633.0], [35.0, 633.0]]\t2013.9-至今\n']

In [ ]

df[df['question_text'].apply(lambda x: '邮箱是多少' in x)]
# 1、ocr识别图片
# 2、问题的意图
# 3、问题和ocr的结果进行匹配
       index question_id                                           filename  \
96        97      Q00097  data/image/e506b03f95cfc0b0649e4edcdb2076300a9...   
376      377      Q00377  data/image/a11b4e53ee1b706c0a01c626d4b53ee7712...   
1954    1955      Q01955  data/image/e09b52455b9e351cf62b8537f4d06208a9c...   
2088    2089      Q02089                           data/image/word_1145.png   
2812    2813      Q02813  data/image/03d0ce91ee87f4939e64470c700d69a9058...   
2850    2851      Q02851  data/image/d326457bd0d87670c10e232ceef5a0ffecc...   
2919    2920      Q02920  data/image/03d0ce91ee87f4939e64470c700d69a9058...   
3153    3154      Q03154  data/image/fa964e762d3d2ab7595931c1d9bdd628475...   
3169    3170      Q03170  data/image/e8c8044dd0ba4c1b7665be4005b6835f314...   
3283    3284      Q03284  data/image/ecd226c3b1db5dec169dad321465287ffad...   
5002    5003      Q05003  data/image/054260010acde733be26cd74ad7fff4b77b...   
7032    7033      Q07033  data/image/c4b40cc2dc55ad0ce1909db20ab1d29fa36...   
7471    7472      Q07472  data/image/ba12ed9453422ec07f2866a3e69d7701af5...   
7538    7539      Q07539  data/image/cb3eb2eb4f42bdac18dd9634c08687422b8...   
8862    8863      Q08863                           data/image/word_1117.png   
9725    9726      Q09726  data/image/d3dfd339afdfd79102cd5dc3508ef106dfc...   
12592  12593      Q12593  data/image/d792e6f57fc699e729122b938777eee60b2...   
13625  13626      Q13626  data/image/3266144112911b1370cbe9b0ebb78bce86c...   
13767  13768      Q13768  data/image/f25e0d56fbe2f2cd7227bcadb30f3c5baca...   
30950  30951      Q30951  data/image/AHEFBZ1Y2021M010251ATEMP_5249414_1.jpg   
31092  31093      Q31093  data/image/11525741dee150c477b4cc598d934afa5d4...   
31379  31380      Q31380  data/image/201907_71922e65-a8da-4b09-844d-8e55...   
34193  34194      Q34194  data/image/4c42b2a394fea24bb8e41a2e6a0f35376bf...   
34369  34370      Q34370  data/image/79dbb85c53a538fc9fe1d38dcc22d949920...   
35070  35071      Q35071  data/image/201907_f0116bfd-948c-4b3b-8bb9-6f06...   
35241  35242      Q35242  data/image/57cedc752dfc0f0037e2fc3771e51abe687...   
35249  35250      Q35250  data/image/3dce2f8ef58180c531cbc8c9a271be4d30e...   
35771  35772      Q35772  data/image/201907_f1d0da38-c6dc-426a-b4ee-a712...   
40771  40772      Q40772  data/image/201907_f4b4c7c1-0c3e-41ed-afa3-1ddf...   

         question_text           answer_text  
96              邮箱是多少?            bd@500d.me  
376          五百丁邮箱是多少?            bd@500d.me  
1954        五百丁的邮箱是多少?            bd@500d.me  
2088       这张简历的邮箱是多少?         666666@qq.com  
2812     该证券中吴立的邮箱是多少?        wuli1@tfzq.com  
2850        五百丁的邮箱是多少?            bd@500d.me  
2919      图中杨烨辉的邮箱是多少?    yangyehui@tfzq.com  
3153        五百丁的邮箱是多少?            bd@500d.me  
3169        五百丁的邮箱是多少?            bd@500d.me  
3283        五百丁的邮箱是多少?            bd@500d.me  
5002   研究助理:薛绍阳的邮箱是多少?  sueshaoyang@mszq.com  
7032   五百丁在简历上留的邮箱是多少?            bd@500d.me  
7471        五百丁的邮箱是多少?            bd@500d.me  
7538         五百丁邮箱是多少?             bd@500.me  
8862         此人的邮箱是多少?      qmjianli@163.com  
9725         五百丁邮箱是多少?            bd@500d.me  
12592       五百丁的邮箱是多少?            bd@500d.me  
13625       五百丁的邮箱是多少?            bd@500d.me  
13767     五百丁的电子邮箱是多少?            bd@500d.me  
30950     投保人的电子邮箱是多少?       54564560134.com  
31092     分析师金敏的邮箱是多少?        jinm@ctsec.com  
31379      锤子的电子邮箱是多少?        6464646@qq.com  
34193       五百丁的邮箱是多少?            bd@500d.me  
34369       五百丁的邮箱是多少?            bd@500d.me  
35070      这张简历的邮箱是多少?           1234@qq.com  
35241       五百丁的邮箱是多少?            bd@500d.me  
35249       五百丁的邮箱是多少?            bd@500d,me  
35771           邮箱是多少?       13800642@qq.com  
40771     求职者锤子的邮箱是多少?         docer @qq.com  

In [ ]

import re
# 对于所有的数据集,迭代每一行
# 步骤1:判断OCR是否识别
for row in df.iloc[:].iterrows():
    qs = row[1].question_text
    # ocr是否识别成功
    if not os.path.exists('ocr_result/'+os.path.basename(row[1]['filename'])[:-4] + '.txt'):
        continue
    
    # 读取ocr识别结果
    ocrs = codecs.open('ocr_result/'+os.path.basename(row[1]['filename'])[:-4] + '.txt').readlines()[:]
    # 文字
    ocr_text = [x.split('\t')[1].strip() for x in ocrs]
    # 文本框
    ocr_box = [x.split('\t')[0].strip() for x in ocrs]
    
    if re.findall('什么药品', qs):
#         pass
        print(row[1].answer_text, '\t', ocr_text[0])
        
    elif re.findall('说明书来源于哪里', qs):
        candicate_text = list(set([x for x in ocr_text if '说明书' in x]))
        candicate_text = [x for x in candicate_text if '说明书' in x][0]
        candicate_text = candicate_text.replace('说明书', '').replace('来源', '').replace(':', '')
        print(row[1].answer_text, candicate_text)
        pass
    elif re.findall('什么大学什么专业', qs):
        candicate_text = list(set([x for x in ocr_text if '大学' in x and '专业' in x]))
        print(row[1].answer_text, candicate_text[0])
    elif re.findall('什么大学', qs):
        candicate_text = list(set([x for x in ocr_text if re.findall('大学', x)]))
        if len(candicate_text) == 0:
            continue
        
        # print(row[1].answer_text, candicate_text[0])
    elif re.findall('什么专业', qs):
        candicate_text = list(set([x for x in ocr_text if re.findall('本科', x)]))
        if len(candicate_text) == 0:
            continue
        
        # print(row[1].answer_text, candicate_text[0])
#     elif re.findall('电话是多少', qs):
#         continue
#         # break
        
    elif re.findall('邮箱', qs):
        candicate_text = list(set([x for x in ocr_text if re.findall('@', x)]))
        if len(candicate_text) == 0:
            continue
        print(row[1].answer_text, candicate_text[0])

    # 没有匹配成功怎么办
    # XX之后是什么?,最近的框里面的文本进行回答
    
    # box信息,字的大小信息,字号
    
    # ocr结果
    # XX
    # YY

        # break
    # break

模型3.3.2 Bert

LayoutLM: Pre-training of Text and Layout for Document Image Understanding

介绍:LayoutLM利用文本分布的板式信息和识别到的文字信息,基于bert进行大规模预训练,然后在SER和RE任务进行微调;LayoutLMv2在LayoutLM的基础上,将图像视觉信息引入预训练阶段,对多模态信息进行更好的融合;LayoutXLM将LayoutLMv2扩展到多语言。

适用场景:针对卡证、票据等场景的信息提取、关系抽取、文档视觉问答任务。

参考资料:

论文:https://arxiv.org/pdf/1912.13318.pdf

https://huggingface.co/transformers/model_doc/layoutlm.html

更多模型选择请参考:https://github.com/PaddlePaddle/PaddleNLP/blob/develop/docs/transformers.md

Bert模型训练

可以直接使用本项目中转换后的文本数据训练,也可先执行下面命令生成自己的数据,训练( 注意:生成数据时间较长,请耐心等待 )

In [ ]

# !python gen_dataset.py
import paddle
import paddlenlp as ppnlp

from functools import partial
from paddlenlp.data import Stack, Dict, Pad
from utils import prepare_train_features, prepare_validation_features, evaluate


############参数配置###############
# 模型名称
MODEL_NAME = "bert-wwm-chinese"
# 根据官方文档可使用更多的模型,例如:BERT,ERNIE, RoBERTa等,之后可考虑集成的方法
# MODEL_NAME = "roberta-wwm-ext"

# 最大文本长度
max_seq_length = 512
# 文本滑动窗口步幅
doc_stride = 128

# 训练过程中的最大学习率
learning_rate = 2e-5 
# 训练轮次
epochs = 12
# 数据批次大小
batch_size = 32  # 根据显存大小更改
# 学习率预热比例
warmup_proportion = 0.1
# 权重衰减系数,类似模型正则项策略,避免模型过拟合
weight_decay = 0.01


#############模型################
# 加载模型
# 请根据模型名称查看官方文档文档更换接口
model = ppnlp.transformers.BertForQuestionAnswering.from_pretrained(MODEL_NAME)
# model = ppnlp.transformers.RobertaForQuestionAnswering.from_pretrained(MODEL_NAME)


# 加载 tokenizer
# 请根据文档更换接口
tokenizer = ppnlp.transformers.BertTokenizer.from_pretrained(MODEL_NAME)
# tokenizer = ppnlp.transformers.RobertaTokenizer.from_pretrained(MODEL_NAME)


#############数据###############
# 加载数据集
# 如果是自己生成的数据请更换为自己数据的路径
train_ds = ppnlp.datasets.load_dataset('dureader_robust', data_files='data/data83268/train.json')
dev_ds = ppnlp.datasets.load_dataset('dureader_robust', data_files='data/data83268/dev.json')

# 数据滑窗处理
train_trans_func = partial(prepare_train_features, 
                           max_seq_length=max_seq_length, 
                           doc_stride=doc_stride,
                           tokenizer=tokenizer)

train_ds.map(train_trans_func, batched=True)

dev_trans_func = partial(prepare_validation_features, 
                           max_seq_length=max_seq_length, 
                           doc_stride=doc_stride,
                           tokenizer=tokenizer)
                           
dev_ds.map(dev_trans_func, batched=True)

# 数据读取器配置
train_batch_sampler = paddle.io.DistributedBatchSampler(
        train_ds, batch_size=batch_size, shuffle=True)

train_batchify_fn = lambda samples, fn=Dict({
    "input_ids": Pad(axis=0, pad_val=tokenizer.pad_token_id),
    "token_type_ids": Pad(axis=0, pad_val=tokenizer.pad_token_type_id),
    "start_positions": Stack(dtype="int64"),
    "end_positions": Stack(dtype="int64")
}): fn(samples)

train_data_loader = paddle.io.DataLoader(
    dataset=train_ds,
    batch_sampler=train_batch_sampler,
    collate_fn=train_batchify_fn,
    return_list=True)

dev_batch_sampler = paddle.io.BatchSampler(
    dev_ds, batch_size=batch_size, shuffle=False)

dev_batchify_fn = lambda samples, fn=Dict({
    "input_ids": Pad(axis=0, pad_val=tokenizer.pad_token_id),
    "token_type_ids": Pad(axis=0, pad_val=tokenizer.pad_token_type_id)
}): fn(samples)

dev_data_loader = paddle.io.DataLoader(
    dataset=dev_ds,
    batch_sampler=dev_batch_sampler,
    collate_fn=dev_batchify_fn,
    return_list=True)


#############优化器配置#############
# 学习率策略
num_training_steps = len(train_data_loader) * epochs
lr_scheduler = ppnlp.transformers.LinearDecayWithWarmup(learning_rate, num_training_steps, warmup_proportion)

# Generate parameter names needed to perform weight decay.
# All bias and LayerNorm parameters are excluded.
decay_params = [
    p.name for n, p in model.named_parameters()
    if not any(nd in n for nd in ["bias", "norm"])
]

# 设置优化器
optimizer = paddle.optimizer.AdamW(
    learning_rate=lr_scheduler,
    parameters=model.parameters(),
    weight_decay=weight_decay,
    apply_decay_param_fun=lambda x: x in decay_params)


#############损失函数################
class CrossEntropyLossForSQuAD(paddle.nn.Layer):
    def __init__(self):
        super(CrossEntropyLossForSQuAD, self).__init__()

    def forward(self, y, label):
        start_logits, end_logits = y   # both shape are [batch_size, seq_len]
        start_position, end_position = label
        start_position = paddle.unsqueeze(start_position, axis=-1)
        end_position = paddle.unsqueeze(end_position, axis=-1)
        start_loss = paddle.nn.functional.softmax_with_cross_entropy(
            logits=start_logits, label=start_position, soft_label=False)
        start_loss = paddle.mean(start_loss)
        end_loss = paddle.nn.functional.softmax_with_cross_entropy(
            logits=end_logits, label=end_position, soft_label=False)
        end_loss = paddle.mean(end_loss)

        loss = (start_loss + end_loss) / 2
        return loss


#############模型训练################
# 实例化 loss
criterion = CrossEntropyLossForSQuAD()
global_step = 0

# 训练
for epoch in range(1, epochs + 1):
    for step, batch in enumerate(train_data_loader, start=1):
        global_step += 1
        input_ids, segment_ids, start_positions, end_positions = batch
        logits = model(input_ids=input_ids, token_type_ids=segment_ids)
        loss = criterion(logits, (start_positions, end_positions))

        if global_step % 100 == 0 :
            print("global step %d, epoch: %d, batch: %d, loss: %.5f" % (global_step, epoch, step, loss))
        loss.backward()
        optimizer.step()
        lr_scheduler.step()
        optimizer.clear_grad()

    evaluate(model=model, data_loader=dev_data_loader) 

# 保存
model.save_pretrained('/home/aistudio/checkpoint')
tokenizer.save_pretrained('/home/aistudio/checkpoint')

In [91]

#运行上训练述代码,可以用四卡跑,若内存溢出可以调低batchsize,训练时间较长可以提前终止。
!python train.py

OCR获得测试集文本数据

In [1]

#注意:如果使用aistudio平台这里需要自行去官网下载数据并上传
#!python gen_test.py

Bert模型预测结果

import paddle
import paddlenlp as ppnlp

from functools import partial
from paddlenlp.data import Dict, Pad
from utils import prepare_validation_features, predict

############参数配置###############
# 模型名称
MODEL_NAME = "bert-wwm-chinese"
# 根据官方文档可使用更多的模型,例如:BERT,ERNIE, RoBERTa等,之后可考虑集成的方法
# MODEL_NAME = "roberta-wwm-ext"

# 最大文本长度
max_seq_length = 512
# 文本滑动窗口步幅
doc_stride = 128

# 训练过程中的最大学习率
learning_rate = 3e-5
# 训练轮次
epochs = 4
# 数据批次大小
batch_size = 32
# 学习率预热比例
warmup_proportion = 0.1
# 权重衰减系数,类似模型正则项策略,避免模型过拟合
weight_decay = 0.01

#############模型################
# 加载模型
model = ppnlp.transformers.BertForQuestionAnswering.from_pretrained("训练得到的checkpoint文件夹")
# model = ppnlp.transformers.RobertaForQuestionAnswering.from_pretrained(MODEL_NAME)
# 更新参数
# state_dict = paddle.load('checkpoints/model_state.pdparams')
# model.state_dict(state_dict)
# 加载 tokenizer
# 请根据文档更换接口
tokenizer = ppnlp.transformers.BertTokenizer.from_pretrained("训练得到的checkpoint文件夹")
# tokenizer = ppnlp.transformers.RobertaTokenizer.from_pretrained(MODEL_NAME)

#############数据###############
# 加载数据集
dev_ds = ppnlp.datasets.load_dataset('dureader_robust', data_files='ocr_result/test.json')


dev_trans_func = partial(prepare_validation_features,
                         max_seq_length=max_seq_length,
                         doc_stride=doc_stride,
                         tokenizer=tokenizer)

dev_ds.map(dev_trans_func, batched=True)

# 数据读取器配置
dev_batch_sampler = paddle.io.BatchSampler(
    dev_ds, batch_size=batch_size, shuffle=False)

dev_batchify_fn = lambda samples, fn=Dict({
    "input_ids": Pad(axis=0, pad_val=tokenizer.pad_token_id),
    "token_type_ids": Pad(axis=0, pad_val=tokenizer.pad_token_type_id)
}): fn(samples)

dev_data_loader = paddle.io.DataLoader(
    dataset=dev_ds,
    batch_sampler=dev_batch_sampler,
    collate_fn=dev_batchify_fn,
    return_list=True)

# Generate parameter names needed to perform weight decay.
# All bias and LayerNorm parameters are excluded.
decay_params = [
    p.name for n, p in model.named_parameters()
    if not any(nd in n for nd in ["bias", "norm"])
]

# 预测
predict(model=model, data_loader=dev_data_loader)
#运行上述推理预测代码
#
!python infer.py

4.项目总结

针对保险文本视觉认知问答项目,本文使用paddle框架将OCR与NLP技术相结合,根据不同文档图片数据集进行OCR识别,对问题和回答进行建模 。本次学习收获主要是对语言处理有进一步了解,熟悉规则匹配和nlp的Bert语言模型,后面会在ENRIE、GPT-2等模型做对比实验,结合模型参数和训练策略,选择准确率最高的模型作为最终项目应用。

本项目链接:保险文本视觉认知问答 - 飞桨AI Studio

参考链接:保险文本视觉认知问答竞赛(Baseline) - 飞桨AI Studio

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/152309.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Rhec第二次作业

两台机器&#xff1a;第一台机器作为客户端&#xff0c;第二台机器作为服务器&#xff0c;在第一台使用rhce用户免密登录第二台机器准备两台虚拟机并保证可以ping通网络&#xff0c;ip地址不一样第一台机器配置创建rhce用户在rhce用户上&#xff0c;创建密钥对并将公钥发送给第…

JDK源码(二)ConcurrentHashMap-JDK1.7

1.背景 并发编程中&#xff0c;ConcurrentHashMap是一个使用度非常高的数据结构。 优点: 线程安全相比于HashTable和Collections.synchronizedMap()效率高&#xff0c;使用了分段锁技术。 2.ConcurrentHashMap数据结构 Segment Segment继承了ReentrantLock&#xff0c;所以它…

Java设计模式-桥接模式Bridge

传统模式 案例 要求对不同手机类型的不同品牌实现操作编程(比如:开机、关机、上网&#xff0c;打电话等)&#xff0c;如图: 类图 问题 扩展性问题(类爆炸)&#xff0c;如果我们再增加手机的样式(旋转式)&#xff0c;就需要增加各个品牌手机的类&#xff0c;同样如果我们…

【Kotlin】标准库函数 ① ( apply 标准库函数 | let 标准库函数 )

文章目录一、apply 标准库函数二、let 标准库函数Kotlin 语言中 , 在 Standard.kt 源码中 , 为所有类型定义了一批标准库函数 , 所有的 Kotlin 类型都可以调用这些函数 ; 一、apply 标准库函数 Kotlin 标准库函数 中的 apply 函数 , 该函数可以看作 实例对象 的 配置函数 , 传…

现货黄金术语汇总

有的投资者可能并不是新手&#xff0c;可能之前对股票投资的一些术语有一定的了解甚至说是经验。但是转到现货黄金市场的时候&#xff0c;还是对很多术语感到很陌生&#xff0c;下面小编针对一些与股票不一样的现货黄金术语进行介绍&#xff0c;希望可以帮助投资者们尽快"…

前言技术之swagger

一.前后端分离的特点前后端分离是的前端与后端之间的职责更加明确 后台&#xff1a; 负责业务处理 前端&#xff1a; 负责显示逻辑 在这种情况下&#xff0c;前端和后端可以分别交付给专业的开发人员去做&#xff0c;所以是必须要定义前后端直接的对接 接口&#xff0c;否则各自…

liblas读取点云,设置半透明

一&#xff0c;用Liblas读取点云数据&#xff0c;获取点云位置和颜色 二&#xff0c;将点云位置和颜色分别代入geode的位置数组和颜色数组 三&#xff0c;用glsl设置半透明。需要注意的是 1&#xff0c;颜色数组是用attribute,所以要用 geom->setVertexAttribArray(10, colo…

Wireshark抓包分析DHCP

1、DHCP简介动态主机设置协议&#xff08;英语&#xff1a;Dynamic Host Configuration Protocol&#xff0c;DHCP&#xff09;是一个局域网的网络协议&#xff0c;使用UDP协议工作&#xff0c;主要有两个用途&#xff1a;用于内部网或网络服务供应商自动分配IP地址&#xff1b…

从0到1完成一个Vue后台管理项目(十五、作业列表、表格数据方法的封装)

往期 从0到1完成一个Vue后台管理项目&#xff08;一、创建项目&#xff09; 从0到1完成一个Vue后台管理项目&#xff08;二、使用element-ui&#xff09; 从0到1完成一个Vue后台管理项目&#xff08;三、使用SCSS/LESS&#xff0c;安装图标库&#xff09; 从0到1完成一个Vu…

解决虚拟机安装 VMware Tools 灰色无法点击问题

环境&#xff1a; 主机 OS: Windows 11 虚拟机平台: VMware Workstation 17 Pro 虚拟系统: Ubuntu 22.04 1. 问题 安装好 Linux 系统后&#xff0c;想要适配桌面大小等其它功能就需要安装 VMware Tools 这个工具&#xff0c;最简单的办法就是通过虚拟机平台的 “一键安装”&a…

Java之ATM系统

目录项目介绍系统准备&#xff0c;首页设计总结开户功能总结用户登录总结用户操作页设计、查询账户、退出账户功能用户存钱取款功能转账功能密码修改、销户源代码项目介绍 系统准备&#xff0c;首页设计 总结 1、用户的账户信息&#xff0c;系统如何表示的? 定义账户类Accoun…

【CVHub】现代目标检测故事 | 40+目标检测网络架构大盘点!从基础架构ResNet到最强检测器Yolov7再到最新部署神器GhostNetV2

本文来源“CVHub”公众号&#xff0c;侵权删&#xff0c;干货满满。 作者丨派派星 来源丨CVHub 原文链接&#xff1a;现代目标检测故事 | 40种网络架构大盘点&#xff01;从基础架构ResNet到最强检测器Yolov7再到最新部署神器GhostNetV2 导读 目标检测是指在图像或视频中分…

PHY6230 高性价比低功耗高性能 集成32-bit MCU BLE5.2+2.4G芯片

PHY6230 是一款高性价比低功耗高性能Bluetooth LE 5.2系统级芯片&#xff0c;集成32-bit高性能低功耗MCU&#xff0c;16KB OTP&#xff0c;8KB Retention SRAM和64KB ROM&#xff0c;可选EEPROM。内置高性能多模射频收发机最大发射功率10dBm&#xff0c;BLE 1Mbps速率下接收灵敏…

快手发布2022直播生态报告,运营人速览

1、快手电商推出2023年直播间联合补贴活动1月5日&#xff0c;快手电商推出2023年直播间联合补贴活动。该活动主要目的是助力主播完成更高销售额&#xff0c;报名成功后&#xff0c;平台将对直播间内的一部分活跃用户发放10%-16%折扣率的满减优惠券&#xff0c;成本由平台和主播…

【BUG解决方案】jQuery数组中包含数据,但通过 .length 获得的数组长度始终为0

0. BUG展示 var lels []; for (var i 0; i < maxDevNums 1; i) {lels.push([]); } $.ajax({type : "post",async : true,url : "/sc/comb/history/data",data : {},dataType : "json",success : function (result) {if (result) {for (le…

FFmpeg 集成 x265 编译及解码

x265 是一个免费的软件库和应用程序&#xff0c;用于将视频流编码为 H.265/MPEG-H HEVC 压缩格式&#xff0c;并在 GNU GPL 条款下发布。 FFmpeg 为了支持 H.265 编、解码可以集成 x265 编译&#xff0c;在编译 FFmpeg 之前需要先编译 x265&#xff0c;但并不是所有的版本都能…

Python一轮知识拾遗

目录 字符串格式化 %格式符 format字符串格式化 三元条件运算符 可迭代对象 break和continue语句 enumerate函数 序列封包 序列解包 部分序列解包 append.列表和extend.列表的区别 字符串格式化 通过字符串的格式化&#xff0c;可以输出特定格式的字符串。 (1) 格式化…

为什么要申报绿色工厂?

一、什么是绿色工厂&#xff1f; 绿色工厂是指实现了用地集约化、生产洁净化、废物资源化、能源低碳化的工厂。 二、为什么要申报绿色工厂&#xff1f; 1、政策导向&#xff0c;发展趋势 发展绿色工厂是顺应全球绿色发展的大趋势&#xff0c;符合国家政策导向。 2、荣誉称号…

[笔记]Windows Cyswin ssh配置及远程控制

文章目录前言一、配置1.1 安装 Cygwin1.2 Cygwin安装时搜索安装ssh1.3 添加cygwin安装目录至Path环境变量1.4 配置 SSHD 服务1.5 添加 sshd连接账号二、使用2.1 使用配置的连接账号进行登录2.2 连接远程主机三、常见问题3.1 ssh on cygwin和openssh 冲突 提示 Host key verific…

新增血缘关系功能,色彩地图支持标记功能,DataEase开源数据可视化分析平台v1.18.0发布

2023年1月9日&#xff0c;DataEase开源数据可视化分析平台正式发布v1.18.0版本。 这一版本的功能升级包括&#xff1a;数据集方面&#xff0c;定时任务采用分页的方式拉取数据&#xff0c;减少资源消耗&#xff1b;仪表板方面&#xff0c;新增仪表板主题&#xff0c;以满足不同…