From 6dc1d7082a1e2bc32049e65f283c4fd20de1c026 Mon Sep 17 00:00:00 2001 From: Mashiro <57566630+HAOCHENYE@users.noreply.github.com> Date: Wed, 1 Feb 2023 23:53:28 +0800 Subject: [PATCH] [Fix] Support update np.ScalarType data in message_hub (#898) * Clean the commit history * Update message_hub.py --------- Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> --- mmengine/logging/message_hub.py | 10 ++++++---- tests/test_logging/test_message_hub.py | 12 ++++++++---- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/mmengine/logging/message_hub.py b/mmengine/logging/message_hub.py index 79fc131a..492af34e 100644 --- a/mmengine/logging/message_hub.py +++ b/mmengine/logging/message_hub.py @@ -314,17 +314,19 @@ class MessageHub(ManagerMixin): return self._runtime_info[key] def _get_valid_value( - self, value: Union['torch.Tensor', np.ndarray, int, float]) \ - -> Union[int, float]: + self, + value: Union['torch.Tensor', np.ndarray, np.number, int, float], + ) -> Union[int, float]: """Convert value to python built-in type. Args: - value (torch.Tensor or np.ndarray or int or float): value of log. + value (torch.Tensor or np.ndarray or np.number or int or float): + value of log. Returns: float or int: python built-in type value. """ - if isinstance(value, np.ndarray): + if isinstance(value, (np.ndarray, np.number)): assert value.size == 1 value = value.item() elif isinstance(value, (int, float)): diff --git a/tests/test_logging/test_message_hub.py b/tests/test_logging/test_message_hub.py index b6061f82..4dffdec0 100644 --- a/tests/test_logging/test_message_hub.py +++ b/tests/test_logging/test_message_hub.py @@ -34,14 +34,18 @@ class TestMessageHub: def test_update_scalar(self): message_hub = MessageHub.get_instance('mmengine') - # test create target `HistoryBuffer` by name + # Update scalar with int. message_hub.update_scalar('name', 1) log_buffer = message_hub.log_scalars['name'] assert (log_buffer._log_history == np.array([1])).all() - # test update target `HistoryBuffer` by name - message_hub.update_scalar('name', 1) + + # Update scalar with np.ndarray. + message_hub.update_scalar('name', np.array(1)) assert (log_buffer._log_history == np.array([1, 1])).all() - # unmatched string will raise a key error + + # Update scalar with np.int + message_hub.update_scalar('name', np.int32(1)) + assert (log_buffer._log_history == np.array([1, 1, 1])).all() def test_update_info(self): message_hub = MessageHub.get_instance('mmengine')