【CNN】经典网络LeNet——最早发布的卷积神经网络之一

news2025/3/1 11:16:52

前言

LeNet是Yann LeCun于1988年提出的用于数字识别的网络结构,可以说LeNet是深度CNN网络的基石,AlexNet、VGG、GoogLeNet、ResNet等都是在VGG基础上加入各类激活函数或加深网络演变而来的,所以理解LeNet对于现在主流CNN深度学习架构的理解有很大帮助。

关于LeNet详细的介绍可以阅读,《Gradient-Based Learning Applied to Document Recognition》,对LeNet的架构做了详细的介绍,并对LeNet与其他算法做了详细的对比。并且这是第一篇通过反向传播成功训练卷积神经网络的研究。

一,介绍

LeNet主要的出现契机是手写数字的识别,并在邮政和银行发挥了非常重要的角色。但是,这个网络在当时流行度没那么高,但是知名度最高的还是MNIST数据集。
在这里插入图片描述
所有的都是黑白图。
对于LeNet,总体来看,LeNet(LeNet-5)由两个部分组成:

  • 卷积编码层:由两个卷积层组成
  • 全连接密集块:由是哪个全连接层组成

架构图如下图所示:
在这里插入图片描述
输入的是28 * 28的单通道图片 得到6输出通道的28 * 28feature map, 通过池化层,得到 6通道 14 * 14的特征图,最后通过卷积操作得到16 输出通道的特征图,在通过池化得到16通道的5 * 5特征图,最后使用三个全连接层,拉成10通道输出,得到0 ~ 9的数字识别结果。
还有一些超参数:

  1. c1 是: kernel_size = 5,padding = 2.
  2. s2 是:kernel_size = 5,stride = 2
  3. c3 是:kernel_size = 5
  4. s4 是: kernel_size = 2,stride = 2

卷积块中的基本单元是一个卷积层、一个sigmoid激活函数和平均汇聚层,虽然ReLU最大汇聚层更有效,但它们在20世纪90年代还没有出现。每个卷积层使用卷积核和一个sigmoid激活函数去代替。

对于LeNet是早期成功的神经网络,先使用卷积层来学习图片的空间信息,然后使用全连接层来转换到别的空间
这个思想,影响了早期神经网络的训练模式,现在几乎不这样。

二,代码实现

按照卷积的计算公式和上面的超参数,通过卷积的输出计算公式搭建网络:
在这里插入图片描述

2.1 搭建网络

导入所需要的包

import torch
from torch import nn
import torchvision
from torchvision import transforms
from torch.utils import data
import matplotlib.pyplot as plt
import numpy as np

按照上图图示搭建网络结构

net = nn.Sequential(
    nn.Conv2d(1, 6, kernel_size=5, padding=2), nn.Sigmoid(),
    nn.AvgPool2d(kernel_size=2, stride=2),
    nn.Conv2d(6, 16, kernel_size=5), nn.Sigmoid(),
    nn.AvgPool2d(kernel_size=2, stride=2),
    nn.Flatten(),
    nn.Linear(16 * 5 * 5, 120), nn.Sigmoid(),
    nn.Linear(120, 84), nn.Sigmoid(),
    nn.Linear(84, 10))

搭建好网络之后,我们需要通过大小为的单通道(黑白)图像通过LeNet。通过在每一层打印输出的形状,我们可以检查模型,以确保其操作与我们期望的。

X = torch.rand(size=(1, 1, 28, 28), dtype=torch.float32)
for layer in net:
    X = layer(X)
    print(layer.__class__.__name__, 'output shape: \t', X.shape)

在这里插入图片描述

2.2 模型的训练

现在我们已经实现了LeNet,让我们来看看在Fashion-MNIST数据集上的表现。
先读取对应的数据集:

def load_data_fashion_mnist(batch_size, resize=None):
    """Download the Fashion-MNIST dataset and then load it into memory.
    Defined in :numref:`sec_fashion_mnist`"""
    trans = [transforms.ToTensor()]
    if resize:
        trans.insert(0, transforms.Resize(resize))
    trans = transforms.Compose(trans)
    mnist_train = torchvision.datasets.FashionMNIST(
        root="../data", train=True, transform=trans, download=True)
    mnist_test = torchvision.datasets.FashionMNIST(
        root="../data", train=False, transform=trans, download=True)
    return (data.DataLoader(mnist_train, batch_size, shuffle=True,
                            num_workers=4),
            data.DataLoader(mnist_test, batch_size, shuffle=False,
                            num_workers=4))

设置训练的小批量大小

batch_size = 256
train_iter, test_iter = load_data_fashion_mnist(batch_size = batch_size)

设置超参数和epoches的数量

device = ('cuda' if torch.cuda.is_available() else 'cpu')
net.to(device)
optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)
# 损失函数
loss_func = nn.CrossEntropyLoss()
epoches = 120
costs = []

开始训练

for epoch in range(epoches):
    sum_loss = 0
    net.train()
    for step, (batch_x,batch_y) in enumerate(train_iter):
        # 如果是存在gpu的话,直接使用gpu来训练
        if torch.cuda.is_available():
            batch_x = batch_x.cuda()
            batch_y = batch_y.cuda()
        # 梯度清零
        optimizer.zero_grad()
        output = net(batch_x)
        loss = loss_func(output, batch_y)
        loss.backward()
        optimizer.step()
        if step % 100 == 0:
            costs.append(loss)
            sum_loss += loss
            print(f'epoch:{epoch + 1},mini_batch:{step + 1},mini_loss:{sum_loss / 100}')
            sum_loss = 0.0
    # 验证
    net.eval()
    correct = 0.0
    total = 0
    for(test_x, test_y) in test_iter:
        if torch.cuda.is_available():
            test_x = test_x.cuda()
            test_y = test_y.cuda()
        test_output = net(test_x)
        # 只返回最大数的那个索引
        predicted = torch.max(test_output, 1)[1]
#         计算总数
        total += test_y.size(0)
#     计算预测的正确数目
        correct += (predicted == test_y).sum()
    print(f'correct:{correct}')
    print(f'total:{total}')
    print(f'Test acc:{(correct / total * 100):.2f}%')

绘制损失图与得到最终结论:

if torch.cuda.is_available():
    costs = [cost.cpu().detach().numpy() for cost in costs]
else:
    costs = [cost.numpy for cost in costs]
plt.plot(costs)
plt.xlabel('number of iteration')
plt.ylabel('loss in train')
plt.title('LeNet')
plt.show()

最后得到的loss图示:
在这里插入图片描述
与一般的验证精度:
在这里插入图片描述
总体来说表现良好。

三,总结

  • 卷积神经网络(CNN)是一类使用卷积层的网络。

  • 在卷积神经网络中,我们组合使用卷积层、非线性激活函数和汇聚层。

  • 为了构造高性能的卷积神经网络,我们通常对卷积层进行排列,逐渐降低其表示的空间分辨率,同时增加通道数。

  • 在传统的卷积神经网络中,卷积块编码得到的表征在输出之前需由一个或多个全连接层进行处理。

  • LeNet是最早发布的卷积神经网络之一。


参考:
https://zh-v2.d2l.ai/chapter_convolutional-neural-networks/lenet.html
https://blog.csdn.net/qq_43960768/article/details/124652618?spm=1001.2101.3001.6650.2&utm_medium=distribute.pc_relevant.none-task-blog-2%7Edefault%7ECTRLIST%7ERate-2-124652618-blog-122780852.pc_relevant_multi_platform_whitelistv4&depth_1-utm_source=distribute.pc_relevant.none-task-blog-2%7Edefault%7ECTRLIST%7ERate-2-124652618-blog-122780852.pc_relevant_multi_platform_whitelistv4&utm_relevant_index=5

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/38860.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

制作一个简单HTML电影网页设计(HTML+CSS)

HTML实例网页代码, 本实例适合于初学HTML的同学。该实例里面有设置了css的样式设置,有div的样式格局,这个实例比较全面,有助于同学的学习,本文将介绍如何通过从头开始设计个人网站并将其转换为代码的过程来实践设计。 文章目录一、网页介绍一…

基于蚁群算法的多配送中心的车辆调度问题的研究(Matlab代码实现)

👨‍🎓个人主页:研学社的博客 💥💥💞💞欢迎来到本博客❤️❤️💥💥 🏆博主优势:🌞🌞🌞博客内容尽量做到思维缜…

【图像处理】小波编码图像中伪影和纹理的检测附Matlab代码和报告

✅作者简介:热爱科研的Matlab仿真开发者,修心和技术同步精进,matlab项目合作可私信。 🍎个人主页:Matlab科研工作室 🍊个人信条:格物致知。 更多Matlab仿真内容点击👇 智能优化算法 …

如果各位同学还对时间复杂度有疑问?看这一篇就可以啦!

🎇🎇🎇作者: 小鱼不会骑车 🎆🎆🎆专栏: 《java练级之旅》 🎓🎓🎓个人简介: 一名专科大一在读的小比特,努力学习编程是我…

chrome浏览器一键切换搜索引擎,一键切换谷歌和百度搜索

chrome浏览器一键切换搜索引擎,一键切换谷歌和百度搜索 背景 有么有办法在谷歌和百度之间(或其他引擎或非引擎,如Youtube、B站、Bing等)之间切换。我们当然是不想重新输入keyword,甚至点击浏览器插件的图标后再选择引…

Scala010--Scala中的常用集合函数及操作Ⅰ

之前我们已经知道了Scala中的数据结果有哪些,并且能够使用for循环取到该数据中的元素,现在我们再进一步的去了解更加方便及常用的函数操作,使得我们能够对集合更好的利用。 目录 一,foreach函数 1,遍历一维数组 1&…

Pytorch中CrossEntropyLoss()详解

一、损失函数 nn.CrossEntropyLoss() 交叉熵损失函数 nn.CrossEntropyLoss() ,结合了 nn.LogSoftmax() 和 nn.NLLLoss() 两个函数。 它在做分类(具体几类)训练的时候是非常有用的。 二. 什么是交叉熵 交叉熵主要是用来判定实际的输出与期望…

HTML CSS个人网页设计与实现——人物介绍丁真(学生个人网站作业设计)

🎉精彩专栏推荐👇🏻👇🏻👇🏻 ✍️ 作者简介: 一个热爱把逻辑思维转变为代码的技术博主 💂 作者主页: 【主页——🚀获取更多优质源码】 🎓 web前端期末大作业…

SpringBoot SpringBoot 原理篇 1 自动配置 1.8 bean 的加载方式【六】

SpringBoot 【黑马程序员SpringBoot2全套视频教程,springboot零基础到项目实战(spring boot2完整版)】 SpringBoot 原理篇 文章目录SpringBootSpringBoot 原理篇1 自动配置1.8 bean 的加载方式【六】1.8.1 ImportSelector1 自动配置 1.8 b…

改进牛顿法潮流计算IEEE33节点潮流计算matlab程序——

IEEE33节点潮流计算matlab程序——改进牛顿法潮流计算 改进牛顿法的基本原理 参考文献:一种新的配电网潮流算法——改进牛顿法-拉夫逊法 牛顿法是改进牛顿法的基础,对牛顿法作科学的近似,即雅可比矩阵做一些更改,使得每次计算得…

stm32项目平衡车详解(stm32F407)下

stm32项目平衡车详解(stm32F407)下 HC-SRO4 超声波测距避障功能开发 TSL1401 CCD摄像头实现小车巡线功能 文章目录stm32项目平衡车详解(stm32F407)下前言一、HC-SRO4 超声波测距避障功能开发HC-SRO4超声波测距模块?超声波测距避障功能开发避障模式开发二、TSL1401 …

【微软】【ICLR 2022】TAPEX:通过学习神经 SQL 执行器进行表预训练

重磅推荐专栏: 《Transformers自然语言处理系列教程》 手把手带你深入实践Transformers,轻松构建属于自己的NLP智能应用! 论文:https://arxiv.org/abs/2107.07653 代码:https://github.com/microsoft/Table-Pretrainin…

数字图像处理(十五)图像旋转

文章目录前言一、图像旋转算法1.算法原理2. 一些需要注意的点3.举例4. 均值插值法二、编程实现1.C代码2.实验结果参考资料前言 图像的旋转是指以图像中的某一点为原点以逆时针或者顺时针方向旋转一定的角度。通常是绕图像的起始点以逆时针进行旋转。 一、图像旋转算法 1.算法原…

JAVA并发之谈谈你对AQS的理解

文章目录一、AQS是什么二、AQS具备哪些特性三、用的哪种设计模式四、AQS与锁二者之间的关系五、如何基于AQS实现一把独占锁六、参考资料一、AQS是什么 AQS的全称是 (AbstractQueuedSynchronizer ),它定义了一套多线程访问共享资源的同步器框架…

【算法基础】(一)基础算法 --- 归并排序

✨个人主页:bit me ✨当前专栏:算法基础 🔥专栏简介:该专栏主要更新一些基础算法题,有参加蓝桥杯等算法题竞赛或者正在刷题的铁汁们可以关注一下🌹 🌹 🌹 归并排序💤一.归…

猴子也能学会的jQuery第十期——jQuery元素操作(上)

📚系列文章—目录🔥 猴子也能学会的jQuery第一期——什么是jQuery 猴子也能学会的jQuery第二期——引用jQuery 猴子也能学会的jQuery第三期——使用jQuery 猴子也能学会的jQuery第四期——jQuery选择器大全 猴子也能学会的jQuery第五期——jQuery样式操作…

基于拟蒙特卡洛模拟法的随机潮流计算matlab程序

电力系统随机潮流计算中常采用模拟法,该方法原理简单、使用方便,能够精确地模拟实际物理过程,但是简单的蒙特卡洛模拟法收敛速度很慢,要得到精确的结果需要以大量的计算时间为代价。本章在此基础上提出了基于拟蒙特卡洛模拟的随机…

【菜菜的sklearn课堂笔记】逻辑回归与评分卡-用逻辑回归制作评分卡-异常值和样本不均衡处理

视频作者:菜菜TsaiTsai 链接:【技术干货】菜菜的机器学习sklearn【全85集】Python进阶_哔哩哔哩_bilibili 描述性统计处理异常值 现实数据永远都会有一些异常值,首先我们要去把他们捕捉出来,然后观察他们的性质。注意&#xff0c…

【雷达检测】基于复杂环境下的雷达目标检测技术附Matlab代码

✅作者简介:热爱科研的Matlab仿真开发者,修心和技术同步精进,matlab项目合作可私信。 🍎个人主页:Matlab科研工作室 🍊个人信条:格物致知。 更多Matlab仿真内容点击👇 智能优化算法 …

3.6、媒体接入控制

1、基本概念 有多台主机连接到这根同轴电缆上,共享这跟传输媒体,形成了总线型的局域网。 各主机竞争使用总线,随机的在信道发送数据。 主机 C 与主机 D 同时使用总线来发送数据,这必然会产生所发送信号的碰撞 2、静态划分信道…