44 KiB
Migrate Runner from MMCV to MMEngine
Introduction
As MMCV supports more and more deep learning tasks, and users' needs become much more complicated, we have higher requirements for the flexibility and versatility of the existing Runner
of MMCV. Therefore, MMEngine implements a more general and flexible Runner
based on MMCV to support more complicated training processes.
The Runner
in MMEngine expands the scope and takes on more functions. we abstracted training loop controller (EpochBasedTrainLoop/IterBasedTrainLoop), validation loop controller (ValLoop) and TestLoop to make it more convenient for users to customize their training process.
Firstly, we will introduce how to migrate the entry point of training from MMCV to MMEngine, to simplify and unify the training script. Then, we'll introduce the difference in the instantiation of Runner
between MMCV and MMEngine in detail.
Migrate the entry point
Take MMDet as an example, the differences between training scripts in MMCV and MMEngine are as follows:
Migrate the configuration file
Configuration file based on MMCV Runner | Configuration file based on MMEngine Runner |
---|---|
|
|
|
|
|
|
Runner
in MMEngine provides more customizable components, including training/validation/testing process and DataLoader. Therefore, the configuration file is a bit longer compared to MMCV.
MMEngine
follows the WYSIWYG principle and reorganizes the hierarchy of each component in configuration so that most of the first-level fields of configuration correspond to the core components in the Runner
, such as DataLoader, Evaluator, Hook, etc. The new format configuration file could help users to read and understand the core components in Runner
, and ignore the relatively unimportant parts.
Migrate the training script
Compared with the Runner
in MMCV, Runner
in MMEngine takes on more functions, such as building DataLoader and distributed model. Therefore, we do not need to build the components like DataLoader and distributed model manually anymore. We can configure them during the instantiation of Runner
, and then build them in the training/validation/testing process. Take the training script of MMDet as an example:
Training script based on MMCV Runner | Training script based on MMEngine Runner |
---|---|
|
|
|
|
Table above shows the differences between training script of MMEngine Runner
and MMCV Runner
. Repositories of OpenMMLab 1.x organize their own process to build Runner
, which contributes to the large amount of redundant code. MMEngine unifies and formats the building process, such as setting random seed, initializing distributed environment, building DataLoader, building Optimizer
, etc. This help the downstream repositories simplify the process to prepare the runner, and only need to configure the parameters of Runner
.
For the downstream repositories, training script based on MMEngine Runner not only simplify the tools/train.py
, but also can directly omit the apis/train.py
. Similarly, we can also set random seed, initialize distributed environment by configuring the parameters of Runner
, and do not need to implement the corresponding code.
Migrate Runner
This section describes the differences in the training, validation, and testing processes between the MMCV Runner and the MMEngine Runner, as follows.
- Prepare logger
- Set random seed
- Initialize environment variables
- Prepare data
- Prepare model
- Prepare optimizer
- Prepare hooks
- Prepare testing/validation components
- Build runner
- Load checkpoint
- Training process, Testing process
- Custom training process
The following tutorial will describe the difference above in detail.
Prepare logger
Prepare logger in MMCV
MMCV needs to call the get_logger
to get a formatted logger and use it to output and log the training information.
logger = get_logger(name='custom', log_file=log_file, log_level=cfg.log_level)
env_info_dict = collect_env()
env_info = '\n'.join([(f'{k}: {v}') for k, v in env_info_dict.items()])
dash_line = '-' * 60 + '\n'
logger.info('Environment info:\n' + dash_line + env_info + '\n' +
dash_line)
The instantiation of the Runner also relies on the logger:
runner = Runner(
...
logger=logger
...)
Prepare logger in MMEngine
Configure the log_level
for Runner
, and it will build the logger automatically.
log_level = 'INFO'
Set random seed
Set random seed in MMCV
Set random seed manually in training script:
...
seed = init_random_seed(args.seed, device=cfg.device)
seed = seed + dist.get_rank() if args.diff_seed else seed
logger.info(f'Set random seed to {seed}, '
f'deterministic: {args.deterministic}')
set_random_seed(seed, deterministic=args.deterministic)
...
Set random seed in MMEngine
Configure the randomness
for Runner
, see more information in Runner.set_randomness
Configuration changes
Configuration of MMCV | Configuration of MMEngine |
---|---|
|
|
Initialize environment variables
Initialize the environment variables
MMCV needs to setup launcher of distributed training, set environment variables for multi-process communication, initialize the distributed environment and wrap model with the distributed wrapper like this:
...
setup_multi_processes(cfg)
init_dist(cfg.launcher, **cfg.dist_params)
model = MMDistributedDataParallel(
model,
device_ids=[int(os.environ['LOCAL_RANK'])],
broadcast_buffers=False,
find_unused_parameters=find_unused_parameters)
As for MMEngine, you can setup launcher by configuring launcher
of Runner
, and configure other items mentioned above in env_cfg
. See more information in the table below:
Configuration changes
MMCV configuration | MMEngine configuration |
---|---|
|
|
In this tutorial, we set env_cfg
to:
env_cfg = dict(dist_cfg=dict(backend='nccl'))
Prepare data
Both MMEngine and MMCV Runner
can accept built DataLoader
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
train_dataset = CIFAR10(
root='data', train=True, download=True, transform=transform)
train_dataloader = DataLoader(
train_dataset, batch_size=128, shuffle=True, num_workers=2)
val_dataset = CIFAR10(
root='data', train=False, download=True, transform=transform)
val_dataloader = DataLoader(
val_dataset, batch_size=128, shuffle=False, num_workers=2)
Configuration changes
Configuration of MMCV | Configuration of MMEngine |
---|---|
|
|
Prepare model
See Migrate model from mmcv for more information
import torch.nn as nn
import torch.nn.functional as F
from mmengine.model import BaseModel
class Model(BaseModel):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
self.loss_fn = nn.CrossEntropyLoss()
def forward(self, img, label, mode):
feat = self.pool(F.relu(self.conv1(img)))
feat = self.pool(F.relu(self.conv2(feat)))
feat = feat.view(-1, 16 * 5 * 5)
feat = F.relu(self.fc1(feat))
feat = F.relu(self.fc2(feat))
feat = self.fc3(feat)
if mode == 'loss':
loss = self.loss_fn(feat, label)
return dict(loss=loss)
else:
return [feat.argmax(1)]
model = Model()
Prepare optimizer
Prepare optimizer in MMCV
MMCV Runner can accept built optimizer
optimizer = SGD(model.parameters(), lr=0.1, momentum=0.9)
For complicated configurations of optimizers, MMCV needs to build optimizers based on the optimizer constructors.
optimizer_cfg = dict(
optimizer=dict(type='SGD', lr=0.01, weight_decay=0.0001),
paramwise_cfg=dict(norm_decay_mult=0))
def build_optimizer_constructor(cfg):
constructor_type = cfg.get('type')
if constructor_type in OPTIMIZER_BUILDERS:
return build_from_cfg(cfg, OPTIMIZER_BUILDERS)
elif constructor_type in MMCV_OPTIMIZER_BUILDERS:
return build_from_cfg(cfg, MMCV_OPTIMIZER_BUILDERS)
else:
raise KeyError(f'{constructor_type} is not registered '
'in the optimizer builder registry.')
def build_optimizer(model, cfg):
optimizer_cfg = copy.deepcopy(cfg)
constructor_type = optimizer_cfg.pop('constructor',
'DefaultOptimizerConstructor')
paramwise_cfg = optimizer_cfg.pop('paramwise_cfg', None)
optim_constructor = build_optimizer_constructor(
dict(
type=constructor_type,
optimizer_cfg=optimizer_cfg,
paramwise_cfg=paramwise_cfg))
optimizer = optim_constructor(model)
return optimizer
optimizer = build_optimizer(model, optimizer_cfg)
Prepare optimizer in MMEngine
MMEngine needs to configure optim_wrapper for Runner
. For more complicated cases, you can also configure the optim_wrapper
more specifically. See more information in the API documents
Configuration changes
Configuration in MMCV | Configuration in MMEngine |
---|---|
|
|
For the high-level tasks like detection and classification, MMCV needs to configure `optim_config` to build `OptimizerHook`, while not necessary for MMEngine.
optim_wrapper
used in this tutorial is as follows:
from torch.optim import SGD
optimizer = SGD(model.parameters(), lr=0.1, momentum=0.9)
optim_wrapper = dict(optimizer=optimizer)
Prepare hooks
Prepare hooks in MMCV
The commonly used hooks configuration in MMCV is as follows:
# learning rate scheduler config
lr_config = dict(policy='step', step=[2, 3])
# configuration of optimizer
optimizer_config = dict(grad_clip=None)
# configuration of saving checkpoints periodically
checkpoint_config = dict(interval=1)
# save log periodically and multiple hooks can be used simultaneously
log_config = dict(interval=100, hooks=[dict(type='TextLoggerHook')])
# register hooks to runner and those hooks will be invoked automatically
runner.register_training_hooks(
lr_config=lr_config,
optimizer_config=optimizer_config,
checkpoint_config=checkpoint_config,
log_config=log_config)
Among them:
lr_config
is used forLrUpdaterHook
optimizer_config
is used forOptimizerHook
checkpoint_config
is used forCheckPointHook
log_config
is used forLoggerHook
Besides the hooks mentioned above, MMCV Runner will build IterTimerHook
automatically. MMCV Runner
will register the training hooks after instantiating the model, while MMEngine Runner will initialize the hooks during instantiating the model.
Prepare hooks in MMEngine
MMEngine Runner
takes some commonly used hooks in MMCV as the default hooks.
Compared with the example of MMCV
LrUpdaterHook
correspond to theParamSchedulerHook
, find more details in migrate scheduler- MMEngine optimize the model in train_step, therefore we do not need
OptimizerHook
in MMEngine anymore - MMEngine takes
CheckPointHook
as the default hook - MMEngine take
LoggerHook
as the default hook
Therefore, we can achieve the same effect as the MMCV example as long as we configure the param_scheduler correctly.
We can also register custom hooks in MMEngine runner, find more details in runner tutorial and migrate hook.
Commonly used hooks in MMCV | Default hooks in MMEngine |
---|---|
|
|
The parameter scheduler used in this tutorial is as follows:
from math import gamma
param_scheduler = dict(type='MultiStepLR', milestones=[2, 3], gamma=0.1)
Prepare testing/validation components
MMCV implements the validation process by EvalHook
, and we'll not talk too much about it here. Given that validation is a common process in training, MMEngine abstracts validation as two independent modules: Evaluator and ValLoop. We can customize the metric or the validation process by defining a new loop or a new metric.
import torch
from mmengine.evaluator import BaseMetric
from mmengine.registry import METRICS
@METRICS.register_module(force=True)
class ToyAccuracyMetric(BaseMetric):
def process(self, label, pred) -> None:
self.results.append((label[1], pred, len(label[1])))
def compute_metrics(self, results: list) -> dict:
num_sample = 0
acc = 0
for label, pred, batch_size in results:
acc += (label == torch.stack(pred)).sum()
num_sample += batch_size
return dict(Accuracy=acc / num_sample)
After defining the metric, we should also configure the evaluator and loop for Runner
. The example used in this tutorial is as follows:
val_evaluator = dict(type='ToyAccuracyMetric')
val_cfg = dict(type='ValLoop')
Configure validation in MMCV | Configure validation in MMEngine |
---|---|
|
|
Build Runner
Building Runner in MMCV
runner = EpochBasedRunner(
model=model,
optimizer=optimizer,
work_dir=work_dir,
logger=logger,
max_epochs=4
)
Building Runner in MMEngine
The EpochBasedRunner
and max_epochs
arguments in MMCV
are moved to train_cfg
in MMEngine. All parameters configurable in train_cfg
are listed below:
- by_epoch:
True
equivalent toEpochBasedRunner
.False
equivalent toIterBasedRunner
max_epoch/max_iter
: Equivalent tomax_epochs
andmax_iters
in MMCVval_iterval
: Equivalent tointerval
in MMCV
from mmengine.runner import Runner
runner = Runner(
model=model, # model to be optimized
work_dir='./work_dir', # working directory
randomness=randomness, # random seed
env_cfg=env_cfg, # environment config
launcher='none', # launcher for distributed training
optim_wrapper=optim_wrapper, # configure optimizer wrapper
param_scheduler=param_scheduler, # configure parameter scheduler
train_dataloader=train_dataloader, # configure train dataloader
train_cfg=dict(by_epoch=True, max_epochs=4, val_interval=1), # Configure training loop
val_dataloader=val_dataloader, # Configure validation dataloader
val_evaluator=val_evaluator, # Configure evaluator and metrics
val_cfg=val_cfg) # Configure validation loop
Load checkpoint
Loading checkpoint in MMCV
if cfg.resume_from:
runner.resume(cfg.resume_from)
elif cfg.load_from:
runner.load_checkpoint(cfg.load_from)
Loading checkpoint in MMEngine
runner = Runner(
...
load_from='/path/to/checkpoint',
resume=True
)
Configuration of loading checkpoint in MMCV | Configuration of loading checkpoint in MMEngine |
---|---|
|
|
|
|
Training process
Training process in MMCV
Resume or load checkpoint firstly, and then start training.
if cfg.resume_from:
runner.resume(cfg.resume_from)
elif cfg.load_from:
runner.load_checkpoint(cfg.load_from)
runner.run(data_loaders, cfg.workflow)
Training process in MMEngine
Complete the process mentioned above the Runner.__init__
and Runner.train
runner.train()
Testing process
Since MMCV Runner does not integrate the test function, we need to implement the test scripts by ourselves.
For MMEngine Runner, as long as we have configured the test_dataloader
, test_cfg
and test_evaluator
for the Runner
, we can call Runner.test
to start the testing process.
work_dir
is the same for training
runner = Runner(
model=model,
work_dir='./work_dir',
randomness=randomness,
env_cfg=env_cfg,
launcher='none', # 不开启分布式训练
optim_wrapper=optim_wrapper,
train_dataloader=train_dataloader,
train_cfg=dict(by_epoch=True, max_epochs=5, val_interval=1),
val_dataloader=val_dataloader,
val_evaluator=val_evaluator,
val_cfg=val_cfg,
test_dataloader=val_dataloader, # 假设测试和验证使用相同的数据和评测器
test_evaluator=val_evaluator,
test_cfg=dict(type='TestLoop'),
)
runner.test()
work_dir
is the different for training, configure load_from manually
runner = Runner(
model=model,
work_dir='./test_work_dir',
load_from='./work_dir/epoch_5.pth', # set load_from additionally
randomness=randomness,
env_cfg=env_cfg,
launcher='none',
optim_wrapper=optim_wrapper,
train_dataloader=train_dataloader,
train_cfg=dict(by_epoch=True, max_epochs=5, val_interval=1),
val_dataloader=val_dataloader,
val_evaluator=val_evaluator,
val_cfg=val_cfg,
test_dataloader=val_dataloader,
test_evaluator=val_evaluator,
test_cfg=dict(type='TestLoop'),
)
runner.test()
Customize training process
If we want to customize a training/validation process, we need to override the Runner.val
or Runner.train
in a custom Runner
. Take overriding runner.train
as an example, suppose we need to train with the same batch twice for each iteration, we can override the Runner.train
like this:
class CustomRunner(EpochBasedRunner):
def train(self, data_loader, **kwargs):
self.model.train()
self.mode = 'train'
self.data_loader = data_loader
self._max_iters = self._max_epochs * len(self.data_loader)
self.call_hook('before_train_epoch')
time.sleep(2) # Prevent possible deadlock during epoch transition
for i, data_batch in enumerate(self.data_loader):
self.data_batch = data_batch
self._inner_iter = i
for _ in range(2)
self.call_hook('before_train_iter')
self.run_iter(data_batch, train_mode=True, **kwargs)
self.call_hook('after_train_iter')
del self.data_batch
self._iter += 1
self.call_hook('after_train_epoch')
self._epoch += 1
In MMEngine, we need to customize a train loop.
from mmengine.registry import LOOPS
from mmengine.runner import EpochBasedTrainLoop
@LOOPS.register_module()
class CustomEpochBasedTrainLoop(EpochBasedTrainLoop):
def run_iter(self, idx, data_batch) -> None:
for _ in range(2):
super().run_iter(idx, data_batch)
and then, we need to set type
as CustomEpochBasedTrainLoop
in train_cfg
. Note that by_epoch
and type
cannot be configured at the same time. Once by_epoch
is configured, the type of the training loop will be inferred as EpochBasedTrainLoop
.
runner = Runner(
model=model,
work_dir='./test_work_dir',
randomness=randomness,
env_cfg=env_cfg,
launcher='none',
optim_wrapper=dict(optimizer=dict(type='SGD', lr=0.001, momentum=0.9)),
train_dataloader=train_dataloader,
train_cfg=dict(
type='CustomEpochBasedTrainLoop',
max_epochs=5,
val_interval=1),
val_dataloader=val_dataloader,
val_evaluator=val_evaluator,
val_cfg=val_cfg,
test_dataloader=val_dataloader,
test_evaluator=val_evaluator,
test_cfg=dict(type='TestLoop'),
)
runner.train()
For more complicated migration needs of Runner
, you can refer to the runner tutorials and runner design.