备注
Click here to download the full example code
用 RDB 后端保存/恢复 Study
RDB后端可以实现持久化实验(即保存和恢复 study)以及访问 study 的历史记录。此外,我们还可以利用这个特点来进行分布式优化。具体描述见 简单的并行化.
在本部分中,我们将尝试一个在本地环境下运行SQLite DB的简单例子。
备注
通过设置 DB 的 storage URL 参数,你也可以使用其他的 RDB 后端,比如 PostgreSQL 或者 MySQL. 设置 URL 的方式参见 SQLAlchemy 的文档.
新建 Study
通过调用函数 create_study()
,我们可以创建一个持久化的 study. 创建新 study 会自动初始化一个 SQLite 文件 example.db
.
import logging
import sys
import optuna
# Add stream handler of stdout to show the messages
optuna.logging.get_logger("optuna").addHandler(logging.StreamHandler(sys.stdout))
study_name = "example-study" # Unique identifier of the study.
storage_name = "sqlite:///{}.db".format(study_name)
study = optuna.create_study(study_name=study_name, storage=storage_name)
Out:
A new study created in RDB with name: example-study
为了运行一个 study, 我们需要将目标函数传入 optimize()
方法并调用它。
def objective(trial):
x = trial.suggest_float("x", -10, 10)
return (x - 2) ** 2
study.optimize(objective, n_trials=3)
Out:
Trial 0 finished with value: 26.616254945485093 and parameters: {'x': 7.1590943919921735}. Best is trial 0 with value: 26.616254945485093.
Trial 1 finished with value: 136.0228813400637 and parameters: {'x': -9.662884777792486}. Best is trial 0 with value: 26.616254945485093.
Trial 2 finished with value: 6.92296253664903 and parameters: {'x': 4.631152321065626}. Best is trial 2 with value: 6.92296253664903.
恢复 Study
为了恢复 study, 首先需要初始化一个 Study
对象, 并将该study 的名字 example-study
和 DB URL参数 sqlite:///example.db
传入其中。
study = optuna.create_study(study_name=study_name, storage=storage_name, load_if_exists=True)
study.optimize(objective, n_trials=3)
Out:
Using an existing study with name 'example-study' instead of creating a new one.
Trial 3 finished with value: 88.36747504461586 and parameters: {'x': -7.400397600347331}. Best is trial 2 with value: 6.92296253664903.
Trial 4 finished with value: 9.371463685595762 and parameters: {'x': 5.061284646287529}. Best is trial 2 with value: 6.92296253664903.
Trial 5 finished with value: 11.281489026771837 and parameters: {'x': 5.358792793068938}. Best is trial 2 with value: 6.92296253664903.
实验历史记录
我们可以通过 Study
类来获得 study 和对应 trials的历史记录。比如,下面的语句可以获取 example-study
的所有 trials.
study = optuna.create_study(study_name=study_name, storage=storage_name, load_if_exists=True)
df = study.trials_dataframe(attrs=("number", "value", "params", "state"))
Out:
Using an existing study with name 'example-study' instead of creating a new one.
trials_dataframe()
方法会返回一个如下的 pandas dataframe:
print(df)
Out:
number value params_x state
0 0 26.616255 7.159094 COMPLETE
1 1 136.022881 -9.662885 COMPLETE
2 2 6.922963 4.631152 COMPLETE
3 3 88.367475 -7.400398 COMPLETE
4 4 9.371464 5.061285 COMPLETE
5 5 11.281489 5.358793 COMPLETE
Study
对象也有一些其他属性,比如 trials
, best_value
和 best_params
(见 轻量级、多功能和跨平台架构).
print("Best params: ", study.best_params)
print("Best value: ", study.best_value)
print("Best Trial: ", study.best_trial)
print("Trials: ", study.trials)
Out:
Best params: {'x': 4.631152321065626}
Best value: 6.92296253664903
Best Trial: FrozenTrial(number=2, values=[6.92296253664903], datetime_start=datetime.datetime(2022, 5, 26, 12, 5, 42, 730964), datetime_complete=datetime.datetime(2022, 5, 26, 12, 5, 42, 746113), params={'x': 4.631152321065626}, distributions={'x': UniformDistribution(high=10.0, low=-10.0)}, user_attrs={}, system_attrs={}, intermediate_values={}, trial_id=3, state=TrialState.COMPLETE, value=None)
Trials: [FrozenTrial(number=0, values=[26.616254945485093], datetime_start=datetime.datetime(2022, 5, 26, 12, 5, 42, 635850), datetime_complete=datetime.datetime(2022, 5, 26, 12, 5, 42, 655944), params={'x': 7.1590943919921735}, distributions={'x': UniformDistribution(high=10.0, low=-10.0)}, user_attrs={}, system_attrs={}, intermediate_values={}, trial_id=1, state=TrialState.COMPLETE, value=None), FrozenTrial(number=1, values=[136.0228813400637], datetime_start=datetime.datetime(2022, 5, 26, 12, 5, 42, 691172), datetime_complete=datetime.datetime(2022, 5, 26, 12, 5, 42, 707448), params={'x': -9.662884777792486}, distributions={'x': UniformDistribution(high=10.0, low=-10.0)}, user_attrs={}, system_attrs={}, intermediate_values={}, trial_id=2, state=TrialState.COMPLETE, value=None), FrozenTrial(number=2, values=[6.92296253664903], datetime_start=datetime.datetime(2022, 5, 26, 12, 5, 42, 730964), datetime_complete=datetime.datetime(2022, 5, 26, 12, 5, 42, 746113), params={'x': 4.631152321065626}, distributions={'x': UniformDistribution(high=10.0, low=-10.0)}, user_attrs={}, system_attrs={}, intermediate_values={}, trial_id=3, state=TrialState.COMPLETE, value=None), FrozenTrial(number=3, values=[88.36747504461586], datetime_start=datetime.datetime(2022, 5, 26, 12, 5, 42, 812319), datetime_complete=datetime.datetime(2022, 5, 26, 12, 5, 42, 833198), params={'x': -7.400397600347331}, distributions={'x': UniformDistribution(high=10.0, low=-10.0)}, user_attrs={}, system_attrs={}, intermediate_values={}, trial_id=4, state=TrialState.COMPLETE, value=None), FrozenTrial(number=4, values=[9.371463685595762], datetime_start=datetime.datetime(2022, 5, 26, 12, 5, 42, 862712), datetime_complete=datetime.datetime(2022, 5, 26, 12, 5, 42, 878204), params={'x': 5.061284646287529}, distributions={'x': UniformDistribution(high=10.0, low=-10.0)}, user_attrs={}, system_attrs={}, intermediate_values={}, trial_id=5, state=TrialState.COMPLETE, value=None), FrozenTrial(number=5, values=[11.281489026771837], datetime_start=datetime.datetime(2022, 5, 26, 12, 5, 42, 901207), datetime_complete=datetime.datetime(2022, 5, 26, 12, 5, 42, 916775), params={'x': 5.358792793068938}, distributions={'x': UniformDistribution(high=10.0, low=-10.0)}, user_attrs={}, system_attrs={}, intermediate_values={}, trial_id=6, state=TrialState.COMPLETE, value=None)]
Total running time of the script: ( 0 minutes 0.626 seconds)