Source code for optuna.visualization._terminator_improvement

from __future__ import annotations

from typing import NamedTuple

import tqdm

import optuna
from optuna._experimental import experimental_func
from optuna.logging import get_logger
from optuna.study.study import Study
from optuna.terminator import BaseErrorEvaluator
from optuna.terminator import BaseImprovementEvaluator
from optuna.terminator import CrossValidationErrorEvaluator
from optuna.terminator import RegretBoundEvaluator
from optuna.terminator.improvement.evaluator import DEFAULT_MIN_N_TRIALS
from optuna.visualization._plotly_imports import _imports


if _imports.is_successful():
    from optuna.visualization._plotly_imports import go

_logger = get_logger(__name__)


PADDING_RATIO_Y = 0.05
OPACITY = 0.25


class _ImprovementInfo(NamedTuple):
    trial_numbers: list[int]
    improvements: list[float]
    errors: list[float] | None


[docs] @experimental_func("3.2.0") def plot_terminator_improvement( study: Study, plot_error: bool = False, improvement_evaluator: BaseImprovementEvaluator | None = None, error_evaluator: BaseErrorEvaluator | None = None, min_n_trials: int = DEFAULT_MIN_N_TRIALS, ) -> "go.Figure": """Plot the potentials for future objective improvement. This function visualizes the objective improvement potentials, evaluated with ``improvement_evaluator``. It helps to determine whether we should continue the optimization or not. You can also plot the error evaluated with ``error_evaluator`` if the ``plot_error`` argument is set to :obj:`True`. Note that this function may take some time to compute the improvement potentials. Example: The following code snippet shows how to plot improvement potentials, together with cross-validation errors. .. plotly:: from lightgbm import LGBMClassifier from sklearn.datasets import load_wine from sklearn.model_selection import cross_val_score from sklearn.model_selection import KFold import optuna from optuna.terminator import report_cross_validation_scores from optuna.visualization import plot_terminator_improvement def objective(trial): X, y = load_wine(return_X_y=True) clf = LGBMClassifier( reg_alpha=trial.suggest_float("reg_alpha", 1e-8, 10.0, log=True), reg_lambda=trial.suggest_float("reg_lambda", 1e-8, 10.0, log=True), num_leaves=trial.suggest_int("num_leaves", 2, 256), colsample_bytree=trial.suggest_float("colsample_bytree", 0.4, 1.0), subsample=trial.suggest_float("subsample", 0.4, 1.0), subsample_freq=trial.suggest_int("subsample_freq", 1, 7), min_child_samples=trial.suggest_int("min_child_samples", 5, 100), ) scores = cross_val_score(clf, X, y, cv=KFold(n_splits=5, shuffle=True)) report_cross_validation_scores(trial, scores) return scores.mean() study = optuna.create_study() study.optimize(objective, n_trials=30) fig = plot_terminator_improvement(study, plot_error=True) fig.show() Args: study: A :class:`~optuna.study.Study` object whose trials are plotted for their improvement. plot_error: A flag to show the error. If it is set to :obj:`True`, errors evaluated by ``error_evaluator`` are also plotted as line graph. Defaults to :obj:`False`. improvement_evaluator: An object that evaluates the improvement of the objective function. Defaults to :class:`~optuna.terminator.RegretBoundEvaluator`. error_evaluator: An object that evaluates the error inherent in the objective function. Defaults to :class:`~optuna.terminator.CrossValidationErrorEvaluator`. min_n_trials: The minimum number of trials before termination is considered. Terminator improvements for trials below this value are shown in a lighter color. Defaults to ``20``. Returns: A :class:`plotly.graph_objects.Figure` object. """ _imports.check() info = _get_improvement_info(study, plot_error, improvement_evaluator, error_evaluator) return _get_improvement_plot(info, min_n_trials)
def _get_improvement_info( study: Study, get_error: bool = False, improvement_evaluator: BaseImprovementEvaluator | None = None, error_evaluator: BaseErrorEvaluator | None = None, ) -> _ImprovementInfo: if study._is_multi_objective(): raise ValueError("This function does not support multi-objective optimization study.") if improvement_evaluator is None: improvement_evaluator = RegretBoundEvaluator() if error_evaluator is None: error_evaluator = CrossValidationErrorEvaluator() trial_numbers = [] completed_trials = [] improvements = [] errors = [] for trial in tqdm.tqdm(study.trials): if trial.state == optuna.trial.TrialState.COMPLETE: completed_trials.append(trial) if len(completed_trials) == 0: continue trial_numbers.append(trial.number) improvement = improvement_evaluator.evaluate( trials=completed_trials, study_direction=study.direction ) improvements.append(improvement) if get_error: error = error_evaluator.evaluate( trials=completed_trials, study_direction=study.direction ) errors.append(error) if len(errors) == 0: return _ImprovementInfo( trial_numbers=trial_numbers, improvements=improvements, errors=None ) else: return _ImprovementInfo( trial_numbers=trial_numbers, improvements=improvements, errors=errors ) def _get_improvement_scatter( trial_numbers: list[int], improvements: list[float], opacity: float = 1.0, showlegend: bool = True, ) -> "go.Scatter": plotly_blue_with_opacity = f"rgba(99, 110, 250, {opacity})" return go.Scatter( x=trial_numbers, y=improvements, mode="markers+lines", marker=dict(color=plotly_blue_with_opacity), line=dict(color=plotly_blue_with_opacity), name="Terminator Improvement", showlegend=showlegend, legendgroup="improvement", ) def _get_error_scatter( trial_numbers: list[int], errors: list[float] | None, ) -> "go.Scatter": if errors is None: return go.Scatter() plotly_red = "rgb(239, 85, 59)" return go.Scatter( x=trial_numbers, y=errors, mode="markers+lines", name="Error", marker=dict(color=plotly_red), line=dict(color=plotly_red), ) def _get_y_range(info: _ImprovementInfo, min_n_trials: int) -> tuple[float, float]: min_value = min(info.improvements) if info.errors is not None: min_value = min(min_value, min(info.errors)) # Determine the display range based on trials after min_n_trials. if len(info.trial_numbers) > min_n_trials: max_value = max(info.improvements[min_n_trials:]) # If there are no trials after min_trials, determine the display range based on all trials. else: max_value = max(info.improvements) if info.errors is not None: max_value = max(max_value, max(info.errors)) padding = (max_value - min_value) * PADDING_RATIO_Y return (min_value - padding, max_value + padding) def _get_improvement_plot(info: _ImprovementInfo, min_n_trials: int) -> "go.Figure": n_trials = len(info.trial_numbers) fig = go.Figure( layout=go.Layout( title="Terminator Improvement Plot", xaxis=dict(title="Trial"), yaxis=dict(title="Terminator Improvement"), ) ) if n_trials == 0: _logger.warning("There are no complete trials.") return fig fig.add_trace( _get_improvement_scatter( info.trial_numbers[: min_n_trials + 1], info.improvements[: min_n_trials + 1], # Plot line with a lighter color until the number of trials reaches min_n_trials. OPACITY, n_trials <= min_n_trials, # Avoid showing legend twice. ) ) if n_trials > min_n_trials: fig.add_trace( _get_improvement_scatter( info.trial_numbers[min_n_trials:], info.improvements[min_n_trials:], ) ) fig.add_trace(_get_error_scatter(info.trial_numbers, info.errors)) fig.update_yaxes(range=_get_y_range(info, min_n_trials)) return fig