《动手深度学习》线性回归简洁实现实例

news2024/9/23 17:15:14

🎈 作者:Linux猿

🎈 简介:CSDN博客专家🏆,华为云享专家🏆,Linux、C/C++、云计算、物联网、面试、刷题、算法尽管咨询我,关注我,有问题私聊!

🎈 欢迎小伙伴们点赞👍、收藏⭐、留言💬


本文是《动手深度学习》线性回归简洁实现实例的实现和分析,主要对代码进行详细讲解,有问题欢迎在评论区讨论交流。

一、代码实现

实现代码如下所示。

import torch
from torch.utils import data
# d2l包是李沐老师等人开发的动手深度学习配套的包,
# 里面封装了很多有关与数据集定义,数据预处理,优化损失函数的包
from d2l import torch as d2l
# nn 是神经网络 Neural Network 的缩写,提供了一系列的模块和类,实现创建、训练、保存、恢复神经网络
from torch import nn

'''
1. 生成数据集,共 1000 条
true_w 和 true_b 是临时变量用于生成数据集
生成 X, y :满足关系 y = Xw + b + noise
'''
true_w = torch.tensor([2, -3.4])
true_b = 4.2
features, labels = d2l.synthetic_data(true_w, true_b, 1000)

'''
2. 构造循环读取数据集的迭代器
'''
def load_array(data_arrays, batch_size, is_train=True):  #@save
    # 构造一个 PyTorch 数据迭代器,对 tensor 进行打包,包装成 dataset。
    dataset = data.TensorDataset(*data_arrays)
    # 根据数据集构造一个迭代器
    return data.DataLoader(dataset, batch_size, shuffle=is_train)

# 小批量数据
batch_size = 10
# 设置了一个数据读取的迭代器,每次读取 batch_size(10) 条
data_iter = load_array((features, labels), batch_size)

'''
3. 设置全连接层
'''
'''
# nn.Linear(in_features, out_features, bias=True)
# in_features : 输入向量的列数
# out_features : 输出向量的列数
# bias = True 是否包含偏置
执行线性变换:Yn*o = Xn*i Wi*o + b
其中:W 和 b 模型需要学习的参数
在本例中:n = 10,i = 2, o = 1
'''
net = nn.Sequential(nn.Linear(2, 1))
# 设置权重 w 和 偏置 b
net[0].weight.data.normal_(0, 0.01)
net[0].bias.data.fill_(0)

'''
4. 定义损失函数
'''
# 均方误差,是预测值与真实值之差的平方和的平均值
loss = nn.MSELoss()
# lr 学习率 learning rate
trainer = torch.optim.SGD(net.parameters(), lr=0.03)

'''
4. 训练数据
'''
# 超参数 设置批次
num_epochs = 3
for epoch in range(num_epochs): # 进行 num_epochs 个迭代周期
    for X, y in data_iter:
        l = loss(net(X) ,y) # 计算损失,net(X) 计算预测值 y1,loss(y1, y) 计算预测值和真实值之间的差距
        trainer.zero_grad() # 将所有模型参数的梯度置为 0
        l.backward() # 求梯度,不使用从零实现中 l.sum.backward 的原因是损失计算中使用了平均的 gard
        trainer.step() # 优化参数 w 和 b
    l = loss(net(features), labels)
    print(f'epoch {epoch + 1}, loss {l:f}')

w = net[0].weight.data
print('w的估计误差:', true_w - w.reshape(true_w.shape))
b = net[0].bias.data
print('b的估计误差:', true_b - b)

二、实现解析

针对实例中重要的函数解析如下。

2.1 Linear 函数

nn.Linear(in_features, out_features, bias=True)

神经网络的线性层,也成为全连接层,进行 Y = XW + b 的线性变换。

参数:

in_features : 输入向量的列数

out_features : 输出向量的列数

bias = True 是否包含偏置

in_features 和 out_features 是 W 的行和列。

执行线性变换:Yn*o = Xn*i Wi*o + b

其中:W 和 b 模型需要学习的参数

在本例中:n = 10,i = 2, o = 1。

2.2 Sequential 函数

一个序列容器,用于搭建神经网络的模块,按照传入构造器的顺序添加到 nn.Sequential() 容器中。按照内部模块的顺序自动依次计算并输出结果。

2.3 MSELoss 函数

均方误差,是预测值与真实值之差的平方和的平均值,即:

2.4 TensorDataset 函数

用来对 tensor 进行打包,就好像 python 中的 zip 功能。该类通过每一个 tensor 的第一个维度进行索引。因此,该类中的 tensor 第一维度必须相等. 另外:TensorDataset 中的参数必须是 tensor。可以参考如下例子:

import torch
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader

# len = 12
a = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [1, 2, 3], [4, 5, 6], [7, 8, 9], [1, 2, 3], [4, 5, 6], [7, 8, 9], [1, 2, 3], [4, 5, 6], [7, 8, 9]])
# len = 12
b = torch.tensor([44, 55, 66, 44, 55, 66, 44, 55, 66, 44, 55, 66])
# 将 tensor a 和 b 压缩在一起
train_ids = TensorDataset(a, b)
# 输出
for x, y in train_ids:
    print(x, y)

输出如下:

tensor([1, 2, 3]) tensor(44)
tensor([4, 5, 6]) tensor(55)
tensor([7, 8, 9]) tensor(66)
tensor([1, 2, 3]) tensor(44)
tensor([4, 5, 6]) tensor(55)
tensor([7, 8, 9]) tensor(66)
tensor([1, 2, 3]) tensor(44)
tensor([4, 5, 6]) tensor(55)
tensor([7, 8, 9]) tensor(66)
tensor([1, 2, 3]) tensor(44)
tensor([4, 5, 6]) tensor(55)
tensor([7, 8, 9]) tensor(66)

2.5 DataLoader 函数

DataLoader 是用来包装所使用的数据,每次抛出一批数据,下面来看一个例子。

import torch
from torch.utils import data

# len = 12
a = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [1, 2, 3], [4, 5, 6], [7, 8, 9], [1, 2, 3], [4, 5, 6], [7, 8, 9], [1, 2, 3], [4, 5, 6], [7, 8, 9]])
# len = 12
b = torch.tensor([44, 55, 66, 44, 55, 66, 44, 55, 66, 44, 55, 66])
# 将 tensor a 和 b 压缩在一起
train_ids = data.TensorDataset(a, b)
# 输出
#for x, y in train_ids:
#    print(x, y)

BATCH_SIZE = 4
loader = data.DataLoader(dataset=train_ids,
                         batch_size=BATCH_SIZE, # 每次取 BATCH_SIZE=4 个数据
                         shuffle=False, # 不打乱顺序,便于查看
                         num_workers=0)

for x, y in loader:
    print(x, y)
    break

输出如下:

tensor([[1, 2, 3],
        [4, 5, 6],
        [7, 8, 9],
        [1, 2, 3]]) tensor([44, 55, 66, 44])

 如上所示,输出第一个 BATCH_SIZE=4。

2.6 zero_grad 函数

trainer.zero_grad() 是用来清空模型参数梯度的函数,它将模型参数的梯度缓存设置为 0。在进行反向传播时,梯度会累加,如果不清空梯度,会影响后续的梯度计算。

2.7 backward 函数

对计算图进行梯度计算,求解计算图中所有节点的梯度。

2.8 step 函数

根据 backward 函数计算出的梯度进行参数更新。

参考链接:

线性回归的实现学习_data.tensordataset_带刺的厚崽的博客-CSDN博客

nn.Sequential()_一颗磐石的博客-CSDN博客

【Pytorch基础】torch.nn.MSELoss损失函数_一穷二白到年薪百万的博客-CSDN博客

pytorch之trainer.zero_grad()_FibonacciCode的博客-CSDN博客

清空模型参数梯度的函数 - 知乎

pytorch中backward()函数详解_backward函数_Camlin_Z的博客-CSDN博客

理解Pytorch的loss.backward()和optimizer.step() - 知乎


🎈 感觉有帮助记得「一键三连支持下哦!有问题可在评论区留言💬,感谢大家的一路支持!🤞猿哥将持续输出「优质文章回馈大家!🤞🌹🌹🌹🌹🌹🌹🤞


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1152553.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

百度富文本上传图片后样式崩塌

🔥博客主页: 破浪前进 🔖系列专栏: Vue、React、PHP ❤️感谢大家点赞👍收藏⭐评论✍️ 问题描述:上传图片后,图片会变得很大,当点击的时候更是会顶开整个的容器的高跟宽 原因&#…

【3D 图像分割】基于 Pytorch 的 VNet 3D 图像分割7(数据预处理)

在上一节:【3D 图像分割】基于 Pytorch 的 VNet 3D 图像分割6(数据预处理) 中,我们已经得到了与mhd图像同seriesUID名称的mask nrrd数据文件了,可以说是一一对应了。 并且,mask的文件,还根据结…

【错误解决方案】ModuleNotFoundError: No module named ‘ngboost‘

1. 错误提示 在python程序,尝试导入一个名为ngboost的模块,但Python提示找不到这个模块。 错误提示:ModuleNotFoundError: No module named ‘ngboost‘ 2. 解决方案 出现上述问题,可能是因为你还没有安装这个模块,…

CXL技术交流群问题讨论记录(1)

🔥点击查看精选 CXL 系列文章🔥 📢 声明: 🥭 作者主页:【MangoPapa的CSDN主页】。⚠️ 本文首发于CSDN,转载或引用请注明出处【https://mangopapa.blog.csdn.net/article/details/134131924】。…

Python 学习1 基础

文章目录 基础字符串字面量常用的值类型注释变量print语句数据类型数据类型转换标识符运算符 字符串拓展小结 2023.10.28 周六 最近打算学一下Python,毕竟确实简单方便,而且那个编程语言排名还是在第一。不过不打算靠它吃饭,深不深入暂且不说…

防数据泄密的解决方案

防数据泄密的解决方案 安企神数据防泄密系统下载使用 现代化企业离不开信息数据,数据对企业的经营至关重要,也是企业发展的命脉。为了保护公司数据不被泄露,尤其是在防止数据泄密方面,公司面临着巨大的挑战,需要采取…

Python爬虫实战(六)——使用代理IP批量下载高清小姐姐图片(附上完整源码)

文章目录 一、爬取目标二、实现效果三、准备工作四、代理IP4.1 代理IP是什么?4.2 代理IP的好处?4.3 获取代理IP4.4 Python获取代理IP 五、代理实战5.1 导入模块5.2 设置翻页5.3 获取图片链接5.4 下载图片5.5 调用主函数5.6 完整源码5.7 免费代理不够用怎…

EasyFlash移植使用- 关于单片机 BootLoader和APP均使用的情况

目前,我的STM32单片机,需要在BootLoader和APP均移植使用EasyFlash,用于参数管理和IAP升级使用。 但是由于Flash和RAM限制,减少Flash占用,我规划如下: BootLoader中移植EasyFlash使用旧版本,因为…

机器学习-基本知识

 任务类型 ◼ 有监督学习(Supervised Learning) 每个训练样本x有人为标注的目标t,学习的目标是发现x到t的映射,如分类、回归。 ◼ 无监督学习(Unsupervised Learning) 学习样本没有人为标注,学习的目的是发现数据x本身的分布规律&#xf…

ROS自学笔记二十: Gazebo里面仿真环境搭建

Gazebo 中创建仿真实现方式有两种:1直接添加内置组件创建仿真环境2: 手动绘制仿真环境 1.添加内置组件创建仿真环境 1.1启动 Gazebo 并添加组件 1.2保存仿真环境 添加完毕后,选择 file ---> Save World as 选择保存路径(功能包下: worlds 目录),文…

二维数组如何更快地遍历

二维数组如何更快地遍历 有时候,我们会发现,自己的代码和别人的代码几乎一模一样,但运行时间差了很多,别人是 AC \text{AC} AC,你是 TLE \text{TLE} TLE,这是为什么呢? 一个可能的原因是数组的…

延迟队列实现方案总结

日常开发中,可能会遇到一些延迟处理的消息任务,例如以下场景 ①订单支付超时未支付 ②考试时间结束试卷自动提交 ③身份证或其他验证信息超时未提交等场景。 ④用户申请退款,一天内没有响应默认自动退款等等。 如何处理这类任务,最…

http1,https,http2,http3总结

1.HTTP 当我们浏览网页时,地址栏中使用最多的多是https://开头的url,它与我们所学的http协议有什么区别? http协议又叫超文本传输协议,它是应用层中使用最多的协议, http与我们常说的socket有什么区别吗? …

2000-2021年上市公司产融结合度量数据

2000-2021年上市公司产融结合度量数据 1、时间:2000-2021年 2、指标:股票代码、年份、是否持有银行股份、持有银行股份比例、是否持有其他金融机构股份、产融结合 3、来源:上市公司年报 4、范围:上市公司 5、样本量&#xff…

4种类型WMS的简要说明

仓库管理系统(WMS)主要有四种类型:独立仓库管理系统、供应链管理系统中的仓库管理模块、ERP 系统中的仓库管理模块和基于云的仓库管理系统。 独立仓库管理系统 独立仓库管理系统提供的功能可实现日常仓库运营。公司可以使用WMS系统来监管和…

【MATLAB源码-第62期】基于matlab的DCSK(差分混沌移位键控调制)系统误码率仿真。

MATLAB 2022a 1、算法描述 DCSK(Differential Chaos Shift Keying)是一种差分混沌移位键控调制方式,常用于无线通信系统。其调制和解调的基本流程如下: 1. DCSK调制 1.1 生成混沌序列 - 初始条件:选择一个混沌映射&a…

『K8S 入门』一:基础概念与初步搭建

『K8S 入门』一:基础概念与初步搭建 一、kubernetes 组件 官方示图 抽象示图 Master 控制面板 Api-Server:接口服务,基于 REST 风格开放 k8s 接口的服务ControllerManager cloud-controller-manager:云控制管理器。第三方平…

Android图片加载框架库源码解析 - Coil

文章目录 一、什么是Coil二、引入Coil1、ImageView加载图片1.1、普通加载1.2、crossfade(淡入淡出)加载1.3、crossfade的动画时间1.4、placeholder1.5、error1.6、高斯模糊1.7、灰度变换1.8、圆形1.9、圆角 2、Gif加载3、SVG加载(不存在)4、视频帧加载5、监听下载过程6、取消下…

想翻译pdf文档,试了几个工具对比:有阿里(完全免费,快,好用,质量高,不用注册登录)道最好(有限免费) 百度(有限免费)和谷歌完全免费(网不好)

文档翻释作为基础设施,工作必备。 阿里 (完全免费,快,好用,质量高,不用注册登录,无广告)我给满分 https://translate.alibaba.com/#core-translation 先选好语言。 Google(完全免…

PDManer生成Postgis对应Schema数据库设计文档

项目开发数据库选择postGis,由于需要编写数据库设计说明书,因此选择工具PDManer生成数据库设计文档,但是postGis一个数据库,可能对应多个Schema。如下图所示: 1.编写数据库设计文档时,仅需编写hly这个Sche…