使用torch解决线性回归问题

news2024/11/25 19:10:45
  1. 数据处理

import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

data=pd.read_csv('./datasets/Income1.csv') #数据准备

data.head(5)#展示数据
#以上所有的代码都是用jupyter notebook写,形成了阶段性的结果展示

查看数据信息
data.info()

绘制数据的散点图
plt.scatter(data.Education,data.Income) #绘制散点图
plt.xlabel('Education')
plt.ylabel('Income')
plt.show()

数据转换
X=torch.from_numpy(data.Education.to_numpy().reshape(-1,1)).type(torch.FloatTensor)
Y=torch.from_numpy(data.Income.to_numpy().reshape(-1,1)).type(torch.FloatTensor)

上面一行代码首先通过data.Education获取受教育年限这一列,使用to_numpy()方法将其转换为ndarrray数组形式,然后使用reshape方法将其形状设置为二维数组,并且将最后一个维度明确为1。

展示数据的形状
print(X.size(),Y.size())

查看数据集的形状如上图所示,目前数据集是二维的,最后一个维度为1,代表单条数据的长度,前面的30代表数据的个数,所以X和Y这两个数据集的形状可以理解为输入X是30个长度为1的数据,输出Y也是30个长度为1的数据。

模型建立与数据训练

编写训练类
from torch import nn #nn.Module是PyTorch的高阶API

class EIModel(nn.Module):
    def __init__(self):
        super(EIModel,self).__init__()   #继承父类的属性 重写init的方法有两种  一.super(子类,self).__init__() 二.父类.__init__(self)
        self.linear=nn.Linear(in_features=1,out_features=1)#创建线性层
    def forward(self,inputs):
        logits=self.linear(inputs)   #在输入上调用初始化的线性层
        return logits

因为当前模型是一个简单的线性回归模型,只有w和b两个参数,在__init__()方法中,使用nn.Linear()方法初始化一个线性连接层,nn.Linear有两个参数,即in_features和out_features,分别代表输入和输出维度的大小,根据X和Y的size()方法,我们可以输入和输出的维度大小都为1.

创建EIModel对象
model=EIModel() #创建一个ETModel对象

这里的模型就是一个最简单的线性层,也就是所谓的一个线性函数,只有w和b两个参数,接下来是我们定义计算损失函数,并且根据损失值进行梯度优化,优化模型参数

构建损失函数
loss_fn=nn.MSELoss()   #定义了均方误差损失来计算损失函数
opt=torch.optim.SGD(model.parameters(),lr=0.0001)  #初始化一个优化器

第一行代码就是用了PyTorch中内置的均方误差来计算损失函数,第二行初始化了一个内置的优化器,第一个参数为需要优化的变量,通过model.parameters()方法可以获取模型中的所有变量,优化器中参数lr为学习率,也就是在计算梯度时用到的alpha。

迭代训练
for epoch in range(5000):         #对全部的数据训练5000次
    for x,y in zip(X,Y):          #同时对X和Y迭代
        y_pred=model(x)           #调用model得到预测输出y_pred
        loss=loss_fn(y_pred,y)    #根据模型预测输出与实际的值y计算损失
        opt.zero_grad()           #将累计的梯度置为0
        loss.backward()           #反向传播损失,计算损失与模型参数之间的梯度
        opt.step()                #根据计算得到梯度优化模型参数
print("Down!")

模型训练代码有一行opt.zero_grad(),是因为PyTorch会累计每次计算的梯度,使用此代码将上一循环中计算的梯度置为0

展示参数
print(list(model.named_parameters()))#以生成器的形式返回模型参数的名称和值

模型效果展示
#绘制原数据分布的散点图
plt.scatter(data.Education,data.Income,label='real data')
#用我们训练出来的参数,来绘制直线
plt.plot(X,model(X).detach().numpy(),c='r',label='predict line')
plt.xlabel('Education')#设置x轴的标签
plt.ylabel('Income')#设置y轴的标签
plt.legend()#可以显示图例
plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1304058.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Qt的坐标系系统 - 3个坐标系,2个变换

参考: https://zhuanlan.zhihu.com/p/584048811https://www.zhihu.com/tardis/zm/art/634951149?source_id1005 小谈Qt的坐标系系统 Qt中有三个坐标系 设备坐标系窗口坐标系逻辑坐标系 设备坐标系: 即Device坐标系。也是物理坐标系。即真实的的物理坐标系。 …

Ubuntu部署EMQX开源版MQTT服务器-Orange Pi部署-服务器部署

一、前言 作为全球最具扩展性的 MQTT 消息服务器,EMQX 提供了高效可靠海量物联网设备连接,能够高性能实时移动与处理消息和事件流数据,本文将介绍如何在Ubuntu 22.04上部署MQTT服务器。我们本次选择开源版,使用离线安装方式部署。…

ArcGIS pro与SuperMap根据属性自动填充颜色步骤

GIS项目经常会接触到控规CAD数据,想要把数据转换成GIS图层并发布,需要进行专题配图。研究了一下ArcGIS pro和SuperMap iDesktop的配图,整理一下用到的一些技术思路。 1、Excel表格根据RGB值添加单元格填充颜色 要实现如上效果图,…

【NR技术】NR NG-RAN整体架构 -网络接口以及无线协议框架(四)

1 引言 本博文介绍NR NG-RAN的网络节点间的接口以及无线协议框架。网络接口介绍包括RAN和NGC之间的NG接口;无线协议框架包括用户面和控制面协议。 2 NG接口 2.1 NG用户面接口 NG-U (user plane interface)是NG-RAN节点与UPF之间的接口。NG接口的用户平面协议栈如图…

UE4 透明物体不渲染显示??

问题描述:半透明特效在背景(半透明材质模型)前,当半透明特效开始移动的时候,随着速度的加快会逐渐不渲染! 解决办法: 1.设置透明度排序 2.如果还没效果,修改半透明背景模型以下材质…

基于KEDA的Kubernetes自动缩放机制

KEDA以事件驱动的方式实现Kubernetes Pod的动态自动扩容机制,以满足不同的负载需求,从而提高应用可伸缩性和弹性。原文: Dynamic Scaling with Kubernetes Event-driven Autoscaling (KEDA) Kubernetes是容器编排平台的事实标准,已经彻底改变…

vp与相机连接

1.网线 2.相机电源线 1.相机电源线 接头 (使用红色和黑色 线头要串联) 1.光源连接口 2.光源控制器 开关 3.光源控制器通电接口 4.光源自动感应接口 (一般用于自动控制光源 开关) 1.两种不同类型的光源 1.光源亮度控制 2.切…

Leetcode—209.长度最小的子数组【中等】

2023每日刷题&#xff08;五十六&#xff09; Leetcode—209.长度最小的子数组 实现代码 class Solution { public:int minSubArrayLen(int target, vector<int>& nums) {int left 0, right 0;int ans nums.size() 1, s 0;for(; right < nums.size(); righ…

基于单片机的自动售货机(论文+源码)

1.系统设计 本设计以这样的工作流程开始自动售货机的自动售货过程&#xff1a; 启动系统&#xff0c;开始待机&#xff1b;顾客通过按键选择商品的种类以及数量并确认&#xff1b; 售货机检查是否有足够的货物并通过LCD提示等待顾客投币&#xff1b;顾客投入货币&#xff0c;…

使用yum/dnf管理软件包

本章主要介绍使用 yum 对软件包进行管理。 yum 的介绍搭建yum源创建私有仓库yum客户端的配置yum的基本使用使用第三方yum源 使用rpm安装包时经常会遇到一个问题就是包依赖&#xff0c;如下所示。 [rootrhel03 ~]# rpm -ivh /mnt/AppStream/Packages/httpd-2.4.37-41.modulee…

【漏洞复现】云时空社会化商业ERP系统gpy文件上传

漏洞描述 用友软件的先进管理理念,汇集各医药企业特色管理需求,通过规范各个流通环节从而提高企业竞争力、降低人员成本,最终实现全面服务于医药批发、零售连锁企业的信息化建设的目标,是一款全面贴合最新GSP要求的医药流通行业一站式管理系统。 时空云社会化商业ERP gpy…

GNSS - PPP软件 - GAMP 在VS2019/2022下完成调试、跑通程序(超详细!)

目录 一、前期准备 二、调试详细步骤 1.VS中新建项目 2.复制源码至项目文件夹 3.将源码中“.c”文件和“.h”文件添加至项目指定位置 4.修改项目属性&#xff1a; (1)【 配置属性 -> C/C ->预处理器 ->预处理器定义】添加如下 (2)【配置属性->链接器->调…

产品固件烧写方案

1、前言 一成熟的量产的嵌入式产品&#xff0c;软件一般分为BootLoader和App&#xff0c;BootLoader用于启动校验、App升级、App版本回滚等功能&#xff0c;BootLoader在cpu上电第一阶段中运行&#xff0c;之后跳转至App地址执行应用程序。 因此&#xff0c;在发布固件的时候&a…

12.11 C++ 作业

完善对话框&#xff0c;点击登录对话框&#xff0c;如果账号和密码匹配&#xff0c;则弹出信息对话框&#xff0c;给出提示”登录成功“&#xff0c;提供一个Ok按钮&#xff0c;用户点击Ok后&#xff0c;关闭登录界面&#xff0c;跳转到其他界面 如果账号和密码不匹配&#xf…

[UNILM]论文实现:Unified Language Model Pre-training for Natural Language.........

文章目录 一、完整代码二、论文解读2.1 介绍2.2 架构2.3 输入端2.4 结果 三、过程实现四、整体总结 论文&#xff1a;Unified Language Model Pre-training for Natural Language Understanding and Generation 作者&#xff1a;Li Dong, Nan Yang, Wenhui Wang, Furu Wei, Xia…

编程实战:自己编写HTTP服务器(系列4:查看文件、下载等一般功能)

系列入口&#xff1a;编程实战&#xff1a;自己编写HTTP服务器&#xff08;系列1&#xff1a;概述和应答&#xff09;-CSDN博客 本文介绍各种功能的实现。大部分是特定内置入口。 目录 一、默认页 二、查看文件 三、关闭服务 四、下载页面 一、默认页 前面在已经介绍过重定…

UE4/UE5 修改/还原场景所有Actor的材质

使用蓝图方法&#xff1a; 1.修改场景所有Actor 材质&#xff1a; Wirframe&#xff1a;一个材质类 MatList&#xff1a;获取到的所有模型的全部材质 的列表 TempAllClass&#xff1a;场景中所有获取的 Actor 的列表 功能方法如下&#xff1a; 蓝图代码可复制在&#xff1a…

MongoDB在Windows系统和Linux系统中实现自动定时备份

本文主要介绍MongoDB在Windows系统和Linux系统中如何实现自动定时备份。 目录 MongoDB在Windows系统中实现自动定时备份MongoDB在Linux系统中实现自动定时备份备份步骤备份恢复 MongoDB在Windows系统中实现自动定时备份 要在Windows系统中实现自动定时备份MongoDB数据库&#…

界面控件DevExpress中文教程 - 如何用Office File API组件填充PDF表单

DevExpress Office File API是一个专为C#, VB.NET 和 ASP.NET等开发人员提供的非可视化.NET库。有了这个库&#xff0c;不用安装Microsoft Office&#xff0c;就可以完全自动处理Excel、Word等文档。开发人员使用一个非常易于操作的API就可以生成XLS, XLSx, DOC, DOCx, RTF, CS…

《Spring Cloud Alibaba 从入门到实战》分布式消息(事件)驱动

分布式消息&#xff08;事件&#xff09;驱动 1、简介 事件驱动架构(Event-driven 架构&#xff0c;简称 EDA)是软件设计领域内的一套程序设计模型。 这套模型的意义是所有的操作通过事件的发送/接收来完成。 传统软件设计 举个例子&#xff0c;比如一个订单的创建在传统软…