[Enhance] enhance runner test case (#631)
* Add runner test cast * Fix unit test * fix unit test * pop None if key does not exist * Fix is_model_wrapper and force register class in test_runner * [Fix] Fix is_model_wrapper * destroy group after ut * register module in testcase * fix as comment * minor refine Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * fix lint Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>pull/756/head
parent
b7aa4dd885
commit
c478bdca27
|
@ -3,9 +3,10 @@ from .compare import (assert_allclose, assert_attrs_equal,
|
|||
assert_dict_contains_subset, assert_dict_has_keys,
|
||||
assert_is_norm_layer, assert_keys_equal,
|
||||
assert_params_all_zeros, check_python_script)
|
||||
from .runner_test_case import RunnerTestCase
|
||||
|
||||
__all__ = [
|
||||
'assert_allclose', 'assert_dict_contains_subset', 'assert_keys_equal',
|
||||
'assert_attrs_equal', 'assert_dict_has_keys', 'assert_is_norm_layer',
|
||||
'assert_params_all_zeros', 'check_python_script'
|
||||
'assert_params_all_zeros', 'check_python_script', 'RunnerTestCase'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,186 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
import time
|
||||
from unittest import TestCase
|
||||
from uuid import uuid4
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.distributed import destroy_process_group
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
import mmengine.hooks # noqa F401
|
||||
import mmengine.optim # noqa F401
|
||||
from mmengine.config import Config
|
||||
from mmengine.dist import is_distributed
|
||||
from mmengine.evaluator import BaseMetric
|
||||
from mmengine.logging import MessageHub, MMLogger
|
||||
from mmengine.model import BaseModel
|
||||
from mmengine.registry import DATASETS, METRICS, MODELS, DefaultScope
|
||||
from mmengine.runner import Runner
|
||||
from mmengine.visualization import Visualizer
|
||||
|
||||
|
||||
class ToyModel(BaseModel):
|
||||
|
||||
def __init__(self, data_preprocessor=None):
|
||||
super().__init__(data_preprocessor=data_preprocessor)
|
||||
self.linear1 = nn.Linear(2, 2)
|
||||
self.linear2 = nn.Linear(2, 1)
|
||||
|
||||
def forward(self, inputs, data_samples, mode='tensor'):
|
||||
if isinstance(inputs, list):
|
||||
inputs = torch.stack(inputs)
|
||||
if isinstance(data_samples, list):
|
||||
data_sample = torch.stack(data_samples)
|
||||
outputs = self.linear1(inputs)
|
||||
outputs = self.linear2(outputs)
|
||||
|
||||
if mode == 'tensor':
|
||||
return outputs
|
||||
elif mode == 'loss':
|
||||
loss = (data_sample - outputs).sum()
|
||||
outputs = dict(loss=loss)
|
||||
return outputs
|
||||
elif mode == 'predict':
|
||||
return outputs
|
||||
|
||||
|
||||
class ToyDataset(Dataset):
|
||||
METAINFO = dict() # type: ignore
|
||||
data = torch.randn(12, 2)
|
||||
label = torch.ones(12)
|
||||
|
||||
@property
|
||||
def metainfo(self):
|
||||
return self.METAINFO
|
||||
|
||||
def __len__(self):
|
||||
return self.data.size(0)
|
||||
|
||||
def __getitem__(self, index):
|
||||
return dict(inputs=self.data[index], data_samples=self.label[index])
|
||||
|
||||
|
||||
class ToyMetric(BaseMetric):
|
||||
|
||||
def __init__(self, collect_device='cpu', dummy_metrics=None):
|
||||
super().__init__(collect_device=collect_device)
|
||||
self.dummy_metrics = dummy_metrics
|
||||
|
||||
def process(self, data_batch, predictions):
|
||||
result = {'acc': 1}
|
||||
self.results.append(result)
|
||||
|
||||
def compute_metrics(self, results):
|
||||
return dict(acc=1)
|
||||
|
||||
|
||||
class RunnerTestCase(TestCase):
|
||||
"""A test case to build runner easily.
|
||||
|
||||
`RunnerTestCase` will do the following things:
|
||||
|
||||
1. Registers a toy model, a toy metric, and a toy dataset, which can be
|
||||
used to run the `Runner` successfully.
|
||||
2. Provides epoch based and iteration based cfg to build runner.
|
||||
3. Provides `build_runner` method to build runner easily.
|
||||
4. Clean the global variable used by the runner.
|
||||
"""
|
||||
dist_cfg = dict(
|
||||
MASTER_ADDR='127.0.0.1',
|
||||
MASTER_PORT=29600,
|
||||
RANK='0',
|
||||
WORLD_SIZE='1',
|
||||
LOCAL_RANK='0')
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.temp_dir = tempfile.TemporaryDirectory()
|
||||
# Prevent from registering module with the same name by other unit
|
||||
# test. These registries will be cleared in `tearDown`
|
||||
MODELS.register_module(module=ToyModel, force=True)
|
||||
METRICS.register_module(module=ToyMetric, force=True)
|
||||
DATASETS.register_module(module=ToyDataset, force=True)
|
||||
epoch_based_cfg = dict(
|
||||
work_dir=self.temp_dir.name,
|
||||
model=dict(type='ToyModel'),
|
||||
train_dataloader=dict(
|
||||
dataset=dict(type='ToyDataset'),
|
||||
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||
batch_size=3,
|
||||
num_workers=0),
|
||||
val_dataloader=dict(
|
||||
dataset=dict(type='ToyDataset'),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
batch_size=3,
|
||||
num_workers=0),
|
||||
val_evaluator=[dict(type='ToyMetric')],
|
||||
test_dataloader=dict(
|
||||
dataset=dict(type='ToyDataset'),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
batch_size=3,
|
||||
num_workers=0),
|
||||
test_evaluator=[dict(type='ToyMetric')],
|
||||
optim_wrapper=dict(optimizer=dict(type='SGD', lr=0.1)),
|
||||
train_cfg=dict(by_epoch=True, max_epochs=2, val_interval=1),
|
||||
val_cfg=dict(),
|
||||
test_cfg=dict(),
|
||||
default_hooks=dict(logger=dict(type='LoggerHook', interval=1)),
|
||||
custom_hooks=[],
|
||||
env_cfg=dict(dist_cfg=dict(backend='nccl')),
|
||||
experiment_name='test1')
|
||||
self.epoch_based_cfg = Config(epoch_based_cfg)
|
||||
|
||||
# prepare iter based cfg.
|
||||
self.iter_based_cfg: Config = copy.deepcopy(self.epoch_based_cfg)
|
||||
self.iter_based_cfg.train_dataloader = dict(
|
||||
dataset=dict(type='ToyDataset'),
|
||||
sampler=dict(type='InfiniteSampler', shuffle=True),
|
||||
batch_size=3,
|
||||
num_workers=0)
|
||||
self.iter_based_cfg.log_processor = dict(by_epoch=False)
|
||||
|
||||
self.iter_based_cfg.train_cfg = dict(by_epoch=False, max_iters=12)
|
||||
self.iter_based_cfg.default_hooks = dict(
|
||||
logger=dict(type='LoggerHook', interval=1),
|
||||
checkpoint=dict(
|
||||
type='CheckpointHook', interval=12, by_epoch=False))
|
||||
|
||||
def tearDown(self):
|
||||
# `FileHandler` should be closed in Windows, otherwise we cannot
|
||||
# delete the temporary directory
|
||||
logging.shutdown()
|
||||
MMLogger._instance_dict.clear()
|
||||
Visualizer._instance_dict.clear()
|
||||
DefaultScope._instance_dict.clear()
|
||||
MessageHub._instance_dict.clear()
|
||||
MODELS.module_dict.pop('ToyModel', None)
|
||||
METRICS.module_dict.pop('ToyMetric', None)
|
||||
DATASETS.module_dict.pop('ToyDataset', None)
|
||||
self.temp_dir.cleanup()
|
||||
if is_distributed():
|
||||
destroy_process_group()
|
||||
|
||||
def build_runner(self, cfg: Config):
|
||||
cfg.experiment_name = self.experiment_name
|
||||
runner = Runner.from_cfg(cfg)
|
||||
return runner
|
||||
|
||||
@property
|
||||
def experiment_name(self):
|
||||
# Since runners could be built too fast to have a unique experiment
|
||||
# name(timestamp is the same), here we use uuid to make sure each
|
||||
# runner has the unique experiment name.
|
||||
return f'{self._testMethodName}_{time.time()} + ' \
|
||||
f'{uuid4()}'
|
||||
|
||||
def setup_dist_env(self):
|
||||
self.dist_cfg['MASTER_PORT'] += 1
|
||||
os.environ['MASTER_PORT'] = str(self.dist_cfg['MASTER_PORT'])
|
||||
os.environ['MASTER_ADDR'] = self.dist_cfg['MASTER_ADDR']
|
||||
os.environ['RANK'] = self.dist_cfg['RANK']
|
||||
os.environ['WORLD_SIZE'] = self.dist_cfg['WORLD_SIZE']
|
||||
os.environ['LOCAL_RANK'] = self.dist_cfg['LOCAL_RANK']
|
|
@ -79,7 +79,7 @@ def test_is_model_wrapper():
|
|||
|
||||
pass
|
||||
|
||||
CHILD_REGISTRY.register_module(module=CustomModelWrapper)
|
||||
CHILD_REGISTRY.register_module(module=CustomModelWrapper, force=True)
|
||||
|
||||
for wrapper in [
|
||||
DistributedDataParallel, MMDistributedDataParallel,
|
||||
|
|
|
@ -37,7 +37,6 @@ from mmengine.utils.dl_utils import TORCH_VERSION
|
|||
from mmengine.visualization import Visualizer
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class ToyModel(BaseModel):
|
||||
|
||||
def __init__(self, data_preprocessor=None):
|
||||
|
@ -63,14 +62,12 @@ class ToyModel(BaseModel):
|
|||
return outputs
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class ToyModel1(ToyModel):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class ToySyncBNModel(BaseModel):
|
||||
|
||||
def __init__(self):
|
||||
|
@ -95,7 +92,6 @@ class ToySyncBNModel(BaseModel):
|
|||
return outputs
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class ToyGANModel(BaseModel):
|
||||
|
||||
def __init__(self):
|
||||
|
@ -127,7 +123,6 @@ class ToyGANModel(BaseModel):
|
|||
return loss
|
||||
|
||||
|
||||
@MODEL_WRAPPERS.register_module()
|
||||
class CustomModelWrapper(nn.Module):
|
||||
|
||||
def __init__(self, module):
|
||||
|
@ -135,7 +130,6 @@ class CustomModelWrapper(nn.Module):
|
|||
self.model = module
|
||||
|
||||
|
||||
@OPTIM_WRAPPER_CONSTRUCTORS.register_module()
|
||||
class ToyMultipleOptimizerConstructor:
|
||||
|
||||
def __init__(self, optim_wrapper_cfg, paramwise_cfg=None):
|
||||
|
@ -163,7 +157,6 @@ class ToyMultipleOptimizerConstructor:
|
|||
return OptimWrapperDict(**optimizers)
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
class ToyDataset(Dataset):
|
||||
METAINFO = dict() # type: ignore
|
||||
data = torch.randn(12, 2)
|
||||
|
@ -180,7 +173,6 @@ class ToyDataset(Dataset):
|
|||
return dict(inputs=self.data[index], data_sample=self.label[index])
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
class ToyDatasetNoMeta(Dataset):
|
||||
data = torch.randn(12, 2)
|
||||
label = torch.ones(12)
|
||||
|
@ -192,7 +184,6 @@ class ToyDatasetNoMeta(Dataset):
|
|||
return dict(inputs=self.data[index], data_sample=self.label[index])
|
||||
|
||||
|
||||
@METRICS.register_module()
|
||||
class ToyMetric1(BaseMetric):
|
||||
|
||||
def __init__(self, collect_device='cpu', dummy_metrics=None):
|
||||
|
@ -207,7 +198,6 @@ class ToyMetric1(BaseMetric):
|
|||
return dict(acc=1)
|
||||
|
||||
|
||||
@METRICS.register_module()
|
||||
class ToyMetric2(BaseMetric):
|
||||
|
||||
def __init__(self, collect_device='cpu', dummy_metrics=None):
|
||||
|
@ -222,12 +212,10 @@ class ToyMetric2(BaseMetric):
|
|||
return dict(acc=1)
|
||||
|
||||
|
||||
@OPTIM_WRAPPERS.register_module()
|
||||
class ToyOptimWrapper(OptimWrapper):
|
||||
...
|
||||
|
||||
|
||||
@HOOKS.register_module()
|
||||
class ToyHook(Hook):
|
||||
priority = 'Lowest'
|
||||
|
||||
|
@ -235,7 +223,6 @@ class ToyHook(Hook):
|
|||
pass
|
||||
|
||||
|
||||
@HOOKS.register_module()
|
||||
class ToyHook2(Hook):
|
||||
priority = 'Lowest'
|
||||
|
||||
|
@ -243,7 +230,6 @@ class ToyHook2(Hook):
|
|||
pass
|
||||
|
||||
|
||||
@LOOPS.register_module()
|
||||
class CustomTrainLoop(BaseLoop):
|
||||
|
||||
def __init__(self, runner, dataloader, max_epochs):
|
||||
|
@ -254,7 +240,6 @@ class CustomTrainLoop(BaseLoop):
|
|||
pass
|
||||
|
||||
|
||||
@LOOPS.register_module()
|
||||
class CustomValLoop(BaseLoop):
|
||||
|
||||
def __init__(self, runner, dataloader, evaluator):
|
||||
|
@ -270,7 +255,6 @@ class CustomValLoop(BaseLoop):
|
|||
pass
|
||||
|
||||
|
||||
@LOOPS.register_module()
|
||||
class CustomTestLoop(BaseLoop):
|
||||
|
||||
def __init__(self, runner, dataloader, evaluator):
|
||||
|
@ -286,7 +270,6 @@ class CustomTestLoop(BaseLoop):
|
|||
pass
|
||||
|
||||
|
||||
@LOG_PROCESSORS.register_module()
|
||||
class CustomLogProcessor(LogProcessor):
|
||||
|
||||
def __init__(self, window_size=10, by_epoch=True, custom_cfg=None):
|
||||
|
@ -296,7 +279,6 @@ class CustomLogProcessor(LogProcessor):
|
|||
self._check_custom_cfg()
|
||||
|
||||
|
||||
@RUNNERS.register_module()
|
||||
class CustomRunner(Runner):
|
||||
|
||||
def __init__(self,
|
||||
|
@ -333,7 +315,6 @@ class CustomRunner(Runner):
|
|||
pass
|
||||
|
||||
|
||||
@EVALUATOR.register_module()
|
||||
class ToyEvaluator(Evaluator):
|
||||
|
||||
def __init__(self, metrics):
|
||||
|
@ -344,7 +325,6 @@ def collate_fn(data_batch):
|
|||
return pseudo_collate(data_batch)
|
||||
|
||||
|
||||
@COLLATE_FUNCTIONS.register_module()
|
||||
def custom_collate(data_batch, pad_value):
|
||||
return pseudo_collate(data_batch)
|
||||
|
||||
|
@ -352,6 +332,28 @@ def custom_collate(data_batch, pad_value):
|
|||
class TestRunner(TestCase):
|
||||
|
||||
def setUp(self):
|
||||
MODELS.register_module(module=ToyModel, force=True)
|
||||
MODELS.register_module(module=ToyModel1, force=True)
|
||||
MODELS.register_module(module=ToySyncBNModel, force=True)
|
||||
MODELS.register_module(module=ToyGANModel, force=True)
|
||||
MODEL_WRAPPERS.register_module(module=CustomModelWrapper, force=True)
|
||||
OPTIM_WRAPPER_CONSTRUCTORS.register_module(
|
||||
module=ToyMultipleOptimizerConstructor, force=True)
|
||||
DATASETS.register_module(module=ToyDataset, force=True)
|
||||
DATASETS.register_module(module=ToyDatasetNoMeta, force=True)
|
||||
METRICS.register_module(module=ToyMetric1, force=True)
|
||||
METRICS.register_module(module=ToyMetric2, force=True)
|
||||
OPTIM_WRAPPERS.register_module(module=ToyOptimWrapper, force=True)
|
||||
HOOKS.register_module(module=ToyHook, force=True)
|
||||
HOOKS.register_module(module=ToyHook2, force=True)
|
||||
LOOPS.register_module(module=CustomTrainLoop, force=True)
|
||||
LOOPS.register_module(module=CustomValLoop, force=True)
|
||||
LOOPS.register_module(module=CustomTestLoop, force=True)
|
||||
LOG_PROCESSORS.register_module(module=CustomLogProcessor, force=True)
|
||||
RUNNERS.register_module(module=CustomRunner, force=True)
|
||||
EVALUATOR.register_module(module=ToyEvaluator, force=True)
|
||||
COLLATE_FUNCTIONS.register_module(module=custom_collate, force=True)
|
||||
|
||||
self.temp_dir = tempfile.mkdtemp()
|
||||
epoch_based_cfg = dict(
|
||||
model=dict(type='ToyModel'),
|
||||
|
@ -413,6 +415,28 @@ class TestRunner(TestCase):
|
|||
def tearDown(self):
|
||||
# `FileHandler` should be closed in Windows, otherwise we cannot
|
||||
# delete the temporary directory
|
||||
MODELS.module_dict.pop('ToyModel')
|
||||
MODELS.module_dict.pop('ToyModel1')
|
||||
MODELS.module_dict.pop('ToySyncBNModel')
|
||||
MODELS.module_dict.pop('ToyGANModel')
|
||||
MODEL_WRAPPERS.module_dict.pop('CustomModelWrapper')
|
||||
OPTIM_WRAPPER_CONSTRUCTORS.module_dict.pop(
|
||||
'ToyMultipleOptimizerConstructor')
|
||||
OPTIM_WRAPPERS.module_dict.pop('ToyOptimWrapper')
|
||||
DATASETS.module_dict.pop('ToyDataset')
|
||||
DATASETS.module_dict.pop('ToyDatasetNoMeta')
|
||||
METRICS.module_dict.pop('ToyMetric1')
|
||||
METRICS.module_dict.pop('ToyMetric2')
|
||||
HOOKS.module_dict.pop('ToyHook')
|
||||
HOOKS.module_dict.pop('ToyHook2')
|
||||
LOOPS.module_dict.pop('CustomTrainLoop')
|
||||
LOOPS.module_dict.pop('CustomValLoop')
|
||||
LOOPS.module_dict.pop('CustomTestLoop')
|
||||
LOG_PROCESSORS.module_dict.pop('CustomLogProcessor')
|
||||
RUNNERS.module_dict.pop('CustomRunner')
|
||||
EVALUATOR.module_dict.pop('ToyEvaluator')
|
||||
COLLATE_FUNCTIONS.module_dict.pop('custom_collate')
|
||||
|
||||
logging.shutdown()
|
||||
MMLogger._instance_dict.clear()
|
||||
shutil.rmtree(self.temp_dir)
|
||||
|
@ -782,7 +806,7 @@ class TestRunner(TestCase):
|
|||
TOY_SCHEDULERS = Registry(
|
||||
'parameter scheduler', parent=PARAM_SCHEDULERS, scope='toy')
|
||||
|
||||
@TOY_SCHEDULERS.register_module()
|
||||
@TOY_SCHEDULERS.register_module(force=True)
|
||||
class ToyScheduler(MultiStepLR):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
|
@ -863,7 +887,7 @@ class TestRunner(TestCase):
|
|||
cfg.model_wrapper_cfg = dict(type='CustomModelWrapper')
|
||||
runner.from_cfg(cfg)
|
||||
|
||||
@MODELS.register_module()
|
||||
@MODELS.register_module(force=True)
|
||||
class ToyBN(BaseModel):
|
||||
|
||||
def __init__(self):
|
||||
|
@ -1349,7 +1373,7 @@ class TestRunner(TestCase):
|
|||
val_epoch_results = []
|
||||
val_epoch_targets = [i for i in range(2, 4)]
|
||||
|
||||
@HOOKS.register_module()
|
||||
@HOOKS.register_module(force=True)
|
||||
class TestEpochHook(Hook):
|
||||
|
||||
def before_train_epoch(self, runner):
|
||||
|
@ -1394,7 +1418,7 @@ class TestRunner(TestCase):
|
|||
val_iter_targets = [i for i in range(4, 12)]
|
||||
val_batch_idx_targets = [i for i in range(4)] * 2
|
||||
|
||||
@HOOKS.register_module()
|
||||
@HOOKS.register_module(force=True)
|
||||
class TestIterHook(Hook):
|
||||
|
||||
def before_train_epoch(self, runner):
|
||||
|
@ -1487,7 +1511,7 @@ class TestRunner(TestCase):
|
|||
val_interval_results = []
|
||||
val_interval_targets = [5] * 10 + [2] * 2
|
||||
|
||||
@HOOKS.register_module()
|
||||
@HOOKS.register_module(force=True)
|
||||
class TestIterDynamicIntervalHook(Hook):
|
||||
|
||||
def before_val(self, runner):
|
||||
|
@ -1524,7 +1548,7 @@ class TestRunner(TestCase):
|
|||
val_interval_results = []
|
||||
val_interval_targets = [5] * 10 + [2] * 2
|
||||
|
||||
@HOOKS.register_module()
|
||||
@HOOKS.register_module(force=True)
|
||||
class TestEpochDynamicIntervalHook(Hook):
|
||||
|
||||
def before_val_epoch(self, runner):
|
||||
|
@ -1553,7 +1577,7 @@ class TestRunner(TestCase):
|
|||
self.assertEqual(result, target)
|
||||
|
||||
# 7. test init weights
|
||||
@MODELS.register_module()
|
||||
@MODELS.register_module(force=True)
|
||||
class ToyModel2(ToyModel):
|
||||
|
||||
def __init__(self):
|
||||
|
@ -1654,7 +1678,7 @@ class TestRunner(TestCase):
|
|||
runner.train()
|
||||
|
||||
# 12.1 Test train with model, which does not inherit from BaseModel
|
||||
@MODELS.register_module()
|
||||
@MODELS.register_module(force=True)
|
||||
class ToyModel3(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
|
@ -1901,7 +1925,7 @@ class TestRunner(TestCase):
|
|||
|
||||
def test_custom_loop(self):
|
||||
# test custom loop with additional hook
|
||||
@LOOPS.register_module()
|
||||
@LOOPS.register_module(force=True)
|
||||
class CustomTrainLoop2(IterBasedTrainLoop):
|
||||
"""Custom train loop with additional warmup stage."""
|
||||
|
||||
|
@ -1942,7 +1966,7 @@ class TestRunner(TestCase):
|
|||
before_warmup_iter_results = []
|
||||
after_warmup_iter_results = []
|
||||
|
||||
@HOOKS.register_module()
|
||||
@HOOKS.register_module(force=True)
|
||||
class TestWarmupHook(Hook):
|
||||
"""test custom train loop."""
|
||||
|
||||
|
|
|
@ -0,0 +1,58 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os
|
||||
|
||||
from mmengine import Config
|
||||
from mmengine.logging import MessageHub, MMLogger
|
||||
from mmengine.registry import DefaultScope
|
||||
from mmengine.testing import RunnerTestCase
|
||||
from mmengine.visualization import Visualizer
|
||||
|
||||
|
||||
class TestRunnerTestCase(RunnerTestCase):
|
||||
|
||||
def test_setup(self):
|
||||
self.assertIsInstance(self.epoch_based_cfg, Config)
|
||||
self.assertIsInstance(self.iter_based_cfg, Config)
|
||||
self.assertIn('MASTER_ADDR', self.dist_cfg)
|
||||
self.assertIn('MASTER_PORT', self.dist_cfg)
|
||||
self.assertIn('RANK', self.dist_cfg)
|
||||
self.assertIn('WORLD_SIZE', self.dist_cfg)
|
||||
self.assertIn('LOCAL_RANK', self.dist_cfg)
|
||||
|
||||
def test_tearDown(self):
|
||||
self.tearDown()
|
||||
self.assertEqual(MMLogger._instance_dict, {})
|
||||
self.assertEqual(MessageHub._instance_dict, {})
|
||||
self.assertEqual(Visualizer._instance_dict, {})
|
||||
self.assertEqual(DefaultScope._instance_dict, {})
|
||||
# tearDown should not be called twice.
|
||||
self.tearDown = super(RunnerTestCase, self).tearDown
|
||||
|
||||
def test_build_runner(self):
|
||||
runner = self.build_runner(self.epoch_based_cfg)
|
||||
runner.train()
|
||||
runner.val()
|
||||
runner.test()
|
||||
|
||||
runner = self.build_runner(self.iter_based_cfg)
|
||||
runner.train()
|
||||
runner.val()
|
||||
runner.test()
|
||||
|
||||
def test_experiment_name(self):
|
||||
runner1 = self.build_runner(self.epoch_based_cfg)
|
||||
runner2 = self.build_runner(self.epoch_based_cfg)
|
||||
self.assertNotEqual(runner1.experiment_name, runner2.experiment_name)
|
||||
|
||||
def test_init_dist(self):
|
||||
self.setup_dist_env()
|
||||
self.assertEqual(
|
||||
str(self.dist_cfg['MASTER_PORT']), os.environ['MASTER_PORT'])
|
||||
self.assertEqual(self.dist_cfg['MASTER_ADDR'],
|
||||
os.environ['MASTER_ADDR'])
|
||||
self.assertEqual(self.dist_cfg['RANK'], os.environ['RANK'])
|
||||
self.assertEqual(self.dist_cfg['LOCAL_RANK'], os.environ['LOCAL_RANK'])
|
||||
self.assertEqual(self.dist_cfg['WORLD_SIZE'], os.environ['WORLD_SIZE'])
|
||||
fisrt_port = os.environ['MASTER_ADDR']
|
||||
self.setup_dist_env()
|
||||
self.assertNotEqual(fisrt_port, os.environ['MASTER_PORT'])
|
Loading…
Reference in New Issue