Pytorch的grid_sample是如何实现对grid求导的?(源码解读)

news2024/10/6 10:34:50

Pytorch的grid_sample是如何实现对grid求导的?(源码解读)

这里本人的参考源码是grid_sample的CPU内核的CPP实现:https://github.com/pytorch/pytorch/blob/b039a715ce4e9cca82ae3bf72cb84652957b2844/aten/src/ATen/native/cpu/GridSamplerKernel.cpp。

grid_sample功能简述

在这里插入图片描述
给定一个input(4D或5D,一般指原图像)和一个流场grid(4D或5D,一般指变形流),基于来自grid的像素位置和input的像素值计算output(输出图像)。
例如,对于4D的情况,input的形状为(N,C,Hin,Win),grid的形状为(N,Hout,Wout,2),那么输出结果即为(N,C,Hout,Wout)。对于output的每个位置(i, j),将根据grid在(i, j)上的值(x, y),从input的(x, y)处采样像素值(采样过程要考虑x、y非整值和越界的情况),用作output在(i, j)处的像素值。
(grid的值应当是根据input的空间维度(H,W)归一化到[-1,1]后的像素点坐标。例如x=-1,y=1代表输入的左上角像素。)
如果grid具有超出[-1,1]范围的值,相应的输出由padding_mode来处理:

  1. padding_mode=“zeros”:超出范围的grid值用0替代。
  2. padding_mode=“border”:超出范围的grid值用边界值替代。
  3. padding_mode=“reflection”:超出范围的grid值用通过边界反射后的值替代。

基本流程

  • 定义结构体“ComputeLocation”:作用是基于padding模式计算插入位置
  • 定义结构体“ApplyGridSample”:作用有两个:(1)提供N(即空间维度)个“ComputeLocation”结构体,然后利用他们去计算对应维度的插入位置;(2)插入值并且写入到output。
  • 定义方法“grid_sample_2d_grid_slice_iterator”函数:作用有两个:(1)迭代grid张量的每一个值即(x,y)对;(2)在每次迭代时应用一个给定的操作器(可以视为是上述的ApplyGridSample中前向和反向传播方法),使得在前向和反向传播时可以使用相同的模式。

实现细节

  • “ComputeLocation”结构体
    apply()函数:输入grid值in,返回去标准化和应用padding机制(逐像素)后的插入位置
    apply_get_grad()函数:输入grid值in,类似于apply,但也会返回apply(in)关于in的偏导数(返回值是一个vec对)(逐像素),通常用于梯度计算中。(这里并没有计算全部的梯度,仅仅是算了根据grid获得去标准化后插入坐标这个过程所得到的梯度
    比如说,采用zeros填充的实现如下:
    [外链图片转存失败,源站可能有防盗在这里插入!链机制,建描述]议将图片上https://传(imblog.csgmg.cn/dfB7aK12c96be39f4fd89ca78ce50f35.png78209)(https://img-blog.csdnimg.cn/df12c96be39f4fd89526c8a09ce50f35.png)]
    这里的apply_get_grad()函数就仅仅对输入grid做了一个去归一化(本质上是通过(in+1)*half_max_val实现的,in的大部分值应位于[-1,1]范围内)作为输出,因此对应的偏导数仅仅就是half_max_val。
  • 对于padding=“zeros”:返回的插入位置值未必全部落在[0, w]和[0,h]范围内,偏导为half_max_val。
  • 对于padding=“border”:返回的插入位置值必定全部落在[0,w]和[0,h]范围内,偏导对于原始grid输入in落在[-1,1]范围上的值为half_max_val,否则为0。
  • 对于padding=“reflection”:返回的插入位置值必定落在[0,w]和[0,h]范围内,偏导求法可以自己看看。
  1. “ApplyGridSample”结构体
    具有N个“ComputeLocation”结构体,其中N是空间维度的数量(对于二维图像H*W,N即为2)。给定N个输入grid向量(每个空间维度一个)和空间偏移,其从“ComputeLocation”中获得对应的插入位置,应用插入过程,并将结果写入输出(对于反向传播过程即是写入grad_input和grad_grid)
    forward()函数:应用网格采样(前向)过程(上述),输出out_slice
    backward()函数:应用反向传播过程,参数与机制和前向过程类似,输出grad_input,grad_grid
    比如说,采用双线性插值的forward()函数实现如下:
inline void forward(TensorAccessor<scalar_t, 3>& out_slice,
                      const TensorAccessor<scalar_t, 3>& inp_slice,
                      int64_t offset, const Vec& grid_x, const Vec& grid_y,
                      int64_t len) const {
    auto x = compute_W.apply(grid_x);
    auto y = compute_H.apply(grid_y);	// 首先根据grid算出反归一化后的插入位置

	//基于双线性插值,对每个位置(小数)首先获得四个方向(到最近的整数位置)上的距离作为插值的权重
	//会返回权重和mask(考虑是否需要处理超出边界的部分)
    auto interp_params = compute_interp_params(x, y); 

	//以下皆为上一个函数的返回值
    auto nw = std::get<4>(interp_params);
    auto ne = std::get<5>(interp_params);
    auto sw = std::get<6>(interp_params);
    auto se = std::get<7>(interp_params);

    auto nw_mask = std::get<8>(interp_params);
    auto ne_mask = std::get<9>(interp_params);
    auto sw_mask = std::get<10>(interp_params);
    auto se_mask = std::get<11>(interp_params);

    auto i_y_n = std::get<12>(interp_params);
    auto i_x_w = std::get<13>(interp_params);
	
	//获得原图input上grid所指示的位置附近四个整数像素点的位置
    auto i_nw_offset = i_y_n * iVec(inp_sH) + i_x_w * iVec(inp_sW);
    auto i_ne_offset = i_nw_offset + iVec(inp_sW);
    auto i_sw_offset = i_nw_offset + iVec(inp_sH);
    auto i_se_offset = i_sw_offset + iVec(inp_sW);

    #ifndef _MSC_VER  
    # pragma unroll  
    #endif
    for (int64_t c = 0; c < C; ++c) {	//C为batch_size
      auto inp_slice_C_ptr = inp_slice[c].data();

      // mask_gather zeros out the mask, so we need to make copies
      Vec nw_mask_copy = nw_mask;
      Vec ne_mask_copy = ne_mask;
      Vec sw_mask_copy = sw_mask;
      Vec se_mask_copy = se_mask;
      //获得原图中四个方向位置中的像素值
      //这里其实是通过对输入图像的底层指针进行偏移量计算来实现根据索引进行插入的效果的
      auto nw_val = mask_gather<sizeof(scalar_t)>(Vec(0), inp_slice_C_ptr, i_nw_offset, nw_mask_copy);
      auto ne_val = mask_gather<sizeof(scalar_t)>(Vec(0), inp_slice_C_ptr, i_ne_offset, ne_mask_copy);
      auto sw_val = mask_gather<sizeof(scalar_t)>(Vec(0), inp_slice_C_ptr, i_sw_offset, sw_mask_copy);
      auto se_val = mask_gather<sizeof(scalar_t)>(Vec(0), inp_slice_C_ptr, i_se_offset, se_mask_copy);
	  //根据各方向权重计算出最终插值结果
      auto interpolated = (nw_val * nw) + (ne_val * ne) + (sw_val * sw) + (se_val * se);
      interpolated.store(out_slice[c].data() + offset, len);
    }

backward()计算关于grid的梯度关键:

gx = gx + ((ne_val - nw_val) * s + (se_val - sw_val) * n) * gOut;
gy = gy + ((sw_val - nw_val) * e + (se_val - ne_val) * w) * gOut;

这里gOut应指来自下一层传回的梯度,ne_val,nw_val,se_val,sw_val指四个方向位置上的原图像像素值,这四个值都是通过以grid值作为索引查找原图像相邻位置去获取到的;s、n、e、w分别指该grid值到这四个方向整数位置上的一个距离(用作双线性插值的权重)。
3. “grid_sample_2d_grid_slice_iterator”函数
提供一个抽象来有效地迭代一个“grid”分片(不带batch维度)。实质上是遍历了每个实例,然后对每个实例应用上述前向和反向处理。

总结

(在双线性插值下),grid_sample()对grid求导采取了类似图像梯度的方式,直接用每个grid值关联到的周围四个位置上的像素值,将两两的差值乘上一个权重(双线性插值的距离),用作本函数的梯度,然后传回给前一层。

思考

考虑这么一个图像矫正的问题,如果有一张输入的变形图像inp,一个参考恢复网格grid_gt,一个预测恢复网格grid_pred,要衡量grid_pred网格的正确性,有两种做法:

  • 直接计算grid_pred和grid_gt的距离(比如说L1距离),作为衡量grid_pred的指标。
  • 分别使用grid_pred和grid_gt对原图像inp应用grid_sample()函数,将矫正结果的距离用作衡量grid_pred的指标。

由于任务关心的实际上是最终的矫正结果的效果,而不是grid绝对值的差距,因此本人认为后一种方法更加准确。再结合上述对grid_sample()求导的分析,这两种方法传递的梯度信息是大不相同的(也即后一种做法在实现上是具有意义的),grid_sample()对grid某点值的偏导会考虑该点在原图像上所有相邻点的像素值。

由于本人水平有限,对代码的理解上可能不够深入,如果存在错误之处,请大神在评论区指正!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/14263.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【Detectron2】代码库学习-4. LazyConfig 配置文件

目录1. 配置文件2. LazyConfig 导入导出3. 递归实例化4. 基于LazyConfig的训练步骤4.1 导入依赖库4.2 日志初始化4.3 训练4.4 评估4.5 训练流程4.6 主函数入口5. TipsDetectron2是Facebook AI Research(FAIR)推出的基于Pytorch的视觉算法开源框架&#xff0c;主要聚焦于目标检测…

力扣160 - 相交链表【双指针妙解】

链表也能相交~一、题目描述二、思路分析与罗列三、整体代码展示四、总结与提炼一、题目描述 原题传送门 示例 1&#xff1a; 输入&#xff1a;intersectVal 8, listA [4,1,8,4,5], listB [5,6,1,8,4,5], skipA 2, skipB 3 输出&#xff1a;Intersected at ‘8’ 解释&…

MySQL索引

索引索引的相关概念索引分类索引的底层数据结构及其原理主键索引&二级索引聚集和非聚集索引哈西索引&&自适应哈西索引索引和慢查询日志索引优化索引的相关概念 什么是索引&#xff1f;索引其实就是一个数据结构。当表中的数据量到达几十万甚至上百万的时候&#x…

每个 Flutter 开发者都应该知道的一些原则

“仅仅让代码起作用是不够的。有效的代码经常被严重破坏。仅满足于工作代码的程序员表现得不专业。他们可能担心没有时间改进代码的结构和设计,但我不同意。没有什么比糟糕的代码对开发项目产生更深远、更长期的影响了。” ― Robert C. Martin,Clean Code:敏捷软件工艺手册…

fpga nvme 寄存器

图1所示的NVMe多队列&#xff0c;每个队列支持64K命令&#xff0c;最多支持64K队列。这些队列的设计使得IO命令和对命令的处理不仅可以在同一处理器内核上运行&#xff0c;也可以充分利用多核处理器的并行处理能力。每个应用程序或线程可以有自己的独立队列&#xff0c;因此不需…

基于Nacos的注册中心与配置中心

基于Nacos的注册中心与配置中心 Nacos简介 概述 Nacos全称是动态命名和配置服务&#xff0c;Nacos是一个更易于构建云原生应用的动态服务发现、配置管理和服务管理平台。Nacos主要用于发现、配置和管理微服务。 什么是Nacos Nacos支持几乎所有主流类型的服务的发现、配置和…

同花顺_代码解析_技术指标_A

本文通过对同花顺中现成代码进行解析&#xff0c;用以了解同花顺相关策略设计的思想 目录 ABI AD ADL ADR ADTM ADVOL AMV ARBR ARMS ASI ATR ABI 绝对幅度指标 算法&#xff1a;上涨家数减去下跌家数所得的差的绝对值。 该指标只适用于大盘日线。 行号 1 aa…

题目7飞机票订票系统

题目7飞机票订票系统问题描述:某公司每天有10航班(航班号、价格)&#xff0c;每个航班的飞机&#xff0c;共有80个座位&#xff0c; 20排&#xff0c;每排4个位子。编号为A&#xff0c;BCD。如座位号:10D表示10排D座。 运行界面如下&#xff1a; 1)能从键盘录入订票信息:乘客的…

[Games 101] Lecture 13-16 Ray Tracing

Ray Tracing Why Ray Tracing 光栅化不能得到很好的全局光照效果 软阴影光线弹射超过一次&#xff08;间接光照&#xff09; 光栅化是一个快速的近似&#xff0c;但是质量较低 光线追踪是准确的&#xff0c;但是较慢 Rasterization: real-time, ray tracing: offline生成一帧…

狗屎一样的面试官,你遇到过几个?

做了几年软件开发&#xff0c;我们都或多或少面试过别人&#xff0c;或者被别人面试过。大家最常吐槽的就是面试造火箭&#xff0c;进厂拧螺丝。今天就来吐槽一下那些奇葩&#xff08;gou&#xff09;一样的面试官 A 那是在我刚工作1年的时候&#xff0c;出去面试前端开发。 那…

分布式开源存储架构Ceph概述

概述 k8s的后端存储中ceph应用较为广泛&#xff0c;当前的存储市场仍然是由一些行业巨头垄断&#xff0c;但在开源市场还是有一些不错的分布式存储&#xff0c;其中包括了Ceph、Swift、sheepdog、glusterfs等 什么是ceph&#xff1f; Ceph需要具有可靠性&#xff08;reliab…

C++11标准模板(STL)- 算法(std::partition_point)

定义于头文件 <algorithm> 算法库提供大量用途的函数&#xff08;例如查找、排序、计数、操作&#xff09;&#xff0c;它们在元素范围上操作。注意范围定义为 [first, last) &#xff0c;其中 last 指代要查询或修改的最后元素的后一个元素。 定位已划分范围的划分点 …

线上崩了?一招教你快速定位问题。

&#x1f44f; 背景 正浏览着下班后去哪家店撸串&#xff0c;结果隔壁组同事囧着脸过来问我&#xff1a;大哥&#xff0c;赶紧过去帮忙看个问题&#xff01;客户反馈很多次了&#xff0c;一直找不出问题出在哪里&#xff01;&#xff01;&#xff01; 我&#xff1a;能不能有…

利用WPS功能破解及本地恢复密码

利用WPS功能破解及本地恢复密码 认识WPS功能 ​ WPS&#xff08;Wi-Fi Protected Setup&#xff09;是Wi-Fi保护设置的英文缩写。WPS是由Wi-Fi联盟组织实施的认证项目&#xff0c;主要致力于简化无线局域网安装及安全性能的配置工作。WPS并不是一项新增的安全性能&#xff0c;它…

数据结构之链表(单链表)

文章目录前言一、链表二、链表的八种结构1.单向或者双向2.带头或者不带头&#xff08;头&#xff1a;哨兵位&#xff09;3.循环或者不循环三、单链表1.接口2.接口的实现1.开辟一个新的节点1.打印单链表2.头插3.尾插4.头删5.尾删6.单链表的查找7.在pos位置之前插入数据8.在pos位…

MySQL8.0概述及新特性

文章目录学习资料常见的数据库管理系统排名&#xff08;DBMS&#xff09;SQL的分类DDL&#xff1a;数据定义语言DML&#xff1a;数据操作语言DCL&#xff1a;数据控制语言MySQL8.0新特性性能优化默认字符集DDL的原子化计算列宽度属性窗口函数公用表表达式索引新特性支持降序索引…

面试了20+前端大厂,整理出的面试题

事件是什么&#xff1f;事件模型&#xff1f; 事件是用户操作网页时发生的交互动作&#xff0c;比如 click/move&#xff0c; 事件除了用户触发的动作外&#xff0c;还可以是文档加载&#xff0c;窗口滚动和大小调整。事件被封装成一个 event 对象&#xff0c;包含了该事件发生…

RabbitMQ Windows 安装、配置、使用 - 小白教程

1、配套文件 下载erlang&#xff1a;http://www.erlang.org/downloads/ 下载RabbitMQ&#xff1a;http://www.rabbitmq.com/download.html 2、RabbitMQ服务端代码是使用并发式语言Erlang编写的&#xff0c;安装Rabbit MQ的前提是安装Erlang&#xff0c;双击otp_win64_21.1.ex…

计算机毕业设计springboot+vue+elementUI汽车车辆充电桩管理系统

项目介绍 随着我国汽车行业的不断发展&#xff0c;电动汽车已经开始逐步的领导整个汽车行业&#xff0c;越来越多的人在追求环保和经济实惠的同时开始使用电动汽车&#xff0c;电动汽车和燃油汽车最大的而不同就是 需要充电&#xff0c;同时我国的基础充电桩也开始遍及了大多数…

Java 异常处理

目录 一、异常的基本概念 二 、为何需要异常处理 三 、异常的处理 四 、异常类的继承架构 五 、抛出异常 5.1、程序中抛出异常 5.2、指定方法抛出异常 六 、自定义异常 不管使用的那种语言进行程序设计&#xff0c;都会产生各种各样的错误。 Java 提供有强大的异常处理…