李沐70_bert微调——自学笔记

news2024/11/28 6:53:54

微调BERT

1.BERT滴哦每一个词元返回抽取了上下文信息的特征向量

2.不同的任务使用不同的特性

句子分类

将cls对应的向量输入到全连接层分类

命名实体识别

1.识别应该词元是不是命名实体,例如人名、机构、位置

2.将非特殊词元放进全连接层分类

问题回答

1.给定应该问题和描述文字,找出一个片段作为回答

2.对片段中的每个词元预测它是不是回答的开头或结束

总结

1.即使下游任务各有不同,使用BERT微调时君只需要增加输出层

2.但根据任务的不同,输入的表示,和使用Bert特征也会不一样

!pip install d2l==0.17.6  ### 很重要,不要下载错了,对于colab

Natural Language Inference and Dataset

斯坦福自然语言推断(SNLI)数据集

import os
import re
import torch
from torch import nn
from d2l import torch as d2l

d2l.DATA_HUB['SNLI'] = (
    'https://nlp.stanford.edu/projects/snli/snli_1.0.zip',
    '9fcde07509c7e87ec61c640c1b2753d9041758e4')

data_dir = d2l.download_extract('SNLI')

读取数据集

定义函数read_snli以仅提取数据集的一部分,然后返回前提、假设及其标签的列表。*

def read_snli(data_dir, is_train):
    """将SNLI数据集解析为前提、假设和标签"""
    def extract_text(s):
        # 删除我们不会使用的信息
        s = re.sub('\\(', '', s)
        s = re.sub('\\)', '', s)
        # 用一个空格替换两个或多个连续的空格
        s = re.sub('\\s{2,}', ' ', s)
        return s.strip()
    label_set = {'entailment': 0, 'contradiction': 1, 'neutral': 2}
    file_name = os.path.join(data_dir, 'snli_1.0_train.txt'
                             if is_train else 'snli_1.0_test.txt')
    with open(file_name, 'r') as f:
        rows = [row.split('\t') for row in f.readlines()[1:]]
    premises = [extract_text(row[1]) for row in rows if row[0] in label_set]
    hypotheses = [extract_text(row[2]) for row in rows if row[0] \
                in label_set]
    labels = [label_set[row[0]] for row in rows if row[0] in label_set]
    return premises, hypotheses, labels

打印前3对前提和假设,以及它们的标签(“0”“1”和“2”分别对应于“蕴涵”“矛盾”和“中性”)。

train_data = read_snli(data_dir, is_train=True)
for x0, x1, y in zip(train_data[0][:3], train_data[1][:3], train_data[2][:3]):
    print('前提:', x0)
    print('假设:', x1)
    print('标签:', y)
前提: A person on a horse jumps over a broken down airplane .
假设: A person is training his horse for a competition .
标签: 2
前提: A person on a horse jumps over a broken down airplane .
假设: A person is at a diner , ordering an omelette .
标签: 1
前提: A person on a horse jumps over a broken down airplane .
假设: A person is outdoors , on a horse .
标签: 0

训练集约有550000对,测试集约有10000对。下面显示了训练集和测试集中的三个标签“蕴涵”“矛盾”和“中性”是平衡的。

test_data = read_snli(data_dir, is_train=False)
for data in [train_data, test_data]:
    print([[row for row in data[2]].count(i) for i in range(3)])
[183416, 183187, 182764]
[3368, 3237, 3219]

定义一个用于加载SNLI数据集的类。类构造函数中的变量num_steps指定文本序列的长度,使得每个小批量序列将具有相同的形状。

class SNLIDataset(torch.utils.data.Dataset):
    """用于加载SNLI数据集的自定义数据集"""
    def __init__(self, dataset, num_steps, vocab=None):
        self.num_steps = num_steps
        all_premise_tokens = d2l.tokenize(dataset[0])
        all_hypothesis_tokens = d2l.tokenize(dataset[1])
        if vocab is None:
            self.vocab = d2l.Vocab(all_premise_tokens + \
                all_hypothesis_tokens, min_freq=5, reserved_tokens=['<pad>'])
        else:
            self.vocab = vocab
        self.premises = self._pad(all_premise_tokens)
        self.hypotheses = self._pad(all_hypothesis_tokens)
        self.labels = torch.tensor(dataset[2])
        print('read ' + str(len(self.premises)) + ' examples')

    def _pad(self, lines):
        return torch.tensor([d2l.truncate_pad(
            self.vocab[line], self.num_steps, self.vocab['<pad>'])
                         for line in lines])

    def __getitem__(self, idx):
        return (self.premises[idx], self.hypotheses[idx]), self.labels[idx]

    def __len__(self):
        return len(self.premises)

我们可以调用read_snli函数和SNLIDataset类来下载SNLI数据集,并返回训练集和测试集的DataLoader实例,以及训练集的词表。

def load_data_snli(batch_size, num_steps=50):
    """下载SNLI数据集并返回数据迭代器和词表"""
    num_workers = d2l.get_dataloader_workers()
    data_dir = d2l.download_extract('SNLI')
    train_data = read_snli(data_dir, True)
    test_data = read_snli(data_dir, False)
    train_set = SNLIDataset(train_data, num_steps)
    test_set = SNLIDataset(test_data, num_steps, train_set.vocab)
    train_iter = torch.utils.data.DataLoader(train_set, batch_size,
                                             shuffle=True,
                                             num_workers=num_workers)
    test_iter = torch.utils.data.DataLoader(test_set, batch_size,
                                            shuffle=False,
                                            num_workers=num_workers)
    return train_iter, test_iter, train_set.vocab

在这里,我们将批量大小设置为128时,将序列长度设置为50,并调用load_data_snli函数来获取数据迭代器和词表。然后我们打印词表大小。

train_iter, test_iter, vocab = load_data_snli(128, 50)
len(vocab)
read 549367 examples
read 9824 examples


/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py:558: UserWarning: This DataLoader will create 4 worker processes in total. Our suggested max number of worker in current system is 2, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.
  warnings.warn(_create_warning_msg(





18678

现在我们打印第一个小批量的形状。与情感分析相反,我们有分别代表前提和假设的两个输入X[0]和X[1]。

for X, Y in train_iter:
    print(X[0].shape)
    print(X[1].shape)
    print(Y.shape)
    break
/usr/lib/python3.10/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.
  self.pid = os.fork()


torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128])

Fine-Tuning BERT

import json
import multiprocessing
import os
import torch
from torch import nn
from d2l import torch as d2l

加载预训练的BERT

提供了两个版本的预训练的BERT:“bert.base”与原始的BERT基础模型一样大,需要大量的计算资源才能进行微调,而“bert.small”是一个小版本,以便于演示。

d2l.DATA_HUB['bert.base'] = (d2l.DATA_URL + 'bert.base.torch.zip',
                             '225d66f04cae318b841a13d32af3acc165f253ac')
d2l.DATA_HUB['bert.small'] = (d2l.DATA_URL + 'bert.small.torch.zip',
                              'c72329e68a732bef0452e4b96a1c341c8910f81f')

两个预训练好的BERT模型都包含一个定义词表的“vocab.json”文件和一个预训练参数的“pretrained.params”文件。我们实现了以下load_pretrained_model函数来加载预先训练好的BERT参数。

def load_pretrained_model(pretrained_model, num_hiddens, ffn_num_hiddens,
                          num_heads, num_layers, dropout, max_len, devices):
    data_dir = d2l.download_extract(pretrained_model)
    # 定义空词表以加载预定义词表
    vocab = d2l.Vocab()
    vocab.idx_to_token = json.load(open(os.path.join(data_dir,
        'vocab.json')))
    vocab.token_to_idx = {token: idx for idx, token in enumerate(
        vocab.idx_to_token)}
    bert = d2l.BERTModel(len(vocab), num_hiddens, norm_shape=[256],
                         ffn_num_input=256, ffn_num_hiddens=ffn_num_hiddens,
                         num_heads=4, num_layers=2, dropout=0.2,
                         max_len=max_len, key_size=256, query_size=256,
                         value_size=256, hid_in_features=256,
                         mlm_in_features=256, nsp_in_features=256)
    # 加载预训练BERT参数
    bert.load_state_dict(torch.load(os.path.join(data_dir,
                                                 'pretrained.params')))
    return bert, vocab

加载和微调经过预训练BERT的小版本(“bert.small”)

devices = d2l.try_all_gpus()
bert, vocab = load_pretrained_model(
    'bert.small', num_hiddens=256, ffn_num_hiddens=512, num_heads=4,
    num_layers=2, dropout=0.1, max_len=512, devices=devices)
Downloading ../data/bert.small.torch.zip from http://d2l-data.s3-accelerate.amazonaws.com/bert.small.torch.zip...

微调BERT数据集

定义了一个定制的数据集类SNLIBERTDataset。在每个样本中,前提和假设形成一对文本序列,并被打包成一个BERT输入序列,如 图15.6.2所示。回想 14.8.4节,片段索引用于区分BERT输入序列中的前提和假设。利用预定义的BERT输入序列的最大长度(max_len),持续移除输入文本对中较长文本的最后一个标记,直到满足max_len。为了加速生成用于微调BERT的SNLI数据集,我们使用4个工作进程并行生成训练或测试样本。

class SNLIBERTDataset(torch.utils.data.Dataset):
    def __init__(self, dataset, max_len, vocab=None):
        all_premise_hypothesis_tokens = [[
            p_tokens, h_tokens] for p_tokens, h_tokens in zip(
            *[d2l.tokenize([s.lower() for s in sentences])
              for sentences in dataset[:2]])]

        self.labels = torch.tensor(dataset[2])
        self.vocab = vocab
        self.max_len = max_len
        (self.all_token_ids, self.all_segments,
         self.valid_lens) = self._preprocess(all_premise_hypothesis_tokens)
        print('read ' + str(len(self.all_token_ids)) + ' examples')

    def _preprocess(self, all_premise_hypothesis_tokens):
        pool = multiprocessing.Pool(4)  # 使用4个进程
        out = pool.map(self._mp_worker, all_premise_hypothesis_tokens)
        all_token_ids = [
            token_ids for token_ids, segments, valid_len in out]
        all_segments = [segments for token_ids, segments, valid_len in out]
        valid_lens = [valid_len for token_ids, segments, valid_len in out]
        return (torch.tensor(all_token_ids, dtype=torch.long),
                torch.tensor(all_segments, dtype=torch.long),
                torch.tensor(valid_lens))

    def _mp_worker(self, premise_hypothesis_tokens):
        p_tokens, h_tokens = premise_hypothesis_tokens
        self._truncate_pair_of_tokens(p_tokens, h_tokens)
        tokens, segments = d2l.get_tokens_and_segments(p_tokens, h_tokens)
        token_ids = self.vocab[tokens] + [self.vocab['<pad>']] \
                             * (self.max_len - len(tokens))
        segments = segments + [0] * (self.max_len - len(segments))
        valid_len = len(tokens)
        return token_ids, segments, valid_len

    def _truncate_pair_of_tokens(self, p_tokens, h_tokens):
        # 为BERT输入中的'<CLS>'、'<SEP>'和'<SEP>'词元保留位置
        while len(p_tokens) + len(h_tokens) > self.max_len - 3:
            if len(p_tokens) > len(h_tokens):
                p_tokens.pop()
            else:
                h_tokens.pop()

    def __getitem__(self, idx):
        return (self.all_token_ids[idx], self.all_segments[idx],
                self.valid_lens[idx]), self.labels[idx]

    def __len__(self):
        return len(self.all_token_ids)

下载完SNLI数据集后,我们通过实例化SNLIBERTDataset类来生成训练和测试样本。这些样本将在自然语言推断的训练和测试期间进行小批量读取。

# 如果出现显存不足错误,请减少“batch_size”。在原始的BERT模型中,max_len=512
batch_size, max_len, num_workers = 512, 128, d2l.get_dataloader_workers()
data_dir = d2l.download_extract('SNLI')
train_set = SNLIBERTDataset(d2l.read_snli(data_dir, True), max_len, vocab)
test_set = SNLIBERTDataset(d2l.read_snli(data_dir, False), max_len, vocab)
train_iter = torch.utils.data.DataLoader(train_set, batch_size, shuffle=True,
                                   num_workers=num_workers)
test_iter = torch.utils.data.DataLoader(test_set, batch_size,
                                  num_workers=num_workers)
read 549367 examples
read 9824 examples

微调BERT

用于自然语言推断的微调BERT只需要一个额外的多层感知机,该多层感知机由两个全连接层组成(请参见下面BERTClassifier类中的self.hidden和self.output)。这个多层感知机将特殊的“”词元的BERT表示进行了转换,该词元同时编码前提和假设的信息为自然语言推断的三个输出:蕴涵、矛盾和中性。

class BERTClassifier(nn.Module):
    def __init__(self, bert):
        super(BERTClassifier, self).__init__()
        self.encoder = bert.encoder
        self.hidden = bert.hidden
        self.output = nn.Linear(256, 3)

    def forward(self, inputs):
        tokens_X, segments_X, valid_lens_x = inputs
        encoded_X = self.encoder(tokens_X, segments_X, valid_lens_x)
        return self.output(self.hidden(encoded_X[:, 0, :]))

预训练的BERT模型bert被送到用于下游应用的BERTClassifier实例net中。

net = BERTClassifier(bert)

为了允许具有陈旧梯度的参数,标志ignore_stale_grad=True在step函数d2l.train_batch_ch13中被设置。我们通过该函数使用SNLI的训练集(train_iter)和测试集(test_iter)对net模型进行训练和评估。

lr, num_epochs = 1e-4, 5
trainer = torch.optim.Adam(net.parameters(), lr=lr)
loss = nn.CrossEntropyLoss(reduction='none')
d2l.train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs,
    devices)
loss 0.520, train acc 0.791, test acc 0.780
2536.5 examples/sec on [device(type='cuda', index=0)]

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1631560.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Android --- 常见UI组件

TextView 文本视图 设置字体大小&#xff1a;android:textSize"20sp" 用sp 设置颜色&#xff1a;android:textColor"#00ffff" 设置倍距(行距)&#xff1a;android:lineSpacingMultiplier"2" 设置具体行距&#xff1a;android:lineSpacingExtra&q…

闲话 ASP.NET Core 数据校验(二):FluentValidation 基本用法

前言 除了使用 ASP.NET Core 内置框架来校验数据&#xff0c;事实上&#xff0c;通过很多第三方框架校验数据&#xff0c;更具优势。 比如 FluentValidation&#xff0c;FluentValidation 是第三方的数据校验框架&#xff0c;具有许多优势&#xff0c;是开发人员首选的数据校验…

kyuubi、sparksql部署实战与连接

一、下载spark和kyuubi的软件包 spark官网下载 https://spark.apache.org/downloads.html kyuubi官网下载 https://www.apache.org/dyn/closer.lua/kyuubi/kyuubi-1.9.0/apache-kyuubi-1.9.0-bin.tgz 二、部署spark 1、spark配置spark-env.sh YARN_CONF_DIR/opt/cloudera…

大象机器人开源协作机械臂myCobot 630 全面升级!

1. 开篇概述 在快速发展的机器人技术领域中&#xff0c;Elephant Robotics的myCobot 600已经证明了其在教育、科研和轻工业领域的显著适用性。作为一款具备六自由度的机械臂&#xff0c;myCobot 600以其600mm的工作半径和2kg的末端负载能力&#xff0c;满足了多样化的操作需求。…

paddleocr C++生成dll

目录 编译完成后修改内容: 新建ppocr.h头文件 注释掉main.cpp内全部内容&#xff0c;将下面内容替换进去。ppocr.h需要再环境配置中包含进去头文件 然后更改配置信息&#xff0c;将exe换成dll 随后右击重新编译会在根目录生成dll,lib文件。 注意这些dll一个也不能少。生成…

HTML5(2)

目录 一.列表、表格、表单 1.列表标签 2.表格 4.无语义的布局标签 5.字符实体 6.综合案例--1 7.综合案例--表单 一.列表、表格、表单 1.列表标签 1.1 无序列表 1.2 有序列表 1.3 定义列表 定义列表一般用于网页底部的帮助中心 2.表格 2.1 2.2 表格结构标签 shiftaltf 格…

计算机网络之传输层TCP\UDP协议

UDP协议 用户数据报协议UDP概述 UDP只在IP数据报服务之上增加了很少功能&#xff0c;即复用分用和差错检测功能 UDP的主要特点&#xff1a; UDP是无连接的&#xff0c;减少开销和发送数据之前的时延 UDP使用最大努力交付&#xff0c;即不保证可靠交付&#xff0c;可靠性由U…

《Fundamentals of Power Electronics》——全桥型隔离降压转换器

以下是关于全桥型隔离降压转换器的相关知识点&#xff1a; 全桥变压器隔离型降压转换器如下图所示。 上图展示了一个具有二次侧绕组中心抽头的版本&#xff0c;该电路常用于产生低输出电压。二次侧绕组的上下两个绕组可以看作是两个单独的绕组&#xff0c;因此可以看成是具有变…

#const成员

基于写的日期类的实现&#xff0c;去分析const成员的问题。 需要用到下面的两个函数。 这个时候我们去构建一个const Date,如下&#xff0c;这是不和法的&#xff0c;为什么呢&#xff1f; 其实这里涉及一个权限放大的问题&#xff0c;之前就说了&#xff0c;权限可以缩小&…

项目部署总结

1、安装jdk 第一步&#xff1a;上传jdk压缩安装包到服务器 第二步&#xff1a;将压缩安装包解压 tar -xvf jdk-8uXXX-linux-x64.tar.gz 第三步&#xff1a;配置环境变量 编辑/etc/profile文件&#xff0c;在文件末尾添加以下内容&#xff1a; export JAVA_HOME/path/to/j…

应急学院物联网应急安全产教融合基地解决方案

第一章 背景 1.1物联网应急安全产教融合发展概况 物联网应急安全产教融合发展是当前社会发展的重要趋势。随着物联网技术的广泛应用&#xff0c;应急安全领域对人才的需求日益迫切。因此&#xff0c;产教融合成为培养高素质、专业化人才的关键途径。在这一背景下&#xff0c;…

Kotlin泛型之 循环引用泛型(A的泛型是B的子类,B的泛型是A的子类)

IDE(编辑器)报错 循环引用泛型是我起的名字&#xff0c;不知道官方的名字是什么。这个问题是我在定义Android 的MVP时提出来的。具体是什么样的呢&#xff1f;我们看一下我的基础的MVP定义&#xff1a; interface IPresenter<V> { fun getView(): V }interface IVie…

云原生Kubernetes: K8S 1.29版本 部署Sonarqube

一、实验 1.环境 &#xff08;1&#xff09;主机 表1 主机 主机架构版本IP备注masterK8S master节点1.29.0192.168.204.8 node1K8S node节点1.29.0192.168.204.9node2K8S node节点1.29.0192.168.204.10已部署Kuboard &#xff08;2&#xff09;master节点查看集群 1&…

【已解决】Python Selenium chromedriver Pycharm闪退的问题

概要 根据不同的业务场景需求&#xff0c;有时我们难免会使用程序来打开浏览器进行访问。本文在pycharm中使用selenium打开chromedriver出现闪退问题&#xff0c;根据不断尝试&#xff0c;最终找到的问题根本是版本问题。 代码如下 # (1) 导入selenium from selenium import …

RDD算子使用----transformation转换操作

RDD算子使用 transformation转换操作 map算子 rdd.map(p: A > B):RDD&#xff0c;对rdd集合中的每一个元素&#xff0c;都作用一次该func函数&#xff0c;之后返回值为生成元素构成的一个新的RDD。 val sc new SparkContext(conf)//map 原集合*7val list 1 to 7//构建…

纯血鸿蒙APP实战开发——全局状态保留能力弹窗

全局状态保留能力弹窗 介绍 全局状态保留能力弹窗一种很常见的能力&#xff0c;能够保持状态&#xff0c;且支持全局控制显隐状态以及自定义布局。使用效果参考评论组件 效果图预览 使用说明 使用案例参考短视频案例 首先程序入口页对全局弹窗初始化&#xff0c;使用Globa…

Redis基本數據結構 ― List

Redis基本數據結構 ― List 介紹常用命令範例1. 將元素推入List中2. 取得List內容3. 彈出元素 介紹 Redis中的List結構是一個雙向鏈表。 LPUSH LPOP StackLPUSH RPOP QueueLPUSH BRPOP Queue(消息隊列) 常用命令 命令功能LPUSH將元素推入列表左端RPUSH將元素推入列表右…

ZooKeeper 环境搭建详细教程之三(真集群)

ZooKeeper 搭建详细步骤之三(真集群) ZooKeeper 搭建详细步骤之二(伪集群模式) ZooKeeper 搭建详细步骤之一(单机模式) ZooKeeper 及相关概念简介 真集群搭建 搭建 ZooKeeper 真集群涉及多个步骤,包括准备环境、配置文件设置、启动服务以及验证集群状态。 以下是一个简…

如何基于Zookeeper实现注册中心模型?

在分布式系统中&#xff0c;通常会存在几十个甚至上百个服务&#xff0c;开发人员可能甚至都无法明确系统中到底有哪些服务正在运行。另一方面&#xff0c;我们很难同时确保所有服务都不出现问题&#xff0c;也很难保证当前的服务部署方式不做调整和优化。由于自动扩容、服务重…

天冕科技亮相第十七届深圳国际金融博览会!

第十七届深圳国际金融博览会在深圳会展中心正式开幕&#xff0c;天冕科技跟随南山区组团集体亮相&#xff0c;充分展现金融活力。此次金博会&#xff0c;南山区政府共遴选了包括天冕科技在内的三家优秀金融科技企业组团参展&#xff0c;以特色与创新的案例展示了辖区金融业发展…