# encoding: utf-8 """ @author: xingyu liao @contact: sherlockliao01@gmail.com """ import logging import os import argparse import io import sys import onnx import onnxoptimizer import torch from onnxsim import simplify from torch.onnx import OperatorExportTypes sys.path.append('.') from fastreid.config import get_cfg from fastreid.modeling.meta_arch import build_model from fastreid.utils.file_io import PathManager from fastreid.utils.checkpoint import Checkpointer from fastreid.utils.logger import setup_logger # import some modules added in project like this below # sys.path.append("projects/FastDistill") # from fastdistill import * setup_logger(name="fastreid") logger = logging.getLogger("fastreid.onnx_export") def setup_cfg(args): cfg = get_cfg() cfg.merge_from_file(args.config_file) cfg.merge_from_list(args.opts) cfg.freeze() return cfg def get_parser(): parser = argparse.ArgumentParser(description="Convert Pytorch to ONNX model") parser.add_argument( "--config-file", metavar="FILE", help="path to config file", ) parser.add_argument( "--name", default="baseline", help="name for converted model" ) parser.add_argument( "--output", default='onnx_model', help='path to save converted onnx model' ) parser.add_argument( '--batch-size', default=1, type=int, help="the maximum batch size of onnx runtime" ) parser.add_argument( "--opts", help="Modify config options using the command-line 'KEY VALUE' pairs", default=[], nargs=argparse.REMAINDER, ) return parser def remove_initializer_from_input(model): if model.ir_version < 4: print( 'Model with ir_version below 4 requires to include initilizer in graph input' ) return inputs = model.graph.input name_to_input = {} for input in inputs: name_to_input[input.name] = input for initializer in model.graph.initializer: if initializer.name in name_to_input: inputs.remove(name_to_input[initializer.name]) return model def export_onnx_model(model, inputs): """ Trace and export a model to onnx format. Args: model (nn.Module): inputs (torch.Tensor): the model will be called by `model(*inputs)` Returns: an onnx model """ assert isinstance(model, torch.nn.Module) # make sure all modules are in eval mode, onnx may change the training state # of the module if the states are not consistent def _check_eval(module): assert not module.training model.apply(_check_eval) logger.info("Beginning ONNX file converting") # Export the model to ONNX with torch.no_grad(): with io.BytesIO() as f: torch.onnx.export( model, inputs, f, operator_export_type=OperatorExportTypes.ONNX_ATEN_FALLBACK, # verbose=True, # NOTE: uncomment this for debugging # export_params=True, ) onnx_model = onnx.load_from_string(f.getvalue()) logger.info("Completed convert of ONNX model") # Apply ONNX's Optimization logger.info("Beginning ONNX model path optimization") all_passes = onnxoptimizer.get_available_passes() passes = ["extract_constant_to_initializer", "eliminate_unused_initializer", "fuse_bn_into_conv"] assert all(p in all_passes for p in passes) onnx_model = onnxoptimizer.optimize(onnx_model, passes) logger.info("Completed ONNX model path optimization") return onnx_model if __name__ == '__main__': args = get_parser().parse_args() cfg = setup_cfg(args) cfg.defrost() cfg.MODEL.BACKBONE.PRETRAIN = False if cfg.MODEL.HEADS.POOL_LAYER == 'FastGlobalAvgPool': cfg.MODEL.HEADS.POOL_LAYER = 'GlobalAvgPool' model = build_model(cfg) Checkpointer(model).load(cfg.MODEL.WEIGHTS) if hasattr(model.backbone, 'deploy'): model.backbone.deploy(True) model.eval() logger.info(model) inputs = torch.randn(args.batch_size, 3, cfg.INPUT.SIZE_TEST[0], cfg.INPUT.SIZE_TEST[1]).to(model.device) onnx_model = export_onnx_model(model, inputs) model_simp, check = simplify(onnx_model) model_simp = remove_initializer_from_input(model_simp) assert check, "Simplified ONNX model could not be validated" PathManager.mkdirs(args.output) save_path = os.path.join(args.output, args.name+'.onnx') onnx.save_model(model_simp, save_path) logger.info("ONNX model file has already saved to {}!".format(save_path))