【漫话机器学习系列】087.常见的神经网络最优化算法(Common Optimizers Of Neural Nets)

news2025/2/10 23:15:36

常见的神经网络优化算法

1. 引言

在深度学习中,优化算法(Optimizers)用于更新神经网络的权重,以最小化损失函数(Loss Function)。一个高效的优化算法可以加速训练过程,并提高模型的性能和稳定性。本文介绍几种常见的神经网络优化算法,包括随机梯度下降(SGD)、带动量的随机梯度下降(Momentum SGD)、均方根传播算法(RMSProp)以及自适应矩估计(Adam),并提供相应的代码示例。

2. 常见的优化算法

2.1 随机梯度下降(Stochastic Gradient Descent, SGD)

随机梯度下降(SGD)是最基本的优化算法,其更新规则如下:

其中:

  • w 代表网络参数(权重);
  • α 是学习率(Learning Rate),控制更新步长;
  • ∇L(w) 是损失函数相对于权重的梯度。

代码示例(使用 PyTorch 实现 SGD)

import torch
import torch.nn as nn
import torch.optim as optim

# 定义简单的线性模型
model = nn.Linear(1, 1)  # 1 个输入特征,1 个输出特征
criterion = nn.MSELoss()  # 均方误差损失
optimizer = optim.SGD(model.parameters(), lr=0.01)  # 随机梯度下降

# 训练步骤
for epoch in range(100):
    optimizer.zero_grad()  # 清空梯度
    inputs = torch.tensor([[1.0]], requires_grad=True)
    targets = torch.tensor([[2.0]])

    outputs = model(inputs)
    loss = criterion(outputs, targets)  # 计算损失
    loss.backward()  # 反向传播
    optimizer.step()  # 更新参数

    if epoch % 10 == 0:
        print(f'Epoch [{epoch}/100], Loss: {loss.item():.4f}')

运行结果

Epoch [0/100], Loss: 4.9142
Epoch [10/100], Loss: 2.1721
Epoch [20/100], Loss: 0.9601
Epoch [30/100], Loss: 0.4244
Epoch [40/100], Loss: 0.1876
Epoch [50/100], Loss: 0.0829
Epoch [60/100], Loss: 0.0366
Epoch [70/100], Loss: 0.0162
Epoch [80/100], Loss: 0.0072
Epoch [90/100], Loss: 0.0032


2.2 带动量的随机梯度下降(Momentum SGD)

带动量的 SGD 在 SGD 的基础上加入动量(Momentum),用于加速收敛并减少震荡:


其中:

  • 是累积的梯度,类似于物理中的动量;
  • β 是动量系数(通常取 0.9)。

代码示例(Momentum SGD)

import torch
import torch.nn as nn
import torch.optim as optim

model = nn.Linear(1, 1)  # 1 个输入特征,1 个输出特征
criterion = nn.MSELoss()  # 均方误差损失
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

for epoch in range(100):
    optimizer.zero_grad()
    inputs = torch.tensor([[1.0]], requires_grad=True)
    targets = torch.tensor([[2.0]])
    outputs = model(inputs)
    loss = criterion(outputs, targets)
    loss.backward()
    optimizer.step()

    if epoch % 10 == 0:
        print(f'Epoch [{epoch}/100], Loss: {loss.item():.4f}')

运行结果 

Epoch [0/100], Loss: 3.0073
Epoch [10/100], Loss: 1.3292
Epoch [20/100], Loss: 0.5875
Epoch [30/100], Loss: 0.2597
Epoch [40/100], Loss: 0.1148
Epoch [50/100], Loss: 0.0507
Epoch [60/100], Loss: 0.0224
Epoch [70/100], Loss: 0.0099
Epoch [80/100], Loss: 0.0044
Epoch [90/100], Loss: 0.0019

优点:

  • 缓解了 SGD 震荡问题,提高收敛速度;
  • 在非凸优化问题中表现更好。

2.3 均方根传播算法(RMSProp)

RMSProp 通过自适应调整学习率来加速训练,并缓解震荡问题:


其中:

  • 是梯度平方的滑动平均;
  • β 是衰减系数(一般取 0.9);
  • ϵ 是一个很小的数,防止除零错误。

代码示例(RMSProp)

import torch
import torch.nn as nn
import torch.optim as optim

# 定义简单的线性模型
model = nn.Linear(1, 1)  # 1 个输入特征,1 个输出特征
criterion = nn.MSELoss()  # 均方误差损失
optimizer = optim.RMSprop(model.parameters(), lr=0.01, alpha=0.9)

for epoch in range(100):
    optimizer.zero_grad()
    inputs = torch.tensor([[1.0]], requires_grad=True)
    targets = torch.tensor([[2.0]])
    outputs = model(inputs)
    loss = criterion(outputs, targets)
    loss.backward()
    optimizer.step()

    if epoch % 10 == 0:
        print(f'Epoch [{epoch}/100], Loss: {loss.item():.4f}')

运行结果

Epoch [0/100], Loss: 1.1952
Epoch [10/100], Loss: 0.5887
Epoch [20/100], Loss: 0.3333
Epoch [30/100], Loss: 0.1731
Epoch [40/100], Loss: 0.0752
Epoch [50/100], Loss: 0.0239
Epoch [60/100], Loss: 0.0043
Epoch [70/100], Loss: 0.0003
Epoch [80/100], Loss: 0.0000
Epoch [90/100], Loss: 0.0000

优点:

  • 适用于非平稳目标函数;
  • 能有效处理不同特征尺度的问题;
  • 在 RNN(循环神经网络)等任务上表现较好。

2.4 自适应矩估计(Adam, Adaptive Moment Estimation)

Adam 结合了动量法(Momentum)和 RMSProp,同时考虑梯度的一阶矩(平均值)和二阶矩(方差):



其中:

  • ​ 是梯度的一阶矩估计;
  • ​ 是梯度的二阶矩估计;
  • ​ 分别控制一阶矩和二阶矩的指数衰减率(通常取 0.9 和 0.999)。

代码示例(Adam)

import torch
import torch.nn as nn
import torch.optim as optim

# 定义简单的线性模型
model = nn.Linear(1, 1)  # 1 个输入特征,1 个输出特征
criterion = nn.MSELoss()  # 均方误差损失
optimizer = optim.Adam(model.parameters(), lr=0.01)

for epoch in range(100):
    optimizer.zero_grad()
    inputs = torch.tensor([[1.0]], requires_grad=True)
    targets = torch.tensor([[2.0]])
    outputs = model(inputs)
    loss = criterion(outputs, targets)
    loss.backward()
    optimizer.step()

    if epoch % 10 == 0:
        print(f'Epoch [{epoch}/100], Loss: {loss.item():.4f}')

输出结果 

Epoch [0/100], Loss: 3.6065
Epoch [10/100], Loss: 2.8894
Epoch [20/100], Loss: 2.2642
Epoch [30/100], Loss: 1.7359
Epoch [40/100], Loss: 1.3021
Epoch [50/100], Loss: 0.9555
Epoch [60/100], Loss: 0.6855
Epoch [70/100], Loss: 0.4805
Epoch [80/100], Loss: 0.3287
Epoch [90/100], Loss: 0.2192

优点:

  • 结合 Momentum 和 RMSProp 的优势;
  • 适用于大规模数据集和高维参数优化;
  • 具有自适应学习率,适用于不同类型的问题。

3. 选择合适的优化算法

优化算法特点适用场景
SGD计算简单,但容易震荡适用于大规模数据,适合凸优化问题
Momentum SGD增加动量,减少震荡,加速收敛适用于复杂深度神经网络
RMSProp自适应调整学习率,适用于非平稳问题适用于 RNN、强化学习等
Adam结合 Momentum 和 RMSProp,自适应学习率适用于大多数深度学习任务

4. 结论

在神经网络训练过程中,优化算法的选择对最终的模型性能有重要影响。SGD 是最基础的优化方法,而带动量的 SGD 在收敛速度和稳定性上有所提升。RMSProp 适用于非平稳目标函数,而 Adam 结合了 Momentum 和 RMSProp 的优势,成为当前最流行的优化算法之一。

不同任务可能需要不同的优化算法,通常的建议是:

  • 对于简单的凸优化问题,可以使用 SGD。
  • 对于深度神经网络,可以使用 Momentum SGD 或 Adam。
  • 对于 RNN 和强化学习问题,RMSProp 是一个不错的选择。

合理选择优化算法可以显著提升模型训练的效率和效果!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2296016.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【JVM详解四】执行引擎

一、概述 Java程序运行时,JVM会加载.class字节码文件,但是字节码并不能直接运行在操作系统之上,而JVM中的执行引擎就是负责将字节码转化为对应平台的机器码让CPU运行的组件。 执行引擎是JVM核心的组成部分之一。可以把JVM架构分成三部分&am…

route 与 router 之间的差别

简述&#xff1a; router&#xff1a;主要用于处理一些动作&#xff0c; route&#xff1a;主要获得或处理一些数据&#xff0c;比如地址、参数等 例&#xff1a; videoInfo1.vue&#xff1a; <template><div class"video-info"><h3>二级组件…

SamWaf开源轻量级的网站应用防火墙(安装包),私有化部署,加密本地存储的数据,易于启动,并支持 Linux 和 Windows 64 位和 Arm64

一、SamWaf轻量级开源防火墙介绍 &#xff08;文末提供下载&#xff09; SamWaf网站防火墙是一款适用于小公司、工作室和个人网站的开源轻量级网站防火墙&#xff0c;完全私有化部署&#xff0c;数据加密且仅保存本地&#xff0c;一键启动&#xff0c;支持Linux&#xff0c;Wi…

极客说|利用 Azure AI Agent Service 创建自定义 VS Code Chat participant

作者&#xff1a;卢建晖 - 微软高级云技术布道师 「极客说」 是一档专注 AI 时代开发者分享的专栏&#xff0c;我们邀请来自微软以及技术社区专家&#xff0c;带来最前沿的技术干货与实践经验。在这里&#xff0c;您将看到深度教程、最佳实践和创新解决方案。关注「极客说」&a…

windows + visual studio 2019 使用cmake 编译构建静、动态库并调用详解

环境 windows visual studio 2019 visual studio 2019创建cmake工程 1. 静态库.lib 1.1 静态库编译生成 以下是我创建的cmake工程文件结构&#xff0c;只关注高亮文件夹部分 libout 存放编译生成的.lib文件libsrc 存放编译用的源代码和头文件CMakeLists.txt 此次编译CMak…

【kafka实战】05 Kafka消费者消费消息过程源码剖析

1. 概述 Kafka消费者&#xff08;Consumer&#xff09;是Kafka系统中负责从Kafka集群中拉取消息的客户端组件。消费者消费消息的过程涉及多个步骤&#xff0c;包括消费者组的协调、分区分配、消息拉取、消息处理等。本文将深入剖析Kafka消费者消费消息的源码&#xff0c;并结合…

[EAI-033] SFT 记忆,RL 泛化,LLM和VLM的消融研究

Paper Card 论文标题&#xff1a;SFT Memorizes, RL Generalizes: A Comparative Study of Foundation Model Post-training 论文作者&#xff1a;Tianzhe Chu, Yuexiang Zhai, Jihan Yang, Shengbang Tong, Saining Xie, Dale Schuurmans, Quoc V. Le, Sergey Levine, Yi Ma 论…

算法与数据结构(字符串相乘)

题目 思路 这道题我们可以使用竖式乘法&#xff0c;从右往左遍历每个乘数&#xff0c;将其相乘&#xff0c;并且把乘完的数记录在nums数组中&#xff0c;然后再进行进位运算&#xff0c;将同一列的数进行相加&#xff0c;进位。 解题过程 首先求出两个数组的长度&#xff0c;…

DeepSeek从入门到精通:全面掌握AI大模型的核心能力

文章目录 一、DeepSeek是什么&#xff1f;性能对齐OpenAI-o1正式版 二、Deepseek可以做什么&#xff1f;能力图谱文本生成自然语言理解与分析编程与代码相关常规绘图 三、如何使用DeepSeek&#xff1f;四、DeepSeek从入门到精通推理模型推理大模型非推理大模型 快思慢想&#x…

【异常解决】在idea中提示 hutool 提示 HttpResponse used withoud try-with-resources statement

博主介绍&#xff1a;✌全网粉丝22W&#xff0c;CSDN博客专家、Java领域优质创作者&#xff0c;掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域✌ 技术范围&#xff1a;SpringBoot、SpringCloud、Vue、SSM、HTML、Nodejs、Python、MySQL、PostgreSQL、大数据、物…

【Uniapp-Vue3】UniCloud云数据库获取指定字段的数据

使用where方法可以获取指定的字段&#xff1a; let db uniCloud.database(); db.collection("数据表").where({字段名1:数据, 字段名2:数据}).get({getOne:true}) 如果我们不在get中添加{getOne:true}&#xff0c;在只获取到一个数据res.result.data将会是一个数组&…

信息科技伦理与道德3-2:智能决策

2.2 智能推荐 推荐算法介绍 推荐系统&#xff1a;猜你喜欢 https://blog.csdn.net/search_129_hr/article/details/120468187 推荐系统–矩阵分解 https://blog.csdn.net/search_129_hr/article/details/121598087 案例一&#xff1a;YouTube推荐算法向儿童推荐不适宜视频 …

Visual Studio 2022 中使用 Google Test

要在 Visual Studio 2022 中使用 Google Test (gtest)&#xff0c;可以按照以下步骤进行&#xff1a; 安装 Google Test&#xff1a;确保你已经安装了 Google Test。如果没有安装&#xff0c;可以通过 Visual Studio Installer 安装。在安装程序中&#xff0c;找到并选择 Googl…

WGCLOUD监控系统部署教程

官网地址&#xff1a;下载WGCLOUD安装包 - WGCLOUD官网 第一步、环境配置 #安装jdk 1、安装 EPEL 仓库&#xff1a; sudo yum install -y epel-release 2、安装 OpenJDK 11&#xff1a; sudo yum install java-11-openjdk-devel 3、如果成功&#xff0c;你可以通过运行 java …

协议-WebRTC-HLS

是什么&#xff1f; WebRTC&#xff08;Web Real-Time Communication&#xff09; 实现 Web 浏览器和移动应用程序之间通过互联网直接进行实时通信。允许点对点音频、视频和数据共享&#xff0c;而无需任何插件或其他软件。WebRTC 广泛用于构建视频会议、语音通话、直播、在线游…

MySQL系列之数据类型(String)

导览 前言一、字符串类型知多少 1. 类型说明2. 字符和字节的转换 二、字符串类型的异同 1. CHAR & VARCHAR2. BINARY & VARBINARY3. BLOB & TEXT4. ENUM & SET 结语精彩回放 前言 MySQL数据类型第三弹闪亮登场&#xff0c;欢迎关注O。 本篇博主开始谈谈MySQ…

【C++高并发服务器WebServer】-15:poll、epoll详解及实现

本文目录 一、poll二、epoll2.1 相对poll和select的优点2.2 epoll的api2.3 epoll的demo实现2.5 epoll的工作模式 一、poll poll是对select的一个改进&#xff0c;我们先来看看select的缺点。 我们来看看poll的实现。 struct pollfd {int fd; /* 委托内核检测的文件描述符 */s…

git提交到GitHub问题汇总

1.main->master git默认主分支是maser&#xff0c;如果是按照这个分支名push&#xff0c;GitHub会出现两个branch&#xff0c;与预期不符 解决方案&#xff1a;更改原始主分支名为main git config --global init.defaultBranch main2.git&#xff1a;OpenSSL SSL_read: SS…

CNN-GRU卷积神经网络门控循环单元多变量多步预测,光伏功率预测(Matlab完整源码和数据)

代码地址&#xff1a;CNN-GRU卷积神经网络门控循环单元多变量多步预测&#xff0c;光伏功率预测&#xff08;Matlab完整源码和数据) CNN-GRU卷积神经网络门控循环单元多变量多步预测&#xff0c;光伏功率预测 一、引言 1.1、研究背景和意义 随着全球能源危机和环境问题的日…

编译原理面试问答

编译原理面试拷打 1.编译原理的基本概念 编译原理是研究如何将高级程序语言转换为计算机可执行代码的理论与技术&#xff0c;其核心目标是实现高效、正确的代码翻译。 **编译器&#xff1a;**将源代码转化为目标代码&#xff08;机器码、字节码等&#xff09;。一次翻译整个程…