程序员学长 | 超强!六大优化算法全总结

news2024/9/24 9:19:56

本文来源公众号“程序员学长”,仅用于学术分享,侵权删,干货满满。

原文链接:超强!六大优化算法全总结

今天我们将详细讨论一下用于训练神经网络(深度学习模型)时使用的一些常见优化技术(优化器)。

主要包括:

  • Gradient Descent Algorithm

  • Momentum

  • Adagrad

  • RmsProp

  • Adadelta

  • Adam

图片

梯度下降算法

梯度下降算法是一种基于迭代的优化算法,用于寻找函数(成本函数)的最小值。

在深度学习模型中常常用来在反向传播过程中更新神经网络的权重。梯度下降的目标是找到使成本函数值最小的参数值。

梯度下降的组成部分

  • 成本函数:

    这是需要最小化的函数。在机器学习中,成本函数通常度量模型预测值与实际值之间的差异。常用的成本函数有均方误差 (MSE)、均方根误差 (RMSE)、平均绝对误差 (MAE) 等。

  • epochs :

    它是一个超参数,表示要运行的迭代次数,即为更新模型参数而计算梯度的次数

  • 学习率

    它是一个超参数,指的是更新步长的大小。如果太大,算法会发散而不是收敛。如果太小,算法需要大量迭代才能收敛,并且可能会遇到梯度消失问题。

执行步骤

梯度下降算法的不同变体

梯度下降算法有几种变体,包括:

批量梯度下降

批量梯度下降使用整个数据集计算梯度。

该方法提供了稳定的收敛性和一致的误差梯度,但对于大型数据集来说计算成本高昂且速度缓慢。

随机梯度下降 (SGD)

SGD 使用单个随机选择的数据点来估计梯度。

虽然它可以更快并且能够逃脱局部最小值,但由于其固有的随机性,它的收敛模式更加不稳定,可能导致成本函数的振荡。

小批量梯度下降

小批量梯度下降在上述两种方法之间取得了平衡。

它使用数据集的子集(或“小批量”)计算梯度。该方法利用矩阵运算的计算优势来加速收敛,并在批量梯度下降的稳定性和 SGD 的速度之间提供折衷方案。

Momentum

Momentum 是一种用于加速梯度下降算法的技术。它在某些方面类似于物理学中的动量概念,因此得名。Momentum 帮助算法在正确的方向上移动,同时减少“摆动”。

在数学上,Momentum 方法修改了标准梯度下降算法中的参数更新规则。在标准梯度下降中,每个参数在每次迭代中根据梯度的相反方向更新。而在使用 Momentum 的情况下,会考虑过去梯度的一部分,这样参数的更新不仅取决于当前梯度,还取决于之前梯度的累积。这可以通过一个动量项来实现,该动量项是过去梯度的加权平均。

具体来说,Momentum 算法在更新参数时使用以下公式。

Momentum 的主要好处包括

  • 加速学习过程,特别是在梯度曲面的方向一致时。

  • 减少震荡,有助于更平稳地收敛到最小值。

  • 在某些情况下,有助于逃脱局部最小值。

Adagrad

Adagrad 是一种自适应的学习率优化算法,专为处理稀疏数据而设计。在传统的梯度下降算法中,全局学习率应用于所有的参数更新,而 Adagrad 允许每个参数有不同的学习率,以便自动调整学习率,这对于处理不同频率的特征是非常有用的。

Adagrad 算法的关键点在于累积过去所有梯度的平方和,用这个累积值来调节每个参数的学习率。这意味着对于出现频率较高的特征,它们的累积梯度会很大,因此学习率会降低;对于出现频率较低的特征,它们的累积梯度小,学习率则相对较高。

Adagrad的优点包括

  • 不需要手动调整学习率,算法会自动进行调整。

  • 能够很好地处理稀疏数据。

  • 对于不同频率的参数可以进行有效的学习。

然而,Adagrad 也有一些局限性

  • 学习率是单调递减的,随着时间的推移,学习率可能会过早和过度地减小到0,这会导致训练过程提前结束。

  • 累积平方梯度在算法运行过程中不断累加,可能会导致分母过大,使得学习率过小。

RmsProp

RMSprop 是一种自适应学习率优化算法,被广泛用于训练各种类型的神经网络。RMSprop 旨在解决 AdaGrad 算法在训练深度神经网络时面临的一些问题,特别是学习率快速下降的问题。

关键概念

  1. 平方梯度累积: 在每次迭代中,RMSprop 首先计算梯度的平方,并将其累积到一个指数衰减的平均中。这个平均可以看作是最近梯度大小的移动平均。这有助于平衡梯度更新,特别是当梯度在不同方向的大小差异很大时。

  2. 自适应学习率: RMSprop 通过这种方式调整学习率,使其对于每个参数都是不同的。具体来说,它通过梯度的移动平均值来调整每个参数的更新。如果一个参数的梯度持续大,那么它的学习率会减少;反之,则增加。

  3. 防止学习率过快减小: 与 AdaGrad 相比,RMSprop 不会让学习率过快减小。这是因为它不是累积所有过去梯度的平方,而是仅仅关注最近的梯度。这使得 RMSprop 在长期训练中更加有效,尤其是对于非凸优化问题。

在数学上,RMSprop 的更新规则如下:

RMSprop 的这些特性使它在深度学习特别是循环神经网络(RNN)的训练中非常有效。它可以自动调整学习率,从而在训练的不同阶段有效地优化模型。

Adadelta

Adadelta 是一种基于梯度的优化算法,专门用于深度学习网络的训练。它是 AdaGrad 算法的扩展,旨在解决 AdaGrad 在训练过程中学习率单调递减的问题。Adadelta 的主要特点是不需要一个全局学习率,而是根据参数的更新历史来调整学习率。

在数学上,Adadelta 的更新规则如下

Adam

Adam 是一种用于机器学习和人工智能领域的强大优化算法,它结合了 Momentum 和 RMSprop 两种优化算法的特点。该算法基于自适应矩估计,利用梯度的一阶矩和二阶矩在训练过程中动态调整学习率。这有助于确保算法快速收敛并避免陷入局部最优。

Adam 的一个主要优点是它需要很少的内存并且只需要一阶梯度,这使其成为随机优化的有效方法。它也非常适合解决大量数据或参数的问题。

算法概念

下面是 Adam 算法中涉及的关键公式和步骤。

Adam 的优势在于每个参数的学习率是自适应调整的,这使得它在实践中对超参数的选择相对不敏感,并且适用于大多数非凸优化问题。此外,Adam 通常表现出比其他自适应学习率方法更快的收敛性能。

案例分享

import keras # Import the Keras library
from keras.datasets import mnist # Load the MNIST dataset
from keras.models import Sequential # Initialize a sequential model
from keras.layers import Dense, Dropout, Flatten
from keras.layers import Conv2D, MaxPooling2D
import numpy as np
# Load the MNIST dataset from Keras
(x_train, y_train), (x_test, y_test) = mnist.load_data()

# Print the shape of the training and test data
print(x_train.shape, y_train.shape)

# Reshape the training and test data to 4 dimensions
x_train = x_train.reshape(x_train.shape[0], 28, 28, 1)

x_test = x_test.reshape(x_test.shape[0], 28, 28, 1)

# Define the input shape
input_shape = (28, 28, 1)

# Convert the labels to categorical format
y_train = keras.utils.to_categorical(y_train)
y_test = keras.utils.to_categorical(y_test)

# Convert the pixel values to floats between 0 and 1
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')

# Normalize the pixel values by dividing them by 255
x_train /= 255
x_test /= 255

# Define the batch size and number of classes
batch_size = 60
num_classes = 10

# Define the number of epochs to train the model for
epochs = 10


def build_model(optimizer):
    model = Sequential()
    model.add(Conv2D(32, kernel_size=(3, 3), activation='relu', input_shape=input_shape))
    model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(Dropout(0.25))
    model.add(Flatten())
    model.add(Dense(128, activation='relu'))
    model.add(Dropout(0.5))
    model.add(Dense(num_classes, activation='softmax'))
    model.compile(loss='categorical_crossentropy', optimizer=optimizer, metrics=['accuracy'])
    return model

optimizers=['SGD','Adagrad','RMSprop','Adadelta','Adam']
histories ={opt:build_model(opt).fit(x_train,y_train,batch_size=batch_size, 
epochs=epochs,verbose=1,validation_data=(x_test,y_test)) for opt in optimizers}

THE END !

文章结束,感谢阅读。您的点赞,收藏,评论是我继续更新的动力。大家有推荐的公众号可以评论区留言,共同学习,一起进步。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1996407.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

FlexBV电路查看软件

FlexBV - Macbook, iPhone, PC/Laptop & Electronics BoardViewer with PDF Cross Referencing 免费。 支持tvw,cad格式。 支持Windows,Linux,Mac。 而且我发现cad格式是文本的!意味着可以自由编辑!

springboot窝窝酒店管理系统-计算机毕业设计源码91798

摘 要 随着时代的进步与发展,互联网技术的应用也变得日益广泛。窝窝酒店管理系统在当今社会体系中扮演了一个非常重要的角色,它能大大地提高效率并减少了资源上的浪费。本文首先介绍了窝窝酒店管理系统的优势以及重要性;然后描述了这个系统的…

学习鸿蒙-构建私有仓储

1.选择 鸿蒙提供ohpm-repo工具用于构建本地私有仓储 ohpm-repo下载 2.环境配置 安装node,ohpm-repo 支持 node.js 18.x 及以上版本 node最新版本下载 3.配置文件及运行 1.解压 ohpm-repo 私仓工具包 2.进入 ohpm-repo 解压目录的 conf 目录内,打开 c…

PyTorch深度学习框架

最近放假在超星总部河北燕郊园区实习,本来是搞前后端开发岗位的,然后带我的副总老大哥比较关照我,了解我的情况后得知我大三选的方向是大数据,于是建议我学学python、Hadoop,Hadoop我看了一下内容比较多,而…

从概念到落地:全面解析DApp项目开发的核心要素与未来趋势

随着区块链技术的迅猛发展,去中心化应用程序(DApp)逐渐成为Web3时代的重要组成部分。DApp通过智能合约和分布式账本技术,提供了无需信任中介的解决方案,这种去中心化的特性使其在金融、游戏、社交等多个领域得到了广泛…

金融行业如何高效管理新媒体矩阵

金融行业作为经济体系的重要一环,受到社会多方关注和监管。 前有“985大一投行实习日常”的短视频引发大众热议,后有某机构女员工自爆事件牵扯出多家金融机构,将金融行业一度推到了舆论的风口浪尖。 这两件事的接连出现,也把金融新…

飞天发布时刻:大数据AI平台产品升级发布

7月24日,阿里云飞天发布时刻产品发布会围绕阿里云大数据AI平台的新能力和新产品进行详细介绍。人工智能平台PAI、云原生大数据计算服务MaxCompute、开源大数据平台E-MapReduce、实时数仓Hologres、阿里云Elasticsearch、向量检索Milvus等产品均带来了相关发布的深度…

C++必修:STL之forward_list与list的使用

✨✨ 欢迎大家来到贝蒂大讲堂✨✨ 🎈🎈养成好习惯,先赞后看哦~🎈🎈 所属专栏:C学习 贝蒂的主页:Betty’s blog 1. forward_list与list forward_list 是 C 11 引入的一种容器,它是一…

LQR横向控制及融合PID纵向控制C++实现

目录 简介一、现代控制理论1.1 经典控制理论和现代控制理论的区别1.2 全状态反馈控制系统 二、LQR控制器2.1 连续时间2.1.1 Q、R矩阵的选取2.1.2 推导过程2.1.3 连续时间下的LQR算法步骤 2.2 离散时间2.2.1 连续LQR和离散LQR的区别2.2.2离散时间下的LQR算法步骤 三、LQR实现自动…

AI大模型之旅--安装向量库milvus

milvus,向量索引库 1.milvus部署 milvus的官方文档中看到最新版本的部署方式Install Milvus Standalone with Docker Compose curl -sfL https://raw.githubusercontent.com/milvus-io/milvus/master/scripts/standalone_embed.sh -o standalone_embed.sh &#xf…

stm32f103c8t6与TB6612FNG解耦测试

stm32f103c8t6与TB6612FNG解耦测试 本文操作方式: 忽略底层,只做上层, 所以前面全部照搬步骤,重在调试 文章目录 stm32f103c8t6与TB6612FNG解耦测试本文操作方式:创建基本工程(1)跳转此链接,创建(2)创建电机驱动文件夹(3)PWM原理(4)电机转动控制 oled调试和key调试(5)OLED转速…

C++:奇异递归模板模式(CRTP模式)

奇异递归模板模式 文章目录 奇异递归模板模式理论说明CRTP模式的功能静态多态强制静态接口编译时多态优化解释 理论说明 奇异递归模板模式(Curiously Recurring Template Pattern, CRTP) 是一种设计模式,其原理很简单: 继承者将自…

工业三防平板赋能自动化产线打造工厂智慧管理

随着工业4.0时代的到来,智能制造成为了众多企业转型升级的必然选择。而MES系统作为智能制造的核心环节,能够有效地整合生产数据,提升生产效率,并实现工厂运营的数字化管理。然而,传统的MES系统大多依赖于PC端操作&…

关于vs调试的一些基本技巧方法,建议新手学习

文章目录 1.Debug 和 Release2.VS的调试快捷键3.对程序的监视和内存观察3.1监视3.2内存 4.编程常见错误归类4.1编译型错误4.2链接型错误4.3运行时错误 1.Debug 和 Release 在我们使用的编译器 vs 中,这个位置有两个选项,分别为Debug和Release&#xff0c…

开源应用:AI监测如何成为社会安全的智能盾牌

社会背景 随着社会的快速发展,社会安全管理正站在一个新时代的门槛上。社会对安全管理的需求不断增长,传统的安全措施已难以满足现代社会的需求。AI技术以其独特的数据处理和模式识别能力,正在成为我们社会安全的智能盾牌。 AI大模型识别功能…

【牛客】2024暑期牛客多校6 补题记录

文章目录 A - Cake(树上dp)B - Cake 2(暴力)D - Puzzle: Wagiri(tarjan)F - Challenge NPC 2(构造)H - Genshin Impacts Fault(签到)I - Intersecting Interv…

利用扩散模型DDPM生成高分辨率图像|(一)DDPM模型构建

利用扩散模型DDPM生成高分辨率图像(生成高保真图像项目实践) Mindspore框架利用扩散模型DDPM生成高分辨率图像|(一)关于denoising diffusion probabilistic model (DDPM)模型 Mindspore框架利用扩散模型DD…

数字音频工作站(DAW)FL Studio 24.1.1.4239中文破解版

FL Studio 24.1.1.4239中文破解版是一款功能强大的数字音频工作站(DAW),它广泛应用于音乐创作和音乐制作领域。FL Studio是由比利时软件公司Image-Line开发的音乐制作软件,它拥有丰富的音效、合成器、采样器、鼓机等工具。FL Stud…

stm32cubemx+ADC的多通道轮询数据采集和DMA数据采集实现,亲测可用

ADC是单片机的重要组成,也是存在一定的难点。 一、多通道轮询数据采集。 1、配置时钟,用的无源晶振。 2、SW烧写方式 添加USART 3、ADC选择了四个通道 其中两个是采集电压,另外两个是采集芯片内部温度和参考电压。 4、配置采集模式 这里是…

萌啦数据官网丨萌啦ozon数据分析工具官网

在当今这个数据驱动的时代,电子商务的蓬勃发展离不开精准的数据分析与洞察。对于在OZON平台上耕耘的商家而言,掌握市场趋势、优化产品布局、提升运营效率成为了赢得竞争的关键。正是在这样的背景下,萌啦数据官网应运而生,作为一款…