[Enhance] Enhance docstring and error cathing in MessageHub (#1098)

This commit is contained in:
Mashiro 2023-04-23 17:16:52 +08:00 committed by GitHub
parent fafb476e58
commit 1db55358fc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -27,14 +27,14 @@ class MessageHub(ManagerMixin):
Args:
name (str): Name of message hub used to get corresponding instance
globally.
log_scalars (OrderedDict, optional): Each key-value pair in the
log_scalars (dict, optional): Each key-value pair in the
dictionary is the name of the log information such as "loss", "lr",
"metric" and their corresponding values. The type of value must be
HistoryBuffer. Defaults to None.
runtime_info (OrderedDict, optional): Each key-value pair in the
runtime_info (dict, optional): Each key-value pair in the
dictionary is the name of the runtime information and their
corresponding values. Defaults to None.
resumed_keys (OrderedDict, optional): Each key-value pair in the
resumed_keys (dict, optional): Each key-value pair in the
dictionary decides whether the key in :attr:`_log_scalars` and
:attr:`_runtime_info` will be serialized.
@ -45,9 +45,9 @@ class MessageHub(ManagerMixin):
Examples:
>>> # create empty `MessageHub`.
>>> message_hub1 = MessageHub()
>>> log_scalars = OrderedDict(loss=HistoryBuffer())
>>> runtime_info = OrderedDict(task='task')
>>> message_hub1 = MessageHub('name')
>>> log_scalars = dict(loss=HistoryBuffer())
>>> runtime_info = dict(task='task')
>>> resumed_keys = dict(loss=True)
>>> # create `MessageHub` from data.
>>> message_hub2 = MessageHub(
@ -59,20 +59,13 @@ class MessageHub(ManagerMixin):
def __init__(self,
name: str,
log_scalars: Optional[OrderedDict] = None,
runtime_info: Optional[OrderedDict] = None,
resumed_keys: Optional[OrderedDict] = None):
log_scalars: Optional[dict] = None,
runtime_info: Optional[dict] = None,
resumed_keys: Optional[dict] = None):
super().__init__(name)
self._log_scalars = log_scalars if log_scalars is not None else \
OrderedDict()
self._runtime_info = runtime_info if runtime_info is not None else \
OrderedDict()
self._resumed_keys = resumed_keys if resumed_keys is not None else \
OrderedDict()
assert isinstance(self._log_scalars, OrderedDict)
assert isinstance(self._runtime_info, OrderedDict)
assert isinstance(self._resumed_keys, OrderedDict)
self._log_scalars = self._parse_input('log_scalars', log_scalars)
self._runtime_info = self._parse_input('runtime_info', runtime_info)
self._resumed_keys = self._parse_input('resumed_keys', resumed_keys)
for value in self._log_scalars.values():
assert isinstance(value, HistoryBuffer), \
@ -113,7 +106,7 @@ class MessageHub(ManagerMixin):
constructor of ``HistoryBuffer``.
Examples:
>>> message_hub = MessageHub
>>> message_hub = MessageHub(name='name')
>>> # create loss `HistoryBuffer` with value=1, count=1
>>> message_hub.update_scalar('loss', 1)
>>> # update loss `HistoryBuffer` with value
@ -197,7 +190,7 @@ class MessageHub(ManagerMixin):
``key``.
Examples:
>>> message_hub = MessageHub()
>>> message_hub = MessageHub(name='name')
>>> message_hub.update_info('iter', 100)
Args:
@ -220,7 +213,7 @@ class MessageHub(ManagerMixin):
``info_dict``.
Examples:
>>> message_hub = MessageHub()
>>> message_hub = MessageHub(name='name')
>>> message_hub.update_info({'iter': 100})
Args:
@ -273,7 +266,6 @@ class MessageHub(ManagerMixin):
Returns:
OrderedDict: A copy of all runtime information.
"""
# return copy.deepcopy(self._runtime_info)
return self._runtime_info
def get_scalar(self, key: str) -> HistoryBuffer:
@ -442,3 +434,21 @@ class MessageHub(ManagerMixin):
self._log_scalars = copy.deepcopy(state_dict._log_scalars)
self._runtime_info = copy.deepcopy(state_dict._runtime_info)
self._resumed_keys = copy.deepcopy(state_dict._resumed_keys)
def _parse_input(self, name: str, value: Any) -> OrderedDict:
"""Parse input value.
Args:
name (str): name of input value.
value (Any): Input value.
Returns:
dict: Parsed input value.
"""
if value is None:
return OrderedDict()
elif isinstance(value, dict):
return OrderedDict(value)
else:
raise TypeError(f'{name} should be a dict or `None`, but '
f'got {type(name)}')