mirror of https://github.com/alibaba/EasyCV.git
parent
5dfe7b2898
commit
7a89d1b7b8
|
@ -89,6 +89,7 @@ class LoadImage:
|
|||
results['img'] = img
|
||||
results['img_shape'] = img.shape
|
||||
results['ori_shape'] = img.shape
|
||||
results['ori_img_shape'] = img.shape
|
||||
results['img_fields'] = ['img']
|
||||
return results
|
||||
|
||||
|
|
|
@ -274,12 +274,12 @@ class PredictorV2(object):
|
|||
else:
|
||||
out_i[k] = None
|
||||
|
||||
out_i = self.postprocess_single(out_i)
|
||||
out_i = self.postprocess_single(out_i, *args, **kwargs)
|
||||
outputs.append(out_i)
|
||||
|
||||
return outputs
|
||||
|
||||
def postprocess_single(self, inputs):
|
||||
def postprocess_single(self, inputs, *args, **kwargs):
|
||||
"""Process outputs of single sample.
|
||||
If you need add some processing ops, you need to reimplement it.
|
||||
"""
|
||||
|
|
|
@ -125,113 +125,133 @@ class DetrPredictor(DetectionPredictor):
|
|||
""""""
|
||||
|
||||
|
||||
class _JitProcessorWrapper:
|
||||
|
||||
def __init__(self, processor, device) -> None:
|
||||
self.processor = processor
|
||||
self.device = device
|
||||
|
||||
def __call__(self, results):
|
||||
if self.processor is not None:
|
||||
from mmcv.parallel import DataContainer as DC
|
||||
outputs = {}
|
||||
img = results['img']
|
||||
img = torch.from_numpy(img).to(self.device)
|
||||
img, img_meta = self.processor(img.unsqueeze(0)) # process batch
|
||||
outputs['img'] = DC(
|
||||
img.squeeze(0),
|
||||
stack=True) # DC wrapper for collate batch and to device
|
||||
outputs['img_metas'] = DC(img_meta, cpu_only=True)
|
||||
return outputs
|
||||
return results
|
||||
|
||||
|
||||
@PREDICTORS.register_module()
|
||||
class TorchYoloXPredictor(PredictorInterface):
|
||||
class YoloXPredictor(DetectionPredictor):
|
||||
"""Detection predictor for Yolox."""
|
||||
|
||||
def __init__(self,
|
||||
model_path,
|
||||
config_file=None,
|
||||
batch_size=1,
|
||||
use_trt_efficientnms=False,
|
||||
device=None,
|
||||
save_results=False,
|
||||
save_path=None,
|
||||
pipelines=None,
|
||||
max_det=100,
|
||||
score_thresh=0.5,
|
||||
use_trt_efficientnms=False,
|
||||
model_config=None):
|
||||
"""
|
||||
init model
|
||||
|
||||
Args:
|
||||
model_path: model file path
|
||||
max_det: maximum number of detection
|
||||
score_thresh: score_thresh to filter box
|
||||
model_config: config string for model to init, in json format
|
||||
"""
|
||||
self.model_path = model_path
|
||||
nms_thresh=None,
|
||||
test_conf=None,
|
||||
*arg,
|
||||
**kwargs):
|
||||
self.max_det = max_det
|
||||
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
# set type
|
||||
self.model_type = 'raw'
|
||||
self.use_trt_efficientnms = use_trt_efficientnms
|
||||
|
||||
if model_path.endswith('jit'):
|
||||
self.model_type = 'jit'
|
||||
if model_path.endswith('blade'):
|
||||
elif model_path.endswith('blade'):
|
||||
self.model_type = 'blade'
|
||||
|
||||
self.use_trt_efficientnms = use_trt_efficientnms
|
||||
else:
|
||||
self.model_type = 'raw'
|
||||
|
||||
if self.model_type == 'blade' or self.use_trt_efficientnms:
|
||||
import torch_blade
|
||||
|
||||
if model_config:
|
||||
model_config = json.loads(model_config)
|
||||
else:
|
||||
model_config = {}
|
||||
if self.model_type != 'raw' and config_file is None:
|
||||
config_file = model_path + '.config.json'
|
||||
|
||||
self.score_thresh = model_config[
|
||||
'score_thresh'] if 'score_thresh' in model_config else score_thresh
|
||||
super(YoloXPredictor, self).__init__(
|
||||
model_path,
|
||||
config_file=config_file,
|
||||
batch_size=batch_size,
|
||||
device=device,
|
||||
save_results=save_results,
|
||||
save_path=save_path,
|
||||
pipelines=pipelines,
|
||||
score_threshold=score_thresh)
|
||||
|
||||
self.test_conf = test_conf or self.cfg['model'].get('test_conf', 0.01)
|
||||
self.nms_thre = nms_thresh or self.cfg['model'].get('nms_thre', 0.65)
|
||||
self.CLASSES = self.cfg.get('CLASSES', None) or self.cfg.get(
|
||||
'classes', None)
|
||||
assert self.CLASSES is not None
|
||||
|
||||
def _build_model(self):
|
||||
if self.model_type != 'raw':
|
||||
with io.open(self.model_path, 'rb') as infile:
|
||||
model = torch.jit.load(infile, self.device)
|
||||
else:
|
||||
model = super()._build_model()
|
||||
model = reparameterize_models(model)
|
||||
return model
|
||||
|
||||
def prepare_model(self):
|
||||
"""Build model from config file by default.
|
||||
If the model is not loaded from a configuration file, e.g. torch jit model, you need to reimplement it.
|
||||
"""
|
||||
model = self._build_model()
|
||||
model.to(self.device)
|
||||
model.eval()
|
||||
if self.model_type == 'raw':
|
||||
load_checkpoint(model, self.model_path, map_location='cpu')
|
||||
return model
|
||||
|
||||
def build_processor(self):
|
||||
self.jit_preprocess = False
|
||||
if self.model_type != 'raw':
|
||||
if hasattr(self.cfg, 'export'):
|
||||
self.jit_preprocess = self.cfg['export'].get(
|
||||
'preprocess_jit', False)
|
||||
|
||||
if self.model_type != 'raw' and self.jit_preprocess:
|
||||
# jit or blade model
|
||||
processor = None
|
||||
preprocess_path = '.'.join(
|
||||
model_path.split('.')[:-1] + ['preprocess'])
|
||||
self.model_path.split('.')[:-1] + ['preprocess'])
|
||||
if os.path.exists(preprocess_path):
|
||||
# use a preprocess jit model to speed up
|
||||
with io.open(preprocess_path, 'rb') as infile:
|
||||
map_location = 'cpu' if self.device == 'cpu' else 'cuda'
|
||||
self.preprocess = torch.jit.load(infile, map_location)
|
||||
|
||||
with io.open(model_path, 'rb') as infile:
|
||||
map_location = 'cpu' if self.device == 'cpu' else 'cuda'
|
||||
self.model = torch.jit.load(infile, map_location)
|
||||
with io.open(model_path + '.config.json', 'r') as infile:
|
||||
self.cfg = json.load(infile)
|
||||
test_pipeline = self.cfg['test_pipeline']
|
||||
self.CLASSES = self.cfg['classes']
|
||||
self.preprocess_jit = self.cfg['export']['preprocess_jit']
|
||||
|
||||
self.traceable = True
|
||||
|
||||
processor = torch.jit.load(infile, self.device)
|
||||
return _JitProcessorWrapper(processor, self.device)
|
||||
else:
|
||||
self.preprocess_jit = False
|
||||
with io.open(self.model_path, 'rb') as infile:
|
||||
checkpoint = torch.load(infile, map_location='cpu')
|
||||
return super().build_processor()
|
||||
|
||||
assert 'meta' in checkpoint and 'config' in checkpoint[
|
||||
'meta'], 'meta.config is missing from checkpoint'
|
||||
def forward(self, inputs):
|
||||
"""Model forward.
|
||||
If you need refactor model forward, you need to reimplement it.
|
||||
"""
|
||||
if self.model_type != 'raw':
|
||||
with torch.no_grad():
|
||||
outputs = self.model(inputs['img'])
|
||||
outputs = {'results': outputs} # convert to dict format
|
||||
else:
|
||||
outputs = super().forward(inputs)
|
||||
|
||||
config_str = checkpoint['meta']['config']
|
||||
# get config
|
||||
basename = os.path.basename(self.model_path)
|
||||
fname, _ = os.path.splitext(basename)
|
||||
self.local_config_file = os.path.join(CACHE_DIR,
|
||||
f'{fname}_config.json')
|
||||
if not os.path.exists(CACHE_DIR):
|
||||
os.makedirs(CACHE_DIR)
|
||||
with open(self.local_config_file, 'w') as ofile:
|
||||
ofile.write(config_str)
|
||||
if 'img_metas' not in outputs:
|
||||
outputs['img_metas'] = inputs['img_metas']
|
||||
|
||||
self.cfg = mmcv_config_fromfile(self.local_config_file)
|
||||
|
||||
# build model
|
||||
self.model = build_model(self.cfg.model)
|
||||
|
||||
self.traceable = getattr(self.model, 'trace_able', False)
|
||||
|
||||
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
map_location = 'cpu' if self.device == 'cpu' else 'cuda'
|
||||
self.ckpt = load_checkpoint(
|
||||
self.model, self.model_path, map_location=map_location)
|
||||
|
||||
self.model = reparameterize_models(self.model)
|
||||
|
||||
self.model.to(self.device)
|
||||
self.model.eval()
|
||||
test_pipeline = self.cfg.test_pipeline
|
||||
self.CLASSES = self.cfg.CLASSES
|
||||
|
||||
# build pipeline
|
||||
pipeline = [build_from_cfg(p, PIPELINES) for p in test_pipeline]
|
||||
self.pipeline = Compose(pipeline)
|
||||
|
||||
self.test_conf = self.cfg['model'].get('test_conf', 0.01)
|
||||
self.nms_thre = self.cfg['model'].get('nms_thre', 0.65)
|
||||
self.num_classes = len(self.CLASSES)
|
||||
return outputs
|
||||
|
||||
def post_assign(self, outputs, img_metas):
|
||||
detection_boxes = []
|
||||
|
@ -267,101 +287,73 @@ class TorchYoloXPredictor(PredictorInterface):
|
|||
}
|
||||
return test_outputs
|
||||
|
||||
def predict(self, input_data_list, batch_size=-1, to_numpy=True):
|
||||
def postprocess_single(self, inputs):
|
||||
det_out = inputs
|
||||
img_meta = det_out['img_metas']
|
||||
|
||||
if self.model_type != 'raw':
|
||||
results = det_out['results']
|
||||
if self.use_trt_efficientnms:
|
||||
det_out = {}
|
||||
det_out['detection_boxes'] = results[1] / img_meta[
|
||||
'scale_factor'][0]
|
||||
det_out['detection_scores'] = results[2]
|
||||
det_out['detection_classes'] = results[3]
|
||||
else:
|
||||
det_out = self.post_assign(
|
||||
postprocess(
|
||||
results.unsqueeze(0), len(self.CLASSES),
|
||||
self.test_conf, self.nms_thre),
|
||||
img_metas=[img_meta])
|
||||
det_out['detection_scores'] = det_out['detection_scores'][0]
|
||||
det_out['detection_boxes'] = det_out['detection_boxes'][0]
|
||||
det_out['detection_classes'] = det_out['detection_classes'][0]
|
||||
|
||||
resuts = super().postprocess_single(det_out)
|
||||
resuts['ori_img_shape'] = list(img_meta['ori_img_shape'][:2])
|
||||
return resuts
|
||||
|
||||
|
||||
@deprecated(reason='Please use YoloXPredictor.')
|
||||
@PREDICTORS.register_module()
|
||||
class TorchYoloXPredictor(YoloXPredictor):
|
||||
|
||||
def __init__(self,
|
||||
model_path,
|
||||
max_det=100,
|
||||
score_thresh=0.5,
|
||||
use_trt_efficientnms=False,
|
||||
model_config=None):
|
||||
"""
|
||||
using session run predict a number of samples using batch_size
|
||||
Args:
|
||||
model_path: model file path
|
||||
max_det: maximum number of detection
|
||||
score_thresh: score_thresh to filter box
|
||||
model_config: config string for model to init, in json format
|
||||
"""
|
||||
if model_config:
|
||||
model_config = json.loads(model_config)
|
||||
else:
|
||||
model_config = {}
|
||||
|
||||
Args:
|
||||
input_data_list: a list of numpy array(in rgb order), each array is a sample
|
||||
to be predicted
|
||||
batch_size: batch_size passed by the caller, you can also ignore this param and
|
||||
use a fixed number if you do not want to adjust batch_size in runtime
|
||||
Return:
|
||||
result: a list of dict, each dict is the prediction result of one sample
|
||||
eg, {"output1": value1, "output2": value2}, the value type can be
|
||||
python int str float, and numpy array
|
||||
"""
|
||||
output_list = []
|
||||
for idx, img in enumerate(input_data_list):
|
||||
if type(img) is not np.ndarray:
|
||||
img = np.asarray(img)
|
||||
score_thresh = model_config[
|
||||
'score_thresh'] if 'score_thresh' in model_config else score_thresh
|
||||
super().__init__(
|
||||
model_path,
|
||||
config_file=None,
|
||||
batch_size=1,
|
||||
use_trt_efficientnms=use_trt_efficientnms,
|
||||
device=None,
|
||||
save_results=False,
|
||||
save_path=None,
|
||||
pipelines=None,
|
||||
max_det=max_det,
|
||||
score_thresh=score_thresh,
|
||||
nms_thresh=None,
|
||||
test_conf=None)
|
||||
|
||||
ori_img_shape = img.shape[:2]
|
||||
if self.preprocess_jit:
|
||||
# the input should also be as the type of uint8 as mmcv
|
||||
img = torch.from_numpy(img).to(self.device)
|
||||
img = img.unsqueeze(0)
|
||||
|
||||
if hasattr(self, 'preprocess'):
|
||||
img, img_info = self.preprocess(img)
|
||||
|
||||
else:
|
||||
data_dict = {'img': img}
|
||||
data_dict = self.pipeline(data_dict)
|
||||
img = data_dict['img']
|
||||
img = torch.unsqueeze(img._data, 0).to(self.device)
|
||||
data_dict.pop('img')
|
||||
img_info = data_dict['img_metas']._data
|
||||
|
||||
if self.traceable:
|
||||
if self.use_trt_efficientnms:
|
||||
with torch.no_grad():
|
||||
tmp_out = self.model(img)
|
||||
det_out = {}
|
||||
det_out['detection_boxes'] = tmp_out[1] / img_info[
|
||||
'scale_factor'][0]
|
||||
det_out['detection_scores'] = tmp_out[2]
|
||||
det_out['detection_classes'] = tmp_out[3]
|
||||
|
||||
else:
|
||||
with torch.no_grad():
|
||||
det_out = self.post_assign(
|
||||
postprocess(
|
||||
self.model(img), self.num_classes,
|
||||
self.test_conf, self.nms_thre),
|
||||
img_metas=[img_info])
|
||||
else:
|
||||
with torch.no_grad():
|
||||
det_out = self.model(
|
||||
img, mode='test', img_metas=[img_info])
|
||||
|
||||
# print(det_out)
|
||||
# det_out = det_out[:self.max_det]
|
||||
# scale box to original image scale, this logic has some operation
|
||||
# that can not be traced, see
|
||||
# https://discuss.pytorch.org/t/windows-libtorch-c-load-cuda-module-with-std-runtime-error-message-shape-4-is-invalid-for-input-if-size-40/63073/4
|
||||
# det_out = scale_coords(img.shape[2:], det_out, ori_img_shape, (scale_factor, pad))
|
||||
|
||||
detection_scores = det_out['detection_scores'][0]
|
||||
|
||||
if detection_scores is not None:
|
||||
sel_ids = detection_scores > self.score_thresh
|
||||
detection_scores = detection_scores[sel_ids]
|
||||
detection_boxes = det_out['detection_boxes'][0][sel_ids]
|
||||
detection_classes = det_out['detection_classes'][0][sel_ids]
|
||||
else:
|
||||
detection_boxes = None
|
||||
detection_classes = None
|
||||
|
||||
num_boxes = detection_classes.shape[
|
||||
0] if detection_classes is not None else 0
|
||||
|
||||
detection_classes_names = [
|
||||
self.CLASSES[detection_classes[idx]]
|
||||
for idx in range(num_boxes)
|
||||
]
|
||||
|
||||
out = {
|
||||
'ori_img_shape': list(ori_img_shape),
|
||||
'detection_boxes': detection_boxes,
|
||||
'detection_scores': detection_scores,
|
||||
'detection_classes': detection_classes,
|
||||
'detection_class_names': detection_classes_names,
|
||||
}
|
||||
|
||||
output_list.append(out)
|
||||
|
||||
return output_list
|
||||
def predict(self, input_data_list, batch_size=-1, to_numpy=True):
|
||||
return super().__call__(input_data_list)
|
||||
|
||||
|
||||
@PREDICTORS.register_module()
|
||||
|
|
|
@ -7,152 +7,99 @@ import unittest
|
|||
import tempfile
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from easycv.predictors.detector import TorchYoloXPredictor, DetectionPredictor
|
||||
from easycv.predictors.detector import DetectionPredictor, YoloXPredictor, TorchYoloXPredictor
|
||||
from tests.ut_config import (PRETRAINED_MODEL_YOLOXS_EXPORT,
|
||||
PRETRAINED_MODEL_YOLOXS_EXPORT_OLD,
|
||||
PRETRAINED_MODEL_YOLOXS_NOPRE_NOTRT_JIT,
|
||||
PRETRAINED_MODEL_YOLOXS_PRE_NOTRT_JIT,
|
||||
DET_DATA_SMALL_COCO_LOCAL)
|
||||
from numpy.testing import assert_array_almost_equal
|
||||
|
||||
|
||||
class DetectorTest(unittest.TestCase):
|
||||
class YoloXPredictorTest(unittest.TestCase):
|
||||
img = os.path.join(DET_DATA_SMALL_COCO_LOCAL, 'val2017/000000522713.jpg')
|
||||
|
||||
def setUp(self):
|
||||
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
|
||||
|
||||
def test_yolox_old_detector(self):
|
||||
detection_model_path = PRETRAINED_MODEL_YOLOXS_EXPORT_OLD
|
||||
|
||||
img = os.path.join(DET_DATA_SMALL_COCO_LOCAL,
|
||||
'val2017/000000522713.jpg')
|
||||
|
||||
input_data_list = [np.asarray(Image.open(img))]
|
||||
predictor = TorchYoloXPredictor(
|
||||
model_path=detection_model_path, score_thresh=0.5)
|
||||
|
||||
output = predictor.predict(input_data_list)[0]
|
||||
|
||||
def test_yolox_detector(self):
|
||||
detection_model_path = PRETRAINED_MODEL_YOLOXS_EXPORT
|
||||
|
||||
img = os.path.join(DET_DATA_SMALL_COCO_LOCAL,
|
||||
'val2017/000000522713.jpg')
|
||||
|
||||
input_data_list = [np.asarray(Image.open(img))]
|
||||
predictor = TorchYoloXPredictor(
|
||||
model_path=detection_model_path, score_thresh=0.5)
|
||||
|
||||
output = predictor.predict(input_data_list)[0]
|
||||
self.assertIn('detection_boxes', output)
|
||||
self.assertIn('detection_scores', output)
|
||||
self.assertIn('detection_classes', output)
|
||||
self.assertIn('detection_class_names', output)
|
||||
self.assertIn('ori_img_shape', output)
|
||||
|
||||
self.assertEqual(len(output['detection_boxes']), 4)
|
||||
self.assertEqual(output['ori_img_shape'], [480, 640])
|
||||
|
||||
self.assertListEqual(output['detection_classes'].tolist(),
|
||||
def _assert_results(self, results):
|
||||
self.assertEqual(results['ori_img_shape'], [480, 640])
|
||||
self.assertListEqual(results['detection_classes'].tolist(),
|
||||
np.array([13, 8, 8, 8], dtype=np.int32).tolist())
|
||||
|
||||
self.assertListEqual(output['detection_class_names'],
|
||||
self.assertListEqual(results['detection_class_names'],
|
||||
['bench', 'boat', 'boat', 'boat'])
|
||||
|
||||
assert_array_almost_equal(
|
||||
output['detection_scores'],
|
||||
np.array([0.92593855, 0.60268813, 0.57775956, 0.5750004],
|
||||
results['detection_scores'],
|
||||
np.array([0.92335737, 0.59416807, 0.5567955, 0.55368793],
|
||||
dtype=np.float32),
|
||||
decimal=2)
|
||||
|
||||
assert_array_almost_equal(
|
||||
output['detection_boxes'],
|
||||
np.array([[407.89523, 284.62598, 561.4984, 356.7296],
|
||||
[439.37653, 263.42395, 467.01526, 271.79144],
|
||||
[480.8597, 269.64435, 502.18765, 274.80127],
|
||||
[510.37033, 268.4982, 527.67017, 273.04935]]),
|
||||
results['detection_boxes'],
|
||||
np.array([[408.1708, 285.11456, 561.84924, 356.42285],
|
||||
[438.88098, 264.46606, 467.07275, 271.76355],
|
||||
[510.19467, 268.46664, 528.26935, 273.37192],
|
||||
[480.9472, 269.74115, 502.00842, 274.85553]]),
|
||||
decimal=1)
|
||||
|
||||
def test_yolox_detector_jit_nopre_notrt(self):
|
||||
img = os.path.join(DET_DATA_SMALL_COCO_LOCAL,
|
||||
'val2017/000000522713.jpg')
|
||||
def _base_test_single(self, model_path, inputs):
|
||||
predictor = YoloXPredictor(model_path=model_path, score_thresh=0.5)
|
||||
|
||||
input_data_list = [np.asarray(Image.open(img))]
|
||||
outputs = predictor(inputs)
|
||||
self.assertEqual(len(outputs), 1)
|
||||
output = outputs[0]
|
||||
self._assert_results(output)
|
||||
|
||||
def _base_test_batch(self, model_path, inputs, num_samples, batch_size):
|
||||
assert isinstance(inputs, list) and len(inputs) == 1
|
||||
|
||||
predictor = YoloXPredictor(
|
||||
model_path=model_path, score_thresh=0.5, batch_size=batch_size)
|
||||
outputs = predictor(inputs * num_samples)
|
||||
|
||||
self.assertEqual(len(outputs), num_samples)
|
||||
for output in outputs:
|
||||
self._assert_results(output)
|
||||
|
||||
def test_single_raw(self):
|
||||
model_path = PRETRAINED_MODEL_YOLOXS_EXPORT
|
||||
inputs = [np.asarray(Image.open(self.img))]
|
||||
self._base_test_single(model_path, inputs)
|
||||
|
||||
def test_batch_raw(self):
|
||||
model_path = PRETRAINED_MODEL_YOLOXS_EXPORT
|
||||
inputs = [np.asarray(Image.open(self.img))]
|
||||
self._base_test_batch(model_path, inputs, num_samples=3, batch_size=2)
|
||||
|
||||
def test_single_jit_nopre_notrt(self):
|
||||
jit_path = PRETRAINED_MODEL_YOLOXS_NOPRE_NOTRT_JIT
|
||||
predictor_jit = TorchYoloXPredictor(
|
||||
model_path=jit_path, score_thresh=0.5)
|
||||
self._base_test_single(jit_path, self.img)
|
||||
|
||||
output = predictor_jit.predict(input_data_list)[0]
|
||||
self.assertIn('detection_boxes', output)
|
||||
self.assertIn('detection_scores', output)
|
||||
self.assertIn('detection_classes', output)
|
||||
self.assertIn('detection_class_names', output)
|
||||
self.assertIn('ori_img_shape', output)
|
||||
|
||||
self.assertEqual(len(output['detection_boxes']), 4)
|
||||
self.assertEqual(output['ori_img_shape'], [480, 640])
|
||||
|
||||
self.assertListEqual(output['detection_classes'].tolist(),
|
||||
np.array([13, 8, 8, 8], dtype=np.int32).tolist())
|
||||
|
||||
self.assertListEqual(output['detection_class_names'],
|
||||
['bench', 'boat', 'boat', 'boat'])
|
||||
|
||||
assert_array_almost_equal(
|
||||
output['detection_scores'],
|
||||
np.array([0.92593855, 0.60268813, 0.57775956, 0.5750004],
|
||||
dtype=np.float32),
|
||||
decimal=2)
|
||||
|
||||
assert_array_almost_equal(
|
||||
output['detection_boxes'],
|
||||
np.array([[407.89523, 284.62598, 561.4984, 356.7296],
|
||||
[439.37653, 263.42395, 467.01526, 271.79144],
|
||||
[480.8597, 269.64435, 502.18765, 274.80127],
|
||||
[510.37033, 268.4982, 527.67017, 273.04935]]),
|
||||
decimal=1)
|
||||
|
||||
def test_yolox_detector_jit_pre_trt(self):
|
||||
img = os.path.join(DET_DATA_SMALL_COCO_LOCAL,
|
||||
'val2017/000000522713.jpg')
|
||||
|
||||
input_data_list = [np.asarray(Image.open(img))]
|
||||
def test_batch_jit_nopre_notrt(self):
|
||||
jit_path = PRETRAINED_MODEL_YOLOXS_NOPRE_NOTRT_JIT
|
||||
self._base_test_batch(
|
||||
jit_path, [self.img], num_samples=2, batch_size=1)
|
||||
|
||||
def test_single_jit_pre_trt(self):
|
||||
jit_path = PRETRAINED_MODEL_YOLOXS_PRE_NOTRT_JIT
|
||||
predictor_jit = TorchYoloXPredictor(
|
||||
model_path=jit_path, score_thresh=0.5)
|
||||
self._base_test_single(jit_path, [self.img])
|
||||
|
||||
output = predictor_jit.predict(input_data_list)[0]
|
||||
self.assertIn('detection_boxes', output)
|
||||
self.assertIn('detection_scores', output)
|
||||
self.assertIn('detection_classes', output)
|
||||
self.assertIn('detection_class_names', output)
|
||||
self.assertIn('ori_img_shape', output)
|
||||
def test_batch_jit_pre_trt(self):
|
||||
jit_path = PRETRAINED_MODEL_YOLOXS_PRE_NOTRT_JIT
|
||||
self._base_test_batch(
|
||||
jit_path, [self.img], num_samples=4, batch_size=2)
|
||||
|
||||
self.assertEqual(len(output['detection_boxes']), 4)
|
||||
self.assertEqual(output['ori_img_shape'], [480, 640])
|
||||
def test_single_raw_TorchYoloXPredictor(self):
|
||||
detection_model_path = PRETRAINED_MODEL_YOLOXS_EXPORT
|
||||
input_data_list = [np.asarray(Image.open(self.img))]
|
||||
predictor = TorchYoloXPredictor(
|
||||
model_path=detection_model_path, score_thresh=0.5)
|
||||
output = predictor(input_data_list)[0]
|
||||
self._assert_results(output)
|
||||
|
||||
self.assertListEqual(output['detection_classes'].tolist(),
|
||||
np.array([13, 8, 8, 8], dtype=np.int32).tolist())
|
||||
|
||||
self.assertListEqual(output['detection_class_names'],
|
||||
['bench', 'boat', 'boat', 'boat'])
|
||||
class DetectionPredictorTest(unittest.TestCase):
|
||||
|
||||
assert_array_almost_equal(
|
||||
output['detection_scores'],
|
||||
np.array([0.92593855, 0.60268813, 0.57775956, 0.5750004],
|
||||
dtype=np.float32),
|
||||
decimal=2)
|
||||
|
||||
assert_array_almost_equal(
|
||||
output['detection_boxes'],
|
||||
np.array([[407.89523, 284.62598, 561.4984, 356.7296],
|
||||
[439.37653, 263.42395, 467.01526, 271.79144],
|
||||
[480.8597, 269.64435, 502.18765, 274.80127],
|
||||
[510.37033, 268.4982, 527.67017, 273.04935]]),
|
||||
decimal=1)
|
||||
def setUp(self):
|
||||
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
|
||||
|
||||
def _detection_detector_assert(self, output):
|
||||
self.assertIn('detection_boxes', output)
|
||||
|
|
|
@ -6,7 +6,7 @@ import os
|
|||
import unittest
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from easycv.predictors.detector import TorchYoloXPredictor
|
||||
from easycv.predictors.detector import YoloXPredictor
|
||||
from tests.ut_config import (PRETRAINED_MODEL_YOLOXS_NOPRE_NOTRT_BLADE,
|
||||
PRETRAINED_MODEL_YOLOXS_PRE_NOTRT_BLADE,
|
||||
DET_DATA_SMALL_COCO_LOCAL)
|
||||
|
@ -17,90 +17,73 @@ from numpy.testing import assert_array_almost_equal
|
|||
|
||||
@unittest.skipIf(torch.__version__ != '1.8.1+cu102',
|
||||
'Blade need another environment')
|
||||
class DetectorTest(unittest.TestCase):
|
||||
class YoloXPredictorBladeTest(unittest.TestCase):
|
||||
img = os.path.join(DET_DATA_SMALL_COCO_LOCAL, 'val2017/000000522713.jpg')
|
||||
|
||||
def setUp(self):
|
||||
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
|
||||
|
||||
def test_yolox_detector_blade_nopre_notrt(self):
|
||||
img = os.path.join(DET_DATA_SMALL_COCO_LOCAL,
|
||||
'val2017/000000522713.jpg')
|
||||
|
||||
input_data_list = [np.asarray(Image.open(img))]
|
||||
|
||||
blade_path = PRETRAINED_MODEL_YOLOXS_NOPRE_NOTRT_BLADE
|
||||
predictor_blade = TorchYoloXPredictor(
|
||||
model_path=blade_path, score_thresh=0.5)
|
||||
|
||||
output = predictor_blade.predict(input_data_list)[0]
|
||||
self.assertIn('detection_boxes', output)
|
||||
self.assertIn('detection_scores', output)
|
||||
self.assertIn('detection_classes', output)
|
||||
self.assertIn('detection_class_names', output)
|
||||
self.assertIn('ori_img_shape', output)
|
||||
|
||||
self.assertEqual(len(output['detection_boxes']), 4)
|
||||
self.assertEqual(output['ori_img_shape'], [480, 640])
|
||||
|
||||
self.assertListEqual(output['detection_classes'].tolist(),
|
||||
def _assert_results(self, results):
|
||||
self.assertEqual(results['ori_img_shape'], [480, 640])
|
||||
self.assertListEqual(results['detection_classes'].tolist(),
|
||||
np.array([13, 8, 8, 8], dtype=np.int32).tolist())
|
||||
|
||||
self.assertListEqual(output['detection_class_names'],
|
||||
self.assertListEqual(results['detection_class_names'],
|
||||
['bench', 'boat', 'boat', 'boat'])
|
||||
|
||||
assert_array_almost_equal(
|
||||
output['detection_scores'],
|
||||
np.array([0.92593855, 0.60268813, 0.57775956, 0.5750004],
|
||||
results['detection_scores'],
|
||||
np.array([0.92335737, 0.59416807, 0.5567955, 0.55368793],
|
||||
dtype=np.float32),
|
||||
decimal=2)
|
||||
|
||||
assert_array_almost_equal(
|
||||
output['detection_boxes'],
|
||||
np.array([[407.89523, 284.62598, 561.4984, 356.7296],
|
||||
[439.37653, 263.42395, 467.01526, 271.79144],
|
||||
[480.8597, 269.64435, 502.18765, 274.80127],
|
||||
[510.37033, 268.4982, 527.67017, 273.04935]]),
|
||||
results['detection_boxes'],
|
||||
np.array([[408.1708, 285.11456, 561.84924, 356.42285],
|
||||
[438.88098, 264.46606, 467.07275, 271.76355],
|
||||
[510.19467, 268.46664, 528.26935, 273.37192],
|
||||
[480.9472, 269.74115, 502.00842, 274.85553]]),
|
||||
decimal=1)
|
||||
|
||||
def _base_test_single(self, model_path, inputs):
|
||||
predictor = YoloXPredictor(model_path=model_path, score_thresh=0.5)
|
||||
|
||||
outputs = predictor(inputs)
|
||||
self.assertEqual(len(outputs), 1)
|
||||
output = outputs[0]
|
||||
self._assert_results(output)
|
||||
|
||||
def _base_test_batch(self, model_path, inputs, num_samples, batch_size):
|
||||
assert isinstance(inputs, list) and len(inputs) == 1
|
||||
|
||||
predictor = YoloXPredictor(
|
||||
model_path=model_path, score_thresh=0.5, batch_size=batch_size)
|
||||
outputs = predictor(inputs * num_samples)
|
||||
|
||||
self.assertEqual(len(outputs), num_samples)
|
||||
for output in outputs:
|
||||
self._assert_results(output)
|
||||
|
||||
def test_blade_nopre_notrt(self):
|
||||
inputs = [np.asarray(Image.open(self.img))]
|
||||
blade_path = PRETRAINED_MODEL_YOLOXS_NOPRE_NOTRT_BLADE
|
||||
self._base_test_single(blade_path, inputs)
|
||||
|
||||
@unittest.skipIf(True,
|
||||
'Need export blade model to support dynamic batch size')
|
||||
def test_blade_nopre_notrt_batch(self):
|
||||
inputs = [np.asarray(Image.open(self.img))]
|
||||
blade_path = PRETRAINED_MODEL_YOLOXS_NOPRE_NOTRT_BLADE
|
||||
self._base_test_batch(blade_path, inputs, num_samples=4, batch_size=2)
|
||||
|
||||
def test_yolox_detector_blade_pre_notrt(self):
|
||||
img = os.path.join(DET_DATA_SMALL_COCO_LOCAL,
|
||||
'val2017/000000522713.jpg')
|
||||
|
||||
input_data_list = [np.asarray(Image.open(img))]
|
||||
|
||||
inputs = [self.img]
|
||||
blade_path = PRETRAINED_MODEL_YOLOXS_PRE_NOTRT_BLADE
|
||||
predictor_blade = TorchYoloXPredictor(
|
||||
model_path=blade_path, score_thresh=0.5)
|
||||
self._base_test_single(blade_path, inputs)
|
||||
|
||||
output = predictor_blade.predict(input_data_list)[0]
|
||||
self.assertIn('detection_boxes', output)
|
||||
self.assertIn('detection_scores', output)
|
||||
self.assertIn('detection_classes', output)
|
||||
self.assertIn('detection_class_names', output)
|
||||
self.assertIn('ori_img_shape', output)
|
||||
|
||||
self.assertEqual(len(output['detection_boxes']), 4)
|
||||
self.assertEqual(output['ori_img_shape'], [480, 640])
|
||||
|
||||
self.assertListEqual(output['detection_classes'].tolist(),
|
||||
np.array([13, 8, 8, 8], dtype=np.int32).tolist())
|
||||
|
||||
self.assertListEqual(output['detection_class_names'],
|
||||
['bench', 'boat', 'boat', 'boat'])
|
||||
|
||||
assert_array_almost_equal(
|
||||
output['detection_scores'],
|
||||
np.array([0.92593855, 0.60268813, 0.57775956, 0.5750004],
|
||||
dtype=np.float32),
|
||||
decimal=2)
|
||||
|
||||
assert_array_almost_equal(
|
||||
output['detection_boxes'],
|
||||
np.array([[407.89523, 284.62598, 561.4984, 356.7296],
|
||||
[439.37653, 263.42395, 467.01526, 271.79144],
|
||||
[480.8597, 269.64435, 502.18765, 274.80127],
|
||||
[510.37033, 268.4982, 527.67017, 273.04935]]),
|
||||
decimal=1)
|
||||
@unittest.skipIf(True,
|
||||
'Need export blade model to support dynamic batch size')
|
||||
def test_yolox_detector_blade_pre_notrt_batch(self):
|
||||
inputs = [self.img]
|
||||
blade_path = PRETRAINED_MODEL_YOLOXS_PRE_NOTRT_BLADE
|
||||
self._base_test_batch(blade_path, inputs, num_samples=3, batch_size=2)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
Loading…
Reference in New Issue