高效的优化算法

通过采用最先进的超参数采样算法和对无望 trial 的剪枝, Optuna使得高效的超参数优化成为可能。

采样算法

利用 suggested 参数值和评估的目标值的记录,采样器基本上不断缩小搜索空间,直到找到一个最佳的搜索空间,其产生的参数会带来 更好的目标函数值。关于采样器如何 suggest 参数的更详细的解释见 optuna.samplers.BaseSampler.

Optuna 提供了下列采样算法:

默认的采样器是 optuna.samplers.TPESampler.

切换采样器

import optuna

默认情况下, Optuna 这样使用 TPESampler.

study = optuna.create_study()
print(f"Sampler is {study.sampler.__class__.__name__}")

Out:

Sampler is TPESampler

如果你希望使用其他采样器,比如 RandomSamplerCmaEsSampler,

study = optuna.create_study(sampler=optuna.samplers.RandomSampler())
print(f"Sampler is {study.sampler.__class__.__name__}")

study = optuna.create_study(sampler=optuna.samplers.CmaEsSampler())
print(f"Sampler is {study.sampler.__class__.__name__}")

Out:

Sampler is RandomSampler
Sampler is CmaEsSampler

剪枝算法

Pruners 自动在训练的早期(也就是自动化的 early-stopping)终止无望的 trial.

Optuna 提供以下剪枝算法:

在大多数例子中我们采用的 optuna.pruners.MedianPruner, 尽管其性能基本上会被 optuna.pruners.SuccessiveHalvingPruneroptuna.pruners.HyperbandPruner 超过,就像在 这个基准测试结果 中那样.

激活 Pruner

要打开剪枝特性的话,你需要在迭代式训练的每一步后调用 report()should_prune(). report() 定期监控目标函数的中间值. should_prune() 确定终结那些没有达到预先设定条件的 trial.

我们推荐在主流机器学习框架中使用集成模块,全部的模块列表在 optuna.integration 里,用例见 optuna/examples.

import logging
import sys

import sklearn.datasets
import sklearn.linear_model
import sklearn.model_selection


def objective(trial):
    iris = sklearn.datasets.load_iris()
    classes = list(set(iris.target))
    train_x, valid_x, train_y, valid_y = sklearn.model_selection.train_test_split(
        iris.data, iris.target, test_size=0.25, random_state=0
    )

    alpha = trial.suggest_float("alpha", 1e-5, 1e-1, log=True)
    clf = sklearn.linear_model.SGDClassifier(alpha=alpha)

    for step in range(100):
        clf.partial_fit(train_x, train_y, classes=classes)

        # Report intermediate objective value.
        intermediate_value = 1.0 - clf.score(valid_x, valid_y)
        trial.report(intermediate_value, step)

        # Handle pruning based on the intermediate value.
        if trial.should_prune():
            raise optuna.TrialPruned()

    return 1.0 - clf.score(valid_x, valid_y)

将中位数终止规则设置为剪枝条件。

# Add stream handler of stdout to show the messages
optuna.logging.get_logger("optuna").addHandler(logging.StreamHandler(sys.stdout))
study = optuna.create_study(pruner=optuna.pruners.MedianPruner())
study.optimize(objective, n_trials=20)

Out:

A new study created in memory with name: no-name-1891047c-93d9-4f83-8cb8-145a51e6203f
Trial 0 finished with value: 0.052631578947368474 and parameters: {'alpha': 0.012161563144007067}. Best is trial 0 with value: 0.052631578947368474.
Trial 1 finished with value: 0.052631578947368474 and parameters: {'alpha': 0.02145981426399537}. Best is trial 0 with value: 0.052631578947368474.
Trial 2 finished with value: 0.42105263157894735 and parameters: {'alpha': 1.7345323169766322e-05}. Best is trial 0 with value: 0.052631578947368474.
Trial 3 finished with value: 0.052631578947368474 and parameters: {'alpha': 0.00027693947091443974}. Best is trial 0 with value: 0.052631578947368474.
Trial 4 finished with value: 0.07894736842105265 and parameters: {'alpha': 0.019560045732631263}. Best is trial 0 with value: 0.052631578947368474.
Trial 5 pruned.
Trial 6 pruned.
Trial 7 pruned.
Trial 8 pruned.
Trial 9 pruned.
Trial 10 pruned.
Trial 11 finished with value: 0.3421052631578947 and parameters: {'alpha': 0.0901342111716878}. Best is trial 0 with value: 0.052631578947368474.
Trial 12 pruned.
Trial 13 pruned.
Trial 14 pruned.
Trial 15 pruned.
Trial 16 finished with value: 0.21052631578947367 and parameters: {'alpha': 0.04042583142588879}. Best is trial 0 with value: 0.052631578947368474.
Trial 17 finished with value: 0.23684210526315785 and parameters: {'alpha': 0.0065084109366433385}. Best is trial 0 with value: 0.052631578947368474.
Trial 18 pruned.
Trial 19 pruned.

如你所见,有几个 trial 在其迭代完成之前被剪枝(终止)了。消息格式是 "Trial <Trial Number> pruned.".

应该使用哪个 pruner 呢?

对于非深度学习来说,根据 optuna/optuna - wiki “Benchmarks with Kurobako” 里的基准测试结果,我们推荐

不过,注意这个基准测试不是深度学习。对于深度学习而言,请参考 Ozaki et al, Hyperparameter Optimization Methods: Overview and Characteristics, in IEICE Trans, Vol.J103-D No.9 pp.615-631, 2020 里的这张表格。

Parallel Compute Resource

Categorical/Conditional Hyperparameters

Recommended Algorithms

Limited

No

TPE. GP-EI if search space is low-dimensional and continuous.

Yes

TPE. GP-EI if search space is low-dimensional and continuous

Sufficient

No

CMA-ES, Random Search

Yes

Random Search or Genetic Algorithm

用于剪枝的集成模块

为了用最简单的形式实现剪枝算法,Optuna 为以下库提供了集成模块。

关于 Optuna 集成模块的完整列表,参见 optuna.integration.

For example, XGBoostPruningCallback introduces pruning without directly changing the logic of training iteration. (See also example for the entire script.)

pruning_callback = optuna.integration.XGBoostPruningCallback(trial, 'validation-error')
bst = xgb.train(param, dtrain, evals=[(dvalid, 'validation')], callbacks=[pruning_callback])

Total running time of the script: ( 0 minutes 1.439 seconds)

Gallery generated by Sphinx-Gallery