Source code for optuna.visualization._parallel_coordinate

from __future__ import annotations

from collections import defaultdict
import math
from typing import cast
from typing import NamedTuple
from typing import TYPE_CHECKING

import numpy as np

from optuna.distributions import CategoricalDistribution
from optuna.logging import get_logger
from optuna.trial import TrialState


if TYPE_CHECKING:
    from collections.abc import Callable
    from typing import Any

    from optuna.study import Study
    from optuna.trial import FrozenTrial

from optuna.visualization._plotly_imports import _imports
from optuna.visualization._utils import _check_plot_args
from optuna.visualization._utils import _filter_nonfinite
from optuna.visualization._utils import _get_skipped_trial_numbers
from optuna.visualization._utils import _is_log_scale
from optuna.visualization._utils import _is_numerical
from optuna.visualization._utils import _is_reverse_scale


if _imports.is_successful():
    from optuna.visualization._plotly_imports import go
    from optuna.visualization._utils import COLOR_SCALE

_logger = get_logger(__name__)


class _DimensionInfo(NamedTuple):
    label: str
    values: tuple[float, ...]
    range: tuple[float, float]
    is_log: bool
    is_cat: bool
    tickvals: list[int | float]
    ticktext: list[str]


class _ParallelCoordinateInfo(NamedTuple):
    dim_objective: _DimensionInfo
    dims_params: list[_DimensionInfo]
    reverse_scale: bool
    target_name: str


[docs] def plot_parallel_coordinate( study: Study, params: list[str] | None = None, *, target: Callable[[FrozenTrial], float] | None = None, target_name: str = "Objective Value", ) -> "go.Figure": """Plot the high-dimensional parameter relationships in a study. Note that, if a parameter contains missing values, a trial with missing values is not plotted. Args: study: A :class:`~optuna.study.Study` object whose trials are plotted for their target values. params: Parameter list to visualize. The default is all parameters. target: A function to specify the value to display. If it is :obj:`None` and ``study`` is being used for single-objective optimization, the objective values are plotted. .. note:: Specify this argument if ``study`` is being used for multi-objective optimization. target_name: Target's name to display on the axis label and the legend. Returns: A :class:`plotly.graph_objects.Figure` object. .. note:: The colormap is reversed when the ``target`` argument isn't :obj:`None` or ``direction`` of :class:`~optuna.study.Study` is ``minimize``. """ _imports.check() info = _get_parallel_coordinate_info(study, params, target, target_name) return _get_parallel_coordinate_plot(info)
def _get_parallel_coordinate_plot(info: _ParallelCoordinateInfo) -> "go.Figure": layout = go.Layout(title="Parallel Coordinate Plot") if len(info.dims_params) == 0 or len(info.dim_objective.values) == 0: return go.Figure(data=[], layout=layout) dims = _get_dims_from_info(info) reverse_scale = info.reverse_scale target_name = info.target_name traces = [ go.Parcoords( dimensions=dims, labelangle=30, labelside="bottom", line={ "color": dims[0]["values"], "colorscale": COLOR_SCALE, "colorbar": {"title": target_name}, "showscale": True, "reversescale": reverse_scale, }, ) ] figure = go.Figure(data=traces, layout=layout) return figure def _get_parallel_coordinate_info( study: Study, params: list[str] | None = None, target: Callable[[FrozenTrial], float] | None = None, target_name: str = "Objective Value", ) -> _ParallelCoordinateInfo: _check_plot_args(study, target, target_name) reverse_scale = _is_reverse_scale(study, target) trials = _filter_nonfinite( study.get_trials(deepcopy=False, states=(TrialState.COMPLETE,)), target=target ) all_params = {p_name for t in trials for p_name in t.params.keys()} if params is not None: for input_p_name in params: if input_p_name not in all_params: raise ValueError(f"Parameter {input_p_name} does not exist in your study.") all_params = set(params) sorted_params = sorted(all_params) if target is None: def _target(t: FrozenTrial) -> float: return cast(float, t.value) target = _target skipped_trial_numbers = _get_skipped_trial_numbers(trials, sorted_params) objectives = tuple([target(t) for t in trials if t.number not in skipped_trial_numbers]) # The value of (0, 0) is a dummy range. It is ignored when we plot. objective_range = (min(objectives), max(objectives)) if len(objectives) > 0 else (0, 0) dim_objective = _DimensionInfo( label=target_name, values=objectives, range=objective_range, is_log=False, is_cat=False, tickvals=[], ticktext=[], ) if len(trials) == 0: _logger.warning("Your study does not have any completed trials.") return _ParallelCoordinateInfo( dim_objective=dim_objective, dims_params=[], reverse_scale=reverse_scale, target_name=target_name, ) if len(objectives) == 0: _logger.warning("Your study has only completed trials with missing parameters.") return _ParallelCoordinateInfo( dim_objective=dim_objective, dims_params=[], reverse_scale=reverse_scale, target_name=target_name, ) numeric_cat_params_indices: list[int] = [] dims = [] for dim_index, p_name in enumerate(sorted_params, start=1): values = [] is_categorical = False for t in trials: if t.number in skipped_trial_numbers: continue if p_name in t.params: values.append(t.params[p_name]) is_categorical |= isinstance(t.distributions[p_name], CategoricalDistribution) if _is_log_scale(trials, p_name): values = [math.log10(v) for v in values] min_value = min(values) max_value = max(values) tickvals: list[int | float] = list( range(math.ceil(min_value), math.floor(max_value) + 1) ) if min_value not in tickvals: tickvals = [min_value] + tickvals if max_value not in tickvals: tickvals = tickvals + [max_value] dim = _DimensionInfo( label=_truncate_label(p_name), values=tuple(values), range=(min_value, max_value), is_log=True, is_cat=False, tickvals=tickvals, ticktext=[f"{math.pow(10, x):.3g}" for x in tickvals], ) elif is_categorical: vocab: defaultdict[int | str, int] = defaultdict(lambda: len(vocab)) ticktext: list[str] if _is_numerical(trials, p_name): _ = [vocab[v] for v in sorted(values)] values = [vocab[v] for v in values] ticktext = [str(v) for v in list(sorted(vocab.keys()))] numeric_cat_params_indices.append(dim_index) else: values = [vocab[v] for v in values] ticktext = [str(v) for v in list(sorted(vocab.keys(), key=lambda x: vocab[x]))] dim = _DimensionInfo( label=_truncate_label(p_name), values=tuple(values), range=(min(values), max(values)), is_log=False, is_cat=True, tickvals=list(range(len(vocab))), ticktext=ticktext, ) else: dim = _DimensionInfo( label=_truncate_label(p_name), values=tuple(values), range=(min(values), max(values)), is_log=False, is_cat=False, tickvals=[], ticktext=[], ) dims.append(dim) if numeric_cat_params_indices: dims.insert(0, dim_objective) # np.lexsort consumes the sort keys the order from back to front. # So the values of parameters have to be reversed the order. idx = np.lexsort([dims[index].values for index in numeric_cat_params_indices][::-1]) updated_dims = [] for dim in dims: # Since the values are mapped to other categories by the index, # the index will be swapped according to the sorted index of numeric params. updated_dims.append( _DimensionInfo( label=dim.label, values=tuple(np.array(dim.values)[idx]), range=dim.range, is_log=dim.is_log, is_cat=dim.is_cat, tickvals=dim.tickvals, ticktext=dim.ticktext, ) ) dim_objective = updated_dims[0] dims = updated_dims[1:] return _ParallelCoordinateInfo( dim_objective=dim_objective, dims_params=dims, reverse_scale=reverse_scale, target_name=target_name, ) def _get_dims_from_info(info: _ParallelCoordinateInfo) -> list[dict[str, Any]]: dims = [ { "label": info.dim_objective.label, "values": info.dim_objective.values, "range": info.dim_objective.range, } ] for dim in info.dims_params: if dim.is_log or dim.is_cat: dims.append( { "label": dim.label, "values": dim.values, "range": dim.range, "tickvals": dim.tickvals, "ticktext": dim.ticktext, } ) else: dims.append({"label": dim.label, "values": dim.values, "range": dim.range}) return dims def _truncate_label(label: str) -> str: return label if len(label) < 20 else f"{label[:17]}..."