mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Enhance] enhance runner test case (#631)
* Add runner test cast * Fix unit test * fix unit test * pop None if key does not exist * Fix is_model_wrapper and force register class in test_runner * [Fix] Fix is_model_wrapper * destroy group after ut * register module in testcase * fix as comment * minor refine Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * fix lint Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
This commit is contained in:
parent
b7aa4dd885
commit
c478bdca27
@ -3,9 +3,10 @@ from .compare import (assert_allclose, assert_attrs_equal,
|
|||||||
assert_dict_contains_subset, assert_dict_has_keys,
|
assert_dict_contains_subset, assert_dict_has_keys,
|
||||||
assert_is_norm_layer, assert_keys_equal,
|
assert_is_norm_layer, assert_keys_equal,
|
||||||
assert_params_all_zeros, check_python_script)
|
assert_params_all_zeros, check_python_script)
|
||||||
|
from .runner_test_case import RunnerTestCase
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'assert_allclose', 'assert_dict_contains_subset', 'assert_keys_equal',
|
'assert_allclose', 'assert_dict_contains_subset', 'assert_keys_equal',
|
||||||
'assert_attrs_equal', 'assert_dict_has_keys', 'assert_is_norm_layer',
|
'assert_attrs_equal', 'assert_dict_has_keys', 'assert_is_norm_layer',
|
||||||
'assert_params_all_zeros', 'check_python_script'
|
'assert_params_all_zeros', 'check_python_script', 'RunnerTestCase'
|
||||||
]
|
]
|
||||||
|
186
mmengine/testing/runner_test_case.py
Normal file
186
mmengine/testing/runner_test_case.py
Normal file
@ -0,0 +1,186 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
import copy
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
import time
|
||||||
|
from unittest import TestCase
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.distributed import destroy_process_group
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
|
import mmengine.hooks # noqa F401
|
||||||
|
import mmengine.optim # noqa F401
|
||||||
|
from mmengine.config import Config
|
||||||
|
from mmengine.dist import is_distributed
|
||||||
|
from mmengine.evaluator import BaseMetric
|
||||||
|
from mmengine.logging import MessageHub, MMLogger
|
||||||
|
from mmengine.model import BaseModel
|
||||||
|
from mmengine.registry import DATASETS, METRICS, MODELS, DefaultScope
|
||||||
|
from mmengine.runner import Runner
|
||||||
|
from mmengine.visualization import Visualizer
|
||||||
|
|
||||||
|
|
||||||
|
class ToyModel(BaseModel):
|
||||||
|
|
||||||
|
def __init__(self, data_preprocessor=None):
|
||||||
|
super().__init__(data_preprocessor=data_preprocessor)
|
||||||
|
self.linear1 = nn.Linear(2, 2)
|
||||||
|
self.linear2 = nn.Linear(2, 1)
|
||||||
|
|
||||||
|
def forward(self, inputs, data_samples, mode='tensor'):
|
||||||
|
if isinstance(inputs, list):
|
||||||
|
inputs = torch.stack(inputs)
|
||||||
|
if isinstance(data_samples, list):
|
||||||
|
data_sample = torch.stack(data_samples)
|
||||||
|
outputs = self.linear1(inputs)
|
||||||
|
outputs = self.linear2(outputs)
|
||||||
|
|
||||||
|
if mode == 'tensor':
|
||||||
|
return outputs
|
||||||
|
elif mode == 'loss':
|
||||||
|
loss = (data_sample - outputs).sum()
|
||||||
|
outputs = dict(loss=loss)
|
||||||
|
return outputs
|
||||||
|
elif mode == 'predict':
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
class ToyDataset(Dataset):
|
||||||
|
METAINFO = dict() # type: ignore
|
||||||
|
data = torch.randn(12, 2)
|
||||||
|
label = torch.ones(12)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def metainfo(self):
|
||||||
|
return self.METAINFO
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.data.size(0)
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
return dict(inputs=self.data[index], data_samples=self.label[index])
|
||||||
|
|
||||||
|
|
||||||
|
class ToyMetric(BaseMetric):
|
||||||
|
|
||||||
|
def __init__(self, collect_device='cpu', dummy_metrics=None):
|
||||||
|
super().__init__(collect_device=collect_device)
|
||||||
|
self.dummy_metrics = dummy_metrics
|
||||||
|
|
||||||
|
def process(self, data_batch, predictions):
|
||||||
|
result = {'acc': 1}
|
||||||
|
self.results.append(result)
|
||||||
|
|
||||||
|
def compute_metrics(self, results):
|
||||||
|
return dict(acc=1)
|
||||||
|
|
||||||
|
|
||||||
|
class RunnerTestCase(TestCase):
|
||||||
|
"""A test case to build runner easily.
|
||||||
|
|
||||||
|
`RunnerTestCase` will do the following things:
|
||||||
|
|
||||||
|
1. Registers a toy model, a toy metric, and a toy dataset, which can be
|
||||||
|
used to run the `Runner` successfully.
|
||||||
|
2. Provides epoch based and iteration based cfg to build runner.
|
||||||
|
3. Provides `build_runner` method to build runner easily.
|
||||||
|
4. Clean the global variable used by the runner.
|
||||||
|
"""
|
||||||
|
dist_cfg = dict(
|
||||||
|
MASTER_ADDR='127.0.0.1',
|
||||||
|
MASTER_PORT=29600,
|
||||||
|
RANK='0',
|
||||||
|
WORLD_SIZE='1',
|
||||||
|
LOCAL_RANK='0')
|
||||||
|
|
||||||
|
def setUp(self) -> None:
|
||||||
|
self.temp_dir = tempfile.TemporaryDirectory()
|
||||||
|
# Prevent from registering module with the same name by other unit
|
||||||
|
# test. These registries will be cleared in `tearDown`
|
||||||
|
MODELS.register_module(module=ToyModel, force=True)
|
||||||
|
METRICS.register_module(module=ToyMetric, force=True)
|
||||||
|
DATASETS.register_module(module=ToyDataset, force=True)
|
||||||
|
epoch_based_cfg = dict(
|
||||||
|
work_dir=self.temp_dir.name,
|
||||||
|
model=dict(type='ToyModel'),
|
||||||
|
train_dataloader=dict(
|
||||||
|
dataset=dict(type='ToyDataset'),
|
||||||
|
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||||
|
batch_size=3,
|
||||||
|
num_workers=0),
|
||||||
|
val_dataloader=dict(
|
||||||
|
dataset=dict(type='ToyDataset'),
|
||||||
|
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||||
|
batch_size=3,
|
||||||
|
num_workers=0),
|
||||||
|
val_evaluator=[dict(type='ToyMetric')],
|
||||||
|
test_dataloader=dict(
|
||||||
|
dataset=dict(type='ToyDataset'),
|
||||||
|
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||||
|
batch_size=3,
|
||||||
|
num_workers=0),
|
||||||
|
test_evaluator=[dict(type='ToyMetric')],
|
||||||
|
optim_wrapper=dict(optimizer=dict(type='SGD', lr=0.1)),
|
||||||
|
train_cfg=dict(by_epoch=True, max_epochs=2, val_interval=1),
|
||||||
|
val_cfg=dict(),
|
||||||
|
test_cfg=dict(),
|
||||||
|
default_hooks=dict(logger=dict(type='LoggerHook', interval=1)),
|
||||||
|
custom_hooks=[],
|
||||||
|
env_cfg=dict(dist_cfg=dict(backend='nccl')),
|
||||||
|
experiment_name='test1')
|
||||||
|
self.epoch_based_cfg = Config(epoch_based_cfg)
|
||||||
|
|
||||||
|
# prepare iter based cfg.
|
||||||
|
self.iter_based_cfg: Config = copy.deepcopy(self.epoch_based_cfg)
|
||||||
|
self.iter_based_cfg.train_dataloader = dict(
|
||||||
|
dataset=dict(type='ToyDataset'),
|
||||||
|
sampler=dict(type='InfiniteSampler', shuffle=True),
|
||||||
|
batch_size=3,
|
||||||
|
num_workers=0)
|
||||||
|
self.iter_based_cfg.log_processor = dict(by_epoch=False)
|
||||||
|
|
||||||
|
self.iter_based_cfg.train_cfg = dict(by_epoch=False, max_iters=12)
|
||||||
|
self.iter_based_cfg.default_hooks = dict(
|
||||||
|
logger=dict(type='LoggerHook', interval=1),
|
||||||
|
checkpoint=dict(
|
||||||
|
type='CheckpointHook', interval=12, by_epoch=False))
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
# `FileHandler` should be closed in Windows, otherwise we cannot
|
||||||
|
# delete the temporary directory
|
||||||
|
logging.shutdown()
|
||||||
|
MMLogger._instance_dict.clear()
|
||||||
|
Visualizer._instance_dict.clear()
|
||||||
|
DefaultScope._instance_dict.clear()
|
||||||
|
MessageHub._instance_dict.clear()
|
||||||
|
MODELS.module_dict.pop('ToyModel', None)
|
||||||
|
METRICS.module_dict.pop('ToyMetric', None)
|
||||||
|
DATASETS.module_dict.pop('ToyDataset', None)
|
||||||
|
self.temp_dir.cleanup()
|
||||||
|
if is_distributed():
|
||||||
|
destroy_process_group()
|
||||||
|
|
||||||
|
def build_runner(self, cfg: Config):
|
||||||
|
cfg.experiment_name = self.experiment_name
|
||||||
|
runner = Runner.from_cfg(cfg)
|
||||||
|
return runner
|
||||||
|
|
||||||
|
@property
|
||||||
|
def experiment_name(self):
|
||||||
|
# Since runners could be built too fast to have a unique experiment
|
||||||
|
# name(timestamp is the same), here we use uuid to make sure each
|
||||||
|
# runner has the unique experiment name.
|
||||||
|
return f'{self._testMethodName}_{time.time()} + ' \
|
||||||
|
f'{uuid4()}'
|
||||||
|
|
||||||
|
def setup_dist_env(self):
|
||||||
|
self.dist_cfg['MASTER_PORT'] += 1
|
||||||
|
os.environ['MASTER_PORT'] = str(self.dist_cfg['MASTER_PORT'])
|
||||||
|
os.environ['MASTER_ADDR'] = self.dist_cfg['MASTER_ADDR']
|
||||||
|
os.environ['RANK'] = self.dist_cfg['RANK']
|
||||||
|
os.environ['WORLD_SIZE'] = self.dist_cfg['WORLD_SIZE']
|
||||||
|
os.environ['LOCAL_RANK'] = self.dist_cfg['LOCAL_RANK']
|
@ -79,7 +79,7 @@ def test_is_model_wrapper():
|
|||||||
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
CHILD_REGISTRY.register_module(module=CustomModelWrapper)
|
CHILD_REGISTRY.register_module(module=CustomModelWrapper, force=True)
|
||||||
|
|
||||||
for wrapper in [
|
for wrapper in [
|
||||||
DistributedDataParallel, MMDistributedDataParallel,
|
DistributedDataParallel, MMDistributedDataParallel,
|
||||||
|
@ -37,7 +37,6 @@ from mmengine.utils.dl_utils import TORCH_VERSION
|
|||||||
from mmengine.visualization import Visualizer
|
from mmengine.visualization import Visualizer
|
||||||
|
|
||||||
|
|
||||||
@MODELS.register_module()
|
|
||||||
class ToyModel(BaseModel):
|
class ToyModel(BaseModel):
|
||||||
|
|
||||||
def __init__(self, data_preprocessor=None):
|
def __init__(self, data_preprocessor=None):
|
||||||
@ -63,14 +62,12 @@ class ToyModel(BaseModel):
|
|||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
@MODELS.register_module()
|
|
||||||
class ToyModel1(ToyModel):
|
class ToyModel1(ToyModel):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
|
||||||
@MODELS.register_module()
|
|
||||||
class ToySyncBNModel(BaseModel):
|
class ToySyncBNModel(BaseModel):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@ -95,7 +92,6 @@ class ToySyncBNModel(BaseModel):
|
|||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
@MODELS.register_module()
|
|
||||||
class ToyGANModel(BaseModel):
|
class ToyGANModel(BaseModel):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@ -127,7 +123,6 @@ class ToyGANModel(BaseModel):
|
|||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
|
||||||
@MODEL_WRAPPERS.register_module()
|
|
||||||
class CustomModelWrapper(nn.Module):
|
class CustomModelWrapper(nn.Module):
|
||||||
|
|
||||||
def __init__(self, module):
|
def __init__(self, module):
|
||||||
@ -135,7 +130,6 @@ class CustomModelWrapper(nn.Module):
|
|||||||
self.model = module
|
self.model = module
|
||||||
|
|
||||||
|
|
||||||
@OPTIM_WRAPPER_CONSTRUCTORS.register_module()
|
|
||||||
class ToyMultipleOptimizerConstructor:
|
class ToyMultipleOptimizerConstructor:
|
||||||
|
|
||||||
def __init__(self, optim_wrapper_cfg, paramwise_cfg=None):
|
def __init__(self, optim_wrapper_cfg, paramwise_cfg=None):
|
||||||
@ -163,7 +157,6 @@ class ToyMultipleOptimizerConstructor:
|
|||||||
return OptimWrapperDict(**optimizers)
|
return OptimWrapperDict(**optimizers)
|
||||||
|
|
||||||
|
|
||||||
@DATASETS.register_module()
|
|
||||||
class ToyDataset(Dataset):
|
class ToyDataset(Dataset):
|
||||||
METAINFO = dict() # type: ignore
|
METAINFO = dict() # type: ignore
|
||||||
data = torch.randn(12, 2)
|
data = torch.randn(12, 2)
|
||||||
@ -180,7 +173,6 @@ class ToyDataset(Dataset):
|
|||||||
return dict(inputs=self.data[index], data_sample=self.label[index])
|
return dict(inputs=self.data[index], data_sample=self.label[index])
|
||||||
|
|
||||||
|
|
||||||
@DATASETS.register_module()
|
|
||||||
class ToyDatasetNoMeta(Dataset):
|
class ToyDatasetNoMeta(Dataset):
|
||||||
data = torch.randn(12, 2)
|
data = torch.randn(12, 2)
|
||||||
label = torch.ones(12)
|
label = torch.ones(12)
|
||||||
@ -192,7 +184,6 @@ class ToyDatasetNoMeta(Dataset):
|
|||||||
return dict(inputs=self.data[index], data_sample=self.label[index])
|
return dict(inputs=self.data[index], data_sample=self.label[index])
|
||||||
|
|
||||||
|
|
||||||
@METRICS.register_module()
|
|
||||||
class ToyMetric1(BaseMetric):
|
class ToyMetric1(BaseMetric):
|
||||||
|
|
||||||
def __init__(self, collect_device='cpu', dummy_metrics=None):
|
def __init__(self, collect_device='cpu', dummy_metrics=None):
|
||||||
@ -207,7 +198,6 @@ class ToyMetric1(BaseMetric):
|
|||||||
return dict(acc=1)
|
return dict(acc=1)
|
||||||
|
|
||||||
|
|
||||||
@METRICS.register_module()
|
|
||||||
class ToyMetric2(BaseMetric):
|
class ToyMetric2(BaseMetric):
|
||||||
|
|
||||||
def __init__(self, collect_device='cpu', dummy_metrics=None):
|
def __init__(self, collect_device='cpu', dummy_metrics=None):
|
||||||
@ -222,12 +212,10 @@ class ToyMetric2(BaseMetric):
|
|||||||
return dict(acc=1)
|
return dict(acc=1)
|
||||||
|
|
||||||
|
|
||||||
@OPTIM_WRAPPERS.register_module()
|
|
||||||
class ToyOptimWrapper(OptimWrapper):
|
class ToyOptimWrapper(OptimWrapper):
|
||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
@HOOKS.register_module()
|
|
||||||
class ToyHook(Hook):
|
class ToyHook(Hook):
|
||||||
priority = 'Lowest'
|
priority = 'Lowest'
|
||||||
|
|
||||||
@ -235,7 +223,6 @@ class ToyHook(Hook):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@HOOKS.register_module()
|
|
||||||
class ToyHook2(Hook):
|
class ToyHook2(Hook):
|
||||||
priority = 'Lowest'
|
priority = 'Lowest'
|
||||||
|
|
||||||
@ -243,7 +230,6 @@ class ToyHook2(Hook):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@LOOPS.register_module()
|
|
||||||
class CustomTrainLoop(BaseLoop):
|
class CustomTrainLoop(BaseLoop):
|
||||||
|
|
||||||
def __init__(self, runner, dataloader, max_epochs):
|
def __init__(self, runner, dataloader, max_epochs):
|
||||||
@ -254,7 +240,6 @@ class CustomTrainLoop(BaseLoop):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@LOOPS.register_module()
|
|
||||||
class CustomValLoop(BaseLoop):
|
class CustomValLoop(BaseLoop):
|
||||||
|
|
||||||
def __init__(self, runner, dataloader, evaluator):
|
def __init__(self, runner, dataloader, evaluator):
|
||||||
@ -270,7 +255,6 @@ class CustomValLoop(BaseLoop):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@LOOPS.register_module()
|
|
||||||
class CustomTestLoop(BaseLoop):
|
class CustomTestLoop(BaseLoop):
|
||||||
|
|
||||||
def __init__(self, runner, dataloader, evaluator):
|
def __init__(self, runner, dataloader, evaluator):
|
||||||
@ -286,7 +270,6 @@ class CustomTestLoop(BaseLoop):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@LOG_PROCESSORS.register_module()
|
|
||||||
class CustomLogProcessor(LogProcessor):
|
class CustomLogProcessor(LogProcessor):
|
||||||
|
|
||||||
def __init__(self, window_size=10, by_epoch=True, custom_cfg=None):
|
def __init__(self, window_size=10, by_epoch=True, custom_cfg=None):
|
||||||
@ -296,7 +279,6 @@ class CustomLogProcessor(LogProcessor):
|
|||||||
self._check_custom_cfg()
|
self._check_custom_cfg()
|
||||||
|
|
||||||
|
|
||||||
@RUNNERS.register_module()
|
|
||||||
class CustomRunner(Runner):
|
class CustomRunner(Runner):
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
@ -333,7 +315,6 @@ class CustomRunner(Runner):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@EVALUATOR.register_module()
|
|
||||||
class ToyEvaluator(Evaluator):
|
class ToyEvaluator(Evaluator):
|
||||||
|
|
||||||
def __init__(self, metrics):
|
def __init__(self, metrics):
|
||||||
@ -344,7 +325,6 @@ def collate_fn(data_batch):
|
|||||||
return pseudo_collate(data_batch)
|
return pseudo_collate(data_batch)
|
||||||
|
|
||||||
|
|
||||||
@COLLATE_FUNCTIONS.register_module()
|
|
||||||
def custom_collate(data_batch, pad_value):
|
def custom_collate(data_batch, pad_value):
|
||||||
return pseudo_collate(data_batch)
|
return pseudo_collate(data_batch)
|
||||||
|
|
||||||
@ -352,6 +332,28 @@ def custom_collate(data_batch, pad_value):
|
|||||||
class TestRunner(TestCase):
|
class TestRunner(TestCase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
|
MODELS.register_module(module=ToyModel, force=True)
|
||||||
|
MODELS.register_module(module=ToyModel1, force=True)
|
||||||
|
MODELS.register_module(module=ToySyncBNModel, force=True)
|
||||||
|
MODELS.register_module(module=ToyGANModel, force=True)
|
||||||
|
MODEL_WRAPPERS.register_module(module=CustomModelWrapper, force=True)
|
||||||
|
OPTIM_WRAPPER_CONSTRUCTORS.register_module(
|
||||||
|
module=ToyMultipleOptimizerConstructor, force=True)
|
||||||
|
DATASETS.register_module(module=ToyDataset, force=True)
|
||||||
|
DATASETS.register_module(module=ToyDatasetNoMeta, force=True)
|
||||||
|
METRICS.register_module(module=ToyMetric1, force=True)
|
||||||
|
METRICS.register_module(module=ToyMetric2, force=True)
|
||||||
|
OPTIM_WRAPPERS.register_module(module=ToyOptimWrapper, force=True)
|
||||||
|
HOOKS.register_module(module=ToyHook, force=True)
|
||||||
|
HOOKS.register_module(module=ToyHook2, force=True)
|
||||||
|
LOOPS.register_module(module=CustomTrainLoop, force=True)
|
||||||
|
LOOPS.register_module(module=CustomValLoop, force=True)
|
||||||
|
LOOPS.register_module(module=CustomTestLoop, force=True)
|
||||||
|
LOG_PROCESSORS.register_module(module=CustomLogProcessor, force=True)
|
||||||
|
RUNNERS.register_module(module=CustomRunner, force=True)
|
||||||
|
EVALUATOR.register_module(module=ToyEvaluator, force=True)
|
||||||
|
COLLATE_FUNCTIONS.register_module(module=custom_collate, force=True)
|
||||||
|
|
||||||
self.temp_dir = tempfile.mkdtemp()
|
self.temp_dir = tempfile.mkdtemp()
|
||||||
epoch_based_cfg = dict(
|
epoch_based_cfg = dict(
|
||||||
model=dict(type='ToyModel'),
|
model=dict(type='ToyModel'),
|
||||||
@ -413,6 +415,28 @@ class TestRunner(TestCase):
|
|||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
# `FileHandler` should be closed in Windows, otherwise we cannot
|
# `FileHandler` should be closed in Windows, otherwise we cannot
|
||||||
# delete the temporary directory
|
# delete the temporary directory
|
||||||
|
MODELS.module_dict.pop('ToyModel')
|
||||||
|
MODELS.module_dict.pop('ToyModel1')
|
||||||
|
MODELS.module_dict.pop('ToySyncBNModel')
|
||||||
|
MODELS.module_dict.pop('ToyGANModel')
|
||||||
|
MODEL_WRAPPERS.module_dict.pop('CustomModelWrapper')
|
||||||
|
OPTIM_WRAPPER_CONSTRUCTORS.module_dict.pop(
|
||||||
|
'ToyMultipleOptimizerConstructor')
|
||||||
|
OPTIM_WRAPPERS.module_dict.pop('ToyOptimWrapper')
|
||||||
|
DATASETS.module_dict.pop('ToyDataset')
|
||||||
|
DATASETS.module_dict.pop('ToyDatasetNoMeta')
|
||||||
|
METRICS.module_dict.pop('ToyMetric1')
|
||||||
|
METRICS.module_dict.pop('ToyMetric2')
|
||||||
|
HOOKS.module_dict.pop('ToyHook')
|
||||||
|
HOOKS.module_dict.pop('ToyHook2')
|
||||||
|
LOOPS.module_dict.pop('CustomTrainLoop')
|
||||||
|
LOOPS.module_dict.pop('CustomValLoop')
|
||||||
|
LOOPS.module_dict.pop('CustomTestLoop')
|
||||||
|
LOG_PROCESSORS.module_dict.pop('CustomLogProcessor')
|
||||||
|
RUNNERS.module_dict.pop('CustomRunner')
|
||||||
|
EVALUATOR.module_dict.pop('ToyEvaluator')
|
||||||
|
COLLATE_FUNCTIONS.module_dict.pop('custom_collate')
|
||||||
|
|
||||||
logging.shutdown()
|
logging.shutdown()
|
||||||
MMLogger._instance_dict.clear()
|
MMLogger._instance_dict.clear()
|
||||||
shutil.rmtree(self.temp_dir)
|
shutil.rmtree(self.temp_dir)
|
||||||
@ -782,7 +806,7 @@ class TestRunner(TestCase):
|
|||||||
TOY_SCHEDULERS = Registry(
|
TOY_SCHEDULERS = Registry(
|
||||||
'parameter scheduler', parent=PARAM_SCHEDULERS, scope='toy')
|
'parameter scheduler', parent=PARAM_SCHEDULERS, scope='toy')
|
||||||
|
|
||||||
@TOY_SCHEDULERS.register_module()
|
@TOY_SCHEDULERS.register_module(force=True)
|
||||||
class ToyScheduler(MultiStepLR):
|
class ToyScheduler(MultiStepLR):
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
@ -863,7 +887,7 @@ class TestRunner(TestCase):
|
|||||||
cfg.model_wrapper_cfg = dict(type='CustomModelWrapper')
|
cfg.model_wrapper_cfg = dict(type='CustomModelWrapper')
|
||||||
runner.from_cfg(cfg)
|
runner.from_cfg(cfg)
|
||||||
|
|
||||||
@MODELS.register_module()
|
@MODELS.register_module(force=True)
|
||||||
class ToyBN(BaseModel):
|
class ToyBN(BaseModel):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@ -1349,7 +1373,7 @@ class TestRunner(TestCase):
|
|||||||
val_epoch_results = []
|
val_epoch_results = []
|
||||||
val_epoch_targets = [i for i in range(2, 4)]
|
val_epoch_targets = [i for i in range(2, 4)]
|
||||||
|
|
||||||
@HOOKS.register_module()
|
@HOOKS.register_module(force=True)
|
||||||
class TestEpochHook(Hook):
|
class TestEpochHook(Hook):
|
||||||
|
|
||||||
def before_train_epoch(self, runner):
|
def before_train_epoch(self, runner):
|
||||||
@ -1394,7 +1418,7 @@ class TestRunner(TestCase):
|
|||||||
val_iter_targets = [i for i in range(4, 12)]
|
val_iter_targets = [i for i in range(4, 12)]
|
||||||
val_batch_idx_targets = [i for i in range(4)] * 2
|
val_batch_idx_targets = [i for i in range(4)] * 2
|
||||||
|
|
||||||
@HOOKS.register_module()
|
@HOOKS.register_module(force=True)
|
||||||
class TestIterHook(Hook):
|
class TestIterHook(Hook):
|
||||||
|
|
||||||
def before_train_epoch(self, runner):
|
def before_train_epoch(self, runner):
|
||||||
@ -1487,7 +1511,7 @@ class TestRunner(TestCase):
|
|||||||
val_interval_results = []
|
val_interval_results = []
|
||||||
val_interval_targets = [5] * 10 + [2] * 2
|
val_interval_targets = [5] * 10 + [2] * 2
|
||||||
|
|
||||||
@HOOKS.register_module()
|
@HOOKS.register_module(force=True)
|
||||||
class TestIterDynamicIntervalHook(Hook):
|
class TestIterDynamicIntervalHook(Hook):
|
||||||
|
|
||||||
def before_val(self, runner):
|
def before_val(self, runner):
|
||||||
@ -1524,7 +1548,7 @@ class TestRunner(TestCase):
|
|||||||
val_interval_results = []
|
val_interval_results = []
|
||||||
val_interval_targets = [5] * 10 + [2] * 2
|
val_interval_targets = [5] * 10 + [2] * 2
|
||||||
|
|
||||||
@HOOKS.register_module()
|
@HOOKS.register_module(force=True)
|
||||||
class TestEpochDynamicIntervalHook(Hook):
|
class TestEpochDynamicIntervalHook(Hook):
|
||||||
|
|
||||||
def before_val_epoch(self, runner):
|
def before_val_epoch(self, runner):
|
||||||
@ -1553,7 +1577,7 @@ class TestRunner(TestCase):
|
|||||||
self.assertEqual(result, target)
|
self.assertEqual(result, target)
|
||||||
|
|
||||||
# 7. test init weights
|
# 7. test init weights
|
||||||
@MODELS.register_module()
|
@MODELS.register_module(force=True)
|
||||||
class ToyModel2(ToyModel):
|
class ToyModel2(ToyModel):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@ -1654,7 +1678,7 @@ class TestRunner(TestCase):
|
|||||||
runner.train()
|
runner.train()
|
||||||
|
|
||||||
# 12.1 Test train with model, which does not inherit from BaseModel
|
# 12.1 Test train with model, which does not inherit from BaseModel
|
||||||
@MODELS.register_module()
|
@MODELS.register_module(force=True)
|
||||||
class ToyModel3(nn.Module):
|
class ToyModel3(nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@ -1901,7 +1925,7 @@ class TestRunner(TestCase):
|
|||||||
|
|
||||||
def test_custom_loop(self):
|
def test_custom_loop(self):
|
||||||
# test custom loop with additional hook
|
# test custom loop with additional hook
|
||||||
@LOOPS.register_module()
|
@LOOPS.register_module(force=True)
|
||||||
class CustomTrainLoop2(IterBasedTrainLoop):
|
class CustomTrainLoop2(IterBasedTrainLoop):
|
||||||
"""Custom train loop with additional warmup stage."""
|
"""Custom train loop with additional warmup stage."""
|
||||||
|
|
||||||
@ -1942,7 +1966,7 @@ class TestRunner(TestCase):
|
|||||||
before_warmup_iter_results = []
|
before_warmup_iter_results = []
|
||||||
after_warmup_iter_results = []
|
after_warmup_iter_results = []
|
||||||
|
|
||||||
@HOOKS.register_module()
|
@HOOKS.register_module(force=True)
|
||||||
class TestWarmupHook(Hook):
|
class TestWarmupHook(Hook):
|
||||||
"""test custom train loop."""
|
"""test custom train loop."""
|
||||||
|
|
||||||
|
58
tests/test_testing/test_runner_test_case.py
Normal file
58
tests/test_testing/test_runner_test_case.py
Normal file
@ -0,0 +1,58 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
import os
|
||||||
|
|
||||||
|
from mmengine import Config
|
||||||
|
from mmengine.logging import MessageHub, MMLogger
|
||||||
|
from mmengine.registry import DefaultScope
|
||||||
|
from mmengine.testing import RunnerTestCase
|
||||||
|
from mmengine.visualization import Visualizer
|
||||||
|
|
||||||
|
|
||||||
|
class TestRunnerTestCase(RunnerTestCase):
|
||||||
|
|
||||||
|
def test_setup(self):
|
||||||
|
self.assertIsInstance(self.epoch_based_cfg, Config)
|
||||||
|
self.assertIsInstance(self.iter_based_cfg, Config)
|
||||||
|
self.assertIn('MASTER_ADDR', self.dist_cfg)
|
||||||
|
self.assertIn('MASTER_PORT', self.dist_cfg)
|
||||||
|
self.assertIn('RANK', self.dist_cfg)
|
||||||
|
self.assertIn('WORLD_SIZE', self.dist_cfg)
|
||||||
|
self.assertIn('LOCAL_RANK', self.dist_cfg)
|
||||||
|
|
||||||
|
def test_tearDown(self):
|
||||||
|
self.tearDown()
|
||||||
|
self.assertEqual(MMLogger._instance_dict, {})
|
||||||
|
self.assertEqual(MessageHub._instance_dict, {})
|
||||||
|
self.assertEqual(Visualizer._instance_dict, {})
|
||||||
|
self.assertEqual(DefaultScope._instance_dict, {})
|
||||||
|
# tearDown should not be called twice.
|
||||||
|
self.tearDown = super(RunnerTestCase, self).tearDown
|
||||||
|
|
||||||
|
def test_build_runner(self):
|
||||||
|
runner = self.build_runner(self.epoch_based_cfg)
|
||||||
|
runner.train()
|
||||||
|
runner.val()
|
||||||
|
runner.test()
|
||||||
|
|
||||||
|
runner = self.build_runner(self.iter_based_cfg)
|
||||||
|
runner.train()
|
||||||
|
runner.val()
|
||||||
|
runner.test()
|
||||||
|
|
||||||
|
def test_experiment_name(self):
|
||||||
|
runner1 = self.build_runner(self.epoch_based_cfg)
|
||||||
|
runner2 = self.build_runner(self.epoch_based_cfg)
|
||||||
|
self.assertNotEqual(runner1.experiment_name, runner2.experiment_name)
|
||||||
|
|
||||||
|
def test_init_dist(self):
|
||||||
|
self.setup_dist_env()
|
||||||
|
self.assertEqual(
|
||||||
|
str(self.dist_cfg['MASTER_PORT']), os.environ['MASTER_PORT'])
|
||||||
|
self.assertEqual(self.dist_cfg['MASTER_ADDR'],
|
||||||
|
os.environ['MASTER_ADDR'])
|
||||||
|
self.assertEqual(self.dist_cfg['RANK'], os.environ['RANK'])
|
||||||
|
self.assertEqual(self.dist_cfg['LOCAL_RANK'], os.environ['LOCAL_RANK'])
|
||||||
|
self.assertEqual(self.dist_cfg['WORLD_SIZE'], os.environ['WORLD_SIZE'])
|
||||||
|
fisrt_port = os.environ['MASTER_ADDR']
|
||||||
|
self.setup_dist_env()
|
||||||
|
self.assertNotEqual(fisrt_port, os.environ['MASTER_PORT'])
|
Loading…
x
Reference in New Issue
Block a user