import threading import time from functools import wraps from typing import Optional, Union from .flags import get_profiling_on from triton._C.libproton import proton as libproton thread_local_scopes = threading.local() MetricValueType = Union[float, int] class scope: """ A context manager and decorator for entering and exiting a scope. Usage: context manager: ```python with proton.scope("test0", {metric_name: metric_value}): foo[1,](x, y) ``` decorator: ```python @proton.scope("test0", {metric_name: metric_value}) def foo(x, y): ... ``` Args: name (str): The name of the scope. metrics (dict[str, float], optional): The metrics of the scope. Default is None. """ def __init__(self, name: str, metrics: Optional[dict[str, MetricValueType]] = None) -> None: self.name = name self.metrics = metrics self.id = None def _enter_scope(self): if not get_profiling_on(): return self.id = libproton.record_scope() libproton.enter_scope(self.id, self.name) if self.metrics: libproton.add_metrics(self.id, self.metrics) def _exit_scope(self): if not get_profiling_on() or self.id is None: return libproton.exit_scope(self.id, self.name) def __enter__(self): self._enter_scope() return self def __exit__(self, exc_type, exc_value, traceback): self._exit_scope() def __call__(self, func): @wraps(func) def wrapper(*args, **kwargs): self._enter_scope() try: return func(*args, **kwargs) finally: self._exit_scope() return wrapper class cpu_timed_scope(scope): """ A scope that measures elapsed time (cpu_time). Args: name (str): The name of the scope. metrics (dict[str, float], optional): Additional metrics to add. Default is None. """ def __init__(self, name: str, metrics: Optional[dict[str, float]] = None) -> None: super().__init__(name, metrics) self.start_time = None if metrics and "cpu_time" in metrics: raise ValueError("The metric name 'cpu_time' is reserved.") def _enter_scope(self): if not get_profiling_on(): return self.start_time = time.time_ns() super()._enter_scope() def _exit_scope(self): if not get_profiling_on(): return super()._exit_scope() if self.start_time is not None: cpu_time = time.time_ns() - self.start_time libproton.add_metrics(self.id, {"cpu_time (ns)(exc)": cpu_time}) def enter_scope(name: str, *, metrics: Optional[dict[str, MetricValueType]] = None) -> Optional[int]: if not get_profiling_on(): return None id = libproton.record_scope() thread_local_scopes.scopes = getattr(thread_local_scopes, "scopes", []) thread_local_scopes.scopes.append((id, name)) libproton.enter_scope(id, name) if metrics: libproton.add_metrics(id, metrics) return id def exit_scope(name: Optional[str] = None, *, metrics: Optional[dict[str, MetricValueType]] = None) -> Optional[int]: # `name` is an optional argument here, only to match the counterpart in enter_scope to make the API consistent with `proton.language.exit_scope` if not get_profiling_on(): return None id, popped_name = thread_local_scopes.scopes.pop() if name and name != popped_name: raise ValueError(f"Scope name mismatch: {name} != {popped_name}") elif not name: name = popped_name libproton.exit_scope(id, name) if metrics: libproton.add_metrics(id, metrics) return id