Triton.jit源码分析
目标对象
区别于nVidia triton inference server
,此处分析openAI的triton complier
。
分析原因
阅读官方tutorial时,kernel函数被decorator后语法很有意思,搞得人很迷惑,尝试弄清楚实现机制。
@triton.jit注解
当需要使用triton实现device侧的kernel时,用类python语言实现triton函数,使用@triton.jit
对该函数修饰,如下面的向量求和代码:
1
2
3
4
5
6
7
8
9
@triton.jit
def def add_kernel(x_ptr, # *Pointer* to first input vector.
y_ptr, # *Pointer* to second input vector.
output_ptr, # *Pointer* to output vector.
n_elements, # Size of the vector.
BLOCK_SIZE: tl.constexpr, # Number of elements each program should process.
# NOTE: `constexpr` so it can be used as a shape value.
):
...
上述代码实现了GPU求和函数的triton版本,当调用时,有官方文档给出代码:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
def add(x: torch.Tensor, y: torch.Tensor):
# We need to preallocate the output.
output = torch.empty_like(x)
assert x.is_cuda and y.is_cuda and output.is_cuda
n_elements = output.numel()
# The SPMD launch grid denotes the number of kernel instances that run in parallel.
# It is analogous to CUDA launch grids. It can be either Tuple[int], or Callable(metaparameters) -> Tuple[int].
# In this case, we use a 1D grid where the size is the number of blocks:
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), )
# NOTE:
# - Each torch.tensor object is implicitly converted into a pointer to its first element.
# - `triton.jit`'ed functions can be indexed with a launch grid to obtain a callable GPU kernel.
# - Don't forget to pass meta-parameters as keywords arguments.
add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)
# We return a handle to z but, since `torch.cuda.synchronize()` hasn't been called, the kernel is still
# running asynchronously at this point.
return output
令人困惑的地方在于add_kernel[grid](*args)
的实现,如果使用cuda实现,则采用类似add_kernel<<<GD, BD>>>(*args)
,同样都是并行编程,由cuda推测triton代码也是将kernel函数加载到一个grid级的线程集合并行执行。
那么@triton.jit
注解如何将kernel规划到该grid上执行呢,该注解的实现逻辑在triton/runtime/jit.py
内,代码如下:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
def jit(
fn: Optional[T] = None,
*,
version=None,
do_not_specialize: Optional[Iterable[int]] = None,
debug: Optional[bool] = None,
noinline: Optional[bool] = None,
) -> Union[JITFunction[T], Callable[[T], JITFunction[T]]]:
"""
Decorator for JIT-compiling a function using the Triton compiler.
:note: When a jit'd function is called, arguments are
implicitly converted to pointers if they have a :code:`.data_ptr()` method
and a `.dtype` attribute.
:note: This function will be compiled and run on the GPU. It will only have access to:
* python primitives,
* builtins within the triton package,
* arguments to this function,
* other jit'd functions
:param fn: the function to be jit-compiled
:type fn: Callable
"""
def decorator(fn: T) -> JITFunction[T]:
assert callable(fn)
if os.getenv("TRITON_INTERPRET", "0") == "1":
return InterpretedFunction(fn)
else:
return JITFunction(
fn,
version=version,
do_not_specialize=do_not_specialize,
debug=debug,
noinline=noinline,
)
if fn is not None:
return decorator(fn)
else:
return decorator
最终被修饰的func
会作为参数构造一个JITFunction
类的实例,该类的定义JITFunction(KernelInterface[T])
继承了KernelInterface
类,代码实现如下:
1
2
3
4
5
6
7
8
9
10
class KernelInterface(Generic[T]):
run: T
def __getitem__(self, grid) -> T:
"""
A JIT function is launched with: fn[grid](*args, **kwargs).
Hence JITFunction.__getitem__ returns a callable proxy that
memorizes the grid.
"""
return cast(T, functools.partial(cast(Callable, self.run), grid=grid))
上述代码定义的__getitem__()
解释了add_kernel[grid]
语法,grid会作为一个偏函数的可变参数传入,构造新的run
函数。下一步进入贼长的run
函数。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
def run(self, *args, **kwargs):
from ..compiler import CompiledKernel, compile, get_arch_default_num_stages, get_arch_default_num_warps
# Get a compiler-flags arg like `num_warps` and remove it from kwargs.
def get_special_arg(name: str, default=None):
if name not in kwargs:
return default
ret = kwargs[name]
del kwargs[name]
return ret
grid = get_special_arg("grid")
num_warps = get_special_arg("num_warps")
num_ctas = get_special_arg("num_ctas", 1)
num_stages = get_special_arg("num_stages")
enable_warp_specialization = get_special_arg("enable_warp_specialization", False)
enable_fp_fusion = get_special_arg("enable_fp_fusion", True)
extern_libs = get_special_arg("extern_libs")
stream = get_special_arg("stream")
warmup = get_special_arg("warmup", False)
device = get_special_arg("device")
device_type = get_special_arg("device_type")
# Bind the remaining arguments to `fn`.
bound_args = self.signature.bind(*args, **kwargs)
bound_args.apply_defaults()
assert len(bound_args.arguments) == len(self.params)
args = [KernelArg(arg_value, param) for (_, arg_value), param in zip(bound_args.arguments.items(), self.params)]
non_constexpr_arg_values = [arg.value for arg in args if not arg.param.is_constexpr]
sig_key = tuple(arg.signature_key() for arg in args if not arg.param.is_constexpr)
spec_key = tuple(arg.specialization_key() for arg in args if not arg.param.do_not_specialize)
constexpr_key = tuple(arg.value for arg in args if arg.param.is_constexpr)
assert num_ctas > 0
assert grid is not None
if callable(grid):
# Arguments are passed as a dict to `grid`, by contract.
# TODO(jlebar): In the new launch API, pass the compiler flags as a
# second parameter to `grid`.
grid = grid(dict(bound_args.arguments))
grid_size = len(grid)
grid_0 = grid[0]
grid_1 = grid[1] if grid_size > 1 else 1
grid_2 = grid[2] if grid_size > 2 else 1
if device_type is None:
device_types = [self._device_of(arg) for arg in non_constexpr_arg_values]
device_types = [_device_type for _device_type in device_types if _device_type != ""]
device_type = self._conclude_device_type(device_types,
[self._pinned_memory_of(arg) for arg in non_constexpr_arg_values])
device_backend = None
if device_type not in ["cuda"]:
device_backend = get_backend(device_type)
if device_backend is None:
raise ValueError("Cannot find backend for " + device_type)
if device is None:
if device_type in ["cuda"]:
device = get_current_device()
set_current_device(device)
else:
device = device_backend.get_current_device()
device_backend.set_current_device(device)
if stream is None and not warmup:
if device_type in ["cuda"]:
stream = get_cuda_stream(device)
else:
stream = device_backend.get_stream()
if num_warps is None:
num_warps = get_arch_default_num_warps(device_type)
if num_stages is None:
num_stages = get_arch_default_num_stages(device_type)
if device_type in ["cuda"]:
version_key = get_cuda_version_key()
else:
version_key = device_backend.get_version_key()
key = (
version_key,
sig_key,
constexpr_key,
spec_key,
num_warps,
num_ctas,
num_stages,
enable_warp_specialization,
enable_fp_fusion,
self.debug,
)
if extern_libs is not None:
key = (key, tuple(extern_libs.items()))
# Kernel is not cached; we have to compile.
if key not in self.cache[device]:
configs = (self._get_config(*[arg.value for arg in args]), )
constants = {
arg.param.num: arg.value
for arg in args
if arg.param.is_constexpr or arg.param.num in configs[0].equal_to_1 or arg.value is None
}
for i, arg in constants.items():
if callable(arg):
raise TypeError(f"Callable constexpr at index {i} is not supported")
# Build kernel signature -- doesn't include constexpr arguments.
signature = {
arg.param.num: self._type_of(self._key_of(arg.value))
for arg in args
if not arg.param.is_constexpr
}
if self._call_hook(
key,
signature,
device,
constants,
num_warps,
num_ctas,
num_stages,
enable_warp_specialization,
enable_fp_fusion,
extern_libs,
configs,
):
return None
self.cache[device][key] = compile(
self,
signature=signature,
device=device,
constants=constants,
num_warps=num_warps,
num_ctas=num_ctas,
num_stages=num_stages,
enable_warp_specialization=enable_warp_specialization,
enable_fp_fusion=enable_fp_fusion,
extern_libs=extern_libs,
configs=configs,
debug=self.debug,
device_type=device_type,
)
bin = self.cache[device][key]
if not warmup:
bin.c_wrapper(
grid_0,
grid_1,
grid_2,
bin.num_warps,
bin.num_ctas,
bin.clusterDims[0],
bin.clusterDims[1],
bin.clusterDims[2],
bin.shared,
stream,
bin.cu_function,
CompiledKernel.launch_enter_hook,
CompiledKernel.launch_exit_hook,
bin,
*bin.assemble_tensormap_to_arg(non_constexpr_arg_values),
)
return bin
类似于grid
, num_warps
以及num_stages
等triton语法特定的参数名,分别从kwargs
内取出,bound_args
内含有fn的所有形参列表。如果grid
是个可调用对象,则将fn的参数也传给grid,正如上述的例子,grid是个lambda函数,需要根据BLOCK_SIZE
计算共要用多少个blocks,计算完后,grid返回block数,维数最高三维,run
函数取出grid的各个维度尺寸。
中间的device
、stream
、num_warps
和nums_stages
等后面研究吧。
根据之前的一些计算好的变量及参数,可以确定一个key,个人理解每个key对应一种优化实现,class JitFunction
提供了一个缓存机制,如果此key已实现,则直接取已编译后的结果,否则执行一次kernel编译。
编译执行跳转到triton/compiler/compiler.py
的compile()
执行,后面有时间研究。最终返回一个CompiledKernel
对象。
CompiledKernel
类会封装一个c_wrapper()
函数,调用如下:
1
2
3
self.c_wrapper(grid[0], grid[1], grid[2], self.num_warps, self.num_ctas, self.clusterDims[0],
self.clusterDims[1], self.clusterDims[2], self.shared, stream, self.cu_function,
CompiledKernel.launch_enter_hook, CompiledKernel.launch_exit_hook, self, *args_expand)
调用底层驱动的后端代码,大致看了下,triton的driver
提供了HIP
和CUDA
两种后端,之前的grid规划的结构会传递到对应后端进行代码实现,具体逻辑就不深究了,以CUDA为例,上面的调用最终会执行:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int clusterDimX, int clusterDimY, int clusterDimZ, int shared_memory, CUstream stream, CUfunction function{', ' + arg_decls if len(arg_decls) > 0 else ''}) {{
void *params[] = {{ {', '.join(f"&arg{i}" for i in params)} }};
if (gridX*gridY*gridZ > 0) {{
if (num_ctas == 1) {{
CUDA_CHECK(cuLaunchKernel(function, gridX, gridY, gridZ, 32*num_warps, 1, 1, shared_memory, stream, params, 0));
}} else {{
CUlaunchAttribute launchAttr[2];
launchAttr[0].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION;
launchAttr[0].value.clusterDim.x = clusterDimX;
launchAttr[0].value.clusterDim.y = clusterDimY;
launchAttr[0].value.clusterDim.z = clusterDimZ;
launchAttr[1].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_SCHEDULING_POLICY_PREFERENCE;
launchAttr[1].value.clusterSchedulingPolicyPreference = CU_CLUSTER_SCHEDULING_POLICY_SPREAD;
CUlaunchConfig config;
config.gridDimX = gridX * clusterDimX;
config.gridDimY = gridY * clusterDimY;
config.gridDimZ = gridZ * clusterDimZ;
config.blockDimX = 32 * num_warps;
config.blockDimY = 1;
config.blockDimZ = 1;
config.sharedMemBytes = shared_memory;
config.hStream = stream;
config.attrs = launchAttr;
config.numAttrs = 2;
static cuLaunchKernelEx_t cuLaunchKernelExHandle = NULL;
if (cuLaunchKernelExHandle == NULL) {{
cuLaunchKernelExHandle = getLaunchKernelExHandle();
}}
CUDA_CHECK(cuLaunchKernelExHandle(&config, function, params, 0));
}}
}}
}}
小结
梳理了@triton.jit
内的部分代码,主要是想弄清楚grid的结构在triton下是如何实现的,慢慢有了点脉络。另外python的骚套路太多了,看这部分的实现也学到一些。感受就是,光看都看得绕绕的,写源码的人也太叼了吧,差距啊。。。