CNN——LeNet

news2024/12/25 9:26:13

1.LeNet概述       

         LeNet是Yann LeCun于1988年提出的用于手写体数字识别的网络结构,它是最早发布的卷积神经网络之一,可以说LeNet是深度CNN网络的基石。

        当时,LeNet取得了与支持向量机(support vector machines)性能相媲美的成果,成为监督学习的主流方法。 LeNet当时被广泛用于自动取款机(ATM)机中,帮助识别处理支票的数字。

        下面是整个网络的结构图

        LeNet共有8层,其中包括输入层,3个卷积层,2个子采样层(也就是现在的池化层),1个全连接层和1个高斯连接层。

        上图中用C代表卷积层,用S代表采样层,用F代表全连接层。输入size固定在1*32*32,LeNet图片的输入是二值图像。网络的输出为0~9十个数字的RBF度量,可以理解为输入图像属于0~9数字的可能性大小。

2.详解LeNet

下面对图中每一层做详细的介绍:

  • LeNet使用的卷积核大小都为5*5,步长为1,无填充,只是卷积深度不一样(卷积核个数导致生成的特征图的通道数)
  • 激活函数为Sigmoid
  • 下采样层都是使用最大池化实现,池化的核都为2*2,步长为2,无填充

        input输入层,尺寸为1*32*32的二值图

        C1层是一个卷积层。该层使用6个卷积核,生成特征图尺寸为32-5+1=28,输出为6个大小为28*28的特征图。再经过一个Sigmoid激活函数非线性变换。

        S2层是一个下采样层。生成特征图尺寸为28/2=14,得到6个14*14的特征图。

        C3层是一个卷积层,该层使用16个卷积核,生成特征图尺寸为14-5+1=10,输出为16个10*10的特征图。再经过一个Sigmoid激活函数非线性变换。

        S4层是一个下采样层,生成特征图尺寸为10/2=5,得到16个5*5的特征图

        C5层是一个卷积层,卷积核数量增加至120。生成特征图尺寸为5-5+1=1。得到120个1*1的特征图。这里实际上相当于S4全连接了,但仍将其标为卷积层,原因是如果LeNet-5的输入图片尺寸变大,其他保持不变,那该层特征图的维数也会大于1*1,那就不是全连接了。再经过一个Sigmoid激活函数非线性变换。

        F6层是一个全连接层,该层与C5层全连接,输出84张特征图。再经过一个Sigmoid激活函数非线性变换。

        输出层:输出层由欧式径向基函数(高斯)单元组成,每个类别(0~9数字)对应一个径向基函数单元,每个单元有84个输入。也就是说,每个输出RBF单元计算输入向量和该类别标记向量之间的欧式距离,距离越远,PRF输出越大,同时我们也会将与标记向量欧式距离最近的类别作为数字识别的输出结果。当然现在通常使用的Softmax实现

3.使用LeNet实现Mnist数据集分类

1.导入所需库

import torch
import torch.nn as nn
from torchsummary import summary
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from tqdm import tqdm # 显示训练进度条

2.使用GPU

device = 'cuda' if torch.cuda.is_available() else 'cpu'

3.读取Mnist数据集

# 定义数据转换以进行数据标准化
transform = transforms.Compose([
    transforms.ToTensor(),  # 将图像转换为 PyTorch 张量
])

# 下载并加载 MNIST 训练和测试数据集
train_dataset = datasets.MNIST(root='./dataset', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./dataset', train=False, download=True, transform=transform)

# 创建数据加载器以批量加载数据
batch_size = 256
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

4.搭建LeNet

        需要注意的是torch.nn.CrossEntropyLoss自带了softmax函数,所以最后一层使用全连接即可,在训练时使用torch.nn.CrossEntropyLoss

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, kernel_size=5, padding=2) # Mnist尺寸为28*28,这里设置填充变成32*32
        self.sigmoid = nn.Sigmoid()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(6, 16, kernel_size=5)
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(self.sigmoid(self.conv1(x)))
        x = self.pool(self.sigmoid(self.conv2(x)))
        x = self.flatten(x)
        x = self.sigmoid(self.fc1(x))
        x = self.sigmoid(self.fc2(x))
        x = self.fc3(x)
        return x
# 实例化模型
model = LeNet().to(device)
summary(model, (1, 28, 28))

5.训练函数

def train(model, lr, epochs):
    # 将模型放入GPU
    model = model.to(device)
    # 使用交叉熵损失函数
    loss_fn = nn.CrossEntropyLoss().to(device)
    # SGD
    optimizer = torch.optim.SGD(model.parameters(), lr=lr)
    # 记录训练与验证数据
    train_losses = []
    train_accuracies = []
    # 开始迭代   
    for epoch in range(epochs):   
        # 切换训练模式
        model.train()  
        # 记录变量
        train_loss = 0.0
        correct_train = 0
        total_train = 0
        # 读取训练数据并使用 tqdm 显示进度条
        for i, (inputs, targets) in tqdm(enumerate(train_dataloader), total=len(train_dataloader), desc=f"Epoch {epoch+1}/{epochs}", unit='batch'):
            # 训练数据移入GPU
            inputs = inputs.to(device)
            targets = targets.to(device)
            # 模型预测
            outputs = model(inputs)
            # 计算损失
            loss = loss_fn(outputs, targets)
            # 梯度清零
            optimizer.zero_grad()
            # 反向传播
            loss.backward()
            # 使用优化器优化参数
            optimizer.step()
            # 记录损失
            train_loss += loss.item()
            # 计算训练正确个数
            _, predicted = torch.max(outputs, 1)
            total_train += targets.size(0)
            correct_train += (predicted == targets).sum().item()
        # 计算训练正确率并记录
        train_loss /= len(train_dataloader)
        train_accuracy = correct_train / total_train
        train_losses.append(train_loss)
        train_accuracies.append(train_accuracy)

        # 输出训练信息
        print(f"Epoch [{epoch + 1}/{epochs}] - Train Loss: {train_loss:.4f}, Train Acc: {train_accuracy:.4f}")
    # 绘制损失和正确率曲线
    plt.figure(figsize=(10, 5))
    plt.subplot(1, 2, 1)
    plt.plot(range(epochs), train_losses, label='Training Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.subplot(1, 2, 2)
    plt.plot(range(epochs), train_accuracies, label='Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.tight_layout()
    plt.show()

6.模型训练

model = LeNet()
lr = 0.9 # sigmoid两端容易饱和,gradient比较小,学得比较慢,所以学习率要大一些
epochs = 20
train(model,lr,epochs)

7.模型测试 

def test(model, test_dataloader, device, model_path):
    # 将模型设置为评估模式
    model.eval()
    # 将模型移动到指定设备上
    model.to(device)

    # 从给定路径加载模型的状态字典
    model.load_state_dict(torch.load(model_path))

    correct_test = 0
    total_test = 0
    # 不计算梯度
    with torch.no_grad():
        # 遍历测试数据加载器
        for inputs, targets in test_dataloader:  
            # 将输入数据和标签移动到指定设备上
            inputs = inputs.to(device)
            targets = targets.to(device)
            # 模型进行推理
            outputs = model(inputs)
            # 获取预测结果中的最大值
            _, predicted = torch.max(outputs, 1)
            total_test += targets.size(0)
            # 统计预测正确的数量
            correct_test += (predicted == targets).sum().item()
    
    # 计算并打印测试数据的准确率
    test_accuracy = correct_test / total_test
    print(f"Accuracy on Test: {test_accuracy:.4f}")
    return test_accuracy
model_path = save_path
test(model, test_dataloader, device, save_path)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1352405.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

数据库进阶教学——主从复制(Ubuntu22.04主+Win10从)

目录 一、概述 二、原理 三、搭建 1、备份数据 2、主库配置Ubuntu22.04 2.1、设置阿里云服务器安全组 2.2、修改配置文件 /etc/my.cnf 2.3、重启MySQL服务 2.4、登录mysql,创建远程连接的账号,并授予主从复制权限 2.5、通过指令,查…

Python Web框架FastAPI——一个比Flask和Tornada更高性能的API框架

目录 一、FastAPI框架概述 二、FastAPI与Flask和Tornado的性能对比 1、路由性能 2、请求处理性能 3、内存占用 三、FastAPI的优点与特色 四、代码示例 五、注意事项 六、结论 在当今的软件开发领域,快速、高效地构建API成为了许多项目的关键需求。为了满足…

k8s快速搭建

VMware16Pro虚拟机安装教程VMware16.1.2安装及各版本密钥CentOS7.4的安装包:提取码:lp6qVMware搭建Centos7虚拟机教程 搭建完一个镜像 关机 拍摄一个快照,克隆两个作为子节点 0. 环境准备 在开始之前,部署Kubernetes集群机器需要满足以下几个条件&#…

实验笔记之——基于Linux服务器复现Instant-NGP及常用的tmux指令

之前博客实现了基于windows来复现Instant-NGP,本博文在linux服务器上测试 实验笔记之——基于windows复现Instant-NGP-CSDN博客文章浏览阅读444次,点赞15次,收藏7次。之前博客对NeRF-SLAM进行了调研,本博文先复现一下Intant-NGP。…

【Matlab】PSO-BP 基于粒子群算法优化BP神经网络的数据时序预测(附代码)

资源下载: https://download.csdn.net/download/vvoennvv/88689096 一,概述 PSO-BP算法是一种结合了粒子群算法(PSO)和BP神经网络的方法,用于数据时序预测。下面是PSO-BP算法的原理和过程: 1. 数据准备&…

elasticsearch如何操作索引库里面的文档

上节介绍了索引库的CRUD,接下来操作索引库里面的文档 目录 一、添加文档 二、查询文档 三、删除文档 四、修改文档 一、添加文档 新增文档的DSL语法如下 POST /索引库名/_doc/文档id(不加id,es会自动生成) { "字段1":"值1", "字段2&q…

印象笔记02: 笔记本管理系统和空间使用

印象笔记02: 笔记本管理系统和空间使用 印象笔记新建笔记是一件非常容易的事情。笔记多了,就是归纳到笔记本里。 印象笔记一共有三层的笔记结构:最高层级是笔记本组,其次是笔记本,最后是一个个的笔记。合理的分类能够…

HbuilderX中的git的使用

原文链接https://blog.csdn.net/Aom_yt/article/details/119924356

PS 2024全新开挂神器Portraiture v4.1.2升级版,功能强大,一键安装永久使用

关于PS修图插件,相信大家都有安装过使用过,而且还不止安装了一款,比如最为经典的DR5.0人像精修插件,Retouch4me11合1插件,Portraiture磨皮插件,这些都是人像精修插件中的领跑者。 其中 Portraiture 刚刚升…

魔棒无人直播系统有哪些优势?

随着科技的发展,越来越多新鲜事物的出现,它们代替了我们做很多的事情,开始解放着自己的双手,其中,无人直播的出现,就让直播变得更加简单。 因为是无人直播,所以全程不需要真人出镜,…

探索Java的魅力

从本篇文章开始,小编准备写一个关于java基础学习的系列文章,文章涉及到java语言中的基础组件、实现原理、使用场景、代码案例。看完下面一系列文章,希望能加深你对java的理解。 本篇文章作为本系列的第一篇文章,主要介绍一些java…

【Mybatis】Mybatis如何防止sql注入

🍎个人博客:个人主页 🏆个人专栏: Mybatis ⛳️ 功不唐捐,玉汝于成 目录 前言 正文 1、使用参数化的 SQL 语句: 2、使用动态 SQL 标签: 3、禁止拼接 SQL: 4、限制参数类…

dmetl5授权查看与更新

1.查看dmetl5授权到期时间 需要登录管理端&#xff0c;菜单栏选择“管理”-“license管理”即可查看授权到期时间。如下图&#xff1a; 2.dmetl5更新授权的方法 dmetl5的<安装目录>\scheduler\config路径下&#xff0c;默认会有一个trail.key的文件&#xff0c;删除后&am…

Ribbon相关面试及答案(2024)

1、Ribbon是什么&#xff0c;它在微服务架构中扮演什么角色&#xff1f; Ribbon是一个客户端负载均衡器&#xff0c;它在微服务架构中扮演着关键性的角色。Ribbon的设计理念是在客户端进行服务发现和负载均衡&#xff0c;这种方式不同于传统的通过中心化的负载均衡器&#xff…

案例216:基于微信小程序的垃圾分类系统

文末获取源码 开发语言&#xff1a;Java 框架&#xff1a;springboot JDK版本&#xff1a;JDK1.8 数据库&#xff1a;mysql 5.7 开发软件&#xff1a;eclipse/myeclipse/idea Maven包&#xff1a;Maven3.5.4 小程序框架&#xff1a;uniapp 小程序开发软件&#xff1a;HBuilder …

【macOS】将macOS安装到U盘,让它在U盘/SSD启动运行

将macOS安装到U盘&#xff0c;让它在U盘/SSD启动运行。 1.从AppStore下载macOS&#xff1b; 2.双击下载的安装包文件&#xff1a; 例如&#xff1a;我下载的是&#xff1a;macOS 14.x 文件名&#xff1a;安装macOS Sonoma 文件位置&#xff1a;/Application/ 【注意】&…

广州求职招聘(找工作)去哪里找比较好

在广州找工作&#xff0c;可以选择“吉鹿力招聘网”这个平台。它是一个号称直接和boss聊的互联网招聘神器&#xff0c;同时&#xff0c;“吉鹿力招聘网”作岗位比较齐全&#xff0c;企业用户也多&#xff0c;比较全面。在“吉鹿力招聘网”历即可投递岗位。 广州找工作上 吉鹿力…

mysql之四大引擎、账号管理以及建库

一.数据库存储引擎 1.1存储引擎的查看 1.2InnoDB 1.3MyISAM 1.4 MEMORY 1.5 Archive 二.数据库管理 2.1元数据库简介 2.2元数据库分类 2.3 相关操作 2.4 MySQL库 三.数据表管理 3.1三大范式 3.2 基本数据类型 3.2.1优化原则 3.3 整形 3.4 实数 3.5 字符串 3.6 text&…

13.Go 异常

1、宕机 Go语言的类型系统会在编译时捕获很多错误&#xff0c;但有些错误只能在运行时检查&#xff0c;如数组访问越界、空指针引用等&#xff0c;这些运行时错误会引起宕机。 一般而言&#xff0c;当宕机发生时&#xff0c;程序会中断运行&#xff0c;并立即执行在该gorouti…

vue和react哪种框架使用范围更广

Vue和React都是非常流行的前端JavaScript框架&#xff0c;它们各自有着广泛的应用场景和支持者。选择使用哪一个框架往往取决于特定的项目需求、开发团队的熟悉程度以及生态系统的偏好。以下是这两个框架的一些主要特点&#xff0c;以帮助比较它们的使用范围&#xff1a; React…