深度学习02-pytorch-08-自动微分模块

news2025/1/15 12:48:40

​​​​​​​

其实自动微分模块,就是求相当于机器学习中的线性回归损失函数的导数。就是求梯度。

反向传播的目的: 更新参数, 所以会使用到自动微分模块。

神经网络传输的数据都是 float32 类型。 

案例1:

代码功能概述:

该代码展示了如何在 PyTorch 中使用 自动微分(Autograd) 计算损失函数相对于权重 w 和偏置 b 的梯度。这是机器学习模型训练中非常重要的步骤,因为这些梯度将用于更新模型的参数,从而最小化损失函数

import torch

# 1. 当x为标量时,梯度的计算
def test01():
    x = torch.tensor(5)  # 输入变量x为标量5
    # 目标值
    y = torch.tensor(0.)  # 目标输出y设置为0
    
    # 设置要更新的权重 和 偏置的初始值
    w = torch.tensor(1., requires_grad=True, dtype=torch.float32)  # 权重w初始化为1,并启用梯度计算
    b = torch.tensor(3., requires_grad=True, dtype=torch.float32)  # 偏置b初始化为3,并启用梯度计算
    
    # 设置网络的输出值
    z = x * w + b  # 计算线性模型的输出 z = x*w + b (等同于线性回归的公式)
    
    # 设置损失函数,并进行损失的计算
    loss = torch.nn.MSELoss()  # 使用均方误差(MSE)作为损失函数
    loss1 = loss(z, y)  # 计算损失,z 是模型的预测值,y 是目标值
    
    # 自动微分,计算损失函数相对于w和b的梯度
    loss1.backward()  # 反向传播计算梯度
    
    # backward 函数计算的梯度值会存储在张量的grad 变量中
    print("w的梯度", w.grad)  # 打印出损失函数对 w 的梯度
    print("b的梯度", b.grad)  # 打印出损失函数对 b 的梯度
    
test01() 

w的梯度 tensor(80.)
b的梯度 tensor(16.)

代码讲解:

    1.    输入与目标值:
    •    x = torch.tensor(5):输入为 x = 5,表示输入的特征值。
    •    y = torch.tensor(0.):目标输出 y 设置为 0,这是我们希望模型最终预测得到的值。
    2.    参数的初始化:
    •    w = torch.tensor(1., requires_grad=True):初始化权重 w 为 1,requires_grad=True 启用对 w 的梯度计算。
    •    b = torch.tensor(3., requires_grad=True):初始化偏置 b 为 3,同样启用对 b 的梯度计算。
requires_grad=True 的作用是让 PyTorch 知道我们想对这些参数进行梯度计算。
    3.    模型计算:
    •    z = x * w + b:计算模型的输出,类似于线性回归的公式。z 是模型的预测输出。
    4.    损失函数:
    •    loss = torch.nn.MSELoss():选择均方误差(MSE)作为损失函数,用于衡量预测值 z 与目标值 y 之间的误差。
    •    loss1 = loss(z, y):计算损失值,z 是模型预测输出,y 是目标值。

MSE 的公式为:

\text{MSE} = \frac{1}{N} \sum_{i=1}^{N} (z_i - y_i)^2

在这个例子中,由于我们只使用了一个数据点,损失计算为:

\text{Loss} = (z - y)^2 = (x \cdot w + b - 0)^2

    5.    反向传播:
    •   loss1.backward():通过调用 backward(),PyTorch 会自动计算损失函数对 w 和 b 的梯度。这个过程称为反向传播(Backpropagation)。梯度的计算基于链式法则,PyTorch 会自动追踪所有的计算操作,计算各个参数对损失的导数。


    6.    梯度输出:
    •    w.grad:存储了损失函数对 w 的梯度。
    •    b.grad:存储了损失函数对 b 的梯度。

案例2:

import torch

def test02():
    # 输入张量 2x5,表示 2 个样本,每个样本有 5 个特征
    x = torch.ones(2, 5)  # 输入数据,全部初始化为 1
    
    # 目标输出张量 2x3,表示我们希望模型预测的输出有 3 个类别
    y = torch.zeros(2, 3)  # 目标输出,初始化为 0
    
    # 设置可更新的权重和偏置的初始值
    # 权重 w 的形状是 5x3,表示输入特征为 5,输出类别为 3
    w = torch.randn(5, 3, requires_grad=True)  # 随机初始化权重,启用梯度计算
    
    # 偏置 b 的形状是 3,表示每个输出类别有一个偏置
    b = torch.randn(3, requires_grad=True)  # 随机初始化偏置,启用梯度计算
    
    # 计算网络的输出,z = x * w + b
    # x 的形状是 2x5,w 的形状是 5x3,矩阵乘法后的结果 z 的形状是 2x3
    z = torch.matmul(x, w) + b  # 矩阵乘法和偏置加法
    
    # 设置损失函数,并计算损失
    # 这里使用均方误差(MSE),z 是预测值,y 是目标值
    loss_fn = torch.nn.MSELoss()  # 损失函数为均方误差
    loss = loss_fn(z, y)  # 计算损失,输出一个标量值
    
    # 自动微分,计算损失函数相对于 w 和 b 的梯度
    loss.backward()  # 反向传播,计算梯度
    
    # 打印权重和偏置的梯度,梯度值存储在 grad 属性中
    print("w 的梯度:\n", w.grad)  # 打印权重 w 的梯度
    print("b 的梯度:\n", b.grad)  # 打印偏置 b 的梯度

# 调用函数进行计算
test02()

自动微分 (Autograd) 的工作原理:

    •    PyTorch 中的 Autograd 是自动微分引擎,它会记录所有张量的计算历史,并根据这些计算图自动执行反向传播,计算参数的梯度。
    •    在向前计算过程中,PyTorch 构建了一个动态计算图(计算图是有向无环图 DAG)。当你调用 .backward(),计算图会根据链式法则从损失开始计算每个变量的梯度。
    •    计算的梯度会存储在对应张量的 .grad 属性中,然后可以使用这些梯度来更新模型的参数。

总结:

    •    w.grad 和 b.grad 的值告诉我们,若我们改变 w 或 b,损失函数会如何变化。
    •    梯度的计算对于优化模型非常重要,因为我们会使用这些梯度来更新权重和偏置,使得损失函数最小化。

PyTorch 中的 自动微分模块 是通过 autograd 实现的,这是 PyTorch 中的核心功能之一,它可以帮助用户在神经网络的训练过程中自动计算梯度。autograd 模块使得实现反向传播和梯度计算变得非常简单和高效。

核心概念

  1. Tensor: PyTorch 的张量 (Tensor) 是自动微分系统的基本单位。如果将 Tensorrequires_grad 属性设置为 True,则 PyTorch 会开始跟踪所有与该张量相关的操作,并在反向传播时自动计算该张量的梯度。

  2. Computational Graph (计算图): PyTorch 会构建一个动态图,记录张量的所有操作。这个图是有向无环图(DAG),图中的每个节点代表一个变量,边代表该变量上发生的操作。当你调用 .backward() 时,PyTorch 会根据计算图自动计算每个张量的梯度。

  3. 梯度 (Gradient): 如果一个张量参与了计算并且 requires_grad=True,在反向传播时可以通过 .grad 属性获取其梯度值。

  4. 反向传播: 通过 tensor.backward() 来执行反向传播计算张量的梯度,默认情况下会对标量进行求导。

使用案例

  1. 创建一个张量并启用梯度跟踪:

    import torch
    ​
    # 创建一个张量,并启用梯度跟踪
    x = torch.tensor([[2.0, 3.0]], requires_grad=True)

  2. 执行一些操作:

    y = x * 3
    z = y.sum()
    print(z)

  3. 反向传播:

    z.backward()  # 对 z 求导
    print(x.grad)  # 查看 x 的梯度

    输出:

    tensor([[3., 3.]])

    在这个例子中,z = x * 3z.backward() 计算了 zx 的梯度,结果为 3

PyTorch 自动微分的几个重要点:

  1. requires_grad=True: 如果需要对某个张量求导,必须将其 requires_grad 属性设置为 True,否则在反向传播时 PyTorch 不会计算该张量的梯度。

  2. grad_fn: 每个跟踪计算的张量都有一个 grad_fn 属性,代表该张量的创建方式和跟踪的操作。例如,如果你对一个张量做了加法操作,它的 grad_fn 就会显示 AddBackward0

    print(y.grad_fn)  # <MulBackward0 object at 0x...>

  3. .backward(): backward() 方法会根据计算图反向传播,自动计算梯度。

  4. 梯度累加: 每次调用 backward() 时,梯度会被累加到 .grad 中,因此在多次反向传播之前,最好手动将 .grad 清零,使用 x.grad.zero_()

autograd 的典型使用场景

  • 神经网络训练:通过 autograd,我们可以在每次迭代时计算损失函数的梯度,然后使用这些梯度更新网络的参数。

  • 自定义梯度计算:可以通过创建复杂的操作来自动推导梯度。

Example: 简单的线性回归

import torch
​
# 生成数据
x = torch.randn(10, 1, requires_grad=True)
y = 3 * x + 2
​
# 定义损失函数
loss = (x - y).pow(2).mean()
​
# 反向传播
loss.backward()
​
# 查看 x 的梯度
print(x.grad)

在这个例子中,loss.backward() 会自动计算 xloss 的梯度。

总结

  • PyTorch 的自动微分机制通过 autograd 实现,用户只需要将张量的 requires_grad 设置为 True,在执行反向传播时,PyTorch 会自动计算张量的梯度。

  • 通过自动构建计算图,autograd 能够跟踪张量上的所有操作,动态计算梯度,极大地方便了深度学习模型的训练。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2154850.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【Python篇】深入机器学习核心:XGBoost 从入门到实战

文章目录 XGBoost 完整学习指南&#xff1a;从零开始掌握梯度提升1. 前言2. 什么是XGBoost&#xff1f;2.1 梯度提升简介 3. 安装 XGBoost4. 数据准备4.1 加载数据4.2 数据集划分 5. XGBoost 基础操作5.1 转换为 DMatrix 格式5.2 设置参数5.3 模型训练5.4 预测 6. 模型评估7. 超…

重生之我们在ES顶端相遇第14 章 - ES 节点类型

文章目录 前言Coordinating nodeMaster-eligible nodeData nodeCoordinating only nodeRemote-eligible nodeMachine learning node 前言 通过前面的学习&#xff0c;我们已经初步的掌握了 ES 的大部分用法。 后面的篇章会介绍 ES 集群相关的内容。 本文着重介绍 ES 节点类型&…

华为HarmonyOS地图服务 3 - 如何开启和展示“我的位置”?

一. 场景介绍 本章节将向您介绍如何开启和展示“我的位置”功能&#xff0c;“我的位置”指的是进入地图后点击“我的位置”显示当前位置点的功能。效果如下&#xff1a; 二. 接口说明 “我的位置”功能主要由MapComponentController的方法实现&#xff0c;更多接口及使用方法…

软考高级:逻辑地址和物理地址转换 AI解读

一、题目 设某进程的段表如下所示&#xff0c;逻辑地址&#xff08; &#xff09;可以转换为对应的物理地址。 A. &#xff08;0&#xff0c;1597&#xff09;、&#xff08;1&#xff0c;30&#xff09;和&#xff08;3&#xff0c;1390&#xff09; B. &#xff08;0&…

Vue3 中组件传递 + css 变量的组合

文章目录 需求效果如下图所示代码逻辑代码参考 需求 开发一个箭头组件&#xff0c;根据父组件传递的 props 来修改 css 的颜色 效果如下图所示 代码逻辑 代码 父组件&#xff1a; <Arrow color"red" />子组件&#xff1a; <template><div class&…

3DMAX乐高积木插件LegoBlocks使用方法

3DMAX乐高积木插件LegoBlocks&#xff0c;用户可以通过控件调整和自定义每个乐高积木的外观和大小。 【适用版本】 3dMax2009或更高版本&#xff08;不仅限于此范围&#xff09; 【安装方法】 3DMAX乐高积木插件无需安装&#xff0c;使用时直接拖动插件脚本文件到3dMax视口中…

TS 运行环境

1、TS Playground&#xff08;在线&#xff09; TS Playground 是一个在线 TypeScript 编辑器&#xff0c;它允许你编写、共享和学习 TypeScript 代码。 2、Stackblitz&#xff08;在线&#xff09; StackBlitz 是面向web开发人员的基于浏览器的协作IDE。StackBlitz消除了耗时…

握手传输 状态机序列检测(记忆科技笔试题)_2024年9月2日

发送模块循环发送0-7&#xff0c;在每个数据传输完成后&#xff0c;间隔5个clk&#xff0c;发送下一个 插入寄存器打拍处理&#xff0c;可以在不同的时钟周期内对信号进行同步&#xff0c;从而减少亚稳态的风险。 记忆科技笔试题&#xff1a;检测出11011在下一个时钟周期输出…

Python | 读取.dat 文件

写在前面 使用matlab可以输出为 .dat 或者 .mat 形式的文件&#xff0c;之前介绍过读取 .mat 后缀文件&#xff0c;今天正好把 .dat 的读取也记录一下。 读取方法 这里可以使用pandas库将其作为一个dataframe的形式读取进python&#xff0c;数据内容格式如下&#xff0c;根据…

VulnHub-Narak靶机笔记

Narak靶机笔记 概述 Narak是一台Vulnhub的靶机&#xff0c;其中有简单的tftp和webdav的利用&#xff0c;以及motd文件的一些知识 靶机地址&#xff1a; https://pan.baidu.com/s/1PbPrGJQHxsvGYrAN1k1New?pwda7kv 提取码: a7kv 当然你也可以去Vulnhub官网下载 一、nmap扫…

zabbix7.0容器化部署测试--(1)准备容器镜像

本文为zabbix7.0容器化部署测试系统文档之一&#xff0c;准备容器镜像。拟测试数据库后台为PostgreSQL16并启用timescaledb插件。 一、准备数据库容器镜像 因为不确定zabbix7.0对数据库timescaledb插件的版本要求&#xff0c;准备了现个镜像版本 1、准备timescaledb-2.14.2插…

linux 基础(一)mkdir、ls、vi、ifconfig

1、linux简介 linux是一个操作系统&#xff08;os: operating system&#xff09; 中国有没有自己的操作系统&#xff08;华为鸿蒙HarmonyOS&#xff0c;阿里龙蜥(Anolis) OS 8、百度DuerOS都有&#xff09; 计算机组的组成&#xff1a;硬件软件 硬件&#xff1a;运算器&am…

【速成Redis】03 Redis 五大高级数据结构介绍及其常用命令 | 消息队列、地理空间、HyperLogLog、BitMap、BitField

前言&#xff1a; 上篇博客我们讲到redis五大基本数据类型&#xff08;也是就下图的第一列&#xff09;。 【速成Redis】02 Redis 五大基本数据类型常用命令-CSDN博客文章浏览阅读1k次&#xff0c;点赞24次&#xff0c;收藏10次。该篇适用于速成redis。本篇我们将讲解&#…

MySQL | 知识 | NULL值是怎么存储的

NULL值有哪些语法影响 我们使用mysql时&#xff0c;使用 xx !aa 这种条件为什么无法筛选出值为NULL的字段呢。 是的&#xff0c;MySQL 中null 值确实无法通过这种条件筛选出来&#xff0c;因为 null 值的定义就跟普通值不一样。 拿官网的例子来说&#xff1a; mysql> INSE…

在Java中基于GeoTools的Shapefile读取乱码的问题解决办法

目录 前言 1、Shapefile属性字段编码的情况&#xff1a; 一、Shp文件常见的字符集编码 1、System编码 2、ISO-8859-1编码 3、UTF-8编码 二、GeoTools解析实战 1、未进行字符处理 2、乱码问题的解决 3、转码支持 4、属性字段编码结果 三、总结 前言 文件编码&#x…

RabbitMQ:交换机详解(Fanout交换机、Direct交换机、Topic交换机)

♥️作者&#xff1a;小宋1021 &#x1f935;‍♂️个人主页&#xff1a;小宋1021主页 ♥️坚持分析平时学习到的项目以及学习到的软件开发知识&#xff0c;和大家一起努力呀&#xff01;&#xff01;&#xff01; &#x1f388;&#x1f388;加油&#xff01; 加油&#xff01…

【笔记】第二节 轧制、热处理和焊接工艺

2.2 钢轨的轧制工艺 坯料进厂按标准验收, 然后装加热炉加热, 加热好的钢坯经高压水除鳞后进行轧制。轧出的钢轨经锯切、打印到中央冷床冷却, 然后装缓冷坑进行缓冷。缓冷后的钢轨进行矫直、轨端加工和端头淬火。钢轨入库前逐根进行探伤和外观检查。 钢轨的轧制 #mermaid-svg-…

【Delphi】使用 TPrototypeBindSource 和 LiveBindings 设计器示例

本教程展示了如何使用 LiveBindings Designer 和 TPrototypeBindSource 可视化地创建控件之间的 LiveBindings&#xff0c;以便快速开发只需很少或无需源代码的应用程序。 注意&#xff1a; TPrototypeBindSource 可用于为项目中的 LiveBindings 生成样本数据。在应用程序原型化…

公私域互通下的新商机探索:链动2+1模式、AI智能名片与S2B2C商城小程序的融合应用

摘要&#xff1a;在数字化时代&#xff0c;公私域流量的有效融合已成为企业获取持续增长动力的关键。本文旨在探讨如何通过链动21模式、AI智能名片以及S2B2C商城小程序源码的综合运用&#xff0c;实现公私域流量的高效互通&#xff0c;进而为门店创造巨大商机。通过分析这些工具…

前后端跨域问题及其在ThinkPHP中的解决方案

在现代Web开发中&#xff0c;前后端分离的架构越来越普遍&#xff0c;但这也带来了跨域问题。跨域指的是在一个域下的网页试图请求另一个域的资源&#xff0c;浏览器出于安全考虑会限制这种行为。本文将探讨如何在ThinkPHP中解决跨域问题。 #### 1. 什么是跨域&#xff1f; 跨…