PyTorch实现基本的线性回归

news2025/1/21 4:58:04

线性回归理论知识参考文章:线性回归

下面我们将从零开始实现整个线性回归方法, 包括数据集生成、模型、损失函数和小批量随机梯度下降优化器。

1.导入

%matplotlib inline
import random
import torch
from d2l import torch as d2l

2.生成数据集

我们将生成一个包含1000个样本的数据集, 每个样本包含从标准正态分布中采样的2个特征。我们的合成数据集是一个矩阵 X ∈ R 1000 × 2 \mathbf{X}\in \mathbb{R}^{1000 \times 2} XR1000×2
使用线性模型参数 w = [ 2 , − 3.4 ] ⊤ \mathbf{w} = [2, -3.4]^\top w=[2,3.4] b = 4.2 b = 4.2 b=4.2和噪声项 ϵ \epsilon ϵ生成数据集及其标签: y = X w + b + ϵ . \mathbf{y}= \mathbf{X} \mathbf{w} + b + \mathbf\epsilon. y=Xw+b+ϵ.
ϵ \epsilon ϵ可以视为模型预测和标签时的潜在观测误差。在这里我们认为标准假设成立,即 ϵ \epsilon ϵ服从均值为0的正态分布。为了简化问题,我们将标准差设为0.01。

def synthetic_data(w, b, num_examples):  #@save
    """生成y=Xw+b+噪声"""
    X = torch.normal(0, 1, (num_examples, len(w)))  #该函数返回从单独的正态分布中提取的随机数的张量 normal(mean, std, size)
    y = torch.matmul(X, w) + b
    y += torch.normal(0, 0.01, y.shape)
    return X, y.reshape((-1, 1))

true_w = torch.tensor([2, -3.4])
true_b = 4.2
features, labels = synthetic_data(true_w, true_b, 1000)

features中的每一行都包含一个二维数据样本, labels中的每一行都包含一维标签值(一个标量)
通过生成第二个特征features[:, 1]和labels的散点图, 可以直观观察到两者之间的线性关系。

d2l.set_figsize()
d2l.plt.scatter(features[:, 1].detach().numpy(), labels.detach().numpy(), 1);

在这里插入图片描述

3.读取数据集

训练模型时要对数据集进行遍历,每次抽取一小批量样本,并使用它们来更新我们的模型。 由于这个过程是训练机器学习算法的基础,所以有必要定义一个函数, 该函数能打乱数据集中的样本并以小批量方式获取数据。
定义一个data_iter函数, 该函数接收批量大小、特征矩阵和标签向量作为输入,生成大小为batch_size的小批量。 每个小批量包含一组特征和标签。

def data_iter(batch_size, features, labels):
    num_examples = len(features)
    indices = list(range(num_examples))  # 生成下标
    # 这些样本是随机读取的,没有特定的顺序
    random.shuffle(indices)  # 把下标随机打乱,用随机的顺序访问样本 
    for i in range(0, num_examples, batch_size):
        batch_indices = torch.tensor(
            indices[i: min(i + batch_size, num_examples)]) # 得到batch_size大小个的随机下标
        yield features[batch_indices], labels[batch_indices]   # torch.Tensor 张量的下标可以是一个数组

当我们运行迭代时,我们会连续地获得不同的小批量,直至遍历完整个数据集。

4.初始化模型参数

通过从均值为0、标准差为0.01的正态分布中采样随机数来初始化权重, 并将偏置初始化为0。

w = torch.normal(0, 0.01, size=(2,1), requires_grad=True) # pytorch自动计算梯度
b = torch.zeros(1, requires_grad=True)

在初始化参数之后,我们的任务是更新这些参数,直到这些参数足够拟合我们的数据。 每次更新都需要计算损失函数关于模型参数的梯度。 有了这个梯度,我们就可以向减小损失的方向更新每个参数。 因为手动计算梯度很枯燥而且容易出错,所以没有人会手动计算梯度。

5.定义模型

def linreg(X, w, b):  #@save
    """线性回归模型"""
    return torch.matmul(X, w) + b # 广播机制

6.损失函数

def squared_loss(y_hat, y):  #@save
    """均方损失"""
    return (y_hat - y.reshape(y_hat.shape)) ** 2 / 2

7.优化算法

在每一步中,使用从数据集中随机抽取的一个小批量,然后根据参数计算损失的梯度。 接下来,朝着减少损失的方向更新我们的参数。 下面的函数实现小批量随机梯度下降更新。 该函数接受模型参数集合、学习速率和批量大小作为输入。每 一步更新的大小由学习速率lr决定。 因为我们计算的损失是一个批量样本的总和,所以我们用批量大小(batch_size) 来规范化步长,这样步长大小就不会取决于我们对批量大小的选择。

def sgd(params, lr, batch_size):  #@save
    """小批量随机梯度下降"""
    with torch.no_grad():
        #torch.no_grad上一个上下文管理器,在你确定不需要调用Tensor.backward()时
        #可以用torch.no_grad来屏蔽梯度计算
        #在被torch.no_grad管控下计算得到的tensor,它的requires_grad就是False
        for param in params:
            param -= lr * param.grad / batch_size
            param.grad.zero_()  # 梯度清零

8.训练

在每个迭代周期(epoch)中,我们使用data_iter函数遍历整个数据集, 并将训练数据集中所有样本都使用一次(假设样本数能够被批量大小整除)。 这里的迭代周期个数num_epochs和学习率lr都是超参数,分别设为3和0.03。 设置超参数很棘手,需要通过反复试验进行调整。

lr = 0.03  # 可以尝试不同的学习率
num_epochs = 3
net = linreg   # 定义模型
loss = squared_loss   # 定义损失

for epoch in range(num_epochs):
    for X, y in data_iter(batch_size, features, labels):
        l = loss(net(X, w, b), y)  # X和y的小批量损失
        # 因为l形状是(batch_size,1),而不是一个标量。l中的所有元素被加到一起,
        # 并以此计算关于[w,b]的梯度
        l.sum().backward()
        sgd([w, b], lr, batch_size)  # 使用参数的梯度更新参数
    with torch.no_grad():
        train_l = loss(net(features, w, b), labels)  # 用更新过的参数计算损失
        print(f'epoch {epoch + 1}, loss {float(train_l.mean()):f}')

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/161216.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

js垃圾回收(引用计数算法、标记清除算法、v8垃圾回收机制、浏览器性能监控、任务管理器、内存分析、JSBench)

目录 垃圾 可达对象 GC算法(垃圾回收机制) 引用计数算法 优点 缺点 标记清除算法 优点 缺点 标记整理算法 优点 缺点 V8 V8垃圾回收 新生代对象回收 晋升条件 老生代对象回收 性能监控Performance 浏览器任务管理器 内存分析 ​编…

Apache Doris 系列: 基础篇-BitMap索引

1. 测试数据准备 本文使用SSB(Star-Schema-Benchmark)的测试数据,读者也可以自行准备测试数据 1.1 编译ssb-dbgen 数据生成工具 ## 拉取Apache Doris源代码 git clone https://github.com/apache/doris.git## 编译ssb-dbgen cd doris/tool…

计算机网络复习之应用层

统一资源定位系统(uniform resource locator;URL)是因特网的万维网服务程序上用于指定信息位置的表示方法。它最初是由蒂姆伯纳斯李发明用来作为万维网的地址。现在它已经被万维网联盟编制为互联网标准RFC1738。邮局协议(Post Office Protoco…

TDemo 备注文本的二种存贮方式

TDemo 备注纯文本的二种存贮方式 数据库使用过程中,对于TDeme控件,对应数据库的分为nvarchar(n)类型字段。 一、通常使用二种格式的文本: (1)单纯文本 (2)带换行符的文本 这二种格式&#xff0c…

Pdf 转换成Word如何在线转换?职场公认好用软件推荐

Pdf 转换成Word如何在线转换?生活中很多时候我们需要接触大量的办公文件,特别是利用office的三种常见的文件格式编辑各类文件,最常见的便是Word文件操作。为了更方便我们进行文件传输,大部分情况下我们会把格式排版完好的Word文档…

UDS诊断系列介绍08-19服务

本文框架1. 系列介绍1.1 19服务概述1.2 DTC故障码定义1.3 DTC状态位2. 19服务常用子服务2.1 19 01服务2.2 19 02服务2.3 19 04服务2.4 19 06服务2.5 19 0A服务2.6 否定响应3. Autosar系列文章快速链接1. 系列介绍 UDS(Unified Diagnostic Services)协议…

Android 深入系统完全讲解(15)

4 权限相关的知识 1 安卓权限 上层 APK 权限获取方式,配置 AndroidManifest.xml,系统会对应的给 gid,在创建进程的时候就带下去,这样子就可以访问对应的设备。 而系统相关的,会限制必须是 uidsystem 这类&#xff0c…

一年融资三轮,一文读懂亿格云这家公司

数字办公时代,网络安全是企业经营的底线工作。如何构建一个安全、稳定、高效的网络安全体系,是企业谋求发展的基础条件之一。近年,倡导“永不信任,始终验证”的零信任网络安全服务理念开始兴起。而国内致力于基于零信任理念构建办…

MySQL 行级锁(行锁、临键锁、间隙锁)

行级锁 行级锁,每次操作锁住对应的行数据。锁定粒度最小,发生锁冲突的概率最低,并发度最高。应用在InnoDB存储引擎中。 InnoDB的数据是基于索引组织的,行锁是通过对索引上的索引项加锁来实现的,而不是对记录加的锁。 1…

类和对象(上)

文章目录引用autoNULL&nullptr&0类和对象类的实例化默认成员函数构造函数析构函数拷贝构造函数运算符的重载赋值运算符的重载拷贝构造次数编译器优化前置后置> < ! - -const成员operator>>&&operator<<再谈构造函数初始化列表初始化expli…

使用Hi3861开发环境搭建

安装ubuntu ​ 文件夹的位置尽量选一个空间比较大的 内存也尽量分配大一点&#xff0c;不要到红色区域就行 固定分配&#xff0c;如果给它100G空间&#xff0c;他就会把这100G空间全部使用掉&#xff0c;动态分配&#xff0c;即使你给他100G内存&#xff0c;但实际使用的空间…

【自学Docker】Docker入门

Docker入门 Docker简介 Docker 是 Docker.Lnc 公司开源的一个基于 LXC 技术之上搭建的 Container 容器引擎&#xff0c;Docker 源代码托管在 Github上&#xff0c;Docker 是基于 Go 语言开发的并遵从 Apache2.0 协议开源。 Docker 属于 Linux 容器的一种封装&#xff0c;提供…

(十九)包装类

前言: 在我们讨论其他变量类型之间的相互转换时&#xff0c;我们需要了解一下Java的包装类&#xff0c;所谓包装类&#xff0c;就是能够直接将简单类型的变量表示为一个类&#xff0c;在执行变量类型的相互转换时&#xff0c;我们会大量使用这些包装类。Java共有六个包装类&…

pandas数据结构

文章目录Series创建series对象Series对象的属性DataFrame创建DataFrame对象Python 在数据处理上独步天下&#xff1a;代码灵活、开发快速&#xff1b;尤其是 Python 的 Pandas 包&#xff0c;无论是在数据分析领域、还是大数据开发场景&#xff0c;都具有显著的优势。Series S…

CesiumLab实例模型切片 CesiumLab系列教程

先解释下实例模型&#xff0c;实例模型使用 GPU instance 技术来渲染的模型&#xff0c;通常用来绘制大量几何体一致&#xff0c;但是位置姿态不同的对象&#xff0c;比如说森林场景&#xff0c;大量路灯&#xff0c;井盖等&#xff0c;如下图&#xff1a; 1.输入文件 目前输入…

【图文教程】Centos单机安装Redis

1.1.安装Redis依赖 Redis是基于C语言编写的&#xff0c;因此首先需要安装Redis所需要的gcc依赖&#xff1a; yum install -y gcc tcl1.2.上传安装包并解压 ​ 例如&#xff0c;凯哥将其放到了/usr/local/src 目录&#xff1a; 解压缩&#xff1a; tar -xzf redis-6.2.6.tar…

搞清clientHeight、offsetHeight、scrollHeight、offsetTop、scrollTop

网页可见区域高:document.body.clientHeight 网页正文全文高:document.body.scrollHeight 网页可见区域高&#xff08;包括边线的高&#xff09;&#xff1a;document.body.offsetHeight 网页被卷去的高&#xff1a;document.body.scrollTop 屏幕分辨率高&#xff1a;window.sc…

SpringBoot实践(三十九):如何使用AOP

目录 直接使用Aspect 定义切面逻辑 模拟业务代码 测试输出 自定义注解方式 自定义切面注解 定义切入点逻辑 模拟业务代码 测试输出 面向切面&#xff08;AOP) 是spring重要特性&#xff0c;在功能上切面编程是面向对象编程的很好的补充&#xff0c;面向对象强调封装和开…

BAT 名企大厂做接口自动化如何高效使用 Requests ?

1080428 28.9 KBRequests是一个优雅而简单的python HTTP库&#xff0c;其实python内置了用于访问网络的资源模块&#xff0c;比如urllib&#xff0c;但是它不如requests简单&#xff0c;优雅&#xff0c;而且缺少许多实用功能。接下来的接口测试的学习和实战&#xff0c;都与re…

语音识别系列之基于CTC的VAD

语音活动性检测&#xff08;Voice Activity Dection, VAD&#xff09;常作为语音识别系统的前端模块过滤非语音段&#xff0c;为后续增强模块提供语音/非语音判据&#xff0c;从而更好的掌握背景噪声特性&#xff0c;进而提升降噪量&#xff0c;保证识别性能&#xff0c;且能降…