文章目录
- 线性回归
- 1. 单变量的线性回归
- 1.1 数据读取
- 1.2 训练数据的准备
- 1.3 假设函数定义--假设函数是为了去预测
- 1.4 损失函数的定义
- 1.5 利用梯度下降算法来优化参数w
- 1.6 可视化误差曲线
- 1.7 可视化回归线/回归平面
- 1.2 单变量的线性回归--基于sklearn试试?
- 1.3 多变量线性回归
- 实验要求1 准备训练数据
- 实验要求2 调用前面的梯度下降算法
- 实验要求3 绘制误差曲线
- 1.4 最小二乘法求参数
- 1.5 来点正则化?
- 1.5.1 普通的线性回归
- 1.5.2 岭回归
- 1.5.3 Lasso回归
- 实验要求4 手写代码实现单变量的L2正则化
线性回归
1. 单变量的线性回归
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif']=['SimHei'] #用来正常显示中文标签
plt.rcParams['axes.unicode_minus']=False #用来正常显示负号
1.1 数据读取
data=pd.read_csv("data/regress_data1.csv")
data.head()
人口 | 收益 | |
---|---|---|
0 | 6.1101 | 17.5920 |
1 | 5.5277 | 9.1302 |
2 | 8.5186 | 13.6620 |
3 | 7.0032 | 11.8540 |
4 | 5.8598 | 6.8233 |
#可视化人口与收益之间的关系
data.plot(kind="scatter",x="人口",y="收益")
plt.xlabel("人口",fontsize=10)
plt.ylabel("收益",fontsize=10)
plt.title("人口与收益之间的关系")
Text(0.5, 1.0, '人口与收益之间的关系')
1.2 训练数据的准备
data.insert(0,"ones",1)
data
ones | 人口 | 收益 | |
---|---|---|---|
0 | 1 | 6.1101 | 17.59200 |
1 | 1 | 5.5277 | 9.13020 |
2 | 1 | 8.5186 | 13.66200 |
3 | 1 | 7.0032 | 11.85400 |
4 | 1 | 5.8598 | 6.82330 |
... | ... | ... | ... |
92 | 1 | 5.8707 | 7.20290 |
93 | 1 | 5.3054 | 1.98690 |
94 | 1 | 8.2934 | 0.14454 |
95 | 1 | 13.3940 | 9.05510 |
96 | 1 | 5.4369 | 0.61705 |
97 rows × 3 columns
col_num=data.shape[1]
m=data.shape[0]
#训练集中的特征
X=data.iloc[:,:col_num-1]
#训练集中的标签
y=data.iloc[:,col_num-1]
X=X.values
y=y.values
X.shape,y.shape
((97, 2), (97,))
y=y.reshape((m,1))
y.shape
(97, 1)
#初始化权重向量
w=np.zeros((col_num-1,1))
w.shape
(2, 1)
1.3 假设函数定义–假设函数是为了去预测
#估计yhat
def h(X,w):
#X的维度m,col_num-1, w的维度col_num-1,1
temp=X@w
return temp
1.4 损失函数的定义
#定义MSE损失,均方损失函数
def cost(X,y,w):
temp=h(X,w)
cost=np.sum(np.square(temp-y))/(2*m)
return cost
def computeCost(X,y,w):
inner = np.power(((X @ w) - y), 2)# (m,n) @ (n, 1) -> (n, 1)
# return np.sum(inner)/(2 * len(X))
return np.sum(inner) / (2*m)
cost(X,y,w)
32.072733877455676
error=h(X,w)-y
error.shape
(97, 1)
x1=np.array([1,2]).reshape(2,1)
x2=np.array([3,4]).reshape(2,1)
np.multiply(x1,x2)
array([[3],
[8]])
X[:,1].shape
(97,)
X.shape,w.shape,y.shape
((97, 2), (2, 1), (97, 1))
h(X,w)-y
array([[-17.592 ],
[ -9.1302 ],
[-13.662 ],
[-11.854 ],
[ -6.8233 ],
[-11.886 ],
[ -4.3483 ],
[-12. ],
[ -6.5987 ],
[ -3.8166 ],
[ -3.2522 ],
[-15.505 ],
[ -3.1551 ],
[ -7.2258 ],
[ -0.71618],
[ -3.5129 ],
[ -5.3048 ],
[ -0.56077],
[ -3.6518 ],
[ -5.3893 ],
[ -3.1386 ],
[-21.767 ],
[ -4.263 ],
[ -5.1875 ],
[ -3.0825 ],
[-22.638 ],
[-13.501 ],
[ -7.0467 ],
[-14.692 ],
[-24.147 ],
[ 1.22 ],
[ -5.9966 ],
[-12.134 ],
[ -1.8495 ],
[ -6.5426 ],
[ -4.5623 ],
[ -4.1164 ],
[ -3.3928 ],
[-10.117 ],
[ -5.4974 ],
[ -0.55657],
[ -3.9115 ],
[ -5.3854 ],
[ -2.4406 ],
[ -6.7318 ],
[ -1.0463 ],
[ -5.1337 ],
[ -1.844 ],
[ -8.0043 ],
[ -1.0179 ],
[ -6.7504 ],
[ -1.8396 ],
[ -4.2885 ],
[ -4.9981 ],
[ -1.4233 ],
[ 1.4211 ],
[ -2.4756 ],
[ -4.6042 ],
[ -3.9624 ],
[ -5.4141 ],
[ -5.1694 ],
[ 0.74279],
[-17.929 ],
[-12.054 ],
[-17.054 ],
[ -4.8852 ],
[ -5.7442 ],
[ -7.7754 ],
[ -1.0173 ],
[-20.992 ],
[ -6.6799 ],
[ -4.0259 ],
[ -1.2784 ],
[ -3.3411 ],
[ 2.6807 ],
[ -0.29678],
[ -3.8845 ],
[ -5.7014 ],
[ -6.7526 ],
[ -2.0576 ],
[ -0.47953],
[ -0.20421],
[ -0.67861],
[ -7.5435 ],
[ -5.3436 ],
[ -4.2415 ],
[ -6.7981 ],
[ -0.92695],
[ -0.152 ],
[ -2.8214 ],
[ -1.8451 ],
[ -4.2959 ],
[ -7.2029 ],
[ -1.9869 ],
[ -0.14454],
[ -9.0551 ],
[ -0.61705]])
np.multiply((h(X,w)-y).ravel(),X[:,1]).shape
(97,)
1.5 利用梯度下降算法来优化参数w
#超参数为I,学习率alpha,对所有样本
def gradient_descent(X,y,w,iter_num,alpha):
temp=np.zeros((col_num-1,1))
cost_lst=[]
for i in range(iter_num):
error=h(X,w)-y
for j in range(col_num-1):
incre=np.multiply(error.ravel(),X[:,j].ravel())
temp[j,0]=w[j,0]-((alpha/m)*np.sum(incre))
w=temp
cost_lst.append(cost(X,y,w))
return w,cost_lst
iter_num=200
alpha=0.003
w=np.zeros((col_num-1,1))
w,cost_lst=gradient_descent(X,y,w,iter_num,alpha)
w
array([[-0.32791203],
[ 0.83460252]])
cost
<function __main__.cost(X, y, w)>
1.6 可视化误差曲线
plt.plot(range(iter_num),cost_lst,"r-+")
plt.xlabel("迭代次数")
plt.ylabel("误差")
plt.show()
1.7 可视化回归线/回归平面
x=np.linspace(data["人口"].min(),data["人口"].max(),50)
y1=w[0,0]*1+w[1,0]*x
plt.plot(x,y1,"r-+",label="预测线")
plt.scatter(data["人口"],data["收益"], label='训练数据')
plt.xlabel("人口",fontsize=10)
plt.ylabel("收益",fontsize=10)
plt.title("人口与收益之间的关系")
plt.show()
w
array([[-0.32791203],
[ 0.83460252]])
总结:
- 数据准备
- 初始化w
- 定义了假设函数
- 定义了损失函数或者代价函数
- 定义梯度下降算法
- 可视化分析
1.2 单变量的线性回归–基于sklearn试试?
X.shape,y.shape
((97, 2), (97, 1))
import sklearn
from sklearn import linear_model
reg=linear_model.LinearRegression()
reg.fit(X,y)
reg.coef_
array([[0. , 1.19303364]])
w
array([[-0.32791203],
[ 0.83460252]])
reg.intercept_
array([-3.89578088])
reg.get_params()
{'copy_X': True,
'fit_intercept': True,
'n_jobs': None,
'normalize': 'deprecated',
'positive': False}
reg.predict(X)-y
array([[-14.19822601],
[ -6.4312488 ],
[ -7.39480448],
[ -7.39472766],
[ -3.72814233],
[ -5.78069914],
[ 0.67551586],
[ -5.66181898],
[ -2.75622606],
[ -1.68207302],
[ -0.33492365],
[ -2.50265234],
[ -0.21002596],
[ -1.09007678],
[ 2.117584 ],
[ -0.99087569],
[ -1.60644452],
[ 1.66383102],
[ 0.12314824],
[ -0.84937859],
[ 0.34942365],
[ -1.47998891],
[ -1.60890687],
[ -1.53603074],
[ -0.33916795],
[ -3.93175849],
[ -2.09254529],
[ 2.12958876],
[ -2.86836958],
[ -1.55385488],
[ 3.59050903],
[ -2.03100498],
[ -4.99636713],
[ 1.28383475],
[ -0.64226232],
[ 1.00673223],
[ 1.6465002 ],
[ -0.60007636],
[ 1.30099898],
[ -1.81336092],
[ 1.99826273],
[ 0.40377318],
[ 4.68685703],
[ 0.55183747],
[ -1.29245052],
[ 3.52022606],
[ -2.9805617 ],
[ 1.18148451],
[ 2.05841276],
[ 1.69763436],
[ -1.65046859],
[ 0.59688379],
[ 0.67268159],
[ 0.17687322],
[ 2.23616258],
[ 5.11170076],
[ 1.11395081],
[ -1.77162904],
[ 3.24920096],
[ 1.96858198],
[ 1.46381825],
[ 3.02608828],
[ 3.56178204],
[ 1.83596469],
[ 1.66894398],
[ -0.16942543],
[ 0.2563525 ],
[ 0.5407115 ],
[ 1.64788834],
[ -0.62028352],
[ 1.51690814],
[ 0.82862438],
[ 1.9914178 ],
[ 1.38386093],
[ 4.78217995],
[ 3.61930412],
[ 1.21352255],
[ -3.58846693],
[ 1.60884678],
[ 0.14027707],
[ 2.45981748],
[ 2.08994488],
[ 3.00817305],
[ 0.21510688],
[ -1.46569296],
[ 2.02402528],
[ 0.25840658],
[ 2.33785705],
[ 2.53824205],
[ -0.68114646],
[ 1.06859725],
[ 0.91903985],
[ -4.09473826],
[ 0.44683982],
[ 5.85398435],
[ 3.02861175],
[ 1.97357374]])
reg.score(X,y)
0.7020315537841397
1.3 多变量线性回归
path = 'data/regress_data2.csv'
data2 = pd.read_csv(path)
data2.head()
面积 | 房间数 | 价格 | |
---|---|---|---|
0 | 2104 | 3 | 399900 |
1 | 1600 | 3 | 329900 |
2 | 2400 | 3 | 369000 |
3 | 1416 | 2 | 232000 |
4 | 3000 | 4 | 539900 |
data2=(data2-data2.mean())/data2.std()
data2.head()
面积 | 房间数 | 价格 | |
---|---|---|---|
0 | 0.130010 | -0.223675 | 0.475747 |
1 | -0.504190 | -0.223675 | -0.084074 |
2 | 0.502476 | -0.223675 | 0.228626 |
3 | -0.735723 | -1.537767 | -0.867025 |
4 | 1.257476 | 1.090417 | 1.595389 |
实验要求1 准备训练数据
data2.insert(0,"ones",1)
col_num2=data2.shape[1]
m2=data2.shape[0]
X2=data2.iloc[:,:-1].values
y2=data2.iloc[:,-1].values.reshape((data2.shape[0],1))
w2=np.zeros((X2.shape[1],1))
X2.shape,y2.shape,w2.shape
((47, 3), (47, 1), (3, 1))
实验要求2 调用前面的梯度下降算法
#定义MSE损失,均方损失函数
def cost2(X,y,w):
temp=h(X,w)
cost=np.sum(np.square(temp-y))/(2*m2)
return cost
#超参数为I,学习率alpha,对所有样本
def gradient_descent(X,y,w,iter_num,alpha):
temp=np.zeros((col_num2-1,1))
cost_lst=[]
for i in range(iter_num):
error=h(X,w)-y
for j in range(col_num2-1):
incre=np.multiply(error.ravel(),X[:,j].ravel())
temp[j,0]=w[j,0]-((alpha/m2)*np.sum(incre))
w=temp
cost_lst.append(cost2(X,y,w))
return w,cost_lst
iter_num2=1000
alpha2=0.01
w2,cost_lst2=gradient_descent(X2,y2,w2,iter_num2,alpha2)
w2
array([[-1.03191687e-16],
[ 8.78503652e-01],
[-4.69166570e-02]])
cost_lst2
[0.4805491041076719,
0.47198587701203876,
0.46366461618706284,
0.4555781400525299,
0.44771948335326117,
0.4400818906150644,
0.43265880979889004,
0.42544388614718714,
0.41843095621663473,
0.4116140420916035,
0.4049873457728717,
0.39854524373628347,
0.3922822816562035,
0.38619316928877434,
0.3802727755101314,
0.3745161235048873,
0.36891838610032585,
0.36347488124189714,
0.3581810676057273,
0.353032540343996,
0.34802502695915444,
0.3431543833030803,
0.33841658969738386,
0.3338077471711977,
0.3293240738128865,
0.32496190123222957,
0.32071767112972566,
0.3165879319697778,
0.3125693357546089,
0.3086586348958572,
0.3048526791808924,
0.301148412830983,
0.29754287164853055,
0.29403318025067643,
0.290616549386659,
0.2872902733363892,
0.2840517273877804,
0.2808983653904495,
0.2778277173834725,
0.2748373872949541,
0.27192505071123485,
0.2690884527136251,
0.26632540578062175,
0.26363378775362334,
0.2610115398642199,
0.2584566648211922,
0.25596722495541263,
0.2535413404208921,
0.2511771874502738,
0.24887299666312288,
0.2466270514254147,
0.2444376862586688,
0.2423032852972262,
0.24022228079221122,
0.23819315166076374,
0.23621442207916982,
0.23428466011856308,
0.23240247642190492,
0.2305665229209955,
0.22877549159230046,
0.22702811325042083,
0.2253231563780625,
0.2236594259914031,
0.22203576253978147,
0.22045104083867165,
0.21890416903493287,
0.2173940876033578,
0.2159197683735719,
0.21448021358636368,
0.21307445497855435,
0.21170155289554488,
0.21036059543069827,
0.20905069759074849,
0.2077710004864442,
0.20652067054766643,
0.20529889876227617,
0.20410489993797507,
0.20293791198648126,
0.20179719522934592,
0.20068203172475318,
0.1995917246146703,
0.19852559749172968,
0.1974829937852473,
0.19646327616579626,
0.19546582596777473,
0.1944900426294234,
0.1935353431497636,
0.19260116156194368,
0.1916869484224977,
0.1907921703160338,
0.189916309374886,
0.18905886281327441,
0.18821934247553765,
0.1873972743980089,
0.1865921983841233,
0.185803667592357,
0.1850312481366081,
0.1842745186986435,
0.18353307015224665,
0.18280650519871142,
0.18209443801333897,
0.18139649390260434,
0.1807123089716706,
0.18004152980193594,
0.17938381313831125,
0.17873882558593354,
0.17810624331602853,
0.17748575178064738,
0.1768770454360068,
0.17627982747417442,
0.17569380956284517,
0.17511871159296452,
0.17455426143396124,
0.17400019469635858,
0.17345625450154267,
0.17292219125846853,
0.17239776244709656,
0.1718827324083545,
0.1713768721404272,
0.170879959101184,
0.17039177701655678,
0.16991211569468956,
0.16944077084568412,
0.16897754390677372,
0.16852224187275908,
0.16807467713154944,
0.16763466730465196,
0.16720203509246181,
0.16677660812420697,
0.16635821881240656,
0.1659467042117072,
0.165541905881964,
0.16514366975543857,
0.16475184600798926,
0.16436628893413294,
0.16398685682586134,
0.1636134118550983,
0.16324581995968843,
0.1628839507328093,
0.16252767731570503,
0.16217687629363958,
0.16183142759497396,
0.16149121439327116,
0.1611561230123392,
0.1608260428341213,
0.1605008662093497,
0.16018048837087778,
0.15986480734960967,
0.15955372389294997,
0.1592471413856969,
0.1589449657733043,
0.1586471054874422,
0.1583534713737859,
0.1580639766219659,
0.15777853669761477,
0.15749706927644566,
0.15721949418030304,
0.1569457333151249,
0.15667571061075897,
0.15640935196257785,
0.15614658517483707,
0.15588733990572493,
0.1556315476140529,
0.15537914150753623,
0.15513005649261752,
0.15488422912578662,
0.15464159756635176,
0.15440210153061742,
0.15416568224742783,
0.15393228241503384,
0.15370184615924348,
0.15347431899281805,
0.15324964777607533,
0.15302778067866443,
0.15280866714247615,
0.1525922578456555,
0.15237850466768188,
0.15216736065548658,
0.15195877999057447,
0.1517527179571211,
0.1515491309110149,
0.15134797624981636,
0.15114921238360698,
0.1509527987067004,
0.15075869557019017,
0.15056686425530957,
0.15037726694757766,
0.15018986671170936,
0.15000462746726564,
0.1498215139650219,
0.1496404917640332,
0.14946152720937447,
0.14928458741053677,
0.14910964022045875,
0.14893665421517455,
0.14876559867406025,
0.14859644356065968,
0.1484291595040733,
0.1482637177808928,
0.14810009029766483,
0.1479382495738683,
0.14777816872538996,
0.1476198214484827,
0.1474631820041929,
0.1473082252032421,
0.14715492639134978,
0.14700326143498368,
0.14685320670752527,
0.14670473907583773,
0.14655783588722407,
0.14641247495676463,
0.1462686345550211,
0.14612629339609778,
0.14598543062604827,
0.14584602581161768,
0.14570805892930994,
0.14557151035477142,
0.14543636085247996,
0.1453025915657318,
0.14517018400691611,
0.1450391200480697,
0.14490938191170247,
0.14478095216188633,
0.14465381369559951,
0.14452794973431848,
0.1444033438158502,
0.14427997978639776,
0.14415784179285188,
0.14403691427530232,
0.14391718195976197,
0.14379862985109793,
0.14368124322616269,
0.14356500762711993,
0.14344990885495965,
0.14333593296319558,
0.1432230662517408,
0.1431112952609561,
0.14300060676586493,
0.14289098777053136,
0.14278242550259543,
0.14267490740796127,
0.14256842114563387,
0.14246295458269945,
0.142358495789446,
0.14225503303461925,
0.14215255478081018,
0.1420510496799703,
0.14195050656905087,
0.14185091446576267,
0.1417522625644519,
0.14165454023209023,
0.14155773700437435,
0.141461842581932,
0.14136684682663245,
0.1412727397579967,
0.14117951154970584,
0.14108715252620416,
0.14099565315939372,
0.1409050040654189,
0.14081519600153714,
0.14072621986307407,
0.1406380666804604,
0.14055072761634763,
0.1404641939628014,
0.14037845713856925,
0.1402935086864208,
0.14020934027055865,
0.14012594367409767,
0.14004331079661067,
0.1399614336517384,
0.1398803043648625,
0.13979991517083906,
0.13972025841179114,
0.13964132653495887,
0.1395631120906052,
0.13948560772997556,
0.1394088062033102,
0.13933270035790743,
0.1392572831362367,
0.13918254757409942,
0.13910848679883667,
0.13903509402758246,
0.138962362565561,
0.13889028580442658,
0.13881885722064558,
0.1387480703739184,
0.138677918905641,
0.1386083965374045,
0.13853949706953195,
0.1384712143796508,
0.13840354242130104,
0.13833647522257653,
0.13827000688480018,
0.1382041315812306,
0.1381388435558005,
0.13807413712188513,
0.13801000666110028,
0.1379464466221289,
0.13788345151957584,
0.13782101593284923,
0.1377591345050687,
0.13769780194199868,
0.13763701301100714,
0.13757676254004808,
0.13751704541666765,
0.13745785658703336,
0.1373991910549853,
0.1373410438811091,
0.13728341018182993,
0.1372262851285268,
0.13716966394666721,
0.13711354191496053,
0.13705791436453063,
0.13700277667810684,
0.1369481242892326,
0.13689395268149127,
0.13684025738774938,
0.13678703398941564,
0.13673427811571623,
0.1366819854429858,
0.13663015169397344,
0.13657877263716303,
0.1365278440861087,
0.1364773618987835,
0.13642732197694216,
0.13637772026549663,
0.13632855275190497,
0.13627981546557258,
0.13623150447726529,
0.13618361589853506,
0.1361361458811566,
0.13608909061657562,
0.1360424463353681,
0.13599620930670997,
0.13595037583785746,
0.13590494227363753,
0.13585990499594844,
0.1358152604232695,
0.1357710050101804,
0.13572713524689023,
0.1356836476587745,
0.1356405388059217,
0.13559780528268783,
0.1355554437172595,
0.13551345077122479,
0.13547182313915254,
0.13543055754817865,
0.1353896507576004,
0.13534909955847768,
0.13530890077324167,
0.13526905125531047,
0.1352295478887108,
0.13519038758770774,
0.13515156729643937,
0.13511308398855892,
0.13507493466688233,
0.13503711636304225,
0.1349996261371476,
0.13496246107744925,
0.1349256183000107,
0.13488909494838502,
0.13485288819329633,
0.1348169952323271,
0.13478141328961019,
0.1347461396155259,
0.1347111714864043,
0.1346765062042316,
0.13464214109636177,
0.13460807351523252,
0.13457430083808555,
0.13454082046669139,
0.13450762982707826,
0.13447472636926538,
0.1344421075670001,
0.1344097709174989,
0.13437771394119274,
0.1343459341814757,
0.1343144292044577,
0.13428319659872076,
0.1342522339750784,
0.1342215389663395,
0.1341911092270747,
0.13416094243338644,
0.1341310362826823,
0.13410138849345168,
0.13407199680504514,
0.13404285897745766,
0.1340139727911139,
0.13398533604665733,
0.13395694656474152,
0.13392880218582504,
0.13390090076996808,
0.1338732401966332,
0.1338458183644872,
0.13381863319120707,
0.1337916826132873,
0.1337649645858507,
0.1337384770824609,
0.13371221809493786,
0.13368618563317541,
0.1336603777249612,
0.13363479241579912,
0.1336094277687338,
0.13358428186417728,
0.13355935279973807,
0.13353463869005203,
0.13351013766661574,
0.1334858478776214,
0.1334617674877945,
0.13343789467823247,
0.13341422764624633,
0.13339076460520344,
0.1333675037843725,
0.13334444342877066,
0.13332158179901155,
0.13329891717115624,
0.13327644783656503,
0.13325417210175164,
0.13323208828823854,
0.13321019473241433,
0.13318848978539244,
0.13316697181287218,
0.13314563919500016,
0.1331244903262345,
0.13310352361520966,
0.13308273748460353,
0.13306213037100528,
0.13304170072478536,
0.13302144700996643,
0.13300136770409596,
0.13298146129812038,
0.13296172629625996,
0.13294216121588615,
0.132922764587399,
0.1329035349541069,
0.13288447087210706,
0.13286557091016762,
0.13284683364961072,
0.1328282576841967,
0.13280984162001042,
0.1327915840753473,
0.13277348368060227,
0.13275553907815818,
0.13273774892227672,
0.13272011187898983,
0.13270262662599233,
0.1326852918525356,
0.13266810625932268,
0.13265106855840397,
0.1326341774730744,
0.13261743173777144,
0.13260083009797413,
0.13258437131010334,
0.13256805414142278,
0.13255187736994117,
0.13253583978431552,
0.1325199401837546,
0.1325041773779249,
0.13248855018685587,
0.1324730574408471,
0.13245769798037618,
0.13244247065600748,
0.13242737432830168,
0.13241240786772626,
0.13239757015456718,
0.13238286007884084,
0.1323682765402074,
0.13235381844788482,
0.1323394847205633,
0.13232527428632143,
0.13231118608254225,
0.1322972190558306,
0.13228337216193134,
0.13226964436564798,
0.13225603464076227,
0.1322425419699548,
0.13222916534472598,
0.13221590376531783,
0.13220275624063701,
0.13218972178817762,
0.13217679943394567,
0.13216398821238384,
0.13215128716629693,
0.13213869534677797,
0.13212621181313536,
0.1321138356328202,
0.13210156588135483,
0.13208940164226146,
0.132077342006992,
0.1320653860748583,
0.1320535329529629,
0.1320417817561306,
0.1320301316068407,
0.13201858163516003,
0.1320071309786758,
0.13199577878243007,
0.13198452419885442,
0.1319733663877049,
0.131962304515998,
0.13195133775794732,
0.1319404652949001,
0.1319296863152752,
0.1319190000145011,
0.1319084055949546,
0.1318979022659001,
0.13188748924342955,
0.1318771657504026,
0.1318669310163877,
0.13185678427760358,
0.13184672477686088,
0.13183675176350523,
0.13182686449335979,
0.13181706222866899,
0.1318073442380425,
0.13179770979639993,
0.1317881581849157,
0.1317786886909647,
0.13176930060806835,
0.13175999323584098,
0.1317507658799369,
0.13174161785199792,
0.13173254846960125,
0.13172355705620795,
0.13171464294111176,
0.13170580545938826,
0.13169704395184512,
0.1316883577649718,
0.13167974625089038,
0.13167120876730698,
0.13166274467746283,
0.13165435335008663,
0.1316460341593466,
0.13163778648480368,
0.13162960971136442,
0.13162150322923485,
0.13161346643387453,
0.13160549872595084,
0.13159759951129413,
0.13158976820085277,
0.13158200421064908,
0.13157430696173483,
0.13156667588014856,
0.13155911039687143,
0.1315516099477853,
0.13154417397362975,
0.13153680191996023,
0.13152949323710658,
0.1315222473801313,
0.13151506380878897,
0.13150794198748567,
0.13150088138523836,
0.1314938814756356,
0.13148694173679754,
0.131480061651337,
0.13147324070632052,
0.1314664783932301,
0.1314597742079246,
0.1314531276506024,
0.13144653822576377,
0.13144000544217335,
0.13143352881282383,
0.13142710785489936,
0.13142074208973897,
0.131414431042801,
0.13140817424362766,
0.13140197122580938,
0.13139582152695017,
0.131389724688633,
0.13138368025638517,
0.13137768777964465,
0.13137174681172598,
0.13136585690978717,
0.1313600176347961,
0.13135422855149823,
0.1313484892283835,
0.13134279923765432,
0.1313371581551934,
0.13133156556053213,
0.13132602103681912,
0.13132052417078882,
0.13131507455273095,
0.13130967177645936,
0.13130431543928223,
0.13129900514197135,
0.13129374048873282,
0.13128852108717703,
0.1312833465482897,
0.13127821648640217,
0.13127313051916345,
0.13126808826751082,
0.13126308935564196,
0.1312581334109867,
0.13125322006417925,
0.13124834894903042,
0.13124351970250062,
0.13123873196467223,
0.13123398537872316,
0.1312292795908998,
0.13122461425049098,
0.13121998900980153,
0.13121540352412622,
0.13121085745172428,
0.13120635045379372,
0.13120188219444595,
0.13119745234068078,
0.1311930605623617,
0.1311887065321909,
0.1311843899256851,
0.1311801104211511,
0.13117586769966202,
0.1311716614450333,
0.13116749134379915,
0.13116335708518903,
0.13115925836110465,
0.13115519486609695,
0.13115116629734297,
0.13114717235462375,
0.13114321274030158,
0.13113928715929773,
0.1311353953190708,
0.1311315369295945,
0.13112771170333615,
0.1311239193552353,
0.13112015960268236,
0.13111643216549748,
0.1311127367659097,
0.1311090731285363,
0.13110544098036211,
0.13110184005071918,
0.13109827007126654,
0.13109473077597045,
0.1310912219010841,
0.13108774318512836,
0.13108429436887195,
0.13108087519531228,
0.13107748540965625,
0.13107412475930125,
0.1310707929938162,
0.131067489864923,
0.1310642151264781,
0.13106096853445376,
0.1310577498469203,
0.1310545588240278,
0.13105139522798825,
0.1310482588230578,
0.13104514937551928,
0.13104206665366452,
0.13103901042777755,
0.13103598047011686,
0.13103297655489882,
0.13102999845828087,
0.1310270459583445,
0.13102411883507903,
0.13102121687036491,
0.13101833984795783,
0.1310154875534722,
0.13101265977436552,
0.13100985629992212,
0.13100707692123786,
0.1310043214312044,
0.13100158962449351,
0.13099888129754225,
0.1309961962485374,
0.13099353427740043,
0.13099089518577295,
0.13098827877700148,
0.13098568485612289,
0.1309831132298502,
0.13098056370655764,
0.13097803609626688,
0.13097553021063246,
0.13097304586292796,
0.130970582868032,
0.13096814104241458,
0.13096572020412311,
0.13096332017276915,
0.13096094076951487,
0.13095858181705952,
0.13095624313962653,
0.13095392456295019,
0.13095162591426274,
0.13094934702228142,
0.13094708771719588,
0.13094484783065521,
0.1309426271957558,
0.13094042564702846,
0.13093824302042653,
0.1309360791533131,
0.1309339338844496,
0.13093180705398308,
0.1309296985034348,
0.13092760807568815,
0.13092553561497697,
0.1309234809668741,
0.13092144397827965,
0.13091942449740973,
0.1309174223737851,
0.1309154374582199,
0.13091346960281078,
0.13091151866092543,
0.13090958448719198,
0.13090766693748815,
0.13090576586893035,
0.13090388113986332,
0.1309020126098491,
0.130900160139657,
0.1308983235912531,
0.13089650282778978,
0.13089469771359577,
0.13089290811416598,
0.13089113389615128,
0.13088937492734887,
0.13088763107669207,
0.130885902214241,
0.13088418821117243,
0.13088248893977056,
0.13088080427341722,
0.13087913408658264,
0.1308774782548159,
0.1308758366547358,
0.13087420916402173,
0.1308725956614043,
0.13087099602665653,
0.13086941014058484,
0.13086783788502007,
0.13086627914280877,
0.13086473379780447,
0.1308632017348589,
0.13086168283981364,
0.13086017699949143,
0.13085868410168772,
0.13085720403516238,
0.13085573668963163,
0.13085428195575916,
0.13085283972514883,
0.13085140989033592,
0.1308499923447795,
0.13084858698285434,
0.13084719369984307,
0.13084581239192838,
0.13084444295618522,
0.13084308529057329,
0.1308417392939292,
0.13084040486595921,
0.13083908190723154,
0.13083777031916902,
0.13083647000404175,
0.13083518086495988,
0.1308339028058663,
0.1308326357315295,
0.13083137954753649,
0.1308301341602858,
0.13082889947698037,
0.1308276754056209,
0.13082646185499866,
0.13082525873468898,
0.13082406595504426,
0.13082288342718762,
0.1308217110630058,
0.13082054877514324,
0.13081939647699484,
0.13081825408270004,
0.13081712150713629,
0.13081599866591254,
0.13081488547536319,
0.13081378185254178,
0.1308126877152146,
0.13081160298185487,
0.1308105275716365,
0.13080946140442812,
0.13080840440078706,
0.1308073564819534,
0.1308063175698443,
0.13080528758704793,
0.13080426645681778,
0.13080325410306715,
0.13080225045036314,
0.13080125542392118,
0.13080026894959965,
0.1307992909538939,
0.13079832136393135,
0.13079736010746548,
0.13079640711287094,
0.1307954623091379,
0.13079452562586683,
0.13079359699326334,
0.13079267634213287,
0.13079176360387565,
0.1307908587104814,
0.13078996159452447,
0.1307890721891588,
0.1307881904281127,
0.13078731624568427,
0.13078644957673607,
0.13078559035669074,
0.13078473852152586,
0.1307838940077693,
0.1307830567524944,
0.13078222669331543,
0.13078140376838285,
0.13078058791637864,
0.13077977907651198,
0.13077897718851425,
0.13077818219263512,
0.1307773940296377,
0.13077661264079418,
0.13077583796788161,
0.13077506995317736,
0.130774308539455,
0.13077355366997984,
0.13077280528850505,
0.13077206333926703,
0.13077132776698147,
0.1307705985168394,
0.13076987553450267,
0.1307691587661004,
0.13076844815822458,
0.13076774365792623,
0.13076704521271165,
0.13076635277053805,
0.1307656662798101,
0.13076498568937595,
0.13076431094852323,
0.13076364200697563,
0.1307629788148889,
0.13076232132284715,
0.1307616694818592,
0.13076102324335503,
0.13076038255918201,
0.1307597473816014,
0.13075911766328474,
0.1307584933573104,
0.13075787441716,
0.13075726079671504,
0.13075665245025325,
0.13075604933244553,
0.13075545139835226,
0.13075485860342015,
0.13075427090347866,
0.13075368825473718,
0.13075311061378125,
0.13075253793756964,
0.13075197018343104,
0.13075140730906076,
0.13075084927251807,
0.1307502960322223,
0.13074974754695046,
0.13074920377583368,
0.13074866467835453,
0.13074813021434364,
0.130747600343977,
0.13074707502777286,
0.13074655422658876,
0.1307460379016188,
0.1307455260143904,
0.13074501852676193,
0.13074451540091928,
0.1307440165993736,
0.13074352208495807,
0.13074303182082542,
0.130742545770445,
0.1307420638976003,
0.1307415861663858,
0.13074111254120485,
0.13074064298676666,
0.1307401774680837,
0.13073971595046924,
0.1307392583995346,
0.13073880478118677,
0.13073835506162565,
0.13073790920734168,
0.13073746718511345,
0.13073702896200484,
0.13073659450536299,
0.13073616378281558,
0.13073573676226868,
0.130735313411904,
0.13073489370017688,
0.13073447759581372,
0.1307340650678097,
0.13073365608542659,
0.13073325061819008,
0.1307328486358882,
0.13073245010856818,
0.13073205500653512,
0.13073166330034908,
0.1307312749608232,
0.13073088995902146,
0.1307305082662567,
0.13073012985408813,
0.1307297546943194,
0.13072938275899676,
0.1307290140204064,
0.13072864845107288,
0.13072828602375694,
0.13072792671145325,
0.1307275704873888,
0.1307272173250206,
0.1307268671980337,
0.13072652008033953,
0.13072617594607358,
0.13072583476959368,
0.13072549652547816,
0.1307251611885236,
0.1307248287337435,
0.13072449913636588,
0.13072417237183181,
0.13072384841579335,
0.13072352724411188,
0.1307232088328562,
0.130722893158301,
0.13072258019692454,
0.13072226992540745,
0.1307219623206308,
0.13072165735967436,
0.13072135501981477,
0.13072105527852415,
0.1307207581134681,
0.13072046350250424,
0.13072017142368056,
0.13071988185523364,
0.13071959477558712,
0.1307193101633501,
0.13071902799731558,
0.13071874825645874,
0.1307184709199356,
0.13071819596708112,
0.1307179233774081,
0.13071765313060527,
0.130717385206536,
0.1307171195852367,
0.13071685624691548,
0.13071659517195028,
0.13071633634088803,
0.13071607973444263,
0.13071582533349385,
0.1307155731190857,
0.13071532307242514,
0.1307150751748808,
0.13071482940798118,
0.13071458575341377,
0.1307143441930234,
0.13071410470881087,
0.13071386728293166,
0.13071363189769475,
0.13071339853556113,
0.13071316717914244,
0.13071293781119986,
0.13071271041464275,
0.1307124849725273,
0.1307122614680553,
0.13071203988457306,
0.13071182020556996,
0.13071160241467716,
0.13071138649566671,
0.13071117243244995,
0.1307109602090767,
0.13071074980973368,
0.13071054121874362,
0.13071033442056404,
0.13071012939978588,
0.1307099261411327,
0.13070972462945923,
0.13070952484975046,
0.1307093267871204,
0.13070913042681098,
0.130708935754191,
0.13070874275475497,
0.1307085514141222,
0.13070836171803538,
0.13070817365235993,
0.1307079872030827,
0.13070780235631096,
0.1307076190982714,
0.1307074374153091,
0.13070725729388652,
0.13070707872058232,
0.1307069016820908,
0.13070672616522036,
0.13070655215689286,
0.13070637964414267,
0.13070620861411547,
0.1307060390540674,
0.13070587095136435,
0.1307057042934805,
0.1307055390679979,
0.13070537526260517,
0.13070521286509698,
0.13070505186337264,
0.13070489224543566,
0.1307047339993925,
0.13070457711345204,
0.13070442157592427,
0.13070426737521984,
0.1307041144998489,
0.1307039629384204,
0.13070381267964126,
0.13070366371231526,
0.13070351602534264,
0.13070336960771892]
实验要求3 绘制误差曲线
plt.plot(range(iter_num2),cost_lst2,"r-+")
plt.xlabel("迭代次数")
plt.ylabel("误差")
plt.show()
1.4 最小二乘法求参数
最小二乘法的需要求解最优参数 w ∗ w^{*} w∗:
已知:目标函数
J ( w ) = 1 2 m ∑ i = 1 m ( h ( x ( i ) ) − y ( i ) ) 2 J\left( w \right)=\frac{1}{2m}\sum\limits_{i=1}^{m}{{{\left( {h}\left( {x^{(i)}} \right)-{y^{(i)}} \right)}^{2}}} J(w)=2m1i=1∑m(h(x(i))−y(i))2
其中: h ( x ) = w T X = w 0 x 0 + w 1 x 1 + w 2 x 2 + . . . + w n x n {h}\left( x \right)={w^{T}}X={w_{0}}{x_{0}}+{w_{1}}{x_{1}}+{w_{2}}{x_{2}}+...+{w_{n}}{x_{n}} h(x)=wTX=w0x0+w1x1+w2x2+...+wnxn
将向量表达形式转为矩阵表达形式,则有 J ( w ) = 1 2 ( X w − y ) 2 J(w )=\frac{1}{2}{{\left( Xw -y\right)}^{2}} J(w)=21(Xw−y)2 ,其中 X X X为 m m m行 n + 1 n+1 n+1列的矩阵( m m m为样本个数, n n n为特征个数), w w w为 n + 1 n+1 n+1行1列的矩阵(包含了 w 0 w_0 w0), y y y为 m m m行1列的矩阵,则可以求得最优参数 w ∗ = ( X T X ) − 1 X T y w^{*} ={{\left( {X^{T}}X \right)}^{-1}}{X^{T}}y w∗=(XTX)−1XTy
梯度下降与最小二乘法的比较:
梯度下降:需要选择学习率 α \alpha α,需要多次迭代,当特征数量 n n n大时也能较好适用,适用于各种类型的模型
最小二乘法:不需要选择学习率 α \alpha α,一次计算得出,需要计算 ( X T X ) − 1 {{\left( {{X}^{T}}X \right)}^{-1}} (XTX)−1,如果特征数量 n n n较大则运算代价大,因为矩阵逆的计算时间复杂度为 O ( n 3 ) O(n^3) O(n3),通常来说当 n n n小于10000 时还是可以接受的,只适用于线性模型,不适合逻辑回归模型等其他模型
def lsm(X,y):
w=np.linalg.inv(X.T@X)@X.T@y
return w
def lsm_v(X,y):
w=np.linalg.inv(np.dot(X.T,X))
w=np.dot(w,X.T)
w=np.dot(w,y)
return w
lsm(X,y)
array([[-3.89578088],
[ 1.19303364]])
lsm_v(X,y)
array([[-3.89578088],
[ 1.19303364]])
1.5 来点正则化?
1.5.1 普通的线性回归
from sklearn import linear_model
reg=linear_model.LinearRegression()
reg.fit(X,y)
LinearRegression()
#回到单变量的线性回归中来
x=X
y_1=reg.predict(x)
plt.plot(x,y_1,"r-+",label="预测线")
plt.scatter(data["人口"],data["收益"], label='训练数据')
plt.xlim(4.7,10)
plt.xlabel("人口",fontsize=10)
plt.ylabel("收益",fontsize=10)
plt.title("人口与收益之间的关系")
plt.show()
reg.coef_,reg.intercept_,reg.score(X,y)
(array([[0. , 1.19303364]]), array([-3.89578088]), 0.7020315537841397)
1.5.2 岭回归
J
(
w
)
=
1
2
∑
i
=
1
m
(
h
w
(
x
(
i
)
)
−
y
(
i
)
)
2
+
λ
∑
j
=
1
n
w
j
2
J ( { w } ) = \frac { 1 } { 2 } \sum _ { i = 1 } ^ { m } ( h _ { w} ( x ^ { ( i ) } ) - y ^ { ( i ) } ) ^ { 2 } + \lambda \sum _ { j = 1 } ^ { n } w_ { j } ^ { 2 }
J(w)=21∑i=1m(hw(x(i))−y(i))2+λ∑j=1nwj2,此时称作Ridge Regression
:
from sklearn import linear_model
reg_rigde=linear_model.Ridge()
reg_rigde.fit(X,y)
Ridge()
#回到单变量的线性回归中来,Ridge
x=X
y_1=reg_rigde.predict(x)
plt.plot(x,y_1,"r-+",label="预测线")
plt.scatter(data["人口"],data["收益"], label='训练数据')
plt.xlim(4.7,10)
plt.xlabel("人口",fontsize=10)
plt.ylabel("收益",fontsize=10)
plt.title("人口与收益之间的关系")
plt.show()
reg_rigde.coef_,reg_rigde.intercept_,reg_rigde.score(X,y)
(array([[0. , 1.1922044]]), array([-3.88901439]), 0.7020312146131912)
1.5.3 Lasso回归
J
(
w
)
=
1
2
∑
i
=
1
m
(
h
w
(
x
(
i
)
)
−
y
(
i
)
)
2
+
λ
∑
j
=
1
n
∣
w
j
∣
J ( {w } ) = \frac { 1 } { 2 } \sum _ { i = 1 } ^ { m } ( h _ { w} ( x ^ { ( i ) } ) - y ^ { ( i ) } ) ^ { 2 } + \lambda \sum _ { j = 1 } ^ { n } | w _ { j } |
J(w)=21∑i=1m(hw(x(i))−y(i))2+λ∑j=1n∣wj∣,此时称作Lasso Regression
from sklearn import linear_model
reg_lasso=linear_model.Lasso()
reg_lasso.fit(X,y)
Lasso()
#回到单变量的线性回归中来,Lasso
x=X
y_1=reg_lasso.predict(x)
plt.plot(x,y_1,"r-+",label="预测线")
plt.scatter(data["人口"],data["收益"], label='训练数据')
plt.xlim(4.7,10)
plt.xlabel("人口",fontsize=10)
plt.ylabel("收益",fontsize=10)
plt.title("人口与收益之间的关系")
plt.show()
reg_lasso.coef_,reg_lasso.intercept_,reg_lasso.score(X,y)
(array([0. , 1.12556458]), array([-3.34524677]), 0.6997863246152711)
实验要求4 手写代码实现单变量的L2正则化
J
(
w
)
=
1
2
∑
i
=
1
m
(
h
w
(
x
(
i
)
)
−
y
(
i
)
)
2
+
λ
∑
j
=
1
n
w
j
2
J ( { w } ) = \frac { 1 } { 2 } \sum _ { i = 1 } ^ { m } ( h _ { w} ( x ^ { ( i ) } ) - y ^ { ( i ) } ) ^ { 2 } + \lambda \sum _ { j = 1 } ^ { n } w_ { j } ^ { 2 }
J(w)=21∑i=1m(hw(x(i))−y(i))2+λ∑j=1nwj2,此时称作Ridge Regression
:
#超参数为I,学习率alpha,对所有样本
def gradient_descent_l2(X,y,w,iter_num,alpha,lambd):
temp=np.zeros((col_num-1,1))
cost_lst=[]
for i in range(iter_num):
error=h(X,w)-y
for j in range(col_num-1):
incre=np.multiply(error.ravel(),X[:,j].ravel())
temp[j,0]=w[j,0]-((alpha/m)*(np.sum(incre)+2*lambd*w[j,0]))
w=temp
cost_lst.append(cost(X,y,w))
return w,cost_lst
iter_num=200
alpha=0.001
lambd=2
w=np.zeros((col_num-1,1))
w,cost_lst=gradient_descent_l2(X,y,w,iter_num,alpha,lambd)
plt.plot(range(iter_num),cost_lst,"r-+")
plt.xlabel("迭代次数")
plt.ylabel("误差")
plt.show()