diff --git a/mmengine/hooks/runtime_info_hook.py b/mmengine/hooks/runtime_info_hook.py index 2ea6cc61..3bf7fe46 100644 --- a/mmengine/hooks/runtime_info_hook.py +++ b/mmengine/hooks/runtime_info_hook.py @@ -34,7 +34,7 @@ class RuntimeInfoHook(Hook): runner.message_hub.update_info('iter', runner.iter) runner.message_hub.update_info('max_epochs', runner.max_epochs) runner.message_hub.update_info('max_iters', runner.max_iters) - if hasattr(runner.train_dataloader.dataset, 'dataset_meta'): + if hasattr(runner.train_dataloader.dataset, 'metainfo'): runner.message_hub.update_info( 'dataset_meta', runner.train_dataloader.dataset.metainfo) diff --git a/tests/test_hooks/test_runtime_info_hook.py b/tests/test_hooks/test_runtime_info_hook.py index 56935547..8b47782d 100644 --- a/tests/test_hooks/test_runtime_info_hook.py +++ b/tests/test_hooks/test_runtime_info_hook.py @@ -15,18 +15,32 @@ class TestRuntimeInfoHook(TestCase): def test_before_train(self): message_hub = MessageHub.get_instance( 'runtime_info_hook_test_before_train') + + class ToyDataset: + ... + runner = Mock() runner.epoch = 7 runner.iter = 71 runner.max_epochs = 4 runner.max_iters = 40 runner.message_hub = message_hub + runner.train_dataloader.dataset = ToyDataset() hook = RuntimeInfoHook() hook.before_train(runner) self.assertEqual(message_hub.get_info('epoch'), 7) self.assertEqual(message_hub.get_info('iter'), 71) self.assertEqual(message_hub.get_info('max_epochs'), 4) self.assertEqual(message_hub.get_info('max_iters'), 40) + with self.assertRaisesRegex(KeyError, 'dataset_meta is not found'): + message_hub.get_info('dataset_meta') + + class ToyDatasetWithMeta: + metainfo = dict() + + runner.train_dataloader.dataset = ToyDatasetWithMeta() + hook.before_train(runner) + self.assertEqual(message_hub.get_info('dataset_meta'), dict()) def test_before_train_epoch(self): message_hub = MessageHub.get_instance(