[Feature] Add conversion script from pytroch to mlmodel (#597)
parent
5232965b17
commit
c07f8326e3
|
@ -0,0 +1,160 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import argparse
|
||||
import os
|
||||
import os.path as osp
|
||||
import warnings
|
||||
from functools import partial
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
import torch
|
||||
from mmcv.runner import load_checkpoint
|
||||
from torch import nn
|
||||
|
||||
from mmcls.models import build_classifier
|
||||
|
||||
torch.manual_seed(3)
|
||||
|
||||
try:
|
||||
import coremltools as ct
|
||||
except ImportError:
|
||||
raise ImportError('Please install coremltools to enable output file.')
|
||||
|
||||
|
||||
def _demo_mm_inputs(input_shape: tuple, num_classes: int):
|
||||
"""Create a superset of inputs needed to run test or train batches.
|
||||
|
||||
Args:
|
||||
input_shape (tuple):
|
||||
input batch dimensions
|
||||
num_classes (int):
|
||||
number of semantic classes
|
||||
"""
|
||||
(N, C, H, W) = input_shape
|
||||
rng = np.random.RandomState(0)
|
||||
imgs = rng.rand(*input_shape)
|
||||
gt_labels = rng.randint(
|
||||
low=0, high=num_classes, size=(N, 1)).astype(np.uint8)
|
||||
mm_inputs = {
|
||||
'imgs': torch.FloatTensor(imgs).requires_grad_(False),
|
||||
'gt_labels': torch.LongTensor(gt_labels),
|
||||
}
|
||||
return mm_inputs
|
||||
|
||||
|
||||
def pytorch2mlmodel(model: nn.Module, input_shape: tuple, output_file: str,
|
||||
add_norm: bool, norm: dict):
|
||||
"""Export Pytorch model to mlmodel format that can be deployed in apple
|
||||
devices through torch.jit.trace and the coremltools library.
|
||||
|
||||
Optionally, embed the normalization step as a layer to the model.
|
||||
|
||||
Args:
|
||||
model (nn.Module): Pytorch model we want to export.
|
||||
input_shape (tuple): Use this input shape to construct
|
||||
the corresponding dummy input and execute the model.
|
||||
show (bool): Whether print the computation graph. Default: False.
|
||||
output_file (string): The path to where we store the output
|
||||
TorchScript model.
|
||||
add_norm (bool): Whether to embed the normalization layer to the
|
||||
output model.
|
||||
norm (dict): image normalization config for embedding it as a layer
|
||||
to the output model.
|
||||
"""
|
||||
model.cpu().eval()
|
||||
|
||||
num_classes = model.head.num_classes
|
||||
mm_inputs = _demo_mm_inputs(input_shape, num_classes)
|
||||
|
||||
imgs = mm_inputs.pop('imgs')
|
||||
img_list = [img[None, :] for img in imgs]
|
||||
model.forward = partial(model.forward, img_metas={}, return_loss=False)
|
||||
|
||||
with torch.no_grad():
|
||||
trace_model = torch.jit.trace(model, img_list[0])
|
||||
save_dir, _ = osp.split(output_file)
|
||||
if save_dir:
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
if add_norm:
|
||||
means, stds = norm.mean, norm.std
|
||||
if stds.count(stds[0]) != len(stds):
|
||||
warnings.warn(f'Image std from config is {stds}. However, '
|
||||
'current version of coremltools (5.1) uses a '
|
||||
'global std rather than the channel-specific '
|
||||
'values that torchvision uses. A mean will be '
|
||||
'taken but this might tamper with the resulting '
|
||||
'model\'s predictions. For more details refer '
|
||||
'to the coreml docs on ImageType pre-processing')
|
||||
scale = np.mean(stds)
|
||||
else:
|
||||
scale = stds[0]
|
||||
|
||||
bias = [-mean / scale for mean in means]
|
||||
image_input = ct.ImageType(
|
||||
name='input_1',
|
||||
shape=input_shape,
|
||||
scale=1 / scale,
|
||||
bias=bias,
|
||||
color_layout='RGB',
|
||||
channel_first=True)
|
||||
|
||||
coreml_model = ct.convert(trace_model, inputs=[image_input])
|
||||
coreml_model.save(output_file)
|
||||
else:
|
||||
coreml_model = ct.convert(
|
||||
trace_model, inputs=[ct.TensorType(shape=input_shape)])
|
||||
coreml_model.save(output_file)
|
||||
|
||||
print(f'Successfully exported coreml model: {output_file}')
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Convert MMCls to MlModel format for apple devices')
|
||||
parser.add_argument('config', help='test config file path')
|
||||
parser.add_argument('--checkpoint', help='checkpoint file', type=str)
|
||||
parser.add_argument('--output-file', type=str, default='model.mlmodel')
|
||||
parser.add_argument(
|
||||
'--shape',
|
||||
type=int,
|
||||
nargs='+',
|
||||
default=[224, 224],
|
||||
help='input image size')
|
||||
parser.add_argument(
|
||||
'--add-norm-layer',
|
||||
action='store_true',
|
||||
help='embed normalization layer to deployed model')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parse_args()
|
||||
|
||||
if len(args.shape) == 1:
|
||||
input_shape = (1, 3, args.shape[0], args.shape[0])
|
||||
elif len(args.shape) == 2:
|
||||
input_shape = (
|
||||
1,
|
||||
3,
|
||||
) + tuple(args.shape)
|
||||
else:
|
||||
raise ValueError('invalid input shape')
|
||||
|
||||
cfg = mmcv.Config.fromfile(args.config)
|
||||
cfg.model.pretrained = None
|
||||
|
||||
# build the model and load checkpoint
|
||||
classifier = build_classifier(cfg.model)
|
||||
|
||||
if args.checkpoint:
|
||||
load_checkpoint(classifier, args.checkpoint, map_location='cpu')
|
||||
|
||||
# convert model to mlmodel file
|
||||
pytorch2mlmodel(
|
||||
classifier,
|
||||
input_shape,
|
||||
output_file=args.output_file,
|
||||
add_norm=args.add_norm_layer,
|
||||
norm=cfg.img_norm_cfg)
|
Loading…
Reference in New Issue