用 RDB 后端保存/恢复 Study

RDB后端可以实现持久化实验(即保存和恢复 study)以及访问 study 的历史记录。此外,我们还可以利用这个特点来进行分布式优化。具体描述见 简单的并行化.

在本部分中,我们将尝试一个在本地环境下运行SQLite DB的简单例子。

备注

通过设置 DB 的 storage URL 参数,你也可以使用其他的 RDB 后端,比如 PostgreSQL 或者 MySQL. 设置 URL 的方式参见 SQLAlchemy 的文档.

新建 Study

通过调用函数 create_study(),我们可以创建一个持久化的 study. 创建新 study 会自动初始化一个 SQLite 文件 example.db.

import logging
import sys

import optuna

# Add stream handler of stdout to show the messages
optuna.logging.get_logger("optuna").addHandler(logging.StreamHandler(sys.stdout))
study_name = "example-study"  # Unique identifier of the study.
storage_name = "sqlite:///{}.db".format(study_name)
study = optuna.create_study(study_name=study_name, storage=storage_name)

Out:

A new study created in RDB with name: example-study

为了运行一个 study, 我们需要将目标函数传入 optimize() 方法并调用它。

def objective(trial):
    x = trial.suggest_float("x", -10, 10)
    return (x - 2) ** 2


study.optimize(objective, n_trials=3)

Out:

Trial 0 finished with value: 1.5067618568079724 and parameters: {'x': 3.2275022838300433}. Best is trial 0 with value: 1.5067618568079724.
Trial 1 finished with value: 24.96165877167141 and parameters: {'x': -2.996164405988999}. Best is trial 0 with value: 1.5067618568079724.
Trial 2 finished with value: 0.45814470516718514 and parameters: {'x': 2.676863874916652}. Best is trial 2 with value: 0.45814470516718514.

恢复 Study

为了恢复 study, 首先需要初始化一个 Study 对象, 并将该study 的名字 example-study 和 DB URL参数 sqlite:///example.db 传入其中。

study = optuna.create_study(study_name=study_name, storage=storage_name, load_if_exists=True)
study.optimize(objective, n_trials=3)

Out:

Using an existing study with name 'example-study' instead of creating a new one.
Trial 3 finished with value: 54.55663920492726 and parameters: {'x': 9.38624662497315}. Best is trial 2 with value: 0.45814470516718514.
Trial 4 finished with value: 14.274094669669244 and parameters: {'x': -1.7781072866806262}. Best is trial 2 with value: 0.45814470516718514.
Trial 5 finished with value: 99.93271079015234 and parameters: {'x': -7.996634973337395}. Best is trial 2 with value: 0.45814470516718514.

实验历史记录

我们可以通过 Study 类来获得 study 和对应 trials的历史记录。比如,下面的语句可以获取 example-study 的所有 trials.

study = optuna.create_study(study_name=study_name, storage=storage_name, load_if_exists=True)
df = study.trials_dataframe(attrs=("number", "value", "params", "state"))

Out:

Using an existing study with name 'example-study' instead of creating a new one.

trials_dataframe() 方法会返回一个如下的 pandas dataframe:

print(df)

Out:

   number      value  params_x     state
0       0   1.506762  3.227502  COMPLETE
1       1  24.961659 -2.996164  COMPLETE
2       2   0.458145  2.676864  COMPLETE
3       3  54.556639  9.386247  COMPLETE
4       4  14.274095 -1.778107  COMPLETE
5       5  99.932711 -7.996635  COMPLETE

Study 对象也有一些其他属性,比如 trials, best_valuebest_params (见 轻量级、多功能和跨平台架构).

print("Best params: ", study.best_params)
print("Best value: ", study.best_value)
print("Best Trial: ", study.best_trial)
print("Trials: ", study.trials)

Out:

Best params:  {'x': 2.676863874916652}
Best value:  0.45814470516718514
Best Trial:  FrozenTrial(number=2, values=[0.45814470516718514], datetime_start=datetime.datetime(2022, 5, 26, 12, 5, 42, 892134), datetime_complete=datetime.datetime(2022, 5, 26, 12, 5, 42, 908612), params={'x': 2.676863874916652}, distributions={'x': UniformDistribution(high=10.0, low=-10.0)}, user_attrs={}, system_attrs={}, intermediate_values={}, trial_id=3, state=TrialState.COMPLETE, value=None)
Trials:  [FrozenTrial(number=0, values=[1.5067618568079724], datetime_start=datetime.datetime(2022, 5, 26, 12, 5, 42, 796078), datetime_complete=datetime.datetime(2022, 5, 26, 12, 5, 42, 816506), params={'x': 3.2275022838300433}, distributions={'x': UniformDistribution(high=10.0, low=-10.0)}, user_attrs={}, system_attrs={}, intermediate_values={}, trial_id=1, state=TrialState.COMPLETE, value=None), FrozenTrial(number=1, values=[24.96165877167141], datetime_start=datetime.datetime(2022, 5, 26, 12, 5, 42, 852903), datetime_complete=datetime.datetime(2022, 5, 26, 12, 5, 42, 867951), params={'x': -2.996164405988999}, distributions={'x': UniformDistribution(high=10.0, low=-10.0)}, user_attrs={}, system_attrs={}, intermediate_values={}, trial_id=2, state=TrialState.COMPLETE, value=None), FrozenTrial(number=2, values=[0.45814470516718514], datetime_start=datetime.datetime(2022, 5, 26, 12, 5, 42, 892134), datetime_complete=datetime.datetime(2022, 5, 26, 12, 5, 42, 908612), params={'x': 2.676863874916652}, distributions={'x': UniformDistribution(high=10.0, low=-10.0)}, user_attrs={}, system_attrs={}, intermediate_values={}, trial_id=3, state=TrialState.COMPLETE, value=None), FrozenTrial(number=3, values=[54.55663920492726], datetime_start=datetime.datetime(2022, 5, 26, 12, 5, 42, 976636), datetime_complete=datetime.datetime(2022, 5, 26, 12, 5, 42, 995327), params={'x': 9.38624662497315}, distributions={'x': UniformDistribution(high=10.0, low=-10.0)}, user_attrs={}, system_attrs={}, intermediate_values={}, trial_id=4, state=TrialState.COMPLETE, value=None), FrozenTrial(number=4, values=[14.274094669669244], datetime_start=datetime.datetime(2022, 5, 26, 12, 5, 43, 26577), datetime_complete=datetime.datetime(2022, 5, 26, 12, 5, 43, 42780), params={'x': -1.7781072866806262}, distributions={'x': UniformDistribution(high=10.0, low=-10.0)}, user_attrs={}, system_attrs={}, intermediate_values={}, trial_id=5, state=TrialState.COMPLETE, value=None), FrozenTrial(number=5, values=[99.93271079015234], datetime_start=datetime.datetime(2022, 5, 26, 12, 5, 43, 66882), datetime_complete=datetime.datetime(2022, 5, 26, 12, 5, 43, 81061), params={'x': -7.996634973337395}, distributions={'x': UniformDistribution(high=10.0, low=-10.0)}, user_attrs={}, system_attrs={}, intermediate_values={}, trial_id=6, state=TrialState.COMPLETE, value=None)]

Total running time of the script: ( 0 minutes 0.622 seconds)

Gallery generated by Sphinx-Gallery