Pruning Unpromising Trials¶
This feature automatically stops unpromising trials at the early stages of the training (a.k.a., automated early-stopping). Optuna provides interfaces to concisely implement the pruning mechanism in iterative training algorithms.
Activating Pruners¶
To turn on the pruning feature, you need to call report()
and should_prune()
after each step of the iterative training.
report()
periodically monitors the intermediate objective values.
should_prune()
decides termination of the trial that does not meet a predefined condition.
"""filename: prune.py"""
import sklearn.datasets
import sklearn.linear_model
import sklearn.model_selection
import optuna
def objective(trial):
iris = sklearn.datasets.load_iris()
classes = list(set(iris.target))
train_x, valid_x, train_y, valid_y = \
sklearn.model_selection.train_test_split(iris.data, iris.target, test_size=0.25, random_state=0)
alpha = trial.suggest_loguniform('alpha', 1e-5, 1e-1)
clf = sklearn.linear_model.SGDClassifier(alpha=alpha)
for step in range(100):
clf.partial_fit(train_x, train_y, classes=classes)
# Report intermediate objective value.
intermediate_value = 1.0 - clf.score(valid_x, valid_y)
trial.report(intermediate_value, step)
# Handle pruning based on the intermediate value.
if trial.should_prune():
raise optuna.TrialPruned()
return 1.0 - clf.score(valid_x, valid_y)
# Set up the median stopping rule as the pruning condition.
study = optuna.create_study(pruner=optuna.pruners.MedianPruner())
study.optimize(objective, n_trials=20)
Executing the script above:
$ python prune.py
[I 2020-06-12 16:54:23,876] Trial 0 finished with value: 0.3157894736842105 and parameters: {'alpha': 0.00181467547181131}. Best is trial 0 with value: 0.3157894736842105.
[I 2020-06-12 16:54:23,981] Trial 1 finished with value: 0.07894736842105265 and parameters: {'alpha': 0.015378744419287613}. Best is trial 1 with value: 0.07894736842105265.
[I 2020-06-12 16:54:24,083] Trial 2 finished with value: 0.21052631578947367 and parameters: {'alpha': 0.04089428832878595}. Best is trial 1 with value: 0.07894736842105265.
[I 2020-06-12 16:54:24,185] Trial 3 finished with value: 0.052631578947368474 and parameters: {'alpha': 0.004018735937374473}. Best is trial 3 with value: 0.052631578947368474.
[I 2020-06-12 16:54:24,303] Trial 4 finished with value: 0.07894736842105265 and parameters: {'alpha': 2.805688697062864e-05}. Best is trial 3 with value: 0.052631578947368474.
[I 2020-06-12 16:54:24,315] Trial 5 pruned.
[I 2020-06-12 16:54:24,355] Trial 6 pruned.
[I 2020-06-12 16:54:24,511] Trial 7 finished with value: 0.052631578947368474 and parameters: {'alpha': 2.243775785299103e-05}. Best is trial 3 with value: 0.052631578947368474.
[I 2020-06-12 16:54:24,625] Trial 8 finished with value: 0.1842105263157895 and parameters: {'alpha': 0.007021209286214553}. Best is trial 3 with value: 0.052631578947368474.
[I 2020-06-12 16:54:24,629] Trial 9 pruned.
...
Trial 5 pruned.
, etc. in the log messages means several trials were stopped before they finished all of the iterations.
Integration Modules for Pruning¶
To implement pruning mechanism in much simpler forms, Optuna provides integration modules for the following libraries.
TensorFlow
optuna.integration.TensorFlowPruningHook
PyTorch Ignite
optuna.integration.PyTorchIgnitePruningHandler
PyTorch Lightning
optuna.integration.PyTorchLightningPruningCallback
For example, XGBoostPruningCallback
introduces pruning without directly changing the logic of training iteration.
(See also example for the entire script.)
pruning_callback = optuna.integration.XGBoostPruningCallback(trial, 'validation-error')
bst = xgb.train(param, dtrain, evals=[(dvalid, 'validation')], callbacks=[pruning_callback])