SPP-学习笔记

news2024/11/19 16:24:18

Spatial Pyramid Pooling in Deep Convolutional Networks for Visual Recognition

SPP提出的原因

1、现有的深度卷积神经网络(spp出现之前的)需要固定大小的输入图像(例如224×224)。往往需要对图片裁剪或者resize,导致图片信息损失或者产生几何畸变。这样可能会损害任意大小比例的图像或子图像的识别精度。
在这里插入图片描述
2、使用SPP-net,只从整个图像中计算一次特征映射,然后将特征集中到任意区域(子图像)中,生成固定长度的表示,用于训练检测器。解决了卷积神经网络对图相关重复特征提取的问题,大大提高了产生候选框的速度,且节省了计算成本。

SPP 实现

在这里插入图片描述
黑色图片代表卷积层之后的特征图,随后我们以不同大小的块来提取特征,分别是4*4,2*2,1*1,就可以得到16+4+1=21种不同的块(Spatial bins).我们从这21个块中,每个块提取出一个特征,这样刚好就是我们要提取的21维特征向量。这种以不同的大小格子的组合方式来池化的过程就是空间金字塔池化(SPP)。 比如,要进行空间金字塔最大池化,其实就是从这21个图片块中,分别计算每个块的最大值,从而得到一个输出单元,最终得到一个21维特征的输出。 输出向量大小为Mk,M=#bins(块数), k=#filters(卷积核个数),作为全连接层的输入。 例如上图,feature map是任意大小的,经过SPP之后,变成固定大小的输出了,以上图为例,共输出(16+4+1)*256的特征。(有256个卷积核)

pooling 参数的计算

[W,H]输入尺寸,
level 可以看做是金字塔的层级,在上面的示意图中产生了三个层级(1,2,4)
k e r n e l _ s i z e = c e i l ( H l e v e l , W l e v e l ) kernel\_ size=ceil(\frac{H}{level},\frac{W}{level}) kernel_size=ceil(levelH,levelW)

s t r i d e = c e i l ( H l e v e l , W l e v e l ) stride=ceil(\frac{H}{level},\frac{W}{level}) stride=ceil(levelH,levelW)

p a d d i n g = ( f l o o r ( ( k e r n e l _ s i z e ∗ l e v e l − H + 1 ) 2 ) , f l o o r ( ( k e r n e l _ s i z e ∗ l e v e l − W + 1 ) 2 ) padding = (floor(\frac{(kernel\_size * level - H + 1)}{2}), floor(\frac{(kernel\_size * level - W + 1)}{2}) padding=(floor(2(kernel_sizelevelH+1)),floor(2(kernel_sizelevelW+1))

参考代码

from math import floor, ceil
import torch
import torch.nn as nn
import torch.nn.functional as F

class SpatialPyramidPooling2d(nn.Module):
    r"""apply spatial pyramid pooling over a 4d input(a mini-batch of 2d inputs
    with additional channel dimension) as described in the paper
    'Spatial Pyramid Pooling in deep convolutional Networks for visual recognition'
    Args:
        num_level:
        pool_type: max_pool, avg_pool, Default:max_pool
    By the way, the target output size is num_grid:
        num_grid = 0
        for i in range num_level:
            num_grid += (i + 1) * (i + 1)
        num_grid = num_grid * channels # channels is the channel dimension of input data
    examples:
        >>> input = torch.randn((1,3,32,32), dtype=torch.float32)
        >>> net = torch.nn.Sequential(nn.Conv2d(in_channels=3,out_channels=32,kernel_size=3,stride=1),\
                                    nn.ReLU(),\
                                    SpatialPyramidPooling2d(num_level=2,pool_type='avg_pool'),\
                                    nn.Linear(32 * (1*1 + 2*2), 10))
        >>> output = net(input)
    """

    def __init__(self, num_level, pool_type='max_pool'):
        super(SpatialPyramidPooling2d, self).__init__()
        self.num_level = num_level
        self.pool_type = pool_type

    def forward(self, x):
        N, C, H, W = x.size()
        for i in range(self.num_level):
            level = i + 1
            kernel_size = (ceil(H / level), ceil(W / level))
            stride = (ceil(H / level), ceil(W / level))
            padding = (floor((kernel_size[0] * level - H + 1) / 2), floor((kernel_size[1] * level - W + 1) / 2))

            if self.pool_type == 'max_pool':
                tensor = (F.max_pool2d(x, kernel_size=kernel_size, stride=stride, padding=padding)).view(N, -1)
            else:
                tensor = (F.avg_pool2d(x, kernel_size=kernel_size, stride=stride, padding=padding)).view(N, -1)

            if i == 0:
                res = tensor
            else:
                res = torch.cat((res, tensor), 1)
        return res
    def __repr__(self):
        return self.__class__.__name__ + '(' \
            + 'num_level = ' + str(self.num_level) \
            + ', pool_type = ' + str(self.pool_type) + ')'

class SPPNet(nn.Module):
    def __init__(self, num_level=3, pool_type='max_pool'):
        super(SPPNet,self).__init__()
        self.num_level = num_level
        self.pool_type = pool_type
        self.feature = nn.Sequential(nn.Conv2d(3,64,3),\
                                    nn.ReLU(),\
                                    nn.MaxPool2d(2),\
                                    nn.Conv2d(64,64,3),\
                                    nn.ReLU())
        self.num_grid = self._cal_num_grids(num_level)
        self.spp_layer = SpatialPyramidPooling2d(num_level)
        self.linear = nn.Sequential(nn.Linear(self.num_grid * 64, 512),\
                                    nn.Linear(512, 10))
    def _cal_num_grids(self, level):
        count = 0
        for i in range(level):
            count += (i + 1) * (i + 1)
        return count

    def forward(self, x):
        x = self.feature(x)
        x = self.spp_layer(x)
        print(x.size())
        x = self.linear(x)
        return x

if __name__ == '__main__':
    a = torch.rand((1,3,128,128))
    net = SPPNet()
    output = net(a)
    print(output)

我们注意到SPP对于深度cnn有几个spatial:

  • SPP能够产生固定长度的输出,而不管输入大小,而在以前的深度网络中使用的滑动窗口池化不能;
  • SPP使用多级spatial bin,而滑动窗口池只使用单一窗口大小。多级池已被证明更具有鲁棒性;
  • 由于输入尺度的灵活性,SPP可以将不同尺度提取的特征集合在一起。实验表明,这些因素都提高了深度网络的识别精度

其它特点

  1. 由于对输入图像的不同纵横比和不同尺寸,SPP同样可以处理,所以提高了图像的尺度不变(scale-invariance)和降低了过拟合(over-fitting)
  2. 实验表明训练图像尺寸的多样性比单一尺寸的训练图像更容易使得网络收敛(convergence)
  3. SPP 对于特定的CNN网络设计和结构是独立的。(也就是说,只要把SPP放在最后一层卷积层后面,对网络的结构是没有影响的, 它只是替换了原来的pooling层)
  4. 不仅可以用于图像分类而且可以用来目标检测

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/29404.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

奥比中光亮相全球1024开发者节,与科大讯飞达成战略合作

作者 | 奥比中光 编辑 | 3D视觉开发者社区 11月17日-23日,第五届世界声博会暨2022科大讯飞全球1024开发者节在安徽合肥举办,奥比中光作为3D视觉感知头部企业参展,并与科大讯飞达成战略合作,共同赋能3D视觉行业应用开发。 本次参…

如何利用现代工具来管理多项目

多项目管理是如今现代企业管理时常常遇到的一个难题。不同于单项目管理,多个项目同时进行管理要复杂得很多。而单纯的手工管理方式已经满足不了多管理的复杂需求,项目负责人想要保障在预定的时间内,又快又好地完成整体项目,便需要…

工厂模式解耦-交由spring来完成

上面两个小节一直在谈论解耦,从入门的多例到升级的单例BeanFactory工厂类是我们自己手工写的。 BeanFactory主要做了3件事: 1.读取配置文件(可以是properties或xml类型的文件,示例中用的是properties文件) 2.获取类…

OC RSA加密解密

好久好久没有更新了。。。你们等的急不急。。这不,我就姗姗来迟了。。。本文重点讲解一下iOS系统下的RSA加密解密问题。 一般为了安全,私钥是不会给前端暴露出来 的,只会通过私钥生成一个公开的公钥提供给外部对数据进行加密。将加密后的数据…

残差网络ResNet解读

一、残差网络的定义 残差网络的核心是解决增加深度带来的退化问题,这样能够通过单纯增加网络深度来提高网络性能。 残差单元以短连接的形式,将单元的输入直接与单元输出加在一起,然后再进行激活。 Weight为抽取特征的网络层 Addition时xl和…

RK3568平台开发系列讲解(视频篇)摄像头采集视频的相关配置

🚀返回专栏总目录 文章目录 一、权限配置二、配置摄像头2.1、打开摄像头2.2、预览格式2.3、预览尺寸沉淀、分享、成长,让自己和他人都能有所收获!😄 📢Android 平台的摄像头的采集核心部分都是在 Native 层构建的,所以这就会涉及 JNI 层的一些转换操作。 一、权限配置…

Linux | 进程间通信 | 匿名管道 | 命名管道 | 模拟代码实现进程通信 |

文章目录进程通信的意义匿名管道通信原理管道的访问控制进程控制管道的特点命名管道进程通信的意义 之前聊进程时,讲过一个性质,即进程具有独立性,两个进程之间的交互频率是比较少的。就连父子进程也只是共享代码,修改父子进程中…

MODBUS通信系列之数据处理

MODBUS通信专栏有详细文章讲解,这里不再赘述,大家可以自行查看。链接如下: SMART S7-200PLC MODBUS通信_RXXW_Dor的博客-CSDN博客_smart200做modbus通讯MODBUS 是 OSI 模型第 7 层上的应用层报文传输协议,它在连接至不同类型总线或网络的设备之间提供客户机/服务器通信。自…

化工机械基础期末复习题及答案

化工设备机械基础复习题 一 选择题 1、材料的刚度条件是指构件抵抗( B )的能力。 A.破坏 B.变形 C.稳定性 D.韧性 2、一梁截面上剪力左上右下,弯矩左顺右逆,描述正确的是&#xff08…

上班总结测试报告

出版社智能智造 测试报告 项目名称 出版社智能智造 测试版本 二期版本20221103 级别 用户使用 编写人 罗胜杰 日期 2022.11.15 目 录 1. 测试概述 1.1. 编写目的 1.2. 产品需求介绍 1.3. 参考资料 2. 测试计划执行情况 2.1. 测试范围及策略 2.2. 本…

[附源码]SSM计算机毕业设计基于的花店后台管理系统JAVA

项目运行 环境配置: Jdk1.8 Tomcat7.0 Mysql HBuilderX(Webstorm也行) Eclispe(IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持)。 项目技术: SSM mybatis Maven Vue 等等组成,B/S模式 M…

【Python百日进阶-WEB开发-冲進Flask】Day181 - Flask简单流程

文章目录一、day01项目环境和结构搭建1.1 新建虚拟环境1.2 安装Flask1.3 配置Python解释器二、后端知识要点2.1 Flask 文档2.2 实例化flask对象2.2.1 新建独立的配置文件settings.py2.2.2 实例化flask对象时加载配置文件2.3 基本路由2.3.1 常用路由及唯一性2.3.2 路由底层调用2…

中央空调系统运行原理以及相关设备介绍

目录前言一、中央空调系统工作原理1-1、工作原理1-2、中央空调系统构成二、室内空调三、制冷机组3-1、概述3-2、原理3-3、蒸发器3-4、冷凝器3-5、压缩机3-6、总结四、冷却塔总结前言 今天也是为了30岁开始养老而奋斗的一天。 一、中央空调系统工作原理 1-1、工作原理 中央空…

FFmpeg入门 - rtmp推流

FFmpeg入门 - 视频播放_音视频开发老马的博客-CSDN博客介绍了怎样用ffmpeg去播放视频. 里面用于打开视频流的avformat_open_input函数除了打开本地视频之外,实际上也能打开rtmp协议的远程视频,实现拉流: ./demo -p 本地视频路径 ​ ./demo -p rtmp://服务器ip/视频流路径 这篇…

JVM垃圾回收总结

常见面试题 如何判断对象是否死亡 简单介绍一下强引用、软引用、弱引用、虚引用 如何判断常量是一个废弃常量 如何判断类是一个无用类 垃圾收集有哪些算法、各自的特点? 常见的垃圾回收器有哪些? 介绍一下CMS,G1收集器? minor gc和…

[附源码]计算机毕业设计JAVA课后作业提交系统关键技术研究与系统实现

[附源码]计算机毕业设计JAVA课后作业提交系统关键技术研究与系统实现 项目运行 环境配置: Jdk1.8 Tomcat7.0 Mysql HBuilderX(Webstorm也行) Eclispe(IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持)。 项目技术&am…

[附源码]计算机毕业设计JAVA课堂点名系统

[附源码]计算机毕业设计JAVA课堂点名系统 项目运行 环境配置: Jdk1.8 Tomcat7.0 Mysql HBuilderX(Webstorm也行) Eclispe(IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持)。 项目技术: SSM mybatis M…

【2】Anaconda基本命令以及相关工具:jupyter、numpy、Matplotilb

上一篇请移步【1】Anaconda基本命令以及相关工具:jupyter、numpy、Matplotilb_水w的博客-CSDN博客 目录 3 Numpy数组基础索引:索引和切片 ◼ 基础索引 4 Numpy非常重要的数组合并与拆分操作 ◼ 数组的合并-concatenate、vstack、hstack numpy.vstac…

生产制造管理:供应商管理系统

随着经济全球化和信息技术的快速推进发展,传统的管理模式早已不再适应现代市场竞争与生产制造的需要,以顾客需求为中心的供应链管理显得更为重要。供应链是围绕核心企业,通过对信息流、物流、资金流等关键部分的控制连成一个整体的功能网链结…

期末前端web大作业——我的家乡陕西介绍网页制作源码HTML+CSS+JavaScript

家乡旅游景点网页作业制作 网页代码运用了DIV盒子的使用方法,如盒子的嵌套、浮动、margin、border、background等属性的使用,外部大盒子设定居中,内部左中右布局,下方横向浮动排列,大学学习的前端知识点和布局方式都有…