mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
* [Feature]Segformer re-implementation * Using act_cfg and norm_cfg to control activation and normalization * Split this PR into several little PRs * Fix lint error * Remove SegFormerHead * parameters init refactor * 1. Refactor segformer backbone parameters init; 2. Remove rebundant functions and unit tests; * Remove rebundant codes * 1. Remove rebundant codes; 2. Modify module name; * Refactor the backbone of segformer using mmcv.cnn.bricks.transformer.py * Fix some code logic bugs. * Add mit_convert.py to match pretrain keys of segformer. * Resolve some comments. * 1. Add some assert to ensure right params; 2. Support flexible peconv position; * Add pe_index assert and fix unit test. * 1. Add doc string for MixVisionTransformer; 2. Add some unit tests for MixVisionTransformer; * Use hw_shape to pass shape of feature map. * 1. Fix doc string of MixVisionTransformer; 2. Simplify MixFFN; 3. Modify H, W to hw_shape; * Add more unit tests. * Add doc string for shape convertion functions. * Add some unit tests to improve code coverage. * Fix Segformer backbone pretrain weights match bug. * resolve the shape convertion functions doc string. * Add pad_to_patch_size arg. * Modify default value of pad_to_patch_size arg.
29 lines
889 B
Python
29 lines
889 B
Python
def nlc_to_nchw(x, hw_shape):
|
|
"""Convert [N, L, C] shape tensor to [N, C, H, W] shape tensor.
|
|
|
|
Args:
|
|
x (Tensor): The input tensor of shape [N, L, C] before convertion.
|
|
hw_shape (Sequence[int]): The height and width of output feature map.
|
|
|
|
Returns:
|
|
Tensor: The output tensor of shape [N, C, H, W] after convertion.
|
|
"""
|
|
H, W = hw_shape
|
|
assert len(x.shape) == 3
|
|
B, L, C = x.shape
|
|
assert L == H * W, 'The seq_len doesn\'t match H, W'
|
|
return x.transpose(1, 2).reshape(B, C, H, W)
|
|
|
|
|
|
def nchw_to_nlc(x):
|
|
"""Flatten [N, C, H, W] shape tensor to [N, L, C] shape tensor.
|
|
|
|
Args:
|
|
x (Tensor): The input tensor of shape [N, C, H, W] before convertion.
|
|
|
|
Returns:
|
|
Tensor: The output tensor of shape [N, L, C] after convertion.
|
|
"""
|
|
assert len(x.shape) == 4
|
|
return x.flatten(2).transpose(1, 2).contiguous()
|