mmcv/tests/test_cnn/test_revert_syncbn.py

59 lines
1.9 KiB
Python
Raw Normal View History

import os
import platform
import numpy as np
import pytest
import torch
import torch.distributed as dist
from mmcv.cnn.bricks import ConvModule
from mmcv.cnn.utils import revert_sync_batchnorm
if platform.system() == 'Windows':
import regex as re
else:
import re
def test_revert_syncbn():
conv = ConvModule(3, 8, 2, norm_cfg=dict(type='SyncBN'))
x = torch.randn(1, 3, 10, 10)
# Expect a ValueError prompting that SyncBN is not supported on CPU
with pytest.raises(ValueError):
y = conv(x)
conv = revert_sync_batchnorm(conv)
y = conv(x)
assert y.shape == (1, 8, 9, 9)
def test_revert_mmsyncbn():
if 'SLURM_NTASKS' not in os.environ or int(os.environ['SLURM_NTASKS']) < 2:
print('Must run on slurm with more than 1 process!\n'
'srun -p test --gres=gpu:2 -n2')
return
rank = int(os.environ['SLURM_PROCID'])
world_size = int(os.environ['SLURM_NTASKS'])
local_rank = int(os.environ['SLURM_LOCALID'])
node_list = str(os.environ['SLURM_NODELIST'])
node_parts = re.findall('[0-9]+', node_list)
os.environ['MASTER_ADDR'] = (f'{node_parts[1]}.{node_parts[2]}' +
f'.{node_parts[3]}.{node_parts[4]}')
os.environ['MASTER_PORT'] = '12341'
os.environ['WORLD_SIZE'] = str(world_size)
os.environ['RANK'] = str(rank)
dist.init_process_group('nccl')
torch.cuda.set_device(local_rank)
x = torch.randn(1, 3, 10, 10).cuda()
dist.broadcast(x, src=0)
conv = ConvModule(3, 8, 2, norm_cfg=dict(type='MMSyncBN')).cuda()
conv.eval()
y_mmsyncbn = conv(x).detach().cpu().numpy()
conv = revert_sync_batchnorm(conv)
y_bn = conv(x).detach().cpu().numpy()
assert np.all(np.isclose(y_bn, y_mmsyncbn, 1e-3))
conv, x = conv.to('cpu'), x.to('cpu')
y_bn_cpu = conv(x).detach().numpy()
assert np.all(np.isclose(y_bn, y_bn_cpu, 1e-3))