【深度学习】7-0 自制框架实现DeZero - 自动微分

news2024/11/15 21:05:31

介绍下处理深度学习的框架DeZero,通过这个框架来了解自动微分是如何实现的
自动微分指的是自动求出导数的做法(技术)。“自动求出导数”是指由计算机(而非人)求出导数。具体来说,它是指在对某个计算(函数)编码后计算机会自动求出该计算的导数的系统。

自动微分。这是一种采用链式法则求导的方法。我们对某个函数编码后,可以通过自动微分高效地求出高精度的导数。反向传播也是自动微分的一种。反向传播相当于反向模式的自动微分。

自动微分是用计算机求导的一种方法。深度学习框架中实的是反向模式的自动微分。

实现Variable类

在DeZero中,变量都是通过Variable类来实现的,也就是让Variable类具有箱子的作用,看下面Variable的实现

class Variable:
    def __init__(self, data):
        self.data = data

实现Function类

Function类的实现如下:

class Function:
	# __call__ 重写调用方法
	def __call__(self, input):
		x = input.data # 取出数据
		y = self.forward(x) # 具体计算在forward中进行
		output = Variable(y) # 作为Variable返回
		return output

	def forward(self, x):
		# 暂时不实现
		raise NotImplementedError()

在DeZero框架中,将Function类作为基类,实现所有函数通用的功能;具体函数时在继承了Function类的类中实现

在具体的使用中,继承自Function类并对输入值进行平方的类。这个类的名字是Square,代码如下所示:

class Square(Function):
	def forward(self, x):
		return x ** 2

具体实现如下:

x = Variable(np.array(10))
f = Square()
y = f(x)
print(type(y)  # <class '__main__.Variable'>
print(y.data) # 100

手动进行反向传播

下面实现支持反向传播的Variable类。为此,要扩展Variable类,除普通值(data)之外,增加与之对应的导数值(grad)。

class Variable:
	def __init__(self, data):
		self.data = data
		self.grad = None # 要在通过反向传播实际计算导数时将其设置为求出的值。

然后要扩展Function类
在当前的Function类上还要新增下面两个功能

  1. 计算导数的反向传播(backward)功能
  2. 调用forward方法时,保有被输入的Variable实例的功能
class Function:
	# __call__ 重写调用方法
	def __call__(self, input):
		x = input.data 
		y = self.forward(x) 
		output = Variable(y) 
		self.input = input # 保存输入的变量
		return output

	def forward(self, x):
		# 暂时不实现
		raise NotImplementedError()

	def backward(self, gy):
		# 暂时不实现
		raise NotImplementedError()

看下面实际的例子
要实现具体函数的反向传播,首先看之前实现的Square类

class Square(Function):
	def forward(self, x):
		y = x ** 2
		return y

	def backward(self, gy):
		x = self.input.data
		gx = 2 * x * gy
		return gx

接下来看Exp类, y = ex ,这个类可以按下面的方式实现

class Exp(Function):
	def forward(self, x):
		y = np.exp(x)
		return y

	def backward(self, gy):
		x = self.input.data
		gx = np.exp(x) * gy
		return gx

反向传播的例子
首先看正向传播的代码

A = Square()
B = Exp()
C = Square()

x = Variable(np.array(0.5))
a = A(x)
b = B(a)
y = C(b)

再通过反向传播计算y的导数

y.grad = np.array(1.0)
b.grad = C.backward(y.grad)
a.grad = B.backward(b.grad)
x.grad = A.backward(a.grad)
print(x.grad)

反向传播的自动化

下面就要让反向传播自动化,也就是要建立这样的机制:无论普通的计算流程(正向传播)中是什么样的计算,反向传播都可以自动进行。
之前做的流水线式的计算,只要以列表的形式记录函数的顺序,就可以通过反向回溯自动进行反向传播。不过,对于有分支的计算图或多次使用同一个变量的复杂计算图,只借助简单的列表就不能奏效了。接下来的目标是建立一个不管计算图多么复杂,都能自动进行反向传播的机制。

其实只要在列表的数据结构上想想办法,将所做的计算添加到列表中,或许可以对任意的计算图准确地进行反向传播。

要实现自动化就要在函数和变量之间建立联系,要让这个“连接”在执行普通计算(正向传播)的那一刻创建,因此要在Variable类中添加以下代码:

class Variable:
	def __init__(self,data):
		self.data = data
		self.grad = None
		self.creator = None

	def set_creator(self, func):
		self.creator = func

在Function中添加代码

class Function:
	def __call__(self, input):
		x = input.data
		y = self.forward(x)
		output = Variable(y)
		output.set_creator(self) # 让输出变量保存创造者信息
		self.input = input
		self.output = output # 也保存输出变量
		return output

变量和函数连接的这个特征就是Define-by-Run。换言之是通过数据的流转建立起来的。这种带有“连接”的数据结构叫作连接节点

下面利用变量和函数之间的连接,尝试实现反向传播。
下面实现从变量y到b的反向传播

y.grad = np.array(1.0)
C = y.creator # 获取函数
b = C.input # 获取函数的输入
b.grad = C.backward(y.grad) # 调用函数的backward方法

在这里插入图片描述
下面实现从变量b到变量a反向传播

B = b.creator  # 获取函数
a = B.input  # 获取函数的输入
a.grad = B.backward(b.grad)  # 调用函数的backward方法

具体来说
流程如下:

  1. 获取函数
  2. 获取函数的输入
  3. 调用函数的backward方法

为Variable增加backward方法
从前面这些反向传播的代码可以看出。它们有着相同的处理方式。为了自动完成这些重复的处理。可以在Variable类中添加一个新的方法 —— backward

class Variable:
	def __init__(self,data):
		self.data = data
		self.grad = None
		self.creator = None

	def set_creator(self, func):
		self.creator = func

    def backward(self):
        f = self.creator  # 1. Get a function
        if f is not None:
            x = f.input  # 2. Get the function's input
            # 递归调用
            x.grad = f.backward(self.grad)  # 3. Call the function's backward
            x.backward()
            
            

上面使用这个新的Variable自动进行反向传播

A = Square()
B = Exp()
C = Square()

x = Variable(np.array(0.5))
a = A(x)
b = B(a)
y = C(b)

# backward
y.grad = np.array(1.0)
y.backward()
print(x.grad) # 输出结果 3.297442541400256

循环实现

在之前Variable的实现中
backward方法内调用backward方法,被调用backward方法内再次调用backward方法的处理会不断延续下去直到某个self.creator函数为None的Variable变量,所以这是个递归结构

下面要使用循环实现,代码如下:

class Variable:
	def __init__(self,data):
		self.data = data
		self.grad = None
		self.creator = None

	def set_creator(self, func):
		self.creator = func

    def backward(self):
    	# 按顺序向funcs列表里添加应该处理的函数。
        funcs = [self.creator]
        while funcs:
        	f = funcs.pop()  # 获取函数 列表的pop方法会删除列表末尾的元素,并取出这个元素的值。
        	x, y = f.input, f.output  # 获取函数的输入
        	x.grad = f.backward(y.grad)  # backward调用backward方法
        	if x.creator is not None:
        		funcs.append(x.creator)  # 将前一个函数添加到列表中
                

之所以要把递归变成循环,主要是为了处理复杂的计算图,使用循环代码实现很容易扩展到复杂的计算图处理,而且执行效率会变高

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/697649.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

flexible.js适配pc端、移动端并自动将px转换rem

首先在assets中创建一个flexible.js文件 ;(function(win, lib) {let doc win.document;let docEl doc.documentElement;let metaEl doc.querySelector(meta[name"viewport"]);let flexibleEl doc.querySelector(meta[name"flexible"]);let dpr 0;let…

POI及EasyExcel操作xls,xlsx文件

Apache POI 是基于 Office Open XML 标准&#xff08;OOXML&#xff09;和 Microsoft 的 OLE 2 复合文档格式&#xff08;OLE2&#xff09;处理各种文件格式的开源项目。 可以使用 Java 读写 MS Excel 文件&#xff0c;可以使用 Java 读写 MS Word 和 MS PowerPoint 文件。 模…

C# 标注图片

画矩形 画四边形 保存标注图片 保存标注信息 代码 using System; using System.Collections.Generic; using System.ComponentModel; using System.Data; using System.Drawing; using System.Linq; using System.Text; using System.Windows.Forms; using System.Drawing.Ima…

【UE5 Cesium】06-Cesium for Unreal 从一个地点飞行到另一个地点(上)

UE版本&#xff1a;5.1 介绍 本文以在墨尔本和悉尼这两个城市间为例&#xff0c;介绍如何使用虚幻5引擎和Cesium for Unreal插件在这两个城市间进行飞行移动&#xff0c;其中墨尔本和悉尼城市的倾斜摄影是Cesium官方仓库中自带的资产&#xff0c;我们引入到自己的Cesium账号…

蓝桥杯专题-试题版-【地宫取宝】【斐波那契】【波动数列】【小朋友排队】

点击跳转专栏>Unity3D特效百例点击跳转专栏>案例项目实战源码点击跳转专栏>游戏脚本-辅助自动化点击跳转专栏>Android控件全解手册点击跳转专栏>Scratch编程案例点击跳转>软考全系列点击跳转>蓝桥系列 &#x1f449;关于作者 专注于Android/Unity和各种游…

MySQL相关知识点

这里写目录标题 MySQL简介概述配置安装连接&#xff08;企业级&#xff09;数据模型sql语句简介语法分类 数据库设计DDL&#xff08;SQL语句&#xff09;数据库操作idea集成mysql开发图形化工具&#xff08;直接在空java项目里打开mysql数据库&#xff09; 表&#xff08;对表的…

ASEMI代理ST可控硅BTA41封装,BTA41图片

编辑-Z BTA41参数描述&#xff1a; 型号&#xff1a;BTA41 封装&#xff1a;TO-3P RMS导通电流IT(RMS)&#xff1a;40A 非重复浪涌峰值导通电流ITSM&#xff1a;420A 峰值栅极电流IGM&#xff1a;8A 平均栅极功耗PG&#xff1a;1W 存储接点温度范围Tstg&#xff1a;-40…

kubelete源码阅读

kubelet 是运行在每个节点上的主要的“节点代理”&#xff0c;每个节点都会启动 kubelet进程&#xff0c;用来处理 Master 节点下发到本节点的任务&#xff0c;按照 PodSpec 描述来管理Pod 和其中的容器&#xff08;PodSpec 是用来描述一个 pod 的 YAML 或者 JSON 对象&#xf…

ATTCK(四)之ATTCK矩阵战术技术(TTP)逻辑和使用

ATT&CK矩阵战术&技术&#xff08;TTP&#xff09;逻辑和使用 ATT&CK的战术与技术组织结构 ATT&CK矩阵中的所有战术和技术&#xff0c;都可以通过以下链接进行目录结构式的浏览https://attack.mitre.org/techniques/enterprise/&#xff0c;也可以在矩阵里直接…

arcgis栅格影像--镶嵌

1、打开软件导入数据&#xff0c;如下&#xff1a; 2、在搜索栏中搜索“镶嵌至新栅格”&#xff0c;如下&#xff1a; 3、双击打开镶嵌对话框&#xff0c;如下&#xff1a; 4、点击确定按钮&#xff0c;进行栅格镶嵌&#xff0c;镶嵌结果如下&#xff1a; 5、去除黑边&#xff…

若依框架-前端使用教程

1 使用 npm run dev 命令执行本机开发测试时&#xff0c;提出错误信息如下&#xff1a; opensslErrorStack: [ error:03000086:digital envelope routines::initialization error ], library: digital envelope routines, reason: unsupported, code: ERR_OSSL_EVP_UNS…

Web安全——PHP基础

PHP基础 一、PHP简述二、基本语法格式三、数据类型、常量以及字符串四、运算符五、控制语句1、条件控制语句2、循环控制语句 六、php数组1、数组的声明2、数组的操作2.1 数组的合拼2.2 填加数组元素2.3 添加到指定位置2.4 删除某一个元素2.5 unset 销毁指定的元素2.6 修改数组中…

Tune-A-Video:用于文本到视频生成的图像扩散模型的One-shot Tuning

Tune-A-Video: One-Shot Tuning of Image Diffusion Models for Text-to-Video Generation Project&#xff1a;https://tuneavideo.github.io 原文链接&#xff1a;Tnue-A-Video:用于文本到视频生成的图像扩散模型的One-shot Tuning &#xff08;by 小样本视觉与智能前沿&…

基于matlab使用校准相机测量平面物体(附源码)

一、前言 此示例演示如何使用单个校准相机以世界单位测量硬币的直径。 此示例演示如何校准相机&#xff0c;然后使用它来测量平面对象&#xff08;如硬币&#xff09;的大小。这种方法的一个示例应用是测量传送带上的零件以进行质量控制。 二、校准相机 相机校准是估计镜头…

基于多站点集中汇聚需求的远程调用直播视频汇聚平台解决方案

一、行业背景 随着视频汇聚需求的不断提升&#xff0c;智慧校园、智慧园区等项目中需要将各分支机构的视频统一汇聚到总部&#xff0c;进行统一管控&#xff0c;要满足在监控内部局域网、互联网、VPN网络等TCP/IP环境下&#xff0c;为用户提供低成本、高扩展、强兼容、高性能的…

ModaHub魔搭社区:如何基于向量数据库+LLM(大语言模型),打造更懂你的企业专属Chatbot?

目录 1、为什么Chatbot需要大语言模型向量数据库? 2、什么是向量数据库? 3、LLM大语言模型ADB-PG:打造企业专属Chatbot 4、ADB-PG:内置向量检索全文检索的一站式企业知识数据库 5、总结 1、为什么Chatbot需要大语言模型向量数据库? 这个春天,最让人震感的科技产品莫过…

6.28作业

作业1 结构体不能被继承&#xff0c;类可以被继承结构体默认的都是公共&#xff0c;类默认是私有的 转载【结构体和类的区别】 结构体是值类型&#xff0c;类是引用类型 结构体存在栈中&#xff0c;类存在堆中 结构体成员不能使用protected访问修饰符&#xff0c;而类可以 结…

vsCode 运行 报错信息 yarn : 无法加载文件 C:\Program Files\nodejs\yarn.ps1

检索说是 PowerShell 执行策略&#xff0c;默认设置是Restricted不去加载配置文件或运行脚本。需要去做相应的变更&#xff0c; 修改配置为 RemoteSigned 管理员身份打开 PowerShell&#xff0c;执行命令&#xff0c;修改PowerShell 执行策略 set-ExecutionPolicy RemoteSigne…

2023.6.28

类和结构体区别&#xff1a; 1&#xff0c;类可以进行封装&#xff08;有访问权限等&#xff09;&#xff0c;结构体无&#xff1b; 2&#xff0c;类有&#xff1a;封装&#xff0c;继承&#xff0c;多态三大特征&#xff0c;结构体只有变量和函数。 #include <iostream&g…

面试题小计(1)

Https加密过程、与三次握手 三次握手是传输层的概念&#xff0c;HTTPS通常是 SSL HTTP 的简称&#xff0c;目前使用的 HTTP/HTTPS 协议是基于 TCP 协议之上的&#xff0c;因此也需要三次握手。要在 TCP 三次握手建立链接之后&#xff0c;才会进行 SSL 握手的过程&#xff08;…