[Fix] Replace interpolate with resize (#731)

* Replace interpolate with resize

* Replace nn.Upsample with ops.Upsample

* Fix test
This commit is contained in:
Miguel Méndez 2021-07-28 10:56:22 +02:00 committed by GitHub
parent b5ae7a7f69
commit 50461efe85
11 changed files with 27 additions and 24 deletions

View File

@ -13,6 +13,7 @@ from torch.nn.modules.linear import Linear
from torch.nn.modules.normalization import LayerNorm from torch.nn.modules.normalization import LayerNorm
from torch.nn.modules.utils import _pair as to_2tuple from torch.nn.modules.utils import _pair as to_2tuple
from mmseg.ops import resize
from ...utils import get_root_logger from ...utils import get_root_logger
from ..builder import ATTENTION, BACKBONES from ..builder import ATTENTION, BACKBONES
from ..utils import PatchEmbed, swin_convert from ..utils import PatchEmbed, swin_convert
@ -745,7 +746,7 @@ class SwinTransformer(BaseModule):
if L1 != L2: if L1 != L2:
S1 = int(L1**0.5) S1 = int(L1**0.5)
S2 = int(L2**0.5) S2 = int(L2**0.5)
table_pretrained_resized = F.interpolate( table_pretrained_resized = resize(
table_pretrained.permute(1, 0).reshape( table_pretrained.permute(1, 0).reshape(
1, nH1, S1, S1), 1, nH1, S1, S1),
size=(S2, S2), size=(S2, S2),

View File

@ -7,6 +7,7 @@ from mmcv.cnn import (UPSAMPLE_LAYERS, ConvModule, build_activation_layer,
from mmcv.runner import BaseModule from mmcv.runner import BaseModule
from mmcv.utils.parrots_wrapper import _BatchNorm from mmcv.utils.parrots_wrapper import _BatchNorm
from mmseg.ops import Upsample
from ..builder import BACKBONES from ..builder import BACKBONES
from ..utils import UpConvBlock from ..utils import UpConvBlock
@ -203,7 +204,7 @@ class InterpConv(nn.Module):
conv_cfg=conv_cfg, conv_cfg=conv_cfg,
norm_cfg=norm_cfg, norm_cfg=norm_cfg,
act_cfg=act_cfg) act_cfg=act_cfg)
upsample = nn.Upsample(**upsample_cfg) upsample = Upsample(**upsample_cfg)
if conv_first: if conv_first:
self.interp_upsample = nn.Sequential(conv, upsample) self.interp_upsample = nn.Sequential(conv, upsample)
else: else:

View File

@ -3,7 +3,6 @@ import warnings
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import (build_norm_layer, constant_init, kaiming_init, from mmcv.cnn import (build_norm_layer, constant_init, kaiming_init,
normal_init, trunc_normal_init) normal_init, trunc_normal_init)
from mmcv.cnn.bricks.transformer import FFN, MultiheadAttention from mmcv.cnn.bricks.transformer import FFN, MultiheadAttention
@ -11,6 +10,7 @@ from mmcv.runner import BaseModule, ModuleList, _load_checkpoint
from torch.nn.modules.batchnorm import _BatchNorm from torch.nn.modules.batchnorm import _BatchNorm
from torch.nn.modules.utils import _pair as to_2tuple from torch.nn.modules.utils import _pair as to_2tuple
from mmseg.ops import resize
from mmseg.utils import get_root_logger from mmseg.utils import get_root_logger
from ..builder import BACKBONES from ..builder import BACKBONES
from ..utils import PatchEmbed, vit_convert from ..utils import PatchEmbed, vit_convert
@ -373,7 +373,7 @@ class VisionTransformer(BaseModule):
pos_embed_weight = pos_embed[:, (-1 * pos_h * pos_w):] pos_embed_weight = pos_embed[:, (-1 * pos_h * pos_w):]
pos_embed_weight = pos_embed_weight.reshape( pos_embed_weight = pos_embed_weight.reshape(
1, pos_h, pos_w, pos_embed.shape[2]).permute(0, 3, 1, 2) 1, pos_h, pos_w, pos_embed.shape[2]).permute(0, 3, 1, 2)
pos_embed_weight = F.interpolate( pos_embed_weight = resize(
pos_embed_weight, size=input_shpae, align_corners=False, mode=mode) pos_embed_weight, size=input_shpae, align_corners=False, mode=mode)
cls_token_weight = cls_token_weight.unsqueeze(1) cls_token_weight = cls_token_weight.unsqueeze(1)
pos_embed_weight = torch.flatten(pos_embed_weight, 2).transpose(1, 2) pos_embed_weight = torch.flatten(pos_embed_weight, 2).transpose(1, 2)

View File

@ -2,7 +2,7 @@ import numpy as np
import torch.nn as nn import torch.nn as nn
from mmcv.cnn import ConvModule from mmcv.cnn import ConvModule
from mmseg.ops import resize from mmseg.ops import Upsample, resize
from ..builder import HEADS from ..builder import HEADS
from .decode_head import BaseDecodeHead from .decode_head import BaseDecodeHead
@ -45,7 +45,7 @@ class FPNHead(BaseDecodeHead):
act_cfg=self.act_cfg)) act_cfg=self.act_cfg))
if feature_strides[i] != feature_strides[0]: if feature_strides[i] != feature_strides[0]:
scale_head.append( scale_head.append(
nn.Upsample( Upsample(
scale_factor=2, scale_factor=2,
mode='bilinear', mode='bilinear',
align_corners=self.align_corners)) align_corners=self.align_corners))

View File

@ -2,6 +2,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from mmcv.cnn import ConvModule from mmcv.cnn import ConvModule
from mmseg.ops import Upsample
from ..builder import HEADS from ..builder import HEADS
from .decode_head import BaseDecodeHead from .decode_head import BaseDecodeHead
@ -46,7 +47,7 @@ class SETRMLAHead(BaseDecodeHead):
padding=1, padding=1,
norm_cfg=self.norm_cfg, norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg), act_cfg=self.act_cfg),
nn.Upsample( Upsample(
scale_factor=up_scale, scale_factor=up_scale,
mode='bilinear', mode='bilinear',
align_corners=self.align_corners))) align_corners=self.align_corners)))

View File

@ -1,6 +1,7 @@
import torch.nn as nn import torch.nn as nn
from mmcv.cnn import ConvModule, build_norm_layer from mmcv.cnn import ConvModule, build_norm_layer
from mmseg.ops import Upsample
from ..builder import HEADS from ..builder import HEADS
from .decode_head import BaseDecodeHead from .decode_head import BaseDecodeHead
@ -59,7 +60,7 @@ class SETRUPHead(BaseDecodeHead):
padding=int(kernel_size - 1) // 2, padding=int(kernel_size - 1) // 2,
norm_cfg=self.norm_cfg, norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg), act_cfg=self.act_cfg),
nn.Upsample( Upsample(
scale_factor=up_scale, scale_factor=up_scale,
mode='bilinear', mode='bilinear',
align_corners=self.align_corners))) align_corners=self.align_corners)))

View File

@ -3,6 +3,7 @@ import torch.nn.functional as F
from mmcv.cnn import ConvModule from mmcv.cnn import ConvModule
from mmcv.runner import BaseModule, auto_fp16 from mmcv.runner import BaseModule, auto_fp16
from mmseg.ops import resize
from ..builder import NECKS from ..builder import NECKS
@ -173,11 +174,10 @@ class FPN(BaseModule):
# In some cases, fixing `scale factor` (e.g. 2) is preferred, but # In some cases, fixing `scale factor` (e.g. 2) is preferred, but
# it cannot co-exist with `size` in `F.interpolate`. # it cannot co-exist with `size` in `F.interpolate`.
if 'scale_factor' in self.upsample_cfg: if 'scale_factor' in self.upsample_cfg:
laterals[i - 1] += F.interpolate(laterals[i], laterals[i - 1] += resize(laterals[i], **self.upsample_cfg)
**self.upsample_cfg)
else: else:
prev_shape = laterals[i - 1].shape[2:] prev_shape = laterals[i - 1].shape[2:]
laterals[i - 1] += F.interpolate( laterals[i - 1] += resize(
laterals[i], size=prev_shape, **self.upsample_cfg) laterals[i], size=prev_shape, **self.upsample_cfg)
# build outputs # build outputs

View File

@ -1,7 +1,7 @@
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import ConvModule, xavier_init from mmcv.cnn import ConvModule, xavier_init
from mmseg.ops import resize
from ..builder import NECKS from ..builder import NECKS
@ -70,7 +70,7 @@ class MultiLevelNeck(nn.Module):
inputs = [inputs[0] for _ in range(self.num_outs)] inputs = [inputs[0] for _ in range(self.num_outs)]
outs = [] outs = []
for i in range(self.num_outs): for i in range(self.num_outs):
x_resize = F.interpolate( x_resize = resize(
inputs[i], scale_factor=self.scales[i], mode='bilinear') inputs[i], scale_factor=self.scales[i], mode='bilinear')
outs.append(self.convs[i](x_resize)) outs.append(self.convs[i](x_resize))
return tuple(outs) return tuple(outs)

View File

@ -1,10 +1,10 @@
import pytest import pytest
import torch import torch
from mmcv.cnn import ConvModule from mmcv.cnn import ConvModule
from torch import nn
from mmseg.models.backbones.unet import (BasicConvBlock, DeconvModule, from mmseg.models.backbones.unet import (BasicConvBlock, DeconvModule,
InterpConv, UNet, UpConvBlock) InterpConv, UNet, UpConvBlock)
from mmseg.ops import Upsample
from .utils import check_norm_state from .utils import check_norm_state
@ -145,7 +145,7 @@ def test_interp_conv():
block = InterpConv(64, 32, conv_first=False) block = InterpConv(64, 32, conv_first=False)
x = torch.randn(1, 64, 128, 128) x = torch.randn(1, 64, 128, 128)
x_out = block(x) x_out = block(x)
assert isinstance(block.interp_upsample[0], nn.Upsample) assert isinstance(block.interp_upsample[0], Upsample)
assert isinstance(block.interp_upsample[1], ConvModule) assert isinstance(block.interp_upsample[1], ConvModule)
assert x_out.shape == torch.Size([1, 32, 256, 256]) assert x_out.shape == torch.Size([1, 32, 256, 256])
@ -154,7 +154,7 @@ def test_interp_conv():
x = torch.randn(1, 64, 128, 128) x = torch.randn(1, 64, 128, 128)
x_out = block(x) x_out = block(x)
assert isinstance(block.interp_upsample[0], ConvModule) assert isinstance(block.interp_upsample[0], ConvModule)
assert isinstance(block.interp_upsample[1], nn.Upsample) assert isinstance(block.interp_upsample[1], Upsample)
assert x_out.shape == torch.Size([1, 32, 256, 256]) assert x_out.shape == torch.Size([1, 32, 256, 256])
# test InterpConv with bilinear upsample for upsample 2X. # test InterpConv with bilinear upsample for upsample 2X.
@ -166,7 +166,7 @@ def test_interp_conv():
scale_factor=2, mode='bilinear', align_corners=False)) scale_factor=2, mode='bilinear', align_corners=False))
x = torch.randn(1, 64, 128, 128) x = torch.randn(1, 64, 128, 128)
x_out = block(x) x_out = block(x)
assert isinstance(block.interp_upsample[0], nn.Upsample) assert isinstance(block.interp_upsample[0], Upsample)
assert isinstance(block.interp_upsample[1], ConvModule) assert isinstance(block.interp_upsample[1], ConvModule)
assert x_out.shape == torch.Size([1, 32, 256, 256]) assert x_out.shape == torch.Size([1, 32, 256, 256])
assert block.interp_upsample[0].mode == 'bilinear' assert block.interp_upsample[0].mode == 'bilinear'
@ -179,7 +179,7 @@ def test_interp_conv():
upsample_cfg=dict(scale_factor=2, mode='nearest')) upsample_cfg=dict(scale_factor=2, mode='nearest'))
x = torch.randn(1, 64, 128, 128) x = torch.randn(1, 64, 128, 128)
x_out = block(x) x_out = block(x)
assert isinstance(block.interp_upsample[0], nn.Upsample) assert isinstance(block.interp_upsample[0], Upsample)
assert isinstance(block.interp_upsample[1], ConvModule) assert isinstance(block.interp_upsample[1], ConvModule)
assert x_out.shape == torch.Size([1, 32, 256, 256]) assert x_out.shape == torch.Size([1, 32, 256, 256])
assert block.interp_upsample[0].mode == 'nearest' assert block.interp_upsample[0].mode == 'nearest'

View File

@ -14,6 +14,7 @@ from mmcv.utils import DictAction
from mmseg.apis import single_gpu_test from mmseg.apis import single_gpu_test
from mmseg.datasets import build_dataloader, build_dataset from mmseg.datasets import build_dataloader, build_dataset
from mmseg.models.segmentors.base import BaseSegmentor from mmseg.models.segmentors.base import BaseSegmentor
from mmseg.ops import resize
class ONNXRuntimeSegmentor(BaseSegmentor): class ONNXRuntimeSegmentor(BaseSegmentor):
@ -79,7 +80,7 @@ class ONNXRuntimeSegmentor(BaseSegmentor):
if not (ori_shape[0] == seg_pred.shape[-2] if not (ori_shape[0] == seg_pred.shape[-2]
and ori_shape[1] == seg_pred.shape[-1]): and ori_shape[1] == seg_pred.shape[-1]):
seg_pred = torch.from_numpy(seg_pred).float() seg_pred = torch.from_numpy(seg_pred).float()
seg_pred = torch.nn.functional.interpolate( seg_pred = resize(
seg_pred, size=tuple(ori_shape[:2]), mode='nearest') seg_pred, size=tuple(ori_shape[:2]), mode='nearest')
seg_pred = seg_pred.long().detach().cpu().numpy() seg_pred = seg_pred.long().detach().cpu().numpy()
seg_pred = seg_pred[0] seg_pred = seg_pred[0]
@ -127,7 +128,7 @@ class TensorRTSegmentor(BaseSegmentor):
if not (ori_shape[0] == seg_pred.shape[-2] if not (ori_shape[0] == seg_pred.shape[-2]
and ori_shape[1] == seg_pred.shape[-1]): and ori_shape[1] == seg_pred.shape[-1]):
seg_pred = torch.from_numpy(seg_pred).float() seg_pred = torch.from_numpy(seg_pred).float()
seg_pred = torch.nn.functional.interpolate( seg_pred = resize(
seg_pred, size=tuple(ori_shape[:2]), mode='nearest') seg_pred, size=tuple(ori_shape[:2]), mode='nearest')
seg_pred = seg_pred.long().detach().cpu().numpy() seg_pred = seg_pred.long().detach().cpu().numpy()
seg_pred = seg_pred[0] seg_pred = seg_pred[0]

View File

@ -16,6 +16,7 @@ from mmseg.apis import show_result_pyplot
from mmseg.apis.inference import LoadImage from mmseg.apis.inference import LoadImage
from mmseg.datasets.pipelines import Compose from mmseg.datasets.pipelines import Compose
from mmseg.models import build_segmentor from mmseg.models import build_segmentor
from mmseg.ops import resize
torch.manual_seed(3) torch.manual_seed(3)
@ -210,10 +211,7 @@ def pytorch2onnx(model,
if dynamic_export and test_mode == 'whole': if dynamic_export and test_mode == 'whole':
# scale image for dynamic shape test # scale image for dynamic shape test
img_list = [ img_list = [resize(_, scale_factor=1.5) for _ in img_list]
nn.functional.interpolate(_, scale_factor=1.5)
for _ in img_list
]
# concate flip image for batch test # concate flip image for batch test
flip_img_list = [_.flip(-1) for _ in img_list] flip_img_list = [_.flip(-1) for _ in img_list]
img_list = [ img_list = [