[Fix]: fix error and add unit test (#429)
parent
f5cb45dc33
commit
a706bbc018
|
@ -34,7 +34,7 @@ class RuntimeInfoHook(Hook):
|
|||
runner.message_hub.update_info('iter', runner.iter)
|
||||
runner.message_hub.update_info('max_epochs', runner.max_epochs)
|
||||
runner.message_hub.update_info('max_iters', runner.max_iters)
|
||||
if hasattr(runner.train_dataloader.dataset, 'dataset_meta'):
|
||||
if hasattr(runner.train_dataloader.dataset, 'metainfo'):
|
||||
runner.message_hub.update_info(
|
||||
'dataset_meta', runner.train_dataloader.dataset.metainfo)
|
||||
|
||||
|
|
|
@ -15,18 +15,32 @@ class TestRuntimeInfoHook(TestCase):
|
|||
def test_before_train(self):
|
||||
message_hub = MessageHub.get_instance(
|
||||
'runtime_info_hook_test_before_train')
|
||||
|
||||
class ToyDataset:
|
||||
...
|
||||
|
||||
runner = Mock()
|
||||
runner.epoch = 7
|
||||
runner.iter = 71
|
||||
runner.max_epochs = 4
|
||||
runner.max_iters = 40
|
||||
runner.message_hub = message_hub
|
||||
runner.train_dataloader.dataset = ToyDataset()
|
||||
hook = RuntimeInfoHook()
|
||||
hook.before_train(runner)
|
||||
self.assertEqual(message_hub.get_info('epoch'), 7)
|
||||
self.assertEqual(message_hub.get_info('iter'), 71)
|
||||
self.assertEqual(message_hub.get_info('max_epochs'), 4)
|
||||
self.assertEqual(message_hub.get_info('max_iters'), 40)
|
||||
with self.assertRaisesRegex(KeyError, 'dataset_meta is not found'):
|
||||
message_hub.get_info('dataset_meta')
|
||||
|
||||
class ToyDatasetWithMeta:
|
||||
metainfo = dict()
|
||||
|
||||
runner.train_dataloader.dataset = ToyDatasetWithMeta()
|
||||
hook.before_train(runner)
|
||||
self.assertEqual(message_hub.get_info('dataset_meta'), dict())
|
||||
|
||||
def test_before_train_epoch(self):
|
||||
message_hub = MessageHub.get_instance(
|
||||
|
|
Loading…
Reference in New Issue