神经网络-LeNet

news2024/12/25 16:19:30

 LeNet在1990年被提出,是一系列网络的统称,包括了LeNet1~LeNet5,对于神经网络的学习者来说,大家对下面这个图一定很熟悉,该图是对LeNet的简化展示。

 

在LeNet中已经提出了卷积层、Pooling层等概念,只是但是由于缺乏大量数据和计算机硬件资源限制,导致LeNet的表现并不理想。

LeNet网络结构

LeNet的构成很简单,包括了基础的卷积层、池化层和全连接层,原始的LeNet使用的是灰度图像,下面示例中使用彩色图像进行说明,不影响网络的理解。

  • 定义网络层

# 定义网络
class LeNet(nn.Module):                    #继承来着nn.Module的父类
    def __init__(self):  
        # 初始化网络
        #super()继承父类的构造函数,多继承需用到super函数
        super(LeNet, self).__init__()
        
        # 定义卷积层,[深度,卷积核数,卷积核大小]
        self.conv1 = nn.Conv2d(3, 16, 5)
        # 最大池化,[核大小,步长]
        self.pool1 = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(16, 32, 5)
        self.pool2 = nn.MaxPool2d(2, 2)
        # 全连接层
        self.fc1 = nn.Linear(32*5*5, 120)
        self.fc2 = nn.Linear(120, 84)
        # 根据训练项目,调整类别数
        self.fc3 = nn.Linear(84, 10)
                                     #图像参数变化
    def forward(self, x):            # input(3, 32, 32)        
        x = F.relu(self.conv1(x))    #output(16, 28, 28)
        x = self.pool1(x)            # output(16, 14, 14)
        x = F.relu(self.conv2(x))    # output(32, 10, 10)
        x = self.pool2(x)            # output(32, 5, 5)
        x = x.view(-1, 32*5*5)       # output(32*5*5)
        x = F.relu(self.fc1(x))      # output(120)
        x = F.relu(self.fc2(x))      # output(84)
        x = self.fc3(x)              # output(10)
        return x

网络结构如下,下面将对每一层做一个介绍:

 网络中feature map的变化大致如下:

 

LeNet实例应用

  • 数据预处理

# 对数据进行预处理
transform = transforms.Compose(
    [
        # 将输入的 numpy.ndarry[h*w*c]转变为[c*h*w],像素点值从[0,255],标准化为[0,1]
        transforms.ToTensor(),
        # 将数据进行标注化
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ]
)
  • 数据读取

如果是初次使用CIFAR,需要将download打开,也可以自行通过其他方式进行下载。

# 读取数据-训练集
train_set = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=36, shuffle=False, num_workers=0)
  • 定义网络

通过LeNet中的介绍,完成网络的定义。

  • 定义损失函数和优化器

pytorch支持很多损失函数和优化器,可以根据需要进行设定

# 定义损失函数
loss_function = nn.CrossEntropyLoss()

# 定义优化器
optimizer = optim.Adam(net.parameters(), lr=0.001)
  • 模型训练

# 开始训练,设置迭代轮次 epoch
for epoch in range(3):
    # 损失函数值
    running_loss = 0.0
    
    for step, data in enumerate(train_loader, start=0):
        inputs, labels = data
        
        # 清除梯度累加值
        optimizer.zero_grad()
        
        outputs = net(inputs.to(device))
        # 计算损失值
        loss = loss_function(outputs, labels.to(device))
        # 计算梯度
        loss.backward()
        # 参数更新
        optimizer.step()
        
        # 输出损失值
        running_loss += loss.item()
        if step % 500 == 499:
            with torch.no_grad():
                outputs = net(val_image.to(device))
                # 输出最大概率
                predict_y = torch.max(outputs, dim=1)[1]
                accuracy = (predict_y == val_label.to(device)).sum().item() / val_label.size(0)
                
                print('[%d, %5d] train_Loss:%.3f tese_accuracy: %.3f' % (epoch + 1, step + 1, running_loss/500, accuracy))
                running_loss = 0.0
                
print('train finished')
  • 保存模型

# 保存模型
save_path = './Lenet.pth'
torch.save(net.state_dict(), save_path)

补充

  • Pytorch中tensor的顺序是:[batch, channel, height, width]

  • 卷积层中计算输出大小

 

  • W表示输入图像的Weight,一般Weight=hight

  • F表示核的大小,核大小一般为F * F

  • P表示Padding,Conv2d中默认是0

  • S表示步长

因此对于32*32的输入,在该网络中Output=(32-5+2*0)/1 +1 = 28

  • 池化层只改变特征的高和宽,不改变深度

因此对于16*28*28,经过MaxPooling后变成了16*14*14

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2265371.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

VMD-SSA-BiLSTM、VMD-BiLSTM、BiLSTM时间序列预测对比

VMD-SSA-BiLSTM、VMD-BiLSTM、BiLSTM时间序列预测对比 目录 VMD-SSA-BiLSTM、VMD-BiLSTM、BiLSTM时间序列预测对比预测效果基本介绍程序设计参考资料 预测效果 基本介绍 1.MATLAB实现VMD-SSA-BiLSTM、VMD-BiLSTM、BiLSTM时间序列预测对比; 2.单变量时间序列预测 就是先vmd把变…

联通光猫怎么自己改桥接模式?

环境: 联通光猫 ZXHN F677V9 硬件版本号 V9.0 软件版本号 V9.0.0P1T3 问题描述: 联通光猫怎么自己改桥接模式 家里用的是ZXHN F677V9 光猫,最近又搞了个软路由,想改桥接模式 解决方案: 1.拿到最新超级密码&…

JSON 系列之1:将 JSON 数据存储在 Oracle 数据库中

本文为Oracle数据库JSON学习系列的第一篇,讲述如何将JSON文档存储到数据库中,包括了版本为19c和23ai的情形。 19c中的JSON 先来看一下数据库版本为19c时的情形。 创建表colortab,其中color列的长度设为4000。若color的长度需要设为32767&a…

【从零开始入门unity游戏开发之——unity篇02】unity6基础入门——软件下载安装、Unity Hub配置、安装unity编辑器、许可证管理

文章目录 一、软件下载安装1、Unity官网2、下载Unity Hub 二、修改Unity Hub配置1、设置Unity Hub中文语言2、修改默认存储目录 三、安装unity编辑器1、点击安装编辑器2、版本选择3、关于版本号4、安装模块选择5、等待下载完成自动安装即可6、追加unity和模块 四、许可证管理专…

SAP从入门到放弃系列之委外分包(Subcontracting)-Part1

以前写过一篇委外相关的文章,没有很详细的写。只是一个概念的概述ERP实施-委外业务-委外采购业务 最近看PA教材,遇到了这块内容,就再详细的整理一下SAP关于委外的理论知识。 文章目录 概述分包和物料需求计划 (MRP)委外分包订单分包委外业务…

vue前端报错 ERROR Error The project seems to require yarn but it‘s not installed

当我们项目启动的时候会报错 报错的信息:ERROR Error: The project seems to require yarn but it’s not installed. 解决的办法首先找到右边的文件夹,yarnlock 找打这个文件删除以后进行全局安装 npm install -g yarn

分体空调智能控制系统

空调是建筑中的用能大户,据统计,空调能耗占建筑总能耗的60%,空调节能作为建筑节能减排的重要组成部分,针对空调的监测和控制尤为重要。随着双碳战略的深入推进、数字化技术的快速发展、人们节能意识普遍增强,对空调用电…

Axure RP 8安装(内带安装包)

通过网盘分享的文件:Axure8.0.zip 链接: https://pan.baidu.com/s/195_qy2iiDIcYG4puAudScA 提取码: 6xt8 --来自百度网盘超级会员v1的分享 勾选I Agree 安装完成

如何在centos系统上挂载U盘

在CentOS上挂载NTFS格式的U盘,需要执行一系列步骤,包括识别U盘设备、安装必要的软件、创建挂载点,并最终挂载U盘。以下是在CentOS上挂载NTFS格式U盘的详细步骤: 一、准备工作 确认CentOS版本: 确保你的CentOS系统已经安装并正常运行。不同版本的CentOS在命令和工具方面可能…

闯关leetcode——3158. Find the XOR of Numbers Which Appear Twice

大纲 题目地址内容 解题代码地址 题目 地址 https://leetcode.com/problems/find-the-xor-of-numbers-which-appear-twice/description/ 内容 You are given an array nums, where each number in the array appears either once or twice. Return the bitwise XOR of all …

zabbix“专家坐诊”第270期问答

问题一 Q:请问,zabbix 6.0.26 是否支持在 Monitoring|Latest data页面仅通过Tags筛选主机?我尝试了下,发现无法筛选到。 A:这里是标记 要选主机正在旁边选。 Q:我只想筛选tag为xxx的主机,不限定…

Git使用经历

目录 1、先创建文件夹 2、仓库初始化 3、配置gitee用户名和密码 4、克隆指定仓库的中指定分支到本地仓库 5、查看当前所在分支、切换分支 6、查看状态,判断是否有修改 7、把更新的内容添加到缓存区 8、把缓存区的数据提交 9、把数据推送到远程仓库 10、把…

蓝牙协议——音乐启停控制

手机播放音乐 手机暂停音乐 耳机播放音乐 耳机暂停音乐

【Web】2024“国城杯”网络安全挑战大赛决赛题解(全)

最近在忙联通的安全准入测试,很少有时间看CTF了,今晚抽点时间回顾下上周线下的题(期末还没开始复习😢) 感觉做渗透测试一半的时间在和甲方掰扯&水垃圾洞,没啥惊喜感,还是CTF有意思 目录 Mountain ez_zhuawa 图…

信奥赛四种算法描述

#include <iostream> #include <iomanip> using namespace std;// 使用unsigned long long类型来尽量容纳较大的结果&#xff0c;不过实际上这个数值极其巨大&#xff0c;可能最终仍会溢出 // 更好的方式可以考虑使用高精度计算库&#xff08;如GMP等&#xff09;来…

12.19问答解析

概述 某中小型企业有四个部门&#xff0c;分别是市场部、行政部、研发部和工程部&#xff0c;请合理规划IP地址和VLAN&#xff0c;实现企业内部能够互联互通&#xff0c;同时要求市场部、行政部和工程部能够访问外网环境(要求使用OSPF协议)&#xff0c;研发部不能访问外网环境…

Docker部署Sentinel

一、简介 是什么&#xff1a;面向分布式、多语言异构化服务架构的流量治理组件 能干嘛&#xff1a;从流量路由、流量控制、流量整形、熔断降级、系统自适应过载保护、热点流量防护等多个维度来帮助开发者保障微服务的稳定性 官网地址&#xff1a;https://sentinelguard.io/zh-c…

LabVIEW如何学习FPGA开发

FPGA&#xff08;现场可编程门阵列&#xff09;开发因其高性能、低延迟的特点&#xff0c;在实时控制和高速数据处理领域具有重要地位。LabVIEW FPGA模块为开发者提供了一个图形化编程平台&#xff0c;降低了FPGA开发的门槛。本篇文章将详细介绍LabVIEW FPGA开发的学习路径&…

基于图注意力网络的两阶段图匹配点云配准方法

Two-stage graph matching point cloud registration method based on graph attention network— 基于图注意力网络的两阶段图匹配点云配准方法 从两阶段点云配准方法中找一些图匹配的一些灵感。文章提出了两阶段图匹配点云配准网络&#xff08;TSGM-Net&#xff09; TSGM-Ne…

U9多组织的退货单不能拉单找不到退货单

培训没有估好吧。参数是顾问亲自操刀去做的。当时事情又多&#xff0c;参加会议都是缺席的。财务做应收单拉单时&#xff0c;说有退货单找不到。也就是查询的条件&#xff08;逻辑&#xff09;不对嘛。U9的参数查询条件太多&#xff0c;逻辑复杂&#xff0c;一时真的分析不出来…