mirror of https://github.com/open-mmlab/mmcv.git
[Feature] Add revert_sync_batchnorm (#1253)
* [Feature] Add revert_sync_batchnorm * support mmsyncbn (to be tested) * Test passed * Update docstring, rename the test file * remove test_sync_bn * add comment * add mmcv.ops check * Add a comment * Add notes and relax test req Co-authored-by: gaotongxiao <gaotongxiao@gmail.con>pull/1327/head
parent
99088c81a8
commit
642d281823
|
@ -1,6 +1,7 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .flops_counter import get_model_complexity_info
|
||||
from .fuse_conv_bn import fuse_conv_bn
|
||||
from .sync_bn import revert_sync_batchnorm
|
||||
from .weight_init import (INITIALIZERS, Caffe2XavierInit, ConstantInit,
|
||||
KaimingInit, NormalInit, PretrainedInit,
|
||||
TruncNormalInit, UniformInit, XavierInit,
|
||||
|
@ -14,5 +15,5 @@ __all__ = [
|
|||
'uniform_init', 'xavier_init', 'fuse_conv_bn', 'initialize',
|
||||
'INITIALIZERS', 'ConstantInit', 'XavierInit', 'NormalInit',
|
||||
'TruncNormalInit', 'UniformInit', 'KaimingInit', 'PretrainedInit',
|
||||
'Caffe2XavierInit'
|
||||
'Caffe2XavierInit', 'revert_sync_batchnorm'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,59 @@
|
|||
import torch
|
||||
|
||||
import mmcv
|
||||
|
||||
|
||||
class _BatchNormXd(torch.nn.modules.batchnorm._BatchNorm):
|
||||
"""A general BatchNorm layer without input dimension check.
|
||||
|
||||
Reproduced from @kapily's work:
|
||||
(https://github.com/pytorch/pytorch/issues/41081#issuecomment-783961547)
|
||||
The only difference between BatchNorm1d, BatchNorm2d, BatchNorm3d, etc
|
||||
is `_check_input_dim` that is designed for tensor sanity checks.
|
||||
The check has been bypassed in this class for the convenience of converting
|
||||
SyncBatchNorm.
|
||||
"""
|
||||
|
||||
def _check_input_dim(self, input):
|
||||
return
|
||||
|
||||
|
||||
def revert_sync_batchnorm(module):
|
||||
"""Helper function to convert all `SyncBatchNorm` (SyncBN) and
|
||||
`mmcv.ops.sync_bn.SyncBatchNorm`(MMSyncBN) layers in the model to
|
||||
`BatchNormXd` layers.
|
||||
|
||||
Adapted from @kapily's work:
|
||||
(https://github.com/pytorch/pytorch/issues/41081#issuecomment-783961547)
|
||||
|
||||
Args:
|
||||
module (nn.Module): The module containing `SyncBatchNorm` layers.
|
||||
|
||||
Returns:
|
||||
module_output: The converted module with `BatchNormXd` layers.
|
||||
"""
|
||||
module_output = module
|
||||
module_checklist = [torch.nn.modules.batchnorm.SyncBatchNorm]
|
||||
if hasattr(mmcv, 'ops'):
|
||||
module_checklist.append(mmcv.ops.SyncBatchNorm)
|
||||
if isinstance(module, tuple(module_checklist)):
|
||||
module_output = _BatchNormXd(module.num_features, module.eps,
|
||||
module.momentum, module.affine,
|
||||
module.track_running_stats)
|
||||
if module.affine:
|
||||
# no_grad() may not be needed here but
|
||||
# just to be consistent with `convert_sync_batchnorm()`
|
||||
with torch.no_grad():
|
||||
module_output.weight = module.weight
|
||||
module_output.bias = module.bias
|
||||
module_output.running_mean = module.running_mean
|
||||
module_output.running_var = module.running_var
|
||||
module_output.num_batches_tracked = module.num_batches_tracked
|
||||
module_output.training = module.training
|
||||
# qconfig exists in quantized models
|
||||
if hasattr(module, 'qconfig'):
|
||||
module_output.qconfig = module.qconfig
|
||||
for name, child in module.named_children():
|
||||
module_output.add_module(name, revert_sync_batchnorm(child))
|
||||
del module
|
||||
return module_output
|
|
@ -0,0 +1,58 @@
|
|||
import os
|
||||
import platform
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from mmcv.cnn.bricks import ConvModule
|
||||
from mmcv.cnn.utils import revert_sync_batchnorm
|
||||
|
||||
if platform.system() == 'Windows':
|
||||
import regex as re
|
||||
else:
|
||||
import re
|
||||
|
||||
|
||||
def test_revert_syncbn():
|
||||
conv = ConvModule(3, 8, 2, norm_cfg=dict(type='SyncBN'))
|
||||
x = torch.randn(1, 3, 10, 10)
|
||||
# Expect a ValueError prompting that SyncBN is not supported on CPU
|
||||
with pytest.raises(ValueError):
|
||||
y = conv(x)
|
||||
conv = revert_sync_batchnorm(conv)
|
||||
y = conv(x)
|
||||
assert y.shape == (1, 8, 9, 9)
|
||||
|
||||
|
||||
def test_revert_mmsyncbn():
|
||||
if 'SLURM_NTASKS' not in os.environ or int(os.environ['SLURM_NTASKS']) < 2:
|
||||
print('Must run on slurm with more than 1 process!\n'
|
||||
'srun -p test --gres=gpu:2 -n2')
|
||||
return
|
||||
rank = int(os.environ['SLURM_PROCID'])
|
||||
world_size = int(os.environ['SLURM_NTASKS'])
|
||||
local_rank = int(os.environ['SLURM_LOCALID'])
|
||||
node_list = str(os.environ['SLURM_NODELIST'])
|
||||
|
||||
node_parts = re.findall('[0-9]+', node_list)
|
||||
os.environ['MASTER_ADDR'] = (f'{node_parts[1]}.{node_parts[2]}' +
|
||||
f'.{node_parts[3]}.{node_parts[4]}')
|
||||
os.environ['MASTER_PORT'] = '12341'
|
||||
os.environ['WORLD_SIZE'] = str(world_size)
|
||||
os.environ['RANK'] = str(rank)
|
||||
|
||||
dist.init_process_group('nccl')
|
||||
torch.cuda.set_device(local_rank)
|
||||
x = torch.randn(1, 3, 10, 10).cuda()
|
||||
dist.broadcast(x, src=0)
|
||||
conv = ConvModule(3, 8, 2, norm_cfg=dict(type='MMSyncBN')).cuda()
|
||||
conv.eval()
|
||||
y_mmsyncbn = conv(x).detach().cpu().numpy()
|
||||
conv = revert_sync_batchnorm(conv)
|
||||
y_bn = conv(x).detach().cpu().numpy()
|
||||
assert np.all(np.isclose(y_bn, y_mmsyncbn, 1e-3))
|
||||
conv, x = conv.to('cpu'), x.to('cpu')
|
||||
y_bn_cpu = conv(x).detach().numpy()
|
||||
assert np.all(np.isclose(y_bn, y_bn_cpu, 1e-3))
|
Loading…
Reference in New Issue