[Fix] Build compatible with low pytorch versions (#301)

* add version compatible for torchscript

* doc

* doc again

* fix lint

* fix lint isort
This commit is contained in:
AllentDan 2021-06-14 23:25:35 +08:00 committed by GitHub
parent b99bd4fa88
commit a24a9f6faa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 18 additions and 7 deletions

View File

@ -24,7 +24,7 @@ python tools/deployment/pytorch2torchscript.py \
--verify \
```
### Description of all arguments:
### Description of all arguments
- `config` : The path of a model config file.
- `--checkpoint` : The path of a model checkpoint file.
@ -48,6 +48,7 @@ Notes:
## Reminders
- For torch.jit.is_tracing() is only supported after v1.6. For users with pytorch v1.3-v1.5, we suggest early returning tensors manually.
- If you meet any problem with the models in this repo, please create an issue and it would be taken care of soon.
## FAQs

View File

@ -62,7 +62,9 @@ class ClsHead(BaseHead):
if isinstance(cls_score, list):
cls_score = sum(cls_score) / float(len(cls_score))
pred = F.softmax(cls_score, dim=1) if cls_score is not None else None
if torch.onnx.is_in_onnx_export() or torch.jit.is_tracing():
on_trace = hasattr(torch.jit, 'is_tracing') and torch.jit.is_tracing()
if torch.onnx.is_in_onnx_export() or on_trace:
return pred
pred = list(pred.detach().cpu().numpy())
return pred

View File

@ -43,7 +43,9 @@ class LinearClsHead(ClsHead):
if isinstance(cls_score, list):
cls_score = sum(cls_score) / float(len(cls_score))
pred = F.softmax(cls_score, dim=1) if cls_score is not None else None
if torch.onnx.is_in_onnx_export() or torch.jit.is_tracing():
on_trace = hasattr(torch.jit, 'is_tracing') and torch.jit.is_tracing()
if torch.onnx.is_in_onnx_export() or on_trace:
return pred
pred = list(pred.detach().cpu().numpy())
return pred

View File

@ -47,7 +47,9 @@ class MultiLabelClsHead(BaseHead):
if isinstance(cls_score, list):
cls_score = sum(cls_score) / float(len(cls_score))
pred = F.sigmoid(cls_score) if cls_score is not None else None
if torch.onnx.is_in_onnx_export() or torch.jit.is_tracing():
on_trace = hasattr(torch.jit, 'is_tracing') and torch.jit.is_tracing()
if torch.onnx.is_in_onnx_export() or on_trace:
return pred
pred = list(pred.detach().cpu().numpy())
return pred

View File

@ -56,7 +56,9 @@ class MultiLabelLinearClsHead(MultiLabelClsHead):
if isinstance(cls_score, list):
cls_score = sum(cls_score) / float(len(cls_score))
pred = F.sigmoid(cls_score) if cls_score is not None else None
if torch.onnx.is_in_onnx_export() or torch.jit.is_tracing():
on_trace = hasattr(torch.jit, 'is_tracing') and torch.jit.is_tracing()
if torch.onnx.is_in_onnx_export() or on_trace:
return pred
pred = list(pred.detach().cpu().numpy())
return pred

View File

@ -68,7 +68,9 @@ class VisionTransformerClsHead(ClsHead):
if isinstance(cls_score, list):
cls_score = sum(cls_score) / float(len(cls_score))
pred = F.softmax(cls_score, dim=1) if cls_score is not None else None
if torch.onnx.is_in_onnx_export() or torch.jit.is_tracing():
on_trace = hasattr(torch.jit, 'is_tracing') and torch.jit.is_tracing()
if torch.onnx.is_in_onnx_export() or on_trace:
return pred
pred = list(pred.detach().cpu().numpy())
return pred

View File

@ -1,5 +1,5 @@
import os
import argparse
import os
import os.path as osp
from functools import partial