深度学习之优化器(简要总结)

news2024/9/9 1:30:42

优化器是用于训练神经网络模型的关键组件,它们决定了模型参数如何根据损失函数的梯度进行更新。不同的优化器具有不同的特性和适用场景。

下面将介绍几种常见的深度学习优化器,以及基于pytorch版本的定义和使用方法。

1.SGD(Stochastic Gradient Descent)

随机梯度下降是最基础的优化算法。它沿着梯度的反方向更新参数,每次更新只考虑单个样本或小批量样本的梯度。优点是简单易懂,容易实现。缺点是收敛速度比较慢,且容易陷入局部最优解。

import torch
import torch.optim as optim

# 定义模型
model = Modeel() 
#自定义的网络模型

# 定义优化器
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=0.005)

SGD优化器参数如下:

params (iterable of parameters):
作用:指定要优化的参数列表。
类型:可迭代对象,通常是模型的 model.parameters() 返回的结果。

lr (float):
作用:学习率,控制参数更新的步长大小。
默认值:0.01。
学习率的选择对模型的训练效果影响极大,通常需要通过实验来进行调整。较小的学习率可能导致收敛速度慢,而较大的学习率可能导致震荡或无法收敛。

momentum (float, optional):
作用:动量因子,用于加速SGD在相关方向上的移动,有助于加快收敛速度并减少震荡。
默认值:0。
通常设置在0.9左右,但根据问题的性质和实验情况可以进行微调。过高的动量可能导致模型在局部最小值周围无法稳定。

dampening (float, optional):
作用:动量的抑制因子,用于抑制动量的振荡。
默认值:0。
适用条件:仅当设置了动量时才有效。

weight_decay (float, optional):
作用:权重衰减(L2惩罚),用于在每次更新时惩罚较大的权重值,有助于防止过拟合。
默认值:0。
建议:根据数据集和模型复杂度调整。过高的权重衰减可能导致欠拟合。

nesterov (bool, optional):
作用:是否使用Nesterov动量。
默认值:False。
Nesterov动量在一些情况下可以提供更好的收敛性能,特别是在梯度较稀疏的情况下。

lr_decay (float, optional):
作用:学习率衰减因子。每个epoch结束后,学习率会乘以这个因子。
默认值:0。

2.AdaGrad(Adaptive Gradient)

AdaGrad优化器的优点是能够自适应地调整学习率,对于不同的参数可以根据其历史梯度信息进行个性化的学习率调整,对稀疏数据效果较好。缺点是由于学习率不断减小,可能会在后期导致学习率过小,使得训练提前结束,无法达到最优解。

import torch
import torch.optim as optim

# 定义模型
model = Model()

# 定义优化器
optimizer = optim.Adagrad(model.parameters(), lr=0.01, weight_decay=0.0005)

AdaGrad优化器参数如下:

params (iterable of parameters):
作用:指定要优化的参数列表。
类型:可迭代对象,通常是模型的 model.parameters() 返回的结果。

lr (float):
作用:学习率,控制参数更新的步长大小。
默认值:0.01。
学习率的选择对模型的训练效果影响极大,通常需要通过实验来进行调整。

lr_decay (float, optional):
作用:学习率衰减因子。每个epoch结束后,学习率会乘以这个因子。
默认值:0。

weight_decay (float, optional):
作用:权重衰减(L2惩罚),用于在每次更新时惩罚较大的权重值,有助于防止过拟合。
默认值:0。
根据数据集和模型复杂度调整。过高的权重衰减可能导致欠拟合。

eps (float, optional):
作用:为了数值稳定性而添加到分母中的小值,避免除以零。
默认值:1e-10。
一般情况下不需要更改这个值,除非遇到数值稳定性的问题。

3.RMSProp(Root Mean Square Propagation)

RMSProp是AdaGrad的改进版本。它通过引入一个衰减系数来限制历史梯度信息的累积,以解决AdaGrad在长时间训练中学习率过快下降的问题。

import torch
import torch.optim as optim

# 定义模型
model = Model()

# 定义优化器
optimizer = optim.RMSprop(model.parameters(), lr=0.001, alpha=0.9, eps=1e-8, weight_decay=0.005)


RMSProp优化器参数如下:

params (iterable of parameters):
作用:指定要优化的参数列表。
类型:可迭代对象,通常是模型的 model.parameters() 返回的结果。

lr (float):
作用:学习率,控制参数更新的步长大小。
默认值:0.01。

alpha (float, optional):
作用:平滑常数,默认为0.99。
一般情况下不需要更改这个值。它是用来计算RMSProp中平方梯度的移动平均值的指数衰减率。

eps (float, optional):
作用:为了数值稳定性而添加到分母中的小值,避免除以零。
默认值:1e-8。
一般情况下不需要更改这个值,除非遇到数值稳定性的问题。

weight_decay (float, optional):
作用:权重衰减(L2惩罚),用于在每次更新时惩罚较大的权重值,有助于防止过拟合。
默认值:0。
根据数据集和模型复杂度调整。过高的权重衰减可能导致欠拟合。

momentum (float, optional):
作用:动量因子,用于加速优化过程。
默认值:0。
通常不需要设置动量,因为RMSProp本身已经包含了动量的效果。

4.Adam (Adaptive Moment Estimation)

Adam优化器结合了动量优化器和RMSProp的思想,能够自适应地调整每个参数的学习率,收敛速度较快,在很多情况下能够取得比较好的性能。但是,它可能会出现权重方差估计过高的情况,导致在某些情况下收敛效果不如SGD。

import torch
import torch.optim as optim

# 定义模型
model = Model()

# 定义优化器
optimizer = optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.005)


Adam优化器参数如下:


params (iterable of parameters):
作用:指定要优化的参数列表。
类型:可迭代对象,通常是模型的 model.parameters() 返回的结果。

lr (float):
作用:学习率,控制参数更新的步长大小。
默认值:0.001。

betas (Tuple[float, float], optional):
作用:用于计算梯度的一阶矩估计(均值)和二阶矩估计(未中心化的方差)的系数。
默认值:(0.9, 0.999)。
通常情况下不需要更改这个值。第一个元素是一阶矩估计的衰减率(动量),第二个元素是二阶矩估计的衰减率。

eps (float, optional):
作用:为了数值稳定性而添加到分母中的小值,避免除以零。
默认值:1e-8。
建议:一般情况下不需要更改这个值,除非遇到数值稳定性的问题。

weight_decay (float, optional):
作用:权重衰减(L2惩罚),用于在每次更新时惩罚较大的权重值,有助于防止过拟合。
默认值:0。
根据数据集和模型复杂度调整。过高的权重衰减可能导致欠拟合。

一般来说,Adam和SGD这两个优化器是我们训练网络模型的首选,在很多情况下,都能够取得不错的效果。当然,在实际问题中,我们也需要根据具体问题和数据特点,灵活选择和调整优化器,以达到最佳的训练效果。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1959196.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

CTF学习笔记汇总(非常详细)零基础入门到精通,收藏这一篇就够了

CTF学习笔记汇总 Part.01 Web 01 SSRF 主要攻击方式如下: 01 对外网、服务器所在内网、本地进行端口扫描,获取一些服务的banner信息。 02 攻击运行在内网或本地的应用程序。 03 对内网Web应用进行指纹识别,识别企业内部的资产信息。 …

Studying-代码随想录训练营day45| 115.不同的子序列、583. 两个字符串的删除操作、72. 编辑距离、编辑距离总结篇

第45天,子序列part03,编辑距离💪(ง •_•)ง,编程语言:C 目录 115.不同的子序列 583. 两个字符串的删除操作 72. 编辑距离 编辑距离总结篇 115.不同的子序列 文档讲解:代码随想录不同的子序列 视频讲…

高效能程序员的9个习惯

最近看了一本关于敏捷软件开发实践的指南,他文中主要是在帮助软件开发者和团队提升工作效率、提高产品质量,并建立良好的工作文化和协作模式。以下是根据目录整理出的一段总结: 书名:《敏捷之道》 本书深入探讨了敏捷开发的核心原…

从 1 到 100 万+连接数,DigitalOcean 负载均衡的架构演进

在前不久,DigitalOcean 全球负载均衡器(GLB)Beta版正式上线。该解决方案能给客户的跨区域业务带来更好的支持,可以增强应用程序的弹性,消除单点故障,并大幅降低终端用户的延迟。这是 DigitalOcean 负载均衡…

Python写UI自动化--playwright(pytest.ini配置)

在 pytest.ini 文件中配置 playwright 的选项可以更好地控制测试执行的过程。 在终端输入pytest --help,可以找到playwright的配置参数 目录 1. --browser{chromium,firefox,webkit} 2. --headed 3. --browser-channelBROWSER_CHANNEL 4. --slowmoSLOWMO 5. …

魔众文库-PHP文库管理系统

魔众文库是一套基于PHPMYSQL开发的适用于多平台的文档管理系统,提供doc、ppt、excel、pdf、压缩包、图片、CAD 等资源的在线预览和下载,文件被转换为H5或图片格式,文字放大无失真,响应速度更快速对SEO更友好,收录更快、…

NFTScan | 07.22~07.28 NFT 市场热点汇总

欢迎来到由 NFT 基础设施 NFTScan 出品的 NFT 生态热点事件每周汇总。 周期:2024.07.22~ 2024.07.28 NFT Hot News 01/ 数据:NFT 系列 Liberty Cats 地板价突破 70000 MATIC 7 月 22 日,据 Magic Eden 数据,NFT 系列 Liberty C…

内网隧道学习笔记

1.基础: 一、端口转发和端口映射 1.端口转发是把一个端口的流量转发到另一个端口 2.端口映射是把一个端口映射到另一个端口上 二、http代理和socks代理 1.http带那里用http协议、主要工作在应用层,主要用来代理浏览网页。 2.socks代理用的是socks协议、…

c# string记录

c# srting 的操作例子 在C#中,string 类型是一个不可变(immutable)的引用类型,表示文本。由于它的不可变性,对字符串的任何修改操作实际上都会返回一个新的字符串实例。以下是一些常见的 string 操作例子: …

Hvv第二周,喝了3瓶红牛,心慌、头晕,我还行么?

Hvv第二周了,你们的物资挥霍的怎么样了啊?今天看到群里有小伙伴说喝了3瓶红牛,结果现在搞得头晕晕的,很慌。 Hvv物资来由 这不仅让我想来聊聊护网物资的来由和发展,也让后来进入网安这个行业的小伙伴了解一下&#xf…

【Plotly-驯化】一文教您画出Plotly中动态可视化饼图:pie技巧

【Plotly-驯化】一文教您画出Plotly中动态可视化饼图:pie技巧 本次修炼方法请往下查看 🌈 欢迎莅临我的个人主页 👈这里是我工作、学习、实践 IT领域、真诚分享 踩坑集合,智慧小天地! 🎇 免费获取相关内…

逻辑漏洞复现(pikachu靶场,大米cms)

逻辑漏洞 漏洞介绍 1.成因 逻辑漏洞是指由于程序逻辑不严或逻辑太复杂,导致一些逻辑分支不能够正常处理或处理错误,一般出现任意密码修改(没有旧密码验证)、越权访问、密码找回、交易支付金额等。 2. 分析 对常见的漏洞进行过…

Qt Phonon多媒体框架详解及简单实例分享

目录 1、Phonon 简介 2、Phonon基本类 2.1、VideoPlayer类 2.2、MediaObject类 2.3、Phonon::createPath() 2.4、AudioOutput类 2.5、VideoWidget Class 2.6、SeekSlider类 2.7、VolumeSlider类 3、Phonon 完整使用实例 4、总结 C++软件异常排查从入门到精通系列教程…

ChatGPT小狐狸AI付费创作系统v3.0.3+前端

小狐狸GPT付费体验系统的开发基于国外很火的ChatGPT,这是一种基于人工智能技术的问答系统,可以实现智能回答用户提出的问题。相比传统的问答系统,ChatGPT可以更加准确地理解用户的意图,提供更加精准的答案。同时,小狐狸…

项目管理“四管”法则

在项目管理中,“四管”的具体内容可能因不同的项目管理框架和实践而有所不同。但一般而言,它们可以概括为与项目成功密切相关的四个关键管理领域。以下是项目管理中“四管”: 一、人力资源管理(管人) 项目团队是项目…

AMQP-核心概念-终章

本文参考以下链接摘录翻译: https://www.rabbitmq.com/tutorials/amqp-concepts 连接(Connections) AMQP 0-9-1连接通常是长期保持的。AMQP 0-9-1是一个应用级别的协议,它使用TCP来实现可靠传输。连接使用认证且可以使用TLS保护…

Python 进行数据可视化(Matplotlib, Seaborn)

数据可视化是数据科学和分析中的重要工具,它通过图形表示数据,使得复杂的数据变得易于理解和分析。在Python中,最常用的两个数据可视化库是Matplotlib和Seaborn。 Matplotlib 1. 简介 Matplotlib是一个用于生成二维图形的Python库。它提供…

深入浅出消息队列----【阶段总结篇】

深入浅出消息队列----【阶段总结篇】 总览nameSrvBrokerproducer(生产者)consumer(消费者) 串联起来 本文仅是文章笔记,整理了原文章中重要的知识点、记录了个人的看法 文章来源:编程导航-鱼皮【yes哥深入浅…

小间距 LED 显示屏:引领显示技术新潮流

在现代显示技术领域,小间距LED显示屏以其先进的像素点控技术和卓越的显示效果,正逐渐成为市场的新宠。在此为您详细解析小间距LED显示屏相较于传统DLP背投显示屏的优势所在。 1、显示像素的完整性更高 在室内中高端显示市场中,DLP背投显示曾占…