[Enhancement] MessageHub.get_info() supports returning a default value (#991)

This commit is contained in:
cyberslack_lee 2023-04-23 17:35:35 +08:00 committed by GitHub
parent 1db55358fc
commit 0687b377b2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 13 additions and 17 deletions

View File

@ -288,22 +288,24 @@ class MessageHub(ManagerMixin):
f'instance name is: {MessageHub.instance_name}')
return self.log_scalars[key]
def get_info(self, key: str) -> Any:
"""Get runtime information by key.
def get_info(self, key: str, default: Optional[Any] = None) -> Any:
"""Get runtime information by key. if the key does not exist, this
method will return default information.
Args:
key (str): Key of runtime information.
default (Any, optional): The default returned value for the
given key.
Returns:
Any: A copy of corresponding runtime information if the key exists.
"""
if key not in self.runtime_info:
raise KeyError(f'{key} is not found in Messagehub.log_buffers: '
f'instance name is: {MessageHub.instance_name}')
# TODO There are restrictions on objects that can be saved
# return copy.deepcopy(self._runtime_info[key])
return self._runtime_info[key]
return default
else:
# TODO There are restrictions on objects that can be saved
# return copy.deepcopy(self._runtime_info[key])
return self._runtime_info[key]
def _get_valid_value(
self,

View File

@ -46,9 +46,6 @@ class TestRuntimeInfoHook(RunnerTestCase):
self.assertEqual(runner.message_hub.get_info('max_epochs'), 2)
self.assertEqual(runner.message_hub.get_info('max_iters'), 8)
with self.assertRaisesRegex(KeyError, 'dataset_meta is not found'):
runner.message_hub.get_info('dataset_meta')
cfg.train_dataloader.dataset.type = 'DatasetWithMetainfo'
runner = self.build_runner(cfg)
hook.before_train(runner)

View File

@ -82,8 +82,7 @@ class TestMessageHub:
def test_get_runtime(self):
message_hub = MessageHub.get_instance('mmengine')
with pytest.raises(KeyError):
message_hub.get_info('unknown')
assert message_hub.get_info('unknown') is None
recorded_dict = dict(a=1, b=2)
message_hub.update_info('test_value', recorded_dict)
assert message_hub.get_info('test_value') == recorded_dict
@ -186,10 +185,8 @@ class TestMessageHub:
obj = pickle.dumps(message_hub)
instance = pickle.loads(obj)
with pytest.raises(KeyError):
instance.get_info('feat')
with pytest.raises(KeyError):
instance.get_info('lr')
assert instance.get_info('feat') is None
assert instance.get_info('lr') is None
instance.get_info('iter')
instance.get_scalar('loss')