Source code for optuna.integration.fastai

import optuna
from optuna._imports import try_import
from optuna import type_checking

if type_checking.TYPE_CHECKING:
    from typing import Any  # NOQA

with try_import() as _imports:
    from fastai.basic_train import Learner  # NOQA
    from fastai.callbacks import TrackerCallback

if not _imports.is_successful():
    TrackerCallback = object  # NOQA

[docs]class FastAIPruningCallback(TrackerCallback): """FastAI callback to prune unpromising trials for fastai. .. note:: This callback is for fastai<2.0, not the coming version developed in fastai/fastai_dev. See `the example < examples/>`__ if you want to add a pruning callback which monitors validation loss of a ``Learner``. Example: Register a pruning callback to ```` and ``learn.fit_one_cycle``. .. code::, callbacks=[FastAIPruningCallback(learn, trial, 'valid_loss')]) learn.fit_one_cycle( n_epochs, cyc_len, max_lr, callbacks=[FastAIPruningCallback(learn, trial, 'valid_loss')]) Args: learn: `fastai.basic_train.Learner <>`_. trial: A :class:`~optuna.trial.Trial` corresponding to the current evaluation of the objective function. monitor: An evaluation metric for pruning, e.g. ``valid_loss`` and ``Accuracy``. Please refer to `fastai.Callback reference <>`_ for further details. """
[docs] def __init__(self, learn, trial, monitor): # type: (Learner, optuna.trial.Trial, str) -> None super(FastAIPruningCallback, self).__init__(learn, monitor) _imports.check() self._trial = trial
def on_epoch_end(self, epoch, **kwargs): # type: (int, Any) -> None value = self.get_monitor_value() if value is None: return # This conversion is necessary to avoid problems reported in issues. # - # -, step=epoch) if self._trial.should_prune(): message = "Trial was pruned at epoch {}.".format(epoch) raise optuna.TrialPruned(message)