mirror of https://github.com/alibaba/EasyCV.git
parent
1d1ac8aa5e
commit
8c90ceaf84
|
@ -6,3 +6,6 @@ model = dict(
|
|||
depth=50,
|
||||
out_indices=[4], # 0: conv-1, x: stage-x
|
||||
norm_cfg=dict(type='BN')))
|
||||
|
||||
checkpoint_sync_export = True
|
||||
export = dict(export_type='raw', export_neck=True)
|
||||
|
|
|
@ -157,6 +157,37 @@ def _export_jit_and_blade(model, cfg, filename, dummy_inputs, fp16=False):
|
|||
torch.jit.save(blade_model, ofile)
|
||||
|
||||
|
||||
def _export_onnx_cls(model, model_config, cfg, filename, meta):
|
||||
|
||||
if model_config['backbone'].get(
|
||||
'type', None) == 'ResNet' and model_config['backbone'].get(
|
||||
'depth', None) == 50:
|
||||
# save json config for test_pipline and class
|
||||
with io.open(
|
||||
filename +
|
||||
'.config.json' if filename.endswith('onnx') else filename +
|
||||
'.onnx.config.json', 'w') as ofile:
|
||||
json.dump(meta, ofile)
|
||||
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
model.eval()
|
||||
model.to(device)
|
||||
img_size = int(cfg.image_size2)
|
||||
x_input = torch.randn((1, 3, img_size, img_size)).to(device)
|
||||
torch.onnx.export(
|
||||
model,
|
||||
(x_input, 'onnx'),
|
||||
filename if filename.endswith('onnx') else filename + '.onnx',
|
||||
export_params=True,
|
||||
opset_version=12,
|
||||
do_constant_folding=True,
|
||||
input_names=['input'],
|
||||
output_names=['output'],
|
||||
)
|
||||
else:
|
||||
raise ValueError('Only support export onnx model for ResNet now!')
|
||||
|
||||
|
||||
def _export_cls(model, cfg, filename):
|
||||
""" export cls (cls & metric learning)model and preprocess config
|
||||
|
||||
|
@ -170,6 +201,7 @@ def _export_cls(model, cfg, filename):
|
|||
else:
|
||||
export_cfg = dict(export_neck=False)
|
||||
|
||||
export_type = export_cfg.get('export_type', 'raw')
|
||||
export_neck = export_cfg.get('export_neck', True)
|
||||
label_map_path = cfg.get('label_map_path', None)
|
||||
class_list = None
|
||||
|
@ -232,9 +264,14 @@ def _export_cls(model, cfg, filename):
|
|||
if export_neck and (k.startswith('neck') or k.startswith('head')):
|
||||
state_dict[k] = v
|
||||
|
||||
checkpoint = dict(state_dict=state_dict, meta=meta, author='EasyCV')
|
||||
with io.open(filename, 'wb') as ofile:
|
||||
torch.save(checkpoint, ofile)
|
||||
if export_type == 'raw':
|
||||
checkpoint = dict(state_dict=state_dict, meta=meta, author='EasyCV')
|
||||
with io.open(filename, 'wb') as ofile:
|
||||
torch.save(checkpoint, ofile)
|
||||
elif export_type == 'onnx':
|
||||
_export_onnx_cls(model, model_config, cfg, filename, config)
|
||||
else:
|
||||
raise ValueError('Only support export onnx/raw model!')
|
||||
|
||||
|
||||
def _export_yolox(model, cfg, filename):
|
||||
|
|
|
@ -151,6 +151,20 @@ class Classification(BaseModel):
|
|||
x = self.backbone(img)
|
||||
return x
|
||||
|
||||
def forward_onnx(self, img: torch.Tensor) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
forward_onnx means generate prob from image only support one neck + one head
|
||||
"""
|
||||
x = self.forward_backbone(img) # tuple
|
||||
|
||||
# if self.neck_num > 0:
|
||||
if hasattr(self, 'neck_0'):
|
||||
x = self.neck_0([i for i in x])
|
||||
|
||||
out = self.head_0(x)[0].cpu()
|
||||
out = self.activate_fn(out)
|
||||
return out
|
||||
|
||||
@torch.jit.unused
|
||||
def forward_train(self, img, gt_labels) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
|
@ -290,6 +304,9 @@ class Classification(BaseModel):
|
|||
return self.forward_test_label(img, gt_labels)
|
||||
else:
|
||||
return self.forward_test(img)
|
||||
elif mode == 'onnx':
|
||||
return self.forward_onnx(img)
|
||||
|
||||
elif mode == 'extract':
|
||||
rd = self.forward_feature(img)
|
||||
rv = {}
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import glob
|
||||
import math
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
@ -7,11 +9,18 @@ from PIL import Image
|
|||
|
||||
from easycv.file import io
|
||||
from easycv.framework.errors import ValueError
|
||||
from easycv.utils.checkpoint import load_checkpoint
|
||||
from easycv.utils.misc import deprecated
|
||||
from .base import InputProcessor, OutputProcessor, Predictor, PredictorV2
|
||||
from .builder import PREDICTORS
|
||||
|
||||
|
||||
# onnx specific
|
||||
def onnx_to_numpy(tensor):
|
||||
return tensor.detach().cpu().numpy(
|
||||
) if tensor.requires_grad else tensor.cpu().numpy()
|
||||
|
||||
|
||||
class ClsInputProcessor(InputProcessor):
|
||||
"""Process inputs for classification models.
|
||||
|
||||
|
@ -146,6 +155,20 @@ class ClassificationPredictor(PredictorV2):
|
|||
self.pil_input = pil_input
|
||||
self.label_map_path = label_map_path
|
||||
|
||||
if model_path.endswith('onnx'):
|
||||
self.model_type = 'onnx'
|
||||
pwd_model = os.path.dirname(model_path)
|
||||
raw_model = glob.glob(
|
||||
os.path.join(pwd_model, '*.onnx.config.json'))
|
||||
if len(raw_model) != 0:
|
||||
config_file = raw_model[0]
|
||||
else:
|
||||
assert len(
|
||||
raw_model
|
||||
) == 0, 'Please have a file with the .onnx.config.json extension in your directory'
|
||||
else:
|
||||
self.model_type = 'raw'
|
||||
|
||||
if self.pil_input:
|
||||
mode = 'RGB'
|
||||
super(ClassificationPredictor, self).__init__(
|
||||
|
@ -186,6 +209,41 @@ class ClassificationPredictor(PredictorV2):
|
|||
|
||||
return ClsOutputProcessor(topk=self.topk, label_map=self.label_map)
|
||||
|
||||
def prepare_model(self):
|
||||
"""Build model from config file by default.
|
||||
If the model is not loaded from a configuration file, e.g. torch jit model, you need to reimplement it.
|
||||
"""
|
||||
if self.model_type == 'raw':
|
||||
model = self._build_model()
|
||||
model.to(self.device)
|
||||
model.eval()
|
||||
load_checkpoint(model, self.model_path, map_location='cpu')
|
||||
return model
|
||||
else:
|
||||
import onnxruntime
|
||||
if onnxruntime.get_device() == 'GPU':
|
||||
onnx_model = onnxruntime.InferenceSession(
|
||||
self.model_path, providers=['CUDAExecutionProvider'])
|
||||
else:
|
||||
onnx_model = onnxruntime.InferenceSession(self.model_path)
|
||||
|
||||
return onnx_model
|
||||
|
||||
def model_forward(self, inputs):
|
||||
"""Model forward.
|
||||
If you need refactor model forward, you need to reimplement it.
|
||||
"""
|
||||
with torch.no_grad():
|
||||
if self.model_type == 'raw':
|
||||
outputs = self.model(**inputs, mode='test')
|
||||
else:
|
||||
outputs = self.model.run(None, {
|
||||
self.model.get_inputs()[0].name:
|
||||
onnx_to_numpy(inputs['img'])
|
||||
})[0]
|
||||
outputs = dict(prob=torch.from_numpy(outputs))
|
||||
return outputs
|
||||
|
||||
|
||||
try:
|
||||
from easy_vision.python.inference.predictor import PredictorInterface
|
||||
|
|
|
@ -2,5 +2,5 @@
|
|||
# GENERATED VERSION FILE
|
||||
# TIME: Thu Nov 5 14:17:50 2020
|
||||
|
||||
__version__ = '0.11.6'
|
||||
short_version = '0.11.6'
|
||||
__version__ = '0.11.7'
|
||||
short_version = '0.11.7'
|
||||
|
|
|
@ -116,6 +116,7 @@ class ModelExportTest(unittest.TestCase):
|
|||
cfg = mmcv_config_fromfile(config_file)
|
||||
cfg_options = {
|
||||
'model.backbone.norm_cfg.type': 'SyncBN',
|
||||
'export.export_type': 'raw'
|
||||
}
|
||||
if cfg_options is not None:
|
||||
cfg.merge_from_dict(cfg_options)
|
||||
|
@ -210,6 +211,27 @@ class ModelExportTest(unittest.TestCase):
|
|||
|
||||
self.assertTrue(os.path.exists(filename + '.jit'))
|
||||
|
||||
def test_export_resnet_onnx(self):
|
||||
|
||||
ckpt_path = PRETRAINED_MODEL_RESNET50
|
||||
|
||||
easycv_dir = os.path.dirname(easycv.__file__)
|
||||
|
||||
if os.path.exists(os.path.join(easycv_dir, 'configs')):
|
||||
config_dir = os.path.join(easycv_dir, 'configs')
|
||||
else:
|
||||
config_dir = os.path.join(os.path.dirname(easycv_dir), 'configs')
|
||||
config_file = os.path.join(
|
||||
config_dir,
|
||||
'classification/imagenet/resnet/imagenet_resnet50_jpg.py')
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
cfg = mmcv_config_fromfile(config_file)
|
||||
cfg.export.export_type = 'onnx'
|
||||
filename = os.path.join(tmpdir, 'imagenet_resnet50')
|
||||
export(cfg, ckpt_path, filename)
|
||||
self.assertTrue(os.path.exists(filename + '.onnx'))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
|
@ -11,7 +11,8 @@ import torch
|
|||
from easycv.predictors.classifier import ClassificationPredictor
|
||||
from easycv.utils.test_util import clean_up, get_tmp_dir
|
||||
from tests.ut_config import (PRETRAINED_MODEL_RESNET50_WITHOUTHEAD,
|
||||
IMAGENET_LABEL_TXT, TEST_IMAGES_DIR)
|
||||
IMAGENET_LABEL_TXT, TEST_IMAGES_DIR,
|
||||
PRETRAINED_MODEL_RESNET50_ONNX_WITHOUTHEAD)
|
||||
|
||||
|
||||
class ClassificationPredictorTest(unittest.TestCase):
|
||||
|
@ -33,6 +34,17 @@ class ClassificationPredictorTest(unittest.TestCase):
|
|||
self.assertListEqual(results['class_name'], ['"Persian cat",'])
|
||||
self.assertEqual(len(results['class_probs']), 1000)
|
||||
|
||||
def test_onnx_single(self):
|
||||
checkpoint = PRETRAINED_MODEL_RESNET50_ONNX_WITHOUTHEAD
|
||||
predict_op = ClassificationPredictor(model_path=checkpoint)
|
||||
|
||||
img_path = os.path.join(TEST_IMAGES_DIR, 'catb.jpg')
|
||||
|
||||
results = predict_op([img_path])[0]
|
||||
self.assertListEqual(results['class'], [578])
|
||||
self.assertListEqual(results['class_name'], ['gown'])
|
||||
self.assertEqual(len(results['class_probs']), 1000)
|
||||
|
||||
def test_batch(self):
|
||||
checkpoint = PRETRAINED_MODEL_RESNET50_WITHOUTHEAD
|
||||
config_file = 'configs/classification/imagenet/resnet/imagenet_resnet50_jpg.py'
|
||||
|
|
|
@ -54,10 +54,10 @@ class PoseTopDownPredictorTest(unittest.TestCase):
|
|||
|
||||
assert_array_almost_equal(
|
||||
result0['bbox'],
|
||||
np.array([[352.3085, 59.00325, 691.4247, 511.15814, 1.],
|
||||
[10.511196, 177.74883, 101.824326, 299.49966, 1.],
|
||||
[224.82036, 114.439865, 312.51306, 231.36348, 1.],
|
||||
[200.71407, 114.716736, 337.17535, 296.6651, 1.]],
|
||||
np.array([[438.9, 59., 604.8, 511.2, 0.9],
|
||||
[10.5, 179.6, 101.8, 297.7, 0.9],
|
||||
[229.6, 114.4, 307.8, 231.4, 0.6],
|
||||
[229.4, 114.7, 308.5, 296.7, 0.6]],
|
||||
dtype=np.float32),
|
||||
decimal=1)
|
||||
vis_result = predictor.show_result(img1, result0)
|
||||
|
@ -92,10 +92,10 @@ class PoseTopDownPredictorTest(unittest.TestCase):
|
|||
|
||||
assert_array_almost_equal(
|
||||
result1['bbox'][:4],
|
||||
np.array([[436.23096, 214.72766, 584.26013, 412.09985, 1.],
|
||||
[43.990044, 91.04126, 164.28406, 251.43329, 1.],
|
||||
[127.44148, 100.38604, 254.219, 269.42273, 1.],
|
||||
[190.08075, 117.31801, 311.22394, 278.8423, 1.]],
|
||||
np.array([[470.6, 214.7, 549.9, 412.1, 0.9],
|
||||
[71.6, 91., 136.7, 251.4, 0.9],
|
||||
[159.7, 100.4, 221.9, 269.4, 0.9],
|
||||
[219.4, 117.3, 281.9, 278.8, 0.9]],
|
||||
dtype=np.float32),
|
||||
decimal=1)
|
||||
vis_result = predictor.show_result(img2, result1)
|
||||
|
|
|
@ -179,6 +179,9 @@ PRETRAINED_MODEL_RESNET50 = os.path.join(
|
|||
PRETRAINED_MODEL_RESNET50_WITHOUTHEAD = os.path.join(
|
||||
BASE_LOCAL_PATH,
|
||||
'pretrained_models/classification/resnet/resnet50_withhead.pth')
|
||||
PRETRAINED_MODEL_RESNET50_ONNX_WITHOUTHEAD = os.path.join(
|
||||
BASE_LOCAL_PATH,
|
||||
'pretrained_models/classification/resnet/imagenet_resnet50.onnx')
|
||||
PRETRAINED_MODEL_FACEID = os.path.join(BASE_LOCAL_PATH,
|
||||
'pretrained_models/faceid')
|
||||
PRETRAINED_MODEL_YOLOXS_EXPORT = os.path.join(
|
||||
|
|
Loading…
Reference in New Issue