NLP实践——知识图谱问答模型FiD

news2025/2/26 13:45:45

NLP实践——知识图谱问答模型FiD

  • 0. 简介
  • 1. 模型结构
  • 2. 召回
  • 3. 问答
  • 4. 结合知识的问答

0. 简介

好久没有更新了,今天介绍一个知识图谱问答(KBQA)模型,在此之前我一直在用huggingface的Pipeline中提供的QA模型,非常方便但是准确性不是特别好。今天介绍的这个模型是Facebook在2021年就已经提出来的FiD(Fusion-in-Decoder),发表在ACL上。

论文地址: https://aclanthology.org/2021.eacl-main.74.pdf
项目地址:https://github.com/facebookresearch/FiD

其实我原本是想看EMNLP2022中的一篇文章,也已经开源。这个项目叫Grape,其基本思想实在FiD的基础上,采用两个T5 Encoder,并且在解码之前利用query和候选文本中的实体,构建GNN,在节点上做了Attention以增强Encoder的表征。

论文地址: https://arxiv.org/pdf/2210.02933.pdf

但是grape的这个项目我在实验的时候,遇到了一点环境配置上的问题,作者采用了一个比较冷门的dgl版本,这个版本在Linux_x86_64系统上没有官方编译过,于是我尝试自己编译,又遇到了一堆cmake和gcc版本的问题,于是放弃尝试。但顺着Grape的论文,找到了FiD这一项目。

1. 模型结构

所谓KBQA,也就是在问答模型的基础上,除了给定原文的信息之外,还考虑知识库中其他的预料信息。这个模型的原理很简单,就是一个生成模型,加上召回任务。

也就是先利用一个召回模型,在知识库中召回若干与给定的原文相关的文本,然后再将问题分别与原文以及相关文本进行拼接,拼接后的结果分别进行编码,再将编码的结果进行concat,最终把concat的结果给到Decoder,由Decoder生成答案。

模型结构
采用的基础模型是T5,分别在两个数据集NaturalQuestions和TriviaQA上进行了训练,数据和训练好的模型均可在git上找到。

2. 召回

召回这部分其实没有什么东西,在官方的git中,就是采用bert-base做了一下编码,我没有跟着它的做法,感兴趣的同学可以自行阅读retrieval相关的py文件。

这里我是觉得自己编码更方便一些,可以直接采用Sentence transformer的预训练模型,或者你自己训练的什么编码模型,另外做成faiss或者milvus索引的话,效率还会高不少。关于Sentence transformer,在好久之前的这篇博客中也介绍过。

3. 问答

虽然这个模型是KBQA模型,但是git上似乎也没有直接给出Fusion的那部分代码。这里我们不妨自己先写一个预测方法,利用它训练好的模型来实现QA的功能。

由于它本身其实就是一个T5模型,所以只要你对transformers模块比较熟悉的话,可以很轻易的写出预测方法。

首先我们加载一下模型和tokenizer:

from transformers import AutoTokenizer
from src.model import FiDT5  # 注意引用时的目录,引不进来就直接把这个类复制过来

# 从git上下载你想要尝试的模型,比如nq,把文件都放在一个目录里,然后用from_pretrained读取它
model = FiDT5.from_pretrained('your_path_to_Fid_model/nq_reader_base/')
tokenizer = AutoTokenizer.from_pretrained('t5-large')  # 联网下载,或提前下载好放在本地目录

# 然后eval一下,关掉dropout和BN,如果你比较叛逆,不关也是可以的
model.eval()

接下来我们写一个简单的预测方法,就可以实现QA了。

def predict(model, tokenizer, question, title, context, device='cpu'):
    """
    预测
    :param model: T5模型
    :param tokenizer: 分词器
    :param question: 问题
    :param title: 标题,没有的话可以给空字符
    :param context: 正文
    :param device: 在cpu还是cuda上执行
    ---------------
    ver: 2023-01-12
    by: changhongyu
    """
    if device.startswith('cuda'):
        model.to(device)
    combined_text = "question: " + question + "title: " + title + "context: " + context
    inputs = tokenizer(combined_text, max_length=1024, return_tensors='pt')
    test_outputs = model.generate(
        input_ids=inputs['input_ids'].unsqueeze(0).to(device),
        attention_mask=inputs['attention_mask'].unsqueeze(0).to(device),
        max_length=50,
    )
    answer = tokenizer.decode(test_outputs[0])
    
    return answer

来测试一下效果:

predict(
    model, 
    tokenizer,
    "Who is Russia's new commander",
    "Russia Ukraine War Live Updates: Russia changes commanders again in Ukraine",
    """09:20 (IST) Jan 12 Ukrainian military analyst Oleh Zhdanov said the situation in Soledar was "approaching that of critical" "The Ukrainian armed forces are holding their positions. About one half of the town is under our control. Fierce fighting is going on near the town centre," he said on YouTube.However, Zhdanov told Ukrainian television that if Russian forces seized Soledar or nearby Bakhmut it would be more a political victory than military. 09:18 (IST) Jan 12 Russian private military firm Wagner Group said its capture of the salt mining town Soledar in eastern Ukraine was complete- a claim denied by Ukraine 09:08 (IST) Jan 12 Russia changes commanders again in Ukraine Moscow named a new commander for its invasion of Ukraine. Russian Defence Minister Sergei Shoigu on Wednesday appointed Chief of the General Staff Valery Gerasimov as overall commander for what Moscow calls its "special military operation" in Ukraine, now in its 11th month.The change effectively demoted General Sergei Surovikin, who was appointed only in October to lead the invasion and oversaw heavy attacks on Ukraine's energy infrastructure. 06:40 (IST) Jan 12 Russia, Ukraine agree new prisoner swap in Turkey Russia and Ukraine on Wednesday agreed a new prisoner swap during rare talks in Turkey during which they also discussed the creation of a "humanitarian corridor" in the war zone. Ukraine's human rights ombudsman Dmytro Lubinets met his Russian counterpart Tatyana Moskalkova on the sidelines of an international conference in Ankara attended by Turkish President Recep Tayyip Erdogan. 06:39 (IST) Jan 12 President Volodymyr Zelenskyy urged NATO on Wednesday to do more than just promise Ukraine its door is open at a July summit, saying Kyiv needs "powerful steps" as it tries to join the military alliance. 06:39 (IST) Jan 12 Russian forces shelled 13 settlements in and around Kharkiv region largely returned to Ukrainian hands in September and October, the Ukrainian military said. 06:38 (IST) Jan 12 Russia's war on Ukraine latest: Russia puts top general in charge of invasion Russia ordered its top general on Wednesday to take charge of its faltering invasion of Ukraine in the biggest shake-up yet of its malfunctioning military command structure after months of battlefield setbacks. 06:37 (IST) Jan 12 Zelenskyy says Russian war won't become WWIII Ukraine will stop Russian aggression and the conflict won't turn into World War III, President Volodymyr Zelenskiy said as his forces battled to keep control of Soledar and Bakhmut in the eastern Donetsk region. The Kremlin had positioned the most experienced units from the Wagner military-contracting company near Soledar, according to Ukrainian operational command spokesman Serhiy Cherevatyi."""
)

模型给出的回答符合预期:

Valery Gerasimov

4. 结合知识的问答

官方的代码中好像没有给出这部分内容,所以我根据论文的思路简单实现了一下,简而言之就是在召回之后,将目标文档的编码结果与召回的参考文档的编码结果进行拼接,然后再统一进行解码即可。

def predict_with_reference(model, tokenizer, question, title, context, reference_title, reference_context, device='cpu'):
    """
    预测
    :param model: T5模型
    :param tokenizer: 分词器
    :param question: 问题
    :param title: 标题,没有的话可以给空字符
    :param context: 正文
    :param reference_title: 召回文本的标题
    :param reference_context: 召回文本的正文
    :param device: 在cpu还是cuda上执行
    ---------------
    ver: 2023-01-12
    by: changhongyu
    """
    if device.startswith('cuda'):
        model.to(device)
    combined_text = "question: " + question + "title: " + title + "context: " + context
    combined_refer = "question: " + question + "title: " + reference_title + "context: " + reference_context
    query_inputs = tokenizer(combined_text, max_length=1024, return_tensors='pt')
    refer_inputs = tokenizer(combined_refer, max_length=1024, return_tensors='pt')
    test_outputs = model.generate(
        input_ids=torch.cat([query_inputs['input_ids'].unsqueeze(0), refer_inputs['input_ids'].unsqueeze(0)], dim=2).to(device),
        attention_mask=torch.cat([query_inputs['attention_mask'].unsqueeze(0), refer_inputs['attention_mask'].unsqueeze(0)], dim=2).to(device),
        max_length=50,
    )
    answer = tokenizer.decode(test_outputs[0])
    
    return answer

然后来测试一下效果:

假设我们有一篇地震相关的新闻:

text = """The death toll in Syria and Turkey from the earthquake has passed 12,000, with the number of injured exceeding 100,000, while hundreds of thousands have been displaced. In Turkey, at least 9,000 have been killed and nearly 60,000 people have been injured, authorities said on Wednesday. The death toll in Syria stands at more than 3,000, according to the Syrian Observatory for Human Rights, while Syrian state media reported more than 298,000 people have been displaced."""

以及在知识库里召回的一篇叙相关的介绍:

reference = """Syria (Arabic: سوريا‎, romanized: Sūriyā), officially the Syrian Arab Republic (Arabic: الجمهورية العربية السورية‎, romanized: al-Jumhūrīyah al-ʻArabīyah as-Sūrīyah), is a country in Western Asia, bordering Lebanon to the southwest, the Mediterranean Sea to the west, Turkey to the north, Iraq to the east, Jordan to the south, and Israel to the southwest. A country of fertile plains, high mountains, and deserts, Syria is home to diverse ethnic and religious groups, including Syrian Arabs, Kurds, Turkemens, Assyrians, Armenians, Circassians, Mandeans and Greeks. Religious groups include Sunnis, Christians, Alawites, Druze, Isma'ilis, Mandeans, Shiites, Salafis, Yazidis, and Jews. Arabs are the largest ethnic group, and Sunnis the largest religious group."""

然后进行问答:

predict_with_reference(
    model, 
    tokenizer,
    question="where is Syria.",
    title="Earthquake death toll exceeds 12,000 as Turkey, Syria seek help.",
    context=text,
    reference_title="Syria",
    reference_context=reference,
)

模型给出的回答是:

'Western Asia'

答案也是符合预期的。

如果是召回多篇文档,理论上将predict_with_reference这个方法的reference都改成list,然后再拼接的时候把结果组合起来就可以了,感兴趣的同学可以自己尝试一下。

以上就是本文的全部内容了,在ChatGPT时代下,KBQA这个话题似乎有点“过时”了,但是这对于练习NLP基础任务和理解attention的运作还是很有帮助的。如果这篇文章对你有帮助,欢迎一键三连加关注,也欢迎评论区或私信交流,我们下期再见。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/347527.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

低代码和零代码的有什么不同?如何区分?

低代码开发平台和零代码平台的区别是什么?一个例子就能讲清楚! 周末你外出露营,在野外需要搭一个帐篷。有两种方法: 一种是最原始的搭帐篷方法,即有隔水布、外账、内账、营柱骨架等等......另一种是直接“封装好”的…

OpenCV-PyQT项目实战(5)项目案例01:图像模糊

欢迎关注『OpenCV-PyQT项目实战 Youcans』系列,持续更新中 OpenCV-PyQT项目实战(1)安装与环境配置 OpenCV-PyQT项目实战(2)QtDesigner 和 PyUIC 快速入门 OpenCV-PyQT项目实战(3)信号与槽机制 …

PySpark实战一之入门

1、PySpark的编程模型 分三个模块: 数据输入:通过SparkContext对象,完成数据输入 数据处理计算:输入数据后得到RDD对象,对RDD对象的成员方法进行迭代计算 数据输出:最后通过RDD对象的成员方法&#xff0…

互联网行业固定资产智能化解决方案为企业降本增效

互联网行业的固定资产数量和种类往往比较多,来源可能是租赁、购入、调拨等。主要分为:办公设备、电子设备、服务器等。固定资产是互联网企业的重要资产之一,是企业持续经营的物质基础。因此,对于实物资产的管理尤为重要。 互联网…

搭建zookeeper高可用集群详细步骤

目录 一、虚拟机设置 1.新建一台虚拟机并克隆三台,配置自定义 2.修改四台虚拟机的主机名并立即生效 3.修改四台虚拟机的网络信息 4.重启四台虚拟机的网络服务并测试网络连接 5.重启四台虚拟机,启动后关闭四台虚拟机的防火墙 6.在第一台虚拟机的/e…

TripleCross:一款功能强大的Linux eBPF安全研究工具

关于TripleCross TripleCross是一款功能强大的Linux eBPF安全研究工具,该工具提供了后门、C2、代码库注入、执行劫持、持久化和隐蔽执行等功能。 功能介绍 1、使用一个代码库注入模块通过往进程的虚拟内存中写入命令来执行恶意代码; 2、提供了一个行劫…

波卡2022年第四季度报告

本文将介绍Messari最新发布的波卡Polkadot 2022年第四季度报告内容。 1 Messari已经发布关于波卡Polkadot最新的报告:显示了2022年第四季度的日活账户增加了64%,新用户增长49%。 2 Messari指出,波卡中继链在2022第四季度的环比增长令人印象…

JavaScript 保留关键字

文章目录JavaScript 保留关键字JavaScript 标准JavaScript 保留关键字JavaScript 对象、属性和方法Java 保留关键字Windows 保留关键字HTML 事件句柄非标准 JavaScriptJavaScript 保留关键字 在 JavaScript 中,一些标识符是保留关键字,不能用作变量名或函…

100行Pytorch代码实现三维重建技术神经辐射场 (NeRF)

提起三维重建技术,NeRF是一个绝对绕不过去的名字。这项逆天的技术,一经提出就被众多研究者所重视,对该技术进行深入研究并提出改进已经成为一个热点。不到两年的时间,NeRF及其变种已经成为重建领域的主流。本文通过100行的Pytorch…

部门新来个00后卷王,太让人崩溃了,想离职了....

在职场上,什么样的人最让人反感? 是技术不好的人吗? 并不是。技术不好的同事,我们可以帮他。 是技术太强的人吗? 也不是。技术很强的同事,可遇不可求,向他学习还来不及呢。 真正让人反感的…

【uniapp】getOpenerEventChannel().once 接收参数无效的解决方案

uniapp项目开发跨平台应用常会遇到接收参数无效的问题,无法判断是哪里出错了,这里是讲替代的方案,现有三种方案可选。 原因 一般我们是这样处理向另一个页面传参,代码是这样写的 //... let { title, type, rank } args; uni.n…

STM32 HAL库-定时器中断

STM32 HAL库-定时器中断一、STM32F407定时器介绍定时器计算公式二、CubeMX配置定时器三、基本定时器中断配置流程1)开启定时器时钟2)初始化定时器参数,设置自动重装值,分频系数,计数方式等3)使能定时器更新中断&#x…

Ubuntu 系统 OpenCV 4 无法打开视频文件解决方案

目录 一、我的运行环境 二、问题描述 三、问题定位及分析 四、解决方案 一、我的运行环境 设备NVIDIA Jetson Nano处理器ARMv8 Processor rev 1 (v8l) 4 GPUNVIDIA Tegra X1 (nvgpu)/integrated操作系统ubuntu 18.04 LTSOpenCV版本4.6.0语言C 二、问题描述 之前一直用的O…

8 冒泡排序

文章目录1 基本介绍1 代码实现1.1 java1.1 scala1 基本介绍 冒泡排序(Bubble Sorting)的基本思想是:通过对待排序序列从前向后(从下标较小的元素开始),依次比较相邻元素的值,若发现逆序则交换,使…

存储管理(6)

存储管理 1 程序的装入与链接 编译:源代码——目标代码 链接:目标代码所需库函数装入模块 装入:将装入模块装入内存,该过程也叫做地址重定位,也称地址映射 地址空间: 源程序经编译后得到的目标程序&…

Leetcode 1223. 掷骰子模拟【动态规划】

有一个骰子模拟器会每次投掷的时候生成一个 1 到 6 的随机数。 不过我们在使用它时有个约束,就是使得投掷骰子时,连续 掷出数字 i 的次数不能超过 rollMax[i](i 从 1 开始编号)。 现在,给你一个整数数组 rollMax 和一…

WebDAV之葫芦儿·派盘+NMM

NMM 支持WebDAV方式连接葫芦儿派盘。 推荐一款文件管理器,可以对手机中的文件进行多方面的管理,支持语法高亮和ftp等远程的文件的管理。支持从WebDav服务器连接葫芦儿派盘服务下载文件和上传文件。 NMM文本编辑器是一款文件管理器,在功能上面更加的适合于一些编程人员进行使…

2023年应该了解的黑客知识

网络犯罪的艺术处于不断变化和演变的状态。与这些趋势保持同步是网络安全人员工作的重要组成部分。 今天的现代网络安全必须确保他们始终为下一个大趋势做好准备并保持领先于对手。 当我们开始迈向 2023 年时,安全格局与一年前相比已经发生了变化,更不…

Spark on hive Hive on spark

文章目录Spark on hive & Hive on sparkHive 架构与基本原理Spark on hiveHive on sparkSpark on hive & Hive on spark Hive 架构与基本原理 Hive 的核心部件主要是 User Interface(1)和 Driver(3)。而不论是元数据库&a…

webpack(高级)--性能优化-代码分离

webpack webpack性能优化 优化一:打包后的结果 上线时的性能优化 (比如分包处理 减少包体积 CDN服务器) 优化二:优化打包速度 开发或者构建优化打包速度 (比如exclude cache-loader等) 大多数情况下我们侧…