从零构建属于自己的GPT系列3:模型训练2(训练函数解读、模型训练函数解读、代码逐行解读)

news2025/1/19 20:22:25

🚩🚩🚩Hugging Face 实战系列 总目录

有任何问题欢迎在下面留言
本篇文章的代码运行界面均在PyCharm中进行
本篇文章配套的代码资源已经上传

从零构建属于自己的GPT系列1:数据预处理
从零构建属于自己的GPT系列2:模型训练1
从零构建属于自己的GPT系列3:模型训练2
从零构建属于自己的GPT系列4:模型训练3

3 数据加载函数

def load_dataset(logger, args):
    logger.info("loading training dataset")
    train_path = args.train_path
    with open(train_path, "rb") as f:
        train_list = pickle.load(f)
    train_dataset = CPMDataset(train_list, args.max_len)
    return train_dataset
  1. 日志报告加载训练数据
  2. 训练数据路径
  3. 将以二进制形式存储的数据文件
  4. 使用 pickle 加载到内存中的 train_list 变量中
  5. 加载CPMDataset包,将train_list从索引转化为torch tensor
  6. 返回tensor
    在这里插入图片描述

4 训练函数

def train(model, logger, train_dataset, args):
    train_dataloader = DataLoader(
        train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, collate_fn=collate_fn,
        drop_last=True
    )
    logger.info("total_steps:{}".format(len(train_dataloader)* args.epochs))
    t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.epochs
    optimizer = transformers.AdamW(model.parameters(), lr=args.lr, eps=args.eps)
    scheduler = transformers.get_linear_schedule_with_warmup(
        optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
    )# 设置warmup
    logger.info('start training')
    train_losses = []   # 记录每个epoch的平均loss
    for epoch in range(args.epochs):
        train_loss = train_epoch(
            model=model, train_dataloader=train_dataloader,
            optimizer=optimizer, scheduler=scheduler,
            logger=logger, epoch=epoch, args=args)
        train_losses.append(round(train_loss, 4))
        logger.info("train loss list:{}".format(train_losses))

    logger.info('training finished')
    logger.info("train_losses:{}".format(train_losses))
  1. 训练函数
  2. 制作Dataloader
  3. 制作Dataloader
  4. 制作Dataloader
  5. 制作Dataloader
  6. 日志添加信息Dataloader*epochs的数量
  7. 记录数据长度到t_total变量中
  8. 指定优化器
  9. 学习率衰减策略,从transformers包中调用现成的get_linear_schedule_with_warmup方法
  10. 设置warmup等参数
  11. 学习率衰减策略
  12. 日志添加信息开始训练
  13. 记录所有epoch的训练损失,以求每个epoch的平均loss
  14. 遍历每个epoch
  15. 指定一个我们自己写的train_epoch函数1
  16. train_epoch函数2
  17. train_epoch函数3
  18. train_epoch函数4
  19. 记录损失,只保存4位小数
  20. 记录日志信息训练损失
  21. 记录日志信息训练完成
  22. 最后一句是在日志中保存所有损失吗?

5 迭代训练函数train_epoch

def train_epoch(model, train_dataloader, optimizer, scheduler, logger, epoch, args):
    model.train()
    device = args.device
    ignore_index = args.ignore_index
    epoch_start_time = datetime.now()
    total_loss = 0  # 记录下整个epoch的loss的总和
    epoch_correct_num = 0   # 每个epoch中,预测正确的word的数量
    epoch_total_num = 0  # 每个epoch中,预测的word的总数量
    for batch_idx, (input_ids, labels) in enumerate(train_dataloader):
        try:
            input_ids = input_ids.to(device)
            labels = labels.to(device)
            outputs = model.forward(input_ids, labels=labels)
            logits = outputs.logits
            loss = outputs.loss
            loss = loss.mean()
            batch_correct_num, batch_total_num = calculate_acc(logits, labels, ignore_index=ignore_index)
            epoch_correct_num += batch_correct_num
            epoch_total_num += batch_total_num
            batch_acc = batch_correct_num / batch_total_num
            total_loss += loss.item()
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
            if (batch_idx + 1) % args.gradient_accumulation_steps == 0:
                optimizer.step()
                scheduler.step()
                optimizer.zero_grad()
            if (batch_idx + 1) % args.log_step == 0:
                logger.info(
                    "batch {} of epoch {}, loss {}, batch_acc {}, lr {}".format(
                        batch_idx + 1, epoch + 1, loss.item() * args.gradient_accumulation_steps, batch_acc, scheduler.get_lr()))
            del input_ids, outputs
        except RuntimeError as exception:
            if "out of memory" in str(exception):
                logger.info("WARNING: ran out of memory")
                if hasattr(torch.cuda, 'empty_cache'):
                    torch.cuda.empty_cache()
            else:
                logger.info(str(exception))
                raise exception
    epoch_mean_loss = total_loss / len(train_dataloader)
    epoch_mean_acc = epoch_correct_num / epoch_total_num
    logger.info(
        "epoch {}: loss {}, predict_acc {}".format(epoch + 1, epoch_mean_loss, epoch_mean_acc))
    logger.info('saving model for epoch {}'.format(epoch + 1))
    model_path = join(args.save_model_path, 'epoch{}'.format(epoch + 1))
    if not os.path.exists(model_path):
        os.mkdir(model_path)
    model_to_save = model.module if hasattr(model, 'module') else model
    model_to_save.save_pretrained(model_path)
    logger.info('epoch {} finished'.format(epoch + 1))
    epoch_finish_time = datetime.now()
    logger.info('time for one epoch: {}'.format(epoch_finish_time - epoch_start_time))

    return epoch_mean_loss
  1. train_epoch函数
  2. 指定训练模式
  3. 训练设备
  4. 需要忽略的索引
  5. 当前epoch开启的具体时间
  6. 当前epoch的loss总和
  7. 当前epoch预测词正确的总数量
  8. 每个epoch需要预测的测的总数量
  9. for训练从train_dataloader遍历取数据
  10. 捕捉异常
  11. 输入词的索引数据进入训练设备
  12. 标签数据进入训练设备
  13. 输入数据经过前向传播得到输出
  14. 经过softmax后的输出
  15. 得到损失
  16. 平均损失
  17. 通过calculate_acc函数统计该batch的预测token的正确数与总数
  18. 统计该epoch的预测token的正确数
  19. 统计该epoch的预测token的总数
  20. 计算该batch的accuracy
  21. 获得损失值的标量累加到当前epoch总损失
  22. 如果当前的梯度累加步数大于1
  23. 对当前累加的损失对梯度累加步数求平均
  24. 损失反向传播
  25. 梯度裁剪:梯度裁剪的目的是控制梯度的大小,防止梯度爆炸的问题。在训练神经网络时,梯度可能会变得非常大,导致优化算法出现数值不稳定的情况。裁剪梯度就是将梯度的整体范数限制在一个特定的阈值之内
  26. 达到梯度累加的次数后
  27. 更新参数
  28. 更新学习率
  29. 梯度清零
  30. 梯度累加次数为0时,也就是参数更新时
  31. 记录日志
  32. 记录的各个参数占位符
  33. 占位符对应的各个变量
  34. 删除两个变量,释放内存
  35. 捕捉到异常
  36. 如果异常的信息中的字符串包含内存不足的问题,也就是显卡内存不足
  37. 将该问题添加到日志信息
  38. 当显卡内存占用过多时
  39. 手动释放显卡内存
  40. 如果不是显卡内存不足
  41. 记录日志
  42. 返回异常
  43. 记录当前epoch的平均loss
  44. 记录当前epoch的平均accuracy
  45. 日志记录信息
  46. 记录的信息为当前epoch索引、损失、准确率
  47. 日志记录信息,当前保存的模型以及对于的epoch索引
  48. 保存模型的地址
  49. 如果地址不存在
  50. 创建该地址
  51. 确保得到不是外壳对象
  52. 保存模型
  53. 日志记录信息训练完成
  54. 记录完成时间
  55. 记录当前epoch训练所花费时间
  56. 返回epoch平均损失

从零构建属于自己的GPT系列1:数据预处理
从零构建属于自己的GPT系列2:模型训练1
从零构建属于自己的GPT系列3:模型训练2
从零构建属于自己的GPT系列4:模型训练3

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1290686.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Edge调用Aria2下载

一、准备工作 1、Edge浏览器:Windows系统自带或点击下载;   2、Aria2 gui:点击github下载或自行搜索下载其他版本; 二、启动Aria2 gui 解压下载的Aria2 gui到任意目录,点击“Aria2c启动器”或“AriaNg启动器”皆可。…

翻译: 大语言模型LLMs能做什么和不能做什么 保存笔记What LLMs can and cannot do

生成式 AI 是一项惊人的技术,但它并非万能。在这个视频中,我们将仔细看看大型语言模型(LLM)能做什么,不能做什么。我们将从我发现的一个有用的心理模型开始,了解它能做什么,然后一起看看 LLM 的…

webrtc网之sip转webrtc

OpenSIP是一个开源的SIP(Session Initiation Protocol)服务器,它提供了一个可扩展的基础架构,用于建立、终止和管理VoIP(Voice over IP)通信会话。SIP是一种通信协议,用于建立、修改和终止多媒体…

如何实现同一画面显示不同的2个视频

有时候我们想将2个视频拼接在一起,让这2个视频并排或上下显示,以在同一屏幕上同时播放,这样可以进行视频里面内容的对比或者引起他人的注意力。 如果您想创作这种分屏的视频,将2个或者多个不同的视频放在一个屏幕上,是…

提取B站视频

1、将视频链接粘贴到下面的网站,下载视频到本地。 贝贝BiliBili - B站视频下载 2、使用剪映打开视频,导入视频,导出字幕文件SRT 剪映专业版-全能易用的桌面端剪辑软件-轻而易剪 上演大幕 3、上传SRT文件,解析出来即可 it365 字…

串口程序(1)-接收多个字节程序设计

数据寄存器 关键的标志位 通过该宏定义可以开启对应的串口中断,之前用该宏定义代替标准库函数USART_ITConfig(USART1, USART_IT_RXNE, ENABLE); //使能接收中断 HAL库程序 1.串口发送程序 HAL库串口发送一个/一组数据是很简单的,可以直接调用HAL_UART…

【9】PyQt对话框

目录 1. QMessageBox 2. QIputDialog 对话框是为了更好地实现人与程序的交互 对话框主要是完成特定场景下的功能,比如删除确认等 QDialog的子类有QMessageBox、QFileDialog、QFontDialog、QInputDialog等 1. QMessageBox QMessageBox是普通的对话框 代码示例: …

什么是数据清洗、特征工程、数据可视化、数据挖掘与建模?

1.1什么是数据清洗、特征工程、数据可视化、数据挖掘与建模? 视频为《Python数据科学应用从入门到精通》张甜 杨维忠 清华大学出版社一书的随书赠送视频讲解1.1节内容。本书已正式出版上市,当当、京东、淘宝等平台热销中,搜索书名即可。内容涵…

【Openstack Train】十六、swift安装

OpenStack Swift是一个分布式对象存储系统,它可以为大规模的数据存储提供高可用性、可扩展性和数据安全性。Swift是OpenStack的一个核心组件,它允许用户将大量的数据存储在云上,并且可以随时访问、检索和管理这些数据。 Swift的设计目标是为了…

深入理解Sentinel系列-1.初识Sentinel

👏作者简介:大家好,我是爱吃芝士的土豆倪,24届校招生Java选手,很高兴认识大家📕系列专栏:Spring源码、JUC源码、Kafka原理、分布式技术原理🔥如果感觉博主的文章还不错的话&#xff…

搜维尔科技:Varjo如何提高汽车设计和驾驶测试的生产力

增强和虚拟现实技术有助于提高汽车、航空航天、工业生产等各个领域的工人生产力。尽管这些应用程序的上下文通常相当具体,但其中许多用例的某些方面是通用的。 在本文中,我们将具体探讨基于LP-RESEARCH的LPVR操作系统的 Varjo头戴式显示器的姿态跟踪主题…

linux虚拟机Virtualbox的下载安装及vagrant镜像下载安装

Virtualbox下载安装以及创建及简单使用一个虚拟机 1.开启电脑cpu虚拟机 以戴尔G3为例 找到电脑设置–>更新与安全–>恢复 这个步骤也可以在电脑开机时一直按键esc(或者F1、或者F2、或者deleete)都可以进入BIOS 进入BIOS 完成以上步骤就可以开启电脑cpu虚拟机了 …

Django回顾 - 6 Ajax

【1】Ajax 定义: 异步Javscript和XML 作用: Javascript语言与服务器(django)进行异步交互,传输的数据为XML(当然,传输的数据不只是XML,现在更多使用json数据) 同步交互和异步交互: 1、同步交互&…

Word文件设置了只读模式,为什么还能编辑?

Word文档设置了只读模式,为什么还可以编辑呢?,不过当我们进行保存的时候会发现,word提示需要重命名并选择新路径才能够保存,是因为什么呢?今天我们学习一下如何解决问题。 这种操作,即使可以编辑…

香港科技大学广州|机器人与自主系统学域博士招生宣讲会—北京专场!!!(暨全额奖学金政策)

在机器人和自主系统领域实现全球卓越—机器人与自主系统学域 硬核科研实验室,浓厚创新产学研氛围! 教授亲临现场,面对面答疑解惑助攻申请! 一经录取,享全额奖学金1.5万/月! 时间:2023年12月09日…

华为配置流量抑制示例

如拓扑图所示,SwitchA作为二层网络到三层路由器的衔接点,需要限制二层网络转发的广播、未知组播和未知单播报文,防止产生广播风暴,同时限制二三层网络转发的已知组播和已知单播报文,防止大流量冲击。 配置思路 用如下…

vue中实现数字+英文字母组合键盘

完整代码 <template><div class"login"><div click"setFileClick">欢迎使用员工自助终端</div><el-dialog title"初始化设置文件打印消耗品配置密码" :visible.sync"dialogSetFile" width"600px&quo…

数据库原理: 笛卡儿积

笛卡儿积&#xff08;Cartesian Product&#xff09;是集合论中的一个概念&#xff0c;也在数据库中的查询操作中经常使用。笛卡儿积是指两个集合&#xff08;或更多集合&#xff09;之间所有可能的组合。如果有两个集合A和B&#xff0c;它们的笛卡儿积记作A B&#xff0c;表示…

DevExpress WinForms Pivot Grid组件,一个类似Excel的数据透视表控件(一)

界面控件DevExpress WinForms的Pivot Grid组件是一个类似Excel的数据透视表控件&#xff0c;用于多维(OLAP)数据分析和跨选项卡报表。众多的布局自定义选项使您可以完全控制其UI&#xff0c;无与伦比的以用户为中心的功能使其易于部署。 DevExpress WinForms有180组件和UI库&a…

MongoDB知识总结

这里写自定义目录标题 MongoDB基本介绍MongoDB基本操作数据库相关集合相关增删改查 MongoDB基本介绍 简单介绍 MongoDB是一个基于分布式文件存储的数据库。由C语言编写。旨在为WEB应用提供可扩展的高性能数据存储解决方案。 MongoDB是一个介于关系数据库和非关系数据库之间的产…