[Fix] fix build train_loop during test (#295)
* fix build train_loop during test * fix build train_loop during test * fix build train_loop during test * fix build train_loop during test * Fix as commentpull/302/head
parent
819e10c24c
commit
8b0c9c5f6f
|
@ -18,17 +18,12 @@ class RuntimeInfoHook(Hook):
|
||||||
|
|
||||||
priority = 'VERY_HIGH'
|
priority = 'VERY_HIGH'
|
||||||
|
|
||||||
def before_run(self, runner) -> None:
|
|
||||||
"""Initialize runtime information."""
|
|
||||||
runner.message_hub.update_info('epoch', runner.epoch)
|
|
||||||
runner.message_hub.update_info('iter', runner.iter)
|
|
||||||
runner.message_hub.update_info('max_epochs', runner.max_epochs)
|
|
||||||
runner.message_hub.update_info('max_iters', runner.max_iters)
|
|
||||||
|
|
||||||
def before_train(self, runner) -> None:
|
def before_train(self, runner) -> None:
|
||||||
"""Update resumed training state."""
|
"""Update resumed training state."""
|
||||||
runner.message_hub.update_info('epoch', runner.epoch)
|
runner.message_hub.update_info('epoch', runner.epoch)
|
||||||
runner.message_hub.update_info('iter', runner.iter)
|
runner.message_hub.update_info('iter', runner.iter)
|
||||||
|
runner.message_hub.update_info('max_epochs', runner.max_epochs)
|
||||||
|
runner.message_hub.update_info('max_iters', runner.max_iters)
|
||||||
|
|
||||||
def before_train_epoch(self, runner) -> None:
|
def before_train_epoch(self, runner) -> None:
|
||||||
"""Update current epoch information before every epoch."""
|
"""Update current epoch information before every epoch."""
|
||||||
|
|
|
@ -12,33 +12,21 @@ from mmengine.optim import OptimWrapper, OptimWrapperDict
|
||||||
|
|
||||||
class TestRuntimeInfoHook(TestCase):
|
class TestRuntimeInfoHook(TestCase):
|
||||||
|
|
||||||
def test_before_run(self):
|
|
||||||
message_hub = MessageHub.get_instance(
|
|
||||||
'runtime_info_hook_test_before_run')
|
|
||||||
runner = Mock()
|
|
||||||
runner.epoch = 3
|
|
||||||
runner.iter = 30
|
|
||||||
runner.max_epochs = 4
|
|
||||||
runner.max_iters = 40
|
|
||||||
runner.message_hub = message_hub
|
|
||||||
hook = RuntimeInfoHook()
|
|
||||||
hook.before_run(runner)
|
|
||||||
self.assertEqual(message_hub.get_info('epoch'), 3)
|
|
||||||
self.assertEqual(message_hub.get_info('iter'), 30)
|
|
||||||
self.assertEqual(message_hub.get_info('max_epochs'), 4)
|
|
||||||
self.assertEqual(message_hub.get_info('max_iters'), 40)
|
|
||||||
|
|
||||||
def test_before_train(self):
|
def test_before_train(self):
|
||||||
message_hub = MessageHub.get_instance(
|
message_hub = MessageHub.get_instance(
|
||||||
'runtime_info_hook_test_before_train')
|
'runtime_info_hook_test_before_train')
|
||||||
runner = Mock()
|
runner = Mock()
|
||||||
runner.epoch = 7
|
runner.epoch = 7
|
||||||
runner.iter = 71
|
runner.iter = 71
|
||||||
|
runner.max_epochs = 4
|
||||||
|
runner.max_iters = 40
|
||||||
runner.message_hub = message_hub
|
runner.message_hub = message_hub
|
||||||
hook = RuntimeInfoHook()
|
hook = RuntimeInfoHook()
|
||||||
hook.before_train(runner)
|
hook.before_train(runner)
|
||||||
self.assertEqual(message_hub.get_info('epoch'), 7)
|
self.assertEqual(message_hub.get_info('epoch'), 7)
|
||||||
self.assertEqual(message_hub.get_info('iter'), 71)
|
self.assertEqual(message_hub.get_info('iter'), 71)
|
||||||
|
self.assertEqual(message_hub.get_info('max_epochs'), 4)
|
||||||
|
self.assertEqual(message_hub.get_info('max_iters'), 40)
|
||||||
|
|
||||||
def test_before_train_epoch(self):
|
def test_before_train_epoch(self):
|
||||||
message_hub = MessageHub.get_instance(
|
message_hub = MessageHub.get_instance(
|
||||||
|
|
|
@ -1241,6 +1241,8 @@ class TestRunner(TestCase):
|
||||||
cfg.experiment_name = 'test_test2'
|
cfg.experiment_name = 'test_test2'
|
||||||
runner = Runner.from_cfg(cfg)
|
runner = Runner.from_cfg(cfg)
|
||||||
runner.test()
|
runner.test()
|
||||||
|
# Test run test without building train loop.
|
||||||
|
self.assertIsInstance(runner._train_loop, dict)
|
||||||
|
|
||||||
# test run test without train and test components
|
# test run test without train and test components
|
||||||
cfg = copy.deepcopy(self.epoch_based_cfg)
|
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||||
|
|
Loading…
Reference in New Issue