from __future__ import annotations from triton.compiler.compiler import ASTSource from triton.backends.compiler import Language from triton.runtime.jit import JITFunction, constexpr_function from typing import TypeVar, Optional, Callable, Iterable, Union from triton._C.libtriton import ir T = TypeVar("T") __all__ = ["constexpr_function", "jit"] class GluonASTSource(ASTSource): def __init__(self, fn, signature, constexprs=None, attrs=None) -> None: super().__init__(fn, signature, constexprs, attrs) self.language = Language.GLUON self.ext = "ttgir" def make_ir(self, target, options, codegen_fns, module_map, context): from triton.compiler.compiler import make_backend from triton.compiler.code_generator import ast_to_ttir builder = ir.builder(context) module = builder.create_module() # Assign module attributes eagerly, as they are needed to verify layouts backend = make_backend(target) target = backend.get_target_name(options) module.set_attr("ttg.target", builder.get_string_attr(target)) module.set_attr("ttg.num-warps", builder.get_int32_attr(options.num_warps)) module.set_attr("ttg.num-ctas", builder.get_int32_attr(options.num_ctas)) module.set_attr("ttg.threads-per-warp", builder.get_int32_attr(options.warp_size)) is_cuda = options.backend_name == "cuda" if is_cuda and options.maxnreg is not None: module.set_attr("ttg.maxnreg", builder.get_int32_attr(options.maxnreg)) module = ast_to_ttir(self.fn, self, context=context, options=options, codegen_fns=codegen_fns, module_map=module_map, module=module) return module class GluonJITFunction(JITFunction[T]): def create_binder(self): result = super().create_binder() self.ASTSource = GluonASTSource return result def is_gluon(self): return True def jit( fn: Optional[T] = None, *, version=None, repr: Optional[Callable] = None, launch_metadata: Optional[Callable] = None, do_not_specialize: Optional[Iterable[int | str]] = None, do_not_specialize_on_alignment: Optional[Iterable[int | str]] = None, debug: Optional[bool] = None, noinline: Optional[bool] = None, ) -> Union[GluonJITFunction[T], Callable[[T], JITFunction[T]]]: """ Decorator for JIT-compiling a function using the Triton compiler. :note: When a jit'd function is called, arguments are implicitly converted to pointers if they have a :code:`.data_ptr()` method and a `.dtype` attribute. :note: This function will be compiled and run on the GPU. It will only have access to: * python primitives, * builtins within the triton package, * arguments to this function, * other jit'd functions :param fn: the function to be jit-compiled :type fn: Callable """ def decorator(fn: T) -> JITFunction[T]: assert callable(fn) return GluonJITFunction( fn, version=version, do_not_specialize=do_not_specialize, do_not_specialize_on_alignment=do_not_specialize_on_alignment, debug=debug, noinline=noinline, repr=repr, launch_metadata=launch_metadata, ) if fn is not None: return decorator(fn) else: return decorator