mmsegmentation/mmseg/models/utils/shape_convert.py
Junjun2016 67f1420472 [Enhancement] Add codespell pre-commit hook and fix typos (#920)
* add codespell pre-commit hook and fix typos

* Update mmseg/models/decode_heads/dpt_head.py

* Update mmseg/models/backbones/vit.py

* Update mmseg/models/backbones/vit.py

* fix typos

* skip formating typo

* deprecate formating

* skip ipynb

* unstage ipynb changes

* unstage ipynb changes

* fix typos in ipynb

* unstage ipynb changes
2021-10-13 06:21:17 -07:00

30 lines
937 B
Python

# Copyright (c) OpenMMLab. All rights reserved.
def nlc_to_nchw(x, hw_shape):
"""Convert [N, L, C] shape tensor to [N, C, H, W] shape tensor.
Args:
x (Tensor): The input tensor of shape [N, L, C] before conversion.
hw_shape (Sequence[int]): The height and width of output feature map.
Returns:
Tensor: The output tensor of shape [N, C, H, W] after conversion.
"""
H, W = hw_shape
assert len(x.shape) == 3
B, L, C = x.shape
assert L == H * W, 'The seq_len doesn\'t match H, W'
return x.transpose(1, 2).reshape(B, C, H, W)
def nchw_to_nlc(x):
"""Flatten [N, C, H, W] shape tensor to [N, L, C] shape tensor.
Args:
x (Tensor): The input tensor of shape [N, C, H, W] before conversion.
Returns:
Tensor: The output tensor of shape [N, L, C] after conversion.
"""
assert len(x.shape) == 4
return x.flatten(2).transpose(1, 2).contiguous()