mmengine/tests/test_testing/test_runner_test_case.py
Mashiro c478bdca27
[Enhance] enhance runner test case (#631)
* Add runner test cast

* Fix unit test

* fix unit test

* pop None if key does not exist

* Fix is_model_wrapper and force register class in test_runner

* [Fix] Fix is_model_wrapper

* destroy group after ut

* register module in testcase

* fix as comment

* minor refine

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* fix lint

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
2022-11-21 11:54:05 +08:00

59 lines
2.2 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
import os
from mmengine import Config
from mmengine.logging import MessageHub, MMLogger
from mmengine.registry import DefaultScope
from mmengine.testing import RunnerTestCase
from mmengine.visualization import Visualizer
class TestRunnerTestCase(RunnerTestCase):
def test_setup(self):
self.assertIsInstance(self.epoch_based_cfg, Config)
self.assertIsInstance(self.iter_based_cfg, Config)
self.assertIn('MASTER_ADDR', self.dist_cfg)
self.assertIn('MASTER_PORT', self.dist_cfg)
self.assertIn('RANK', self.dist_cfg)
self.assertIn('WORLD_SIZE', self.dist_cfg)
self.assertIn('LOCAL_RANK', self.dist_cfg)
def test_tearDown(self):
self.tearDown()
self.assertEqual(MMLogger._instance_dict, {})
self.assertEqual(MessageHub._instance_dict, {})
self.assertEqual(Visualizer._instance_dict, {})
self.assertEqual(DefaultScope._instance_dict, {})
# tearDown should not be called twice.
self.tearDown = super(RunnerTestCase, self).tearDown
def test_build_runner(self):
runner = self.build_runner(self.epoch_based_cfg)
runner.train()
runner.val()
runner.test()
runner = self.build_runner(self.iter_based_cfg)
runner.train()
runner.val()
runner.test()
def test_experiment_name(self):
runner1 = self.build_runner(self.epoch_based_cfg)
runner2 = self.build_runner(self.epoch_based_cfg)
self.assertNotEqual(runner1.experiment_name, runner2.experiment_name)
def test_init_dist(self):
self.setup_dist_env()
self.assertEqual(
str(self.dist_cfg['MASTER_PORT']), os.environ['MASTER_PORT'])
self.assertEqual(self.dist_cfg['MASTER_ADDR'],
os.environ['MASTER_ADDR'])
self.assertEqual(self.dist_cfg['RANK'], os.environ['RANK'])
self.assertEqual(self.dist_cfg['LOCAL_RANK'], os.environ['LOCAL_RANK'])
self.assertEqual(self.dist_cfg['WORLD_SIZE'], os.environ['WORLD_SIZE'])
fisrt_port = os.environ['MASTER_ADDR']
self.setup_dist_env()
self.assertNotEqual(fisrt_port, os.environ['MASTER_PORT'])