Pytorch之LeNet-5图像分类

news2025/1/1 23:24:47
  • 💂 个人主页:风间琉璃
  • 🤟 版权: 本文由【风间琉璃】原创、在CSDN首发、需要转载请联系博主
  • 💬 如果文章对你有帮助、欢迎关注、点赞、收藏(一键三连)订阅专栏

目录

前言 

一、LeNet-5

二、LeNet-5网络实现

1.定义LeNet-5模型

2.加载数据集

3.训练模型

4.测试模型

三、实现图像分类


前言 

 LeNet-5是一个经典的深度卷积神经网络,由Yann LeCun在1998年提出,旨在解决手写数字识别问题,被认为是卷积神经网络的开创性工作之一。该网络是第一个被广泛应用于数字图像识别的神经网络之一,也是深度学习领域的里程碑之一。

一、LeNet-5

下图是 LeNet-5 的网络结构图,它 接受32×32大小的数字、字符图片,经过第一个卷积层得到[b,6, 28,28]形状的张量,经过一个向下采样层,张量尺寸缩小到[b,6,14,14],经过第二个卷积层,得到[b,16,10,10]形状的张量,同样经过下采样层,张量尺寸缩小到[b,16, 5,5],在进入全连接层之前,先将张量 打成[b,16*5*5 ]的张量,送入输出节点数分别为 120、84 的 2 个全连接层,得到[b,84]的张量,最后通过Gaussian connections层,最终输出[b,10]

LeNet-5的基本结构包括7层网络结构(不含输入层),其中包括2个卷积层、2个降采样层(池化层)、2个全连接层和输出层。LeNet-5 网络层数较少(2 个卷积层和 2 个全连接层),参数量较少,计算代价较低,尤其在现代GPU的加持下,数分钟即可训练好 LeNet-5 网络。 

这里网络结构只给了进行卷积核池化前后的特征图的大小,那么如果确定卷积核的尺寸和通道数呢?

1.输入特征层的channel与卷积核的channel相同

2.输出特征层的channel与卷积核个数相同

经过卷积后的矩阵尺寸大小计算公式为:

N = (W - F + 2P) /  S  +1

①输入图片大小WxW

②卷积核Filter大小FxF

③步长S

④panding填充值P

比如输入层接收大小为 32×32 的手写数字图像,卷积层C1包括6个卷积核,每个卷积核的大小为 5×5 ,步长为1,填充为0。因此,每个卷积核会产生一个大小为 28×28 的特征图(输出通道数为6)。

N(28) = (32-5+0)/1 + 1 =27 + 1 = 28

采样层S2采用最大池化(max-pooling)操作,每个窗口的大小为 2×2 ,步长为2。因此,每个池化操作会从4个相邻的特征图中选择最大值,产生一个大小为 14×14 的特征图(输出通道数为6)。这样可以减少特征图的大小,提高计算效率,并且对于轻微的位置变化可以保持一定的不变性。其他的网络层也是一样的,可以相互推算。

二、LeNet-5网络实现

1.定义LeNet-5模型

根据上面网络模型使用Pytorch实现LeNet-5网络模型的搭建

import torch.nn as nn
import torch.nn.functional as F


class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=6, kernel_size=5)
        self.pool1 = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.pool2 = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(16*5*5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))     # input(3, 32, 32) output(6, 28, 28)
        x = self.pool1(x)             # output(6, 14, 14)
        x = F.relu(self.conv2(x))     # output(16, 10, 10)
        x = self.pool2(x)             # output(16, 5, 5)
        x = x.view(-1, 16*5*5)        # output(16*5*5)
        x = F.relu(self.fc1(x))       # output(120)
        x = F.relu(self.fc2(x))       # output(84)
        x = self.fc3(x)                # output(10)
        return x

if __name__ == '__main__':
    net = LeNet()
    print(net)

2.加载数据集

使用CIFAR10数据集,加载数据集后还需要对数据集进行预处理,如图像格式转换(Tensor)、归一化、标准化等处理。然后使用DataLoader分批次加载数据集,用于训练和测试。

# 预处理
    transform = transforms.Compose(
        [transforms.ToTensor(),  # 将图像转化为tensor,并做归一化:[0,1] 数据类型转换 + 标准化
         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # 输入数据的数值范围标准化为特定的均值和标准差
         ]
    )

    # 加载训练集
    train_set = torchvision.datasets.CIFAR10(root='./data', train=True, transform=transform, download=True)
    train_loader = torch.utils.data.DataLoader(train_set, batch_size=36, shuffle=True, num_workers=0)
    # 加载测试集
    val_set = torchvision.datasets.CIFAR10(root='./data', train=False, transform=transform, download=True)
    val_loader = torch.utils.data.DataLoader(val_set, batch_size=5000, num_workers=0)

    # 使用next函数从val_data_iter迭代器中获取下一个批次的数据
    val_data_iter = iter(val_loader)
    val_image, val_label = next(val_data_iter)

3.训练模型

实例化网络模型,并进行网络模型的训练。

    net = LeNet()
    loss_function = nn.CrossEntropyLoss()
    optimizer = optim.Adam(net.parameters(), lr=0.001)

    for epoch in range(10):  # 训练次数
        # 每次训练的损失值
        running_loss = 0.0
        # 获取批次的索引 step 和数据 data
        for step, data in enumerate(train_loader, start=0):
            # 获取images,labels; data是一个列表[images, labels]
            images, labels = data

            # 将优化器的梯度缓冲区清零
            optimizer.zero_grad()
            # forward + backward + optimize
            # 前向传播,得到模型的输出
            outputs = net(images)
            # 计算模型的输出和真实标签 labels 之间的损失(误差)
            loss = loss_function(outputs, labels)
            # 通过反向传播算法计算损失对模型参数的梯度
            loss.backward()
            # 根据梯度更新模型参数,这是优化器的一次参数更新步骤
            optimizer.step()

4.测试模型

在每训练到500次时,进行一次测试。

            # 测试
            running_loss += loss.item()
            if step % 500 == 499:
                # 关闭梯度计算。因为在验证或测试时不需要计算梯度,所以可以提高运行效率
                with torch.no_grad():
                    outputs = net(val_image)  # [batch, 10]
                    # 选择输出中概率最高的类别作为预测结果,并且是在第一个维度[batch,10]
                    # max 返回找到最大的值以及该值所在的位置(索引),是一个元组(val ,index)
                    predict_y = torch.max(outputs, dim=1)[1]
                    accuracy = torch.eq(predict_y, val_label).sum().item() / val_label.size(0)

                    print('[%d, %5d] train_loss: %.3f  test_accuracy: %.3f' %
                          (epoch + 1, step + 1, running_loss / 500, accuracy))
                    running_loss = 0.0

在网络训练完成后,记得保存网络模型,用于后续的部署和使用。

save_path = './Lenet.pth'
torch.save(net.state_dict(), save_path)

三、实现图像分类

将上面保存的模型用来测试其他图片,检验模型训练的效果。

import torch
import torchvision.transforms as transforms
from PIL import Image, ImageDraw
from model import LeNet

def main():
    # 图片预处理
    transform = transforms.Compose(
        [transforms.Resize((32, 32)),
         transforms.ToTensor(),
         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

    # 分类标签
    classes = ('plane', 'car', 'bird', 'cat',
               'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

    # s实例化网络
    net = LeNet()
    # 加载网络模型
    net.load_state_dict(torch.load('Lenet.pth'))

    img = Image.open('dog.jpg')
    # [H, W, C] --> [C, H, W]
    image = transform(img)
    # 增加维度:[N, C, H, W],使满足网络的输入维度要求
    image = torch.unsqueeze(image, dim=0)

    with torch.no_grad():
        # 得到预测结果
        outputs = net(image)
        # 得到分类标签
        predict = torch.max(outputs, dim=1)[1].numpy()
    print(classes[int(predict)])
    draw = ImageDraw.Draw(img)
    text = classes[int(predict)]
    # 文本的左上角位置
    position = (10, 10)
    # fill 指定文本颜色
    draw.text(position, text, fill='red')
    img.show()

if __name__ == '__main__':
    main()

预测结果:

结束语
感谢你观看我的文章呐~本次航班到这里就结束啦 🛬

希望本篇文章有对你带来帮助 🎉,有学习到一点知识~

躲起来的星星🍥也在努力发光,你也要努力加油(让我们一起努力叭)。

最后,博主要一下你们的三连呀(点赞、评论、收藏),不要钱的还是可以搞一搞的嘛~

不知道评论啥的,即使扣个666也是对博主的鼓舞吖 💞 感谢 💐

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1037976.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

解决apk签名时 no conscrypt_openjdk_jni in java.library.path 方法

使用下面命令时若出现no conscrypt_openjdk_jni in java.library.path java -jar signapk.jar platform.x509.pem platform.pk8 app-debug.apk app-debug_sign.apk 缺少相关库,从以下位置下载,只在 android11下测试通过。 https://download.csdn.net…

2023 年前端 UI 组件库概述,百花齐放!

UI组件库提供了各种常见的 UI 元素,比如按钮、输入框、菜单等,只需要调用相应的组件并按照需求进行配置,就能够快速构建出一个功能完善的 UI。 虽然市面上有许多不同的UI组件库可供选择,但在2023年底也并没有出现一两个明确的解决…

java面试题-常见技术场景

常见技术场景 1.单点登录这块怎么实现的 1.1 概述 单点登录的英文名叫做:Single Sign On(简称SSO),只需要登录一次,就可以访问所有信任的应用系统 在以前的时候,一般我们就单系统,所有的功能都在同一个…

EtherCAT转Modbus网关做为 MODBUS 从站配置案例

兴达易控EtherCAT转Modbus网关可以用作MODBUS从站的配置。这种网关允许将Modbus协议与EtherCAT协议进行转换,从而实现不同通信系统之间的互操作性。通过将Modbus从站配置到网关中,可以实现对Modbus设备的访问和控制。同时,该网关还可以扩展Mo…

mysql基本语句学习(基本)

1.本地登录 mysql -u root -p 密码 mysql开启远程 1.查看数据库 show databases; 2.查看当前所示数据库 select database(); 3.创建数据库 create database 数据库名字; 4.查看创建数据库语句 show create database 数据库名字; 2.…

(十一)VBA常用基础知识:worksheet的各种操作之sheet删除

当前sheet确认 2.Sheets(1).Delete Sub Hello()8 Sheets(1).DeleteSheets(1).Delete End Sub实验得知, Sheets(1).Delete删除的是最左边的sheet 另外,因为有弹出提示信息的确认框,这个在代码执行时,会导致还需要手动点击一下&a…

仿制 Google Chrome 的恐龙小游戏

通过仿制 Google Chrome 的恐龙小游戏,我们可以掌握如下知识点: 灵活使用视口单位掌握绝对定位JavaScript 来操作 CSS 变量requestAnimationFrame 函数的使用无缝动画实现 页面结构 实现页面结构 通过上述的页面结构我们可以知道,此游戏中…

【多态】虚函数表存储在哪个区域?

A:栈 B:堆 C:代码段&#xff08;常量区&#xff09; D:数据段&#xff08;静态区&#xff09; 答案 &#xff1a; 代码段&#xff08;常量区&#xff09; 验证如下&#xff1a; class Person { public:virtual void BuyTicket() { cout << "Person::BuyTicket()&q…

【Hash表】判断有没有重复元素-力扣 217

&#x1f49d;&#x1f49d;&#x1f49d;欢迎来到我的博客&#xff0c;很高兴能够在这里和您见面&#xff01;希望您在这里可以感受到一份轻松愉快的氛围&#xff0c;不仅可以获得有趣的内容和知识&#xff0c;也可以畅所欲言、分享您的想法和见解。 推荐:kuan 的首页,持续学…

认识HTTP和HTTPS协议

HTTPS 是什么 HTTPS 也是一个应用层协议. 是在 HTTP 协议的基础上引入了一个加密层. 为什么要引入加密层呢&#xff1f; HTTP 协议内容都是按照文本的方式明文传输的. 这就导致在传输过程中出现一些被篡改的情况. HTTPS就是在HTTP的基础上进行了加密&#xff0c;进一步的保…

群体遗传学-选择消除分析

一、选择消除分析 所谓选择性清除&#xff1a;当一个有利突变发生后&#xff0c;这个突变基因的适合度越高&#xff0c;就越容易被选择固定。当这个基因被快速固定之后&#xff0c;与此基因座连锁的染色体区域&#xff0c;由于搭车效应也被固定下来&#xff0c;大片紧密连锁的染…

【跟小嘉学习区块链】二、Hyperledger Fabric 架构详解

系列文章目录 【跟小嘉学习区块链】一、区块链基础知识与关键技术解析 【跟小嘉学习区块链】一、区块链基础知识与关键技术解析 文章目录 系列文章目录[TOC](文章目录) 前言一、Hyperledger 社区1.1、Hyperledger(面向企业的分布式账本)1.2、Hyperledger社区组织结构 二、Hype…

UDS 28服务

28服务主要是用来控制报文接收和发送。 具体的服务控制格式&#xff1a; controlType 通信控制类型 tips&#xff1a;Bit7 用于是否抑制积极响应。 communication 报文类型 例子

Mysql 数据类型、运算符

数据类型 数据类型的选择不是越大越好&#xff0c;因为我们业务层一般都是在内存上工作的&#xff0c;效率以及速度是比较快的&#xff0c;但是我们的数据库涉及磁盘的IO操作磁盘的IO操作相对来说是要慢很多的&#xff0c;所以我们在定义表结构的时候每一个字段的数据类型还是比…

API网关是如何提升API接口安全管控能力的

API安全的重要性 近几年&#xff0c;越来越多的企业开始数字化转型之路。数字化转型的核心是将企业的服务、资产和能力打包成服务&#xff08;服务的形式通常为API&#xff0c;API又称接口&#xff0c;下文中提到的API和接口意思相同&#xff09;&#xff0c;从而让资源之间形…

计算机组成原理课程设计

操作控制和顺序控制 操作控制就是由各种微命令来构成的顺序控制就是由P测试和后续微地址构成的 这就构成了整个微指令的三个部分 访存指令就是实现对主存中的数据进行访问或存储 一、 操作控制字段是由各种微命令来构成的&#xff0c;这些微命令怎么来设计&#xff1f; 一个萝卜…

全新贝锐蒲公英客户端6.0:如何实现快速部署、高效异地组网?

贝锐蒲公英客户端6.0版本进行了全新的升级&#xff0c;此次升级对原有企业版、个人版和个人管理端进行了深度整合&#xff0c;不同身份的用户现在可以统一登录&#xff0c;大大简化了异地组网的流程&#xff0c;同时提升了效率。那么贝锐蒲公英客户端6.0&#xff0c;做了哪些深…

Cortex-M3/M4之SVC和PendSV异常

一、SVC异常 SVC(系统服务调用&#xff0c;亦简称系统调用)用于产生系统函数的调用请求。例如&#xff0c;操作系统不让用户程序直接访问硬件&#xff0c;而是通过提供一些系统服务函数&#xff0c;用户程序使用 SVC 发出对系统服务函数的呼叫请求&#xff0c;以这种方法调用它…

更新至2022年上市公司ESG评级评分数据合集(含华证、盟浪、wind、彭博、润灵环球、商道融绿、和讯网、富时罗素数据)

更新至2022年ESG评级评分数据合集&#xff08;含华证、盟浪、wind、彭博、润灵环球、商道融绿、和讯网、富时罗素及世界各国ESG数据&#xff09; 1、来源&#xff1a;整理自wind和csmar 2、具体时间&#xff1a; 华证&#xff1a;2009-2022年、盟浪&#xff1a;2018-2022年、…

Python实现猎人猎物优化算法(HPO)优化LightGBM分类模型(LGBMClassifier算法)项目实战

说明&#xff1a;这是一个机器学习实战项目&#xff08;附带数据代码文档视频讲解&#xff09;&#xff0c;如需数据代码文档视频讲解可以直接到文章最后获取。 1.项目背景 猎人猎物优化搜索算法(Hunter–prey optimizer, HPO)是由Naruei& Keynia于2022年提出的一种最新的…