机器学习深度学习——卷积神经网络(LeNet)

news2025/1/8 0:08:30

👨‍🎓作者简介:一位即将上大四,正专攻机器学习的保研er
🌌上期文章:机器学习&&深度学习——池化层
📚订阅专栏:机器学习&&深度学习
希望文章对你们有所帮助

卷积神经网络(LeNet)

  • 引言
  • LeNet
  • 模型训练
  • 小结

引言

之前的内容中曾经将softmax回归模型和多层感知机应用于Fashion-MNIST数据集中的服装图片。为了能应用他们,我们首先就把图像展平成了一维向量,然后用全连接层对其进行处理。
而现在已经学习过了卷积层的处理方法,我们就可以在图像中保留空间结构。同时,用卷积层代替全连接层的另一个好处是:模型更简单,所需参数更少。
LeNet是最早发布的卷积神经网络之一,之前出来的目的是为了识别图像中的手写数字。

LeNet

总体看,由两个部分组成:
1、卷积编码器:由两个卷积层组成
2、全连接层密集快:由三个全连接层组成
在这里插入图片描述
上图中就是LeNet的数据流图示,其中汇聚层也就是池化层。
最终输出的大小是10,也就是10个可能结果(0-9)。
每个卷积块的基本单元是一个卷积层、一个sigmoid激活函数和平均池化层(当年没有ReLU和最大池化层)。每个卷积层使用5×5卷积核和一个sigmoid激活函数。
这些层的作用就是将输入映射到多个二维特征输出,通常同时增加通道的数量。(从上图容易看出:第一卷积层有6个输出通道,而第二个卷积层有16个输出通道;每个2×2池操作(步幅也为2)通过空间下采样将维数减少4倍)。卷积的输出形状那是由批量大小、通道数、高度、宽度决定。
为了将卷积块的输出传递给稠密块,我们必须在小批量中展平每个样本(也就是把四维的输入转换为全连接层期望的二维输入,第一维索引小批量中的样本,第二维给出给个样本的平面向量表示)。
LeNet的稠密块有三个全连接层,分别有120、84和10个输出。因为我们在执行分类任务,所以输出层的10维对应于最后输出结果的数量(代表0-9是个结果)。
深度学习框架实现此类模型非常简单,用一个Sequential块把需要的层连接在一个就可以了,我们对原始模型做一个小改动,去掉最后一层的高斯激活:

import torch
from torch import nn
from d2l import torch as d2l

net = nn.Sequential(
    # 输入图像和输出图像都是28×28,因此我们要先进行填充2格
    nn.Conv2d(1, 6, kernel_size=5, padding=2), nn.Sigmoid(),
    nn.AvgPool2d(kernel_size=2, stride=2),
    nn.Conv2d(6, 16, kernel_size=5), nn.Sigmoid(),
    nn.AvgPool2d(kernel_size=2, stride=2),
    nn.Flatten(),
    nn.Linear(16 * 5 * 5, 120), nn.Sigmoid(),
    nn.Linear(120, 84), nn.Sigmoid(),
    nn.Linear(84, 10)
)

上面的模型图示就为:
在这里插入图片描述
我们可以先检查模型,在每一层打印输出的形状:

X = torch.rand(size=(1, 1, 28, 28), dtype=torch.float32)
for layer in net:
    X = layer(X)
    print(layer.__class__.__name__, 'output shape:\t', X.shape)

输出结果:

Conv2d output shape: torch.Size([1, 6, 28, 28])
Sigmoid output shape: torch.Size([1, 6, 28, 28])
AvgPool2d output shape: torch.Size([1, 6, 14, 14])
Conv2d output shape: torch.Size([1, 16, 10, 10])
Sigmoid output shape: torch.Size([1, 16, 10, 10])
AvgPool2d output shape: torch.Size([1, 16, 5, 5])
Flatten output shape: torch.Size([1, 400])
Linear output shape: torch.Size([1, 120])
Sigmoid output shape: torch.Size([1, 120])
Linear output shape: torch.Size([1, 84])
Sigmoid output shape: torch.Size([1, 84])
Linear output shape: torch.Size([1, 10])

模型训练

既然已经实现了LeNet,现在可以查看它在Fashion-MNIST数据集上的表现:

batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)

计算成本较高,因此使用GPU来加快训练。为了进行评估,对之前的evaluate_accuracy进行修改,由于完整的数据集位于内存中,因此在模型使用GPU计算数据集之前,我们需要将其复制到显存中。

def evaluate_accuracy_gpu(net, data_iter, device=None):
    """使用GPU计算模型在数据集上的精度"""
    if isinstance(net, nn.Module):
        net.eval()  # 设置为评估模式
        if not device:
            device = next(iter(net.parameters())).device
    # 正确预测的数量,总预测的数量
    metric = d2l.Accumulator(2)
    with torch.no_grad():
        for X, y in data_iter:
            if isinstance(X, list):
            # BERT微调所需(后面内容)
            else:
                X = X.to(device)
            y = y.to(device)
            metric.add(d2l.accuracy(net(X), y), y.numel())
    return metric[0] / metric[1]

要使用GPU,我们要在正向和反向传播之前,将每一小批量数据移动到我们GPU上。
如下所示的train_ch6类似于之前定义的train_ch3。以下训练函数假定从高级API创建的模型作为输入,并进行相应的优化。
使用Xavier来随机初始化模型参数。有关于Xavier的推导和原理可以看下面的文章:
机器学习&&深度学习——数值稳定性和模型化参数(详细数学推导)
与全连接层一样,使用交叉熵损失函数和小批量随机梯度下降,代码如下:

def train_ch6(net, train_iter, test_iter, num_epochs, lr, device):  #@save
    """用GPU训练模型"""
    def init_weights(m):
        if type(m) == nn.Linear or type(m) == nn.Conv2d:
            nn.init.xavier_uniform_(m.weight)
    net.apply(init_weights)
    print('training on', device)
    net.to(device)
    optimizer = torch.optim.SGD(net.parameters(), lr=lr)
    loss = nn.CrossEntropyLoss()
    animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs],
                            legend=['train loss', 'train acc', 'test acc'])
    timer, num_batches = d2l.Timer(), len(train_iter)
    for epoch in range(num_epochs):
        # 训练损失之和,训练准确率之和,样本数
        metric = d2l.Accumulator(3)
        net.train()
        for i, (X, y) in enumerate(train_iter):
            timer.start()
            optimizer.zero_grad()
            X, y = X.to(device), y.to(device)
            y_hat = net(X)
            l = loss(y_hat, y)
            optimizer.step()
            with torch.no_grad():
                metric.add(l * X.shape[0], d2l.accuracy(y_hat, y), X.shape[0])
            timer.stop()
            train_l = metric[0] / metric[2]
            train_acc =  metric[1] / metric[2]
            if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:
                animator.add(epoch + (i+1) / num_batches, (train_l, train_acc, None))
        test_acc = evaluate_accuracy_gpu(net, test_iter)
        animator.add(epoch + 1, (None, None, test_acc))
    print(f'loss {train_l:.3f}, train acc {train_acc:.3f}, '
          f'test acc {test_acc:.3f}')
    print(f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec '
          f'on {str(device)}')

此时我们可以开始训练和评估LeNet模型:

lr, num_epochs = 0.9, 10
train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())
d2l.plt.show()

运行输出(这边我没有用远程的GPU,在自己本地跑了,本地只有CPU):

training on cpu
loss 0.477, train acc 0.820, test acc 0.795
8004.2 examples/sec on cpu

运行图片:
在这里插入图片描述

小结

1、卷积神经网络(CNN)是一类使用卷积层的网络
2、在卷积神经网络中,我们组合使用卷积层、非线性激活函数和池化层
3、为了构造高性能的卷积神经网络,我们通常对卷积层进行排列,逐渐降低其表示的空间分辨率,同时增加通道数
4、传统卷积神经网络中,卷积块编码得到的表征在输出之前需要由一个或多个全连接层进行处理

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/836232.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

微信朋友圈会自动点赞?

网友称微信存在bug,朋友圈会自动点赞?腾讯回应了 微信作为国内最大的网络社交平台,目前用户已超过11亿。 令人吃惊的是,拥有这么庞大用户数量的平台,竟然有可能存在Bug。 近日,#微信回应看朋友圈会自动点…

Linux性能分析工具介绍(二)--内存、进程、磁盘、IO分析

目录 一、引言 二、Linux性能分析工具介绍 ------>2.1、进程 ------>2.2、内存 ------>2.3、磁盘 ------>2.4、IO 一、引言 本章从内存、IO、进程的角度,分析linux系统的性能 二、Linux性能分析工具介绍 2.1、进程 2.1.1、top top命令可以动态查看进程…

【pandas百炼成钢】数据预览与预处理

知识目录 前言一、数据查看1 - 查看数据维度2 - 随机查看5条数据3 - 查看数据前后5行4 - 查看数据基本信息5 - 查看数据统计信息|数值6 - 查看数据统计信息|非数值7 - 查看数据统计信息|整体 二、缺失值处理8 - 计算缺失值|总计9 …

【ASP.NET MVC】使用动软(三)(11)

一、问题 上文中提到,动软提供了数据库的基本操作功能,但是往往需要添加新的功能来解决实际问题,比如GetModel,通过id去查对象: 这个功能就需要进行改进:往往程序中获取的是实体的其他属性,比如…

浪潮服务器硬盘指示灯显示黄色的服务器数据恢复案例

服务器数据恢复环境: 宁夏某市某单位的一台浪潮服务器,该服务器中有一组由6块SAS硬盘组建的RAID5阵列。 服务器上存放的是Oracle数据库文件,操作系统层面划分了1个卷。 服务器故障&初检: 服务器在运行过程中有两块磁盘的指示灯…

需要仔细了解公文类型和目的,以便选择合适的写作风格

撰写公文前需要仔细了解公文类型和目的,以便选择合适的写作风格。 不同类型的公文有不同的结构、内容和表达方式,需要根据具体类型和目的来选择合适的写作风格和表达方式。例如,通知、公告等公文需要采用简洁明了、规范严谨的表达方式&#x…

一篇文章教你学会:对Java集合进行并集,交集,差集运算

废话不多,直接上代码: 目录 1:新建一个实体类 2:准备好数据 3:使用stream 流求 3.1 并集 3.2 交集 3.3 差集 3.31(第一种) 3.32(第二种) 4:使用Gool…

《吐血整理》高级系列教程-吃透Fiddler抓包教程(28)-Fiddler如何抓取Android7.0以上的Https包-下篇

1.简介 虽然依旧能抓到大部分Android APP的HTTP/HTTPS包,但是别高兴的太早,有的APP为了防抓包,还做了很多操作: ① 二次加密 有的APP,在涉及到关键数据通信时,会将正文二次加密后才通过HTTPS发送&#xff…

RFID资产管理系统的选择

RFID资产管理是一种有效的资产过程控制方法,可以帮助企业实现高效的资产管理。选择RFID技术,可以高度集成各种资产信息,完成实时跟踪管理。   根据RFID资产管理系统,可以做到资产的实时管理,使企业管理者可以实时了解…

Android优化篇|网络预连接

作者:苍耳叔叔 一个示例 前后分别去请求同一个域名下的接口,通过 Charles 抓包,可以看到 Timing 下面的时间: 第二次请求时,DNS、Connect 和 TLS Handshake 部分都是 -,说明没有这部分的耗时,…

C# 控制台彩色深度打印 工具类

文章目录 前言Nuget 环境安装代码使用打印结果 总结 前言 有时候我们想要靠打印获得程序信息,因为Dubeg模式需要一点一点断点进入进出,但是我们觉得断点运行实在是太慢了,还是直接打印后找结果会好一点。 Nuget 环境安装 想自己写的话可以看…

unity tolua热更新框架教程(1)

git GitHub - topameng/tolua: The fastest unity lua binding solution 拉取到本地 使用unity打开,此处使用环境 打开前几个弹窗(管线和api升级)都点确定 修改项目设置 切换到安卓平台尝试打包编译 设置包名 查看报错 打开 屏蔽接口导出 重新生成 编译通过 …

FineReport常用功能

不分页显示数据 参见:https://help.fanruan.com/finereport/doc-view-328.html?source4 列数多时,所有列不能在一页显示,可在URL后增加如下参数,添加模版时,可以作为模版参数进行设置: 分页预览模式&am…

orangepi 4lts ubuntu安装RabbitMQ

4lts的emmc 系统安装选文件系统格式 ext4 需先安装erlang: sudo apt install erlang 安装RabbitMQ: sudo apt install rabbitmq-server - 添加用户以便远程访问: - 账号密码都是admin: sudo rabbitmqctl add_user admin admin -sudo rabbitmqct…

C 语言高级2-多维数组,结构体,递归操作

1. 多维数组 1.1 一维数组 元素类型角度:数组是相同类型的变量的有序集合内存角度:连续的一大片内存空间 在讨论多维数组之前,我们还需要学习很多关于一维数组的知识。首先让我们学习一个概念。 1.1.1 数组名 考虑下面这些声明&#xff1…

钉钉微应用

钉钉微应用 在做钉钉微应用开发的时候,遇到了一些相关性的问题,特此记录下,有遇到其他问题的,欢迎一起讨论 调试工具 当我们基于钉钉开发微应用时,难免会遇到调用钉钉api后的调试,这个时候可以安装eruda…

笔记本WIFI连接无网络【实测有效解决方案,不用重启电脑】

笔记本Wifi连接无网络实测有效解决方案 问题描述: 笔记本买来一段时间后,WIFI网络连接开机一段时间还正常连接,但是过一段时间显示网络连接不上解决方案: 1.编写网络重启bat脚本,将以下内容写到文本文件,把…

使用 FastGPT 构建高质量 AI 知识库

作者:余金隆。FastGPT 项目作者,Sealos 项目前端负责人,前 Shopee 前端开发工程师 FastGPT 项目地址:https://github.com/labring/FastGPT/ 引言 自从去年 12 月 ChatGPT 发布以来,带动了一轮新的交互应用革命。尤其在…

【分布式系统】聊聊系统监控

对于分布式系统来说,出现故障的是常有的事情,如何在短时间内找到故障的原因,排除故障是非常重要的,而监控系统是就像系统的眼睛可以通过分析相关数据,进一步管理和运维整个分布式系统。 监控系统的的基本功能包含 全…

Java02-迭代器,数据结构,List,Set ,TreeSet集合,Collections工具类

目录 什么是遍历? 一、Collection集合的遍历方式 1.迭代器遍历 方法 流程 案例 2. foreach(增强for循环)遍历 案例 3.Lamdba表达式遍历 案例 二、数据结构 数据结构介绍 常见数据结构 栈(Stack) 队列&a…