LeNet网络复现

news2024/11/26 21:39:22

文章目录

  • 1. LeNet历史背景
    • 1.1 早期神经网络的挑战
    • 1.2 LeNet的诞生背景
  • 2. LeNet详细结构
    • 2.1 总览
    • 2.2 卷积层与其特点
    • 2.3 子采样层(池化层)
    • 2.4 全连接层
    • 2.5 输出层及激活函数
  • 3. LeNet实战复现
    • 3.1 模型搭建model.py
    • 3.2 训练模型train.py
    • 3.3 测试模型test.py
  • 4. LeNet的变种与实际应用
    • 4.1 LeNet-5及其优化
    • 4.2 从LeNet到现代卷积神经网络

1. LeNet历史背景

1.1 早期神经网络的挑战

早期神经网络面临了许多挑战。首先,它们经常遇到训练难题,例如梯度消失和梯度爆炸,特别是在使用传统激活函数如sigmoid或tanh时。另外,当时缺乏大规模、公开的数据集,导致模型容易过拟合并且泛化性能差。再者,受限于当时的计算资源,模型的大小和训练速度受到了很大的限制。最后,由于深度学习领域还处于萌芽阶段,缺少许多现代技术来优化和提升模型的表现。

1.2 LeNet的诞生背景

LeNet的诞生背景是为了满足20世纪90年代对手写数字识别的实际需求,特别是在邮政和银行系统中。Yann LeCun及其团队意识到,对于图像这种有结构的数据,传统的全连接网络并不是最佳选择。因此,他们引入了卷积的概念,设计出了更适合图像处理任务的网络结构,即LeNet。

2. LeNet详细结构

2.1 总览

在这里插入图片描述

2.2 卷积层与其特点

卷积层是卷积神经网络(CNN)的核心。这一层的主要目的是通过卷积操作检测图像中的局部特征。

特点:

  1. 局部连接性: 卷积层的每个神经元不再与前一层的所有神经元相连接,而只与其局部区域相连接。这使得网络能够专注于图像的小部分,并检测其中的特征。
  2. 权值共享: 在卷积层中,一组权值在整个输入图像上共享。这不仅减少了模型的参数,而且使得模型具有平移不变性。
  3. 多个滤波器: 通常会使用多个卷积核(滤波器),以便在不同的位置检测不同的特征。

2.3 子采样层(池化层)

池化层是卷积神经网络中的另一个关键组件,用于缩减数据的空间尺寸,从而减少计算量和参数数量。

主要类型:

  1. 最大池化(Max pooling): 选择覆盖区域中的最大值作为输出。
  2. 平均池化(Average pooling): 计算覆盖区域的平均值作为输出。

2.4 全连接层

在卷积神经网络的最后,经过若干卷积和池化操作后,全连接层用于将提取的特征进行“拼接”,并输出到最终的分类器。

特点:

  1. 完全连接: 全连接层中的每个神经元都与前一层的所有神经元相连接。
  2. 参数量大: 由于全连接性,此层通常包含网络中的大部分参数。
  3. 连接多个卷积或池化层的特征: 它的主要目的是整合先前层中提取的所有特征。

2.5 输出层及激活函数

输出层:
输出层是神经网络的最后一层,用于输出预测结果。输出的数量和类型取决于特定任务,例如,对于10类分类任务,输出层可能有10个神经元。

激活函数:
激活函数为神经网络提供了非线性,使其能够学习并进行复杂的预测。

  1. Sigmoid: 取值范围为(0, 1)。
  2. Tanh: 取值范围为(-1, 1)。
  3. ReLU (Rectified Linear Unit): 最常用的激活函数,将所有负值置为0。
  4. Softmax: 常用于多类分类的输出层,它返回每个类的概率。

3. LeNet实战复现

3.1 模型搭建model.py

import torch
from torch import nn

# 自定义网络模型
class LeNet(nn.Module):
    # 1. 初始化网络(定义初始化函数)
    def __init__(self):
        super(LeNet, self).__init__()

        # 定义网络层
        self.Sigmoid = nn.Sigmoid()
        self.c1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, padding=2)
        self.s2 = nn.AvgPool2d(kernel_size=2, stride=2)
        self.c3 = nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5)
        self.s4 = nn.AvgPool2d(kernel_size=2, stride=2)
        self.c5 = nn.Conv2d(in_channels=16, out_channels=120, kernel_size=5)

        # 展开
        self.flatten = nn.Flatten()
        self.f6 = nn.Linear(120, 84)
        self.output = nn.Linear(84, 10)

    # 2. 前向传播网络
    def forward(self, x):
        x = self.Sigmoid(self.c1(x))
        x = self.s2(x)
        x = self.Sigmoid(self.c3(x))
        x = self.s4(x)
        x = self.c5(x)

        x = self.flatten(x)
        x = self.f6(x)
        x = self.output(x)

        return x

if __name__ == "__main__":
    x = torch.rand([1, 1, 28, 28])
    model = LeNet()
    y = model(x)

3.2 训练模型train.py

import torch
from torch import nn
from model import LeNet
from torch.optim import lr_scheduler
from torchvision import datasets, transforms
import os

# 数据转换为tensor格式
data_transformer = transforms.Compose([
    transforms.ToTensor()
])

# 加载训练的数据集
train_dataset = datasets.MNIST(root='./data', train=True, transform=data_transformer, download=True)
train_dataloader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=16, shuffle=True)

# 加载测试的数据集
test_dataset = datasets.MNIST(root='./data', train=False, transform=data_transformer, download=True)
test_dataloader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=16, shuffle=True)

# 使用GPU进行训练
device = "cuda" if torch.cuda.is_available() else "cpu"

# 调用搭好的模型,将模型数据转到GPU上
model = LeNet().to(device)

# 定义一个损失函数(交叉熵损失)
loss_fn = nn.CrossEntropyLoss()

# 定义一个优化器
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

# 学习率每隔10轮, 变换原来的0.1
lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

# 定义训练函数
def train(dataloader, model, loss_fn, optimizer):
    loss, current, n = 0.0, 0.0, 0
    for batch, (x, y) in enumerate(dataloader):
        # 前向传播
        x, y = x.to(device), y.to(device)
        output = model(x)
        cur_loss = loss_fn(output, y)
        _, pred = torch.max(output, axis=1)

        cur_acc = torch.sum(y == pred)/output.shape[0]

        optimizer.zero_grad()
        cur_loss.backward()
        optimizer.step()

        loss += cur_loss.item()
        current += cur_acc.item()
        n = n + 1
    print("train_loss" + str(loss/n))
    print("train_acc" + str(current/n))

# 定义测试函数
def val(dataloader, model, loss_fn):
    model.eval()
    loss, current, n = 0.0, 0.0, 0
    with torch.no_grad():
        for batch, (x, y) in enumerate(dataloader):
            # 前向传播
            x, y = x.to(device), y.to(device)
            output = model(x)
            cur_loss = loss_fn(output, y)
            _, pred = torch.max(output, axis=1)
            cur_acc = torch.sum(y == pred) / output.shape[0]

            loss += cur_loss.item()
            current += cur_acc.item()
            n = n + 1
        print("val_loss" + str(loss / n))
        print("val_acc" + str(current / n))

        return current/n


# 开始训练
epoch = 50
min_acc = 0

for t in range(epoch):
    print(f'epoch{t+1}\n-----------------------------------------------------------')
    train(train_dataloader, model, loss_fn, optimizer)
    a = val(test_dataloader, model, loss_fn)
    # 保存最好的模型权重
    if a > min_acc:
        folder = "sava_model"
        if not os.path.exists(folder):
            os.mkdir("sava_model")
        min_acc = a
        print("sava best model")
        torch.save(model.state_dict(), 'sava_model/best_model.pth')
print('Done!')

3.3 测试模型test.py

import torch
from model import LeNet
from torchvision import datasets, transforms
from torchvision.transforms import ToPILImage

# 数据转换为tensor格式
data_transformer = transforms.Compose([
    transforms.ToTensor()
])

# 加载训练的数据集
train_dataset = datasets.MNIST(root='./data', train=True, transform=data_transformer, download=True)
train_dataloader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=16, shuffle=True)

# 加载测试的数据集
test_dataset = datasets.MNIST(root='./data', train=False, transform=data_transformer, download=True)
test_dataloader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=16, shuffle=True)

# 使用GPU进行训练
device = "cuda" if torch.cuda.is_available() else "cpu"

# 调用搭好的模型,将模型数据转到GPU上
model = LeNet().to(device)

# 加载模型权重并设置为评估模式
model.load_state_dict(torch.load("./sava_model/best_model.pth"))
model.eval()

# 获取结果
classes = [
    "0",
    "1",
    "2",
    "3",
    "4",
    "5",
    "6",
    "7",
    "8",
    "9",
]

# 把tensor转换为图片, 方便可视化
show = ToPILImage()

# 进入验证
for i in range(5):
    x, y = test_dataset[i]
    show(x).show()

    x = torch.unsqueeze(x, dim=0).float().to(device)
    with torch.no_grad():
        pred = model(x)
        predicted, actual = classes[torch.argmax(pred[0])], classes[y]
        print(f'Predicted: {predicted}, Actual: {actual}')

4. LeNet的变种与实际应用

4.1 LeNet-5及其优化

LeNet-5是由Yann LeCun于1998年设计的,并被广泛应用于手写数字识别任务。它是卷积神经网络的早期设计之一,主要包含卷积层、池化层和全连接层。

结构:

  1. 输入层:接收32×32的图像。
  2. 卷积层C1:使用5×5的滤波器,输出6个特征图。
  3. 池化层S2:2x2的平均池化。
  4. 卷积层C3:使用5×5的滤波器,输出16个特征图。
  5. 池化层S4:2x2的平均池化。
  6. 卷积层C5:使用5×5的滤波器,输出120个特征图。
  7. 全连接层F6。
  8. 输出层:10个单元,对应于0-9的手写数字。
    激活函数:Sigmoid或Tanh。

优化:

  1. ReLU激活函数: 原始的LeNet-5使用Sigmoid或Tanh作为激活函数,但现代网络更喜欢使用ReLU,因为它的训练更快,且更少受梯度消失的影响。
  2. 更高效的优化算法: 如Adam或RMSProp,它们通常比传统的SGD更快、更稳定。
  3. 批量归一化: 加速训练,并提高模型的泛化能力。
  4. Dropout: 在全连接层中引入Dropout可以增强模型的正则化效果。

4.2 从LeNet到现代卷积神经网络

从LeNet-5开始,卷积神经网络已经经历了巨大的发展。以下是一些重要的里程碑:

  1. AlexNet (2012): 在ImageNet竞赛中取得突破性的成功。它具有更深的层次,使用ReLU激活函数,以及Dropout来防止过拟合。
  2. VGG (2014): 由于其统一的结构(仅使用3×3的卷积和2x2的池化)而闻名,拥有多达19层的版本。
  3. GoogLeNet/Inception (2014): 引入了Inception模块,可以并行执行多种大小的卷积。
  4. ResNet (2015): 引入了残差块,使得训练非常深的网络变得可能。通过这种方式,网络可以达到上百甚至上千的层数。
  5. DenseNet (2017): 每层都与之前的所有层连接,导致具有非常稠密的特征图。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1053146.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

shopify目录层级释义

└── theme├── assets // assets目录包含主题中使用的所有资源文件,包括图像、CSS和JavaScript文件。├── config // config目录包含主题的配置文件。 配置文件在主题编辑器的主题设置区域中定义设置,并存储它们的值。├── layout // layou…

会议AISTATS(Artificial Intelligence and Statistics) Latex模板参考文献引用问题

前言 在看AISTATS2024模板的时候,发现模板里面根本没有教怎么引用,要被气死了。 如下,引用(Cheesman, 1985)的时候,模板是自己手打上去的?而且模板提供的那三个引用,根本也没有Cheesman这个人&#xff0c…

Mysql各种锁

一.不同存储引擎支持的锁机制 Mysql数据库有多种数据存储引擎,Mysql中不同的存储引擎支持不同的锁机制 MyISAM和MEMORY存储引擎采用的表级锁 InnoDB存储引擎支持行级锁,也支持表级锁,默认情况下采用行级锁 二.锁类型的划分 按照数据操作…

【LeetCode热题100】--21.合并两个有序链表

21.合并两个有序链表 将两个升序链表合并为一个新的 升序 链表并返回。新链表是通过拼接给定的两个链表的所有节点组成的。 /*** Definition for singly-linked list.* public class ListNode {* int val;* ListNode next;* ListNode() {}* ListNode(int val)…

flutter开发实战-应用更新apk下载、安装apk、启动应用实现

flutter开发实战-应用更新apk下载、安装apk、启动应用实现 在开发过程中,经常遇到需要更新下载新版本的apk文件,之后进行应用更新apk下载、安装apk、启动应用。我们在flutter工程中实现下载apk,判断当前版本与需要更新安装的版本进行比对判断…

FFmpeg 命令:从入门到精通 | ffmpeg 命令分类查询

FFmpeg 命令:从入门到精通 | ffmpeg 命令分类查询 FFmpeg 命令:从入门到精通 | ffmpeg 命令分类查询ffmpeg -versionffmpeg -buildconfffmpeg -formatsffmpeg -muxersffmpeg -demuxersffmpeg -codecsffmpeg -decodersffmpeg -encodersffmpeg -bsfsffmpeg…

LabVIEW学习笔记五:错误,visa关闭超时(错误-1073807339)

写的串口调试工具,其中出现了这个错误 这是串口接收的部分,如果没有在很短的时间内收到外界发进来的数据,这里就会报错。 先在网上查了一下,这个问题很常见,我找到了官方的解答: VISA读取或写入时出现超时…

视频高效剪辑,批量调整视频速度,让视频更加精彩

你是否曾经需要调整多个视频的速度,但却苦于手动操作效率低下?如果你也遇到了这样的问题,那么是时候采取行动,使用一款高效易用的视频处理工具了。 首先,我们要进入好简单批量智剪,并在板块栏里选择“任务…

26358-2022 旅游度假区等级划分 思维导图

声明 本文是学习GB-T 26358-2022 旅游度假区等级划分. 而整理的学习笔记,分享出来希望更多人受益,如果存在侵权请及时联系我们 1 范围 本文件规定了旅游度假区的等级划分和依据、总则、基本条件、省级和国家级旅游度假区条件。 本文件适用于旅游度假区的等级认定与复核依据…

pandas数据处理之构建联邦学习数据

本文以天池比赛《车辆贷款违约预测》的数据为例,通过pandas处理数据,构建联邦学习数据,用于FATE框架联邦学习。 通过pandas处理数据 1. 读取数据 下载car_loan_train.csv数据后,用pandas读取数据。 import pandas as pddatapd…

Transformers.js v2.6 现已发布

🤯 新增了 14 种架构 在这次发布中,我们添加了大量的新架构:BLOOM、MPT、BeiT、CamemBERT、CodeLlama、GPT NeoX、GPT-J、HerBERT、mBART、mBART-50、OPT、ResNet、WavLM 和 XLM。这将支持架构的总数提升到了 46 个!以下是一些示例…

IntelliJ IDEA 控制台中文乱码的四种解决方法

前言 IntelliJ IDEA 如果不进行配置的话,运行程序时控制台有时候会遇到中文乱码,中文乱码问题非常严重,甚至影响我们对信息的获取和程序的跟踪。开发体验非常不好。 本文中我总结出四点用于解决控制台中文乱码问题的方法,希望有助…

分布式事务-TCC异常-空回滚

1、空回滚问题: 因为是全局事务,A服务调用服务C的try时服务出现异常服务B因为网络或其他原因还没执行try方法,TCC因为C的try出现异常让所有的服务执行cancel方法,比如B的try是扣减积分 cancel是增加积分,还没扣减就增…

【动态规划】动态规划经典例题 力扣牛客

文章目录 跳台阶 BM63 简单跳台阶扩展 JZ71 简单打家结舍 LC198 中等打家劫舍2 LC213中等最长连续递增序列 LC674 简单乘积最大子数组LC152 中等最长递增子序列LC300 中等最长重复子数组LC718最长公共子串NC BM66最长公共子序列LC1143 中等完全平方数LC279零钱兑换 LC322 中等单…

【图像分割】图像检测(分割、特征提取)、各种特征(面积等)的测量和过滤(Matlab代码实现)

💥💥💞💞欢迎来到本博客❤️❤️💥💥 🏆博主优势:🌞🌞🌞博客内容尽量做到思维缜密,逻辑清晰,为了方便读者。 ⛳️座右铭&a…

基于SSM+Vue的医院住院综合服务管理系统的设计与实现

末尾获取源码 开发语言:Java Java开发工具:JDK1.8 后端框架:SSM 前端:采用Vue技术开发 数据库:MySQL5.7和Navicat管理工具结合 服务器:Tomcat8.5 开发软件:IDEA / Eclipse 是否Maven项目&#x…

前端最新支持四级及以下结构仿企查查、天眼查关联投资机构 股权结构 tree树形结构 控股结构

随着技术的发展,开发的复杂度也越来越高,传统开发方式将一个系统做成了整块应用,经常出现的情况就是一个小小的改动或者一个小功能的增加可能会引起整体逻辑的修改,造成牵一发而动全身。通过组件化开发,可以有效实现单…

数据结构——堆(C语言)

本篇会解决一下几个问题: 1.堆是什么? 2.如何形成一个堆? 3.堆的应用场景 堆是什么? 堆总是一颗完全二叉树堆的某个节点总是不大于或不小于父亲节点 如图,在小堆中,父亲节点总是小于孩子节点的。 如图&a…

外包干了3个月,整个人都萎靡不振了。。。。。

先说一下自己的情况,本科生生,19年通过校招进入广州某软件公司,干了接近4年的功能测试,今年年初,感觉自己不能够在这样下去了,长时间呆在一个舒适的环境会让一个人堕落!而我已经在一个企业干了四年的功能测…

基于Java的美容院管理系统设计与实现(源码+lw+部署文档+讲解等)

文章目录 前言具体实现截图论文参考详细视频演示为什么选择我自己的网站自己的小程序(小蔡coding)有保障的售后福利 代码参考源码获取 前言 💗博主介绍:✌全网粉丝10W,CSDN特邀作者、博客专家、CSDN新星计划导师、全栈领域优质创作…