240621_昇思学习打卡-Day3-余弦退火+周期性重启+warm up

news2024/11/16 12:01:05

240621_昇思学习打卡-Day3-余弦退火+周期性重启+warm up

先展示一个完整的余弦退火+周期性重启+warm up调整学习率的流程(横轴为epoch,纵轴为学习率):

image-20240622001826281

我们换一个收敛较慢的图进行详细说明:

image-20240622002540209

Warm up

在神经网络刚开始训练时,梯度较大,如果一开始就设置比较大的学习率的话,训练会极不稳定,导致不能得到较好的收敛效果,所以我们需要在最开始训练时将学习率保持在一个比较低的水平,让梯度先收敛到一定程度,然后再把学习率增大,可以有效提高收敛效果。这个过程称为网络训练的预热(warm up)

    def init_lr(self):
        """
        初始化每个参数组的学习率。
        """
        self.base_lrs = []
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = self.min_lr
            self.base_lrs.append(self.min_lr)

余弦退火

使用余弦函数可以达到一个较好的学习率衰减效果,具体来说,随着x的增加余弦值首先缓慢下降,然后加速下降,在即将到达极值点时收敛速度,缓慢靠近,这种下降模式与学习率配合可以得到很好的效果。这个过程就叫余弦退火

image-20240622003616284
在这里插入图片描述
这一段本来是写的字,但是传到csdn不知道为什么编码就乱了,半夜了也不想深究了,直接截图上来吧

image-20240622005556314

在这里插入图片描述

这部分的核心代码如下:

    def get_lr(self):
        """
        计算当前步骤的学习率。

        返回:
        - list: 当前步骤的学习率列表。
        """
        if self.step_in_cycle == -1:
            return self.base_lrs
        elif self.step_in_cycle < self.warmup_steps:
            # 线性warmup阶段
            return [(self.max_lr - base_lr)*self.step_in_cycle / self.warmup_steps + base_lr for base_lr in self.base_lrs]
        else:
            # 余弦退火阶段,这么多代码其实就是上面那个公式
            return [base_lr + (self.max_lr - base_lr) \
                    * (1 + math.cos(math.pi * (self.step_in_cycle-self.warmup_steps) \
                                    / (self.cur_cycle_steps - self.warmup_steps))) / 2
                    for base_lr in self.base_lrs]

周期性重启

在进行“一轮”(学习率收敛到较低点)学习之后,可能达到的极值点只是局部极小值,并不是全局最小值,现在就需要用到:学习率突然增大,跳出局部极小值,去寻找全局最小值,这个过程称为周期性重启

这里我们梳理一下整个流程:

在模型训练最开始时,因此时梯度较大,我们的学习率需要保持一个较低得水平,在梯度得到一定程度的收敛之后(比如两轮),学习率开始急速增大(warm up),然后为了靠近极值点,采用余弦退火进行学习率的调整,在学习率调整到较低点时(假设八轮),因此时不能保证是否为全局最小值,我们需要再让他跳出这个极值点,再去找有没有更优的极值点,这就是周期性重启(十轮一重启)。

以下就是周期性重启的核心代码:

	 def step(self, epoch=None):
        """
        执行一步学习率调度。

        参数:
        - epoch (int, 可选): 当前周期。如果为None,则使用内部的last_epoch值并自增。
        """
        if epoch is None:
            # 自增步骤并检查是否需要重置周期
            epoch = self.last_epoch + 1
            self.step_in_cycle = self.step_in_cycle + 1
            if self.step_in_cycle >= self.cur_cycle_steps:
                self.cycle += 1
                self.step_in_cycle = self.step_in_cycle - self.cur_cycle_steps
                self.cur_cycle_steps = int((self.cur_cycle_steps - self.warmup_steps) * self.cycle_mult) + self.warmup_steps
        else:
            # 根据给定的epoch更新周期和步骤
            if epoch >= self.first_cycle_steps:
                if self.cycle_mult == 1.:
                    self.step_in_cycle = epoch % self.first_cycle_steps
                    self.cycle = epoch // self.first_cycle_steps
                else:
                    n = int(math.log((epoch / self.first_cycle_steps * (self.cycle_mult - 1) + 1), self.cycle_mult))
                    self.cycle = n
                    self.step_in_cycle = epoch - int(self.first_cycle_steps * (self.cycle_mult ** n - 1) / (self.cycle_mult - 1))
                    self.cur_cycle_steps = self.first_cycle_steps * self.cycle_mult ** (n)
            else:
                self.cur_cycle_steps = self.first_cycle_steps
                self.step_in_cycle = epoch

        # 更新最大学习率
        self.max_lr = self.base_max_lr * (self.gamma**self.cycle)
        self.last_epoch = math.floor(epoch)
        for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
            param_group['lr'] = lr

完整代码如下:

# 导入优化器模块和PyTorch的神经网络模块
import torch.optim as optim
import torch
import torch.nn as nn
from torch.optim import SGD
import math
import matplotlib.pyplot as plt

# 定义CosineAnnealingWarmupRestarts学习率调度器类
class CosineAnnealingWarmupRestarts(optim.lr_scheduler._LRScheduler):
    """
    Cosine Annealing Warmup Restarts学习率调度器。

    参数:
    - optimizer (Optimizer): 包装的优化器。
    - first_cycle_steps (int): 第一个周期的步数。
    - cycle_mult (float): 周期步数的乘数。默认: 1.0。
    - max_lr (float): 第一个周期的最大学习率。默认: 0.1。
    - min_lr (float): 最小学习率。默认: 0.001。
    - warmup_steps (int): 线性warmup的步数。默认: 0。
    - gamma (float): 周期间最大学习率的减少率。默认: 1.0。
    - last_epoch (int): 上一个周期的索引。默认: -1。
    """
    def __init__(self,
                 optimizer: torch.optim.Optimizer,
                 first_cycle_steps: int,
                 cycle_mult: float = 1.,
                 max_lr: float = 0.1,
                 min_lr: float = 0.001,
                 warmup_steps: int = 0,
                 gamma: float = 1.,
                 last_epoch: int = -1
                 ):
        # 确保warmup步骤少于第一个周期的步骤
        assert warmup_steps < first_cycle_steps

        # 初始化各种参数
        self.first_cycle_steps = first_cycle_steps  # first cycle step size
        self.cycle_mult = cycle_mult    # cycle steps magnification
        self.base_max_lr = max_lr   # first max learning rate
        self.max_lr = max_lr    # max learning rate in the current cycle
        self.min_lr = min_lr    # min learning rate
        self.warmup_steps = warmup_steps    # warmup step size
        self.gamma = gamma  # decrease rate of max learning rate by cycle

        # 当前周期的步数和周期计数
        self.cur_cycle_steps = first_cycle_steps    # first cycle step size
        self.cycle = 0  # cycle count
        self.step_in_cycle = last_epoch     # step size of the current cycle

        # 调用父类构造函数
        super(CosineAnnealingWarmupRestarts, self).__init__(optimizer, last_epoch)

        # 初始化学习率
        self.init_lr()

    # 初始化学习率的方法
    def init_lr(self):
        """
        初始化每个参数组的学习率。
        """
        self.base_lrs = []
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = self.min_lr
            self.base_lrs.append(self.min_lr)

    # 计算当前学习率的方法
    def get_lr(self):
        """
        计算当前步骤的学习率。

        返回:
        - list: 当前步骤的学习率列表。
        """
        if self.step_in_cycle == -1:
            return self.base_lrs
        elif self.step_in_cycle < self.warmup_steps:
            # 线性warmup阶段
            return [(self.max_lr - base_lr)*self.step_in_cycle / self.warmup_steps + base_lr for base_lr in self.base_lrs]
        else:
            # Cosine Annealing阶段
            return [base_lr + (self.max_lr - base_lr) \
                    * (1 + math.cos(math.pi * (self.step_in_cycle-self.warmup_steps) \
                                    / (self.cur_cycle_steps - self.warmup_steps))) / 2
                    for base_lr in self.base_lrs]

    # 执行一步学习率调度的方法
    def step(self, epoch=None):
        """
        执行一步学习率调度。

        参数:
        - epoch (int, 可选): 当前周期。如果为None,则使用内部的last_epoch值并自增。
        """
        if epoch is None:
            # 自增步骤并检查是否需要重置周期
            epoch = self.last_epoch + 1
            self.step_in_cycle = self.step_in_cycle + 1
            if self.step_in_cycle >= self.cur_cycle_steps:
                self.cycle += 1
                self.step_in_cycle = self.step_in_cycle - self.cur_cycle_steps
                self.cur_cycle_steps = int((self.cur_cycle_steps - self.warmup_steps) * self.cycle_mult) + self.warmup_steps
        else:
            # 根据给定的epoch更新周期和步骤
            if epoch >= self.first_cycle_steps:
                if self.cycle_mult == 1.:
                    self.step_in_cycle = epoch % self.first_cycle_steps
                    self.cycle = epoch // self.first_cycle_steps
                else:
                    n = int(math.log((epoch / self.first_cycle_steps * (self.cycle_mult - 1) + 1), self.cycle_mult))
                    self.cycle = n
                    self.step_in_cycle = epoch - int(self.first_cycle_steps * (self.cycle_mult ** n - 1) / (self.cycle_mult - 1))
                    self.cur_cycle_steps = self.first_cycle_steps * self.cycle_mult ** (n)
            else:
                self.cur_cycle_steps = self.first_cycle_steps
                self.step_in_cycle = epoch

        # 更新最大学习率
        self.max_lr = self.base_max_lr * (self.gamma**self.cycle)
        self.last_epoch = math.floor(epoch)
        for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
            param_group['lr'] = lr

# 创建一个简单的线性模型和SGD优化器
# 构建一个简单的模型和优化器
model = nn.Linear(10, 1)  # 简单的线性层作为示例模型
optimizer = SGD(model.parameters(), lr=0.1)  # 初始化优化器,lr参数会被调度器覆盖

# 实例化学习率调度器
# 实例化学习率调度器
scheduler = CosineAnnealingWarmupRestarts(optimizer,
                                         first_cycle_steps=10,
                                         cycle_mult=1.,
                                         max_lr=0.01,
                                         min_lr=0.001,
                                         warmup_steps=2,
                                         gamma=0.9)

# 打印初始学习率
print("Initial LR:", scheduler.get_lr())

# 记录学习率的变化
loss_list=[]
# 模拟训练过程中的学习率变化
# 模拟几个周期的训练步骤
for epoch in range(100):  # 总共运行25个epoch
    scheduler.step()  # 更新学习率
    lrs = scheduler.get_lr()
    loss_list.append(lrs)
    print(f"Epoch {epoch}: LR(s) -> {lrs}")

# 绘制学习率变化图
x=list(range(100))
plt.figure()
plt.plot(x,loss_list)
plt.show()

打卡记录:

image-20240622013340465

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1859446.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

若依框架下拉单选框根据js动态加载,如何使select2的下拉搜素功能同时生效(达到select下拉框的样式不变的效果)

直接上代码&#xff0c;不废话 $(select[name"sealType"]).change(function (event) {let value event.target.valuequeeryDeptListBySealType(value)})// 获取科目信息function queeryDeptListBySealType(value){$.ajax({type: "post",url: prefix &quo…

【Linux】Linux基础开发工具(yum)

Linux 软件包管理器 yum 什么是软件包 在Linux下安装软件, 一个通常的办法是下载到程序的源代码, 并进行编译, 得到可执行程序.但是这样太麻烦了, 于是有些人把一些常用的软件提前编译好, 做成软件包(可以理解成windows上的安 装程序)放在一个服务器上, 通过包管理器可以很方便…

__FILE__ 一个非常实用的宏

经常在VS和QtCreator这两个开发环境之间切换的同志肯定会发现这两个开发环境生成的可执行程序的文件路径不一样&#xff0c;VS是在项目文件目录里面&#xff0c;而qt creator是在和项目文件夹同一目录下。如下图所示&#xff1a; QtCreator: VS: 这就导致了一个问题,若要获取项…

面向对象的编程思想

面向对象的编程思想 一、什么是面向对象&#xff1f; 面向对象编程的核心思想是把构成问题的各个事物分解成各个对象&#xff0c;建立对象的目的不是为了完成一个步骤&#xff0c;而是为了描述一个事物在解决问题的过程中经历的步骤和行为。对象作为程序的基本单位&#xff0…

如何开发、使用 Starter

开发 第一步&#xff1a;创建starter工程hello-spring-boot-starter并配置pom.xml文件 <?xml version"1.0" encoding"UTF-8"?> <project xmlns"http://maven.apache.org/POM/4.0.0"xmlns:xsi"http://www.w3.org/2001/XMLSchem…

达梦(DM8)数据库表空间的备份与还原(联机备份) 四

一、表空间的备份 1、备份表空间的命令操作 backup tablespace main backupset /home/dmdba/dmdata/DAMENG/bak/full_back_01 ; 2、检查表空间的备份文件 select sf_bakset_check(disk,/home/dmdba/dmdata/DAMENG/bak/full_back_01); 二、表空间的还原 1、修改表空间位脱机…

[MySql]两阶段提交

文章目录 什么是binlog使用binlog进行恢复的流程 什么是redolog缓冲池redologredolog结构 两阶段提交 什么是binlog binlog是二进制格式的文件&#xff0c;用于记录用户对数据库的修改&#xff0c;可以作用于主从复制过程中数据同步以及基于时间点的恢复&#xff08;PITR&…

Java如何设置Map过期时间的的几种方法

一、技术背景 在实际的项目开发中&#xff0c;我们经常会使用到缓存中间件&#xff08;如redis、MemCache等&#xff09;来帮助我们提高系统的可用性和健壮性。 但是很多时候如果项目比较简单&#xff0c;就没有必要为了使用缓存而专门引入Redis等等中间件来加重系统的复杂性…

【uniapp】HBuilderx中uniapp项目运行到微信小程序报错Error: Fail to open IDE

HBuilderx中uniapp项目运行到微信小程序报错Error: Fail to open IDE 问题描述 uniapp开发微信小程序&#xff0c;在HBuilderx中运行到微信开发者工具时报错Error: Fail to open IDE 解决方案 1. 查看微信开发者工具端服务端口是否开放 打开微信开发者工具选择&#xff1…

艺术家电gorenje x 设计上海丨用设计诠释“生活的艺术”

2024年6月19日—22日&#xff0c;艺术家电gorenje亮相“设计上海”2024&#xff0c;以“gorenje是家电更是艺术品”为题&#xff0c;为人们带来融入日常的艺术之美。设计上海2024不但汇集了国内外卓越设计品牌和杰出独立设计师的家具设计作品&#xff0c;还联合国内外多名设计师…

【linux学习十六】网络管理

网络管理器(NetworkManager)是一个动态网络的控制器与配置系统&#xff0c;它用于当网络设备可用时保持设备和连接开启并激活 默认情况下&#xff0c;CentOS/RHEL7已安装网络管理器&#xff0c;并处于启用状态。 认识网卡 ens32 ens33 ens34 ens35 一.ip相关 查询网络状态 sy…

数据结构——二分算法

二分查找 1. 在排序数组中查找元素的第一个和最后一个位置 代码实现&#xff1a; /*** Note: The returned array must be malloced, assume caller calls free().*/int binarySearch(int *nums, int numsSize, int target) {int l 0, r numsSize - 1; while (l <…

viper:一款中国人写的红队服务器——记一次内网穿透练习

1. viper Viper 是中国人自主编写的一款红队服务器&#xff0c;提供图形化的操作界面&#xff0c;让用户使用浏览器即可进行内网渗透&#xff0c;发布在语雀官方地址 提供了很全面的官方文档&#xff0c;包括四大部分&#xff0c;分别是使用手册、模块文档、博客文章、开发手册…

高中数学:数列-错位相减法与裂项相消法求数列的和

一、错位相减法 设&#xff0c;an是等差数列&#xff0c;bn是等比数列&#xff0c;那么{an*bn}构成一个新的数列 这个新数列的求和公式&#xff0c;就可以用错位相减法求解。 练习 例题1 解析&#xff1a; 第一问 第二问 二、裂项相消法 1、裂项的几种常见形式 形式1…

Junit4测试基本应用(白盒测试)

Junit4测试基本应用&#xff08;白盒测试&#xff09; 一、实验目的 掌握Junit的基本操作&#xff0c;进行较简单的单元测试。 二、Junit4测试的使用 1. 创建java项目JUnitText 我使用的Eclipse&#xff0c;在左侧Package Explorer(包资源管理器)右键&#xff0c;新建Java …

物联网 IoT 收录

物联网IoT日常知识收录 thingsboard, nodered是国际大品牌&#xff0c; iotgateway是国内的&#xff0c; 几个scada, pyscada, json-scada都还不错&#xff0c;比较一下。thingsboard-gateway是python系的&#xff0c;如果你愿意&#xff0c;可以用这个作为公司的物联网网关。…

全网最强剖析Spring AOP底层原理

相信各位读者对于Spring AOP的理解都是一知半解&#xff0c;只懂使用&#xff0c;却不懂原理。网上关于Spring AOP的讲解层出不穷&#xff0c;但是易于理解&#xff0c;让人真正掌握原理的文章屈指可数。笔者针对这一痛点需求&#xff0c;决定写一篇关于Spring AOP原理的优质博…

视频监控解决方案:视频平台升级技术方案(下)

目录 1 项目概况 2 项目需求 2.1 视频感知资源扩充 2.2 视频支撑能力升级 2.3 视频应用能力升级 3 技术设计方案 3.1系统总体架构 3.2视频感知资源扩充设计 3.3 视频支撑能力升级设计 3.4 视频应用能力升级设计 3.4.1视频资源目录管理 3.4.2标签管理 3.4.3设备智能…

万亿国债野外图传——天通卫星图传设备类目推荐

在远离都市喧嚣的辽阔自然中&#xff0c;户外工业作业以其独特的重要性日益凸显&#xff0c;涵盖了从高山峻岭的地质勘探、森林资源调查到广袤草原的生态监测等众多领域。然而传统监测方法不能全面覆盖&#xff0c;冰雪覆盖的山区和偏远地区的电力设施状况以及野生动物等户外状…

多功能推拉力测试机可实现焊球推力测试

LB-8100A 多功能推拉力测试机广泛应于与 LED 封装测试、IC 半导体封 装测试、TO 封装测试、IGBT 功率模块封装测试、光电子元器件封装测试、汽 车领域、航天航空领域、军工产品测试、研究机构的测试及各类院校的测试 研究等应用。 多功能推拉力测试机设置主要结构&#xff1a;…