用于超参数优化分析的快速可视化

Optuna 在 optuna.visualization 里提供了各种可视化特性用于分析优化结果。

通过可视化乳腺癌数据集 上的 lightgbm 历史记录,本教程将带你体验此模块。

import lightgbm as lgb
import numpy as np
import sklearn.datasets
import sklearn.metrics
from sklearn.model_selection import train_test_split

import optuna
from optuna.visualization import plot_contour
from optuna.visualization import plot_edf
from optuna.visualization import plot_intermediate_values
from optuna.visualization import plot_optimization_history
from optuna.visualization import plot_parallel_coordinate
from optuna.visualization import plot_param_importances
from optuna.visualization import plot_slice

SEED = 42

np.random.seed(SEED)

定义目标函数

def objective(trial):
    data, target = sklearn.datasets.load_breast_cancer(return_X_y=True)
    train_x, valid_x, train_y, valid_y = train_test_split(data, target, test_size=0.25)
    dtrain = lgb.Dataset(train_x, label=train_y)
    dvalid = lgb.Dataset(valid_x, label=valid_y)

    param = {
        "objective": "binary",
        "metric": "auc",
        "verbosity": -1,
        "boosting_type": "gbdt",
        "bagging_fraction": trial.suggest_float("bagging_fraction", 0.4, 1.0),
        "bagging_freq": trial.suggest_int("bagging_freq", 1, 7),
        "min_child_samples": trial.suggest_int("min_child_samples", 5, 100),
    }

    # Add a callback for pruning.
    pruning_callback = optuna.integration.LightGBMPruningCallback(trial, "auc")
    gbm = lgb.train(
        param, dtrain, valid_sets=[dvalid], verbose_eval=False, callbacks=[pruning_callback]
    )

    preds = gbm.predict(valid_x)
    pred_labels = np.rint(preds)
    accuracy = sklearn.metrics.accuracy_score(valid_y, pred_labels)
    return accuracy
study = optuna.create_study(
    direction="maximize",
    sampler=optuna.samplers.TPESampler(seed=SEED),
    pruner=optuna.pruners.MedianPruner(n_warmup_steps=10),
)
study.optimize(objective, n_trials=100, timeout=600)

绘图函数

对优化历史进行可视化。细节见 plot_optimization_history().

plot_optimization_history(study)


绘制各 trial 的学习曲线。细节见 plot_intermediate_values().

plot_intermediate_values(study)


绘制高维参数关系。细节见 plot_parallel_coordinate().

plot_parallel_coordinate(study)


选择要绘制的参数

plot_parallel_coordinate(study, params=["bagging_freq", "bagging_fraction"])


绘制超参数关系。细节见 plot_contour().

plot_contour(study)


选择要绘制的参数

plot_contour(study, params=["bagging_freq", "bagging_fraction"])


绘制各超参数的切片图。细节见 plot_slice().

plot_slice(study)


选择要绘制的参数

plot_slice(study, params=["bagging_freq", "bagging_fraction"])


绘制超参数重要性。细节见 plot_param_importances().

plot_param_importances(study)


绘制经验分布函数。 plot_edf().

plot_edf(study)


Total running time of the script: ( 0 minutes 3.913 seconds)

Gallery generated by Sphinx-Gallery