import math import os import sys from collections import OrderedDict from dataclasses import astuple, dataclass from typing import Any, NamedTuple, Optional from typing_extensions import Self import torch from torch import nan, nn, UntypedStorage from torch._guards import active_fake_mode from torch._subclasses.fake_tensor import FakeTensorMode from torch.distributed._tools.common_utils import get_untyped_storages from torch.distributed._tools.mod_tracker import ModTracker from torch.distributed._tools.runtime_estimator import RuntimeEstimator from torch.testing._internal.composite_compliance import ( is_inplace, is_inplace_view_fn, is_view_fn, ) from torch.utils._python_dispatch import TorchDispatchMode from torch.utils._pytree import tree_flatten from torch.utils.checkpoint import SAC_IGNORED_OPS __all__ = ["SACEstimator", "SACStats", "MSPS", "SACTradeOffStats", "SACGreedyOrderMeta"] aten = torch.ops.aten _ADDITIONAL_IGNORED_OPS = { aten.lift_fresh.default, # type: ignore[attr-defined] torch.ops.profiler._record_function_exit._RecordFunction, # type: ignore[attr-defined] aten.clone.default, # type: ignore[attr-defined] # seems needed for torch.compile } OPS_TO_ALWAYS_SKIP = SAC_IGNORED_OPS | _ADDITIONAL_IGNORED_OPS # This value is hard-coded here: # https://github.com/pytorch/pytorch/blob/5fba5d83f0703ff8077ab65448a998e9ad6598fd/c10/cuda/CUDACachingAllocator.cpp#L117 _PYTORCH_MIN_ALLOCATE = ( 2**9 if int(os.environ.get("PYTORCH_NO_CUDA_MEMORY_CACHING", 0)) == 0 else 1 ) def _display_stats_tabular(headers: list[str], table_data: list[list[Any]]) -> None: try: from tabulate import tabulate except ImportError as err: raise ImportError("Please install tabulate.") from err # Use tabulate to print the table print(tabulate(table_data, headers=headers, tablefmt="rst")) # Based on: # https://github.com/facebookresearch/xformers/blob/main/xformers/checkpoint.py#L71 @dataclass class _SACMetadata: """ Stores metadata for a single operator for SAC. Attributes: func (Any): The operator function. time_taken (float): The time taken by the operator. memory_used (float): The memory used by the operator. curr_idx (int): The current operator index. output_ids (Tuple[int, ...]): The storage IDs of the operator's outputs. inplace_info (Tuple[int, ...]): Tuple of self and parent operator for in-place operator. is_view_like (bool): Whether the operator is view-like. is_rand_op (bool): Whether the operator is a random operator. """ func: Any time_taken: float memory_used: float curr_idx: int output_ids: tuple[int, ...] inplace_info: tuple[int, ...] is_view_like: bool is_rand_op: bool @dataclass class _SACModMetadata: """ Stores metadata for a module for SAC. Attributes: start_idx (int): The starting index of the module's operators. force_store_random (bool): Whether to force store random operators in the module. sac_metadata (List[_SACMetadata]): List of metadata for each operator in the module. """ start_idx: int force_store_random: bool sac_metadata: list[_SACMetadata] @dataclass class SACStats: """ A class for storing Activation Checkpointing statistics corresponding to a module. Attributes: func_names (List[str]): List of operator names. runtimes (List[float]): List of operator runtimes in millliseconds. memory (List[int]): List of operator memory usage in bytes. view_like_ops (List[int]): Indices of view-like operators. rand_ops (List[int]): Indices of random operators. saved_autograd_ops (List[int]): Indices of operator results saved by autograd engine. inplace_ops (List[Tuple[int, int]]): Tuple of indices of op and its first parent for Inplace operators. force_store_random (bool): Whether to force store random operator results. """ func_names: list[str] runtimes: list[float] memory: list[int] view_like_ops: list[int] rand_ops: list[int] saved_autograd_ops: list[int] inplace_ops: list[tuple[int, int]] force_store_random: bool class MSPS(NamedTuple): """ Represents Memory and Runtime Statistics for an operator/operator group. Attributes: func_names (set[str]): Set of operator/operator group names. op_idx (int): Operator index (group head index in case of operator groups). memory (int): Memory usage in bytes. runtime (float): Runtime in milliseconds. msps (float): Memory per second calculated as memory/runtime. """ func_names: set[str] op_idx: int memory: int runtime: float msps: float @dataclass class SACTradeOffStats: """ Stores statistics for activation-checkpointing trade-off. Attributes: n_segments (int): Number of piecewise linear segments fitted to the trade-off curve. slopes (List[float]): Slopes of the pieces of linear segments fitted to the trade-off curve. intercepts (List[float]): Intercepts of the of the pieces of linear segments fitted to the trade-off curve. fit_breaks (List[float]): Breakpoints of the of the pieces of linear segments fitted to the trade-off curve. tradeoff_curve (OrderedDict[float, float]): Trade-off curve data of memory discarded vs recomputation time. sac_memory (int): Total memory of operations available for activation checkpointing in bytes. sac_runtime (float): Total runtime of operations available for activation checkpointing in milliseconds. """ n_segments: int slopes: list[float] intercepts: list[float] fit_breaks: list[float] tradeoff_curve: OrderedDict[float, float] sac_memory: int sac_runtime: float @dataclass class SACGreedyOrderMeta: """ Stores metadata for Greedy-order SAC. Attributes: recomputed_ops (set[int]): Set of operator indices to be recomputed. stored_ops (set[int]): Set of operator indices to be stored. inplace_op_groups (dict[int, set[int]]): Dictionary of inplace operator groups from group-head to operators. random_ops_group (dict[int, set[int]]): Dictionary of random op group head to random ops. msps_meta (list[MSPS]): List of Memory and Runtime Statistics for operators. """ recomputed_ops: set[int] stored_ops: set[int] inplace_op_groups: dict[int, set[int]] random_ops_group: dict[int, set[int]] msps_meta: list[MSPS] class SACEstimator(TorchDispatchMode): """ Estimates the memory and recomputation time trade-offs for applying Selective Activation Checkpointing (SAC). This class provides a ``TorchDispatchMode`` based context manager that can be used to estimate the memory and runtime trade-offs of functions or ``torch.nn.Module``s for Selective Activation Checkpointing (SAC). It provides detailed statistics and metadata information for operators of each module and provides a greedy order for selecting the operators to be recomputed/checkpointed. It also constructs the per-module trade-off graph of discarded memory vs recomputation time for the obtained greedy order. Using ``RuntimeEstimator`` under the hood, it supports two estimation modes, `operator-level-benchmark` and (`operator-level-cost-model` (roofline model). Attributes: sac_mod_stats (Dict[str, SACStats]): Dictionary from module FQN (fully qualified name) to ``SACStats``. sac_mod_tradeoff_stats (Dict[str, SACTradeOffStats]): Dictionary from module FQN to ``SACTradeOffStats``. sac_mod_greedy_order_meta (Dict[str, SACGreedyOrderMeta]): Dictionary from module FQN to ``SACGreedyOrderMeta``. Note: 1) This class is designed to be used under ``FakeTensorMode``. 2) Currently, it only supports estimation of compute time and memory usage, and does not consider communication. Example usage: .. code-block:: python sac_estimator = SACEstimator() with FakeTensorMode(): module = ... inp = ... with sac_estimator("operator-level-cost-model"): output = module(inp) sac_estimator.display_modulewise_sac_stats(depth=4, print_tabular=True) """ def __init__(self) -> None: self.sac_mod_stats: dict[str, SACStats] = {} self.sac_mod_tradeoff_stats: dict[str, SACTradeOffStats] = {} self.sac_mod_greedy_order_meta: dict[str, SACGreedyOrderMeta] = {} self._mod_tracker = ModTracker() self._sac_metadata: list[_SACMetadata] = [] self._sac_mod_metadata: dict[str, _SACModMetadata] = {} self._leaf_modules: set[str] = set() self._saved_tensor_hook_ctx = torch.autograd.graph.saved_tensors_hooks( self._pack_hook, lambda x: x ) self._saved_tensor_ids: set[int] = set() self._estimate_runtime = RuntimeEstimator._roofline_estimate def _pack_hook(self, x: torch.Tensor) -> torch.Tensor: # Hook function to track underlying storage IDs of tensors # Updates the _saved_tensor_ids set with the IDs of the tensor's storages # Used in conjunction with torch.autograd.graph.saved_tensors_hooks untyped_storages = get_untyped_storages(x) storage_ids = (hash(st) for st in untyped_storages) self._saved_tensor_ids.update(storage_ids) return x def _pre_fw_hook(self, mod: nn.Module, inputs: Any) -> None: # Pre-forward hook function to prepare module metadata # Tracks module FQN, force store random flag, and ``SACModMetadata`` # Initializes metadata for non-leaf modules, marks leaf modules mod_fqn = self._mod_tracker.get_known_fqn(mod) assert mod_fqn is not None num_children = sum(1 for _ in mod.children()) if num_children > 0: force_store_random = self._get_force_store_random(inputs) self._sac_mod_metadata[mod_fqn] = _SACModMetadata( start_idx=len(self._sac_metadata), force_store_random=force_store_random, sac_metadata=[], ) else: self._leaf_modules.add(mod_fqn) def _post_fw_hook(self, mod: nn.Module, inputs: Any, outputs: Any) -> None: # 1. Retrieves the module's FQN and checks if it's a leaf module # 2. If not a leaf module, computes: # - ``SACStats`` using the module's metadata and force store random flag # - ``SACGreedyOrderMeta`` using the computed SAC statistics mod_fqn = self._mod_tracker.get_known_fqn(mod) assert mod_fqn is not None if mod_fqn in self._leaf_modules: return else: self.sac_mod_stats[mod_fqn] = self._get_sac_stats( data=self._sac_mod_metadata[mod_fqn].sac_metadata, force_store_random=self._sac_mod_metadata[mod_fqn].force_store_random, ) self.sac_mod_greedy_order_meta[mod_fqn] = self._get_greedy_order_meta( self.sac_mod_stats[mod_fqn] ) def _get_force_store_random(self, inputs: Any) -> bool: flat_inputs, _ = tree_flatten(inputs) return all(not isinstance(x, torch.Tensor) for x in flat_inputs) def _get_sac_stats( self, data: list[_SACMetadata], force_store_random: bool ) -> SACStats: # 1. Ignore the operations that should be skipped by SAC such as aten.detach.default because autograd # inserts those during backward and it breaks the fwd-bwd alignment filtered_data = [x for x in data if x.func not in OPS_TO_ALWAYS_SKIP] ( ops, runtimes_, memory_, new_ids, output_ids, inplace_ops_, view_like_ops_, rand_ops_, ) = zip(*[astuple(x) for x in filtered_data], strict=True) # 2. Extract the metadata information runtimes = list(runtimes_) memory = list(memory_) func_names = [op._overloadpacket.__name__ for op in ops] view_like_ops = [i for i, x in enumerate(view_like_ops_) if x] rand_ops = [i for i, x in enumerate(rand_ops_) if x] saved_autograd_ops = [ i for i, out_ids in enumerate(output_ids) if set(out_ids).issubset(self._saved_tensor_ids) ] # 3. Remap the inplace indices as we have removed OPS_TO_ALWAYS_SKIP # FIXME @sanketpurandare: Fix this by changing the parent of the inplace-op # to itself if the original parent is in OPS_TO_ALWAYS_SKIP. try: inplace_ops = [tuple(map(new_ids.index, x)) for x in inplace_ops_ if x] except ValueError as err: raise ValueError( f"The remapping of inplace ops failed since one of the inplace op parents" f" must have been present in {OPS_TO_ALWAYS_SKIP}" ) from err # 4. The last operation is always stored as the output of the checkpoint # block, so we can avoid recomputing it. We set the memory to zero # instead of adding a new constraint because we want both the 0 and 1 # endpoints for memory_budget to be valid # FIXME @sanketpurandare: this heuristic for finding the last non-view non-inplace op # might not always be correct, which would yield suboptimal policies last_op = len(ops) - 1 skip_ops_ = set(view_like_ops) | set({x[0] for x in inplace_ops}) reversed_skip_ops = sorted(skip_ops_, reverse=True) for op in reversed_skip_ops: if op == last_op: last_op -= 1 memory[last_op] = 0 # 5. Create a single ``SACStats`` object for the entire block of ``_SACMetadata``. return SACStats( func_names=func_names, runtimes=runtimes, memory=memory, view_like_ops=view_like_ops, rand_ops=rand_ops, saved_autograd_ops=saved_autograd_ops, inplace_ops=inplace_ops, # type: ignore[arg-type] force_store_random=force_store_random, ) def _get_inplace_metadata( self, func: Any, out_storages: set[UntypedStorage] ) -> tuple[int, tuple[int, ...], dict[str, tuple[int, ...]]]: # 1. Get the current index of the metadata obtained so far curr_idx = len(self._sac_metadata) # 2. Get the set of active modules that are not leaf active_mod_fqns: set[str] = { par for par in self._mod_tracker.parents if par not in self._leaf_modules } # 3. Output ids are the identifies of the storage objects corresponding to the tensors output_ids = tuple(hash(st) for st in out_storages) # 4. If the function is not inplace, return if not is_inplace(func): return curr_idx, output_ids, dict.fromkeys(active_mod_fqns, ()) op_idx = curr_idx # 5. Initialize the parent op ids of the inplace op for each of the active modules mod_op_parent_idxs: dict[str, int] = dict.fromkeys(active_mod_fqns, -1) for i, d in enumerate(self._sac_metadata): # 6. Find the first occurrence of a tensor corresponding to each module that # shares the same storage as the current tensor past_output_ids = d.output_ids if set(output_ids).issubset(set(past_output_ids)): for mod_fqn, op_parent_idx in mod_op_parent_idxs.items(): if op_parent_idx == -1: if acm_stats := self._sac_mod_metadata.get(mod_fqn, None): if i >= acm_stats.start_idx: mod_op_parent_idxs[mod_fqn] = i else: assert mod_fqn == "Global" mod_op_parent_idxs[mod_fqn] = i # 7. If no parent tensor is found, then it's probably an inplace op on the arguments # so one can just store the current-op idx as parent idx for mod_fqn, op_parent_idx in mod_op_parent_idxs.items(): if op_parent_idx < 0: mod_op_parent_idxs[mod_fqn] = op_idx mod_inplace_info = { mod_fqn: (op_idx, mod_op_parent_idxs[mod_fqn]) for mod_fqn in active_mod_fqns } return curr_idx, output_ids, mod_inplace_info # type: ignore[return-value] def __torch_dispatch__( # type: ignore[no-untyped-def] self, func, types, args=..., kwargs=None ): # 1. Get the runtime estimate out, op_time = self._estimate_runtime(func, args, kwargs) flat_outs, _ = tree_flatten(out) out_storages_cuda: set[UntypedStorage] = set() out_storages_cpu: set[UntypedStorage] = set() cuda_devices: set[torch.device] = set() for o in flat_outs: if isinstance(o, torch.Tensor): if o.device.type == "cuda": out_storages_cuda.update(get_untyped_storages(o)) cuda_devices.add(o.device) else: out_storages_cpu.update(get_untyped_storages(o)) # Check if there's more than 1 CUDA device assert len(cuda_devices) <= 1, ( f"{func.__name__}'s output has more than 1 CUDA devices {cuda_devices}" ) # 2. Get the memory consumed by output nbytes_cuda = sum( math.ceil(st.nbytes() / _PYTORCH_MIN_ALLOCATE) * _PYTORCH_MIN_ALLOCATE for st in out_storages_cuda ) nbytes_cpu = sum(st.nbytes() for st in out_storages_cpu) nbytes = nbytes_cuda + nbytes_cpu # 3. Get the current operator index, output storage identifiers and inplace metadata out_storages = out_storages_cuda | out_storages_cpu curr_idx, output_ids, mod_inplace_info = self._get_inplace_metadata( func, out_storages ) # 4. Determine if the function is in-place, random-op or a view-like is_view_like = is_view_fn(func) or is_inplace_view_fn(func) is_rand_op = torch.Tag.nondeterministic_seeded in func.tags if is_view_like: nbytes = 0 # sdpa has non-deterministic seed, but might be deterministic # if no dropout is applied if func.overloadpacket.__name__ == "_scaled_dot_product_flash_attention": is_rand_op = kwargs.get("dropout_p", 0) != 0 # 5. Create metadata information per active non-leaf module for mod_fqn in self._mod_tracker.parents: if mod_fqn in self._leaf_modules: continue acm = _SACMetadata( func=func, time_taken=op_time, memory_used=nbytes, curr_idx=curr_idx, output_ids=output_ids, inplace_info=mod_inplace_info[mod_fqn], is_view_like=is_view_like, is_rand_op=is_rand_op, ) if acm_stats := self._sac_mod_metadata.get(mod_fqn, None): acm_stats.sac_metadata.append(acm) else: assert mod_fqn == "Global", ( f"Module {mod_fqn} not found in AC Mod Stats" ) self._sac_metadata.append(acm) return out def _get_greedy_order_meta(self, sac_stats: SACStats) -> SACGreedyOrderMeta: # An inplace-op group is a set of inplace-ops that operate on the same underlying tensor storage. # 1. inplace_op_groups: A dictionary from the top-most parent of inplace-ops to the inplace-ops in the group # The top-most op can itself be an inplace-op or can be a non-inplace op. # 2. inplace_op_to_group_head: A dictionary that maps all the inplace-ops to their respective group heads. inplace_op_groups: dict[int, set[int]] = {} inplace_op_to_group_head: dict[int, int] = dict(sac_stats.inplace_ops) # Initialize inplace_op_groups using inplace_op_to_group_head for op_idx, group_head_idx in inplace_op_to_group_head.items(): op_group = inplace_op_groups.setdefault(group_head_idx, {group_head_idx}) op_group.add(op_idx) # Like inplace ops, all of the random ops in the function/module should all be either recomputed or saved # as a group. This is because, they affect the ranom seed generator. If force_store_random is set True, # all of the random ops will be stored by default. For easy of manageability, we store the top-most random op # as the leader of the random_ops_group. random_ops_group: dict[int, set[int]] = {} random_group_head_idx = min(sac_stats.rand_ops, default=-1) has_rand_ops = bool(sac_stats.rand_ops) if has_rand_ops: random_ops_group[random_group_head_idx] = set(sac_stats.rand_ops) # 1. Random ops are stored if force_store_random is set # 2. View-like ops are recomputed by default # 3. For inplace_op_groups: # a) If the head of this group is an inplace op, then we have to store the entire group. # b) If any op in the group is random and force_store_random is set, then entire group will be stored. # c) If none of ops in the group are random and the head of the group is not an in-place op, then # this group can be considered for recomputation in its entirety stored_ops: set[int] = set() recomputed_ops: set[int] = set() # Case 1: if has_rand_ops and sac_stats.force_store_random: stored_ops.add(random_group_head_idx) # Case 2: recomputed_ops.update(set(sac_stats.view_like_ops)) for group_head_idx, op_group in inplace_op_groups.items(): # Case 3a: if group_head_idx in inplace_op_to_group_head: stored_ops.add(group_head_idx) # Case 3b: if ( sac_stats.force_store_random & len(op_group & set(sac_stats.rand_ops)) > 0 ): stored_ops.add(group_head_idx) # The potential recompute candidates are populated as: recompute_candidates: set[int] = set() # 1) The random group head if it is not stored if has_rand_ops and random_group_head_idx not in stored_ops: recompute_candidates.add(random_group_head_idx) # 2) The in-place op group heads that are not stored recompute_candidates.update(set(inplace_op_groups.keys()) - stored_ops) # 3) The non-inplace and non-random ops that are neither stored nor recomputed by default recompute_candidates.update( set(range(len(sac_stats.memory))) - recomputed_ops - stored_ops - set(inplace_op_to_group_head.keys()) - set(sac_stats.rand_ops) ) # We define msps for a recomp candidate as the ratio of memory/runtime aka memory savings per second msps_meta: list[MSPS] = [] for cand_idx in recompute_candidates: op_indices = {cand_idx} if cand_idx in inplace_op_groups: op_indices.update(inplace_op_groups[cand_idx]) if has_rand_ops and cand_idx == random_group_head_idx: op_indices.update(sac_stats.rand_ops) mem = sum(sac_stats.memory[op_idx] for op_idx in op_indices) runtime = sum(sac_stats.runtimes[op_idx] for op_idx in op_indices) func_names = {sac_stats.func_names[op_idx] for op_idx in op_indices} msps = (mem / runtime) if runtime > 0 else sys.float_info.max msps_meta.append(MSPS(func_names, cand_idx, mem, runtime, msps)) # We choose candidates to be recomputed based on increasing msps msps_meta.sort(key=lambda x: x.msps, reverse=True) return SACGreedyOrderMeta( recomputed_ops, stored_ops, inplace_op_groups, random_ops_group, msps_meta ) def _get_sac_tradeoff_pwlf_stats( self, sac_stats: SACStats, greedy_order_meta: SACGreedyOrderMeta, n_segments: int = 2, save_tradeoff_graph: bool = False, filename: str = "ac_tradeoff", ) -> SACTradeOffStats: try: import numpy as np # type: ignore[import-not-found] import pwlf # type: ignore[import-untyped, import-not-found] except ImportError as err: raise ImportError("Please install pwlf and numpy package.") from err stored_ops, recomputed_ops, inplace_op_groups, random_ops_group, msps_meta = ( greedy_order_meta.stored_ops, greedy_order_meta.recomputed_ops, greedy_order_meta.inplace_op_groups, greedy_order_meta.random_ops_group, greedy_order_meta.msps_meta, ) # 1. Initialize the discarded memory and recomputation runtime to sum of already chosen recomputed_ops recomp_indices: set[int] = set() for r_idx in recomputed_ops: recomp_indices.add(r_idx) if r_idx in inplace_op_groups: recomp_indices.update(inplace_op_groups[r_idx]) if r_idx in random_ops_group: recomp_indices.update(random_ops_group[r_idx]) discarded_mem = sum(sac_stats.memory[op_idx] for op_idx in recomp_indices) recomp_runtime = sum(sac_stats.runtimes[op_idx] for op_idx in recomp_indices) # 2. Initialize the max recomputation time and total recomputation memory sac_runtime = sum(sac_stats.runtimes) sac_memory = sum(sac_stats.memory) # 3. Tradeoff curve stores the KV pair of the discarded memory to total memory and, # recomputation time to total runtime incurred. delta = 1e-2 tradeoff_curve = OrderedDict() # 4. Initialize the trade-off curve with the stats of of already chosen recomputed_ops tradeoff_curve[(discarded_mem / sac_memory) + delta] = ( recomp_runtime / sac_runtime ) # 5. Update the trade-off curve with memory and runtime stats of SAC candidates in the # greedy order of their ``MSPS``. for cand in msps_meta: discarded_mem += cand.memory recomp_runtime += cand.runtime tradeoff_curve[(discarded_mem / sac_memory) + delta] = ( recomp_runtime / sac_runtime ) # 6. Finally, we add the memory and recomputation time of the always stored ops. stored_indices: set[int] = set() for s_idx in stored_ops: stored_indices.add(s_idx) if s_idx in inplace_op_groups: stored_indices.update(inplace_op_groups[s_idx]) if s_idx in random_ops_group: stored_indices.update(random_ops_group[s_idx]) discarded_mem += sum(sac_stats.memory[op_idx] for op_idx in stored_indices) recomp_runtime += sum(sac_stats.runtimes[op_idx] for op_idx in stored_indices) tradeoff_curve[(discarded_mem / sac_memory) + delta] = ( recomp_runtime / sac_runtime ) x_ = list(tradeoff_curve.keys()) y_ = list(tradeoff_curve.values()) # 7. We shift the y values to left and x values to right to upperbound the trade-off function # TODO: Write a better explanation why this needs to be done x = x_[: len(x_) - 1] y = y_[1:] tradeoff_pwlf = pwlf.PiecewiseLinFit(x, y) # 8. Fit a piecewise linear function with the specified number of segments to the trade-off curve. n_segments = max(min(len(x) - 2, n_segments), 1) tradeoff_pwlf.fit(n_segments=n_segments) # save prediction graph def save_prediction_graph( pwlf_: pwlf.PiecewiseLinFit, x: list[float], y: list[float], filename: str ) -> None: try: import matplotlib.pyplot as plt # type: ignore[import-not-found] import numpy as np # type: ignore[import-not-found] except ImportError as err: raise ImportError( "Install matplotlib and numpy using pip: pip install matplotlib numpy" ) from err # predict for the determined points xHat = np.linspace(min(x), max(x), num=10000) yHat = pwlf_.predict(xHat) # plot the results plt.figure() plt.plot(x, y, "o", label="Shifted") plt.plot(xHat, yHat, "-", label="Predicted") plt.plot(x_, y_, "x", label="Original") plt.ylabel("Recomp time / Total recomp time") plt.xlabel("Memory discarded / Total memory") plt.legend() plt.title(f"{filename}") plt.suptitle( f"Total Memory = {sac_memory} B Total Runtime = {sac_runtime:.4f} ms", fontsize=10, ) folder_name = "tradeoff_graphs" if not os.path.exists(folder_name): os.makedirs(folder_name) # Save the plots in the folder plt.savefig(os.path.join(folder_name, f"{filename}.png")) if save_tradeoff_graph: save_prediction_graph(tradeoff_pwlf, x, y, filename) # 9. Obtain the slopes, intercepts and breakpoints of the fitted piecewise linear functions slopes = tradeoff_pwlf.calc_slopes().tolist() assert isinstance(tradeoff_pwlf.intercepts, np.ndarray) and isinstance( tradeoff_pwlf.fit_breaks, np.ndarray ) intercepts = tradeoff_pwlf.intercepts.tolist() fit_breaks = tradeoff_pwlf.fit_breaks.tolist() return SACTradeOffStats( n_segments=n_segments, slopes=slopes, intercepts=intercepts, # type: ignore[arg-type] fit_breaks=fit_breaks, # type: ignore[arg-type] tradeoff_curve=tradeoff_curve, sac_memory=sac_memory, sac_runtime=sac_runtime, ) def display_sac_stats( self, sac_stats: SACStats, print_tabular: bool = False ) -> None: """ Displays the SAC statistics. Args: sac_stats (SACStats): The SAC statistics to display. print_tabular (bool, optional): Whether to print the statistics in a tabular format. Defaults to False. Prints: 1. Total Memory: The total memory usage in bytes. 2. Total Runtime: The total runtime in milliseconds. 3. Store Random: A flag indicating whether to force store random operator results. Followed by a table with the following columns: 1. Op Idx: The operator index. 2. Op Name: The operator name. 3. Runtimes (ms): The operator runtime in milliseconds. 4. Memory (B): The operator memory usage in bytes. 5. View-like: A flag indicating whether the operator is view-like. 6. Random: A flag indicating whether the operator is random. 7. Saved Autograd: A flag indicating whether the operator's result is saved by autograd engine. 8. In-place: The index of the operator's first parent, or None if not in-place. If print_tabular is True, the table is printed in a tabular format. Otherwise, the table is printed in a plain text format. """ print( f"Total Memory: {sum(sac_stats.memory)} B Total Runtime: {sum(sac_stats.runtimes)} ms" f" Store Random: {sac_stats.force_store_random}" ) table_data = [] op_parent = dict(sac_stats.inplace_ops) for i, fn_name in enumerate(sac_stats.func_names): row = [ str(i), fn_name, f"{sac_stats.runtimes[i]:.4f}", str(sac_stats.memory[i]), str(i in sac_stats.view_like_ops), str(i in sac_stats.rand_ops), str(i in sac_stats.saved_autograd_ops), str(op_parent.get(i, None)), ] table_data.append(row) # Define headers headers = [ "Op Idx", "Op Name", "Runtimes(ms)", "Memory (B)", "View-like", "Random", "Saved Autograd", "In-place", ] if print_tabular: _display_stats_tabular(headers, table_data) else: max_widths = [0 for _ in range(len(headers))] table_data.insert(0, headers) for row in table_data: for i, elem in enumerate(row): max_widths[i] = max(max_widths[i], len(elem)) for row in table_data: print( "\t".join( [f"{elem:<{max_widths[i]}}" for i, elem in enumerate(row)] ) ) def display_sac_tradeoff_stats( self, greedy_order_meta: SACGreedyOrderMeta, sac_stats: SACStats, print_tabular: bool = False, ) -> None: """ Displays the SAC trade-off statistics. Args: greedy_order_meta (SACGreedyOrderMeta): The SAC greedy order metadata. sac_stats (SACStats): The SAC statistics. print_tabular (bool, optional): Whether to print the statistics in a tabular format. Defaults to False. Prints: A table with the following columns: 1. Op Id(s): The operator index(es). 2. Op Name(s): The operator name(s). 3. Discarded Mem (%): The percentage of discarded memory. 4. Discarded Mem (B): The discarded memory in bytes. 5. Recomp time (%): The percentage of recomputed time. 6. Recomp time (ms): The recomputed time in milliseconds. 7. MSPS: The memory per second. 8. Always Stored: A flag indicating whether the operator is always stored. 9. Always Recomputed: A flag indicating whether the operator is always recomputed. If print_tabular is True, the table is printed in a tabular format. Otherwise, the table is printed in a plain text format. """ table_data = [] total_memory, total_runtime = sum(sac_stats.memory), sum(sac_stats.runtimes) discarded_mem: int = 0 recomp_runtime: float = 0.0 def append_row( op_indices: set[int], func_names: set[str], msps: Optional[float] = None, stored: Optional[bool] = False, recomputed: Optional[bool] = False, ) -> None: row = [ str(op_indices), str(func_names), f"{discarded_mem / total_memory:.4f}", str(discarded_mem), f"{recomp_runtime / total_runtime:.4f}", str(recomp_runtime), f"{msps:.2e}" if msps is not None else str(nan), str(stored), str(recomputed), ] table_data.append(row) stored_ops, recomputed_ops, inplace_op_groups, random_ops_group, msps_meta = ( greedy_order_meta.stored_ops, greedy_order_meta.recomputed_ops, greedy_order_meta.inplace_op_groups, greedy_order_meta.random_ops_group, greedy_order_meta.msps_meta, ) for op_idx in recomputed_ops: op_indices: set[int] = {op_idx} if op_idx in inplace_op_groups: op_indices.update(inplace_op_groups[op_idx]) if op_idx in random_ops_group: op_indices.update(random_ops_group[op_idx]) discarded_mem += sum(sac_stats.memory[i] for i in op_indices) recomp_runtime += sum(sac_stats.runtimes[i] for i in op_indices) func_names = {sac_stats.func_names[i] for i in op_indices} append_row(op_indices, func_names, recomputed=True) for cand in msps_meta: discarded_mem += cand.memory recomp_runtime += cand.runtime op_indices = {cand.op_idx} if cand.op_idx in inplace_op_groups: op_indices.update(inplace_op_groups[cand.op_idx]) if cand.op_idx in random_ops_group: op_indices.update(random_ops_group[cand.op_idx]) append_row(op_indices, cand.func_names, msps=cand.msps) for op_idx in stored_ops: op_indices = {op_idx} if op_idx in inplace_op_groups: op_indices.update(inplace_op_groups[op_idx]) if op_idx in random_ops_group: op_indices.update(random_ops_group[op_idx]) discarded_mem += sum(sac_stats.memory[i] for i in op_indices) recomp_runtime += sum(sac_stats.runtimes[i] for i in op_indices) func_names = {sac_stats.func_names[i] for i in op_indices} append_row(op_indices, func_names, stored=True) headers = [ "Op Id(s)", "Op Name(s)", "Discarded Mem (%)", "Discarded Mem (B)", "Recomp time (%)", "Recomp time (ms)", "MSPS", "Always Stored", "Always Recomputed", ] if print_tabular: _display_stats_tabular(headers, table_data) else: max_widths = [0 for _ in range(len(headers))] table_data.insert(0, headers) for row in table_data: for i, elem in enumerate(row): max_widths[i] = max(max_widths[i], len(elem)) for row in table_data: print( "\t".join( [f"{elem:<{max_widths[i]}}" for i, elem in enumerate(row)] ) ) def pwlf_sac_tradeoff_curve( self, n_segments: int = 2, save_tradeoff_graphs: bool = False, ) -> None: """ Fits a piecewise linear function with the specified sumber of segments to the SAC trade-off curve of discarded memory vs recomputation time. Args: n_segments (int, optional): The number of segments to be used for fitting the piecewise linear function to the trade-off curve. Defaults to 2. save_tradeoff_graphs (bool, optional): Whether to save the trade-off graphs to file. Defaults to False. If save_tradeoff_graphs is True, the trade-off graphs are saved to file using the module FQN as the filename. """ for mod_fqn, sac_stats in self.sac_mod_stats.items(): self.sac_mod_tradeoff_stats[mod_fqn] = self._get_sac_tradeoff_pwlf_stats( sac_stats=sac_stats, greedy_order_meta=self.sac_mod_greedy_order_meta[mod_fqn], n_segments=n_segments, save_tradeoff_graph=save_tradeoff_graphs, filename=mod_fqn, ) def display_modulewise_sac_stats( self, depth: int = 2, print_tabular: bool = False ) -> None: """ Displays the SAC and trade-off statistics for each module. Args: depth (int, optional): The maximum depth of modules to display. Defaults to 2. print_tabular (bool, optional): Whether to print the statistics in a tabular format. Defaults to False. Prints: For each module with depth less than or equal to the specified depth: 1. The SAC statistics for the module (using display_sac_stats). 2. The SAC trade-off statistics for the module (using display_sac_tradeoff_stats). If print_tabular is True, the statistics are printed in a tabular format. Otherwise, the statistics are printed in a plain text format. """ for mod_fqn, sac_stats in self.sac_mod_stats.items(): mod_depth = mod_fqn.count(".") + 1 if mod_depth > depth: continue print(f"Module: {mod_fqn}") self.display_sac_stats(sac_stats, print_tabular) print(f"AC Trade-off for Module: {mod_fqn} MSPS = Memory/Runtime") self.display_sac_tradeoff_stats( self.sac_mod_greedy_order_meta[mod_fqn], sac_stats, print_tabular ) def __call__(self, estimate_mode_type: str) -> Self: """ Sets the estimate mode type. Currently supported modes: - "operator-level-benchmark": Estimates runtime using operator benchmarking. - "operator-level-cost-model": Estimates runtime using roofline cost model. Args: estimate_mode_type (str): The type of estimate mode to use. Returns: SACEstimator: The SAC estimator instance. Raises: NotImplementedError: If the estimate mode type is not supported. """ if estimate_mode_type == "operator-level-benchmark": self._estimate_runtime = RuntimeEstimator._benchmark_estimate elif estimate_mode_type == "operator-level-cost-model": self._estimate_runtime = RuntimeEstimator._roofline_estimate else: raise NotImplementedError( f"estimate_mode_type {estimate_mode_type} not supported" ) return self def __enter__(self) -> Self: # type: ignore[no-untyped-def] fake_mode = active_fake_mode() assert isinstance(fake_mode, FakeTensorMode), ( "SAC Estimator should be called in FakeTensorMode" ) RuntimeEstimator.fake_mode = fake_mode self._mod_tracker.register_user_hooks( pre_fw_hook=self._pre_fw_hook, post_fw_hook=self._post_fw_hook, ) self._mod_tracker.__enter__() self._saved_tensor_hook_ctx.__enter__() return super().__enter__() def __exit__(self, *args: Any) -> None: # type: ignore[no-untyped-def] self._saved_tensor_hook_ctx.__exit__() self._mod_tracker.__exit__(*args) super().__exit__(*args)