Source code for optuna.storages._journal.storage

import copy
import datetime
import enum
import pickle
import threading
from typing import Any
from typing import Container
from typing import Dict
from typing import List
from typing import Optional
from typing import Sequence
import uuid

import optuna
from optuna._experimental import experimental_class
from optuna._typing import JSONSerializable
from optuna.distributions import BaseDistribution
from optuna.distributions import check_distribution_compatibility
from optuna.distributions import distribution_to_json
from optuna.distributions import json_to_distribution
from optuna.exceptions import DuplicatedStudyError
from optuna.storages import BaseStorage
from optuna.storages._base import DEFAULT_STUDY_NAME_PREFIX
from optuna.storages._journal.base import BaseJournalLogSnapshot
from optuna.storages._journal.base import BaseJournalLogStorage
from optuna.study._frozen import FrozenStudy
from optuna.study._study_direction import StudyDirection
from optuna.trial import FrozenTrial
from optuna.trial import TrialState


_logger = optuna.logging.get_logger(__name__)

NOT_FOUND_MSG = "Record does not exist."
# A heuristic interval number to dump snapshots
SNAPSHOT_INTERVAL = 100


class JournalOperation(enum.IntEnum):
    CREATE_STUDY = 0
    DELETE_STUDY = 1
    SET_STUDY_USER_ATTR = 2
    SET_STUDY_SYSTEM_ATTR = 3
    CREATE_TRIAL = 4
    SET_TRIAL_PARAM = 5
    SET_TRIAL_STATE_VALUES = 6
    SET_TRIAL_INTERMEDIATE_VALUE = 7
    SET_TRIAL_USER_ATTR = 8
    SET_TRIAL_SYSTEM_ATTR = 9


[docs]@experimental_class("3.1.0") class JournalStorage(BaseStorage): """Storage class for Journal storage backend. Note that library users can instantiate this class, but the attributes provided by this class are not supposed to be directly accessed by them. Journal storage writes a record of every operation to the database as it is executed and at the same time, keeps a latest snapshot of the database in-memory. If the database crashes for any reason, the storage can re-establish the contents in memory by replaying the operations stored from the beginning. Journal storage has several benefits over the conventional value logging storages. 1. The number of IOs can be reduced because of larger granularity of logs. 2. Journal storage has simpler backend API than value logging storage. 3. Journal storage keeps a snapshot in-memory so no need to add more cache. Example: .. code:: import optuna def objective(trial): ... storage = optuna.storages.JournalStorage( optuna.storages.JournalFileStorage("./journal.log"), ) study = optuna.create_study(storage=storage) study.optimize(objective) In a Windows environment, an error message "A required privilege is not held by the client" may appear. In this case, you can solve the problem with creating storage by specifying :class:`~optuna.storages.JournalFileOpenLock` as follows. .. code:: file_path = "./journal.log" lock_obj = optuna.storages.JournalFileOpenLock(file_path) storage = optuna.storages.JournalStorage( optuna.storages.JournalFileStorage(file_path, lock_obj=lock_obj), ) """ def __init__(self, log_storage: BaseJournalLogStorage) -> None: self._worker_id_prefix = str(uuid.uuid4()) + "-" self._backend = log_storage self._thread_lock = threading.Lock() self._replay_result = JournalStorageReplayResult(self._worker_id_prefix) with self._thread_lock: if isinstance(self._backend, BaseJournalLogSnapshot): snapshot = self._backend.load_snapshot() if snapshot is not None: self.restore_replay_result(snapshot) self._sync_with_backend() def __getstate__(self) -> Dict[Any, Any]: state = self.__dict__.copy() del state["_worker_id_prefix"] del state["_replay_result"] del state["_thread_lock"] return state def __setstate__(self, state: Dict[Any, Any]) -> None: self.__dict__.update(state) self._worker_id_prefix = str(uuid.uuid4()) + "-" self._replay_result = JournalStorageReplayResult(self._worker_id_prefix) self._thread_lock = threading.Lock() def restore_replay_result(self, snapshot: bytes) -> None: try: r: Optional[JournalStorageReplayResult] = pickle.loads(snapshot) except (pickle.UnpicklingError, KeyError): _logger.warning("Failed to restore `JournalStorageReplayResult`.") return if r is None: return if not isinstance(r, JournalStorageReplayResult): _logger.warning("The restored object is not `JournalStorageReplayResult`.") return r._worker_id_prefix = self._worker_id_prefix r._worker_id_to_owned_trial_id = {} r._last_created_trial_id_by_this_process = -1 self._replay_result = r def _write_log(self, op_code: int, extra_fields: Dict[str, Any]) -> None: worker_id = self._replay_result.worker_id self._backend.append_logs([{"op_code": op_code, "worker_id": worker_id, **extra_fields}]) def _sync_with_backend(self) -> None: logs = self._backend.read_logs(self._replay_result.log_number_read) self._replay_result.apply_logs(logs)
[docs] def create_new_study( self, directions: Sequence[StudyDirection], study_name: Optional[str] = None ) -> int: study_name = study_name or DEFAULT_STUDY_NAME_PREFIX + str(uuid.uuid4()) with self._thread_lock: self._write_log( JournalOperation.CREATE_STUDY, {"study_name": study_name, "directions": directions} ) self._sync_with_backend() for frozen_study in self._replay_result.get_all_studies(): if frozen_study.study_name != study_name: continue _logger.info("A new study created in Journal with name: {}".format(study_name)) study_id = frozen_study._study_id # Dump snapshot here. if ( isinstance(self._backend, BaseJournalLogSnapshot) and study_id != 0 and study_id % SNAPSHOT_INTERVAL == 0 ): self._backend.save_snapshot(pickle.dumps(self._replay_result)) return study_id assert False, "Should not reach."
[docs] def delete_study(self, study_id: int) -> None: with self._thread_lock: self._write_log(JournalOperation.DELETE_STUDY, {"study_id": study_id}) self._sync_with_backend()
[docs] def set_study_user_attr(self, study_id: int, key: str, value: Any) -> None: log: Dict[str, Any] = {"study_id": study_id, "user_attr": {key: value}} with self._thread_lock: self._write_log(JournalOperation.SET_STUDY_USER_ATTR, log) self._sync_with_backend()
[docs] def set_study_system_attr(self, study_id: int, key: str, value: JSONSerializable) -> None: log: Dict[str, Any] = {"study_id": study_id, "system_attr": {key: value}} with self._thread_lock: self._write_log(JournalOperation.SET_STUDY_SYSTEM_ATTR, log) self._sync_with_backend()
[docs] def get_study_id_from_name(self, study_name: str) -> int: with self._thread_lock: self._sync_with_backend() for study in self._replay_result.get_all_studies(): if study.study_name == study_name: return study._study_id raise KeyError(NOT_FOUND_MSG)
[docs] def get_study_name_from_id(self, study_id: int) -> str: with self._thread_lock: self._sync_with_backend() return self._replay_result.get_study(study_id).study_name
[docs] def get_study_directions(self, study_id: int) -> List[StudyDirection]: with self._thread_lock: self._sync_with_backend() return self._replay_result.get_study(study_id).directions
[docs] def get_study_user_attrs(self, study_id: int) -> Dict[str, Any]: with self._thread_lock: self._sync_with_backend() return self._replay_result.get_study(study_id).user_attrs
[docs] def get_study_system_attrs(self, study_id: int) -> Dict[str, Any]: with self._thread_lock: self._sync_with_backend() return self._replay_result.get_study(study_id).system_attrs
[docs] def get_all_studies(self) -> List[FrozenStudy]: with self._thread_lock: self._sync_with_backend() return copy.deepcopy(self._replay_result.get_all_studies())
# Basic trial manipulation
[docs] def create_new_trial(self, study_id: int, template_trial: Optional[FrozenTrial] = None) -> int: log: Dict[str, Any] = { "study_id": study_id, "datetime_start": datetime.datetime.now().isoformat(timespec="microseconds"), } if template_trial: log["state"] = template_trial.state if template_trial.values is not None and len(template_trial.values) > 1: log["value"] = None log["values"] = template_trial.values else: log["value"] = template_trial.value log["values"] = None if template_trial.datetime_start: log["datetime_start"] = template_trial.datetime_start.isoformat( timespec="microseconds" ) else: log["datetime_start"] = None if template_trial.datetime_complete: log["datetime_complete"] = template_trial.datetime_complete.isoformat( timespec="microseconds" ) log["distributions"] = { k: distribution_to_json(dist) for k, dist in template_trial.distributions.items() } log["params"] = { k: template_trial.distributions[k].to_internal_repr(param) for k, param in template_trial.params.items() } log["user_attrs"] = template_trial.user_attrs log["system_attrs"] = template_trial.system_attrs log["intermediate_values"] = template_trial.intermediate_values with self._thread_lock: self._write_log(JournalOperation.CREATE_TRIAL, log) self._sync_with_backend() trial_id = self._replay_result._last_created_trial_id_by_this_process # Dump snapshot here. if ( isinstance(self._backend, BaseJournalLogSnapshot) and trial_id != 0 and trial_id % SNAPSHOT_INTERVAL == 0 ): self._backend.save_snapshot(pickle.dumps(self._replay_result)) return trial_id
[docs] def set_trial_param( self, trial_id: int, param_name: str, param_value_internal: float, distribution: BaseDistribution, ) -> None: log: Dict[str, Any] = { "trial_id": trial_id, "param_name": param_name, "param_value_internal": param_value_internal, "distribution": distribution_to_json(distribution), } with self._thread_lock: self._write_log(JournalOperation.SET_TRIAL_PARAM, log) self._sync_with_backend()
[docs] def get_trial_id_from_study_id_trial_number(self, study_id: int, trial_number: int) -> int: with self._thread_lock: self._sync_with_backend() if len(self._replay_result._study_id_to_trial_ids[study_id]) <= trial_number: raise KeyError( "No trial with trial number {} exists in study with study_id {}.".format( trial_number, study_id ) ) return self._replay_result._study_id_to_trial_ids[study_id][trial_number]
[docs] def set_trial_state_values( self, trial_id: int, state: TrialState, values: Optional[Sequence[float]] = None ) -> bool: log: Dict[str, Any] = { "trial_id": trial_id, "state": state, "values": values, } if state == TrialState.RUNNING: log["datetime_start"] = datetime.datetime.now().isoformat(timespec="microseconds") elif state.is_finished(): log["datetime_complete"] = datetime.datetime.now().isoformat(timespec="microseconds") with self._thread_lock: self._write_log(JournalOperation.SET_TRIAL_STATE_VALUES, log) self._sync_with_backend() if state == TrialState.RUNNING and trial_id != self._replay_result.owned_trial_id: return False else: return True
[docs] def set_trial_intermediate_value( self, trial_id: int, step: int, intermediate_value: float ) -> None: log: Dict[str, Any] = { "trial_id": trial_id, "step": step, "intermediate_value": intermediate_value, } with self._thread_lock: self._write_log(JournalOperation.SET_TRIAL_INTERMEDIATE_VALUE, log) self._sync_with_backend()
[docs] def set_trial_user_attr(self, trial_id: int, key: str, value: Any) -> None: log: Dict[str, Any] = { "trial_id": trial_id, "user_attr": {key: value}, } with self._thread_lock: self._write_log(JournalOperation.SET_TRIAL_USER_ATTR, log) self._sync_with_backend()
[docs] def set_trial_system_attr(self, trial_id: int, key: str, value: JSONSerializable) -> None: log: Dict[str, Any] = { "trial_id": trial_id, "system_attr": {key: value}, } with self._thread_lock: self._write_log(JournalOperation.SET_TRIAL_SYSTEM_ATTR, log) self._sync_with_backend()
[docs] def get_trial(self, trial_id: int) -> FrozenTrial: with self._thread_lock: self._sync_with_backend() return self._replay_result.get_trial(trial_id)
[docs] def get_all_trials( self, study_id: int, deepcopy: bool = True, states: Optional[Container[TrialState]] = None, ) -> List[FrozenTrial]: with self._thread_lock: self._sync_with_backend() frozen_trials = self._replay_result.get_all_trials(study_id, states) if deepcopy: return copy.deepcopy(frozen_trials) return frozen_trials
class JournalStorageReplayResult: def __init__(self, worker_id_prefix: str) -> None: self.log_number_read = 0 self._worker_id_prefix = worker_id_prefix self._studies: Dict[int, FrozenStudy] = {} self._trials: Dict[int, FrozenTrial] = {} self._study_id_to_trial_ids: Dict[int, List[int]] = {} self._trial_id_to_study_id: Dict[int, int] = {} self._next_study_id: int = 0 self._worker_id_to_owned_trial_id: Dict[str, int] = {} def apply_logs(self, logs: List[Dict[str, Any]]) -> None: for log in logs: self.log_number_read += 1 op = log["op_code"] if op == JournalOperation.CREATE_STUDY: self._apply_create_study(log) elif op == JournalOperation.DELETE_STUDY: self._apply_delete_study(log) elif op == JournalOperation.SET_STUDY_USER_ATTR: self._apply_set_study_user_attr(log) elif op == JournalOperation.SET_STUDY_SYSTEM_ATTR: self._apply_set_study_system_attr(log) elif op == JournalOperation.CREATE_TRIAL: self._apply_create_trial(log) elif op == JournalOperation.SET_TRIAL_PARAM: self._apply_set_trial_param(log) elif op == JournalOperation.SET_TRIAL_STATE_VALUES: self._apply_set_trial_state_values(log) elif op == JournalOperation.SET_TRIAL_INTERMEDIATE_VALUE: self._apply_set_trial_intermediate_value(log) elif op == JournalOperation.SET_TRIAL_USER_ATTR: self._apply_set_trial_user_attr(log) elif op == JournalOperation.SET_TRIAL_SYSTEM_ATTR: self._apply_set_trial_system_attr(log) else: assert False, "Should not reach." def get_study(self, study_id: int) -> FrozenStudy: if study_id not in self._studies: raise KeyError(NOT_FOUND_MSG) return self._studies[study_id] def get_all_studies(self) -> List[FrozenStudy]: return list(self._studies.values()) def get_trial(self, trial_id: int) -> FrozenTrial: if trial_id not in self._trials: raise KeyError(NOT_FOUND_MSG) return self._trials[trial_id] def get_all_trials( self, study_id: int, states: Optional[Container[TrialState]] ) -> List[FrozenTrial]: if study_id not in self._studies: raise KeyError(NOT_FOUND_MSG) frozen_trials: List[FrozenTrial] = [] for trial_id in self._study_id_to_trial_ids[study_id]: trial = self._trials[trial_id] if states is None or trial.state in states: frozen_trials.append(trial) return frozen_trials @property def worker_id(self) -> str: return self._worker_id_prefix + str(threading.get_ident()) @property def owned_trial_id(self) -> Optional[int]: return self._worker_id_to_owned_trial_id.get(self.worker_id) def _is_issued_by_this_worker(self, log: Dict[str, Any]) -> bool: return log["worker_id"] == self.worker_id def _study_exists(self, study_id: int, log: Dict[str, Any]) -> bool: if study_id in self._studies: return True if self._is_issued_by_this_worker(log): raise KeyError(NOT_FOUND_MSG) return False def _apply_create_study(self, log: Dict[str, Any]) -> None: study_name = log["study_name"] directions = [StudyDirection(d) for d in log["directions"]] if study_name in [s.study_name for s in self._studies.values()]: if self._is_issued_by_this_worker(log): raise DuplicatedStudyError( "Another study with name '{}' already exists. " "Please specify a different name, or reuse the existing one " "by setting `load_if_exists` (for Python API) or " "`--skip-if-exists` flag (for CLI).".format(study_name) ) return study_id = self._next_study_id self._next_study_id += 1 self._studies[study_id] = FrozenStudy( study_name=study_name, direction=None, user_attrs={}, system_attrs={}, study_id=study_id, directions=directions, ) self._study_id_to_trial_ids[study_id] = [] def _apply_delete_study(self, log: Dict[str, Any]) -> None: study_id = log["study_id"] if self._study_exists(study_id, log): fs = self._studies.pop(study_id) assert fs._study_id == study_id def _apply_set_study_user_attr(self, log: Dict[str, Any]) -> None: study_id = log["study_id"] if self._study_exists(study_id, log): assert len(log["user_attr"]) == 1 self._studies[study_id].user_attrs.update(log["user_attr"]) def _apply_set_study_system_attr(self, log: Dict[str, Any]) -> None: study_id = log["study_id"] if self._study_exists(study_id, log): assert len(log["system_attr"]) == 1 self._studies[study_id].system_attrs.update(log["system_attr"]) def _apply_create_trial(self, log: Dict[str, Any]) -> None: study_id = log["study_id"] if not self._study_exists(study_id, log): return trial_id = len(self._trials) distributions = {} if "distributions" in log: distributions = {k: json_to_distribution(v) for k, v in log["distributions"].items()} params = {} if "params" in log: params = {k: distributions[k].to_external_repr(p) for k, p in log["params"].items()} if log["datetime_start"] is not None: datetime_start = datetime.datetime.fromisoformat(log["datetime_start"]) else: datetime_start = None if "datetime_complete" in log: datetime_complete = datetime.datetime.fromisoformat(log["datetime_complete"]) else: datetime_complete = None self._trials[trial_id] = FrozenTrial( trial_id=trial_id, number=len(self._study_id_to_trial_ids[study_id]), state=TrialState(log.get("state", TrialState.RUNNING.value)), params=params, distributions=distributions, user_attrs=log.get("user_attrs", {}), system_attrs=log.get("system_attrs", {}), value=log.get("value", None), intermediate_values={int(k): v for k, v in log.get("intermediate_values", {}).items()}, datetime_start=datetime_start, datetime_complete=datetime_complete, values=log.get("values", None), ) self._study_id_to_trial_ids[study_id].append(trial_id) self._trial_id_to_study_id[trial_id] = study_id if self._is_issued_by_this_worker(log): self._last_created_trial_id_by_this_process = trial_id if self._trials[trial_id].state == TrialState.RUNNING: self._worker_id_to_owned_trial_id[self.worker_id] = trial_id def _apply_set_trial_param(self, log: Dict[str, Any]) -> None: trial_id = log["trial_id"] if not self._trial_exists_and_updatable(trial_id, log): return param_name = log["param_name"] param_value_internal = log["param_value_internal"] distribution = json_to_distribution(log["distribution"]) study_id = self._trial_id_to_study_id[trial_id] for prev_trial_id in self._study_id_to_trial_ids[study_id]: prev_trial = self._trials[prev_trial_id] if param_name in prev_trial.params.keys(): try: check_distribution_compatibility( prev_trial.distributions[param_name], distribution ) except Exception: if self._is_issued_by_this_worker(log): raise return break trial = copy.copy(self._trials[trial_id]) trial.params = { **copy.copy(trial.params), param_name: distribution.to_external_repr(param_value_internal), } trial.distributions = {**copy.copy(trial.distributions), param_name: distribution} self._trials[trial_id] = trial def _apply_set_trial_state_values(self, log: Dict[str, Any]) -> None: trial_id = log["trial_id"] if not self._trial_exists_and_updatable(trial_id, log): return state = TrialState(log["state"]) if state == self._trials[trial_id].state and state == TrialState.RUNNING: return trial = copy.copy(self._trials[trial_id]) if state == TrialState.RUNNING: trial.datetime_start = datetime.datetime.fromisoformat(log["datetime_start"]) if self._is_issued_by_this_worker(log): self._worker_id_to_owned_trial_id[self.worker_id] = trial_id if state.is_finished(): trial.datetime_complete = datetime.datetime.fromisoformat(log["datetime_complete"]) trial.state = state if log["values"] is not None: trial.values = log["values"] self._trials[trial_id] = trial def _apply_set_trial_intermediate_value(self, log: Dict[str, Any]) -> None: trial_id = log["trial_id"] if self._trial_exists_and_updatable(trial_id, log): trial = copy.copy(self._trials[trial_id]) trial.intermediate_values = { **copy.copy(trial.intermediate_values), log["step"]: log["intermediate_value"], } self._trials[trial_id] = trial def _apply_set_trial_user_attr(self, log: Dict[str, Any]) -> None: trial_id = log["trial_id"] if self._trial_exists_and_updatable(trial_id, log): assert len(log["user_attr"]) == 1 trial = copy.copy(self._trials[trial_id]) trial.user_attrs = {**copy.copy(trial.user_attrs), **log["user_attr"]} self._trials[trial_id] = trial def _apply_set_trial_system_attr(self, log: Dict[str, Any]) -> None: trial_id = log["trial_id"] if self._trial_exists_and_updatable(trial_id, log): assert len(log["system_attr"]) == 1 trial = copy.copy(self._trials[trial_id]) trial.system_attrs = { **copy.copy(trial.system_attrs), **log["system_attr"], } self._trials[trial_id] = trial def _trial_exists_and_updatable(self, trial_id: int, log: Dict[str, Any]) -> bool: if trial_id not in self._trials: if self._is_issued_by_this_worker(log): raise KeyError(NOT_FOUND_MSG) return False elif self._trials[trial_id].state.is_finished(): if self._is_issued_by_this_worker(log): raise RuntimeError( "Trial#{} has already finished and can not be updated.".format( self._trials[trial_id].number ) ) return False else: return True