mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
Remove redundancies in pytorch2onnx (#160)
* rm redundancies * re-add some packages
This commit is contained in:
parent
e1f4f51dc6
commit
f70507168f
@ -5,11 +5,11 @@ import mmcv
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import onnxruntime as rt
|
import onnxruntime as rt
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
|
||||||
import torch._C
|
import torch._C
|
||||||
import torch.serialization
|
import torch.serialization
|
||||||
from mmcv.onnx import register_extra_symbolics
|
from mmcv.onnx import register_extra_symbolics
|
||||||
from mmcv.runner import load_checkpoint
|
from mmcv.runner import load_checkpoint
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
from mmseg.models import build_segmentor
|
from mmseg.models import build_segmentor
|
||||||
|
|
||||||
@ -186,11 +186,6 @@ if __name__ == '__main__':
|
|||||||
# convert SyncBN to BN
|
# convert SyncBN to BN
|
||||||
segmentor = _convert_batchnorm(segmentor)
|
segmentor = _convert_batchnorm(segmentor)
|
||||||
|
|
||||||
if isinstance(segmentor.decode_head, nn.ModuleList):
|
|
||||||
num_classes = segmentor.decode_head[-1].num_classes
|
|
||||||
else:
|
|
||||||
num_classes = segmentor.decode_head.num_classes
|
|
||||||
|
|
||||||
if args.checkpoint:
|
if args.checkpoint:
|
||||||
load_checkpoint(segmentor, args.checkpoint, map_location='cpu')
|
load_checkpoint(segmentor, args.checkpoint, map_location='cpu')
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user