mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
[Fix] Replace interpolate with resize (#731)
* Replace interpolate with resize * Replace nn.Upsample with ops.Upsample * Fix test
This commit is contained in:
parent
b5ae7a7f69
commit
50461efe85
@ -13,6 +13,7 @@ from torch.nn.modules.linear import Linear
|
|||||||
from torch.nn.modules.normalization import LayerNorm
|
from torch.nn.modules.normalization import LayerNorm
|
||||||
from torch.nn.modules.utils import _pair as to_2tuple
|
from torch.nn.modules.utils import _pair as to_2tuple
|
||||||
|
|
||||||
|
from mmseg.ops import resize
|
||||||
from ...utils import get_root_logger
|
from ...utils import get_root_logger
|
||||||
from ..builder import ATTENTION, BACKBONES
|
from ..builder import ATTENTION, BACKBONES
|
||||||
from ..utils import PatchEmbed, swin_convert
|
from ..utils import PatchEmbed, swin_convert
|
||||||
@ -745,7 +746,7 @@ class SwinTransformer(BaseModule):
|
|||||||
if L1 != L2:
|
if L1 != L2:
|
||||||
S1 = int(L1**0.5)
|
S1 = int(L1**0.5)
|
||||||
S2 = int(L2**0.5)
|
S2 = int(L2**0.5)
|
||||||
table_pretrained_resized = F.interpolate(
|
table_pretrained_resized = resize(
|
||||||
table_pretrained.permute(1, 0).reshape(
|
table_pretrained.permute(1, 0).reshape(
|
||||||
1, nH1, S1, S1),
|
1, nH1, S1, S1),
|
||||||
size=(S2, S2),
|
size=(S2, S2),
|
||||||
|
@ -7,6 +7,7 @@ from mmcv.cnn import (UPSAMPLE_LAYERS, ConvModule, build_activation_layer,
|
|||||||
from mmcv.runner import BaseModule
|
from mmcv.runner import BaseModule
|
||||||
from mmcv.utils.parrots_wrapper import _BatchNorm
|
from mmcv.utils.parrots_wrapper import _BatchNorm
|
||||||
|
|
||||||
|
from mmseg.ops import Upsample
|
||||||
from ..builder import BACKBONES
|
from ..builder import BACKBONES
|
||||||
from ..utils import UpConvBlock
|
from ..utils import UpConvBlock
|
||||||
|
|
||||||
@ -203,7 +204,7 @@ class InterpConv(nn.Module):
|
|||||||
conv_cfg=conv_cfg,
|
conv_cfg=conv_cfg,
|
||||||
norm_cfg=norm_cfg,
|
norm_cfg=norm_cfg,
|
||||||
act_cfg=act_cfg)
|
act_cfg=act_cfg)
|
||||||
upsample = nn.Upsample(**upsample_cfg)
|
upsample = Upsample(**upsample_cfg)
|
||||||
if conv_first:
|
if conv_first:
|
||||||
self.interp_upsample = nn.Sequential(conv, upsample)
|
self.interp_upsample = nn.Sequential(conv, upsample)
|
||||||
else:
|
else:
|
||||||
|
@ -3,7 +3,6 @@ import warnings
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
|
||||||
from mmcv.cnn import (build_norm_layer, constant_init, kaiming_init,
|
from mmcv.cnn import (build_norm_layer, constant_init, kaiming_init,
|
||||||
normal_init, trunc_normal_init)
|
normal_init, trunc_normal_init)
|
||||||
from mmcv.cnn.bricks.transformer import FFN, MultiheadAttention
|
from mmcv.cnn.bricks.transformer import FFN, MultiheadAttention
|
||||||
@ -11,6 +10,7 @@ from mmcv.runner import BaseModule, ModuleList, _load_checkpoint
|
|||||||
from torch.nn.modules.batchnorm import _BatchNorm
|
from torch.nn.modules.batchnorm import _BatchNorm
|
||||||
from torch.nn.modules.utils import _pair as to_2tuple
|
from torch.nn.modules.utils import _pair as to_2tuple
|
||||||
|
|
||||||
|
from mmseg.ops import resize
|
||||||
from mmseg.utils import get_root_logger
|
from mmseg.utils import get_root_logger
|
||||||
from ..builder import BACKBONES
|
from ..builder import BACKBONES
|
||||||
from ..utils import PatchEmbed, vit_convert
|
from ..utils import PatchEmbed, vit_convert
|
||||||
@ -373,7 +373,7 @@ class VisionTransformer(BaseModule):
|
|||||||
pos_embed_weight = pos_embed[:, (-1 * pos_h * pos_w):]
|
pos_embed_weight = pos_embed[:, (-1 * pos_h * pos_w):]
|
||||||
pos_embed_weight = pos_embed_weight.reshape(
|
pos_embed_weight = pos_embed_weight.reshape(
|
||||||
1, pos_h, pos_w, pos_embed.shape[2]).permute(0, 3, 1, 2)
|
1, pos_h, pos_w, pos_embed.shape[2]).permute(0, 3, 1, 2)
|
||||||
pos_embed_weight = F.interpolate(
|
pos_embed_weight = resize(
|
||||||
pos_embed_weight, size=input_shpae, align_corners=False, mode=mode)
|
pos_embed_weight, size=input_shpae, align_corners=False, mode=mode)
|
||||||
cls_token_weight = cls_token_weight.unsqueeze(1)
|
cls_token_weight = cls_token_weight.unsqueeze(1)
|
||||||
pos_embed_weight = torch.flatten(pos_embed_weight, 2).transpose(1, 2)
|
pos_embed_weight = torch.flatten(pos_embed_weight, 2).transpose(1, 2)
|
||||||
|
@ -2,7 +2,7 @@ import numpy as np
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from mmcv.cnn import ConvModule
|
from mmcv.cnn import ConvModule
|
||||||
|
|
||||||
from mmseg.ops import resize
|
from mmseg.ops import Upsample, resize
|
||||||
from ..builder import HEADS
|
from ..builder import HEADS
|
||||||
from .decode_head import BaseDecodeHead
|
from .decode_head import BaseDecodeHead
|
||||||
|
|
||||||
@ -45,7 +45,7 @@ class FPNHead(BaseDecodeHead):
|
|||||||
act_cfg=self.act_cfg))
|
act_cfg=self.act_cfg))
|
||||||
if feature_strides[i] != feature_strides[0]:
|
if feature_strides[i] != feature_strides[0]:
|
||||||
scale_head.append(
|
scale_head.append(
|
||||||
nn.Upsample(
|
Upsample(
|
||||||
scale_factor=2,
|
scale_factor=2,
|
||||||
mode='bilinear',
|
mode='bilinear',
|
||||||
align_corners=self.align_corners))
|
align_corners=self.align_corners))
|
||||||
|
@ -2,6 +2,7 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from mmcv.cnn import ConvModule
|
from mmcv.cnn import ConvModule
|
||||||
|
|
||||||
|
from mmseg.ops import Upsample
|
||||||
from ..builder import HEADS
|
from ..builder import HEADS
|
||||||
from .decode_head import BaseDecodeHead
|
from .decode_head import BaseDecodeHead
|
||||||
|
|
||||||
@ -46,7 +47,7 @@ class SETRMLAHead(BaseDecodeHead):
|
|||||||
padding=1,
|
padding=1,
|
||||||
norm_cfg=self.norm_cfg,
|
norm_cfg=self.norm_cfg,
|
||||||
act_cfg=self.act_cfg),
|
act_cfg=self.act_cfg),
|
||||||
nn.Upsample(
|
Upsample(
|
||||||
scale_factor=up_scale,
|
scale_factor=up_scale,
|
||||||
mode='bilinear',
|
mode='bilinear',
|
||||||
align_corners=self.align_corners)))
|
align_corners=self.align_corners)))
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from mmcv.cnn import ConvModule, build_norm_layer
|
from mmcv.cnn import ConvModule, build_norm_layer
|
||||||
|
|
||||||
|
from mmseg.ops import Upsample
|
||||||
from ..builder import HEADS
|
from ..builder import HEADS
|
||||||
from .decode_head import BaseDecodeHead
|
from .decode_head import BaseDecodeHead
|
||||||
|
|
||||||
@ -59,7 +60,7 @@ class SETRUPHead(BaseDecodeHead):
|
|||||||
padding=int(kernel_size - 1) // 2,
|
padding=int(kernel_size - 1) // 2,
|
||||||
norm_cfg=self.norm_cfg,
|
norm_cfg=self.norm_cfg,
|
||||||
act_cfg=self.act_cfg),
|
act_cfg=self.act_cfg),
|
||||||
nn.Upsample(
|
Upsample(
|
||||||
scale_factor=up_scale,
|
scale_factor=up_scale,
|
||||||
mode='bilinear',
|
mode='bilinear',
|
||||||
align_corners=self.align_corners)))
|
align_corners=self.align_corners)))
|
||||||
|
@ -3,6 +3,7 @@ import torch.nn.functional as F
|
|||||||
from mmcv.cnn import ConvModule
|
from mmcv.cnn import ConvModule
|
||||||
from mmcv.runner import BaseModule, auto_fp16
|
from mmcv.runner import BaseModule, auto_fp16
|
||||||
|
|
||||||
|
from mmseg.ops import resize
|
||||||
from ..builder import NECKS
|
from ..builder import NECKS
|
||||||
|
|
||||||
|
|
||||||
@ -173,11 +174,10 @@ class FPN(BaseModule):
|
|||||||
# In some cases, fixing `scale factor` (e.g. 2) is preferred, but
|
# In some cases, fixing `scale factor` (e.g. 2) is preferred, but
|
||||||
# it cannot co-exist with `size` in `F.interpolate`.
|
# it cannot co-exist with `size` in `F.interpolate`.
|
||||||
if 'scale_factor' in self.upsample_cfg:
|
if 'scale_factor' in self.upsample_cfg:
|
||||||
laterals[i - 1] += F.interpolate(laterals[i],
|
laterals[i - 1] += resize(laterals[i], **self.upsample_cfg)
|
||||||
**self.upsample_cfg)
|
|
||||||
else:
|
else:
|
||||||
prev_shape = laterals[i - 1].shape[2:]
|
prev_shape = laterals[i - 1].shape[2:]
|
||||||
laterals[i - 1] += F.interpolate(
|
laterals[i - 1] += resize(
|
||||||
laterals[i], size=prev_shape, **self.upsample_cfg)
|
laterals[i], size=prev_shape, **self.upsample_cfg)
|
||||||
|
|
||||||
# build outputs
|
# build outputs
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
|
||||||
from mmcv.cnn import ConvModule, xavier_init
|
from mmcv.cnn import ConvModule, xavier_init
|
||||||
|
|
||||||
|
from mmseg.ops import resize
|
||||||
from ..builder import NECKS
|
from ..builder import NECKS
|
||||||
|
|
||||||
|
|
||||||
@ -70,7 +70,7 @@ class MultiLevelNeck(nn.Module):
|
|||||||
inputs = [inputs[0] for _ in range(self.num_outs)]
|
inputs = [inputs[0] for _ in range(self.num_outs)]
|
||||||
outs = []
|
outs = []
|
||||||
for i in range(self.num_outs):
|
for i in range(self.num_outs):
|
||||||
x_resize = F.interpolate(
|
x_resize = resize(
|
||||||
inputs[i], scale_factor=self.scales[i], mode='bilinear')
|
inputs[i], scale_factor=self.scales[i], mode='bilinear')
|
||||||
outs.append(self.convs[i](x_resize))
|
outs.append(self.convs[i](x_resize))
|
||||||
return tuple(outs)
|
return tuple(outs)
|
||||||
|
@ -1,10 +1,10 @@
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
from mmcv.cnn import ConvModule
|
from mmcv.cnn import ConvModule
|
||||||
from torch import nn
|
|
||||||
|
|
||||||
from mmseg.models.backbones.unet import (BasicConvBlock, DeconvModule,
|
from mmseg.models.backbones.unet import (BasicConvBlock, DeconvModule,
|
||||||
InterpConv, UNet, UpConvBlock)
|
InterpConv, UNet, UpConvBlock)
|
||||||
|
from mmseg.ops import Upsample
|
||||||
from .utils import check_norm_state
|
from .utils import check_norm_state
|
||||||
|
|
||||||
|
|
||||||
@ -145,7 +145,7 @@ def test_interp_conv():
|
|||||||
block = InterpConv(64, 32, conv_first=False)
|
block = InterpConv(64, 32, conv_first=False)
|
||||||
x = torch.randn(1, 64, 128, 128)
|
x = torch.randn(1, 64, 128, 128)
|
||||||
x_out = block(x)
|
x_out = block(x)
|
||||||
assert isinstance(block.interp_upsample[0], nn.Upsample)
|
assert isinstance(block.interp_upsample[0], Upsample)
|
||||||
assert isinstance(block.interp_upsample[1], ConvModule)
|
assert isinstance(block.interp_upsample[1], ConvModule)
|
||||||
assert x_out.shape == torch.Size([1, 32, 256, 256])
|
assert x_out.shape == torch.Size([1, 32, 256, 256])
|
||||||
|
|
||||||
@ -154,7 +154,7 @@ def test_interp_conv():
|
|||||||
x = torch.randn(1, 64, 128, 128)
|
x = torch.randn(1, 64, 128, 128)
|
||||||
x_out = block(x)
|
x_out = block(x)
|
||||||
assert isinstance(block.interp_upsample[0], ConvModule)
|
assert isinstance(block.interp_upsample[0], ConvModule)
|
||||||
assert isinstance(block.interp_upsample[1], nn.Upsample)
|
assert isinstance(block.interp_upsample[1], Upsample)
|
||||||
assert x_out.shape == torch.Size([1, 32, 256, 256])
|
assert x_out.shape == torch.Size([1, 32, 256, 256])
|
||||||
|
|
||||||
# test InterpConv with bilinear upsample for upsample 2X.
|
# test InterpConv with bilinear upsample for upsample 2X.
|
||||||
@ -166,7 +166,7 @@ def test_interp_conv():
|
|||||||
scale_factor=2, mode='bilinear', align_corners=False))
|
scale_factor=2, mode='bilinear', align_corners=False))
|
||||||
x = torch.randn(1, 64, 128, 128)
|
x = torch.randn(1, 64, 128, 128)
|
||||||
x_out = block(x)
|
x_out = block(x)
|
||||||
assert isinstance(block.interp_upsample[0], nn.Upsample)
|
assert isinstance(block.interp_upsample[0], Upsample)
|
||||||
assert isinstance(block.interp_upsample[1], ConvModule)
|
assert isinstance(block.interp_upsample[1], ConvModule)
|
||||||
assert x_out.shape == torch.Size([1, 32, 256, 256])
|
assert x_out.shape == torch.Size([1, 32, 256, 256])
|
||||||
assert block.interp_upsample[0].mode == 'bilinear'
|
assert block.interp_upsample[0].mode == 'bilinear'
|
||||||
@ -179,7 +179,7 @@ def test_interp_conv():
|
|||||||
upsample_cfg=dict(scale_factor=2, mode='nearest'))
|
upsample_cfg=dict(scale_factor=2, mode='nearest'))
|
||||||
x = torch.randn(1, 64, 128, 128)
|
x = torch.randn(1, 64, 128, 128)
|
||||||
x_out = block(x)
|
x_out = block(x)
|
||||||
assert isinstance(block.interp_upsample[0], nn.Upsample)
|
assert isinstance(block.interp_upsample[0], Upsample)
|
||||||
assert isinstance(block.interp_upsample[1], ConvModule)
|
assert isinstance(block.interp_upsample[1], ConvModule)
|
||||||
assert x_out.shape == torch.Size([1, 32, 256, 256])
|
assert x_out.shape == torch.Size([1, 32, 256, 256])
|
||||||
assert block.interp_upsample[0].mode == 'nearest'
|
assert block.interp_upsample[0].mode == 'nearest'
|
||||||
|
@ -14,6 +14,7 @@ from mmcv.utils import DictAction
|
|||||||
from mmseg.apis import single_gpu_test
|
from mmseg.apis import single_gpu_test
|
||||||
from mmseg.datasets import build_dataloader, build_dataset
|
from mmseg.datasets import build_dataloader, build_dataset
|
||||||
from mmseg.models.segmentors.base import BaseSegmentor
|
from mmseg.models.segmentors.base import BaseSegmentor
|
||||||
|
from mmseg.ops import resize
|
||||||
|
|
||||||
|
|
||||||
class ONNXRuntimeSegmentor(BaseSegmentor):
|
class ONNXRuntimeSegmentor(BaseSegmentor):
|
||||||
@ -79,7 +80,7 @@ class ONNXRuntimeSegmentor(BaseSegmentor):
|
|||||||
if not (ori_shape[0] == seg_pred.shape[-2]
|
if not (ori_shape[0] == seg_pred.shape[-2]
|
||||||
and ori_shape[1] == seg_pred.shape[-1]):
|
and ori_shape[1] == seg_pred.shape[-1]):
|
||||||
seg_pred = torch.from_numpy(seg_pred).float()
|
seg_pred = torch.from_numpy(seg_pred).float()
|
||||||
seg_pred = torch.nn.functional.interpolate(
|
seg_pred = resize(
|
||||||
seg_pred, size=tuple(ori_shape[:2]), mode='nearest')
|
seg_pred, size=tuple(ori_shape[:2]), mode='nearest')
|
||||||
seg_pred = seg_pred.long().detach().cpu().numpy()
|
seg_pred = seg_pred.long().detach().cpu().numpy()
|
||||||
seg_pred = seg_pred[0]
|
seg_pred = seg_pred[0]
|
||||||
@ -127,7 +128,7 @@ class TensorRTSegmentor(BaseSegmentor):
|
|||||||
if not (ori_shape[0] == seg_pred.shape[-2]
|
if not (ori_shape[0] == seg_pred.shape[-2]
|
||||||
and ori_shape[1] == seg_pred.shape[-1]):
|
and ori_shape[1] == seg_pred.shape[-1]):
|
||||||
seg_pred = torch.from_numpy(seg_pred).float()
|
seg_pred = torch.from_numpy(seg_pred).float()
|
||||||
seg_pred = torch.nn.functional.interpolate(
|
seg_pred = resize(
|
||||||
seg_pred, size=tuple(ori_shape[:2]), mode='nearest')
|
seg_pred, size=tuple(ori_shape[:2]), mode='nearest')
|
||||||
seg_pred = seg_pred.long().detach().cpu().numpy()
|
seg_pred = seg_pred.long().detach().cpu().numpy()
|
||||||
seg_pred = seg_pred[0]
|
seg_pred = seg_pred[0]
|
||||||
|
@ -16,6 +16,7 @@ from mmseg.apis import show_result_pyplot
|
|||||||
from mmseg.apis.inference import LoadImage
|
from mmseg.apis.inference import LoadImage
|
||||||
from mmseg.datasets.pipelines import Compose
|
from mmseg.datasets.pipelines import Compose
|
||||||
from mmseg.models import build_segmentor
|
from mmseg.models import build_segmentor
|
||||||
|
from mmseg.ops import resize
|
||||||
|
|
||||||
torch.manual_seed(3)
|
torch.manual_seed(3)
|
||||||
|
|
||||||
@ -210,10 +211,7 @@ def pytorch2onnx(model,
|
|||||||
|
|
||||||
if dynamic_export and test_mode == 'whole':
|
if dynamic_export and test_mode == 'whole':
|
||||||
# scale image for dynamic shape test
|
# scale image for dynamic shape test
|
||||||
img_list = [
|
img_list = [resize(_, scale_factor=1.5) for _ in img_list]
|
||||||
nn.functional.interpolate(_, scale_factor=1.5)
|
|
||||||
for _ in img_list
|
|
||||||
]
|
|
||||||
# concate flip image for batch test
|
# concate flip image for batch test
|
||||||
flip_img_list = [_.flip(-1) for _ in img_list]
|
flip_img_list = [_.flip(-1) for _ in img_list]
|
||||||
img_list = [
|
img_list = [
|
||||||
|
Loading…
x
Reference in New Issue
Block a user