Pytorch深度学习快速入门—LeNet简单介绍(附代码)

news2024/11/16 9:19:24

一、网络模型结构

        LeNet是具有代表性的CNN,在1998年被提出,是进行手写数字识别的网络,是其他深度学习网络模型的基础。如下图所示,它具有连狙的卷积层和池化层,最后经全连接层输出结果。

二、各层参数详解

2.1 INPUT层-输入层

        数据input层,输入图像的尺寸为:32*32大小的一维一通道图片。

        注意:①灰度图像是单通道图像,其中每个像素只携带有关光强度的信息;

                   ②RGB图像是彩色图像,为三通道图像;

                   ③传统上输入层不被视为网络层次结构之一,因此输入层不算LeNet的网络结构。

2.2 C1层-卷积层

       输入数据(输入特征图input feature map):32*32

       卷积核大小:5*5

计算公式:

height_{out}=\frac{height_{in}-height_{kernel}+2*padding}{stride}+1width_{out}=\frac{width_{out}-widtht_{kernel}+2*padding}{stride}+1

其中,height_{in}是指输入图片的高度;width_{in}是指输入图片的宽度;height_{kernel}是指卷积核的大小;padding是指向图片外面补边,默认为0;S是指步长,卷积核遍历图片的步长,默认为1。

       卷积核种类(通道数):6

       输出数据(输出特征图output feature map):28*28

2.3 S2层-池化层(下采样层)

       池化是缩小高、长方向上的空间的运算。

       输入数据:28*28

       采样区域:2*2

       采样种类(通道数):6

       输出数据:14*14

注意:①经过池化运算,输入数据和输出数据的通道数不会发生变化。

②此时,S2中每个特征图的大小是C1中每个特征图大小的1/4.

2.4 C3层-卷积层

       输入数据:S2中所有6个或者几个特征map组合

       卷积核大小:5*5

       卷积核种类(通道数):16

       输出数据(输出特征图output feature map):10*10

注意:C3中的每个特征map是连接到S2中的所有6个或者几个特征map的,表示本层的特征map是上一层提取到的特征map的不同组合。

2.5 S4层-池化层(下采样层)

       输入数据:10*10

       采样区域:2*2

       采样种类(通道数):16

       输出数据:5*5

2.6 C5层-卷积层

       输入数据:S4层的全部16个单元特征map(与s4全相连)

       卷积核大小:5*5

       卷积核种类(通道数):120

       输出数据(输出特征图output feature map):1*1

2.7 F6层-全连接层

       输入数据:120维向量

       输出数据:84维向量

2.2 Output层-全连接层

       输入数据:84维向量

       输出数据:10维向量

三、代码实现(采用的激活函数为relu函数)

3.1 搭建网络框架

(1)导包:

import torch
import torch.nn as nn
import torch.nn.functional as F

 (2)定义卷积神经网络:由于训练数据采用的是彩色图片(三通道),因此与上面介绍的通道数有出入。

class Net(nn.Module):
    def __init__(self):
        super(Net,self).__init__()
        self.conv1 = nn.Conv2d(3,6,5)
        self.conv2 = nn.Conv2d(6,16,5)
        self.fc1 = nn.Linear(16*5*5,120)
        self.fc2 = nn.Linear(120,84)
        self.fc3 = nn.Linear(84,10)
    def forward(self,x):
        x = self.conv1(x)
        x = F.relu(x)
        x = F.max_pool2d(x,(2,2))
        x = F.max_pool2d(F.relu(self.conv2(x)),2)
        x = x.view(-1,x.size()[1:].numel())
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

(3)测试网络效果:相当于打印初始化部分

net = Net()
print(net)

3.2 定义数据集

(1)导包:

import torchvision
import torchvision.transforms as transforms

(2)下载数据集:

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
])

trainset = torchvision.datasets.CIFAR10(root='./data',train=True,download=True,transform=transform)
testset = torchvision.datasets.CIFAR10(root='./data',train=False,download=True,transform=transform)
trainloader = torch.utils.data.DataLoader(trainset,batch_size=4,shuffle=True,num_workers=0)
testloader = torch.utils.data.DataLoader(testset,batch_size=4,shuffle=False,num_workers=0)

(3)定义元组:进行类别名的中文转换

classes = ('airplane','automobile','bird','car','deer','dog','frog','horse','ship','truck')

 (4)运行数据加载器:使用绘图函数查看数据加载效果

import matplotlib.pyplot as plt
import numpy as np

def imshow(img):
    img = img / 2 + 0.5
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg,(1,2,0)))
    plt.show()

dataiter = iter(trainloader)
images,labels = dataiter.next()

imshow(torchvision.utils.make_grid(images))

print(labels)
print(labels[0],classes[labels[0]])
print(' '.join(classes[labels[j]] for j in range(4)))

3.3 定义损失函数与优化器

(1)定义损失函数:交叉熵损失函数

criterion = nn.CrossEntropyLoss()

(2)定义优化器:让网络进行更新,不断更新好的参数,达到更好的效果

import torch.optim as optim
optimizer = optim.SGD(net.parameters(),lr=0.001,momentum=0.9)

3.4 训练网络

for epoch in range(2):
    
    running_loss = 0.0
    
    for i,data in enumerate(trainloader,0):
        inputs,labels = data
        optimizer.zero_grad()
        
        outputs = net(inputs)
        loss = criterion(outputs,labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        
        if i % 2000 == 1999:
            print('[%d,%5d] loss:%.3f' % (epoch + 1,i+1,running_loss/2000))
            running_loss = 0.0

print("Finish")

3.5 测试网络

(1)保存学习好的网络参数:将权重文件保存到本地

PATH='./cifar_net.pth'
torch.save(net.state_dict(),PATH)

(2) 测试一组图片的训练效果

dataiter = iter(testloader)
images,labels = dataiter.next()
imshow(torchvision.utils.make_grid(images))
print('GroundTruth:',' '.join('%5s'% classes[labels[j]] for j in range(4)))

(3)观察整个训练集的测试效果

correct = 0
total = 0
with torch.no_grad():
    for data in testloader:
        images,labels = data
        outputs = net(images)
        _,predicted = torch.max(outputs,1)
        
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

correctGailv = 100*(correct / total)
print(correctGailv)

四、小结

        与“目前的CNN”相比,LeNet有以下几个不同点:

        ①激活函数不同:LeNet使用sigmoid函数,而目前的CNN中主要使用ReLU函数;

        ②原始的LeNet中使用子采样(subsampling)缩小中间数据的大小,而目前的CNN中Max池化是主流。

参考:LeNet详解-CSDN博客

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1125712.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

上门预约洗鞋小程序开发;

上门洗鞋小程序服务小程序是一款方便用户与服务提供者进行交流和预约的平台,覆盖多个行业,包括家政清洁、洗衣洗鞋,维修服务等,满足用户在生活中各种需求的上门服务。用户可以在小程序中选择服务项目、预约时间,服务人…

JDK的配置及运行过程

文章目录 介绍JDK编译运行过程为什么要配置环境变量配置环境变量的作用 配置JDK验证ps: 介绍JDK 【面试题】JDK、JRE、JVM之间的关系? JDK(Java Development Kit):Java开发工具包,提供给Java程序员使用,包含了JRE,同时还包含了编译…

Mysql如何理解Sql语句?MySql分析器

1. 什么是 MySQL 分析器? MySQL 分析器是 MySQL 数据库系统中的一个关键组件,它负责解析 SQL 查询语句,确定如何执行这些查询,并生成查询执行计划。分析器将 SQL 语句转换为内部数据结构,以便 MySQL 可以理解和执行查询请求。 …

数据集特征预处理

1、什么是特征预处理 1.1、什么是特征预处理 scikit-learn的解释 provides several common utility functions and transformer classes to change raw feature vectors into a representation that is more suitable for the downstream estimators. 翻译过来:通…

VRPTW(MATLAB):斑马优化算法ZOA求解带时间窗的车辆路径问题VRPTW(提供参考文献及MATLAB代码)

一、VRPTW简介 带时间窗的车辆路径问题(Vehicle Routing Problem with Time Windows, VRPTW)是车辆路径问题(VRP)的一种拓展类型。VRPTW一般指具有容量约束的车辆在客户指定的时间内提供配送或取货服务,在物流领域应用广泛,具有重要的实际意义。VRPTW常…

嵌入式系统设计中时钟抖动的基础

嵌入式开发时钟抖动是时钟沿偏离其理想位置的偏差。了解时钟抖动在应用中非常重要,因为它在系统的时序预算中起着关键作用。它有助于嵌入式开发工程师了解系统时序裕度。 随着系统数据速率的提高,时序抖动已成为系统设计中的关键,因为在某些…

[资源推荐]看到一篇关于agent的好文章

链接在此:Chat 向左,Agent 向右 - 李博杰的文章 - 知乎 https://zhuanlan.zhihu.com/p/662704254当时在电脑知乎上看了一半,打开手机微信公众号,就给我推了同样的,这推荐算法😥今年关于大模型的想法经历了几…

短视频矩阵系统源码搭建/技术应用开发/源头独立搭建

短视频剪辑矩阵系统开发源码----源头搭建 矩阵系统源码主要有三种框架:Spring、Struts和Hibernate。Spring框架是一个全栈式的Java应用程序开发框架,提供了IOC容器、AOP、事务管理等功能。Struts框架是一个MVC架构的Web应用程序框架,用于将数…

功能基础篇8——图形用户界面

图形用户界面 Graphics User Interface,GUI,图形用户界面 Ubuntu GUI Command Line Interface,CLI,命令行界面 Centos CLI tkinter GUI,Python标准库 from tkinter import ttk, Tkroot Tk() frm ttk.Frame(…

网络工程师最强入职指南

大家好,我是老杨。 秋招即将进入尾声,各位都找到心仪的工作了吗? 今年的春秋招的热度好像不是很高,而且很多网工都是在“全年找工作”的状态里持续着,字里行间无不透露出对行业和自身的焦虑。 毕竟“今年是未来10年…

vue3 elementPlus 表格实现行列拖拽及列检索功能

1、安装vuedraggable npm i -S vuedraggablenext 2、完整代码 <template> <div classcontainer><div class"dragbox"><el-table row-key"id" :data"tableData" :border"true"><el-table-columnv-for"…

Qt 资源系统(Qt Resource System)

1. Qt Resource System是什么&#xff1f; Qt 资源系统&#xff08;Qt Resource System&#xff09;是一种将图片、数据存储在二进制文件中的一套系统。构建应用程序需要的不仅仅是代码。通常你的界面会需要图标来做动作&#xff0c;你可能想要添加插图或品牌标识&#xff0c;或…

Spring Boot中RedisTemplate的使用

当前Spring Boot的版本为2.7.6&#xff0c;在使用RedisTemplate之前我们需要在pom.xml中引入下述依赖&#xff1a; <dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-data-redis</artifactId><vers…

insert overwrite table:数据仓库和数据分析中的常用技术

一、介绍&#xff1a; INSERT OVERWRITE TABLE 是用于覆盖&#xff08;即替换&#xff09;目标表中的数据的操作。它将新的数据写入表中&#xff0c;并删除原有的数据。这个操作适用于非分区表和分区表。 二、使用场景&#xff1a; 1、数据更新&#xff1a;当您需要更新表中…

软考系列(系统架构师)- 2021年系统架构师软考案例分析考点

试题一 软件架构&#xff08;架构风格、质量属性&#xff09; 【问题1】&#xff08;9分&#xff09; 在架构评估过程中&#xff0c;质量属性效用树(utility tree)是对系统质量属性进行识别和优先级排序的重要工具。 请将合适的质量属性名称填入图1-1中(1)、(2)空白处&#xf…

教您2个方法,轻松学会如何克隆硬盘或分区!

为什么需要克隆硬盘或分区&#xff1f; 在现在&#xff0c;学会如何克隆硬盘或分区是很重要的&#xff0c;因为这项技能本身是很简单的&#xff0c;并且也能够为我们带来足够多的好处与便利。 备份恢复&#xff1a;通过克隆硬盘驱动器或分区&#xff0c;您可以创建…

企业文件加密软件!哪个好用?

天锐绿盾是一款专业的企业文件加密软件&#xff0c;提供了多种功能来保护企业文件的安全。它的主要功能包括文件加密、文件外发控制、打印内容监控、内网行为管理、外网安全管理、文件管理控制、邮件白名单管理和U盘认证管理等功能。 PC访问地址&#xff1a; https://isite.ba…

节奏达人疯狂猜歌双端流量主小程序开发

节奏达人疯狂猜歌双端流量主小程序开发 流量主小程序千千万&#xff0c;可以长期运营且留存高的&#xff0c;猜歌小程序必有一席之地。 好运营&#xff1a;依靠社交属性&#xff0c;可以快速短时间裂变。依靠短视频可以快速吸引玩家。 活跃度高&#xff0c;粘性高&#xff0…

0基础学习PyFlink——使用PyFlink的SQL进行字数统计

在《0基础学习PyFlink——Map和Reduce函数处理单词统计》和《0基础学习PyFlink——模拟Hadoop流程》这两篇文章中&#xff0c;我们使用了Python基础函数实现了字&#xff08;符&#xff09;统计的功能。这篇我们将切入PyFlink&#xff0c;使用这个框架实现字数统计功能。 PyFl…

深入解析i++和++i的区别及性能影响

在我们编写代码时&#xff0c;经常需要对变量进行自增操作。这种情况下&#xff0c;我们通常会用到两种常见的操作符&#xff1a;i和i。最近在阅读博客时&#xff0c;我偶然看到了有关i和i性能的讨论。之前我一直在使用它们&#xff0c;但从未从性能的角度考虑过&#xff0c;这…