mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
Fix: avoid modification of scalar_dict in LocalVisBackend (#377)
This commit is contained in:
parent
5b648c119f
commit
ea61bf6bb7
@ -1,4 +1,5 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
import functools
|
||||
import os
|
||||
import os.path as osp
|
||||
@ -291,6 +292,7 @@ class LocalVisBackend(BaseVisBackend):
|
||||
Default to None.
|
||||
"""
|
||||
assert isinstance(scalar_dict, dict)
|
||||
scalar_dict = copy.deepcopy(scalar_dict)
|
||||
scalar_dict.setdefault('step', step)
|
||||
|
||||
if file_path is not None:
|
||||
@ -542,7 +544,8 @@ class TensorboardVisBackend(BaseVisBackend):
|
||||
value (int, float, torch.Tensor, np.ndarray): Value to save.
|
||||
step (int): Global step value to record. Default to 0.
|
||||
"""
|
||||
if isinstance(value, (int, float, torch.Tensor, np.ndarray)):
|
||||
if isinstance(value,
|
||||
(int, float, torch.Tensor, np.ndarray, np.number)):
|
||||
self._tensorboard.add_scalar(name, value, step)
|
||||
else:
|
||||
warnings.warn(f'Got {type(value)}, but numpy array, torch tensor, '
|
||||
|
@ -86,6 +86,7 @@ class TestLocalVisBackend:
|
||||
input_dict = {'map': 0.7, 'acc': 0.9}
|
||||
local_vis_backend = LocalVisBackend('temp_dir')
|
||||
local_vis_backend.add_scalars(input_dict)
|
||||
assert input_dict == {'map': 0.7, 'acc': 0.9}
|
||||
out_dict = load(local_vis_backend._scalar_save_file, 'json')
|
||||
assert out_dict == {'map': 0.7, 'acc': 0.9, 'step': 0}
|
||||
|
||||
@ -143,7 +144,19 @@ class TestTensorboardVisBackend:
|
||||
# test append mode
|
||||
tensorboard_vis_backend.add_scalar('map', 0.9, step=0)
|
||||
tensorboard_vis_backend.add_scalar('map', 0.95, step=1)
|
||||
|
||||
# test with numpy
|
||||
with pytest.warns(None) as record:
|
||||
tensorboard_vis_backend.add_scalar('map', np.array(0.9), step=0)
|
||||
tensorboard_vis_backend.add_scalar('map', np.array(0.95), step=1)
|
||||
tensorboard_vis_backend.add_scalar('map', np.array(9), step=0)
|
||||
tensorboard_vis_backend.add_scalar('map', np.array(95), step=1)
|
||||
tensorboard_vis_backend.add_scalar('map', np.array([9])[0], step=0)
|
||||
tensorboard_vis_backend.add_scalar(
|
||||
'map', np.array([95])[0], step=1)
|
||||
assert len(record) == 0
|
||||
# test with tensor
|
||||
tensorboard_vis_backend.add_scalar('map', torch.tensor(0.9), step=0)
|
||||
tensorboard_vis_backend.add_scalar('map', torch.tensor(0.95), step=1)
|
||||
# Unprocessable data will output a warning message
|
||||
with pytest.warns(Warning):
|
||||
tensorboard_vis_backend.add_scalar('map', [0.95])
|
||||
|
Loading…
x
Reference in New Issue
Block a user