mirror of
https://github.com/open-mmlab/mmclassification.git
synced 2025-06-03 21:53:55 +08:00
[Fix] Build compatible with low pytorch versions (#301)
* add version compatible for torchscript * doc * doc again * fix lint * fix lint isort
This commit is contained in:
parent
b99bd4fa88
commit
a24a9f6faa
@ -24,7 +24,7 @@ python tools/deployment/pytorch2torchscript.py \
|
||||
--verify \
|
||||
```
|
||||
|
||||
### Description of all arguments:
|
||||
### Description of all arguments
|
||||
|
||||
- `config` : The path of a model config file.
|
||||
- `--checkpoint` : The path of a model checkpoint file.
|
||||
@ -48,6 +48,7 @@ Notes:
|
||||
|
||||
## Reminders
|
||||
|
||||
- For torch.jit.is_tracing() is only supported after v1.6. For users with pytorch v1.3-v1.5, we suggest early returning tensors manually.
|
||||
- If you meet any problem with the models in this repo, please create an issue and it would be taken care of soon.
|
||||
|
||||
## FAQs
|
||||
|
@ -62,7 +62,9 @@ class ClsHead(BaseHead):
|
||||
if isinstance(cls_score, list):
|
||||
cls_score = sum(cls_score) / float(len(cls_score))
|
||||
pred = F.softmax(cls_score, dim=1) if cls_score is not None else None
|
||||
if torch.onnx.is_in_onnx_export() or torch.jit.is_tracing():
|
||||
|
||||
on_trace = hasattr(torch.jit, 'is_tracing') and torch.jit.is_tracing()
|
||||
if torch.onnx.is_in_onnx_export() or on_trace:
|
||||
return pred
|
||||
pred = list(pred.detach().cpu().numpy())
|
||||
return pred
|
||||
|
@ -43,7 +43,9 @@ class LinearClsHead(ClsHead):
|
||||
if isinstance(cls_score, list):
|
||||
cls_score = sum(cls_score) / float(len(cls_score))
|
||||
pred = F.softmax(cls_score, dim=1) if cls_score is not None else None
|
||||
if torch.onnx.is_in_onnx_export() or torch.jit.is_tracing():
|
||||
|
||||
on_trace = hasattr(torch.jit, 'is_tracing') and torch.jit.is_tracing()
|
||||
if torch.onnx.is_in_onnx_export() or on_trace:
|
||||
return pred
|
||||
pred = list(pred.detach().cpu().numpy())
|
||||
return pred
|
||||
|
@ -47,7 +47,9 @@ class MultiLabelClsHead(BaseHead):
|
||||
if isinstance(cls_score, list):
|
||||
cls_score = sum(cls_score) / float(len(cls_score))
|
||||
pred = F.sigmoid(cls_score) if cls_score is not None else None
|
||||
if torch.onnx.is_in_onnx_export() or torch.jit.is_tracing():
|
||||
|
||||
on_trace = hasattr(torch.jit, 'is_tracing') and torch.jit.is_tracing()
|
||||
if torch.onnx.is_in_onnx_export() or on_trace:
|
||||
return pred
|
||||
pred = list(pred.detach().cpu().numpy())
|
||||
return pred
|
||||
|
@ -56,7 +56,9 @@ class MultiLabelLinearClsHead(MultiLabelClsHead):
|
||||
if isinstance(cls_score, list):
|
||||
cls_score = sum(cls_score) / float(len(cls_score))
|
||||
pred = F.sigmoid(cls_score) if cls_score is not None else None
|
||||
if torch.onnx.is_in_onnx_export() or torch.jit.is_tracing():
|
||||
|
||||
on_trace = hasattr(torch.jit, 'is_tracing') and torch.jit.is_tracing()
|
||||
if torch.onnx.is_in_onnx_export() or on_trace:
|
||||
return pred
|
||||
pred = list(pred.detach().cpu().numpy())
|
||||
return pred
|
||||
|
@ -68,7 +68,9 @@ class VisionTransformerClsHead(ClsHead):
|
||||
if isinstance(cls_score, list):
|
||||
cls_score = sum(cls_score) / float(len(cls_score))
|
||||
pred = F.softmax(cls_score, dim=1) if cls_score is not None else None
|
||||
if torch.onnx.is_in_onnx_export() or torch.jit.is_tracing():
|
||||
|
||||
on_trace = hasattr(torch.jit, 'is_tracing') and torch.jit.is_tracing()
|
||||
if torch.onnx.is_in_onnx_export() or on_trace:
|
||||
return pred
|
||||
pred = list(pred.detach().cpu().numpy())
|
||||
return pred
|
||||
|
@ -1,5 +1,5 @@
|
||||
import os
|
||||
import argparse
|
||||
import os
|
||||
import os.path as osp
|
||||
from functools import partial
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user