mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
Remove redundancies in pytorch2onnx (#160)
* rm redundancies * re-add some packages
This commit is contained in:
parent
e1f4f51dc6
commit
f70507168f
@ -5,11 +5,11 @@ import mmcv
|
||||
import numpy as np
|
||||
import onnxruntime as rt
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch._C
|
||||
import torch.serialization
|
||||
from mmcv.onnx import register_extra_symbolics
|
||||
from mmcv.runner import load_checkpoint
|
||||
from torch import nn
|
||||
|
||||
from mmseg.models import build_segmentor
|
||||
|
||||
@ -186,11 +186,6 @@ if __name__ == '__main__':
|
||||
# convert SyncBN to BN
|
||||
segmentor = _convert_batchnorm(segmentor)
|
||||
|
||||
if isinstance(segmentor.decode_head, nn.ModuleList):
|
||||
num_classes = segmentor.decode_head[-1].num_classes
|
||||
else:
|
||||
num_classes = segmentor.decode_head.num_classes
|
||||
|
||||
if args.checkpoint:
|
||||
load_checkpoint(segmentor, args.checkpoint, map_location='cpu')
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user