AI学习指南深度学习篇-Adam的Python实践

news2024/9/20 16:22:03

AI学习指南深度学习篇-Adam的Python实践

在深度学习领域,优化算法是影响模型性能的关键因素之一。Adam(Adaptive Moment Estimation)是一种广泛使用的优化算法,因其在多种问题上均表现优异而被广泛使用。本文将深入探讨Adam优化器,并提供详细的代码示例,展示如何在Python的深度学习库(如TensorFlow和PyTorch)中实现Adam,进行模型训练以及调参过程。

引言

优化算法的选择会影响深度学习模型的收敛速度和最终性能。Adam算法不仅结合了动量(Momentum)的优点,还引入了自适应学习率,这使得其在许多任务中表现良好。本文将通过实际代码示例介绍Adam的实现和调参过程,让读者能够在自己的项目中有效应用这一算法。

Adam优化器概述

2.1 公式推导

Adam优化器的核心思想是计算梯度的动量以及梯度的平方动量,并利用这两个动量来调整学习率。Adam的更新公式如下:

  1. 初始化参数

    • ( m t = 0 ) ( m_t = 0 ) (mt=0)(一阶矩估计)
    • ( v t = 0 ) ( v_t = 0 ) (vt=0)(二阶矩估计)
    • ( t = 0 ) ( t = 0 ) (t=0)(时间步长)
    • ( β 1 , β 2 ) ( \beta_1, \beta_2 ) (β1,β2)(通常取值为0.9,0.999)
    • ( ϵ ) ( \epsilon ) (ϵ)(通常取小值以避免除零错误)
  2. 参数更新
    [ t = t + 1 ] [ t = t + 1 ] [t=t+1]
    [ m t = β 1 ⋅ m t − 1 + ( 1 − β 1 ) ⋅ g t ] [ m_t = \beta_1 \cdot m_{t-1} + (1 - \beta_1) \cdot g_t ] [mt=β1mt1+(1β1)gt]
    [ v t = β 2 ⋅ v t − 1 + ( 1 − β 2 ) ⋅ g t 2 ] [ v_t = \beta_2 \cdot v_{t-1} + (1 - \beta_2) \cdot g_t^2 ] [vt=β2vt1+(1β2)gt2]
    [ m ^ t = m t 1 − β 1 t ] [ \hat{m}_t = \frac{m_t}{1 - \beta_1^t} ] [m^t=1β1tmt]
    [ v ^ t = v t 1 − β 2 t ] [ \hat{v}_t = \frac{v_t}{1 - \beta_2^t} ] [v^t=1β2tvt]
    [ θ t = θ t − 1 − α v ^ t + ϵ ⋅ m ^ t ] [ \theta_{t} = \theta_{t-1} - \frac{\alpha}{\hat{v}_t + \epsilon} \cdot \hat{m}_t ] [θt=θt1v^t+ϵαm^t]

2.2 参数说明

  • 学习率 ( ( α ) ) ((\alpha)) ((α)):控制每次更新的步幅,通常初始值设为0.001。
  • ( β 1 ) (\beta_1) (β1) ( β 2 ) (\beta_2) (β2):分别控制一阶矩和二阶矩的衰减率。
  • ( ϵ ) (\epsilon) (ϵ):通常设为 ( 1 0 − 8 ) (10^{-8}) (108),避免在计算时出现除零错误。

在TensorFlow中使用Adam

3.1 环境准备

确保你的计算环境中安装了TensorFlow和其他必要的库:

pip install tensorflow numpy matplotlib

3.2 数据加载

我们将使用Keras提供的MNIST手写数字数据集作为示例:

import tensorflow as tf
from tensorflow.keras import layers, models
import numpy as np
import matplotlib.pyplot as plt

# 加载MNIST数据集
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()

# 数据预处理
x_train = x_train.astype("float32") / 255.0
x_test = x_test.astype("float32") / 255.0
y_train = tf.keras.utils.to_categorical(y_train, 10)
y_test = tf.keras.utils.to_categorical(y_test, 10)

3.3 构建模型

我们将定义一个简单的神经网络模型:

def create_model():
    model = models.Sequential()
    model.add(layers.Flatten(input_shape=(28, 28)))
    model.add(layers.Dense(128, activation="relu"))
    model.add(layers.Dropout(0.2))
    model.add(layers.Dense(10, activation="softmax"))
    return model

3.4 训练模型

使用Adam优化器训练模型:

model = create_model()

# 编译模型
model.compile(optimizer="adam",
              loss="categorical_crossentropy",
              metrics=["accuracy"])

# 训练模型
history = model.fit(x_train, y_train, epochs=10, batch_size=32, validation_split=0.2)

3.5 调整超参数

可以通过以下方式调整超参数,比如修改学习率或尝试不同的批大小:

from tensorflow.keras.optimizers import Adam

# 创建自定义Adam优化器
adam = Adam(learning_rate=0.001)

# 重新编译模型
model.compile(optimizer=adam, 
              loss="categorical_crossentropy", 
              metrics=["accuracy"])

# 重新训练模型
history = model.fit(x_train, y_train, epochs=10, batch_size=64, validation_split=0.2)

在PyTorch中使用Adam

4.1 环境准备

确保你的计算环境中安装了PyTorch和其他必要的库:

pip install torch torchvision numpy matplotlib

4.2 数据加载

与TensorFlow类似,我们将使用同样的数据集:

import torch
from torchvision import datasets, transforms
from torch import nn, optim

# 数据预处理
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])

# 加载MNIST数据集
trainset = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True)

testset = datasets.MNIST(root="./data", train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=32, shuffle=False)

4.3 构建模型

PyTorch模型构建如下:

class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 128)
        self.dropout = nn.Dropout(0.2)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = x.view(x.shape[0], -1)  # 展平操作
        x = torch.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x

model = SimpleNN()

4.4 训练模型

使用Adam优化器训练模型的示例如下:

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 训练模型
epochs = 10
for epoch in range(epochs):
    running_loss = 0
    for images, labels in trainloader:
        optimizer.zero_grad()  # 清空梯度
        output = model(images)  # 前向传播
        loss = criterion(output, labels)  # 计算损失
        loss.backward()  # 反向传播
        optimizer.step()  # 更新参数
        running_loss += loss.item()
    print(f"Epoch {epoch + 1}/{epochs} - Loss: {running_loss/len(trainloader)}")

4.5 调整超参数

在PyTorch中,你也可以像在TensorFlow中那样调整超参数,下面是修改学习率的例子:

# 创建自定义Adam优化器
optimizer = optim.Adam(model.parameters(), lr=0.0001)

# 重新训练模型
for epoch in range(epochs):
    running_loss = 0
    for images, labels in trainloader:
        optimizer.zero_grad()
        output = model(images)
        loss = criterion(output, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    print(f"Epoch {epoch + 1}/{epochs} - Loss: {running_loss/len(trainloader)}")

结论

Adam优化器因其良好的自适应性和快速的收敛能力,成为深度学习中最流行的优化算法之一。在TensorFlow和PyTorch等深度学习框架中,Adam均被用户广泛应用。本文详细介绍了在这两种框架中使用Adam优化器进行模型训练的完整流程,并展示了如何在训练过程中灵活调整超参数。希望这篇文章能帮助你更好地理解和应用Adam优化器。尽管TensorFlow和PyTorch有其独特之处,但选用合适的优化器对于模型的最终表现仍然至关重要。在实际应用中,建议尝试多种优化算法并进行超参数调整,以获得最佳的训练效果。

如果想了解更深入的Adam算法工作原理或其他优化算法的使用,请关注后续更新,继续学习更多的深度学习内容。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2149392.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

某锂电厂房项目密集母线槽上红外测温的案例分享

1 行业背景 在政策和技术推动下,锂电产业迅速发展,产业规模持续扩大,同时对供电设备的可靠性要求提高。密集型母线槽作为厂房重要电力传输设备若出现触头温升过高,可能导致停电甚至烧毁等故障,会对生产线安全和企业效…

Java反序列化漏洞分析

相关学习资料# http://www.freebuf.com/vuls/90840.htmlhttps://security.tencent.com/index.php/blog/msg/97http://www.tuicool.com/articles/ZvMbInehttp://www.freebuf.com/vuls/86566.htmlhttp://sec.chinabyte.com/435/13618435.shtmlhttp://www.myhack58.com/Article/ht…

【Qt笔记】QTabWidget控件详解

目录 引言 一、基本功能 二、核心属性 2.1 标签页管理 2.2 标签位置 2.3 标签形状 2.4 标签可关闭性 2.5 标签可移动性 三、信号与槽 四、高级功能 4.1 动态添加和删除标签页 4.2 自定义标签页的关闭按钮行为 4.3 标签页的上下文菜单 五、样式设置 六、应用示例…

git使用“保姆级”教程1——简介及配置项设置

一、git介绍 Git是一个开源的分布式版本控制系统,用于:敏捷高效地处理任何或小或大的项目。Git 是Linus Torvalds 为了帮助管理Linux内核开发而开发的一个开放源码的版本控制软件。版本控制: 版本控制(Revision control&#xff…

鸿蒙环境服务端签名直传文件到OSS

本文介绍如何在鸿蒙环境下将文件上传到OSS。 背景信息 鸿蒙环境是当下比较流行的操作环境,与服务端签名直传的原理类似,鸿蒙环境上传文件到OSS是利用OSS提供的PutObject接口来实现文件上传到OSS。关于PutObject的详细介绍,请参见PutObject。…

大厂常问的MySQL事务隔离到底怎么回答

什么是事务 事务就是一组原子性的SQL查询,或者说一个独立的工作单元。事务内的语句,要么全部执行成功,要么全部执行失败。 关于事务银行系统的应用是解释事务必要性的一个经典例子。 假设一个银行的数据库有两张表:支票表&#x…

OpenAI o1大模型:提示词工程已死

OpenAI 最近发布了最新大模型 o1,通过强化学习训练来执行复杂的推理任务,o1 在多项基准测试中展现了博士级别的推理能力,甚至在某些情况下可以与人类专家相媲美。 当你使用 o1 的时候,会发现文档中多了一项提示词建议。 翻译一下&…

OBB-最小外接矩形包围框-原理-代码实现

前言 定义:OBB是相对于物体方向对齐的包围盒,不再局限于坐标轴对齐,因此包围点云时更加紧密。优点:能够更好地贴合物体形状,减少空白区域。缺点:计算较为复杂,需要计算物体的主方向&#xff0c…

二叉树的遍历【C++】

对于二叉树系列的题,必须要会遍历二叉树。 遍历的有:深度优先:前序、中序、后序,广度优先:层序遍历 什么序是指处理根节点在哪个位置,比如前序是指处理节点顺序:根左右。 接下来要说明的是&…

深入浅出Docker

1. Docker引擎 Docker引擎是用来运行和管理容器的核心软件。通常人们会简单的将其指代为Docker或Docker平台。 基于开放容器计划(OCI)相关的标准要求,Docker引擎采用了模块化的设计原则,其组件是可替换的。 Docker引擎由如下主…

从理论到实践:全面指导企业实现数字化转型的战略路径

全球企业数字化转型的必然性 在全球范围内,数字化转型成为了企业战略中的核心命题。随着云计算、大数据、人工智能等新兴技术的快速发展,企业的运营模式、管理体系及客户体验正在发生深刻的变革。数字技术不仅为企业带来了新的商业机会,还使…

【Elasticsearch】-图片向量化存储

需要结合深度学习模型 1、pom依赖 注意结尾的webp-imageio 包&#xff0c;用于解决ImageIO.read读取部分图片返回为null的问题 <dependency><groupId>org.openpnp</groupId><artifactId>opencv</artifactId><version>4.7.0-0</versio…

【2025】中医药健康管理小程序(安卓原生开发+用户+管理员)

博主介绍&#xff1a; ✌我是阿龙&#xff0c;一名专注于Java技术领域的程序员&#xff0c;全网拥有10W粉丝。作为CSDN特邀作者、博客专家、新星计划导师&#xff0c;我在计算机毕业设计开发方面积累了丰富的经验。同时&#xff0c;我也是掘金、华为云、阿里云、InfoQ等平台…

海报制作哪个软件好?这些在线工具不容错过

国庆节的脚步越来越近&#xff0c;不少公司正计划利用这个时机开展一些特别的庆典活动。 在这些活动中&#xff0c;海报作为一种传统的宣传方式&#xff0c;仍然是不可或缺的。但在制作海报时&#xff0c;我们可能会遇到创意瓶颈、时间限制或者预算约束等问题。 幸运的是&…

高棉语翻译神器上线!中柬互译,OCR识别,语音翻译一应俱全,《柬埔寨语翻译通》App

全新的高棉语翻译神器已经正式上架&#xff01; 无论你是安卓还是iOS用户&#xff0c;现在都可以轻松开始使用&#xff0c;开启你的翻译之旅&#xff01; 这款应用不仅仅是一个简单的翻译工具&#xff0c;它还支持中文与高棉语的双向翻译。翻译结果可以语音播放&#xff0c;翻…

AI服务器是什么?为什么要用AI服务器?

AI服务器的定义 AI服务器是一种专门为人工智能应用设计的服务器&#xff0c;它采用异构形式的硬件架构&#xff0c;通常搭载GPU、FPGA、ASIC等加速芯片&#xff0c;利用CPU与加速芯片的组合来满足高吞吐量互联的需求&#xff0c;为自然语言处理、计算机视觉、机器学习等人工智…

企业微信-前往服务商后台页面对接解决方案

序 我会告诉你在哪里点我会告诉你在哪里配置点下去他只返回auth_code的&#xff0c;我怎么登录 正文 他是在这个位置 是这样&#xff0c;应用授权安装第三方应用后&#xff0c;企业微信&#xff08;管理员角色&#xff09;是可以从pc端企业后台点第三方应用的。 如果我没记…

【余弦相似度】

余弦相似度 又称为余弦距离&#xff0c;利用两个向量之间的夹角的余弦值来衡量两个向量的余弦相似度&#xff0c;两个向量夹角越小&#xff0c;余弦值越接近1。 向量模&#xff08;向量长度&#xff09;计算方法&#xff1a; n维向量的相似度计算&#xff1a; 余弦相似度的取…

黑盒测试 | 挖掘.NET程序中的反序列化漏洞

通过不安全反序列化漏洞远程执行代码 img 今天&#xff0c;我将回顾 OWASP 的十大漏洞之一&#xff1a;不安全反序列化&#xff0c;重点是 .NET 应用程序上反序列化漏洞的利用。 &#x1f4dd;$ _序列化_与_反序列化 序列化是将数据对象转换为字节流的过程&#xff0c;字节流…

Entity更新坐标不闪烁需采用setCallbackPositions方法赋值

问题描述&#xff1a; 1.new mars3d.graphic.PolygonEntity({在更新点位高度模拟水面上身的时候&#xff0c;会存在闪烁 2.当把addDemoGraphic4添加到图层后&#xff0c;addDemoGraphic1水位变化不闪烁&#xff0c;把addDemoGraphic4注释后&#xff0c;addDemoGraphic1闪烁。…