EasyCV/tests/datasets/segmentation/test_seg_raw_dataset.py

63 lines
2.2 KiB
Python

# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import unittest
from tests.ut_config import (IMG_NORM_CFG_255, SEG_DATA_SMALL_RAW_LOCAL,
VOC_CLASSES)
from easycv.core.evaluation.builder import build_evaluator
from easycv.datasets.builder import build_datasource
from easycv.datasets.segmentation.raw import SegDataset
class SegDatasetTest(unittest.TestCase):
def setUp(self):
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
def test_default(self):
data_root = SEG_DATA_SMALL_RAW_LOCAL
data_source_cfg = dict(
type='SegSourceRaw',
img_root=os.path.join(data_root, 'images'),
label_root=os.path.join(data_root, 'labels'),
classes=VOC_CLASSES,
num_processes=
1 # results copy from datasource, ensure results and groundtruth has the same data list
)
crop_size = (512, 512)
pipeline = [
dict(
type='SegRandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
dict(type='MMNormalize', **IMG_NORM_CFG_255),
dict(type='MMPad', size=crop_size),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_semantic_seg']),
]
dataset = SegDataset(data_source_cfg, pipeline)
data_source = build_datasource(data_source_cfg)
gt_seg_maps = []
for i in range(len(data_source)):
sample = data_source[i]
gt_seg_maps.append(sample['gt_semantic_seg'])
results = {'seg_pred': gt_seg_maps}
evaluator = build_evaluator(
dict(
type='SegmentationEvaluator',
classes=VOC_CLASSES,
metric_names=['mIoU'],
))
eval_results = dataset.evaluate(results, evaluators=evaluator)
self.assertEqual(eval_results['aAcc'], 1.0)
self.assertEqual(eval_results['mIoU'], 1.0)
self.assertEqual(eval_results['mAcc'], 1.0)
self.assertEqual(eval_results['IoU.aeroplane'], 1.0)
self.assertEqual(eval_results['IoU.tvmonitor'], 1.0)
self.assertEqual(eval_results['Acc.tvmonitor'], 1.0)
if __name__ == '__main__':
unittest.main()