from __future__ import annotations from collections.abc import Container from collections.abc import Iterable from collections.abc import Sequence import copy import datetime import enum import pickle import threading from typing import Any import uuid import optuna from optuna._typing import JSONSerializable from optuna.distributions import BaseDistribution from optuna.distributions import check_distribution_compatibility from optuna.distributions import distribution_to_json from optuna.distributions import json_to_distribution from optuna.exceptions import DuplicatedStudyError from optuna.exceptions import UpdateFinishedTrialError from optuna.storages import BaseStorage from optuna.storages._base import DEFAULT_STUDY_NAME_PREFIX from optuna.storages.journal._base import BaseJournalBackend from optuna.storages.journal._base import BaseJournalSnapshot from optuna.study._frozen import FrozenStudy from optuna.study._study_direction import StudyDirection from optuna.trial import FrozenTrial from optuna.trial import TrialState _logger = optuna.logging.get_logger(__name__) NOT_FOUND_MSG = "Record does not exist." UNUPDATABLE_MSG = "Trial#{trial_number} has already finished and can not be updated." # A heuristic interval number to dump snapshots SNAPSHOT_INTERVAL = 100 class JournalOperation(enum.IntEnum): CREATE_STUDY = 0 DELETE_STUDY = 1 SET_STUDY_USER_ATTR = 2 SET_STUDY_SYSTEM_ATTR = 3 CREATE_TRIAL = 4 SET_TRIAL_PARAM = 5 SET_TRIAL_STATE_VALUES = 6 SET_TRIAL_INTERMEDIATE_VALUE = 7 SET_TRIAL_USER_ATTR = 8 SET_TRIAL_SYSTEM_ATTR = 9 class JournalStorage(BaseStorage): """Storage class for Journal storage backend. Note that library users can instantiate this class, but the attributes provided by this class are not supposed to be directly accessed by them. Journal storage writes a record of every operation to the database as it is executed and at the same time, keeps a latest snapshot of the database in-memory. If the database crashes for any reason, the storage can re-establish the contents in memory by replaying the operations stored from the beginning. Journal storage has several benefits over the conventional value logging storages. 1. The number of IOs can be reduced because of larger granularity of logs. 2. Journal storage has simpler backend API than value logging storage. 3. Journal storage keeps a snapshot in-memory so no need to add more cache. Example: .. code:: import optuna def objective(trial): ... storage = optuna.storages.JournalStorage( optuna.storages.journal.JournalFileBackend("./optuna_journal_storage.log") ) study = optuna.create_study(storage=storage) study.optimize(objective) In a Windows environment, an error message "A required privilege is not held by the client" may appear. In this case, you can solve the problem with creating storage by specifying :class:`~optuna.storages.journal.JournalFileOpenLock` as follows. .. code:: file_path = "./optuna_journal_storage.log" lock_obj = optuna.storages.journal.JournalFileOpenLock(file_path) storage = optuna.storages.JournalStorage( optuna.storages.journal.JournalFileBackend(file_path, lock_obj=lock_obj), ) """ def __init__(self, log_storage: BaseJournalBackend) -> None: self._worker_id_prefix = str(uuid.uuid4()) + "-" self._backend = log_storage self._thread_lock = threading.Lock() self._replay_result = JournalStorageReplayResult(self._worker_id_prefix) with self._thread_lock: if isinstance(self._backend, BaseJournalSnapshot): snapshot = self._backend.load_snapshot() if snapshot is not None: self.restore_replay_result(snapshot) self._sync_with_backend() def __getstate__(self) -> dict[Any, Any]: state = self.__dict__.copy() del state["_worker_id_prefix"] del state["_replay_result"] del state["_thread_lock"] return state def __setstate__(self, state: dict[Any, Any]) -> None: self.__dict__.update(state) self._worker_id_prefix = str(uuid.uuid4()) + "-" self._replay_result = JournalStorageReplayResult(self._worker_id_prefix) self._thread_lock = threading.Lock() def restore_replay_result(self, snapshot: bytes) -> None: try: r: JournalStorageReplayResult | None = pickle.loads(snapshot) except (pickle.UnpicklingError, KeyError): _logger.warning("Failed to restore `JournalStorageReplayResult`.") return if r is None: return if not isinstance(r, JournalStorageReplayResult): _logger.warning("The restored object is not `JournalStorageReplayResult`.") return r._worker_id_prefix = self._worker_id_prefix r._worker_id_to_owned_trial_id = {} r._last_created_trial_id_by_this_process = -1 self._replay_result = r def _write_log(self, op_code: int, extra_fields: dict[str, Any]) -> None: worker_id = self._replay_result.worker_id self._backend.append_logs([{"op_code": op_code, "worker_id": worker_id, **extra_fields}]) def _sync_with_backend(self) -> None: logs = self._backend.read_logs(self._replay_result.log_number_read) self._replay_result.apply_logs(logs) def create_new_study( self, directions: Sequence[StudyDirection], study_name: str | None = None ) -> int: study_name = study_name or DEFAULT_STUDY_NAME_PREFIX + str(uuid.uuid4()) with self._thread_lock: self._write_log( JournalOperation.CREATE_STUDY, {"study_name": study_name, "directions": directions} ) self._sync_with_backend() for frozen_study in self._replay_result.get_all_studies(): if frozen_study.study_name != study_name: continue _logger.info("A new study created in Journal with name: {}".format(study_name)) study_id = frozen_study._study_id # Dump snapshot here. if ( isinstance(self._backend, BaseJournalSnapshot) and study_id != 0 and study_id % SNAPSHOT_INTERVAL == 0 ): self._backend.save_snapshot(pickle.dumps(self._replay_result)) return study_id assert False, "Should not reach." def delete_study(self, study_id: int) -> None: with self._thread_lock: self._write_log(JournalOperation.DELETE_STUDY, {"study_id": study_id}) self._sync_with_backend() def set_study_user_attr(self, study_id: int, key: str, value: Any) -> None: log: dict[str, Any] = {"study_id": study_id, "user_attr": {key: value}} with self._thread_lock: self._write_log(JournalOperation.SET_STUDY_USER_ATTR, log) self._sync_with_backend() def set_study_system_attr(self, study_id: int, key: str, value: JSONSerializable) -> None: log: dict[str, Any] = {"study_id": study_id, "system_attr": {key: value}} with self._thread_lock: self._write_log(JournalOperation.SET_STUDY_SYSTEM_ATTR, log) self._sync_with_backend() def get_study_id_from_name(self, study_name: str) -> int: with self._thread_lock: self._sync_with_backend() for study in self._replay_result.get_all_studies(): if study.study_name == study_name: return study._study_id raise KeyError(NOT_FOUND_MSG) def get_study_name_from_id(self, study_id: int) -> str: with self._thread_lock: self._sync_with_backend() return self._replay_result.get_study(study_id).study_name def get_study_directions(self, study_id: int) -> list[StudyDirection]: with self._thread_lock: self._sync_with_backend() return self._replay_result.get_study(study_id).directions def get_study_user_attrs(self, study_id: int) -> dict[str, Any]: with self._thread_lock: self._sync_with_backend() return self._replay_result.get_study(study_id).user_attrs def get_study_system_attrs(self, study_id: int) -> dict[str, Any]: with self._thread_lock: self._sync_with_backend() return self._replay_result.get_study(study_id).system_attrs def get_all_studies(self) -> list[FrozenStudy]: with self._thread_lock: self._sync_with_backend() return copy.deepcopy(self._replay_result.get_all_studies()) # Basic trial manipulation def create_new_trial(self, study_id: int, template_trial: FrozenTrial | None = None) -> int: log: dict[str, Any] = { "study_id": study_id, "datetime_start": datetime.datetime.now().isoformat(timespec="microseconds"), } if template_trial: log["state"] = template_trial.state if template_trial.values is not None and len(template_trial.values) > 1: log["value"] = None log["values"] = template_trial.values else: log["value"] = template_trial.value log["values"] = None if template_trial.datetime_start: log["datetime_start"] = template_trial.datetime_start.isoformat( timespec="microseconds" ) else: log["datetime_start"] = None if template_trial.datetime_complete: log["datetime_complete"] = template_trial.datetime_complete.isoformat( timespec="microseconds" ) log["distributions"] = { k: distribution_to_json(dist) for k, dist in template_trial.distributions.items() } log["params"] = { k: template_trial.distributions[k].to_internal_repr(param) for k, param in template_trial.params.items() } log["user_attrs"] = template_trial.user_attrs log["system_attrs"] = template_trial.system_attrs log["intermediate_values"] = template_trial.intermediate_values with self._thread_lock: self._write_log(JournalOperation.CREATE_TRIAL, log) self._sync_with_backend() trial_id = self._replay_result._last_created_trial_id_by_this_process # Dump snapshot here. if ( isinstance(self._backend, BaseJournalSnapshot) and trial_id != 0 and trial_id % SNAPSHOT_INTERVAL == 0 ): self._backend.save_snapshot(pickle.dumps(self._replay_result)) return trial_id def set_trial_param( self, trial_id: int, param_name: str, param_value_internal: float, distribution: BaseDistribution, ) -> None: log: dict[str, Any] = { "trial_id": trial_id, "param_name": param_name, "param_value_internal": param_value_internal, "distribution": distribution_to_json(distribution), } with self._thread_lock: self._write_log(JournalOperation.SET_TRIAL_PARAM, log) self._sync_with_backend() def get_trial_id_from_study_id_trial_number(self, study_id: int, trial_number: int) -> int: with self._thread_lock: self._sync_with_backend() if len(self._replay_result._study_id_to_trial_ids[study_id]) <= trial_number: raise KeyError( "No trial with trial number {} exists in study with study_id {}.".format( trial_number, study_id ) ) return self._replay_result._study_id_to_trial_ids[study_id][trial_number] def set_trial_state_values( self, trial_id: int, state: TrialState, values: Sequence[float] | None = None ) -> bool: log: dict[str, Any] = { "trial_id": trial_id, "state": state, "values": values, } if state == TrialState.RUNNING: log["datetime_start"] = datetime.datetime.now().isoformat(timespec="microseconds") elif state.is_finished(): log["datetime_complete"] = datetime.datetime.now().isoformat(timespec="microseconds") with self._thread_lock: if state == TrialState.RUNNING: # NOTE(nabenabe): This sync is not necessary because the last # set_trial_state_values call by the same thread always syncs before the true pop, # but I keep it here to avoid the confusion. Anyways, this section isn't triggered # that often because this section is only for enqueue_trial. self._sync_with_backend() # NOTE(nabenabe): This section is triggered only when we are using `enqueue_trial` # and `GrpcProxyStorage` in distributed optimization setups and solves the issue # https://github.com/optuna/optuna/issues/6084. # When using gRPC, the current thread may already have popped the trial with # trial_id for another process, potentially leading to a false positive in the # return statement of trial_id == _replay_result.owned_trial_id. To eliminate false # positives, we verify whether another process is already evaluating the trial with # trial_id. If True, it means this query does not update the trial state. existing_trial = self._replay_result._trials.get(trial_id) assert ( existing_trial is not None ), "Please report your bug on GitHub if this line fails your script." if existing_trial.state.is_finished(): raise UpdateFinishedTrialError( UNUPDATABLE_MSG.format(trial_number=existing_trial.number) ) if existing_trial.state != TrialState.WAITING: # This line is equivalent to `existing_trial.state == TrialState.RUNNING`. return False self._write_log(JournalOperation.SET_TRIAL_STATE_VALUES, log) self._sync_with_backend() return state != TrialState.RUNNING or trial_id == self._replay_result.owned_trial_id def set_trial_intermediate_value( self, trial_id: int, step: int, intermediate_value: float ) -> None: log: dict[str, Any] = { "trial_id": trial_id, "step": step, "intermediate_value": intermediate_value, } with self._thread_lock: self._write_log(JournalOperation.SET_TRIAL_INTERMEDIATE_VALUE, log) self._sync_with_backend() def set_trial_user_attr(self, trial_id: int, key: str, value: Any) -> None: log: dict[str, Any] = { "trial_id": trial_id, "user_attr": {key: value}, } with self._thread_lock: self._write_log(JournalOperation.SET_TRIAL_USER_ATTR, log) self._sync_with_backend() def set_trial_system_attr(self, trial_id: int, key: str, value: JSONSerializable) -> None: log: dict[str, Any] = { "trial_id": trial_id, "system_attr": {key: value}, } with self._thread_lock: self._write_log(JournalOperation.SET_TRIAL_SYSTEM_ATTR, log) self._sync_with_backend() def get_trial(self, trial_id: int) -> FrozenTrial: with self._thread_lock: self._sync_with_backend() return self._replay_result.get_trial(trial_id) def get_all_trials( self, study_id: int, deepcopy: bool = True, states: Container[TrialState] | None = None, ) -> list[FrozenTrial]: with self._thread_lock: self._sync_with_backend() frozen_trials = self._replay_result.get_all_trials(study_id, states) if deepcopy: return copy.deepcopy(frozen_trials) return frozen_trials class JournalStorageReplayResult: def __init__(self, worker_id_prefix: str) -> None: self.log_number_read = 0 self._worker_id_prefix = worker_id_prefix self._studies: dict[int, FrozenStudy] = {} self._trials: dict[int, FrozenTrial] = {} self._study_id_to_trial_ids: dict[int, list[int]] = {} self._trial_id_to_study_id: dict[int, int] = {} self._next_study_id: int = 0 self._worker_id_to_owned_trial_id: dict[str, int] = {} def apply_logs(self, logs: Iterable[dict[str, Any]]) -> None: for log in logs: self.log_number_read += 1 op = log["op_code"] if op == JournalOperation.CREATE_STUDY: self._apply_create_study(log) elif op == JournalOperation.DELETE_STUDY: self._apply_delete_study(log) elif op == JournalOperation.SET_STUDY_USER_ATTR: self._apply_set_study_user_attr(log) elif op == JournalOperation.SET_STUDY_SYSTEM_ATTR: self._apply_set_study_system_attr(log) elif op == JournalOperation.CREATE_TRIAL: self._apply_create_trial(log) elif op == JournalOperation.SET_TRIAL_PARAM: self._apply_set_trial_param(log) elif op == JournalOperation.SET_TRIAL_STATE_VALUES: self._apply_set_trial_state_values(log) elif op == JournalOperation.SET_TRIAL_INTERMEDIATE_VALUE: self._apply_set_trial_intermediate_value(log) elif op == JournalOperation.SET_TRIAL_USER_ATTR: self._apply_set_trial_user_attr(log) elif op == JournalOperation.SET_TRIAL_SYSTEM_ATTR: self._apply_set_trial_system_attr(log) else: assert False, "Should not reach." def get_study(self, study_id: int) -> FrozenStudy: if study_id not in self._studies: raise KeyError(NOT_FOUND_MSG) return self._studies[study_id] def get_all_studies(self) -> list[FrozenStudy]: return list(self._studies.values()) def get_trial(self, trial_id: int) -> FrozenTrial: if trial_id not in self._trials: raise KeyError(NOT_FOUND_MSG) return self._trials[trial_id] def get_all_trials( self, study_id: int, states: Container[TrialState] | None ) -> list[FrozenTrial]: if study_id not in self._studies: raise KeyError(NOT_FOUND_MSG) frozen_trials: list[FrozenTrial] = [] for trial_id in self._study_id_to_trial_ids[study_id]: trial = self._trials[trial_id] if states is None or trial.state in states: frozen_trials.append(trial) return frozen_trials @property def worker_id(self) -> str: return self._worker_id_prefix + str(threading.get_ident()) @property def owned_trial_id(self) -> int | None: return self._worker_id_to_owned_trial_id.get(self.worker_id) def _is_issued_by_this_worker(self, log: dict[str, Any]) -> bool: return log["worker_id"] == self.worker_id def _study_exists(self, study_id: int, log: dict[str, Any]) -> bool: if study_id in self._studies: return True if self._is_issued_by_this_worker(log): raise KeyError(NOT_FOUND_MSG) return False def _apply_create_study(self, log: dict[str, Any]) -> None: study_name = log["study_name"] directions = [StudyDirection(d) for d in log["directions"]] if study_name in [s.study_name for s in self._studies.values()]: if self._is_issued_by_this_worker(log): raise DuplicatedStudyError( "Another study with name '{}' already exists. " "Please specify a different name, or reuse the existing one " "by setting `load_if_exists` (for Python API) or " "`--skip-if-exists` flag (for CLI).".format(study_name) ) return study_id = self._next_study_id self._next_study_id += 1 self._studies[study_id] = FrozenStudy( study_name=study_name, direction=None, user_attrs={}, system_attrs={}, study_id=study_id, directions=directions, ) self._study_id_to_trial_ids[study_id] = [] def _apply_delete_study(self, log: dict[str, Any]) -> None: study_id = log["study_id"] if self._study_exists(study_id, log): fs = self._studies.pop(study_id) assert fs._study_id == study_id def _apply_set_study_user_attr(self, log: dict[str, Any]) -> None: study_id = log["study_id"] if self._study_exists(study_id, log): assert len(log["user_attr"]) == 1 self._studies[study_id].user_attrs.update(log["user_attr"]) def _apply_set_study_system_attr(self, log: dict[str, Any]) -> None: study_id = log["study_id"] if self._study_exists(study_id, log): assert len(log["system_attr"]) == 1 self._studies[study_id].system_attrs.update(log["system_attr"]) def _apply_create_trial(self, log: dict[str, Any]) -> None: study_id = log["study_id"] if not self._study_exists(study_id, log): return trial_id = len(self._trials) distributions = {} if "distributions" in log: distributions = {k: json_to_distribution(v) for k, v in log["distributions"].items()} params = {} if "params" in log: params = {k: distributions[k].to_external_repr(p) for k, p in log["params"].items()} if log["datetime_start"] is not None: datetime_start = datetime.datetime.fromisoformat(log["datetime_start"]) else: datetime_start = None if "datetime_complete" in log: datetime_complete = datetime.datetime.fromisoformat(log["datetime_complete"]) else: datetime_complete = None self._trials[trial_id] = FrozenTrial( trial_id=trial_id, number=len(self._study_id_to_trial_ids[study_id]), state=TrialState(log.get("state", TrialState.RUNNING.value)), params=params, distributions=distributions, user_attrs=log.get("user_attrs", {}), system_attrs=log.get("system_attrs", {}), value=log.get("value", None), intermediate_values={int(k): v for k, v in log.get("intermediate_values", {}).items()}, datetime_start=datetime_start, datetime_complete=datetime_complete, values=log.get("values", None), ) self._study_id_to_trial_ids[study_id].append(trial_id) self._trial_id_to_study_id[trial_id] = study_id if self._is_issued_by_this_worker(log): self._last_created_trial_id_by_this_process = trial_id if self._trials[trial_id].state == TrialState.RUNNING: self._worker_id_to_owned_trial_id[self.worker_id] = trial_id def _apply_set_trial_param(self, log: dict[str, Any]) -> None: trial_id = log["trial_id"] if not self._trial_exists_and_updatable(trial_id, log): return param_name = log["param_name"] param_value_internal = log["param_value_internal"] distribution = json_to_distribution(log["distribution"]) study_id = self._trial_id_to_study_id[trial_id] for prev_trial_id in self._study_id_to_trial_ids[study_id]: prev_trial = self._trials[prev_trial_id] if param_name in prev_trial.params.keys(): try: check_distribution_compatibility( prev_trial.distributions[param_name], distribution ) except Exception: if self._is_issued_by_this_worker(log): raise return break trial = copy.copy(self._trials[trial_id]) trial.params = { **copy.copy(trial.params), param_name: distribution.to_external_repr(param_value_internal), } trial.distributions = {**copy.copy(trial.distributions), param_name: distribution} self._trials[trial_id] = trial def _apply_set_trial_state_values(self, log: dict[str, Any]) -> None: trial_id = log["trial_id"] if not self._trial_exists_and_updatable(trial_id, log): return state = TrialState(log["state"]) if state == self._trials[trial_id].state and state == TrialState.RUNNING: # Reject the operation as the popped trial is already run by another process. return trial = copy.copy(self._trials[trial_id]) if state == TrialState.RUNNING: trial.datetime_start = datetime.datetime.fromisoformat(log["datetime_start"]) if self._is_issued_by_this_worker(log): self._worker_id_to_owned_trial_id[self.worker_id] = trial_id if state.is_finished(): trial.datetime_complete = datetime.datetime.fromisoformat(log["datetime_complete"]) trial.state = state if log["values"] is not None: trial.values = log["values"] self._trials[trial_id] = trial def _apply_set_trial_intermediate_value(self, log: dict[str, Any]) -> None: trial_id = log["trial_id"] if self._trial_exists_and_updatable(trial_id, log): trial = copy.copy(self._trials[trial_id]) trial.intermediate_values = { **copy.copy(trial.intermediate_values), log["step"]: log["intermediate_value"], } self._trials[trial_id] = trial def _apply_set_trial_user_attr(self, log: dict[str, Any]) -> None: trial_id = log["trial_id"] if self._trial_exists_and_updatable(trial_id, log): assert len(log["user_attr"]) == 1 trial = copy.copy(self._trials[trial_id]) trial.user_attrs = {**copy.copy(trial.user_attrs), **log["user_attr"]} self._trials[trial_id] = trial def _apply_set_trial_system_attr(self, log: dict[str, Any]) -> None: trial_id = log["trial_id"] if self._trial_exists_and_updatable(trial_id, log): assert len(log["system_attr"]) == 1 trial = copy.copy(self._trials[trial_id]) trial.system_attrs = { **copy.copy(trial.system_attrs), **log["system_attr"], } self._trials[trial_id] = trial def _trial_exists_and_updatable(self, trial_id: int, log: dict[str, Any]) -> bool: if trial_id not in self._trials: if self._is_issued_by_this_worker(log): raise KeyError(NOT_FOUND_MSG) return False elif self._trials[trial_id].state.is_finished(): if self._is_issued_by_this_worker(log): raise UpdateFinishedTrialError( UNUPDATABLE_MSG.format(trial_number=self._trials[trial_id].number) ) return False else: return True