optuna.integration¶
The integration
module contains classes used to integrate Optuna with external machine learning frameworks.
For most of the ML frameworks supported by Optuna, the corresponding Optuna integration class serves only to implement a callback object and functions, compliant with the framework’s specific callback API, to be called with each intermediate step in the model training. The functionality implemented in these callbacks across the different ML frameworks includes:
Reporting intermediate model scores back to the Optuna trial using
optuna.trial.report()
,According to the results of
optuna.trial.Trial.should_prune()
, pruning the current model by raisingoptuna.TrialPruned()
, andReporting intermediate Optuna data such as the current trial number back to the framework, as done in
MLflowCallback
.
For scikit-learn, an integrated OptunaSearchCV
estimator is available that combines scikit-learn BaseEstimator functionality with access to a class-level Study
object.
AllenNLP¶
AllenNLP extension to use optuna with Jsonnet config file. |
|
Save JSON config file with environment variables and best performing hyperparameters. |
|
AllenNLP callback to prune unpromising trials. |
BoTorch¶
A sampler that uses BoTorch, a Bayesian optimization library built on top of PyTorch. |
|
Quasi MC-based batch Expected Improvement (qEI). |
|
Quasi MC-based batch Expected Hypervolume Improvement (qEHVI). |
|
Quasi MC-based extended ParEGO (qParEGO) for constrained multi-objective optimization. |
Catalyst¶
Catalyst callback to prune unpromising trials. |
Chainer¶
Chainer extension to prune unpromising trials. |
|
A wrapper of |
fast.ai¶
FastAI callback to prune unpromising trials for fastai. |
|
FastAI callback to prune unpromising trials for fastai. |
|
alias of |
Keras¶
Keras callback to prune unpromising trials. |
LightGBM¶
Callback for LightGBM to prune unpromising trials. |
|
Wrapper of LightGBM Training API to tune hyperparameters. |
|
Hyperparameter tuner for LightGBM. |
|
Hyperparameter tuner for LightGBM with cross-validation. |
MLflow¶
Callback to track Optuna trials with MLflow. |
Weights & Biases¶
Callback to track Optuna trials with Weights & Biases. |
MXNet¶
MXNet callback to prune unpromising trials. |
pycma¶
A Sampler using cma library as the backend. |
|
Wrapper class of PyCmaSampler for backward compatibility. |
PyTorch¶
PyTorch Ignite handler to prune unpromising trials. |
|
PyTorch Lightning callback to prune unpromising trials. |
|
A wrapper of |
scikit-learn¶
Hyperparameter search with cross-validation. |
scikit-optimize¶
Sampler using Scikit-Optimize as the backend. |
skorch¶
Skorch callback to prune unpromising trials. |
TensorFlow¶
Callback to track Optuna trials with TensorBoard. |
|
TensorFlow SessionRunHook to prune unpromising trials. |
|
tf.keras callback to prune unpromising trials. |
XGBoost¶
Callback for XGBoost to prune unpromising trials. |