Source code for optuna.visualization._slice

from typing import List
from typing import Optional

from optuna.logging import get_logger
from optuna.study import Study
from optuna.trial import FrozenTrial
from optuna.trial import TrialState
from optuna.visualization._plotly_imports import _imports
from optuna.visualization._utils import _is_log_scale

if _imports.is_successful():
    from optuna.visualization._plotly_imports import go
    from optuna.visualization._plotly_imports import make_subplots
    from optuna.visualization._plotly_imports import Scatter

_logger = get_logger(__name__)


[docs]def plot_slice(study: Study, params: Optional[List[str]] = None) -> "go.Figure": """Plot the parameter relationship as slice plot in a study. Note that, If a parameter contains missing values, a trial with missing values is not plotted. Example: The following code snippet shows how to plot the parameter relationship as slice plot. .. testcode:: import optuna def objective(trial): x = trial.suggest_uniform('x', -100, 100) y = trial.suggest_categorical('y', [-1, 0, 1]) return x ** 2 + y study = optuna.create_study() study.optimize(objective, n_trials=10) optuna.visualization.plot_slice(study, params=['x', 'y']) .. raw:: html <iframe src="../../_static/plot_slice.html" width="100%" height="500px" frameborder="0"> </iframe> Args: study: A :class:`~optuna.study.Study` object whose trials are plotted for their objective values. params: Parameter list to visualize. The default is all parameters. Returns: A :class:`plotly.graph_objs.Figure` object. """ _imports.check() return _get_slice_plot(study, params)
def _get_slice_plot(study: Study, params: Optional[List[str]] = None) -> "go.Figure": layout = go.Layout(title="Slice Plot",) trials = [trial for trial in study.trials if trial.state == TrialState.COMPLETE] if len(trials) == 0: _logger.warning("Your study does not have any completed trials.") return go.Figure(data=[], layout=layout) all_params = {p_name for t in trials for p_name in t.params.keys()} if params is None: sorted_params = sorted(list(all_params)) else: for input_p_name in params: if input_p_name not in all_params: raise ValueError("Parameter {} does not exist in your study.".format(input_p_name)) sorted_params = sorted(list(set(params))) n_params = len(sorted_params) if n_params == 1: figure = go.Figure( data=[_generate_slice_subplot(study, trials, sorted_params[0])], layout=layout ) figure.update_xaxes(title_text=sorted_params[0]) figure.update_yaxes(title_text="Objective Value") if _is_log_scale(trials, sorted_params[0]): figure.update_xaxes(type="log") else: figure = make_subplots(rows=1, cols=len(sorted_params), shared_yaxes=True) figure.update_layout(layout) showscale = True # showscale option only needs to be specified once. for i, param in enumerate(sorted_params): trace = _generate_slice_subplot(study, trials, param) trace.update(marker={"showscale": showscale}) # showscale's default is True. if showscale: showscale = False figure.add_trace(trace, row=1, col=i + 1) figure.update_xaxes(title_text=param, row=1, col=i + 1) if i == 0: figure.update_yaxes(title_text="Objective Value", row=1, col=1) if _is_log_scale(trials, param): figure.update_xaxes(type="log", row=1, col=i + 1) if n_params > 3: # Ensure that each subplot has a minimum width without relying on autusizing. figure.update_layout(width=300 * n_params) return figure def _generate_slice_subplot(study: Study, trials: List[FrozenTrial], param: str) -> "Scatter": return go.Scatter( x=[t.params[param] for t in trials if param in t.params], y=[t.value for t in trials if param in t.params], mode="markers", marker={ "line": {"width": 0.5, "color": "Grey",}, "color": [t.number for t in trials if param in t.params], "colorscale": "Blues", "colorbar": { "title": "#Trials", "x": 1.0, # Offset the colorbar position with a fixed width `xpad`. "xpad": 40, }, }, showlegend=False, )