[Fix] Fix SegLocalVisualizer gt_sem_seg cuda tensor error (#1845)

* [Fix] Fix SegLocalVisualizer gt_sem_seg cuda tensor error

* fix ut error and add config visualizer dict

* fix ut error
pull/1850/head
MengzhangLI 2022-08-01 15:03:01 +08:00 committed by GitHub
parent 5d9650838e
commit 7369d50049
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 102 additions and 119 deletions

View File

@ -4,6 +4,9 @@ env_cfg = dict(
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
dist_cfg=dict(backend='nccl'),
)
vis_backends = [dict(type='LocalVisBackend')]
visualizer = dict(
type='SegLocalVisualizer', vis_backends=vis_backends, name='visualizer')
log_level = 'INFO'
load_from = None
resume = False

View File

@ -81,7 +81,7 @@ class SegLocalVisualizer(Visualizer):
"""
num_classes = len(classes)
sem_seg = sem_seg.data
sem_seg = sem_seg.cpu().data
ids = np.unique(sem_seg)[::-1]
legal_indices = ids < num_classes
ids = ids[legal_indices]

View File

@ -1,11 +1,13 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os
import os.path as osp
import tempfile
from unittest import TestCase
import cv2
import mmcv
import numpy as np
import pytest
import torch
from mmengine.data import PixelData
@ -27,66 +29,55 @@ class TestSegLocalVisualizer(TestCase):
gt_sem_seg_data = dict(data=torch.randint(0, num_class, (1, h, w)))
gt_sem_seg = PixelData(**gt_sem_seg_data)
gt_seg_data_sample = SegDataSample()
gt_seg_data_sample.gt_sem_seg = gt_sem_seg
@pytest.mark.parametrize('gt_sem_seg', (gt_sem_seg, gt_sem_seg.cuda()))
def test_add_datasample_forward(gt_sem_seg):
gt_seg_data_sample = SegDataSample()
gt_seg_data_sample.gt_sem_seg = gt_sem_seg
seg_local_visualizer = SegLocalVisualizer(
vis_backends=[dict(type='LocalVisBackend')], save_dir='temp_dir')
seg_local_visualizer.dataset_meta = dict(
classes=('background', 'foreground'),
palette=[[120, 120, 120], [6, 230, 230]])
seg_local_visualizer.add_datasample(out_file, image,
gt_seg_data_sample)
with tempfile.TemporaryDirectory(dir='temp_dir') as tmp_dir:
seg_local_visualizer = SegLocalVisualizer(
vis_backends=[dict(type='LocalVisBackend')],
save_dir=tmp_dir)
seg_local_visualizer.dataset_meta = dict(
classes=('background', 'foreground'),
palette=[[120, 120, 120], [6, 230, 230]])
# test out_file
seg_local_visualizer.add_datasample(out_file, image,
gt_seg_data_sample)
# test out_file
seg_local_visualizer.add_datasample(out_file, image,
gt_seg_data_sample)
assert os.path.exists(
osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png'))
drawn_img = cv2.imread(
osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png'))
assert drawn_img.shape == (h, w, 3)
assert os.path.exists(
osp.join(tmp_dir, 'vis_data', 'vis_image',
out_file + '_0.png'))
drawn_img = cv2.imread(
osp.join(tmp_dir, 'vis_data', 'vis_image',
out_file + '_0.png'))
assert drawn_img.shape == (h, w, 3)
os.remove(
osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png'))
os.rmdir('temp_dir' + '/vis_data/vis_image')
# test gt_instances and pred_instances
pred_sem_seg_data = dict(
data=torch.randint(0, num_class, (1, h, w)))
pred_sem_seg = PixelData(**pred_sem_seg_data)
# test gt_instances and pred_instances
pred_sem_seg_data = dict(data=torch.randint(0, num_class, (1, h, w)))
pred_sem_seg = PixelData(**pred_sem_seg_data)
pred_seg_data_sample = SegDataSample()
pred_seg_data_sample.pred_sem_seg = pred_sem_seg
pred_seg_data_sample = SegDataSample()
pred_seg_data_sample.pred_sem_seg = pred_sem_seg
seg_local_visualizer.add_datasample(out_file, image,
gt_seg_data_sample,
pred_seg_data_sample)
self._assert_image_and_shape(
osp.join(tmp_dir, 'vis_data', 'vis_image',
out_file + '_0.png'), (h, w * 2, 3))
seg_local_visualizer.add_datasample(out_file, image,
gt_seg_data_sample,
pred_seg_data_sample)
self._assert_image_and_shape(
osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png'),
(h, w * 2, 3))
seg_local_visualizer.add_datasample(
out_file,
image,
gt_seg_data_sample,
pred_seg_data_sample,
draw_gt=False)
self._assert_image_and_shape(
osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png'),
(h, w, 3))
seg_local_visualizer.add_datasample(
out_file,
image,
gt_seg_data_sample,
pred_seg_data_sample,
draw_pred=False)
self._assert_image_and_shape(
osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png'),
(h, w, 3))
os.rmdir('temp_dir/vis_data')
os.rmdir('temp_dir')
seg_local_visualizer.add_datasample(
out_file,
image,
gt_seg_data_sample,
pred_seg_data_sample,
draw_gt=False)
self._assert_image_and_shape(
osp.join(tmp_dir, 'vis_data', 'vis_image',
out_file + '_0.png'), (h, w, 3))
def test_cityscapes_add_datasample(self):
h = 128
@ -110,78 +101,67 @@ class TestSegLocalVisualizer(TestCase):
gt_sem_seg_data = dict(data=sem_seg)
gt_sem_seg = PixelData(**gt_sem_seg_data)
gt_seg_data_sample = SegDataSample()
gt_seg_data_sample.gt_sem_seg = gt_sem_seg
@pytest.mark.parametrize('gt_sem_seg', (gt_sem_seg, gt_sem_seg.cuda()))
def test_cityscapes_add_datasample_forward(gt_sem_seg):
gt_seg_data_sample = SegDataSample()
gt_seg_data_sample.gt_sem_seg = gt_sem_seg
with tempfile.TemporaryDirectory(dir='temp_dir') as tmp_dir:
seg_local_visualizer = SegLocalVisualizer(
vis_backends=[dict(type='LocalVisBackend')],
save_dir='temp_dir')
seg_local_visualizer.dataset_meta = dict(
classes=('road', 'sidewalk', 'building', 'wall', 'fence',
'pole', 'traffic light', 'traffic sign',
'vegetation', 'terrain', 'sky', 'person', 'rider',
'car', 'truck', 'bus', 'train', 'motorcycle',
'bicycle'),
palette=[[128, 64, 128], [244, 35, 232], [70, 70, 70],
[102, 102, 156], [190, 153, 153], [153, 153, 153],
[250, 170, 30], [220, 220, 0], [107, 142, 35],
[152, 251, 152], [70, 130, 180], [220, 20, 60],
[255, 0, 0], [0, 0, 142], [0, 0, 70],
[0, 60, 100], [0, 80, 100], [0, 0, 230],
[119, 11, 32]])
seg_local_visualizer.add_datasample(out_file, image,
gt_seg_data_sample)
seg_local_visualizer = SegLocalVisualizer(
vis_backends=[dict(type='LocalVisBackend')], save_dir='temp_dir')
seg_local_visualizer.dataset_meta = dict(
classes=('road', 'sidewalk', 'building', 'wall', 'fence', 'pole',
'traffic light', 'traffic sign', 'vegetation', 'terrain',
'sky', 'person', 'rider', 'car', 'truck', 'bus', 'train',
'motorcycle', 'bicycle'),
palette=[[128, 64, 128], [244, 35, 232], [70, 70, 70],
[102, 102, 156], [190, 153, 153], [153, 153, 153],
[250, 170, 30], [220, 220, 0], [107, 142, 35],
[152, 251, 152], [70, 130, 180], [220, 20, 60],
[255, 0, 0], [0, 0, 142], [0, 0, 70], [0, 60, 100],
[0, 80, 100], [0, 0, 230], [119, 11, 32]])
seg_local_visualizer.add_datasample(out_file, image,
gt_seg_data_sample)
# test out_file
seg_local_visualizer.add_datasample(out_file, image,
gt_seg_data_sample)
assert os.path.exists(
osp.join(tmp_dir, 'vis_data', 'vis_image',
out_file + '_0.png'))
drawn_img = cv2.imread(
osp.join(tmp_dir, 'vis_data', 'vis_image',
out_file + '_0.png'))
assert drawn_img.shape == (h, w, 3)
# test out_file
seg_local_visualizer.add_datasample(out_file, image,
gt_seg_data_sample)
assert os.path.exists(
osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png'))
drawn_img = cv2.imread(
osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png'))
assert drawn_img.shape == (h, w, 3)
# test gt_instances and pred_instances
pred_sem_seg_data = dict(
data=torch.randint(0, num_class, (1, h, w)))
pred_sem_seg = PixelData(**pred_sem_seg_data)
os.remove(
osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png'))
os.rmdir('temp_dir/vis_data/vis_image')
pred_seg_data_sample = SegDataSample()
pred_seg_data_sample.pred_sem_seg = pred_sem_seg
# test gt_instances and pred_instances
pred_sem_seg_data = dict(data=torch.randint(0, num_class, (1, h, w)))
pred_sem_seg = PixelData(**pred_sem_seg_data)
seg_local_visualizer.add_datasample(out_file, image,
gt_seg_data_sample,
pred_seg_data_sample)
self._assert_image_and_shape(
osp.join(tmp_dir, 'vis_data', 'vis_image',
out_file + '_0.png'), (h, w * 2, 3))
pred_seg_data_sample = SegDataSample()
pred_seg_data_sample.pred_sem_seg = pred_sem_seg
seg_local_visualizer.add_datasample(out_file, image,
gt_seg_data_sample,
pred_seg_data_sample)
self._assert_image_and_shape(
osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png'),
(h, w * 2, 3))
seg_local_visualizer.add_datasample(
out_file,
image,
gt_seg_data_sample,
pred_seg_data_sample,
draw_gt=False)
self._assert_image_and_shape(
osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png'),
(h, w, 3))
seg_local_visualizer.add_datasample(
out_file,
image,
gt_seg_data_sample,
pred_seg_data_sample,
draw_pred=False)
self._assert_image_and_shape(
osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png'),
(h, w, 3))
os.rmdir('temp_dir/vis_data')
os.rmdir('temp_dir')
seg_local_visualizer.add_datasample(
out_file,
image,
gt_seg_data_sample,
pred_seg_data_sample,
draw_gt=False)
self._assert_image_and_shape(
osp.join(tmp_dir, 'vis_data', 'vis_image',
out_file + '_0.png'), (h, w, 3))
def _assert_image_and_shape(self, out_file, out_shape):
assert os.path.exists(out_file)
drawn_img = cv2.imread(out_file)
assert drawn_img.shape == out_shape
os.remove(out_file)
os.rmdir('temp_dir/vis_data/vis_image')