[Enhance] enhance runner test case (#631)

* Add runner test cast

* Fix unit test

* fix unit test

* pop None if key does not exist

* Fix is_model_wrapper and force register class in test_runner

* [Fix] Fix is_model_wrapper

* destroy group after ut

* register module in testcase

* fix as comment

* minor refine

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* fix lint

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
pull/756/head
Mashiro 2022-11-21 11:54:05 +08:00 committed by GitHub
parent b7aa4dd885
commit c478bdca27
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 301 additions and 32 deletions

View File

@ -3,9 +3,10 @@ from .compare import (assert_allclose, assert_attrs_equal,
assert_dict_contains_subset, assert_dict_has_keys,
assert_is_norm_layer, assert_keys_equal,
assert_params_all_zeros, check_python_script)
from .runner_test_case import RunnerTestCase
__all__ = [
'assert_allclose', 'assert_dict_contains_subset', 'assert_keys_equal',
'assert_attrs_equal', 'assert_dict_has_keys', 'assert_is_norm_layer',
'assert_params_all_zeros', 'check_python_script'
'assert_params_all_zeros', 'check_python_script', 'RunnerTestCase'
]

View File

@ -0,0 +1,186 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import logging
import os
import tempfile
import time
from unittest import TestCase
from uuid import uuid4
import torch
import torch.nn as nn
from torch.distributed import destroy_process_group
from torch.utils.data import Dataset
import mmengine.hooks # noqa F401
import mmengine.optim # noqa F401
from mmengine.config import Config
from mmengine.dist import is_distributed
from mmengine.evaluator import BaseMetric
from mmengine.logging import MessageHub, MMLogger
from mmengine.model import BaseModel
from mmengine.registry import DATASETS, METRICS, MODELS, DefaultScope
from mmengine.runner import Runner
from mmengine.visualization import Visualizer
class ToyModel(BaseModel):
def __init__(self, data_preprocessor=None):
super().__init__(data_preprocessor=data_preprocessor)
self.linear1 = nn.Linear(2, 2)
self.linear2 = nn.Linear(2, 1)
def forward(self, inputs, data_samples, mode='tensor'):
if isinstance(inputs, list):
inputs = torch.stack(inputs)
if isinstance(data_samples, list):
data_sample = torch.stack(data_samples)
outputs = self.linear1(inputs)
outputs = self.linear2(outputs)
if mode == 'tensor':
return outputs
elif mode == 'loss':
loss = (data_sample - outputs).sum()
outputs = dict(loss=loss)
return outputs
elif mode == 'predict':
return outputs
class ToyDataset(Dataset):
METAINFO = dict() # type: ignore
data = torch.randn(12, 2)
label = torch.ones(12)
@property
def metainfo(self):
return self.METAINFO
def __len__(self):
return self.data.size(0)
def __getitem__(self, index):
return dict(inputs=self.data[index], data_samples=self.label[index])
class ToyMetric(BaseMetric):
def __init__(self, collect_device='cpu', dummy_metrics=None):
super().__init__(collect_device=collect_device)
self.dummy_metrics = dummy_metrics
def process(self, data_batch, predictions):
result = {'acc': 1}
self.results.append(result)
def compute_metrics(self, results):
return dict(acc=1)
class RunnerTestCase(TestCase):
"""A test case to build runner easily.
`RunnerTestCase` will do the following things:
1. Registers a toy model, a toy metric, and a toy dataset, which can be
used to run the `Runner` successfully.
2. Provides epoch based and iteration based cfg to build runner.
3. Provides `build_runner` method to build runner easily.
4. Clean the global variable used by the runner.
"""
dist_cfg = dict(
MASTER_ADDR='127.0.0.1',
MASTER_PORT=29600,
RANK='0',
WORLD_SIZE='1',
LOCAL_RANK='0')
def setUp(self) -> None:
self.temp_dir = tempfile.TemporaryDirectory()
# Prevent from registering module with the same name by other unit
# test. These registries will be cleared in `tearDown`
MODELS.register_module(module=ToyModel, force=True)
METRICS.register_module(module=ToyMetric, force=True)
DATASETS.register_module(module=ToyDataset, force=True)
epoch_based_cfg = dict(
work_dir=self.temp_dir.name,
model=dict(type='ToyModel'),
train_dataloader=dict(
dataset=dict(type='ToyDataset'),
sampler=dict(type='DefaultSampler', shuffle=True),
batch_size=3,
num_workers=0),
val_dataloader=dict(
dataset=dict(type='ToyDataset'),
sampler=dict(type='DefaultSampler', shuffle=False),
batch_size=3,
num_workers=0),
val_evaluator=[dict(type='ToyMetric')],
test_dataloader=dict(
dataset=dict(type='ToyDataset'),
sampler=dict(type='DefaultSampler', shuffle=False),
batch_size=3,
num_workers=0),
test_evaluator=[dict(type='ToyMetric')],
optim_wrapper=dict(optimizer=dict(type='SGD', lr=0.1)),
train_cfg=dict(by_epoch=True, max_epochs=2, val_interval=1),
val_cfg=dict(),
test_cfg=dict(),
default_hooks=dict(logger=dict(type='LoggerHook', interval=1)),
custom_hooks=[],
env_cfg=dict(dist_cfg=dict(backend='nccl')),
experiment_name='test1')
self.epoch_based_cfg = Config(epoch_based_cfg)
# prepare iter based cfg.
self.iter_based_cfg: Config = copy.deepcopy(self.epoch_based_cfg)
self.iter_based_cfg.train_dataloader = dict(
dataset=dict(type='ToyDataset'),
sampler=dict(type='InfiniteSampler', shuffle=True),
batch_size=3,
num_workers=0)
self.iter_based_cfg.log_processor = dict(by_epoch=False)
self.iter_based_cfg.train_cfg = dict(by_epoch=False, max_iters=12)
self.iter_based_cfg.default_hooks = dict(
logger=dict(type='LoggerHook', interval=1),
checkpoint=dict(
type='CheckpointHook', interval=12, by_epoch=False))
def tearDown(self):
# `FileHandler` should be closed in Windows, otherwise we cannot
# delete the temporary directory
logging.shutdown()
MMLogger._instance_dict.clear()
Visualizer._instance_dict.clear()
DefaultScope._instance_dict.clear()
MessageHub._instance_dict.clear()
MODELS.module_dict.pop('ToyModel', None)
METRICS.module_dict.pop('ToyMetric', None)
DATASETS.module_dict.pop('ToyDataset', None)
self.temp_dir.cleanup()
if is_distributed():
destroy_process_group()
def build_runner(self, cfg: Config):
cfg.experiment_name = self.experiment_name
runner = Runner.from_cfg(cfg)
return runner
@property
def experiment_name(self):
# Since runners could be built too fast to have a unique experiment
# name(timestamp is the same), here we use uuid to make sure each
# runner has the unique experiment name.
return f'{self._testMethodName}_{time.time()} + ' \
f'{uuid4()}'
def setup_dist_env(self):
self.dist_cfg['MASTER_PORT'] += 1
os.environ['MASTER_PORT'] = str(self.dist_cfg['MASTER_PORT'])
os.environ['MASTER_ADDR'] = self.dist_cfg['MASTER_ADDR']
os.environ['RANK'] = self.dist_cfg['RANK']
os.environ['WORLD_SIZE'] = self.dist_cfg['WORLD_SIZE']
os.environ['LOCAL_RANK'] = self.dist_cfg['LOCAL_RANK']

View File

@ -79,7 +79,7 @@ def test_is_model_wrapper():
pass
CHILD_REGISTRY.register_module(module=CustomModelWrapper)
CHILD_REGISTRY.register_module(module=CustomModelWrapper, force=True)
for wrapper in [
DistributedDataParallel, MMDistributedDataParallel,

View File

@ -37,7 +37,6 @@ from mmengine.utils.dl_utils import TORCH_VERSION
from mmengine.visualization import Visualizer
@MODELS.register_module()
class ToyModel(BaseModel):
def __init__(self, data_preprocessor=None):
@ -63,14 +62,12 @@ class ToyModel(BaseModel):
return outputs
@MODELS.register_module()
class ToyModel1(ToyModel):
def __init__(self):
super().__init__()
@MODELS.register_module()
class ToySyncBNModel(BaseModel):
def __init__(self):
@ -95,7 +92,6 @@ class ToySyncBNModel(BaseModel):
return outputs
@MODELS.register_module()
class ToyGANModel(BaseModel):
def __init__(self):
@ -127,7 +123,6 @@ class ToyGANModel(BaseModel):
return loss
@MODEL_WRAPPERS.register_module()
class CustomModelWrapper(nn.Module):
def __init__(self, module):
@ -135,7 +130,6 @@ class CustomModelWrapper(nn.Module):
self.model = module
@OPTIM_WRAPPER_CONSTRUCTORS.register_module()
class ToyMultipleOptimizerConstructor:
def __init__(self, optim_wrapper_cfg, paramwise_cfg=None):
@ -163,7 +157,6 @@ class ToyMultipleOptimizerConstructor:
return OptimWrapperDict(**optimizers)
@DATASETS.register_module()
class ToyDataset(Dataset):
METAINFO = dict() # type: ignore
data = torch.randn(12, 2)
@ -180,7 +173,6 @@ class ToyDataset(Dataset):
return dict(inputs=self.data[index], data_sample=self.label[index])
@DATASETS.register_module()
class ToyDatasetNoMeta(Dataset):
data = torch.randn(12, 2)
label = torch.ones(12)
@ -192,7 +184,6 @@ class ToyDatasetNoMeta(Dataset):
return dict(inputs=self.data[index], data_sample=self.label[index])
@METRICS.register_module()
class ToyMetric1(BaseMetric):
def __init__(self, collect_device='cpu', dummy_metrics=None):
@ -207,7 +198,6 @@ class ToyMetric1(BaseMetric):
return dict(acc=1)
@METRICS.register_module()
class ToyMetric2(BaseMetric):
def __init__(self, collect_device='cpu', dummy_metrics=None):
@ -222,12 +212,10 @@ class ToyMetric2(BaseMetric):
return dict(acc=1)
@OPTIM_WRAPPERS.register_module()
class ToyOptimWrapper(OptimWrapper):
...
@HOOKS.register_module()
class ToyHook(Hook):
priority = 'Lowest'
@ -235,7 +223,6 @@ class ToyHook(Hook):
pass
@HOOKS.register_module()
class ToyHook2(Hook):
priority = 'Lowest'
@ -243,7 +230,6 @@ class ToyHook2(Hook):
pass
@LOOPS.register_module()
class CustomTrainLoop(BaseLoop):
def __init__(self, runner, dataloader, max_epochs):
@ -254,7 +240,6 @@ class CustomTrainLoop(BaseLoop):
pass
@LOOPS.register_module()
class CustomValLoop(BaseLoop):
def __init__(self, runner, dataloader, evaluator):
@ -270,7 +255,6 @@ class CustomValLoop(BaseLoop):
pass
@LOOPS.register_module()
class CustomTestLoop(BaseLoop):
def __init__(self, runner, dataloader, evaluator):
@ -286,7 +270,6 @@ class CustomTestLoop(BaseLoop):
pass
@LOG_PROCESSORS.register_module()
class CustomLogProcessor(LogProcessor):
def __init__(self, window_size=10, by_epoch=True, custom_cfg=None):
@ -296,7 +279,6 @@ class CustomLogProcessor(LogProcessor):
self._check_custom_cfg()
@RUNNERS.register_module()
class CustomRunner(Runner):
def __init__(self,
@ -333,7 +315,6 @@ class CustomRunner(Runner):
pass
@EVALUATOR.register_module()
class ToyEvaluator(Evaluator):
def __init__(self, metrics):
@ -344,7 +325,6 @@ def collate_fn(data_batch):
return pseudo_collate(data_batch)
@COLLATE_FUNCTIONS.register_module()
def custom_collate(data_batch, pad_value):
return pseudo_collate(data_batch)
@ -352,6 +332,28 @@ def custom_collate(data_batch, pad_value):
class TestRunner(TestCase):
def setUp(self):
MODELS.register_module(module=ToyModel, force=True)
MODELS.register_module(module=ToyModel1, force=True)
MODELS.register_module(module=ToySyncBNModel, force=True)
MODELS.register_module(module=ToyGANModel, force=True)
MODEL_WRAPPERS.register_module(module=CustomModelWrapper, force=True)
OPTIM_WRAPPER_CONSTRUCTORS.register_module(
module=ToyMultipleOptimizerConstructor, force=True)
DATASETS.register_module(module=ToyDataset, force=True)
DATASETS.register_module(module=ToyDatasetNoMeta, force=True)
METRICS.register_module(module=ToyMetric1, force=True)
METRICS.register_module(module=ToyMetric2, force=True)
OPTIM_WRAPPERS.register_module(module=ToyOptimWrapper, force=True)
HOOKS.register_module(module=ToyHook, force=True)
HOOKS.register_module(module=ToyHook2, force=True)
LOOPS.register_module(module=CustomTrainLoop, force=True)
LOOPS.register_module(module=CustomValLoop, force=True)
LOOPS.register_module(module=CustomTestLoop, force=True)
LOG_PROCESSORS.register_module(module=CustomLogProcessor, force=True)
RUNNERS.register_module(module=CustomRunner, force=True)
EVALUATOR.register_module(module=ToyEvaluator, force=True)
COLLATE_FUNCTIONS.register_module(module=custom_collate, force=True)
self.temp_dir = tempfile.mkdtemp()
epoch_based_cfg = dict(
model=dict(type='ToyModel'),
@ -413,6 +415,28 @@ class TestRunner(TestCase):
def tearDown(self):
# `FileHandler` should be closed in Windows, otherwise we cannot
# delete the temporary directory
MODELS.module_dict.pop('ToyModel')
MODELS.module_dict.pop('ToyModel1')
MODELS.module_dict.pop('ToySyncBNModel')
MODELS.module_dict.pop('ToyGANModel')
MODEL_WRAPPERS.module_dict.pop('CustomModelWrapper')
OPTIM_WRAPPER_CONSTRUCTORS.module_dict.pop(
'ToyMultipleOptimizerConstructor')
OPTIM_WRAPPERS.module_dict.pop('ToyOptimWrapper')
DATASETS.module_dict.pop('ToyDataset')
DATASETS.module_dict.pop('ToyDatasetNoMeta')
METRICS.module_dict.pop('ToyMetric1')
METRICS.module_dict.pop('ToyMetric2')
HOOKS.module_dict.pop('ToyHook')
HOOKS.module_dict.pop('ToyHook2')
LOOPS.module_dict.pop('CustomTrainLoop')
LOOPS.module_dict.pop('CustomValLoop')
LOOPS.module_dict.pop('CustomTestLoop')
LOG_PROCESSORS.module_dict.pop('CustomLogProcessor')
RUNNERS.module_dict.pop('CustomRunner')
EVALUATOR.module_dict.pop('ToyEvaluator')
COLLATE_FUNCTIONS.module_dict.pop('custom_collate')
logging.shutdown()
MMLogger._instance_dict.clear()
shutil.rmtree(self.temp_dir)
@ -782,7 +806,7 @@ class TestRunner(TestCase):
TOY_SCHEDULERS = Registry(
'parameter scheduler', parent=PARAM_SCHEDULERS, scope='toy')
@TOY_SCHEDULERS.register_module()
@TOY_SCHEDULERS.register_module(force=True)
class ToyScheduler(MultiStepLR):
def __init__(self, *args, **kwargs):
@ -863,7 +887,7 @@ class TestRunner(TestCase):
cfg.model_wrapper_cfg = dict(type='CustomModelWrapper')
runner.from_cfg(cfg)
@MODELS.register_module()
@MODELS.register_module(force=True)
class ToyBN(BaseModel):
def __init__(self):
@ -1349,7 +1373,7 @@ class TestRunner(TestCase):
val_epoch_results = []
val_epoch_targets = [i for i in range(2, 4)]
@HOOKS.register_module()
@HOOKS.register_module(force=True)
class TestEpochHook(Hook):
def before_train_epoch(self, runner):
@ -1394,7 +1418,7 @@ class TestRunner(TestCase):
val_iter_targets = [i for i in range(4, 12)]
val_batch_idx_targets = [i for i in range(4)] * 2
@HOOKS.register_module()
@HOOKS.register_module(force=True)
class TestIterHook(Hook):
def before_train_epoch(self, runner):
@ -1487,7 +1511,7 @@ class TestRunner(TestCase):
val_interval_results = []
val_interval_targets = [5] * 10 + [2] * 2
@HOOKS.register_module()
@HOOKS.register_module(force=True)
class TestIterDynamicIntervalHook(Hook):
def before_val(self, runner):
@ -1524,7 +1548,7 @@ class TestRunner(TestCase):
val_interval_results = []
val_interval_targets = [5] * 10 + [2] * 2
@HOOKS.register_module()
@HOOKS.register_module(force=True)
class TestEpochDynamicIntervalHook(Hook):
def before_val_epoch(self, runner):
@ -1553,7 +1577,7 @@ class TestRunner(TestCase):
self.assertEqual(result, target)
# 7. test init weights
@MODELS.register_module()
@MODELS.register_module(force=True)
class ToyModel2(ToyModel):
def __init__(self):
@ -1654,7 +1678,7 @@ class TestRunner(TestCase):
runner.train()
# 12.1 Test train with model, which does not inherit from BaseModel
@MODELS.register_module()
@MODELS.register_module(force=True)
class ToyModel3(nn.Module):
def __init__(self):
@ -1901,7 +1925,7 @@ class TestRunner(TestCase):
def test_custom_loop(self):
# test custom loop with additional hook
@LOOPS.register_module()
@LOOPS.register_module(force=True)
class CustomTrainLoop2(IterBasedTrainLoop):
"""Custom train loop with additional warmup stage."""
@ -1942,7 +1966,7 @@ class TestRunner(TestCase):
before_warmup_iter_results = []
after_warmup_iter_results = []
@HOOKS.register_module()
@HOOKS.register_module(force=True)
class TestWarmupHook(Hook):
"""test custom train loop."""

View File

@ -0,0 +1,58 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os
from mmengine import Config
from mmengine.logging import MessageHub, MMLogger
from mmengine.registry import DefaultScope
from mmengine.testing import RunnerTestCase
from mmengine.visualization import Visualizer
class TestRunnerTestCase(RunnerTestCase):
def test_setup(self):
self.assertIsInstance(self.epoch_based_cfg, Config)
self.assertIsInstance(self.iter_based_cfg, Config)
self.assertIn('MASTER_ADDR', self.dist_cfg)
self.assertIn('MASTER_PORT', self.dist_cfg)
self.assertIn('RANK', self.dist_cfg)
self.assertIn('WORLD_SIZE', self.dist_cfg)
self.assertIn('LOCAL_RANK', self.dist_cfg)
def test_tearDown(self):
self.tearDown()
self.assertEqual(MMLogger._instance_dict, {})
self.assertEqual(MessageHub._instance_dict, {})
self.assertEqual(Visualizer._instance_dict, {})
self.assertEqual(DefaultScope._instance_dict, {})
# tearDown should not be called twice.
self.tearDown = super(RunnerTestCase, self).tearDown
def test_build_runner(self):
runner = self.build_runner(self.epoch_based_cfg)
runner.train()
runner.val()
runner.test()
runner = self.build_runner(self.iter_based_cfg)
runner.train()
runner.val()
runner.test()
def test_experiment_name(self):
runner1 = self.build_runner(self.epoch_based_cfg)
runner2 = self.build_runner(self.epoch_based_cfg)
self.assertNotEqual(runner1.experiment_name, runner2.experiment_name)
def test_init_dist(self):
self.setup_dist_env()
self.assertEqual(
str(self.dist_cfg['MASTER_PORT']), os.environ['MASTER_PORT'])
self.assertEqual(self.dist_cfg['MASTER_ADDR'],
os.environ['MASTER_ADDR'])
self.assertEqual(self.dist_cfg['RANK'], os.environ['RANK'])
self.assertEqual(self.dist_cfg['LOCAL_RANK'], os.environ['LOCAL_RANK'])
self.assertEqual(self.dist_cfg['WORLD_SIZE'], os.environ['WORLD_SIZE'])
fisrt_port = os.environ['MASTER_ADDR']
self.setup_dist_env()
self.assertNotEqual(fisrt_port, os.environ['MASTER_PORT'])