[Refactor] Refactor the unit tests of SyncBuffersHook (#813)

This commit is contained in:
Mashiro 2023-04-28 17:32:30 +08:00 committed by GitHub
parent 69b563dc3b
commit 3715fea15b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 73 additions and 9 deletions

View File

@ -167,5 +167,5 @@ jobs:
if: ${{ matrix.platform == 'cpu' }}
- name: Run GPU unittests
# Skip testing distributed related unit tests since the memory of windows CI is limited
run: pytest tests/ --ignore tests/test_dist --ignore tests/test_optim/test_optimizer/test_optimizer_wrapper.py --ignore tests/test_model/test_wrappers/test_model_wrapper.py
run: pytest tests/ --ignore tests/test_dist --ignore tests/test_optim/test_optimizer/test_optimizer_wrapper.py --ignore tests/test_model/test_wrappers/test_model_wrapper.py --ignore tests/test_hooks/test_sync_buffers_hook.py
if: ${{ matrix.platform == 'cu111' }}

View File

@ -67,6 +67,9 @@ class MultiProcessTestCase(TestCase):
def _should_stop_test_suite(self) -> bool:
return False
def prepare_subprocess(self):
pass
@property
def world_size(self) -> int:
return 2
@ -169,7 +172,7 @@ class MultiProcessTestCase(TestCase):
def _run(cls, rank: int, test_name: str, file_name: str,
parent_pipe) -> None:
self = cls(test_name)
self.prepare_subprocess()
self.rank = rank
self.file_name = file_name
self.run_test(test_name, parent_pipe)

View File

@ -1,13 +1,74 @@
# Copyright (c) OpenMMLab. All rights reserved.
from unittest.mock import Mock
import os
from unittest.mock import MagicMock
import torch
import torch.distributed as torch_dist
import torch.nn as nn
from mmengine.dist import all_gather
from mmengine.hooks import SyncBuffersHook
from mmengine.registry import MODELS
from mmengine.testing._internal import MultiProcessTestCase
from mmengine.testing.runner_test_case import RunnerTestCase, ToyModel
class TestSyncBuffersHook:
class ToyModuleWithNorm(ToyModel):
def __init__(self, data_preprocessor=None):
super().__init__(data_preprocessor=data_preprocessor)
bn = nn.BatchNorm1d(2)
self.linear1 = nn.Sequential(self.linear1, bn)
def init_weights(self):
for buffer in self.buffers():
buffer.fill_(
torch.tensor(int(os.environ['RANK']), dtype=torch.float32))
return super().init_weights()
class TestSyncBuffersHook(MultiProcessTestCase, RunnerTestCase):
def setUp(self) -> None:
super().setUp()
self._spawn_processes()
def prepare_subprocess(self):
MODELS.register_module(module=ToyModuleWithNorm, force=True)
super(MultiProcessTestCase, self).setUp()
def test_sync_buffers_hook(self):
runner = Mock()
runner.model = Mock()
self.setup_dist_env()
runner = MagicMock()
runner.model = ToyModuleWithNorm()
runner.model.init_weights()
for buffer in runner.model.buffers():
buffer1, buffer2 = all_gather(buffer)
self.assertFalse(torch.allclose(buffer1, buffer2))
hook = SyncBuffersHook()
hook._after_epoch(runner)
hook.after_train_epoch(runner)
for buffer in runner.model.buffers():
buffer1, buffer2 = all_gather(buffer)
self.assertTrue(torch.allclose(buffer1, buffer2))
def test_with_runner(self):
self.setup_dist_env()
cfg = self.epoch_based_cfg
cfg.model = dict(type='ToyModuleWithNorm')
cfg.launch = 'pytorch'
cfg.custom_hooks = [dict(type='SyncBuffersHook')]
runner = self.build_runner(cfg)
runner.train()
for buffer in runner.model.buffers():
buffer1, buffer2 = all_gather(buffer)
self.assertTrue(torch.allclose(buffer1, buffer2))
def setup_dist_env(self):
super().setup_dist_env()
os.environ['RANK'] = str(self.rank)
torch_dist.init_process_group(
backend='gloo', rank=self.rank, world_size=self.world_size)

View File

@ -75,7 +75,7 @@ class ComplexModel(BaseModel):
class TestDistributedDataParallel(MultiProcessTestCase):
def setUp(self) -> None:
def setUp(self):
super().setUp()
self._spawn_processes()

View File

@ -774,7 +774,7 @@ class TestBuilder(TestCase):
reason='ZeRO requires pytorch>=1.8 with torch.distributed.rpc available.')
class TestZeroOptimizer(MultiProcessTestCase):
def setUp(self) -> None:
def setUp(self):
super().setUp()
self._spawn_processes()