[Fix] Fix LoacalVisualizer UT (#1851)

This commit is contained in:
Miao Zheng 2022-08-01 19:53:13 +08:00 committed by GitHub
parent 1479d0a87b
commit 85ef7b905a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -7,7 +7,6 @@ from unittest import TestCase
import cv2
import mmcv
import numpy as np
import pytest
import torch
from mmengine.data import PixelData
@ -29,14 +28,11 @@ class TestSegLocalVisualizer(TestCase):
gt_sem_seg_data = dict(data=torch.randint(0, num_class, (1, h, w)))
gt_sem_seg = PixelData(**gt_sem_seg_data)
@pytest.mark.skipif(
not torch.cuda.is_available(), reason='CUDA not available')
@pytest.mark.parametrize('gt_sem_seg', (gt_sem_seg, gt_sem_seg.cuda()))
def test_add_datasample_forward(gt_sem_seg):
gt_seg_data_sample = SegDataSample()
gt_seg_data_sample.gt_sem_seg = gt_sem_seg
with tempfile.TemporaryDirectory(dir='temp_dir') as tmp_dir:
with tempfile.TemporaryDirectory() as tmp_dir:
seg_local_visualizer = SegLocalVisualizer(
vis_backends=[dict(type='LocalVisBackend')],
save_dir=tmp_dir)
@ -81,6 +77,10 @@ class TestSegLocalVisualizer(TestCase):
osp.join(tmp_dir, 'vis_data', 'vis_image',
out_file + '_0.png'), (h, w, 3))
if torch.cuda.is_available():
test_add_datasample_forward(gt_sem_seg.cuda())
test_add_datasample_forward(gt_sem_seg)
def test_cityscapes_add_datasample(self):
h = 128
w = 256
@ -103,16 +103,14 @@ class TestSegLocalVisualizer(TestCase):
gt_sem_seg_data = dict(data=sem_seg)
gt_sem_seg = PixelData(**gt_sem_seg_data)
@pytest.mark.skipif(
not torch.cuda.is_available(), reason='CUDA not available')
@pytest.mark.parametrize('gt_sem_seg', (gt_sem_seg, gt_sem_seg.cuda()))
def test_cityscapes_add_datasample_forward(gt_sem_seg):
gt_seg_data_sample = SegDataSample()
gt_seg_data_sample.gt_sem_seg = gt_sem_seg
with tempfile.TemporaryDirectory(dir='temp_dir') as tmp_dir:
with tempfile.TemporaryDirectory() as tmp_dir:
seg_local_visualizer = SegLocalVisualizer(
vis_backends=[dict(type='LocalVisBackend')],
save_dir='temp_dir')
save_dir=tmp_dir)
seg_local_visualizer.dataset_meta = dict(
classes=('road', 'sidewalk', 'building', 'wall', 'fence',
'pole', 'traffic light', 'traffic sign',
@ -165,6 +163,10 @@ class TestSegLocalVisualizer(TestCase):
osp.join(tmp_dir, 'vis_data', 'vis_image',
out_file + '_0.png'), (h, w, 3))
if torch.cuda.is_available():
test_cityscapes_add_datasample_forward(gt_sem_seg.cuda())
test_cityscapes_add_datasample_forward(gt_sem_seg)
def _assert_image_and_shape(self, out_file, out_shape):
assert os.path.exists(out_file)
drawn_img = cv2.imread(out_file)