Source code for optuna.visualization.matplotlib._contour

from __future__ import annotations

from typing import Callable
from typing import Sequence

import numpy as np

from optuna._experimental import experimental_func
from optuna._imports import try_import
from optuna.logging import get_logger
from optuna.study import Study
from optuna.trial import FrozenTrial
from optuna.visualization._contour import _AxisInfo
from optuna.visualization._contour import _ContourInfo
from optuna.visualization._contour import _get_contour_info
from optuna.visualization._contour import _SubContourInfo
from optuna.visualization.matplotlib._matplotlib_imports import _imports


with try_import() as _optuna_imports:
    import scipy

if _imports.is_successful():
    from optuna.visualization.matplotlib._matplotlib_imports import Axes
    from optuna.visualization.matplotlib._matplotlib_imports import Colormap
    from optuna.visualization.matplotlib._matplotlib_imports import ContourSet
    from optuna.visualization.matplotlib._matplotlib_imports import plt

_logger = get_logger(__name__)


CONTOUR_POINT_NUM = 100


[docs]@experimental_func("2.2.0") def plot_contour( study: Study, params: list[str] | None = None, *, target: Callable[[FrozenTrial], float] | None = None, target_name: str = "Objective Value", ) -> "Axes": """Plot the parameter relationship as contour plot in a study with Matplotlib. Note that, if a parameter contains missing values, a trial with missing values is not plotted. .. seealso:: Please refer to :func:`optuna.visualization.plot_contour` for an example. Warnings: Output figures of this Matplotlib-based :func:`~optuna.visualization.matplotlib.plot_contour` function would be different from those of the Plotly-based :func:`~optuna.visualization.plot_contour`. Example: The following code snippet shows how to plot the parameter relationship as contour plot. .. plot:: import optuna def objective(trial): x = trial.suggest_float("x", -100, 100) y = trial.suggest_categorical("y", [-1, 0, 1]) return x ** 2 + y sampler = optuna.samplers.TPESampler(seed=10) study = optuna.create_study(sampler=sampler) study.optimize(objective, n_trials=30) optuna.visualization.matplotlib.plot_contour(study, params=["x", "y"]) Args: study: A :class:`~optuna.study.Study` object whose trials are plotted for their target values. params: Parameter list to visualize. The default is all parameters. target: A function to specify the value to display. If it is :obj:`None` and ``study`` is being used for single-objective optimization, the objective values are plotted. .. note:: Specify this argument if ``study`` is being used for multi-objective optimization. target_name: Target's name to display on the color bar. Returns: A :class:`matplotlib.axes.Axes` object. .. note:: The colormap is reversed when the ``target`` argument isn't :obj:`None` or ``direction`` of :class:`~optuna.study.Study` is ``minimize``. """ _imports.check() _logger.warning( "Output figures of this Matplotlib-based `plot_contour` function would be different from " "those of the Plotly-based `plot_contour`." ) info = _get_contour_info(study, params, target, target_name) return _get_contour_plot(info)
def _get_contour_plot(info: _ContourInfo) -> "Axes": sorted_params = info.sorted_params sub_plot_infos = info.sub_plot_infos reverse_scale = info.reverse_scale target_name = info.target_name if len(sorted_params) <= 1: _, ax = plt.subplots() return ax n_params = len(sorted_params) plt.style.use("ggplot") # Use ggplot style sheet for similar outputs to plotly. if n_params == 2: # Set up the graph style. fig, axs = plt.subplots() axs.set_title("Contour Plot") cmap = _set_cmap(reverse_scale) cs = _generate_contour_subplot(sub_plot_infos[0][0], axs, cmap) if isinstance(cs, ContourSet): axcb = fig.colorbar(cs) axcb.set_label(target_name) else: # Set up the graph style. fig, axs = plt.subplots(n_params, n_params) fig.suptitle("Contour Plot") cmap = _set_cmap(reverse_scale) # Prepare data and draw contour plots. cs_list = [] for x_i in range(len(sorted_params)): for y_i in range(len(sorted_params)): ax = axs[y_i, x_i] cs = _generate_contour_subplot(sub_plot_infos[y_i][x_i], ax, cmap) if isinstance(cs, ContourSet): cs_list.append(cs) if cs_list: axcb = fig.colorbar(cs_list[0], ax=axs) axcb.set_label(target_name) return axs def _set_cmap(reverse_scale: bool) -> "Colormap": cmap = "Blues_r" if not reverse_scale else "Blues" return plt.get_cmap(cmap) class _LabelEncoder: def __init__(self) -> None: self.labels: list[str] = [] def fit(self, labels: list[str]) -> "_LabelEncoder": self.labels = sorted(set(labels)) return self def transform(self, labels: list[str]) -> list[int]: return [self.labels.index(label) for label in labels] def fit_transform(self, labels: list[str]) -> list[int]: return self.fit(labels).transform(labels) def get_labels(self) -> list[str]: return self.labels def get_indices(self) -> list[int]: return list(range(len(self.labels))) def _calculate_griddata( xaxis: _AxisInfo, yaxis: _AxisInfo, z_values_dict: dict[tuple[int, int], float], ) -> tuple[ np.ndarray, np.ndarray, np.ndarray, list[int], list[str], list[int], list[str], list[int | float], list[int | float], ]: x_values = [] y_values = [] z_values = [] for x_value, y_value in zip(xaxis.values, yaxis.values): if x_value is not None and y_value is not None: x_values.append(x_value) y_values.append(y_value) x_i = xaxis.indices.index(x_value) y_i = yaxis.indices.index(y_value) z_values.append(z_values_dict[(x_i, y_i)]) # Return empty values when x or y has no value. if len(x_values) == 0 or len(y_values) == 0: return np.array([]), np.array([]), np.array([]), [], [], [], [], [], [] def _calculate_axis_data( axis: _AxisInfo, values: Sequence[str | float], ) -> tuple[np.ndarray, list[str], list[int], list[int | float]]: # Convert categorical values to int. cat_param_labels: list[str] = [] cat_param_pos: list[int] = [] returned_values: Sequence[int | float] if axis.is_cat: enc = _LabelEncoder() returned_values = enc.fit_transform(list(map(str, values))) cat_param_labels = enc.get_labels() cat_param_pos = enc.get_indices() else: returned_values = list(map(lambda x: float(x), values)) # For x and y, create 1-D array of evenly spaced coordinates on linear or log scale. if axis.is_log: ci = np.logspace(np.log10(axis.range[0]), np.log10(axis.range[1]), CONTOUR_POINT_NUM) else: ci = np.linspace(axis.range[0], axis.range[1], CONTOUR_POINT_NUM) return ci, cat_param_labels, cat_param_pos, list(returned_values) xi, cat_param_labels_x, cat_param_pos_x, transformed_x_values = _calculate_axis_data( xaxis, x_values, ) yi, cat_param_labels_y, cat_param_pos_y, transformed_y_values = _calculate_axis_data( yaxis, y_values, ) # Calculate grid data points. zi: np.ndarray = np.array([]) # Create irregularly spaced map of trial values # and interpolate it with Plotly's interpolation formulation. if xaxis.name != yaxis.name: zmap = _create_zmap(transformed_x_values, transformed_y_values, z_values, xi, yi) zi = _interpolate_zmap(zmap, CONTOUR_POINT_NUM) return ( xi, yi, zi, cat_param_pos_x, cat_param_labels_x, cat_param_pos_y, cat_param_labels_y, transformed_x_values, transformed_y_values, ) def _generate_contour_subplot(info: _SubContourInfo, ax: "Axes", cmap: "Colormap") -> "ContourSet": if len(info.xaxis.indices) < 2 or len(info.yaxis.indices) < 2: ax.label_outer() return ax ax.set(xlabel=info.xaxis.name, ylabel=info.yaxis.name) ax.set_xlim(info.xaxis.range[0], info.xaxis.range[1]) ax.set_ylim(info.yaxis.range[0], info.yaxis.range[1]) if info.xaxis.name == info.yaxis.name: ax.label_outer() return ax ( xi, yi, zi, x_cat_param_pos, x_cat_param_label, y_cat_param_pos, y_cat_param_label, x_values, y_values, ) = _calculate_griddata(info.xaxis, info.yaxis, info.z_values) cs = None if len(zi) > 0: if info.xaxis.is_log: ax.set_xscale("log") if info.yaxis.is_log: ax.set_yscale("log") if info.xaxis.name != info.yaxis.name: # Contour the gridded data. ax.contour(xi, yi, zi, 15, linewidths=0.5, colors="k") cs = ax.contourf(xi, yi, zi, 15, cmap=cmap.reversed()) # Plot data points. ax.scatter( x_values, y_values, marker="o", c="black", s=20, edgecolors="grey", linewidth=2.0, ) if info.xaxis.is_cat: ax.set_xticks(x_cat_param_pos) ax.set_xticklabels(x_cat_param_label) if info.yaxis.is_cat: ax.set_yticks(y_cat_param_pos) ax.set_yticklabels(y_cat_param_label) ax.label_outer() return cs def _create_zmap( x_values: list[int | float], y_values: list[int | float], z_values: list[float], xi: np.ndarray, yi: np.ndarray, ) -> dict[tuple[int, int], float]: # Creates z-map from trial values and params. # z-map is represented by hashmap of coordinate and trial value pairs. # # Coordinates are represented by tuple of integers, where the first item # indicates x-axis index and the second item indicates y-axis index # and refer to a position of trial value on irregular param grid. # # Since params were resampled either with linspace or logspace # original params might not be on the x and y axes anymore # so we are going with close approximations of trial value positions. zmap = dict() for x, y, z in zip(x_values, y_values, z_values): xindex = int(np.argmin(np.abs(xi - x))) yindex = int(np.argmin(np.abs(yi - y))) zmap[(xindex, yindex)] = z return zmap def _interpolate_zmap(zmap: dict[tuple[int, int], float], contour_plot_num: int) -> np.ndarray: # Implements interpolation formulation used in Plotly # to interpolate heatmaps and contour plots # https://github.com/plotly/plotly.js/blob/95b3bd1bb19d8dc226627442f8f66bce9576def8/src/traces/heatmap/interp2d.js#L15-L20 # citing their doc: # # > Fill in missing data from a 2D array using an iterative # > poisson equation solver with zero-derivative BC at edges. # > Amazingly, this just amounts to repeatedly averaging all the existing # > nearest neighbors # # Plotly's algorithm is equivalent to solve the following linear simultaneous equation. # It is discretization form of the Poisson equation. # # z[x, y] = zmap[(x, y)] (if zmap[(x, y)] is given) # 4 * z[x, y] = z[x-1, y] + z[x+1, y] + z[x, y-1] + z[x, y+1] (if zmap[(x, y)] is not given) a_data = [] a_row = [] a_col = [] b = np.zeros(contour_plot_num**2) for x in range(contour_plot_num): for y in range(contour_plot_num): grid_index = y * contour_plot_num + x if (x, y) in zmap: a_data.append(1) a_row.append(grid_index) a_col.append(grid_index) b[grid_index] = zmap[(x, y)] else: for dx, dy in ((-1, 0), (1, 0), (0, -1), (0, 1)): if 0 <= x + dx < contour_plot_num and 0 <= y + dy < contour_plot_num: a_data.append(1) a_row.append(grid_index) a_col.append(grid_index) a_data.append(-1) a_row.append(grid_index) a_col.append(grid_index + dy * contour_plot_num + dx) z = scipy.sparse.linalg.spsolve(scipy.sparse.csc_matrix((a_data, (a_row, a_col))), b) return z.reshape((contour_plot_num, contour_plot_num))