学习率设置(写给自己看)

news2025/1/9 16:32:05

现往你的.py文件上打上以下代码:

import torch
import numpy as np
from torch.optim import SGD
from torch.optim import lr_scheduler
from torch.nn.parameter import Parameter

model = [Parameter(torch.randn(2, 2, requires_grad=True))]
optimizer = SGD(model, lr=0.1)

然后在最后的循环打上以下代码:

epochs=100
for epoch in (1,epochs+1):
    train()
    test()
    lr_schedulers.step()

这里的train和test是你的训练和测试调用的函数。

学习率参数很难调节,针对图像分类任务,一般使用的是:

1.阶梯型衰减,

就是在指定的批次上降低指定倍数,比如如果100个epoch,设置在1/3和3/4处学习率减小一倍,这种有两种实现方式:

方式一:

lr_schedulers=lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)

方式二:

epochs=100
for epoch in (1,epochs+1):
    if epoch%30 == 0:
        lr = lr*0.1
    train()
    test()
    lr_schedulers.step()

2.MultiStepLR:多个不同速率的衰减

方式一:

scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[30,80], gamma=0.5)

方式二:

for epoch in (1,epochs+1):
    if epoch == 30:
        lr = lr*0.1
    if epoch == 40:
        lr = lr*0.5
    train()
    test()
    lr_schedulers.step()

3.指数型下降的学习率调节器

公式:

 curr_rate:当前的学习率

 init_rate:初始的学习率

gamma:衰减系数

epochs:计数器,从0计数到训练的迭代次数

decay_step:控制衰减速度

公式表达的含义其实很明显,gamma衰减系数代表的就是衰减函数的形状,>1学习率就增长了,<1学习率就衰减了。代码实现:

X = []
Y = []
# 初始学习率
learning_rate = 0.1
# 衰减系数
decay_rate = 0.1
# decay_steps控制衰减速度
# 如果decay_steps大一些,(global_step / decay_steps)就会增长缓慢一些
#   从而指数衰减学习率decayed_learning_rate就会衰减得慢一些
#   否则学习率很快就会衰减为趋近于0
decay_steps = 60
# 迭代轮数
global_steps = 120
# 指数学习率衰减过程
for global_step in range(0,global_steps):
    decayed_learning_rate = learning_rate * decay_rate**(global_step / decay_steps)
    X.append(global_step / decay_steps)
    Y.append(decayed_learning_rate)
    if global_step==0 or global_step==global_steps-1:
        print("global step: %d, learning rate: %f" % (global_step,decayed_learning_rate))
    
fig = plt.figure(1)
ax = fig.add_subplot(1,1,1)
curve = ax.plot(X,Y,'b',label="learning rate")
ax.legend()
ax.set_xlabel("epochs / decay_steps")
ax.set_ylabel("learning_rate")

你通过设置初始学习率和最后想要下降到的学习率试着模拟一下。 效果还是不错的。

 实现方式:

实质上pytorch里面有:

scheduler = lr_scheduler.ExponentialLR(optimizer, gamma=0.9)

但是和上面的公式是有出入的,他的实现方式其实就是当前的学习率乘以gamma系数值,所以在最后学习率肯定会同样的衰减率torch里面下降的是比上面的快的,所以有两种策略,第一种调整gamma系数然后打印每次的学习率的数值调整到自己想要的学习率大小,即:

我i试了试改成0.96差不多就可以了。

第二种就是把上面的方式封装成一个函数,在for循环里每次调用他,封装成函数就可以使用

LambdaLR学习策略

了,它可以自定义函数,实现方式如下:

# 初始学习率
learning_rate = 0.1
# 衰减系数
decay_rate = 0.1
# decay_steps 控制衰减速度
decay_steps = 60
# 迭代轮数
global_steps = 120


# 自定义指数衰减函数
def exponential_decay(initial_lr, decay_rate, decay_steps, global_step):
    return initial_lr * decay_rate**(global_step / decay_steps)



scheduler = LambdaLR(optimizer, lr_lambda=lambda step: exponential_decay(learning_rate, decay_rate, decay_steps, step))

# 记录学习率的变化
lr_history = []

# 模拟训练过程
for epoch in range(global_steps):
    # 执行训练步骤
    # ...

    # 记录当前学习率
    current_lr = optimizer.param_groups[0]['lr']
    lr_history.append(current_lr)

    # 更新学习率
    scheduler.step()

这个函数就非常的方便,像是上面的多阶段衰减也可以使用这个函数进行实现。

 OneCycleLR

scheduler=lr_scheduler.OneCycleLR(optimizer,max_lr=0.1,pct_start=0.5,total_steps=120,div_factor=10,final_div_factor=10)

可视化 OneCycleLR:

import torch
from torch.optim.lr_scheduler import OneCycleLR
import matplotlib.pyplot as plt

# 定义神经网络和优化器
class SimpleNet(torch.nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.fc = torch.nn.Linear(10, 1)

    def forward(self, x):
        return self.fc(x)

net = SimpleNet()
optimizer = torch.optim.SGD(net.parameters(), lr=0.1)

# 定义 OneCycleLR 学习率调度器
scheduler = OneCycleLR(optimizer, max_lr=0.1, pct_start=0.5, total_steps=120, div_factor=10, final_div_factor=10)

# 记录学习率的变化
lr_history = []

# 模拟训练过程
for epoch in range(120):
    # 执行训练步骤
    # ...

    # 记录当前学习率
    current_lr = optimizer.param_groups[0]['lr']
    lr_history.append(current_lr)

    # 更新学习率
    scheduler.step()

# 绘制学习率变化曲线
plt.plot(range(120), lr_history, label="learning rate")
plt.xlabel("epochs")
plt.ylabel("learning rate")
plt.legend()
plt.show()

最后一个余弦退火学习率衰减CosineAnnealingLR

CosineAnnealingLR是余弦退火学习率,T_max是周期的一半,最大学习率在optimizer中指定,最小学习率为eta_min。这里同样能够帮助逃离鞍点。值得注意的是最大学习率不宜太大,否则loss可能出现和学习率相似周期的上下剧烈波动。

基本上的选择方式是选择1/4个余弦函数的周期。

可视化:

这里官方文档的公式说明讲的很清晰,自行学习吧: 

Parameters 参数

  • optimizer (Optimizer) - 包装优化器。

  • T_max (int) - 最大迭代次数。

  • eta_min (float) - 最低学习率。默认值:0。

  • last_epoch (int) - 上一个纪元的索引。默认值:-1。

  • verbose (bool) – 如果 True ,则在每次更新时向 stdout 打印一条消息。默认值: False .

今天的学习就到这里,散会!

ps:最近心情有点糟糕,六级+期末考试+实验出了些问题,好累,今晚好好睡一觉吧,晚安各位。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1283675.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Redis 分布式锁测试

一、前提依赖&#xff08;除去SpringBoot项目基本依赖外&#xff09;&#xff1a; <dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-test</artifactId> </dependency><!-- 配置使用redis启动…

进入软件的世界

选择计算机 上高中的时候&#xff0c;因为沉迷于网络游戏&#xff0c;于是对计算机产生了浓厚的兴趣&#xff0c;但是那个时候对于计算机的了解还是非常肤浅的。上大学的时候&#xff0c;也就义无反顾的选择了计算机专业&#xff0c;其实并不是一个纯粹的计算机专业&#xff0…

代码随想录算法训练营第五十五天【动态规划part15】 | 392.判断子序列、115.不同的子序列

392.判断子序列 题目链接 力扣&#xff08;LeetCode&#xff09;官网 - 全球极客挚爱的技术成长平台 求解思路 也可以用双指针来做。 动规五部曲 1.确定dp数组及其下标含义 以下标i-1为结尾的字符串s&#xff0c;和以下标j-1为结尾的字符串t&#xff0c;相同子序列的长度…

Mybatis中的设计模式

Mybatis中的设计模式 Mybatis中使用了大量的设计模式。 以下列举一些看源码时&#xff0c;觉得还不错的用法&#xff1a; 创建型模式 工厂方法模式 DataSourceFactory 通过不同的子类工厂&#xff0c;实例化不同的DataSource TransactionFactory 通过不同的工厂&#xff…

【雷电模拟器桥接问题解决方法】

1.ROOT权限开启 2.开启网络桥接模式&#xff0c;选择静态IP设置&#xff0c;点击安装桥接网卡&#xff0c;填写IP地址&#xff08;注意&#xff1a;IP地址要与host主机在同一IP段内&#xff09; 3.重启后 adb shell就能进入到模拟器控制台中了&#xff0c;如果出现以下内容&…

记一次若依二开的简单流程

记一次若依二开的简单流程 前言: 搞Java后端的应该都知道若依框架&#xff0c;是一个十分强大且功能齐全的开源的快速开发平台&#xff0c;且毫无保留给个人及企业免费使用。很多中小型公司会直接在该系统上进行二次开发使用。本文记录一次使用若依二开零编码的简单实现&#…

JFrog----软件的SBOM分析简介

文章目录 什么是SBOM&#xff1f;SBOM分析的重要性SBOM分析的过程结语 什么是SBOM&#xff1f; SBOM&#xff0c;全称是“软件物料清单”&#xff0c;它像是一个详尽的清单&#xff0c;列出了构成特定软件的所有组件&#xff0c;包括库、模块、包等。这就像是制造业中的物料清…

iOS ------ UICollectionView

一&#xff0c;UICollectionView的简介 UICollectionView是iOS6之后引入的一个新的UI控件&#xff0c;它和UITableView有着诸多的相似之处&#xff0c;其中许多代理方法都十分类似。简单来说&#xff0c;UICollectionView是比UITbleView更加强大的一个UI控件&#xff0c;有如下…

C语言中如何取一串比特中的特定位的比特

#include <iostream> #include <bitset> using namespace std; /* 向右的移位操作相当于丢掉最后的几位&#xff0c;然后剩下的位数进行“与”运算即可。 */ int main() {int a 0x2FB7; //0x2FB70010 1111 1011 0111char end3 (a >> 4) & 0x07; //取a…

Javaweb之Vue路由的详细解析

5 Vue路由 5.1 路由介绍 将资代码/vue-project(路由)/vue-project/src/views/tlias/DeptView.vue拷贝到我们当前EmpView.vue同级&#xff0c;其结构如下&#xff1a; 此时我们希望基于4.4案例中的功能&#xff0c;实现点击侧边栏的部门管理&#xff0c;显示部门管理的信息&am…

“影响力”经济:抖音为什么更值得商家、达人长期深耕?

文&#xff5c;新熔财经 作者&#xff5c;叶一城 数亿的活跃用户&#xff0c;简单而自然的切入方式&#xff0c;快速、高频的执行效率&#xff0c;让抖音对电商界的冲击无可阻挡。 这背后&#xff0c;流量玩法登峰造极&#xff0c;是很多人的直接观感。 但实际上&#xff0…

FL Studio 21.2.1.3859中文破解版及FL Studio怎么录制

FL Studio 21.2.1.3859中文破解版是一个数字音频工作站 (DAW)。该软件借助各种编辑工具、插件和效果&#xff0c;让您可以录制、混音和掌握高度复杂的音乐作品。FL Studio 21还允许您注册和编辑 MIDI 文件&#xff0c;您可以在众多可用乐器之一上演奏这些文件。FL Studio 拥有 …

【VRTK】【VR开发】【Unity】10-连续移动

课程配套学习资源下载 https://download.csdn.net/download/weixin_41697242/88485426?spm=1001.2014.3001.5503 【概述】 连续移动与瞬移有如下不同: 连续移动不容易打断沉浸对于新手或者不适应者来说更容易晕动 我对玩家的建议:连续移动前后左右可以用摇杆,转向用自己…

java常用知识点记忆

类的继承与多态 类的继承不支持多重继承非private 方法才可以被覆盖覆盖的方法要求&#xff0c;子类中的方法的名字&#xff0c;参数列表&#xff0c;返回类型与父类相同方法的重载是在一个类中定义方法名字相同&#xff0c;但是参数列表不同的方法要是在子类中定义了与父类名字…

Huawei FusionSphere FusionCompte FusionManager

什么是FusionSphere FusionSphere 解决方案不独立发布软件&#xff0c;由各配套部件发布&#xff0c;请参 《FusionSphere_V100R005C10U1_版本配套表_01》。 目前我们主要讨论FusionManager和FusionCompute两个组件。 什么是FusionCompte FusionCompute是华为提供的虚拟化软…

深度学习训练 tricks(持续更新)

妈妈&#xff0c;我的炼丹炉子炸啦&#xff08;不是&#xff09; 妈妈&#xff0c;我的深度学习模型训练好了&#xff01; 本文持续更新&#xff0c;如果有什么你知道的深度学习模型训练技巧&#xff0c;可以在评论区提出&#xff0c;我会加进来的。 文章目录 weight decaywe…

3DMM模型

目录 BFMBFM_200901_MorphableModel.matexp_pca.bintopology_info.npyexp_info.npy BFM BFM_2009 01_MorphableModel.mat from scipy.io import loadmat original_BFM loadmat("01_MorphableModel.mat") # dict_keys: [__header__, __version__, __globals__, # …

C++ 文件操作之配置文件读取

C 文件操作之配置文件读取 在项目应用时常常会涉及一些调参工作&#xff0c;如果项目封装成了.exe或者.dll&#xff0c;那么频繁调参多次编译是一件十分低效的事情&#xff0c;如果代码算法或者逻辑是一定的&#xff0c;那么参数完全可以通过读入配置文件来获取之前在用C - op…

SpringBoot药品进销存管理系统(诊所管理系统)(乡村药店管理系统)

SSM毕设分享 SpringBoot药品进销存管理系统(诊所管理系统)(乡村药店管理系统) 1 项目简介 Hi&#xff0c;各位同学好&#xff0c;这里是郑师兄&#xff01; 今天向大家分享一个毕业设计项目作品【SpringBoot药品进销存管理系统(诊所管理系统)(乡村药店管理系统)】 师兄根据实…

ROS-ROS通信机制-话题通信

文章目录 一、话题通信基础知识二、话题通信基本操作2-1 C2-2 Python2-3 C与python节点通信 三、自定义msg3-1 自定义msg3-2 C实现自定义msg调用3-3 Python实现自定义msg调用 一、话题通信基础知识 话题通信实现模型是比较复杂的&#xff0c;该模型如下图所示,该模型中涉及到三…