Home triton-shared CPU后端分析
Post
Cancel

triton-shared CPU后端分析

triton-shared CPU后端编译

triton-shared由triton源码作为子模块编译,大致代码与triton相同。仓库新增了一个CPU后端,挖一下实现流程。
@triton.jit开始与原生triton一样,一直走到JitFunction封装kernel:

1
2
3
4
5
6
7
# compile the kernel
    src = ASTSource(self, signature, constants, configs[0])
    self.cache[device][key] = compile(
        src,
        target=target,
        options=options.__dict__,
    )

src是kernel的ast,进入triton-shared/triton/python/triton/compiler.py执行编译,compiler有一套编译流水线:

1
2
3
4
# run compilation pipeline  and populate metadata
    stages = dict()
    backend.add_stages(stages, options)
    first_stage = list(stages.keys()).index(src.ext)

执行该流水线,在triton-shared/backend/complier.py内配置流水线步骤:

1
2
3
4
5
def add_stages(self, stages, options):
        stages["ttir"] = lambda src, metadata: self.make_ttir(src, metadata, options)
        stages["ttsharedir"] = lambda src, metadata: _optimize_ttsharedir(_ttir_to_ttsharedir(src))
        stages["llir"] = lambda src, metadata: _optimize_llir(_ttsharedir_to_llir(src))
        stages["cpuasm"] = lambda src, metadata: _llir_to_bin(src, metadata)

在for循环内执行上述ir迭代流程。写入metadata:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
    try:
        module = src.make_ir(options, context)
    except Exception as e:
        filter_traceback(e)
        raise
    for ext, compile_ir in list(stages.items())[first_stage:]:
        next_module = compile_ir(module, metadata)
        ir_filename = f"{src.name}.{ext}"
        metadata_group[ir_filename] = fn_cache_manager.put(next_module, ir_filename)
        if fn_dump_manager is not None:
            fn_dump_manager.put(next_module, ir_filename)
        if (fn_override_manager is not None and fn_override_manager.has_file(ir_filename)):
            print(f"\nOverriding kernel with file {ir_filename}")
            full_name = fn_override_manager.get_file(ir_filename)
            next_module = parse(full_name, ext, context)
        module = next_module
    # write-back metadata
    metadata_group[metadata_filename] = fn_cache_manager.put(json.dumps(metadata, default=vars), metadata_filename,
                                                             binary=False)
    fn_cache_manager.put_group(metadata_filename, metadata_group)
    # return handle to compiled kernel
    return CompiledKernel(src, metadata_group, hash)

最终返回一个CompiledKernel实例,在对此实例配置kernel的grid语法下会触发以下代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
def __getitem__(self, grid):
    self._init_handles()

    def runner(*args, stream=None):
        if stream is None:
            device = driver.active.get_current_device()
            stream = driver.active.get_current_stream(device)
        md = self.metadata
        self.run(grid[0], grid[1], grid[2], md.num_warps, md.num_ctas, md.cluster_dims[0], md.cluster_dims[1],
                    md.cluster_dims[2], md.shared, stream, self.function, CompiledKernel.launch_enter_hook,
                    CompiledKernel.launch_exit_hook, md, *args)

    return runner

先执行一次self._init_handles(),代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
def _init_handles(self):
    if self.module is not None:
        return
    device = driver.active.get_current_device()
    # create launcher
    self.run = driver.active.launcher_cls(self.src, self.metadata)
    # not enough shared memory to run the kernel
    max_shared = driver.active.utils.get_device_properties(device)["max_shared_mem"]
    if self.metadata.shared > max_shared:
        raise OutOfResources(self.metadata.shared, max_shared, "shared memory")
    # TODO: n_regs, n_spills should be metadata generated when calling `ptxas`
    self.module, self.function, self.n_regs, self.n_spills = driver.active.utils.load_binary(
        self.name, self.kernel, self.metadata.shared, device)

创建了一个driver的launcher_cls加载器,如果是CPU后端的driver加载器则指向了CPULauncher

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
class CPULauncher(object):

    def __init__(self, src, metadata):
        constants = src.constants if hasattr(src, "constants") else dict()

        kernel_placeholder_name = "KERNEL_NAME_PLACEHOLDER"
        launcher_src = _generate_launcher(constants, src.signature, kernel_placeholder_name)
        # Later KERNEL_NAME_PLACEHOLDER will be used to assign the kernel name
        # in the following launch function.
        self.launch = compile_module(launcher_src, kernel_placeholder_name)

    def __call__(self, *args, **kwargs):
        self.launch(*args, **kwargs)

class CPUDriver(DriverBase):

    def __init__(self):
        super().__init__()
        self.utils = CPUUtils()
        self.launcher_cls = CPULauncher
        self.binary_ext = "cpuasm"

其中_generate_launcher()生成了一段C代码的主程序,在grid维度下运行kernel代码,并封装成可供python调用的外部C模块:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
...
extern "C" {{
  // Pointer type (=Memref) becomes int64_t + MemRef struct
  // FIXME: understand what this int64_t is used for.
  void {kernel_name}({', '.join(_ty_to_cpp(ty) if ty[0] != "*" else f"int64_t, void*" for i, ty in signature.items() if i not in constants)},
                       int, int, int, int, int, int);
}}

static void _launch(int gridX, int gridY, int gridZ, {arg_decls}) {{
  if (gridX*gridY*gridZ > 0) {{
    // Cast "function" to the real function type.
    for(int x = 0; x < gridX; x++) {{
      for(int y = 0; y < gridY; y++) {{
        for(int z = 0; z < gridZ; z++) {{
          // Use some random type "char" here.
          {' '.join(f'StridedMemRefType<char, 0> ptr_arg{i} = {{static_cast<char *>(arg{i}), static_cast<char *>(arg{i}), 0}};' for i, ty in signature.items() if i not in constants and ty[0] == "*")}
          {kernel_name}({', '.join(f"static_cast<{_ty_to_cpp(ty)}>(arg{i})" if ty[0] != "*" else f"0, &ptr_arg{i}" for i, ty in signature.items() if i not in constants)},
                        gridX, gridY, gridZ, x, y, z);
        }}
      }}
    }}
  }}
}}
...

_generate_launcher()执行完成后,执行compile_module(),返回一个launch()方法供后续调用:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
def compile_module(launcher_src, kernel_placeholder_name):
    # This function was renamed and made public in Python 3.10
    if hasattr(sysconfig, 'get_default_scheme'):
        scheme = sysconfig.get_default_scheme()
    else:
        scheme = sysconfig._get_default_scheme()
    # 'posix_local' is a custom scheme on Debian. However, starting Python 3.10, the default install
    # path changes to include 'local'. This change is required to use triton with system-wide python.
    if scheme == 'posix_local':
        scheme = 'posix_prefix'
    py_include_dir = sysconfig.get_paths(scheme=scheme)["include"]
    cpu_backend_path = Path(__file__).resolve().parent
    include_dir = os.path.join(cpu_backend_path, "include")

    def launch(
        gridX, gridY, gridZ, num_warps, num_ctas, clusterDim0, clusterDim1, clusterDim2,
        shared, stream, cu_function, launch_enter_hook, launch_exit_hook, metadata,
        *args):
        # Unlike CUDA/HIP, we cannot easily pass function pointer across different pybind libraries.
        # Let's compile a kernel every time.
        # The cu_function parameter actually contains our assembly source code.
        # See CPUUtils.load_binary method.
        asm_src = cu_function
        src = launcher_src.replace(kernel_placeholder_name, metadata.name)

        key = hashlib.md5(src.encode("utf-8")).hexdigest()
        cache = get_cache_manager(key)
        name = "__triton_shared_ref_cpu_kernel_launcher"
        filename = f"{name}.so"
        cache_path = cache.get_file(filename)

        if cache_path is None:
          with tempfile.TemporaryDirectory() as tmpdir:
              asm_src_path = os.path.join(tmpdir, "kernel.s")
              launcher_src_path = os.path.join(tmpdir, "main.cxx")
              so_path = os.path.join(tmpdir, "kernel.so")
              Path(asm_src_path).write_bytes(asm_src)
              Path(launcher_src_path).write_text(src)
              # Compile it together.
              subprocess.check_call([
                "g++", launcher_src_path, asm_src_path,
                f"-I{py_include_dir}", f"-I{include_dir}",
                "-shared", "-fPIC", "-o", so_path
              ])

              with open(so_path, "rb") as f:
                cache_path = cache.put(f.read(), filename, binary=True)

        # Load and launch the compiled kernel.
        spec = importlib.util.spec_from_file_location(name, cache_path)
        mod = importlib.util.module_from_spec(spec)
        spec.loader.exec_module(mod)
        return mod.launch(gridX, gridY, gridZ, launch_enter_hook, launch_exit_hook, metadata, *args)

    return launch

launch方法接受类似原生triton compiler的run方法的参数,如grid维度等,其中cu_function是核函数的lower ir,描述为asm汇编代码,总共编译两个文件,一个是刚刚生成的main.cxx代码,占位符被配置的核函数名替代;另一个是核函数kernel.s的asm代码。
编译完成后的C模块被加载到python,模块的launch方法返回给CPULauncherself.launch,调用CompliedKernel时会触发调用栈最终到该方法。每次触发上述launch()判断函数的hash code是否有变化,若有改变则要进行一次g++编译,注释中也说明了原因,因为不能像CUDA/HIP那样在不同的模块间传递函数指针,所以把核函数的asm代码加进来混合编译。

This post is licensed under CC BY 4.0 by the author.