[Fix] Replace interpolate with resize (#731)

* Replace interpolate with resize

* Replace nn.Upsample with ops.Upsample

* Fix test
This commit is contained in:
Miguel Méndez 2021-07-28 10:56:22 +02:00 committed by GitHub
parent b5ae7a7f69
commit 50461efe85
11 changed files with 27 additions and 24 deletions

View File

@ -13,6 +13,7 @@ from torch.nn.modules.linear import Linear
from torch.nn.modules.normalization import LayerNorm
from torch.nn.modules.utils import _pair as to_2tuple
from mmseg.ops import resize
from ...utils import get_root_logger
from ..builder import ATTENTION, BACKBONES
from ..utils import PatchEmbed, swin_convert
@ -745,7 +746,7 @@ class SwinTransformer(BaseModule):
if L1 != L2:
S1 = int(L1**0.5)
S2 = int(L2**0.5)
table_pretrained_resized = F.interpolate(
table_pretrained_resized = resize(
table_pretrained.permute(1, 0).reshape(
1, nH1, S1, S1),
size=(S2, S2),

View File

@ -7,6 +7,7 @@ from mmcv.cnn import (UPSAMPLE_LAYERS, ConvModule, build_activation_layer,
from mmcv.runner import BaseModule
from mmcv.utils.parrots_wrapper import _BatchNorm
from mmseg.ops import Upsample
from ..builder import BACKBONES
from ..utils import UpConvBlock
@ -203,7 +204,7 @@ class InterpConv(nn.Module):
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
upsample = nn.Upsample(**upsample_cfg)
upsample = Upsample(**upsample_cfg)
if conv_first:
self.interp_upsample = nn.Sequential(conv, upsample)
else:

View File

@ -3,7 +3,6 @@ import warnings
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import (build_norm_layer, constant_init, kaiming_init,
normal_init, trunc_normal_init)
from mmcv.cnn.bricks.transformer import FFN, MultiheadAttention
@ -11,6 +10,7 @@ from mmcv.runner import BaseModule, ModuleList, _load_checkpoint
from torch.nn.modules.batchnorm import _BatchNorm
from torch.nn.modules.utils import _pair as to_2tuple
from mmseg.ops import resize
from mmseg.utils import get_root_logger
from ..builder import BACKBONES
from ..utils import PatchEmbed, vit_convert
@ -373,7 +373,7 @@ class VisionTransformer(BaseModule):
pos_embed_weight = pos_embed[:, (-1 * pos_h * pos_w):]
pos_embed_weight = pos_embed_weight.reshape(
1, pos_h, pos_w, pos_embed.shape[2]).permute(0, 3, 1, 2)
pos_embed_weight = F.interpolate(
pos_embed_weight = resize(
pos_embed_weight, size=input_shpae, align_corners=False, mode=mode)
cls_token_weight = cls_token_weight.unsqueeze(1)
pos_embed_weight = torch.flatten(pos_embed_weight, 2).transpose(1, 2)

View File

@ -2,7 +2,7 @@ import numpy as np
import torch.nn as nn
from mmcv.cnn import ConvModule
from mmseg.ops import resize
from mmseg.ops import Upsample, resize
from ..builder import HEADS
from .decode_head import BaseDecodeHead
@ -45,7 +45,7 @@ class FPNHead(BaseDecodeHead):
act_cfg=self.act_cfg))
if feature_strides[i] != feature_strides[0]:
scale_head.append(
nn.Upsample(
Upsample(
scale_factor=2,
mode='bilinear',
align_corners=self.align_corners))

View File

@ -2,6 +2,7 @@ import torch
import torch.nn as nn
from mmcv.cnn import ConvModule
from mmseg.ops import Upsample
from ..builder import HEADS
from .decode_head import BaseDecodeHead
@ -46,7 +47,7 @@ class SETRMLAHead(BaseDecodeHead):
padding=1,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg),
nn.Upsample(
Upsample(
scale_factor=up_scale,
mode='bilinear',
align_corners=self.align_corners)))

View File

@ -1,6 +1,7 @@
import torch.nn as nn
from mmcv.cnn import ConvModule, build_norm_layer
from mmseg.ops import Upsample
from ..builder import HEADS
from .decode_head import BaseDecodeHead
@ -59,7 +60,7 @@ class SETRUPHead(BaseDecodeHead):
padding=int(kernel_size - 1) // 2,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg),
nn.Upsample(
Upsample(
scale_factor=up_scale,
mode='bilinear',
align_corners=self.align_corners)))

View File

@ -3,6 +3,7 @@ import torch.nn.functional as F
from mmcv.cnn import ConvModule
from mmcv.runner import BaseModule, auto_fp16
from mmseg.ops import resize
from ..builder import NECKS
@ -173,11 +174,10 @@ class FPN(BaseModule):
# In some cases, fixing `scale factor` (e.g. 2) is preferred, but
# it cannot co-exist with `size` in `F.interpolate`.
if 'scale_factor' in self.upsample_cfg:
laterals[i - 1] += F.interpolate(laterals[i],
**self.upsample_cfg)
laterals[i - 1] += resize(laterals[i], **self.upsample_cfg)
else:
prev_shape = laterals[i - 1].shape[2:]
laterals[i - 1] += F.interpolate(
laterals[i - 1] += resize(
laterals[i], size=prev_shape, **self.upsample_cfg)
# build outputs

View File

@ -1,7 +1,7 @@
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import ConvModule, xavier_init
from mmseg.ops import resize
from ..builder import NECKS
@ -70,7 +70,7 @@ class MultiLevelNeck(nn.Module):
inputs = [inputs[0] for _ in range(self.num_outs)]
outs = []
for i in range(self.num_outs):
x_resize = F.interpolate(
x_resize = resize(
inputs[i], scale_factor=self.scales[i], mode='bilinear')
outs.append(self.convs[i](x_resize))
return tuple(outs)

View File

@ -1,10 +1,10 @@
import pytest
import torch
from mmcv.cnn import ConvModule
from torch import nn
from mmseg.models.backbones.unet import (BasicConvBlock, DeconvModule,
InterpConv, UNet, UpConvBlock)
from mmseg.ops import Upsample
from .utils import check_norm_state
@ -145,7 +145,7 @@ def test_interp_conv():
block = InterpConv(64, 32, conv_first=False)
x = torch.randn(1, 64, 128, 128)
x_out = block(x)
assert isinstance(block.interp_upsample[0], nn.Upsample)
assert isinstance(block.interp_upsample[0], Upsample)
assert isinstance(block.interp_upsample[1], ConvModule)
assert x_out.shape == torch.Size([1, 32, 256, 256])
@ -154,7 +154,7 @@ def test_interp_conv():
x = torch.randn(1, 64, 128, 128)
x_out = block(x)
assert isinstance(block.interp_upsample[0], ConvModule)
assert isinstance(block.interp_upsample[1], nn.Upsample)
assert isinstance(block.interp_upsample[1], Upsample)
assert x_out.shape == torch.Size([1, 32, 256, 256])
# test InterpConv with bilinear upsample for upsample 2X.
@ -166,7 +166,7 @@ def test_interp_conv():
scale_factor=2, mode='bilinear', align_corners=False))
x = torch.randn(1, 64, 128, 128)
x_out = block(x)
assert isinstance(block.interp_upsample[0], nn.Upsample)
assert isinstance(block.interp_upsample[0], Upsample)
assert isinstance(block.interp_upsample[1], ConvModule)
assert x_out.shape == torch.Size([1, 32, 256, 256])
assert block.interp_upsample[0].mode == 'bilinear'
@ -179,7 +179,7 @@ def test_interp_conv():
upsample_cfg=dict(scale_factor=2, mode='nearest'))
x = torch.randn(1, 64, 128, 128)
x_out = block(x)
assert isinstance(block.interp_upsample[0], nn.Upsample)
assert isinstance(block.interp_upsample[0], Upsample)
assert isinstance(block.interp_upsample[1], ConvModule)
assert x_out.shape == torch.Size([1, 32, 256, 256])
assert block.interp_upsample[0].mode == 'nearest'

View File

@ -14,6 +14,7 @@ from mmcv.utils import DictAction
from mmseg.apis import single_gpu_test
from mmseg.datasets import build_dataloader, build_dataset
from mmseg.models.segmentors.base import BaseSegmentor
from mmseg.ops import resize
class ONNXRuntimeSegmentor(BaseSegmentor):
@ -79,7 +80,7 @@ class ONNXRuntimeSegmentor(BaseSegmentor):
if not (ori_shape[0] == seg_pred.shape[-2]
and ori_shape[1] == seg_pred.shape[-1]):
seg_pred = torch.from_numpy(seg_pred).float()
seg_pred = torch.nn.functional.interpolate(
seg_pred = resize(
seg_pred, size=tuple(ori_shape[:2]), mode='nearest')
seg_pred = seg_pred.long().detach().cpu().numpy()
seg_pred = seg_pred[0]
@ -127,7 +128,7 @@ class TensorRTSegmentor(BaseSegmentor):
if not (ori_shape[0] == seg_pred.shape[-2]
and ori_shape[1] == seg_pred.shape[-1]):
seg_pred = torch.from_numpy(seg_pred).float()
seg_pred = torch.nn.functional.interpolate(
seg_pred = resize(
seg_pred, size=tuple(ori_shape[:2]), mode='nearest')
seg_pred = seg_pred.long().detach().cpu().numpy()
seg_pred = seg_pred[0]

View File

@ -16,6 +16,7 @@ from mmseg.apis import show_result_pyplot
from mmseg.apis.inference import LoadImage
from mmseg.datasets.pipelines import Compose
from mmseg.models import build_segmentor
from mmseg.ops import resize
torch.manual_seed(3)
@ -210,10 +211,7 @@ def pytorch2onnx(model,
if dynamic_export and test_mode == 'whole':
# scale image for dynamic shape test
img_list = [
nn.functional.interpolate(_, scale_factor=1.5)
for _ in img_list
]
img_list = [resize(_, scale_factor=1.5) for _ in img_list]
# concate flip image for batch test
flip_img_list = [_.flip(-1) for _ in img_list]
img_list = [