Source code for optuna.visualization._timeline

from __future__ import annotations

import datetime
from typing import NamedTuple

from optuna.logging import get_logger
from optuna.samplers._base import _CONSTRAINTS_KEY
from optuna.study import Study
from optuna.trial import TrialState
from optuna.visualization._plotly_imports import _imports
from optuna.visualization._utils import _make_hovertext


if _imports.is_successful():
    from optuna.visualization._plotly_imports import go

_logger = get_logger(__name__)


class _TimelineBarInfo(NamedTuple):
    number: int
    start: datetime.datetime
    complete: datetime.datetime
    state: TrialState
    hovertext: str
    infeasible: bool


class _TimelineInfo(NamedTuple):
    bars: list[_TimelineBarInfo]


[docs] def plot_timeline(study: Study) -> "go.Figure": """Plot the timeline of a study. Args: study: A :class:`~optuna.study.Study` object whose trials are plotted with their lifetime. Returns: A :class:`plotly.graph_objects.Figure` object. """ _imports.check() info = _get_timeline_info(study) return _get_timeline_plot(info)
def _get_max_datetime_complete(study: Study) -> datetime.datetime: max_run_duration = max( [ t.datetime_complete - t.datetime_start for t in study.trials if t.datetime_complete is not None and t.datetime_start is not None ], default=None, ) if _is_running_trials_in_study(study, max_run_duration): return datetime.datetime.now() return max( [t.datetime_complete for t in study.trials if t.datetime_complete is not None], default=datetime.datetime.now(), ) def _is_running_trials_in_study(study: Study, max_run_duration: datetime.timedelta | None) -> bool: running_trials = study.get_trials(states=(TrialState.RUNNING,), deepcopy=False) if max_run_duration is None: return len(running_trials) > 0 now = datetime.datetime.now() # This heuristic is to check whether we have trials that were somehow killed, # still remain as `RUNNING` in `study`. return any( now - t.datetime_start < 5 * max_run_duration for t in running_trials # MyPy redefinition: Running trial should have datetime_start. if t.datetime_start is not None ) def _get_timeline_info(study: Study) -> _TimelineInfo: bars = [] max_datetime = _get_max_datetime_complete(study) timedelta_for_small_bar = datetime.timedelta(seconds=1) for trial in study.get_trials(deepcopy=False): datetime_start = trial.datetime_start or max_datetime datetime_complete = ( max_datetime + timedelta_for_small_bar if trial.state == TrialState.RUNNING else trial.datetime_complete or datetime_start + timedelta_for_small_bar ) infeasible = ( False if _CONSTRAINTS_KEY not in trial.system_attrs else any([x > 0 for x in trial.system_attrs[_CONSTRAINTS_KEY]]) ) if datetime_complete < datetime_start: _logger.warning( ( f"The start and end times for Trial {trial.number} seem to be reversed. " f"The start time is {datetime_start} and the end time is {datetime_complete}." ) ) bars.append( _TimelineBarInfo( number=trial.number, start=datetime_start, complete=datetime_complete, state=trial.state, hovertext=_make_hovertext(trial), infeasible=infeasible, ) ) if len(bars) == 0: _logger.warning("Your study does not have any trials.") return _TimelineInfo(bars) def _get_timeline_plot(info: _TimelineInfo) -> "go.Figure": _cm = { "COMPLETE": "blue", "FAIL": "red", "PRUNED": "orange", "RUNNING": "green", "WAITING": "gray", } fig = go.Figure() for state in sorted(TrialState, key=lambda x: x.name): if state.name == "COMPLETE": infeasible_bars = [b for b in info.bars if b.state == state and b.infeasible] feasible_bars = [b for b in info.bars if b.state == state and not b.infeasible] _plot_bars(infeasible_bars, "#cccccc", "INFEASIBLE", fig) _plot_bars(feasible_bars, _cm[state.name], state.name, fig) else: bars = [b for b in info.bars if b.state == state] _plot_bars(bars, _cm[state.name], state.name, fig) fig.update_xaxes(type="date") fig.update_layout( go.Layout( title="Timeline Plot", xaxis={"title": "Datetime"}, yaxis={"title": "Trial"}, ) ) fig.update_layout(showlegend=True) # Draw a legend even if all TrialStates are the same. return fig def _plot_bars(bars: list[_TimelineBarInfo], color: str, name: str, fig: go.Figure) -> None: if len(bars) == 0: return fig.add_trace( go.Bar( name=name, x=[(b.complete - b.start).total_seconds() * 1000 for b in bars], y=[b.number for b in bars], base=[b.start.isoformat() for b in bars], text=[b.hovertext for b in bars], hovertemplate="%{text}<extra>" + name + "</extra>", orientation="h", marker=dict(color=color), textposition="none", # Avoid drawing hovertext in a bar. ) )