备注
Click here to download the full example code
Study.optimize 的回调
本教程展示了如何使用和实现一个 用于 optimize()
的Optuna Callback
.
Callback
会在 objective
每次求值以后被调用一次, 它接受 Study
和 FrozenTrial
作为参数并进行处理.
MLflowCallback
是个好例子.
在特定数量的 trial 被剪枝后终止优化
本例实现了一个有状态的回调函数. 如果特定数目的 trial 被剪枝了, 它将终止优化. 被剪枝的 trial 数是通过 threshold
来指定的.
import optuna
class StopWhenTrialKeepBeingPrunedCallback:
def __init__(self, threshold: int):
self.threshold = threshold
self._consequtive_pruned_count = 0
def __call__(self, study: optuna.study.Study, trial: optuna.trial.FrozenTrial) -> None:
if trial.state == optuna.trial.TrialState.PRUNED:
self._consequtive_pruned_count += 1
else:
self._consequtive_pruned_count = 0
if self._consequtive_pruned_count >= self.threshold:
study.stop()
该目标函数会对除了前五个 trial 之外的所有trial 进行剪枝 (trial.number
从 0 开始计数).
def objective(trial):
if trial.number > 4:
raise optuna.TrialPruned
return trial.suggest_float("x", 0, 1)
在这里, 我们将阈值设置为 2
: 优化过程会在一旦两个 trial 被剪枝后发生. 因此, 我们预期该 study 会在 7 个 trial 后停止.
import logging
import sys
# Add stream handler of stdout to show the messages
optuna.logging.get_logger("optuna").addHandler(logging.StreamHandler(sys.stdout))
study_stop_cb = StopWhenTrialKeepBeingPrunedCallback(2)
study = optuna.create_study()
study.optimize(objective, n_trials=10, callbacks=[study_stop_cb])
Out:
A new study created in memory with name: no-name-be0f83c6-46e8-42fd-9d9e-f6ebd9175ba1
Trial 0 finished with value: 0.03979847340429532 and parameters: {'x': 0.03979847340429532}. Best is trial 0 with value: 0.03979847340429532.
Trial 1 finished with value: 0.32272885613475355 and parameters: {'x': 0.32272885613475355}. Best is trial 0 with value: 0.03979847340429532.
Trial 2 finished with value: 0.04095547461074589 and parameters: {'x': 0.04095547461074589}. Best is trial 0 with value: 0.03979847340429532.
Trial 3 finished with value: 0.36590369991767835 and parameters: {'x': 0.36590369991767835}. Best is trial 0 with value: 0.03979847340429532.
Trial 4 finished with value: 0.23228097273032844 and parameters: {'x': 0.23228097273032844}. Best is trial 0 with value: 0.03979847340429532.
Trial 5 pruned.
Trial 6 pruned.
从上面的日志中可以看出, study 如期在 7 个trial 之后停止了.
Total running time of the script: ( 0 minutes 0.008 seconds)