mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Enhance] Enhance docstring and error cathing in MessageHub (#1098)
This commit is contained in:
parent
fafb476e58
commit
1db55358fc
@ -27,14 +27,14 @@ class MessageHub(ManagerMixin):
|
|||||||
Args:
|
Args:
|
||||||
name (str): Name of message hub used to get corresponding instance
|
name (str): Name of message hub used to get corresponding instance
|
||||||
globally.
|
globally.
|
||||||
log_scalars (OrderedDict, optional): Each key-value pair in the
|
log_scalars (dict, optional): Each key-value pair in the
|
||||||
dictionary is the name of the log information such as "loss", "lr",
|
dictionary is the name of the log information such as "loss", "lr",
|
||||||
"metric" and their corresponding values. The type of value must be
|
"metric" and their corresponding values. The type of value must be
|
||||||
HistoryBuffer. Defaults to None.
|
HistoryBuffer. Defaults to None.
|
||||||
runtime_info (OrderedDict, optional): Each key-value pair in the
|
runtime_info (dict, optional): Each key-value pair in the
|
||||||
dictionary is the name of the runtime information and their
|
dictionary is the name of the runtime information and their
|
||||||
corresponding values. Defaults to None.
|
corresponding values. Defaults to None.
|
||||||
resumed_keys (OrderedDict, optional): Each key-value pair in the
|
resumed_keys (dict, optional): Each key-value pair in the
|
||||||
dictionary decides whether the key in :attr:`_log_scalars` and
|
dictionary decides whether the key in :attr:`_log_scalars` and
|
||||||
:attr:`_runtime_info` will be serialized.
|
:attr:`_runtime_info` will be serialized.
|
||||||
|
|
||||||
@ -45,9 +45,9 @@ class MessageHub(ManagerMixin):
|
|||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> # create empty `MessageHub`.
|
>>> # create empty `MessageHub`.
|
||||||
>>> message_hub1 = MessageHub()
|
>>> message_hub1 = MessageHub('name')
|
||||||
>>> log_scalars = OrderedDict(loss=HistoryBuffer())
|
>>> log_scalars = dict(loss=HistoryBuffer())
|
||||||
>>> runtime_info = OrderedDict(task='task')
|
>>> runtime_info = dict(task='task')
|
||||||
>>> resumed_keys = dict(loss=True)
|
>>> resumed_keys = dict(loss=True)
|
||||||
>>> # create `MessageHub` from data.
|
>>> # create `MessageHub` from data.
|
||||||
>>> message_hub2 = MessageHub(
|
>>> message_hub2 = MessageHub(
|
||||||
@ -59,20 +59,13 @@ class MessageHub(ManagerMixin):
|
|||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
name: str,
|
name: str,
|
||||||
log_scalars: Optional[OrderedDict] = None,
|
log_scalars: Optional[dict] = None,
|
||||||
runtime_info: Optional[OrderedDict] = None,
|
runtime_info: Optional[dict] = None,
|
||||||
resumed_keys: Optional[OrderedDict] = None):
|
resumed_keys: Optional[dict] = None):
|
||||||
super().__init__(name)
|
super().__init__(name)
|
||||||
self._log_scalars = log_scalars if log_scalars is not None else \
|
self._log_scalars = self._parse_input('log_scalars', log_scalars)
|
||||||
OrderedDict()
|
self._runtime_info = self._parse_input('runtime_info', runtime_info)
|
||||||
self._runtime_info = runtime_info if runtime_info is not None else \
|
self._resumed_keys = self._parse_input('resumed_keys', resumed_keys)
|
||||||
OrderedDict()
|
|
||||||
self._resumed_keys = resumed_keys if resumed_keys is not None else \
|
|
||||||
OrderedDict()
|
|
||||||
|
|
||||||
assert isinstance(self._log_scalars, OrderedDict)
|
|
||||||
assert isinstance(self._runtime_info, OrderedDict)
|
|
||||||
assert isinstance(self._resumed_keys, OrderedDict)
|
|
||||||
|
|
||||||
for value in self._log_scalars.values():
|
for value in self._log_scalars.values():
|
||||||
assert isinstance(value, HistoryBuffer), \
|
assert isinstance(value, HistoryBuffer), \
|
||||||
@ -113,7 +106,7 @@ class MessageHub(ManagerMixin):
|
|||||||
constructor of ``HistoryBuffer``.
|
constructor of ``HistoryBuffer``.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> message_hub = MessageHub
|
>>> message_hub = MessageHub(name='name')
|
||||||
>>> # create loss `HistoryBuffer` with value=1, count=1
|
>>> # create loss `HistoryBuffer` with value=1, count=1
|
||||||
>>> message_hub.update_scalar('loss', 1)
|
>>> message_hub.update_scalar('loss', 1)
|
||||||
>>> # update loss `HistoryBuffer` with value
|
>>> # update loss `HistoryBuffer` with value
|
||||||
@ -197,7 +190,7 @@ class MessageHub(ManagerMixin):
|
|||||||
``key``.
|
``key``.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> message_hub = MessageHub()
|
>>> message_hub = MessageHub(name='name')
|
||||||
>>> message_hub.update_info('iter', 100)
|
>>> message_hub.update_info('iter', 100)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -220,7 +213,7 @@ class MessageHub(ManagerMixin):
|
|||||||
``info_dict``.
|
``info_dict``.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> message_hub = MessageHub()
|
>>> message_hub = MessageHub(name='name')
|
||||||
>>> message_hub.update_info({'iter': 100})
|
>>> message_hub.update_info({'iter': 100})
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -273,7 +266,6 @@ class MessageHub(ManagerMixin):
|
|||||||
Returns:
|
Returns:
|
||||||
OrderedDict: A copy of all runtime information.
|
OrderedDict: A copy of all runtime information.
|
||||||
"""
|
"""
|
||||||
# return copy.deepcopy(self._runtime_info)
|
|
||||||
return self._runtime_info
|
return self._runtime_info
|
||||||
|
|
||||||
def get_scalar(self, key: str) -> HistoryBuffer:
|
def get_scalar(self, key: str) -> HistoryBuffer:
|
||||||
@ -442,3 +434,21 @@ class MessageHub(ManagerMixin):
|
|||||||
self._log_scalars = copy.deepcopy(state_dict._log_scalars)
|
self._log_scalars = copy.deepcopy(state_dict._log_scalars)
|
||||||
self._runtime_info = copy.deepcopy(state_dict._runtime_info)
|
self._runtime_info = copy.deepcopy(state_dict._runtime_info)
|
||||||
self._resumed_keys = copy.deepcopy(state_dict._resumed_keys)
|
self._resumed_keys = copy.deepcopy(state_dict._resumed_keys)
|
||||||
|
|
||||||
|
def _parse_input(self, name: str, value: Any) -> OrderedDict:
|
||||||
|
"""Parse input value.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name (str): name of input value.
|
||||||
|
value (Any): Input value.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Parsed input value.
|
||||||
|
"""
|
||||||
|
if value is None:
|
||||||
|
return OrderedDict()
|
||||||
|
elif isinstance(value, dict):
|
||||||
|
return OrderedDict(value)
|
||||||
|
else:
|
||||||
|
raise TypeError(f'{name} should be a dict or `None`, but '
|
||||||
|
f'got {type(name)}')
|
||||||
|
Loading…
x
Reference in New Issue
Block a user