【从零开始学习深度学习】25.卷积神经网络之LeNet模型介绍及其Pytorch实现【含完整代码】

news2025/1/20 4:36:08

目录

    • 1. LeNet模型介绍与实现
    • 2. 输入为Fashion-MNIST时各层输出形状
    • 3. 获取Fashion-MNIST数据和并使用LeNet模型进行训练
    • 4.完整代码

之前我们对Fashion-MNIST数据集中的图像进行分类时,是将28*28图像中的像素逐行展开,得到长度为784的向量,并输入进全连接层中进行计算,这种分类方法有一定的局限性。

  1. 图像在同一列邻近的像素在这个向量中可能相距较远。它们构成的模式可能难以被模型识别。
  2. 对于大尺寸的输入图像,使用全连接层容易造成模型过大。假设输入是高和宽均为1000像素的彩色照片(含3个通道)。即使全连接层输出个数仍是256,该层权重参数的形状是 3 , 000 , 000 × 256 3,000,000\times 256 3,000,000×256:它占用了大约3 GB的内存或显存。这带来过复杂的模型和过高的存储开销。

卷积层尝试解决这两个问题:

一方面,卷积层保留输入形状,使图像的像素在高和宽两个方向上的相关性均可能被有效识别;

另一方面,卷积层通过滑动窗口将同一卷积核与不同位置的输入重复计算,从而避免参数尺寸过大。

卷积神经网络就是含卷积层的网络。本文我们将介绍一个早期用来识别手写数字图像的卷积神经网络:LeNet 。

Lenet 是一系列网络的合称,包括 Lenet1 - Lenet5,由 Yann LeCun 等人在 1990 年《Handwritten Digit Recognition with a Back-Propagation Network》中提出,是卷积神经网络的 HelloWorld。LeNet展示了通过梯度下降训练卷积神经网络可以达到手写数字识别在当时最先进的结果。这个奠基性的工作第一次将卷积神经网络推上舞台,为世人所知。LeNet5的网络结构如下图所示。

在这里插入图片描述

1. LeNet模型介绍与实现

LeNet分为卷积层块全连接层块两个部分。下面我们分别介绍这两个模块。

卷积层块里的基本单位是卷积层后接最大池化层:卷积层用来识别图像里的空间模式,如线条和物体局部,之后的最大池化层则用来降低卷积层对位置的敏感性。卷积层块由两个这样的基本单位重复堆叠构成。在卷积层块中,每个卷积层都使用 5 × 5 5\times 5 5×5的窗口,并在输出上使用sigmoid激活函数。第一个卷积层输出通道数为6,第二个卷积层输出通道数则增加到16。这是因为第二个卷积层比第一个卷积层的输入的高和宽要小,所以增加输出通道使两个卷积层的参数尺寸类似。卷积层块的两个最大池化层的窗口形状均为 2 × 2 2\times 2 2×2,且步幅为2。由于池化窗口与步幅形状相同,池化窗口在输入上每次滑动所覆盖的区域互不重叠。

卷积层块的输出形状为(批量大小, 通道, 高, 宽)。当卷积层块的输出传入全连接层块时,全连接层块会将小批量中每个样本变平(flatten)。也就是说,全连接层的输入形状将变成二维,其中第一维是小批量中的样本,第二维是每个样本变平后的向量表示,且向量长度为通道、高和宽的乘积。全连接层块含3个全连接层。它们的输出个数分别是120、84和10,其中10为输出的类别个数。

下面我们通过Sequential类来实现LeNet模型。

import time
import torch
from torch import nn, optim

import sys
import d2lzh_pytorch as d2l
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        # 卷积神经网络
        self.conv = nn.Sequential(
            nn.Conv2d(1, 6, 5), # in_channels, out_channels, kernel_size
            nn.Sigmoid(),
            nn.MaxPool2d(2, 2), # kernel_size, stride
            nn.Conv2d(6, 16, 5),
            nn.Sigmoid(),
            nn.MaxPool2d(2, 2)
        )
        # 分类器
        self.fc = nn.Sequential(
            nn.Linear(16*4*4, 120),
            nn.Sigmoid(),
            nn.Linear(120, 84),
            nn.Sigmoid(),
            nn.Linear(84, 10)
        )

    def forward(self, img):
        feature = self.conv(img)
        # 将feature展平,传入分类器fc
        output = self.fc(feature.view(img.shape[0], -1))   
        return output

接下来查看每个层的形状。

net = LeNet()
print(net)

输出:

LeNet(
  (conv): Sequential(
    (0): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
    (1): Sigmoid()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
    (4): Sigmoid()
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (fc): Sequential(
    (0): Linear(in_features=256, out_features=120, bias=True)
    (1): Sigmoid()
    (2): Linear(in_features=120, out_features=84, bias=True)
    (3): Sigmoid()
    (4): Linear(in_features=84, out_features=10, bias=True)
  )
)

可以看到,在卷积层块中输入的高和宽在逐层减小。卷积层由于使用高和宽均为5的卷积核,从而将高和宽分别减小4,而池化层则将高和宽减半,但通道数则从1增加到16。全连接层则逐层减少输出个数,直到变成图像的类别数10。

2. 输入为Fashion-MNIST时各层输出形状

如果输入为Fashion-MNIST数据集,那么各层的形状的变化过程如下:

self.conv = nn.Sequential(
			# 输入:1*28*28
            nn.Conv2d(1, 6, 5), # in_channels, out_channels, kernel_size
            # 输出:6 * 24 * 24    【24=28-5+1】
            nn.Sigmoid(),
    		# 输出:6 * 24 * 24
            nn.MaxPool2d(2, 2), # kernel_size, stride
    		# 输出:6 * 12 * 12    【12=(24-2+2)/2】
            nn.Conv2d(6, 16, 5),
    		# 输出:16 * 8 * 8     【8=12-5+1】
            nn.Sigmoid(),
    		# 输出:16 * 8 * 8
            nn.MaxPool2d(2, 2)
    		# 输出:16 * 4 * 4     【4=(8-2+2)/2】
        )
        # 分类器
        self.fc = nn.Sequential(
            # 输入:16*4*4
            nn.Linear(16*4*4, 120),
            # 输出:120
            nn.Sigmoid(),
            nn.Linear(120, 84),
            # 输出:84
            nn.Sigmoid(),
            nn.Linear(84, 10)
            # 输出:10
        )

3. 获取Fashion-MNIST数据和并使用LeNet模型进行训练

下面我们运用LeNet模型对Fashion-MNIST数据集进行训练。

batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size=batch_size)

因为卷积神经网络计算比多层感知机要复杂,建议使用GPU来加速计算。定义评价函数evaluate_accuracy,能同时支持GPU与CPU计算。

def evaluate_accuracy(data_iter, net, device=None):
    if device is None and isinstance(net, torch.nn.Module):
        # 如果没指定device就使用net的device
        device = list(net.parameters())[0].device
    acc_sum, n = 0.0, 0
    with torch.no_grad():
        for X, y in data_iter:
            if isinstance(net, torch.nn.Module):
                net.eval() # 评估模式, 这会关闭dropout
                acc_sum += (net(X.to(device)).argmax(dim=1) == y.to(device)).float().sum().cpu().item()
                net.train() # 改回训练模式
            else: 
                if('is_training' in net.__code__.co_varnames): # 如果有is_training这个参数
                    # 将is_training设置成False
                    acc_sum += (net(X, is_training=False).argmax(dim=1) == y).float().sum().item() 
                else:
                    acc_sum += (net(X).argmax(dim=1) == y).float().sum().item() 
            n += y.shape[0]
    return acc_sum / n

定义train_ch3训练函数,确保计算使用的数据和模型同在内存或显存上。

def train_ch5(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs):
    net = net.to(device)
    print("training on ", device)
    loss = torch.nn.CrossEntropyLoss()
    for epoch in range(num_epochs):
        train_l_sum, train_acc_sum, n, batch_count, start = 0.0, 0.0, 0, 0, time.time()
        for X, y in train_iter:
            X = X.to(device)
            y = y.to(device)
            y_hat = net(X)
            l = loss(y_hat, y)
            optimizer.zero_grad()
            l.backward()
            optimizer.step()
            train_l_sum += l.cpu().item()
            train_acc_sum += (y_hat.argmax(dim=1) == y).sum().cpu().item()
            n += y.shape[0]
            batch_count += 1
        test_acc = evaluate_accuracy(test_iter, net)
        print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f, time %.1f sec'
              % (epoch + 1, train_l_sum / batch_count, train_acc_sum / n, test_acc, time.time() - start))

学习率采用0.001,训练算法使用Adam算法,损失函数使用交叉熵损失函数。

lr, num_epochs = 0.001, 5
optimizer = torch.optim.Adam(net.parameters(), lr=lr)
train_ch5(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs)

输出:

training on  cpu
epoch 1, loss 1.7832, train acc 0.341, test acc 0.595, time 15.3 sec
epoch 2, loss 0.9300, train acc 0.649, test acc 0.705, time 15.5 sec
epoch 3, loss 0.7574, train acc 0.722, test acc 0.731, time 15.6 sec
epoch 4, loss 0.6708, train acc 0.745, test acc 0.743, time 15.6 sec
epoch 5, loss 0.6165, train acc 0.762, test acc 0.764, time 15.8 sec

4.完整代码

import time
import torch
from torch import nn, optim

import sys
import d2lzh_pytorch as d2l
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 定义模型
class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        # 卷积神经网络
        self.conv = nn.Sequential(
            nn.Conv2d(1, 6, 5), # in_channels, out_channels, kernel_size
            nn.Sigmoid(),
            nn.MaxPool2d(2, 2), # kernel_size, stride
            nn.Conv2d(6, 16, 5),
            nn.Sigmoid(),
            nn.MaxPool2d(2, 2)
        )
        # 分类器
        self.fc = nn.Sequential(
            nn.Linear(16*4*4, 120),
            nn.Sigmoid(),
            nn.Linear(120, 84),
            nn.Sigmoid(),
            nn.Linear(84, 10)
        )

    def forward(self, img):
        feature = self.conv(img)
        # 将feature展平,传入分类器fc
        output = self.fc(feature.view(img.shape[0], -1))   
        return output

# 定义评价函数
def evaluate_accuracy(data_iter, net, device=None):
    if device is None and isinstance(net, torch.nn.Module):
        # 如果没指定device就使用net的device
        device = list(net.parameters())[0].device
    acc_sum, n = 0.0, 0
    with torch.no_grad():
        for X, y in data_iter:
            if isinstance(net, torch.nn.Module):
                net.eval() # 评估模式, 这会关闭dropout
                acc_sum += (net(X.to(device)).argmax(dim=1) == y.to(device)).float().sum().cpu().item()
                net.train() # 改回训练模式
            else: 
                if('is_training' in net.__code__.co_varnames): # 如果有is_training这个参数
                    # 将is_training设置成False
                    acc_sum += (net(X, is_training=False).argmax(dim=1) == y).float().sum().item() 
                else:
                    acc_sum += (net(X).argmax(dim=1) == y).float().sum().item() 
            n += y.shape[0]
    return acc_sum / n

# 定义训练函数
def train_ch5(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs):
    net = net.to(device)
    print("training on ", device)
    loss = torch.nn.CrossEntropyLoss()
    for epoch in range(num_epochs):
        train_l_sum, train_acc_sum, n, batch_count, start = 0.0, 0.0, 0, 0, time.time()
        for X, y in train_iter:
            X = X.to(device)
            y = y.to(device)
            y_hat = net(X)
            l = loss(y_hat, y)
            optimizer.zero_grad()
            l.backward()
            optimizer.step()
            train_l_sum += l.cpu().item()
            train_acc_sum += (y_hat.argmax(dim=1) == y).sum().cpu().item()
            n += y.shape[0]
            batch_count += 1
        test_acc = evaluate_accuracy(test_iter, net)
        print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f, time %.1f sec'
              % (epoch + 1, train_l_sum / batch_count, train_acc_sum / n, test_acc, time.time() - start))

# 使用模型进行训练
net = LeNet()
batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size=batch_size)
lr, num_epochs = 0.001, 5
optimizer = torch.optim.Adam(net.parameters(), lr=lr)
train_ch5(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/93040.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Spring Cloud基于JWT创建统一的认证服务

认证服务肯定要有用户信息,不然怎么认证是否为合法用户?因为是内部的调用认证,可以简单一点,用数据库管理就是一种方式。或者可以配置用户信息,然后集成分布式配置管理就完美了。 表结构 本教程中的案例把查数据库这…

2022-年终总结

2022年已经到了尾声,后半年度过的太漫长了,也是自己这两年来成长速度最快的一次了(后文揭晓) 今年的年中总结链接 上半年我沉浸在读各类技术书籍中,但是后半年的我几乎放弃了读书,转而投身到另外一个学习渠…

Linux Phy 驱动解析

文章目录1. 简介2. phy_device2.1 mdio bus2.2 mdio device2.3 mdio driver2.4 poll task2.4.1 自协商配置2.4.2 link 状态读取2.4.3 link 状态通知3. phylink3.1 phylink_create()3.2 phylink_connect_phy()3.3 phylink_start()3.3 poll task参考资料1. 简介 在调试网口驱动的…

从另外一个角度解释AUC

AUC到底代表什么呢,我们从另外一个角度解释AUC,我们先看看一个auc曲线 蓝色曲线下的面积(我的模型的AUC)比红线下的面积(理论随机模型的AUC)大得多,所以我的模型一定更好。 我的模型比随机模型好多少呢?理论随机模型只是对角线,…

加密与认证技术

加密与认证技术密码技术概述密码算法与密码体制的基本概念加密算法与解密算法秘钥的作用什么是密码密钥长度对称密码体系对称加密的基本概念典型的对称加密算法DES加密算法3DES加密算法非对称密码体系非对称加密基本概念密码技术概述 密码技术是保证网络安全的核心技术之一&am…

【windows Server 2019系列】 构建IIS服务器

个人名片: 对人间的热爱与歌颂,可抵岁月冗长🌞 Github👨🏻‍💻:念舒_C.ying CSDN主页✏️:念舒_C.ying 个人博客🌏 :念舒_C.ying Web服务器也称为WWW(World W…

电子厂测试题——难倒众多主播——大司马也才90分

一、选择题 1、1-2 ( ) A.1 B.3 C.-1 D.-3 2、|1-2|( ) A.1 B.3 C. -1 D.-3 3、1x2x3( ) A.5 B.6 C.7 D.8 4、3643( ) A.29 B.16 C.8 D.3 5、55x5( ) A.15 B.30 C.50 D.125 二、填空题(请填写阿拉伯数字) 6、110100 1000_______ 7、一个三角形砍去1个角&#…

Feign的两种最佳实践方式介绍

何谓最佳实践呢?就是企业中各种踩坑,最后总结出来的相对比较好的使用方式; 下面给大家介绍两种比较好的实践方案: 方式一(继承):给消费者的FeignClient和提供着的Controller定义一个统一的父接…

在逆变器中驱动和保护IGBT

在逆变器中驱动和保护IGBT 介绍 ACPL-339J是一款先进的1.0 A双输出,易于使用,智能的手机IGBT门驱动光耦合器接口。专为支持而设计MOSFET制造商的各种电流评级,ACPL-339J使它更容易为系统工程师支持不同的系统额定功率使用一个硬件平台通过…

全面解析若依框架(springboot-vue前后分离--后端部分)

1、 若依框架分解 - 启动配置 前端启动 # 进入项目目录 cd ruoyi-ui# 安装依赖 npm install# 强烈建议不要用直接使用 cnpm 安装,会有各种诡异的 bug,可以通过重新指定 registry 来解决 npm 安装速度慢的问题。 npm install --registryhttps://regist…

算法刷题打卡第47天:排序数组---归并排序

排序数组 难度:中等 给你一个整数数组 nums,请你将该数组升序排列。 示例 1: 输入:nums [5,2,3,1] 输出:[1,2,3,5]示例 2: 输入:nums [5,1,1,2,0,0] 输出:[0,0,1,1,2,5]归并排…

用CSS给健身的侣朋友做一个喝水记录本

前言 事情是这样的,由于七八月份的晚上时不时就坐在地摊上开始了喝酒撸串的一系列放肆的长肉肉项目。 这不,前段时间女朋友痛下决心(心血来潮)地就去报了一个健身的私教班,按照教练给的饮食计划中,其中有一…

卵巢早衰与微生物群,营养治疗新进展

卵巢早衰 卵巢早衰(premature ovarian insufficiency,简称POI)在生殖系统疾病中位居首位,这些疾病可能会损害多个功能系统,降低生活质量,最终剥夺女性患者的生育能力。 目前的激素替代疗法不能改善受孕或降…

NR PDSCH(七) DL SPS

非动态调度,除了PUSCH configured grant type 1和2的传输,还有PDSCH SPS 传输,两者的流程基本类似,也有些小区别。在实网并没有见过配置DL SPS PDSCH传输的log,但还是按顺序理一遍相关内容。 RRC/MAC 先看下MAC 38.32…

文件上传,还存储在应用服务器?

一般项目开发中都会有文件、图片、视频等文件上传并能够访问的场景。要实现这样的场景,要么把文件存储在应用服务器上,要么搭建文件服务来存储。但是这两种方式也有不少的缺点,增加运维的成本。 因此,追求用户体验的项目可能会考…

Tomcat安装配置全解

👌 棒棒有言:也许我一直照着别人的方向飞,可是这次,我想要用我的方式飞翔一次!人生,既要淡,又要有味。凡事不必太在意,一切随缘,缘深多聚聚,缘浅随它去。凡事…

数据库分库分表

文章目录为什么要分库分表?数据切分垂直切分水平切分(每个表的结构相同)范围拆分取模拆分(一般为业务主键)分库分表带来的问题数据倾斜问题热点问题事务问题聚合查询问题分页问题非分区业务查询分库分表实现或工具hash…

DSP篇--C6701功能调试系列之 UART串口测试

目录 1、原理 2、测试 调试的前期准备可以参考前面的博文:DSP篇--C6701功能调试系列之前期准备_nanke_yh的博客-CSDN博客 UART串口收发数据存在两种模式:通常的串口模式(McBSP in Serial Port Mode)和GPIO模式(McBS…

哈希表及其与Java类集的关系

目录 1.哈希表的概念 2.哈希冲突 3.如何避免哈希冲突? 3.1哈希函数设计 3.2 负载因子的调节 4.解决哈希冲突 4.1闭散列 4.1.1线性探测 4.1.2二次探测 4.2开散列(哈希桶) 5.HashMap 6.HashSet 1.哈希表的概念 假设有一组数据,要让你去搜索其中的一个关键码,这种场…

JWT快速入门及所需依赖

目录 1.JWT 1.1什么是JWT 1.2JWT的构成 jwt的头部 payload signature 1.3JWT快速入门案例 2Jwt认证(微服务) 2.1微服务下统一权限认证 2.2应用认证 3.无状态的JWT令牌如何实现续签功能? 3.1不允许改变Token令牌实现续签 3.2允许改…