Multi-objective Optimization with Optuna

This tutorial showcases Optuna’s multi-objective optimization feature by optimizing the validation accuracy of Fashion MNIST dataset and the FLOPS of the model implemented in PyTorch.

We use thop to measure FLOPS.

import thop
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision

import optuna


DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
DIR = ".."
BATCHSIZE = 128
N_TRAIN_EXAMPLES = BATCHSIZE * 30
N_VALID_EXAMPLES = BATCHSIZE * 10


def define_model(trial):
    n_layers = trial.suggest_int("n_layers", 1, 3)
    layers = []

    in_features = 28 * 28
    for i in range(n_layers):
        out_features = trial.suggest_int("n_units_l{}".format(i), 4, 128)
        layers.append(nn.Linear(in_features, out_features))
        layers.append(nn.ReLU())
        p = trial.suggest_float("dropout_{}".format(i), 0.2, 0.5)
        layers.append(nn.Dropout(p))

        in_features = out_features

    layers.append(nn.Linear(in_features, 10))
    layers.append(nn.LogSoftmax(dim=1))

    return nn.Sequential(*layers)


# Defines training and evaluation.
def train_model(model, optimizer, train_loader):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.view(-1, 28 * 28).to(DEVICE), target.to(DEVICE)
        optimizer.zero_grad()
        F.nll_loss(model(data), target).backward()
        optimizer.step()


def eval_model(model, valid_loader):
    model.eval()
    correct = 0
    with torch.no_grad():
        for batch_idx, (data, target) in enumerate(valid_loader):
            data, target = data.view(-1, 28 * 28).to(DEVICE), target.to(DEVICE)
            pred = model(data).argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    accuracy = correct / N_VALID_EXAMPLES

    flops, _ = thop.profile(model, inputs=(torch.randn(1, 28 * 28).to(DEVICE),), verbose=False)
    return flops, accuracy

Define multi-objective objective function. Objectives are FLOPS and accuracy.

def objective(trial):
    train_dataset = torchvision.datasets.FashionMNIST(
        DIR, train=True, download=True, transform=torchvision.transforms.ToTensor()
    )
    train_loader = torch.utils.data.DataLoader(
        torch.utils.data.Subset(train_dataset, list(range(N_TRAIN_EXAMPLES))),
        batch_size=BATCHSIZE,
        shuffle=True,
    )

    val_dataset = torchvision.datasets.FashionMNIST(
        DIR, train=False, transform=torchvision.transforms.ToTensor()
    )
    val_loader = torch.utils.data.DataLoader(
        torch.utils.data.Subset(val_dataset, list(range(N_VALID_EXAMPLES))),
        batch_size=BATCHSIZE,
        shuffle=True,
    )
    model = define_model(trial).to(DEVICE)

    optimizer = torch.optim.Adam(
        model.parameters(), trial.suggest_float("lr", 1e-5, 1e-1, log=True)
    )

    for epoch in range(10):
        train_model(model, optimizer, train_loader)
    flops, accuracy = eval_model(model, val_loader)
    return flops, accuracy

Run multi-objective optimization

If your optimization problem is multi-objective, Optuna assumes that you will specify the optimization direction for each objective. Specifically, in this example, we want to minimize the FLOPS (we want a faster model) and maximize the accuracy. So we set directions to ["minimize", "maximize"].

study = optuna.create_study(directions=["minimize", "maximize"])
study.optimize(objective, n_trials=30, timeout=300)

print("Number of finished trials: ", len(study.trials))
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to ../FashionMNIST/raw/train-images-idx3-ubyte.gz

  0%|          | 0/26421880 [00:00<?, ?it/s]
  0%|          | 41984/26421880 [00:00<01:14, 352258.47it/s]
  0%|          | 78848/26421880 [00:00<01:13, 360287.41it/s]
  1%|          | 144384/26421880 [00:00<00:53, 488285.97it/s]
  1%|1         | 302080/26421880 [00:00<00:28, 904917.94it/s]
  2%|2         | 616448/26421880 [00:00<00:15, 1694116.03it/s]
  5%|4         | 1243136/26421880 [00:00<00:07, 3221646.27it/s]
  9%|9         | 2485248/26421880 [00:00<00:03, 6186936.09it/s]
 15%|#5        | 4079616/26421880 [00:00<00:02, 9254733.28it/s]
 21%|##1       | 5643264/26421880 [00:00<00:01, 11218740.39it/s]
 27%|##7       | 7189504/26421880 [00:01<00:01, 12500675.80it/s]
 33%|###3      | 8736768/26421880 [00:01<00:01, 13384832.73it/s]
 39%|###8      | 10283008/26421880 [00:01<00:01, 13987729.74it/s]
 45%|####4     | 11829248/26421880 [00:01<00:01, 14408785.25it/s]
 51%|#####     | 13376512/26421880 [00:01<00:00, 14705931.12it/s]
 56%|#####6    | 14922752/26421880 [00:01<00:00, 14910496.16it/s]
 62%|######2   | 16468992/26421880 [00:01<00:00, 15051742.23it/s]
 68%|######8   | 18079744/26421880 [00:01<00:00, 15346465.83it/s]
 75%|#######4  | 19691520/26421880 [00:01<00:00, 15556794.34it/s]
 81%|########  | 21302272/26421880 [00:01<00:00, 15699309.94it/s]
 87%|########6 | 22913024/26421880 [00:02<00:00, 15800814.29it/s]
 93%|#########2| 24523776/26421880 [00:02<00:00, 15870558.11it/s]
 99%|#########8| 26135552/26421880 [00:02<00:00, 15919802.00it/s]
26422272it [00:02, 11835922.13it/s]
Extracting ../FashionMNIST/raw/train-images-idx3-ubyte.gz to ../FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to ../FashionMNIST/raw/train-labels-idx1-ubyte.gz

  0%|          | 0/29515 [00:00<?, ?it/s]
 42%|####1     | 12288/29515 [00:00<00:00, 122849.66it/s]
29696it [00:00, 295211.94it/s]
Extracting ../FashionMNIST/raw/train-labels-idx1-ubyte.gz to ../FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to ../FashionMNIST/raw/t10k-images-idx3-ubyte.gz

  0%|          | 0/4422102 [00:00<?, ?it/s]
  0%|          | 18432/4422102 [00:00<00:23, 184245.94it/s]
  1%|          | 44032/4422102 [00:00<00:19, 225854.49it/s]
  2%|2         | 104448/4422102 [00:00<00:10, 397469.11it/s]
  5%|4         | 206848/4422102 [00:00<00:06, 642988.33it/s]
 10%|9         | 431104/4422102 [00:00<00:03, 1216218.70it/s]
 20%|#9        | 877568/4422102 [00:00<00:01, 2313264.20it/s]
 40%|####      | 1771520/4422102 [00:00<00:00, 4465440.57it/s]
 76%|#######6  | 3362816/4422102 [00:00<00:00, 8083607.63it/s]
4422656it [00:00, 5374347.58it/s]
Extracting ../FashionMNIST/raw/t10k-images-idx3-ubyte.gz to ../FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to ../FashionMNIST/raw/t10k-labels-idx1-ubyte.gz

  0%|          | 0/5148 [00:00<?, ?it/s]
6144it [00:00, 34871182.38it/s]
Extracting ../FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to ../FashionMNIST/raw

Number of finished trials:  30

Check trials on Pareto front visually.

optuna.visualization.plot_pareto_front(study, target_names=["FLOPS", "accuracy"])


Fetch the list of trials on the Pareto front with best_trials.

For example, the following code shows the number of trials on the Pareto front and picks the trial with the highest accuracy.

print(f"Number of trials on the Pareto front: {len(study.best_trials)}")

trial_with_highest_accuracy = max(study.best_trials, key=lambda t: t.values[1])
print(f"Trial with highest accuracy: ")
print(f"\tnumber: {trial_with_highest_accuracy.number}")
print(f"\tparams: {trial_with_highest_accuracy.params}")
print(f"\tvalues: {trial_with_highest_accuracy.values}")
Number of trials on the Pareto front: 5
Trial with highest accuracy:
        number: 22
        params: {'n_layers': 1, 'n_units_l0': 63, 'dropout_0': 0.3004585640715691, 'lr': 0.004986868776421742}
        values: [50022.0, 0.8265625]

Learn which hyperparameters are affecting the flops most with hyperparameter importance.

optuna.visualization.plot_param_importances(
    study, target=lambda t: t.values[0], target_name="flops"
)


Total running time of the script: ( 1 minutes 56.803 seconds)

Gallery generated by Sphinx-Gallery