mirror of https://github.com/alibaba/EasyCV.git
parent
1d1ac8aa5e
commit
8c90ceaf84
|
@ -6,3 +6,6 @@ model = dict(
|
||||||
depth=50,
|
depth=50,
|
||||||
out_indices=[4], # 0: conv-1, x: stage-x
|
out_indices=[4], # 0: conv-1, x: stage-x
|
||||||
norm_cfg=dict(type='BN')))
|
norm_cfg=dict(type='BN')))
|
||||||
|
|
||||||
|
checkpoint_sync_export = True
|
||||||
|
export = dict(export_type='raw', export_neck=True)
|
||||||
|
|
|
@ -157,6 +157,37 @@ def _export_jit_and_blade(model, cfg, filename, dummy_inputs, fp16=False):
|
||||||
torch.jit.save(blade_model, ofile)
|
torch.jit.save(blade_model, ofile)
|
||||||
|
|
||||||
|
|
||||||
|
def _export_onnx_cls(model, model_config, cfg, filename, meta):
|
||||||
|
|
||||||
|
if model_config['backbone'].get(
|
||||||
|
'type', None) == 'ResNet' and model_config['backbone'].get(
|
||||||
|
'depth', None) == 50:
|
||||||
|
# save json config for test_pipline and class
|
||||||
|
with io.open(
|
||||||
|
filename +
|
||||||
|
'.config.json' if filename.endswith('onnx') else filename +
|
||||||
|
'.onnx.config.json', 'w') as ofile:
|
||||||
|
json.dump(meta, ofile)
|
||||||
|
|
||||||
|
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||||
|
model.eval()
|
||||||
|
model.to(device)
|
||||||
|
img_size = int(cfg.image_size2)
|
||||||
|
x_input = torch.randn((1, 3, img_size, img_size)).to(device)
|
||||||
|
torch.onnx.export(
|
||||||
|
model,
|
||||||
|
(x_input, 'onnx'),
|
||||||
|
filename if filename.endswith('onnx') else filename + '.onnx',
|
||||||
|
export_params=True,
|
||||||
|
opset_version=12,
|
||||||
|
do_constant_folding=True,
|
||||||
|
input_names=['input'],
|
||||||
|
output_names=['output'],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError('Only support export onnx model for ResNet now!')
|
||||||
|
|
||||||
|
|
||||||
def _export_cls(model, cfg, filename):
|
def _export_cls(model, cfg, filename):
|
||||||
""" export cls (cls & metric learning)model and preprocess config
|
""" export cls (cls & metric learning)model and preprocess config
|
||||||
|
|
||||||
|
@ -170,6 +201,7 @@ def _export_cls(model, cfg, filename):
|
||||||
else:
|
else:
|
||||||
export_cfg = dict(export_neck=False)
|
export_cfg = dict(export_neck=False)
|
||||||
|
|
||||||
|
export_type = export_cfg.get('export_type', 'raw')
|
||||||
export_neck = export_cfg.get('export_neck', True)
|
export_neck = export_cfg.get('export_neck', True)
|
||||||
label_map_path = cfg.get('label_map_path', None)
|
label_map_path = cfg.get('label_map_path', None)
|
||||||
class_list = None
|
class_list = None
|
||||||
|
@ -232,9 +264,14 @@ def _export_cls(model, cfg, filename):
|
||||||
if export_neck and (k.startswith('neck') or k.startswith('head')):
|
if export_neck and (k.startswith('neck') or k.startswith('head')):
|
||||||
state_dict[k] = v
|
state_dict[k] = v
|
||||||
|
|
||||||
checkpoint = dict(state_dict=state_dict, meta=meta, author='EasyCV')
|
if export_type == 'raw':
|
||||||
with io.open(filename, 'wb') as ofile:
|
checkpoint = dict(state_dict=state_dict, meta=meta, author='EasyCV')
|
||||||
torch.save(checkpoint, ofile)
|
with io.open(filename, 'wb') as ofile:
|
||||||
|
torch.save(checkpoint, ofile)
|
||||||
|
elif export_type == 'onnx':
|
||||||
|
_export_onnx_cls(model, model_config, cfg, filename, config)
|
||||||
|
else:
|
||||||
|
raise ValueError('Only support export onnx/raw model!')
|
||||||
|
|
||||||
|
|
||||||
def _export_yolox(model, cfg, filename):
|
def _export_yolox(model, cfg, filename):
|
||||||
|
|
|
@ -151,6 +151,20 @@ class Classification(BaseModel):
|
||||||
x = self.backbone(img)
|
x = self.backbone(img)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
def forward_onnx(self, img: torch.Tensor) -> Dict[str, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
forward_onnx means generate prob from image only support one neck + one head
|
||||||
|
"""
|
||||||
|
x = self.forward_backbone(img) # tuple
|
||||||
|
|
||||||
|
# if self.neck_num > 0:
|
||||||
|
if hasattr(self, 'neck_0'):
|
||||||
|
x = self.neck_0([i for i in x])
|
||||||
|
|
||||||
|
out = self.head_0(x)[0].cpu()
|
||||||
|
out = self.activate_fn(out)
|
||||||
|
return out
|
||||||
|
|
||||||
@torch.jit.unused
|
@torch.jit.unused
|
||||||
def forward_train(self, img, gt_labels) -> Dict[str, torch.Tensor]:
|
def forward_train(self, img, gt_labels) -> Dict[str, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
|
@ -290,6 +304,9 @@ class Classification(BaseModel):
|
||||||
return self.forward_test_label(img, gt_labels)
|
return self.forward_test_label(img, gt_labels)
|
||||||
else:
|
else:
|
||||||
return self.forward_test(img)
|
return self.forward_test(img)
|
||||||
|
elif mode == 'onnx':
|
||||||
|
return self.forward_onnx(img)
|
||||||
|
|
||||||
elif mode == 'extract':
|
elif mode == 'extract':
|
||||||
rd = self.forward_feature(img)
|
rd = self.forward_feature(img)
|
||||||
rv = {}
|
rv = {}
|
||||||
|
|
|
@ -1,5 +1,7 @@
|
||||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||||
|
import glob
|
||||||
import math
|
import math
|
||||||
|
import os
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
@ -7,11 +9,18 @@ from PIL import Image
|
||||||
|
|
||||||
from easycv.file import io
|
from easycv.file import io
|
||||||
from easycv.framework.errors import ValueError
|
from easycv.framework.errors import ValueError
|
||||||
|
from easycv.utils.checkpoint import load_checkpoint
|
||||||
from easycv.utils.misc import deprecated
|
from easycv.utils.misc import deprecated
|
||||||
from .base import InputProcessor, OutputProcessor, Predictor, PredictorV2
|
from .base import InputProcessor, OutputProcessor, Predictor, PredictorV2
|
||||||
from .builder import PREDICTORS
|
from .builder import PREDICTORS
|
||||||
|
|
||||||
|
|
||||||
|
# onnx specific
|
||||||
|
def onnx_to_numpy(tensor):
|
||||||
|
return tensor.detach().cpu().numpy(
|
||||||
|
) if tensor.requires_grad else tensor.cpu().numpy()
|
||||||
|
|
||||||
|
|
||||||
class ClsInputProcessor(InputProcessor):
|
class ClsInputProcessor(InputProcessor):
|
||||||
"""Process inputs for classification models.
|
"""Process inputs for classification models.
|
||||||
|
|
||||||
|
@ -146,6 +155,20 @@ class ClassificationPredictor(PredictorV2):
|
||||||
self.pil_input = pil_input
|
self.pil_input = pil_input
|
||||||
self.label_map_path = label_map_path
|
self.label_map_path = label_map_path
|
||||||
|
|
||||||
|
if model_path.endswith('onnx'):
|
||||||
|
self.model_type = 'onnx'
|
||||||
|
pwd_model = os.path.dirname(model_path)
|
||||||
|
raw_model = glob.glob(
|
||||||
|
os.path.join(pwd_model, '*.onnx.config.json'))
|
||||||
|
if len(raw_model) != 0:
|
||||||
|
config_file = raw_model[0]
|
||||||
|
else:
|
||||||
|
assert len(
|
||||||
|
raw_model
|
||||||
|
) == 0, 'Please have a file with the .onnx.config.json extension in your directory'
|
||||||
|
else:
|
||||||
|
self.model_type = 'raw'
|
||||||
|
|
||||||
if self.pil_input:
|
if self.pil_input:
|
||||||
mode = 'RGB'
|
mode = 'RGB'
|
||||||
super(ClassificationPredictor, self).__init__(
|
super(ClassificationPredictor, self).__init__(
|
||||||
|
@ -186,6 +209,41 @@ class ClassificationPredictor(PredictorV2):
|
||||||
|
|
||||||
return ClsOutputProcessor(topk=self.topk, label_map=self.label_map)
|
return ClsOutputProcessor(topk=self.topk, label_map=self.label_map)
|
||||||
|
|
||||||
|
def prepare_model(self):
|
||||||
|
"""Build model from config file by default.
|
||||||
|
If the model is not loaded from a configuration file, e.g. torch jit model, you need to reimplement it.
|
||||||
|
"""
|
||||||
|
if self.model_type == 'raw':
|
||||||
|
model = self._build_model()
|
||||||
|
model.to(self.device)
|
||||||
|
model.eval()
|
||||||
|
load_checkpoint(model, self.model_path, map_location='cpu')
|
||||||
|
return model
|
||||||
|
else:
|
||||||
|
import onnxruntime
|
||||||
|
if onnxruntime.get_device() == 'GPU':
|
||||||
|
onnx_model = onnxruntime.InferenceSession(
|
||||||
|
self.model_path, providers=['CUDAExecutionProvider'])
|
||||||
|
else:
|
||||||
|
onnx_model = onnxruntime.InferenceSession(self.model_path)
|
||||||
|
|
||||||
|
return onnx_model
|
||||||
|
|
||||||
|
def model_forward(self, inputs):
|
||||||
|
"""Model forward.
|
||||||
|
If you need refactor model forward, you need to reimplement it.
|
||||||
|
"""
|
||||||
|
with torch.no_grad():
|
||||||
|
if self.model_type == 'raw':
|
||||||
|
outputs = self.model(**inputs, mode='test')
|
||||||
|
else:
|
||||||
|
outputs = self.model.run(None, {
|
||||||
|
self.model.get_inputs()[0].name:
|
||||||
|
onnx_to_numpy(inputs['img'])
|
||||||
|
})[0]
|
||||||
|
outputs = dict(prob=torch.from_numpy(outputs))
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from easy_vision.python.inference.predictor import PredictorInterface
|
from easy_vision.python.inference.predictor import PredictorInterface
|
||||||
|
|
|
@ -2,5 +2,5 @@
|
||||||
# GENERATED VERSION FILE
|
# GENERATED VERSION FILE
|
||||||
# TIME: Thu Nov 5 14:17:50 2020
|
# TIME: Thu Nov 5 14:17:50 2020
|
||||||
|
|
||||||
__version__ = '0.11.6'
|
__version__ = '0.11.7'
|
||||||
short_version = '0.11.6'
|
short_version = '0.11.7'
|
||||||
|
|
|
@ -116,6 +116,7 @@ class ModelExportTest(unittest.TestCase):
|
||||||
cfg = mmcv_config_fromfile(config_file)
|
cfg = mmcv_config_fromfile(config_file)
|
||||||
cfg_options = {
|
cfg_options = {
|
||||||
'model.backbone.norm_cfg.type': 'SyncBN',
|
'model.backbone.norm_cfg.type': 'SyncBN',
|
||||||
|
'export.export_type': 'raw'
|
||||||
}
|
}
|
||||||
if cfg_options is not None:
|
if cfg_options is not None:
|
||||||
cfg.merge_from_dict(cfg_options)
|
cfg.merge_from_dict(cfg_options)
|
||||||
|
@ -210,6 +211,27 @@ class ModelExportTest(unittest.TestCase):
|
||||||
|
|
||||||
self.assertTrue(os.path.exists(filename + '.jit'))
|
self.assertTrue(os.path.exists(filename + '.jit'))
|
||||||
|
|
||||||
|
def test_export_resnet_onnx(self):
|
||||||
|
|
||||||
|
ckpt_path = PRETRAINED_MODEL_RESNET50
|
||||||
|
|
||||||
|
easycv_dir = os.path.dirname(easycv.__file__)
|
||||||
|
|
||||||
|
if os.path.exists(os.path.join(easycv_dir, 'configs')):
|
||||||
|
config_dir = os.path.join(easycv_dir, 'configs')
|
||||||
|
else:
|
||||||
|
config_dir = os.path.join(os.path.dirname(easycv_dir), 'configs')
|
||||||
|
config_file = os.path.join(
|
||||||
|
config_dir,
|
||||||
|
'classification/imagenet/resnet/imagenet_resnet50_jpg.py')
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
cfg = mmcv_config_fromfile(config_file)
|
||||||
|
cfg.export.export_type = 'onnx'
|
||||||
|
filename = os.path.join(tmpdir, 'imagenet_resnet50')
|
||||||
|
export(cfg, ckpt_path, filename)
|
||||||
|
self.assertTrue(os.path.exists(filename + '.onnx'))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|
|
@ -11,7 +11,8 @@ import torch
|
||||||
from easycv.predictors.classifier import ClassificationPredictor
|
from easycv.predictors.classifier import ClassificationPredictor
|
||||||
from easycv.utils.test_util import clean_up, get_tmp_dir
|
from easycv.utils.test_util import clean_up, get_tmp_dir
|
||||||
from tests.ut_config import (PRETRAINED_MODEL_RESNET50_WITHOUTHEAD,
|
from tests.ut_config import (PRETRAINED_MODEL_RESNET50_WITHOUTHEAD,
|
||||||
IMAGENET_LABEL_TXT, TEST_IMAGES_DIR)
|
IMAGENET_LABEL_TXT, TEST_IMAGES_DIR,
|
||||||
|
PRETRAINED_MODEL_RESNET50_ONNX_WITHOUTHEAD)
|
||||||
|
|
||||||
|
|
||||||
class ClassificationPredictorTest(unittest.TestCase):
|
class ClassificationPredictorTest(unittest.TestCase):
|
||||||
|
@ -33,6 +34,17 @@ class ClassificationPredictorTest(unittest.TestCase):
|
||||||
self.assertListEqual(results['class_name'], ['"Persian cat",'])
|
self.assertListEqual(results['class_name'], ['"Persian cat",'])
|
||||||
self.assertEqual(len(results['class_probs']), 1000)
|
self.assertEqual(len(results['class_probs']), 1000)
|
||||||
|
|
||||||
|
def test_onnx_single(self):
|
||||||
|
checkpoint = PRETRAINED_MODEL_RESNET50_ONNX_WITHOUTHEAD
|
||||||
|
predict_op = ClassificationPredictor(model_path=checkpoint)
|
||||||
|
|
||||||
|
img_path = os.path.join(TEST_IMAGES_DIR, 'catb.jpg')
|
||||||
|
|
||||||
|
results = predict_op([img_path])[0]
|
||||||
|
self.assertListEqual(results['class'], [578])
|
||||||
|
self.assertListEqual(results['class_name'], ['gown'])
|
||||||
|
self.assertEqual(len(results['class_probs']), 1000)
|
||||||
|
|
||||||
def test_batch(self):
|
def test_batch(self):
|
||||||
checkpoint = PRETRAINED_MODEL_RESNET50_WITHOUTHEAD
|
checkpoint = PRETRAINED_MODEL_RESNET50_WITHOUTHEAD
|
||||||
config_file = 'configs/classification/imagenet/resnet/imagenet_resnet50_jpg.py'
|
config_file = 'configs/classification/imagenet/resnet/imagenet_resnet50_jpg.py'
|
||||||
|
|
|
@ -54,10 +54,10 @@ class PoseTopDownPredictorTest(unittest.TestCase):
|
||||||
|
|
||||||
assert_array_almost_equal(
|
assert_array_almost_equal(
|
||||||
result0['bbox'],
|
result0['bbox'],
|
||||||
np.array([[352.3085, 59.00325, 691.4247, 511.15814, 1.],
|
np.array([[438.9, 59., 604.8, 511.2, 0.9],
|
||||||
[10.511196, 177.74883, 101.824326, 299.49966, 1.],
|
[10.5, 179.6, 101.8, 297.7, 0.9],
|
||||||
[224.82036, 114.439865, 312.51306, 231.36348, 1.],
|
[229.6, 114.4, 307.8, 231.4, 0.6],
|
||||||
[200.71407, 114.716736, 337.17535, 296.6651, 1.]],
|
[229.4, 114.7, 308.5, 296.7, 0.6]],
|
||||||
dtype=np.float32),
|
dtype=np.float32),
|
||||||
decimal=1)
|
decimal=1)
|
||||||
vis_result = predictor.show_result(img1, result0)
|
vis_result = predictor.show_result(img1, result0)
|
||||||
|
@ -92,10 +92,10 @@ class PoseTopDownPredictorTest(unittest.TestCase):
|
||||||
|
|
||||||
assert_array_almost_equal(
|
assert_array_almost_equal(
|
||||||
result1['bbox'][:4],
|
result1['bbox'][:4],
|
||||||
np.array([[436.23096, 214.72766, 584.26013, 412.09985, 1.],
|
np.array([[470.6, 214.7, 549.9, 412.1, 0.9],
|
||||||
[43.990044, 91.04126, 164.28406, 251.43329, 1.],
|
[71.6, 91., 136.7, 251.4, 0.9],
|
||||||
[127.44148, 100.38604, 254.219, 269.42273, 1.],
|
[159.7, 100.4, 221.9, 269.4, 0.9],
|
||||||
[190.08075, 117.31801, 311.22394, 278.8423, 1.]],
|
[219.4, 117.3, 281.9, 278.8, 0.9]],
|
||||||
dtype=np.float32),
|
dtype=np.float32),
|
||||||
decimal=1)
|
decimal=1)
|
||||||
vis_result = predictor.show_result(img2, result1)
|
vis_result = predictor.show_result(img2, result1)
|
||||||
|
|
|
@ -179,6 +179,9 @@ PRETRAINED_MODEL_RESNET50 = os.path.join(
|
||||||
PRETRAINED_MODEL_RESNET50_WITHOUTHEAD = os.path.join(
|
PRETRAINED_MODEL_RESNET50_WITHOUTHEAD = os.path.join(
|
||||||
BASE_LOCAL_PATH,
|
BASE_LOCAL_PATH,
|
||||||
'pretrained_models/classification/resnet/resnet50_withhead.pth')
|
'pretrained_models/classification/resnet/resnet50_withhead.pth')
|
||||||
|
PRETRAINED_MODEL_RESNET50_ONNX_WITHOUTHEAD = os.path.join(
|
||||||
|
BASE_LOCAL_PATH,
|
||||||
|
'pretrained_models/classification/resnet/imagenet_resnet50.onnx')
|
||||||
PRETRAINED_MODEL_FACEID = os.path.join(BASE_LOCAL_PATH,
|
PRETRAINED_MODEL_FACEID = os.path.join(BASE_LOCAL_PATH,
|
||||||
'pretrained_models/faceid')
|
'pretrained_models/faceid')
|
||||||
PRETRAINED_MODEL_YOLOXS_EXPORT = os.path.join(
|
PRETRAINED_MODEL_YOLOXS_EXPORT = os.path.join(
|
||||||
|
|
Loading…
Reference in New Issue