mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
* [Feature]Segformer re-implementation * Using act_cfg and norm_cfg to control activation and normalization * Split this PR into several little PRs * Fix lint error * Remove SegFormerHead * [Feature] Add segformer decode head and related train config * Add ade20K trainval support for segformer 1. Add related train and val configs; 2. Add AlignedResize; * Set arg: find_unused_parameters = True * parameters init refactor * 1. Refactor segformer backbone parameters init; 2. Remove rebundant functions and unit tests; * Remove rebundant codes * Replace Linear Layer to 1X1 Conv * Use nn.ModuleList to refactor segformer head. * Remove local to_xtuple * 1. Remove rebundant codes; 2. Modify module name; * Refactor the backbone of segformer using mmcv.cnn.bricks.transformer.py * Fix some code logic bugs. * Add mit_convert.py to match pretrain keys of segformer. * Resolve some comments. * 1. Add some assert to ensure right params; 2. Support flexible peconv position; * Add pe_index assert and fix unit test. * 1. Add doc string for MixVisionTransformer; 2. Add some unit tests for MixVisionTransformer; * Use hw_shape to pass shape of feature map. * 1. Fix doc string of MixVisionTransformer; 2. Simplify MixFFN; 3. Modify H, W to hw_shape; * Add more unit tests. * Add doc string for shape convertion functions. * Add some unit tests to improve code coverage. * Fix Segformer backbone pretrain weights match bug. * Modify configs of segformer. * resolve the shape convertion functions doc string. * Add pad_to_patch_size arg. * Support progressive test with fewer memory cost. * Modify default value of pad_to_patch_size arg. * Temp code * Using processor to refactor evaluation workflow. * refactor eval hook. * Fix process bar. * Fix middle save argument. * Modify some variable name of dataset evaluate api. * Modify some viriable name of eval hook. * Fix some priority bugs of eval hook. * Fix some bugs about model loading and eval hook. * Add ade20k 640x640 dataset. * Fix related segformer configs. * Depreciated efficient_test. * Fix training progress blocked by eval hook. * Depreciated old test api. * Modify error patch size. * Fix pretrain of mit_b0 * Fix the test api error. * Modify dataset base config. * Fix test api error. * Modify outer api. * Build a sampler test api. * TODO: Refactor format_results. * Modify variable names. * Fix num_classes bug. * Fix sampler index bug. * Fix grammaly bug. * Add part of benchmark results. * Support batch sampler. * More readable test api. * Remove some command arg and fix eval hook bug. * Support format-only arg. * Modify format_results of datasets. * Modify tool which use test apis. * Update readme. * Update readme of segformer. * Updata readme of segformer. * Update segformer readme and fix segformer mit_b4. * Update readme of segformer. * Clean AlignedResize related config. * Clean code from pr #709 * Clean code from pr #709 * Add 512x512 segformer_mit-b5. * Fix lint. * Fix some segformer head bugs. * Add segformer unit tests. * Replace AlignedResize to ResizeToMultiple. * Modify readme of segformer. * Fix bug of ResizeToMultiple. * Add ResizeToMultiple unit tests. * Resolve conflict. * Simplify the implementation of ResizeToMultiple. * Update test results. * Fix multi-scale test error when resize_ratio=1.75 and input size=640x640. * Update segformer results. * Update Segformer results. * Fix some url bugs and pipelines bug. * Move ckpt convertion to tools. * Add segformer official pretrain weights usage. * Clean redundant codes. * Remove redundant codes. * Unfied format. * Add description for segformer converter. * Update workers.
66 lines
1.9 KiB
Python
66 lines
1.9 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
from mmcv.cnn import ConvModule
|
|
|
|
from mmseg.models.builder import HEADS
|
|
from mmseg.models.decode_heads.decode_head import BaseDecodeHead
|
|
from mmseg.ops import resize
|
|
|
|
|
|
@HEADS.register_module()
|
|
class SegformerHead(BaseDecodeHead):
|
|
"""The all mlp Head of segformer.
|
|
|
|
This head is the implementation of
|
|
`Segformer <https://arxiv.org/abs/2105.15203>` _.
|
|
|
|
Args:
|
|
interpolate_mode: The interpolate mode of MLP head upsample operation.
|
|
Default: 'bilinear'.
|
|
"""
|
|
|
|
def __init__(self, interpolate_mode='bilinear', **kwargs):
|
|
super().__init__(input_transform='multiple_select', **kwargs)
|
|
|
|
self.interpolate_mode = interpolate_mode
|
|
num_inputs = len(self.in_channels)
|
|
|
|
assert num_inputs == len(self.in_index)
|
|
|
|
self.convs = nn.ModuleList()
|
|
for i in range(num_inputs):
|
|
self.convs.append(
|
|
ConvModule(
|
|
in_channels=self.in_channels[i],
|
|
out_channels=self.channels,
|
|
kernel_size=1,
|
|
stride=1,
|
|
norm_cfg=self.norm_cfg,
|
|
act_cfg=self.act_cfg))
|
|
|
|
self.fusion_conv = ConvModule(
|
|
in_channels=self.channels * num_inputs,
|
|
out_channels=self.channels,
|
|
kernel_size=1,
|
|
norm_cfg=self.norm_cfg)
|
|
|
|
def forward(self, inputs):
|
|
# Receive 4 stage backbone feature map: 1/4, 1/8, 1/16, 1/32
|
|
inputs = self._transform_inputs(inputs)
|
|
outs = []
|
|
for idx in range(len(inputs)):
|
|
x = inputs[idx]
|
|
conv = self.convs[idx]
|
|
outs.append(
|
|
resize(
|
|
input=conv(x),
|
|
size=inputs[0].shape[2:],
|
|
mode=self.interpolate_mode,
|
|
align_corners=self.align_corners))
|
|
|
|
out = self.fusion_conv(torch.cat(outs, dim=1))
|
|
|
|
out = self.cls_seg(out)
|
|
|
|
return out
|