mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
[Fix] Fix LoacalVisualizer UT (#1851)
This commit is contained in:
parent
1479d0a87b
commit
85ef7b905a
@ -7,7 +7,6 @@ from unittest import TestCase
|
||||
import cv2
|
||||
import mmcv
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from mmengine.data import PixelData
|
||||
|
||||
@ -29,14 +28,11 @@ class TestSegLocalVisualizer(TestCase):
|
||||
gt_sem_seg_data = dict(data=torch.randint(0, num_class, (1, h, w)))
|
||||
gt_sem_seg = PixelData(**gt_sem_seg_data)
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not torch.cuda.is_available(), reason='CUDA not available')
|
||||
@pytest.mark.parametrize('gt_sem_seg', (gt_sem_seg, gt_sem_seg.cuda()))
|
||||
def test_add_datasample_forward(gt_sem_seg):
|
||||
gt_seg_data_sample = SegDataSample()
|
||||
gt_seg_data_sample.gt_sem_seg = gt_sem_seg
|
||||
|
||||
with tempfile.TemporaryDirectory(dir='temp_dir') as tmp_dir:
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
seg_local_visualizer = SegLocalVisualizer(
|
||||
vis_backends=[dict(type='LocalVisBackend')],
|
||||
save_dir=tmp_dir)
|
||||
@ -81,6 +77,10 @@ class TestSegLocalVisualizer(TestCase):
|
||||
osp.join(tmp_dir, 'vis_data', 'vis_image',
|
||||
out_file + '_0.png'), (h, w, 3))
|
||||
|
||||
if torch.cuda.is_available():
|
||||
test_add_datasample_forward(gt_sem_seg.cuda())
|
||||
test_add_datasample_forward(gt_sem_seg)
|
||||
|
||||
def test_cityscapes_add_datasample(self):
|
||||
h = 128
|
||||
w = 256
|
||||
@ -103,16 +103,14 @@ class TestSegLocalVisualizer(TestCase):
|
||||
gt_sem_seg_data = dict(data=sem_seg)
|
||||
gt_sem_seg = PixelData(**gt_sem_seg_data)
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not torch.cuda.is_available(), reason='CUDA not available')
|
||||
@pytest.mark.parametrize('gt_sem_seg', (gt_sem_seg, gt_sem_seg.cuda()))
|
||||
def test_cityscapes_add_datasample_forward(gt_sem_seg):
|
||||
gt_seg_data_sample = SegDataSample()
|
||||
gt_seg_data_sample.gt_sem_seg = gt_sem_seg
|
||||
with tempfile.TemporaryDirectory(dir='temp_dir') as tmp_dir:
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
seg_local_visualizer = SegLocalVisualizer(
|
||||
vis_backends=[dict(type='LocalVisBackend')],
|
||||
save_dir='temp_dir')
|
||||
save_dir=tmp_dir)
|
||||
seg_local_visualizer.dataset_meta = dict(
|
||||
classes=('road', 'sidewalk', 'building', 'wall', 'fence',
|
||||
'pole', 'traffic light', 'traffic sign',
|
||||
@ -165,6 +163,10 @@ class TestSegLocalVisualizer(TestCase):
|
||||
osp.join(tmp_dir, 'vis_data', 'vis_image',
|
||||
out_file + '_0.png'), (h, w, 3))
|
||||
|
||||
if torch.cuda.is_available():
|
||||
test_cityscapes_add_datasample_forward(gt_sem_seg.cuda())
|
||||
test_cityscapes_add_datasample_forward(gt_sem_seg)
|
||||
|
||||
def _assert_image_and_shape(self, out_file, out_shape):
|
||||
assert os.path.exists(out_file)
|
||||
drawn_img = cv2.imread(out_file)
|
||||
|
Loading…
x
Reference in New Issue
Block a user