""" Checkpoint writer functionality for machine learning models. This module provides classes for writing checkpoints to storage, including determining checkpoint layout, configuring the writer, and defining hooks for custom actions during the checkpoint writing process. """ import abc import logging import os from concurrent.futures import Future from dataclasses import dataclass from pathlib import Path from typing import Any, Optional import torch from .barriers import Barrier from .types import RankInfo, STATE_DICT logger = logging.getLogger(__name__) class WriterHook(abc.ABC): """ Abstract base class for checkpoint commit hooks. A commit hook provides callbacks that are executed before and after a checkpoint is committed to storage. This allows for custom actions to be performed at specific points in the checkpoint writing process, such as metadata updates, cleanup operations, or notifications. """ @abc.abstractmethod def pre_commit(self, path: str, **kwargs: dict[str, Any]) -> None: """ Performs actions before committing the checkpoint. """ @abc.abstractmethod def post_commit(self, path: str, **kwargs: dict[str, Any]) -> None: """ Performs actions after committing the checkpoint. """ @dataclass class CheckpointWriterConfig: """ Configuration options for the CheckpointWriter. Attributes: write_barrier_timeout_secs: Maximum time in seconds to wait for all ranks to reach the checkpoint barrier before timing out. Default is 600 seconds. """ write_barrier_timeout_secs: int = 600 class CheckpointWriter: """ Handles writing state dictionaries to storage. This class is responsible for writing model state dictionaries to storage according to the specified checkpoint layout. It supports synchronization barriers to ensure all ranks in a distributed setting complete their checkpoint operations. """ def __init__( self, config: CheckpointWriterConfig, rank_info: RankInfo, barrier: Optional[Barrier] = None, commit_hook: Optional[WriterHook] = None, ): """ Initialize a CheckpointWriter. Args: config: Configuration options for the checkpoint writer. rank_info: Information about the current rank in a distributed setting. barrier: Optional synchronization barrier for distributed checkpointing. Note: The barrier should be initialized with the appropriate barrier_prefix and timeout_secs parameters. commit_hook: Optional hook for custom actions before and after checkpoint commits. """ self._config = config self._rank_info = rank_info self._commit_hook = commit_hook self._barrier = barrier def write( self, path: str, state_dict: STATE_DICT, **kwargs: dict[str, Any], ) -> Optional[Future[None]]: """ Writes the state_dict to storage. Args: path (str): The path to write the checkpoint to. state_dict (STATE_DICT): The state_dict to write. **kwargs: Additional keyword arguments passed to hooks. Returns: Optional[Future[None]]: A future for tracking the write operation, if applicable. """ logger.debug( "Writing checkpoint to %s for rank %s", path, self._rank_info.global_rank, ) dir_path = Path(path) full_path = dir_path / f"checkpoint_{self._rank_info.global_rank}.pt" os.makedirs( os.path.dirname(full_path), exist_ok=True, ) torch.save(state_dict, full_path) logger.debug("Successfully saved checkpoint file to %s", full_path) # Execute pre-commit hook if available commit_hook = self._commit_hook if commit_hook is not None: logger.debug("Executing pre-commit hook for %s", path) commit_hook.pre_commit(path, **kwargs) # Wait for all ranks to finish writing if barrier is available barrier = self._barrier if barrier is not None: logger.info( "Waiting for all ranks at barrier with timeout %ss", self._config.write_barrier_timeout_secs, ) barrier.execute_barrier() logger.info("All ranks passed barrier") else: logger.info("No barrier configured, skipping synchronization") # Execute commit hook if available if commit_hook is not None: logger.debug("Executing commit hook for %s", path) commit_hook.post_commit(path, **kwargs) logger.info( "Successfully wrote checkpoint to %s for rank %s", path, self._rank_info.global_rank, ) return None def close(self) -> None: """ Close the writer and release any resources. This is a no-op for the base CheckpointWriter but may be overridden by subclasses that need to perform cleanup. """ logger.debug("Closing checkpoint writer")