# mypy: allow-untyped-defs import functools import inspect import itertools import logging from dataclasses import dataclass from typing import Any, Callable, Optional, Union import torch import torch.utils._pytree as pytree log = logging.getLogger(__name__) trace_shape_events_log = torch._logging.getArtifactLogger( __name__, "trace_shape_events" ) __all__ = [ "ShapeEnvEvent", "record_shapeenv_event", "replay_shape_env_events", "FakeTensorMeta", "shape_env_check_state_equal", "NotEqualError", ] # [Note: Recording ShapeEnv Events] # ================================= # # What is a ShapeEnv event? # ------------------------- # We consider a ShapeEnv event every function call (ShapeEnv method or # independent function) that modifies the state of the ShapeEnv instance. # Such calls are recorded alongside their positional and keyword arguments, # so that it may be replayed over a different ShapeEnv instance. # # See [Note: ShapeEnv State Equality] for what is considered the state # of a ShapeEnv instance. # # What is it for? # --------------- # ShapeEnv events recording is used for reconstructing the ShapeEnv in an # arbitrary state in time. # # Being able to arbitrarily replay events like so is useful, mainly for # translation validation bisection. i.e. if a ValidationException has been # raised, find the earliest point in time where the translation validation # fails. # # Besides that, it also allows us to inspect the given instance and, # for example, check the guards that would actually be issued at that point. # # What kind of arguments can be stored in an event? # ------------------------------------------------- # There's no specific rule for what cannot be used as an argument. # That said, pay special attention to the following cases: # # 1. Tensor inputs: there are some tests that check whether the inputs # were garbage collected after execution. These will fail if there's # an event that is holding a reference to those inputs. # # 2. ShapeEnv arguments: if there is an argument of ShapeEnv type, that # will be automatically replaced by the new given ShapeEnv instance. # # 3. SymTypes arguments: they also hold references to ShapeEnv. So, # whenever we see them, we create a new instance, replacing the # ShapeEnv reference. # # 4. FX nodes: specifically, FX nodes from the FX graph for symbolic # shapes. That argument must be replaced when replaying the event at # ShapeEnvEvent.run, since it has to reference a node from the given # instance, and not from the recorded instance. # Event class for reconstructing ShapeEnv at arbitrary time. # # Represents a method call that mutates ShapeEnv in a way that affects the # issued guards, when ShapeEnv.produce_guards is called. @dataclass class ShapeEnvEvent: # ShapeEnv method. f: Callable # Arguments and keyword arguments called with. args: Optional[list[Any]] = None kwargs: Optional[dict[str, Any]] = None # List of tracked_fakes at the time the method was called. tracked_fakes: Optional[list[Any]] = None # Name of the captured event. # Used for special handling of particular methods. name: Optional[str] = None # Replay itself, but using shape_env as self. def run(self, shape_env=None) -> Any: from torch.fx.experimental.symbolic_shapes import ( is_symbolic, ShapeEnv, SymTypes, ) # Special handling for the constructor event. if self.f is ShapeEnv: assert shape_env is None and self.args is None and self.kwargs is not None return ShapeEnv(**self.kwargs) assert shape_env is not None args = list(self.args or []) kwargs = dict(self.kwargs or {}) # Replace any argument of type ShapeEnv by the given one. args, kwargs = pytree.tree_map_only( ShapeEnv, lambda _: shape_env, (args, kwargs) ) # Replace any argument of type SymTypes by a new instance, # replacing its ShapeEnv reference. args, kwargs = pytree.tree_map_only( lambda x: isinstance(x, SymTypes) and is_symbolic(x), lambda a: type(a)(a.node.with_shape_env(shape_env)), (args, kwargs), ) # Converts FX nodes using the mapping argument. def maybe_convert_node(x: Any) -> Any: if not isinstance(x, torch.fx.Node): # Don't do anything to x if it's not an FX node. return x # If, at some point, we created an FX node, it means that translation validation is on. # It also means we are building an FX graph for symbolic shapes at shape_env.graph, and # we are tracking node names at shape_env.name_to_node. assert hasattr(shape_env, "name_to_node") name_to_node = shape_env.name_to_node # type: ignore[attr-defined] assert x.name in name_to_node return name_to_node[x.name] # Replaces the value of an specific argument by the result of fn. def replacearg(index: int, key: str, fn: Callable): if index < len(args): args[index] = fn(args[index]) if key in kwargs: kwargs[key] = fn(kwargs[key]) if self.is_create_fx_call_function(): # ShapeEnv.create_fx_call_function: # "args" parameter is a tuple of FX nodes from the FX graph of the old ShapeEnv. # They must be replaced, since a "call_function" FX node with this tuple as argument # will be added to the FX graph of the new shape_env. replacearg( index=2, key="args", fn=lambda args: tuple(maybe_convert_node(a) for a in args), ) if self.is_evaluate_expr() or self.is_defer_runtime_assert(): # ShapeEnv.evaluate_expr and ShapeEnv.guard_or_defer_runtime_assert: # "fx_node" parameter is an (optional) FX node that represents the evaluate expression. # They must be replaced, since it will be part of a "call_function" FX node for # torch._assert, which will be added to the FX graph of the new shape_env. replacearg(index=3, key="fx_node", fn=maybe_convert_node) # Actually call the method with the converted arguments. return self.f(*args, **kwargs) def __str__(self) -> str: name = self.name if self.name is not None else self.f.__name__ return f"event: {name} ({self.args}, {self.kwargs})" def is_create_fx_call_function(self) -> bool: return self.name == "_create_fx_call_function" def is_evaluate_expr(self) -> bool: return self.name == "evaluate_expr" def is_defer_runtime_assert(self) -> bool: return self.name == "guard_or_defer_runtime_assert" NEST = 0 # Extracts a ShapeEnv instance inside args and kwargs. # Specifically, it looks for: # 1. ShapeEnv arguments # 2. SymInt, SymFloat, or SymBool arguments # If we find more than one object of any of the above types, we # also check that the ShapeEnv instance is the same for all of them. def _extract_shape_env_and_assert_equal(args, kwargs): from torch.fx.experimental.symbolic_shapes import is_symbolic, ShapeEnv, SymTypes def assert_equal(old: Optional[ShapeEnv], new: ShapeEnv) -> ShapeEnv: if old is not None: assert old is new, "call with different ShapeEnv" return new shape_env = None for val in itertools.chain(args, kwargs.values()): if isinstance(val, ShapeEnv): shape_env = assert_equal(shape_env, val) if isinstance(val, SymTypes) and is_symbolic(val): shape_env = assert_equal(shape_env, val.node.shape_env) return shape_env # Decorator for recording the given function as a replayable event. # # This decorator should be used at every function that mutates the state of # ShapeEnv in some way that affects the resulting issued guards (i.e. when # ShapeEnv.produce_guards is called). # # save_tracked_fakes: saves a snapshot of the TrackedFake list. # This is used when calling ShapeEnv.produce_guards at arbitrary points in time. # # name: the name of the function being recorded. Normally (and by default) this # is taken from the decorated function but can be set if you need to override # it. # # When to save the list of TrackedFake? # ===================================== # We should save the list of TrackedFake whenever the translation validation # bisection may actually stop and call the produce_guards method at the moment # right after the recorded function was played. In other words, since the # bisection bisects through torch._assert calls, we should save in all methods # that adds a torch._assert call to the symbolic shapes FX graph. # # At the moment, there are 2 methods that save the list: # - ShapeEnv.evaluate_expr # - ShapeEnv.guard_or_defer_runtime_assert def record_shapeenv_event( *, save_tracked_fakes: bool = False, name: Optional[str] = None ) -> Callable: def decorator(fn: Callable) -> Callable: assert callable(fn) args = inspect.getfullargspec(fn).args assert args and args[0] == "self", ( "record_shapeenv_event should only wrap methods on ShapeEnv; refactor your " "code so that it calls into a method on ShapeEnv" ) nonlocal name if name is None: name = fn.__name__ @functools.wraps(fn) def wrapper(*args, **kwargs): from torch.fx.experimental.symbolic_shapes import ShapeEnv assert isinstance(args[0], ShapeEnv) global NEST trace_shape_events_log.debug( "%scall %s(*%r, **%r)", " " * NEST, name, args[1:], kwargs ) NEST += 1 def retlog(r): trace_shape_events_log.debug("%s-> %s", " " * (NEST - 1), r) return r shape_env = args[0] try: if not shape_env.should_record_events or shape_env.is_recording: # type: ignore[has-type] # If ShapeEnv is already recording an event, call the wrapped # function directly. # # NB: here, we skip the check of whether all ShapeEnv instances # are equal, in favor of a faster dispatch. return retlog(fn(*args, **kwargs)) # Retrieve an instance of ShapeEnv. # Assumption: the collection of args and kwargs may not reference # different ShapeEnv instances. self = _extract_shape_env_and_assert_equal(args, kwargs) # If we are calling this function without any ShapeEnv instance # alive in its arguments, we don't record and call the original. if self is None: return retlog(fn(*args, **kwargs)) # Otherwise, start recording and call the function. with self._recording(): # Take a snapshot of the current tracked_fakes. tracked_fakes = ( self._snapshot_tracked_fakes() if save_tracked_fakes else None ) # Record the event for 'fn'. event = ShapeEnvEvent( fn, list(args), kwargs, tracked_fakes, name=name, ) # Play the event on this ShapeEnv. # NB: It's important to put the event first, because running # the event can trigger internal events that must be ordered # after this event. However, if an exception happens, we do # NOT want to have the event in the list, so pop it off from # the record if an error happened self.events.append(event) try: return retlog(event.run(self)) except Exception: self.events.pop() raise except Exception: if not shape_env.should_record_events or shape_env.is_recording: # If ShapeEnv is disabled or already recording an event, re-raise the exception without logging. raise log.error( # noqa: G201 "failed while running %s(*%s, **%s)", name, args[1:], kwargs, exc_info=log.isEnabledFor(logging.INFO), ) raise finally: NEST -= 1 return wrapper return decorator # Replays the ShapeEnvEvents list. # It assumes the first event is the constructor call. # # fn: transforms an old FX node into one corresponding to the newly created ShapeEnv. def replay_shape_env_events(events): from torch.fx.experimental.symbolic_shapes import ShapeEnv constructor_event = events[0] assert constructor_event.f == ShapeEnv # Constructs the new ShapeEnv. shape_env = constructor_event.run() for event in events[1:]: try: # Actually replays each event. # We need to call create_mapping_fn every time, since the node list might # change after each event is replayed. event.run(shape_env) except Exception: log.error("failed when running event: %s", event) raise return shape_env # FakeTensor metadata. # This is to be used in place of FakeTensor placeholders when calling # ShapeEnv.produce_guards. @dataclass class FakeTensorMeta: tensor_size: tuple[Union[int, torch.SymInt], ...] tensor_stride: tuple[Union[int, torch.SymInt], ...] tensor_storage_offset: Union[int, torch.SymInt] is_nested: bool def size(self) -> tuple[Union[int, torch.SymInt], ...]: return self.tensor_size def stride(self) -> tuple[Union[int, torch.SymInt], ...]: return self.tensor_stride def storage_offset(self) -> Union[int, torch.SymInt]: return self.tensor_storage_offset def dim(self) -> int: return len(self.tensor_size) @staticmethod def from_fake(fake) -> "FakeTensorMeta": return FakeTensorMeta( fake.size(), fake.stride(), fake.storage_offset(), fake.is_nested ) # [Note: ShapeEnv State Equality] # =============================== # # What is considered ShapeEnv state? # ---------------------------------- # We consider to be the state of a ShapeEnv instance everything that # is not in the inline tuple inside remove_nonstate_variables function. # That is: the fields within ShapeEnv that modify the flow of execution # of the program. # # So, for example: the replacements field might influence on how an # expression is simplified. That, in turn, may result in a guard being # statically known (i.e. not added). # # On the other hand, var_to_stack serves only changes what is printed # in the screen, i.e. used only for debugging purposes. Therefore, we # should not consider it when comparing states. # # What to do on NotEqualError? # ---------------------------- # Here are a few possible causes for getting a NotEqualError raised: # # 1. New field that does not belong in the ShapeEnv state. # For example: log field of type ShapeEnvLoggerAdapter. Different # ShapeEnv instances will always have different ShapeEnvLoggerAdapter # instances, i.e. equality comparison would fail. # Solution: add it to the inlined tuple inside remove_nonstate_variables # function inside check_equal method. # # 2. New field that is not directly comparable across instances. # For example: guards field of type List[ShapeGuard]. More specifically, # the ShapeGuard type holds an expression and a stack information # for debugging purposes. When replaying the even on a new ShapeEnv # instance, the stack would be different, which would trigger this error. # Solution: add a special case to the map_value function inside # check_equal function. # # 3. Mutation of ShapeEnv on some not recorded function. # If a mutation of the state of ShapeEnv happens inside a function # that is not recorded (or that no caller in the stack is recorded), # then, the replayed ShapeEnv won't catch that. # Solution: decorate the function with record_shape_env_event. # Checks whether the state of two ShapeEnv are equal w.r.t. the guards # returned by ShapeEnv.produce_guards. def shape_env_check_state_equal(env1, env2, non_state_variable_names, map_value): # Collect and remove variables that don't necessarily represent the state # of a ShapeEnv. Note: we copy the dictionary so that we don't modify the # instance itself. env1_vars = vars(env1).copy() env2_vars = vars(env2).copy() for v in non_state_variable_names: if v in env1_vars: env1_vars.pop(v) if v in env2_vars: env2_vars.pop(v) # Function for transforming the mismatched values into string. # Needed, since dict and set entries order might not be the same every time. def value_to_str(value: Any) -> str: if isinstance(value, dict): return ( "{" + ", ".join(f"{k}: {value[k]}" for k in sorted(value.keys(), key=str)) + "}" ) if isinstance(value, set): return "{" + ", ".join(f"{v}" for v in sorted(value)) + "}" return str(value) # Compares env1_vars with env2_vars. # Here, we allow the value of each field to be mapped, so that we appropriately # compare the two values. def compare_vars( map_value: Callable[[str, Any], Any], ) -> list[tuple[str, str, str]]: env1_set, env2_set = set(env1_vars), set(env2_vars) # First, compare the set of keys in each vars dictionary. if env1_set != env2_set: raise NotEqualError( "field set mismatch:", [ ( "found unique fields:", str(sorted(env1_set - env2_set)), str(sorted(env2_set - env1_set)), ), ], ) # Then, sort the keys, and compare the mapped values of each key. sorted_keys = list(env1_set) sorted_keys.sort() mapped_dict = [ (k, map_value(k, env1_vars[k]), map_value(k, env2_vars[k])) for k in sorted_keys ] # Return a list of tuples representing the fields that did not match # alongside their respective mapped values. return [ (f"{k}: values don't match.", value_to_str(val1), value_to_str(val2)) for k, val1, val2 in mapped_dict if val1 != val2 ] # Accumulate the mismatching fields. errors = compare_vars(map_value) if len(errors) > 0: raise NotEqualError("field values don't match:", errors) class NotEqualError(Exception): def __init__( self, msg: str, mismatched: list[tuple[str, str, str]], ) -> None: details = "\n".join( [ "\n".join( [ f"==> {inner_msg}", f" > Left: {str1}", f" > Right: {str2}", ] ) for inner_msg, str1, str2 in mismatched ] ) super().__init__( f"""\ ShapeEnv not equal: {msg} {details} """ )