mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Refactor] Refactor the unit tests of SyncBuffersHook (#813)
This commit is contained in:
parent
69b563dc3b
commit
3715fea15b
2
.github/workflows/pr_stage_test.yml
vendored
2
.github/workflows/pr_stage_test.yml
vendored
@ -167,5 +167,5 @@ jobs:
|
|||||||
if: ${{ matrix.platform == 'cpu' }}
|
if: ${{ matrix.platform == 'cpu' }}
|
||||||
- name: Run GPU unittests
|
- name: Run GPU unittests
|
||||||
# Skip testing distributed related unit tests since the memory of windows CI is limited
|
# Skip testing distributed related unit tests since the memory of windows CI is limited
|
||||||
run: pytest tests/ --ignore tests/test_dist --ignore tests/test_optim/test_optimizer/test_optimizer_wrapper.py --ignore tests/test_model/test_wrappers/test_model_wrapper.py
|
run: pytest tests/ --ignore tests/test_dist --ignore tests/test_optim/test_optimizer/test_optimizer_wrapper.py --ignore tests/test_model/test_wrappers/test_model_wrapper.py --ignore tests/test_hooks/test_sync_buffers_hook.py
|
||||||
if: ${{ matrix.platform == 'cu111' }}
|
if: ${{ matrix.platform == 'cu111' }}
|
||||||
|
@ -67,6 +67,9 @@ class MultiProcessTestCase(TestCase):
|
|||||||
def _should_stop_test_suite(self) -> bool:
|
def _should_stop_test_suite(self) -> bool:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def prepare_subprocess(self):
|
||||||
|
pass
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def world_size(self) -> int:
|
def world_size(self) -> int:
|
||||||
return 2
|
return 2
|
||||||
@ -169,7 +172,7 @@ class MultiProcessTestCase(TestCase):
|
|||||||
def _run(cls, rank: int, test_name: str, file_name: str,
|
def _run(cls, rank: int, test_name: str, file_name: str,
|
||||||
parent_pipe) -> None:
|
parent_pipe) -> None:
|
||||||
self = cls(test_name)
|
self = cls(test_name)
|
||||||
|
self.prepare_subprocess()
|
||||||
self.rank = rank
|
self.rank = rank
|
||||||
self.file_name = file_name
|
self.file_name = file_name
|
||||||
self.run_test(test_name, parent_pipe)
|
self.run_test(test_name, parent_pipe)
|
||||||
|
@ -1,13 +1,74 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
from unittest.mock import Mock
|
import os
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed as torch_dist
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from mmengine.dist import all_gather
|
||||||
from mmengine.hooks import SyncBuffersHook
|
from mmengine.hooks import SyncBuffersHook
|
||||||
|
from mmengine.registry import MODELS
|
||||||
|
from mmengine.testing._internal import MultiProcessTestCase
|
||||||
|
from mmengine.testing.runner_test_case import RunnerTestCase, ToyModel
|
||||||
|
|
||||||
|
|
||||||
class TestSyncBuffersHook:
|
class ToyModuleWithNorm(ToyModel):
|
||||||
|
|
||||||
|
def __init__(self, data_preprocessor=None):
|
||||||
|
super().__init__(data_preprocessor=data_preprocessor)
|
||||||
|
bn = nn.BatchNorm1d(2)
|
||||||
|
self.linear1 = nn.Sequential(self.linear1, bn)
|
||||||
|
|
||||||
|
def init_weights(self):
|
||||||
|
for buffer in self.buffers():
|
||||||
|
buffer.fill_(
|
||||||
|
torch.tensor(int(os.environ['RANK']), dtype=torch.float32))
|
||||||
|
return super().init_weights()
|
||||||
|
|
||||||
|
|
||||||
|
class TestSyncBuffersHook(MultiProcessTestCase, RunnerTestCase):
|
||||||
|
|
||||||
|
def setUp(self) -> None:
|
||||||
|
super().setUp()
|
||||||
|
self._spawn_processes()
|
||||||
|
|
||||||
|
def prepare_subprocess(self):
|
||||||
|
MODELS.register_module(module=ToyModuleWithNorm, force=True)
|
||||||
|
super(MultiProcessTestCase, self).setUp()
|
||||||
|
|
||||||
def test_sync_buffers_hook(self):
|
def test_sync_buffers_hook(self):
|
||||||
runner = Mock()
|
self.setup_dist_env()
|
||||||
runner.model = Mock()
|
runner = MagicMock()
|
||||||
|
runner.model = ToyModuleWithNorm()
|
||||||
|
runner.model.init_weights()
|
||||||
|
|
||||||
|
for buffer in runner.model.buffers():
|
||||||
|
buffer1, buffer2 = all_gather(buffer)
|
||||||
|
self.assertFalse(torch.allclose(buffer1, buffer2))
|
||||||
|
|
||||||
hook = SyncBuffersHook()
|
hook = SyncBuffersHook()
|
||||||
hook._after_epoch(runner)
|
hook.after_train_epoch(runner)
|
||||||
|
|
||||||
|
for buffer in runner.model.buffers():
|
||||||
|
buffer1, buffer2 = all_gather(buffer)
|
||||||
|
self.assertTrue(torch.allclose(buffer1, buffer2))
|
||||||
|
|
||||||
|
def test_with_runner(self):
|
||||||
|
self.setup_dist_env()
|
||||||
|
cfg = self.epoch_based_cfg
|
||||||
|
cfg.model = dict(type='ToyModuleWithNorm')
|
||||||
|
cfg.launch = 'pytorch'
|
||||||
|
cfg.custom_hooks = [dict(type='SyncBuffersHook')]
|
||||||
|
runner = self.build_runner(cfg)
|
||||||
|
runner.train()
|
||||||
|
|
||||||
|
for buffer in runner.model.buffers():
|
||||||
|
buffer1, buffer2 = all_gather(buffer)
|
||||||
|
self.assertTrue(torch.allclose(buffer1, buffer2))
|
||||||
|
|
||||||
|
def setup_dist_env(self):
|
||||||
|
super().setup_dist_env()
|
||||||
|
os.environ['RANK'] = str(self.rank)
|
||||||
|
torch_dist.init_process_group(
|
||||||
|
backend='gloo', rank=self.rank, world_size=self.world_size)
|
||||||
|
@ -75,7 +75,7 @@ class ComplexModel(BaseModel):
|
|||||||
|
|
||||||
class TestDistributedDataParallel(MultiProcessTestCase):
|
class TestDistributedDataParallel(MultiProcessTestCase):
|
||||||
|
|
||||||
def setUp(self) -> None:
|
def setUp(self):
|
||||||
super().setUp()
|
super().setUp()
|
||||||
self._spawn_processes()
|
self._spawn_processes()
|
||||||
|
|
||||||
|
@ -774,7 +774,7 @@ class TestBuilder(TestCase):
|
|||||||
reason='ZeRO requires pytorch>=1.8 with torch.distributed.rpc available.')
|
reason='ZeRO requires pytorch>=1.8 with torch.distributed.rpc available.')
|
||||||
class TestZeroOptimizer(MultiProcessTestCase):
|
class TestZeroOptimizer(MultiProcessTestCase):
|
||||||
|
|
||||||
def setUp(self) -> None:
|
def setUp(self):
|
||||||
super().setUp()
|
super().setUp()
|
||||||
self._spawn_processes()
|
self._spawn_processes()
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user