58 lines
2.2 KiB
Python
58 lines
2.2 KiB
Python
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||
|
import numpy as np
|
||
|
import pytest
|
||
|
|
||
|
from mmengine import MessageHub
|
||
|
|
||
|
|
||
|
class TestMessageHub:
|
||
|
|
||
|
def test_init(self):
|
||
|
message_hub = MessageHub('name')
|
||
|
assert message_hub.instance_name == 'name'
|
||
|
assert len(message_hub.log_buffers) == 0
|
||
|
assert len(message_hub.log_buffers) == 0
|
||
|
|
||
|
def test_update_log(self):
|
||
|
message_hub = MessageHub.create_instance()
|
||
|
# test create target `LogBuffer` by name
|
||
|
message_hub.update_log('name', 1)
|
||
|
log_buffer = message_hub.log_buffers['name']
|
||
|
assert (log_buffer._log_history == np.array([1])).all()
|
||
|
# test update target `LogBuffer` by name
|
||
|
message_hub.update_log('name', 1)
|
||
|
assert (log_buffer._log_history == np.array([1, 1])).all()
|
||
|
# unmatched string will raise a key error
|
||
|
|
||
|
def test_update_info(self):
|
||
|
message_hub = MessageHub.create_instance()
|
||
|
# test runtime value can be overwritten.
|
||
|
message_hub.update_info('key', 2)
|
||
|
assert message_hub.runtime_info['key'] == 2
|
||
|
message_hub.update_info('key', 1)
|
||
|
assert message_hub.runtime_info['key'] == 1
|
||
|
|
||
|
def test_get_log_buffers(self):
|
||
|
message_hub = MessageHub.create_instance()
|
||
|
# Get undefined key will raise error
|
||
|
with pytest.raises(KeyError):
|
||
|
message_hub.get_log('unknown')
|
||
|
# test get log_buffer as wished
|
||
|
log_history = np.array([1, 2, 3, 4, 5])
|
||
|
count = np.array([1, 1, 1, 1, 1])
|
||
|
for i in range(len(log_history)):
|
||
|
message_hub.update_log('test_value', float(log_history[i]),
|
||
|
int(count[i]))
|
||
|
recorded_history, recorded_count = \
|
||
|
message_hub.get_log('test_value').data
|
||
|
assert (log_history == recorded_history).all()
|
||
|
assert (recorded_count == count).all()
|
||
|
|
||
|
def test_get_runtime(self):
|
||
|
message_hub = MessageHub.create_instance()
|
||
|
with pytest.raises(KeyError):
|
||
|
message_hub.get_info('unknown')
|
||
|
recorded_dict = dict(a=1, b=2)
|
||
|
message_hub.update_info('test_value', recorded_dict)
|
||
|
assert message_hub.get_info('test_value') == recorded_dict
|