神经网络中的优化方法

news2024/10/6 12:23:42

 一、引入

在传统的梯度下降优化算法中,如果碰到平缓区域,梯度值较小,参数优化变慢 ,遇到鞍点(是指在某些方向上梯度为零而在其他方向上梯度非零的点。),梯度为 0,参数无法优化,碰到局部最小值。实践中使用的小批量梯度下降法(mini-batch SGD)因其梯度估计的噪声性质,有时能够使模型脱离这些点。

💥为了克服这些困难,研究者们提出了多种改进策略,出现了一些对梯度下降算法的优化方法:Momentum、AdaGrad、RMSprop、Adam 等。

二、指数加权平均

我们最常见的算数平均指的是将所有数加起来除以数的个数,每个数的权重是相同的。加权平均指的是给每个数赋予不同的权重求得平均数。指数加权平均是一种数据处理方式,它通过对历史数据应用不同的权重来减少过去数据的影响,并强调近期数据的重要性。

[ vt = beta * v{t-1} + (1 - beta) * theta_t] 

比如:明天气温怎么样,和昨天气温有很大关系,而和一周前的气温关系就小一些。 

vt​ 是第 𝑡 天的平均温度值,𝜃𝑡​ 是第 𝑡t 天的实际观察值,而 𝛽 是一个可调节的超参数(通常 0<𝛽<1)。这个公式表明,当前的平均值是前一天平均值与当天实际值的加权平均。

  • β 调节权重系数,该值越大平均数越平缓。

 我们接下来通过一段代码来看下结果,随机产生进 30 天的气温数据:

import torch
import matplotlib.pyplot as plt
import os
os.environ['KMP_DUPLICATE_LIB_OK']='TRUE'


# 实际平均温度
def test01():

    # 固定随机数种子
    torch.manual_seed(0)

    # 产生30天的随机温度
    temperature = torch.randn(size=[30]) * 10
    print(temperature)

    # 绘制平均温度
    days = torch.arange(1, 31, 1)
    plt.plot(days, temperature, color='r')
    plt.scatter(days, temperature)
    plt.show()

# 指数加权平均温度
def test02(beta=0.8):

    # 固定随机数种子
    torch.manual_seed(0)
    # torch.initial_seed()
    # 产生30天的随机温度
    temperature = torch.randn(size=[30,]) * 10
    print(temperature)

    exp_weight_avg = []
    for idx, temp in enumerate(temperature, 1):

        # 第一个元素的的 EWA 值等于自身
        if idx == 1:
            exp_weight_avg.append(temp)
            continue

        # 第二个元素的 EWA 值等于上一个 EWA 乘以 β + 当前气氛乘以 (1-β)
        new_temp = exp_weight_avg[idx - 2] * beta + (1 - beta) * temp
        exp_weight_avg.append(new_temp)


    days = torch.arange(1, 31, 1)
    plt.plot(days, exp_weight_avg, color='r')
    plt.scatter(days, temperature)
    plt.show()


if __name__ == '__main__':

    test01()
    test02()

这是test01执行后产生的实际值: 

 我们再看一下指数平均后的值:

🔎指数加权平均绘制出的气氛变化曲线更加平缓; β 的值越大,则绘制出的折线越加平缓; 

三、Momentum

我们通过对指数加权平均的知识来研究Momentum优化方法💢

  • 鞍点:梯度为零的点,损失函数的梯度在所有方向上都接近或等于零。由于梯度为零,标准梯度下降法在此将无法继续优化参数。
  • 平缓区域:这些区域的梯度值较小,导致参数更新缓慢。虽然这意味着算法接近极小值点,但收敛速度会变得非常慢。

当梯度下降碰到 “峡谷” 、”平缓”、”鞍点” 区域时,参数更新速度变慢,Momentum 通过指数加权平均法,累计历史梯度值,进行参数更新,越近的梯度值对当前参数更新的重要性越大。

Momentum优化方法是对传统梯度下降法的一种改进:

Momentum优化算法的核心思想是在一定程度上积累之前的梯度信息,以此来调整当前的梯度更新方向。这种方法可以在一定程度上减少训练过程中的摆动现象,使得学习过程更加平滑,从而可能使用较大的学习率而不必担心偏离最小值太远。 

梯度计算公式:Dt = β * St-1 + (1- β) * Dt

在面对梯度消失、鞍点等问题时,Momentum能够改善SGD的表现,帮助模型跳出局部最小值或平坦区域;如果当处于鞍点位置时,由于当前的梯度为 0,参数无法更新。但是 Momentum梯度下降算法已经在先前积累了一些梯度值,很有可能使得跨过鞍点。

由于 mini-batch 普通的梯度下降算法,每次选取少数的样本梯度确定前进方向,可能会出现震荡,使得训练时间变长。Momentum 使用移动加权平均,平滑了梯度的变化,使得前进方向更加平缓,有利于加快训练过程。一定程度上有利于降低 “峡谷” 问题的影响。

Momentum方法的实现案例: 

import torch
import torch.nn as nn
import torch.optim as optim

# 定义模型
model = nn.Linear(10, 1)

# 定义损失函数
criterion = nn.MSELoss()

# 定义优化器,并设置momentum参数为0.9
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

# 模拟数据
inputs = torch.randn(100, 10)
targets = torch.randn(100, 1)

# 训练模型
for epoch in range(10):
    # 前向传播
    outputs = model(inputs)
    loss = criterion(outputs, targets)

    # 反向传播和优化
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    print('Epoch [{}/{}], Loss: {:.4f}'.format(epoch+1, 10, loss.item()))

Momentum 算法可以理解为是对梯度值的一种调整,我们知道梯度下降算法中还有一个很重要的学习率,Momentum 并没有学习率进行优化。 

四、AdaGrad

💥Momentum 算法是对梯度值调整,使得模型可以更好的进行参数更新,AdaGrad算法则是对学习率,即每次更新走的步长,进行调整更新~

AdaGrad 通过对不同的参数分量使用不同的学习率,AdaGrad 的学习率总体会逐渐减小,这是因为 AdaGrad算法认为:在起初时,我们距离最优目标仍较远,可以使用较大的学习率,加快训练速度,随着迭代次数的增加,学习率逐渐下降。

🗨️计算步骤如下:

  1. 初始化学习率 α、初始化参数 θ、小常数 σ = 1e-6
  2. 初始化梯度累积变量 s = 0
  3. 从训练集中采样 m 个样本的小批量,计算梯度 g
  4. 累积平方梯度 s = s + g ⊙ g,⊙ 表示各个分量相乘

 

AdaGrad通过这种方式实现了对每个参数的个性化学习率调整,使得在参数空间较平缓的方向上可以取得更大的进步,而在陡峭的方向上则能够变得更加平缓,从而加快了训练速度( 如果累计梯度值s大的话,学习率就会小一点)        

使用Python实现AdaGrad算法的API代码:

import torch

class AdaGrad:
    def __init__(self, params, lr=0.01, epsilon=1e-8):
        self.params = list(params)
        self.lr = lr
        self.epsilon = epsilon
        self.cache = [torch.zeros_like(param) for param in self.params]

    def step(self):
        for i, param in enumerate(self.params):
            self.cache[i] += param.grad.data ** 2
            param.data -= self.lr * param.grad.data / (torch.sqrt(self.cache[i]) + self.epsilon)

💥AdaGrad 缺点是可能会使得学习率过早、过量的降低,导致模型训练后期学习率太小,较难找到最优解。

五、RMSProp

RMSProp(Root Mean Square Prop)是一种常用的自适应学习率优化算法,是对 AdaGrad 的优化,最主要的不同是,RMSProp使用指数移动加权平均梯度替换历史梯度的平方和。

  1. 初始化学习率 α、初始化参数 θ、小常数 σ = 1e-6
  2. 初始化参数 θ
  3. 初始化梯度累计变量 s
  4. 从训练集中采样 m 个样本的小批量,计算梯度 g
  5. 使用指数移动平均累积历史梯度

 

RMSProp 与 AdaGrad 最大的区别是对梯度的累积方式不同,对于每个梯度分量仍然使用不同的学习率。RMSProp 通过引入衰减系数 β,控制历史梯度对历史梯度信息获取的多少,使得学习率衰减更加合理一些。 

import numpy as np

def rmsprop(params, grads, learning_rate=0.01, decay_rate=0.9, epsilon=1e-8):
    cache = {}
    for key in params.keys():
        cache[key] = np.zeros_like(params[key])

    for key in params.keys():
        cache[key] = decay_rate * cache[key] + (1 - decay_rate) * grads[key] ** 2
        params[key] -= learning_rate * grads[key] / (np.sqrt(cache[key]) + epsilon)

    return params

params是一个字典,包含了模型的参数;grads是一个字典,包含了参数对应的梯度;learning_rate是学习率;decay_rate是衰减系数;epsilon是一个很小的正数,用于防止除以零。

六、Adam 

💯Adam 结合了两种优化算法的优点:RMSProp(Root Mean Square Prop)和Momentum。Adam在深度学习中被广泛使用,因为它能够自动调整学习率,特别适合处理大规模数据集和复杂模型。

Adam的关键特点:

  1. 一阶矩估计(First Moment):梯度的均值,类似于Momentum中的velocity term,用于指示梯度在何时变得非常剧烈。

  2. 二阶矩估计(Second Moment):梯度的未中心化方差,类似于RMSProp中的平方梯度的指数移动平均值,用于指示梯度变化的范围。

我们在平时使用中会经常用到次方法,在PyTorch中就是optim.Adam方法,不再是optim.SGD方法:

import torch
import torch.nn as nn
import torch.optim as optim

# 定义一个简单的线性模型
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.linear = nn.Linear(10, 1)  # 假设输入维度是10,输出维度是1

    def forward(self, x):
        return self.linear(x)

# 创建模型实例
model = SimpleModel()

# 定义损失函数
criterion = nn.MSELoss()  # 均方误差损失函数

# 创建优化器,设定学习率为0.001,参数beta1默认为0.9,beta2默认为0.999
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 假设有一个输入数据x和对应的目标y
x = torch.randn(32, 10)  # 批量大小为32,每个样本的输入维度是10
y = torch.randn(32, 1)   # 批量大小为32,每个样本的输出维度是1

# 前向传播
outputs = model(x)

# 计算损失
loss = criterion(outputs, y)

# 清空之前所有的梯度
optimizer.zero_grad()

# 反向传播
loss.backward()

# 更新模型参数
optimizer.step()

# 打印损失值
print("Loss: ", loss.item())

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1640259.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

基于Springboot的滑雪场管理系统(有报告)。Javaee项目,springboot项目。

演示视频&#xff1a; 基于Springboot的滑雪场管理系统&#xff08;有报告&#xff09;。Javaee项目&#xff0c;springboot项目。 项目介绍&#xff1a; 采用M&#xff08;model&#xff09;V&#xff08;view&#xff09;C&#xff08;controller&#xff09;三层体系结构&a…

【linuxC语言】守护进程

文章目录 前言一、守护进程的介绍二、开启守护进程总结 前言 在Linux系统中&#xff0c;守护进程是在后台运行的进程&#xff0c;通常以服务的形式提供某种功能&#xff0c;如网络服务、系统监控等。守护进程的特点是在启动时脱离终端并且在后台运行&#xff0c;它们通常不与用…

如何使用免费软件从Mac恢复音频文件?

要从Mac中删除任何文件&#xff0c;背后是有原因的。大多数Mac用户都希望增加Mac中的空间&#xff0c;这就是为什么他们更喜欢从驱动器中删除文件以便出现一些空间的原因。一些Mac用户错误地删除了该文件&#xff0c;无法识别这是一个重要文件。例如&#xff0c;他们错误地从Ma…

I/O体系结构和设备驱动程序

I/O体系结构 为了确保计算机能够正常工作&#xff0c;必须提供数据通路&#xff0c;让信息在连接到个人计算机的CPU、RAM和I/O设备之间流动。这些数据通路总称为总线&#xff0c;担当计算机内部主通信通道的作用。 所有计算机都拥有一条系统总线&#xff0c;它连接大部分内部…

ps科研常用操作,制作模式图 扣取想要的内容元素photoshop

复制想要copy的图片&#xff0c; 打开ps---file-----new &#xff0c;ctrolv粘贴图片进入ps 选择魔棒工具&#xff0c;点击想要去除的白色区域 然后&#xff0c;cotrol shift i&#xff0c;反选&#xff0c; ctrol shiftj复制&#xff0c;复制成功之后&#xff0c;一定要改…

【Java EE】Mybatis之XML详解

文章目录 &#x1f38d;配置数据库连接和MyBatis&#x1f340;写持久层代码&#x1f338;添加mapper接口&#x1f338;添加UserInfoXMLMapper.xml&#x1f338;单元测试 &#x1f332;CRUD&#x1f338;增(Insert)&#x1f338;删(Delete)&#x1f338;改(Update)&#x1f338;…

CMake:静态库链接其他动态库或静态库(九)

1、项目结构 对于下面这样一个项目 把calc模块做成静态或者动态库把sort模块做成静态库然后再sort模块中的*.cpp调用calc模块生成的库即可&#xff08;这样就制作了一个静态库引用动态或者静态库&#xff09;test模块用于测试sort模块中的内容 . ├── calc │ ├── ad…

ThreeJS:本地部署官网文档与案例

部署方式 部署之前请确保已经配置好node.js环境。 1. 下载ThreeJS源码 ThreeJS的GitHub地址&#xff1a;GitHub - mrdoob/three.js: JavaScript 3D Library.&#xff0c;可以简单查看ThreeJS当前版本&#xff1a;r164&#xff0c; 我们可以选择对应的版本&#xff08;此处为r1…

【跟马少平老师学AI】-【神经网络是怎么实现的】(七-2)word2vec模型

一句话归纳&#xff1a; 1&#xff09;CBOW模型&#xff1a; 2c个向量是相加&#xff0c;而不是拼接。 2&#xff09;CBOW模型中的哈夫曼树&#xff1a; 从root开始&#xff0c;向左为1&#xff0c;向右为0。叶子结点对应词有中的一个词。每个词对应唯一的编码。词编码不等长。…

计算机等级考试2级(Python)知识点整理

计算机等级考试2级&#xff08;Python&#xff09;知识点整理 1.基础知识点&#xff08;记忆、理解&#xff09; 第1讲Python概述 01. 源代码 02. 目标代码 03. 编译和解释 04. 程序的基本编写方法 第2讲 Python语言基础&#xff08;一&#xff09; 01. 用缩进表示代码…

深入理解网络原理1

文章目录 前言一、网络初识1.1 IP地址1.2 端口号1.3 协议1.4 五元组1.5 协议分层 二、TCP/IP五层协议三、封装和分用四、客户端vs服务端4.1 交互模式4.2 常见的客户端服务端模型4.3 TCP和UDP差别 前言 随着时代的发展&#xff0c;越来越需要计算机之间互相通信&#xff0c;共享…

前端基础学习html(1)

1.标题标签.h1,h2...h6 2.段落标签p 换行标签br 3.加粗strong(b) /倾斜em(i) /删除 del(s) /下划线ins(u) 4.盒子&#xff1a;div //一行一个 span//一行多个 5.img :src alt title width height border 图片src引用&#xff1a;相对路径 上级/同级/中级 绝对路径&#xff…

地图产业的困局与破局:高精地图“上车”难 轻量化渐成主流方案 | 最新快讯

《科创板日报》5月3日讯&#xff08;编辑 邱思雨&#xff09; 近期&#xff0c;特斯拉与百度的“绯闻”成为智驾、地图行业的焦点。 有媒体消息称&#xff0c;特斯拉将与百度地图独家深度定制车道级高辅地图。《科创板日报》记者也获悉&#xff0c;自5月1日起&#xff0c;百度…

【C语言实现贪吃蛇】(内含源码)

前言&#xff1a;首先在实现贪吃蛇小游戏之前&#xff0c;我们要先了解Win32 API的有关知识 1.Win32 API Windows这个多作业系统除了协调应用程序的执行、分配内存、管理资源之外&#xff0c;它同时也是一个很大的服务中心&#xff0c;调佣这个中心的各种服务&#xff08;每一…

前端面试和一些建议

最近公司在招前端&#xff0c;我有跟着一起参与面试。我们主要负责面试的人&#xff0c;不会问那些什么闭包&#xff0c;原型链&#xff0c;他觉得那些东西在我们日常开发中用不到&#xff0c;问的基本都是一些工作中的问题。这些问题不是每次都问&#xff0c;但也就问这些了。…

基于Unity+Vue通信交互的WebGL项目实践

unity-webgl 是无法直接向vue项目进行通信的&#xff0c;需要一个中间者 jslib 文件 jslib当作中间者&#xff0c;unity与它通信&#xff0c;前端也与它通信&#xff0c;在此基础上三者之间进行了通信对接 看过很多例子&#xff1a;介绍的都不是很详细&#xff0c;不如自己写&…

(39)4.29数据结构(栈,队列和数组)栈

#include<stdlib.h> #include<stdio.h> #define MaxSize 10 #define Elemtype int 1.栈的基本概念 2.栈的基本操作 typedef struct { Elemtype data[MaxSize]; int top; }Sqstack;//初始化栈 void InitStack(Sqstack& S) { S.top -1; //初始化…

4G小车的公网直播推流

一直想做一个小车, 可以通过4G推流, 没想到现在很多云服务提供商, SRS云服务器已经可以一键搭建了. 硬件方面, 就是一个1126驮着一个3516, 1126负责4G连接, 转流到Intenet, 3516负责vi_venc_rtsp 思路如下, 我的1126的摄像头一直没能横过来, 所以就不用1126的摄像头了, 先用35…

Redis-概述-安装-基本知识

Redis概述 Redis是什么 Redis&#xff08;Remote Dictionary Server 远程字段服务&#xff09;是一个开源的使用ANSI C语言编写、支持网 络、内存亦可持久化的key-value数据库&#xff0c;并提供多种语言的API。Redis是一个key-value存储系统&#xff0c;它支持存储的value类型…

好用的AI工具推荐与案例分析

你用过最好用的AI工具有哪些&#xff1f; 简介&#xff1a;探讨人们在使用AI工具时&#xff0c;最喜欢的和认为最好用的工具是哪些&#xff0c;展示AI技术的实际应用和影响。 方向一&#xff1a;常用AI工具 在选择常用AI工具时&#xff0c;可以根据不同的应用场景和需求来挑选…