GPT-2 语言模型 - 模型训练

news2025/4/16 1:40:47

本节代码是一个完整的机器学习工作流程,用于训练一个基于GPT-2的语言模型。下面是对这段代码的详细解释:

文件目录如下

1. 初始化和数据准备

  • 设置随机种子

    random.seed(1002)

    确保结果的可重复性。

  • 定义参数

    test_rate = 0.2
    context_length = 128
    • test_rate:测试集占总数据集的比例。

    • context_length:模型处理的文本长度。

  • 获取数据文件

    all_files = glob(pathname=os.path.join("data","*"))

    使用 glob 获取 data 目录下的所有文件。

  • 划分数据集

    test_file_list = random.sample(all_files, int(len(all_files) * test_rate))
    train_file_list = [i for i in all_files if i not in test_file_list]

    将数据集随机划分为训练集和测试集。

  • 加载数据集

    raw_datasets = load_dataset("csv", data_files={"train": train_file_list, "vaild": test_file_list}, cache_dir="cache_data")

    使用 datasets 库加载 CSV 格式的数据集,并缓存到 cache_data 目录。

2. 数据预处理

  • 初始化分词器

    tokenizer = BertTokenizerFast.from_pretrained("D:/bert-base-chinese")
    tokenizer.add_special_tokens({"bos_token":"[begin]","eos_token":"[end]"})

    从本地路径加载预训练的 BERT 分词器,并添加自定义的开始和结束标记。

  • 数据预处理

    tokenize_datasets = raw_datasets.map(tokenize, batched=True, remove_columns=raw_datasets["train"].column_names)

    使用 map 方法对数据集进行预处理,将文本转换为模型可接受的格式。

    • tokenize 函数对文本进行分词和截断。

    • batched=True 表示批量处理数据。

    • remove_columns 删除原始数据集中的列。

3. 模型配置和初始化

  • 模型配置

    config = GPT2Config.from_pretrained("config",
                                        vocab_size=len(tokenizer),
                                        n_ctx=context_length,
                                        bos_token_id=tokenizer.bos_token_id,
                                        eos_token_id=tokenizer.eos_token_id,
                                        )

    加载预训练的 GPT-2 配置,并根据分词器的词汇表大小和上下文长度进行调整。

  • 初始化模型

    model = GPT2LMHeadModel(config)
    model_size = sum([t.numel() for t in model.parameters()])
    print(f"model_size: {model_size/1000/1000} M")

    根据配置初始化 GPT-2 语言模型,并计算模型参数的总数,打印模型大小(以兆字节为单位)。

4. 训练设置

  • 数据整理器

    data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)

    使用 DataCollatorForLanguageModeling 整理训练数据,设置 mlm=False 表示不使用掩码语言模型。

  • 训练参数

    args = TrainingArguments(
        learning_rate=1e-5,
        num_train_epochs=100,
        per_device_train_batch_size=10,
        per_device_eval_batch_size=10,
        eval_steps=2000,
        logging_steps=2000,
        gradient_accumulation_steps=5,
        weight_decay=0.1,
        warmup_steps=1000,
        lr_scheduler_type="cosine",
        save_steps=100,
        output_dir="model_output",
        fp16=True,
    )

    配置训练参数,包括学习率、训练轮数、批大小、评估间隔等。

  • 初始化训练器

    trianer = Trainer(
        model=model,
        args=args,
        tokenizer=tokenizer,
        data_collator=data_collator,
        train_dataset=tokenize_datasets["train"],
        eval_dataset=tokenize_datasets["vaild"]
    )
  • 启动训练

    trianer.train()

    使用 Trainer 类启动模型训练。

需复现完整代码

from glob import glob
import os
from torch.utils.data import Dataset
from datasets import load_dataset
import random
from transformers import BertTokenizerFast
from transformers import GPT2Config
from transformers import GPT2LMHeadModel
from transformers import DataCollatorForLanguageModeling
from transformers import Trainer,TrainingArguments

def tokenize(element):
    outputs = tokenizer(element["content"],truncation=True,max_length=context_length,return_overflowing_tokens=True,return_length=True)

    input_batch = []

    for length,input_ids in zip(outputs["length"],outputs["input_ids"]):

        if length == context_length:
            input_batch.append(input_ids)

    return {"input_ids":input_batch}

if __name__ == "__main__":
    random.seed(1002)
    test_rate = 0.2
    context_length = 128

    all_files = glob(pathname=os.path.join("data","*"))

    test_file_list = random.sample(all_files,int(len(all_files)*test_rate))
    train_file_list = [i for i in all_files if i not in test_file_list]

    raw_datasets = load_dataset("csv",data_files={"train":train_file_list,"vaild":test_file_list},cache_dir="cache_data")


    tokenizer = BertTokenizerFast.from_pretrained("D:/bert-base-chinese")
    tokenizer.add_special_tokens({"bos_token":"[begin]","eos_token":"[end]"})

    tokenize_datasets = raw_datasets.map(tokenize,batched=True,remove_columns=raw_datasets["train"].column_names)

    config = GPT2Config.from_pretrained("config",
                                        vocab_size=len(tokenizer),
                                        n_ctx=context_length,
                                        bos_token_id = tokenizer.bos_token_id,
                                        eos_token_id = tokenizer.eos_token_id,
                                        )

    model = GPT2LMHeadModel(config)
    model_size = sum([ t.numel() for t in model.parameters()])
    print(f"model_size: {model_size/1000/1000} M")

    data_collator = DataCollatorForLanguageModeling(tokenizer,mlm=False)

    args = TrainingArguments(
        learning_rate=1e-5,
        num_train_epochs=100,
        per_device_train_batch_size=10,
        per_device_eval_batch_size=10,
        eval_steps=2000,
        logging_steps=2000,
        gradient_accumulation_steps=5,
        weight_decay=0.1,
        warmup_steps=1000,
        lr_scheduler_type="cosine",
        save_steps=100,
        output_dir="model_output",
        fp16=True,
    )

    trianer = Trainer(
        model=model,
        args=args,
        tokenizer=tokenizer,
        data_collator=data_collator,
        train_dataset=tokenize_datasets["train"],
        eval_dataset=tokenize_datasets["vaild"]
    )

    trianer.train()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2335630.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

科技项目验收测试包括哪些内容?有什么作用?

在现代科技快速发展的背景下,科技项目的验收测试已成为项目管理中的重要环节。科技项目验收测试是一种系统性的方法,旨在评估一个科技项目是否达到预定的技术指标和要求,确认项目的完成质量。该测试通常在项目实施完成后进行,通过…

websoket 学习笔记

目录 基本概念 工作原理 优势 应用场景 HTTP协议与 webSoket协议之间的对比 消息推送场景 1. 轮询(Polling) 2. 长轮询(Long Polling) 3. 服务器发送事件(Server-Sent Events, SSE) 4. WebSocket…

博途 TIA Portal之1200做从站与汇川EASY的TCP通讯

上篇我们写到了博途做主站与汇川EASY的通讯。通讯操作起来很简单,当然所谓的简单,也是相对的,如果操作成功一次,那么后面就很容易了, 如果操作不成功,就会很遭心。本篇我们将1200做从站,与汇川EASY做主站进行TCP的通讯。 1、硬件准备 1200PLC一台,带调试助手的PC机一…

【数据结构_6下篇】有关链表的oj题

思路: 1.分别求出这两个链表的长度 2.创建两个引用,指向两个链表的头节点;找到长度长的链表,让她的引用先走差值步数 3.让这两个引用,同时往后走,每个循环各自走一步 然后再判定两个引用是否指向同一个…

vscode+wsl 运行编译 c++

linux 的 windows 子系统(wsl)是 windows 的一项功能,可以安装 Linux 的发行版,例如(Ubuntu,Kali,Arch Linux)等,从而可以直接在 windows 下使用 Linux 应用程序&#xf…

关于 Spring Boot 微服务解决方案的对比,并以 Spring Cloud Alibaba 为例,详细说明其核心组件的使用方式、配置及代码示例

以下是关于 Spring Boot 微服务解决方案的对比,并以 Spring Cloud Alibaba 为例,详细说明其核心组件的使用方式、配置及代码示例: 关于 Spring Cloud Alibaba 致力于提供微服务开发的一站式解决方案! https://sca.aliyun.com/?spm7145af80…

VS 基于git工程编译版本自动添加版本号

目录 概要 实现方案 概要 最近在用visual Studio 开发MFC项目时,需要在release版本编译后的exe文件自动追加版本信息。 由于我们用的git工程管理,即需要基于最新的git 提交来打版本。 比如: MFCApplication_V1.0.2_9.exe 由于git 提交信…

pytorch软件封装

封装代码,通过传入文件名,即可输出类别信息 上一章节,我们做了关于动物图像的分类,接下来我们把程序封装,然后进行预测。 单张图片的predict文件 predict.py 按着路径,导入单张图片做预测from torchvis…

【多线程-第四天-自己模拟SDWebImage的下载图片功能-看SDWebImage的Demo Objective-C语言】

一、我们打开之前我们写的异步下载网络图片的项目,把刚刚我们写好的分类拖进来 1.我们这个分类包含哪些文件: 1)HMDownloaderOperation类, 2)HMDownloaderOperationManager类, 3)NSString+Sandbox分类, 4)UIImageView+WebCache分类, 这四个文件吧,把它们拖过来…

电脑提示“找不到mfc140u.dll“的完整解决方案:从原因分析到彻底修复

当你启动某个软件或游戏时,突然遭遇"无法启动程序,因为计算机中丢失mfc140u.dll"的错误提示,这确实令人沮丧。mfc140u.dll是Microsoft Foundation Classes(MFC)库的重要组成部分,属于Visual C Re…

图像变换方式区别对比(Opencv)

1. 变换示例 import cv2 import matplotlib.pyplot as plotimg cv2.imread(url) img_cut img[100:200, 200:300] img_rsize cv2.resize(img, (50, 50)) (hight,width) img.shape[:2] rotate_matrix cv2.getRotationMatrix2D((hight//2, width//2), 50, 1) img_wa cv2.wa…

图像颜色空间对比(Opencv)

1. 颜色转换 import cv2 import matplotlib.pyplot as plotimg cv2.imread("tmp.jpg") img_r cv2.cvtColor(img, cv2.COLOR_BGR2RGB) img_g cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) img_h cv2.cvtColor(img, cv2.COLOR_BGR2HSV) img_l cv2.cvtColor(img, cv2.C…

每天学一个 Linux 命令(15):man

可访问网站查看,视觉品味拉满:http://www.616vip.cn/15/index.html 每天学一个 Linux 命令(15):man 命令简介 man(Manual)是 Linux 中最核心的命令之一,用于查看命令、系统调用、库函数等的手册文档。它是用户和开发者获取帮助的核心工具,几乎覆盖了系统中的所有功…

必刷算法100题之计算右侧小于当前元素的个数

题目链接 315. 计算右侧小于当前元素的个数 - 力扣(LeetCode) 题目解析 计算数组里面所有元素右侧比它小的数的个数, 并且组成一个数组,进行返回 算法原理 归并解法(分治) 当前元素的后面, 有多少个比我小(降序) 我们要找到第一比左边小的元素, 这样…

Python依赖注入完全指南:高效解耦、技术深析与实践落地

Python依赖注入完全指南:高效解耦、技术深析与实践落地 摘要 依赖注入(DI)不仅是一种设计技术,更是一种解耦的艺术。它通过削减模块间的强耦合性,为系统提供了更高的灵活性和可测试性,特别是在 FastAPI 等…

深度学习ResNet模型提取影响特征

大家好,我是带我去滑雪! 影像组学作为近年来医学影像分析领域的重要研究方向,致力于通过从医学图像中高通量提取大量定量特征,以辅助疾病诊断、分型、预后评估及治疗反应预测。这些影像特征涵盖了形状、纹理、灰度统计及波形变换等…

【Qt】Qt Creator开发基础:项目创建、界面解析与核心概念入门

🍑个人主页:Jupiter. 🚀 所属专栏:QT 欢迎大家点赞收藏评论😊 目录 Qt Creator 新建项⽬认识 Qt Creator 界⾯项⽬⽂件解析Qt 编程注意事项认识对象模型(对象树)Qt 窗⼝坐标体系 Qt Creator 新…

制造业项目管理如何做才能更高效?制造企业如何选择适配的数字化项目管理系统工具?

一、制造企业项目管理过程中面临的痛点有哪些? 制造企业在项目管理过程中面临的痛点通常涉及跨部门协作、资源调配、数据整合、风险控制等多个维度,且与行业特性(如离散制造vs流程制造)紧密相关。 进度失控多项目资源冲突信息孤…

Python批量处理PDF图片详解(插入、压缩、提取、替换、分页、旋转、删除)

目录 一、概述 二、 使用工具 三、Python 在 PDF 中插入图片 3.1 插入图片到现有PDF 3.2 插入图片到新建PDF 3.3 批量插入多张图片到PDF 四、Python 提取 PDF 图片及其元数据 五、Python 替换 PDF 图片 5.1 使用图片替换图片 5.2 使用文字替换图片 六、Python 实现 …

七种驱动器综合对比——《器件手册--驱动器》

九、驱动器 名称 功能与作用 工作原理 优势 应用 隔离式栅极驱动器 隔离式栅极驱动器用于控制功率晶体管(如MOSFET、IGBT、SiC或GaN等)的开关,其核心功能是将控制信号从低压侧传输到高压侧的功率器件栅极,同时在输入和输出之…