【从零开始学习深度学习】38. Pytorch实战案例:梯度下降、随机梯度下降、小批量随机梯度下降3种优化算法对比【含数据集与源码】

news2025/1/16 21:14:03

本文将使用一个来自NASA测试不同飞机机翼噪音的数据集,通过梯度下降、随机梯度下降、小批量随机梯度下降这3种优化算法进行模型训练,比较3种训练结果的差异。

目录

  • 1. 梯度下降、随机梯度下降、小批量随机梯度下降区别
  • 2. 读取训练数据
  • 3. 从零实现3种梯度算法并进行训练
    • 3.1 梯度下降训练结果
    • 3.2 随机梯度下降将结果
    • 3.3 小批量随机梯度下降结果
  • 4 .使用Pytorch的optim.SGD实现梯度下降优化算法
    • 4.1 梯度下降训练结果
    • 4.2 随机梯度下降将结果
    • 4.3 小批量随机梯度下降结果
  • 5. 总结

1. 梯度下降、随机梯度下降、小批量随机梯度下降区别

梯度下降:在每一次迭代中,梯度下降使用整个训练数据集来计算梯度,一个epoch周期内参数只更新一次。

随机梯度下降:在每次迭代中,只随机采样一个样本来计算梯度,一个epoch周期内会进行样本数目次参数更新。

小批量随机梯度下降:在每次迭代中随机均匀采样多个样本来组成一个小批量来计算梯度,一个epoch周期内会进行(样本数目/批量大小)次的参数更新。

2. 读取训练数据

获取数据集方法,关注GZH:阿旭算法与机器学习,回复“梯度下降”即可。

该数据集为NASA的测试不同飞机机翼噪音的数据集,数据集一共包含1503个样本,每个样本包含5个特征与1个标签,下面我们将使用该数据集的前1,500个样本进行模型的训练,并比较各个优化算法的区别。

数据集展示:

在这里插入图片描述

%matplotlib inline
import numpy as np
import time
import torch
from torch import nn, optim
import sys
import d2lzh_pytorch as d2l

def get_data_ch7():  
    data = np.genfromtxt('./data/airfoil_self_noise.dat', delimiter='\t')
    # 标准化数据
    data = (data - data.mean(axis=0)) / data.std(axis=0)
    return torch.tensor(data[:1500, :-1], dtype=torch.float32), \
    torch.tensor(data[:1500, -1], dtype=torch.float32) # 前1500个样本(每个样本包含5个特征)

features, labels = get_data_ch7()
features.shape # torch.Size([1500, 5])

3. 从零实现3种梯度算法并进行训练

下面实现一个通用的训练函数,它初始化一个线性回归模型,然后可以使用梯度下降、随机梯度下降和小批量随机梯度下降算法来训练模型。

# 参数优化器
def sgd(params, states, hyperparams):
    for p in params:
        p.data -= hyperparams['lr'] * p.grad.data
# 训练函数        
def train_ch7(optimizer_fn, states, hyperparams, features, labels,
              batch_size=10, num_epochs=2):
    # 初始化模型,初始化一个线性回归模型
    net, loss = d2l.linreg, d2l.squared_loss
    
    w = torch.nn.Parameter(torch.tensor(np.random.normal(0, 0.01, size=(features.shape[1], 1)), dtype=torch.float32),
                           requires_grad=True)
    b = torch.nn.Parameter(torch.zeros(1, dtype=torch.float32), requires_grad=True)

    def eval_loss():
        return loss(net(features, w, b), labels).mean().item()

    ls = [eval_loss()]
    data_iter = torch.utils.data.DataLoader(
        torch.utils.data.TensorDataset(features, labels), batch_size, shuffle=True)
    
    for _ in range(num_epochs):
        start = time.time()
        for batch_i, (X, y) in enumerate(data_iter):
            l = loss(net(X, w, b), y).mean()  # 使用平均损失
            
            # 梯度清零
            if w.grad is not None:
                w.grad.data.zero_()
                b.grad.data.zero_()
                
            l.backward()
            optimizer_fn([w, b], states, hyperparams)  # 迭代模型参数
            if (batch_i + 1) * batch_size % 100 == 0:
                ls.append(eval_loss())  # 每100个样本记录下当前训练误差
    # 打印结果和作图
    print('loss: %f, %f sec per epoch' % (ls[-1], time.time() - start))
    d2l.set_figsize()
    d2l.plt.plot(np.linspace(0, num_epochs, len(ls)), ls)
    d2l.plt.xlabel('epoch')
    d2l.plt.ylabel('loss')

3.1 梯度下降训练结果

当批量大小为样本总数1,500时,使用的是梯度下降。梯度下降的1个迭代周期对模型参数只迭代1次。可以看到6次迭代后目标函数值(训练损失)的下降趋向了平稳。

def train_sgd(lr, batch_size, num_epochs=2):
    train_ch7(sgd, None, {'lr': lr}, features, labels, batch_size, num_epochs)

train_sgd(1, 1500, 6)

输出:

loss: 0.245426, 0.013536 sec per epoch

在这里插入图片描述

3.2 随机梯度下降将结果

当批量大小为1时,优化使用的是随机梯度下降。随机梯度下降中,每处理一个样本会更新一次自变量(模型参数),一个迭代周期里会对自变量进行1,500次更新。可以看到,目标函数值的下降在1个迭代周期后就变得较为平缓。

train_sgd(0.005, 1)

输出:

loss: 0.246051, 0.531435 sec per epoch

在这里插入图片描述

虽然随机梯度下降和梯度下降在一个迭代周期里都处理了1,500个样本,但实验中随机梯度下降的一个迭代周期耗时更多。这是因为随机梯度下降在一个迭代周期里做了更多次的自变量迭代,而且单样本的梯度计算难以有效利用矢量计算。

3.3 小批量随机梯度下降结果

当批量大小为10时,优化使用的是小批量随机梯度下降。它在每个迭代周期的耗时介于梯度下降和随机梯度下降的耗时之间。

train_sgd(0.05, 10)

输出:

loss: 0.242805, 0.078792 sec per epoch

在这里插入图片描述

4 .使用Pytorch的optim.SGD实现梯度下降优化算法

在PyTorch里可以直接通过创建optimizer实例来调用优化算法。这能让实现更简洁。下面实现一个通用的训练函数,它通过优化算法的函数optimizer_fn和超参数optimizer_hyperparams来创建optimizer实例。

def train_pytorch_ch7(optimizer_fn, optimizer_hyperparams, features, labels,
                    batch_size=10, num_epochs=2):
    # 初始化模型
    net = nn.Sequential(
        nn.Linear(features.shape[-1], 1)
    )
    loss = nn.MSELoss()
    optimizer = optimizer_fn(net.parameters(), **optimizer_hyperparams)

    def eval_loss():
        return loss(net(features).view(-1), labels).item() / 2

    ls = [eval_loss()]
    data_iter = torch.utils.data.DataLoader(
        torch.utils.data.TensorDataset(features, labels), batch_size, shuffle=True)

    for _ in range(num_epochs):
        start = time.time()
        for batch_i, (X, y) in enumerate(data_iter):
            # 除以2是为了和train_ch7保持一致, 因为squared_loss中除了2
            l = loss(net(X).view(-1), y) / 2 
            
            optimizer.zero_grad()
            l.backward()
            optimizer.step()
            if (batch_i + 1) * batch_size % 100 == 0:
                ls.append(eval_loss())
    # 打印结果和作图
    print('loss: %f, %f sec per epoch' % (ls[-1], time.time() - start))
    d2l.set_figsize()
    d2l.plt.plot(np.linspace(0, num_epochs, len(ls)), ls)
    d2l.plt.xlabel('epoch')
    d2l.plt.ylabel('loss')

下面重复第3小节中的实验。

4.1 梯度下降训练结果

train_pytorch_ch7(optim.SGD, {"lr": 0.05}, features, labels, batch_size=1500, num_epochs=6)

输出:

loss: 0.701703, 0.013035 sec per epoch

在这里插入图片描述

4.2 随机梯度下降将结果

train_pytorch_ch7(optim.SGD, {"lr": 0.05}, features, labels, batch_size=1, num_epochs=2)

输出:

loss: 0.288860, 0.586868 sec per epoch

在这里插入图片描述

4.3 小批量随机梯度下降结果

train_pytorch_ch7(optim.SGD, {"lr": 0.05}, features, labels, batch_size=10, num_epochs=2)

输出:

loss: 0.242063, 0.075203 sec per epoch

在这里插入图片描述

5. 总结

  • 小批量随机梯度每次随机均匀采样一个小批量的训练样本来计算梯度。
  • 通常,小批量随机梯度在每个迭代周期的耗时介于梯度下降和随机梯度下降的耗时之间。

如果文章内容对你有帮助,感谢点赞+关注!

关注下方GZH:阿旭算法与机器学习,回复:“梯度下降”即可获取本文数据集、源码与项目文档,欢迎共同学习交流

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/145094.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

多线程与高并发(16)——线程池原理(ThreadPoolExecutor源码)

本文从ThreadPoolExecutor源码来理解线程池原理。 ThreadPoolExecutor使用了AQS、位操作、CAS操作等。在看这篇文章之前,需要具备以下知识: 多线程与高并发(6)——CAS详解(包含ABA问题) 多线程与高并发&…

腾讯三面:进程写文件过程中,进程崩溃了,文件数据会丢吗?

进程写文件(使用缓冲 IO)过程中,写一半的时候,进程发生了崩溃,会丢失数据吗? 答案,是不会的。 因为进程在执行 write (使用缓冲 IO)系统调用的时候,实际上是…

企业宣传片制作配音,我们该从哪里找?

优秀的品质的配音是制作优质企业视频必不可少的硬件条件。因此,许多公司视频配音或旁白声音是由专门从事配音行业的人员配音的。 首先是在宣传视频中配音的作用 1.宣传视频的配音为您建立企业形象 2.宣传视频的配音将为您打开市场 3.宣传视频的配音将使您的宣传…

深入理解Synchronized

Synchronized 底层原理 Synchronized的语义底层是通过一个 Monitor 的对象来完成,其实wait/notify等方法也依赖于 Monitor 对象,这就是为什么只有在同步的块中,拿到锁之后,才能调用wait/notify等方法,否则会抛出java.…

AI助力产品质量检验,基于YOLO实现瓷砖缺陷问题检测识别

在我之前的文章中也写过很多关于生产质检相关的实践文章,一直觉得这块是比较有意思的应用方向,做出来的模型能够以一种更加直观贴切的形式展现出来,瓷砖缺陷问题检测识别也是一个比较老的话题了,今天还是想拿出来具体实践做一下&a…

Golang.org/x库初探1——image库

Golang有一个很有意思的官方库,叫golang.org/x,x可能是extends,experimental,总之是一些在官方库中没有,但是又很有用的库。最近花点时间把这里有用的介绍一下。 Image库 提供更多的图像格式 golang.org/x/image库整…

Linux 网络驱动

1. linux 里面驱动三巨头:字符设备驱动、块设备驱动、网络设备驱动。2.嵌入式网络硬件分为两部分: MAC 和 PHY。如果一款芯片数据手册说自己支持网络,一般都是说的这款 SOC 内置 MAC, MAC 类似 I2C 控制器、SPI 控制器一样的外设。…

Java三大技术平台是什么?

为了使软件开发人员、服务提供商和设备生产商可以针对特定的市场进行开发,SUN公司将Java划分为三个技术平台,它们分别是 JavaSE、 JavaEE和 JavaME。Java SE( Java Platform Standard Edition)标准版,是为开发普通桌面和商务应用程序提供的解…

零宽断言正则表达式替换方案

一、背景 safari浏览器不支持零宽断言正则表达式 二、解决方案 使用其他正则替换零宽断言正则&#xff08;包含&#xff1a;(?<)正向肯定预查、(?<!)正向否定预查、(?)反向肯定预查、(?!)反向否定预查&#xff09; 三、涉及场景 1、仅校验&#xff0c;不取值 如表…

首汽约车驶向极速统一之路!出行平台如何基于StarRocks构建实时数仓?

作者&#xff1a;王满&#xff0c;高级数据架构工程师首汽约车&#xff08;以下简称 “首约”&#xff09;是首汽集团为响应交通运输部号召&#xff0c;积极拥抱互联网&#xff0c;推动传统出租车行业转型升级&#xff0c;加强建设交通强国而打造的网约车出行平台。 在用车服务…

KernelSU: 内核 ROOT 方案, KernelSU KernelSU KernelSU 新的隐藏root防止检测 封号方案

大约一年多以前&#xff0c;我在一篇讲Android 上 ROOT 的过去、现在和未来https://mp.weixin.qq.com/s?__bizMjM5Njg5ODU2NA&mid2257499009&idx1&sn3cfce1ea7deb6e0e4f2ac170cffd7cc1&scene21#wechat_redirect 的文章中提到&#xff1a; 我认为&#xff0c;随…

三菱FX5U 多个表格运行指令 DRVTBL

简述该指令可以用GX Works3预先在表格数据中设定的控制方式的动作&#xff0c;&#xff08;连续或步进&#xff09; 执行多行。 本文演示了步进执行多行。指令解释2.1梯形图中的指令第一个参数&#xff1a;输出脉冲的轴编号 &#xff0c;K1,K2,K3,K4... 第二个参数&#xff1a;…

ESP8266 Windows开发环境搭建(IDE1.5)好用不骗人

最近一个项目需要用ESP8266&#xff0c;找了很多文章进行环境搭建编译都很问题&#xff0c;不是make Menuconfig 不出来&#xff0c;就是编译报错&#xff0c;现总结如下。 我在自己电脑上没弄出来&#xff0c;就安装了一个虚拟机很干净的环境没有其它开发环境影响。 提前去官…

逆向入门|全国建筑市场监管公共服务平台JS逆向

看了志远的公开课&#xff0c;自己做一下练手。 全国建筑市场监管公共服务平台&#xff08;四库一平台&#xff09; 先点到 数据这里打开f12看一眼 第一个就是 https://jzsc.mohurd.gov.cn/api/webApi/dataservice/query/comp/list?pg1&pgsz15&total450 取这个地址…

线段树讲解

0、引入 假设给定一个长度为 1001 的数组&#xff0c;即下标 0 到 1000。 现在需要完成 3 个功能&#xff1a; add(1, 200, 6); //给下标 1 到 200 的每个数都加 6&#xff1b; update(7, 375, 4); //下标 7 到 375 的数全部修改为 4 query(3, 999); //下标 3 到 999 所有数…

深入理解如何利用PWM驱动舵机:ESP32驱动DS1115舵机

深入理解如何利用PWM驱动舵机&#xff1a;ESP32驱动DS1115舵机DS1115舵机技术规格举例说明之前做了一个项目&#xff0c;关于ESP32驱动DS1115舵机&#xff0c;但是在项目运行的过程中由于学艺不精&#xff0c;导致电机抽搐 &#x1f635;‍&#x1f4ab;&#xff0c;所以特意拜…

声纹识别可靠评测

分享嘉宾 | 李蓝天 文稿整理 | William 1 Introduction 声纹识别的发展&#xff0c;非常迅猛&#xff0c;在一些基准上取得了不错的效果&#xff0c;但如果将其部署到一个实际的应用系统里面&#xff0c; 从应用方的反馈来看&#xff0c;纹识别在很多场景里的鲁棒性并不理想。…

聚观早报 | 亚马逊将裁员17000人;苹果砍单MacBook等产品线架构

今日要闻&#xff1a;亚马逊将裁员17000人&#xff1b;苹果砍单MacBook等产品线&#xff1b;京东科技调整组织架构&#xff1b;小米x徕卡团队获技术大奖&#xff1b;必应搜索或将纳入ChatGPT亚马逊将裁员17000人 1 月 5 日消息&#xff0c;知情人士称&#xff0c;亚马逊新一轮裁…

正版授权|FastStone Capture 专业屏幕截图录屏工具软件 商业版,支持商业用途。

现在截图对每个人来说都是一个必不可少的功能。QQ软件截图、360游览器截图等都是相对简单快速的途径。但是如果你对截图有更多的要求&#xff0c;那么这里推荐一款截图软件&#xff0c;它就是FastStone Capture。这个对于商城老用户来说&#xff0c;几乎是接近人手一份。强大的…

【VUE3】保姆级基础讲解(六)Axios库

目录 Axios介绍与原生的差异 发送常见的请求和配置选项 1、发送request请求 baseURL &#xff1a; 2、发送get请求 3、发送post请求 axios.all Axios创建新的实例 请求和响应拦截 请求拦截 响应拦截 Axios介绍与原生的差异 Axios其实就是一个网络请求库 与原生的差异&…