前言
目的是利用torch
已经有的自动微分机制,进行参数迭代更新,就不用自己写代码算了。
文章目录
- 前言
- 1. 待优化函数
- 1.1 解释
- 2. 代码
- 3. 结果
1. 待优化函数
y = 10 × ( x 1 + x 2 − 5 ) 2 + ( x 1 − x 2 ) 2 y=10\times(x_1+x_2-5)^2+(x_1-x_2)^2 y=10×(x1+x2−5)2+(x1−x2)2
1.1 解释
这里我们把[10,5]
看成inputs
,整个函数就是model
, [x1,x2]
就是需要迭代优化的参数。我们要求使得y=0
时的参数。根据我们的先验知识,我们期望输出的结果是
5
2
\frac{5}{2}
25,
5
2
\frac{5}{2}
25。
2. 代码
import torch
from torch import nn
# y = 10*(x1+x2-5)^2 + (x1-x2)^2
class Func(nn.Module):
def __init__(self, size=2):
super(Func, self).__init__()
params = torch.rand((size,1),requires_grad=True)
self.params = nn.Parameter(params)
def forward(self, inputs):
y = inputs[0] * torch.pow(self.params[0]+self.params[1]-inputs[1],2)\
+ torch.pow(self.params[0]-self.params[1],2)
return y
class cusLoss(nn.Module):
def __init__(self):
super(cusLoss, self).__init__()
def forward(self,y_pred, y_true):
return torch.abs(y_true-y_pred)
model = Func(size=2)
optimizer = torch.optim.AdamW(model.parameters(),lr=1,weight_decay=1e-5)
loss_func = cusLoss()
x = torch.tensor([10,5])
for i in range(400):
y_pred = model(x)
loss = loss_func(y_pred,0)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print("loss: ", loss.item())
for item in model.parameters():
print(item)
lr
从10调到1,感觉loss
就比较好了。
3. 结果
loss: 156.83274841308594
loss: 38.49453353881836
loss: 0.23079237341880798
loss: 18.87002944946289
loss: 47.266353607177734
loss: 53.96223449707031
loss: 40.30459976196289
loss: 19.8029842376709
loss: 4.426110744476318
loss: 0.13327637314796448
loss: 5.848392009735107
loss: 15.253518104553223
loss: 21.365623474121094
loss: 20.824350357055664
loss: 14.78664779663086
loss: 7.028133392333984
loss: 1.4270392656326294
loss: 0.08896948397159576
loss: 2.5925991535186768
loss: 6.561254978179932
loss: 9.249794960021973
loss: 9.138370513916016
loss: 6.541881561279297
loss: 3.0825929641723633
loss: 0.5860095620155334
loss: 0.07377026975154877
loss: 1.3241856098175049
loss: 3.1676881313323975
loss: 4.294349193572998
loss: 4.038957118988037
loss: 2.664410352706909
loss: 1.0545320510864258
loss: 0.0981285348534584
loss: 0.16388599574565887
loss: 0.9541118741035461
loss: 1.7817351818084717
loss: 2.053565502166748
loss: 1.6271475553512573
loss: 0.8293525576591492
loss: 0.17544522881507874
loss: 0.011017587967216969
loss: 0.313961386680603
loss: 0.7632603049278259
loss: 0.9963911771774292
loss: 0.8582224249839783
loss: 0.4731455445289612
loss: 0.116444431245327
loss: 0.0019159330986440182
loss: 0.14314040541648865
loss: 0.3744982182979584
loss: 0.49575677514076233
loss: 0.4202014207839966
loss: 0.2189445197582245
loss: 0.044234275817871094
loss: 0.005240241996943951
loss: 0.09423969686031342
loss: 0.21066127717494965
loss: 0.25115442276000977
loss: 0.18814465403556824
loss: 0.07860895991325378
loss: 0.0066448356956243515
loss: 0.01395349856466055
loss: 0.07405033707618713
loss: 0.1240086629986763
loss: 0.12016353011131287
loss: 0.0697774738073349
loss: 0.01685400679707527
loss: 0.0004727207124233246
loss: 0.02368409000337124
loss: 0.05685136467218399
loss: 0.06748046725988388
loss: 0.047708846628665924
loss: 0.01680355705320835
loss: 0.0003963433555327356
loss: 0.007644111756235361
loss: 0.02610907331109047
loss: 0.0360269770026207
loss: 0.028617529198527336
loss: 0.011885224841535091
loss: 0.0007842279155738652
loss: 0.0028550983406603336
loss: 0.012763937003910542
loss: 0.019207362085580826
loss: 0.016054006293416023
loss: 0.006997825112193823
loss: 0.0005394043400883675
loss: 0.0014044318813830614
loss: 0.006836464628577232
loss: 0.010403113439679146
loss: 0.008586933836340904
loss: 0.003567643463611603
loss: 0.00020178437989670783
loss: 0.0009794053621590137
loss: 0.004060074221342802
loss: 0.005762426648288965
loss: 0.004401565529406071
loss: 0.001568805892020464
loss: 2.635764940350782e-05
loss: 0.0008179567521438003
loss: 0.002542160451412201
loss: 0.0031480903271585703
loss: 0.0020700229797512293
loss: 0.000531529716681689
loss: 1.5086681742104702e-05
loss: 0.0007311901426874101
loss: 0.0016244511352851987
loss: 0.0016573555767536163
loss: 0.0008591370424255729
loss: 0.00011040566459996626
loss: 9.262182720704004e-05
loss: 0.0006213163724169135
loss: 0.0009869090281426907
loss: 0.0007818497833795846
loss: 0.0002726784732658416
loss: 2.5415104119019816e-06
loss: 0.00017136213136836886
loss: 0.00048359768697991967
loss: 0.0005475623183883727
loss: 0.00030850598705001175
loss: 4.818948218598962e-05
loss: 2.336039688088931e-05
loss: 0.00019348246860317886
loss: 0.0003168497933074832
loss: 0.00024830293841660023
loss: 8.038982196012512e-05
loss: 2.252596686957986e-08
loss: 6.616349855903536e-05
loss: 0.00016678131942171603
loss: 0.00017271166143473238
loss: 8.399917714996263e-05
loss: 7.142824415495852e-06
loss: 1.6916283129830845e-05
loss: 7.81318376539275e-05
loss: 0.00010478033073013648
loss: 6.641951040364802e-05
loss: 1.2883243471151218e-05
loss: 3.119921984762186e-06
loss: 3.588973049772903e-05
loss: 6.165451486594975e-05
loss: 4.818914021598175e-05
loss: 1.456955124012893e-05
loss: 7.381692057606415e-08
loss: 1.4879115951771382e-05
loss: 3.3393916964996606e-05
loss: 3.068727164645679e-05
loss: 1.1735791304090526e-05
loss: 1.3357407624425832e-07
loss: 6.764242698409362e-06
loss: 1.9035425793845206e-05
loss: 1.9952953152824193e-05
loss: 9.178495929518249e-06
loss: 4.976278091817221e-07
loss: 2.7312034944770858e-06
loss: 1.0013219252869021e-05
loss: 1.1602534868870862e-05
loss: 5.797338872071123e-06
loss: 4.113908858016657e-07
loss: 1.4342235772346612e-06
loss: 5.989348665025318e-06
loss: 7.375299446721328e-06
loss: 3.998892680101562e-06
loss: 4.2068927541549783e-07
loss: 6.325563504105958e-07
loss: 3.202781499567209e-06
loss: 4.100205842405558e-06
loss: 2.1982234557071934e-06
loss: 1.9405547391215805e-07
loss: 4.73600323402934e-07
loss: 2.1346897938201437e-06
loss: 2.6749878543341765e-06
loss: 1.4465717868006323e-06
loss: 1.4582843732569017e-07
loss: 2.3943954374772147e-07
loss: 1.1561919563973788e-06
loss: 1.4047566310182447e-06
loss: 6.778419106012734e-07
loss: 2.8631802706513554e-08
loss: 2.509316345822299e-07
loss: 8.661324955028249e-07
loss: 9.58465079747839e-07
loss: 4.4112675823271275e-07
loss: 2.068622961814981e-08
loss: 1.3861813386029098e-07
loss: 4.55244389740983e-07
loss: 4.510482085606782e-07
loss: 1.5433130329256528e-07
loss: 1.1869474292325322e-09
loss: 1.6850668771439814e-07
loss: 3.73103375750361e-07
loss: 3.215137098777632e-07
loss: 1.0316830412193667e-07
loss: 1.460307430534158e-10
loss: 8.736839163248078e-08
loss: 1.7326237866654992e-07
loss: 1.1830303492388339e-07
loss: 1.6191336271731416e-08
loss: 1.859552867244929e-08
loss: 1.1124279808427673e-07
loss: 1.5508044270973187e-07
loss: 9.294950586991035e-08
loss: 1.3275212040753104e-08
loss: 7.75889930082485e-09
loss: 4.7226023980329046e-08
loss: 5.3971746183378855e-08
loss: 1.7632885374041507e-08
loss: 5.916831469221506e-10
loss: 3.007107807206921e-08
loss: 6.265560159590677e-08
loss: 5.392854518504464e-08
loss: 1.8428409020998515e-08
loss: 5.690026227966882e-11
loss: 9.340737960883416e-09
loss: 1.8043010641122237e-08
loss: 9.34875288294279e-09
loss: 7.190692485892214e-11
loss: 8.770314252615208e-09
loss: 2.557277412051917e-08
loss: 2.8033127819071524e-08
loss: 1.384620418320992e-08
loss: 1.545231498312205e-09
loss: 1.1041265679523349e-09
loss: 4.812136467080563e-09
loss: 3.2833327168191317e-09
loss: 2.0691004465334117e-11
loss: 3.639399892563233e-09
loss: 1.146554495790042e-08
loss: 1.4196075426298194e-08
loss: 8.747122137719998e-09
loss: 2.0532411326712463e-09
loss: 1.5973000699887052e-11
loss: 9.15179043659009e-10
loss: 5.866809260623995e-10
loss: 4.001776687800884e-11
loss: 2.049148406513268e-09
loss: 6.149605269456515e-09
loss: 7.649362032680074e-09
loss: 5.2387463256309275e-09
loss: 1.7826664588938002e-09
loss: 1.1164047464262694e-10
loss: 3.68913788406644e-11
loss: 9.606537787476555e-12
loss: 2.2788526621297933e-10
loss: 1.7835191101767123e-09
loss: 4.012292720290134e-09
loss: 4.403375442052493e-09
loss: 3.284696958871791e-09
loss: 1.3110934560245369e-09
loss: 3.283275873400271e-10
loss: 8.236611392931081e-11
loss: 1.1164047464262694e-10
loss: 5.821334525535349e-10
loss: 1.5370460459962487e-09
loss: 2.6284396881237626e-09
loss: 2.9468196771631483e-09
loss: 2.1852883946849033e-09
loss: 1.1010001799149904e-09
loss: 5.125002644490451e-10
loss: 3.2883917810977437e-10
loss: 3.851710062008351e-10
loss: 7.372022992058191e-10
loss: 1.537557636765996e-09
loss: 2.0465904526645318e-09
loss: 2.0465904526645318e-09
loss: 1.537273419671692e-09
loss: 9.09551545191789e-10
loss: 5.821334525535349e-10
loss: 5.821334525535349e-10
loss: 6.571099220309407e-10
loss: 9.09551545191789e-10
loss: 1.3097292139718775e-09
loss: 1.657781467656605e-09
loss: 1.537557636765996e-09
loss: 1.310183961322764e-09
loss: 9.100062925426755e-10
loss: 5.825881999044213e-10
loss: 5.825881999044213e-10
loss: 8.21046342025511e-10
loss: 1.100545432564104e-09
loss: 1.4210854715202004e-09
loss: 1.5371028894151095e-09
loss: 1.3097292139718775e-09
loss: 9.09551545191789e-10
loss: 7.369180821115151e-10
loss: 6.573372957063839e-10
loss: 7.369180821115151e-10
loss: 9.09551545191789e-10
loss: 1.3097292139718775e-09
loss: 1.3097292139718775e-09
loss: 1.3097292139718775e-09
loss: 1.0027179087046534e-09
loss: 7.367475518549327e-10
loss: 7.367475518549327e-10
loss: 7.369180821115151e-10
loss: 9.097220754483715e-10
loss: 1.1007159628206864e-09
loss: 1.2030341167701408e-09
loss: 1.100545432564104e-09
loss: 1.0027179087046534e-09
loss: 9.09551545191789e-10
loss: 7.367475518549327e-10
loss: 7.367475518549327e-10
loss: 9.09551545191789e-10
loss: 9.09551545191789e-10
loss: 1.100545432564104e-09
loss: 1.100545432564104e-09
loss: 9.09551545191789e-10
loss: 9.09551545191789e-10
loss: 7.367475518549327e-10
loss: 7.367475518549327e-10
loss: 9.09551545191789e-10
loss: 9.09551545191789e-10
loss: 1.100545432564104e-09
loss: 1.0027179087046534e-09
loss: 9.094947017729282e-10
loss: 8.208189683500677e-10
loss: 7.367475518549327e-10
loss: 7.367475518549327e-10
loss: 9.09551545191789e-10
loss: 9.09551545191789e-10
loss: 1.100545432564104e-09
loss: 9.09551545191789e-10
loss: 9.09551545191789e-10
loss: 7.367475518549327e-10
loss: 7.367475518549327e-10
loss: 8.208189683500677e-10
loss: 9.09551545191789e-10
loss: 9.09551545191789e-10
loss: 9.09551545191789e-10
loss: 7.367475518549327e-10
loss: 7.367475518549327e-10
loss: 9.09551545191789e-10
loss: 9.09551545191789e-10
loss: 9.09551545191789e-10
loss: 7.367475518549327e-10
loss: 7.367475518549327e-10
loss: 8.208189683500677e-10
loss: 9.09551545191789e-10
loss: 9.09551545191789e-10
loss: 7.367475518549327e-10
loss: 7.367475518549327e-10
loss: 7.367475518549327e-10
loss: 9.09551545191789e-10
loss: 9.09551545191789e-10
loss: 7.367475518549327e-10
loss: 7.367475518549327e-10
loss: 7.367475518549327e-10
loss: 8.208189683500677e-10
loss: 8.208189683500677e-10
loss: 8.208189683500677e-10
loss: 8.208189683500677e-10
loss: 7.367475518549327e-10
loss: 7.367475518549327e-10
loss: 7.367475518549327e-10
loss: 7.367475518549327e-10
loss: 9.09551545191789e-10
loss: 9.09551545191789e-10
loss: 7.367475518549327e-10
loss: 7.367475518549327e-10
loss: 5.821334525535349e-10
loss: 7.367475518549327e-10
loss: 7.367475518549327e-10
loss: 9.09551545191789e-10
loss: 9.09551545191789e-10
loss: 7.366907084360719e-10
loss: 6.571099220309407e-10
loss: 5.821334525535349e-10
loss: 7.367475518549327e-10
loss: 7.367475518549327e-10
loss: 9.09551545191789e-10
loss: 7.367475518549327e-10
loss: 7.366907084360719e-10
loss: 7.367475518549327e-10
loss: 5.821334525535349e-10
loss: 7.367475518549327e-10
loss: 7.366907084360719e-10
loss: 7.366907084360719e-10
loss: 7.366907084360719e-10
loss: 7.366907084360719e-10
loss: 7.366907084360719e-10
loss: 7.366907084360719e-10
loss: 7.367475518549327e-10
loss: 5.821334525535349e-10
loss: 5.821334525535349e-10
loss: 7.367475518549327e-10
loss: 7.367475518549327e-10
loss: 8.208189683500677e-10
loss: 7.367475518549327e-10
loss: 7.367475518549327e-10
loss: 5.821334525535349e-10
loss: 5.821334525535349e-10
loss: 7.367475518549327e-10
loss: 7.367475518549327e-10
loss: 7.367475518549327e-10
loss: 6.571099220309407e-10
loss: 6.571099220309407e-10
loss: 6.571099220309407e-10
loss: 6.571099220309407e-10
loss: 6.571099220309407e-10
loss: 6.571099220309407e-10
loss: 6.571099220309407e-10
Parameter containing:
tensor([[2.5000],
[2.5000]], requires_grad=True)