diff --git a/tools/deployment/pytorch2onnx.py b/tools/deployment/pytorch2onnx.py index 28578b2a8..e6704172c 100644 --- a/tools/deployment/pytorch2onnx.py +++ b/tools/deployment/pytorch2onnx.py @@ -59,7 +59,15 @@ def pytorch2onnx(model, """ model.cpu().eval() - num_classes = model.head.num_classes + if hasattr(model.head, 'num_classes'): + num_classes = model.head.num_classes + # Some backbones use `num_classes=-1` to disable top classifier. + elif getattr(model.backbone, 'num_classes', -1) > 0: + num_classes = model.backbone.num_classes + else: + raise AttributeError('Cannot find "num_classes" in both head and ' + 'backbone, please check the config file.') + mm_inputs = _demo_mm_inputs(input_shape, num_classes) imgs = mm_inputs.pop('imgs')