diff --git a/.github/workflows/pr_stage_test.yml b/.github/workflows/pr_stage_test.yml index 320262d4..e5578ab9 100644 --- a/.github/workflows/pr_stage_test.yml +++ b/.github/workflows/pr_stage_test.yml @@ -167,5 +167,5 @@ jobs: if: ${{ matrix.platform == 'cpu' }} - name: Run GPU unittests # Skip testing distributed related unit tests since the memory of windows CI is limited - run: pytest tests/ --ignore tests/test_dist --ignore tests/test_optim/test_optimizer/test_optimizer_wrapper.py --ignore tests/test_model/test_wrappers/test_model_wrapper.py + run: pytest tests/ --ignore tests/test_dist --ignore tests/test_optim/test_optimizer/test_optimizer_wrapper.py --ignore tests/test_model/test_wrappers/test_model_wrapper.py --ignore tests/test_hooks/test_sync_buffers_hook.py if: ${{ matrix.platform == 'cu111' }} diff --git a/mmengine/testing/_internal/distributed.py b/mmengine/testing/_internal/distributed.py index 5e5020fa..20e3b601 100644 --- a/mmengine/testing/_internal/distributed.py +++ b/mmengine/testing/_internal/distributed.py @@ -67,6 +67,9 @@ class MultiProcessTestCase(TestCase): def _should_stop_test_suite(self) -> bool: return False + def prepare_subprocess(self): + pass + @property def world_size(self) -> int: return 2 @@ -169,7 +172,7 @@ class MultiProcessTestCase(TestCase): def _run(cls, rank: int, test_name: str, file_name: str, parent_pipe) -> None: self = cls(test_name) - + self.prepare_subprocess() self.rank = rank self.file_name = file_name self.run_test(test_name, parent_pipe) diff --git a/tests/test_hooks/test_sync_buffers_hook.py b/tests/test_hooks/test_sync_buffers_hook.py index c7c64287..6d4019dc 100644 --- a/tests/test_hooks/test_sync_buffers_hook.py +++ b/tests/test_hooks/test_sync_buffers_hook.py @@ -1,13 +1,74 @@ # Copyright (c) OpenMMLab. All rights reserved. -from unittest.mock import Mock +import os +from unittest.mock import MagicMock +import torch +import torch.distributed as torch_dist +import torch.nn as nn + +from mmengine.dist import all_gather from mmengine.hooks import SyncBuffersHook +from mmengine.registry import MODELS +from mmengine.testing._internal import MultiProcessTestCase +from mmengine.testing.runner_test_case import RunnerTestCase, ToyModel -class TestSyncBuffersHook: +class ToyModuleWithNorm(ToyModel): + + def __init__(self, data_preprocessor=None): + super().__init__(data_preprocessor=data_preprocessor) + bn = nn.BatchNorm1d(2) + self.linear1 = nn.Sequential(self.linear1, bn) + + def init_weights(self): + for buffer in self.buffers(): + buffer.fill_( + torch.tensor(int(os.environ['RANK']), dtype=torch.float32)) + return super().init_weights() + + +class TestSyncBuffersHook(MultiProcessTestCase, RunnerTestCase): + + def setUp(self) -> None: + super().setUp() + self._spawn_processes() + + def prepare_subprocess(self): + MODELS.register_module(module=ToyModuleWithNorm, force=True) + super(MultiProcessTestCase, self).setUp() def test_sync_buffers_hook(self): - runner = Mock() - runner.model = Mock() + self.setup_dist_env() + runner = MagicMock() + runner.model = ToyModuleWithNorm() + runner.model.init_weights() + + for buffer in runner.model.buffers(): + buffer1, buffer2 = all_gather(buffer) + self.assertFalse(torch.allclose(buffer1, buffer2)) + hook = SyncBuffersHook() - hook._after_epoch(runner) + hook.after_train_epoch(runner) + + for buffer in runner.model.buffers(): + buffer1, buffer2 = all_gather(buffer) + self.assertTrue(torch.allclose(buffer1, buffer2)) + + def test_with_runner(self): + self.setup_dist_env() + cfg = self.epoch_based_cfg + cfg.model = dict(type='ToyModuleWithNorm') + cfg.launch = 'pytorch' + cfg.custom_hooks = [dict(type='SyncBuffersHook')] + runner = self.build_runner(cfg) + runner.train() + + for buffer in runner.model.buffers(): + buffer1, buffer2 = all_gather(buffer) + self.assertTrue(torch.allclose(buffer1, buffer2)) + + def setup_dist_env(self): + super().setup_dist_env() + os.environ['RANK'] = str(self.rank) + torch_dist.init_process_group( + backend='gloo', rank=self.rank, world_size=self.world_size) diff --git a/tests/test_model/test_wrappers/test_model_wrapper.py b/tests/test_model/test_wrappers/test_model_wrapper.py index 7a2d3bd7..eabe10ea 100644 --- a/tests/test_model/test_wrappers/test_model_wrapper.py +++ b/tests/test_model/test_wrappers/test_model_wrapper.py @@ -75,7 +75,7 @@ class ComplexModel(BaseModel): class TestDistributedDataParallel(MultiProcessTestCase): - def setUp(self) -> None: + def setUp(self): super().setUp() self._spawn_processes() diff --git a/tests/test_optim/test_optimizer/test_optimizer.py b/tests/test_optim/test_optimizer/test_optimizer.py index 35aabaea..e4089a4e 100644 --- a/tests/test_optim/test_optimizer/test_optimizer.py +++ b/tests/test_optim/test_optimizer/test_optimizer.py @@ -774,7 +774,7 @@ class TestBuilder(TestCase): reason='ZeRO requires pytorch>=1.8 with torch.distributed.rpc available.') class TestZeroOptimizer(MultiProcessTestCase): - def setUp(self) -> None: + def setUp(self): super().setUp() self._spawn_processes()