optuna.integration
The integration
module contains classes used to integrate Optuna with external machine learning frameworks.
Note
Optuna’s integration modules for third-party libraries have started migrating from Optuna itself to a package called optuna-integration. Please check the repository and the documentation.
For most of the ML frameworks supported by Optuna, the corresponding Optuna integration class serves only to implement a callback object and functions, compliant with the framework’s specific callback API, to be called with each intermediate step in the model training. The functionality implemented in these callbacks across the different ML frameworks includes:
Reporting intermediate model scores back to the Optuna trial using
optuna.trial.Trial.report()
,According to the results of
optuna.trial.Trial.should_prune()
, pruning the current model by raisingoptuna.TrialPruned()
, andReporting intermediate Optuna data such as the current trial number back to the framework, as done in
MLflowCallback
.
For scikit-learn, an integrated OptunaSearchCV
estimator is available that combines scikit-learn BaseEstimator functionality with access to a class-level Study
object.
Dependencies of each integration
We summarize the necessary dependencies for each integration.
Integration |
Dependencies |
---|---|
allennlp, torch, psutil, jsonnet |
|
botorch, gpytorch, torch |
|
catboost |
|
chainermn |
|
chainer |
|
cma |
|
distributed |
|
fastai |
|
keras |
|
lightgbm, scikit-learn |
|
lightgbm |
|
mlflow |
|
mxnet |
|
PyTorch Distributed |
torch |
PyTorch (Ignite) |
pytorch-ignite |
PyTorch (Lightning) |
pytorch-lightning |
scikit-learn, shap |
|
pandas, scipy, scikit-learn |
|
skorch |
|
tensorboard, tensorflow |
|
tensorflow, tensorflow-estimator |
|
tensorflow |
|
wandb |
|
xgboost |