线性模型
课程
代码
import numpy as np
import matplotlib.pyplot as plt
x_data=[1.0,2.0,3.0]
y_data=[2.0,4.0,6.0]
#前馈函数
def forward(x):
return x*w
#损失函数
def loss(x,y):
y_pred=forward(x)
return (y_pred-y)*(y_pred-y)
w_list=[]
mse_list=[]
for w in np.arange(0.0,4.1,0.1):
print("w=",w)
l_sum=0
for x_v,y_v in zip(x_data,y_data):
loss_v=loss(x_v,y_v)
l_sum+=loss_v
print(x_v,y_v,loss_v)
l_sum/=3
print("MSE=",l_sum)
w_list.append(w)
mse_list.append(l_sum)
plt.plot(w_list,mse_list)
plt.xlabel('w')
plt.ylabel('loss')
plt.show()
结果
w= 0.0
1.0 2.0 4.0
2.0 4.0 16.0
3.0 6.0 36.0
MSE= 18.666666666666668
w= 0.1
1.0 2.0 3.61
2.0 4.0 14.44
3.0 6.0 32.49
MSE= 16.846666666666668
w= 0.2
1.0 2.0 3.24
2.0 4.0 12.96
3.0 6.0 29.160000000000004
MSE= 15.120000000000003
w= 0.30000000000000004
1.0 2.0 2.8899999999999997
2.0 4.0 11.559999999999999
3.0 6.0 26.009999999999998
MSE= 13.486666666666665
w= 0.4
1.0 2.0 2.5600000000000005
2.0 4.0 10.240000000000002
3.0 6.0 23.04
MSE= 11.946666666666667
w= 0.5
1.0 2.0 2.25
2.0 4.0 9.0
3.0 6.0 20.25
MSE= 10.5
w= 0.6000000000000001
1.0 2.0 1.9599999999999997
2.0 4.0 7.839999999999999
3.0 6.0 17.639999999999993
MSE= 9.146666666666663
w= 0.7000000000000001
1.0 2.0 1.6899999999999995
2.0 4.0 6.759999999999998
3.0 6.0 15.209999999999999
MSE= 7.886666666666666
w= 0.8
1.0 2.0 1.44
2.0 4.0 5.76
3.0 6.0 12.959999999999997
MSE= 6.719999999999999
w= 0.9
1.0 2.0 1.2100000000000002
2.0 4.0 4.840000000000001
3.0 6.0 10.889999999999999
MSE= 5.646666666666666
w= 1.0
1.0 2.0 1.0
2.0 4.0 4.0
3.0 6.0 9.0
MSE= 4.666666666666667
w= 1.1
1.0 2.0 0.8099999999999998
2.0 4.0 3.2399999999999993
3.0 6.0 7.289999999999998
MSE= 3.779999999999999
w= 1.2000000000000002
1.0 2.0 0.6399999999999997
2.0 4.0 2.5599999999999987
3.0 6.0 5.759999999999997
MSE= 2.986666666666665
w= 1.3
1.0 2.0 0.48999999999999994
2.0 4.0 1.9599999999999997
3.0 6.0 4.409999999999998
MSE= 2.2866666666666657
w= 1.4000000000000001
1.0 2.0 0.3599999999999998
2.0 4.0 1.4399999999999993
3.0 6.0 3.2399999999999993
MSE= 1.6799999999999995
w= 1.5
1.0 2.0 0.25
2.0 4.0 1.0
3.0 6.0 2.25
MSE= 1.1666666666666667
w= 1.6
1.0 2.0 0.15999999999999992
2.0 4.0 0.6399999999999997
3.0 6.0 1.4399999999999984
MSE= 0.746666666666666
w= 1.7000000000000002
1.0 2.0 0.0899999999999999
2.0 4.0 0.3599999999999996
3.0 6.0 0.809999999999999
MSE= 0.4199999999999995
w= 1.8
1.0 2.0 0.03999999999999998
2.0 4.0 0.15999999999999992
3.0 6.0 0.3599999999999996
MSE= 0.1866666666666665
w= 1.9000000000000001
1.0 2.0 0.009999999999999974
2.0 4.0 0.0399999999999999
3.0 6.0 0.0899999999999999
MSE= 0.046666666666666586
w= 2.0
1.0 2.0 0.0
2.0 4.0 0.0
3.0 6.0 0.0
MSE= 0.0
w= 2.1
1.0 2.0 0.010000000000000018
2.0 4.0 0.04000000000000007
3.0 6.0 0.09000000000000043
MSE= 0.046666666666666835
w= 2.2
1.0 2.0 0.04000000000000007
2.0 4.0 0.16000000000000028
3.0 6.0 0.36000000000000065
MSE= 0.18666666666666698
w= 2.3000000000000003
1.0 2.0 0.09000000000000016
2.0 4.0 0.36000000000000065
3.0 6.0 0.8100000000000006
MSE= 0.42000000000000054
w= 2.4000000000000004
1.0 2.0 0.16000000000000028
2.0 4.0 0.6400000000000011
3.0 6.0 1.4400000000000026
MSE= 0.7466666666666679
w= 2.5
1.0 2.0 0.25
2.0 4.0 1.0
3.0 6.0 2.25
MSE= 1.1666666666666667
w= 2.6
1.0 2.0 0.3600000000000001
2.0 4.0 1.4400000000000004
3.0 6.0 3.2400000000000024
MSE= 1.6800000000000008
w= 2.7
1.0 2.0 0.49000000000000027
2.0 4.0 1.960000000000001
3.0 6.0 4.410000000000006
MSE= 2.2866666666666693
w= 2.8000000000000003
1.0 2.0 0.6400000000000005
2.0 4.0 2.560000000000002
3.0 6.0 5.760000000000002
MSE= 2.986666666666668
w= 2.9000000000000004
1.0 2.0 0.8100000000000006
2.0 4.0 3.2400000000000024
3.0 6.0 7.290000000000005
MSE= 3.780000000000003
w= 3.0
1.0 2.0 1.0
2.0 4.0 4.0
3.0 6.0 9.0
MSE= 4.666666666666667
w= 3.1
1.0 2.0 1.2100000000000002
2.0 4.0 4.840000000000001
3.0 6.0 10.890000000000004
MSE= 5.646666666666668
w= 3.2
1.0 2.0 1.4400000000000004
2.0 4.0 5.760000000000002
3.0 6.0 12.96000000000001
MSE= 6.720000000000003
w= 3.3000000000000003
1.0 2.0 1.6900000000000006
2.0 4.0 6.7600000000000025
3.0 6.0 15.210000000000003
MSE= 7.886666666666668
w= 3.4000000000000004
1.0 2.0 1.960000000000001
2.0 4.0 7.840000000000004
3.0 6.0 17.640000000000008
MSE= 9.14666666666667
w= 3.5
1.0 2.0 2.25
2.0 4.0 9.0
3.0 6.0 20.25
MSE= 10.5
w= 3.6
1.0 2.0 2.5600000000000005
2.0 4.0 10.240000000000002
3.0 6.0 23.040000000000006
MSE= 11.94666666666667
w= 3.7
1.0 2.0 2.8900000000000006
2.0 4.0 11.560000000000002
3.0 6.0 26.010000000000016
MSE= 13.486666666666673
w= 3.8000000000000003
1.0 2.0 3.240000000000001
2.0 4.0 12.960000000000004
3.0 6.0 29.160000000000004
MSE= 15.120000000000005
w= 3.9000000000000004
1.0 2.0 3.610000000000001
2.0 4.0 14.440000000000005
3.0 6.0 32.49000000000001
MSE= 16.84666666666667
w= 4.0
1.0 2.0 4.0
2.0 4.0 16.0
3.0 6.0 36.0
MSE= 18.666666666666668
作业
代码
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
#令y=2x+1
x_data=[1.0,2.0,3.0]
y_data=[2.0,4.0,6.0]
def forward(x):
return x*w+b
def loss(x,y):
y_pred=forward(x)
return (y_pred-y)*(y_pred-y)
w_list=np.arange(0.0,4.0,0.1)
b_list=np.arange(-2.0,2.0,0.1)
mse_list=[]
for w in np.arange(0.0,4.0,0.1):
for b in np.arange(-2.0,2.0,0.1):
print("w=",w)
print("b=",b)
l_sum=0
for x_v,y_v in zip(x_data,y_data):
loss_v=loss(x_v,y_v)
l_sum+=loss_v
print(x_v,y_v,loss_v)
l_sum/=3
print("MSE=",l_sum)
mse_list.append(l_sum)
#将一维的数值转变为二维坐标点
w,b=np.meshgrid(w_list,b_list)
#调整形状并转置,以适应w,b的坐标
mse=np.array(mse_list)
mse=np.transpose(mse.reshape(40,40))
fig = plt.figure()
ax =fig.add_axes(Axes3D(fig))
ax.plot_surface(w,b,mse,cmap=plt.get_cmap('rainbow'))
plt.show()
结果
梯度下降算法
课程
初始梯度下降算法
代码
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
#令y=2x+1
x_data=[1.0,2.0,3.0]
y_data=[2.0,4.0,6.0]
w=1.0
def forward(x):
return x*w
def cost(xs,ys):
cost=0
for x,y in zip(xs,ys):
cost+=(forward(x)-y)**2
return cost/len(xs)
def gradient(xs,ys):
grad=0
for x,y in zip(xs,ys):
grad+=2*x*(forward(x)-y)
return grad/len(xs)
print("Predict (before training)",4,forward(4))
cost_list=[]
epoch_list=[]
for epoch in range(100):
cost_v=cost(x_data,y_data)
grad_v=gradient(x_data,y_data)
w-=0.01*grad_v
cost_list.append(cost_v)
epoch_list.append(epoch)
print("Epoch:",epoch,"w=:",w,"cost=",cost_v)
print("Predict (after training)",4,forward(4))
plt.plot(epoch_list,cost_list)
plt.xlabel("Epoch")
plt.ylabel("Cost")
plt.show()
结果
Predict (before training) 4 4.0
Epoch: 0 w=: 1.0933333333333333 cost= 4.666666666666667
Epoch: 1 w=: 1.1779555555555554 cost= 3.8362074074074086
Epoch: 2 w=: 1.2546797037037036 cost= 3.1535329869958857
Epoch: 3 w=: 1.3242429313580246 cost= 2.592344272332262
Epoch: 4 w=: 1.3873135910979424 cost= 2.1310222071581117
Epoch: 5 w=: 1.4444976559288012 cost= 1.7517949663820642
Epoch: 6 w=: 1.4963445413754464 cost= 1.440053319920117
Epoch: 7 w=: 1.5433523841804047 cost= 1.1837878313441108
Epoch: 8 w=: 1.5859728283235668 cost= 0.9731262101573632
Epoch: 9 w=: 1.6246153643467005 cost= 0.7999529948031382
Epoch: 10 w=: 1.659651263674342 cost= 0.6575969151946154
Epoch: 11 w=: 1.6914171457314033 cost= 0.5405738908195378
Epoch: 12 w=: 1.7202182121298057 cost= 0.44437576375991855
Epoch: 13 w=: 1.7463311789976905 cost= 0.365296627844598
Epoch: 14 w=: 1.7700069356245727 cost= 0.3002900634939416
Epoch: 15 w=: 1.7914729549662791 cost= 0.2468517784170642
Epoch: 16 w=: 1.8109354791694263 cost= 0.2029231330489788
Epoch: 17 w=: 1.8285815011136133 cost= 0.16681183417217407
Epoch: 18 w=: 1.8445805610096762 cost= 0.1371267415488235
Epoch: 19 w=: 1.8590863753154396 cost= 0.11272427607497944
Epoch: 20 w=: 1.872238313619332 cost= 0.09266436490145864
Epoch: 21 w=: 1.8841627376815275 cost= 0.07617422636521683
Epoch: 22 w=: 1.8949742154979183 cost= 0.06261859959338009
Epoch: 23 w=: 1.904776622051446 cost= 0.051475271914629306
Epoch: 24 w=: 1.9136641373266443 cost= 0.04231496130368814
Epoch: 25 w=: 1.9217221511761575 cost= 0.03478477885657844
Epoch: 26 w=: 1.9290280837330496 cost= 0.02859463421027894
Epoch: 27 w=: 1.9356521292512983 cost= 0.023506060193480772
Epoch: 28 w=: 1.9416579305211772 cost= 0.01932302619282764
Epoch: 29 w=: 1.9471031903392007 cost= 0.015884386331668398
Epoch: 30 w=: 1.952040225907542 cost= 0.01305767153735723
Epoch: 31 w=: 1.9565164714895047 cost= 0.010733986344664803
Epoch: 32 w=: 1.9605749341504843 cost= 0.008823813841374291
Epoch: 33 w=: 1.9642546069631057 cost= 0.007253567147113681
Epoch: 34 w=: 1.9675908436465492 cost= 0.005962754575689583
Epoch: 35 w=: 1.970615698239538 cost= 0.004901649272531298
Epoch: 36 w=: 1.9733582330705144 cost= 0.004029373553099482
Epoch: 37 w=: 1.975844797983933 cost= 0.0033123241439168096
Epoch: 38 w=: 1.9780992835054327 cost= 0.0027228776607060357
Epoch: 39 w=: 1.980143350378259 cost= 0.002238326453885249
Epoch: 40 w=: 1.9819966376762883 cost= 0.001840003826269386
Epoch: 41 w=: 1.983676951493168 cost= 0.0015125649231412608
Epoch: 42 w=: 1.9852004360204722 cost= 0.0012433955919298103
Epoch: 43 w=: 1.9865817286585614 cost= 0.0010221264385926248
Epoch: 44 w=: 1.987834100650429 cost= 0.0008402333603648631
Epoch: 45 w=: 1.9889695845897222 cost= 0.0006907091659248264
Epoch: 46 w=: 1.9899990900280147 cost= 0.0005677936325753796
Epoch: 47 w=: 1.9909325082920666 cost= 0.0004667516012495216
Epoch: 48 w=: 1.9917788075181404 cost= 0.000383690560742734
Epoch: 49 w=: 1.9925461188164473 cost= 0.00031541069384432885
Epoch: 50 w=: 1.9932418143935788 cost= 0.0002592816085930997
Epoch: 51 w=: 1.9938725783835114 cost= 0.0002131410058905752
Epoch: 52 w=: 1.994444471067717 cost= 0.00017521137977565514
Epoch: 53 w=: 1.9949629871013967 cost= 0.0001440315413480261
Epoch: 54 w=: 1.9954331083052663 cost= 0.0001184003283899171
Epoch: 55 w=: 1.9958593515301082 cost= 9.733033217332803e-05
Epoch: 56 w=: 1.9962458120539648 cost= 8.000985883901657e-05
Epoch: 57 w=: 1.9965962029289281 cost= 6.57716599593935e-05
Epoch: 58 w=: 1.9969138906555615 cost= 5.406722767150764e-05
Epoch: 59 w=: 1.997201927527709 cost= 4.444566413387458e-05
Epoch: 60 w=: 1.9974630809584561 cost= 3.65363112808981e-05
Epoch: 61 w=: 1.9976998600690001 cost= 3.0034471708953996e-05
Epoch: 62 w=: 1.9979145397958935 cost= 2.4689670610172655e-05
Epoch: 63 w=: 1.9981091827482769 cost= 2.0296006560253656e-05
Epoch: 64 w=: 1.9982856590251044 cost= 1.6684219437262796e-05
Epoch: 65 w=: 1.9984456641827613 cost= 1.3715169898293847e-05
Epoch: 66 w=: 1.9985907355257035 cost= 1.1274479219506377e-05
Epoch: 67 w=: 1.9987222668766378 cost= 9.268123006398985e-06
Epoch: 68 w=: 1.9988415219681517 cost= 7.61880902783969e-06
Epoch: 69 w=: 1.9989496465844576 cost= 6.262999634617916e-06
Epoch: 70 w=: 1.9990476795699081 cost= 5.1484640551938914e-06
Epoch: 71 w=: 1.9991365628100501 cost= 4.232266273994499e-06
Epoch: 72 w=: 1.999217150281112 cost= 3.479110977946351e-06
Epoch: 73 w=: 1.999290216254875 cost= 2.859983851026929e-06
Epoch: 74 w=: 1.9993564627377531 cost= 2.3510338359374262e-06
Epoch: 75 w=: 1.9994165262155628 cost= 1.932654303533636e-06
Epoch: 76 w=: 1.999470983768777 cost= 1.5887277332523938e-06
Epoch: 77 w=: 1.9995203586170245 cost= 1.3060048068548734e-06
Epoch: 78 w=: 1.9995651251461022 cost= 1.0735939958924364e-06
Epoch: 79 w=: 1.9996057134657994 cost= 8.825419799121559e-07
Epoch: 80 w=: 1.9996425135423248 cost= 7.254887315754342e-07
Epoch: 81 w=: 1.999675878945041 cost= 5.963839812987369e-07
Epoch: 82 w=: 1.999706130243504 cost= 4.902541385825727e-07
Epoch: 83 w=: 1.9997335580874436 cost= 4.0301069098738336e-07
Epoch: 84 w=: 1.9997584259992822 cost= 3.312926995781724e-07
Epoch: 85 w=: 1.9997809729060159 cost= 2.723373231729343e-07
Epoch: 86 w=: 1.9998014154347876 cost= 2.2387338352920307e-07
Epoch: 87 w=: 1.9998199499942075 cost= 1.8403387118941732e-07
Epoch: 88 w=: 1.9998367546614149 cost= 1.5128402140063082e-07
Epoch: 89 w=: 1.9998519908930161 cost= 1.2436218932547864e-07
Epoch: 90 w=: 1.9998658050763347 cost= 1.0223124683409346e-07
Epoch: 91 w=: 1.9998783299358769 cost= 8.403862850836479e-08
Epoch: 92 w=: 1.9998896858085284 cost= 6.908348768398496e-08
Epoch: 93 w=: 1.9998999817997325 cost= 5.678969725349543e-08
Epoch: 94 w=: 1.9999093168317574 cost= 4.66836551287917e-08
Epoch: 95 w=: 1.9999177805941268 cost= 3.8376039345125727e-08
Epoch: 96 w=: 1.9999254544053418 cost= 3.154680994333735e-08
Epoch: 97 w=: 1.9999324119941766 cost= 2.593287985380858e-08
Epoch: 98 w=: 1.9999387202080534 cost= 2.131797981222471e-08
Epoch: 99 w=: 1.9999444396553017 cost= 1.752432687141379e-08
Predict (after training) 4 7.999777758621207
随机梯度下降算法
在梯度下降中可能遇到导数为0的鞍点,使得w无法进行迭代,采用随机梯度下降可以利用噪声使迭代继续进行
代码
import numpy as np
import matplotlib.pyplot as plt
x_data=[1.0,2.0,3.0]
y_data=[2.0,4.0,6.0]
w=1.0
#前馈函数
def forward(x):
return x*w
#损失函数
def loss(x,y):
y_pred=forward(x)
return (y_pred-y)*(y_pred-y)
def gradient(x,y):
return 2*x*(forward(x)-y)
cost_list=[]
epoch_list=[]
print("Predict (before training)",4,forward(4))
for epoch in range(100):
cost_v=0
for x,y in zip(x_data,y_data):
grad=gradient(x,y)
w-=0.01*grad
print("\tgrad:",x,y,grad)
cost_v+=loss(x,y)
cost_list.append(cost_v)
epoch_list.append(epoch)
print("Epoch:",epoch,"w=:",w,"cost=",cost_v)
print("Predict (after training)",4,forward(4))
plt.plot(epoch_list,cost_list)
plt.xlabel("Epoch")
plt.ylabel("Cost")
plt.show()
结果
Predict (before training) 4 4.0
grad: 1.0 2.0 -2.0
grad: 2.0 4.0 -7.84
grad: 3.0 6.0 -16.2288
Epoch: 0 w=: 1.260688 cost= 9.131170340095998
grad: 1.0 2.0 -1.478624
grad: 2.0 4.0 -5.796206079999999
grad: 3.0 6.0 -11.998146585599997
Epoch: 1 w=: 1.453417766656 cost= 4.990935477534164
grad: 1.0 2.0 -1.093164466688
grad: 2.0 4.0 -4.285204709416961
grad: 3.0 6.0 -8.87037374849311
Epoch: 2 w=: 1.5959051959019805 cost= 2.727956659786429
grad: 1.0 2.0 -0.8081896081960389
grad: 2.0 4.0 -3.1681032641284723
grad: 3.0 6.0 -6.557973756745939
Epoch: 3 w=: 1.701247862192685 cost= 1.4910526435717042
grad: 1.0 2.0 -0.59750427561463
grad: 2.0 4.0 -2.3422167604093502
grad: 3.0 6.0 -4.848388694047353
Epoch: 4 w=: 1.7791289594933983 cost= 0.814982883956898
grad: 1.0 2.0 -0.44174208101320334
grad: 2.0 4.0 -1.7316289575717576
grad: 3.0 6.0 -3.584471942173538
Epoch: 5 w=: 1.836707389300983 cost= 0.4454551648502959
grad: 1.0 2.0 -0.3265852213980338
grad: 2.0 4.0 -1.2802140678802925
grad: 3.0 6.0 -2.650043120512205
Epoch: 6 w=: 1.8792758133988885 cost= 0.24347787885849426
grad: 1.0 2.0 -0.241448373202223
grad: 2.0 4.0 -0.946477622952715
grad: 3.0 6.0 -1.9592086795121197
Epoch: 7 w=: 1.910747160155559 cost= 0.13308068279633564
grad: 1.0 2.0 -0.17850567968888198
grad: 2.0 4.0 -0.6997422643804168
grad: 3.0 6.0 -1.4484664872674653
Epoch: 8 w=: 1.9340143044689266 cost= 0.07273953681776574
grad: 1.0 2.0 -0.13197139106214673
grad: 2.0 4.0 -0.5173278529636143
grad: 3.0 6.0 -1.0708686556346834
Epoch: 9 w=: 1.9512159834655312 cost= 0.03975813848626224
grad: 1.0 2.0 -0.09756803306893769
grad: 2.0 4.0 -0.38246668963023644
grad: 3.0 6.0 -0.7917060475345892
Epoch: 10 w=: 1.9639333911678687 cost= 0.02173109212742138
grad: 1.0 2.0 -0.07213321766426262
grad: 2.0 4.0 -0.2827622132439096
grad: 3.0 6.0 -0.5853177814148953
Epoch: 11 w=: 1.9733355232910992 cost= 0.011877828868010278
grad: 1.0 2.0 -0.05332895341780164
grad: 2.0 4.0 -0.2090494973977819
grad: 3.0 6.0 -0.4327324596134101
Epoch: 12 w=: 1.9802866323953892 cost= 0.006492210229954897
grad: 1.0 2.0 -0.039426735209221686
grad: 2.0 4.0 -0.15455280202014876
grad: 3.0 6.0 -0.3199243001817109
Epoch: 13 w=: 1.9854256707695 cost= 0.003548526766827499
grad: 1.0 2.0 -0.02914865846100012
grad: 2.0 4.0 -0.11426274116712065
grad: 3.0 6.0 -0.2365238742159388
Epoch: 14 w=: 1.9892250235079405 cost= 0.0019395616852935693
grad: 1.0 2.0 -0.021549952984118992
grad: 2.0 4.0 -0.08447581569774698
grad: 3.0 6.0 -0.17486493849433593
Epoch: 15 w=: 1.9920339305797026 cost= 0.0010601299576561939
grad: 1.0 2.0 -0.015932138840594856
grad: 2.0 4.0 -0.062453984255132156
grad: 3.0 6.0 -0.12927974740812687
Epoch: 16 w=: 1.994110589284741 cost= 0.000579448199890608
grad: 1.0 2.0 -0.011778821430517894
grad: 2.0 4.0 -0.046172980007630926
grad: 3.0 6.0 -0.09557806861579543
Epoch: 17 w=: 1.9956458879852805 cost= 0.0003167160912033633
grad: 1.0 2.0 -0.008708224029438938
grad: 2.0 4.0 -0.03413623819540135
grad: 3.0 6.0 -0.07066201306448505
Epoch: 18 w=: 1.9967809527381737 cost= 0.00017311138846592646
grad: 1.0 2.0 -0.006438094523652627
grad: 2.0 4.0 -0.02523733053271826
grad: 3.0 6.0 -0.052241274202728505
Epoch: 19 w=: 1.9976201197307648 cost= 9.461960932498085e-05
grad: 1.0 2.0 -0.004759760538470381
grad: 2.0 4.0 -0.01865826131080439
grad: 3.0 6.0 -0.03862260091336722
Epoch: 20 w=: 1.998240525958391 cost= 5.171739738298014e-05
grad: 1.0 2.0 -0.0035189480832178432
grad: 2.0 4.0 -0.01379427648621423
grad: 3.0 6.0 -0.028554152326460525
Epoch: 21 w=: 1.99869919972735 cost= 2.8267810564328208e-05
grad: 1.0 2.0 -0.002601600545300009
grad: 2.0 4.0 -0.01019827413757568
grad: 3.0 6.0 -0.021110427464781978
Epoch: 22 w=: 1.9990383027488265 cost= 1.5450683029998435e-05
grad: 1.0 2.0 -0.001923394502346909
grad: 2.0 4.0 -0.007539706449199102
grad: 3.0 6.0 -0.01560719234984198
Epoch: 23 w=: 1.9992890056818404 cost= 8.445068837224808e-06
grad: 1.0 2.0 -0.0014219886363191492
grad: 2.0 4.0 -0.005574195454370212
grad: 3.0 6.0 -0.011538584590544687
Epoch: 24 w=: 1.999474353368653 cost= 4.61592458579276e-06
grad: 1.0 2.0 -0.0010512932626940419
grad: 2.0 4.0 -0.004121069589761106
grad: 3.0 6.0 -0.008530614050808794
Epoch: 25 w=: 1.9996113831376856 cost= 2.522982369050795e-06
grad: 1.0 2.0 -0.0007772337246287897
grad: 2.0 4.0 -0.0030467562005451754
grad: 3.0 6.0 -0.006306785335127074
Epoch: 26 w=: 1.9997126908902887 cost= 1.379017337962252e-06
grad: 1.0 2.0 -0.0005746182194226179
grad: 2.0 4.0 -0.002252503420136165
grad: 3.0 6.0 -0.00466268207967957
Epoch: 27 w=: 1.9997875889274812 cost= 7.537463764030389e-07
grad: 1.0 2.0 -0.0004248221450375844
grad: 2.0 4.0 -0.0016653028085471533
grad: 3.0 6.0 -0.0034471768136938863
Epoch: 28 w=: 1.9998429619451539 cost= 4.119843777895263e-07
grad: 1.0 2.0 -0.00031407610969225175
grad: 2.0 4.0 -0.0012311783499932005
grad: 3.0 6.0 -0.0025485391844828342
Epoch: 29 w=: 1.9998838998815958 cost= 2.2518334131447684e-07
grad: 1.0 2.0 -0.00023220023680847746
grad: 2.0 4.0 -0.0009102249282886277
grad: 3.0 6.0 -0.0018841656015560204
Epoch: 30 w=: 1.9999141657892625 cost= 1.230812136072453e-07
grad: 1.0 2.0 -0.00017166842147497974
grad: 2.0 4.0 -0.0006729402121816719
grad: 3.0 6.0 -0.0013929862392156878
Epoch: 31 w=: 1.9999365417379913 cost= 6.72740046159769e-08
grad: 1.0 2.0 -0.0001269165240174175
grad: 2.0 4.0 -0.0004975127741477792
grad: 3.0 6.0 -0.0010298514424817995
Epoch: 32 w=: 1.9999530845453979 cost= 3.677077568883698e-08
grad: 1.0 2.0 -9.383090920422887e-05
grad: 2.0 4.0 -0.00036781716408107457
grad: 3.0 6.0 -0.0007613815296476645
Epoch: 33 w=: 1.9999653148414271 cost= 2.009825269786531e-08
grad: 1.0 2.0 -6.937031714571162e-05
grad: 2.0 4.0 -0.0002719316432120422
grad: 3.0 6.0 -0.0005628985014531906
Epoch: 34 w=: 1.999974356846045 cost= 1.0985347846020229e-08
grad: 1.0 2.0 -5.1286307909848006e-05
grad: 2.0 4.0 -0.00020104232700646207
grad: 3.0 6.0 -0.0004161576169003922
Epoch: 35 w=: 1.9999810417085633 cost= 6.004395959775294e-09
grad: 1.0 2.0 -3.7916582873442906e-05
grad: 2.0 4.0 -0.0001486330048638962
grad: 3.0 6.0 -0.0003076703200690645
Epoch: 36 w=: 1.9999859839076413 cost= 3.2818961535760077e-09
grad: 1.0 2.0 -2.8032184717474706e-05
grad: 2.0 4.0 -0.0001098861640933535
grad: 3.0 6.0 -0.00022746435967313516
Epoch: 37 w=: 1.9999896377347262 cost= 1.7938261292475793e-09
grad: 1.0 2.0 -2.0724530547688857e-05
grad: 2.0 4.0 -8.124015974608767e-05
grad: 3.0 6.0 -0.00016816713067413502
Epoch: 38 w=: 1.999992339052936 cost= 9.804734918933698e-10
grad: 1.0 2.0 -1.5321894128117464e-05
grad: 2.0 4.0 -6.006182498197177e-05
grad: 3.0 6.0 -0.00012432797771566584
Epoch: 39 w=: 1.9999943361699042 cost= 5.359093909486216e-10
grad: 1.0 2.0 -1.1327660191629008e-05
grad: 2.0 4.0 -4.4404427951505454e-05
grad: 3.0 6.0 -9.191716585732479e-05
Epoch: 40 w=: 1.9999958126624442 cost= 2.929185517711759e-10
grad: 1.0 2.0 -8.37467511161094e-06
grad: 2.0 4.0 -3.282872643772805e-05
grad: 3.0 6.0 -6.795546372551087e-05
Epoch: 41 w=: 1.999996904251097 cost= 1.6010407620189493e-10
grad: 1.0 2.0 -6.191497806007362e-06
grad: 2.0 4.0 -2.4270671399762023e-05
grad: 3.0 6.0 -5.0240289795056015e-05
Epoch: 42 w=: 1.999997711275687 cost= 8.751004353779433e-11
grad: 1.0 2.0 -4.5774486259198e-06
grad: 2.0 4.0 -1.794359861406747e-05
grad: 3.0 6.0 -3.714324913239864e-05
Epoch: 43 w=: 1.9999983079186507 cost= 4.7831435036272696e-11
grad: 1.0 2.0 -3.3841626985164908e-06
grad: 2.0 4.0 -1.326591777761621e-05
grad: 3.0 6.0 -2.7460449796734565e-05
Epoch: 44 w=: 1.9999987490239537 cost= 2.614381258124446e-11
grad: 1.0 2.0 -2.5019520926150562e-06
grad: 2.0 4.0 -9.807652203264183e-06
grad: 3.0 6.0 -2.0301840059744336e-05
Epoch: 45 w=: 1.9999990751383971 cost= 1.4289743472838837e-11
grad: 1.0 2.0 -1.8497232057157476e-06
grad: 2.0 4.0 -7.250914967116273e-06
grad: 3.0 6.0 -1.5009393983689279e-05
Epoch: 46 w=: 1.9999993162387186 cost= 7.810519902420936e-12
grad: 1.0 2.0 -1.3675225627451937e-06
grad: 2.0 4.0 -5.3606884460322135e-06
grad: 3.0 6.0 -1.109662508014253e-05
Epoch: 47 w=: 1.9999994944870796 cost= 4.269091410874813e-12
grad: 1.0 2.0 -1.0110258408246864e-06
grad: 2.0 4.0 -3.963221296032771e-06
grad: 3.0 6.0 -8.20386808086937e-06
Epoch: 48 w=: 1.9999996262682318 cost= 2.3334095170724104e-12
grad: 1.0 2.0 -7.474635363990956e-07
grad: 2.0 4.0 -2.930057062755509e-06
grad: 3.0 6.0 -6.065218119744031e-06
Epoch: 49 w=: 1.999999723695619 cost= 1.2754001845445623e-12
grad: 1.0 2.0 -5.526087618612507e-07
grad: 2.0 4.0 -2.166226346744793e-06
grad: 3.0 6.0 -4.484088535150477e-06
Epoch: 50 w=: 1.9999997957248556 cost= 6.971110807809687e-13
grad: 1.0 2.0 -4.08550288710785e-07
grad: 2.0 4.0 -1.6015171322436572e-06
grad: 3.0 6.0 -3.3151404608133817e-06
Epoch: 51 w=: 1.9999998489769344 cost= 3.810285310612767e-13
grad: 1.0 2.0 -3.020461312175371e-07
grad: 2.0 4.0 -1.1840208351543424e-06
grad: 3.0 6.0 -2.4509231284497446e-06
Epoch: 52 w=: 1.9999998883468353 cost= 2.0826342583935095e-13
grad: 1.0 2.0 -2.2330632942768602e-07
grad: 2.0 4.0 -8.753608113920563e-07
grad: 3.0 6.0 -1.811996877876254e-06
Epoch: 53 w=: 1.9999999174534755 cost= 1.1383308862014303e-13
grad: 1.0 2.0 -1.6509304900935717e-07
grad: 2.0 4.0 -6.471647520100987e-07
grad: 3.0 6.0 -1.3396310407642886e-06
Epoch: 54 w=: 1.999999938972364 cost= 6.221914359741133e-14
grad: 1.0 2.0 -1.220552721115098e-07
grad: 2.0 4.0 -4.784566662863199e-07
grad: 3.0 6.0 -9.904052991061008e-07
Epoch: 55 w=: 1.9999999548815364 cost= 3.400787836254813e-14
grad: 1.0 2.0 -9.023692726373156e-08
grad: 2.0 4.0 -3.5372875473171916e-07
grad: 3.0 6.0 -7.322185204827747e-07
Epoch: 56 w=: 1.9999999666433785 cost= 1.8588102038943167e-14
grad: 1.0 2.0 -6.671324292994996e-08
grad: 2.0 4.0 -2.615159129248923e-07
grad: 3.0 6.0 -5.413379398078177e-07
Epoch: 57 w=: 1.9999999753390494 cost= 1.0159926345146978e-14
grad: 1.0 2.0 -4.932190122985958e-08
grad: 2.0 4.0 -1.9334185274999527e-07
grad: 3.0 6.0 -4.002176350326181e-07
Epoch: 58 w=: 1.9999999817678633 cost= 5.553235233228262e-15
grad: 1.0 2.0 -3.6464273378555845e-08
grad: 2.0 4.0 -1.429399514307761e-07
grad: 3.0 6.0 -2.9588569994132286e-07
Epoch: 59 w=: 1.9999999865207625 cost= 3.0352997458740945e-15
grad: 1.0 2.0 -2.6958475007887728e-08
grad: 2.0 4.0 -1.0567722164012139e-07
grad: 3.0 6.0 -2.1875184863517916e-07
Epoch: 60 w=: 1.999999990034638 cost= 1.6590409281392764e-15
grad: 1.0 2.0 -1.993072418216002e-08
grad: 2.0 4.0 -7.812843882959442e-08
grad: 3.0 6.0 -1.617258700292723e-07
Epoch: 61 w=: 1.9999999926324883 cost= 9.068022795449514e-16
grad: 1.0 2.0 -1.473502342363986e-08
grad: 2.0 4.0 -5.7761292637792394e-08
grad: 3.0 6.0 -1.195658771990793e-07
Epoch: 62 w=: 1.99999999455311 cost= 4.956420488952487e-16
grad: 1.0 2.0 -1.0893780100218464e-08
grad: 2.0 4.0 -4.270361841918202e-08
grad: 3.0 6.0 -8.839649012770678e-08
Epoch: 63 w=: 1.9999999959730488 cost= 2.7090914819461653e-16
grad: 1.0 2.0 -8.05390243385773e-09
grad: 2.0 4.0 -3.1571296688071016e-08
grad: 3.0 6.0 -6.53525820126788e-08
Epoch: 64 w=: 1.9999999970228268 cost= 1.480741124443087e-16
grad: 1.0 2.0 -5.9543463493128e-09
grad: 2.0 4.0 -2.334103754719763e-08
grad: 3.0 6.0 -4.8315948575350376e-08
Epoch: 65 w=: 1.9999999977989402 cost= 8.093467319454633e-17
grad: 1.0 2.0 -4.402119557767037e-09
grad: 2.0 4.0 -1.725630838222969e-08
grad: 3.0 6.0 -3.5720557178819945e-08
Epoch: 66 w=: 1.9999999983727301 cost= 4.4237452103093086e-17
grad: 1.0 2.0 -3.254539748809293e-09
grad: 2.0 4.0 -1.2757796596929438e-08
grad: 3.0 6.0 -2.6408640607655798e-08
Epoch: 67 w=: 1.9999999987969397 cost= 2.417940731628089e-17
grad: 1.0 2.0 -2.406120636067044e-09
grad: 2.0 4.0 -9.431992964437086e-09
grad: 3.0 6.0 -1.9524227568012975e-08
Epoch: 68 w=: 1.999999999110563 cost= 1.321603818914294e-17
grad: 1.0 2.0 -1.7788739370416806e-09
grad: 2.0 4.0 -6.97318647269185e-09
grad: 3.0 6.0 -1.4434496264925656e-08
Epoch: 69 w=: 1.9999999993424284 cost= 7.223653526452133e-18
grad: 1.0 2.0 -1.3151431055291596e-09
grad: 2.0 4.0 -5.155360582875801e-09
grad: 3.0 6.0 -1.067159693945996e-08
Epoch: 70 w=: 1.9999999995138495 cost= 3.9483193314682906e-18
grad: 1.0 2.0 -9.72300906454393e-10
grad: 2.0 4.0 -3.811418736177075e-09
grad: 3.0 6.0 -7.88963561149103e-09
Epoch: 71 w=: 1.9999999996405833 cost= 2.1580806069463956e-18
grad: 1.0 2.0 -7.18833437218791e-10
grad: 2.0 4.0 -2.8178277489132597e-09
grad: 3.0 6.0 -5.832902161273523e-09
Epoch: 72 w=: 1.999999999734279 cost= 1.1795680215816471e-18
grad: 1.0 2.0 -5.314420015167798e-10
grad: 2.0 4.0 -2.0832526814729135e-09
grad: 3.0 6.0 -4.31233715403323e-09
Epoch: 73 w=: 1.9999999998035491 cost= 6.4473110851989965e-19
grad: 1.0 2.0 -3.92901711165905e-10
grad: 2.0 4.0 -1.5401742103904326e-09
grad: 3.0 6.0 -3.188159070077745e-09
Epoch: 74 w=: 1.9999999998547615 cost= 3.5239871381328163e-19
grad: 1.0 2.0 -2.9047697580608656e-10
grad: 2.0 4.0 -1.1386696030513122e-09
grad: 3.0 6.0 -2.3570478902001923e-09
Epoch: 75 w=: 1.9999999998926234 cost= 1.9261504048240074e-19
grad: 1.0 2.0 -2.1475310418850313e-10
grad: 2.0 4.0 -8.418314934033333e-10
grad: 3.0 6.0 -1.7425900722400911e-09
Epoch: 76 w=: 1.9999999999206153 cost= 1.0527977803310672e-19
grad: 1.0 2.0 -1.5876944203796484e-10
grad: 2.0 4.0 -6.223768167501476e-10
grad: 3.0 6.0 -1.2883241140571045e-09
Epoch: 77 w=: 1.9999999999413098 cost= 5.754429955693564e-20
grad: 1.0 2.0 -1.17380327679939e-10
grad: 2.0 4.0 -4.601314884666863e-10
grad: 3.0 6.0 -9.524754318590567e-10
Epoch: 78 w=: 1.9999999999566096 cost= 3.1452826708032987e-20
grad: 1.0 2.0 -8.678080476443029e-11
grad: 2.0 4.0 -3.4018121652934497e-10
grad: 3.0 6.0 -7.041780492045291e-10
Epoch: 79 w=: 1.9999999999679208 cost= 1.7191656109647926e-20
grad: 1.0 2.0 -6.415845632545825e-11
grad: 2.0 4.0 -2.5150193039280566e-10
grad: 3.0 6.0 -5.206075570640678e-10
Epoch: 80 w=: 1.9999999999762834 cost= 9.396758983838199e-21
grad: 1.0 2.0 -4.743316850408519e-11
grad: 2.0 4.0 -1.8593837580738182e-10
grad: 3.0 6.0 -3.8489211817704927e-10
Epoch: 81 w=: 1.999999999982466 cost= 5.136032805157672e-21
grad: 1.0 2.0 -3.5067948545020045e-11
grad: 2.0 4.0 -1.3746692673066718e-10
grad: 3.0 6.0 -2.845563784603655e-10
Epoch: 82 w=: 1.9999999999870368 cost= 2.8073414984553174e-21
grad: 1.0 2.0 -2.5926372160256506e-11
grad: 2.0 4.0 -1.0163070385260653e-10
grad: 3.0 6.0 -2.1037571684701106e-10
Epoch: 83 w=: 1.999999999990416 cost= 1.5344657811957726e-21
grad: 1.0 2.0 -1.9167778475548403e-11
grad: 2.0 4.0 -7.51381179497912e-11
grad: 3.0 6.0 -1.5553425214420713e-10
Epoch: 84 w=: 1.9999999999929146 cost= 8.386855712615821e-22
grad: 1.0 2.0 -1.4170886686315498e-11
grad: 2.0 4.0 -5.555023108172463e-11
grad: 3.0 6.0 -1.1499068364173581e-10
Epoch: 85 w=: 1.9999999999947617 cost= 4.584004440146043e-22
grad: 1.0 2.0 -1.0476508549572827e-11
grad: 2.0 4.0 -4.106759377009439e-11
grad: 3.0 6.0 -8.500933290633839e-11
Epoch: 86 w=: 1.9999999999961273 cost= 2.5055063379082294e-22
grad: 1.0 2.0 -7.745359908994942e-12
grad: 2.0 4.0 -3.036149109902908e-11
grad: 3.0 6.0 -6.285105769165966e-11
Epoch: 87 w=: 1.999999999997137 cost= 1.3693547355931303e-22
grad: 1.0 2.0 -5.726086271806707e-12
grad: 2.0 4.0 -2.2446045022661565e-11
grad: 3.0 6.0 -4.646416584819235e-11
Epoch: 88 w=: 1.9999999999978835 cost= 7.484131529060059e-23
grad: 1.0 2.0 -4.233058348290797e-12
grad: 2.0 4.0 -1.659294923683774e-11
grad: 3.0 6.0 -3.4351188560322043e-11
Epoch: 89 w=: 1.9999999999984353 cost= 4.090151225674099e-23
grad: 1.0 2.0 -3.1294966618133913e-12
grad: 2.0 4.0 -1.226752033289813e-11
grad: 3.0 6.0 -2.539835008974478e-11
Epoch: 90 w=: 1.9999999999988431 cost= 2.2361520788848633e-23
grad: 1.0 2.0 -2.3137047833188262e-12
grad: 2.0 4.0 -9.070078021977679e-12
grad: 3.0 6.0 -1.8779644506139448e-11
Epoch: 91 w=: 1.9999999999991447 cost= 1.2222275007963959e-23
grad: 1.0 2.0 -1.7106316363424412e-12
grad: 2.0 4.0 -6.7057470687359455e-12
grad: 3.0 6.0 -1.3882228699912957e-11
Epoch: 92 w=: 1.9999999999993676 cost= 6.680541397586452e-24
grad: 1.0 2.0 -1.2647660696529783e-12
grad: 2.0 4.0 -4.957811938766099e-12
grad: 3.0 6.0 -1.0263789818054647e-11
Epoch: 93 w=: 1.9999999999995324 cost= 3.6539335536669686e-24
grad: 1.0 2.0 -9.352518759442319e-13
grad: 2.0 4.0 -3.666400516522117e-12
grad: 3.0 6.0 -7.58859641791787e-12
Epoch: 94 w=: 1.9999999999996543 cost= 1.997419675061985e-24
grad: 1.0 2.0 -6.914468997365475e-13
grad: 2.0 4.0 -2.7107205369247822e-12
grad: 3.0 6.0 -5.611511255665391e-12
Epoch: 95 w=: 1.9999999999997444 cost= 1.091085548139986e-24
grad: 1.0 2.0 -5.111466805374221e-13
grad: 2.0 4.0 -2.0037305148434825e-12
grad: 3.0 6.0 -4.1460168631601846e-12
Epoch: 96 w=: 1.999999999999811 cost= 5.963228352228142e-25
grad: 1.0 2.0 -3.779199175824033e-13
grad: 2.0 4.0 -1.4814816040598089e-12
grad: 3.0 6.0 -3.064215547965432e-12
Epoch: 97 w=: 1.9999999999998603 cost= 3.26058694663643e-25
grad: 1.0 2.0 -2.793321129956894e-13
grad: 2.0 4.0 -1.0942358130705543e-12
grad: 3.0 6.0 -2.2648549702353193e-12
Epoch: 98 w=: 1.9999999999998967 cost= 1.7819519823469544e-25
grad: 1.0 2.0 -2.0650148258027912e-13
grad: 2.0 4.0 -8.100187187665142e-13
grad: 3.0 6.0 -1.6786572132332367e-12
Epoch: 99 w=: 1.9999999999999236 cost= 9.755053953963032e-26
Predict (after training) 4 7.9999999999996945
注:随机梯度下降算法最终结果更加精确,因为虽然Epoch相同,但经历了更多次更新
反向传播
课程
代码
import torch
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
#令y=2x+1
x_data=[1.0,2.0,3.0]
y_data=[2.0,4.0,6.0]
w=torch.Tensor([1.0])
w.requires_grad=True
#结果为Tensor类型
def forward(x):
return x*w
#构建计算图
def loss(x,y):
y_pred=forward(x)
return (y_pred-y)**2
print("Predict (before training)",4,forward(4).item())
for epoch in range(100):
for x,y in zip(x_data,y_data):
#前向传播求损失
loss_v=loss(x,y)
#反向传播求梯度
loss_v.backward()
print("\tgrad:",x,y,w.grad.item())
#更新参数
w.data=w.data-0.01*w.grad.data
#清空梯度,否则会累加
w.grad.data.zero_()
print("progress:",epoch,loss_v.item())
print("Predict (after training)",4,forward(4).item())
结果
Predict (before training) 4 4.0
grad: 1.0 2.0 -2.0
grad: 2.0 4.0 -7.840000152587891
grad: 3.0 6.0 -16.228801727294922
progress: 0 7.315943717956543
grad: 1.0 2.0 -1.478623867034912
grad: 2.0 4.0 -5.796205520629883
grad: 3.0 6.0 -11.998146057128906
progress: 1 3.9987640380859375
grad: 1.0 2.0 -1.0931644439697266
grad: 2.0 4.0 -4.285204887390137
grad: 3.0 6.0 -8.870372772216797
progress: 2 2.1856532096862793
grad: 1.0 2.0 -0.8081896305084229
grad: 2.0 4.0 -3.1681032180786133
grad: 3.0 6.0 -6.557973861694336
progress: 3 1.1946394443511963
grad: 1.0 2.0 -0.5975041389465332
grad: 2.0 4.0 -2.3422164916992188
grad: 3.0 6.0 -4.848389625549316
progress: 4 0.6529689431190491
grad: 1.0 2.0 -0.4417421817779541
grad: 2.0 4.0 -1.7316293716430664
grad: 3.0 6.0 -3.58447265625
progress: 5 0.35690122842788696
grad: 1.0 2.0 -0.3265852928161621
grad: 2.0 4.0 -1.2802143096923828
grad: 3.0 6.0 -2.650045394897461
progress: 6 0.195076122879982
grad: 1.0 2.0 -0.24144840240478516
grad: 2.0 4.0 -0.9464778900146484
grad: 3.0 6.0 -1.9592113494873047
progress: 7 0.10662525147199631
grad: 1.0 2.0 -0.17850565910339355
grad: 2.0 4.0 -0.699742317199707
grad: 3.0 6.0 -1.4484672546386719
progress: 8 0.0582793727517128
grad: 1.0 2.0 -0.1319713592529297
grad: 2.0 4.0 -0.5173273086547852
grad: 3.0 6.0 -1.070866584777832
progress: 9 0.03185431286692619
grad: 1.0 2.0 -0.09756779670715332
grad: 2.0 4.0 -0.3824653625488281
grad: 3.0 6.0 -0.7917022705078125
progress: 10 0.017410902306437492
grad: 1.0 2.0 -0.07213282585144043
grad: 2.0 4.0 -0.2827606201171875
grad: 3.0 6.0 -0.5853137969970703
progress: 11 0.009516451507806778
grad: 1.0 2.0 -0.053328514099121094
grad: 2.0 4.0 -0.2090473175048828
grad: 3.0 6.0 -0.43272972106933594
progress: 12 0.005201528314501047
grad: 1.0 2.0 -0.039426326751708984
grad: 2.0 4.0 -0.15455150604248047
grad: 3.0 6.0 -0.3199195861816406
progress: 13 0.0028430151287466288
grad: 1.0 2.0 -0.029148340225219727
grad: 2.0 4.0 -0.11426162719726562
grad: 3.0 6.0 -0.23652076721191406
progress: 14 0.0015539465239271522
grad: 1.0 2.0 -0.021549701690673828
grad: 2.0 4.0 -0.08447456359863281
grad: 3.0 6.0 -0.17486286163330078
progress: 15 0.0008493617060594261
grad: 1.0 2.0 -0.01593184471130371
grad: 2.0 4.0 -0.062453269958496094
grad: 3.0 6.0 -0.12927818298339844
progress: 16 0.00046424579340964556
grad: 1.0 2.0 -0.011778593063354492
grad: 2.0 4.0 -0.046172142028808594
grad: 3.0 6.0 -0.09557533264160156
progress: 17 0.0002537401160225272
grad: 1.0 2.0 -0.00870823860168457
grad: 2.0 4.0 -0.03413581848144531
grad: 3.0 6.0 -0.07066154479980469
progress: 18 0.00013869594840798527
grad: 1.0 2.0 -0.006437778472900391
grad: 2.0 4.0 -0.025236129760742188
grad: 3.0 6.0 -0.052239418029785156
progress: 19 7.580435340059921e-05
grad: 1.0 2.0 -0.004759550094604492
grad: 2.0 4.0 -0.018657684326171875
grad: 3.0 6.0 -0.038620948791503906
progress: 20 4.143271507928148e-05
grad: 1.0 2.0 -0.003518819808959961
grad: 2.0 4.0 -0.0137939453125
grad: 3.0 6.0 -0.028553009033203125
progress: 21 2.264650902361609e-05
grad: 1.0 2.0 -0.00260162353515625
grad: 2.0 4.0 -0.010198593139648438
grad: 3.0 6.0 -0.021108627319335938
progress: 22 1.2377059647405986e-05
grad: 1.0 2.0 -0.0019233226776123047
grad: 2.0 4.0 -0.0075397491455078125
grad: 3.0 6.0 -0.0156097412109375
progress: 23 6.768445018678904e-06
grad: 1.0 2.0 -0.0014221668243408203
grad: 2.0 4.0 -0.0055751800537109375
grad: 3.0 6.0 -0.011541366577148438
progress: 24 3.7000872907810844e-06
grad: 1.0 2.0 -0.0010514259338378906
grad: 2.0 4.0 -0.0041217803955078125
grad: 3.0 6.0 -0.008531570434570312
progress: 25 2.021880391112063e-06
grad: 1.0 2.0 -0.0007772445678710938
grad: 2.0 4.0 -0.0030469894409179688
grad: 3.0 6.0 -0.006305694580078125
progress: 26 1.1044940038118511e-06
grad: 1.0 2.0 -0.0005745887756347656
grad: 2.0 4.0 -0.0022525787353515625
grad: 3.0 6.0 -0.0046634674072265625
progress: 27 6.041091182851233e-07
grad: 1.0 2.0 -0.0004248619079589844
grad: 2.0 4.0 -0.0016651153564453125
grad: 3.0 6.0 -0.003444671630859375
progress: 28 3.296045179013163e-07
grad: 1.0 2.0 -0.0003139972686767578
grad: 2.0 4.0 -0.0012311935424804688
grad: 3.0 6.0 -0.0025491714477539062
progress: 29 1.805076408345485e-07
grad: 1.0 2.0 -0.00023221969604492188
grad: 2.0 4.0 -0.0009107589721679688
grad: 3.0 6.0 -0.0018854141235351562
progress: 30 9.874406714516226e-08
grad: 1.0 2.0 -0.00017189979553222656
grad: 2.0 4.0 -0.0006742477416992188
grad: 3.0 6.0 -0.00139617919921875
progress: 31 5.4147676564753056e-08
grad: 1.0 2.0 -0.0001270771026611328
grad: 2.0 4.0 -0.0004978179931640625
grad: 3.0 6.0 -0.00102996826171875
progress: 32 2.9467628337442875e-08
grad: 1.0 2.0 -9.393692016601562e-05
grad: 2.0 4.0 -0.0003681182861328125
grad: 3.0 6.0 -0.0007610321044921875
progress: 33 1.6088051779661328e-08
grad: 1.0 2.0 -6.937980651855469e-05
grad: 2.0 4.0 -0.00027179718017578125
grad: 3.0 6.0 -0.000560760498046875
progress: 34 8.734787115827203e-09
grad: 1.0 2.0 -5.125999450683594e-05
grad: 2.0 4.0 -0.00020122528076171875
grad: 3.0 6.0 -0.0004177093505859375
progress: 35 4.8466972657479346e-09
grad: 1.0 2.0 -3.790855407714844e-05
grad: 2.0 4.0 -0.000148773193359375
grad: 3.0 6.0 -0.000308990478515625
progress: 36 2.6520865503698587e-09
grad: 1.0 2.0 -2.8133392333984375e-05
grad: 2.0 4.0 -0.000110626220703125
grad: 3.0 6.0 -0.0002288818359375
progress: 37 1.4551915228366852e-09
grad: 1.0 2.0 -2.09808349609375e-05
grad: 2.0 4.0 -8.20159912109375e-05
grad: 3.0 6.0 -0.00016880035400390625
progress: 38 7.914877642178908e-10
grad: 1.0 2.0 -1.5497207641601562e-05
grad: 2.0 4.0 -6.103515625e-05
grad: 3.0 6.0 -0.000125885009765625
progress: 39 4.4019543565809727e-10
grad: 1.0 2.0 -1.1444091796875e-05
grad: 2.0 4.0 -4.482269287109375e-05
grad: 3.0 6.0 -9.1552734375e-05
progress: 40 2.3283064365386963e-10
grad: 1.0 2.0 -8.344650268554688e-06
grad: 2.0 4.0 -3.24249267578125e-05
grad: 3.0 6.0 -6.580352783203125e-05
progress: 41 1.2028067430946976e-10
grad: 1.0 2.0 -5.9604644775390625e-06
grad: 2.0 4.0 -2.288818359375e-05
grad: 3.0 6.0 -4.57763671875e-05
progress: 42 5.820766091346741e-11
grad: 1.0 2.0 -4.291534423828125e-06
grad: 2.0 4.0 -1.71661376953125e-05
grad: 3.0 6.0 -3.719329833984375e-05
progress: 43 3.842615114990622e-11
grad: 1.0 2.0 -3.337860107421875e-06
grad: 2.0 4.0 -1.33514404296875e-05
grad: 3.0 6.0 -2.86102294921875e-05
progress: 44 2.2737367544323206e-11
grad: 1.0 2.0 -2.6226043701171875e-06
grad: 2.0 4.0 -1.049041748046875e-05
grad: 3.0 6.0 -2.288818359375e-05
progress: 45 1.4551915228366852e-11
grad: 1.0 2.0 -1.9073486328125e-06
grad: 2.0 4.0 -7.62939453125e-06
grad: 3.0 6.0 -1.430511474609375e-05
progress: 46 5.6843418860808015e-12
grad: 1.0 2.0 -1.430511474609375e-06
grad: 2.0 4.0 -5.7220458984375e-06
grad: 3.0 6.0 -1.1444091796875e-05
progress: 47 3.637978807091713e-12
grad: 1.0 2.0 -1.1920928955078125e-06
grad: 2.0 4.0 -4.76837158203125e-06
grad: 3.0 6.0 -1.1444091796875e-05
progress: 48 3.637978807091713e-12
grad: 1.0 2.0 -9.5367431640625e-07
grad: 2.0 4.0 -3.814697265625e-06
grad: 3.0 6.0 -8.58306884765625e-06
progress: 49 2.0463630789890885e-12
grad: 1.0 2.0 -7.152557373046875e-07
grad: 2.0 4.0 -2.86102294921875e-06
grad: 3.0 6.0 -5.7220458984375e-06
progress: 50 9.094947017729282e-13
grad: 1.0 2.0 -7.152557373046875e-07
grad: 2.0 4.0 -2.86102294921875e-06
grad: 3.0 6.0 -5.7220458984375e-06
progress: 51 9.094947017729282e-13
grad: 1.0 2.0 -7.152557373046875e-07
grad: 2.0 4.0 -2.86102294921875e-06
grad: 3.0 6.0 -5.7220458984375e-06
progress: 52 9.094947017729282e-13
grad: 1.0 2.0 -7.152557373046875e-07
grad: 2.0 4.0 -2.86102294921875e-06
grad: 3.0 6.0 -5.7220458984375e-06
progress: 53 9.094947017729282e-13
grad: 1.0 2.0 -7.152557373046875e-07
grad: 2.0 4.0 -2.86102294921875e-06
grad: 3.0 6.0 -5.7220458984375e-06
progress: 54 9.094947017729282e-13
grad: 1.0 2.0 -7.152557373046875e-07
grad: 2.0 4.0 -2.86102294921875e-06
grad: 3.0 6.0 -5.7220458984375e-06
progress: 55 9.094947017729282e-13
grad: 1.0 2.0 -7.152557373046875e-07
grad: 2.0 4.0 -2.86102294921875e-06
grad: 3.0 6.0 -5.7220458984375e-06
progress: 56 9.094947017729282e-13
grad: 1.0 2.0 -7.152557373046875e-07
grad: 2.0 4.0 -2.86102294921875e-06
grad: 3.0 6.0 -5.7220458984375e-06
progress: 57 9.094947017729282e-13
grad: 1.0 2.0 -7.152557373046875e-07
grad: 2.0 4.0 -2.86102294921875e-06
grad: 3.0 6.0 -5.7220458984375e-06
progress: 58 9.094947017729282e-13
grad: 1.0 2.0 -7.152557373046875e-07
grad: 2.0 4.0 -2.86102294921875e-06
grad: 3.0 6.0 -5.7220458984375e-06
progress: 59 9.094947017729282e-13
grad: 1.0 2.0 -7.152557373046875e-07
grad: 2.0 4.0 -2.86102294921875e-06
grad: 3.0 6.0 -5.7220458984375e-06
progress: 60 9.094947017729282e-13
grad: 1.0 2.0 -7.152557373046875e-07
grad: 2.0 4.0 -2.86102294921875e-06
grad: 3.0 6.0 -5.7220458984375e-06
progress: 61 9.094947017729282e-13
grad: 1.0 2.0 -7.152557373046875e-07
grad: 2.0 4.0 -2.86102294921875e-06
grad: 3.0 6.0 -5.7220458984375e-06
progress: 62 9.094947017729282e-13
grad: 1.0 2.0 -7.152557373046875e-07
grad: 2.0 4.0 -2.86102294921875e-06
grad: 3.0 6.0 -5.7220458984375e-06
progress: 63 9.094947017729282e-13
grad: 1.0 2.0 -7.152557373046875e-07
grad: 2.0 4.0 -2.86102294921875e-06
grad: 3.0 6.0 -5.7220458984375e-06
progress: 64 9.094947017729282e-13
grad: 1.0 2.0 -7.152557373046875e-07
grad: 2.0 4.0 -2.86102294921875e-06
grad: 3.0 6.0 -5.7220458984375e-06
progress: 65 9.094947017729282e-13
grad: 1.0 2.0 -7.152557373046875e-07
grad: 2.0 4.0 -2.86102294921875e-06
grad: 3.0 6.0 -5.7220458984375e-06
progress: 66 9.094947017729282e-13
grad: 1.0 2.0 -7.152557373046875e-07
grad: 2.0 4.0 -2.86102294921875e-06
grad: 3.0 6.0 -5.7220458984375e-06
progress: 67 9.094947017729282e-13
grad: 1.0 2.0 -7.152557373046875e-07
grad: 2.0 4.0 -2.86102294921875e-06
grad: 3.0 6.0 -5.7220458984375e-06
progress: 68 9.094947017729282e-13
grad: 1.0 2.0 -7.152557373046875e-07
grad: 2.0 4.0 -2.86102294921875e-06
grad: 3.0 6.0 -5.7220458984375e-06
progress: 69 9.094947017729282e-13
grad: 1.0 2.0 -7.152557373046875e-07
grad: 2.0 4.0 -2.86102294921875e-06
grad: 3.0 6.0 -5.7220458984375e-06
progress: 70 9.094947017729282e-13
grad: 1.0 2.0 -7.152557373046875e-07
grad: 2.0 4.0 -2.86102294921875e-06
grad: 3.0 6.0 -5.7220458984375e-06
progress: 71 9.094947017729282e-13
grad: 1.0 2.0 -7.152557373046875e-07
grad: 2.0 4.0 -2.86102294921875e-06
grad: 3.0 6.0 -5.7220458984375e-06
progress: 72 9.094947017729282e-13
grad: 1.0 2.0 -7.152557373046875e-07
grad: 2.0 4.0 -2.86102294921875e-06
grad: 3.0 6.0 -5.7220458984375e-06
progress: 73 9.094947017729282e-13
grad: 1.0 2.0 -7.152557373046875e-07
grad: 2.0 4.0 -2.86102294921875e-06
grad: 3.0 6.0 -5.7220458984375e-06
progress: 74 9.094947017729282e-13
grad: 1.0 2.0 -7.152557373046875e-07
grad: 2.0 4.0 -2.86102294921875e-06
grad: 3.0 6.0 -5.7220458984375e-06
progress: 75 9.094947017729282e-13
grad: 1.0 2.0 -7.152557373046875e-07
grad: 2.0 4.0 -2.86102294921875e-06
grad: 3.0 6.0 -5.7220458984375e-06
progress: 76 9.094947017729282e-13
grad: 1.0 2.0 -7.152557373046875e-07
grad: 2.0 4.0 -2.86102294921875e-06
grad: 3.0 6.0 -5.7220458984375e-06
progress: 77 9.094947017729282e-13
grad: 1.0 2.0 -7.152557373046875e-07
grad: 2.0 4.0 -2.86102294921875e-06
grad: 3.0 6.0 -5.7220458984375e-06
progress: 78 9.094947017729282e-13
grad: 1.0 2.0 -7.152557373046875e-07
grad: 2.0 4.0 -2.86102294921875e-06
grad: 3.0 6.0 -5.7220458984375e-06
progress: 79 9.094947017729282e-13
grad: 1.0 2.0 -7.152557373046875e-07
grad: 2.0 4.0 -2.86102294921875e-06
grad: 3.0 6.0 -5.7220458984375e-06
progress: 80 9.094947017729282e-13
grad: 1.0 2.0 -7.152557373046875e-07
grad: 2.0 4.0 -2.86102294921875e-06
grad: 3.0 6.0 -5.7220458984375e-06
progress: 81 9.094947017729282e-13
grad: 1.0 2.0 -7.152557373046875e-07
grad: 2.0 4.0 -2.86102294921875e-06
grad: 3.0 6.0 -5.7220458984375e-06
progress: 82 9.094947017729282e-13
grad: 1.0 2.0 -7.152557373046875e-07
grad: 2.0 4.0 -2.86102294921875e-06
grad: 3.0 6.0 -5.7220458984375e-06
progress: 83 9.094947017729282e-13
grad: 1.0 2.0 -7.152557373046875e-07
grad: 2.0 4.0 -2.86102294921875e-06
grad: 3.0 6.0 -5.7220458984375e-06
progress: 84 9.094947017729282e-13
grad: 1.0 2.0 -7.152557373046875e-07
grad: 2.0 4.0 -2.86102294921875e-06
grad: 3.0 6.0 -5.7220458984375e-06
progress: 85 9.094947017729282e-13
grad: 1.0 2.0 -7.152557373046875e-07
grad: 2.0 4.0 -2.86102294921875e-06
grad: 3.0 6.0 -5.7220458984375e-06
progress: 86 9.094947017729282e-13
grad: 1.0 2.0 -7.152557373046875e-07
grad: 2.0 4.0 -2.86102294921875e-06
grad: 3.0 6.0 -5.7220458984375e-06
progress: 87 9.094947017729282e-13
grad: 1.0 2.0 -7.152557373046875e-07
grad: 2.0 4.0 -2.86102294921875e-06
grad: 3.0 6.0 -5.7220458984375e-06
progress: 88 9.094947017729282e-13
grad: 1.0 2.0 -7.152557373046875e-07
grad: 2.0 4.0 -2.86102294921875e-06
grad: 3.0 6.0 -5.7220458984375e-06
progress: 89 9.094947017729282e-13
grad: 1.0 2.0 -7.152557373046875e-07
grad: 2.0 4.0 -2.86102294921875e-06
grad: 3.0 6.0 -5.7220458984375e-06
progress: 90 9.094947017729282e-13
grad: 1.0 2.0 -7.152557373046875e-07
grad: 2.0 4.0 -2.86102294921875e-06
grad: 3.0 6.0 -5.7220458984375e-06
progress: 91 9.094947017729282e-13
grad: 1.0 2.0 -7.152557373046875e-07
grad: 2.0 4.0 -2.86102294921875e-06
grad: 3.0 6.0 -5.7220458984375e-06
progress: 92 9.094947017729282e-13
grad: 1.0 2.0 -7.152557373046875e-07
grad: 2.0 4.0 -2.86102294921875e-06
grad: 3.0 6.0 -5.7220458984375e-06
progress: 93 9.094947017729282e-13
grad: 1.0 2.0 -7.152557373046875e-07
grad: 2.0 4.0 -2.86102294921875e-06
grad: 3.0 6.0 -5.7220458984375e-06
progress: 94 9.094947017729282e-13
grad: 1.0 2.0 -7.152557373046875e-07
grad: 2.0 4.0 -2.86102294921875e-06
grad: 3.0 6.0 -5.7220458984375e-06
progress: 95 9.094947017729282e-13
grad: 1.0 2.0 -7.152557373046875e-07
grad: 2.0 4.0 -2.86102294921875e-06
grad: 3.0 6.0 -5.7220458984375e-06
progress: 96 9.094947017729282e-13
grad: 1.0 2.0 -7.152557373046875e-07
grad: 2.0 4.0 -2.86102294921875e-06
grad: 3.0 6.0 -5.7220458984375e-06
progress: 97 9.094947017729282e-13
grad: 1.0 2.0 -7.152557373046875e-07
grad: 2.0 4.0 -2.86102294921875e-06
grad: 3.0 6.0 -5.7220458984375e-06
progress: 98 9.094947017729282e-13
grad: 1.0 2.0 -7.152557373046875e-07
grad: 2.0 4.0 -2.86102294921875e-06
grad: 3.0 6.0 -5.7220458984375e-06
progress: 99 9.094947017729282e-13
Predict (after training) 4 7.999998569488525