经典关系抽取(一)CasRel(层叠式指针标注)在DuIE2.0数据集上的应用

news2025/1/4 19:52:48

经典关系抽取(一)CasRel(层叠式指针标注)在DuIE2.0数据集上的应用

  • 关系抽取(Relation Extraction)就是从一段文本中抽取出(主体,关系,客体)这样的三元组,用英文表示是 (subject, relation, object) 。
  • 从关系抽取的定义可以看出,关系抽取主要要做两件事:
    • 识别文本中的subject和object(实体识别)
    • 判断这两个实体属于哪种关系(关系分类)
  • 在解决关系抽取这个任务时,按照模型的结构分为两种,一种是 Pipeline,另一种是 Joint Model。
    • 如果将关系抽取的两个任务分离,先进行实体识别,再进行关系分类,就是 Pipeline 模型。一般认为Pipeline会出现误差传播的情况,也就是实体识别的误差,会影响后面的关系分类任务,但关系分类任务,又无法对实体识别造成的误差进行修正。因为两个任务是独立的,关系分类的loss,无法反向传播给实体识别阶段。
    • 为了优化 Pipeline 存在的问题,就有了 Joint Model(多任务学习的联合抽取模型),本质上还是 Pipeline 的编码方式,也就是先解码出实体,再去解码关系。但跟 Pipeline 不同的是,Joint Model 只计算一次损失,所以loss反向传播,也会对实体识别的误差进行修正
  • 今天,我们介绍一种利用层叠式指针标注(CasRel、大神苏剑林二作)进行关系抽取的模型。
    • 论文链接:https://arxiv.org/abs/1909.03227
    • 论文代码:https://github.com/weizhepei/CasRel(keras实现)
    • 苏神博客介绍:https://kexue.fm/archives/7161

1 实体关系的joint抽取模型简介

  • 我们已经知道pipeline方式主要存在下面几个问题:
    • pipeline方式存在误差积累,还会增加计算复杂度(实体冗余计算)
    • pipeline方式存在交互缺失,忽略实体和关系两个任务之间的内在联系
  • 因此现在关系抽取大都采取joint方式,实体关系的joint抽取模型可分为下面2大类。

1.1 多任务学习

多任务学习(共享参数的联合抽取模型)。多任务学习机制中,实体和关系共享同一个网络编码,但本质上仍然是采取pipeline的解码方式(故仍然存在误差传播问题)。近年来的大部分joint都采取这种共享参数的模式,集中在魔改各种Tag框架和解码方式。下面是几种典型网络结构:

  • 多头选择。构建的关系分类器对每一个实体pair进行关系预测(N为序列长度,C为关系类别总数),输入的实体pair其实是每一个抽取实体的最后一个token。

    • 论文地址:Joint entity recognition and relation extraction as a multi-head selection problem

    • 模型架构:

      在这里插入图片描述

  • 层叠式指针标注。将关系看作是SPO(Subject-Prediction-Object)抽取,先抽取主体Subject,然后对主体感知编码,最后通过层叠式的指针网络抽取关系及其对应的Object。

    • 论文地址:A Novel Cascade Binary Tagging Framework for Relational Triple Extraction
    • 模型架构:

    在这里插入图片描述

  • Span-level NER。通过片段(span)排列抽取实体,然后提取实体对进行关系分类。

    • 论文地址:Span-based Joint Entity and Relation Extraction with Transformer Pre-training(1909)
    • 模型架构:

    在这里插入图片描述

1.2 结构化预测

结构化预测(联合解码的联合抽取模型)

  • 结构化预测则是一个全局优化问题,在推断的时候能够联合解码实体和关系(而不是像多任务学习那样,先抽取实体、再进行关系分类)。
  • 结构化预测的joint模型也有很多,比如:
    • 统一的序列标注框架。Joint extraction of entities and relations based on a novel tagging scheme
    • 多轮QA+强化学习。Entity-Relation Extraction as Multi-Turn Question Answering

2 层叠式指针标注CasRel简介

2.1 背景介绍

  • 以往的方法大多将关系建模为实体対上的一个离散的标签:

    • 首先通过命名实体识别(NER)确定出句子中所有的实体
    • 然后学习一个关系分类器在所有的实体对上做RC,最终得到我们所需的关系三元组
  • 但是这样会出现下面的问题:

    • 类别分布不平衡。多数提取出来的实体对之间无关系
    • 同一实体参与多个有效关系(如下图的重叠三元组),分类器可能会混淆。如果没有足够的训练样例,分类器难以区分(有的多标签分类难以工作就是因为这个)。

    在这里插入图片描述

  • CasRel框架最核心思想:把关系(Relation)建模为将头实体(Subject)映射到尾实体(Object)的函数,而不是将其视为实体对上的标签。

我们不学习关系分类器: f ( s u b j e c t , o b j e c t ) − > r 而是学习关系特定的尾实体 ( o b j e c t ) 标注器 : f r ( s u b j e c t ) − > o b j e c t 我们不学习关系分类器:f(subject,object)->r\\ 而是学习关系特定的尾实体(object)标注器:f_r(subject)->object 我们不学习关系分类器:f(subject,object)>r而是学习关系特定的尾实体(object)标注器:fr(subject)>object

  • 每个尾实体标注器都将在给定关系和头实体的条件下识别出所有可能的尾实体

  • CASREL的做法:

    • 编码器:Bert
    • 序列标注:抽出头实体
    • 关系特定的尾实体标注:对每一个头实体,针对其可能的关系抽取其尾实体,如果不存在尾实体,就无此关系。

2.2 CasRel的网络结构

  • 关系三元组抽取问题被分解为如下的两步过程:

    • 首先,我们确定出句子中所有可能的头实体;
      • 在下图中,我们识别到三个头实体分别为:[Jackie R. Brown]、[Washington]、[United States Of America]
    • 然后针对每个头实体,我们使用关系特定的标注器来同时识别出所有可能的关系和对应的尾实体。
      • 在下图中,头实体[Jackie R. Brown]在特定关系[Birth_place]中识别到了2个尾实体[Washington]和[United States Of America],那么三元组为:(Jackie R. Brown, Birth_place, Washington)、(Jackie R. Brown, Birth_place, United States Of America)
      • 头实体[Washington]在特定关系[Capital_of]中识别到了1个尾实体[United States Of America],那么三元组为:(Washington, Capital_of, United States Of America)
      • 头实体[United States Of America]在所有的关系中都没有识别尾实体,因此没有三元组。
  • 如下图所示,CASREL模型由三个模块构成:

    • BERT编码层模块(BERT Encoder)。

      • 这个是固定操作,就是将token经过Bert后转换为词向量。
    • 主体标记模块(Subject Tagger)。

      • 采用两个相同的二元分类器分别检测主体的起始位置和结束位置。当置信度大于阈值标记为1,小于阈值标记为0。

      在这里插入图片描述
      在这里插入图片描述

      • 主体标记模块损失函数为BCELoss。
    • 特定关系下客体的标记模块(Relation-Specific Object Taggers)。

      • 结构和subject tagger相同,每个关系都有一个object tagger。

      在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

2.3 模型局限性

  • 因为实体就近匹配,无法解决实体嵌套的场景,
"""
假设有下面的嵌套实体(叶圣陶和叶圣陶散文选集):
       叶 圣  陶 散 文  选  集
start  1  0  0  0  0   0  0
end    0  0  1  0  0   0  1
"""
    ......
    # 1、先筛选出大于阈值的实体头和实体尾的位置
	start = torch.where(subject_preds[0, :, 0] > 0.6)[0]
    end = torch.where(subject_preds[0, :, 1] > 0.5)[0]
    # 2、主体抽取的实现代码
    subjects = []
    for i in start:
        j = end[end >= i]
        if len(j) > 0:
            j = j[0]
            subjects.append((i.item(), j.item()))
     # 依据代码的逻辑可以看到只能提取到【叶圣陶】实体        
  • 该模型不适用于长段落、篇章级别信息抽取,因为bert位置编码512位,关系无法跨句子;

3 CasRel在DuIE2.0数据集上的应用

  • 这里我们使用一款简洁的训练框架:bert4torch
# bert4torch框架作者知乎:https://www.zhihu.com/people/li-bo-53-72/posts
# https://github.com/Tongjilibo/bert4torch
pip install bert4torch==0.5.0
  • Bert预训练模型依然是哈工大开源的chinese-macbert-base

    • 预训练模型huggingface链接:
      https://huggingface.co/hfl/chinese-macbert-base/tree/main
    • 预训练模型魔搭社区链接:
      https://modelscope.cn/models/dienstag/chinese-macbert-base/files
  • 百度DuIE2.0中文关系抽取数据集:https://www.luge.ai/#/luge/dataDetail?id=5

  • 另外分享一份百度Knowledge Extraction比赛数据集:https://aistudio.baidu.com/datasetdetail/177191

3.1 加载数据集

我们先来看一条训练数据:

{
    "text": "《邪少兵王》是冰火未央写的网络小说连载于旗峰天下",
    "spo_list": [
        {
            "predicate": "作者", // 关系
            "object_type": {
                "@value": "人物"
            },
            "subject_type": "图书作品",
            "object": {
                "@value": "冰火未央" // 客体
            },
            "subject": "邪少兵王"    // 主体
        }
    ]
}

我们需要将上面的单条数据转换为下面形式:

// 我们需要将上面的单条数据转换为
//{'text': text, 'spo_list': [(s, p, o)]}
{
	'text': "《邪少兵王》是冰火未央写的网络小说连载于旗峰天下", 
	'spo_list': [('邪少兵王', '作者', '冰火未央')]
}
// 然后将text进行编码,获取token_ids和segment_ids,还需要获取主体及客体在token_ids中的位置
{
    'text': '《邪少兵王》是冰火未央写的网络小说连载于旗峰天下', 
    'spo_list': [('邪少兵王', '作者', '冰火未央')], 
    'token_ids': [101, 517, 6932, 2208, 1070, 4374, 518, 3221, 1102, 4125, 3313, 1925, 1091, 4638, 5381, 5317, 2207, 6432, 6825, 6770, 754, 3186, 2292, 1921, 678, 102],
    'segment_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 
    // 主体在token_ids中的位置(2, 5),左右都闭
    // 客体在token_ids中的位置(8, 11)以及关系的id=7
    'spoes': {(2, 5): [(8, 11, 7)]}
}

具体代码实现如下,在调试时候我们可以选取一部分数据进行测试:

#! -*- coding:utf-8 -*-
import json
import platform

import numpy as np
from bert4torch.layers import LayerNorm
from bert4torch.tokenizers import Tokenizer
from bert4torch.models import build_transformer_model, BaseModel
from bert4torch.snippets import sequence_padding, ListDataset
from bert4torch.callbacks import Callback
from tqdm import tqdm
import torch
from torch.utils.data import DataLoader, Dataset
import torch.optim as optim
import torch.nn as nn

maxlen = 128
batch_size = 64


# 获取当前操作系统的名称
os_name = platform.system()

# 设置模型路径及数据集路径
if os_name == "Windows":
    # 这里使用哈工大开源的chinese-macbert-base
    config_path = r'D:\python\models\berts\chinese-macbert-base\config.json'
    checkpoint_path = r'D:\python\models\berts\chinese-macbert-base\pytorch_model.bin'
    dict_path = r'D:\python\models\berts\chinese-macbert-base\vocab.txt'
    # 数据集使用百度开源的DuIE2.0
    train_data_path = r'D:\python\datas\nlp_ner\DuIE2.0\duie_train.json'
    dev_data_path = r'D:\python\datas\nlp_ner\DuIE2.0\duie_dev.json'
    schema_path = r'D:\python\datas\nlp_ner\DuIE2.0\duie_schema.json'
    print("当前执行环境是 Windows...")
elif os_name == "Linux":
    config_path = r'/root/autodl-fs/models/chinese-macbert-base/config.json'
    checkpoint_path = r'/root/autodl-fs/models/chinese-macbert-base/pytorch_model.bin'
    dict_path = r'/root/autodl-fs/models/chinese-macbert-base/vocab.txt'

    train_data_path = r'/root/autodl-fs/data/nlp_ai/nlp_ner/duie2/duie_train.json'
    dev_data_path = r'/root/autodl-fs/data/nlp_ai/nlp_ner/duie2/duie_dev.json'
    schema_path = r'/root/autodl-fs/data/nlp_ai/nlp_ner/duie2/duie_schema.json'
    print("当前执行环境是 Linux...")
else:
    raise ValueError("当前执行环境不是 Windows 也不是 Linux")



device = 'cuda' if torch.cuda.is_available() else 'cpu'

# 加载标签字典
predicate2id, id2predicate = {}, {}

# with open(r'D:\python\datas\nlp_ner\knowledge_extraction\all_50_schemas', encoding='utf-8') as f:
with open(schema_path, encoding='utf-8') as f:
    for l in f:
        l = json.loads(l)
        if l['predicate'] not in predicate2id:
            id2predicate[len(predicate2id)] = l['predicate']
            predicate2id[l['predicate']] = len(predicate2id)

# 建立分词器
tokenizer = Tokenizer(dict_path, do_lower_case=True)

# 解析样本
def get_spoes(text, spo_list):
    '''单独抽出来,这样读取数据时候,可以根据spoes来选择跳过
    '''
    def search(pattern, sequence):
        """从sequence中寻找子串pattern
        如果找到,返回第一个下标;否则返回-1。
        """
        n = len(pattern)
        for i in range(len(sequence)):
            if sequence[i:i + n] == pattern:
                return i
        return -1

    token_ids, segment_ids = tokenizer.encode(text, maxlen=maxlen)
    # 整理三元组 {s: [(o, p)]}
    spoes = {}
    for s, p, o in spo_list:
        s = tokenizer.encode(s)[0][1:-1]
        p = predicate2id[p]
        o = tokenizer.encode(o)[0][1:-1]
        s_idx = search(s, token_ids)
        o_idx = search(o, token_ids)
        if s_idx != -1 and o_idx != -1:
            s = (s_idx, s_idx + len(s) - 1)
            o = (o_idx, o_idx + len(o) - 1, p)
            if s not in spoes:
                spoes[s] = []
            spoes[s].append(o)
    return token_ids, segment_ids, spoes

# 加载数据集
class MyDataset(ListDataset):
    @staticmethod
    def load_data(filename):
        """加载数据
        单条格式:{'text': text, 'spo_list': [(s, p, o)]}
        """
        D = []
        with open(filename, encoding='utf-8') as f:
            for l in tqdm(f):
                l = json.loads(l)
                labels = [(spo['subject'], spo['predicate'], spo['object']['@value']) for spo in l['spo_list']]
                token_ids, segment_ids, spoes = get_spoes(l['text'], labels)
                if spoes:
                    D.append({'text': l['text'], 'spo_list': labels, 'token_ids': token_ids, 
                              'segment_ids': segment_ids, 'spoes': spoes})
                # 这里可以选一部分数据进行测试    
                # if len(D) > 100:
                #     break
        print(f'loaded data nums = {len(D)} from {filename}')
        return D

3.2 构建DataLoader

  • 利用collate_fn方法构建subject标签、object标签,并进行填充
  • 注意:这里是随机选一个subject
def collate_fn(batch):
    batch_token_ids, batch_segment_ids = [], []
    batch_subject_labels, batch_subject_ids, batch_object_labels = [], [], []
    for d in batch:
        token_ids, segment_ids, spoes = d['token_ids'], d['segment_ids'], d['spoes']
        if spoes:
            # 1、subject标签
            subject_labels = np.zeros((len(token_ids), 2))
            for s in spoes:
                subject_labels[s[0], 0] = 1  # subject首
                subject_labels[s[1], 1] = 1  # subject尾
            # 随机选一个subject(这里没有实现错误!这就是想要的效果!!)
            start, end = np.array(list(spoes.keys())).T
            start = np.random.choice(start)
            end = np.random.choice(end[end >= start])
            subject_ids = (start, end)
            # 2、对应的object标签
            # 后面sequence_padding方法中是对第1维数据进行填充,因此需要把len(token_ids)放在第1维
            object_labels = np.zeros((len(token_ids), len(predicate2id), 2))
            for o in spoes.get(subject_ids, []):
                object_labels[o[0], o[2], 0] = 1
                object_labels[o[1], o[2], 1] = 1
            # 构建batch
            batch_token_ids.append(token_ids)
            batch_segment_ids.append(segment_ids)
            batch_subject_labels.append(subject_labels)
            batch_subject_ids.append(subject_ids)
            batch_object_labels.append(object_labels)
    # 填充
    batch_token_ids = torch.tensor(sequence_padding(batch_token_ids), dtype=torch.long, device=device)
    batch_segment_ids = torch.tensor(sequence_padding(batch_segment_ids), dtype=torch.long, device=device)
    batch_subject_labels = torch.tensor(sequence_padding(batch_subject_labels), dtype=torch.float, device=device)
    batch_subject_ids = torch.tensor(batch_subject_ids, dtype=torch.long, device=device)
    batch_object_labels = torch.tensor(sequence_padding(batch_object_labels), dtype=torch.float, device=device)
    batch_attention_mask = (batch_token_ids != tokenizer._token_pad_id)
    return [batch_token_ids, batch_segment_ids, batch_subject_ids], [batch_subject_labels, batch_object_labels, batch_attention_mask]

# 构造训练集的DataLoader
train_dataset = MyDataset(train_data_path)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)

# 构造测试集的DataLoader
valid_dataset = MyDataset(dev_data_path)
valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, collate_fn=collate_fn)

3.3 定义CasRel模型

  • 注意:这里使用Conditional Layer Normalization将subject融入到object的预测中
  • 常见的归一化可以参考:
    Pytorch常用的函数(六)常见的归一化总结(BatchNorm/LayerNorm/InsNorm/GroupNorm)
  • 我们看下condLayerNorm实现的核心代码:
# bert4torch/layers/core.py
class LayerNorm(nn.Module):
    def __init__(self, hidden_size:int, eps:float=1e-12, conditional_size:Union[bool, int]=False, weight:bool=True, bias:bool=True, 
                 norm_mode:Literal['normal', 'torch_buildin', 'rmsnorm']='normal', **kwargs):
        """ layernorm层,自行实现是为了兼容conditianal layernorm,使得可以做条件文本生成、条件分类等任务

            :param hidden_size: int, layernorm的神经元个数
            :param eps: float
            :param conditional_size: int, condition layernorm的神经元个数; 详情:https://spaces.ac.cn/archives/7124
            :param weight: bool, 是否包含权重
            :param bias: bool, 是否包含偏置
            :param norm_mode: str, `normal`, `rmsnorm`, `torch_buildin`
        """
        ......
        # 条件layernorm, 用于条件文本生成
        if conditional_size:
            # 这里采用全零初始化, 目的是在初始状态不干扰原来的预训练权重
            self.dense1 = nn.Linear(conditional_size, hidden_size, bias=False)
            self.dense1.weight.data.uniform_(0, 0)
            self.dense2 = nn.Linear(conditional_size, hidden_size, bias=False)
            self.dense2.weight.data.uniform_(0, 0)

    def forward(self, hidden_states:torch.FloatTensor, cond:Optional[torch.Tensor]=None):
        ......
        # 自行实现的LayerNorm
        else:
            u = hidden_states.mean(-1, keepdim=True)
            s = (hidden_states - u).pow(2).mean(-1, keepdim=True)
            o = (hidden_states - u) / torch.sqrt(s + self.eps)

        if not hasattr(self, 'weight'):
            self.weight = 1

        if self.conditional_size and (cond is not None):
            # 会进入此分支
            for _ in range(len(hidden_states.shape) - len(cond.shape)):
                # cond就是将subject的start和end向量concat一起
                # [bs, hidden_size*2] -> [bs, 1, hidden_size*2]
                cond = cond.unsqueeze(dim=1)
            # 通过下面代码将subject融入到object的预测中
            # o是经过原始LayerNorm后的值
            # (self.weight + self.dense1(cond))相当于gamma
            # self.dense2(cond)相当于beta
            output = (self.weight + self.dense1(cond)) * o + self.dense2(cond)
        else:
            output = self.weight * o

        if hasattr(self, 'bias') and (self.bias is not None):
            output += self.bias
        return output.type_as(hidden_states)
# 定义bert上的模型结构
class Model(BaseModel):
    def __init__(self) -> None:
        super().__init__()
        self.bert = build_transformer_model(config_path, checkpoint_path)
        self.linear1 = nn.Linear(768, 2)
        """
        苏神博客:https://spaces.ac.cn/archives/7124
        在Bert等Transformer模型中,主要的Normalization方法是Layer Normalization
            ,所以很自然就能想到将对应的β和γ变成输入条件的函数,来控制Transformer模型的生成行为
            ,这就是Conditional Layer Normalization的线索思路。
        这里通过Conditional Layer Normalization将subject融入到object的预测中    
        """
        self.condLayerNorm = LayerNorm(hidden_size=768, conditional_size=768*2)
        self.linear2 = nn.Linear(768, len(predicate2id)*2)

    @staticmethod
    def extract_subject(inputs):
        """
        根据subject_ids从output中取出subject的向量表征
        :param inputs: [seq_output, subject_ids]
                seq_output是模型预测的每个token属于start和end概率结果,shape=(bs, seq_len, hidden_size)
                subject_ids是主体的id,shape=(bs, 2)
        :return: subject的向量表征 shape = (bs, hidden_size * 2)
        """
        output, subject_ids = inputs
        # torch.gather函数:https://blog.csdn.net/qq_44665283/article/details/134088676
        # 1、取出主体start和end向量的向量表征
        # start shape = (bs, 1, hidden_size)
        start = torch.gather(output, dim=1, index=subject_ids[:, :1].unsqueeze(2).expand(-1, -1, output.shape[-1]))
        # end shape = (bs, 1, hidden_size)
        end = torch.gather(output, dim=1, index=subject_ids[:, 1:].unsqueeze(2).expand(-1, -1, output.shape[-1]))
        # 2、concat
        subject = torch.cat([start, end], 2)
        return subject[:, 0]

    def forward(self, *inputs):
        """
        :param inputs: collate_fn函数会将batch数据集组装成train_X, train_y
                        这里的inputs就是train_X,即[batch_token_ids, batch_segment_ids, batch_subject_ids]
        :return: 主体和客体的预测值
        """
        token_ids, segment_ids, subject_ids = inputs
        # 预测subject
        seq_output = self.bert(token_ids, segment_ids)  # [bs, seq_len, hidden_size]
        subject_preds = (torch.sigmoid(self.linear1(seq_output)))**2  # [btz, seq_len, 2]

        # 传入subject,预测object
        subject = self.extract_subject([seq_output, subject_ids])
        # Note: 通过Conditional Layer Normalization将subject融入到object的预测中
        # 理论上应该用LayerNorm前的,但是这样只能返回各个block顶层输出,这里和keras实现不一致
        output = self.condLayerNorm(seq_output, subject)
        # 进行客体的分类预测
        output = (torch.sigmoid(self.linear2(output)))**4
        # (bs, seq_len, len(predicate2id)*2) -> (bs, seq_len, len(predicate2id), 2)
        object_preds = output.reshape(*output.shape[:2], len(predicate2id), 2)

        return [subject_preds, object_preds]
    
    def predict_subject(self, inputs):
        """
        主体预测
        :param inputs: [batch_token_ids, batch_segment_ids, batch_subject_ids]
        :return: [seq_output, subject_preds]
                 每个token的向量表示:seq_output shape = [bs, seq_len, hidden_size]
                 每个token属于主体的start和end的概率:subject_preds shape = [bs, seq_len, 2]
        """
        self.eval()
        with torch.no_grad():
            seq_output = self.bert(inputs[:2])  # [bs, seq_len, hidden_size]
            subject_preds = (torch.sigmoid(self.linear1(seq_output)))**2  # [bs, seq_len, 2]
        return [seq_output, subject_preds]
    
    def predict_object(self, inputs):
        """
        客体预测
        :param inputs: 主体预测的输出[seq_output, subject_preds]
        :return: 客体预测结果 shape = (bs, seq_len, len(predicate2id), 2)
        """
        self.eval()
        with torch.no_grad():
            seq_output, subject_ids = inputs
            # 根据subject_ids从output中取出subject的向量表征
            subject = self.extract_subject([seq_output, subject_ids])
            # 通过Conditional Layer Normalization将subject融入到object的预测中
            output = self.condLayerNorm(seq_output, subject)
            # 客体预测的置信度
            output = (torch.sigmoid(self.linear2(output)))**4
            object_preds = output.reshape(*output.shape[:2], len(predicate2id), 2)
        return object_preds


train_model = Model().to(device)

3.4 定义损失函数、评估函数

  • 损失函数就是主体BCELoss+客体BCELoss
  • 评估函数,计算f1、precision、recall
class BCELoss(nn.BCELoss):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def forward(self, inputs, targets):
        """
        loss计算
        :param inputs: Model模型forward方法的输出内容,即[subject_preds, object_preds]
        :param targets: collate_fn函数会将batch数据集组装成train_X, train_y
                        这里的targets就是train_y,即[batch_subject_labels, batch_object_labels, batch_attention_mask]
        :return: 主体loss+客体loss
        """
        subject_preds, object_preds = inputs
        subject_labels, object_labels, mask = targets

        # sujuect部分loss
        subject_loss = super().forward(subject_preds, subject_labels)
        subject_loss = subject_loss.mean(dim=-1)
        subject_loss = (subject_loss * mask).sum() / mask.sum()
        # object部分loss
        object_loss = super().forward(object_preds, object_labels)
        object_loss = object_loss.mean(dim=-1).sum(dim=-1)
        object_loss = (object_loss * mask).sum() / mask.sum()
        return subject_loss + object_loss


train_model.compile(loss=BCELoss(reduction='none'), optimizer=optim.Adam(train_model.parameters(), 1e-5))


def extract_spoes(text):
    """抽取输入text所包含的三元组
    """
    tokens = tokenizer.tokenize(text, maxlen=maxlen)
    mapping = tokenizer.rematch(text, tokens)
    token_ids, segment_ids = tokenizer.encode(text, maxlen=maxlen)
    token_ids = torch.tensor([token_ids], dtype=torch.long, device=device)
    segment_ids = torch.tensor([segment_ids], dtype=torch.long, device=device)

    # 抽取subject,主体头和尾的阈值分别为:0.6和0.5
    seq_output, subject_preds = train_model.predict_subject([token_ids, segment_ids])
    subject_preds[:, [0, -1]] *= 0  # 首cls, 尾sep置为0
    # 1、先筛选出大于阈值的实体头和实体尾的位置
    start = torch.where(subject_preds[0, :, 0] > 0.6)[0]
    end = torch.where(subject_preds[0, :, 1] > 0.5)[0]
    # 2、主体抽取的实现代码
    subjects = []
    for i in start:
        j = end[end >= i]
        if len(j) > 0:
            j = j[0]
            subjects.append((i.item(), j.item()))

    # 3、如果存在主体,就构造(主体、关系、客体)三元组
    if subjects:
        spoes = []
        seq_output = seq_output.repeat([len(subjects)]+[1]*(len(seq_output.shape)-1))
        subjects = torch.tensor(subjects, dtype=torch.long, device=device)
        # 传入subject,抽取object和predicate
        object_preds = train_model.predict_object([seq_output, subjects])
        object_preds[:, [0, -1]] *= 0
        for subject, object_pred in zip(subjects, object_preds):
            # 客体的阈值
            start = torch.where(object_pred[:, :, 0] > 0.6)
            end = torch.where(object_pred[:, :, 1] > 0.5)
            for _start, predicate1 in zip(*start):
                for _end, predicate2 in zip(*end):
                    if _start <= _end and predicate1 == predicate2:
                        spoes.append(
                            (
                                (mapping[subject[0]][0], mapping[subject[1]][-1]),
                                predicate1.item(),
                                (mapping[_start][0], mapping[_end][-1])
                            )
                        )
                        break
        return [(text[s[0]:s[1] + 1], id2predicate[p], text[o[0]:o[1] + 1])
                for s, p, o, in spoes]
    else:
        return []


class SPO(tuple):
    """用来存三元组的类
    表现跟tuple基本一致,只是重写了 __hash__ 和 __eq__ 方法,
    使得在判断两个三元组是否等价时容错性更好。
    """
    def __init__(self, spo):
        self.spox = (
            tuple(tokenizer.tokenize(spo[0])),
            spo[1],
            # tuple(tokenizer.tokenize(spo[2])),
            tuple(tokenizer.tokenize(spo[2])),
        )

    def __hash__(self):
        return self.spox.__hash__()

    def __eq__(self, spo):
        return self.spox == spo.spox


def evaluate(data):
    """评估函数,计算f1、precision、recall
    """
    X, Y, Z = 1e-10, 1e-10, 1e-10
    f = open('dev_pred.json', 'w', encoding='utf-8')
    pbar = tqdm()
    for d in data:
        # 预测三元组集合
        R = set([SPO(spo) for spo in extract_spoes(d['text'])])
        # 真实三元组集合
        T = set([SPO(spo) for spo in d['spo_list']])
        X += len(R & T)
        Y += len(R)
        Z += len(T)
        # 计算f1、precision、recall指标
        f1, precision, recall = 2 * X / (Y + Z), X / Y, X / Z
        pbar.update()
        pbar.set_description(
            'f1: %.5f, precision: %.5f, recall: %.5f' % (f1, precision, recall)
        )
        s = json.dumps({
            'text': d['text'],
            'spo_list': list(T),
            'spo_list_pred': list(R),
            'new': list(R - T),
            'lack': list(T - R),
        },
                       ensure_ascii=False,
                       indent=4)
        f.write(s + '\n')
    pbar.close()
    f.close()
    return f1, precision, recall


class Evaluator(Callback):
    """评估与保存
    """
    def __init__(self):
        self.best_val_f1 = 0.

    def on_epoch_end(self, steps, epoch, logs=None):
        # optimizer.apply_ema_weights()
        f1, precision, recall = evaluate(valid_dataset.data)
        if f1 >= self.best_val_f1:
            self.best_val_f1 = f1
            train_model.save_weights('best_model.pt')
        # optimizer.reset_old_weights()
        print(
            'f1: %.5f, precision: %.5f, recall: %.5f, best f1: %.5f\n' %
            (f1, precision, recall, self.best_val_f1)
        )


if __name__ == '__main__':
    # 训练
    if True:
        evaluator = Evaluator()
        train_model.fit(train_dataloader, steps_per_epoch=None, epochs=5, callbacks=[evaluator])
    # 预测并评估
    else:
        train_model.load_weights('best_model.pt')
        f1, precision, recall = evaluate(valid_dataset.data)
# batch_size = 64设置下,训练过程中的资源消耗(2080Ti)
# 全量数据进行训练,训练一个epoch消耗时间25min左右
(transformers) root@autodl-container-adbc11ae52-f2ebff02:~/autodl-tmp/NLP_AI/a_re# nvidia-smi 
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.89.02    Driver Version: 525.89.02    CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  NVIDIA GeForce ...  On   | 00000000:B2:00.0 Off |                  N/A |
|111%   74C    P2   240W / 250W |  10940MiB / 11264MiB |     96%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
+-----------------------------------------------------------------------------+
# 这里只训练了5个epochs,有兴趣的可以多训练一些批次
# 训练5个epochs的效果如下
f1: 0.74309, precision: 0.78018, recall: 0.70936

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1923158.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【Windows】硬链接和软链接(OneDrive同步指定目录?)

文章目录 一、场景带入二、Windows下的硬链接和软链接2.1 硬链接&#xff08;Hard Link&#xff09;2.2 软链接&#xff08;符号链接&#xff0c;Symbolic Link&#xff09;2.3 软链接和快捷方式2.4 应用场景 三、OneDrive中的应用3.1 错误姿势3.2 好像可行的尝试3.3 合理的解决…

SpringBoot使用Redisson操作Redis及使用场景实战

前言 在SpringBoot使用RedisTemplate、StringRedisTemplate操作Redis中&#xff0c;我们介绍了RedisTemplate以及如何SpringBoot如何通过RedisTemplate、StringRedisTemplate操作Redis。 RedisTemplate的好处就是基于SpringBoot自动装配的原理&#xff0c;使得整合redis时比较…

51单片机(STC8H8K64U/STC8051U34K64)_RA8889_8080参考代码(v1.3)

硬件&#xff1a;STC8H8K64U/STC8051U34K64 RA8889开发板 硬件跳线变更为并口8080模式&#xff0c;PS00x&#xff0c;R143&#xff0c;R142不接&#xff0c;R141无关 8080接口电路连接图&#xff1a; 实物连接图&#xff1a; RA8889开发板外接MCU连接器之引脚定义&…

防火巡查记录卡数字化平台

防火巡查记录卡数字化平台 利用凡尔码搭建防火巡查记录卡数字化平台是一个高效且实用的解决方案&#xff0c;能够显著提升防火巡查的效率和管理水平。替代纸质巡检造成的数据丢失等困扰。 一、如何注册凡尔码平台 百度搜索“凡尔码”找到平台地址即可注册开通。凡尔码平台通…

二叉树层序遍历?秒了!

废话不多说&#xff0c;直接上题&#xff0c;涉及到二叉树层序遍历的题目大部分都可以用这个方法&#xff1a; 示例&#xff1a;力扣102 二叉树的层序遍历 给你二叉树的根节点 root &#xff0c;返回其节点值的 层序遍历 。 &#xff08;即逐层地&#xff0c;从左到右访问所有…

实验8 视图创建与管理实验

一、实验目的 理解视图的概念。掌握创建、更改、删除视图的方法。掌握使用视图来访问数据的方法。 二、实验内容 在job数据库中&#xff0c;有聘任人员信息表&#xff1a;Work_lnfo表&#xff0c;其表结构如下表所示&#xff1a; 其中表中练习数据如下&#xff1a; 1.‘张明…

数据结构之单链表(赋源码)

数据结构之单链表 线性表 线性表的顺序存储结构&#xff0c;有着较大的缺陷 插入和删除操作需要移动大量元素。会耗费很多时间增容需要申请空间&#xff0c;拷贝数据&#xff0c;释放旧空间。会有不小的消耗即使是使用合理的增容策略&#xff0c;实际上还会浪费许多用不上的…

【Oracle】实验五 PL_SQL编程

【实验目的】 熟悉PL/SQL的数据类型和书写规则熟悉控制结构和游标的使用编写和运行函数、过程和触发器 【实验内容】 编写脚本文件&#xff0c;调试运行脚本文件&#xff0c;并记录结果。 本地子程序的编写及调试 1、编写一个PL/SQL块&#xff0c;功能用于打印学生信息。整…

【学习css1】flex布局-页面footer部分保持在网页底部

中间内容高度不够屏幕高度撑不开的页面时候&#xff0c;页面footer部分都能保持在网页页脚&#xff08;最底部&#xff09;的方法 1、首先上图看显示效果 2、奉上源码 2.1、html部分 <body><header>头部</header><main>主区域</main><foot…

深入解析香橙派 AIpro开发板:功能、性能与应用场景全面测评

文章目录 引言香橙派AIpro开发板介绍到手第一感觉开发板正面开发板背面 性能性能概况性能体验 应用场景移植操作系统香橙派 AIpro开发板支持哪些操作系统&#xff1f;烧写操作系统到SD卡中启动开发板的步骤查看系统提供的事例程序体验——开发的简洁性 视频播放展示ffmpeg简介f…

【Python3】自动化测试_用Playwright发送API请求

一、创建APIRequestContex实例 # 连接到 APIRequest&#xff0c;可用于 Web API 测试的 API。 myRequest myPlaywright.request# 创建APIRequestContext实例&#xff0c;该实例可用于发送 Web 请求 myRequestContext myRequest.new_context() myRequest.new_context(**kwargs…

【MySQL】8.复合查询

复合查询 一.基本查询回顾(新增子查询)二.多表查询三.自连接四.子查询1.单列单行子查询2.单列多行子查询——三个关键字3.多列子查询4.在 from 子句中使用子查询 五.合并查询六.总结 一.基本查询回顾(新增子查询) //1.查询工资高于500或岗位为MANAGER的雇员&#xff0c;同时还…

js逆向-webpack-python

网站&#xff08;base64&#xff09;&#xff1a;aHR0cHM6Ly93d3cuY29pbmdsYXNzLmNvbS96aA 案例响应解密爬取&#xff08;webpack&#xff09; 1、找到目标url 2、进行入口定位&#xff08;此案例使用 ‘decrypt(’ 关键字搜索 &#xff09; 3、找到位置进行分析 --t 为 dat…

【软件工具】VMware Workstation Pro 15.5安装

1、双击运行安装包程序 2、接受许可证协议 3、选择安装位置&#xff0c;建议非中文无空格&#xff0c;增强型键盘驱动程序可选 4、按照自身使用习惯勾选产品更新和客户体验提升计划 5、快捷方式 6、开始安装 7、稍等会儿(可以玩会儿手机) 8、可输入许可证也可直接完成&#xff…

《ElementUI/Plus 基础知识》el-tree 之修改可拖拽节点的高亮背景和线

前言 收到需求&#xff0c;PM 觉得可拖拽节点的高亮背景和线样式不明显&#xff01;CSS 样式得改&#xff01; 注意&#xff1a;下述方式适用于ElementUI el-tree 和 ElementPlus el-tree&#xff01; 修改 拖拽被叠加节点的背景色和文字 关键类名 is-drop-inner .el-tree…

几何距离与函数距离:解锁数据空间中的奥秘

几何距离&#xff1a;直观的空间度量 几何距离&#xff0c;顾名思义&#xff0c;是我们在几何学中熟悉的距离概念&#xff0c;如欧几里得距离、曼哈顿距离和切比雪夫距离等。这些距离度量直接反映了数据点在多维空间中的位置关系。 欧几里得距离&#xff1a;最为人熟知的几何距…

conda install问题记录

最近想用代码处理sar数据&#xff0c;解放双手。 看重了isce这个处理平台&#xff0c;在安装包的时候遇到了一些问题。 这一步持续了非常久&#xff0c;然后我就果断ctrlc了 后面再次进行尝试&#xff0c;出现一大串报错&#xff0c;不知道是不是依赖项的问题 后面看到说mam…

langchain-runnable底层原理

文章目录 langchainlangchain生态介绍langchainLCELrunnablerunnable基础能力介绍invokebatchstreamainvokeabatchastream__or__、__ror__pipeget_nameInputType (属性)OutputType (属性)input_schema (属性)output_schema (属性) langchain langchain生态介绍 langchain是一个…

Min P Sampling: Balancing Creativity and Coherence at High Temperature阅读笔记

上一篇文章是关于大语言模型的调参数&#xff0c;写了temperature这个参数近期的一个工作。那接下来&#xff0c;就不得不再来讲讲top-p这个参数啦。首先还是上文章&#xff0c;同样是非常新的一个工作&#xff0c;2024年7月1日submit的呢。 文章链接&#xff1a;https://arxi…

NLP任务:情感分析、看图说话

我可不向其他博主那样拖泥带水&#xff0c;我有代码就直接贴在文章里&#xff0c;或者放到gitee供你们参考下载&#xff0c;虽然写的不咋滴&#xff0c;废话少说&#xff0c;上代码。 gitee码云地址&#xff1a; 卢东艺/pytorch_cv_nlp - 码云 - 开源中国 (gitee.com)https:/…