前言背景知识:
梯度下降(Gradient descent,GD)
正文:
自动微分为机器学习、深度学习神经网络的核心知识之一,若想更深一步使用神经网络进行具体问题研究,那么自动微分不得不了解。 “工欲善其事,必先利其器”,事为我们的研究问题,器乃神经网络、自动微分工具,这就不得不提一提现深度学习框架TensorFlow、PyTorch,还有即时编译、自动微分工具JAX。本文将使用Pytorch和Jax实现自动微分的基本案例。方便大家学习,入门自动微分工具,提供案例模板为Pytorch版和Jax版,为什么没有Tensorflow,因为觉得它不好用!
注:这些框架、工具在CPU、GPU,又或是TPU的运行效率可执行查阅,作者认为Jax可能会成为潮流。
已知基础单层线性回归模型如下:
则有J关于theta偏导如下:
原生代码实现计算:
def h_x(theta0,theta1,x):
return theta0+theta1*x
def SG(m,theta0,theta1,X,Y):
sum = 0
for i in range(0,m):
sum += (h_x(theta0,theta1,X[i])-Y[i])
theta0_grad = 1.0/m*sum
sum = 0
for i in range(0,m):
sum += (h_x(theta0,theta1,X[i])-Y[i])*X[i]
theta1_grad = 1.0/m*sum
print("O_SG_grad_caculate : {} , {} ".format(theta0_grad,theta1_grad))
#损失函数
def loss(m,theta0,theta1,X,Y):
result = 0.0
for i in range(0,m):
result += (h_x(theta0,theta1,X[i])-Y[i])**2
return result/(2*m)
X = [1,2,3,4,5,6]
Y = [13,14,20,21,25,30]
theta0 = 0.0
theta1 = 0.0
m = 6
y_pre = h_x(theta0,theta1,X)
loss = loss(m,theta0,theta1,X,Y)
print(loss)
SG(m,theta0,theta1,X,Y)
输出:
loss : 227.58333333333334
O_SG_grad_caculate : -20.5 , -81.66666666666666
Pytorch自动微分:
1)torch.autograd.grad计算微分:
import torch
def h_x(theta0,theta1,x):
return theta0+theta1*x
def SG_Torch(theta0,theta1,loss):
theta0_grad,theta1_grad = torch.autograd.grad(loss,[theta0,theta1])
print("T_SG_grad_caculate : {} , {} ".format(theta0_grad,theta1_grad))
X = torch.tensor([1,2,3,4,5,6])
Y = torch.tensor([13,14,20,21,25,30])
theta0 = torch.tensor(0.0,requires_grad=True)
theta1 = torch.tensor(0.0,requires_grad=True)
y_pre = h_x(theta0,theta1,X)
loss = torch.mean((y_pre - Y)**2/2)
#print loss res
print("loss : {}".format(loss))
#print grad res
SG_Torch(theta0,theta1,loss)
输出:
loss : 227.5833282470703
T_SG_grad_caculate : -20.500001907348633 , -81.66667175292969
2)loss.backward()实现计算微分,回传theta0和theta1两叶子节点(设置要求grad)
import torch
def h_x(theta0,theta1,x):
return theta0+theta1*x
X = torch.tensor([1,2,3,4,5,6])
Y = torch.tensor([13,14,20,21,25,30])
theta0 = torch.tensor(0.0,requires_grad=True)
theta1 = torch.tensor(0.0,requires_grad=True)
y_pre = h_x(theta0,theta1,X)
loss = torch.mean((y_pre - Y)**2/2)
loss.backward()
#print loss res
print("loss : {}".format(loss))
#print grad res
print("T_SG_grad_caculate : {} , {} ".format(theta0.grad,theta1.grad))
输出:
loss : 227.5833282470703
T_SG_grad_caculate : -20.500001907348633 , -81.66667175292969
Jax自动微分:
import jax
import jax.numpy as np
from jax import grad
def h_x(theta0,theta1,x):
return theta0+theta1*x
def loss(theta0,theta1,X,Y):
y_pre = h_x(theta0,theta1,X)
loss = np.mean((y_pre - Y)**2/2)
return loss
def SG_Jax(theta0,theta1,l,X,Y):
g_L_theta0 = grad(loss,argnums = 0)
g_L_theta1 = grad(loss,argnums = 1)
theta0_grad = g_L_theta0(theta0,theta1,X,Y)
theta1_grad = g_L_theta1(theta0,theta1,X,Y)
print("J_SG_grad_caculate : {} , {} ".format(theta0_grad,theta1_grad))
X = np.array([1,2,3,4,5,6])
Y = np.array([13,14,20,21,25,30])
theta0 = 0.0
theta1 = 0.0
l = loss(theta0,theta1,X,Y)
#print loss res
print("loss : {}".format(l))
#print grad res
SG_Jax(theta0,theta1,l,X,Y)
输出:
loss : 227.58334350585938 J_SG_grad_caculate : -20.500001907348633 , -81.66667175292969