Source code for optuna.visualization.matplotlib._optimization_history

from __future__ import annotations

from typing import Callable
from typing import Sequence

import numpy as np

from optuna._experimental import experimental_func
from import Study
from optuna.trial import FrozenTrial
from optuna.visualization._optimization_history import _get_optimization_history_info_list
from optuna.visualization._optimization_history import _OptimizationHistoryInfo
from optuna.visualization.matplotlib._matplotlib_imports import _imports

if _imports.is_successful():
    from optuna.visualization.matplotlib._matplotlib_imports import Axes
    from optuna.visualization.matplotlib._matplotlib_imports import plt

[docs]@experimental_func("2.2.0") def plot_optimization_history( study: Study | Sequence[Study], *, target: Callable[[FrozenTrial], float] | None = None, target_name: str = "Objective Value", error_bar: bool = False, ) -> "Axes": """Plot optimization history of all trials in a study with Matplotlib. .. seealso:: Please refer to :func:`optuna.visualization.plot_optimization_history` for an example. Example: The following code snippet shows how to plot optimization history. .. plot:: import optuna import matplotlib.pyplot as plt def objective(trial): x = trial.suggest_float("x", -100, 100) y = trial.suggest_categorical("y", [-1, 0, 1]) return x ** 2 + y sampler = optuna.samplers.TPESampler(seed=10) study = optuna.create_study(sampler=sampler) study.optimize(objective, n_trials=10) optuna.visualization.matplotlib.plot_optimization_history(study) plt.tight_layout() .. note:: You need to adjust the size of the plot by yourself using ``plt.tight_layout()`` or ``plt.savefig(IMAGE_NAME, bbox_inches='tight')``. Args: study: A :class:`` object whose trials are plotted for their target values. You can pass multiple studies if you want to compare those optimization histories. target: A function to specify the value to display. If it is :obj:`None` and ``study`` is being used for single-objective optimization, the objective values are plotted. .. note:: Specify this argument if ``study`` is being used for multi-objective optimization. target_name: Target's name to display on the axis label and the legend. error_bar: A flag to show the error bar. Returns: A :class:`matplotlib.axes.Axes` object. """ _imports.check() info_list = _get_optimization_history_info_list(study, target, target_name, error_bar) return _get_optimization_history_plot(info_list, target_name)
def _get_optimization_history_plot( info_list: list[_OptimizationHistoryInfo], target_name: str, ) -> "Axes": # Set up the graph style."ggplot") # Use ggplot style sheet for similar outputs to plotly. _, ax = plt.subplots() ax.set_title("Optimization History Plot") ax.set_xlabel("Trial") ax.set_ylabel(target_name) cmap = plt.get_cmap("tab10") # Use tab10 colormap for similar outputs to plotly. for i, (trial_numbers, values_info, best_values_info) in enumerate(info_list): if values_info.stds is not None: plt.errorbar( x=trial_numbers, y=values_info.values, yerr=values_info.stds, capsize=5, fmt="o", color="tab:blue", ) ax.scatter( x=trial_numbers, y=values_info.values, color=cmap(0) if len(info_list) == 1 else cmap(2 * i), alpha=1, label=values_info.label_name, ) if best_values_info is not None: ax.plot( trial_numbers, best_values_info.values, marker="o", color=cmap(3) if len(info_list) == 1 else cmap(2 * i + 1), alpha=0.5, label=best_values_info.label_name, ) if best_values_info.stds is not None: lower = np.array(best_values_info.values) - np.array(best_values_info.stds) upper = np.array(best_values_info.values) + np.array(best_values_info.stds) ax.fill_between( x=trial_numbers, y1=lower, y2=upper, color="tab:red", alpha=0.4, ) ax.legend() plt.legend(bbox_to_anchor=(1.05, 1.0), loc="upper left") return ax