97. BERT微调、自然语言推理数据集以及代码实现

news2024/11/23 0:07:07

1. 微调BERT

在这里插入图片描述

2. 句子分类

在这里插入图片描述

3. 命名实体识别

在这里插入图片描述

4. 问题回答

在这里插入图片描述

5. 总结

  • 即使下游任务各有不同,使用BERT微调时只需要增加输出层
  • 但根据任务的不同,输入的表示,和使用的BERT特征也会不一样

6. 自然语言推理数据集

斯坦福自然语言推断语料库(Stanford Natural Language Inference,SNLI)]是由500000多个带标签的英语句子对组成的集合 。我们在路径../data/snli_1.0中下载并存储提取的SNLI数据集。

import os
import re
import torch
from torch import nn
from d2l import torch as d2l

d2l.DATA_HUB['SNLI'] = (
    'https://nlp.stanford.edu/projects/snli/snli_1.0.zip',
    '9fcde07509c7e87ec61c640c1b2753d9041758e4')

data_dir = d2l.download_extract('SNLI')

6.1 读取数据集

原始的SNLI数据集包含的信息比我们在实验中真正需要的信息丰富得多。因此,我们定义函数read_snli以仅提取数据集的一部分,然后返回前提、假设及其标签的列表。

def read_snli(data_dir, is_train):
def read_snli(data_dir, is_train):
    """将SNLI数据集解析为前提、假设和标签"""
    def extract_text(s):
        # 删除我们不会使用的信息
        s = re.sub('\\(', '', s)
        s = re.sub('\\)', '', s)
        # 用一个空格替换两个或多个连续的空格
        s = re.sub('\\s{2,}', ' ', s)
        return s.strip()
    label_set = {'entailment': 0, 'contradiction': 1, 'neutral': 2}
    file_name = os.path.join(data_dir, 'snli_1.0_train.txt'
                             if is_train else 'snli_1.0_test.txt')
    with open(file_name, 'r') as f:
        # rows是一个list,其中包含多个list,这每个list就是由多个字符串组成,如:
        # rows[i]:['contradiction',
        # '( ( ( A person ) ( on ( a horse ) ) ) ( ( jumps ( over ( a ( broken ( down airplane ) ) ) ) ) . ) )',
        # '( ( A person ) ( ( is ( ( training ( his horse ) ) ( for ( a competition ) ) ) ) . ) )',]
        # 第一个字符串会指示属于哪个标签
        rows = [row.split('\t') for row in f.readlines()[1:]]
    
    # 第2个字符串表示premises(前提)
    premises = [extract_text(row[1]) for row in rows if row[0] in label_set]
    # 第3个字符串表示hypotheses(假设)
    hypotheses = [extract_text(row[2]) for row in rows if row[0] \
                in label_set]
    # 第一个字符串表示label(标签),再通过label_set得到具体的数字
    labels = [label_set[row[0]] for row in rows if row[0] in label_set]
    return premises, hypotheses, labels

现在让我们打印前3对前提和假设,以及它们的标签(“0”“1”和“2”分别对应于“蕴涵”“矛盾”和“中性”)。

train_data = read_snli(data_dir, is_train=True)
# train_data[0]表示premises,[:3]表示premises前三个
for x0, x1, y in zip(train_data[0][:3], train_data[1][:3], train_data[2][:3]):
    print('前提:', x0)
    print('假设:', x1)
    print('标签:', y)

运行结果:

在这里插入图片描述

训练集约有550000对,测试集约有10000对。下面显示了训练集和测试集中的三个标签“蕴涵”“矛盾”和“中性”是平衡的。

test_data = read_snli(data_dir, is_train=False)
for data in [train_data, test_data]:
    # data[2]是label这一数组,row就取到每一个label,一个label也能代表一行
    # 计算训练集和测试集中所有的label的0,1,2的数量
    print([[row for row in data[2]].count(i) for i in range(3)])

运行结果:

在这里插入图片描述

6.2 定义用于加载数据集的类

下面我们来定义一个用于加载SNLI数据集的类。类构造函数中的变量num_steps指定文本序列的长度,使得每个小批量序列将具有相同的形状。换句话说,在较长序列中的前num_steps个标记之后的标记被截断,而特殊标记“< pad>”将被附加到较短的序列后,直到它们的长度变为num_steps。通过实现__getitem__功能,我们可以任意访问带有索引idx的前提、假设和标签。

class SNLIDataset(torch.utils.data.Dataset):
    """用于加载SNLI数据集的自定义数据集"""
    def __init__(self, dataset, num_steps, vocab=None):
        self.num_steps = num_steps
        all_premise_tokens = d2l.tokenize(dataset[0])
        all_hypothesis_tokens = d2l.tokenize(dataset[1])
        if vocab is None:
            self.vocab = d2l.Vocab(all_premise_tokens + \
                all_hypothesis_tokens, min_freq=5, reserved_tokens=['<pad>'])
        else:
            self.vocab = vocab
        self.premises = self._pad(all_premise_tokens)
        self.hypotheses = self._pad(all_hypothesis_tokens)
        self.labels = torch.tensor(dataset[2])
        print('read ' + str(len(self.premises)) + ' examples')

    def _pad(self, lines):
        return torch.tensor([d2l.truncate_pad(
            self.vocab[line], self.num_steps, self.vocab['<pad>'])
                         for line in lines])

    def __getitem__(self, idx):
        return (self.premises[idx], self.hypotheses[idx]), self.labels[idx]

    def __len__(self):
        return len(self.premises)

6.3 整合代码

现在,我们可以调用read_snli函数和SNLIDataset类来下载SNLI数据集,并返回训练集和测试集的DataLoader实例,以及训练集的词表。值得注意的是,我们必须使用从训练集构造的词表作为测试集的词表。因此,在训练集中训练的模型将不知道来自测试集的任何新词元。

def load_data_snli(batch_size, num_steps=50):
    """下载SNLI数据集并返回数据迭代器和词表"""
    num_workers = d2l.get_dataloader_workers()
    data_dir = d2l.download_extract('SNLI')
    train_data = read_snli(data_dir, True)
    test_data = read_snli(data_dir, False)
    train_set = SNLIDataset(train_data, num_steps)
    test_set = SNLIDataset(test_data, num_steps, train_set.vocab)
    train_iter = torch.utils.data.DataLoader(train_set, batch_size,
                                             shuffle=True,
                                             num_workers=num_workers)
    test_iter = torch.utils.data.DataLoader(test_set, batch_size,
                                            shuffle=False,
                                            num_workers=num_workers)
    return train_iter, test_iter, train_set.vocab

在这里,我们将批量大小设置为128时,将序列长度设置为50,并调用load_data_snli函数来获取数据迭代器和词表。然后我们打印词表大小。

train_iter, test_iter, vocab = load_data_snli(128, 50)
len(vocab)

运行结果:

在这里插入图片描述

现在我们打印第一个小批量的形状。与情感分析相反,我们有分别代表前提和假设的两个输入X[0]和X[1]。

for X, Y in train_iter:
    print(X[0].shape)
    print(X[1].shape)
    print(Y.shape)
    break

运行结果:

在这里插入图片描述

7. BERT微调代码实现

本节将下载一个预训练好的小版本的BERT,然后对其进行微调,以便在SNLI数据集上进行自然语言推断。

import json
import multiprocessing
import os
import torch
from torch import nn
from d2l import torch as d2l

7.1 加载预训练的BERT

我们已经在 WikiText-2数据集上预训练BERT(请注意,原始的BERT模型是在更大的语料库上预训练的)。原始的BERT模型有数以亿计的参数。在下面,我们提供了两个版本的预训练的BERT:“bert.base”与原始的BERT基础模型一样大,需要大量的计算资源才能进行微调,而“bert.small”是一个小版本,以便于演示。

d2l.DATA_HUB['bert.base'] = (d2l.DATA_URL + 'bert.base.torch.zip',
                             '225d66f04cae318b841a13d32af3acc165f253ac')
d2l.DATA_HUB['bert.small'] = (d2l.DATA_URL + 'bert.small.torch.zip',
                              'c72329e68a732bef0452e4b96a1c341c8910f81f')

两个预训练好的BERT模型都包含一个定义词表的“vocab.json”文件和一个预训练参数的“pretrained.params”文件。我们实现了以下load_pretrained_model函数来加载预先训练好的BERT参数。

def load_pretrained_model(pretrained_model, num_hiddens, ffn_num_hiddens,
                          num_heads, num_layers, dropout, max_len, devices):
    data_dir = d2l.download_extract(pretrained_model)
    # 定义空词表以加载预定义词表
    vocab = d2l.Vocab()
    vocab.idx_to_token = json.load(open(os.path.join(data_dir,
        'vocab.json')))
    vocab.token_to_idx = {token: idx for idx, token in enumerate(
        vocab.idx_to_token)}
    bert = d2l.BERTModel(len(vocab), num_hiddens, norm_shape=[256],
                         ffn_num_input=256, ffn_num_hiddens=ffn_num_hiddens,
                         num_heads=4, num_layers=2, dropout=0.2,
                         max_len=max_len, key_size=256, query_size=256,
                         value_size=256, hid_in_features=256,
                         mlm_in_features=256, nsp_in_features=256)
    # 加载预训练BERT参数
    bert.load_state_dict(torch.load(os.path.join(data_dir,
                                                 'pretrained.params')))
    return bert, vocab

为了便于在大多数机器上演示,我们将在本节中加载和微调经过预训练BERT的小版本(“bert.small”)。在练习中,我们将展示如何微调大得多的“bert.base”以显著提高测试精度。

devices = d2l.try_all_gpus()
bert, vocab = load_pretrained_model(
    'bert.small', num_hiddens=256, ffn_num_hiddens=512, num_heads=4,
    num_layers=2, dropout=0.1, max_len=512, devices=devices)

7.2 微调BERT的数据集

对于SNLI数据集的下游任务自然语言推断,我们定义了一个定制的数据集类SNLIBERTDataset。在每个样本中,前提和假设形成一对文本序列,并被打包成一个BERT输入序列。

片段索引用于区分BERT输入序列中的前提和假设。利用预定义的BERT输入序列的最大长度(max_len),持续移除输入文本对中较长文本的最后一个标记,直到满足max_len。为了加速生成用于微调BERT的SNLI数据集,我们使用4个工作进程并行生成训练或测试样本。

class SNLIBERTDataset(torch.utils.data.Dataset):
    def __init__(self, dataset, max_len, vocab=None):
        all_premise_hypothesis_tokens = [[
            p_tokens, h_tokens] for p_tokens, h_tokens in zip(
            *[d2l.tokenize([s.lower() for s in sentences])
              for sentences in dataset[:2]])]

        self.labels = torch.tensor(dataset[2])
        self.vocab = vocab
        self.max_len = max_len
        (self.all_token_ids, self.all_segments,
         self.valid_lens) = self._preprocess(all_premise_hypothesis_tokens)
        print('read ' + str(len(self.all_token_ids)) + ' examples')

    def _preprocess(self, all_premise_hypothesis_tokens):
        pool = multiprocessing.Pool(4)  # 使用4个进程
        out = pool.map(self._mp_worker, all_premise_hypothesis_tokens)
        all_token_ids = [
            token_ids for token_ids, segments, valid_len in out]
        all_segments = [segments for token_ids, segments, valid_len in out]
        valid_lens = [valid_len for token_ids, segments, valid_len in out]
        return (torch.tensor(all_token_ids, dtype=torch.long),
                torch.tensor(all_segments, dtype=torch.long),
                torch.tensor(valid_lens))

    def _mp_worker(self, premise_hypothesis_tokens):
        p_tokens, h_tokens = premise_hypothesis_tokens
        self._truncate_pair_of_tokens(p_tokens, h_tokens)
        tokens, segments = d2l.get_tokens_and_segments(p_tokens, h_tokens)
        token_ids = self.vocab[tokens] + [self.vocab['<pad>']] \
                             * (self.max_len - len(tokens))
        segments = segments + [0] * (self.max_len - len(segments))
        valid_len = len(tokens)
        return token_ids, segments, valid_len

    def _truncate_pair_of_tokens(self, p_tokens, h_tokens):
        # 为BERT输入中的'<CLS>'、'<SEP>'和'<SEP>'词元保留位置
        while len(p_tokens) + len(h_tokens) > self.max_len - 3:
            if len(p_tokens) > len(h_tokens):
                p_tokens.pop()
            else:
                h_tokens.pop()

    def __getitem__(self, idx):
        return (self.all_token_ids[idx], self.all_segments[idx],
                self.valid_lens[idx]), self.labels[idx]

    def __len__(self):
        return len(self.all_token_ids)

下载完SNLI数据集后,我们通过实例化SNLIBERTDataset类来生成训练和测试样本。这些样本将在自然语言推断的训练和测试期间进行小批量读取。

# 如果出现显存不足错误,请减少“batch_size”。在原始的BERT模型中,max_len=512
batch_size, max_len, num_workers = 512, 128, d2l.get_dataloader_workers()
data_dir = d2l.download_extract('SNLI')
train_set = SNLIBERTDataset(d2l.read_snli(data_dir, True), max_len, vocab)
test_set = SNLIBERTDataset(d2l.read_snli(data_dir, False), max_len, vocab)
train_iter = torch.utils.data.DataLoader(train_set, batch_size, shuffle=True,
                                   num_workers=num_workers)
test_iter = torch.utils.data.DataLoader(test_set, batch_size,
                                  num_workers=num_workers)

运行结果:

在这里插入图片描述

7.3 微调BERT

用于自然语言推断的微调BERT只需要一个额外的多层感知机,该多层感知机由两个全连接层组成(请参见下面BERTClassifier类中的self.hiddenself.output)。这个多层感知机将特殊的“< cls>”词元的BERT表示进行了转换,该词元同时编码前提和假设的信息(为自然语言推断的三个输出):蕴涵、矛盾和中性。

class BERTClassifier(nn.Module):
    def __init__(self, bert):
        super(BERTClassifier, self).__init__()
        self.encoder = bert.encoder
        self.hidden = bert.hidden
        self.output = nn.Linear(256, 3)

    def forward(self, inputs):
        tokens_X, segments_X, valid_lens_x = inputs
        encoded_X = self.encoder(tokens_X, segments_X, valid_lens_x)
        return self.output(self.hidden(encoded_X[:, 0, :]))

在下文中,预训练的BERT模型bert被送到用于下游应用的BERTClassifier实例net中。在BERT微调的常见实现中,只有额外的多层感知机(net.output)的输出层的参数将从零开始学习。预训练BERT编码器(net.encoder)和额外的多层感知机的隐藏层(net.hidden)的所有参数都将进行微调。

net = BERTClassifier(bert)

回想一下,在 sec_bert中,MaskLM类NextSentencePred类在其使用的多层感知机中都有一些参数。这些参数是预训练BERT模型bert中参数的一部分,因此是net中的参数的一部分。然而,这些参数仅用于计算预训练过程中的遮蔽语言模型损失和下一句预测损失。这两个损失函数与微调下游应用无关,因此当BERT微调时,MaskLMNextSentencePred中采用的多层感知机的参数不会更新(陈旧的,staled)。

为了允许具有陈旧梯度的参数,标志ignore_stale_grad=Truestep函数d2l.train_batch_ch13中被设置。我们通过该函数使用SNLI的训练集(train_iter)和测试集(test_iter)对net模型进行训练和评估。由于计算资源有限,训练和测试精度可以进一步提高:我们把对它的讨论留在练习中。

lr, num_epochs = 1e-4, 5
trainer = torch.optim.Adam(net.parameters(), lr=lr)
loss = nn.CrossEntropyLoss(reduction='none')
d2l.train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs,
    devices)

运行结果:

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/189409.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

BP神经网络算法实现

目录 一、实验数学原理 二、实验算法和实验步骤 三、结果分析 1. 均方误差变化的影响 2. 迭代次数变化的影响 3. 学习效率变化的影响 四、预测 一、实验数学原理 激活函数&#xff1a; 一般使用S形函数&#xff08;即sigmoid函数&#xff09;&#xff0c;比如可以使用log-…

【爬虫系列】Python如何实现进度条效果?

一、需求 在爬取数据过程中&#xff0c;发现不看输出日志是不知道当前的爬取进度&#xff0c;而单纯靠控制台输出日志信息也不方便判断。因此&#xff0c;就想办法给爬取过程加个进度条&#xff0c;实时展示当前的爬取进度。 有了这个需求和想法之后&#xff0c;那如何实现呢…

k8s中不同名称空间下的pod无法解析服务名

1、背景 公司的项目需要使用容器化部署&#xff0c;为了更好的维护和管理&#xff0c;我将各个项目按照命名空间进行隔离开&#xff0c;但是却发现存在一些问题 不同的系统间需要项目调用&#xff0c;而且是按照服务名进行调用&#xff0c;但是却导致不同名称空间下pod无法解析…

mac 快应用开发工具 真机调试 usb调试 提示Error:没有找到Android设备

项目场景&#xff1a; 项目场景&#xff1a;mac使用快应用开发工具连接Android手机 问题描述 显示错误没有找到Android设备 原因分析&#xff1a; adb连接的问题 解决方案&#xff1a; 1.确保手机开启开发者模式 2.确保手机与mac的连接线能传输数据&#xff0c;有的线只能…

python+moviepy音视频处理(一):基本操作

目录 视频处理 视频加载和输出 视频转换gif 视频裁剪 视频音量调节 去掉视频声音 视频中的音频提取与替换 获取视频属性 倍数播放视频 截取视频某帧为封面 多视频拼接 音频处理 替换视频文件的音频 多个音频文件拼接 安装&#xff1a;pip install moviepy 中文官…

【自学Docker】Docker stats命令

Docker stats命令 大纲 docker stats命令教程 docker stats 命令可以用于动态显示 Docker容器 的资源消耗情况&#xff0c;包括&#xff1a;CPU、内存、网络I/O。docker stats命令也可以指定已停止的容器&#xff0c;但是不会返回任何信息。 docker stats命令语法 haicoder…

Windows下载安装Nignx

下载 下载地址&#xff1a;http://nginx.org/en/download.html 下载完成以后,得到nginx压缩包; Nginx启动 方式一&#xff1a;可执行文件启动 双击nginx.exe启动 现在,我们打开任务管理器,如果发现nginx进程存在,说明启动完成; 方式二&#xff1a;命令行启动 进入nginx所在…

【Mysql第三期 基本查询语句结构】

文章目录1. SQL概述1.1 SQL背景知识1.2SQL 分类2. SQL语言的规则与规范2.1 基本规则2.2 SQL大小写规范 &#xff08;建议遵守&#xff09;2.3 注 释2.4 命名规则&#xff08;暂时了解&#xff09;3.基本的SELECT语句3.1 查询基本结构3.2 列的别名3.3 去除重复行扩展windows cmd…

同步FIFO设计verilog设计及仿真

同步FIFO设计 1.功能定义: 用16*8 RAM实现一个同步先进先出(FIFO)队列设计。由写使能端控制该数据流的写入FIFO,并由读使能控制FIFO中数据的读出。写入和读出的操作由时钟的上升沿触发。当FIFO的数据满和空的时候分别设置相应的高电平加以指示。 2.顶层信号定义: 信号名…

最小生成树问题(Prim算法和Kruskal算法)

问题引入&#xff1a; 这算是一道模板题了&#xff0c;只不过这次在做的时候感觉又学到了些新的东西&#xff0c;之前都是数据结构里学的&#xff0c;因为用惯了C&#xff0c;所以就想摆脱那些邻接数组之类的写法&#xff0c;用STL试一下&#xff0c;在其中把我遇到的一些问题写…

【论文翻译】边缘应用中加速卷积神经网络的剪枝算法综述

摘要 随着卷积神经网络&#xff08;CNN&#xff09;模型大小的增加&#xff0c;模型压缩和加速技术对于在边缘设备上部署这些模型变得至关重要。在本文中&#xff0c;我们对修剪进行了全面的调查&#xff0c;这是一种主要的压缩策略&#xff0c;可以从CNN模型中删除非关键或冗…

iOS_Memory Leak 内存泄露治理

1、内存分类 官方文档介绍 app 的内存分三类&#xff1a; Leaked memory&#xff1a;Memory unreferenced by your application that cannot be used again or freed (also detectable by using the Leaks instrument) Abandoned memory&#xff1a;Memory still referenced b…

设计模式 - 结构型模式_桥接模式

文章目录结构型模式概述CaseBad ImplBetter Impl小结结构型模式 结构型模式主要是解决如何将对象和类组装成较大的结构&#xff0c; 并同时保持结构的灵活和⾼效。 结构型模式包括&#xff1a;适配器、桥接、组合、装饰器、外观、享元、代理&#xff0c;这7类 概述 桥接模式的…

2023牛客寒假算法基础集训营4

A-清楚姐姐学信息论 链接&#xff1a;登录—专业IT笔试面试备考平台_牛客网 来源&#xff1a;牛客网 不同进制对于信息的表示效率不同&#xff0c;清楚姐姐最近学习了信息论中使用不同进制表示信息的方法&#xff0c;她现在想要比较两种不同进制表示信息时&#xff0c;谁的…

从软件角度看PCIe设备的硬件结构

从软件角度看PCIe设备的硬件结构 文章目录从软件角度看PCIe设备的硬件结构参考资料&#xff1a;一、 PCIe接口引脚二、 从软件角度理解硬件接口2.1 PCI/PCIe地址空间转换2.2 PCIe上怎么传输地址、数据三、 PCIe系统的硬件框图致谢参考资料&#xff1a; 《PCI Express Technolo…

ElasticSearch概念与架构原理

文章目录一、概述二、ElasticSearch架构原理三、ElasticSearch搜索入门一、概述 ElasticSearch简介 简介 ES是建立在Lucene基础之上的分布式准实时搜索引擎&#xff0c;它所提供的诸多功能中有一大优点&#xff0c;就是实时性好。比如&#xff1a;在业务需求中&#xff0c;新增…

计算机图形学 第7章 自由曲线曲面

先说好&#xff0c;第八章不学。 目录学习目标曲线与曲面的表示形式插值与逼近Bezier曲线定义一次Bezier曲线二次Bezier曲线⭐⭐⭐三次Bezier曲线⭐⭐⭐三次Bezier曲线的Bernstein基函数&#xff1a;Bernstein基函数的性质Bezier曲线的性质de Casteljau算法几何作图法绘制Bezie…

Struts2之拦截器

Struts2之拦截器1、Struts2体系架构1.1、执行流程1.2、核心接口和类1.3、流程简图2、Struts2拦截器2.1、使用拦截器的目的2.2、拦截器的简介2.3、拦截器的工作原理2.4、拦截器的使用2.4.1、创建自定义拦截器2.4.2、struts.xml中定义和配置拦截器2.4.3、Struts2默认拦截器2.4.4、…

Leetcode.2319 判断矩阵是否是一个 X 矩阵

题目链接 Leetcode.2319 判断矩阵是否是一个 X 矩阵 Rating : 1201 题目描述 如果一个正方形矩阵满足下述 全部 条件&#xff0c;则称之为一个 X矩阵 &#xff1a; 矩阵对角线上的所有元素都 不是 0 矩阵中所有其他元素都是 0 给你一个大小为 n x n的二维整数数组 grid&#…

ElasticSearch - 旅游酒店案例es功能实现

目录 案例 搜索与分页功能 条件过滤功能 附近的酒店功能 广告置顶功能 HotelService(es操作)总览 案例 搜索与分页功能 案例需求&#xff1a;实现旅游的酒店搜索功能&#xff0c;完成关键字搜索和分页实现步骤如下&#xff1a;1.定义实体类&#xff0c;接收前端请求实体…