from __future__ import annotations from collections.abc import Container from collections.abc import Sequence import copy from datetime import datetime import threading from typing import Any import uuid import optuna from optuna import distributions # NOQA from optuna._typing import JSONSerializable from optuna.exceptions import DuplicatedStudyError from optuna.storages import BaseStorage from optuna.storages._base import DEFAULT_STUDY_NAME_PREFIX from optuna.study._frozen import FrozenStudy from optuna.study._study_direction import StudyDirection from optuna.trial import FrozenTrial from optuna.trial import TrialState _logger = optuna.logging.get_logger(__name__) class InMemoryStorage(BaseStorage): """Storage class that stores data in memory of the Python process. Example: Create an :class:`~optuna.storages.InMemoryStorage` instance. .. testcode:: import optuna def objective(trial): x = trial.suggest_float("x", -100, 100) return x**2 storage = optuna.storages.InMemoryStorage() study = optuna.create_study(storage=storage) study.optimize(objective, n_trials=10) """ def __init__(self) -> None: self._trial_id_to_study_id_and_number: dict[int, tuple[int, int]] = {} self._study_name_to_id: dict[str, int] = {} self._studies: dict[int, _StudyInfo] = {} self._max_study_id = -1 self._max_trial_id = -1 self._lock = threading.RLock() self._prev_waiting_trial_number: dict[int, int] = {} def __getstate__(self) -> dict[Any, Any]: state = self.__dict__.copy() del state["_lock"] return state def __setstate__(self, state: dict[Any, Any]) -> None: self.__dict__.update(state) self._lock = threading.RLock() def create_new_study( self, directions: Sequence[StudyDirection], study_name: str | None = None ) -> int: with self._lock: study_id = self._max_study_id + 1 self._max_study_id += 1 if study_name is not None: if study_name in self._study_name_to_id: raise DuplicatedStudyError else: study_uuid = str(uuid.uuid4()) study_name = DEFAULT_STUDY_NAME_PREFIX + study_uuid self._studies[study_id] = _StudyInfo(study_name, list(directions)) self._study_name_to_id[study_name] = study_id self._prev_waiting_trial_number[study_id] = 0 _logger.info("A new study created in memory with name: {}".format(study_name)) return study_id def delete_study(self, study_id: int) -> None: with self._lock: self._check_study_id(study_id) for trial in self._studies[study_id].trials: del self._trial_id_to_study_id_and_number[trial._trial_id] study_name = self._studies[study_id].name del self._study_name_to_id[study_name] del self._studies[study_id] del self._prev_waiting_trial_number[study_id] def set_study_user_attr(self, study_id: int, key: str, value: Any) -> None: with self._lock: self._check_study_id(study_id) self._studies[study_id].user_attrs[key] = value def set_study_system_attr(self, study_id: int, key: str, value: JSONSerializable) -> None: with self._lock: self._check_study_id(study_id) self._studies[study_id].system_attrs[key] = value def get_study_id_from_name(self, study_name: str) -> int: with self._lock: if study_name not in self._study_name_to_id: raise KeyError("No such study {}.".format(study_name)) return self._study_name_to_id[study_name] def get_study_name_from_id(self, study_id: int) -> str: with self._lock: self._check_study_id(study_id) return self._studies[study_id].name def get_study_directions(self, study_id: int) -> list[StudyDirection]: with self._lock: self._check_study_id(study_id) return self._studies[study_id].directions def get_study_user_attrs(self, study_id: int) -> dict[str, Any]: with self._lock: self._check_study_id(study_id) return self._studies[study_id].user_attrs def get_study_system_attrs(self, study_id: int) -> dict[str, Any]: with self._lock: self._check_study_id(study_id) return self._studies[study_id].system_attrs def get_all_studies(self) -> list[FrozenStudy]: with self._lock: return [self._build_frozen_study(study_id) for study_id in self._studies] def _build_frozen_study(self, study_id: int) -> FrozenStudy: study = self._studies[study_id] return FrozenStudy( study_name=study.name, direction=None, directions=study.directions, user_attrs=copy.deepcopy(study.user_attrs), system_attrs=copy.deepcopy(study.system_attrs), study_id=study_id, ) def create_new_trial(self, study_id: int, template_trial: FrozenTrial | None = None) -> int: with self._lock: self._check_study_id(study_id) if template_trial is None: trial = self._create_running_trial() else: trial = copy.deepcopy(template_trial) trial_id = self._max_trial_id + 1 self._max_trial_id += 1 trial.number = len(self._studies[study_id].trials) trial._trial_id = trial_id self._trial_id_to_study_id_and_number[trial_id] = (study_id, trial.number) self._studies[study_id].trials.append(trial) self._update_cache(trial_id, study_id) return trial_id @staticmethod def _create_running_trial() -> FrozenTrial: return FrozenTrial( trial_id=-1, # dummy value. number=-1, # dummy value. state=TrialState.RUNNING, params={}, distributions={}, user_attrs={}, system_attrs={}, value=None, intermediate_values={}, datetime_start=datetime.now(), datetime_complete=None, ) def set_trial_param( self, trial_id: int, param_name: str, param_value_internal: float, distribution: distributions.BaseDistribution, ) -> None: with self._lock: trial = self._get_trial(trial_id) self.check_trial_is_updatable(trial_id, trial.state) study_id = self._trial_id_to_study_id_and_number[trial_id][0] # Check param distribution compatibility with previous trial(s). if param_name in self._studies[study_id].param_distribution: distributions.check_distribution_compatibility( self._studies[study_id].param_distribution[param_name], distribution ) # Set param distribution. self._studies[study_id].param_distribution[param_name] = distribution # Set param. trial = copy.copy(trial) trial.params = copy.copy(trial.params) trial.params[param_name] = distribution.to_external_repr(param_value_internal) trial.distributions = copy.copy(trial.distributions) trial.distributions[param_name] = distribution self._set_trial(trial_id, trial) def get_trial_id_from_study_id_trial_number(self, study_id: int, trial_number: int) -> int: with self._lock: study = self._studies.get(study_id) if study is None: raise KeyError("No study with study_id {} exists.".format(study_id)) trials = study.trials if len(trials) <= trial_number: raise KeyError( "No trial with trial number {} exists in study with study_id {}.".format( trial_number, study_id ) ) trial = trials[trial_number] assert trial.number == trial_number return trial._trial_id def get_trial_number_from_id(self, trial_id: int) -> int: with self._lock: self._check_trial_id(trial_id) return self._trial_id_to_study_id_and_number[trial_id][1] def get_best_trial(self, study_id: int) -> FrozenTrial: with self._lock: self._check_study_id(study_id) best_trial_id = self._studies[study_id].best_trial_id if best_trial_id is None: raise ValueError("No trials are completed yet.") elif len(self._studies[study_id].directions) > 1: raise RuntimeError( "Best trial can be obtained only for single-objective optimization." ) return self.get_trial(best_trial_id) def get_trial_param(self, trial_id: int, param_name: str) -> float: with self._lock: trial = self._get_trial(trial_id) distribution = trial.distributions[param_name] return distribution.to_internal_repr(trial.params[param_name]) def set_trial_state_values( self, trial_id: int, state: TrialState, values: Sequence[float] | None = None ) -> bool: with self._lock: trial = copy.copy(self._get_trial(trial_id)) self.check_trial_is_updatable(trial_id, trial.state) if state == TrialState.RUNNING and trial.state != TrialState.WAITING: return False trial.state = state if values is not None: trial.values = values if state == TrialState.RUNNING: trial.datetime_start = datetime.now() if state.is_finished(): trial.datetime_complete = datetime.now() self._set_trial(trial_id, trial) study_id = self._trial_id_to_study_id_and_number[trial_id][0] self._update_cache(trial_id, study_id) else: self._set_trial(trial_id, trial) return True def _update_cache(self, trial_id: int, study_id: int) -> None: trial = self._get_trial(trial_id) if trial.state != TrialState.COMPLETE: return best_trial_id = self._studies[study_id].best_trial_id if best_trial_id is None: self._studies[study_id].best_trial_id = trial_id return _directions = self.get_study_directions(study_id) if len(_directions) > 1: return direction = _directions[0] best_trial = self._get_trial(best_trial_id) assert best_trial is not None if best_trial.value is None: self._studies[study_id].best_trial_id = trial_id return # Complete trials do not have `None` values. assert trial.value is not None best_value = best_trial.value new_value = trial.value if direction == StudyDirection.MAXIMIZE: if best_value < new_value: self._studies[study_id].best_trial_id = trial_id else: if best_value > new_value: self._studies[study_id].best_trial_id = trial_id def set_trial_intermediate_value( self, trial_id: int, step: int, intermediate_value: float ) -> None: with self._lock: trial = self._get_trial(trial_id) self.check_trial_is_updatable(trial_id, trial.state) trial = copy.copy(trial) trial.intermediate_values = copy.copy(trial.intermediate_values) trial.intermediate_values[step] = intermediate_value self._set_trial(trial_id, trial) def set_trial_user_attr(self, trial_id: int, key: str, value: Any) -> None: with self._lock: self._check_trial_id(trial_id) trial = self._get_trial(trial_id) self.check_trial_is_updatable(trial_id, trial.state) trial = copy.copy(trial) trial.user_attrs = copy.copy(trial.user_attrs) trial.user_attrs[key] = value self._set_trial(trial_id, trial) def set_trial_system_attr(self, trial_id: int, key: str, value: JSONSerializable) -> None: with self._lock: trial = self._get_trial(trial_id) self.check_trial_is_updatable(trial_id, trial.state) trial = copy.copy(trial) trial.system_attrs = copy.copy(trial.system_attrs) trial.system_attrs[key] = value self._set_trial(trial_id, trial) def get_trial(self, trial_id: int) -> FrozenTrial: with self._lock: return self._get_trial(trial_id) def _get_trial(self, trial_id: int) -> FrozenTrial: self._check_trial_id(trial_id) study_id, trial_number = self._trial_id_to_study_id_and_number[trial_id] return self._studies[study_id].trials[trial_number] def _set_trial(self, trial_id: int, trial: FrozenTrial) -> None: study_id, trial_number = self._trial_id_to_study_id_and_number[trial_id] self._studies[study_id].trials[trial_number] = trial def get_all_trials( self, study_id: int, deepcopy: bool = True, states: Container[TrialState] | None = None, ) -> list[FrozenTrial]: with self._lock: self._check_study_id(study_id) # Optimized retrieval of trials in the WAITING state to improve performance # for the call, `get_all_trials(states=(TrialState.WAITING,))`. if states == (TrialState.WAITING,): trials: list[FrozenTrial] = [] for trial in self._studies[study_id].trials[ self._prev_waiting_trial_number[study_id] : ]: if trial.state == TrialState.WAITING: if not trials: self._prev_waiting_trial_number[study_id] = trial.number trials.append(trial) if not trials: self._prev_waiting_trial_number[study_id] = len(self._studies[study_id].trials) else: trials = self._studies[study_id].trials if states is not None: trials = [t for t in trials if t.state in states] if deepcopy: trials = copy.deepcopy(trials) else: # This copy is required for the replacing trick in `set_trial_xxx`. trials = copy.copy(trials) return trials def _check_study_id(self, study_id: int) -> None: if study_id not in self._studies: raise KeyError("No study with study_id {} exists.".format(study_id)) def _check_trial_id(self, trial_id: int) -> None: if trial_id not in self._trial_id_to_study_id_and_number: raise KeyError("No trial with trial_id {} exists.".format(trial_id)) class _StudyInfo: def __init__(self, name: str, directions: list[StudyDirection]) -> None: self.trials: list[FrozenTrial] = [] self.param_distribution: dict[str, distributions.BaseDistribution] = {} self.user_attrs: dict[str, Any] = {} self.system_attrs: dict[str, Any] = {} self.name: str = name self.directions: list[StudyDirection] = directions self.best_trial_id: int | None = None