mirror of https://github.com/alibaba/EasyCV.git
81 lines
2.5 KiB
Python
81 lines
2.5 KiB
Python
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
import copy
|
|
import glob
|
|
import json
|
|
import logging
|
|
import os
|
|
import sys
|
|
import tempfile
|
|
import unittest
|
|
|
|
import torch
|
|
from mmcv import Config
|
|
from tests.ut_config import (PRETRAINED_MODEL_SEGFORMER,
|
|
PRETRAINED_MODEL_YOLOXS_EXPORT, TEST_IMAGES_DIR)
|
|
|
|
from easycv.file import io
|
|
from easycv.utils.test_util import run_in_subprocess
|
|
|
|
sys.path.append(os.path.dirname(os.path.realpath(__file__)))
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
|
|
|
|
class PredictTest(unittest.TestCase):
|
|
|
|
def setUp(self):
|
|
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
|
|
|
|
def tearDown(self):
|
|
super().tearDown()
|
|
|
|
def _base_predict(self, model_type, model_path, dist=False):
|
|
input_file = tempfile.NamedTemporaryFile('w').name
|
|
input_line_num = 10
|
|
with open(input_file, 'w') as ofile:
|
|
for _ in range(input_line_num):
|
|
ofile.write(
|
|
os.path.join(TEST_IMAGES_DIR, '000000289059.jpg') + '\n')
|
|
output_file = tempfile.NamedTemporaryFile('w').name
|
|
|
|
if dist:
|
|
cmd = f'PYTHONPATH=. python -m torch.distributed.launch --nproc_per_node=2 --master_port=29527 \
|
|
tools/predict.py \
|
|
--input_file {input_file} \
|
|
--output_file {output_file} \
|
|
--model_type {model_type} \
|
|
--model_path {model_path} \
|
|
--launcher pytorch'
|
|
|
|
else:
|
|
cmd = f'PYTHONPATH=. python tools/predict.py \
|
|
--input_file {input_file} \
|
|
--output_file {output_file} \
|
|
--model_type {model_type} \
|
|
--model_path {model_path} '
|
|
|
|
logging.info('run command: %s' % cmd)
|
|
run_in_subprocess(cmd)
|
|
|
|
with open(output_file, 'r') as infile:
|
|
output_line_num = len(infile.readlines())
|
|
self.assertEqual(input_line_num, output_line_num)
|
|
|
|
io.remove(input_file)
|
|
io.remove(output_file)
|
|
|
|
def test_predict(self):
|
|
model_type = 'YoloXPredictor'
|
|
model_path = PRETRAINED_MODEL_YOLOXS_EXPORT
|
|
self._base_predict(model_type, model_path)
|
|
|
|
@unittest.skipIf(torch.cuda.device_count() <= 1, 'distributed unittest')
|
|
def test_predict_dist(self):
|
|
model_type = 'YoloXPredictor'
|
|
model_path = PRETRAINED_MODEL_YOLOXS_EXPORT
|
|
self._base_predict(model_type, model_path, dist=True)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|