diff --git a/tools/pytorch2onnx.py b/tools/pytorch2onnx.py index b24536678..f22d4d378 100644 --- a/tools/pytorch2onnx.py +++ b/tools/pytorch2onnx.py @@ -5,11 +5,11 @@ import mmcv import numpy as np import onnxruntime as rt import torch -from torch import nn import torch._C import torch.serialization from mmcv.onnx import register_extra_symbolics from mmcv.runner import load_checkpoint +from torch import nn from mmseg.models import build_segmentor @@ -186,11 +186,6 @@ if __name__ == '__main__': # convert SyncBN to BN segmentor = _convert_batchnorm(segmentor) - if isinstance(segmentor.decode_head, nn.ModuleList): - num_classes = segmentor.decode_head[-1].num_classes - else: - num_classes = segmentor.decode_head.num_classes - if args.checkpoint: load_checkpoint(segmentor, args.checkpoint, map_location='cpu')