[Feature] Add torchscript deployment (#279)
* add torchscript deploy * fix lint * add check and delete \pull/301/head
parent
dbddde52ef
commit
c2f01e0dcd
|
@ -0,0 +1,55 @@
|
|||
# Tutorial 5: Pytorch to TorchScript (Experimental)
|
||||
|
||||
<!-- TOC -->
|
||||
|
||||
- [Tutorial 5: Pytorch to TorchScript (Experimental)](#tutorial-5-pytorch-to-torchscript-experimental)
|
||||
- [How to convert models from Pytorch to TorchScript](#how-to-convert-models-from-pytorch-to-torchscript)
|
||||
- [Usage](#usage)
|
||||
- [Description of all arguments](#description-of-all-arguments)
|
||||
- [Reminders](#reminders)
|
||||
- [FAQs](#faqs)
|
||||
|
||||
<!-- TOC -->
|
||||
|
||||
## How to convert models from Pytorch to TorchScript
|
||||
|
||||
### Usage
|
||||
|
||||
```bash
|
||||
python tools/deployment/pytorch2torchscript.py \
|
||||
${CONFIG_FILE} \
|
||||
--checkpoint ${CHECKPOINT_FILE} \
|
||||
--output-file ${OUTPUT_FILE} \
|
||||
--shape ${IMAGE_SHAPE} \
|
||||
--verify \
|
||||
```
|
||||
|
||||
### Description of all arguments:
|
||||
|
||||
- `config` : The path of a model config file.
|
||||
- `--checkpoint` : The path of a model checkpoint file.
|
||||
- `--output-file`: The path of output TorchScript model. If not specified, it will be set to `tmp.pt`.
|
||||
- `--shape`: The height and width of input tensor to the model. If not specified, it will be set to `224 224`.
|
||||
- `--verify`: Determines whether to verify the correctness of an exported model. If not specified, it will be set to `False`.
|
||||
|
||||
Example:
|
||||
|
||||
```bash
|
||||
python tools/deployment/pytorch2onnx.py \
|
||||
configs/resnet/resnet18_b16x8_cifar10.py \
|
||||
--checkpoint checkpoints/resnet/resnet18_b16x8_cifar10.pth \
|
||||
--output-file checkpoints/resnet/resnet18_b16x8_cifar10.pt \
|
||||
--verify \
|
||||
```
|
||||
|
||||
Notes:
|
||||
|
||||
- *All models above are tested with Pytorch==1.8.1*
|
||||
|
||||
## Reminders
|
||||
|
||||
- If you meet any problem with the models in this repo, please create an issue and it would be taken care of soon.
|
||||
|
||||
## FAQs
|
||||
|
||||
- None
|
|
@ -62,7 +62,7 @@ class ClsHead(BaseHead):
|
|||
if isinstance(cls_score, list):
|
||||
cls_score = sum(cls_score) / float(len(cls_score))
|
||||
pred = F.softmax(cls_score, dim=1) if cls_score is not None else None
|
||||
if torch.onnx.is_in_onnx_export():
|
||||
if torch.onnx.is_in_onnx_export() or torch.jit.is_tracing():
|
||||
return pred
|
||||
pred = list(pred.detach().cpu().numpy())
|
||||
return pred
|
||||
|
|
|
@ -43,7 +43,7 @@ class LinearClsHead(ClsHead):
|
|||
if isinstance(cls_score, list):
|
||||
cls_score = sum(cls_score) / float(len(cls_score))
|
||||
pred = F.softmax(cls_score, dim=1) if cls_score is not None else None
|
||||
if torch.onnx.is_in_onnx_export():
|
||||
if torch.onnx.is_in_onnx_export() or torch.jit.is_tracing():
|
||||
return pred
|
||||
pred = list(pred.detach().cpu().numpy())
|
||||
return pred
|
||||
|
|
|
@ -47,7 +47,7 @@ class MultiLabelClsHead(BaseHead):
|
|||
if isinstance(cls_score, list):
|
||||
cls_score = sum(cls_score) / float(len(cls_score))
|
||||
pred = F.sigmoid(cls_score) if cls_score is not None else None
|
||||
if torch.onnx.is_in_onnx_export():
|
||||
if torch.onnx.is_in_onnx_export() or torch.jit.is_tracing():
|
||||
return pred
|
||||
pred = list(pred.detach().cpu().numpy())
|
||||
return pred
|
||||
|
|
|
@ -56,7 +56,7 @@ class MultiLabelLinearClsHead(MultiLabelClsHead):
|
|||
if isinstance(cls_score, list):
|
||||
cls_score = sum(cls_score) / float(len(cls_score))
|
||||
pred = F.sigmoid(cls_score) if cls_score is not None else None
|
||||
if torch.onnx.is_in_onnx_export():
|
||||
if torch.onnx.is_in_onnx_export() or torch.jit.is_tracing():
|
||||
return pred
|
||||
pred = list(pred.detach().cpu().numpy())
|
||||
return pred
|
||||
|
|
|
@ -68,7 +68,7 @@ class VisionTransformerClsHead(ClsHead):
|
|||
if isinstance(cls_score, list):
|
||||
cls_score = sum(cls_score) / float(len(cls_score))
|
||||
pred = F.softmax(cls_score, dim=1) if cls_score is not None else None
|
||||
if torch.onnx.is_in_onnx_export():
|
||||
if torch.onnx.is_in_onnx_export() or torch.jit.is_tracing():
|
||||
return pred
|
||||
pred = list(pred.detach().cpu().numpy())
|
||||
return pred
|
||||
|
|
|
@ -26,7 +26,7 @@ def _demo_mm_inputs(input_shape, num_classes):
|
|||
rng = np.random.RandomState(0)
|
||||
imgs = rng.rand(*input_shape)
|
||||
gt_labels = rng.randint(
|
||||
low=0, high=num_classes - 1, size=(N, 1)).astype(np.uint8)
|
||||
low=0, high=num_classes, size=(N, 1)).astype(np.uint8)
|
||||
mm_inputs = {
|
||||
'imgs': torch.FloatTensor(imgs).requires_grad_(True),
|
||||
'gt_labels': torch.LongTensor(gt_labels),
|
||||
|
|
|
@ -0,0 +1,138 @@
|
|||
import os
|
||||
import argparse
|
||||
import os.path as osp
|
||||
from functools import partial
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
import torch
|
||||
from mmcv.runner import load_checkpoint
|
||||
from torch import nn
|
||||
|
||||
from mmcls.models import build_classifier
|
||||
|
||||
torch.manual_seed(3)
|
||||
|
||||
|
||||
def _demo_mm_inputs(input_shape: tuple, num_classes: int):
|
||||
"""Create a superset of inputs needed to run test or train batches.
|
||||
|
||||
Args:
|
||||
input_shape (tuple):
|
||||
input batch dimensions
|
||||
num_classes (int):
|
||||
number of semantic classes
|
||||
"""
|
||||
(N, C, H, W) = input_shape
|
||||
rng = np.random.RandomState(0)
|
||||
imgs = rng.rand(*input_shape)
|
||||
gt_labels = rng.randint(
|
||||
low=0, high=num_classes, size=(N, 1)).astype(np.uint8)
|
||||
mm_inputs = {
|
||||
'imgs': torch.FloatTensor(imgs).requires_grad_(False),
|
||||
'gt_labels': torch.LongTensor(gt_labels),
|
||||
}
|
||||
return mm_inputs
|
||||
|
||||
|
||||
def pytorch2torchscript(model: nn.Module, input_shape: tuple, output_file: str,
|
||||
verify: bool):
|
||||
"""Export Pytorch model to TorchScript model through torch.jit.trace and
|
||||
verify the outputs are same between Pytorch and TorchScript.
|
||||
|
||||
Args:
|
||||
model (nn.Module): Pytorch model we want to export.
|
||||
input_shape (tuple): Use this input shape to construct
|
||||
the corresponding dummy input and execute the model.
|
||||
show (bool): Whether print the computation graph. Default: False.
|
||||
output_file (string): The path to where we store the output
|
||||
TorchScript model.
|
||||
verify (bool): Whether compare the outputs between Pytorch
|
||||
and TorchScript through loading generated output_file.
|
||||
"""
|
||||
model.cpu().eval()
|
||||
|
||||
num_classes = model.head.num_classes
|
||||
mm_inputs = _demo_mm_inputs(input_shape, num_classes)
|
||||
|
||||
imgs = mm_inputs.pop('imgs')
|
||||
img_list = [img[None, :] for img in imgs]
|
||||
|
||||
# replace original forward function
|
||||
origin_forward = model.forward
|
||||
model.forward = partial(model.forward, img_metas={}, return_loss=False)
|
||||
|
||||
with torch.no_grad():
|
||||
trace_model = torch.jit.trace(model, img_list[0])
|
||||
save_dir, _ = osp.split(output_file)
|
||||
if save_dir:
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
trace_model.save(output_file)
|
||||
print(f'Successfully exported TorchScript model: {output_file}')
|
||||
model.forward = origin_forward
|
||||
|
||||
if verify:
|
||||
# load by torch.jit
|
||||
jit_model = torch.jit.load(output_file)
|
||||
|
||||
# check the numerical value
|
||||
# get pytorch output
|
||||
pytorch_result = model(img_list, img_metas={}, return_loss=False)[0]
|
||||
|
||||
# get jit output
|
||||
jit_result = jit_model(img_list[0])[0].detach().numpy()
|
||||
if not np.allclose(pytorch_result, jit_result):
|
||||
raise ValueError(
|
||||
'The outputs are different between Pytorch and TorchScript')
|
||||
print('The outputs are same between Pytorch and TorchScript')
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Convert MMCls to TorchScript')
|
||||
parser.add_argument('config', help='test config file path')
|
||||
parser.add_argument('--checkpoint', help='checkpoint file', type=str)
|
||||
parser.add_argument(
|
||||
'--verify',
|
||||
action='store_true',
|
||||
help='verify the TorchScript model',
|
||||
default=False)
|
||||
parser.add_argument('--output-file', type=str, default='tmp.pt')
|
||||
parser.add_argument(
|
||||
'--shape',
|
||||
type=int,
|
||||
nargs='+',
|
||||
default=[224, 224],
|
||||
help='input image size')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parse_args()
|
||||
|
||||
if len(args.shape) == 1:
|
||||
input_shape = (1, 3, args.shape[0], args.shape[0])
|
||||
elif len(args.shape) == 2:
|
||||
input_shape = (
|
||||
1,
|
||||
3,
|
||||
) + tuple(args.shape)
|
||||
else:
|
||||
raise ValueError('invalid input shape')
|
||||
|
||||
cfg = mmcv.Config.fromfile(args.config)
|
||||
cfg.model.pretrained = None
|
||||
|
||||
# build the model and load checkpoint
|
||||
classifier = build_classifier(cfg.model)
|
||||
|
||||
if args.checkpoint:
|
||||
load_checkpoint(classifier, args.checkpoint, map_location='cpu')
|
||||
|
||||
# conver model to TorchScript file
|
||||
pytorch2torchscript(
|
||||
classifier,
|
||||
input_shape,
|
||||
output_file=args.output_file,
|
||||
verify=args.verify)
|
Loading…
Reference in New Issue