LeNet基础

news2024/11/24 5:22:47

目录

1.LeNet简介

1.1基本介绍

1.2网络结构 

 2.LetNet在pytorch中的使用

2.1首先定义模型 

2.2初始化数据集,初始化模型,同时训练数据。

2.3 训练结果​编辑

2.4绘制曲线


1.LeNet简介

 1.1基本介绍

 LeNet(LeNet-5)是历史上第一个成功应用于数字识别任务的卷积神经网络模型。由于其优秀的表现和先进的结构被广泛认可,成为深度学习的里程碑之一。LeNet由加拿大籍计算机科学家Yann LeCun在1998年提出,旨在解决手写数字识别问题。它是第一个能够通过卷积层和池化层实现特征提取和降维的卷积神经网络模型。LeNet-5模型的架构包含6层,包括2个卷积层、2个池化层和2个全连接层。LeNet-5的主要优点是它非常高效,具有一定的鲁棒性,并且在处理小尺寸图像时表现出色。今天,LeNet-5已经成为深度学习神经网络的开山鼻祖之一。

LeNet的设计目标是用于识别手写数字,特别是美国支票上的手写数字识别。它由一系列的卷积层和池化层组成,最后连接全连接层进行分类。LeNet的基本结构如下:

1. 输入层(Input Layer):接收输入图像数据。

2. 卷积层(Convolutional Layer):使用卷积核对输入图像进行卷积操作,提取图像特征。同时通过激活函数引入非线性。

3. 池化层(Pooling Layer):对卷积层的输出进行下采样,减小数据的空间维度,减少计算量,并保留重要的特征。

4. 全连接层(Fully Connected Layer):将池化层的输出展平,并连接到一个或多个全连接层,用于图像分类。

5. 输出层(Output Layer):进行最终的分类操作,输出预测结果。

LeNet的创新之处在于引入了卷积层和池化层,使得网络可以自动从原始图像数据中提取和学习特征。这种结构的设计在图像处理任务中非常有效,为后续深度学习模型的发展奠定了基础。

尽管LeNet的规模较小,但它在当时的手写数字识别任务中取得了很好的效果,并为后续更复杂的卷积神经网络的发展提供了启示。LeNet的设计思想和结构对现代深度学习中的卷积神经网络仍然具有重要的影响。

1.2网络结构 

 

每个卷积块中的基本单元是一个卷积层、一个sigmoid激活函数和平均汇聚层。请注意,虽然ReLU和最大汇聚层更有效,但它们在20世纪90年代还没有出现。每个卷积层使用5×5卷积核和一个sigmoid激活函数。这些层将输入映射到多个二维特征输出,通常同时增加通道的数量。第一卷积层有6个输出通道,而第二个卷积层有16个输出通道。每个2×2池操作(步幅2)通过空间下采样将维数减少4倍。卷积的输出形状由批量大小、通道数、高度、宽度决定。

为了将卷积块的输出传递给稠密块,我们必须在小批量中展平每个样本。换言之,我们将这个四维输入转换成全连接层所期望的二维输入。这里的二维表示的第一个维度索引小批量中的样本,第二个维度给出每个样本的平面向量表示。LeNet的稠密块有三个全连接层,分别有120、84和10个输出。因为我们在执行分类任务,所以输出层的10维对应于最后输出结果的数量。

 2.LetNet在pytorch中的使用

 定义了LeNet-5模型,包括特征提取层和分类器层。初始化LeNet-5模型,并定义损失函数和优化器。进行训练循环,包括训练和验证阶段。在每个迭代中计算训练和验证的损失和准确度。在训练循环结束后,进行测试阶段,计算测试的损失和准确度。打印出训练、验证和测试的损失和准确度。

2.1首先定义模型 

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10

# Define LeNet-5 model
class LeNet5(nn.Module):
    def __init__(self, num_classes=10):
        super(LeNet5, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 6, kernel_size=5),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(6, 16, kernel_size=5),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.classifier = nn.Sequential(
            nn.Linear(16 * 5 * 5, 120),
            nn.ReLU(inplace=True),
            nn.Linear(120, 84),
            nn.ReLU(inplace=True),
            nn.Linear(84, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

 2.2初始化数据集,初始化模型,同时训练数据。

# Set random seed for reproducibility
torch.manual_seed(42)

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load CIFAR-10 dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

train_dataset = CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)

val_dataset = CIFAR10(root='./data', train=False, download=True, transform=transform)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)

# Initialize LeNet-5 model
model = LeNet5().to(device)

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training loop
num_epochs = 10
for epoch in range(num_epochs):
    running_loss = 0.0
    correct_predictions = 0
    total_predictions = 0

    # Training phase
    model.train()
    for images, labels in train_loader:
        images = images.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()

        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

        _, predicted = torch.max(outputs.data, 1)
        total_predictions += labels.size(0)
        correct_predictions += (predicted == labels).sum().item()

    train_loss = running_loss / len(train_loader)
    train_accuracy = correct_predictions / total_predictions

    # Validation phase
    model.eval()
    val_loss = 0.0
    val_correct_predictions = 0
    val_total_predictions = 0

    with torch.no_grad():
        for images, labels in val_loader:
            images = images.to(device)
            labels = labels.to(device)

            outputs = model(images)
            loss = criterion(outputs, labels)

            val_loss += loss.item()

            _, predicted = torch.max(outputs.data, 1)
            val_total_predictions += labels.size(0)
            val_correct_predictions += (predicted == labels).sum().item()

    val_loss /= len(val_loader)
    val_accuracy = val_correct_predictions / val_total_predictions

    print(f"Epoch [{epoch + 1}/{num_epochs}], Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.4f}, "
          f"Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.4f}")

# Test phase
model.eval()

test_loss = 0.0
test_correct_predictions = 0
test_total_predictions = 0

with torch.no_grad():
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)

        outputs = model(images)
        loss = criterion(outputs, labels)

        test_loss += loss.item()

        _, predicted = torch.max(outputs.data, 1)
        test_total_predictions += labels.size(0)
        test_correct_predictions += (predicted == labels).sum().item()

test_loss /= len(test_loader)
test_accuracy = test_correct_predictions / test_total_predictions

print(f"Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.4f}")

2.3 训练结果

 2.4绘制曲线

绘制Training accuracy Curve和Validation accuracy Curve

train_losses = []
train_accuracies = []
val_losses = []
val_accuracies = []

num_epochs = 10
for epoch in range(num_epochs):
    running_loss = 0.0
    correct_predictions = 0
    total_predictions = 0

    # Training phase
    # ...

    train_loss = running_loss / len(train_loader)
    train_accuracy = correct_predictions / (total_predictions + 1e-7)  # Add a small epsilon value

    # Validation phase
    # ...

    val_loss = val_loss / len(val_loader)
    val_accuracy = val_correct_predictions / (val_total_predictions + 1e-7)  # Add a small epsilon value

    # Append accuracy and loss values to lists
    train_losses.append(train_loss)
    train_accuracies.append(train_accuracy)
    val_losses.append(val_loss)
    val_accuracies.append(val_accuracy)

    # Print epoch results
    # ...

# Test phase
# ...

test_loss = 0.0
test_correct_predictions = 0
test_total_predictions = 0

# Plotting the curves
plt.figure(figsize=(12, 4))

# Plot training and validation loss curves
plt.subplot(1, 2, 1)
plt.plot(range(1, num_epochs + 1), train_losses, label='Training Loss')
plt.plot(range(1, num_epochs + 1), val_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Validation Loss')
plt.legend()

# Plot training and validation accuracy curves
plt.subplot(1, 2, 2)
plt.plot(range(1, num_epochs + 1), train_accuracies, label='Training Accuracy')
plt.plot(range(1, num_epochs + 1), val_accuracies, label='Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title('Training and Validation Accuracy')
plt.legend()

# Show the plot
plt.tight_layout()
plt.show()

 运行结果

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/703210.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

磁盘阵列(RAID)

什么是磁盘阵列 磁盘阵列(RAID)是一种将多个物理硬盘组合成一个逻辑存储单元的技术。这种技术可以提高数据存储的可靠性、性能或容量,并且可以在某些情况下提供备份和灾难恢复功能。 RAID技术可以通过在多个硬盘之间分配数据来提高性能。例…

事务处理相关

目录 步骤1.创建一个数据表 步骤2:创建项目导入jar包 步骤3:根据表创建模型类 步骤5:创建Service接口和实现类 步骤6:添加jdbc.properties文件 步骤7:创建JdbcConfig配置类 步骤8:创建MybatisConfig配置类 步骤9:创建SpringConfig配置类 步骤10:编写测试类 开启事务 1…

电磁阀原理精髓

一、引用 电磁阀在液/气路系统中,用来实现液路的通断或液流方向的改变,它一般具有一个可以在线圈电磁力驱动下滑动的阀芯,阀芯在不同的位置时,电磁阀的通路也就不同。 阀芯在线圈不通电时处在甲位置,在线圈通电时处在…

算法与数据结构-链表

文章目录 链表和数组的区别常见的链表类型单链表循环链表双向链表 总结 链表和数组的区别 相比数组,链表是一种稍微复杂一点的数据结构。对于初学者来说,掌握起来也要比数组稍难一些。这两个非常基础、非常常用的数据结构,我们常常会放到一块…

Python基础 - global nonlocal

global global作为全局变量的标识符,修饰变量后标识该变量是全局变量 global关键字可以用在任何地方,包括最上层函数中和嵌套函数中 实例1:如下代码,定义了两个x,并且赋值不同 直接调用print(x) 打印的是全局变量x的…

号外!MyEclipse 2023.1.1已发布,更好支持Vue框架

MyEclipse 2023.1.1是之前发布的2023.1.0的一个小错误修复版本,如果您已经安装了MyEclipse 2023,只需检查产品中的更新 (Help > Check for Updates…) 就可以选择这个新版本。或者,下载我们更新的离线安装程序来安装2023.1.1。 MyEclipse…

C# WPF应用使用visual studio的安装程序类的一些坑

重写installer实现自定义安装程序时,项目类型要选择 类库(.NET Framework) 否则会出现命名空间System.Configuration不存在Install的报错 有些可能想实现安装完自动启动应用的功能,就需要获取installer安装路径 var s Context.Parameters["assem…

【Java】网络编程与Socket套接字、UDP编程和TCP编程实现客户端和服务端通信

网络编程客户端和服务器Socket套接字流套接字TCP数据报套接字UDP对比TCP与UDP UDP编程DatagramSocket构造方法:普通方法: DatagramPacket构造方法:普通方法: 实现 TCP编程ServerSocket构造方法普通方法 Socket构造方法普通方法 实现 网络编程 为什么需要…

MyBatis-Plus 实现PostgreSQL数据库jsonb类型的保存

文章目录 在 handle 包下新建Jsonb处理类方式一方式二 PostgreSQL jsonb类型示例新建数据库表含有jsonb类型创建实体类创建Control 发起请求 在 handle 包下新建Jsonb处理类 方式一 import com.alibaba.fastjson.JSON; import com.alibaba.fastjson.serializer.SerializerFea…

低代码开发平台到底省掉了哪些成本?可能大家一直错了

低代码到底是否真正可以降低研发成本?是否每个团队都适合?如果能降低,到底是降低的什么成本?其实我觉得这个是我们每个技术交付团队应该在使用任何产品之前都要考虑的问题。 在我们考虑低代码是否能降低成本的问题前,…

【Python】一文带你学会数据结构中的字典、集合

作者主页:爱笑的男孩。的博客_CSDN博客-深度学习,活动,python领域博主爱笑的男孩。擅长深度学习,活动,python,等方面的知识,爱笑的男孩。关注算法,python,计算机视觉,图像处理,深度学习,pytorch,神经网络,opencv领域.https://blog.csdn.net/Code_and516?typeblog个…

一步一步学OAK之九:通过OAK相机实现视频帧旋转

目录 Setup 1: 创建文件Setup 2: 安装依赖Setup 3: 导入需要的包Setup 4: 定义变量Setup 5: 定义旋转矩形的四个顶点坐标Setup 6: 创建pipelineSetup 7: 创建节点Setup 8: 设置属性Setup 9: 建立链接Setup 10: 连接设备并启动管道Setup 11: 创建与DepthAI设备通信的输入队列和输…

C#核心知识回顾——2.拓展方法、运算符重载、分部类、里氏替换

1.拓展方法 为现有非静态变量类型添加新方法 1.提升程序拓展性 2.不需要再对象中重新写方法 3.不需要继承来添加方法 4.为别人封装的类型写额外的方法 特点&#xff1a; 1.一定是写在静态类中 2.一定是个静态函数 3.第一个参数为拓展目标 4.第一个参数用this修饰 /// <sum…

element table表格支持添加编辑校验

实现效果&#xff1a; 将table表格与form表单结合使用 &#xff08;用el-form外层包裹el-table结合rules进行校验&#xff09; 代码实现 <template><div><el-card class"box-card" shadow"never"><div><el-buttonsize"m…

E8-事关明细表里的控件事件绑定、日期的计算、明细表的求和等问题的处理办法

起因 下面的讲述的事情是从开发出差申请流程开始的。涉及的知识点偏多&#xff0c;且得容我慢慢梳理出来。以下篇幅可能会有点儿长&#xff0c;但内容我会争取写得精彩的。 图1 发起表单样式如图1&#xff0c;我想实现的是当修改出发日期或结束日期的时候&#xff0c;自动计算…

并发-synchronized详解

JDK1.6之前的synchronized关键字一来就直接给对象加了一把重量级锁&#xff0c;频繁地在用户态和内核态之间切换&#xff0c;导致性能非常低。为了弥补synchronized的不足&#xff0c;大佬doug lee写了一个AQS框架&#xff0c;用Java语言实现了ReentrantLock。然后在JDK1.6之后…

电脑文件夹怎么设置密码?3个方法为文件加密!

我的电脑里存了很多重要的文件夹&#xff0c;为了防止信息的泄露&#xff0c;我想把这些文件夹都设置密码。但是不知道具体应该如何操作。请求大家的帮助&#xff01; 我们平常在使用电脑时&#xff0c;可能会将很多重要的文件保存在电脑中。如果不想让别人看到我们这些重要的文…

JMeter安装RabbitMQ测试插件

目录 前言&#xff1a; 具体实现步骤&#xff1a; 1、ant环境搭建 2、AMQP源码下载 3、拷贝JMeter_core.jar包到JMeter-Rabbit-AMQP插件根目录下 4、修改AMQP插件的配置文件 5、打包 6、RabbitMQ客户端插件下载 7、完成以上&#xff0c;重启JMeter创建线程组就可以看到…

智能小家电如何升级Type-C接口充电?

目前市面上智能小家电充电接口还是USB Micor&#xff0c;AC&#xff0c;DC接口等&#xff0c;今年随着欧盟的一纸令下&#xff0c;22年12月24日&#xff0c;欧洲理事会最终批准了“在欧盟范围内统一充电器接口”的法案。这意味着到2024年&#xff0c;usb type-c接口将成为一系列…

openknx初编译

knx协议出来也是很长时间了&#xff0c;但国内相关开发的文章很少&#xff0c;比起zigbee,lora这些网上一搜一大零的&#xff0c;显得可怜。因为公司以后可能会开发knx产品&#xff0c;所以对国外的openknx自已研究了一下。 https://github.com/thelsing/knx 这个就是openknx项…