torch._dynamo 理解(2)——Backend

news2024/11/25 18:25:22

0 概述

TorchDynamo 是一个 Python 级别的即时 (JIT) 编译器,旨在让未修改的 PyTorch 程序运行得更快。它通过 Python Frame Evaluation Hooks(Python 框架评估钩子)来实现这一目标,以便在运行时动态地生成和优化代码。这使得 TorchDynamo 可以有效地处理各种 Python 代码,包括包含控制流(如循环和条件语句)的代码,而无需进行任何修改。

整个 pytorch 的编译栈如下:
在这里插入图片描述

后端整体流程如下:
在这里插入图片描述

以 Triton 为例,首先会做一次 lowering,然后进行调度,最后才会生成 Triton 的 kernel。

1 Loop-level IR

这里的 lowering,使用loop-level IR来表示,其对aten IR的每一句话做解释,并且每次的解析都会与前文联系起来。这一层 IR 的类型有:

  • PointWise
  • Reduction
  • TensorBox
  • MatrixMultiplyAdd
    除此之外,还有一些其他的类型。

这一层处理流程:

  • 对于从前端拿到的aten IR
    在这里插入图片描述

  • 对于上面的每一句运算,都翻译为loop-level IR

    • convert_element_type:
      在这里插入图片描述

    • amax
      在这里插入图片描述

      这里将计算的结果存储到buf0

    • sub:
      在这里插入图片描述

      由于amax将结果存储到buf0中,因此这里才能从buf0中直接 load 进来

    • exp:
      在这里插入图片描述

      如果上一条 IR 是pointwise的话,那么就会和这一次的进行归约,例如这里,只是在sub的 IR 上加上了tmp4 = exp(tmp3)并将 return 改为了tmp4
      这一层的 pass 会对aten IR的每一句话进行解析,并且每次的解析都会与前文联系起来,最终得到一个归约的loop-level IR

2 Schedule

一下面的代码为例:

if __name__ == '__main__':
    model = nn.Sequential(
        nn.Conv2d(16, 32, 3),
        nn.BatchNorm2d(32),
        nn.ReLU(),
    ).cuda()
    model = torch.compile(model)
    x = torch.randn((2, 16, 8, 8), requires_grad=True, device="cuda")
    model(x)

其在loop-level层构建出 11 个缓冲区。随后,对这些缓冲区进行 schedule,内容包括:
在这里插入图片描述

这里有些缓冲区启用了 Reduction,也就是说这里的归约是对于缓冲区而言的。将这些缓冲区放在一起,生成一个 kernel ,而其他的缓冲区,则单独生成自己的 kernel (注意这里的 kernel 是指 triton 的 kernel,实际上我们可以认为是一个函数)。只有 reduction 的 kernel 中会出现循环语句,若只是 pointwise 的计算,则不会生成循环

3 Triton Kernel

最后就是 triton kernel 的生成,其采取的策略是:

  • 首先生成 load 语句
  • 生成 compute 语句
  • 生成 store 语句
  • 组合三种语句为一个 kernel
  • 组合所有 kernel 与一个 call 函数和 main 模块在一起为一个 .py 文件

上述例子生成的文件如下:

from ctypes import c_void_p, c_long
import torch
import math
import random
from torch import empty_strided, as_strided, device
from torch._inductor.codecache import AsyncCompile
from torch._inductor.select_algorithm import extern_kernels

aten = torch.ops.aten
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
async_compile = AsyncCompile()

import triton
import triton.language as tl
from torch._inductor.triton_ops.autotune import grid
from torch._C import _cuda_getCurrentRawStream as get_cuda_stream


triton__0 = async_compile.triton('''
import triton
import triton.language as tl
from torch._inductor.ir import ReductionHint
from torch._inductor.ir import TileHint
from torch._inductor.triton_ops.autotune import pointwise
from torch._inductor.utils import instance_descriptor

@pointwise(size_hints=[4096], filename=__file__, meta={'signature': {0: '*fp32', 1: '*fp32', 2: 'i32'}, 'device': 0, 'constants': {}, 'mutated_arg_names': ['in_out_ptr0'], 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]})
@triton.jit
def triton_(in_out_ptr0, in_ptr0, xnumel, XBLOCK : tl.constexpr):
        xnumel = 2304
        xoffset = tl.program_id(0) * XBLOCK
        xindex = xoffset + tl.arange(0, XBLOCK)[:]
        xmask = xindex < xnumel
    x3 = xindex
    x1 = (xindex // 36) % 32
    tmp0 = tl.load(in_out_ptr0 + (x3), xmask)
    tmp1 = tl.load(in_ptr0 + (x1), xmask)	
    tmp2 = tmp0 + tmp1
    tl.store(in_out_ptr0 + (x3 + tl.zeros([XBLOCK], tl.int32)), tmp2, xmask)
''')


triton__1 = async_compile.triton('''
import triton
import triton.language as tl
from torch._inductor.ir import ReductionHint
from torch._inductor.ir import TileHint
from torch._inductor.triton_ops.autotune import reduction
from torch._inductor.utils import instance_descriptor

@reduction(
    size_hints=[32, 128],
    reduction_hint=ReductionHint.INNER,
    filename=__file__,
    meta={'signature': {0: '*fp32', 1: '*fp32', 2: '*fp32', 3: '*fp32', 4: '*fp32', 5: '*fp32', 6: '*fp32', 7: '*fp32', 8: 'i32', 9: 'i32'}, 'device': 0, 'constants': {}, 'mutated_arg_names': ['in_out_ptr0'], 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7, 8), equal_to_1=())]}
)
@triton.jit
def triton_(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, out_ptr0, out_ptr1, out_ptr2, out_ptr3, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
    xnumel = 32
    rnumel = 72
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    rbase = tl.arange(0, RBLOCK)[None, :]
    x0 = xindex
    _tmp1 = tl.zeros([XBLOCK, RBLOCK], tl.float32) + 0
    for roffset in range(0, rnumel, RBLOCK):
        rindex = roffset + rbase
        rmask = rindex < rnumel
        r1 = rindex % 36
        r2 = (rindex // 36)
        tmp0 = tl.load(in_ptr0 + (r1 + (36*x0) + (1152*r2)), rmask & xmask, eviction_policy='evict_last', other=0)
        _tmp1 = tl.where(rmask & xmask, _tmp1 + tmp0, _tmp1)
    tmp1 = tl.sum(_tmp1, 1)[:, None]
    tmp6 = tl.load(in_ptr1 + (x0), xmask)
    tmp2 = 72.0
    tmp3 = tmp1 / tmp2
    tmp4 = 0.1
    tmp5 = tmp3 * tmp4
    tmp7 = 0.9
    tmp8 = tmp6 * tmp7
    tmp9 = tmp5 + tmp8
    tl.store(in_out_ptr0 + (x0 + tl.zeros([XBLOCK, 1], tl.int32)), tmp3, xmask)
    tl.store(out_ptr0 + (x0 + tl.zeros([XBLOCK, 1], tl.int32)), tmp9, xmask)
    _tmp13 = tl.zeros([XBLOCK, RBLOCK], tl.float32) + 0
    for roffset in range(0, rnumel, RBLOCK):
        rindex = roffset + rbase
        rmask = rindex < rnumel
        r1 = rindex % 36
        r2 = (rindex // 36)
        tmp10 = tl.load(in_ptr0 + (r1 + (36*x0) + (1152*r2)), rmask & xmask, eviction_policy='evict_last', other=0)
        tmp11 = tmp10 - tmp3
        tmp12 = tmp11 * tmp11
        _tmp13 = tl.where(rmask & xmask, _tmp13 + tmp12, _tmp13)
    tmp13 = tl.sum(_tmp13, 1)[:, None]
    tl.store(out_ptr1 + x0, tmp13, xmask)
    tmp23 = tl.load(in_ptr2 + (x0), xmask)
    tmp14 = 72.0
    tmp15 = tmp13 / tmp14
    tmp16 = 1e-05
    tmp17 = tmp15 + tmp16
    tmp18 = tl.libdevice.rsqrt(tmp17)
    tmp19 = 1.0140845070422535
    tmp20 = tmp15 * tmp19
    tmp21 = 0.1
    tmp22 = tmp20 * tmp21
    tmp24 = 0.9
    tmp25 = tmp23 * tmp24
    tmp26 = tmp22 + tmp25
    tl.store(out_ptr2 + (x0 + tl.zeros([XBLOCK, 1], tl.int32)), tmp18, xmask)
    tl.store(out_ptr3 + (x0 + tl.zeros([XBLOCK, 1], tl.int32)), tmp26, xmask)
''')


triton__2 = async_compile.triton('''
import triton
import triton.language as tl
from torch._inductor.ir import ReductionHint
from torch._inductor.ir import TileHint
from torch._inductor.triton_ops.autotune import pointwise
from torch._inductor.utils import instance_descriptor

@pointwise(size_hints=[4096], filename=__file__, meta={'signature': {0: '*fp32', 1: '*fp32', 2: '*fp32', 3: '*fp32', 4: '*fp32', 5: '*fp32', 6: '*i1', 7: 'i32'}, 'device': 0, 'constants': {}, 'mutated_arg_names': [], 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7), equal_to_1=())]})
@triton.jit
def triton_(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr0, out_ptr1, xnumel, XBLOCK : tl.constexpr):
    xnumel = 2304
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x3 = xindex
    x1 = (xindex // 36) % 32
    tmp0 = tl.load(in_ptr0 + (x3), xmask)
    tmp1 = tl.load(in_ptr1 + (x1), xmask)
    tmp3 = tl.load(in_ptr2 + (x1), xmask)
    tmp10 = tl.load(in_ptr3 + (x1), xmask)
    tmp12 = tl.load(in_ptr4 + (x1), xmask)
    tmp2 = tmp0 - tmp1
    tmp4 = 72.0
    tmp5 = tmp3 / tmp4
    tmp6 = 1e-05
    tmp7 = tmp5 + tmp6
    tmp8 = tl.libdevice.rsqrt(tmp7)
    tmp9 = tmp2 * tmp8
    tmp11 = tmp9 * tmp10
    tmp13 = tmp11 + tmp12
    tmp14 = tl.where(0 != 0, 0, tl.where(0 > tmp13, 0, tmp13))
    tmp15 = 0.0
    tmp16 = tmp14 <= tmp15
    tl.store(out_ptr0 + (x3 + tl.zeros([XBLOCK], tl.int32)), tmp14, xmask)
    tl.store(out_ptr1 + (x3 + tl.zeros([XBLOCK], tl.int32)), tmp16, xmask)
''')


triton__3 = async_compile.triton('''
import triton
import triton.language as tl
from torch._inductor.ir import ReductionHint
from torch._inductor.ir import TileHint
from torch._inductor.triton_ops.autotune import pointwise
from torch._inductor.utils import instance_descriptor

@pointwise(size_hints=[1], filename=__file__, meta={'signature': {0: '*i64', 1: '*i64', 2: 'i32'}, 'device': 0, 'constants': {}, 'mutated_arg_names': [], 'configs': [instance_descriptor(divisible_by_16=(0, 1), equal_to_1=())]})
@triton.jit
def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 1
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    tmp0_load = tl.load(in_ptr0 + (0))
    tmp0 = tl.broadcast_to(tmp0_load, [XBLOCK])
    tmp1 = 1
    tmp2 = tmp0 + tmp1
    tl.store(out_ptr0 + (0 + tl.zeros([XBLOCK], tl.int32)), tmp2, None)
''')


async_compile.wait(globals())
del async_compile

def call(args):
    primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8 = args
    args.clear()
    with torch.cuda._DeviceGuard(0):
        torch.cuda.set_device(0) # no-op to ensure context
        buf0 = aten.convolution(primals_8, primals_1, None, (1, 1), (0, 0), (1, 1), False, (0, 0), 1)
        assert_size_stride(buf0, (2, 32, 6, 6), (1152, 36, 6, 1))
        buf1 = buf0; del buf0  # reuse
        stream0 = get_cuda_stream(0)
        triton__0.run(buf1, primals_2, 2304, grid=grid(2304), stream=stream0)
        del primals_2
        buf2 = empty_strided((1, 32, 1, 1), (32, 1, 32, 32), device='cuda', dtype=torch.float32)
        buf3 = buf2; del buf2  # reuse
        buf6 = empty_strided((32, ), (1, ), device='cuda', dtype=torch.float32)
        buf4 = empty_strided((1, 32, 1, 1), (32, 1, 32, 32), device='cuda', dtype=torch.float32)
        buf5 = empty_strided((32, ), (1, ), device='cuda', dtype=torch.float32)
        buf7 = empty_strided((32, ), (1, ), device='cuda', dtype=torch.float32)
        triton__1.run(buf3, buf1, primals_5, primals_6, buf6, buf4, buf5, buf7, 32, 72, grid=grid(32), stream=stream0)
        del primals_5
        del primals_6
        buf8 = empty_strided((2, 32, 6, 6), (1152, 36, 6, 1), device='cuda', dtype=torch.float32)
        buf9 = empty_strided((2, 32, 6, 6), (1152, 36, 6, 1), device='cuda', dtype=torch.bool)
        triton__2.run(buf1, buf3, buf4, primals_3, primals_4, buf8, buf9, 2304, grid=grid(2304), stream=stream0)
        del buf4
        del primals_4
        buf10 = empty_strided((), (), device='cuda', dtype=torch.int64)
        triton__3.run(primals_7, buf10, 1, grid=grid(1), stream=stream0)
        del primals_7
        return (buf6, buf7, buf10, buf8, primals_1, primals_3, primals_8, buf1, buf5, buf9, as_strided(buf3, (1, 32, 1, 1), (32, 1, 1, 1)), )


if __name__ == "__main__":
    from torch._dynamo.testing import rand_strided
    from torch._inductor.utils import print_performance
    primals_1 = rand_strided((32, 16, 3, 3), (144, 9, 3, 1), device='cuda:0', dtype=torch.float32)
    primals_2 = rand_strided((32, ), (1, ), device='cuda:0', dtype=torch.float32)
    primals_3 = rand_strided((32, ), (1, ), device='cuda:0', dtype=torch.float32)
    primals_4 = rand_strided((32, ), (1, ), device='cuda:0', dtype=torch.float32)
    primals_5 = rand_strided((32, ), (1, ), device='cuda:0', dtype=torch.float32)
    primals_6 = rand_strided((32, ), (1, ), device='cuda:0', dtype=torch.float32)
    primals_7 = rand_strided((), (), device='cuda:0', dtype=torch.int64)
    primals_8 = rand_strided((2, 16, 8, 8), (1024, 64, 8, 1), device='cuda:0', dtype=torch.float32)
    print_performance(lambda: call([primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8]))

此文件生成在/tmp文件夹中,后缀为py ,后续直接运行此文件,可得到 performace 的值,同样,也可在运行中捕获到运算的值。

4 loop-level IR --> triton kernel

通过数据结构GraphLowering的方法run(*example_input)也就是一个Fake Tensor来生成 triton kernel:


Graph ID : 0 

Input : {
    'primals_1': TensorBox(StorageBox(
  InputBuffer(name='primals_1', layout=FixedLayout('cuda', torch.float32, size=[32, 16, 3, 3], stride=[144, 9, 3, 1]))
)), 
    'primals_2': TensorBox(StorageBox(
  InputBuffer(name='primals_2', layout=FixedLayout('cuda', torch.float32, size=[32], stride=[1]))
)), 
    'primals_3': TensorBox(StorageBox(
  InputBuffer(name='primals_3', layout=FixedLayout('cuda', torch.float32, size=[32], stride=[1]))
)), 
    'primals_4': TensorBox(StorageBox(
  InputBuffer(name='primals_4', layout=FixedLayout('cuda', torch.float32, size=[32], stride=[1]))
)), 
    'primals_5': TensorBox(StorageBox(
  InputBuffer(name='primals_5', layout=FixedLayout('cuda', torch.float32, size=[32], stride=[1]))
)), 
    'primals_6': TensorBox(StorageBox(
  InputBuffer(name='primals_6', layout=FixedLayout('cuda', torch.float32, size=[32], stride=[1]))
)), 
    'primals_7': TensorBox(StorageBox(
  InputBuffer(name='primals_7', layout=FixedLayout('cuda', torch.int64, size=[], stride=[]))
)), 
    'primals_8': TensorBox(StorageBox(
  InputBuffer(name='primals_8', layout=FixedLayout('cuda', torch.float32, size=[2, 16, 8, 8], stride=[1024, 64, 8, 1]))
))} 

Origin Input : {
    'primals_1': InputBuffer(name='primals_1', layout=FixedLayout('cuda', torch.float32, size=[32, 16, 3, 3], stride=[144, 9, 3, 1])), 
    'primals_2': InputBuffer(name='primals_2', layout=FixedLayout('cuda', torch.float32, size=[32], stride=[1])), 
    'primals_3': InputBuffer(name='primals_3', layout=FixedLayout('cuda', torch.float32, size=[32], stride=[1])), 
    'primals_4': InputBuffer(name='primals_4', layout=FixedLayout('cuda', torch.float32, size=[32], stride=[1])), 
    'primals_5': InputBuffer(name='primals_5', layout=FixedLayout('cuda', torch.float32, size=[32], stride=[1])), 
    'primals_6': InputBuffer(name='primals_6', layout=FixedLayout('cuda', torch.float32, size=[32], stride=[1])), 
	'primals_7': InputBuffer(name='primals_7', layout=FixedLayout('cuda', torch.int64, size=[], stride=[])), 
    'primals_8': InputBuffer(name='primals_8', layout=FixedLayout('cuda', torch.float32, size=[2, 16, 8, 8], stride=[1024, 64, 8, 1]))} 


Output : [
    StorageBox(ComputedBuffer(name='buf6', layout=FixedLayout('cuda', torch.float32, size=[32], stride=[1]), data=Pointwise(
    'cuda',
    torch.float32,
    tmp0 = load(buf3, i0)
    tmp1 = constant(0.1, torch.float32)
    tmp2 = tmp0 * tmp1
    tmp3 = load(primals_5, i0)
    tmp4 = constant(0.9, torch.float32)
    tmp5 = tmp3 * tmp4
    tmp6 = tmp2 + tmp5
    return tmp6
    ,
    ranges=[32],
    origins={add_2}
  ))
), 
	StorageBox(ComputedBuffer(name='buf7', layout=FixedLayout('cuda', torch.float32, size=(32,), stride=[1]), data=Pointwise(
    'cuda',
    torch.float32,
    tmp0 = load(buf4, i0)
    tmp1 = index_expr(72, torch.float32)
    tmp2 = tmp0 / tmp1
    tmp3 = constant(1.0140845070422535, torch.float32)
    tmp4 = tmp2 * tmp3
    tmp5 = constant(0.1, torch.float32)
    tmp6 = tmp4 * tmp5
    tmp7 = load(primals_6, i0)
    tmp8 = constant(0.9, torch.float32)
    tmp9 = tmp7 * tmp8
    tmp10 = tmp6 + tmp9
    return tmp10
    ,
    ranges=(32,),
    origins={add_3}
  ))
), 
	StorageBox(ComputedBuffer(name='buf10', layout=FixedLayout('cuda', torch.int64, size=[], stride=[]), data=Pointwise(
    'cuda',
    torch.int64,
    tmp0 = load(primals_7, 0)
    tmp1 = constant(1, torch.int64)
    tmp2 = tmp0 + tmp1
    return tmp2
    ,
    ranges=[],
    origins={primals_7, clone_2, add}
  ))
), 
	StorageBox(ComputedBuffer(name='buf8', layout=FixedLayout('cuda', torch.float32, size=[2, 32, 6, 6], stride=[1152, 36, 6, 1]), data=Pointwise(
    'cuda',
    torch.float32,
    tmp0 = load(buf1, i3 + 6 * i2 + 36 * i1 + 1152 * i0)assembly
    tmp1 = load(buf3, i1)
    tmp2 = tmp0 - tmp1
    tmp3 = load(buf4, i1)
    tmp4 = index_expr(72, torch.float32)
    tmp5 = tmp3 / tmp4
    tmp6 = constant(1e-05, torch.float32)
    tmp7 = tmp5 + tmp6
    tmp8 = rsqrt(tmp7)
    tmp9 = tmp2 * tmp8
    tmp10 = load(primals_3, i1)
    tmp11 = tmp9 * tmp10
    tmp12 = load(primals_4, i1)
    tmp13 = tmp11 + tmp12
    tmp14 = relu(tmp13)
    return tmp14
    ,
    ranges=[2, 32, 6, 6],
    origins={relu}
  ))
), 
	StorageBox(InputBuffer(name='primals_1', layout=FixedLayout('cuda', torch.float32, size=[32, 16, 3, 3], stride=[144, 9, 3, 1]))
), 
	StorageBox(InputBuffer(name='primals_3', layout=FixedLayout('cuda', torch.float32, size=[32], stride=[1]))
), 
	StorageBox(InputBuffer(name='primals_8', layout=FixedLayout('cuda', torch.float32, size=[2, 16, 8, 8], stride=[1024, 64, 8, 1]))
), 
	StorageBox(ComputedBuffer(name='buf1', layout=FixedLayout('cuda', torch.float32, size=[2, 32, 6, 6], stride=[1152, 36, 6, 1]), data=Pointwise(
    'cuda',
    torch.float32,
    tmp0 = load(buf0, i3 + 6 * i2 + 36 * i1 + 1152 * i0)
    tmp1 = load(primals_2, i1)
    tmp2 = tmp0 + tmp1
    return tmp2
    ,
    ranges=[2, 32, 6, 6],
    origins={convolution}
  ))
), 
	StorageBox(ComputedBuffer(name='buf5', layout=FixedLayout('cuda', torch.float32, size=(32,), stride=[1]), data=Pointwise(
    'cuda',
    torch.float32,
    tmp0 = load(buf4, i0)
    tmp1 = index_expr(72, torch.float32)
    tmp2 = tmp0 / tmp1
    tmp3 = constant(1e-05, torch.float32)
    tmp4 = tmp2 + tmp3
    tmp5 = rsqrt(tmp4)
    return tmp5
    ,
    ranges=(32,),
    origins={squeeze_1}
  ))
), 
	StorageBox(ComputedBuffer(name='buf9', layout=FixedLayout('cuda', torch.bool, size=[2, 32, 6, 6], stride=[1152, 36, 6, 1]), data=Pointwise(
    'cuda',
    torch.bool,
    tmp0 = load(buf8, i3 + 6 * i2 + 36 * i1 + 1152 * i0)
    tmp1 = constant(0, torch.float32)
    tmp2 = tmp0 <= tmp1
    return tmp2
    ,
    ranges=[2, 32, 6, 6],
    origins={le}
  ))
), 
ReinterpretView(StorageBox(ComputedBuffer(name='buf3', layout=FixedLayout('cuda', torch.float32, size=[1, 32, 1, 1], stride=[32, 1, 32, 32]), data=Pointwise(
      'cuda',
      torch.float32,
      tmp0 = load(buf2, i1)
      tmp1 = index_expr(72, torch.float32)
      tmp2 = tmp0 / tmp1
      return tmp2
      ,
      ranges=[1, 32, 1, 1],
      origins={convolution, var_mean}
    ))
  ),
  FixedLayout('cuda', torch.float32, size=[1, 32, 1, 1], stride=[32, 1, 1, 1]),
  no origins?
)]

5 调度的目的

  • 调度的目的:由于在前面已经进行了 decompose (一般在转为 aten 算子的时候就已经完成了),因此这里的目的是为了调整 buff 的次序,也就是调度内存,以优化内存访问的效率。

6 aten IR --> loop-level IR

aten IRloop-level IRtorch/_inductor/compile_fx.py 中 #179 完成的,其中,输入的 gm 中存储的 code 为:

def forward(self, primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8):
    clone = torch.ops.aten.clone.default(primals_5);  primals_5 = None
    clone_1 = torch.ops.aten.clone.default(primals_6);  primals_6 = None
    clone_2 = torch.ops.aten.clone.default(primals_7);  primals_7 = None
    convolution = torch.ops.aten.convolution.default(primals_8, primals_1, primals_2, [1, 1], [0, 0], [1, 1], False, [0, 0], 1);  primals_2 = None
    add = torch.ops.aten.add.Tensor(clone_2, 1);  clone_2 = None
    var_mean = torch.ops.aten.var_mean.correction(convolution, [0, 2, 3], correction = 0, keepdim = True)
    getitem = var_mean[0]
    getitem_1 = var_mean[1];  var_mean = None
    add_1 = torch.ops.aten.add.Tensor(getitem, 1e-05)
    rsqrt = torch.ops.aten.rsqrt.default(add_1);  add_1 = None
    sub = torch.ops.aten.sub.Tensor(convolution, getitem_1)
    mul = torch.ops.aten.mul.Tensor(sub, rsqrt);  sub = None
    squeeze = torch.ops.aten.squeeze.dims(getitem_1, [0, 2, 3]);  getitem_1 = None
    squeeze_1 = torch.ops.aten.squeeze.dims(rsqrt, [0, 2, 3]);  rsqrt = None
    mul_1 = torch.ops.aten.mul.Tensor(squeeze, 0.1)
    mul_2 = torch.ops.aten.mul.Tensor(clone, 0.9);  clone = None
    add_2 = torch.ops.aten.add.Tensor(mul_1, mul_2);  mul_1 = mul_2 = None
    squeeze_2 = torch.ops.aten.squeeze.dims(getitem, [0, 2, 3]);  getitem = None
    mul_3 = torch.ops.aten.mul.Tensor(squeeze_2, 1.0140845070422535);  squeeze_2 = None
    mul_4 = torch.ops.aten.mul.Tensor(mul_3, 0.1);  mul_3 = None
    mul_5 = torch.ops.aten.mul.Tensor(clone_1, 0.9);  clone_1 = None
    add_3 = torch.ops.aten.add.Tensor(mul_4, mul_5);  mul_4 = mul_5 = None
    unsqueeze = torch.ops.aten.unsqueeze.default(primals_3, -1)
    unsqueeze_1 = torch.ops.aten.unsqueeze.default(unsqueeze, -1);  unsqueeze = None
    unsqueeze_2 = torch.ops.aten.unsqueeze.default(primals_4, -1);  primals_4 = None
    unsqueeze_3 = torch.ops.aten.unsqueeze.default(unsqueeze_2, -1);  unsqueeze_2 = None
    mul_6 = torch.ops.aten.mul.Tensor(mul, unsqueeze_1);  mul = unsqueeze_1 = None
    add_4 = torch.ops.aten.add.Tensor(mul_6, unsqueeze_3);  mul_6 = unsqueeze_3 = None
    relu = torch.ops.aten.relu.default(add_4);  add_4 = None
    le = torch.ops.aten.le.Scalar(relu, 0)
    unsqueeze_4 = torch.ops.aten.unsqueeze.default(squeeze, 0);  squeeze = None
    unsqueeze_5 = torch.ops.aten.unsqueeze.default(unsqueeze_4, 2);  unsqueeze_4 = None
    unsqueeze_6 = torch.ops.aten.unsqueeze.default(unsqueeze_5, 3);  unsqueeze_5 = None
    return [add_2, add_3, add, relu, primals_1, primals_3, primals_8, convolution, squeeze_1, le, unsqueeze_6]

得到的loop-level IR会通过下一行的compile_to_fn()进行到 triton 的转化,生成的 triron 代码会存储在/tmp/目录下的.py文件中。返回的值是一个函数compiled_fn,其__module__变量存储着上述的文件路径。

7 GraphLowering --> Triton kernel
  • 调用graph.compile_to_fn()
  • 这个函数会先去调用 graph 中的compiler_to_module(),对此返回值,取出其 call 属性并返回
  • 对于compiler_to_module(),首先调用self.codegen()来生成 triton 代码(返回一个 py 文件),随后将此代码重命名后返回
  • 在 codegen 中,首先调用了self.init_wrapper_code(),此函数只是检查是否需要使用 cpp 包装,一般都不需要,于是实例化了一个WrapperCodeGen()的对象并返回
  • 对 graph 中的 scheduler 进行实例化,调度的对象为loop-level IR中构造出的东西,实际上可以视为计算节点
    • 实例化的过程:
      • 声明一个空的 node 列表,用于新的构造
      • 拿到后续计算所依赖的缓冲区名称
      • 遍历传入的参数列表,这里就是在之前传入的列表等,对于列表中的每一个元素,做如下操作:
        • 查看此 node 是否存在入度(也就是数据是从什么地方来的,一般为缓冲区名称)
        • 对 node 的类型进行查看,在这里由于传入的节点均为 buffer ,因此不会进入is_no_op函数。接着,判断是否为 ComputedBufferTemplateBuffer ,其中 TemplateBuffer 给出的解释为 Represents a Triton (in the futurue other type) of template operator that we can fuse an epilogue onto.(显然,对于后续的 ComputedBuffer 都会进入这一条分支,并执行 self.get_backend(node.get_device()).group_fn )。对于卷积而言,在这里定义为 ExternKernel。
        • 将刚才生成的 node 添加到最开始创建的 node 列表中去
        • 做完这部分,接着进行死节点消除与节点融合
      • 完成调度后,接着就开始直接生成内核,注意,如果是特殊的算子(例如卷积)是不会被翻译为 triton 的,而是直接生成 aten ,否则,我们会进入 codegen_kernel 阶段。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1962590.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

借助Aspose.html控件, 将SVG 转PNG 的 C# 图像处理库

Aspose.HTML for .NET 不仅提供超文本标记语言 ( HTML ) 文件处理&#xff0c;还提供流行图像文件格式之间的转换。您可以利用丰富的渲染和转换功能将SVG文件渲染为PNG、JPG或其他广泛使用的文件格式。但是&#xff0c;我们将使用此C# 图像处理库以编程方式在 C# 中将 SVG 转换…

VBA 颜色

1. ColorIndex 1-1. ColorIndex的值是从1到56。 Option ExplicitConst MAX_COL As Long 8 Const MAX_ROW As Long 2 Const START_ROW As Long 2 Const START_COL As Long 2Sub Color()Dim i As IntegerDim intRow As Long, intCol As LongCells.SelectSelection.ClearCon…

redis Ubuntu安装问题

报错1&#xff1a;Package pkg-config is not available, but is referred to by another package /bin/sh: 1: pkg-config: not found&#xff08;没有安装pkg-config&#xff09; sudo apt-get install pkg-config /bin/sh: 1: cc: not found&#xff08;没有安装gcc环境&am…

年过30年程序员,到底要不要考虑搞点副业

一、前言 作为一名年过三十的程序员&#xff0c;我深刻体会到了职场的残酷和不确定性。在这个技术日新月异的时代&#xff0c;我们不仅要在专业领域深耕细作&#xff0c;更要敏锐地捕捉互联网的风口&#xff0c;以确保自己不被时代淘汰。程序员的黄金年龄似乎被限定在35岁之前…

《Milvus Cloud向量数据库指南》——ChatGLM:从GLM-130B到GLM-4

ChatGLM:从GLM-130B到GLM-4的跨越:智谱AI在通用人工智能领域的深度探索与实践 在人工智能的浩瀚星空中,智谱AI如同一颗璀璨的新星,以其独特的技术视角和坚定的创新步伐,在通用人工智能(AGI)的征途上留下了深刻的足迹。技术生态总监贾伟在近期的一次分享中,不仅为我们描…

蓝牙+LoRa+北斗RTK融合定位系统介绍

蓝牙LoRa北斗RTK定位系统是新锐科创自主研发的融合定位系统&#xff0c;该系统利用融合定位技术将当今主流的室内外定位技术有机融合&#xff0c;从而满足不同场景定位需求。 蓝牙LoRa北斗RTK定位系统是一种室内外高精度人员定位管理系统&#xff0c;具有功耗低、部署简单、实时…

【计算机视觉学习之CV2图像操作实战:车牌识别1】

基于Sobel算子的车牌识别 步骤如下 高斯模糊图片灰度化Sobel算子图像二值化闭操作膨胀腐蚀中值滤波查找轮廓判断车牌区域 import cv2 # 读取图片 rawImage cv2.imread("car1.jpg") # 高斯模糊&#xff0c;将图片平滑化&#xff0c;去掉干扰的噪声 image cv2.Gau…

蓄势赋能 数智化转型掌舵人百望云杨正道荣膺“先锋人物”

2024年&#xff0c;在数据与智能的双涡轮驱动下&#xff0c;我们迎来了一个以智能科技为核心的新质生产力大爆发时代。在数智化浪潮的推动下&#xff0c;全球企业正站在转型升级的十字路口。在这个充满变革的时代&#xff0c;企业转型升级的道路充满挑战&#xff0c;但也孕育着…

【OceanBase DBA早下班系列】—— obdiag 收集的 SQL Monitor Report 如何解读

1. 前言 前几天写了一篇博客&#xff0c;告诉大家在遇到慢SQL或者复杂的并行SQL的时候怎么高效的来收集【SQL Monitor Report】&#xff0c;上一篇博客的链接&#xff1a; OceanBase 社区 &#xff1b;发出去后有不少问我这份报告咋解读。今天再出一篇博客给大家介绍下如何解…

【股票价格跨度】python刷题记录

R3-栈和队列-单调栈 有个小思路&#xff1a;如果用栈的话&#xff0c;比如a,b在c前面&#xff0c;然后查找c的跨度的时候&#xff0c;往回搜索&#xff0c;如果b比c小&#xff0c;那就可以把b的跨度加到c上&#xff0c;否则&#xff0c;继续往回查找到a----&#xff08;思路貌似…

AI画笔,你的创意伙伴:6款最佳AI绘画工具推荐

在这个无限可能的时代&#xff0c;一个优秀的人工智能绘画软件不仅可以打破传统绘画方法的束缚&#xff0c;而且可以让每个热爱艺术的人都体验到创作的乐趣。那么&#xff0c;什么样的人工智能绘画软件才是优秀的呢&#xff1f;什么样的人工智能绘画软件才能生成超逼真的AI绘画…

(十)联合概率数据互联原理及应用(JPDA)

目录 前言 一、JPDA原理及算法步骤 &#xff08;一&#xff09;算法步骤 1.确认矩阵计算 2.确认矩阵拆分 3.互联概率计算 4.状态及协方差更新 二、仿真验证 &#xff08;一&#xff09;模型构建 &#xff08;二&#xff09;仿真结果 总结 引用文献 前言 本文主要针…

世界上速度最快的超级计算机推导出超级BC8钻石配方

BC8 超级钻石比任何已知材料都要坚硬&#xff0c;但它们很可能只存在于巨型系外行星的内核中。现在&#xff0c;世界上最强的超级计算机"前沿"已经揭开了它们形成的秘密&#xff0c;这一发现可能会导致在地球上生产它们。 钻石不仅是夺人眼球的珠宝&#xff0c;而且在…

致敬万物的解释者:丹尼尔・丹尼特(1942~2024)

2024 年 4 月 19 日&#xff0c;全世界最受欢迎的哲学家之一丹尼尔・丹尼特在缅因州波特兰去世。他被无数人称作是自己的哲学英雄&#xff0c;他的观点像一把利刃&#xff0c;在心智和意识领域无出其右。 巨星陨落&#xff0c;思想长存。让我们一起回顾他的生平&#xff0c;聆…

AGI|如何用Open WebUI和Ollama在本地运行大型语言模型?

在某AI产品发布会上&#xff0c;我们需要演示在个人PC上运行大模型的能力。为了实现这一目标&#xff0c;我们进行了深入的市场调研和技术评估&#xff0c;最终选择了Open WebUI和Ollama作为演示的核心工具。 目录 一、介绍 二、部署 三、代码展示 一、介绍 什么是Ollama …

前端代码混淆加密(使用Terser、WebpackObfuscator)

零、相关技术及版本号 "vue": "2.6.12", "vue/cli-service": "4.4.6", "javascript-obfuscator": "^4.1.1", "terser-webpack-plugin": "^4.2.3", "vue-template-compiler": &quo…

JAVA后端拉取gitee仓库代码项目并将该工程打包成jar包

公司当前有一个系统用于导出项目&#xff0c;而每次导出的项目并不可以直接使用&#xff0c;需要手动从gitee代码仓库中获取一个模板代码然后将他们整合到一起它才是一个完整的项目&#xff0c;所以目前我的任务就是编写一个java程序可以自动地从gitee仓库拉取下来那个模板代码…

纯前端实现导出pdf文件(服务端不参与)

大致查阅了现阶段使用较多的几种方案,&#xff0c;大概有以下几种方式&#xff1a; 一、原生window.print()方法导出pdf 二、jspdf 三、jspdf html2canvas 四、pdfmake 方案优点缺点window.print()1、兼容性最好 2、可以将任意内容导出成 pdf 文档, 甚至是非改页面上的内容1…

⑦【从0制作自己的ros导航小车:上位机篇】02、ros1多机通讯与坐标变换可视化

从0制作自己的ros导航小车 前言一、ros1多机通讯二、rviz可视化小车坐标系 系列文章&#xff1a; ①【从0制作自己的ros导航小车&#xff1a;介绍及准备】 ②【从0制作自己的ros导航小车&#xff1a;下位机篇】01、工程准备_标准库移植freertos ③【从0制作自己的ros导航小车&a…

Sobel Operator

什么是图像边缘&#xff1f; 边缘是指图像中灰度或颜色强度发生显著变化的区域。 什么是Sobel operator Sobel算子是一种用于图像处理的边缘检测算子。它通过计算图像灰度值的梯度来检测图像中的边缘。 什么情况产生梯度&#xff1f; 黑与白交界处。 Sobel 算子原理 计算P…