Source code for optuna.importance._fanova._evaluator

from collections import OrderedDict
from typing import Dict
from typing import List
from typing import Optional

import numpy

from optuna.distributions import CategoricalDistribution
from optuna.distributions import DiscreteUniformDistribution
from optuna.distributions import IntUniformDistribution
from optuna.distributions import LogUniformDistribution
from optuna.distributions import UniformDistribution
from optuna.importance._base import _get_distributions
from optuna.importance._base import _get_study_data
from optuna.importance._base import BaseImportanceEvaluator
from optuna.importance._fanova._fanova import _Fanova
from optuna.study import Study


[docs]class FanovaImportanceEvaluator(BaseImportanceEvaluator): """fANOVA importance evaluator. Implements the fANOVA hyperparameter importance evaluation algorithm in `An Efficient Approach for Assessing Hyperparameter Importance <http://proceedings.mlr.press/v32/hutter14.html>`_. Given a study, fANOVA fits a random forest regression model that predicts the objective value given a parameter configuration. The more accurate this model is, the more reliable the importances assessed by this class are. .. note:: Requires the `sklearn <https://github.com/scikit-learn/scikit-learn>`_ Python package. .. note:: Pairwise and higher order importances are not supported through this class. They can be computed using :class:`~optuna.importance._fanova._fanova._Fanova` directly but is not recommended as interfaces may change without prior notice. .. note:: The performance of fANOVA depends on the prediction performance of the underlying random forest model. In order to obtain high prediction performance, it is necessary to cover a wide range of the hyperparameter search space. It is recommended to use an exploration-oriented sampler such as :class:`~optuna.samplers.RandomSampler`. .. note:: For how to cite the original work, please refer to https://automl.github.io/fanova/cite.html. Args: n_trees: The number of trees in the forest. max_depth: The maximum depth of the trees in the forest. seed: Controls the randomness of the forest. For deterministic behavior, specify a value other than :obj:`None`. """
[docs] def __init__( self, *, n_trees: int = 64, max_depth: int = 64, seed: Optional[int] = None ) -> None: self._evaluator = _Fanova( n_trees=n_trees, max_depth=max_depth, min_samples_split=2, min_samples_leaf=1, seed=seed, )
[docs] def evaluate(self, study: Study, params: Optional[List[str]] = None) -> Dict[str, float]: distributions = _get_distributions(study, params) params_data, values_data = _get_study_data(study, distributions) if params_data.size == 0: # `params` were given but as an empty list. return OrderedDict() # Many (deep) copies of the search spaces are required during the tree traversal and using # Optuna distributions will create a bottleneck. # Therefore, search spaces (parameter distributions) are represented by a single # `numpy.ndarray`, coupled with a list of flags that indicate whether they are categorical # or not. search_spaces = numpy.empty((len(distributions), 2), dtype=numpy.float64) search_spaces_is_categorical = [] for i, distribution in enumerate(distributions.values()): if isinstance(distribution, CategoricalDistribution): search_spaces[i, 0] = 0 search_spaces[i, 1] = len(distribution.choices) search_spaces_is_categorical.append(True) elif isinstance( distribution, ( DiscreteUniformDistribution, IntUniformDistribution, LogUniformDistribution, UniformDistribution, ), ): search_spaces[i, 0] = distribution.low search_spaces[i, 1] = distribution.high search_spaces_is_categorical.append(False) else: assert False evaluator = self._evaluator evaluator.fit( X=params_data, y=values_data, search_spaces=search_spaces, search_spaces_is_categorical=search_spaces_is_categorical, ) importances = {} for i, name in enumerate(distributions.keys()): importance, _ = evaluator.get_importance((i,)) importances[name] = importance total_importance = sum(importances.values()) for name in importances.keys(): importances[name] /= total_importance sorted_importances = OrderedDict( reversed( sorted( importances.items(), key=lambda name_and_importance: name_and_importance[1], ) ) ) return sorted_importances