mirror of https://github.com/open-mmlab/mmcv.git
60 lines
1.9 KiB
Python
60 lines
1.9 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import os
|
|
import platform
|
|
|
|
import numpy as np
|
|
import pytest
|
|
import torch
|
|
import torch.distributed as dist
|
|
|
|
from mmcv.cnn.bricks import ConvModule
|
|
from mmcv.cnn.utils import revert_sync_batchnorm
|
|
|
|
if platform.system() == 'Windows':
|
|
import regex as re
|
|
else:
|
|
import re
|
|
|
|
|
|
def test_revert_syncbn():
|
|
conv = ConvModule(3, 8, 2, norm_cfg=dict(type='SyncBN'))
|
|
x = torch.randn(1, 3, 10, 10)
|
|
# Expect a ValueError prompting that SyncBN is not supported on CPU
|
|
with pytest.raises(ValueError):
|
|
y = conv(x)
|
|
conv = revert_sync_batchnorm(conv)
|
|
y = conv(x)
|
|
assert y.shape == (1, 8, 9, 9)
|
|
|
|
|
|
def test_revert_mmsyncbn():
|
|
if 'SLURM_NTASKS' not in os.environ or int(os.environ['SLURM_NTASKS']) < 2:
|
|
print('Must run on slurm with more than 1 process!\n'
|
|
'srun -p test --gres=gpu:2 -n2')
|
|
return
|
|
rank = int(os.environ['SLURM_PROCID'])
|
|
world_size = int(os.environ['SLURM_NTASKS'])
|
|
local_rank = int(os.environ['SLURM_LOCALID'])
|
|
node_list = str(os.environ['SLURM_NODELIST'])
|
|
|
|
node_parts = re.findall('[0-9]+', node_list)
|
|
os.environ['MASTER_ADDR'] = (f'{node_parts[1]}.{node_parts[2]}' +
|
|
f'.{node_parts[3]}.{node_parts[4]}')
|
|
os.environ['MASTER_PORT'] = '12341'
|
|
os.environ['WORLD_SIZE'] = str(world_size)
|
|
os.environ['RANK'] = str(rank)
|
|
|
|
dist.init_process_group('nccl')
|
|
torch.cuda.set_device(local_rank)
|
|
x = torch.randn(1, 3, 10, 10).cuda()
|
|
dist.broadcast(x, src=0)
|
|
conv = ConvModule(3, 8, 2, norm_cfg=dict(type='MMSyncBN')).cuda()
|
|
conv.eval()
|
|
y_mmsyncbn = conv(x).detach().cpu().numpy()
|
|
conv = revert_sync_batchnorm(conv)
|
|
y_bn = conv(x).detach().cpu().numpy()
|
|
assert np.all(np.isclose(y_bn, y_mmsyncbn, 1e-3))
|
|
conv, x = conv.to('cpu'), x.to('cpu')
|
|
y_bn_cpu = conv(x).detach().numpy()
|
|
assert np.all(np.isclose(y_bn, y_bn_cpu, 1e-3))
|