LeNet卷积神经网络-笔记

news2024/11/18 7:44:42

LeNet卷积神经网络-笔记

在这里插入图片描述
手写分析LeNet网三卷积运算和两池化加两全连接层计算分析
在这里插入图片描述
基于paddle飞桨框架构建测试代码

#输出结果为:
#[validation] accuracy/loss: 0.9530/0.1516
#这里准确率为95.3%
#通过运行结果可以看出,LeNet在手写数字识别MNIST验证数据集上的准确率高达92%以上。

详细源代码如下所示:

# 导入需要的包
import paddle
import numpy as np
from paddle.nn import Conv2D, MaxPool2D, Linear

## 组网
import paddle.nn.functional as F

# 定义 LeNet 网络结构
#==============================================================================
class LeNet(paddle.nn.Layer):
    def __init__(self, num_classes=1):
        super(LeNet, self).__init__()
        # 创建卷积和池化层
        # 创建第1个卷积层
        self.conv1 = Conv2D(in_channels=1, out_channels=6, kernel_size=5)
        self.max_pool1 = MaxPool2D(kernel_size=2, stride=2)
        # 尺寸的逻辑:池化层未改变通道数;当前通道数为6
        # 创建第2个卷积层
        self.conv2 = Conv2D(in_channels=6, out_channels=16, kernel_size=5)
        self.max_pool2 = MaxPool2D(kernel_size=2, stride=2)
        # 创建第3个卷积层
        self.conv3 = Conv2D(in_channels=16, out_channels=120, kernel_size=4)
        # 尺寸的逻辑:输入层将数据拉平[B,C,H,W] -> [B,C*H*W]
        # 输入size是[28,28],经过三次卷积和两次池化之后,C*H*W等于120
        self.fc1 = Linear(in_features=120, out_features=64)
        # 创建全连接层,第一个全连接层的输出神经元个数为64, 第二个全连接层输出神经元个数为分类标签的类别数
        self.fc2 = Linear(in_features=64, out_features=num_classes)
    # 网络的前向计算过程
    def forward(self, x):
        x = self.conv1(x)
        # 每个卷积层使用Sigmoid激活函数,后面跟着一个2x2的池化
        x = F.sigmoid(x)
        x = self.max_pool1(x)
        x = F.sigmoid(x)
        x = self.conv2(x)
        x = self.max_pool2(x)
        x = self.conv3(x)
        # 尺寸的逻辑:输入层将数据拉平[B,C,H,W] -> [B,C*H*W]
        x = paddle.reshape(x, [x.shape[0], -1])
        x = self.fc1(x)
        x = F.sigmoid(x)
        x = self.fc2(x)
        return x
#==========================================================================================
# 输入数据形状是 [N, 1, H, W]
# 这里用np.random创建一个随机数组作为输入数据
x = np.random.randn(*[3,1,28,28])
x = x.astype('float32')

# 创建LeNet类的实例,指定模型名称和分类的类别数目
model = LeNet(num_classes=10)
# 通过调用LeNet从基类继承的sublayers()函数,
# 查看LeNet中所包含的子层
print(model.sublayers())
print(x.shape)
x = paddle.to_tensor(x)
print(x.shape)
for item in model.sublayers():
    # item是LeNet类中的一个子层
    # 查看经过子层之后的输出数据形状
    try:
        x = item(x)
    except:
        x = paddle.reshape(x, [x.shape[0], -1])
        x = item(x)
    if len(item.parameters())==2:
        # 查看卷积和全连接层的数据和参数的形状,
        # 其中item.parameters()[0]是权重参数w,item.parameters()[1]是偏置参数b
        print(item.full_name(), x.shape, item.parameters()[0].shape, item.parameters()[1].shape)
    else:
        # 池化层没有参数
        print(item.full_name(), x.shape)  
#
'''
#显示子图层列表model.sublayers()
[
  Conv2D(1, 6, kernel_size=[5, 5], data_format=NCHW), 
  MaxPool2D(kernel_size=2, stride=2, padding=0), 
  Conv2D(6, 16, kernel_size=[5, 5], data_format=NCHW), 
  MaxPool2D(kernel_size=2, stride=2, padding=0), 
  Conv2D(16, 120, kernel_size=[4, 4], data_format=NCHW), 
  Linear(in_features=120, out_features=64, dtype=float32), 
  Linear(in_features=64, out_features=10, dtype=float32)
]
'''    

# -*- coding: utf-8 -*-
# LeNet 识别手写数字
import os
import random
import paddle
import numpy as np
import paddle
from paddle.vision.transforms import ToTensor
from paddle.vision.datasets import MNIST

# 定义训练过程
def train(model, opt, train_loader, valid_loader):
    # 开启0号GPU训练
    use_gpu = True
    paddle.device.set_device('gpu:0') if use_gpu else paddle.device.set_device('cpu')
    print('start training ... ')
    model.train()
    for epoch in range(EPOCH_NUM):
        for batch_id, data in enumerate(train_loader()):
            img = data[0]
            label = data[1] 
            # 计算模型输出
            logits = model(img)
            # 计算损失函数
            loss_func = paddle.nn.CrossEntropyLoss(reduction='none')
            loss = loss_func(logits, label)
            avg_loss = paddle.mean(loss)

            if batch_id % 2000 == 0:
                print("epoch: {}, batch_id: {}, loss is: {:.4f}".format(epoch, batch_id, float(avg_loss.numpy())))
            avg_loss.backward()
            opt.step()
            opt.clear_grad()

        model.eval()
        accuracies = []
        losses = []
        for batch_id, data in enumerate(valid_loader()):
            img = data[0]
            label = data[1] 
            # 计算模型输出
            logits = model(img)
            pred = F.softmax(logits)
            # 计算损失函数
            loss_func = paddle.nn.CrossEntropyLoss(reduction='none')
            loss = loss_func(logits, label)
            acc = paddle.metric.accuracy(pred, label)
            accuracies.append(acc.numpy())
            losses.append(loss.numpy())
        print("[validation] accuracy/loss: {:.4f}/{:.4f}".format(np.mean(accuracies), np.mean(losses)))
        model.train()

    # 保存模型参数
    paddle.save(model.state_dict(), 'mnist_LeNet.pdparams')


# 创建模型
model = LeNet(num_classes=10)
# 设置迭代轮数
EPOCH_NUM = 5
# 设置优化器为Momentum,学习率为0.001
opt = paddle.optimizer.Momentum(learning_rate=0.001, momentum=0.9, parameters=model.parameters())
# 定义数据读取器
train_loader = paddle.io.DataLoader(MNIST(mode='train', transform=ToTensor()), batch_size=10, shuffle=True)
valid_loader = paddle.io.DataLoader(MNIST(mode='test', transform=ToTensor()), batch_size=10)
# 启动训练过程
train(model, opt, train_loader, valid_loader)

#输出结果为:
#[validation] accuracy/loss: 0.9530/0.1516
#这里准确率为95.3%
#通过运行结果可以看出,LeNet在手写数字识别MNIST验证数据集上的准确率高达92%以上。  

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/826726.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

如何开启一个java微服务工程

安装idea IDEA常用配置和插件(包括导入导出) https://blog.csdn.net/qq_38586496/article/details/109382560安装配置maven 导入source创建项目 修改项目编码utf-8 File->Settings->Editor->File Encodings 修改项目的jdk maven import引入…

【C++】类和对象——拷贝构造函数、运算符重载、日期类实现、const成员、取地址操作符重载

目录 拷贝构造函数运算符重载日期类实现const成员取地址及const取地址操作符重载 拷贝构造函数 拷贝构造函数:只有单个形参,该形参是对本类类型对象的引用(一般常用const修饰),在用已存在的类类型对象创建新对象时由编译器自动调用。 拷贝构…

SOLIDWORKS 钣金零件怎么画?

一、SOLIDWORKS 钣金功能介绍 SOLIDWORKS 是一款广泛应用于机械设计领域的 CAD 软件,其钣金功能可以帮助用户快速创建钣金件的 3D 模型。钣金折弯是一种常见的加工方式,可以将平面材料通过弯曲变形成为所需形状。 二、如何使用 SOLIDWORKS 钣金功能 步骤…

shell清理redis模糊匹配的多个key

#!/bin/bash# 定义Redis服务器地址和端口 REDIS_HOST"localhost" REDIS_PORT6380# 获取匹配键的数量 function get_matching_keys() {local key_pattern"$1"redis-cli -h $REDIS_HOST -p $REDIS_PORT -n 0 KEYS "$key_pattern" }# 删除匹配的键 …

一文带你详细了解Open API设计规范

写在前面: OpenAPI 规范(OAS)定义了一个标准的、语言无关的 RESTful API 接口规范,它可以同时允许开发人员和操作系统查看并理解某个服务的功能,而无需访问源代码,文档或网络流量检查(既方便人…

Atlas200DK A2联网实战

文章目录 1.Atlas原始网络信息2. 开发板联网2.1 使用Type-c 连接开发板2.2 修改本地网络适配器2.3 修改开发板网络信息2.4 测试外网连接 1.Atlas原始网络信息 Type-C 网口 ETH0 网口 ETH1 网口 2. 开发板联网 2.1 使用Type-c 连接开发板 使用xshell 等ssh终端登录开发板&…

【C++从0到王者】第十五站:list源码分析及手把手教你写一个list

文章目录 一、list源码分析1.分析构造函数2.分析尾插等 二、手把手教你写一个list1.结点声明2.list类的成员变量3.list类的默认构造函数4.list类的尾插5.结点的默认构造函数6.list类的迭代器7.设计const迭代器8.list的insert、erase等接口9.size10.list的clear11.list的析构函数…

【java安全】CommonsBeanUtils1

文章目录 【java安全】CommonsBeanUtils1前言Apache Commons BeanutilsBeanComparator如何调用BeanComparator#compare()方法?构造POC完整POC 调用链 【java安全】CommonsBeanUtils1 前言 在之前我们学习了java.util.PriorityQueue,它是java中的一个优…

2.2 身份鉴别与访问控制

数据参考:CISP官方 目录 身份鉴别基础基于实体所知的鉴别基于实体所有的鉴别基于实体特征的鉴别访问控制基础访问控制模型 一、身份鉴别基础 1、身份鉴别的概念 标识 实体身份的一种计算机表达每个实体与计算机内部的一个身份表达绑定信息系统在执行操作时&a…

3、详解桶排序及排序内容总结

堆 满二叉树可以用一个数组中从0开始的连续一段来记录 i i i位置左孩子: 2 ∗ i + 1 2*i+1 2∗i+1,右孩子: 2 ∗ i + 2 2*i+2 2∗i+2,父: ( i − 1 ) / 2 (i-1)/2 (i−1)/2 大根堆 每一棵子树的根为最大值 小根堆 每一棵子树的根为最小值 建大根堆 不断地根据公…

配置HDFS单机版,打造数据存储的强大解决方案

目录 简介:步骤:安装java下载安装hadoop配置hadoop-env.sh配置 core-site.xml配置hdfs-site.xml初始化hdfs文件系统启动hdfs服务验证hdfs 结论: 简介: Hadoop分布式文件系统(HDFS)是Hadoop生态系统中的一个…

【硬件设计】模拟电子基础二--放大电路

模拟电子基础二--放大电路 一、基本放大电路1.1 初始电路1.2 静态工作点1.3 分压偏置电路 二、负反馈放大电路三、直流稳压电路 前言:本章为知识的简单复习,适合于硬件设计学习前的知识回顾,不适合运用于考试。 一、基本放大电路 1.1 初始电…

数学建模-爬虫入门

Python快速入门 简单易懂Python入门 爬虫流程 获取网页内容:HTTP请求解析网页内容:Requst库、HTML结果、Beautiful Soup库储存和分析数据 什么是HTTP请求和响应 如何用Python Requests发送请求 下载pip macos系统下载:pip3 install req…

VactorCast自动化单元测试

VectorCAST软件自动化测试方案 VectorCAST软件自动化测试方案 博客园 软件测试面临的问题 有一句格言是这样说的,“如果没有事先做好准备,就意味着做好了 失败的准备。”如果把这个隐喻应用在软件测试方面,就可以这样说“没有测试到&#xf…

Tomcat虚拟主机

Tomcat虚拟主机 部署 [rootlocalhost webapps]# cd ../conf [rootlocalhost conf]# pwd /usr/local/tomcat/conf [rootlocalhost conf]# vim server.xml #增加虚拟主机配置&#xff0c;添加以下&#xff1a; <Host name"www.a.com" appBase"webapps"u…

react-redux的理解与使用

一、react-redux作用 和redux和flux功能一样都是管理各个组件的状态&#xff0c;是redux的升级版。 二、为什么要用reac-redux&#xff1f; 那么我们既然有了redux&#xff0c;为什么还要用react-redux呢&#xff1f;原因如下&#xff1a; 1&#xff0c;解决了每个组件用数…

怎么才能远程控制笔记本电脑?

为什么选择AnyViewer远程控制软件&#xff1f; 为什么AnyViewer是远程控制笔记本电脑软件的首选&#xff1f;以下是选择AnyViewer成为笔记本电脑远程控制软件的主要因素。 跨平台能力 AnyViewer作为一款跨平台远程控制软件&#xff0c;不仅可以用于从一台Windows电…

数据库监控平台,数据库监控的指标有哪些--PIGOSS BSM

引言 在现代企业的信息化时代&#xff0c;数据库作为关键的数据存储和管理工具&#xff0c;扮演着至关重要的角色。然而&#xff0c;数据库的稳定性和高效性对于企业的正常运营至关重要。为了帮助企业保障数据库的运行状态&#xff0c;我们公司推出了PIGOSS BSM&#xff0c;一款…

MySql006——基本的SELECT查询语句

在《MySql003——结构化查询语言SQL基础知识》中&#xff0c;我们学习了有关SQL的基础知识&#xff0c;也知道SQL中查询语句SELECT使用最为频繁 接下来我们将学习一些基本的SELECT查询语句 一、SELECT语句的通用语法 在MySQL数据库中&#xff0c;使用SELECT语句可以查询数据…

024 - mix()函数

定义&#xff1a;MIN()函数返回一组值中的最小值。NULL 值不包括在计算中。 语法&#xff1a; MIN(expression) 参数值&#xff1a; 参数 描述 expression 必须项。数值&#xff08;可以是字段或公式&#xff09; -- 实际操作&#xff08;查询最小工资数&#xff09;: SE…