mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
Onnx upsample (#100)
* add customized Upsample which can convert to ONNX * support multiply decode head for hrnet * support size for Upsample
This commit is contained in:
parent
b8f42c70fa
commit
0c04f52c42
@ -4,7 +4,7 @@ from mmcv.cnn import (build_conv_layer, build_norm_layer, constant_init,
|
|||||||
from mmcv.runner import load_checkpoint
|
from mmcv.runner import load_checkpoint
|
||||||
from mmcv.utils.parrots_wrapper import _BatchNorm
|
from mmcv.utils.parrots_wrapper import _BatchNorm
|
||||||
|
|
||||||
from mmseg.ops import resize
|
from mmseg.ops import Upsample, resize
|
||||||
from mmseg.utils import get_root_logger
|
from mmseg.utils import get_root_logger
|
||||||
from ..builder import BACKBONES
|
from ..builder import BACKBONES
|
||||||
from .resnet import BasicBlock, Bottleneck
|
from .resnet import BasicBlock, Bottleneck
|
||||||
@ -141,7 +141,7 @@ class HRModule(nn.Module):
|
|||||||
bias=False),
|
bias=False),
|
||||||
build_norm_layer(self.norm_cfg, in_channels[i])[1],
|
build_norm_layer(self.norm_cfg, in_channels[i])[1],
|
||||||
# we set align_corners=False for HRNet
|
# we set align_corners=False for HRNet
|
||||||
nn.Upsample(
|
Upsample(
|
||||||
scale_factor=2**(j - i),
|
scale_factor=2**(j - i),
|
||||||
mode='bilinear',
|
mode='bilinear',
|
||||||
align_corners=False)))
|
align_corners=False)))
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
from .encoding import Encoding
|
from .encoding import Encoding
|
||||||
from .separable_conv_module import DepthwiseSeparableConvModule
|
from .separable_conv_module import DepthwiseSeparableConvModule
|
||||||
from .wrappers import resize
|
from .wrappers import Upsample, resize
|
||||||
|
|
||||||
__all__ = ['resize', 'DepthwiseSeparableConvModule', 'Encoding']
|
__all__ = ['Upsample', 'resize', 'DepthwiseSeparableConvModule', 'Encoding']
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
@ -11,8 +13,8 @@ def resize(input,
|
|||||||
warning=True):
|
warning=True):
|
||||||
if warning:
|
if warning:
|
||||||
if size is not None and align_corners:
|
if size is not None and align_corners:
|
||||||
input_h, input_w = input.shape[2:]
|
input_h, input_w = tuple(int(x) for x in input.shape[2:])
|
||||||
output_h, output_w = size
|
output_h, output_w = tuple(int(x) for x in size)
|
||||||
if output_h > input_h or output_w > output_h:
|
if output_h > input_h or output_w > output_h:
|
||||||
if ((output_h > 1 and output_w > 1 and input_h > 1
|
if ((output_h > 1 and output_w > 1 and input_h > 1
|
||||||
and input_w > 1) and (output_h - 1) % (input_h - 1)
|
and input_w > 1) and (output_h - 1) % (input_h - 1)
|
||||||
@ -22,4 +24,30 @@ def resize(input,
|
|||||||
'the output would more aligned if '
|
'the output would more aligned if '
|
||||||
f'input size {(input_h, input_w)} is `x+1` and '
|
f'input size {(input_h, input_w)} is `x+1` and '
|
||||||
f'out size {(output_h, output_w)} is `nx+1`')
|
f'out size {(output_h, output_w)} is `nx+1`')
|
||||||
|
if isinstance(size, torch.Size):
|
||||||
|
size = tuple(int(x) for x in size)
|
||||||
return F.interpolate(input, size, scale_factor, mode, align_corners)
|
return F.interpolate(input, size, scale_factor, mode, align_corners)
|
||||||
|
|
||||||
|
|
||||||
|
class Upsample(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
size=None,
|
||||||
|
scale_factor=None,
|
||||||
|
mode='nearest',
|
||||||
|
align_corners=None):
|
||||||
|
super(Upsample, self).__init__()
|
||||||
|
self.size = size
|
||||||
|
if isinstance(scale_factor, tuple):
|
||||||
|
self.scale_factor = tuple(float(factor) for factor in scale_factor)
|
||||||
|
else:
|
||||||
|
self.scale_factor = float(scale_factor) if scale_factor else None
|
||||||
|
self.mode = mode
|
||||||
|
self.align_corners = align_corners
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if not self.size:
|
||||||
|
size = [int(t * self.scale_factor) for t in x.shape[-2:]]
|
||||||
|
else:
|
||||||
|
size = self.size
|
||||||
|
return resize(x, size, None, self.mode, self.align_corners)
|
||||||
|
@ -5,6 +5,7 @@ import mmcv
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import onnxruntime as rt
|
import onnxruntime as rt
|
||||||
import torch
|
import torch
|
||||||
|
from torch import nn
|
||||||
import torch._C
|
import torch._C
|
||||||
import torch.serialization
|
import torch.serialization
|
||||||
from mmcv.onnx import register_extra_symbolics
|
from mmcv.onnx import register_extra_symbolics
|
||||||
@ -88,7 +89,10 @@ def pytorch2onnx(model,
|
|||||||
"""
|
"""
|
||||||
model.cpu().eval()
|
model.cpu().eval()
|
||||||
|
|
||||||
num_classes = model.decode_head.num_classes
|
if isinstance(model.decode_head, nn.ModuleList):
|
||||||
|
num_classes = model.decode_head[-1].num_classes
|
||||||
|
else:
|
||||||
|
num_classes = model.decode_head.num_classes
|
||||||
|
|
||||||
mm_inputs = _demo_mm_inputs(input_shape, num_classes)
|
mm_inputs = _demo_mm_inputs(input_shape, num_classes)
|
||||||
|
|
||||||
@ -142,7 +146,7 @@ def pytorch2onnx(model,
|
|||||||
|
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
parser = argparse.ArgumentParser(description='Convert MMDet to ONNX')
|
parser = argparse.ArgumentParser(description='Convert MMSeg to ONNX')
|
||||||
parser.add_argument('config', help='test config file path')
|
parser.add_argument('config', help='test config file path')
|
||||||
parser.add_argument('--checkpoint', help='checkpoint file', default=None)
|
parser.add_argument('--checkpoint', help='checkpoint file', default=None)
|
||||||
parser.add_argument('--show', action='store_true', help='show onnx graph')
|
parser.add_argument('--show', action='store_true', help='show onnx graph')
|
||||||
@ -182,11 +186,13 @@ if __name__ == '__main__':
|
|||||||
# convert SyncBN to BN
|
# convert SyncBN to BN
|
||||||
segmentor = _convert_batchnorm(segmentor)
|
segmentor = _convert_batchnorm(segmentor)
|
||||||
|
|
||||||
num_classes = segmentor.decode_head.num_classes
|
if isinstance(segmentor.decode_head, nn.ModuleList):
|
||||||
|
num_classes = segmentor.decode_head[-1].num_classes
|
||||||
|
else:
|
||||||
|
num_classes = segmentor.decode_head.num_classes
|
||||||
|
|
||||||
if args.checkpoint:
|
if args.checkpoint:
|
||||||
checkpoint = load_checkpoint(
|
load_checkpoint(segmentor, args.checkpoint, map_location='cpu')
|
||||||
segmentor, args.checkpoint, map_location='cpu')
|
|
||||||
|
|
||||||
# conver model to onnx file
|
# conver model to onnx file
|
||||||
pytorch2onnx(
|
pytorch2onnx(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user