fast-reid/tools/deploy/onnx_export.py

168 lines
4.6 KiB
Python

# encoding: utf-8
"""
@author: xingyu liao
@contact: sherlockliao01@gmail.com
"""
import logging
import os
import argparse
import io
import sys
import onnx
import onnxoptimizer
import torch
from onnxsim import simplify
from torch.onnx import OperatorExportTypes
sys.path.append('.')
from fastreid.config import get_cfg
from fastreid.modeling.meta_arch import build_model
from fastreid.utils.file_io import PathManager
from fastreid.utils.checkpoint import Checkpointer
from fastreid.utils.logger import setup_logger
# import some modules added in project like this below
# sys.path.append("projects/FastDistill")
# from fastdistill import *
setup_logger(name="fastreid")
logger = logging.getLogger("fastreid.onnx_export")
def setup_cfg(args):
cfg = get_cfg()
cfg.merge_from_file(args.config_file)
cfg.merge_from_list(args.opts)
cfg.freeze()
return cfg
def get_parser():
parser = argparse.ArgumentParser(description="Convert Pytorch to ONNX model")
parser.add_argument(
"--config-file",
metavar="FILE",
help="path to config file",
)
parser.add_argument(
"--name",
default="baseline",
help="name for converted model"
)
parser.add_argument(
"--output",
default='onnx_model',
help='path to save converted onnx model'
)
parser.add_argument(
'--batch-size',
default=1,
type=int,
help="the maximum batch size of onnx runtime"
)
parser.add_argument(
"--opts",
help="Modify config options using the command-line 'KEY VALUE' pairs",
default=[],
nargs=argparse.REMAINDER,
)
return parser
def remove_initializer_from_input(model):
if model.ir_version < 4:
print(
'Model with ir_version below 4 requires to include initilizer in graph input'
)
return
inputs = model.graph.input
name_to_input = {}
for input in inputs:
name_to_input[input.name] = input
for initializer in model.graph.initializer:
if initializer.name in name_to_input:
inputs.remove(name_to_input[initializer.name])
return model
def export_onnx_model(model, inputs):
"""
Trace and export a model to onnx format.
Args:
model (nn.Module):
inputs (torch.Tensor): the model will be called by `model(*inputs)`
Returns:
an onnx model
"""
assert isinstance(model, torch.nn.Module)
# make sure all modules are in eval mode, onnx may change the training state
# of the module if the states are not consistent
def _check_eval(module):
assert not module.training
model.apply(_check_eval)
logger.info("Beginning ONNX file converting")
# Export the model to ONNX
with torch.no_grad():
with io.BytesIO() as f:
torch.onnx.export(
model,
inputs,
f,
operator_export_type=OperatorExportTypes.ONNX_ATEN_FALLBACK,
# verbose=True, # NOTE: uncomment this for debugging
# export_params=True,
)
onnx_model = onnx.load_from_string(f.getvalue())
logger.info("Completed convert of ONNX model")
# Apply ONNX's Optimization
logger.info("Beginning ONNX model path optimization")
all_passes = onnxoptimizer.get_available_passes()
passes = ["extract_constant_to_initializer", "eliminate_unused_initializer", "fuse_bn_into_conv"]
assert all(p in all_passes for p in passes)
onnx_model = onnxoptimizer.optimize(onnx_model, passes)
logger.info("Completed ONNX model path optimization")
return onnx_model
if __name__ == '__main__':
args = get_parser().parse_args()
cfg = setup_cfg(args)
cfg.defrost()
cfg.MODEL.BACKBONE.PRETRAIN = False
if cfg.MODEL.HEADS.POOL_LAYER == 'FastGlobalAvgPool':
cfg.MODEL.HEADS.POOL_LAYER = 'GlobalAvgPool'
model = build_model(cfg)
Checkpointer(model).load(cfg.MODEL.WEIGHTS)
if hasattr(model.backbone, 'deploy'):
model.backbone.deploy(True)
model.eval()
logger.info(model)
inputs = torch.randn(args.batch_size, 3, cfg.INPUT.SIZE_TEST[0], cfg.INPUT.SIZE_TEST[1]).to(model.device)
onnx_model = export_onnx_model(model, inputs)
model_simp, check = simplify(onnx_model)
model_simp = remove_initializer_from_input(model_simp)
assert check, "Simplified ONNX model could not be validated"
PathManager.mkdirs(args.output)
save_path = os.path.join(args.output, args.name+'.onnx')
onnx.save_model(model_simp, save_path)
logger.info("ONNX model file has already saved to {}!".format(save_path))