12 KiB
15 minutes to get started with MMEngine
In this tutorial, we'll take training a ResNet-50 model on CIFAR-10 dataset as an example. We will build a complete and configurable pipeline for both training and validation in only 80 lines of code with MMEngine
.
The whole process includes the following steps:
Build a Model
First, we need to build a model. In MMEngine, the model should inherit from BaseModel
. Aside from parameters representing inputs from the dataset, its forward
method needs to accept an extra argument called mode
:
- for training, the value of
mode
is "loss," and theforward
method should return adict
containing the key "loss". - for validation, the value of
mode
is "predict", and the forward method should return results containing both predictions and labels.
import torch.nn.functional as F
import torchvision
from mmengine.model import BaseModel
class MMResNet50(BaseModel):
def __init__(self):
super().__init__()
self.resnet = torchvision.models.resnet50()
def forward(self, imgs, labels, mode):
x = self.resnet(imgs)
if mode == 'loss':
return {'loss': F.cross_entropy(x, labels)}
elif mode == 'predict':
return x, labels
Build a Dataset and DataLoader
Next, we need to create Dataset and DataLoader for training and validation. For basic training and validation, we can simply use built-in datasets supported in TorchVision.
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
norm_cfg = dict(mean=[0.491, 0.482, 0.447], std=[0.202, 0.199, 0.201])
train_dataloader = DataLoader(batch_size=32,
shuffle=True,
dataset=torchvision.datasets.CIFAR10(
'data/cifar10',
train=True,
download=True,
transform=transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(**norm_cfg)
])))
val_dataloader = DataLoader(batch_size=32,
shuffle=False,
dataset=torchvision.datasets.CIFAR10(
'data/cifar10',
train=False,
download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(**norm_cfg)
])))
Build a Evaluation Metrics
To validate and test the model, we need to define a Metric called accuracy to evaluate the model. This metric needs inherit from BaseMetric
and implements the process
and compute_metrics
methods where the process
method accepts the output of the dataset and other outputs when mode="predict"
. The output data at this scenario is a batch of data. After processing this batch of data, we save the information to self.results
property.
compute_metrics
accepts a results
parameter. The input results
of compute_metrics
is all the information saved in process
(In the case of a distributed environment, results
are the information collected from all process
in all the processes). Use these information to calculate and return a dict
that holds the results of the evaluation metrics
from mmengine.evaluator import BaseMetric
class Accuracy(BaseMetric):
def process(self, data_batch, data_samples):
score, gt = data_samples
# save the middle result of a batch to `self.results`
self.results.append({
'batch_size': len(gt),
'correct': (score.argmax(dim=1) == gt).sum().cpu(),
})
def compute_metrics(self, results):
total_correct = sum(item['correct'] for item in results)
total_size = sum(item['batch_size'] for item in results)
# return the dict containing the eval results
# the key is the name of the metric name
return dict(accuracy=100 * total_correct / total_size)
Build a Runner and Run the Task
Now we can build a Runner with previously defined Model
, DataLoader
, and Metrics
, and some other configs shown as follows:
from torch.optim import SGD
from mmengine.runner import Runner
runner = Runner(
# the model used for training and validation.
# Needs to meet specific interface requirements
model=MMResNet50(),
# working directory which saves training logs and weight files
work_dir='./work_dir',
# train dataloader needs to meet the PyTorch data loader protocol
train_dataloader=train_dataloader,
# optimize wrapper for optimization with additional features like
# AMP, gradtient accumulation, etc
optim_wrapper=dict(optimizer=dict(type=SGD, lr=0.001, momentum=0.9)),
# trainging coinfs for specifying training epoches, verification intervals, etc
train_cfg=dict(by_epoch=True, max_epochs=5, val_interval=1),
# validation dataloaer also needs to meet the PyTorch data loader protocol
val_dataloader=val_dataloader,
# validation configs for specifying additional parameters required for validation
val_cfg=dict(),
# validation evaluator. The default one is used here
val_evaluator=dict(type=Accuracy),
)
runner.train()
Finally, let's put all the codes above together into a complete script that uses the MMEngine
executor for training and validation:
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.optim import SGD
from torch.utils.data import DataLoader
from mmengine.evaluator import BaseMetric
from mmengine.model import BaseModel
from mmengine.runner import Runner
class MMResNet50(BaseModel):
def __init__(self):
super().__init__()
self.resnet = torchvision.models.resnet50()
def forward(self, imgs, labels, mode):
x = self.resnet(imgs)
if mode == 'loss':
return {'loss': F.cross_entropy(x, labels)}
elif mode == 'predict':
return x, labels
class Accuracy(BaseMetric):
def process(self, data_batch, data_samples):
score, gt = data_samples
self.results.append({
'batch_size': len(gt),
'correct': (score.argmax(dim=1) == gt).sum().cpu(),
})
def compute_metrics(self, results):
total_correct = sum(item['correct'] for item in results)
total_size = sum(item['batch_size'] for item in results)
return dict(accuracy=100 * total_correct / total_size)
norm_cfg = dict(mean=[0.491, 0.482, 0.447], std=[0.202, 0.199, 0.201])
train_dataloader = DataLoader(batch_size=32,
shuffle=True,
dataset=torchvision.datasets.CIFAR10(
'data/cifar10',
train=True,
download=True,
transform=transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(**norm_cfg)
])))
val_dataloader = DataLoader(batch_size=32,
shuffle=False,
dataset=torchvision.datasets.CIFAR10(
'data/cifar10',
train=False,
download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(**norm_cfg)
])))
runner = Runner(
model=MMResNet50(),
work_dir='./work_dir',
train_dataloader=train_dataloader,
optim_wrapper=dict(optimizer=dict(type=SGD, lr=0.001, momentum=0.9)),
train_cfg=dict(by_epoch=True, max_epochs=5, val_interval=1),
val_dataloader=val_dataloader,
val_cfg=dict(),
val_evaluator=dict(type=Accuracy),
)
runner.train()
Training log would be similar to this:
2022/08/22 15:51:53 - mmengine - INFO -
------------------------------------------------------------
System environment:
sys.platform: linux
Python: 3.8.12 (default, Oct 12 2021, 13:49:34) [GCC 7.5.0]
CUDA available: True
numpy_random_seed: 1513128759
GPU 0: NVIDIA GeForce GTX 1660 SUPER
CUDA_HOME: /usr/local/cuda
...
2022/08/22 15:51:54 - mmengine - INFO - Checkpoints will be saved to /home/mazerun/work_dir by HardDiskBackend.
2022/08/22 15:51:56 - mmengine - INFO - Epoch(train) [1][10/1563] lr: 1.0000e-03 eta: 0:18:23 time: 0.1414 data_time: 0.0077 memory: 392 loss: 5.3465
2022/08/22 15:51:56 - mmengine - INFO - Epoch(train) [1][20/1563] lr: 1.0000e-03 eta: 0:11:29 time: 0.0354 data_time: 0.0077 memory: 392 loss: 2.7734
2022/08/22 15:51:56 - mmengine - INFO - Epoch(train) [1][30/1563] lr: 1.0000e-03 eta: 0:09:10 time: 0.0352 data_time: 0.0076 memory: 392 loss: 2.7789
2022/08/22 15:51:57 - mmengine - INFO - Epoch(train) [1][40/1563] lr: 1.0000e-03 eta: 0:08:00 time: 0.0353 data_time: 0.0073 memory: 392 loss: 2.5725
2022/08/22 15:51:57 - mmengine - INFO - Epoch(train) [1][50/1563] lr: 1.0000e-03 eta: 0:07:17 time: 0.0347 data_time: 0.0073 memory: 392 loss: 2.7382
2022/08/22 15:51:57 - mmengine - INFO - Epoch(train) [1][60/1563] lr: 1.0000e-03 eta: 0:06:49 time: 0.0347 data_time: 0.0072 memory: 392 loss: 2.5956
2022/08/22 15:51:58 - mmengine - INFO - Epoch(train) [1][70/1563] lr: 1.0000e-03 eta: 0:06:28 time: 0.0348 data_time: 0.0072 memory: 392 loss: 2.7351
...
2022/08/22 15:52:50 - mmengine - INFO - Saving checkpoint at 1 epochs
2022/08/22 15:52:51 - mmengine - INFO - Epoch(val) [1][10/313] eta: 0:00:03 time: 0.0122 data_time: 0.0047 memory: 392
2022/08/22 15:52:51 - mmengine - INFO - Epoch(val) [1][20/313] eta: 0:00:03 time: 0.0122 data_time: 0.0047 memory: 308
2022/08/22 15:52:51 - mmengine - INFO - Epoch(val) [1][30/313] eta: 0:00:03 time: 0.0123 data_time: 0.0047 memory: 308
...
2022/08/22 15:52:54 - mmengine - INFO - Epoch(val) [1][313/313] accuracy: 35.7000
The corresponding implementation of PyTorch and MMEngine:
In addition to these basic components, you can also use executor to easily combine and configure various training techniques, such as enabling mixed-precision training and gradient accumulation (see OptimWrapper), configuring the learning rate decay curve (see Metrics & Evaluator), and etc.