mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Refactor] Refactor the unit tests of SyncBuffersHook (#813)
This commit is contained in:
parent
69b563dc3b
commit
3715fea15b
2
.github/workflows/pr_stage_test.yml
vendored
2
.github/workflows/pr_stage_test.yml
vendored
@ -167,5 +167,5 @@ jobs:
|
||||
if: ${{ matrix.platform == 'cpu' }}
|
||||
- name: Run GPU unittests
|
||||
# Skip testing distributed related unit tests since the memory of windows CI is limited
|
||||
run: pytest tests/ --ignore tests/test_dist --ignore tests/test_optim/test_optimizer/test_optimizer_wrapper.py --ignore tests/test_model/test_wrappers/test_model_wrapper.py
|
||||
run: pytest tests/ --ignore tests/test_dist --ignore tests/test_optim/test_optimizer/test_optimizer_wrapper.py --ignore tests/test_model/test_wrappers/test_model_wrapper.py --ignore tests/test_hooks/test_sync_buffers_hook.py
|
||||
if: ${{ matrix.platform == 'cu111' }}
|
||||
|
@ -67,6 +67,9 @@ class MultiProcessTestCase(TestCase):
|
||||
def _should_stop_test_suite(self) -> bool:
|
||||
return False
|
||||
|
||||
def prepare_subprocess(self):
|
||||
pass
|
||||
|
||||
@property
|
||||
def world_size(self) -> int:
|
||||
return 2
|
||||
@ -169,7 +172,7 @@ class MultiProcessTestCase(TestCase):
|
||||
def _run(cls, rank: int, test_name: str, file_name: str,
|
||||
parent_pipe) -> None:
|
||||
self = cls(test_name)
|
||||
|
||||
self.prepare_subprocess()
|
||||
self.rank = rank
|
||||
self.file_name = file_name
|
||||
self.run_test(test_name, parent_pipe)
|
||||
|
@ -1,13 +1,74 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from unittest.mock import Mock
|
||||
import os
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import torch
|
||||
import torch.distributed as torch_dist
|
||||
import torch.nn as nn
|
||||
|
||||
from mmengine.dist import all_gather
|
||||
from mmengine.hooks import SyncBuffersHook
|
||||
from mmengine.registry import MODELS
|
||||
from mmengine.testing._internal import MultiProcessTestCase
|
||||
from mmengine.testing.runner_test_case import RunnerTestCase, ToyModel
|
||||
|
||||
|
||||
class TestSyncBuffersHook:
|
||||
class ToyModuleWithNorm(ToyModel):
|
||||
|
||||
def __init__(self, data_preprocessor=None):
|
||||
super().__init__(data_preprocessor=data_preprocessor)
|
||||
bn = nn.BatchNorm1d(2)
|
||||
self.linear1 = nn.Sequential(self.linear1, bn)
|
||||
|
||||
def init_weights(self):
|
||||
for buffer in self.buffers():
|
||||
buffer.fill_(
|
||||
torch.tensor(int(os.environ['RANK']), dtype=torch.float32))
|
||||
return super().init_weights()
|
||||
|
||||
|
||||
class TestSyncBuffersHook(MultiProcessTestCase, RunnerTestCase):
|
||||
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
self._spawn_processes()
|
||||
|
||||
def prepare_subprocess(self):
|
||||
MODELS.register_module(module=ToyModuleWithNorm, force=True)
|
||||
super(MultiProcessTestCase, self).setUp()
|
||||
|
||||
def test_sync_buffers_hook(self):
|
||||
runner = Mock()
|
||||
runner.model = Mock()
|
||||
self.setup_dist_env()
|
||||
runner = MagicMock()
|
||||
runner.model = ToyModuleWithNorm()
|
||||
runner.model.init_weights()
|
||||
|
||||
for buffer in runner.model.buffers():
|
||||
buffer1, buffer2 = all_gather(buffer)
|
||||
self.assertFalse(torch.allclose(buffer1, buffer2))
|
||||
|
||||
hook = SyncBuffersHook()
|
||||
hook._after_epoch(runner)
|
||||
hook.after_train_epoch(runner)
|
||||
|
||||
for buffer in runner.model.buffers():
|
||||
buffer1, buffer2 = all_gather(buffer)
|
||||
self.assertTrue(torch.allclose(buffer1, buffer2))
|
||||
|
||||
def test_with_runner(self):
|
||||
self.setup_dist_env()
|
||||
cfg = self.epoch_based_cfg
|
||||
cfg.model = dict(type='ToyModuleWithNorm')
|
||||
cfg.launch = 'pytorch'
|
||||
cfg.custom_hooks = [dict(type='SyncBuffersHook')]
|
||||
runner = self.build_runner(cfg)
|
||||
runner.train()
|
||||
|
||||
for buffer in runner.model.buffers():
|
||||
buffer1, buffer2 = all_gather(buffer)
|
||||
self.assertTrue(torch.allclose(buffer1, buffer2))
|
||||
|
||||
def setup_dist_env(self):
|
||||
super().setup_dist_env()
|
||||
os.environ['RANK'] = str(self.rank)
|
||||
torch_dist.init_process_group(
|
||||
backend='gloo', rank=self.rank, world_size=self.world_size)
|
||||
|
@ -75,7 +75,7 @@ class ComplexModel(BaseModel):
|
||||
|
||||
class TestDistributedDataParallel(MultiProcessTestCase):
|
||||
|
||||
def setUp(self) -> None:
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self._spawn_processes()
|
||||
|
||||
|
@ -774,7 +774,7 @@ class TestBuilder(TestCase):
|
||||
reason='ZeRO requires pytorch>=1.8 with torch.distributed.rpc available.')
|
||||
class TestZeroOptimizer(MultiProcessTestCase):
|
||||
|
||||
def setUp(self) -> None:
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self._spawn_processes()
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user