【gridsample】地平线如何支持gridsample算子

news2024/10/6 8:24:04

文章目录

  • 1. grid_sample算子功能解析
    • 1.1 理论介绍
    • 1.2 代码分析
      • 1.2.1 x,y取值范围[-1,1]
      • 1.2.2 x,y取值范围超出[-1,1]
  • 2. 使用grid_sample算子构建一个网络
  • 3. 走PTQ进行模型转换与编译

实操以J5 OE1.1.60对应的docker为例

1. grid_sample算子功能解析

该段主要参考:https://blog.csdn.net/jameschen9051/article/details/124714759,不想看理论可直接跳至第2节

1.1 理论介绍

在图像处理领域,grid_sample 是一个常用的操作,通常用于对图像进行仿射变换或透视变换。它可以在给定输入图像和一个变换矩阵的情况下,对输入图像进行采样,生成一个新的输出图像。
pytorch中调用接口:

torch.nn.functional.grid_sample(input,grid,mode='bilinear',padding_mode='zeros',align_corners=None)
  • input:输入特征图,可以是四维或者五维张量,本文主要以四维为例进行介绍,表示为 (N,C,Hin,Win) 。
  • grid:采样网格,包含输出特征图的shape大小(Hout、Wout),每个 网格值 通过变换对应到输入特征图的采样点位,当对应四维input时,其张量形式为(N,Hout,Wout,2),其中最后一维大小必须为2,如果输入input为五维张量,那么最后一维大小必须为3。

为什么最后一维必须为2或者3?因为grid的最后一个维度实际上代表一个坐标(x,y)或者(x,y,z),对应到输入特征图的二维或三维特征图的坐标维度,x,y取值范围一般为[-1,1],该范围映射到输入特征图的全图,一通操作变换后对应于输出图像上的一个像素点。

  • mode:采样模式,可以是 ‘bilinear’(双线性插值)、 ‘nearest’(最近邻插值)、‘bicubic’ 双三次插值。。
  • padding_mode:填充模式,用于处理采样时超出输入图像边界的情况,可以是 ‘zeros’ 、 ‘border’、 ‘reflection’。
  • align_corners:一个布尔值,用于指定特征图坐标与特征值对应方式,设定为TRUE时,特征值位于像素中心。
    总的说来,grid_sample 算子会根据给定的网格(grid)在输入图像上进行采样,然后根据选择的插值方法在采样点周围的像素上进行插值,最终生成输出图像。

画一个在BEV方案中grid_sample原理图来帮助理解grid_sample怎么回事:
在这里插入图片描述

1.2 代码分析

对照代码进行下一步解读。
假设输入shape为(N,C,H_in,W_in),grid的shape设定为(N,H_out,W_out,2),使用双线性差值,填充模式为zeros,align_corners需要设置为True。

首先根据input和grid设定,输出特征图tensor的shape为(N,C,H_out,W_out),输出特征图上每一个cell上的值与grid最后一维(x,y)息息相关,那么如何计算输出tensor上每一个点的值?

首先,通过(x,y)找到输入特征图上的采样位置:由于x,y取值范围为[-1,1],为了便于计算,先将x,y取值范围调整为[0,1],方法是(x+1)/2,(y+1)/2。因此,将x,y映射为输入特征图的具体坐标位置:(w-1)(x+1)/2、(h-1)(y+1)/2。
将x,y映射到输入特征图实际坐标后,取该坐标附近四个角点特征值,通过四个特征值坐标与采样点坐标相对关系进行双线性插值,得到采样点的值。

注意:x,y映射后的坐标可能是输入特征图上任意位置。

基于上面的思路,可以进行一个简单的自定义实现。根据指定shape生成input和grid,之后取grid中的第一个位置中的x,y,根据x,y从input中通过双线性插值计算出output第一个位置的值。类比使用pytorch中的grid_sample算子生成output。

其它的看代码注释即可。

1.2.1 x,y取值范围[-1,1]

import torch
import numpy as np

def grid_sample(input, grid):
    N, C, H_in, W_in = input.shape
    N, H_out, W_out, _ = grid.shape
    output = np.random.random((N,C,H_out,W_out))
    for i in range(N):
        for j in range(C):
            for k in range(H_out):
                for l in range(W_out):
                    param = [0.0, 0.0]
                    # 通过(w-1)*(x+1)/2、(h-1)*(y+1)/2将x,y映射为输入特征图的具体坐标位置。
                    param[0] = (W_in - 1) * (grid[i][k][l][0] + 1) / 2
                    param[1] = (H_in - 1) * (grid[i][k][l][1] + 1) / 2
                    x0 = int(param[0])  # int取整规则:将小数部分截断去掉。
                    x1 = x0 + 1
                    y0 = int(param[1])
                    y1 = y0 + 1
                    param[0] -= x0  # 此时param里装的是小数部分
                    param[1] -= y0
                    # 双线性插值
                    left_top = input[i][j][y0][x0] * (1 - param[0]) * (1 - param[1])
                    left_bottom = input[i][j][y1][x0] * (1 - param[0]) * param[1]
                    right_top = input[i][j][y0][x1] * param[0] * (1 - param[1])
                    right_bottom = input[i][j][y1][x1] * param[0] * param[1]
                    result = left_bottom + left_top + right_bottom + right_top
                    output[i][j][k][l] = result
    return output

if __name__=='__main__':
    N, C, H_in, W_in, H_out, W_out = 1, 1, 4, 4, 2, 2

    input = np.random.random((N,C,H_in,W_in))
    # np.random.random()范围是[0,1),想要[a,b)的数据,需要(b-a)*np.random.random() + a
    grid = -1 + 2*np.random.random((N,H_out,W_out,2))  # 最后一维2,生成了坐标
    
    out = grid_sample(input, grid)
    print(f'自定义实现输出结果:\n{out}')

    input = torch.from_numpy(input)
    grid = torch.from_numpy(grid)
    # 注意:这儿align_corners=True
    output = torch.nn.functional.grid_sample(input,grid,mode='bilinear', padding_mode='zeros',align_corners=True)
    print(f'grid_sample输出结果:\n{output}')

输出
在这里插入图片描述
从输出结果上看,与pytorch基本一致。

注意:这里没有对超出[-1,1]范围的x,y值做处理,只能处理四维input,五维input的实现思路与这里基本一致:再加一层循环,内插算法改为3维。。

1.2.2 x,y取值范围超出[-1,1]

考虑到(x,y)取值范围可能越界,pytorch中的padding_mode设置就是对(x,y)落在输入特征图外边缘情况进行处理,一般设置’zero’,也就是对靠近输入特征图范围以外的采样点进行0填充,如果不进行处理显然会造成索引越界。要解决(x,y)越界问题,可以进行如下修改:

import torch
import numpy as np

def grid_sample(input, grid):
    N, C, H_in, W_in = input.shape
    N, H_out, W_out, _ = grid.shape
    output = np.random.random((N,C,H_out,W_out))
    for i in range(N):
        for j in range(C):
            for k in range(H_out):
                for l in range(W_out):
                    param = [0.0, 0.0]
                    # 通过(w-1)*(x+1)/2、(h-1)*(y+1)/2将x,y映射为输入特征图的具体坐标位置。
                    param[0] = (W_in - 1) * (grid[i][k][l][0] + 1) / 2
                    param[1] = (H_in - 1) * (grid[i][k][l][1] + 1) / 2
                    x1 = int(param[0] + 1)  # int取整规则:将小数部分截断去掉。
                    x0 = x1 - 1 
                    y1 = int(param[1] + 1)
                    y0 = y1 - 1
                    param[0] = abs(param[0] - x0)  # 此时param里装的是离x0,y0的距离
                    param[1] = abs(param[1] - y0)

                    # 填充
                    left_top_value, left_bottom_value, right_top_value, right_bottom_value = 0, 0, 0, 0
                    if 0 <= x0 < W_in and 0 <= y0 < H_in:
                        left_top_value = input[i][j][y0][x0]
                    if 0 <= x1 < W_in and 0 <= y0 < H_in:
                        right_top_value = input[i][j][y0][x1]
                    if 0 <= x0 < W_in and 0 <= y1 < H_in:
                        left_bottom_value = input[i][j][y1][x0]
                    if 0 <= x1 < W_in and 0 <= y1 < H_in:
                        right_bottom_value = input[i][j][y1][x1]

                    # 双线性插值
                    left_top = left_top_value * (1 - param[0]) * (1 - param[1])
                    left_bottom = left_bottom_value * (1 - param[0]) * param[1]
                    right_top = right_top_value * param[0] * (1 - param[1])
                    right_bottom = right_bottom_value * param[0] * param[1]
                    result = left_bottom + left_top + right_bottom + right_top
                    output[i][j][k][l] = result
    return output

if __name__=='__main__':
    N, C, H_in, W_in, H_out, W_out = 1, 1, 4, 4, 2, 2

    input = np.random.random((N,C,H_in,W_in))
    # np.random.random()范围是[0,1),想要[a,b)的数据,需要(b-a)*np.random.random() + a
    grid = -1 + 2*np.random.random((N,H_out,W_out,2))  # 最后一维2,生成了坐标
    grid[0][0][0] = [-1.2, 1.3]     # 超出[-1,1]的范围
    
    out = grid_sample(input, grid)
    print(f'自定义实现输出结果:\n{out}')

    input = torch.from_numpy(input)
    grid = torch.from_numpy(grid)
    # 注意:这儿align_corners=True
    output = torch.nn.functional.grid_sample(input,grid,mode='bilinear', padding_mode='zeros',align_corners=True)
    print(f'grid_sample输出结果:\n{output}')

输出
在这里插入图片描述

2. 使用grid_sample算子构建一个网络

先看一下地平线提供的算子支持与约束列表:
在这里插入图片描述

据此,构建一个简单的网络,test.py代码如下:

import torch
import torch.nn as nn
import torch.nn.functional as F

from horizon_nn.torch import export_onnx

class GridSampleModel(nn.Module):
    def __init__(self):
        super(GridSampleModel, self).__init__()
        
        self.unitconv = nn.Conv2d(24, 24, (1, 1), groups=3)
        nn.init.constant_(self.unitconv.weight, 1)
        nn.init.constant_(self.unitconv.bias, 0)

    
    def forward(self, x1, x2):
        x1 = self.unitconv(x1)
        x = F.grid_sample(x1,
                          grid=x2,
                          mode='bilinear',
                          padding_mode='zeros',
                          align_corners=True)
        x = self.unitconv(x)
        return x

if __name__ == "__main__":
    model = GridSampleModel()
    model.eval()

    input_names = ['x1', 'x2']
    output_names = ['output']
    x1 = torch.randn((1, 24, 600, 800))
    x2 = torch.randn((1, 48, 64, 2))

    export_onnx(model, (x1, x2), 'gridsample.onnx', 
                verbose=True, opset_version=11,
                input_names=input_names, output_names=output_names)
    
    print('convert to gridsampe onnx finish!!!')

运行test.py,生成onnx模型,可视化结构如下图:
在这里插入图片描述

3. 走PTQ进行模型转换与编译

对应config.yaml文件:

# 模型转化相关的参数
model_parameters:
  onnx_model: './gridsample.onnx'
  march: "bayes"
  working_dir: 'model_output'
  output_model_file_prefix: 'gridsample'

# 模型输入相关参数, 若输入多个节点, 则应使用';'进行分隔, 使用默认缺省设置则写None
input_parameters:
  input_name: "x1;x2"
  input_type_rt: 'featuremap;featuremap'
  input_layout_rt: 'NCHW;NCHW'
  input_type_train: 'featuremap;featuremap'
  input_layout_train: 'NCHW;NCHW'
  input_shape: '1x24x600x800;1x48x64x2'
  norm_type: 'no_preprocess;no_preprocess'

# 模型量化相关参数
calibration_parameters:
  calibration_type: 'skip'

# 编译器相关参数
compiler_parameters:
  compile_mode: 'latency'
  optimize_level: 'O3'

使用的是OE1.1.60对应的docker

hb_mapper makertbin --config config.yaml --model-type onnx

在这里插入图片描述
全一段,且都在BPU上

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/859952.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

最大子数组和——力扣53

文章目录 题目描述解法一 动态规划题目描述 解法一 动态规划 int maxSubArray(vector<int>& nums){int pre=0, res=nums

spring boot策略模式实用: 告警模块为例

spring boot策略模式实用: 告警模块 0 涉及知识点 策略模式, 模板方法, 代理, 多态, 反射 1 需求概括 场景: 每隔一段时间, 会获取设备运行数据, 如通过温湿度计获取到当前环境温湿度;需求: 对获取回来的进行分析, 超过配置的阈值需要产生对应的告警 2 方案设计 告警的类…

详解双端队列单调队列

1. 双端队列 双端队列&#xff08;Double-ended Queue&#xff09;&#xff0c;简称Deque&#xff0c;是一种具有特殊功能的线性数据结构。它支持从两端进行元素的插入和删除操作&#xff0c;因此可以在队列和栈之间灵活地切换操作。双端队列在编程中经常用于需要在队列和栈之间…

MySQL多表连接查询2

目录 1 所有有门派的人员信息 2 列出所有用户&#xff0c;并显示其机构信息 3 列出不入派的人员 4 所有没人入的门派 5 列出所有人员和门派的对照关系 6 列出所有没入派的人员和没人入的门派 7 求各个门派对应的掌门人名称: ​8 求所有当上掌门人的平均年龄: 9 求所…

6.4 (通俗易懂)可视化详解多通道 多通道输入输出卷积代码实现

以前对多通道和多通道输入输出的卷积操作不理解&#xff0c;今天自己在草稿纸上画图推理了一遍&#xff0c;终于弄懂了。希望能帮助到大家。 多通道可视化 一通道的2x2矩阵 torch.Size([2,2]) 相当于 torch.Size([1,2,2])&#xff0c;是一通道的2x2矩阵 二通道的 2x2矩阵 …

go-zero 是如何实现令牌桶限流的?

原文链接&#xff1a; 上一篇文章介绍了 如何实现计数器限流&#xff1f;主要有两种实现方式&#xff0c;分别是固定窗口和滑动窗口&#xff0c;并且分析了 go-zero 采用固定窗口方式实现的源码。 但是采用固定窗口实现的限流器会有两个问题&#xff1a; 会出现请求量超出限…

断续模式(DCM)与连续模式(CCM)

断续模式&#xff08;DCM&#xff09;与连续模式&#xff08;CCM)是开关电源最常用的两种工作模式。当初级开关管导通前&#xff0c;初级绕组还存在能量&#xff0c;不完全传递到次级&#xff0c;这种情况就叫连续模式。若初级绕组能量完全传递到次级&#xff0c;则为断续模式。…

Linux与安卓安全对抗

导读大家都知道安卓是基于Linux内核&#xff0c;而且大家也知道Linux的安全性是公认的&#xff0c;那为什么和Linux有着类似嫡系关系的安卓却一直被人诟病不安全呢&#xff1f;要想说清楚这个问题&#xff0c;我们需要了解一下安卓和Linux到底是什么关系&#xff0c;而且这两个…

中国信通院高质量数字化转型产品及服务全景图发布,合合信息多项AI产品入选

随着5G、人工智能、大数据等新一代技术的发展&#xff0c;企业在商业竞争中正面临更多不确定性。中国信通院高度关注企业数字化转型中遇到的痛点&#xff0c;发起“铸基计划-高质量数字化转型行动”&#xff0c;链接企业数字化转型供、需两侧的发展需求&#xff0c;以期推动国家…

MySQL—缓存

目录标题 为什么要有Buffer Poolbuffer pool有多大buffer pool缓存什么 如何管理Buffer Pool如何管理空闲页如何管理脏页如何提高缓存命中率预读失效buffer pool污染 脏页什么时候会被刷入到磁盘 为什么要有Buffer Pool 虽然说MySQL的数据是存储在磁盘中&#xff0c;但是也不能…

C++——缺省参数

缺省参数的定义 缺省参数是声明或定义函数时为函数的参数指定一个缺省值。在调用该函数的时候&#xff0c;如果没有指定实参&#xff0c;则采用该形参的缺省值&#xff0c;否则使用指定的实参。 void Func(int a 0) {cout << a << endl; } int main() { Func()…

【C++学习手札】new和delete看这一篇就够了!

​ 食用指南&#xff1a;本文在有C基础的情况下食用更佳 &#x1f340;本文前置知识&#xff1a; C类 ♈️今日夜电波&#xff1a; Prover—milet 1:21 ━━━━━━️&#x1f49f;──────── 4:01 …

OI易问卷协助企业服务好员工,收集员工反馈与信息

OI易问卷——企业问卷调查工具 OI易问卷&#xff0c;是群硕专为企业打造&#xff0c;对内服务员工的调查问卷。 集成于办公联合创新平台&#xff0c;并进一步帮助客户实现与微信或企业微信等其他平台的对接。 可以有效促进员工服务数字化&#xff0c;提高各部门工作效率&…

mysql的相关指令

mysql的相关指令 DML 数据操作语言DQL数据查询 mysql -uroot -p //启动数据库 show databases; //查看有哪些数据库 use 数据库名; //使用某个数据库 show tables; //查看数据库内有哪些表 exit; //退出mysql的命令环境 create database 数据库名称 charset utf8; //创建数据…

四项代表厂商,Kyligence 入选 Gartner 数据及人工智能相关领域多项报告

近日&#xff0c;全球权威的技术研究与咨询公司 Gartner 发布了《2023 年中国数据、分析及人工智能技术成熟度曲线》、《2023 年分析与商业智能技术成熟度曲线报告》、《2023 年数据管理技术成熟度曲线报告》&#xff0c;Kyligence 分别入选这三项报告的指标平台 Metrics Store…

【Git】 git push origin master Everything up-to-date报错

hello&#xff0c;我是索奇&#xff0c;可以叫我小奇 git push 出错&#xff1f;显示 Everything up-to-date 那么看看你是否提交了message 下面是提交的简单流程 git add . git commit -m "message" git push origin master 大多数伙伴是没写git commit -m "…

二维码查分系统制作方法大公开:用这个方法,你也可以快速拥有

自从“双减”政策颁布以来&#xff0c;学校对成绩的公布变得更加重视。尤其是小学年级&#xff0c;将成绩信息从分数制改为等级制进行发布。同时&#xff0c;成绩公布的方式也有了新的规定&#xff1a;禁止公开公布&#xff0c;不允许为学生成绩进行排名&#xff0c;并需要以特…

Java课题笔记~ ServletConfig

概念&#xff1a;代表整个web应用&#xff0c;可以和程序的容器(服务器)来通信 <?xml version"1.0" encoding"UTF-8"?> <web-app xmlns"http://java.sun.com/xml/ns/javaee"xmlns:xsi"http://www.w3.org/2001/XMLSchema-instan…

突破笔试:力扣129. 求根节点到叶节点数字之和

1. 题目链接&#xff1a;129. 求根节点到叶节点数字之和 给你一个二叉树的根节点 root &#xff0c;树中每个节点都存放有一个 0 到 9 之间的数字。每条从根节点到叶节点的路径都代表一个数字&#xff1a;例如&#xff0c;从根节点到叶节点的路径 1 -> 2 -> 3 表示数字 …

操作系统 -- 进程间通信

一、概述 进程经常需要与其他进程通信。例如&#xff0c;在一个shell管道中&#xff0c;第一个进程的输出必须传送给第二个进程&#xff0c;这样沿着管道传递下去。因此在进程之间需要通信&#xff0c;而且最好使用一种结构良好的方式&#xff0c;不要使用中断。在下面几节中&…