from __future__ import annotations import datetime from typing import NamedTuple from optuna.logging import get_logger from optuna.samplers._base import _CONSTRAINTS_KEY from optuna.study import Study from optuna.trial import TrialState from optuna.visualization._plotly_imports import _imports from optuna.visualization._utils import _make_hovertext if _imports.is_successful(): from optuna.visualization._plotly_imports import go _logger = get_logger(__name__) class _TimelineBarInfo(NamedTuple): number: int start: datetime.datetime complete: datetime.datetime state: TrialState hovertext: str infeasible: bool class _TimelineInfo(NamedTuple): bars: list[_TimelineBarInfo] def plot_timeline(study: Study, n_recent_trials: int | None = None) -> "go.Figure": """Plot the timeline of a study. Args: study: A :class:`~optuna.study.Study` object whose trials are plotted with their lifetime. n_recent_trials: The number of recent trials to plot. If :obj:`None`, all trials are plotted. If specified, only the most recent ``n_recent_trials`` will be displayed. Must be a positive integer. Returns: A :class:`plotly.graph_objects.Figure` object. Raises: ValueError: if ``n_recent_trials`` is 0 or negative. """ if n_recent_trials is not None and n_recent_trials <= 0: raise ValueError("n_recent_trials must be a positive integer or None.") _imports.check() info = _get_timeline_info(study, n_recent_trials=n_recent_trials) return _get_timeline_plot(info) def _get_max_datetime_complete(study: Study) -> datetime.datetime: max_run_duration = max( [ t.datetime_complete - t.datetime_start for t in study.trials if t.datetime_complete is not None and t.datetime_start is not None ], default=None, ) if _is_running_trials_in_study(study, max_run_duration): return datetime.datetime.now() return max( [t.datetime_complete for t in study.trials if t.datetime_complete is not None], default=datetime.datetime.now(), ) def _is_running_trials_in_study(study: Study, max_run_duration: datetime.timedelta | None) -> bool: running_trials = study.get_trials(states=(TrialState.RUNNING,), deepcopy=False) if max_run_duration is None: return len(running_trials) > 0 now = datetime.datetime.now() # This heuristic is to check whether we have trials that were somehow killed, # still remain as `RUNNING` in `study`. return any( now - t.datetime_start < 5 * max_run_duration for t in running_trials # MyPy redefinition: Running trial should have datetime_start. if t.datetime_start is not None ) def _get_timeline_info(study: Study, n_recent_trials: int | None = None) -> _TimelineInfo: bars = [] max_datetime = _get_max_datetime_complete(study) timedelta_for_small_bar = datetime.timedelta(seconds=1) trials = study.get_trials(deepcopy=False) if n_recent_trials is not None: trials = trials[-n_recent_trials:] for trial in trials: datetime_start = trial.datetime_start or max_datetime datetime_complete = ( max_datetime + timedelta_for_small_bar if trial.state == TrialState.RUNNING else trial.datetime_complete or datetime_start + timedelta_for_small_bar ) infeasible = ( False if _CONSTRAINTS_KEY not in trial.system_attrs else any([x > 0 for x in trial.system_attrs[_CONSTRAINTS_KEY]]) ) if datetime_complete < datetime_start: _logger.warning( ( f"The start and end times for Trial {trial.number} seem to be reversed. " f"The start time is {datetime_start} and the end time is {datetime_complete}." ) ) bars.append( _TimelineBarInfo( number=trial.number, start=datetime_start, complete=datetime_complete, state=trial.state, hovertext=_make_hovertext(trial), infeasible=infeasible, ) ) if len(bars) == 0: _logger.warning("Your study does not have any trials.") return _TimelineInfo(bars) def _get_timeline_plot(info: _TimelineInfo) -> "go.Figure": _cm = { "COMPLETE": "blue", "FAIL": "red", "PRUNED": "orange", "RUNNING": "green", "WAITING": "gray", } fig = go.Figure() for state in sorted(TrialState, key=lambda x: x.name): if state.name == "COMPLETE": infeasible_bars = [b for b in info.bars if b.state == state and b.infeasible] feasible_bars = [b for b in info.bars if b.state == state and not b.infeasible] _plot_bars(infeasible_bars, "#cccccc", "INFEASIBLE", fig) _plot_bars(feasible_bars, _cm[state.name], state.name, fig) else: bars = [b for b in info.bars if b.state == state] _plot_bars(bars, _cm[state.name], state.name, fig) fig.update_xaxes(type="date") fig.update_layout( go.Layout( title="Timeline Plot", xaxis={"title": "Datetime"}, yaxis={"title": "Trial"}, ) ) fig.update_layout(showlegend=True) # Draw a legend even if all TrialStates are the same. return fig def _plot_bars(bars: list[_TimelineBarInfo], color: str, name: str, fig: go.Figure) -> None: if len(bars) == 0: return fig.add_trace( go.Bar( name=name, x=[(b.complete - b.start).total_seconds() * 1000 for b in bars], y=[b.number for b in bars], base=[b.start.isoformat() for b in bars], text=[b.hovertext for b in bars], hovertemplate="%{text}" + name + "", orientation="h", marker=dict(color=color), textposition="none", # Avoid drawing hovertext in a bar. ) )