Source code for optuna.pruners._patient

from __future__ import annotations

import numpy as np

import optuna
from optuna._experimental import experimental_class
from optuna.pruners import BasePruner
from optuna.study._study_direction import StudyDirection


[docs] @experimental_class("2.8.0") class PatientPruner(BasePruner): """Pruner which wraps another pruner with tolerance. This pruner monitors intermediate values in a trial and prunes the trial if the improvement in the intermediate values after a patience period is less than a threshold. The pruner handles NaN values in the following manner: 1. If all intermediate values before or during the patient period are NaN, the trial will not be pruned 2. During the pruning calculations, NaN values are ignored. Only valid numeric values are considered. Example: .. testcode:: import numpy as np from sklearn.datasets import load_iris from sklearn.linear_model import SGDClassifier from sklearn.model_selection import train_test_split import optuna X, y = load_iris(return_X_y=True) X_train, X_valid, y_train, y_valid = train_test_split(X, y) classes = np.unique(y) def objective(trial): alpha = trial.suggest_float("alpha", 0.0, 1.0) clf = SGDClassifier(alpha=alpha) n_train_iter = 100 for step in range(n_train_iter): clf.partial_fit(X_train, y_train, classes=classes) intermediate_value = clf.score(X_valid, y_valid) trial.report(intermediate_value, step) if trial.should_prune(): raise optuna.TrialPruned() return clf.score(X_valid, y_valid) study = optuna.create_study( direction="maximize", pruner=optuna.pruners.PatientPruner(optuna.pruners.MedianPruner(), patience=1), ) study.optimize(objective, n_trials=20) Args: wrapped_pruner: Wrapped pruner to perform pruning when :class:`~optuna.pruners.PatientPruner` allows a trial to be pruned. If it is :obj:`None`, this pruner is equivalent to early-stopping taken the intermediate values in the individual trial. patience: Pruning is disabled until the objective doesn't improve for ``patience`` consecutive steps. min_delta: Tolerance value to check whether or not the objective improves. This value should be non-negative. """ def __init__( self, wrapped_pruner: BasePruner | None, patience: int, min_delta: float = 0.0 ) -> None: if patience < 0: raise ValueError(f"patience cannot be negative but got {patience}.") if min_delta < 0: raise ValueError(f"min_delta cannot be negative but got {min_delta}.") self._wrapped_pruner = wrapped_pruner self._patience = patience self._min_delta = min_delta
[docs] def prune(self, study: "optuna.study.Study", trial: "optuna.trial.FrozenTrial") -> bool: step = trial.last_step if step is None: return False intermediate_values = trial.intermediate_values steps = np.asarray(list(intermediate_values.keys())) # Do not prune if number of step to determine are insufficient. if steps.size <= self._patience + 1: return False steps.sort() # This is the score patience steps ago steps_before_patience = steps[: -self._patience - 1] scores_before_patience = np.asarray( list(intermediate_values[step] for step in steps_before_patience) ) # And these are the scores after that steps_after_patience = steps[-self._patience - 1 :] scores_after_patience = np.asarray( list(intermediate_values[step] for step in steps_after_patience) ) direction = study.direction if direction == StudyDirection.MINIMIZE: maybe_prune = np.nanmin(scores_before_patience) + self._min_delta < np.nanmin( scores_after_patience ) else: maybe_prune = np.nanmax(scores_before_patience) - self._min_delta > np.nanmax( scores_after_patience ) if maybe_prune: if self._wrapped_pruner is not None: return self._wrapped_pruner.prune(study, trial) else: return True else: return False