【动手学深度学习】(十一)池化层+LeNet

news2024/11/19 22:34:33

文章目录

  • 一、池化层
    • 1.理论知识
    • 2.代码
  • 二、LeNet
    • 1.理论知识
    • 2.代码实现
  • 【相关总结】
    • nn.MaxPool2d()

卷积层对位置比较敏感

一、池化层

1.理论知识

二维最大池化
在这里插入图片描述
在这里插入图片描述
填充、步幅和多个通道

  • 池化层与卷积层类似,都具有填充和步幅
  • 没有可学习的参数
  • 在每个输入通道应用池化层以获得相应的输出通道
  • 输出通道数=输入通道数

平均池化层

  • 最大池化层:每个窗口中最强的模式信号
  • 平均池化层:将最大池化层中的“最大”操作替换为“平均”

2.代码

实现池化层的正向传播

import torch
from torch import nn
from d2l import torch as d2l

def pool2d(X, pool_size, mode='max'):
    p_h, p_w = pool_size
    Y = torch.zeros((X.shape[0] - p_h + 1, X.shape[1] - p_w + 1))
    for i in range(Y.shape[0]):
        for j in range(Y.shape[1]):
            if mode == 'max':
                Y[i, j] = X[i:i + p_h, j:j+p_w].max()
            elif mode == 'avg':
                Y[i, j] = X[i:i + p_h, j:j + p_w].mean()
                
    return Y
# 验证二维最大池化层的输出
X = torch.tensor([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0], [6.0, 7.0, 8.0]])
pool2d(X, (2,2))
# print(Y)

tensor([[4., 5.],
[7., 8.]])

# 验证平均池化层
pool2d(X,(2,2),'avg')

tensor([[2., 3.],
[5., 6.]])

X = torch.arange(16, dtype=torch.float32).reshape((1,1,4,4))
# X

# 深度学习框架中的步幅与池化窗口的大小相同
pool2d = nn.MaxPool2d(3)
pool2d(X)

tensor([[[[10.]]]])

# 手动指定步幅和填充
pool2d = nn.MaxPool2d(3, padding=1, stride=2)
pool2d(X)

tensor([[[[ 5., 7.],
[13., 15.]]]])

# 设定一个任意大小的矩形池化窗口
pool2d = nn.MaxPool2d((2,3), padding=(1,1), stride=(2,3))
pool2d(X)

tensor([[[[ 1., 3.],
[ 9., 11.],
[13., 15.]]]])

X = torch.cat((X, X + 1), 1)
# Y2 = torch.stack((X,X+1))
# print(Y)
# print(Y2)
X

tensor([[[[ 0., 1., 2., 3.],
[ 4., 5., 6., 7.],
[ 8., 9., 10., 11.],
[12., 13., 14., 15.]],
[[ 1., 2., 3., 4.],
[ 5., 6., 7., 8.],
[ 9., 10., 11., 12.],
[13., 14., 15., 16.]]]])

# print(X.shape)
pool2d = nn.MaxPool2d(3, padding=1, stride=2)
pool2d(X)

tensor([[[[ 5., 7.],
[13., 15.]],
[[ 6., 8.],
[14., 16.]]]])

二、LeNet

1.理论知识

在这里插入图片描述
LeNet-5的典型结构:

  • 输入层:输入图像大小为32*32
  • 第一层:卷积核大小为5*5,输出通道数为6
  • 第二层:大小为2*2的平均池化层,步幅为2
  • 第三层:卷积核大小为5,输出通道为16
  • 第四层:大小为2*2的平均池化层,步幅为2
  • 第五层:120个神经元的全连接层
  • 第六层:84个神经元的全连接层
  • 输出层:10个神经元,对应于10个类别
    在这里插入图片描述

2.代码实现

LeNet由两个部分组成:卷积编码器和全连接层密集块

import torch
from torch import nn
from d2l import torch as d2l

class Reshape(torch.nn.Module):
    def forward(self, x):
        return x.view(-1, 1, 28, 28)
    
net = torch.nn.Sequential(
    Reshape(), nn.Conv2d(1, 6, kernel_size=5, padding=2), nn.Sigmoid(),
    nn.AvgPool2d(2, stride=2),
    nn.Conv2d(6, 16, kernel_size=5), nn.Sigmoid(),
    nn.AvgPool2d(kernel_size=2, stride=2), nn.Flatten(),
    nn.Linear(16 * 5 * 5, 120), nn.Sigmoid(),
    nn.Linear(120, 84), nn.Sigmoid(),
    nn.Linear(84, 10)
)
# print(net)

检查模型

X = torch.rand(size=(1, 1, 28, 28), dtype=torch.float32)
for layer in net:
    X = layer(X)
    print(layer.__class__.__name__,'output shape:\t', X.shape)
Reshape output shape:	 torch.Size([1, 1, 28, 28])
Conv2d output shape:	 torch.Size([1, 6, 28, 28])
Sigmoid output shape:	 torch.Size([1, 6, 28, 28])
AvgPool2d output shape:	 torch.Size([1, 6, 14, 14])
Conv2d output shape:	 torch.Size([1, 16, 10, 10])
Sigmoid output shape:	 torch.Size([1, 16, 10, 10])
AvgPool2d output shape:	 torch.Size([1, 16, 5, 5])
Flatten output shape:	 torch.Size([1, 400])
Linear output shape:	 torch.Size([1, 120])
Sigmoid output shape:	 torch.Size([1, 120])
Linear output shape:	 torch.Size([1, 84])
Sigmoid output shape:	 torch.Size([1, 84])
Linear output shape:	 torch.Size([1, 10])

LeNet在Fashion-MNIST数据集上的表现

batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size=batch_size)
train_iter.num_workers = 0
test_iter.num_workers = 0

对evaluate_accuracy函数进行改进

def evaluate_accuracy_gpu(net, data_iter, device=None):
    """使用GPU计算模型在数据集上的精度"""
    if isinstance(net, torch.nn.Module):
        net.eval()
        if not device:
            device = next(iter(net.parameters())).device
    metric = d2l.Accumulator(2)
    for X, y in data_iter:
        if isinstance(X,list):
            X = [x.to(device) for x in X]
        else:
            X = X.to(device)
        y = y.to(device)
#       将当前批次的正确预测数量和总样本数
        metric.add(d2l.accuracy(net(X), y), y.numel())
    return metric[0] / metric[1]

为了使用GPU,我们还需要修改

def train_ch6(net, train_iter, test_iter, num_epochs, lr, device):
    """用GPU训练模型"""
#   初始化神经网络的权重
    def init_weights(m):
        if type(m) == nn.Linear or type(m) == nn.Conv2d:
            nn.init.xavier_uniform_(m.weight)
    net.apply(init_weights)
    print('training on', device)
    net.to(device)
    optimizer = torch.optim.SGD(net.parameters(), lr=lr)
    loss = nn.CrossEntropyLoss()
    animator = d2l.Animator(xlabel='epoch', xlim=[1,num_epochs],
                           legend=['train loss', 'train acc', 'test acc'])
    timer, num_batches = d2l.Timer(), len(train_iter)
    for epoch in range(num_epochs):
#       训练损失之和,训练准确率之和,样本数
        metric = d2l.Accumulator(3)
#     将神经网络设置为训练模式
        net.train()
        for i, (X, y) in enumerate(train_iter):
            timer.start()
            optimizer.zero_grad()
            X, y = X.to(device), y.to(device)
            y_hat = net(X)
            l = loss(y_hat, y)
            l.backward()
            optimizer.step()
            with torch.no_grad():
                metric.add(l * X.shape[0], d2l.accuracy(y_hat, y), X.shape[0])
            timer.stop()
#           计算平均训练损失和平均训练准确率
            train_l = metric[0] / metric[2]
            train_acc = metric[1] / metric[2]
#         控制输出频率,确保训练信息在每个 epoch的五分之一处和最后一个迭代时被输出
            if(i+1) % (num_batches // 5) == 0 or i == num_batches - 1:
                animator.add(epoch + (i + 1) / num_batches,
                            (train_l, train_acc, None))
#       在测试数据集上评估模型的准确率  
        test_acc = evaluate_accuracy_gpu(net, test_iter)
        animator.add(epoch + 1, (None, None, test_acc))
    print(f'loss {train_l:.3f}, train acc {train_acc:.3f}, '
          f'test acc {test_acc:.3f}')
    print(f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec '
          f'on {str(device)}')
# 训练和评估LeNet-5模型
torch.cuda.set_device(0)
lr, num_epochs = 0.9, 10
train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())

loss 0.461, train acc 0.827, test acc 0.818
23915.6 examples/sec on cuda:0
在这里插入图片描述

【相关总结】

nn.MaxPool2d()

torch.nn.MaxPool2d(kernel_size, [stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False])

  • kernel_size:池化窗口大小,当为一个整数时,表示为一个方形,否则需要输入一个包含长宽的元组。
  • stride:窗口移动的步长,!!!默认是kernel_size
import torch
import torch.nn as nn

# 定义一个最大池化层,窗口大小为 3x3
max_pool_layer = nn.MaxPool2d(3)

# 创建一个输入张量(假设是一张图像)
input_data = torch.rand(1, 1, 5, 5)  # (batch_size, channels, height, width)

# 使用最大池化层进行池化操作
output_data = max_pool_layer(input_data)

print("Input data:")
print(input_data)

print("\nOutput data after max pooling:")
print(output_data)

Input data:
tensor([[[[0.0636, 0.8813, 0.3543, 0.8072, 0.7034],
[0.0906, 0.2161, 0.3276, 0.7605, 0.5871],
[0.3102, 0.9458, 0.7694, 0.7519, 0.5355],
[0.0510, 0.6437, 0.4188, 0.0824, 0.0427],
[0.5253, 0.1354, 0.7783, 0.6787, 0.4483]]]])

Output data after max pooling:
tensor([[[[0.9458]]]])

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1298111.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【出现模块node_modules里面包找不到】

#pic_center R 1 R_1 R1​ R 2 R^2 R2 目录 一、出现的问题二、解决办法三、其它可供参考 一、出现的问题 在本地运行 npm run docs:dev之后,出现 Error [ERR_MODULE_NOT_FOUND]: Cannot find package Z:\Blog\docs\node_modules\htmlparser2\ imported from Z:\Blo…

CCF计算机软件能力认证202309-1坐标变换(其一)(C语言)

ccf-csp计算机软件能力认证202309-1坐标变换(其一)(C语言版) 题目内容: 问题描述 输入格式 输出格式 样例输入 3 2 10 10 0 0 10 -20 1 -1 0 0样例输出 21 -11 20 -10样例解释 评测用例规模与约定 解题思路 1.第一步分析问题&…

Redux Toolkit(RTK)在React tsx中的使用

一个需求: header组建中有一个搜索框,然后这个搜索框在其他页面路由上都可以使用:例如这两个图共用顶部的搜索框; 我之前的做法就是组建传值, 在他们header 组建和 PageA ,B 的父级组件上定一个值,然后顶部变化传到父级组件,在从父级组件传到page组件,有点繁琐,现在说一下利用…

Javaweb之 依赖管理的详细解析

04. 依赖管理 4.1 依赖配置 依赖:指当前项目运行所需要的jar包。一个项目中可以引入多个依赖: 例如:在当前工程中,我们需要用到logback来记录日志,此时就可以在maven工程的pom.xml文件中,引入logback的依…

16.Java程序设计-基于SSM框架的android餐厅在线点单系统App设计与实现

摘要: 本研究旨在设计并实现一款基于SSM框架的Android餐厅在线点单系统,致力于提升餐厅点餐流程的效率和用户体验。通过整合Android移动应用和SSM框架的优势,该系统涵盖了用户管理、菜单浏览与点单、订单管理、支付与结算等多个功能模块&…

解决 Cannot read properties of undefined (reading ‘getUserMedia‘) 报错

[TOC](解决 Cannot read properties of undefined (reading ‘getUserMedia’) 报错) 0. 背景 使用浏览器输入语音时,浏览器的控制台里面有下面错误信息。 Cannot read properties of undefined (reading getUserMedia)1. 解决方法 在浏览器中访问 chrome://fla…

AVFormatContext编解码层:理论与实战

文章目录 前言一、FFmpeg 解码流程二、FFmpeg 转码流程三、编解码 API 详解1、解码 API 使用详解2、编码 API 使用详解 四、编码案例实战1、示例源码2、运行结果 五、解码案例实战1、示例源码2、运行结果 前言 AVFormatContext 是一个贯穿始终的数据结构,很多函数都…

Java集合框架定义以及整体结构

目录 一、Java集合框架1.1 什么是java集合框架1.2 集合与数组 二、集合框架具体内容2.1 整体框架2.2 遗留类和遗留接口1.3 集合框架设计特点 参考资料 一、Java集合框架 1.1 什么是java集合框架 Java集合框架(Java Collections Framework)是Java平台提…

二叉树的遍历之迭代遍历

前言:在学习二叉树的时候我们基本上已经了解过二叉树的三种遍历,对于这三种遍历,我们采用递归的思路,很简单的就能实现,那么如何用迭代的方法去解决问题? 我们首先来看第一个: 前序遍历 144.…

代码随想录二刷 |二叉树 |二叉树的层平均值

代码随想录二刷 |二叉树 |二叉树的层平均值 题目描述解题思路代码实现 题目描述 637.二叉树的层平均值 给定一个非空二叉树的根节点 root , 以数组的形式返回每一层节点的平均值。与实际答案相差 10-5 以内的答案可以被接受。 示例 1: 输…

Avaya Aura Device Services 任意文件上传漏洞复现

0x01 产品简介 Avaya Aura Device Services是美国Avaya公司的一个应用软件。提供一个管理 Avaya 端点功能。 0x02 漏洞概述 Avaya Aura Device Services 系统PhoneBackup接口处存在任意文件上传漏洞,攻击者可绕过验证上传任意文件获取服务器权限。 0x03 影响范围…

结构体和位段

结构体: C语言中,我们之前使用的都是C语言中内置的类型,比如整形(int)、字符型(char)、单精度浮点型(float)等。但是我们知道,我们现实世界中,还…

用Rust刷LeetCode之27 移除元素

27. 移除元素 难度: 简单 原描述: 新描述: func removeElement(nums []int, val int) int { for i : 0; i < len(nums); i { if nums[i] val { nums append(nums[:i], nums[i1:]...) i-- } } return len(nums)} Rust 版本 下面这种写法编译无法通过: pub fn remove_…

b样条原理与测试

为了保留贝塞尔曲线的优点&#xff0c;同时克服贝塞尔曲线的缺点&#xff0c;b样条在贝塞尔曲线上发展而来&#xff0c;首先来看贝塞尔曲线的定义&#xff1a; 对于贝塞尔中的基函数而言&#xff0c;是确定的&#xff0c;全局唯一的&#xff0c;这导致了如果控制点发生变换将会…

Linux基本指令(超详版)

Linux基本指令&#xff08;超详版&#xff09; 1. ls指令2.pwd指令3. cd 指令4.touch指令5mkdir指令6.rmdir指令&&rm指令7.man指令7.cp指令8.mv指令9.echo指令10.cat指令11.more指令12.less指令13.head指令14.tail指令15.date指令16.find指令17.grep指令zip(打包压缩) …

使用cmake构建Qt6.6的qt quick项目,添加应用程序图标的方法

最近&#xff0c;在学习qt的过程中&#xff0c;遇到了一个难题&#xff0c;不知道如何给应用程序添加图标&#xff0c;按照网上的方法也没有成功&#xff0c;后来终于自己摸索出了一个方法。 1、准备一张图片作为图标&#xff0c;保存到工程目录下面&#xff0c;如logo.ico。 …

二维码智慧门牌管理系统:引领未来的城市管理

文章目录 前言一、主要特点二、升级带来的优势与意义 前言 随着科技的快速发展&#xff0c;智能化管理已经成为我们生活和工作的重要方面。门牌管理系统是城市管理的基础设施之一&#xff0c;其智能化程度直接影响着城市管理的效率和质量。为了适应这一需求&#xff0c;二维码…

Helio 升级为 LISTA DAO,开启多链时代新篇章并宣布积分空投计划

Helio Protocol 是 BNB Chain 上排名第一的去中心化稳定币协议&#xff0c;其推出的超额抵押和清算机制支持的去中心化稳定币 HAY&#xff0c;在 BNB Chain 有非常广泛的应用&#xff0c;包括流动性挖掘、质押、交易、储值等&#xff01; 2023 年 7 月&#xff0c;Helio Protoc…

【小沐学Python】Python实现语音识别(SpeechRecognition)

文章目录 1、简介2、安装和测试2.1 安装python2.2 安装SpeechRecognition2.3 安装pyaudio2.4 安装pocketsphinx&#xff08;offline&#xff09;2.5 安装Vosk &#xff08;offline&#xff09;2.6 安装Whisper&#xff08;offline&#xff09; 3 测试3.1 命令3.2 fastapi3.3 go…

C#注册表技术及操作

目录 一、注册表基础 1.Registry和RegistryKey类 &#xff08;1&#xff09;Registry类 &#xff08;2&#xff09;RegistryKey类 二、在C#中操作注册表 1.读取注册表中的信息 &#xff08;1&#xff09;OpenSubKey()方法 &#xff08;2&#xff09;GetSubKeyNames()…