用 RDB 后端保存/恢复 Study

RDB后端可以实现持久化实验(即保存和恢复 study)以及访问 study 的历史记录。此外,我们还可以利用这个特点来进行分布式优化。具体描述见 简单的并行化.

在本部分中,我们将尝试一个在本地环境下运行SQLite DB的简单例子。

备注

通过设置 DB 的 storage URL 参数,你也可以使用其他的 RDB 后端,比如 PostgreSQL 或者 MySQL. 设置 URL 的方式参见 SQLAlchemy 的文档.

新建 Study

通过调用函数 create_study(),我们可以创建一个持久化的 study. 创建新 study 会自动初始化一个 SQLite 文件 example.db.

import logging
import sys

import optuna

# Add stream handler of stdout to show the messages
optuna.logging.get_logger("optuna").addHandler(logging.StreamHandler(sys.stdout))
study_name = "example-study"  # Unique identifier of the study.
storage_name = "sqlite:///{}.db".format(study_name)
study = optuna.create_study(study_name=study_name, storage=storage_name)

Out:

A new study created in RDB with name: example-study

为了运行一个 study, 我们需要将目标函数传入 optimize() 方法并调用它。

def objective(trial):
    x = trial.suggest_float("x", -10, 10)
    return (x - 2) ** 2


study.optimize(objective, n_trials=3)

Out:

Trial 0 finished with value: 26.616254945485093 and parameters: {'x': 7.1590943919921735}. Best is trial 0 with value: 26.616254945485093.
Trial 1 finished with value: 136.0228813400637 and parameters: {'x': -9.662884777792486}. Best is trial 0 with value: 26.616254945485093.
Trial 2 finished with value: 6.92296253664903 and parameters: {'x': 4.631152321065626}. Best is trial 2 with value: 6.92296253664903.

恢复 Study

为了恢复 study, 首先需要初始化一个 Study 对象, 并将该study 的名字 example-study 和 DB URL参数 sqlite:///example.db 传入其中。

study = optuna.create_study(study_name=study_name, storage=storage_name, load_if_exists=True)
study.optimize(objective, n_trials=3)

Out:

Using an existing study with name 'example-study' instead of creating a new one.
Trial 3 finished with value: 88.36747504461586 and parameters: {'x': -7.400397600347331}. Best is trial 2 with value: 6.92296253664903.
Trial 4 finished with value: 9.371463685595762 and parameters: {'x': 5.061284646287529}. Best is trial 2 with value: 6.92296253664903.
Trial 5 finished with value: 11.281489026771837 and parameters: {'x': 5.358792793068938}. Best is trial 2 with value: 6.92296253664903.

实验历史记录

我们可以通过 Study 类来获得 study 和对应 trials的历史记录。比如,下面的语句可以获取 example-study 的所有 trials.

study = optuna.create_study(study_name=study_name, storage=storage_name, load_if_exists=True)
df = study.trials_dataframe(attrs=("number", "value", "params", "state"))

Out:

Using an existing study with name 'example-study' instead of creating a new one.

trials_dataframe() 方法会返回一个如下的 pandas dataframe:

print(df)

Out:

   number       value  params_x     state
0       0   26.616255  7.159094  COMPLETE
1       1  136.022881 -9.662885  COMPLETE
2       2    6.922963  4.631152  COMPLETE
3       3   88.367475 -7.400398  COMPLETE
4       4    9.371464  5.061285  COMPLETE
5       5   11.281489  5.358793  COMPLETE

Study 对象也有一些其他属性,比如 trials, best_valuebest_params (见 轻量级、多功能和跨平台架构).

print("Best params: ", study.best_params)
print("Best value: ", study.best_value)
print("Best Trial: ", study.best_trial)
print("Trials: ", study.trials)

Out:

Best params:  {'x': 4.631152321065626}
Best value:  6.92296253664903
Best Trial:  FrozenTrial(number=2, values=[6.92296253664903], datetime_start=datetime.datetime(2022, 5, 26, 12, 5, 42, 730964), datetime_complete=datetime.datetime(2022, 5, 26, 12, 5, 42, 746113), params={'x': 4.631152321065626}, distributions={'x': UniformDistribution(high=10.0, low=-10.0)}, user_attrs={}, system_attrs={}, intermediate_values={}, trial_id=3, state=TrialState.COMPLETE, value=None)
Trials:  [FrozenTrial(number=0, values=[26.616254945485093], datetime_start=datetime.datetime(2022, 5, 26, 12, 5, 42, 635850), datetime_complete=datetime.datetime(2022, 5, 26, 12, 5, 42, 655944), params={'x': 7.1590943919921735}, distributions={'x': UniformDistribution(high=10.0, low=-10.0)}, user_attrs={}, system_attrs={}, intermediate_values={}, trial_id=1, state=TrialState.COMPLETE, value=None), FrozenTrial(number=1, values=[136.0228813400637], datetime_start=datetime.datetime(2022, 5, 26, 12, 5, 42, 691172), datetime_complete=datetime.datetime(2022, 5, 26, 12, 5, 42, 707448), params={'x': -9.662884777792486}, distributions={'x': UniformDistribution(high=10.0, low=-10.0)}, user_attrs={}, system_attrs={}, intermediate_values={}, trial_id=2, state=TrialState.COMPLETE, value=None), FrozenTrial(number=2, values=[6.92296253664903], datetime_start=datetime.datetime(2022, 5, 26, 12, 5, 42, 730964), datetime_complete=datetime.datetime(2022, 5, 26, 12, 5, 42, 746113), params={'x': 4.631152321065626}, distributions={'x': UniformDistribution(high=10.0, low=-10.0)}, user_attrs={}, system_attrs={}, intermediate_values={}, trial_id=3, state=TrialState.COMPLETE, value=None), FrozenTrial(number=3, values=[88.36747504461586], datetime_start=datetime.datetime(2022, 5, 26, 12, 5, 42, 812319), datetime_complete=datetime.datetime(2022, 5, 26, 12, 5, 42, 833198), params={'x': -7.400397600347331}, distributions={'x': UniformDistribution(high=10.0, low=-10.0)}, user_attrs={}, system_attrs={}, intermediate_values={}, trial_id=4, state=TrialState.COMPLETE, value=None), FrozenTrial(number=4, values=[9.371463685595762], datetime_start=datetime.datetime(2022, 5, 26, 12, 5, 42, 862712), datetime_complete=datetime.datetime(2022, 5, 26, 12, 5, 42, 878204), params={'x': 5.061284646287529}, distributions={'x': UniformDistribution(high=10.0, low=-10.0)}, user_attrs={}, system_attrs={}, intermediate_values={}, trial_id=5, state=TrialState.COMPLETE, value=None), FrozenTrial(number=5, values=[11.281489026771837], datetime_start=datetime.datetime(2022, 5, 26, 12, 5, 42, 901207), datetime_complete=datetime.datetime(2022, 5, 26, 12, 5, 42, 916775), params={'x': 5.358792793068938}, distributions={'x': UniformDistribution(high=10.0, low=-10.0)}, user_attrs={}, system_attrs={}, intermediate_values={}, trial_id=6, state=TrialState.COMPLETE, value=None)]

Total running time of the script: ( 0 minutes 0.626 seconds)

Gallery generated by Sphinx-Gallery