[Fix] Fix SegLocalVisualizer gt_sem_seg cuda tensor error (#1845)
* [Fix] Fix SegLocalVisualizer gt_sem_seg cuda tensor error * fix ut error and add config visualizer dict * fix ut errorpull/1850/head
parent
5d9650838e
commit
7369d50049
|
@ -4,6 +4,9 @@ env_cfg = dict(
|
||||||
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
|
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
|
||||||
dist_cfg=dict(backend='nccl'),
|
dist_cfg=dict(backend='nccl'),
|
||||||
)
|
)
|
||||||
|
vis_backends = [dict(type='LocalVisBackend')]
|
||||||
|
visualizer = dict(
|
||||||
|
type='SegLocalVisualizer', vis_backends=vis_backends, name='visualizer')
|
||||||
log_level = 'INFO'
|
log_level = 'INFO'
|
||||||
load_from = None
|
load_from = None
|
||||||
resume = False
|
resume = False
|
||||||
|
|
|
@ -81,7 +81,7 @@ class SegLocalVisualizer(Visualizer):
|
||||||
"""
|
"""
|
||||||
num_classes = len(classes)
|
num_classes = len(classes)
|
||||||
|
|
||||||
sem_seg = sem_seg.data
|
sem_seg = sem_seg.cpu().data
|
||||||
ids = np.unique(sem_seg)[::-1]
|
ids = np.unique(sem_seg)[::-1]
|
||||||
legal_indices = ids < num_classes
|
legal_indices = ids < num_classes
|
||||||
ids = ids[legal_indices]
|
ids = ids[legal_indices]
|
||||||
|
|
|
@ -1,11 +1,13 @@
|
||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
import os
|
import os
|
||||||
import os.path as osp
|
import os.path as osp
|
||||||
|
import tempfile
|
||||||
from unittest import TestCase
|
from unittest import TestCase
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
import mmcv
|
import mmcv
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
from mmengine.data import PixelData
|
from mmengine.data import PixelData
|
||||||
|
|
||||||
|
@ -27,66 +29,55 @@ class TestSegLocalVisualizer(TestCase):
|
||||||
gt_sem_seg_data = dict(data=torch.randint(0, num_class, (1, h, w)))
|
gt_sem_seg_data = dict(data=torch.randint(0, num_class, (1, h, w)))
|
||||||
gt_sem_seg = PixelData(**gt_sem_seg_data)
|
gt_sem_seg = PixelData(**gt_sem_seg_data)
|
||||||
|
|
||||||
gt_seg_data_sample = SegDataSample()
|
@pytest.mark.parametrize('gt_sem_seg', (gt_sem_seg, gt_sem_seg.cuda()))
|
||||||
gt_seg_data_sample.gt_sem_seg = gt_sem_seg
|
def test_add_datasample_forward(gt_sem_seg):
|
||||||
|
gt_seg_data_sample = SegDataSample()
|
||||||
|
gt_seg_data_sample.gt_sem_seg = gt_sem_seg
|
||||||
|
|
||||||
seg_local_visualizer = SegLocalVisualizer(
|
with tempfile.TemporaryDirectory(dir='temp_dir') as tmp_dir:
|
||||||
vis_backends=[dict(type='LocalVisBackend')], save_dir='temp_dir')
|
seg_local_visualizer = SegLocalVisualizer(
|
||||||
seg_local_visualizer.dataset_meta = dict(
|
vis_backends=[dict(type='LocalVisBackend')],
|
||||||
classes=('background', 'foreground'),
|
save_dir=tmp_dir)
|
||||||
palette=[[120, 120, 120], [6, 230, 230]])
|
seg_local_visualizer.dataset_meta = dict(
|
||||||
seg_local_visualizer.add_datasample(out_file, image,
|
classes=('background', 'foreground'),
|
||||||
gt_seg_data_sample)
|
palette=[[120, 120, 120], [6, 230, 230]])
|
||||||
|
|
||||||
# test out_file
|
# test out_file
|
||||||
seg_local_visualizer.add_datasample(out_file, image,
|
seg_local_visualizer.add_datasample(out_file, image,
|
||||||
gt_seg_data_sample)
|
gt_seg_data_sample)
|
||||||
|
|
||||||
assert os.path.exists(
|
assert os.path.exists(
|
||||||
osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png'))
|
osp.join(tmp_dir, 'vis_data', 'vis_image',
|
||||||
drawn_img = cv2.imread(
|
out_file + '_0.png'))
|
||||||
osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png'))
|
drawn_img = cv2.imread(
|
||||||
assert drawn_img.shape == (h, w, 3)
|
osp.join(tmp_dir, 'vis_data', 'vis_image',
|
||||||
|
out_file + '_0.png'))
|
||||||
|
assert drawn_img.shape == (h, w, 3)
|
||||||
|
|
||||||
os.remove(
|
# test gt_instances and pred_instances
|
||||||
osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png'))
|
pred_sem_seg_data = dict(
|
||||||
os.rmdir('temp_dir' + '/vis_data/vis_image')
|
data=torch.randint(0, num_class, (1, h, w)))
|
||||||
|
pred_sem_seg = PixelData(**pred_sem_seg_data)
|
||||||
|
|
||||||
# test gt_instances and pred_instances
|
pred_seg_data_sample = SegDataSample()
|
||||||
pred_sem_seg_data = dict(data=torch.randint(0, num_class, (1, h, w)))
|
pred_seg_data_sample.pred_sem_seg = pred_sem_seg
|
||||||
pred_sem_seg = PixelData(**pred_sem_seg_data)
|
|
||||||
|
|
||||||
pred_seg_data_sample = SegDataSample()
|
seg_local_visualizer.add_datasample(out_file, image,
|
||||||
pred_seg_data_sample.pred_sem_seg = pred_sem_seg
|
gt_seg_data_sample,
|
||||||
|
pred_seg_data_sample)
|
||||||
|
self._assert_image_and_shape(
|
||||||
|
osp.join(tmp_dir, 'vis_data', 'vis_image',
|
||||||
|
out_file + '_0.png'), (h, w * 2, 3))
|
||||||
|
|
||||||
seg_local_visualizer.add_datasample(out_file, image,
|
seg_local_visualizer.add_datasample(
|
||||||
gt_seg_data_sample,
|
out_file,
|
||||||
pred_seg_data_sample)
|
image,
|
||||||
self._assert_image_and_shape(
|
gt_seg_data_sample,
|
||||||
osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png'),
|
pred_seg_data_sample,
|
||||||
(h, w * 2, 3))
|
draw_gt=False)
|
||||||
|
self._assert_image_and_shape(
|
||||||
seg_local_visualizer.add_datasample(
|
osp.join(tmp_dir, 'vis_data', 'vis_image',
|
||||||
out_file,
|
out_file + '_0.png'), (h, w, 3))
|
||||||
image,
|
|
||||||
gt_seg_data_sample,
|
|
||||||
pred_seg_data_sample,
|
|
||||||
draw_gt=False)
|
|
||||||
self._assert_image_and_shape(
|
|
||||||
osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png'),
|
|
||||||
(h, w, 3))
|
|
||||||
|
|
||||||
seg_local_visualizer.add_datasample(
|
|
||||||
out_file,
|
|
||||||
image,
|
|
||||||
gt_seg_data_sample,
|
|
||||||
pred_seg_data_sample,
|
|
||||||
draw_pred=False)
|
|
||||||
self._assert_image_and_shape(
|
|
||||||
osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png'),
|
|
||||||
(h, w, 3))
|
|
||||||
os.rmdir('temp_dir/vis_data')
|
|
||||||
os.rmdir('temp_dir')
|
|
||||||
|
|
||||||
def test_cityscapes_add_datasample(self):
|
def test_cityscapes_add_datasample(self):
|
||||||
h = 128
|
h = 128
|
||||||
|
@ -110,78 +101,67 @@ class TestSegLocalVisualizer(TestCase):
|
||||||
gt_sem_seg_data = dict(data=sem_seg)
|
gt_sem_seg_data = dict(data=sem_seg)
|
||||||
gt_sem_seg = PixelData(**gt_sem_seg_data)
|
gt_sem_seg = PixelData(**gt_sem_seg_data)
|
||||||
|
|
||||||
gt_seg_data_sample = SegDataSample()
|
@pytest.mark.parametrize('gt_sem_seg', (gt_sem_seg, gt_sem_seg.cuda()))
|
||||||
gt_seg_data_sample.gt_sem_seg = gt_sem_seg
|
def test_cityscapes_add_datasample_forward(gt_sem_seg):
|
||||||
|
gt_seg_data_sample = SegDataSample()
|
||||||
|
gt_seg_data_sample.gt_sem_seg = gt_sem_seg
|
||||||
|
with tempfile.TemporaryDirectory(dir='temp_dir') as tmp_dir:
|
||||||
|
seg_local_visualizer = SegLocalVisualizer(
|
||||||
|
vis_backends=[dict(type='LocalVisBackend')],
|
||||||
|
save_dir='temp_dir')
|
||||||
|
seg_local_visualizer.dataset_meta = dict(
|
||||||
|
classes=('road', 'sidewalk', 'building', 'wall', 'fence',
|
||||||
|
'pole', 'traffic light', 'traffic sign',
|
||||||
|
'vegetation', 'terrain', 'sky', 'person', 'rider',
|
||||||
|
'car', 'truck', 'bus', 'train', 'motorcycle',
|
||||||
|
'bicycle'),
|
||||||
|
palette=[[128, 64, 128], [244, 35, 232], [70, 70, 70],
|
||||||
|
[102, 102, 156], [190, 153, 153], [153, 153, 153],
|
||||||
|
[250, 170, 30], [220, 220, 0], [107, 142, 35],
|
||||||
|
[152, 251, 152], [70, 130, 180], [220, 20, 60],
|
||||||
|
[255, 0, 0], [0, 0, 142], [0, 0, 70],
|
||||||
|
[0, 60, 100], [0, 80, 100], [0, 0, 230],
|
||||||
|
[119, 11, 32]])
|
||||||
|
seg_local_visualizer.add_datasample(out_file, image,
|
||||||
|
gt_seg_data_sample)
|
||||||
|
|
||||||
seg_local_visualizer = SegLocalVisualizer(
|
# test out_file
|
||||||
vis_backends=[dict(type='LocalVisBackend')], save_dir='temp_dir')
|
seg_local_visualizer.add_datasample(out_file, image,
|
||||||
seg_local_visualizer.dataset_meta = dict(
|
gt_seg_data_sample)
|
||||||
classes=('road', 'sidewalk', 'building', 'wall', 'fence', 'pole',
|
assert os.path.exists(
|
||||||
'traffic light', 'traffic sign', 'vegetation', 'terrain',
|
osp.join(tmp_dir, 'vis_data', 'vis_image',
|
||||||
'sky', 'person', 'rider', 'car', 'truck', 'bus', 'train',
|
out_file + '_0.png'))
|
||||||
'motorcycle', 'bicycle'),
|
drawn_img = cv2.imread(
|
||||||
palette=[[128, 64, 128], [244, 35, 232], [70, 70, 70],
|
osp.join(tmp_dir, 'vis_data', 'vis_image',
|
||||||
[102, 102, 156], [190, 153, 153], [153, 153, 153],
|
out_file + '_0.png'))
|
||||||
[250, 170, 30], [220, 220, 0], [107, 142, 35],
|
assert drawn_img.shape == (h, w, 3)
|
||||||
[152, 251, 152], [70, 130, 180], [220, 20, 60],
|
|
||||||
[255, 0, 0], [0, 0, 142], [0, 0, 70], [0, 60, 100],
|
|
||||||
[0, 80, 100], [0, 0, 230], [119, 11, 32]])
|
|
||||||
seg_local_visualizer.add_datasample(out_file, image,
|
|
||||||
gt_seg_data_sample)
|
|
||||||
|
|
||||||
# test out_file
|
# test gt_instances and pred_instances
|
||||||
seg_local_visualizer.add_datasample(out_file, image,
|
pred_sem_seg_data = dict(
|
||||||
gt_seg_data_sample)
|
data=torch.randint(0, num_class, (1, h, w)))
|
||||||
assert os.path.exists(
|
pred_sem_seg = PixelData(**pred_sem_seg_data)
|
||||||
osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png'))
|
|
||||||
drawn_img = cv2.imread(
|
|
||||||
osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png'))
|
|
||||||
assert drawn_img.shape == (h, w, 3)
|
|
||||||
|
|
||||||
os.remove(
|
pred_seg_data_sample = SegDataSample()
|
||||||
osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png'))
|
pred_seg_data_sample.pred_sem_seg = pred_sem_seg
|
||||||
os.rmdir('temp_dir/vis_data/vis_image')
|
|
||||||
|
|
||||||
# test gt_instances and pred_instances
|
seg_local_visualizer.add_datasample(out_file, image,
|
||||||
pred_sem_seg_data = dict(data=torch.randint(0, num_class, (1, h, w)))
|
gt_seg_data_sample,
|
||||||
pred_sem_seg = PixelData(**pred_sem_seg_data)
|
pred_seg_data_sample)
|
||||||
|
self._assert_image_and_shape(
|
||||||
|
osp.join(tmp_dir, 'vis_data', 'vis_image',
|
||||||
|
out_file + '_0.png'), (h, w * 2, 3))
|
||||||
|
|
||||||
pred_seg_data_sample = SegDataSample()
|
seg_local_visualizer.add_datasample(
|
||||||
pred_seg_data_sample.pred_sem_seg = pred_sem_seg
|
out_file,
|
||||||
|
image,
|
||||||
seg_local_visualizer.add_datasample(out_file, image,
|
gt_seg_data_sample,
|
||||||
gt_seg_data_sample,
|
pred_seg_data_sample,
|
||||||
pred_seg_data_sample)
|
draw_gt=False)
|
||||||
self._assert_image_and_shape(
|
self._assert_image_and_shape(
|
||||||
osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png'),
|
osp.join(tmp_dir, 'vis_data', 'vis_image',
|
||||||
(h, w * 2, 3))
|
out_file + '_0.png'), (h, w, 3))
|
||||||
|
|
||||||
seg_local_visualizer.add_datasample(
|
|
||||||
out_file,
|
|
||||||
image,
|
|
||||||
gt_seg_data_sample,
|
|
||||||
pred_seg_data_sample,
|
|
||||||
draw_gt=False)
|
|
||||||
self._assert_image_and_shape(
|
|
||||||
osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png'),
|
|
||||||
(h, w, 3))
|
|
||||||
|
|
||||||
seg_local_visualizer.add_datasample(
|
|
||||||
out_file,
|
|
||||||
image,
|
|
||||||
gt_seg_data_sample,
|
|
||||||
pred_seg_data_sample,
|
|
||||||
draw_pred=False)
|
|
||||||
self._assert_image_and_shape(
|
|
||||||
osp.join('temp_dir' + '/vis_data/vis_image', out_file + '_0.png'),
|
|
||||||
(h, w, 3))
|
|
||||||
|
|
||||||
os.rmdir('temp_dir/vis_data')
|
|
||||||
os.rmdir('temp_dir')
|
|
||||||
|
|
||||||
def _assert_image_and_shape(self, out_file, out_shape):
|
def _assert_image_and_shape(self, out_file, out_shape):
|
||||||
assert os.path.exists(out_file)
|
assert os.path.exists(out_file)
|
||||||
drawn_img = cv2.imread(out_file)
|
drawn_img = cv2.imread(out_file)
|
||||||
assert drawn_img.shape == out_shape
|
assert drawn_img.shape == out_shape
|
||||||
os.remove(out_file)
|
|
||||||
os.rmdir('temp_dir/vis_data/vis_image')
|
|
||||||
|
|
Loading…
Reference in New Issue