Source code for optuna.trial._fixed

import datetime
from typing import Any
from typing import Dict
from typing import Optional
from typing import Sequence
from typing import Union

from optuna import distributions
from optuna.distributions import BaseDistribution
from optuna.distributions import CategoricalChoiceType
from optuna.distributions import CategoricalDistribution
from optuna.distributions import DiscreteUniformDistribution
from optuna.distributions import IntLogUniformDistribution
from optuna.distributions import IntUniformDistribution
from optuna.distributions import LogUniformDistribution
from optuna.distributions import UniformDistribution
from optuna.trial._base import BaseTrial


[docs]class FixedTrial(BaseTrial): """A trial class which suggests a fixed value for each parameter. This object has the same methods as :class:`~optuna.trial.Trial`, and it suggests pre-defined parameter values. The parameter values can be determined at the construction of the :class:`~optuna.trial.FixedTrial` object. In contrast to :class:`~optuna.trial.Trial`, :class:`~optuna.trial.FixedTrial` does not depend on :class:`~optuna.study.Study`, and it is useful for deploying optimization results. Example: Evaluate an objective function with parameter values given by a user. .. testcode:: import optuna def objective(trial): x = trial.suggest_uniform("x", -100, 100) y = trial.suggest_categorical("y", [-1, 0, 1]) return x ** 2 + y assert objective(optuna.trial.FixedTrial({"x": 1, "y": 0})) == 1 .. note:: Please refer to :class:`~optuna.trial.Trial` for details of methods and properties. Args: params: A dictionary containing all parameters. number: A trial number. Defaults to ``0``. """ def __init__(self, params: Dict[str, Any], number: int = 0) -> None: self._params = params self._suggested_params: Dict[str, Any] = {} self._distributions: Dict[str, BaseDistribution] = {} self._user_attrs: Dict[str, Any] = {} self._system_attrs: Dict[str, Any] = {} self._datetime_start = datetime.datetime.now() self._number = number def suggest_float( self, name: str, low: float, high: float, *, step: Optional[float] = None, log: bool = False, ) -> float: if step is not None: if log: raise ValueError("The parameter `step` is not supported when `log` is True.") else: return self._suggest(name, DiscreteUniformDistribution(low=low, high=high, q=step)) else: if log: return self._suggest(name, LogUniformDistribution(low=low, high=high)) else: return self._suggest(name, UniformDistribution(low=low, high=high)) def suggest_uniform(self, name: str, low: float, high: float) -> float: return self._suggest(name, UniformDistribution(low=low, high=high)) def suggest_loguniform(self, name: str, low: float, high: float) -> float: return self._suggest(name, LogUniformDistribution(low=low, high=high)) def suggest_discrete_uniform(self, name: str, low: float, high: float, q: float) -> float: discrete = DiscreteUniformDistribution(low=low, high=high, q=q) return self._suggest(name, discrete) def suggest_int(self, name: str, low: int, high: int, step: int = 1, log: bool = False) -> int: if step != 1: if log: raise ValueError( "The parameter `step != 1` is not supported when `log` is True." "The specified `step` is {}.".format(step) ) else: distribution: Union[ IntUniformDistribution, IntLogUniformDistribution ] = IntUniformDistribution(low=low, high=high, step=step) else: if log: distribution = IntLogUniformDistribution(low=low, high=high) else: distribution = IntUniformDistribution(low=low, high=high, step=step) return int(self._suggest(name, distribution)) def suggest_categorical( self, name: str, choices: Sequence[CategoricalChoiceType] ) -> CategoricalChoiceType: return self._suggest(name, CategoricalDistribution(choices=choices)) def report(self, value: float, step: int) -> None: pass def should_prune(self) -> bool: return False def set_user_attr(self, key: str, value: Any) -> None: self._user_attrs[key] = value def set_system_attr(self, key: str, value: Any) -> None: self._system_attrs[key] = value def _suggest(self, name: str, distribution: BaseDistribution) -> Any: if name not in self._params: raise ValueError( "The value of the parameter '{}' is not found. Please set it at " "the construction of the FixedTrial object.".format(name) ) value = self._params[name] param_value_in_internal_repr = distribution.to_internal_repr(value) if not distribution._contains(param_value_in_internal_repr): raise ValueError( "The value {} of the parameter '{}' is out of " "the range of the distribution {}.".format(value, name, distribution) ) if name in self._distributions: distributions.check_distribution_compatibility(self._distributions[name], distribution) self._suggested_params[name] = value self._distributions[name] = distribution return value @property def params(self) -> Dict[str, Any]: return self._suggested_params @property def distributions(self) -> Dict[str, BaseDistribution]: return self._distributions @property def user_attrs(self) -> Dict[str, Any]: return self._user_attrs @property def system_attrs(self) -> Dict[str, Any]: return self._system_attrs @property def datetime_start(self) -> Optional[datetime.datetime]: return self._datetime_start @property def number(self) -> int: return self._number