pytorch实现单层线性回归模型

news2024/12/23 18:45:54

文章目录

    • 简述
      • 代码重构要点
    • 数学模型、运行结果
    • 数据构建与分批
    • 模型封装
    • 运行测试

简述

python使用 数值微分法 求梯度,实现单层线性回归-CSDN博客
python使用 计算图(forward与backward) 求梯度,实现单层线性回归-CSDN博客
数值微分求梯度、计算图求梯度,实现单层线性回归 模型速度差异及损失率比对-CSDN博客

上述文章都是使用python来实现求梯度的,是为了学习原理,实际使用上,pytorch实现了自动求导,原理也是(基于计算图的)链式求导,本文还就 “单层线性回归” 问题用pytorch实现。

代码重构要点

1.nn.Moudle

torch.nn.Module的继承、nn.Sequentialnn.Linear
torch.nn — PyTorch 2.4 documentation

对于nn.Sequential的理解可以看python使用 计算图(forward与backward) 求梯度,实现单层线性回归-CSDN博客一文代码的模型初始化与计算部分,如图:

在这里插入图片描述

nn.Sequential可以说是把图中标注的代码封装起来了,并且可以放多层。

2.torch.optim优化器

本例中使用随机梯度下降torch.optim.SGD()
torch.optim — PyTorch 2.4 documentation
SGD — PyTorch 2.4 documentation

3.数据构建与数据加载

data.TensorDatasetdata.DataLoader,之前为了实现数据分批,手动实现了data_iter,现在可以直接调用pytorch的data.DataLoader

对于data.DataLoader的参数num_workers,默认值为0,即在主线程中处理,但设置其它值时存在反而速度变慢的情况,以后再讨论。

数学模型、运行结果

y = X W + b y = XW + b y=XW+b

y为标量,X列数为2. 损失函数使用均方误差。

运行结果:

在这里插入图片描述

在这里插入图片描述

数据构建与分批

def build_data(weights, bias, num_examples):  
    x = torch.randn(num_examples, len(weights))  
    y = x.matmul(weights) + bias  
    # 给y加个噪声  
    y += torch.randn(1)  
    return x, y  
  
  
def load_array(data_arrays, batch_size, num_workers=0, is_train=True):  
    """构造一个PyTorch数据迭代器"""  
    dataset = data.TensorDataset(*data_arrays)  
    return data.DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=is_train)

模型封装

class TorchLinearNet(torch.nn.Module):  
    def __init__(self):  
        super(TorchLinearNet, self).__init__()  
        model = nn.Sequential(Linear(in_features=2, out_features=1))  
        self.model = model  
        self.criterion = nn.MSELoss()  
  
    def predict(self, x):  
        return self.model(x)  
  
    def loss(self, y_predict, y):  
        return self.criterion(y_predict, y)

运行测试

if __name__ == '__main__':  
    start = time.perf_counter()  
  
    true_w1 = torch.rand(2, 1)  
    true_b1 = torch.rand(1)  
    x_train, y_train = build_data(true_w1, true_b1, 5000)  
  
    net = TorchLinearNet()  
    print(net)  
  
    init_loss = net.loss(net.predict(x_train), y_train)  
    loss_history = list()  
    loss_history.append(init_loss.item())  
  
    num_epochs = 3  
    batch_size = 50  
    learning_rate = 0.01  
    dataloader_workers = 6  
  
    data_loader = load_array((x_train, y_train), batch_size=batch_size, is_train=True)  
    optimizer = torch.optim.SGD(net.parameters(), lr=learning_rate)  
  
    for epoch in range(num_epochs):  
        # running_loss = 0.0  
        for x, y in data_loader:  
            y_pred = net.predict(x)  
            loss = net.loss(y_pred, y)  
            optimizer.zero_grad()  
            loss.backward()  
            optimizer.step()  
            # running_loss = running_loss + loss.item()  
            loss_history.append(loss.item())  

    end = time.perf_counter()  
    print(f"运行时间(不含绘图时间):{(end - start) * 1000}毫秒\n")  
  
    plt.title("pytorch实现单层线性回归模型", fontproperties="STSong")  
    plt.xlabel("epoch")  
    plt.ylabel("loss")  
    plt.plot(loss_history, linestyle='dotted')  
    plt.show()  
  
    print(f'初始损失值:{init_loss}')  
    print(f'最后一次损失值:{loss_history[-1]}\n')  
  
    print(f'正确参数: true_w1={true_w1}, true_b1={true_b1}')  
    print(f'预测参数:{net.model.state_dict()}')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2048245.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

会议中控系统有多少种编程方法

会议中控系统的编程方法并不局限于某一种固定的方式,而是根据系统的具体需求、开发团队的技能以及所选用的编程语言和技术栈等多种因素来决定的。以下是一些常见的会议中控系统编程方法和考虑因素: 1. 编程语言选择 会议中控系统的开发通常会选择以下几…

Kubernetes拉取阿里云的私人镜像

前提条件 登录到阿里云控制台 拥有阿里云的ACR服务 创建一个命名空间 获取仓库的访问凭证(可以设置固定密码) 例如 sudo docker login --usernameyourAliyunAccount registry.cn-guangzhou.aliyuncs.com 在K8s集群中创建一个secret 使用kubectl命令行…

qt生成一幅纯马赛克图像

由于项目需要&#xff0c;需生成一幅纯马赛克的图像作为背景&#xff0c;经过多次测试成功&#xff0c;记录下来。 方法一&#xff1a;未优化方法 1、代码&#xff1a; #include <QImage> #include <QDebug> #include <QElapsedTimer>QImage generateMosa…

AI跟踪报道第52期-新加坡内哥谈技术-本周AI新闻: X推出的惊人逼真的但不受约束的图像生成器和 GooglePixel 9

每周跟踪AI热点新闻动向和震撼发展 想要探索生成式人工智能的前沿进展吗&#xff1f;订阅我们的简报&#xff0c;深入解析最新的技术突破、实际应用案例和未来的趋势。与全球数同行一同&#xff0c;从行业内部的深度分析和实用指南中受益。不要错过这个机会&#xff0c;成为AI领…

PostgreSQL-02-入门篇-查询数据

文章目录 1 简单查询SELECT 语句简介SELECT 语句语法SELECT 示例1) 使用 SELECT 语句查询一列数据的示例2) 使用 SELECT 语句查询多列数据的示例3) 使用 SELECT 语句查询表所有列数据的示例4) 使用带有表达式的 SELECT 语句的示例5) 使用带有表达式的 SELECT 语句的示例 2 列别…

大公报发表欧科云链署名文章:发行港元稳定币,建Web3.0新生态

欧科云链研究院资深研究员蒋照生近日与香港科技大学副校长兼香港Web3.0协会首席科学顾问汪扬、零壹智库创始人兼CEO柏亮&#xff0c;在大公报发布联合署名文章 ——《Web3.0洞察 / 发行港元稳定币&#xff0c;建Web3.0新生态》&#xff0c;引发市场广泛讨论。 文章就香港稳定币…

江科大/江协科技 STM32学习笔记P24

文章目录 DMA数据转运验证存储器映像的内容什么时候需要定义常量 验证外设寄存器的地址理解ADC1->DR main.c初始化DMADMA库函数MyDMA.cmain.c DMAAD多通道AD.cmain.c DMA数据转运 验证存储器映像的内容 #include "stm32f10x.h" // Device heade…

视频号分销系统搭建教程,源代码+部署上线指南

目录 一、视频号分销是什么&#xff1f; 二、视频号分销系统怎么搭建&#xff1f; 1.系统架构设计 2.部署与上线 3.持续迭代与升级 三、部分代码展示 一、视频号分销是什么&#xff1f; 视频号分销系统是合集了视频号商家的产品&#xff0c;推广达人推广商家的产品可赚取…

【算法 04】汉诺塔递归求解和通式求解

汉诺塔问题&#xff1a;一个经典的递归问题 汉诺塔&#xff08;Tower of Hanoi&#xff09;问题是一个源自古印度传说的经典益智游戏&#xff0c;也是心理学实验研究和计算机科学中常用的任务之一。该游戏通过三根高度相同的柱子和一系列大小及颜色不同的圆盘来构成&#xff0c…

[Python学习日记-7] 初识基本数据类型(下)

简介 我们在基本数据类型&#xff08;上&#xff09;当中介绍了数据类型中的数据类型&#xff08;整数、浮点数&#xff09;、字符串和布尔值&#xff0c;那么我们还剩下列表和数组还没有介绍了&#xff0c;在 Python 中&#xff0c;列表&#xff08;List&#xff09;是一种有序…

力扣Hot100-final关键字,常量,抽象类(模板方法设计模式),接口

&#xff08;一&#xff09;final关键字 &#xff08;2&#xff09;常量 使用static final 修饰的成员变量被称为常量 作用&#xff1a;&#xff1b;通常用于记录系统的配置信息 注意&#xff1a;产量命名要求&#xff1a;单词大写&#xff0c;下划线连接多个单词 产量优势…

windows下使用vcpkg编译libcurl库并使用C++实现ftp上传下载功能

1、下载安装vcpkg git clone https://github.com/microsoft/vcpkg2、编译vcpkg 使用cmd命令 D:\Code\ThirdParty>cd vcpkg D:\Code\ThirdParty\vcpkg>bootstrap-vcpkg.bat3、使用vcpkg编译所需的库 进入vckpkg目录&#xff0c;使用vckpkg install 命令进行安装。在安…

OJ题——二叉树(最大深度/平衡二叉树/前序遍历构建)

&#x1f36c;个人主页&#xff1a;Yanni.— &#x1f308;数据结构&#xff1a;Data Structure.​​​​​​ &#x1f382;C语言笔记&#xff1a;C Language Notes &#x1f3c0;OJ题分享&#xff1a; Topic Sharing 题目一&#xff08;最大深度&#xff09; 利用分治的思想&…

饿了么新财年开门见喜:亏损减负,收入增肌

撰稿 | 行星 来源 | 贝多财经 8月15日&#xff0c;阿里巴巴对外发布2025财年一季度&#xff08;即自然年2024年二季度&#xff09;业绩。不难看出&#xff0c;受益于饿了么和高德订单的显著增长&#xff0c;以及市场营销服务收入的明显拉升&#xff0c;该季度本地生活集团成绩…

10.DMA

理论 12个通道&#xff1a;DMA1&#xff08;7&#xff09;DMA2&#xff08;5&#xff09; 方向&#xff1a;存储器和存储器间(DMA_MEMORY_TO_MEMORY)、外设到存储器(DMA_PERIPH_TO_MEMORY)、存储器到外设(DMA_MEMORY_TO_PERIPH) 闪存、 SRAM、外设的SRAM、 APB1、 APB2和AHB外…

Simple RPC - 05 从零开始设计一个客户端(下)_ 依赖倒置和SPI

文章目录 Pre概述依赖倒置原则与解耦设计与实现1. 定义接口来隔离调用方与实现类2. 实现类DynamicStubFactory3. 调用方与实现类的解耦 依赖注入与SPI的解耦依赖注入SPI&#xff08;Service Provider Interface&#xff09; 总结 Pre Simple RPC - 01 框架原理及总体架构初探 …

一个模型,多种作物:迁移学习如何提升设施农业AI模型效能

&#xff08; 于景鑫 国家农业信息化工程技术研究中心&#xff09;设施农业是现代农业的"压舱石",但传统的经验式管理模式已难以为继。在数字经济时代,设施农业亟需向数字化、网络化、智能化转型升级。以人工智能为代表的信息技术,正在为设施农业插上腾飞的翅膀。作为…

Kafka主题(Topic/文件夹)的操作

Kafka主题&#xff08;Topic/文件夹&#xff09;的操作 1、Kafka主题&#xff08;Topic/文件夹&#xff09;2、Kafka主题&#xff08;Topic/文件夹&#xff09;的一些操作2.1、创建主题&#xff08;Topic/文件夹&#xff09;2.2、列出所有主题&#xff08;Topic/文件夹&#xf…

8路VBO转HDMI2.0支持4K60频率ITE6265芯片方案心得分享

在此之前&#xff0c;有人找到我这边询问能不能将智能电视主板改成机顶盒&#xff0c;将VBO信号转换输出位HDMI进行投屏&#xff0c;具体应用奇奇怪怪&#xff01;但是奈何是甲方大佬。认命照做。从网上也有搜索了解过这类芯片&#xff0c;发现资料很少&#xff0c;所以有了这篇…

基于免疫算法的最优物流仓储点选址方案MATLAB仿真

目录 1.程序功能描述 2.测试软件版本以及运行结果展示 3.核心程序 4.本算法原理 5.完整程序 1.程序功能描述 基于免疫算法的最优物流仓储点选址方案MATLAB仿真。 2.测试软件版本以及运行结果展示 MATLAB2022A版本运行 &#xff08;完整程序运行后无水印&#xff09; 3…