深入探索智能未来:文本生成与问答模型的创新融合

news2024/10/6 2:27:45

深入探索智能未来:文本生成与问答模型的创新融合

1.Filling Model with T5

1.1背景介绍

该项目用于将句子中 [MASK] 位置通过生成模型还原,以实现 UIE 信息抽取中 Mask Then Filling 数据增强策略。

Mask Then Fill 是一种基于生成模型的信息抽取数据增强策略。对于一段文本,我们其分为「关键信息段」和「非关键信息段」,包含关键词片段称为「关键信息段」。下面例子中标粗的为 关键信息片段,其余的为 非关键片段

大年三十 我从 北京 的大兴机场 飞回成都

我们随机 [MASK] 住一部分「非关键片段」,使其变为:

大年三十 我从 北京 [MASK] 飞回成都

随后,将改句子喂给 filling 模型(T5-Fine Tuned)还原句子,得到新生成的句子:

大年三十 我从 北京 首都机场作为起点,飞回成都

  • 环境安装

本项目基于 pytorch + transformers 实现,运行前请安装相关依赖包:

pip install -r ../requirements.txt
  • 数据集准备

项目中提供了一部分示例数据,数据来自DuIE数据集中的文本数据,数据在 data/

若想使用 自定义数据 训练,只需要仿照示例数据构建带 [MASK] 的文本即可,你也可以使用 parse_data.py 快速生成基于 词粒度 的训练数据:

"Bortolaso Guillaume,法国籍[MASK]"中[MASK]位置的文本是:	运动员
"歌曲[MASK]是由歌手海生演唱的一首歌曲"中[MASK]位置的文本是:	《情一动心就痛》
...

每一行用 \t 分隔符分开,第一部分部分为 带[MASK]的文本,后一部分为 [MASK]位置的原始文本(label)

1.2. 模型训练

修改训练脚本 train.sh 里的对应参数, 开启模型训练:

python train.py \
    --pretrained_model "uer/t5-base-chinese-cluecorpussmall" \
    --save_dir "checkpoints/t5" \
    --train_path "data/train.tsv" \
    --dev_path "data/dev.tsv" \
    --img_log_dir "logs" \
    --img_log_name "T5-Base-Chinese" \
    --batch_size 128 \
    --max_source_seq_len 128 \
    --max_target_seq_len 32 \
    --learning_rate 1e-4 \
    --num_train_epochs 20 \
    --logging_steps 50 \
    --valid_steps 500 \
    --device cuda:0

正确开启训练后,终端会打印以下信息:

...
 0%|          | 0/2 [00:00<?, ?it/s]
100%|██████████| 2/2 [00:00<00:00, 21.28it/s]
DatasetDict({
    train: Dataset({
        features: ['text'],
        num_rows: 350134
    })
    dev: Dataset({
        features: ['text'],
        num_rows: 38904
    })
})
...
global step 2400, epoch: 1, loss: 7.44746, speed: 0.82 step/s
global step 2450, epoch: 1, loss: 7.42028, speed: 0.82 step/s
global step 2500, epoch: 1, loss: 7.39333, speed: 0.82 step/s
Evaluation bleu4: 0.00578
best BLEU-4 performence has been updated: 0.00026 --> 0.00578
global step 2550, epoch: 1, loss: 7.36620, speed: 0.81 step/s
...

logs/T5-Base-Chinese.png 文件中将会保存训练曲线图:

1.3 模型预测

完成模型训练后,运行 inference.py 以加载训练好的模型并应用:

 if __name__ == "__main__":
    masked_texts = [
        '"《μVision2单片机应用程序开发指南》是2005年2月[MASK]图书,作者是李宇"中[MASK]位置的文本是:'
    ]
    inference(masked_texts)
python inference.py

得到以下推理结果:

maksed text: 
[
    '"《μVision2单片机应用程序开发指南》是2005年2月[MASK]图书,作者是李宇"中[MASK]位置的文本是:'
]
output: 
[
    ',中国工业出版社出版的'
]

2.问答模型(Text-Generation, T5 Based)

2.1 背景介绍

问答模型是指通过输入一个「问题」和一段「文章」,输出「问题的答案」。

问答模型分为「抽取式」和「生成式」,抽取式问答可以使用 [UIE] 训练,这个实验中我们将使用「生成式」模型来训练一个问答模型。

我们选用「T5」作为 backbone,使用百度开源的「QA数据集」来训练得到一个生成式的问答模型。

  • 环境安装

本项目基于 pytorch + transformers 实现,运行前请安装相关依赖包:

pip install -r ../requirements.txt

2.2 数据集准备

项目中提供了一部分示例数据,数据是百度开源的问答数据集,数据在 data/DuReaderQG

若想使用自定义数据训练,只需要仿照示例数据构建数据集即可:

{"context": "违规分为:一般违规扣分、严重违规扣分、出售假冒商品违规扣分,淘宝网每年12月31日24:00点会对符合条件的扣分做清零处理,详情如下:|温馨提醒:由于出售假冒商品24≤N<48分,当年的24分不清零,所以会存在第一年和第二年的不同计分情况。", "answer": "12月31日24:00", "question": "淘宝扣分什么时候清零", "id": 203}
{"context": "生长速度 头发是毛发中生长最快的毛发,一般每天长0.27—0.4mm,每月平均生长约1.0cm,一年大概长10—14cm。但是,头发不可能无限制的生长,一般情况下,头发长至50—60cm,就会脱落再生新发。", "answer": "0.27—0.4mm", "question": "头发一天能长多少", "id": 328}
...

每一行为一个数据样本,json 格式。

其中,"context" 代表参考文章,question 代表问题,"answer" 代表问题答案。

2.3 模型训练

修改训练脚本 train.sh 里的对应参数, 开启模型训练:

python train.py \
    --pretrained_model "uer/t5-base-chinese-cluecorpussmall" \
    --save_dir "checkpoints/DuReaderQG" \
    --train_path "data/DuReaderQG/train.json" \
    --dev_path "data/DuReaderQG/dev.json" \
    --img_log_dir "logs/DuReaderQG" \
    --img_log_name "T5-Base-Chinese" \
    --batch_size 32 \
    --learning_rate 1e-4 \
    --max_source_seq_len 256 \
    --max_target_seq_len 32 \
    --learning_rate 5e-5 \
    --num_train_epochs 50 \
    --logging_steps 10 \
    --valid_steps 500 \
    --device "cuda:0"

正确开启训练后,终端会打印以下信息:

...
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 650.73it/s]
DatasetDict({
    train: Dataset({
        features: ['text'],
        num_rows: 14520
    })
    dev: Dataset({
        features: ['text'],
        num_rows: 984
    })

global step 10, epoch: 1, loss: 9.39613, speed: 1.60 step/s
global step 20, epoch: 1, loss: 9.39434, speed: 1.71 step/s
global step 30, epoch: 1, loss: 9.39222, speed: 1.72 step/s
global step 40, epoch: 1, loss: 9.38739, speed: 1.63 step/s
global step 50, epoch: 1, loss: 9.38296, speed: 1.63 step/s
global step 60, epoch: 1, loss: 9.37982, speed: 1.71 step/s
global step 70, epoch: 1, loss: 9.37385, speed: 1.71 step/s
global step 80, epoch: 1, loss: 9.36876, speed: 1.69 step/s
global step 90, epoch: 1, loss: 9.36209, speed: 1.72 step/s
global step 100, epoch: 1, loss: 9.35349, speed: 1.70 step/s
...

logs/DuReaderQG 文件下将会保存训练曲线图:

2.4 模型推理

完成模型训练后,运行 inference.py 以加载训练好的模型并应用:

...

if __name__ == '__main__':
    question = '治疗宫颈糜烂的最佳时间'
    context = '专家指出,宫颈糜烂治疗时间应选在月经干净后3-7日,因为治疗之后宫颈有一定的创面,如赶上月经期易发生感染。因此患者应在月经干净后3天尽快来医院治疗。同时应该注意,术前3天禁同房,有生殖道急性炎症者应治好后才可进行。'
    inference(qustion=question, context=context)

运行推理程序:

python inference.py

得到以下推理结果:

Q: "治疗宫颈糜烂的最佳时间"
C: "专家指出,宫颈糜烂治疗时间应选在月经干净后3-7日,因为治疗之后宫颈有一定的创面,如赶上月经期易发生感染。因此患者应在月经干净后3天尽快来医院治疗。同时应该注意,术前3天禁同房,有生殖道急性炎症者应治好后才可进行。"
A: "答案:月经干净后3-7日"

项目链接:https://github.com/HarderThenHarder/transformers_tasks/blob/main/answer_generation/readme.md

更多优质内容请关注公号:汀丶人工智能;会提供一些相关的资源和优质文章,免费获取阅读。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/893723.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

科学规划假期学习,猿辅导《暑假一本通》获用户好评

一直以来&#xff0c;有效利用寒、暑假期查漏补缺、解决偏科问题、初步养成好的自主学习习惯等是很多家长对学生的期望。但当前市面上教辅品类繁多&#xff0c;内容质量却参差不齐。据北京开卷统计数据显示&#xff0c;2022年前三季度零售市场上的教辅图书超过8000种&#xff0…

高级AI赋能Fortinet FortiXDR解决方案

扩展检测和响应 (XDR&#xff1a;Extended Detection and Response) 解决方案旨在帮助组织整合分布式安全技术&#xff0c;更有效地识别和响应活动的威胁。虽然 XDR 是一种新的技术概念&#xff0c;但其构建基础是端点检测和响应 (EDR&#xff1a;Endpoint Detection and Respo…

CW4L2-3A-S电源滤波器

CW4L2-3A-T CW4L2-6A-T CW4L2-10A-T CW4L2-20A-T CW4L2-3A-S CW4L2-6A-S CW4L2-10A-S CW4L2-20A-S 安装位置应靠近电源线入口处&#xff0c;尽可能滤除沿电源线侵入和窜出的电磁干扰。 确保滤波器外壳与设备机箱良好电接触&#xff0c;并接好地线。 滤波器的输入输出…

PoseiSwap 更新质押系统,并将在 8 月18 日开启“Trident ”快照

自 DeFi Summer 后&#xff0c;DeFi 设施整体的形态并未发生本质的变化&#xff0c;我们看到 DeFi 应用仍旧不具向外长期捕获价值、用户的能力&#xff0c;老旧叙事导致 DeFi 赛道整体的发展停滞不前。伴随着行业进入到下行周期&#xff0c;DeFi 赛道的资金、用户不断出逃&…

政务、商务数据资源有效共享:让数据上“链”,记录每一个存储过程!

数据上链是目前“区块链”最常见的场景。因为链上所有参与方都分享了统一的事实来源&#xff0c;所有人都可以即时获得最新的信息&#xff0c;数据可用不可见。因此&#xff0c;不同参与方之间的协作效率得以大幅提高。同时&#xff0c;因为区块链上的数据难以篡改&#xff0c;…

猿辅导与中街1946联手推出“冷知识冰棍”,带来学习新体验

为了给孩子们的暑假学习加点“料”&#xff0c;猿辅导近日脑洞大开&#xff0c;和中街1946携手推出了“冷知识冰棍”&#xff0c;以数学、英语、语文、科学4个科目为外包装&#xff0c;分别对应草莓山楂、青提菠萝、茉莉蓝莓和蜜桃乌龙等4种口味&#xff0c;为孩子们开启了夏日…

学习笔记」左偏树

dist 的性质 对于一棵二叉树&#xff0c;我们定义左孩子或右孩子为空的节点为外节点&#xff0c;定义外节点的 distdist 为 11&#xff0c;空节点的 distdist 为 00&#xff0c;不是外节点也不是空节点的 distdist 为其到子树中最近的外节点的距离加一。 一棵根的 distdist 为…

多线段的研究

1.AutoCAD分为二维多线段(命令pline)&#xff0c;三维多线段(3dpoly)。 1.1&#xff1a;二维多线段 对应类Acdb2dPolyline&#xff0c;有起点&#xff0c;末点的宽度。由Acdb2dVertex&#xff08;顶点组成&#xff09; 1.2:三维多义线&#xff0c;(三维多义线的顶点没有凸度&a…

地理测绘基础知识(3)-观测与遮挡

在上一篇文章中&#xff0c;我们介绍了椭球模型下的一系列基础的坐标操作。本节&#xff0c;介绍观测与遮挡问题。 观测主要用于从观察点A观测大地标准点B&#xff0c;用来解决观测的仰角、方位角与大地坐标系之间的关系。 在没有GPS卫星的时代&#xff0c;为了测量一个位置的…

ipad手写笔一定要买苹果的吗?适合学生党电容笔推荐

暑假接近尾声&#xff0c;不少学生党开始为开学而做准备了。如果你想要一个与iPad相匹配的电容笔&#xff0c;可以买一个Apple Pencil吧。但事实上&#xff0c;这个苹果产品性能比较出色&#xff0c;卖的还是很好的。但是平替电容笔也是个不错的选择&#xff0c;而且价格也很合…

构建LLM应用程序时需要了解的5件事

推荐&#xff1a;使用 NSDT场景编辑器 助你快速搭建可二次编辑的3D应用场景 1.幻觉 使用LLM时应注意的主要方面之一是幻觉。在LLM的背景下&#xff0c;幻觉是指产生不真实的&#xff0c;不正确的&#xff0c;无意义的信息。LLM非常有创意&#xff0c;它们可以用于不同的领域&am…

亚马逊如何登录多个买家号?如何防止账号关联?

如果有多买家账号需要登录使用&#xff0c;以下是在同一设备上登录多个买家账号的一般步骤&#xff1a; 1、登出当前账号&#xff1a;如果您已经登录了一个买家账号&#xff0c;首先需要退出该账号。在页面右上角&#xff0c;通常会看到一个"Hello, [您的用户名]"&a…

终端安全无忧!迅软科技助力母婴用品企业保护隐私信息

客户简要介绍 某母婴用品企业是专业的婴幼儿用品综合制造厂商&#xff0c;是总部设在上海&#xff0c;致力于研发集安全性、舒适性、功能性以及环保于一体的产品。 企业的重要诉求 公司内部奶瓶、纸尿裤等产品的销售数据以及新品设计图片要避免外传被竞争对手拿到&#xff0c;需…

Java“牵手”根据关键词搜索(分类搜索)1688商品列表页面数据获取方法,1688API实现批量商品数据抓取示例

1688商城是一个网上购物平台&#xff0c;售卖各类商品&#xff0c;包括服装、鞋类、家居用品、美妆产品、电子产品等。要获取1688商品列表和商品详情页面数据&#xff0c;您可以通过开放平台的接口或者直接访问1688商城的网页来获取商品详情信息。以下是两种常用方法的介绍&…

yaml语法规则

1.语法规则 大小写敏感属性层级关系使用多行描述&#xff0c;每行结尾使用冒号结束使用缩进表示层级关系&#xff0c;同层级左侧对齐&#xff0c;只允许使用空格&#xff08;不允许 使用Tab键&#xff09;属性值前面添加空格&#xff08;属性名与属性值之间使用冒号空格作为分…

Azure VM上意外禁用NIC如何还原恢复

创建一个windows虚拟机&#xff0c;并远程连接管理员的方式打开powershell 首先查看虚拟网卡&#xff0c;netsh interface show interface 然后禁用虚拟网卡 ,netsh interface set interface Ethernet disable 去Azure虚拟机控制台&#xff0c;打开串行控制台 控制台中键入cmd,…

如何使用 Docker Compose 运行 OSS Wordle 克隆

了解如何使用 Docker Compose 在五分钟内运行您自己的流行 Wordle 克隆实例。您将如何部署 Wordle&#xff1f; Wordle在 2021 年底发布后席卷了互联网。对于许多人来说&#xff0c;这仍然是一种早晨的仪式&#xff0c;与一杯咖啡和一天的开始完美搭配。作为一名 DevOps 工程师…

MongoDB 安装 linux

本文介绍一下MongoDB的安装教程。 系统环境&#xff1a;CentOS7.4 可以用 cat /etc/redhat-release 查看本机的系统版本号 一、MongoDB版本选择 当前最新的版本为7.0&#xff0c;但是由于7.0版本安装需要升级glibc2.25以上,所以这里我暂时不安装该版本。我们选择的是6.0.9版本…

Leetcode每日一题:1388. 3n 块披萨(2023.8.18 C++)

目录 1388. 3n 块披萨 问题描述&#xff1a; 实现代码与解析&#xff1a; 动态规划 原理思路&#xff1a; 1388. 3n 块披萨 问题描述&#xff1a; 给你一个披萨&#xff0c;它由 3n 块不同大小的部分组成&#xff0c;现在你和你的朋友们需要按照如下规则来分披萨&am…

【MT32F006】MT32F006之HT1628驱动LED

本文最后修改时间&#xff1a;2023年03月30日 一、本节简介 本文介绍如何使用MT32F006连接HT1628芯片驱动LED。 二、实验平台 库版本&#xff1a;V1.0.0 编译软件&#xff1a;MDK5.37 硬件平台&#xff1a;MT32F006开发板&#xff08;主芯片MT32F006&#xff09; 仿真器&a…