Note
Click here to download the full example code
Multi-objective Optimization with Optuna¶
This tutorial showcases Optuna’s multi-objective optimization feature by optimizing the validation accuracy of Fashion MNIST dataset and the FLOPS of the model implemented in PyTorch.
We use thop to measure FLOPS.
import thop
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import optuna
DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
DIR = ".."
BATCHSIZE = 128
N_TRAIN_EXAMPLES = BATCHSIZE * 30
N_VALID_EXAMPLES = BATCHSIZE * 10
def define_model(trial):
n_layers = trial.suggest_int("n_layers", 1, 3)
layers = []
in_features = 28 * 28
for i in range(n_layers):
out_features = trial.suggest_int("n_units_l{}".format(i), 4, 128)
layers.append(nn.Linear(in_features, out_features))
layers.append(nn.ReLU())
p = trial.suggest_float("dropout_{}".format(i), 0.2, 0.5)
layers.append(nn.Dropout(p))
in_features = out_features
layers.append(nn.Linear(in_features, 10))
layers.append(nn.LogSoftmax(dim=1))
return nn.Sequential(*layers)
# Defines training and evaluation.
def train_model(model, optimizer, train_loader):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.view(-1, 28 * 28).to(DEVICE), target.to(DEVICE)
optimizer.zero_grad()
F.nll_loss(model(data), target).backward()
optimizer.step()
def eval_model(model, valid_loader):
model.eval()
correct = 0
with torch.no_grad():
for batch_idx, (data, target) in enumerate(valid_loader):
data, target = data.view(-1, 28 * 28).to(DEVICE), target.to(DEVICE)
pred = model(data).argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
accuracy = correct / N_VALID_EXAMPLES
flops, _ = thop.profile(model, inputs=(torch.randn(1, 28 * 28).to(DEVICE),), verbose=False)
return flops, accuracy
Define multi-objective objective function. Objectives are FLOPS and accuracy.
def objective(trial):
train_dataset = torchvision.datasets.FashionMNIST(
DIR, train=True, download=True, transform=torchvision.transforms.ToTensor()
)
train_loader = torch.utils.data.DataLoader(
torch.utils.data.Subset(train_dataset, list(range(N_TRAIN_EXAMPLES))),
batch_size=BATCHSIZE,
shuffle=True,
)
val_dataset = torchvision.datasets.FashionMNIST(
DIR, train=False, transform=torchvision.transforms.ToTensor()
)
val_loader = torch.utils.data.DataLoader(
torch.utils.data.Subset(val_dataset, list(range(N_VALID_EXAMPLES))),
batch_size=BATCHSIZE,
shuffle=True,
)
model = define_model(trial).to(DEVICE)
optimizer = torch.optim.Adam(
model.parameters(), trial.suggest_float("lr", 1e-5, 1e-1, log=True)
)
for epoch in range(10):
train_model(model, optimizer, train_loader)
flops, accuracy = eval_model(model, val_loader)
return flops, accuracy
Run multi-objective optimization¶
If your optimization problem is multi-objective,
Optuna assumes that you will specify the optimization direction for each objective.
Specifically, in this example, we want to minimize the FLOPS (we want a faster model)
and maximize the accuracy. So we set directions
to ["minimize", "maximize"]
.
study = optuna.create_study(directions=["minimize", "maximize"])
study.optimize(objective, n_trials=30, timeout=300)
print("Number of finished trials: ", len(study.trials))
Out:
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to ../FashionMNIST/raw/train-images-idx3-ubyte.gz
0%| | 0/26421880 [00:00<?, ?it/s]
0%| | 14336/26421880 [00:00<03:05, 142618.53it/s]
0%| | 44032/26421880 [00:00<01:54, 231233.05it/s]
0%| | 104448/26421880 [00:00<01:06, 398070.98it/s]
1%| | 219136/26421880 [00:00<00:38, 687642.77it/s]
2%|1 | 453632/26421880 [00:00<00:20, 1273571.84it/s]
3%|3 | 923648/26421880 [00:00<00:10, 2415632.52it/s]
7%|7 | 1862656/26421880 [00:00<00:05, 4649921.84it/s]
13%|#3 | 3440640/26421880 [00:00<00:02, 8127285.57it/s]
19%|#8 | 5015552/26421880 [00:00<00:02, 10440205.71it/s]
25%|##4 | 6562816/26421880 [00:01<00:01, 11924170.68it/s]
31%|### | 8124416/26421880 [00:01<00:01, 12988161.51it/s]
37%|###6 | 9708544/26421880 [00:01<00:01, 13787455.53it/s]
43%|####2 | 11267072/26421880 [00:01<00:01, 14262193.98it/s]
48%|####8 | 12813312/26421880 [00:01<00:00, 14557581.96it/s]
54%|#####4 | 14369792/26421880 [00:01<00:00, 14796188.18it/s]
60%|###### | 15984640/26421880 [00:01<00:00, 15138332.32it/s]
67%|######6 | 17598464/26421880 [00:01<00:00, 15368272.33it/s]
73%|#######2 | 19220480/26421880 [00:01<00:00, 15551683.38it/s]
79%|#######8 | 20835328/26421880 [00:01<00:00, 15663337.46it/s]
85%|########4 | 22445056/26421880 [00:02<00:00, 15721821.81it/s]
91%|#########1| 24056832/26421880 [00:02<00:00, 15770060.68it/s]
97%|#########7| 25668608/26421880 [00:02<00:00, 15799566.34it/s]
26422272it [00:02, 11750246.90it/s]
Extracting ../FashionMNIST/raw/train-images-idx3-ubyte.gz to ../FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to ../FashionMNIST/raw/train-labels-idx1-ubyte.gz
0%| | 0/29515 [00:00<?, ?it/s]
49%|####8 | 14336/29515 [00:00<00:00, 142394.61it/s]
29696it [00:00, 293427.37it/s]
Extracting ../FashionMNIST/raw/train-labels-idx1-ubyte.gz to ../FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to ../FashionMNIST/raw/t10k-images-idx3-ubyte.gz
0%| | 0/4422102 [00:00<?, ?it/s]
0%| | 14336/4422102 [00:00<00:30, 142852.66it/s]
1%| | 44032/4422102 [00:00<00:18, 232291.46it/s]
2%|2 | 103424/4422102 [00:00<00:10, 395280.76it/s]
5%|4 | 218112/4422102 [00:00<00:06, 688442.64it/s]
10%|# | 452608/4422102 [00:00<00:03, 1279255.17it/s]
21%|## | 922624/4422102 [00:00<00:01, 2429480.37it/s]
42%|####2 | 1862656/4422102 [00:00<00:00, 4683369.48it/s]
78%|#######8 | 3459072/4422102 [00:00<00:00, 8228785.60it/s]
4422656it [00:00, 5386922.26it/s]
Extracting ../FashionMNIST/raw/t10k-images-idx3-ubyte.gz to ../FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to ../FashionMNIST/raw/t10k-labels-idx1-ubyte.gz
0%| | 0/5148 [00:00<?, ?it/s]
6144it [00:00, 34359738.37it/s]
Extracting ../FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to ../FashionMNIST/raw
Processing...
/home/docs/checkouts/readthedocs.org/user_builds/optuna/envs/v2.10.1/lib/python3.8/site-packages/torchvision/datasets/mnist.py:479: UserWarning:
The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at /pytorch/torch/csrc/utils/tensor_numpy.cpp:143.)
Done!
Number of finished trials: 30
Check trials on pareto front visually
optuna.visualization.plot_pareto_front(study, target_names=["FLOPS", "accuracy"])
Out:
/home/docs/checkouts/readthedocs.org/user_builds/optuna/checkouts/v2.10.1/tutorial/20_recipes/002_multi_objective.py:123: ExperimentalWarning:
plot_pareto_front is experimental (supported from v2.4.0). The interface can change in the future.
Learn which hyperparameters are affecting the flops most with hyperparameter importance.
optuna.visualization.plot_param_importances(
study, target=lambda t: t.values[0], target_name="flops"
)
Total running time of the script: ( 1 minutes 37.262 seconds)