mirror of
https://github.com/alibaba/EasyCV.git
synced 2025-06-03 14:49:00 +08:00
49 lines
1.5 KiB
Python
49 lines
1.5 KiB
Python
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
"""
|
|
isort:skip_file
|
|
"""
|
|
import os
|
|
import unittest
|
|
|
|
import numpy as np
|
|
from PIL import Image
|
|
|
|
from tests.ut_config import TEST_IMAGES_DIR
|
|
from tests.ut_config import (PRETRAINED_MODEL_SEGFORMER,
|
|
MODEL_CONFIG_SEGFORMER)
|
|
from easycv.predictors.segmentation import SegFormerPredictor
|
|
|
|
|
|
class SegmentorTest(unittest.TestCase):
|
|
|
|
def setUp(self):
|
|
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
|
|
|
|
def test_segformer_detector(self):
|
|
segmentation_model_path = PRETRAINED_MODEL_SEGFORMER
|
|
segmentation_model_config = MODEL_CONFIG_SEGFORMER
|
|
|
|
img = os.path.join(TEST_IMAGES_DIR, '000000289059.jpg')
|
|
if not os.path.exists(img):
|
|
img = './data/test/segmentation/coco_stuff_164k/val2017/000000289059.jpg'
|
|
|
|
input_data_list = [np.asarray(Image.open(img))]
|
|
predictor = SegFormerPredictor(
|
|
model_path=segmentation_model_path,
|
|
model_config=segmentation_model_config)
|
|
|
|
output = predictor.predict(input_data_list)[0]
|
|
self.assertIn('seg_pred', output)
|
|
|
|
self.assertListEqual(
|
|
list(input_data_list[0].shape)[:2],
|
|
list(output['seg_pred'][0].shape))
|
|
self.assertListEqual(output['seg_pred'][0][1, :10].tolist(),
|
|
[161 for i in range(10)])
|
|
self.assertListEqual(output['seg_pred'][0][-1, -10:].tolist(),
|
|
[133 for i in range(10)])
|
|
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|