【深度学习】经典的深度学习模型-01 开山之作:CNN卷积神经网络LeNet-5

news2024/11/25 5:16:27

【深度学习】经典的深度学习模型-01 开山之作:CNN卷积神经网络LeNet-5

Note: 草稿状态,持续更新中,如果有感兴趣,欢迎关注。。。

0. 论文信息

@article{lecun1998gradient,
title={Gradient-based learning applied to document recognition},
author={LeCun, Yann and Bottou, L{'e}on and Bengio, Yoshua and Haffner, Patrick},
journal={Proceedings of the IEEE},
volume={86},
number={11},
pages={2278–2324},
year={1998},
publisher={Ieee}
}

基于梯度的学习在文档识别中的应用
在这里插入图片描述
LeNet-5 是一个经典的卷积神经网络(CNN)架构,由 Yann LeCun 等人在 1998 年提出,主要用于手写数字识别任务,特别是在 MNIST 数据集上。
在这里插入图片描述
LeNet-5 的设计对后来的卷积神经网络研究产生了深远影响,该模型具有以下几个特点:

  1. 卷积层:LeNet-5 包含多个卷积层,每个卷积层后面通常会跟一个池化层(Pooling Layer),用于提取图像特征并降低特征图的空间维度。

  2. 池化层:在卷积层之后,LeNet-5 使用池化层来降低特征图的空间分辨率,减少计算量,并增加模型的抽象能力。

  3. 全连接层:在卷积和池化层之后,LeNet-5 包含几个全连接层,用于学习特征之间的复杂关系。

  4. 激活函数:LeNet-5 使用了 Sigmoid 激活函数,这是一种早期的非线性激活函数,用于引入非线性,使得网络可以学习复杂的模式。

  5. Dropout:尽管原始的 LeNet-5 并没有使用 Dropout,但后来的研究者在改进模型时加入了 Dropout 技术,以减少过拟合。

  6. 输出层:LeNet-5 的输出层通常使用 Softmax 激活函数,用于进行多分类任务,输出每个类别的概率。

虽然站在2024年看LeNet-5 的模型结构相对简单,但是时间回拨到1998年,彼时SVM这类算法为主的时代,LeNet-5的出现,不仅证明了卷积神经网络在图像识别任务中的有效性,而且为后续深度神经网络研究的发展带来重要启迪作用,使得我们有幸看到诸如 AlexNet、VGGNet、ResNet 等模型的不断推成出新。

2. 论文摘要

3. 研究背景

4. 算法模型

5. 实验效果

6. 代码实现

以MNIST手写字图像识别问题为例子,采用LeNet5模型进行分类,代码如下:

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


# Define the LeNet-5 model
class LeNet5(nn.Module):
    def __init__(self):
        super(LeNet5, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 5)  # 1 input image channel, 6 output channels, 5x5 kernel
        self.pool = nn.MaxPool2d(2, 2)  # pool with window 2x2, stride 2
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 4 * 4, 120)  # 16*4*4 = 256
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 4 * 4)  # flatten the tensor
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# # Initialize the network
# net = LeNet5()

# Initialize the network on GPU
net = LeNet5().to(device)

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

# Data loading
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=1000, shuffle=False)

# Train the network
for epoch in range(10):  # loop over the dataset multiple times
    running_loss = 0.0
    for i, data in enumerate(train_loader, 0):
        # for cpu
        # inputs, labels = data
        # for gpu
        inputs, labels = data[0].to(device), data[1].to(device)
        optimizer.zero_grad()

        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if i % 2000 == 1999:  # print every 2000 mini-batches
            print(f'[{epoch + 1}, {i + 1}] loss: {running_loss / 2000:.3f}')
            running_loss = 0.0

print('Finished Training')

# Test the network on the test data
correct = 0
total = 0
with torch.no_grad():
    for data in test_loader:
        # # for cpu
        # images, labels = data
        # for gpu
        images, labels = data[0].to(device), data[1].to(device)
        outputs = net(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Accuracy of the network on the 10000 test images: {100 * correct / total:.2f}%')

注意:这里使用GPU做简单加速。如果没有GPU,可以关闭对应代码,替换为相应的CPU代码即可。
程序运行后结果如下:
在这里插入图片描述
可以看到,在测试数据上的准确率为98.33%!

7. 问题及优化

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2210676.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【智能算法应用】长鼻浣熊优化算法求解二维路径规划问题

摘要 本文采用长鼻浣熊优化算法 (Coati Optimization Algorithm, COA) 求解二维路径规划问题。COA 是一种基于长鼻浣熊的觅食和社群行为的智能优化算法,具有快速收敛性和较强的全局搜索能力。通过仿真实验,本文验证了 COA 在复杂环境下的路径规划性能&a…

【微服务】springboot3 集成 Flink CDC 1.17 实现mysql数据同步

目录 一、前言 二、常用的数据同步解决方案 2.1 为什么需要数据同步 2.2 常用的数据同步方案 2.2.1 Debezium 2.2.2 DataX 2.2.3 Canal 2.2.4 Sqoop 2.2.5 Kettle 2.2.6 Flink CDC 三、Flink CDC介绍 3.1 Flink CDC 概述 3.1.1 Flink CDC 工作原理 3.2 Flink CDC…

数据结构:栈的创建、使用以及销毁

这里写目录标题 栈的结构与概念栈底层结构的选取栈的代码实现(stack)头文件(stack.h)栈的初始化栈的销毁入栈出栈获取栈顶数据获取栈大小代码的测试 栈的结构与概念 栈:⼀种特殊的线性表,其只允许在固定的…

【算法篇】动态规划类(1)(笔记)

目录 一、理论基础 1. 大纲 2. 动态规划的解题步骤 二、LeetCode 题目 1. 斐波那契数 2. 爬楼梯 3. 使用最小花费爬楼梯 4. 不同路径 5. 不同路径 II 6. 整数拆分 7. 不同的二叉搜索树 一、理论基础 1. 大纲 动态规划,英文:Dynamic Programm…

企业水、电、气、热等能耗数据采集系统

介绍 通过物联网技术,采集企业水、电、气、热等能耗数据,帮企业建立能源管理体系,找到跑冒滴漏,从而为企业节能提供依据。 进一步为企业实现碳跟踪、碳盘查、碳交易、谈汇报的全生命过程。 为中国碳达峰-碳中和做出贡献。 针对客…

【C++进阶】set的使用

1. 序列式容器和关联式容器 前面,我们已经接触过STL中的部分容器如:string、vector、list、deque、array、forward_list等,这些容器统称为序列式容器,因为逻辑结构为线性序列的数据结构,两个位置存储的值之间⼀般没有紧…

【工具箱】Flash基础及“SD NAND Flash”的测试例程

目录 一、“FLASH闪存”是什么? 1. 简介 2. 分类 3. 性能 4.可靠性 5.易用性 二、SD NAND Flash 1. 概述 2. 特点 3. 引脚分配 4. 数据传输模式 5. SD NAND寄存器 6. 通电图 7. 参考设计 三、STM32测试例程 1. 初始化 2. 单数据块测试 3. 多数据块…

场景题 - 画三角形并只点击三角形触发事件

简介 画一个三角形并仅点击三角形区域才会触发点击事件。 可以拆解成: 画个三角形绑定点击事件(涉及点击区域) 这里提供更多更好用的方法,svg polygon绘制三角形、canvas、css clip-path:polygon( ) 裁剪可视区域,并…

文件和目录的权限管理

定义: 文件和目录的权限管理在操作系统中至关重要,特别是在多用户环境下,它决定了不同用户对文件和目录的访问和操作权限。 一、基本权限类型及表示方法 在Linux系统中,文件和目录的权限分为三类:读取权限(…

谷歌-BERT-第一步:模型下载

1 需求 需求1:基于transformers库实现自动从Hugging Face下载模型 需求2:基于huggingface-hub库实现自动从Hugging Face下载模型 需求3:手动从Hugging Face下载模型 2 接口 3.1 需求1 示例一:下载到默认目录 from transform…

南邮-软件安全--第一次实验报告-非爆破计算校验值

软件安全第一次实验报告,切勿直接搬运(改改再交) 实验要求 1、逆向分析目标程序运行过程,找到程序的关键校验点; 2、以非爆破的方式正确计算crackme的校验值; 内容 使用x32dbg对文件进行分析 打开文件…

思迈特:在AI时代韧性增长的流量密码

作者 | 曾响铃 文 | 响铃说 “超级人工智能将在‘几千天内’降临。” 最近,OpenAI 公司 CEO 山姆奥特曼在社交媒体罕见发表长文,预言了这一点。之前,很多专家预测超级人工智能将在五年内到来,奥特曼的预期,可能让这…

构建可扩展的高校学科竞赛平台:SpringBoot案例分析

2相关技术 2.1 MYSQL数据库 MySQL是一个真正的多用户、多线程SQL数据库服务器。 是基于SQL的客户/服务器模式的关系数据库管理系统,它的有点有有功能强大、使用简单、管理方便、安全可靠性高、运行速度快、多线程、跨平台性、完全网络化、稳定性等,非常…

高效管理学科竞赛:SpringBoot平台的创新应用

1系统概述 1.1 研究背景 随着计算机技术的发展以及计算机网络的逐渐普及,互联网成为人们查找信息的重要场所,二十一世纪是信息的时代,所以信息的管理显得特别重要。因此,使用计算机来管理高校学科竞赛平台的相关信息成为必然。开发…

Zookeeper快速入门:部署服务、基本概念与操作

文章目录 一、部署服务1.下载与安装2.查看并修改配置文件3.启动 二、基本概念与操作1.节点类型特性总结使用场景示例查看节点查看节点数据 2.文件系统层次结构3.watcher 一、部署服务 1.下载与安装 下载: 一定要下载编译后的文件,后缀为bin.tar.gz w…

算法:525.连续数组

题目 链接:leetcode 思路分析(前缀和) 首先介绍一个小技巧 在处理二进制数组的时候,因为数组里面只有0和1,我们可以将所有的0变成-1 这个时候1和-1之间就可以产生很多抵消,有利于处理数组。 在该题中&am…

2.2.1 绘制Canvas路径 - 绘制线条

文章目录 1. 绘制线条2. 绘制具有不同结束线帽的线条3. 绘制向阳花图形 今天我们要一起探讨的是如何使用HTML5的Canvas元素来绘制各种图形。Canvas提供了一个强大的图形绘制API,使我们能够在网页上绘制出各种复杂的图形和动画。接下来,我将通过几个实战示…

从Naive RAG到Agentic RAG:基于Milvus构建Agentic RAG

检索增强生成(Retrieval-Augmented Generation, RAG)作为应用大模型落地的方案之一,通过让 LLM 获取上下文最新数据来解决 LLM 的局限性。典型的应用案例是基于公司特定的文档和知识库开发的聊天机器人,为公司内部人员快速检索内部…

如何在数仓中处理缓慢变化维度(SCD)

在数据仓库中,处理缓慢变化维度(SCD,Slowly Changing Dimension)是一个非常常见且重要的问题。为了保证数据的完整性和准确性,我们通常会采取不同的策略来处理维度表中的数据变化。SCD的核心解决方案是通过不同类型的历…

Run the FPGA VI 选项的作用

Run the FPGA VI 选项的作用是决定当主机 VI 运行时,FPGA VI 是否会自动运行。 具体作用: 勾选 “Run the FPGA VI”: 当主机 VI 执行时,如果 FPGA VI 没有正在运行,系统将自动启动并运行该 FPGA VI。 这可以确保 FPG…