目录
- 手写AutoGrad
- 前言
- 1. 基本介绍
- 1.1 计算图
- 1.2 智能指针的引出
- 2. 示例代码
- 2.1 Scale
- 2.2 Multiply
- 2.3 Pow
- 总结
手写AutoGrad
前言
手写AI推出的全新面向AI算法的C++课程 Algo C++,链接。记录下个人学习笔记,仅供自己参考。
本次课程主要是手写 AutoGrad 代码
课程大纲可看下面的思维导图
1. 基本介绍
1.1 计算图
自动微分中的计算图是一种数据结构,用于记录计算过程中变量之间的依赖关系,从而实现自动求导。计算图是由节点和边组成的有向无环图。每个节点代表一个变量或者操作,边代表变量之间的依赖关系。(from chatGPT)
计算图中的节点分为两类:变量节点和操作节点。变量节点代表输入的变量,操作节点代表计算过程中的操作。对于一个变量节点,我们需要记录该变量的取值以及其对应的导数值,对于一个操作节点,我们需要记录该操作的具体实现以及其所依赖的变量节点。下图是一个简单的计算图示例:
计算图中的边代表变量之间的依赖关系。对于一个变量节点,如果它是某个操作节点的输入,那么它与该操作节点之间就有一条边。对于一个操作节点,它的输入通常是多个变量节点,输出则是一个新的变量节点,因此它与多个变量节点之间都有边。
在前向计算中,我们从输入变量开始,按照计算图中的边进行计算,最终得到输出变量的值。在反向传播中,我们从输出变量开始,按照计算图中的边进行反向传播,计算每个变量对应的导数值。这里需要用到链式法则,将每个操作的导数值乘到其输入变量的导数值上。
在实现计算图时,我们通常采用面向对象的方法,将每个变量和操作都封装成一个类。对于变量节点,我们需要记录其取值和导数值;对于操作节点,我们需要记录其具体实现以及所依赖的变量节点。同时,我们需要提供一个接口,用于对变量和操作进行前向计算和反向传播。
综上所述,计算图是自动微分中非常重要的概念,它能够帮助我们自动推导变量之间的依赖关系,并自动计算每个变量的导数值。在实际应用中,我们通常使用现有的自动微分库,如 TensorFlow、PyTorch 等,来帮助我们快速地构建和计算计算图。
1.2 智能指针的引出
我们先来看一个简单的示例代码
#include <iostream>
#include <string.h>
class A{
public:
A(){
this->p = new int[100];
this->p[0] = 555;
}
// 深拷贝的实现,对于实例对象的复制,复制一份p的副本
A(const A& other){
this->a = other.a;
this->b = other.b;
this->p = new int[100];
memcpy(this->p, other.p, sizeof(int) * 100);
}
virtual ~A(){
delete [] this->p;
}
private:
int a = 123;
int b = 234
int* p = nullptr;
}
int main(){
A a;
A b = a;
cout << a.p << endl;
cout << b.p << endl;
return 0;
}
上述示例代码演示了对数据地址 p
的引用过程,我们实现了对 p
的深拷贝,在复制类对象的时候复制一份 p,但存在数据复制造成内存的浪费的问题,我们现在的需求有两个:
- 1.无论
a
被复制多少份,它的数据指针p
永远都是同一个(即数据引用),使得这个数据不至于有太多的副本 - 2.我们希望能够正常释放内存,即
p
指针释放的时机
为此我们考虑二级指针和引用计数的思想,示例如下:
#include <iostream>
#include <string.h>
// 诉求并不是为了实现对p的深度copy,在复制类对象的时候复制一份p
// 是为了演示,对数据地址p的引用过程,即
// 1. 无论a被复制多少份,它的数据指针p永远都是同一个 (数据引用)
// 使得这个数据不至于有太多副本
// 2. 还能够正常释放内存
// p指针的释放问题,也就是释放的时机
// 二级指针 + 引用计数
class Data{
public:
Data(){
this->p = new int[100];
nref = 1;
cout << "构造Data" << endl;
}
virtual ~Data(){
cout << "析构Data" << endl;
delete [] this->p;
}
int *p = nullptr;
int nref = 0; // 引用计数
}
class A{
public:
A(){
this->new Data();
}
A(const A& other){
*this = other;
this->p->nref ++;
}
virtual ~A(){
// 当最后一个引用实例释放时,删除p指针
// 只要存在其它引用,释放都是不合适的
// 问题在于,怎么判断目前执行的析构函数是最后一个引用呢?
// 引用计数
delete [] this->p;
if(this->p->nref == 1){
delete this->p;
}else{
this->p->nref --;
}
}
int a = 123;
int b = 234
Data* p = nullptr;
}
int main(){
A* a = new A();
A b = *a;
delete a;
// cout << a.p << endl;
cout << b.p << endl;
return 0;
}
上述示例代码在原基础上引入了二级指针和引用计数来对数据的引用,避免了深度复制时生成多副本的问题,并解决了对象的内存释放的问题。具体分析如下:
- Data 类:将数据成员
p
单独抽象出来,并添加引用计数nref
作为计数器,用于记录有多少个对象指向同一块数据空间,方便释放内存 - A 类:将
Data
对象的指针作为A
类的成员p
,实现对Data
对象的引用。同时,在 A 类的拷贝构造函数中,对 Data 对象指针进行拷贝并引用计数加一。在 A 类的析构函数中,通过引用计数来判断是否释放 Data 对象的内存空间,如果当前对象是最后一个引用,释放 Data 对象,否则减少引用计数。 - 在 main 函数中,通过 new 创建了一个 A 类型的对象,并将其指针 a 赋值给 b。通过 delete 删除 a 指针时,并不会影响 b 中 Data 对象的内存空间,因为其引用计数不为0.
由此我们引出
shared_ptr
共享智能指针,它可以帮助我们完成引用计数操作避免数据副本拷贝,自动实现内存的引用关系,而不需要我们自己去实现。下面是利用shared_ptr
的示例代码:
#include <iostream>
#include <string.h>
#include <memory>
class Data{
public:
Data(){
this->p = new int[100];
nref = 1;
cout << "构造Data" << endl;
}
virtual ~Data(){
cout << "析构Data" << endl;
delete [] this->p;
}
int *p = nullptr;
}
class A{
public:
A(){
this->pdata.reset(new Data());
}
int a = 123;
int b = 234
shared_ptr<Data> pdata;
}
int main(){
// shared_ptr
A* a = new A();
A b = *a;
delete a;
// cout << a.p << endl;
cout << b.pdata.get() << endl;
return 0;
}
在上面的示例代码中我们使用了智能指针 shared_ptr 对内存的管理。在 A 类里面只需要一个指向 Data 类型对象的 shared_ptr
智能指针,并通过 reset
函数进行初始化。需要注意的是,通过 shared_ptr
管理内存,不再需要显式地调用 delete
操作符,shared_ptr
析构时会自动释放所持有地内存,避免了内存泄漏的问题。关于更多细节请查看共享智能指针
2. 示例代码
自动微分简单版本的示例代码
2.1 Scale
自动微分中标量 Scale 实现的示例代码
#include <iostream>
#include <memory>
#include <cmath>
using namespace std;
// 储存表达式所需要的相关数据
class Container{
public:
// 返回这个表达式具体是什么类型,或者说什么名称
virtual const char* type() = 0;
// 具体的forward的实现过程
virtual float forward() = 0;
// 具体的backward的实现
virtual void backward(float gradient) = 0;
};
// 表达式类
// 1. 其实所有的操作都可以认为是表达式
// a. 标量x,可以认为是标量表达式
// b. 任意的算子,比如加减乘除,都可以抽象为表达式
class Expression{
public:
// 对该表达式进行前向推理,并得到推理后的结果
float forward(){
return container_->forward();
}
// 对该表达式反向推理,计算每一个节点对应的导数
void backward(){
return container_->backward(1.0f);
}
// 为了储存表达式中的数据,所以需要引入二级指针,表示表达式所储存的具体实现
// 具体实现在这里
shared_ptr<Container> container_;
};
// 标量
class ScalarContainer : public Container{
public:
ScalarContainer(float value){
value_ = value;
}
virtual const char* type() override{
return "Scalar";
}
virtual float forward() override{
return value_;
}
virtual void backward(float gradient) override{
gradient_ += gradient;
}
float value_ = 0;
float gradient_ = 0;
};
class Scalar : public Expression{
public:
Scalar(flaot value){
container_.reset(new ScalarContainer(value)); // 智能指针初始化方法
}
};
int main(){
// 1. 实现计算图的统计
// 2. 实现过程,应该跟四则运算类似,跟普通写表达式类似
// 3. 实现forward前向计算和backward反向求导
Scale a(3.0f);
cout << a.forward() << endl;
return 0;
}
上述示例代码实现了自动微分中的标量,其中定义了 Container
和 Expression
两个类,分别用于储存表达式所需的相关数据和定义表达式。在 Container 中,包含了类型定义、前向计算和反向求导的虚函数,其中 ScalarContainer 是用于储存标量的数据。在 Expression 中,定义了前向计算和反向求导的成员函数,并利用二级智能指针 container_ 引用 Container 类的实例。在 Scale 中,重载了 Expression 的构造函数,并初始化了 container_ 为一个 ScalarContainer 类的实例。
2.2 Multiply
自动微分中乘法 Multiply 实现的示例代码
#include <iostream>
#include <memory>
#include <cmath>
using namespace std;
// 储存表达式所需要的相关数据
class Container{
public:
// 返回这个表达式具体是什么类型,或者说什么名称
virtual const char* type() = 0;
// 具体的forward的实现过程
virtual float forward() = 0;
// 具体的backward的实现
virtual void backward(float gradient) = 0;
};
// 表达式类
// 1. 其实所有的操作都可以认为是表达式
// a. 标量x,可以认为是标量表达式
// b. 任意的算子,比如加减乘除,都可以抽象为表达式
class Expression{
public:
// 对该表达式进行前向推理,并得到推理后的结果
float forward(){
return container_->forward();
}
// 对该表达式反向推理,计算每一个节点对应的导数
void backward(){
return container_->backward(1.0f);
}
// 为了储存表达式中的数据,所以需要引入二级指针,表示表达式所储存的具体实现
// 具体实现在这里
shared_ptr<Container> container_;
};
// 乘法
class MultiplyContainer : public Container{
public:
MultiplyContainer(const Expression& left, const Expression& right){
left_value_ = left.container_;
right_value_ = right.container_;
}
virtual const char* type() override{
return "Multiply";
}
virtual float forward() override{
return left_value_->forward() * right_value_->forward();
}
virtual void backward(float gradient) override{
left_value_->backward(gradient * right_value_->forward());
right_value_->backward(gradient * left_value_->forward());
}
shared_ptr<Container> left_value_;
shared_ptr<Container> right_value_;
};
class Multiply : public Expression{
public:
Multiply(const Expression& left, const Expression& right){
container_.reset(new MultiplyContainer(left, right));
}
};
// 重载
Expression operator*(float left, const Expression& right){
return Multiply(Scalar(left), right);
}
Expression operator*(const Expression& left, float right){
return Multiply(left, Scalar(right));
}
Expression operator*(const Expression& left, const Expression& right){
return Multiply(left, right);
}
int main(){
// 1. 实现计算图的统计
// 2. 实现过程,应该跟四则运算类似,跟普通写表达式类似
// 3. 实现forward前向计算和backward反向求导
Scalar a(3.0f);
// auto exp = Multiply(a, Scalar(5.0f));
auto exp = 5.0f * a * 2.0f;
auto f = exp * 10 * a;
cout << exp.forward() << endl;
exp.backward();
cout << a.gradient() << endl;
return 0;
}
上述示例代码实现了自动微分中的乘法,实现了 MultiplyContainer
类,该类继承自 Container 类。它有两个操作数即 left_value_ 和 right_value_,分别是左右操作数的值,构造函数需要传入左右操作数的 Expression 对象,用于后续的前向推理和反向推理。Multiply
类继承自 Expression 类,它包含 MultiplyContainer 类对象的智能指针,用于储存乘法表达式的实现。
此外,示例代码还实现了乘法运算的重载,用于简化表达式的书写。
2.3 Pow
自动微分中平方 Pow 实现的示例代码
#include <iostream>
#include <memory>
#include <cmath>
using namespace std;
// 储存表达式所需要的相关数据
class Container{
public:
// 返回这个表达式具体是什么类型,或者说什么名称
virtual const char* type() = 0;
// 具体的forward的实现过程
virtual float forward() = 0;
// 具体的backward的实现
virtual void backward(float gradient) = 0;
};
// 表达式类
// 1. 其实所有的操作都可以认为是表达式
// a. 标量x,可以认为是标量表达式
// b. 任意的算子,比如加减乘除,都可以抽象为表达式
class Expression{
public:
// 对该表达式进行前向推理,并得到推理后的结果
float forward(){
return container_->forward();
}
// 对该表达式反向推理,计算每一个节点对应的导数
void backward(){
return container_->backward(1.0f);
}
// 为了储存表达式中的数据,所以需要引入二级指针,表示表达式所储存的具体实现
// 具体实现在这里
shared_ptr<Container> container_;
};
// 标量
class ScalarContainer : public Container{
public:
ScalarContainer(float value){
value_ = value;
}
virtual const char* type() override{
return "Scalar";
}
virtual float forward() override{
return value_;
}
virtual void backward(float gradient) override{
gradient_ += gradient;
}
float value_ = 0;
float gradient_ = 0;
};
class Scalar : public Expression{
public:
Scalar(float value){
container_.reset(new ScalarContainer(value)); // 智能指针初始化方法
}
float gradient() const{
return dynamic_pointer_cast<ScalarContainer>(container_)->gradient_;
// return ((ScalarContainer*)container_.get())->gradient_;
}
};
// 乘法
class MultiplyContainer : public Container{
public:
MultiplyContainer(const Expression& left, const Expression& right){
left_value_ = left.container_;
right_value_ = right.container_;
}
virtual const char* type() override{
return "Multiply";
}
virtual float forward() override{
return left_value_->forward() * right_value_->forward();
}
virtual void backward(float gradient) override{
left_value_->backward(gradient * right_value_->forward());
right_value_->backward(gradient * left_value_->forward());
}
shared_ptr<Container> left_value_;
shared_ptr<Container> right_value_;
};
class Multiply : public Expression{
public:
Multiply(const Expression& left, const Expression& right){
container_.reset(new MultiplyContainer(left, right));
}
};
Expression operator*(float left, const Expression& right){
return Multiply(Scalar(left), right);
}
Expression operator*(const Expression& left, float right){
return Multiply(left, Scalar(right));
}
Expression operator*(const Expression& left, const Expression& right){
return Multiply(left, right);
}
// 减法
class SubContainer : public Container{
public:
SubContainer(const Expression& left, const Expression& right){
left_value_ = left.container_;
right_value_ = right.container_;
}
virtual const char* type() override{
return "Sub";
}
virtual float forward() override{
return left_value_->forward() - right_value_->forward();
}
virtual void backward(float gradient) override{
left_value_->backward(gradient);
right_value_->backward(-gradient);
}
shared_ptr<Container> left_value_;
shared_ptr<Container> right_value_;
};
class Sub : public Expression{
public:
Sub(const Expression& left, const Expression& right){
container_.reset(new SubContainer(left, right));
}
};
Expression operator-(const Expression& left, float right){
return Sub(left, Scalar(right));
}
// power
class PowerContainer : public Container{
public:
PowerContainer(const Expression& x, float y){
x_ = x.container_;
y_ = y;
}
virtual const char* type() override{
return "Power";
}
virtual float forward() override{
return std::pow(x_->forward(), y_);
}
virtual void backward(float gradient) override{
x_->backward(gradient * (y_ * std::pow(x_->forward(), y_ - 1)));
}
shared_ptr<Container> x_;
float y_ = 0;
};
class Power : public Expression{
public:
Power(const Expression& x, float y){
container_.reset(new PowerContainer(x, y));
}
};
namespace op{
Expression power(const Expression& x, float y){
return Power(x, y);
}
};
int main(){
// 1. 实现计算图的统计
// 2. 实现过程,应该跟四则运算类似,跟普通写表达式类似
// 3. 实现forward前向计算和backward反向求导
// 需要求解的值
float x = 3.0f;
// 中间变量
float t = x / 2.0f;
// 定义loss
float loss = 0.5 * std::pow(t * t - x, 2.0f);
// 定义最少容忍的误差
// 三种停止条件
// 1. loss低于一定值
// 2. t的改变量低于一定值
// 3. 迭代次数满足条件
float eps = 1e-5;
// 迭代步长,其实就是所谓的学习率
float lr = 0.01;
while(loss > eps){
// float dt = (t * t - x) * 2 - t;
Scalar st(t);
auto sl = 0.5 * op::power(st * st - x, 2.0f);
s1.backward();
float dt = st.gradient();
t = t - lr * dt;
loss = 0.5 * std::pow(t * t - x, 2.0f);
std::printf("Loss: %.5f, t = %.5f, sqrt(x) = %.5f\n", loss, t, std::sqrt(x));
}
return 0;
}
上述示例代码实现了自动微分中的平方。Power
类实现了 Pow
函数的前向计算和反向传播,它继承自 Expression。整个代码实现了基本的自动微分框架,并在此基础上实现了一个求解平方根的例子。
总结
在本次课程中引入了计算图的概念,通过计算图我们可以实现自动微分过程。并实现了一个简单且基础的自动微分示例代码,包括 Scale 标量、Multiply 乘法以及 Power 平方三个部分的内容,并完成了一个求解平方根的例子。实际上自动微分还可以针对向量甚至矩阵,但这样一来复杂度就高了,杜老师有提供相应的示例代码供大家学习。😂