diff --git a/mmengine/logging/message_hub.py b/mmengine/logging/message_hub.py index 492af34e..c538ecc7 100644 --- a/mmengine/logging/message_hub.py +++ b/mmengine/logging/message_hub.py @@ -27,14 +27,14 @@ class MessageHub(ManagerMixin): Args: name (str): Name of message hub used to get corresponding instance globally. - log_scalars (OrderedDict, optional): Each key-value pair in the + log_scalars (dict, optional): Each key-value pair in the dictionary is the name of the log information such as "loss", "lr", "metric" and their corresponding values. The type of value must be HistoryBuffer. Defaults to None. - runtime_info (OrderedDict, optional): Each key-value pair in the + runtime_info (dict, optional): Each key-value pair in the dictionary is the name of the runtime information and their corresponding values. Defaults to None. - resumed_keys (OrderedDict, optional): Each key-value pair in the + resumed_keys (dict, optional): Each key-value pair in the dictionary decides whether the key in :attr:`_log_scalars` and :attr:`_runtime_info` will be serialized. @@ -45,9 +45,9 @@ class MessageHub(ManagerMixin): Examples: >>> # create empty `MessageHub`. - >>> message_hub1 = MessageHub() - >>> log_scalars = OrderedDict(loss=HistoryBuffer()) - >>> runtime_info = OrderedDict(task='task') + >>> message_hub1 = MessageHub('name') + >>> log_scalars = dict(loss=HistoryBuffer()) + >>> runtime_info = dict(task='task') >>> resumed_keys = dict(loss=True) >>> # create `MessageHub` from data. >>> message_hub2 = MessageHub( @@ -59,20 +59,13 @@ class MessageHub(ManagerMixin): def __init__(self, name: str, - log_scalars: Optional[OrderedDict] = None, - runtime_info: Optional[OrderedDict] = None, - resumed_keys: Optional[OrderedDict] = None): + log_scalars: Optional[dict] = None, + runtime_info: Optional[dict] = None, + resumed_keys: Optional[dict] = None): super().__init__(name) - self._log_scalars = log_scalars if log_scalars is not None else \ - OrderedDict() - self._runtime_info = runtime_info if runtime_info is not None else \ - OrderedDict() - self._resumed_keys = resumed_keys if resumed_keys is not None else \ - OrderedDict() - - assert isinstance(self._log_scalars, OrderedDict) - assert isinstance(self._runtime_info, OrderedDict) - assert isinstance(self._resumed_keys, OrderedDict) + self._log_scalars = self._parse_input('log_scalars', log_scalars) + self._runtime_info = self._parse_input('runtime_info', runtime_info) + self._resumed_keys = self._parse_input('resumed_keys', resumed_keys) for value in self._log_scalars.values(): assert isinstance(value, HistoryBuffer), \ @@ -113,7 +106,7 @@ class MessageHub(ManagerMixin): constructor of ``HistoryBuffer``. Examples: - >>> message_hub = MessageHub + >>> message_hub = MessageHub(name='name') >>> # create loss `HistoryBuffer` with value=1, count=1 >>> message_hub.update_scalar('loss', 1) >>> # update loss `HistoryBuffer` with value @@ -197,7 +190,7 @@ class MessageHub(ManagerMixin): ``key``. Examples: - >>> message_hub = MessageHub() + >>> message_hub = MessageHub(name='name') >>> message_hub.update_info('iter', 100) Args: @@ -220,7 +213,7 @@ class MessageHub(ManagerMixin): ``info_dict``. Examples: - >>> message_hub = MessageHub() + >>> message_hub = MessageHub(name='name') >>> message_hub.update_info({'iter': 100}) Args: @@ -273,7 +266,6 @@ class MessageHub(ManagerMixin): Returns: OrderedDict: A copy of all runtime information. """ - # return copy.deepcopy(self._runtime_info) return self._runtime_info def get_scalar(self, key: str) -> HistoryBuffer: @@ -442,3 +434,21 @@ class MessageHub(ManagerMixin): self._log_scalars = copy.deepcopy(state_dict._log_scalars) self._runtime_info = copy.deepcopy(state_dict._runtime_info) self._resumed_keys = copy.deepcopy(state_dict._resumed_keys) + + def _parse_input(self, name: str, value: Any) -> OrderedDict: + """Parse input value. + + Args: + name (str): name of input value. + value (Any): Input value. + + Returns: + dict: Parsed input value. + """ + if value is None: + return OrderedDict() + elif isinstance(value, dict): + return OrderedDict(value) + else: + raise TypeError(f'{name} should be a dict or `None`, but ' + f'got {type(name)}')