解决deepspeed框架的bug:不保存调度器状态,模型训练重启时学习率从头开始

news2024/10/6 16:27:00

deepspeed存在一个bug,即在训练时不保存调度器状态,因此如果训练中断后再重新开始训练,调度器还是会从头开始而不是接着上一个checkpoint的调度器状态来训练。这个bug在deepspeed的github中也有其他人提出:https://github.com/microsoft/DeepSpeed/issues/3875
因此我们需要写一个保存调度器状态的代码,才可以解决这个问题。
具体方法是加一个callback类,专门负责保存调度器的状态以及在训练重新开始时加载调度器的状态:

class SchedulerStateCallback(TrainerCallback):
    def on_save(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
        if os.environ.get("RANK", "0") == "0":
            scheduler = kwargs['lr_scheduler']
            scheduler_state = scheduler.state_dict()
            save_path = os.path.join(args.output_dir, SCHEDULER_NAME)
            torch.save(scheduler_state, save_path)
            #优化器状态已经被deepspeed框架保存了,所以这里没必要再保存
            # optimizer = kwargs['optimizer']
            # optimizer_state = optimizer.state_dict()
            # save_path = os.path.join(args.output_dir, OPTIMIZER_NAME)
            # torch.save(optimizer_state, save_path)
        #torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
    def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
        # 当训练开始时,尝试加载最近的调度器状态
        # load_path = os.path.join(args.output_dir, OPTIMIZER_NAME)
        # if os.path.exists(load_path):
        #     optimizer = kwargs['optimizer']
        #     optimizer_state = torch.load(load_path)
        #     optimizer.load_state_dict(optimizer_state)
        load_path = os.path.join(args.output_dir, SCHEDULER_NAME)
        if os.path.exists(load_path):
            scheduler = kwargs['lr_scheduler']
            scheduler_state = torch.load(load_path)
            scheduler.load_state_dict(scheduler_state)

解决效果如下,我们可以看到,在chaeckpoint10重新开始训练的时候,学习率是接着之前的学习率开始的(5.5e-7),而不是从头开始(0.5e-7):
在这里插入图片描述在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/978273.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

清理Maven仓库中下载失败的文件

🌷🍁 博主猫头虎(🐅🐾)带您 Go to New World✨🍁 🦄 博客首页——🐅🐾猫头虎的博客🎐 🐳 《面试题大全专栏》 🦕 文章图文…

【SpringBoot】统一功能处理

目录 🎃1 拦截器 🎀1.1 拦截器的代码实现 🎨1.2 拦截器的实现原理 🧶2 拦截器应用——登录验证 🦺3 异常统一处理 🎭4 统一数据返回格式 🧤4.1 为什么需要统一数据返回格式 🧣4.2 统…

Cisco Packet Tracer入门篇

💐 🌸 🌷 🍀 🌹 🌻 🌺 🍁 🍃 🍂 🌿 🍄🍝 🍛 🍤 📃个人主页 :阿然成长日记 …

Python中的文件I/O操作:常见问题与解决方案

在Python编程中,文件I/O操作是常见的任务。本文将介绍一些关于Python文件I/O操作的常见问题及其解决方案,并提供详细的代码示例。 1、问题:如何正确地打开和关闭文件? 解决方案:使用with语句可以确保文件在操作完成后…

查漏补缺 - ES6

目录 1,let 和 const1,会产生块级作用域。2,如何理解 const 定义的变量不可被修改? 2,数组3,对象1,Object.is()2,属性描述符3,常用API4,得到除某个属性之外的新对象。 4…

华为云云服务器评测|使用Docker可视化Portainer部署Yolov5项目进行AI识别

目录 初始化配置使用Xshell连接 项目准备 docker-compose Dockerfile .dockerignore 在服务器中启动Docker项目 初始化配置使用Xshell连接 因为我比较喜欢用xshell来操作服务器,如果你是使用华为在线的CloudShell或其他方式,可以跳过第一步的连接…

【Redis专题】Redis持久化、主从与哨兵架构详解

目录 前言课程目录一、Redis持久化1.1 RDB快照(Snapshot):二进制文件基本介绍开启/关闭方式触发方式bgsave的写时复制(COW,Copy On Write)机制优缺点 1.2 AOF(append-only file)&…

Git—版本控制系统

git版本控制系统 1、什么是版本控制2、常见的版本控制工具3、版本控制分类3.1、本地版本控制3.2、集中版本控制 SVN3.3、分布式版本控制 Git 4、Git与SVN的主要区别5、Git环境配置6、启动Git7、常用的Linux命令8、Git配置9、设置用户名与邮箱(用户标识,必…

数学建模--逻辑回归算法的Python实现

首先感谢CSDN上发布吴恩达的机器学习逻辑回归算法任务的各位大佬. 通过大佬的讲解和代码才勉强学会. 这篇文章也就是简单记录一下过程和代码. CSDN上写有关这类文章的大佬有很多,大家都可以多看一看学习学习. 机器学习方面主要还是过程和方法. 这篇文章只完成了线性可分方面的任…

Mac Homebrew中常用的 Brew 命令

Mac 中常用的 Brew 命令集 Brew(Homebrew)是一个强大的包管理器,用于在 macOS 上安装、更新和管理各种软件包。它使得在 Mac 上安装开发工具、应用程序和库变得轻松和便捷。本博客将介绍一些在 Mac 中常用的 Brew 命令,以帮助您更…

SpringMVC_SSM整合

一、回顾SpringMVC访问接口流程 1.容器加载分析 容器分析 手动注册WebApplicationContext public class ServletConfig extends AbstractDispatcherServletInitializer {Overrideprotected WebApplicationContext createServletApplicationContext() {//获取SpringMVC容器An…

UDP的可靠性传输

UDP系列文章目录 第一章 UDP的可靠性传输-理论篇(一) 第二章 UDP的可靠性传输-理论篇(二) 文章目录 UDP系列文章目录前言1.TCP 和UDP格式对比2.UDP分片原理3.UDP 传输层应该注意问题4.MTU5.UDP 分片机制设计重点 一、ARQ协议什么…

华为OD机考算法题:食堂供餐

目录 题目部分 解析与思路 代码实现 题目部分 题目食堂供餐题目说明某公司员工食堂以盒饭方式供餐。为将员工取餐排队时间降低为0,食堂的供餐速度必须要足够快。现在需要根据以往员工取餐的统计信息,计算出一个刚好能达成排队时间为0的最低供餐速度。…

PPO算法

PPO算法 全称Proximal Policy Optimization,是TRPO(Trust Region Policy Optimization)算法的继承与简化,大大降低了实现难度。原论文 算法大致流程 首先,使用已有的策略采样 N N N条轨迹,使用这些轨迹上的数据估计优势函数 A ^ …

算法做题记录

一、递推 95.费解的开关 #include<iostream> #include<cstring> using namespace std;const int N 8;char a[N][N],s[N][N]; int T; int ans20,cnt; int dir[5][2]{1,0,-1,0,0,1,0,-1,0,0};void turn(int x,int y) {for(int i0;i<5;i){int xx xdir[i][0];in…

数学建模--Topsis评价方法的Python实现

目录 1.算法流程简介 2.算法核心代码 3.算法效果展示 1.算法流程简介 """ TOPSIS(综合评价方法):主要是根据根据各测评对象与理想目标的接近程度进行排序. 然后在现有研究对象中进行相对优劣评价。 其基本原理就是求解计算各评价对象与最优解和最劣解的距离…

文字验证码:简单有效的账号安全守卫!

前言 文字验证码不仅是一种简单易懂的验证方式&#xff0c;同时也是保护您的账号安全的重要工具。通过输入正确的文字组合&#xff0c;您可以有效地确认自己的身份&#xff0c;确保只有真正的用户才能访问您的账号。 HTML代码 <script src"https://cdn6.kgcaptcha.…

rust编译出错:error: failed to run custom build command for `ring v0.16.20`

安装 Visual Studio&#xff0c;确保选择 —.NET 桌面开发、使用 C 的桌面开发和通用 Windows 平台开发。显示已安装的工具链rustup show。然后通过运行更改和设置工具链rustup default stable-x86_64-pc-windows-msvc。 另外是想用clion进行调试rust 需要你按下面配置即可解…

【Spring MVC】统一功能处理

一、登录验证 登录验证通过拦截器实现&#xff0c;拦截器就是在用户访问服务器时&#xff0c;预先拦截检查一下用户的访问请求。 没有拦截器时&#xff0c;用户访问服务器的流程是&#xff1a;用户–>controller–>service–>Mapper。有拦截器时&#xff0c;用户访问…

自旋锁和读写锁

目录 一、自旋锁 1.自旋锁和挂起等待锁 2.自旋锁的接口 二、读写锁 1.读者写者模型与读写锁 2.读写锁接口 3.加锁的原理 4.读写优先级 一、自旋锁 1.自旋锁和挂起等待锁 互斥锁的类型有很多&#xff0c;我们之前使用的锁实际上是互斥锁中的挂起等待锁。互斥锁比较有代…