mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
[Fix] Replace interpolate with resize (#731)
* Replace interpolate with resize * Replace nn.Upsample with ops.Upsample * Fix test
This commit is contained in:
parent
b5ae7a7f69
commit
50461efe85
@ -13,6 +13,7 @@ from torch.nn.modules.linear import Linear
|
||||
from torch.nn.modules.normalization import LayerNorm
|
||||
from torch.nn.modules.utils import _pair as to_2tuple
|
||||
|
||||
from mmseg.ops import resize
|
||||
from ...utils import get_root_logger
|
||||
from ..builder import ATTENTION, BACKBONES
|
||||
from ..utils import PatchEmbed, swin_convert
|
||||
@ -745,7 +746,7 @@ class SwinTransformer(BaseModule):
|
||||
if L1 != L2:
|
||||
S1 = int(L1**0.5)
|
||||
S2 = int(L2**0.5)
|
||||
table_pretrained_resized = F.interpolate(
|
||||
table_pretrained_resized = resize(
|
||||
table_pretrained.permute(1, 0).reshape(
|
||||
1, nH1, S1, S1),
|
||||
size=(S2, S2),
|
||||
|
@ -7,6 +7,7 @@ from mmcv.cnn import (UPSAMPLE_LAYERS, ConvModule, build_activation_layer,
|
||||
from mmcv.runner import BaseModule
|
||||
from mmcv.utils.parrots_wrapper import _BatchNorm
|
||||
|
||||
from mmseg.ops import Upsample
|
||||
from ..builder import BACKBONES
|
||||
from ..utils import UpConvBlock
|
||||
|
||||
@ -203,7 +204,7 @@ class InterpConv(nn.Module):
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
upsample = nn.Upsample(**upsample_cfg)
|
||||
upsample = Upsample(**upsample_cfg)
|
||||
if conv_first:
|
||||
self.interp_upsample = nn.Sequential(conv, upsample)
|
||||
else:
|
||||
|
@ -3,7 +3,6 @@ import warnings
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmcv.cnn import (build_norm_layer, constant_init, kaiming_init,
|
||||
normal_init, trunc_normal_init)
|
||||
from mmcv.cnn.bricks.transformer import FFN, MultiheadAttention
|
||||
@ -11,6 +10,7 @@ from mmcv.runner import BaseModule, ModuleList, _load_checkpoint
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
from torch.nn.modules.utils import _pair as to_2tuple
|
||||
|
||||
from mmseg.ops import resize
|
||||
from mmseg.utils import get_root_logger
|
||||
from ..builder import BACKBONES
|
||||
from ..utils import PatchEmbed, vit_convert
|
||||
@ -373,7 +373,7 @@ class VisionTransformer(BaseModule):
|
||||
pos_embed_weight = pos_embed[:, (-1 * pos_h * pos_w):]
|
||||
pos_embed_weight = pos_embed_weight.reshape(
|
||||
1, pos_h, pos_w, pos_embed.shape[2]).permute(0, 3, 1, 2)
|
||||
pos_embed_weight = F.interpolate(
|
||||
pos_embed_weight = resize(
|
||||
pos_embed_weight, size=input_shpae, align_corners=False, mode=mode)
|
||||
cls_token_weight = cls_token_weight.unsqueeze(1)
|
||||
pos_embed_weight = torch.flatten(pos_embed_weight, 2).transpose(1, 2)
|
||||
|
@ -2,7 +2,7 @@ import numpy as np
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule
|
||||
|
||||
from mmseg.ops import resize
|
||||
from mmseg.ops import Upsample, resize
|
||||
from ..builder import HEADS
|
||||
from .decode_head import BaseDecodeHead
|
||||
|
||||
@ -45,7 +45,7 @@ class FPNHead(BaseDecodeHead):
|
||||
act_cfg=self.act_cfg))
|
||||
if feature_strides[i] != feature_strides[0]:
|
||||
scale_head.append(
|
||||
nn.Upsample(
|
||||
Upsample(
|
||||
scale_factor=2,
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners))
|
||||
|
@ -2,6 +2,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule
|
||||
|
||||
from mmseg.ops import Upsample
|
||||
from ..builder import HEADS
|
||||
from .decode_head import BaseDecodeHead
|
||||
|
||||
@ -46,7 +47,7 @@ class SETRMLAHead(BaseDecodeHead):
|
||||
padding=1,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg),
|
||||
nn.Upsample(
|
||||
Upsample(
|
||||
scale_factor=up_scale,
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)))
|
||||
|
@ -1,6 +1,7 @@
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule, build_norm_layer
|
||||
|
||||
from mmseg.ops import Upsample
|
||||
from ..builder import HEADS
|
||||
from .decode_head import BaseDecodeHead
|
||||
|
||||
@ -59,7 +60,7 @@ class SETRUPHead(BaseDecodeHead):
|
||||
padding=int(kernel_size - 1) // 2,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg),
|
||||
nn.Upsample(
|
||||
Upsample(
|
||||
scale_factor=up_scale,
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)))
|
||||
|
@ -3,6 +3,7 @@ import torch.nn.functional as F
|
||||
from mmcv.cnn import ConvModule
|
||||
from mmcv.runner import BaseModule, auto_fp16
|
||||
|
||||
from mmseg.ops import resize
|
||||
from ..builder import NECKS
|
||||
|
||||
|
||||
@ -173,11 +174,10 @@ class FPN(BaseModule):
|
||||
# In some cases, fixing `scale factor` (e.g. 2) is preferred, but
|
||||
# it cannot co-exist with `size` in `F.interpolate`.
|
||||
if 'scale_factor' in self.upsample_cfg:
|
||||
laterals[i - 1] += F.interpolate(laterals[i],
|
||||
**self.upsample_cfg)
|
||||
laterals[i - 1] += resize(laterals[i], **self.upsample_cfg)
|
||||
else:
|
||||
prev_shape = laterals[i - 1].shape[2:]
|
||||
laterals[i - 1] += F.interpolate(
|
||||
laterals[i - 1] += resize(
|
||||
laterals[i], size=prev_shape, **self.upsample_cfg)
|
||||
|
||||
# build outputs
|
||||
|
@ -1,7 +1,7 @@
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmcv.cnn import ConvModule, xavier_init
|
||||
|
||||
from mmseg.ops import resize
|
||||
from ..builder import NECKS
|
||||
|
||||
|
||||
@ -70,7 +70,7 @@ class MultiLevelNeck(nn.Module):
|
||||
inputs = [inputs[0] for _ in range(self.num_outs)]
|
||||
outs = []
|
||||
for i in range(self.num_outs):
|
||||
x_resize = F.interpolate(
|
||||
x_resize = resize(
|
||||
inputs[i], scale_factor=self.scales[i], mode='bilinear')
|
||||
outs.append(self.convs[i](x_resize))
|
||||
return tuple(outs)
|
||||
|
@ -1,10 +1,10 @@
|
||||
import pytest
|
||||
import torch
|
||||
from mmcv.cnn import ConvModule
|
||||
from torch import nn
|
||||
|
||||
from mmseg.models.backbones.unet import (BasicConvBlock, DeconvModule,
|
||||
InterpConv, UNet, UpConvBlock)
|
||||
from mmseg.ops import Upsample
|
||||
from .utils import check_norm_state
|
||||
|
||||
|
||||
@ -145,7 +145,7 @@ def test_interp_conv():
|
||||
block = InterpConv(64, 32, conv_first=False)
|
||||
x = torch.randn(1, 64, 128, 128)
|
||||
x_out = block(x)
|
||||
assert isinstance(block.interp_upsample[0], nn.Upsample)
|
||||
assert isinstance(block.interp_upsample[0], Upsample)
|
||||
assert isinstance(block.interp_upsample[1], ConvModule)
|
||||
assert x_out.shape == torch.Size([1, 32, 256, 256])
|
||||
|
||||
@ -154,7 +154,7 @@ def test_interp_conv():
|
||||
x = torch.randn(1, 64, 128, 128)
|
||||
x_out = block(x)
|
||||
assert isinstance(block.interp_upsample[0], ConvModule)
|
||||
assert isinstance(block.interp_upsample[1], nn.Upsample)
|
||||
assert isinstance(block.interp_upsample[1], Upsample)
|
||||
assert x_out.shape == torch.Size([1, 32, 256, 256])
|
||||
|
||||
# test InterpConv with bilinear upsample for upsample 2X.
|
||||
@ -166,7 +166,7 @@ def test_interp_conv():
|
||||
scale_factor=2, mode='bilinear', align_corners=False))
|
||||
x = torch.randn(1, 64, 128, 128)
|
||||
x_out = block(x)
|
||||
assert isinstance(block.interp_upsample[0], nn.Upsample)
|
||||
assert isinstance(block.interp_upsample[0], Upsample)
|
||||
assert isinstance(block.interp_upsample[1], ConvModule)
|
||||
assert x_out.shape == torch.Size([1, 32, 256, 256])
|
||||
assert block.interp_upsample[0].mode == 'bilinear'
|
||||
@ -179,7 +179,7 @@ def test_interp_conv():
|
||||
upsample_cfg=dict(scale_factor=2, mode='nearest'))
|
||||
x = torch.randn(1, 64, 128, 128)
|
||||
x_out = block(x)
|
||||
assert isinstance(block.interp_upsample[0], nn.Upsample)
|
||||
assert isinstance(block.interp_upsample[0], Upsample)
|
||||
assert isinstance(block.interp_upsample[1], ConvModule)
|
||||
assert x_out.shape == torch.Size([1, 32, 256, 256])
|
||||
assert block.interp_upsample[0].mode == 'nearest'
|
||||
|
@ -14,6 +14,7 @@ from mmcv.utils import DictAction
|
||||
from mmseg.apis import single_gpu_test
|
||||
from mmseg.datasets import build_dataloader, build_dataset
|
||||
from mmseg.models.segmentors.base import BaseSegmentor
|
||||
from mmseg.ops import resize
|
||||
|
||||
|
||||
class ONNXRuntimeSegmentor(BaseSegmentor):
|
||||
@ -79,7 +80,7 @@ class ONNXRuntimeSegmentor(BaseSegmentor):
|
||||
if not (ori_shape[0] == seg_pred.shape[-2]
|
||||
and ori_shape[1] == seg_pred.shape[-1]):
|
||||
seg_pred = torch.from_numpy(seg_pred).float()
|
||||
seg_pred = torch.nn.functional.interpolate(
|
||||
seg_pred = resize(
|
||||
seg_pred, size=tuple(ori_shape[:2]), mode='nearest')
|
||||
seg_pred = seg_pred.long().detach().cpu().numpy()
|
||||
seg_pred = seg_pred[0]
|
||||
@ -127,7 +128,7 @@ class TensorRTSegmentor(BaseSegmentor):
|
||||
if not (ori_shape[0] == seg_pred.shape[-2]
|
||||
and ori_shape[1] == seg_pred.shape[-1]):
|
||||
seg_pred = torch.from_numpy(seg_pred).float()
|
||||
seg_pred = torch.nn.functional.interpolate(
|
||||
seg_pred = resize(
|
||||
seg_pred, size=tuple(ori_shape[:2]), mode='nearest')
|
||||
seg_pred = seg_pred.long().detach().cpu().numpy()
|
||||
seg_pred = seg_pred[0]
|
||||
|
@ -16,6 +16,7 @@ from mmseg.apis import show_result_pyplot
|
||||
from mmseg.apis.inference import LoadImage
|
||||
from mmseg.datasets.pipelines import Compose
|
||||
from mmseg.models import build_segmentor
|
||||
from mmseg.ops import resize
|
||||
|
||||
torch.manual_seed(3)
|
||||
|
||||
@ -210,10 +211,7 @@ def pytorch2onnx(model,
|
||||
|
||||
if dynamic_export and test_mode == 'whole':
|
||||
# scale image for dynamic shape test
|
||||
img_list = [
|
||||
nn.functional.interpolate(_, scale_factor=1.5)
|
||||
for _ in img_list
|
||||
]
|
||||
img_list = [resize(_, scale_factor=1.5) for _ in img_list]
|
||||
# concate flip image for batch test
|
||||
flip_img_list = [_.flip(-1) for _ in img_list]
|
||||
img_list = [
|
||||
|
Loading…
x
Reference in New Issue
Block a user