Source code for optuna.visualization._slice

from __future__ import annotations

from collections import defaultdict
from collections.abc import Callable
from typing import Any
from typing import cast
from typing import NamedTuple
from typing import TYPE_CHECKING

from optuna.distributions import CategoricalChoiceType
from optuna.distributions import CategoricalDistribution
from optuna.logging import get_logger
from optuna.samplers._base import _CONSTRAINTS_KEY
from optuna.trial import TrialState
from optuna.visualization._plotly_imports import _imports


if TYPE_CHECKING:
    from collections.abc import Callable

    from optuna.study import Study
    from optuna.trial import FrozenTrial
from optuna.visualization._utils import _check_plot_args
from optuna.visualization._utils import _filter_nonfinite
from optuna.visualization._utils import _is_log_scale


if _imports.is_successful():
    from optuna.visualization._plotly_imports import go
    from optuna.visualization._plotly_imports import make_subplots
    from optuna.visualization._plotly_imports import Scatter
    from optuna.visualization._utils import COLOR_SCALE

_logger = get_logger(__name__)


class _SliceSubplotInfo(NamedTuple):
    param_name: str
    x: list[Any]
    y: list[float]
    trial_numbers: list[int]
    is_log: bool
    is_numerical: bool
    constraints: list[bool]
    x_labels: tuple[CategoricalChoiceType, ...] | None


class _SlicePlotInfo(NamedTuple):
    target_name: str
    subplots: list[_SliceSubplotInfo]


class _PlotValues(NamedTuple):
    x: list[Any]
    y: list[float]
    trial_numbers: list[int]


def _get_slice_subplot_info(
    trials: list[FrozenTrial],
    param: str,
    target: Callable[[FrozenTrial], float] | None,
    log_scale: bool,
    numerical: bool,
    x_labels: tuple[CategoricalChoiceType, ...] | None,
) -> _SliceSubplotInfo:
    if target is None:

        def _target(t: FrozenTrial) -> float:
            return cast("float", t.value)

        target = _target

    plot_info = _SliceSubplotInfo(
        param_name=param,
        x=[],
        y=[],
        trial_numbers=[],
        is_log=log_scale,
        is_numerical=numerical,
        x_labels=x_labels,
        constraints=[],
    )

    for t in trials:
        if param not in t.params:
            continue
        plot_info.x.append(t.params[param])
        plot_info.y.append(target(t))
        plot_info.trial_numbers.append(t.number)
        constraints = t.system_attrs.get(_CONSTRAINTS_KEY)
        plot_info.constraints.append(constraints is None or all([x <= 0.0 for x in constraints]))

    return plot_info


def _get_slice_plot_info(
    study: Study,
    params: list[str] | None,
    target: Callable[[FrozenTrial], float] | None,
    target_name: str,
) -> _SlicePlotInfo:
    _check_plot_args(study, target, target_name)

    trials = _filter_nonfinite(
        study.get_trials(deepcopy=False, states=(TrialState.COMPLETE,)), target=target
    )

    if len(trials) == 0:
        _logger.warning("Your study does not have any completed trials.")
        return _SlicePlotInfo(target_name, [])

    all_params = {p_name for t in trials for p_name in t.params.keys()}

    distributions = {}
    for trial in trials:
        for param_name, distribution in trial.distributions.items():
            if param_name not in distributions:
                distributions[param_name] = distribution

    x_labels = {}
    for param_name, distribution in distributions.items():
        if isinstance(distribution, CategoricalDistribution):
            x_labels[param_name] = distribution.choices

    if params is None:
        sorted_params = sorted(all_params)
    else:
        for input_p_name in params:
            if input_p_name not in all_params:
                raise ValueError(f"Parameter {input_p_name} does not exist in your study.")
        sorted_params = sorted(set(params))

    return _SlicePlotInfo(
        target_name=target_name,
        subplots=[
            _get_slice_subplot_info(
                trials=trials,
                param=param,
                target=target,
                log_scale=_is_log_scale(trials, param),
                numerical=not isinstance(distributions[param], CategoricalDistribution),
                x_labels=x_labels.get(param),
            )
            for param in sorted_params
        ],
    )


[docs] def plot_slice( study: Study, params: list[str] | None = None, *, target: Callable[[FrozenTrial], float] | None = None, target_name: str = "Objective Value", ) -> "go.Figure": """Plot the parameter relationship as slice plot in a study. Note that, if a parameter contains missing values, a trial with missing values is not plotted. Args: study: A :class:`~optuna.study.Study` object whose trials are plotted for their target values. params: Parameter list to visualize. The default is all parameters. target: A function to specify the value to display. If it is :obj:`None` and ``study`` is being used for single-objective optimization, the objective values are plotted. .. note:: Specify this argument if ``study`` is being used for multi-objective optimization. target_name: Target's name to display on the axis label. Returns: A :class:`plotly.graph_objects.Figure` object. """ _imports.check() return _get_slice_plot(_get_slice_plot_info(study, params, target, target_name))
def _get_slice_plot(info: _SlicePlotInfo) -> "go.Figure": layout = go.Layout(title="Slice Plot") if len(info.subplots) == 0: return go.Figure(data=[], layout=layout) elif len(info.subplots) == 1: figure = go.Figure(data=_generate_slice_subplot(info.subplots[0]), layout=layout) figure.update_xaxes(title_text=info.subplots[0].param_name) figure.update_yaxes(title_text=info.target_name) if not info.subplots[0].is_numerical: figure.update_xaxes( type="category", categoryorder="array", categoryarray=_get_categorical_labels(info.subplots[0].x_labels), ) elif info.subplots[0].is_log: figure.update_xaxes(type="log") else: figure = make_subplots(rows=1, cols=len(info.subplots), shared_yaxes=True) figure.update_layout(layout) showscale = True # showscale option only needs to be specified once. for column_index, subplot_info in enumerate(info.subplots, start=1): trace = _generate_slice_subplot(subplot_info) trace[0].update(marker={"showscale": showscale}) # showscale's default is True. if showscale: showscale = False for t in trace: figure.add_trace(t, row=1, col=column_index) figure.update_xaxes(title_text=subplot_info.param_name, row=1, col=column_index) if column_index == 1: figure.update_yaxes(title_text=info.target_name, row=1, col=column_index) if not subplot_info.is_numerical: figure.update_xaxes( type="category", categoryorder="array", categoryarray=_get_categorical_labels(subplot_info.x_labels), row=1, col=column_index, ) elif subplot_info.is_log: figure.update_xaxes(type="log", row=1, col=column_index) if len(info.subplots) > 3: # Ensure that each subplot has a minimum width without relying on autusizing. figure.update_layout(width=300 * len(info.subplots)) return figure def _generate_slice_subplot(subplot_info: _SliceSubplotInfo) -> list[Scatter]: trace = [] feasible = _PlotValues([], [], []) infeasible = _PlotValues([], [], []) for x, y, num, c in zip( subplot_info.x, subplot_info.y, subplot_info.trial_numbers, subplot_info.constraints ): if subplot_info.is_numerical and x is None: continue if c: feasible.x.append(x) feasible.y.append(y) feasible.trial_numbers.append(num) else: infeasible.x.append(x) infeasible.y.append(y) if subplot_info.is_numerical: feasible_x = feasible.x feasible_y = feasible.y feasible_c = feasible.trial_numbers infeasible_x = infeasible.x infeasible_y = infeasible.y else: feasible_x, feasible_y, feasible_c = _get_categorical_plot_values(subplot_info, feasible) infeasible_x, infeasible_y, _ = _get_categorical_plot_values(subplot_info, infeasible) trace.append( go.Scatter( x=feasible_x, y=feasible_y, mode="markers", name="Feasible Trial", marker={ "line": {"width": 0.5, "color": "Grey"}, "color": feasible_c, "colorscale": COLOR_SCALE, "colorbar": { "title": "Trial", "x": 1.0, # Offset the colorbar position with a fixed width `xpad`. "xpad": 40, }, }, showlegend=False, ) ) if len(infeasible_x) > 0: trace.append( go.Scatter( x=infeasible_x, y=infeasible_y, mode="markers", name="Infeasible Trial", marker={ "color": "#cccccc", }, showlegend=False, ) ) return trace def _get_categorical_plot_values( subplot_info: _SliceSubplotInfo, values: _PlotValues ) -> tuple[list[Any], list[float], list[int]]: assert subplot_info.x_labels is not None value_x = [] value_y = [] value_c = [] points_dict = defaultdict(list) for x, y, number in zip(values.x, values.y, values.trial_numbers): points_dict[x].append((y, number)) for x_label in subplot_info.x_labels: for y, number in points_dict[x_label]: value_x.append(repr(x_label)) value_y.append(y) value_c.append(number) return value_x, value_y, value_c def _get_categorical_labels( x_labels: tuple[CategoricalChoiceType, ...] | None, ) -> list[str] | None: if x_labels is None: return None return [repr(x_label) for x_label in x_labels]