from __future__ import annotations
from collections.abc import Mapping
from collections.abc import Sequence
import datetime
import math
from typing import Any
from typing import cast
from typing import Dict
from typing import overload
import warnings
from optuna import distributions
from optuna import logging
from optuna._convert_positional_args import convert_positional_args
from optuna._deprecated import deprecated_func
from optuna._typing import JSONSerializable
from optuna.distributions import _convert_old_distribution_to_new_distribution
from optuna.distributions import BaseDistribution
from optuna.distributions import CategoricalChoiceType
from optuna.distributions import CategoricalDistribution
from optuna.distributions import FloatDistribution
from optuna.distributions import IntDistribution
from optuna.trial._base import _SUGGEST_INT_POSITIONAL_ARGS
from optuna.trial._base import BaseTrial
from optuna.trial._state import TrialState
_logger = logging.get_logger(__name__)
_suggest_deprecated_msg = "Use suggest_float{args} instead."
[docs]
class FrozenTrial(BaseTrial):
"""Status and results of a :class:`~optuna.trial.Trial`.
An object of this class has the same methods as :class:`~optuna.trial.Trial`, but is not
associated with, nor has any references to a :class:`~optuna.study.Study`.
It is therefore not possible to make persistent changes to a storage from this object by
itself, for instance by using :func:`~optuna.trial.FrozenTrial.set_user_attr`.
It will suggest the parameter values stored in :attr:`params` and will not sample values from
any distributions.
It can be passed to objective functions (see :func:`~optuna.study.Study.optimize`) and is
useful for deploying optimization results.
Example:
Re-evaluate an objective function with parameter values optimized study.
.. testcode::
import optuna
def objective(trial):
x = trial.suggest_float("x", -1, 1)
return x**2
study = optuna.create_study()
study.optimize(objective, n_trials=3)
assert objective(study.best_trial) == study.best_value
.. note::
Instances are mutable, despite the name.
For instance, :func:`~optuna.trial.FrozenTrial.set_user_attr` will update user attributes
of objects in-place.
Example:
Overwritten attributes.
.. testcode::
import copy
import datetime
import optuna
def objective(trial):
x = trial.suggest_float("x", -1, 1)
# this user attribute always differs
trial.set_user_attr("evaluation time", datetime.datetime.now())
return x**2
study = optuna.create_study()
study.optimize(objective, n_trials=3)
best_trial = study.best_trial
best_trial_copy = copy.deepcopy(best_trial)
# re-evaluate
objective(best_trial)
# the user attribute is overwritten by re-evaluation
assert best_trial.user_attrs != best_trial_copy.user_attrs
.. note::
Please refer to :class:`~optuna.trial.Trial` for details of methods and properties.
Attributes:
number:
Unique and consecutive number of :class:`~optuna.trial.Trial` for each
:class:`~optuna.study.Study`. Note that this field uses zero-based numbering.
state:
:class:`TrialState` of the :class:`~optuna.trial.Trial`.
value:
Objective value of the :class:`~optuna.trial.Trial`.
``value`` and ``values`` must not be specified at the same time.
values:
Sequence of objective values of the :class:`~optuna.trial.Trial`.
The length is greater than 1 if the problem is multi-objective optimization.
``value`` and ``values`` must not be specified at the same time.
datetime_start:
Datetime where the :class:`~optuna.trial.Trial` started.
datetime_complete:
Datetime where the :class:`~optuna.trial.Trial` finished.
params:
Dictionary that contains suggested parameters.
distributions:
Dictionary that contains the distributions of :attr:`params`.
user_attrs:
Dictionary that contains the attributes of the :class:`~optuna.trial.Trial` set with
:func:`optuna.trial.Trial.set_user_attr`.
system_attrs:
Dictionary that contains the attributes of the :class:`~optuna.trial.Trial` set with
:func:`optuna.trial.Trial.set_system_attr`.
intermediate_values:
Intermediate objective values set with :func:`optuna.trial.Trial.report`.
"""
def __init__(
self,
number: int,
state: TrialState,
value: float | None,
datetime_start: datetime.datetime | None,
datetime_complete: datetime.datetime | None,
params: dict[str, Any],
distributions: dict[str, BaseDistribution],
user_attrs: dict[str, Any],
system_attrs: dict[str, Any],
intermediate_values: dict[int, float],
trial_id: int,
*,
values: Sequence[float] | None = None,
) -> None:
self._number = number
self.state = state
self._values: list[float] | None = None
if value is not None and values is not None:
raise ValueError("Specify only one of `value` and `values`.")
elif value is not None:
self._values = [value]
elif values is not None:
self._values = list(values)
self._datetime_start = datetime_start
self.datetime_complete = datetime_complete
self._params = params
self._user_attrs = user_attrs
self._system_attrs = system_attrs
self.intermediate_values = intermediate_values
self._distributions = distributions
self._trial_id = trial_id
def __eq__(self, other: Any) -> bool:
if not isinstance(other, FrozenTrial):
return NotImplemented
return other.__dict__ == self.__dict__
def __lt__(self, other: Any) -> bool:
if not isinstance(other, FrozenTrial):
return NotImplemented
return self.number < other.number
def __le__(self, other: Any) -> bool:
if not isinstance(other, FrozenTrial):
return NotImplemented
return self.number <= other.number
def __hash__(self) -> int:
return hash(tuple(getattr(self, field) for field in self.__dict__))
def __repr__(self) -> str:
return "{cls}({kwargs})".format(
cls=self.__class__.__name__,
kwargs=", ".join(
"{field}={value}".format(
field=field if not field.startswith("_") else field[1:],
value=repr(getattr(self, field)),
)
for field in self.__dict__
)
+ ", value=None",
)
def suggest_float(
self,
name: str,
low: float,
high: float,
*,
step: float | None = None,
log: bool = False,
) -> float:
return self._suggest(name, FloatDistribution(low, high, log=log, step=step))
@convert_positional_args(previous_positional_arg_names=_SUGGEST_INT_POSITIONAL_ARGS)
def suggest_int(
self, name: str, low: int, high: int, *, step: int = 1, log: bool = False
) -> int:
return int(self._suggest(name, IntDistribution(low, high, log=log, step=step)))
@overload
def suggest_categorical(self, name: str, choices: Sequence[None]) -> None: ...
@overload
def suggest_categorical(self, name: str, choices: Sequence[bool]) -> bool: ...
@overload
def suggest_categorical(self, name: str, choices: Sequence[int]) -> int: ...
@overload
def suggest_categorical(self, name: str, choices: Sequence[float]) -> float: ...
@overload
def suggest_categorical(self, name: str, choices: Sequence[str]) -> str: ...
@overload
def suggest_categorical(
self, name: str, choices: Sequence[CategoricalChoiceType]
) -> CategoricalChoiceType: ...
def suggest_categorical(
self, name: str, choices: Sequence[CategoricalChoiceType]
) -> CategoricalChoiceType:
return self._suggest(name, CategoricalDistribution(choices=choices))
[docs]
def report(self, value: float, step: int) -> None:
"""Interface of report function.
Since :class:`~optuna.trial.FrozenTrial` is not pruned,
this report function does nothing.
.. seealso::
Please refer to :func:`~optuna.trial.FrozenTrial.should_prune`.
Args:
value:
A value returned from the objective function.
step:
Step of the trial (e.g., Epoch of neural network training). Note that pruners
assume that ``step`` starts at zero. For example,
:class:`~optuna.pruners.MedianPruner` simply checks if ``step`` is less than
``n_warmup_steps`` as the warmup mechanism.
"""
pass
[docs]
def should_prune(self) -> bool:
"""Suggest whether the trial should be pruned or not.
The suggestion is always :obj:`False` regardless of a pruning algorithm.
.. note::
:class:`~optuna.trial.FrozenTrial` only samples one combination of parameters.
Returns:
:obj:`False`.
"""
return False
def set_user_attr(self, key: str, value: Any) -> None:
self._user_attrs[key] = value
[docs]
@deprecated_func("3.1.0", "5.0.0")
def set_system_attr(self, key: str, value: Any) -> None:
self._system_attrs[key] = value
def _validate(self) -> None:
if self.state != TrialState.WAITING and self.datetime_start is None:
raise ValueError(
"`datetime_start` is supposed to be set when the trial state is not waiting."
)
if self.state.is_finished():
if self.datetime_complete is None:
raise ValueError("`datetime_complete` is supposed to be set for a finished trial.")
else:
if self.datetime_complete is not None:
raise ValueError(
"`datetime_complete` is supposed to be None for an unfinished trial."
)
if self.state == TrialState.FAIL and self._values is not None:
raise ValueError(f"values should be None for a failed trial, but got {self._values}.")
if self.state == TrialState.COMPLETE:
if self._values is None:
raise ValueError("values should be set for a complete trial.")
elif any(math.isnan(x) for x in self._values):
raise ValueError("values should not contain NaN.")
if set(self.params.keys()) != set(self.distributions.keys()):
raise ValueError(
"Inconsistent parameters {} and distributions {}.".format(
set(self.params.keys()), set(self.distributions.keys())
)
)
for param_name, param_value in self.params.items():
distribution = self.distributions[param_name]
param_value_in_internal_repr = distribution.to_internal_repr(param_value)
if not distribution._contains(param_value_in_internal_repr):
raise ValueError(
"The value {} of parameter '{}' isn't contained in the distribution "
"{}.".format(param_value, param_name, distribution)
)
def _suggest(self, name: str, distribution: BaseDistribution) -> Any:
if name not in self._params:
raise ValueError(
"The value of the parameter '{}' is not found. Please set it at "
"the construction of the FrozenTrial object.".format(name)
)
value = self._params[name]
param_value_in_internal_repr = distribution.to_internal_repr(value)
if not distribution._contains(param_value_in_internal_repr):
warnings.warn(
"The value {} of the parameter '{}' is out of "
"the range of the distribution {}.".format(value, name, distribution)
)
if name in self._distributions:
distributions.check_distribution_compatibility(self._distributions[name], distribution)
self._distributions[name] = distribution
return value
@property
def number(self) -> int:
return self._number
@number.setter
def number(self, value: int) -> None:
self._number = value
@property
def value(self) -> float | None:
if self._values is not None:
if len(self._values) > 1:
raise RuntimeError(
"This attribute is not available during multi-objective optimization."
)
return self._values[0]
return None
@value.setter
def value(self, v: float | None) -> None:
if self._values is not None:
if len(self._values) > 1:
raise RuntimeError(
"This attribute is not available during multi-objective optimization."
)
if v is not None:
self._values = [v]
else:
self._values = None
# These `_get_values`, `_set_values`, and `values = property(_get_values, _set_values)` are
# defined to pass the mypy.
# See https://github.com/python/mypy/issues/3004#issuecomment-726022329.
def _get_values(self) -> list[float] | None:
return self._values
def _set_values(self, v: Sequence[float] | None) -> None:
if v is not None:
self._values = list(v)
else:
self._values = None
values = property(_get_values, _set_values)
@property
def datetime_start(self) -> datetime.datetime | None:
return self._datetime_start
@datetime_start.setter
def datetime_start(self, value: datetime.datetime | None) -> None:
self._datetime_start = value
@property
def params(self) -> dict[str, Any]:
return self._params
@params.setter
def params(self, params: dict[str, Any]) -> None:
self._params = params
@property
def distributions(self) -> dict[str, BaseDistribution]:
return self._distributions
@distributions.setter
def distributions(self, value: dict[str, BaseDistribution]) -> None:
self._distributions = value
@property
def user_attrs(self) -> dict[str, Any]:
return self._user_attrs
@user_attrs.setter
def user_attrs(self, value: dict[str, Any]) -> None:
self._user_attrs = value
@property
def system_attrs(self) -> dict[str, Any]:
return self._system_attrs
@system_attrs.setter
def system_attrs(self, value: Mapping[str, JSONSerializable]) -> None:
self._system_attrs = cast(Dict[str, Any], value)
@property
def last_step(self) -> int | None:
"""Return the maximum step of :attr:`intermediate_values` in the trial.
Returns:
The maximum step of intermediates.
"""
if len(self.intermediate_values) == 0:
return None
else:
return max(self.intermediate_values.keys())
@property
def duration(self) -> datetime.timedelta | None:
"""Return the elapsed time taken to complete the trial.
Returns:
The duration.
"""
if self.datetime_start and self.datetime_complete:
return self.datetime_complete - self.datetime_start
else:
return None
[docs]
def create_trial(
*,
state: TrialState = TrialState.COMPLETE,
value: float | None = None,
values: Sequence[float] | None = None,
params: dict[str, Any] | None = None,
distributions: dict[str, BaseDistribution] | None = None,
user_attrs: dict[str, Any] | None = None,
system_attrs: dict[str, Any] | None = None,
intermediate_values: dict[int, float] | None = None,
) -> FrozenTrial:
"""Create a new :class:`~optuna.trial.FrozenTrial`.
Example:
.. testcode::
import optuna
from optuna.distributions import CategoricalDistribution
from optuna.distributions import FloatDistribution
trial = optuna.trial.create_trial(
params={"x": 1.0, "y": 0},
distributions={
"x": FloatDistribution(0, 10),
"y": CategoricalDistribution([-1, 0, 1]),
},
value=5.0,
)
assert isinstance(trial, optuna.trial.FrozenTrial)
assert trial.value == 5.0
assert trial.params == {"x": 1.0, "y": 0}
.. seealso::
See :func:`~optuna.study.Study.add_trial` for how this function can be used to create a
study from existing trials.
.. note::
Please note that this is a low-level API. In general, trials that are passed to objective
functions are created inside :func:`~optuna.study.Study.optimize`.
.. note::
When ``state`` is :class:`TrialState.COMPLETE`, the following parameters are
required:
* ``params``
* ``distributions``
* ``value`` or ``values``
Args:
state:
Trial state.
value:
Trial objective value. Must be specified if ``state`` is :class:`TrialState.COMPLETE`.
``value`` and ``values`` must not be specified at the same time.
values:
Sequence of the trial objective values. The length is greater than 1 if the problem is
multi-objective optimization.
Must be specified if ``state`` is :class:`TrialState.COMPLETE`.
``value`` and ``values`` must not be specified at the same time.
params:
Dictionary with suggested parameters of the trial.
distributions:
Dictionary with parameter distributions of the trial.
user_attrs:
Dictionary with user attributes.
system_attrs:
Dictionary with system attributes. Should not have to be used for most users.
intermediate_values:
Dictionary with intermediate objective values of the trial.
Returns:
Created trial.
"""
params = params or {}
distributions = distributions or {}
distributions = {
key: _convert_old_distribution_to_new_distribution(dist)
for key, dist in distributions.items()
}
user_attrs = user_attrs or {}
system_attrs = system_attrs or {}
intermediate_values = intermediate_values or {}
if state == TrialState.WAITING:
datetime_start = None
else:
datetime_start = datetime.datetime.now()
if state.is_finished():
datetime_complete: datetime.datetime | None = datetime_start
else:
datetime_complete = None
trial = FrozenTrial(
number=-1,
trial_id=-1,
state=state,
value=value,
values=values,
datetime_start=datetime_start,
datetime_complete=datetime_complete,
params=params,
distributions=distributions,
user_attrs=user_attrs,
system_attrs=system_attrs,
intermediate_values=intermediate_values,
)
trial._validate()
return trial