mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Enhancement] MessageHub.get_info() supports returning a default value (#991)
This commit is contained in:
parent
1db55358fc
commit
0687b377b2
@ -288,22 +288,24 @@ class MessageHub(ManagerMixin):
|
||||
f'instance name is: {MessageHub.instance_name}')
|
||||
return self.log_scalars[key]
|
||||
|
||||
def get_info(self, key: str) -> Any:
|
||||
"""Get runtime information by key.
|
||||
def get_info(self, key: str, default: Optional[Any] = None) -> Any:
|
||||
"""Get runtime information by key. if the key does not exist, this
|
||||
method will return default information.
|
||||
|
||||
Args:
|
||||
key (str): Key of runtime information.
|
||||
default (Any, optional): The default returned value for the
|
||||
given key.
|
||||
|
||||
Returns:
|
||||
Any: A copy of corresponding runtime information if the key exists.
|
||||
"""
|
||||
if key not in self.runtime_info:
|
||||
raise KeyError(f'{key} is not found in Messagehub.log_buffers: '
|
||||
f'instance name is: {MessageHub.instance_name}')
|
||||
|
||||
# TODO: There are restrictions on objects that can be saved
|
||||
# return copy.deepcopy(self._runtime_info[key])
|
||||
return self._runtime_info[key]
|
||||
return default
|
||||
else:
|
||||
# TODO: There are restrictions on objects that can be saved
|
||||
# return copy.deepcopy(self._runtime_info[key])
|
||||
return self._runtime_info[key]
|
||||
|
||||
def _get_valid_value(
|
||||
self,
|
||||
|
@ -46,9 +46,6 @@ class TestRuntimeInfoHook(RunnerTestCase):
|
||||
self.assertEqual(runner.message_hub.get_info('max_epochs'), 2)
|
||||
self.assertEqual(runner.message_hub.get_info('max_iters'), 8)
|
||||
|
||||
with self.assertRaisesRegex(KeyError, 'dataset_meta is not found'):
|
||||
runner.message_hub.get_info('dataset_meta')
|
||||
|
||||
cfg.train_dataloader.dataset.type = 'DatasetWithMetainfo'
|
||||
runner = self.build_runner(cfg)
|
||||
hook.before_train(runner)
|
||||
|
@ -82,8 +82,7 @@ class TestMessageHub:
|
||||
|
||||
def test_get_runtime(self):
|
||||
message_hub = MessageHub.get_instance('mmengine')
|
||||
with pytest.raises(KeyError):
|
||||
message_hub.get_info('unknown')
|
||||
assert message_hub.get_info('unknown') is None
|
||||
recorded_dict = dict(a=1, b=2)
|
||||
message_hub.update_info('test_value', recorded_dict)
|
||||
assert message_hub.get_info('test_value') == recorded_dict
|
||||
@ -186,10 +185,8 @@ class TestMessageHub:
|
||||
obj = pickle.dumps(message_hub)
|
||||
instance = pickle.loads(obj)
|
||||
|
||||
with pytest.raises(KeyError):
|
||||
instance.get_info('feat')
|
||||
with pytest.raises(KeyError):
|
||||
instance.get_info('lr')
|
||||
assert instance.get_info('feat') is None
|
||||
assert instance.get_info('lr') is None
|
||||
|
||||
instance.get_info('iter')
|
||||
instance.get_scalar('loss')
|
||||
|
Loading…
x
Reference in New Issue
Block a user