BERT在GLUE数据集构建任务

news2025/1/13 17:27:07

0 Introduction

谷歌开源的BERT项目在Github上,视频讲解可以参考B站上的一个视频

1 GLUE部分基准数据集介绍

  • GLUE数据集官网
  • GLUE数据集下载,建议下载运行这个download_glue_data.py文件进行数据集的下载,如果链接无法打开,运行下面代码,运行后,会自动下载GLUE数据集到本地项目文件夹中,所包含的数据集有CoLA,diagnostic,MNLI,MRPC,QNLI,QQP,RTE,SST-2,STS-B,WNLI等,关于这些数据集的详细中文介绍,参考这篇博客,本例是在MRPC数据集上构建任务。
  • 关于微软的MRPC数据集:本例中是在MRPC数据集上进行构建的,因为MRPC数据集较小,只有3600多条文本数据,但如下面代码中的注释所说,由于版权问题,不再托管MRPC数据集,需要手动下载。下载方式:首先去官网,下载到MSRParaphraseCorpus.msi文件,双击安装后,会产生一个文件夹,里面即包含了MPRC数据。
    数据集搞定后,文件结构如下图
    在这里插入图片描述
    以下是用于下载GLUE数据集的脚本文件download_glue_data.py,如果下载数据集有困难,可以去百度网盘下载
    链接:https://pan.baidu.com/s/1D_AJ_GgWgaPuYbror_jUNg
    提取码:9k9r
    –来自百度网盘超级会员V4的分享
''' Script for downloading all GLUE data.

Note: for legal reasons, we are unable to host MRPC.
You can either use the version hosted by the SentEval team, which is already tokenized, 
or you can download the original data from (https://download.microsoft.com/download/D/4/6/D46FF87A-F6B9-4252-AA8B-3604ED519838/MSRParaphraseCorpus.msi) and extract the data from it manually.
For Windows users, you can run the .msi file. For Mac and Linux users, consider an external library such as 'cabextract' (see below for an example).
You should then rename and place specific files in a folder (see below for an example).

mkdir MRPC
cabextract MSRParaphraseCorpus.msi -d MRPC
cat MRPC/_2DEC3DBE877E4DB192D17C0256E90F1D | tr -d $'\r' > MRPC/msr_paraphrase_train.txt
cat MRPC/_D7B391F9EAFF4B1B8BCE8F21B20B1B61 | tr -d $'\r' > MRPC/msr_paraphrase_test.txt
rm MRPC/_*
rm MSRParaphraseCorpus.msi

1/30/19: It looks like SentEval is no longer hosting their extracted and tokenized MRPC data, so you'll need to download the data from the original source for now.
2/11/19: It looks like SentEval actually *is* hosting the extracted data. Hooray!
'''

import os
import sys
import shutil
import argparse
import tempfile
import urllib.request
import zipfile

import urllib as URLLIB
import urllib.response
import urllib.parse
import io
# from six.moves import urllib


TASKS = ["CoLA", "SST", "MRPC", "QQP", "STS", "MNLI", "QNLI", "RTE", "WNLI", "diagnostic"]
TASK2PATH = {"CoLA":'https://dl.fbaipublicfiles.com/glue/data/CoLA.zip',
             "SST":'https://dl.fbaipublicfiles.com/glue/data/SST-2.zip',
             "QQP":'https://dl.fbaipublicfiles.com/glue/data/QQP-clean.zip',
             "STS":'https://dl.fbaipublicfiles.com/glue/data/STS-B.zip',
             "MNLI":'https://dl.fbaipublicfiles.com/glue/data/MNLI.zip',
             "QNLI":'https://dl.fbaipublicfiles.com/glue/data/QNLIv2.zip',
             "RTE":'https://dl.fbaipublicfiles.com/glue/data/RTE.zip',
             "WNLI":'https://dl.fbaipublicfiles.com/glue/data/WNLI.zip',
             "diagnostic":'https://dl.fbaipublicfiles.com/glue/data/AX.tsv'}

MRPC_TRAIN = 'https://dl.fbaipublicfiles.com/senteval/senteval_data/msr_paraphrase_train.txt'
MRPC_TEST = 'https://dl.fbaipublicfiles.com/senteval/senteval_data/msr_paraphrase_test.txt'

def download_and_extract(task, data_dir):
    print("Downloading and extracting %s..." % task)
    if task == "MNLI":
        print("\tNote (12/10/20): This script no longer downloads SNLI. You will need to manually download and format the data to use SNLI.")
    data_file = "%s.zip" % task
    urllib.request.urlretrieve(TASK2PATH[task], data_file)
    with zipfile.ZipFile(data_file) as zip_ref:
        zip_ref.extractall(data_dir)
    os.remove(data_file)
    print("\tCompleted!")

def format_mrpc(data_dir, path_to_data):
    print("Processing MRPC...")
    mrpc_dir = os.path.join(data_dir, "MRPC")
    if not os.path.isdir(mrpc_dir):
        os.mkdir(mrpc_dir)
    if path_to_data:
        mrpc_train_file = os.path.join(path_to_data, "msr_paraphrase_train.txt")
        mrpc_test_file = os.path.join(path_to_data, "msr_paraphrase_test.txt")
    else:
        try:
            mrpc_train_file = os.path.join(mrpc_dir, "msr_paraphrase_train.txt")
            mrpc_test_file = os.path.join(mrpc_dir, "msr_paraphrase_test.txt")
            URLLIB.request.urlretrieve(MRPC_TRAIN, mrpc_train_file)
            URLLIB.request.urlretrieve(MRPC_TEST, mrpc_test_file)
        except urllib.error.HTTPError:
            print("Error downloading MRPC")
            return
    assert os.path.isfile(mrpc_train_file), "Train data not found at %s" % mrpc_train_file
    assert os.path.isfile(mrpc_test_file), "Test data not found at %s" % mrpc_test_file

    with io.open(mrpc_test_file, encoding='utf-8') as data_fh, \
            io.open(os.path.join(mrpc_dir, "test.tsv"), 'w', encoding='utf-8') as test_fh:
        header = data_fh.readline()
        test_fh.write("index\t#1 ID\t#2 ID\t#1 String\t#2 String\n")
        for idx, row in enumerate(data_fh):
            label, id1, id2, s1, s2 = row.strip().split('\t')
            test_fh.write("%d\t%s\t%s\t%s\t%s\n" % (idx, id1, id2, s1, s2))

    try:
        URLLIB.request.urlretrieve(TASK2PATH["MRPC"], os.path.join(mrpc_dir, "dev_ids.tsv"))
    except KeyError or urllib.error.HTTPError:
        print("\tError downloading standard development IDs for MRPC. You will need to manually split your data.")
        return

    dev_ids = []
    with io.open(os.path.join(mrpc_dir, "dev_ids.tsv"), encoding='utf-8') as ids_fh:
        for row in ids_fh:
            dev_ids.append(row.strip().split('\t'))

    with io.open(mrpc_train_file, encoding='utf-8') as data_fh, \
         io.open(os.path.join(mrpc_dir, "train.tsv"), 'w', encoding='utf-8') as train_fh, \
         io.open(os.path.join(mrpc_dir, "dev.tsv"), 'w', encoding='utf-8') as dev_fh:
        header = data_fh.readline()
        train_fh.write(header)
        dev_fh.write(header)
        for row in data_fh:
            label, id1, id2, s1, s2 = row.strip().split('\t')
            if [id1, id2] in dev_ids:
                dev_fh.write("%s\t%s\t%s\t%s\t%s\n" % (label, id1, id2, s1, s2))
            else:
                train_fh.write("%s\t%s\t%s\t%s\t%s\n" % (label, id1, id2, s1, s2))
                
    print("\tCompleted!")
    
def download_diagnostic(data_dir):
    print("Downloading and extracting diagnostic...")
    if not os.path.isdir(os.path.join(data_dir, "diagnostic")):
        os.mkdir(os.path.join(data_dir, "diagnostic"))
    data_file = os.path.join(data_dir, "diagnostic", "diagnostic.tsv")
    urllib.request.urlretrieve(TASK2PATH["diagnostic"], data_file)
    print("\tCompleted!")
    return

def get_tasks(task_names):
    task_names = task_names.split(',')
    if "all" in task_names:
        tasks = TASKS
    else:
        tasks = []
        for task_name in task_names:
            assert task_name in TASKS, "Task %s not found!" % task_name
            tasks.append(task_name)
    return tasks

def main(arguments):
    parser = argparse.ArgumentParser()
    parser.add_argument('--data_dir', help='directory to save data to', type=str, default='glue_data')
    parser.add_argument('--tasks', help='tasks to download data for as a comma separated string',
                        type=str, default='all')
    parser.add_argument('--path_to_mrpc', help='path to directory containing extracted MRPC data, msr_paraphrase_train.txt and msr_paraphrase_text.txt',
                        type=str, default='')
    args = parser.parse_args(arguments)

    if not os.path.isdir(args.data_dir):
        os.mkdir(args.data_dir)
    tasks = get_tasks(args.tasks)

    for task in tasks:
        if task == 'MRPC':
            format_mrpc(args.data_dir, args.path_to_mrpc)
        elif task == 'diagnostic':
            download_diagnostic(args.data_dir)
        else:
            download_and_extract(task, args.data_dir)


if __name__ == '__main__':
    sys.exit(main(sys.argv[1:]))

2 下载BERT项目

  1. 用Git工具从BERT开源项目上把项目完整克隆下来,如何使用Git工具从GitHub或者Gitee上克隆项目,参见这里
  2. 创建一个用于该项目的虚拟环境,参见这里
  3. 在该虚拟环境下,下载安装该项目所需的依赖,即requirements.txt,安装方式,在激活当前虚拟环境的情况下,执行pip install -r D:\Code\BERT\bert\requirements.txt,其中D:\Code\BERT\bert\requirements.txt所在路径

3. 运行BERT项目

运行run_classifier.py文件,该文件为带参数文件,而且参数较多,传入参数方式有若干种。
1. 方法1
在PyCharm->Edit Configuration->Parameters中输入--task_name=MRPC --do_train=true --do_eval=true --data_dir=../GLUE/glue_data/MRPC --vocab_file=../uncased_L-12_H-768_A-12/vocab.txt --bert_config_file=../uncased_L-12_H-768_A-12/bert_config.json --init_checkpoint=../uncased_L-12_H-768_A-12/bert_model.ckpt --max_seq_length=128 --train_batch_size=8 --learning_rate=2e-5 --num_train_epochs=3.0 --output_dir=../output
2. 方法2
上面都设置好后,即可运行run_classifier.py,大约需要十几秒(一张2080TI的显卡),模型训练结果在保存在预先创建的output文件夹下

遇到的问题

  1. 下载的BERT项目中,requirements.txt如下
tensorflow >= 1.11.0   # CPU Version of TensorFlow.
# tensorflow-gpu  >= 1.11.0  # GPU version of TensorFlow.

由于直接执行了pip install -r D:\Code\BERT\bert\requirements.txt,因此安装了最新版本的tensorflow==2.11,出现了一大堆API问题,列举2个如下

  • AttributeError: module ‘tensorflow._api.v2.train‘ has no attribute ‘Optimizer‘ -->> 改为tf.keras.optimizers.Optimizer
  • AttributeError: module 'tensorflow' has no attribute 'flags'–>>改为flags = tf.compat.v1.flags.FLAGS

修改后,API问题依旧存在,而且越改越多,隧把创建的虚拟环境删除了,重新创建,修改了requirements.txt如下,安装了GPU版本的tensorflow==1.13.1版本。在选择tensorflow版本时,注意要和自己的CUDA版本对应,如何查看自己的CUDA版本,参考这里

# tensorflow == 1.13.1   # CPU Version of TensorFlow.
tensorflow-gpu  == 1.13.1  # GPU version of TensorFlow.

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/607812.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

七、Gitee码云的注册及使用(二)

1、创建远程仓库 (1)登录Gitee.com,点击右上角 号,再点击新建仓库。 (2)填写仓库名称,路径,仓库介绍 (3)选择是否开源 (4)初始化仓库 开源许可证:主要包括开源是否可以随意转载,开源但不能商业使用&…

yolov8改进大作战,开箱即用,提供yolov8魔术师专栏代码

1.yolov8魔术师专栏介绍 开箱即用:提供 yolov8魔术师专栏 代码,方便直接使用,无需自己重新添加引起的一些bug问题: https://blog.csdn.net/m0_63774211/category_12289773.html?spm1001.2014.3001.5482 专栏内容如下&#xff…

【深入浅出Spring Security(五)】自定义过滤器进行前后端登录认证

自定义过滤器 一、自定义过滤器自定义登录认证过滤器自定义 LoginFilter配置 LoginFilter测试 二、总结 一、自定义过滤器 在【深入浅出Spring Security(二)】Spring Security的实现原理 中小编阐述了默认加载的过滤器,里面有些过滤器有时并…

OpenGL实现第一个窗口-三角形

1.简介 此代码是基于QtOpenGL实现的,但是大部分的代码是OpenGL,Qt封装了一些类,方便使用。 2.准备工作 QOpenGLWidget提供了三个便捷的虚函数,可以重写,用来重写实现典型的OpenGL任务。不需要GLFW。 paintGL&#…

【C语言】Visual Studio社区版安装配置环境(保姆级图文)

目录 1. 官网下载社区版2. 选择安装项目2.1 点击使用C的桌面开发2.2 语言包选择简体中文2.3 设置安装位置 3. 创建新项目3.1 点击创建新项目3.2 点击空项目,下一步3.3 设置项目名称路径3.4 创建项目 4. 测试例程总结 欢迎关注 『C语言』 系列,持续更新中…

代码随想录 二叉树 Java (一)

文章目录 (简单)144. 二叉树的前序遍历(简单)94. 二叉树的中序遍历(简单)145. 二叉树的后序遍历二叉树的统一遍历方法(参考代码随想录)(中等)102. 二叉树的层…

横岗茂盛村旧改,已立项,一期已拆平。

项目位于龙岗区横岗街道红棉路与茂盛路交汇处,距离轨道3号线横岗站约700米。 茂盛片区城市更新单元规划(草案)已经在近日公示,该旧改被纳入《2012年深圳市城市更新单元计划第五批计划》,2019年曾被暂停,20…

Redis实战14-分布式锁基本原理和不同实现方式对比

在上一篇文章中,我们知道了,当在集群环境下,synchronized关键字实现的JVM级别锁会失效的。那么怎么解决这个问题呢?我们可以使用分布式锁来解决。本文咱们就来介绍分布式锁基本原理以及不同实现方式对比。 我们先来回顾&#xff…

【深度学习】混合精度训练与显存分析

混合精度训练与显存分析 ​ 关于参数精度的介绍可以见文章https://zhuanlan.zhihu.com/p/604338403 相关博客 【深度学习】混合精度训练与显存分析 【深度学习】【分布式训练】Collective通信操作及Pytorch示例 【自然语言处理】【大模型】大语言模型BLOOM推理工具测试 【自然语…

(论文阅读)Chain-of-Thought Prompting Elicits Reasoningin Large Language Models

论文地址 https://openreview.net/pdf?id_VjQlMeSB_J 摘要 我们探索如何生成一个思维链——一系列中间推理步骤——如何显著提高大型语言模型执行复杂推理的能力。 特别是,我们展示了这种推理能力如何通过一种称为思维链提示的简单方法自然地出现在足够大的语言模…

2023 更新版:苏生不惑开发过的那些原创工具和脚本

苏生不惑第431 篇原创文章,将本公众号设为星标,第一时间看最新文章。 4年来苏生不惑这个公众号已经写了400多篇原创文章,去年分享过文章更新版:整理下苏生不惑开发过的那些工具和脚本 ,今年再更新下我开发过的原创工具…

【Python开发】FastAPI 07:Depends 依赖注入

在 FastAPI 中,Depends 是一个依赖注入系统,用于注入应用程序中所需的依赖项,通过 Depends,我们可以轻松地将依赖项注入到 FastAPI 路由函数中。简单来说,Depends 依赖注入的目的就是将代码重复最小! 目录 …

Vue学习3

文章目录 Vuex工作原理配置环境各种函数mapState对象写法数组写法 MapGetterMapMutations对象写法数组写法 Mapaction总结 模块化模块化1总结 Vuex 工作原理 那三个要通过store管理 配置环境 使用import时,回先执行Import中的代码,在后面的也会提前。 index.js…

Vscode利用ssh登录ubuntu开发环境下,代码不能跳转问题解决

0 开发环境 环境:VScode remote ssh 虚拟机Ubuntu22.04 1 问题记录 在win环境下,Vscode可以实现代码跳转。但是,在利用VScode的ssh登录Ubuntu下,代码不能进行跳转。 网上看到很多帖子,有的更改settings.json&…

【Ubuntu】保姆级图文介绍双系统win10卸载Ubuntu16.04

文章目录 删除Ubuntu分区数据删除Ubuntu启动项 这段时间想将前几年安装的Ubuntu16.04版本升级到Ubuntu20.04。 折腾了一番,升级失败了。想着还不如卸载了重新安装Ubuntu20.04。 由于Ubuntu16.04在升级过程中出现了一些问题,导致进不去Ubuntu系统。因此只…

tinkerCAD入门操作(2):移动、旋转和缩放对象

tinkerCAD入门操作:移动、旋转和缩放对象 介绍 现在您已经学会了如何在工作平面上旋转,是时候真正开始处理对象了。 在本课中,您将了解有关对象物理属性的更多信息。 放置一个盒子 我们需要一个对象来操作。让我们从一个盒子开始。在提示…

使用Druid数据源并查看监控页面

💧 使 用 D r u i d 数 据 源 并 查 看 监 控 信 息 \color{#FF1493}{使用Druid数据源并查看监控信息} 使用Druid数据源并查看监控信息💧 🌷 仰望天空,妳我亦是行人.✨ 🦄 个人主页——微风撞见云的博客&…

百度狂问3小时,大厂offer到手,小伙真狠!(百度面试真题)

前言: 在40岁老架构师尼恩的(50)读者社群中,经常有小伙伴,需要面试 百度、头条、美团、阿里、京东等大厂。 下面是一个小伙伴成功拿到通过了百度三次技术面试,小伙伴通过三个多小时技术拷问,最…

Docker镜像存储

前言 在之前的文章中有说过容器目录的隔离机制. 今天来分析一下镜像的文件系统. Docker 已经用了很久了, 也知道镜像存储的时候是分层存储的(从docker pull时分层下载就能看出), 但是具体是如何将多层进行聚合并生成最终展示的文件, 这个过程从未深究过. 既然不知道, 又难掩好…

chatgpt赋能python:Python反向切片:介绍与例子

Python反向切片:介绍与例子 Python是一种高级编程语言,具有简单易懂的语法和高效的运行速度,以及丰富的标准库和第三方库。其中一项有趣的功能是Python反向切片,它能够用一种简单而有效的方式处理列表(list&#xff0…