[Fix] Fix SegLocalVisualizer gt_sem_seg cuda tensor error (#1845)
* [Fix] Fix SegLocalVisualizer gt_sem_seg cuda tensor error * fix ut error and add config visualizer dict * fix ut errorpull/1850/head
parent
5d9650838e
commit
7369d50049
|
@ -4,6 +4,9 @@ env_cfg = dict(
|
|||
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
|
||||
dist_cfg=dict(backend='nccl'),
|
||||
)
|
||||
vis_backends = [dict(type='LocalVisBackend')]
|
||||
visualizer = dict(
|
||||
type='SegLocalVisualizer', vis_backends=vis_backends, name='visualizer')
|
||||
log_level = 'INFO'
|
||||
load_from = None
|
||||
resume = False
|
||||
|
|
|
@ -81,7 +81,7 @@ class SegLocalVisualizer(Visualizer):
|
|||
"""
|
||||
num_classes = len(classes)
|
||||
|
||||
sem_seg = sem_seg.data
|
||||
sem_seg = sem_seg.cpu().data
|
||||
ids = np.unique(sem_seg)[::-1]
|
||||
legal_indices = ids < num_classes
|
||||
ids = ids[legal_indices]
|
||||
|
|
|
@ -1,11 +1,13 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os
|
||||
import os.path as osp
|
||||
import tempfile
|
||||
from unittest import TestCase
|
||||
|
||||
import cv2
|
||||
import mmcv
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from mmengine.data import PixelData
|
||||
|
||||
|
@ -27,66 +29,55 @@ class TestSegLocalVisualizer(TestCase):
|
|||
gt_sem_seg_data = dict(data=torch.randint(0, num_class, (1, h, w)))
|
||||
gt_sem_seg = PixelData(**gt_sem_seg_data)
|
||||
|
||||
gt_seg_data_sample = SegDataSample()
|
||||
gt_seg_data_sample.gt_sem_seg = gt_sem_seg
|
||||
@pytest.mark.parametrize('gt_sem_seg', (gt_sem_seg, gt_sem_seg.cuda()))
|
||||
def test_add_datasample_forward(gt_sem_seg):
|
||||
gt_seg_data_sample = SegDataSample()
|
||||
gt_seg_data_sample.gt_sem_seg = gt_sem_seg
|
||||
|
||||
seg_local_visualizer = SegLocalVisualizer(
|
||||
vis_backends=[dict(type='LocalVisBackend')], save_dir='temp_dir')
|
||||
seg_local_visualizer.dataset_meta = dict(
|
||||
classes=('background', 'foreground'),
|
||||
palette=[[120, 120, 120], [6, 230, 230]])
|
||||
seg_local_visualizer.add_datasample(out_file, image,
|
||||
gt_seg_data_sample)
|
||||
with tempfile.TemporaryDirectory(dir='temp_dir') as tmp_dir:
|
||||
seg_local_visualizer = SegLocalVisualizer(
|
||||
vis_backends=[dict(type='LocalVisBackend')],
|
||||
save_dir=tmp_dir)
|
||||
seg_local_visualizer.dataset_meta = dict(
|
||||
classes=('background', 'foreground'),
|
||||
palette=[[120, 120, 120], [6, 230, 230]])
|
||||
|
||||
# test out_file
|
||||
seg_local_visualizer.add_datasample(out_file, image,
|
||||
gt_seg_data_sample)
|
||||
# test out_file
|
||||
seg_local_visualizer.add_datasample(out_file, image,
|
||||
gt_seg_data_sample)
|
||||
|
||||
assert os.path.exists(
|
||||
osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png'))
|
||||
drawn_img = cv2.imread(
|
||||
osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png'))
|
||||
assert drawn_img.shape == (h, w, 3)
|
||||
assert os.path.exists(
|
||||
osp.join(tmp_dir, 'vis_data', 'vis_image',
|
||||
out_file + '_0.png'))
|
||||
drawn_img = cv2.imread(
|
||||
osp.join(tmp_dir, 'vis_data', 'vis_image',
|
||||
out_file + '_0.png'))
|
||||
assert drawn_img.shape == (h, w, 3)
|
||||
|
||||
os.remove(
|
||||
osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png'))
|
||||
os.rmdir('temp_dir' + '/vis_data/vis_image')
|
||||
# test gt_instances and pred_instances
|
||||
pred_sem_seg_data = dict(
|
||||
data=torch.randint(0, num_class, (1, h, w)))
|
||||
pred_sem_seg = PixelData(**pred_sem_seg_data)
|
||||
|
||||
# test gt_instances and pred_instances
|
||||
pred_sem_seg_data = dict(data=torch.randint(0, num_class, (1, h, w)))
|
||||
pred_sem_seg = PixelData(**pred_sem_seg_data)
|
||||
pred_seg_data_sample = SegDataSample()
|
||||
pred_seg_data_sample.pred_sem_seg = pred_sem_seg
|
||||
|
||||
pred_seg_data_sample = SegDataSample()
|
||||
pred_seg_data_sample.pred_sem_seg = pred_sem_seg
|
||||
seg_local_visualizer.add_datasample(out_file, image,
|
||||
gt_seg_data_sample,
|
||||
pred_seg_data_sample)
|
||||
self._assert_image_and_shape(
|
||||
osp.join(tmp_dir, 'vis_data', 'vis_image',
|
||||
out_file + '_0.png'), (h, w * 2, 3))
|
||||
|
||||
seg_local_visualizer.add_datasample(out_file, image,
|
||||
gt_seg_data_sample,
|
||||
pred_seg_data_sample)
|
||||
self._assert_image_and_shape(
|
||||
osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png'),
|
||||
(h, w * 2, 3))
|
||||
|
||||
seg_local_visualizer.add_datasample(
|
||||
out_file,
|
||||
image,
|
||||
gt_seg_data_sample,
|
||||
pred_seg_data_sample,
|
||||
draw_gt=False)
|
||||
self._assert_image_and_shape(
|
||||
osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png'),
|
||||
(h, w, 3))
|
||||
|
||||
seg_local_visualizer.add_datasample(
|
||||
out_file,
|
||||
image,
|
||||
gt_seg_data_sample,
|
||||
pred_seg_data_sample,
|
||||
draw_pred=False)
|
||||
self._assert_image_and_shape(
|
||||
osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png'),
|
||||
(h, w, 3))
|
||||
os.rmdir('temp_dir/vis_data')
|
||||
os.rmdir('temp_dir')
|
||||
seg_local_visualizer.add_datasample(
|
||||
out_file,
|
||||
image,
|
||||
gt_seg_data_sample,
|
||||
pred_seg_data_sample,
|
||||
draw_gt=False)
|
||||
self._assert_image_and_shape(
|
||||
osp.join(tmp_dir, 'vis_data', 'vis_image',
|
||||
out_file + '_0.png'), (h, w, 3))
|
||||
|
||||
def test_cityscapes_add_datasample(self):
|
||||
h = 128
|
||||
|
@ -110,78 +101,67 @@ class TestSegLocalVisualizer(TestCase):
|
|||
gt_sem_seg_data = dict(data=sem_seg)
|
||||
gt_sem_seg = PixelData(**gt_sem_seg_data)
|
||||
|
||||
gt_seg_data_sample = SegDataSample()
|
||||
gt_seg_data_sample.gt_sem_seg = gt_sem_seg
|
||||
@pytest.mark.parametrize('gt_sem_seg', (gt_sem_seg, gt_sem_seg.cuda()))
|
||||
def test_cityscapes_add_datasample_forward(gt_sem_seg):
|
||||
gt_seg_data_sample = SegDataSample()
|
||||
gt_seg_data_sample.gt_sem_seg = gt_sem_seg
|
||||
with tempfile.TemporaryDirectory(dir='temp_dir') as tmp_dir:
|
||||
seg_local_visualizer = SegLocalVisualizer(
|
||||
vis_backends=[dict(type='LocalVisBackend')],
|
||||
save_dir='temp_dir')
|
||||
seg_local_visualizer.dataset_meta = dict(
|
||||
classes=('road', 'sidewalk', 'building', 'wall', 'fence',
|
||||
'pole', 'traffic light', 'traffic sign',
|
||||
'vegetation', 'terrain', 'sky', 'person', 'rider',
|
||||
'car', 'truck', 'bus', 'train', 'motorcycle',
|
||||
'bicycle'),
|
||||
palette=[[128, 64, 128], [244, 35, 232], [70, 70, 70],
|
||||
[102, 102, 156], [190, 153, 153], [153, 153, 153],
|
||||
[250, 170, 30], [220, 220, 0], [107, 142, 35],
|
||||
[152, 251, 152], [70, 130, 180], [220, 20, 60],
|
||||
[255, 0, 0], [0, 0, 142], [0, 0, 70],
|
||||
[0, 60, 100], [0, 80, 100], [0, 0, 230],
|
||||
[119, 11, 32]])
|
||||
seg_local_visualizer.add_datasample(out_file, image,
|
||||
gt_seg_data_sample)
|
||||
|
||||
seg_local_visualizer = SegLocalVisualizer(
|
||||
vis_backends=[dict(type='LocalVisBackend')], save_dir='temp_dir')
|
||||
seg_local_visualizer.dataset_meta = dict(
|
||||
classes=('road', 'sidewalk', 'building', 'wall', 'fence', 'pole',
|
||||
'traffic light', 'traffic sign', 'vegetation', 'terrain',
|
||||
'sky', 'person', 'rider', 'car', 'truck', 'bus', 'train',
|
||||
'motorcycle', 'bicycle'),
|
||||
palette=[[128, 64, 128], [244, 35, 232], [70, 70, 70],
|
||||
[102, 102, 156], [190, 153, 153], [153, 153, 153],
|
||||
[250, 170, 30], [220, 220, 0], [107, 142, 35],
|
||||
[152, 251, 152], [70, 130, 180], [220, 20, 60],
|
||||
[255, 0, 0], [0, 0, 142], [0, 0, 70], [0, 60, 100],
|
||||
[0, 80, 100], [0, 0, 230], [119, 11, 32]])
|
||||
seg_local_visualizer.add_datasample(out_file, image,
|
||||
gt_seg_data_sample)
|
||||
# test out_file
|
||||
seg_local_visualizer.add_datasample(out_file, image,
|
||||
gt_seg_data_sample)
|
||||
assert os.path.exists(
|
||||
osp.join(tmp_dir, 'vis_data', 'vis_image',
|
||||
out_file + '_0.png'))
|
||||
drawn_img = cv2.imread(
|
||||
osp.join(tmp_dir, 'vis_data', 'vis_image',
|
||||
out_file + '_0.png'))
|
||||
assert drawn_img.shape == (h, w, 3)
|
||||
|
||||
# test out_file
|
||||
seg_local_visualizer.add_datasample(out_file, image,
|
||||
gt_seg_data_sample)
|
||||
assert os.path.exists(
|
||||
osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png'))
|
||||
drawn_img = cv2.imread(
|
||||
osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png'))
|
||||
assert drawn_img.shape == (h, w, 3)
|
||||
# test gt_instances and pred_instances
|
||||
pred_sem_seg_data = dict(
|
||||
data=torch.randint(0, num_class, (1, h, w)))
|
||||
pred_sem_seg = PixelData(**pred_sem_seg_data)
|
||||
|
||||
os.remove(
|
||||
osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png'))
|
||||
os.rmdir('temp_dir/vis_data/vis_image')
|
||||
pred_seg_data_sample = SegDataSample()
|
||||
pred_seg_data_sample.pred_sem_seg = pred_sem_seg
|
||||
|
||||
# test gt_instances and pred_instances
|
||||
pred_sem_seg_data = dict(data=torch.randint(0, num_class, (1, h, w)))
|
||||
pred_sem_seg = PixelData(**pred_sem_seg_data)
|
||||
seg_local_visualizer.add_datasample(out_file, image,
|
||||
gt_seg_data_sample,
|
||||
pred_seg_data_sample)
|
||||
self._assert_image_and_shape(
|
||||
osp.join(tmp_dir, 'vis_data', 'vis_image',
|
||||
out_file + '_0.png'), (h, w * 2, 3))
|
||||
|
||||
pred_seg_data_sample = SegDataSample()
|
||||
pred_seg_data_sample.pred_sem_seg = pred_sem_seg
|
||||
|
||||
seg_local_visualizer.add_datasample(out_file, image,
|
||||
gt_seg_data_sample,
|
||||
pred_seg_data_sample)
|
||||
self._assert_image_and_shape(
|
||||
osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png'),
|
||||
(h, w * 2, 3))
|
||||
|
||||
seg_local_visualizer.add_datasample(
|
||||
out_file,
|
||||
image,
|
||||
gt_seg_data_sample,
|
||||
pred_seg_data_sample,
|
||||
draw_gt=False)
|
||||
self._assert_image_and_shape(
|
||||
osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png'),
|
||||
(h, w, 3))
|
||||
|
||||
seg_local_visualizer.add_datasample(
|
||||
out_file,
|
||||
image,
|
||||
gt_seg_data_sample,
|
||||
pred_seg_data_sample,
|
||||
draw_pred=False)
|
||||
self._assert_image_and_shape(
|
||||
osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png'),
|
||||
(h, w, 3))
|
||||
|
||||
os.rmdir('temp_dir/vis_data')
|
||||
os.rmdir('temp_dir')
|
||||
seg_local_visualizer.add_datasample(
|
||||
out_file,
|
||||
image,
|
||||
gt_seg_data_sample,
|
||||
pred_seg_data_sample,
|
||||
draw_gt=False)
|
||||
self._assert_image_and_shape(
|
||||
osp.join(tmp_dir, 'vis_data', 'vis_image',
|
||||
out_file + '_0.png'), (h, w, 3))
|
||||
|
||||
def _assert_image_and_shape(self, out_file, out_shape):
|
||||
assert os.path.exists(out_file)
|
||||
drawn_img = cv2.imread(out_file)
|
||||
assert drawn_img.shape == out_shape
|
||||
os.remove(out_file)
|
||||
os.rmdir('temp_dir/vis_data/vis_image')
|
||||
|
|
Loading…
Reference in New Issue