【NLP相关】基于现有的预训练模型使用领域语料二次预训练

news2024/12/28 21:32:40

❤️觉得内容不错的话,欢迎点赞收藏加关注😊😊😊,后续会继续输入更多优质内容❤️

👉有问题欢迎大家加关注私戳或者评论(包括但不限于NLP算法相关,linux学习相关,读研读博相关......)👈

BERT

【NLP相关】基于现有的预训练模型使用领域语料二次预训练

在自然语言处理领域,预训练模型已经成为了最为热门和有效的技术之一。预训练模型通过在大规模文本语料库上进行训练,可以学习到通用的语言模型,然后可以在不同的任务上进行微调。但是,预训练模型在领域特定任务上的表现可能不够好,因为它们是在通用语言语料库上进行训练的。为了提高在特定领域的任务中的性能,我们可以使用领域语料库对预训练模型进行二次预训练。

本篇博客将介绍如何基于现有的预训练模型使用领域语料二次预训练。我们将以 PyTorch 和 Transformers 库为基础,以医学文本分类任务为例,来详细说明二次预训练的过程。

1. 模型介绍

在本篇博客中,我们使用的预训练模型是 BERT(Bidirectional Encoder Representations from Transformers)。BERT 是一种基于 Transformer 的预训练模型,由 Google 团队开发。它在多个自然语言处理任务上取得了最先进的结果,例如文本分类、命名实体识别和问答系统等。

BERT 模型是一种双向的 Transformer 模型,能够有效地处理自然语言序列。它将文本输入嵌入到向量空间中,并在此基础上进行自监督训练,以学习通用的语言表示。在预训练完成后,BERT 模型可以进行微调,以适应不同的自然语言处理任务。

2. 代码实现

2.1 数据预处理

在开始二次预训练之前,我们需要准备领域特定的语料库。在这里,我们使用的是医学文本分类数据集,其中包含了一些医学文章的标题和摘要,并且每个文本都被标记为一个预定义的类别。

首先,我们需要将原始文本数据拆分为单个句子,并将其标记化处理。我们可以使用 Hugging Face 的 tokenizer 来完成这个任务。

from transformers import BertTokenizer

# 加载 BERT tokenizer
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

class MedicalDataset(Dataset):
    def __init__(self, tokens, max_length=128):
        self.tokens = tokens
        self.max_length = max_length

    def __len__(self):
        return len(self.tokens)

    def __getitem__(self, idx):
        # 获取句子对
        tokens = self.tokens[idx]

        # 将句子对拼接成一个序列,并将其标记化处理
        input_ids = tokenizer.encode(
            tokens[0], tokens[1],
            add_special_tokens=True, max_length=self.max_length,
            truncation_strategy='longest_first'
        )
        attention_mask = [1] * len(input_ids)

        # 填充序列长度
        padding = [0] * (self.max_length - len(input_ids))
        input_ids += padding
        attention_mask += padding

        # 返回 input_ids 和 attention_mask
        return torch.LongTensor(input_ids), torch.LongTensor(attention_mask)

2.2 二次预训练

在数据预处理之后,我们可以开始进行二次预训练了。在这里,我们将使用 Hugging Face 的 Transformers 库,以及 PyTorch 框架来实现二次预训练。

首先,我们需要加载预训练的 BERT 模型。在这里,我们使用的是 bert-base-uncased 模型,它是一个基于英文的预训练模型。我们还需要定义一些训练参数,例如学习率和批大小等。

from transformers import BertForPreTraining, AdamW
from torch.utils.data import DataLoader

# 加载预训练的 BERT 模型
model = BertForPreTraining.from_pretrained('bert-base-uncased')

# 定义训练参数
epochs = 3
batch_size = 16
learning_rate = 2e-5
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 将模型移动到 GPU 上
model.to(device)

接下来,我们需要加载领域特定的语料库,并将其转换为 PyTorch 数据集。在这里,我们使用的是 PyTorch 中的 Dataset 类。我们还需要将数据集加载到 PyTorch 的数据加载器中,以便进行训练。

# 加载领域特定的语料库
with open('medical_data.txt') as f:
    sentences = f.readlines()

# 将语料库转换为 PyTorch 数据集
dataset = MedicalDataset(sentences)

# 将数据集加载到 PyTorch 的数据加载器中
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

在准备好数据集之后,我们可以开始训练模型了。我们将使用 AdamW 优化器和交叉熵损失函数来训练模型。在每个 epoch 完成之后,我们会对模型进行一次测试,并计算准确率和损失函数值。最后,我们将保存训练好的模型。

# 定义优化器和损失函数
optimizer = AdamW(model.parameters(), lr=learning_rate)
criterion = torch.nn.CrossEntropyLoss()

# 训练模型
for epoch in range(epochs):
    total_loss = 0
    total_correct = 0
    total_samples = 0

    # 遍历数据集
    for i, batch in enumerate(loader):
        # 将输入数据和标签移动到 GPU 上
        input_ids, attention_mask = batch
        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)
    # 将模型设置为训练模式
    model.train()

    # 计算模型的输出
    outputs = model(input_ids, attention_mask=attention_mask)

    # 计算损失函数值
    loss = criterion(outputs.logits.view(-1, 2), outputs.labels.view(-1))

    # 清除之前的梯度
    optimizer.zero_grad()

    # 反向传播和优化
    loss.backward()
    optimizer.step()

    # 统计训练信息
    total_loss += loss.item()
    total_samples += input_ids.size(0)
    total_correct += torch.sum(torch.argmax(outputs.logits, dim=-1) == outputs.labels.view(-1)).item()

    # 输出训练信息
    if (i + 1) % 100 == 0:
        print('Epoch [%d/%d], Step [%d/%d], Loss: %.4f, Accuracy: %.2f%%'
              % (epoch + 1, epochs, i + 1, len(loader),
                 total_loss / total_samples, total_correct / total_samples * 100))

# 在每个 epoch 完成之后进行一次测试
with torch.no_grad():
    total_loss = 0
    total_correct = 0
    total_samples = 0

    # 将模型设置为评估模式
    model.eval()

    # 遍历测试数据集
    for i, batch in enumerate(test_loader):
        # 将输入数据和标签移动到 GPU 上
        input_ids, attention_mask = batch
        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)

        # 计算模型的输出
        outputs = model(input_ids, attention_mask=attention_mask)

        # 计算损失函数值
        loss = criterion(outputs.logits.view(-1, 2), outputs.labels.view(-1))

        # 统计测试信息
        total_loss += loss.item()
        total_samples += input_ids.size(0)
        total_correct += torch.sum(torch.argmax(outputs.logits, dim=-1) == outputs.labels.view(-1)).item()

    # 输出测试信息
    print('Epoch [%d/%d], Test Loss: %.4f, Test Accuracy: %.2f%%'
          % (epoch + 1, epochs, total_loss / total_samples, total_correct / total_samples * 100))
	#保存训练好的模型
	torch.save(model.state_dict(), 'medical_bert.pth')

3. 案例解析

假设我们要对医学领域中的文本进行二次预训练。我们可以使用已经预训练好的 BERT 模型,并使用医学领域的语料库进行二次预训练。

首先,我们需要将医学领域的语料库进行预处理。在这里,我们可以使用 NLTK 库来进行分词和词形还原等操作。我们还可以将语料库中的每个句子转换为 BERT 输入格式。

import nltk
from nltk.tokenize import word_tokenize
from nltk.stem import WordNetLemmatizer
from transformers import BertTokenizer

# 加载 BERT 分词器
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# 加载 NLTK 分词器和词形还原
nltk.download('punkt')
nltk.download('wordnet')
lemmatizer = WordNetLemmatizer()

加载医学领域的语料库
with open('medical_corpus.txt', 'r') as f:
	corpus = f.read()

#对每个句子进行分词、词形还原和转换为 BERT 输入格式
sentences = []
for sentence in nltk.sent_tokenize(corpus):
	words = nltk.word_tokenize(sentence)
	words = [lemmatizer.lemmatize(word) for word in words]
	words = [word.lower() for word in words]
	tokens = tokenizer.encode_plus(words,
									add_special_tokens=True,
									max_length=512,
									padding='max_length',
									truncation=True)
	sentences.append((tokens['input_ids'], tokens['attention_mask']))

接下来,我们需要使用这些句子对 BERT 模型进行二次预训练。为此,我们需要定义一个新的数据加载器,将这些句子传递给模型进行训练。

from torch.utils.data import Dataset, DataLoader

class MedicalDataset(Dataset):
    def __init__(self, sentences):
        self.sentences = sentences
    
    def __len__(self):
        return len(self.sentences)
    
    def __getitem__(self, idx):
        return self.sentences[idx]

# 定义数据加载器
loader = DataLoader(MedicalDataset(sentences),
                    batch_size=16,
                    shuffle=True)

现在,我们可以开始对 BERT 模型进行二次预训练了。我们可以使用与之前相同的训练代码。

# 定义训练函数
def train(model, loader, optimizer, device):
    model.train()
    for batch in loader:
        input_ids = batch[0].to(device)
        attention_mask = batch[1].to(device)
        
        optimizer.zero_grad()
        
        loss, _ = model(input_ids=input_ids,
                        attention_mask=attention_mask,
                        output_hidden_states=True)[:2]
        loss.backward()
        
        optimizer.step()

# 加载预训练的 BERT 模型
model = BertForMaskedLM.from_pretrained('bert-base-uncased')

# 设置设备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

# 定义优化器
optimizer = AdamW(model.parameters(), lr=5e-5)

# 进行二次预训练
for epoch in range(num_epochs):
    train(model, train_loader, optimizer, device)

    # 每个 epoch 结束后测试模型的性能
    perplexity = evaluate(model, test_loader, device)
    print(f'Epoch {epoch+1}, perplexity: {perplexity:.3f}')

    # 保存模型
    model_path = f'model_epoch{epoch+1}.pt'
    torch.save(model.state_dict(), model_path)

这里定义了一个 train 函数来训练模型。这个函数接收一个模型、一个数据加载器、一个优化器和一个设备作为输入。它会将模型设为训练模式,并且在每个批次上运行前向传播、计算损失、计算梯度和更新参数。

接下来,我们加载预训练的 BERT 模型,并将其移动到所选设备上。我们使用 AdamW 优化器,并将学习率设置为 5e-5。

最后,我们使用一个简单的 for 循环来进行二次预训练。在每个 epoch 结束时,我们会在测试集上评估模型,并打印出 perplexity 指标。我们还会将模型保存在磁盘上,以便以后进行检索。


❤️觉得内容不错的话,欢迎点赞收藏加关注😊😊😊,后续会继续输入更多优质内容❤️

👉有问题欢迎大家加关注私戳或者评论(包括但不限于NLP算法相关,linux学习相关,读研读博相关......)👈

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/389119.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

《七》JavaScript 中的作用域、作用域链、执行上下文、执行上下文栈

JS 引擎会在执行所有代码之前,先在堆内存中创建一个全局对象(Global Object、GO),包含 String、Math、Date、parseInt() 等属性和方法。所有作用域都可以访问这个全局对象。 在浏览器中 Global Object 就是 Window 对象。 执行上…

不用机器学习不用大数据,给你讲通ChatGPT的深层原理

ChatGPT现在看来已经异常火爆了,很多人已经熟知,并且开始练习使用或者开始利用他开始实践了。但仍然有很多人在观望,在疑惑,今天狗哥不用那些高端大气的机器学习亦或是大数据还给你讲通ChatGPT深层到底是个啥逻辑。 目录 1. 聊家…

CV——dy83 接昨天的论文中DAM模块:压缩-激励的宽残差网络在图像分类中的应用

压缩-激励的宽残差网络在图像分类中的应用(ICIP 2019)1. INTRODUCTION2. PROPOSED METHODS2.1 总体框架2.2 通道的重要性3. EXPERIMENTS3.1 Datasets3.2 训练和测试的设置3.3 分类结果及分析4. CONCLUSIONSQUEEZE-AND-EXCITATION WIDE RESIDUAL NETWORKS…

CSS 选择器以及CSS常用属性

目录 🐇今日良言:可以不光芒万丈,但不要停止发光 🐯一、写CSS的三种方法 🐯二、CSS选择器的常见用法 🐯三、CSS常用属性 🐇今日良言:可以不光芒万丈,但不要停止发光 🐯一、写CSS的三种方法 CSS的基本语…

目标检测开源数据集汇总

导 读本文汇总了一些开源目标检测类的数据集,附下载链接。多显著性对象数据集数据集链接:http://m6z.cn/5AsmXB本数据集共有 1224 张图像来自四个公共图像数据集:COCO、VOC07、ImageNet 和 SUN。Amazon Mechanic Turk 工作人员将每个图像标记…

Firebase入门使用 01

官网 firebase.google.com 解决问题 firebase 帮助解决 数据库 和 API之间的问题 这样我们就可以 集中精力开创应用。 快速上手样例指南 https://github.com/firebase 提供的服务 其中80%用不到,下面是一些我们可以用到的服务。 Authentication:用户认证管理…

Qt安装与使用经验分享;无.pro文件;无QTextCodec file;Qt小试;界面居中;无缝;更换Qt图标;更换Qt标题。

1、切换安装下载源 《Qt安装教程》先推荐一篇安装文章:《Qt安装教程》 Qt 5.15 之后已经不提供离线安装包了,就是那个 3.7G 的 exe 安装包。请看官方说明,所以只能用在线安装包。 1,下载在线安装包 QT 在线安装包链接&#xff…

基于WSL2和Clion搭建Win下C开发环境

系列文章目录 一、基于WSL2和Clion搭建Win下C开发环境 二、make、makeFile、CMake、CMakeLists的使用 三、全面、详细、通俗易懂的C语言语法和标准库 文章目录系列文章目录前言WSL2安装WSL常用命令VSCode连接WSLroot密码以systemd启动配置sshClion结语前言 Win下C语言开发环境…

zabbix-API对接实录:关键基础设施数据清洗和封装函数(php数组函数、数据清洗、数据结构化)

系列文章目录 Zabbix监控系统PHP-API开发测试实录Zabbix监控系统开发(2):JSON多维数组筛选字段是否包含字符串的解决方案Zabbix物联网可视化开发文档 文章目录系列文章目录前言一、zabbix-API数据爬虫二、主机ID封装接口1.封装API接口2.数据处理封装函数三、组ID封装接口1.格式…

汽车 Automotive > T-BOX GNSS高精定位测试相关知识

参考:https://en.wikipedia.org/wiki/Global_Positioning_SystemGPS和GNSS的关系GPS(Global Positioning System),全球定位系统是美国军民两用的导航定位卫星系统,GPS包含双频信号,频点L1、L2和L5GNSS&…

RecyclerView ViewType二级

实现效果描述: 1、点击recyclerview中item,列表下方出现其他样式的item,作为子item,如下所示 所需要的java文件和xml文件有: 1、创建FoldAdapteradapter, 在FoldAdapter中,定义两种不同的类型&#xff…

Allegro如何将Waived掉的DRC显示或隐藏操作指导

Allegro如何将Waived掉的DRC显示或隐藏操作指导 在用Allegro做PCB设计的时候,如果遇到正常的DRC,可以用Waive的命令将DRC不显示,如下图 当DRC被Waive掉的时候,如何将DRC再次显示出来。类似下图效果 具体操作如下 点击Display

linux下strace的使用

strace是一款用于跟踪Linux系统调用和信号的工具,可以帮助开发者排除程序运行时的问题。 具体来说,strace可以跟踪一个程序执行时所涉及到的系统调用,包括读写文件、网络通信、进程管理、内存管理等操作,通过分析程序运行过程中发…

JavaWeb--JSP案例

JSP案例8 案例8.1 环境准备8.1.1 创建工程8.1.2 创建包8.1.3 创建表8.1.4 创建实体类8.1.5 准备mybatis环境8.2 查询所有8.2.1 编写BrandMapper8.2.2 编写工具类8.2.3 编写BrandService8.2.4 编写Servlet8.2.5 编写brand.jsp页面8.2.6 测试8.3 添加8.3.1 编写BrandMapper方法8.…

ARM uboot 的移植0-从三星官方 uboot 开始移植的准备工作

一、移植前的准备工作 1、三星移植过的uboot源代码准备 (1) 三星对于 S5PV210 的官方开发板为 SMDKV210,对应的移植过的 uboot 是:三星官方为210移植过的uboot和kernel/android_uboot_smdkv210.tar.bz2。 (2) 这个源代码网上是下载不到的,…

Leetcode.2397 被列覆盖的最多行数

题目链接 Leetcode.2397 被列覆盖的最多行数 Rating : 1719 题目描述 给你一个下标从 0 开始的 m x n二进制矩阵 mat和一个整数 cols,表示你需要选出的列数。 如果一行中,所有的 1 都被你选中的列所覆盖,那么我们称这一行 被覆盖…

RabbitMQ的使用以及整合到SpringBoot中

RabbitMQ的使用以及整合到SpringBoot中 一、比较: (1)、传统请求服务器: (2)、通过MQ去操作数据库: 通过MQ去操作数据库,从而达到削峰的效果; 问题现象: (1)、海量数据; (2)、高并发&#…

Python如何获取弹幕?给你介绍两种方式

前言 弹幕可以给观众一种“实时互动”的错觉,虽然不同弹幕的发送时间有所区别,但是其只会在视频中特定的一个时间点出现,因此在相同时刻发送的弹幕基本上也具有相同的主题,在参与评论时就会有与其他观众同时评论的错觉。 在国内…

【SQLAlchemy】第二篇——连接失效及连接池

一、背景 为了节约资源,MySQL会对建立的连接进行监控,当某些连接处于不活跃状态的时间超过一个阈值时,则关闭它们。 用户可以执行show variables like %wait_timeout%;来查看这个阈值: 可以看到,在默认的情况下&…

Multi-modal Graph Contrastive Learning for Micro-video Recommendation

模型总览如下: 解决问题:同种重要性对待每种模态,可能使得得到的特征表示次优,例如过度强调学习到的表示中的特定模态。以MMGCN为例,下图为MMGCN模型总览。 如上图所示MMGCN在每种模态上构建用户-物品二部图&#xff0…