From 0687b377b27d6c824d4e14b38be164b6bd2d14ec Mon Sep 17 00:00:00 2001 From: cyberslack_lee Date: Sun, 23 Apr 2023 17:35:35 +0800 Subject: [PATCH] [Enhancement] MessageHub.get_info() supports returning a default value (#991) --- mmengine/logging/message_hub.py | 18 ++++++++++-------- tests/test_hooks/test_runtime_info_hook.py | 3 --- tests/test_logging/test_message_hub.py | 9 +++------ 3 files changed, 13 insertions(+), 17 deletions(-) diff --git a/mmengine/logging/message_hub.py b/mmengine/logging/message_hub.py index c538ecc7..363a3e1f 100644 --- a/mmengine/logging/message_hub.py +++ b/mmengine/logging/message_hub.py @@ -288,22 +288,24 @@ class MessageHub(ManagerMixin): f'instance name is: {MessageHub.instance_name}') return self.log_scalars[key] - def get_info(self, key: str) -> Any: - """Get runtime information by key. + def get_info(self, key: str, default: Optional[Any] = None) -> Any: + """Get runtime information by key. if the key does not exist, this + method will return default information. Args: key (str): Key of runtime information. + default (Any, optional): The default returned value for the + given key. Returns: Any: A copy of corresponding runtime information if the key exists. """ if key not in self.runtime_info: - raise KeyError(f'{key} is not found in Messagehub.log_buffers: ' - f'instance name is: {MessageHub.instance_name}') - - # TODO: There are restrictions on objects that can be saved - # return copy.deepcopy(self._runtime_info[key]) - return self._runtime_info[key] + return default + else: + # TODO: There are restrictions on objects that can be saved + # return copy.deepcopy(self._runtime_info[key]) + return self._runtime_info[key] def _get_valid_value( self, diff --git a/tests/test_hooks/test_runtime_info_hook.py b/tests/test_hooks/test_runtime_info_hook.py index 7593f845..6e3adb8c 100644 --- a/tests/test_hooks/test_runtime_info_hook.py +++ b/tests/test_hooks/test_runtime_info_hook.py @@ -46,9 +46,6 @@ class TestRuntimeInfoHook(RunnerTestCase): self.assertEqual(runner.message_hub.get_info('max_epochs'), 2) self.assertEqual(runner.message_hub.get_info('max_iters'), 8) - with self.assertRaisesRegex(KeyError, 'dataset_meta is not found'): - runner.message_hub.get_info('dataset_meta') - cfg.train_dataloader.dataset.type = 'DatasetWithMetainfo' runner = self.build_runner(cfg) hook.before_train(runner) diff --git a/tests/test_logging/test_message_hub.py b/tests/test_logging/test_message_hub.py index 4dffdec0..a78c7728 100644 --- a/tests/test_logging/test_message_hub.py +++ b/tests/test_logging/test_message_hub.py @@ -82,8 +82,7 @@ class TestMessageHub: def test_get_runtime(self): message_hub = MessageHub.get_instance('mmengine') - with pytest.raises(KeyError): - message_hub.get_info('unknown') + assert message_hub.get_info('unknown') is None recorded_dict = dict(a=1, b=2) message_hub.update_info('test_value', recorded_dict) assert message_hub.get_info('test_value') == recorded_dict @@ -186,10 +185,8 @@ class TestMessageHub: obj = pickle.dumps(message_hub) instance = pickle.loads(obj) - with pytest.raises(KeyError): - instance.get_info('feat') - with pytest.raises(KeyError): - instance.get_info('lr') + assert instance.get_info('feat') is None + assert instance.get_info('lr') is None instance.get_info('iter') instance.get_scalar('loss')