L05_课后练习_波士顿房价预测

news2024/11/25 9:42:44

波士顿房价预测

1.导入所需要的库

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import random
from sklearn.model_selection import train_test_split

2.读入数据

feature = pd.read_csv("../data/boston.csv")
feature
CRIMZNINDUSCHASNOXRMAGEDISRADTAXPTRATIOBLSTATMEDV
00.0063218.02.3100.5386.57565.24.0900129615.3396.904.9824.0
10.027310.07.0700.4696.42178.94.9671224217.8396.909.1421.6
20.027290.07.0700.4697.18561.14.9671224217.8392.834.0334.7
30.032370.02.1800.4586.99845.86.0622322218.7394.632.9433.4
40.069050.02.1800.4587.14754.26.0622322218.7396.905.3336.2
.............................................
5010.062630.011.9300.5736.59369.12.4786127321.0391.999.6722.4
5020.045270.011.9300.5736.12076.72.2875127321.0396.909.0820.6
5030.060760.011.9300.5736.97691.02.1675127321.0396.905.6423.9
5040.109590.011.9300.5736.79489.32.3889127321.0393.456.4822.0
5050.047410.011.9300.5736.03080.82.5050127321.0396.907.8811.9

506 rows × 14 columns

3.数据预处理

(1).将特征和标签分离

获取标签数据并转变成darray数据类型

label = feature['MEDV']
# label
label = np.array(label)
label,label.dtype,len(label)
(array([24. , 21.6, 34.7, 33.4, 36.2, 28.7, 22.9, 27.1, 16.5, 18.9, 15. ,
        18.9, 21.7, 20.4, 18.2, 19.9, 23.1, 17.5, 20.2, 18.2, 13.6, 19.6,
        15.2, 14.5, 15.6, 13.9, 16.6, 14.8, 18.4, 21. , 12.7, 14.5, 13.2,
        13.1, 13.5, 18.9, 20. , 21. , 24.7, 30.8, 34.9, 26.6, 25.3, 24.7,
        21.2, 19.3, 20. , 16.6, 14.4, 19.4, 19.7, 20.5, 25. , 23.4, 18.9,
        35.4, 24.7, 31.6, 23.3, 19.6, 18.7, 16. , 22.2, 25. , 33. , 23.5,
        19.4, 22. , 17.4, 20.9, 24.2, 21.7, 22.8, 23.4, 24.1, 21.4, 20. ,
        20.8, 21.2, 20.3, 28. , 23.9, 24.8, 22.9, 23.9, 26.6, 22.5, 22.2,
        23.6, 28.7, 22.6, 22. , 22.9, 25. , 20.6, 28.4, 21.4, 38.7, 43.8,
        33.2, 27.5, 26.5, 18.6, 19.3, 20.1, 19.5, 19.5, 20.4, 19.8, 19.4,
        21.7, 22.8, 18.8, 18.7, 18.5, 18.3, 21.2, 19.2, 20.4, 19.3, 22. ,
        20.3, 20.5, 17.3, 18.8, 21.4, 15.7, 16.2, 18. , 14.3, 19.2, 19.6,
        23. , 18.4, 15.6, 18.1, 17.4, 17.1, 13.3, 17.8, 14. , 14.4, 13.4,
        15.6, 11.8, 13.8, 15.6, 14.6, 17.8, 15.4, 21.5, 19.6, 15.3, 19.4,
        17. , 15.6, 13.1, 41.3, 24.3, 23.3, 27. , 50. , 50. , 50. , 22.7,
        25. , 50. , 23.8, 23.8, 22.3, 17.4, 19.1, 23.1, 23.6, 22.6, 29.4,
        23.2, 24.6, 29.9, 37.2, 39.8, 36.2, 37.9, 32.5, 26.4, 29.6, 50. ,
        32. , 29.8, 34.9, 37. , 30.5, 36.4, 31.1, 29.1, 50. , 33.3, 30.3,
        34.6, 34.9, 32.9, 24.1, 42.3, 48.5, 50. , 22.6, 24.4, 22.5, 24.4,
        20. , 21.7, 19.3, 22.4, 28.1, 23.7, 25. , 23.3, 28.7, 21.5, 23. ,
        26.7, 21.7, 27.5, 30.1, 44.8, 50. , 37.6, 31.6, 46.7, 31.5, 24.3,
        31.7, 41.7, 48.3, 29. , 24. , 25.1, 31.5, 23.7, 23.3, 22. , 20.1,
        22.2, 23.7, 17.6, 18.5, 24.3, 20.5, 24.5, 26.2, 24.4, 24.8, 29.6,
        42.8, 21.9, 20.9, 44. , 50. , 36. , 30.1, 33.8, 43.1, 48.8, 31. ,
        36.5, 22.8, 30.7, 50. , 43.5, 20.7, 21.1, 25.2, 24.4, 35.2, 32.4,
        32. , 33.2, 33.1, 29.1, 35.1, 45.4, 35.4, 46. , 50. , 32.2, 22. ,
        20.1, 23.2, 22.3, 24.8, 28.5, 37.3, 27.9, 23.9, 21.7, 28.6, 27.1,
        20.3, 22.5, 29. , 24.8, 22. , 26.4, 33.1, 36.1, 28.4, 33.4, 28.2,
        22.8, 20.3, 16.1, 22.1, 19.4, 21.6, 23.8, 16.2, 17.8, 19.8, 23.1,
        21. , 23.8, 23.1, 20.4, 18.5, 25. , 24.6, 23. , 22.2, 19.3, 22.6,
        19.8, 17.1, 19.4, 22.2, 20.7, 21.1, 19.5, 18.5, 20.6, 19. , 18.7,
        32.7, 16.5, 23.9, 31.2, 17.5, 17.2, 23.1, 24.5, 26.6, 22.9, 24.1,
        18.6, 30.1, 18.2, 20.6, 17.8, 21.7, 22.7, 22.6, 25. , 19.9, 20.8,
        16.8, 21.9, 27.5, 21.9, 23.1, 50. , 50. , 50. , 50. , 50. , 13.8,
        13.8, 15. , 13.9, 13.3, 13.1, 10.2, 10.4, 10.9, 11.3, 12.3,  8.8,
         7.2, 10.5,  7.4, 10.2, 11.5, 15.1, 23.2,  9.7, 13.8, 12.7, 13.1,
        12.5,  8.5,  5. ,  6.3,  5.6,  7.2, 12.1,  8.3,  8.5,  5. , 11.9,
        27.9, 17.2, 27.5, 15. , 17.2, 17.9, 16.3,  7. ,  7.2,  7.5, 10.4,
         8.8,  8.4, 16.7, 14.2, 20.8, 13.4, 11.7,  8.3, 10.2, 10.9, 11. ,
         9.5, 14.5, 14.1, 16.1, 14.3, 11.7, 13.4,  9.6,  8.7,  8.4, 12.8,
        10.5, 17.1, 18.4, 15.4, 10.8, 11.8, 14.9, 12.6, 14.1, 13. , 13.4,
        15.2, 16.1, 17.8, 14.9, 14.1, 12.7, 13.5, 14.9, 20. , 16.4, 17.7,
        19.5, 20.2, 21.4, 19.9, 19. , 19.1, 19.1, 20.1, 19.9, 19.6, 23.2,
        29.8, 13.8, 13.3, 16.7, 12. , 14.6, 21.4, 23. , 23.7, 25. , 21.8,
        20.6, 21.2, 19.1, 20.6, 15.2,  7. ,  8.1, 13.6, 20.1, 21.8, 24.5,
        23.1, 19.7, 18.3, 21.2, 17.5, 16.8, 22.4, 20.6, 23.9, 22. , 11.9]),
 dtype('float64'),
 506)

对特征进行提取并转为darray类型

feature = feature.drop('MEDV',axis=1)
data = np.array(feature)
data
array([[6.3200e-03, 1.8000e+01, 2.3100e+00, ..., 1.5300e+01, 3.9690e+02,
        4.9800e+00],
       [2.7310e-02, 0.0000e+00, 7.0700e+00, ..., 1.7800e+01, 3.9690e+02,
        9.1400e+00],
       [2.7290e-02, 0.0000e+00, 7.0700e+00, ..., 1.7800e+01, 3.9283e+02,
        4.0300e+00],
       ...,
       [6.0760e-02, 0.0000e+00, 1.1930e+01, ..., 2.1000e+01, 3.9690e+02,
        5.6400e+00],
       [1.0959e-01, 0.0000e+00, 1.1930e+01, ..., 2.1000e+01, 3.9345e+02,
        6.4800e+00],
       [4.7410e-02, 0.0000e+00, 1.1930e+01, ..., 2.1000e+01, 3.9690e+02,
        7.8800e+00]])

(2).对数据做归一化处理

由于输入的数据中不同的特征值数值大小差异很大,所以神经网络会认为数值大的特征更重要,以此影响了预测结果。

因此我们需要归一化处理,让不同的特征值大小在一个级别上

# !pip install scikit-learn
from sklearn import preprocessing
data = preprocessing.StandardScaler().fit_transform(data)
data
array([[-0.41978194,  0.28482986, -1.2879095 , ..., -1.45900038,
         0.44105193, -1.0755623 ],
       [-0.41733926, -0.48772236, -0.59338101, ..., -0.30309415,
         0.44105193, -0.49243937],
       [-0.41734159, -0.48772236, -0.59338101, ..., -0.30309415,
         0.39642699, -1.2087274 ],
       ...,
       [-0.41344658, -0.48772236,  0.11573841, ...,  1.17646583,
         0.44105193, -0.98304761],
       [-0.40776407, -0.48772236,  0.11573841, ...,  1.17646583,
         0.4032249 , -0.86530163],
       [-0.41500016, -0.48772236,  0.11573841, ...,  1.17646583,
         0.44105193, -0.66905833]])

(3).划分训练集和测试集

train_data,test_data,train_label,test_label = train_test_split(data,label,test_size=0.2)   
# train_data
test_data,test_label
(array([[-0.40301721, -0.48772236,  2.11761463, ...,  0.29797709,
          0.14950964,  1.78818809],
        [ 2.59827406, -0.48772236,  1.01599907, ...,  0.80657583,
         -2.51428122,  1.96060184],
        [ 1.32780421, -0.48772236,  1.01599907, ...,  0.80657583,
         -0.07887794,  1.7181012 ],
        ...,
        [-0.39825173,  0.45650813, -0.76993132, ...,  0.29797709,
          0.35585887,  0.8139803 ],
        [-0.41613247,  2.94584308, -1.40317788, ..., -2.70737911,
          0.38951945, -0.8456773 ],
        [-0.32270106, -0.48772236, -0.43725801, ...,  1.17646583,
         -0.58389629,  0.54064142]]),
 array([17.3, 10.4, 12. , 26.4, 24.4, 20.8, 28.7, 18.3, 14.5, 22.6, 21.2,
        15.2, 29.1, 22.9, 13.1, 21.4, 21.4, 24.1, 11.8, 21.7, 17.8, 23.2,
        19.5, 23.3, 33.1, 14.1, 26.7, 36.5, 14.1, 16.8,  7. , 17.7, 24.8,
        22.2, 21. , 29.1,  8.7, 27.1,  9.7, 20.4, 15.2, 16.5, 33.2, 19.8,
        35.4, 22. , 21.2, 20.7, 10.9, 20.9,  8.8, 10.5, 31.5, 21.9, 28.7,
        25. , 17.9,  7.2, 19.4, 29. , 14. , 42.3, 22.8, 29.9, 14.5, 17.5,
        22.6,  7. , 36. , 21.2, 20.6, 15. ,  9.5, 10.8, 24.8, 50. , 16.4,
        19.6, 23.7, 24.5, 15.6, 17.8, 16.2, 20.6, 19.9, 19.6, 19. ,  8.5,
        22.8, 21. , 21.7, 34.9, 48.5, 15.6, 11.7, 12.5, 17.1, 17.5, 28.4,
        18.5, 34.6, 13.9]))

(4).转换为tensor张量

# train_data = torch.tensor(train_data,dtype=torch.float32)
# train_label = torch.tensor(train_label,dtype=torch.float32)
# test_data = torch.tensor(test_data,dtype=torch.float32)
# test_label = torch.tensor(test_label,dtype=torch.float32)
train_data=torch.tensor(train_data,dtype=float,requires_grad=True).to(torch.float32)
train_label=torch.tensor(train_label,dtype=float,requires_grad=True).to(torch.float32).reshape(-1,1)
test_data=torch.tensor(test_data,dtype=float,requires_grad=True).to(torch.float32)
test_label=torch.tensor(test_label,dtype=float,requires_grad=True).to(torch.float32).reshape(-1,1)
train_data.size(),train_label.shape,test_data.shape,test_label.shape,train_data.type(),train_label

(torch.Size([404, 13]),
 torch.Size([404, 1]),
 torch.Size([102, 13]),
 torch.Size([102, 1]),
 'torch.FloatTensor',
 tensor([[19.1000],
         [33.3000],
         [15.1000],
         [15.4000],
         [22.5000],
         [19.9000],
         [30.7000],
         [19.5000],
         [31.2000],
         [19.3000],
         [13.1000],
         [31.6000],
         [43.8000],
         [15.6000],
         [50.0000],
         [29.6000],
         [19.3000],
         [20.2000],
         [15.4000],
         [19.8000],
         [12.8000],
         [21.7000],
         [18.5000],
         [21.2000],
         [34.9000],
         [31.5000],
         [21.1000],
         [23.6000],
         [19.1000],
         [ 5.0000],
         [10.2000],
         [20.5000],
         [15.0000],
         [13.4000],
         [24.5000],
         [18.7000],
         [21.4000],
         [29.4000],
         [20.4000],
         [27.9000],
         [42.8000],
         [23.1000],
         [33.4000],
         [ 9.6000],
         [22.2000],
         [17.4000],
         [22.3000],
         [14.3000],
         [24.4000],
         [20.0000],
         [18.4000],
         [18.5000],
         [27.0000],
         [11.9000],
         [22.6000],
         [50.0000],
         [13.1000],
         [23.3000],
         [16.1000],
         [17.2000],
         [ 8.5000],
         [17.2000],
         [13.0000],
         [29.6000],
         [25.0000],
         [22.4000],
         [20.1000],
         [16.6000],
         [50.0000],
         [20.5000],
         [22.3000],
         [19.2000],
         [17.4000],
         [30.1000],
         [ 7.2000],
         [13.8000],
         [14.8000],
         [19.4000],
         [19.3000],
         [22.0000],
         [28.6000],
         [18.7000],
         [13.2000],
         [11.9000],
         [32.9000],
         [50.0000],
         [20.3000],
         [22.9000],
         [20.1000],
         [30.1000],
         [36.1000],
         [15.3000],
         [18.9000],
         [17.8000],
         [27.1000],
         [34.7000],
         [28.7000],
         [18.2000],
         [18.5000],
         [13.4000],
         [18.7000],
         [16.3000],
         [28.0000],
         [19.1000],
         [45.4000],
         [18.4000],
         [13.5000],
         [21.9000],
         [18.9000],
         [10.2000],
         [23.7000],
         [19.9000],
         [22.8000],
         [37.9000],
         [24.6000],
         [19.6000],
         [23.7000],
         [18.2000],
         [22.7000],
         [24.5000],
         [19.6000],
         [16.8000],
         [50.0000],
         [13.3000],
         [20.8000],
         [23.3000],
         [24.0000],
         [ 8.3000],
         [50.0000],
         [24.7000],
         [41.7000],
         [17.8000],
         [16.0000],
         [19.9000],
         [26.6000],
         [36.4000],
         [13.8000],
         [29.0000],
         [ 6.3000],
         [19.0000],
         [22.0000],
         [20.4000],
         [25.2000],
         [23.4000],
         [29.8000],
         [21.4000],
         [22.6000],
         [23.1000],
         [14.4000],
         [21.7000],
         [21.6000],
         [28.5000],
         [21.8000],
         [18.3000],
         [ 7.4000],
         [14.1000],
         [17.1000],
         [26.6000],
         [21.8000],
         [19.1000],
         [14.6000],
         [13.8000],
         [16.6000],
         [20.1000],
         [13.3000],
         [48.8000],
         [21.0000],
         [34.9000],
         [11.8000],
         [23.9000],
         [32.5000],
         [20.0000],
         [25.1000],
         [28.4000],
         [23.7000],
         [13.6000],
         [12.7000],
         [21.9000],
         [24.4000],
         [26.5000],
         [18.6000],
         [50.0000],
         [37.2000],
         [31.1000],
         [20.3000],
         [20.4000],
         [22.5000],
         [24.3000],
         [20.1000],
         [14.5000],
         [23.1000],
         [32.0000],
         [23.0000],
         [13.8000],
         [18.1000],
         [22.2000],
         [14.3000],
         [23.1000],
         [19.7000],
         [21.1000],
         [22.5000],
         [ 8.1000],
         [20.7000],
         [22.4000],
         [38.7000],
         [21.7000],
         [25.3000],
         [46.7000],
         [19.5000],
         [19.8000],
         [14.9000],
         [17.6000],
         [26.2000],
         [ 7.5000],
         [15.2000],
         [12.3000],
         [13.4000],
         [23.2000],
         [23.9000],
         [20.0000],
         [20.3000],
         [50.0000],
         [23.0000],
         [20.3000],
         [23.0000],
         [14.4000],
         [29.8000],
         [ 5.6000],
         [20.9000],
         [30.8000],
         [24.0000],
         [24.8000],
         [22.0000],
         [23.4000],
         [16.7000],
         [23.8000],
         [25.0000],
         [22.0000],
         [26.4000],
         [30.3000],
         [19.4000],
         [50.0000],
         [15.0000],
         [19.3000],
         [21.4000],
         [12.7000],
         [18.8000],
         [13.3000],
         [23.9000],
         [25.0000],
         [24.6000],
         [50.0000],
         [14.6000],
         [20.0000],
         [37.3000],
         [ 8.4000],
         [32.2000],
         [16.1000],
         [28.2000],
         [20.6000],
         [31.7000],
         [18.8000],
         [27.9000],
         [21.5000],
         [23.9000],
         [11.0000],
         [20.0000],
         [19.3000],
         [18.2000],
         [13.1000],
         [32.7000],
         [31.0000],
         [18.9000],
         [10.2000],
         [15.7000],
         [17.5000],
         [20.2000],
         [27.5000],
         [20.6000],
         [19.4000],
         [18.0000],
         [20.6000],
         [22.0000],
         [33.1000],
         [25.0000],
         [19.6000],
         [ 8.3000],
         [36.2000],
         [14.9000],
         [17.1000],
         [23.1000],
         [50.0000],
         [27.5000],
         [17.8000],
         [19.4000],
         [19.2000],
         [27.5000],
         [33.2000],
         [24.4000],
         [22.9000],
         [22.0000],
         [31.6000],
         [41.3000],
         [17.4000],
         [22.1000],
         [24.7000],
         [17.2000],
         [48.3000],
         [10.5000],
         [32.4000],
         [10.4000],
         [35.1000],
         [22.8000],
         [19.4000],
         [22.6000],
         [16.2000],
         [20.6000],
         [13.8000],
         [13.6000],
         [23.0000],
         [24.2000],
         [44.8000],
         [25.0000],
         [50.0000],
         [23.8000],
         [16.1000],
         [18.6000],
         [19.5000],
         [18.4000],
         [20.5000],
         [24.7000],
         [25.0000],
         [28.1000],
         [19.7000],
         [39.8000],
         [25.0000],
         [11.3000],
         [26.6000],
         [23.1000],
         [21.5000],
         [13.9000],
         [27.5000],
         [15.6000],
         [36.2000],
         [ 8.4000],
         [13.4000],
         [32.0000],
         [50.0000],
         [23.2000],
         [33.0000],
         [13.5000],
         [16.5000],
         [ 8.8000],
         [18.9000],
         [ 5.0000],
         [23.8000],
         [23.8000],
         [14.2000],
         [43.1000],
         [46.0000],
         [20.1000],
         [24.3000],
         [33.8000],
         [21.7000],
         [15.6000],
         [35.2000],
         [22.7000],
         [24.1000],
         [30.1000],
         [12.1000],
         [12.7000],
         [43.5000],
         [23.9000],
         [22.2000],
         [21.7000],
         [24.1000],
         [11.5000],
         [20.8000],
         [14.9000],
         [16.7000],
         [23.2000],
         [23.5000],
         [21.2000],
         [ 7.2000],
         [10.9000],
         [35.4000],
         [24.3000],
         [21.6000],
         [23.6000],
         [11.7000],
         [17.0000],
         [50.0000],
         [44.0000],
         [33.4000],
         [30.5000],
         [24.8000],
         [37.6000],
         [37.0000],
         [12.6000],
         [22.2000],
         [22.9000],
         [50.0000],
         [23.1000],
         [23.3000]], grad_fn=<ReshapeAliasBackward0>))

4.定义模型

class LinearModel(torch.nn.Module):
    def __init__(self):
        super(LinearModel,self).__init__()
        self.linear = torch.nn.Linear(13, 1)
    def forward(self, x):
        y_pred = self.linear(x)
        return y_pred
model1 = LinearModel()
model1
LinearModel(
  (linear): Linear(in_features=13, out_features=1, bias=True)
)
criterion = torch.nn.MSELoss(reduction='mean')
optimizer = torch.optim.SGD(model1.parameters(), lr =0.005)

5.训练

epoch_loss = []
for epoch in range(300):
    y_pred = model1(train_data)
    loss = criterion(y_pred,train_label)
    print(epoch,loss.item())
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
#     print('epoch',epoch,'loss=',loss.item())
    epoch_loss.append(loss.item())
    
0 624.7080688476562
1 608.4896240234375
2 593.03759765625
3 578.2877197265625
4 564.1824951171875
5 550.6704711914062
6 537.7056274414062
7 525.24658203125
8 513.2561645507812
9 501.7012023925781
10 490.5516052246094
11 479.7803955078125
12 469.3631591796875
13 459.2780456542969
14 449.50494384765625
15 440.0260009765625
16 430.8248291015625
17 421.886474609375
18 413.19757080078125
19 404.74554443359375
20 396.51934814453125
21 388.50848388671875
22 380.7034606933594
23 373.09564208984375
24 365.6768493652344
25 358.43975830078125
26 351.37744140625
27 344.48345947265625
28 337.7518615722656
29 331.17706298828125
30 324.7538757324219
31 318.47747802734375
32 312.3431701660156
33 306.3467102050781
34 300.4839782714844
35 294.75115966796875
36 289.1445617675781
37 283.6607666015625
38 278.29638671875
39 273.04827880859375
40 267.91339111328125
41 262.888916015625
42 257.9720458984375
43 253.1600799560547
44 248.45053100585938
45 243.84091186523438
46 239.3287811279297
47 234.91195678710938
48 230.58810424804688
49 226.35516357421875
50 222.2110137939453
51 218.15365600585938
52 214.18113708496094
53 210.29153442382812
54 206.48306274414062
55 202.7538604736328
56 199.10223388671875
57 195.52645874023438
58 192.02490234375
59 188.59597778320312
60 185.23800659179688
61 181.94960021972656
62 178.7291717529297
63 175.57528686523438
64 172.48654174804688
65 169.46153259277344
66 166.49891662597656
67 163.5973663330078
68 160.75558471679688
69 157.97230529785156
70 155.24632263183594
71 152.57640075683594
72 149.9613494873047
73 147.40000915527344
74 144.89129638671875
75 142.43408203125
76 140.02731323242188
77 137.66993713378906
78 135.3608856201172
79 133.09915161132812
80 130.88375854492188
81 128.7137451171875
82 126.58817291259766
83 124.5060806274414
84 122.46662139892578
85 120.4688491821289
86 118.51192474365234
87 116.59502410888672
88 114.71727752685547
89 112.87786865234375
90 111.07605743408203
91 109.31100463867188
92 107.58196258544922
93 105.88819122314453
94 104.2289810180664
95 102.60359191894531
96 101.01134490966797
97 99.45150756835938
98 97.92346954345703
99 96.42654418945312
100 94.9600830078125
101 93.52346801757812
102 92.1160659790039
103 90.73729705810547
104 89.38656616210938
105 88.06329345703125
106 86.76689910888672
107 85.4968490600586
108 84.25257873535156
109 83.0335693359375
110 81.83930969238281
111 80.66928100585938
112 79.5229721069336
113 78.39990997314453
114 77.29962158203125
115 76.22161865234375
116 75.16545867919922
117 74.13068389892578
118 73.11685180664062
119 72.12356567382812
120 71.15036010742188
121 70.19683837890625
122 69.2625961303711
123 68.34725189208984
124 67.45038604736328
125 66.57164001464844
126 65.71065521240234
127 64.86703491210938
128 64.04045104980469
129 63.230552673339844
130 62.436981201171875
131 61.659427642822266
132 60.89752960205078
133 60.15097427368164
134 59.419464111328125
135 58.70269012451172
136 58.00033950805664
137 57.3121223449707
138 56.63774871826172
139 55.97693634033203
140 55.32942199707031
141 54.694915771484375
142 54.07315444946289
143 53.46388626098633
144 52.8668327331543
145 52.28177261352539
146 51.70844268798828
147 51.1466178894043
148 50.59606170654297
149 50.05653381347656
150 49.527801513671875
151 49.009674072265625
152 48.50191879272461
153 48.00431823730469
154 47.51667785644531
155 47.038795471191406
156 46.57045364379883
157 46.11146926879883
158 45.66167068481445
159 45.220829010009766
160 44.7888069152832
161 44.36539840698242
162 43.950435638427734
163 43.54374313354492
164 43.1451530456543
165 42.7545051574707
166 42.37163162231445
167 41.99638366699219
168 41.62860107421875
169 41.268123626708984
170 40.91482162475586
171 40.568538665771484
172 40.22915267944336
173 39.89649200439453
174 39.570430755615234
175 39.25083541870117
176 38.93758010864258
177 38.63054275512695
178 38.329586029052734
179 38.03458023071289
180 37.74541091918945
181 37.461971282958984
182 37.18413162231445
183 36.91177749633789
184 36.644813537597656
185 36.38311767578125
186 36.1265869140625
187 35.8751220703125
188 35.62860107421875
189 35.38695526123047
190 35.150062561035156
191 34.91782760620117
192 34.69016647338867
193 34.466983795166016
194 34.24818801879883
195 34.0337028503418
196 33.823421478271484
197 33.61726760864258
198 33.415157318115234
199 33.21701431274414
200 33.02274703979492
201 32.83229064941406
202 32.64555740356445
203 32.46246337890625
204 32.282962799072266
205 32.106956481933594
206 31.934391021728516
207 31.76519203186035
208 31.599294662475586
209 31.436626434326172
210 31.277130126953125
211 31.12073516845703
212 30.967384338378906
213 30.8170223236084
214 30.669580459594727
215 30.524993896484375
216 30.383211135864258
217 30.244182586669922
218 30.10784339904785
219 29.97414779663086
220 29.843036651611328
221 29.71446418762207
222 29.58837890625
223 29.464717864990234
224 29.34345054626465
225 29.224512100219727
226 29.107872009277344
227 28.99346923828125
228 28.881269454956055
229 28.771228790283203
230 28.66329574584961
231 28.557437896728516
232 28.45360565185547
233 28.351764678955078
234 28.251873016357422
235 28.15389060974121
236 28.05777931213379
237 27.963502883911133
238 27.87102699279785
239 27.780311584472656
240 27.69132423400879
241 27.60402488708496
242 27.51839256286621
243 27.434383392333984
244 27.35196304321289
245 27.27110481262207
246 27.19178581237793
247 27.11396026611328
248 27.03760528564453
249 26.962692260742188
250 26.889190673828125
251 26.817075729370117
252 26.74631690979004
253 26.67688751220703
254 26.608760833740234
255 26.54191017150879
256 26.476316452026367
257 26.411943435668945
258 26.348777770996094
259 26.286794662475586
260 26.22595977783203
261 26.166259765625
262 26.1076717376709
263 26.0501766204834
264 25.993741989135742
265 25.938358306884766
266 25.883996963500977
267 25.830644607543945
268 25.778274536132812
269 25.726869583129883
270 25.67641258239746
271 25.626882553100586
272 25.57826805114746
273 25.53053855895996
274 25.483688354492188
275 25.43768882751465
276 25.39253044128418
277 25.34819984436035
278 25.304670333862305
279 25.26193618774414
280 25.2199764251709
281 25.178773880004883
282 25.138320922851562
283 25.09859848022461
284 25.05959701538086
285 25.02129364013672
286 24.983678817749023
287 24.946739196777344
288 24.910465240478516
289 24.87483787536621
290 24.8398494720459
291 24.805482864379883
292 24.771726608276367
293 24.738569259643555
294 24.706003189086914
295 24.67401123046875
296 24.6425838470459
297 24.61171531677246
298 24.581388473510742
299 24.551593780517578

画出损失和训练次数的图

from pylab import mpl
# 设置显示中文字体
mpl.rcParams["font.sans-serif"] = ["SimHei"]
plt.plot([i for i in range(len(epoch_loss))],epoch_loss)
plt.xlabel('训练次数')
plt.ylabel('损失值')
plt.grid(True, linestyle='--', alpha=0.5)
plt.title("训练次数和损失值关系图", fontsize=12)
plt.show()

在这里插入图片描述

6.预测

y_pred = model1(test_data).detach().numpy()
test_label_numpy = test_label.detach().numpy()
plt.plot(y_pred ,color='b',label="预测值")
plt.plot(test_label_numpy,color='r',label="真实值")
plt.xlabel('预测样本点')
plt.ylabel('预测值')
plt.title("预测值和真实值拟合检验图", fontsize=12)
plt.grid(True, linestyle='--', alpha=0.5)
plt.legend(loc="best")
plt.show()
F:\SoftWare\ProgramFiles\anaconda3\envs\pytorch\lib\site-packages\IPython\core\pylabtools.py:151: UserWarning: Glyph 8722 (\N{MINUS SIGN}) missing from current font.
  fig.canvas.print_figure(bytes_io, **kw)

在这里插入图片描述

7.总结

在训练的过程中,发现损失值一直为80.不下降。经过一上午的调试,发现问题在于:标签的维度是一维的,导致训练有问题,经过改正,训练正常。
并且预测值和真实值基本吻合。


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/116315.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Postgresql源码(95)优化器关键数据结构实例

1 测试数据 drop table student; create table student(sno int primary key, sname varchar(10), ssex int); insert into student values(1, stu1, 0); insert into student values(2, stu2, 1); insert into student values(3, stu3, 1); insert into student values(4, st…

Linux的环境变量

目录 什么是环境变量&#xff1f; 那么如何添加自己的程序到命令行上&#xff0c;可以直接执行&#xff1f; 如何查看环境变量&#xff1f; 如何定义环境变量&#xff1f; C语言如何获取环境变量&#xff1f; 什么是环境变量&#xff1f; 在回答这个问题之前&#xff0c;可…

【Linux】之systemd与systemctl

文章目录一、systemd1. systemd 守护进程管理 Linux 的启动2. systemd 提供的功能:3. systemd 使用单元来管理不同类型的对象。4. 服务单元信息二、systemctl1. systemctl输出中的服务状态2. 列出servera上安装的所以服务单元3. 列出servera上所有活动和不活动的套接字单元4.1 …

存档鉴未来,新时代电子档案长期保存之道

&#xff08;一&#xff09;电子档案单套制是未来档案管理的趋势 在政府和企业端&#xff0c;办公体系的信息化&#xff0c;电子档案的单套制实施&#xff0c;极大地提高了事务办理的效率&#xff0c;同时节约了大量纸质文件的使用成本。 在政务领域&#xff0c;单套制的推行…

Day840.原子类-Java 并发编程实战

原子类 Hi&#xff0c;我是阿昌&#xff0c;今天学习记录的是关于原子类。 一个累加器的例子&#xff0c;示例代码如下&#xff1a; 在这个例子中&#xff0c;add10K() 这个方法不是线程安全的&#xff0c;问题就出在变量 count 的可见性和 count1 的原子性上。 可见性问题…

Java7的异常处理新特性addSuppressed()方法

学习使用Java7新语法try-with-resources&#xff0c;在查看编译文件时&#xff0c;接触到addSuppressed()方法。记录一下使用方式。 先来看一段代码&#xff1a; private static void testt() {try (InputStream is CatchTest.class.getClassLoader().getResourceAsStream(&…

ThinkPHP 多应用模式初探

还是很久以前用tp3.0开发过项目&#xff0c;之后就再没使用过&#xff0c;现在tp都更新到6了&#xff0c;与之前差距很大&#xff0c;需要重新练习掌握最新的tp框架使用及特性。 目录 1.安装框架 2.安装多应用模式扩展think-multi-app 3.目录结构修改并创建应用子目录 4.应…

年后市场将反弹?服装人做好这些准备,才能赚到2023年第一桶金!

目前&#xff0c;随着防疫政策精准落地、逐步放开&#xff0c;人们对疫情的科学认知不断更新&#xff0c;市场活跃度正逐步恢复。秦丝通过与数万服装老板沟通交流&#xff0c;发现新的模式也在渐渐兴起&#xff0c;国内服装市场将有望迎来反弹。 1、消费氛围活跃&#xff0c;市…

善网ESG周报(第六期)

ESG报告&#xff1a; 宁夏建投城运首份社会责任&#xff08;ESG&#xff09;报告正式发布 12月20日&#xff0c;宁夏建投城市运营管理有限公司发布首份ESG报告。报告显示&#xff0c;其公司将业务与环境保护、社会责任、公司治理相结合打造一条绿色发展道路。 国寿股权投资发…

滚动条基本样式设置

::-webkit-scrollbar 系列属性 详细使用说明 ::-webkit-scrollbar注意&#xff1a;如果没有设置滚动溢出的相关属性&#xff0c;滚动条样式系列属性不会生效&#xff08;resize 除外&#xff09;。属性 ::-webkit-scrollbar 整个滚动条。::-webkit-scrollbar-button 滚动条上的…

Vue3组件化开发(一)

文章目录p11 组件组件的拆分和嵌套组件的CSS作用域组件的通信父子组件的通信父组件传递给子组件props的对象用法非prop的attribute子组件传递给父组件案例p11 组件 组件的拆分和嵌套 推荐安装的VS Cdoe插件 组件的CSS作用域 组件的通信 父子组件的通信 父组件传递给子组件…

模型初始化

在深度学习模型训练中&#xff0c;权重初始值极为重要&#xff0c;一个好的初始值会使得模型收敛速度提高&#xff0c;使模型准确率更准确&#xff0c;一般情况下&#xff0c;我们不使用全零初始值训练网络&#xff0c;为了利于训练和减少收敛时间&#xff0c;我们需要对模型进…

从入门到项目实战 - Vue 计算属性用法解析

Vue 计算属性用法解析上一节&#xff1a;《Vue 监听器用法解析 》| 下一节&#xff1a;《Vue 样式绑定》jcLee95 邮箱 &#xff1a;291148484163.com CSDN 主页&#xff1a;https://blog.csdn.net/qq_28550263?spm1001.2101.3001.5343 本文地址&#xff1a;https://blog.…

衣服、商品、商城网站模板首页,仿U袋网,vue+elementui简洁实现(二)

一.前言 接上一遍博客&#xff1a;《衣服、商品、商城网站模板首页&#xff0c;仿U袋网&#xff0c;vueelementui简洁实现》 在此基础上增加了和完善一些页面&#xff1a; 商品分类筛选页面登录、注册、找回密码共用页面U袋学堂&#xff08;视频专区&#xff0c;视频播放&am…

编译原理——参数传递—传名、传地址、得结果、传值

1.传名&#xff08;替换操作&#xff09; 把这种方式理解为替换操作&#xff0c;把P函数参数X、Y、Z和P函数内部的Y、Z替换为A、B&#xff0c;然后P函数对Y、Z的操作&#xff0c;其实就是对A、B的操作&#xff1b;需要注意这和传地址一样&#xff0c;上面对A造成的变化&#x…

制品仓库 Nexus 安装、配置、备份、使用

目录 1.1 Nexus 优点 1.2 Nexus 仓库类型 2. 安装 Nexus 2.1 设置持久化目录 2.2 拉取 Nexus docker 镜像 2.3 运行并启动 Nexus 3. 系统配置 3.1 配置管理员密码 3.2 配置 LDAP 3.3 配置 Email 服务器 4. 配置 Repository 4.1 添加 Blob Stores 4.2 添加 Reposit…

软考高级考哪个好?

软考高级一共5个科目&#xff0c;含金量都差不多&#xff0c;每个人考证的需求各不相同&#xff0c;合适自己情况的才是最有用的证书。看你自己的工作、专业与哪个更相近&#xff0c;再来深入学习备考的&#xff0c;当然自己也要对考试取证有一定的信心。 高级科目介绍&#x…

【LeetCode每日一题】——剑指 Offer II 072.求平方根

文章目录一【题目类别】二【题目难度】三【题目编号】四【题目描述】五【题目示例】六【解题思路】七【题目提示】八【题目注意】九【时间频度】十【代码实现】十一【提交结果】一【题目类别】 二分查找 二【题目难度】 简单 三【题目编号】 剑指 Offer II 072.求平方根 …

《图解TCP/IP》阅读笔记(第七章 7.5)—— OSPF 开放最短路径优先协议

7.5 OSPF OSPF&#xff08;Open Shortest Path First&#xff0c;开放最短路径优先&#xff09;是一种链路状态性的路由协议&#xff0c;即使网络中有环路&#xff0c;也可以进行稳定的路由控制。 另外&#xff0c;OSPF支持子网掩码&#xff0c;使得在RIP中无法实现的可变长度…

在简历上写了“精通自动化测试,阿里面试官跟我死磕后就给我发了高薪 offer

事情是这样的 前段时间面试了阿里&#xff0c;大家也都清楚&#xff0c;如果你在简历上面写着你精通 XX 技术&#xff0c;那面试官就会跟你死磕到底。 我就是在自己的简历上写了精通自动化测试&#xff0c;然后就开启了和阿里面试官的死磕之路&#xff0c;结果就是拿到了一份…