Source code for optuna.integration.catalyst

from typing import Any  # NOQA

import optuna

from optuna._experimental import experimental
from optuna._imports import try_import

with try_import() as _imports:
    from catalyst.dl import Callback

if not _imports.is_successful():
    Callback = object  # NOQA


[docs]@experimental("2.0.0") class CatalystPruningCallback(Callback): """Catalyst callback to prune unpromising trials. See `the example <https://github.com/optuna/optuna/blob/master/ examples/catalyst_simple.py>`_ if you want to add a pruning callback which observes the accuracy of Catalyst's ``SupervisedRunner``. Args: trial: A :class:`~optuna.trial.Trial` corresponding to the current evaluation of the objective function. metric (str): Name of a metric, which is passed to `catalyst.core.State.valid_metrics` dictionary to fetch the value of metric computed on validation set. Pruning decision is made based on this value. """
[docs] def __init__(self, trial, metric="loss"): # type: (optuna.trial.Trial, str) -> None # set order=1000 to run pruning callback after other callbacks # refer to `catalyst.core.CallbackOrder` _imports.check() super(CatalystPruningCallback, self).__init__(order=1000) self._trial = trial self.metric = metric
def on_epoch_end(self, state): # type: (Any) -> None current_score = state.valid_metrics[self.metric] self._trial.report(current_score, state.epoch) if self._trial.should_prune(): message = "Trial was pruned at epoch {}.".format(state.epoch) raise optuna.TrialPruned(message)