(二)Pytorch快速搭建神经网络模型实现气温预测回归(代码+详细注解)

news2024/11/17 19:52:01

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档

文章目录

  • 前言
  • 一、数据集
  • 二、导入数据以及展示部分
    • 1.导入数据集以及对数据集进行处理
    • 2.展示数据(看看就好)
  • 三(1)、搭建网络进行预测(理解版)
  • 三(2)、搭建网络进行预测(应用版)
  • 四、 对预测结果进行一个展示,蓝色真实值,红色预测值
  • 总结


前言

深度学习pytorch系列第二篇,第一篇实现的是分类任务,这篇是回归任务,大差不差,重在理解,具体的理解内容我都以注释的形式放在了代码中,方便大家阅读


一、数据集

想要复现的可以下载
链接:网盘链接
提取码:k6a4

二、导入数据以及展示部分

1.导入数据集以及对数据集进行处理

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
# 过滤警告
import warnings
warnings.filterwarnings("ignore")
# 读取数据
features = pd.read_csv('data/temps.csv')
#
#看看数据长什么样子
# print(features.head())
# print('数据维度:', features.shape)
# 数据维度:(348, 9),348条数据,每条8个特征x,1个标签y
# 处理时间数据
import datetime
# 分别得到年,月,日
years = features['year']
months = features['month']
days = features['day']
#
# # datetime格式
dates = [str(int(year)) + '-' + str(int(month)) + '-' + str(int(day)) for year, month, day in zip(years, months, days)]
dates = [datetime.datetime.strptime(date, '%Y-%m-%d') for date in dates]
# 在打印的结果中,每个datetime.datetime对象的后面两个0表示小时和分钟,没有时默认为0
# print(dates[:5])
# 独热编码
# # 将字符串进行onehot
# # 周一 周二 周三 周四 周五 周六 周天
# # 如果是周一,编码就是
# # 1000000
# Pandas库中的get_dummies函数,是一种独热编码(One-Hot Encoding)的方法
features = pd.get_dummies(features)

# print(features.head(5))
# print(features.shape)
# 此时的数据维度:(348, 15),多的7个是日期的七天
# 取标签
labels = np.array(features['actual'])
# 在特征中去掉标签,features.drop,去掉标签列
features= features.drop('actual', axis = 1)
# 名字单独保存一下,以备后患
feature_list = list(features.columns)
# 转换成合适的格式
features = np.array(features)
# print(features.shape)
# print(features)
"""
数据标准化
由于神经网络在训练的过程中具有倾向性,数值越大,认为越重要
# 但是在月份这种重要程度与数值无关的特征上,这种倾向性就会出错
# 因此进行标准化,使数据以零点为中心均匀分布
# (x-u)/σ
# x-u  去均值
# /σ  除以标准差:让离散数据更加收敛
标准化通常是针对特征而不是标签的。
标准化的目的是使特征具有相同的尺度,以便模型能够更好地学习权重并提高模型的性能。
标签(也称为目标变量)通常不需要标准化,因为它们是模型试图预测的值,而不是用于学习权重的输入。
"""
from sklearn import preprocessing
input_features = preprocessing.StandardScaler().fit_transform(features)
"""
[ 0.         -1.5678393  -1.65682171 -1.48452388 -1.49443549 -1.3470703
 -1.98891668  2.44131112 -0.40482045 -0.40961596 -0.40482045 -0.40482045
 -0.41913682 -0.40482045]
 标准化处理后的数据以零点为中心,均匀分布
"""

上述代码中的初始数据集为:
在这里插入图片描述
处理完成后的数据样貌:
在这里插入图片描述

2.展示数据(看看就好)

代码如下(示例):

# 该段是展示一下数据的样貌
plt.style.use('fivethirtyeight')
# 设置布局
# 4个子图,两行两列
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(nrows=2, ncols=2, figsize = (10,10))
# 坐标倾斜45度
fig.autofmt_xdate(rotation = 45)

# 标签值
ax1.plot(dates, features['actual'])
ax1.set_xlabel(''); ax1.set_ylabel('Temperature'); ax1.set_title('Max Temp')
# 昨天
ax2.plot(dates, features['temp_1'])
ax2.set_xlabel(''); ax2.set_ylabel('Temperature'); ax2.set_title('Previous Max Temp')
#
# 前天
ax3.plot(dates, features['temp_2'])
ax3.set_xlabel('Date'); ax3.set_ylabel('Temperature'); ax3.set_title('Two Days Prior Max Temp')
#
# 朋友感觉的值
ax4.plot(dates, features['friend'])
ax4.set_xlabel('Date'); ax4.set_ylabel('Temperature'); ax4.set_title('Friend Estimate')
# 子图之间间隔多少
plt.tight_layout(pad=2)
plt.show()

展示图如下:
在这里插入图片描述


三(1)、搭建网络进行预测(理解版)

该过程是一步一步构建网络,促进理解,后边会附上更为简单的网络结构


x = torch.tensor(input_features, dtype=float)
y = torch.tensor(labels, dtype=float)
# # 权重参数初始化
# (14, 128),将14个特征转成128个神经元,可以理解为转成128个特征
# requires_grad = True,是否求导,也就是是否记录梯度
weights = torch.randn((14, 128), dtype=float, requires_grad=True)
biases = torch.randn(128, dtype=float, requires_grad=True)
weights2 = torch.randn((128, 1), dtype=float, requires_grad=True)
biases2 = torch.randn(1, dtype=float, requires_grad=True)
# 学习率  :决定梯度更新幅度的大小,计算出来的梯度只能确定方向
# 这个幅度不能太大
learning_rate = 0.001
losses = []
# 迭代次数,每次算梯度,然后更新
for i in range(1000):
    # 计算隐层
    hidden = x.mm(weights) + biases
    # 加入激活函数,非线性映射
    hidden = torch.relu(hidden)
    # 预测结果  :h1*w2+b2=预测值
    predictions = hidden.mm(weights2) + biases2
    # 通计算损失
    loss = torch.mean((predictions - y) ** 2)
    losses.append(loss.data.numpy())

    # 打印损失值
    if i % 100 == 0:
        print('loss:', loss)
    # 返向传播计算
    loss.backward()

    # 更新参数
    #     grad.data  取梯度,然后乘以学习率,应该沿着梯度的反方向更新
    weights.data.add_(- learning_rate * weights.grad.data)
    biases.data.add_(- learning_rate * biases.grad.data)
    weights2.data.add_(- learning_rate * weights2.grad.data)
    biases2.data.add_(- learning_rate * biases2.grad.data)

    # 每次迭代都得记得清空
    #     每次迭代过程都是独立的,之前计算的梯度要清零
    # 在torch中,如果不清零,梯度就会累加
    weights.grad.data.zero_()
    biases.grad.data.zero_()
    weights2.grad.data.zero_()
    biases2.grad.data.zero_()
print(predictions.shape)
print(predictions)

三(2)、搭建网络进行预测(应用版)

实际应用中,往往会这样实现

# 更简单的构建网络模型
# 取特征个数
# 0是样本数;1是特征数
input_size = input_features.shape[1]
# print(input_size)  14 有14个特征
# 隐层个数
hidden_size = 128
output_size = 1
batch_size = 16
# Sequential序列模块,按顺序执行
my_nn = torch.nn.Sequential(
    # 计算隐层,相当于wx+b,参数是自动更新的
    torch.nn.Linear(input_size, hidden_size),
#     激活函数
    torch.nn.Sigmoid(),
#     预测结果  :h1*w2+b2=预测值
    torch.nn.Linear(hidden_size, output_size),
)
# 计算损失
# reduction='mean  平均损失
cost = torch.nn.MSELoss(reduction='mean')
# 优化器
# my_nn.parameters() 更新nn中所有参数
optimizer = torch.optim.Adam(my_nn.parameters(), lr = 0.001)
# ADM优化器,比SGD(梯度下降)效果好,效率高
# 训练网络
losses = []
# 迭代1000次
for i in range(1000):
    #     每次取一个batch的数据,每次只取一批数据
    batch_loss = []
    # MINI-Batch方法来进行训练
    #   for start in range(0, len(input_features), batch_size):
    # 从0开始,到整个数据结束,取batch,间隔是一个batch_size大小
    for start in range(0, len(input_features), batch_size):
        end = start + batch_size if start + batch_size < len(input_features) else len(input_features)  # 判断索引越界
        xx = torch.tensor(input_features[start:end], dtype=torch.float, requires_grad=True)
        yy = torch.tensor(labels[start:end], dtype=torch.float, requires_grad=True)
        prediction = my_nn(xx)
        loss = cost(prediction, yy)
        #         通过优化器进行梯度清零
        optimizer.zero_grad()
        #     反向传播
        loss.backward(retain_graph=True)
        #     更新参数
        optimizer.step()
        #     将每一个batch的损失相加
        batch_loss.append(loss.data.numpy())

    # 打印损失
    if i % 100 == 0:
        losses.append(np.mean(batch_loss))
        print(i, np.mean(batch_loss))
x = torch.tensor(input_features, dtype = torch.float)
# 所有的数据进行预测,得到结果,进行画图
predict = my_nn(x).data.numpy()

四、 对预测结果进行一个展示,蓝色真实值,红色预测值

# 转换日期格式
dates = [str(int(year)) + '-' + str(int(month)) + '-' + str(int(day)) for year, month, day in zip(years, months, days)]
dates = [datetime.datetime.strptime(date, '%Y-%m-%d') for date in dates]

# 创建一个表格来存日期和其对应的标签数值
true_data = pd.DataFrame(data = {'date': dates, 'actual': labels})

# 同理,再创建一个来存日期和其对应的模型预测值
months = features[:, feature_list.index('month')]
days = features[:, feature_list.index('day')]
years = features[:, feature_list.index('year')]

test_dates = [str(int(year)) + '-' + str(int(month)) + '-' + str(int(day)) for year, month, day in zip(years, months, days)]

test_dates = [datetime.datetime.strptime(date, '%Y-%m-%d') for date in test_dates]

predictions_data = pd.DataFrame(data = {'date': test_dates, 'prediction': predict.reshape(-1)})
# 真实值
plt.plot(true_data['date'], true_data['actual'], 'b-', label = 'actual')

# 预测值
plt.plot(predictions_data['date'], predictions_data['prediction'], 'ro', label = 'prediction')
plt.xticks(rotation = '60');
plt.legend()
plt.show()
# 图名
plt.xlabel('Date'); plt.ylabel('Maximum Temperature (F)'); plt.title('Actual and Predicted Values');
# 层数越来越对,就会过拟合
# 什么是过拟合?过拟合(Overfitting)是指机器学习模型在训练数据上表现得很好,但在未见过的新数据上表现较差的现象。

在这里插入图片描述

总结

pytorch学习的第二篇啦,慢慢更新ing

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1222385.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

初学编程学习,计算机编程怎么自学,中文编程工具下载

初学编程学习&#xff0c;计算机编程怎么自学&#xff0c;中文编程工具下载 给大家分享一款中文编程工具&#xff0c;零基础轻松学编程&#xff0c;不需英语基础&#xff0c;编程工具可下载。 这款工具不但可以连接部分硬件&#xff0c;而且可以开发大型的软件&#xff0c;象如…

Hangfire.Pro 3.0 Crack

Hangfire.Pro 有限的存储支持 Hangfire Pro 是一组扩展包&#xff0c;允许使用批处理创建复杂的后台作业工作流程&#xff0c;并提供对超快速Redis作为作业存储的支持 请注意&#xff0c;仅在使用Hangfire.SqlServer、Hangfire.Pro.Redis或Hangfire.InMemory包作为作业存储时才…

SpringCloud 2022有哪些变化

目录 前提条件 AOT支持 Spring Native支持 前提条件 Spring Cloud 2022.0.0是构建在Spring Framework 6.0和Spring Boot 3.0 之上的一S个主要版本。 JDK要求最低需要是Java 17J2EE要求最低需要Jakarta EE 9 AOT支持 Spring cloud 2022支持AOT编译&#xff0c;它是将程序源…

【mysql】1153 - Got a packet bigger than ‘max_allowed_packet‘ bytes

执行mysql 语句出现&#xff1a;1153 - Got a packet bigger than max_allowed_packet bytes&#xff1b; 1153-得到一个大于“max_allowed_packet”字节的数据包。 数据包太大了怎么办。该配置吧。 查看max_allowed_packet show global variables like max_allowed_packet;…

用Java实现贪吃蛇小游戏

一、创建新项目 首先创建一个新的项目&#xff0c;并命名为贪吃蛇。 其次在贪吃蛇项目下创建一个名为images的文件夹用来存放游戏相关图片。 然后再在项目的src文件下创建一个com.xxx.view的包用来存放所有的图形界面类&#xff0c;创建一个com.xxx.controller的包用来存放启…

js-webApi 笔记2之DOM事件

目录 1、表单事件 2、键盘事件 3、事件对象 4、排他思想 5、事件流 6、捕获和冒泡 7、阻止默认和冒泡 8、事件委托 9、事件解绑 10、窗口加载事件 11、窗口尺寸事件 12、元素尺寸和位置 13、窗口滚动事件 14、日期对象 15、节点 16、鼠标移入事件 1、表单事件 获取…

SourceTree提示128错误

错误&#xff1a; SourceTree打开报错&#xff1a;git log 失败&#xff0c;错误代码128 错误截图&#xff1a; 解决方法&#xff1a; 第一种&#xff1a; 打开电脑路径 C:\Users\Administrator &#xff0c;删除下面的【.gitconifg】文件 第二种&#xff1a; 如果上述方法…

场景交互与场景漫游-场景漫游器(6)

场景漫游 在浏览整个三维场景时&#xff0c;矩阵变换是非常关键的&#xff0c;通过适当的矩阵变换可以获得各种移动或者渲染效果。因此&#xff0c;在编写自己的场景漫游操作器时&#xff0c;如何作出符合逻辑的矩阵操作器是非常重要的&#xff0c;但这对初学者来说还是有一定难…

黑马React18: 基础Part 1

黑马React: 基础1 Date: November 15, 2023 Sum: React介绍、JSX、事件绑定、组件、useState、B站评论 React介绍 概念: React由Meta公司研发&#xff0c;是一个用于 构建Web和原生交互界面的库 优势: 1-组件化的开发方式 2-优秀的性能 3-丰富的生态 4-跨平台开发 开发环境搭…

鸿蒙ToastDialog内嵌一个xml页面会弹跳到一个新页面《解决》

ToastDialog 土司组件 1.问题展示2.代码展示3.问题分析 1.问题展示 0.理想效果 错误效果: 1.首页展示页面 (未点击按钮前) 2.点击按钮之后&#xff0c;弹窗不在同一个位置 2.代码展示 1.点击按钮的 <?xml version"1.0" encoding"utf-8"?> <…

Jmeter 如何监控目标服务的系统资源

下载Jmeter插件管理下载 perfmon 将这个插件管理放到Jmeter的\lib\ext目录下 然后重启Jmeter jmeter-plugins-manager-1.10.jar 下载 perfmon插件 添加 io 内存 磁盘的监听 并且添加监听 在宿主机中安装代理监听程序 并启动 ServerAgent.tar.gz

Linux常用命令——bzcat命令

在线Linux命令查询工具 bzcat 解压缩指定的.bz2文件 补充说明 bzcat命令解压缩指定的.bz2文件&#xff0c;并显示解压缩后的文件内容。保留原压缩文件&#xff0c;并且不生成解压缩后的文件。 语法 bzcat(参数)参数 .bz2压缩文件&#xff1a;指定要显示内容的.bz2压缩文…

任正非说:公司要逐步实行分灶吃饭,我们在管理上不能过于整齐划一,否则缺少战斗力。

你好&#xff01;这是华研荟【任正非说】系列的第42篇文章&#xff0c;让我们聆听任正非先生的真知灼见&#xff0c;学习华为的管理思想和管理理念。 一、我们必须在混沌中寻找战略方向。规划就是要抓住机会点&#xff0c;委员会是火花荟萃的地方&#xff0c;它预研的方向是可做…

贝加莱MQTT功能

贝加莱实现MQTT Client端的功能库和例程 导入库和例程&#xff0c;AS Logical View中分别通过Add Object—Library&#xff0c;Add—Program插入MQTT库和例程。 将例程Sample放置于CPU循环周期中 定义证书存放路径&#xff0c;在AS Physical View 中&#xff0c;右击PLC—Con…

聚观早报 |零跑C10亮相广州车展;小鹏X9亮相广州车展

【聚观365】11月18日消息 零跑C10亮相广州车展 小鹏X9亮相广州车展 坦克700 Hi4-T开启预售 超A级家轿五菱星光正式预售 哪吒汽车发布山海平台2.0 零跑C10亮相广州车展 零跑汽车首款全球车型C10在广州车展首次亮相&#xff0c;同时该车也是零跑LEAP 3.0技术架构下的首款全…

C++菜鸟日记2

关于getline()函数&#xff0c;在char和string输入的区别 参考博客 1.在char中的使用&#xff1a; 2.在string中的使用&#xff1a; 关于char字符数组拼接和string字符串拼接方法 参考博客 字符串拼接方法&#xff1a; 1.直接用 号 2.利用append&#xff08;&#xff0…

【草料】uni-app ts vue 小程序 如何如何通过草料生成对应的模块化二维码

一、查看uni-app项目 1、找到路径 可以看到项目从 src-race-pages-group 这个使我们目标的查询页面 下面我们将这个路径copy到草料内 2、找到进入页面入参 一般我们都会选择 onload() 函数下的入参 这里我们参数的是 id 二、草料 建议看完这里的教程文档 十分清晰&#xff01…

详解自动化测试之 Selenium

目录 1. 什么是自动化 2.自动化测试的分类 3. selenium&#xff08;web 自动化测试工具&#xff09; 1&#xff09;选择 selenium 的原因 2&#xff09;环境部署 3&#xff09;什么是驱动&#xff1f; 4. 一个简单的自动化例子 5.selenium 常用方法 5.1 查找页面元素&…

【STM32】RTC(实时时钟)

1.RTC简介 本质&#xff1a;计数器 RTC中断是外部中断&#xff08;EXTI&#xff09; 当VDD掉电的时候&#xff0c;Vbat可以通过电源--->实时计时 STM32的RTC外设&#xff08;Real Time Clock&#xff09;&#xff0c;实质是一个 掉电 后还继续运行的定时器。从定时器的角度…

三十分钟学会zookeeper

zookeeper 一、前提知识 集群与分布式 ​ 集群&#xff1a;将一个任务部署在多个服务器&#xff0c;每个服务器都能独立完成该任务。 ​ 分布式&#xff1a;将一个任务拆分成若干个子任务&#xff0c;由若干个服务器分别完成这些子任务&#xff0c;每个服务器只能完成某个特…