Source code for optuna.visualization.matplotlib._parallel_coordinate

from __future__ import annotations

from typing import Callable

import numpy as np

from optuna._experimental import experimental_func
from import Study
from optuna.trial import FrozenTrial
from optuna.visualization._parallel_coordinate import _get_parallel_coordinate_info
from optuna.visualization._parallel_coordinate import _ParallelCoordinateInfo
from optuna.visualization.matplotlib._matplotlib_imports import _imports

if _imports.is_successful():
    from optuna.visualization.matplotlib._matplotlib_imports import Axes
    from optuna.visualization.matplotlib._matplotlib_imports import LineCollection
    from optuna.visualization.matplotlib._matplotlib_imports import plt

[docs] @experimental_func("2.2.0") def plot_parallel_coordinate( study: Study, params: list[str] | None = None, *, target: Callable[[FrozenTrial], float] | None = None, target_name: str = "Objective Value", ) -> "Axes": """Plot the high-dimensional parameter relationships in a study with Matplotlib. Note that, if a parameter contains missing values, a trial with missing values is not plotted. .. seealso:: Please refer to :func:`optuna.visualization.plot_parallel_coordinate` for an example. Example: The following code snippet shows how to plot the high-dimensional parameter relationships. .. plot:: import optuna def objective(trial): x = trial.suggest_float("x", -100, 100) y = trial.suggest_categorical("y", [-1, 0, 1]) return x ** 2 + y sampler = optuna.samplers.TPESampler(seed=10) study = optuna.create_study(sampler=sampler) study.optimize(objective, n_trials=10) optuna.visualization.matplotlib.plot_parallel_coordinate(study, params=["x", "y"]) Args: study: A :class:`` object whose trials are plotted for their target values. params: Parameter list to visualize. The default is all parameters. target: A function to specify the value to display. If it is :obj:`None` and ``study`` is being used for single-objective optimization, the objective values are plotted. .. note:: Specify this argument if ``study`` is being used for multi-objective optimization. target_name: Target's name to display on the axis label and the legend. Returns: A :class:`matplotlib.axes.Axes` object. .. note:: The colormap is reversed when the ``target`` argument isn't :obj:`None` or ``direction`` of :class:`` is ``minimize``. """ _imports.check() info = _get_parallel_coordinate_info(study, params, target, target_name) return _get_parallel_coordinate_plot(info)
def _get_parallel_coordinate_plot(info: _ParallelCoordinateInfo) -> "Axes": reversescale = info.reverse_scale target_name = info.target_name # Set up the graph style. fig, ax = plt.subplots() cmap = plt.get_cmap("Blues_r" if reversescale else "Blues") ax.set_title("Parallel Coordinate Plot") ax.spines["top"].set_visible(False) ax.spines["bottom"].set_visible(False) # Prepare data for plotting. if len(info.dims_params) == 0 or len(info.dim_objective.values) == 0: return ax obj_min = info.dim_objective.range[0] obj_max = info.dim_objective.range[1] obj_w = obj_max - obj_min dims_obj_base = [[o] for o in info.dim_objective.values] for dim in info.dims_params: p_min = dim.range[0] p_max = dim.range[1] p_w = p_max - p_min if p_w == 0.0: center = obj_w / 2 + obj_min for i in range(len(dim.values)): dims_obj_base[i].append(center) else: for i, v in enumerate(dim.values): dims_obj_base[i].append((v - p_min) / p_w * obj_w + obj_min) # Draw multiple line plots and axes. # Ref: n_params = len(info.dims_params) ax.set_xlim(0, n_params) ax.set_ylim(info.dim_objective.range[0], info.dim_objective.range[1]) xs = [range(n_params + 1) for _ in range(len(dims_obj_base))] segments = [np.column_stack([x, y]) for x, y in zip(xs, dims_obj_base)] lc = LineCollection(segments, cmap=cmap) lc.set_array(np.asarray(info.dim_objective.values)) axcb = fig.colorbar(lc, pad=0.1, ax=ax) axcb.set_label(target_name) var_names = [info.dim_objective.label] + [dim.label for dim in info.dims_params] plt.xticks(range(n_params + 1), var_names, rotation=330) for i, dim in enumerate(info.dims_params): ax2 = ax.twinx() if dim.is_log: ax2.set_ylim(np.power(10, dim.range[0]), np.power(10, dim.range[1])) ax2.set_yscale("log") else: ax2.set_ylim(dim.range[0], dim.range[1]) ax2.spines["top"].set_visible(False) ax2.spines["bottom"].set_visible(False) ax2.xaxis.set_visible(False) ax2.spines["right"].set_position(("axes", (i + 1) / n_params)) if dim.is_cat: ax2.set_yticks(dim.tickvals) ax2.set_yticklabels(dim.ticktext) ax.add_collection(lc) return ax