使用pytorch实现一个线性回归训练函数

news2025/1/16 5:10:06

使用sklearn.dataset 的make_regression创建用于线性回归的数据集

def create_dataset():
    x, y, coef = make_regression(n_samples=100, noise=10, coef=True, bias=14.5, n_features=1, random_state=0)
    return torch.tensor(x), torch.tensor(y), coef

加载数据集,并拆分batchs训练集

def load_dataset(x, y, batch_size):
    data_len = len(y)
    batch_num = data_len // batch_size
    for idx in range(batch_num):
        start = idx * batch_num
        end = idx * batch_num + batch_num
        train_x = x[start : end]
        train_y = y[start : end]
        yield train_x, train_y

定义初始权重和定义计算函数

w = torch.tensor(0.1, requires_grad=True, dtype=torch.float64)
b = torch.tensor(0, requires_grad=True, dtype=torch.float64)
def linear_regression(x):
    return x * w + b

损失函数使用平方差

def linear_loss(y_pred, y_true):
    return (y_pred - y_true) ** 2

优化参数使用梯度下降方法

def sgd(linear_rate, batch_size):
    w.data = w.data - linear_rate * w.grad / batch_size
    b.data = b.data - linear_rate * b.grad / batch_size

训练代码

def train():
    # 加载数据
    x, y, coef = create_dataset()
    data_len = len(y)

    # 定义参数
    batch_size = 10
    epochs = 100
    linear_rate = 0.01

    # 记录损失值
    epochs_loss = []

    # 迭代
    for eid in range(epochs):
        total_loss = 0.0
        for train_x, train_y in load_dataset(x, y, batch_size):
            # 输入模型
            y_pred = linear_regression(train_x)

            # 计算损失
            loss_num = linear_loss(y_pred, train_y.reshape(-1,1)).sum()

            # 梯度清理
            if w.grad is not None:
                w.grad.zero_()
            if b.grad is not None:
                b.grad.zero_()

            # 反向传播
            loss_num.backward()

            # 更新权重
            sgd(linear_rate, batch_size)

            # 统计损失数值
            total_loss = total_loss + loss_num.item()

        # 记录本次迭代的平均损失
        b_loss = total_loss / data_len
        epochs_loss.append(b_loss)
        print("epoch={},b_loss={}".format(eid, b_loss))

    # 显示预测线核真实线的拟合关系
    print(w, b)
    print(coef, 14.5)

    plt.scatter(x, y)

    test_x = torch.linspace(x.min(), x.max(), 1000)
    y1 = torch.tensor([v * w + b for v in test_x])
    y2 = torch.tensor([v * coef + 14.5 for v in test_x])
    plt.plot(test_x, y1, label='train')
    plt.plot(test_x, y2, label='true')
    plt.grid()
    plt.show()

    # 显示损失值变化曲线
    plt.plot(range(epochs), epochs_loss)
    plt.show()

拟合显示还不错

损失值在低5次迭代后基本就很小了

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1476240.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

蓝桥杯_定时器的基本原理与应用

一 什么是定时器 定时器/计数器是一种能够对内部时钟信号或外部输入信号进行计数,当计数值达到设定要求时,向cpu提出中断处理请求,从而实现,定时或者计数功能的外设。 二 51单片机的定时/计数器 单片机外部晶振12MHZ,…

leetcode 重复的子字符串

前要推理 以abababab为例,这里最主要的就是根据相等前后缀进行推导 s [ 0123 ] 如 t【 0123 】 f 【01 23 】 后两个分别是前后缀,第一个是总的字符串,然后可以推导 //首先还是算出…

【JavaEE】网络原理: HTTP协议相关内容

目录 HTTP 是什么 理解HTTP 协议的工作过程 HTTP 协议格式 抓包工具的使用 抓包工具的原理 抓包结果 HTTP请求 HTTP响应 协议格式总结 HTTP 请求 (Request) 认识 URL 关于 URL encode 认识 "方法" (method) get方法 post方法 其他方法 认识请求 &q…

Linux小项目:在线词典开发

在线词典介绍 流程图如下: 项目的功能介绍 在线英英词典项目功能描述用户注册和登录验证服务器端将用户信息和历史记录保存在数据中。客户端输入用户和密码,服务器端在数据库中查找、匹配,返回结果单词在线翻译根据客户端输入输入的单词在字…

李宏毅机器学习入门笔记——第四节

自注意力机制(常见的神经网络结构) 上节课我们已经讲述过 CNN 卷积神经网络 和 spatial transformer 网络。这次讲述一个其他的常用神经网络自注意力机制神经网络。 对于输入的变量长度不一的时候,采用frame的形式,进行裁剪设计。…

Astra Pro点云代码

github上找到的python读取点云的代码 import timeimport cv2 as cv import numpy as np import open3d from openni import _openni2 from openni import openni2SAVE_POINTCLOUDS True # 是否保存点云数据def get_rgbd(color_capture, depth_stream, depth_scale1000, dept…

猫头虎分享已解决Bug || 依赖问题:DependencyNotFoundException: Module ‘xyz‘ was not found

博主猫头虎的技术世界 🌟 欢迎来到猫头虎的博客 — 探索技术的无限可能! 专栏链接: 🔗 精选专栏: 《面试题大全》 — 面试准备的宝典!《IDEA开发秘籍》 — 提升你的IDEA技能!《100天精通鸿蒙》 …

VirtualBox虚拟机安装 Linux 系统

要想学习各种计算机技术,自然离不开Linux系统。并且目前大多数生产系统都是安装在Linux系统上。常用的Linux系统有 Redhat,Centos,OracleLinux 三种。 三者的区别简单说明如下: Red Hat Enterprise Linux (RHEL): RHEL 是由美国…

QEMU之内存虚拟化

内存虚拟化方案 最直观的方案,将QEMU进程的虚拟地址空间的一部分作为虚拟机的物理地址。但该方案有一个问题: 在物理机上,CPU对内存的访问在保护模式下是通过分段分页实现的,在该模式下,CPU访问时使用的是虚拟地址&am…

算法--时空复杂度分析以及各个数据量对应的可使用的算法(C++;1s内)

这里写目录标题 由数据范围反推算法时间复杂度以及算法内容一级目录二级目录二级目录二级目录 一级目录二级目录二级目录二级目录 一级目录二级目录二级目录二级目录 一级目录二级目录二级目录二级目录 由数据范围反推算法时间复杂度以及算法内容 一级目录 二级目录 二级目录…

龙蜥 Anolis OS8.4 设置IP

1、配置文件路径 /etc/sysconfig/network-scripts/ [rootlocalhost ~]# cd /etc/sysconfig/network-scripts/ [rootlocalhost network-scripts]# ls ifcfg-ens32 进入配置文件路径后,展示。ifcfg-ens32这个不同的服务器不一样,本次虚拟机所对应的是ens3…

labelme 使用笔记

下载和安装 labelme官网地址 在Anaconda环境下 conda create -n labelme python3.6 conda activate labelme # go https://anaconda.org/ to find pkg conda install conda-forge/label/cf202003::labelme安装好了,查看版本和使用帮助 labelme -V labelme -h用l…

【机器人最短路径规划问题(栅格地图)】基于蚁群算法求解

基于蚁群算法求解机器人最短路径规划问题的仿真结果 仿真结果 收敛曲线变化趋势 蚁群算法求解最优解的机器人运动路径 各代蚂蚁求解机器人最短路径的运动轨迹

pandas/geopandas 笔记:逐record的轨迹dataFrame转成逐traj_id的轨迹dataFrame

我们现在有这样的一个dataframe,名字为dart 我们需要这样一个DataFrame,每一行有两列,一列是new_installation_id,表示这个轨迹的id;另一列就是这个new_installation_id的轨迹 dart_new dart[[new_installation_id]]…

使用HTML5画布(Canvas)模拟图层(Layers)效果

使用HTML5画布(Canvas)模拟图层(Layers)效果 在图形处理和计算机图形学中,图层(Layers)是指将图像分成不同的可独立编辑、组合和控制的部分的技术或概念。每个图层都可以包含不同的图形元素、效…

你真的了解C语言的枚举和联合吗~

目录 1. 枚举1.1 枚举类型的定义1.2 枚举的优点1.3 枚举的使用 2. 联合(共用体)2.1 联合类型的定义2.2 联合的特点2.3 使用联合体判断当前机器的大小端2.4 联合大小的计算 1. 枚举 枚举顾名思义就是一一列举。 把可能的取值一一列举。 比如我们现实生活…

华为云磁盘挂载

华为云磁盘挂载 磁盘挂载情况 fdisk -l 2. 查看当前分区情况 df -h 3.给新硬盘添加新分区 fdisk /dev/vdb 4.分区完成,查询所有设备的文件系统类型 blkid 发现新分区并没有文件系统类型(type为文件系统具体类型,有ext3,ext4,xfs,iso9660等…

如何一步一步地优化LVGL的丝滑度

经过一番周折将LVGL移植到了STM32F407单片机上,底层驱动的LCD是st7789,移植时的条件和环境如下: ●LVGL用的是单缓冲,一次刷新10行; ●刷新函数用的是最原始的一个一个打点的方式; ●ST7789底层发送数据用的…

寒假开学在即,怎么寄行李才能便宜省钱呢?

在度过了一个充实愉快的假期之后,小伙伴们就要踏上新的征程了,来面对新学期的到来,可是,面对这么多不知道怎么安排的行李可就把人给愁死了,如果通过驿站寄行李的话,又要花费一大笔快递费了,可是…

IAudioManager.cpp源码解读

IAudioManager.cpp源码如下: 源码路径:https://cs.android.com/android/platform/superproject/main//main:frameworks/native/services/audiomanager/IAudioManager.cpp;drc84410fbd18148d422d3581201c67f1a72a6658c4;l147?hlzh-cn /** Copyright (C)…