mmpretrain/tools/deployment/pytorch2torchscript.py

139 lines
4.2 KiB
Python
Raw Normal View History

import os
import argparse
import os.path as osp
from functools import partial
import mmcv
import numpy as np
import torch
from mmcv.runner import load_checkpoint
from torch import nn
from mmcls.models import build_classifier
torch.manual_seed(3)
def _demo_mm_inputs(input_shape: tuple, num_classes: int):
"""Create a superset of inputs needed to run test or train batches.
Args:
input_shape (tuple):
input batch dimensions
num_classes (int):
number of semantic classes
"""
(N, C, H, W) = input_shape
rng = np.random.RandomState(0)
imgs = rng.rand(*input_shape)
gt_labels = rng.randint(
low=0, high=num_classes, size=(N, 1)).astype(np.uint8)
mm_inputs = {
'imgs': torch.FloatTensor(imgs).requires_grad_(False),
'gt_labels': torch.LongTensor(gt_labels),
}
return mm_inputs
def pytorch2torchscript(model: nn.Module, input_shape: tuple, output_file: str,
verify: bool):
"""Export Pytorch model to TorchScript model through torch.jit.trace and
verify the outputs are same between Pytorch and TorchScript.
Args:
model (nn.Module): Pytorch model we want to export.
input_shape (tuple): Use this input shape to construct
the corresponding dummy input and execute the model.
show (bool): Whether print the computation graph. Default: False.
output_file (string): The path to where we store the output
TorchScript model.
verify (bool): Whether compare the outputs between Pytorch
and TorchScript through loading generated output_file.
"""
model.cpu().eval()
num_classes = model.head.num_classes
mm_inputs = _demo_mm_inputs(input_shape, num_classes)
imgs = mm_inputs.pop('imgs')
img_list = [img[None, :] for img in imgs]
# replace original forward function
origin_forward = model.forward
model.forward = partial(model.forward, img_metas={}, return_loss=False)
with torch.no_grad():
trace_model = torch.jit.trace(model, img_list[0])
save_dir, _ = osp.split(output_file)
if save_dir:
os.makedirs(save_dir, exist_ok=True)
trace_model.save(output_file)
print(f'Successfully exported TorchScript model: {output_file}')
model.forward = origin_forward
if verify:
# load by torch.jit
jit_model = torch.jit.load(output_file)
# check the numerical value
# get pytorch output
pytorch_result = model(img_list, img_metas={}, return_loss=False)[0]
# get jit output
jit_result = jit_model(img_list[0])[0].detach().numpy()
if not np.allclose(pytorch_result, jit_result):
raise ValueError(
'The outputs are different between Pytorch and TorchScript')
print('The outputs are same between Pytorch and TorchScript')
def parse_args():
parser = argparse.ArgumentParser(
description='Convert MMCls to TorchScript')
parser.add_argument('config', help='test config file path')
parser.add_argument('--checkpoint', help='checkpoint file', type=str)
parser.add_argument(
'--verify',
action='store_true',
help='verify the TorchScript model',
default=False)
parser.add_argument('--output-file', type=str, default='tmp.pt')
parser.add_argument(
'--shape',
type=int,
nargs='+',
default=[224, 224],
help='input image size')
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
if len(args.shape) == 1:
input_shape = (1, 3, args.shape[0], args.shape[0])
elif len(args.shape) == 2:
input_shape = (
1,
3,
) + tuple(args.shape)
else:
raise ValueError('invalid input shape')
cfg = mmcv.Config.fromfile(args.config)
cfg.model.pretrained = None
# build the model and load checkpoint
classifier = build_classifier(cfg.model)
if args.checkpoint:
load_checkpoint(classifier, args.checkpoint, map_location='cpu')
# conver model to TorchScript file
pytorch2torchscript(
classifier,
input_shape,
output_file=args.output_file,
verify=args.verify)