《动手学习深度学习》笔记(二)线性神经网络

news2024/9/25 3:20:37

三、线性神经网络

3.1 线性回归

3.1.1 介绍

  1. 回归是为一个或多个自变量与因变量之间的关系建模的一类方法。而线性回归基于几个简单的假设:① 自变量和因变量关系是线性的;② 允许包含噪声但是噪声遵循正态分布。
  2. 训练数据集/训练集,样本/数据点/数据样本,标签/目标(试图预测的目标),特征/协变量(预测所依据的自变量)的概念,用 n n n 来表示数据集中的样本数,对索引为 i i i 的样本,输入表示为 x ( i ) = [ x 1 ( i ) , x 2 ( i ) ] ⊤ \mathbf{x}^{(i)}=[x_1^{(i)}, x_2^{(i)}]^{\top} x(i)=[x1(i),x2(i)],对应的标签为 y ( i ) y^{(i)} y(i)
  3. 线性假设中包含权重和偏置,是对输入特征的一个仿射变换。将所有特征放到向量 w ∈ R d \mathbf{w}\in\mathbb{R}^d wRd 中,得到线性模型的简洁表示: y ^ = w ⊤ x + b \hat{y}=\mathbf{w}^{\top}\mathbf{x}+b y^=wx+b,进而得到整个数据集的模型表示: y ^ = X w + b \hat{y}=\mathbf{Xw}+b y^=Xw+b,要得到最好的模型参数 w \mathbf{w} w b b b ,还需要两个东西:
(1)一种模型质量的度量方式——损失函数
L ( w , b ) = 1 n ∑ i = 1 n l ( i ) ( w , b ) = 1 n ∑ i = 1 n 1 2 ( w ⊤ x ( i ) + b − y ( i ) ) 2 (1) L(\mathbf{w},b)=\frac{1}{n}\sum^n_{i=1}l^{(i)}(\mathbf{w},b)=\frac{1}{n}\sum^n_{i=1}\frac{1}{2}(\mathbf{w}^{\top}\mathbf{x}^{(i)}+b-y^{(i)})^2\tag{1} L(w,b)=n1i=1nl(i)(w,b)=n1i=1n21(wx(i)+by(i))2(1)
w ∗ , b ∗ = arg min ⁡ w , b L ( w , b ) (2) \mathbf{w}^*,b^*=\argmin_{\mathbf{w},b}L(\mathbf{w},b)\tag{2} w,b=w,bargminL(w,b)(2)
  线性回归的解可以用一个公式简单表达出来,但这种方法对问题限制很严格,因此无法广泛应用于深度学习,于是就需要下面的——
(2)一种能够更新模型以提高模型预测质量的方法。如梯度下降法(Gradient Descent),可以计算损失函数关于模型参数的导数,而实际执行时通常是在每次需要计算更新的时候随机抽取一小批样本,这种变体称为:小批量随机梯度下降
  随机抽取一个小批量 B \mathcal{B} B,预先确定一个正数 η \eta η,下面是更新过程:
( w , b ) ← ( w , b ) − η ∣ B ∣ ∑ i ∈ B ∂ ( w , b ) l ( i ) ( w , b ) (3) (\mathbf{w},b)\leftarrow(\mathbf{w},b)-\frac{\eta}{\vert\mathcal{B}\vert}\sum_{i\in\mathcal{B}}\partial_{(\mathbf{w},b)}l^{(i)}(\mathbf{w},b)\tag{3} (w,b)(w,b)BηiB(w,b)l(i)(w,b)(3)
  其中 ∣ B ∣ \vert\mathcal{B}\vert B 表示每个小批量中的样本数,也称为批量大小(batch size), η \eta η 表示学习率,它们通过预先手动设定,这些可以调节但是不再训练过程中更新的参数称为超参数
  4. 泛化的挑战;用模型进行预测
  5. 矢量化加速:在训练模型时,对计算进行矢量化,从而利用线性代数库,而非开销高昂的Python For循环,缩减程序运行所需时间。
  6. 正态分布与平方损失:考虑观测信息中的噪声并设定其服从正态分布后,有 y = w ⊤ x + b + ϵ y=\mathbf{w^\top x}+b+\epsilon y=wx+b+ϵ,其中 ϵ ∼ N ( 0 , σ 2 ) \epsilon\sim\mathcal{N}(0,\sigma^2) ϵN(0,σ2)极大似然估计
  7. 从线性回归到深度网络:神经网络涵盖了丰富的模型,可以使用描述神经网络的方式描述线性模型,下面用“层”符号重写这个模型。(注意:下图中隐去了权重和偏置的值)

线性回归是一个单层神经网络

  对于线性回归,每个输入都与输出相连,这种变换被称为全连接层(fully-connected layer)或稠密层(dense layer)。这些算法的想法一定程度上归功于我们对真实生物神经系统的研究。

3.1.2 线性回归的从零开始实现

3.1.3 线性回归的简洁实现

3.2 softmax回归

3.2.1 介绍

  1. [分类问题] 回归可以用于预测多少的问题,而我们也可能对分类问题感兴趣,不是问“多少”,而是问“哪一个”。有两种差别微妙的问题:① 只区分“硬性”类别;② 得到“软性”类别,即属于每个类别的概率。
  2. [独热编码] 统计学家很早以前就发明了一种表示分类数据的简单方法:独热编码,它是一个向量,分量与类别一样多,由此,标签 y y y可以表示为:
y ∈ { ( 1 , 0 , 0 ) , ( 0 , 1 , 0 ) , ( 0 , 0 , 1 ) } y\in\{(1,0,0),(0,1,0),(0,0,1)\} y{(1,0,0),(0,1,0),(0,0,1)}
  3. [多输出模型] 以4像素图片区分三种类别为例,为估计每个类别概率,需要一个多输出模型
o 1 = x 1 w 11 + x 2 w 12 + x 3 w 13 + x 4 w 14 + b 1 o 2 = x 1 w 21 + x 2 w 22 + x 3 w 23 + x 4 w 24 + b 2 o 3 = x 1 w 31 + x 2 w 32 + x 3 w 33 + x 4 w 14 + b 3 (4) o_1=x_1w_{11}+x_2w_{12}+x_3w_{13}+x_4w_{14}+b_1\\ o_2=x_1w_{21}+x_2w_{22}+x_3w_{23}+x_4w_{24}+b_2\\ o_3=x_1w_{31}+x_2w_{32}+x_3w_{33}+x_4w_{14}+b_3\tag{4} o1=x1w11+x2w12+x3w13+x4w14+b1o2=x1w21+x2w22+x3w23+x4w24+b2o3=x1w31+x2w32+x3w33+x4w14+b3(4)
  使用神经网络图描述如下,可以看到softmax回归也是一个单层神经网络,输出层也是全连接层

softmax回归是一个单层神经网络

  4. [参数开销]深度学习中,全连接层无处不在,对于任何具有 d d d 个输入和 q q q 个输出的全连接层,参数开销 O ( d q ) \mathcal{O}(dq) O(dq) ,这在实践中可能是不可接受的,这个成本可以减少到 O ( d q n ) \mathcal{O}(\frac{dq}{n}) O(ndq),超参数 n n n 可以灵活指定,以平衡参数节约与模型有效性(Zhang et al., 2021)。

  5. [softmax函数]如果要将输出视为概率,必须保证任何数据上的输出都非负且总和为1。softmax函数能够将未规范化的预测转换为非负数且总和为1,同时保持模型可导性质,公式如下:
y ^ = softmax ( o )     其中     y ^ j = exp ⁡ ( o j ) ∑ k exp ⁡ ( o k ) (5) \hat{\mathbf{y}}=\text{softmax}(\mathbf{o})\ \ \ \ 其中\ \ \ \ \hat{y}_j=\frac{\exp(o_j)}{{\sum}_k\exp(o_k)}\tag{5} y^=softmax(o)    其中    y^j=kexp(ok)exp(oj)(5)
  这个过程中仍然保持 arg max ⁡ j y ^ j = arg max ⁡ j o j \argmax_j\hat{y}_j=\argmax_jo_j argmaxjy^j=argmaxjoj,虽然softmax非线性,但是softmax回归的输出仍然由输入特征的仿射变换决定,因此softmax回归是一个线性模型
  6. [小批量的矢量化],加快 X \mathbf{X} X W \mathbf{W} W 的矩阵向量乘法。
  7. [损失函数],softmax函数给出了一个向量 y ^ \hat{y} y^,即“对给定输入 x \mathbf{x} x 的每个类的条件概率”,优化目标是最大化 P ( Y ∣ X ) P(\mathbf{Y\vert X}) P(Y∣X),可以转换为最小化负对数似然:
− log ⁡ P ( Y ∣ X ) = ∑ i = 1 n − log ⁡ P ( y ( i ) ∣ x ( i ) ) = ∑ i = 1 n l ( y ( i ) , y ^ ( i ) ) (6) -\log P(\mathbf{Y\vert X})=\sum_{i=1}^n-\log P(\mathbf{y}^{(i)}\vert\mathbf{x}^{(i)})=\sum_{i=1}^nl(\mathbf{y}^{(i)},\hat{\mathbf{y}}^{(i)})\tag{6} logP(Y∣X)=i=1nlogP(y(i)x(i))=i=1nl(y(i),y^(i))(6)
l ( y ( i ) , y ^ ( i ) ) = − ∑ j = 1 q y j log ⁡ y ^ j (7) l(\mathbf{y}^{(i)},\hat{\mathbf{y}}^{(i)})=-\sum^q_{j=1}y_j\log{\hat{y}_j}\tag{7} l(y(i),y^(i))=j=1qyjlogy^j(7)
  式(7)中的损失函数通常被称为交叉熵损失(cross-entropy loss),是分类问题最常用的损失之一。由于幂永远大于零,得到的概率一定大于零,理论上损失函数不能被进一步最小化。
  8. [导数计算]将式(5)带入式(7),可以得到:
l ( y , y ^ ) = − ∑ j = 1 q y j log ⁡ exp ⁡ ( o j ) ∑ k exp ⁡ ( o k ) = log ⁡ ∑ k = 1 q exp ⁡ ( o k ) − ∑ j = 1 q y j o j (8) l(\mathbf{y},\hat{\mathbf{y}})=-\sum^q_{j=1}y_j\log{\frac{\exp(o_j)}{{\sum}_k\exp(o_k)}}=\log\sum_{k=1}^q \exp(o_k)-\sum_{j=1}^qy_jo_j\tag{8} l(y,y^)=j=1qyjlogkexp(ok)exp(oj)=logk=1qexp(ok)j=1qyjoj(8)
得到损失相对于任何未规范化的预测 o j o_j oj 的导数:
∂ o j l ( y , y ^ ) = exp ⁡ ( o j ) ∑ k = 1 q exp ⁡ ( o k ) − y j = softmax ( o ) j − y j (9) \partial_{o_j}l(\mathbf{y},\hat{\mathbf{y}})=\frac{\exp(o_j)}{{\sum}^q_{k=1}\exp(o_k)}-y_j=\text{softmax}(\mathbf{o})_j-y_j\tag{9} ojl(y,y^)=k=1qexp(ok)exp(oj)yj=softmax(o)jyj(9)
  9. [熵的计算], 核心思想是量化数据中的信息内容,这个数值在信息论中称为分布 P P P 的熵(entropy),计算方法为 H [ P ] = ∑ j − P ( j ) log ⁡ P ( j ) H[P]=\sum_j-P(j)\log P(j) H[P]=jP(j)logP(j),纳特、比特…
  [压缩与预测] 想象一个需要压缩的数据流,如果很容易预测下一个数据,则很容易被压缩,而当我们不能完全预测每一个时间,就会感到“惊异”,香农决定用信息量 log ⁡ 1 P ( j ) = − log ⁡ P ( j ) \log\frac{1}{P(j)}=-\log P(j) logP(j)1=logP(j) 来量化这种惊异,我们赋予一个事件较低的概率,我们的惊异越大,包含的信息量也就越大,而上面的熵则是当分配的概率真正匹配数据生成过程的信息量的期望。
  [熵与交叉熵的理解] 把熵 H ( P ) H(P) H(P) 想象为“知道真实概率的人所经历的惊异程度”,交叉熵 H ( P , Q ) H(P,Q) H(P,Q) 则是“主观概率为 Q Q Q 的观察者在看到根据概率 P P P 生成的数据时的预期惊异”, P = Q P=Q P=Q 时交叉熵达到最低
  [交叉熵分类目标] 可以从两个方面考虑交叉熵分类目标:① 最大化观测数据的似然;② 最小化传达标签所需的惊异。

  10. [模型预测评估] 使用预测概率最高的类别作为输出类别,如果预测与实际类别一致,则预测是正确的。使用精度(正确预测数与预测总数之间的比率)来评估模型的性能。

3.2.2 softmax回归的从零开始实现

3.2.3 softmax回归的简洁实现

3.3 图像分类数据集

  MNIST (LeCun et al., 1998) 是图像分类中广泛使用的数据集之一,但是作为基准数据集过于简单,Fashion-MNIST (Xiao et al., 2017)数据集类似但更加复杂,是一个服装分类数据集,它由10个类别的图像组成,每个类别由训练集6000张图像和测试集1000张图像组成。

""" 导入所需库 """
%matplotlib inline
import torch
import torchvision
from torch.utils import data
from torchvision import transforms
from d2l import torch as d2l

d2l.use_svg_display()

""" 读取数据集 """
# 通过ToTensor实例将图像数据从PIL类型变换成32位浮点数格式,
# 并除以255使得所有像素的数值均在0~1之间
trans = transforms.ToTensor()
mnist_train = torchvision.datasets.FashionMNIST(
    root="../data", train=True, transform=trans, download=True)
mnist_test = torchvision.datasets.FashionMNIST(
    root="../data", train=False, transform=trans, download=True)

len(mnist_train), len(mnist_test) # (60000, 10000)
mnist_train[0][0].shape # torch.Size([1, 28, 28])

""" 在数字标签索引及其文本名称之间转换 """
def get_fashion_mnist_labels(labels):  #@save
    """返回Fashion-MNIST数据集的文本标签"""
    text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',
                   'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
    return [text_labels[int(i)] for i in labels]

""" 可视化样本图像 """
def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5):  #@save
    """绘制图像列表"""
    figsize = (num_cols * scale, num_rows * scale)
    _, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize)
    axes = axes.flatten()
    for i, (ax, img) in enumerate(zip(axes, imgs)):
        if torch.is_tensor(img):
            # 图片张量
            ax.imshow(img.numpy())
        else:
            # PIL图片
            ax.imshow(img)
        ax.axes.get_xaxis().set_visible(False)
        ax.axes.get_yaxis().set_visible(False)
        if titles:
            ax.set_title(titles[i])
    return axes

""" 展示图像 """
X, y = next(iter(data.DataLoader(mnist_train, batch_size=18)))
show_images(X.reshape(18, 28, 28), 2, 9, titles=get_fashion_mnist_labels(y));

为了使训练集、测试集的读取更容易,一般使用Pytorch内置的数据迭代器而非从零开始创建。数据迭代器在每次迭代中读取一小批量数据(大小为batch_size),并且可以随机打乱样本,从而保证无偏。

batch_size = 256
def get_dataloader_workers():  #@save
    """使用4个进程来读取数据"""
    return 4

train_iter = data.DataLoader(mnist_train, batch_size, shuffle=True,
                             num_workers=get_dataloader_workers())
""" 统计时间 """
timer = d2l.Timer()
for X, y in train_iter:
    continue
f'{timer.stop():.2f} sec'

整合组件

def load_data_fashion_mnist(batch_size, resize=None):  #@save
    """下载Fashion-MNIST数据集,然后将其加载到内存中"""
    trans = [transforms.ToTensor()]
    if resize:
        trans.insert(0, transforms.Resize(resize))
    trans = transforms.Compose(trans)
    mnist_train = torchvision.datasets.FashionMNIST(
        root="../data", train=True, transform=trans, download=True)
    mnist_test = torchvision.datasets.FashionMNIST(
        root="../data", train=False, transform=trans, download=True)
    return (data.DataLoader(mnist_train, batch_size, shuffle=True,
                            num_workers=get_dataloader_workers()),
            data.DataLoader(mnist_test, batch_size, shuffle=False,
                            num_workers=get_dataloader_workers()))

train_iter, test_iter = load_data_fashion_mnist(32, resize=64)
for X, y in train_iter:
    print(X.shape, X.dtype, y.shape, y.dtype)  # torch.Size([32, 1, 64, 64]) torch.float32 torch.Size([32]) torch.int64
    break

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/367523.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

算法训练营 day53 动态规划 买卖股票的最佳时机系列2

算法训练营 day53 动态规划 买卖股票的最佳时机系列2 买卖股票的最佳时机III 123. 买卖股票的最佳时机 III - 力扣(LeetCode) 给定一个数组,它的第 i 个元素是一支给定的股票在第 i 天的价格。 设计一个算法来计算你所能获取的最大利润。…

软件项目管理知识回顾---网络图

网络图 9.网络图 9.1简介 1.分类 AOA,双代号,ADMAON,PDM,单代号,前导图2.活动的逻辑管理 头到头/尾,尾到头/尾 依赖关系 3.工序 紧前紧后9.2绘制规则 1.两个节点只能一条线。不能是平行线。平行的话就不知道是哪个活动…

LeetCode-93. 复原 IP 地址

目录题目思路回溯法题目来源 93. 复原 IP 地址 题目思路 意识到这是切割问题,切割问题就可以使用回溯搜索法把所有可能性搜出来,和131.分割回文串就十分类似了。 回溯法 1.递归参数 startIndex一定是需要的,因为不能重复分割&#xff0c…

【GeoDjango框架解析】读取矢量数据写入postgis数据库

系列文章目录 提示:这里可以添加系列文章的所有文章的目录,目录需要自己手动添加 【GeoDjango框架解析】读取矢量数据写入postgis数据库 提示:写完文章后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录系列…

关于iframe一些通讯的记录(可适用工作流审批,文中有项目实践,欢迎咨询)

一.知识点(1).我们可以通过postMessage(发送方)和onmessage(接收方)这两个HTML5的方法, 来解决跨页面通信问题&#xff0c;或者通过iframe嵌套的不同页面之间的通信a.父页面代码如下<div v-if"src" class"iframe"><iframeref"iframe"id…

Kafka进阶篇-消费者详解Flume消费Kafka原理

简介 由于挺多时候如果不太熟系kafka消费者详细的话&#xff0c;很容易产生问题&#xff0c;所有剖析一定的原理很重要。 Kafka消费者图解 消费方式 消费者总体工作流程 消费者组初始化流程 消费者详细消费流程 消费者重要参数 bootstrap.servers 向 Kafka 集群建立初…

Jackson使用进阶

实现 注解 属性命名 JsonProperty 定义属性序列化时的名称。 JacksonAnnotation public interface JsonProperty {public final static String USE_DEFAULT_NAME "";public final static int INDEX_UNKNOWN -1;//指定属性的名称。String value() default USE_…

2022IDEA搭建springMvc项目

springmvc项目搭建一. 创建maven项目二. Add Framework Support三. 添加依赖并配置maven四. 配置前端控制器DispatcherServlet五. 配置SpringMVC.XML文件六. 创建controller类七. 创建index.html页面八. 查看jar包是否添加九. 配置tomcat&#xff08;重&#xff09;十. springm…

Kafka(7):生产者详解

1 消息发送 1.1 Kafka Java客户端数据生产流程解析 1 首先要构造一个 ProducerRecord 对象,该对象可以声明主题Topic、分区Partition、键 Key以及值 Value,主题和值是必须要声明的,分区和键可以不用指定。 2 调用send() 方法进行消息发送。 3 因为消息要到网络上进行传输…

国产蓝牙耳机什么便宜又好用?学生党平价蓝牙耳机推荐

蓝牙耳机凭借近几年的快速发展&#xff0c;越来越多的品牌、款式出现在人们的日常生活当中。最近看到很多人问&#xff0c;国产蓝牙耳机什么便宜又好用&#xff1f;针对这个问题&#xff0c;我来给大家推荐几款平价蓝牙耳机&#xff0c;很适合学生党&#xff0c;一起来看看吧。…

推荐系统从入门到入门(3)——基于MapReuduce与Spark的分布式推荐系统构建

本系列博客总结了不同框架、不同算法、不同界面的推荐系统&#xff0c;完整阅读需要大量时间&#xff08;又臭又长&#xff09;&#xff0c;建议根据目录选择需要的内容查看&#xff0c;欢迎讨论与指出问题。 目录 系列文章梗概 系列文章目录 三、MapReduce 1.MapReduce详…

赶紧收藏:如何使用Telegram客户支持

想要使用Telegram需要客户支持&#xff1f;您需要了解的有关使用Telegram作为客户服务渠道的所有信息&#xff0c;本文章都会介绍。我们将首先讨论提供Telegram支持以及入门所需了解的内容。然后&#xff0c;我们将向您展示如何用智能客服工具ss可以帮助您提供一流的服务Telegr…

Oracle监听详解

本文摘自《ORACLE数据库技术实用详解》和《成功之路&#xff1a;ORACLE 11g学习笔记》 配置网络环境 本文将介绍和Oracle相关的网络问题&#xff0c;Oracle网络建立在操作系统之上。配置操作系统网络是配置Oracle网络的第一步。在配置Oracle网络之前&#xff0c;我们需要确保操…

数学基础--均值、方差、标准差、协方差

1. 简介 统计学中最核心的概念之一是&#xff1a;标准差及其与其他统计量&#xff08;如方差和均值&#xff09;之间的关系&#xff0c;本文将对标准差这一概念提供直观的视觉解释&#xff0c;在文章的最后我们将会介绍协方差的概念。 2. 概念介绍 均值 均值&#xff1a; 均值…

(一)Spring-Cloud源码分析之核心流程关系及springcloud与springboot包区别(新)

文章目录1. 前言2. springcloud简介3. Springcloud包简介4. Springcloud和Springboot流程关系5. Springcloud启动流程新增的功能和接口5.1 新增接口5.2 新增功能类5.2.1 spring-cloud-context包5.2.2 spring-cloud-commons包6. Springcloud实现机制带来的问题7. Springcloud和S…

【MyBatis】映射器配置|注解完成CRUD(三)

&#x1f697;MyBatis学习第三站~ &#x1f6a9;起始站&#xff1a;MyBatis概述&环境搭建(一) &#x1f6a9;本文已收录至专栏&#xff1a;数据库学习之旅 &#x1f44d;希望您能有所收获 上一篇我们学习了如何使用Mapper代理开发&#xff0c;核心配置文件&#xff0c;但却…

OnlyOffice验证(一)DocumentServer编译验证

OnlyOffice验证&#xff08;一&#xff09;DocumentServer编译验证 资源准备 Ubuntu16.04桌面版 验证用的版本[ubuntu-16机接上传ubuntu.04.7-desktop-amd67131.iso&#xff0c;&#xff08;别用高版本&#xff01;试过20.04耽误两三天&#xff0c;差点放弃了&#xff09;&am…

javaee之node.js与es6

问题1&#xff1a;在IDEA控制台为什么node显示不会出来命令 修改完之后记得重新启动电脑 问题2&#xff1a;response.end()作用 在Web开发中&#xff0c;浏览器端的请求到达服务器进行处理的时候&#xff0c;Response.End的作用就是让request执行到此结束&#xff0c;输出到客户…

移掉K位数字-力扣402-java贪心策略

一、题目描述给你一个以字符串表示的非负整数 num 和一个整数 k &#xff0c;移除这个数中的 k 位数字&#xff0c;使得剩下的数字最小。请你以字符串形式返回这个最小的数字。示例 1 &#xff1a;输入&#xff1a;num "1432219", k 3输出&#xff1a;"1219&q…

Vue实战第5章:发布Vue工程到github静态页面

前言 本篇在讲什么 简单讲解关于Vue发布github静态页面相关的内容 本篇适合什么 适合初学Vue的小白 适合想要自己搭建网站的新手 本篇需要什么 对Html和css语法有简单认知 对Vue有简单认知 Node.js(博主v18.13.0)的开发环境 Npm(博主v8.19.3)的开发环境 Vue(博主v5.…