机器学习深度学习——线性回归的从零开始实现

news2025/1/9 0:58:16

虽然现在的深度学习框架几乎可以自动化实现下面的工作,但从零开始实现可以更了解工作原理,方便我们自定义模型、自定义层或自定义损失函数。

import random
import torch
from d2l import torch as d2l

线性回归的从零开始实现

  • 生成数据集
  • 读取数据集
  • 初始化模型参数
  • 定义模型
  • 定义损失函数
  • 定义优化算法
  • 训练

生成数据集

根据带有噪声的线性模型构造一个人造数据集。任务是使用这个数据集来恢复模型的参数。我们使用低维数据,可以更容易地进行可视化。
在下面代码中,我们生成一个包含1000个样本的数据集,每个样本包含从标准正态分布中采样的2个特征。我们的数据集是一个1000×2的矩阵X。
使用线性模型参数 w = [ 2 , − 3.4 ] T 、 b = 4.2 和噪声项 δ 生成数据集及标签: y = X w + b + δ 使用线性模型参数w=[2,-3.4]^T、b=4.2和噪声项\delta生成数据集及标签:\\ y=Xw+b+\delta 使用线性模型参数w=[2,3.4]Tb=4.2和噪声项δ生成数据集及标签:y=Xw+b+δ
其中,δ可以视为模型预测和标签时的潜在观测误差。在这里我们认为标准假设成立,即δ服从均值为0的正态分布。为简化问题,将标准差设为0.01。下面的代码生成合成数据集:

def synthetic_data(w, b, num_examples):  #@save
    """生成y=Xw+b+δ"""
    # 生成均值为0,标准差为1(标准正态分布)且大小1000*2的数据集
    X = torch.normal(0, 1, (num_examples, len(w)))
    # 生成y函数,生成1000*1的矩阵
    y = torch.matmul(X, w) + b
    # 再加上服从均值为0的正态分布的δ
    y += torch.normal(0, 0.01, y.shape)
    return X, y.reshape((-1, 1))

true_w = torch.tensor([2, -3.4])
true_b = 4.2
features, labels = synthetic_data(true_w, true_b, 1000)

其中,features中每一行都包含一个二维数据样本,labels中每一行都包含一个一维标签值(一个标量)。

print('features:', features[0], '\nlabel:', labels[0])

结果:

features: tensor([-0.5829, -0.2094])
label: tensor([3.7491])

通过生成第二个特征features[:, 1]和labels的散点图,可以直观看出两者之间的线性关系:

d2l.set_figsize()
d2l.plt.scatter(features[:, 1].detach().numpy(), labels.detach().numpy(), 1)
d2l.plt.show()

在这里插入图片描述

读取数据集

训练模型时,要对数据集进行遍历,每次抽取一小批量样本,并使用它们来更新模型。因此,需要定义一个函数,该函数能打乱数据集中的样本并以小批量方式获取数据。
在下面代码中,定义一个data_iter函数,接收批量大小、特征矩阵和标签向量作为输入,生成大小为batch_size的小批量。每个小批量包含一组特征和标签。

def data_iter(batch_size, features, labels):
    num_examples = len(features)
    indices = list(range(num_examples))  # 0到999的顺序
    random.shuffle(indices)  # 这些样本是随机读取的,没有特定顺序
    for i in range(0, num_examples, batch_size):
        batch_indices = torch.tensor(
            indices[i: min(i+batch_size, num_examples)]  # 随机取样
        )
        yield features[batch_indices], labels[batch_indices]
        # yield返回一个可以用来迭代for循环的生成器,而不是直接return

通常我们会利用CPU并行运算的优势,处理合理大小的“小批量”。每个样本都可以并行进行模型计算,且每个样本损失函数的梯度也可以被并行计算。
可以直观感受一下小批量运算:读取第一个小批量数据样本并打印:

batch_size = 10
for X, y in data_iter(batch_size, features, labels):
    print(X, '\n', y)
    break

结果:

tensor([[-1.0186, 1.8338],
[ 0.6455, 1.1226],
[-0.5020, 0.2105],
[ 1.3583, 0.6979],
[ 0.3024, -0.8929],
[ 0.4045, -0.4207],
[ 0.5201, -0.3263],
[ 0.6037, -0.1332],
[ 1.6171, 0.2449],
[-0.6540, 1.0338]])
tensor([[-4.0795],
[ 1.6835],
[ 2.5014],
[ 4.5346],
[ 7.8678],
[ 6.4298],
[ 6.3537],
[ 5.8528],
[ 6.6194],
[-0.6216]])

当我们进行迭代时,我们会连续地获得不同的小批量,直到遍历完整个数据集。但上面实现的迭代执行效率很低,可能会出问题。在深度学习框架中实现的内置迭代器效率要高得多,它可以处理存储在文件中的数据和数据流提供的数据

初始化模型参数

通过从均值为0、标准差为0.01的正态分布中采样随机数来初始化权重,并将偏置初始化为0:

w = torch.normal(0, 0.01, size=(2, 1), requires_grad=True)
b = torch.zeros(1, requires_grad=True)

在初始化参数后,我们的任务是更新这些参数,直到这些参数足够拟合我们的数据。每次更新都需要计算损失函数关于模型参数的梯度,有了这个梯度就可以向减小损失的方向来更新每个参数

定义模型

定义模型,就要将模型的输入和参数同模型的输出关联起来。
而要计算线性模型的输出,只需要计算输入特征X与模型权重w的矩阵-向量乘法后再加上偏置b。(Xw是一个向量,而b是标量)当我们用一个向量加上一个标量时,标量会加到向量的每个分量上(广播机制):

def linreg(X, w, b):  #@save
    """线性回归模型"""
    return torch.matmul(X, w) + b

定义损失函数

要计算损失函数的梯度,自然要先定义损失函数,下面定义了平方损失函数:

def squared_loss(y_hat, y):  #@save
    """均方损失"""
    return (y_hat - y.reshape(y_hat.shape)) ** 2 / 2

定义优化算法

线性回归是有解析解的,但是其他模型基本没有,因此还是用随机梯度下降法来进行优化。
在每一步中,使用数据集中随机抽取的一个小批量,然后根据参数计算损失的梯度。接下来朝着减小损失的方向来更新参数。
下面就是随机梯度下降更新的函数,该函数接受模型参数集合、学习速率和批量大小作为输入。每一步更新的大小由学习率lr决定。因为我们计算的损失是一个批量样本的综合,因此用批量大小batch_size来规范步长,这样步长大小就不会取决于我们对批量大小的选择:

def sgd(params, lr, batch_size):  #@save
    """小批量随机梯度下降"""
    with torch.no_grad():
        for param in params:
            param -= lr * param.grad / batch_size
            param.grad.zero_()

其中,torch.no_grad()是上下文管理器,用来指定在其内部的代码块中不进行梯度计算。当不需要计算梯度时,使用该上下文管理器可以提高代码执行效率。

训练

在每次迭代中,我们读取一小批训练样本,并通过我们的模型来获得一组预测。计算完损失后,我们开始反向传播,存储每个参数的梯度。最后调用优化算法sgd来更新模型参数。
概括一下,就是执行下面的循环:
1、初始化参数
2、重复一下训练,直到完成:
计算梯度 g ← ∂ ( w , b ) 1 ∣ B ∣ ∑ i ∈ B l ( x ( i ) , y ( i ) , w , b ) 更新参数 ( w , b ) ← ( w , b ) − η g 计算梯度g←\partial_{(w,b)}\frac{1}{|B|}\sum_{i∈B}l(x^{(i)},y^{(i)},w,b)\\ 更新参数(w,b)←(w,b)-ηg 计算梯度g(w,b)B1iBl(x(i),y(i),w,b)更新参数(w,b)(w,b)ηg
在每个迭代周期中,我们使用deta_iter函数遍历整个数据集,并将训练数据集中所有样本都使用一次(假设样本数能够被批量大小整除)。这里的迭代周期个数num_epoches和学习率lr都是超参数,分别设为3和0.03。(超参数设置很麻烦,现在忽略细节)

batch_size = 10
lr = 0.03
num_epochs = 3
net = linreg
loss = squared_loss

for epoch in range(num_epochs):
    for X, y in data_iter(batch_size, features, labels):
        l = loss(net(X, w, b), y)  # X和y的小批量损失
        # 因为l形状是(batch_size,1),不是标量
        # l中所有元素加起来再计算关于[w,b]的梯度
        l.sum().backward()
        sgd([w, b], lr, batch_size)  # 使用参数的梯度更新参数
    with torch.no_grad():
        train_l = loss(net(features, w, b), labels)
        print(f'epoch {epoch + 1}, loss {float(train_l.mean()):f}')

结果:

epoch 1, loss 0.040672
epoch 2, loss 0.000146
epoch 3, loss 0.000047

事实上,真实参数和通过训练得到的参数很接近:

print(f'w的估计误差: {true_w - w.reshape(true_w.shape)}')
print(f'b的估计误差: {true_b - b}')

结果:

w的估计误差: tensor([ 0.0006, -0.0002], grad_fn=)
b的估计误差: tensor([0.0004], grad_fn=)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/786371.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【技术】国标GB28181视频监控平台EasyGBS视无法播放,抓包返回ICMP

视频流媒体安防监控国标GB28181平台EasyGBS视频能力丰富,部署灵活,既能作为业务平台使用,也能作为安防监控视频能力层被业务管理平台调用。国标GB28181视频EasyGBS平台可提供流媒体接入、处理、转发等服务,支持内网、公网的安防视…

Ansys Speos | Presets 适合用户的预定义参数集

概述 Speos Presets 参数预置功能允许创建预定义的参数集,并将它们应用于新的或现有的 Speos,从任何 Speos 对象创建预设,例如光源,传感器,材料,仿真等,通过一个*.Preset 的文件定对仿真类型的配…

C++之文件操作

1.C文件操作 C中文件操作头文件:fstream。   文件类型:文件文件和二进制文件。 文件操作三大类:     ofstream 写操作     ifstream 读操作     fstream:读写操作 文件打开方式: 标志说明ios::in只读ios::out只写,文件不存在则…

Spring系列一:spring的安装与使用

文章目录 💞 官方资料🍊Spring5下载🍊文档介绍 💞Spring5🍊内容介绍🍊重要概念 💞快速入门🍊Spring操作演示🍊类加载路径🍊Debug配置🍊Spring容器…

基于Centos 7虚拟机的磁盘操作(添加磁盘、分区、格式分区、挂载)

目录 一、添加硬盘 二、查看新磁盘 三、磁盘分区 3.1新建分区 3.2 格式分区 3.3 挂载分区 3.4 永久挂载新分区 3.5 取消挂载分区 一、添加硬盘 1.在虚拟机处选择编辑虚拟机设置,然后选择添加 2.选择硬盘,然后选择下一步 3.默认即可,下一步…

啤酒节,燃起青岛啤酒们的“热血”

【潮汐商业评论/ 原创】 “这周五晚上我们就出发!三年了,终于可以再去啤酒节畅快淋漓了!”作为啤酒爱好者Joe兴奋道。 随着线下经济的复苏,疫情后的第一个盛夏正在被全国各地的“啤酒狂欢”所点燃。 7月14日晚,随着…

Canal安装部署与测试

文章目录 第一章 Canal概述1.1 简介1.2 工作原理1.2.1 MySQL主备复制原理1.2.2 canal 工作原理 1.3 重要版本更新说明1.4 多语言 第二章 Canal安装部署2.1 准备2.2 canal安装 第三章 Canal和Kafka整合测试注意事项 第一章 Canal概述 Github地址:https://github.com…

脑电信号处理与特征提取——4.脑电信号的预处理及数据分析要点(彭微微)

目录 四、脑电信号的预处理及数据分析要点 4.1 脑电基础知识回顾 4.2 伪迹 4.3 EEG预处理 4.3.1 滤波 4.3.2 重参考 4.3.3 分段和基线校正 4.3.4 坏段剔除 4.3.5 坏导剔除/插值 4.3.6 独立成分分析ICA 4.4 事件相关电位(ERPs) 4.4.1 如何获…

Java子类可以继承父类的所有属性吗

子类可以继承父类的所有属性吗 答案: 不可以,当然是不可以全部属性和方法都继承,那么哪些不可以继承? 最起码私有的就不可以,私有的属性和方法都不可以,其他的那就需要继续去测试了。 我用的是jdk1.8来…

Docker 的数据管理、镜像的创建

Docker 的数据管理、镜像的创建 Docker 的数据管理1.数据卷2.数据卷容器端口映射容器互联(使用centos镜像) Docker 镜像的创建1.基于现有镜像创建(1)首先启动一个镜像,在容器里做修改…

使用docker 部署自己的chatgpt

直接docker部署 docker run --name chatgpt-web -d -p 3002:3002 --env OPENAI_API_KEYyour_api_key chenzhaoyu94/chatgpt-web:latestDocker compose部署 version: 3services:app:image: chenzhaoyu94/chatgpt-web # 总是使用 latest ,更新时重新 pull 该 tag 镜像即可ports…

拆解雪花算法生成规则 | 京东物流技术团队

1 介绍 雪花算法(Snowflake)是一种生成分布式全局唯一 ID 的算法,生成的 ID 称为 Snowflake IDs 或 snowflakes。这种算法由 Twitter 创建,并用于推文的 ID。目前仓储平台生成 ID 是用的雪花算法修改后的版本。 雪花算法几个特性…

体验百度大模型文心千帆有感

大家好,我是雄雄,微信公众号:雄雄的小课堂,欢迎关注。 前言 近段时间来,各种大厂都推出了自己的大模型平台,有讯飞星火大模型、百度文心大模型、阿里通义千问大模型、claude、还有openai公司出产的chatgpt…

基于so-token的前后端分离项目 + uni-app微信小程序

PC端效果: 移动端效果: Sa-Token 是一个轻量级 Java 权限认证框架,主要解决:登录认证、权限认证、单点登录、OAuth2.0、分布式Session会话、微服务网关鉴权 等一系列权限相关问题。 uni-app 是一个使用 Vue.js 开发所有前端应用的…

合并二叉树

给你两棵二叉树: root1 和 root2 。 想象一下,当你将其中一棵覆盖到另一棵之上时,两棵树上的一些节点将会重叠(而另一些不会)。你需要将这两棵树合并成一棵新二叉树。合并的规则是:如果两个节点重叠&#…

荧光粉的发光效率是多少?--光致发光量子效率检测系统

稀土荧光粉广泛应用于生态照明、动态显示、通讯卫星、光学计算机及生物分子探针等高科技领域。三基色荧光粉是目前最有研究价值的荧光粉,其在可见光区具有丰富的谱线,发光谱带狭窄,发光能量更为集中;具有较强的抗紫外辐照能力&…

用户认证模式Cookie-Session、JWT-Token(goland实现)

用户认证 Cookie-Session认证模式简介代码示例优缺点 Token认证模式简介JWT介绍JWT结构标头(Header)负载(Payload)签名(Signature) 代码示例JWT优缺点Access Token和Refresh Token认证模式代码示例 在计算机…

解决QT 编译qmake 无法找到问题

问题: Command qmake not found, but can be installed with: sudo apt install qtchooser 原因: 这个错误提示指出在当前环境中找不到 qmake 命令 解决方法: 其实ubuntu已经给提示了就是要安装qtchooser 安装命令为: sudo…

Springboot中 AOP实现日志信息的记录到数据库

1、导入相关的依赖 <!--spring切面aop依赖--><dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-aop</artifactId></dependency> 注意&#xff1a;在application.properties文件里加这样一…

串2:云计算架构思考

开始之前&#xff0c;先给出串1&#xff1a;一文将大数据、云计算、物联网、5G&#xff08;移动网&#xff09;、人工智能等最新技术串起来_龙赤子的博客-CSDN博客 承上 事物的复杂性一般有两个方面&#xff0c;一个是本身结构的复杂&#xff0c;一个是运行机制的复杂。因为这…