自然语言处理实战项目8- BERT模型的搭建,训练BERT实现实体抽取识别的任务

news2024/10/2 10:37:21

大家好,我是微学AI,今天给大家介绍一下自然语言处理实战项目8- BERT模型的搭建,训练BERT实现实体抽取识别的任务。BERT模型是一种用于自然语言处理的深度学习模型,它可以通过训练来理解单词之间的上下文关系,从而为下游任务提供高质量的语言表示。它的结构是由多个Transformer编码器组成的,而Transformer编码器是由多个自注意力机制组成的。在训练中,模型通过预测遮盖的单词和判断两个句子之间的关系来提高语言表示的准确性。在实体识别任务中,BERT模型可以作为特征提取器使用,将每个单词的上下文相关的向量表示输入到分类器中完成实体识别。

一、BERT模型的框架

BERT的基础结构是多层的Transformer编码器架构。Transformer是一种自注意力机制,允许模型在不同的词语之间捕获重要的关系。具体而言,BERT使用自注意力头为文本序列中的每个单词生成一个向量表示,同时捕捉了整个句子的上下文信息。这些向量表示可以从底层到更高层进行组合,从而允许模型学习更加复杂的语义结构。

BERT模型有两种主要的预训练模型:
1.BERT-Base:包含12层(Encoder layers)、12个自注意力头(Attention heads)和768个隐藏层大小(Hidden size),总共有约 110M 个参数。
2.BERT-Large:包含 24层(Encoder layers)、16个自注意力头(Attention heads)和1024个隐藏层大小(Hidden size),总共约340M个参数。

 二、BERT的预训练

BERT的预训练主要分为两个阶段:预训练和微调。

2.1 预训练:

在预训练阶段,BERT的创新之举是利用大量无标注文本进行双向训练。在这个阶段,BERT引入了两个预训练任务:掩码语言建模(MLM) 和 下一句预测(NSP)。
 掩码语言建模(MLM):在训练过程中,输入句子中的一部分单词被随机替换为一个特殊的掩码符号(MASK)。模型的目标是根据句子其他部分的上下文信息,预测被掩码的单词。这种训练方式使模型能够学习双向的语义信息。
下一句预测(NSP):这个任务旨在让模型学习理解句子之间的关系。给定一对句子,模型需要预测第二个句子是否紧跟在第一个句子之后。这个任务有助于BERT更好地应对需要理解多个句子关系的任务,如问答和自然语言推理等。

2.2 微调:

在预训练阶段完成后,BERT模型已经学会了丰富的语义表示。然后,在实际的NLP任务中,我们需要对预训练的BERT进行微调。微调阶段只需较少的标注数据就能使模型针对具体任务进行优化。
在微调过程中,通常在BERT模型的顶部添加一个任务相关的神经网络层,如全连接层、卷积层等。然后连同BERT整个模型进行端到端的微调训练。在训练时,利用带标签的数据计算损失,利用梯度下降进行参数更新。经过微调后,BERT能够根据特定任务生成更有针对性的结果。

三、训练BERT实现实体抽取识别的任务

以下是一个使用PyTorch和Hugging Face Transformers库的BERT模型进行中文命名实体识别任务的完整代码,第三方库载入,数据样例导入。

import torch
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer, BertForTokenClassification, AdamW
from sklearn.model_selection import train_test_split
from tqdm import tqdm

def load_data_from_txt(file_path):
    with open(file_path, "r", encoding="utf-8") as f:
        lines = f.readlines()
    data = []
    for line in lines:
        line = line.strip()
        if not line:
            continue
        token, label = line.split()
        data.append((token, label))
    return data

file_path = 'ner_data.txt'
data = load_data_from_txt(file_path)
print(data)

ner_data.txt文件数据样例:

李 B-PER
华 I-PER
山 I-PER
是 O
一 O
个 O
优 O
秀 O
的 O
程 O
序 O
员 O
。 O
阿 B-ORG
里 I-ORG
巴 I-ORG
巴 I-ORG
是 O
一 O
家 O
著 O
名 O
的 O
中 B-LOC
国 I-LOC
公 O
司 O
。 O
陈 B-PER
明 I-PER
在 O
北 B-LOC
京 I-LOC
上 O
了 O
一 O
所 O
大 O
学 O
。 O

模型加载与训练:

# 预处理
tokenizer = BertTokenizer.from_pretrained("bert-base-chinese")
label_map = {"B-PER": 0, "I-PER": 1, "B-ORG": 2, "I-ORG": 3, "B-EDU":4,"I-EDU":5, "B-LOC":6,"I-LOC":7,"O": 8}
inverse_label_map = {v: k for k, v in label_map.items()}

class NERDataset(Dataset):
    def __init__(self, data, tokenizer, label_map):
        self.data = data
        self.tokenizer = tokenizer
        self.label_map = label_map

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        token, label = self.data[idx]
        input_ids = tokenizer.encode(token, add_special_tokens=False)
        label_id = self.label_map[label]
        return torch.tensor(input_ids, dtype=torch.long), torch.tensor(label_id, dtype=torch.long)

dataset = NERDataset(data, tokenizer, label_map)
train_data, val_data = train_test_split(dataset, test_size=0.2, random_state=42)
train_loader = DataLoader(train_data, batch_size=1, shuffle=True)
val_loader = DataLoader(val_data, batch_size=1, shuffle=False)

# 模型
model = BertForTokenClassification.from_pretrained("bert-base-chinese", num_labels=len(label_map))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# 优化器
optimizer = AdamW(model.parameters(), lr=5e-5)

# 训练
num_epochs = 8
for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    total_correct = 0
    total_count = 0

    for batch in tqdm(train_loader):
        input_ids, labels = batch
        input_ids = input_ids.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        outputs = model(input_ids, labels=labels)
        loss = outputs.loss
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        total_correct += (outputs.logits.argmax(-1) == labels).sum().item()
        total_count += labels.size(0)

    avg_loss = total_loss / total_count
    accuracy = total_correct / total_count
    print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {avg_loss:.4f}, Accuracy: {accuracy:.4f}")

运行结果:

100%|██████████| 97/97 [00:39<00:00,  2.47it/s]
  0%|          | 0/97 [00:00<?, ?it/s]Epoch 1/10, Loss: 1.4691, Accuracy: 0.5979
100%|██████████| 97/97 [00:40<00:00,  2.42it/s]
Epoch 2/10, Loss: 1.3695, Accuracy: 0.6598
100%|██████████| 97/97 [00:39<00:00,  2.48it/s]
Epoch 3/10, Loss: 1.2924, Accuracy: 0.5979
100%|██████████| 97/97 [00:39<00:00,  2.48it/s]
  0%|          | 0/97 [00:00<?, ?it/s]Epoch 4/10, Loss: 1.3100, Accuracy: 0.6701
100%|██████████| 97/97 [00:37<00:00,  2.59it/s]
Epoch 5/10, Loss: 1.2179, Accuracy: 0.6598
100%|██████████| 97/97 [00:40<00:00,  2.39it/s]
  0%|          | 0/97 [00:00<?, ?it/s]Epoch 6/10, Loss: 0.9726, Accuracy: 0.6495
100%|██████████| 97/97 [00:39<00:00,  2.46it/s]
  0%|          | 0/97 [00:00<?, ?it/s]Epoch 7/10, Loss: 1.0536, Accuracy: 0.6186
100%|██████████| 97/97 [00:40<00:00,  2.42it/s]
Epoch 8/10, Loss: 0.9458, Accuracy: 0.6907

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/584879.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

spring boot--web响应

2. 响应 前面我们学习过HTTL协议的交互方式&#xff1a;请求响应模式&#xff08;有请求就有响应&#xff09; 那么Controller程序呢&#xff0c;除了接收请求外&#xff0c;还可以进行响应。 2.1 ResponseBody 在我们前面所编写的controller方法中&#xff0c;都已经设置了…

spring集成mybatis

目录 (1)新建javaEE web项目 ​(2)加入相关依赖的坐标 (3) 创建相应的包和类 (4) 配置spring和mybatis的配置文件 在resources中建mybatis-config.xml 在 resources中建spring.xml 在 resources中建db.xml 在 resources中建config.propertis 集成mybatis配置 ,导入myb…

MyBatis参数传递(提供ParamNameResolver类来进行参数封装)源码分析

MyBatis接口方法中可以接收各种各样的参数&#xff0c;MyBatis底层对于这些参数进行不同的封装处理方式。 单个参数&#xff1a;实体类、Map集合、Collection、List、Array以及其他类型。 多个参数&#xff1a;Param注解定义的名称要与sql语句中参数占位符中的名称相同。 这里…

RDD缓存有哪些特点?

RDD之间进行相互迭代计算(Transformation的转换)&#xff0c;当执行开启后&#xff0c;新RDD的生成&#xff0c;代表老RDD的消失。RDD的数据是过程数据&#xff0c;只在处理的过程中存在&#xff0c;一旦处理完成&#xff0c;就不见了。这个特性可以最大化的利用资源&#xff0…

【CSAPP】Binarybomb 实验(phase_1-6+secret_phase)

Binarybomb 实验&#xff08;phase_1-6secret_phase&#xff09; 实验内容 一个“binary bombs”&#xff08;二进制炸弹&#xff0c;下文将简称为炸弹&#xff09;是一个Linux可执行C程序&#xff0c;包含了7个阶段&#xff08;phase1~phase6和一个隐藏阶段&#xff09;。炸…

【CANoe示例分析】0002_SOMEIPDemo

该工程由Vector官方提供,作为仿真SOME/IP节点的示例。Demo中介绍了两种仿真SOME/IP节点的方法,一种是基于arxml数据库的仿真,另一种是没有数据库(arxml、fibex)的仿真。 无论是哪种形式的仿真,如果想要通过CAPL程序接收或者发送SOME/IP信息,都需要添加交互 层的信息,这…

不懂就问:年薪百万的程序员是怎么做到的?

很多人对程序员的第一反应就是“工资高”。 从行业平均薪酬来看&#xff0c;“程序员”相关专业的收入确实更高一点。 但是&#xff0c;“程序员”内部薪资却存在着很大的差异&#xff0c;多数人月薪在1-2万&#xff0c;一线城市可以达到3-5万&#xff0c;而顶级程序员&#…

探索Java面向对象编程的奇妙世界(六)

⭐ 多态(polymorphism)⭐ 对象的转型(casting)⭐ 抽象类⭐ 接口 interface ⭐ 多态(polymorphism) 多态指的是同一个方法调用&#xff0c;由于对象不同可能会有不同的行为。现实生活中&#xff0c;同一个方法&#xff0c;具体实现会完全不同。 比如&#xff1a;同样是调用人“吃…

回归方程的显著性检验——F检验

回归方程的显著性检验——F检验 9.2 回归方程的显著性检验 (edu-edu.com.cn) 概念 记号&#xff1a; y i y_i yi​&#xff1a;真实值&#xff0c;观测值 y ˉ \bar{y} yˉ​&#xff1a;真实值的平均值 y ^ \hat{y} y^​&#xff1a;估计值&#xff0c;预测值 几个差&#x…

Activiti、Flowable与CCFlow的选型对比

前言 工作流是什么&#xff0c;这个问题我们就不在此进行解释了&#xff0c;这里我们主要讲解一下Activiti、Flowable和CCFlow三款工作流的对比&#xff0c;为大家选型时做一些参考。 Activiti和Flowable大家可能多少都听说过&#xff0c;都是国外的工作流引擎&#xff0c;都…

Axure教程—单色面积图(中继器)

本文将教大家如何用AXURE制作单色面积图 一、效果介绍 如图&#xff1a; 预览地址&#xff1a;https://icg26y.axshare.com/ 下载地址&#xff1a;https://download.csdn.net/download/weixin_43516258/87837919?spm1001.2014.3001.5503 二、功能介绍 简单填写中继器内容即…

软件设计师总结-含括学习方法和学习过程,可参考

目录 考前备战宏观    心路历程-感受    学习阶段-计划的安排 微观一、课本和视频的学习    本阶段的目的:    侧重点    涉及的学习方法&#xff08;最后有如何使用这些方法&#xff09;    学习结果 二、32小时通关辅助前面的知识点    本阶段的…

提升企业管理效率的利器——ADManager Plus

在当今信息时代&#xff0c;企业的规模和复杂性不断增长&#xff0c;管理各个方面变得愈发具有挑战性。而在企业管理中&#xff0c;活跃目录&#xff08;Active Directory&#xff09;起着至关重要的作用。它是一种用于组织内部的用户、计算机、组和其他对象进行集中管理的目录…

javascript中的this与函数讲解

前言 javascript中没有块级作用域&#xff08;es6以前&#xff09;&#xff0c;javascript中作用域分为函数作用域和全局作用域。并且&#xff0c;大家可以认为全局作用域其实就是Window函数的函数作用域&#xff0c;我们编写的js代码&#xff0c;都存放在Window函数内&#x…

Hack The Box-Redeemer关卡

TASK 1 任务 1 Which TCP port is open on the machine? 计算机上打开了哪个 TCP 端口&#xff1f; 6379TASK 2 任务 2 Which service is running on the port that is open on the machine? 计算机上打开的端口上运行哪个服务&#xff1f; redisTASK 3 任务 3 What typ…

java学习——java学习进度一String类1(学习记录——供回溯)

String 分割字符串 split( ) String s "1,2,3,4"; //未使用split分割前 System.out.println(s.length());//使用split分割后 String[] ssplit s.split(","); System.out.println(ssplit.length);split( , ) //两个参数都有的时候&#xff0c;第一个为用…

视频编辑软件:迅捷视频工具箱

这是一款功能强大、易于使用的视频编辑工具&#xff0c;支持视频剪辑、视频转换、音频转换、视频压缩、视频水印、字幕贴图等实用功能&#xff0c;可以帮助你制作出高质量的视频作品。&#xff08;传送门&#xff1a;https://www.xunjiepdf.com/xjspgjx&#xff09; 功能简介 …

Linux:CentOS:进程查看和控制

查看 ps 查看静态的进程统计信息top查看动态的进程排名信息pgrep根据特定条件查询进程 PID 信息pstree以树形结构列出进程信息 S ---休眠 R ---运行 Z ---僵死&#xff08;应予以手动终止&#xff09; < ---高优先级 N ---低优先级 …

FrameLayout+LinearLayout实现首页底部菜单

1.布局样式 2.main.xml代码 <?xml version"1.0" encoding"utf-8"?> <LinearLayout xmlns:android"http://schemas.android.com/apk/res/android"android:layout_width"match_parent"android:layout_height"match_par…

Socket(二)

文章目录 1. Socket地址2. 代理服务器3. 获取Socket的信息4. 关闭还是连接5. toString() 1. Socket地址 SocketAddress类表示一个连接端点&#xff0c;这个一个空的抽象类&#xff0c;除了一个默认构造函数外&#xff0c;没有其他方法。当前只支持TCP/IP Socket&#xff0c;实…