备注
Click here to download the full example code
Ask-and-Tell 接口
Optuna 带有 Ask-and-Tell 接口, 它为超参数优化提供了一个更灵活的接口.本教程将展示三种可以用到ask-and-tell接口的情况.
只进行最小改动, 就可将 Optuna 应用到一个现存的优化问题上
考虑一个传统的有监督分类问题; 你想最大化验证准确率. 所以, 你训练了一个 LogisticRegression 作为简单模型.
import numpy as np
from sklearn.datasets import make_classification
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
import optuna
X, y = make_classification(n_features=10)
X_train, X_test, y_train, y_test = train_test_split(X, y)
C = 0.01
clf = LogisticRegression(C=C)
clf.fit(X_train, y_train)
val_accuracy = clf.score(X_test, y_test) # the objective
然后你试图通过 Optuna 来优化超参数 C
和 solver
.当你简单地引入 Optuna 后, 你定义了一个 objective
函数, 它接受 trial
并且调用 trial
的 suggest_*
方法 来采样超参数:
def objective(trial):
X, y = make_classification(n_features=10)
X_train, X_test, y_train, y_test = train_test_split(X, y)
C = trial.suggest_loguniform("C", 1e-7, 10.0)
solver = trial.suggest_categorical("solver", ("lbfgs", "saga"))
clf = LogisticRegression(C=C, solver=solver)
clf.fit(X_train, y_train)
val_accuracy = clf.score(X_test, y_test)
return val_accuracy
study = optuna.create_study(direction="maximize")
study.optimize(objective, n_trials=10)
这个接口并不灵活. 比如, 如果 objective
需要不同于 trial
的额外参数, 你就需要顶一个类, 就像 How to define objective functions that have own arguments? 里做的那样. 而 ask-and-tell 接口提供了更加灵活的语法来优化超参数. 下面的例子等同于上面的代码块.
study = optuna.create_study(direction="maximize")
n_trials = 10
for _ in range(n_trials):
trial = study.ask() # `trial` is a `Trial` and not a `FrozenTrial`.
C = trial.suggest_loguniform("C", 1e-7, 10.0)
solver = trial.suggest_categorical("solver", ("lbfgs", "saga"))
clf = LogisticRegression(C=C, solver=solver)
clf.fit(X_train, y_train)
val_accuracy = clf.score(X_test, y_test)
study.tell(trial, val_accuracy) # tell the pair of trial and objective value
Out:
/home/docs/checkouts/readthedocs.org/user_builds/optuna-zh-cn/envs/latest/lib/python3.7/site-packages/sklearn/linear_model/_sag.py:354: ConvergenceWarning:
The max_iter was reached which means the coef_ did not converge
/home/docs/checkouts/readthedocs.org/user_builds/optuna-zh-cn/envs/latest/lib/python3.7/site-packages/sklearn/linear_model/_sag.py:354: ConvergenceWarning:
The max_iter was reached which means the coef_ did not converge
/home/docs/checkouts/readthedocs.org/user_builds/optuna-zh-cn/envs/latest/lib/python3.7/site-packages/sklearn/linear_model/_sag.py:354: ConvergenceWarning:
The max_iter was reached which means the coef_ did not converge
主要区别是这里用了两个方法 optuna.study.Study.ask()
和 optuna.study.Study.tell()
. optuna.study.Study.ask()
创建了一个 trial, 它可以采样超参数, 而 optuna.study.Study.tell()
通过传递 trial
和一个目标函数值完成这个 trial. 你可以在没有 objective
函数的情况下对你的原始代码应用 Optuna 的超参数优化.
如果你想用 pruner 让你的优化变得更快, 你需要显式地将 trial 的状态传入到 optuna.study.Study.tell()
的参数中, 就像下面这样:
import numpy as np
from sklearn.datasets import load_iris
from sklearn.linear_model import SGDClassifier
from sklearn.model_selection import train_test_split
import optuna
X, y = load_iris(return_X_y=True)
X_train, X_valid, y_train, y_valid = train_test_split(X, y)
classes = np.unique(y)
n_train_iter = 100
# define study with hyperband pruner.
study = optuna.create_study(
direction="maximize",
pruner=optuna.pruners.HyperbandPruner(
min_resource=1, max_resource=n_train_iter, reduction_factor=3
),
)
for _ in range(20):
trial = study.ask()
alpha = trial.suggest_uniform("alpha", 0.0, 1.0)
clf = SGDClassifier(alpha=alpha)
pruned_trial = False
for step in range(n_train_iter):
clf.partial_fit(X_train, y_train, classes=classes)
intermediate_value = clf.score(X_valid, y_valid)
trial.report(intermediate_value, step)
if trial.should_prune():
pruned_trial = True
break
if pruned_trial:
study.tell(trial, state=optuna.trial.TrialState.PRUNED) # tell the pruned state
else:
score = clf.score(X_valid, y_valid)
study.tell(trial, score) # tell objective value
备注
optuna.study.Study.tell()
可以接受一个 trial number 而不是 trial 对象本身. study.tell(trial.number, y)
等价于 study.tell(trial, y)
.
定义-运行 (Define-and-Run)
ask-and-tell 接口同时支持 define-by-run 和 define-and-run API. 在上面 define-by-run的例子之外, 本部分展示 define-and-run 的 API.
在调用 optuna.study.Study.ask()
方法之前为 define-and-run API 定义超参数分布. 例如,
distributions = {
"C": optuna.distributions.LogUniformDistribution(1e-7, 10.0),
"solver": optuna.distributions.CategoricalDistribution(("lbfgs", "saga")),
}
在每次调用时, 将 distributions
传递给 optuna.study.Study.ask()
方法.返回的 trial
里将包含建议的超参数.
study = optuna.create_study(direction="maximize")
n_trials = 10
for _ in range(n_trials):
trial = study.ask(distributions) # pass the pre-defined distributions.
# two hyperparameters are already sampled from the pre-defined distributions
C = trial.params["C"]
solver = trial.params["solver"]
clf = LogisticRegression(C=C, solver=solver)
clf.fit(X_train, y_train)
val_accuracy = clf.score(X_test, y_test)
study.tell(trial, val_accuracy)
Out:
/home/docs/checkouts/readthedocs.org/user_builds/optuna-zh-cn/envs/latest/lib/python3.7/site-packages/sklearn/linear_model/_sag.py:354: ConvergenceWarning:
The max_iter was reached which means the coef_ did not converge
/home/docs/checkouts/readthedocs.org/user_builds/optuna-zh-cn/envs/latest/lib/python3.7/site-packages/sklearn/linear_model/_sag.py:354: ConvergenceWarning:
The max_iter was reached which means the coef_ did not converge
批优化
为了更快地完成优化, ask-and-tell 接口使我们可以优化一个批目标函数. 比如, 并行求值, 向量操作等.
下面这个目标函数接受批超参数 xs
而不是一个单独的超参数, 并对整个向量计算目标函数.
def batched_objective(xs: np.ndarray):
return xs ** 2 + 1
在下面的例子中, 一批中包含的超参数个数是 \(10\), 而 batched_objective
进行了三次求值.因此, trial的个数是 \(30\). 注意, 在 批求值以后, 你需要存储 trial_ids
或者 trial
才能调用 optuna.study.Study.tell()
方法
batch_size = 10
study = optuna.create_study()
for _ in range(3):
# create batch
trial_ids = []
samples = []
for _ in range(batch_size):
trial = study.ask()
trial_ids.append(trial.number)
x = trial.suggest_int("x", -10, 10)
samples.append(x)
# evaluate batched objective
samples = np.array(samples)
objectives = batched_objective(samples)
# finish all trials in the batch
for trial_id, objective in zip(trial_ids, objectives):
study.tell(trial_id, objective)
Total running time of the script: ( 0 minutes 0.162 seconds)