from abc import ABCMeta, abstractmethod from dataclasses import dataclass from enum import Enum from typing import Dict, Union from types import ModuleType @dataclass(frozen=True) class GPUTarget(object): # Target backend, e.g., cuda, hip backend: str # Target architecture, e.g., 90 (for cuda compute capability), gfx940 (for hip) arch: Union[int, str] warp_size: int class Language(Enum): """The input language being compiled by the backend.""" TRITON = 0 GLUON = 1 class BaseBackend(metaclass=ABCMeta): def __init__(self, target: GPUTarget) -> None: self.target = target assert self.supports_target(target) @staticmethod @abstractmethod def supports_target(target: GPUTarget): raise NotImplementedError @abstractmethod def hash(self) -> str: """Returns a unique identifier for this backend""" raise NotImplementedError @abstractmethod def parse_options(self, options: dict) -> object: """ Converts an `options` dictionary into an arbitrary object and returns it. This function may contain target-specific heuristics and check the legality of the provided options """ raise NotImplementedError @abstractmethod def add_stages(self, stages: dict, options: object) -> None: """ Populates `stages` dictionary with entries of the form: ir_name [str] => Function[(src: str, metadata: dict) -> str|bytes] The value of each entry may populate a `metadata` dictionary. Stages will be run sequentially (in inseriton order) and can communicate using `metadata`. All stages are expected to return a `str` object, except for the last stage which returns a `bytes` object for execution by the launcher. """ raise NotImplementedError @abstractmethod def load_dialects(self, context): """ Load additional MLIR dialects into the provided `context` """ raise NotImplementedError @abstractmethod def get_module_map(self) -> Dict[str, ModuleType]: """ Return a map of interface modules to their device-specific implementations """ raise NotImplementedError @staticmethod def parse_attr(desc): assert isinstance(desc, str) ret = [] if "D" in desc: ret += [["tt.divisibility", 16]] return ret @staticmethod def get_arg_specialization(arg, ty, **kwargs): """ Return a string unique to each possible specialization of the argument """ if ty == "int" and arg % 16 == 0 and kwargs.get("align", False): return "D" if ty == "tensor" and arg.data_ptr() % 16 == 0 and kwargs.get("align", False): return "D" return ""