【torch.nn.Fold】和【torch.nn.Unfold】

news2025/1/11 22:48:29

文章目录

  • torch.nn.Unfold
    • 直观理解
    • 官方文档
  • toch.nn.Fold
    • 直观理解
    • 官方文档

torch.nn.Unfold

直观理解

torhc.nn.Unfold的功能: 从一个batch的样本中,提取出滑动的局部区域块patch(也就是卷积操作中的提取kernel filter对应的滑动窗口)把它按照顺序展开,得到的特征数就是通道数*卷积核的宽*卷积核的高, 下图中的L就是滑动完成后总的patch的个数
在这里插入图片描述
举个例子:

import torch
input1=torch.randn(1,3,4,6)
print(input1)
unfold1=torch.nn.Unfold(kernel_size=(2,3),stride=(2,3))
patches1=unfold1(input1)
print(patches1.shape)
print(patches1)

下图中的红框、蓝框、黄框、绿框分别是2x3的窗口按照步幅2x3滑动时得到的4个patch。每个patch的特征总数是2*3*3=18 ( 滑动窗口的高 ∗ 滑动窗口的宽 ∗ 通道数 滑动窗口的高*滑动窗口的宽*通道数 滑动窗口的高滑动窗口的宽通道数)
得到的输出patches1就是把每个patch的特征按照顺序展开,输出的大小就是(1,18,4)

在这里插入图片描述

官方文档

CLASS
torch.nn.Unfold(kernel_size, dilation=1, padding=0, stride=1)
  • 功能: 从批量输入张量中提取滑动局部块。

    假设一个batch的输入张量大小为 ( N , C , ∗ ) (N,C,*) (N,C,),其中 N N N表示batch的维度, C C C表示通道维度, ∗ * 表示任意的空间维度。该操作将输入空间维度内的每个kernel_size大小的滑动块展平到一列中, 输出的大小为 ( N , C × ∏ ( k e r n e l _ s i z e ) , L ) \left(N, C \times \prod( kernel\_size ), L\right) (N,C×(kernel_size),L), 其中 C × ∏ ( k e r n e l _ s i z e ) C \times \prod( kernel\_size) C×(kernel_size)表示每个block中包含的所有值的个数,一个block是kernel_size的面积和通道数的乘积, L L L是这样的block的个数。

    L = ∏ d ⌊  spatial_size  [ d ] + 2 × padding ⁡ [ d ] − dilation ⁡ [ d ] × ( kernel ⁡ _ size  [ d ] − 1 ) − 1 stride ⁡ [ d ] + 1 ] ,  L=\prod_d\left\lfloor\frac{\text { spatial\_size }[d]+2 \times \operatorname{padding}[d]-\operatorname{dilation}[d] \times\left(\operatorname{kernel} \_ \text {size }[d]-1\right)-1}{\operatorname{stride}[d]}+1\right] \text {, } L=dstride[d] spatial_size [d]+2×padding[d]dilation[d]×(kernel_size [d]1)1+1]

    其中 s p a t i a l _ s i z e spatial\_size spatial_size 是输入的空间维度(对应上述的*), d d d是所有的空间维度。

    因此,最后一个维度(列维度)的索引输出给出了某个块内的所有值。

    padding、stride和dilation参数指定如何检索滑动块。

    Stride控制滑块的步幅; Padding控制重塑前每个维度的点的填充数两边隐式零填充的数量。

    dilation 控制kenel 点之间的间距;也被称为à trous算法。

  • 参数

    • kernel_size(int or tuple) : 滑块的尺寸
    • dilation(int or tuple,optional): 控制邻域内元素步幅的参数。默认值:1
    • padding(int or tuple, optional) : 在输入的两侧添加隐式零填充。默认值:0
    • stride(int or tuple, optional) : 滑动块在输入空间维度中的步长。默认值:1

    如果kernel_size、dilation、padding或stride是int或长度为1的元组,它们的值将在所有空间维度上复制。

  • 形状:

    • 输入: ( N , C , ∗ ) (N,C,*) (N,C,)
    • 输出: ( N , C × ∏ ( k e r n e l _ s i z e ) , L ) \left(N, C \times \prod( kernel\_size ), L\right) (N,C×(kernel_size),L)
  • 例子

unfold = nn.Unfold(kernel_size=(2, 3))
input = torch.randn(2, 5, 3, 4)
output = unfold(input)
# each patch contains 30 values (2x3=6 vectors, each of 5 channels)
# 4 blocks (2x3 kernels) in total in the 3x4 input
output.size()

# Convolution is equivalent with Unfold + Matrix Multiplication + Fold (or view to output shape)
inp = torch.randn(1, 3, 10, 12)
w = torch.randn(2, 3, 4, 5)
inp_unf = torch.nn.functional.unfold(inp, (4, 5))
out_unf = inp_unf.transpose(1, 2).matmul(w.view(w.size(0), -1).t()).transpose(1, 2)
out = torch.nn.functional.fold(out_unf, (7, 8), (1, 1))
# or equivalently (and avoiding a copy),
# out = out_unf.view(1, 2, 7, 8)
(torch.nn.functional.conv2d(inp, w) - out).abs().max()

toch.nn.Fold

直观理解

toch.nn.Fold 就是torch.nn.Unfold的逆操作,将提取出的滑动局部区域块还原成batch的张量形式。
在这里插入图片描述
举个例子:我们把上面输出的patches 通过具有相同大小的卷积核以及步幅进行Flod操作,得到的input_restoreinput1 相同,说明Fold和UnFold互为逆操作。

fold1=torch.nn.Fold(output_size=(4,6),kernel_size=(2,3),stride=(2,3))
input_restore=fold1(patches1)
print(input_restore.shape)
print(input_restore==input1)
print(input_restore)

在这里插入图片描述

官方文档

CLASS
torch.nn.Fold(output_size, kernel_size, dilation=1, padding=0, stride=1)
  • 功能:

和Unfold相反,将提取出的滑动局部区域块还原成batch的张量形式。

  • 参数
    • output_size(int or tuple) : 输出的空间维度的形状
    • kernel_size(int or tuple) : 滑块的尺寸
    • dilation(int or tuple,optional): 控制邻域内元素步幅的参数。默认值:1
    • padding(int or tuple, optional) : 在输入的两侧添加隐式零填充。默认值:0
    • stride(int or tuple, optional) : 滑动块在输入空间维度中的步长。默认值:1
  • 形状
    • 输入: ( N , C × ∏ (  kernel_size  ) , L ) \left(N, C \times \prod(\text { kernel\_size }), L\right) (N,C×( kernel_size ),L) 或者 ( C × ∏ (  kernel_size  ) , L ) \left( C \times \prod(\text { kernel\_size }), L\right) (C×( kernel_size ),L)
    • 输出: ( N , C ,  output_size  [ 0 ] ,  output_size  [ 1 ] , … ) (N, C, \text { output\_size }[0], \text { output\_size }[1], \ldots) (N,C, output_size [0], output_size [1],) ( N , C ,  output_size  [ 0 ] ,  output_size  [ 1 ] , … ) (N, C, \text { output\_size }[0], \text { output\_size }[1], \ldots) (N,C, output_size [0], output_size [1],)
  • 例子
>>> fold = nn.Fold(output_size=(4, 5), kernel_size=(2, 2))
>>> input = torch.randn(1, 3 * 2 * 2, 12)
>>> output = fold(input)
>>> output.size()
torch.Size([1, 3, 4, 5])

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/855346.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

C语言基础(持续更新)

常用函数 strrchr 描述 C 库函数 char *strrchr(const char *str, int c) 在参数 str 所指向的字符串中搜索最后一次出现字符 c(一个无符号字符)的位置。测试代码 #include "stdio.h" #include "string.h"int main() {printf(&q…

Ubuntu18.04中QGroundControl安装及添加到应用程序

Ubuntu18.04中QGroundControl安装及添加到应用程序 Ubuntu18.04中QGroundControl安装及添加到应用程序教程 目录 Ubuntu18.04中QGroundControl安装及添加到应用程序QGroundControl下载安装一、设置用户权限二、安装必要扩展包三、注销并再次登录以启用对用户权限的更改四、下…

pgsql查询某表所有字段

查询某表所有字段 查询某表所有字段 select * from information_schema.columns where table_schema模式名称 and table_name表名;模式 查询某表字段个数 select count(*) from information_schema.columns where table_schema模式名称 and table_name表名;

4.2 Windows终端数据安全

数据参考:CISP官方 目录 系统备份与还原数据备份数据粉碎数据加密 一、系统备份与还原 为什么需要系统备份 系统越用越慢系统故障导致不稳定系统无法登录 系统备份重新部署 (重装系统、重置系统) 丟失配置,需要重新配置个人数据丢失的风险 系统…

2023-08-09 LeetCode每日一题(整数的各位积和之差)

2023-08-09每日一题 一、题目编号 1281. 整数的各位积和之差二、题目链接 点击跳转到题目位置 三、题目描述 给你一个整数 n,请你帮忙计算并返回该整数「各位数字之积」与「各位数字之和」的差。 示例1: 示例2: 提示: 1 …

记一次空间告警与pg_rman keep-data-days参数研究

一、 背景 收到一个磁盘空间告警,检查发现是本地备份保留比较多导致的,处理过程倒很简单,手动清理掉旧的备份(已自动备到远端服务器),告警就恢复了。 但是检查备份脚本的时候,发现keep-data-day…

WPF实战项目十一(API篇):待办事项功能api接口

1、新建ToDoController.cs继承基础控制器BaseApiController,但是一般业务代码不写在控制器内,业务代码写在Service,先新建统一返回值格式ApiResponse.cs: public class ApiResponse{public ApiResponse(bool status, string mess…

科技云报道:一波未平一波又起?AI大模型再出邪恶攻击工具

AI大模型的快速向前奔跑,让我们见识到了AI的无限可能,但也展示了AI在虚假信息、深度伪造和网络攻击方面的潜在威胁。 据安全分析平台Netenrich报道,近日,一款名为FraudGPT的AI工具近期在暗网上流通,并被犯罪分子用于编…

数据包传输方式:单播、多播、广播、组播、泛播

数据包传输方式 单播、多播、广播、组播、泛播 网络中假设X代表所有的机器,Y代表X中的一部分机器,Z代表一组机器,1代表一台机器,那么 1:1 那就是单播;1:Y 那就是多播;1&#xff1…

mysql数据库如何转移到oracle

mysql数据库转移到oracle 在研发过程中,可能会用到将表数据库中的表结构及数据迁移到另外一种数据库中, 比如说从mysql中迁移到oracle中, 常用的方法有好些,如下 1、使用powerdesigner,先连接mysql然后生成mysql的p…

springboot启动忽略某些类

springboot启动忽略某些类 描述解决方案单拉一个提交,把所有的涉及kafka消费的都不注入容器通过配置ComponentScan的excludeFilters配置了不生效后续处理改之前改之后解释 总结 拆分环境 感触解决实现demo参考 描述 目前我这的开发环境和测试环境数据库是两份&#…

webshell免杀项目-潮影在线免杀平台(六)

平台地址: http://bypass.tidesec.com/web/ 注:需要先注册一个用户登录后才能使用该平台

RK3588平台开发系列讲解(进程篇)Linux进程IPC:管道的使用及原理

文章目录 一、什么是管道二、匿名管道和命名管道 如何进行选择三、管道使用案例四、管道的原理沉淀、分享、成长,让自己和他人都能有所收获!😄 📢 今天介绍Linux进程IPC管道。 一、什么是管道 顾名思义,通常管道就是你家一端连接着水池,另一端连着水龙头的、能流通水的…

算法练习--字符串相关

文章目录 计算字符串最后一个单词的长度计算某字符出现次数明明的随机数回文字符串回文数字无重复字符的最大子串长度有效的括号罗马数字转整数字符串通配符杨辉三角查找两个字符串a,b中的最长公共子串 **找出字符串中第一个只出现一次的字符 计算字符串最后一个单词的长度 pe…

【C++】继承的概念和简单介绍、基类和派生类对象复制转换、继承中的作用域、派生类的默认成员函数

文章目录 继承1.继承的概念和简单介绍1.1继承的概念1.2继承的定义 2.基类和派生类对象复制转换3.继承中的作用域4.派生类的默认成员函数5.继承与友元6.继承与静态成员 继承 1.继承的概念和简单介绍 1.1继承的概念 继承(inheritance)机制是面向对象程序设计使代码可以复用的最…

深入理解Jdk5引入的Java泛型:类型安全与灵活性并存

深入理解Jdk5引入的Java泛型:类型安全与灵活性并存 ​ 在Java的中,有一个强大的工具,它可以让你在编写代码时既保持类型安全,又享受灵活性。**这个工具就是——泛型(Generics)。**本文将引导你深入了解Java…

小白到运维工程师自学之路 第七十集 (Kubernetes集群部署)

一、概述 Kubernetes(简称K8S)是一个开源的容器编排和管理平台,是由Google发起并捐赠给Cloud Native Computing Foundation(CNCF)管理的项目。它的目标是简化容器化应用的部署、扩展、管理和自动化操作。 以下是Kube…

【D3S】集成smart-doc并同步配置到Torna

目录 一、引言二、maven插件三、smart-doc.json配置四、smart-doc-maven-plugin相关命令五、推送文档到Torna六、通过Maven Profile简化构建 一、引言 D3S(DDD with SpringBoot)为本作者使用DDD过程中开发的框架,目前已可公开查看源码&#…

leetcode26-删除有序数组中的重复项

双指针—快慢指针 慢指针 slow 走在后面&#xff0c;快指针 fast 走在前面探路&#xff0c;找到一个不重复的元素的时候就让slow前进一步并赋值给它。 流程&#xff1a; 代码 class Solution { public:int removeDuplicates(vector<int>& nums) {int slow 0, fas…

【2012年专利】基于中继节点的互联网通信系统和通信路径选择方法

基于中继节点的互联网通信系统和通信路径选择方法 CN102594703A地址基于中继节点的互联网通信系统和通信路径选择方法 ,解决:服务器间直接传输时丢包率高及延时长的缺点包括:服务器之间转发数据包的中继节点计算服务器间最优传输路径的选路决策节点, 本发明涉及一种晶于中继…