diff --git a/mmengine/dist/__init__.py b/mmengine/dist/__init__.py index f59129f9..48d7d4f9 100644 --- a/mmengine/dist/__init__.py +++ b/mmengine/dist/__init__.py @@ -2,7 +2,7 @@ from .dist import (all_gather_object, all_reduce, all_gather, all_reduce_dict, collect_results, gather, broadcast, gather_object, sync_random_seed, broadcast_object_list, - collect_results_cpu, collect_results_gpu) + collect_results_cpu, collect_results_gpu, all_reduce_params) from .utils import (get_dist_info, init_dist, init_local_group, get_backend, get_world_size, get_rank, get_local_size, get_local_rank, is_main_process, master_only, barrier, get_local_group, @@ -16,6 +16,6 @@ __all__ = [ 'get_dist_info', 'init_dist', 'init_local_group', 'get_backend', 'get_world_size', 'get_rank', 'get_local_size', 'get_local_group', 'get_local_rank', 'is_main_process', 'master_only', 'barrier', - 'is_distributed', 'get_default_group', 'get_data_device', - 'get_comm_device', 'cast_data_device' + 'is_distributed', 'get_default_group', 'all_reduce_params', + 'get_data_device', 'get_comm_device', 'cast_data_device' ] diff --git a/mmengine/dist/dist.py b/mmengine/dist/dist.py index 73eb88c0..690338cb 100644 --- a/mmengine/dist/dist.py +++ b/mmengine/dist/dist.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Any, List, Optional, Tuple, Dict +from typing import Any, List, Optional, Tuple, Dict, Generator, Union +from collections import OrderedDict import shutil import pickle import numpy as np @@ -7,6 +8,8 @@ import tempfile import torch import os.path as osp from torch import Tensor +from torch._utils import (_flatten_dense_tensors, _take_tensors, + _unflatten_dense_tensors) from torch import distributed as torch_dist from torch.distributed import ProcessGroup @@ -805,7 +808,7 @@ def gather_object(data: Any, the object must be picklable in order to be gathered. Note: - ``NCCL backend`` dost not support ``gather_object``. + ``NCCL backend`` does not support ``gather_object``. Note: Unlike PyTorch ``torch.distributed.gather_object``, @@ -1036,3 +1039,92 @@ def collect_results_gpu(result_part: list, size: int) -> Optional[list]: return ordered_results else: return None + + +def _all_reduce_coalesced(tensors: List[torch.Tensor], + bucket_size_mb: int = -1, + op: str = 'sum', + group: Optional[ProcessGroup] = None) -> None: + """All-reduce a sequence of tensors as a whole. + + Args: + tensors (List[torch.Tensor]): A sequence of tensors to be + all-reduced. + bucket_size_mb (int): The limit of each chunk in megabytes + for grouping tensors into chunks. Defaults to -1. + op (str): Operation to reduce data. Defaults to 'sum'. Optional values + are 'sum', 'mean' and 'produce', 'min', 'max', 'band', 'bor' and + 'bxor'. + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. Defaults to None. + """ + if bucket_size_mb > 0: + bucket_size_bytes = bucket_size_mb * 1024 * 1024 + buckets = _take_tensors(tensors, bucket_size_bytes) + else: + buckets = OrderedDict() + for tensor in tensors: + tp = tensor.type() + if tp not in buckets: + buckets[tp] = [] + buckets[tp].append(tensor) + buckets = buckets.values() + + for bucket in buckets: + flat_tensors = _flatten_dense_tensors(bucket) + all_reduce(flat_tensors, op=op, group=group) + for tensor, synced in zip( + bucket, _unflatten_dense_tensors(flat_tensors, bucket)): + tensor.copy_(synced) + + +def all_reduce_params(params: Union[List, Generator[torch.Tensor, None, None]], + coalesce: bool = True, + bucket_size_mb: int = -1, + op: str = 'sum', + group: Optional[ProcessGroup] = None) -> None: + """All-reduce parameters. + + Args: + params (List or Generator[torch.Tensor, None, None]): List of + parameters or buffers of a model. + coalesce (bool, optional): Whether to reduce parameters as a whole. + Defaults to True. + bucket_size_mb (int, optional): Size of bucket, the unit is MB. + Defaults to -1. + op (str): Operation to reduce data. Defaults to 'sum'. Optional values + are 'sum', 'mean' and 'produce', 'min', 'max', 'band', 'bor' and + 'bxor'. + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. Defaults to None. + + Examples: + >>> import torch + >>> import mmengine.dist as dist + + >>> # non-distributed environment + >>> data = [torch.arange(2), torch.arange(3)] + >>> dist.all_reduce_params(data) + >>> data + [tensor([0, 1]), tensor([0, 1, 2])] + + >>> # distributed environment + >>> # We have 2 process groups, 2 ranks. + >>> if dist.get_rank() == 0: + ... data = [torch.tensor([1, 2]), torch.tensor([3, 4])] + ... else: + ... data = [torch.tensor([2, 3]), torch.tensor([4, 5])] + + >>> dist.all_reduce_params(data) + >>> data + [torch.tensor([3, 5]), torch.tensor([7, 9])] + """ + world_size = get_world_size(group) + if world_size == 1: + return + params_data = [param.data for param in params] + if coalesce: + _all_reduce_coalesced(params_data, bucket_size_mb, op=op, group=group) + else: + for tensor in params_data: + all_reduce(tensor, op=op, group=group) diff --git a/mmengine/hooks/sync_buffer_hook.py b/mmengine/hooks/sync_buffer_hook.py index 37b62f98..bd2a5673 100644 --- a/mmengine/hooks/sync_buffer_hook.py +++ b/mmengine/hooks/sync_buffer_hook.py @@ -1,83 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. -# from mmengine.dist import get_dist_info, all_reduce -from collections import OrderedDict -from typing import Generator, List -from unittest.mock import MagicMock, Mock - -import torch -from torch._utils import (_flatten_dense_tensors, _take_tensors, - _unflatten_dense_tensors) - +from mmengine import dist from mmengine.registry import HOOKS from .hook import Hook -# TODO, replace with import mmengine.dist as dist -dist = Mock() -dist.IS_DIST = MagicMock(return_value=True) - -# TODO, replace with mmengine.dist.get_dist_info -get_dist_info = MagicMock(return_value=(0, 1)) -# TODO, replace with mmengine.dist.all_reduce -all_reduce = MagicMock() - - -# TODO, may need to move to dist.utils after implementing dist module -def _allreduce_coalesced(tensors: List[torch.Tensor], - world_size: int, - bucket_size_mb: int = -1) -> None: - """All-reduce a sequence of tensors as a whole. - - Args: - tensors (List[torch.Tensor]): A sequence of tensors to be - all-reduced. - world_size (int): The world size of the process group. - bucket_size_mb (int): The limit of each chunk in megabytes - for grouping tensors into chunks. Defaults to -1. - """ - if bucket_size_mb > 0: - bucket_size_bytes = bucket_size_mb * 1024 * 1024 - buckets = _take_tensors(tensors, bucket_size_bytes) - else: - buckets = OrderedDict() - for tensor in tensors: - tp = tensor.type() - if tp not in buckets: - buckets[tp] = [] - buckets[tp].append(tensor) - buckets = buckets.values() - - for bucket in buckets: - flat_tensors = _flatten_dense_tensors(bucket) - all_reduce(flat_tensors) - flat_tensors.div_(world_size) - for tensor, synced in zip( - bucket, _unflatten_dense_tensors(flat_tensors, bucket)): - tensor.copy_(synced) - - -def allreduce_params(params: Generator[torch.Tensor, None, None], - coalesce: bool = True, - bucket_size_mb: int = -1) -> None: - """All-reduce parameters. - - Args: - params (Generator[torch.Tensor, None, None]): List of parameters or - buffers of a model. - coalesce (bool, optional): Whether to reduce parameters as a whole. - Defaults to True. - bucket_size_mb (int, optional): Size of bucket, the unit is MB. - Defaults to -1. - """ - _, world_size = get_dist_info() - if world_size == 1: - return - params_data = [param.data for param in params] - if coalesce: - _allreduce_coalesced(params_data, world_size, bucket_size_mb) - else: - for tensor in params_data: - all_reduce(tensor.div_(world_size)) - @HOOKS.register_module() class SyncBuffersHook(Hook): @@ -87,7 +12,7 @@ class SyncBuffersHook(Hook): priority = 'NORMAL' def __init__(self) -> None: - self.distributed = dist.IS_DIST + self.distributed = dist.is_distributed() def after_train_epoch(self, runner) -> None: """All-reduce model buffers at the end of each epoch. @@ -96,4 +21,4 @@ class SyncBuffersHook(Hook): runner (Runner): The runner of the training process. """ if self.distributed: - allreduce_params(runner.model.buffers()) + dist.all_reduce_params(runner.model.buffers(), op='mean') diff --git a/mmengine/logging/logger.py b/mmengine/logging/logger.py index f10a77ed..78bd0598 100644 --- a/mmengine/logging/logger.py +++ b/mmengine/logging/logger.py @@ -5,9 +5,9 @@ import sys from logging import Logger, LogRecord from typing import Optional, Union -import torch.distributed as dist from termcolor import colored +from mmengine import dist from mmengine.utils import ManagerMixin @@ -144,10 +144,8 @@ class MMLogger(Logger, ManagerMixin): Logger.__init__(self, logger_name) ManagerMixin.__init__(self, name) # Get rank in DDP mode. - if dist.is_available() and dist.is_initialized(): - rank = dist.get_rank() - else: - rank = 0 + rank = dist.get_rank() + # Config stream_handler. If `rank != 0`. stream_handler can only # export ERROR logs. stream_handler = logging.StreamHandler(stream=sys.stdout) diff --git a/mmengine/runner/runner.py b/mmengine/runner/runner.py index 4f4a1296..ead5c6e1 100644 --- a/mmengine/runner/runner.py +++ b/mmengine/runner/runner.py @@ -579,9 +579,6 @@ class Runner: self._rank, self._world_size = get_dist_info() timestamp = torch.tensor(time.time(), dtype=torch.float64) - # TODO: handled by broadcast - if self._world_size > 1 and torch.cuda.is_available(): - timestamp = timestamp.cuda() # broadcast timestamp from 0 process to other processes broadcast(timestamp) self._timestamp = time.strftime('%Y%m%d_%H%M%S', diff --git a/tests/test_dist/test_dist.py b/tests/test_dist/test_dist.py index bf8646fb..ae4beec3 100644 --- a/tests/test_dist/test_dist.py +++ b/tests/test_dist/test_dist.py @@ -99,6 +99,22 @@ class TestDist(TestCase): output = dist.collect_results(data, size, device='gpu') self.assertEqual(output, expected) + def test_all_reduce_params(self): + for tensor_type, reduce_op in zip([torch.int64, torch.float32], + ['sum', 'mean']): + data = [ + torch.tensor([0, 1], dtype=tensor_type) for _ in range(100) + ] + data_gen = (item for item in data) + expected = [ + torch.tensor([0, 1], dtype=tensor_type) for _ in range(100) + ] + + dist.all_reduce_params(data_gen, op=reduce_op) + + for item1, item2 in zip(data, expected): + self.assertTrue(torch.allclose(item1, item2)) + class TestDistWithGLOOBackend(MultiProcessTestCase): @@ -300,6 +316,39 @@ class TestDistWithGLOOBackend(MultiProcessTestCase): else: self.assertIsNone(output) + def test_all_reduce_params(self): + self._init_dist_env(self.rank, self.world_size) + + tensor_types = [torch.int64, torch.float32] + reduce_ops = ['sum', 'mean'] + coalesces = [True, False] + for tensor_type, reduce_op, coalesce in zip(tensor_types, reduce_ops, + coalesces): + if dist.get_rank() == 0: + data = [ + torch.tensor([0, 1], dtype=tensor_type) for _ in range(100) + ] + else: + data = ( + torch.tensor([2, 3], dtype=tensor_type) + for _ in range(100)) + + data_gen = (item for item in data) + + if reduce_op == 'sum': + expected = ( + torch.tensor([2, 4], dtype=tensor_type) + for _ in range(100)) + else: + expected = ( + torch.tensor([1, 2], dtype=tensor_type) + for _ in range(100)) + + dist.all_reduce_params(data_gen, coalesce=coalesce, op=reduce_op) + + for item1, item2 in zip(data, expected): + self.assertTrue(torch.allclose(item1, item2)) + @unittest.skipIf( torch.cuda.device_count() < 2, reason='need 2 gpu to test nccl') @@ -568,3 +617,37 @@ class TestDistWithNCCLBackend(MultiProcessTestCase): self.assertEqual(output, expected) else: self.assertIsNone(output) + + def test_all_reduce_params(self): + self._init_dist_env(self.rank, self.world_size) + + tensor_types = [torch.int64, torch.float32] + reduce_ops = ['sum', 'mean'] + coalesces = [True, False] + device_types = ['cpu', 'cuda'] + for tensor_type, reduce_op, coalesce, device_type in zip( + tensor_types, reduce_ops, coalesces, device_types): + if dist.get_rank() == 0: + data = [ + torch.tensor([0, 1], dtype=tensor_type).to(device_type) + for _ in range(100) + ] + else: + data = [ + torch.tensor([2, 3], dtype=tensor_type).to(device_type) + for _ in range(100) + ] + + data_gen = (item for item in data) + + if reduce_op == 'sum': + expected = ( + torch.tensor([2, 4], dtype=tensor_type).to(device_type) + for _ in range(100)) + else: + expected = ( + torch.tensor([1, 2], dtype=tensor_type).to(device_type) + for _ in range(100)) + + for item1, item2 in zip(data_gen, expected): + self.assertTrue(torch.allclose(item1, item2)) diff --git a/tests/test_logging/test_logger.py b/tests/test_logging/test_logger.py index b7d5260c..dac3100f 100644 --- a/tests/test_logging/test_logger.py +++ b/tests/test_logging/test_logger.py @@ -15,9 +15,7 @@ class TestLogger: stream_handler_regex_time = r'\d{2}/\d{2} \d{2}:\d{2}:\d{2}' file_handler_regex_time = r'\d{4}/\d{2}/\d{2} \d{2}:\d{2}:\d{2}' - @patch('torch.distributed.get_rank', lambda: 0) - @patch('torch.distributed.is_initialized', lambda: True) - @patch('torch.distributed.is_available', lambda: True) + @patch('mmengine.dist.get_rank', lambda: 0) def test_init_rank0(self, tmp_path): logger = MMLogger.get_instance('rank0.pkg1', log_level='INFO') assert logger.name == 'mmengine' @@ -47,9 +45,7 @@ class TestLogger: assert logger.instance_name == 'rank0.pkg3' logging.shutdown() - @patch('torch.distributed.get_rank', lambda: 1) - @patch('torch.distributed.is_initialized', lambda: True) - @patch('torch.distributed.is_available', lambda: True) + @patch('mmengine.dist.get_rank', lambda: 1) def test_init_rank1(self, tmp_path): # If `rank!=1`, the `loglevel` of file_handler is `logging.ERROR`. tmp_file = tmp_path / 'tmp_file.log'