使用Pytorch写简单线性回归

news2025/1/15 13:44:49

文章目录

  • Pytorch
    • 一、Pytorch 介绍
    • 二、概念
    • 三、应用于简单线性回归
  • 1.代码框架
  • 2.引用
  • 3.继续模型
    • (1)要定义一个模型,需要继承`nn.Module`:
    • (2)如果函数的参数不具体指定,那么就需要在`__init__`函数中添加未指定的变量:
  • 2.定义数据
  • 3.实例化模型
  • 4.损失函数
  • 5.优化器
  • 6.模型训练
  • 7.绘制数据

Pytorch

一、Pytorch 介绍

  PyTorch 是一个开源的深度学习框架,由 Facebook 的人工智能研究团队开发。它主要用于构建和训练深度学习模型,具有以下特点:
  动态计算图:PyTorch 使用动态计算图,这意味着可以在运行时动态地构建、修改和执行计算图,使得开发和调试更加灵活。
  易于使用:提供了简洁直观的 API,使得开发者可以快速上手,专注于模型的设计和实现。
  强大的 GPU 加速:支持在 GPU 上进行高效的并行计算,大大加快了训练和推理的速度。
  广泛的社区支持:拥有庞大的开发者社区,提供了丰富的教程、示例和第三方扩展。

二、概念

  张量(Tensor):是 PyTorch 中的基本数据结构,类似于多维数组,可以在 CPUGPU 上存储和操作数据。
  自动求导(Autograd):PyTorch 能够自动计算张量的梯度,这对于训练深度学习模型非常重要,因为它可以通过反向传播算法自动更新模型参数。
  模块(Module):是 PyTorch 中构建模型的基本单元,可以包含多个子模块和参数。
  优化器(Optimizer):用于优化模型参数,常见的优化算法如随机梯度下降(SGD)、Adam 等。
  损失函数(Loss Function):用于衡量模型预测与真实值之间的差异,常见的损失函数有均方误差(MSE)、交叉熵损失等。

三、应用于简单线性回归

  线性回归是一种简单的机器学习算法,用于预测一个连续的数值。下面是使用 PyTorch 实现简单线性回归的步骤:
  准备数据:
  生成一些随机的输入数据和对应的输出数据。例如,假设我们要拟合一个线性函数 y = 2x + 1,可以生成一些随机的 x 值,并计算出对应的 y 值。
  定义模型:
  使用 PyTorch 的模块类定义一个简单的线性回归模型。这个模型通常包含一个线性层,即一个全连接层,它将输入特征映射到输出。
  定义损失函数和优化器:
  选择一个合适的损失函数,如均方误差(MSE)损失。
  选择一个优化器,如随机梯度下降(SGD)优化器,并设置学习率等参数。
  训练模型:
  将数据分成小批次,每次输入一个批次的数据到模型中进行前向传播,计算损失。
  然后进行反向传播,计算梯度,并使用优化器更新模型参数。
  重复这个过程直到达到预定的训练次数或损失收敛。
  测试模型:
  使用训练好的模型对新的数据进行预测,评估模型的性能。

1.代码框架

在这里插入图片描述

2.引用

import torch        
from torch import nn
from torch import optim

3.继续模型

  继承模型主要都是在nn.Module类

(1)要定义一个模型,需要继承nn.Module

class EIModel(nn.Module):
    def __init__(self):
        super(EIModel,self).__init__()   #等价于super().__init__()  
        self.linear=nn.Linear(in_features=1,out_features=1)   #创建线性层

    def forward(self,inputs):
        logits=self.linear(inputs)
        return logits   

  注:forward()return切记要写上

(2)如果函数的参数不具体指定,那么就需要在__init__函数中添加未指定的变量:

class EIModel(nn.Module):
    def __init__(self,in_features,out_features):
        super(EIModel,self).__init__()
        self.linear=nn.Linear(in_features,out_features)  

    def forward(self,inputs):
        logits=self.linear(inputs)
        return logits

  注:这时在实例化模型时,函数内要指定参数:

model = EIModel(in_features=1,out_features=1)

2.定义数据

x_list=[0,1,2,3,4]
y_list=[2,3,4,5,8]

x_numpy=np.array(x_list,dtype=np.float32)
x=torch.from_numpy(x_numpy.reshape(-1,1))
y_numpy=np.array(y_list,dtype=np.float32)
y=torch.from_numpy(y_numpy.reshape(-1,1))

3.实例化模型

model = EIModel()

  直接调用模型

import torchvision.models as models
models.resnet50()

  测试模型预测结果

outputs=model(x)
print(outputs)

  结果:

tensor([[-0.9462],
        [-1.4654],
        [-1.9846],
        [-2.5038],
        [-3.0230]], grad_fn=<AddmmBackward>)

4.损失函数

  nn.MSELoss()定义均方误差损失计算函数
(1)loss_f=nn.MSELoss()
(2)loss_f=nn.CrossEntropyLoss()

5.优化器

  torch.optim.SGD()是一个内置的优化器
  它的第一个参数是需要优化的变量,可以通过model.parameters()方法获取模型中所有变量
lr=0.0001定义学习率
  (1)opt=torch.optim.SGD(model.parameters(),lr=0.0001)
  (2)optimizer_ft=optim.Adam(params_to_update,lr=1e-2)
  Adam优点:可以自动调整学习效率

6.模型训练

  (1)因为pytorch会累积每次计算的梯度,所以需要将上一循环中的计算的梯度归零
将全部数据训练一遍称为一个epoch,这里训练了500epoch

for epoch in range(500):
    for x_index,y_index in zip(x,y): #同时对x和y迭代
        y_pred=model(x_index)        #等价于model.forward(inputs)
        loss=loss_f(y_pred,y_index)  #根据模型预测输出与实际值y_index计算损失
        opt.zero_grad()              #将累计的梯度清0
        loss.backward()              #反向传播损失,计算损失与模型参数之间的梯度
        opt.step()                   #根据计算得到梯度优化模型参数

  (2)将损失误差打印出来

for epoch in range(500):
    for x_index,y_index in zip(x,y):   
        y_pred=model(x_index)
        loss=loss_f(y_pred,y_index)
        opt.zero_grad()     #将累计的梯度清0
        loss.backward()     #反向传播损失,计算损失与模型参数之间的梯度
        opt.step()          #根据计算得到梯度优化模型参数

    if (epoch + 1) % 50 == 0:
        print(f'epoch:{epoch + 1}, loss = {loss.item():.4f}')

  结果:

epoch:50, loss = 12.1212
epoch:100, loss = 7.1772
epoch:150, loss = 4.4344
epoch:200, loss = 2.8781
epoch:250, loss = 1.9724
epoch:300, loss = 1.4308
epoch:350, loss = 1.0978
epoch:400, loss = 0.8877
epoch:450, loss = 0.7521
epoch:500, loss = 0.6629

  参数名称和值:
model.named_parameters()可以以生成器的形式返回模型参数的名称和值

print(list(model.named_parameters()))

  结果:

[('linear.weight', Parameter containing:tensor([[1.4773]], requires_grad=True)), 
('linear.bias', Parameter containing:tensor([1.2792], requires_grad=True))]

  单独查看权重/偏置:

print(model.linear.weight)
print(model.linear.bias)

7.绘制数据

  使用tensor.detach()方法获得具有相同内容但不需要跟踪运算的新张量,可以认为是获取张量的值

plt.scatter(x_list,y_list,label='scatter plot')
plt.plot(x,model(x).detach().numpy(),c='r',label='line plot')
plt.legend()
plt.show()

  结果:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2204327.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

掌握未来技能:亚马逊云科技推出生成式 AI 认证计划!

目录 前言 生成式 AI 的力量 1. 内容创造的无限可能 2. 数据增强和个性化 3. 提高生产力 4. 教育和研究的辅助工具 5. 突破语言障碍 关于亚马逊云科技生成式 AI 认证 1. 认证目标 2. 认证内容 3. 认证优势 如何获得认证 1. 在线学习 2. 实践考试 3.AWS Certifie…

连肝了多天学习MySQL索引与性能优化,详细总结一下索引的使用与数据库优化

文章目录 索引是什么&#xff1f;索引的作用初步认识索引索引的类型按照数据结构分类BTREE索引 哈希索引 按功能逻辑进行分类唯一索引普通索引主键索引全文索引 按照字段的个数进行划分单列索引多列&#xff08;组合&#xff0c;联合&#xff09;索引 小结索引的设计原则数据准…

FreeRTOS——TCB任务控制块、任务句柄、任务栈详解

任务控制块结构体 任务控制块是 FreeRTOS 中用于描述和管理任务的数据结构&#xff0c;包含了任务的状态、优先级、堆栈等信息。 TCB_t的全称为Task Control Block&#xff0c;也就是任务控制块&#xff0c;这个结构体包含了一个任务所有的信息&#xff0c;它的定义以及相关变…

UE5蓝图学习笔记玩家碰撞触发死亡加一秒黑屏

UE5蓝图学习笔记玩家碰撞触发死亡加一秒黑屏 1.代表检测自身是否到和其他Actor碰撞。 2.判断Actor是否等于Player Pawn 3.摄像机在一秒钟褪色0-1。 4.Delay延时一秒执行。 5.获取当前关卡的名字。 6.重新加载当前的关卡 。 7.获取Get Plyer Pawn。 8.获取玩家相机控制器…

一次性语音芯片:重塑语音识别技术,引领智能化生活新时代

随着一次性语音芯片的突破性进展&#xff0c;语音识别技术正融入我们生活的方方面面&#xff0c;引领着智能化生活迈向一个全新的时代。这些芯片不仅体积小巧、成本低廉&#xff0c;更在性能上实现了质的飞跃&#xff0c;能够更精确地捕捉并理解人类语音。本文将解读关于一次性…

Scrapy网络爬虫基础

使用Spider提取数据 Scarpy网络爬虫编程的核心就是爬虫Spider组件&#xff0c;它其实是一个继承与Spider的类&#xff0c;主要功能设计封装一个发送给网站服务器的HTTP请求&#xff0c;解析网站返回的网页及提取数据 执行步骤 1、Spider生成初始页面请求&#xff08;封装于R…

基于SpringBoot智能垃圾分类系统【附源码】

基于SpringBoot智能垃圾分类系统 效果如下&#xff1a; 系统首页界面 用户注册界面 垃圾站点页面 商品兑换页面 管理员登录界面 垃圾投放界面 物业登录界面 物业功能界图 研究背景 随着城市化进程的加速&#xff0c;生活垃圾的产量急剧增加&#xff0c;传统的垃圾分类方式已…

Java 集合 Collection常考面试题

理解集合体系图 collection中 list 是有序的,set 是无序的 什么是迭代器 主要遍历 Collection 集合中的元素,所有实现了 Collection 的集合类都有一个iterator()方法,可以返回一个 iterator 的迭代器。 ArrayList 和 Vector 的区别? ArrayList 可以存放 null,底层是由数…

Oracle RAC IPC Send timeout detected问题分析处理

一、报错信息 今天在进行数据库巡检时&#xff0c;在集群节点1发现了IPC相关报错信息&#xff1a; 2024-10-10T10:22:06.84631708:00 IPC Receiver dump detected. Sender instance 2 Receiver pnum 277 ospid 377527 [oraclezxsszpt-sjkfwq1 (PPA6)], pser 124403 2024-10-1…

飞行机器人专栏(十六)-- 双臂机器人体感交互式控制

目录 1. 概要 2. 整体架构流程 3. 控制系统设计 3.1 Vision-based Human-Robot Interaction Control 3.2 Human Motion Estimation Approach 4. 实现方法及实验验证 4.1 System Implementation 4.2 Experimental Setup 4.3 Experimental Results 5. 小结 ​​​​​​​ 1. 概…

Qt Creator 通过python解释器调用*.py

全是看了大佬们的帖子&#xff0c;结合chatGPT才揉出来。在此做个记录。 安装python在Qt Creator *.pro 文件中配置好环境来个简单的example.py调用代码安装pip添加opencv等库调用包含了opencv库的py代码成功 *.pro配置&#xff1a; INCLUDEPATH C:\Users\xuanm\AppData\Lo…

接口测试-day3-jmeter-2组件和元件

组件和元件&#xff1a; 组件&#xff1a;组件指的是jmeter里面任意一个可以使用的功能。比如说查看结果树或者是http请求 元件&#xff1a;元件指是提对组件的分类 组件的作用域&#xff1a;组件放的位置不一样生效也不一样。 作用域取决于组件的的层级结构并不取决于组件的…

论文阅读:OpenSTL: A Comprehensive Benchmark of Spatio-Temporal Predictive Learning

论文地址&#xff1a;arxiv 摘要 由于时空预测没有标准化的比较&#xff0c;所以为了解决这个问题&#xff0c;作者提出了 OpenSTL&#xff0c;这是一个全面的时空预测学习基准。它将流行的方法分为基于循环和非循环模型两类。OpenSTL提供了一个模块化且可扩展的框架&#xff…

算法: 前缀和题目练习

文章目录 前缀和题目练习前缀和二维前缀和寻找数组的中心下标除自身以外数组的乘积和为 K 的子数组和可被 K 整除的子数组连续数组矩阵区域和 前缀和题目练习 前缀和 自己写出来了~ 坑: 数据太大,要用long. import java.util.Scanner;public class Main {public static voi…

“国货户外TOP1”凯乐石签约实在智能,RPA助力全域电商运营自动化提效

近日&#xff0c;国货第一户外品牌KAILAS凯乐石与实在智能携手合作&#xff0c;基于实在智能“取数宝”自动化能力&#xff0c;打通运营数据获取全链路&#xff0c;全面提升淘宝、天猫、抖音等平台的运营效率与消费者体验&#xff0c;以自动化能力驱动企业增长。 KAILAS凯乐石…

雨晨 24H2 正式版 Windows 11 iot ltsc 2024 适度 26100.2033 VIP2IN1

雨晨 24H2 正式版 Windows 11 iot ltsc 2024 适度 26100.2033 VIP2IN1 install.wim 索引: 1 名称: Windows 11 IoT 企业版 LTSC 2024 x64 适度 (生产力环境推荐) 描述: Windows 11 IoT 企业版 LTSC 2024 x64 适度 By YCDISM 2024-10-09 大小: 15,699,006,618 个字节 索引: 2 …

Jenkins常见问题处理

Jenkins操作手册 读者对象&#xff1a;生产环境管理及运维人员 Jenkins作用&#xff1a;项目自动化构建部署。 一、登陆 二、新增用户及设置权限 2.1&#xff1a;新增用户 点击Manager Jenkins → Manager Users → Create User 2.2&#xff1a;权限 点击Manager Jenkins…

互联网线上融合上门洗衣洗鞋小程序,让洗衣洗鞋像点外卖一样简单

随着服务创新的风潮&#xff0c;众多商家已巧妙融入预约上门洗鞋新风尚&#xff0c;并携手洗鞋小程序&#xff0c;开辟线上蓝海。那么&#xff0c;这不仅仅是一个小程序&#xff0c;它究竟蕴含着哪些诱人好处呢&#xff1f; 1. 无缝融合&#xff0c;双线共赢&#xff1a;小程序…

Corel VideoStudio Ultimate 会声会影2025旗舰版震憾来袭,会声会影2025旗舰版最低系统要求

软件介绍 会声会影2025旗舰版全名&#xff1a;Corel VideoStudio Ultimate 2025&#xff0c;相信做视频剪辑的朋友都认识它&#xff0c;会声会影是一款强大的视频剪辑编辑软件&#xff0c;运用数百种拖放滤镜、效果、图形、标题和过渡&#xff0c;探索新奇好玩的新增面部追踪贴…

彩族相机内存卡恢复多种攻略:告别数据丢失

在数字时代&#xff0c;相机内存卡作为我们存储珍贵照片和视频的重要媒介&#xff0c;其数据安全性显得尤为重要。然而&#xff0c;意外删除、错误格式化、存储卡损坏等情况时有发生&#xff0c;导致数据丢失&#xff0c;给用户带来不小的困扰。本文将详细介绍彩族相机内存卡数…