diff --git a/tools/deployment/pytorch2mlmodel.py b/tools/deployment/pytorch2mlmodel.py new file mode 100644 index 00000000..814cbe94 --- /dev/null +++ b/tools/deployment/pytorch2mlmodel.py @@ -0,0 +1,160 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import os +import os.path as osp +import warnings +from functools import partial + +import mmcv +import numpy as np +import torch +from mmcv.runner import load_checkpoint +from torch import nn + +from mmcls.models import build_classifier + +torch.manual_seed(3) + +try: + import coremltools as ct +except ImportError: + raise ImportError('Please install coremltools to enable output file.') + + +def _demo_mm_inputs(input_shape: tuple, num_classes: int): + """Create a superset of inputs needed to run test or train batches. + + Args: + input_shape (tuple): + input batch dimensions + num_classes (int): + number of semantic classes + """ + (N, C, H, W) = input_shape + rng = np.random.RandomState(0) + imgs = rng.rand(*input_shape) + gt_labels = rng.randint( + low=0, high=num_classes, size=(N, 1)).astype(np.uint8) + mm_inputs = { + 'imgs': torch.FloatTensor(imgs).requires_grad_(False), + 'gt_labels': torch.LongTensor(gt_labels), + } + return mm_inputs + + +def pytorch2mlmodel(model: nn.Module, input_shape: tuple, output_file: str, + add_norm: bool, norm: dict): + """Export Pytorch model to mlmodel format that can be deployed in apple + devices through torch.jit.trace and the coremltools library. + + Optionally, embed the normalization step as a layer to the model. + + Args: + model (nn.Module): Pytorch model we want to export. + input_shape (tuple): Use this input shape to construct + the corresponding dummy input and execute the model. + show (bool): Whether print the computation graph. Default: False. + output_file (string): The path to where we store the output + TorchScript model. + add_norm (bool): Whether to embed the normalization layer to the + output model. + norm (dict): image normalization config for embedding it as a layer + to the output model. + """ + model.cpu().eval() + + num_classes = model.head.num_classes + mm_inputs = _demo_mm_inputs(input_shape, num_classes) + + imgs = mm_inputs.pop('imgs') + img_list = [img[None, :] for img in imgs] + model.forward = partial(model.forward, img_metas={}, return_loss=False) + + with torch.no_grad(): + trace_model = torch.jit.trace(model, img_list[0]) + save_dir, _ = osp.split(output_file) + if save_dir: + os.makedirs(save_dir, exist_ok=True) + + if add_norm: + means, stds = norm.mean, norm.std + if stds.count(stds[0]) != len(stds): + warnings.warn(f'Image std from config is {stds}. However, ' + 'current version of coremltools (5.1) uses a ' + 'global std rather than the channel-specific ' + 'values that torchvision uses. A mean will be ' + 'taken but this might tamper with the resulting ' + 'model\'s predictions. For more details refer ' + 'to the coreml docs on ImageType pre-processing') + scale = np.mean(stds) + else: + scale = stds[0] + + bias = [-mean / scale for mean in means] + image_input = ct.ImageType( + name='input_1', + shape=input_shape, + scale=1 / scale, + bias=bias, + color_layout='RGB', + channel_first=True) + + coreml_model = ct.convert(trace_model, inputs=[image_input]) + coreml_model.save(output_file) + else: + coreml_model = ct.convert( + trace_model, inputs=[ct.TensorType(shape=input_shape)]) + coreml_model.save(output_file) + + print(f'Successfully exported coreml model: {output_file}') + + +def parse_args(): + parser = argparse.ArgumentParser( + description='Convert MMCls to MlModel format for apple devices') + parser.add_argument('config', help='test config file path') + parser.add_argument('--checkpoint', help='checkpoint file', type=str) + parser.add_argument('--output-file', type=str, default='model.mlmodel') + parser.add_argument( + '--shape', + type=int, + nargs='+', + default=[224, 224], + help='input image size') + parser.add_argument( + '--add-norm-layer', + action='store_true', + help='embed normalization layer to deployed model') + args = parser.parse_args() + return args + + +if __name__ == '__main__': + args = parse_args() + + if len(args.shape) == 1: + input_shape = (1, 3, args.shape[0], args.shape[0]) + elif len(args.shape) == 2: + input_shape = ( + 1, + 3, + ) + tuple(args.shape) + else: + raise ValueError('invalid input shape') + + cfg = mmcv.Config.fromfile(args.config) + cfg.model.pretrained = None + + # build the model and load checkpoint + classifier = build_classifier(cfg.model) + + if args.checkpoint: + load_checkpoint(classifier, args.checkpoint, map_location='cpu') + + # convert model to mlmodel file + pytorch2mlmodel( + classifier, + input_shape, + output_file=args.output_file, + add_norm=args.add_norm_layer, + norm=cfg.img_norm_cfg)