深度学习——线性神经网络(二、线性回归的从零开始实现)

news2024/10/10 6:14:12

目录

    • 2.1 生成数据集
    • 2.2 读取数据集
    • 2.3 初始化模型参数
    • 2.4 定义模型
    • 2.5 定义损失函数
    • 2.6 定义优化算法
    • 2.7 训练

2.1 生成数据集

  为简单展示,将根据带有噪声的线性模型构造一个数据集。生成一个包含1000个样本的数据集。每个样本包含从标准正态分布中的抽样的两个特征,因此我们合成的数据集是一个矩阵 X ∈ R 1000 × 2 X \in \mathbb{R}^{1000 \times 2} XR1000×2
  使用线性模型参数 w = [ 2 , − 3.4 ] , b = 4.2 \bm w=[2,-3.4],b=4.2 w=[2,3.4],b=4.2和噪声项 ϵ \epsilon ϵ来生成数据集及其标签。
y = X w + b + ϵ \bm y=\bm {Xw}+b+\epsilon y=Xw+b+ϵ
   ϵ \epsilon ϵ可以视为模型预测和标签的潜在观测误差,假设 ϵ \epsilon ϵ服从均值为0的正态分布,同时将标准差 σ \sigma σ设为0.01

def synthetic_data(w, b, num_examples):
    X = torch.normal(0, 1, (num_examples, len(w)))
    y = torch.matmul(X, w) + b
    y += torch.normal(0, 0.01, y.shape)
    return X, y.reshape((-1, 1))


true_w = torch.tensor([2, -3.4])
true_b = 4.2
features, labels = synthetic_data(true_w, true_b, 1000)

  features中的每一行都包含一个二维数据样本,labels中的每一行都包含一个标签值(标量)

print('features:', features[0], '\nlabel:', labels[0])

在这里插入图片描述
  通过生成第二个特征features[:, 1]和labels的散点图,可以直观地观察到两者间的线性关系。

d2l.set_figsize()
d2l.plt.scatter(features[:,1].detach().numpy(),labels.detach().numpy(),1)
d2l.plt.show()

在这里插入图片描述

features[:, 1].detach().numpy() 取出了特征数据的第二列(索引为1),并且转换为 NumPy 数组。
labels.detach().numpy() 将标签数据也转换为 NumPy 数组。
d2l.set_figsize( ) 设置了图形的大小。
d2l.plt.scatter(…, 1) 使用了大小为 1 的点来绘制散点图。

2.2 读取数据集

  训练模型时要对数据集进行遍历,每次抽取小批量样本,并使用它们来更新我们的模型。因此有必要定义一个函数,该函数能打乱数据集中的样本并以小批量方式获取数据。
  定义一个data_iter函数,该函数接收批量大小特征矩阵和标签向量作为输入,生成大小为batch_size的小批量。每个小批量包含一组特征和标签。

def data_iter(batch_size, features, labels):
    num_examples = len(features)
    indices = list(range(num_examples))
    # 随机读取
    random.shuffle(indices)
    for i in range(0, num_examples, batch_size):
        batch_indices = torch.tensor(
            indices[i:min(i + batch_size,num_examples)])
        yield features[batch_indices], labels[batch_indices]

  通常,我们利用GPU并行计算的优势,处理大小合理的“小批量”。每个样本都可以被并行地进行模型计算,且每个样本损失函数的梯度也可以被并行计算。GPU可以实现在处理几百个样本时所花费的时间不比处理单个样本时多太多。
  直观感受一下小批量计算:读取第一个小批量数据样本并打印。每个批量的特征维度是显示批量大小和输入特征数。同样,批量的标签形状与batch_size相等。

batch_size = 10

for X, y in data_iter(batch_size, features, labels):
    print(X, '\n', y)
    break

在这里插入图片描述

当我们执行迭代时,我们会连续地获得不同的小批量,直至遍历完整个数据集。
上面的迭代过程对于展示代码比较适合,但是它的执行效率很低,可能会在实际应用问题中产生比较大的麻烦。

  如果要将所有数据加载到内存中,并执行大量的随机内存访问。在深度学习框架中实现的内置迭代器的效率要高得多,它可以处理存储在文件中的数据和数据流提供的数据。

2.3 初始化模型参数

  我们通过从均值为0,标准差为0.01的正态分布中抽样随机数来初始化权重,并将偏置初始化为0

w = torch.normal(0, 0.01, size=(2, 1), requires_grad=True)
b = torch.zeros(1, requires_grad=True)

  在初始化这些参数后,我们的任务是更新这些参数,直到这些参数足以拟合我们的数据。
  每次更新都需要计算损失函数关于模型参数的梯度,根据梯度,可以向减少损失的方向更新每个参数。
  通过torch中引入的自动微分来计算梯度。

2.4 定义模型

  要计算线性模型的输出,只需计算输入特征 X \bm X X和模型权重 w \bm w w的矩阵—向量乘法后加上偏置 b b b

   X w \bm {Xw} Xw是一个向量,而 b b b是一个标量。根据前面章节提到的广播机制:当我们用一个向量加上一个标量时,标量会被加到向量的每一个分量上。

def linreg(X, w, b):
    """定义线性回归模型"""
    return torch.matmul(X, w) + b

2.5 定义损失函数

def squared_loss(y_hat, y):
    """均方损失函数"""
    return (y_hat - y.reshape(y_hat.shape)) ** 2 / 2

  定义平方损失函数,需要将真实值y的形状转换为和预测值y-hat的形状相同。

2.6 定义优化算法

  在每一步中,使用从数据集中随机抽取的一个小批量,然后根据参数计算损失的梯度,接下来,朝着减少损失的方向更新我们的参数。
  下面的函数实现小批量随机梯度下降更新,该函数接受模型的参数集合、学习率和批量大小作为输入,每一步更新的大小由学习率lr决定。
  因为我们计算的损失是一个批量样本的总和,所以我们用批量大小(batch_size)来规范化步长,这样步长大小就不会取决于我们对批量大小的选择。

def sgd(params, lr, batch_size):
    """小批量随机梯度下降"""
    with torch.no_grad():
        for param in params:
            param -= lr * param.grad / batch_size
            param.grad.zero_()

2.7 训练

  在每次迭代中,我们读取小批量训练样本,并通过我们的模型来获得一组预测,计算完损失后,开始反向传播,存储每个参数的梯度。最后,调用优化算法SGD来更新模型参数。
  将执行以下迭代;

  1. 初始化参数

  2. 重复以下训练,直到完成:

    • 计算梯度 g ← ∂ ( w , b ) 1 ∣ B ∣ ∑ i ∈ B l ( x ( i ) , y ( i ) , w , b ) \bm g \leftarrow \partial_{(\bm w,b)} \frac 1 {\vert B \vert}\sum_{i \in B} l(\bm x^{(i)},y^{(i)},\bm w,b) g(w,b)B1iBl(x(i),y(i),w,b)
      -更新参数 ( w , b ) ← ( w , b ) − η g (\bm w,b) \leftarrow (\bm w,b)- \eta\bm g (w,b)(w,b)ηg

  在每轮(epoch)中,我们使用data_iter函数遍历整个数据集,并将训练数据集中的所有样本都使用一次(假设样本数能够被批量大小整除)。
  这里的轮数num_epochs和学习率lr都是超参数,分别设为3和0.03.

  设置超参数比较麻烦,需要通过反复实验进行调整。

lr = 0.03
num_epochs = 3
net = linreg
loss = squared_loss

for epoch in range(num_epochs):
    for X, y in data_iter(batch_size, features, labels):
        l = loss(net(X, w, b), y) # X和y的小批量损失
        # 因为l的形状是(batch_size, 1),而不是一个标量
        # l中的所有元素被加到一起,并以此计算关于[w, b]的梯度
        l.sum().backward()
        sgd([w, b], lr, batch_size) # 使用参数的梯度更新参数
    with torch.no_grad():
        train_l = loss(net(features, w, b), labels)
        print(f'epoch{epoch + 1}, loss{float(train_l.mean()):f}')

在这里插入图片描述
  因为使用的是自己合成的数据集,所以知道真实的参数是什么。因此,我们可以通过比较真实参数和通过训练学习的参数来评估训练好坏程度。事实上,真实参数和通过训练学习的参数确实非常相近。

print(f'w的估计误差:{true_w - w.reshape(true_w.shape)}')
print(f'b的估计误差:{true_b - b}')

在这里插入图片描述

  在机器学习中,我们通常不太关心恢复真实的参数,而更关心如何高度准确地预测参数。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2201292.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

RocketMq-秒杀应用场景

1、介绍mq 2、秒杀介绍 -----redis配置 3、生产者-消费者搭建 1、介绍mq 消息存储架构图中主要有下面三个跟消息存储相关的文件构成。 (1) CommitLog:消息主体以及元数据的存储主体,存储Producer端写入的消息主体内容,消息内容不是定长的。单个文件大…

微服务实战——登录(普通登录、社交登录、SSO单点登录)

登录 1.1. 用户密码 PostMapping("/login")public String login(UserLoginVo vo, RedirectAttributes redirectAttributes, HttpSession session){R r memberFeignService.login(vo);if(r.getCode() 0){MemberRespVo data r.getData("data", new Type…

价值5000元完整版GOD引擎手机客户端三端引擎源码 编译完整版

5000元完整版GOD引擎手机客户端三端引擎源码 支持三端互通:电脑端,安卓端,苹果端 GOD引擎全套源码及手游客户端源码(苍穹引擎源码及修改教程) 服务端代码为Delphir,手游客户端代码为cocos2dx的(…

DAMA数据管理知识体系(第14章 大数据和数据科学)

课本内容 14.1 引言 概要 从数据中探究、研发预测模型、机器学习模型、规范性模型和分析方法并将研发结果进行部署供相关方分析的人,被称为数据科学家业务驱动 期望抓住从多种流程生成的数据集中发现的商机,是提升一个组织大数据和数据科学能力的最大业务…

论文阅读(十二):Attention is All You Need

文章目录 一、循环神经网络1.1RNN模型1.1.1RNN模型简介1.1.2RNN基本结构1.1.3权重共享机制1.1.4RNN局限性:长期依赖问题与梯度消失 1.2LSTM模型1.2.1LSTM核心思想1.2.2遗忘门1.2.3输入门1.2.4更新细胞状态1.2.5输出门1.2.6参数更新 二、Seq2Seq机制2.1RNN结构的局限…

react 知识点汇总(非常全面)

React 是一个用于构建用户界面的 JavaScript 库,由 Facebook 开发并维护。它的核心理念是“组件化”,即将用户界面拆分为可重用的组件。 React 的组件通常使用 JSX(JavaScript XML)。JSX 是一种 JavaScript 语法扩展,…

house_of_muney

house_of_muney 首先介绍一下house of muney 这个利用原理: 在了解过_dl_runtime_resolve的前提下,当程序保护开了延迟绑定的时候,程序第一次调用相关函数的时候会执行下面的命令 push n push ModuleID jmp _dl_runtime_resolve 这里的n…

OCR+PDF解析配套前端工具开源详解!

目录 一、项目简介 TextIn为相关领域的前端开发提供了优秀的范本。 目前项目已在Github上开源! 二、性能特色 三、安装使用 安装依赖启动项目脚本命令项目结构 四、效果展示 面对日常生活和工作中常见的OCR识别、PDF解析、翻译、校对等场景,配套的…

洛谷P5648

洛谷P5648 这题花了很长时间,是在线段树题单里找到的( )。有线段树做法,但是我感觉可能比倍增做法更难看懂。以后有空再看看吧。感觉线段树现在只会板子题,绿稍微难点可能就不会。 花了很久时间之后,就觉得…

打造直播美颜平台的关键技术:视频美颜SDK的深度解析

本篇文章,小编将深入解析视频美颜SDK的关键技术,探讨其在打造直播美颜平台中的作用。 一、视频美颜SDK的定义与功能 视频美颜SDK是一套专门为实时视频处理而设计的软件开发工具包。其主要功能包括人脸检测、肤色美化、瑕疵修复、虚化背景、实时滤镜等。…

Python对PDF文件的合并操作

在处理 PDF 文件时,合并多个 PDF 文件为一个单一文件或者将某个单一文件插入某个PDF文件是一个常见的需求。Python 提供了多种库来实现这一功能,其中 PyPDF2 是一个非常流行的选择。该库提供了简单易用的接口,包括 merge() 方法,可…

CRE6281B1 (宽VCC:8-45V PWM电源芯片)

CRE6281B1 是一款外驱功率管的高度集成的电流型PWM 控制 IC,为高性能、低待机功率、低成本、高效率的隔离型反激式开关电源控制器。在满载时,CRE6281B1工作在固定频率(65kHz)模式。在负载较低时,采用节能模式,实现较高的功率转换效…

关于Allegro导出Gerber时的槽孔问题

注意点一: 如果设计的板子中有 槽孔和通孔(俗称圆孔),不仅要NC Drill, 还要 NC Route allegro导出的槽孔文件后缀是 .rou 圆型孔后缀 是 .drl ,出gerber时需要看下是否有该文件。 注意点二: 导出钻孔文件时,设置参…

Hi3061M开发板——系统时钟频率

这里写目录标题 前言MCU时钟介绍PLLCRG_ConfigPLL时钟配置另附完整系统时钟结构图 前言 Hi3061M使用过程中,AD和APT输出,都需要考虑到时钟频率,特别是APT,关系到PWM的输出频率。于是就研究了下相关的时钟。 MCU时钟介绍 MCU共有…

22.1 K8S之KubeSphere实现中间件高可用集群

22.1 K8S之KubeSphere实现中间件高可用集群 一. 章节概述二. WordPress1. WordPress 简介---------------------------------------------------------------------------------------------------一. 章节概述 二. WordPress 1. WordPress 简介 创建并部署 WordPress

MySQL 数据库的性能优化方法方法有哪些?

MySQL 数据库的性能优化方法方法有哪些? 从开发角度来看,一般可以从 SQL 和库表设计两部分优化性能。 SQL 优化 根据慢sql日志,找出需要优化的一些sql语句。 常见优化方向: 避免select *,只查询必要的字段&#x…

62 加密算法

62 加密算法 三种加密算法分类: 对称加密:密钥只有一个,解密、解密都是这个密码,加解密速度快,典型的对称加密有DES、AES、RC4等非对称加密:密钥成对出现,分别为公钥和私钥,从公钥…

sass学习笔记(1.0)

1.使用变量 sass可以像声明变量那样进行使用,这样同样的样式,就可以使用相同的变量来提高复用。 语法为:$ 变量名 在界面中也可以正常的显示 当然了,变量之间也可以相互引用,比如下面 div{$_color: #d45387;$BgColo…

kibana 删除es指定数据,不是删除索引

1 查询条件查询出满足条件的数据 GET /order_header_idx_202410/_search {"from":0,"size":10,"query":{"bool":{"filter":[{"term":{"oh_tenantId":{"value":"0211000001",&…

NeuVector部署、使用与原理分析

文章目录 前言1、概述2、安装与使用2.1、安装方法2.1.1、部署NeuVector前的准备工作2.1.1.1 扩容系统交换空间2.1.1.2 Kubernetes单机部署2.1.1.2.1 部署Docker2.1.1.2.2 部署Kubectl2.1.1.2.3 部署Minikube 2.1.1.3 Helm部署 2.1.2、使用Helm部署NeuVector 2.2、使用方法2.2.1…