from __future__ import annotations import logging from typing import Any from typing import TYPE_CHECKING import warnings from tqdm.auto import tqdm from optuna import logging as optuna_logging if TYPE_CHECKING: from optuna.study import Study _tqdm_handler: _TqdmLoggingHandler | None = None # Reference: https://gist.github.com/hvy/8b80c2cedf02b15c24f85d1fa17ebe02 class _TqdmLoggingHandler(logging.StreamHandler): def emit(self, record: Any) -> None: try: msg = self.format(record) tqdm.write(msg) self.flush() except (KeyboardInterrupt, SystemExit): raise except Exception: self.handleError(record) class _ProgressBar: """Progress Bar implementation for :func:`~optuna.study.Study.optimize` on the top of `tqdm`. Args: is_valid: Whether to show progress bars in :func:`~optuna.study.Study.optimize`. n_trials: The number of trials. timeout: Stop study after the given number of second(s). """ def __init__( self, is_valid: bool, n_trials: int | None = None, timeout: float | None = None, ) -> None: if is_valid and n_trials is None and timeout is None: warnings.warn("Progress bar won't be displayed because n_trials and timeout are None.") self._is_valid = is_valid and (n_trials or timeout) is not None self._n_trials = n_trials self._timeout = timeout self._last_elapsed_seconds = 0.0 if self._is_valid: if self._n_trials is not None: self._progress_bar = tqdm(total=self._n_trials) elif self._timeout is not None: total = tqdm.format_interval(self._timeout) fmt = "{desc} {percentage:3.0f}%|{bar}| {elapsed}/" + total self._progress_bar = tqdm(total=self._timeout, bar_format=fmt) else: assert False global _tqdm_handler _tqdm_handler = _TqdmLoggingHandler() _tqdm_handler.setLevel(logging.INFO) _tqdm_handler.setFormatter(optuna_logging.create_default_formatter()) optuna_logging.disable_default_handler() optuna_logging._get_library_root_logger().addHandler(_tqdm_handler) def update(self, elapsed_seconds: float, study: Study) -> None: """Update the progress bars if ``is_valid`` is :obj:`True`. Args: elapsed_seconds: The time past since :func:`~optuna.study.Study.optimize` started. study: The current study object. """ if self._is_valid: if not study._is_multi_objective(): # Not updating the progress bar when there are no complete trial. try: msg = ( f"Best trial: {study.best_trial.number}. " f"Best value: {study.best_value:.6g}" ) self._progress_bar.set_description(msg) except ValueError: pass if self._n_trials is not None: self._progress_bar.update(1) if self._timeout is not None: self._progress_bar.set_postfix_str( "{:.02f}/{} seconds".format(elapsed_seconds, self._timeout) ) elif self._timeout is not None: time_diff = elapsed_seconds - self._last_elapsed_seconds if elapsed_seconds > self._timeout: # Clip elapsed time to avoid tqdm warnings. time_diff -= elapsed_seconds - self._timeout self._progress_bar.update(time_diff) self._last_elapsed_seconds = elapsed_seconds else: assert False def close(self) -> None: """Close progress bars.""" if self._is_valid: self._progress_bar.close() assert _tqdm_handler is not None optuna_logging._get_library_root_logger().removeHandler(_tqdm_handler) optuna_logging.enable_default_handler()