备注
Click here to download the full example code
用户定义的采样器 (Sampler)
你可以用用户定义的 sampler 来实现:
试验你自己的采样算法,
实现具体任务对应的算法来改进优化性能,或者
将其他的优化框架包装起来,整合进 Optuna 的流水线中 (比如
SkoptSampler
).
本节将介绍 sampler 类的内部行为,并展示一个实现用户自定义 sampler 的例子。
Sampler 概述
A sampler has the responsibility to determine the parameter values to be evaluated in a trial.
When a suggest API (e.g., suggest_float()
) is called inside an objective function, the corresponding distribution object (e.g., UniformDistribution
) is created internally. A sampler samples a parameter value from the distribution. The sampled value is returned to the caller of the suggest API and evaluated in the objective function.
为了创建一个新的 sampler, 你所定义的类需继承 BaseSampler
.该基类提供三个抽象方法:infer_relative_search_space()
, sample_relative()
和 sample_independent()
.
从这些方法名可以看出,Optuna 支持两种类型的采样过程:一种是 relative sampling, 它考虑了单个 trial 内参数之间的相关性, 另一种是 independent sampling, 它对各个参数的采样是彼此独立的。
在一个 trial 刚开始时,infer_relative_search_space()
会被调用,它向该 trial 提供一个相对搜索空间。之后, sample_relative()
会被触发,它从该搜索空间中对相对参数进行采样。在目标函数的执行过程中,sample_independent()
用于对不属于该相对搜索空间的参数进行采样。
备注
更多细节参见 BaseSampler
的文档。
案例: 实现模拟退火 Sampler (SimulatedAnnealingSampler)
下面的代码根据 Simulated Annealing (SA) 定义类一个 sampler:
import numpy as np
import optuna
class SimulatedAnnealingSampler(optuna.samplers.BaseSampler):
def __init__(self, temperature=100):
self._rng = np.random.RandomState()
self._temperature = temperature # Current temperature.
self._current_trial = None # Current state.
def sample_relative(self, study, trial, search_space):
if search_space == {}:
return {}
# Simulated Annealing algorithm.
# 1. Calculate transition probability.
prev_trial = study.trials[-2]
if self._current_trial is None or prev_trial.value <= self._current_trial.value:
probability = 1.0
else:
probability = np.exp(
(self._current_trial.value - prev_trial.value) / self._temperature
)
self._temperature *= 0.9 # Decrease temperature.
# 2. Transit the current state if the previous result is accepted.
if self._rng.uniform(0, 1) < probability:
self._current_trial = prev_trial
# 3. Sample parameters from the neighborhood of the current point.
# The sampled parameters will be used during the next execution of
# the objective function passed to the study.
params = {}
for param_name, param_distribution in search_space.items():
if not isinstance(param_distribution, optuna.distributions.UniformDistribution):
raise NotImplementedError("Only suggest_float() is supported")
current_value = self._current_trial.params[param_name]
width = (param_distribution.high - param_distribution.low) * 0.1
neighbor_low = max(current_value - width, param_distribution.low)
neighbor_high = min(current_value + width, param_distribution.high)
params[param_name] = self._rng.uniform(neighbor_low, neighbor_high)
return params
# The rest are unrelated to SA algorithm: boilerplate
def infer_relative_search_space(self, study, trial):
return optuna.samplers.intersection_search_space(study)
def sample_independent(self, study, trial, param_name, param_distribution):
independent_sampler = optuna.samplers.RandomSampler()
return independent_sampler.sample_independent(study, trial, param_name, param_distribution)
备注
为了代码的简洁性,上面的实现没有支持一些特性 (比如 maximization). 如果你对如何实现这些特性感兴趣,请看 examples/samplers/simulated_annealing.py.
你可以像使用内置的 sampler 一样使用 SimulatedAnnealingSampler
:
def objective(trial):
x = trial.suggest_float("x", -10, 10)
y = trial.suggest_float("y", -5, 5)
return x ** 2 + y
sampler = SimulatedAnnealingSampler()
study = optuna.create_study(sampler=sampler)
study.optimize(objective, n_trials=100)
best_trial = study.best_trial
print("Best value: ", best_trial.value)
print("Parameters that achieve the best value: ", best_trial.params)
Out:
Best value: -4.823887940146913
Parameters that achieve the best value: {'x': -0.16867044642607176, 'y': -4.852337659644483}
在上面这个优化过程中,参数 x
和 y
的值都是由 SimulatedAnnealingSampler.sample_relative
方法采样得出的。
备注
严格意义上说,在第一个 trial 中,SimulatedAnnealingSampler.sample_independent
用于采样参数值。因为,如果没有已经完成的 trial 的话, SimulatedAnnealingSampler.infer_relative_search_space
中的 intersection_search_space()
是无法对搜索空间进行推断的。
Total running time of the script: ( 0 minutes 0.373 seconds)