7从0开始学PyTorch | PyTorch中求导、梯度、学习率、归一化

news2024/11/26 16:32:32

今天我们继续,接着昨天的进度。
先回顾一下上一小节,我学到了构建起一个模型函数和一个损失函数,然后我们使用人眼观察损失,并手动调整模型参数。然而看起来,我们虽然看到了损失,但我们调整参数的方案跟损失并没有太大的关系,而是随机的进行了调整,那么有没有什么方法能够衡量我们的参数该往什么方向去调整呢?是该调大还是调小呢?这里就涉及到一个梯度的概念了。

梯度(gradient)

百科给梯度的定义是这样的,反正我是没太看得懂。大学数学学得知识也忘得差不多了。

梯度的本意是一个向量(矢量),表示某一函数在该点处的方向导数沿着该方向取得最大值,即函数在该点处沿着该方向(此梯度的方向)变化最快,变化率最大(为该梯度的模)。

简单来说对于函数的某个特定点,它的梯度就表示从该点出发,函数值增长最为迅猛的方向(direction of greatest increase of a function)。而我们这里所要求的就是loss关于参数w和偏置b的梯度,然后沿着这个梯度去修正我们的w和b。

这里举个例子,想像当年上大学的时候,女生宿舍楼就在你的窗户后面,于是你买了一个望远镜,想窥探一下女生宿舍的生活细节。夜黑风高,对面的女生宿舍楼都亮起了灯光。这个时候你掏出了新买的望远镜,但是你发现看不清楚,可以说非常模糊,这时候你摸索着望远镜,上面有两个旋钮,一个是可以调整清晰度,一个是调整放大倍率(这里可以看做是w和b),这时候你发现向左拧按钮w就更模糊了,向右拧w就清楚一点,所以在当前这个点,向右就是你要的梯度。你发现向右可以变清楚,于是你喜出望外,大力往右一拧,貌似有那么一瞬间清晰,但是又变得模糊起来,这个时候向左就是你要的梯度。

把这个事情转换成数学公式,就是计算loss对于每一个参数的导数,然后在一个具体点位获得的矢量就是梯度结果。

image.png

根据求导的链式法则,有如下结果

d loss_fn / d w = (d loss_fn / d t_p) * (d t_p / d w)
对参数b同样适用
d loss_fn / d b = (d loss_fn / d t_p) * (d t_p / d b)

image.png

这个时候,我们写成代码

def dloss_fn(t_p, t_c): #loss对t_p求导
    dsq_diffs = 2 * (t_p - t_c) / t_p.size(0)  # <1>
    return dsq_diffs

def dmodel_dw(t_u, w, b): #t_p对w求导
    return t_u

def dmodel_db(t_u, w, b): #t_p对b求导
    return 1.0

梯度函数

def grad_fn(t_u, t_c, t_p, w, b):
    dloss_dtp = dloss_fn(t_p, t_c)
    dloss_dw = dloss_dtp * dmodel_dw(t_u, w, b)
    dloss_db = dloss_dtp * dmodel_db(t_u, w, b)
    return torch.stack([dloss_dw.sum(), dloss_db.sum()])  

对于stack方法的官方解释:沿着一个新维度对输入张量序列进行连接。 序列中所有的张量都应该为相同形状。
浅显说法:把多个2维的张量凑成一个3维的张量;多个3维的凑成一个4维的张量…以此类推,也就是在增加新的维度进行堆叠。

说一句,我理解这个计算梯度的过程就叫反向传播。

学习率

通过求梯度,我们能知道该往哪个方向走是loss下降最快的,然而,到底该迈多大的步子呢?想到这个问题,我不禁又掏出了望远镜。当我们发现大力拧望远镜的时候,很容易就拧过去了,一下子就从一种模糊状态变成了另一种模糊状态,中间有一个清晰的妹子图像一闪而过。于是我决定控制好抖动的双手,每次只拧一点点,这样我才能够在看清妹子的时候停下来。

所以,这里面涉及到的一个概念就是学习率(learning rate)。前面我们通过梯度确定了参数的调整方向,然后我们用学习率来调整步子的大小,其实就是在梯度上面乘以一个系数,比如说w = w - learing_rate * grad作为我们下次尝试的参数。

可以想到的是,如果学习率定的太大,可能很难收敛,就像你的望远镜一直在两种不同的模糊状态中变来变去,而你的学习率定的太小,也会很难收敛,比如你每次只转动0.0001毫米,估计对面的女生都毕业了你也没转到清楚的地方。因此这个学习率也是一个玄学,有时候在搞模型的时候会让你有意想不到的情况发生。当然,很多时候也有一些参考值,比如设定为1e-5。

如何优化

说到这里,我们训练的前期准备都差不多完成了,接下来开启炼丹过程。

这里有一个概念就是epoch,话说epoch原意是时代,这里其实就是循环训练了几次的意思,一个epoch就是一次训练修正参数的过程,这原创人真的能整活,一个循环跑下来,一个时代就过去了。不过想想也是,深度模型实在是太费资源,如果你资源不充足,跑的是真慢,大厂有句老话,一杯茶,一包烟,一个模型跑一天。

def training_loop(n_epochs, learning_rate, params, t_u, t_c):

    for epoch in range(1, n_epochs + 1):
        w, b = params

        t_p = model(t_u, w, b)  # 正向传播,数据输入模型获得预测结果
        loss = loss_fn(t_p, t_c) #计算预测结果和真实值的损失
        grad = grad_fn(t_u, t_c, t_p, w, b)  # 反向传播,求损失关于参数的梯度

        params = params - learning_rate * grad #参数调整

        print('Epoch %d, Loss %f' % (epoch, float(loss))) 
            
    return params

# 定义好了训练迭代的方案,开始跑训练
training_loop(
    n_epochs = 100,  #100个时代
    learning_rate = 1e-2,  #学习率初始化
    params = torch.tensor([1.0, 0.0]),  #参数初始化
    t_u = t_u, 
    t_c = t_c)

搞了这么多,结果还是出问题了,你猜怎么着,看起来这效果一点也不好啊,这每轮训练不光没有降低损失,反而让损失越来越大,到了第11代直接溢出了,好嘛。

image.png

这里你想到什么问题,就是我们前面说的学习率过大了,那我们就把学习率调小一点,其他的不变,把学习率改到1e-5,同时把grad和params也输出看一下。

看前5次迭代,明显效果好多了,至少loss是在下降的。

Epoch 1, Loss 1763.884766
params:tensor([ 9.5483e-01, -8.2600e-04])
grad:tensor([4517.2964,   82.6000])
Epoch 2, Loss 1565.761353
params:tensor([ 0.9123, -0.0016])
grad:tensor([4251.5220,   77.9184])
Epoch 3, Loss 1390.265503
params:tensor([ 0.8723, -0.0023])
grad:tensor([4001.3838,   73.5123])
Epoch 4, Loss 1234.812378
params:tensor([ 0.8346, -0.0030])
grad:tensor([3765.9622,   69.3654])
Epoch 5, Loss 1097.112793
params:tensor([ 0.7992, -0.0037])
grad:tensor([3544.3916,   65.4625])

看最后一代,loss降到了29,其实到第64代的时候,loss就已经在29-30之间徘徊了,看起来这就是当前的一个极限水平了。

Epoch 100, Loss 29.114819
params:tensor([ 0.2340, -0.0165])
grad:tensor([11.1109,  3.2240])

那么还有什么地方是可以优化的呢?我们观察一下结果,在params上,参数w和参数b基本上有10倍的差距,而我们使用同一个学习率那么可能导致一些问题,如果说这个学习率对较大的那个参数比较合适,那么比较小的那个肯定是属于优化过慢,而如果学习率比较适合较小的那个参数,那么较大的那个就属于步子太大可能不稳定。这个时候我们自然想到的是给每一个参数设定一个不同的学习率,但是这个成本很高,至少目前看起来是很高,因为我们在深度模型里可能会有几十亿的参数,那就需要有几十亿的学习率。

反过来,这里有一个比较简单的方案,既然调整学习率不方便,那么我们就想别的办法。比如说做输入数据的归一化。因为参数和数据合并起来构成一项,如果我们把所有维度的输入数据都限定到一个固定的区间中,那么学习率的影响也应该是类似的。

这里我们做个简单的尝试,把t_u都缩小10倍,使用params来承接输出结果

t_un = 0.1 * t_u
params=training_loop(
    n_epochs = 100,  #100个时代
    learning_rate = 1e-5,  #学习率初始化
    params = torch.tensor([1.0, 0.0]),  #参数初始化
    t_u = t_un, 
    t_c = t_c)

结果呢,到了100代loss才降到74,而且观察前100,loss是稳定下降的,这说明我们的学习率太小了,这个时候可以增大epoch,或者增大学习率。

Epoch 100, Loss 74.637016
params:tensor([1.0753, 0.0102])
grad:tensor([-73.1205,  -9.8467])

当把epoch改到3000的时候,loss下降到了30,还不如之前效果好,看来还得加大epoch

Epoch 3000, Loss 30.896944
params:tensor([2.0809, 0.0991])
grad:tensor([-13.0219,   0.7548])

或者我们把学习率改到1e-2试一下,epoch为100,这个时候可以看到loss已经降到了22,说明我们的优化起到了效果。

Epoch 100, Loss 22.148710
params:tensor([ 2.7553, -2.5162])
grad:tensor([-0.4446,  2.5165])

这个时候让我们双管齐下,学习率使用1e-2,迭代次数3000次,可以看到这时候loss虽然还没有到0,但是跟之前比起已经非常小了,只有2.9,当然我们的数据本来就有一些误差,所以肯定到不了0。

Epoch 3000, Loss 2.928648
params:tensor([  5.3489, -17.1980])
grad:tensor([-0.0032,  0.0182])

最后,让我们把我们预测完的模型图像绘制出来,就是一个直线

这里面用到一个新的参数传入方式“*”,就像下面代码里写的,t_p = model(t_un, *params),这里是解包方法,意味着接受到的参数params中的元素作为单独的参数传入,等同于model(t_un, params[0],params[1])

%matplotlib inline
from matplotlib import pyplot as plt

t_p = model(t_un, *params)  

fig = plt.figure(dpi=600)
plt.xlabel("Temperature (°Fahrenheit)")
plt.ylabel("Temperature (°Celsius)")
plt.plot(t_u.numpy(), t_p.detach().numpy())
plt.plot(t_u.numpy(), t_c.numpy(), 'o')
plt.savefig("temp_unknown_plot.png", format="png")  

image.png

总结一下今天这一小节,关于构建模型这一部分,我们了解了梯度这个概念,知道了怎么计算梯度以及梯度下降方法用于更新参数,然后了解了学习率以及学习率对更新参数的影响。最后学了一点点优化方法,比如像归一化数据,如何修改学习率,增大epoch等等,每天进步一点点。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/665716.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【Python学习】—认识Python与环境搭建(一)

【Python学习】—认识Python与环境搭建&#xff08;一&#xff09; 本章节思维导图如下&#xff1a; 一、Python解释器 首先一个基本原理就是&#xff1a;计算机只认识二进制&#xff0c;0和1 Python解释器&#xff0c;是一个计算机程序&#xff0c;用来翻译Python代码&…

十六、Docker Swarm的介绍和使用

一、Swarm简介 1、swarm介绍 Dockere Swarm是Docker公司推出的用来管理docker集群的编排工具&#xff0c;代码开源在https://github.com/docker/swarm&#xff0c; 它是将一群Docker宿主机变成一个单一的虚拟主机&#xff0c;提供了标准的 Docker API&#xff0c;所有任何已经与…

MySQL8超详细安装教程

MySQL的下载与安装 一、MySQL8下载 MySQL Community Server 社区版本&#xff0c;开源免费&#xff0c;自由下载&#xff0c;但不提供官方技术支持&#xff0c;适用于大多数普通用户。 MySQL Enterprise Edition 企业版本&#xff0c;需付费&#xff0c;不能在线下载&#x…

VUE——Vue CLI的原理与基本使用

摘要 Vue CLI 是一个基于 Vue.js 进行快速开发的完整系统&#xff0c;提供&#xff1a; 通过 vue/cli 实现的交互式的项目脚手架。通过 vue/cli vue/cli-service-global 实现的零配置原型开发。一个运行时依赖 (vue/cli-service)&#xff0c;该依赖&#xff1a; 可升级&…

互联网企业更需要线上版的产品手册

互联网企业在不断发展变化的市场中&#xff0c;需要更加灵活和快速地适应市场需求&#xff0c;因此&#xff0c;线上版的产品手册对于互联网企业来说是非常重要的。 互联网企业更需要线上版的产品手册的原因 互联网用户更喜欢在线文档 互联网用户更喜欢在线文档&#xff0c;…

中创|没人比我更懂!马斯克发出警告:人类要小心人工智能

马斯克在过去十年对AI的态度一直非常鲜明&#xff0c;很早就对这个问题有深入地思考&#xff1a; 2014 “我们对AI要非常小心&#xff0c;这可能是我们最大的存在威胁。” 2016 “AI的未来发展方向可能并不乐观&#xff0c;起码不会所有结果都是好的。” 2017 “AI会比地…

中原银行 OLAP 架构实时化演进

中原银行 OLAP 架构实时化演进 1. OLAP 实时化建设背景2. OLAP 全链路实时化3. OLAP 实时化探索4. 未来探索方向 中原银行成立于 2014 年&#xff0c;是河南省唯一的省级法人银行&#xff0c;2017 年在香港联交所主板上市&#xff0c;2022 年 5 月经中国银保监会批准正式吸收合…

【问题解决】 网关代理Nginx 301暴露自身端口号

一般项目上常用Nginx做负载均衡和静态资源服务器&#xff0c;本案例中项目上使用Nginx作为静态资源服务器出现了很奇怪的现象&#xff0c;我们一起来看看。 “诡异”的现象 部署架构如下图&#xff0c;Nginx作为静态资源服务器监听8080端口&#xff0c;客户浏览器通过API网关…

跟晓月一起学:mysql中常用的命令汇总

前言 本文主要讲解了MySQL中常用的命令&#xff0c;感谢师父的耐心指导&#xff0c;师父博客&#xff1a;https://zmedu.blog.csdn.net 本文是对MySQL常用的两个命令的总结&#xff0c;一个是select &#xff0c;一个是show命令&#xff0c;很多时候我们监控MySQL需要监控MyS…

ABB 5SHY35L4520 AC10272001R0101/5SXE10-0181 IGCT模块

ABB 5SHY35L4520 AC10272001R0101/5SXE10-0181 IGCT模块 ABB 5SHY35L4520 AC10272001R0101/5SXE10-0181 IGCT模块 2、DCS的软件系统 DCS的软件体系如图2所示&#xff0c;通常可以为用户提供相当丰富的功能软件模块和功能软件包&#xff0c;控制工程师利用DCS提供的组态软件&…

STM32 USART串口

什么是串口 串口是串行接口 (Serial Interface)的简称&#xff0c;它是指数据一位一位地顺序传送&#xff0c;其特点是通信线路简单&#xff0c;只要一对传输线就可以实现双向通信&#xff08;可以直接利用电话线作为传输线&#xff09;&#xff0c;从而大大降低了成本&#xf…

优秀的 Verilog/FPGA开源项目介绍(三十七)- MATH库

DSP介绍 数字信号处理&#xff08; Digital Signal Processing)技术广泛地应用于通信与信息系统、信号与信息处理、自动控制、 雷达、军事、航空航天、医疗、家用电器等许多领域。DSP 技术可以快速地对采集的信号进行量化、变换、滤波、估值 、增强、压缩、识别等处理&#xff…

2023 linux驱动中probe函数的返回值,返回0成功。返回负数则失败,这个时候驱动向系统申请的有关资源都会被释放,如中断号,申请的内存等。实际测试。

一、在linux 驱动里面申请一个gpip&#xff0c;&#xff0c;gpip2b4 变换是 76 &#xff0c;dts 如下&#xff1a; m117b45 {compatible "xxx,m117b";reg <0x45>;pinctrl-names "default";pinctrl-0 <&m117b_gpio>;pwdn-gpios <&a…

数据库数据更新:从内存到磁盘,一步步揭开数据的神秘面纱!

大家好&#xff0c;我是小米&#xff01;今天我要和大家分享一下数据库数据更新的流程。作为一名热衷于技术分享的小伙伴&#xff0c;我希望通过本篇文章&#xff0c;帮助大家更好地理解数据库数据更新的过程。废话不多说&#xff0c;让我们开始吧&#xff01; 获取数据 在数据…

PM3328BP-6电源模块PIONEER MAGNETICS

PM3328BP-6电源模块PIONEER MAGNETICS PM3328BP-6电源模块PIONEER MAGNETICS DCS中的先进控制技术 DCS在控制上的最大特点是依靠各种控制、运算模块的灵活组态&#xff0c;可实现多样化的控制策略以满足不同情况下的需要&#xff0c;使得在单元组合仪表实现起来相当繁琐与复杂…

基于GO实现的简易博客,附源码

1、简介 此博客系统主要是基于GO、Gin、Gorm进行开发&#xff0c;以及采用lay-ui框架进行前端界面的开发&#xff0c;项目包含功能众多&#xff0c;基本上涵盖了博客系统的大部分需求。 此项目适合开发者练手学习&#xff0c;同时也适合高校毕业设计的作品。 以下对作品进行…

#经验分享#消防电源强切故障

工业园火灾报警控制器显示&#xff0c;13#厂房电源强切报故障&#xff0c;经过紧急处理&#xff0c;成功解决了故障问题。 据了解&#xff0c;故障原因是71#强切模块被修复大门时损坏模块破碎无法进行修复&#xff0c; 只留有接线底座&#xff0c;测试并检查底座线路正常。 坏…

如何查看jar包的官网地址

https://mvnrepository.com/ 使用artifactId搜索 点击要查看的版本 查看HomePage LicenseApache 2.0CategoriesJSON LibrariesTagsformatjsonOrganizationAlibaba GroupHomePageGitHub - alibaba/fastjson2: &#x1f684; FASTJSON2 is a Java JSON library with excellent…

Citespace和vosviewer文献计量学可视化SCI论文高效写作方法

【基于Citespace和vosviewer文献计量学相关论文 】 文献计量学是指用数学和统计学的方法&#xff0c;定量地分析一切知识载体的交叉科学。它是集数学、统计学、文献学为一体&#xff0c;注重量化的综合性知识体系。特别是&#xff0c;信息可视化技术手段和方法的运用&#xff…