from collections import defaultdict
from typing import Callable
from typing import cast
from typing import DefaultDict
from typing import List
from typing import Optional
import numpy as np
from optuna._experimental import experimental
from optuna.logging import get_logger
from optuna.study import Study
from optuna.study._study_direction import StudyDirection
from optuna.trial import FrozenTrial
from optuna.trial import TrialState
from optuna.visualization._utils import _check_plot_args
from optuna.visualization._utils import _filter_nonfinite
from optuna.visualization._utils import _get_skipped_trial_numbers
from optuna.visualization.matplotlib._matplotlib_imports import _imports
from optuna.visualization.matplotlib._utils import _is_categorical
from optuna.visualization.matplotlib._utils import _is_log_scale
from optuna.visualization.matplotlib._utils import _is_numerical
if _imports.is_successful():
from optuna.visualization.matplotlib._matplotlib_imports import Axes
from optuna.visualization.matplotlib._matplotlib_imports import LineCollection
from optuna.visualization.matplotlib._matplotlib_imports import plt
_logger = get_logger(__name__)
[docs]@experimental("2.2.0")
def plot_parallel_coordinate(
study: Study,
params: Optional[List[str]] = None,
*,
target: Optional[Callable[[FrozenTrial], float]] = None,
target_name: str = "Objective Value",
) -> "Axes":
"""Plot the high-dimensional parameter relationships in a study with Matplotlib.
Note that, if a parameter contains missing values, a trial with missing values is not plotted.
.. seealso::
Please refer to :func:`optuna.visualization.plot_parallel_coordinate` for an example.
Example:
The following code snippet shows how to plot the high-dimensional parameter relationships.
.. plot::
import optuna
def objective(trial):
x = trial.suggest_float("x", -100, 100)
y = trial.suggest_categorical("y", [-1, 0, 1])
return x ** 2 + y
sampler = optuna.samplers.TPESampler(seed=10)
study = optuna.create_study(sampler=sampler)
study.optimize(objective, n_trials=10)
optuna.visualization.matplotlib.plot_parallel_coordinate(study, params=["x", "y"])
Args:
study:
A :class:`~optuna.study.Study` object whose trials are plotted for their target values.
params:
Parameter list to visualize. The default is all parameters.
target:
A function to specify the value to display. If it is :obj:`None` and ``study`` is being
used for single-objective optimization, the objective values are plotted.
.. note::
Specify this argument if ``study`` is being used for multi-objective optimization.
target_name:
Target's name to display on the axis label and the legend.
Returns:
A :class:`matplotlib.axes.Axes` object.
"""
_imports.check()
_check_plot_args(study, target, target_name)
return _get_parallel_coordinate_plot(study, params, target, target_name)
def _get_parallel_coordinate_plot(
study: Study,
params: Optional[List[str]] = None,
target: Optional[Callable[[FrozenTrial], float]] = None,
target_name: str = "Objective Value",
) -> "Axes":
if target is None:
def _target(t: FrozenTrial) -> float:
return cast(float, t.value)
target = _target
reversescale = study.direction == StudyDirection.MINIMIZE
else:
reversescale = True
# Set up the graph style.
fig, ax = plt.subplots()
cmap = plt.get_cmap("Blues_r" if reversescale else "Blues")
ax.set_title("Parallel Coordinate Plot")
ax.spines["top"].set_visible(False)
ax.spines["bottom"].set_visible(False)
# Prepare data for plotting.
trials = _filter_nonfinite(
study.get_trials(deepcopy=False, states=(TrialState.COMPLETE,)), target=target
)
if len(trials) == 0:
_logger.warning("Your study does not have any completed trials.")
return ax
all_params = {p_name for t in trials for p_name in t.params.keys()}
if params is not None:
for input_p_name in params:
if input_p_name not in all_params:
raise ValueError("Parameter {} does not exist in your study.".format(input_p_name))
all_params = set(params)
sorted_params = sorted(all_params)
skipped_trial_numbers = _get_skipped_trial_numbers(trials, sorted_params)
obj_org = [target(t) for t in trials if t.number not in skipped_trial_numbers]
if len(obj_org) == 0:
_logger.warning("Your study has only completed trials with missing parameters.")
return ax
obj_min = min(obj_org)
obj_max = max(obj_org)
obj_w = obj_max - obj_min
dims_obj_base = [[o] for o in obj_org]
cat_param_names = []
cat_param_values = []
cat_param_ticks = []
param_values = []
var_names = [target_name]
numeric_cat_params_indices: List[int] = []
for param_index, p_name in enumerate(sorted_params):
values = [t.params[p_name] for t in trials if t.number not in skipped_trial_numbers]
if _is_categorical(trials, p_name):
vocab = defaultdict(lambda: len(vocab)) # type: DefaultDict[str, int]
if _is_numerical(trials, p_name):
_ = [vocab[v] for v in sorted(values)]
numeric_cat_params_indices.append(param_index)
values = [vocab[v] for v in values]
cat_param_names.append(p_name)
vocab_item_sorted = sorted(vocab.items(), key=lambda x: x[1])
cat_param_values.append([v[0] for v in vocab_item_sorted])
cat_param_ticks.append([v[1] for v in vocab_item_sorted])
if _is_log_scale(trials, p_name):
values_for_lc = [np.log10(v) for v in values]
else:
values_for_lc = values
p_min = min(values_for_lc)
p_max = max(values_for_lc)
p_w = p_max - p_min
if p_w == 0.0:
center = obj_w / 2 + obj_min
for i in range(len(values)):
dims_obj_base[i].append(center)
else:
for i, v in enumerate(values_for_lc):
dims_obj_base[i].append((v - p_min) / p_w * obj_w + obj_min)
var_names.append(p_name if len(p_name) < 20 else "{}...".format(p_name[:17]))
param_values.append(values)
if numeric_cat_params_indices:
# np.lexsort consumes the sort keys the order from back to front.
# So the values of parameters have to be reversed the order.
sorted_idx = np.lexsort(
[param_values[index] for index in numeric_cat_params_indices][::-1]
)
# Since the values are mapped to other categories by the index,
# the index will be swapped according to the sorted index of numeric params.
param_values = [list(np.array(v)[sorted_idx]) for v in param_values]
# Draw multiple line plots and axes.
# Ref: https://stackoverflow.com/a/50029441
ax.set_xlim(0, len(sorted_params))
ax.set_ylim(obj_min, obj_max)
xs = [range(len(sorted_params) + 1) for _ in range(len(dims_obj_base))]
segments = [np.column_stack([x, y]) for x, y in zip(xs, dims_obj_base)]
lc = LineCollection(segments, cmap=cmap)
lc.set_array(np.asarray(obj_org))
axcb = fig.colorbar(lc, pad=0.1)
axcb.set_label(target_name)
plt.xticks(range(len(sorted_params) + 1), var_names, rotation=330)
for i, p_name in enumerate(sorted_params):
ax2 = ax.twinx()
ax2.set_ylim(min(param_values[i]), max(param_values[i]))
if _is_log_scale(trials, p_name):
ax2.set_yscale("log")
ax2.spines["top"].set_visible(False)
ax2.spines["bottom"].set_visible(False)
ax2.xaxis.set_visible(False)
ax2.spines["right"].set_position(("axes", (i + 1) / len(sorted_params)))
if p_name in cat_param_names:
idx = cat_param_names.index(p_name)
tick_pos = cat_param_ticks[idx]
tick_labels = cat_param_values[idx]
ax2.set_yticks(tick_pos)
ax2.set_yticklabels(tick_labels)
ax.add_collection(lc)
return ax