onnxruntime 中的 Gather 算子

news2024/10/6 16:27:39

上一篇文章中介绍了 Division by Invariant Integers using Multiplication 的原理,很多框架均才用该算法优化除法运算。onnxruntime 是已知实现中最为简洁的,因此本文结合 onnxruntime 的 Gather 实现进行介绍。 Gather 算子是一个索引类算子,kernel 中每个线程计算偏移时使用 fast_divmod 避免除法运算。

注意:ONNX 中的 Gather 功能与 numpy.take 相同,torch.index_select 是其简化版。而 ONNX 中的 GatherElements 与 torch.gather 和 paddle. take_along_axis 相对应。

Gather

Gather
CudaKernel
OpKernel
GatherBase

会话运行时,ExecuteKernel 函数会调用 OpKernel。
CudaKernel 是 CUDA kernel 的基类,提供了 CudaKernel::Compute 函数。
OpKernelInfo 是一个非常轻量级的类,它作为构建 Kernel 实例所需的所有数据的聚合视图。 注意:它不拥有/持有任何对象。

class Gather : public CudaKernel, public GatherBase {
 public:
  Gather(const OpKernelInfo& info) : CudaKernel(info), GatherBase(info) {}
  Status ComputeInternal(OpKernelContext* context) const override;
};

Gather::ComputeInternal

Gather::ComputeInternal
GatherBase::PrepareForCompute
GatherImpl

创建一个 GatherBase::Prepare 结构体,包含了两个输入和一个输出张量的指针。
GatherBase::PrepareForCompute 准备输入输出。输出张量的秩为input_rank - 1 + indices_rank,即将axis参数指定的轴替换为indices张量的形状。
ORT_RETURN_IF_ERROR 在表达式失败时返回错误。
TensorShape::SizeFromDimension 计算从指定维度开始的乘积大小。
axis参数会将输入张量划分为3部分:batch 维度、索引维度、分块维度。
block_size为每个索引对应的分块大小。
N为索引数量。
input_block_size为在输入上的分块大小。
indices_max即索引上限。

Status Gather::ComputeInternal(OpKernelContext* context) const {
  Prepare p;
  ORT_RETURN_IF_ERROR(PrepareForCompute(context, p));

  const TensorShape& input_shape = p.input_tensor->Shape();

  const int64_t block_size = input_shape.SizeFromDimension(p.axis + 1);
  size_t N = p.indices_tensor->Shape().Size();
  const int64_t input_block_size = input_shape.SizeFromDimension(p.axis);
  const int64_t output_block_size = N * block_size;
  const int64_t indices_max = input_shape[p.axis];
  const void* input_data = p.input_tensor->DataRaw();
  const void* indices_data = p.indices_tensor->DataRaw();
  void* output_data = p.output_tensor->MutableDataRaw();

  if (p.output_tensor->Shape().Size() == 0) {
    return Status::OK();
  }

gsl::narrow可确保无损失转换,并在无法转换时引发gsl::narrowing_error
fast_divmod 即 DivMod,用于快速计算除法。

  const fast_divmod divmod_output_block_size(gsl::narrow_cast<int>(output_block_size));
  const fast_divmod divmod_block_size(gsl::narrow_cast<int>(block_size));

  const size_t element_size = p.input_tensor->DataType()->Size();
  const size_t index_element_size = p.indices_tensor->DataType()->Size();

GatherImpl 函数索仅支持int32_tint64_t引类型。
传入的p.output_tensor->Shape().Size()即输出元素总数。

  // CUDA Kernel implementation supports element sizes of:
  // int8_t, int16_t, int32_t and int64_t which covers all supported
  // types since there is no computations necessary just data movement
  if (p.indices_tensor->IsDataType<int32_t>() ||
      p.indices_tensor->IsDataType<int64_t>()) {
    GatherImpl(
        Stream(context),
        input_block_size,
        indices_max,
        divmod_output_block_size,
        divmod_block_size,
        indices_data,
        index_element_size,
        input_data,
        element_size,
        output_data,
        p.output_tensor->Shape().Size());
    return Status::OK();
  }

ORT_MAKE_STATUS 创建一个 Status 对象。

  return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "Type for Tind not supported yet in Gather.");
}

GatherImpl

GatherImpl
_GatherKernel

GridDim 结构体中定义了美剧值。
N为输出元素数量。直接求出所需 threadblock 的数量,没有太多策略。

void GatherImpl(
    cudaStream_t stream,
    const int64_t input_block_size,
    const int64_t indices_max,
    const fast_divmod& output_block_size,
    const fast_divmod& block_size,
    const void* indices_data,
    size_t index_element_size,
    const void* input_data,
    size_t element_size,
    void* output_data,
    const size_t N) {

  int blocksPerGrid = (int)(ceil(static_cast<float>(N) / GridDim::maxThreadsPerBlock));

ToCudaType 模板类将类型枚举转换为数据类型。
根据元素大小调用 _GatherKernel 模板函数,这样减少了实例化类型。

  switch (element_size) {
    case sizeof(int8_t): {
      using CudaType = typename ToCudaType<int8_t>::MappedType;
      _GatherKernel<<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream>>>(
          input_block_size, indices_max, output_block_size, block_size, indices_data, index_element_size,
          reinterpret_cast<const CudaType*>(input_data), reinterpret_cast<CudaType*>(output_data), (CUDA_LONG)N);

    } break;
    case sizeof(int16_t): {
      using CudaType = typename ToCudaType<int16_t>::MappedType;
      _GatherKernel<<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream>>>(
          input_block_size, indices_max, output_block_size, block_size, indices_data, index_element_size,
          reinterpret_cast<const CudaType*>(input_data), reinterpret_cast<CudaType*>(output_data), (CUDA_LONG)N);

    } break;
    case sizeof(int32_t): {
      using CudaType = typename ToCudaType<int32_t>::MappedType;
      _GatherKernel<<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream>>>(
          input_block_size, indices_max, output_block_size, block_size, indices_data, index_element_size,
          reinterpret_cast<const CudaType*>(input_data), reinterpret_cast<CudaType*>(output_data), (CUDA_LONG)N);

    } break;
    case sizeof(int64_t): {
      using CudaType = typename ToCudaType<int64_t>::MappedType;
      _GatherKernel<<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream>>>(
          input_block_size, indices_max, output_block_size, block_size, indices_data, index_element_size,
          reinterpret_cast<const CudaType*>(input_data), reinterpret_cast<CudaType*>(output_data), (CUDA_LONG)N);

    } break;

    default:
      ORT_THROW("Unsupported element size by the Gather CUDA kernel");
  }
}

_GatherKernel

_GatherKernel
GetIndexValue

CALCULATE_ELEMENTWISE_INDEX_OR_EXIT 计算元素索引,并在超出范围时返回。

template <typename T>
__global__ void _GatherKernel(
    const int64_t input_block_size,
    const int64_t indices_max,
    const fast_divmod output_block_size,
    const fast_divmod block_size,
    const void* indices_data,
    const size_t index_element_size,
    const T* input_data,
    T* output_data,
    const CUDA_LONG N) {
  CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N);
  CUDA_LONG input_index = 0;

线程号id除以output_block_size得到输出元素所对应的输入分块索引input_block_index和输入分块内的偏移block_offset
根据block_offset计算对应的indices张量的索引以及分块内元素偏移。
GetIndexValue 取出indices张量的值。相比 TensorFlow 的 gather_functor_gpu.cu.h 没有进行向量化访存优化。
idx支持负数。索引值超出范围时赋零。

  int input_block_index, block_offset;
  output_block_size.divmod(id, input_block_index, block_offset);
  int indices_index, offset;
  block_size.divmod(block_offset, indices_index, offset);
  int64_t idx = GetIndexValue(indices_data, index_element_size, indices_index);
  idx = idx < 0 ? idx + indices_max : idx;
  if (idx < 0 || idx >= indices_max) {
    output_data[id] = 0;
    return;
  }

三部分相加得到输入张量索引。

  input_index = input_block_index * input_block_size + idx * block_size.d_ + offset;
  output_data[id] = input_data[input_index];
}

GetIndexValue

index_data指针转为相应类型,然后返回偏移位置的值。

__host__ __device__ inline int64_t GetIndexValue(const void* index_data, size_t index_element_size, size_t offset) {
  switch (index_element_size) {
    case sizeof(int32_t):
      return *(reinterpret_cast<const int32_t*>(index_data) + offset);
      break;
    case sizeof(int64_t):
      return *(reinterpret_cast<const int64_t*>(index_data) + offset);
      break;
    default:
      break;
  }
  // What is a sensible thing to do here?
  assert(false);
  return std::numeric_limits<int64_t>::max();
}

DivMod

除法取余实现基于 Division by Invariant Integers using Multiplication 中的 Figure 4.1。

在这里插入图片描述

// The code below is based on section 4 Unsigned division of paper https://gmplib.org/~tege/divcnst-pldi94.pdf
// In current ORT, fast_divmod is used for calculating the position of a element in tensor,
// so unsigned integer division from the paper is good enough for ORT. The advantage is that div is very simple,
// then GPU compiler can do loop unroll easilly when divmod is called in a loop.
template <>
struct DivMod<int> {
  DivMod(int d = 1) {
    d_ = d == 0 ? 1 : d;
    ORT_ENFORCE(d_ >= 1 && d_ <= static_cast<uint32_t>(std::numeric_limits<int>::max()));

l_ ℓ = ⌈ log ⁡ 2 x ⌉ \ell = \lceil \log_2 x \rceil =log2x

    for (l_ = 0; l_ < 32; l_++)
      if ((1U << l_) >= d_) break;

m m ′ = ⌊ 2 N ∗ ( 2 ℓ − d ) / d ⌋ + 1 m' = \lfloor 2^N ∗ (2^\ell − d)/d\rfloor + 1 m=2N(2d)/d+1

    uint64_t one = 1;
    uint64_t m = ((one << 32) * ((one << l_) - d_)) / d_ + 1;
    M_ = static_cast<uint32_t>(m);
    // according to paper, the value of m' should fit in a unsigned integer.
    ORT_ENFORCE(M_ > 0 && M_ == m);
  }

DivMod::div

t t 1 = M U L U H ( m ′ , n ) t_1 = \mathrm{MULUH}(m', n) t1=MULUH(m,n),使用uint64_t计算避免溢出。
对于 q q q

  • 如果 d = 1 d = 1 d=1,那么 ℓ = 0 \ell = 0 =0,所以 m ′ = 1 m' = 1 m=1 s h 1 = s h 2 = 0 sh_1 = sh_2 = 0 sh1=sh2=0。代码计算 t 1 = ⌊ 1 ∗ n / 2 N ⌋ = 0 t_1 = \lfloor 1 ∗ n/2^N \rfloor = 0 t1=1n/2N=0 q = n q = n q=n
  • d > 1 d > 1 d>1,则 ℓ ≥ 1 \ell≥1 1,故 s h 1 = 1 sh_1 = 1 sh1=1 s h 2 = ℓ − 1 sh_2 =\ell −1 sh2=1
    q = S R L ( t 1 + S R L ( n − t 1 , s h 1 ) , s h 2 ) = S R L ( t 1 + S R L ( n − t 1 , 1 ) , ℓ − 1 ) = ⌊ t 1 + ⌊ ( n − t 1 ) 2 ⌋ 2 ℓ − 1 ⌋ = ⌊ ⌊ 2 ∗ t 1 2 + ( n − t 1 ) 2 ⌋ 2 ℓ − 1 ⌋ (4.5) = ⌊ ⌊ ( t 1 + n ) / 2 ⌋ 2 ℓ − 1 ⌋ = ⌊ t 1 + n 2 ℓ ⌋ \begin{aligned} q &= \mathrm{SRL}(t_1 + \mathrm{SRL}(n − t_1, sh_1), sh_2)\\ &= \mathrm{SRL}(t_1 + \mathrm{SRL}(n − t_1, 1), \ell− 1)\\ &=\lfloor \frac{t_1 + \lfloor \frac{(n − t_1)}{2} \rfloor}{2^{\ell− 1}}\rfloor\\ &=\lfloor \frac{\lfloor \frac{2*t_1}{2} + \frac{(n − t_1)}{2} \rfloor}{2^{\ell− 1}}\rfloor \qquad\text{(4.5)}\\ &=\lfloor \frac{\lfloor(t_1 + n)/2\rfloor}{2^{\ell− 1}} \rfloor\\ &=\lfloor \frac{t_1 + n}{2^{\ell}} \rfloor \end{aligned} q=SRL(t1+SRL(nt1,sh1),sh2)=SRL(t1+SRL(nt1,1),1)=21t1+2(nt1)=2122t1+2(nt1)(4.5)=21⌊(t1+n)/2=2t1+n

__umulhi 计算两个 32 位无符号整数的乘积的最高有效 32 位。

  __host__ __device__ inline int div(int n) const {
#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__)
    uint32_t t = __umulhi(M_, n);
    return (t + n) >> l_;
#else
    // Using uint64_t for t, then t + n won't overflow.
    uint64_t t = ((uint64_t)M_ * n) >> 32;
    return static_cast<int>((t + n) >> l_);
#endif
  }

DivMod::mod

n m o d d = n − d ∗ ⌊ n / d ⌋ n \enspace \mathrm{mod} \enspace d = n − d ∗ \lfloor n/d \rfloor nmodd=ndn/d

  __host__ __device__ inline int mod(int n) const {
    return n - div(n) * d_;
  }

DivMod::divmod

  __host__ __device__ inline void divmod(int n, int& q, int& r) const {
    q = div(n);
    r = n - q * d_;
  }
  uint32_t d_;  // divisor
  uint32_t M_;  // m' in the paper.
  uint32_t l_;  // l_ = ceil(log2(d_))
};
  • Gather
  • ONNXRuntime整体概览
  • ONNXRuntime源码之OpKernel注册
  • Ways to specify [[nodiscard]] before C++17
  • microsoft/GSL
  • How to use gsl narrow cast
  • 警告 C26472
  • GSL and C++ Core Guidelines
  • Gather
  • Gather
  • tf.gather
  • torch.gather
  • paddle.gather
  • [菁英计划] 索引取值及gather函数 #36815
  • paddle. take_along_axis
  • torch.gather in pytorch.onnx and onnxruntime #31464
  • Replace torch.gather by other operator?
  • Problem compiling onnx model using GLOW compiler: constant not found
  • pytorch导出onnx的原则-以SwinTransformer和DETR在trt8.0.3.4部署为例
  • GatherElements
  • tf2onnx Gather
  • OpenVINO Gather
  • Pytorch equivalent of numpy.take()
  • torch.index_select
  • wkentaro/pytorch-for-numpy-users
  • Similar operation like numpy.take
  • numpy.take
  • tensorflow/tensorflow/core/kernels/gather_functor.h
  • tensorflow/core/kernels/gather_functor_batched.h
  • abseil中的微操

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1549180.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

unity学习(72)——编译游戏发生错误4——GAME_STATE状态

1.经过一天的冷静&#xff0c;我感觉问题出在mapHandler的update中。 如果还没有初始化对象&#xff0c;就开始读取对象的内容&#xff0c;一定会有异常的。 2.之前已有GameState结构体&#xff0c;我一直没当回事&#xff0c;这次用到了 3.从user切换到map场景的过程中会触发如…

号码采集协议讲解

仅供学习研究交流使用 需要的进去拿源码或者成品

【区块链】C语言编程实现三叉Merkle树

目录 1. Merkle树简介2. 构建Merkle树3. 生成SPV路径4. 验证SPV路径5. 三叉Merkle树创建、SPV生成及验证总程序6. 程序运行结果 1. Merkle树简介 如上图所示&#xff0c;Merkle 树的叶子节点为交易序列&#xff0c;对每一笔交易进行 Hash&#xff08;SHA 256算法&#xff09; 之…

vivado 在远程主机上启动作业、ISE命令图、实施类别,战略描述和指令映射

在远程主机上启动作业 一旦配置了远程主机&#xff0c;使用它们启动Vivado作业就很容易了。下图显示了启动运行对话框。启动跑步时&#xff0c;选择“在远程上启动跑步”hosts或Launch在群集上运行&#xff0c;然后选择特定的群集。这些作业将使用您的要执行的预配置设置。 作业…

针对COT控制模式下低ESR电容造成次谐波振荡问题的片内斜波补偿方案

COT模式&#xff1a;MOS管固定导通时间控制模式&#xff0c;关断时间由输出反馈电压与内部基准源的相较值决定。 RBCOT控制模式&#xff1a;Ripple-Based COT基于纹波的固定导通时间控制方法&#xff0c;特别的是环路控制部分主要有固定导通时间发生装置及比较器组成。RBCOT控…

DreamPolisher、InternLM2 、AniArtAvatar、PlainMamba、AniPortrait

本文首发于公众号&#xff1a;机器感知 DreamPolisher、InternLM2 、AniArtAvatar、PlainMamba、AniPortrait DreamPolisher: Towards High-Quality Text-to-3D Generation via Geometric Diffusion We present DreamPolisher, a novel Gaussian Splatting based method wit…

PPP实验

一、实验拓扑图 二、实验要求 1、R1和R2使用PPP链路直连&#xff0c;R2和R3把2条PPP链路捆绑为PPP MP直连 2、按照图示配置IP地址 3、R2对R1的PPP进行单向chap验证 4、R2和R3的PPP进行双向chap验证 三、实验步骤 1、PPP MP&#xff1a; &#xff08;1&#xff09;R2配置&#x…

C语言从入门到实战----数据在内存中的存储

1. 整数在内存中的存储 在讲解操作符的时候&#xff0c;我们就讲过了下⾯的内容&#xff1a; 整数的2进制表⽰⽅法有三种&#xff0c;即 原码、反码和补码 有符号的整数&#xff0c;三种表⽰⽅法均有符号位和数值位两部分&#xff0c;符号位都是⽤0表⽰“正”&#xff0c;⽤…

dji esdk开发(4)SDK互联互通(与云端进行小数据通信)

Edge SDK 提供接口可以通过上云 API 与和机场建立连接的云端服务器进行小数据交互,即向云端服务器发送自定义小数据与接收来自云端服务器的自定义小数据。 注意: 使用该接口发送和接收数据上下行通道最大带宽不应超过 0.5Mb/S。 1、云端低速通道介绍 使用自定义小数据通道需…

C++类和对象、面向对象编程 (OOP)

文章目录 一、封装1.抽象、封装2.类和对象(0)学习视频(1)类的构成(2)三种访问权限(3)struct和class的区别(4)私有的成员变量、共有的成员函数(5)类内可以直接访问私有成员&#xff0c;不需要经过对象 二、继承三、多态1.概念2.多态的满足条件3.多态的使用条件4.多态原理剖析5.纯…

详细描述红黑树如何左旋、右旋(图文结合)

红黑树 首先要理解二叉查找树 二叉查找树&#xff08;BST&#xff09;具备什么特性呢&#xff1f; 左子树上所有结点的值均小于或等于它的根结点的值。 右子树上所有结点的值均大于或等于它的根结点的值。 左、右子树也分别为二叉排序树。 二叉查找树是二分查找的思想&…

vben admin路由跳转拿不到param参数问题

vben admin路由跳转拿不到param参数问题 问题原因&#xff1a; 也就是说&#xff0c;从Vue Router的2022-8-22 这次更新后&#xff0c;我们使用上面的方式在新页面无法获取&#xff1a; vue也给我们提出了解决方案&#xff1a; ​ 1.使用 query 的方式传参 ​ 2.将参数放…

Linux项目自动化构建工具make和makefile

前言 前面我们对yum、vim、gcc/g做了介绍&#xff0c;本期我们再来介绍一个好用的工具&#xff0c;就是make和makefile! 本期内容介绍 什么是make和makefile makefile文件内容的解释 make执行makefile的原理 我们想要的makefile 一、什么是make 和 makefile ? make是一条指令…

DVB-S系统仿真学习

DVB-S系统用于卫星电视信号传输&#xff0c;发送端框图如下所示 扰码 实际数字通信中&#xff0c;载荷数据的码元会出现长连0或长连1的情况&#xff0c;不利于接收端提取时钟信号&#xff0c;同时会使得数据流中含有大量的低频分量&#xff0c;使得QPSK调制器的相位长时间不变…

Python算法100例-4.6 歌星大奖赛

完整源代码项目地址&#xff0c;关注博主私信源代码后可获取 1.问题描述2.问题分析3.算法设计4.确定程序框架5.完整的程序6.问题拓展7.知识点补充 1&#xff0e;问题描述 在歌星大奖赛中&#xff0c;有10个评委为参赛的选手打分&#xff0c;分数为1&#xff5e;100分。选手最…

金融投贷通(金融投资+贷款通)项目准备

金融投贷通&#xff08;金融投资贷款通&#xff09;项目准备 专业术语投资专业术语本息专业术语还款专业术语项目介绍三个子系统技术架构核心流程发布借款标投资业务 项目实施测试流程测试步骤 专业术语 投资专业术语 案例&#xff1a;张三借给李四5W&#xff0c;约定期满1年后…

MySQL 高级语句(二)

一、子查询 1.1 相同表子查询 1.2 不同表/多表子查询 1.3 子查询的应用 1.3.1 语法 1.3.2 insert 子查询 1.3.3 update 子查询 1.3.4 delete 子查询 1.4 exists 关键字 1.4.1 true 1.4.2 false 1.5 as别名 二、视图 2.1 视图和表的区别和联系 2.1.1 区别 2.1.2 …

【面试必备】针对一个案例,怎么测试

思考角度 测试用例设计万能公式功能测试&#xff08;最重要&#xff09;界面测试易用性测试性能测试安全性测试兼容性测试容错性测试 常见案例物品类水杯笔 软件类微信发送朋友圈功能 测试用例设计万能公式 在面试中经常会遇到的一类题是&#xff0c;给你一个具体的产品&#…

《Attention Is All You Need》

参考&#xff1a; Attention Is All You Need 论文解读:Attention is All you need Transformer模型中的attention结构作用是什么&#xff1f; 如何最简单、通俗地理解Transformer&#xff1f; Transformer 新型神经网络&#xff0c;基于注意力机制 的 编码器-解码器 的序列处…

数据分析web可视化神器---streamlit框架,无需懂前端也能搭建出精美的web网站页面

✨✨ 欢迎大家来到景天科技苑✨✨ &#x1f388;&#x1f388; 养成好习惯&#xff0c;先赞后看哦~&#x1f388;&#x1f388; 所属的专栏&#xff1a;数据分析系统化教学&#xff0c;零基础到进阶实战 景天的主页&#xff1a;景天科技苑 文章目录 Streamlit什么是streamli…