mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
[Fix] Fix LoacalVisualizer UT (#1851)
This commit is contained in:
parent
1479d0a87b
commit
85ef7b905a
@ -7,7 +7,6 @@ from unittest import TestCase
|
|||||||
import cv2
|
import cv2
|
||||||
import mmcv
|
import mmcv
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
|
||||||
import torch
|
import torch
|
||||||
from mmengine.data import PixelData
|
from mmengine.data import PixelData
|
||||||
|
|
||||||
@ -29,14 +28,11 @@ class TestSegLocalVisualizer(TestCase):
|
|||||||
gt_sem_seg_data = dict(data=torch.randint(0, num_class, (1, h, w)))
|
gt_sem_seg_data = dict(data=torch.randint(0, num_class, (1, h, w)))
|
||||||
gt_sem_seg = PixelData(**gt_sem_seg_data)
|
gt_sem_seg = PixelData(**gt_sem_seg_data)
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
|
||||||
not torch.cuda.is_available(), reason='CUDA not available')
|
|
||||||
@pytest.mark.parametrize('gt_sem_seg', (gt_sem_seg, gt_sem_seg.cuda()))
|
|
||||||
def test_add_datasample_forward(gt_sem_seg):
|
def test_add_datasample_forward(gt_sem_seg):
|
||||||
gt_seg_data_sample = SegDataSample()
|
gt_seg_data_sample = SegDataSample()
|
||||||
gt_seg_data_sample.gt_sem_seg = gt_sem_seg
|
gt_seg_data_sample.gt_sem_seg = gt_sem_seg
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory(dir='temp_dir') as tmp_dir:
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
seg_local_visualizer = SegLocalVisualizer(
|
seg_local_visualizer = SegLocalVisualizer(
|
||||||
vis_backends=[dict(type='LocalVisBackend')],
|
vis_backends=[dict(type='LocalVisBackend')],
|
||||||
save_dir=tmp_dir)
|
save_dir=tmp_dir)
|
||||||
@ -81,6 +77,10 @@ class TestSegLocalVisualizer(TestCase):
|
|||||||
osp.join(tmp_dir, 'vis_data', 'vis_image',
|
osp.join(tmp_dir, 'vis_data', 'vis_image',
|
||||||
out_file + '_0.png'), (h, w, 3))
|
out_file + '_0.png'), (h, w, 3))
|
||||||
|
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
test_add_datasample_forward(gt_sem_seg.cuda())
|
||||||
|
test_add_datasample_forward(gt_sem_seg)
|
||||||
|
|
||||||
def test_cityscapes_add_datasample(self):
|
def test_cityscapes_add_datasample(self):
|
||||||
h = 128
|
h = 128
|
||||||
w = 256
|
w = 256
|
||||||
@ -103,16 +103,14 @@ class TestSegLocalVisualizer(TestCase):
|
|||||||
gt_sem_seg_data = dict(data=sem_seg)
|
gt_sem_seg_data = dict(data=sem_seg)
|
||||||
gt_sem_seg = PixelData(**gt_sem_seg_data)
|
gt_sem_seg = PixelData(**gt_sem_seg_data)
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
|
||||||
not torch.cuda.is_available(), reason='CUDA not available')
|
|
||||||
@pytest.mark.parametrize('gt_sem_seg', (gt_sem_seg, gt_sem_seg.cuda()))
|
|
||||||
def test_cityscapes_add_datasample_forward(gt_sem_seg):
|
def test_cityscapes_add_datasample_forward(gt_sem_seg):
|
||||||
gt_seg_data_sample = SegDataSample()
|
gt_seg_data_sample = SegDataSample()
|
||||||
gt_seg_data_sample.gt_sem_seg = gt_sem_seg
|
gt_seg_data_sample.gt_sem_seg = gt_sem_seg
|
||||||
with tempfile.TemporaryDirectory(dir='temp_dir') as tmp_dir:
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
seg_local_visualizer = SegLocalVisualizer(
|
seg_local_visualizer = SegLocalVisualizer(
|
||||||
vis_backends=[dict(type='LocalVisBackend')],
|
vis_backends=[dict(type='LocalVisBackend')],
|
||||||
save_dir='temp_dir')
|
save_dir=tmp_dir)
|
||||||
seg_local_visualizer.dataset_meta = dict(
|
seg_local_visualizer.dataset_meta = dict(
|
||||||
classes=('road', 'sidewalk', 'building', 'wall', 'fence',
|
classes=('road', 'sidewalk', 'building', 'wall', 'fence',
|
||||||
'pole', 'traffic light', 'traffic sign',
|
'pole', 'traffic light', 'traffic sign',
|
||||||
@ -165,6 +163,10 @@ class TestSegLocalVisualizer(TestCase):
|
|||||||
osp.join(tmp_dir, 'vis_data', 'vis_image',
|
osp.join(tmp_dir, 'vis_data', 'vis_image',
|
||||||
out_file + '_0.png'), (h, w, 3))
|
out_file + '_0.png'), (h, w, 3))
|
||||||
|
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
test_cityscapes_add_datasample_forward(gt_sem_seg.cuda())
|
||||||
|
test_cityscapes_add_datasample_forward(gt_sem_seg)
|
||||||
|
|
||||||
def _assert_image_and_shape(self, out_file, out_shape):
|
def _assert_image_and_shape(self, out_file, out_shape):
|
||||||
assert os.path.exists(out_file)
|
assert os.path.exists(out_file)
|
||||||
drawn_img = cv2.imread(out_file)
|
drawn_img = cv2.imread(out_file)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user