EasyCV/tests/predictors/test_segmentor.py

58 lines
1.9 KiB
Python

# Copyright (c) Alibaba, Inc. and its affiliates.
"""
isort:skip_file
"""
import os
import tempfile
import unittest
import cv2
import numpy as np
from PIL import Image
from easycv.predictors.detector import TorchYoloXPredictor, TorchViTDetPredictor
from tests.ut_config import (PRETRAINED_MODEL_YOLOXS_EXPORT,
PRETRAINED_MODEL_YOLOXS_EXPORT_JIT,
PRETRAINED_MODEL_YOLOXS_END2END_JIT,
DET_DATA_SMALL_COCO_LOCAL)
from tests.ut_config import (PRETRAINED_MODEL_SEGFORMER,
MODEL_CONFIG_SEGFORMER)
from easycv.predictors.segmentation import (SegFormerPredictor)
from numpy.testing import assert_array_almost_equal
class SegmentorTest(unittest.TestCase):
def setUp(self):
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
def test_segformer_detector(self):
segmentation_model_path = PRETRAINED_MODEL_SEGFORMER
segmentation_model_config = MODEL_CONFIG_SEGFORMER
img = os.path.join(DET_DATA_SMALL_COCO_LOCAL,
'val2017/000000289059.jpg')
if os.path.exists(img) == False:
img = './data/test/segmentation/coco_stuff_164k/val2017/000000289059.jpg'
input_data_list = [np.asarray(Image.open(img))]
predictor = SegFormerPredictor(
model_path=segmentation_model_path,
model_config=segmentation_model_config)
output = predictor.predict(input_data_list)[0]
self.assertIn('seg_pred', output)
self.assertListEqual(
list(input_data_list[0].shape)[:2],
list(output['seg_pred'][0].shape))
self.assertListEqual(output['seg_pred'][0][1, :10].tolist(),
[161 for i in range(10)])
self.assertListEqual(output['seg_pred'][0][-1, -10:].tolist(),
[133 for i in range(10)])
if __name__ == '__main__':
unittest.main()