optuna.integration.PyTorchLightningPruningCallback
- class optuna.integration.PyTorchLightningPruningCallback(trial, monitor)[source]
PyTorch Lightning callback to prune unpromising trials.
See the example if you want to add a pruning callback which observes accuracy.
- Parameters:
trial (Trial) – A
Trialcorresponding to the current evaluation of the objective function.monitor (str) – An evaluation metric for pruning, e.g.,
val_lossorval_acc. The metrics are obtained from the returned dictionaries from e.g.pytorch_lightning.LightningModule.training_steporpytorch_lightning.LightningModule.validation_epoch_endand the names thus depend on how this dictionary is formatted.
Note
For the distributed data parallel training, the version of PyTorchLightning needs to be higher than or equal to v1.6.0. In addition,
Studyshould be instantiated with RDB storage.Note
If you would like to use PyTorchLightningPruningCallback in a distributed training environment, you need to evoke PyTorchLightningPruningCallback.check_pruned() manually so that
TrialPrunedis properly handled.Methods
Raise
optuna.TrialPrunedmanually if pruned.on_fit_start(trainer, pl_module)on_validation_end(trainer, pl_module)- check_pruned()[source]
Raise
optuna.TrialPrunedmanually if pruned.Currently,
intermediate_valuesare not properly propagated between processes due to storage cache. Therefore, necessary information is kept in trial_system_attrs when the trial runs in a distributed situation. Please call this method right after callingpytorch_lightning.Trainer.fit(). If a callback doesn’t have any backend storage for DDP, this method does nothing.- Return type:
None