[Fix] Support update np.ScalarType data in message_hub (#898)
* Clean the commit history * Update message_hub.py --------- Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>pull/887/head
parent
5753cd98e0
commit
6dc1d7082a
|
@ -314,17 +314,19 @@ class MessageHub(ManagerMixin):
|
||||||
return self._runtime_info[key]
|
return self._runtime_info[key]
|
||||||
|
|
||||||
def _get_valid_value(
|
def _get_valid_value(
|
||||||
self, value: Union['torch.Tensor', np.ndarray, int, float]) \
|
self,
|
||||||
-> Union[int, float]:
|
value: Union['torch.Tensor', np.ndarray, np.number, int, float],
|
||||||
|
) -> Union[int, float]:
|
||||||
"""Convert value to python built-in type.
|
"""Convert value to python built-in type.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
value (torch.Tensor or np.ndarray or int or float): value of log.
|
value (torch.Tensor or np.ndarray or np.number or int or float):
|
||||||
|
value of log.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
float or int: python built-in type value.
|
float or int: python built-in type value.
|
||||||
"""
|
"""
|
||||||
if isinstance(value, np.ndarray):
|
if isinstance(value, (np.ndarray, np.number)):
|
||||||
assert value.size == 1
|
assert value.size == 1
|
||||||
value = value.item()
|
value = value.item()
|
||||||
elif isinstance(value, (int, float)):
|
elif isinstance(value, (int, float)):
|
||||||
|
|
|
@ -34,14 +34,18 @@ class TestMessageHub:
|
||||||
|
|
||||||
def test_update_scalar(self):
|
def test_update_scalar(self):
|
||||||
message_hub = MessageHub.get_instance('mmengine')
|
message_hub = MessageHub.get_instance('mmengine')
|
||||||
# test create target `HistoryBuffer` by name
|
# Update scalar with int.
|
||||||
message_hub.update_scalar('name', 1)
|
message_hub.update_scalar('name', 1)
|
||||||
log_buffer = message_hub.log_scalars['name']
|
log_buffer = message_hub.log_scalars['name']
|
||||||
assert (log_buffer._log_history == np.array([1])).all()
|
assert (log_buffer._log_history == np.array([1])).all()
|
||||||
# test update target `HistoryBuffer` by name
|
|
||||||
message_hub.update_scalar('name', 1)
|
# Update scalar with np.ndarray.
|
||||||
|
message_hub.update_scalar('name', np.array(1))
|
||||||
assert (log_buffer._log_history == np.array([1, 1])).all()
|
assert (log_buffer._log_history == np.array([1, 1])).all()
|
||||||
# unmatched string will raise a key error
|
|
||||||
|
# Update scalar with np.int
|
||||||
|
message_hub.update_scalar('name', np.int32(1))
|
||||||
|
assert (log_buffer._log_history == np.array([1, 1, 1])).all()
|
||||||
|
|
||||||
def test_update_info(self):
|
def test_update_info(self):
|
||||||
message_hub = MessageHub.get_instance('mmengine')
|
message_hub = MessageHub.get_instance('mmengine')
|
||||||
|
|
Loading…
Reference in New Issue