Training CIFAR-10 in 36 seconds on a single A100¶
In this example, we’ll show how to use FFCV and the ResNet-9 architecture in order to train a CIFAR-10 classifier to 92.6% accuracy in 36 seconds on a single NVIDIA A100 GPU.
We also provide the code here and the corresponding script here.
Here, we show a step by step walkthrough.
First, we import torch
and necessary components from ffcv
.
from typing import List
import torch as ch
import torchvision
from ffcv.fields import IntField, RGBImageField
from ffcv.fields.decoders import IntDecoder, SimpleRGBImageDecoder
from ffcv.loader import Loader, OrderOption
from ffcv.pipeline.operation import Operation
from ffcv.transforms import RandomHorizontalFlip, Cutout, \
RandomTranslate, Convert, ToDevice, ToTensor, ToTorchImage
from ffcv.transforms.common import Squeeze
from ffcv.writer import DatasetWriter
Step 1: Create an FFCV-compatible CIFAR-10 dataset¶
First, we’ll use DatasetWriter
to convert torchvision.datasets.CIFAR10
to FFCV format.
(See Writing datasets for more details.)
We use a single RGBImageField
to store the image and a single IntField
to store the label.
datasets = {
'train': torchvision.datasets.CIFAR10('/tmp', train=True, download=True),
'test': torchvision.datasets.CIFAR10('/tmp', train=False, download=True)
}
for (name, ds) in datasets.items():
writer = DatasetWriter(f'/tmp/cifar_{name}.beton', {
'image': RGBImageField(),
'label': IntField()
})
writer.from_indexed_dataset(ds)
Step 2: Create data loaders¶
Next, we construct FFCV dataloaders from the .beton
dataset file created above.
(See Making an FFCV dataloader for more details.)
For the training set, we use a set of standard data augmentations: random horizontal flip,
random translation, and Cutout.
Note that the transformation pipeline can consist of both standard transforms from ffcv
and other sources such as any torch.nn.Module
.
# Note that statistics are wrt to uin8 range, [0,255].
CIFAR_MEAN = [125.307, 122.961, 113.8575]
CIFAR_STD = [51.5865, 50.847, 51.255]
BATCH_SIZE = 512
loaders = {}
for name in ['train', 'test']:
label_pipeline: List[Operation] = [IntDecoder(), ToTensor(), ToDevice('cuda:0'), Squeeze()]
image_pipeline: List[Operation] = [SimpleRGBImageDecoder()]
# Add image transforms and normalization
if name == 'train':
image_pipeline.extend([
RandomHorizontalFlip(),
RandomTranslate(padding=2),
Cutout(8, tuple(map(int, CIFAR_MEAN))), # Note Cutout is done before normalization.
])
image_pipeline.extend([
ToTensor(),
ToDevice('cuda:0', non_blocking=True),
ToTorchImage(),
Convert(ch.float16),
torchvision.transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
])
# Create loaders
loaders[name] = Loader(f'/tmp/cifar_{name}.beton',
batch_size=BATCH_SIZE,
num_workers=8,
order=OrderOption.RANDOM,
drop_last=(name == 'train'),
pipelines={'image': image_pipeline,
'label': label_pipeline})
Step 3: Setup model architecture and optimization parameters¶
For the model, we use a custom ResNet-9 architecture from KakaoBrain.
class Mul(ch.nn.Module):
def __init__(self, weight):
super(Mul, self).__init__()
self.weight = weight
def forward(self, x): return x * self.weight
class Flatten(ch.nn.Module):
def forward(self, x): return x.view(x.size(0), -1)
class Residual(ch.nn.Module):
def __init__(self, module):
super(Residual, self).__init__()
self.module = module
def forward(self, x): return x + self.module(x)
def conv_bn(channels_in, channels_out, kernel_size=3, stride=1, padding=1, groups=1):
return ch.nn.Sequential(
ch.nn.Conv2d(channels_in, channels_out,
kernel_size=kernel_size, stride=stride, padding=padding,
groups=groups, bias=False),
ch.nn.BatchNorm2d(channels_out),
ch.nn.ReLU(inplace=True)
)
NUM_CLASSES = 10
model = ch.nn.Sequential(
conv_bn(3, 64, kernel_size=3, stride=1, padding=1),
conv_bn(64, 128, kernel_size=5, stride=2, padding=2),
Residual(ch.nn.Sequential(conv_bn(128, 128), conv_bn(128, 128))),
conv_bn(128, 256, kernel_size=3, stride=1, padding=1),
ch.nn.MaxPool2d(2),
Residual(ch.nn.Sequential(conv_bn(256, 256), conv_bn(256, 256))),
conv_bn(256, 128, kernel_size=3, stride=1, padding=0),
ch.nn.AdaptiveMaxPool2d((1, 1)),
Flatten(),
ch.nn.Linear(128, NUM_CLASSES, bias=False),
Mul(0.2)
)
model = model.to(memory_format=ch.channels_last).cuda()
Note the ch.channels_last
option when we put the model on GPU.
Next, we define the optimizer and hyperparameters. We use standard SGD on the cross entropy loss with label smoothing and a cyclic learning rate schedule (triangular).
import numpy as np
from torch.cuda.amp import GradScaler, autocast
from torch.nn import CrossEntropyLoss
from torch.optim import SGD, lr_scheduler
EPOCHS = 24
opt = SGD(model.parameters(), lr=.5, momentum=0.9, weight_decay=5e-4)
iters_per_epoch = 50000 // BATCH_SIZE
lr_schedule = np.interp(np.arange((EPOCHS+1) * iters_per_epoch),
[0, 5 * iters_per_epoch, EPOCHS * iters_per_epoch],
[0, 1, 0])
scheduler = lr_scheduler.LambdaLR(opt, lr_schedule.__getitem__)
scaler = GradScaler()
loss_fn = CrossEntropyLoss(label_smoothing=0.1)
Step 4: Train and evaluate the model¶
Finally, we’re ready to train our model.
from tqdm import tqdm
for ep in range(EPOCHS):
for ims, labs in tqdm(loaders['train']):
opt.zero_grad(set_to_none=True)
with autocast():
out = model(ims)
loss = loss_fn(out, labs)
scaler.scale(loss).backward()
scaler.step(opt)
scaler.update()
scheduler.step()
model.eval()
with ch.no_grad():
total_correct, total_num = 0., 0.
for ims, labs in tqdm(loaders['test']):
with autocast():
out = (model(ims) + model(ch.fliplr(ims))) / 2. # Test-time augmentation
total_correct += out.argmax(1).eq(labs).sum().cpu().item()
total_num += ims.shape[0]
print(f'Accuracy: {total_correct / total_num * 100:.1f}%')
Wrapping up¶
It’s that simple! In this tutorial, we used FFCV to train a CIFAR-10 classifier to 92.6% accuracy in 36 seconds.
For a different example using FFCV to speed up training, see Large-Scale Linear Regression.