from __future__ import annotations
from collections.abc import Callable
from optuna._experimental import experimental_func
from optuna.study import Study
from optuna.trial import FrozenTrial
from optuna.visualization._rank import _get_rank_info
from optuna.visualization._rank import _get_tick_info
from optuna.visualization._rank import _RankPlotInfo
from optuna.visualization._rank import _RankSubplotInfo
from optuna.visualization.matplotlib._matplotlib_imports import _imports
if _imports.is_successful():
from optuna.visualization.matplotlib._matplotlib_imports import Axes
from optuna.visualization.matplotlib._matplotlib_imports import PathCollection
from optuna.visualization.matplotlib._matplotlib_imports import plt
[docs]
@experimental_func("3.2.0")
def plot_rank(
study: Study,
params: list[str] | None = None,
*,
target: Callable[[FrozenTrial], float] | None = None,
target_name: str = "Objective Value",
) -> "Axes":
"""Plot parameter relations as scatter plots with colors indicating ranks of target value.
Note that trials missing the specified parameters will not be plotted.
.. seealso::
Please refer to :func:`optuna.visualization.plot_rank` for an example.
Args:
study:
A :class:`~optuna.study.Study` object whose trials are plotted for their target values.
params:
Parameter list to visualize. The default is all parameters.
target:
A function to specify the value to display. If it is :obj:`None` and ``study`` is being
used for single-objective optimization, the objective values are plotted.
.. note::
Specify this argument if ``study`` is being used for multi-objective optimization.
target_name:
Target's name to display on the color bar.
Returns:
A :class:`matplotlib.axes.Axes` object.
"""
_imports.check()
info = _get_rank_info(study, params, target, target_name)
return _get_rank_plot(info)
def _get_rank_plot(
info: _RankPlotInfo,
) -> "Axes":
params = info.params
sub_plot_infos = info.sub_plot_infos
plt.style.use("ggplot") # Use ggplot style sheet for similar outputs to plotly.
title = f"Rank ({info.target_name})"
n_params = len(params)
if n_params == 0:
_, ax = plt.subplots()
ax.set_title(title)
return ax
if n_params == 1 or n_params == 2:
fig, axs = plt.subplots()
axs.set_title(title)
pc = _add_rank_subplot(axs, sub_plot_infos[0][0])
else:
fig, axs = plt.subplots(n_params, n_params)
fig.suptitle(title)
for x_i in range(n_params):
for y_i in range(n_params):
ax = axs[x_i, y_i]
# Set the x or y label only if the subplot is in the edge of the overall figure.
pc = _add_rank_subplot(
ax,
sub_plot_infos[x_i][y_i],
set_x_label=x_i == (n_params - 1),
set_y_label=y_i == 0,
)
tick_info = _get_tick_info(info.zs)
pc.set_cmap(plt.get_cmap("RdYlBu_r"))
cbar = fig.colorbar(pc, ax=axs, ticks=tick_info.coloridxs)
cbar.ax.set_yticklabels(tick_info.text)
cbar.outline.set_edgecolor("gray")
return axs
def _add_rank_subplot(
ax: "Axes", info: _RankSubplotInfo, set_x_label: bool = True, set_y_label: bool = True
) -> "PathCollection":
if set_x_label:
ax.set_xlabel(info.xaxis.name)
if set_y_label:
ax.set_ylabel(info.yaxis.name)
if not info.xaxis.is_cat:
ax.set_xlim(info.xaxis.range[0], info.xaxis.range[1])
if not info.yaxis.is_cat:
ax.set_ylim(info.yaxis.range[0], info.yaxis.range[1])
if info.xaxis.is_log:
ax.set_xscale("log")
if info.yaxis.is_log:
ax.set_yscale("log")
return ax.scatter(
x=[str(x) for x in info.xs] if info.xaxis.is_cat else info.xs,
y=[str(y) for y in info.ys] if info.yaxis.is_cat else info.ys,
c=info.colors / 255,
edgecolors="grey",
)