Note
Go to the end to download the full example code.
Quick Visualization for Hyperparameter Optimization Analysis
Optuna provides various visualization features in optuna.visualization
to analyze optimization results visually.
Note that this tutorial requires Plotly to be installed:
$ pip install plotly
# Required if you are running this tutorial in Jupyter Notebook.
$ pip install nbformat
If you prefer to use Matplotlib instead of Plotly, please run the following command:
$ pip install matplotlib
This tutorial walks you through this module by visualizing the optimization results of PyTorch model for FashionMNIST dataset.
For visualizing multi-objective optimization (i.e., the usage of optuna.visualization.plot_pareto_front()
),
please refer to the tutorial of Multi-objective Optimization with Optuna.
Note
By using Optuna Dashboard, you can also check the optimization history, hyperparameter importances, hyperparameter relationships, etc. in graphs and tables. Please make your study persistent using RDB backend and execute following commands to run Optuna Dashboard.
$ pip install optuna-dashboard
$ optuna-dashboard sqlite:///example-study.db
Please check out the GitHub repository for more details.
Manage Studies |
Visualize with Interactive Graphs |
---|---|
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import optuna
# You can use Matplotlib instead of Plotly for visualization by simply replacing `optuna.visualization` with
# `optuna.visualization.matplotlib` in the following examples.
from optuna.visualization import plot_contour
from optuna.visualization import plot_edf
from optuna.visualization import plot_intermediate_values
from optuna.visualization import plot_optimization_history
from optuna.visualization import plot_parallel_coordinate
from optuna.visualization import plot_param_importances
from optuna.visualization import plot_rank
from optuna.visualization import plot_slice
from optuna.visualization import plot_timeline
SEED = 13
torch.manual_seed(SEED)
DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
DIR = ".."
BATCHSIZE = 128
N_TRAIN_EXAMPLES = BATCHSIZE * 30
N_VALID_EXAMPLES = BATCHSIZE * 10
def define_model(trial):
n_layers = trial.suggest_int("n_layers", 1, 2)
layers = []
in_features = 28 * 28
for i in range(n_layers):
out_features = trial.suggest_int("n_units_l{}".format(i), 64, 512)
layers.append(nn.Linear(in_features, out_features))
layers.append(nn.ReLU())
in_features = out_features
layers.append(nn.Linear(in_features, 10))
layers.append(nn.LogSoftmax(dim=1))
return nn.Sequential(*layers)
# Defines training and evaluation.
def train_model(model, optimizer, train_loader):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.view(-1, 28 * 28).to(DEVICE), target.to(DEVICE)
optimizer.zero_grad()
F.nll_loss(model(data), target).backward()
optimizer.step()
def eval_model(model, valid_loader):
model.eval()
correct = 0
with torch.no_grad():
for batch_idx, (data, target) in enumerate(valid_loader):
data, target = data.view(-1, 28 * 28).to(DEVICE), target.to(DEVICE)
pred = model(data).argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
accuracy = correct / N_VALID_EXAMPLES
return accuracy
Define the objective function.
def objective(trial):
train_dataset = torchvision.datasets.FashionMNIST(
DIR, train=True, download=True, transform=torchvision.transforms.ToTensor()
)
train_loader = torch.utils.data.DataLoader(
torch.utils.data.Subset(train_dataset, list(range(N_TRAIN_EXAMPLES))),
batch_size=BATCHSIZE,
shuffle=True,
)
val_dataset = torchvision.datasets.FashionMNIST(
DIR, train=False, transform=torchvision.transforms.ToTensor()
)
val_loader = torch.utils.data.DataLoader(
torch.utils.data.Subset(val_dataset, list(range(N_VALID_EXAMPLES))),
batch_size=BATCHSIZE,
shuffle=True,
)
model = define_model(trial).to(DEVICE)
optimizer = torch.optim.Adam(
model.parameters(), trial.suggest_float("lr", 1e-5, 1e-1, log=True)
)
for epoch in range(10):
train_model(model, optimizer, train_loader)
val_accuracy = eval_model(model, val_loader)
trial.report(val_accuracy, epoch)
if trial.should_prune():
raise optuna.exceptions.TrialPruned()
return val_accuracy
study = optuna.create_study(
direction="maximize",
sampler=optuna.samplers.TPESampler(seed=SEED),
pruner=optuna.pruners.MedianPruner(),
)
study.optimize(objective, n_trials=30, timeout=300)
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to ../FashionMNIST/raw/train-images-idx3-ubyte.gz
0%| | 0/26421880 [00:00<?, ?it/s]
0%| | 32768/26421880 [00:00<01:22, 319437.53it/s]
0%| | 65536/26421880 [00:00<01:23, 316512.69it/s]
0%| | 98304/26421880 [00:00<01:23, 315848.71it/s]
1%| | 196608/26421880 [00:00<00:46, 564373.88it/s]
1%| | 294912/26421880 [00:00<00:37, 701627.70it/s]
1%|▏ | 393216/26421880 [00:00<00:33, 783945.02it/s]
2%|▏ | 491520/26421880 [00:00<00:31, 836387.17it/s]
2%|▏ | 622592/26421880 [00:00<00:26, 969457.81it/s]
3%|▎ | 720896/26421880 [00:00<00:26, 963063.84it/s]
3%|▎ | 851968/26421880 [00:01<00:24, 1052350.26it/s]
4%|▎ | 983040/26421880 [00:01<00:22, 1115393.97it/s]
4%|▍ | 1114112/26421880 [00:01<00:21, 1159087.43it/s]
5%|▍ | 1245184/26421880 [00:01<00:21, 1189515.59it/s]
5%|▌ | 1376256/26421880 [00:01<00:20, 1212170.22it/s]
6%|▌ | 1540096/26421880 [00:01<00:18, 1319405.59it/s]
6%|▋ | 1703936/26421880 [00:01<00:17, 1394536.29it/s]
7%|▋ | 1867776/26421880 [00:01<00:16, 1450010.28it/s]
8%|▊ | 2031616/26421880 [00:01<00:16, 1489683.92it/s]
8%|▊ | 2228224/26421880 [00:01<00:15, 1606797.42it/s]
9%|▉ | 2424832/26421880 [00:02<00:14, 1691117.90it/s]
10%|▉ | 2621440/26421880 [00:02<00:13, 1750609.26it/s]
11%|█ | 2850816/26421880 [00:02<00:12, 1883649.75it/s]
12%|█▏ | 3047424/26421880 [00:02<00:12, 1890781.56it/s]
13%|█▎ | 3309568/26421880 [00:02<00:11, 2074264.58it/s]
13%|█▎ | 3538944/26421880 [00:02<00:10, 2115826.70it/s]
14%|█▍ | 3801088/26421880 [00:02<00:10, 2236802.63it/s]
16%|█▌ | 4096000/26421880 [00:02<00:09, 2409343.56it/s]
16%|█▋ | 4358144/26421880 [00:02<00:09, 2449086.02it/s]
18%|█▊ | 4653056/26421880 [00:03<00:08, 2563373.61it/s]
19%|█▉ | 4980736/26421880 [00:03<00:07, 2738408.11it/s]
20%|██ | 5308416/26421880 [00:03<00:07, 2861178.07it/s]
21%|██▏ | 5668864/26421880 [00:03<00:06, 3037370.00it/s]
23%|██▎ | 6029312/26421880 [00:03<00:06, 3166832.89it/s]
24%|██▍ | 6422528/26421880 [00:03<00:05, 3346120.16it/s]
26%|██▌ | 6815744/26421880 [00:03<00:05, 3480096.79it/s]
27%|██▋ | 7241728/26421880 [00:03<00:05, 3660079.19it/s]
29%|██▉ | 7700480/26421880 [00:03<00:04, 3879283.61it/s]
31%|███ | 8159232/26421880 [00:03<00:04, 4038504.28it/s]
33%|███▎ | 8650752/26421880 [00:04<00:04, 4239336.75it/s]
35%|███▍ | 9175040/26421880 [00:04<00:03, 4472035.54it/s]
37%|███▋ | 9699328/26421880 [00:04<00:03, 4658404.50it/s]
39%|███▉ | 10256384/26421880 [00:04<00:03, 4871564.25it/s]
41%|████ | 10846208/26421880 [00:04<00:03, 5111425.82it/s]
43%|████▎ | 11468800/26421880 [00:04<00:02, 5373166.07it/s]
46%|████▌ | 12156928/26421880 [00:04<00:02, 5727581.02it/s]
49%|████▊ | 12845056/26421880 [00:04<00:02, 5990164.26it/s]
51%|█████▏ | 13565952/26421880 [00:04<00:02, 6270523.52it/s]
54%|█████▍ | 14319616/26421880 [00:04<00:01, 6566116.79it/s]
57%|█████▋ | 15106048/26421880 [00:05<00:01, 6863480.25it/s]
60%|██████ | 15958016/26421880 [00:05<00:01, 7257334.52it/s]
64%|██████▎ | 16809984/26421880 [00:05<00:01, 7538814.23it/s]
67%|██████▋ | 17727488/26421880 [00:05<00:01, 7921262.78it/s]
71%|███████ | 18710528/26421880 [00:05<00:00, 8372642.26it/s]
75%|███████▍ | 19726336/26421880 [00:05<00:00, 8788745.45it/s]
79%|███████▊ | 20774912/26421880 [00:05<00:00, 9176162.51it/s]
83%|████████▎ | 21889024/26421880 [00:05<00:00, 9633883.69it/s]
87%|████████▋ | 23068672/26421880 [00:05<00:00, 10145366.70it/s]
92%|█████████▏| 24313856/26421880 [00:06<00:00, 10688417.10it/s]
97%|█████████▋| 25624576/26421880 [00:06<00:00, 11251484.20it/s]
100%|██████████| 26421880/26421880 [00:06<00:00, 4298629.79it/s]
Extracting ../FashionMNIST/raw/train-images-idx3-ubyte.gz to ../FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to ../FashionMNIST/raw/train-labels-idx1-ubyte.gz
0%| | 0/29515 [00:00<?, ?it/s]
100%|██████████| 29515/29515 [00:00<00:00, 287001.37it/s]
100%|██████████| 29515/29515 [00:00<00:00, 286218.37it/s]
Extracting ../FashionMNIST/raw/train-labels-idx1-ubyte.gz to ../FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to ../FashionMNIST/raw/t10k-images-idx3-ubyte.gz
0%| | 0/4422102 [00:00<?, ?it/s]
1%| | 32768/4422102 [00:00<00:13, 317498.24it/s]
1%|▏ | 65536/4422102 [00:00<00:13, 316864.42it/s]
3%|▎ | 131072/4422102 [00:00<00:09, 461394.42it/s]
4%|▍ | 196608/4422102 [00:00<00:07, 529178.65it/s]
9%|▉ | 393216/4422102 [00:00<00:03, 1023838.99it/s]
18%|█▊ | 786432/4422102 [00:00<00:01, 1968065.54it/s]
36%|███▋ | 1605632/4422102 [00:00<00:00, 3912585.62it/s]
72%|███████▏ | 3178496/4422102 [00:00<00:00, 7505016.67it/s]
100%|██████████| 4422102/4422102 [00:00<00:00, 5325427.95it/s]
Extracting ../FashionMNIST/raw/t10k-images-idx3-ubyte.gz to ../FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to ../FashionMNIST/raw/t10k-labels-idx1-ubyte.gz
0%| | 0/5148 [00:00<?, ?it/s]
100%|██████████| 5148/5148 [00:00<00:00, 32616732.62it/s]
Extracting ../FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to ../FashionMNIST/raw
Plot functions
Visualize the optimization history. See plot_optimization_history()
for the details.
plot_optimization_history(study)
Visualize the learning curves of the trials. See plot_intermediate_values()
for the details.
plot_intermediate_values(study)
Visualize high-dimensional parameter relationships. See plot_parallel_coordinate()
for the details.
plot_parallel_coordinate(study)
Select parameters to visualize.
plot_parallel_coordinate(study, params=["lr", "n_layers"])
Visualize hyperparameter relationships. See plot_contour()
for the details.
plot_contour(study)
Select parameters to visualize.
plot_contour(study, params=["lr", "n_layers"])
Visualize individual hyperparameters as slice plot. See plot_slice()
for the details.
plot_slice(study)
Select parameters to visualize.
plot_slice(study, params=["lr", "n_layers"])
Visualize parameter importances. See plot_param_importances()
for the details.
plot_param_importances(study)
Learn which hyperparameters are affecting the trial duration with hyperparameter importance.
optuna.visualization.plot_param_importances(
study, target=lambda t: t.duration.total_seconds(), target_name="duration"
)
Visualize empirical distribution function. See plot_edf()
for the details.
plot_edf(study)
Visualize parameter relations with scatter plots colored by objective values. See plot_rank()
for the details.
plot_rank(study)
Visualize the optimization timeline of performed trials. See plot_timeline()
for the details.
plot_timeline(study)
Customize generated figures
In optuna.visualization
and optuna.visualization.matplotlib
, a function returns an editable figure object:
plotly.graph_objects.Figure
or matplotlib.axes.Axes
depending on the module.
This allows users to modify the generated figure for their demand by using API of the visualization library.
The following example replaces figure titles drawn by Plotly-based plot_intermediate_values()
manually.
fig = plot_intermediate_values(study)
fig.update_layout(
title="Hyperparameter optimization for FashionMNIST classification",
xaxis_title="Epoch",
yaxis_title="Validation Accuracy",
)
Total running time of the script: (2 minutes 4.251 seconds)