Source code for optuna.visualization._param_importances

from __future__ import annotations

from typing import Callable
from typing import NamedTuple

import optuna
from optuna.distributions import BaseDistribution
from optuna.importance._base import BaseImportanceEvaluator
from optuna.logging import get_logger
from optuna.study import Study
from optuna.trial import FrozenTrial
from optuna.trial import TrialState
from optuna.visualization._plotly_imports import _imports
from optuna.visualization._utils import _check_plot_args
from optuna.visualization._utils import _filter_nonfinite


if _imports.is_successful():
    import plotly

    from optuna.visualization._plotly_imports import go


logger = get_logger(__name__)


class _ImportancesInfo(NamedTuple):
    importance_values: list[float]
    param_names: list[str]
    importance_labels: list[str]
    target_name: str


def _get_importances_info(
    study: Study,
    evaluator: BaseImportanceEvaluator | None,
    params: list[str] | None,
    target: Callable[[FrozenTrial], float] | None,
    target_name: str,
) -> _ImportancesInfo:
    _check_plot_args(study, target, target_name)

    trials = _filter_nonfinite(
        study.get_trials(deepcopy=False, states=(TrialState.COMPLETE,)), target=target
    )

    if len(trials) == 0:
        logger.warning("Study instance does not contain completed trials.")
        return _ImportancesInfo(
            importance_values=[],
            param_names=[],
            importance_labels=[],
            target_name=target_name,
        )

    importances = optuna.importance.get_param_importances(
        study, evaluator=evaluator, params=params, target=target
    )

    importances = dict(reversed(list(importances.items())))
    importance_values = list(importances.values())
    param_names = list(importances.keys())
    importance_labels = [f"{val:.2f}" if val >= 0.01 else "<0.01" for val in importance_values]

    return _ImportancesInfo(
        importance_values=importance_values,
        param_names=param_names,
        importance_labels=importance_labels,
        target_name=target_name,
    )


[docs]def plot_param_importances( study: Study, evaluator: BaseImportanceEvaluator | None = None, params: list[str] | None = None, *, target: Callable[[FrozenTrial], float] | None = None, target_name: str = "Objective Value", ) -> "go.Figure": """Plot hyperparameter importances. Example: The following code snippet shows how to plot hyperparameter importances. .. plotly:: import optuna def objective(trial): x = trial.suggest_int("x", 0, 2) y = trial.suggest_float("y", -1.0, 1.0) z = trial.suggest_float("z", 0.0, 1.5) return x ** 2 + y ** 3 - z ** 4 sampler = optuna.samplers.RandomSampler(seed=10) study = optuna.create_study(sampler=sampler) study.optimize(objective, n_trials=100) fig = optuna.visualization.plot_param_importances(study) fig.show() .. seealso:: This function visualizes the results of :func:`optuna.importance.get_param_importances`. Args: study: An optimized study. evaluator: An importance evaluator object that specifies which algorithm to base the importance assessment on. Defaults to :class:`~optuna.importance.FanovaImportanceEvaluator`. params: A list of names of parameters to assess. If :obj:`None`, all parameters that are present in all of the completed trials are assessed. target: A function to specify the value to display. If it is :obj:`None` and ``study`` is being used for single-objective optimization, the objective values are plotted. .. note:: Specify this argument if ``study`` is being used for multi-objective optimization. For example, to get the hyperparameter importance of the first objective, use ``target=lambda t: t.values[0]`` for the target parameter. target_name: Target's name to display on the axis label. Returns: A :class:`plotly.graph_objs.Figure` object. """ _imports.check() importances_info = _get_importances_info(study, evaluator, params, target, target_name) hover_template = _get_hover_template(importances_info, study) return _get_importances_plot(importances_info, hover_template)
def _get_importances_plot(info: _ImportancesInfo, hover_template: list[str]) -> "go.Figure": layout = go.Layout( title="Hyperparameter Importances", xaxis={"title": f"Importance for {info.target_name}"}, yaxis={"title": "Hyperparameter"}, showlegend=False, ) param_names = info.param_names importance_values = info.importance_values if len(importance_values) == 0: return go.Figure(data=[], layout=layout) fig = go.Figure( data=[ go.Bar( x=importance_values, y=param_names, text=info.importance_labels, textposition="outside", cliponaxis=False, # Ensure text is not clipped. hovertemplate=hover_template, marker_color=plotly.colors.sequential.Blues[-4], orientation="h", ) ], layout=layout, ) return fig def _get_distribution(param_name: str, study: Study) -> BaseDistribution: for trial in study.trials: if param_name in trial.distributions: return trial.distributions[param_name] assert False def _make_hovertext(param_name: str, importance: float, study: Study) -> str: return "{} ({}): {}<extra></extra>".format( param_name, _get_distribution(param_name, study).__class__.__name__, importance ) def _get_hover_template(importances_info: _ImportancesInfo, study: Study) -> list[str]: return [ _make_hovertext(param_name, importance, study) for param_name, importance in zip( importances_info.param_names, importances_info.importance_values ) ]