Pytorch气温预测实战

news2025/1/16 13:54:58

数据集

数据有8个特征,一个标签值

自变量因变量
yearactual:当天的真实最高温度
month
day
week:星期几
temp_1:昨天的最高温度
temp_2:前天的最高温度值
average:在历史中,每年这一天的平均最高温度
friend:朋友猜测的温度
year,month,day,week,temp_2,temp_1,average,actual,friend
2016,1,1,Fri,45,45,45.6,45,29
2016,1,2,Sat,44,45,45.7,44,61
2016,1,3,Sun,45,44,45.8,41,56
2016,1,4,Mon,44,41,45.9,40,53
2016,1,5,Tues,41,40,46,44,41
2016,1,6,Wed,40,44,46.1,51,40
2016,1,7,Thurs,44,51,46.2,45,38
2016,1,8,Fri,51,45,46.3,48,34
2016,1,9,Sat,45,48,46.4,50,47
2016,1,10,Sun,48,50,46.5,52,49
2016,1,11,Mon,50,52,46.7,45,39
2016,1,12,Tues,52,45,46.8,49,61
2016,1,13,Wed,45,49,46.9,55,33
2016,1,14,Thurs,49,55,47,49,58
2016,1,15,Fri,55,49,47.1,48,65
2016,1,16,Sat,49,48,47.3,54,28
2016,1,17,Sun,48,54,47.4,50,47
2016,1,18,Mon,54,50,47.5,54,58
2016,1,19,Tues,50,54,47.6,48,53
2016,1,20,Wed,54,48,47.7,52,61
2016,1,21,Thurs,48,52,47.8,52,57
2016,1,22,Fri,52,52,47.9,57,60
2016,1,23,Sat,52,57,48,48,37
2016,1,24,Sun,57,48,48.1,51,54
2016,1,25,Mon,48,51,48.2,54,63
2016,1,26,Tues,51,54,48.3,56,61
2016,1,27,Wed,54,56,48.4,57,54
2016,1,28,Thurs,56,57,48.4,56,34
2016,1,29,Fri,57,56,48.5,52,49
2016,1,30,Sat,56,52,48.6,48,47
2016,1,31,Sun,52,48,48.7,47,61
2016,2,1,Mon,48,47,48.8,46,51
2016,2,2,Tues,47,46,48.8,51,56
2016,2,3,Wed,46,51,48.9,49,40
2016,2,4,Thurs,51,49,49,49,44
2016,2,5,Fri,49,49,49.1,53,45
2016,2,6,Sat,49,53,49.1,49,56
2016,2,7,Sun,53,49,49.2,51,63
2016,2,8,Mon,49,51,49.3,57,34
2016,2,9,Tues,51,57,49.4,62,57
2016,2,10,Wed,57,62,49.4,56,30
2016,2,11,Thurs,62,56,49.5,55,37
2016,2,12,Fri,56,55,49.6,58,33
2016,2,15,Mon,55,58,49.9,55,53
2016,2,16,Tues,58,55,49.9,56,55
2016,2,17,Wed,55,56,50,57,46
2016,2,18,Thurs,56,57,50.1,53,34
2016,2,19,Fri,57,53,50.2,51,42
2016,2,20,Sat,53,51,50.4,53,43
2016,2,21,Sun,51,53,50.5,51,46
2016,2,22,Mon,53,51,50.6,51,59
2016,2,23,Tues,51,51,50.7,60,43
2016,2,24,Wed,51,60,50.8,59,46
2016,2,25,Thurs,60,59,50.9,61,35
2016,2,26,Fri,59,61,51.1,60,65
2016,2,27,Sat,61,60,51.2,57,61
2016,2,28,Sun,60,57,51.3,53,66
2016,3,1,Tues,53,54,51.5,58,53
2016,3,2,Wed,54,58,51.6,55,37
2016,3,3,Thurs,58,55,51.8,59,71
2016,3,4,Fri,55,59,51.9,57,45
2016,3,5,Sat,59,57,52.1,64,46
2016,3,6,Sun,57,64,52.2,60,49
2016,3,7,Mon,64,60,52.4,53,71
2016,3,8,Tues,60,53,52.5,54,70
2016,3,9,Wed,53,54,52.7,55,57
2016,3,10,Thurs,54,55,52.8,56,50
2016,3,11,Fri,55,56,53,55,36
2016,3,12,Sat,56,55,53.1,52,65
2016,3,13,Sun,55,52,53.3,54,54
2016,3,14,Mon,52,54,53.4,49,44
2016,3,15,Tues,54,49,53.6,51,70
2016,3,16,Wed,49,51,53.7,53,65
2016,3,17,Thurs,51,53,53.9,58,62
2016,3,18,Fri,53,58,54,63,56
2016,3,19,Sat,58,63,54.2,61,62
2016,3,20,Sun,63,61,54.3,55,50
2016,3,21,Mon,61,55,54.5,56,52
2016,3,22,Tues,55,56,54.6,57,64
2016,3,23,Wed,56,57,54.7,53,70
2016,3,24,Thurs,57,53,54.9,54,72
2016,3,25,Fri,53,54,55,57,42
2016,3,26,Sat,54,57,55.2,59,54
2016,3,27,Sun,57,59,55.3,51,39
2016,3,28,Mon,59,51,55.5,56,47
2016,3,29,Tues,51,56,55.6,64,45
2016,3,30,Wed,56,64,55.7,68,57
2016,3,31,Thurs,64,68,55.9,73,56
2016,4,1,Fri,68,73,56,71,41
2016,4,2,Sat,73,71,56.2,63,45
2016,4,3,Sun,71,63,56.3,69,64
2016,4,4,Mon,63,69,56.5,60,45
2016,4,5,Tues,69,60,56.6,57,72
2016,4,6,Wed,60,57,56.8,68,64
2016,4,7,Thurs,57,68,56.9,77,38
2016,4,8,Fri,68,77,57.1,76,41
2016,4,9,Sat,77,76,57.2,66,74
2016,4,10,Sun,76,66,57.4,59,60
2016,4,11,Mon,66,59,57.6,58,40
2016,4,12,Tues,59,58,57.7,60,61
2016,4,13,Wed,58,60,57.9,59,77
2016,4,14,Thurs,60,59,58.1,59,66
2016,4,15,Fri,59,59,58.3,60,40
2016,4,16,Sat,59,60,58.5,68,59
2016,4,17,Sun,60,68,58.6,77,54
2016,4,18,Mon,68,77,58.8,89,39
2016,4,19,Tues,77,89,59,81,61
2016,4,20,Wed,89,81,59.2,81,66
2016,4,21,Thurs,81,81,59.4,73,55
2016,4,22,Fri,81,73,59.7,64,59
2016,4,23,Sat,73,64,59.9,65,57
2016,4,24,Sun,64,65,60.1,55,41
2016,4,25,Mon,65,55,60.3,59,77
2016,4,26,Tues,55,59,60.5,60,75
2016,4,27,Wed,59,60,60.7,61,50
2016,4,28,Thurs,60,61,61,64,73
2016,4,29,Fri,61,64,61.2,61,49
2016,4,30,Sat,64,61,61.4,68,78
2016,5,1,Sun,61,68,61.6,77,75
2016,5,2,Mon,68,77,61.9,87,59
2016,5,3,Tues,77,87,62.1,74,69
2016,5,4,Wed,87,74,62.3,60,61
2016,5,5,Thurs,74,60,62.5,68,56
2016,5,6,Fri,60,68,62.8,77,64
2016,5,7,Sat,68,77,63,82,83
2016,5,8,Sun,77,82,63.2,63,83
2016,5,9,Mon,82,63,63.4,67,64
2016,5,10,Tues,63,67,63.6,75,68
2016,5,11,Wed,67,75,63.8,81,60
2016,5,12,Thurs,75,81,64.1,77,81
2016,5,13,Fri,81,77,64.3,82,67
2016,5,14,Sat,77,82,64.5,65,65
2016,5,15,Sun,82,65,64.7,57,58
2016,5,16,Mon,65,57,64.8,60,53
2016,5,17,Tues,57,60,65,71,55
2016,5,18,Wed,60,71,65.2,64,56
2016,5,19,Thurs,71,64,65.4,63,56
2016,5,20,Fri,64,63,65.6,66,73
2016,5,21,Sat,63,66,65.7,59,49
2016,5,22,Sun,66,59,65.9,66,80
2016,5,23,Mon,59,66,66.1,65,66
2016,5,24,Tues,66,65,66.2,66,67
2016,5,25,Wed,65,66,66.4,66,60
2016,5,26,Thurs,66,66,66.5,65,85
2016,5,27,Fri,66,65,66.7,64,73
2016,5,28,Sat,65,64,66.8,64,64
2016,5,29,Sun,64,64,67,64,76
2016,5,30,Mon,64,64,67.1,71,69
2016,5,31,Tues,64,71,67.3,79,85
2016,6,1,Wed,71,79,67.4,75,58
2016,6,2,Thurs,79,75,67.6,71,77
2016,6,3,Fri,75,71,67.7,80,55
2016,6,4,Sat,71,80,67.9,81,76
2016,6,5,Sun,80,81,68,92,54
2016,6,6,Mon,81,92,68.2,86,71
2016,6,7,Tues,92,86,68.3,85,58
2016,6,8,Wed,86,85,68.5,67,81
2016,6,9,Thurs,85,67,68.6,65,80
2016,6,10,Fri,67,65,68.8,67,73
2016,6,11,Sat,65,67,69,65,87
2016,6,12,Sun,67,65,69.1,70,83
2016,6,13,Mon,65,70,69.3,66,79
2016,6,14,Tues,70,66,69.5,60,85
2016,6,15,Wed,66,60,69.7,67,69
2016,6,16,Thurs,60,67,69.8,71,87
2016,6,17,Fri,67,71,70,67,54
2016,6,18,Sat,71,67,70.2,65,77
2016,6,19,Sun,67,65,70.4,70,58
2016,6,20,Mon,65,70,70.6,76,79
2016,6,21,Tues,70,76,70.8,73,57
2016,6,22,Wed,76,73,71,75,78
2016,6,23,Thurs,73,75,71.3,68,56
2016,6,24,Fri,75,68,71.5,69,65
2016,6,25,Sat,68,69,71.7,71,89
2016,6,26,Sun,69,71,71.9,78,70
2016,6,27,Mon,71,78,72.2,85,84
2016,6,28,Tues,78,85,72.4,79,67
2016,6,29,Wed,85,79,72.6,74,81
2016,6,30,Thurs,79,74,72.8,73,87
2016,7,1,Fri,74,73,73.1,76,93
2016,7,2,Sat,73,76,73.3,76,84
2016,7,3,Sun,76,76,73.5,71,85
2016,7,4,Mon,76,71,73.8,68,86
2016,7,5,Tues,71,68,74,69,62
2016,7,6,Wed,68,69,74.2,76,86
2016,7,7,Thurs,69,76,74.4,68,72
2016,7,8,Fri,76,68,74.6,74,77
2016,7,9,Sat,68,74,74.9,71,60
2016,7,10,Sun,74,71,75.1,74,95
2016,7,11,Mon,71,74,75.3,74,71
2016,7,12,Tues,74,74,75.4,77,71
2016,7,13,Wed,74,77,75.6,75,56
2016,7,14,Thurs,77,75,75.8,77,77
2016,7,15,Fri,75,77,76,76,75
2016,7,16,Sat,77,76,76.1,72,61
2016,7,17,Sun,76,72,76.3,80,88
2016,7,18,Mon,72,80,76.4,73,66
2016,7,19,Tues,80,73,76.6,78,90
2016,7,20,Wed,73,78,76.7,82,66
2016,7,21,Thurs,78,82,76.8,81,84
2016,7,22,Fri,82,81,76.9,71,70
2016,7,23,Sat,81,71,77,75,86
2016,7,24,Sun,71,75,77.1,80,75
2016,7,25,Mon,75,80,77.1,85,81
2016,7,26,Tues,80,85,77.2,79,74
2016,7,27,Wed,85,79,77.3,83,79
2016,7,28,Thurs,79,83,77.3,85,76
2016,7,29,Fri,83,85,77.3,88,77
2016,7,30,Sat,85,88,77.3,76,70
2016,7,31,Sun,88,76,77.4,73,95
2016,8,1,Mon,76,73,77.4,77,65
2016,8,2,Tues,73,77,77.4,73,62
2016,8,3,Wed,77,73,77.3,75,93
2016,8,4,Thurs,73,75,77.3,80,66
2016,8,5,Fri,75,80,77.3,79,71
2016,8,6,Sat,80,79,77.2,72,60
2016,8,7,Sun,79,72,77.2,72,95
2016,8,8,Mon,72,72,77.1,73,65
2016,8,9,Tues,72,73,77.1,72,94
2016,8,10,Wed,73,72,77,76,68
2016,8,11,Thurs,72,76,76.9,80,80
2016,8,12,Fri,76,80,76.9,87,81
2016,8,13,Sat,80,87,76.8,90,73
2016,8,14,Sun,87,90,76.7,83,65
2016,8,15,Mon,90,83,76.6,84,70
2016,8,16,Tues,83,84,76.5,81,90
2016,8,23,Tues,84,81,75.7,79,89
2016,8,28,Sun,81,79,75,75,85
2016,8,30,Tues,79,75,74.6,70,63
2016,9,3,Sat,75,70,73.9,67,68
2016,9,4,Sun,70,67,73.7,68,64
2016,9,5,Mon,67,68,73.5,68,54
2016,9,6,Tues,68,68,73.3,68,79
2016,9,7,Wed,68,68,73,67,70
2016,9,8,Thurs,68,67,72.8,72,56
2016,9,9,Fri,67,72,72.6,74,78
2016,9,10,Sat,72,74,72.3,77,91
2016,9,11,Sun,74,77,72.1,70,70
2016,9,12,Mon,77,70,71.8,74,90
2016,9,13,Tues,70,74,71.5,75,82
2016,9,14,Wed,74,75,71.2,79,77
2016,9,15,Thurs,75,79,71,71,64
2016,9,16,Fri,79,71,70.7,75,52
2016,9,17,Sat,71,75,70.3,68,84
2016,9,18,Sun,75,68,70,69,90
2016,9,19,Mon,68,69,69.7,71,88
2016,9,20,Tues,69,71,69.4,67,81
2016,9,21,Wed,71,67,69,68,76
2016,9,22,Thurs,67,68,68.7,67,56
2016,9,23,Fri,68,67,68.3,64,61
2016,9,24,Sat,67,64,68,67,64
2016,9,25,Sun,64,67,67.6,76,62
2016,9,26,Mon,67,76,67.2,77,74
2016,9,27,Tues,76,77,66.8,69,64
2016,9,28,Wed,77,69,66.5,68,62
2016,9,29,Thurs,69,68,66.1,66,57
2016,9,30,Fri,68,66,65.7,67,74
2016,10,1,Sat,66,67,65.3,63,54
2016,10,2,Sun,67,63,64.9,65,82
2016,10,3,Mon,63,65,64.5,61,49
2016,10,4,Tues,65,61,64.1,63,60
2016,10,5,Wed,61,63,63.7,66,48
2016,10,6,Thurs,63,66,63.3,63,55
2016,10,7,Fri,66,63,62.9,64,78
2016,10,8,Sat,63,64,62.5,68,73
2016,10,9,Sun,64,68,62.1,57,55
2016,10,10,Mon,68,57,61.8,60,62
2016,10,11,Tues,57,60,61.4,62,58
2016,10,12,Wed,60,62,61,66,52
2016,10,13,Thurs,62,66,60.6,60,57
2016,10,14,Fri,66,60,60.2,60,78
2016,10,15,Sat,60,60,59.9,62,46
2016,10,16,Sun,60,62,59.5,60,40
2016,10,17,Mon,62,60,59.1,60,62
2016,10,18,Tues,60,60,58.8,61,53
2016,10,19,Wed,60,61,58.4,58,41
2016,10,20,Thurs,61,58,58.1,62,43
2016,10,21,Fri,58,62,57.8,59,44
2016,10,22,Sat,62,59,57.4,62,44
2016,10,23,Sun,59,62,57.1,62,67
2016,10,24,Mon,62,62,56.8,61,70
2016,10,25,Tues,62,61,56.5,65,70
2016,10,26,Wed,61,65,56.2,58,41
2016,10,27,Thurs,65,58,55.9,60,39
2016,10,28,Fri,58,60,55.6,65,52
2016,10,29,Sat,60,65,55.3,68,65
2016,10,31,Mon,65,68,54.8,59,62
2016,11,1,Tues,68,59,54.5,57,61
2016,11,2,Wed,59,57,54.2,57,70
2016,11,3,Thurs,57,57,53.9,65,35
2016,11,4,Fri,57,65,53.7,65,38
2016,11,5,Sat,65,65,53.4,58,41
2016,11,6,Sun,65,58,53.2,61,71
2016,11,7,Mon,58,61,52.9,63,35
2016,11,8,Tues,61,63,52.7,71,49
2016,11,9,Wed,63,71,52.4,65,42
2016,11,10,Thurs,71,65,52.2,64,38
2016,11,11,Fri,65,64,51.9,63,55
2016,11,12,Sat,64,63,51.7,59,63
2016,11,13,Sun,63,59,51.4,55,64
2016,11,14,Mon,59,55,51.2,57,42
2016,11,15,Tues,55,57,51,55,46
2016,11,16,Wed,57,55,50.7,50,34
2016,11,17,Thurs,55,50,50.5,52,57
2016,11,18,Fri,50,52,50.3,55,35
2016,11,19,Sat,52,55,50,57,56
2016,11,20,Sun,55,57,49.8,55,30
2016,11,21,Mon,57,55,49.5,54,67
2016,11,22,Tues,55,54,49.3,54,58
2016,11,23,Wed,54,54,49.1,49,38
2016,11,24,Thurs,54,49,48.9,52,29
2016,11,25,Fri,49,52,48.6,52,41
2016,11,26,Sat,52,52,48.4,53,58
2016,11,27,Sun,52,53,48.2,48,53
2016,11,28,Mon,53,48,48,52,44
2016,11,29,Tues,48,52,47.8,52,50
2016,11,30,Wed,52,52,47.6,52,44
2016,12,1,Thurs,52,52,47.4,46,39
2016,12,2,Fri,52,46,47.2,50,41
2016,12,3,Sat,46,50,47,49,58
2016,12,4,Sun,50,49,46.8,46,53
2016,12,5,Mon,49,46,46.6,40,65
2016,12,6,Tues,46,40,46.4,42,56
2016,12,7,Wed,40,42,46.3,40,62
2016,12,8,Thurs,42,40,46.1,41,36
2016,12,9,Fri,40,41,46,36,54
2016,12,10,Sat,41,36,45.9,44,65
2016,12,11,Sun,36,44,45.7,44,35
2016,12,12,Mon,44,44,45.6,43,42
2016,12,13,Tues,44,43,45.5,40,46
2016,12,14,Wed,43,40,45.4,39,49
2016,12,15,Thurs,40,39,45.3,39,46
2016,12,16,Fri,39,39,45.3,35,39
2016,12,17,Sat,39,35,45.2,35,38
2016,12,18,Sun,35,35,45.2,39,36
2016,12,19,Mon,35,39,45.1,46,51
2016,12,20,Tues,39,46,45.1,51,62
2016,12,21,Wed,46,51,45.1,49,39
2016,12,22,Thurs,51,49,45.1,45,38
2016,12,23,Fri,49,45,45.1,40,35
2016,12,24,Sat,45,40,45.1,41,39
2016,12,25,Sun,40,41,45.1,42,31
2016,12,26,Mon,41,42,45.2,42,58
2016,12,27,Tues,42,42,45.2,47,47
2016,12,28,Wed,42,47,45.3,48,58
2016,12,29,Thurs,47,48,45.3,48,65
2016,12,30,Fri,48,48,45.4,57,42
2016,12,31,Sat,48,57,45.5,40,57

数据集处理

读取数据

# -*-coding:utf-8-*-
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torch.optim as optim
import warnings
warnings.filterwarnings("ignore")
#   year  month  day  week  temp_2  temp_1  average  actual(y)  friend
features=pd.read_csv("temps.csv")

# print(features.head())
#数据维度 (348, 9)
print("数据维度",features.shape)
print(features.head(5))

数据维度 (348, 9)
   year  month  day  week  temp_2  temp_1  average  actual  friend
0  2016      1    1   Fri      45      45     45.6      45      29
1  2016      1    2   Sat      44      45     45.7      44      61
2  2016      1    3   Sun      45      44     45.8      41      56
3  2016      1    4   Mon      44      41     45.9      40      53
4  2016      1    5  Tues      41      40     46.0      44      41

 把年月日转换为datetime格式

#处理事件数据
import datetime
#分别得到年、月、日
years=features["year"]
months=features["month"]
days=features["day"]

#datetime格式
"""
['2016-1-1', '2016-1-2', '2016-1-3', '2016-1-4', '2016-1-5']
"""
dates=[str(int(year))+"-"+str(int(month))+"-"+str(int(day)) for year,month,day in zip(years,months,days)]
print(dates[:5])


"""
[datetime.datetime(2016, 1, 1, 0, 0), 
datetime.datetime(2016, 1, 2, 0, 0), 
datetime.datetime(2016, 1, 3, 0, 0), 
datetime.datetime(2016, 1, 4, 0, 0), 
datetime.datetime(2016, 1, 5, 0, 0)]
"""
dates=[datetime.datetime.strptime(date, "%Y-%m-%d") for date in dates]
print(dates[:5])

['2016-1-1', '2016-1-2', '2016-1-3', '2016-1-4', '2016-1-5']
[datetime.datetime(2016, 1, 1, 0, 0),

datetime.datetime(2016, 1, 2, 0, 0),

datetime.datetime(2016, 1, 3, 0, 0),

datetime.datetime(2016, 1, 4, 0, 0),

datetime.datetime(2016, 1, 5, 0, 0)]

画图显示数据

plt.subplots(nrows=2, ncols=2, figsize = (20,20)) 2行,2列,图像大小20*20
fig.autofmt_xdate(rotation = 45) X轴的字体旋转角度,也就是dates所对应的年月日信息显示旋转角度,这里是45°,如:2016-01等,倾斜显示

ax1.plot(dates, dataset['actual']) X轴为dates(年月日),Y轴为actual(实际的真实温度值)
ax1.set_xlabel(''); ax1.set_ylabel('Temperature'); ax1.set_title('Actual Max Temp') X轴标签为空,Y轴标签为Temperature,整体标题为Actual Max Temp
其他的同理

plt.tight_layout(pad=1, h_pad=1, w_pad=1) #子图间隔有多大
pad:图形边和子图的边之间进行填充
h_pad,w_pad:相邻子图的边之间的填充(高度/宽度)

def drawData(dates,features):
    #准备画图
    #指定默认风格
    plt.style.use("fivethirtyeight")

    #设置布局
    fig, ((ax1,ax2),(ax3,ax4)) = plt.subplots(nrows=2,ncols=2,figsize=(10,10))
    #X轴标签倾斜显示
    fig.autofmt_xdate(rotation=45)

    #标签值
    ax1.plot(dates,features["actual"])
    ax1.set_xlabel(" "); ax1.set_ylabel("Actual Temperature"); ax1.set_title("Max Temp")

    #昨天
    ax2.plot(dates,features["temp_1"])
    ax2.set_xlabel(" "); ax2.set_ylabel("temp_1 Temperature"); ax2.set_title("Previous Max Temp")

    #前天
    ax3.plot(dates,features["temp_2"])
    ax3.set_xlabel("Date"); ax3.set_ylabel("temp_2 Temperatiure"); ax3.set_title("Two Days Prior Previous Max Temp")

    ax4.plot(dates,features["friend"])
    ax4.set_xlabel("Date"); ax4.set_ylabel("friend Temperatiure"); ax4.set_title("Friend Estimate")

    #子图间隔多大
    plt.tight_layout(pad=2)
    plt.show()

drawData(dates,features)


 把星期几转为独热编码

#读热编码
#year  month  day  temp_2  ... week_Fri week_Mon week_Sat week_Sun  week_Thurs  week_Tues  week_Wed
features=pd.get_dummies(features)
print( features.head(8) )

   year  month  day  temp_2  ...  week_Sun  week_Thurs  week_Tues  week_Wed
0  2016      1    1      45  ...         0           0          0         0
1  2016      1    2      44  ...         0           0          0         0
2  2016      1    3      45  ...         1           0          0         0
3  2016      1    4      44  ...         0           0          0         0
4  2016      1    5      41  ...         0           0          1         0
5  2016      1    6      40  ...         0           0          0         1
6  2016      1    7      44  ...         0           1          0         0
7  2016      1    8      51  ...         0           0          0         0

 提取actual的实际值

#标签,把标签值actual提取出来
#labels=[45 44 41 40 44 51 45 48 50 ....]
labels=np.array(features["actual"])

#在特征中去掉标签, 去掉actual,因为actual是y
features=features.drop("actual",axis=1)

#名字单独保存一下,以备后患
feature_list=list(features.columns)
print(feature_list)

['year', 'month', 'day', 'temp_2', 'temp_1', 'average', 'friend', 'week_Fri', 'week_Mon', 'week_Sat', 'week_Sun', 'week_Thurs', 'week_Tues', 'week_Wed']

标准化数据

(1)把features变成了数值,用科学计数法,(348,14)

features=np.array(features)
print(features)

[[2.016e+03 1.000e+00 1.000e+00 ... 0.000e+00 0.000e+00 0.000e+00]
 [2.016e+03 1.000e+00 2.000e+00 ... 0.000e+00 0.000e+00 0.000e+00]
 [2.016e+03 1.000e+00 3.000e+00 ... 0.000e+00 0.000e+00 0.000e+00]
 ...
 [2.016e+03 1.200e+01 2.900e+01 ... 1.000e+00 0.000e+00 0.000e+00]
 [2.016e+03 1.200e+01 3.000e+01 ... 0.000e+00 0.000e+00 0.000e+00]
 [2.016e+03 1.200e+01 3.100e+01 ... 0.000e+00 0.000e+00 0.000e+00]]

(2)标准化 (x-平均值)/标准差

#标准化 (x-平均值)/标准差
from sklearn import  preprocessing
input_features=preprocessing.StandardScaler().fit_transform(features)

print( input_features )

[[ 0.         -1.5678393  -1.65682171 ... -0.40482045 -0.41913682
  -0.40482045]
 [ 0.         -1.5678393  -1.54267126 ... -0.40482045 -0.41913682
  -0.40482045]
 [ 0.         -1.5678393  -1.4285208  ... -0.40482045 -0.41913682
  -0.40482045]
 ...
 [ 0.          1.5810006   1.53939107 ...  2.47023092 -0.41913682
  -0.40482045]
 [ 0.          1.5810006   1.65354153 ... -0.40482045 -0.41913682
  -0.40482045]
 [ 0.          1.5810006   1.76769198 ... -0.40482045 -0.41913682
  -0.40482045]]

第一种方式构建网络模型

(1)构建x和y矩阵。

构建tensor类型的x和y变量

#x:torch.Size([348, 14])
x=torch.tensor(input_features,dtype=float)
y=torch.tensor(labels,dtype=float)

(2)权重参数初始化

#权重参数初始化
weights=torch.randn((14,128),dtype=float,requires_grad=True)
biases=torch.randn(128,dtype=float ,requires_grad=True)
weights2=torch.randn((128,1),dtype=float,requires_grad=True)
biases2=torch.randn(1,dtype=float ,requires_grad=True)
#用于梯度下降的学习率
learning_rate=0.001
#记录损失值
losses=[]

 (3)搭建网络模型

#搭建网络模型
for i in range(1000):
    #计算隐层
    hidden=x.mm(weights)+biases
    #激活函数
    hidden=torch.relu(hidden)
    #预测结果
    predictions=hidden.mm(weights2)+biases2

    #计算损失
    #cost/w 1/2n
    loss=torch.mean((predictions - y)**2)
    losses.append(loss.data.numpy())
    #打印损失值
    if i%100==0:
        print("loss:",loss)

    #反向传播计算
    loss.backward()

    #更新参数
    weights.data.add_(- learning_rate*weights.grad.data)
    biases.data.add_(- learning_rate*biases.grad.data)
    weights2.data.add_(- learning_rate*weights2.grad.data)
    biases2.data.add_(- learning_rate*biases2.grad.data)

    #每次迭代都得清空,梯度不清零会累加的
    weights.grad.data.zero_()
    biases.grad.data.zero_()
    weights2.grad.data.zero_()
    biases2.grad.data.zero_()

print(predictions.shape)

loss: tensor(1214.2573, dtype=torch.float64, grad_fn=<MeanBackward0>)
loss: tensor(155.9543, dtype=torch.float64, grad_fn=<MeanBackward0>)
loss: tensor(147.3351, dtype=torch.float64, grad_fn=<MeanBackward0>)
loss: tensor(144.7511, dtype=torch.float64, grad_fn=<MeanBackward0>)
loss: tensor(143.4206, dtype=torch.float64, grad_fn=<MeanBackward0>)
loss: tensor(142.5542, dtype=torch.float64, grad_fn=<MeanBackward0>)
loss: tensor(141.9567, dtype=torch.float64, grad_fn=<MeanBackward0>)
loss: tensor(141.5120, dtype=torch.float64, grad_fn=<MeanBackward0>)
loss: tensor(141.1755, dtype=torch.float64, grad_fn=<MeanBackward0>)
loss: tensor(140.9053, dtype=torch.float64, grad_fn=<MeanBackward0>)
torch.Size([348, 1])

(4) 把actual的实际值和prediction画在一张图

loss是140.9,所以拟合效果非常差

 第二种方式构建神经网络

 (1)数据预处理

# -*-coding:utf-8-*-
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torch.optim as optim
import warnings
warnings.filterwarnings("ignore")
#   year  month  day  week  temp_2  temp_1  average  actual(y)  friend
features=pd.read_csv("temps.csv")

# print(features.head())
#数据维度 (348, 9)
print("数据维度",features.shape)


#处理事件数据
import datetime
#分别得到年、月、日
years=features["year"]
months=features["month"]
days=features["day"]

#datetime格式
dates=[str(int(year))+"-"+str(int(month))+"-"+str(int(day)) for year,month,day in zip(years,months,days)]
dates=[datetime.datetime.strptime(date, "%Y-%m-%d") for date in dates]

print(dates[:5])


#hotcode
#year  month  day  temp_2  ... week_Fri week_Mon week_Sat week_Sun  week_Thurs  week_Tues  week_Wed
features=pd.get_dummies(features)
# print( features.head(5) )

#标签
labels=np.array(features["actual"])

#在特征中去掉标签, 去掉actual,因为actual是y
features=features.drop("actual",axis=1)

#名字单独保存一下,以备后患
feature_list=list(features.columns)

#变成了数值,用科学计数法,(348,14)
features=np.array(features)
# print(features)

#标准化 (x-平均值)/标准差
from sklearn import  preprocessing
input_features=preprocessing.StandardScaler().fit_transform(features)

# print( input_features )
# print( input_features[0] )

#构建网络模型
#x:torch.Size([348, 14])
x=torch.tensor(input_features,dtype=float)
y=torch.tensor(labels,dtype=float)

(2)构建模型,计算预测值

#更简单的构建网络模型
#input_size:特征数:14
input_size=input_features.shape[1]

#隐藏层有128个神经元
hidden_size=128
#神经网络的最终输出1个数,即predictions_actual
output_size=1
#每次读取16行数据
batch_size=16

#Sequential:按顺序执行
my_nn=torch.nn.Sequential(
    #第一层神经元(14,128)
    torch.nn.Linear(input_size,hidden_size),
    torch.nn.Sigmoid(),
    #第二层神经元(128,1)
    torch.nn.Linear(hidden_size,output_size)
)
#损失函数
cost=torch.nn.MSELoss(reduction="mean")

#梯度下降,更新权重参数
#Adam:很强的工具,
optimizer=torch.optim.Adam(my_nn.parameters(),lr=0.001)

#训练网络
losses=[]
for i in range (2000):
    batch_loss=[]
    #MINI-Batch方法进行训练,0~348,每次训练选batch_size(16)个样本
    for start in range (0,len(input_features),batch_size):
        end=start+batch_size if start+batch_size <len(input_features) else len(input_features)
        xx=torch.tensor(input_features[start:end],dtype=torch.float,requires_grad=True)
        yy=torch.tensor(labels[start:end],dtype=torch.float,requires_grad=True)
        prediction=my_nn(xx)
        loss=cost(prediction,yy)
        #梯度清零
        optimizer.zero_grad()
        #反向传播
        loss.backward(retain_graph=True)
        #梯度下降:step
        optimizer.step()
        batch_loss.append(loss.data.numpy())

    #计算损失
    if i%100==0:
        losses.append(np.mean(batch_loss))
        print(i,np.mean(batch_loss))


#预测训练结果
x=torch.tensor(input_features,dtype=torch.float)
predict=my_nn(x).data.numpy()

0 3938.2195
100 37.869488
200 35.65198
300 35.276585
400 35.106327
500 34.972412
600 34.85649
700 34.742622
800 34.62497
900 34.503555

(3) 画图

#转换日期格式
#datetime格式
dates=[str(int(year))+"-"+str(int(month))+"-"+str(int(day)) for year,month,day in zip(years,months,days)]
dates=[datetime.datetime.strptime(date, "%Y-%m-%d") for date in dates]

#true创建
true_data=pd.DataFrame(data={"date":dates,"actual":labels})

#同理,再创建一个来存日期和其对应的模型预测值
months=features[:,feature_list.index("month")]
days=features[:,feature_list.index("day")]
years=features[:,feature_list.index("year")]

test_dates=[str(int(year))+"-"+str(int(month))+"-"+str(int(day)) for year,month,day in zip(years,months,days)]
#          [datetime.datetime.strptime(date, "%Y-%m-%d") for date in dates]
test_dates=[datetime.datetime.strptime(date, "%Y-%m-%d") for date in test_dates]

predictions_data=pd.DataFrame(data={"date":test_dates,"prediction":predict.reshape(-1)})


#真实值
plt.plot(true_data["date"],true_data["actual"],"b-",label="actual")

#预测值
plt.plot(predictions_data["date"],predictions_data["prediction"],"ro",label="prediction")
plt.xticks(rotation = 60);
# plt.xticks(rotation="60")
plt.legend()
#图名
plt.xlabel("Date");plt.ylabel("Maximum Temperature(F)");plt.title("Actual and Predicted Values");
plt.show()

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/810169.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【Linux】HTTPS协议是如何保证数据安全的

​&#x1f320; 作者&#xff1a;阿亮joy. &#x1f386;专栏&#xff1a;《学会Linux》 &#x1f387; 座右铭&#xff1a;每个优秀的人都有一段沉默的时光&#xff0c;那段时光是付出了很多努力却得不到结果的日子&#xff0c;我们把它叫做扎根 目录 &#x1f449;基础概念…

【Linux】进程篇Ⅰ:进程信息、进程状态、环境变量、进程地址空间

文章目录 一、概述二、查看进程信息1. 系统文件夹 /proc2. 用户级工具 ps3. getpid() 函数&#xff1a;查看进程 PID4. 用 kill 杀进程5. 进程优先级 二、进程状态分析0. 1. R (running) 运行状态2. S (sleeping) 休眠状态3. D (disk sleep) 不可中断的休眠状态4. T (stopped) …

【数据结构】顺序表(SeqList)(增、删、查、改)详解

一、顺序表的概念和结构 1、顺序表的概念&#xff1a; 顺序表是用一段物理地址连续的存储单元依次存储数据元素的线性结构&#xff0c;一般情况下采用数组存储。在数组上完成数据的增删查改。 2、顺序表的结构&#xff1a; &#xff08;1&#xff09;静态顺序表&#xff1a;使…

Redis Cluster 在Spring中遇到的问题

Redis集群配置可能会在运行时更改。可以添加新节点&#xff0c;可以更改特定插槽的主节点。还有可能因为master宕机或网络抖动等原因&#xff0c;引起了主从切换。 无法感知集群槽位变化 SpringBoot2.x 开始默认使用的 Redis 客户端由 Jedis 变成了 Lettuce&#xff0c;但是当…

忽略nan值,沿指定轴计算标准(偏)差numpy.nanstd()

【小白从小学Python、C、Java】 【计算机等考500强证书考研】 【Python-数据分析】 沿指定轴方向 计算标准(偏)差 numpy.nanstd() [太阳]选择题 import numpy as np a np.array([[1,2],[np.nan,3]]) print("【显示】a ") print(a) print("【执行】np.std(a)&qu…

QT项目代码去UI界面常用开发步骤

QT项目代码去UI界面常用开发步骤 因项目开发需求&#xff0c;领导要求整个QT项目中不要用UI方式来实现界面&#xff0c;这样能保障程序运行稳定性以及代码的逻辑和可读性,先记录具体操作步骤如下&#xff1a; 1、首先我们通过拖控件的方式来实现界面的设计效果&#xff0c…

ARM汇编基本变量的定义和使用

一、ARM汇编中基本变量是什么? 数字变量: GBLA LCLA SETA 逻辑变量:GBLL LCLL SETL 字符串:GBLS LCLS SETLS 注意需要TAB键定义变量和行首改变值 二、使用步骤 1.引入库 代码如下(示例): GBLA led_num Reset_Handler PROCEXPORT Reset_Handler [WEA…

HCIP BGP综合实验

题目 1、AS1存在两个环回&#xff0c;一个地址为192.168.1.0/24该地址不能在任何协议中宣告&#xff1b; 2、AS3中存在两个环回&#xff0c;一个地址为192.168.2.0/24该地址不能在任何协议中宣告&#xff0c;最终要求这两个环回可以互相通讯&#xff1b; 3、AS间的骨干链路I…

Vue3搭建启动

Vue3搭建&启动 一、创建项目二、启动项目三、配置项目1、添加编辑器配置文件2、配置别名3、处理sass/scss4、处理tsx 四、添加Eslint 一、创建项目 npm create vite 1.project-name 输入项目名vue3-vite 2.select a framework 选择框架 3.select a variant 选择语言 二、启…

idea 安装 插件jrebel 报错LS client not configured.

这个报错找了好久&#xff0c;有博主说版本不对&#xff0c;我脑子没反应过来以为是随便换一个低版本的就行&#xff0c;没想到只能是2022.4.1 这个版本才行 一定要用jrebel 2022.4.1的插件版本&#xff01;&#xff01;&#xff01;&#xff01;&#xff01; 插件下载地址&…

网络面试合集

传输层的数据结构是什么&#xff1f; 就是在问他的协议格式&#xff1a;UDP&TCP 2.1.1三次握手 通信前&#xff0c;要先建立连接&#xff0c;确保双方都是在线&#xff0c;具有数据收发的能力。 2.1.2四次挥手 通信结束后&#xff0c;会有一个断开连接的过程&#xff0…

❤️创意网页:绚丽粒子雨动画

✨博主&#xff1a;命运之光 &#x1f338;专栏&#xff1a;Python星辰秘典 &#x1f433;专栏&#xff1a;web开发&#xff08;简单好用又好看&#xff09; ❤️专栏&#xff1a;Java经典程序设计 ☀️博主的其他文章&#xff1a;点击进入博主的主页 前言&#xff1a;欢迎踏入…

Codeforces Round 889 (Div. 2) 题解

晚上睡不着就来总结一下叭~&#xff08;OoO&#xff09; 赛后榜(希望不要被Hack...Orz) 终榜&#xff01;&#xff01;&#xff01; 瞬间的辉煌(呜呜呜~) 先不放图了。。怕被dalaoHack...呜呜呜~ 总结 7.29半夜比赛&#xff0c;本来是不想打的&#xff0c;感觉最近做的题太多…

Manjaro KDE 22.1.3vmware无法复制文件

Wayland 是 X11 的现代替代品&#xff0c;几十年来 X11 一直是 Linux 上的默认窗口系统。 Wayland 是一种通信协议&#xff0c;定义 X Window 显示服务器和客户端应用程序之间的消息传递。 软件还不兼容 使用X11即可

JavaScript中的switch语句

switch语句和if语句一样&#xff0c;同样是运用于条件循环中&#xff1b; 下面例子我们用switch实现 例如如果今天是周一就学习HTML&#xff0c;周二学习CSS和JavaScript&#xff0c;周三学习vue&#xff0c;周四&#xff0c;周五学习node.js&#xff0c;周六周日快乐玩耍&…

微服务项目,maven无法加载其他服务依赖

微服务项目&#xff0c;导入了工具类工程&#xff0c;但是一直报错&#xff0c;没有该类&#xff0c; 检查maven 这里的Maven的版本与idea版本不匹配可能是导致依赖加载失败的最重要原因 检查maven配置&#xff0c;我这是原来的maven&#xff0c;home 修改之后,就不报错了

39.密码长度改变图片模糊

密码长度改变图片模糊 html部分 <div class"bg"></div> <div class"container"><h1>Image Password Strength</h1><h3>Change the password to see the effect</h3><div class"email" style&quo…

Mybatis-Flex 比 MyBatis-Plus更轻量,高性能

一、Mybatis-Flex是什么&#xff1f; Mybatis-Flex 是一个优雅的 Mybatis 增强框架&#xff0c;它非常轻量、同时拥有极高的性能与灵活性。我们可以轻松的使用 Mybaits-Flex 链接任何数据库&#xff0c;其内置的 QueryWrapper^亮点 帮助我们极大的减少了 SQL 编写的工作的同时…

MQTT服务器详细介绍:连接物联网的通信枢纽

随着物联网技术的不断发展&#xff0c;MQTT&#xff08;Message Queuing Telemetry Transport&#xff09;协议作为一种轻量级、可靠、灵活的通信协议&#xff0c;被广泛应用于物联网领域。在MQTT系统中&#xff0c;MQTT服务器扮演着重要的角色&#xff0c;作为连接物联网设备和…

MDK5__配色方案的修改

一、必要的知识 与MDK主题相关的文件有两个&#xff0c;在X:\Keil_v5\UV4路径下&#xff1a; global.propglobal.prop.def其中global.prop.def是系统默认的主题配置 如果修改过字体等&#xff0c;系统会生成一个global.prop。 二、修改的步骤 1、打开工程 菜单 Edit 下 Con…