用BERT做命名实体识别任务

news2024/12/24 2:36:29

命名实体识别NER任务是NLP的一个常见任务,

它是Named Entity Recognization的简称。

简单地说,就是识别一个句子中的各种 名称实体。

诸如:人名,地名,机构 等。

例如对于下面这句话:

小明对小红说:"你听说过安利吗?"

它的NER抽取结果如下:

[{'entity': 'person',
  'word': '小明',
  'start': 0,
  'end': 2},
 {'entity': 'person',
  'word': '小红',
  'start': 3,
  'end': 5},
 {'entity': 'organization',
  'word': '安利',
  'start': 12,
  'end': 14}]

本质上NER是一个token classification任务, 需要把文本中的每一个token做一个分类。

那些不是命名实体的token,一般用大'O'表示。

值得注意的是,由于有些命名实体是由连续的多个token构成的,为了避免有两个连续的相同的命名实体无法区分,需要对token是否处于命名实体的开头进行区分。

例如,对于下面这句话。

我爱北京天安门

如果我们不区分token是否为命名实体的开头的话,可能会得到这样的token分类结果。

我(O) 爱(O) 北(Loc) 京(Loc) 天(Loc) 安(Loc) 门(Loc)

然后我们做后处理的时候,把类别相同的token连起来,会得到一个location实体 '北京天安门'。

但是,’北京‘ 和 ’天安门‘ 是两个不同的location实体,把它们区分开来更加合理一些. 因此我们可以这样对token进行分类。

我(O) 爱(O) 北(B-Loc) 京(I-Loc) 天(B-Loc) 安(I-Loc) 门(I-Loc)

我们用 B-Loc表示这个token是一个location实体的开始token,用I-Loc表示这个token是一个location实体的内部(包括中间以及结尾)token.

这样,我们做后处理的时候,就可以把 B-loc以及它后面的 I-loc连成一个实体。这样就可以得到’北京‘ 和 ’天安门‘ 是两个不同的location的结果了。

区分token是否是entity开头的好处是我们可以把连续的同一类别的的命名实体进行区分,坏处是分类数量会几乎翻倍(n+1->2n+1)。

在许多情况下,出现这种连续的同命名实体并不常见,但为了稳妥起见,区分token是否是entity开头还是十分必要的。

一,准备数据

公众号算法美食屋后台回复关键词:torchkeras,获取本文notebook代码和车道线数据集下载链接。

import numpy as np 
import pandas as pd 

from transformers import BertTokenizer
from torch.utils.data import DataLoader,Dataset 
from transformers import DataCollatorForTokenClassification
import datasets

1,数据加载

datadir = "./data/cluener_public/"

train_path = datadir+"train.json"
val_path = datadir+"dev.json"

dftrain = pd.read_json(train_path,lines=True)
dfval = pd.read_json(train_path,lines=True)

entities = ['address','book','company','game','government','movie',
              'name','organization','position','scene']

label_names = ['O']+['B-'+x for x in entities]+['I-'+x for x in entities]

id2label = {i: label for i, label in enumerate(label_names)}
label2id = {v: k for k, v in id2label.items()}
text = dftrain["text"][43]
label = dftrain["label"][43]

print(text)
print(label)
世上或许有两个人并不那么喜欢LewisCarroll的原著小说《爱丽斯梦游奇境》(
{'book': {'《爱丽斯梦游奇境》': [[31, 39]]}, 
'name': {'LewisCarroll': [[14, 25]]}}

2,文本分词

from transformers import BertTokenizer
model_name = 'bert-base-chinese'
tokenizer = BertTokenizer.from_pretrained(model_name)
tokenized_input = tokenizer(text)
print(tokenized_input["input_ids"])
[101, 686, 677, 2772,..., 518, 113, 102]
#可以从id还原每个token对应的字符组合
tokens = tokenizer.convert_ids_to_tokens(tokenized_input["input_ids"])
for t in tokens:
    print(t)
[CLS]
世
上
或
许
有
两
个
人
并
不
那
么
喜
欢
[UNK]
的
原
著
小
说
《
爱
丽
斯
梦
游
奇
境
》
(
[SEP]

3,标签对齐

可以看到,经过文本分词后的token长度与文本长度并不相同,

主要有以下一些原因导致:一是BERT分词后会增加一些特殊字符如 [CLS],[SEP]

二是,还会有一些英文单词的subword作为一个 token. (如这个例子中的 'charles'),

此外,还有一些未在词典中的元素被标记为[UNK]会造成影响。

因此需要给这些token赋予正确的label不是一个容易的事情。

我们分两步走,第一步,把原始的dict形式的label转换成字符级别的char_label

第二步,再将char_label对齐到token_label

# 把 label格式转化成字符级别的char_label
def get_char_label(text,label):
    char_label = ['O' for x in text]
    for tp,dic in label.items():
        for word,idxs in dic.items():
            idx_start = idxs[0][0]
            idx_end = idxs[0][1]
            char_label[idx_start] = 'B-'+tp
            char_label[idx_start+1:idx_end+1] = ['I-'+tp for x in range(idx_start+1,idx_end+1)]
    return char_label
char_label = get_char_label(text,label)
for char,char_tp in zip(text,char_label):
    print(char+'\t'+char_tp)
世	O
上	O
或	O
许	O
有	O
两	O
个	O
人	O
并	O
不	O
那	O
么	O
喜	O
欢	O
L	B-name
e	I-name
w	I-name
i	I-name
s	I-name
C	I-name
a	I-name
r	I-name
r	I-name
o	I-name
l	I-name
l	I-name
的	O
原	O
著	O
小	O
说	O
《	B-book
爱	I-book
丽	I-book
斯	I-book
梦	I-book
游	I-book
奇	I-book
境	I-book
》	I-book
(	O
def get_token_label(text, char_label, tokenizer):
    tokenized_input = tokenizer(text)
    tokens = tokenizer.convert_ids_to_tokens(tokenized_input["input_ids"])
    
    iter_tokens = iter(tokens)
    iter_char_label = iter(char_label)  
    iter_text = iter(text.lower()) 

    token_labels = []

    t = next(iter_tokens)
    char = next(iter_text)
    char_tp = next(iter_char_label)

    while True:
        #单个字符token(如汉字)直接赋给对应字符token
        if len(t)==1:
            assert t==char
            token_labels.append(char_tp)   
            try:
                char = next(iter_text)
                char_tp = next(iter_char_label)
            except StopIteration:
                pass  

        #添加的特殊token如[CLS],[SEP],排除[UNK]
        elif t in tokenizer.special_tokens_map.values() and t!='[UNK]':
            token_labels.append('O')              


        elif t=='[UNK]':
            token_labels.append(char_tp) 
            #重新对齐
            try:
                t = next(iter_tokens)
            except StopIteration:
                break 

            if t not in tokenizer.special_tokens_map.values():
                while char!=t[0]:
                    try:
                        char = next(iter_text)
                        char_tp = next(iter_char_label)
                    except StopIteration:
                        pass    
            continue

        #其它长度大于1的token,如英文token
        else:
            t_label = char_tp
            t = t.replace('##','') #移除因为subword引入的'##'符号
            for c in t:
                assert c==char or char not in tokenizer.vocab
                if t_label!='O':
                    t_label=char_tp
                try:
                    char = next(iter_text)
                    char_tp = next(iter_char_label)
                except StopIteration:
                    pass    
            token_labels.append(t_label) 

        try:
            t = next(iter_tokens)
        except StopIteration:
            break  
            
    assert len(token_labels)==len(tokens)
    return token_labels
token_labels = get_token_label(text,char_label,tokenizer)
for t,t_label in zip(tokens,token_labels):
    print(t,'\t',t_label)
[CLS] 	 O
世 	 O
上 	 O
或 	 O
许 	 O
有 	 O
两 	 O
个 	 O
人 	 O
并 	 O
不 	 O
那 	 O
么 	 O
喜 	 O
欢 	 O
[UNK] 	 B-name
的 	 O
原 	 O
著 	 O
小 	 O
说 	 O
《 	 B-book
爱 	 I-book
丽 	 I-book
斯 	 I-book
梦 	 I-book
游 	 I-book
奇 	 I-book
境 	 I-book
》 	 I-book
( 	 O
[SEP] 	 O

4,构建管道

dftrain.head()

51f8d021b4b43b9f4cb3db995417b477.png

def make_sample(text,label,tokenizer):
    sample = tokenizer(text)
    char_label = get_char_label(text,label)
    token_label = get_token_label(text,char_label,tokenizer)
    sample['labels'] = [label2id[x] for x in token_label]
    return sample
from tqdm import tqdm 
train_samples = [make_sample(text,label,tokenizer) for text,label in 
                 tqdm(list(zip(dftrain['text'],dftrain['label'])))]
val_samples = [make_sample(text,label,tokenizer) for text,label in 
                 tqdm(list(zip(dfval['text'],dfval['label'])))]
100%|██████████| 10748/10748 [00:06<00:00, 1717.47it/s]
100%|██████████| 10748/10748 [00:06<00:00, 1711.10it/s]
ds_train = datasets.Dataset.from_list(train_samples)
ds_val = datasets.Dataset.from_list(val_samples)
data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer)
dl_train = DataLoader(ds_train,batch_size=8,collate_fn=data_collator)
dl_val = DataLoader(ds_val,batch_size=8,collate_fn=data_collator)
for batch in dl_train:
    break

二,定义模型

from transformers import BertForTokenClassification

net = BertForTokenClassification.from_pretrained(
    model_name,
    id2label=id2label,
    label2id=label2id,
)

#冻结bert基模型参数
for para in net.bert.parameters():
    para.requires_grad_(False)

print(net.config.num_labels) 

#模型试算
out = net(**batch)
print(out.loss) 
print(out.logits.shape)

ad597c4424d147c8f854813634a263b7.png

三,训练模型

import torch 
from torchkeras import KerasModel 

#我们需要修改StepRunner以适应transformers的数据集格式

class StepRunner:
    def __init__(self, net, loss_fn, accelerator, stage = "train", metrics_dict = None, 
                 optimizer = None, lr_scheduler = None
                 ):
        self.net,self.loss_fn,self.metrics_dict,self.stage = net,loss_fn,metrics_dict,stage
        self.optimizer,self.lr_scheduler = optimizer,lr_scheduler
        self.accelerator = accelerator
        if self.stage=='train':
            self.net.train() 
        else:
            self.net.eval()
    
    def __call__(self, batch):
        
        out = self.net(**batch)
        
        #loss
        loss= out.loss
        
        #preds
        preds =(out.logits).argmax(axis=2) 
    
        #backward()
        if self.optimizer is not None and self.stage=="train":
            self.accelerator.backward(loss)
            self.optimizer.step()
            if self.lr_scheduler is not None:
                self.lr_scheduler.step()
            self.optimizer.zero_grad()
        
        all_loss = self.accelerator.gather(loss).sum()
        
        labels = batch['labels']
        
        #precision & recall
        
        precision =  (((preds>0)&(preds==labels)).sum())/(
            torch.maximum((preds>0).sum(),torch.tensor(1.0).to(preds.device)))
        recall =  (((labels>0)&(preds==labels)).sum())/(
            torch.maximum((labels>0).sum(),torch.tensor(1.0).to(labels.device)))
    
        
        all_precision = self.accelerator.gather(precision).mean()
        all_recall = self.accelerator.gather(recall).mean()
        
        f1 = 2*all_precision*all_recall/torch.maximum(
            all_recall+all_precision,torch.tensor(1.0).to(labels.device))
        
        #losses
        step_losses = {self.stage+"_loss":all_loss.item(), 
                       self.stage+'_precision':all_precision.item(),
                       self.stage+'_recall':all_recall.item(),
                       self.stage+'_f1':f1.item()
                      }
        
        #metrics
        step_metrics = {}
        
        if self.stage=="train":
            if self.optimizer is not None:
                step_metrics['lr'] = self.optimizer.state_dict()['param_groups'][0]['lr']
            else:
                step_metrics['lr'] = 0.0
        return step_losses,step_metrics
    
KerasModel.StepRunner = StepRunner
optimizer = torch.optim.AdamW(net.parameters(), lr=3e-5)

keras_model = KerasModel(net,
                   loss_fn=None,
                   optimizer = optimizer
                   )
keras_model.fit(
    train_data = dl_train,
    val_data= dl_val,
    ckpt_path='bert_ner.pt',
    epochs=50,
    patience=5,
    monitor="val_f1", 
    mode="max",
    plot = True,
    wandb = False,
    quiet = True
)

713beb26fe8fd7b653e5e0aacab9a0d5.png

四,评估模型

from torchmetrics import Accuracy
acc = Accuracy(task='multiclass',num_classes=21)
acc = keras_model.accelerator.prepare(acc)

dl_test = keras_model.accelerator.prepare(dl_val)
net = keras_model.accelerator.prepare(net)
from tqdm import tqdm 
for batch in tqdm(dl_test):
    with torch.no_grad():
        outputs = net(**batch)
        
    labels = batch['labels']
    labels[labels<0]=0
    #preds
    preds =(outputs.logits).argmax(axis=2) 
    acc.update(preds,labels)
acc.compute()  #这里的acc包括了 ’O‘的分类结果,存在高估。
tensor(0.9178, device='cuda:0')

五,使用模型

我们可以使用pipeline来串起整个预测流程.

注意我们这里使用内置的'simple'这个aggregation_strategy,

把应该归并的token自动归并成一个entity.

from transformers import pipeline
recognizer = pipeline("token-classification", 
                      model=net, tokenizer=tokenizer, aggregation_strategy='simple')
net.to('cpu');
recognizer('小明对小红说,“你听说过安利吗?”')
[{'entity_group': 'name',
  'score': 0.6913842,
  'word': '小 明',
  'start': None,
  'end': None},
 {'entity_group': 'name',
  'score': 0.58951116,
  'word': '小 红',
  'start': None,
  'end': None},
 {'entity_group': 'name',
  'score': 0.74060774,
  'word': '安 利',
  'start': None,
  'end': None}]

六,保存模型

保存model和tokenizer之后,我们可以用一个pipeline加载,并进行批量预测。

net.save_pretrained("ner_bert")
tokenizer.save_pretrained("ner_bert")
('ner_bert/tokenizer_config.json',
 'ner_bert/special_tokens_map.json',
 'ner_bert/vocab.txt',
 'ner_bert/added_tokens.json')
recognizer = pipeline("token-classification", 
                      model="ner_bert",
                      aggregation_strategy='simple')
recognizer('小明对小红说,“你听说过安利吗?”')
[{'entity_group': 'name',
  'score': 0.6913842,
  'word': '小 明',
  'start': 0,
  'end': 2},
 {'entity_group': 'name',
  'score': 0.58951116,
  'word': '小 红',
  'start': 3,
  'end': 5},
 {'entity_group': 'name',
  'score': 0.74060774,
  'word': '安 利',
  'start': 12,
  'end': 14}]

公众号后台回复关键词:torchkeras,获取本文notebook源码数据集以及更多有趣炼丹范例~

d3ebc76c457278f76c27ceba3990b48d.png

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/686520.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

佩戴舒适度的蓝牙耳机品牌有哪些?佩戴舒适性蓝牙耳机排行榜推荐

​对于年轻人来说&#xff0c;耳机使用场景丰富&#xff0c;时尚追求度高&#xff0c;喜好的音乐类型也是多种多样&#xff0c;需求侧重也不尽相同。下面我来推荐几款相当不错的蓝牙耳机给大家&#xff0c;总会有喜欢那款&#xff01; 一、南卡OE PRO开放式耳机 南卡OE PRO是国…

带你阅读 Flutter Demo(flutter 保姆级入门教程)

dart、flutter Flutter Demo 解析 - 文章信息 - Author: Jack Lee (jcLee95) Visit me at: https://jclee95.blog.csdn.netEmail: 291148484163.com. Shenzhen ChineAddress of this article:https://blog.csdn.net/qq_28550263/article/details/xxxxxx 【介绍】&#xff1a;本…

RFID技术的革新与应用:连接智能物联网的关键

在日益数字化的时代&#xff0c;物联网&#xff08;IoT&#xff09;技术正迅速发展&#xff0c;并为我们的生活带来了无数的便利。而射频识别&#xff08;RFID&#xff09;技术作为物联网的关键支撑之一&#xff0c;正在推动着智能化、自动化的进程。本文将深入探讨RFID技术的基…

INTERSPEECH 2023论文|基于自监督学习表示的具有持久性口音记忆的口音识别

论文题目&#xff1a; Self-supervised Learning Representation based Accent Recognition with Persistent Accent Memory 作者列表&#xff1a; 李睿&#xff0c;谢志伟&#xff0c;徐海华&#xff0c;彭亦周&#xff0c;刘和鑫&#xff0c;黄浩&#xff0c;Chng Eng Sio…

神州设备IPV6路由综合运用

实训拓扑图 一、基本配置: SW-1: SW-1>ena SW-1#conf SW-1(config)#vlan 10;100 SW-1(config)#int l1 SW-1(config-if-loopback1)#ip add 1.1.1.1 255.255.255.255 SW-1(config-if-loopback1)#ipv6 add 2001:1::1/128 SW-1(config-if-loopback1)#exit

前端系列18集-权限,nginx成功,屏幕分辨率,vue3

vue3.0 使用原生websocket通信 // Websoket连接成功事件const websocketonopen (res: any) > {console.log("WebSocket连接成功", res);};// Websoket接收消息事件const websocketonmessage (res: any) > {console.log("数据", res);};// Websoket…

【从零开始学习C++ | 第二十二篇】C++新增特性(下)

目录 前言&#xff1a; 类型推导&#xff1a; constexpr关键字&#xff1a; 初始化列表&#xff1a; 基于范围的for循环&#xff1a; 智能指针之unique ptr Lambda表达式&#xff1a; 总结&#xff1a; 前言&#xff1a; 本文我们将继续介绍 C 11 新增十大特性的剩余…

解决前端容器不能充满屏幕

解决前端容器不能充满屏幕 px、rpx、em、rem、vw、vh各种像素单位的区别 css3新单位vw、vh、vmin、vmax的使用详解 学习element-UI写管理系统的页面&#xff0c;发现当菜单栏都收缩起来&#xff0c;结果是这样的 红色框是容器里每个板块的布局&#xff0c;但是容器下面却有空白…

如何处理兼容性测试中的变更管理?

如何处理兼容性测试中的变更管理? 在进行软件测试的过程中&#xff0c;兼容性测试是&#xfeff;非常重要的一环。然而&#xff0c;在进行兼容性测试时&#xff0c;由于涉及到不同平台、不同设备的适配问题&#xff0c;可能会出现许多变更管理的情况。这时候&#xff0c;如果没…

阿里企业邮箱收费标准_企业邮箱费用明细表

阿里云企业邮箱收费标准&#xff08;免费版/标准/尊享/集团&#xff09;&#xff0c;2023阿里云企业邮箱收费标准&#xff0c;免费版企业邮箱0元&#xff0c;标准版企业邮箱540元一年&#xff08;原价600元一年&#xff09;&#xff0c;企业邮箱尊享版1400元一年&#xff0c;9折…

4-移动端适配-1

01-移动 Web 基础 谷歌模拟器 模拟移动设备&#xff0c;方便查看页面效果 屏幕分辨率 分类&#xff1a; 物理分辨率&#xff1a;硬件分辨率&#xff08;出厂设置&#xff09;逻辑分辨率&#xff1a;软件 / 驱动设置 结论&#xff1a;制作网页参考 逻辑分辨率 视口 作用&a…

微信小程序项目实例——2048小游戏

今日推荐&#x1f481;‍♂️ 第一次听廖俊涛的歌是他首次出现在明日之子舞台上的那首《谁》 到现在这首歌成了我网易云收藏的十几首歌中的一首&#xff0c;也是听的最多的一首 怎么形容呢&#x1f914;算不上惊艳&#xff0c;却百听不厌&#x1f442; &#x1f52e;&#x1…

直播美颜SDK的商业化应用:开发者需要注意的关键问题

直播美颜SDK是当前直播行业中十分热门的技术之一&#xff0c;它可以为直播平台提供高质量的美颜效果&#xff0c;提升直播用户的使用体验和观看体验。随着直播市场的不断扩大和竞争的加剧&#xff0c;越来越多的直播平台开始使用美颜SDK以提高自身的用户黏性和用户体验。那么&a…

二叉树OJ题:LeetCode--100.相同的树

朋友们、伙计们&#xff0c;我们又见面了&#xff0c;本期来给大家解读一下LeetCode中第100道二叉树OJ题&#xff0c;如果看完之后对你有一定的启发&#xff0c;那么请留下你的三连&#xff0c;祝大家心想事成&#xff01; 数据结构与算法专栏&#xff1a;数据结构与算法 个 人…

脑机接口:运动想象简介

脑机接口&#xff1a;运动想象简介 0. 脑机接口1. 运动想象2. 信号处理2.1 信号采集2.2 信号预处理2.3 特征提取2.4 分类识别 3. EEG波段介绍4. 脑电图电极定位5. 总结 0. 脑机接口 脑机接口&#xff08;Brain-Computer Interface&#xff0c; BCI&#xff09;&#xff1a;它是…

MATLAB 之 可视化图形用户界面设计

这里写目录标题 一、可视化图形用户界面设计1. 图形用户界面设计窗口1.1 图形用户界面设计模板1.2 图形用户界面设计窗口 2. 可视化图形用户界面设计工具1.1 对象属性检查器2.2 菜单编辑器2.3 工具栏编辑器2.4 对齐对象工具2.5 对象浏览器2.6 Tab 键顺序编辑器 3. 可视化图形用…

途乐证券|股票XR是什么意思?买股票为什么赚不到钱?

股票市场上有时会出现一些股票在其名称前加上英文字母的情况&#xff0c;比如XD、XR等。那么股票XR是什么意思&#xff1f;买股票为什么赚不到钱&#xff1f;途乐证券为大家准备了相关内容&#xff0c;以供参考。 股票XR是什么意思&#xff1f; 股票名称中带有XR是表示股票在进…

yolov5-cls部署之onnx导出

本文旨在介绍说明yolov5自带的分类如何导出动态的batch的onnx。其中输出两种形式&#xff1a; 形式&#xff08;1&#xff09;&#xff1a;导出带softmax映射到概率的 形式&#xff08;2&#xff09;&#xff1a;导出不带softmax的&#xff0c;这个也是官方默认的方式 一、动…

连接服务器,再连接VSCode

一、 创建账号&#xff0c;查找公钥 通过命令窗口 a. 打开你的 git bash 窗口 b. 进入 .ssh 目录&#xff1a;cd ~/.ssh c. 找到 id_rsa.pub 文件&#xff1a;ls d. 查看公钥&#xff1a;cat id_rsa.pub 或者 vim id_rsa.pub 查看本机 ssh 公钥&#xff0c;生成公钥 二、用…

Sangfor华东天勇战队:mybatis-plus demo

基本依赖添加&#xff0c;表创建&#xff0c;启动类&#xff0c;测试类 引入依赖&#xff1a; <!-- mybatis-plus 依赖--> <dependency><groupId>com.baomidou</groupId><artifactId>mybatis-plus-boot-starter</artifactId><version…