mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Enhance] Enhance docstring and error cathing in MessageHub (#1098)
This commit is contained in:
parent
fafb476e58
commit
1db55358fc
@ -27,14 +27,14 @@ class MessageHub(ManagerMixin):
|
||||
Args:
|
||||
name (str): Name of message hub used to get corresponding instance
|
||||
globally.
|
||||
log_scalars (OrderedDict, optional): Each key-value pair in the
|
||||
log_scalars (dict, optional): Each key-value pair in the
|
||||
dictionary is the name of the log information such as "loss", "lr",
|
||||
"metric" and their corresponding values. The type of value must be
|
||||
HistoryBuffer. Defaults to None.
|
||||
runtime_info (OrderedDict, optional): Each key-value pair in the
|
||||
runtime_info (dict, optional): Each key-value pair in the
|
||||
dictionary is the name of the runtime information and their
|
||||
corresponding values. Defaults to None.
|
||||
resumed_keys (OrderedDict, optional): Each key-value pair in the
|
||||
resumed_keys (dict, optional): Each key-value pair in the
|
||||
dictionary decides whether the key in :attr:`_log_scalars` and
|
||||
:attr:`_runtime_info` will be serialized.
|
||||
|
||||
@ -45,9 +45,9 @@ class MessageHub(ManagerMixin):
|
||||
|
||||
Examples:
|
||||
>>> # create empty `MessageHub`.
|
||||
>>> message_hub1 = MessageHub()
|
||||
>>> log_scalars = OrderedDict(loss=HistoryBuffer())
|
||||
>>> runtime_info = OrderedDict(task='task')
|
||||
>>> message_hub1 = MessageHub('name')
|
||||
>>> log_scalars = dict(loss=HistoryBuffer())
|
||||
>>> runtime_info = dict(task='task')
|
||||
>>> resumed_keys = dict(loss=True)
|
||||
>>> # create `MessageHub` from data.
|
||||
>>> message_hub2 = MessageHub(
|
||||
@ -59,20 +59,13 @@ class MessageHub(ManagerMixin):
|
||||
|
||||
def __init__(self,
|
||||
name: str,
|
||||
log_scalars: Optional[OrderedDict] = None,
|
||||
runtime_info: Optional[OrderedDict] = None,
|
||||
resumed_keys: Optional[OrderedDict] = None):
|
||||
log_scalars: Optional[dict] = None,
|
||||
runtime_info: Optional[dict] = None,
|
||||
resumed_keys: Optional[dict] = None):
|
||||
super().__init__(name)
|
||||
self._log_scalars = log_scalars if log_scalars is not None else \
|
||||
OrderedDict()
|
||||
self._runtime_info = runtime_info if runtime_info is not None else \
|
||||
OrderedDict()
|
||||
self._resumed_keys = resumed_keys if resumed_keys is not None else \
|
||||
OrderedDict()
|
||||
|
||||
assert isinstance(self._log_scalars, OrderedDict)
|
||||
assert isinstance(self._runtime_info, OrderedDict)
|
||||
assert isinstance(self._resumed_keys, OrderedDict)
|
||||
self._log_scalars = self._parse_input('log_scalars', log_scalars)
|
||||
self._runtime_info = self._parse_input('runtime_info', runtime_info)
|
||||
self._resumed_keys = self._parse_input('resumed_keys', resumed_keys)
|
||||
|
||||
for value in self._log_scalars.values():
|
||||
assert isinstance(value, HistoryBuffer), \
|
||||
@ -113,7 +106,7 @@ class MessageHub(ManagerMixin):
|
||||
constructor of ``HistoryBuffer``.
|
||||
|
||||
Examples:
|
||||
>>> message_hub = MessageHub
|
||||
>>> message_hub = MessageHub(name='name')
|
||||
>>> # create loss `HistoryBuffer` with value=1, count=1
|
||||
>>> message_hub.update_scalar('loss', 1)
|
||||
>>> # update loss `HistoryBuffer` with value
|
||||
@ -197,7 +190,7 @@ class MessageHub(ManagerMixin):
|
||||
``key``.
|
||||
|
||||
Examples:
|
||||
>>> message_hub = MessageHub()
|
||||
>>> message_hub = MessageHub(name='name')
|
||||
>>> message_hub.update_info('iter', 100)
|
||||
|
||||
Args:
|
||||
@ -220,7 +213,7 @@ class MessageHub(ManagerMixin):
|
||||
``info_dict``.
|
||||
|
||||
Examples:
|
||||
>>> message_hub = MessageHub()
|
||||
>>> message_hub = MessageHub(name='name')
|
||||
>>> message_hub.update_info({'iter': 100})
|
||||
|
||||
Args:
|
||||
@ -273,7 +266,6 @@ class MessageHub(ManagerMixin):
|
||||
Returns:
|
||||
OrderedDict: A copy of all runtime information.
|
||||
"""
|
||||
# return copy.deepcopy(self._runtime_info)
|
||||
return self._runtime_info
|
||||
|
||||
def get_scalar(self, key: str) -> HistoryBuffer:
|
||||
@ -442,3 +434,21 @@ class MessageHub(ManagerMixin):
|
||||
self._log_scalars = copy.deepcopy(state_dict._log_scalars)
|
||||
self._runtime_info = copy.deepcopy(state_dict._runtime_info)
|
||||
self._resumed_keys = copy.deepcopy(state_dict._resumed_keys)
|
||||
|
||||
def _parse_input(self, name: str, value: Any) -> OrderedDict:
|
||||
"""Parse input value.
|
||||
|
||||
Args:
|
||||
name (str): name of input value.
|
||||
value (Any): Input value.
|
||||
|
||||
Returns:
|
||||
dict: Parsed input value.
|
||||
"""
|
||||
if value is None:
|
||||
return OrderedDict()
|
||||
elif isinstance(value, dict):
|
||||
return OrderedDict(value)
|
||||
else:
|
||||
raise TypeError(f'{name} should be a dict or `None`, but '
|
||||
f'got {type(name)}')
|
||||
|
Loading…
x
Reference in New Issue
Block a user