0 概述
TorchDynamo 是一个 Python 级别的即时 (JIT) 编译器,旨在让未修改的 PyTorch 程序运行得更快。它通过 Python Frame Evaluation Hooks(Python 框架评估钩子)来实现这一目标,以便在运行时动态地生成和优化代码。这使得 TorchDynamo 可以有效地处理各种 Python 代码,包括包含控制流(如循环和条件语句)的代码,而无需进行任何修改。
整个 pytorch 的编译栈如下:
后端整体流程如下:
以 Triton 为例,首先会做一次 lowering,然后进行调度,最后才会生成 Triton 的 kernel。
1 Loop-level IR
这里的 lowering,使用loop-level IR
来表示,其对aten IR
的每一句话做解释,并且每次的解析都会与前文联系起来。这一层 IR 的类型有:
- PointWise
- Reduction
- TensorBox
- MatrixMultiplyAdd
除此之外,还有一些其他的类型。
这一层处理流程:
-
对于从前端拿到的
aten IR
:
-
对于上面的每一句运算,都翻译为
loop-level IR
:-
convert_element_type
:
-
amax
:
这里将计算的结果存储到
buf0
中 -
sub
:
由于
amax
将结果存储到buf0
中,因此这里才能从buf0
中直接 load 进来 -
exp
:
如果上一条 IR 是
pointwise
的话,那么就会和这一次的进行归约,例如这里,只是在sub
的 IR 上加上了tmp4 = exp(tmp3)
并将 return 改为了tmp4
这一层的 pass 会对aten IR
的每一句话进行解析,并且每次的解析都会与前文联系起来,最终得到一个归约的loop-level IR
。
-
2 Schedule
一下面的代码为例:
if __name__ == '__main__':
model = nn.Sequential(
nn.Conv2d(16, 32, 3),
nn.BatchNorm2d(32),
nn.ReLU(),
).cuda()
model = torch.compile(model)
x = torch.randn((2, 16, 8, 8), requires_grad=True, device="cuda")
model(x)
其在loop-level
层构建出 11 个缓冲区。随后,对这些缓冲区进行 schedule,内容包括:
这里有些缓冲区启用了 Reduction,也就是说这里的归约是对于缓冲区而言的。将这些缓冲区放在一起,生成一个 kernel ,而其他的缓冲区,则单独生成自己的 kernel (注意这里的 kernel 是指 triton 的 kernel,实际上我们可以认为是一个函数)。只有 reduction 的 kernel 中会出现循环语句,若只是 pointwise 的计算,则不会生成循环
3 Triton Kernel
最后就是 triton kernel 的生成,其采取的策略是:
- 首先生成 load 语句
- 生成 compute 语句
- 生成 store 语句
- 组合三种语句为一个 kernel
- 组合所有 kernel 与一个 call 函数和 main 模块在一起为一个 .py 文件
上述例子生成的文件如下:
from ctypes import c_void_p, c_long
import torch
import math
import random
from torch import empty_strided, as_strided, device
from torch._inductor.codecache import AsyncCompile
from torch._inductor.select_algorithm import extern_kernels
aten = torch.ops.aten
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
async_compile = AsyncCompile()
import triton
import triton.language as tl
from torch._inductor.triton_ops.autotune import grid
from torch._C import _cuda_getCurrentRawStream as get_cuda_stream
triton__0 = async_compile.triton('''
import triton
import triton.language as tl
from torch._inductor.ir import ReductionHint
from torch._inductor.ir import TileHint
from torch._inductor.triton_ops.autotune import pointwise
from torch._inductor.utils import instance_descriptor
@pointwise(size_hints=[4096], filename=__file__, meta={'signature': {0: '*fp32', 1: '*fp32', 2: 'i32'}, 'device': 0, 'constants': {}, 'mutated_arg_names': ['in_out_ptr0'], 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]})
@triton.jit
def triton_(in_out_ptr0, in_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 2304
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x3 = xindex
x1 = (xindex // 36) % 32
tmp0 = tl.load(in_out_ptr0 + (x3), xmask)
tmp1 = tl.load(in_ptr0 + (x1), xmask)
tmp2 = tmp0 + tmp1
tl.store(in_out_ptr0 + (x3 + tl.zeros([XBLOCK], tl.int32)), tmp2, xmask)
''')
triton__1 = async_compile.triton('''
import triton
import triton.language as tl
from torch._inductor.ir import ReductionHint
from torch._inductor.ir import TileHint
from torch._inductor.triton_ops.autotune import reduction
from torch._inductor.utils import instance_descriptor
@reduction(
size_hints=[32, 128],
reduction_hint=ReductionHint.INNER,
filename=__file__,
meta={'signature': {0: '*fp32', 1: '*fp32', 2: '*fp32', 3: '*fp32', 4: '*fp32', 5: '*fp32', 6: '*fp32', 7: '*fp32', 8: 'i32', 9: 'i32'}, 'device': 0, 'constants': {}, 'mutated_arg_names': ['in_out_ptr0'], 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7, 8), equal_to_1=())]}
)
@triton.jit
def triton_(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, out_ptr0, out_ptr1, out_ptr2, out_ptr3, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
xnumel = 32
rnumel = 72
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
rbase = tl.arange(0, RBLOCK)[None, :]
x0 = xindex
_tmp1 = tl.zeros([XBLOCK, RBLOCK], tl.float32) + 0
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r1 = rindex % 36
r2 = (rindex // 36)
tmp0 = tl.load(in_ptr0 + (r1 + (36*x0) + (1152*r2)), rmask & xmask, eviction_policy='evict_last', other=0)
_tmp1 = tl.where(rmask & xmask, _tmp1 + tmp0, _tmp1)
tmp1 = tl.sum(_tmp1, 1)[:, None]
tmp6 = tl.load(in_ptr1 + (x0), xmask)
tmp2 = 72.0
tmp3 = tmp1 / tmp2
tmp4 = 0.1
tmp5 = tmp3 * tmp4
tmp7 = 0.9
tmp8 = tmp6 * tmp7
tmp9 = tmp5 + tmp8
tl.store(in_out_ptr0 + (x0 + tl.zeros([XBLOCK, 1], tl.int32)), tmp3, xmask)
tl.store(out_ptr0 + (x0 + tl.zeros([XBLOCK, 1], tl.int32)), tmp9, xmask)
_tmp13 = tl.zeros([XBLOCK, RBLOCK], tl.float32) + 0
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r1 = rindex % 36
r2 = (rindex // 36)
tmp10 = tl.load(in_ptr0 + (r1 + (36*x0) + (1152*r2)), rmask & xmask, eviction_policy='evict_last', other=0)
tmp11 = tmp10 - tmp3
tmp12 = tmp11 * tmp11
_tmp13 = tl.where(rmask & xmask, _tmp13 + tmp12, _tmp13)
tmp13 = tl.sum(_tmp13, 1)[:, None]
tl.store(out_ptr1 + x0, tmp13, xmask)
tmp23 = tl.load(in_ptr2 + (x0), xmask)
tmp14 = 72.0
tmp15 = tmp13 / tmp14
tmp16 = 1e-05
tmp17 = tmp15 + tmp16
tmp18 = tl.libdevice.rsqrt(tmp17)
tmp19 = 1.0140845070422535
tmp20 = tmp15 * tmp19
tmp21 = 0.1
tmp22 = tmp20 * tmp21
tmp24 = 0.9
tmp25 = tmp23 * tmp24
tmp26 = tmp22 + tmp25
tl.store(out_ptr2 + (x0 + tl.zeros([XBLOCK, 1], tl.int32)), tmp18, xmask)
tl.store(out_ptr3 + (x0 + tl.zeros([XBLOCK, 1], tl.int32)), tmp26, xmask)
''')
triton__2 = async_compile.triton('''
import triton
import triton.language as tl
from torch._inductor.ir import ReductionHint
from torch._inductor.ir import TileHint
from torch._inductor.triton_ops.autotune import pointwise
from torch._inductor.utils import instance_descriptor
@pointwise(size_hints=[4096], filename=__file__, meta={'signature': {0: '*fp32', 1: '*fp32', 2: '*fp32', 3: '*fp32', 4: '*fp32', 5: '*fp32', 6: '*i1', 7: 'i32'}, 'device': 0, 'constants': {}, 'mutated_arg_names': [], 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7), equal_to_1=())]})
@triton.jit
def triton_(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr0, out_ptr1, xnumel, XBLOCK : tl.constexpr):
xnumel = 2304
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x3 = xindex
x1 = (xindex // 36) % 32
tmp0 = tl.load(in_ptr0 + (x3), xmask)
tmp1 = tl.load(in_ptr1 + (x1), xmask)
tmp3 = tl.load(in_ptr2 + (x1), xmask)
tmp10 = tl.load(in_ptr3 + (x1), xmask)
tmp12 = tl.load(in_ptr4 + (x1), xmask)
tmp2 = tmp0 - tmp1
tmp4 = 72.0
tmp5 = tmp3 / tmp4
tmp6 = 1e-05
tmp7 = tmp5 + tmp6
tmp8 = tl.libdevice.rsqrt(tmp7)
tmp9 = tmp2 * tmp8
tmp11 = tmp9 * tmp10
tmp13 = tmp11 + tmp12
tmp14 = tl.where(0 != 0, 0, tl.where(0 > tmp13, 0, tmp13))
tmp15 = 0.0
tmp16 = tmp14 <= tmp15
tl.store(out_ptr0 + (x3 + tl.zeros([XBLOCK], tl.int32)), tmp14, xmask)
tl.store(out_ptr1 + (x3 + tl.zeros([XBLOCK], tl.int32)), tmp16, xmask)
''')
triton__3 = async_compile.triton('''
import triton
import triton.language as tl
from torch._inductor.ir import ReductionHint
from torch._inductor.ir import TileHint
from torch._inductor.triton_ops.autotune import pointwise
from torch._inductor.utils import instance_descriptor
@pointwise(size_hints=[1], filename=__file__, meta={'signature': {0: '*i64', 1: '*i64', 2: 'i32'}, 'device': 0, 'constants': {}, 'mutated_arg_names': [], 'configs': [instance_descriptor(divisible_by_16=(0, 1), equal_to_1=())]})
@triton.jit
def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 1
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
tmp0_load = tl.load(in_ptr0 + (0))
tmp0 = tl.broadcast_to(tmp0_load, [XBLOCK])
tmp1 = 1
tmp2 = tmp0 + tmp1
tl.store(out_ptr0 + (0 + tl.zeros([XBLOCK], tl.int32)), tmp2, None)
''')
async_compile.wait(globals())
del async_compile
def call(args):
primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8 = args
args.clear()
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0) # no-op to ensure context
buf0 = aten.convolution(primals_8, primals_1, None, (1, 1), (0, 0), (1, 1), False, (0, 0), 1)
assert_size_stride(buf0, (2, 32, 6, 6), (1152, 36, 6, 1))
buf1 = buf0; del buf0 # reuse
stream0 = get_cuda_stream(0)
triton__0.run(buf1, primals_2, 2304, grid=grid(2304), stream=stream0)
del primals_2
buf2 = empty_strided((1, 32, 1, 1), (32, 1, 32, 32), device='cuda', dtype=torch.float32)
buf3 = buf2; del buf2 # reuse
buf6 = empty_strided((32, ), (1, ), device='cuda', dtype=torch.float32)
buf4 = empty_strided((1, 32, 1, 1), (32, 1, 32, 32), device='cuda', dtype=torch.float32)
buf5 = empty_strided((32, ), (1, ), device='cuda', dtype=torch.float32)
buf7 = empty_strided((32, ), (1, ), device='cuda', dtype=torch.float32)
triton__1.run(buf3, buf1, primals_5, primals_6, buf6, buf4, buf5, buf7, 32, 72, grid=grid(32), stream=stream0)
del primals_5
del primals_6
buf8 = empty_strided((2, 32, 6, 6), (1152, 36, 6, 1), device='cuda', dtype=torch.float32)
buf9 = empty_strided((2, 32, 6, 6), (1152, 36, 6, 1), device='cuda', dtype=torch.bool)
triton__2.run(buf1, buf3, buf4, primals_3, primals_4, buf8, buf9, 2304, grid=grid(2304), stream=stream0)
del buf4
del primals_4
buf10 = empty_strided((), (), device='cuda', dtype=torch.int64)
triton__3.run(primals_7, buf10, 1, grid=grid(1), stream=stream0)
del primals_7
return (buf6, buf7, buf10, buf8, primals_1, primals_3, primals_8, buf1, buf5, buf9, as_strided(buf3, (1, 32, 1, 1), (32, 1, 1, 1)), )
if __name__ == "__main__":
from torch._dynamo.testing import rand_strided
from torch._inductor.utils import print_performance
primals_1 = rand_strided((32, 16, 3, 3), (144, 9, 3, 1), device='cuda:0', dtype=torch.float32)
primals_2 = rand_strided((32, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_3 = rand_strided((32, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_4 = rand_strided((32, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_5 = rand_strided((32, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_6 = rand_strided((32, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_7 = rand_strided((), (), device='cuda:0', dtype=torch.int64)
primals_8 = rand_strided((2, 16, 8, 8), (1024, 64, 8, 1), device='cuda:0', dtype=torch.float32)
print_performance(lambda: call([primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8]))
此文件生成在/tmp
文件夹中,后缀为py
,后续直接运行此文件,可得到 performace 的值,同样,也可在运行中捕获到运算的值。
4 loop-level IR --> triton kernel
通过数据结构GraphLowering
的方法run(*example_input)
也就是一个Fake Tensor
来生成 triton kernel:
Graph ID : 0
Input : {
'primals_1': TensorBox(StorageBox(
InputBuffer(name='primals_1', layout=FixedLayout('cuda', torch.float32, size=[32, 16, 3, 3], stride=[144, 9, 3, 1]))
)),
'primals_2': TensorBox(StorageBox(
InputBuffer(name='primals_2', layout=FixedLayout('cuda', torch.float32, size=[32], stride=[1]))
)),
'primals_3': TensorBox(StorageBox(
InputBuffer(name='primals_3', layout=FixedLayout('cuda', torch.float32, size=[32], stride=[1]))
)),
'primals_4': TensorBox(StorageBox(
InputBuffer(name='primals_4', layout=FixedLayout('cuda', torch.float32, size=[32], stride=[1]))
)),
'primals_5': TensorBox(StorageBox(
InputBuffer(name='primals_5', layout=FixedLayout('cuda', torch.float32, size=[32], stride=[1]))
)),
'primals_6': TensorBox(StorageBox(
InputBuffer(name='primals_6', layout=FixedLayout('cuda', torch.float32, size=[32], stride=[1]))
)),
'primals_7': TensorBox(StorageBox(
InputBuffer(name='primals_7', layout=FixedLayout('cuda', torch.int64, size=[], stride=[]))
)),
'primals_8': TensorBox(StorageBox(
InputBuffer(name='primals_8', layout=FixedLayout('cuda', torch.float32, size=[2, 16, 8, 8], stride=[1024, 64, 8, 1]))
))}
Origin Input : {
'primals_1': InputBuffer(name='primals_1', layout=FixedLayout('cuda', torch.float32, size=[32, 16, 3, 3], stride=[144, 9, 3, 1])),
'primals_2': InputBuffer(name='primals_2', layout=FixedLayout('cuda', torch.float32, size=[32], stride=[1])),
'primals_3': InputBuffer(name='primals_3', layout=FixedLayout('cuda', torch.float32, size=[32], stride=[1])),
'primals_4': InputBuffer(name='primals_4', layout=FixedLayout('cuda', torch.float32, size=[32], stride=[1])),
'primals_5': InputBuffer(name='primals_5', layout=FixedLayout('cuda', torch.float32, size=[32], stride=[1])),
'primals_6': InputBuffer(name='primals_6', layout=FixedLayout('cuda', torch.float32, size=[32], stride=[1])),
'primals_7': InputBuffer(name='primals_7', layout=FixedLayout('cuda', torch.int64, size=[], stride=[])),
'primals_8': InputBuffer(name='primals_8', layout=FixedLayout('cuda', torch.float32, size=[2, 16, 8, 8], stride=[1024, 64, 8, 1]))}
Output : [
StorageBox(ComputedBuffer(name='buf6', layout=FixedLayout('cuda', torch.float32, size=[32], stride=[1]), data=Pointwise(
'cuda',
torch.float32,
tmp0 = load(buf3, i0)
tmp1 = constant(0.1, torch.float32)
tmp2 = tmp0 * tmp1
tmp3 = load(primals_5, i0)
tmp4 = constant(0.9, torch.float32)
tmp5 = tmp3 * tmp4
tmp6 = tmp2 + tmp5
return tmp6
,
ranges=[32],
origins={add_2}
))
),
StorageBox(ComputedBuffer(name='buf7', layout=FixedLayout('cuda', torch.float32, size=(32,), stride=[1]), data=Pointwise(
'cuda',
torch.float32,
tmp0 = load(buf4, i0)
tmp1 = index_expr(72, torch.float32)
tmp2 = tmp0 / tmp1
tmp3 = constant(1.0140845070422535, torch.float32)
tmp4 = tmp2 * tmp3
tmp5 = constant(0.1, torch.float32)
tmp6 = tmp4 * tmp5
tmp7 = load(primals_6, i0)
tmp8 = constant(0.9, torch.float32)
tmp9 = tmp7 * tmp8
tmp10 = tmp6 + tmp9
return tmp10
,
ranges=(32,),
origins={add_3}
))
),
StorageBox(ComputedBuffer(name='buf10', layout=FixedLayout('cuda', torch.int64, size=[], stride=[]), data=Pointwise(
'cuda',
torch.int64,
tmp0 = load(primals_7, 0)
tmp1 = constant(1, torch.int64)
tmp2 = tmp0 + tmp1
return tmp2
,
ranges=[],
origins={primals_7, clone_2, add}
))
),
StorageBox(ComputedBuffer(name='buf8', layout=FixedLayout('cuda', torch.float32, size=[2, 32, 6, 6], stride=[1152, 36, 6, 1]), data=Pointwise(
'cuda',
torch.float32,
tmp0 = load(buf1, i3 + 6 * i2 + 36 * i1 + 1152 * i0)assembly
tmp1 = load(buf3, i1)
tmp2 = tmp0 - tmp1
tmp3 = load(buf4, i1)
tmp4 = index_expr(72, torch.float32)
tmp5 = tmp3 / tmp4
tmp6 = constant(1e-05, torch.float32)
tmp7 = tmp5 + tmp6
tmp8 = rsqrt(tmp7)
tmp9 = tmp2 * tmp8
tmp10 = load(primals_3, i1)
tmp11 = tmp9 * tmp10
tmp12 = load(primals_4, i1)
tmp13 = tmp11 + tmp12
tmp14 = relu(tmp13)
return tmp14
,
ranges=[2, 32, 6, 6],
origins={relu}
))
),
StorageBox(InputBuffer(name='primals_1', layout=FixedLayout('cuda', torch.float32, size=[32, 16, 3, 3], stride=[144, 9, 3, 1]))
),
StorageBox(InputBuffer(name='primals_3', layout=FixedLayout('cuda', torch.float32, size=[32], stride=[1]))
),
StorageBox(InputBuffer(name='primals_8', layout=FixedLayout('cuda', torch.float32, size=[2, 16, 8, 8], stride=[1024, 64, 8, 1]))
),
StorageBox(ComputedBuffer(name='buf1', layout=FixedLayout('cuda', torch.float32, size=[2, 32, 6, 6], stride=[1152, 36, 6, 1]), data=Pointwise(
'cuda',
torch.float32,
tmp0 = load(buf0, i3 + 6 * i2 + 36 * i1 + 1152 * i0)
tmp1 = load(primals_2, i1)
tmp2 = tmp0 + tmp1
return tmp2
,
ranges=[2, 32, 6, 6],
origins={convolution}
))
),
StorageBox(ComputedBuffer(name='buf5', layout=FixedLayout('cuda', torch.float32, size=(32,), stride=[1]), data=Pointwise(
'cuda',
torch.float32,
tmp0 = load(buf4, i0)
tmp1 = index_expr(72, torch.float32)
tmp2 = tmp0 / tmp1
tmp3 = constant(1e-05, torch.float32)
tmp4 = tmp2 + tmp3
tmp5 = rsqrt(tmp4)
return tmp5
,
ranges=(32,),
origins={squeeze_1}
))
),
StorageBox(ComputedBuffer(name='buf9', layout=FixedLayout('cuda', torch.bool, size=[2, 32, 6, 6], stride=[1152, 36, 6, 1]), data=Pointwise(
'cuda',
torch.bool,
tmp0 = load(buf8, i3 + 6 * i2 + 36 * i1 + 1152 * i0)
tmp1 = constant(0, torch.float32)
tmp2 = tmp0 <= tmp1
return tmp2
,
ranges=[2, 32, 6, 6],
origins={le}
))
),
ReinterpretView(StorageBox(ComputedBuffer(name='buf3', layout=FixedLayout('cuda', torch.float32, size=[1, 32, 1, 1], stride=[32, 1, 32, 32]), data=Pointwise(
'cuda',
torch.float32,
tmp0 = load(buf2, i1)
tmp1 = index_expr(72, torch.float32)
tmp2 = tmp0 / tmp1
return tmp2
,
ranges=[1, 32, 1, 1],
origins={convolution, var_mean}
))
),
FixedLayout('cuda', torch.float32, size=[1, 32, 1, 1], stride=[32, 1, 1, 1]),
no origins?
)]
5 调度的目的
- 调度的目的:由于在前面已经进行了 decompose (一般在转为 aten 算子的时候就已经完成了),因此这里的目的是为了调整 buff 的次序,也就是调度内存,以优化内存访问的效率。
6 aten IR --> loop-level IR
aten IR
到loop-level IR
在torch/_inductor/compile_fx.py
中 #179 完成的,其中,输入的 gm 中存储的 code 为:
def forward(self, primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8):
clone = torch.ops.aten.clone.default(primals_5); primals_5 = None
clone_1 = torch.ops.aten.clone.default(primals_6); primals_6 = None
clone_2 = torch.ops.aten.clone.default(primals_7); primals_7 = None
convolution = torch.ops.aten.convolution.default(primals_8, primals_1, primals_2, [1, 1], [0, 0], [1, 1], False, [0, 0], 1); primals_2 = None
add = torch.ops.aten.add.Tensor(clone_2, 1); clone_2 = None
var_mean = torch.ops.aten.var_mean.correction(convolution, [0, 2, 3], correction = 0, keepdim = True)
getitem = var_mean[0]
getitem_1 = var_mean[1]; var_mean = None
add_1 = torch.ops.aten.add.Tensor(getitem, 1e-05)
rsqrt = torch.ops.aten.rsqrt.default(add_1); add_1 = None
sub = torch.ops.aten.sub.Tensor(convolution, getitem_1)
mul = torch.ops.aten.mul.Tensor(sub, rsqrt); sub = None
squeeze = torch.ops.aten.squeeze.dims(getitem_1, [0, 2, 3]); getitem_1 = None
squeeze_1 = torch.ops.aten.squeeze.dims(rsqrt, [0, 2, 3]); rsqrt = None
mul_1 = torch.ops.aten.mul.Tensor(squeeze, 0.1)
mul_2 = torch.ops.aten.mul.Tensor(clone, 0.9); clone = None
add_2 = torch.ops.aten.add.Tensor(mul_1, mul_2); mul_1 = mul_2 = None
squeeze_2 = torch.ops.aten.squeeze.dims(getitem, [0, 2, 3]); getitem = None
mul_3 = torch.ops.aten.mul.Tensor(squeeze_2, 1.0140845070422535); squeeze_2 = None
mul_4 = torch.ops.aten.mul.Tensor(mul_3, 0.1); mul_3 = None
mul_5 = torch.ops.aten.mul.Tensor(clone_1, 0.9); clone_1 = None
add_3 = torch.ops.aten.add.Tensor(mul_4, mul_5); mul_4 = mul_5 = None
unsqueeze = torch.ops.aten.unsqueeze.default(primals_3, -1)
unsqueeze_1 = torch.ops.aten.unsqueeze.default(unsqueeze, -1); unsqueeze = None
unsqueeze_2 = torch.ops.aten.unsqueeze.default(primals_4, -1); primals_4 = None
unsqueeze_3 = torch.ops.aten.unsqueeze.default(unsqueeze_2, -1); unsqueeze_2 = None
mul_6 = torch.ops.aten.mul.Tensor(mul, unsqueeze_1); mul = unsqueeze_1 = None
add_4 = torch.ops.aten.add.Tensor(mul_6, unsqueeze_3); mul_6 = unsqueeze_3 = None
relu = torch.ops.aten.relu.default(add_4); add_4 = None
le = torch.ops.aten.le.Scalar(relu, 0)
unsqueeze_4 = torch.ops.aten.unsqueeze.default(squeeze, 0); squeeze = None
unsqueeze_5 = torch.ops.aten.unsqueeze.default(unsqueeze_4, 2); unsqueeze_4 = None
unsqueeze_6 = torch.ops.aten.unsqueeze.default(unsqueeze_5, 3); unsqueeze_5 = None
return [add_2, add_3, add, relu, primals_1, primals_3, primals_8, convolution, squeeze_1, le, unsqueeze_6]
得到的loop-level IR
会通过下一行的compile_to_fn()
进行到 triton 的转化,生成的 triron 代码会存储在/tmp/
目录下的.py
文件中。返回的值是一个函数compiled_fn
,其__module__
变量存储着上述的文件路径。
7 GraphLowering --> Triton kernel
- 调用
graph.compile_to_fn()
- 这个函数会先去调用 graph 中的
compiler_to_module()
,对此返回值,取出其 call 属性并返回 - 对于
compiler_to_module()
,首先调用self.codegen()
来生成 triton 代码(返回一个 py 文件),随后将此代码重命名后返回 - 在 codegen 中,首先调用了
self.init_wrapper_code()
,此函数只是检查是否需要使用 cpp 包装,一般都不需要,于是实例化了一个WrapperCodeGen()
的对象并返回 - 对 graph 中的 scheduler 进行实例化,调度的对象为
loop-level IR
中构造出的东西,实际上可以视为计算节点- 实例化的过程:
- 声明一个空的 node 列表,用于新的构造
- 拿到后续计算所依赖的缓冲区名称
- 遍历传入的参数列表,这里就是在之前传入的列表等,对于列表中的每一个元素,做如下操作:
- 查看此 node 是否存在入度(也就是数据是从什么地方来的,一般为缓冲区名称)
- 对 node 的类型进行查看,在这里由于传入的节点均为 buffer ,因此不会进入
is_no_op
函数。接着,判断是否为ComputedBuffer
或TemplateBuffer
,其中TemplateBuffer
给出的解释为Represents a Triton (in the futurue other type) of template operator that we can fuse an epilogue onto
.(显然,对于后续的ComputedBuffer
都会进入这一条分支,并执行self.get_backend(node.get_device()).group_fn )
。对于卷积而言,在这里定义为 ExternKernel。 - 将刚才生成的 node 添加到最开始创建的 node 列表中去
- 做完这部分,接着进行死节点消除与节点融合
- 完成调度后,接着就开始直接生成内核,注意,如果是特殊的算子(例如卷积)是不会被翻译为 triton 的,而是直接生成 aten ,否则,我们会进入 codegen_kernel 阶段。
- 实例化的过程: