diff --git a/tests/test_visualization/test_local_visualizer.py b/tests/test_visualization/test_local_visualizer.py index 4be96d3ba..7b94ec223 100644 --- a/tests/test_visualization/test_local_visualizer.py +++ b/tests/test_visualization/test_local_visualizer.py @@ -7,7 +7,6 @@ from unittest import TestCase import cv2 import mmcv import numpy as np -import pytest import torch from mmengine.data import PixelData @@ -29,14 +28,11 @@ class TestSegLocalVisualizer(TestCase): gt_sem_seg_data = dict(data=torch.randint(0, num_class, (1, h, w))) gt_sem_seg = PixelData(**gt_sem_seg_data) - @pytest.mark.skipif( - not torch.cuda.is_available(), reason='CUDA not available') - @pytest.mark.parametrize('gt_sem_seg', (gt_sem_seg, gt_sem_seg.cuda())) def test_add_datasample_forward(gt_sem_seg): gt_seg_data_sample = SegDataSample() gt_seg_data_sample.gt_sem_seg = gt_sem_seg - with tempfile.TemporaryDirectory(dir='temp_dir') as tmp_dir: + with tempfile.TemporaryDirectory() as tmp_dir: seg_local_visualizer = SegLocalVisualizer( vis_backends=[dict(type='LocalVisBackend')], save_dir=tmp_dir) @@ -81,6 +77,10 @@ class TestSegLocalVisualizer(TestCase): osp.join(tmp_dir, 'vis_data', 'vis_image', out_file + '_0.png'), (h, w, 3)) + if torch.cuda.is_available(): + test_add_datasample_forward(gt_sem_seg.cuda()) + test_add_datasample_forward(gt_sem_seg) + def test_cityscapes_add_datasample(self): h = 128 w = 256 @@ -103,16 +103,14 @@ class TestSegLocalVisualizer(TestCase): gt_sem_seg_data = dict(data=sem_seg) gt_sem_seg = PixelData(**gt_sem_seg_data) - @pytest.mark.skipif( - not torch.cuda.is_available(), reason='CUDA not available') - @pytest.mark.parametrize('gt_sem_seg', (gt_sem_seg, gt_sem_seg.cuda())) def test_cityscapes_add_datasample_forward(gt_sem_seg): gt_seg_data_sample = SegDataSample() gt_seg_data_sample.gt_sem_seg = gt_sem_seg - with tempfile.TemporaryDirectory(dir='temp_dir') as tmp_dir: + + with tempfile.TemporaryDirectory() as tmp_dir: seg_local_visualizer = SegLocalVisualizer( vis_backends=[dict(type='LocalVisBackend')], - save_dir='temp_dir') + save_dir=tmp_dir) seg_local_visualizer.dataset_meta = dict( classes=('road', 'sidewalk', 'building', 'wall', 'fence', 'pole', 'traffic light', 'traffic sign', @@ -165,6 +163,10 @@ class TestSegLocalVisualizer(TestCase): osp.join(tmp_dir, 'vis_data', 'vis_image', out_file + '_0.png'), (h, w, 3)) + if torch.cuda.is_available(): + test_cityscapes_add_datasample_forward(gt_sem_seg.cuda()) + test_cityscapes_add_datasample_forward(gt_sem_seg) + def _assert_image_and_shape(self, out_file, out_shape): assert os.path.exists(out_file) drawn_img = cv2.imread(out_file)