mirror of
https://github.com/open-mmlab/mmclassification.git
synced 2025-06-03 21:53:55 +08:00
[Fix] Build compatible with low pytorch versions (#301)
* add version compatible for torchscript * doc * doc again * fix lint * fix lint isort
This commit is contained in:
parent
b99bd4fa88
commit
a24a9f6faa
@ -24,7 +24,7 @@ python tools/deployment/pytorch2torchscript.py \
|
|||||||
--verify \
|
--verify \
|
||||||
```
|
```
|
||||||
|
|
||||||
### Description of all arguments:
|
### Description of all arguments
|
||||||
|
|
||||||
- `config` : The path of a model config file.
|
- `config` : The path of a model config file.
|
||||||
- `--checkpoint` : The path of a model checkpoint file.
|
- `--checkpoint` : The path of a model checkpoint file.
|
||||||
@ -48,6 +48,7 @@ Notes:
|
|||||||
|
|
||||||
## Reminders
|
## Reminders
|
||||||
|
|
||||||
|
- For torch.jit.is_tracing() is only supported after v1.6. For users with pytorch v1.3-v1.5, we suggest early returning tensors manually.
|
||||||
- If you meet any problem with the models in this repo, please create an issue and it would be taken care of soon.
|
- If you meet any problem with the models in this repo, please create an issue and it would be taken care of soon.
|
||||||
|
|
||||||
## FAQs
|
## FAQs
|
||||||
|
@ -62,7 +62,9 @@ class ClsHead(BaseHead):
|
|||||||
if isinstance(cls_score, list):
|
if isinstance(cls_score, list):
|
||||||
cls_score = sum(cls_score) / float(len(cls_score))
|
cls_score = sum(cls_score) / float(len(cls_score))
|
||||||
pred = F.softmax(cls_score, dim=1) if cls_score is not None else None
|
pred = F.softmax(cls_score, dim=1) if cls_score is not None else None
|
||||||
if torch.onnx.is_in_onnx_export() or torch.jit.is_tracing():
|
|
||||||
|
on_trace = hasattr(torch.jit, 'is_tracing') and torch.jit.is_tracing()
|
||||||
|
if torch.onnx.is_in_onnx_export() or on_trace:
|
||||||
return pred
|
return pred
|
||||||
pred = list(pred.detach().cpu().numpy())
|
pred = list(pred.detach().cpu().numpy())
|
||||||
return pred
|
return pred
|
||||||
|
@ -43,7 +43,9 @@ class LinearClsHead(ClsHead):
|
|||||||
if isinstance(cls_score, list):
|
if isinstance(cls_score, list):
|
||||||
cls_score = sum(cls_score) / float(len(cls_score))
|
cls_score = sum(cls_score) / float(len(cls_score))
|
||||||
pred = F.softmax(cls_score, dim=1) if cls_score is not None else None
|
pred = F.softmax(cls_score, dim=1) if cls_score is not None else None
|
||||||
if torch.onnx.is_in_onnx_export() or torch.jit.is_tracing():
|
|
||||||
|
on_trace = hasattr(torch.jit, 'is_tracing') and torch.jit.is_tracing()
|
||||||
|
if torch.onnx.is_in_onnx_export() or on_trace:
|
||||||
return pred
|
return pred
|
||||||
pred = list(pred.detach().cpu().numpy())
|
pred = list(pred.detach().cpu().numpy())
|
||||||
return pred
|
return pred
|
||||||
|
@ -47,7 +47,9 @@ class MultiLabelClsHead(BaseHead):
|
|||||||
if isinstance(cls_score, list):
|
if isinstance(cls_score, list):
|
||||||
cls_score = sum(cls_score) / float(len(cls_score))
|
cls_score = sum(cls_score) / float(len(cls_score))
|
||||||
pred = F.sigmoid(cls_score) if cls_score is not None else None
|
pred = F.sigmoid(cls_score) if cls_score is not None else None
|
||||||
if torch.onnx.is_in_onnx_export() or torch.jit.is_tracing():
|
|
||||||
|
on_trace = hasattr(torch.jit, 'is_tracing') and torch.jit.is_tracing()
|
||||||
|
if torch.onnx.is_in_onnx_export() or on_trace:
|
||||||
return pred
|
return pred
|
||||||
pred = list(pred.detach().cpu().numpy())
|
pred = list(pred.detach().cpu().numpy())
|
||||||
return pred
|
return pred
|
||||||
|
@ -56,7 +56,9 @@ class MultiLabelLinearClsHead(MultiLabelClsHead):
|
|||||||
if isinstance(cls_score, list):
|
if isinstance(cls_score, list):
|
||||||
cls_score = sum(cls_score) / float(len(cls_score))
|
cls_score = sum(cls_score) / float(len(cls_score))
|
||||||
pred = F.sigmoid(cls_score) if cls_score is not None else None
|
pred = F.sigmoid(cls_score) if cls_score is not None else None
|
||||||
if torch.onnx.is_in_onnx_export() or torch.jit.is_tracing():
|
|
||||||
|
on_trace = hasattr(torch.jit, 'is_tracing') and torch.jit.is_tracing()
|
||||||
|
if torch.onnx.is_in_onnx_export() or on_trace:
|
||||||
return pred
|
return pred
|
||||||
pred = list(pred.detach().cpu().numpy())
|
pred = list(pred.detach().cpu().numpy())
|
||||||
return pred
|
return pred
|
||||||
|
@ -68,7 +68,9 @@ class VisionTransformerClsHead(ClsHead):
|
|||||||
if isinstance(cls_score, list):
|
if isinstance(cls_score, list):
|
||||||
cls_score = sum(cls_score) / float(len(cls_score))
|
cls_score = sum(cls_score) / float(len(cls_score))
|
||||||
pred = F.softmax(cls_score, dim=1) if cls_score is not None else None
|
pred = F.softmax(cls_score, dim=1) if cls_score is not None else None
|
||||||
if torch.onnx.is_in_onnx_export() or torch.jit.is_tracing():
|
|
||||||
|
on_trace = hasattr(torch.jit, 'is_tracing') and torch.jit.is_tracing()
|
||||||
|
if torch.onnx.is_in_onnx_export() or on_trace:
|
||||||
return pred
|
return pred
|
||||||
pred = list(pred.detach().cpu().numpy())
|
pred = list(pred.detach().cpu().numpy())
|
||||||
return pred
|
return pred
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
import os
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import os
|
||||||
import os.path as osp
|
import os.path as osp
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user