Pytorch优化器Optimizer

news2024/11/25 1:44:11

优化器Optimizer

什么是优化器

pytorch的优化器:管理并更新模型中可学习参数的值,使得模型输出更接近真实标签

导数:函数在指定坐标轴上的变化率
方向导数:指定方向上的变化率(二元及以上函数,偏导数)
梯度:一个向量,方向是使得方向导数取得最大值的方向

Pytorch的Optimizer

在这里插入图片描述

参数

  • defaults:优化器超参数
  • state:参数的缓存,如momentum的缓存
  • param_groups:管理的参数组
  • _step_count:记录更新次数,学习率调整中使用

基本方法:

  • zero_grad():清空所管理参数的梯度
    在这里插入图片描述

pytorch特性:张量梯度不会自动清零

  • step():执行一步更新

  • add_param_group():添加参数组
    在这里插入图片描述

  • state_dict():获取优化器当前状态信息字典
    在这里插入图片描述

  • load_state_dict():加载状态信息字典

使用代码帮助理解和学习

import os
import torch
import torch.optim as optim


BASE_DIR = os.path.dirname(os.path.abspath(__file__))


weight = torch.randn((2, 2), requires_grad=True)
weight.grad = torch.ones((2, 2))

# 需要传入一个可迭代对象
optimizer = optim.SGD([weight], lr=1)

print("weight before step:{}".format(weight.data))
optimizer.step()
print("weight after step:{}".format(weight.data))

weight before step:tensor([[-0.0606, -0.3197],
        [ 1.4949, -0.8007]])
weight after step:tensor([[-1.0606, -1.3197],
        [ 0.4949, -1.8007]])

weight = weight - lr * weight.grad
上面学习率是1,把学习率改为0.1试一下

optimizer = optim.SGD([weight], lr=0.1)

weight before step:tensor([[ 0.3901,  0.2167],
        [-0.3428, -0.7151]])
weight after step:tensor([[ 0.2901,  0.1167],
        [-0.4428, -0.8151]])

接着上面的代码,我们再看一下add_param_group方法

# add_param_group方法
print("optimizer.param_groups is \n{}".format(optimizer.param_groups))

w2 = torch.randn((3, 3), requires_grad=True)
optimizer.add_param_group({"params": w2, "lr": 0.0001})
print("optimizer.param_groups is\n{}".format(optimizer.param_groups))

optimizer.param_groups is 
[{'params': [tensor([[ 0.1749, -0.2018],
        [ 0.0080,  0.3517]], requires_grad=True)], 'lr': 0.1, 'momentum': 0, 'dampening': 0, 'weight_decay': 0, 'nesterov': False}]

optimizer.param_groups is
[{'params': [tensor([[ 0.1749, -0.2018],
        [ 0.0080,  0.3517]], requires_grad=True)], 'lr': 0.1, 'momentum': 0, 'dampening': 0, 'weight_decay': 0, 'nesterov': False}, 
 {'params': [tensor([[ 0.4538, -0.8521, -1.3081],
        [-0.0158, -0.2708,  0.0302],
        [-0.3751, -0.1052, -0.3030]], requires_grad=True)], 'lr': 0.0001, 'momentum': 0, 'dampening': 0, 'weight_decay': 0, 'nesterov': False}]

关于zero_grad()step()state_dict()load_state_dict()这几个方法比较简单就不再赘述。

SGD随机梯度下降

learning_rate学习率

在这里插入图片描述

这里学习率为1,可以看到并没有达到梯度下降的效果,反而y值越来越大,这是因为更新的步伐太大。

在这里插入图片描述

我们以y = 4*x^2这个函数举例,将y值作为要优化的损失值,那么梯度下降的过程就是为了找到y的最小值(即此函数曲线的最小值);如果我们把学习率设置为0.2,就可以得到这样一个梯度下降的图

def func(x):
    return torch.pow(2*x, 2)

x = torch.tensor([2.], requires_grad=True)
iter_rec, loss_rec, x_rec = list(), list(), list()
lr = 0.2
max_iteration = 20

for i in range(max_iteration):
    y = func(x)
    y.backward()

    print("iter:{}, x:{:8}, x.grad:{:8}, loss:{:10}".format(
        i, x.detach().numpy()[0], x.grad.detach().numpy()[0], y.item()
    ))
    x_rec.append(x.item())
    x.data.sub_(lr * x.grad)
    x.grad.zero_()
    iter_rec.append(i)
    loss_rec.append(y.item())

plt.subplot(121).plot(iter_rec, loss_rec, '-ro')
plt.xlabel("Iteration")
plt.ylabel("Loss value")

x_t = torch.linspace(-3, 3, 100)
y = func(x_t)
plt.subplot(122).plot(x_t.numpy(), y.numpy(), label="y = 4*x^2")
plt.grid()

y_rec = [func(torch.tensor(i)).item() for i in x_rec]
plt.subplot(122).plot(x_rec, y_rec, '-ro')
plt.legend()
plt.show()

在这里插入图片描述

这里其实存在一个下降速度更快的学习率,那就是0.125,一步就可以将loss更新为0,这是因为我们已经了这个函数表达式,而在实际神经网络模型训练的过程中,是不知道所谓的函数表达式的,所以只能选取一个相对较小的学习率,然后以训练更多的迭代次数来达到最优的loss。

在这里插入图片描述

动量(Momentum,又叫冲量)

结合当前梯度与上一次更新信息,用于当前更新

为什么会出现动量这个概念?

当学习率比较小时,往往更新比较慢,通过引入动量,使得后续的更新受到前面更新的影响,可以更快的进行梯度下降。

指数加权平均:当前时刻的平均值(Vt)与当前参数值(θ)和前一时刻的平均值(Vt-1)的关系。

在这里插入图片描述

根据上述公式进行迭代展开,因为0<β<1,当前时刻的平均值受越近时刻的影响越大(更近的时刻其所占的权重更高),越远时刻的影响越小,我们可以通过下面作图来看到这一变化。

import numpy as np
import matplotlib.pyplot as plt


def exp_w_func(beta, time_list):
    return [(1-beta) * np.power(beta, exp) for exp in time_list]


beta = 0.9
num_point = 100
time_list = np.arange(num_point).tolist()

weights = exp_w_func(beta, time_list)

plt.plot(time_list, weights, '-ro', label="Beta: {}\n = B * (1-B)^t".format(beta))
plt.xlabel("time")
plt.ylabel("weight")
plt.legend()
plt.title("exponentially weighted average")
plt.show()

在这里插入图片描述

这里β是一个超参数,设置不同的值,其对于过去时刻的权重计算如下图

beta_list = [0.98, 0.95, 0.9, 0.8]
w_list = [exp_w_func(beta, time_list) for beta in beta_list]
for i, w in enumerate(w_list):
    plt.plot(time_list, w, label="Beta: {}".format(beta_list[i]))
    plt.xlabel("time")
    plt.ylabel("weight")
plt.legend()
plt.show()

在这里插入图片描述

从图中可以得到这一结论:β值越小,记忆周期越短,β值越大,记忆周期越长

pytorch中带有momentum参数的更新公式

在这里插入图片描述

对于y=4*x^2这个例子,在没有momentum时,我们对比学习率分别为0.01和0.03会发现,0.03收敛的更快。

在这里插入图片描述

如果我们给learning_rate=0.01增加momentum参数,会发现其可以先一步0.03的学习率到达loss的较小值,但是因为动量较大的因素,在达到了最小值后还会反弹到一个大的值。

在这里插入图片描述

Pytorch中的优化器

optim.SGD

主要参数:

  • params:管理的参数组
  • lr:学习率
  • momentum:动量系数,贝塔
  • weight_decayL2正则化系数
  • nesterov:是否采用NAG,默认False

optim.Adagrad:自适应学习率梯度下降法

optim.RMSprop:Adagrad的改进

optim.Adadelta:Adagrad的改进

optim.Adam:RMSprop结合Momentum

optim.Adamax:Adam增加学习率上限

optim.SparseAdam:稀疏版的Adam

optim.ASGD:随机平均梯度下降

optim.Rprop:弹性反向传播

optim.LBFGS:BFGS的改进

学习率调整

前期学习率大,后期学习率小

pytorch中调整学习率的基类

class _LRScheduler

主要属性:

  • optimizer:关联的优化器
  • last_epoch:记录epoch数
  • base_lrs:记录初始学习率

主要方法:

  • step():更新下一个epoch的学习率
  • get_lr():虚函数,计算下一个epoch的学习率

StepLR

等间隔调整学习率

主要参数:

  • step_size:调整间隔数
  • gamma:调整系数

调整方式:lr = lr * gamma

import torch
import torch.optim as optim
import matplotlib.pyplot as plt


LR = 0.1
iteration = 10
max_epoch = 200

weights = torch.randn((1,), requires_grad=True)
target = torch.zeros((1, ))

optimizer = optim.SGD([weights], lr=LR, momentum=0.9)


scheduler_lr = optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.1)  # 设置学习率下降策略

lr_list, epoch_list = list(), list()
for epoch in range(max_epoch):
    lr_list.append(scheduler_lr.get_lr())
    epoch_list.append(epoch)

    for i in range(iteration):
        loss = torch.pow((weights-target), 2)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

    scheduler_lr.step()

plt.plot(epoch_list, lr_list, label='Step LR Scheduler')
plt.xlabel('Epoch')
plt.ylabel('Learning Rate')
plt.legend()
plt.show()

在这里插入图片描述

MultiStepLR

功能:按给定间隔调整学习率

主要参数:

  • milestones:设定调整时刻数
  • gamma:调整系数

调整方式:lr = lr * gamma

# MultiStepLR
milestones = [50, 125, 160]
scheduler_lr = optim.lr_scheduler.MultiStepLR(optimizer, milestones=milestones, gamma=0.1)

只需要改变这里代码,其他部分与StepLR中基本一致

在这里插入图片描述

ExponentialLR

功能:按指数衰减调整学习率

主要参数:

  • gamma:指数的底

调整方式:lr = lr * gamma ** epoch

# Exponential LR
gamma = 0.95
scheduler_lr = optim.lr_scheduler.ExponentialLR(optimizer, gamma=gamma)


在这里插入图片描述

CosineAnnealingLR

功能:余弦周期调整学习率

主要参数:

  • T_max:下降周期
  • eta_min:学习率下限

调整方式:

在这里插入图片描述

# CosineAnnealingLR
t_max = 50
scheduler_lr = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=t_max, eta_min=0)

在这里插入图片描述

ReduceLRonPlateau

功能:监控指标,当指标不再变化则调整学习率

主要参数:

  • mode:min/max,两种模式,min观察下降,max观察上升
  • factor:调整系数
  • patience:“耐心”,接受几次不变化
  • cooldown:“冷却时间”,停止监控一段时间
  • verbose:是否打印日志
  • min_lr:学习率下限
  • eps:学习率衰减最小值
# Reduce LR on Plateau
loss_value = 0.5
accuray = 0.9

factor = 0.1
mode = 'min'
patience = 10
cooldown = 10
min_lr = 1e-4
verbose = True

scheduler_lr = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=factor, mode=mode, patience=patience,
                                                    cooldown=cooldown, min_lr=min_lr, verbose=verbose)
for epoch in range(max_epoch):
    for i in range(iteration):
        optimizer.step()
        optimizer.zero_grad()
    # if epoch == 5:
        # loss_value = 0.4

    # 把要监控的指标传进去
    scheduler_lr.step(loss_value)

Epoch    12: reducing learning rate of group 0 to 1.0000e-02.
Epoch    33: reducing learning rate of group 0 to 1.0000e-03.
Epoch    54: reducing learning rate of group 0 to 1.0000e-04.


LambdaLR

功能:自定义调整策略

主要参数:

  • lr_lambda:function or list
# lambda LR

lr_init = 0.1
weights_1 = torch.randn((6, 3, 5, 5))
weights_2 = torch.ones((5, 5))

optimizer = optim.SGD([
    {'params': [weights_1]},
    {'params': [weights_2]}
], lr=lr_init)

lambda1 = lambda epoch: 0.1 ** (epoch // 20)
lambda2 = lambda epoch: 0.95 ** epoch


scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=[lambda1, lambda2])

lr_list, epoch_list = list(), list()
for epoch in range(max_epoch):
    for i in range(iteration):
        optimizer.step()
        optimizer.zero_grad()
    scheduler.step()

    lr_list.append(scheduler.get_lr())
    epoch_list.append(epoch)

    print('epoch: {:5d}, lr:{}'.format(epoch, scheduler.get_lr()))


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/401534.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

windows安装docker-小白用【避坑】【伸手党福利】

目录实操开启 Hyper-V 和容器特性下载docker安装dockercmd中&#xff0c;使用命令测试是否成功报错解决办法&#xff1a;下载linux模拟器wsl&#xff1a;双击打开docker重新打开cmd&#xff0c;输入命令&#xff0c;成功显示sever和clinet实操 开启 Hyper-V 和容器特性 控制面…

项目进度管理:项目经理应该怎么做?

项目经理的职责是非常清晰的、界面分明的。项目经理经常忙碌的原因是&#xff0c;缺乏规划&#xff0c;觉得很多业务都跟自己相关&#xff0c;但不知道到底要做哪些工作&#xff0c;没有把多个角色分清楚。 1、目标。 项目目标是实施项目所要达到的期望结果&#xff0c;一个明…

基于嵌入式linux的OpenSSL源码移植(基于arm64)

SSL是Secure Sockets Layer&#xff08;安全套接层协议&#xff09;的缩写&#xff0c;可以在Internet上提供秘密性传输。Netscape公司在推出第一个Web浏览器的同时&#xff0c;提出了SSL协议标准。其目标是保证两个应用间通信的保密性和可靠性,可在服务器端和用户端同时实现支…

类和对象(一)

类和对象&#xff08;一&#xff09; C并不是纯面向对象语言 C是面向过程和面向对象语言的&#xff01; 面向过程和面向对象初步认识&#xff1a; C语言是面向过程的&#xff0c;关注的是过程&#xff0c;分析出求解问题的步骤&#xff0c;通过函数调用逐步解决问题。 C是基…

驱动程序开发:FTP服务器和OpenSSH的移植与搭建、以及一些笔记

目录一、FTP服务器移植与搭建1、在ubuntu下安装vsftpd2、在window下安装FileZilla3、移植vsftpd到开发板上4、Filezilla 连接测试5、注意点二、开发板 OpenSSH 移植与使用1、移植 zlib 库2、移植 openssl 库3、移植 openssh 库4、openssh 使用测试三、关于u-boot上的操作及根文…

数据表(二) - 数据表的制作方式

本篇来介绍下数据表的几种制作数据的方式。Excel是大部分数值策划选择用的填数工具&#xff0c;因为Excel是天生为数据处理而生&#xff0c;而Excel转为什么格式就需要选择了。最简单的就是直接将Excel里的数据复制黏贴到文本文件作为游戏数据。这种简单快捷的方式任何人都能做…

DatenLord前沿技术分享 No.20

达坦科技专注于打造新一代开源跨云存储平台DatenLord&#xff0c;致力于解决多云架构、多数据中心场景下异构存储、数据统一管理需求等问题&#xff0c;以满足不同行业客户对海量数据跨云、跨数据中心高性能访问的需求。喷泉码具有极高的纠错能力&#xff0c;且具有低延迟、地复…

CnOpenData·A股上市企业数字化转型指数数据

一、数据简介 企业数字化转型是近年来中国社会各界重点关注的领域&#xff0c;但基础数据的不完善在很大程度上制约了相关科学研究的开展。构建合理、科学的数字化转型指标体系有利于学者定量地研究企业数字化的相关问题&#xff0c;也有利于衡量企业的数字化水平。广东金融学院…

Linux驱动开发

一、驱动分类Linux中包含三大类驱动&#xff1a;字符设备驱动、块设备驱动和网络设备驱动。其中字符设备驱动是最大的一类驱动&#xff0c;因为字符设备最多&#xff0c;从led到I2C、SPI、音频等都属于字符设备驱动。块设备驱动和网络设备驱动都要比字符设备驱动复杂。因为其比…

标度不变性(scale invariance)与无标度(scale-free)概念辨析

文章目录标度标度种类名义标度序级标度等距标度比率标度常用标度方法不足标度不变性标度不变&#xff08;Scale-invariant&#xff09;曲线和自相似性&#xff08;self-similarity&#xff09;射影几何分形随机过程中的标度不变性标度不变的 Tweedie distribution普适性&#x…

蓝牙及其安全技术概述

作者 | 陆杰 上海控安可信软件创新研究院研发工程师 来源 | 鉴源实验室 01 背 景 汽车已成为现代社会生活不可或缺的一部分。车辆蓝牙[1]安全非常重要&#xff0c;因为未经保护的蓝牙连接可能会被黑客利用来获取车辆的敏感信息、控制车辆等&#xff0c;从而对车辆的安全和车主…

12N65-ASEMI高压MOS管12N65

编辑-Z 12N65在TO-220封装里的静态漏极源导通电阻&#xff08;RDS(ON)&#xff09;为0.68Ω&#xff0c;是一款N沟道高压MOS管。12N65的最大脉冲正向电流ISM为48A&#xff0c;零栅极电压漏极电流(IDSS)为10uA&#xff0c;其工作时耐温度范围为-55~150摄氏度。12N65功耗&#x…

【项目精选】基于Java的愤怒的小鸟游戏的设计与实现(视频+论文+源码)

点击下载源码 基本功能包括&#xff1a;新游戏、载入游戏、控制帮助、退出游戏等。本系统结构如下&#xff1a; &#xff08;1&#xff09;新游戏&#xff1a; 需要输入你的昵称&#xff1b; 选择难度&#xff1a;容易、中等、困难、噩梦(每个级别都有5个关卡) &#xff08;2&…

【Fabric 超级账本学习【3】Fabric2.4 使用Tape进行吞吐量量性能测试】

如果想测试一下超级账本fabric对某个合约函数的执行时间是多少&#xff0c;简单地可以通过打印合约函数开始执行时间和结束执行时间来计算时间差就可以了。 Tape 是一款轻量级 Hyperledger Fabric 性能测试工具。 tape的github地址&#xff1a;https://github.com/Hyperledge…

Spark的基本概念与架构

一、Spark简介 Spark 是一种与 Hadoop 相似的开源集群计算环境&#xff0c;但是两者之间还存在一些不同之处&#xff0c;这些有用的不同之处使 Spark 在某些工作负载方面表现得更加优越&#xff0c;换句话说&#xff0c;Spark 启用了内存分布数据集&#xff0c;除了能够提供交…

(二十五)操作系统--读者·写者问题

文章目录一、问题描述二、问题分析1&#xff0e;关系分析2&#xff0e;整理思路3&#xff0e;设置信号量4.注意三、代码实现1.代码2.改进代码四、总结一、问题描述 有读者和写者两组并发进程&#xff0c;共享一个文件&#xff0c;当两个或两个以上的读进程同时访问共享数据时不…

战斗力最强排行榜:10-30人团队任务管理工具

工欲善其事&#xff0c;必先利其器。在高效的任务执行过程中&#xff0c;选择灵活轻便的项目管理工具来提升工作效率、适应快速多变的发展诉求&#xff0c;对团队来说&#xff0c;至关重要。但是如果团队不大&#xff0c;企业对这块的预算又有限&#xff0c;大型的团队任务管理…

abc 联合索引查 bc索引到底走不走索引?

今天面试的时候&#xff0c;面试官有问到这个问题我说不会&#xff0c;可是面试官说走&#xff0c;网上也众说纷纭&#xff0c;那到底会不会走呢&#xff1f; 先看官网解释不会走&#xff1a; https://dev.mysql.com/doc/refman/8.0/en/multiple-column-indexes.html SELECT *…

响应式操作实战案例

Project Reactor 框架 在Spring Boot 项目 Maven 中添加依赖管理。 <dependency><groupId>io.projectreactor</groupId><artifactId>reactor-core</artifactId> </dependency><dependency><groupId>io.projectreactor</g…

AT32F437制作Bootloader然后实现Http OTA升级

首先创建一个AT32F437的工程&#xff0c;然后发现调试工程配置这里的型号和创建工程选的型号不一致&#xff0c;手动更改一下&#xff0c;使用PW Link下载程序的话还要配置一下pyocd.exe的路径。 打开drv_clk.c文件的调试功能看下系统时钟频率。 项目使用的是AT32F437VMT7芯片&…