详解torch.nn.functional.grid_sample函数(通俗易懂):可实现对特征图的水平/垂直翻转

news2024/9/25 11:15:16

一、函数介绍

Pytorch中grid_sample函数的接口声明如下,具体网址可以点这里

torch.nn.functional.grid_sample(input, grid, mode=‘bilinear’, padding_mode=‘zeros’, align_corners=None)

为了简单起见,以下讨论都是基于如下参数进行实验及讲解的:

torch.nn.functional.grid_sample(input, grid, mode=‘bilinear’, padding_mode=‘border’, align_corners=True)

给定 维度为(N,C,Hin,Win) 的input,维度为(N,Hout,Wout,2) 的grid,则该函数output的维度为(N,C,Hout,Wout)

  • 其实,input(N,C,Hin,Win)可以理解为一批特征图。其中,N可以理解为批大小(batch size),C可以理解为特征图的通道数,Hin可以理解为特征图的高,Win可以理解为特征图的宽

  • gird(N,Hout,Wout,2)的作用在于提供一批用于在输入特征图上进行元素采样的位置坐标grid的元素值通常在[-1,1]之间(-1,-1) 表示取输入特征图左上角的元素,(1,1) 表示取输入特征图右下角的元素。

  • output(N,C,Hout,Wout)表示函数输出的一批特征图,其批大小依然为N,特征图的通道数依然为C,但特征图的高已经变成了Hout,宽变成了Wout,并且输出特征图中的元素值是从根据grid所提供的位置坐标在输入特征图中采样得到的

因此,一般来说,我们 首先需要根据Hin和Win的大小,对输入特征图元素坐标位置进行规范化

假设我们此时有一个1 × 2 × 5 × 9的特征图,即N=1,C=2,Hin=5,Win=9。如下:

在这里插入图片描述

那么对输入特征图根据其高(Hin=5)和宽(Win=9)进行元素坐标位置规范化如下:

在这里插入图片描述

如果我们想实现对 输入特征图 input(维度大小为1 × 2 × 5 × 9)的 水平翻转 ,则 grid (维度大小应该为1 × 5 × 9 × 2)应该设定为对上述 坐标位置规范化结果的水平翻转 形式,如下:【注意,这里的两个2表示的含义完全不同,input中的2表示的是通道数为2,而grid中的2表示的是坐标,众所周知,二维坐标是2个数。这里之所以举例比较巧合,就是想通过这里的解释,让大家深刻理解上述参数的具体含义。】

在这里插入图片描述

二、代码验证(输入特征图和输出特征图大小相同)

根据代码运行结果可知,当grid设定为对输入特征图元素坐标位置规范化结果的水平翻转形式时,也就实现了对输入特征图的水平翻转

import torch
import torch.nn.functional as F

input_data = torch.tensor([[[[1,2,3,4,5,-4,-3,-2,-1],
                        [-1,-2,-3,-4,-5,4,3,2,1],
                        [1,3,5,7,9,11,13,15,17],
                        [0,2,4,6,8,10,12,14,16],
                        [3,6,9,12,15,16,19,21,24]],

                       [[9,8,7,6,5,4,3,2,1],
                       [1,2,3,4,5,6,7,8,9],
                       [-9,-8,-7,-6,-5,-4,-3,-2,-1],
                       [-1,-2,-3,-4,-5,-6,-7,-8,-9],
                       [0,2,4,6,8,1,3,5,7]]]]).float()
print(input_data.shape) # torch.Size([1, 2, 5, 9])

grid = torch.tensor([[[[1,-1],
                      [0.75,-1],
                      [0.5,-1],
                      [0.25,-1],
                      [0,-1],
                      [-0.25,-1],
                      [-0.5,-1],
                      [-0.75,-1],
                      [-1,-1]],

                      [[1,-0.5],
                      [0.75,-0.5],
                      [0.5,-0.5],
                      [0.25,-0.5],
                      [0,-0.5],
                      [0.25,-0.5],
                      [0.5,-0.5],
                      [0.75,-0.5],
                      [1,-0.5]],

                      [[1,0],
                      [0.75,0],
                      [0.5,0],
                      [0.25,0],
                      [0,0],
                      [0.25,0],
                      [0.5,0],
                      [0.75,0],
                      [1,0]],

                      [[1,0.5],
                      [0.75,0.5],
                      [0.5,0.5],
                      [0.25,0.5],
                      [0,0.5],
                      [0.25,0.5],
                      [0.5,0.5],
                      [0.75,0.5],
                      [1,0.5]],

                      [[1,1],
                      [0.75,1],
                      [0.5,1],
                      [0.25,1],
                      [0,1],
                      [0.25,1],
                      [0.5,1],
                      [0.75,1],
                      [1,1]]]])
print(grid.shape) # torch.Size([1, 5, 9, 2])

output = F.grid_sample(input_data, grid, mode='bilinear', padding_mode='border', align_corners=True)
print(output)

在这里插入图片描述

在上述例子中,批大小为1。如过批大小为N,则grid应该为N个特征图提供相应的N种采样方式,比如对某些特征图进行水平翻转,对某些特征图进行上下翻转…,当然也可以对N个特征图提供N种相同的采样方式。注意:虽然可以对N个特征图提供N种相同的采样方式,但是对于每个特征图中的所有通道,采样方式都是一致的

另外,在本例中,输入特征图和输出特征图大小相同。如果我们想输出和输入特征图不同大小的特征图,也是可以的,只需要对grid进行改变即可,参见第三部分。

三、代码验证(输入特征图和输出特征图大小不同)

假定,输入特征图与上述保持一致,即N=1,C=2,Hin=5,Win=9,如下:

在这里插入图片描述

然而我们 只想采样黄色区域的元素,则相应地, grid应该只选择对输入特征图元素坐标位置规范化结果的对应坐标位置,如下:

在这里插入图片描述
代码及结果如下:

import torch
import torch.nn.functional as F

input_data = torch.tensor([[[[1,2,3,4,5,-4,-3,-2,-1],
                        [-1,-2,-3,-4,-5,4,3,2,1],
                        [1,3,5,7,9,11,13,15,17],
                        [0,2,4,6,8,10,12,14,16],
                        [3,6,9,12,15,16,19,21,24]],

                       [[9,8,7,6,5,4,3,2,1],
                       [1,2,3,4,5,6,7,8,9],
                       [-9,-8,-7,-6,-5,-4,-3,-2,-1],
                       [-1,-2,-3,-4,-5,-6,-7,-8,-9],
                       [0,2,4,6,8,1,3,5,7]]]]).float()
print(input_data.shape) # torch.Size([1, 2, 5, 9])

grid = torch.tensor([[[[-0.75,-1],
                       [-0.25,-1],
                       [0.25,-1],
                       [0.75,-1]],

                      [[-0.75,0],
                       [-0.25,0],
                       [0.25,0],
                       [0.75,0]],

                       [[-0.75,1],
                       [-0.25,1],
                       [0.25,1],
                       [0.75,1]]]])
print(grid.shape) # torch.Size([1, 3, 4, 2])

output = F.grid_sample(input_data, grid, mode='bilinear', padding_mode='border', align_corners=True)
print(output)

在这里插入图片描述

四、自动对输入特征图中的元素坐标位置进行规范化操作

看到这里,相信大家已经基本知道 torch.nn.functional.grid_sample(input, grid, mode=‘bilinear’, padding_mode=‘zeros’, align_corners=None)函数是做什么的了。

总之一句话,该函数可以根据grid中的坐标顺序对input进行重新采样,从而生成新的ouput

在上述分析中,我们对输入特征图元素坐标位置的规范化是手动计算的,那么能不能让程序自动对输入特征图中的元素坐标位置进行规范化操作呢?

当然是可以的,具体分析可以参考如下代码:首先,定义一个函数 generate_flip_grid(w, h)自动对输入特征图中的元素坐标位置进行规范化操作,然后,对规范化后的结果进行相关变化,比如水平翻转或上下翻转,即可实现对输入特征图的水平翻转或上下翻转。在下例中,我们对给定的一批输入特征图均执行了水平翻转的操作。

# 参考链接:
# https://cloud.tencent.com/developer/article/1781060
# https://pytorch.org/docs/stable/generated/torch.Tensor.expand.html


import torch
import torch.nn.functional as F


# w和h分别为torch.nn.functional.grid_sample函数中input参数的宽和高
def generate_flip_grid(w, h):
    x_ = torch.arange(w).unsqueeze(0).expand(h, -1) # torch.Size([h, w])
    # expand(*size)函数可以实现对张量中单维度上数据的复制操作。
    # 其中,*size分别指定了每个维度上复制的倍数。
    # 对于不需要(或非单维度)进行复制的维度,对应位置上可以写上原始维度的大小或者直接写-1。

    # 单维度怎么理解呢?
    # 将张量中大小为1的维度称为单维度。例如,shape为[2,3]的张量就没有单维度,
    # shape为[1,3]的张量,其第0个维度上的大小为1,因此第0个维度为张量的单维度。

    # 例如,torch.arange(7)结果的shape为[7],没有单维度,因此需要先通过unsqueeze()进行维度增加,
    # 参数为0表示在第0个维度进行维度增加操作,即在张量最外层加一个中括号变成第一维。

    y_ = torch.arange(h).unsqueeze(1).expand(-1, w) # torch.Size([h, w])
    grid = torch.stack([x_, y_], dim=0).float() # torch.Size([2, h, w])
    # 将x_和y_沿维度0进行堆叠

    grid = grid.unsqueeze(0) # torch.Size([1,2, h, w])
    grid[:, 0, :, :] = 2 * grid[:, 0, :, :] / (w - 1) - 1 # 相当于对x轴坐标进行规范化操作 torch.Size([1, 2, h, w])
    grid[:, 1, :, :] = 2 * grid[:, 1, :, :] / (h - 1) - 1 # 相当于对y轴坐标进行规范化操作 torch.Size([1, 2, h, w])
    grid = grid.permute(0,2,3,1) # 交换维度 转换为 torch.nn.functional.grid_sample函数中grid规定的形式[1,h,w,2]

    return grid # torch.Size([1,h,w,2])




# w和h分别为torch.nn.functional.grid_sample函数中input参数的宽和高
w = 9
h = 5
N = 2 # 这里的N相当于batch size

grid = generate_flip_grid(w,h) # 获取输入特征图中元素位置的规范化结果

grid = grid.expand(N, -1, -1, -1).clone() # torch.Size([N, h, w, 2])
# expand()函数并不会重新分配内存,返回的结果仅仅是原始张量上的一个视图,无法对原始张量进行修改。
# 因此,如果expand之后直接在下面对grid张量进行元素改变,就会发生错误。
# clone()函数为复制函数, 可以返回一个完全相同的张量,与原张量不共享内存,从而可以实现下面对张量的修改。

grid[:, :, :, 0] = -grid[:, :, :, 0] # 对x轴坐标取反,相当于实现了水平/左右翻转
# grid[:, :, :, 1] = -grid[:, :, :, 1] # 对y轴坐标取反,相当于实现了上下翻转

input = torch.tensor([[[[1,2,3,4,5,-4,-3,-2,-1],
                        [-1,-2,-3,-4,-5,4,3,2,1],
                        [1,3,5,7,9,11,13,15,17],
                        [0,2,4,6,8,10,12,14,16],
                        [3,6,9,12,15,16,19,21,24]],

                       [[9,8,7,6,5,4,3,2,1],
                        [1,2,3,4,5,6,7,8,9],
                        [-9,-8,-7,-6,-5,-4,-3,-2,-1],
                        [-1,-2,-3,-4,-5,-6,-7,-8,-9],
                        [0,2,4,6,8,1,3,5,7]]],

                       [[[9,8,7,6,5,4,3,2,1],
                         [1,2,3,4,5,6,7,8,9],
                         [-9,-8,-7,-6,-5,-4,-3,-2,-1],
                         [-1,-2,-3,-4,-5,-6,-7,-8,-9],
                         [0,2,4,6,8,1,3,5,7]],

                        [[1,2,3,4,5,-4,-3,-2,-1],
                         [-1,-2,-3,-4,-5,4,3,2,1],
                         [1,3,5,7,9,11,13,15,17],
                         [0,2,4,6,8,10,12,14,16],
                         [3,6,9,12,15,16,19,21,24]]]]).float()
# print(input.shape) # torch.Size([2, 2, 5, 9])

output = F.grid_sample(input, grid, mode='bilinear', padding_mode='border', align_corners=True)
print(output)

在这里插入图片描述

五、关于参数

torch.nn.functional.grid_sample(input, grid, mode=‘bilinear’, padding_mode=‘zeros’, align_corners=None)

通过上面的介绍,相信大家对input和grid以及函数的输出都已经了解得差不多了。

这里,主要说一下其它的三个参数。

  • padding_mode表示当grid中的坐标位置超出边界时像素值的填充方式,如果为zeros,则表示一旦grid坐标超出边界,则用0去填充输出特征图的相应位置元素,如果为border,则表示利用输入特征图对应的边缘元素去填充输出特征图的相应位置元素。想了解更多选择,可以去官网进一步了解。笔者目前只研究了zeros和border两种情况。

  • mode表示插值方式,对于四维数据的话,大家一般选择bilinear即可。想了解更多选择,可以去官网进一步了解。这里说明一下什么时候会用到插值,如果grid中的某个坐标直接对应于输入特征图元素位置的规范化结果中的某个坐标,则直接把对应的值取过来就行。但如果grid中的某个坐标不能直接对应于输入特征图元素位置的规范化结果中的所有坐标,则需要根据不同的插值方式(比如bilinear)在输入特征图中进行插值。

  • 至于align_corners这个参数,一般和插值方式mode搭配使用,表示在插值时像素的对齐方式,有两种选择,分别是True和False。如果把一个像素点看做一个正方形的话,True表示角像素点位于对应正方形的中心。False表示角像素点位于对应正方形的角点坐标。笔者目前只研究明白了align_corner=True的含义。

六、参考文献

  • grid_sample()函数及双线性采样

  • TORCH.NN.FUNCTIONAL.GRID_SAMPLE

  • Pytorch中的grid_sample算子功能解析

  • PyTorch中grid_sample的使用方法

  • torch.nn.function.grid_sample的介绍及使用方法

  • torch.nn.functional.grid_sample

  • align_corners参数介绍

  • 一文看懂align_corners

  • RuntimeError: unsupported operation: more than one element of the written-to tensor refers to a single memory location. Please clone() the tensor before performing the operation

  • expand()函数的局限性,需要搭配clone()函数使用

  • PyTorch入门笔记-复制数据expand函数

  • TORCH.TENSOR.EXPAND

总之,这个函数,就是根据grid所提供的坐标在input中进行重新采样,然后生成新的output。用这个函数来实现特征图的水平翻转或垂直翻转,特别容易理解,因为直接把input的元素坐标都水平翻转或垂直翻转一下就行了。这种情况下,grid中的坐标不会超出界限,因此就不用考虑padding_mode参数。另外,这种情况下,grid中的坐标在input中均能找到映射,因此也不用考虑详细的插值情况,只需要注意将align_corners设为True,因为我们grid边缘的点的位置坐标在相应的轴上都是等距的,与align_corners为True一致(align_corners为False则不等距)。

这篇博客被我断断续续写了三四天,如果大家觉得有所帮助的话,麻烦点个赞鼓励一下吧😭。大家有任何问题,欢迎评论区留言,我看到都会尽量回复的~

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/56628.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

BSN开放联盟链“中移链”浏览器2.0正式发布!

由中国移动信息技术中心自主研发的中移链EOS区块链浏览器2.0版本,已在区块链服务网络(BSN)官网和BSN-DDC网络官网正式发布。 中移链浏览器2.0 无论是从政策导向还是从业务需求方面来说,区块链技术的发展已经是一种不可逆的趋势&a…

查找-二叉排序树

问题引入 【问题描述】 输入若干个整数建立二叉排序树,以0结束输入,在二叉排序树上查找关键字,删除指定关键字结点。 【输入形式】 (1)第一行,输入若干个整数,输入0结束输入; 如输入关键字 45 24 53 12 28 90 0 可建立如下二叉排序树 (2)第二行,输入两个整数,一…

GameOff2022参与有感

GameOff2022参与有感以及年度总结 厚颜无耻的用我们美术的立绘 GameOff— Redemption 很高兴在一个月的时间里面和大家一起完成了《Redemption》 比赛链接:Itch.io 百度云盘链接: 链接:https://pan.baidu.com/s/1ylK0QRr2lmkqi4JF1wsXtA 提…

【servelt原理_6_servlet核心接口和类】

servlet核心接口和类 在Servlet体系中,除了实现servlet接口,还可以通过继承GenericServlet或HttpServlet类实现编写1.Servlet接口 servlet接口是整个servlet的核心。它是所有Servlet类必须直接或者间接实现的一个接口,其内部需要实现的5个方法分别关乎…

基于flv.js的视频自动播放

1: html <video class"video-content" id"video">您的浏览器不支持 HTML5 video&#xff01; </video> 2: 创建flv实例并播放 let videoPlayer document.getElementById(video); //获取html if (flvJs.isSupported()) {//创建flv实例this.P…

音视频开发——FFmpeg技术点 【进阶一览】

概述 Fmpeg是一套领先的音视频多媒体处理开源框架&#xff0c;采用LGPL或GPL许可证。它提供了对音视频的采集、编码、解码、转码、音视频分离、合并、流化、过滤器等丰富的功能&#xff0c;包含了非常先进的音频/视频编解码库libavcodec&#xff0c;具有非常高的可移植性和编解…

[附源码]计算机毕业设计中小学课后延时服务管理系统Springboot程序

项目运行 环境配置&#xff1a; Jdk1.8 Tomcat7.0 Mysql HBuilderX&#xff08;Webstorm也行&#xff09; Eclispe&#xff08;IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持&#xff09;。 项目技术&#xff1a; SSM mybatis Maven Vue 等等组成&#xff0c;B/S模式 M…

Git 之 已有项目创建 git 仓库

Git 之 已有项目创建 git 仓库前言一、现在 github/gitee 中创建仓库二、在项目的文件夹当中 git bash here1.git init2. git remote add origin 仓库地址3. git pull origin master4. git add . git commit -m git push -u origin master前言 项目已经开始写了,但是还没有…

干货 | 数字经济创新创业——如何造就成功的职业生涯

下文整理自清华大学大数据能力提升项目能力提升模块课程“Innovation & Entrepreneurship for Digital Economy”&#xff08;数字经济创新创业课程)的精彩内容。主讲嘉宾&#xff1a;Kris Singh: CEO at SRII, Palo Alto, CaliforniaVisiting Professor of Tsinghua Unive…

第一天:Python元学习——通用人工智能的实现

文章目录0 封面1 第一章&#xff1a;元学习简介1.1 元学习与少样本学习1.2 元学习的类型——学习度量空间1.3 学习初始化1.4 学习优化器1.5 通过梯度下降来学习如何通过梯度下降来学习2 第二章&#xff1a;使用孪生网络进行人脸识别与音频视频2.1 什么是孪生神经网络孪生神经网…

机器学习与数据挖掘——数据预处理

如果有兴趣了解更多相关内容&#xff0c;欢迎来我的个人网站看看&#xff1a;瞳孔空间 一&#xff1a;关于数据预处理 在工程实践中&#xff0c;我们得到的数据会存在有缺失值、重复值等&#xff0c;在使用之前需要进行数据预处理。数据预处理没有标准的流程&#xff0c;通常…

Kaldi的简单介绍和基本使用说明

Kaldi的简单介绍和基本使用说明前言一、ASR简介1.语音识别系统特征提取&#xff1a;声学模型发音词典语言模型语音解码2. ASR项目二、Kaldi简介三、Kaldi项目的结构四、Kaldi的安装1. 安装依赖的几个系统开发库2. 安装依赖的第三方工具库3. 编译Kaldi代码配置Kaldi编译Kaldi五、…

Python-进程和线程

张钊*&#xff0c;沈啸彬*, 王旭* 李月&#xff0c;曹海艳&#xff0c; (淮北师范大学计算机科学与技术学院&#xff0c;淮北师范大学经济与管理学院&#xff0c;安徽 淮北) *These authors contributed to the work equllly and should be regarded as co-first authors. &a…

智能电网中需求响应研究(Matlab代码实现)

目录 1 概述 2 运行结果 ​编辑 ​编辑 3 参考文献 4 Matlab代码实现 1 概述 智能电网需求响应可以降低电网高峰用电需求、提高电网运行稳定性和可靠性&#xff0c;尤其是通过需求响应实现电网接纳间歇性可再生能源发电的能力。 需求响应的全球进展及产生的效益等情况在…

SDUT—Python程序设计实验1011(面向对象)

7-1 sdut-oop-2 Shift Dot(类和对象&#xff09; 给出平面直角坐标系中的一点&#xff0c;并顺序给出n个向量&#xff0c;求该点根据给定的n个向量位移后的位置。 设计点类Point&#xff0c;内含&#xff1a; &#xff08;1&#xff09;整型属性x和y&#xff0c;表示点的横坐标…

数据可视化之交通可视化

一 前言 智慧城市的概念自 2008年提出以来&#xff0c;在国际上引起广泛关注&#xff0c;并持续引发了全球智慧城市的发展热潮。智慧城市已经成为推进全球城镇化、提升城市治理水平、破解大城市病、提高公共服务质量、发展数字经济的战略选择。近年来&#xff0c;我国智慧城市…

rxjs pipeable operators(上)

rxjs pipeable operators&#xff08;上&#xff09; A Pipeable Operator is a function that takes an Observable as its input and returns another Observable. It is a pure operation: the previous Observable stays unmodified. 一个 Pipeable Operator 是一个接受一个…

Ubuntu空间不足,如何扩容

目录 1、硬盘操作步骤 2、Ubuntu命令操作&#xff1a;安装分区管理工具 3、分区结果展示 1、硬盘操作步骤 最近发现Ubuntu空间不足&#xff0c;怎么去扩容呢&#xff1f;第一步&#xff1a;点击【硬盘】 第二步&#xff1a;点击【扩展】 第三步&#xff1a;修改【最大磁盘…

创新洞察丨消费品牌D2C生存发展的3大差异化策略

在过去六年中&#xff0c;DTC 品牌的销售额增长了两倍&#xff0c;但另一个事实是&#xff0c;他们花费了数十亿美元于营销投入&#xff0c;品牌知名度却不见增长。Lego 创意总监James Gregson认为&#xff0c;在同质化的DTC品牌崛起之下&#xff0c;打造品牌差异成为生存的关键…

Jsp 学习笔记

代码可参考: Demo地址 1 入门 1.1 环境搭建 创建moven项目目录结构如下 1.2 依赖配置 <!-- 依赖 --> <dependencies><dependency><groupId>javax.servlet</groupId><artifactId>javax.servlet-api</artifactId><version>…