from __future__ import annotations
import datetime
import math
from typing import Any
from typing import cast
from typing import overload
from typing import TYPE_CHECKING
from optuna import distributions
from optuna import logging
from optuna._convert_positional_args import convert_positional_args
from optuna._deprecated import deprecated_func
from optuna._warnings import optuna_warn
from optuna.distributions import _convert_old_distribution_to_new_distribution
from optuna.distributions import CategoricalDistribution
from optuna.distributions import FloatDistribution
from optuna.distributions import IntDistribution
from optuna.trial._base import _SUGGEST_INT_POSITIONAL_ARGS
from optuna.trial._base import BaseTrial
from optuna.trial._state import TrialState
if TYPE_CHECKING:
from collections.abc import Mapping
from collections.abc import Sequence
from optuna._typing import JSONSerializable
from optuna.distributions import BaseDistribution
from optuna.distributions import CategoricalChoiceType
_logger = logging.get_logger(__name__)
_suggest_deprecated_msg = "Use suggest_float{args} instead."
[docs]
class FrozenTrial(BaseTrial):
"""Status and results of a :class:`~optuna.trial.Trial`.
An object of this class has the same methods as :class:`~optuna.trial.Trial`, but is not
associated with, nor has any references to a :class:`~optuna.study.Study`.
It is therefore not possible to make persistent changes to a storage from this object by
itself, for instance by using :func:`~optuna.trial.FrozenTrial.set_user_attr`.
It will suggest the parameter values stored in :attr:`params` and will not sample values from
any distributions.
It can be passed to objective functions (see :func:`~optuna.study.Study.optimize`) and is
useful for deploying optimization results.
Example:
Re-evaluate an objective function with parameter values optimized study.
.. testcode::
import optuna
def objective(trial):
x = trial.suggest_float("x", -1, 1)
return x**2
study = optuna.create_study()
study.optimize(objective, n_trials=3)
assert objective(study.best_trial) == study.best_value
.. note::
Instances are mutable, despite the name.
For instance, :func:`~optuna.trial.FrozenTrial.set_user_attr` will update user attributes
of objects in-place.
Example:
Overwritten attributes.
.. testcode::
import copy
import datetime
import optuna
def objective(trial):
x = trial.suggest_float("x", -1, 1)
# this user attribute always differs
trial.set_user_attr("evaluation time", datetime.datetime.now())
return x**2
study = optuna.create_study()
study.optimize(objective, n_trials=3)
best_trial = study.best_trial
best_trial_copy = copy.deepcopy(best_trial)
# re-evaluate
objective(best_trial)
# the user attribute is overwritten by re-evaluation
assert best_trial.user_attrs != best_trial_copy.user_attrs
.. note::
Please refer to :class:`~optuna.trial.Trial` for details of methods and properties.
Attributes:
number:
Unique and consecutive number of :class:`~optuna.trial.Trial` for each
:class:`~optuna.study.Study`. Note that this field uses zero-based numbering.
state:
:class:`TrialState` of the :class:`~optuna.trial.Trial`.
value:
Objective value of the :class:`~optuna.trial.Trial`.
``value`` and ``values`` must not be specified at the same time.
values:
Sequence of objective values of the :class:`~optuna.trial.Trial`.
The length is greater than 1 if the problem is multi-objective optimization.
``value`` and ``values`` must not be specified at the same time.
datetime_start:
Datetime where the :class:`~optuna.trial.Trial` started.
datetime_complete:
Datetime where the :class:`~optuna.trial.Trial` finished.
params:
Dictionary that contains suggested parameters.
distributions:
Dictionary that contains the distributions of :attr:`params`.
user_attrs:
Dictionary that contains the attributes of the :class:`~optuna.trial.Trial` set with
:func:`optuna.trial.Trial.set_user_attr`.
system_attrs:
Dictionary that contains the attributes of the :class:`~optuna.trial.Trial` set with
:func:`optuna.trial.Trial.set_system_attr`.
intermediate_values:
Intermediate objective values set with :func:`optuna.trial.Trial.report`.
"""
def __init__(
self,
number: int,
state: TrialState,
value: float | None,
datetime_start: datetime.datetime | None,
datetime_complete: datetime.datetime | None,
params: dict[str, Any],
distributions: dict[str, BaseDistribution],
user_attrs: dict[str, Any],
system_attrs: dict[str, Any],
intermediate_values: dict[int, float],
trial_id: int,
*,
values: Sequence[float] | None = None,
) -> None:
self._number = number
self.state = state
self._values: list[float] | None = None
if value is not None and values is not None:
raise ValueError("Specify only one of `value` and `values`.")
elif value is not None:
self._values = [value]
elif values is not None:
self._values = list(values)
self._datetime_start = datetime_start
self.datetime_complete = datetime_complete
self._params = params
self._user_attrs = user_attrs
self._system_attrs = system_attrs
self.intermediate_values = intermediate_values
self._distributions = distributions
self._trial_id = trial_id
def __eq__(self, other: Any) -> bool:
if not isinstance(other, FrozenTrial):
return NotImplemented
return other.__dict__ == self.__dict__
def __lt__(self, other: Any) -> bool:
if not isinstance(other, FrozenTrial):
return NotImplemented
return self.number < other.number
def __le__(self, other: Any) -> bool:
if not isinstance(other, FrozenTrial):
return NotImplemented
return self.number <= other.number
def __hash__(self) -> int:
return hash(tuple(getattr(self, field) for field in self.__dict__))
def __repr__(self) -> str:
cls = self.__class__.__name__
kwargs = (
", ".join(
f"{field if not field.startswith('_') else field[1:]}={repr(getattr(self, field))}"
for field in self.__dict__
)
+ ", value=None"
)
return f"{cls}({kwargs})"
def suggest_float(
self,
name: str,
low: float,
high: float,
*,
step: float | None = None,
log: bool = False,
) -> float:
return self._suggest(name, FloatDistribution(low, high, log=log, step=step))
@convert_positional_args(
previous_positional_arg_names=_SUGGEST_INT_POSITIONAL_ARGS,
deprecated_version="3.5.0",
removed_version="5.0.0",
)
def suggest_int(
self, name: str, low: int, high: int, *, step: int = 1, log: bool = False
) -> int:
return int(self._suggest(name, IntDistribution(low, high, log=log, step=step)))
@overload
def suggest_categorical(self, name: str, choices: Sequence[None]) -> None: ...
@overload
def suggest_categorical(self, name: str, choices: Sequence[bool]) -> bool: ...
@overload
def suggest_categorical(self, name: str, choices: Sequence[int]) -> int: ...
@overload
def suggest_categorical(self, name: str, choices: Sequence[float]) -> float: ...
@overload
def suggest_categorical(self, name: str, choices: Sequence[str]) -> str: ...
@overload
def suggest_categorical(
self, name: str, choices: Sequence[CategoricalChoiceType]
) -> CategoricalChoiceType: ...
def suggest_categorical(
self, name: str, choices: Sequence[CategoricalChoiceType]
) -> CategoricalChoiceType:
return self._suggest(name, CategoricalDistribution(choices=choices))
[docs]
def report(self, value: float, step: int) -> None:
"""Interface of report function.
Since :class:`~optuna.trial.FrozenTrial` is not pruned,
this report function does nothing.
.. seealso::
Please refer to :func:`~optuna.trial.FrozenTrial.should_prune`.
Args:
value:
A value returned from the objective function.
step:
Step of the trial (e.g., Epoch of neural network training). Note that pruners
assume that ``step`` starts at zero. For example,
:class:`~optuna.pruners.MedianPruner` simply checks if ``step`` is less than
``n_warmup_steps`` as the warmup mechanism.
"""
pass
[docs]
def should_prune(self) -> bool:
"""Suggest whether the trial should be pruned or not.
The suggestion is always :obj:`False` regardless of a pruning algorithm.
.. note::
:class:`~optuna.trial.FrozenTrial` only samples one combination of parameters.
Returns:
:obj:`False`.
"""
return False
def set_user_attr(self, key: str, value: Any) -> None:
self._user_attrs[key] = value
[docs]
@deprecated_func("3.1.0", "5.0.0")
def set_system_attr(self, key: str, value: Any) -> None:
self._system_attrs[key] = value
def _validate(self) -> None:
if self.state != TrialState.WAITING and self.datetime_start is None:
raise ValueError(
"`datetime_start` is supposed to be set when the trial state is not waiting."
)
if self.state.is_finished():
if self.datetime_complete is None:
raise ValueError("`datetime_complete` is supposed to be set for a finished trial.")
else:
if self.datetime_complete is not None:
raise ValueError(
"`datetime_complete` is supposed to be None for an unfinished trial."
)
if self.state == TrialState.FAIL and self._values is not None:
raise ValueError(f"values should be None for a failed trial, but got {self._values}.")
if self.state == TrialState.COMPLETE:
if self._values is None:
raise ValueError("values should be set for a complete trial.")
elif any(math.isnan(x) for x in self._values):
raise ValueError("values should not contain NaN.")
if set(self.params.keys()) != set(self.distributions.keys()):
raise ValueError(
f"Inconsistent parameters {set(self.params.keys())} and "
f"distributions {set(self.distributions.keys())}."
)
for param_name, param_value in self.params.items():
distribution = self.distributions[param_name]
param_value_in_internal_repr = distribution.to_internal_repr(param_value)
if not distribution._contains(param_value_in_internal_repr):
raise ValueError(
f"The value {param_value} of parameter '{param_name}' isn't contained in "
f"the distribution {distribution}."
)
def _suggest(self, name: str, distribution: BaseDistribution) -> Any:
if name not in self._params:
raise ValueError(
f"The value of the parameter '{name}' is not found. "
f"Please set it at the construction of the FrozenTrial object."
)
value = self._params[name]
param_value_in_internal_repr = distribution.to_internal_repr(value)
if not distribution._contains(param_value_in_internal_repr):
optuna_warn(
f"The value {value} of the parameter '{name}' is out of "
f"the range of the distribution {distribution}."
)
if name in self._distributions:
distributions.check_distribution_compatibility(self._distributions[name], distribution)
self._distributions[name] = distribution
return value
@property
def number(self) -> int:
return self._number
@number.setter
def number(self, value: int) -> None:
self._number = value
@property
def value(self) -> float | None:
if self._values is not None:
if len(self._values) > 1:
raise RuntimeError(
"This attribute is not available during multi-objective optimization."
)
return self._values[0]
return None
@value.setter
def value(self, v: float | None) -> None:
if self._values is not None:
if len(self._values) > 1:
raise RuntimeError(
"This attribute is not available during multi-objective optimization."
)
if v is not None:
self._values = [v]
else:
self._values = None
# These `_get_values`, `_set_values`, and `values = property(_get_values, _set_values)` are
# defined to pass the mypy.
# See https://github.com/python/mypy/issues/3004#issuecomment-726022329.
def _get_values(self) -> list[float] | None:
return self._values
def _set_values(self, v: Sequence[float] | None) -> None:
if v is not None:
self._values = list(v)
else:
self._values = None
values = property(_get_values, _set_values)
@property
def datetime_start(self) -> datetime.datetime | None:
return self._datetime_start
@datetime_start.setter
def datetime_start(self, value: datetime.datetime | None) -> None:
self._datetime_start = value
@property
def params(self) -> dict[str, Any]:
return self._params
@params.setter
def params(self, params: dict[str, Any]) -> None:
self._params = params
@property
def distributions(self) -> dict[str, BaseDistribution]:
return self._distributions
@distributions.setter
def distributions(self, value: dict[str, BaseDistribution]) -> None:
self._distributions = value
@property
def user_attrs(self) -> dict[str, Any]:
return self._user_attrs
@user_attrs.setter
def user_attrs(self, value: dict[str, Any]) -> None:
self._user_attrs = value
@property
def system_attrs(self) -> dict[str, Any]:
return self._system_attrs
@system_attrs.setter
def system_attrs(self, value: Mapping[str, JSONSerializable]) -> None:
self._system_attrs = cast("dict[str, Any]", value)
@property
def last_step(self) -> int | None:
"""Return the maximum step of :attr:`intermediate_values` in the trial.
Returns:
The maximum step of intermediates.
"""
if len(self.intermediate_values) == 0:
return None
else:
return max(self.intermediate_values.keys())
@property
def duration(self) -> datetime.timedelta | None:
"""Return the elapsed time taken to complete the trial.
Returns:
The duration.
"""
if self.datetime_start and self.datetime_complete:
return self.datetime_complete - self.datetime_start
else:
return None
[docs]
def create_trial(
*,
state: TrialState = TrialState.COMPLETE,
value: float | None = None,
values: Sequence[float] | None = None,
params: dict[str, Any] | None = None,
distributions: dict[str, BaseDistribution] | None = None,
user_attrs: dict[str, Any] | None = None,
system_attrs: dict[str, Any] | None = None,
intermediate_values: dict[int, float] | None = None,
) -> FrozenTrial:
"""Create a new :class:`~optuna.trial.FrozenTrial`.
Example:
.. testcode::
import optuna
from optuna.distributions import CategoricalDistribution
from optuna.distributions import FloatDistribution
trial = optuna.trial.create_trial(
params={"x": 1.0, "y": 0},
distributions={
"x": FloatDistribution(0, 10),
"y": CategoricalDistribution([-1, 0, 1]),
},
value=5.0,
)
assert isinstance(trial, optuna.trial.FrozenTrial)
assert trial.value == 5.0
assert trial.params == {"x": 1.0, "y": 0}
.. seealso::
See :func:`~optuna.study.Study.add_trial` for how this function can be used to create a
study from existing trials.
.. note::
Please note that this is a low-level API. In general, trials that are passed to objective
functions are created inside :func:`~optuna.study.Study.optimize`.
.. note::
When ``state`` is :class:`TrialState.COMPLETE`, the following parameters are
required:
* ``params``
* ``distributions``
* ``value`` or ``values``
Args:
state:
Trial state.
value:
Trial objective value. Must be specified if ``state`` is :class:`TrialState.COMPLETE`.
``value`` and ``values`` must not be specified at the same time.
values:
Sequence of the trial objective values. The length is greater than 1 if the problem is
multi-objective optimization.
Must be specified if ``state`` is :class:`TrialState.COMPLETE`.
``value`` and ``values`` must not be specified at the same time.
params:
Dictionary with suggested parameters of the trial.
distributions:
Dictionary with parameter distributions of the trial.
user_attrs:
Dictionary with user attributes.
system_attrs:
Dictionary with system attributes. Should not have to be used for most users.
intermediate_values:
Dictionary with intermediate objective values of the trial.
Returns:
Created trial.
"""
params = params or {}
distributions = distributions or {}
distributions = {
key: _convert_old_distribution_to_new_distribution(dist)
for key, dist in distributions.items()
}
user_attrs = user_attrs or {}
system_attrs = system_attrs or {}
intermediate_values = intermediate_values or {}
if state == TrialState.WAITING:
datetime_start = None
else:
datetime_start = datetime.datetime.now()
if state.is_finished():
datetime_complete: datetime.datetime | None = datetime_start
else:
datetime_complete = None
trial = FrozenTrial(
number=-1,
trial_id=-1,
state=state,
value=value,
values=values,
datetime_start=datetime_start,
datetime_complete=datetime_complete,
params=params,
distributions=distributions,
user_attrs=user_attrs,
system_attrs=system_attrs,
intermediate_values=intermediate_values,
)
trial._validate()
return trial