From c5374854aa13c0e4fa08526f3499bcb5eb2a5b4a Mon Sep 17 00:00:00 2001 From: Ma Zerun Date: Sat, 25 Sep 2021 09:25:58 +0800 Subject: [PATCH] [Fix] Fix num_classes bug in pytorch2onnx.py (#458) * Fix num_classes bug in pytorch2onnx.py * Fix `num_classes=-1`. --- tools/deployment/pytorch2onnx.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/tools/deployment/pytorch2onnx.py b/tools/deployment/pytorch2onnx.py index 28578b2a8..e6704172c 100644 --- a/tools/deployment/pytorch2onnx.py +++ b/tools/deployment/pytorch2onnx.py @@ -59,7 +59,15 @@ def pytorch2onnx(model, """ model.cpu().eval() - num_classes = model.head.num_classes + if hasattr(model.head, 'num_classes'): + num_classes = model.head.num_classes + # Some backbones use `num_classes=-1` to disable top classifier. + elif getattr(model.backbone, 'num_classes', -1) > 0: + num_classes = model.backbone.num_classes + else: + raise AttributeError('Cannot find "num_classes" in both head and ' + 'backbone, please check the config file.') + mm_inputs = _demo_mm_inputs(input_shape, num_classes) imgs = mm_inputs.pop('imgs')