mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
142 lines
4.9 KiB
Python
142 lines
4.9 KiB
Python
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||
|
import random
|
||
|
import sys
|
||
|
from unittest.mock import MagicMock
|
||
|
|
||
|
import numpy as np
|
||
|
import pytest
|
||
|
import torch
|
||
|
|
||
|
from mmengine.data import BaseDataElement, BaseDataSample
|
||
|
from mmengine.visualizer import (VISUALIZERS, LocalWriter, TensorboardWriter,
|
||
|
WandbWriter)
|
||
|
|
||
|
|
||
|
def get_demo_datasample():
|
||
|
metainfo = dict(
|
||
|
img_id=random.randint(0, 100),
|
||
|
img_shape=(random.randint(400, 600), random.randint(400, 600)))
|
||
|
gt_instances = BaseDataElement(
|
||
|
data=dict(bboxes=torch.rand((5, 4)), labels=torch.rand((5, ))))
|
||
|
pred_instances = BaseDataElement(
|
||
|
data=dict(bboxes=torch.rand((5, 4)), scores=torch.rand((5, ))))
|
||
|
data = dict(gt_instances=gt_instances, pred_instances=pred_instances)
|
||
|
data_sample = BaseDataSample(data=data, metainfo=metainfo)
|
||
|
return data_sample
|
||
|
|
||
|
|
||
|
class TestLocalWriter:
|
||
|
|
||
|
def test_add_image(self):
|
||
|
image = np.random.randint(0, 256, size=(10, 10, 3))
|
||
|
data_sample = get_demo_datasample()
|
||
|
|
||
|
local_writer = LocalWriter(visuailzer=dict(type='Visualizer'))
|
||
|
local_writer.add_image('img', image)
|
||
|
local_writer.add_image('img', image, data_sample)
|
||
|
|
||
|
bboxes = np.array([[1, 1, 2, 2], [1, 1.5, 1, 2.5]])
|
||
|
local_writer.visualizer.draw_bboxes(bboxes)
|
||
|
local_writer.add_image('img', local_writer.visualizer.get_image())
|
||
|
|
||
|
visuailzer = VISUALIZERS.build(dict(type='Visualizer'))
|
||
|
local_writer = LocalWriter(visuailzer=visuailzer)
|
||
|
local_writer.add_image('img', image)
|
||
|
local_writer.add_image('img', image, data_sample)
|
||
|
|
||
|
# test `visuailzer` parameter
|
||
|
# `visuailzer` parameter which must be either Visualizer instance
|
||
|
# or valid dictionary.
|
||
|
with pytest.raises(AssertionError):
|
||
|
|
||
|
class A:
|
||
|
pass
|
||
|
|
||
|
LocalWriter(visuailzer=A())
|
||
|
with pytest.raises(AssertionError):
|
||
|
LocalWriter(visuailzer=dict(a='Visualizer'))
|
||
|
|
||
|
# test not visuailzer
|
||
|
# The visuailzer parameter must be set when
|
||
|
# the local_writer object is instantiated and
|
||
|
# the `add_image` method is called.
|
||
|
with pytest.raises(AssertionError):
|
||
|
local_writer = LocalWriter()
|
||
|
local_writer.add_image('img', image)
|
||
|
|
||
|
def test_add_scaler(self):
|
||
|
local_writer = LocalWriter()
|
||
|
local_writer.add_scaler('map', 0.9)
|
||
|
|
||
|
def test_add_hyperparams(self):
|
||
|
local_writer = LocalWriter()
|
||
|
local_writer.add_hyperparams('hyper', dict(lr=0.01))
|
||
|
|
||
|
|
||
|
class TestWandbWriter:
|
||
|
sys.modules['wandb'] = MagicMock()
|
||
|
|
||
|
def test_add_image(self):
|
||
|
image = np.random.randint(0, 256, size=(10, 10, 3))
|
||
|
data_sample = get_demo_datasample()
|
||
|
|
||
|
wandb_writer = WandbWriter()
|
||
|
assert not wandb_writer.visualizer
|
||
|
wandb_writer.add_image('img', image, data_sample)
|
||
|
|
||
|
wandb_writer = WandbWriter(visuailzer=dict(type='Visualizer'))
|
||
|
assert wandb_writer.visualizer
|
||
|
wandb_writer.add_image('img', image)
|
||
|
wandb_writer.add_image('img', image, data_sample)
|
||
|
|
||
|
wandb_writer.visualizer.set_image(image)
|
||
|
wandb_writer.add_image('img', wandb_writer.visualizer.get_image())
|
||
|
|
||
|
# TODO test file exist
|
||
|
|
||
|
def test_add_scaler(self):
|
||
|
wandb_writer = WandbWriter()
|
||
|
wandb_writer.add_scaler('map', 0.9)
|
||
|
|
||
|
def test_add_hyperparams(self):
|
||
|
wandb_writer = WandbWriter()
|
||
|
wandb_writer.add_hyperparams('hyper', dict(lr=0.01))
|
||
|
|
||
|
|
||
|
class TestTensorboardWriter:
|
||
|
sys.modules['torch.utils.tensorboard.SummaryWriter'] = MagicMock()
|
||
|
|
||
|
def test_add_image(self):
|
||
|
image = np.random.randint(0, 256, size=(10, 10, 3))
|
||
|
data_sample = get_demo_datasample()
|
||
|
|
||
|
tensorboard_writer = TensorboardWriter()
|
||
|
assert not tensorboard_writer.visualizer
|
||
|
tensorboard_writer.add_image('img', image, data_sample)
|
||
|
|
||
|
tensorboard_writer = TensorboardWriter(
|
||
|
visuailzer=dict(type='Visualizer'))
|
||
|
assert tensorboard_writer.visualizer
|
||
|
tensorboard_writer.add_image('img', image)
|
||
|
tensorboard_writer.add_image('img', image, data_sample)
|
||
|
|
||
|
tensorboard_writer.visualizer.set_image(image)
|
||
|
tensorboard_writer.add_image('img',
|
||
|
tensorboard_writer.visualizer.get_image())
|
||
|
|
||
|
# test no visualizer
|
||
|
# The visuailzer parameter must be set when
|
||
|
# the tensorboard_writer object is instantiated and
|
||
|
# the `add_image` method is called.
|
||
|
with pytest.raises(AssertionError):
|
||
|
tensorboard_writer = TensorboardWriter()
|
||
|
tensorboard_writer.add_image('img', image)
|
||
|
|
||
|
def test_add_scaler(self):
|
||
|
tensorboard_writer = TensorboardWriter()
|
||
|
tensorboard_writer.add_scaler('map', 0.9)
|
||
|
|
||
|
def test_add_hyperparams(self):
|
||
|
tensorboard_writer = TensorboardWriter()
|
||
|
tensorboard_writer.add_hyperparams('hyper', dict(lr=0.01))
|