【pytorch】使用pytorch构建线性回归模型-了解计算图和自动梯度

news2025/1/11 6:24:49

使用pytorch构建线性回归模型

线性方程的一般形式

请添加图片描述

衡量线性损失的一般形式-均方误差

请添加图片描述

pytorch中计算图的作用和优势

在 PyTorch 中,计算图(Computational Graph)是一种用于表示神经网络运算的数据结构。每个节点代表一个操作,例如加法、乘法或激活函数,而边则代表这些操作之间的数据流动。

计算图的主要优点是可以自动进行微分计算。当你在计算图上调用 .backward() 方法时,PyTorch 会自动计算出所有变量的梯度。这是通过使用反向传播算法来实现的,该算法从最后的输出开始,然后根据链式法则回溯到输入。

以下是一个简单的计算图示例:

import torch

# 定义两个张量
x = torch.tensor([1.0], requires_grad=True)
y = torch.tensor([2.0], requires_grad=True)

# 定义计算图
z = x * y
out = z.mean()

# 计算梯度
out.backward()

print(x.grad) # tensor([0.5])
print(y.grad) # tensor([0.5])

在这个例子中,我们首先定义了两个需要求导的张量 xy。然后,我们定义了一个计算图,其中 zxy 的乘积,outz 的平均值。当我们调用 out.backward() 时,PyTorch 会自动计算出 xy 的梯度。

注意,只有那些设置了 requires_grad=True 的张量才会被跟踪并存储在计算图中。这样,我们就可以在需要时计算这些张量的梯度。

import torch

x_data = [1.0, 2.0, 3.0] #x输入,表示特征
y_data = [2.0, 4.0, 6.0] #y输入,表示标签

w = torch.tensor([1.0], requires_grad=True) #创建权重张量,启用自动计算梯度

def forward(x): #前向传播
    return x * w #特征和权重的点积,构建乘法计算图

def loss(x, y):
    y_pred = forward(x)
    return (y_pred - y) ** 2 #均方误差,构建损失计算图

线性模型的计算图的一般形式

请添加图片描述

print("predict before training is {}".format(forward(4).item()))

for epoch in range(100):
    for x,y in zip(x_data, y_data):#组合特征和标签
        l = loss(x,y) #定义计算图,包括前向传播和计算损失
        l.backward() #反向传播,计算梯度
        print("\tgrad:", x,y,w.grad.item())#梯度的标量
        w.data = w.data - 0.01 * w.grad.data#使用“.data”表示是更新数据,而不是创建计算图
        w.grad.data.zero_()#梯度清零,准备创建下一个计算图
    print("progress:", epoch, l.item())
print("predict after training:{}".format(forward(4).item()))

使用pytorch API

pytorch的张量计算
请添加图片描述

准备数据集

x_data = torch.tensor([[1.0],[2.0],[3.0]])
y_data = torch.tensor([[2.0],[4.0],[6.0]])

请添加图片描述

class LinearModel(torch.nn.Module):
    def __init__(self, *args, **kwargs) -> None:
        super(LinearModel, self).__init__(*args, **kwargs)
        self.linear = torch.nn.Linear(in_features=1, out_features=1)
        
    def forward(self, x):
        y_pred = self.linear(x)
                return y_pred
    
model = LinearModel()

定义损失函数和损失优化函数

关于小批量随机梯度下降

小批量随机梯度下降(Mini-batch Stochastic Gradient Descent)是批量梯度下降的一种变体。与批量梯度下降相比,小批量随机梯度下降在每次迭代时只使用一小部分数据(称为小批量)来计算梯度,然后根据这个梯度来更新模型的参数。

小批量随机梯度下降的目标函数为:

J ( θ ) = 1 m ∑ i = 1 m L ( y ( i ) , f θ ( x ( i ) ) ) J(\theta) = \frac{1}{m} \sum_{i=1}^{m} L(y^{(i)}, f_{\theta}(x^{(i)})) J(θ)=m1i=1mL(y(i),fθ(x(i)))

其中, J ( θ ) J(\theta) J(θ) 是目标函数, m m m 是数据集的大小, L ( y ( i ) , f θ ( x ( i ) ) ) L(y^{(i)}, f_{\theta}(x^{(i)})) L(y(i),fθ(x(i))) 是第 i i i 个样本的损失函数, f θ ( x ( i ) ) f_{\theta}(x^{(i)}) fθ(x(i)) 是模型对第 i i i 个样本的预测。

小批量随机梯度下降的更新规则为:

θ = θ − α ∇ J ( θ ) \theta = \theta - \alpha \nabla J(\theta) θ=θαJ(θ)

其中, α \alpha α 是学习率, ∇ J ( θ ) \nabla J(\theta) J(θ) 是目标函数关于 θ \theta θ 的梯度。

小批量随机梯度下降的优点是它结合了批量梯度下降的优点(即可以利用整个数据集的信息来更新参数)和随机梯度下降的优点(即可以在每次迭代时使用新的数据)。这使得它在处理大规模数据集时具有更好的计算效率,同时也能避免随机梯度下降的问题(即可能会陷入局部最优)。

criteria = torch.nn.MSELoss()#使用均方误差做损失函数
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)#使用随机梯度下降做损失优化函数
for epoch in range(1000):
    y_pred = model(x_data)
    loss = criteria(y_pred, y_data)
    print(epoch, loss)
    optimizer.zero_grad()#梯度清零
    loss.backward()#反向传播
    optimizer.step()#梯度下降更新参数

预测

print("w=", model.linear.weight.item())
print("b=", model.linear.bias.item())

x_test = torch.Tensor([[4.0]])
y_test = model(x_test)
print("y_pred=", y_test.data)
print("w=", model.linear.weight.item())
print("b=", model.linear.bias.item())

实践

使用pytorch创建线性模型进行波士顿房价预测(数据集可以自行下载)

import pandas as pd
import numpy as np
import torch
import matplotlib.pyplot as plt

data_file = "J:\\MachineLearning\\数据集\\housing.data"

pd_data = pd.read_csv(data_file, sep="\s+")

def prepare_data(data, normalize_data=True):    
    # 标准化特征矩阵(可选)    
    if normalize_data:    
        features_mean = np.mean(data, axis=0)    #特征的平均值
        features_dev = np.std(data, axis=0)      #特征的标准偏差
        features_ret = (data - features_mean) / features_dev    #标准化数据
    else:    
        features_mean = None    
        features_dev = None   
    return features_ret

np_data = pd_data.sample(frac=1).reset_index(drop=True).values
#bias = np.ones(len(np_data)).reshape(-1,1)
#np_data = np.concatenate((bias, np_data), axis=1)
train_data = np_data[:int(len(np_data)*0.8), :]

test_data = np_data[int(len(np_data)*0.8):, :]
train_dataset = train_data[:, :-1]
test_dataset = test_data[:, :-1]
train_labels = train_data[:, -1]
test_labels = test_data[:, -1]
train_dataset = prepare_data(train_dataset)
# Save the mean and standard deviation of the target variable before normalization
target_mean = np.mean(train_labels)
target_dev = np.std(train_labels)

print(np_data.shape)
print(train_data.shape)
print(train_dataset.shape)
print(train_labels.shape)

features = torch.tensor(train_dataset, dtype=torch.float32)
print(features[:10])
feature_num = features.shape[1]
labels = torch.tensor(train_labels.reshape(-1,1), dtype=torch.float32)

print(features.shape)
print(labels.shape)

class LinearReg(torch.nn.Module):
    def __init__(self, *args, **kwargs) -> None:
        super(LinearReg, self).__init__(*args, **kwargs)
        self.linear_reg = torch.nn.Linear(in_features=feature_num, out_features=1)
        
    def forward(self, x):
        pred_y = self.linear_reg(x)
        return pred_y

epoch = 100000
lr = 0.001
linear_model = LinearReg()
loss_function = torch.nn.MSELoss(size_average=True)
optimizer = torch.optim.SGD(linear_model.parameters(), lr)

loss_history = []
last_loss = 0.01
for epoch_step in range(epoch):
    predict = linear_model(features)
    loss = loss_function(predict, labels)
    if (epoch_step % 100 == 0):
        print(epoch_step, loss)
        loss_history.append(loss.item())
        if (abs(float(loss.item()) - last_loss)/last_loss < 0.00001):
            break
        last_loss = float(loss.item())
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
plt.plot(loss_history)
plt.show()
test_dataset = prepare_data(test_dataset)
test_dataset = torch.tensor(test_dataset, dtype=torch.float32)
result = linear_model(test_dataset).detach().numpy()
predicted_values = np.round(result, 2)
print(predicted_values)
show_result = np.concatenate((predicted_values.reshape(-1,1), test_labels.reshape(-1,1)), axis=1)
print(show_result)
print(predicted_values)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1376859.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【AWS】使用亚马逊云服务器创建EC2实例

目录 前言为什么选择 Amazon EC2 云服务器搭建 Amazon EC2 云服务器注册亚马逊账号登录控制台服务器配置免费套餐预览使用 Amazon EC2 云服务器打开服务器管理界面设置服务器区域填写实例名称选择服务器系统镜像选择实例类型创建密钥对网络设置配置存储启动实例查看实例 总结 前…

【天龙怀旧服】攻略day5

关键字&#xff1a; 天鉴扫荡、举贤、燕子水路 1】85天鉴任务可以扫荡 在流派选择npc那里&#xff0c;花费40交子即可扫荡100点&#xff0c;可以兑换10个灵武打造图&#xff1b; 此外打造图绑定不影响做出来的灵武绑定&#xff0c;只要对应的玉不绑灵武就不绑定 2】冠绝师门…

C#使用CryptoStream类加密和解密字符串

目录 一、CrytoStream的加密方法 二、CrytoStream的解密方法 三、实例 1.源码Form1.cs 2.类库Encrypt.cs 3.生成效果 在使用CryptoStream前要先引用命名空间using System.Security.Cryptography。 一、CrytoStream的加密方法 记住&#xff0c;不能再使用DESCryptoServi…

宏集案例丨宏集PC Runtime软件助推食品行业生产线数字化革新

来源&#xff1a;宏集科技 工业物联网 宏集案例丨宏集PC Runtime软件助推食品行业生产线数字化革新 原文链接&#xff1a;https://mp.weixin.qq.com/s/DwzVzifUiidNr-FT3Zfzpg 欢迎关注虹科&#xff0c;为您提供最新资讯&#xff01; 01 前言 近年来&#xff0c;中国食品行业…

深入浅出Android dmabuf_dump工具

目录 dmabuf是什么&#xff1f; dmabuf_dump工具介绍(基于Android 14) Android.bp dmabuf_dump.cpp 整体架构结构如下 dmabuf_dump主要包含以下功能 前置背景知识 fdinfo 思考 bufinfo Dump整个手机系统的dmabuf Dump某个进程的dmabuf​​​​​​​ 以Table[buff…

Hive 的 安装与使用

目录 1 安装 MySql2 安装 Hive3 Hive 元数据配置到 MySql4 启动 Hive5 Hive 常用交互命令6 Hive 常见属性配置 Hive 官网 1 安装 MySql 为什么需要安装 MySql? 原因在于Hive 默认使用的元数据库为 derby&#xff0c;开启 Hive 之后就会占用元数据库&#xff0c;且不与其他客户…

Windows 远程控制之 PsExec

1、介绍&#xff1a; PsExec 是一种轻量级 telnet 替代品&#xff0c;可让你在其他系统上执行进程&#xff0c;并为控制台应用程序提供完整交互性&#xff0c;而无需手动安装客户端软件。 PsExec 最强大的用途包括在远程系统上启动交互式命令提示符&#xff0c;以及 IpConfig …

一篇文章彻底搞懂TiDB集群各种容量计算方式

背景 TiDB 集群的监控面板里面有两个非常重要、且非常常用的指标&#xff0c;相信用了 TiDB 的都见过&#xff1a; Storage capacity&#xff1a;集群的总容量 Current storage size&#xff1a;集群当前已经使用的空间大小 当你准备了一堆服务器&#xff0c;经过各种思考设计…

【JaveWeb教程】(21) MySQL数据库开发之多表设计:一对多、一对一、多对多的表关系 详细代码示例讲解

目录 2. 多表设计2.1 一对多2.1.1 表设计2.1.2 外键约束 2.2 一对一2.3 多对多2.4 案例 2. 多表设计 关于单表的操作(单表的设计、单表的增删改查)我们就已经学习完了。接下来我们就要来学习多表的操作&#xff0c;首先来学习多表的设计。 项目开发中&#xff0c;在进行数据库…

OCR字符识别:开始批量识别身份证信息

身份证信息批量识别OCR是一项解决方案&#xff0c;它能够将身份证照片打包成zip格式或通过URL地址进行提交&#xff0c;并能够识别照片中的文本信息。最终&#xff0c;用户可以将识别结果生成为excel文件进行下载。 API接口功能&#xff1a; 1. 批量识别&#xff1a;支持将多…

SPDK中常用的性能测试工具

本文主要介绍磁盘性能评估的方法&#xff0c;针对用户态驱动Kernel与SPDK中各种IO测试工具的使用方法做出总结。其中fio是一个常用的IO测试工具&#xff0c;可以运行在Linux、Windows等多种系统之上&#xff0c;可以用来测试本地磁盘、网络存储等的性能。为了和SPDK的fio工具相…

大模型学习与实践笔记(四)

一、大模型开发范式 RAG&#xff08;Retrieval Augmented Generation&#xff09;检索增强生成&#xff0c;即大模型LLM在回答问题或生成文本时&#xff0c;会先从大量的文档中检索出相关信息&#xff0c;然后基于这些检索出的信息进行回答或生成文本&#xff0c;从而可以提高回…

Prepar3D设置全屏显示设置方法

一、 基础设置 当视景软件显示的屏幕超过一个的时候&#xff0c;需要将多个显示屏幕在设置->屏幕设置->多显示器这里设置为扩展这些显示器。 二、全屏方法说明 一般情况只需要设置了多屏显示扩展并设置了P3D软件全屏设置&#xff08;即下面的步骤一&#xff09;保存后…

D2576 DC-DC降压芯片用于直流充电桩,具备3A的输出电流能力,输入电压6~40VDC

随着新能源汽车的不断普及&#xff0c;如何解决新能源车充电的问题也成为大热话题&#xff0c;充电桩的数量与质量也是目前急需提升的热门方面&#xff0c;现阶段人们需要的充电桩主要有交流充电桩和直流充电桩&#xff0c;直流充电桩因其节能效率高、功率因数高、充电快、逐渐…

【OJ比赛日历】快周末了,不来一场比赛吗? #01.13-01.19 #11场

CompHub[1] 实时聚合多平台的数据类(Kaggle、天池…)和OJ类(Leetcode、牛客…&#xff09;比赛。本账号会推送最新的比赛消息&#xff0c;欢迎关注&#xff01; 以下信息仅供参考&#xff0c;以比赛官网为准 目录 2024-01-13&#xff08;周六&#xff09; #4场比赛2024-01-14…

【一、测试基础】Java基础语法

Java 的用法及注意事项有很多&#xff0c;今天的目标是了解Java基础语法&#xff0c;且能够输出"hello world" 几个基础的概念 对象&#xff1a;对象是类的一个实例&#xff0c;有状态和行为。一只猫是一个对象&#xff0c;猫的状态有&#xff1a;颜色、名字、品种&…

通信原理 | 累积谱的概念 | Python案例代码介讲解

文章目录 累积谱的概念Python案例代码讲解结果展示拓展累积幅值什么意思累积谱的概念 累积谱(Cumulative Spectrum)是信号处理领域的一个概念,它用于描述信号的频谱分布。具体来说,累积谱是一个函数或者图表,用来表示信号频谱中的能量分布,通常是从最低频率开始累积到某…

AI智能分析网关V4:太阳能+4G智慧水库远程可视化智能监管方案

一、背景需求分析 由于水库位置分散的原因&#xff0c;水库视频监控建设在立杆、布线等方面都存在一定的难度&#xff0c;且需要人力、物力的前期投入和后期维护。目前水库的监管存在一定的问题&#xff0c;管理人员工作强度大但管理质量并不高&#xff0c;人为巡检无法实时发…

C# OpenCvSharp DNN 部署yoloX

目录 效果 模型信息 项目 代码 下载 C# OpenCvSharp DNN 部署yoloX 效果 模型信息 Inputs ------------------------- name&#xff1a;images tensor&#xff1a;Float[1, 3, 640, 640] --------------------------------------------------------------- Outputs ---…

Java异常处理--异常处理的方式1

文章目录 一、异常处理概述二、方式1&#xff1a;捕获异常&#xff08;try-catch-finally&#xff09;&#xff08;1&#xff09;抓抛模型&#xff08;2&#xff09;try-catch-finally基本格式1、基本语法2、整体执行过程3、try和catch3.1 try3.2 catch (Exceptiontype e) &…