import abc import builtins import importlib import inspect import logging import pickle import types from dataclasses import dataclass from typing import Any, Callable, Optional import torch import torch.fx from torch._dynamo.precompile_context import PrecompileContext from . import convert_frame from .hooks import Hooks log = logging.getLogger(__name__) class SerializableCallable(abc.ABC): @classmethod @abc.abstractmethod def serialize_compile_artifacts(cls, fn: Any) -> bytes: pass @classmethod @abc.abstractmethod def deserialize_compile_artifacts(cls, data: bytes) -> Any: pass def bind_locals( signature: inspect.Signature, *args: Any, **kwargs: Any ) -> dict[str, Any]: bound_arguments = signature.bind(*args, **kwargs) bound_arguments.apply_defaults() return bound_arguments.arguments @dataclass class CompileArtifacts: signature: inspect.Signature bytecode: types.CodeType guard_manager: Optional[torch._dynamo.guards.GuardManagerWrapper] guards_state: bytes import_sources: dict[str, str] backend_id: str compiled_fn: SerializableCallable original_code: types.CodeType closure: Optional[tuple[Any, ...]] @dataclass class AOTCompiledFunction: _artifacts: CompileArtifacts def guard_check(self, *args: Any, **kwargs: Any) -> bool: f_locals = bind_locals(self._artifacts.signature, *args, **kwargs) assert self._artifacts.guard_manager is not None return self._artifacts.guard_manager.check(f_locals) def __post_init__(self) -> None: import_sources = { alias: importlib.import_module(module_name) for alias, module_name in self._artifacts.import_sources.items() } f_globals = { **import_sources, self._artifacts.backend_id: self._artifacts.compiled_fn, } self.fn = types.FunctionType( self._artifacts.bytecode, f_globals, closure=self._artifacts.closure ) if self._artifacts.guard_manager is None: guards_state = pickle.loads(self._artifacts.guards_state) self._artifacts.guard_manager = torch._dynamo.guards.CheckFunctionManager( self._artifacts.original_code, guards_state.output_graph, shape_code_parts=guards_state.shape_code_parts, runtime_global_scope=f_globals, ).guard_manager def __call__(self, *args: Any, **kwargs: Any) -> Any: assert self._artifacts.guard_manager is not None if not self.guard_check(*args, **kwargs): f_locals = bind_locals(self._artifacts.signature, *args, **kwargs) reason = str(self._artifacts.guard_manager.check_verbose(f_locals)) raise RuntimeError(f"GuardManager check failed, reason: {reason}") return self.fn(*args, **kwargs) def save_compiled_function(self, path: str) -> None: with open(path, "wb") as f: f.write(type(self).serialize(self)) @classmethod def serialize(cls, fn: "AOTCompiledFunction") -> bytes: from torch._dynamo.package import SerializedCode state = fn._artifacts.__dict__.copy() state["guard_manager"] = None state["bytecode"] = SerializedCode.from_code_object(state["bytecode"]) compiled_fn = state["compiled_fn"] state["compiled_fn"] = ( type(compiled_fn).deserialize_compile_artifacts, type(compiled_fn).serialize_compile_artifacts(compiled_fn), ) state["original_code"] = SerializedCode.from_code_object(state["original_code"]) return pickle.dumps(state) @classmethod def deserialize(cls, data: bytes) -> "AOTCompiledFunction": from torch._dynamo.package import SerializedCode state = pickle.loads(data) state["bytecode"] = SerializedCode.to_code_object(state["bytecode"]) deserializer, compiled_fn_state = state["compiled_fn"] state["compiled_fn"] = deserializer(compiled_fn_state) state["original_code"] = SerializedCode.to_code_object(state["original_code"]) artifacts = CompileArtifacts(**state) return cls(artifacts) class BundledAOTAutogradSerializableCallable(SerializableCallable): """ Represents a serializable callable generated by compile_fx. This class wraps around the compiled function generated by AOTAutograd. TODO: Instead of using PrecompileContext to grab it from AOTAutograd, this object should be what's *returned* by aot_module_simplified. We'll do that refactor in a later PR. """ def __init__(self, artifact: Any) -> None: """ Takes in a BundledAOTAutogradCacheArtifact, which is the serialized form of a compiled function generated by AOTAutograd. """ self.compiled_fn = artifact.after_deserialization() self.data = artifact.content def __getattr__(self, attr: Any) -> Any: if hasattr(self, attr): return getattr(super(), attr) else: return getattr(self.compiled_fn, attr) @classmethod def from_backend_id( cls, backend_id: str ) -> "BundledAOTAutogradSerializableCallable": """ Takes in a backend_id, and returns a BundledAOTAutogradSerializableCallable that wraps around the compiled function generated by AOTAutograd. """ artifact = PrecompileContext.serialize_artifact_by_key(backend_id) if artifact is None: raise RuntimeError("No artifact found for backend_id: " + backend_id) return cls(artifact) @classmethod def serialize_compile_artifacts( cls, fn: "BundledAOTAutogradSerializableCallable" ) -> bytes: return fn.data @classmethod def deserialize_compile_artifacts(cls, data: bytes) -> Any: from torch._functorch._aot_autograd.autograd_cache import ( BundledAOTAutogradCacheArtifact, ) # The key in the artifact is not important here since we're not populating a cache, # we just want to grab the callable back out of the serialized entry artifact = BundledAOTAutogradCacheArtifact("", data) return cls(artifact) def __call__(self, *args: Any, **kwargs: Any) -> Any: return self.compiled_fn(*args, **kwargs) def aot_compile_fullgraph( model: Any, example_inputs: tuple[tuple[Any, ...], dict[str, Any]], hooks: Hooks, backend: Callable[[torch.fx.GraphModule, list[torch.Tensor]], SerializableCallable], ) -> AOTCompiledFunction: from torch._dynamo.guards import CheckFunctionManager from torch._dynamo.utils import dynamo_timed, get_metrics_context from torch._guards import compile_context, CompileContext, TracingContext args, kwargs = example_inputs if hasattr(model, "__self__"): fn = model.__func__ args = (model.__self__,) + args elif inspect.isfunction(model): fn = model else: raise RuntimeError(f"Unsupported model code type {model}") signature = inspect.signature(fn) f_locals = bind_locals(signature, *args, **kwargs) if fn.__code__.co_freevars or fn.__closure__: assert len(fn.__closure__) == len(fn.__code__.co_freevars) f_locals.update( { name: cell.cell_contents for name, cell in zip(fn.__code__.co_freevars, fn.__closure__) } ) with ( compile_context(CompileContext(convert_frame.get_compile_id({}))), get_metrics_context(), dynamo_timed("fullgraph_capture"), ): capture_output = convert_frame.fullgraph_capture( convert_frame.FrameInfo( fn.__code__, fn.__globals__, f_locals, builtins.__dict__, closure=fn.__closure__ or (), # type: ignore[arg-type] ) ) dynamo_output = capture_output.dynamo_output if not hooks.guard_filter_fn: from torch._dynamo.types import GuardFilterEntry def new_guard_filter_fn( guard_entries: list[GuardFilterEntry], ) -> list[bool]: return [ ( not ( g.is_global or g.guard_type in CheckFunctionManager.UNSUPPORTED_SERIALIZATION_GUARD_TYPES ) ) for g in guard_entries ] hooks.guard_filter_fn = new_guard_filter_fn check_fn = dynamo_output.build_guards( fn.__code__, hooks=hooks, save=True, strict_error=True ) assert check_fn.guards_state is not None backend_input = capture_output.backend_input backend_input.graph_module._backend_id = backend_input.backend_id # type: ignore[assignment] output_graph = dynamo_output.tracer_output.output_graph assert output_graph is not None import_sources = output_graph.import_sources with ( torch._guards.tracing(TracingContext(backend_input.fake_mode)), torch._functorch.config.patch("bundled_autograd_cache", True), ): compiled_fn = backend(backend_input.graph_module, backend_input.example_inputs) # If Inductor backend is used, grab the compiled_fn from PrecompileContext # TODO: this should be replaced once we make the backend return the SerializableCallable directly. if isinstance(backend, torch._TorchCompileInductorWrapper): compiled_fn = BundledAOTAutogradSerializableCallable.from_backend_id( backend_input.backend_id ) if not isinstance(compiled_fn, SerializableCallable): if hasattr(backend, "compiler_fn"): compiler_fn = backend.compiler_fn else: compiler_fn = backend raise RuntimeError( f"Compiled function type {type(compiled_fn)} (produced " + f"from backend {compiler_fn}) does not implement SerializableCallable." ) artifacts = CompileArtifacts( signature=signature, bytecode=dynamo_output.bytecode, guard_manager=check_fn.guard_manager, guards_state=check_fn.guards_state, import_sources=import_sources, backend_id=backend_input.backend_id, compiled_fn=compiled_fn, original_code=fn.__code__, closure=fn.__closure__, ) aot_compiled_fn = AOTCompiledFunction(_artifacts=artifacts) return aot_compiled_fn