用 Optuna 进行多目标优化

This tutorial showcases Optuna’s multi-objective optimization feature by optimizing the validation accuracy of Fashion MNIST dataset and the FLOPS of the model implemented in PyTorch.

我们采用 thop 来测量 FLOPS.

import thop
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision

import optuna


DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
DIR = ".."
BATCHSIZE = 128
N_TRAIN_EXAMPLES = BATCHSIZE * 30
N_VALID_EXAMPLES = BATCHSIZE * 10


def define_model(trial):
    n_layers = trial.suggest_int("n_layers", 1, 3)
    layers = []

    in_features = 28 * 28
    for i in range(n_layers):
        out_features = trial.suggest_int("n_units_l{}".format(i), 4, 128)
        layers.append(nn.Linear(in_features, out_features))
        layers.append(nn.ReLU())
        p = trial.suggest_float("dropout_{}".format(i), 0.2, 0.5)
        layers.append(nn.Dropout(p))

        in_features = out_features

    layers.append(nn.Linear(in_features, 10))
    layers.append(nn.LogSoftmax(dim=1))

    return nn.Sequential(*layers)


# Defines training and evaluation.
def train_model(model, optimizer, train_loader):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.view(-1, 28 * 28).to(DEVICE), target.to(DEVICE)
        optimizer.zero_grad()
        F.nll_loss(model(data), target).backward()
        optimizer.step()


def eval_model(model, valid_loader):
    model.eval()
    correct = 0
    with torch.no_grad():
        for batch_idx, (data, target) in enumerate(valid_loader):
            data, target = data.view(-1, 28 * 28).to(DEVICE), target.to(DEVICE)
            pred = model(data).argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    accuracy = correct / N_VALID_EXAMPLES

    flops, _ = thop.profile(model, inputs=(torch.randn(1, 28 * 28).to(DEVICE),), verbose=False)
    return flops, accuracy

Out:

/home/docs/checkouts/readthedocs.org/user_builds/optuna-zh-cn/envs/latest/lib/python3.7/site-packages/torch/cuda/__init__.py:52: UserWarning:

CUDA initialization: Found no NVIDIA driver on your system. Please check that you have an NVIDIA GPU and installed a driver from http://www.nvidia.com/Download/index.aspx (Triggered internally at  /pytorch/c10/cuda/CUDAFunctions.cpp:100.)

定义多目标函数. 其目标为 FLOPS 和 准确度.

def objective(trial):
    train_dataset = torchvision.datasets.FashionMNIST(
        DIR, train=True, download=True, transform=torchvision.transforms.ToTensor()
    )
    train_loader = torch.utils.data.DataLoader(
        torch.utils.data.Subset(train_dataset, list(range(N_TRAIN_EXAMPLES))),
        batch_size=BATCHSIZE,
        shuffle=True,
    )

    val_dataset = torchvision.datasets.FashionMNIST(
        DIR, train=False, transform=torchvision.transforms.ToTensor()
    )
    val_loader = torch.utils.data.DataLoader(
        torch.utils.data.Subset(val_dataset, list(range(N_VALID_EXAMPLES))),
        batch_size=BATCHSIZE,
        shuffle=True,
    )
    model = define_model(trial).to(DEVICE)

    optimizer = torch.optim.Adam(
        model.parameters(), trial.suggest_float("lr", 1e-5, 1e-1, log=True)
    )

    for epoch in range(10):
        train_model(model, optimizer, train_loader)
    flops, accuracy = eval_model(model, val_loader)
    return flops, accuracy

运行多目标优化.

如果你的优化问题是多目标的, Optuna 预设你会为每一个目标指定优化方向. 具体在本例中, 我们希望最小化 FLOPS (我们想要更快的模型) 和最大化准确度. 所以我们将 directions 设置为 ["minimize", "maximize"].

study = optuna.create_study(directions=["minimize", "maximize"])
study.optimize(objective, n_trials=30, timeout=300)

print("Number of finished trials: ", len(study.trials))

Out:

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to ../FashionMNIST/raw/train-images-idx3-ubyte.gz

0it [00:00, ?it/s]
  0%|          | 0/26421880 [00:00<?, ?it/s]
  0%|          | 49152/26421880 [00:00<01:46, 248014.92it/s]
  1%|          | 139264/26421880 [00:00<01:02, 421743.45it/s]
  2%|1         | 442368/26421880 [00:00<00:20, 1272550.89it/s]
  4%|3         | 942080/26421880 [00:00<00:11, 2176696.28it/s]
  7%|7         | 1974272/26421880 [00:00<00:05, 4511643.66it/s]
 14%|#3        | 3678208/26421880 [00:01<00:02, 8152329.01it/s]
 21%|##        | 5545984/26421880 [00:01<00:02, 10092032.65it/s]
 27%|##6       | 7127040/26421880 [00:01<00:01, 11192705.86it/s]
 33%|###3      | 8830976/26421880 [00:01<00:01, 12787248.30it/s]
 40%|###9      | 10518528/26421880 [00:01<00:01, 13926096.75it/s]
 46%|####6     | 12197888/26421880 [00:01<00:00, 14735962.39it/s]
 53%|#####2    | 13959168/26421880 [00:01<00:00, 15569539.36it/s]
 61%|######1   | 16179200/26421880 [00:01<00:00, 15751040.63it/s]
 67%|######7   | 17833984/26421880 [00:01<00:00, 15966606.44it/s]
 74%|#######3  | 19513344/26421880 [00:02<00:00, 16195251.49it/s]
 80%|########  | 21233664/26421880 [00:02<00:00, 16483372.21it/s]
 87%|########7 | 23019520/26421880 [00:02<00:00, 16880922.11it/s]
 94%|#########3| 24748032/26421880 [00:02<00:00, 15316962.81it/s]
26427392it [00:02, 10739958.83it/s]
Extracting ../FashionMNIST/raw/train-images-idx3-ubyte.gz to ../FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to ../FashionMNIST/raw/train-labels-idx1-ubyte.gz

0it [00:00, ?it/s]
  0%|          | 0/29515 [00:00<?, ?it/s]
32768it [00:00, 99803.39it/s]
Extracting ../FashionMNIST/raw/train-labels-idx1-ubyte.gz to ../FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to ../FashionMNIST/raw/t10k-images-idx3-ubyte.gz

0it [00:00, ?it/s]
  0%|          | 0/4422102 [00:00<?, ?it/s]
  1%|1         | 49152/4422102 [00:00<00:17, 247986.57it/s]
  3%|2         | 114688/4422102 [00:00<00:10, 418145.41it/s]
  6%|5         | 253952/4422102 [00:00<00:05, 782465.83it/s]
 11%|#1        | 507904/4422102 [00:00<00:02, 1394129.49it/s]
 25%|##4       | 1089536/4422102 [00:00<00:01, 2862981.83it/s]
 50%|####9     | 2203648/4422102 [00:00<00:00, 5526569.71it/s]
 89%|########9 | 3956736/4422102 [00:01<00:00, 9307042.78it/s]
4423680it [00:01, 4222097.03it/s]
Extracting ../FashionMNIST/raw/t10k-images-idx3-ubyte.gz to ../FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to ../FashionMNIST/raw/t10k-labels-idx1-ubyte.gz

0it [00:00, ?it/s]
  0%|          | 0/5148 [00:00<?, ?it/s]
8192it [00:00, 30002.98it/s]
Extracting ../FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to ../FashionMNIST/raw
Processing...
/home/docs/checkouts/readthedocs.org/user_builds/optuna-zh-cn/envs/latest/lib/python3.7/site-packages/torchvision/datasets/mnist.py:480: UserWarning:

The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at  /pytorch/torch/csrc/utils/tensor_numpy.cpp:141.)

Done!
Number of finished trials:  30

可视化地检查位于帕累托前沿上的 trials

optuna.visualization.plot_pareto_front(study, target_names=["FLOPS", "accuracy"])

Out:

/home/docs/checkouts/readthedocs.org/user_builds/optuna-zh-cn/checkouts/latest/tutorial/20_recipes/002_multi_objective.py:123: ExperimentalWarning:

plot_pareto_front is experimental (supported from v2.4.0). The interface can change in the future.


Total running time of the script: ( 1 minutes 48.248 seconds)

Gallery generated by Sphinx-Gallery