mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
[Feature] add nlc2nchw2nlc and nchw2nlc2nchw (#1249)
* [Feature] add nlc2nchw2nlc and nchw2nlc2nchw * add example * add test, add **kwargs to make it more universal
This commit is contained in:
parent
d06a84d54d
commit
42df28c7bd
@ -5,11 +5,12 @@ from .make_divisible import make_divisible
|
|||||||
from .res_layer import ResLayer
|
from .res_layer import ResLayer
|
||||||
from .se_layer import SELayer
|
from .se_layer import SELayer
|
||||||
from .self_attention_block import SelfAttentionBlock
|
from .self_attention_block import SelfAttentionBlock
|
||||||
from .shape_convert import nchw_to_nlc, nlc_to_nchw
|
from .shape_convert import (nchw2nlc2nchw, nchw_to_nlc, nlc2nchw2nlc,
|
||||||
|
nlc_to_nchw)
|
||||||
from .up_conv_block import UpConvBlock
|
from .up_conv_block import UpConvBlock
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'ResLayer', 'SelfAttentionBlock', 'make_divisible', 'InvertedResidual',
|
'ResLayer', 'SelfAttentionBlock', 'make_divisible', 'InvertedResidual',
|
||||||
'UpConvBlock', 'InvertedResidualV3', 'SELayer', 'PatchEmbed',
|
'UpConvBlock', 'InvertedResidualV3', 'SELayer', 'PatchEmbed',
|
||||||
'nchw_to_nlc', 'nlc_to_nchw'
|
'nchw_to_nlc', 'nlc_to_nchw', 'nchw2nlc2nchw', 'nlc2nchw2nlc'
|
||||||
]
|
]
|
||||||
|
@ -27,3 +27,81 @@ def nchw_to_nlc(x):
|
|||||||
"""
|
"""
|
||||||
assert len(x.shape) == 4
|
assert len(x.shape) == 4
|
||||||
return x.flatten(2).transpose(1, 2).contiguous()
|
return x.flatten(2).transpose(1, 2).contiguous()
|
||||||
|
|
||||||
|
|
||||||
|
def nchw2nlc2nchw(module, x, contiguous=False, **kwargs):
|
||||||
|
"""Flatten [N, C, H, W] shape tensor `x` to [N, L, C] shape tensor. Use the
|
||||||
|
reshaped tensor as the input of `module`, and the convert the output of
|
||||||
|
`module`, whose shape is.
|
||||||
|
|
||||||
|
[N, L, C], to [N, C, H, W].
|
||||||
|
|
||||||
|
Args:
|
||||||
|
module (Callable): A callable object the takes a tensor
|
||||||
|
with shape [N, L, C] as input.
|
||||||
|
x (Tensor): The input tensor of shape [N, C, H, W].
|
||||||
|
contiguous:
|
||||||
|
contiguous (Bool): Whether to make the tensor contiguous
|
||||||
|
after each shape transform.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: The output tensor of shape [N, C, H, W].
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> import torch
|
||||||
|
>>> import torch.nn as nn
|
||||||
|
>>> norm = nn.LayerNorm(4)
|
||||||
|
>>> feature_map = torch.rand(4, 4, 5, 5)
|
||||||
|
>>> output = nchw2nlc2nchw(norm, feature_map)
|
||||||
|
"""
|
||||||
|
B, C, H, W = x.shape
|
||||||
|
if not contiguous:
|
||||||
|
x = x.flatten(2).transpose(1, 2)
|
||||||
|
x = module(x, **kwargs)
|
||||||
|
x = x.transpose(1, 2).reshape(B, C, H, W)
|
||||||
|
else:
|
||||||
|
x = x.flatten(2).transpose(1, 2).contiguous()
|
||||||
|
x = module(x, **kwargs)
|
||||||
|
x = x.transpose(1, 2).reshape(B, C, H, W).contiguous()
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def nlc2nchw2nlc(module, x, hw_shape, contiguous=False, **kwargs):
|
||||||
|
"""Convert [N, L, C] shape tensor `x` to [N, C, H, W] shape tensor. Use the
|
||||||
|
reshaped tensor as the input of `module`, and convert the output of
|
||||||
|
`module`, whose shape is.
|
||||||
|
|
||||||
|
[N, C, H, W], to [N, L, C].
|
||||||
|
|
||||||
|
Args:
|
||||||
|
module (Callable): A callable object the takes a tensor
|
||||||
|
with shape [N, C, H, W] as input.
|
||||||
|
x (Tensor): The input tensor of shape [N, L, C].
|
||||||
|
hw_shape: (Sequence[int]): The height and width of the
|
||||||
|
feature map with shape [N, C, H, W].
|
||||||
|
contiguous (Bool): Whether to make the tensor contiguous
|
||||||
|
after each shape transform.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: The output tensor of shape [N, L, C].
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> import torch
|
||||||
|
>>> import torch.nn as nn
|
||||||
|
>>> conv = nn.Conv2d(16, 16, 3, 1, 1)
|
||||||
|
>>> feature_map = torch.rand(4, 25, 16)
|
||||||
|
>>> output = nlc2nchw2nlc(conv, feature_map, (5, 5))
|
||||||
|
"""
|
||||||
|
H, W = hw_shape
|
||||||
|
assert len(x.shape) == 3
|
||||||
|
B, L, C = x.shape
|
||||||
|
assert L == H * W, 'The seq_len doesn\'t match H, W'
|
||||||
|
if not contiguous:
|
||||||
|
x = x.transpose(1, 2).reshape(B, C, H, W)
|
||||||
|
x = module(x, **kwargs)
|
||||||
|
x = x.flatten(2).transpose(1, 2)
|
||||||
|
else:
|
||||||
|
x = x.transpose(1, 2).reshape(B, C, H, W).contiguous()
|
||||||
|
x = module(x, **kwargs)
|
||||||
|
x = x.flatten(2).transpose(1, 2).contiguous()
|
||||||
|
return x
|
||||||
|
89
tests/test_models/test_utils/test_shape_convert.py
Normal file
89
tests/test_models/test_utils/test_shape_convert.py
Normal file
@ -0,0 +1,89 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from mmseg.models.utils import (nchw2nlc2nchw, nchw_to_nlc, nlc2nchw2nlc,
|
||||||
|
nlc_to_nchw)
|
||||||
|
|
||||||
|
|
||||||
|
def test_nchw2nlc2nchw():
|
||||||
|
# Test nchw2nlc2nchw function
|
||||||
|
shape_nchw = (4, 2, 5, 5)
|
||||||
|
shape_nlc = (4, 25, 2)
|
||||||
|
|
||||||
|
def test_func(x):
|
||||||
|
assert x.shape == torch.Size(shape_nlc)
|
||||||
|
return x
|
||||||
|
|
||||||
|
x = torch.rand(*shape_nchw)
|
||||||
|
output = nchw2nlc2nchw(test_func, x)
|
||||||
|
assert output.shape == torch.Size(shape_nchw)
|
||||||
|
|
||||||
|
def test_func2(x, arg):
|
||||||
|
assert x.shape == torch.Size(shape_nlc)
|
||||||
|
assert arg == 100
|
||||||
|
return x
|
||||||
|
|
||||||
|
x = torch.rand(*shape_nchw)
|
||||||
|
output = nchw2nlc2nchw(test_func2, x, arg=100)
|
||||||
|
assert output.shape == torch.Size(shape_nchw)
|
||||||
|
|
||||||
|
def test_func3(x):
|
||||||
|
assert x.is_contiguous()
|
||||||
|
assert x.shape == torch.Size(shape_nlc)
|
||||||
|
return x
|
||||||
|
|
||||||
|
x = torch.rand(*shape_nchw)
|
||||||
|
output = nchw2nlc2nchw(test_func3, x, contiguous=True)
|
||||||
|
assert output.shape == torch.Size(shape_nchw)
|
||||||
|
assert output.is_contiguous()
|
||||||
|
|
||||||
|
|
||||||
|
def test_nlc2nchw2nlc():
|
||||||
|
# Test nlc2nchw2nlc function
|
||||||
|
shape_nchw = (4, 2, 5, 5)
|
||||||
|
shape_nlc = (4, 25, 2)
|
||||||
|
|
||||||
|
def test_func(x):
|
||||||
|
assert x.shape == torch.Size(shape_nchw)
|
||||||
|
return x
|
||||||
|
|
||||||
|
x = torch.rand(*shape_nlc)
|
||||||
|
output = nlc2nchw2nlc(test_func, x, shape_nchw[2:])
|
||||||
|
assert output.shape == torch.Size(shape_nlc)
|
||||||
|
|
||||||
|
def test_func2(x, arg):
|
||||||
|
assert x.shape == torch.Size(shape_nchw)
|
||||||
|
assert arg == 100
|
||||||
|
return x
|
||||||
|
|
||||||
|
x = torch.rand(*shape_nlc)
|
||||||
|
output = nlc2nchw2nlc(test_func2, x, shape_nchw[2:], arg=100)
|
||||||
|
assert output.shape == torch.Size(shape_nlc)
|
||||||
|
|
||||||
|
def test_func3(x):
|
||||||
|
assert x.is_contiguous()
|
||||||
|
assert x.shape == torch.Size(shape_nchw)
|
||||||
|
return x
|
||||||
|
|
||||||
|
x = torch.rand(*shape_nlc)
|
||||||
|
output = nlc2nchw2nlc(test_func3, x, shape_nchw[2:], contiguous=True)
|
||||||
|
assert output.shape == torch.Size(shape_nlc)
|
||||||
|
assert output.is_contiguous()
|
||||||
|
|
||||||
|
|
||||||
|
def test_nchw_to_nlc():
|
||||||
|
# Test nchw_to_nlc function
|
||||||
|
shape_nchw = (4, 2, 5, 5)
|
||||||
|
shape_nlc = (4, 25, 2)
|
||||||
|
x = torch.rand(*shape_nchw)
|
||||||
|
y = nchw_to_nlc(x)
|
||||||
|
assert y.shape == torch.Size(shape_nlc)
|
||||||
|
|
||||||
|
|
||||||
|
def test_nlc_to_nchw():
|
||||||
|
# Test nlc_to_nchw function
|
||||||
|
shape_nchw = (4, 2, 5, 5)
|
||||||
|
shape_nlc = (4, 25, 2)
|
||||||
|
x = torch.rand(*shape_nlc)
|
||||||
|
y = nlc_to_nchw(x, (5, 5))
|
||||||
|
assert y.shape == torch.Size(shape_nchw)
|
Loading…
x
Reference in New Issue
Block a user