Export onnx for model only ()

* Support export onnx for model only

* Fix

* Fix
pull/368/head
tripleMu 2022-12-13 15:06:59 +08:00 committed by GitHub
parent e7ff6fcbf0
commit 6fd50af6ab
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 126 additions and 111 deletions

View File

@ -1,5 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .backendwrapper import BackendWrapper, EngineBuilder
from .backendwrapper import BackendWrapper
from .model import DeployModel
__all__ = ['DeployModel', 'BackendWrapper', 'EngineBuilder']
__all__ = ['DeployModel', 'BackendWrapper']

View File

@ -2,11 +2,15 @@ import warnings
from collections import OrderedDict, namedtuple
from functools import partial
from pathlib import Path
from typing import List, Optional, Tuple, Union
from typing import List, Optional, Union
import numpy as np
import onnxruntime
import tensorrt as trt
try:
import tensorrt as trt
except Exception:
trt = None
import torch
from numpy import ndarray
from torch import Tensor
@ -172,85 +176,3 @@ class BackendWrapper:
inputs = [inputs]
outputs = self.__infer(inputs)
return outputs
class EngineBuilder:
def __init__(
self,
checkpoint: Union[str, Path],
opt_shape: Union[Tuple, List] = (1, 3, 640, 640),
device: Optional[Union[str, int, torch.device]] = None) -> None:
checkpoint = Path(checkpoint) if isinstance(checkpoint,
str) else checkpoint
assert checkpoint.exists() and checkpoint.suffix == '.onnx'
if isinstance(device, str):
device = torch.device(device)
elif isinstance(device, int):
device = torch.device(f'cuda:{device}')
self.checkpoint = checkpoint
self.opt_shape = np.array(opt_shape, dtype=np.float32)
self.device = device
def __build_engine(self,
scale: Optional[List[List]] = None,
fp16: bool = True,
with_profiling: bool = True) -> None:
logger = trt.Logger(trt.Logger.WARNING)
trt.init_libnvinfer_plugins(logger, namespace='')
builder = trt.Builder(logger)
config = builder.create_builder_config()
config.max_workspace_size = torch.cuda.get_device_properties(
self.device).total_memory
flag = (1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
network = builder.create_network(flag)
parser = trt.OnnxParser(network, logger)
if not parser.parse_from_file(str(self.checkpoint)):
raise RuntimeError(
f'failed to load ONNX file: {str(self.checkpoint)}')
inputs = [network.get_input(i) for i in range(network.num_inputs)]
outputs = [network.get_output(i) for i in range(network.num_outputs)]
profile = None
dshape = -1 in network.get_input(0).shape
if dshape:
profile = builder.create_optimization_profile()
if scale is None:
scale = np.array(
[[1, 1, 0.5, 0.5], [1, 1, 1, 1], [4, 1, 1.5, 1.5]],
dtype=np.float32)
scale = (self.opt_shape * scale).astype(np.int32)
elif isinstance(scale, List):
scale = np.array(scale, dtype=np.int32)
assert scale.shape[0] == 3, 'Input a wrong scale list'
else:
raise NotImplementedError
for inp in inputs:
logger.log(
trt.Logger.WARNING,
f'input "{inp.name}" with shape{inp.shape} {inp.dtype}')
if dshape:
profile.set_shape(inp.name, *scale)
for out in outputs:
logger.log(
trt.Logger.WARNING,
f'output "{out.name}" with shape{out.shape} {out.dtype}')
if fp16 and builder.platform_has_fast_fp16:
config.set_flag(trt.BuilderFlag.FP16)
self.weight = self.checkpoint.with_suffix('.engine')
if dshape:
config.add_optimization_profile(profile)
if with_profiling:
config.profiling_verbosity = trt.ProfilingVerbosity.DETAILED
with builder.build_engine(network, config) as engine:
self.weight.write_bytes(engine.serialize())
logger.log(
trt.Logger.WARNING, f'Build tensorrt engine finish.\n'
f'Save in {str(self.weight.absolute())}')
def build(self,
scale: Optional[List[List]] = None,
fp16: bool = True,
with_profiling=True):
self.__build_engine(scale, fp16, with_profiling)

View File

@ -22,23 +22,19 @@ class DeployModel(nn.Module):
postprocess_cfg: Optional[ConfigDict] = None):
super().__init__()
self.baseModel = baseModel
self.baseHead = baseModel.bbox_head
self.__init_sub_attributes()
detector_type = type(self.baseHead)
if postprocess_cfg is None:
pre_top_k = 1000
keep_top_k = 100
iou_threshold = 0.65
score_threshold = 0.25
backend = 1
self.with_postprocess = False
else:
pre_top_k = postprocess_cfg.get('pre_top_k', 1000)
keep_top_k = postprocess_cfg.get('keep_top_k', 100)
iou_threshold = postprocess_cfg.get('iou_threshold', 0.65)
score_threshold = postprocess_cfg.get('score_threshold', 0.25)
backend = postprocess_cfg.get('backend', 1)
self.with_postprocess = True
self.baseHead = baseModel.bbox_head
self.__init_sub_attributes()
self.detector_type = type(self.baseHead)
self.pre_top_k = postprocess_cfg.get('pre_top_k', 1000)
self.keep_top_k = postprocess_cfg.get('keep_top_k', 100)
self.iou_threshold = postprocess_cfg.get('iou_threshold', 0.65)
self.score_threshold = postprocess_cfg.get('score_threshold', 0.25)
self.backend = postprocess_cfg.get('backend', 1)
self.__switch_deploy()
self.__dict__.update(locals())
def __init_sub_attributes(self):
self.bbox_decoder = self.baseHead.bbox_coder.decode
@ -140,5 +136,7 @@ class DeployModel(nn.Module):
def forward(self, inputs: Tensor):
neck_outputs = self.baseModel(inputs)
outputs = self.pred_by_feat(*neck_outputs)
return outputs
if self.with_postprocess:
return self.pred_by_feat(*neck_outputs)
else:
return neck_outputs

View File

@ -1,6 +1,95 @@
import argparse
from pathlib import Path
from typing import List, Optional, Tuple, Union
from ..model import EngineBuilder
try:
import tensorrt as trt
except Exception:
trt = None
import numpy as np
import torch
class EngineBuilder:
def __init__(
self,
checkpoint: Union[str, Path],
opt_shape: Union[Tuple, List] = (1, 3, 640, 640),
device: Optional[Union[str, int, torch.device]] = None) -> None:
checkpoint = Path(checkpoint) if isinstance(checkpoint,
str) else checkpoint
assert checkpoint.exists() and checkpoint.suffix == '.onnx'
if isinstance(device, str):
device = torch.device(device)
elif isinstance(device, int):
device = torch.device(f'cuda:{device}')
self.checkpoint = checkpoint
self.opt_shape = np.array(opt_shape, dtype=np.float32)
self.device = device
def __build_engine(self,
scale: Optional[List[List]] = None,
fp16: bool = True,
with_profiling: bool = True) -> None:
logger = trt.Logger(trt.Logger.WARNING)
trt.init_libnvinfer_plugins(logger, namespace='')
builder = trt.Builder(logger)
config = builder.create_builder_config()
config.max_workspace_size = torch.cuda.get_device_properties(
self.device).total_memory
flag = (1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
network = builder.create_network(flag)
parser = trt.OnnxParser(network, logger)
if not parser.parse_from_file(str(self.checkpoint)):
raise RuntimeError(
f'failed to load ONNX file: {str(self.checkpoint)}')
inputs = [network.get_input(i) for i in range(network.num_inputs)]
outputs = [network.get_output(i) for i in range(network.num_outputs)]
profile = None
dshape = -1 in network.get_input(0).shape
if dshape:
profile = builder.create_optimization_profile()
if scale is None:
scale = np.array(
[[1, 1, 0.5, 0.5], [1, 1, 1, 1], [4, 1, 1.5, 1.5]],
dtype=np.float32)
scale = (self.opt_shape * scale).astype(np.int32)
elif isinstance(scale, List):
scale = np.array(scale, dtype=np.int32)
assert scale.shape[0] == 3, 'Input a wrong scale list'
else:
raise NotImplementedError
for inp in inputs:
logger.log(
trt.Logger.WARNING,
f'input "{inp.name}" with shape{inp.shape} {inp.dtype}')
if dshape:
profile.set_shape(inp.name, *scale)
for out in outputs:
logger.log(
trt.Logger.WARNING,
f'output "{out.name}" with shape{out.shape} {out.dtype}')
if fp16 and builder.platform_has_fast_fp16:
config.set_flag(trt.BuilderFlag.FP16)
self.weight = self.checkpoint.with_suffix('.engine')
if dshape:
config.add_optimization_profile(profile)
if with_profiling:
config.profiling_verbosity = trt.ProfilingVerbosity.DETAILED
with builder.build_engine(network, config) as engine:
self.weight.write_bytes(engine.serialize())
logger.log(
trt.Logger.WARNING, f'Build tensorrt engine finish.\n'
f'Save in {str(self.weight.absolute())}')
def build(self,
scale: Optional[List[List]] = None,
fp16: bool = True,
with_profiling=True):
self.__build_engine(scale, fp16, with_profiling)
def parse_args():

View File

@ -21,6 +21,8 @@ def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('config', help='Config file')
parser.add_argument('checkpoint', help='Checkpoint file')
parser.add_argument(
'--model-only', action='store_true', help='Export model only')
parser.add_argument(
'--work-dir', default='./work_dir', help='Path to save export model')
parser.add_argument(
@ -78,13 +80,17 @@ def main():
if not os.path.exists(args.work_dir):
os.mkdir(args.work_dir)
postprocess_cfg = ConfigDict(
pre_top_k=args.pre_topk,
keep_top_k=args.keep_topk,
iou_threshold=args.iou_threshold,
score_threshold=args.score_threshold,
backend=args.backend)
if args.model_only:
postprocess_cfg = None
output_names = None
else:
postprocess_cfg = ConfigDict(
pre_top_k=args.pre_topk,
keep_top_k=args.keep_topk,
iou_threshold=args.iou_threshold,
score_threshold=args.score_threshold,
backend=args.backend)
output_names = ['num_det', 'det_boxes', 'det_scores', 'det_classes']
baseModel = build_model_from_cfg(args.config, args.checkpoint, args.device)
deploy_model = DeployModel(
@ -104,7 +110,7 @@ def main():
fake_input,
f,
input_names=['images'],
output_names=['num_det', 'det_boxes', 'det_scores', 'det_classes'],
output_names=output_names,
opset_version=args.opset)
f.seek(0)
onnx_model = onnx.load(f)