[Feature] Add conversion script from pytroch to mlmodel (#597)

pull/612/head
Rejnald Lleshi 2022-01-13 05:32:52 +01:00 committed by GitHub
parent 5232965b17
commit c07f8326e3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 160 additions and 0 deletions

View File

@ -0,0 +1,160 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os
import os.path as osp
import warnings
from functools import partial
import mmcv
import numpy as np
import torch
from mmcv.runner import load_checkpoint
from torch import nn
from mmcls.models import build_classifier
torch.manual_seed(3)
try:
import coremltools as ct
except ImportError:
raise ImportError('Please install coremltools to enable output file.')
def _demo_mm_inputs(input_shape: tuple, num_classes: int):
"""Create a superset of inputs needed to run test or train batches.
Args:
input_shape (tuple):
input batch dimensions
num_classes (int):
number of semantic classes
"""
(N, C, H, W) = input_shape
rng = np.random.RandomState(0)
imgs = rng.rand(*input_shape)
gt_labels = rng.randint(
low=0, high=num_classes, size=(N, 1)).astype(np.uint8)
mm_inputs = {
'imgs': torch.FloatTensor(imgs).requires_grad_(False),
'gt_labels': torch.LongTensor(gt_labels),
}
return mm_inputs
def pytorch2mlmodel(model: nn.Module, input_shape: tuple, output_file: str,
add_norm: bool, norm: dict):
"""Export Pytorch model to mlmodel format that can be deployed in apple
devices through torch.jit.trace and the coremltools library.
Optionally, embed the normalization step as a layer to the model.
Args:
model (nn.Module): Pytorch model we want to export.
input_shape (tuple): Use this input shape to construct
the corresponding dummy input and execute the model.
show (bool): Whether print the computation graph. Default: False.
output_file (string): The path to where we store the output
TorchScript model.
add_norm (bool): Whether to embed the normalization layer to the
output model.
norm (dict): image normalization config for embedding it as a layer
to the output model.
"""
model.cpu().eval()
num_classes = model.head.num_classes
mm_inputs = _demo_mm_inputs(input_shape, num_classes)
imgs = mm_inputs.pop('imgs')
img_list = [img[None, :] for img in imgs]
model.forward = partial(model.forward, img_metas={}, return_loss=False)
with torch.no_grad():
trace_model = torch.jit.trace(model, img_list[0])
save_dir, _ = osp.split(output_file)
if save_dir:
os.makedirs(save_dir, exist_ok=True)
if add_norm:
means, stds = norm.mean, norm.std
if stds.count(stds[0]) != len(stds):
warnings.warn(f'Image std from config is {stds}. However, '
'current version of coremltools (5.1) uses a '
'global std rather than the channel-specific '
'values that torchvision uses. A mean will be '
'taken but this might tamper with the resulting '
'model\'s predictions. For more details refer '
'to the coreml docs on ImageType pre-processing')
scale = np.mean(stds)
else:
scale = stds[0]
bias = [-mean / scale for mean in means]
image_input = ct.ImageType(
name='input_1',
shape=input_shape,
scale=1 / scale,
bias=bias,
color_layout='RGB',
channel_first=True)
coreml_model = ct.convert(trace_model, inputs=[image_input])
coreml_model.save(output_file)
else:
coreml_model = ct.convert(
trace_model, inputs=[ct.TensorType(shape=input_shape)])
coreml_model.save(output_file)
print(f'Successfully exported coreml model: {output_file}')
def parse_args():
parser = argparse.ArgumentParser(
description='Convert MMCls to MlModel format for apple devices')
parser.add_argument('config', help='test config file path')
parser.add_argument('--checkpoint', help='checkpoint file', type=str)
parser.add_argument('--output-file', type=str, default='model.mlmodel')
parser.add_argument(
'--shape',
type=int,
nargs='+',
default=[224, 224],
help='input image size')
parser.add_argument(
'--add-norm-layer',
action='store_true',
help='embed normalization layer to deployed model')
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
if len(args.shape) == 1:
input_shape = (1, 3, args.shape[0], args.shape[0])
elif len(args.shape) == 2:
input_shape = (
1,
3,
) + tuple(args.shape)
else:
raise ValueError('invalid input shape')
cfg = mmcv.Config.fromfile(args.config)
cfg.model.pretrained = None
# build the model and load checkpoint
classifier = build_classifier(cfg.model)
if args.checkpoint:
load_checkpoint(classifier, args.checkpoint, map_location='cpu')
# convert model to mlmodel file
pytorch2mlmodel(
classifier,
input_shape,
output_file=args.output_file,
add_norm=args.add_norm_layer,
norm=cfg.img_norm_cfg)