【深度学习】Pytorch教程(十三):PyTorch数据结构:5、张量的梯度计算:变量(Variable)、自动微分、计算图及其可视化

news2025/3/3 4:12:12

文章目录

  • 一、前言
  • 二、实验环境
  • 三、PyTorch数据结构
    • 1、Tensor(张量)
      • 1. 维度(Dimensions)
      • 2. 数据类型(Data Types)
      • 3. GPU加速(GPU Acceleration)
    • 2、张量的数学运算
      • 1. 向量运算
      • 2. 矩阵运算
      • 3. 向量范数、矩阵范数、与谱半径详解
      • 4. 一维卷积运算
      • 5. 二维卷积运算
      • 6. 高维张量
    • 3、张量的统计计算
    • 4、张量操作
      • 1. 张量变形
      • 2. 索引
      • 3. 切片
      • 4. 张量修改
    • 5、张量的梯度计算
      • 0. 变量(Variable)
      • 1. 自动微分
        • a. 单参数
        • b. 多参数
      • 2. 计算图
      • 3. 神经网络模型
      • 4. 可视化计算图结构

一、前言

  本文将介绍张量的梯度计算,包括变量(Variable)、自动微分、计算图及其可视化等

二、实验环境

  本系列实验使用如下环境

conda create -n DL python==3.11
conda activate DL
conda install pytorch torchvision torchaudio pytorch-cuda=12.1 -c pytorch -c nvidia
conda install pydot

三、PyTorch数据结构

1、Tensor(张量)

  Tensor(张量)是PyTorch中用于表示多维数据的主要数据结构,类似于多维数组,可以存储和操作数字数据。

1. 维度(Dimensions)

  Tensor(张量)的维度(Dimensions)是指张量的轴数或阶数。在PyTorch中,可以使用size()方法获取张量的维度信息,使用dim()方法获取张量的轴数。

在这里插入图片描述

2. 数据类型(Data Types)

  PyTorch中的张量可以具有不同的数据类型:

  • torch.float32或torch.float:32位浮点数张量。
  • torch.float64或torch.double:64位浮点数张量。
  • torch.float16或torch.half:16位浮点数张量。
  • torch.int8:8位整数张量。
  • torch.int16或torch.short:16位整数张量。
  • torch.int32或torch.int:32位整数张量。
  • torch.int64或torch.long:64位整数张量。
  • torch.bool:布尔张量,存储True或False。

【深度学习】Pytorch 系列教程(一):PyTorch数据结构:1、Tensor(张量)及其维度(Dimensions)、数据类型(Data Types)

3. GPU加速(GPU Acceleration)

【深度学习】Pytorch 系列教程(二):PyTorch数据结构:1、Tensor(张量): GPU加速(GPU Acceleration)

2、张量的数学运算

  PyTorch提供了丰富的操作函数,用于对Tensor进行各种操作,如数学运算、统计计算、张量变形、索引和切片等。这些操作函数能够高效地利用GPU进行并行计算,加速模型训练过程。

1. 向量运算

【深度学习】Pytorch 系列教程(三):PyTorch数据结构:2、张量的数学运算(1):向量运算(加减乘除、数乘、内积、外积、范数、广播机制)

2. 矩阵运算

【深度学习】Pytorch 系列教程(四):PyTorch数据结构:2、张量的数学运算(2):矩阵运算及其数学原理(基础运算、转置、行列式、迹、伴随矩阵、逆、特征值和特征向量)

3. 向量范数、矩阵范数、与谱半径详解

【深度学习】Pytorch 系列教程(五):PyTorch数据结构:2、张量的数学运算(3):向量范数(0、1、2、p、无穷)、矩阵范数(弗罗贝尼乌斯、列和、行和、谱范数、核范数)与谱半径详解

4. 一维卷积运算

【深度学习】Pytorch 系列教程(六):PyTorch数据结构:2、张量的数学运算(4):一维卷积及其数学原理(步长stride、零填充pad;宽卷积、窄卷积、等宽卷积;卷积运算与互相关运算)

5. 二维卷积运算

【深度学习】Pytorch 系列教程(七):PyTorch数据结构:2、张量的数学运算(5):二维卷积及其数学原理

6. 高维张量

3、张量的统计计算

【深度学习】Pytorch教程(九):PyTorch数据结构:3、张量的统计计算详解

4、张量操作

1. 张量变形

【深度学习】Pytorch教程(十):PyTorch数据结构:4、张量操作(1):张量变形操作

2. 索引

3. 切片

【深度学习】Pytorch 教程(十一):PyTorch数据结构:4、张量操作(2):索引和切片操作

4. 张量修改

【深度学习】Pytorch 教程(十二):PyTorch数据结构:4、张量操作(3):张量修改操作(拆分、拓展、修改)

5、张量的梯度计算

0. 变量(Variable)

  Variable(变量)是早期版本中的一种概念,用于自动求导(autograd)。从PyTorch 0.4.0版本开始,Variable已经被弃用,自动求导功能直接集成在张量(Tensor)中,因此不再需要显式地使用Variable。
  在早期版本的PyTorch中,Variable是一种包装张量的方式,它包含了张量的数据、梯度和其他与自动求导相关的信息。可以对Variable进行各种操作,就像操作张量一样,而且它会自动记录梯度信息。然后,通过调用.backward()方法,可以对Variable进行反向传播,计算梯度,并将梯度传播到相关的变量。

import torch
from torch.autograd import Variable
 
# 创建一个Variable
x = Variable(torch.tensor([2.0]), requires_grad=True)
 
# 定义一个计算图
y = x ** 2 + 3 * x + 1
 
# 进行反向传播
y.backward()
 
# 获取梯度
gradient = x.grad
print("梯度:", gradient)  # 输出: tensor([7.])

1. 自动微分

  PyTorch 使用自动微分机制来计算梯度,当定义一个 Tensor 对象时,可以通过设置 requires_grad=True 来告诉 PyTorch 跟踪相关的计算,并使用 backward() 方法来计算梯度:

a. 单参数
import torch


x = torch.tensor([2.0], requires_grad=True)


def f(x):
    return 5 * x ** 2 + 2 * x - 1


# 计算函数的输出
y = f(x)

# 使用backward()方法计算梯度
y.backward()

# 获取梯度值
gradient = x.grad
print(gradient)

输出:

tensor([22.])

y ′ = 10 x + 2 22 = 10 ∗ 2 + 2 y'=10x+2\\22=10*2+2 y=10x+222=102+2

b. 多参数
import torch


def f(x, w, b):
    return 1 / (torch.exp(-(w * x + b)) + 1)


# 设置输入参数
x = torch.tensor(1.0, requires_grad=True)
w = torch.tensor(0.0, requires_grad=True)
b = torch.tensor(0.0, requires_grad=True)

# 计算函数f(x;w,b)的值
result = f(x, w, b)

# 使用自动微分计算梯度
result.backward()
print(x.grad)  # 求关于x的梯度
print(w.grad)  # 求关于w的梯度
print(b.grad)  # 求关于b的梯度

在这里插入图片描述

2. 计算图

  计算图是一种用来表示数学运算过程的图形化结构,它将数学计算表达为节点和边的关系,提供了一种直观的方式来理解和推导复杂的数学运算过程。在深度学习中,计算图帮助我们理解模型的训练过程,直观地把握损失函数对模型参数的影响,同时为反向传播算法提供了理论基础。现代深度学习框架如 PyTorch 和 TensorFlow 都是基于计算图的理论基础构建出来的。
f ( x ; w , b ) = 1 e − ( w x + b ) + 1 f(x;w,b)=\frac{1}{e^{-(wx+b)}+1} f(x;w,b)=e(wx+b)+11 x = 1 , w = 0 , b = 0 x=1,w=0,b=0 x=1,w=0,b=0
在这里插入图片描述
在这里插入图片描述
  计算图包括两种节点:计算节点(Compute Node)和数据节点(Data Node)。

  • 数据节点:表示输入数据、参数或中间变量,在计算图中通常用圆形结点表示。数据节点始终是叶节点,它们没有任何输入,仅表示数据。

  • 计算节点:表示数学运算过程,它将输入的数据节点进行数学运算后输出结果。在计算图中通常用方形结点表示。计算节点可以有多个输入和一个输出。反向传播算法中的梯度计算正是通过计算节点来实现的。

一个完整的计算图可以分为正向传播和反向传播两个阶段:

  • 正向传播(Forward Propagation):输入数据经过计算节点逐层传播,最终得到输出结果

  • 反向传播(Backward Propagation):首先根据损失函数计算输出结果与真实标签之间的误差,然后利用链式法则,逐个计算每个计算节点对应的输入的梯度,最终得到参数的梯度信息

3. 神经网络模型

略~详见后文模块(Module) 部分

import torch


class MyModel(torch.nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()

    def forward(self, x, w, b):
        return 1 / (torch.exp(-(w * x + b)) + 1)


model = MyModel()

# 设置输入参数
x = torch.tensor(1.0, requires_grad=True)
w = torch.tensor(0.0, requires_grad=True)
b = torch.tensor(0.0, requires_grad=True)

# Forward pass
output = model(x, w, b)

# Backward pass
output.backward()

# Access the gradients
print(x.grad)  # Gradient with respect to x
print(w.grad)  # Gradient with respect to w
print(b.grad)  # Gradient with respect to b

4. 可视化计算图结构

  以下内容完全由GPT生成:

import torch
import onnx
from onnx.tools.net_drawer import GetPydotGraph
import pydot


class MyModel(torch.nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()

    def forward(self, x, w, b):
        return 1 / (torch.exp(-(w * x + b)) + 1)


model = MyModel()

x = torch.tensor(1.0, requires_grad=True)
w = torch.tensor(0.0, requires_grad=True)
b = torch.tensor(0.0, requires_grad=True)

# 导出为ONNX格式
torch.onnx.export(model, (x, w, b), "f.onnx", verbose=True, input_names=["x", "w", "b"], output_names=["output"])

# 可视化计算图结构
model = onnx.load("f.onnx")
pydot_graph = GetPydotGraph(model.graph, name=model.graph.name, rankdir="TB"
                            # rankdir="LR"
                            )  # Set rankdir to "LR" for left to right orientation
pydot_graph.write_dot("f_computation_graph.dot")

# 将dot文件转换为PNG格式的图像
(graph,) = pydot.graph_from_dot_file('f_computation_graph.dot')
graph.write_png('f_computation_graph.png')

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1472912.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

景联文科技:引领战场数据标注服务,赋能态势感知升级

自21世纪初,信息化战争使战场环境变得更为复杂和难以预测,持续涌入的海量、多样化、多来源和高维度数据,加大了指挥员的认知负担,使其需要具备更强的数据处理能力。 同时,计算机技术和人工智能技术的飞速发展&#xff…

EMQX Enterprise 5.5 发布:新增 Elasticsearch 数据集成

EMQX Enterprise 5.5.0 版本已正式发布! 在这个版本中,我们引入了一系列新的功能和改进,包括对 Elasticsearch 的集成、Apache IoTDB 和 OpenTSDB 数据集成优化、授权缓存支持排除主题等功能。此外,新版本还进行了多项改进以及 B…

Gemma谷歌(google)开源大模型微调实战(fintune gemma-2b)

Gemma-SFT Gemma-SFT(谷歌, Google), gemma-2b/gemma-7b微调(transformers)/LORA(peft)/推理 项目地址 https://github.com/yongzhuo/gemma-sft全部weights要用fp32/tf32, 使用fp16微调十几或几十的步数后大概率lossnan;(即便layer-norm是fp32也不行, LLaMA就没有这个问题, …

SpringBoot/Java中OCR实现,集成Tess4J实现图片文字识别

场景 Tesseract Tesseract是一个开源的光学字符识别(OCR)引擎,它可以将图像中的文字转换为计算机可读的文本。 支持多种语言和书面语言,并且可以在命令行中执行。它是一个流行的开源OCR工具,可以在许多不同的操作系…

【Vue3】插槽使用和animate使用

插槽使用 插槽slot匿名插槽具名插槽插槽作用域简写 动态插槽transition动画组件自定义过渡class类名如何使用animate动画库组件动画生命周期appear transition- group过渡列表 插槽slot 插槽就是子组件中提供给父组件使用的一个占位符父组件可以在这个占位符智能填充任何模板代…

pytest-配置项目不同环境URL

pytest自动化中,在不同环境进行测试,可以将项目中的url单独抽取出来,通过pytest.ini配置文件实现(类似postman中的“Environments”) 使用步骤: 1)安装pytest-base-url插件 pytest-base-url …

【Flink精讲】Flink状态及Checkpoint调优

RocksDB大状态调优 RocksDB 是基于 LSM Tree 实现的(类似 HBase) ,写数据都是先缓存到内存中, 所以 RocksDB 的写请求效率比较高。 RocksDB 使用内存结合磁盘的方式来存储数据,每 次获取数据时,先从内存中 …

Mac下载安装配置运行MySQL

一、打开官网 MySQL :: Download MySQL Community Serverhttps://dev.mysql.com/downloads/mysql/ 1、根据自己的电脑版本下载相对应的MySQL版本,Mac分为ARM和X86两个不同的架构 ​ 不知道自己电脑是ARM还是X86的,如下操作进行查询 uname -a 我的电脑是…

2024年Apache DolphinScheduler RoadMap:引领开源调度系统的未来

非常欢迎大家来到Apache DolphinScheduler社区!随着开源技术在全球范围内的快速发展,社区的贡献者 “同仁” 一直致力于构建一个强大而活跃的开源调度系统社区,为用户提供高效、可靠的任务调度和工作流管理解决方案。 在过去的一段时间里&…

代码随想录|学习工具分享

工具分享 画图 https://excalidraw.com/ 大家平时刷题可以用这个网站画草稿图帮助理解!如果看题解很蒙或者思路不清晰的时候,跟着程序处理流程画一个图,90%的情况下都可以解决问题! 数据结构可视化 https://www.cs.usfca.edu/…

【Java设计模式】一、工厂模式、建造者模式、原型设计模式

文章目录 1、简单工厂模式2、工厂方法模式3、抽象工厂模式4、建造者模式5、原型设计模式 设计模式即总结出来的一些最佳实现。23种设计模式可分为三大类: 创建型模式:隐藏了创建对象的过程,通过逻辑方法进行创建对象,而不是直接n…

逆序或者正序打印一个数的每一位数,递归实现(C语言)

从键盘上输入一个不多于5位(包括5位)的正整数,要求 (1)求出它是几位数;(2)分别输出每一位数字(3)按逆序输出各位数字 (1)求出它是几位…

Linux浅学笔记04

目录 Linux实用操作 Linux系统下载软件 yum命令 apt systemctl命令 ln命令 日期和时区 IP地址 主机名 网络传输-下载和网络请求 ping命令 wget命令 curl命令 网络传输-端口 进程 ps 命令 关闭进程命令: 主机状态监控命令 磁盘信息监控&#xff1a…

【MQ05】异常消息处理

异常消息处理 上节课我们已经学习到了消息的持久化和确认相关的内容。但是,光有这些还不行,如果我们的消费者出现问题了,无法确认,或者直接报错产生异常了,这些消息要怎么处理呢?直接丢弃?这就是…

深入理解计算机系统学习笔记

第三章 程序的机器级表示 3.2.1 机器级代码 对于机器级编程来说,其中两种抽象尤为重要。第一种是由捍令集体系结构或指令集架构(Instruction Set Architecture, ISA)来定义机器级程序的 格式和行为,它定义了处理器状态、指令的格式&#xf…

在Ubuntu上为ARM 8处理器安装Python 3.10.4虚拟环境指南

在Ubuntu上为ARM 8处理器安装Python 3.10.4虚拟环境指南 安装Anaconda或Miniconda: 首先,您需要从官方网站下载适用于ARM架构的Anaconda或Miniconda安装包。下载完成后,在终端中使用bash Anaconda3-2019.10-Linux-armv8.sh(文件…

将仓库A中的部分提交迁移到仓库B中

结论: 使用git format-patchgit am即可实现 使用场景: 例如仓库A这里有5个提交记录,commitid1, commitid2, commitid3, commitid4,commitid5 仓库B想用仓库A中提交的代码,手动改比较慢,当改动较多的时候…

2.26 Qt day4+5 纯净窗口移动+绘画事件+Qt实现TCP连接服务+Qt实现连接数据库

思维导图 Qt实现TCP连接 服务器端&#xff1a; widget.h #ifndef WIDGET_H #define WIDGET_H#include <QWidget> #include<QTcpServer>//服务器端类 #include<QTcpSocket>//客户端类 #include<QMessageBox>//消息对话框类 #include<QList>//链…

2024-02-26(Spark,kafka)

1.Spark SQL是Spark的一个模块&#xff0c;用于处理海量结构化数据 限定&#xff1a;结构化数据处理 RDD的数据开发中&#xff0c;结构化&#xff0c;非结构化&#xff0c;半结构化数据都能处理。 2.为什么要学习SparkSQL SparkSQL是非常成熟的海量结构化数据处理框架。 学…

实践航拍小目标检测,基于轻量级YOLOv8n开发构建无人机航拍场景下的小目标检测识别分析系统

关于无人机相关的场景在我们之前的博文也有一些比较早期的实践&#xff0c;感兴趣的话可以自行移步阅读即可&#xff1a; 《deepLabV3Plus实现无人机航拍目标分割识别系统》 《基于目标检测的无人机航拍场景下小目标检测实践》 《助力环保河道水质监测&#xff0c;基于yolov…