实现线性回归笔记 # 自用

news2024/12/17 7:43:06

线性模型可以看作是一个单层的神经网络。

对于n个输入[x1, x2, ...., xn],由n个权重[w1, w2, ......, wn]以及一个偏置常数b得到的输出y,则称y = x1w1+x2w2+......+xnwn+b称为线性模型

即 线性模型是对n维输入的加权外加偏差。

要利用线性模型进行预测,要先衡量真实值与估计值之间的损失

这里我们采用的是平方损失,用来衡量单个样本的预测误差大小。

定义好模型与损失后,我们就可以利用数据集进行训练了。

假设有n个训练样本,x = [x1, x2, ......, xn]T, y = [y1, y2, ......, yn]T,每一个x是一个列向量,排列好后进行转置,得到的x每一行就是一个特征,每一个y是一个常数,即一个真实值。(T表示转置)

对于模型在每一个数据上的损失求均值,就能得到损失函数均方误差)。

函数中的1/2来自平方损失函数,1/n说明求平均。

对于第i个样本,第i个真实值减去第i个特征x与权重w的内积减去偏差b再平方求和,最后乘以1/2n。

写成向量,即一个向量y减去矩阵X乘向量w减去标量b,再将其求平方和。

最后求损失的最小值,将最小值中的w和b作为模型的解。

以上为求解的过程,但线性回归模型是有显式解的,故我们可以简化过程。

首先在X的最右边加入一列1,再将偏差放置在w的末端,损失函数即可写成

求导后即

当一个模型没有显示解时,可以先随机生成参数的初始值w0,接下来利用梯度下降不断更新w。

更新规则为

新的w等于 上一次的w 减去损失函数关于上一次w的梯度 再乘学习率。

学习率可看作梯度下降过程中的下降步长,太大则下降效果过于粗糙,太小则运算次数需求过高。

学习率超参数,即人为指定的参数

每一次梯度下降都要计算一次损失函数,运算量过大。可以随机采样b个样本来近似计算损失。

这里的b是批量大小,同样是超参数。

b的选择太小,不适合并行计算;选择太大则消耗内存太大。

总结:梯度下降通过不断沿着反梯度方向更新参数求解。

以下为pytorch实现代码:

import numpy as np
import torch
from torch.utils import data
from torch import nn  # 神经网络模块
def synthetic_data(w, b, num_examples):
    """
    生成合成数据,模拟线性回归问题。
    
    参数:
    w (torch.Tensor): 线性模型的权重。
    b (float): 线性模型的偏置。
    num_examples (int): 生成的样本数量。
    device (str): 指定运行设备,默认为 'cuda'。
    
    返回:
    X (torch.Tensor): 生成的特征数据,形状为 (num_examples, len(w))。
    y (torch.Tensor): 生成的标签数据,形状为 (num_examples, 1)。
    """
    # 生成均值为0,标准差为1的正态分布随机数,形状为 (num_examples, len(w))
    X = torch.normal(0, 1, (num_examples, len(w)))
    # 使用线性模型计算预测值
    y = torch.matmul(X, w) + b
    # 向预测值中添加均值为0,标准差为0.01的正态分布噪声
    y += torch.normal(0, 0.01, y.shape)
    # 将 y 调整为列向量
    return X, y.reshape((-1, 1))

def load_array(data_arrays, batch_size, is_train=True):
    """
    构造一个PyTorch数据迭代器,用于加载和批量处理数据。

    参数:
        data_arrays (tuple): 包含特征张量和标签张量的元组。例如: (features, labels)。
        batch_size (int): 每个批次的数据大小,即每次迭代时返回的样本数量。
        is_train (bool): 是否在加载数据时打乱顺序 (shuffle)。
            - 如果为 True:数据会在每个 epoch 开始前打乱顺序,常用于训练。
            - 如果为 False:数据按照顺序加载,常用于验证或测试。

    返回:
        torch.utils.data.DataLoader: PyTorch 的数据加载器,支持按批次迭代数据。
    """
    # 使用 TensorDataset 将特征张量和标签张量打包在一起
    # TensorDataset 会将特征和标签一一对应,形成一个可以索引的数据集
    dataset = data.TensorDataset(*data_arrays)

    # 创建数据加载器 DataLoader:
    # - dataset: 传入封装好的 TensorDataset。
    # - batch_size: 每个批次返回的数据量。
    # - shuffle: 决定是否在每个 epoch 前随机打乱数据顺序。
    return data.DataLoader(dataset, batch_size, shuffle=is_train)

def train(net, true_w, true_b, features, labels, batch_size, num_epochs, lr):
    """
    训练一个简单的线性回归模型,使用均方误差作为损失函数,并通过随机梯度下降优化。

    参数:
        net (torch.nn.Sequential): 神经网络模型,包含一个线性层。
        true_w (torch.Tensor): 真实的权重向量,用于与模型学习到的参数进行比较。
        true_b (float): 真实的偏置值,用于与模型学习到的参数进行比较。
        features (torch.Tensor): 输入特征数据,形状为 (样本数, 特征数)。
        labels (torch.Tensor): 输出标签数据,形状为 (样本数, ) 或 (样本数, 标签数)。
        batch_size (int): 每个批次的样本数量。
        num_epochs (int): 训练的轮数(数据集完整遍历的次数)。
        lr (float): 学习率,控制梯度下降时的步长。

    输出:
        - 每轮训练的损失值。
        - 训练完成后,打印学习到的权重和偏置与真实值的误差。
    """
    # 构造数据迭代器,支持小批量随机梯度下降
    data_iter = load_array((features, labels), batch_size)
    # 初始化模型参数
    net[0].weight.data.normal_(0, 0.01)  # 将权重初始化为均值为 0,标准差为 0.01 的正态分布
    net[0].bias.data.fill_(0)            # 将偏置初始化为 0
    # 定义损失函数:均方误差 (MSE)
    loss = nn.MSELoss()
    # 定义优化器:随机梯度下降 (SGD)
    # net.parameters() 返回模型中所有需要优化的参数(权重和偏置)
    trainer = torch.optim.SGD(net.parameters(), lr=lr)
    # 开始训练
    for epoch in range(num_epochs):  # 外层循环:遍历每个 epoch
        for X, y in data_iter:       # 内层循环:遍历每个小批量数据
            # 前向传播:计算模型预测值和损失
            l = loss(net(X), y)  # l 是当前小批量的损失值
            # 清空上一批次的梯度
            trainer.zero_grad()
            # 反向传播:计算当前损失对模型参数的梯度
            l.backward()
            # 更新参数:使用计算出的梯度对模型参数进行优化
            trainer.step()
        # 每个 epoch 结束后,计算整个数据集上的损失值
        l = loss(net(features), labels)
        # 打印当前 epoch 的训练损失
        print(f'epoch {epoch + 1}, loss {l:f}')
    # 打印模型参数的估计误差
    # 将学习到的参数与真实值对比,衡量模型的学习效果
    print(f'w的估计误差: {true_w - net[0].weight.reshape(true_w.shape)}')
    print(f'b的估计误差: {true_b - net[0].bias}')


net = nn.Sequential(nn.Linear(2, 1))
true_w = torch.tensor([2, -3.4])
true_b = 4.2
features, labels = synthetic_data(true_w, true_b, 1000)
batch_size = 10
num_epochs = 3
lr = 0.03

train(net, true_w, true_b, features, labels, batch_size, num_epochs, lr)

若想将权重保存到本地或加载本地权重文件,代码如下

# 保存文件
torch.save({
    'weights': net[0].weight.data,
    'bias': net[0].bias.data
}, 'linear_model_weights.pth')

checkpoint = torch.load('linear_model_weights.pth')  # 加载保存的文件
loaded_weights = checkpoint['weights']
loaded_bias = checkpoint['bias']

# 将权重和偏置赋值回模型
net[0].weight.data = loaded_weights.clone()
net[0].bias.data = loaded_bias.clone()

以上是一个简易的线性回归模型训练方法。

懒得注释代码所以让ai生成了注释(目移)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2260935.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

实景视频与模型叠加融合?

[视频GIS系列]无人机视频与与实景模型进行实时融合_无人机视频融合-CSDN博客文章浏览阅读1.5k次,点赞28次,收藏14次。将无人机视频与实景模型进行实时融合是一个涉及多个技术领域的复杂过程,主要包括无人机视频采集、实景模型构建、视频与模型…

c语言——数据结构【链表:单向链表】

上篇→快速掌握C语言——数据结构【创建顺序表】多文件编译-CSDN博客 一、链表 二、单向链表 2.1 概念 2.2 单向链表的组成 2.3 单向链表节点的结构体原型 //类型重定义,表示存放的数据类型 typedef int DataType;//定义节点的结构体类型 typedef struct node {union{int l…

【LC】876. 链表的中间结点

题目描述: 给你单链表的头结点 head ,请你找出并返回链表的中间结点。 如果有两个中间结点,则返回第二个中间结点。 示例 1: 输入:head [1,2,3,4,5] 输出:[3,4,5] 解释:链表只有一个中间结点…

Bugku---misc---隐写2

题目出处:首页 - Bugku CTF平台 ✨打开发现是一张图片,于是查看属性,放在010查看,这都是基本步骤了,发现里面有一个flag.rar!!!拿binwalk分析也确实存在 ✨于是按照压缩包的起始位置…

无需公网IP,本地可访问TightVNC 服务端

TightVNC 是一款免费而且开源的远程桌面软件,它允许用户在不同的操作系统之间实现无缝连接,TightVNC支持 Windows、macOS 和 Linux 等多个操作系统,为用户提供高效便捷的远程控制体验。在 Windows 系统电脑端安装使用 TightVNC 服务端和客户端…

【Unity基础】Unity中如何实现图形倒计时

为了在Unity中实现一个图形倒计时,除了代码部分,还需要一些UI元素的创建和设置。本文以环形倒计时为例,以下是完整的步骤,涵盖了如何创建UI元素、设置它们,以及如何编写控制环形倒计时进度的脚本。 1. 创建UI元素 创建…

Excel/VBA 正则表达式归纳汇总

1.with结构。以下语句用来提取A列中的“成品”两个字前面的部分的中文,不含成品两个字,结果存放在第2列。使用了On Error Resume Next,表示错误时继续下一条。 Sub 提取口味() Set regx CreateObject("vbscript.regexp") On Err…

xshell连接虚拟机,更换网络模式:NAT->桥接模式

NAT模式:虚拟机通过宿主机的网络访问外网。优点在于不需要手动配置IP地址和子网掩码,只要宿主机能够访问网络,虚拟机也能够访问。对外部网络而言,它看到的是宿主机的IP地址,而不是虚拟机的IP。但是,宿主机可…

优选算法《双指针》

在学习了C/C的基础知识之后接下来我们就可以来系统的学习相关的算法了,这在之后的笔试、面试或竞赛都是必须要掌握的;在这些算法中我们先来了解的是一些非常经典且较为常用的算法,在此也就是优选出来的算法,接下来在每一篇章中我们…

SQL server学习06-查询数据表中的数据(中)

目录 一,聚合函数 1,常用聚合函数 2,具体使用 二,GROP BY子句分组 1,基础语法 2,具体使用 3,加上HAVING对组进行筛选 4,使WHERE记录查询条件 汇总查询:在对数…

上传文件时获取音视频文件时长和文本文件字数

获取音视频文件时长和文本文件字数 一、获取音视频文件时长二、计算文本文件字数 最近有个需求,要求上传文件时获取音视频文件时长和文本文件字数🐶。 发现这样的冷门资料不多,特做个记录。本文忽略文件上传功能,只封装核心的工具…

C语言学习day22:进程ID获取工具/GetWindowThreadProcessId函数

简言: 每个人都有身份证号,这个身份证号就是个人的唯一标识符 进程也是如此,每个进程也有唯一的标识符,来标记自身是独一无二的 如下图:其中PID :Process ID,即进程ID 但是我们怎么去在编程中去获取某个…

使用Localstorage(Mapty)

使用Localstorage(Mapty) 首先,我们创建一个函数名,先在app中去调用它 // 为所有的锻炼创建本地存储this._setLocalStorage();之后我们就开始编写这个函数的功能 _setLocalStorage() {localStorage.setItem(workouts, JSON.stringify(this.#workouts));…

如何用细节提升用户体验?

前端给用户反馈是提升用户体验的重要部分,根据场景选择不同的方式可以有效地提升产品的易用性和用户满意度。以下是常见的方法: 1. 视觉反馈 用户执行了某些操作后,需要即时确认操作结果。例如:按钮点击、数据提交、页面加载等。…

OpenHarmony-3.HDF input子系统(5)

HDF input 子系统OpenHarmony-4.0-Release 1.Input 概述 输入设备是用户与计算机系统进行人机交互的主要装置之一,是用户与计算机或者其他设备通信的桥梁。常见的输入设备有键盘、鼠标、游戏杆、触摸屏等。本文档将介绍基于 HDF_Input 模型的触摸屏器件 IC 为 GT91…

旅游资源系统|Java|SSM|VUE| 前后端分离

【技术栈】 1⃣️:架构: B/S、MVC 2⃣️:系统环境:Windowsh/Mac 3⃣️:开发环境:IDEA、JDK1.8、Maven、Mysql5.7 4⃣️:技术栈:Java、Mysql、SSM、Mybatis-Plus、VUE、jquery,html 5⃣️数据库可…

Docker网络与数据管理

Docker网络与数据管理 1. Docker网络基础:桥接网络、主机网络和自定义网络 Docker提供了多种网络模式,以满足不同应用场景的需求。理解Docker的网络模式对于容器间通信、网络安全性及性能优化至关重要。在Docker中,每个容器都可以连接到不同…

X.game解析柚子币提升速效双向利好和年中历史新低原因

柚子币最新消息,币安宣布将于2024年9月25日21:00左右暂停柚子币网络上的代币存取业务,以全力支持即将到来的柚子币网络升级和硬分叉,这一消息为柚子币的未来发展增添了新的期待和变数。 除了速度的提升,Spring1.0还带来了诸多技术…

数据结构之线性表1

2.1 线性表的定义和基本操作 1.线性结构的特点是:在数据元素的非空有限集中, (1)存在惟一的一个被称做“第一个”的数据元素; (2) 存在惟一的一个被称做“最后一个”的数据元素; &a…

Tomcat原理(5)——tomcat最终实现

目录 一、什么是Servlet容器 二、ServletConfigMapping构建实现容器 ServletConfigMapping MyTomcat 三、优化server Server MyTomcat 四、匹配 代码如下: 测试如下: 上一篇博客已经为介绍了servelet的实现 ,这篇对上一篇博客进行补…