from __future__ import annotations
from typing import NamedTuple
import tqdm
import optuna
from optuna._experimental import experimental_func
from optuna.logging import get_logger
from optuna.study.study import Study
from optuna.terminator import BaseErrorEvaluator
from optuna.terminator import BaseImprovementEvaluator
from optuna.terminator import CrossValidationErrorEvaluator
from optuna.terminator import RegretBoundEvaluator
from optuna.terminator.improvement.evaluator import DEFAULT_MIN_N_TRIALS
from optuna.visualization._plotly_imports import _imports
if _imports.is_successful():
from optuna.visualization._plotly_imports import go
_logger = get_logger(__name__)
PADDING_RATIO_Y = 0.05
OPACITY = 0.25
class _ImprovementInfo(NamedTuple):
trial_numbers: list[int]
improvements: list[float]
errors: list[float] | None
[docs]@experimental_func("3.2.0")
def plot_terminator_improvement(
study: Study,
plot_error: bool = False,
improvement_evaluator: BaseImprovementEvaluator | None = None,
error_evaluator: BaseErrorEvaluator | None = None,
min_n_trials: int = DEFAULT_MIN_N_TRIALS,
) -> "go.Figure":
"""Plot the potentials for future objective improvement.
This function visualizes the objective improvement potentials, evaluated
with ``improvement_evaluator``.
It helps to determine whether we should continue the optimization or not.
You can also plot the error evaluated with
``error_evaluator`` if the ``plot_error`` argument is set to :obj:`True`.
Note that this function may take some time to compute
the improvement potentials.
Example:
The following code snippet shows how to plot improvement potentials,
together with cross-validation errors.
.. plotly::
from lightgbm import LGBMClassifier
from sklearn.datasets import load_wine
from sklearn.model_selection import cross_val_score
from sklearn.model_selection import KFold
import optuna
from optuna.terminator import report_cross_validation_scores
from optuna.visualization import plot_terminator_improvement
def objective(trial):
X, y = load_wine(return_X_y=True)
clf = LGBMClassifier(
reg_alpha=trial.suggest_float("reg_alpha", 1e-8, 10.0, log=True),
reg_lambda=trial.suggest_float("reg_lambda", 1e-8, 10.0, log=True),
num_leaves=trial.suggest_int("num_leaves", 2, 256),
colsample_bytree=trial.suggest_float("colsample_bytree", 0.4, 1.0),
subsample=trial.suggest_float("subsample", 0.4, 1.0),
subsample_freq=trial.suggest_int("subsample_freq", 1, 7),
min_child_samples=trial.suggest_int("min_child_samples", 5, 100),
)
scores = cross_val_score(clf, X, y, cv=KFold(n_splits=5, shuffle=True))
report_cross_validation_scores(trial, scores)
return scores.mean()
study = optuna.create_study()
study.optimize(objective, n_trials=30)
fig = plot_terminator_improvement(study, plot_error=True)
fig.show()
Args:
study:
A :class:`~optuna.study.Study` object whose trials are plotted
for their improvement.
plot_error:
A flag to show the error. If it is set to :obj:`True`, errors
evaluated by ``error_evaluator`` are also plotted as line graph.
Defaults to :obj:`False`.
improvement_evaluator:
An object that evaluates the improvement of the objective function.
Defaults to :class:`~optuna.terminator.RegretBoundEvaluator`.
error_evaluator:
An object that evaluates the error inherent in the objective function.
Defaults to :class:`~optuna.terminator.CrossValidationErrorEvaluator`.
min_n_trials:
The minimum number of trials before termination is considered.
Terminator improvements for trials below this value are
shown in a lighter color. Defaults to ``20``.
Returns:
A :class:`plotly.graph_objs.Figure` object.
"""
_imports.check()
info = _get_improvement_info(study, plot_error, improvement_evaluator, error_evaluator)
return _get_improvement_plot(info, min_n_trials)
def _get_improvement_info(
study: Study,
get_error: bool = False,
improvement_evaluator: BaseImprovementEvaluator | None = None,
error_evaluator: BaseErrorEvaluator | None = None,
) -> _ImprovementInfo:
if study._is_multi_objective():
raise ValueError("This function does not support multi-objective optimization study.")
if improvement_evaluator is None:
improvement_evaluator = RegretBoundEvaluator()
if error_evaluator is None:
error_evaluator = CrossValidationErrorEvaluator()
trial_numbers = []
completed_trials = []
improvements = []
errors = []
for trial in tqdm.tqdm(study.trials):
if trial.state == optuna.trial.TrialState.COMPLETE:
completed_trials.append(trial)
if len(completed_trials) == 0:
continue
trial_numbers.append(trial.number)
improvement = improvement_evaluator.evaluate(
trials=completed_trials, study_direction=study.direction
)
improvements.append(improvement)
if get_error:
error = error_evaluator.evaluate(
trials=completed_trials, study_direction=study.direction
)
errors.append(error)
if len(errors) == 0:
return _ImprovementInfo(
trial_numbers=trial_numbers, improvements=improvements, errors=None
)
else:
return _ImprovementInfo(
trial_numbers=trial_numbers, improvements=improvements, errors=errors
)
def _get_improvement_scatter(
trial_numbers: list[int],
improvements: list[float],
opacity: float = 1.0,
showlegend: bool = True,
) -> "go.Scatter":
plotly_blue_with_opacity = f"rgba(99, 110, 250, {opacity})"
return go.Scatter(
x=trial_numbers,
y=improvements,
mode="markers+lines",
marker=dict(color=plotly_blue_with_opacity),
line=dict(color=plotly_blue_with_opacity),
name="Terminator Improvement",
showlegend=showlegend,
legendgroup="improvement",
)
def _get_error_scatter(
trial_numbers: list[int],
errors: list[float] | None,
) -> "go.Scatter":
if errors is None:
return go.Scatter()
plotly_red = "rgb(239, 85, 59)"
return go.Scatter(
x=trial_numbers,
y=errors,
mode="markers+lines",
name="Error",
marker=dict(color=plotly_red),
line=dict(color=plotly_red),
)
def _get_y_range(info: _ImprovementInfo, min_n_trials: int) -> tuple[float, float]:
min_value = min(info.improvements)
if info.errors is not None:
min_value = min(min_value, min(info.errors))
# Determine the display range based on trials after min_n_trials.
if len(info.trial_numbers) > min_n_trials:
max_value = max(info.improvements[min_n_trials:])
# If there are no trials after min_trials, determine the display range based on all trials.
else:
max_value = max(info.improvements)
if info.errors is not None:
max_value = max(max_value, max(info.errors))
padding = (max_value - min_value) * PADDING_RATIO_Y
return (min_value - padding, max_value + padding)
def _get_improvement_plot(info: _ImprovementInfo, min_n_trials: int) -> "go.Figure":
n_trials = len(info.trial_numbers)
fig = go.Figure(
layout=go.Layout(
title="Terminator Improvement Plot",
xaxis=dict(title="Trial"),
yaxis=dict(title="Terminator Improvement"),
)
)
if n_trials == 0:
_logger.warning("There are no complete trials.")
return fig
fig.add_trace(
_get_improvement_scatter(
info.trial_numbers[: min_n_trials + 1],
info.improvements[: min_n_trials + 1],
# Plot line with a lighter color until the number of trials reaches min_n_trials.
OPACITY,
n_trials <= min_n_trials, # Avoid showing legend twice.
)
)
if n_trials > min_n_trials:
fig.add_trace(
_get_improvement_scatter(
info.trial_numbers[min_n_trials:],
info.improvements[min_n_trials:],
)
)
fig.add_trace(_get_error_scatter(info.trial_numbers, info.errors))
fig.update_yaxes(range=_get_y_range(info, min_n_trials))
return fig