【PyTorch单点知识】深入理解与应用转置卷积ConvTranspose2d模块

news2024/11/21 1:36:04

文章目录

      • 0. 前言
      • 1. 转置卷积概述
      • 2. `nn.ConvTranspose2d` 模块详解
        • 2.1 主要参数
        • 2.2 属性与方法
      • 3. 计算过程(重点)
        • 3.1 基本过程
        • 3.2 调整stride
        • 3.3 调整dilation
        • 3.4 调整padding
        • 3.5 调整output_padding
      • 4. 应用实例
      • 5. 总结

0. 前言

按照国际惯例,首先声明:本文只是我自己学习的理解,虽然参考了他人的宝贵见解及成果,但是内容可能存在不准确的地方。如果发现文中错误,希望批评指正,共同进步。

nn.ConvTranspose2d 模块是用于实现二维转置卷积(又称为反卷积)的核心组件。本文将详细介绍 ConvTranspose2d 的概念、工作原理、参数设置以及实际应用。

本文的说明参考了PyTorch的官方文档

1. 转置卷积概述

转置卷积(Transposed Convolution),有时也被称为“反卷积”(尽管严格来说它并不是真正意义上的卷积的逆运算),是一种特殊的卷积操作,常用于从较低分辨率的特征图上采样到较高分辨率的空间维度。

在诸如深度卷积生成对抗网络(DCGAN)和条件生成对抗网络(CGANs)等任务中,转置卷积被广泛用于将网络内部的紧凑特征(较小的特征)表示恢复为与原始输入尺寸相匹配或接近的(较大的特征)输出。

2. nn.ConvTranspose2d 模块详解

nn.ConvTranspose2d 是 PyTorch 中 torch.nn 模块的一部分,专门用于定义和实例化二维转置卷积层。其构造函数接受一系列参数来配置卷积行为:

2.1 主要参数
  1. in_channels (int) - 输入特征图的通道数,即前一层的输出通道数。

  2. out_channels (int) - 输出特征图的通道数,即本层产生的新特征通道数。

  3. kernel_size (inttuple) - 卷积核大小,通常是一个整数(当使用方形卷积核时)或包含两个整数的元组(分别对应卷积核的高度和宽度)。

  4. stride (inttuple, default=1) - 卷积步长,决定了卷积核在输入特征图上滑动的距离。与 kernel_size 类似,它可以是单个整数(对所有维度相同)或一个包含两个整数的元组。

  5. padding (inttuple, default=0) - 填充量,用于控制输出尺寸和保持边界信息。

  6. output_padding (inttuple, default=0) - 用于调整输出尺寸的额外填充量,仅应用于转置卷积。它在卷积计算后增加到输出边缘的额外像素数量。

  7. groups (int, default=1) - 分组卷积参数,当大于1时,输入和输出通道将被分成若干组,每组内的卷积相互独立。

  8. bias (bool, default=True) - 表示是否为该层添加可学习的偏置项。

  9. dilation (inttuple, default=1) - 卷积核元素之间的间距(膨胀率),控制卷积核中非零元素之间的距离。

  10. padding_mode (str , default=zeros) - 填充数据方式,zeros为全部填充0

  11. device (str , default=cpu) - 处理数据的设备

  12. dtype (str, default=None ) - 数据类型

2.2 属性与方法
  • .weight (Tensor) - 存储转置卷积核的权重,形状为 (out_channels, in_channels, kernel_size[0], kernel_size[1]),是可学习的模型参数。

  • .bias (Tensor) - 若 bias=True,则包含与每个输出通道关联的偏置项,形状为 (out_channels),也是可学习的参数。

  • .forward(input) - 接受输入张量 input,执行转置卷积运算并返回输出特征图。

3. 计算过程(重点)

输入输出图像一般为4维或3维,即[B, C, H, W]或[C, H, W],其中:

  • B:Batch_size,每批的样本数
  • C:channel,通道数
  • H, W:图像的高和宽

以图像高度H为例(宽度W同理),转置卷积的输出尺寸可以通过以下公式计算:

H o u t = ( H i n − 1 ) × stride − 2 × padding + dilation × ( kernel-size − 1 ) + output-padding + 1 H_{out}=(H_{in}-1) \times \text{stride} -2 \times \text{padding} + \text{dilation} \times (\text{kernel-size}-1) + \text{output-padding}+1 Hout=(Hin1)×stride2×padding+dilation×(kernel-size1)+output-padding+1

这个公式看起来比较复杂,下面我们通过实例来理解转置卷积的计算过程。

3.1 基本过程

输入原图size为[1, 2, 2],卷积核也size也为[1, 2, 2],其余参数如下:

in_channels=1, out_channels=1, kernel_size=2, stride=1, padding=0, output_padding=0,dilation=1,bias=False

计算过程:

在这里插入图片描述

容易看出,经历转置卷积后特征图会扩大,即上采样。使用代码验算:

import torch

input = torch.tensor([[[[0,1],
                        [2,3]]]],dtype=torch.float32)

ConvTrans = torch.nn.ConvTranspose2d(in_channels=1, out_channels=1, kernel_size=2, stride=1, padding=0, output_padding=0,dilation=1,bias=False)
ConvTrans.weight = torch.nn.Parameter(torch.tensor([[[[ 1.1, 2.2],
          [ 3.3, 4.4]]]], dtype=torch.float32,requires_grad=True))

print(ConvTrans(input))

输出为:

tensor([[[[ 0.0000,  1.1000,  2.2000],
          [ 2.2000, 11.0000, 11.0000],
          [ 6.6000, 18.7000, 13.2000]]]], grad_fn=<ConvolutionBackward0>)
3.2 调整stride

把stride调整为2后,计算过程如下:

在这里插入图片描述
如果stride过大,则会在跳过的位置补0。例如上面的计算过程中,如果stride = 3输出则为:

在这里插入图片描述

注意,这里stride可以指定为tuple,即让横向和纵向的stride不一样,例如(1, 2),但其计算思路不变,这里直接用代码计算结果(懒得再画过程图了):

import torch

input = torch.tensor([[[[0,1],
                        [2,3]]]],dtype=torch.float32)

ConvTrans = torch.nn.ConvTranspose2d(in_channels=1, out_channels=1, kernel_size=2, stride=(1,2), padding=0, output_padding=0,dilation=1,bias=False)
ConvTrans.weight = torch.nn.Parameter(torch.tensor([[[[ 1.1, 2.2],
          [ 3.3, 4.4]]]], dtype=torch.float32,requires_grad=True))

print(ConvTrans(input))

输出为:

tensor([[[[ 0.0000,  0.0000,  1.1000,  2.2000],
          [ 2.2000,  4.4000,  6.6000, 11.0000],
          [ 6.6000,  8.8000,  9.9000, 13.2000]]]],
       grad_fn=<ConvolutionBackward0>)
3.3 调整dilation

这个过程非常简单,可以分为2步:

  1. 把卷积核进行dilation(爆炸)处理
  2. 进行3.1基本过程

即:
在这里插入图片描述
代码验算过程如下:

import torch

input = torch.tensor([[[[0,1],
                        [2,3]]]],dtype=torch.float32)

ConvTrans_dilation2 = torch.nn.ConvTranspose2d(in_channels=1, out_channels=1, kernel_size=2, stride=1, padding=0, output_padding=0,dilation=2,bias=False)
ConvTrans_dilation2.weight = torch.nn.Parameter(torch.tensor([[[[ 1.1, 2.2],
          [ 3.3, 4.4]]]], dtype=torch.float32,requires_grad=True))

print(ConvTrans_dilation2(input))

ConvTrans_dilation1 = torch.nn.ConvTranspose2d(in_channels=1, out_channels=1, kernel_size=2, stride=1, padding=0, output_padding=0,dilation=1,bias=False)
ConvTrans_dilation1.weight = torch.nn.Parameter(torch.tensor([[[[1.1, 0, 2.2],
                                                                     [0, 0, 0],
                                                                     [3.3, 0, 4.4]]]], dtype=torch.float32,requires_grad=True))   #对卷积核进行dilation

print(ConvTrans_dilation1(input))
print(ConvTrans_dilation2(input) == ConvTrans_dilation1(input))

输出为:

tensor([[[[ 0.0000,  1.1000,  0.0000,  2.2000],
          [ 2.2000,  3.3000,  4.4000,  6.6000],
          [ 0.0000,  3.3000,  0.0000,  4.4000],
          [ 6.6000,  9.9000,  8.8000, 13.2000]]]],
       grad_fn=<ConvolutionBackward0>)
tensor([[[[ 0.0000,  1.1000,  0.0000,  2.2000],
          [ 2.2000,  3.3000,  4.4000,  6.6000],
          [ 0.0000,  3.3000,  0.0000,  4.4000],
          [ 6.6000,  9.9000,  8.8000, 13.2000]]]],
       grad_fn=<ConvolutionBackward0>)
tensor([[[[True, True, True, True],
          [True, True, True, True],
          [True, True, True, True],
          [True, True, True, True]]]])
3.4 调整padding

这是一个下采样的过程,会减少输出size。具体计算方法也很简单:给输出数据减去padding。基于3.1基本过程举例说明padding = 1的情况如下:

在这里插入图片描述

3.5 调整output_padding

这个参数用于给最终输出补0,output_padding必须要比stride或者dilation小。需要注意的是output_padding补0只能补半圈,如下:

我也想不明白为什么不是补一整圈?

在这里插入图片描述

4. 应用实例

在实际使用中,nn.ConvTranspose2d 可以嵌入到神经网络结构中,用于实现上采样、特征图尺寸放大或生成与输入尺寸相似的输出。以下是一个简单的使用示例:

import torch
import torch.nn as nn

# 定义一个包含转置卷积层的简单模型
class TransposedConvModel(nn.Module):
    def __init__(self, in_channels=32, out_channels=64, kernel_size=4, stride=2, padding=1, output_padding=0):
        super().__init__()
        self.conv_transpose = nn.ConvTranspose2d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            output_padding=output_padding,
            bias=True
        )

    def forward(self, x):
        return self.conv_transpose(x)

# 实例化模型并应用到输入数据
model = TransposedConvModel()
input_tensor = torch.randn(1, 32, 16, 16)  # (batch_size, in_channels, height, width)
output = model(input_tensor)
print("Output shape:", output.shape)

输出为:

Output shape: torch.Size([1, 64, 32, 32])

5. 总结

nn.ConvTranspose2d 是 PyTorch 中用于实现二维转置卷积的关键模块,它通过逆向的卷积操作实现了特征图的上采样和空间维度的扩大。

正确理解和配置其参数(如 kernel_sizestridepaddingoutput_padding 等),可以帮助开发者构建出适应特定任务需求的神经网络架构,特别是在图像生成、超分辨率、语义分割等需要从低分辨率特征恢复到高分辨率输出的应用场景中发挥关键作用。通过实践和调整这些参数,研究人员和工程师能够灵活地设计和优化基于转置卷积的深度学习模型。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1661440.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

什么是FMEA的分析范围?——FMEA软件

免费试用FMEA软件-免费版-SunFMEA FMEA的分析范围广泛而深入&#xff0c;涵盖了产品设计、制造过程、供应链管理以及使用和维修等多个方面。 产品设计是FMEA分析的重要一环。在设计阶段&#xff0c;FMEA能够帮助工程师识别潜在的设计缺陷&#xff0c;并预测这些缺陷可能对产品…

Bugku Crypto 部分题目简单题解(三)

where is flag 5 下载打开附件 Gx8EAA8SCBIfHQARCxMUHwsAHRwRHh8BEQwaFBQfGwMYCBYRHx4SBRQdGR8HAQ0QFQ 看着像base64解码 尝试后发现&#xff0c;使用在线工具无法解密 编写脚本 import base64enc Gx8EAA8SCBIfHQARCxMUHwsAHRwRHh8BEQwaFBQfGwMYCBYRHx4SBRQdGR8HAQ0QFQ tex…

ArcGIS10.2能用了10.2.2不行了(解决)

前两天我们的推文介绍了 ArcGIS10.2系列许可到期解决方案-CSDN博客文章浏览阅读2次。本文手机码字&#xff0c;不排版了。 昨晚&#xff08;2021\12\17&#xff09;12点后&#xff0c;收到很多学员反馈 ArcGIS10.2系列软件突然崩溃。更有的&#xff0c;今天全单位崩溃。​提示许…

智慧公厕,小民生里的“大智慧”!

公共厕所是城市社会生活的基础设施&#xff0c;而智慧公厕则以其独特的管理模式为城市居民提供更优质的服务。通过智能化的监测和控制系统&#xff0c;智慧公厕实现了厕位智能引导、环境监测、资源消耗监测、安全防范管理、卫生消杀设备、多媒体信息交互、自动化控制、自动化清…

OpenCV 入门(四)—— 车牌号识别

OpenCV 入门系列&#xff1a; OpenCV 入门&#xff08;一&#xff09;—— OpenCV 基础 OpenCV 入门&#xff08;二&#xff09;—— 车牌定位 OpenCV 入门&#xff08;三&#xff09;—— 车牌筛选 OpenCV 入门&#xff08;四&#xff09;—— 车牌号识别 OpenCV 入门&#xf…

DiskCatalogMaker for Mac:高效管理磁盘文件助手

DiskCatalogMaker for Mac&#xff0c;助您高效管理磁盘文件&#xff0c;让文件整理变得轻而易举&#xff01;这款软件以其出色的性能和人性化的设计&#xff0c;赢得了广大Mac用户的喜爱。 DiskCatalogMaker支持多种磁盘格式&#xff0c;让您轻松管理硬盘、U盘、光盘等存储设备…

LaTeX公式学习笔记

\sqrt[3]{100} \frac{2}{3} \sum_{i0}^{n} x^{3} \log_{a}{b} \vec{a} \bar{a} \lim_{x \to \infty} \Delta A B C

基于Springboot的微乐校园管理系统(有报告)。Javaee项目,springboot项目。

演示视频&#xff1a; 基于Springboot的微乐校园管理系统&#xff08;有报告&#xff09;。Javaee项目&#xff0c;springboot项目。 项目介绍&#xff1a; 采用M&#xff08;model&#xff09;V&#xff08;view&#xff09;C&#xff08;controller&#xff09;三层体系结构…

Python尝试安装 pyaudio 时遇到的错误信息表示安装过程失败,原因是找不到 Python.h 头文件

环境&#xff1a; Python 3.8.10 WSL2 问题描述&#xff1a; 尝试安装 pyaudio 时遇到的错误信息表示安装过程失败&#xff0c;原因是找不到 Python.h 头文件 error: subprocess-exited-with-error Building wheel for pyaudio (pyproject.toml) did not run successfully…

组合模式(结构型)

目录 一、前言 二、透明组合模式 三、安全组合模式 四、总结 一、前言 组合模式(Composite Pattern)是一种结构型设计模式&#xff0c;将对象组合成树形结构以表示“部分-整体”得层次结构。组合模式使得用户对单个对象和组合对象的使用具有一致性。 组合模式由以下角色组成…

数据库调优-SQL语句优化

2. SQL语句优化 sql 复制代码 # 请问这两条SQL语句有什么区别呢&#xff1f;你来猜一猜那条SQL语句执行查询效果更好&#xff01; select id from sys_goods where goods_name华为 HUAWEI 麦芒7 魅海蓝 6G64G 全网通; ​ select id from sys_goods where goods_id14967325985…

(九)JSP教程——pageContext对象

pageContext对象是由JSP容器创建并初始化的&#xff0c;相当于当前页面的容器&#xff0c;它可以访问当前页面中的所有对象。它的主要作用是为JSP页面包装上下文&#xff0c;并用于管理属于JSP的特殊可见部分中已命名对象的访问。 一般情况下&#xff0c;使用该对象的应用并不多…

netty配置SSL、netty配置https(开发)

netty配置SSL、netty配置https&#xff08;开发&#xff09; 我们在开发下使用ssl&#xff0c;所用的证书将不被客户端信任。 转自&#xff1a;https://lingkang.top/archives/netty-pei-zhi-ssl 方案一 快速。使用netty提供的临时签发证书 private static SslContext sslC…

富士Apeos 2350 NDA复印机报062 360代码故障

故障描述&#xff1a; 富士Apeos 2350 NDA复印机新机器刚拆箱安装&#xff0c;开机正常&#xff0c;自检扫描头一卡一卡的往前动几下就不动了、扫描灯也不亮扫描头也不能正常复位&#xff1b;按机器的复印键直接报062 360代码&#xff1b; 解答&#xff1a; 此代码为扫描故障&a…

unreal engine4 创建动画蒙太奇

UE4系列文章目录 文章目录 UE4系列文章目录前言一、创建动画蒙太奇 前言 动画蒙太奇的官方解释&#xff1a;Animation Montages are animation assets that enable you to combine animations in a single asset and control playback using Blueprints.You can use Animation…

粤嵌—2024/4/26—跳跃游戏 ||

代码实现&#xff1a; 方法一&#xff1a;回溯 历史答案剪枝优化——超时 int *dis;void dfs(int k, int startindex, int *nums, int numsSize) {if (dis[startindex] < k) {return;}dis[startindex] k;for (int i 0; i < nums[startindex]; i) {if (startindex i &…

mac内存不足怎么清理?有哪些免费的软件工具?

当你的mac电脑使用一段时间之后&#xff0c;你可能就会发现&#xff0c;原本非常流畅的运行开始出现卡顿的现象&#xff0c;此时正是mac内存不足的外在表现。可mac内存不足怎么清理呢&#xff0c;别急&#xff0c;清理内存的方式方法有很多&#xff0c;小编将结合实际情况给大家…

【Web】2023浙江大学生省赛初赛 secObj 题解

目录 step 0 step 1 step 2 step 3 题目本身是不难&#xff0c;简单复健一下 step 0 pom依赖就是spring 反序列化入口在./admin/user/readObj 输入流做了黑名单的过滤&#xff0c;TemplatesImpl不能直接打 可以jackson打SignedObject二次反序列化绕过 具体原理看下面这…

选择定制温度快速温变试验箱,为您开启高效试验新时代

在现代工业生产中&#xff0c;温度快速温变试验是评估产品性能和可靠性的重要手段之一。然而&#xff0c;市面上有大多数企业在选择试验设备时常常面临着一些困惑&#xff0c;那就是温度快速温变试验箱定制的性能与需求不匹配、定制服务不足、售后服务不到位等等。针对这些问题…

98、技巧-颜色分类

思路 这道题的思路是什么&#xff0c;首先典型荷兰国旗问题&#xff1a; 该问题的关键在于我们要将所有的0放到数组的前部&#xff0c;所有的1放在中间&#xff0c;所有的2放在后部。这可以通过使用两个指针&#xff0c;一个指向数组开头的“0”的最后一个位置&#xff0c;另…