【深度学习】LeNet网络架构

news2024/11/26 16:34:05

文章目录

  • 什么是LeNet
  • 代码实现网络架构


什么是LeNet

LeNet是一种经典的卷积神经网络,由Yann LeCun等人在1998年提出。它是深度学习中第一个成功应用于手写数字识别的卷积神经网络,并且被认为是现代卷积神经网络的基础。

LeNet模型包含了多个卷积层和池化层,以及最后的全连接层用于分类。其中,每个卷积层都包含了一个卷积操作和一个非线性激活函数,用于提取输入图像的特征。池化层则用于缩小特征图的尺寸,减少模型参数和计算量。全连接层则将特征向量映射到类别概率上。

在这里插入图片描述


代码实现网络架构

如何搭建网络模型参考博客:Pytorch学习笔记(模型训练)
在这里插入图片描述我们采用CIFAR-10数据集进行训练测试,上面网络模型是1个channel的32x32,而我们的数据集是3个channel的32x32,模型结构不变,改变一下输入输出大小。
model.py:

import torch
from torch import nn


# 搭建网络模型
class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=5, stride=1, padding=0),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
            nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=0),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
            nn.Flatten(),
            nn.Linear(32 * 5 * 5, 120),
            nn.ReLU(),
            nn.Linear(120, 84),
            nn.ReLU(),
            nn.Linear(84, 10),
        )

    def forward(self, x):
        x = self.model(x)
        return x


# 测试
if __name__ == '__main__':
    leNet = LeNet()
    input = torch.ones((64, 3, 32, 32))
    output = leNet(input)
    print(output.shape)

train.py

import torch.optim
import torchvision.datasets
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

from learning.lenet.model import LeNet

# 1. 数据集
dataset_train = torchvision.datasets.CIFAR10("./data", train=True, transform=torchvision.transforms.ToTensor(),
                                             download=True)
dataset_test = torchvision.datasets.CIFAR10("./data", train=True, transform=torchvision.transforms.ToTensor(),

                                            download=True)
train_data_size = len(dataset_train)
test_data_size = len(dataset_test)
# 2. 加载数据集
dataloader_train = DataLoader(dataset_train, batch_size=64)
dataloader_test = DataLoader(dataset_test, batch_size=64)

# 3. 搭建model
leNet = LeNet()
if torch.cuda.is_available():
    leNet = leNet.cuda()

# 4. 创建损失函数
loss_fn = nn.CrossEntropyLoss()
if torch.cuda.is_available():
    loss_fn = loss_fn.cuda()

# 5. 优化器
learning_rate = 0.1
optimizer = torch.optim.SGD(leNet.parameters(), lr=learning_rate)  # 随机梯度下降

# 6. 设置训练网络的一些参数
total_train_step = 0  # 记录训练次数
total_test_step = 0  # 训练测试次数
epoch = 5  # 训练轮数

# 补充tensorboard
writer = SummaryWriter("../../logs")

# 开始训练
for i in range(epoch):
    print(f"--------第{i+1}轮训练开始--------")
    # 训练
    leNet.train()
    for data in dataloader_train:
        imgs, targets = data
        if torch.cuda.is_available():
            imgs = imgs.cuda()
            targets = targets.cuda()
        outputs = leNet(imgs)
        loss = loss_fn(outputs, targets)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_train_step += 1
        if total_train_step % 100 == 0:
            print(f"训练次数:{total_train_step}---loss:{loss.item()}")
            writer.add_scalar("train_loss", loss.item(), total_train_step)

    # 测试
    leNet.eval()
    total_test_loss = 0  # 总体的误差
    total_accuracy = 0  # 总体的正确率
    with torch.no_grad():
        for data in dataloader_test:
            imgs, targets = data
            if torch.cuda.is_available():
                imgs = imgs.cuda()
                targets = targets.cuda()
            outputs = leNet(imgs)
            loss = loss_fn(outputs, targets)
            total_test_loss += loss.item()
            accuracy = (outputs.argmax(1) == targets).sum()
            total_accuracy += accuracy
    print(f"整体测试集上的loss:{total_test_loss}")
    print(f"整体测试集上的准确率:{total_accuracy/test_data_size}")
    writer.add_scalar("test_loss", total_test_loss, total_test_step)
    writer.add_scalar("total_accuracy", total_accuracy/test_data_size, total_test_step)
    total_test_step += 1

    # 保存每一轮训练的模型
    torch.save(leNet, f"leNet_{i+1}.pth")
    print("模式已保存")


writer.close()

在这里插入图片描述

5轮训练中,第5轮的准确率是最高的,采用第5轮的模型进行测试:

test.py

import torch
import torchvision.transforms
from PIL import Image

from learning.lenet.model import LeNet


# 需要测试的图片
image_path = "../../imgs/airplane.png"
image = Image.open(image_path)
image = image.convert('RGB')  # png图片多了一个透明度通道,修改成rgb三个通道
transform = torchvision.transforms.Compose([torchvision.transforms.Resize((32, 32)),
                                            torchvision.transforms.ToTensor()])
image = transform(image)
print(image.shape)


# 引入网络架构



# 读取网络模型  如果保存的模型是通过gpu训练出来的,需要添加 map_location=torch.device("cpu")
model_load = torch.load("leNet_5.pth", map_location=torch.device("cpu"))
# 原有的图片是没有bitch-size的,而我们的输入是需要的
image = torch.reshape(image, (1, 3, 32, 32))
model_load.eval()
with torch.no_grad():
    outputs = model_load(image)
print(outputs)

classes = ('plane', 'car', 'bird', 'cat',
               'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

print(classes[outputs.argmax(1)])

在这里插入图片描述


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1029121.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

JOSEF约瑟 剩余电流继电器PFR-5 PFE-W-20 国产化改造ZLR-G81 ZCT-45

系列型号: PFR-003剩余电流继电器 PFR-03剩余电流继电器 PFR-5剩余电流继电器 PFR-W-105互感器 PFR-W-140互感器 PFR-W-20互感器 PFR-W-210互感器 PFR-W-30互感器 PFR-W-35互感器 PFR-W-70互感器 一、用途 PFR剩余电流继电器(以下简称继电器…

【教学类】小2班学号字帖(A4横版2份)

图片展示: 背景需求: 突然接到通知,明天下午临时去带小2班。 小班刚入园的孩子,能给他们提供什么样的可操作的学具呢? 思来想去,还是让生成一份学号字帖,让幼儿熟悉自己的学号,让我也熟悉幼儿的名字和学…

phpstudy2016 RCE漏洞验证

文章目录 漏洞描述漏洞验证 漏洞描述 PHPStudyRCE(Remote Code Execution),也称为phpstudy_backdoor漏洞,是指PHPStudy软件中存在的一个远程代码执行漏洞。 漏洞验证 打开phpstudy2016,用bp自带的浏览器访问www目录下…

可视化报表设计器的功能内容是什么?

当前,随着社会化发展程度越来越深,传统的表单制作方式已经无法满足需求了,此时,低代码技术平台的出现,可以在一定程度上帮助不同行业的客户实现流程化办公管理,从而实现提质增效的办公效率。 可视化报表设计…

Yolov8小目标检测-添加模块改进-实验记录

简介,本文通过结合了一些先进的算法改进了yolov8小目标检测能力,以下是一些记录。 数据集:足球比赛数据集,里面只有两个类别足球和人。 兄弟姐妹们,如果本文对你有用,点赞收藏一下呗,☺️☺️…

云可观测性:提升云环境中应用程序可靠性

随着云计算的兴起和广泛应用,越来越多的企业将其应用程序和服务迁移到云环境中。在这个高度动态的环境中,确保应用程序的可靠性和可管理性成为了一个迫切的需求。云可观测性作为一种解决方案,针对这一需求提供了有效的方法和工具。本文将介绍…

单臂路由的配置

目录 单臂路由 单臂路由是什么 为什么要用单臂路由 单臂路由的注意事项 单臂路由的原理 单臂路由的优缺点 单臂路由的实验 ensp Cisco H3C 单臂路由是什么 单臂路由是一种特殊的路由器,它的设计目的是实现在一个路由器的一个接口上通过配置子接口&#xf…

森林防火可视化智能监管与风险预警系统解决方案

一、方案背景 森林火灾是世界八大自然灾害之一,具有发生面广、突发性强、破坏性大、危险性高、处置扑救特别困难等特点,严重危及人民生命财产和森林资源安全,甚至引发生态灾难。有效预防和及时控制森林火灾是保护国家生态建设成果、推进生态…

【实战案例】技术转项目经理容易踩的坑,我都踩了

“带团队容易,带好团队难。” 这是身边一位项目经理近期在团队管理方面的深刻感悟。目前,他手上的一个项目被迫暂停了,项目团队也散了。下面给大家简要分享下这个项目案例。 【案例分享】 小李负责的是一个二次开发的项目,所涉及…

新型智慧公厕“1+3+N”架构,平台、系统、应用的创新

近年来,随着人民生活水平的提高,人们对公共设施的要求也越来越高。其中,如厕问题一直是人们关注的焦点,但传统的公厕设施已经不能满足人们对干净、舒适、安全的需求,这促使了新型智慧公厕的诞生与应用,以如…

Puppeter与Electron的结合,使用Electron创建可视化界面

前言 上一篇文章:Puppeteer基础入门、常见应用、利用谷歌插件编写Puppeteer脚本,简单介绍了Puppeteer的基本使用,以及如何编写一个脚本。 但是呢脚本的运行需要在node环境里,开发人员可能没什么问题。但是如果你写的这个脚本要给…

Aspose转pdf乱码问题

一、问题描述 ​ 在centos服务器使用aspose.word转换word文件为pdf的时候显示中文乱码(如图),但是在win服务器上使用可以正常转换 二、问题原因 由于linux服务器缺少对应的字库导致文件转换出现乱码的 三、解决方式 1.将window中字体(c:\windows\fonts)放到linux…

软件过程能力成熟度评估——CSMM认证

CSMM认证又称为“软件过程能力过程成熟度评估”,由中国电子技术标准化研究院联合五十余家产学研用相关方结合我国实际,自主制定的团体标准,于2021年6月8号发布,目的是为了帮助国内软件企业对自身的软件能力进行评估和判断&#xf…

Redis实战(10)-一条命令在Redis是如何执行的?

Redis Server一旦和某客户端建立连接,就会在事件驱动框架中注册可读事件,对应客户端的命令请求。 整个命令处理过程可分阶段: 命令解析,processInputBufferAndReplicate命令执行,processCommand结果返回,…

APEX:开启Android系统新篇章的应用扁平化技术

APEX:开启Android系统新篇章的应用扁平化技术 Android Pony Express (APEX) 是在 Android Q 中引入的一种容器格式,用于安装流程中较低级系统模块的更新。该格式方便了系统组件的更新,这些组件不适合标准的 Android 应用程序模型。一些示例组…

计算机系大学生,可以通过Java做什么副业?这篇文章给你答案!

文章目录 前言发现副业机会提高效率面向人群 如何开启自己的副业价格优势需要课设的人多吗怎么宣传生成器的使用 生成器介绍安装功能介绍文档查询功能生成的JavaWeb系统示例生成的C#生成的Javaswing生成的VueER图 、UML、功能图..生成的C、C系统 前言 计算机系科班出身的学生&a…

轻量服务器2核与1核的区别

​ 1.核心数量 轻量服务器2核与1核最明显的区别在于核心数量。1核服务器只有一个处理器核心,而2核服务器有两个处理器核心。这使得2核服务器在处理数据时能够同时执行更多的任务。 2.并行处理能力 由于只有1个核心,1核服务器不具备并行处理任务的能力。而…

天津专升本文化课考试计算机应用基础考试大纲(2023年9月修订)

天津市高等院校“高职升本科”招生统一考试计算机应用基础考试大纲(2023年9月修订) 一、考试性质 天津市高等院校“高职升本科”招生统一考试是由合格的高职高专毕业生参加的选拔性 考试。高等院校根据考生的成绩,按照已确定的招生计划&am…

低功耗蓝牙物联网:未来连接的无限可能

物联网是连接各种设备和传感器的网络,其目的是实现信息的交换和共享,提高效率并优化生活。在这个领域,低功耗蓝牙(BLE)正在发挥着越来越重要的作用。 低功耗蓝牙是一种无线通信技术,它的主要特点是低功耗和…

HarmonyOS之 组件的使用

一 容器 1.1 容器分类 Column表示沿垂直方向布局的容器。Row表示沿水平方向布局的容器。 1.2 主轴和交叉轴 主轴:在Column容器中的子组件是按照从上到下的垂直方向布局的,其主轴的方向是垂直方向;在Row容器中的组件是按照从左到右的水平方向…