mirror of https://github.com/open-mmlab/mmcv.git
296 lines
9.5 KiB
Python
296 lines
9.5 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import os
|
|
import platform
|
|
|
|
import numpy as np
|
|
import pytest
|
|
import torch
|
|
import torch.distributed as dist
|
|
import torch.nn as nn
|
|
|
|
if platform.system() == 'Windows':
|
|
import regex as re
|
|
else:
|
|
import re
|
|
|
|
|
|
class TestSyncBN:
|
|
|
|
def dist_init(self):
|
|
rank = int(os.environ['SLURM_PROCID'])
|
|
world_size = int(os.environ['SLURM_NTASKS'])
|
|
local_rank = int(os.environ['SLURM_LOCALID'])
|
|
node_list = str(os.environ['SLURM_NODELIST'])
|
|
|
|
node_parts = re.findall('[0-9]+', node_list)
|
|
os.environ['MASTER_ADDR'] = (f'{node_parts[1]}.{node_parts[2]}' +
|
|
f'.{node_parts[3]}.{node_parts[4]}')
|
|
os.environ['MASTER_PORT'] = '12341'
|
|
os.environ['WORLD_SIZE'] = str(world_size)
|
|
os.environ['RANK'] = str(rank)
|
|
|
|
dist.init_process_group('nccl')
|
|
torch.cuda.set_device(local_rank)
|
|
|
|
def _test_syncbn_train(self, size=1, half=False):
|
|
|
|
if 'SLURM_NTASKS' not in os.environ or int(
|
|
os.environ['SLURM_NTASKS']) != 4:
|
|
print('must run with slurm has 4 processes!\n'
|
|
'srun -p test --gres=gpu:4 -n4')
|
|
return
|
|
else:
|
|
print('Running syncbn test')
|
|
from mmcv.ops import SyncBatchNorm
|
|
|
|
assert size in (1, 2, 4)
|
|
if not dist.is_initialized():
|
|
self.dist_init()
|
|
rank = dist.get_rank()
|
|
|
|
torch.manual_seed(9)
|
|
torch.cuda.manual_seed(9)
|
|
|
|
self.x = torch.rand(16, 3, 2, 3).cuda()
|
|
self.y_bp = torch.rand(16, 3, 2, 3).cuda()
|
|
|
|
if half:
|
|
self.x = self.x.half()
|
|
self.y_bp = self.y_bp.half()
|
|
dist.broadcast(self.x, src=0)
|
|
dist.broadcast(self.y_bp, src=0)
|
|
|
|
torch.cuda.synchronize()
|
|
if size == 1:
|
|
groups = [None, None, None, None]
|
|
groups[0] = dist.new_group([0])
|
|
groups[1] = dist.new_group([1])
|
|
groups[2] = dist.new_group([2])
|
|
groups[3] = dist.new_group([3])
|
|
group = groups[rank]
|
|
elif size == 2:
|
|
groups = [None, None, None, None]
|
|
groups[0] = groups[1] = dist.new_group([0, 1])
|
|
groups[2] = groups[3] = dist.new_group([2, 3])
|
|
group = groups[rank]
|
|
elif size == 4:
|
|
group = dist.group.WORLD
|
|
syncbn = SyncBatchNorm(3, group=group).cuda()
|
|
syncbn.weight.data[0] = 0.2
|
|
syncbn.weight.data[1] = 0.5
|
|
syncbn.weight.data[2] = 0.7
|
|
syncbn.train()
|
|
|
|
bn = nn.BatchNorm2d(3).cuda()
|
|
bn.weight.data[0] = 0.2
|
|
bn.weight.data[1] = 0.5
|
|
bn.weight.data[2] = 0.7
|
|
bn.train()
|
|
|
|
sx = self.x[rank * 4:rank * 4 + 4]
|
|
sx.requires_grad_()
|
|
sy = syncbn(sx)
|
|
sy.backward(self.y_bp[rank * 4:rank * 4 + 4])
|
|
|
|
smean = syncbn.running_mean
|
|
svar = syncbn.running_var
|
|
sx_grad = sx.grad
|
|
sw_grad = syncbn.weight.grad
|
|
sb_grad = syncbn.bias.grad
|
|
|
|
if size == 1:
|
|
x = self.x[rank * 4:rank * 4 + 4]
|
|
y_bp = self.y_bp[rank * 4:rank * 4 + 4]
|
|
elif size == 2:
|
|
x = self.x[rank // 2 * 8:rank // 2 * 8 + 8]
|
|
y_bp = self.y_bp[rank // 2 * 8:rank // 2 * 8 + 8]
|
|
elif size == 4:
|
|
x = self.x
|
|
y_bp = self.y_bp
|
|
x.requires_grad_()
|
|
y = bn(x)
|
|
y.backward(y_bp)
|
|
|
|
if size == 2:
|
|
y = y[rank % 2 * 4:rank % 2 * 4 + 4]
|
|
elif size == 4:
|
|
y = y[rank * 4:rank * 4 + 4]
|
|
|
|
mean = bn.running_mean
|
|
var = bn.running_var
|
|
if size == 1:
|
|
x_grad = x.grad
|
|
w_grad = bn.weight.grad
|
|
b_grad = bn.bias.grad
|
|
elif size == 2:
|
|
x_grad = x.grad[rank % 2 * 4:rank % 2 * 4 + 4]
|
|
w_grad = bn.weight.grad / 2
|
|
b_grad = bn.bias.grad / 2
|
|
elif size == 4:
|
|
x_grad = x.grad[rank * 4:rank * 4 + 4]
|
|
w_grad = bn.weight.grad / 4
|
|
b_grad = bn.bias.grad / 4
|
|
|
|
assert np.allclose(mean.data.cpu().numpy(),
|
|
smean.data.cpu().numpy(), 1e-3)
|
|
assert np.allclose(var.data.cpu().numpy(),
|
|
svar.data.cpu().numpy(), 1e-3)
|
|
assert np.allclose(y.data.cpu().numpy(), sy.data.cpu().numpy(), 1e-3)
|
|
assert np.allclose(w_grad.data.cpu().numpy(),
|
|
sw_grad.data.cpu().numpy(), 1e-3)
|
|
assert np.allclose(b_grad.data.cpu().numpy(),
|
|
sb_grad.data.cpu().numpy(), 1e-3)
|
|
assert np.allclose(x_grad.data.cpu().numpy(),
|
|
sx_grad.data.cpu().numpy(), 1e-2)
|
|
|
|
def _test_syncbn_empty_train(self, size=1, half=False):
|
|
|
|
if 'SLURM_NTASKS' not in os.environ or int(
|
|
os.environ['SLURM_NTASKS']) != 4:
|
|
print('must run with slurm has 4 processes!\n'
|
|
'srun -p test --gres=gpu:4 -n4')
|
|
return
|
|
else:
|
|
print('Running syncbn test')
|
|
from mmcv.ops import SyncBatchNorm
|
|
|
|
assert size in (1, 2, 4)
|
|
if not dist.is_initialized():
|
|
self.dist_init()
|
|
rank = dist.get_rank()
|
|
|
|
torch.manual_seed(9)
|
|
torch.cuda.manual_seed(9)
|
|
|
|
self.x = torch.rand(0, 3, 2, 3).cuda()
|
|
self.y_bp = torch.rand(0, 3, 2, 3).cuda()
|
|
|
|
if half:
|
|
self.x = self.x.half()
|
|
self.y_bp = self.y_bp.half()
|
|
dist.broadcast(self.x, src=0)
|
|
dist.broadcast(self.y_bp, src=0)
|
|
|
|
torch.cuda.synchronize()
|
|
if size == 1:
|
|
groups = [None, None, None, None]
|
|
groups[0] = dist.new_group([0])
|
|
groups[1] = dist.new_group([1])
|
|
groups[2] = dist.new_group([2])
|
|
groups[3] = dist.new_group([3])
|
|
group = groups[rank]
|
|
elif size == 2:
|
|
groups = [None, None, None, None]
|
|
groups[0] = groups[1] = dist.new_group([0, 1])
|
|
groups[2] = groups[3] = dist.new_group([2, 3])
|
|
group = groups[rank]
|
|
elif size == 4:
|
|
group = dist.group.WORLD
|
|
|
|
syncbn = SyncBatchNorm(3, group=group, stats_mode='N').cuda()
|
|
syncbn.weight.data[0] = 0.2
|
|
syncbn.weight.data[1] = 0.5
|
|
syncbn.weight.data[2] = 0.7
|
|
syncbn.train()
|
|
|
|
bn = nn.BatchNorm2d(3).cuda()
|
|
bn.weight.data[0] = 0.2
|
|
bn.weight.data[1] = 0.5
|
|
bn.weight.data[2] = 0.7
|
|
bn.train()
|
|
|
|
sx = self.x[rank * 4:rank * 4 + 4]
|
|
sx.requires_grad_()
|
|
sy = syncbn(sx)
|
|
sy.backward(self.y_bp[rank * 4:rank * 4 + 4])
|
|
smean = syncbn.running_mean
|
|
svar = syncbn.running_var
|
|
sx_grad = sx.grad
|
|
sw_grad = syncbn.weight.grad
|
|
sb_grad = syncbn.bias.grad
|
|
|
|
if size == 1:
|
|
x = self.x[rank * 4:rank * 4 + 4]
|
|
y_bp = self.y_bp[rank * 4:rank * 4 + 4]
|
|
elif size == 2:
|
|
x = self.x[rank // 2 * 8:rank // 2 * 8 + 8]
|
|
y_bp = self.y_bp[rank // 2 * 8:rank // 2 * 8 + 8]
|
|
elif size == 4:
|
|
x = self.x
|
|
y_bp = self.y_bp
|
|
x.requires_grad_()
|
|
y = bn(x)
|
|
y.backward(y_bp)
|
|
|
|
if size == 2:
|
|
y = y[rank % 2 * 4:rank % 2 * 4 + 4]
|
|
elif size == 4:
|
|
y = y[rank * 4:rank * 4 + 4]
|
|
|
|
mean = bn.running_mean
|
|
var = bn.running_var
|
|
if size == 1:
|
|
x_grad = x.grad
|
|
w_grad = bn.weight.grad
|
|
b_grad = bn.bias.grad
|
|
elif size == 2:
|
|
x_grad = x.grad[rank % 2 * 4:rank % 2 * 4 + 4]
|
|
w_grad = bn.weight.grad / 2
|
|
b_grad = bn.bias.grad / 2
|
|
elif size == 4:
|
|
x_grad = x.grad[rank * 4:rank * 4 + 4]
|
|
w_grad = bn.weight.grad / 4
|
|
b_grad = bn.bias.grad / 4
|
|
|
|
assert np.allclose(mean.data.cpu().numpy(),
|
|
smean.data.cpu().numpy(), 1e-3)
|
|
assert np.allclose(var.data.cpu().numpy(),
|
|
svar.data.cpu().numpy(), 1e-3)
|
|
assert np.allclose(y.data.cpu().numpy(), sy.data.cpu().numpy(), 1e-3)
|
|
assert np.allclose(w_grad.data.cpu().numpy(),
|
|
sw_grad.data.cpu().numpy(), 1e-3)
|
|
assert np.allclose(b_grad.data.cpu().numpy(),
|
|
sb_grad.data.cpu().numpy(), 1e-3)
|
|
assert np.allclose(x_grad.data.cpu().numpy(),
|
|
sx_grad.data.cpu().numpy(), 1e-2)
|
|
|
|
# 'stats_mode' only allows 'default' and 'N'
|
|
with pytest.raises(AssertionError):
|
|
SyncBatchNorm(3, group=group, stats_mode='X')
|
|
|
|
def test_syncbn_1(self):
|
|
self._test_syncbn_train(size=1)
|
|
|
|
def test_syncbn_2(self):
|
|
self._test_syncbn_train(size=2)
|
|
|
|
def test_syncbn_4(self):
|
|
self._test_syncbn_train(size=4)
|
|
|
|
def test_syncbn_1_half(self):
|
|
self._test_syncbn_train(size=1, half=True)
|
|
|
|
def test_syncbn_2_half(self):
|
|
self._test_syncbn_train(size=2, half=True)
|
|
|
|
def test_syncbn_4_half(self):
|
|
self._test_syncbn_train(size=4, half=True)
|
|
|
|
def test_syncbn_empty_1(self):
|
|
self._test_syncbn_empty_train(size=1)
|
|
|
|
def test_syncbn_empty_2(self):
|
|
self._test_syncbn_empty_train(size=2)
|
|
|
|
def test_syncbn_empty_4(self):
|
|
self._test_syncbn_empty_train(size=4)
|
|
|
|
def test_syncbn_empty_1_half(self):
|
|
self._test_syncbn_empty_train(size=1, half=True)
|
|
|
|
def test_syncbn_empty_2_half(self):
|
|
self._test_syncbn_empty_train(size=2, half=True)
|
|
|
|
def test_syncbn_empty_4_half(self):
|
|
self._test_syncbn_empty_train(size=4, half=True)
|