8 从0开始学PyTorch | PyTorch中自动计算梯度、使用优化器

news2024/9/28 11:12:40

上一节,我们写了很多代码,但是不知道你有没有注意,那些代码看起来跟PyTorch关系并不是很大啊,貌似很多都是Python原生代码?

如果你是这样的感觉,那我要告诉你,你感觉的没有错。前面主要在于机制的理解,我们实际上用手动的方式实现了一遍模型求解的过程,主要的改进就是使用了PyTorch里面的tensor数据结构,但是这还不够,PyTorch提供了很多强大的功能,当然不只是在处理tensor上面的,接下来我们就要看一下,PyTorch提供的方便我们运算的各种功能。

自动计算梯度

上次我们用手动求导计算梯度,可是你别忘了,那个包浆的温度计变换只需要2个参数,而如果有10亿个参数,那用手可是求导不过来啊。不要怕,PyTorch给出了自动求导机制。在PyTorch中,可以存储张量的生产路径,包括一个张量经过了何种计算,得到的结果有哪些,借助这个能力,对于我们用到的tensor,就可以找到它的爷爷tensor和它的爷爷的爷爷tensor,并且自动对这些操作求导,所有这些只需要你的一句“我需要autograd功能”。

我们来看一下实现方式,如果你已经把上一节的代码关了,没关系,这里我们从头写起,包括原始数据,紧接着是模型函数和loss函数,最后是给params初始化,这里唯一的区别就是,我们之前的初始化参数是这么写的:
params = torch.tensor([1.0, 0.0])
现在我们在tensor方法内部加上了一个参数requires_grad,并给它赋值True

%matplotlib inline
import numpy as np
import torch
torch.set_printoptions(edgeitems=2)

t_c = torch.tensor([0.5, 14.0, 15.0, 28.0, 11.0, 8.0,
                    3.0, -4.0, 6.0, 13.0, 21.0])
t_u = torch.tensor([35.7, 55.9, 58.2, 81.9, 56.3, 48.9,
                    33.9, 21.8, 48.4, 60.4, 68.4])
t_un = 0.1 * t_u

def model(t_u, w, b):
    return w * t_u + b

def loss_fn(t_p, t_c):
    squared_diffs = (t_p - t_c)**2
    return squared_diffs.mean()

#唯一改变
params = torch.tensor([1.0, 0.0], requires_grad=True)

加入这个requires_grad=True之后,意味着所有后续跟params相关的调用和操作记录都会被保留下来,任何一个经过params变换得到的新的tensor都可以追踪它的变换记录,如果它的变换函数是可微的,导数的值会被自动放进params的grad属性中。

让我们看一下代码

loss = loss_fn(model(t_u, *params), t_c)
loss.backward() #对loss进行反向传播

#输出params的梯度看看
params.grad 
outs:tensor([4517.2969,   82.6000])

对于上面输出的grad有没有一丝熟悉的味道?这个结果跟我们之前第一次执行手动编写的grad函数时的结果是一样的,也就是说这里的自动方法跟我们前面手动编写的逻辑可以认为是一样的。

image.png

值得注意的是,我们实际的运算往往不是这么简单的,可能会涉及到若干个requires-grad为True的张量进行运算,在这种情况下,PyTorch会把整个计算图上的损失的导数,并把这些结果累加到grad属性中。

这里涉及到一个计算图的概念,大意是在PyTorch底层为tensor及运算构建了一个图关系,前面说到的关于反向传播也都是基于这个图上的存储关系进行的。这关系到PyTorch底层的运行逻辑,这里我们先不做太多的探讨,如果你对PyTorch的底层运行逻辑感兴趣可以进行深度的学习,否则,在这里我们还是先来看看它到底怎么去用的问题。

在调用backward()的时候,将会把导数累加在叶节点上,如果提前调用backward(),则会再次调用backward(),每个叶节点上的梯度将在上一次迭代中计算的梯度之上累加(求和),这会导致梯度计算的结果出错。

这里所谓的叶节点你可以认为就是我们最开始的那个tensor,在反向传播的路径里面它处于最末端,所以可以称为叶节点,当然在计算图上面,我们也可以把一些中间节点强行设置成叶节点,当然这就会使得梯度不再向下传倒。

如果要防止这个问题发生,我们需要在每次迭代的时候手动的把梯度置为零。这看起来或多或少有点麻烦,为啥不自动在迭代的时候清零呢?按官方解释就是增加代码的灵活性,万一你需要不清零呢?所以这个事情我们先把它记住,不去探究深层次原因,将来如果遇到自然就能明白了。

if params.grad is not None:
    params.grad.zero_()

这时候我们使用自动反向传播机制来改写我们之前的代码

with torch.no_grad的作用
在该模块下,所有计算得出的tensor的requires_grad都自动设置为False。

def training_loop(n_epochs, learning_rate, params, t_u, t_c):
    for epoch in range(1, n_epochs + 1):
        if params.grad is not None:
            params.grad.zero_()
        
        t_p = model(t_u, *params) 
        loss = loss_fn(t_p, t_c)
        loss.backward()
        
        with torch.no_grad():  
            params -= learning_rate * params.grad

        if epoch % 500 == 0:
            print('Epoch %d, Loss %f' % (epoch, float(loss)))
            
    return params

输出可以自己看一下,跟之前并没有什么区别,loss徘徊在2.9左右,不同的是我们让PyTorch自动的处理了梯度计算。

优化器

然后我们再来看另一个可以优化的地方。就是关于参数更新这块,
params -= learning_rate * params.grad
我们这里采用的通过计算梯度,并按照梯度方向更新参数,这个计算称作梯度下降方法,而且是最原始的批量梯度下降方法。在每一个epoch,所有训练样本都会用于计算梯度,这个方案很稳妥,但是如果我们的样本很多的时候就不妙了,比如说计算一次就需要耗费大量的时间。

在PyTorch中提供了一个optim模块,里面收集了很多种优化方法

dir() 函数不带参数时,返回当前范围内的变量、方法和定义的类型列表;带参数时,返回参数的属性、方法列表。如果参数包含方法dir(),该方法将被调用。如果参数不包含dir(),该方法将最大限度地收集参数信息。

import torch.optim as optim

dir(optim)
outs:
['ASGD',
 'Adadelta',
 'Adagrad',
 'Adam',
 'AdamW',
 'Adamax',
 'LBFGS',
 'NAdam',
 'Optimizer',
 'RAdam',
 'RMSprop',
 'Rprop',
 'SGD',
 'SparseAdam',
 '__builtins__',
 '__cached__',
 '__doc__',
 '__file__',
 '__loader__',
 '__name__',
 '__package__',
 '__path__',
 '__spec__',
 '_functional',
 '_multi_tensor',
 'lr_scheduler',
 'swa_utils']

这些大写字母开头的就是优化算法,可以看到非常著名的adam,adamx等等。关于每个优化器都是怎么去优化的,这里就先不讨论了,我们先看优化器怎么用。

image.png

优化器接收参数tensor,读取他们的grad属性并对其执行更新的操作,然后再把接力棒交给模型。

接下来让我们使用优化器来实现梯度下降。我们使用了一个叫SGD的优化器,这个称为随机梯度下降,这个方法是每次计算只随机采用一个样本,大大降低了计算成本。

params = torch.tensor([1.0, 0.0], requires_grad=True)
learning_rate = 1e-5
optimizer = optim.SGD([params], lr=learning_rate)

t_p = model(t_u, *params)
loss = loss_fn(t_p, t_c)
loss.backward()

optimizer.step() #调用step()方法,就会更新params的值

params
outs:
tensor([ 9.5483e-01, -8.2600e-04], requires_grad=True)

正如我们前面所说,我们需要在执行反向传播前手动的把梯度归零,然后再次把训练跑起来。

def training_loop(n_epochs, optimizer, params, t_u, t_c):
    for epoch in range(1, n_epochs + 1):
        t_p = model(t_u, *params) 
        loss = loss_fn(t_p, t_c)
        
        optimizer.zero_grad() #手动调用梯度归零方法
        loss.backward()
        optimizer.step()

        if epoch % 500 == 0:
            print('Epoch %d, Loss %f' % (epoch, float(loss)))
            
    return params

params = torch.tensor([1.0, 0.0], requires_grad=True)
learning_rate = 1e-2
optimizer = optim.SGD([params], lr=learning_rate)

training_loop(
    n_epochs = 5000, 
    optimizer = optimizer,
    params = params, 
    t_u = t_un,
    t_c = t_c)

outs:
Epoch 500, Loss 7.860115
Epoch 1000, Loss 3.828538
Epoch 1500, Loss 3.092191
Epoch 2000, Loss 2.957698
Epoch 2500, Loss 2.933134
Epoch 3000, Loss 2.928648
Epoch 3500, Loss 2.927830
Epoch 4000, Loss 2.927679
Epoch 4500, Loss 2.927652
Epoch 5000, Loss 2.927647
tensor([  5.3671, -17.3012], requires_grad=True)

这个地方你可以把优化器换成你喜欢的一个其他优化器来试试,当然你也可以去了解一下每个优化器都有什么特点,然后跑起来看看。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/688225.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

下面告诉你音频转换工具有哪些

今天我想和大家聊一聊音频转换工具。你是不是有时候想把一首酷炫的歌曲转换成你喜欢的音频格式,或者想把录音文件转成可编辑的格式?别担心,这里有一些超赞的音频转换工具,可以帮你解决这些问题!无论是从MP3到WAV&#…

武汉大学计算机考研分析

关注我们的微信公众号 姚哥计算机考研 更多详情欢迎咨询 武汉大学(A-)考研难度(☆☆☆☆☆) 武汉大学计算机考研招生学院是计算机学院、国家网络安全学院和测绘遥感信息工程国家重点实验室。目前均已出拟录取名单。 武汉大学计…

Redis的3大特殊数据类型(1)-BitMap

BitMap(位图/位数组)是Redis2.2.0版本中引入的一种新数据类型,该数据类型本质是一个仅含0和1的二进制字符串。因此可以把 Bitmap 想象成一个以位为单位的数组,数组的每个单元只能存储 0 和 1,数组的下标在 Bitmap 中叫做偏移量 offset&#x…

volatile关键字和ThreadLocal

作用: 1.线程的可见性:当一个线程修改一个共享变量时,另外一个线程能读到这个修改的值。 2. 顺序一致性:禁止指令重排序。 线程之间的共享变量存储在主内存中(Main Memory)中,每个线程都一个都…

StarRocks Friends 上海站活动回顾(含 PPT 下载链接)

6月17日, StarRocks & Friends 上海站活动如期而至,近百位社区小伙伴参与交流活动;针对 StarRocks 存算分离、StarRocks 在业界的应用实践、以及 StarRocks 与 BI 结合、湖仓一体规划等话题展开激烈的交流互动。 本文总结了技术交流活动…

未来的彩电,彩电的未来

疫情后的首个线上大促已经结束,“史上投入最大618”也没能抵住彩电市场整体的需求疲软。 根据奥维云网线上推总数据,2023年618期间,中国彩电线上市场零售量规模为249.9万台,同比下降12.9%;零售额规模为79.7亿元&#…

配电柜(箱)使用防雷浪涌保护器的作用和方案

配电箱是电力系统中的重要组成部分,负责将电力从供电系统输送到各个电器设备。然而,由于天气状况和其他因素的影响,电力系统可能会受到雷击引起的浪涌电压的威胁。为了保护配电箱和其中的设备免受浪涌电压的破坏,我们需要在配电箱…

Redis中3大特殊数据结构(2)-HyperLogLog

HyperLogLog算法是法国人Philippe Flajolet 教授发明的一种基数计数概率算法,每个 HyperLogLog 键只需要花费 12 KB 内存,就可以计算接近 2^64 个不同元素的基数。HyperLogLog 适用于大数据量的去重统计,HyperLogLog 提供不精确的去重计数方案…

基于Java+Swing实现餐厅点餐系统

基于JavaSwing实现餐厅点餐系统 一、系统介绍二、系统展示1.主页2.点菜3.下单4.结算5.销售情况(管理员) 三、系统实现四、其他系统五、获取源码 一、系统介绍 该系统针对两个方面的用户,一个是用餐客户,另一个是餐厅管理员。将功…

iOS 17 beta 2有哪些BUG?iOS 17 beta 2推荐升级吗?

虽然iOS 17 beta 2 带来了大量的功能更新,但毕竟是测试版,海量的适配BUG也一同随之而来。 想升级iOS 17 beta 2的用户不妨先查看下目前存在的问题汇总! 一:存储空间更小了 升级beta1后存储空间缩小了大概3G左右,bet…

k8s网络通信

详解Kubernetes网络模型 Kubernetes 是为运行分布式集群而建立的,分布式系统的本质使得网络成为 Kubernetes 的核心和必要组成部分,了解 Kubernetes 网络模型可以使你能够正确运行、监控和排查应用程序故障。 网络所涉及的内容很多,拥有许多…

人人都能生成火爆全网的最不像二维码的二维码!

Sealos 公众号已接入了 GPT-4,完全免费!欢迎前来调戏👇 最近有人展示了使用 Stable Diffusion 创建的艺术二维码。这些二维码是使用定制训练的 ControlNet模型生成的。 但是操作门槛有点高。 你需要 GPU,还需要学习如何使用 Stabl…

diffusion model(二)—— DDIM技术小结

论文地址:Denoising Diffusion Implicit Models github地址:https://github.com/ermongroup/ddim 背景 去噪扩散概率模型 (DDPM1) 在没有对抗训练的情况下实现了高质量的图像生成,但其采样过程依赖马尔可夫假设,需要较多的时间…

SoapUI实践:自动化测试、压力测试、持续集成

因为项目的原因,前段时间研究并使用了 SoapUI 测试工具进行自测开发的 api。下面将研究的成果展示给大家,希望对需要的人有所帮助。 如果你想学习自动化测试,我这边给你推荐一套视频,这个视频可以说是B站播放全网第一的自动化测试…

Android View的渲染过程

原文链接 Android View的渲染过程 对于安卓开发猿来说,每天都会跟布局打交道,那么从我们写的一个布局文件,到运行后可视化的视图页面,这么长的时间内到底 发生了啥呢?今天我们就一起来探询这一旅程。 View tree的创建…

A Survey on In-context Learning

这是LLM系列相关综述文章的第二篇,针对《A Survey on In-context Learning》的翻译。 上下文学习综述 摘要1 引言2 概述3 定义和公式4 模型预热4.1 监督上下文训练4.2 半监督上下文训练 5 示例设计5.1 示例组织5.1.1 示例选择5.1.2 示例排序 5.2 示例形式化5.2.1 指…

Segment Anything Model Geospatial (SAM-Geo) 创建交互式地图

SAM-Geo是一个用于地理空间数据的Python 包,可在 PyPI 和 conda-forge 上使用。本节教程是SAM-Geo官网的一个教程,根据输入提示范围创建mask遮罩。后面还有一种基于提示词创建的方式,如只输出房屋、道路、树木等,下一期我们专门写…

IEEE Transactions的模板中,出现subfig包和fontenc包冲突的问题,怎么解决?

IEEE Transactions的模板中,出现subfig包和fontenc包冲突的问题,怎么解决? 本文章记录如何在IEEE Transactions的模板中,出现了subfig包和fontenc包冲突的问题,该怎么解决。 目录 IEEE Transactions的模板中&#xff…

ubuntu系统解压.rar文件问题与解决办法

ubuntu20.04解压rar文件 问题解决办法 问题 在ubuntu系统中,直接解压rar文件可能会报错,或者一直在提取文件中,无法结束。 例如直接右件该rar文件,将文件提取到此处 一直显示这个,无法结束 解决办法 需要安装一些软…

架构师进阶之路 - 架构优化为什么难

目录 业务迭代和技术优化难以兼顾 缺少“上帝”视角思维 系统架构腐化 缺少架构师视角 系统迭代机制 设计规范把控 最近在组织团队内的系统架构优化,总而言之就是难,至于为什么难我这边总结了以下六个方面,记录一下自己的架构师进阶之路吧。&…