PyTorch - Conv2d 和 MaxPool2d

news2025/1/11 10:01:10

文章目录

    • Conv2d
      • 计算
      • Conv2d 函数解析
      • 代码示例
    • MaxPool2d
      • 计算
      • 函数说明
    • 卷积过程动画
        • Transposed convolution animations
        • Transposed convolution animations


参考视频:土堆说 卷积计算
https://www.bilibili.com/video/BV1hE411t7RN


关于 torch.nn 和 torch.nn.function
torch.nn 是对 torch.nn.function 的封装,前者更方便实用。


Conv2d

卷积过程可见文末动画


计算

卷积层输入特征图(input feature map)的尺寸为:H_i × W_i × C_i

  • H_i :输入特征图的高
  • W_i :输入特征图的宽
  • C_i :输入特征图的通道数
    (如果是第一个卷积层则是输入图像的通道数,如果是中间的卷积层,则是上一层的输出通道数

卷积层的参数如下:

  • P:padding,补零的行数和列数
  • F:正方形卷积核的边长
  • S:stride,步幅
  • K:输出通道数

输出特征图(output feature map)的尺寸为 H_o × W_o × C_o ,其中每一个变量的计算方式如下:

  • H_o = (H_i + 2P − F)/S + 1
  • W_o = (W_i + 2P − F)/S + 1
  • C_o = K

  • 卷积时,超出边界的不计算。

参数量大小的计算,分为weights和biases:

首先,计算weights的参数量:F × F × C_i × K
接着计算biases的参数量:K
所以总参数量为:F × F × C_i × K + K


计算示例

输入卷积核步长padding输出计算
5x52x2104x44 = (5-2)/1 + 1
5x53x3103x33 = (5-3)/1 + 1
5x52x2202x22 = (5-2)/2 + 1
6x62x2203x33 = (6-2)/2 + 1
5x52x2116x64 = (5 + 1*2 - 2)/1 + 1
5x53x3115x53 = (5 + 1*2 - 3)/1 + 1
5x53x3224x43 = (5 + 2*2 - 3)/2 + 1

Conv2d 函数解析

  • torch.nn.functional.conv2d 官方说明
    https://pytorch.org/docs/stable/generated/torch.nn.functional.conv2d.html#torch.nn.functional.conv2d

torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode=‘zeros’, device=None, dtype=None)

  • in_channels
  • out_channels
  • kernel_size,卷积核大小;可以是一个数(n*n矩阵),也可以是一个元组。这个值在训练过程中,会不断被调整。
  • stride=1
  • padding=0
  • dilation=1,卷积核对应位的距离
  • groups=1,分组卷积;一般为1,很少改动。
  • bias=True,偏置,一般为True
  • padding_mode=‘zeros’,如果设置了 padding,填充模式。默认为 zeros,即填充0。
  • device=None
  • dtype=None)

一般只设置前五个参数



代码示例

import torch
import torch.nn.functional as F

t1 = torch.Tensor([[1, 2, 0, 3, 1], 
                   [0, 1, 2, 3, 1],
                   [1, 2, 1, 0, 0],
                   [5, 2, 3, 1, 1],
                   [2, 1, 0, 1, 1], ])

kernel = torch.Tensor([
    [1, 2, 1],
    [0, 1, 0],
    [2, 1, 0]
])
t1.shape, kernel.shape
# (torch.Size([5, 5]), torch.Size([3, 3]))

# channel 和 batch_size 为 1
ip = torch.reshape(t1, (1, 1, 5, 5))
kernel = torch.reshape(kernel, (1, 1, 3, 3))
ip.shape, kernel.shape
# (torch.Size([1, 1, 5, 5]), torch.Size([1, 1, 3, 3]))


op = F.conv2d(ip, kernel, stride=1) 
op, op.shape 
'''
(tensor([[[[10., 12., 12.],
           [18., 16., 16.],
           [13.,  9.,  3.]]]]),
           
 torch.Size([1, 1, 3, 3]))
'''

# 不同 stride
op = F.conv2d(ip, kernel, stride=2) 
op, op.shape 
'''
(tensor([[[[10., 12.],
           [13.,  3.]]]]),
 
torch.Size([1, 1, 2, 2]))
'''

# 增加 padding
op = F.conv2d(ip, kernel, stride=2, padding=1) 
op, op.shape 
'''
(tensor([[[[ 1.,  4.,  8.],
           [ 7., 16.,  8.],
           [14.,  9.,  4.]]]]),
           
 torch.Size([1, 1, 3, 3]))
'''


MaxPool2d

池化的目的是保留特征,减少数据量;
最大池化也被称为 下采样;
另外池化操作是分别应用到每一个深度切片层。输出深度 与 输入的深度 相同。


计算

  • 输入宽高深:H_i,W_i, D_i
  • 滤波器宽高:f_w, f_h
  • S: stride,步长

输出为:
H_o = (H_i - f_h)/S + 1
W_o = (W_i - f_w)/S + 1
D_o = D_i


输入维度是 4x4x5 (HxWxD)
滤波器大小 2x2 (HxW)
stride 的高和宽都是 2 (S)


在这里插入图片描述


函数说明

  • 官方说明
    https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html#torch.nn.MaxPool2d

torch.nn.MaxPool2d(kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False)

  • ceil_mode,超出范围时是否计算

代码实现

import torch
import torch.nn as nn

# MaxPool2d 函数 input 需要是 4维
ip = torch.reshape(t1, (-1, 1, 5, 5))


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.maxpool = nn.MaxPool2d(kernel_size=3, ceil_mode=True)
#         self.maxpool = nn.MaxPool2d(kernel_size=3, ceil_mode=False)
        
    def forward(self, input):
        output = self.maxpool(input)
        return output
        
net = Net()
ret = net(ip)
ret

# tensor([[[[2., 3.], [5., 1.]]]])  # ceil_mode=True
# tensor([[[[2.]]]])  # ceil_mode=False

数据集中调用

import torchvision
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

data_path = '/xxxx/cifar10'
datasets = torchvision.datasets.CIFAR10(data_path, train=False, download=True, 
                                    transform=torchvision.transforms.ToTensor())

data_loader = DataLoader(datasets, batch_size=64)

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.maxpool1 = nn.MaxPool2d(kernel_size=3, ceil_mode=False)
#         self.maxpool = nn.MaxPool2d(kernel_size=3, ceil_mode=False)
        
    def forward(self, input):
        output = self.maxpool1(input)
        return output
        

writer = SummaryWriter('logs_maxpool1')
step = 0
net = Net()
for data in data_loader:
    imgs, targets = data
    writer.add_images('input', imgs, step)
    output = net(imgs) 
    writer.add_images('output', output, step)
    step = step + 1
    
    
writer.close() 
  • 启动 tensorboard:
tensorboard --logdir=logs_maxpool1 


卷积过程动画

图片来自:https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md

Transposed convolution animations


No padding, no strides
请添加图片描述


Arbitrary padding, no strides

请添加图片描述


Half padding, no strides
请添加图片描述


Full padding, no strides

请添加图片描述


No padding, strides

请添加图片描述


Padding, strides

请添加图片描述


Padding, strides (odd)

请添加图片描述


Transposed convolution animations


No padding, no strides, transposed

请添加图片描述


Arbitrary padding, no strides, transposed

请添加图片描述


Half padding, no strides, transposed

请添加图片描述


Full padding, no strides, transposed

请添加图片描述


No padding, strides, transposed
请添加图片描述


Padding, strides, transposed
请添加图片描述


Padding, strides, transposed (odd)

请添加图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/355322.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Reverse入门[不断记录]

文章目录前言一、[SWPUCTF 2021 新生赛]re1二、[SWPUCTF 2021 新生赛]re2三、[GFCTF 2021]wordy[花指令]四、[NSSRound#3 Team]jump_by_jump[花指令]五、[NSSRound#3 Team]jump_by_jump_revenge[花指令]前言 心血来潮,想接触点Reverse,感受下Reverse&am…

网络编程(一)

网络编程 文章目录网络编程前置概念1- 字节序高低地址与高低字节高低地址:高低字节字节序大端小端例子代码判断当前机器是大端还是小端为何要有字节序字节序转换函数需要字节序转换的时机例子一例子二2- IP地址转换函数早期(不用管)举例现在与字节序转换函数相比:**…

模块化热更思路

title: 模块化热更思路 categories: Others tags: [热更, 模块化, 分包] date: 2023-02-18 01:04:57 comments: false mathjax: true toc: true 模块化热更 浅浅的记录一下访问破 200w (But, I don’t care about this.) 前篇 只谈思路, 不贴实现代码. 需求 游戏类型属于合集…

Linux(十三)设计模式——单例模式

设计模式——针对典型场景所设计出来的特别的处理方案 单例模式:一个类只能实例化一个对象(所以叫单例) 场景: 1、资源角度:资源在内存中只占有一份 2、数据角度:如果只有一个对象,那么该对象在…

2019蓝桥杯真题质数(填空题) C语言/C++

题目描述 本题为填空题,只需要算出结果后,在代码中使用输出语句将所填结果输出即可。 我们知道第一个质数是 2、第二个质数是 3、第三个质数是 5…… 请你计算第 2019 个质数是多少? 运行限制 最大运行时间:1s 最大运行内存: 128M…

Mac下安装Tomcat以及IDEA中的配置

安装brew 打开终端输入以下命令: /usr/bin/ruby -e "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/master/install)" 搜索tomcat版本,输入以下命令: brew search tomcat 安装自己想要的版本,例…

JDK版本区别

1. 泛型 ArrayList listnew ArrayList()------>ArrayList<Integer>listnew ArrayList<Integer>(); 2 自动装箱/拆箱 nt ilist.get(0).parseInt();-------->int ilist.get(0);原始类型与对应的包装类不用显式转换 3 for-each i0;i<a.length;i------------&…

解析从Linux零拷贝深入了解Linux-I/O(上)

本文将从文件传输场景以及零拷贝技术深究 Linux I/O 的发展过程、优化手段以及实际应用。前言 存储器是计算机的核心部件之一&#xff0c;在完全理想的状态下&#xff0c;存储器应该要同时具备以下三种特性&#xff1a; 速度足够快&#xff1a;存储器的存取速度应当快于 CPU …

JWT安全漏洞以及常见攻击方式

前言 随着web应用的日渐复杂化&#xff0c;某些场景下&#xff0c;仅使用Cookie、Session等常见的身份鉴别方式无法满足业务的需要&#xff0c;JWT也就应运而生&#xff0c;JWT可以有效的解决分布式场景下的身份鉴别问题&#xff0c;并且会规避掉一些安全问题&#xff0c;如CO…

python+vue微信小程序的线上服装店系统

服装行业是一个传统的行业。根据当前发展现状,网络信息时代的全面普及,服装行业也在发生着变化,单就服饰这一方面,利用手机购物正在逐步进入人们的生活。传统的购物方式,不仅会耗费大量的人力、时间,有时候还会出错。小程序系统伴随智能手机为我们提供了新的方向。手机线上服装…

JavaEE|套接字编程之UDP数据报

文章目录一、DatagramSocket API构造方法常用方法二、DatagramPacket API构造方法常用方法E1:回显服务器的实现E2:带有业务逻辑的请求发送一、DatagramSocket API 在操作系统中&#xff0c;把socket对象当成了一个文件处理。等价于是文件描述符表上的一项。 普通的文件&#xf…

vbs简单语法及简单案例

文章目录一、简单语法1、变量2、输入3、输出4、选择语句5、循环二、用记事本编译中文乱码问题三、制作一个简单vbs脚本表白一、简单语法 1、变量 语法&#xff1a; dim 变量名例&#xff1a; dim a,b a1 b2 msgbox ab运行&#xff1a; 2、输入 语法&#xff1a;InputBox(…

【ip neigh】管理IP邻居( 添加ARP\NDP静态记录、删除记录、查看记录)

一、邻居管理存在状态 1、NUD_NONE&#xff1a; 初始状态。当一个新的路由缓存条目被创建时&#xff0c;arp_bind_neighbour()函数被调用.如果找不到相匹配的ARP缓存条目, neigh_alloc()将创建一个新的ARP缓存条目并设置状态为NUD_NONE. 2、NUD_INCOMPLETE&#xff1a;未完成状…

设计模式之适配器模式与桥接模式详解和应用

目录1 适配器模式1.1 定义1.2 应用场景1.3 适配器角色1.4 类适配器1.5 对象适配器1.5 接口适配器1.6 实战1.7 源码1.8 适配器与装饰器的对比1.9 适配器模式的优缺点1.10 总结2 桥接模式2.1 原理解析2.2 角色2.3 通用写法2.4 应用场景2.5 业务场景中的运用2.6 源码2.7 桥接模式优…

指针笔记(指针数组和指向数组的指针,数组中a和a的区别等)

指针数组和指向数组的指针 int *p[4]和int (*p)[4]有何区别&#xff1f; 前者是一个指针数组&#xff0c;数组大小为4&#xff0c;每一个元素都是一个指向int的指针 后者是指向int[4]类型数组的指针 以上代码若运行会报如下错误 main函数中定义的a数组本质是一个指向int[2]的…

内网渗透(三十八)之横向移动篇-pass the key 密钥传递攻击(PTK)横向攻击

系列文章第一章节之基础知识篇 内网渗透(一)之基础知识-内网渗透介绍和概述 内网渗透(二)之基础知识-工作组介绍 内网渗透(三)之基础知识-域环境的介绍和优点 内网渗透(四)之基础知识-搭建域环境 内网渗透(五)之基础知识-Active Directory活动目录介绍和使用 内网渗透(六)之基…

从0到1一步一步玩转openEuler--18 openEuler 管理服务-简介

文章目录18 管理服务简介18.1 概念介绍18 管理服务简介 systemd是在Linux下&#xff0c;与SysV和LSB初始化脚本兼容的系统和服务管理器。systemd使用socket和D-Bus来开启服务&#xff0c;提供基于守护进程的按需启动策略&#xff0c;支持快照和系统状态恢复&#xff0c;维护挂…

java基础学习 day41(继承中成员变量和成员方法的访问特点,方法的重写)

继承中&#xff0c;成员变量的访问特点 a. name前什么都不加&#xff0c;name变量的访问采用就近原则&#xff0c;先在局部变量中查找&#xff0c;若没找到&#xff0c;继续在本类的成员变量中查找&#xff0c;若没找到&#xff0c;继续在直接父类的成员变量中查找&#xff0c…

Mel Frequency Cepstral Coefficients (MFCCs)

wiki里说 在声音处理中&#xff0c;梅尔频率倒谱( MFC ) 是声音的短期功率谱的表示&#xff0c;基于非线性梅尔频率标度上的对数功率谱的线性余弦变换。 倒谱和MFC 之间的区别在于&#xff0c;在 MFC 中&#xff0c;频带在梅尔尺度上等距分布&#xff0c;这比正常频谱中使用的线…

Windows10 安装ElasticStack8.6.1

一、安装ElasticSearch8.6.1 1.官网下载ElasticSearch8.6.1压缩包后解压 2.安装为服务 elasticsearch-service.bat install 3.运行 elasticsearch-service.bat start 4.通过浏览器访问 http://localhost:9200/ 提示需要登录&#xff0c;但不知密码是啥。 5.重置密码 ela…