from abc import ABCMeta, abstractmethod from typing import Callable, List, Protocol, Sequence class Benchmarker(Protocol): def __call__(self, kernel_call: Callable, *, quantiles: List[float], **kwargs) -> Sequence[float]: pass class DriverBase(metaclass=ABCMeta): @classmethod @abstractmethod def is_active(self): pass @abstractmethod def map_python_to_cpp_type(self, ty: str) -> str: """ Converts a Triton type string to its corresponding C++ type string for this backend. Args: ty (str): The Triton type string. e.g., 'i32', '*fp16', 'fp32'. Returns: str: The C++ type string. """ pass @abstractmethod def get_current_target(self): pass @abstractmethod def get_active_torch_device(self): pass @abstractmethod def get_benchmarker(self) -> Benchmarker: """ Return the benchmarking function that this backend should use by default. """ raise NotImplementedError def __init__(self) -> None: pass class GPUDriver(DriverBase): def __init__(self): # TODO: support other frameworks than torch import torch self.get_device_capability = torch.cuda.get_device_capability try: from torch._C import _cuda_getCurrentRawStream self.get_current_stream = _cuda_getCurrentRawStream except ImportError: self.get_current_stream = lambda idx: torch.cuda.current_stream(idx).cuda_stream self.get_current_device = torch.cuda.current_device self.set_current_device = torch.cuda.set_device # TODO: remove once TMA is cleaned up def assemble_tensormap_to_arg(self, tensormaps_info, args): return args