动手学深度学习——线性回归从零实现

news2025/1/10 11:42:57

1. 数据集

1.1 生成数据集

要训练模型首先要准备训练数据集,对于线性模型 y=Xw+b,定义生成数据集的函数如下:

def synthetic_data(w, b, num_examples):  #@save
    """生成y=Xw+b+噪声"""
    # 从均值为0,标准差为1的正态分布中随机采样
    # 第3个参数表示张量的形状,num_examples行,len(w)列
    X = torch.normal(0, 1, (num_examples, len(w)))  
    y = torch.matmul(X, w) + b     # 矩阵乘法
    y += torch.normal(0, 0.01, y.shape) # 添加噪声,从均值为0标准差为0.01的正态分布中随机采样
    return X, y.reshape((-1, 1)) # 将y的形状变换为列向量

torch.normal:从正态分布中随机采样。标准正态分布是对称的,虽然值的分布理论上是正负无穷大,但遵循概率分布。
对于均值为0,标准差为1的正态分布,大约有68%的数据落在均值0附近的一个标准差范围内(-1, 1),约95%的数据落在两个标准差范围内(-2, 2),约99.7%的数据落在三个标准差范围内(-3, 3)。

添加噪声目的:更好地模拟现实世界中的数据。由于测量误差、不完全的观测或其他随机因素的存在,许多现实世界的数据往往是包含噪声的。

这里人为构造包含1000个样本的数据,使用指定参数w=[2,-3.4]和指定偏移b=4.2来生成数据及标签。每个样本数据包含两个特征(与参数w向量长度相同)。

true_w = torch.tensor([2, -3.4])   # 指定参数
true_b = 4.2                       # 指定偏差
features, labels = synthetic_data(true_w, true_b, 1000)

生成的数据集示例:

print('features:', features[:3],'\nlabel:', labels[:3])
> features: tensor([[ 0.9426,  0.4816],
        [ 3.7041, -0.3572],
        [ 0.2075,  0.5264]]) 
   label: tensor([[ 4.4473],
        [12.8080],
        [ 2.8209]])

可以使用matplotlib的散点图来查看标签与特征之间的线性关系。

  • plt.scatter 函数用于绘制散点图,3个参数分别为特征、标签、点的大小。
  • 下面绘制标签与第一个特征的线性关系,横坐标表示特征0,纵坐标表示标签:
d2l.plt.scatter(features[:, 0].numpy(), labels.numpy(), 1);

在这里插入图片描述

  • 下面绘制标签与第2个特征的线性关系,横坐标表示特征1,纵坐标表示标签
d2l.plt.scatter(features[:, 1].numpy(), labels.numpy(), 1);

在这里插入图片描述

1.2 读取数据集

前文讲过,训练模型采用的是小批量随机梯度下降,通过对小批量样本计算梯度来更新我们的模型,所以我们要定义一个读取小批量数据的函数。

这里定义一个data_iter函数

  • 该函数接收三个参数:批量大小、特征矩阵和标签向量。
  • 采用python中的yield生成器语法支持多次调用,每次调用返回大小为batch_size的小批量特征和标签。
  • random.shuffle: 通过将序列打乱,而达到随机读取样本的目的,小批量读出的数据没有特定的顺序。
def data_iter(batch_size, features, labels):
    num_examples = len(features)
    indices = list(range(num_examples)) # 构造一个从0开始的序列 
    random.shuffle(indices)  # 将序列的顺序打乱
    for i in range(0, num_examples, batch_size):
        batch_indices = torch.tensor(
            indices[i: min(i + batch_size, num_examples)])
        yield features[batch_indices], labels[batch_indices]

这里以batch_size=5示例下小批量样本读取:

batch_size = 5

for X, y in data_iter(batch_size, features, labels):
    print(X, '\n', y)
    break
tensor([[ 0.4795, -0.1943],
        [ 1.1495,  0.6922],
        [-1.1642, -1.7341],
        [ 0.1394,  0.5820],
        [ 0.9647, -0.7035]]) 
 tensor([[5.8164],
        [4.1540],
        [7.7648],
        [2.4904],
        [8.5224]])

2. 模型

2.1 定义模型

本质:将模型的输入特征、参数权重和模型的输出关联起来。

在我们这个场景下,模型的输出 = 输入特征X和模型权重w的矩阵向量积,再加上偏置。

def linreg(X, w, b):  #@save
    """线性回归模型"""
    return torch.matmul(X, w) + b

上面的Xw是一个向量,而b是一个标量, 两者进行加法运算时,遵循广播机制,标量b会被加到向量Xw的每个分量上。

2.2 定义损失函数

计算梯度需要先定义损失函数,这里使用上一篇文章中提到的平方误差函数,y表示真实值,y_hat表示预测值。

def squared_loss(y_hat, y):  #@save
    """均方损失"""
    return (y_hat - y.reshape(y_hat.shape)) ** 2 / 2

2.3 定义优化算法

定义一个函数来实现参数优化,主要在于用计算好的梯度来更新参数,函数接受3个参数:

  • params:一个包含需要被更新的参数张量列表。
  • lr:学习率,用于控制参数更新的速度,即每一步更新的大小。
  • batch_size:批量大小,用于计算参数的梯度和更新。
def sgd(params, lr, batch_size):  #@save
    """小批量随机梯度下降"""
    with torch.no_grad():
        for param in params: # 遍历需要更新的参数
        	# 将参数的梯度grad乘以学习率lr,再除以批量大小 batch_size,得到参数的更新量
        	# 将参数的值减去更新量,实现对参数的原地更新
            param -= lr * param.grad / batch_size
            # 每次参数更新完毕后,梯度要清零,以免影响下次计算
            param.grad.zero_()

with torch.no_grad(): 这个语句块内的计算将不会被计算图追踪。原因在于:

  1. 避免把对参数的更新操作记录为计算图的一部分,对梯度计算造成影响。
  2. 避免无用数据的保存,减少内存消耗。

3. 训练

上面已经准备好了模型训练需要的要素,下面实现训练部分。

3.1 定义模型超参

lr = 0.03
num_epochs = 3 

net = linreg        
loss = squared_loss  
  • num_epochs表示迭代周期次数,也就是训练多少次后终止。
  • lr表示学习率,学习率设的太大会震荡,太小则可能收敛慢。
  • net:网络模型,用于预测结果,使用上面定义的线性回归模型。
  • loss:损失函数,用于计算梯度。

3.2 初始化参数

采用小批量随机梯度下降优化我们的模型参数之前, 先定义参数的初始值。

这里从均值为0、标准差为1.0的正态分布中采样随机数来初始化权重, 并将偏置初始化为0。

w = torch.normal(0, 1.0, size=(2,1), requires_grad=True)
b = torch.zeros(1, requires_grad=True)
w, b

> (tensor([[-1.1961],
         [-0.3512]], requires_grad=True),
 tensor([0.], requires_grad=True))

requires_grad=True 表示对张量的梯度进行跟踪和计算,打开此设置项就能够在训练过程中对张量所表示的参数进行多轮迭代优化。

3.3 运行训练

# 迭代训练次数
for epoch in range(num_epochs):
	# 每次对小批量数据集作迭代
    for X, y in data_iter(batch_size, features, labels):
        l = loss(net(X, w, b), y)  # net做预测,loss计算损失
        l.sum().backward()  # 使用反向累积计算梯度
        sgd([w, b], lr, batch_size)  # 使用参数的梯度更新参数
    with torch.no_grad():   # 开始一个没有梯度跟踪的上下文环境
    	# 计算当前训练的参数w、b在验证集上的损失
        train_l = loss(net(features, w, b), labels)
        print(f'epoch {epoch + 1}, loss {float(train_l.mean()):f}')

运行结果:

epoch 1, loss 0.000208
epoch 2, loss 0.000051
epoch 3, loss 0.000051

对比真实参数与学习得到的参数:

print("true_w: ", true_w, ", w:", w.reshape(true_w.shape))
print("true_b: ", true_b, ",  b:", b)

> true_w:  tensor([ 2.0000, -3.4000]) , w: tensor([ 2.0004, -3.4002], grad_fn=<ViewBackward0>)
true_b:  4.2 ,  b: tensor([4.1995], requires_grad=True)

结论: 可以看到,真实参数和通过训练学到的参数确实非常接近。

3.4 不同学习率的表现

学习率lr设为1时的表现:

epoch 1, loss 0.000057
epoch 2, loss 0.000193
epoch 3, loss 0.000090

可以看到,损失函数并未收敛,结果出现了从小到大再到小的震荡。

学习率lr设为0.01时的表现:

epoch 1, loss 0.440745
epoch 2, loss 0.009315
epoch 3, loss 0.000247

可以看到,结果收敛的比0.03时要慢。

可见,学习率的设置是比较棘手的,需要反复试验进行调整。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1629706.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Git重修系列 ------ Git的使用和常用命令总结

一、Git的安装和配置 git安装&#xff1a; Git - Downloads git首次配置用户信息&#xff1a; $ git config --global user.name "kequan" $ git config --global user.email kequanchanqq.com $ git config --global credential store 配置 Git 以使用本地存储机…

基于自注意力机制的长短期记忆神经网络(LSTM-SelfAttention)的回归预测

提示&#xff1a;MATLAB版本需要R2023a以上 基于自注意力机制的长短期记忆神经网络&#xff08;LSTM-SelfAttention&#xff09;是一种用于时序数据预测的模型。这个模型结合了两个不同的结构&#xff0c;即长短期记忆网络&#xff08;LSTM&#xff09;和自注意力机制&#xff…

解决HttpServletRequest中的InputStream/getReader只能被读取一次的问题

一、事由 由于我们业务接口需要做签名校验&#xff0c;但因为是老系统了签名规则被放在了Body里而不是Header里面&#xff0c;但是我们不能在每个Controller层都手动去做签名校验&#xff0c;这样不是优雅的做法&#xff0c;然后我就写了一个AOP&#xff0c;在AOP中实现签名校…

Cesium.js(1):Cesium.js简介

1 前言 现有的gis开发方向较流行的是webgis开发&#xff0c;其中Cesium是一款开源的WebGIS库&#xff0c;主要用于实时地球和空间数据的可视化和分析。它提供了丰富的地图显示和数据可视化功能&#xff0c;并能实现三维可视化开发。Cesium适用于地球科学研究、军事情报分析、航…

Java编程练习之final关键字

1.final类&#xff1a;不允许任何类继承&#xff0c;并且不允许其他人对这个类进行任何改动&#xff1b; 当被某个类设置为final类时&#xff0c;类中的所有方法都被隐式的设置为final形式&#xff0c;但是final类中的成员变量既可以被定义为final形式&#xff0c;又可以被定义…

【区块链】椭圆曲线数字签名算法(ECDSA)

本文主要参考&#xff1a; 一文读懂ECDSA算法如何保护数据 椭圆曲线数字签名算法 1. ECDSA算法简介 ECDSA 是 Elliptic Curve Digital Signature Algorithm 的简称&#xff0c;主要用于对数据&#xff08;比如一个文件&#xff09;创建数字签名&#xff0c;以便于你在不破坏它…

Maven的仓库、周期和插件

优质博文&#xff1a;IT-BLOG-CN 一、简介 随着各公司的Java项目入库方式由老的Ant改为Maven后&#xff0c;相信大家对Maven已经有了个基本的熟悉。但是在实际的使用、入库过程中&#xff0c;笔者发现挺多人对Maven的一些基本知识还缺乏了解&#xff0c;因此在此处跟大家简单地…

SpringCloud系列(19)--将服务消费者Consumer注册进Consul

前言&#xff1a;在上一章节中我们把服务提供者Provider注册进了Consul&#xff0c;而本章节则是关于如何将服务消费者Consumer注册进Consul 1、再次创建一个服务提供者模块&#xff0c;命名为consumerconsul-order80 (1)在父工程下新建模块 (2)选择模块的项目类型为Maven并选…

使用CubeMx配置GD32F303系列单片机进行DMA ADC

原理图查看 查原理图可以看到GD32F103C8T6的官方开发板GD32303C-START-V1.0的PA1没有接任何东西 使用PA1作为ADC端口 CubeMX配置ADC和时钟 配置ADC通道 启用循环模式 配置此通道ADC分频 配置ADC DMA为循环模式 配置时钟 生成项目 Keil里面的配置 选择对应的GD32型号 编译…

2024全新瀚海跑道:矢量图片迅速养号游戏玩法,每天一小时,日转现200

最初我注意到这种玩法&#xff0c;是因为最近在浏览各大平台的视频时&#xff0c;我发现了一种特殊类型的账号&#xff0c;其养号成功率高达90%。这些账号发布的视频内容和数据非常夸张&#xff0c;而且制作起来非常简单&#xff0c;任何人都可以轻松上手。这些账号主要发布矢量…

堆与优先队列——练习题

1. 数据流中的第 K 大元素 代码实现&#xff1a; 思路&#xff1a;创建一个大小为 k 的小顶堆&#xff0c;堆顶元素就是第 K 大元素 typedef struct {int *__data, *data;int size;int n; } KthLargest;#define swap(a, b) { \__typeof(a) __c (a); \(a) (b); \(b) __c; \ }…

C++ 笔试练习笔记【1】:字符串中找出连续最长的数字串 OR59

文章目录 OR59 字符串中找出连续最长的数字串题目思路分析实现代码 注&#xff1a;本次练习题目出自牛客网 OR59 字符串中找出连续最长的数字串 题目思路分析 首先想到的是用双指针模拟&#xff0c;进行检索比较输出 以示例1为例&#xff1a; 1.首先i遍历str直到遍历到数字&a…

字符串类型漏洞之updatexml函数盲注

UPDATEXML 是 MySQL 数据库中的一个函数&#xff0c;它用于对 XML 文档数据进行修改和查询。然而&#xff0c;当它被不当地使用或与恶意输入结合时&#xff0c;它可能成为 SQL 注入攻击的一部分&#xff0c;从而暴露敏感信息或导致其他安全漏洞。 在 SQL 注入攻击中&#xff0…

CentOS 9 (stream) 安装 nginx

1.我们直接使用安装命令 dnf install nginx 2.安装完成后启动nginx服务 # 启动 systemctl start nginx # 设置开机自启动 systemctl enable nginx# 重启 systemctl restart nginx# 查看状态 systemctl status nginx# 停止服务 systemctl stop nginx 3.查看版本确认安装成功…

Pytorch实现线性回归模型

在机器学习和深度学习的世界中&#xff0c;线性回归模型是一种基础且广泛使用的算法&#xff0c;简单易于理解&#xff0c;但功能强大&#xff0c;可以作为更复杂模型的基础。使用PyTorch实现线性回归模型不仅可以帮助初学者理解模型的基本概念&#xff0c;还可以为进一步探索更…

深信服超融合虚拟机备份报错显示准备备分镜像失败

问题&#xff1a;最近一段时间深信服超融合虚拟机在执行备份策略时总是报错&#xff0c;备份空间又还很富余。 解决办法&#xff1a; 1 删除备份失败虚拟机的所有备份 2 解绑该虚拟机的备份策略 可靠服务>>备份与CDP>> 找到备份策略>>点【编辑】>>…

刷机维修进阶教程---开机定屏 红字感叹号报错 写字库保资料 救砖 刷官方包保资料的步骤方法解析

在维修各种机型 中经常会遇到开机定屏 进不去系统,正常使用无故定屏进不去系统或者更新降级开机红色感叹号的一些故障机。但顾客需要报资料救砖的要求,遇到这种情况。我们首先要确定故障机型的缘由。是摔 还是更新降级 还是无故使用重启定屏等等。根据原因来对症解决。 通过…

springboot3整合redis

redis在我们的日常开发中是必不可少的&#xff0c;本次来介绍使用spring boot整合redis实现一些基本的操作&#xff1b; 1、新建一个spring boot项目&#xff0c;并导入相应的依赖&#xff1b; <dependency><groupId>org.springframework.boot</groupId><…

基于YOLOV8+Pyqt5无人机航拍太阳能电池板检测系统

1.YOLOv8的基本原理 YOLOv8是一种前沿的目标检测技术&#xff0c;它基于先前YOLO版本在目标检测任务上的成功&#xff0c;进一步提升了性能和灵活性&#xff0c;在精度和速度方面都具有尖端性能。在之前YOLO 版本的基础上&#xff0c;YOLOv8 引入了新的功能和优化&#xff0c;…

PDF 正确指定页码挂载书签后,书签页码对不上

这个问题与我的另一篇中方法一样 如何让一个大几千页的打开巨慢的 PDF 秒开-CSDN博客 https://blog.csdn.net/u013669912/article/details/138166922 另做一篇原因是一篇文章附带一个与该文章主题不相关的问题时&#xff0c;不利于被遇到该问题的人快速搜索发现以解决其遇到的…