mmsegmentation/mmseg/models/utils/shape_convert.py

30 lines
937 B
Python
Raw Normal View History

# Copyright (c) OpenMMLab. All rights reserved.
def nlc_to_nchw(x, hw_shape):
"""Convert [N, L, C] shape tensor to [N, C, H, W] shape tensor.
Args:
x (Tensor): The input tensor of shape [N, L, C] before convertion.
hw_shape (Sequence[int]): The height and width of output feature map.
Returns:
Tensor: The output tensor of shape [N, C, H, W] after convertion.
"""
H, W = hw_shape
assert len(x.shape) == 3
B, L, C = x.shape
assert L == H * W, 'The seq_len doesn\'t match H, W'
return x.transpose(1, 2).reshape(B, C, H, W)
def nchw_to_nlc(x):
"""Flatten [N, C, H, W] shape tensor to [N, L, C] shape tensor.
Args:
x (Tensor): The input tensor of shape [N, C, H, W] before convertion.
Returns:
Tensor: The output tensor of shape [N, L, C] after convertion.
"""
assert len(x.shape) == 4
return x.flatten(2).transpose(1, 2).contiguous()