Pytorch常用的函数(三)深度学习中常见的卷积操作详细总结

news2024/12/28 19:40:12

Pytorch常用的函数(三)深度学习中常见的卷积操作

1、标准卷积(Standard Convolution)

1.1 标准卷积的理解

我们直接来看二维卷积,这在实际应用中是最常见的。

在这里插入图片描述

上图中Conv 2D其实就是卷积核,也叫做滤波器。滤波器的值决定了输出的情况,模型训练的过程就是调整这些值,使网络的输出更加准确。我们看看卷积算法如何将滤波器和输入结合在一起。

从概念上将,我们每次处理一个滤波器,要计算一个输出值,我们要看一下位于滤波器窗口下的输入区域。
在这个例子中,我们看一个3x3的像素区域,用这个区域和滤波器来计算输出。可以把滤波器看成包含某种模式,而输出值就是输入与这种模式的匹配程度。在网络中,比如一个浅层可能在寻找某种颜色,或者一个边缘;而深层可能在寻找一只狗。

在这里插入图片描述

计算完这个区域的值后,我们就移动到下一个区域,再次进行同样的计算来得到下一个输出值。

在这里插入图片描述

我们持续这个过程,计算出这一行中的其他输出。然后,我们在输入和输出中都换到新的一行,然后对这新的一行重复我们之前的操作。

我们按行地继续这个过程,直到整个输入的空间范围都被覆盖到。

在这里插入图片描述

换到下一个滤波器。用这个新的滤波器重复整个过程,形成输出的下一个特性。我们对每个滤波器都这样做,以形成输出的所有特性。

在这里插入图片描述

1.2 Pytorch中的API

torch.nn.Conv2d(
     in_channels,   # 输入通道数,即卷积核通道数
     out_channels,  # 输出通道数,即卷积核个数
     kernel_size, # 核大小
     stride=1,   # 步幅
     padding=0,  # 填充
     dilation=1, # 控制kernel点之间的空间距离
     groups=1,   # 分组卷积
     bias=True, 
     padding_mode='zeros', # 图像四周默认填充值为0
     device=None, 
     dtype=None
)

  • in_channels :代表输入特征矩阵的深度即channel,比如输入一张RGB彩色图像,那in_channels=3

  • out_channels:代表卷积核的个数,使用n个卷积核输出的特征矩阵深度,即channel就是n

  • kernel size:代表卷积核的尺寸,输入可以是int类型,如3代表卷积核的height=width=3,也可以是tuple类型如(3,5)代表卷积核的height=3,width=5

  • stride:代表卷积核的步距默认为1,和kernel size一样输入可以是int类型,也可以是tuple类型

  • padding:代表在输入特征矩阵四周补零的情况默认为0,同样输入可以为int型如1代表上下方向各补一行0元素、左右方向各补一列0像素(即补一圈0) ,如果输入为tuple型如(2,1)代表在上方补两行,下方补两行,左边补一列,右边补一列。

  • bias参数表示是否使用偏置 (默认使用)

  • dilation、groups是高阶用法

  • CNN的卷积核通道数 = 卷积输入层的通道数

  • CNN的卷积输出层通道数(深度) = 卷积核的个数

标准卷积的参数及计算量的计算

在这里插入图片描述

经卷积后的矩阵尺寸大小计算公式为 :

N = (W - K + 2P) / S + 1

例如:输入的矩阵 H=W=5,卷积核的K=2,S=2,Padding=1。

N = (5 - 2 + 2✖1) / 2 + 1 = 3.5

此时在Pytorch中是如何处理呢?

结论: 在卷积过程中会直接将最后一行以及最后一列给忽略掉,以保证N为整数,此时N= (5 - 2 + 2 * 1 - 1) / 2 + 1 = 3 [即向下取整]

注意:卷积核中的in_channels与需要进行卷积操作的数据x的channels一致

1.3 案例

举个例子:

输入一个12×12×3的一个输入特征图,经过一个5×5×3的卷积核卷积,得到一个8×8×1的输出特征图。如果此时我们有256个卷积核,我们将会得到一个8×8×256的输出特征图。

在这里插入图片描述

import torch
import torch.nn as nn

x = torch.rand(size=(1, 3, 12, 12))


model =  nn.Sequential(
     nn.Conv2d(
               in_channels=3,  # 卷积核中的in_channels与需要进行卷积操作的数据x的channels一致
               out_channels=1, # 输出通道数,即卷积核个数
               kernel_size=(5,5)
              )
)

# torch.Size([1, 1, 8, 8])
# N = (W - K + 2P) / S + 1 = (12 - 5 + 2 * 0 ) / 1 + 1 = 8
print(model(x).shape)  



model =  nn.Sequential(
     nn.Conv2d(
         in_channels=3,    # 输入通道数,即卷积核通道数
         out_channels=256, # 输出通道数,即卷积核个数
         kernel_size=(5,5) # 卷积核大小
     )
)

# torch.Size([1, 256, 8, 8])
# N = (W - K + 2P) / S + 1 = (12 - 5 + 2 * 0 ) / 1 + 1 = 8
print(model(x).shape) 

2、分组卷积(Group Convolution)

2.1 分组卷积的理解

对于需要区分各种视觉场景的大型的深度网络,我们需要大量的特征,尤其是在更深的层次,这就暴漏了卷积的性能扩展问题。

如下图,在更深的层次,每层的输入和输出特征数量都在增加。

在这里插入图片描述

如下图,增加输入特征的channels,会让滤波器更深(数量更多),增加输出特征的channels,就会有更多的滤波器,因此特征数量的倍增会让计算量增加4倍。

原来

在这里插入图片描述

增加输入特征的channels、增加输出特征的channels

在这里插入图片描述

想一想,每个滤波器真的需要查看输入的每个特征吗?肯定不是

因此,我们可以把输入特征分为两组,每个滤波器只需要查看其中一组就行。滤波器的前半部分会查看第一组输入,后半部分会查看另一组。

在这里插入图片描述

我们开始从第一组输入特征出发,使用对应的滤波器。注意每个滤波器的深度只和组的深度相同,而不是和整个输入的深度相同。这就是我们想要的性能提升。

在这里插入图片描述

当我们用完了一半的滤波器,就转向下一组特征,继续使用剩下的滤波器。这就和把输入和滤波器分开,执行单独的卷积,然后把结果拼接起来没有什么区别【设置groups=2】。

在这里插入图片描述

2.2 Pytorch中的API

对输入feature map进行分组,然后每组分别卷积。这种分组只是在深度上【channels】进行划分,即某几个通道编为一组,这个具体的数量由 (C1/g) 决定。例如,输入的feature map的通道数C1=20,我们分为g=5组,那么每一组有4个卷积核。

在这里插入图片描述

torch.nn.Conv2d(
     in_channels,   
     out_channels,  
     kernel_size, 
     stride=1,   
     padding=0,  
     dilation=1, 
     groups=1,   # 分组卷积,默认为1组
     bias=True, 
     padding_mode='zeros' # 图像四周默认填充值为0
)

2.3 案例

x = torch.rand(size=(1, 256, 12, 12))

model =  nn.Sequential(
     nn.Conv2d(
          in_channels=256, 
          out_channels=32,   # 输出通道数,即卷积核的个数为256/8=32
          kernel_size=(3,3),
          padding=1,
          groups=8           # 分为8组
     )
)

# torch.Size([1, 32, 12, 12])
print(model(x).shape)

3、逐点卷积(PW Convolution)

3.1 逐点卷积的理解

在这里插入图片描述

Pointwise Convolution的运算与常规卷积运算非常相似,它的券积核的尺寸为1X1XM,M为上一层的通道数。

所以这里的卷积运算会将上一步的map在深度方向上进行加权组合,目的是: 生成新的Feature map

3.2 Pytorch中的API及案例

torch.nn.Conv2d(
     in_channels,   
     out_channels,  
     kernel_size, # 逐点卷积,将卷积核的大小设置为1
     stride=1,   
     padding=0,  
     dilation=1, 
     groups=1,   
     bias=True, 
     padding_mode='zeros'
)

假设我们得到了8×8×3的特征图,我们用256个1×1×3的卷积核对输入特征图进行卷积操作,输出的特征图变为8×8×256了

x = torch.rand(size=(1, 3, 8, 8))

model =  nn.Sequential(
     nn.Conv2d(
          in_channels=3,
          out_channels=256, # 输出通道数,即卷积核的个数,256个1×1×3的卷积核
          kernel_size=(1,1) # 逐点卷积,卷积核的大小为1
     )
)

# torch.Size([1, 256, 8, 8])
print(model(x).shape)

4、深度卷积(DW Convolution)

4.1 深度卷积的理解

想一想,如果我们把所有的组都分出来,会有什么问题【这就是深度卷积】。

存在的问题就是:通道数太少,特征图的维度太少,不能获取到足够的有效信息,可以通过结合逐点卷积进行解决,即【深度可分离卷积】。

在这里插入图片描述

深度卷积(逐通道卷积)参数量的计算

在这里插入图片描述

4.2 Pytorch中的API及案例

groups就是实现深度卷积的关键,默认为1,意思是将输入分为一组,此时是常规卷积

当将其设为in channels时,意思是将输入的每一个通道作为一组,然后分别对其卷积。

torch.nn.Conv2d(
     in_channels,   
     out_channels=in_channels,  # 深度卷积,out_channels=in_channels
     kernel_size, 
     stride=1,   
     padding=0,  
     dilation=1, 
     groups=in_channels,        #  深度卷积,将输入的每一个通道作为一组,groups=in_channels
     bias=True, 
     padding_mode='zeros'
)

在这里插入图片描述

与标准卷积网络不一样的是,我们将卷积核拆分成为但单通道形式,在不改变输入特征图像的深度的情况下,对每一通道进行卷积操作,这样就得到了和输入特征图通道数一致的输出特征图。如上图:输入12×12×3的特征图,经过5×5×1×3的深度卷积之后,得到了8×8×3的输出特征图。输入个输出的维度是不变的3。

x = torch.rand(size=(1, 3, 12, 12))

model =  nn.Sequential(
     nn.Conv2d(
          in_channels=3,    
          out_channels=3,    # out_channels=in_channels
          kernel_size=(5,5),
          groups=3           # groups=in_channels
     )
)

# torch.Size([1, 3, 8, 8])
print(model(x).shape)

5、深度可分离卷积(PW+DW)

5.1 深度分离卷积的理解

在深度卷积中,我们将所有的组都进行了分离。但是,此时我们发现:第一个输出特征只依赖于第一个输入特征。

在这里插入图片描述

这个模式会在网络的更深层次中持续。这样我们永远都不会得到像只有一组那样的全部表达能力。

在这里插入图片描述

一个滤波器从原始图中得到第一个输出特征,此时表达能力强

在这里插入图片描述

如何解决这个问题呢?我们可以在每个深度卷积后,加上1个标准的1✖1的卷积【逐点卷积】,而不是堆叠深度卷积。

【逐点卷积】在空间上只有一个像素,同时接收所有输入特征。

在这里插入图片描述

这与深度卷积完美地互补,深度卷积在空间上有3x3的接受区域,但只有一个特征。

在这里插入图片描述

当我们把它们结合起来,两层的输出都有3x3的空间接受区域和所有原始特征。这完美地匹了有一组的3x3券积的接收区域。这就是【深度可分离卷积】。你可能已经注意到了,点对点卷积让我们原来特征数量翻倍导致计算量增4倍的问题又回来了,但我们相对标准卷积仍然领先。如果你看看3x3深度可分离卷积执行的总计算量,它只是标准3x3卷积的计算量的大约11%,换句话说,它快了9倍。严格地说,
加速取决干特征数量,但随着特征数量的增加,加速越来越接近理想的9倍速度。

在这里插入图片描述

5.2 深度分离卷积的案例

深度可分离卷积就是将普通卷积拆分成为一个深度卷积和一个逐点卷积

输入一个12×12×3的一个输入特征图,经过5×5×3的卷积核卷积得到一个8×8×1的输出特征图。如果此时我们有256个特征图,我们将会得到一个8×8×256的输出特征图。

标准卷积

# 标准卷积
x = torch.rand(size=(1, 3, 12, 12))

model =  nn.Sequential(
     nn.Conv2d(
          in_channels=3,
          out_channels=256,
          kernel_size=(5,5)
     )
)

# torch.Size([1, 256, 8, 8])
print(model(x).shape)

深度可分离卷积

# 深度可分离卷积

x = torch.rand(size=(1, 3, 12, 12))

model =  nn.Sequential(
     # 深度卷积
     nn.Conv2d(
          in_channels=3,
          out_channels=3,    # out_channels=in_channels
          kernel_size=(5,5),
          groups=3           # groups=in_channels
     ),
     # 逐点卷积
     nn.Conv2d(
          in_channels=3,
          out_channels=256, # 输出通道数,即卷积核的个数,256个1×1×3的卷积核
          kernel_size=(1,1) # 逐点卷积,卷积核的大小为1
     )

)

# torch.Size([1, 256, 8, 8])
print(model(x).shape)

5.3 计算量的对比

标准卷积

在这里插入图片描述

深度可分离卷积

在这里插入图片描述

因此:

在这里插入图片描述

我们通常所使用的是3×3的卷积核,也就是会下降到原来的九分之一到八分之一。

5.4 代码实现

在这里插入图片描述

"""
     深度分离卷积
"""
 
import torch
import torch.nn as nn
 
 
class Depth_Wise_Conv(nn.Module):
    """
        深度可分离卷积 = 深度卷积 + 逐点卷积调整通道
    """
    def __init__(self, in_channel, out_channel):
        super(Depth_Wise_Conv, self).__init__()
        # 深度卷积
        self.conv_group = nn.Conv2d(in_channel, in_channel, kernel_size=3, stride=1, padding=1,                                                  groups=in_channel)
        # 逐点卷积调整通道
        self.conv_point = nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=1, padding=1, groups=1)
        # BN
        self.bn = nn.BatchNorm2d(out_channel)
        # activate
        self.act = nn.ReLU()
 
    def forward(self, inputs):
        """
            前向传播
        """
        x = self.conv_group(inputs)
        x = self.conv_point(x)
        x = self.bn(x)
        x = self.act(x)
        return x
 
 
if __name__ == '__main__':
    # 均匀分布产生数据
    x = torch.rand(1, 3, 16, 16)
    model = Depth_Wise_Conv(3, 16)
    model = model(x)
    print(model)

参考博客:
图像部分 https://animatedai.github.io/
常用卷积总结 https://zhuanlan.zhihu.com/p/490761167
轻量级神经网络“巡礼”(二)—— MobileNet,从V1到V3 https://zhuanlan.zhihu.com/p/70703846

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/701925.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【Java】Java核心 86:Git 教程(9)GIT远程仓库操作

文章目录 14.GIT远程仓库操作-关联、拉取、推送、克隆目标内容小结 Git提供了一系列命令来进行远程仓库的操作。 下面是一些常用的Git远程仓库操作&#xff1a; 克隆远程仓库到本地&#xff1a; git clone <远程仓库URL>查看远程仓库信息&#xff1a; git remote -v添…

功能键F4在Microsoft Excel中有什么用

的确,许多 Excel 用户发现使用键盘快捷键对他们来说更有效。事实上,键盘快捷键可能是使用鼠标的最佳选择,因为使用 Excel 时使用触摸屏可能不是视力障碍者的最佳选择。 使用功能键,如 Excel 中的 F4 以及 F2 可能是非常必要的。在这篇文章中,我们将研究功能键 F4 及其在 …

【Java】Java核心 85:Git 教程(8)GIT远程仓库介绍与码云仓库注册创建

文章目录 13.GIT远程仓库介绍与码云仓库注册创建目标小结 Git是一个分布式版本控制系统&#xff0c;它允许多个开发者协同工作并管理代码的版本。远程仓库是存放在网络上的Git仓库&#xff0c;可以用于团队成员之间的代码共享和协作。 常见的远程仓库托管服务提供商有GitHub、…

XShell、XFtp、Linux上MySQL的远程连接及使用

下载资源包&#xff0c;请于文章顶部下载即可 XShell的使用 1. 打开安装好的XShell 2. 点击左上角新建连接 3. 填写相应连接服务器信息 4. 输入需要连接到Linux操作系统哪个用户的用户名 5. 输入连接到用户的密码 6. 远程登录Linux成功 7. 此时可以正常使用Linux指令操作Linu…

chatgpt赋能python:隐藏鼠标:Python实现隐藏鼠标的应用

隐藏鼠标&#xff1a;Python实现隐藏鼠标的应用 作为一名有10年Python编程经验的工程师&#xff0c;我深知Python在图形用户界面(GUI)开发上的优势&#xff0c;其中一个有趣而且有用的应用就是隐藏鼠标。 在某些情况下&#xff0c;用户可能希望隐藏鼠标&#xff0c;这可以用于…

ARM-异常与中断(四)

文章目录 中断中断请求、中断源中断服务程序保存现场、恢复现场中断仲裁、中断优先级中断嵌套 异常广义上的异常同步异常异步异常精确异步异常&#xff08;Precise Asynchronous Exception&#xff09;非精确异步异常&#xff08;Imprecise Asynchronous Exception&#xff09;…

【DBA专属】MHA高可用数据库集群-----------一主一备两从一管理,一个VIP客户端

MHA高可用数据库集群 目录 环境配置&#xff1a; 所有操作系统均为centos 7.x 64bit 1、关闭防火墙&#xff1a;&#xff08;所有服务器&#xff09; 2.配置所有主机名映射&#xff08;所有服务器&#xff09; 3、同步时区 4.安装MHA node及相关perl依赖包&#xff08;所有…

AutoSAR系列讲解(入门篇)4.1-BSW概述

BSW概述 一、什么是BSW 二、BSW的结构 1、微控制器硬件抽象层&#xff08;MCAL&#xff09; 2、ECU抽象层 3、服务层 4、复杂驱动 三、再将结构细分 一、什么是BSW 中文翻译就是基础软件层&#xff08;Basic Software&#xff09;。这个基础软件层实质上就是将整个ECU分…

【斯坦福】FrugalGPT: 如何使用大型语言模型,同时降低成本并提高性能

FrugalGPT: 如何使用大型语言模型&#xff0c;同时降低成本并提高性能 作者&#xff1a;Lingjiao Chen, Matei Zaharia, James Zou 引言 本文介绍了一种新颖的方法&#xff0c;旨在解决使用大型语言模型&#xff08;LLM&#xff09;时面临的成本和性能挑战。随着GPT-4和Chat…

链路聚合综合实战

拓扑 需求 -PC1和PC3属于vlan 10、PC2和PC4属于vlan 20 -设备之间配置lacp模式的链路聚合&#xff0c;并确保同vlan之间的主机可以互通 配置步骤 1&#xff09;PC配置IP地址 2&#xff09;所有交换机创建vlan10 和vlan20 3&#xff09;交换机和PC互联的接口设置为access &am…

python数据分析之连接MySQL数据库并进行数据可视化

大家好&#xff0c;我是带我去滑雪&#xff01; 本期将熟悉MySQL数据库以及管理和操作MySQL数据库的数据库管理工具Navicat Premium&#xff0c;然后在python中调用MySQL数据库进行数据分析和数据可视化。 目录 1、MySQL数据库与数据库管理工具Navicat Premium 2、调用MySQL…

EasyCVR如何实现国标级联无人机推送的RTMP推流通道?

EasyCVR视频融合平台基于云边端一体化架构&#xff0c;可支持多协议、多类型设备接入&#xff0c;包括&#xff1a;NVR、IPC、视频编码器、无人机、车载设备、智能手持终端、移动执法仪等。平台具有强大的数据接入、处理及分发能力&#xff0c;可在复杂的网络环境中&#xff0c…

el-date-picker禁用指定日期之前或之后的日期

一、elementUI中el-date-picker禁用指定日期之前或之后的日期 通过配置picker-options配置指定禁用日期&#xff08;pickerOptions写到data里面&#xff09; <el-date-pickerv-model"date"type"date"size"small"value-format"yyyy-MM-d…

Linux-passwd或shadow删了怎么办?

passwd或者shadow被删除了怎么办 passwd和shadow被删了&#xff0c;可以启用营救模式进行补救&#xff0c;原因是这两个文件都有备份。 先将光驱的自动启动勾选。 将Boot的引导顺序改变下&#xff0c;将光驱引导顺序放到最前面。 选择Troubleshootin可以从名字来知道&#…

广州华锐互动:机电专业VR模拟实操教学平台提供沉浸式的实践操作和训练机会

虚拟现实(VR)技术是一种先进的技术&#xff0c;可以应用于机电专业的培训中。以下是VR技术应用到机电专业培训的一些好处&#xff1a; 模拟实际操作环境&#xff1a;VR技术可以创建一个虚拟的环境&#xff0c;模拟真实的机械和电气设备的操作环境。这使得学生可以在安全的环境…

Python最基础语法

文章目录 一、简介1、Python安全路径2、Python开发工具(PyCharm) 二、PyCharm使用1、新建项目位置2、Hello World3、查看python版本4、PEP8规范 三、标识符和关键字四、基本数据类型1、数据类型2、多数据赋值&#xff1a;3、标准数据类型4、格式化输出5、输入(input)6、格式转换…

电商系统架构设计系列(四):流量大、数据多的「商品详情页系统」该如何设计?

一个电商的商品系统&#xff0c;主要功能就是增删改查商品信息。 上篇文章中&#xff0c;我给你留了一个思考题&#xff1a;流量大、数据多的商品详情页系统该如何设计&#xff1f; 今天这篇文章&#xff0c;主要聊一下&#xff0c;如何设计一个快速、可靠的存储架构支撑商品系…

七、一百零二类花分类项目实战

一、准备数据集 一百零二类花数据集下载 flower_data包括train和valid文件&#xff0c;分别存放102个文件&#xff0c;对应102种类别的花 cat_to_name.json为类别和花品种键值对 将压缩包进行解压&#xff0c;跟项目放到同级路径下 二、导包 若遇到报错&#xff0c;不存…

网络链路聚合

这里写目录标题 链路聚合什么是链路聚合&#xff1f;为什么要进行链路聚合&#xff1f;Linux网卡bonding的7种模式模式一&#xff1a;balance-rr 轮询均衡模式模式二&#xff1a;active-backup 主备策略模式模式三&#xff1a;balance-xor 平衡策略模式四&#xff1a;broadcast…

【python】枚举的基本使用,及如何实现枚举属性的自增长

▒ 目录 ▒ &#x1f6eb; 问题描述环境 1️⃣ 枚举的基本使用自定义枚举成员的值枚举值唯一&#xff1a;unique枚举成员的别名&#xff1a;property枚举成员的元数据 2️⃣ 实现枚举属性的自增长python3.6python3.5.2python2不支持enum模块 &#x1f6ec; 结论&#x1f4d6; 参…