from __future__ import annotations
import abc
import copy
import decimal
import json
import math
from numbers import Real
from typing import Any
from typing import cast
from typing import TYPE_CHECKING
from typing import Union
import warnings
from optuna._deprecated import deprecated_class
if TYPE_CHECKING:
from collections.abc import Sequence
CategoricalChoiceType = Union[None, bool, int, float, str]
_float_distribution_deprecated_msg = (
"Use :class:`~optuna.distributions.FloatDistribution` instead."
)
_int_distribution_deprecated_msg = "Use :class:`~optuna.distributions.IntDistribution` instead."
class BaseDistribution(abc.ABC):
"""Base class for distributions.
Note that distribution classes are not supposed to be called by library users.
They are used by :class:`~optuna.trial.Trial` and :class:`~optuna.samplers` internally.
"""
def to_external_repr(self, param_value_in_internal_repr: float) -> Any:
"""Convert internal representation of a parameter value into external representation.
Args:
param_value_in_internal_repr:
Optuna's internal representation of a parameter value.
Returns:
Optuna's external representation of a parameter value.
"""
return param_value_in_internal_repr
@abc.abstractmethod
def to_internal_repr(self, param_value_in_external_repr: Any) -> float:
"""Convert external representation of a parameter value into internal representation.
Args:
param_value_in_external_repr:
Optuna's external representation of a parameter value.
Returns:
Optuna's internal representation of a parameter value.
"""
raise NotImplementedError
@abc.abstractmethod
def single(self) -> bool:
"""Test whether the range of this distribution contains just a single value.
Returns:
:obj:`True` if the range of this distribution contains just a single value,
otherwise :obj:`False`.
"""
raise NotImplementedError
@abc.abstractmethod
def _contains(self, param_value_in_internal_repr: float) -> bool:
"""Test if a parameter value is contained in the range of this distribution.
Args:
param_value_in_internal_repr:
Optuna's internal representation of a parameter value.
Returns:
:obj:`True` if the parameter value is contained in the range of this distribution,
otherwise :obj:`False`.
"""
raise NotImplementedError
def _asdict(self) -> dict:
return self.__dict__
def __eq__(self, other: Any) -> bool:
if not isinstance(other, BaseDistribution):
return NotImplemented
if type(self) is not type(other):
return False
return self.__dict__ == other.__dict__
def __hash__(self) -> int:
return hash((self.__class__,) + tuple(sorted(self.__dict__.items())))
def __repr__(self) -> str:
kwargs = ", ".join("{}={}".format(k, v) for k, v in sorted(self._asdict().items()))
return "{}({})".format(self.__class__.__name__, kwargs)
[docs]
class FloatDistribution(BaseDistribution):
"""A distribution on floats.
This object is instantiated by :func:`~optuna.trial.Trial.suggest_float`, and passed to
:mod:`~optuna.samplers` in general.
.. note::
When ``step`` is not :obj:`None`, if the range :math:`[\\mathsf{low}, \\mathsf{high}]`
is not divisible by :math:`\\mathsf{step}`, :math:`\\mathsf{high}` will be replaced
with the maximum of :math:`k \\times \\mathsf{step} + \\mathsf{low} < \\mathsf{high}`,
where :math:`k` is an integer.
Attributes:
low:
Lower endpoint of the range of the distribution. ``low`` is included in the range.
``low`` must be less than or equal to ``high``. If ``log`` is :obj:`True`,
``low`` must be larger than 0.
high:
Upper endpoint of the range of the distribution. ``high`` is included in the range.
``high`` must be greater than or equal to ``low``.
log:
If ``log`` is :obj:`True`, this distribution is in log-scaled domain.
In this case, all parameters enqueued to the distribution must be positive values.
This parameter must be :obj:`False` when the parameter ``step`` is not :obj:`None`.
step:
A discretization step. ``step`` must be larger than 0.
This parameter must be :obj:`None` when the parameter ``log`` is :obj:`True`.
"""
def __init__(
self, low: float, high: float, log: bool = False, step: None | float = None
) -> None:
if log and step is not None:
raise ValueError("The parameter `step` is not supported when `log` is true.")
if low > high:
raise ValueError(
"The `low` value must be smaller than or equal to the `high` value "
"(low={}, high={}).".format(low, high)
)
if log and low <= 0.0:
raise ValueError(
"The `low` value must be larger than 0 for a log distribution "
"(low={}, high={}).".format(low, high)
)
if step is not None and step <= 0:
raise ValueError(
"The `step` value must be non-zero positive value, " "but step={}.".format(step)
)
self.step = None
if step is not None:
high = _adjust_discrete_uniform_high(low, high, step)
self.step = float(step)
self.low = float(low)
self.high = float(high)
self.log = log
[docs]
def single(self) -> bool:
if self.step is None:
return self.low == self.high
else:
if self.low == self.high:
return True
high = decimal.Decimal(str(self.high))
low = decimal.Decimal(str(self.low))
step = decimal.Decimal(str(self.step))
return (high - low) < step
def _contains(self, param_value_in_internal_repr: float) -> bool:
value = param_value_in_internal_repr
if self.step is None:
return self.low <= value <= self.high
else:
k = (value - self.low) / self.step
return self.low <= value <= self.high and abs(k - round(k)) < 1.0e-8
[docs]
def to_internal_repr(self, param_value_in_external_repr: float) -> float:
try:
internal_repr = float(param_value_in_external_repr)
except (ValueError, TypeError) as e:
raise ValueError(
f"'{param_value_in_external_repr}' is not a valid type. "
"float-castable value is expected."
) from e
if math.isnan(internal_repr):
raise ValueError(f"`{param_value_in_external_repr}` is invalid value.")
if self.log and internal_repr <= 0.0:
raise ValueError(
f"`{param_value_in_external_repr}` is invalid value for the case log=True."
)
return internal_repr
[docs]
class IntDistribution(BaseDistribution):
"""A distribution on integers.
This object is instantiated by :func:`~optuna.trial.Trial.suggest_int`, and passed to
:mod:`~optuna.samplers` in general.
.. note::
When ``step`` is not :obj:`None`, if the range :math:`[\\mathsf{low}, \\mathsf{high}]`
is not divisible by :math:`\\mathsf{step}`, :math:`\\mathsf{high}` will be replaced
with the maximum of :math:`k \\times \\mathsf{step} + \\mathsf{low} < \\mathsf{high}`,
where :math:`k` is an integer.
Attributes:
low:
Lower endpoint of the range of the distribution. ``low`` is included in the range.
``low`` must be less than or equal to ``high``. If ``log`` is :obj:`True`,
``low`` must be larger than or equal to 1.
high:
Upper endpoint of the range of the distribution. ``high`` is included in the range.
``high`` must be greater than or equal to ``low``.
log:
If ``log`` is :obj:`True`, this distribution is in log-scaled domain.
In this case, all parameters enqueued to the distribution must be positive values.
This parameter must be :obj:`False` when the parameter ``step`` is not 1.
step:
A discretization step. ``step`` must be a positive integer. This parameter must be 1
when the parameter ``log`` is :obj:`True`.
"""
def __init__(self, low: int, high: int, log: bool = False, step: int = 1) -> None:
if log and step != 1:
raise ValueError(
"Samplers and other components in Optuna only accept step is 1 "
"when `log` argument is True."
)
if low > high:
raise ValueError(
"The `low` value must be smaller than or equal to the `high` value "
"(low={}, high={}).".format(low, high)
)
if log and low < 1:
raise ValueError(
"The `low` value must be equal to or greater than 1 for a log distribution "
"(low={}, high={}).".format(low, high)
)
if step <= 0:
raise ValueError(
"The `step` value must be non-zero positive value, but step={}.".format(step)
)
self.log = log
self.step = int(step)
self.low = int(low)
high = int(high)
self.high = _adjust_int_uniform_high(self.low, high, self.step)
[docs]
def to_external_repr(self, param_value_in_internal_repr: float) -> int:
return int(param_value_in_internal_repr)
[docs]
def to_internal_repr(self, param_value_in_external_repr: int) -> float:
try:
internal_repr = float(param_value_in_external_repr)
except (ValueError, TypeError) as e:
raise ValueError(
f"'{param_value_in_external_repr}' is not a valid type. "
"float-castable value is expected."
) from e
if math.isnan(internal_repr):
raise ValueError(f"`{param_value_in_external_repr}` is invalid value.")
if self.log and internal_repr <= 0.0:
raise ValueError(
f"`{param_value_in_external_repr}` is invalid value for the case log=True."
)
return internal_repr
[docs]
def single(self) -> bool:
if self.log:
return self.low == self.high
if self.low == self.high:
return True
return (self.high - self.low) < self.step
def _contains(self, param_value_in_internal_repr: float) -> bool:
value = param_value_in_internal_repr
return self.low <= value <= self.high and (value - self.low) % self.step == 0
def _categorical_choice_equal(
value1: CategoricalChoiceType, value2: CategoricalChoiceType
) -> bool:
"""A function to check two choices equal considering NaN.
This function can handle NaNs like np.float32("nan") other than float.
"""
value1_is_nan = isinstance(value1, Real) and math.isnan(float(value1))
value2_is_nan = isinstance(value2, Real) and math.isnan(float(value2))
return (value1 == value2) or (value1_is_nan and value2_is_nan)
[docs]
class CategoricalDistribution(BaseDistribution):
"""A categorical distribution.
This object is instantiated by :func:`~optuna.trial.Trial.suggest_categorical`, and
passed to :mod:`~optuna.samplers` in general.
Args:
choices:
Parameter value candidates. ``choices`` must have one element at least.
.. note::
Not all types are guaranteed to be compatible with all storages. It is recommended to
restrict the types of the choices to :obj:`None`, :class:`bool`, :class:`int`,
:class:`float` and :class:`str`.
Attributes:
choices:
Parameter value candidates.
"""
def __init__(self, choices: Sequence[CategoricalChoiceType]) -> None:
if len(choices) == 0:
raise ValueError("The `choices` must contain one or more elements.")
for choice in choices:
if choice is not None and not isinstance(choice, (bool, int, float, str)):
message = (
"Choices for a categorical distribution should be a tuple of None, bool, "
"int, float and str for persistent storage but contains {} which is of type "
"{}.".format(choice, type(choice).__name__)
)
warnings.warn(message)
self.choices = tuple(choices)
[docs]
def to_external_repr(self, param_value_in_internal_repr: float) -> CategoricalChoiceType:
return self.choices[int(param_value_in_internal_repr)]
[docs]
def to_internal_repr(self, param_value_in_external_repr: CategoricalChoiceType) -> float:
try:
# NOTE(nabenabe): With this implementation, we cannot distinguish some values
# such as True and 1, or 1.0 and 1. For example, if choices=[True, 1] and external_repr
# is 1, this method wrongly returns 0 instead of 1. However, we decided to accept this
# bug for such exceptional choices for less complexity and faster processing.
return self.choices.index(param_value_in_external_repr)
except ValueError: # ValueError: param_value_in_external_repr is not in choices.
# ValueError also happens if external_repr is nan or includes precision error in float.
for index, choice in enumerate(self.choices):
if _categorical_choice_equal(param_value_in_external_repr, choice):
return index
raise ValueError(f"'{param_value_in_external_repr}' not in {self.choices}.")
[docs]
def single(self) -> bool:
return len(self.choices) == 1
def _contains(self, param_value_in_internal_repr: float) -> bool:
index = int(param_value_in_internal_repr)
return 0 <= index < len(self.choices)
def __eq__(self, other: Any) -> bool:
if not isinstance(other, BaseDistribution):
return NotImplemented
if not isinstance(other, self.__class__):
return False
if self.__dict__.keys() != other.__dict__.keys():
return False
for key, value in self.__dict__.items():
if key == "choices":
if len(value) != len(getattr(other, key)):
return False
for choice, other_choice in zip(value, getattr(other, key)):
if not _categorical_choice_equal(choice, other_choice):
return False
else:
if value != getattr(other, key):
return False
return True
__hash__ = BaseDistribution.__hash__
DISTRIBUTION_CLASSES = (
IntDistribution,
IntLogUniformDistribution,
IntUniformDistribution,
FloatDistribution,
UniformDistribution,
LogUniformDistribution,
DiscreteUniformDistribution,
CategoricalDistribution,
)
[docs]
def json_to_distribution(json_str: str) -> BaseDistribution:
"""Deserialize a distribution in JSON format.
Args:
json_str: A JSON-serialized distribution.
Returns:
A deserialized distribution.
"""
json_dict = json.loads(json_str)
if "name" in json_dict:
if json_dict["name"] == CategoricalDistribution.__name__:
json_dict["attributes"]["choices"] = tuple(json_dict["attributes"]["choices"])
for cls in DISTRIBUTION_CLASSES:
if json_dict["name"] == cls.__name__:
return cls(**json_dict["attributes"])
raise ValueError("Unknown distribution class: {}".format(json_dict["name"]))
else:
# Deserialize a distribution from an abbreviated format.
if json_dict["type"] == "categorical":
return CategoricalDistribution(json_dict["choices"])
elif json_dict["type"] in ("float", "int"):
low = json_dict["low"]
high = json_dict["high"]
step = json_dict.get("step")
log = json_dict.get("log", False)
if json_dict["type"] == "float":
return FloatDistribution(low, high, log=log, step=step)
else:
if step is None:
step = 1
return IntDistribution(low=low, high=high, log=log, step=step)
raise ValueError("Unknown distribution type: {}".format(json_dict["type"]))
[docs]
def distribution_to_json(dist: BaseDistribution) -> str:
"""Serialize a distribution to JSON format.
Args:
dist: A distribution to be serialized.
Returns:
A JSON string of a given distribution.
"""
return json.dumps({"name": dist.__class__.__name__, "attributes": dist._asdict()})
[docs]
def check_distribution_compatibility(
dist_old: BaseDistribution, dist_new: BaseDistribution
) -> None:
"""A function to check compatibility of two distributions.
It checks whether ``dist_old`` and ``dist_new`` are the same kind of distributions.
If ``dist_old`` is :class:`~optuna.distributions.CategoricalDistribution`,
it further checks ``choices`` are the same between ``dist_old`` and ``dist_new``.
Note that this method is not supposed to be called by library users.
Args:
dist_old:
A distribution previously recorded in storage.
dist_new:
A distribution newly added to storage.
"""
if dist_old.__class__ != dist_new.__class__:
raise ValueError("Cannot set different distribution kind to the same parameter name.")
if isinstance(dist_old, (FloatDistribution, IntDistribution)):
# For mypy.
assert isinstance(dist_new, (FloatDistribution, IntDistribution))
if dist_old.log != dist_new.log:
raise ValueError("Cannot set different log configuration to the same parameter name.")
if not isinstance(dist_old, CategoricalDistribution):
return
if not isinstance(dist_new, CategoricalDistribution):
return
if dist_old != dist_new:
raise ValueError(
CategoricalDistribution.__name__ + " does not support dynamic value space."
)
def _adjust_discrete_uniform_high(low: float, high: float, step: float) -> float:
d_high = decimal.Decimal(str(high))
d_low = decimal.Decimal(str(low))
d_step = decimal.Decimal(str(step))
d_r = d_high - d_low
if d_r % d_step != decimal.Decimal("0"):
old_high = high
high = float((d_r // d_step) * d_step + d_low)
warnings.warn(
"The distribution is specified by [{low}, {old_high}] and step={step}, but the range "
"is not divisible by `step`. It will be replaced by [{low}, {high}].".format(
low=low, old_high=old_high, high=high, step=step
)
)
return high
def _adjust_int_uniform_high(low: int, high: int, step: int) -> int:
r = high - low
if r % step != 0:
old_high = high
high = r // step * step + low
warnings.warn(
"The distribution is specified by [{low}, {old_high}] and step={step}, but the range "
"is not divisible by `step`. It will be replaced by [{low}, {high}].".format(
low=low, old_high=old_high, high=high, step=step
)
)
return high
def _get_single_value(distribution: BaseDistribution) -> int | float | CategoricalChoiceType:
assert distribution.single()
if isinstance(
distribution,
(
FloatDistribution,
IntDistribution,
),
):
return distribution.low
elif isinstance(distribution, CategoricalDistribution):
return distribution.choices[0]
assert False
# TODO(himkt): Remove this method with the deletion of deprecated distributions.
# https://github.com/optuna/optuna/issues/2941
def _convert_old_distribution_to_new_distribution(
distribution: BaseDistribution,
suppress_warning: bool = False,
) -> BaseDistribution:
new_distribution: BaseDistribution
# Float distributions.
if isinstance(distribution, UniformDistribution):
new_distribution = FloatDistribution(
low=distribution.low,
high=distribution.high,
log=False,
step=None,
)
elif isinstance(distribution, LogUniformDistribution):
new_distribution = FloatDistribution(
low=distribution.low,
high=distribution.high,
log=True,
step=None,
)
elif isinstance(distribution, DiscreteUniformDistribution):
new_distribution = FloatDistribution(
low=distribution.low,
high=distribution.high,
log=False,
step=distribution.q,
)
# Integer distributions.
elif isinstance(distribution, IntUniformDistribution):
new_distribution = IntDistribution(
low=distribution.low,
high=distribution.high,
log=False,
step=distribution.step,
)
elif isinstance(distribution, IntLogUniformDistribution):
new_distribution = IntDistribution(
low=distribution.low,
high=distribution.high,
log=True,
step=distribution.step,
)
# Categorical distribution.
else:
new_distribution = distribution
if new_distribution != distribution and not suppress_warning:
message = (
f"{distribution} is deprecated and internally converted to"
f" {new_distribution}. See https://github.com/optuna/optuna/issues/2941."
)
warnings.warn(message, FutureWarning)
return new_distribution
def _is_distribution_log(distribution: BaseDistribution) -> bool:
if isinstance(distribution, (FloatDistribution, IntDistribution)):
return distribution.log
return False