Callback for Study.optimize

This tutorial showcases how to use & implement Optuna Callback for optimize().

Callback is called after every evaluation of objective, and it takes Study and FrozenTrial as arguments, and does some work.

MLflowCallback is a great example.

Stop optimization after some trials are pruned in a row

This example implements a stateful callback which stops the optimization if a certain number of trials are pruned in a row. The number of trials pruned in a row is specified by threshold.

import optuna


class StopWhenTrialKeepBeingPrunedCallback:
    def __init__(self, threshold: int):
        self.threshold = threshold
        self._consequtive_pruned_count = 0

    def __call__(self, study: optuna.study.Study, trial: optuna.trial.FrozenTrial) -> None:
        if trial.state == optuna.trial.TrialState.PRUNED:
            self._consequtive_pruned_count += 1
        else:
            self._consequtive_pruned_count = 0

        if self._consequtive_pruned_count >= self.threshold:
            study.stop()

This objective prunes all the trials except for the first 5 trials (trial.number starts with 0).

def objective(trial):
    if trial.number > 4:
        raise optuna.TrialPruned

    return trial.suggest_float("x", 0, 1)

Here, we set the threshold to 2: optimization finishes once two trials are pruned in a row. So, we expect this study to stop after 7 trials.

import logging
import sys

# Add stream handler of stdout to show the messages
optuna.logging.get_logger("optuna").addHandler(logging.StreamHandler(sys.stdout))

study_stop_cb = StopWhenTrialKeepBeingPrunedCallback(2)
study = optuna.create_study()
study.optimize(objective, n_trials=10, callbacks=[study_stop_cb])

Out:

A new study created in memory with name: no-name-60f02b20-ff32-4eda-88e9-004318d07678
Trial 0 finished with value: 0.06540296186670091 and parameters: {'x': 0.06540296186670091}. Best is trial 0 with value: 0.06540296186670091.
Trial 1 finished with value: 0.9711070774818921 and parameters: {'x': 0.9711070774818921}. Best is trial 0 with value: 0.06540296186670091.
Trial 2 finished with value: 0.6300442846013438 and parameters: {'x': 0.6300442846013438}. Best is trial 0 with value: 0.06540296186670091.
Trial 3 finished with value: 0.018927519201614063 and parameters: {'x': 0.018927519201614063}. Best is trial 3 with value: 0.018927519201614063.
Trial 4 finished with value: 0.8810086329091955 and parameters: {'x': 0.8810086329091955}. Best is trial 3 with value: 0.018927519201614063.
Trial 5 pruned.
Trial 6 pruned.

As you can see in the log above, the study stopped after 7 trials as expected.

Total running time of the script: ( 0 minutes 0.007 seconds)

Gallery generated by Sphinx-Gallery