使用python从头开始预训练RoBERTa模型

news2025/1/18 19:06:26

本文将介绍如何使用Hugging Face库从头开始构建一个预训练Transformer模型。该模型称为 KantaiBERT。

#@title Step 1: Loading the Dataset
#1.Load kant.txt using the Colab file manager
#2.Downloading the file from GitHubant
!curl -L https://raw.githubusercontent.com/Denis2054/Transformers-for-NLP-2nd-Edition/master/Chapter04/kant.txt --output "kant.txt"

目录结构

#@title Step 2:APRIL 2023 UPDATE: Installing Hugging Face Transformers
'''
# We won't need TensorFlow here
!pip uninstall -y tensorflow
# Install `transformers` from master
!pip install git+https://github.com/huggingface/transformers
!pip list | grep -E 'transformers|tokenizers'
# transformers version at notebook update --- 2.9.1
# tokenizers version at notebook update --- 0.7.0
'''
#@title Step 3: Training a Tokenizer
from pathlib import Path
from tokenizers import ByteLevelBPETokenizer
paths = [str(x) for x in Path(".").glob("**/*.txt")]
# Initialize a tokenizer
tokenizer = ByteLevelBPETokenizer()

# Customize training
tokenizer.train(files=paths, vocab_size=52_000, min_frequency=2, special_tokens=[
    "<s>",
    "<pad>",
    "</s>",
    "<unk>",
    "<mask>",
])
#@title Step 4: Saving the files to disk
import os
token_dir = '/root/information_needs/KantaiBERT'
if not os.path.exists(token_dir):
  os.makedirs(token_dir)
tokenizer.save_model('KantaiBERT')
#@title Step 5 Loading the Trained Tokenizer Files 
from tokenizers.implementations import ByteLevelBPETokenizer
from tokenizers.processors import BertProcessing

tokenizer = ByteLevelBPETokenizer(
    "./KantaiBERT/vocab.json",
    "./KantaiBERT/merges.txt",
)
print(tokenizer.encode("The Critique of Pure Reason.").tokens)
# ['The', 'ĠCritique', 'Ġof', 'ĠPure', 'ĠReason', '.']
print(tokenizer.encode("The Critique of Pure Reason."))
# Encoding(num_tokens=6, attributes=[ids, type_ids, tokens, offsets, attention_mask, special_tokens_mask, overflowing])
tokenizer._tokenizer.post_processor = BertProcessing(
    ("</s>", tokenizer.token_to_id("</s>")),
    ("<s>", tokenizer.token_to_id("<s>")),
)
tokenizer.enable_truncation(max_length=512)
#@title Checking that PyTorch Sees CUDA
import torch
torch.cuda.is_available()
#@title Step 7: Defining the configuration of the Model
from transformers import RobertaConfig

config = RobertaConfig(
    vocab_size=52_000,
    max_position_embeddings=514,
    num_attention_heads=12,
    num_hidden_layers=6,
    type_vocab_size=1,
)
print(config)
# RobertaConfig {
#  "attention_probs_dropout_prob": 0.1,
#  "bos_token_id": 0,
#  "classifier_dropout": null,
#  "eos_token_id": 2,
#  "hidden_act": "gelu",
#  "hidden_dropout_prob": 0.1,
#  "hidden_size": 768,
#  "initializer_range": 0.02,
#  "intermediate_size": 3072,
#  "layer_norm_eps": 1e-12,
#  "max_position_embeddings": 514,
#  "model_type": "roberta",
#  "num_attention_heads": 12,
#  "num_hidden_layers": 6,
#  "pad_token_id": 1,
#  "position_embedding_type": "absolute",
#  "transformers_version": "4.45.2",
#  "type_vocab_size": 1,
#  "use_cache": true,
#  "vocab_size": 52000
#}
#@title Step 8: Re-creating the Tokenizer in Transformers
from transformers import RobertaTokenizer
tokenizer = RobertaTokenizer.from_pretrained("./KantaiBERT", max_length=512)
#@title Step 9: Initializing a Model From Scratch
from transformers import RobertaForMaskedLM
model = RobertaForMaskedLM(config=config)
print(model)
print(model.num_parameters())
# => 84,095,008 parameters
#@title Exploring the Parameters
LP=list(model.parameters())
lp=len(LP)
print(lp)
for p in range(0,lp):
  print(LP[p])
#@title Counting the parameters
np=0
for p in range(0,lp):#number of tensors
  PL2=True
  try:
    L2=len(LP[p][0]) #check if 2D
  except:
    L2=1             #not 2D but 1D
    PL2=False
  L1=len(LP[p])      
  L3=L1*L2
  np+=L3             # number of parameters per tensor
  if PL2==True:
    print(p,L1,L2,L3)  # displaying the sizes of the parameters
  if PL2==False:
    print(p,L1,L3)  # displaying the sizes of the parameters

print(np)              # total number of parameters
#@title Step 10: Building the Dataset

from transformers import LineByLineTextDataset
dataset = LineByLineTextDataset(
    tokenizer=tokenizer,
    file_path="/root/information_needs/kantai/kant.txt",
    block_size=128,
)
#@title Step 11: Defining a Data Collator
from transformers import DataCollatorForLanguageModeling
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer, mlm=True, mlm_probability=0.15
)
#@title Step 12: Initializing the Trainer
from transformers import Trainer, TrainingArguments

training_args = TrainingArguments(
    output_dir="./KantaiBERT",
    overwrite_output_dir=True,
    num_train_epochs=1,
    per_device_train_batch_size=64,
    save_steps=10_000,
    save_total_limit=2,
)
trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=dataset,
)
#@title Step 13: Pre-training the Model
trainer.train()

在这里插入图片描述

#@title Step 14: Saving the Final Model(+tokenizer + config) to disk
trainer.save_model("./KantaiBERT")
#@title Step 15: Language Modeling with the FillMaskPipeline
from transformers import pipeline
fill_mask = pipeline(
    "fill-mask",
    model="./KantaiBERT",
    tokenizer="./KantaiBERT"
)
fill_mask("Human thinking involves human <mask>.")

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2214200.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Linux学习第一天

目录 1.引入 计算机的组成&#xff08;图解&#xff09; 操作系统是什么 操作系统的功能 操作系统的组成&#xff08;图解&#xff09; 操作系统内核的功能 常见的操作系统 2.Libux的学习 Linux的特点 Linux应用领域 搭建Linux学习环境 下载 创建虚拟机 新建虚拟机…

短视频矩阵开发,抖音新机遇(技术开发框架解析)

开发前言&#xff1a; 抖音短视频矩阵系统技术开发框架主要利用了VUE&#xff0c; Spring Boot、Django等技术。本技术文档适用于短视频矩阵源码的开发和部署。 #短视频矩阵源码开发部署 #抖音矩阵源码开发 #抖音矩阵源码 #抖音矩阵开发 抖音短视频矩阵系统的技术开发框架可以…

P1320压缩技术(续集版

P1320压缩技术&#xff08;续集版 感觉这题还是蛮难的对我来说&#xff0c;通过这题我才知道原来字符串输入不碰到空格就会一起输进来 我参考了一写题解自己又写了自己的解法&#xff0c;vs中的scanf_s和scanf()用法不太一样&#xff0c;之前按scanf写法写一直在报错&#xff…

彻底掌握Android中的Lifecycle

彻底掌握Android中的Lifecycle Lifecycle 是一个生命周期感知型组件&#xff0c;属于 Jetpack 组件库中的一部分&#xff0c;其核心功能是将组件&#xff08;如Activity 和 Fragment&#xff09;的生命周期状态通知给观察者&#xff08;LifecycleObserver&#xff09;。观察者…

指针 + 数组 较为复杂凌乱的 【笔试题】

2024 - 10 - 10 - 笔记 - 25 作者(Author): 郑龙浩 / 仟濹(CSDN 账号名) 【指针 数组】的 各种题型(笔试题) 来自于鹏哥的网课&#xff0c;我做一下笔记 119. 【C语言进阶】笔试题详解&#xff08;4&#xff09;_哔哩哔哩_bilibili ① 题 #include <stdio.h> int m…

VUE 开发——Vue学习(三)—— 智慧商城项目

目录 解释各个模块 api接口模块&#xff1a;发送ajax请求的接口模块utils工具模块&#xff1a;自己封装的一些工具方法模块components组件模块&#xff1a;全局通用的组件router路由模块&#xff1a;封装要所有路由views&#xff1a;各个页面assets&#xff1a;各种资源 van…

JAVA软开-面试经典题(7)-字符串常量池

字符串常量池 1.定义&#xff1a;字符串常量池&#xff08;String Constant Pool&#xff09;&#xff0c;用于存放字符串常量的运行时内存结构&#xff0c;其底层的实现为Hashtable。 【注意】 在JDK1.6之前&#xff0c;字符串常量池中只会存放具体的String实例&#xff0c;在…

MySQL基础探秘(3)

前面那篇文章是简单介绍了往数据库中插入数据&#xff0c;以及对数据进行有些改动。 但是&#xff0c;细想下&#xff0c;数据能够无限制&#xff0c;无约束进行插入吗&#xff1f; emm……显然是不行的&#xff0c;不然数据就乱套了&#xff0c;看起来不美观。 所以要对数据…

Axure详细介绍及功能对比,常用版本选择和替代软件分享

Axure是一款专门用于原型设计和交互设计的专业软件&#xff0c;广泛应用于用户界面&#xff08;UI&#xff09;和用户体验&#xff08;UX&#xff09;设计领域。它的主要功能是帮助产品经理、设计师以及开发人员创建具有互动性的原型&#xff0c;以便展示和测试各种应用、网站或…

CST学习笔记(二)Floquet模式激励设置

CST学习笔记&#xff08;二&#xff09;Floquet模式激励设置 在CST中我们常常使用Floquet模式来仿真频率选择表面(FSS)或者超材料等&#xff0c;但是我们设置好Zmax的floquet模式数量后&#xff0c;启动仿真&#xff0c;会发现S参数一栏中有很多我们不想要看的S参数&#xff0…

海南聚广众达电子商务咨询有限公司解锁流量密码

在这个瞬息万变的数字时代&#xff0c;电商行业如同一股不可阻挡的洪流&#xff0c;正以前所未有的速度重塑着商业版图。而在这股浪潮中&#xff0c;抖音电商以其独特的魅力&#xff0c;迅速崛起为一颗璀璨的新星&#xff0c;吸引了无数商家与创业者的目光。海南聚广众达电子商…

【题解】【动态规划01背包问题】—— [NOIP2012 普及组] 摆花

【题解】【动态规划01背包问题】—— [NOIP2012 普及组] 摆花 [NOIP2012 普及组] 摆花题目描述输入格式输出格式输入输出样例输入 #1输出 #1 提示 解法1.二维 d p dp dp1.1.思路解析1.2.AC代码 解法2.一维 d p dp dp2.1.思路解析2.2.AC代码 3.扩展:前缀和优化 [NOIP2012 普及组…

python基础知识(十一)面向过程,面向对象,对象属性,魔法方法,继承,私有权限

目录 面向过程是什么 什么是面向对象&#xff1f; 面向对象的三大特性&#xff1a; 继承 多态 类 对象 self关键字 对象属性 类外面访问属性 类内部获取属性 魔法方法 无参init()方法 有参init()方法 str()方法 del()方法 继承基础 什么是继承 单继承 多继…

Javascript笔试题目(六)

1.如何使用JS实现Promise 对象?请写出具体代码 Promise其实也不难-CSDN博客 Javascript 手写一个Promise_javascript中手写promise ?-CSDN博客 Promise其实也不难-CSDN博客 题目要求我们使用JavaScript实现一个Promise对象。对此我们可以基于Promise/A规范的要求进行实现Prom…

面试-2024年7月16号

面试-2024年7月16号 自我介绍Mysql主从复制是做了一个什么样的集群&#xff1f;在Mysql的使用过程中遇到过哪些问题&#xff1f;mysql迁移具体步骤mysql漏洞修复是怎么做的。mysql的容灾方案&#xff08;灾备恢复机制&#xff09;。redis多节点怎么部署的redis的备份与恢复、迁…

电源中的“冷地”和“热地”

最近硬件同事在弄开关电源相关项目&#xff0c;由于其第一次做开关电源&#xff0c;并不懂冷地和热地的区别&#xff0c;出现示波器探头接地夹夹“热地”导致实验室多次跳闸&#xff0c;最严重时把板子给炸了。为了了解冷地和热地的如何辨别以及为什么热地带电这些知识&#xf…

【从零开发Mybatis】引入XNode和XPathParser

引言 在上文&#xff0c;我们发现直接使用 DOM库去解析XML 配置文件&#xff0c;非常复杂&#xff0c;也很不方便&#xff0c;需要编写大量的重复代码来处理 XML 文件的读取和解析&#xff0c;代码可读性以及可维护性相当差&#xff0c;使用起来非常不灵活。 因此&#xff0c…

JavaEE 多线程第二节 (多线程的简单实现Thread/Runable)

1. 创建线程&#xff08;继承 Thread 类&#xff09;步骤&#xff1a; 继承 Thread 类&#xff1a; 创建一个类并继承 Thread 类&#xff0c;然后重写 run() 方法&#xff0c;在该方法中写入线程执行的代码 class MyThread extends Thread {Overridepublic void run()…

数据恢复超简单!9 个方法任你选!小白也能轻易恢复数据!

在当今数字化的世界中&#xff0c;数据恢复的重要性日益凸显。无论是工作中的重要文档&#xff0c;还是生活中的珍贵照片和视频&#xff0c;一旦丢失&#xff0c;都可能给我们带来极大的困扰。别担心&#xff0c;下面为大家介绍 9 个超简单的数据恢复方法&#xff0c;让小白也能…

C++基础面试题 | 什么是C++中的运算符重载?

文章目录 回答重点&#xff1a;示例&#xff1a; 运算符重载的基本规则和注意事项&#xff1a; 回答重点&#xff1a; C的运算符重载是指可以为自定义类型&#xff08;如类或结构体&#xff09;定义运算符的行为&#xff0c;使其像内置类型一样使用运算符。通过重载运算符&…