Bert 在 OCNLI 训练微调

news2025/1/24 14:01:59

目录

  • 0 资料
  • 1 预训练权重
  • 2 wandb
  • 3 Bert-OCNLI
    • 3.1 目录结构
    • 3.2 导入的库
    • 3.3 数据集
      • 自然语言推断
      • 数据集路径
      • 读取数据集
      • 数据集样例展示
      • 数据集类别统计
      • 数据集类
      • 加载数据
    • 3.4 Bert
    • 3.4 训练
  • 4 训练微调结果
    • 3k
    • 10k
    • 50k

0 资料

【数据集微调】

阿里天池比赛 微调BERT的数据集(“任务1:OCNLI–中文原版自然语言推理”)

数据集地址:https://tianchi.aliyun.com/competition/entrance/531841/information

由于这个比赛已经结束,原地址提交不了榜单看测试结果,请参照下面的信息,下载数据集、提交榜单测试。

  • “任务1:OCNLI–中文原版自然语言推理”数据集的GitHub地址:https://github.com/CLUEbenchmark/OCNLI

  • 榜单提交地址:https://www.cluebenchmarks.com/index.html

  • 榜单提交步骤:

    • 打开“榜单提交地址”,点击“立即测评”——填写相关信息(github地址填https://github.com/CLUEbenchmark/CLUE,其他信息任意填)。
    • 上传一个.zip压缩文件,在压缩文件里存放我们模型预测结果的文件。
    • 点击提交。
  • 【注意】预测结果文件的格式:https://storage.googleapis.com/cluebenchmark/tasks/clue_submit_examples.zip

15.4. 自然语言推断与数据集:https://zh-v2.d2l.ai/chapter_natural-language-processing-applications/natural-language-inference-and-dataset.html

15.7. 自然语言推断:微调BERT:https://zh-v2.d2l.ai/chapter_natural-language-processing-applications/natural-language-inference-bert.html#id3

保姆级教程,用PyTorch和BERT进行文本分类:https://zhuanlan.zhihu.com/p/524487313

1 预训练权重

在国内,一般是手动下载预训练权重,而非网络自动下载。

我们将用到 chinese-macbert-base 这个预训练文件,下载网址如下:

https://huggingface.co/hfl/chinese-macbert-base/tree/main

除了叉掉的,其余都要下载。
在这里插入图片描述

2 wandb

pip install wandb

WandB 是一个用于实验跟踪、版本控制和结果可视化的工具,主要用于机器学习项目。
wandb使用教程(一):基础用法:https://zhuanlan.zhihu.com/p/493093033

3 Bert-OCNLI

3.1 目录结构

在这里插入图片描述

3.2 导入的库

import os
import torch
from torch import nn
import pandas as pd
from transformers import BertModel, BertTokenizer
from torch.optim import Adam
from tqdm import tqdm

3.3 数据集

自然语言推断

自然语言推断(natural language inference)主要研究 假设(hypothesis)是否可以从前提(premise)中推断出来, 其中两者都是文本序列。 换言之,自然语言推断决定了一对文本序列之间的逻辑关系。这类关系通常分为三种类型:

蕴涵(entailment):假设可以从前提中推断出来。

矛盾(contradiction):假设的否定可以从前提中推断出来。

中性(neutral):所有其他情况。

自然语言推断也被称为识别文本蕴涵任务。 例如,下面的一个文本对将被贴上“蕴涵”的标签,因为假设中的“表白”可以从前提中的“拥抱”中推断出来。

前提:两个女人拥抱在一起。

假设:两个女人在示爱。

下面是一个“矛盾”的例子,因为“运行编码示例”表示“不睡觉”,而不是“睡觉”。

前提:一名男子正在运行Dive Into Deep Learning的编码示例。

假设:该男子正在睡觉。

第三个例子显示了一种“中性”关系,因为“正在为我们表演”这一事实无法推断出“出名”或“不出名”。

前提:音乐家们正在为我们表演。

假设:音乐家很有名。

自然语言推断一直是理解自然语言的中心话题。它有着广泛的应用,从信息检索到开放领域的问答。为了研究这个问题,我们将首先研究一个流行的自然语言推断基准数据集。

数据集路径

# 数据集路径
data_dir = 'OCNLI/data/ocnli'

读取数据集

# 读ocnli,两个参数,data_dir是数据集的路径,is_train为bool类型,True代表训练,False代表验证
def read_ocnli(data_dir, is_train):
    # 将ocnli解析为前提、假设、标签
    # labels_map是标签映射,0、1、2代表三类,3代表无法分类(或者应该去除的数据)。
    labels_map = {'entailment':0, 'neutral':1, 'contradiction':2, '-': 3}
    file_name = os.path.join(data_dir, 'train.3k.json' if is_train else 'dev.json')
    rows = pd.read_json(file_name, lines=True)
    
    premises = [sentence1 for sentence1 in rows['sentence1'] ]  # 前提
    hypotheses = [sentence2 for sentence2 in rows['sentence2'] ] # 假设
    # if label != '-' 是为了去除无法分类的标签
    labels = [labels_map[label] for label in rows['label'] if label != '-'] # 标签
    return premises, hypotheses, labels
    

数据集样例展示

# 样例展示
train_data = read_ocnli(data_dir, is_train=True)
for x0, x1, y in zip(train_data[0][:3], train_data[1][:3], train_data[2][:3]):
    print("前提:", x0)
    print("假设:", x1)
    print("标签:", y)

结果:

前提: 现在,我代表国务院,向大会报告政府工作,请予审议,并请全国政协委员提出意见
假设: 全国政协委员无权提出建议
标签: 2
前提: 不过以后呢,两年增加一次工资.
假设: 多年之后工资很高
标签: 1
前提: 一万块,嗯那头盔要八千.
假设: 说话的人很有钱
标签: 1

数据集类别统计

# 类别数据统计
val_data = read_ocnli(data_dir, is_train=False)

label_set = [0, 1, 2]

for data in [train_data, val_data]:
    print([
        [
            row for row in data[2]
        ].count(i) for i in label_set
    ])

结果:

[974, 1054, 966]
[947, 1103, 900]

数据集类

tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
class OCNLI_Dataset(torch.utils.data.Dataset):
    def __init__(self, dataset):
        sentence1 = [sentence1 for sentence1 in dataset[0]]
        sentence2 = [sentence2 for sentence2 in dataset[1]]
        # 用 _ 将前提和假设拼接在一起,但这应该不是好的做法
        sentence1_2 = ['{}_{}'.format(a, b) for a, b in zip(sentence1, sentence2)]
        self.texts = [tokenizer(
            sentence, 
            padding='max_length', 
            # bert最大可以设置到512,对OCNLI的统计计算中,
            # 发现所有数据没有超过128,max_length越大,计算量越大
            max_length = 128, 
            truncation=True,
            return_tensors="pt"
        ) for sentence in sentence1_2 ] 
        self.labels = torch.tensor(dataset[2])
    
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        return self.texts[idx], self.labels[idx]

加载数据

train_set = OCNLI_Dataset(read_ocnli(data_dir, True))
test_set = OCNLI_Dataset(read_ocnli(data_dir, False))
print(len(train_set))
# for train_input, train_label in train_set:
#     print(train_input)
#     print(train_label)
#     input()

结果:

3000

3.4 Bert

class BertClassifier(nn.Module):
    def __init__(self, dropout=0.5):
        super(BertClassifier, self).__init__()
        self.bert = BertModel.from_pretrained('bert-base-chinese')
        self.dropout = nn.Dropout(dropout)
        self.linear = nn.Linear(768, 3) # 这里的3代表输出的类别
        self.relu = nn.ReLU()

    def forward(self, input_id, mask):
        _, pooled_output = self.bert(input_ids= input_id, attention_mask=mask,return_dict=False)
        dropout_output = self.dropout(pooled_output)
        linear_output = self.linear(dropout_output)
        final_layer = self.relu(linear_output)
        return final_layer

3.4 训练

def train(model, train_data, val_data, learning_rate, epochs):
    # 通过Dataset类获取训练和验证集
    train, val = OCNLI_Dataset(train_data), OCNLI_Dataset(val_data)
    # DataLoader根据batch_size获取数据,训练时选择打乱样本
    train_dataloader = torch.utils.data.DataLoader(train, batch_size=32, shuffle=True)
    val_dataloader = torch.utils.data.DataLoader(val, batch_size=32)
    # 判断是否使用GPU
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")
    # 定义损失函数和优化器
    criterion = nn.CrossEntropyLoss()
    optimizer = Adam(model.parameters(), lr=learning_rate)

    if use_cuda:
        model = model.cuda()
        criterion = criterion.cuda()
    # 开始进入训练循环
    for epoch_num in range(epochs):
        # 定义两个变量,用于存储训练集的准确率和损失
        total_acc_train = 0
        total_loss_train = 0
        # 进度条函数tqdm
        for train_input, train_label in tqdm(train_dataloader):
            train_label = train_label.to(device)
            mask = train_input['attention_mask'].to(device)
            input_id = train_input['input_ids'].squeeze(1).to(device)
            # 通过模型得到输出
            output = model(input_id, mask)
            # 计算损失
            batch_loss = criterion(output, train_label)
            # input()
            total_loss_train += batch_loss.item()
            # print("total_loss_train:",total_loss_train)
            # 计算精度
            acc = (output.argmax(dim=1) == train_label).sum().item()
            total_acc_train += acc
            # 模型更新
            model.zero_grad()
            batch_loss.backward()
            optimizer.step()
            # ------ 验证模型 -----------
            # 定义两个变量,用于存储验证集的准确率和损失
            total_acc_val = 0
            total_loss_val = 0
            # 不需要计算梯度
            with torch.no_grad():
                # 循环获取数据集,并用训练好的模型进行验证
                for val_input, val_label in val_dataloader:
                    # 如果有GPU,则使用GPU,接下来的操作同训练
                    val_label = val_label.to(device)
                    mask = val_input['attention_mask'].to(device)
                    input_id = val_input['input_ids'].squeeze(1).to(device)
  
                    output = model(input_id, mask)

                    batch_loss = criterion(output, val_label)
                    total_loss_val += batch_loss.item()
                    
                    acc = (output.argmax(dim=1) == val_label).sum().item()
                    total_acc_val += acc

    
        print(
            f'''Epochs: {epoch_num + 1} 
          | Train Loss: {total_loss_train / len(train): .3f} 
          | Train Accuracy: {total_acc_train / len(train): .3f} 
          | Val Loss: {total_loss_val / len(train): .3f} 
          | Val Accuracy: {total_acc_val / len(train): .3f}''')     
        print("total_loss_train:",total_loss_train)
        print("total_acc_train:",total_acc_train)
        print("total_loss_val:",total_loss_val)
        print("total_acc_val:",total_acc_val)
        print("len(train_data):",len(train))          
EPOCHS = 50
model = BertClassifier()
LR = 1e-6
train(model, read_ocnli(data_dir, True), read_ocnli(data_dir, False), LR, EPOCHS)

在这里插入图片描述

4 训练微调结果

3k

10k

50k

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1658462.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

LeetCode HOT 100刷题总结

文章目录 1 哈希1.1 1-1.两数之和🟢1.2 2-49.字母异位词分组🟡1.3 3-128.最长连续序列🟡 2 双指针2.1 4-283.移动零🟢2.2 6-15.三数之和🟡2.3 7-11.盛最多水的容器🟡2.4 8-42.接雨水🔴 3 滑动窗…

一、计算机基础(Java零基础一)

🌻🌻目录 一、🌻🌻剖析学习Java前的疑问🌻🌻1.1 零基础学习编程1.2 英语不好能学吗?1.3 理解慢能学好吗?1.4 现在学Java晚吗?1.5 Java 和 Python 还有 Go 的选择1.6 Java…

WINDOWS下zookeeper突然无法启动但是端口未占用的解决办法(用了WSL)

windows下用着用着时候突然zookeeper启动不了了。netstat查也没有找到端口占用,就是起不来。控制台报错 java.lang.reflect.UndeclaredThrowableException: nullat org.springframework.util.ReflectionUtils.rethrowRuntimeException(ReflectionUtils.java:147) ~…

Semaphore

文章目录 基本使用Semaphore 应用-改进数据库连接池Semaphore 原理1. 加锁解锁流程2. 源码分析 基本使用 信号量,用来限制能同时访问共享资源的线程上限。 public static void main(String[] args) {// 1. 创建 semaphore 对象Semaphore semaphore new Semaphore(…

let命令

let 命令 let 与 var 二者区别: 作用域不同:变量提升(Hoisting):临时性死区重复声明: 联系:举例说明: 块级作用域 块级作用域的关键字使用 var(无块级作用域)…

C++中的std::bind深入剖析

目录 1.概要 2.原理 3.源码分析 3.1._Binder分析 3.2._CALL_BINDER的实现 4.总结 1.概要 std::bind是C11 中的一个函数模板,用于创建一个可调用对象(函数对象或者函数指针)的绑定副本,其中一部分参数被固定为指定值&#xf…

​​​​【收录 Hello 算法】4.4 内存与缓存

目录 4.4 内存与缓存 4.4.1 计算机存储设备 4.4.2 数据结构的内存效率 4.4.3 数据结构的缓存效率 4.4 内存与缓存 在本章的前两节中,我们探讨了数组和链表这两种基础且重要的数据结构,它们分别代表了“连续存储”和“分散存储”两种物理…

Qt常用基础控件总结

一、按钮部件 按钮部件共同特性 Qt 用于描述按钮部件的类、继承关系、各按钮的名称和样式,如下图: 助记符:使用字符"&“可在为按钮指定文本标签时设置快捷键,在&之后的字符将作为快捷键。比如 “A&BC” 则 Alt+B 将成为该按钮的快捷键,使用”&&qu…

铁山靠之数学建模 - Matlab入门

Matlab基础 1. Matlab界面与基本操作1.1 matlab帮助系统1.2 matlab命令1.3 matlab功能符号1.4 matlab的数据类型1.5 函数计算1.6 matlab向量1.7 matlab多项式1.8 M文件1.9 函数文件1.10 matlab的程序结构1.11 echo、warning和error函数1.12 交互输入1.13 程序调试1.14 设置断点…

游戏陪玩平台app小程序H5源码交付游戏陪玩接单软件游戏陪玩源码 陪玩小程序陪玩工作室运营模式陪玩管理系统游戏陪玩工作室怎么做

提供陪玩平台源码,陪玩系统源码,陪玩app源码,团队各部门配备齐全,分工明确,及时对接开发进度,保证开发效率 一、陪玩平台源码的功能介绍 1、派单大厅:陪玩系统源码的派单大厅内支持用户通过语音连麦的方式…

idea已配置的git仓库地址 更换新的Git仓库地址 教程

文章目录 目录 文章目录 更改流程 小结 概要更改流程技术细节小结 概要 先在idea控制台走一下流程 先将本地的git仓库删除 1. 查看当前远程仓库地址: 在终端或命令行中,导航到你的项目目录,并运行以下命令查看当前的远程仓库地址&#xff…

QT+MYSQL数据库处理

1、打印Qt支持的数据库驱动&#xff0c;看是否有MYSQL数据库驱动 qDebug() << QSqlDatabase::drivers(); 有打印结果可知&#xff0c;没有MYSQL数据库的驱动 2、下载MYSQL数据库驱动&#xff0c;查看下面的文章配置&#xff0c;亲测&#xff0c;可以成功 Qt6 配置MySQL…

智能BI(后端)-- 系统异步化

文章目录 系统问题分析什么是异步化&#xff1f;业务流程分析标准异步化的业务流程系统业务流程 线程池为什么需要线程池&#xff1f;线程池两种实现方式线程池的参数线程池的开发 项目异步化改造 系统问题分析 问题场景&#xff1a;调用的服务能力有限&#xff0c;或者接口的…

phpstudy(MySQL启动又立马停止)问题的解决办法

方法一&#xff1a;查看本地安装的MySQL有没有启动 1.鼠标右击开始按钮选择计算机管理 2.点击服务和应用程序 3.找到服务双击 4.找到MySQL服务 5.双击查看是否启动&#xff0c;如启动则停止他&#xff0c;然后确定&#xff0c;重新打开phpstudy,启动Mysql. 方法二&#xff…

OpenHarmony 实战开发——3.1 Release + Linux 原厂内核Launcher起不来问题分析报告

1、关键字 Launcher 无法启动&#xff1b;原厂内核&#xff1b;Access Token ID&#xff1b; 2、问题描述 芯片&#xff1a;rk3566&#xff1b;rk3399 内核版本&#xff1a;Linux 4.19&#xff0c;是 RK 芯片原厂发布的 rk356x 4.19 稳定版内核 OH 版本&#xff1a;OpenHa…

net7部署经历

1、linux安装dotnet命令&#xff1a; sudo yum install dotnet-sdk-7.0 或者直接在商店里安装 2、配置反向代理 127.0.0.1:5000》localhost 访问后报错 原因&#xff1a;数据表驼峰名&#xff0c; 在windows的数据表不区分大小写&#xff0c;但是在linux里面是默认区分的&…

xiuno(修罗)知乎模板二开优化魔板仿网盘资源社–模板加全套插件

使用说明 以服务器为例搭建教程 ①先安装 PHP7.1 版本 再安装数据库 Mysql ②解压文件&#xff1a;xiunobbs_4.0.4&#xff08;解压到根目录&#xff09;.zip ③解压②完成后找到【plugin】文件夹再解压&#xff1a;plugin(解压到 plugin 文件夹).zip 设置伪静态代码在上面&am…

记录如何查询域名txt解析是否生效

要查询域名的TXT记录&#xff0c;可以使用nslookup命令。具体步骤如下&#xff1a;12 打开命令行终端。输入命令 nslookup -qttxt 域名&#xff0c;将"域名"替换为你要查询的实际域名。执行命令后&#xff0c;nslookup会返回域名的TXT记录值。 如何查询域名txt解析是…

【C++后端项目】负载均衡OJ服务器

文章目录 一、演示项目二、所用技术与开发环境所用技术开发环境 三、项目宏观结构I. 风格&#xff1a;仿leetcodeII. 结构&#xff1a;Browser-Server模式III. 编写思路&#xff1a;编译服务 -> OJ服务 -> 前端设计 四、关于Git分支管理✨4.1 Git 分支结构4.2 Git 分支命…

【linux】主分区,扩展分区,逻辑分区,动态分区,引导分区,标准分区

目录 主分区&#xff0c;扩展分区&#xff0c;逻辑分区 主分区和引导分区 主分区&#xff0c;扩展分区&#xff0c;逻辑分区&#xff08;标准分区&#xff09; 硬盘一般划分为一个“主分区”和“扩展分区”&#xff0c;然后在扩展分区上再分成数个逻辑分区。 磁盘主分区扩展…