使用无标注的数据训练Bert

news2024/7/6 18:40:35

文章目录

  • 1、准备用于训练的数据集
  • 2、处理数据集
  • 3、克隆代码
  • 4、运行代码
  • 5、将ckpt模型转为bin模型使其可在pytorch中运用

Bert官方仓库:https://github.com/google-research/bert

1、准备用于训练的数据集

此处准备的是BBC news的数据集,下载链接:https://www.kaggle.com/datasets/gpreda/bbc-news
原数据集格式(.csv):
在这里插入图片描述

2、处理数据集

训练Bert时需要预处理数据,将数据处理成https://github.com/google-research/bert/blob/master/sample_text.txt中所示格式,如下所示:
在这里插入图片描述
数据预处理代码参考:

import pandas as pd

# 读取BBC-news数据集
df = pd.read_csv("../../bbc_news.csv")
# print(df['title'])
l1 = []
l2 = []
cnt = 0
for line in df['title']:
    l1.append(line)

for line in df['description']:
    l2.append(line)
# cnt=0
f = open("test1.txt", 'w+', encoding='utf8')
for i in range(len(l1)):
    s = l1[i] + " " + l2[i] + '\n'
    f.write(s)
    # cnt+=1
    # if cnt>10: break
f.close()
# print(l1)

处理完后的BBC news数据集格式如下所示:
在这里插入图片描述

3、克隆代码

使用git克隆仓库代码
http:

git clone https://github.com/google-research/bert.git

或ssh:

git clone git@github.com:google-research/bert.git

4、运行代码

先下载Bert模型:BERT-Base, Uncased
该文件中有以下文件:
在这里插入图片描述
运行代码:
在Teminal中运行:

python create_pretraining_data.py \
  --input_file=./sample_text.txt(数据集地址) \
  --output_file=/tmp/tf_examples.tfrecord(处理后数据集保存的位置) \
  --vocab_file=$BERT_BASE_DIR/vocab.txt(vocab.txt文件位置) \
  --do_lower_case=True \
  --max_seq_length=128 \
  --max_predictions_per_seq=20 \
  --masked_lm_prob=0.15 \
  --random_seed=12345 \
  --dupe_factor=5

训练模型:

python run_pretraining.py \
  --input_file=/tmp/tf_examples.tfrecord(处理后数据集保存的位置) \
  --output_dir=/tmp/pretraining_output(训练后模型保存位置) \
  --do_train=True \
  --do_eval=True \
  --bert_config_file=$BERT_BASE_DIR/bert_config.json(bert_config.json文件位置) \
  --init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt(如果要从头开始的预训练,则去掉这行) \
  --train_batch_size=32 \
  --max_seq_length=128 \
  --max_predictions_per_seq=20 \
  --num_train_steps=20 \
  --num_warmup_steps=10 \
  --learning_rate=2e-5

训练完成后模型输出示例:

***** Eval results *****
  global_step = 20
  loss = 0.0979674
  masked_lm_accuracy = 0.985479
  masked_lm_loss = 0.0979328
  next_sentence_accuracy = 1.0
  next_sentence_loss = 3.45724e-05

要注意应该能够在至少具有 12GB RAM 的 GPU 上运行,不然会报错显存不足。
使用未标注数据训练BERT

5、将ckpt模型转为bin模型使其可在pytorch中运用

上一步训练好后准备好训练出来的model.ckpt-20.index文件和Bert模型中的bert_config.json文件

创建python文件convert_bert_original_tf_checkpoint_to_pytorch.py:

# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert BERT checkpoint."""


import argparse

import torch

from transformers import BertConfig, BertForPreTraining, load_tf_weights_in_bert
from transformers.utils import logging


logging.set_verbosity_info()


def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path):
    # Initialise PyTorch model
    config = BertConfig.from_json_file(bert_config_file)
    print("Building PyTorch model from configuration: {}".format(str(config)))
    model = BertForPreTraining(config)

    # Load weights from tf checkpoint
    load_tf_weights_in_bert(model, config, tf_checkpoint_path)

    # Save pytorch-model
    print("Save PyTorch model to {}".format(pytorch_dump_path))
    torch.save(model.state_dict(), pytorch_dump_path)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    # Required parameters
    parser.add_argument(
        "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
    )
    parser.add_argument(
        "--bert_config_file",
        default=None,
        type=str,
        required=True,
        help="The config json file corresponding to the pre-trained BERT model. \n"
        "This specifies the model architecture.",
    )
    parser.add_argument(
        "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
    )
    args = parser.parse_args()
    convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.bert_config_file, args.pytorch_dump_path)

在Terminal中运行以下命令:

python convert_bert_original_tf_checkpoint_to_pytorch.py \
--tf_checkpoint_path Models/chinese_L-12_H-768_A-12/bert_model.ckpt.index(.ckpt.index文件位置) \
--bert_config_file Models/chinese_L-12_H-768_A-12/bert_config.json(bert_config.json文件位置)  \
--pytorch_dump_path  Models/chinese_L-12_H-768_A-12/pytorch_model.bin(输出的.bin模型文件位置)

以上命令最好在一行中运行:

python convert_bert_original_tf_checkpoint_to_pytorch.py --tf_checkpoint_path bert_model.ckpt.index --bert_config_file bert_config.json  --pytorch_dump_path  pytorch_model.bin

然后就可以得到bin文件了
在这里插入图片描述

【BERT for Tensorflow】本地ckpt文件的BERT使用

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/489340.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Python | 人脸识别系统 — UI界面设计

本博客为人脸识别系统的UI界面设计代码解释 人脸识别系统博客汇总:人脸识别系统-博客索引 项目GitHub地址:【待】 注意:阅读本博客前请先参考以下博客 工具安装、环境配置:人脸识别系统-简介 阅读完本博客后可以继续阅读&#xff…

不用下载就能使用的4款轻量在线PS工具

PS是一种非常熟悉的设计工具,也是一种在设计领域占有重要地位的软件,如常见的产品设计、平面设计或摄影后期设计,几乎与PS的使用密不可分。PS本身也有很多功能,每个人的日常设计图纸、图纸修复等工作都可以用PS完成。 但PS有很多…

yolov8 OpenCV DNN 部署 推理报错

yolov8是yolov5作者发布的新作品 目录 1、下载源码 2、下载权重 3、配置环境 4、导出onnx格式 5、OpenCV DNN 推理 1、下载源码 git clone https://github.com/ultralytics/ultralytics.git 2、下载权重 git clone https://github.com/ultralytics/assets/releases/dow…

MySQL知识学习05(InnoDB存储引擎对MVCC的实现)

1、一致性非锁定读和锁定读 一致性非锁定读 对于 一致性非锁定读(Consistent Nonlocking Reads) ,通常做法是加一个版本号或者时间戳字段,在更新数据的同时版本号 1 或者更新时间戳。查询时,将当前可见的版本号与对…

K8S资源-configmap创建六种方式

云原生实现配置分离重要实现方式 两者都是用来存储配置文件,configmap存储通用的配置文件,secret存储需要加密的配置文件。 将配置文件configmap挂在到pod上 创建configmap 1.基于配置文件目录创建configmap kubectl create cm cmdir --from-fileconf…

医学图像分割之U-Net

一、背景及问题 在过去两年中,在很多视觉识别任务重,深度卷积网络的表现优于当时最先进的方法。但这些深度卷积网络的发展受限于网络模型的大小以及训练数据集的规模。虽然这个限制有过突破,也是在更深的网络、更大的数据集中产生的更好的性能…

【redis】redis的缓存过期淘汰策略

【redis】redis的缓存过期淘汰策略 文章目录 【redis】redis的缓存过期淘汰策略前言一、面试题二、redis内存满了怎么办?1、redis默认内存是多少?在哪查看?如何修改?在conf配置文件中可以查看 修改,内存默认是0redis的默认内存有…

使用意图intent构建一个多活动的Android应用

安卓意图Intent是Android应用组件(Activity、Service、Broadcast Receiver)之间进行交互的一种重要方式。Intent允许启动一个活动、启动一个服务、传递广播等。Intent使应用能够响应系统及其他应用的动作。Intent使用的主要目的有: 1、 启动Activity:可以启动自己应用内的Activ…

DDPM--生成扩散模型

DDPM–生成扩散模型 Github: https://github.com/daiyizheng/Deep-Learning-Ai/blob/master/AIGC/Diffusion.ipynb DDPM 是当前扩散模型的起点。在本文中,作者建议使用马尔可夫链模型,逐步向图像添加噪声。 函数 q ( x t ∣ x t − 1 ) q(x_t | x_t-1…

java获取真实ip的方法

在网络中,如果不想被人监听,那么就需要获取 IP地址了,在电脑中我们可以使用到 ip地址获取工具,那么如何在 Java中获取真实的 IP地址呢? 1、首先我们需要先准备一台电脑,然后将电脑进行联网; 2、…

ChatGPT带你一起了解C语言中的fseek()

fseek函数用于将文件指针移动到指定位置。它的原型如下: c int fseek(FILE *stream, long offset, int whence); 其中,stream是文件指针,offset是偏移量,whence是起始位置。 偏移量offset可以是正数、负数或零。 如果是正数&a…

Java --- springboot2数据响应与内容协商

目录 一、数据响应与内容协商 1.1、响应json 1.1.1、返回值解析器 1.1.2、springMVC支持的返回值类型 1.1.3、HttpMessageConverter原理 1.2、内容协商 1.2.1、引入依赖 1.2.2、 postman分别测试返回json和xml 1.2.3、开启浏览器参数方式内容协商功能 1.3、自定义 Message…

持续测试:DevOps时代质量保证的关键

在 DevOps 时代,持续测试已成为质量保证的一个重要方面。近年来,软件开发方法论发生了快速转变。随着 DevOps 的出现,已经发生了向自动化和持续集成与交付 (CI/CD) 的重大转变。传统的质量保证方法已不足以满足现代软件开发实践的需求。持续测…

Java——二叉树的深度

题目链接 牛客网在线oj题——二叉树的深度 题目描述 输入一棵二叉树,求该树的深度。从根结点到叶结点依次经过的结点(含根、叶结点)形成树的一条路径,最长路径的长度为树的深度,根节点的深度视为 1 。 数据范围&am…

记一次产线打印json导致的redis连接超时

服务在中午十一点上线后,服务每分钟发出三到四次redis连接超时告警。错误信息为: Dial err:dial tcp: lookup xxxxx: i/o timeout 排查过程 先是检查redis机器的情况,redis写入并发数较大,缓存中保留了一小时大概400w条数据。red…

java学习之第十章作业

目录 第一题 第二题 第三题 第四题 第五题 第六题 代码的问题点 第七题 第八题 第一题 package homework;public class HomeWork01 {public static void main(String[] args) {Car c new Car();//创建新对象,没有实参Car c1 new Car(100);//1.创建一个新的…

Windows11开启远程桌面和修改远程端口

该示例适用于大部分的Windows平台,示例基于Windows 11。操作系统:Windows 11 专业版。远程桌面默认使用TCP协议,默认端口为3389,修改后为13389。 一、开启远程桌面 控制面板-->系统与安全-->系统-->允许远程访问 二、修…

牛客网_华为机试题_HJ23 删除字符串中出现次数最少的字符

写在前面: 题目链接:牛客网_华为机试题_HJ23 删除字符串中出现次数最少的字符 编程语言:C 难易程度:简单 一、题目描述 描述 实现删除字符串中出现次数最少的字符,若出现次数最少的字符有多个,则把出现次数…

09 虚拟机配置-虚拟机描述

文章目录 09 虚拟机配置-虚拟机描述9.1 概述9.2 元素介绍9.3 配置示例 09 虚拟机配置-虚拟机描述 9.1 概述 本节介绍虚拟机domain根元素和虚拟机名称的配置。 9.2 元素介绍 domain:虚拟机XML配置文件的根元素,用于配置运行此虚拟机的hypervisor的类型…

英语中主语从句的概念及其用法,例句(不断更新)

主语从句的原理 主语从句是一种充当整个句子主语的从句,主语从句构成的句子,是要以引导词开头的。它可以用名词性从属连词、关系代词或关系副词引导。主语从句通常位于谓语动词之前,用于表示动作、状态或事件的主体。 以下是一些常用的引导主…