[Feature] Add revert_sync_batchnorm (#1253)

* [Feature] Add revert_sync_batchnorm

* support mmsyncbn (to be tested)

* Test passed

* Update docstring, rename the test file

* remove test_sync_bn

* add comment

* add mmcv.ops check

* Add a comment

* Add notes and relax test req

Co-authored-by: gaotongxiao <gaotongxiao@gmail.con>
pull/1327/head
Tong Gao 2021-09-08 10:59:39 +08:00 committed by GitHub
parent 99088c81a8
commit 642d281823
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 119 additions and 1 deletions

View File

@ -1,6 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .flops_counter import get_model_complexity_info
from .fuse_conv_bn import fuse_conv_bn
from .sync_bn import revert_sync_batchnorm
from .weight_init import (INITIALIZERS, Caffe2XavierInit, ConstantInit,
KaimingInit, NormalInit, PretrainedInit,
TruncNormalInit, UniformInit, XavierInit,
@ -14,5 +15,5 @@ __all__ = [
'uniform_init', 'xavier_init', 'fuse_conv_bn', 'initialize',
'INITIALIZERS', 'ConstantInit', 'XavierInit', 'NormalInit',
'TruncNormalInit', 'UniformInit', 'KaimingInit', 'PretrainedInit',
'Caffe2XavierInit'
'Caffe2XavierInit', 'revert_sync_batchnorm'
]

View File

@ -0,0 +1,59 @@
import torch
import mmcv
class _BatchNormXd(torch.nn.modules.batchnorm._BatchNorm):
"""A general BatchNorm layer without input dimension check.
Reproduced from @kapily's work:
(https://github.com/pytorch/pytorch/issues/41081#issuecomment-783961547)
The only difference between BatchNorm1d, BatchNorm2d, BatchNorm3d, etc
is `_check_input_dim` that is designed for tensor sanity checks.
The check has been bypassed in this class for the convenience of converting
SyncBatchNorm.
"""
def _check_input_dim(self, input):
return
def revert_sync_batchnorm(module):
"""Helper function to convert all `SyncBatchNorm` (SyncBN) and
`mmcv.ops.sync_bn.SyncBatchNorm`(MMSyncBN) layers in the model to
`BatchNormXd` layers.
Adapted from @kapily's work:
(https://github.com/pytorch/pytorch/issues/41081#issuecomment-783961547)
Args:
module (nn.Module): The module containing `SyncBatchNorm` layers.
Returns:
module_output: The converted module with `BatchNormXd` layers.
"""
module_output = module
module_checklist = [torch.nn.modules.batchnorm.SyncBatchNorm]
if hasattr(mmcv, 'ops'):
module_checklist.append(mmcv.ops.SyncBatchNorm)
if isinstance(module, tuple(module_checklist)):
module_output = _BatchNormXd(module.num_features, module.eps,
module.momentum, module.affine,
module.track_running_stats)
if module.affine:
# no_grad() may not be needed here but
# just to be consistent with `convert_sync_batchnorm()`
with torch.no_grad():
module_output.weight = module.weight
module_output.bias = module.bias
module_output.running_mean = module.running_mean
module_output.running_var = module.running_var
module_output.num_batches_tracked = module.num_batches_tracked
module_output.training = module.training
# qconfig exists in quantized models
if hasattr(module, 'qconfig'):
module_output.qconfig = module.qconfig
for name, child in module.named_children():
module_output.add_module(name, revert_sync_batchnorm(child))
del module
return module_output

View File

@ -0,0 +1,58 @@
import os
import platform
import numpy as np
import pytest
import torch
import torch.distributed as dist
from mmcv.cnn.bricks import ConvModule
from mmcv.cnn.utils import revert_sync_batchnorm
if platform.system() == 'Windows':
import regex as re
else:
import re
def test_revert_syncbn():
conv = ConvModule(3, 8, 2, norm_cfg=dict(type='SyncBN'))
x = torch.randn(1, 3, 10, 10)
# Expect a ValueError prompting that SyncBN is not supported on CPU
with pytest.raises(ValueError):
y = conv(x)
conv = revert_sync_batchnorm(conv)
y = conv(x)
assert y.shape == (1, 8, 9, 9)
def test_revert_mmsyncbn():
if 'SLURM_NTASKS' not in os.environ or int(os.environ['SLURM_NTASKS']) < 2:
print('Must run on slurm with more than 1 process!\n'
'srun -p test --gres=gpu:2 -n2')
return
rank = int(os.environ['SLURM_PROCID'])
world_size = int(os.environ['SLURM_NTASKS'])
local_rank = int(os.environ['SLURM_LOCALID'])
node_list = str(os.environ['SLURM_NODELIST'])
node_parts = re.findall('[0-9]+', node_list)
os.environ['MASTER_ADDR'] = (f'{node_parts[1]}.{node_parts[2]}' +
f'.{node_parts[3]}.{node_parts[4]}')
os.environ['MASTER_PORT'] = '12341'
os.environ['WORLD_SIZE'] = str(world_size)
os.environ['RANK'] = str(rank)
dist.init_process_group('nccl')
torch.cuda.set_device(local_rank)
x = torch.randn(1, 3, 10, 10).cuda()
dist.broadcast(x, src=0)
conv = ConvModule(3, 8, 2, norm_cfg=dict(type='MMSyncBN')).cuda()
conv.eval()
y_mmsyncbn = conv(x).detach().cpu().numpy()
conv = revert_sync_batchnorm(conv)
y_bn = conv(x).detach().cpu().numpy()
assert np.all(np.isclose(y_bn, y_mmsyncbn, 1e-3))
conv, x = conv.to('cpu'), x.to('cpu')
y_bn_cpu = conv(x).detach().numpy()
assert np.all(np.isclose(y_bn, y_bn_cpu, 1e-3))