optuna.integration.FastAIV2PruningCallback

class optuna.integration.FastAIV2PruningCallback(trial, monitor='valid_loss')[source]

FastAI callback to prune unpromising trials for fastai.

Note

This callback is for fastai>=2.0.

See the example if you want to add a pruning callback which monitors validation loss of a Learner.

Example

Register a pruning callback to learn.fit and learn.fit_one_cycle.

learn = cnn_learner(dls, resnet18, metrics=[error_rate])
learn.fit(n_epochs, cbs=[FastAIPruningCallback(trial)])  # Monitor "valid_loss"
learn.fit_one_cycle(
    n_epochs,
    lr_max,
    cbs=[FastAIPruningCallback(trial, monitor="error_rate")],  # Monitor "error_rate"
)
Parameters

Methods

after_epoch()

after_fit()