Source code for optuna.visualization._terminator_improvement

from __future__ import annotations

from typing import NamedTuple

import tqdm

import optuna
from optuna._experimental import experimental_func
from optuna.logging import get_logger
from import Study
from optuna.terminator import BaseErrorEvaluator
from optuna.terminator import BaseImprovementEvaluator
from optuna.terminator import CrossValidationErrorEvaluator
from optuna.terminator import RegretBoundEvaluator
from optuna.terminator.erroreval import StaticErrorEvaluator
from optuna.terminator.improvement.evaluator import BestValueStagnationEvaluator
from optuna.terminator.improvement.evaluator import DEFAULT_MIN_N_TRIALS
from optuna.visualization._plotly_imports import _imports

if _imports.is_successful():
    from optuna.visualization._plotly_imports import go

_logger = get_logger(__name__)

OPACITY = 0.25

class _ImprovementInfo(NamedTuple):
    trial_numbers: list[int]
    improvements: list[float]
    errors: list[float] | None

[docs] @experimental_func("3.2.0") def plot_terminator_improvement( study: Study, plot_error: bool = False, improvement_evaluator: BaseImprovementEvaluator | None = None, error_evaluator: BaseErrorEvaluator | None = None, min_n_trials: int = DEFAULT_MIN_N_TRIALS, ) -> "go.Figure": """Plot the potentials for future objective improvement. This function visualizes the objective improvement potentials, evaluated with ``improvement_evaluator``. It helps to determine whether we should continue the optimization or not. You can also plot the error evaluated with ``error_evaluator`` if the ``plot_error`` argument is set to :obj:`True`. Note that this function may take some time to compute the improvement potentials. Example: The following code snippet shows how to plot improvement potentials, together with cross-validation errors. .. plotly:: from lightgbm import LGBMClassifier from sklearn.datasets import load_wine from sklearn.model_selection import cross_val_score from sklearn.model_selection import KFold import optuna from optuna.terminator import report_cross_validation_scores from optuna.visualization import plot_terminator_improvement def objective(trial): X, y = load_wine(return_X_y=True) clf = LGBMClassifier( reg_alpha=trial.suggest_float("reg_alpha", 1e-8, 10.0, log=True), reg_lambda=trial.suggest_float("reg_lambda", 1e-8, 10.0, log=True), num_leaves=trial.suggest_int("num_leaves", 2, 256), colsample_bytree=trial.suggest_float("colsample_bytree", 0.4, 1.0), subsample=trial.suggest_float("subsample", 0.4, 1.0), subsample_freq=trial.suggest_int("subsample_freq", 1, 7), min_child_samples=trial.suggest_int("min_child_samples", 5, 100), ) scores = cross_val_score(clf, X, y, cv=KFold(n_splits=5, shuffle=True)) report_cross_validation_scores(trial, scores) return scores.mean() study = optuna.create_study() study.optimize(objective, n_trials=30) fig = plot_terminator_improvement(study, plot_error=True) Args: study: A :class:`` object whose trials are plotted for their improvement. plot_error: A flag to show the error. If it is set to :obj:`True`, errors evaluated by ``error_evaluator`` are also plotted as line graph. Defaults to :obj:`False`. improvement_evaluator: An object that evaluates the improvement of the objective function. Defaults to :class:`~optuna.terminator.RegretBoundEvaluator`. error_evaluator: An object that evaluates the error inherent in the objective function. Defaults to :class:`~optuna.terminator.CrossValidationErrorEvaluator`. min_n_trials: The minimum number of trials before termination is considered. Terminator improvements for trials below this value are shown in a lighter color. Defaults to ``20``. Returns: A :class:`plotly.graph_objects.Figure` object. """ _imports.check() info = _get_improvement_info(study, plot_error, improvement_evaluator, error_evaluator) return _get_improvement_plot(info, min_n_trials)
def _get_improvement_info( study: Study, get_error: bool = False, improvement_evaluator: BaseImprovementEvaluator | None = None, error_evaluator: BaseErrorEvaluator | None = None, ) -> _ImprovementInfo: if study._is_multi_objective(): raise ValueError("This function does not support multi-objective optimization study.") if improvement_evaluator is None: improvement_evaluator = RegretBoundEvaluator() if error_evaluator is None: if isinstance(improvement_evaluator, BestValueStagnationEvaluator): error_evaluator = StaticErrorEvaluator(constant=0) else: error_evaluator = CrossValidationErrorEvaluator() trial_numbers = [] completed_trials = [] improvements = [] errors = [] for trial in tqdm.tqdm(study.trials): if trial.state == optuna.trial.TrialState.COMPLETE: completed_trials.append(trial) if len(completed_trials) == 0: continue trial_numbers.append(trial.number) improvement = improvement_evaluator.evaluate( trials=completed_trials, study_direction=study.direction ) improvements.append(improvement) if get_error: error = error_evaluator.evaluate( trials=completed_trials, study_direction=study.direction ) errors.append(error) if len(errors) == 0: return _ImprovementInfo( trial_numbers=trial_numbers, improvements=improvements, errors=None ) else: return _ImprovementInfo( trial_numbers=trial_numbers, improvements=improvements, errors=errors ) def _get_improvement_scatter( trial_numbers: list[int], improvements: list[float], opacity: float = 1.0, showlegend: bool = True, ) -> "go.Scatter": plotly_blue_with_opacity = f"rgba(99, 110, 250, {opacity})" return go.Scatter( x=trial_numbers, y=improvements, mode="markers+lines", marker=dict(color=plotly_blue_with_opacity), line=dict(color=plotly_blue_with_opacity), name="Terminator Improvement", showlegend=showlegend, legendgroup="improvement", ) def _get_error_scatter( trial_numbers: list[int], errors: list[float] | None, ) -> "go.Scatter": if errors is None: return go.Scatter() plotly_red = "rgb(239, 85, 59)" return go.Scatter( x=trial_numbers, y=errors, mode="markers+lines", name="Error", marker=dict(color=plotly_red), line=dict(color=plotly_red), ) def _get_y_range(info: _ImprovementInfo, min_n_trials: int) -> tuple[float, float]: min_value = min(info.improvements) if info.errors is not None: min_value = min(min_value, min(info.errors)) # Determine the display range based on trials after min_n_trials. if len(info.trial_numbers) > min_n_trials: max_value = max(info.improvements[min_n_trials:]) # If there are no trials after min_trials, determine the display range based on all trials. else: max_value = max(info.improvements) if info.errors is not None: max_value = max(max_value, max(info.errors)) padding = (max_value - min_value) * PADDING_RATIO_Y return (min_value - padding, max_value + padding) def _get_improvement_plot(info: _ImprovementInfo, min_n_trials: int) -> "go.Figure": n_trials = len(info.trial_numbers) fig = go.Figure( layout=go.Layout( title="Terminator Improvement Plot", xaxis=dict(title="Trial"), yaxis=dict(title="Terminator Improvement"), ) ) if n_trials == 0: _logger.warning("There are no complete trials.") return fig fig.add_trace( _get_improvement_scatter( info.trial_numbers[: min_n_trials + 1], info.improvements[: min_n_trials + 1], # Plot line with a lighter color until the number of trials reaches min_n_trials. OPACITY, n_trials <= min_n_trials, # Avoid showing legend twice. ) ) if n_trials > min_n_trials: fig.add_trace( _get_improvement_scatter( info.trial_numbers[min_n_trials:], info.improvements[min_n_trials:], ) ) fig.add_trace(_get_error_scatter(info.trial_numbers, info.errors)) fig.update_yaxes(range=_get_y_range(info, min_n_trials)) return fig