# Copyright (c) OpenMMLab. All rights reserved. import os from unittest import TestCase import cv2 import numpy as np import torch from mmengine.data import InstanceData from mmselfsup.data import SelfSupDataSample from mmselfsup.visualization import SelfSupLocalVisualizer def _rand_patch_box(num_boxes, h, w): cx, cy, bw, bh = torch.rand(num_boxes, 4).T tl_x = ((cx * w) - (w * bw / 2)).clip(0, w) tl_y = ((cy * h) - (h * bh / 2)).clip(0, h) br_x = ((cx * w) + (w * bw / 2)).clip(0, w) br_y = ((cy * h) + (h * bh / 2)).clip(0, h) patch_box = torch.vstack([tl_x, tl_y, br_x, br_y]).T return patch_box.unsqueeze(0) class TestSelfSupLocalVisualizer(TestCase): def test_add_datasample(self): h = 12 w = 12 out_file = 'out_file.jpg' # ======= test relative_loc ======= # gt_instances num_patch_box = 5 image = np.random.randint(0, 256, (h, w, 3)) image = np.expand_dims(image, 0) pseudo_label = InstanceData() pseudo_label.patch_box = _rand_patch_box(num_patch_box, h, w) pseudo_label.unpatched_img = torch.tensor(image) gt_selfsup_data_sample = SelfSupDataSample() gt_selfsup_data_sample.pseudo_label = pseudo_label # pred_instances pseudo_label = InstanceData() pseudo_label.patch_box = _rand_patch_box(num_patch_box, h, w) pseudo_label.unpatched_img = torch.tensor(image) pred_selfsup_data_sample = SelfSupDataSample() pred_selfsup_data_sample.pseudo_label = pseudo_label selfsup_local_visualizer = SelfSupLocalVisualizer() # test gt_instances selfsup_local_visualizer.add_datasample('image', image, gt_selfsup_data_sample) # test out_file selfsup_local_visualizer.add_datasample( 'image', image, gt_selfsup_data_sample, out_file=out_file) self._assert_image_and_shape(out_file, (h, w, 3)) # test pred_instance selfsup_local_visualizer.add_datasample( 'image', image, gt_selfsup_data_sample, pred_selfsup_data_sample, draw_gt=False, out_file=out_file) self._assert_image_and_shape(out_file, (h, w, 3)) # test gt_instances and pred_instances selfsup_local_visualizer.add_datasample( 'image', image, gt_selfsup_data_sample, pred_selfsup_data_sample, out_file=out_file) self._assert_image_and_shape(out_file, (h, w * 2, 3)) # ======= test rotation_pred ======= # gt_instances image = [np.random.randint(0, 256, (h, w, 3)) for _ in range(4)] image = np.concatenate(image, axis=1) pseudo_label = InstanceData() pseudo_label.rot_label = torch.tensor([0, 1, 2, 3]) gt_selfsup_data_sample = SelfSupDataSample() gt_selfsup_data_sample.pseudo_label = pseudo_label # pred_instances pseudo_label = InstanceData() pseudo_label.rot_label = torch.tensor([0, 1, 2, 3]) pred_selfsup_data_sample = SelfSupDataSample() pred_selfsup_data_sample.pseudo_label = pseudo_label selfsup_local_visualizer = SelfSupLocalVisualizer() # test gt_instances selfsup_local_visualizer.add_datasample('image', image, gt_selfsup_data_sample) # test out_file selfsup_local_visualizer.add_datasample( 'image', image, gt_selfsup_data_sample, out_file=out_file) self._assert_image_and_shape(out_file, (h, w * 4, 3)) # test pred_instance selfsup_local_visualizer.add_datasample( 'image', image, gt_selfsup_data_sample, pred_selfsup_data_sample, draw_gt=False, out_file=out_file) self._assert_image_and_shape(out_file, (h, w * 4, 3)) # test gt_instances and pred_instances selfsup_local_visualizer.add_datasample( 'image', image, gt_selfsup_data_sample, pred_selfsup_data_sample, out_file=out_file) self._assert_image_and_shape(out_file, (h, w * 8, 3)) # ======= test mask image modeling ======= # gt_instances image = np.random.randint(0, 256, (h, w, 3)) mask = InstanceData() mask.value = torch.tensor([[1, 0], [0, 1]]) gt_selfsup_data_sample = SelfSupDataSample() gt_selfsup_data_sample.mask = mask # pred_instances mask = InstanceData() mask.value = torch.tensor([[1, 0], [0, 1]]) pred_selfsup_data_sample = SelfSupDataSample() pred_selfsup_data_sample.mask = mask selfsup_local_visualizer = SelfSupLocalVisualizer() # test gt_instances selfsup_local_visualizer.add_datasample('image', image, gt_selfsup_data_sample) # test out_file selfsup_local_visualizer.add_datasample( 'image', image, gt_selfsup_data_sample, out_file=out_file) self._assert_image_and_shape(out_file, (h, w, 3)) # test pred_instance selfsup_local_visualizer.add_datasample( 'image', image, gt_selfsup_data_sample, pred_selfsup_data_sample, draw_gt=False, out_file=out_file) self._assert_image_and_shape(out_file, (h, w, 3)) # test gt_instances and pred_instances selfsup_local_visualizer.add_datasample( 'image', image, gt_selfsup_data_sample, pred_selfsup_data_sample, out_file=out_file) self._assert_image_and_shape(out_file, (h, w * 2, 3)) # ======= test contrastive learning ======= # gt_instances image = [np.random.randint(0, 256, (h, w, 3)) for _ in range(2)] image = np.concatenate(image, axis=1) gt_selfsup_data_sample = SelfSupDataSample() # pred_instances pred_selfsup_data_sample = SelfSupDataSample() selfsup_local_visualizer = SelfSupLocalVisualizer() # test gt_instances selfsup_local_visualizer.add_datasample('image', image, gt_selfsup_data_sample) # test out_file selfsup_local_visualizer.add_datasample( 'image', image, gt_selfsup_data_sample, out_file=out_file) self._assert_image_and_shape(out_file, (h, w * 2, 3)) # test pred_instance selfsup_local_visualizer.add_datasample( 'image', image, gt_selfsup_data_sample, pred_selfsup_data_sample, draw_gt=False, out_file=out_file) self._assert_image_and_shape(out_file, (h, w * 2, 3)) # test gt_instances and pred_instances selfsup_local_visualizer.add_datasample( 'image', image, gt_selfsup_data_sample, pred_selfsup_data_sample, out_file=out_file) self._assert_image_and_shape(out_file, (h, w * 2 * 2, 3)) def _assert_image_and_shape(self, out_file, out_shape): assert os.path.exists(out_file) drawn_img = cv2.imread(out_file) assert drawn_img.shape == out_shape os.remove(out_file)