[Fix] Fix num_classes bug in pytorch2onnx.py (#458)

* Fix num_classes bug in pytorch2onnx.py

* Fix `num_classes=-1`.
pull/471/head
Ma Zerun 2021-09-25 09:25:58 +08:00 committed by GitHub
parent 1d6d142b42
commit c5374854aa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 9 additions and 1 deletions

View File

@ -59,7 +59,15 @@ def pytorch2onnx(model,
"""
model.cpu().eval()
num_classes = model.head.num_classes
if hasattr(model.head, 'num_classes'):
num_classes = model.head.num_classes
# Some backbones use `num_classes=-1` to disable top classifier.
elif getattr(model.backbone, 'num_classes', -1) > 0:
num_classes = model.backbone.num_classes
else:
raise AttributeError('Cannot find "num_classes" in both head and '
'backbone, please check the config file.')
mm_inputs = _demo_mm_inputs(input_shape, num_classes)
imgs = mm_inputs.pop('imgs')