Features/self test onnx (#330)

add yolox onnx export method
This commit is contained in:
gulou 2023-10-31 14:56:08 +08:00 committed by GitHub
parent db33ced143
commit 8c3ba59aaf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 101 additions and 52 deletions

View File

@ -49,14 +49,14 @@ img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [ train_pipeline = [
dict(type='MMMosaic', img_scale='${img_scale}', pad_val=114.0), dict(type='MMMosaic', img_scale=tuple(img_scale), pad_val=114.0),
dict( dict(
type='MMRandomAffine', type='MMRandomAffine',
scaling_ratio_range='${scale_ratio}', scaling_ratio_range=scale_ratio,
border=['-${img_scale}[0] // 2', '-${img_scale}[1] // 2']), border=[img_scale[0] // 2, img_scale[1] // 2]),
dict( dict(
type='MMMixUp', # s m x l; tiny nano will detele type='MMMixUp', # s m x l; tiny nano will detele
img_scale='${img_scale}', img_scale=tuple(img_scale),
ratio_range=(0.8, 1.6), ratio_range=(0.8, 1.6),
pad_val=114.0), pad_val=114.0),
dict( dict(
@ -70,45 +70,43 @@ train_pipeline = [
dict(type='MMPad', pad_to_square=True, pad_val=(114.0, 114.0, 114.0)), dict(type='MMPad', pad_to_square=True, pad_val=(114.0, 114.0, 114.0)),
dict( dict(
type='MMNormalize', type='MMNormalize',
mean='${img_norm_cfg.mean}', mean=img_norm_cfg['mean'],
std='${img_norm_cfg.std}', std=img_norm_cfg['std'],
to_rgb='${img_norm_cfg.to_rgb}'), to_rgb=img_norm_cfg['to_rgb']),
dict(type='DefaultFormatBundle'), dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']) dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels'])
] ]
test_pipeline = [ test_pipeline = [
dict(type='MMResize', img_scale='${img_scale}', keep_ratio=True), dict(type='MMResize', img_scale=img_scale, keep_ratio=True),
dict(type='MMPad', pad_to_square=True, pad_val=(114.0, 114.0, 114.0)), dict(type='MMPad', pad_to_square=True, pad_val=(114.0, 114.0, 114.0)),
dict( dict(
type='MMNormalize', type='MMNormalize',
mean='${img_norm_cfg.mean}', mean=img_norm_cfg['mean'],
std='${img_norm_cfg.std}', std=img_norm_cfg['std'],
to_rgb='${img_norm_cfg.to_rgb}'), to_rgb=img_norm_cfg['to_rgb']),
dict(type='DefaultFormatBundle'), dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img']) dict(type='Collect', keys=['img'])
] ]
data = dict( train_path = 'data/coco/train2017.manifest'
imgs_per_gpu=16, val_path = 'data/coco/val2017.manifest'
workers_per_gpu=4,
train=dict( train_dataset = dict(
type='DetImagesMixDataset', type='DetImagesMixDataset',
data_source=dict( data_source=dict(type='DetSourcePAI', path=train_path, classes=CLASSES),
type='DetSourcePAI', pipeline=train_pipeline,
path='data/coco/train2017.manifest', dynamic_scale=tuple(img_scale))
classes='${CLASSES}'),
pipeline='${train_pipeline}', val_dataset = dict(
dynamic_scale='${img_scale}'),
val=dict(
type='DetImagesMixDataset', type='DetImagesMixDataset',
imgs_per_gpu=2, imgs_per_gpu=2,
data_source=dict( data_source=dict(type='DetSourcePAI', path=val_path, classes=CLASSES),
type='DetSourcePAI', pipeline=test_pipeline,
path='data/coco/val2017.manifest',
classes='${CLASSES}'),
pipeline='${test_pipeline}',
dynamic_scale=None, dynamic_scale=None,
label_padding=False)) label_padding=False)
data = dict(
imgs_per_gpu=16, workers_per_gpu=4, train=train_dataset, val=val_dataset)
# additional hooks # additional hooks
interval = 10 interval = 10
@ -120,14 +118,14 @@ custom_hooks = [
priority=48), priority=48),
dict( dict(
type='SyncRandomSizeHook', type='SyncRandomSizeHook',
ratio_range='${random_size}', ratio_range=random_size,
img_scale='${img_scale}', img_scale=img_scale,
interval='${interval}', interval=interval,
priority=48), priority=48),
dict( dict(
type='SyncNormHook', type='SyncNormHook',
num_last_epochs=15, num_last_epochs=15,
interval='${interval}', interval=interval,
priority=48) priority=48)
] ]
@ -135,23 +133,23 @@ custom_hooks = [
vis_num = 20 vis_num = 20
score_thr = 0.5 score_thr = 0.5
eval_config = dict( eval_config = dict(
interval='${interval}', interval=interval,
gpu_collect=False, gpu_collect=False,
visualization_config=dict( visualization_config=dict(
vis_num='${vis_num}', vis_num=vis_num,
score_thr='${score_thr}', score_thr=score_thr,
) # show by TensorboardLoggerHookV2 ) # show by TensorboardLoggerHookV2
) )
eval_pipelines = [ eval_pipelines = [
dict( dict(
mode='test', mode='test',
data='${data.val}', data=val_dataset,
evaluators=[dict(type='CocoDetectionEvaluator', classes=CLASSES)], evaluators=[dict(type='CocoDetectionEvaluator', classes=CLASSES)],
) )
] ]
checkpoint_config = dict(interval='${interval}') checkpoint_config = dict(interval=interval)
# optimizer # optimizer
# basic_lr_per_img = 0.01 / 64.0 # basic_lr_per_img = 0.01 / 64.0
optimizer = dict( optimizer = dict(

View File

@ -247,10 +247,10 @@ def _export_yolox(model, cfg, filename):
if hasattr(cfg, 'export'): if hasattr(cfg, 'export'):
export_type = getattr(cfg.export, 'export_type', 'raw') export_type = getattr(cfg.export, 'export_type', 'raw')
default_export_type_list = ['raw', 'jit', 'blade'] default_export_type_list = ['raw', 'jit', 'blade', 'onnx']
if export_type not in default_export_type_list: if export_type not in default_export_type_list:
logging.warning( logging.warning(
'YOLOX-PAI only supports the export type as [raw,jit,blade], otherwise we use raw as default' 'YOLOX-PAI only supports the export type as [raw,jit,blade,onnx], otherwise we use raw as default'
) )
export_type = 'raw' export_type = 'raw'
@ -276,7 +276,7 @@ def _export_yolox(model, cfg, filename):
len(img_scale) == 2 len(img_scale) == 2
), 'Export YoloX predictor config contains img_scale must be (int, int) tuple!' ), 'Export YoloX predictor config contains img_scale must be (int, int) tuple!'
input = 255 * torch.rand((batch_size, 3) + img_scale) input = 255 * torch.rand((batch_size, 3) + tuple(img_scale))
# assert use_trt_efficientnms only happens when static_opt=True # assert use_trt_efficientnms only happens when static_opt=True
if static_opt is not True: if static_opt is not True:
@ -355,6 +355,31 @@ def _export_yolox(model, cfg, filename):
json.dump(config, ofile) json.dump(config, ofile)
if export_type == 'onnx':
with io.open(
filename + '.config.json' if filename.endswith('onnx')
else filename + '.onnx.config.json', 'w') as ofile:
config = dict(
model=cfg.model,
export=cfg.export,
test_pipeline=cfg.test_pipeline,
classes=cfg.CLASSES)
json.dump(config, ofile)
torch.onnx.export(
model,
input.to(device),
filename if filename.endswith('onnx') else filename +
'.onnx',
export_params=True,
opset_version=12,
do_constant_folding=True,
input_names=['input'],
output_names=['output'],
)
if export_type == 'jit': if export_type == 'jit':
with io.open(filename + '.jit', 'wb') as ofile: with io.open(filename + '.jit', 'wb') as ofile:
torch.jit.save(yolox_trace, ofile) torch.jit.save(yolox_trace, ofile)

View File

@ -23,6 +23,12 @@ except Exception:
from .interface import PredictorInterface from .interface import PredictorInterface
# 将张量转化为ndarray格式
def onnx_to_numpy(tensor):
return tensor.detach().cpu().numpy(
) if tensor.requires_grad else tensor.cpu().numpy()
class DetInputProcessor(InputProcessor): class DetInputProcessor(InputProcessor):
def build_processor(self): def build_processor(self):
@ -349,9 +355,11 @@ class YoloXPredictor(DetectionPredictor):
self.model_type = 'jit' self.model_type = 'jit'
elif model_path.endswith('blade'): elif model_path.endswith('blade'):
self.model_type = 'blade' self.model_type = 'blade'
elif model_path.endswith('onnx'):
self.model_type = 'onnx'
else: else:
self.model_type = 'raw' self.model_type = 'raw'
assert self.model_type in ['raw', 'jit', 'blade'] assert self.model_type in ['raw', 'jit', 'blade', 'onnx']
if self.model_type == 'blade' or self.use_trt_efficientnms: if self.model_type == 'blade' or self.use_trt_efficientnms:
import torch_blade import torch_blade
@ -381,8 +389,16 @@ class YoloXPredictor(DetectionPredictor):
def _build_model(self): def _build_model(self):
if self.model_type != 'raw': if self.model_type != 'raw':
if self.model_type != 'onnx':
with io.open(self.model_path, 'rb') as infile: with io.open(self.model_path, 'rb') as infile:
model = torch.jit.load(infile, self.device) model = torch.jit.load(infile, self.device)
else:
import onnxruntime
if onnxruntime.get_device() == 'GPU':
model = onnxruntime.InferenceSession(
self.model_path, providers=['CUDAExecutionProvider'])
else:
model = onnxruntime.InferenceSession(self.model_path)
else: else:
from easycv.utils.misc import reparameterize_models from easycv.utils.misc import reparameterize_models
model = super()._build_model() model = super()._build_model()
@ -394,6 +410,7 @@ class YoloXPredictor(DetectionPredictor):
If the model is not loaded from a configuration file, e.g. torch jit model, you need to reimplement it. If the model is not loaded from a configuration file, e.g. torch jit model, you need to reimplement it.
""" """
model = self._build_model() model = self._build_model()
if self.model_type != 'onnx':
model.to(self.device) model.to(self.device)
model.eval() model.eval()
if self.model_type == 'raw': if self.model_type == 'raw':
@ -406,7 +423,15 @@ class YoloXPredictor(DetectionPredictor):
""" """
if self.model_type != 'raw': if self.model_type != 'raw':
with torch.no_grad(): with torch.no_grad():
if self.model_type != 'onnx':
outputs = self.model(inputs['img']) outputs = self.model(inputs['img'])
else:
outputs = self.model.run(
None, {
self.model.get_inputs()[0].name:
onnx_to_numpy(inputs['img'])
})[0]
outputs = torch.from_numpy(outputs)
outputs = {'results': outputs} # convert to dict format outputs = {'results': outputs} # convert to dict format
else: else:
outputs = super().model_forward(inputs) outputs = super().model_forward(inputs)

View File

@ -13,6 +13,7 @@ lmdb
numba numba
numpy numpy
nuscenes-devkit nuscenes-devkit
onnxruntime
opencv-python opencv-python
oss2 oss2
packaging packaging

View File

@ -83,12 +83,12 @@ class PredictTest(unittest.TestCase):
oss_config = get_oss_config() oss_config = get_oss_config()
ak_id = oss_config['ak_id'] ak_id = oss_config['ak_id']
ak_secret = oss_config['ak_secret'] ak_secret = oss_config['ak_secret']
hosts = oss_config['hosts'] + ['oss-cn-hangzhou.aliyuncs.com'] hosts = oss_config['hosts']
hosts = ','.join(_ for _ in hosts) hosts = ','.join(_ for _ in hosts)
buckets = oss_config['buckets'] + ['easycv'] buckets = oss_config['buckets']
buckets = ','.join(_ for _ in buckets) buckets = ','.join(_ for _ in buckets)
input_file = 'oss://easycv/data/small_test_data/test_images/http_image_list.txt' input_file = 'oss://pai-vision-data-hz/unittest/local_backup/easycv_nfs/data/test_images/http_image_list.txt'
output_file = tempfile.NamedTemporaryFile('w').name output_file = tempfile.NamedTemporaryFile('w').name
cmd = f'PYTHONPATH=. python tools/predict.py \ cmd = f'PYTHONPATH=. python tools/predict.py \
--input_file {input_file} \ --input_file {input_file} \