CSDN个人主页:清风莫追
欢迎关注本专栏:《一起撸个DL框架》
文章目录
- 3 前向传播🥝
- 3.1 前情提要
- 3.2 前向传播:递归的forward方法
- 3.3 再添乘法节点:搭建函数y=2x+1
- 3.4 小结
3 前向传播🥝
3.1 前情提要
上一篇:【一起撸个DL框架】2 节点与计算图的搭建
在上一节中,我们定义了加法节点和变量节点类,搭建计算图并实现了加法功能。但还有一个小问题,那就是节点类的定义中,只有父节点有值时,才能调用compute()
方法计算本节点的值。而当存在多个节点串联时,就无法直接调用结果节点的compute()
方法。因此,这一节我们将采用递归来解决这个问题。
图1:串联的两个加法节点 |
3.2 前向传播:递归的forward方法
我们只需要修改Node类,在其中添加一个forword()
方法。当父节点的值为空时,递归地调用forward()
计算节点的值,然后再调用compute()
计算本节点的值。
class Node:
def __init__(self, parent1=None, parent2=None) -> None:
self.parent1 = parent1
self.parent2 = parent2
self.value = None
def set_value(self, value):
self.value = value
def compute(self):
pass
def forward(self):
for parent in [self.parent1, self.parent2]:
if parent.value is None:
parent.forward()
self.compute()
return self.value
然后,我们就可以使用修改过的节点类,搭建出图1中的计算图,并计算节点add2的值。
if __name__ == '__main__':
# 搭建计算图
x1 = Varrible()
x2 = Varrible()
add1 = Add(x1, x2)
x3 = Varrible()
add2 = Add(add1, x3)
# 输入
x1.set_value(int(input('请输入x1:')))
x2.set_value(int(input('请输入x2:')))
x3.set_value(int(input('请输入x3:')))
# 前向传播
y = add2.forward()
print(y)
运行代码效果如下:
请输入x1:1
请输入x2:2
请输入x3:3
6
3.3 再添乘法节点:搭建函数y=2x+1
函数 y = 2 x + 1 y=2x+1 y=2x+1的计算图如图2所示,与图1很相似,只是其中一个加法节点换成了乘法节点。但不同之处是,在函数 y = 2 x + 1 y=2x+1 y=2x+1的计算图中,只有x一个自变量,其余变量节点称为参数。
图2:函数y=2x+1的计算图 |
乘法节点类的实现与加法节点差不多,如下所示:
class Mul(Node):
def __init__(self, parent1=None, parent2=None) -> None:
super().__init__(parent1, parent2)
def compute(self):
self.value = self.parent1.value * self.parent2.value
下面是图2中计算图的搭建:
if __name__ == '__main__':
# 搭建计算图
w = Varrible()
x = Varrible()
mul = Mul(w, x)
b = Varrible()
add = Add(mul, b)
# 输入
w.set_value(2)
b.set_value(1)
x.set_value(int(input('请输入x:')))
# 前向传播
y = add.forward()
print(y)
请输入x:2
5
3.4 小结
这一节的内容比较简单,我们用递归实现了前向传播,并搭建了一个一次函数: y = 2 x + 1 y=2x+1 y=2x+1。