[Refactor] Replace torch distributed with mmengine dist module (#196)
* [Fix] Replace torch distributed with mmengine dist module * minor refinement * move all_reduce_params to dist.py * add unit tests * update unit tests * fix test_logger.py * add examplespull/244/head
parent
e37f1f905b
commit
98c85529b1
|
@ -2,7 +2,7 @@
|
|||
from .dist import (all_gather_object, all_reduce, all_gather, all_reduce_dict,
|
||||
collect_results, gather, broadcast, gather_object,
|
||||
sync_random_seed, broadcast_object_list,
|
||||
collect_results_cpu, collect_results_gpu)
|
||||
collect_results_cpu, collect_results_gpu, all_reduce_params)
|
||||
from .utils import (get_dist_info, init_dist, init_local_group, get_backend,
|
||||
get_world_size, get_rank, get_local_size, get_local_rank,
|
||||
is_main_process, master_only, barrier, get_local_group,
|
||||
|
@ -16,6 +16,6 @@ __all__ = [
|
|||
'get_dist_info', 'init_dist', 'init_local_group', 'get_backend',
|
||||
'get_world_size', 'get_rank', 'get_local_size', 'get_local_group',
|
||||
'get_local_rank', 'is_main_process', 'master_only', 'barrier',
|
||||
'is_distributed', 'get_default_group', 'get_data_device',
|
||||
'get_comm_device', 'cast_data_device'
|
||||
'is_distributed', 'get_default_group', 'all_reduce_params',
|
||||
'get_data_device', 'get_comm_device', 'cast_data_device'
|
||||
]
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Any, List, Optional, Tuple, Dict
|
||||
from typing import Any, List, Optional, Tuple, Dict, Generator, Union
|
||||
from collections import OrderedDict
|
||||
import shutil
|
||||
import pickle
|
||||
import numpy as np
|
||||
|
@ -7,6 +8,8 @@ import tempfile
|
|||
import torch
|
||||
import os.path as osp
|
||||
from torch import Tensor
|
||||
from torch._utils import (_flatten_dense_tensors, _take_tensors,
|
||||
_unflatten_dense_tensors)
|
||||
from torch import distributed as torch_dist
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
|
@ -805,7 +808,7 @@ def gather_object(data: Any,
|
|||
the object must be picklable in order to be gathered.
|
||||
|
||||
Note:
|
||||
``NCCL backend`` dost not support ``gather_object``.
|
||||
``NCCL backend`` does not support ``gather_object``.
|
||||
|
||||
Note:
|
||||
Unlike PyTorch ``torch.distributed.gather_object``,
|
||||
|
@ -1036,3 +1039,92 @@ def collect_results_gpu(result_part: list, size: int) -> Optional[list]:
|
|||
return ordered_results
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def _all_reduce_coalesced(tensors: List[torch.Tensor],
|
||||
bucket_size_mb: int = -1,
|
||||
op: str = 'sum',
|
||||
group: Optional[ProcessGroup] = None) -> None:
|
||||
"""All-reduce a sequence of tensors as a whole.
|
||||
|
||||
Args:
|
||||
tensors (List[torch.Tensor]): A sequence of tensors to be
|
||||
all-reduced.
|
||||
bucket_size_mb (int): The limit of each chunk in megabytes
|
||||
for grouping tensors into chunks. Defaults to -1.
|
||||
op (str): Operation to reduce data. Defaults to 'sum'. Optional values
|
||||
are 'sum', 'mean' and 'produce', 'min', 'max', 'band', 'bor' and
|
||||
'bxor'.
|
||||
group (ProcessGroup, optional): The process group to work on. If None,
|
||||
the default process group will be used. Defaults to None.
|
||||
"""
|
||||
if bucket_size_mb > 0:
|
||||
bucket_size_bytes = bucket_size_mb * 1024 * 1024
|
||||
buckets = _take_tensors(tensors, bucket_size_bytes)
|
||||
else:
|
||||
buckets = OrderedDict()
|
||||
for tensor in tensors:
|
||||
tp = tensor.type()
|
||||
if tp not in buckets:
|
||||
buckets[tp] = []
|
||||
buckets[tp].append(tensor)
|
||||
buckets = buckets.values()
|
||||
|
||||
for bucket in buckets:
|
||||
flat_tensors = _flatten_dense_tensors(bucket)
|
||||
all_reduce(flat_tensors, op=op, group=group)
|
||||
for tensor, synced in zip(
|
||||
bucket, _unflatten_dense_tensors(flat_tensors, bucket)):
|
||||
tensor.copy_(synced)
|
||||
|
||||
|
||||
def all_reduce_params(params: Union[List, Generator[torch.Tensor, None, None]],
|
||||
coalesce: bool = True,
|
||||
bucket_size_mb: int = -1,
|
||||
op: str = 'sum',
|
||||
group: Optional[ProcessGroup] = None) -> None:
|
||||
"""All-reduce parameters.
|
||||
|
||||
Args:
|
||||
params (List or Generator[torch.Tensor, None, None]): List of
|
||||
parameters or buffers of a model.
|
||||
coalesce (bool, optional): Whether to reduce parameters as a whole.
|
||||
Defaults to True.
|
||||
bucket_size_mb (int, optional): Size of bucket, the unit is MB.
|
||||
Defaults to -1.
|
||||
op (str): Operation to reduce data. Defaults to 'sum'. Optional values
|
||||
are 'sum', 'mean' and 'produce', 'min', 'max', 'band', 'bor' and
|
||||
'bxor'.
|
||||
group (ProcessGroup, optional): The process group to work on. If None,
|
||||
the default process group will be used. Defaults to None.
|
||||
|
||||
Examples:
|
||||
>>> import torch
|
||||
>>> import mmengine.dist as dist
|
||||
|
||||
>>> # non-distributed environment
|
||||
>>> data = [torch.arange(2), torch.arange(3)]
|
||||
>>> dist.all_reduce_params(data)
|
||||
>>> data
|
||||
[tensor([0, 1]), tensor([0, 1, 2])]
|
||||
|
||||
>>> # distributed environment
|
||||
>>> # We have 2 process groups, 2 ranks.
|
||||
>>> if dist.get_rank() == 0:
|
||||
... data = [torch.tensor([1, 2]), torch.tensor([3, 4])]
|
||||
... else:
|
||||
... data = [torch.tensor([2, 3]), torch.tensor([4, 5])]
|
||||
|
||||
>>> dist.all_reduce_params(data)
|
||||
>>> data
|
||||
[torch.tensor([3, 5]), torch.tensor([7, 9])]
|
||||
"""
|
||||
world_size = get_world_size(group)
|
||||
if world_size == 1:
|
||||
return
|
||||
params_data = [param.data for param in params]
|
||||
if coalesce:
|
||||
_all_reduce_coalesced(params_data, bucket_size_mb, op=op, group=group)
|
||||
else:
|
||||
for tensor in params_data:
|
||||
all_reduce(tensor, op=op, group=group)
|
||||
|
|
|
@ -1,83 +1,8 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# from mmengine.dist import get_dist_info, all_reduce
|
||||
from collections import OrderedDict
|
||||
from typing import Generator, List
|
||||
from unittest.mock import MagicMock, Mock
|
||||
|
||||
import torch
|
||||
from torch._utils import (_flatten_dense_tensors, _take_tensors,
|
||||
_unflatten_dense_tensors)
|
||||
|
||||
from mmengine import dist
|
||||
from mmengine.registry import HOOKS
|
||||
from .hook import Hook
|
||||
|
||||
# TODO, replace with import mmengine.dist as dist
|
||||
dist = Mock()
|
||||
dist.IS_DIST = MagicMock(return_value=True)
|
||||
|
||||
# TODO, replace with mmengine.dist.get_dist_info
|
||||
get_dist_info = MagicMock(return_value=(0, 1))
|
||||
# TODO, replace with mmengine.dist.all_reduce
|
||||
all_reduce = MagicMock()
|
||||
|
||||
|
||||
# TODO, may need to move to dist.utils after implementing dist module
|
||||
def _allreduce_coalesced(tensors: List[torch.Tensor],
|
||||
world_size: int,
|
||||
bucket_size_mb: int = -1) -> None:
|
||||
"""All-reduce a sequence of tensors as a whole.
|
||||
|
||||
Args:
|
||||
tensors (List[torch.Tensor]): A sequence of tensors to be
|
||||
all-reduced.
|
||||
world_size (int): The world size of the process group.
|
||||
bucket_size_mb (int): The limit of each chunk in megabytes
|
||||
for grouping tensors into chunks. Defaults to -1.
|
||||
"""
|
||||
if bucket_size_mb > 0:
|
||||
bucket_size_bytes = bucket_size_mb * 1024 * 1024
|
||||
buckets = _take_tensors(tensors, bucket_size_bytes)
|
||||
else:
|
||||
buckets = OrderedDict()
|
||||
for tensor in tensors:
|
||||
tp = tensor.type()
|
||||
if tp not in buckets:
|
||||
buckets[tp] = []
|
||||
buckets[tp].append(tensor)
|
||||
buckets = buckets.values()
|
||||
|
||||
for bucket in buckets:
|
||||
flat_tensors = _flatten_dense_tensors(bucket)
|
||||
all_reduce(flat_tensors)
|
||||
flat_tensors.div_(world_size)
|
||||
for tensor, synced in zip(
|
||||
bucket, _unflatten_dense_tensors(flat_tensors, bucket)):
|
||||
tensor.copy_(synced)
|
||||
|
||||
|
||||
def allreduce_params(params: Generator[torch.Tensor, None, None],
|
||||
coalesce: bool = True,
|
||||
bucket_size_mb: int = -1) -> None:
|
||||
"""All-reduce parameters.
|
||||
|
||||
Args:
|
||||
params (Generator[torch.Tensor, None, None]): List of parameters or
|
||||
buffers of a model.
|
||||
coalesce (bool, optional): Whether to reduce parameters as a whole.
|
||||
Defaults to True.
|
||||
bucket_size_mb (int, optional): Size of bucket, the unit is MB.
|
||||
Defaults to -1.
|
||||
"""
|
||||
_, world_size = get_dist_info()
|
||||
if world_size == 1:
|
||||
return
|
||||
params_data = [param.data for param in params]
|
||||
if coalesce:
|
||||
_allreduce_coalesced(params_data, world_size, bucket_size_mb)
|
||||
else:
|
||||
for tensor in params_data:
|
||||
all_reduce(tensor.div_(world_size))
|
||||
|
||||
|
||||
@HOOKS.register_module()
|
||||
class SyncBuffersHook(Hook):
|
||||
|
@ -87,7 +12,7 @@ class SyncBuffersHook(Hook):
|
|||
priority = 'NORMAL'
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.distributed = dist.IS_DIST
|
||||
self.distributed = dist.is_distributed()
|
||||
|
||||
def after_train_epoch(self, runner) -> None:
|
||||
"""All-reduce model buffers at the end of each epoch.
|
||||
|
@ -96,4 +21,4 @@ class SyncBuffersHook(Hook):
|
|||
runner (Runner): The runner of the training process.
|
||||
"""
|
||||
if self.distributed:
|
||||
allreduce_params(runner.model.buffers())
|
||||
dist.all_reduce_params(runner.model.buffers(), op='mean')
|
||||
|
|
|
@ -5,9 +5,9 @@ import sys
|
|||
from logging import Logger, LogRecord
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch.distributed as dist
|
||||
from termcolor import colored
|
||||
|
||||
from mmengine import dist
|
||||
from mmengine.utils import ManagerMixin
|
||||
|
||||
|
||||
|
@ -144,10 +144,8 @@ class MMLogger(Logger, ManagerMixin):
|
|||
Logger.__init__(self, logger_name)
|
||||
ManagerMixin.__init__(self, name)
|
||||
# Get rank in DDP mode.
|
||||
if dist.is_available() and dist.is_initialized():
|
||||
rank = dist.get_rank()
|
||||
else:
|
||||
rank = 0
|
||||
rank = dist.get_rank()
|
||||
|
||||
# Config stream_handler. If `rank != 0`. stream_handler can only
|
||||
# export ERROR logs.
|
||||
stream_handler = logging.StreamHandler(stream=sys.stdout)
|
||||
|
|
|
@ -579,9 +579,6 @@ class Runner:
|
|||
self._rank, self._world_size = get_dist_info()
|
||||
|
||||
timestamp = torch.tensor(time.time(), dtype=torch.float64)
|
||||
# TODO: handled by broadcast
|
||||
if self._world_size > 1 and torch.cuda.is_available():
|
||||
timestamp = timestamp.cuda()
|
||||
# broadcast timestamp from 0 process to other processes
|
||||
broadcast(timestamp)
|
||||
self._timestamp = time.strftime('%Y%m%d_%H%M%S',
|
||||
|
|
|
@ -99,6 +99,22 @@ class TestDist(TestCase):
|
|||
output = dist.collect_results(data, size, device='gpu')
|
||||
self.assertEqual(output, expected)
|
||||
|
||||
def test_all_reduce_params(self):
|
||||
for tensor_type, reduce_op in zip([torch.int64, torch.float32],
|
||||
['sum', 'mean']):
|
||||
data = [
|
||||
torch.tensor([0, 1], dtype=tensor_type) for _ in range(100)
|
||||
]
|
||||
data_gen = (item for item in data)
|
||||
expected = [
|
||||
torch.tensor([0, 1], dtype=tensor_type) for _ in range(100)
|
||||
]
|
||||
|
||||
dist.all_reduce_params(data_gen, op=reduce_op)
|
||||
|
||||
for item1, item2 in zip(data, expected):
|
||||
self.assertTrue(torch.allclose(item1, item2))
|
||||
|
||||
|
||||
class TestDistWithGLOOBackend(MultiProcessTestCase):
|
||||
|
||||
|
@ -300,6 +316,39 @@ class TestDistWithGLOOBackend(MultiProcessTestCase):
|
|||
else:
|
||||
self.assertIsNone(output)
|
||||
|
||||
def test_all_reduce_params(self):
|
||||
self._init_dist_env(self.rank, self.world_size)
|
||||
|
||||
tensor_types = [torch.int64, torch.float32]
|
||||
reduce_ops = ['sum', 'mean']
|
||||
coalesces = [True, False]
|
||||
for tensor_type, reduce_op, coalesce in zip(tensor_types, reduce_ops,
|
||||
coalesces):
|
||||
if dist.get_rank() == 0:
|
||||
data = [
|
||||
torch.tensor([0, 1], dtype=tensor_type) for _ in range(100)
|
||||
]
|
||||
else:
|
||||
data = (
|
||||
torch.tensor([2, 3], dtype=tensor_type)
|
||||
for _ in range(100))
|
||||
|
||||
data_gen = (item for item in data)
|
||||
|
||||
if reduce_op == 'sum':
|
||||
expected = (
|
||||
torch.tensor([2, 4], dtype=tensor_type)
|
||||
for _ in range(100))
|
||||
else:
|
||||
expected = (
|
||||
torch.tensor([1, 2], dtype=tensor_type)
|
||||
for _ in range(100))
|
||||
|
||||
dist.all_reduce_params(data_gen, coalesce=coalesce, op=reduce_op)
|
||||
|
||||
for item1, item2 in zip(data, expected):
|
||||
self.assertTrue(torch.allclose(item1, item2))
|
||||
|
||||
|
||||
@unittest.skipIf(
|
||||
torch.cuda.device_count() < 2, reason='need 2 gpu to test nccl')
|
||||
|
@ -568,3 +617,37 @@ class TestDistWithNCCLBackend(MultiProcessTestCase):
|
|||
self.assertEqual(output, expected)
|
||||
else:
|
||||
self.assertIsNone(output)
|
||||
|
||||
def test_all_reduce_params(self):
|
||||
self._init_dist_env(self.rank, self.world_size)
|
||||
|
||||
tensor_types = [torch.int64, torch.float32]
|
||||
reduce_ops = ['sum', 'mean']
|
||||
coalesces = [True, False]
|
||||
device_types = ['cpu', 'cuda']
|
||||
for tensor_type, reduce_op, coalesce, device_type in zip(
|
||||
tensor_types, reduce_ops, coalesces, device_types):
|
||||
if dist.get_rank() == 0:
|
||||
data = [
|
||||
torch.tensor([0, 1], dtype=tensor_type).to(device_type)
|
||||
for _ in range(100)
|
||||
]
|
||||
else:
|
||||
data = [
|
||||
torch.tensor([2, 3], dtype=tensor_type).to(device_type)
|
||||
for _ in range(100)
|
||||
]
|
||||
|
||||
data_gen = (item for item in data)
|
||||
|
||||
if reduce_op == 'sum':
|
||||
expected = (
|
||||
torch.tensor([2, 4], dtype=tensor_type).to(device_type)
|
||||
for _ in range(100))
|
||||
else:
|
||||
expected = (
|
||||
torch.tensor([1, 2], dtype=tensor_type).to(device_type)
|
||||
for _ in range(100))
|
||||
|
||||
for item1, item2 in zip(data_gen, expected):
|
||||
self.assertTrue(torch.allclose(item1, item2))
|
||||
|
|
|
@ -15,9 +15,7 @@ class TestLogger:
|
|||
stream_handler_regex_time = r'\d{2}/\d{2} \d{2}:\d{2}:\d{2}'
|
||||
file_handler_regex_time = r'\d{4}/\d{2}/\d{2} \d{2}:\d{2}:\d{2}'
|
||||
|
||||
@patch('torch.distributed.get_rank', lambda: 0)
|
||||
@patch('torch.distributed.is_initialized', lambda: True)
|
||||
@patch('torch.distributed.is_available', lambda: True)
|
||||
@patch('mmengine.dist.get_rank', lambda: 0)
|
||||
def test_init_rank0(self, tmp_path):
|
||||
logger = MMLogger.get_instance('rank0.pkg1', log_level='INFO')
|
||||
assert logger.name == 'mmengine'
|
||||
|
@ -47,9 +45,7 @@ class TestLogger:
|
|||
assert logger.instance_name == 'rank0.pkg3'
|
||||
logging.shutdown()
|
||||
|
||||
@patch('torch.distributed.get_rank', lambda: 1)
|
||||
@patch('torch.distributed.is_initialized', lambda: True)
|
||||
@patch('torch.distributed.is_available', lambda: True)
|
||||
@patch('mmengine.dist.get_rank', lambda: 1)
|
||||
def test_init_rank1(self, tmp_path):
|
||||
# If `rank!=1`, the `loglevel` of file_handler is `logging.ERROR`.
|
||||
tmp_file = tmp_path / 'tmp_file.log'
|
||||
|
|
Loading…
Reference in New Issue