【一起撸个DL框架】4 反向传播求梯度

news2025/1/18 3:32:56

CSDN个人主页:清风莫追
欢迎关注本专栏:《一起撸个DL框架》

文章目录

  • 4 反向传播求梯度🥥
    • 4.1 简介
    • 4.2 导数与梯度
    • 4.3 链式法则
    • 4.4 示例:y=2x+1的梯度

4 反向传播求梯度🥥

4.1 简介

上一篇:【一起撸个DL框架】3 前向传播

前面我们已经介绍了前向传播,而本节即将介绍的反向传播中的自动微分机制,可以说是深度学习框架的一个核心功能。因为计算图中的参数正是按照着梯度的指引来更新的。

4.2 导数与梯度

说到“梯度”与“导数”这两个概念,有些同学可能已经有些模糊了。在一元函数的情况下,两者几乎可以混为一谈,然而在多元函数的情况下梯度与导数的概念是有区别的。例如二元函数 f ( x , y ) f(x,y) f(x,y)

  • 它沿着二维平面内的每一个方向都会有一个方向导数,方向导数的结果是一个数值,代表沿着该方向的变化率
  • 而它只有一个梯度,梯度的结果是一个向量 ▽ f ( x , y ) = ( ∂ f ∂ x , ∂ f ∂ y ) \triangledown f(x,y)=(\frac{\partial f}{\partial x},\frac{\partial f}{\partial y}) f(x,y)=(xf,yf) ,它同时指示了两个信息:变化率最大的方向和相应的变化率。

4.3 链式法则

简单描述就是“嵌套函数的导,等于各层求导的乘积”。下面是数学的表达: z = f ( y ) , y = g ( x ) z=f(y),y=g(x) z=f(y),y=g(x) ,则 ∂ z ∂ x = ∂ z ∂ y ∂ y ∂ x \frac{\partial z}{\partial x}=\frac{\partial z}{\partial y} \frac{\partial y}{\partial x} xz=yzxy 。计算图就相当于多层的嵌套函数,线性函数可以表示成加法和乘法节点的嵌套。

problem: 梯度不是相对于函数而言的吗?例如 ▽ f ( w , x ) \triangledown f(w,x) f(w,x)。为什么会有“w的梯度”这种概念呢?

在计算图中求一个节点的梯度,只需要将结果节点对子节点的梯度子节点对自己的梯度乘起来就可以了。

在这里插入图片描述
图1:链式法则

4.4 示例:y=2x+1的梯度

为了实现反向传播,我们需要在节点类中加入几个方法。

class Node:
    def __init__(self, parent1=None, parent2=None) -> None:
        self.parent1 = parent1
        self.parent2 = parent2
        self.value = None

        self.grad = None  # 在其它结点求梯度时可能再次用到本结点的梯度
        self.children = []

        parents = [self.parent1, self.parent2]
        for parent in parents:
            if parent is not None:
                parent.children.append(self)

    def set_value(self, value):
        self.value = value
    def compute(self):
        '''抽象方法,在具体的类中重写'''
        pass
    def forward(self):
        for parent in [self.parent1, self.parent2]:
            if parent.value is None:
                parent.forward()
        self.compute()
        return self.value
    def get_parent_grad(self, parent):
        '''求本节点对于父节点的梯度,抽象方法'''
        pass
    def get_grad(self):
        '''求结果节点对本节点的梯度'''
        # 结果结点返回单位值,而不是self.value
        if not self.children:
            return 1
        if self.grad is not None:
            return self.grad
        else:
            self.grad = 0
            for i in range(len(self.children)):
                grad1 = self.children[i].get_parent_grad(parent=self)  # 子节点对自己的梯度
                grad2 = self.children[i].get_grad()  # 结果节点对子节点的梯度
                self.grad += grad1 * grad2
            return self.grad


class Varrible(Node):
    def __init__(self) -> None:
        super().__init__()

class Add(Node):
    def __init__(self, parent1=None, parent2=None) -> None:
        super().__init__(parent1, parent2)
    def compute(self):
        self.value = self.parent1.value + self.parent2.value
    def get_parent_grad(self, parent):
        return 1

class Mul(Node):
    def __init__(self, parent1=None, parent2=None) -> None:
        super().__init__(parent1, parent2)
    def compute(self):
        self.value = self.parent1.value * self.parent2.value
    def get_parent_grad(self, parent):
        '''从parent1,2改成parents时,需要重写'''
        if parent == self.parent1:
            return self.parent2.value
        elif parent == self.parent2:
            return self.parent1.value
        else:
            raise "get which is not a parent of mul node"
        

在之前的基础上,我们在Node类中新增了两个属性self.gradself.children,它们都将在get_grad()方法中发挥作用。

然后我们还在Node类中新增了两个方get_parent_grad()get_grad(),分别用来求子节点对于子节点的父节点(即本节点)的梯度,和结果节点对于本节点的梯度。get_parent_grad()是一个抽象方法,还需在Add类和Mul类中进行具体的实现,实现的方式比较简单,大家阅读源代码即可。

关于get_grad()方法,在本节点没有子节点时,就判断本节点为结果节点,梯度设置为单位值1。在本节点梯度已经存在时,就不再进行递归求值,而直接返回已经保存的梯度值。为什么要保存梯度值?一个节点可以有多个需要求梯度的父节点,这种情况下就会多次用到同一个节点的梯度,每次都递归到结果节点来求值显然是不必要的,于是我们使用空间来换时间。

在下面的示例代码中,我们求 w w w x x x的梯度,就两次用到了mul节点的梯度。

if __name__ == '__main__':
    # 搭建计算图: y=2x+1
    x1 = od.core.Varrible()
    w1 = od.core.Varrible()
    mul = od.ops.Mul(x1, w1)
    b = od.core.Varrible()
    add = od.ops.Add(mul, b)
    # 给参数赋值
    w1.set_value(2)
    b.set_value(1)
    # 使用计算图计算
    x1.set_value(int(input("请输入x:")))
    y = add.forward()
    print(f"y: {y}")
    # 反向传播求梯度
    w_grad = w1.get_grad()
    x_grad = x1.get_grad()
    print(f"w_grad: {w_grad}, x_grad: {x_grad}")

    '''
    请输入x:3
    y: 7
    w_grad: 3, x_grad: 2
    '''

然后我们再次搭建了函数 y = 2 x + 1 y=2x+1 y=2x+1的计算图,首先进行了前向传播,目的是检查计算图是否正确实现了函数y=2x+1的功能。然后进行反向传播求得了 w w w x x x的梯度。

这一节代码的结构已经逐渐变得复杂,一些细节的设计可能需要反复揣摩才能明白,同时代码也还存在一些有待改进的地方,例如Node类的__init__()方法中,每个节点只会有两个父节点的设定其实是不太合适的。


下节预告:动手实现自适应线性单元

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/473952.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Python标准数据类型-String(字符串)

✅作者简介:CSDN内容合伙人、阿里云专家博主、51CTO专家博主、新星计划第三季python赛道Top1 📃个人主页:hacker707的csdn博客 🔥系列专栏:零基础入门篇 💬个人格言:不断的翻越一座又一座的高山…

MATLAB符号运算(七)

目录 1、实验目的: 2、实验内容: 1、实验目的: 1)掌握定义符号对象和创建符号表达式的方法; 2)掌握符号运算基本命令和规则; 3)掌握符号表达式的运算法则以及符号矩阵运算&#xf…

大型Saas系统的权限体系设计(二)

X0 上期回顾 上文《大型Saas系统的权限体系设计(一)》提到2B的Saas系统的多层次权限体系设计的难题,即平台、平台的客户、客户的客户,乃至客户的客户的客户如何授权,这个可以通过“权限-角色-岗位”三级结构来实现。 但这个只是功能权限&am…

mac免费杀毒软件哪个好用?如何清理mac系统需要垃圾

CleanMyMac x是一款功能强大的Mac系统优化清理工具,使用旨在帮助用户更加方便的清理您系统中的所有垃圾,从而加快电脑运行速度,保持最佳性能,更加稳定、流畅、快速!!! CleanMyMac X无疑是目前m…

C++内存管理基础

文章目录 前言1. C/C内存分布2. C语言中动态内存管理方式3. C中动态内存管理3.1 new/delete操作内置类型3.2 new和delete操作自定义类型 4. operator new与operator delete函数4.1 operator new与operator delete函数(重点) 5. new和delete的实现原理5.1…

hana odata batch

sap 博客有写 odata batch 处理前,先看一张图 In this blog post,we are going to see how to send a Odata Batch Request to the SAP Cloud for Customer system using POSTMAN Tool. Answers to expect from this post? How to use batch request in the POS…

『python爬虫』04. 爬虫需要知道的HTTP协议知识(保姆级图文)

目录 1. HTTP协议是什么?2. HTTP协议结构3. 爬⾍需要的请求头和响应头内容总结 欢迎关注 『python爬虫』 专栏,持续更新中 欢迎关注 『python爬虫』 专栏,持续更新中 1. HTTP协议是什么? HTTP协议, Hyper Text Transfer Protocol…

2023独立站能不能做FP?看完这篇你就懂了

现在已经快2023年中了,2023年已经过去了1/3,但还是有人在问特货产品能不能做独立站,还是有不少人在观望。心动不如行动啊朋友们!要是想在跨境独立站做出一番事业来,建议现在立马行动起来,趁早在FP独立站领域…

工厂能耗管理系统linux嵌入式边缘网关

随着工业智能化进程的不断推进,能源能耗管理已成为企业经营中一个重要的环节。而在能源能耗管理场景下,边缘计算机发挥了越来越重要的角色。本文将介绍边缘计算机的功能特点、能源能耗使用对接的设备以及应用前景市场容量,并探讨ARM边缘计算机…

Java使用 Scanner连续输入int, String 异常错误输出原因分析

目录 一、Scanner常用语法 1、sc.nextInt()介绍 2、sc.next()介绍 3、sc.nextLine()介绍 4、sc.hasNext()介绍 二、报错案例 1、使用next()来接收带有空格的字符串会输出异常 2、先输入数字再输入字符串的输出异常 一、Scanner常用语法 Scanner sc new Scanner(System.…

STM32物联网实战开发(2)——回调函数

在第一篇博客中提到了全新的程序框架,我们会大量的使用回调函数,其中包括枚举类型、结构体、函数指针的应用。 回调函数:就是一个通过函数指针调用的函数。如果你把函数的地址传递给中间函数的形参,中间函数通过函数指针调用其所…

【VM服务管家】VM4.0软件使用_1.3全局模块类

目录 1.3.1 通讯管理:通讯管理的心跳管理功能的使用方法1.3.2 全局触发:使用全局触发功能执行流程的方法1.3.3 全局变量:全局变量关联流程中具体模块结果的方法1.3.4 全局脚本:方案加载完成信号发给通信设备的方法1.3.5 全局脚本&…

我做了个GPT3键盘,用了两个月发现它有点傻

自 ChatGPT 出世,各类文本类AI产品层出不穷。甚至接连几日,Producthunt 上新品过半都是AI相关。 这其中部分原因是 OpenAI 公司开放的 GPT3 1API 接口十分易用。只要一个简单的文本请求,就能将现有产品加入AI功能。例如,Notion、…

Docker在Windows系统中的安装方法和使用方法

Docker在Windows系统中的安装方法和使用方法 Docker是一种容器化技术,可以让开发者将应用程序和其依赖项打包成一个可移植的容器,从而实现快速部署和运行。在Windows系统中,Docker可以通过以下步骤进行安装和使用。 优点: Dock…

【VM服务管家】VM4.x算子SDK开发_3.3 模块工具类

目录 3.3.1 位置修正:位置修正算子工具的使用方法3.3.2 模板保存:实现模板自动加载的方法3.3.3 模板匹配: 获取模板匹配框和轮廓点的方法3.3.4 模板训练:模板训练执行完成的判断方法3.3.5 图像相减:算子SDK开发图像相减…

浅谈软件质量与度量

本文从研发角度探讨下高质量软件应具备哪些特点,以及如何度量软件质量。 软件质量的分类 软件质量通常可以分为:内部质量和外部质量。 内部质量 内部质量是指软件的结构和代码质量,以及其是否适合维护、扩展和重构。它关注的是软件本身的…

数据结构 | 常见的数据结构是怎样的?

本文简单总结数据结构的概念及常见的数据结构种类 1’ 2。 更新:2023 / 04 / 05 数据结构 | 常见的数据结构是怎样的? 总览概念分类 常用的数据结构数组链表跳表栈队列树二叉树完全二叉树、满二叉树 平衡二叉树单旋转左旋右旋 红黑树红黑树 V.S 平衡二叉…

2 天:我用文字 AI-ChatGPT 写了绘画 AI-Stable Diffusion 跨平台绘画应用

文本 AI - ChatGPT 和绘画 AI - Stable Diffusion,平地惊雷,突然进入寻常百姓家。 如果时间可以快进,未来的人们对于我们这段时光的历史评价,大概会说: 当时的人们在短时间连续经历了这几种情感。从不信,…

java多线程BlockingDeque的三种线程安全正确退出方法

本文介绍两种BlockingDeque在多线程任务处理时正确结束的方法 一般最开始简单的多线程处理任务过程 把总任务放入BlockingDeque创建多个线程,每个线程内逻辑时,判断BlockingDeque任务是否处理完,处理完退出,还有任务就BlockingDe…

对顶堆模板!!【DS对顶堆】ABC281 E - Least Elements

我想的思路和正解是差不多的 就是滑动窗口,每过去一个用DS维护一下前k个元素和sum 本来想的是用优先队列维护前k个 然后想着multiset维护前k个,但是具体不知道怎么操作 这里用的是multiset维护对顶堆 关于对顶堆,我在寒假的时候总结过 …