Study.optimize 的回调

本教程展示了如何使用和实现一个 用于 optimize() 的Optuna Callback .

Callback 会在 objective 每次求值以后被调用一次, 它接受 StudyFrozenTrial 作为参数并进行处理.

MLflowCallback 是个好例子.

在特定数量的 trial 被剪枝后终止优化

本例实现了一个有状态的回调函数. 如果特定数目的 trial 被剪枝了, 它将终止优化. 被剪枝的 trial 数是通过 threshold 来指定的.

import optuna


class StopWhenTrialKeepBeingPrunedCallback:
    def __init__(self, threshold: int):
        self.threshold = threshold
        self._consequtive_pruned_count = 0

    def __call__(self, study: optuna.study.Study, trial: optuna.trial.FrozenTrial) -> None:
        if trial.state == optuna.trial.TrialState.PRUNED:
            self._consequtive_pruned_count += 1
        else:
            self._consequtive_pruned_count = 0

        if self._consequtive_pruned_count >= self.threshold:
            study.stop()

该目标函数会对除了前五个 trial 之外的所有trial 进行剪枝 (trial.number 从 0 开始计数).

def objective(trial):
    if trial.number > 4:
        raise optuna.TrialPruned

    return trial.suggest_float("x", 0, 1)

在这里, 我们将阈值设置为 2: 优化过程会在一旦两个 trial 被剪枝后发生. 因此, 我们预期该 study 会在 7 个 trial 后停止.

import logging
import sys

# Add stream handler of stdout to show the messages
optuna.logging.get_logger("optuna").addHandler(logging.StreamHandler(sys.stdout))

study_stop_cb = StopWhenTrialKeepBeingPrunedCallback(2)
study = optuna.create_study()
study.optimize(objective, n_trials=10, callbacks=[study_stop_cb])

Out:

A new study created in memory with name: no-name-be0f83c6-46e8-42fd-9d9e-f6ebd9175ba1
Trial 0 finished with value: 0.03979847340429532 and parameters: {'x': 0.03979847340429532}. Best is trial 0 with value: 0.03979847340429532.
Trial 1 finished with value: 0.32272885613475355 and parameters: {'x': 0.32272885613475355}. Best is trial 0 with value: 0.03979847340429532.
Trial 2 finished with value: 0.04095547461074589 and parameters: {'x': 0.04095547461074589}. Best is trial 0 with value: 0.03979847340429532.
Trial 3 finished with value: 0.36590369991767835 and parameters: {'x': 0.36590369991767835}. Best is trial 0 with value: 0.03979847340429532.
Trial 4 finished with value: 0.23228097273032844 and parameters: {'x': 0.23228097273032844}. Best is trial 0 with value: 0.03979847340429532.
Trial 5 pruned.
Trial 6 pruned.

从上面的日志中可以看出, study 如期在 7 个trial 之后停止了.

Total running time of the script: ( 0 minutes 0.008 seconds)

Gallery generated by Sphinx-Gallery