add resnet export of onnx (#341)

* add checkpoint_sync_export for resnet config
pull/343/head
gulou 2024-07-02 19:39:28 +08:00 committed by GitHub
parent 1d1ac8aa5e
commit 8c90ceaf84
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 166 additions and 14 deletions

View File

@ -6,3 +6,6 @@ model = dict(
depth=50, depth=50,
out_indices=[4], # 0: conv-1, x: stage-x out_indices=[4], # 0: conv-1, x: stage-x
norm_cfg=dict(type='BN'))) norm_cfg=dict(type='BN')))
checkpoint_sync_export = True
export = dict(export_type='raw', export_neck=True)

View File

@ -157,6 +157,37 @@ def _export_jit_and_blade(model, cfg, filename, dummy_inputs, fp16=False):
torch.jit.save(blade_model, ofile) torch.jit.save(blade_model, ofile)
def _export_onnx_cls(model, model_config, cfg, filename, meta):
if model_config['backbone'].get(
'type', None) == 'ResNet' and model_config['backbone'].get(
'depth', None) == 50:
# save json config for test_pipline and class
with io.open(
filename +
'.config.json' if filename.endswith('onnx') else filename +
'.onnx.config.json', 'w') as ofile:
json.dump(meta, ofile)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model.eval()
model.to(device)
img_size = int(cfg.image_size2)
x_input = torch.randn((1, 3, img_size, img_size)).to(device)
torch.onnx.export(
model,
(x_input, 'onnx'),
filename if filename.endswith('onnx') else filename + '.onnx',
export_params=True,
opset_version=12,
do_constant_folding=True,
input_names=['input'],
output_names=['output'],
)
else:
raise ValueError('Only support export onnx model for ResNet now!')
def _export_cls(model, cfg, filename): def _export_cls(model, cfg, filename):
""" export cls (cls & metric learning)model and preprocess config """ export cls (cls & metric learning)model and preprocess config
@ -170,6 +201,7 @@ def _export_cls(model, cfg, filename):
else: else:
export_cfg = dict(export_neck=False) export_cfg = dict(export_neck=False)
export_type = export_cfg.get('export_type', 'raw')
export_neck = export_cfg.get('export_neck', True) export_neck = export_cfg.get('export_neck', True)
label_map_path = cfg.get('label_map_path', None) label_map_path = cfg.get('label_map_path', None)
class_list = None class_list = None
@ -232,9 +264,14 @@ def _export_cls(model, cfg, filename):
if export_neck and (k.startswith('neck') or k.startswith('head')): if export_neck and (k.startswith('neck') or k.startswith('head')):
state_dict[k] = v state_dict[k] = v
checkpoint = dict(state_dict=state_dict, meta=meta, author='EasyCV') if export_type == 'raw':
with io.open(filename, 'wb') as ofile: checkpoint = dict(state_dict=state_dict, meta=meta, author='EasyCV')
torch.save(checkpoint, ofile) with io.open(filename, 'wb') as ofile:
torch.save(checkpoint, ofile)
elif export_type == 'onnx':
_export_onnx_cls(model, model_config, cfg, filename, config)
else:
raise ValueError('Only support export onnx/raw model!')
def _export_yolox(model, cfg, filename): def _export_yolox(model, cfg, filename):

View File

@ -151,6 +151,20 @@ class Classification(BaseModel):
x = self.backbone(img) x = self.backbone(img)
return x return x
def forward_onnx(self, img: torch.Tensor) -> Dict[str, torch.Tensor]:
"""
forward_onnx means generate prob from image only support one neck + one head
"""
x = self.forward_backbone(img) # tuple
# if self.neck_num > 0:
if hasattr(self, 'neck_0'):
x = self.neck_0([i for i in x])
out = self.head_0(x)[0].cpu()
out = self.activate_fn(out)
return out
@torch.jit.unused @torch.jit.unused
def forward_train(self, img, gt_labels) -> Dict[str, torch.Tensor]: def forward_train(self, img, gt_labels) -> Dict[str, torch.Tensor]:
""" """
@ -290,6 +304,9 @@ class Classification(BaseModel):
return self.forward_test_label(img, gt_labels) return self.forward_test_label(img, gt_labels)
else: else:
return self.forward_test(img) return self.forward_test(img)
elif mode == 'onnx':
return self.forward_onnx(img)
elif mode == 'extract': elif mode == 'extract':
rd = self.forward_feature(img) rd = self.forward_feature(img)
rv = {} rv = {}

View File

@ -1,5 +1,7 @@
# Copyright (c) Alibaba, Inc. and its affiliates. # Copyright (c) Alibaba, Inc. and its affiliates.
import glob
import math import math
import os
import numpy as np import numpy as np
import torch import torch
@ -7,11 +9,18 @@ from PIL import Image
from easycv.file import io from easycv.file import io
from easycv.framework.errors import ValueError from easycv.framework.errors import ValueError
from easycv.utils.checkpoint import load_checkpoint
from easycv.utils.misc import deprecated from easycv.utils.misc import deprecated
from .base import InputProcessor, OutputProcessor, Predictor, PredictorV2 from .base import InputProcessor, OutputProcessor, Predictor, PredictorV2
from .builder import PREDICTORS from .builder import PREDICTORS
# onnx specific
def onnx_to_numpy(tensor):
return tensor.detach().cpu().numpy(
) if tensor.requires_grad else tensor.cpu().numpy()
class ClsInputProcessor(InputProcessor): class ClsInputProcessor(InputProcessor):
"""Process inputs for classification models. """Process inputs for classification models.
@ -146,6 +155,20 @@ class ClassificationPredictor(PredictorV2):
self.pil_input = pil_input self.pil_input = pil_input
self.label_map_path = label_map_path self.label_map_path = label_map_path
if model_path.endswith('onnx'):
self.model_type = 'onnx'
pwd_model = os.path.dirname(model_path)
raw_model = glob.glob(
os.path.join(pwd_model, '*.onnx.config.json'))
if len(raw_model) != 0:
config_file = raw_model[0]
else:
assert len(
raw_model
) == 0, 'Please have a file with the .onnx.config.json extension in your directory'
else:
self.model_type = 'raw'
if self.pil_input: if self.pil_input:
mode = 'RGB' mode = 'RGB'
super(ClassificationPredictor, self).__init__( super(ClassificationPredictor, self).__init__(
@ -186,6 +209,41 @@ class ClassificationPredictor(PredictorV2):
return ClsOutputProcessor(topk=self.topk, label_map=self.label_map) return ClsOutputProcessor(topk=self.topk, label_map=self.label_map)
def prepare_model(self):
"""Build model from config file by default.
If the model is not loaded from a configuration file, e.g. torch jit model, you need to reimplement it.
"""
if self.model_type == 'raw':
model = self._build_model()
model.to(self.device)
model.eval()
load_checkpoint(model, self.model_path, map_location='cpu')
return model
else:
import onnxruntime
if onnxruntime.get_device() == 'GPU':
onnx_model = onnxruntime.InferenceSession(
self.model_path, providers=['CUDAExecutionProvider'])
else:
onnx_model = onnxruntime.InferenceSession(self.model_path)
return onnx_model
def model_forward(self, inputs):
"""Model forward.
If you need refactor model forward, you need to reimplement it.
"""
with torch.no_grad():
if self.model_type == 'raw':
outputs = self.model(**inputs, mode='test')
else:
outputs = self.model.run(None, {
self.model.get_inputs()[0].name:
onnx_to_numpy(inputs['img'])
})[0]
outputs = dict(prob=torch.from_numpy(outputs))
return outputs
try: try:
from easy_vision.python.inference.predictor import PredictorInterface from easy_vision.python.inference.predictor import PredictorInterface

View File

@ -2,5 +2,5 @@
# GENERATED VERSION FILE # GENERATED VERSION FILE
# TIME: Thu Nov 5 14:17:50 2020 # TIME: Thu Nov 5 14:17:50 2020
__version__ = '0.11.6' __version__ = '0.11.7'
short_version = '0.11.6' short_version = '0.11.7'

View File

@ -116,6 +116,7 @@ class ModelExportTest(unittest.TestCase):
cfg = mmcv_config_fromfile(config_file) cfg = mmcv_config_fromfile(config_file)
cfg_options = { cfg_options = {
'model.backbone.norm_cfg.type': 'SyncBN', 'model.backbone.norm_cfg.type': 'SyncBN',
'export.export_type': 'raw'
} }
if cfg_options is not None: if cfg_options is not None:
cfg.merge_from_dict(cfg_options) cfg.merge_from_dict(cfg_options)
@ -210,6 +211,27 @@ class ModelExportTest(unittest.TestCase):
self.assertTrue(os.path.exists(filename + '.jit')) self.assertTrue(os.path.exists(filename + '.jit'))
def test_export_resnet_onnx(self):
ckpt_path = PRETRAINED_MODEL_RESNET50
easycv_dir = os.path.dirname(easycv.__file__)
if os.path.exists(os.path.join(easycv_dir, 'configs')):
config_dir = os.path.join(easycv_dir, 'configs')
else:
config_dir = os.path.join(os.path.dirname(easycv_dir), 'configs')
config_file = os.path.join(
config_dir,
'classification/imagenet/resnet/imagenet_resnet50_jpg.py')
with tempfile.TemporaryDirectory() as tmpdir:
cfg = mmcv_config_fromfile(config_file)
cfg.export.export_type = 'onnx'
filename = os.path.join(tmpdir, 'imagenet_resnet50')
export(cfg, ckpt_path, filename)
self.assertTrue(os.path.exists(filename + '.onnx'))
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

View File

@ -11,7 +11,8 @@ import torch
from easycv.predictors.classifier import ClassificationPredictor from easycv.predictors.classifier import ClassificationPredictor
from easycv.utils.test_util import clean_up, get_tmp_dir from easycv.utils.test_util import clean_up, get_tmp_dir
from tests.ut_config import (PRETRAINED_MODEL_RESNET50_WITHOUTHEAD, from tests.ut_config import (PRETRAINED_MODEL_RESNET50_WITHOUTHEAD,
IMAGENET_LABEL_TXT, TEST_IMAGES_DIR) IMAGENET_LABEL_TXT, TEST_IMAGES_DIR,
PRETRAINED_MODEL_RESNET50_ONNX_WITHOUTHEAD)
class ClassificationPredictorTest(unittest.TestCase): class ClassificationPredictorTest(unittest.TestCase):
@ -33,6 +34,17 @@ class ClassificationPredictorTest(unittest.TestCase):
self.assertListEqual(results['class_name'], ['"Persian cat",']) self.assertListEqual(results['class_name'], ['"Persian cat",'])
self.assertEqual(len(results['class_probs']), 1000) self.assertEqual(len(results['class_probs']), 1000)
def test_onnx_single(self):
checkpoint = PRETRAINED_MODEL_RESNET50_ONNX_WITHOUTHEAD
predict_op = ClassificationPredictor(model_path=checkpoint)
img_path = os.path.join(TEST_IMAGES_DIR, 'catb.jpg')
results = predict_op([img_path])[0]
self.assertListEqual(results['class'], [578])
self.assertListEqual(results['class_name'], ['gown'])
self.assertEqual(len(results['class_probs']), 1000)
def test_batch(self): def test_batch(self):
checkpoint = PRETRAINED_MODEL_RESNET50_WITHOUTHEAD checkpoint = PRETRAINED_MODEL_RESNET50_WITHOUTHEAD
config_file = 'configs/classification/imagenet/resnet/imagenet_resnet50_jpg.py' config_file = 'configs/classification/imagenet/resnet/imagenet_resnet50_jpg.py'

View File

@ -54,10 +54,10 @@ class PoseTopDownPredictorTest(unittest.TestCase):
assert_array_almost_equal( assert_array_almost_equal(
result0['bbox'], result0['bbox'],
np.array([[352.3085, 59.00325, 691.4247, 511.15814, 1.], np.array([[438.9, 59., 604.8, 511.2, 0.9],
[10.511196, 177.74883, 101.824326, 299.49966, 1.], [10.5, 179.6, 101.8, 297.7, 0.9],
[224.82036, 114.439865, 312.51306, 231.36348, 1.], [229.6, 114.4, 307.8, 231.4, 0.6],
[200.71407, 114.716736, 337.17535, 296.6651, 1.]], [229.4, 114.7, 308.5, 296.7, 0.6]],
dtype=np.float32), dtype=np.float32),
decimal=1) decimal=1)
vis_result = predictor.show_result(img1, result0) vis_result = predictor.show_result(img1, result0)
@ -92,10 +92,10 @@ class PoseTopDownPredictorTest(unittest.TestCase):
assert_array_almost_equal( assert_array_almost_equal(
result1['bbox'][:4], result1['bbox'][:4],
np.array([[436.23096, 214.72766, 584.26013, 412.09985, 1.], np.array([[470.6, 214.7, 549.9, 412.1, 0.9],
[43.990044, 91.04126, 164.28406, 251.43329, 1.], [71.6, 91., 136.7, 251.4, 0.9],
[127.44148, 100.38604, 254.219, 269.42273, 1.], [159.7, 100.4, 221.9, 269.4, 0.9],
[190.08075, 117.31801, 311.22394, 278.8423, 1.]], [219.4, 117.3, 281.9, 278.8, 0.9]],
dtype=np.float32), dtype=np.float32),
decimal=1) decimal=1)
vis_result = predictor.show_result(img2, result1) vis_result = predictor.show_result(img2, result1)

View File

@ -179,6 +179,9 @@ PRETRAINED_MODEL_RESNET50 = os.path.join(
PRETRAINED_MODEL_RESNET50_WITHOUTHEAD = os.path.join( PRETRAINED_MODEL_RESNET50_WITHOUTHEAD = os.path.join(
BASE_LOCAL_PATH, BASE_LOCAL_PATH,
'pretrained_models/classification/resnet/resnet50_withhead.pth') 'pretrained_models/classification/resnet/resnet50_withhead.pth')
PRETRAINED_MODEL_RESNET50_ONNX_WITHOUTHEAD = os.path.join(
BASE_LOCAL_PATH,
'pretrained_models/classification/resnet/imagenet_resnet50.onnx')
PRETRAINED_MODEL_FACEID = os.path.join(BASE_LOCAL_PATH, PRETRAINED_MODEL_FACEID = os.path.join(BASE_LOCAL_PATH,
'pretrained_models/faceid') 'pretrained_models/faceid')
PRETRAINED_MODEL_YOLOXS_EXPORT = os.path.join( PRETRAINED_MODEL_YOLOXS_EXPORT = os.path.join(