Source code for optuna.terminator.erroreval

from __future__ import annotations

import abc
from typing import cast

import numpy as np

from optuna._experimental import experimental_class
from optuna.study import StudyDirection
from optuna.trial import FrozenTrial
from optuna.trial import Trial
from optuna.trial._state import TrialState


_CROSS_VALIDATION_SCORES_KEY = "terminator:cv_scores"


[docs]class BaseErrorEvaluator(metaclass=abc.ABCMeta): """Base class for error evaluators.""" @abc.abstractmethod def evaluate( self, trials: list[FrozenTrial], study_direction: StudyDirection, ) -> float: pass
[docs]@experimental_class("3.2.0") class CrossValidationErrorEvaluator(BaseErrorEvaluator): """An error evaluator for objective functions based on cross-validation. This evaluator evaluates the objective function's statistical error, which comes from the randomness of dataset. This evaluator assumes that the objective function is the average of the cross-validation and uses the scaled variance of the cross-validation scores in the best trial at the moment as the statistical error. """
[docs] def evaluate( self, trials: list[FrozenTrial], study_direction: StudyDirection, ) -> float: """Evaluate the statistical error of the objective function based on cross-validation. Args: trials: A list of trials to consider. The best trial in ``trials`` is used to compute the statistical error. study_direction: The direction of the study. Returns: A float representing the statistical error of the objective function. """ trials = [trial for trial in trials if trial.state == TrialState.COMPLETE] assert len(trials) > 0 if study_direction == StudyDirection.MAXIMIZE: best_trial = max(trials, key=lambda t: cast(float, t.value)) else: best_trial = min(trials, key=lambda t: cast(float, t.value)) best_trial_attrs = best_trial.system_attrs if _CROSS_VALIDATION_SCORES_KEY in best_trial_attrs: cv_scores = best_trial_attrs[_CROSS_VALIDATION_SCORES_KEY] else: raise ValueError( "Cross-validation scores have not been reported. Please call " "`report_cross_validation_scores(trial, scores)` during a trial and pass the " "list of scores as `scores`." ) k = len(cv_scores) assert k > 1, "Should be guaranteed by `report_cross_validation_scores`." scale = 1 / k + 1 / (k - 1) var = scale * np.var(cv_scores) std = np.sqrt(var) return float(std)
[docs]@experimental_class("3.2.0") def report_cross_validation_scores(trial: Trial, scores: list[float]) -> None: """A function to report cross-validation scores of a trial. This function should be called within the objective function to report the cross-validation scores. The reported scores are used to evaluate the statistical error for termination judgement. Args: trial: A :class:`~optuna.trial.Trial` object to report the cross-validation scores. scores: The cross-validation scores of the trial. """ if len(scores) <= 1: raise ValueError("The length of `scores` is expected to be greater than one.") trial.storage.set_trial_system_attr(trial._trial_id, _CROSS_VALIDATION_SCORES_KEY, scores)
[docs]@experimental_class("3.2.0") class StaticErrorEvaluator(BaseErrorEvaluator): """An error evaluator that always returns a constant value. This evaluator can be used to terminate the optimization when the evaluated improvement potential is below the fixed threshold. Args: constant: A user-specified constant value to always return as an error estimate. """ def __init__(self, constant: float) -> None: self._constant = constant def evaluate( self, trials: list[FrozenTrial], study_direction: StudyDirection, ) -> float: return self._constant