EasyCV/tests/predictors/test_segmentation.py

108 lines
3.8 KiB
Python
Raw Normal View History

# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import pickle
import shutil
import tempfile
import unittest
import numpy as np
from PIL import Image
from tests.ut_config import (MODEL_CONFIG_SEGFORMER,
PRETRAINED_MODEL_SEGFORMER, TEST_IMAGES_DIR)
from easycv.predictors.segmentation import SegmentationPredictor
class SegmentationPredictorTest(unittest.TestCase):
def setUp(self):
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
def test_single(self):
segmentation_model_path = PRETRAINED_MODEL_SEGFORMER
segmentation_model_config = MODEL_CONFIG_SEGFORMER
img_path = os.path.join(TEST_IMAGES_DIR, '000000289059.jpg')
img = np.asarray(Image.open(img_path))
predict_pipeline = SegmentationPredictor(
model_path=segmentation_model_path,
config_file=segmentation_model_config)
outputs = predict_pipeline(img_path, keep_inputs=True)
self.assertEqual(len(outputs), 1)
self.assertEqual(outputs[0]['inputs'], [img_path])
results = outputs[0]['results']
self.assertListEqual(
list(img.shape)[:2], list(results['seg_pred'][0].shape))
self.assertListEqual(results['seg_pred'][0][1, :10].tolist(),
[161 for i in range(10)])
self.assertListEqual(results['seg_pred'][0][-1, -10:].tolist(),
[133 for i in range(10)])
def test_batch(self):
segmentation_model_path = PRETRAINED_MODEL_SEGFORMER
segmentation_model_config = MODEL_CONFIG_SEGFORMER
img_path = os.path.join(TEST_IMAGES_DIR, '000000289059.jpg')
img = np.asarray(Image.open(img_path))
predict_pipeline = SegmentationPredictor(
model_path=segmentation_model_path,
config_file=segmentation_model_config,
batch_size=2)
total_samples = 3
outputs = predict_pipeline(
[img_path] * total_samples, keep_inputs=True)
self.assertEqual(len(outputs), 2)
self.assertEqual(outputs[0]['inputs'], [img_path] * 2)
self.assertEqual(outputs[1]['inputs'], [img_path] * 1)
self.assertEqual(len(outputs[0]['results']['seg_pred']), 2)
self.assertEqual(len(outputs[1]['results']['seg_pred']), 1)
for result in [outputs[0]['results'], outputs[1]['results']]:
self.assertListEqual(
list(img.shape)[:2], list(result['seg_pred'][0].shape))
self.assertListEqual(result['seg_pred'][0][1, :10].tolist(),
[161 for i in range(10)])
self.assertListEqual(result['seg_pred'][0][-1, -10:].tolist(),
[133 for i in range(10)])
def test_dump(self):
segmentation_model_path = PRETRAINED_MODEL_SEGFORMER
segmentation_model_config = MODEL_CONFIG_SEGFORMER
img_path = os.path.join(TEST_IMAGES_DIR, '000000289059.jpg')
temp_dir = tempfile.TemporaryDirectory().name
if not os.path.exists(temp_dir):
os.makedirs(temp_dir)
tmp_path = os.path.join(temp_dir, 'results.pkl')
predict_pipeline = SegmentationPredictor(
model_path=segmentation_model_path,
config_file=segmentation_model_config,
batch_size=2,
save_results=True,
save_path=tmp_path)
total_samples = 3
outputs = predict_pipeline(
[img_path] * total_samples, keep_inputs=True)
self.assertEqual(outputs, [])
with open(tmp_path, 'rb') as f:
results = pickle.loads(f.read())
self.assertIn('inputs', results[0])
self.assertIn('results', results[0])
shutil.rmtree(temp_dir, ignore_errors=True)
if __name__ == '__main__':
unittest.main()