Source code for optuna.visualization.matplotlib._pareto_front

from typing import List
from typing import Optional

import optuna
from optuna._experimental import experimental
from optuna.study import Study
from optuna.trial import FrozenTrial
from optuna.trial import TrialState
from optuna.visualization.matplotlib._matplotlib_imports import _imports


if _imports.is_successful():
    from optuna.visualization.matplotlib._matplotlib_imports import Axes
    from optuna.visualization.matplotlib._matplotlib_imports import plt

_logger = optuna.logging.get_logger(__name__)


[docs]@experimental("2.8.0") def plot_pareto_front( study: Study, *, target_names: Optional[List[str]] = None, include_dominated_trials: bool = True, axis_order: Optional[List[int]] = None, ) -> "Axes": """Plot the Pareto front of a study. .. seealso:: Please refer to :func:`optuna.visualization.plot_pareto_front` for an example. Example: The following code snippet shows how to plot the Pareto front of a study. .. plot:: import optuna def objective(trial): x = trial.suggest_float("x", 0, 5) y = trial.suggest_float("y", 0, 3) v0 = 4 * x ** 2 + 4 * y ** 2 v1 = (x - 5) ** 2 + (y - 5) ** 2 return v0, v1 study = optuna.create_study(directions=["minimize", "minimize"]) study.optimize(objective, n_trials=50) optuna.visualization.matplotlib.plot_pareto_front(study) Args: study: A :class:`~optuna.study.Study` object whose trials are plotted for their objective values. target_names: Objective name list used as the axis titles. If :obj:`None` is specified, "Objective {objective_index}" is used instead. include_dominated_trials: A flag to include all dominated trial's objective values. axis_order: A list of indices indicating the axis order. If :obj:`None` is specified, default order is used. Returns: A :class:`matplotlib.axes.Axes` object. Raises: :exc:`ValueError`: If the number of objectives of ``study`` isn't 2 or 3. """ _imports.check() if len(study.directions) == 2: return _get_pareto_front_2d(study, target_names, include_dominated_trials, axis_order) elif len(study.directions) == 3: return _get_pareto_front_3d(study, target_names, include_dominated_trials, axis_order) else: raise ValueError("`plot_pareto_front` function only supports 2 or 3 objective studies.")
def _get_non_pareto_front_trials( study: Study, pareto_trials: List[FrozenTrial] ) -> List[FrozenTrial]: non_pareto_trials = [] for trial in study.get_trials(): if trial.state == TrialState.COMPLETE and trial not in pareto_trials: non_pareto_trials.append(trial) return non_pareto_trials def _get_pareto_front_2d( study: Study, target_names: Optional[List[str]], include_dominated_trials: bool = False, axis_order: Optional[List[int]] = None, ) -> "Axes": # Set up the graph style. plt.style.use("ggplot") # Use ggplot style sheet for similar outputs to plotly. _, ax = plt.subplots() ax.set_title("Pareto-front Plot") cmap = plt.get_cmap("tab10") # Use tab10 colormap for similar outputs to plotly. if target_names is None: target_names = ["Objective 0", "Objective 1"] elif len(target_names) != 2: raise ValueError("The length of `target_names` is supposed to be 2.") # Prepare data for plotting. trials = study.best_trials if len(trials) == 0: _logger.warning("Your study does not have any completed trials.") if include_dominated_trials: non_pareto_trials = _get_non_pareto_front_trials(study, trials) trials += non_pareto_trials if axis_order is None: axis_order = list(range(2)) else: if len(axis_order) != 2: raise ValueError( f"Size of `axis_order` {axis_order}. Expect: 2, Actual: {len(axis_order)}." ) if len(set(axis_order)) != 2: raise ValueError(f"Elements of given `axis_order` {axis_order} are not unique!") if max(axis_order) > 1: raise ValueError( f"Given `axis_order` {axis_order} contains invalid index {max(axis_order)} " "higher than 1." ) if min(axis_order) < 0: raise ValueError( f"Given `axis_order` {axis_order} contains invalid index {min(axis_order)} " "lower than 0." ) ax.set_xlabel(target_names[axis_order[0]]) ax.set_ylabel(target_names[axis_order[1]]) if len(trials) - len(study.best_trials) != 0: ax.scatter( x=[t.values[axis_order[0]] for t in trials[len(study.best_trials) :]], y=[t.values[axis_order[1]] for t in trials[len(study.best_trials) :]], color=cmap(0), label="Trial", ) if len(study.best_trials): ax.scatter( x=[t.values[axis_order[0]] for t in trials[: len(study.best_trials)]], y=[t.values[axis_order[1]] for t in trials[: len(study.best_trials)]], color=cmap(3), label="Best Trial", ) if include_dominated_trials and ax.has_data(): ax.legend() return ax def _get_pareto_front_3d( study: Study, target_names: Optional[List[str]], include_dominated_trials: bool = False, axis_order: Optional[List[int]] = None, ) -> "Axes": # Set up the graph style. plt.style.use("ggplot") # Use ggplot style sheet for similar outputs to plotly. fig = plt.figure() ax = fig.add_subplot(projection="3d") ax.set_title("Pareto-front Plot") cmap = plt.get_cmap("tab10") # Use tab10 colormap for similar outputs to plotly. if target_names is None: target_names = ["Objective 0", "Objective 1", "Objective 2"] elif len(target_names) != 3: raise ValueError("The length of `target_names` is supposed to be 3.") trials = study.best_trials if len(trials) == 0: _logger.warning("Your study does not have any completed trials.") if include_dominated_trials: non_pareto_trials = _get_non_pareto_front_trials(study, trials) trials += non_pareto_trials if axis_order is None: axis_order = list(range(3)) else: if len(axis_order) != 3: raise ValueError( f"Size of `axis_order` {axis_order}. Expect: 3, Actual: {len(axis_order)}." ) if len(set(axis_order)) != 3: raise ValueError(f"Elements of given `axis_order` {axis_order} are not unique!.") if max(axis_order) > 2: raise ValueError( f"Given `axis_order` {axis_order} contains invalid index {max(axis_order)} " "higher than 2." ) if min(axis_order) < 0: raise ValueError( f"Given `axis_order` {axis_order} contains invalid index {min(axis_order)} " "lower than 0." ) ax.set_xlabel(target_names[axis_order[0]]) ax.set_ylabel(target_names[axis_order[1]]) ax.set_zlabel(target_names[axis_order[2]]) if len(trials) - len(study.best_trials) != 0: ax.scatter( xs=[t.values[axis_order[0]] for t in trials[len(study.best_trials) :]], ys=[t.values[axis_order[1]] for t in trials[len(study.best_trials) :]], zs=[t.values[axis_order[2]] for t in trials[len(study.best_trials) :]], color=cmap(0), label="Trial", ) if len(study.best_trials): ax.scatter( xs=[t.values[axis_order[0]] for t in trials[: len(study.best_trials)]], ys=[t.values[axis_order[1]] for t in trials[: len(study.best_trials)]], zs=[t.values[axis_order[2]] for t in trials[: len(study.best_trials)]], color=cmap(3), label="Best Trial", ) if include_dominated_trials and ax.has_data(): ax.legend() return ax