mirror of
https://github.com/alibaba/EasyCV.git
synced 2025-06-03 14:49:00 +08:00
108 lines
3.8 KiB
Python
108 lines
3.8 KiB
Python
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||
|
import os
|
||
|
import pickle
|
||
|
import shutil
|
||
|
import tempfile
|
||
|
import unittest
|
||
|
|
||
|
import numpy as np
|
||
|
from PIL import Image
|
||
|
from tests.ut_config import (MODEL_CONFIG_SEGFORMER,
|
||
|
PRETRAINED_MODEL_SEGFORMER, TEST_IMAGES_DIR)
|
||
|
|
||
|
from easycv.predictors.segmentation import SegmentationPredictor
|
||
|
|
||
|
|
||
|
class SegmentationPredictorTest(unittest.TestCase):
|
||
|
|
||
|
def setUp(self):
|
||
|
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
|
||
|
|
||
|
def test_single(self):
|
||
|
segmentation_model_path = PRETRAINED_MODEL_SEGFORMER
|
||
|
segmentation_model_config = MODEL_CONFIG_SEGFORMER
|
||
|
|
||
|
img_path = os.path.join(TEST_IMAGES_DIR, '000000289059.jpg')
|
||
|
img = np.asarray(Image.open(img_path))
|
||
|
|
||
|
predict_pipeline = SegmentationPredictor(
|
||
|
model_path=segmentation_model_path,
|
||
|
config_file=segmentation_model_config)
|
||
|
|
||
|
outputs = predict_pipeline(img_path, keep_inputs=True)
|
||
|
self.assertEqual(len(outputs), 1)
|
||
|
self.assertEqual(outputs[0]['inputs'], [img_path])
|
||
|
|
||
|
results = outputs[0]['results']
|
||
|
self.assertListEqual(
|
||
|
list(img.shape)[:2], list(results['seg_pred'][0].shape))
|
||
|
self.assertListEqual(results['seg_pred'][0][1, :10].tolist(),
|
||
|
[161 for i in range(10)])
|
||
|
self.assertListEqual(results['seg_pred'][0][-1, -10:].tolist(),
|
||
|
[133 for i in range(10)])
|
||
|
|
||
|
def test_batch(self):
|
||
|
segmentation_model_path = PRETRAINED_MODEL_SEGFORMER
|
||
|
segmentation_model_config = MODEL_CONFIG_SEGFORMER
|
||
|
|
||
|
img_path = os.path.join(TEST_IMAGES_DIR, '000000289059.jpg')
|
||
|
img = np.asarray(Image.open(img_path))
|
||
|
|
||
|
predict_pipeline = SegmentationPredictor(
|
||
|
model_path=segmentation_model_path,
|
||
|
config_file=segmentation_model_config,
|
||
|
batch_size=2)
|
||
|
|
||
|
total_samples = 3
|
||
|
outputs = predict_pipeline(
|
||
|
[img_path] * total_samples, keep_inputs=True)
|
||
|
self.assertEqual(len(outputs), 2)
|
||
|
|
||
|
self.assertEqual(outputs[0]['inputs'], [img_path] * 2)
|
||
|
self.assertEqual(outputs[1]['inputs'], [img_path] * 1)
|
||
|
self.assertEqual(len(outputs[0]['results']['seg_pred']), 2)
|
||
|
self.assertEqual(len(outputs[1]['results']['seg_pred']), 1)
|
||
|
|
||
|
for result in [outputs[0]['results'], outputs[1]['results']]:
|
||
|
self.assertListEqual(
|
||
|
list(img.shape)[:2], list(result['seg_pred'][0].shape))
|
||
|
self.assertListEqual(result['seg_pred'][0][1, :10].tolist(),
|
||
|
[161 for i in range(10)])
|
||
|
self.assertListEqual(result['seg_pred'][0][-1, -10:].tolist(),
|
||
|
[133 for i in range(10)])
|
||
|
|
||
|
def test_dump(self):
|
||
|
segmentation_model_path = PRETRAINED_MODEL_SEGFORMER
|
||
|
segmentation_model_config = MODEL_CONFIG_SEGFORMER
|
||
|
|
||
|
img_path = os.path.join(TEST_IMAGES_DIR, '000000289059.jpg')
|
||
|
|
||
|
temp_dir = tempfile.TemporaryDirectory().name
|
||
|
if not os.path.exists(temp_dir):
|
||
|
os.makedirs(temp_dir)
|
||
|
tmp_path = os.path.join(temp_dir, 'results.pkl')
|
||
|
|
||
|
predict_pipeline = SegmentationPredictor(
|
||
|
model_path=segmentation_model_path,
|
||
|
config_file=segmentation_model_config,
|
||
|
batch_size=2,
|
||
|
save_results=True,
|
||
|
save_path=tmp_path)
|
||
|
|
||
|
total_samples = 3
|
||
|
outputs = predict_pipeline(
|
||
|
[img_path] * total_samples, keep_inputs=True)
|
||
|
self.assertEqual(outputs, [])
|
||
|
|
||
|
with open(tmp_path, 'rb') as f:
|
||
|
results = pickle.loads(f.read())
|
||
|
|
||
|
self.assertIn('inputs', results[0])
|
||
|
self.assertIn('results', results[0])
|
||
|
|
||
|
shutil.rmtree(temp_dir, ignore_errors=True)
|
||
|
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
unittest.main()
|