diff --git a/easycv/datasets/shared/pipelines/transforms.py b/easycv/datasets/shared/pipelines/transforms.py index 31e4a966..d9803fa1 100644 --- a/easycv/datasets/shared/pipelines/transforms.py +++ b/easycv/datasets/shared/pipelines/transforms.py @@ -89,6 +89,7 @@ class LoadImage: results['img'] = img results['img_shape'] = img.shape results['ori_shape'] = img.shape + results['ori_img_shape'] = img.shape results['img_fields'] = ['img'] return results diff --git a/easycv/predictors/base.py b/easycv/predictors/base.py index 5b36f2fd..ef2be922 100644 --- a/easycv/predictors/base.py +++ b/easycv/predictors/base.py @@ -274,12 +274,12 @@ class PredictorV2(object): else: out_i[k] = None - out_i = self.postprocess_single(out_i) + out_i = self.postprocess_single(out_i, *args, **kwargs) outputs.append(out_i) return outputs - def postprocess_single(self, inputs): + def postprocess_single(self, inputs, *args, **kwargs): """Process outputs of single sample. If you need add some processing ops, you need to reimplement it. """ diff --git a/easycv/predictors/detector.py b/easycv/predictors/detector.py index 38fd262f..070d3268 100644 --- a/easycv/predictors/detector.py +++ b/easycv/predictors/detector.py @@ -125,113 +125,133 @@ class DetrPredictor(DetectionPredictor): """""" +class _JitProcessorWrapper: + + def __init__(self, processor, device) -> None: + self.processor = processor + self.device = device + + def __call__(self, results): + if self.processor is not None: + from mmcv.parallel import DataContainer as DC + outputs = {} + img = results['img'] + img = torch.from_numpy(img).to(self.device) + img, img_meta = self.processor(img.unsqueeze(0)) # process batch + outputs['img'] = DC( + img.squeeze(0), + stack=True) # DC wrapper for collate batch and to device + outputs['img_metas'] = DC(img_meta, cpu_only=True) + return outputs + return results + + @PREDICTORS.register_module() -class TorchYoloXPredictor(PredictorInterface): +class YoloXPredictor(DetectionPredictor): + """Detection predictor for Yolox.""" def __init__(self, model_path, + config_file=None, + batch_size=1, + use_trt_efficientnms=False, + device=None, + save_results=False, + save_path=None, + pipelines=None, max_det=100, score_thresh=0.5, - use_trt_efficientnms=False, - model_config=None): - """ - init model - - Args: - model_path: model file path - max_det: maximum number of detection - score_thresh: score_thresh to filter box - model_config: config string for model to init, in json format - """ - self.model_path = model_path + nms_thresh=None, + test_conf=None, + *arg, + **kwargs): self.max_det = max_det - self.device = 'cuda' if torch.cuda.is_available() else 'cpu' - # set type - self.model_type = 'raw' + self.use_trt_efficientnms = use_trt_efficientnms + if model_path.endswith('jit'): self.model_type = 'jit' - if model_path.endswith('blade'): + elif model_path.endswith('blade'): self.model_type = 'blade' - - self.use_trt_efficientnms = use_trt_efficientnms + else: + self.model_type = 'raw' if self.model_type == 'blade' or self.use_trt_efficientnms: import torch_blade - if model_config: - model_config = json.loads(model_config) - else: - model_config = {} + if self.model_type != 'raw' and config_file is None: + config_file = model_path + '.config.json' - self.score_thresh = model_config[ - 'score_thresh'] if 'score_thresh' in model_config else score_thresh + super(YoloXPredictor, self).__init__( + model_path, + config_file=config_file, + batch_size=batch_size, + device=device, + save_results=save_results, + save_path=save_path, + pipelines=pipelines, + score_threshold=score_thresh) + self.test_conf = test_conf or self.cfg['model'].get('test_conf', 0.01) + self.nms_thre = nms_thresh or self.cfg['model'].get('nms_thre', 0.65) + self.CLASSES = self.cfg.get('CLASSES', None) or self.cfg.get( + 'classes', None) + assert self.CLASSES is not None + + def _build_model(self): if self.model_type != 'raw': + with io.open(self.model_path, 'rb') as infile: + model = torch.jit.load(infile, self.device) + else: + model = super()._build_model() + model = reparameterize_models(model) + return model + + def prepare_model(self): + """Build model from config file by default. + If the model is not loaded from a configuration file, e.g. torch jit model, you need to reimplement it. + """ + model = self._build_model() + model.to(self.device) + model.eval() + if self.model_type == 'raw': + load_checkpoint(model, self.model_path, map_location='cpu') + return model + + def build_processor(self): + self.jit_preprocess = False + if self.model_type != 'raw': + if hasattr(self.cfg, 'export'): + self.jit_preprocess = self.cfg['export'].get( + 'preprocess_jit', False) + + if self.model_type != 'raw' and self.jit_preprocess: # jit or blade model + processor = None preprocess_path = '.'.join( - model_path.split('.')[:-1] + ['preprocess']) + self.model_path.split('.')[:-1] + ['preprocess']) if os.path.exists(preprocess_path): # use a preprocess jit model to speed up with io.open(preprocess_path, 'rb') as infile: - map_location = 'cpu' if self.device == 'cpu' else 'cuda' - self.preprocess = torch.jit.load(infile, map_location) - - with io.open(model_path, 'rb') as infile: - map_location = 'cpu' if self.device == 'cpu' else 'cuda' - self.model = torch.jit.load(infile, map_location) - with io.open(model_path + '.config.json', 'r') as infile: - self.cfg = json.load(infile) - test_pipeline = self.cfg['test_pipeline'] - self.CLASSES = self.cfg['classes'] - self.preprocess_jit = self.cfg['export']['preprocess_jit'] - - self.traceable = True - + processor = torch.jit.load(infile, self.device) + return _JitProcessorWrapper(processor, self.device) else: - self.preprocess_jit = False - with io.open(self.model_path, 'rb') as infile: - checkpoint = torch.load(infile, map_location='cpu') + return super().build_processor() - assert 'meta' in checkpoint and 'config' in checkpoint[ - 'meta'], 'meta.config is missing from checkpoint' + def forward(self, inputs): + """Model forward. + If you need refactor model forward, you need to reimplement it. + """ + if self.model_type != 'raw': + with torch.no_grad(): + outputs = self.model(inputs['img']) + outputs = {'results': outputs} # convert to dict format + else: + outputs = super().forward(inputs) - config_str = checkpoint['meta']['config'] - # get config - basename = os.path.basename(self.model_path) - fname, _ = os.path.splitext(basename) - self.local_config_file = os.path.join(CACHE_DIR, - f'{fname}_config.json') - if not os.path.exists(CACHE_DIR): - os.makedirs(CACHE_DIR) - with open(self.local_config_file, 'w') as ofile: - ofile.write(config_str) + if 'img_metas' not in outputs: + outputs['img_metas'] = inputs['img_metas'] - self.cfg = mmcv_config_fromfile(self.local_config_file) - - # build model - self.model = build_model(self.cfg.model) - - self.traceable = getattr(self.model, 'trace_able', False) - - self.device = 'cuda' if torch.cuda.is_available() else 'cpu' - map_location = 'cpu' if self.device == 'cpu' else 'cuda' - self.ckpt = load_checkpoint( - self.model, self.model_path, map_location=map_location) - - self.model = reparameterize_models(self.model) - - self.model.to(self.device) - self.model.eval() - test_pipeline = self.cfg.test_pipeline - self.CLASSES = self.cfg.CLASSES - - # build pipeline - pipeline = [build_from_cfg(p, PIPELINES) for p in test_pipeline] - self.pipeline = Compose(pipeline) - - self.test_conf = self.cfg['model'].get('test_conf', 0.01) - self.nms_thre = self.cfg['model'].get('nms_thre', 0.65) - self.num_classes = len(self.CLASSES) + return outputs def post_assign(self, outputs, img_metas): detection_boxes = [] @@ -267,101 +287,73 @@ class TorchYoloXPredictor(PredictorInterface): } return test_outputs - def predict(self, input_data_list, batch_size=-1, to_numpy=True): + def postprocess_single(self, inputs): + det_out = inputs + img_meta = det_out['img_metas'] + + if self.model_type != 'raw': + results = det_out['results'] + if self.use_trt_efficientnms: + det_out = {} + det_out['detection_boxes'] = results[1] / img_meta[ + 'scale_factor'][0] + det_out['detection_scores'] = results[2] + det_out['detection_classes'] = results[3] + else: + det_out = self.post_assign( + postprocess( + results.unsqueeze(0), len(self.CLASSES), + self.test_conf, self.nms_thre), + img_metas=[img_meta]) + det_out['detection_scores'] = det_out['detection_scores'][0] + det_out['detection_boxes'] = det_out['detection_boxes'][0] + det_out['detection_classes'] = det_out['detection_classes'][0] + + resuts = super().postprocess_single(det_out) + resuts['ori_img_shape'] = list(img_meta['ori_img_shape'][:2]) + return resuts + + +@deprecated(reason='Please use YoloXPredictor.') +@PREDICTORS.register_module() +class TorchYoloXPredictor(YoloXPredictor): + + def __init__(self, + model_path, + max_det=100, + score_thresh=0.5, + use_trt_efficientnms=False, + model_config=None): """ - using session run predict a number of samples using batch_size + Args: + model_path: model file path + max_det: maximum number of detection + score_thresh: score_thresh to filter box + model_config: config string for model to init, in json format + """ + if model_config: + model_config = json.loads(model_config) + else: + model_config = {} - Args: - input_data_list: a list of numpy array(in rgb order), each array is a sample - to be predicted - batch_size: batch_size passed by the caller, you can also ignore this param and - use a fixed number if you do not want to adjust batch_size in runtime - Return: - result: a list of dict, each dict is the prediction result of one sample - eg, {"output1": value1, "output2": value2}, the value type can be - python int str float, and numpy array - """ - output_list = [] - for idx, img in enumerate(input_data_list): - if type(img) is not np.ndarray: - img = np.asarray(img) + score_thresh = model_config[ + 'score_thresh'] if 'score_thresh' in model_config else score_thresh + super().__init__( + model_path, + config_file=None, + batch_size=1, + use_trt_efficientnms=use_trt_efficientnms, + device=None, + save_results=False, + save_path=None, + pipelines=None, + max_det=max_det, + score_thresh=score_thresh, + nms_thresh=None, + test_conf=None) - ori_img_shape = img.shape[:2] - if self.preprocess_jit: - # the input should also be as the type of uint8 as mmcv - img = torch.from_numpy(img).to(self.device) - img = img.unsqueeze(0) - - if hasattr(self, 'preprocess'): - img, img_info = self.preprocess(img) - - else: - data_dict = {'img': img} - data_dict = self.pipeline(data_dict) - img = data_dict['img'] - img = torch.unsqueeze(img._data, 0).to(self.device) - data_dict.pop('img') - img_info = data_dict['img_metas']._data - - if self.traceable: - if self.use_trt_efficientnms: - with torch.no_grad(): - tmp_out = self.model(img) - det_out = {} - det_out['detection_boxes'] = tmp_out[1] / img_info[ - 'scale_factor'][0] - det_out['detection_scores'] = tmp_out[2] - det_out['detection_classes'] = tmp_out[3] - - else: - with torch.no_grad(): - det_out = self.post_assign( - postprocess( - self.model(img), self.num_classes, - self.test_conf, self.nms_thre), - img_metas=[img_info]) - else: - with torch.no_grad(): - det_out = self.model( - img, mode='test', img_metas=[img_info]) - - # print(det_out) - # det_out = det_out[:self.max_det] - # scale box to original image scale, this logic has some operation - # that can not be traced, see - # https://discuss.pytorch.org/t/windows-libtorch-c-load-cuda-module-with-std-runtime-error-message-shape-4-is-invalid-for-input-if-size-40/63073/4 - # det_out = scale_coords(img.shape[2:], det_out, ori_img_shape, (scale_factor, pad)) - - detection_scores = det_out['detection_scores'][0] - - if detection_scores is not None: - sel_ids = detection_scores > self.score_thresh - detection_scores = detection_scores[sel_ids] - detection_boxes = det_out['detection_boxes'][0][sel_ids] - detection_classes = det_out['detection_classes'][0][sel_ids] - else: - detection_boxes = None - detection_classes = None - - num_boxes = detection_classes.shape[ - 0] if detection_classes is not None else 0 - - detection_classes_names = [ - self.CLASSES[detection_classes[idx]] - for idx in range(num_boxes) - ] - - out = { - 'ori_img_shape': list(ori_img_shape), - 'detection_boxes': detection_boxes, - 'detection_scores': detection_scores, - 'detection_classes': detection_classes, - 'detection_class_names': detection_classes_names, - } - - output_list.append(out) - - return output_list + def predict(self, input_data_list, batch_size=-1, to_numpy=True): + return super().__call__(input_data_list) @PREDICTORS.register_module() diff --git a/tests/predictors/test_detector.py b/tests/predictors/test_detector.py index 1b160a01..3e46a0a0 100644 --- a/tests/predictors/test_detector.py +++ b/tests/predictors/test_detector.py @@ -7,152 +7,99 @@ import unittest import tempfile import numpy as np from PIL import Image - -from easycv.predictors.detector import TorchYoloXPredictor, DetectionPredictor +from easycv.predictors.detector import DetectionPredictor, YoloXPredictor, TorchYoloXPredictor from tests.ut_config import (PRETRAINED_MODEL_YOLOXS_EXPORT, - PRETRAINED_MODEL_YOLOXS_EXPORT_OLD, PRETRAINED_MODEL_YOLOXS_NOPRE_NOTRT_JIT, PRETRAINED_MODEL_YOLOXS_PRE_NOTRT_JIT, DET_DATA_SMALL_COCO_LOCAL) from numpy.testing import assert_array_almost_equal -class DetectorTest(unittest.TestCase): +class YoloXPredictorTest(unittest.TestCase): + img = os.path.join(DET_DATA_SMALL_COCO_LOCAL, 'val2017/000000522713.jpg') def setUp(self): print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) - def test_yolox_old_detector(self): - detection_model_path = PRETRAINED_MODEL_YOLOXS_EXPORT_OLD - - img = os.path.join(DET_DATA_SMALL_COCO_LOCAL, - 'val2017/000000522713.jpg') - - input_data_list = [np.asarray(Image.open(img))] - predictor = TorchYoloXPredictor( - model_path=detection_model_path, score_thresh=0.5) - - output = predictor.predict(input_data_list)[0] - - def test_yolox_detector(self): - detection_model_path = PRETRAINED_MODEL_YOLOXS_EXPORT - - img = os.path.join(DET_DATA_SMALL_COCO_LOCAL, - 'val2017/000000522713.jpg') - - input_data_list = [np.asarray(Image.open(img))] - predictor = TorchYoloXPredictor( - model_path=detection_model_path, score_thresh=0.5) - - output = predictor.predict(input_data_list)[0] - self.assertIn('detection_boxes', output) - self.assertIn('detection_scores', output) - self.assertIn('detection_classes', output) - self.assertIn('detection_class_names', output) - self.assertIn('ori_img_shape', output) - - self.assertEqual(len(output['detection_boxes']), 4) - self.assertEqual(output['ori_img_shape'], [480, 640]) - - self.assertListEqual(output['detection_classes'].tolist(), + def _assert_results(self, results): + self.assertEqual(results['ori_img_shape'], [480, 640]) + self.assertListEqual(results['detection_classes'].tolist(), np.array([13, 8, 8, 8], dtype=np.int32).tolist()) - - self.assertListEqual(output['detection_class_names'], + self.assertListEqual(results['detection_class_names'], ['bench', 'boat', 'boat', 'boat']) - assert_array_almost_equal( - output['detection_scores'], - np.array([0.92593855, 0.60268813, 0.57775956, 0.5750004], + results['detection_scores'], + np.array([0.92335737, 0.59416807, 0.5567955, 0.55368793], dtype=np.float32), decimal=2) - assert_array_almost_equal( - output['detection_boxes'], - np.array([[407.89523, 284.62598, 561.4984, 356.7296], - [439.37653, 263.42395, 467.01526, 271.79144], - [480.8597, 269.64435, 502.18765, 274.80127], - [510.37033, 268.4982, 527.67017, 273.04935]]), + results['detection_boxes'], + np.array([[408.1708, 285.11456, 561.84924, 356.42285], + [438.88098, 264.46606, 467.07275, 271.76355], + [510.19467, 268.46664, 528.26935, 273.37192], + [480.9472, 269.74115, 502.00842, 274.85553]]), decimal=1) - def test_yolox_detector_jit_nopre_notrt(self): - img = os.path.join(DET_DATA_SMALL_COCO_LOCAL, - 'val2017/000000522713.jpg') + def _base_test_single(self, model_path, inputs): + predictor = YoloXPredictor(model_path=model_path, score_thresh=0.5) - input_data_list = [np.asarray(Image.open(img))] + outputs = predictor(inputs) + self.assertEqual(len(outputs), 1) + output = outputs[0] + self._assert_results(output) + def _base_test_batch(self, model_path, inputs, num_samples, batch_size): + assert isinstance(inputs, list) and len(inputs) == 1 + + predictor = YoloXPredictor( + model_path=model_path, score_thresh=0.5, batch_size=batch_size) + outputs = predictor(inputs * num_samples) + + self.assertEqual(len(outputs), num_samples) + for output in outputs: + self._assert_results(output) + + def test_single_raw(self): + model_path = PRETRAINED_MODEL_YOLOXS_EXPORT + inputs = [np.asarray(Image.open(self.img))] + self._base_test_single(model_path, inputs) + + def test_batch_raw(self): + model_path = PRETRAINED_MODEL_YOLOXS_EXPORT + inputs = [np.asarray(Image.open(self.img))] + self._base_test_batch(model_path, inputs, num_samples=3, batch_size=2) + + def test_single_jit_nopre_notrt(self): jit_path = PRETRAINED_MODEL_YOLOXS_NOPRE_NOTRT_JIT - predictor_jit = TorchYoloXPredictor( - model_path=jit_path, score_thresh=0.5) + self._base_test_single(jit_path, self.img) - output = predictor_jit.predict(input_data_list)[0] - self.assertIn('detection_boxes', output) - self.assertIn('detection_scores', output) - self.assertIn('detection_classes', output) - self.assertIn('detection_class_names', output) - self.assertIn('ori_img_shape', output) - - self.assertEqual(len(output['detection_boxes']), 4) - self.assertEqual(output['ori_img_shape'], [480, 640]) - - self.assertListEqual(output['detection_classes'].tolist(), - np.array([13, 8, 8, 8], dtype=np.int32).tolist()) - - self.assertListEqual(output['detection_class_names'], - ['bench', 'boat', 'boat', 'boat']) - - assert_array_almost_equal( - output['detection_scores'], - np.array([0.92593855, 0.60268813, 0.57775956, 0.5750004], - dtype=np.float32), - decimal=2) - - assert_array_almost_equal( - output['detection_boxes'], - np.array([[407.89523, 284.62598, 561.4984, 356.7296], - [439.37653, 263.42395, 467.01526, 271.79144], - [480.8597, 269.64435, 502.18765, 274.80127], - [510.37033, 268.4982, 527.67017, 273.04935]]), - decimal=1) - - def test_yolox_detector_jit_pre_trt(self): - img = os.path.join(DET_DATA_SMALL_COCO_LOCAL, - 'val2017/000000522713.jpg') - - input_data_list = [np.asarray(Image.open(img))] + def test_batch_jit_nopre_notrt(self): + jit_path = PRETRAINED_MODEL_YOLOXS_NOPRE_NOTRT_JIT + self._base_test_batch( + jit_path, [self.img], num_samples=2, batch_size=1) + def test_single_jit_pre_trt(self): jit_path = PRETRAINED_MODEL_YOLOXS_PRE_NOTRT_JIT - predictor_jit = TorchYoloXPredictor( - model_path=jit_path, score_thresh=0.5) + self._base_test_single(jit_path, [self.img]) - output = predictor_jit.predict(input_data_list)[0] - self.assertIn('detection_boxes', output) - self.assertIn('detection_scores', output) - self.assertIn('detection_classes', output) - self.assertIn('detection_class_names', output) - self.assertIn('ori_img_shape', output) + def test_batch_jit_pre_trt(self): + jit_path = PRETRAINED_MODEL_YOLOXS_PRE_NOTRT_JIT + self._base_test_batch( + jit_path, [self.img], num_samples=4, batch_size=2) - self.assertEqual(len(output['detection_boxes']), 4) - self.assertEqual(output['ori_img_shape'], [480, 640]) + def test_single_raw_TorchYoloXPredictor(self): + detection_model_path = PRETRAINED_MODEL_YOLOXS_EXPORT + input_data_list = [np.asarray(Image.open(self.img))] + predictor = TorchYoloXPredictor( + model_path=detection_model_path, score_thresh=0.5) + output = predictor(input_data_list)[0] + self._assert_results(output) - self.assertListEqual(output['detection_classes'].tolist(), - np.array([13, 8, 8, 8], dtype=np.int32).tolist()) - self.assertListEqual(output['detection_class_names'], - ['bench', 'boat', 'boat', 'boat']) +class DetectionPredictorTest(unittest.TestCase): - assert_array_almost_equal( - output['detection_scores'], - np.array([0.92593855, 0.60268813, 0.57775956, 0.5750004], - dtype=np.float32), - decimal=2) - - assert_array_almost_equal( - output['detection_boxes'], - np.array([[407.89523, 284.62598, 561.4984, 356.7296], - [439.37653, 263.42395, 467.01526, 271.79144], - [480.8597, 269.64435, 502.18765, 274.80127], - [510.37033, 268.4982, 527.67017, 273.04935]]), - decimal=1) + def setUp(self): + print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) def _detection_detector_assert(self, output): self.assertIn('detection_boxes', output) diff --git a/tests/predictors/test_detector_blade.py b/tests/predictors/test_detector_blade.py index 143425a3..28c90b38 100644 --- a/tests/predictors/test_detector_blade.py +++ b/tests/predictors/test_detector_blade.py @@ -6,7 +6,7 @@ import os import unittest import numpy as np from PIL import Image -from easycv.predictors.detector import TorchYoloXPredictor +from easycv.predictors.detector import YoloXPredictor from tests.ut_config import (PRETRAINED_MODEL_YOLOXS_NOPRE_NOTRT_BLADE, PRETRAINED_MODEL_YOLOXS_PRE_NOTRT_BLADE, DET_DATA_SMALL_COCO_LOCAL) @@ -17,90 +17,73 @@ from numpy.testing import assert_array_almost_equal @unittest.skipIf(torch.__version__ != '1.8.1+cu102', 'Blade need another environment') -class DetectorTest(unittest.TestCase): +class YoloXPredictorBladeTest(unittest.TestCase): + img = os.path.join(DET_DATA_SMALL_COCO_LOCAL, 'val2017/000000522713.jpg') def setUp(self): print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) - def test_yolox_detector_blade_nopre_notrt(self): - img = os.path.join(DET_DATA_SMALL_COCO_LOCAL, - 'val2017/000000522713.jpg') - - input_data_list = [np.asarray(Image.open(img))] - - blade_path = PRETRAINED_MODEL_YOLOXS_NOPRE_NOTRT_BLADE - predictor_blade = TorchYoloXPredictor( - model_path=blade_path, score_thresh=0.5) - - output = predictor_blade.predict(input_data_list)[0] - self.assertIn('detection_boxes', output) - self.assertIn('detection_scores', output) - self.assertIn('detection_classes', output) - self.assertIn('detection_class_names', output) - self.assertIn('ori_img_shape', output) - - self.assertEqual(len(output['detection_boxes']), 4) - self.assertEqual(output['ori_img_shape'], [480, 640]) - - self.assertListEqual(output['detection_classes'].tolist(), + def _assert_results(self, results): + self.assertEqual(results['ori_img_shape'], [480, 640]) + self.assertListEqual(results['detection_classes'].tolist(), np.array([13, 8, 8, 8], dtype=np.int32).tolist()) - - self.assertListEqual(output['detection_class_names'], + self.assertListEqual(results['detection_class_names'], ['bench', 'boat', 'boat', 'boat']) - assert_array_almost_equal( - output['detection_scores'], - np.array([0.92593855, 0.60268813, 0.57775956, 0.5750004], + results['detection_scores'], + np.array([0.92335737, 0.59416807, 0.5567955, 0.55368793], dtype=np.float32), decimal=2) - assert_array_almost_equal( - output['detection_boxes'], - np.array([[407.89523, 284.62598, 561.4984, 356.7296], - [439.37653, 263.42395, 467.01526, 271.79144], - [480.8597, 269.64435, 502.18765, 274.80127], - [510.37033, 268.4982, 527.67017, 273.04935]]), + results['detection_boxes'], + np.array([[408.1708, 285.11456, 561.84924, 356.42285], + [438.88098, 264.46606, 467.07275, 271.76355], + [510.19467, 268.46664, 528.26935, 273.37192], + [480.9472, 269.74115, 502.00842, 274.85553]]), decimal=1) + def _base_test_single(self, model_path, inputs): + predictor = YoloXPredictor(model_path=model_path, score_thresh=0.5) + + outputs = predictor(inputs) + self.assertEqual(len(outputs), 1) + output = outputs[0] + self._assert_results(output) + + def _base_test_batch(self, model_path, inputs, num_samples, batch_size): + assert isinstance(inputs, list) and len(inputs) == 1 + + predictor = YoloXPredictor( + model_path=model_path, score_thresh=0.5, batch_size=batch_size) + outputs = predictor(inputs * num_samples) + + self.assertEqual(len(outputs), num_samples) + for output in outputs: + self._assert_results(output) + + def test_blade_nopre_notrt(self): + inputs = [np.asarray(Image.open(self.img))] + blade_path = PRETRAINED_MODEL_YOLOXS_NOPRE_NOTRT_BLADE + self._base_test_single(blade_path, inputs) + + @unittest.skipIf(True, + 'Need export blade model to support dynamic batch size') + def test_blade_nopre_notrt_batch(self): + inputs = [np.asarray(Image.open(self.img))] + blade_path = PRETRAINED_MODEL_YOLOXS_NOPRE_NOTRT_BLADE + self._base_test_batch(blade_path, inputs, num_samples=4, batch_size=2) + def test_yolox_detector_blade_pre_notrt(self): - img = os.path.join(DET_DATA_SMALL_COCO_LOCAL, - 'val2017/000000522713.jpg') - - input_data_list = [np.asarray(Image.open(img))] - + inputs = [self.img] blade_path = PRETRAINED_MODEL_YOLOXS_PRE_NOTRT_BLADE - predictor_blade = TorchYoloXPredictor( - model_path=blade_path, score_thresh=0.5) + self._base_test_single(blade_path, inputs) - output = predictor_blade.predict(input_data_list)[0] - self.assertIn('detection_boxes', output) - self.assertIn('detection_scores', output) - self.assertIn('detection_classes', output) - self.assertIn('detection_class_names', output) - self.assertIn('ori_img_shape', output) - - self.assertEqual(len(output['detection_boxes']), 4) - self.assertEqual(output['ori_img_shape'], [480, 640]) - - self.assertListEqual(output['detection_classes'].tolist(), - np.array([13, 8, 8, 8], dtype=np.int32).tolist()) - - self.assertListEqual(output['detection_class_names'], - ['bench', 'boat', 'boat', 'boat']) - - assert_array_almost_equal( - output['detection_scores'], - np.array([0.92593855, 0.60268813, 0.57775956, 0.5750004], - dtype=np.float32), - decimal=2) - - assert_array_almost_equal( - output['detection_boxes'], - np.array([[407.89523, 284.62598, 561.4984, 356.7296], - [439.37653, 263.42395, 467.01526, 271.79144], - [480.8597, 269.64435, 502.18765, 274.80127], - [510.37033, 268.4982, 527.67017, 273.04935]]), - decimal=1) + @unittest.skipIf(True, + 'Need export blade model to support dynamic batch size') + def test_yolox_detector_blade_pre_notrt_batch(self): + inputs = [self.img] + blade_path = PRETRAINED_MODEL_YOLOXS_PRE_NOTRT_BLADE + self._base_test_batch(blade_path, inputs, num_samples=3, batch_size=2) if __name__ == '__main__':