LeNet-5上手敲代码

news2025/1/16 21:02:33

LeNet-5

LeNet-5Yann LeCun在1998年提出,旨在解决手写数字识别问题,被认为是卷积神经网络的开创性工作之一。该网络是第一个被广泛应用于数字图像识别的神经网络之一,也是深度学习领域的里程碑之一。

LeNet-5的整体架构:

在这里插入图片描述

总体来看LeNet-5由两个部分组成:

  • 卷积编码器:由两个卷积层和两个下采样层组成;
  • 全连接层密集块:由三个全连接层组成

特点:

1.相比MLPLeNet使用了相对更少的参数,获得了更好的结果。

2.设计了MaxPool来提取特征

代码实现

1. 模型文件的实现

通过观察模型的整体架构,可以知到LeNet-5只用了三个基本的层——卷积层、下采样层、全连接层,因此我们很容易写出模型的基本框架。

其中Gaussian connections也是一个全连接层。Gaussian Connections利用的是RBF函数(径向欧式距离函数),计算输入向量和参数向量之间的欧式距离。目前该方式基本已淘汰,取而代之的是Softmax

为了提高模型的性能,我们会在卷积层与下采样层之间添加一个Relu激活函数,因此模型的整体流程架构为:

Convolutions -> Relu->Subsampling -> Convolutions -> Relu-> Subsampling -> Full connection -> Full connection -> Full connection

pytorch中,卷积层对应的是nn.Conv2d()方法, 下采样层可以使用pytorch中的最大池化下采样nn.MaxPool2d()方法来实现,全连接层可以使用nn.Linear()方法来实现。

确定参数:

卷积层:对于LeNet-5论文中输入的图片是 32 × 32 32 \times 32 32×32大小的图片(图片通道个数为3)。因此第一个卷积层的输入的通道个数为3,输出的通道个数为16,也就是说一共有16个卷积核。卷积核的个数等于通过卷积后图片的通道个数

我们可以根据如下公式来计算出卷积核的大小。

计算卷积后图像宽和高的公式

  • I n p u t : ( N , C i n , H i n , W i n ) Input:(N, C_{in},H_{in},W_{in}) Input(NCinHinWin)

  • O u t p u t : ( N , C o u t , H o u t , W o u t ) Output:(N,C_{out},H_{out},W_{out}) Output(NCoutHoutWout)

H o u t = [ H i n + 2 × p a d d i n g [ 0 ] − d i l a t i o n [ 0 ] × ( k e r n e l _ s i z e [ 0 ] − 1 ) − 1 s t r i d e [ 0 ] + 1 ] H_{out} = [\frac{H_{in} + 2 \times padding[0] - dilation[0] \times (kernel\_size[0] - 1) - 1}{stride[0]} + 1] Hout=[stride[0]Hin+2×padding[0]dilation[0]×(kernel_size[0]1)1+1]

W o u t = [ W i n + 2 × p a d d i n g [ 1 ] − d i l a t i o n [ 1 ] × ( k e r n e l _ s i z e [ 1 ] − 1 ) − 1 s t r i d e [ 1 ] + 1 ] W_{out} = [\frac{W_{in} + 2 \times padding[1] - dilation[1] \times (kernel\_size[1] - 1) - 1}{stride[1]} + 1] Wout=[stride[1]Win+2×padding[1]dilation[1]×(kernel_size[1]1)1+1]

公式中dilation我们没有使用,默认情况为1,输入的图片为 32 × 32 × 3 32 \times 32 \times 3 32×32×3输出为 28 × 28 × 6 28 \times 28 \times 6 28×28×6,通过公式,我们很容易算出 k e r n e l s i z e = ( 5 , 5 ) kernel_{size} = (5, 5) kernelsize=(5,5)【通常情况下如果通过卷积层后的图片的大小没有很明显的缩小(成倍数缩小),那么stride一般为默认值1,通过以上公式,我们可以求得每一个卷积核的大小 。

最大池化下采样:由于特征图通过最大池化下采样层之后,图片的大小变为原来的一半,因此我们知道在长度方向上每两个像素之间取一个最大值,这样才能将长度变为原来的一半,宽度方向上每两个像素之间取一个最大值,这样才能将宽度变为原来的一半。结合起来得到池化层的每一个滑动窗口的大小为 2 × 2 2 \times 2 2×2,也就是说,每四个像素取一个最大值。

在这里插入图片描述

全连接层:输入为上一个层的输出数据大小,输出为自定义大小,对于第一个全连接层,输入为下采样层的输出,即: 5 × 5 × 16 5 \times 5 \times 16 5×5×16 个矩阵值。输出为下一个全连接层单元的个数(第二个全连接层的单元个数为84个),可以推出所有全连接层的单元个数。

model.py

import torch
import torch.nn as nn
import torch.nn.functional as F


class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, (5, 5))
        self.pool1 = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, (5, 5))
        self.pool2 = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool1(x)
        x = F.relu(self.conv2(x))
        x = self.pool2(x)
        x = x.view(-1, 16 * 5 * 5)   # 改变张量形状为一个二维张量,第一个维度是自动推断的,第二个维度设定为16 * 5 * 5
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.fc3(x)
        return x


if __name__ == '__main__':
    model =  LeNet()
    x = torch.randn((3, 32, 32))
    output = model(x)
    print(x)

2. 训练程序

写训练程序的基本步骤为:

  1. 加载训练数据
  2. 初始化模型
  3. 设定损失函数
  4. 设定优化器
  5. 设定迭代次数
  6. 根据情况保存模型权重文件

训练数据我们使用的是CIFAR10中的训练数据,验证集的数据也使用的是CIFAR10中的数据,同时将训练集和验证集的数据进行转换(转换为tensor类型,进行归一化)。设置dataloader,训练集的batch_size64,并且进行随机打乱,设置num_workers2,验证集的batch_size5000,进行随机打乱,设置num_workers2

num_workers:用于设置是否使用多线程读取数据,开启后会加快数据读取速度,但是会占用更多内存,内存较小的电脑可以设置为2或者0

训练数据时,我们在每次的500步之后进行一次验证,验证的方式为,加载验证集,然后输入到网络中进行预测,得到输出的最大值的索引,然后再与真实标签进行比较,统计为True的个数,然后除以所有的标签的个数,得到最后的模型的正确率。

predict_y = torch.max(outputs, dim=1)[1]
accuracy = torch.eq(predict_y, test_label).sum().item() / test_label.size(0)  # .item() 方法将结果转换为标量,即 Python 中的普通数字类型。

在迭代完所有的步数之后进行保存模型的权重文件。

train.py

import torch
import torchvision
from torch import nn, optim
from torch.utils.data import DataLoader

from model import LeNet


def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    transform = torchvision.transforms.Compose(
        [torchvision.transforms.ToTensor(), torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

    # 训练集
    train_set = torchvision.datasets.CIFAR10(root="./data", train=True, download=False, transform=transform)
    train_loader = DataLoader(dataset=train_set, batch_size=64, shuffle=True, num_workers=2)

    # 验证集
    test_set = torchvision.datasets.CIFAR10(root='./data', train=False, download=False, transform=transform)
    test_loader = DataLoader(dataset=test_set, batch_size=5000, shuffle=True, num_workers=0)

    # 实例化网络,损失函数,优化器
    net = LeNet().to(device)
    net.load_state_dict(torch.load('LeNet_200.pth'))  # 加载权重
    loss_function = nn.CrossEntropyLoss().to(device)
    optimizer = optim.Adam(net.parameters(), lr=0.001)

    epochs = 200
    epoch = 0
    # 开始训练
    print("training...")
    while epoch <= epochs:
        epoch += 1
        running_loss = 0.0
        for step, data in enumerate(train_loader):
            print(f"epoc: {epoch}, step: {step}")
            inputs, lables = data
            inputs, lables = inputs.to(device), lables.to(device)   # 将数据移动到GPU上
            optimizer.zero_grad()
            output = net(inputs)
            loss = loss_function(output, lables)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            if step % 500 == 499:   # 每500个batch_size之后进行验证一次
                with torch.no_grad():
                    test_image, test_label = next(iter(test_loader))  # iter(test_loader)作用是设定一个迭代器,这行代码的作用是取出验证集中的一个batch_size的图片和对应的标签。
                    test_image, test_label = test_image.to(device), test_label.to(device)  # 将数据移动到 GPU 上
                    outputs = net(test_image)
                    predict_y = torch.max(outputs, dim=1)[1]
                    accuracy = torch.eq(predict_y, test_label).sum().item() / test_label.size(0)  # .item() 方法将结果转换为标量,即 Python 中的普通数字类型。
                    print('[%d, %5d] train_loss: %.3f  test_accuracy: %.3f' %
                          (epoch + 1, step + 1, running_loss / 500, accuracy))
                    running_loss = 0.0
        print(f"The epoc is {epoch}")
    print("Finish Training")
    save_path = "./LeNet.pth"
    torch.save(net.state_dict(), save_path)


if __name__ == '__main__':
    main()

3. 验证程序

验证程序,首先需要加载图片,然后进行转换(包括裁剪为模型的输入形状大小【这里为 32 × 32 32 \times 32 32×32】,然后转换为tensor类型,最后进行归一化),将预处理后的图片送入到模型中,模型输出的是一个batch_size个一维向量,每一个一维向量有10个数,表示输出的类别一共有10个,取10个中值最大的数的索引作为预测的类别,可以使用以下代码:predict = torch.max(outputs, dim=1)[1].numpy(),这表示在模型输出的结果中,取第一个维度上的10个数取最大值的索引,并将其转换为numpy类型的数据。然后将这个数对照标签的映射关系,可以得到最终预测的类别。

varify.py

import torch
import torchvision.transforms as transforms
from PIL import Image

from model import LeNet


def main():
    transform = transforms.Compose(
        [transforms.Resize((32, 32)),
         transforms.ToTensor(),
         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

    classes = ('plane', 'car', 'bird', 'cat',
               'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

    net = LeNet()
    net.load_state_dict(torch.load('LeNet_250.pth'))

    im = Image.open('2.jpg')  # 加载图片
    im = transform(im)  # [C, H, W]
    im = torch.unsqueeze(im, dim=0)  # [N, C, H, W]

    with torch.no_grad():  # 用于设置在该上下文中不进行梯度计算,因为推断时不需要计算梯度,可以提高计算效率。
        outputs = net(im)
        predict = torch.max(outputs, dim=1)[1].numpy()
    print(classes[int(predict)])


if __name__ == '__main__':
    main()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1658991.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

3分钟快速了解VR全景编辑器

说到VR全景&#xff0c;想必大多数人都见过那种可以360旋转拖动观看的图片。虽然这种技术已经不算新鲜&#xff0c;如果你以为这就是VR全景的全部&#xff0c;那就大错特错了&#xff01; 上面看到的这种形式&#xff0c;只能算VR全景的第一层形态。现在的VR全景已经发展成为了…

vue+canvas实现根据数据展示不同高度,不同j渐变颜色的长方体效果

文章目录 不一样的长方体1. 实现效果预览2.实现思路2.1效果难点2.2 实现思路 3.实现3.1 测试数据编写3.2 编写canvas绘制函数 不一样的长方体 1. 实现效果预览 俗话说的好&#xff0c;没有实现不了的页面效果&#xff0c;只有禁锢的思想&#xff0c; 这不ui又给整了个新奇的页…

模型查询器在使用别名后不能使用tp6

在我们定义了模型的查询器时&#xff0c;再通过模型进行连表加别名的时候&#xff0c;使用查询器&#xff0c;查询器会没办法使用&#xff1b; 那我们可以将查询器前缀增加表名或者__TABLE__ 以上两种方式都可以&#xff0c;个人建议使用__TABLE__&#xff0c;因为这个查询器可…

单单单单单の刁队列

在数据结构的学习中&#xff0c;队列是一种常用的线性数据结构&#xff0c;它遵循先进先出&#xff08;FIFO&#xff09;的原则。而单调队列是队列的一种变体&#xff0c;它在特定条件下保证了队列中的元素具有某种单调性质&#xff0c;例如单调递增或单调递减。单调队列在处理…

Linux -- > vim

vi和vim是什么 vi和vim是两款流行的文本编辑器&#xff0c;广泛用于Unix和类Unix系统中。它们以其强大的功能和灵活的编辑能力而闻名&#xff0c;特别是在编程和系统管理中非常受欢迎。 vi&#xff08;Visual Interface&#xff09; vi是最初的文本编辑器之一&#xff0c;由…

AI赋能EasyCVR视频汇聚/视频监控平台加快医院安防体系数字化转型升级

近来&#xff0c;云南镇雄一医院发生持刀伤人事件持续发酵&#xff0c;目前已造成2人死亡21人受伤。此类事件在医院层出不穷&#xff0c;有的是因为医患纠纷、有的是因为打架斗殴。而且在每日大量流动的人口中&#xff0c;一些不法分子也将罪恶的手伸到了医院&#xff0c;实行扒…

不要错过!实景三维倾斜摄影在3D引擎的丝滑用法

在3D领域&#xff0c;倾斜摄影模型的应用是一个常见的瓶颈。工程建设、工业制造、科学分析、古建遗产&#xff0c;倾斜摄影是占主导地位的处理对象&#xff0c;但模型数据量大、精度要求高以及线上线下同步困难等&#xff0c;会导致生成的三维项目出现瑕疵。 所以在行业内&…

Electron学习笔记(二)

文章目录 相关笔记笔记说明 三、引入现代前端框架1、配置 webpack&#xff08;1&#xff09;安装 webpack 和 electron-webpack&#xff1a;&#xff08;2&#xff09;自定义入口页面 2、引入 Vue&#xff08;1&#xff09;安装 Vue CLI &#xff08;2&#xff09;调试配置 -- …

【解决】Android APK文件安装时 已包含数字签名相同APP问题

引言 在开发Android程序过程中&#xff0c;编译好的APK文件&#xff0c;安装至Android手机时&#xff0c;有时会报 包含数字签名相同的APP 然后无法安装的问题&#xff0c;这可能是之前安装过同签名的APP&#xff0c;但是如果不知道哪个是&#xff0c;无法有效卸载&#xff0c;…

KaiwuDB 参编的《分析型数据库技术要求》标准正式发布

近期&#xff0c;中国电子工业标准化技术协会正式发布团体标准《分析型数据库技术要求》&#xff08;项目号&#xff1a;T-CESA 2023-006&#xff09;。该标准由中国电子技术标准化研究院、KaiwuDB&#xff08;上海沄熹科技有限公司&#xff09; 等国内 16 家企业联合起草&…

婚恋程序_婚恋系统_交友程序_ 婚恋相亲交友系统-一站式搭建婚恋平台-社交婚恋系统-相亲交友APP小程序H5系统婚恋交友社交软件开发语音视频聊天平台定制开发

快速搭建线上平台 赋予十大线上盈利 快速精准牵线匹配 会员资料管理跟进 精美多样海报系统 红娘独立办公系统 丰富拓客引流工具 合伙红娘拓展客源 可多区域连锁运营 外呼电销到店邀约 线下约见服务管理 1对1技术服务支持 无感自动更新升级 行业领先的研发技术与服…

武汉凯迪正大—钢管焊缝裂纹探伤仪

产品概述 武汉凯迪正大无损探伤仪是一种便携式工业无损探伤仪器&#xff0c; 能够快速便捷、无损伤、精确地进行工件内部多种缺陷&#xff08;裂纹、夹杂、气孔等&#xff09;的检测、定位、评估和诊断。既可以用于实验室&#xff0c;也可以用于工程现场。 设置简单&#xff0c…

Swift 集合类型

集合类型 一、集合的可变性二、数组&#xff08;Arrays&#xff09;1、数组的简单语法2、创建一个空数组3、创建一个带有默认值的数组4、通过两个数组相加创建一个数组5、用数组字面量构造数组6、访问和修改数组7、数组的遍历 三、集合&#xff08;Sets&#xff09;1、集合类型…

IDEA 使用maven编译,控制台出现乱码问题的解决方式

前言 使用idea进行maven项目的编译时&#xff0c;控制台输出中文的时候出现乱码的情况。 通常出现这样的问题&#xff0c;都是因为编码格式不一样导致的。既然是maven出的问题&#xff0c;我们在idea中查找下看可以如何设置文件编码。 第一种方式 在pom.xml文件中&#xff…

LeetCode-2079. 给植物浇水【数组 模拟】

LeetCode-2079. 给植物浇水【数组 模拟】 题目描述&#xff1a;解题思路一&#xff1a;简单的模拟题&#xff0c;初始化为0&#xff0c;考虑先不浇灌每一个植物解题思路二&#xff1a;初始化为n&#xff0c;考虑每一个植物需要浇灌解题思路三&#xff1a;0 题目描述&#xff1a…

2024车载测试还有发展吗?

2024年已过接近1/4了,你是不是还在围观车载测试行业的发展? 现在入车载测试还来得及吗? 如何高效学习车载测试呢? 首先我们看一下车载测试行情发展,通过某大平台,我们后去数据如下: 这样的数据可以预估一下未来车载测试还是会持续发展. 随着科技的发展和汽车行业的不断创新,…

Python import 必看技巧:打造干净利落的代码结构

大家好,学习Python你肯定绕不过一个概念import,它是连接不同模块的桥梁,是实现代码复用和模块化的关键。本文将带你深入探索Python中import的原理,并分享一些实用的导入技巧。 1. import 原理 导入机制概述 在Python中,模块(module)是一种封装Python代码的方式,它允许…

华为eNSP学习—IP编址

IP编址 IP编址子网划分例题展示第一步:机房1的子网划分第二步:机房2的子网划分第三步:机房3的子网划分IP编址 明确:IPv4地址长度32bit,点分十进制的形式 ip地址构成=网络位+主机位 子网掩码区分网络位和主机位 学此篇基础: ①学会十进制与二进制转换 ②学会区分网络位和…

宋仕强论道之新质生产力

宋仕强论道之新质生产力&#xff0c;宋仕强说当前5G通信、人工智能、万物互联、工业互联网、数字经济、新能源技术和产业等领域正蓬勃发展&#xff0c;成为未来经济增长的重要推动力&#xff0c;也是目前提倡的新质生产力的重要组成部分。而这些领域的发展都离不开数据的采集、…

每日一题7:Pandas-重命名列

一、每日一题 编写一个解决方案&#xff0c;按以下方式重命名列&#xff1a; id 重命名为 student_idfirst 重命名为 first_namelast 重命名为 last_nameage 重命名为 age_in_years 返回结果格式如下示例所示。 解答&#xff1a; import pandas as pddef renameColumns(studen…