Source code for optuna.trial._frozen

from __future__ import annotations

import datetime
import math
from typing import Any
from typing import cast
from typing import overload
from typing import TYPE_CHECKING

from optuna import distributions
from optuna import logging
from optuna._convert_positional_args import convert_positional_args
from optuna._deprecated import deprecated_func
from optuna._warnings import optuna_warn
from optuna.distributions import _convert_old_distribution_to_new_distribution
from optuna.distributions import CategoricalDistribution
from optuna.distributions import FloatDistribution
from optuna.distributions import IntDistribution
from optuna.trial._base import _SUGGEST_INT_POSITIONAL_ARGS
from optuna.trial._base import BaseTrial
from optuna.trial._state import TrialState


if TYPE_CHECKING:
    from collections.abc import Mapping
    from collections.abc import Sequence

    from optuna._typing import JSONSerializable
    from optuna.distributions import BaseDistribution
    from optuna.distributions import CategoricalChoiceType


_logger = logging.get_logger(__name__)
_suggest_deprecated_msg = "Use suggest_float{args} instead."


[docs] class FrozenTrial(BaseTrial): """Status and results of a :class:`~optuna.trial.Trial`. An object of this class has the same methods as :class:`~optuna.trial.Trial`, but is not associated with, nor has any references to a :class:`~optuna.study.Study`. It is therefore not possible to make persistent changes to a storage from this object by itself, for instance by using :func:`~optuna.trial.FrozenTrial.set_user_attr`. It will suggest the parameter values stored in :attr:`params` and will not sample values from any distributions. It can be passed to objective functions (see :func:`~optuna.study.Study.optimize`) and is useful for deploying optimization results. Example: Re-evaluate an objective function with parameter values optimized study. .. testcode:: import optuna def objective(trial): x = trial.suggest_float("x", -1, 1) return x**2 study = optuna.create_study() study.optimize(objective, n_trials=3) assert objective(study.best_trial) == study.best_value .. note:: Instances are mutable, despite the name. For instance, :func:`~optuna.trial.FrozenTrial.set_user_attr` will update user attributes of objects in-place. Example: Overwritten attributes. .. testcode:: import copy import datetime import optuna def objective(trial): x = trial.suggest_float("x", -1, 1) # this user attribute always differs trial.set_user_attr("evaluation time", datetime.datetime.now()) return x**2 study = optuna.create_study() study.optimize(objective, n_trials=3) best_trial = study.best_trial best_trial_copy = copy.deepcopy(best_trial) # re-evaluate objective(best_trial) # the user attribute is overwritten by re-evaluation assert best_trial.user_attrs != best_trial_copy.user_attrs .. note:: Please refer to :class:`~optuna.trial.Trial` for details of methods and properties. Attributes: number: Unique and consecutive number of :class:`~optuna.trial.Trial` for each :class:`~optuna.study.Study`. Note that this field uses zero-based numbering. state: :class:`TrialState` of the :class:`~optuna.trial.Trial`. value: Objective value of the :class:`~optuna.trial.Trial`. ``value`` and ``values`` must not be specified at the same time. values: Sequence of objective values of the :class:`~optuna.trial.Trial`. The length is greater than 1 if the problem is multi-objective optimization. ``value`` and ``values`` must not be specified at the same time. datetime_start: Datetime where the :class:`~optuna.trial.Trial` started. datetime_complete: Datetime where the :class:`~optuna.trial.Trial` finished. params: Dictionary that contains suggested parameters. distributions: Dictionary that contains the distributions of :attr:`params`. user_attrs: Dictionary that contains the attributes of the :class:`~optuna.trial.Trial` set with :func:`optuna.trial.Trial.set_user_attr`. system_attrs: Dictionary that contains the attributes of the :class:`~optuna.trial.Trial` set with :func:`optuna.trial.Trial.set_system_attr`. intermediate_values: Intermediate objective values set with :func:`optuna.trial.Trial.report`. """ def __init__( self, number: int, state: TrialState, value: float | None, datetime_start: datetime.datetime | None, datetime_complete: datetime.datetime | None, params: dict[str, Any], distributions: dict[str, BaseDistribution], user_attrs: dict[str, Any], system_attrs: dict[str, Any], intermediate_values: dict[int, float], trial_id: int, *, values: Sequence[float] | None = None, ) -> None: self._number = number self.state = state self._values: list[float] | None = None if value is not None and values is not None: raise ValueError("Specify only one of `value` and `values`.") elif value is not None: self._values = [value] elif values is not None: self._values = list(values) self._datetime_start = datetime_start self.datetime_complete = datetime_complete self._params = params self._user_attrs = user_attrs self._system_attrs = system_attrs self.intermediate_values = intermediate_values self._distributions = distributions self._trial_id = trial_id def __eq__(self, other: Any) -> bool: if not isinstance(other, FrozenTrial): return NotImplemented return other.__dict__ == self.__dict__ def __lt__(self, other: Any) -> bool: if not isinstance(other, FrozenTrial): return NotImplemented return self.number < other.number def __le__(self, other: Any) -> bool: if not isinstance(other, FrozenTrial): return NotImplemented return self.number <= other.number def __hash__(self) -> int: return hash(tuple(getattr(self, field) for field in self.__dict__)) def __repr__(self) -> str: cls = self.__class__.__name__ kwargs = ( ", ".join( f"{field if not field.startswith('_') else field[1:]}={repr(getattr(self, field))}" for field in self.__dict__ ) + ", value=None" ) return f"{cls}({kwargs})" def suggest_float( self, name: str, low: float, high: float, *, step: float | None = None, log: bool = False, ) -> float: return self._suggest(name, FloatDistribution(low, high, log=log, step=step))
[docs] @deprecated_func("3.0.0", "6.0.0", text=_suggest_deprecated_msg.format(args="")) def suggest_uniform(self, name: str, low: float, high: float) -> float: return self.suggest_float(name, low, high)
[docs] @deprecated_func("3.0.0", "6.0.0", text=_suggest_deprecated_msg.format(args="(..., log=True)")) def suggest_loguniform(self, name: str, low: float, high: float) -> float: return self.suggest_float(name, low, high, log=True)
[docs] @deprecated_func("3.0.0", "6.0.0", text=_suggest_deprecated_msg.format(args="(..., step=...)")) def suggest_discrete_uniform(self, name: str, low: float, high: float, q: float) -> float: return self.suggest_float(name, low, high, step=q)
@convert_positional_args( previous_positional_arg_names=_SUGGEST_INT_POSITIONAL_ARGS, deprecated_version="3.5.0", removed_version="5.0.0", ) def suggest_int( self, name: str, low: int, high: int, *, step: int = 1, log: bool = False ) -> int: return int(self._suggest(name, IntDistribution(low, high, log=log, step=step))) @overload def suggest_categorical(self, name: str, choices: Sequence[None]) -> None: ... @overload def suggest_categorical(self, name: str, choices: Sequence[bool]) -> bool: ... @overload def suggest_categorical(self, name: str, choices: Sequence[int]) -> int: ... @overload def suggest_categorical(self, name: str, choices: Sequence[float]) -> float: ... @overload def suggest_categorical(self, name: str, choices: Sequence[str]) -> str: ... @overload def suggest_categorical( self, name: str, choices: Sequence[CategoricalChoiceType] ) -> CategoricalChoiceType: ... def suggest_categorical( self, name: str, choices: Sequence[CategoricalChoiceType] ) -> CategoricalChoiceType: return self._suggest(name, CategoricalDistribution(choices=choices))
[docs] def report(self, value: float, step: int) -> None: """Interface of report function. Since :class:`~optuna.trial.FrozenTrial` is not pruned, this report function does nothing. .. seealso:: Please refer to :func:`~optuna.trial.FrozenTrial.should_prune`. Args: value: A value returned from the objective function. step: Step of the trial (e.g., Epoch of neural network training). Note that pruners assume that ``step`` starts at zero. For example, :class:`~optuna.pruners.MedianPruner` simply checks if ``step`` is less than ``n_warmup_steps`` as the warmup mechanism. """ pass
[docs] def should_prune(self) -> bool: """Suggest whether the trial should be pruned or not. The suggestion is always :obj:`False` regardless of a pruning algorithm. .. note:: :class:`~optuna.trial.FrozenTrial` only samples one combination of parameters. Returns: :obj:`False`. """ return False
def set_user_attr(self, key: str, value: Any) -> None: self._user_attrs[key] = value
[docs] @deprecated_func("3.1.0", "5.0.0") def set_system_attr(self, key: str, value: Any) -> None: self._system_attrs[key] = value
def _validate(self) -> None: if self.state != TrialState.WAITING and self.datetime_start is None: raise ValueError( "`datetime_start` is supposed to be set when the trial state is not waiting." ) if self.state.is_finished(): if self.datetime_complete is None: raise ValueError("`datetime_complete` is supposed to be set for a finished trial.") else: if self.datetime_complete is not None: raise ValueError( "`datetime_complete` is supposed to be None for an unfinished trial." ) if self.state == TrialState.FAIL and self._values is not None: raise ValueError(f"values should be None for a failed trial, but got {self._values}.") if self.state == TrialState.COMPLETE: if self._values is None: raise ValueError("values should be set for a complete trial.") elif any(math.isnan(x) for x in self._values): raise ValueError("values should not contain NaN.") if set(self.params.keys()) != set(self.distributions.keys()): raise ValueError( f"Inconsistent parameters {set(self.params.keys())} and " f"distributions {set(self.distributions.keys())}." ) for param_name, param_value in self.params.items(): distribution = self.distributions[param_name] param_value_in_internal_repr = distribution.to_internal_repr(param_value) if not distribution._contains(param_value_in_internal_repr): raise ValueError( f"The value {param_value} of parameter '{param_name}' isn't contained in " f"the distribution {distribution}." ) def _suggest(self, name: str, distribution: BaseDistribution) -> Any: if name not in self._params: raise ValueError( f"The value of the parameter '{name}' is not found. " f"Please set it at the construction of the FrozenTrial object." ) value = self._params[name] param_value_in_internal_repr = distribution.to_internal_repr(value) if not distribution._contains(param_value_in_internal_repr): optuna_warn( f"The value {value} of the parameter '{name}' is out of " f"the range of the distribution {distribution}." ) if name in self._distributions: distributions.check_distribution_compatibility(self._distributions[name], distribution) self._distributions[name] = distribution return value @property def number(self) -> int: return self._number @number.setter def number(self, value: int) -> None: self._number = value @property def value(self) -> float | None: if self._values is not None: if len(self._values) > 1: raise RuntimeError( "This attribute is not available during multi-objective optimization." ) return self._values[0] return None @value.setter def value(self, v: float | None) -> None: if self._values is not None: if len(self._values) > 1: raise RuntimeError( "This attribute is not available during multi-objective optimization." ) if v is not None: self._values = [v] else: self._values = None # These `_get_values`, `_set_values`, and `values = property(_get_values, _set_values)` are # defined to pass the mypy. # See https://github.com/python/mypy/issues/3004#issuecomment-726022329. def _get_values(self) -> list[float] | None: return self._values def _set_values(self, v: Sequence[float] | None) -> None: if v is not None: self._values = list(v) else: self._values = None values = property(_get_values, _set_values) @property def datetime_start(self) -> datetime.datetime | None: return self._datetime_start @datetime_start.setter def datetime_start(self, value: datetime.datetime | None) -> None: self._datetime_start = value @property def params(self) -> dict[str, Any]: return self._params @params.setter def params(self, params: dict[str, Any]) -> None: self._params = params @property def distributions(self) -> dict[str, BaseDistribution]: return self._distributions @distributions.setter def distributions(self, value: dict[str, BaseDistribution]) -> None: self._distributions = value @property def user_attrs(self) -> dict[str, Any]: return self._user_attrs @user_attrs.setter def user_attrs(self, value: dict[str, Any]) -> None: self._user_attrs = value @property def system_attrs(self) -> dict[str, Any]: return self._system_attrs @system_attrs.setter def system_attrs(self, value: Mapping[str, JSONSerializable]) -> None: self._system_attrs = cast("dict[str, Any]", value) @property def last_step(self) -> int | None: """Return the maximum step of :attr:`intermediate_values` in the trial. Returns: The maximum step of intermediates. """ if len(self.intermediate_values) == 0: return None else: return max(self.intermediate_values.keys()) @property def duration(self) -> datetime.timedelta | None: """Return the elapsed time taken to complete the trial. Returns: The duration. """ if self.datetime_start and self.datetime_complete: return self.datetime_complete - self.datetime_start else: return None
[docs] def create_trial( *, state: TrialState = TrialState.COMPLETE, value: float | None = None, values: Sequence[float] | None = None, params: dict[str, Any] | None = None, distributions: dict[str, BaseDistribution] | None = None, user_attrs: dict[str, Any] | None = None, system_attrs: dict[str, Any] | None = None, intermediate_values: dict[int, float] | None = None, ) -> FrozenTrial: """Create a new :class:`~optuna.trial.FrozenTrial`. Example: .. testcode:: import optuna from optuna.distributions import CategoricalDistribution from optuna.distributions import FloatDistribution trial = optuna.trial.create_trial( params={"x": 1.0, "y": 0}, distributions={ "x": FloatDistribution(0, 10), "y": CategoricalDistribution([-1, 0, 1]), }, value=5.0, ) assert isinstance(trial, optuna.trial.FrozenTrial) assert trial.value == 5.0 assert trial.params == {"x": 1.0, "y": 0} .. seealso:: See :func:`~optuna.study.Study.add_trial` for how this function can be used to create a study from existing trials. .. note:: Please note that this is a low-level API. In general, trials that are passed to objective functions are created inside :func:`~optuna.study.Study.optimize`. .. note:: When ``state`` is :class:`TrialState.COMPLETE`, the following parameters are required: * ``params`` * ``distributions`` * ``value`` or ``values`` Args: state: Trial state. value: Trial objective value. Must be specified if ``state`` is :class:`TrialState.COMPLETE`. ``value`` and ``values`` must not be specified at the same time. values: Sequence of the trial objective values. The length is greater than 1 if the problem is multi-objective optimization. Must be specified if ``state`` is :class:`TrialState.COMPLETE`. ``value`` and ``values`` must not be specified at the same time. params: Dictionary with suggested parameters of the trial. distributions: Dictionary with parameter distributions of the trial. user_attrs: Dictionary with user attributes. system_attrs: Dictionary with system attributes. Should not have to be used for most users. intermediate_values: Dictionary with intermediate objective values of the trial. Returns: Created trial. """ params = params or {} distributions = distributions or {} distributions = { key: _convert_old_distribution_to_new_distribution(dist) for key, dist in distributions.items() } user_attrs = user_attrs or {} system_attrs = system_attrs or {} intermediate_values = intermediate_values or {} if state == TrialState.WAITING: datetime_start = None else: datetime_start = datetime.datetime.now() if state.is_finished(): datetime_complete: datetime.datetime | None = datetime_start else: datetime_complete = None trial = FrozenTrial( number=-1, trial_id=-1, state=state, value=value, values=values, datetime_start=datetime_start, datetime_complete=datetime_complete, params=params, distributions=distributions, user_attrs=user_attrs, system_attrs=system_attrs, intermediate_values=intermediate_values, ) trial._validate() return trial