Source code for optuna.integration.wandb

from typing import Any
from typing import Dict
from typing import Optional
from typing import Sequence
from typing import Union

import optuna
from optuna._experimental import experimental
from optuna._imports import try_import

with try_import() as _imports:
    import wandb

[docs]@experimental("2.9.0") class WeightsAndBiasesCallback(object): """Callback to track Optuna trials with Weights & Biases. This callback enables tracking of Optuna study in Weights & Biases. The study is tracked as a single experiment run, where all suggested hyperparameters and optimized metrics are logged and plotted as a function of optimizer steps. .. note:: User needs to be logged in to Weights & Biases before using this callback in online mode. For more information, please refer to `wandb setup <>`_. .. note:: Users who want to run multiple Optuna studies within the same process should call ``wandb.finish()`` between subsequent calls to ``study.optimize()``. Calling ``wandb.finish()`` is not necessary if you are running one Optuna study per process. .. note:: To ensure correct trial order in Weights & Biases, this callback should only be used with ``study.optimize(n_jobs=1)``. Example: Add Weights & Biases callback to Optuna optimization. .. code:: import optuna from optuna.integration.wandb import WeightsAndBiasesCallback def objective(trial): x = trial.suggest_float("x", -10, 10) return (x - 2) ** 2 wandb_kwargs = {"project": "my-project"} wandbc = WeightsAndBiasesCallback(wandb_kwargs=wandb_kwargs) study = optuna.create_study(study_name="my_study") study.optimize(objective, n_trials=10, callbacks=[wandbc]) Args: metric_name: Name assigned to optimized metric. In case of multi-objective optimization, list of names can be passed. Those names will be assigned to metrics in the order returned by objective function. If single name is provided, or this argument is left to default value, it will be broadcasted to each objective with a number suffix in order returned by objective function e.g. two objectives and default metric name will be logged as ``value_0`` and ``value_1``. wandb_kwargs: Set of arguments passed when initializing Weights & Biases run. Please refer to `Weights & Biases API documentation <>`_ for more details. Raises: :exc:`ValueError`: If there are missing or extra metric names in multi-objective optimization. :exc:`TypeError`: When metric names are not passed as sequence. """ def __init__( self, metric_name: Union[str, Sequence[str]] = "value", wandb_kwargs: Optional[Dict[str, Any]] = None, ) -> None: _imports.check() if not isinstance(metric_name, Sequence): raise TypeError( "Expected metric_name to be string or sequence of strings, got {}.".format( type(metric_name) ) ) self._metric_name = metric_name self._wandb_kwargs = wandb_kwargs or {} self._initialize_run() def __call__(self, study:, trial: optuna.trial.FrozenTrial) -> None: if isinstance(self._metric_name, str): if len(trial.values) > 1: # Broadcast default name for multi-objective optimization. names = ["{}_{}".format(self._metric_name, i) for i in range(len(trial.values))] else: names = [self._metric_name] else: if len(self._metric_name) != len(trial.values): raise ValueError( "Running multi-objective optimization " "with {} objective values, but {} names specified. " "Match objective values and names, or use default broadcasting.".format( len(trial.values), len(self._metric_name) ) ) else: names = [*self._metric_name] metrics = {name: value for name, value in zip(names, trial.values)} attributes = {"direction": [ for d in study.directions]} wandb.config.update(attributes) wandb.log({**trial.params, **metrics}, step=trial.number) def _initialize_run(self) -> None: """Initializes Weights & Biases run.""" wandb.init(**self._wandb_kwargs)