from __future__ import annotations import functools import operator from typing import Any, TYPE_CHECKING import torch # NOTE: other files rely on the imports below from torch._dynamo import callback as compilation_callback # noqa: F401 from torch._inductor.runtime.cache_dir_utils import ( # noqa: F401 cache_dir, default_cache_dir, triton_cache_dir, ) if TYPE_CHECKING: from collections.abc import Hashable from .triton_compat import Config def conditional_product(*args: int) -> int: return functools.reduce(operator.mul, [x for x in args if x]) def ceildiv(number: int, denom: int) -> int: return -(number // -denom) def is_power_of_2(n: int) -> bool: """Returns whether n = 2 ** m for some integer m.""" return n > 0 and n & n - 1 == 0 def next_power_of_2(n: int) -> int: """Return the smallest power of 2 greater than or equal to n""" n -= 1 n |= n >> 1 n |= n >> 2 n |= n >> 4 n |= n >> 8 n |= n >> 16 n |= n >> 32 n += 1 return n def get_num_bytes(*args: torch.Tensor, num_in_out_args: int = 0) -> int: """ Return the total number of bytes the arguments of tensor type takes. For in/out args, tensor sizes are counted twice: once for reading and once for writing. The first num_in_out_args arguments are in out tensors. """ return sum( arg.numel() * arg.element_size() * (1 + int(i < num_in_out_args)) for i, arg in enumerate(args) if isinstance(arg, torch.Tensor) ) def triton_config_to_hashable(cfg: Config) -> Hashable: """ Convert triton config to a tuple that can uniquely identify it. We can use the return value as a dictionary key. """ items = sorted(cfg.kwargs.items()) items.append(("num_warps", cfg.num_warps)) items.append(("num_stages", cfg.num_stages)) return tuple(items) def validate_triton_config(cfg: Config) -> None: # [Note: Triton pre_hook in inductor] # pre-hook is a lambda function, which we don't attempt to serialize. # right now, if a pre-hook is attached to the config, it will not be saved; # and then it won't be used when the config is loaded from cache. # So we assert - if we do get a pre_hook, it might get ignored after caching. assert getattr(cfg, "pre_hook", None) is None, ( "triton configs with pre_hooks not supported" ) def create_bandwidth_info_str( ms: float, num_gb: float, gb_per_s: float, prefix: str = "", suffix: str = "", color: bool = True, ) -> str: info_str = f"{prefix}{ms:.3f}ms \t{num_gb:.3f} GB \t {gb_per_s:7.2f}GB/s{suffix}" slow = ms > 0.012 and gb_per_s < 650 return red_text(info_str) if color and slow else info_str def get_max_y_grid() -> int: return 65535 try: import colorama HAS_COLORAMA = True except ModuleNotFoundError: HAS_COLORAMA = False colorama = None # type: ignore[assignment] if HAS_COLORAMA: def _color_text(msg: str, color: str) -> str: return getattr(colorama.Fore, color.upper()) + msg + colorama.Fore.RESET else: def _color_text(msg: str, color: str) -> str: return msg def green_text(msg: str) -> str: return _color_text(msg, "green") def yellow_text(msg: str) -> str: return _color_text(msg, "yellow") def red_text(msg: str) -> str: return _color_text(msg, "red") def blue_text(msg: str) -> str: return _color_text(msg, "blue") def get_first_attr(obj: Any, *attrs: str) -> Any: """ Return the first available attribute or throw an exception if none is present. """ for attr in attrs: if hasattr(obj, attr): return getattr(obj, attr) raise AssertionError(f"{obj} does not has any of the attributes: {attrs}") dynamo_timed = torch._dynamo.utils.dynamo_timed # type: ignore[has-type] def triton_hash_to_path_key(key: str) -> str: # In early versions of Triton, the hash is directly used in the path name. # Later, the hash is converted to base64 before being used in the path name. # Later, the base64 conversion was replaced to the base32 # # This code tries to import _base64 and falls back to _base32 if _base64 is unavailable. # # To handle this, try to import the to-base64-conversion function. # If it exists, use it; otherwise, try using _base32; if both are unavailable, use the hash directly. try: from triton.runtime.cache import _base64 return _base64(key) except Exception: try: from triton.runtime.cache import _base32 return _base32(key) except Exception: return key def compile_mps_shader(source: str) -> Any: """ Compiles shader source but raise more actionable error message when needed """ try: return torch.mps.compile_shader(source) except SyntaxError as err: raise SyntaxError(f"failed to compile {source} with {err.msg}") from err