""" Template heuristic registry system for PyTorch Inductor. This module provides a centralized registration system for template heuristics, allowing automatic registration based on device type and conditional registration for CUDA vs ROCm based on torch.version.hip. """ from __future__ import annotations import contextlib import logging from typing import Any, Optional, TYPE_CHECKING, Union from .base import TemplateConfigHeuristics if TYPE_CHECKING: from collections.abc import Iterator # Module-wide registry for template heuristics _TEMPLATE_HEURISTIC_REGISTRY: dict[ tuple[Union[str, None], ...], type[TemplateConfigHeuristics] ] = {} # Manual cache for successful lookups only (fallback instances are not cached) _HEURISTIC_CACHE: dict[tuple[str, str, str], TemplateConfigHeuristics] = {} log = logging.getLogger(__name__) def register_template_heuristic( template_name: str, device_type: Union[str, None], register: bool = True, op_name: Optional[str] = None, ) -> Any: """ Decorator to register template heuristic classes. Args: template_name: Name of the template (e.g., "mm", "bmm", "scaled_mm") device_type: Device type ("cuda", "cpu", "xpu") Set this to None to indicate that the heuristic is applicable to all device types. register: Whether to register this heuristic. Caller should pass the condition directly. op_name: Name of the operator (e.g., "mm", "bmm", "scaled_mm"). This is optional and is only used when a template uses different heuristics for different ops Returns: Decorator function that registers the class if conditions are met. Example: @register_template_heuristic("mm", "cuda", register=torch.version.hip is None) class CUDAMMTemplateConfigHeuristic(MMTemplateConfigMixin, CUDAConfigHeuristic): pass """ def decorator( cls: type[TemplateConfigHeuristics], ) -> type[TemplateConfigHeuristics]: if register: key: tuple[Union[str, None], ...] = (template_name, device_type, op_name) _TEMPLATE_HEURISTIC_REGISTRY[key] = cls log.info( f"Registered template heuristic: {cls.__name__} for '{template_name=}', '{device_type=}', '{op_name=}'" # noqa: G004 ) return cls return decorator def get_template_heuristic( template_name: str, device_type: str, op_name: str ) -> TemplateConfigHeuristics: """ Retrieve a template heuristic instance for the given template and device type. Args: template_name: Name of the template (e.g., "mm", "bmm", "scaled_mm") device_type: Device type ("cuda", "cpu", "xpu") op_name: Name of the operator (e.g., "mm", "bmm", "scaled_mm") Returns: Template heuristic instance. If no specific heuristic is found, returns a fallback TemplateConfigHeuristics() instance (uncached). """ # Check cache first cache_key = (template_name, device_type, op_name) if cache_key in _HEURISTIC_CACHE: return _HEURISTIC_CACHE[cache_key] keys = [ # everything is specified (template_name, device_type, op_name), # heuristic is valid across all devices (template_name, None, op_name), # heuristic is valid across all ops for that device (template_name, device_type, None), # heuristic is always valid for that template (template_name, None, None), ] # Look up in registry heuristic_class = None for key in keys: if key in _TEMPLATE_HEURISTIC_REGISTRY: heuristic_class = _TEMPLATE_HEURISTIC_REGISTRY[key] break if heuristic_class is None: # Log error and return fallback instance (uncached) log.error( "No template heuristic found - template_name=%s, device_type=%s, op_name=%s. " "Available combinations: %s. Using fallback TemplateConfigHeuristics instance.", template_name, device_type, op_name, list(_TEMPLATE_HEURISTIC_REGISTRY.keys()), ) return TemplateConfigHeuristics() # Cache successful lookup and return instance = heuristic_class() _HEURISTIC_CACHE[cache_key] = instance return instance def clear_registry() -> None: """ Clear all registered template heuristics. This is primarily useful for testing purposes to ensure a clean state. """ _TEMPLATE_HEURISTIC_REGISTRY.clear() _HEURISTIC_CACHE.clear() @contextlib.contextmanager def override_template_heuristics( device_type: str, template_op_pairs: list[tuple[str, str]], ) -> Iterator[None]: """ Context manager to temporarily override template heuristics with an empty heuristic. This is useful for testing purposes, where we want to ensure a specific template/op pair is not used Args: device_type: Device type ("cuda", "cpu", "xpu") template_op_pairs: List of (template_name, op_name) pairs to override. """ # Save original entries to restore later original_entries = {} new_keys = [] _HEURISTIC_CACHE.clear() try: for template_name, op_name in template_op_pairs: assert op_name is not None key = (device_type, template_name, op_name) if key in _TEMPLATE_HEURISTIC_REGISTRY: original_entries[key] = _TEMPLATE_HEURISTIC_REGISTRY[key] # TemplateConfigHeuristics base class returns no entries # so we use it for overriding _TEMPLATE_HEURISTIC_REGISTRY[key] = TemplateConfigHeuristics new_keys.append(key) yield finally: # Restore original entries or remove if they didn't exist before for key in new_keys: _TEMPLATE_HEURISTIC_REGISTRY.pop(key, None) if key in original_entries: _TEMPLATE_HEURISTIC_REGISTRY[key] = original_entries[key] _HEURISTIC_CACHE.clear()