# Copyright (c) OpenMMLab. All rights reserved. import numpy as np import pytest import torch from mmengine import HistoryBuffer class TestLoggerBuffer: def test_init(self): log_buffer = HistoryBuffer() assert log_buffer.max_length == 1000000 log_history, counts = log_buffer.data assert len(log_history) == 0 assert len(counts) == 0 # test the length of array exceed `max_length` logs = np.random.randint(1, 10, log_buffer.max_length + 1) counts = np.random.randint(1, 10, log_buffer.max_length + 1) log_buffer = HistoryBuffer(logs, counts) log_history, count_history = log_buffer.data assert len(log_history) == log_buffer.max_length assert len(count_history) == log_buffer.max_length assert logs[1] == log_history[0] assert counts[1] == count_history[0] # The different lengths of `log_history` and `count_history` will # raise error with pytest.raises(AssertionError): HistoryBuffer([1, 2], [1]) @pytest.mark.parametrize('array_method', [torch.tensor, np.array, lambda x: x]) def test_update(self, array_method): # test `update` method log_buffer = HistoryBuffer() log_history = array_method([1, 2, 3, 4, 5]) count_history = array_method([5, 5, 5, 5, 5]) for i in range(len(log_history)): log_buffer.update(float(log_history[i]), float(count_history[i])) recorded_history, recorded_count = log_buffer.data for a, b in zip(log_history, recorded_history): assert float(a) == float(b) for a, b in zip(count_history, recorded_count): assert float(a) == float(b) # test the length of `array` exceed `max_length` max_array = array_method([[-1] + [1] * (log_buffer.max_length - 1)]) max_count = array_method([[-1] + [1] * (log_buffer.max_length - 1)]) log_buffer = HistoryBuffer(max_array, max_count) log_buffer.update(1) log_history, count_history = log_buffer.data assert log_history[0] == 1 assert count_history[0] == 1 assert len(log_history) == log_buffer.max_length assert len(count_history) == log_buffer.max_length # Update an iterable object will raise a type error, `log_val` and # `count` should be single value with pytest.raises(TypeError): log_buffer.update(array_method([1, 2])) @pytest.mark.parametrize('statistics_method, log_buffer_type', [(np.min, 'min'), (np.max, 'max')]) def test_max_min(self, statistics_method, log_buffer_type): log_history = np.random.randint(1, 5, 20) count_history = np.ones(20) log_buffer = HistoryBuffer(log_history, count_history) assert statistics_method(log_history[-10:]) == \ getattr(log_buffer, log_buffer_type)(10) assert statistics_method(log_history) == \ getattr(log_buffer, log_buffer_type)() def test_mean(self): log_history = np.random.randint(1, 5, 20) count_history = np.ones(20) log_buffer = HistoryBuffer(log_history, count_history) assert np.sum(log_history[-10:]) / \ np.sum(count_history[-10:]) == \ log_buffer.mean(10) assert np.sum(log_history) / \ np.sum(count_history) == \ log_buffer.mean() def test_current(self): log_history = np.random.randint(1, 5, 20) count_history = np.ones(20) log_buffer = HistoryBuffer(log_history, count_history) assert log_history[-1] == log_buffer.current() # test get empty array log_buffer = HistoryBuffer() with pytest.raises(ValueError): log_buffer.current() def test_statistics(self): log_history = np.array([1, 2, 3, 4, 5]) count_history = np.array([1, 1, 1, 1, 1]) log_buffer = HistoryBuffer(log_history, count_history) assert log_buffer.statistics('mean') == 3 assert log_buffer.statistics('min') == 1 assert log_buffer.statistics('max') == 5 assert log_buffer.statistics('current') == 5 # Access unknown method will raise an error. with pytest.raises(KeyError): log_buffer.statistics('unknown') def test_register_statistics(self): @HistoryBuffer.register_statistics def custom_statistics(self): return -1 log_buffer = HistoryBuffer() assert log_buffer.statistics('custom_statistics') == -1