Merge branch 'github_master'

pull/200/head
jiangnana.jnn 2022-09-23 15:10:22 +08:00
commit 893d879b44
5 changed files with 284 additions and 361 deletions

View File

@ -89,6 +89,7 @@ class LoadImage:
results['img'] = img results['img'] = img
results['img_shape'] = img.shape results['img_shape'] = img.shape
results['ori_shape'] = img.shape results['ori_shape'] = img.shape
results['ori_img_shape'] = img.shape
results['img_fields'] = ['img'] results['img_fields'] = ['img']
return results return results

View File

@ -274,12 +274,12 @@ class PredictorV2(object):
else: else:
out_i[k] = None out_i[k] = None
out_i = self.postprocess_single(out_i) out_i = self.postprocess_single(out_i, *args, **kwargs)
outputs.append(out_i) outputs.append(out_i)
return outputs return outputs
def postprocess_single(self, inputs): def postprocess_single(self, inputs, *args, **kwargs):
"""Process outputs of single sample. """Process outputs of single sample.
If you need add some processing ops, you need to reimplement it. If you need add some processing ops, you need to reimplement it.
""" """

View File

@ -125,113 +125,133 @@ class DetrPredictor(DetectionPredictor):
"""""" """"""
class _JitProcessorWrapper:
def __init__(self, processor, device) -> None:
self.processor = processor
self.device = device
def __call__(self, results):
if self.processor is not None:
from mmcv.parallel import DataContainer as DC
outputs = {}
img = results['img']
img = torch.from_numpy(img).to(self.device)
img, img_meta = self.processor(img.unsqueeze(0)) # process batch
outputs['img'] = DC(
img.squeeze(0),
stack=True) # DC wrapper for collate batch and to device
outputs['img_metas'] = DC(img_meta, cpu_only=True)
return outputs
return results
@PREDICTORS.register_module() @PREDICTORS.register_module()
class TorchYoloXPredictor(PredictorInterface): class YoloXPredictor(DetectionPredictor):
"""Detection predictor for Yolox."""
def __init__(self, def __init__(self,
model_path, model_path,
config_file=None,
batch_size=1,
use_trt_efficientnms=False,
device=None,
save_results=False,
save_path=None,
pipelines=None,
max_det=100, max_det=100,
score_thresh=0.5, score_thresh=0.5,
use_trt_efficientnms=False, nms_thresh=None,
model_config=None): test_conf=None,
""" *arg,
init model **kwargs):
Args:
model_path: model file path
max_det: maximum number of detection
score_thresh: score_thresh to filter box
model_config: config string for model to init, in json format
"""
self.model_path = model_path
self.max_det = max_det self.max_det = max_det
self.device = 'cuda' if torch.cuda.is_available() else 'cpu' self.use_trt_efficientnms = use_trt_efficientnms
# set type
self.model_type = 'raw'
if model_path.endswith('jit'): if model_path.endswith('jit'):
self.model_type = 'jit' self.model_type = 'jit'
if model_path.endswith('blade'): elif model_path.endswith('blade'):
self.model_type = 'blade' self.model_type = 'blade'
else:
self.use_trt_efficientnms = use_trt_efficientnms self.model_type = 'raw'
if self.model_type == 'blade' or self.use_trt_efficientnms: if self.model_type == 'blade' or self.use_trt_efficientnms:
import torch_blade import torch_blade
if model_config: if self.model_type != 'raw' and config_file is None:
model_config = json.loads(model_config) config_file = model_path + '.config.json'
else:
model_config = {}
self.score_thresh = model_config[ super(YoloXPredictor, self).__init__(
'score_thresh'] if 'score_thresh' in model_config else score_thresh model_path,
config_file=config_file,
batch_size=batch_size,
device=device,
save_results=save_results,
save_path=save_path,
pipelines=pipelines,
score_threshold=score_thresh)
self.test_conf = test_conf or self.cfg['model'].get('test_conf', 0.01)
self.nms_thre = nms_thresh or self.cfg['model'].get('nms_thre', 0.65)
self.CLASSES = self.cfg.get('CLASSES', None) or self.cfg.get(
'classes', None)
assert self.CLASSES is not None
def _build_model(self):
if self.model_type != 'raw': if self.model_type != 'raw':
with io.open(self.model_path, 'rb') as infile:
model = torch.jit.load(infile, self.device)
else:
model = super()._build_model()
model = reparameterize_models(model)
return model
def prepare_model(self):
"""Build model from config file by default.
If the model is not loaded from a configuration file, e.g. torch jit model, you need to reimplement it.
"""
model = self._build_model()
model.to(self.device)
model.eval()
if self.model_type == 'raw':
load_checkpoint(model, self.model_path, map_location='cpu')
return model
def build_processor(self):
self.jit_preprocess = False
if self.model_type != 'raw':
if hasattr(self.cfg, 'export'):
self.jit_preprocess = self.cfg['export'].get(
'preprocess_jit', False)
if self.model_type != 'raw' and self.jit_preprocess:
# jit or blade model # jit or blade model
processor = None
preprocess_path = '.'.join( preprocess_path = '.'.join(
model_path.split('.')[:-1] + ['preprocess']) self.model_path.split('.')[:-1] + ['preprocess'])
if os.path.exists(preprocess_path): if os.path.exists(preprocess_path):
# use a preprocess jit model to speed up # use a preprocess jit model to speed up
with io.open(preprocess_path, 'rb') as infile: with io.open(preprocess_path, 'rb') as infile:
map_location = 'cpu' if self.device == 'cpu' else 'cuda' processor = torch.jit.load(infile, self.device)
self.preprocess = torch.jit.load(infile, map_location) return _JitProcessorWrapper(processor, self.device)
with io.open(model_path, 'rb') as infile:
map_location = 'cpu' if self.device == 'cpu' else 'cuda'
self.model = torch.jit.load(infile, map_location)
with io.open(model_path + '.config.json', 'r') as infile:
self.cfg = json.load(infile)
test_pipeline = self.cfg['test_pipeline']
self.CLASSES = self.cfg['classes']
self.preprocess_jit = self.cfg['export']['preprocess_jit']
self.traceable = True
else: else:
self.preprocess_jit = False return super().build_processor()
with io.open(self.model_path, 'rb') as infile:
checkpoint = torch.load(infile, map_location='cpu')
assert 'meta' in checkpoint and 'config' in checkpoint[ def forward(self, inputs):
'meta'], 'meta.config is missing from checkpoint' """Model forward.
If you need refactor model forward, you need to reimplement it.
"""
if self.model_type != 'raw':
with torch.no_grad():
outputs = self.model(inputs['img'])
outputs = {'results': outputs} # convert to dict format
else:
outputs = super().forward(inputs)
config_str = checkpoint['meta']['config'] if 'img_metas' not in outputs:
# get config outputs['img_metas'] = inputs['img_metas']
basename = os.path.basename(self.model_path)
fname, _ = os.path.splitext(basename)
self.local_config_file = os.path.join(CACHE_DIR,
f'{fname}_config.json')
if not os.path.exists(CACHE_DIR):
os.makedirs(CACHE_DIR)
with open(self.local_config_file, 'w') as ofile:
ofile.write(config_str)
self.cfg = mmcv_config_fromfile(self.local_config_file) return outputs
# build model
self.model = build_model(self.cfg.model)
self.traceable = getattr(self.model, 'trace_able', False)
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
map_location = 'cpu' if self.device == 'cpu' else 'cuda'
self.ckpt = load_checkpoint(
self.model, self.model_path, map_location=map_location)
self.model = reparameterize_models(self.model)
self.model.to(self.device)
self.model.eval()
test_pipeline = self.cfg.test_pipeline
self.CLASSES = self.cfg.CLASSES
# build pipeline
pipeline = [build_from_cfg(p, PIPELINES) for p in test_pipeline]
self.pipeline = Compose(pipeline)
self.test_conf = self.cfg['model'].get('test_conf', 0.01)
self.nms_thre = self.cfg['model'].get('nms_thre', 0.65)
self.num_classes = len(self.CLASSES)
def post_assign(self, outputs, img_metas): def post_assign(self, outputs, img_metas):
detection_boxes = [] detection_boxes = []
@ -267,101 +287,73 @@ class TorchYoloXPredictor(PredictorInterface):
} }
return test_outputs return test_outputs
def predict(self, input_data_list, batch_size=-1, to_numpy=True): def postprocess_single(self, inputs):
det_out = inputs
img_meta = det_out['img_metas']
if self.model_type != 'raw':
results = det_out['results']
if self.use_trt_efficientnms:
det_out = {}
det_out['detection_boxes'] = results[1] / img_meta[
'scale_factor'][0]
det_out['detection_scores'] = results[2]
det_out['detection_classes'] = results[3]
else:
det_out = self.post_assign(
postprocess(
results.unsqueeze(0), len(self.CLASSES),
self.test_conf, self.nms_thre),
img_metas=[img_meta])
det_out['detection_scores'] = det_out['detection_scores'][0]
det_out['detection_boxes'] = det_out['detection_boxes'][0]
det_out['detection_classes'] = det_out['detection_classes'][0]
resuts = super().postprocess_single(det_out)
resuts['ori_img_shape'] = list(img_meta['ori_img_shape'][:2])
return resuts
@deprecated(reason='Please use YoloXPredictor.')
@PREDICTORS.register_module()
class TorchYoloXPredictor(YoloXPredictor):
def __init__(self,
model_path,
max_det=100,
score_thresh=0.5,
use_trt_efficientnms=False,
model_config=None):
""" """
using session run predict a number of samples using batch_size Args:
model_path: model file path
max_det: maximum number of detection
score_thresh: score_thresh to filter box
model_config: config string for model to init, in json format
"""
if model_config:
model_config = json.loads(model_config)
else:
model_config = {}
Args: score_thresh = model_config[
input_data_list: a list of numpy array(in rgb order), each array is a sample 'score_thresh'] if 'score_thresh' in model_config else score_thresh
to be predicted super().__init__(
batch_size: batch_size passed by the caller, you can also ignore this param and model_path,
use a fixed number if you do not want to adjust batch_size in runtime config_file=None,
Return: batch_size=1,
result: a list of dict, each dict is the prediction result of one sample use_trt_efficientnms=use_trt_efficientnms,
eg, {"output1": value1, "output2": value2}, the value type can be device=None,
python int str float, and numpy array save_results=False,
""" save_path=None,
output_list = [] pipelines=None,
for idx, img in enumerate(input_data_list): max_det=max_det,
if type(img) is not np.ndarray: score_thresh=score_thresh,
img = np.asarray(img) nms_thresh=None,
test_conf=None)
ori_img_shape = img.shape[:2] def predict(self, input_data_list, batch_size=-1, to_numpy=True):
if self.preprocess_jit: return super().__call__(input_data_list)
# the input should also be as the type of uint8 as mmcv
img = torch.from_numpy(img).to(self.device)
img = img.unsqueeze(0)
if hasattr(self, 'preprocess'):
img, img_info = self.preprocess(img)
else:
data_dict = {'img': img}
data_dict = self.pipeline(data_dict)
img = data_dict['img']
img = torch.unsqueeze(img._data, 0).to(self.device)
data_dict.pop('img')
img_info = data_dict['img_metas']._data
if self.traceable:
if self.use_trt_efficientnms:
with torch.no_grad():
tmp_out = self.model(img)
det_out = {}
det_out['detection_boxes'] = tmp_out[1] / img_info[
'scale_factor'][0]
det_out['detection_scores'] = tmp_out[2]
det_out['detection_classes'] = tmp_out[3]
else:
with torch.no_grad():
det_out = self.post_assign(
postprocess(
self.model(img), self.num_classes,
self.test_conf, self.nms_thre),
img_metas=[img_info])
else:
with torch.no_grad():
det_out = self.model(
img, mode='test', img_metas=[img_info])
# print(det_out)
# det_out = det_out[:self.max_det]
# scale box to original image scale, this logic has some operation
# that can not be traced, see
# https://discuss.pytorch.org/t/windows-libtorch-c-load-cuda-module-with-std-runtime-error-message-shape-4-is-invalid-for-input-if-size-40/63073/4
# det_out = scale_coords(img.shape[2:], det_out, ori_img_shape, (scale_factor, pad))
detection_scores = det_out['detection_scores'][0]
if detection_scores is not None:
sel_ids = detection_scores > self.score_thresh
detection_scores = detection_scores[sel_ids]
detection_boxes = det_out['detection_boxes'][0][sel_ids]
detection_classes = det_out['detection_classes'][0][sel_ids]
else:
detection_boxes = None
detection_classes = None
num_boxes = detection_classes.shape[
0] if detection_classes is not None else 0
detection_classes_names = [
self.CLASSES[detection_classes[idx]]
for idx in range(num_boxes)
]
out = {
'ori_img_shape': list(ori_img_shape),
'detection_boxes': detection_boxes,
'detection_scores': detection_scores,
'detection_classes': detection_classes,
'detection_class_names': detection_classes_names,
}
output_list.append(out)
return output_list
@PREDICTORS.register_module() @PREDICTORS.register_module()

View File

@ -7,152 +7,99 @@ import unittest
import tempfile import tempfile
import numpy as np import numpy as np
from PIL import Image from PIL import Image
from easycv.predictors.detector import DetectionPredictor, YoloXPredictor, TorchYoloXPredictor
from easycv.predictors.detector import TorchYoloXPredictor, DetectionPredictor
from tests.ut_config import (PRETRAINED_MODEL_YOLOXS_EXPORT, from tests.ut_config import (PRETRAINED_MODEL_YOLOXS_EXPORT,
PRETRAINED_MODEL_YOLOXS_EXPORT_OLD,
PRETRAINED_MODEL_YOLOXS_NOPRE_NOTRT_JIT, PRETRAINED_MODEL_YOLOXS_NOPRE_NOTRT_JIT,
PRETRAINED_MODEL_YOLOXS_PRE_NOTRT_JIT, PRETRAINED_MODEL_YOLOXS_PRE_NOTRT_JIT,
DET_DATA_SMALL_COCO_LOCAL) DET_DATA_SMALL_COCO_LOCAL)
from numpy.testing import assert_array_almost_equal from numpy.testing import assert_array_almost_equal
class DetectorTest(unittest.TestCase): class YoloXPredictorTest(unittest.TestCase):
img = os.path.join(DET_DATA_SMALL_COCO_LOCAL, 'val2017/000000522713.jpg')
def setUp(self): def setUp(self):
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
def test_yolox_old_detector(self): def _assert_results(self, results):
detection_model_path = PRETRAINED_MODEL_YOLOXS_EXPORT_OLD self.assertEqual(results['ori_img_shape'], [480, 640])
self.assertListEqual(results['detection_classes'].tolist(),
img = os.path.join(DET_DATA_SMALL_COCO_LOCAL,
'val2017/000000522713.jpg')
input_data_list = [np.asarray(Image.open(img))]
predictor = TorchYoloXPredictor(
model_path=detection_model_path, score_thresh=0.5)
output = predictor.predict(input_data_list)[0]
def test_yolox_detector(self):
detection_model_path = PRETRAINED_MODEL_YOLOXS_EXPORT
img = os.path.join(DET_DATA_SMALL_COCO_LOCAL,
'val2017/000000522713.jpg')
input_data_list = [np.asarray(Image.open(img))]
predictor = TorchYoloXPredictor(
model_path=detection_model_path, score_thresh=0.5)
output = predictor.predict(input_data_list)[0]
self.assertIn('detection_boxes', output)
self.assertIn('detection_scores', output)
self.assertIn('detection_classes', output)
self.assertIn('detection_class_names', output)
self.assertIn('ori_img_shape', output)
self.assertEqual(len(output['detection_boxes']), 4)
self.assertEqual(output['ori_img_shape'], [480, 640])
self.assertListEqual(output['detection_classes'].tolist(),
np.array([13, 8, 8, 8], dtype=np.int32).tolist()) np.array([13, 8, 8, 8], dtype=np.int32).tolist())
self.assertListEqual(results['detection_class_names'],
self.assertListEqual(output['detection_class_names'],
['bench', 'boat', 'boat', 'boat']) ['bench', 'boat', 'boat', 'boat'])
assert_array_almost_equal( assert_array_almost_equal(
output['detection_scores'], results['detection_scores'],
np.array([0.92593855, 0.60268813, 0.57775956, 0.5750004], np.array([0.92335737, 0.59416807, 0.5567955, 0.55368793],
dtype=np.float32), dtype=np.float32),
decimal=2) decimal=2)
assert_array_almost_equal( assert_array_almost_equal(
output['detection_boxes'], results['detection_boxes'],
np.array([[407.89523, 284.62598, 561.4984, 356.7296], np.array([[408.1708, 285.11456, 561.84924, 356.42285],
[439.37653, 263.42395, 467.01526, 271.79144], [438.88098, 264.46606, 467.07275, 271.76355],
[480.8597, 269.64435, 502.18765, 274.80127], [510.19467, 268.46664, 528.26935, 273.37192],
[510.37033, 268.4982, 527.67017, 273.04935]]), [480.9472, 269.74115, 502.00842, 274.85553]]),
decimal=1) decimal=1)
def test_yolox_detector_jit_nopre_notrt(self): def _base_test_single(self, model_path, inputs):
img = os.path.join(DET_DATA_SMALL_COCO_LOCAL, predictor = YoloXPredictor(model_path=model_path, score_thresh=0.5)
'val2017/000000522713.jpg')
input_data_list = [np.asarray(Image.open(img))] outputs = predictor(inputs)
self.assertEqual(len(outputs), 1)
output = outputs[0]
self._assert_results(output)
def _base_test_batch(self, model_path, inputs, num_samples, batch_size):
assert isinstance(inputs, list) and len(inputs) == 1
predictor = YoloXPredictor(
model_path=model_path, score_thresh=0.5, batch_size=batch_size)
outputs = predictor(inputs * num_samples)
self.assertEqual(len(outputs), num_samples)
for output in outputs:
self._assert_results(output)
def test_single_raw(self):
model_path = PRETRAINED_MODEL_YOLOXS_EXPORT
inputs = [np.asarray(Image.open(self.img))]
self._base_test_single(model_path, inputs)
def test_batch_raw(self):
model_path = PRETRAINED_MODEL_YOLOXS_EXPORT
inputs = [np.asarray(Image.open(self.img))]
self._base_test_batch(model_path, inputs, num_samples=3, batch_size=2)
def test_single_jit_nopre_notrt(self):
jit_path = PRETRAINED_MODEL_YOLOXS_NOPRE_NOTRT_JIT jit_path = PRETRAINED_MODEL_YOLOXS_NOPRE_NOTRT_JIT
predictor_jit = TorchYoloXPredictor( self._base_test_single(jit_path, self.img)
model_path=jit_path, score_thresh=0.5)
output = predictor_jit.predict(input_data_list)[0] def test_batch_jit_nopre_notrt(self):
self.assertIn('detection_boxes', output) jit_path = PRETRAINED_MODEL_YOLOXS_NOPRE_NOTRT_JIT
self.assertIn('detection_scores', output) self._base_test_batch(
self.assertIn('detection_classes', output) jit_path, [self.img], num_samples=2, batch_size=1)
self.assertIn('detection_class_names', output)
self.assertIn('ori_img_shape', output)
self.assertEqual(len(output['detection_boxes']), 4)
self.assertEqual(output['ori_img_shape'], [480, 640])
self.assertListEqual(output['detection_classes'].tolist(),
np.array([13, 8, 8, 8], dtype=np.int32).tolist())
self.assertListEqual(output['detection_class_names'],
['bench', 'boat', 'boat', 'boat'])
assert_array_almost_equal(
output['detection_scores'],
np.array([0.92593855, 0.60268813, 0.57775956, 0.5750004],
dtype=np.float32),
decimal=2)
assert_array_almost_equal(
output['detection_boxes'],
np.array([[407.89523, 284.62598, 561.4984, 356.7296],
[439.37653, 263.42395, 467.01526, 271.79144],
[480.8597, 269.64435, 502.18765, 274.80127],
[510.37033, 268.4982, 527.67017, 273.04935]]),
decimal=1)
def test_yolox_detector_jit_pre_trt(self):
img = os.path.join(DET_DATA_SMALL_COCO_LOCAL,
'val2017/000000522713.jpg')
input_data_list = [np.asarray(Image.open(img))]
def test_single_jit_pre_trt(self):
jit_path = PRETRAINED_MODEL_YOLOXS_PRE_NOTRT_JIT jit_path = PRETRAINED_MODEL_YOLOXS_PRE_NOTRT_JIT
predictor_jit = TorchYoloXPredictor( self._base_test_single(jit_path, [self.img])
model_path=jit_path, score_thresh=0.5)
output = predictor_jit.predict(input_data_list)[0] def test_batch_jit_pre_trt(self):
self.assertIn('detection_boxes', output) jit_path = PRETRAINED_MODEL_YOLOXS_PRE_NOTRT_JIT
self.assertIn('detection_scores', output) self._base_test_batch(
self.assertIn('detection_classes', output) jit_path, [self.img], num_samples=4, batch_size=2)
self.assertIn('detection_class_names', output)
self.assertIn('ori_img_shape', output)
self.assertEqual(len(output['detection_boxes']), 4) def test_single_raw_TorchYoloXPredictor(self):
self.assertEqual(output['ori_img_shape'], [480, 640]) detection_model_path = PRETRAINED_MODEL_YOLOXS_EXPORT
input_data_list = [np.asarray(Image.open(self.img))]
predictor = TorchYoloXPredictor(
model_path=detection_model_path, score_thresh=0.5)
output = predictor(input_data_list)[0]
self._assert_results(output)
self.assertListEqual(output['detection_classes'].tolist(),
np.array([13, 8, 8, 8], dtype=np.int32).tolist())
self.assertListEqual(output['detection_class_names'], class DetectionPredictorTest(unittest.TestCase):
['bench', 'boat', 'boat', 'boat'])
assert_array_almost_equal( def setUp(self):
output['detection_scores'], print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
np.array([0.92593855, 0.60268813, 0.57775956, 0.5750004],
dtype=np.float32),
decimal=2)
assert_array_almost_equal(
output['detection_boxes'],
np.array([[407.89523, 284.62598, 561.4984, 356.7296],
[439.37653, 263.42395, 467.01526, 271.79144],
[480.8597, 269.64435, 502.18765, 274.80127],
[510.37033, 268.4982, 527.67017, 273.04935]]),
decimal=1)
def _detection_detector_assert(self, output): def _detection_detector_assert(self, output):
self.assertIn('detection_boxes', output) self.assertIn('detection_boxes', output)

View File

@ -6,7 +6,7 @@ import os
import unittest import unittest
import numpy as np import numpy as np
from PIL import Image from PIL import Image
from easycv.predictors.detector import TorchYoloXPredictor from easycv.predictors.detector import YoloXPredictor
from tests.ut_config import (PRETRAINED_MODEL_YOLOXS_NOPRE_NOTRT_BLADE, from tests.ut_config import (PRETRAINED_MODEL_YOLOXS_NOPRE_NOTRT_BLADE,
PRETRAINED_MODEL_YOLOXS_PRE_NOTRT_BLADE, PRETRAINED_MODEL_YOLOXS_PRE_NOTRT_BLADE,
DET_DATA_SMALL_COCO_LOCAL) DET_DATA_SMALL_COCO_LOCAL)
@ -17,90 +17,73 @@ from numpy.testing import assert_array_almost_equal
@unittest.skipIf(torch.__version__ != '1.8.1+cu102', @unittest.skipIf(torch.__version__ != '1.8.1+cu102',
'Blade need another environment') 'Blade need another environment')
class DetectorTest(unittest.TestCase): class YoloXPredictorBladeTest(unittest.TestCase):
img = os.path.join(DET_DATA_SMALL_COCO_LOCAL, 'val2017/000000522713.jpg')
def setUp(self): def setUp(self):
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
def test_yolox_detector_blade_nopre_notrt(self): def _assert_results(self, results):
img = os.path.join(DET_DATA_SMALL_COCO_LOCAL, self.assertEqual(results['ori_img_shape'], [480, 640])
'val2017/000000522713.jpg') self.assertListEqual(results['detection_classes'].tolist(),
input_data_list = [np.asarray(Image.open(img))]
blade_path = PRETRAINED_MODEL_YOLOXS_NOPRE_NOTRT_BLADE
predictor_blade = TorchYoloXPredictor(
model_path=blade_path, score_thresh=0.5)
output = predictor_blade.predict(input_data_list)[0]
self.assertIn('detection_boxes', output)
self.assertIn('detection_scores', output)
self.assertIn('detection_classes', output)
self.assertIn('detection_class_names', output)
self.assertIn('ori_img_shape', output)
self.assertEqual(len(output['detection_boxes']), 4)
self.assertEqual(output['ori_img_shape'], [480, 640])
self.assertListEqual(output['detection_classes'].tolist(),
np.array([13, 8, 8, 8], dtype=np.int32).tolist()) np.array([13, 8, 8, 8], dtype=np.int32).tolist())
self.assertListEqual(results['detection_class_names'],
self.assertListEqual(output['detection_class_names'],
['bench', 'boat', 'boat', 'boat']) ['bench', 'boat', 'boat', 'boat'])
assert_array_almost_equal( assert_array_almost_equal(
output['detection_scores'], results['detection_scores'],
np.array([0.92593855, 0.60268813, 0.57775956, 0.5750004], np.array([0.92335737, 0.59416807, 0.5567955, 0.55368793],
dtype=np.float32), dtype=np.float32),
decimal=2) decimal=2)
assert_array_almost_equal( assert_array_almost_equal(
output['detection_boxes'], results['detection_boxes'],
np.array([[407.89523, 284.62598, 561.4984, 356.7296], np.array([[408.1708, 285.11456, 561.84924, 356.42285],
[439.37653, 263.42395, 467.01526, 271.79144], [438.88098, 264.46606, 467.07275, 271.76355],
[480.8597, 269.64435, 502.18765, 274.80127], [510.19467, 268.46664, 528.26935, 273.37192],
[510.37033, 268.4982, 527.67017, 273.04935]]), [480.9472, 269.74115, 502.00842, 274.85553]]),
decimal=1) decimal=1)
def _base_test_single(self, model_path, inputs):
predictor = YoloXPredictor(model_path=model_path, score_thresh=0.5)
outputs = predictor(inputs)
self.assertEqual(len(outputs), 1)
output = outputs[0]
self._assert_results(output)
def _base_test_batch(self, model_path, inputs, num_samples, batch_size):
assert isinstance(inputs, list) and len(inputs) == 1
predictor = YoloXPredictor(
model_path=model_path, score_thresh=0.5, batch_size=batch_size)
outputs = predictor(inputs * num_samples)
self.assertEqual(len(outputs), num_samples)
for output in outputs:
self._assert_results(output)
def test_blade_nopre_notrt(self):
inputs = [np.asarray(Image.open(self.img))]
blade_path = PRETRAINED_MODEL_YOLOXS_NOPRE_NOTRT_BLADE
self._base_test_single(blade_path, inputs)
@unittest.skipIf(True,
'Need export blade model to support dynamic batch size')
def test_blade_nopre_notrt_batch(self):
inputs = [np.asarray(Image.open(self.img))]
blade_path = PRETRAINED_MODEL_YOLOXS_NOPRE_NOTRT_BLADE
self._base_test_batch(blade_path, inputs, num_samples=4, batch_size=2)
def test_yolox_detector_blade_pre_notrt(self): def test_yolox_detector_blade_pre_notrt(self):
img = os.path.join(DET_DATA_SMALL_COCO_LOCAL, inputs = [self.img]
'val2017/000000522713.jpg')
input_data_list = [np.asarray(Image.open(img))]
blade_path = PRETRAINED_MODEL_YOLOXS_PRE_NOTRT_BLADE blade_path = PRETRAINED_MODEL_YOLOXS_PRE_NOTRT_BLADE
predictor_blade = TorchYoloXPredictor( self._base_test_single(blade_path, inputs)
model_path=blade_path, score_thresh=0.5)
output = predictor_blade.predict(input_data_list)[0] @unittest.skipIf(True,
self.assertIn('detection_boxes', output) 'Need export blade model to support dynamic batch size')
self.assertIn('detection_scores', output) def test_yolox_detector_blade_pre_notrt_batch(self):
self.assertIn('detection_classes', output) inputs = [self.img]
self.assertIn('detection_class_names', output) blade_path = PRETRAINED_MODEL_YOLOXS_PRE_NOTRT_BLADE
self.assertIn('ori_img_shape', output) self._base_test_batch(blade_path, inputs, num_samples=3, batch_size=2)
self.assertEqual(len(output['detection_boxes']), 4)
self.assertEqual(output['ori_img_shape'], [480, 640])
self.assertListEqual(output['detection_classes'].tolist(),
np.array([13, 8, 8, 8], dtype=np.int32).tolist())
self.assertListEqual(output['detection_class_names'],
['bench', 'boat', 'boat', 'boat'])
assert_array_almost_equal(
output['detection_scores'],
np.array([0.92593855, 0.60268813, 0.57775956, 0.5750004],
dtype=np.float32),
decimal=2)
assert_array_almost_equal(
output['detection_boxes'],
np.array([[407.89523, 284.62598, 561.4984, 356.7296],
[439.37653, 263.42395, 467.01526, 271.79144],
[480.8597, 269.64435, 502.18765, 274.80127],
[510.37033, 268.4982, 527.67017, 273.04935]]),
decimal=1)
if __name__ == '__main__': if __name__ == '__main__':