import copy
from datetime import datetime
import pickle
from typing import Any
from typing import cast
from typing import Dict
from typing import List
from typing import Optional
from typing import Sequence
from typing import Tuple
import optuna
from optuna import distributions
from optuna import exceptions
from optuna._experimental import experimental
from optuna._imports import try_import
from optuna._study_direction import StudyDirection
from optuna._study_summary import StudySummary
from optuna.storages import BaseStorage
from optuna.storages._base import DEFAULT_STUDY_NAME_PREFIX
from optuna.trial import FrozenTrial
from optuna.trial import TrialState
_logger = optuna.logging.get_logger(__name__)
with try_import() as _imports:
import redis
[文档]@experimental("1.4.0")
class RedisStorage(BaseStorage):
"""Storage class for Redis backend.
Note that library users can instantiate this class, but the attributes
provided by this class are not supposed to be directly accessed by them.
Example:
We create an :class:`~optuna.storages.RedisStorage` instance using
the given redis database URL.
.. code::
import optuna
def objective(trial):
...
storage = optuna.storages.RedisStorage(
url="redis://passwd@localhost:port/db",
)
study = optuna.create_study(storage=storage)
study.optimize(objective)
Args:
url: URL of the redis storage, password and db are optional. (ie: redis://localhost:6379)
.. note::
If you use plan to use Redis as a storage mechanism for optuna,
make sure Redis in installed and running.
Please execute ``$ pip install -U redis`` to install redis python library.
"""
def __init__(self, url: str) -> None:
_imports.check()
self._url = url
self._redis = redis.Redis.from_url(url)
[文档] def create_new_study(self, study_name: Optional[str] = None) -> int:
if study_name is not None and self._redis.exists(self._key_study_name(study_name)):
raise exceptions.DuplicatedStudyError
if not self._redis.exists("study_counter"):
# We need the counter to start with 0.
self._redis.set("study_counter", -1)
study_id = self._redis.incr("study_counter", 1)
# We need the trial_number counter to start with 0.
self._redis.set("study_id:{:010d}:trial_number".format(study_id), -1)
if study_name is None:
study_name = "{}{:010d}".format(DEFAULT_STUDY_NAME_PREFIX, study_id)
with self._redis.pipeline() as pipe:
pipe.multi()
pipe.set(self._key_study_name(study_name), pickle.dumps(study_id))
pipe.set("study_id:{:010d}:study_name".format(study_id), pickle.dumps(study_name))
pipe.set(
"study_id:{:010d}:directions".format(study_id),
pickle.dumps([StudyDirection.NOT_SET]),
)
study_summary = StudySummary(
study_name=study_name,
direction=StudyDirection.NOT_SET,
best_trial=None,
user_attrs={},
system_attrs={},
n_trials=0,
datetime_start=None,
study_id=study_id,
)
pipe.rpush("study_list", pickle.dumps(study_id))
pipe.set(self._key_study_summary(study_id), pickle.dumps(study_summary))
pipe.execute()
_logger.info("A new study created in Redis with name: {}".format(study_name))
return study_id
[文档] def delete_study(self, study_id: int) -> None:
self._check_study_id(study_id)
with self._redis.pipeline() as pipe:
pipe.multi()
# Sumaries
pipe.delete(self._key_study_summary(study_id))
pipe.lrem("study_list", 0, pickle.dumps(study_id))
# Trials
trial_ids = self._get_study_trials(study_id)
for trial_id in trial_ids:
pipe.delete("trial_id:{:010d}:frozentrial".format(trial_id))
pipe.delete("trial_id:{:010d}:study_id".format(trial_id))
pipe.delete("study_id:{:010d}:trial_list".format(study_id))
pipe.delete("study_id:{:010d}:trial_number".format(study_id))
# Study
study_name = self.get_study_name_from_id(study_id)
pipe.delete("study_name:{}:study_id".format(study_name))
pipe.delete("study_id:{:010d}:study_name".format(study_id))
pipe.delete("study_id:{:010d}:directions".format(study_id))
pipe.delete("study_id:{:010d}:best_trial_id".format(study_id))
pipe.delete("study_id:{:010d}:params_distribution".format(study_id))
pipe.execute()
@staticmethod
def _key_study_name(study_name: str) -> str:
return "study_name:{}:study_id".format(study_name)
@staticmethod
def _key_study_summary(study_id: int) -> str:
return "study_id:{:010d}:study_summary".format(study_id)
def _set_study_summary(self, study_id: int, study_summary: StudySummary) -> None:
self._redis.set(self._key_study_summary(study_id), pickle.dumps(study_summary))
def _get_study_summary(self, study_id: int) -> StudySummary:
summary_pkl = self._redis.get(self._key_study_summary(study_id))
assert summary_pkl is not None
return pickle.loads(summary_pkl)
def _del_study_summary(self, study_id: int) -> None:
self._redis.delete(self._key_study_summary(study_id))
@staticmethod
def _key_study_direction(study_id: int) -> str:
return "study_id:{:010d}:directions".format(study_id)
[文档] def set_study_directions(self, study_id: int, directions: Sequence[StudyDirection]) -> None:
self._check_study_id(study_id)
if self._redis.exists(self._key_study_direction(study_id)):
direction_pkl = self._redis.get(self._key_study_direction(study_id))
assert direction_pkl is not None
current_directions = pickle.loads(direction_pkl)
if (
current_directions[0] != StudyDirection.NOT_SET
and current_directions != directions
):
raise ValueError(
"Cannot overwrite study direction from {} to {}.".format(
current_directions, directions
)
)
with self._redis.pipeline() as pipe:
pipe.multi()
pipe.set(self._key_study_direction(study_id), pickle.dumps(directions))
study_summary = self._get_study_summary(study_id)
study_summary._directions = list(directions)
pipe.set(self._key_study_summary(study_id), pickle.dumps(study_summary))
pipe.execute()
[文档] def set_study_user_attr(self, study_id: int, key: str, value: Any) -> None:
self._check_study_id(study_id)
study_summary = self._get_study_summary(study_id)
study_summary.user_attrs[key] = value
self._set_study_summary(study_id, study_summary)
[文档] def set_study_system_attr(self, study_id: int, key: str, value: Any) -> None:
self._check_study_id(study_id)
study_summary = self._get_study_summary(study_id)
study_summary.system_attrs[key] = value
self._set_study_summary(study_id, study_summary)
[文档] def get_study_id_from_name(self, study_name: str) -> int:
if not self._redis.exists(self._key_study_name(study_name)):
raise KeyError("No such study {}.".format(study_name))
study_id_pkl = self._redis.get(self._key_study_name(study_name))
assert study_id_pkl is not None
return pickle.loads(study_id_pkl)
[文档] def get_study_id_from_trial_id(self, trial_id: int) -> int:
study_id_pkl = self._redis.get("trial_id:{:010d}:study_id".format(trial_id))
if study_id_pkl is None:
raise KeyError("No such trial: {}.".format(trial_id))
return pickle.loads(study_id_pkl)
[文档] def get_study_name_from_id(self, study_id: int) -> str:
self._check_study_id(study_id)
study_name_pkl = self._redis.get("study_id:{:010d}:study_name".format(study_id))
if study_name_pkl is None:
raise KeyError("No such study: {}.".format(study_id))
return pickle.loads(study_name_pkl)
[文档] def get_study_directions(self, study_id: int) -> List[StudyDirection]:
direction_pkl = self._redis.get("study_id:{:010d}:directions".format(study_id))
if direction_pkl is None:
raise KeyError("No such study: {}.".format(study_id))
return list(pickle.loads(direction_pkl))
[文档] def get_study_user_attrs(self, study_id: int) -> Dict[str, Any]:
self._check_study_id(study_id)
study_summary = self._get_study_summary(study_id)
return copy.deepcopy(study_summary.user_attrs)
[文档] def get_study_system_attrs(self, study_id: int) -> Dict[str, Any]:
self._check_study_id(study_id)
study_summary = self._get_study_summary(study_id)
return copy.deepcopy(study_summary.system_attrs)
@staticmethod
def _key_study_param_distribution(study_id: int) -> str:
return "study_id:{:010d}:params_distribution".format(study_id)
def _get_study_param_distribution(self, study_id: int) -> Dict:
if self._redis.exists(self._key_study_param_distribution(study_id)):
param_distribution_pkl = self._redis.get(self._key_study_param_distribution(study_id))
assert param_distribution_pkl is not None
return pickle.loads(param_distribution_pkl)
else:
return {}
def _set_study_param_distribution(self, study_id: int, param_distribution: Dict) -> None:
self._redis.set(
self._key_study_param_distribution(study_id), pickle.dumps(param_distribution)
)
[文档] def get_all_study_summaries(self) -> List[StudySummary]:
study_summaries = []
study_ids = [pickle.loads(sid) for sid in self._redis.lrange("study_list", 0, -1)]
for study_id in study_ids:
study_summary = self._get_study_summary(study_id)
study_summaries.append(study_summary)
return study_summaries
[文档] def create_new_trial(self, study_id: int, template_trial: Optional[FrozenTrial] = None) -> int:
self._check_study_id(study_id)
if template_trial is None:
trial = self._create_running_trial()
else:
trial = copy.deepcopy(template_trial)
if not self._redis.exists("trial_counter"):
self._redis.set("trial_counter", -1)
trial_id = self._redis.incr("trial_counter", 1)
trial_number = self._redis.incr("study_id:{:010d}:trial_number".format(study_id))
trial.number = trial_number
trial._trial_id = trial_id
with self._redis.pipeline() as pipe:
pipe.multi()
pipe.set(self._key_trial(trial_id), pickle.dumps(trial))
pipe.set("trial_id:{:010d}:study_id".format(trial_id), pickle.dumps(study_id))
pipe.rpush("study_id:{:010d}:trial_list".format(study_id), trial_id)
pipe.execute()
pipe.multi()
study_summary = self._get_study_summary(study_id)
study_summary.n_trials = len(self._get_study_trials(study_id))
min_datetime_start = min(
[
cast(datetime, t.datetime_start)
for t in self.get_all_trials(study_id)
if t.datetime_start is not None
],
default=None,
)
study_summary.datetime_start = min_datetime_start
pipe.set(self._key_study_summary(study_id), pickle.dumps(study_summary))
pipe.execute()
if trial.state.is_finished():
self._update_cache(trial_id)
return trial_id
@staticmethod
def _create_running_trial() -> FrozenTrial:
return FrozenTrial(
trial_id=-1, # dummy value.
number=-1, # dummy value.
state=TrialState.RUNNING,
params={},
distributions={},
user_attrs={},
system_attrs={},
value=None,
intermediate_values={},
datetime_start=datetime.now(),
datetime_complete=None,
)
[文档] def set_trial_state(self, trial_id: int, state: TrialState) -> bool:
self._check_trial_id(trial_id)
trial = self.get_trial(trial_id)
self.check_trial_is_updatable(trial_id, trial.state)
if state == TrialState.RUNNING and trial.state != TrialState.WAITING:
return False
trial.state = state
if state == TrialState.RUNNING:
trial.datetime_start = datetime.now()
if state.is_finished():
trial.datetime_complete = datetime.now()
self._redis.set(self._key_trial(trial_id), pickle.dumps(trial))
self._update_cache(trial_id)
else:
self._redis.set(self._key_trial(trial_id), pickle.dumps(trial))
return True
[文档] def set_trial_param(
self,
trial_id: int,
param_name: str,
param_value_internal: float,
distribution: distributions.BaseDistribution,
) -> None:
self._check_trial_id(trial_id)
self.check_trial_is_updatable(trial_id, self.get_trial(trial_id).state)
# Check param distribution compatibility with previous trial(s).
study_id = self.get_study_id_from_trial_id(trial_id)
param_distribution = self._get_study_param_distribution(study_id)
if param_name in param_distribution:
distributions.check_distribution_compatibility(
param_distribution[param_name], distribution
)
trial = self.get_trial(trial_id)
with self._redis.pipeline() as pipe:
pipe.multi()
# Set study param distribution.
param_distribution[param_name] = distribution
pipe.set(
self._key_study_param_distribution(study_id), pickle.dumps(param_distribution)
)
# Set params.
trial.params[param_name] = distribution.to_external_repr(param_value_internal)
trial.distributions[param_name] = distribution
pipe.set(self._key_trial(trial_id), pickle.dumps(trial))
pipe.execute()
[文档] def get_trial_id_from_study_id_trial_number(self, study_id: int, trial_number: int) -> int:
trial_ids = self._get_study_trials(study_id)
if len(trial_ids) <= trial_number:
raise KeyError(
"No trial with trial number {} exists in study with study_id {}.".format(
trial_number, study_id
)
)
return trial_ids[trial_number]
[文档] def get_trial_number_from_id(self, trial_id: int) -> int:
return self.get_trial(trial_id).number
@staticmethod
def _key_best_trial(study_id: int) -> str:
return "study_id:{:010d}:best_trial_id".format(study_id)
[文档] def get_best_trial(self, study_id: int) -> FrozenTrial:
if not self._redis.exists(self._key_best_trial(study_id)):
all_trials = self.get_all_trials(study_id, deepcopy=False)
all_trials = [t for t in all_trials if t.state is TrialState.COMPLETE]
if len(all_trials) == 0:
raise ValueError("No trials are completed yet.")
_direction = self.get_study_directions(study_id)
if len(_direction) > 1:
raise ValueError(
"Best trial can be obtained only for single-objective optimization."
)
direction = _direction[0]
if direction == StudyDirection.MAXIMIZE:
best_trial = max(all_trials, key=lambda t: cast(float, t.value))
else:
best_trial = min(all_trials, key=lambda t: cast(float, t.value))
self._set_best_trial(study_id, best_trial.number)
else:
best_trial_id_pkl = self._redis.get(self._key_best_trial(study_id))
assert best_trial_id_pkl is not None
best_trial_id = pickle.loads(best_trial_id_pkl)
best_trial = self.get_trial(best_trial_id)
return best_trial
def _set_best_trial(self, study_id: int, trial_id: int) -> None:
with self._redis.pipeline() as pipe:
pipe.multi()
pipe.set(self._key_best_trial(study_id), pickle.dumps(trial_id))
study_summary = self._get_study_summary(study_id)
study_summary.best_trial = self.get_trial(trial_id)
pipe.set(self._key_study_summary(study_id), pickle.dumps(study_summary))
pipe.execute()
[文档] def get_trial_param(self, trial_id: int, param_name: str) -> float:
distribution = self.get_trial(trial_id).distributions[param_name]
return distribution.to_internal_repr(self.get_trial(trial_id).params[param_name])
[文档] def set_trial_values(self, trial_id: int, values: Sequence[float]) -> None:
self._check_trial_id(trial_id)
trial = self.get_trial(trial_id)
self.check_trial_is_updatable(trial_id, trial.state)
trial.values = values
self._redis.set(self._key_trial(trial_id), pickle.dumps(trial))
def _update_cache(self, trial_id: int) -> None:
trial = self.get_trial(trial_id)
if trial.state != TrialState.COMPLETE:
return
study_id = self.get_study_id_from_trial_id(trial_id)
_direction = self.get_study_directions(study_id)
if len(_direction) > 1:
return
direction = _direction[0]
if not self._redis.exists("study_id:{:010d}:best_trial_id".format(study_id)):
self._set_best_trial(study_id, trial_id)
return
best_value_or_none = self.get_best_trial(study_id).value
assert best_value_or_none is not None
assert trial.value is not None
best_value = float(best_value_or_none)
new_value = float(trial.value)
# Complete trials do not have `None` values.
assert new_value is not None
if direction == StudyDirection.MAXIMIZE:
if new_value > best_value:
self._set_best_trial(study_id, trial_id)
else:
if new_value < best_value:
self._set_best_trial(study_id, trial_id)
return
[文档] def set_trial_user_attr(self, trial_id: int, key: str, value: Any) -> None:
self._check_trial_id(trial_id)
trial = self.get_trial(trial_id)
self.check_trial_is_updatable(trial_id, trial.state)
trial.user_attrs[key] = value
self._set_trial(trial_id, trial)
[文档] def set_trial_system_attr(self, trial_id: int, key: str, value: Any) -> None:
self._check_trial_id(trial_id)
trial = self.get_trial(trial_id)
self.check_trial_is_updatable(trial_id, trial.state)
trial.system_attrs[key] = value
self._set_trial(trial_id, trial)
@staticmethod
def _key_trial(trial_id: int) -> str:
return "trial_id:{:010d}:frozentrial".format(trial_id)
[文档] def get_trial(self, trial_id: int) -> FrozenTrial:
self._check_trial_id(trial_id)
frozen_trial_pkl = self._redis.get(self._key_trial(trial_id))
assert frozen_trial_pkl is not None
return pickle.loads(frozen_trial_pkl)
def _set_trial(self, trial_id: int, trial: FrozenTrial) -> None:
self._redis.set(self._key_trial(trial_id), pickle.dumps(trial))
def _del_trial(self, trial_id: int) -> None:
with self._redis.pipeline() as pipe:
pipe.multi()
pipe.delete(self._key_trial(trial_id))
pipe.delete("trial_id:{:010d}:study_id".format(trial_id))
pipe.execute()
def _get_study_trials(self, study_id: int) -> List[int]:
self._check_study_id(study_id)
study_trial_list_key = "study_id:{:010d}:trial_list".format(study_id)
return [int(tid) for tid in self._redis.lrange(study_trial_list_key, 0, -1)]
[文档] def get_all_trials(
self,
study_id: int,
deepcopy: bool = True,
states: Optional[Tuple[TrialState, ...]] = None,
) -> List[FrozenTrial]:
self._check_study_id(study_id)
trials = []
trial_ids = self._get_study_trials(study_id)
for trial_id in trial_ids:
frozen_trial = self.get_trial(trial_id)
if states is None or frozen_trial.state in states:
trials.append(frozen_trial)
if deepcopy:
return copy.deepcopy(trials)
else:
return trials
[文档] def read_trials_from_remote_storage(self, study_id: int) -> None:
self._check_study_id(study_id)
def _check_study_id(self, study_id: int) -> None:
if not self._redis.exists("study_id:{:010d}:study_name".format(study_id)):
raise KeyError("study_id {} does not exist.".format(study_id))
def _check_trial_id(self, trial_id: int) -> None:
if not self._redis.exists(self._key_trial(trial_id)):
raise KeyError("study_id {} does not exist.".format(trial_id))