mirror of
https://github.com/alibaba/EasyCV.git
synced 2025-06-03 14:49:00 +08:00
356 lines
11 KiB
Python
356 lines
11 KiB
Python
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
import unittest
|
|
|
|
import numpy as np
|
|
import torch
|
|
from mmcv.runner import init_dist
|
|
from torch import distributed as dist
|
|
from torch.distributed import ReduceOp
|
|
|
|
from easycv.utils.test_util import DistributedTestCase
|
|
|
|
# ReduceOp arg can not be passed through shell cmd, map with str
|
|
op_map = {
|
|
'MAX': ReduceOp.MAX,
|
|
'MIN': ReduceOp.MIN,
|
|
'SUM': ReduceOp.SUM,
|
|
'PRODUCT': ReduceOp.PRODUCT
|
|
}
|
|
|
|
|
|
def _init_dist(torchacc_enabled):
|
|
if torchacc_enabled:
|
|
from easycv.toolkit.torchacc import torchacc_init
|
|
torchacc_init()
|
|
else:
|
|
init_dist(launcher='pytorch')
|
|
|
|
|
|
def _check_type(module,
|
|
is_raw_module=True,
|
|
torchacc_enabled=False,
|
|
value=None):
|
|
rank = dist.get_rank()
|
|
|
|
if torchacc_enabled:
|
|
if value is not None:
|
|
import torchacc.torch_xla.core.xla_model as xm
|
|
assert value.device == xm.xla_device()
|
|
if is_raw_module:
|
|
assert module.__module__ == 'torchacc.torch_xla.core.xla_model'
|
|
else:
|
|
assert module.__module__ == 'easycv.toolkit.torchacc.convert_ops'
|
|
else:
|
|
if value is not None:
|
|
cur_device = torch.device('cuda:{}'.format(rank))
|
|
assert value.device == cur_device
|
|
assert module.__module__ == 'torch.distributed.distributed_c10d'
|
|
|
|
|
|
def _create_value(base_value):
|
|
if isinstance(base_value, (int, float)):
|
|
base_value = torch.tensor(base_value).cuda()
|
|
else:
|
|
base_value = torch.Tensor(base_value).cuda()
|
|
|
|
rank = dist.get_rank()
|
|
return base_value * (rank + 1)
|
|
|
|
|
|
def _dist_info(torchacc_enabled=False):
|
|
_init_dist(torchacc_enabled)
|
|
|
|
rank = dist.get_rank()
|
|
world_size = dist.get_world_size()
|
|
|
|
_check_type(dist.get_rank, torchacc_enabled=torchacc_enabled)
|
|
_check_type(dist.get_world_size, torchacc_enabled=torchacc_enabled)
|
|
|
|
return rank, world_size
|
|
|
|
|
|
def _reduce(base_value, op_str, dst, torchacc_enabled=False):
|
|
_init_dist(torchacc_enabled)
|
|
|
|
value = _create_value(base_value)
|
|
dist.reduce(value, int(dst), op=op_map[op_str])
|
|
_check_type(
|
|
dist.reduce,
|
|
is_raw_module=False,
|
|
torchacc_enabled=torchacc_enabled,
|
|
value=value)
|
|
|
|
return value.cpu().numpy().tolist()
|
|
|
|
|
|
def _all_gather(base_value, torchacc_enabled=False):
|
|
_init_dist(torchacc_enabled)
|
|
|
|
value = _create_value(base_value)
|
|
|
|
world_size = dist.get_world_size()
|
|
tensor_list = [
|
|
torch.zeros(value.size(), dtype=value.dtype, device=value.device)
|
|
for _ in range(world_size)
|
|
]
|
|
dist.all_gather(tensor_list, value)
|
|
|
|
_check_type(
|
|
dist.all_gather,
|
|
is_raw_module=False,
|
|
torchacc_enabled=torchacc_enabled,
|
|
value=value)
|
|
|
|
return [i.cpu().numpy().tolist() for i in tensor_list]
|
|
|
|
|
|
def _all_reduce(base_value, op_str, torchacc_enabled=False):
|
|
_init_dist(torchacc_enabled)
|
|
|
|
value = _create_value(base_value)
|
|
dist.all_reduce(value, op=op_map[op_str])
|
|
|
|
_check_type(
|
|
dist.all_reduce,
|
|
is_raw_module=False,
|
|
torchacc_enabled=torchacc_enabled,
|
|
value=value)
|
|
|
|
return value.cpu().numpy().tolist()
|
|
|
|
|
|
def _skip():
|
|
torchacc_enabled = True
|
|
try:
|
|
import torchacc
|
|
except:
|
|
torchacc_enabled = False
|
|
|
|
return not (torchacc_enabled and torch.cuda.device_count() > 1)
|
|
|
|
|
|
@unittest.skipIf(_skip(), 'distributed unittest for torchacc')
|
|
class ConvertOpsTest(DistributedTestCase):
|
|
|
|
def setUp(self):
|
|
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
|
|
|
|
def test_dist_info(self):
|
|
|
|
def assert_callback(res):
|
|
self.assertEqual(list(res[0]), [0, 2]) # rank 0
|
|
self.assertEqual(list(res[1]), [1, 2]) # rank 1
|
|
|
|
# test original all_reduce
|
|
self.start_with_torch(
|
|
_dist_info,
|
|
num_gpus=2,
|
|
assert_callback=assert_callback,
|
|
save_all_ranks=True)
|
|
|
|
# test torchacc all_reduce
|
|
kwargs = {'torchacc_enabled': True}
|
|
self.start_with_torchacc(
|
|
_dist_info,
|
|
num_gpus=2,
|
|
assert_callback=assert_callback,
|
|
save_all_ranks=True,
|
|
**kwargs)
|
|
|
|
@unittest.skipIf(True, 'fix reduce with hang')
|
|
def test_reduce(self):
|
|
# base value will multiply with rank+1, and return all ranks results
|
|
cases = [
|
|
{
|
|
'assert_callback':
|
|
lambda res: self.assertListEqual(res, [3.0, 2.0]),
|
|
'base_value':
|
|
1.0,
|
|
'op_str':
|
|
'SUM',
|
|
'dst':
|
|
0
|
|
},
|
|
{
|
|
'assert_callback': lambda res: self.assertEqual(res, [2, 4]),
|
|
'base_value': 2,
|
|
'op_str': 'MIN',
|
|
'dst': 0
|
|
},
|
|
{
|
|
'assert_callback':
|
|
lambda res: self.assertListEqual(res, [[1.0], [2.0]]),
|
|
'base_value': [1.0],
|
|
'op_str':
|
|
'MAX',
|
|
'dst':
|
|
1
|
|
},
|
|
{
|
|
'assert_callback':
|
|
lambda res: self.assertListEqual(res, [[1.0, 2.0], [2.0, 8.0]]
|
|
),
|
|
'base_value': [1.0, 2.0],
|
|
'op_str':
|
|
'PRODUCT',
|
|
'dst':
|
|
1
|
|
},
|
|
{
|
|
'assert_callback':
|
|
lambda res: self.assertListEqual(
|
|
res,
|
|
np.asarray([[[[3.0, 6.0], [9.0, 12.0]],
|
|
[[12.0, 9.0], [6.0, 3.0]]],
|
|
[[[2.0, 4.0], [6.0, 8.0]],
|
|
[[8.0, 6.0], [4.0, 2.0]]]]).tolist()),
|
|
'base_value': [[[1.0, 2.0], [3.0, 4.0]],
|
|
[[4.0, 3.0], [2.0, 1.0]]],
|
|
'op_str':
|
|
'SUM',
|
|
'dst':
|
|
0
|
|
},
|
|
]
|
|
|
|
for case in cases:
|
|
# test original all_reduce
|
|
self.start_with_torch(
|
|
_reduce,
|
|
num_gpus=2,
|
|
assert_callback=case['assert_callback'],
|
|
save_all_ranks=True,
|
|
base_value=case['base_value'],
|
|
op_str=case['op_str'],
|
|
dst=case['dst'])
|
|
|
|
# test torchacc all_reduce
|
|
kwargs = {'torchacc_enabled': True}
|
|
self.start_with_torchacc(
|
|
_reduce,
|
|
num_gpus=2,
|
|
assert_callback=case['assert_callback'],
|
|
save_all_ranks=True,
|
|
base_value=case['base_value'],
|
|
op_str=case['op_str'],
|
|
dst=case['dst'],
|
|
**kwargs)
|
|
|
|
def test_broadcast(self):
|
|
# Not implemented for torchacc yet
|
|
pass
|
|
|
|
def test_all_gather(self):
|
|
# base value will multiply with rank+1
|
|
cases = [
|
|
# nor support scalar type
|
|
# {
|
|
# 'assert_callback': lambda res: self.assertEqual(res, [1.0, 2.0]),
|
|
# 'base_value': 1.0,
|
|
# },
|
|
# {
|
|
# 'assert_callback': lambda res: self.assertEqual(res, [2, 4]),
|
|
# 'base_value': 2,
|
|
# },
|
|
{
|
|
'assert_callback':
|
|
lambda res: self.assertListEqual(res, [[1.0], [2.]]),
|
|
'base_value': [1.0],
|
|
},
|
|
{
|
|
'assert_callback':
|
|
lambda res: self.assertListEqual(res, [[1.0, 2.0], [2.0, 4.0]]
|
|
),
|
|
'base_value': [1.0, 2.0],
|
|
},
|
|
{
|
|
'assert_callback':
|
|
lambda res: self.assertListEqual(
|
|
res,
|
|
np.asarray([[[[1.0, 2.0], [3.0, 4.0]],
|
|
[[4.0, 3.0], [2.0, 1.0]]],
|
|
[[[2.0, 4.0], [6.0, 8.0]],
|
|
[[8.0, 6.0], [4.0, 2.0]]]]).tolist()),
|
|
'base_value': [[[1.0, 2.0], [3.0, 4.0]],
|
|
[[4.0, 3.0], [2.0, 1.0]]],
|
|
},
|
|
]
|
|
|
|
for case in cases:
|
|
# test original all_reduce
|
|
self.start_with_torch(
|
|
_all_gather,
|
|
num_gpus=2,
|
|
assert_callback=case['assert_callback'],
|
|
base_value=case['base_value'])
|
|
|
|
# test torchacc all_reduce
|
|
kwargs = {'torchacc_enabled': True}
|
|
self.start_with_torchacc(
|
|
_all_gather,
|
|
num_gpus=2,
|
|
assert_callback=case['assert_callback'],
|
|
base_value=case['base_value'],
|
|
**kwargs)
|
|
|
|
def test_all_reduce(self):
|
|
# base value will multiply with rank+1
|
|
cases = [
|
|
{
|
|
'assert_callback': lambda res: self.assertEqual(res, 3.0),
|
|
'base_value': 1.0,
|
|
'op_str': 'SUM'
|
|
},
|
|
{
|
|
'assert_callback': lambda res: self.assertEqual(res, 2),
|
|
'base_value': 2,
|
|
'op_str': 'MIN'
|
|
},
|
|
{
|
|
'assert_callback': lambda res: self.assertListEqual(res, [2.]),
|
|
'base_value': [1.0],
|
|
'op_str': 'MAX'
|
|
},
|
|
{
|
|
'assert_callback':
|
|
lambda res: self.assertListEqual(res, [2.0, 8.0]),
|
|
'base_value': [1.0, 2.0],
|
|
'op_str':
|
|
'PRODUCT'
|
|
},
|
|
{
|
|
'assert_callback':
|
|
lambda res: self.assertListEqual(
|
|
res,
|
|
np.asarray([[[3.0, 6.0], [9.0, 12.0]],
|
|
[[12.0, 9.0], [6.0, 3.0]]]).tolist()),
|
|
'base_value': [[[1.0, 2.0], [3.0, 4.0]],
|
|
[[4.0, 3.0], [2.0, 1.0]]],
|
|
'op_str':
|
|
'SUM'
|
|
},
|
|
]
|
|
|
|
for case in cases:
|
|
# test original all_reduce
|
|
self.start_with_torch(
|
|
_all_reduce,
|
|
num_gpus=2,
|
|
assert_callback=case['assert_callback'],
|
|
base_value=case['base_value'],
|
|
op_str=case['op_str'])
|
|
|
|
# test torchacc all_reduce
|
|
kwargs = {'torchacc_enabled': True}
|
|
self.start_with_torchacc(
|
|
_all_reduce,
|
|
num_gpus=2,
|
|
assert_callback=case['assert_callback'],
|
|
base_value=case['base_value'],
|
|
op_str=case['op_str'],
|
|
**kwargs)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|