Source code for optuna.visualization._timeline

from __future__ import annotations

import datetime
from typing import NamedTuple

from optuna._experimental import experimental_func
from optuna.logging import get_logger
from optuna.study import Study
from optuna.trial import TrialState
from optuna.visualization._plotly_imports import _imports
from optuna.visualization._utils import _make_hovertext


if _imports.is_successful():
    from optuna.visualization._plotly_imports import go

_logger = get_logger(__name__)


class _TimelineBarInfo(NamedTuple):
    number: int
    start: datetime.datetime
    complete: datetime.datetime
    state: TrialState
    hovertext: str


class _TimelineInfo(NamedTuple):
    bars: list[_TimelineBarInfo]


[docs]@experimental_func("3.2.0") def plot_timeline(study: Study) -> "go.Figure": """Plot the timeline of a study. Example: The following code snippet shows how to plot the timeline of a study. Timeline plot can visualize trials with overlapping execution time (e.g., in distributed environments). .. plotly:: import time import optuna def objective(trial): x = trial.suggest_float("x", 0, 1) time.sleep(x * 0.1) if x > 0.8: raise ValueError() if x > 0.4: raise optuna.TrialPruned() return x ** 2 study = optuna.create_study(direction="minimize") study.optimize( objective, n_trials=50, n_jobs=2, catch=(ValueError,) ) fig = optuna.visualization.plot_timeline(study) fig.show() Args: study: A :class:`~optuna.study.Study` object whose trials are plotted with their lifetime. Returns: A :class:`plotly.graph_objs.Figure` object. """ _imports.check() info = _get_timeline_info(study) return _get_timeline_plot(info)
def _get_timeline_info(study: Study) -> _TimelineInfo: bars = [] for t in study.get_trials(deepcopy=False): date_complete = t.datetime_complete or datetime.datetime.now() date_start = t.datetime_start or date_complete if date_complete < date_start: _logger.warning( ( f"The start and end times for Trial {t.number} seem to be reversed. " f"The start time is {date_start} and the end time is {date_complete}." ) ) bars.append( _TimelineBarInfo( number=t.number, start=date_start, complete=date_complete, state=t.state, hovertext=_make_hovertext(t), ) ) if len(bars) == 0: _logger.warning("Your study does not have any trials.") return _TimelineInfo(bars) def _get_timeline_plot(info: _TimelineInfo) -> "go.Figure": _cm = { "COMPLETE": "blue", "FAIL": "red", "PRUNED": "orange", "RUNNING": "green", "WAITING": "gray", } fig = go.Figure() for s in sorted(TrialState, key=lambda x: x.name): bars = [b for b in info.bars if b.state == s] if len(bars) == 0: continue fig.add_trace( go.Bar( name=s.name, x=[(b.complete - b.start).total_seconds() * 1000 for b in bars], y=[b.number for b in bars], base=[b.start.isoformat() for b in bars], text=[b.hovertext for b in bars], hovertemplate="%{text}<extra>" + s.name + "</extra>", orientation="h", marker=dict(color=_cm[s.name]), textposition="none", # Avoid drawing hovertext in a bar. ) ) fig.update_xaxes(type="date") fig.update_layout( go.Layout( title="Timeline Plot", xaxis={"title": "Datetime"}, yaxis={"title": "Trial"}, ) ) fig.update_layout(showlegend=True) # Draw a legend even if all TrialStates are the same. return fig