PyTorch 池化层详解

news2024/9/20 21:51:03

在深度学习中,池化层(Pooling Layer)是卷积神经网络(CNN)中的关键组成部分。池化层的主要功能是对特征图进行降维和减少计算量,同时增强模型的鲁棒性。本文将详细介绍池化层的作用、种类、实现方法,并对比其与卷积层的异同,以及深入探讨全局池化的应用。

1. 池化层的作用

池化层的核心作用包括以下几个方面:

  1. 降维:通过池化操作,可以减少特征图的空间尺寸(高度和宽度),从而降低计算复杂度。
  2. 特征提取:池化层保留局部区域的显著特征,如边缘、纹理等。
  3. 抑制噪声:池化操作可以抑制输入特征图中的局部噪声,提高模型的鲁棒性。
  4. 防止过拟合:通过减少特征图的尺寸和参数数量,池化层有助于防止模型过拟合。
2. 池化层的类型

池化层主要包括最大池化(Max Pooling)和平均池化(Average Pooling),此外还有全局池化(Global Pooling)。

2.1 最大池化(Max Pooling)

最大池化选取池化窗口内的最大值作为输出。这种方法能够保留特征图中最显著的特征,通常用于提取边缘等强特征。

import torch
import torch.nn as nn

# 创建一个二维最大池化层,池化窗口大小为2x2,步幅为2x2
maxpool2d = nn.MaxPool2d(kernel_size=2, stride=2)

# 输入数据为 (batch_size, channels, height, width)
input_tensor = torch.tensor([[[[1, 2, 3, 4],
                               [5, 6, 7, 8],
                               [9, 10, 11, 12],
                               [13, 14, 15, 16]]]], dtype=torch.float32)

# 进行池化操作
output_tensor = maxpool2d(input_tensor)
print(output_tensor)

输出结果为:

tensor([[[[ 6.,  8.],
          [14., 16.]]]])
2.1.1 最大池化的详细计算过程

最大池化(Max Pooling)是一种常见的池化操作,用于对输入特征图进行降维和特征提取。其核心思想是通过池化窗口(也称为滤波器)在特征图上滑动,并在每个窗口内选取最大值作为该窗口的输出,从而形成一个新的、尺寸较小的特征图。

1. 池化窗口(Pooling Window)

池化窗口是一个固定大小的矩形区域,通常用kernel_size参数指定。例如,kernel_size=2表示一个2x2的池化窗口。池化窗口在特征图上滑动,滑动的步幅用stride参数指定。例如,stride=2表示池化窗口每次滑动2个单位。

2. 操作过程

假设我们有一个输入特征图,每个池化窗口覆盖特征图的一部分,最大池化的具体操作步骤如下:

  1. 选择窗口位置:将池化窗口放置在特征图的左上角,覆盖一个kernel_size大小的区域。
  2. 计算最大值:在这个窗口内,找出所有元素的最大值。
  3. 记录结果:将这个最大值记录到输出特征图的对应位置。
  4. 滑动窗口:按照stride参数指定的步幅,滑动池化窗口到新的位置,重复步骤2和步骤3,直到整个特征图都被池化窗口覆盖。
3. 示例

假设我们有一个4x4的特征图,池化窗口大小为2x2,步幅为2。具体操作如下:

输入特征图:

[[1, 3, 2, 4],
 [5, 6, 8, 7],
 [4, 2, 1, 0],
 [9, 7, 3, 2]]

池化过程:

  1. 第一个窗口覆盖位置(左上角2x2):
    [[1, 3],
     [5, 6]]
    
    最大值为6。
  2. 第二个窗口覆盖位置(右上角2x2):
    [[2, 4],
     [8, 7]]
    
    最大值为8。
  3. 第三个窗口覆盖位置(左下角2x2):
    [[4, 2],
     [9, 7]]
    
    最大值为9。
  4. 第四个窗口覆盖位置(右下角2x2):
    [[1, 0],
     [3, 2]]
    
    最大值为3。

输出特征图:

[[6, 8],
 [9, 3]]

请添加图片描述

4. 代码实现

以下是使用PyTorch实现上述最大池化操作的代码示例:

import torch
import torch.nn as nn

# 定义一个2x2的最大池化层,步幅为2
maxpool2d = nn.MaxPool2d(kernel_size=2, stride=2)

# 输入数据 (1, 1, 4, 4) 表示 (batch_size, channels, height, width)
input_tensor = torch.tensor([[[[1, 3, 2, 4],
                               [5, 6, 8, 7],
                               [4, 2, 1, 0],
                               [9, 7, 3, 2]]]], dtype=torch.float32)

# 进行池化操作
output_tensor = maxpool2d(input_tensor)
print(output_tensor)

输出结果为:

tensor([[[[6., 8.],
          [9., 3.]]]])
2.2 平均池化(Average Pooling)

平均池化计算池化窗口内的平均值作为输出。它能够平滑特征图,通常用于减少噪声。

import torch
import torch.nn as nn

# 创建一个二维平均池化层,池化窗口大小为2x2,步幅为2x2
avgpool2d = nn.AvgPool2d(kernel_size=2, stride=2)

# 输入数据为 (batch_size, channels, height, width)
input_tensor = torch.tensor([[[[1, 2, 3, 4],
                               [5, 6, 7, 8],
                               [9, 10, 11, 12],
                               [13, 14, 15, 16]]]], dtype=torch.float32)

# 进行池化操作
output_tensor = avgpool2d(input_tensor)
print(output_tensor)

输出结果为:

tensor([[[[ 3.5,  5.5],
          [11.5, 13.5]]]])
3. 全局池化(Global Pooling)

全局池化是一种特殊的池化操作,它将整个特征图缩小为一个单独的值。全局池化通常用于卷积神经网络的最后一个池化层,目的是将特征图的空间维度完全去除,从而得到一个固定大小的输出。这对于连接全连接层(Fully Connected Layer)或进行分类任务非常有用。

3.1 全局平均池化(Global Average Pooling)

全局平均池化计算整个特征图的平均值。

import torch
import torch.nn as nn

# 定义一个全局平均池化层
global_avgpool = nn.AdaptiveAvgPool2d((1, 1))

# 输入数据 (batch_size, channels, height, width)
input_tensor = torch.tensor([[[[1, 3, 2, 4],
                               [5, 6, 8, 7],
                               [4, 2, 1, 0],
                               [9, 7, 3, 2]]]], dtype=torch.float32)

# 进行全局平均池化操作
output_tensor = global_avgpool(input_tensor)
print("全局平均池化后的特征图:", output_tensor)

输出结果为:

全局平均池化后的特征图: tensor([[[[4.2500]]]])
3.2 全局最大池化(Global Max Pooling)

全局最大池化计算整个特征图的最大值。

import torch
import torch.nn as nn

# 定义一个全局最大池化层
global_maxpool = nn.AdaptiveMaxPool2d((1, 1))

# 输入数据 (batch_size, channels, height, width)
input_tensor = torch.tensor([[[[1, 3, 2, 4],
                               [5, 6, 8, 7],
                               [4, 2, 1, 0],
                               [9, 7, 3, 2]]]], dtype=torch.float32)

# 进行全局最大池化操作
output_tensor = global_maxpool(input_tensor)
print("全局最大池化后的特征图:", output_tensor)

输出结果为:

全局最大池化后的特征图: tensor([[[[9.]]]])
3.3 全局池化的应用

全局池化在深度学习模型中有许多应用,特别是在卷积神经网络(CNN)中。以下是一些常见的应用场景:

  1. 简化模型结构:全局池化可以将特征图的空间维度完全去除,从而简化模型结构。这使得模型在处理不同尺寸的输入时更加灵活。
  2. 减少参数:全局池化可以减少全连接层的参数数量,因为它将特征图缩小为一个固定大小的输出。这有助于降低模型的复杂度和过拟合风险。
  3. 提高模型的泛化能力:全局池化通过聚合整个特征图的信息,可以提高模型的泛化能力,使其在不同数据集上表现更好。
3.4 全局池化与传统池化的对比
特性传统池化(如 MaxPool2d, AvgPool2d)全局池化(Global Pooling)
池化窗口大小固定大小(如 2x2, 3x3)覆盖整个特征图
输出尺寸依据池化窗口大小和步幅固定为 1x1
主要用途局部特征提取和降维全局特征聚合和降维
计算复杂度较低较低
参数数量无参数无参数
4. 池化层和卷积层的对比

池化层和卷积层在使用滑动窗口和降维方面有相似之处,但它们的功能和作用不同。

相似之处
  1. 滑动窗口(Kernel):两者都使用固定大小的窗口在特征图上滑动。
  2. 降维:两者都可以通过设置适当的步幅(stride)来减少特征图的空间尺寸。
  3. 步幅(Stride):两者都可以设置步幅来控制滑动窗口的移动步长,从而影响输出特征图的大小。
不同之处
  1. 操作性质

    • 池化层:主要用于降维和特征选择,操作较为简单(如最大值或平均值计算)。池化层无参数更新,不涉及学习过程。
    • 卷积层:用于特征提取,通过卷积运算捕捉局部特征。卷积层包含可学习的参数(卷积核),这些参数通过反向传播进行更新。
  2. 输出特征图的内容

    • 池化层:输出的特征图是输入特征图的一种精简表示,保留了局部区域的显著特征(如最大值或平均值)。
    • 卷积层:输出的特征图是通过卷积核的加权求和得到的,能够捕捉到输入特征图的不同特征(如边缘、纹理等)。
  3. 学习能力

    • 池化层:无学习能力,不含可学习的参数。
    • 卷积层:具有学习能力,卷积核参数通过训练过程进行优化。
5. 计算输出特征图的大小

池化操作后输出特征图的大小可以通过以下公式计算。假设输入特征图的高度和宽度分别为 H in H_{\text{in}} Hin W in W_{\text{in}} Win,池化窗口的大小(即 kernel size)为 k h × k w k_h \times k_w kh×kw,步幅(stride)为 s h s_h sh s w s_w sw,填充(padding)为 p h p_h ph p w p_w pw

无填充情况下的输出大小计算

在无填充(padding = 0)的情况下,输出特征图的高度 H out H_{\text{out}} Hout 和宽度 W out W_{\text{out}} Wout 可以通过以下公式计算:

H out = ⌊ H in − k h s h ⌋ + 1 H_{\text{out}} = \left\lfloor \frac{H_{\text{in}} - k_h}{s_h} \right\rfloor + 1 Hout=shHinkh+1

W out = ⌊ W in − k w s w ⌋ + 1 W_{\text{out}} = \left\lfloor \frac{W_{\text{in}} - k_w}{s_w} \right\rfloor + 1 Wout=swWinkw+1

有填充情况下的输出大小计算

在有填充的情况下,填充的大小分别为 p h p_h ph p w p_w pw,输出特征图的高度 H out H_{\text{out}} Hout 和宽度 W out W_{\text{out}} Wout 可以通过以下公式计算:

H out = ⌊ H in + 2 p h − k h s h ⌋ + 1 H_{\text{out}} = \left\lfloor \frac{H_{\text{in}} + 2p_h - k_h}{s_h} \right\rfloor + 1 Hout=shHin+2phkh+1

W out = ⌊ W in + 2 p w − k w s w ⌋ + 1 W_{\text{out}} = \left\lfloor \frac{W_{\text{in}} + 2p_w - k_w}{s_w} \right\rfloor + 1 Wout=swWin+2pwkw+1

总结

池化层在深度学习中扮演着重要角色,通过降维、特征提取和抑制噪声等功能,显著提高了模型的计算效率和鲁棒性。最大池化和平均池化是最常见的池化操作,而全局池化作为一种特殊的池化方法,在简化模型结构和提高泛化能力方面表现突出。了解池化层的工作原理和应用,对于设计和优化高效的深度学习模型至关重要。

参考链接

PyTorch概述
Pytorch :张量(Tensor)详解
PyTorch 卷积层详解
PyTorch 全连接层(Fully Connected Layer)详解
PyTorch 池化层详解
PyTorch 激活函数及非线性变换详解

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2142175.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

BLE 协议之物理层

目录 一、概述二、Physical Channel1、物理通道2、物理通道的细分 三、调制1、调制方式2、GFSK 四、发射机五、接收机六、收发机 一、概述 物理层(Physical Layer)是 BLE 协议栈最底层,它规定了 BLE 通信的基础射频参数,包括信号频…

Minio环境搭建(单机安装包、docker)(一)

前言: 项目中客户不愿意掏钱买oss,无奈只能给他免费大保健来一套。本篇文章只是记录验证可行性,毕竟minio太少文档了,参考着官网来。后面还会再出一套验证集群部署的文章。 一、资料 MinIO官网: MinIO | S3 Compatib…

Windows 安装 ZooKeeper 以及 IDEA 安装 zoolytic 连接工具

目录 前言 下载 解压 配置 启动服务 zoolytic 前言 在前公司做微服务开发时,使用的都是 Spring Cloud 的生态,服务的注册与发现中心用的 Eureka,也有使用 Nacos 的,远程调用则是用的 OpenFeign,换工作后&#x…

istio中serviceentry结合vs、dr实现多版本路由

假设有一个外部服务,外部服务ip为:10.10.102.90,其中32033为v1版本,32034为v2版本。 现在需要把这个服务引入到istio中,并且需要配置路由规则,使得header中x-version的值为v1的路由到v1版本,x-…

Gitee Pipeline 从入门到实战【详细步骤】

文章目录 Gitee Pipeline 简介Gitee Pipeline 实战案例 1 - 前端部署输入源NPM 构建Docker 镜像构建Shell 命令执行案例 2 - 后端部署全局参数输入源Maven 构建Docker 镜像构建Shell 命令执行参考🚀 本文目标:快速了解 Gitee Pipeline,并实现前端及后端打包部署。 Gitee Pi…

MYSQL数据库——MYSQL管理

MYSQL数据库安装完成后,自带四个数据库,具体作用如下: 常用工具 1.mysql 不是指mysql服务,而是指mysql的客户端工具 例如: 2.mysqladmin 这是一个执行管理操作的客户端程序,可以用它来检查服务器的配置和…

SpringMVC映射请求;SpringMVC返回值类型;SpringMVC参数绑定;

一,SpringMVC映射请求 SpringMVC 使用 RequestMapping 注解为控制器指定可以处理哪些URL请求 1.1RequestMapping修饰类 注解RequestMapping修饰类,提供初步的请求映射信息,相对于WEB应用的跟目录。 注: 如果在类名前&#xff0…

【车载开发系列】ParaSoft单元测试环境配置(三)

【车载开发系列】ParaSoft单元测试环境配置(三) 【车载开发系列】ParaSoft单元测试环境配置(三) 【车载开发系列】ParaSoft单元测试环境配置(三)一. 去插桩设置Step1:静态解析代码Step2:编辑Parasoft文件Step3:确认去插桩二. 新增测试用例Step1:生成测试用例Step2:执…

【网络安全】Node.js初探+同步异步进程

未经许可,不得转载。 文章目录 Node.js 基础介绍NPM 包管理安装同步与异步fs 模块示例child_process 模块Node.js 基础介绍 Node.js 是运行在服务器端的 JavaScript 环境。它基于 Chrome 的 V8 引擎,拥有高效的执行性能。Node.js 采用事件驱动的 I/O 模型,使得它在处理高并…

策略路由与路由策略的区别

🐣个人主页 可惜已不在 🐤这篇在这个专栏 华为_可惜已不在的博客-CSDN博客 🐥有用的话就留下一个三连吧😼 目录 一、主体不同 二、方式不同 三、规则不同 四、定义和基本概念 一、主体不同 1、路由策略:是为了改…

苹果 2024 秋季新品发布会一文汇总:iPhone 16 / Pro 登场、手表耳机齐换代

✌ 作者名字:高峰君主 📫 如果文章知识点有错误的地方,请指正!和大家一起学习,一起进步👀 💬 人生格言:没有我不会的语言,没有你过不去的坎儿。💬 &#x1f5…

跟着DAMA学数据管理--数据管理框架

数据治理框架 数据治理框架是一套全面、系统的结构和方法,用于指导和管理组织内数据的整个生命周期,以确保数据的质量、可用性、安全性和一致性,从而实现数据的价值最大化。 它通常涵盖了一系列的策略、流程、组织架构、技术和标准。策略方面…

汽车电子笔记之-013:旋变硬解码ADI芯片AD2S1210使用记录(从零开始到软件实现)

目录 1、概述 2、技术规格 3、芯片引脚 4、旋变信号格式 5、使用过程只是要点分析 5.1、程序注意点分析 5.1.1、SPI配置时序 5.1.2、问题一:SPI时序问题 5.1.3、问题二:SPI读取时序(配置模式) 5.1.4、问题三&#xff1a…

canal消费binlog异常排查

canal简介 canal是一款优秀的订阅MySQL binlog的中间件,在MySQL异构数据到其它存储平台领域非常的实用好用。而且在数据表的迁移中也可以用canal订阅,然后将更新实时同步到新表。 原理 canal部署后伪装为一个MySQL slave节点向DB发起同步binlog请求&am…

Istio下载及安装

Istio 是一个开源的服务网格,用于连接、管理和保护微服务。以下是下载并安装 Istio 的步骤。 官网文档:https://istio.io/latest/zh/docs/setup/getting-started/ 下载 Istio 前往Istio 发布页面下载适用于您的操作系统的安装文件,或者自动…

系统架构-面向对象

有对象和没对象一样,鉴于今天中秋节 所以明天姐姐我就恢复单身了,忍这几个小时也没关系,一点不重要了

『功能项目』伤害数字UI显示【53】

我们打开上一篇52眩晕图标显示的项目, 本章要做的事情是在Boss受到伤害时显示伤害数字 首先打开Boss01预制体空间在Canvas下创建一个Text文本 设置Text文本 重命名为DamageUI 设置为隐藏 编写脚本:PlayerCtrl.cs 运行项目 本章做了怪物受伤血量的显示UI…

iOS 18 新功能:控制中心大變身!控制項目自由選配

蘋果於 Apple iOS 18 中為控制中心帶來大改變,變得更具有擴充性,而且將支援第三方應用的控制按鈕,中心內的組件大小也可調節。如今 iOS 18 正式上線,我們就可以試試控制中心不同項目自由選配帶來的效果。 組件可在三尺寸之間調整 …

十五、谷粒商城- 报错汇总

🌻🌻目录🌻🌻 一、谷粒商城- 分布式基础&环境搭建(1)1.1 项目构建完clean报错1.2 idea安装插件报错 二、谷粒商城- 快速开发之Spring Cloud Alibaba(2)2.1 配置完renren-fast 启…

园林建筑物类型检测系统源码分享

园林建筑物类型检测检测系统源码分享 [一条龙教学YOLOV8标注好的数据集一键训练_70全套改进创新点发刊_Web前端展示] 1.研究背景与意义 项目参考AAAI Association for the Advancement of Artificial Intelligence 项目来源AACV Association for the Advancement of Comput…