diff --git a/docs/tutorials/pytorch2torchscript.md b/docs/tutorials/pytorch2torchscript.md index d9f110575..13ea2ccf5 100644 --- a/docs/tutorials/pytorch2torchscript.md +++ b/docs/tutorials/pytorch2torchscript.md @@ -24,7 +24,7 @@ python tools/deployment/pytorch2torchscript.py \ --verify \ ``` -### Description of all arguments: +### Description of all arguments - `config` : The path of a model config file. - `--checkpoint` : The path of a model checkpoint file. @@ -48,6 +48,7 @@ Notes: ## Reminders +- For torch.jit.is_tracing() is only supported after v1.6. For users with pytorch v1.3-v1.5, we suggest early returning tensors manually. - If you meet any problem with the models in this repo, please create an issue and it would be taken care of soon. ## FAQs diff --git a/mmcls/models/heads/cls_head.py b/mmcls/models/heads/cls_head.py index c4ac12963..5a5457e9f 100644 --- a/mmcls/models/heads/cls_head.py +++ b/mmcls/models/heads/cls_head.py @@ -62,7 +62,9 @@ class ClsHead(BaseHead): if isinstance(cls_score, list): cls_score = sum(cls_score) / float(len(cls_score)) pred = F.softmax(cls_score, dim=1) if cls_score is not None else None - if torch.onnx.is_in_onnx_export() or torch.jit.is_tracing(): + + on_trace = hasattr(torch.jit, 'is_tracing') and torch.jit.is_tracing() + if torch.onnx.is_in_onnx_export() or on_trace: return pred pred = list(pred.detach().cpu().numpy()) return pred diff --git a/mmcls/models/heads/linear_head.py b/mmcls/models/heads/linear_head.py index 1be0991ae..7672f7cbd 100644 --- a/mmcls/models/heads/linear_head.py +++ b/mmcls/models/heads/linear_head.py @@ -43,7 +43,9 @@ class LinearClsHead(ClsHead): if isinstance(cls_score, list): cls_score = sum(cls_score) / float(len(cls_score)) pred = F.softmax(cls_score, dim=1) if cls_score is not None else None - if torch.onnx.is_in_onnx_export() or torch.jit.is_tracing(): + + on_trace = hasattr(torch.jit, 'is_tracing') and torch.jit.is_tracing() + if torch.onnx.is_in_onnx_export() or on_trace: return pred pred = list(pred.detach().cpu().numpy()) return pred diff --git a/mmcls/models/heads/multi_label_head.py b/mmcls/models/heads/multi_label_head.py index 3e087c6a9..3635d0d27 100644 --- a/mmcls/models/heads/multi_label_head.py +++ b/mmcls/models/heads/multi_label_head.py @@ -47,7 +47,9 @@ class MultiLabelClsHead(BaseHead): if isinstance(cls_score, list): cls_score = sum(cls_score) / float(len(cls_score)) pred = F.sigmoid(cls_score) if cls_score is not None else None - if torch.onnx.is_in_onnx_export() or torch.jit.is_tracing(): + + on_trace = hasattr(torch.jit, 'is_tracing') and torch.jit.is_tracing() + if torch.onnx.is_in_onnx_export() or on_trace: return pred pred = list(pred.detach().cpu().numpy()) return pred diff --git a/mmcls/models/heads/multi_label_linear_head.py b/mmcls/models/heads/multi_label_linear_head.py index 07d23703b..118fe747b 100644 --- a/mmcls/models/heads/multi_label_linear_head.py +++ b/mmcls/models/heads/multi_label_linear_head.py @@ -56,7 +56,9 @@ class MultiLabelLinearClsHead(MultiLabelClsHead): if isinstance(cls_score, list): cls_score = sum(cls_score) / float(len(cls_score)) pred = F.sigmoid(cls_score) if cls_score is not None else None - if torch.onnx.is_in_onnx_export() or torch.jit.is_tracing(): + + on_trace = hasattr(torch.jit, 'is_tracing') and torch.jit.is_tracing() + if torch.onnx.is_in_onnx_export() or on_trace: return pred pred = list(pred.detach().cpu().numpy()) return pred diff --git a/mmcls/models/heads/vision_transformer_head.py b/mmcls/models/heads/vision_transformer_head.py index 8b5dfc37a..606217d4f 100644 --- a/mmcls/models/heads/vision_transformer_head.py +++ b/mmcls/models/heads/vision_transformer_head.py @@ -68,7 +68,9 @@ class VisionTransformerClsHead(ClsHead): if isinstance(cls_score, list): cls_score = sum(cls_score) / float(len(cls_score)) pred = F.softmax(cls_score, dim=1) if cls_score is not None else None - if torch.onnx.is_in_onnx_export() or torch.jit.is_tracing(): + + on_trace = hasattr(torch.jit, 'is_tracing') and torch.jit.is_tracing() + if torch.onnx.is_in_onnx_export() or on_trace: return pred pred = list(pred.detach().cpu().numpy()) return pred diff --git a/tools/deployment/pytorch2torchscript.py b/tools/deployment/pytorch2torchscript.py index d1764a22b..edaca2d2d 100644 --- a/tools/deployment/pytorch2torchscript.py +++ b/tools/deployment/pytorch2torchscript.py @@ -1,5 +1,5 @@ -import os import argparse +import os import os.path as osp from functools import partial