[Feature] Add torchscript deployment (#279)

* add torchscript deploy

* fix lint

* add check and delete \
pull/301/head
AllentDan 2021-06-12 21:50:48 +08:00 committed by GitHub
parent dbddde52ef
commit c2f01e0dcd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 199 additions and 6 deletions

View File

@ -0,0 +1,55 @@
# Tutorial 5: Pytorch to TorchScript (Experimental)
<!-- TOC -->
- [Tutorial 5: Pytorch to TorchScript (Experimental)](#tutorial-5-pytorch-to-torchscript-experimental)
- [How to convert models from Pytorch to TorchScript](#how-to-convert-models-from-pytorch-to-torchscript)
- [Usage](#usage)
- [Description of all arguments](#description-of-all-arguments)
- [Reminders](#reminders)
- [FAQs](#faqs)
<!-- TOC -->
## How to convert models from Pytorch to TorchScript
### Usage
```bash
python tools/deployment/pytorch2torchscript.py \
${CONFIG_FILE} \
--checkpoint ${CHECKPOINT_FILE} \
--output-file ${OUTPUT_FILE} \
--shape ${IMAGE_SHAPE} \
--verify \
```
### Description of all arguments:
- `config` : The path of a model config file.
- `--checkpoint` : The path of a model checkpoint file.
- `--output-file`: The path of output TorchScript model. If not specified, it will be set to `tmp.pt`.
- `--shape`: The height and width of input tensor to the model. If not specified, it will be set to `224 224`.
- `--verify`: Determines whether to verify the correctness of an exported model. If not specified, it will be set to `False`.
Example:
```bash
python tools/deployment/pytorch2onnx.py \
configs/resnet/resnet18_b16x8_cifar10.py \
--checkpoint checkpoints/resnet/resnet18_b16x8_cifar10.pth \
--output-file checkpoints/resnet/resnet18_b16x8_cifar10.pt \
--verify \
```
Notes:
- *All models above are tested with Pytorch==1.8.1*
## Reminders
- If you meet any problem with the models in this repo, please create an issue and it would be taken care of soon.
## FAQs
- None

View File

@ -62,7 +62,7 @@ class ClsHead(BaseHead):
if isinstance(cls_score, list):
cls_score = sum(cls_score) / float(len(cls_score))
pred = F.softmax(cls_score, dim=1) if cls_score is not None else None
if torch.onnx.is_in_onnx_export():
if torch.onnx.is_in_onnx_export() or torch.jit.is_tracing():
return pred
pred = list(pred.detach().cpu().numpy())
return pred

View File

@ -43,7 +43,7 @@ class LinearClsHead(ClsHead):
if isinstance(cls_score, list):
cls_score = sum(cls_score) / float(len(cls_score))
pred = F.softmax(cls_score, dim=1) if cls_score is not None else None
if torch.onnx.is_in_onnx_export():
if torch.onnx.is_in_onnx_export() or torch.jit.is_tracing():
return pred
pred = list(pred.detach().cpu().numpy())
return pred

View File

@ -47,7 +47,7 @@ class MultiLabelClsHead(BaseHead):
if isinstance(cls_score, list):
cls_score = sum(cls_score) / float(len(cls_score))
pred = F.sigmoid(cls_score) if cls_score is not None else None
if torch.onnx.is_in_onnx_export():
if torch.onnx.is_in_onnx_export() or torch.jit.is_tracing():
return pred
pred = list(pred.detach().cpu().numpy())
return pred

View File

@ -56,7 +56,7 @@ class MultiLabelLinearClsHead(MultiLabelClsHead):
if isinstance(cls_score, list):
cls_score = sum(cls_score) / float(len(cls_score))
pred = F.sigmoid(cls_score) if cls_score is not None else None
if torch.onnx.is_in_onnx_export():
if torch.onnx.is_in_onnx_export() or torch.jit.is_tracing():
return pred
pred = list(pred.detach().cpu().numpy())
return pred

View File

@ -68,7 +68,7 @@ class VisionTransformerClsHead(ClsHead):
if isinstance(cls_score, list):
cls_score = sum(cls_score) / float(len(cls_score))
pred = F.softmax(cls_score, dim=1) if cls_score is not None else None
if torch.onnx.is_in_onnx_export():
if torch.onnx.is_in_onnx_export() or torch.jit.is_tracing():
return pred
pred = list(pred.detach().cpu().numpy())
return pred

View File

@ -26,7 +26,7 @@ def _demo_mm_inputs(input_shape, num_classes):
rng = np.random.RandomState(0)
imgs = rng.rand(*input_shape)
gt_labels = rng.randint(
low=0, high=num_classes - 1, size=(N, 1)).astype(np.uint8)
low=0, high=num_classes, size=(N, 1)).astype(np.uint8)
mm_inputs = {
'imgs': torch.FloatTensor(imgs).requires_grad_(True),
'gt_labels': torch.LongTensor(gt_labels),

View File

@ -0,0 +1,138 @@
import os
import argparse
import os.path as osp
from functools import partial
import mmcv
import numpy as np
import torch
from mmcv.runner import load_checkpoint
from torch import nn
from mmcls.models import build_classifier
torch.manual_seed(3)
def _demo_mm_inputs(input_shape: tuple, num_classes: int):
"""Create a superset of inputs needed to run test or train batches.
Args:
input_shape (tuple):
input batch dimensions
num_classes (int):
number of semantic classes
"""
(N, C, H, W) = input_shape
rng = np.random.RandomState(0)
imgs = rng.rand(*input_shape)
gt_labels = rng.randint(
low=0, high=num_classes, size=(N, 1)).astype(np.uint8)
mm_inputs = {
'imgs': torch.FloatTensor(imgs).requires_grad_(False),
'gt_labels': torch.LongTensor(gt_labels),
}
return mm_inputs
def pytorch2torchscript(model: nn.Module, input_shape: tuple, output_file: str,
verify: bool):
"""Export Pytorch model to TorchScript model through torch.jit.trace and
verify the outputs are same between Pytorch and TorchScript.
Args:
model (nn.Module): Pytorch model we want to export.
input_shape (tuple): Use this input shape to construct
the corresponding dummy input and execute the model.
show (bool): Whether print the computation graph. Default: False.
output_file (string): The path to where we store the output
TorchScript model.
verify (bool): Whether compare the outputs between Pytorch
and TorchScript through loading generated output_file.
"""
model.cpu().eval()
num_classes = model.head.num_classes
mm_inputs = _demo_mm_inputs(input_shape, num_classes)
imgs = mm_inputs.pop('imgs')
img_list = [img[None, :] for img in imgs]
# replace original forward function
origin_forward = model.forward
model.forward = partial(model.forward, img_metas={}, return_loss=False)
with torch.no_grad():
trace_model = torch.jit.trace(model, img_list[0])
save_dir, _ = osp.split(output_file)
if save_dir:
os.makedirs(save_dir, exist_ok=True)
trace_model.save(output_file)
print(f'Successfully exported TorchScript model: {output_file}')
model.forward = origin_forward
if verify:
# load by torch.jit
jit_model = torch.jit.load(output_file)
# check the numerical value
# get pytorch output
pytorch_result = model(img_list, img_metas={}, return_loss=False)[0]
# get jit output
jit_result = jit_model(img_list[0])[0].detach().numpy()
if not np.allclose(pytorch_result, jit_result):
raise ValueError(
'The outputs are different between Pytorch and TorchScript')
print('The outputs are same between Pytorch and TorchScript')
def parse_args():
parser = argparse.ArgumentParser(
description='Convert MMCls to TorchScript')
parser.add_argument('config', help='test config file path')
parser.add_argument('--checkpoint', help='checkpoint file', type=str)
parser.add_argument(
'--verify',
action='store_true',
help='verify the TorchScript model',
default=False)
parser.add_argument('--output-file', type=str, default='tmp.pt')
parser.add_argument(
'--shape',
type=int,
nargs='+',
default=[224, 224],
help='input image size')
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
if len(args.shape) == 1:
input_shape = (1, 3, args.shape[0], args.shape[0])
elif len(args.shape) == 2:
input_shape = (
1,
3,
) + tuple(args.shape)
else:
raise ValueError('invalid input shape')
cfg = mmcv.Config.fromfile(args.config)
cfg.model.pretrained = None
# build the model and load checkpoint
classifier = build_classifier(cfg.model)
if args.checkpoint:
load_checkpoint(classifier, args.checkpoint, map_location='cpu')
# conver model to TorchScript file
pytorch2torchscript(
classifier,
input_shape,
output_file=args.output_file,
verify=args.verify)