[Fix] Support update np.ScalarType data in message_hub (#898)

* Clean the commit history

* Update message_hub.py

---------

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
pull/887/head
Mashiro 2023-02-01 23:53:28 +08:00 committed by GitHub
parent 5753cd98e0
commit 6dc1d7082a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 14 additions and 8 deletions

View File

@ -314,17 +314,19 @@ class MessageHub(ManagerMixin):
return self._runtime_info[key] return self._runtime_info[key]
def _get_valid_value( def _get_valid_value(
self, value: Union['torch.Tensor', np.ndarray, int, float]) \ self,
-> Union[int, float]: value: Union['torch.Tensor', np.ndarray, np.number, int, float],
) -> Union[int, float]:
"""Convert value to python built-in type. """Convert value to python built-in type.
Args: Args:
value (torch.Tensor or np.ndarray or int or float): value of log. value (torch.Tensor or np.ndarray or np.number or int or float):
value of log.
Returns: Returns:
float or int: python built-in type value. float or int: python built-in type value.
""" """
if isinstance(value, np.ndarray): if isinstance(value, (np.ndarray, np.number)):
assert value.size == 1 assert value.size == 1
value = value.item() value = value.item()
elif isinstance(value, (int, float)): elif isinstance(value, (int, float)):

View File

@ -34,14 +34,18 @@ class TestMessageHub:
def test_update_scalar(self): def test_update_scalar(self):
message_hub = MessageHub.get_instance('mmengine') message_hub = MessageHub.get_instance('mmengine')
# test create target `HistoryBuffer` by name # Update scalar with int.
message_hub.update_scalar('name', 1) message_hub.update_scalar('name', 1)
log_buffer = message_hub.log_scalars['name'] log_buffer = message_hub.log_scalars['name']
assert (log_buffer._log_history == np.array([1])).all() assert (log_buffer._log_history == np.array([1])).all()
# test update target `HistoryBuffer` by name
message_hub.update_scalar('name', 1) # Update scalar with np.ndarray.
message_hub.update_scalar('name', np.array(1))
assert (log_buffer._log_history == np.array([1, 1])).all() assert (log_buffer._log_history == np.array([1, 1])).all()
# unmatched string will raise a key error
# Update scalar with np.int
message_hub.update_scalar('name', np.int32(1))
assert (log_buffer._log_history == np.array([1, 1, 1])).all()
def test_update_info(self): def test_update_info(self):
message_hub = MessageHub.get_instance('mmengine') message_hub = MessageHub.get_instance('mmengine')