Pytorch优化器全总结(四)常用优化器性能对比 含代码

news2025/1/11 14:05:14

目录

写在前面

一、优化器介绍

1.SGD+Momentum

2.Adagrad

3.Adadelta

4.RMSprop

5.Adam

6.Adamax

7.AdaW

8.L-BFGS

二、优化器对比


优化器系列文章列表

Pytorch优化器全总结(一)SGD、ASGD、Rprop、Adagrad

Pytorch优化器全总结(二)Adadelta、RMSprop、Adam、Adamax、AdamW、NAdam、SparseAdam

Pytorch优化器全总结(三)牛顿法、BFGS、L-BFGS 含代码

写在前面

        常用的优化器我已经用三篇文章介绍完了,现在我将对比一下这些优化器的收敛速度。

        下面我将简单介绍一下要对比的优化器,每种我只用一到两句话介绍,详细介绍请跳转上面的链接,每种优化器都详细介绍过。

一、优化器介绍

1.SGD+Momentum

        带动量 的SGD 优化算法,Momentum通过将当前梯度与过去梯度加权平均,来获取即将更新的梯度,有助于在相关方向上加速SGD并抑制振荡。

2.Adagrad

        每个时间步长对每个参数使用不同的学习率。 引入了梯度的二阶矩,二阶矩是迄今为止所有梯度值的平方和,二阶矩的越大,代表步长的不确定性越大,学习率就越小,反之学习率越大。 

3.Adadelta

        对于每个维度,用梯度平方的指数加权平均代替了全部梯度的平方和,避免了后期更新时更新幅度逐渐趋近于0的问题。

        用更新量的平方的指数加权平均来动态得代替了全局的标量的学习率,避免了对学习率的敏感。

4.RMSprop

        与Adadelta同一时期,等价于实现了Adadelta的第一个改动。

5.Adam

        同时使用梯度的一阶矩估计和二阶矩估计动态调整每个参数的学习率。 一阶矩来控制模型更新的方向,二阶矩控制步长(学习率)。

6.Adamax

        在Adam的基础上,为学习率的上限限制了范围。将Adam的二范数(二阶矩估计)推广到无穷范数,因为无穷范数,就是取向量的最大值,这就为学习率的上限提供了一个更简单的范围。

7.AdaW

        使用adam+权重衰减的方式解决了adam+L2正则化表现不佳的问题。

8.L-BFGS

        牛顿法是基于迭代的二阶优化方法,对于高维的应用场景,求二阶导变得不可行;BFGS对牛顿法做了改进,用一阶导和一个基于迭代的矩阵H模拟海森矩阵,从而降低计算的复杂度;BFGS虽然对牛顿法做了优化,但是H的存储空间至少为N(N+1)/2(N为特征维数),需要的存储空间将是非常巨大的,L-BFGS采用加窗的方式,通过存储前m次迭代的少量数据来替代前一次的H矩阵,从而大大减少数据的存储空间。

二、优化器对比

        下面我们将对比SGD、SGD+Momentum、Adagrad、Adadelta、RMSprop、Adam、Adamax、AdaW、L-BFGS的收敛速度。

        代码如下:


import torch
import torch.utils.data as Data
import torch.nn.functional as F
from torch.autograd import Variable
import matplotlib.pyplot as plt

# 超参数
LR = 0.01
BATCH_SIZE = 32
EPOCH = 12

# 生成假数据
# torch.unsqueeze() 的作用是将一维变二维,torch只能处理二维的数据
x = torch.unsqueeze(torch.linspace(-1, 1, 1000), dim=1)  # x data (tensor), shape(100, 1)
# 0.2 * torch.rand(x.size())增加噪点
y = x.pow(2) + 0.1 * torch.normal(torch.zeros(*x.size()))

# 定义数据库
dataset = Data.TensorDataset(x, y)

# 定义数据加载器
loader = Data.DataLoader(dataset=dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)


# 定义pytorch网络
class Net(torch.nn.Module):
    def __init__(self, n_features, n_hidden, n_output):
        super(Net, self).__init__()
        self.hidden = torch.nn.Linear(n_features, n_hidden)
        self.predict = torch.nn.Linear(n_hidden, n_output)

    def forward(self, x):
        x = F.relu(self.hidden(x))
        y = self.predict(x)
        return y

# 定义不同的优化器网络
net_SGD = Net(1, 10, 1)
net_Momentum = Net(1, 10, 1)
net_Adagrad = Net(1, 10, 1)
net_Adadelta = Net(1, 10, 1)
net_RMSprop = Net(1, 10, 1)
net_Adam = Net(1, 10, 1)
net_Adamax = Net(1, 10, 1)
net_AdamW = Net(1, 10, 1)
net_LBFGS = Net(1, 10, 1)

# 选择不同的优化方法
opt_SGD = torch.optim.SGD(net_SGD.parameters(), lr=LR)
opt_Momentum = torch.optim.SGD(net_Momentum.parameters(), lr=LR, momentum=0.9)
opt_Adagrad = torch.optim.Adagrad(net_Adagrad.parameters(), lr=LR)
opt_Adadelta = torch.optim.Adadelta(net_Adadelta.parameters(), lr=LR)
opt_RMSprop = torch.optim.RMSprop(net_RMSprop.parameters(), lr=LR, alpha=0.9)
opt_Adam = torch.optim.Adam(net_Adam.parameters(), lr=LR, betas=(0.9, 0.99))
opt_Adamax = torch.optim.Adamax(net_Adamax.parameters(), lr=LR, betas=(0.9, 0.99))
opt_AdamW = torch.optim.AdamW(net_AdamW.parameters(), lr=LR, betas=(0.9, 0.99))
opt_LBFGS = torch.optim.LBFGS(net_LBFGS.parameters(), lr=LR, max_iter=10, max_eval=10)



nets = [net_SGD, net_Momentum, net_Adagrad, net_Adadelta, net_RMSprop, net_Adam, net_Adamax, net_AdamW, net_LBFGS]
optimizers = [opt_SGD, opt_Momentum, opt_Adagrad, opt_Adadelta, opt_RMSprop, opt_Adam, opt_Adamax, opt_AdamW, opt_LBFGS]

# 选择损失函数
loss_func = torch.nn.MSELoss()

# 不同方法的loss
loss_SGD = []
loss_Momentum = []
loss_Adagrad = []
loss_Adadelta = []
loss_RMSprop = []
loss_Adam = []
loss_Adamax = []
loss_AdamW = []
loss_LBFGS = []

# 保存所有loss
losses = [loss_SGD, loss_Momentum, loss_Adagrad, loss_Adadelta, loss_RMSprop, loss_Adam, loss_Adamax, loss_AdamW, loss_LBFGS]

# 执行训练
for epoch in range(EPOCH):
    for step, (batch_x, batch_y) in enumerate(loader):
        var_x = Variable(batch_x)
        var_y = Variable(batch_y)
        for net, optimizer, loss_history in zip(nets, optimizers, losses):
            if isinstance(optimizer, torch.optim.LBFGS):
                def closure():
                    y_pred = net(var_x)
                    loss = loss_func(y_pred, var_y)
                    optimizer.zero_grad()
                    loss.backward()
                    return loss
                loss = optimizer.step(closure)
            else:
                # 对x进行预测
                prediction = net(var_x)
                # 计算损失
                loss = loss_func(prediction, var_y)
                # 每次迭代清空上一次的梯度
                optimizer.zero_grad()
                # 反向传播
                loss.backward()
                # 更新梯度
                optimizer.step()
            # 保存loss记录
            loss_history.append(loss.data)

# 画图
labels = ['SGD', 'Momentum', 'Adagrad', 'Adadelta', 'RMSprop', 'Adam', 'Adamax', 'AdamW', 'LBFGS']
for i, loss_history in enumerate(losses):
    plt.plot(loss_history, label=labels[i])
plt.legend(loc='best')
plt.xlabel('Steps')
plt.ylabel('Loss')
plt.ylim((0, 0.2))
plt.show()

         从图中可以看到,Adam、Adamax、AdaW、L-BFGS收敛速度要更快,当然这次实验只代表一般情况下的结果,项目中还是要以实际效果为准,大家在实际项目中还是要多试几种,选择适合自己的。

        算法的性能比较就介绍到这里,收藏关注不迷路。

优化器系列文章列表

Pytorch优化器全总结(一)SGD、ASGD、Rprop、Adagrad

Pytorch优化器全总结(二)Adadelta、RMSprop、Adam、Adamax、AdamW、NAdam、SparseAdam

Pytorch优化器全总结(三)牛顿法、BFGS、L-BFGS 含代码

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/165835.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

设计模式学习(七):Factory Method工厂模式

目录 一、什么是Factory Method模式 二、Factory Method示例代码 2.1 类之间的关系 2.2 Product类 2.3 Factory类 2.4 IDCard类 2.5 IDCardFactory类 2.6 用于测试的Main类 2.7 运行结果 三、拓展思路的要点 3.1 框架与具体加工 3.2 使用模式与开发人员之间的沟通 …

(12)go-micro微服务JWT跨域认证

文章目录一 JWT介绍二 JWT优缺点三 JWT使用1. 导包和数据定义2.生成JWT3.解析JWT4.完整代码四 最后一 JWT介绍 JWT 英文名是 Json Web Token ,是一种用于通信双方之间传递安全信息的简洁的、URL安全的表述性声明规范,经常用在跨域身份验证。 JWT 以 JS…

拐点检测常用算法总结

目录概览问题定义符号定义研究方法损失函数概览 问题定义 拐点检测名为 change point detection,对于一条不平缓的时间序列曲线,认为存在一些时间点 (t1,t2,...,tk)( t_1 , t_2 , . . . , t_k )(t1​,t2​,...,tk​) ,使得曲线在这些点对应…

Java SPI的介绍、JDBC中SPI的应用、自己实现一个SPI应用

目录1. Java SPI介绍2. Java SPI的运行流程3. Java SPI在JDBC中的应用4. Java SPI的三大规范要素5. 自己实现一个SPI应用5.1 Service接口5.2 运营商1的Service Provider5.3 运营商2的Service Provider5.3 手机使用网络1. Java SPI介绍 SPI(Service Provider Interface)是一种基…

在别墅大宅中打造全屋智能,总共需要几步?

关于智能家居,很多读者可能会想起一些不那么愉快的回忆:2014年左右的智能家居浪潮,涌现出了众多带蓝牙互联功能的家电产品,但数据无法互联互通、单品体验升级有限,加上一些企业竞争失败产品不再更新,留给消…

EXCEL工具介绍

目录1. 锁定功能2. 固定1. 锁定功能 锁定:F4 公式引用单元格,有“相对引用”与“绝对引用” 美元符号“ $ ”在excel公式中的作用是在“绝对引用”时,锁定行号或列标(单元格地址由列标行号组成,如A1,A为列…

国产软件不惧微软,WPS力扛大旗,新型办公软件争相助力

金山作为程序员的“黄埔军校”,输入了一批批互联网大佬,小米的雷军、哔哩哔哩的陈睿、蓝港互动的王峰等都师出金山。WPS作为金山拳头作品,有着“国民软件”美誉,功能强大,使用体验一点不输微软word,然而有一…

打工人必学的法律知识(三)——《中华人民共和国劳动争议调解仲裁法》

PS : 对与日常工作比较紧密的部分做摘录 中华人民共和国劳动争议调解仲裁法http://www.npc.gov.cn/npc/c198/200712/756d4eceb95e420a87c97545a58d931c.shtml 目录 一、调解 二、仲裁 三、申请和受理 四、开庭和裁决 五、附则 第六条 发生劳动争议&#xff0…

python镜像设置

winr 输入 %USERPROFILE% 新建pip目录,目录内新建pip.ini 输入: [global] index-urlhttp://mirrors.aliyun.com/pypi/simple/ trusted-hostmirrors.aliyun.com

计算机网络概括

1 前言计算机网络是指将位于不同地理位置,但具有独立功能的多台设备,通过通信设备和线路连接起来,在网络操作系统,网络管理软件、网络通信协议的协调管理下,实现资源共享和信息传递的计算机系统。简单来说,…

STM32模拟SPI总线读写RFID模块RC522

STM32模拟SPI总线读写RFID模块RC522 RC522是一款NXP 公司的支持ISO14443A协议的高频(13.56MHz)RFID射频芯片。RFID有ID和IC两种卡应用类型,RC522属于IC卡的应用类型。NFC则属于增强的IC卡类型,支持双向通信和更多类型的协议。 I…

es官网翻译之Exploring Your Cluster

Exploring Your Cluster 探索你的集群 The rest api rest 风格的 api Now that we have our node (and cluster) up and running, the next step is to understand how to 现在 我们已经将我们自己的节点(和集群) 启动并运行着, 下一个步骤是知道如何 communicate with it…

Java面试题每日10问(12)

1. What is String Pool? String pool is the space reserved in the heap memory that can be used to store the strings.The main advantage of using the String pool is whenever we create a string literal; the JVM checks the “string constant pool” first.If th…

速度为单GPU1.6倍,kaggle双GPU(ddp模式)加速pytorch攻略

accelerate 是huggingface开源的一个方便将pytorch模型迁移到 GPU/multi-GPUs/TPU/fp16 模式下训练的小巧工具。和标准的 pytorch 方法相比,使用accelerate 进行多GPU DDP模式/TPU/fp16 训练你的模型变得非常简单(只需要在标准的pytorch训练代码中改动不几行代码就可…

linux基功系列之man帮助命令实战

文章目录前言一、man命令介绍二、常用参数2.1 语法2.2 常用参数2.3 man页面的操作命令man命令使用案例1. 直接查看手册2. -aw 参数找到可以被查询的章节2.3 一次性查阅所有章节2.4 搜索手册页2.5 -L 设置查询语言总结前言 linux系统中的命令数量有上千的,即使是常用…

前端——周总结系列二

1 JS数组排序sort()方法 不传参数排序,默认根据Unicode排序 附录 传参数,使用比较函数,自己定义比较规则 简单数组排序 // 升序 function ascSort(a, b) {return a - b; } // 降序 function ascSort(a, b) {return b - a; }数组对象排序…

算法leetcode|31. 下一个排列(rust重拳出击)

文章目录31. 下一个排列:样例 1:样例 2:样例 3:提示:分析:题解:rustgoccpythonjava31. 下一个排列: 整数数组的一个 排列 就是将其所有成员以序列或线性顺序排列。 例如&#xff0…

ROS2机器人编程简述humble-第二章-First Steps with ROS2 .1

ROS2机器人编程简述新书推荐-A Concise Introduction to Robot Programming with ROS2学习笔记流水账-推荐阅读原书。第二章主要就是一些ROS的基本概念,其实ROS1和ROS2的基本概念很多都是类似的。ROS2机器人个人教程博客汇总(2021共6套)如何更…

Linux chgrp 命令

Linux chgrp(英文全拼:change group)命令用于变更文件或目录的所属群组。与 chown 命令不同,chgrp 允许普通用户改变文件所属的组,只要该用户是该组的一员。在 UNIX 系统家族里,文件或目录权限的掌控以拥有…

(一)Jenkins部署、基础配置

目录 1、前言 1.1、Jenkins是什么 1.2、jenkins有什么用 2、 Jenkins安装 2.1、jdk安装 2.2、安装Jenkins 3、Jenkins配置 3.1、解锁Jenkins 3.2、插件安装 3.3、创建管理员 3.4、实例配置 4、汉化 4.1、下载Locale插件 4.2、设置为中文 5、设置中文失效解决步骤 1…