plot_terminator_improvement

optuna.visualization.matplotlib.plot_terminator_improvement(study, plot_error=False, improvement_evaluator=None, error_evaluator=None, min_n_trials=20)[source]

Plot the potentials for future objective improvement.

This function visualizes the objective improvement potentials, evaluated with improvement_evaluator. It helps to determine whether we should continue the optimization or not. You can also plot the error evaluated with error_evaluator if the plot_error argument is set to True. Note that this function may take some time to compute the improvement potentials.

Parameters:
  • study (Study) – A Study object whose trials are plotted for their improvement.

  • plot_error (bool) – A flag to show the error. If it is set to True, errors evaluated by error_evaluator are also plotted as line graph. Defaults to False.

  • improvement_evaluator (BaseImprovementEvaluator | None) – An object that evaluates the improvement of the objective function. Default to RegretBoundEvaluator.

  • error_evaluator (BaseErrorEvaluator | None) – An object that evaluates the error inherent in the objective function. Default to CrossValidationErrorEvaluator.

  • min_n_trials (int) – The minimum number of trials before termination is considered. Terminator improvements for trials below this value are shown in a lighter color. Defaults to 20.

Returns:

A matplotlib.axes.Axes object.

Return type:

Axes

Note

Added in v3.2.0 as an experimental feature. The interface may change in newer versions without prior notice. See https://github.com/optuna/optuna/releases/tag/v3.2.0.

The following code snippet shows how to plot improvement potentials, together with cross-validation errors.

Terminator Improvement Plot
/home/docs/checkouts/readthedocs.org/user_builds/optuna/checkouts/stable/docs/visualization_matplotlib_examples/optuna.visualization.matplotlib.terminator_improvement.py:41: ExperimentalWarning:

plot_terminator_improvement is experimental (supported from v3.2.0). The interface can change in the future.

/home/docs/checkouts/readthedocs.org/user_builds/optuna/envs/stable/lib/python3.11/site-packages/optuna/visualization/_terminator_improvement.py:93: ExperimentalWarning:

RegretBoundEvaluator is experimental (supported from v3.2.0). The interface can change in the future.

/home/docs/checkouts/readthedocs.org/user_builds/optuna/envs/stable/lib/python3.11/site-packages/optuna/visualization/_terminator_improvement.py:98: ExperimentalWarning:

CrossValidationErrorEvaluator is experimental (supported from v3.2.0). The interface can change in the future.


  0%|          | 0/30 [00:00<?, ?it/s]
 17%|█▋        | 5/30 [00:00<00:00, 45.07it/s]
 33%|███▎      | 10/30 [00:00<00:00, 39.67it/s]
 50%|█████     | 15/30 [00:00<00:00, 32.32it/s]
 63%|██████▎   | 19/30 [00:00<00:00, 30.54it/s]
 77%|███████▋  | 23/30 [00:00<00:00, 29.21it/s]
 87%|████████▋ | 26/30 [00:00<00:00, 28.91it/s]
 97%|█████████▋| 29/30 [00:00<00:00, 28.04it/s]
100%|██████████| 30/30 [00:00<00:00, 30.25it/s]

<Axes: title={'center': 'Terminator Improvement Plot'}, xlabel='Trial', ylabel='Terminator Improvement'>

from lightgbm import LGBMClassifier
from sklearn.datasets import load_wine
from sklearn.model_selection import cross_val_score
from sklearn.model_selection import KFold
import optuna
from optuna.terminator import report_cross_validation_scores
from optuna.visualization.matplotlib import plot_terminator_improvement


def objective(trial):
    X, y = load_wine(return_X_y=True)
    clf = LGBMClassifier(
        reg_alpha=trial.suggest_float("reg_alpha", 1e-8, 10.0, log=True),
        reg_lambda=trial.suggest_float("reg_lambda", 1e-8, 10.0, log=True),
        num_leaves=trial.suggest_int("num_leaves", 2, 256),
        colsample_bytree=trial.suggest_float("colsample_bytree", 0.4, 1.0),
        subsample=trial.suggest_float("subsample", 0.4, 1.0),
        subsample_freq=trial.suggest_int("subsample_freq", 1, 7),
        min_child_samples=trial.suggest_int("min_child_samples", 5, 100),
    )
    scores = cross_val_score(clf, X, y, cv=KFold(n_splits=5, shuffle=True))
    report_cross_validation_scores(trial, scores)
    return scores.mean()


study = optuna.create_study()
study.optimize(objective, n_trials=30)

plot_terminator_improvement(study, plot_error=True)

Total running time of the script: (0 minutes 3.613 seconds)

Gallery generated by Sphinx-Gallery