mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
[Feature] add DPT head (#605)
* add DPT head * [fix] fix init error * use mmcv function * delete code * remove transpose clas * support NLC output shape * Delete post_process_layer.py * add unittest and docstring * rename variables * fix project error and add unittest * match dpt weights * add configs * fix vit pos_embed bug and dpt feature fusion bug * match vit output * fix gelu * minor change * update unitest * fix configs error * inference test * remove auxilary * use local pretrain * update training results * update yml * update fps and memory test * update doc * update readme * add yml * update doc * remove with_cp * update config * update docstring * remove dpt-l * add init_cfg and modify readme.md * Update dpt_vit-b16.py * zh-n README * use constructor instead of build function * prevent tensor being modified by ConvModule * fix unittest Co-authored-by: Junjun2016 <hejunjun@sjtu.edu.cn>
This commit is contained in:
parent
0cf838f294
commit
ef4b30038f
31
configs/_base_/models/dpt_vit-b16.py
Normal file
31
configs/_base_/models/dpt_vit-b16.py
Normal file
@ -0,0 +1,31 @@
|
||||
norm_cfg = dict(type='SyncBN', requires_grad=True)
|
||||
model = dict(
|
||||
type='EncoderDecoder',
|
||||
pretrained='pretrain/vit-b16_p16_224-80ecf9dd.pth', # noqa
|
||||
backbone=dict(
|
||||
type='VisionTransformer',
|
||||
img_size=224,
|
||||
embed_dims=768,
|
||||
num_layers=12,
|
||||
num_heads=12,
|
||||
out_indices=(2, 5, 8, 11),
|
||||
final_norm=False,
|
||||
with_cls_token=True,
|
||||
output_cls_token=True),
|
||||
decode_head=dict(
|
||||
type='DPTHead',
|
||||
in_channels=(768, 768, 768, 768),
|
||||
channels=256,
|
||||
embed_dims=768,
|
||||
post_process_channels=[96, 192, 384, 768],
|
||||
num_classes=150,
|
||||
readout_type='project',
|
||||
input_transform='multiple_select',
|
||||
in_index=(0, 1, 2, 3),
|
||||
norm_cfg=norm_cfg,
|
||||
loss_decode=dict(
|
||||
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
|
||||
auxiliary_head=None,
|
||||
# model training and testing settings
|
||||
train_cfg=dict(),
|
||||
test_cfg=dict(mode='whole')) # yapf: disable
|
47
configs/dpt/README.md
Normal file
47
configs/dpt/README.md
Normal file
@ -0,0 +1,47 @@
|
||||
# Vision Transformer for Dense Prediction
|
||||
|
||||
## Introduction
|
||||
|
||||
<!-- [ALGORITHM] -->
|
||||
|
||||
```latex
|
||||
@article{dosoViTskiy2020,
|
||||
title={An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale},
|
||||
author={DosoViTskiy, Alexey and Beyer, Lucas and Kolesnikov, Alexander and Weissenborn, Dirk and Zhai, Xiaohua and Unterthiner, Thomas and Dehghani, Mostafa and Minderer, Matthias and Heigold, Georg and Gelly, Sylvain and Uszkoreit, Jakob and Houlsby, Neil},
|
||||
journal={arXiv preprint arXiv:2010.11929},
|
||||
year={2020}
|
||||
}
|
||||
|
||||
@article{Ranftl2021,
|
||||
author = {Ren\'{e} Ranftl and Alexey Bochkovskiy and Vladlen Koltun},
|
||||
title = {Vision Transformers for Dense Prediction},
|
||||
journal = {ArXiv preprint},
|
||||
year = {2021},
|
||||
}
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
To use other repositories' pre-trained models, it is necessary to convert keys.
|
||||
|
||||
We provide a script [`vit2mmseg.py`](../../tools/model_converters/vit2mmseg.py) in the tools directory to convert the key of models from [timm](https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py) to MMSegmentation style.
|
||||
|
||||
```shell
|
||||
python tools/model_converters/vit2mmseg.py ${PRETRAIN_PATH} ${STORE_PATH}
|
||||
```
|
||||
|
||||
E.g.
|
||||
|
||||
```shell
|
||||
python tools/model_converters/vit2mmseg.py https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth pretrain/jx_vit_base_p16_224-80ecf9dd.pth
|
||||
```
|
||||
|
||||
This script convert model from `PRETRAIN_PATH` and store the converted model in `STORE_PATH`.
|
||||
|
||||
## Results and models
|
||||
|
||||
### ADE20K
|
||||
|
||||
| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | config | download |
|
||||
| ------- | -------- | --------- | ------: | -------- | -------------- | ----: | ------------: | ---------------------------------------------------------------------------------------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| DPT | ViT-B | 512x512 | 160000 | 8.09 | 10.41 | 46.97 | 48.34 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/dpt/dpt_vit-b16_512x512_160k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/dpt/dpt_vit-b16_512x512_160k_ade20k/dpt_vit-b16_512x512_160k_ade20k-db31cf52.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/dpt/dpt_vit-b16_512x512_160k_ade20k/dpt_vit-b16_512x512_160k_ade20k-20210809_172025.log.json) |
|
28
configs/dpt/dpt.yml
Normal file
28
configs/dpt/dpt.yml
Normal file
@ -0,0 +1,28 @@
|
||||
Collections:
|
||||
- Metadata:
|
||||
Training Data:
|
||||
- ADE20K
|
||||
Name: dpt
|
||||
Models:
|
||||
- Config: configs/dpt/dpt_vit-b16_512x512_160k_ade20k.py
|
||||
In Collection: dpt
|
||||
Metadata:
|
||||
backbone: ViT-B
|
||||
crop size: (512,512)
|
||||
inference time (ms/im):
|
||||
- backend: PyTorch
|
||||
batch size: 1
|
||||
hardware: V100
|
||||
mode: FP32
|
||||
resolution: (512,512)
|
||||
value: 96.06
|
||||
lr schd: 160000
|
||||
memory (GB): 8.09
|
||||
Name: dpt_vit-b16_512x512_160k_ade20k
|
||||
Results:
|
||||
Dataset: ADE20K
|
||||
Metrics:
|
||||
mIoU: 46.97
|
||||
mIoU(ms+flip): 48.34
|
||||
Task: Semantic Segmentation
|
||||
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/dpt/dpt_vit-b16_512x512_160k_ade20k/dpt_vit-b16_512x512_160k_ade20k-db31cf52.pth
|
32
configs/dpt/dpt_vit-b16_512x512_160k_ade20k.py
Normal file
32
configs/dpt/dpt_vit-b16_512x512_160k_ade20k.py
Normal file
@ -0,0 +1,32 @@
|
||||
_base_ = [
|
||||
'../_base_/models/dpt_vit-b16.py', '../_base_/datasets/ade20k.py',
|
||||
'../_base_/default_runtime.py', '../_base_/schedules/schedule_160k.py'
|
||||
]
|
||||
|
||||
# AdamW optimizer, no weight decay for position embedding & layer norm
|
||||
# in backbone
|
||||
optimizer = dict(
|
||||
_delete_=True,
|
||||
type='AdamW',
|
||||
lr=0.00006,
|
||||
betas=(0.9, 0.999),
|
||||
weight_decay=0.01,
|
||||
paramwise_cfg=dict(
|
||||
custom_keys={
|
||||
'pos_embed': dict(decay_mult=0.),
|
||||
'cls_token': dict(decay_mult=0.),
|
||||
'norm': dict(decay_mult=0.)
|
||||
}))
|
||||
|
||||
lr_config = dict(
|
||||
_delete_=True,
|
||||
policy='poly',
|
||||
warmup='linear',
|
||||
warmup_iters=1500,
|
||||
warmup_ratio=1e-6,
|
||||
power=1.0,
|
||||
min_lr=0.0,
|
||||
by_epoch=False)
|
||||
|
||||
# By default, models are trained on 8 GPUs with 2 images per GPU
|
||||
data = dict(samples_per_gpu=2, workers_per_gpu=2)
|
@ -6,6 +6,7 @@ from .cc_head import CCHead
|
||||
from .da_head import DAHead
|
||||
from .dm_head import DMHead
|
||||
from .dnl_head import DNLHead
|
||||
from .dpt_head import DPTHead
|
||||
from .ema_head import EMAHead
|
||||
from .enc_head import EncHead
|
||||
from .fcn_head import FCNHead
|
||||
@ -29,5 +30,5 @@ __all__ = [
|
||||
'UPerHead', 'DepthwiseSeparableASPPHead', 'ANNHead', 'DAHead', 'OCRHead',
|
||||
'EncHead', 'DepthwiseSeparableFCNHead', 'FPNHead', 'EMAHead', 'DNLHead',
|
||||
'PointHead', 'APCHead', 'DMHead', 'LRASPPHead', 'SETRUPHead',
|
||||
'SETRMLAHead', 'SegformerHead'
|
||||
'SETRMLAHead', 'DPTHead', 'SETRMLAHead', 'SegformerHead'
|
||||
]
|
||||
|
293
mmseg/models/decode_heads/dpt_head.py
Normal file
293
mmseg/models/decode_heads/dpt_head.py
Normal file
@ -0,0 +1,293 @@
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule, Linear, build_activation_layer
|
||||
from mmcv.runner import BaseModule
|
||||
|
||||
from mmseg.ops import resize
|
||||
from ..builder import HEADS
|
||||
from .decode_head import BaseDecodeHead
|
||||
|
||||
|
||||
class ReassembleBlocks(BaseModule):
|
||||
"""ViTPostProcessBlock, process cls_token in ViT backbone output and
|
||||
rearrange the feature vector to feature map.
|
||||
|
||||
Args:
|
||||
in_channels (int): ViT feature channels. Default: 768.
|
||||
out_channels (List): output channels of each stage.
|
||||
Default: [96, 192, 384, 768].
|
||||
readout_type (str): Type of readout operation. Default: 'ignore'.
|
||||
patch_size (int): The patch size. Default: 16.
|
||||
init_cfg (dict, optional): Initialization config dict. Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels=768,
|
||||
out_channels=[96, 192, 384, 768],
|
||||
readout_type='ignore',
|
||||
patch_size=16,
|
||||
init_cfg=None):
|
||||
super(ReassembleBlocks, self).__init__(init_cfg)
|
||||
|
||||
assert readout_type in ['ignore', 'add', 'project']
|
||||
self.readout_type = readout_type
|
||||
self.patch_size = patch_size
|
||||
|
||||
self.projects = nn.ModuleList([
|
||||
ConvModule(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channel,
|
||||
kernel_size=1,
|
||||
act_cfg=None,
|
||||
) for out_channel in out_channels
|
||||
])
|
||||
|
||||
self.resize_layers = nn.ModuleList([
|
||||
nn.ConvTranspose2d(
|
||||
in_channels=out_channels[0],
|
||||
out_channels=out_channels[0],
|
||||
kernel_size=4,
|
||||
stride=4,
|
||||
padding=0),
|
||||
nn.ConvTranspose2d(
|
||||
in_channels=out_channels[1],
|
||||
out_channels=out_channels[1],
|
||||
kernel_size=2,
|
||||
stride=2,
|
||||
padding=0),
|
||||
nn.Identity(),
|
||||
nn.Conv2d(
|
||||
in_channels=out_channels[3],
|
||||
out_channels=out_channels[3],
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1)
|
||||
])
|
||||
if self.readout_type == 'project':
|
||||
self.readout_projects = nn.ModuleList()
|
||||
for _ in range(len(self.projects)):
|
||||
self.readout_projects.append(
|
||||
nn.Sequential(
|
||||
Linear(2 * in_channels, in_channels),
|
||||
build_activation_layer(dict(type='GELU'))))
|
||||
|
||||
def forward(self, inputs):
|
||||
assert isinstance(inputs, list)
|
||||
out = []
|
||||
for i, x in enumerate(inputs):
|
||||
assert len(x) == 2
|
||||
x, cls_token = x[0], x[1]
|
||||
feature_shape = x.shape
|
||||
if self.readout_type == 'project':
|
||||
x = x.flatten(2).permute((0, 2, 1))
|
||||
readout = cls_token.unsqueeze(1).expand_as(x)
|
||||
x = self.readout_projects[i](torch.cat((x, readout), -1))
|
||||
x = x.permute(0, 2, 1).reshape(feature_shape)
|
||||
elif self.readout_type == 'add':
|
||||
x = x.flatten(2) + cls_token.unsqueeze(-1)
|
||||
x = x.reshape(feature_shape)
|
||||
else:
|
||||
pass
|
||||
x = self.projects[i](x)
|
||||
x = self.resize_layers[i](x)
|
||||
out.append(x)
|
||||
return out
|
||||
|
||||
|
||||
class PreActResidualConvUnit(BaseModule):
|
||||
"""ResidualConvUnit, pre-activate residual unit.
|
||||
|
||||
Args:
|
||||
in_channels (int): number of channels in the input feature map.
|
||||
act_cfg (dict): dictionary to construct and config activation layer.
|
||||
norm_cfg (dict): dictionary to construct and config norm layer.
|
||||
stride (int): stride of the first block. Default: 1
|
||||
dilation (int): dilation rate for convs layers. Default: 1.
|
||||
init_cfg (dict, optional): Initialization config dict. Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
act_cfg,
|
||||
norm_cfg,
|
||||
stride=1,
|
||||
dilation=1,
|
||||
init_cfg=None):
|
||||
super(PreActResidualConvUnit, self).__init__(init_cfg)
|
||||
|
||||
self.conv1 = ConvModule(
|
||||
in_channels,
|
||||
in_channels,
|
||||
3,
|
||||
stride=stride,
|
||||
padding=dilation,
|
||||
dilation=dilation,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg,
|
||||
bias=False,
|
||||
order=('act', 'conv', 'norm'))
|
||||
|
||||
self.conv2 = ConvModule(
|
||||
in_channels,
|
||||
in_channels,
|
||||
3,
|
||||
padding=1,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg,
|
||||
bias=False,
|
||||
order=('act', 'conv', 'norm'))
|
||||
|
||||
def forward(self, inputs):
|
||||
inputs_ = inputs.clone()
|
||||
x = self.conv1(inputs)
|
||||
x = self.conv2(x)
|
||||
return x + inputs_
|
||||
|
||||
|
||||
class FeatureFusionBlock(BaseModule):
|
||||
"""FeatureFusionBlock, merge feature map from different stages.
|
||||
|
||||
Args:
|
||||
in_channels (int): Input channels.
|
||||
act_cfg (dict): The activation config for ResidualConvUnit.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
expand (bool): Whether expand the channels in post process block.
|
||||
Default: False.
|
||||
align_corners (bool): align_corner setting for bilinear upsample.
|
||||
Default: True.
|
||||
init_cfg (dict, optional): Initialization config dict. Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
act_cfg,
|
||||
norm_cfg,
|
||||
expand=False,
|
||||
align_corners=True,
|
||||
init_cfg=None):
|
||||
super(FeatureFusionBlock, self).__init__(init_cfg)
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.expand = expand
|
||||
self.align_corners = align_corners
|
||||
|
||||
self.out_channels = in_channels
|
||||
if self.expand:
|
||||
self.out_channels = in_channels // 2
|
||||
|
||||
self.project = ConvModule(
|
||||
self.in_channels,
|
||||
self.out_channels,
|
||||
kernel_size=1,
|
||||
act_cfg=None,
|
||||
bias=True)
|
||||
|
||||
self.res_conv_unit1 = PreActResidualConvUnit(
|
||||
in_channels=self.in_channels, act_cfg=act_cfg, norm_cfg=norm_cfg)
|
||||
self.res_conv_unit2 = PreActResidualConvUnit(
|
||||
in_channels=self.in_channels, act_cfg=act_cfg, norm_cfg=norm_cfg)
|
||||
|
||||
def forward(self, *inputs):
|
||||
x = inputs[0]
|
||||
if len(inputs) == 2:
|
||||
if x.shape != inputs[1].shape:
|
||||
res = resize(
|
||||
inputs[1],
|
||||
size=(x.shape[2], x.shape[3]),
|
||||
mode='bilinear',
|
||||
align_corners=False)
|
||||
else:
|
||||
res = inputs[1]
|
||||
x = x + self.res_conv_unit1(res)
|
||||
x = self.res_conv_unit2(x)
|
||||
x = resize(
|
||||
x,
|
||||
scale_factor=2,
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
x = self.project(x)
|
||||
return x
|
||||
|
||||
|
||||
@HEADS.register_module()
|
||||
class DPTHead(BaseDecodeHead):
|
||||
"""Vision Transformers for Dense Prediction.
|
||||
|
||||
This head is implemented of `DPT <https://arxiv.org/abs/2103.13413>`_.
|
||||
|
||||
Args:
|
||||
embed_dims (int): The embed dimension of the ViT backbone.
|
||||
Default: 768.
|
||||
post_process_channels (List): Out channels of post process conv
|
||||
layers. Default: [96, 192, 384, 768].
|
||||
readout_type (str): Type of readout operation. Default: 'ignore'.
|
||||
patch_size (int): The patch size. Default: 16.
|
||||
expand_channels (bool): Whether expand the channels in post process
|
||||
block. Default: False.
|
||||
act_cfg (dict): The activation config for residual conv unit.
|
||||
Defalut dict(type='ReLU').
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='BN').
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
embed_dims=768,
|
||||
post_process_channels=[96, 192, 384, 768],
|
||||
readout_type='ignore',
|
||||
patch_size=16,
|
||||
expand_channels=False,
|
||||
act_cfg=dict(type='ReLU'),
|
||||
norm_cfg=dict(type='BN'),
|
||||
**kwargs):
|
||||
super(DPTHead, self).__init__(**kwargs)
|
||||
|
||||
self.in_channels = self.in_channels
|
||||
self.expand_channels = expand_channels
|
||||
self.reassemble_blocks = ReassembleBlocks(embed_dims,
|
||||
post_process_channels,
|
||||
readout_type, patch_size)
|
||||
|
||||
self.post_process_channels = [
|
||||
channel * math.pow(2, i) if expand_channels else channel
|
||||
for i, channel in enumerate(post_process_channels)
|
||||
]
|
||||
self.convs = nn.ModuleList()
|
||||
for channel in self.post_process_channels:
|
||||
self.convs.append(
|
||||
ConvModule(
|
||||
channel,
|
||||
self.channels,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
act_cfg=None,
|
||||
bias=False))
|
||||
self.fusion_blocks = nn.ModuleList()
|
||||
for _ in range(len(self.convs)):
|
||||
self.fusion_blocks.append(
|
||||
FeatureFusionBlock(self.channels, act_cfg, norm_cfg))
|
||||
self.fusion_blocks[0].res_conv_unit1 = None
|
||||
self.project = ConvModule(
|
||||
self.channels,
|
||||
self.channels,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
norm_cfg=norm_cfg)
|
||||
self.num_fusion_blocks = len(self.fusion_blocks)
|
||||
self.num_reassemble_blocks = len(self.reassemble_blocks.resize_layers)
|
||||
self.num_post_process_channels = len(self.post_process_channels)
|
||||
assert self.num_fusion_blocks == self.num_reassemble_blocks
|
||||
assert self.num_reassemble_blocks == self.num_post_process_channels
|
||||
|
||||
def forward(self, inputs):
|
||||
assert len(inputs) == self.num_reassemble_blocks
|
||||
x = self._transform_inputs(inputs)
|
||||
x = self.reassemble_blocks(x)
|
||||
x = [self.convs[i](feature) for i, feature in enumerate(x)]
|
||||
out = self.fusion_blocks[0](x[-1])
|
||||
for i in range(1, len(self.fusion_blocks)):
|
||||
out = self.fusion_blocks[i](out, x[-(i + 1)])
|
||||
out = self.project(out)
|
||||
out = self.cls_seg(out)
|
||||
return out
|
@ -8,6 +8,7 @@ Import:
|
||||
- configs/deeplabv3plus/deeplabv3plus.yml
|
||||
- configs/dmnet/dmnet.yml
|
||||
- configs/dnlnet/dnlnet.yml
|
||||
- configs/dpt/dpt.yml
|
||||
- configs/emanet/emanet.yml
|
||||
- configs/encnet/encnet.yml
|
||||
- configs/fastscnn/fastscnn.yml
|
||||
|
48
tests/test_models/test_heads/test_dpt_head.py
Normal file
48
tests/test_models/test_heads/test_dpt_head.py
Normal file
@ -0,0 +1,48 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmseg.models.decode_heads import DPTHead
|
||||
|
||||
|
||||
def test_dpt_head():
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# input_transform must be 'multiple_select'
|
||||
head = DPTHead(
|
||||
in_channels=[768, 768, 768, 768],
|
||||
channels=256,
|
||||
num_classes=19,
|
||||
in_index=[0, 1, 2, 3])
|
||||
|
||||
head = DPTHead(
|
||||
in_channels=[768, 768, 768, 768],
|
||||
channels=256,
|
||||
num_classes=19,
|
||||
in_index=[0, 1, 2, 3],
|
||||
input_transform='multiple_select')
|
||||
|
||||
inputs = [[torch.randn(4, 768, 2, 2),
|
||||
torch.randn(4, 768)] for _ in range(4)]
|
||||
output = head(inputs)
|
||||
assert output.shape == torch.Size((4, 19, 16, 16))
|
||||
|
||||
# test readout operation
|
||||
head = DPTHead(
|
||||
in_channels=[768, 768, 768, 768],
|
||||
channels=256,
|
||||
num_classes=19,
|
||||
in_index=[0, 1, 2, 3],
|
||||
input_transform='multiple_select',
|
||||
readout_type='add')
|
||||
output = head(inputs)
|
||||
assert output.shape == torch.Size((4, 19, 16, 16))
|
||||
|
||||
head = DPTHead(
|
||||
in_channels=[768, 768, 768, 768],
|
||||
channels=256,
|
||||
num_classes=19,
|
||||
in_index=[0, 1, 2, 3],
|
||||
input_transform='multiple_select',
|
||||
readout_type='project')
|
||||
output = head(inputs)
|
||||
assert output.shape == torch.Size((4, 19, 16, 16))
|
Loading…
x
Reference in New Issue
Block a user