from __future__ import annotations import hashlib import json from .._C.libtriton import get_cache_invalidating_env_vars, ir from ..backends import backends from ..backends.compiler import Language from ..backends.compiler import BaseBackend, GPUTarget from .. import __version__, knobs from ..runtime.autotuner import OutOfResources from ..runtime.cache import get_cache_manager, get_dump_manager, get_override_manager, get_cache_key from ..runtime.driver import driver from ..tools.disasm import get_sass from pathlib import Path import re import functools import os import time import copy # - ^\s*tt\.func\s+ : match the start of the string, any leading whitespace, the keyword func, # and any following whitespace # - (public\s+)? : optionally match the keyword public and any following whitespace # - (@\w+) : match an @ symbol followed by one or more word characters # (letters, digits, or underscores), and capture it as group 1 (the function name) # - (\((?:%\w+: \S+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\)) : match a pair of parentheses enclosing # zero or more arguments separated by commas, and capture it as group 2 (the argument list) # - (attributes \{[\S\s]+\})? : optionally match attributes enclosed in braces and capture it as group 3 ptx_prototype_pattern = r"\.(?:visible|extern)\s+\.(?:entry|func)\s+(\w+)\s*\(([^)]*)\)" prototype_pattern = { "ptx": ptx_prototype_pattern, } ptx_arg_type_pattern = r"\.param\s+\.(\w+)" arg_type_pattern = { "ptx": ptx_arg_type_pattern, } def convert_type_repr(x): # Currently we only capture the pointer type and assume the pointer is on global memory. # TODO: Capture and support shared memory space match = re.search(r'!tt\.ptr<([^,]+)', x) tma = re.search(r'tt.nv_tma_desc = 1', x) if tma is not None: return 'nvTmaDesc' x = re.sub(r' {[^}]+}', '', x) if match is not None: return '*' + convert_type_repr(match.group(1)) return x class ASTSource: def __init__(self, fn, signature, constexprs=None, attrs=None) -> None: self.fn = fn self.language = Language.TRITON self.ext = "ttir" self.name = fn.__name__ self.signature = signature self.constants = dict() if constexprs is not None: for k, v in constexprs.items(): k = (fn.arg_names.index(k), ) if isinstance(k, str) else k assert isinstance(k, tuple) self.constants[k] = v self.attrs = attrs or dict() for k in self.signature.keys(): if not isinstance(k, str): raise TypeError("Signature keys must be string") def hash(self): sorted_sig = [v for k, v in sorted(self.signature.items())] get_key = lambda x: x.cache_key if hasattr(x, 'cache_key') else str(x) constants_key = '-'.join([get_key(v) for k, v in sorted(self.constants.items())]) key = f"{self.fn.cache_key}-{str(self.attrs)}-{sorted_sig}-{constants_key}" return hashlib.sha256(key.encode("utf-8")).hexdigest() def make_ir(self, target: GPUTarget, options, codegen_fns, module_map, context): from .code_generator import ast_to_ttir return ast_to_ttir(self.fn, self, context=context, options=options, codegen_fns=codegen_fns, module_map=module_map) def parse_options(self): return dict() class IRSource: def __init__(self, path, context, backend): self.path = path path = Path(path) self.ext = path.suffix[1:] self.language = Language.TRITON self.src = path.read_text() ir.load_dialects(context) backend.load_dialects(context) # We don't have a easy-to-use PTX parser that we can use, so keep that regex for now. # TODO - replace with a proper parser if self.ext == "ptx": match = re.search(prototype_pattern[self.ext], self.src, re.MULTILINE) self.name = match.group(1) signature = match.group(2) types = re.findall(arg_type_pattern[self.ext], signature) self.signature = {k: convert_type_repr(ty) for k, ty in enumerate(types)} else: self.module = ir.parse_mlir_module(self.path, context) fn_name = self.module.get_entry_func_name() self.name = "@" + fn_name funcOp = self.module.get_function(fn_name) func_ty = self.module.get_function_signature(funcOp) self.signature = {k: ty for k, ty in enumerate(func_ty)} def hash(self): return hashlib.sha256(self.src.encode("utf-8")).hexdigest() def make_ir(self, target: GPUTarget, options, codegen_fns, module_map, context): self.module.context = context return self.module def parse_options(self): if self.ext == "ttgir": num_warps = self.module.get_int_attr("ttg.num-warps") assert num_warps is not None, "Unable to parse ttg.num-warps attribute" return {'num_warps': num_warps} return dict() @functools.lru_cache() def max_shared_mem(device): return driver.active.utils.get_device_properties(device)["max_shared_mem"] def parse(full_name, ext, context): if ext == "ttir" or ext == "ttgir": module = ir.parse_mlir_module(full_name, context) module.context = context return module if ext == "llir" or ext == "ptx" or ext == "amdgcn": return Path(full_name).read_text() if ext == "cubin" or ext == "hsaco": return Path(full_name).read_bytes() def filter_traceback(e: BaseException): """ Removes code_generator.py and related files from tracebacks. These are uninteresting to the user -- "just show me *my* code!" """ if knobs.compilation.front_end_debugging: return if e.__cause__ is not None: filter_traceback(e.__cause__) if e.__context__ is not None: filter_traceback(e.__context__) # If a user has a file that matches one of these, they're out of luck. BAD_FILES = [ "/triton/compiler/code_generator.py", "/ast.py", ] BAD_FILES = [bad_file.replace("/", os.sep) for bad_file in BAD_FILES] tb = e.__traceback__ frames = [] while tb is not None: if not any(f for f in BAD_FILES if tb.tb_frame.f_code.co_filename.endswith(f)): frames.append(tb) tb = tb.tb_next for (cur_frame, next_frame) in zip(frames, frames[1:]): cur_frame.tb_next = next_frame if not frames: e.__traceback__ = None else: frames[-1].tb_next = None e.__traceback__ = frames[0] class CompileTimer: def __init__(self) -> None: self.start: float = time.time() self.ir_initialization_end: float | None = None self.lowering_stage_ends: list[tuple[str, float]] = [] self.store_results_end: float | None = None def finished_ir_initialization(self) -> None: self.ir_initialization_end = time.time() def stage_finished(self, stage_name: str) -> None: self.lowering_stage_ends.append((stage_name, time.time())) def end(self) -> knobs.CompileTimes: timestamp = time.time() if self.ir_initialization_end is None: self.ir_initialization_end = timestamp else: self.store_results_end = timestamp def delta(start: float, end: float | None) -> int: if end is None: return 0 return int((end - start) * 1000000) lowering_stage_durations = [] stage_start = self.ir_initialization_end for stage_name, stage_end in self.lowering_stage_ends: lowering_stage_durations.append((stage_name, delta(stage_start, stage_end))) stage_start = stage_end return knobs.CompileTimes( ir_initialization=delta(self.start, self.ir_initialization_end), lowering_stages=lowering_stage_durations, store_results=delta(stage_start, self.store_results_end), ) def compile(src, target=None, options=None, _env_vars=None): compilation_listener = knobs.compilation.listener if compilation_listener: timer = CompileTimer() if target is None: target = driver.active.get_current_target() assert isinstance(target, GPUTarget), "target must be of GPUTarget type" backend = make_backend(target) ir_source = not isinstance(src, ASTSource) # create backend if ir_source: assert isinstance(src, str), "source must be either AST or a filepath" context = ir.context() src = IRSource(src, context, backend) extra_options = src.parse_options() options = backend.parse_options(dict(options or dict(), **extra_options)) # create cache manager env_vars = get_cache_invalidating_env_vars() if _env_vars is None else _env_vars key = get_cache_key(src, backend, options, env_vars=env_vars) hash = hashlib.sha256(key.encode("utf-8")).hexdigest() fn_cache_manager = get_cache_manager(hash) # For dumping/overriding only hash the source as we want it to be independent of triton # core changes to make it easier to track kernels by hash. enable_override = knobs.compilation.override enable_ir_dump = knobs.compilation.dump_ir store_only_binary = knobs.compilation.store_binary_only fn_override_manager = get_override_manager(src.hash()) if enable_override else None fn_dump_manager = get_dump_manager(src.hash()) if enable_ir_dump else None # Pre-truncate the file name here to avoid hitting the 255 character limit on common platforms. # The final file name in the cache will have a format of f"{filename}.{ext}.tmp.pid_{pid}_{uuid}". # A PID string can be 5-character long. A UUID string has typically 36 characters. Let's truncate # the file name to 150 characters to be safe. file_name = src.name[:150] metadata_filename = f"{file_name}.json" metadata_group = fn_cache_manager.get_group(metadata_filename) or {} metadata_path = metadata_group.get(metadata_filename) always_compile = knobs.compilation.always_compile if not always_compile and metadata_path is not None: # cache hit! res = CompiledKernel(src, metadata_group, hash) if compilation_listener: compilation_listener( src=src, metadata=res.metadata._asdict(), metadata_group=metadata_group, times=timer.end(), cache_hit=True, ) return res # initialize metadata metadata = { "hash": hash, "target": target, **options.__dict__, **env_vars, } metadata["triton_version"] = __version__ # run compilation pipeline and populate metadata stages = dict() backend.add_stages(stages, options, src.language) first_stage = list(stages.keys()).index(src.ext) # when the source is an IR file, don't apply the passes related to this stage. This makes it easier to write IR level tests. if ir_source: first_stage += 1 # For IRSource, we have already grabbed the context + called both # ir.load_dialects and backend.load_dialects. if not isinstance(src, IRSource): context = ir.context() ir.load_dialects(context) backend.load_dialects(context) codegen_fns = backend.get_codegen_implementation(options) module_map = backend.get_module_map() try: module = src.make_ir(target, options, codegen_fns, module_map, context) except Exception as e: filter_traceback(e) raise if ir_source: ir_filename = f"{file_name}.{src.ext}" metadata_group[ir_filename] = fn_cache_manager.put(module, ir_filename) else: ir_filename = f"{file_name}.source" metadata_group[ir_filename] = fn_cache_manager.put(module, ir_filename) use_ir_loc = knobs.compilation.use_ir_loc if ir_source and use_ir_loc: module.create_location_snapshot(src.path) print(f"Creating new locations for {src.path}") if compilation_listener: timer.finished_ir_initialization() for ext, compile_ir in list(stages.items())[first_stage:]: next_module = compile_ir(module, metadata) ir_filename = f"{file_name}.{ext}" if fn_override_manager is None: # Users can override kernels at scale by setting `ir_override` in autotune config # without TRITON_KERNEL_OVERRIDE if (ir_override := metadata.get("ir_override", None)) and ir_override.endswith(f".{ext}"): next_module = parse(ir_override, ext, context) elif full_name := fn_override_manager.get_file(ir_filename): print(f"\nOverriding kernel with file {full_name}") next_module = parse(full_name, ext, context) # If TRITON_STORE_BINARY_ONLY is 1, only store cubin/hsaco/json if (not store_only_binary) or (ext in ("cubin", "hsaco", "json")): metadata_group[ir_filename] = fn_cache_manager.put(next_module, ir_filename) if fn_dump_manager is not None: fn_dump_manager.put(next_module, ir_filename) if ext == "cubin": sass = get_sass(next_module) fn_dump_manager.put(sass, file_name + ".sass") # use an env variable to parse ir from file if use_ir_loc == ext: ir_full_name = fn_cache_manager.get_file(ir_filename) next_module.create_location_snapshot(ir_full_name) print(f"Creating new locations for {ir_full_name}") module = next_module if compilation_listener: timer.stage_finished(ext) # write-back metadata metadata_group[metadata_filename] = fn_cache_manager.put(json.dumps(metadata, default=vars), metadata_filename, binary=False) fn_cache_manager.put_group(metadata_filename, metadata_group) # Compilation completed, disabling multithreading in context. # This is needed to safely finalize threads pool inside context: if current process forks before # python GC deletes context object, thread pool in child process will be invalid, which could # lead to child crash or hang. # # However disabling multithreading causes the code to hang if the ASAN pass is enabled # this is likely due to the llvm-symbolizer forking a process # TODO: Reconcile the difference here between the ASAN and non-ASAN path with enabling # multithreading in the MLIR context if not knobs.compilation.enable_asan: context.disable_multithreading() # notify any listener if compilation_listener: compilation_listener(src=src, metadata=metadata, metadata_group=metadata_group, times=timer.end(), cache_hit=False) # return handle to compiled kernel return CompiledKernel(src, metadata_group, hash) def make_backend(target: GPUTarget) -> BaseBackend: actives = [x.compiler for x in backends.values() if x.compiler.supports_target(target)] if len(actives) != 1: raise RuntimeError( f"{len(actives)} compatible backends for target ({target.backend}) ({actives}). There should only be one.") return actives[0](target) class LazyDict: def __init__(self, data): self.data = data self.extras = [] def get(self): for func, args in self.extras: self.data = self.data | func(*args) self.extras.clear() return self.data def add(self, func, args): self.extras.append((func, args)) class AsmDict(dict): def __missing__(self, key): if key == "sass": value = get_sass(self["cubin"]) else: raise KeyError("Unknown key: '%s'" % key) self[key] = value return value def _raise_error(err, *args, **kwargs): raise copy.deepcopy(err) class CompiledKernel: def __init__(self, src, metadata_group, hash): from collections import namedtuple metadata_path = next((Path(p) for c, p in metadata_group.items() if c.endswith(".json"))) metadata = json.loads(metadata_path.read_text()) metadata['cluster_dims'] = tuple(metadata['cluster_dims']) # JSON serialization dumps the target as a dict. Restore it to a GPUTarget. target = metadata['target'] metadata['target'] = GPUTarget(target['backend'], target['arch'], target['warp_size']) KernelMetadata = namedtuple('KernelMetadata', sorted(list(metadata.keys()))) self.metadata = KernelMetadata(**metadata) backend = make_backend(self.metadata.target) self.packed_metadata = backend.pack_metadata(self.metadata) self.src = src self.hash = hash self.name = self.metadata.name # stores the text of each level of IR that was generated during compilation asm_files = [Path(p) for c, p in metadata_group.items() if not c.endswith(".json")] binary_ext = backend.binary_ext self.asm = AsmDict({ file.suffix[1:]: file.read_bytes() if file.suffix[1:] == binary_ext else file.read_text() for file in asm_files }) self.metadata_group = metadata_group self.kernel = self.asm[binary_ext] # binaries are lazily initialized # because it involves doing runtime things # (e.g., checking amount of shared memory on current device) self.module = None self.function = None self._run = None def _init_handles(self): if self.module is not None: return def raise_(err): # clone the exception object so that the one saved in the closure # of the partial function below doesn't get assigned a stack trace # after the subsequent raise. otherwise, the CompiledKernel instance # saved in the (global) kernel cache will keep references to all the # locals in the traceback via the exception instance in the closure. cloned_err = copy.deepcopy(err) self._run = functools.partial(_raise_error, cloned_err) raise err device = driver.active.get_current_device() # create launcher self._run = driver.active.launcher_cls(self.src, self.metadata) # not enough shared memory to run the kernel max_shared = max_shared_mem(device) if self.metadata.shared > max_shared: raise_(OutOfResources(self.metadata.shared, max_shared, "shared memory")) if hasattr(self.metadata, "tmem_size") and self.metadata.tmem_size is not None: # Use blackwell max tmem size for now, this should be moved in device properties max_tmem_size = 512 # tmem size in number of columns if self.metadata.tmem_size > max_tmem_size: raise_(OutOfResources(self.metadata.tmem_size, max_tmem_size, "tensor memory")) if knobs.runtime.kernel_load_start_hook is not None: knobs.runtime.kernel_load_start_hook(self.module, self.function, self.name, self.metadata_group, self.hash) # TODO: n_regs, n_spills should be metadata generated when calling `ptxas` self.module, self.function, self.n_regs, self.n_spills, self.n_max_threads = driver.active.utils.load_binary( self.name, self.kernel, self.metadata.shared, device) warp_size = driver.active.get_current_target().warp_size if self.metadata.num_warps * warp_size > self.n_max_threads: raise_(OutOfResources(self.metadata.num_warps * warp_size, self.n_max_threads, "threads")) if knobs.runtime.kernel_load_end_hook is not None: knobs.runtime.kernel_load_end_hook(self.module, self.function, self.name, self.metadata_group, self.hash) @property def run(self): if self._run is None: self._init_handles() return self._run def launch_metadata(self, grid, stream, *args): if knobs.runtime.launch_enter_hook is None: return None self._init_handles() ret = LazyDict({"name": self.name, "function": self.function, "stream": stream}) if not isinstance(self.src, ASTSource) or self.src.fn.launch_metadata is None: return ret arg_dict = {name: arg for name, arg in zip(self.src.fn.arg_names, args)} ret.add(self.src.fn.launch_metadata, (grid, self.metadata, arg_dict)) return ret def __getitem__(self, grid): self._init_handles() def runner(*args, stream=None): if stream is None: device = driver.active.get_current_device() stream = driver.active.get_current_stream(device) launch_metadata = self.launch_metadata(grid, stream, *args) self.run(grid[0], grid[1], grid[2], stream, self.function, self.packed_metadata, launch_metadata, knobs.runtime.launch_enter_hook, knobs.runtime.launch_exit_hook, *args) return runner