PyTorch实验—回归任务

news2025/1/23 13:10:25

PyTorch回归任务

回归任务概述:通过pytorch搭建神经网络,进行气温的预测
回归任务可以看作 y = kx + b
y为需要进行回归预测的值

下面对实验步骤进行整理

导入相关的库

import numpy as np
import pandas as pd 
import matplotlib.pyplot as plt
import torch
import torch.optim as optim
import warnings
warnings.filterwarnings("ignore")
%matplotlib inline

主要包括了相关的科学计算库,和torch中的优化器

查看并说明数据集

features = pd.read_csv('temps.csv')

#看看数据长什么样子
features.head()

在这里插入图片描述

数据表中

  • year,moth,day,week分别表示的具体的时间
  • temp_2:前天的最高温度值
  • temp_1:昨天的最高温度值
  • average:在历史中,每年这一天的平均最高温度值
  • actual:这就是我们的标签值了,当天的真实最高温度
  • friend:这一列可能是凑热闹的,你的朋友猜测的可能值,咱们不管它就好了

查看并输出数据的维度
数据维度: (348, 9)

print('数据维度:', features.shape)

将时间进行格式化并输出格式化的结果

# 处理时间数据
import datetime

# 分别得到年,月,日
years = features['year']
months = features['month']
days = features['day']

# datetime格式
dates = [str(int(year)) + '-' + str(int(month)) + '-' + str(int(day)) for year, month, day in zip(years, months, days)]
dates = [datetime.datetime.strptime(date, '%Y-%m-%d') for date in dates]
dates[:5]

在这里插入图片描述

绘图将多个数据进行显示便于对比

设置子图布局完成子图的绘制 4个子图
时间设置倾斜

# 准备画图
# 指定默认风格
plt.style.use('fivethirtyeight')

# 设置布局
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(nrows=2, ncols=2, figsize = (10,10))
fig.autofmt_xdate(rotation = 45)

# 标签值
ax1.plot(dates, features['actual'])
ax1.set_xlabel(''); ax1.set_ylabel('Temperature'); ax1.set_title('Max Temp')

# 昨天
ax2.plot(dates, features['temp_1'])
ax2.set_xlabel(''); ax2.set_ylabel('Temperature'); ax2.set_title('Previous Max Temp')

# 前天
ax3.plot(dates, features['temp_2'])
ax3.set_xlabel('Date'); ax3.set_ylabel('Temperature'); ax3.set_title('Two Days Prior Max Temp')

# 我的逗逼朋友
ax4.plot(dates, features['friend'])
ax4.set_xlabel('Date'); ax4.set_ylabel('Temperature'); ax4.set_title('Friend Estimate')

plt.tight_layout(pad=2)

在这里插入图片描述

使用one-hot编码对星期字符串进行处理

# 独热编码
features = pd.get_dummies(features)
features.head(5)

在这里插入图片描述

在训练之前对数据特征进行预处理

去除特征中的标签项

# 标签
labels = np.array(features['actual'])

# 在特征中去掉标签
features= features.drop('actual', axis = 1)

# 名字单独保存一下,以备后患
feature_list = list(features.columns)

# 转换成合适的格式
features = np.array(features)

查看新的数据特征的维度

features.shape

在这里插入图片描述

使用sklearn库结合标准正态分布的相关理论对整体的数据进行标准化的处理

from sklearn import preprocessing
input_features = preprocessing.StandardScaler().fit_transform(features)

input_features[0]

在这里插入图片描述

构建网络模型

将narrdry格式转化为tensor(张量)进行输入

x = torch.tensor(input_features, dtype = float)

y = torch.tensor(labels, dtype = float)

在神经网络的隐层中设置128个神经元,因为每一条数据拥有14个特征,因此设置一个14x128的特征矩阵,这时的b1需要更新的此时为128x1

在进行结果输出是,只需要输出一个结果所以定义一个128*1的特征矩阵,此时的b2只需要更新1次

# 权重参数初始化
weights = torch.randn((14, 128), dtype = float, requires_grad = True) 
biases = torch.randn(128, dtype = float, requires_grad = True) 
weights2 = torch.randn((128, 1), dtype = float, requires_grad = True) 
biases2 = torch.randn(1, dtype = float, requires_grad = True) 

设置学习率和损失值

learning_rate = 0.001 
losses = []

通过矩阵的乘法计算w1x+b1
计算隐层
hidden = x.mm(weights) + biases

加入激活函数进行非线性处理
加入激活函数
hidden = torch.relu(hidden)

之后通过w2和b2来计算与预测的结果
预测结果
predictions = hidden.mm(weights2) + biases2

通过与真实结果的比较来计算损失函数,结合反向传播的算法更新参数信息,沿着梯度值的方向和学习率进行更新操作

#返向传播计算
    loss.backward()
    
    #更新参数
    weights.data.add_(- learning_rate * weights.grad.data)  
    biases.data.add_(- learning_rate * biases.grad.data)
    weights2.data.add_(- learning_rate * weights2.grad.data)
    biases2.data.add_(- learning_rate * biases2.grad.data)

梯度的反方向进行更新
迭代完成之后进行清0处理

每次迭代都得记得清空
weights.grad.data.zero_()
biases.grad.data.zero_()
weights2.grad.data.zero_()
biases2.grad.data.zero_()

x = torch.tensor(input_features, dtype = float)

y = torch.tensor(labels, dtype = float)

# 权重参数初始化
weights = torch.randn((14, 128), dtype = float, requires_grad = True) 
biases = torch.randn(128, dtype = float, requires_grad = True) 
weights2 = torch.randn((128, 1), dtype = float, requires_grad = True) 
biases2 = torch.randn(1, dtype = float, requires_grad = True) 

learning_rate = 0.001 
losses = []

for i in range(1000):
    # 计算隐层
    hidden = x.mm(weights) + biases
    # 加入激活函数
    hidden = torch.relu(hidden)
    # 预测结果
    predictions = hidden.mm(weights2) + biases2
    # 通计算损失
    loss = torch.mean((predictions - y) ** 2) 
    losses.append(loss.data.numpy())
    
    # 打印损失值
    if i % 100 == 0:
        print('loss:', loss)
    #返向传播计算
    loss.backward()
    
    #更新参数
    weights.data.add_(- learning_rate * weights.grad.data)  
    biases.data.add_(- learning_rate * biases.grad.data)
    weights2.data.add_(- learning_rate * weights2.grad.data)
    biases2.data.add_(- learning_rate * biases2.grad.data)
    
    # 每次迭代都得记得清空
    weights.grad.data.zero_()
    biases.grad.data.zero_()
    weights2.grad.data.zero_()
    biases2.grad.data.zero_()

在这里插入图片描述

更简单的构建网络模型

input_size = input_features.shape[1]
hidden_size = 128
output_size = 1
batch_size = 16 #16个批次进行训练
my_nn = torch.nn.Sequential(
    torch.nn.Linear(input_size, hidden_size),
    torch.nn.Sigmoid(),
    torch.nn.Linear(hidden_size, output_size),
)
cost = torch.nn.MSELoss(reduction='mean')
optimizer = torch.optim.Adam(my_nn.parameters(), lr = 0.001)

训练网络得到对应的结果

# 训练网络
losses = []
for i in range(1000):
    batch_loss = []
    # MINI-Batch方法来进行训练
    for start in range(0, len(input_features), batch_size):
        end = start + batch_size if start + batch_size < len(input_features) else len(input_features)
        xx = torch.tensor(input_features[start:end], dtype = torch.float, requires_grad = True)
        yy = torch.tensor(labels[start:end], dtype = torch.float, requires_grad = True)
        prediction = my_nn(xx)
        loss = cost(prediction, yy)
        optimizer.zero_grad()
        loss.backward(retain_graph=True)
        optimizer.step()
        batch_loss.append(loss.data.numpy())
    
    # 打印损失
    if i % 100==0:
        losses.append(np.mean(batch_loss))
        print(i, np.mean(batch_loss))

在这里插入图片描述

得到预测结果并绘制图像

其中红色代表的是与预测值

# 转换日期格式
dates = [str(int(year)) + '-' + str(int(month)) + '-' + str(int(day)) for year, month, day in zip(years, months, days)]
dates = [datetime.datetime.strptime(date, '%Y-%m-%d') for date in dates]

# 创建一个表格来存日期和其对应的标签数值
true_data = pd.DataFrame(data = {'date': dates, 'actual': labels})

# 同理,再创建一个来存日期和其对应的模型预测值
months = features[:, feature_list.index('month')]
days = features[:, feature_list.index('day')]
years = features[:, feature_list.index('year')]

test_dates = [str(int(year)) + '-' + str(int(month)) + '-' + str(int(day)) for year, month, day in zip(years, months, days)]

test_dates = [datetime.datetime.strptime(date, '%Y-%m-%d') for date in test_dates]

predictions_data = pd.DataFrame(data = {'date': test_dates, 'prediction': predict.reshape(-1)}) 
# 真实值
plt.plot(true_data['date'], true_data['actual'], 'b-', label = 'actual')

# 预测值
plt.plot(predictions_data['date'], predictions_data['prediction'], 'ro', label = 'prediction')
plt.xticks(rotation = '60'); 
plt.legend()

# 图名
plt.xlabel('Date'); plt.ylabel('Maximum Temperature (F)'); plt.title('Actual and Predicted Values');

在这里插入图片描述
此时并没有出现预测值过拟合的问题

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/606301.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

张小飞的Java之路——第四十四章——其他流对象

写在前面&#xff1a; 视频是什么东西&#xff0c;有看文档精彩吗&#xff1f; 视频是什么东西&#xff0c;有看文档速度快吗&#xff1f; 视频是什么东西&#xff0c;有看文档效率高吗&#xff1f; 诸小亮&#xff1a;这一节&#xff0c;我们介绍一下其他不常用的流对象 …

SAP-MM-分割评估-评估类型-评估类别

同一物料的使用&#xff0c;既有“自制品”&#xff0c;又有“外购品”&#xff0c;并且其来源不同&#xff0c;如同一外购品由不同的供应商提供&#xff0c;价格也不相同。也就是说:同一物料有不同的价值指派&#xff0c;即在不同的条件下&#xff0c;同一物料可能有不同的价值…

智能型数字档案馆构建设想

档案作为企业正式权威的数据资源&#xff0c;具有其历史传承和凭证唯一性等特点&#xff0c;随着企业的数字化转型&#xff0c;档案工作更需要数字化转型&#xff0c;档案管理与利用急需借助信息技术手段来管理好和记录好&#xff0c;急需挖掘档案资源&#xff0c;发挥其价值&a…

01.硬盘启动盘,加载操作系统

硬盘启动盘&#xff0c;加载操作系统 模拟硬盘加载操作系统 环境&#xff1a; VMware16 Ubuntu16.04 qemu bochs 2.7 参考: 启动&#xff0c;BIOS&#xff0c;MBR 硬盘控制器主要端口寄存器 《操作系统真相还原》 1.系统开机流程 暂不构建中断向量表&#xff0c;直接加载MBR …

Knowledge Distillation: A Survey

本文是蒸馏学习综述系列的第一篇文章&#xff0c;主要是针对2021年 IJCV Knowledge Distillation: A Survey的一个翻译。 知识蒸馏&#xff1a;综述 摘要1 引言2 知识2.1 基于响应的知识2.2 基于特征的知识2.3 基于关系的知识 3 蒸馏方案3.1 离线蒸馏3.2 在线蒸馏3.3 自蒸馏 4…

你真的了解epoll吗?深入epoll的五个问题

由于epoll用的比较多&#xff0c;最近看到一些网友关于epoll的问答&#xff0c;所以我就想整理成一篇文章&#xff0c;这样看起来和理解起来都方便一些。 问题1&#xff1a;什么是epoll的ET/LT模式&#xff0c;select/poll支持吗&#xff1f; ET是edge trigger&#xff0c;也…

K8s in Action 阅读笔记——【9】Deployments: updating applications declaratively

K8s in Action 阅读笔记——【9】Deployments: updating applications declaratively 集群配置&#xff1a; 本章介绍如何更新运行在Kubernetes集群中的应用&#xff0c;以及Kubernetes如何帮助你实现真正的零停机更新过程。虽然这可以仅使用ReplicationControllers或ReplicaSe…

【Spring】javaBean、依赖注入、面向切面AOP、使用注解开发

文章目录 JavaBeanIoC理论基础使用IoC容器使用Spring 生命周期与继承生命周期继承 依赖注入 Dependency Injection基本类型注入非基本类型注入集合注入自动装配注入 面向切面AOP使用SpringAOP环绕方法 使用接口实现AOP 使用注解开发注解实现配置文件注解实现AOP操作其他注解配置…

MongoDB6.0.6 副本集搭建(CentOs8)

本文只说如何操作配置副本集&#xff0c;历程艰难&#xff0c;官网文档看了半天也只说了怎么添加单个&#xff0c;没有给出来一个完整的流程。 1.第一步安装&#xff0c;参考前一篇安装即可。 配置三台虚拟机&#xff1a; 192.168.182.142 192.168.182.143 192.168.182.14…

钉钉小程序页面之间传递数据如何传递对象

今天写代码的时候&#xff0c;发现了一个问题&#xff0c;在钉钉小程序页面之间传递对象数据的时候&#xff0c;如果直接传递一个对象那个的话&#xff0c;接收的那个页面得到的是一个【object Object】&#xff0c;而并非里面的数据&#xff0c;所以针对上述问题&#xff0c;下…

基于Spring Boot的学生志愿者管理系统的设计与实现

摘 要 信息化社会内需要与之针对性的信息获取途径&#xff0c;但是途径的扩展基本上为人们所努力的方向&#xff0c;由于站在的角度存在偏差&#xff0c;人们经常能够获得不同类型信息&#xff0c;这也是技术最为难以攻克的课题。针对学生志愿者管理等问题&#xff0c;对学生…

项目计划软件 project安装包的下载和安装教程

目录 简介 安装配置过程 总结&#xff1a; 简介 Project是由微软公司开发的项目管理软件&#xff0c;旨在帮助个人和团队有效地管理项目进度、资源分配、协作和报告等工作&#xff0c;从而提高项目的质量和效率。Project维护项目的进程表、资源清单、成本预算、工作表和报告…

CSS-HTML知识点与高频考题解析

知识点梳理 选择器的权重和优先级 盒模型 盒子大小计算margin 的重叠计算 浮动 float浮动布局概念清理浮动 定位 position文档流概念定位分类fixed 定位特点绝对定位计算方式 flex布局 如何实现居中对齐&#xff1f; 理解语义化 CSS3 动画 重绘和回流 选择器的权重和优…

VTK 开发中遇到问题整理

1 Generic Warning VTK 开发 中是到 vtkOutputWindow 弹窗并提示Generic Warning&#xff1a;… vtkOutputWindow 弹窗 解决方法&#xff1a; 添加&#xff1a; #include <vtkOutputWindow.h> 在 main.cpp函数中添加&#xff1a; vtkOutputWindow::SetGlobalWarningD…

petalinux2022.2在ubantu20.04下的安装

1.Petalinux的下载路径 Downloads 这个是下载petalinux的官网路径。默认是2022.2版本&#xff0c;后期更新的均是以petalinux2022.2版本做的更新。 2.安装流程 在官网下载完成之后&#xff0c;会得到一个名为petalinux-v2022.2-10141622-installer.run的文件&#xff0c;这个文…

linux|磁盘管理工作|lvm逻辑管理卷的创建和使用总结(包括扩容,根目录扩容演示)

前言&#xff1a; 对于运维工作来说&#xff0c;磁盘管理是一个非常重要的工作。当然了&#xff0c;此类工作也是比较偏向底层的一项工作。 一个合理的磁盘分区设置&#xff0c;文件系统格式&#xff0c;以及准确的lvm逻辑管理会对我们的后期的扩展工作&#xff0c;管理工作带…

深入理解设计原则之单一职责原则(SRP)【软件架构设计】

系列文章目录 C高性能优化编程系列 深入理解软件架构设计系列 深入理解设计模式系列 高级C并发线程编程 SRP&#xff1a;单一职责原则 系列文章目录1、单一职责原则的定义和解读2、单一职责原则案例解读2.1、违背单一职责原则反面案例2.2、违背单一职责原则反面案例 - 解决方…

Openwrt_XiaoMiR3G路由器_刷入Breed固件

当我刷完Breed后&#xff0c;重启没有进入原来的小米路由器固件&#xff0c;但可以进入breed控制台。目前不清楚那个环节出错了。所以本过程会导致路由器无法再直接使用&#xff01;&#xff01;&#xff01;。 本过程只刷入Breed&#xff0c;接着编译OpenWrt和刷入OpenWrt请参…

git命令的使用

1. 查看文件 git cat-file -p 仓库路径下右键 Git Bash Here 打开git命令窗口&#xff1a; 复制某个文件的版本号&#xff1a; 粘贴到git命令窗口&#xff0c;会显示文件的提交信息&#xff1a; 查看 tree后面的版本号&#xff0c;则会看到详细提交信息&#xff1a; 查看hell…

第8章 泛型程序设计

文章目录 为什么要使用泛型程序设计类型参数的好处谁想成为泛型程序员 定义简单泛型类泛型方法类型变量的限定泛型代码和虚拟机类型擦除转换泛型表达式转换泛型方法类型擦除与多态会发生冲突桥方法实现多态桥方法与可协变的返回类型 调用遗留代码 限制与局限性泛型类型的继承规…