上一篇文章中介绍了 Division by Invariant Integers using Multiplication 的原理,很多框架均才用该算法优化除法运算。onnxruntime 是已知实现中最为简洁的,因此本文结合 onnxruntime 的 Gather 实现进行介绍。 Gather 算子是一个索引类算子,kernel 中每个线程计算偏移时使用 fast_divmod 避免除法运算。
注意:ONNX 中的 Gather 功能与 numpy.take 相同,torch.index_select 是其简化版。而 ONNX 中的 GatherElements 与 torch.gather 和 paddle. take_along_axis 相对应。
Gather
会话运行时,ExecuteKernel 函数会调用 OpKernel。
CudaKernel 是 CUDA kernel 的基类,提供了 CudaKernel::Compute 函数。
OpKernelInfo 是一个非常轻量级的类,它作为构建 Kernel 实例所需的所有数据的聚合视图。 注意:它不拥有/持有任何对象。
class Gather : public CudaKernel, public GatherBase {
public:
Gather(const OpKernelInfo& info) : CudaKernel(info), GatherBase(info) {}
Status ComputeInternal(OpKernelContext* context) const override;
};
Gather::ComputeInternal
创建一个 GatherBase::Prepare 结构体,包含了两个输入和一个输出张量的指针。
GatherBase::PrepareForCompute 准备输入输出。输出张量的秩为input_rank - 1 + indices_rank
,即将axis
参数指定的轴替换为indices
张量的形状。
ORT_RETURN_IF_ERROR 在表达式失败时返回错误。
TensorShape::SizeFromDimension 计算从指定维度开始的乘积大小。
axis
参数会将输入张量划分为3部分:batch 维度、索引维度、分块维度。
block_size
为每个索引对应的分块大小。
N
为索引数量。
input_block_size
为在输入上的分块大小。
indices_max
即索引上限。
Status Gather::ComputeInternal(OpKernelContext* context) const {
Prepare p;
ORT_RETURN_IF_ERROR(PrepareForCompute(context, p));
const TensorShape& input_shape = p.input_tensor->Shape();
const int64_t block_size = input_shape.SizeFromDimension(p.axis + 1);
size_t N = p.indices_tensor->Shape().Size();
const int64_t input_block_size = input_shape.SizeFromDimension(p.axis);
const int64_t output_block_size = N * block_size;
const int64_t indices_max = input_shape[p.axis];
const void* input_data = p.input_tensor->DataRaw();
const void* indices_data = p.indices_tensor->DataRaw();
void* output_data = p.output_tensor->MutableDataRaw();
if (p.output_tensor->Shape().Size() == 0) {
return Status::OK();
}
gsl::narrow
可确保无损失转换,并在无法转换时引发gsl::narrowing_error
。
fast_divmod 即 DivMod,用于快速计算除法。
const fast_divmod divmod_output_block_size(gsl::narrow_cast<int>(output_block_size));
const fast_divmod divmod_block_size(gsl::narrow_cast<int>(block_size));
const size_t element_size = p.input_tensor->DataType()->Size();
const size_t index_element_size = p.indices_tensor->DataType()->Size();
GatherImpl 函数索仅支持int32_t
和int64_t
引类型。
传入的p.output_tensor->Shape().Size()
即输出元素总数。
// CUDA Kernel implementation supports element sizes of:
// int8_t, int16_t, int32_t and int64_t which covers all supported
// types since there is no computations necessary just data movement
if (p.indices_tensor->IsDataType<int32_t>() ||
p.indices_tensor->IsDataType<int64_t>()) {
GatherImpl(
Stream(context),
input_block_size,
indices_max,
divmod_output_block_size,
divmod_block_size,
indices_data,
index_element_size,
input_data,
element_size,
output_data,
p.output_tensor->Shape().Size());
return Status::OK();
}
ORT_MAKE_STATUS 创建一个 Status 对象。
return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "Type for Tind not supported yet in Gather.");
}
GatherImpl
GridDim 结构体中定义了美剧值。
N
为输出元素数量。直接求出所需 threadblock 的数量,没有太多策略。
void GatherImpl(
cudaStream_t stream,
const int64_t input_block_size,
const int64_t indices_max,
const fast_divmod& output_block_size,
const fast_divmod& block_size,
const void* indices_data,
size_t index_element_size,
const void* input_data,
size_t element_size,
void* output_data,
const size_t N) {
int blocksPerGrid = (int)(ceil(static_cast<float>(N) / GridDim::maxThreadsPerBlock));
ToCudaType 模板类将类型枚举转换为数据类型。
根据元素大小调用 _GatherKernel 模板函数,这样减少了实例化类型。
switch (element_size) {
case sizeof(int8_t): {
using CudaType = typename ToCudaType<int8_t>::MappedType;
_GatherKernel<<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream>>>(
input_block_size, indices_max, output_block_size, block_size, indices_data, index_element_size,
reinterpret_cast<const CudaType*>(input_data), reinterpret_cast<CudaType*>(output_data), (CUDA_LONG)N);
} break;
case sizeof(int16_t): {
using CudaType = typename ToCudaType<int16_t>::MappedType;
_GatherKernel<<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream>>>(
input_block_size, indices_max, output_block_size, block_size, indices_data, index_element_size,
reinterpret_cast<const CudaType*>(input_data), reinterpret_cast<CudaType*>(output_data), (CUDA_LONG)N);
} break;
case sizeof(int32_t): {
using CudaType = typename ToCudaType<int32_t>::MappedType;
_GatherKernel<<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream>>>(
input_block_size, indices_max, output_block_size, block_size, indices_data, index_element_size,
reinterpret_cast<const CudaType*>(input_data), reinterpret_cast<CudaType*>(output_data), (CUDA_LONG)N);
} break;
case sizeof(int64_t): {
using CudaType = typename ToCudaType<int64_t>::MappedType;
_GatherKernel<<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream>>>(
input_block_size, indices_max, output_block_size, block_size, indices_data, index_element_size,
reinterpret_cast<const CudaType*>(input_data), reinterpret_cast<CudaType*>(output_data), (CUDA_LONG)N);
} break;
default:
ORT_THROW("Unsupported element size by the Gather CUDA kernel");
}
}
_GatherKernel
CALCULATE_ELEMENTWISE_INDEX_OR_EXIT 计算元素索引,并在超出范围时返回。
template <typename T>
__global__ void _GatherKernel(
const int64_t input_block_size,
const int64_t indices_max,
const fast_divmod output_block_size,
const fast_divmod block_size,
const void* indices_data,
const size_t index_element_size,
const T* input_data,
T* output_data,
const CUDA_LONG N) {
CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N);
CUDA_LONG input_index = 0;
线程号id
除以output_block_size
得到输出元素所对应的输入分块索引input_block_index
和输入分块内的偏移block_offset
。
根据block_offset
计算对应的indices
张量的索引以及分块内元素偏移。
GetIndexValue 取出indices
张量的值。相比 TensorFlow 的 gather_functor_gpu.cu.h 没有进行向量化访存优化。
idx
支持负数。索引值超出范围时赋零。
int input_block_index, block_offset;
output_block_size.divmod(id, input_block_index, block_offset);
int indices_index, offset;
block_size.divmod(block_offset, indices_index, offset);
int64_t idx = GetIndexValue(indices_data, index_element_size, indices_index);
idx = idx < 0 ? idx + indices_max : idx;
if (idx < 0 || idx >= indices_max) {
output_data[id] = 0;
return;
}
三部分相加得到输入张量索引。
input_index = input_block_index * input_block_size + idx * block_size.d_ + offset;
output_data[id] = input_data[input_index];
}
GetIndexValue
将index_data
指针转为相应类型,然后返回偏移位置的值。
__host__ __device__ inline int64_t GetIndexValue(const void* index_data, size_t index_element_size, size_t offset) {
switch (index_element_size) {
case sizeof(int32_t):
return *(reinterpret_cast<const int32_t*>(index_data) + offset);
break;
case sizeof(int64_t):
return *(reinterpret_cast<const int64_t*>(index_data) + offset);
break;
default:
break;
}
// What is a sensible thing to do here?
assert(false);
return std::numeric_limits<int64_t>::max();
}
DivMod
除法取余实现基于 Division by Invariant Integers using Multiplication 中的 Figure 4.1。
// The code below is based on section 4 Unsigned division of paper https://gmplib.org/~tege/divcnst-pldi94.pdf
// In current ORT, fast_divmod is used for calculating the position of a element in tensor,
// so unsigned integer division from the paper is good enough for ORT. The advantage is that div is very simple,
// then GPU compiler can do loop unroll easilly when divmod is called in a loop.
template <>
struct DivMod<int> {
DivMod(int d = 1) {
d_ = d == 0 ? 1 : d;
ORT_ENFORCE(d_ >= 1 && d_ <= static_cast<uint32_t>(std::numeric_limits<int>::max()));
l_
是
ℓ
=
⌈
log
2
x
⌉
\ell = \lceil \log_2 x \rceil
ℓ=⌈log2x⌉。
for (l_ = 0; l_ < 32; l_++)
if ((1U << l_) >= d_) break;
m
是
m
′
=
⌊
2
N
∗
(
2
ℓ
−
d
)
/
d
⌋
+
1
m' = \lfloor 2^N ∗ (2^\ell − d)/d\rfloor + 1
m′=⌊2N∗(2ℓ−d)/d⌋+1。
uint64_t one = 1;
uint64_t m = ((one << 32) * ((one << l_) - d_)) / d_ + 1;
M_ = static_cast<uint32_t>(m);
// according to paper, the value of m' should fit in a unsigned integer.
ORT_ENFORCE(M_ > 0 && M_ == m);
}
DivMod::div
t
是
t
1
=
M
U
L
U
H
(
m
′
,
n
)
t_1 = \mathrm{MULUH}(m', n)
t1=MULUH(m′,n),使用uint64_t
计算避免溢出。
对于
q
q
q,
- 如果 d = 1 d = 1 d=1,那么 ℓ = 0 \ell = 0 ℓ=0,所以 m ′ = 1 m' = 1 m′=1 且 s h 1 = s h 2 = 0 sh_1 = sh_2 = 0 sh1=sh2=0。代码计算 t 1 = ⌊ 1 ∗ n / 2 N ⌋ = 0 t_1 = \lfloor 1 ∗ n/2^N \rfloor = 0 t1=⌊1∗n/2N⌋=0 和 q = n q = n q=n。
- 若
d
>
1
d > 1
d>1,则
ℓ
≥
1
\ell≥1
ℓ≥1,故
s
h
1
=
1
sh_1 = 1
sh1=1 且
s
h
2
=
ℓ
−
1
sh_2 =\ell −1
sh2=ℓ−1。
q = S R L ( t 1 + S R L ( n − t 1 , s h 1 ) , s h 2 ) = S R L ( t 1 + S R L ( n − t 1 , 1 ) , ℓ − 1 ) = ⌊ t 1 + ⌊ ( n − t 1 ) 2 ⌋ 2 ℓ − 1 ⌋ = ⌊ ⌊ 2 ∗ t 1 2 + ( n − t 1 ) 2 ⌋ 2 ℓ − 1 ⌋ (4.5) = ⌊ ⌊ ( t 1 + n ) / 2 ⌋ 2 ℓ − 1 ⌋ = ⌊ t 1 + n 2 ℓ ⌋ \begin{aligned} q &= \mathrm{SRL}(t_1 + \mathrm{SRL}(n − t_1, sh_1), sh_2)\\ &= \mathrm{SRL}(t_1 + \mathrm{SRL}(n − t_1, 1), \ell− 1)\\ &=\lfloor \frac{t_1 + \lfloor \frac{(n − t_1)}{2} \rfloor}{2^{\ell− 1}}\rfloor\\ &=\lfloor \frac{\lfloor \frac{2*t_1}{2} + \frac{(n − t_1)}{2} \rfloor}{2^{\ell− 1}}\rfloor \qquad\text{(4.5)}\\ &=\lfloor \frac{\lfloor(t_1 + n)/2\rfloor}{2^{\ell− 1}} \rfloor\\ &=\lfloor \frac{t_1 + n}{2^{\ell}} \rfloor \end{aligned} q=SRL(t1+SRL(n−t1,sh1),sh2)=SRL(t1+SRL(n−t1,1),ℓ−1)=⌊2ℓ−1t1+⌊2(n−t1)⌋⌋=⌊2ℓ−1⌊22∗t1+2(n−t1)⌋⌋(4.5)=⌊2ℓ−1⌊(t1+n)/2⌋⌋=⌊2ℓt1+n⌋
__umulhi 计算两个 32 位无符号整数的乘积的最高有效 32 位。
__host__ __device__ inline int div(int n) const {
#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__)
uint32_t t = __umulhi(M_, n);
return (t + n) >> l_;
#else
// Using uint64_t for t, then t + n won't overflow.
uint64_t t = ((uint64_t)M_ * n) >> 32;
return static_cast<int>((t + n) >> l_);
#endif
}
DivMod::mod
n m o d d = n − d ∗ ⌊ n / d ⌋ n \enspace \mathrm{mod} \enspace d = n − d ∗ \lfloor n/d \rfloor nmodd=n−d∗⌊n/d⌋
__host__ __device__ inline int mod(int n) const {
return n - div(n) * d_;
}
DivMod::divmod
__host__ __device__ inline void divmod(int n, int& q, int& r) const {
q = div(n);
r = n - q * d_;
}
uint32_t d_; // divisor
uint32_t M_; // m' in the paper.
uint32_t l_; // l_ = ceil(log2(d_))
};
- Gather
- ONNXRuntime整体概览
- ONNXRuntime源码之OpKernel注册
- Ways to specify [[nodiscard]] before C++17
- microsoft/GSL
- How to use gsl narrow cast
- 警告 C26472
- GSL and C++ Core Guidelines
- Gather
- Gather
- tf.gather
- torch.gather
- paddle.gather
- [菁英计划] 索引取值及gather函数 #36815
- paddle. take_along_axis
- torch.gather in pytorch.onnx and onnxruntime #31464
- Replace torch.gather by other operator?
- Problem compiling onnx model using GLOW compiler: constant not found
- pytorch导出onnx的原则-以SwinTransformer和DETR在trt8.0.3.4部署为例
- GatherElements
- tf2onnx Gather
- OpenVINO Gather
- Pytorch equivalent of numpy.take()
- torch.index_select
- wkentaro/pytorch-for-numpy-users
- Similar operation like
numpy.take
- numpy.take
- tensorflow/tensorflow/core/kernels/gather_functor.h
- tensorflow/core/kernels/gather_functor_batched.h
- abseil中的微操