pytorch深度学习基础 7(简单的的线性训练,SGD与Adam优化器)

news2024/9/22 13:24:50

接下来小编来讲一下一些优化器在线性问题中的简单使用使用,torch模块中有一个叫optim的子模块,我们可以在其中找到实现不同优化算法的类

SGD随机梯度下降

基本概念

  • 定义:随机梯度下降(SGD)是一种梯度下降形式,对于每次前向传递,都会从总的数据集中随机选择一批数据,即批次大小1。
  • 参数更新过程:这个参数的更新过程可以描述为随机梯度下降法,随机梯度下降(SGD)是一种简单但非常有效的方法,多用于支持向量机,逻辑回归(LR)等凸损失函数下的线性分类器的学习3。

实现步骤

  1. 随机抽样:从总的数据集中随机抽样一批数据1。
  2. 计算梯度:前向和后向运行网络,计算梯度(根据抽样的数据)1。
  3. 应用更新:应用梯度下降更新1。
  4. 重复循环:重复上述步骤,直到出现收敛情况或者循环被其他机制暂停(即迭代次数)1。

特点与优势

  • 提高训练速度:随机梯度下降算法可以使用一小部分样本来计算梯度,从而大大提高了训练速度和鲁棒性3。
  • 更快收敛:相比于批量梯度,这样的方法更快收敛,因此使用也比较广泛4。
    import numpy as np
    import torch
    torch.set_printoptions(edgeitems=2, linewidth=75) # 设置打印格式
    
    # 初始化参数
    t_c = torch.tensor([0.5, 14.0, 15.0, 28.0, 11.0,
                        8.0, 3.0, -4.0, 6.0, 13.0, 21.0])
    t_u = torch.tensor([35.7, 55.9, 58.2, 81.9, 56.3, 48.9,
                        33.9, 21.8, 48.4, 60.4, 68.4])
    t_un = 0.1 * t_u # 归一化处理,防止梯度爆炸
    
    
    def model(t_u, w, b):
        return w * t_u + b
    
    def loss_fn(t_p, t_c):
        squared_diffs = (t_p - t_c)**2
        return squared_diffs.mean()
    
    import torch.optim as optim # 导入优化器模块
    # dir(optim)
    '''
    params = torch.tensor([1.0, 0.0], requires_grad=True)  # 创建一个params张量
    learning_rate = 1e-5 # 学习率
    optimizer = optim.SGD([params], lr=learning_rate) # 定义了一个使用随机梯度下降(Stochastic Gradient Descent, SGD)
                                                        # 算法的优化器,并设置了学习率'''
    params = torch.tensor([1.0, 0.0], requires_grad=True)
    learning_rate = 1e-2
    optimizer = optim.SGD([params], lr=learning_rate)
    t_p = model(t_u, *params)
    loss = loss_fn(t_p, t_c)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step() # 用于更新模型参数的一个关键函数 通常在每个训练迭代(epoch)中调用一次,
                     # 以根据损失函数的梯度调整模型的权重和偏置,从而最小化损失函数。
    
    def training_loop(n_epochs, optimizer, params, t_u, t_c):
        for epoch in range(1, n_epochs + 1):
            t_p = model(t_u, *params)
            loss = loss_fn(t_p, t_c)
    
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
    
            if epoch % 500 == 0:
                print('Epoch %d, Loss %f' % (epoch, float(loss)))
    
        return params
    
    
    
    training_loop(
        n_epochs = 5000,
        optimizer = optimizer,
        params = params,
        t_u = t_un,
        t_c = t_c)

    运行结果如下

 可以明显看出来,SGD的损失值下降还是非常快速的,收敛非常快

Adam优化器

Adam是一个更加复杂的优化器,其中学习率是自适应设置的,此外,它对参数的缩放不太敏感,我们可以不对数据进行归一化处理,甚至可以把学习率设置为1e-1

上面的代码的函数保持不变,我们只需要修改优器的选择就行

Adam 优化器(自适应矩估计优化器)

基本概念

Adam 优化器是一种自适应学习率的优化算法,结合了动量梯度下降和 RMSprop 算法的思想。它通过自适应地调整每个参数的学习率,从而在训练过程中加速收敛。1 2

使用方法

  1. 初始化参数:首先需要初始化 Adam 优化器的参数,包括学习率、动量因子、指数衰减率等。这些参数的选择通常需要经验和实验来确定。3
  2. 计算梯度:在每个训练迭代中,需要计算损失函数对各个参数的梯度。这可以通过反向传播算法来实现。3
  3. 更新参数:使用 Adam 优化器的更新公式来更新模型的参数。Adam 优化器的更新公式包括两个主要的步骤:计算梯度的一阶矩估计和二阶矩估计,然后将它们结合起来对参数进行更新。3
  4. 调整学习率:在训练过程中,可以根据需要动态调整学习率。例如,可以使用学习率衰减策略来提高模型在训练后期的稳定性和泛化能力。3

相关技巧和注意事项

  • 参数调节:不同的问题可能适合不同的 Adam 优化器参数设置。可以通过尝试不同的参数组合来找到最佳的性能。3
  • 正则化:在使用 Adam 优化器时,可以结合正则化技术来降低模型的过拟合风险。例如,可以使用 L1 正则化或 L2 正则化来约束模型的复杂度。3
  • 批量归一化:在深度学习中,批量归一化是一种常用的技术,它可以加速训练过程并提高模型的泛化能力。可以在使用 Adam 优化器时结合批量归一化技术来进一步优化模型。3

总之,Adam 优化器是一种强大而灵活的优化算法,在机器学习和深度学习任务中广泛应用。

这里有个小问题,Adam优化器被称为自适应优化器,可是为什么还需要设置一个学习率呢,其实这两者之间并不矛盾

它根据参数的更新历史来调整每个参数的学习率。尽管Adam具有自适应的特性,但在使用时仍然需要设置一个全局的学习率(也称为初始学习率或基础学习率)。这个全局学习率对Adam的优化过程有重要影响,它控制着参数更新的总体规模。

Adam优化器会根据参数的梯度、梯度的平方以及参数的更新历史来计算每个参数的自适应学习率。然而,这个自适应学习率是在全局学习率的基础上进行调整的。如果全局学习率设置得太高,可能会导致参数更新过大,进而使模型变得不稳定;如果全局学习率设置得太低,则可能导致模型训练过慢,无法有效收敛。

因此,在使用Adam优化器时,仍然需要谨慎选择全局学习率,以确保模型能够稳定且有效地进行训练。在实际应用中,通常会通过实验来调整全局学习率,以找到最适合当前任务和学习数据的值。

# Adam优化器
params = torch.tensor([1.0, 0.0], requires_grad=True)
learning_rate = 1e-1
optimizer = optim.Adam([params], lr=learning_rate) # <1>

training_loop(
    n_epochs = 5000,
    optimizer = optimizer,
    params = params,
    t_u = t_u, # <2>
    t_c = t_c

可以看到,这里传入的参数不再是归一化后的t_un,而是原始的数据t_u,可见Adam优化器确实对参数的缩放不太敏感,那我们来看看训练的效果吧

 对比前面的SGD,Adam的效果也不差,所以在实际情况中还需要选择适合自己模型的优化器,选择正确的优化器可以显著提高模型的训练效果和收敛速度

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2068781.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

mysql中出现错误1138-Invalid use of NULL value

问题&#xff1a;1138-Invalid use of NULL value 解决&#xff1a; 问题是当前字段中&#xff0c;有null的值&#xff0c;简单来说就是&#xff0c;你表里有空值&#xff0c;不能设置不为空&#xff01;&#xff01;&#xff01; 把空的值删掉重新设计就好了

第一次重大人工智能失败刚刚发生

这终于发生了。我们迎来了第一家真正意义上的 AI 公司惨败。 Inflection是一家由比尔盖茨、埃里克施密特、微软等人投资的公司&#xff0c;它成为第一家被冲进马桶的生成式人工智能相关公司。 他们最重要的产品是Pi&#xff0c;ChatGPT 的竞争对手&#xff0c;专注于成为友好且…

SpringAOP使用详解

AOP使用详解 首先创建maven项目 添加依赖在pom.xml里 创建三层结构和spring.xml文件&#xff0c;只要用到注解就得写扫描包在spring.xml里 上篇文章的知识点总结 对上篇文章excution详细解释 如果把前置通知修改成这个代表只有带有Logger注解的才会生效 合并注解的方法用&…

Windows权限维持实战

目录 介绍步骤 介绍 在攻击过程中中对于拿到的shell或钓上来的鱼&#xff0c;目前比较流行用CS做统一管理&#xff0c;但实战中CS官方没有集成一键权限维持的功能&#xff0c;为了将该机器作为一个持久化的据点&#xff0c;种植一个具备持久化的后门&#xff0c;从而随时可以连…

ffmpeg最新5.1.6版本源码安装

一、编译安装需要的开源编码格式&#xff1a; 首先在编译安装这些开源编码格式之前,我们要明白为啥需要他们&#xff1a; aacx264x265 为啥需要呢&#xff1f;如果你对ffmpeg稍微了解的话&#xff0c;ffmpeg本身是一个框架&#xff0c;自身默认并没有支持这三种编码格式&…

Vue3 后台管理系统项目 前端部分

这里写目录标题 1 创建Vue3项目1.1 相关链接1.2 Vue Router1.3 Element1.4 scss1.5 mitt1.6 axios1.7 echarts1.8 配置vite.config.js 2 CSS部分2.1 样式穿透2.2 :style &#xff1a;在样式中使用插值语法 3. ElementUI3.1 rules&#xff1a; 数据验证3.2 修改element.style中的…

《计算机网络-期末模拟卷》

一、分析题&#xff08;每题 6 分&#xff0c;共 36 分&#xff09; 1.请分析图 1-1 所展示的是哪种互联网接入技术&#xff0c;分析此接入技术的优势&#xff0c;并介绍你所了解的其他互联网接入技术。&#xff08;至少写 3 个&#xff09; 二、计算问答题&#xff08;每题 5…

docker应用

打包传输 1.将镜像打包 #查看帮助文件 docker --help #找到save&#xff0c;可以将镜像保存为一个tar包 docker save --help #查看save使用方式 #查看现有的镜像 docker images # docker save --output centos.tar centos:latest ls ...centos.tar... 可以将tar发送给其他用户…

JS基础进阶3-DOM事件

DOM事件流 一、定义 DOM事件流指的是从页面接收事件的顺序。这个路径包括了事件的捕获阶段、目标阶段和冒泡阶段。 图片来源黑马pink老师PPT 二、事件流阶段 DOM事件流涉及三个主要阶段&#xff1a; 捕获阶段&#xff08;Capturing Phase&#xff09;&#xff1a; 事件从…

QtChart1-基础入门

Qt Charts概述 Qt Charts模块是一组易于使用的图标组件&#xff0c;它基于Qt的Graphice View架构&#xff0c;其核心组件是QChartView和QChart。QChartView的父类是QGraphicsView&#xff0c;就是Graphics View架构中的视图组件&#xff0c;所以&#xff0c;QChartView是用于显…

Apollo9.0 PNC源码学习之Planning模块—— Lattice规划(六):横纵向运动轨迹评估

参考文章: (1)Apollo6.0代码Lattice算法详解——Part6:轨迹评估及碰撞检测对象构建 (2)自动驾驶规划理论与实践Lattice算法详解 0 前言 横纵向运动轨迹的评估,主要通过构建定点巡航和定点停车两个场景下,对纵向运动参考速度、加速度、加加速度的大小进行检验和过滤,然…

Vue3源码调试-第一篇

前言 相信大家在前端从业生涯中都会被问到过&#xff0c;有了解过Vue源码吗&#xff1f;我是没有的&#xff0c;所以今天就来读好吧&#xff0c;浅浅读一下&#xff0c;顺便记录一下。 那究竟怎么读&#xff0c;从哪里读&#xff0c;我就不啰嗦了&#xff0c;直接给大家一个链…

python dash框架 油气田可视化软件设计文档

V1.1:机器学习框架(神经网络) 时间范围优化 表格布局优化 添加前端设计元素布局 V1.0&#xff1a;基础布局和对应计算函数 要求 首先第一部分是通过神经网络预测天然气流量&#xff0c;其中输入开始时间和截止时间是为了显示这一段时间内的天然气流量预测结果 第二部分&…

小型空气净化器可以除猫毛吗?宠物空气净化器评测推荐

前段时间我有个病人&#xff0c;诊断出来肺结节&#xff0c;他第一反应就是说他家养着好几只猫&#xff0c;会不会是吸入宠物毛发导致的肺结节。有些结节确实跟宠物有关系&#xff0c;如果是对毛发过敏、或者是对排泄物过敏&#xff0c;养宠物就会增加患结节的概率。不过就算是…

推荐一款AI智能编程助手CodeGeeX

最近&#xff0c;使用了一款AI智能编程助手CodeGeeX&#xff0c;感觉还不错&#xff0c;推荐给大家。 官网地址&#xff1a;https://codegeex.cn/ 一、安装教程 IDEA中安装插件&#xff1a;https://codegeex.cn/downloadGuide#idea VSCode中安装插件&#xff1a;https://codege…

八股(5)——数据库

八股&#xff08;5&#xff09;——数据库 4.1 数据库基础什么是数据库, 数据库管理系统, 数据库系统, 数据库管理员?什么是元组, 码, 候选码, 主码, 外码, 主属性, 非主属性&#xff1f;主键和外键有什么区别?为什么不推荐使用外键与级联&#xff1f;什么是 ER 图&#xff1…

TD学习笔记————中级教程总结(中)

目录 四、生成艺术 问题: CHOP TO放置后直接报错 附着不上线 五、Python Lists 与 Python Dictionaries 问题: 使用for的格式要求 显示numRows错误 List中表格定义报错 六、Replicate 与 Instance 问题: 传递处理好的噪音后不变化 Renderpass区分线和字时不起作用…

安科瑞智能物联网关:重塑能源管理新纪元,远程智控尽在“掌”握

在数字化转型浪潮中&#xff0c;能源管理与工业自动化领域正经历着前所未有的变革。安科瑞智能物联网关-智能通信管理机不仅重新定义了智能监控与保护装置的通信管理模式&#xff0c;更为能源数据采集与远程控制提供了前所未有的高效解决方案。 安科瑞智能物联网关&#xff0c…

qt开发环境搭建Qt Creator并创建Demo项目

一 Qt Creator工具下载&#xff1a;工具下载链接&#xff1a; Index of /archive/online_installers/4.8 (qt.io) 下载后点击安装&#xff0c;没有账号得先注册一个账号&#xff0c;如下图&#xff0c;然后点击下一步 随便填点&#xff0c;我填"abc"&#xff0c;然…

Unity读取Android本地图片

unity读取Android本地图片 一、安卓读取路径 安卓路径&#xff1a;“file:///storage/emulated/0/”自己图片的路径 例&#xff1a;“file:///storage/emulated/0/small.jpg” 二、unity搭建 使用UI简单搭个界面 三、新建一个脚本 代码内容如下 using System.Collectio…