【PyTorch】第四节:梯度下降算法

news2024/12/24 21:21:30

 作者🕵️‍♂️:让机器理解语言か

专栏🎇:PyTorch

描述🎨:PyTorch 是一个基于 Torch 的 Python 开源机器学习库。

寄语💓:🐾没有白走的路,每一步都算数!🐾 

介绍

       本实验主要对梯度下降算法的基本原理进行了讲解,然后使用手写梯度下降算法解决了线性回归问题。最后对 PyTorch 中的反向传播函数进行了讲解并利用该函数简明快速的完成了损失的求导与模型的训练。

知识点🍉

  • 🍓线性回归
  • 🍓梯度下降算法
  • 🍓损失函数

人工梯度下降算法

这里我们使用梯度下降算法来对线性回归问题进行求解。

线性回归问题

线性回归问题其实就是寻找一条合适的直线(y=wx)用以表示所有的数据点,如下:

        如上图所示,线性回归问题其实就是求解上面的线性函数中最佳的 w 值。合适 y=wx 函数可以表示标签 Y 和数据 X 之间的关系,进而预测新的 x 所对应的 预测值 y_{pre_i},其中 y_{pre_i} = w\cdot x_i。那么我们应当用什么来衡量最佳的 w 值呢?

        我们一般认为预测值y_{pre_i} 与真实值 y_i的距离越小,那么该函数就越好,w 的值就越趋近于最佳。在机器学习中,这种计算距离的函数有一个名字,叫做损失函数。定义如下:

import numpy as np

# 所有点的预测值和实际值的距离的平方和,再取平均值(这种距离叫做欧氏距离)。
def loss(y, y_pred):
    return ((y_pred - y)**2).mean()
#测试代码
y_pred = np.array([1,2])
y = np.array([1,1])
loss(y, y_pred)
# 0.5

综上,线性回归问题其实就是求解损失函数最小的情况下的 w 值。

梯度下降算法

        梯度下降算法是一种用于求解函数极小值的方法。

        我们可以把梯度下降算法类比为一个下山的过程。假设一个人被困在山上,需要快速从山上下来,走到山底。 但是,由于山中浓雾很大,导致可视度很低。因此,我们无法确定下山的路径,只能看到周围的一些信息。也就是说我们需要走一步看一步再走一步。此时,我们就可以使用到梯度下降的算法。如下:

        我们需要找的损失函数也可以看做一座山,我们的目标就是找到这座山的最小值,即山底。每走一步,我们都会重新找山脚的方向。因为沿着山脚方向走能够使我们最快到达山脚的位置。

由于梯度表示的是函数上升最快的方向,因此梯度的反方向也应该是函数下降最快的方向。我们每次到了一个新的位置,就会就计算该位置的梯度,找到下一步下山最快的方向。

        根据梯度和当前位置更新下一次所在位置的数学表达式如下:

\theta^1=\theta^0-\alpha \cdot \triangledown J(\theta^0)

        上面式子展示了损失函数 J(θ) 的最小值的求解过程。

        其中 \theta^0 表示当前所在位置,\theta^1 表示下一步的位置,\alpha 表示步长(即一次更新的距离),-\triangledown J(\theta^0) 表示损失函数的梯度的相反方向。

        我们可以将损失函数值 J 定义为欧氏距离,如下:

J = \frac{1}{N}\sum_{i=1}^N(w\cdot x_i - y_i)^2

        损失函数关于 w 的梯度为(此时我们需要求的是 w 的最佳值, w 为变量,因此求损失关于 w 的梯度):

\frac{\partial{J}}{\partial w} = \frac{1}{N}\sum_{i=1}^N(2wx_{i}^2-2x_{i}y_{i})

根据上面的梯度公式,让我们来定义损失函数的梯度计算公式:

#返回dJ/dw
def gradient(x, y, w):
    return np.mean(2*w*x*x-2*x*y)
## 测试代码
x = np.array([1,2])
y = np.array([1,1])
gradient(x, y, 2)
# 7.0

人工实现梯度下降算法(需要推导梯度公式)

假设 w 为损失函数需要求的变量,那么梯度下降算法的具体步骤如下:

  1. 随机初始化一个 w 的值。
  2. 在该 w 下进行正向传播,得到所有 x 的预测值\bg_white y_{pre}
  3. 通过实际的值 y 和预测值 \bg_white y_{pre} 计算损失
  4. 通过损失计算梯度 dw
  5. 更新ww = w-lr\cdot dw,其中lr为步长(学习率),可自定义具体的值。
  6. 重复步骤 2−5,直到损失降到较小位置。 

首先,让我们先来定义一下梯度下降算法所需要的数据集和变量值:

# 正向传播,计算预测值
def forward(x):
    return w * x
# 定义数据集合和 w 的初始化
X = np.array([1, 2, 3, 4], dtype=np.float32)
Y = np.array([2, 4, 6, 8], dtype=np.float32)
w = 0.0
# 定义步长和迭代次数
learning_rate = 0.01
n_iters = 20

接下来,让我们根据上面步骤,利用梯度下降算法求解一元回归函数中的 w 的值:

for epoch in range(n_iters):
    # 正向传播
    y_pred = forward(X)
    # 计算损失
    l = loss(Y, y_pred)
    # 计算梯度
    dw = gradient(X, Y, w)
    # 更新权重 w
    w -= learning_rate * dw

    if epoch % 2 == 0:    # 每两次训练输出一次
        print(f'epoch {epoch+1}: w = {w:.3f}, loss = {l:.8f}')
     
print(f'根据训练模型预测,当 x =5 时,y 的值为: {forward(5):.3f}')

       从结果可以很清晰的函数,我们利用梯度下降算法,不断的降低损失的值,寻找最佳的权重 w。当损失不再发生变化时,证明我们已经找到了一个较小的值,此时的 w 就是较佳权重(根据结果可以看到,w 的值无限接近于 2)。即线性函数 y=2x 可以很好的表示上面的数据集合。

利用 PyTorch 实现梯度下降算法

        由于线性函数的损失函数的梯度公式很容易被推导出来,因此我们能够手动的完成梯度下降算法。但是,在很多机器学习中,模型的函数表达式是非常复杂的,这个时候手动定义该函数的梯度函数需要很强的数学功底。因此,这里我们使用上一个实验中所用的后向传播函数来实现梯度下降算法,求解最佳权重 w。 

        首先,让我们来定义数据集合以及 w 的初始值,并将其设置为可以求偏导的张量

import torch
X = torch.tensor([1, 2, 3, 4], dtype=torch.float32)
Y = torch.tensor([2, 4, 6, 8], dtype=torch.float32)
#初始化张量 w
w = torch.tensor(0.0, dtype=torch.float32, requires_grad=True)
# 定义步长和迭代次数
learning_rate = 0.01
n_iters = 20

        接下来让我们使用 .backward() 直接求解梯度:

 for epoch in range(n_iters):
    y_pred = forward(X)
    l = loss(Y, y_pred)
    # 无需定义梯度求解的函数,直接求解梯度
    l.backward()
    # 利用梯度下降更新参数
    with torch.no_grad():    # 停止张量计算
        # w.grad :返回 w 的梯度
        w.data -= learning_rate * w.grad
    
    # 清空梯度
    w.grad.zero_()

    if epoch % 2 == 0:
        print(f'epoch {epoch+1}: w = {w.item():.3f}, loss = {l.item():.8f}')
print(f'根据训练模型预测,当 x =5 时,y 的值为: {forward(5):.3f}')

         可以看到,利用 PyTorch 进行的梯度下降的结果和人工梯度下降结果一致。我们可以通过 PyTorch 中的 .backward(),简洁明了的求取任何复杂函数的梯度,大大的节约了我们公式推导的时间。

实验总结🔑

        当然,本实验只是利用 .backward()对损失进行了求导,其实 PyTorch 中还有很多用于梯度下降算法的工具包。我们可以使用这些工具包完成损失函数的定义、损失的求导以及权重的更新等各种操作。在下一个实验中,我们将对这些工具包进行详细的讲解。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/417308.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

java--HtmlUnit--模拟浏览器操作--自动化操作浏览器--自动登录校园网为案例

写在前面: 闲来无事,因为宿舍每次嫌登录校园网有点免费。然后想着能不能一键自动化实现。然后更麻烦了,哈哈哈。不过倒是写一次代码就可以了。 可能不是特别系统,因为资料太少了。都是案例驱动找的资料。花了3大节课才搞完了。 会…

Redis运维之swap影响及解决方案

一、操作系统SWAP swap空间对于操作系统来说比较重要,当我们使用操作系统的时候,如果系统内存不足,常常会将一部分内存数据页进行swap操作,以解决临时的内存困境。swap空间由磁盘提供,对于高并发场景下,sw…

全球土壤湿度数据获取方法

土壤湿度亦称土壤含水率,表示土壤干湿程度的物理量。是土壤含水量的一种相对变量。通常用土壤含水量占干土重的百分数是示,亦称土壤质量湿度,如用土壤水分容积占土壤总容积的百分数表示,则称土壤容积湿度。通常说的土壤湿度&#…

Vivado中VIO IP核的使用

Vivado中VIO IP核的使用一、写在前面二、VIO IP核配置三、VIO联调四、写在后面一、写在前面 Vivado中的VIO(Virtual Input/Output) IP核是一种用于调试和测试FPGA设计的IP核。它允许设计者通过使用JTAG接口读取和写入FPGA内部的寄存器,从而检…

【JavaEE】关于synchronized总结-Callable用法及JUC的常见问题

博主简介:想进大厂的打工人博主主页:xyk:所属专栏: JavaEE初阶synchronized原理是什么?synchronized到底有什么特点,synchronized的锁策略是什么,是怎么变化的呢?本篇文章总结出, Synchronized 具有以下特性…

【Java|golang】1041. 困于环中的机器人

在无限的平面上,机器人最初位于 (0, 0) 处,面朝北方。注意: 北方向 是y轴的正方向。 南方向 是y轴的负方向。 东方向 是x轴的正方向。 西方向 是x轴的负方向。 机器人可以接受下列三条指令之一: “G”:直走 1 个单位 “L”&…

Markdown 语法大全

Markdown是一种轻量级标记语言,常用于撰写博客、文档、论文等。它可以让你使用易读易写的纯文本格式来编写文档,然后通过转换成有效的HTML文档进行发布。以下是Markdown常用的语法: 这里写目录标题标题列表引用一级引用嵌套引用粗体和斜体删除…

技术复盘(1)--redis

技术复盘--redis技术复盘(1)--redis资料地址准备工作发展史redis-windowsredis-windows-说明redis-centos7安装jdk安装redisredis-key基本命令redis-string命令redis-list命令redis-set命令redis-hash命令redis-zset命令redis-geospatial命令redis-hyperloglog命令redis-bitmap…

【Linux驱动开发】024 INPUT子系统

一、前言 按键、鼠标、键盘、触摸屏等都属于输入(input)设备,Linux 内核为此专门做了一个叫做 input子系统的框架来处理输入事件。输入设备本质上还是字符设备,只是在此基础上套上了 input 框架,用户只需要负责上报输入事件,比如…

文本聚类与摘要,让AI帮你做个总结

你好,我是徐文浩。 上一讲里,我们用上了最新的ChatGPT的API,注册好了HuggingFace的账号,也把我们的聊天机器人部署了出去。希望通过这个过程,你对实际的应用开发过程已经有了充足的体验。那么这一讲里,我们…

[目标识别-论文笔记]Object Detection in Videos by Short and Long Range Object Linking

文章标题:2018_Cite13_Tang——Object Detection in Videos by Short and Long Range Object Linking 这篇论文也被叫做“2019_Cite91_TPAMI_Tang——Object Detection in Videos by High Quality Object Linking” 如果这篇博客对你有帮助,希望你 点赞…

ES索引库操作

文章目录1、对索引库的操作:创建、删除、查看2、文档操作3、 RestClient操作索引库4、利用RestClient实现文档的CRUD5、 批量导入功能有了索引库相当于数据库database,而接下来,就是需要索引库中的类型了,也就是数据库中的表&…

nssctf web入门(1)

这里通过nssctf的题单web安全入门来写,会按照题单详细解释每题。题单在NSSCTF中。 想入门ctfweb的可以看这个系列,之后会一直出这个题单的解析,题目一共有28题,打算写10篇。 [SWPUCTF 2021 新生赛]jicao [SWPUCTF 2021 新生赛]j…

RL4RS,离线强化学习,无模型强化学习等等资源汇总

发现好文章: 强化学习推荐系统综述:Reinforcement Learning based Recommender Systems: A Survey 强化学习图鉴|你与最优策略之间,可能还差一本离线强化学习秘籍 科学应用强化学习创新论文洞察 https://hub.baai.ac.cn/view/18…

【论文精读】PP-YOLOE: An evolved version of YOLO

文章目录前言一、可扩展的 Backbone 和 Neck二、更高效的标签分配策略 TAL (Task Alignment Learning)三、更简洁有效的 ET-Head (Efficient Task-aligned Head)前言 百度飞桨团队发布了 PP-YOLOE,与其他 YOLO 系列算法相比,其具有更强的性能、更丰富灵…

8.2 正态总体的参数的检验

学习目标: 如果我要学习正态总数的参数检验,我会按照以下步骤进行学习: 学习正态分布的基本知识:正态分布是统计学中非常重要的概率分布之一,掌握其基本知识包括概率密度函数、期望值、方差、标准差等是非常重要的。 …

Prometheus - Grafana 监控 MySQLD Linux服务器 demo版

目录 首先是下载Prometheus 下载和安装 配置Prometheus 查看监控数据 监控mysql demo 部署 mysqld_exporter 组件 配置 Prometheus 获取监控数据 -------------------------------------- 安装和使用Grafana 启动Grafana -------------------------------------- 配…

MySQL5.5安装图解

一、MYSQL的安装 1、打开下载的mysql安装文件mysql-5.5.27-win32.zip,双击解压缩,运行“setup.exe” 2、选择安装类型,有“Typical(默认)”、“Complete(完全)”、“Custom(用户自定义)”三个选项,选择“Cu…

VSD Viewer for Mac,Visio绘图文件阅读器

VSD Viewer for Mac版是mac上一款非常强大的Visio绘图文件阅读器,它为打开和打印Visio文件提供了简单的解决方案。可以显示隐藏的图层,查看对象的形状数据,预览超链接。还可以将Visio转换为包含图层,形状数据和超链接的PDF文档。 …

【状态估计】基于增强数值稳定性的无迹卡尔曼滤波多机电力系统动态状态估计(Matlab代码实现)

💥💥💞💞欢迎来到本博客❤️❤️💥💥 🏆博主优势:🌞🌞🌞博客内容尽量做到思维缜密,逻辑清晰,为了方便读者。 ⛳️座右铭&a…