pytorch 3 计算图

news2024/9/21 14:27:53

计算图结构

**加粗样式**

分析:

  1. 起始节点 a
  2. b = 5 - 3a
  3. c = 2b + 3
  4. d = 5b + 6
  5. e = 7c + d^2
  6. f = 2e
  7. 最终输出 g = 3f - o(其中 o 是另一个输入)

前向传播

前向传播按照上述顺序计算每个节点的值。

反向传播过程

反向传播的目标是计算损失函数(这里假设为 g)对每个中间变量和输入的偏导数。从右向左进行计算:

  1. ∂g/∂o = -1
  2. ∂g/∂f = 3
  3. ∂f/∂e = 2
  4. ∂e/∂c = 7
  5. ∂e/∂d = 2d
  6. ∂d/∂b = 5
  7. ∂c/∂b = 2
  8. ∂b/∂a = -3

链式法则应用

使用链式法则计算出 g 对每个变量的全导数:

  1. dg/df = ∂g/∂f = 3
  2. dg/de = (∂g/∂f) * (∂f/∂e) = 3 * 2 = 6
  3. dg/dc = (dg/de) * (∂e/∂c) = 6 * 7 = 42
  4. dg/dd = (dg/de) * (∂e/∂d) = 6 * 2d
  5. dg/db = (dg/dc) * (∂c/∂b) + (dg/dd) * (∂d/∂b)
    = 42 * 2 + 6 * 2d * 5
    = 84 + 60d
  6. dg/da = (dg/db) * (∂b/∂a)
    = (84 + 60d) * (-3)
    = -252 - 180d

最终梯度

最终得到 g 对输入 a 和 o 的梯度:

  • dg/da = -252 - 180d
  • dg/do = -1

代码实现

静态图

import math

class Node:
    """
    表示计算图中的一个节点。
    每个节点都可以存储一个值、梯度,并且知道如何计算前向传播和反向传播。
    """
    def __init__(self, value=None):
        self.value = value  # 节点的值
        self.gradient = 0   # 节点的梯度
        self.parents = []   # 父节点列表
        self.forward_fn = lambda: None  # 前向传播函数
        self.backward_fn = lambda: None  # 反向传播函数

    def __add__(self, other):
        """加法操作"""
        return self._create_binary_operation(other, lambda x, y: x + y, lambda: (1, 1))

    def __mul__(self, other):
        """乘法操作"""
        return self._create_binary_operation(other, lambda x, y: x * y, lambda: (other.value, self.value))

    def __sub__(self, other):
        """减法操作"""
        return self._create_binary_operation(other, lambda x, y: x - y, lambda: (1, -1))

    def __pow__(self, power):
        """幂运算"""
        result = Node()
        result.parents = [self]
        def forward():
            result.value = math.pow(self.value, power)
        def backward():
            self.gradient += power * math.pow(self.value, power-1) * result.gradient
        result.forward_fn = forward
        result.backward_fn = backward
        return result

    def _create_binary_operation(self, other, forward_op, gradient_op):
        """
        创建二元操作的辅助方法。
        用于简化加法、乘法和减法的实现。
        """
        result = Node()
        result.parents = [self, other]
        def forward():
            result.value = forward_op(self.value, other.value)
        def backward():
            grads = gradient_op()
            self.gradient += grads[0] * result.gradient
            other.gradient += grads[1] * result.gradient
        result.forward_fn = forward
        result.backward_fn = backward
        return result

def topological_sort(node):
    """
    对计算图进行拓扑排序。
    确保在前向和反向传播中按正确的顺序处理节点。
    """
    visited = set()
    topo_order = []
    def dfs(n):
        if n not in visited:
            visited.add(n)
            for parent in n.parents:
                dfs(parent)
            topo_order.append(n)
    dfs(node)
    return topo_order

# 构建计算图
a = Node(2)  # 假设a的初始值为2
o = Node(1)  # 假设o的初始值为1

# 按照给定的数学表达式构建计算图
b = Node(5) - a * Node(3)
c = b * Node(2) + Node(3)
d = b * Node(5) + Node(6)
e = c * Node(7) + d ** 2
f = e * Node(2)
g = f * Node(3) - o

# 前向传播
sorted_nodes = topological_sort(g)
for node in sorted_nodes:
    node.forward_fn()

# 反向传播
g.gradient = 1  # 设置输出节点的梯度为1
for node in reversed(sorted_nodes):
    node.backward_fn()

# 打印结果
print(f"g = {g.value}")
print(f"dg/da = {a.gradient}")
print(f"dg/do = {o.gradient}")

# 验证手动计算的结果
d_value = 5 * b.value + 6
expected_dg_da = -252 - 180 * d_value
print(f"Expected dg/da = {expected_dg_da}")
print(f"Difference: {abs(a.gradient - expected_dg_da)}")

动态图

import math

class Node:
    """
    表示计算图中的一个节点。
    实现了动态计算图的核心功能,包括前向计算和反向传播。
    """
    def __init__(self, value, children=(), op=''):
        self.value = value  # 节点的值
        self.grad = 0       # 节点的梯度
        self._backward = lambda: None  # 反向传播函数,默认为空操作
        self._prev = set(children)  # 前驱节点集合
        self._op = op  # 操作符,用于调试

    def __add__(self, other):
        """加法操作"""
        other = other if isinstance(other, Node) else Node(other)
        result = Node(self.value + other.value, (self, other), '+')
        def _backward():
            self.grad += result.grad
            other.grad += result.grad
        result._backward = _backward
        return result

    def __mul__(self, other):
        """乘法操作"""
        other = other if isinstance(other, Node) else Node(other)
        result = Node(self.value * other.value, (self, other), '*')
        def _backward():
            self.grad += other.value * result.grad
            other.grad += self.value * result.grad
        result._backward = _backward
        return result

    def __pow__(self, other):
        """幂运算"""
        assert isinstance(other, (int, float)), "only supporting int/float powers for now"
        result = Node(self.value ** other, (self,), f'**{other}')
        def _backward():
            self.grad += (other * self.value**(other-1)) * result.grad
        result._backward = _backward
        return result

    def __neg__(self):
        """取反操作"""
        return self * -1

    def __sub__(self, other):
        """减法操作"""
        return self + (-other)

    def __truediv__(self, other):
        """除法操作"""
        return self * other**-1

    def __radd__(self, other):
        """反向加法"""
        return self + other

    def __rmul__(self, other):
        """反向乘法"""
        return self * other

    def __rtruediv__(self, other):
        """反向除法"""
        return other * self**-1

    def tanh(self):
        """双曲正切函数"""
        x = self.value
        t = (math.exp(2*x) - 1)/(math.exp(2*x) + 1)
        result = Node(t, (self,), 'tanh')
        def _backward():
            self.grad += (1 - t**2) * result.grad
        result._backward = _backward
        return result

    def backward(self):
        """
        执行反向传播,计算梯度。
        使用拓扑排序确保正确的反向传播顺序。
        """
        topo = []
        visited = set()
        def build_topo(v):
            if v not in visited:
                visited.add(v)
                for child in v._prev:
                    build_topo(child)
                topo.append(v)
        build_topo(self)
        
        self.grad = 1  # 设置输出节点的梯度为1
        for node in reversed(topo):
            node._backward()  # 对每个节点执行反向传播

def main():
    """
    主函数,用于测试自动微分系统。
    构建一个计算图,执行反向传播,并验证结果。
    """
    # 构建计算图
    a = Node(2)
    o = Node(1)
    b = Node(5) - a * 3
    c = b * 2 + 3
    d = b * 5 + 6
    e = c * 7 + d ** 2
    f = e * 2
    g = f * 3 - o

    # 反向传播
    g.backward()

    # 打印结果
    print(f"g = {g.value}")
    print(f"dg/da = {a.grad}")
    print(f"dg/do = {o.grad}")

    # 验证手动计算的结果
    d_value = 5 * b.value + 6
    expected_dg_da = -252 - 180 * d_value
    print(f"Expected dg/da = {expected_dg_da}")
    print(f"Difference: {abs(a.grad - expected_dg_da)}")

if __name__ == "__main__":
    main()

解释:

  1. Node 类代表计算图中的一个节点,包含值、梯度、父节点以及前向和反向传播函数。
  2. 重载的数学运算符 (__add__, __mul__, __sub__, __pow__) 允许直观地构建计算图。
  3. _create_binary_operation 方法用于创建二元操作,简化了加法、乘法和减法的实现。
  4. topological_sort 函数对计算图进行拓扑排序,确保正确的计算顺序。
import math

class Node:
    """
    表示计算图中的一个节点。
    实现了动态计算图的核心功能,包括前向计算和反向传播。
    """
    def __init__(self, value, children=(), op=''):
        self.value = value  # 节点的值
        self.grad = 0       # 节点的梯度
        self._backward = lambda: None  # 反向传播函数,默认为空操作
        self._prev = set(children)  # 前驱节点集合
        self._op = op  # 操作符,用于调试

    def __add__(self, other):
        """加法操作"""
        other = other if isinstance(other, Node) else Node(other)
        result = Node(self.value + other.value, (self, other), '+')
        def _backward():
            self.grad += result.grad
            other.grad += result.grad
        result._backward = _backward
        return result

    def __mul__(self, other):
        """乘法操作"""
        other = other if isinstance(other, Node) else Node(other)
        result = Node(self.value * other.value, (self, other), '*')
        def _backward():
            self.grad += other.value * result.grad
            other.grad += self.value * result.grad
        result._backward = _backward
        return result

    def __pow__(self, other):
        """幂运算"""
        assert isinstance(other, (int, float)), "only supporting int/float powers for now"
        result = Node(self.value ** other, (self,), f'**{other}')
        def _backward():
            self.grad += (other * self.value**(other-1)) * result.grad
        result._backward = _backward
        return result

    def __neg__(self):
        """取反操作"""
        return self * -1

    def __sub__(self, other):
        """减法操作"""
        return self + (-other)

    def __truediv__(self, other):
        """除法操作"""
        return self * other**-1

    def __radd__(self, other):
        """反向加法"""
        return self + other

    def __rmul__(self, other):
        """反向乘法"""
        return self * other

    def __rtruediv__(self, other):
        """反向除法"""
        return other * self**-1

    def tanh(self):
        """双曲正切函数"""
        x = self.value
        t = (math.exp(2*x) - 1)/(math.exp(2*x) + 1)
        result = Node(t, (self,), 'tanh')
        def _backward():
            self.grad += (1 - t**2) * result.grad
        result._backward = _backward
        return result

    def backward(self):
        """
        执行反向传播,计算梯度。
        使用拓扑排序确保正确的反向传播顺序。
        """
        topo = []
        visited = set()
        def build_topo(v):
            if v not in visited:
                visited.add(v)
                for child in v._prev:
                    build_topo(child)
                topo.append(v)
        build_topo(self)
        
        self.grad = 1  # 设置输出节点的梯度为1
        for node in reversed(topo):
            node._backward()  # 对每个节点执行反向传播

def main():
    """
    主函数,用于测试自动微分系统。
    构建一个计算图,执行反向传播,并验证结果。
    """
    # 构建计算图
    a = Node(2)
    o = Node(1)
    b = Node(5) - a * 3
    c = b * 2 + 3
    d = b * 5 + 6
    e = c * 7 + d ** 2
    f = e * 2
    g = f * 3 - o

    # 反向传播
    g.backward()

    # 打印结果
    print(f"g = {g.value}")
    print(f"dg/da = {a.grad}")
    print(f"dg/do = {o.grad}")

    # 验证手动计算的结果
    d_value = 5 * b.value + 6
    expected_dg_da = -252 - 180 * d_value
    print(f"Expected dg/da = {expected_dg_da}")
    print(f"Difference: {abs(a.grad - expected_dg_da)}")

if __name__ == "__main__":
    main()

解释:

  1. Node 类是核心,它代表计算图中的一个节点,并实现了各种数学运算。

  2. 每个数学运算(如 __add__, __mul__ 等)都创建一个新的 Node,并定义了相应的反向传播函数。

  3. backward 方法实现了反向传播算法,使用拓扑排序确保正确的计算顺序。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2039288.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

看demo学算法之 线性回归模型

嗨!今天我们来聊聊如何用Python构建一个简单的线性回归模型。这个过程就像给数据配对舞一样,让它们在舞池里找到最佳位置。准备好了吗?让我们开始吧!🚀 第一步:数据准备 首先,我们要准备一些数…

基因家族Motif 分析

一、名词解释 Motif分析是一种在生物信息学和计算生物学中广泛应用的技术,用于识别DNA、RNA或蛋白质序列中具有生物学功能的短保守序列模式(motif)。这些motif通常与特定的生物学功能相关,如DNA中的转录因子结合位点、RNA中的剪接…

vue项目名修改、webstorm和idea创建的项目重命名、重构项目、修改项目名称

一、需求 就是创建了一个项目,后期需要重命名,怎么办?----> 直接修改?肯定不行,因为里面有些配置也需要修改,假如你只改文件夹名称的话,里面配置都没修改,后期可能会出问题。 二…

专栏十七:如何选择你的单细胞亚群的分辨率--chooseR

好久没更,没想到还是有小伙伴订阅,那就更一个最近看到的问题 1.缘起 是因为在文章Single-cell RNA sequencing and spatial transcriptomics reveal cancer-associated fibroblasts in glioblastoma with protumoral effects(https://doi.org/10.1172/JCI147087.)中看到 也…

机械行业数字化生产供应链产品解决方案(三)

在机械行业数字化生产供应链产品解决方案中,通过融合物联网(IoT)技术、数据分析平台与智能自动化系统,实现生产设备和供应链的全方位数字化管理,能够实时监控生产过程、预测维护需求并优化生产调度,同时利用…

FPGA资源评估

FPGA资源评估 文章目录 FPGA资源评估前言一、资源评估1.1 资源有哪些1.2 资源统计 二、 FPGA 的基本结构三、 更为复杂的 FPGA 架构 前言 一、资源评估 大家在项目中一般会要遇到需要资源评估的情况,例如立了新项目,前期需要确定使用什么FPGA片子&…

06 集合

1.集合类的体系结构 接口:Colltion(单列) List(可重复) Set(不可重复) Map(双列) 实现类: ArrayList,LinkedList HashSet,TreeSet HashMap 2. Collection集合 Collection集合概述 1.是单列集合的顶层接口 2.JDK不提供该接口的任何直接实现,提供具体的子接口(Set和List)实…

Leetcode JAVA刷刷站(11)盛最多水的容器

一、题目概述 二、思路方向 这个问题是经典的“盛最多水的容器”问题,通常使用双指针法来解决。基本思路是,我们初始化两个指针,一个指向数组的起始位置,另一个指向数组的末尾位置。然后,我们计算当前两个指针所指向…

工业智能网关在汽车制造企业的应用价值及功能-天拓四方

随着工业互联网的飞速发展,工业智能网关作为连接物理世界与数字世界的桥梁,正逐渐成为制造业数字化转型的核心组件。本文将以一家汽车制造企业的实际使用案例为蓝本,深入解析工业智能网关在实际应用中的价值、功能及其实操性。 一、背景与挑…

Java语言程序设计——篇十三(1)

🌿🌿🌿跟随博主脚步,从这里开始→博主主页🌿🌿🌿 欢迎大家:这里是我的学习笔记、总结知识的地方,喜欢的话请三连,有问题可以私信🌳🌳&…

Leetcod编程基础0到1-基础实现内容(个人解法)(笔记)

以下为个人解法,欢迎提供不同思路 1768. 交替合并字符串 题目:给你两个字符串 word1 和 word2 。请你从 word1 开始,通过交替添加字母来合并字符串。如果一个字符串比另一个字符串长,就将多出来的字母追加到合并后字符串的末尾&…

凹凸纹理概念

1、凹凸纹理 纹理除了可以用来进行颜色映射外,另外一种常见的应用就是进行凹凸映射。凹凸映射的目的是使用一张纹理来修改模型表面的法线,让我们不需要增加顶点,而让模型看起来有凹凸效果。原理:光照的计算都会利用法线参与计算&…

winform 大头针实现方法——把窗口钉在最上层

平时我们再使用成熟的软件的时候,会发现有个大头针的功能挺不错的。就是点一下大头针,窗口就会钉住,一直保持在最上面一层,这样可以一边设置参数,一边观察这个窗口里面的变化,比较方便。下面我就来简单实现…

进阶SpringBoot之首页和图标定制

idea 快捷键: ctrl shift "" 使缩起来的代码展开 ctrl shitf "-" 使代码缩起 WebMvcAutoConfiguration.class:可以看到需要有一个 index.html 文件映射到首页 private Resource getIndexHtmlResource(Resource location) {tr…

关于SpringMVC的一点学习笔记

关于SpringMVC的一点学习笔记 1、 maven依赖/目录结构2、配置文件3、从前端请求开始4、Controller5、Service6、Dao7、mybatis8、utils公共类9、 分页查询 QueryPageBean / PageResult10、静态页面Freemarker用在经常访问但不经常变化的页面场景中11、Reids12、Echarts13、认证…

后台管理权限自定义按钮指令v-hasPermi

第一步:在src下面建立一个自定义指令文件,放自定义指令方法 permission.js文件: /*** v-hasPermi 操作权限处理*/import store from "/store";export default {inserted(el, binding) {const { value } binding;//从仓库里面获取到后台给的数组const permission s…

软件设计之MySQL(2)

软件设计之MySQL(2) 此篇应在JavaSE之后进行学习: 路线图推荐: 【Java学习路线-极速版】【Java架构师技术图谱】 Navicat可以在软件管家下载 使用navicat连接mysql数据库创建数据库、表、转储sql文件,导入sql数据 学习内容: 基础的SELECT语…

数据分析:宏基因组数据的荟萃分析

禁止商业或二改转载,仅供自学使用,侵权必究,如需截取部分内容请后台联系作者! 介绍 宏基因组数据的荟萃分析是一种综合多个独立宏基因组研究结果的方法,目的是揭示不同人群或样本中微生物群落的共同特征和差异。这种方法特别适用…

ubantu安装python3.10

1.从官网下载安装 1.1安装依赖 sudo apt update sudo apt install build-essential zlib1g-dev libncurses5-dev libgdbm-dev libnss3-dev libssl-dev libreadline-dev libffi-dev libsqlite3-dev wget libbz2-dev1.2从官网下载源文件 wget https://www.python.org/ftp/pyth…

设计资讯 | 巴黎 2024 年奥运会“另一个自我”以 DAB 汽车定制电动摩托车的形式亮相

巴黎 2024 年奥运会运动作为定制电动摩托车 DAB Motors 融入了2024 年巴黎奥运会的精神,通过其定制电动摩托车诠释了奥运会的五环。这些车辆由其服务部门 DAB Custom Studio (DCS) 提供,颜色编码与奥运五环一样。每种颜色代表一项运动:蓝色代…