biaffine model:Named Entity Recognition as Dependency Parsing

news2024/12/28 5:03:43

论文名称:Named Entity Recognition as Dependency Parsing

论文地址:https://www.aclweb.org/anthology/2020.acl-main.577/

前提说明

本文主要参考了以下资料

  • nlp_paper_study_information_extraction/code_pytorch.md at main · km1994/nlp_paper_study_information_extraction (github.com)
  • suolyer/PyTorch_BERT_Biaffine_NER: 论文复现《Named Entity Recognition as Dependency Parsing》 (github.com)

借助于第二个资料里的仓库,也是很顺利的跑出了该模型

摘要

  • 动机:NER研究关注于flat NER,而忽略了nested NER
  • 方法:在本文中,使用基于图的依存关系解析中的思想,以通过biaffine model为模型提供全局的输入视图。biaffine model 对句子中的开始标记和结束标记进行评分,使用该标记来探索所有跨度,以便该模型能够准确地预测命名实体
  • 工作介绍:在这项工作中,我们将NER重新确定为开始和结束索引的任务,并为这些定义的范围分配类别,我们的系统在多层Bi-LSTM之上使用biaffine模型,将分数分配给句子中所有可能的跨度。此后,我们不用构建依赖关系树,而是根据侯选树的分数对它们进行排序,然后返回符合 Flat 或 Nested NER约束排名最高的树span
  • 实验结果:
  • 我们根据三个嵌套的NER基准(ACE 2004,ACE 2005,GENIA)和五个扁平的NER语料库(CONLL 2002(荷兰语,西班牙语),CONLL 2003(英语,德语)和ONTONOTES)对系统进行了评估。结果表明,我们的系统在所有三个嵌套的NER语料库和所有五个平坦的NER语料库上均取得了SoTA结果,与以前的SoTA相比,实际收益高达2.2%的绝对百分比。

一、数据处理模块

1. 1原始数据格式

{"text": "当希望工程救助的百万儿童成长起来,科教兴国蔚然成风时,今天有收藏价值的书你没买,明日就叫你悔不当初!", 
 "entity_list": []
}
{"text": "藏书本来就是所有传统收藏门类中的第一大户,只是我们结束温饱的时间太短而已。", 
 "entity_list": []
}
{"text": "因有关日寇在京掠夺文物详情,藏界较为重视,也是我们收藏北京史料中的要件之一。", 
 "entity_list": [{"type": "ns", "argument": "北京"}]
}
...

1.2数据预处理模块

1.2 .1数据加载load_data(file_path)
def load_data(file_path):
    with open(file_path, 'r', encoding='utf8') as f:
        lines = f.readlines()
        sentences = []
        arguments = []
        for line in lines:
            data = json.loads(line)
            text,entity_list = data['text'],data['entity_list']
            args_dict={}
            if entity_list != []:
                for entity in entity_list:
                    entity_type,entity_argument = entity['type'],entity['argument']

                    if entity_type not in args_dict.keys():
                        args_dict[entity_type] = [entity_argument]
                    else:
                        args_dict[entity_type].append(entity_argument)
                sentences.append(text)
                arguments.append(args_dict)
        return sentences, arguments
  • 获取原始数据

  • 返回entity_list不为 [] 的数据

  • 返回sentences、arguments,格式如下

    print(f"sentences[0:2]:{sentences[0:2]}")
    print(f"arguments[0:2]:{arguments[0:2]}")
    
    sentences[0:2]:['因有关日寇在京掠夺文物详情,藏界较为重视,也是我们收藏北京史料中的要件之一。', 
    			   '我们藏有一册1945年6月油印的《北京文物保存保管状态之调查报告》,调查范围涉及故宫、历博、古研所、北大清华图书馆、北图、日伪资料库等二十几家,言及文物二十万件以上,洋洋三万余言,是珍贵的北京史料。']
    
    arguments[0:2]:[{'ns': ['北京']}, 
    			   {'ns': ['北京', '故宫', '历博', '北大清华图书馆', '北图', '北京'], 'nt': ['古研所']}]
    
1.2.2 数据编码encoder(sentence,argument)
# step 1:获取 Bert tokenizer
tokenizer=tools.get_tokenizer()
# step 2: 获取 label 到 id 间  的 映射表;
label2id,id2label,num_labels = tools.load_schema()

def encoder(sentence, argument):
    from utils.arguments_parse import args
    # step 3:利用 tokenizer 对 sentence 进行 编码
    encode_dict = tokenizer.encode_plus(sentence,
                                        max_length=args.max_length,
                                        pad_to_max_length=True)
    encode_sent = encode_dict['input_ids']
    token_type_ids = encode_dict['token_type_ids']
    attention_mask = encode_dict['attention_mask']
    
	# step 4:span_mask 生成
    zero = [0 for i in range(args.max_length)]
    span_mask=[ attention_mask for i in range(sum(attention_mask))]
    span_mask.extend([ zero for i in range(sum(attention_mask),args.max_length)])

    # step 5:span_label 生成
    span_label = [0 for i in range(args.max_length)]
    span_label = [span_label for i in range(args.max_length)]
    span_label = np.array(span_label)
    for entity_type,args in argument.items():
        for arg in args:
            encode_arg = tokenizer.encode(arg)
            start_idx = tools.search(encode_arg[1:-1], encode_sent)
            end_idx = start_idx + len(encode_arg[1:-1]) - 1
            span_label[start_idx, end_idx] = label2id[entity_type]+1 # 在span_label这个矩阵中,1代表nr,2代表ns,3代表nt

    return encode_sent, token_type_ids, attention_mask, span_label, span_mask
  • 获取Bert tokenizer、获取label到id间的映射表

  • encode_plus后的编码信息

    • input_ids:单词在词典中的编码
    • token_type_ids:区分两个句子的编码
    • attention_mask:指定对哪些词进行self-Attention操作
    encode_dict:
    {
        'input_ids': [101, 1728, 3300, 1068, 3189, 2167, 1762, 776, 2966, 1932, 3152, 4289, 6422, 2658, 8024, 5966, 4518, 6772, 711, 7028, 6228, 8024, 738, 3221, 2769, 812, 3119, 5966, 1266, 776, 1380, 3160, 704, 4638, 6206, 816, 722, 671, 511, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 
        'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
    }
    
  • span_mask:形状为[max_len,max_len]。这个行、列都是这个句子表示,掩码机制扩展到二维上

  • span_label:形状为[max_len,max_len]。用于定位实体span在句子中的位置[开始位置,结束位置],span在矩阵中行代表开始,列代表结束,里面的值就是该span所对应的类型

    >>>
    import numpy as np
    span_label = [0 for i in range(10)]
    span_label = [span_label for i in range(10)]
    span_label = np.array(span_label)
    start = [1, 3, 7]
    end  = [ 2,9, 9]
    label2id = [1,2,4]
    for i in range(len(label2id)):
        span_label[start[i], end[i]] = label2id[i]  
    
    >>> 
    array( [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
            [0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
            [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
            [0, 0, 0, 0, 0, 0, 0, 0, 0, 2],
            [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
            [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
            [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
            [0, 0, 0, 0, 0, 0, 0, 0, 0, 4],
            [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
            [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])
    > 注:行号 为 start,列号 为 end,值 为 label2id
    
1.2.3 数据预处理主函数 data_pre(file_path)
  • 加载数据,对数据进行编码,转化为训练数据格式
def data_pre(file_path):
    sentences, arguments = load_data(file_path)
    data = []
    for i in tqdm(range(len(sentences))): ##一条条句子读取
        encode_sent, token_type_ids, attention_mask, span_label, span_mask = encoder(
            sentences[i], arguments[i])
        tmp = {}
        tmp['input_ids'] = encode_sent
        tmp['input_seg'] = token_type_ids
        tmp['input_mask'] = attention_mask
        tmp['span_label'] = span_label
        tmp['span_mask'] = span_mask
        data.append(tmp)

    return data

1.3 数据转为MyDataset对象

将数据转化为 torch.tensor 类型

class MyDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        item = self.data[index]
        one_data = {
            "input_ids": torch.tensor(item['input_ids']).long(),
            "input_seg": torch.tensor(item['input_seg']).long(),
            "input_mask": torch.tensor(item['input_mask']).float(),
            "span_label": torch.tensor(item['span_label']).long(),
            "span_mask": torch.tensor(item['span_mask']).long()
        }
        return one_data

1.4 构建数据迭代器

def yield_data(file_path):
    tmp = MyDataset(data_pre(file_path))
    return DataLoader(tmp, batch_size=args.batch_size, shuffle=True)

二、模型构建 模块

模型主要由 embedding layer、BiLSTM、biaffine model四部分组成

embedding layers

  1. BERT: 遵循 (Kantor and Globerson, 2019) 的方法来获取目标令牌的上下文相关嵌入,每侧有64个周围令牌;
  2. character-based word embeddings:使用 CNN 编码 characters of the tokens
# 获取 Bert tokenizer
tokenizer=tools.get_tokenizer()

class myModel(nn.Module):
    def __init__(self, pre_train_dir: str, dropout_rate: float):
        super().__init__()
        self.roberta_encoder = BertModel.from_pretrained(pre_train_dir)
        self.roberta_encoder.resize_token_embeddings(len(tokenizer))
        ...

    def forward(self, input_ids, input_mask, input_seg, is_training=False):
        bert_output = self.roberta_encoder(input_ids=input_ids, 
                                            attention_mask=input_mask, 
                                            token_type_ids=input_seg) 
        encoder_rep = bert_output[0]# {batch_size,max_seq_len,hidden_size=768}
        ...

BiLSTM

拼接 char emb 和 word emb,并输入到 BiLSTM,以获得 word 表示;

class myModel(nn.Module):
    def __init__(self, pre_train_dir: str, dropout_rate: float):
        super().__init__()
        ...
        self.lstm=torch.nn.LSTM(input_size=768,hidden_size=768, \
                        num_layers=1,batch_first=True, \
                        dropout=0.5,bidirectional=True)
        ...

    def forward(self, input_ids, input_mask, input_seg, is_training=False):
        ...
        encoder_rep,_ = self.lstm(encoder_rep)# encoder_rep : {batch_size,max_seq_len,hidden_size * 2}
        ...

FFNN

从BiLSTM获得单词表示形式后,我们应用两个单独的FFNN为 span 的开始/结束创建不同的表示形式(hs / he)。对 span 的开始/结束使用不同的表示,可使系统学会单独识别 span 的开始/结束。与直接使用LSTM输出的模型相比,这提高了准确性,因为实体开始和结束的上下文不同

class myModel(nn.Module):
    def __init__(self, pre_train_dir: str, dropout_rate: float):
        ...
        self.start_layer = torch.nn.Sequential(
            torch.nn.Linear(in_features=2*768, out_features=128),
            torch.nn.ReLU()
        )
        self.end_layer = torch.nn.Sequential(
            torch.nn.Linear(in_features=2*768, out_features=128),
            torch.nn.ReLU()
        )
        ...

    def forward(self, input_ids, input_mask, input_seg, is_training=False):
        ...
        start_logits = self.start_layer(encoder_rep) # {batch_size,max_seq_len,out_features}
        end_logits = self.end_layer(encoder_rep) # {batch_size,max_seq_len,out_features}
        ...

biaffine model

句子上使用biaffine模型来创建 l×l×c 评分张量(rm),其中l是句子的长度,c 是 NER 类别的数量 +1(对于非实体)

  • si和ei是 span i 的开始和结束索引
  • Um 是 d×c×d 张量
  • Wm是2d×c矩阵
  • bm是偏差
# NER类别数量 + 1(对于非实体)
num_label = num_labels+1

class biaffine(nn.Module):
    def __init__(self, in_size, out_size, bias_x=True, bias_y=True):
        super().__init__()
        self.bias_x = bias_x
        self.bias_y = bias_y
        self.out_size = out_size
        self.U = torch.nn.Parameter(torch.Tensor(in_size + int(bias_x),out_size,in_size + int(bias_y))) 
    def forward(self, x, y):# {batch_size,max_seq_len,out_features}
        if self.bias_x:
            x = torch.cat((x, torch.ones_like(x[..., :1])), dim=-1)# {batch_size,max_seq_len,out_features + 1}
        if self.bias_y:
            y = torch.cat((y, torch.ones_like(y[..., :1])), dim=-1)
        bilinar_mapping = torch.einsum('bxi,ioj,byj->bxyo', x, self.U, y)# {bacth_size,max_seq_len,out_features,num_label}
        return bilinar_mapping
class myModel(nn.Module):
    def __init__(self, pre_train_dir: str, dropout_rate: float):
        ...
        self.biaffne_layer = biaffine(128,num_label)
        ...

    def forward(self, input_ids, input_mask, input_seg, is_training=False):
        ...
        span_logits = self.biaffne_layer(start_logits,end_logits)# {bacth_size,max_seq_len,out_features,num_label}
        span_logits = span_logits.contiguous()
        ...

冲突解决

张量 r_m 提供在 s_i≤e_i 的约束下(实体的起点在其终点之前)可以构成命名实体的所有可能 span 的分数。我们为每个跨度分配一个NER类别

然后,我们按照其类别得分 (r_m(i_{y’})) 降序对所有其他“非实体”类别的 span 进行排序,并应用以下后处理约束:对于嵌套的NER,只要选择了一个实体就不会与排名较高的实体发生冲突。对于 实体 i与其他实体 j ,如果 s_i<s_j≤e_i<e_j 或 s_j<s_i≤e_j<e_i ,那么这两个实体冲突。此时只会选择类别得分较高的 span

eg:
在 句子 : In the Bank of China 中, 实体 the Bank 的 边界与 实体 Bank of China 冲突,
注:对于 flat NER,我们应用了一个更多的约束,其中包含或在排名在它之前的实体之内的任何实体都将不会被选择。我们命名实体识别器的学习目标是为每个有效范围分配正确的类别(包括非实体)。

损失函数

因为该任务属于 多类别分类问题:

class myModel(nn.Module):
    def __init__(self, pre_train_dir: str, dropout_rate: float):
        ...

    def forward(self, input_ids, input_mask, input_seg, is_training=False):
        ...
        span_prob = torch.nn.functional.softmax(span_logits, dim=-1)# {bacth_size,max_seq_len,out_features,num_label}

        if is_training:
            return span_logits
        else:
            return span_prob

三、学习率衰减模块

class WarmUp_LinearDecay:
    def __init__(self, optimizer: optim.AdamW, init_rate, warm_up_epoch, decay_epoch, min_lr_rate=1e-8):
        self.optimizer = optimizer
        self.init_rate = init_rate
        self.epoch_step = train_data_length / args.batch_size
        self.warm_up_steps = self.epoch_step * warm_up_epoch
        self.decay_steps = self.epoch_step * decay_epoch
        self.min_lr_rate = min_lr_rate
        self.optimizer_step = 0
        self.all_steps = args.epoch*(train_data_length/args.batch_size)

    def step(self):
        self.optimizer_step += 1
        if self.optimizer_step <= self.warm_up_steps:
            rate = (self.optimizer_step / self.warm_up_steps) * self.init_rate
        elif self.warm_up_steps < self.optimizer_step <= self.decay_steps:
            rate = self.init_rate
        else:
            rate = (1.0 - ((self.optimizer_step - self.decay_steps) / (self.all_steps-self.decay_steps))) * self.init_rate
            if rate < self.min_lr_rate:
                rate = self.min_lr_rate
        for p in self.optimizer.param_groups:
            p["lr"] = rate
        self.optimizer.step()

四 、 损失函数定义

1.span_loss 损失函数定义

核心思想:对于模型学习到的所有实体的 start 和 end 位置,构造首尾实体匹配任务,即判断某个 start 位置是否与某个end位置匹配为一个实体,是则预测为1,否则预测为0,相当于转化为一个二分类问题,正样本就是真实实体的匹配,负样本是非实体的位置匹配

import torch
from torch import nn
from utils.arguments_parse import args
from data_preprocessing import tools
label2id,id2label,num_labels=tools.load_schema()
num_label = num_labels+1

class Span_loss(nn.Module):
    def __init__(self):
        super().__init__()
        self.loss_func = torch.nn.CrossEntropyLoss(reduction="none")

    def forward(self,span_logits,span_label,seq_mask):
        '''
        span_logits : {batch_size,max_seq_len,out_features,num_label}
        span_label : {batch_size,max_seq_len,out_features=128}
        span_mask : {batch_size,max_seq_len,max_seq_len=128}
        '''
        span_label = span_label.view(size=(-1,))# {batch_size * max_seq_len * out_features}
        span_logits = span_logits.view(size=(-1, num_label)) # {batch_size * max_seq_len * out_features,num_labels}
        span_loss = self.loss_func(input=span_logits, target=span_label) # {batch_size * max_seq_len * out_features}
        span_mask = seq_mask.view(size=(-1,)) # {batch_size * max_seq_len * out_features}
        span_loss *=span_mask
        avg_se_loss = torch.sum(span_loss) / seq_mask.size()[0]
        # avg_se_loss = torch.sum(sum_loss) / bsz
        return avg_se_loss

参考论文:[1910.11476] A Unified MRC Framework for Named Entity Recognition (arxiv.org)

focal_loss损失函数定义

  • 目标:解决分类问题中类别不平衡、分类难度差异的一个 loss
  • 思路:降低了大量简单负样本在训练中所占的权重,也可理解为一种困难样本挖掘

Focal loss是在交叉熵损失函数基础上进行的修改,首先回顾二分类交叉上损失

y’是经过激活函数的输出,所以在0-1之间。可见普通的交叉熵对于正样本而言,输出概率越大损失越小。对于负样本而言,输出概率越小则损失越小。此时的损失函数在大量简单样本的迭代过程中比较缓慢且可能无法优化至最优。那么Focal loss是怎么改进的呢?

首先在原有的基础上加了一个因子,其中gamma>0使得减少易分类样本的损失。使得更关注于困难的、错分的样本。

例如gamma为2,对于正类样本而言,预测结果为0.95肯定是简单样本,所以(1-0.95)的gamma次方就会很小,这时损失函数值就变得更小。而预测概率为0.3的样本其损失相对很大。对于负类样本而言同样,预测0.1的结果应当远比预测0.7的样本损失值要小得多。对于预测概率为0.5时,损失只减少了0.25倍,所以更加关注于这种难以区分的样本。这样减少了简单样本的影响,大量预测概率很小的样本叠加起来后的效应才可能比较有效。

此外,加入平衡因子alpha,用来平衡正负样本本身的比例不均:

只添加alpha虽然可以平衡正负样本的重要性,但是无法解决简单与困难样本的问题。

lambda调节简单样本权重降低的速率,当lambda为0时即为交叉熵损失函数,当lambda增加时,调整因子的影响也在增加。实验发现lambda为2是最优

import torch
import torch.nn as nn
import torch.nn.functional as F
class FocalLoss(nn.Module):
    '''Multi-class Focal loss implementation'''
    def __init__(self, gamma=2, weight=None, ignore_index=-100):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.weight = weight
        self.ignore_index = ignore_index

    def forward(self, input, target):
        """
        input: [N, C]
        target: [N, ]
        """
        logpt = F.log_softmax(input, dim=1)
        pt = torch.exp(logpt)
        logpt = (1 - pt) ** self.gamma * logpt
        loss = F.nll_loss(logpt, target, self.weight, ignore_index=self.ignore_index)
        return loss

参考论文:https://arxiv.org/pdf/1708.02002.pdf

五、模型训练

def train():
    # setp1:获取训练所需数据
    train_data = data_prepro.yield_data(args.train_path)
    test_data = data_prepro.yield_data(args.test_path)

    # step2 : 模型定义
    model = myModel(pre_train_dir=args.pretrained_model_path, dropout_rate=0.5).to(device)
    # model.load_state_dict(torch.load(args.checkpoints))

    # step3 : 优化函数定义
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'gamma', 'beta']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
            'weight_decay_rate': 0.01},
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
            'weight_decay_rate': 0.0}
    ]
    optimizer = optim.AdamW(params=optimizer_grouped_parameters, lr=args.learning_rate)

    schedule = WarmUp_LinearDecay(
                optimizer = optimizer, 
                init_rate = args.learning_rate,
                warm_up_epoch = args.warm_up_epoch,
                decay_epoch = args.decay_epoch
            )

    # step4 : 损失函数函数定义
    span_loss_func = span_loss.Span_loss().to(device)
    span_acc = metrics.metrics_span().to(device)

    # step5 : 训练
    step = 0
    best = 0
    for epoch in range(args.epoch):
        for item in train_data:
            step += 1
            # 模型输入
            input_ids, input_mask, input_seg = item["input_ids"], item["input_mask"], item["input_seg"] # {batch_size,max_seq_len}
            span_label,span_mask = item['span_label'],item["span_mask"] # {batch_size,max_seq_len,max_seq_len}
            optimizer.zero_grad()

            # 模型训练
            span_logits = model( 
                input_ids=input_ids.to(device), 
                input_mask=input_mask.to(device),
                input_seg=input_seg.to(device),
                is_training=True
            ) # span_logits:{batch_size,max_seq_len,out_features,num_label}

            # span损失
            span_loss_v = span_loss_func(span_logits,span_label.to(device),span_mask.to(device))
            loss = span_loss_v
            loss = loss.float().mean().type_as(loss)

            # 反向传播
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=args.clip_norm)
            schedule.step()
            # optimizer.step()

            # 打印此时模型的效果
            if step%100 == 0:
                span_logits = torch.nn.functional.softmax(span_logits, dim=-1)
                recall,precise,span_f1=span_acc(span_logits,span_label.to(device))
                logger.info('epoch %d, step %d, loss %.4f, recall %.4f, precise %.4f, span_f1 %.4f'% (epoch,step,loss,recall,precise,span_f1))

            # 测试
        with torch.no_grad():
            count=0
            span_f1=0
            recall=0
            precise=0

            for item in test_data:
                count+=1
                input_ids, input_mask, input_seg = item["input_ids"], item["input_mask"], item["input_seg"]
                span_label,span_mask = item['span_label'],item["span_mask"]

                optimizer.zero_grad()
                span_logits = model( 
                    input_ids=input_ids.to(device), 
                    input_mask=input_mask.to(device),
                    input_seg=input_seg.to(device),
                    is_training=False
                    ) 
                tmp_recall,tmp_precise,tmp_span_f1=span_acc(span_logits,span_label.to(device))
                span_f1+=tmp_span_f1
                recall+=tmp_recall
                precise+=tmp_precise
                
            span_f1 = span_f1/count
            recall=recall/count
            precise=precise/count

            logger.info('-----eval----')
            logger.info('epoch %d, step %d, loss %.4f, recall %.4f, precise %.4f, span_f1 %.4f'% (epoch,step,loss,recall,precise,span_f1))
            logger.info('-----eval----')
            if best < span_f1:
                best=span_f1
                torch.save(model.state_dict(), f=args.checkpoints)
                logger.info('-----save the best model----')

参考

nlp_paper_study_information_extraction/code_pytorch.md at main · km1994/nlp_paper_study_information_extraction (github.com)

实体识别之Biaffine双仿射注意力机制 - 知乎 (zhihu.com)

1 | 原来这也叫Dependency Parsing - 知乎 (zhihu.com)

Biaffine for NER:Named Entity Recognition as Dependency Parsing - 知乎 (zhihu.com)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/32012.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

ASEMI肖特基二极管SBT40100VFCT规格,SBT40100VFCT封装

编辑-Z ASEMI肖特基二极管SBT40100VFCT参数&#xff1a; 型号&#xff1a;SBT40100VFCT 最大重复峰值反向电压&#xff08;VRRM&#xff09;&#xff1a;100V 最大平均正向整流输出电流&#xff08;IF&#xff09;&#xff1a;40A 峰值正向浪涌电流&#xff08;IFSM&#…

使用kubeadm搭建高可用集群-k8s相关组件及1.16版本的安装部署

本文是向大家分享k8s相关组件及1.16版本的安装部署&#xff0c;它能够让大家初步了解k8s核心组件的原理及k8s的相关优势&#xff0c;有兴趣的同学可以部署安装下。 什么是kubernetes kubernetes是Google 开源的容器集群管理系统&#xff0c;是大规模容器应用编排系统&#xff…

ubuntu下jupyter notebook设置远程访问

1. 安装anaconda 推荐安装anaconda&#xff0c;安装后就会包含jupyter notebook 使用命令conda list或者pip list查看jupyter notebook包&#xff0c;这里不多介绍 2. 生成默认配置文件 在ubuntu环境下&#xff0c;安装jupyter notebook后&#xff0c;用户主目录中会有一个…

DRU-Net--一种用于医学图像分割的高效深度卷积神经网络

Title:DRU-NET: AN EFFICIENT DEEP CONVOLUTIONAL NEURAL NETWORK FOR MEDICAL IMAGE SEGMENTATION 摘要 本文的网络结构是受ResNet和DenseNet两个网络的启发而提出的。与ResNet相比本文的方法增加了额外的跳跃连接&#xff0c;但使用的模型参数要比DenseNet少的多。 基于先…

【创建型设计模式-单例模式】一文搞懂单例模式的使用场景及代码实现的7种方式(全)

1.什么是单例模式 在了解单例模式前&#xff0c;我们先来看一下它的定义&#xff1a; 确保一个类只有一个实例&#xff0c;而且自行实例化并且自行向整个系统提供这个实例&#xff0c;这个类称为单例类&#xff0c;它提供全局访问的方法&#xff0c; 单例模式是一种对象的创建型…

北京东物流,南顺丰速运

配图来自Canva可画 众所周知&#xff0c;“双11”是一年一度的物流高峰期&#xff0c;但2022年“双11”当日快递业务量并未达到预期水平&#xff0c;全年增速创下新低。据了解&#xff0c;“双11”当日业务量为5.52亿件&#xff0c;同比下滑了20.69%&#xff0c;而11月1日至11…

什么是CISAW认证?有什么价值?

随着信息技术的快速发展和信息化应用的不断深入&#xff0c;信息技术、产品及网络已经融入社会经济生活的方方面面&#xff0c;但同时信息安全问题也越来越突出。面对严峻的信息安全形势&#xff0c;我国将信息安全上升至国家战略&#xff0c;相继出台了一系列政策法规。那大家…

IDEA好用插件推荐

一、MavenHelper 当Maven Helper 插件安装成功后&#xff0c;打开项目中的pom文件&#xff0c;下面就会多出一个试图Dependency Analyzer 切换到此试图即可进行相应操作&#xff1a; Conflicts&#xff08;查看冲突&#xff09;All Dependencies as List&#xff08;列表形式…

数据仓库规范

模型设计 模型设计概述 为什么需要模型设计&#xff1f; Linux 的创始人 Torvalds 有 一段关于“什么才是优秀程序员”的话:“烂程序员关心的是代码&#xff0c;好程序员关心的是数据结构和它们之间的关系”&#xff0c;其阐述了数据模型的重要性。有了适合业务和基础数据存…

python 中__init__ 作用

__init__的作用&#xff1a; &#xff08;1&#xff09;声明包 &#xff08;2&#xff09;预加载模块内容 &#xff08;1&#xff09;声明包 python项目结构中&#xff0c;普通目录下无__init__文件&#xff1b;而包下是有__init__文件的。 python 项目结构是按目录来组织的…

R语言结课及Matlab开始

R语言结课 我们R语言的学习这节课下课就结束了&#xff0c;接下来进行Matlab的学习。下面我会说一下R的结课任务及如何考试&#xff0c;以及我自己整理的Matlab安装教程。 R的结课作业&#xff1a;周二上课时提到的两个回归模型课程总结&#xff08;老师说作业总结主要是作业…

如何运用java代码操作Redis

目录 1、java如何连接Redis&#xff1f; 1.1.启动Redis服务 1.2.导入相关Redis依赖 1.3.java代码进行连接 2、java连接Redis 2.1.String 2.1.1.设值 2.1.2.拿值 2.1.3.删除 2.1.4.修改 2.1.5.给键值对设置过期时间 2.1.6.获取键值对剩余的存活时间 2.2.哈希&#xff08;Hash&a…

jacoco单测报告怎么同步到sonarqube

sonarqube支持多种代码覆盖率的报告展示&#xff0c;最常用的当属jacoco报告&#xff0c;那么jacoco的报告怎么同步到我们的sonarqube中呢&#xff1f; 我们先看看jacoco的offline模式&#xff08;单元测试&#xff09;报告生成的流程 根据上图我们需要生成单测报告&#xff0…

Apollo 应用与源码分析:CyberRT-工具与命令

概念 cyberRT包括一个可视化工具cyber_visualizer和两个命令行工具cyber_monitor和cyber_recorder。 注意&#xff1a;使用这些工具需要apollo docker环境 并且Cyber RT 中提供了一些命令工具&#xff0c;可以方便快捷的解决上述问题&#xff0c;本部分内容就主要介绍这些命…

Clion学习

看看Cmake是个什么&#xff1f; 他是个构建管理工具 一个比较OK的图 cmake_minimum_required(VERSION 3.15)#指定了最小的Cmake版本 project(jcdd)#指定了项目名称 set(CMAKE_CXX_STANDARD 14) add_executable(jcdd main.cpp)#输出可执行文件的名称安装第三方库&#xff…

图解来啦!机器学习工业部署最佳实践!10分钟上手机器学习部署与大规模扩展 ⛵

&#x1f4a1; 作者&#xff1a;韩信子ShowMeAI &#x1f4d8; 机器学习实战系列&#xff1a;https://www.showmeai.tech/tutorials/41 &#x1f4d8; 深度学习实战系列&#xff1a;https://www.showmeai.tech/tutorials/42 &#x1f4d8; 本文地址&#xff1a;https://www.sho…

【MyBatis】动态SQL

if标签 CarMapper.java /*** 多条件查询* param brand 品牌* param guidePrice 指导价* param carType 汽车类型* return*/List<Car> selectByMultiCondition(Param("brand") String brand,Param("guidePrice") Double guidePrice,Param("car…

MySQL基础篇之MySQL概述

01、MySQL概述 1.1、数据库相关概念 1、数据库相关概念 名称解释说明简称数据库存储数据的仓库&#xff0c;数据是有组织的进行存储DataBase&#xff08;DB&#xff09;数据库管理系统操纵和管理数据库的大型软件DataBase Management System&#xff08;DBMS&#xff09;SQL…

ky使用教程(基于fetch的小巧优雅js的http客服端)

1.前言 react项目更加倾向于使用原生的fetch请求方式&#xff0c;而ky正是底层使用fetch的api做请求。github星数是8.2K&#xff0c;源码地址是&#xff1a;GitHub - sindresorhus/ky: &#x1f333; Tiny & elegant JavaScript HTTP client based on the browser Fetch A…

树上背包dp

“我们终其一生不过是为了一个AC罢了” 软件安装 嗯…这个题又调了一个下午&#xff0c;不过俺的确对dp方程有了一些理解 这个题没啥难的&#xff0c;不过是这个转移方程不太好想&#xff0c;过于抽象了&#xff0c;之前一直不理解树上背包是啥&#xff0c;现在理解了&#xff…