[Fix] Fix LoacalVisualizer UT (#1851)

This commit is contained in:
Miao Zheng 2022-08-01 19:53:13 +08:00 committed by GitHub
parent 1479d0a87b
commit 85ef7b905a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -7,7 +7,6 @@ from unittest import TestCase
import cv2 import cv2
import mmcv import mmcv
import numpy as np import numpy as np
import pytest
import torch import torch
from mmengine.data import PixelData from mmengine.data import PixelData
@ -29,14 +28,11 @@ class TestSegLocalVisualizer(TestCase):
gt_sem_seg_data = dict(data=torch.randint(0, num_class, (1, h, w))) gt_sem_seg_data = dict(data=torch.randint(0, num_class, (1, h, w)))
gt_sem_seg = PixelData(**gt_sem_seg_data) gt_sem_seg = PixelData(**gt_sem_seg_data)
@pytest.mark.skipif(
not torch.cuda.is_available(), reason='CUDA not available')
@pytest.mark.parametrize('gt_sem_seg', (gt_sem_seg, gt_sem_seg.cuda()))
def test_add_datasample_forward(gt_sem_seg): def test_add_datasample_forward(gt_sem_seg):
gt_seg_data_sample = SegDataSample() gt_seg_data_sample = SegDataSample()
gt_seg_data_sample.gt_sem_seg = gt_sem_seg gt_seg_data_sample.gt_sem_seg = gt_sem_seg
with tempfile.TemporaryDirectory(dir='temp_dir') as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
seg_local_visualizer = SegLocalVisualizer( seg_local_visualizer = SegLocalVisualizer(
vis_backends=[dict(type='LocalVisBackend')], vis_backends=[dict(type='LocalVisBackend')],
save_dir=tmp_dir) save_dir=tmp_dir)
@ -81,6 +77,10 @@ class TestSegLocalVisualizer(TestCase):
osp.join(tmp_dir, 'vis_data', 'vis_image', osp.join(tmp_dir, 'vis_data', 'vis_image',
out_file + '_0.png'), (h, w, 3)) out_file + '_0.png'), (h, w, 3))
if torch.cuda.is_available():
test_add_datasample_forward(gt_sem_seg.cuda())
test_add_datasample_forward(gt_sem_seg)
def test_cityscapes_add_datasample(self): def test_cityscapes_add_datasample(self):
h = 128 h = 128
w = 256 w = 256
@ -103,16 +103,14 @@ class TestSegLocalVisualizer(TestCase):
gt_sem_seg_data = dict(data=sem_seg) gt_sem_seg_data = dict(data=sem_seg)
gt_sem_seg = PixelData(**gt_sem_seg_data) gt_sem_seg = PixelData(**gt_sem_seg_data)
@pytest.mark.skipif(
not torch.cuda.is_available(), reason='CUDA not available')
@pytest.mark.parametrize('gt_sem_seg', (gt_sem_seg, gt_sem_seg.cuda()))
def test_cityscapes_add_datasample_forward(gt_sem_seg): def test_cityscapes_add_datasample_forward(gt_sem_seg):
gt_seg_data_sample = SegDataSample() gt_seg_data_sample = SegDataSample()
gt_seg_data_sample.gt_sem_seg = gt_sem_seg gt_seg_data_sample.gt_sem_seg = gt_sem_seg
with tempfile.TemporaryDirectory(dir='temp_dir') as tmp_dir:
with tempfile.TemporaryDirectory() as tmp_dir:
seg_local_visualizer = SegLocalVisualizer( seg_local_visualizer = SegLocalVisualizer(
vis_backends=[dict(type='LocalVisBackend')], vis_backends=[dict(type='LocalVisBackend')],
save_dir='temp_dir') save_dir=tmp_dir)
seg_local_visualizer.dataset_meta = dict( seg_local_visualizer.dataset_meta = dict(
classes=('road', 'sidewalk', 'building', 'wall', 'fence', classes=('road', 'sidewalk', 'building', 'wall', 'fence',
'pole', 'traffic light', 'traffic sign', 'pole', 'traffic light', 'traffic sign',
@ -165,6 +163,10 @@ class TestSegLocalVisualizer(TestCase):
osp.join(tmp_dir, 'vis_data', 'vis_image', osp.join(tmp_dir, 'vis_data', 'vis_image',
out_file + '_0.png'), (h, w, 3)) out_file + '_0.png'), (h, w, 3))
if torch.cuda.is_available():
test_cityscapes_add_datasample_forward(gt_sem_seg.cuda())
test_cityscapes_add_datasample_forward(gt_sem_seg)
def _assert_image_and_shape(self, out_file, out_shape): def _assert_image_and_shape(self, out_file, out_shape):
assert os.path.exists(out_file) assert os.path.exists(out_file)
drawn_img = cv2.imread(out_file) drawn_img = cv2.imread(out_file)