Pytorch实战(一):LeNet神经网络

news2025/1/14 18:29:45

文章目录

  • 一、模型实现
    • 1.1数据集的下载
    • 1.2加载数据集
    • 1.3模型训练
    • 1.4模型预测


  LeNet神经网络是第一个卷积神经网络(CNN),首次采用了卷积层、池化层这两个全新的神经网络组件,接收灰度图像,并输出其中包含的手写数字,在手写字符识别任务上取得了瞩目的准确率。LeNet网络的一系列的版本,以LeNet-5版本最为著名,也是LeNet系列中效果最佳的版本。LeNet神经网络输入图像大小必须为32x32,且所用卷积核大小固定为5x5,模型结构如下:
在这里插入图片描述

模型参数:

  • INPUT(输入层):输入图像尺寸为32x32,且是单通道灰色图像。
  • C1(卷积层):使用6个5x5大小的卷积核,步长为1,卷积后得到6张28×28的特征图。
  • S2(池化层):使用了6个2×2 的平均池化,池化后得到6张14×14的特征图。
  • C3(卷积层):使用了16个大小为5×5的卷积核,步长为1,得到 16 张10×10的特征图。
  • S4(池化层):使用16个2×2的平均池化,池化后得到16张5×5 的特征图。
  • C5(卷积层):使用120个大小为5×5的卷积核,步长为1,卷积后得到120张1×1的特征图。
  • F6(全连接层):输入维度120,输出维度是84(对应7x12 的比特图)。
  • OUTPUT(输出层):使用高斯核函数,输入维度84,输出维度是10(对应数字 0 到 9)。

该模型有如下特点:

  • 1.首次提出卷积神经网络基本框架: 卷积层,池化层,全连接层。
  • 2.卷积层的权重共享,相较于全连接层使用更少参数,节省了计算量与内存空间。
  • 3.卷积层的局部连接,保证图像的空间相关性。
  • 4.使用映射到空间均值下采样,减少特征数量。
  • 5.使用双曲线(tanh)或S型(sigmoid)形式的非线性激活函数。

一、模型实现

1.1数据集的下载

  使用torchversion内置的MNIST数据集,训练集大小60000,测试集大小10000,图像大小是1×28×28,包括数字0~9共10个类。

from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
import torchvision
# 下载训练、测试数据集
mnist_train = torchvision.datasets.MNIST(root='./dataset/',
                                         train=True, download=True, transform=transforms.ToTensor())
mnist_test = torchvision.datasets.MNIST(root='./dataset/',
                                        train=False, download=True, transform=transforms.ToTensor())
print('mnist_train基本信息为:',mnist_train)
print('-----------------------------------------')
print('mnist_test基本信息为:',mnist_test)
print('-----------------------------------------')
img,label=mnist_train[0]
print('mnist_train[0]图像大小及标签为:',img.shape,label)

在这里插入图片描述

1.2加载数据集

trainDataLoader = DataLoader(mnist_train, batch_size=64, num_workers=5, shuffle=True)
testDataLoader = DataLoader(mnist_test, batch_size=64, num_workers=0, shuffle=True)
write = SummaryWriter('./log')
step = 0
for images, labels in testDataLoader:
    write.add_images(tag='train', images, global_step=step)
    step += 1
write.close()

  注意不能使用for images, labels in testDataLoader.datasettestDataLoader.dataset[0]是保存图像(28
,28)和对应标签的元组,而Tensorboardadd_images只能输入NCHW格式对象,使用该代码会报错:

size of input tensor and input format are different. tensor shape: (1, 28, 28), input_format: NCHW

数据加载器按batch_size对数据及标签进行封装名,可直接作为输入。查看封装的元组:

for data in testDataLoader:
    print('type(data):',type(data))
    img,label=data
    print('type(img):',type(img),'img.shape:',img.shape)
    print('type(label):',type(label),'label.shape:',label.shape)

在这里插入图片描述

1.3模型训练

  LeNet模型的输入为(32,32)的图片,而MNIST数据集为(28,28)的图片,故需对原图片进行填充。搭建模型:

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.model = nn.Sequential(  #MNIST数据集图像大小为28x28,而LeNet输入为32x32,故需填充
            nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, stride=1, padding=2),  #C1层共六个卷积核,故out_channels=6
            nn.AvgPool2d(kernel_size=2, stride=2),  #C2层使用平均池化
            nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5),
            nn.AvgPool2d(kernel_size=2, stride=2),
            nn.Flatten(),
            nn.Conv2d(in_channels=16 * 5 * 5, out_channels=120),
            nn.Linear(in_features=120, out_features=84),
            nn.Linear(in_features=84, out_features=10)
        )

    def forward(self, x):
        return self.model(x)

# 初始化模型对象
myLeNet = LeNet()

  设置损失函数、优化器并训练模型:

# 设置损失函数为交叉熵损失函数
loss_fn = nn.CrossEntropyLoss()
loss_fn = loss_fn.to(device)
# 设置优化器,使用Adam优化算法
learning_rate = 1e-2
optimizer = torch.optim.Adam(myLeNet.parameters(), lr=learning_rate)
total_train_step = 0  # 总训练次数
epoch = 10  # 训练轮数
writer = SummaryWriter(log_dir='./runs/LeNet/')
for i in range(epoch):
    print("-----第{}轮训练开始-----".format(i + 1))
    myLeNet.train()  # 训练模式
    train_loss = 0
    for data in trainDataLoader:
        imgs, labels = data
        imgs = imgs.to(device)  # 适配GPU/CPU
        labels = labels.to(device)
        outputs = myLeNet(imgs)
        loss = loss_fn(outputs, labels)#计算损失函数
        optimizer.zero_grad()  # 清空之前梯度
        loss.backward()  # 反向传播
        optimizer.step()  # 更新参数
        total_train_step += 1  # 更新步数
        train_loss += loss.item()
        writer.add_scalar("train_loss_detail", loss.item(), total_train_step)
    writer.add_scalar("train_loss_total", train_loss, i + 1)
    
writer.close()

1.4模型预测

myLeNet.eval() 
total_test_loss = 0  # 当前轮次模型测试所得损失
total_accuracy = 0  # 当前轮次精确率
with torch.no_grad():  # 关闭梯度反向传播
    for data in testDataLoader:
        imgs, targets = data
        imgs = imgs.to(device)
        targets = targets.to(device)
        outputs = myLeNet(imgs)
        loss = loss_fn(outputs, targets)
        total_test_loss = total_test_loss + loss.item()
        accuracy = (outputs.argmax(1) == targets).sum()
        total_accuracy = total_accuracy + accuracy
writer.add_scalar("test_loss", total_test_loss, i+1)
writer.add_scalar("test_accuracy", total_accuracy/len(mnist_test), i+1)

https://blog.csdn.net/qq_43307074/article/details/126022041?ops_request_misc=%257B%2522request%255Fid%2522%253A%2522171938503416800186515588%2522%252C%2522scm%2522%253A%252220140713.130102334…%2522%257D&request_id=171938503416800186515588&biz_id=0&utm_medium=distribute.pc_search_result.none-task-blog-2alltop_click~default-2-126022041-null-null.142v100pc_search_result_base3&utm_term=LeNet&spm=1018.2226.3001.4187

https://blog.csdn.net/hellocsz/article/details/80764804?ops_request_misc=&request_id=&biz_id=102&utm_term=LeNet&utm_medium=distribute.pc_search_result.none-task-blog-2allsobaiduweb~default-1-80764804.142v100pc_search_result_base3&spm=1018.2226.3001.4187

https://blog.csdn.net/qq_45034708/article/details/128319241?ops_request_misc=%257B%2522request%255Fid%2522%253A%2522171936257316800222847105%2522%252C%2522scm%2522%253A%252220140713.130102334…%2522%257D&request_id=171936257316800222847105&biz_id=0&utm_medium=distribute.pc_search_result.none-task-blog-2alltop_positive~default-1-128319241-null-null.142v100pc_search_result_base3&utm_term=LeNet&spm=1018.2226.3001.4187

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1879010.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【吊打面试官系列-MyBatis面试题】#{}和${}的区别是什么?

大家好,我是锋哥。今天分享关于 【#{}和${}的区别是什么?】面试题,希望对大家有帮助; #{}和${}的区别是什么? #{} 是预编译处理,${}是字符串替换。 1000道 互联网大厂Java工程师 精选面试题-Java资源分享网…

对话贾扬清:我创业这一年所看到的 AI

引言 在这次对话中,前阿里巴巴人工智能专家、现LIBRINAI创始人贾扬清分享了他在AI领域创业一年的见解和经历。作为一位从科学家转型为CEO的创业者,他探讨了AI计算、异构计算和云原生软件的结合带来的革命性变化,并讨论了LIBRINAI如何在激烈的…

Redis 集群模式

一、集群模式概述 Redis 中哨兵模式虽然提高了系统的可用性,但是真正存储数据的还是主节点和从节点,并且每个节点都存储了全量的数据,此时,如果数据量过大,接近或超出了 主节点 / 从节点机器的物理内存,就…

无人机远程控制:北斗短报文技术详解

无人机(UAV)技术的快速发展和应用,使得远程控制成为了一项关键技术。无人机远程控制涉及无线通信、数据处理等多个方面,其中北斗短报文技术以其独特的优势,在无人机远程控制领域发挥着重要作用。本文将详细解析无人机远…

【SQL】已解决:MySQL 服务无法启动

文章目录 一、分析问题背景二、可能出错的原因三、错误代码示例四、正确代码示例五、注意事项 已解决:MySQL 服务无法启动 一、分析问题背景 MySQL是一种流行的开源关系型数据库管理系统,在许多应用中被广泛使用。有时在启动MySQL服务时,可…

Spring Boot集成jasypt快速入门Demo

1.什么是Jasypt? Jasypt(Java Simplified Encryption)是一个专注于简化Java加密操作的工具。 它提供了一种简单而强大的方式来处理数据的加密和解密,使开发者能够轻松地保护应用程序中的敏感信息,如数据库密码、API密…

PHP校园论坛-计算机毕业设计源码08586

摘 要 本项目旨在基于PHP技术设计与实现一个校园论坛系统,以提供一个功能丰富、用户友好的交流平台。该论坛系统将包括用户注册与登录、帖子发布与回复、个人信息管理等基本功能,并结合社交化特点,增强用户之间的互动性。通过利用PHP语言及其…

【D3.js in Action 3 精译】1.2.2 可缩放矢量图形(二)

当前内容所在位置 第一部分 D3.js 基础知识 第一章 D3.js 简介 1.1 何为 D3.js?1.2 D3 生态系统——入门须知 1.2.1 HTML 与 DOM1.2.2 SVG - 可缩放矢量图形 ✔️ 第一部分【第二部分】✔️第三部分(精译中 ⏳) 1.2.3 Canvas 与 WebGL&#x…

Linux多进程和多线程(一)

进程 进程的概念 进程(Process)是操作系统对一个正在运行的程序的一种抽象。它是系统运行程序的最小单位,是资源分配和调度的基本单位。 进程的特点如下 进程是⼀个独⽴的可调度的活动, 由操作系统进⾏统⼀调度, 相应的任务会被调度到cpu …

【鸿蒙学习笔记】尺寸设置

官方文档:尺寸设置 目录标题 width:设置组件自身的宽度,缺省时自适应height:设置组件自身的高度,缺省时自适应size:设置高宽尺寸。margin:设置组件的外边距padding:设置组件的内边距…

数据库-数据完整性-用户自定义完整性实验

NULL/NOT NULL 约束: 在每个字段后面可以加上 NULL 修饰符来指定该字段是否可以为空;或者加上 NOT NULL 修饰符来指定该字段必须填上数据。 DEFAULT约束说明 DEFAULT 约束用于向列中插入默认值。如果列中没有规定其他的值,那么会将默认值添加…

electron线上跨域问题

一、配置background.js win new BrowserWindow({webPreferences: {nodeIntegration: true, // 使渲染进程拥有node环境//关闭web权限检查,允许跨域webSecurity: false,// Use pluginOptions.nodeIntegration, leave this alone// See nklayman.github.io/vue-cli-p…

【计算机网络】HTTP——基于HTTP的功能追加协议(个人笔记)

学习日期:2024.6.29 内容摘要:基于HTTP的功能追加协议和HTTP/2.0 HTTP的瓶颈与各功能追加协议 需求的产生 在Facebook、推特、微博等平台,每分每秒都会有人更新内容,我们作为用户当然希望时刻都能收到最新的消息,为…

常用字符串方法<python>

导言 在python中内置了许多的字符串方法,使用字符串方法可以方便快捷解决很多问题,所以本文将要介绍一些常用的字符串方法。 目录 导言 string.center(width[,fillchar]) string.capitalize() string.count(sub[,start[,end]]) string.join(iterabl…

收银系统源码-千呼新零售【全场景收银】

千呼新零售2.0系统是零售行业连锁店一体化收银系统,包括线下收银线上商城连锁店管理ERP管理商品管理供应商管理会员营销等功能为一体,线上线下数据全部打通。 适用于商超、便利店、水果、生鲜、母婴、服装、零食、百货、宠物等连锁店使用。 详细介绍请…

基于星火大模型的群聊对话分角色要素提取挑战赛-Lora微调与prompt构造

赛题连接 https://challenge.xfyun.cn/topic/info?typerole-element-extraction&optionphb 数据集预处理 由于赛题官方限定使用了星火大模型,所以只能调用星火大模型的API或者使用零代码微调 首先训练数据很少是有129条,其中只有chat_text和info…

模版方法模式详解:使用和实现的指南

目录 模版方法模式模版方法模式结构模版方法模式适合应用场景模版方法模式优缺点练手题目题目描述输入描述输出描述题解 模版方法模式 模板方法模式是一种行为设计模式, 它在超类中定义了一个算法的框架, 允许子类在不修改结构的情况下重写算法的特定步…

游戏推荐: 植物大战僵尸杂交版

下载地址网上一搜就有. 安装就能玩. 2是显血. 4显示植物血, 5是加速. 都是左手主键盘的按钮, 再按是取消. 比较刺激: ps: 设置里面还能打开自动收集阳光和金币.

Elasticsearch (1):ES基本概念和原理简单介绍

Elasticsearch(简称 ES)是一款基于 Apache Lucene 的分布式搜索和分析引擎。随着业务的发展,系统中的数据量不断增长,传统的关系型数据库在处理大量模糊查询时效率低下。因此,ES 作为一种高效、灵活和可扩展的全文检索…

斜率优化DP——AcWing 303. 运输小猫

斜率优化DP 定义 斜率优化DP(Slope Optimization Dynamic Programming)是一种高级动态规划技巧,用于优化具有特定形式的状态转移方程。它主要应用于那些状态转移涉及求极值(如最小值或最大值)的问题中,通…