Pytorch笔记之回归

news2025/1/18 11:50:09

文章目录

  • 前言
  • 一、导入库
  • 二、数据处理
  • 三、构建模型
  • 四、迭代训练
  • 五、结果预测
  • 总结


前言

以线性回归为例,记录Pytorch的基本使用方法。


一、导入库

import numpy as np
import matplotlib.pyplot as plt
import torch
from torch.autograd import Variable # 定义求导变量
from torch import nn, optim # 定义网络模型和优化器

二、数据处理

将数据类型转为tensor,第一维度变为batch_size

# 构建数据
x = np.random.rand(100)
noise = np.random.normal(0, 0.01, x.shape)
y = 0.1 * x + 0.2 + noise
# 数据处理
x_data = torch.FloatTensor(x.reshape(-1, 1))
y_data = torch.FloatTensor(y.reshape(-1, 1))
inputs = Variable(x_data)
target = Variable(y_data)

三、构建模型

1、继承nn.Module,定义一个线性回归模型。在__init__中定义连接层,定义前向传播的方法
2、实例化模型,定义损失函数与优化器

# 继承模型
class LinearRegression(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(1, 1)
    def forward(self, x):
        out = self.fc(x)
        return out
# 定义模型
print('模型参数')
model = LinearRegression()
mse_loss = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.1)
for name, param in model.named_parameters():
    print('{}:{}'.format(name, param))

四、迭代训练

1、梯度清零:optimizer.zero_grad()
2、反向传播计算梯度值:loss.backward()
3、执行参数更新:optimizer.step()
循环迭代,定期输出损失值

print('损失值')
for i in range(1001):
    out = model.forward(inputs)
    loss = mse_loss(out, target)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if i % 200 == 0:
        print(i, loss.item())

五、结果预测

绘制样本的散点图与预测值的折线图

print('结果预测')
y_pred = model(x_data)
plt.plot(x, y, 'b.')
plt.plot(x, y_pred.data.numpy(), 'r-')
plt.show()


总结

使用Pytorch进行训练主要的三步:
(1)数据处理:将数据维度转换为(batch, *),数据类型转换为可训练的tensor;
(2)构建模型:继承nn.Module,定义连接层与运算方法,实例化,定义损失函数与优化器;
(3)迭代训练:循环迭代,依次执行梯度清零、梯度计算、参数更新。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1064628.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

ESP32/ESP8266在线刷写Sonoff Tasmota固件以及配置简要

ESP32/ESP8266在线刷写Sonoff Tasmota固件以及配置简要 📍原项目Github地址:https://github.com/arendst/Tasmota/tree/v13.1.0📑官方文档介绍:https://tasmota.github.io/docs/🚩(✨推荐方式✨)在线固件刷写地址&…

strcpy函数详解:字符串复制的利器

目录 一,strcpy函数的简介 二,strcpy函数的实现原理 三,strcpy函数的注意事项 四,strcpy函数的模拟实现 一,strcpy函数的简介 strcpy函数是C语言中的字符串复制函数,其原型如下: char * str…

Linux中的wc命令

2023年10月6月,周五晚上 目录 wc命令的主要功能和用法如下:统计文件行数、字数和字节数只统计行数只统计字数只统计字节数 wc命令在Linux/Unix系统中是word count的缩写,它用来统计文件的行数、字数和字节数。 wc命令的主要功能和用法如下: 统计文件行数、字数和字…

英语四六级高频核心词(故事版)

第一组:" A Century of Community Effort to Improve Quality of Life and Climate" In the early years of the 20th century, a small community found itself facing a decade of challenges. The most pressing issue was the mental quality of life…

VSC-HVDC直流输电matlab仿真模型

微❤关注“电气仔推送”获得资料(专享优惠) VSC-HVDC直流输电仿真,换流站采用两电平结构,全控型器件(IGBT),采用双环控制,包括电压外环,电流内环,分为d、q两…

【论文阅读】An Evaluation of Concurrency Control with One Thousand Cores

An Evaluation of Concurrency Control with One Thousand Cores Staring into the Abyss: An Evaluation of Concurrency Control with One Thousand Cores ABSTRACT 随着多核处理器的发展,一个芯片可能有几十乃至上百个core。在数百个线程并行运行的情况下&…

Springboot+vue的开放性实验室管理系统(有报告)。Javaee项目,springboot vue前后端分离项目。

演示视频: Springbootvue的开放性实验室管理系统(有报告)。Javaee项目,springboot vue前后端分离项目。 项目介绍: 本文设计了一个基于Springbootvue的前后端分离的开放性实验室管理系统,采用M&#xff08…

基于SSM的家庭财务管理系统设计与实现

末尾获取源码 开发语言:Java Java开发工具:JDK1.8 后端框架:SSM 前端:采用JSP技术开发 数据库:MySQL5.7和Navicat管理工具结合 服务器:Tomcat8.5 开发软件:IDEA / Eclipse 是否Maven项目&#x…

基于SSM的大学生就业信息管理系统设计与实现

末尾获取源码 开发语言:Java Java开发工具:JDK1.8 后端框架:SSM 前端:采用JSP技术开发 数据库:MySQL5.7和Navicat管理工具结合 服务器:Tomcat8.5 开发软件:IDEA / Eclipse 是否Maven项目&#x…

matlab之zeros函数语法与举例说明(附代码)

一、zeros函数语法与举例说明 (1)X zeros——返回标量0 X zeros 示例: (2)X zeros(n)——返回一个 nn 的全零矩阵 零矩阵: 示例:创建一个由零值组成的 33 矩阵 X zeros(3) (…

《protobuf》基础语法3

文章目录 默认值更新规则保留字段未知字段 默认值 在反序列化时,若被反序列化的二进制序列中不包含某个字段,则在反序列化时,就会设置对应默认值。不同的类型默认值不同: 类型默认值字符串“”布尔型false数值类型0枚举型0设置了…

基于SSM的旅游网站设计与实现

末尾获取源码 开发语言:Java Java开发工具:JDK1.8 后端框架:SSM 前端:采用JSP技术开发 数据库:MySQL5.7和Navicat管理工具结合 服务器:Tomcat8.5 开发软件:IDEA / Eclipse 是否Maven项目&#x…

国庆看坚如磐石

坚如磐石上映了,可以在爱奇艺观看。 而博主在使用蓝牙耳机连接电脑的过程中,发现没有蓝牙开启选项,并且在服务的设备管理器中也没有找到,很明显这是缺少驱动导致的,因此便去联想官方网站下载对应的驱动。 这里可以输入…

二分查找:34. 在排序数组中查找元素的第一个和最后一个位置

个人主页 : 个人主页 个人专栏 : 《数据结构》 《C语言》《C》《算法》 文章目录 前言一、题目解析二、解题思路1. 暴力查找2. 一次二分查找 部分遍历3. 两次二分查找分别查找左右端点1.查找区间左端点2. 查找区间右端点 三、代码实现总结 前言 本篇文…

GCN详解

a ⃗ \vec{a} a 向量 a ‾ \overline{a} a 平均值 a ‾ \underline{a} a​下横线 a ^ \widehat{a} a (线性回归,直线方程) y尖 a ~ \widetilde{a} a a ˙ \dot{a} a˙ 一阶导数 a \ddot{a} a 二阶导数 H(l)表示l层的节点的特征 W(l)表示l层的参数 D ~ \widet…

Kafka客户端核心参数详解

这一部分主要是从客户端使用的角度来理解 Kakfa 的重要机制。重点依然是要建立自己脑海中的 Kafka 消费模型。Kafka 的 HighLevel API 使用是非常简单的,所以梳理模型时也要尽量简单化,主线清晰,细节慢慢扩展。 一、从基础的客户端说起 Kaf…

iphone怎么传大量照片到电脑,这四招你要学会

如果你喜欢用iPhone拍照、总会遇到要把大量照片从iPhone传输到电脑的情况,要是你对这方面不熟悉就很容易浪费时间。下面小编就介绍几种方法可以快速高效的传大量照片到电脑上去。 iPhone传输照片到电脑 方法一:使用iMazing传输 推荐度★★★★★ 有了i…

操作系统八股

1、请你介绍一下死锁,产生的必要条件,产生的原因,怎么预防死锁 1、死锁 两个或两个以上的进程在执行过程中,因争夺共享资源而造成的一种互相等待的现象,若无外力作用,它们都将无法推进下去。此时称系统处…

【Spring Boot】日志文件

日志文件 一. 日志文件有什么用二. 日志怎么用三. ⾃定义⽇志打印1. 在程序中得到⽇志对象2. 使⽤⽇志对象打印⽇志3. ⽇志格式说明 四. 日志级别1. ⽇志级别有什么⽤2. ⽇志级别的分类与使⽤ 五. 日志持久化六. 更简单的⽇志输出—lombok1. 添加 lombok 依赖2. 输出⽇志3. lom…