Pytorch解决 多元回归 问题的算法

news2024/11/19 16:45:19

Pytorch解决 多元回归 问题的算法

回归是一种基本的统计建模技术,用于建立因变量与一个或多个自变量之间的关系。
我们将使用 PyTorch(一种流行的深度学习框架)来开发和训练线性回归模型。

在这里插入图片描述

二元回归的简单示例

训练数据集(可获取)

对于此分析,我们将使用scikit-learn 库中的 make regression() 函数生成的合成数据集。数据集由输入特征和目标变量组成。输入特征代表自变量,而目标变量代表我们想要预测的因变量

import seaborn as sns
import numpy as sns
import torch
import torch.nn as nn
import torch.optim as optim
import sklearn
from sklearn import datasets
import pandas as pd

data=datasets.make_regression()    # from sklearn we are going to select one dataset
df = pd.DataFrame(data[0], columns=[f"feature_{i+1}" for i in range(data[0].shape[1])])
df["target"] = data[1]

在这里插入图片描述

数据的结构,100 rows × 101 columns,最后 1 column为目标值

准备训练集与测试集

PyTorch 是一个功能强大的开源深度学习框架,提供了一种灵活的方式来构建和训练神经网络。它提供了一系列张量运算、自动微分和优化算法的功能。

使用 sklearn Train-Test-split 准备数据以开发模型

x=df.iloc[: , :-1]   # 除目标数据身下所以的
y=df.iloc[: , -1]    # target

from sklearn.model_selection import train_test_split
X_train,X_test,y_train,y_test=train_test_split(x,y,test_size=0.2,random_state=42)
print(type(X_train))
# X_train=torch.tensor(X_train,dtype=torch.float32)

X_train = torch.tensor(X_train.values, dtype=torch.float32)  # 转化为 tensor
X_test = torch.tensor(X_test.values, dtype=torch.float32)
y_train = torch.tensor(y_train.values, dtype=torch.float32)
y_test = torch.tensor(y_test.values, dtype=torch.float32)

模型架构

数据准备好了,可以准备模型了

我们的线性回归模型是作为PyTorch 中nn.Module类的子类实现的。该模型由多个按顺序连接的完全连接(线性)层组成。

class linearRegression(nn.Module): 
# 所有来自torch的依赖项将被传递给这个类[父类] 
# nn.Module 包含了神经网络的所有构建模块:
  def __init__(self,input_dim):
    super(linearRegression,self).__init__()   # building connection with parent and child classes
    self.fc1=nn.Linear(input_dim,10)          # hidden layer 1
    self.fc2=nn.Linear(10,5)                  # hidden layer 2
    self.fc3=nn.Linear(5,3)                   # hidden layer 3
    self.fc4=nn.Linear(3,1)                   # last layer

  def forward(self,d):
    out=torch.relu(self.fc1(d))              # input * weights + bias for layer 1
    out=torch.relu(self.fc2(out))            # input * weights + bias for layer 2
    out=torch.relu(self.fc3(out))            # input * weights + bias for layer 3
    out=self.fc4(out)                        # input * weights + bias for last layer
    return out                               # final outcome

input_dim=X_train.shape[1]     # 获取 input_dim 变量的数量
torch.manual_seed(42)          # to make initilized weights stable:
model=linearRegression(input_dim)
# select loss and optimizers

loss=nn.MSELoss() # loss function
optimizers=optim.Adam(params=model.parameters(),lr=0.01)

loss_values_all = []  # 创建一个列表来存储每个迭代的loss值

# training the model:

num_of_epochs=1000
for i in range(num_of_epochs):
  # give the input data to the architecure
  y_train_prediction=model(X_train)  # model initilizing
  loss_value=loss(y_train_prediction.squeeze(),y_train)   # find the loss function:
  optimizers.zero_grad() # make gradients zero for every iteration so next iteration it will be clear
  loss_value.backward()  # back propagation
  optimizers.step()      # update weights in NN
  
  loss_values_all.append(loss_value.item())  # 将当前的loss值添加到列表中

  # print the loss in training part:
  if i % 10 == 0:
    print(f'[epoch:{i}]: The loss value for training part={loss_value}')

绘制 loss 曲线图
在这里插入图片描述
在测试数据集上的效果(test data)

with torch.no_grad():
  model.eval()   # make model in evaluation stage
  y_test_prediction=model(X_test)
  test_loss=loss(y_test_prediction.squeeze(),y_test)
  print(f'Test loss value : {test_loss.item():.4f}')

测试自己随机生成的数据

# Inference with own data:
pr = torch.tensor(torch.arange(1, 101).unsqueeze(dim=0), dtype=torch.float32).clone().detach()
print(pr)

保存训练好的模型

# save the torch model:

from pathlib import Path

filename=Path('models')
filename.mkdir(parents=True,exist_ok=True)

model_name='linear_regression.pth' # model name

# saving path

saving_path=filename/model_name
print(saving_path)
torch.save(obj=model.state_dict(),f=saving_path)

# we can load the saved model and do the inference again:

load_model=linearRegression(input_dim) # creating an instance again for loaded model
load_model.load_state_dict(torch.load('./models/linear_regression.pth'))

load_model.eval()   # make model in evaluation stage
with torch.no_grad():
  pred = load_model(torch.tensor([[  1.,   2.,   3.,   4.,   5.,   6.,   7.,   8.,   9.,  10.,  11.,  12.,
          13.,  14.,  15.,  16.,  17.,  18.,  19.,  20.,  21.,  22.,  23.,  24.,
          25.,  26.,  27.,  28.,  29.,  30.,  31.,  32.,  33.,  34.,  35.,  36.,
          37.,  38.,  39.,  40.,  41.,  42.,  43.,  44.,  45.,  46.,  47.,  48.,
          49.,  50.,  51.,  52.,  53.,  54.,  55.,  56.,  57.,  58.,  59.,  60.,
          61.,  62.,  63.,  64.,  65.,  66.,  67.,  68.,  69.,  70.,  71.,  72.,
          73.,  74.,  75.,  76.,  77.,  78.,  79.,  80.,  81.,  82.,  83.,  84.,
          85.,  86.,  87.,  88.,  89.,  90.,  91.,  92.,  93.,  94.,  95.,  96.,
          97.,  98.,  99., 100.]]))

  print(f'prediction value : {pred.item()}')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1811538.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

分离式光电液位传感器有哪些特点?

分离式光电液位传感器是一种先进的液位检测技术,在科学技术的不断推进下得到了广泛应用。相比传统的液位传感器,分离式光电液位传感器具有许多独特的特点。 传感器采用了先进的光学技术,将传感器装在需要检测液位的位置,并采用了…

非GIS专业,真的不适合WebGIS开发吗?

到底是哪些人在新中地特训营学GIS开发? 很多同学对GIS开发的认知还停留在GIS专业的学生才能学GIS开发,然而,在新中地教育,非GIS专业的学生几乎占一半。 除了GIS专业,还有测绘、遥感等跟GIS差别不大的专业学生选择来学…

面试官:你讲下接口防重放如何处理?

前言 我们的API接口都是提供给第三方服务/客户端调用,所有请求地址以及请求参数都是暴露给用户的。 我们每次请求一个HTTP请求,用户都可以通过F12,或者抓包工具fd看到请求的URL链接,然后copy出来。这样是非常不安全的,有人可能会…

鸿蒙轻内核A核源码分析系列二 数据结构-位图操作

在进一步分析之前,本文我们先来熟悉下OpenHarmony鸿蒙轻内核提供的位操作模块,在互斥锁等模块对位操作有使用。位操作是指对二进制数的bit位进行操作。程序可以设置某一变量为状态字,状态字中的每一bit位(标志位)可以具…

Kubernetes——HPA自动伸缩机制

目录 前言 一、概念 1.定义 2.核心概念 3.工作原理 4.HPA的配置关键参数 5.关键组件 5.1HPA控制器(HPA Controller) 5.2Metrics Server 5.3自定义指标适配器(Custom Metrics Adapter) 5.4Deployment/ReplicaSet 5.5Po…

文件二维码能快速生成吗?多种类型文件生成二维码的方法

现在将文件做成二维码是一种很常用形式,通过二维码来存储多个文件,在手机上扫码查看内容,对于文件的安全性和用户体验都有很好的提升。用户无需下载文件,扫码就可以快速在线阅读或者下载文件内容,有利于文件的快速分享…

2024年智能制造行业CRM研究(附需求清单、市场格局、选型建议)

在国家大力鼓励智能制造行业与数字化转型这个大背景下,我们选择了2024年智能制造行业数字化的几个关键趋势做深入解读,并对智能制造行业核心的数字化系统CRM进行了全面评估与排名。本文不仅提供了详尽的需求清单,帮助企业明确自身对CRM系统的…

不定时更新 解决无法访问GitHub github.com 打不开 访问加速

1 修改hosts Windows 10为例,‪文件C:\Windows\System32\drivers\etc\hosts 管理员打开记事本来修改 文件-打开-“C:\Windows\System32\drivers\etc\hosts” 20.205.243.168 api.github.com 185.199.108.154 github.githubassets.com 185.199.108.133 raw.githubusercontent.…

1.ei论文会被scopus检索吗文被其检索吗?

ei论文会被scopus检索吗 scopus数据库能检索的专业范围是比较广泛的,涵盖了医学,地球环境科学,化学,数学,工程学,物理,生物科学等领域,也收录了很多会议论文,那么ei论文…

JwtAccessConverterJwtTokenStorejdbc建表结构

文章目录 JWT实现MacTestMacSigner Rsa生成jks证书需要先安装opensslkeytool生成jks (Java Key Store) 文件测试密钥 JwtTokenStoreInMemoryTokenStore&RedisTokenStore&JdbcTokenStore&JwtTokenStore图解JwtTokenStore详解 jdbc实现表结构说明1oauth_client_detai…

Win10系统自带输入法英文变大的问题

现在习惯使用Windows自带的五笔输入法了,但一直以来总会遇到输入时突然英文字母变大了,相隔空间也变大了的情况。 asdfghkl asdfghkl 后来知道这是输入法变…

如何在Linux虚拟机服务器上配置和部署Java项目?

在Linux虚拟机上配置和部署Java项目,通常涉及以下步骤: 1. 准备Linux虚拟机 选择合适的Linux发行版 :根据项目需求和个人熟悉程度,选择如Ubuntu LTS、CentOS Stream或Debian等发行版。 安装虚拟机软件 :在宿主机&#…

css图片适配,不随屏幕的大小变化

.carimg {width: 100%;height: 100%;max-width: 100%;max-height: 100%;object-fit: cover; } <img class"carimg" :src"item.imageUrl" alt"" /> 效果&#xff1a; 全屏时 屏幕变小时

985找工作都这么难了吗

经历如下 也不知道自己适合干啥 好想找份实习入门啊

无人机RTMP推流EasyDSS直播平台推流成功,不显示直播按钮是什么原因?

互联网视频云平台/视频点播直播/视频推拉流EasyDSS支持HTTP、HLS、RTMP等播出协议&#xff0c;并且兼容多终端&#xff0c;如Windows、Android、iOS、Mac等。为了便于用户集成与二次开发&#xff0c;我们也提供了API接口供用户调用和集成。在无人机场景上&#xff0c;可以通过E…

Allegro铺铜以及分割操作

Allegro铺铜以及分割操作 一、铺铜全局设置 点击Shape–>Global Dynamic Shape Parameters&#xff0c;在Shape fill中选择Smooth&#xff0c;其他不用管&#xff0c;这个是在铺铜的时候自动避让不同网络&#xff0c;在Void controls中一般填写如下参数&#xff0c;即避让…

cocos creator3.7版本拖拽事件处理

前言&#xff1a;网上能找到的资料都太落后了&#xff0c;导致哥们用AI去写&#xff0c;全是瞎B写&#xff0c;版本都不对。贴点实际有用的。别老捣鼓你那破convertToNodeSpaceAR或者convertToNodeSpace了。 核心代码 touch.getDeltaX() touch.getDeltaY() 在cocoscreator3…

windows系统下安装fnm

由于最近做项目要切换多个node版本&#xff0c;查询了一下常用的有nvm和fnm这两种&#xff0c;对比了一下选择了fnm。 下载fnm 有两种方式&#xff0c;目前最新版本是1.37.0&#xff1a; 1.windows下打开powershell&#xff0c;执行以下命令下载fnm winget install Schniz.f…

Linux上手实验七:网络配置与管理

lab 2.3 网络配置与管理 1、实验背景&#xff1a; 做为服务器操作系统的企业版 CentOS 7 的网络性能及网络管理是非常重要&#xff0c;往往在部署重要应用的先决条件就是&#xff0c;配置好网络参数及测试好网络的联通性。 作为 Linux 管理员&#xff0c;对于系统的网络配置和…

PNAS | 工作记忆中大脑节律的因果功能图

摘要 工作记忆是一个涉及大脑中多个功能解剖节点的关键认知过程。尽管有大量与工作记忆结构相关的神经影像学证据&#xff0c;但我们对控制整体表现的关键中枢的理解并不完整。因果解释需要在对特定功能解剖节点进行安全、暂时和可控的神经调节后进行认知测试。随着经颅交流电…