[Fix] Fix SegLocalVisualizer gt_sem_seg cuda tensor error (#1845)

* [Fix] Fix SegLocalVisualizer gt_sem_seg cuda tensor error

* fix ut error and add config visualizer dict

* fix ut error
pull/1850/head
MengzhangLI 2022-08-01 15:03:01 +08:00 committed by GitHub
parent 5d9650838e
commit 7369d50049
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 102 additions and 119 deletions

View File

@ -4,6 +4,9 @@ env_cfg = dict(
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
dist_cfg=dict(backend='nccl'), dist_cfg=dict(backend='nccl'),
) )
vis_backends = [dict(type='LocalVisBackend')]
visualizer = dict(
type='SegLocalVisualizer', vis_backends=vis_backends, name='visualizer')
log_level = 'INFO' log_level = 'INFO'
load_from = None load_from = None
resume = False resume = False

View File

@ -81,7 +81,7 @@ class SegLocalVisualizer(Visualizer):
""" """
num_classes = len(classes) num_classes = len(classes)
sem_seg = sem_seg.data sem_seg = sem_seg.cpu().data
ids = np.unique(sem_seg)[::-1] ids = np.unique(sem_seg)[::-1]
legal_indices = ids < num_classes legal_indices = ids < num_classes
ids = ids[legal_indices] ids = ids[legal_indices]

View File

@ -1,11 +1,13 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import os import os
import os.path as osp import os.path as osp
import tempfile
from unittest import TestCase from unittest import TestCase
import cv2 import cv2
import mmcv import mmcv
import numpy as np import numpy as np
import pytest
import torch import torch
from mmengine.data import PixelData from mmengine.data import PixelData
@ -27,66 +29,55 @@ class TestSegLocalVisualizer(TestCase):
gt_sem_seg_data = dict(data=torch.randint(0, num_class, (1, h, w))) gt_sem_seg_data = dict(data=torch.randint(0, num_class, (1, h, w)))
gt_sem_seg = PixelData(**gt_sem_seg_data) gt_sem_seg = PixelData(**gt_sem_seg_data)
gt_seg_data_sample = SegDataSample() @pytest.mark.parametrize('gt_sem_seg', (gt_sem_seg, gt_sem_seg.cuda()))
gt_seg_data_sample.gt_sem_seg = gt_sem_seg def test_add_datasample_forward(gt_sem_seg):
gt_seg_data_sample = SegDataSample()
gt_seg_data_sample.gt_sem_seg = gt_sem_seg
seg_local_visualizer = SegLocalVisualizer( with tempfile.TemporaryDirectory(dir='temp_dir') as tmp_dir:
vis_backends=[dict(type='LocalVisBackend')], save_dir='temp_dir') seg_local_visualizer = SegLocalVisualizer(
seg_local_visualizer.dataset_meta = dict( vis_backends=[dict(type='LocalVisBackend')],
classes=('background', 'foreground'), save_dir=tmp_dir)
palette=[[120, 120, 120], [6, 230, 230]]) seg_local_visualizer.dataset_meta = dict(
seg_local_visualizer.add_datasample(out_file, image, classes=('background', 'foreground'),
gt_seg_data_sample) palette=[[120, 120, 120], [6, 230, 230]])
# test out_file # test out_file
seg_local_visualizer.add_datasample(out_file, image, seg_local_visualizer.add_datasample(out_file, image,
gt_seg_data_sample) gt_seg_data_sample)
assert os.path.exists( assert os.path.exists(
osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png')) osp.join(tmp_dir, 'vis_data', 'vis_image',
drawn_img = cv2.imread( out_file + '_0.png'))
osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png')) drawn_img = cv2.imread(
assert drawn_img.shape == (h, w, 3) osp.join(tmp_dir, 'vis_data', 'vis_image',
out_file + '_0.png'))
assert drawn_img.shape == (h, w, 3)
os.remove( # test gt_instances and pred_instances
osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png')) pred_sem_seg_data = dict(
os.rmdir('temp_dir' + '/vis_data/vis_image') data=torch.randint(0, num_class, (1, h, w)))
pred_sem_seg = PixelData(**pred_sem_seg_data)
# test gt_instances and pred_instances pred_seg_data_sample = SegDataSample()
pred_sem_seg_data = dict(data=torch.randint(0, num_class, (1, h, w))) pred_seg_data_sample.pred_sem_seg = pred_sem_seg
pred_sem_seg = PixelData(**pred_sem_seg_data)
pred_seg_data_sample = SegDataSample() seg_local_visualizer.add_datasample(out_file, image,
pred_seg_data_sample.pred_sem_seg = pred_sem_seg gt_seg_data_sample,
pred_seg_data_sample)
self._assert_image_and_shape(
osp.join(tmp_dir, 'vis_data', 'vis_image',
out_file + '_0.png'), (h, w * 2, 3))
seg_local_visualizer.add_datasample(out_file, image, seg_local_visualizer.add_datasample(
gt_seg_data_sample, out_file,
pred_seg_data_sample) image,
self._assert_image_and_shape( gt_seg_data_sample,
osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png'), pred_seg_data_sample,
(h, w * 2, 3)) draw_gt=False)
self._assert_image_and_shape(
seg_local_visualizer.add_datasample( osp.join(tmp_dir, 'vis_data', 'vis_image',
out_file, out_file + '_0.png'), (h, w, 3))
image,
gt_seg_data_sample,
pred_seg_data_sample,
draw_gt=False)
self._assert_image_and_shape(
osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png'),
(h, w, 3))
seg_local_visualizer.add_datasample(
out_file,
image,
gt_seg_data_sample,
pred_seg_data_sample,
draw_pred=False)
self._assert_image_and_shape(
osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png'),
(h, w, 3))
os.rmdir('temp_dir/vis_data')
os.rmdir('temp_dir')
def test_cityscapes_add_datasample(self): def test_cityscapes_add_datasample(self):
h = 128 h = 128
@ -110,78 +101,67 @@ class TestSegLocalVisualizer(TestCase):
gt_sem_seg_data = dict(data=sem_seg) gt_sem_seg_data = dict(data=sem_seg)
gt_sem_seg = PixelData(**gt_sem_seg_data) gt_sem_seg = PixelData(**gt_sem_seg_data)
gt_seg_data_sample = SegDataSample() @pytest.mark.parametrize('gt_sem_seg', (gt_sem_seg, gt_sem_seg.cuda()))
gt_seg_data_sample.gt_sem_seg = gt_sem_seg def test_cityscapes_add_datasample_forward(gt_sem_seg):
gt_seg_data_sample = SegDataSample()
gt_seg_data_sample.gt_sem_seg = gt_sem_seg
with tempfile.TemporaryDirectory(dir='temp_dir') as tmp_dir:
seg_local_visualizer = SegLocalVisualizer(
vis_backends=[dict(type='LocalVisBackend')],
save_dir='temp_dir')
seg_local_visualizer.dataset_meta = dict(
classes=('road', 'sidewalk', 'building', 'wall', 'fence',
'pole', 'traffic light', 'traffic sign',
'vegetation', 'terrain', 'sky', 'person', 'rider',
'car', 'truck', 'bus', 'train', 'motorcycle',
'bicycle'),
palette=[[128, 64, 128], [244, 35, 232], [70, 70, 70],
[102, 102, 156], [190, 153, 153], [153, 153, 153],
[250, 170, 30], [220, 220, 0], [107, 142, 35],
[152, 251, 152], [70, 130, 180], [220, 20, 60],
[255, 0, 0], [0, 0, 142], [0, 0, 70],
[0, 60, 100], [0, 80, 100], [0, 0, 230],
[119, 11, 32]])
seg_local_visualizer.add_datasample(out_file, image,
gt_seg_data_sample)
seg_local_visualizer = SegLocalVisualizer( # test out_file
vis_backends=[dict(type='LocalVisBackend')], save_dir='temp_dir') seg_local_visualizer.add_datasample(out_file, image,
seg_local_visualizer.dataset_meta = dict( gt_seg_data_sample)
classes=('road', 'sidewalk', 'building', 'wall', 'fence', 'pole', assert os.path.exists(
'traffic light', 'traffic sign', 'vegetation', 'terrain', osp.join(tmp_dir, 'vis_data', 'vis_image',
'sky', 'person', 'rider', 'car', 'truck', 'bus', 'train', out_file + '_0.png'))
'motorcycle', 'bicycle'), drawn_img = cv2.imread(
palette=[[128, 64, 128], [244, 35, 232], [70, 70, 70], osp.join(tmp_dir, 'vis_data', 'vis_image',
[102, 102, 156], [190, 153, 153], [153, 153, 153], out_file + '_0.png'))
[250, 170, 30], [220, 220, 0], [107, 142, 35], assert drawn_img.shape == (h, w, 3)
[152, 251, 152], [70, 130, 180], [220, 20, 60],
[255, 0, 0], [0, 0, 142], [0, 0, 70], [0, 60, 100],
[0, 80, 100], [0, 0, 230], [119, 11, 32]])
seg_local_visualizer.add_datasample(out_file, image,
gt_seg_data_sample)
# test out_file # test gt_instances and pred_instances
seg_local_visualizer.add_datasample(out_file, image, pred_sem_seg_data = dict(
gt_seg_data_sample) data=torch.randint(0, num_class, (1, h, w)))
assert os.path.exists( pred_sem_seg = PixelData(**pred_sem_seg_data)
osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png'))
drawn_img = cv2.imread(
osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png'))
assert drawn_img.shape == (h, w, 3)
os.remove( pred_seg_data_sample = SegDataSample()
osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png')) pred_seg_data_sample.pred_sem_seg = pred_sem_seg
os.rmdir('temp_dir/vis_data/vis_image')
# test gt_instances and pred_instances seg_local_visualizer.add_datasample(out_file, image,
pred_sem_seg_data = dict(data=torch.randint(0, num_class, (1, h, w))) gt_seg_data_sample,
pred_sem_seg = PixelData(**pred_sem_seg_data) pred_seg_data_sample)
self._assert_image_and_shape(
osp.join(tmp_dir, 'vis_data', 'vis_image',
out_file + '_0.png'), (h, w * 2, 3))
pred_seg_data_sample = SegDataSample() seg_local_visualizer.add_datasample(
pred_seg_data_sample.pred_sem_seg = pred_sem_seg out_file,
image,
seg_local_visualizer.add_datasample(out_file, image, gt_seg_data_sample,
gt_seg_data_sample, pred_seg_data_sample,
pred_seg_data_sample) draw_gt=False)
self._assert_image_and_shape( self._assert_image_and_shape(
osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png'), osp.join(tmp_dir, 'vis_data', 'vis_image',
(h, w * 2, 3)) out_file + '_0.png'), (h, w, 3))
seg_local_visualizer.add_datasample(
out_file,
image,
gt_seg_data_sample,
pred_seg_data_sample,
draw_gt=False)
self._assert_image_and_shape(
osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png'),
(h, w, 3))
seg_local_visualizer.add_datasample(
out_file,
image,
gt_seg_data_sample,
pred_seg_data_sample,
draw_pred=False)
self._assert_image_and_shape(
osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png'),
(h, w, 3))
os.rmdir('temp_dir/vis_data')
os.rmdir('temp_dir')
def _assert_image_and_shape(self, out_file, out_shape): def _assert_image_and_shape(self, out_file, out_shape):
assert os.path.exists(out_file) assert os.path.exists(out_file)
drawn_img = cv2.imread(out_file) drawn_img = cv2.imread(out_file)
assert drawn_img.shape == out_shape assert drawn_img.shape == out_shape
os.remove(out_file)
os.rmdir('temp_dir/vis_data/vis_image')