From 7369d5004958fa7ded78022de6a2482fd8a2bf51 Mon Sep 17 00:00:00 2001 From: MengzhangLI Date: Mon, 1 Aug 2022 15:03:01 +0800 Subject: [PATCH] [Fix] Fix SegLocalVisualizer gt_sem_seg cuda tensor error (#1845) * [Fix] Fix SegLocalVisualizer gt_sem_seg cuda tensor error * fix ut error and add config visualizer dict * fix ut error --- configs/_base_/default_runtime.py | 3 + mmseg/visualization/local_visualizer.py | 2 +- .../test_local_visualizer.py | 216 ++++++++---------- 3 files changed, 102 insertions(+), 119 deletions(-) diff --git a/configs/_base_/default_runtime.py b/configs/_base_/default_runtime.py index 5925c6926..921175671 100644 --- a/configs/_base_/default_runtime.py +++ b/configs/_base_/default_runtime.py @@ -4,6 +4,9 @@ env_cfg = dict( mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), dist_cfg=dict(backend='nccl'), ) +vis_backends = [dict(type='LocalVisBackend')] +visualizer = dict( + type='SegLocalVisualizer', vis_backends=vis_backends, name='visualizer') log_level = 'INFO' load_from = None resume = False diff --git a/mmseg/visualization/local_visualizer.py b/mmseg/visualization/local_visualizer.py index ea966fa5b..df7ba9a4c 100644 --- a/mmseg/visualization/local_visualizer.py +++ b/mmseg/visualization/local_visualizer.py @@ -81,7 +81,7 @@ class SegLocalVisualizer(Visualizer): """ num_classes = len(classes) - sem_seg = sem_seg.data + sem_seg = sem_seg.cpu().data ids = np.unique(sem_seg)[::-1] legal_indices = ids < num_classes ids = ids[legal_indices] diff --git a/tests/test_visualization/test_local_visualizer.py b/tests/test_visualization/test_local_visualizer.py index 1eb091de7..042e96b8a 100644 --- a/tests/test_visualization/test_local_visualizer.py +++ b/tests/test_visualization/test_local_visualizer.py @@ -1,11 +1,13 @@ # Copyright (c) OpenMMLab. All rights reserved. import os import os.path as osp +import tempfile from unittest import TestCase import cv2 import mmcv import numpy as np +import pytest import torch from mmengine.data import PixelData @@ -27,66 +29,55 @@ class TestSegLocalVisualizer(TestCase): gt_sem_seg_data = dict(data=torch.randint(0, num_class, (1, h, w))) gt_sem_seg = PixelData(**gt_sem_seg_data) - gt_seg_data_sample = SegDataSample() - gt_seg_data_sample.gt_sem_seg = gt_sem_seg + @pytest.mark.parametrize('gt_sem_seg', (gt_sem_seg, gt_sem_seg.cuda())) + def test_add_datasample_forward(gt_sem_seg): + gt_seg_data_sample = SegDataSample() + gt_seg_data_sample.gt_sem_seg = gt_sem_seg - seg_local_visualizer = SegLocalVisualizer( - vis_backends=[dict(type='LocalVisBackend')], save_dir='temp_dir') - seg_local_visualizer.dataset_meta = dict( - classes=('background', 'foreground'), - palette=[[120, 120, 120], [6, 230, 230]]) - seg_local_visualizer.add_datasample(out_file, image, - gt_seg_data_sample) + with tempfile.TemporaryDirectory(dir='temp_dir') as tmp_dir: + seg_local_visualizer = SegLocalVisualizer( + vis_backends=[dict(type='LocalVisBackend')], + save_dir=tmp_dir) + seg_local_visualizer.dataset_meta = dict( + classes=('background', 'foreground'), + palette=[[120, 120, 120], [6, 230, 230]]) - # test out_file - seg_local_visualizer.add_datasample(out_file, image, - gt_seg_data_sample) + # test out_file + seg_local_visualizer.add_datasample(out_file, image, + gt_seg_data_sample) - assert os.path.exists( - osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png')) - drawn_img = cv2.imread( - osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png')) - assert drawn_img.shape == (h, w, 3) + assert os.path.exists( + osp.join(tmp_dir, 'vis_data', 'vis_image', + out_file + '_0.png')) + drawn_img = cv2.imread( + osp.join(tmp_dir, 'vis_data', 'vis_image', + out_file + '_0.png')) + assert drawn_img.shape == (h, w, 3) - os.remove( - osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png')) - os.rmdir('temp_dir' + '/vis_data/vis_image') + # test gt_instances and pred_instances + pred_sem_seg_data = dict( + data=torch.randint(0, num_class, (1, h, w))) + pred_sem_seg = PixelData(**pred_sem_seg_data) - # test gt_instances and pred_instances - pred_sem_seg_data = dict(data=torch.randint(0, num_class, (1, h, w))) - pred_sem_seg = PixelData(**pred_sem_seg_data) + pred_seg_data_sample = SegDataSample() + pred_seg_data_sample.pred_sem_seg = pred_sem_seg - pred_seg_data_sample = SegDataSample() - pred_seg_data_sample.pred_sem_seg = pred_sem_seg + seg_local_visualizer.add_datasample(out_file, image, + gt_seg_data_sample, + pred_seg_data_sample) + self._assert_image_and_shape( + osp.join(tmp_dir, 'vis_data', 'vis_image', + out_file + '_0.png'), (h, w * 2, 3)) - seg_local_visualizer.add_datasample(out_file, image, - gt_seg_data_sample, - pred_seg_data_sample) - self._assert_image_and_shape( - osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png'), - (h, w * 2, 3)) - - seg_local_visualizer.add_datasample( - out_file, - image, - gt_seg_data_sample, - pred_seg_data_sample, - draw_gt=False) - self._assert_image_and_shape( - osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png'), - (h, w, 3)) - - seg_local_visualizer.add_datasample( - out_file, - image, - gt_seg_data_sample, - pred_seg_data_sample, - draw_pred=False) - self._assert_image_and_shape( - osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png'), - (h, w, 3)) - os.rmdir('temp_dir/vis_data') - os.rmdir('temp_dir') + seg_local_visualizer.add_datasample( + out_file, + image, + gt_seg_data_sample, + pred_seg_data_sample, + draw_gt=False) + self._assert_image_and_shape( + osp.join(tmp_dir, 'vis_data', 'vis_image', + out_file + '_0.png'), (h, w, 3)) def test_cityscapes_add_datasample(self): h = 128 @@ -110,78 +101,67 @@ class TestSegLocalVisualizer(TestCase): gt_sem_seg_data = dict(data=sem_seg) gt_sem_seg = PixelData(**gt_sem_seg_data) - gt_seg_data_sample = SegDataSample() - gt_seg_data_sample.gt_sem_seg = gt_sem_seg + @pytest.mark.parametrize('gt_sem_seg', (gt_sem_seg, gt_sem_seg.cuda())) + def test_cityscapes_add_datasample_forward(gt_sem_seg): + gt_seg_data_sample = SegDataSample() + gt_seg_data_sample.gt_sem_seg = gt_sem_seg + with tempfile.TemporaryDirectory(dir='temp_dir') as tmp_dir: + seg_local_visualizer = SegLocalVisualizer( + vis_backends=[dict(type='LocalVisBackend')], + save_dir='temp_dir') + seg_local_visualizer.dataset_meta = dict( + classes=('road', 'sidewalk', 'building', 'wall', 'fence', + 'pole', 'traffic light', 'traffic sign', + 'vegetation', 'terrain', 'sky', 'person', 'rider', + 'car', 'truck', 'bus', 'train', 'motorcycle', + 'bicycle'), + palette=[[128, 64, 128], [244, 35, 232], [70, 70, 70], + [102, 102, 156], [190, 153, 153], [153, 153, 153], + [250, 170, 30], [220, 220, 0], [107, 142, 35], + [152, 251, 152], [70, 130, 180], [220, 20, 60], + [255, 0, 0], [0, 0, 142], [0, 0, 70], + [0, 60, 100], [0, 80, 100], [0, 0, 230], + [119, 11, 32]]) + seg_local_visualizer.add_datasample(out_file, image, + gt_seg_data_sample) - seg_local_visualizer = SegLocalVisualizer( - vis_backends=[dict(type='LocalVisBackend')], save_dir='temp_dir') - seg_local_visualizer.dataset_meta = dict( - classes=('road', 'sidewalk', 'building', 'wall', 'fence', 'pole', - 'traffic light', 'traffic sign', 'vegetation', 'terrain', - 'sky', 'person', 'rider', 'car', 'truck', 'bus', 'train', - 'motorcycle', 'bicycle'), - palette=[[128, 64, 128], [244, 35, 232], [70, 70, 70], - [102, 102, 156], [190, 153, 153], [153, 153, 153], - [250, 170, 30], [220, 220, 0], [107, 142, 35], - [152, 251, 152], [70, 130, 180], [220, 20, 60], - [255, 0, 0], [0, 0, 142], [0, 0, 70], [0, 60, 100], - [0, 80, 100], [0, 0, 230], [119, 11, 32]]) - seg_local_visualizer.add_datasample(out_file, image, - gt_seg_data_sample) + # test out_file + seg_local_visualizer.add_datasample(out_file, image, + gt_seg_data_sample) + assert os.path.exists( + osp.join(tmp_dir, 'vis_data', 'vis_image', + out_file + '_0.png')) + drawn_img = cv2.imread( + osp.join(tmp_dir, 'vis_data', 'vis_image', + out_file + '_0.png')) + assert drawn_img.shape == (h, w, 3) - # test out_file - seg_local_visualizer.add_datasample(out_file, image, - gt_seg_data_sample) - assert os.path.exists( - osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png')) - drawn_img = cv2.imread( - osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png')) - assert drawn_img.shape == (h, w, 3) + # test gt_instances and pred_instances + pred_sem_seg_data = dict( + data=torch.randint(0, num_class, (1, h, w))) + pred_sem_seg = PixelData(**pred_sem_seg_data) - os.remove( - osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png')) - os.rmdir('temp_dir/vis_data/vis_image') + pred_seg_data_sample = SegDataSample() + pred_seg_data_sample.pred_sem_seg = pred_sem_seg - # test gt_instances and pred_instances - pred_sem_seg_data = dict(data=torch.randint(0, num_class, (1, h, w))) - pred_sem_seg = PixelData(**pred_sem_seg_data) + seg_local_visualizer.add_datasample(out_file, image, + gt_seg_data_sample, + pred_seg_data_sample) + self._assert_image_and_shape( + osp.join(tmp_dir, 'vis_data', 'vis_image', + out_file + '_0.png'), (h, w * 2, 3)) - pred_seg_data_sample = SegDataSample() - pred_seg_data_sample.pred_sem_seg = pred_sem_seg - - seg_local_visualizer.add_datasample(out_file, image, - gt_seg_data_sample, - pred_seg_data_sample) - self._assert_image_and_shape( - osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png'), - (h, w * 2, 3)) - - seg_local_visualizer.add_datasample( - out_file, - image, - gt_seg_data_sample, - pred_seg_data_sample, - draw_gt=False) - self._assert_image_and_shape( - osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png'), - (h, w, 3)) - - seg_local_visualizer.add_datasample( - out_file, - image, - gt_seg_data_sample, - pred_seg_data_sample, - draw_pred=False) - self._assert_image_and_shape( - osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png'), - (h, w, 3)) - - os.rmdir('temp_dir/vis_data') - os.rmdir('temp_dir') + seg_local_visualizer.add_datasample( + out_file, + image, + gt_seg_data_sample, + pred_seg_data_sample, + draw_gt=False) + self._assert_image_and_shape( + osp.join(tmp_dir, 'vis_data', 'vis_image', + out_file + '_0.png'), (h, w, 3)) def _assert_image_and_shape(self, out_file, out_shape): assert os.path.exists(out_file) drawn_img = cv2.imread(out_file) assert drawn_img.shape == out_shape - os.remove(out_file) - os.rmdir('temp_dir/vis_data/vis_image')