Source code for optuna.integration.chainer

from typing import Tuple
from typing import Union

import optuna

with optuna._imports.try_import() as _imports:
    import chainer
    from import Extension
    from import IntervalTrigger
    from import ManualScheduleTrigger

if not _imports.is_successful():
    Extension = object  # type: ignore # NOQA

[docs]class ChainerPruningExtension(Extension): """Chainer extension to prune unpromising trials. See `the example < chainer/>`__ if you want to add a pruning extension which observes validation accuracy of a `Chainer Trainer < reference/generated/>`_. Args: trial: A :class:`~optuna.trial.Trial` corresponding to the current evaluation of the objective function. observation_key: An evaluation metric for pruning, e.g., ``main/loss`` and ``validation/main/accuracy``. Please refer to `chainer.Reporter reference < util/generated/chainer.Reporter.html>`_ for further details. pruner_trigger: A trigger to execute pruning. ``pruner_trigger`` is an instance of `IntervalTrigger <>`_ or `ManualScheduleTrigger <>`_. `IntervalTrigger <https:// IntervalTrigger.html>`_ can be specified by a tuple of the interval length and its unit like ``(1, 'epoch')``. """ def __init__( self, trial: optuna.trial.Trial, observation_key: str, pruner_trigger: Union[Tuple[(int, str)], "IntervalTrigger", "ManualScheduleTrigger"], ) -> None: _imports.check() self._trial = trial self._observation_key = observation_key self._pruner_trigger = if not isinstance(self._pruner_trigger, (IntervalTrigger, ManualScheduleTrigger)): pruner_type = type(self._pruner_trigger) raise TypeError( "Invalid trigger class: " + str(pruner_type) + "\n" "Pruner trigger is supposed to be an instance of " "IntervalTrigger or ManualScheduleTrigger." ) @staticmethod def _get_float_value(observation_value: Union[float, "chainer.Variable"]) -> float: _imports.check() if isinstance(observation_value, chainer.Variable): observation_value = # type: ignore try: observation_value = float(observation_value) # type: ignore except TypeError: raise TypeError( "Type of observation value is not supported by ChainerPruningExtension.\n" "{} cannot be cast to float.".format(type(observation_value)) ) from None return observation_value def _observation_exists(self, trainer: "") -> bool: return self._pruner_trigger(trainer) and self._observation_key in trainer.observation def __call__(self, trainer: "") -> None: if not self._observation_exists(trainer): return current_score = self._get_float_value(trainer.observation[self._observation_key]) current_step = getattr(trainer.updater, self._pruner_trigger.unit), step=current_step) if self._trial.should_prune(): message = "Trial was pruned at {} {}.".format(self._pruner_trigger.unit, current_step) raise optuna.TrialPruned(message)