LeNet-5(论文复现)

news2025/1/12 0:45:59

LeNet-5(论文复现)

本文所涉及所有资源均在传知代码平台可获取

文章目录

    • LeNet-5(论文复现)
        • 概述
        • LeNet-5网络架构介绍
        • 训练过程
        • 测试过程
        • 使用方式
        • 说明

概述

LeNet是最早的卷积神经网络之一。1998年,Yann LeCun第一次将LeNet卷积神经网络应用到图像分类上,在手写数字识别任务中取得了巨大成功。LeNet通过连续使用卷积和池化层的组合提取图像特征。
出自论文《Gradient-Based Learning Applied to Document Recognition》

LeNet-5网络架构介绍

在这里插入图片描述

  • 输入层

    输入32×32通道数为1的图片

  • C1层(卷积层)

    使用6个5×5大小的卷积核,padding=0,stride=1,得到6个28×28大小的特征图

    激活函数: ReLU

    **可训练参数:**6×(5×5+1)=1566×(5×5+1)=156

  • S2层(池化层)

    最大池化,池化窗大小2×2,stride=2

    **可训练参数:**6×(1+1)=126×(1+1)=12,其中第一个 1 为池化对应的 2*2 感受野中最大的那个数的权重 w,第二个 1 为偏置 b。

  • C3层(卷积层)

    使用16个5×5大小的卷积核,padding=0,stride=1,得到16个10×10大小的特征图

    激活函数: ReLu

    **可训练参数:**6×(5×5×3+1)+6×(5×5×4+1)+3×(5×5×4+1)+1×(5×5×6+1)=15166×(5×5×3+1)+6×(5×5×4+1)+3×(5×5×4+1)+1×(5×5×6+1)=1516

    16 个卷积核并不是都与 S2 的 6 个通道层进行卷积操作,如下图所示,C3 的前六个特征图(0,1,2,3,4,5)由 S2 的相邻三个特征图作为输入,对应的卷积核尺寸为:5x5x3;接下来的 6 个特征图(6,7,8,9,10,11)由 S2 的相邻四个特征图作为输入对应的卷积核尺寸为:5x5x4;接下来的 3 个特征图(12,13,14)号特征图由 S2 间断的四个特征图作为输入对应的卷积核尺寸为:5x5x4;最后的 15 号特征图由 S2 全部(6 个)特征图作为输入,对应的卷积核尺寸为:5x5x6

  • S4层(池化层)

    最大池化,池化窗大小2×2,stride=2

    **可训练参数:**16×(1+1)=3216×(1+1)=32

  • C5层(卷积层/全连接层)

    由于该层卷积核的大小与输入图像相同,故也可认为是全连接层。

    C5 层是卷积层,使用 120 个 5×5x16 大小的卷积核,padding=0,stride=1进行卷积,得到 120 个 1×1 大小的特征图:5-5+1=1。即相当于 120 个神经元的全连接层。

    值得注意的是,与C3层不同,这里120个卷积核都与S4的16个通道层进行卷积操作。

    激活函数: ReLU

    **可训练参数:**120×(5×5×16+1)=48120120×(5×5×16+1)=48120

  • F6层(全连接层)

    F6 是全连接层,共有 84 个神经元,与 C5 层进行全连接,即每个神经元都与 C5 层的 120 个特征图相连。计算输入向量和权重向量之间的点积,再加上一个偏置,结果通过 sigmoid 函数输出。

    **可训练参数:**84×(120+1)84×(120+1)

  • OUTPUT层(全连接层)

    最后的 Output 层也是全连接层,是 Gaussian Connections,采用了 RBF 函数(即径向欧式距离函数),计算输入向量和参数向量之间的欧式距离(目前已经被Softmax 取代)。

    **可训练参数:**84×1084×10

使用 LeNet-5 网络结构创建 MNIST 手写数字识别分类器

MNIST是一个非常有名的手写体数字识别数据集,训练样本:共60000个,其中55000个用于训练,另外5000个用于验证;测试样本:共10000个。MNIST数据集每张图片是单通道的,大小为28x28

在这里插入图片描述

下载并加载数据,并对数据进行预处理

# 下载MNIST数据集
    train_set = datasets.MNIST(root = "./data", train = True, download = True, transform = pipline_train)
    test_set = datasets.MNIST(root = "./data", train = False, download = True, transform = pipline_test)
    # 加载数据集
    train_data = torch.utils.data.DataLoader(train_set, batch_size = opt.batch_size, shuffle = True)
    test_data = torch.utils.data.DataLoader(test_set, batch_size = opt.batch_size, shuffle = False)

    train_data_size = len(train_data)
    test_data_size = len(test_data)
    print("训练数据集长度:{}\n测试数据集长度:{}".format(train_data_size, test_data_size))

若本地无MNIST数据集,会在当前目录下新建一个data文件夹存放数据

在这里插入图片描述

MNIST数据集中的图片数据以ubyte格式存储,ubyte是一种无符号字节类型,取值范围在0~255之间。MNIST数据集的图像数据文件为"train-images-idx3-ubyte.gz"和"t10k-images-idx3-ubyte.gz",其中前者存储了训练数据,后者存储了测试数据。

由于 MNIST 数据集图片尺寸是 28x28 单通道的,而 LeNet-5 网络输入 Input 图片尺寸是 32x32,使用 transforms.Resize 将输入图片尺寸调整为 32x32

pipline_train = transforms.Compose([
    # 随机旋转图片
    transforms.RandomHorizontalFlip(),
    # 将图片尺寸resize到32x32
    transforms.Resize((32, 32)),
    # 将图片转化为Tensor格式
    transforms.ToTensor(),
    # 正则化(当模型出现过拟合的情况时,用来降低模型的复杂度)
    transforms.Normalize((0.1307,), (0.3081,))
])
pipline_test = transforms.Compose([
    # 将图片尺寸resize到32x32
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

搭建 LeNet-5 神经网络结构,并定义前向传播的过程

# 搭建 LeNet-5 神经网络结构,并定义前向传播的过程
class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 5)  # 输入通道的数量;输出通道的数量(也就是卷积核的数量);卷积核的大小
        self.relu = nn.ReLU()
        self.maxpool1 = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.maxpool2 = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x = self.maxpool1(x)
        x = self.conv2(x)
        x = self.relu(x)
        x = self.maxpool2(x)
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        output = F.log_softmax(x, dim = 1)
        return output
训练过程
def train_runner(model, device, trainloader, optimizer):
    # 训练模型, 启用 BatchNormalization 和 Dropout, 将BatchNormalization和Dropout置为True
    model.train()
    total = 0
    correct = 0.0

    # enumerate迭代已加载的数据集,同时获取数据和数据下标
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        # 把模型部署到device上
        inputs, labels = inputs.to(device), labels.to(device)
        # 保存训练结果
        outputs = model(inputs)
        # 计算损失和
        # 多分类情况通常使用cross_entropy(交叉熵损失函数), 而对于二分类问题, 通常使用sigmod
        loss = F.cross_entropy(outputs, labels)
        # 初始化梯度
        optimizer.zero_grad()
        # 反向传播
        loss.backward()
        # 更新参数
        optimizer.step()

        # 获取最大概率的预测结果
        # dim=1表示返回每一行的最大值对应的列下标
        predict = outputs.argmax(dim = 1)
        total += labels.size(0)
        correct += (predict == labels).sum().item()

        if i % 1000 == 0:
            # loss.item()表示当前loss的数值
            print("Train Loss: {:.4f}, Accuracy: {:.2f}%".format(loss.item(), 100 * (correct / total)))

    return loss.item(), 100 * (correct / total)

测试过程
def val_runner(model, device, testloader):
    # 模型验证, 必须要写, 否则只要有输入数据, 即使不训练, 它也会改变权值
    # 因为调用eval()将不启用 BatchNormalization 和 Dropout, BatchNormalization和Dropout置为False
    model.eval()
    # 统计模型正确率, 设置初始值
    correct = 0.0
    test_loss = 0.0
    total = 0
    best_acc = 0.0
    # torch.no_grad将不会计算梯度, 也不会进行反向传播
    with torch.no_grad():
        for data, label in testloader:
            data, label = data.to(device), label.to(device)
            output = model(data)
            test_loss += F.cross_entropy(output, label).item()
            predict = output.argmax(dim = 1)
            # 计算正确数量
            total += label.size(0)
            correct += (predict == label).sum().item()
        # 计算损失值
        val_acc = correct / total
        print("Test loss: {:.4f}, Accuracy: {:.2f}%".format(test_loss / total, 100 * val_acc))
        # 保存最佳模型
        if val_acc > best_acc:
            best_acc = val_acc
            torch.save(model, './model-mnist_best.pth')  # 保存模型

    return test_loss / total, 100 * val_acc
使用方式

可直接在IDLE中运行代码,其中train.py文件用于训练网络,model.py文件用于定义网络,test.py文件用来对训练完的模型做一个测试推理。
也可直接调用命令行实现,如

python train.py --epochs 100 --lr 0.001 --batch_size 64

若不指定相关参数,train.py默认为训练10轮,学习率0.001,batch_size为64

说明

本项目的文件夹架构如下

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

代码中还使用了tensorboard可视化工具,以下是tensorboard可视化结果

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

最终在测试样本上,average_loss降到了 0.00129,accuracy 达到了 97.28%。可以说 LeNet-5 的效果非常好!

使用test.py进行测试推理时,由于MNIST数据集中的图片数据以ubyte格式存储,需要转成图片的格式,具体转换脚本参照mnist2jpg.py

# 获取图像数据和标签
    img, label = mnist_train[i]

    # 转换图像数据为numpy数组
    img_np = np.squeeze(img.numpy())

    # 展示图像
    plt.imshow(img_np, cmap = 'gray')
    plt.axis('off')  # 关闭坐标轴显示
    plt.savefig('{}/mnist_image_{}.jpg'.format(save_dir, label), bbox_inches = 'tight', pad_inches = 0)
    plt.close()

测试图片

在这里插入图片描述

文章代码资源点击附件获取

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2215548.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

站在用户视角审视:以太彩光与PON之争

作者:科技作家-郑凯 园区,是企业数字化转型的“中心战场”。 云计算、大数据、人工智能等数智化技术在园区里“战火交织”;高清视频、协同办公,智慧安防等大量创新应用产生的海量数据在园区内“纵横驰骋”;加上大量的IOT和智能化设备涌入“战场”,让园区网络面对着难以抵御的…

基于YOLOv9的空中飞鸟识别检测系统(附项目源码和数据集下载)

项目完整源码与模型 YOLOv9实现源码:项目完整源码及教程-点我下载YOLOv5实现源码:项目完整源码及教程-点我下载YOLOv7实现源码:项目完整源码及教程-点我下载YOLOv8实现源码:项目完整源码及教程-点我下载数据集:空中飞…

等保测评的技术要求与管理要求详解

等保测评,即网络安全等级保护测评,是根据《中华人民共和国网络安全法》、《信息安全技术网络安全等级保护基本要求》等相关法规和标准,对信息系统的安全性进行评估的过程。等保测评分为技术要求和管理要求两大方面,旨在确保信息系…

外包干了5天,技术明显退步

我是一名本科生,自2019年起,我便在南京某软件公司担任功能测试的工作。这份工作虽然稳定,但日复一日的重复性工作让我逐渐陷入了舒适区,失去了前进的动力。两年的时光匆匆流逝,我却在原地踏步,技术没有丝毫…

PicoQuant GmbH公司Dr. Christian Oelsner到访东隆科技

昨日,德国PicoQuant公司的光谱和显微应用和市场专家Dr.Christian Oelsner莅临武汉东隆科技有限公司。会议上Dr. Christian Oelsner就荧光寿命光谱和显微技术的最新研究和应用进行了深入的交流与探讨。此次访问不仅加强了两家公司在高科技领域的合作关系,…

成都爱尔李晓峰主任讲解“寒”已至,眼需“养”

温度逐渐走低,寒冷空气的到来带走夏季闷热潮湿,也带走了空气中的水分,环境变得干燥,眼睛水分蒸发加快,十分容易造成眼部不适,干眼患者尤其需要注意! 有干眼问题的患者,在这样的天气下…

案例实践 | 以长安链为坚实底层,江海链助力南通民政打造慈善应用标杆

案例名称-江海链 ■ 实施单位 中国移动通信集团江苏有限公司南通分公司、中国移动通信集团江苏有限公司 ■ 业主单位 江苏省南通市民政局 ■ 上线时间 2023年12月 ■ 用户群体 南通市民政局、南通慈善总会等慈善组织及全市民众 ■ 用户规模 全市近30家慈善组织&#…

【网络安全】漏洞案例:提升 Self-XSS 危害

未经许可,不得转载。 文章目录 Self-XSS-1Self-XSS-2Self-XSS-1 目标应用程序为某在线商店,在其注册页面的First Name字段中注入XSS Payload: 注册成功,但当我尝试登录我的帐户时,我得到了403 Forbidden,即无法登录我的帐户。 我很好奇为什么我无法登录我的帐户,所以我…

【unity框架开发起步】一些框架开发思维和工具类封装

文章目录 前言一、Editor操作二、快捷导出unity包三、快捷打开存储目录四、封装transform操作1、localPosition赋值简化2、封装修改transform.localPosition X Y Z3、封装transform.localPosition XY、XZ 和YZ4、Transform 重置 五、封装概率函数六、方法过时七、partial 关键字…

STM32_实验2_printf函数重定向输出

掌握串口通信,并将 printf 函数重定向到串口输出。 USART1 global interrupt 的使能与不使能对系统的影响主要体现在如何处理串口通信事件上,如数据接收和发送的方式。这些不同的配置会直接影响系统的效率、响应时间以及资源的使用。 配置printf函数使用…

递归查找子物体+生命周期函数

递归查找子物体 相关代码&#xff1a; Transform FindChild(string childName, Transform parent){if (childName parent.name) {return parent;}if (parent.childCount < 1){return null;}Transform obj null;for(int i 0; i < parent.childCount; i){Transform t …

Hbuilder如何修改px转rpx的比例如图

mac系统点击hbuilderX图标如图&#xff1a; 打开偏好设置后选择语言服务配置&#xff0c;在px转rpx中设置对应比例&#xff0c;例如设计稿是375那就是0.5&#xff0c;设计稿是750就是1&#xff0c;公式按照设计稿宽度/750 得出比例

Python 自动排班表格(代码分享)

✅作者简介&#xff1a;2022年博客新星 第八。热爱国学的Java后端开发者&#xff0c;修心和技术同步精进。 &#x1f34e;个人主页&#xff1a;Java Fans的博客 &#x1f34a;个人信条&#xff1a;不迁怒&#xff0c;不贰过。小知识&#xff0c;大智慧。 &#x1f49e;当前专栏…

前端reactvue3——实现滚动到底加载数据

文章目录 ⭐前言⭐react 实现滚动加载⭐vue3 实现滚动加载⭐总结⭐结束 ⭐前言 大家好&#xff0c;我是yma16&#xff0c;本文分享 前端react&vue3——实现滚动加载&#xff08;到底部加载&#xff09; scrollTop 属性 一个双精度浮点值&#xff0c;表示元素当前从原点垂直…

全国41G带高度的矢量建筑楼块

建筑数据用于精确描述建筑物的空间位置和范围&#xff0c;支持城市规划、灾害管理、房地产开发及各类空间分析等多领域应用。 数据介绍 带有高度的建筑数据在气候建模、能耗分析及社会经济活动等多种应用中起着至关重要的作用。 尽管这些信息至关重要&#xff0c;但以往的研…

入门必备:什么是鸿蒙系统

鸿蒙系统(HarmonyOS)是华为公司发布的一款基于微内核的面向全场景的分布式操作系统。以下是对它的具体介绍&#xff1a; 1. 核心特点: • 分布式能力&#xff1a;这是鸿蒙系统的核心优势之一。它能够将多种不同类型的智能终端设备连接起来&#xff0c;使这些设备在系统层面相…

MySQL数据的导出

【图书推荐】《MySQL 9从入门到性能优化&#xff08;视频教学版&#xff09;》-CSDN博客 《MySQL 9从入门到性能优化&#xff08;视频教学版&#xff09;&#xff08;数据库技术丛书&#xff09;》(王英英)【摘要 书评 试读】- 京东图书 (jd.com) MySQL9数据库技术_夏天又到了…

MySQL中什么情况下类型转换会导致索引失效

文章目录 1. 问题引入2. 准备工作3. 案例分析3.1 正常情况3.2 发生了隐式类型转换的情况 4. MySQL隐式类型转换的规则4.1 案例引入4.2 MySQL 中隐式类型转换的规则4.3 验证 MySQL 隐式类型转换的规则 5. 总结 如果对 MySQL 索引不了解&#xff0c;可以看一下我的另一篇博文&…

Hadoop集群安装

集群规划 node01node02node03角色主节点从节点从节点NameNode√DataNode√√√ResourceManager√NodeManager√√√SecondaryNameNode√Historyserver√ 上传安装包到node01 解压到指定目录 tar -zxvf /bigdata/soft/hadoop-3.3.3.tar.gz -C /bigdata/server/ 创建软链接 cd…

在线matlab环境

登陆https://ww2.mathworks.cn/ 在线文档https://ww2.mathworks.cn/help/index.html 在线环境[需要先登陆]