波士顿房价预测
1.导入所需要的库
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import random
from sklearn.model_selection import train_test_split
2.读入数据
feature = pd.read_csv("../data/boston.csv")
feature
CRIM | ZN | INDUS | CHAS | NOX | RM | AGE | DIS | RAD | TAX | PTRATIO | B | LSTAT | MEDV | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 0.00632 | 18.0 | 2.31 | 0 | 0.538 | 6.575 | 65.2 | 4.0900 | 1 | 296 | 15.3 | 396.90 | 4.98 | 24.0 |
1 | 0.02731 | 0.0 | 7.07 | 0 | 0.469 | 6.421 | 78.9 | 4.9671 | 2 | 242 | 17.8 | 396.90 | 9.14 | 21.6 |
2 | 0.02729 | 0.0 | 7.07 | 0 | 0.469 | 7.185 | 61.1 | 4.9671 | 2 | 242 | 17.8 | 392.83 | 4.03 | 34.7 |
3 | 0.03237 | 0.0 | 2.18 | 0 | 0.458 | 6.998 | 45.8 | 6.0622 | 3 | 222 | 18.7 | 394.63 | 2.94 | 33.4 |
4 | 0.06905 | 0.0 | 2.18 | 0 | 0.458 | 7.147 | 54.2 | 6.0622 | 3 | 222 | 18.7 | 396.90 | 5.33 | 36.2 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
501 | 0.06263 | 0.0 | 11.93 | 0 | 0.573 | 6.593 | 69.1 | 2.4786 | 1 | 273 | 21.0 | 391.99 | 9.67 | 22.4 |
502 | 0.04527 | 0.0 | 11.93 | 0 | 0.573 | 6.120 | 76.7 | 2.2875 | 1 | 273 | 21.0 | 396.90 | 9.08 | 20.6 |
503 | 0.06076 | 0.0 | 11.93 | 0 | 0.573 | 6.976 | 91.0 | 2.1675 | 1 | 273 | 21.0 | 396.90 | 5.64 | 23.9 |
504 | 0.10959 | 0.0 | 11.93 | 0 | 0.573 | 6.794 | 89.3 | 2.3889 | 1 | 273 | 21.0 | 393.45 | 6.48 | 22.0 |
505 | 0.04741 | 0.0 | 11.93 | 0 | 0.573 | 6.030 | 80.8 | 2.5050 | 1 | 273 | 21.0 | 396.90 | 7.88 | 11.9 |
506 rows × 14 columns
3.数据预处理
(1).将特征和标签分离
获取标签数据并转变成darray数据类型
label = feature['MEDV']
# label
label = np.array(label)
label,label.dtype,len(label)
(array([24. , 21.6, 34.7, 33.4, 36.2, 28.7, 22.9, 27.1, 16.5, 18.9, 15. ,
18.9, 21.7, 20.4, 18.2, 19.9, 23.1, 17.5, 20.2, 18.2, 13.6, 19.6,
15.2, 14.5, 15.6, 13.9, 16.6, 14.8, 18.4, 21. , 12.7, 14.5, 13.2,
13.1, 13.5, 18.9, 20. , 21. , 24.7, 30.8, 34.9, 26.6, 25.3, 24.7,
21.2, 19.3, 20. , 16.6, 14.4, 19.4, 19.7, 20.5, 25. , 23.4, 18.9,
35.4, 24.7, 31.6, 23.3, 19.6, 18.7, 16. , 22.2, 25. , 33. , 23.5,
19.4, 22. , 17.4, 20.9, 24.2, 21.7, 22.8, 23.4, 24.1, 21.4, 20. ,
20.8, 21.2, 20.3, 28. , 23.9, 24.8, 22.9, 23.9, 26.6, 22.5, 22.2,
23.6, 28.7, 22.6, 22. , 22.9, 25. , 20.6, 28.4, 21.4, 38.7, 43.8,
33.2, 27.5, 26.5, 18.6, 19.3, 20.1, 19.5, 19.5, 20.4, 19.8, 19.4,
21.7, 22.8, 18.8, 18.7, 18.5, 18.3, 21.2, 19.2, 20.4, 19.3, 22. ,
20.3, 20.5, 17.3, 18.8, 21.4, 15.7, 16.2, 18. , 14.3, 19.2, 19.6,
23. , 18.4, 15.6, 18.1, 17.4, 17.1, 13.3, 17.8, 14. , 14.4, 13.4,
15.6, 11.8, 13.8, 15.6, 14.6, 17.8, 15.4, 21.5, 19.6, 15.3, 19.4,
17. , 15.6, 13.1, 41.3, 24.3, 23.3, 27. , 50. , 50. , 50. , 22.7,
25. , 50. , 23.8, 23.8, 22.3, 17.4, 19.1, 23.1, 23.6, 22.6, 29.4,
23.2, 24.6, 29.9, 37.2, 39.8, 36.2, 37.9, 32.5, 26.4, 29.6, 50. ,
32. , 29.8, 34.9, 37. , 30.5, 36.4, 31.1, 29.1, 50. , 33.3, 30.3,
34.6, 34.9, 32.9, 24.1, 42.3, 48.5, 50. , 22.6, 24.4, 22.5, 24.4,
20. , 21.7, 19.3, 22.4, 28.1, 23.7, 25. , 23.3, 28.7, 21.5, 23. ,
26.7, 21.7, 27.5, 30.1, 44.8, 50. , 37.6, 31.6, 46.7, 31.5, 24.3,
31.7, 41.7, 48.3, 29. , 24. , 25.1, 31.5, 23.7, 23.3, 22. , 20.1,
22.2, 23.7, 17.6, 18.5, 24.3, 20.5, 24.5, 26.2, 24.4, 24.8, 29.6,
42.8, 21.9, 20.9, 44. , 50. , 36. , 30.1, 33.8, 43.1, 48.8, 31. ,
36.5, 22.8, 30.7, 50. , 43.5, 20.7, 21.1, 25.2, 24.4, 35.2, 32.4,
32. , 33.2, 33.1, 29.1, 35.1, 45.4, 35.4, 46. , 50. , 32.2, 22. ,
20.1, 23.2, 22.3, 24.8, 28.5, 37.3, 27.9, 23.9, 21.7, 28.6, 27.1,
20.3, 22.5, 29. , 24.8, 22. , 26.4, 33.1, 36.1, 28.4, 33.4, 28.2,
22.8, 20.3, 16.1, 22.1, 19.4, 21.6, 23.8, 16.2, 17.8, 19.8, 23.1,
21. , 23.8, 23.1, 20.4, 18.5, 25. , 24.6, 23. , 22.2, 19.3, 22.6,
19.8, 17.1, 19.4, 22.2, 20.7, 21.1, 19.5, 18.5, 20.6, 19. , 18.7,
32.7, 16.5, 23.9, 31.2, 17.5, 17.2, 23.1, 24.5, 26.6, 22.9, 24.1,
18.6, 30.1, 18.2, 20.6, 17.8, 21.7, 22.7, 22.6, 25. , 19.9, 20.8,
16.8, 21.9, 27.5, 21.9, 23.1, 50. , 50. , 50. , 50. , 50. , 13.8,
13.8, 15. , 13.9, 13.3, 13.1, 10.2, 10.4, 10.9, 11.3, 12.3, 8.8,
7.2, 10.5, 7.4, 10.2, 11.5, 15.1, 23.2, 9.7, 13.8, 12.7, 13.1,
12.5, 8.5, 5. , 6.3, 5.6, 7.2, 12.1, 8.3, 8.5, 5. , 11.9,
27.9, 17.2, 27.5, 15. , 17.2, 17.9, 16.3, 7. , 7.2, 7.5, 10.4,
8.8, 8.4, 16.7, 14.2, 20.8, 13.4, 11.7, 8.3, 10.2, 10.9, 11. ,
9.5, 14.5, 14.1, 16.1, 14.3, 11.7, 13.4, 9.6, 8.7, 8.4, 12.8,
10.5, 17.1, 18.4, 15.4, 10.8, 11.8, 14.9, 12.6, 14.1, 13. , 13.4,
15.2, 16.1, 17.8, 14.9, 14.1, 12.7, 13.5, 14.9, 20. , 16.4, 17.7,
19.5, 20.2, 21.4, 19.9, 19. , 19.1, 19.1, 20.1, 19.9, 19.6, 23.2,
29.8, 13.8, 13.3, 16.7, 12. , 14.6, 21.4, 23. , 23.7, 25. , 21.8,
20.6, 21.2, 19.1, 20.6, 15.2, 7. , 8.1, 13.6, 20.1, 21.8, 24.5,
23.1, 19.7, 18.3, 21.2, 17.5, 16.8, 22.4, 20.6, 23.9, 22. , 11.9]),
dtype('float64'),
506)
对特征进行提取并转为darray类型
feature = feature.drop('MEDV',axis=1)
data = np.array(feature)
data
array([[6.3200e-03, 1.8000e+01, 2.3100e+00, ..., 1.5300e+01, 3.9690e+02,
4.9800e+00],
[2.7310e-02, 0.0000e+00, 7.0700e+00, ..., 1.7800e+01, 3.9690e+02,
9.1400e+00],
[2.7290e-02, 0.0000e+00, 7.0700e+00, ..., 1.7800e+01, 3.9283e+02,
4.0300e+00],
...,
[6.0760e-02, 0.0000e+00, 1.1930e+01, ..., 2.1000e+01, 3.9690e+02,
5.6400e+00],
[1.0959e-01, 0.0000e+00, 1.1930e+01, ..., 2.1000e+01, 3.9345e+02,
6.4800e+00],
[4.7410e-02, 0.0000e+00, 1.1930e+01, ..., 2.1000e+01, 3.9690e+02,
7.8800e+00]])
(2).对数据做归一化处理
由于输入的数据中不同的特征值数值大小差异很大,所以神经网络会认为数值大的特征更重要,以此影响了预测结果。
因此我们需要归一化处理,让不同的特征值大小在一个级别上
# !pip install scikit-learn
from sklearn import preprocessing
data = preprocessing.StandardScaler().fit_transform(data)
data
array([[-0.41978194, 0.28482986, -1.2879095 , ..., -1.45900038,
0.44105193, -1.0755623 ],
[-0.41733926, -0.48772236, -0.59338101, ..., -0.30309415,
0.44105193, -0.49243937],
[-0.41734159, -0.48772236, -0.59338101, ..., -0.30309415,
0.39642699, -1.2087274 ],
...,
[-0.41344658, -0.48772236, 0.11573841, ..., 1.17646583,
0.44105193, -0.98304761],
[-0.40776407, -0.48772236, 0.11573841, ..., 1.17646583,
0.4032249 , -0.86530163],
[-0.41500016, -0.48772236, 0.11573841, ..., 1.17646583,
0.44105193, -0.66905833]])
(3).划分训练集和测试集
train_data,test_data,train_label,test_label = train_test_split(data,label,test_size=0.2)
# train_data
test_data,test_label
(array([[-0.40301721, -0.48772236, 2.11761463, ..., 0.29797709,
0.14950964, 1.78818809],
[ 2.59827406, -0.48772236, 1.01599907, ..., 0.80657583,
-2.51428122, 1.96060184],
[ 1.32780421, -0.48772236, 1.01599907, ..., 0.80657583,
-0.07887794, 1.7181012 ],
...,
[-0.39825173, 0.45650813, -0.76993132, ..., 0.29797709,
0.35585887, 0.8139803 ],
[-0.41613247, 2.94584308, -1.40317788, ..., -2.70737911,
0.38951945, -0.8456773 ],
[-0.32270106, -0.48772236, -0.43725801, ..., 1.17646583,
-0.58389629, 0.54064142]]),
array([17.3, 10.4, 12. , 26.4, 24.4, 20.8, 28.7, 18.3, 14.5, 22.6, 21.2,
15.2, 29.1, 22.9, 13.1, 21.4, 21.4, 24.1, 11.8, 21.7, 17.8, 23.2,
19.5, 23.3, 33.1, 14.1, 26.7, 36.5, 14.1, 16.8, 7. , 17.7, 24.8,
22.2, 21. , 29.1, 8.7, 27.1, 9.7, 20.4, 15.2, 16.5, 33.2, 19.8,
35.4, 22. , 21.2, 20.7, 10.9, 20.9, 8.8, 10.5, 31.5, 21.9, 28.7,
25. , 17.9, 7.2, 19.4, 29. , 14. , 42.3, 22.8, 29.9, 14.5, 17.5,
22.6, 7. , 36. , 21.2, 20.6, 15. , 9.5, 10.8, 24.8, 50. , 16.4,
19.6, 23.7, 24.5, 15.6, 17.8, 16.2, 20.6, 19.9, 19.6, 19. , 8.5,
22.8, 21. , 21.7, 34.9, 48.5, 15.6, 11.7, 12.5, 17.1, 17.5, 28.4,
18.5, 34.6, 13.9]))
(4).转换为tensor张量
# train_data = torch.tensor(train_data,dtype=torch.float32)
# train_label = torch.tensor(train_label,dtype=torch.float32)
# test_data = torch.tensor(test_data,dtype=torch.float32)
# test_label = torch.tensor(test_label,dtype=torch.float32)
train_data=torch.tensor(train_data,dtype=float,requires_grad=True).to(torch.float32)
train_label=torch.tensor(train_label,dtype=float,requires_grad=True).to(torch.float32).reshape(-1,1)
test_data=torch.tensor(test_data,dtype=float,requires_grad=True).to(torch.float32)
test_label=torch.tensor(test_label,dtype=float,requires_grad=True).to(torch.float32).reshape(-1,1)
train_data.size(),train_label.shape,test_data.shape,test_label.shape,train_data.type(),train_label
(torch.Size([404, 13]),
torch.Size([404, 1]),
torch.Size([102, 13]),
torch.Size([102, 1]),
'torch.FloatTensor',
tensor([[19.1000],
[33.3000],
[15.1000],
[15.4000],
[22.5000],
[19.9000],
[30.7000],
[19.5000],
[31.2000],
[19.3000],
[13.1000],
[31.6000],
[43.8000],
[15.6000],
[50.0000],
[29.6000],
[19.3000],
[20.2000],
[15.4000],
[19.8000],
[12.8000],
[21.7000],
[18.5000],
[21.2000],
[34.9000],
[31.5000],
[21.1000],
[23.6000],
[19.1000],
[ 5.0000],
[10.2000],
[20.5000],
[15.0000],
[13.4000],
[24.5000],
[18.7000],
[21.4000],
[29.4000],
[20.4000],
[27.9000],
[42.8000],
[23.1000],
[33.4000],
[ 9.6000],
[22.2000],
[17.4000],
[22.3000],
[14.3000],
[24.4000],
[20.0000],
[18.4000],
[18.5000],
[27.0000],
[11.9000],
[22.6000],
[50.0000],
[13.1000],
[23.3000],
[16.1000],
[17.2000],
[ 8.5000],
[17.2000],
[13.0000],
[29.6000],
[25.0000],
[22.4000],
[20.1000],
[16.6000],
[50.0000],
[20.5000],
[22.3000],
[19.2000],
[17.4000],
[30.1000],
[ 7.2000],
[13.8000],
[14.8000],
[19.4000],
[19.3000],
[22.0000],
[28.6000],
[18.7000],
[13.2000],
[11.9000],
[32.9000],
[50.0000],
[20.3000],
[22.9000],
[20.1000],
[30.1000],
[36.1000],
[15.3000],
[18.9000],
[17.8000],
[27.1000],
[34.7000],
[28.7000],
[18.2000],
[18.5000],
[13.4000],
[18.7000],
[16.3000],
[28.0000],
[19.1000],
[45.4000],
[18.4000],
[13.5000],
[21.9000],
[18.9000],
[10.2000],
[23.7000],
[19.9000],
[22.8000],
[37.9000],
[24.6000],
[19.6000],
[23.7000],
[18.2000],
[22.7000],
[24.5000],
[19.6000],
[16.8000],
[50.0000],
[13.3000],
[20.8000],
[23.3000],
[24.0000],
[ 8.3000],
[50.0000],
[24.7000],
[41.7000],
[17.8000],
[16.0000],
[19.9000],
[26.6000],
[36.4000],
[13.8000],
[29.0000],
[ 6.3000],
[19.0000],
[22.0000],
[20.4000],
[25.2000],
[23.4000],
[29.8000],
[21.4000],
[22.6000],
[23.1000],
[14.4000],
[21.7000],
[21.6000],
[28.5000],
[21.8000],
[18.3000],
[ 7.4000],
[14.1000],
[17.1000],
[26.6000],
[21.8000],
[19.1000],
[14.6000],
[13.8000],
[16.6000],
[20.1000],
[13.3000],
[48.8000],
[21.0000],
[34.9000],
[11.8000],
[23.9000],
[32.5000],
[20.0000],
[25.1000],
[28.4000],
[23.7000],
[13.6000],
[12.7000],
[21.9000],
[24.4000],
[26.5000],
[18.6000],
[50.0000],
[37.2000],
[31.1000],
[20.3000],
[20.4000],
[22.5000],
[24.3000],
[20.1000],
[14.5000],
[23.1000],
[32.0000],
[23.0000],
[13.8000],
[18.1000],
[22.2000],
[14.3000],
[23.1000],
[19.7000],
[21.1000],
[22.5000],
[ 8.1000],
[20.7000],
[22.4000],
[38.7000],
[21.7000],
[25.3000],
[46.7000],
[19.5000],
[19.8000],
[14.9000],
[17.6000],
[26.2000],
[ 7.5000],
[15.2000],
[12.3000],
[13.4000],
[23.2000],
[23.9000],
[20.0000],
[20.3000],
[50.0000],
[23.0000],
[20.3000],
[23.0000],
[14.4000],
[29.8000],
[ 5.6000],
[20.9000],
[30.8000],
[24.0000],
[24.8000],
[22.0000],
[23.4000],
[16.7000],
[23.8000],
[25.0000],
[22.0000],
[26.4000],
[30.3000],
[19.4000],
[50.0000],
[15.0000],
[19.3000],
[21.4000],
[12.7000],
[18.8000],
[13.3000],
[23.9000],
[25.0000],
[24.6000],
[50.0000],
[14.6000],
[20.0000],
[37.3000],
[ 8.4000],
[32.2000],
[16.1000],
[28.2000],
[20.6000],
[31.7000],
[18.8000],
[27.9000],
[21.5000],
[23.9000],
[11.0000],
[20.0000],
[19.3000],
[18.2000],
[13.1000],
[32.7000],
[31.0000],
[18.9000],
[10.2000],
[15.7000],
[17.5000],
[20.2000],
[27.5000],
[20.6000],
[19.4000],
[18.0000],
[20.6000],
[22.0000],
[33.1000],
[25.0000],
[19.6000],
[ 8.3000],
[36.2000],
[14.9000],
[17.1000],
[23.1000],
[50.0000],
[27.5000],
[17.8000],
[19.4000],
[19.2000],
[27.5000],
[33.2000],
[24.4000],
[22.9000],
[22.0000],
[31.6000],
[41.3000],
[17.4000],
[22.1000],
[24.7000],
[17.2000],
[48.3000],
[10.5000],
[32.4000],
[10.4000],
[35.1000],
[22.8000],
[19.4000],
[22.6000],
[16.2000],
[20.6000],
[13.8000],
[13.6000],
[23.0000],
[24.2000],
[44.8000],
[25.0000],
[50.0000],
[23.8000],
[16.1000],
[18.6000],
[19.5000],
[18.4000],
[20.5000],
[24.7000],
[25.0000],
[28.1000],
[19.7000],
[39.8000],
[25.0000],
[11.3000],
[26.6000],
[23.1000],
[21.5000],
[13.9000],
[27.5000],
[15.6000],
[36.2000],
[ 8.4000],
[13.4000],
[32.0000],
[50.0000],
[23.2000],
[33.0000],
[13.5000],
[16.5000],
[ 8.8000],
[18.9000],
[ 5.0000],
[23.8000],
[23.8000],
[14.2000],
[43.1000],
[46.0000],
[20.1000],
[24.3000],
[33.8000],
[21.7000],
[15.6000],
[35.2000],
[22.7000],
[24.1000],
[30.1000],
[12.1000],
[12.7000],
[43.5000],
[23.9000],
[22.2000],
[21.7000],
[24.1000],
[11.5000],
[20.8000],
[14.9000],
[16.7000],
[23.2000],
[23.5000],
[21.2000],
[ 7.2000],
[10.9000],
[35.4000],
[24.3000],
[21.6000],
[23.6000],
[11.7000],
[17.0000],
[50.0000],
[44.0000],
[33.4000],
[30.5000],
[24.8000],
[37.6000],
[37.0000],
[12.6000],
[22.2000],
[22.9000],
[50.0000],
[23.1000],
[23.3000]], grad_fn=<ReshapeAliasBackward0>))
4.定义模型
class LinearModel(torch.nn.Module):
def __init__(self):
super(LinearModel,self).__init__()
self.linear = torch.nn.Linear(13, 1)
def forward(self, x):
y_pred = self.linear(x)
return y_pred
model1 = LinearModel()
model1
LinearModel(
(linear): Linear(in_features=13, out_features=1, bias=True)
)
criterion = torch.nn.MSELoss(reduction='mean')
optimizer = torch.optim.SGD(model1.parameters(), lr =0.005)
5.训练
epoch_loss = []
for epoch in range(300):
y_pred = model1(train_data)
loss = criterion(y_pred,train_label)
print(epoch,loss.item())
optimizer.zero_grad()
loss.backward()
optimizer.step()
# print('epoch',epoch,'loss=',loss.item())
epoch_loss.append(loss.item())
0 624.7080688476562
1 608.4896240234375
2 593.03759765625
3 578.2877197265625
4 564.1824951171875
5 550.6704711914062
6 537.7056274414062
7 525.24658203125
8 513.2561645507812
9 501.7012023925781
10 490.5516052246094
11 479.7803955078125
12 469.3631591796875
13 459.2780456542969
14 449.50494384765625
15 440.0260009765625
16 430.8248291015625
17 421.886474609375
18 413.19757080078125
19 404.74554443359375
20 396.51934814453125
21 388.50848388671875
22 380.7034606933594
23 373.09564208984375
24 365.6768493652344
25 358.43975830078125
26 351.37744140625
27 344.48345947265625
28 337.7518615722656
29 331.17706298828125
30 324.7538757324219
31 318.47747802734375
32 312.3431701660156
33 306.3467102050781
34 300.4839782714844
35 294.75115966796875
36 289.1445617675781
37 283.6607666015625
38 278.29638671875
39 273.04827880859375
40 267.91339111328125
41 262.888916015625
42 257.9720458984375
43 253.1600799560547
44 248.45053100585938
45 243.84091186523438
46 239.3287811279297
47 234.91195678710938
48 230.58810424804688
49 226.35516357421875
50 222.2110137939453
51 218.15365600585938
52 214.18113708496094
53 210.29153442382812
54 206.48306274414062
55 202.7538604736328
56 199.10223388671875
57 195.52645874023438
58 192.02490234375
59 188.59597778320312
60 185.23800659179688
61 181.94960021972656
62 178.7291717529297
63 175.57528686523438
64 172.48654174804688
65 169.46153259277344
66 166.49891662597656
67 163.5973663330078
68 160.75558471679688
69 157.97230529785156
70 155.24632263183594
71 152.57640075683594
72 149.9613494873047
73 147.40000915527344
74 144.89129638671875
75 142.43408203125
76 140.02731323242188
77 137.66993713378906
78 135.3608856201172
79 133.09915161132812
80 130.88375854492188
81 128.7137451171875
82 126.58817291259766
83 124.5060806274414
84 122.46662139892578
85 120.4688491821289
86 118.51192474365234
87 116.59502410888672
88 114.71727752685547
89 112.87786865234375
90 111.07605743408203
91 109.31100463867188
92 107.58196258544922
93 105.88819122314453
94 104.2289810180664
95 102.60359191894531
96 101.01134490966797
97 99.45150756835938
98 97.92346954345703
99 96.42654418945312
100 94.9600830078125
101 93.52346801757812
102 92.1160659790039
103 90.73729705810547
104 89.38656616210938
105 88.06329345703125
106 86.76689910888672
107 85.4968490600586
108 84.25257873535156
109 83.0335693359375
110 81.83930969238281
111 80.66928100585938
112 79.5229721069336
113 78.39990997314453
114 77.29962158203125
115 76.22161865234375
116 75.16545867919922
117 74.13068389892578
118 73.11685180664062
119 72.12356567382812
120 71.15036010742188
121 70.19683837890625
122 69.2625961303711
123 68.34725189208984
124 67.45038604736328
125 66.57164001464844
126 65.71065521240234
127 64.86703491210938
128 64.04045104980469
129 63.230552673339844
130 62.436981201171875
131 61.659427642822266
132 60.89752960205078
133 60.15097427368164
134 59.419464111328125
135 58.70269012451172
136 58.00033950805664
137 57.3121223449707
138 56.63774871826172
139 55.97693634033203
140 55.32942199707031
141 54.694915771484375
142 54.07315444946289
143 53.46388626098633
144 52.8668327331543
145 52.28177261352539
146 51.70844268798828
147 51.1466178894043
148 50.59606170654297
149 50.05653381347656
150 49.527801513671875
151 49.009674072265625
152 48.50191879272461
153 48.00431823730469
154 47.51667785644531
155 47.038795471191406
156 46.57045364379883
157 46.11146926879883
158 45.66167068481445
159 45.220829010009766
160 44.7888069152832
161 44.36539840698242
162 43.950435638427734
163 43.54374313354492
164 43.1451530456543
165 42.7545051574707
166 42.37163162231445
167 41.99638366699219
168 41.62860107421875
169 41.268123626708984
170 40.91482162475586
171 40.568538665771484
172 40.22915267944336
173 39.89649200439453
174 39.570430755615234
175 39.25083541870117
176 38.93758010864258
177 38.63054275512695
178 38.329586029052734
179 38.03458023071289
180 37.74541091918945
181 37.461971282958984
182 37.18413162231445
183 36.91177749633789
184 36.644813537597656
185 36.38311767578125
186 36.1265869140625
187 35.8751220703125
188 35.62860107421875
189 35.38695526123047
190 35.150062561035156
191 34.91782760620117
192 34.69016647338867
193 34.466983795166016
194 34.24818801879883
195 34.0337028503418
196 33.823421478271484
197 33.61726760864258
198 33.415157318115234
199 33.21701431274414
200 33.02274703979492
201 32.83229064941406
202 32.64555740356445
203 32.46246337890625
204 32.282962799072266
205 32.106956481933594
206 31.934391021728516
207 31.76519203186035
208 31.599294662475586
209 31.436626434326172
210 31.277130126953125
211 31.12073516845703
212 30.967384338378906
213 30.8170223236084
214 30.669580459594727
215 30.524993896484375
216 30.383211135864258
217 30.244182586669922
218 30.10784339904785
219 29.97414779663086
220 29.843036651611328
221 29.71446418762207
222 29.58837890625
223 29.464717864990234
224 29.34345054626465
225 29.224512100219727
226 29.107872009277344
227 28.99346923828125
228 28.881269454956055
229 28.771228790283203
230 28.66329574584961
231 28.557437896728516
232 28.45360565185547
233 28.351764678955078
234 28.251873016357422
235 28.15389060974121
236 28.05777931213379
237 27.963502883911133
238 27.87102699279785
239 27.780311584472656
240 27.69132423400879
241 27.60402488708496
242 27.51839256286621
243 27.434383392333984
244 27.35196304321289
245 27.27110481262207
246 27.19178581237793
247 27.11396026611328
248 27.03760528564453
249 26.962692260742188
250 26.889190673828125
251 26.817075729370117
252 26.74631690979004
253 26.67688751220703
254 26.608760833740234
255 26.54191017150879
256 26.476316452026367
257 26.411943435668945
258 26.348777770996094
259 26.286794662475586
260 26.22595977783203
261 26.166259765625
262 26.1076717376709
263 26.0501766204834
264 25.993741989135742
265 25.938358306884766
266 25.883996963500977
267 25.830644607543945
268 25.778274536132812
269 25.726869583129883
270 25.67641258239746
271 25.626882553100586
272 25.57826805114746
273 25.53053855895996
274 25.483688354492188
275 25.43768882751465
276 25.39253044128418
277 25.34819984436035
278 25.304670333862305
279 25.26193618774414
280 25.2199764251709
281 25.178773880004883
282 25.138320922851562
283 25.09859848022461
284 25.05959701538086
285 25.02129364013672
286 24.983678817749023
287 24.946739196777344
288 24.910465240478516
289 24.87483787536621
290 24.8398494720459
291 24.805482864379883
292 24.771726608276367
293 24.738569259643555
294 24.706003189086914
295 24.67401123046875
296 24.6425838470459
297 24.61171531677246
298 24.581388473510742
299 24.551593780517578
画出损失和训练次数的图
from pylab import mpl
# 设置显示中文字体
mpl.rcParams["font.sans-serif"] = ["SimHei"]
plt.plot([i for i in range(len(epoch_loss))],epoch_loss)
plt.xlabel('训练次数')
plt.ylabel('损失值')
plt.grid(True, linestyle='--', alpha=0.5)
plt.title("训练次数和损失值关系图", fontsize=12)
plt.show()
6.预测
y_pred = model1(test_data).detach().numpy()
test_label_numpy = test_label.detach().numpy()
plt.plot(y_pred ,color='b',label="预测值")
plt.plot(test_label_numpy,color='r',label="真实值")
plt.xlabel('预测样本点')
plt.ylabel('预测值')
plt.title("预测值和真实值拟合检验图", fontsize=12)
plt.grid(True, linestyle='--', alpha=0.5)
plt.legend(loc="best")
plt.show()
F:\SoftWare\ProgramFiles\anaconda3\envs\pytorch\lib\site-packages\IPython\core\pylabtools.py:151: UserWarning: Glyph 8722 (\N{MINUS SIGN}) missing from current font.
fig.canvas.print_figure(bytes_io, **kw)
7.总结
在训练的过程中,发现损失值一直为80.不下降。经过一上午的调试,发现问题在于:标签的维度是一维的,导致训练有问题,经过改正,训练正常。
并且预测值和真实值基本吻合。