计算图 Compute Graph 和自动求导 Autograd | PyTorch 深度学习实战

news2025/2/5 12:21:40

前一篇文章,Tensor 基本操作5 device 管理,使用 GPU 设备 | PyTorch 深度学习实战

本系列文章 GitHub Repo: https://github.com/hailiang-wang/pytorch-get-started

PyTorch 计算图和 Autograd

  • 微积分之于机器学习
  • Computational Graphs 计算图
  • Autograd 自动求导
  • 一个训练过程及 no_grad 的使用
    • 示例代码
    • 执行结果
      • 生成数据
      • 第一轮后
      • 第二轮后
      • 第十轮后
  • 更多计算图的知识
    • 更为复杂点的计算图的样子
    • 自动求导有关的参数
  • Links

微积分之于机器学习

机器学习的主要工作原理,就是万事万物存在规律,而我们使用机器来完成参数评估。参数评估的过程是随机梯度下降,也就是任意选择起点,然后使用微积分技术指导我们调优,找到一组最优参数值。

这就像我们爬山,面对众多的山峰,我们从不同的出发点出发,不断的朝着山顶前进,最终,我们即便起点不同,都可以达到山顶 - 通向山顶的路有多条。另外一方面,我们可能来到了不同的山顶。

在我们爬山的过程中,如何选择下一步呢?这时,就是微积分大显身手的时候了。

在机器学习中,对参数优化的过程,使用了大量微积分的运算,PyTorch 能成为通用性的机器学习框架,就在于不同的机器学习任务底层的数学原理是一致的,而 PyTorch 内置了这些标准化的数学运算,在 PyTorch 中,除了 Tensor 外,还有两个关键的概念:

  • 计算图
  • 自动求导

Computational Graphs 计算图

神经网络是由很多神经元组成的网络,最简单的神经网络就是只包含一个线性神经元的神经网络,理解这个最简单的神经网络,有助于理解任何复杂的神经网络。

z = x ∗ w + b z = x * w + b z=xw+b

注意:这里没有添加激活函数,这个神经元是一个简单的线性神经元。
在这里插入图片描述
计算过程:

  1. 加权输出 z 与理想输出 y 之间,使用交叉熵(CE)计算出损失(loss)
  2. 然后基于 loss 计算梯度 grad
  3. 基于梯度更新 w 和 b

这个计算过程,可以用一张图表达,一个图就是由节点以及边组成,边上定义操作符。同时,这个计算过程会在训练中发生多次,因为梯度下降算法是 SGD 迭代运算。

PyTorch 为了让每次运算可以更灵活,比如使用 Dropout 随机丢弃一些神经元,PyTorch 实现了每次运算动态的生成这张图 - 动态计算图1。也就是说,对于每次运算,PyTorch 会生成一个计算图并附着计算状态。

Autograd 自动求导

附着状态,最主要的目的就是实现自动求导。因为每个节点都是一个变量,变量和变量之间通过操作符相互依赖,而操作符和变量构成的函数式,就可以实现求导,根据链式法则,实现计算图中,每个变量的导数的计算。

在上图,只有一个线性神经元的情况下,PyTorch 的自动求导是如何工作的呢?参考下面的代码。

import torch

# 定义输入和理想输出
x = torch.ones(5)   # input tensor
y = torch.zeros(3)  # expected output

# 定义参数
w = torch.randn(5, 3, requires_grad=True)
b = torch.randn(3, requires_grad=True)

# 定义模型,并进行一次运算
z = torch.matmul(x, w)+b

# 定义损失函数,并得到单次的损失
loss = torch.nn.functional.binary_cross_entropy_with_logits(z, y)

# 进行反向传播,并得到梯度
loss.backward()
print(w.grad)
print(b.grad)

如此一来,参数更新将变得非常简单。计算图允许每次迭代传入不同的操作符等,实现训练过程更灵活的配置。计算图保留了运算过程中的 Tensor、操作符、操作符对应的导函数。当 loss.backward() 调用时,顺序的调用自动求导变量的导函数,得到 .grad 梯度值。

一个训练过程及 no_grad 的使用

现在我们看一个例子,通过一个简单的模型,了解训练中,自动求导机制是如何工作的。

示例代码


'''
autograd
'''
import plotly.graph_objects as go
import plotly.express as px
from torch import nn
import numpy as np
import torch
import math

# 输入变量 x,理想输出 yt(生成 y 的函数就是要拟合的模型) 
X  = torch.tensor(np.linspace(-10, 10, 1000))
y  = 1.5 * torch.sin(X) + 1.2 * torch.cos(X/4) # 真实的模型
yt = y + np.random.normal(0, 1, 1000)

# vis
def plotter(X, y, yhat=None, title=None):
    with torch.no_grad():
        fig = go.Figure()
        fig.add_trace(go.Scatter(x=X, y=y, mode='lines',    name='y'))
        fig.add_trace(go.Scatter(x=X, y=yt, mode='markers', marker=dict(size=4), name='yt'))
        if yhat is not None: fig.add_trace(go.Scatter(x=X, y=yhat, mode='lines', name='yhat'))
        fig.update_layout(template='none', title=title)
        fig.show()

plotter(X, y, title='Data Generating Process')

# 计算模型的实际输出,这里前提是假设知道变量 X 和函数 sin|cos, 而不知道参数 theta
def fit_model(theta:torch.tensor=torch.rand(3, requires_grad=True)):
    return theta[0] * X + theta[1] * torch.sin(X) + theta[2] * torch.cos(X/4)

# 随机初始化参数,开启自动求导
theta = torch.randn(3, requires_grad=True)

# 损失函数和优化器
loss_fn  = nn.MSELoss()                         # MSE loss
optimizer = torch.optim.SGD([theta], lr=0.01)   # build optimizer 

# 迭代训练
epochs = 500
for i in range(epochs):
    yhat = fit_model(theta)  # 计算实际输出
    loss = loss_fn(y, yhat)  # 将实际输出和理想输出传入损失函数,得到损失 loss
    loss.backward()          # 反向传播,完成 .grad 梯度的计算
    optimizer.step()         # 基于梯度完成参数更新 
    optimizer.zero_grad()    # 本轮计算完成,将梯度值归零,否则下次计算损失并调用 backward 导致梯度累计 
    if i % (epochs/10) == 0: # 验证及输出调试信息 
        msg = f"loss: {loss.item():>7f} theta: {theta.detach().numpy()}"
        yhat = fit_model(theta)
        plotter(X, y, yhat.detach(), title=f"loss: {loss.item():>7f} theta: {theta.detach().numpy().round(3)}")

执行结果

生成数据

创建了一个假数据:

  • 分布在象限中的点就是 x,y
  • 象限中的曲线,就是符合设想的模型,我们看最终的机器学习的模型,能否拟合这条曲线
    在这里插入图片描述

第一轮后

初始化后,实际模型和理想模型差距很大。注意,此时 theta 和目标参数差距很大。
在这里插入图片描述

第二轮后

经过两次迭代,差距在缩小。

在这里插入图片描述

第十轮后

又经过了几轮训练,此时,我们发现图中已经分辨不出来,但是从 theta 的值,我们还可以看到一点差距,这已经证明,机器学习拟合上了目标空间。
在这里插入图片描述

更多计算图的知识

更为复杂点的计算图的样子

在训练中,生成的 DAG 类似如下。
在这里插入图片描述

自动求导有关的参数

# 做一个计算图
x = torch.rand(1)
b = torch.rand(1, requires_grad=True)
w = torch.rand(1, requires_grad=True)
y = w * x  # y 是一个新的 tensor

# 检查 y 是否是叶子节点,这里 y 是输出,也就是 root 节点而不是 leaf 节点
print(y.is_leaf)

# 反向传播 
y.backward(retain_graph=True)  # retain_graph=True,保留计算图中的状态,https://discuss.pytorch.org/t/use-of-retain-graph-true/179658
print(w.grad) # 查看梯度

Links

  • How Computational Graphs are Constructed in PyTorch
  • How Computational Graphs are Executed in PyTorch
  • PyTorch’s Dynamic Graphs (Autograd)
  • Automatic Differentiation with torch.autograd
  • Autograd mechanics

  1. PyTorch 使用 DAG 有向无环图这种格式存储计算图,其中输入的 Tensor 称为叶子节点(leaves),输出的 Tensor 称为根节点(roots)。 ↩︎

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2292309.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

接入DeepSeek大模型

接入DeepSeek 下载并安装Ollamachatbox 软件配置大模型 下载并安装Ollama 下载并安装Ollama, 使用参数ollama -v查看是否安装成功。 输入命令ollama list, 可以看到已经存在4个目录了。 输入命令ollama pull deepseek-r1:1.5b, 下载deepse…

【论文复现】粘菌算法在最优经济排放调度中的发展与应用

目录 1.摘要2.黏菌算法SMA原理3.改进策略4.结果展示5.参考文献6.代码获取 1.摘要 本文提出了一种改进粘菌算法(ISMA),并将其应用于考虑阀点效应的单目标和双目标经济与排放调度(EED)问题。为提升传统粘菌算法&#xf…

UE Bridge混合材质工具

打开虚幻内置Bridge 随便点个材质点右下角图标 就能打开材质混合工具 可以用来做顶点绘制

基于 yolov8_pyqt5 自适应界面设计的火灾检测系统 demo:毕业设计参考

基于 yolov8_pyqt5 自适应界面设计的火灾检测系统 demo:毕业设计参考 【毕业设计参考】基于yolov8-pyqt5自适应界面设计的火灾检测系统demo.zip资源-CSDN文库 【毕业设计参考】基于yolov8-pyqt5自适应界面设计的火灾检测系统demo.zip资源-CSDN文库 一、项目背景 …

Linux 传输层协议 UDP 和 TCP

UDP 协议 UDP 协议端格式 16 位 UDP 长度, 表示整个数据报(UDP 首部UDP 数据)的最大长度如果校验和出错, 就会直接丢弃 UDP 的特点 UDP 传输的过程类似于寄信 . 无连接: 知道对端的 IP 和端口号就直接进行传输, 不需要建立连接不可靠: 没有确认机制, 没有重传机制; 如果因…

chrome浏览器chromedriver下载

chromedriver 下载地址 https://googlechromelabs.github.io/chrome-for-testing/ 上面的链接有和当前发布的chrome浏览器版本相近的chromedriver 实际使用感受 chrome浏览器会自动更新,可以去下载最新的chromedriver使用,自动化中使用新的chromedr…

第一个Qt开发实例(一个Push Button按钮和两个Label)【包括如何在QtCreator中创建新工程、代码详解、编译、环境变量配置、测试程序运行等】

目录 Qt开发环境QtCreator的安装、配置在QtCreator中创建新工程在Forms→mainwindow.ui中拖曳出我们要的图形按钮查看拖曳出按钮后的代码为pushButton这个图形添加回调函数编译工程关闭开发板上QT的GUI(选做)禁止LCD黑屏(选做)设置Qt运行的环境变量运行Qt程序如何让程序在系统启…

【react+redux】 react使用redux相关内容

首先说一下,文章中所提及的内容都是我自己的个人理解,是我理逻辑的时候,自我说服的方式,如果有问题有补充欢迎在评论区指出。 一、场景描述 为什么在react里面要使用redux,我的理解是因为想要使组件之间的通信更便捷…

【435. 无重叠区间 中等】

题目: 给定一个区间的集合 intervals ,其中 intervals[i] [starti, endi] 。返回 需要移除区间的最小数量,使剩余区间互不重叠 。 注意 只在一点上接触的区间是 不重叠的。例如 [1, 2] 和 [2, 3] 是不重叠的。 示例 1: 输入: intervals …

文献学习笔记:中风醒脑液(FYTF-919)临床试验解读:有效还是无效?

【中风醒脑液(FYTF-919)临床试验解读:有效还是无效?】 在发表于 The Lancet (2024 年 11 月 30 日,第 404 卷)的临床研究《Traditional Chinese medicine FYTF-919 (Zhongfeng Xingnao oral pr…

vue2语法速通

首先,git clone下来的项目要npm install下载依赖,如果是vue项目,运行通常npm run serve或者npm run dev vue速通一下 使用vite创建项目(较快) npm create vite 配置文件 src/ ├── assets/ # 存放…

【商品库存管理——差分、前缀和】

题目 代码 #include <bits/stdc.h> using namespace std; const int N 3e510; int l[N], r[N], b[N]; int s1[N], s0[N]; int main() {int n, m;cin >> n >> m;for(int i 1; i < m; i){cin >> l[i] >> r[i];b[l[i]], b[r[i]1]--;}int a 0…

Linux基本指令2

07.man指令&#xff08;重要&#xff09;&#xff1a; Linux的命令有很多参数&#xff0c;我们不可能全记住&#xff0c;我们可以通过查看联机手册获取帮助。访问Linux手册页的命令是 man 语法: man [选项] 命令 man ls查看ls指令更多的说明。 man man&#xff1a; man指令就…

Android学习19 -- 手搓App

1 前言 之前工作中&#xff0c;很多时候要搞一个简单的app去验证底层功能&#xff0c;Android studio又过于重型&#xff0c;之前用gradle&#xff0c;被版本匹配和下载外网包折腾的堪称噩梦。所以搞app都只有找应用的同事帮忙。一直想知道一些简单的app怎么能手搓一下&#x…

人工智能导论-第3章-知识点与学习笔记

参考教材3.2节的内容&#xff0c;介绍什么是自然演绎推理&#xff1b;解释“肯定后件”与“否定前件”两类错误的演绎推理是什么意义&#xff0c;给出具体例子加以阐述。参考教材3.3节的内容&#xff0c;介绍什么是文字&#xff08;literal&#xff09;&#xff1b;介绍什么是子…

DeepSeek 的含金量还在上升

大家好啊&#xff0c;我是董董灿。 最近 DeepSeek 越来越火了。 网上有很多针对 DeepSeek 的推理测评&#xff0c;除此之外&#xff0c;也有很多人从技术的角度来探讨 DeepSeek 带给行业的影响。 比如今天就看到了一篇文章&#xff0c;探讨 DeepSeek 在使用 GPU 进行模型训练…

【Linux系统】信号:信号保存 / 信号处理、内核态 / 用户态、操作系统运行原理(中断)

理解Linux系统内进程信号的整个流程可分为&#xff1a; 信号产生 信号保存 信号处理 上篇文章重点讲解了 信号的产生&#xff0c;本文会讲解信号的保存和信号处理相关的概念和操作&#xff1a; 两种信号默认处理 1、信号处理之忽略 ::signal(2, SIG_IGN); // ignore: 忽略#…

【Numpy核心编程攻略:Python数据处理、分析详解与科学计算】2.6 广播机制核心算法:维度扩展的数学建模

2.6 广播机制核心算法&#xff1a;维度扩展的数学建模 目录/提纲 #mermaid-svg-IfELXmhcsdH1tW69 {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-IfELXmhcsdH1tW69 .error-icon{fill:#552222;}#mermaid-svg-IfELXm…

硬件产品经理:需求引力模型(DGM)

目录 1、DGM 模型简介 2、理论核心&#xff1a;打破传统线性逻辑 3、三大定律 第一定律&#xff1a;暗物质需求法则 第二定律&#xff1a;引力井效应 第三定律&#xff1a;熵减增长律 4、落地工具包 工具1&#xff1a;需求密度热力图 工具3&#xff1a;摩擦力歼灭清单…

Guided Decoding (借助FSM,有限状态自动机)

VLLM对结构化输出的支持&#xff1a; vllm/docs/source/features/structured_outputs.md at main vllm-project/vllm GitHub VLLM对tool call的支持&#xff1a; vllm/docs/source/features/tool_calling.md at main vllm-project/vllm GitHub 以上指定输出格式&#xf…