mmsegmentation/tests/test_datasets/test_loading.py

164 lines
5.3 KiB
Python
Raw Normal View History

# Copyright (c) OpenMMLab. All rights reserved.
import copy
import os.path as osp
import tempfile
import mmcv
import numpy as np
from mmcv.transforms import LoadImageFromFile
from mmseg.datasets.pipelines import LoadAnnotations
class TestLoading(object):
@classmethod
def setup_class(cls):
cls.data_prefix = osp.join(osp.dirname(__file__), '../data')
def test_load_img(self):
results = dict(img_path=osp.join(self.data_prefix, 'color.jpg'))
transform = LoadImageFromFile()
results = transform(copy.deepcopy(results))
assert results['img_path'] == osp.join(self.data_prefix, 'color.jpg')
assert results['img'].shape == (288, 512, 3)
assert results['img'].dtype == np.uint8
assert results['ori_shape'] == results['img'].shape[:2]
assert repr(transform) == transform.__class__.__name__ + \
"(to_float32=False, color_type='color'," + \
" imdecode_backend='cv2', file_client_args={'backend': 'disk'})"
# to_float32
transform = LoadImageFromFile(to_float32=True)
results = transform(copy.deepcopy(results))
assert results['img'].dtype == np.float32
# gray image
results = dict(img_path=osp.join(self.data_prefix, 'gray.jpg'))
transform = LoadImageFromFile()
results = transform(copy.deepcopy(results))
assert results['img'].shape == (288, 512, 3)
assert results['img'].dtype == np.uint8
transform = LoadImageFromFile(color_type='unchanged')
results = transform(copy.deepcopy(results))
assert results['img'].shape == (288, 512)
assert results['img'].dtype == np.uint8
def test_load_seg(self):
seg_path = osp.join(self.data_prefix, 'seg.png')
results = dict(
seg_map_path=seg_path, reduce_zero_label=True, seg_fields=[])
transform = LoadAnnotations()
results = transform(copy.deepcopy(results))
assert results['gt_seg_map'].shape == (288, 512)
assert results['gt_seg_map'].dtype == np.uint8
assert repr(transform) == transform.__class__.__name__ + \
"(reduce_zero_label=True,imdecode_backend='pillow')" + \
"file_client_args={'backend': 'disk'})"
# reduce_zero_label
transform = LoadAnnotations(reduce_zero_label=True)
results = transform(copy.deepcopy(results))
assert results['gt_seg_map'].shape == (288, 512)
assert results['gt_seg_map'].dtype == np.uint8
def test_load_seg_custom_classes(self):
test_img = np.random.rand(10, 10)
test_gt = np.zeros_like(test_img)
test_gt[2:4, 2:4] = 1
test_gt[2:4, 6:8] = 2
test_gt[6:8, 2:4] = 3
test_gt[6:8, 6:8] = 4
tmp_dir = tempfile.TemporaryDirectory()
img_path = osp.join(tmp_dir.name, 'img.jpg')
gt_path = osp.join(tmp_dir.name, 'gt.png')
mmcv.imwrite(test_img, img_path)
mmcv.imwrite(test_gt, gt_path)
# test only train with label with id 3
results = dict(
img_path=img_path,
seg_map_path=gt_path,
label_map={
0: 0,
1: 0,
2: 0,
3: 1,
4: 0
},
reduce_zero_label=False,
seg_fields=[])
load_imgs = LoadImageFromFile()
results = load_imgs(copy.deepcopy(results))
load_anns = LoadAnnotations()
results = load_anns(copy.deepcopy(results))
gt_array = results['gt_seg_map']
true_mask = np.zeros_like(gt_array)
true_mask[6:8, 2:4] = 1
assert results['seg_fields'] == ['gt_seg_map']
assert gt_array.shape == (10, 10)
assert gt_array.dtype == np.uint8
np.testing.assert_array_equal(gt_array, true_mask)
# test only train with label with id 4 and 3
results = dict(
img_path=osp.join(self.data_prefix, 'color.jpg'),
seg_map_path=gt_path,
label_map={
0: 0,
1: 0,
2: 0,
3: 2,
4: 1
},
reduce_zero_label=False,
seg_fields=[])
load_imgs = LoadImageFromFile()
results = load_imgs(copy.deepcopy(results))
load_anns = LoadAnnotations()
results = load_anns(copy.deepcopy(results))
gt_array = results['gt_seg_map']
true_mask = np.zeros_like(gt_array)
true_mask[6:8, 2:4] = 2
true_mask[6:8, 6:8] = 1
assert results['seg_fields'] == ['gt_seg_map']
assert gt_array.shape == (10, 10)
assert gt_array.dtype == np.uint8
np.testing.assert_array_equal(gt_array, true_mask)
# test no custom classes
results = dict(
img_path=img_path,
seg_map_path=gt_path,
reduce_zero_label=False,
seg_fields=[])
load_imgs = LoadImageFromFile()
results = load_imgs(copy.deepcopy(results))
load_anns = LoadAnnotations()
results = load_anns(copy.deepcopy(results))
gt_array = results['gt_seg_map']
assert results['seg_fields'] == ['gt_seg_map']
assert gt_array.shape == (10, 10)
assert gt_array.dtype == np.uint8
np.testing.assert_array_equal(gt_array, test_gt)
tmp_dir.cleanup()