mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
[Feature]: Add MAE (#1307)
* [Fix]: Fix lint * [WIP]: Add mae seg config * [Feature]: Add MAE seg * [Fix]: Fix mae dataset img scale bug * [Fix]: Fix lint * [Feature]: Change mae config to mae_segmentation's config * [Feature]: Add interpolate pe when loading * [Fix]: Fix pos_embed not used bug * [Fix]: Fix lint * [Fix]: Init rel pos embed with zeros * [Fix]: Fix lint * [Fix]: Change the type name of backbone to MAE * [Fix]: Delete ade20k_512x512.py * [Fix]: Use mmseg provided ade20k.py * [Fix]: Change 1 sample per gpu to 2 samples per gpu * [Fix]: Fix conflict * [Refactor]: Use the TransformerEncoderLayer of BEiT * [Feature]: Add UT * [Fix]: Change the default value of qv bias to False * [Fix]: Initialize relative pos table with zeros * [Fix]: Delete redundant code in mae * [Fix]: Fix lint * [Fix]: Rename qkv_bias to qv_bias * [Fix]: Add docstring to weight_init of MAEAttention * [Refactor]: Delete qv_bias param * [Fix]: Add reference to fix_init_weight * [Fix]: Fix lint * [Fix]: Delete extra crop size * [Refactor]: Rename mae * [Fix]: Set bias to True * [Fix]: Delete redundant params * [Fix]: Fix lint * [Fix]: Fix UT * [Fix]: Add resize abs pos embed * [Fix]: Fix UT * [Refactor]: Use build layer * [Fix]: Add licsense and fix docstring * [Fix]: Fix docstring * [Feature]: Add README metafile * [Fix]: Change 640 to 512 * [Fix]: Fix README * fix readme of MAE Co-authored-by: MengzhangLI <mcmong@pku.edu.cn>
This commit is contained in:
parent
69b28e0b59
commit
6563cb513e
49
configs/_base_/models/upernet_mae.py
Normal file
49
configs/_base_/models/upernet_mae.py
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
norm_cfg = dict(type='SyncBN', requires_grad=True)
|
||||||
|
model = dict(
|
||||||
|
type='EncoderDecoder',
|
||||||
|
pretrained=None,
|
||||||
|
backbone=dict(
|
||||||
|
type='MAE',
|
||||||
|
img_size=(640, 640),
|
||||||
|
patch_size=16,
|
||||||
|
in_channels=3,
|
||||||
|
embed_dims=768,
|
||||||
|
num_layers=12,
|
||||||
|
num_heads=12,
|
||||||
|
mlp_ratio=4,
|
||||||
|
out_indices=(3, 5, 7, 11),
|
||||||
|
attn_drop_rate=0.0,
|
||||||
|
drop_path_rate=0.1,
|
||||||
|
norm_cfg=dict(type='LN', eps=1e-6),
|
||||||
|
act_cfg=dict(type='GELU'),
|
||||||
|
norm_eval=False,
|
||||||
|
init_values=0.1),
|
||||||
|
neck=dict(type='Feature2Pyramid', embed_dim=768, rescales=[4, 2, 1, 0.5]),
|
||||||
|
decode_head=dict(
|
||||||
|
type='UPerHead',
|
||||||
|
in_channels=[384, 384, 384, 384],
|
||||||
|
in_index=[0, 1, 2, 3],
|
||||||
|
pool_scales=(1, 2, 3, 6),
|
||||||
|
channels=512,
|
||||||
|
dropout_ratio=0.1,
|
||||||
|
num_classes=19,
|
||||||
|
norm_cfg=norm_cfg,
|
||||||
|
align_corners=False,
|
||||||
|
loss_decode=dict(
|
||||||
|
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
|
||||||
|
auxiliary_head=dict(
|
||||||
|
type='FCNHead',
|
||||||
|
in_channels=384,
|
||||||
|
in_index=2,
|
||||||
|
channels=256,
|
||||||
|
num_convs=1,
|
||||||
|
concat_input=False,
|
||||||
|
dropout_ratio=0.1,
|
||||||
|
num_classes=19,
|
||||||
|
norm_cfg=norm_cfg,
|
||||||
|
align_corners=False,
|
||||||
|
loss_decode=dict(
|
||||||
|
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
|
||||||
|
# model training and testing settings
|
||||||
|
train_cfg=dict(),
|
||||||
|
test_cfg=dict(mode='whole'))
|
81
configs/mae/README.md
Normal file
81
configs/mae/README.md
Normal file
@ -0,0 +1,81 @@
|
|||||||
|
# MAE
|
||||||
|
|
||||||
|
[Masked Autoencoders Are Scalable Vision Learners](https://arxiv.org/abs/2111.06377)
|
||||||
|
|
||||||
|
## Introduction
|
||||||
|
|
||||||
|
<!-- [BACKBONE] -->
|
||||||
|
|
||||||
|
<a href="https://github.com/facebookresearch/mae">Official Repo</a>
|
||||||
|
|
||||||
|
<a href="https://github.com/open-mmlab/mmsegmentation/blob/v0.24.0/mmseg/models/backbones/mae.py#46">Code Snippet</a>
|
||||||
|
|
||||||
|
## Abstract
|
||||||
|
|
||||||
|
<!-- [ABSTRACT] -->
|
||||||
|
|
||||||
|
This paper shows that masked autoencoders (MAE) are scalable self-supervised learners for computer vision. Our MAE approach is simple: we mask random patches of the input image and reconstruct the missing pixels. It is based on two core designs. First, we develop an asymmetric encoder-decoder architecture, with an encoder that operates only on the visible subset of patches (without mask tokens), along with a lightweight decoder that reconstructs the original image from the latent representation and mask tokens. Second, we find that masking a high proportion of the input image, e.g., 75%, yields a nontrivial and meaningful self-supervisory task. Coupling these two designs enables us to train large models efficiently and effectively: we accelerate training (by 3x or more) and improve accuracy. Our scalable approach allows for learning high-capacity models that generalize well: e.g., a vanilla ViT-Huge model achieves the best accuracy (87.8%) among methods that use only ImageNet-1K data. Transfer performance in downstream tasks outperforms supervised pre-training and shows promising scaling behavior.
|
||||||
|
|
||||||
|
<!-- [IMAGE] -->
|
||||||
|
<div align=center>
|
||||||
|
<img src="https://user-images.githubusercontent.com/24582831/165456416-1cba54bf-b1b5-4bdf-ad86-d6390de7f342.png" width="70%"/>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
## Citation
|
||||||
|
|
||||||
|
```bibtex
|
||||||
|
@article{he2021masked,
|
||||||
|
title={Masked autoencoders are scalable vision learners},
|
||||||
|
author={He, Kaiming and Chen, Xinlei and Xie, Saining and Li, Yanghao and Doll{\'a}r, Piotr and Girshick, Ross},
|
||||||
|
journal={arXiv preprint arXiv:2111.06377},
|
||||||
|
year={2021}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
To use other repositories' pre-trained models, it is necessary to convert keys.
|
||||||
|
|
||||||
|
We provide a script [`beit2mmseg.py`](../../tools/model_converters/beit2mmseg.py) in the tools directory to convert the key of MAE model from [the official repo](https://github.com/facebookresearch/mae) to MMSegmentation style.
|
||||||
|
|
||||||
|
```shell
|
||||||
|
python tools/model_converters/beit2mmseg.py ${PRETRAIN_PATH} ${STORE_PATH}
|
||||||
|
```
|
||||||
|
|
||||||
|
E.g.
|
||||||
|
|
||||||
|
```shell
|
||||||
|
python tools/model_converters/beit2mmseg.py https://dl.fbaipublicfiles.com/mae/pretrain/mae_pretrain_vit_base.pth pretrain/mae_pretrain_vit_base_mmcls.pth
|
||||||
|
```
|
||||||
|
|
||||||
|
This script convert model from `PRETRAIN_PATH` and store the converted model in `STORE_PATH`.
|
||||||
|
|
||||||
|
In our default setting, pretrained models could be defined below:
|
||||||
|
|
||||||
|
| pretrained models | original models |
|
||||||
|
| ------ | -------- |
|
||||||
|
|mae_pretrain_vit_base_mmcls.pth | ['mae_pretrain_vit_base'](https://dl.fbaipublicfiles.com/mae/pretrain/mae_pretrain_vit_base.pth) |
|
||||||
|
|
||||||
|
Verify the single-scale results of the model:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
sh tools/dist_test.sh \
|
||||||
|
configs/mae/upernet_mae-base_fp16_8x2_512x512_160k_ade20k.py \
|
||||||
|
upernet_mae-base_fp16_8x2_512x512_160k_ade20k_20220426_174752-f92a2975.pth $GPUS --eval mIoU
|
||||||
|
```
|
||||||
|
|
||||||
|
Since relative position embedding requires the input length and width to be equal, the sliding window is adopted for multi-scale inference. So we set min_size=512, that is, the shortest edge is 512. So the multi-scale inference of config is performed separately, instead of '--aug-test'. For multi-scale inference:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
sh tools/dist_test.sh \
|
||||||
|
configs/mae/upernet_mae-base_fp16_512x512_160k_ade20k_ms.py \
|
||||||
|
upernet_mae-base_fp16_8x2_512x512_160k_ade20k_20220426_174752-f92a2975.pth $GPUS --eval mIoU
|
||||||
|
```
|
||||||
|
|
||||||
|
## Results and models
|
||||||
|
|
||||||
|
### ADE20K
|
||||||
|
|
||||||
|
| Method | Backbone | Crop Size | pretrain | pretrain img size | Batch Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | config | download |
|
||||||
|
| ------ | -------- | --------- | ---------- | ------- | -------- | --- | --- | -------------- | ----- | ------------: | -------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
|
||||||
|
| UperNet | ViT-B | 512x512 | ImageNet-1K | 224x224 | 16 | 160000 | 9.96 | 7.14 | 48.13 | 48.70 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/mae/upernet_mae-base_fp16_8x2_512x512_160k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/mae/upernet_mae-base_fp16_8x2_512x512_160k_ade20k/upernet_mae-base_fp16_8x2_512x512_160k_ade20k_20220426_174752-f92a2975.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/mae/upernet_mae-base_fp16_8x2_512x512_160k_ade20k/upernet_mae-base_fp16_8x2_512x512_160k_ade20k_20220426_174752.log.json) |
|
23
configs/mae/mae.yml
Normal file
23
configs/mae/mae.yml
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
Models:
|
||||||
|
- Name: upernet_mae-base_fp16_8x2_512x512_160k_ade20k
|
||||||
|
In Collection: UperNet
|
||||||
|
Metadata:
|
||||||
|
backbone: ViT-B
|
||||||
|
crop size: (512,512)
|
||||||
|
lr schd: 160000
|
||||||
|
inference time (ms/im):
|
||||||
|
- value: 140.06
|
||||||
|
hardware: V100
|
||||||
|
backend: PyTorch
|
||||||
|
batch size: 1
|
||||||
|
mode: FP16
|
||||||
|
resolution: (512,512)
|
||||||
|
Training Memory (GB): 9.96
|
||||||
|
Results:
|
||||||
|
- Task: Semantic Segmentation
|
||||||
|
Dataset: ADE20K
|
||||||
|
Metrics:
|
||||||
|
mIoU: 48.13
|
||||||
|
mIoU(ms+flip): 48.7
|
||||||
|
Config: configs/mae/upernet_mae-base_fp16_8x2_512x512_160k_ade20k.py
|
||||||
|
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/mae/upernet_mae-base_fp16_8x2_512x512_160k_ade20k/upernet_mae-base_fp16_8x2_512x512_160k_ade20k_20220426_174752-f92a2975.pth
|
24
configs/mae/upernet_mae-base_fp16_512x512_160k_ade20k_ms.py
Normal file
24
configs/mae/upernet_mae-base_fp16_512x512_160k_ade20k_ms.py
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
_base_ = './upernet_mae-base_fp16_8x2_512x512_160k_ade20k.py'
|
||||||
|
|
||||||
|
img_norm_cfg = dict(
|
||||||
|
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
||||||
|
|
||||||
|
test_pipeline = [
|
||||||
|
dict(type='LoadImageFromFile'),
|
||||||
|
dict(
|
||||||
|
type='MultiScaleFlipAug',
|
||||||
|
img_scale=(2048, 512),
|
||||||
|
img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
|
||||||
|
flip=True,
|
||||||
|
transforms=[
|
||||||
|
dict(type='Resize', keep_ratio=True, min_size=512),
|
||||||
|
dict(type='RandomFlip'),
|
||||||
|
dict(type='Normalize', **img_norm_cfg),
|
||||||
|
dict(type='ImageToTensor', keys=['img']),
|
||||||
|
dict(type='Collect', keys=['img']),
|
||||||
|
])
|
||||||
|
]
|
||||||
|
data = dict(
|
||||||
|
val=dict(pipeline=test_pipeline),
|
||||||
|
test=dict(pipeline=test_pipeline),
|
||||||
|
samples_per_gpu=2)
|
48
configs/mae/upernet_mae-base_fp16_8x2_512x512_160k_ade20k.py
Normal file
48
configs/mae/upernet_mae-base_fp16_8x2_512x512_160k_ade20k.py
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
_base_ = [
|
||||||
|
'../_base_/models/upernet_mae.py', '../_base_/datasets/ade20k.py',
|
||||||
|
'../_base_/default_runtime.py', '../_base_/schedules/schedule_160k.py'
|
||||||
|
]
|
||||||
|
|
||||||
|
model = dict(
|
||||||
|
pretrained='./pretrain/mae_pretrain_vit_base_mmcls.pth',
|
||||||
|
backbone=dict(
|
||||||
|
type='MAE',
|
||||||
|
img_size=(512, 512),
|
||||||
|
patch_size=16,
|
||||||
|
embed_dims=768,
|
||||||
|
num_layers=12,
|
||||||
|
num_heads=12,
|
||||||
|
mlp_ratio=4,
|
||||||
|
init_values=1.0,
|
||||||
|
drop_path_rate=0.1,
|
||||||
|
out_indices=[3, 5, 7, 11]),
|
||||||
|
neck=dict(embed_dim=768, rescales=[4, 2, 1, 0.5]),
|
||||||
|
decode_head=dict(
|
||||||
|
in_channels=[768, 768, 768, 768], num_classes=150, channels=768),
|
||||||
|
auxiliary_head=dict(in_channels=768, num_classes=150),
|
||||||
|
test_cfg=dict(mode='slide', crop_size=(512, 512), stride=(341, 341)))
|
||||||
|
|
||||||
|
optimizer = dict(
|
||||||
|
_delete_=True,
|
||||||
|
type='AdamW',
|
||||||
|
lr=1e-4,
|
||||||
|
betas=(0.9, 0.999),
|
||||||
|
weight_decay=0.05,
|
||||||
|
constructor='LayerDecayOptimizerConstructor',
|
||||||
|
paramwise_cfg=dict(num_layers=12, layer_decay_rate=0.65))
|
||||||
|
|
||||||
|
lr_config = dict(
|
||||||
|
_delete_=True,
|
||||||
|
policy='poly',
|
||||||
|
warmup='linear',
|
||||||
|
warmup_iters=1500,
|
||||||
|
warmup_ratio=1e-6,
|
||||||
|
power=1.0,
|
||||||
|
min_lr=0.0,
|
||||||
|
by_epoch=False)
|
||||||
|
|
||||||
|
# mixed precision
|
||||||
|
fp16 = dict(loss_scale='dynamic')
|
||||||
|
|
||||||
|
# By default, models are trained on 8 GPUs with 2 images per GPU
|
||||||
|
data = dict(samples_per_gpu=2)
|
@ -7,6 +7,7 @@ from .erfnet import ERFNet
|
|||||||
from .fast_scnn import FastSCNN
|
from .fast_scnn import FastSCNN
|
||||||
from .hrnet import HRNet
|
from .hrnet import HRNet
|
||||||
from .icnet import ICNet
|
from .icnet import ICNet
|
||||||
|
from .mae import MAE
|
||||||
from .mit import MixVisionTransformer
|
from .mit import MixVisionTransformer
|
||||||
from .mobilenet_v2 import MobileNetV2
|
from .mobilenet_v2 import MobileNetV2
|
||||||
from .mobilenet_v3 import MobileNetV3
|
from .mobilenet_v3 import MobileNetV3
|
||||||
@ -25,5 +26,5 @@ __all__ = [
|
|||||||
'ResNeSt', 'MobileNetV2', 'UNet', 'CGNet', 'MobileNetV3',
|
'ResNeSt', 'MobileNetV2', 'UNet', 'CGNet', 'MobileNetV3',
|
||||||
'VisionTransformer', 'SwinTransformer', 'MixVisionTransformer',
|
'VisionTransformer', 'SwinTransformer', 'MixVisionTransformer',
|
||||||
'BiSeNetV1', 'BiSeNetV2', 'ICNet', 'TIMMBackbone', 'ERFNet', 'PCPVT',
|
'BiSeNetV1', 'BiSeNetV2', 'ICNet', 'TIMMBackbone', 'ERFNet', 'PCPVT',
|
||||||
'SVT', 'STDCNet', 'STDCContextPathNet', 'BEiT'
|
'SVT', 'STDCNet', 'STDCContextPathNet', 'BEiT', 'MAE'
|
||||||
]
|
]
|
||||||
|
261
mmseg/models/backbones/mae.py
Normal file
261
mmseg/models/backbones/mae.py
Normal file
@ -0,0 +1,261 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.import math
|
||||||
|
import math
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from mmcv.cnn.utils.weight_init import (constant_init, kaiming_init,
|
||||||
|
trunc_normal_)
|
||||||
|
from mmcv.runner import ModuleList, _load_checkpoint
|
||||||
|
from torch.nn.modules.batchnorm import _BatchNorm
|
||||||
|
|
||||||
|
from mmseg.utils import get_root_logger
|
||||||
|
from ..builder import BACKBONES
|
||||||
|
from .beit import BEiT, BEiTAttention, BEiTTransformerEncoderLayer
|
||||||
|
|
||||||
|
|
||||||
|
class MAEAttention(BEiTAttention):
|
||||||
|
"""Multi-head self-attention with relative position bias used in MAE.
|
||||||
|
|
||||||
|
This module is different from ``BEiTAttention`` by initializing the
|
||||||
|
relative bias table with zeros.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def init_weights(self):
|
||||||
|
"""Initialize relative position bias with zeros."""
|
||||||
|
|
||||||
|
# As MAE initializes relative position bias as zeros and this class
|
||||||
|
# inherited from BEiT which initializes relative position bias
|
||||||
|
# with `trunc_normal`, `init_weights` here does
|
||||||
|
# nothing and just passes directly
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class MAETransformerEncoderLayer(BEiTTransformerEncoderLayer):
|
||||||
|
"""Implements one encoder layer in Vision Transformer.
|
||||||
|
|
||||||
|
This module is different from ``BEiTTransformerEncoderLayer`` by replacing
|
||||||
|
``BEiTAttention`` with ``MAEAttention``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def build_attn(self, attn_cfg):
|
||||||
|
self.attn = MAEAttention(**attn_cfg)
|
||||||
|
|
||||||
|
|
||||||
|
@BACKBONES.register_module()
|
||||||
|
class MAE(BEiT):
|
||||||
|
"""VisionTransformer with support for patch.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img_size (int | tuple): Input image size. Default: 224.
|
||||||
|
patch_size (int): The patch size. Default: 16.
|
||||||
|
in_channels (int): Number of input channels. Default: 3.
|
||||||
|
embed_dims (int): embedding dimension. Default: 768.
|
||||||
|
num_layers (int): depth of transformer. Default: 12.
|
||||||
|
num_heads (int): number of attention heads. Default: 12.
|
||||||
|
mlp_ratio (int): ratio of mlp hidden dim to embedding dim.
|
||||||
|
Default: 4.
|
||||||
|
out_indices (list | tuple | int): Output from which stages.
|
||||||
|
Default: -1.
|
||||||
|
attn_drop_rate (float): The drop out rate for attention layer.
|
||||||
|
Default 0.0
|
||||||
|
drop_path_rate (float): stochastic depth rate. Default 0.0.
|
||||||
|
norm_cfg (dict): Config dict for normalization layer.
|
||||||
|
Default: dict(type='LN')
|
||||||
|
act_cfg (dict): The activation config for FFNs.
|
||||||
|
Default: dict(type='GELU').
|
||||||
|
patch_norm (bool): Whether to add a norm in PatchEmbed Block.
|
||||||
|
Default: False.
|
||||||
|
final_norm (bool): Whether to add a additional layer to normalize
|
||||||
|
final feature map. Default: False.
|
||||||
|
num_fcs (int): The number of fully-connected layers for FFNs.
|
||||||
|
Default: 2.
|
||||||
|
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
||||||
|
freeze running stats (mean and var). Note: Effect on Batch Norm
|
||||||
|
and its variants only. Default: False.
|
||||||
|
pretrained (str, optional): model pretrained path. Default: None.
|
||||||
|
init_values (float): Initialize the values of Attention and FFN
|
||||||
|
with learnable scaling. Defaults to 0.1.
|
||||||
|
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||||
|
Default: None.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
img_size=224,
|
||||||
|
patch_size=16,
|
||||||
|
in_channels=3,
|
||||||
|
embed_dims=768,
|
||||||
|
num_layers=12,
|
||||||
|
num_heads=12,
|
||||||
|
mlp_ratio=4,
|
||||||
|
out_indices=-1,
|
||||||
|
attn_drop_rate=0.,
|
||||||
|
drop_path_rate=0.,
|
||||||
|
norm_cfg=dict(type='LN'),
|
||||||
|
act_cfg=dict(type='GELU'),
|
||||||
|
patch_norm=False,
|
||||||
|
final_norm=False,
|
||||||
|
num_fcs=2,
|
||||||
|
norm_eval=False,
|
||||||
|
pretrained=None,
|
||||||
|
init_values=0.1,
|
||||||
|
init_cfg=None):
|
||||||
|
super(MAE, self).__init__(
|
||||||
|
img_size=img_size,
|
||||||
|
patch_size=patch_size,
|
||||||
|
in_channels=in_channels,
|
||||||
|
embed_dims=embed_dims,
|
||||||
|
num_layers=num_layers,
|
||||||
|
num_heads=num_heads,
|
||||||
|
mlp_ratio=mlp_ratio,
|
||||||
|
out_indices=out_indices,
|
||||||
|
qv_bias=False,
|
||||||
|
attn_drop_rate=attn_drop_rate,
|
||||||
|
drop_path_rate=drop_path_rate,
|
||||||
|
norm_cfg=norm_cfg,
|
||||||
|
act_cfg=act_cfg,
|
||||||
|
patch_norm=patch_norm,
|
||||||
|
final_norm=final_norm,
|
||||||
|
num_fcs=num_fcs,
|
||||||
|
norm_eval=norm_eval,
|
||||||
|
pretrained=pretrained,
|
||||||
|
init_values=init_values,
|
||||||
|
init_cfg=init_cfg)
|
||||||
|
|
||||||
|
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims))
|
||||||
|
|
||||||
|
self.num_patches = self.patch_shape[0] * self.patch_shape[1]
|
||||||
|
self.pos_embed = nn.Parameter(
|
||||||
|
torch.zeros(1, self.num_patches + 1, embed_dims))
|
||||||
|
|
||||||
|
def _build_layers(self):
|
||||||
|
dpr = [
|
||||||
|
x.item()
|
||||||
|
for x in torch.linspace(0, self.drop_path_rate, self.num_layers)
|
||||||
|
]
|
||||||
|
self.layers = ModuleList()
|
||||||
|
for i in range(self.num_layers):
|
||||||
|
self.layers.append(
|
||||||
|
MAETransformerEncoderLayer(
|
||||||
|
embed_dims=self.embed_dims,
|
||||||
|
num_heads=self.num_heads,
|
||||||
|
feedforward_channels=self.mlp_ratio * self.embed_dims,
|
||||||
|
attn_drop_rate=self.attn_drop_rate,
|
||||||
|
drop_path_rate=dpr[i],
|
||||||
|
num_fcs=self.num_fcs,
|
||||||
|
bias=True,
|
||||||
|
act_cfg=self.act_cfg,
|
||||||
|
norm_cfg=self.norm_cfg,
|
||||||
|
window_size=self.patch_shape,
|
||||||
|
init_values=self.init_values))
|
||||||
|
|
||||||
|
def fix_init_weight(self):
|
||||||
|
"""Rescale the initialization according to layer id.
|
||||||
|
|
||||||
|
This function is copied from https://github.com/microsoft/unilm/blob/master/beit/modeling_pretrain.py. # noqa: E501
|
||||||
|
Copyright (c) Microsoft Corporation
|
||||||
|
Licensed under the MIT License
|
||||||
|
"""
|
||||||
|
|
||||||
|
def rescale(param, layer_id):
|
||||||
|
param.div_(math.sqrt(2.0 * layer_id))
|
||||||
|
|
||||||
|
for layer_id, layer in enumerate(self.layers):
|
||||||
|
rescale(layer.attn.proj.weight.data, layer_id + 1)
|
||||||
|
rescale(layer.ffn.layers[1].weight.data, layer_id + 1)
|
||||||
|
|
||||||
|
def init_weights(self):
|
||||||
|
|
||||||
|
def _init_weights(m):
|
||||||
|
if isinstance(m, nn.Linear):
|
||||||
|
trunc_normal_(m.weight, std=.02)
|
||||||
|
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||||
|
nn.init.constant_(m.bias, 0)
|
||||||
|
elif isinstance(m, nn.LayerNorm):
|
||||||
|
nn.init.constant_(m.bias, 0)
|
||||||
|
nn.init.constant_(m.weight, 1.0)
|
||||||
|
|
||||||
|
self.apply(_init_weights)
|
||||||
|
self.fix_init_weight()
|
||||||
|
|
||||||
|
if (isinstance(self.init_cfg, dict)
|
||||||
|
and self.init_cfg.get('type') == 'Pretrained'):
|
||||||
|
logger = get_root_logger()
|
||||||
|
checkpoint = _load_checkpoint(
|
||||||
|
self.init_cfg['checkpoint'], logger=logger, map_location='cpu')
|
||||||
|
state_dict = self.resize_rel_pos_embed(checkpoint)
|
||||||
|
state_dict = self.resize_abs_pos_embed(state_dict)
|
||||||
|
self.load_state_dict(state_dict, False)
|
||||||
|
elif self.init_cfg is not None:
|
||||||
|
super(MAE, self).init_weights()
|
||||||
|
else:
|
||||||
|
# We only implement the 'jax_impl' initialization implemented at
|
||||||
|
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L353 # noqa: E501
|
||||||
|
# Copyright 2019 Ross Wightman
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License")
|
||||||
|
trunc_normal_(self.cls_token, std=.02)
|
||||||
|
for n, m in self.named_modules():
|
||||||
|
if isinstance(m, nn.Linear):
|
||||||
|
trunc_normal_(m.weight, std=.02)
|
||||||
|
if m.bias is not None:
|
||||||
|
if 'ffn' in n:
|
||||||
|
nn.init.normal_(m.bias, mean=0., std=1e-6)
|
||||||
|
else:
|
||||||
|
nn.init.constant_(m.bias, 0)
|
||||||
|
elif isinstance(m, nn.Conv2d):
|
||||||
|
kaiming_init(m, mode='fan_in', bias=0.)
|
||||||
|
elif isinstance(m, (_BatchNorm, nn.GroupNorm, nn.LayerNorm)):
|
||||||
|
constant_init(m, val=1.0, bias=0.)
|
||||||
|
|
||||||
|
def resize_abs_pos_embed(self, state_dict):
|
||||||
|
if 'pos_embed' in state_dict:
|
||||||
|
pos_embed_checkpoint = state_dict['pos_embed']
|
||||||
|
embedding_size = pos_embed_checkpoint.shape[-1]
|
||||||
|
num_extra_tokens = self.pos_embed.shape[-2] - self.num_patches
|
||||||
|
# height (== width) for the checkpoint position embedding
|
||||||
|
orig_size = int(
|
||||||
|
(pos_embed_checkpoint.shape[-2] - num_extra_tokens)**0.5)
|
||||||
|
# height (== width) for the new position embedding
|
||||||
|
new_size = int(self.num_patches**0.5)
|
||||||
|
# class_token and dist_token are kept unchanged
|
||||||
|
if orig_size != new_size:
|
||||||
|
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
|
||||||
|
# only the position tokens are interpolated
|
||||||
|
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
|
||||||
|
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size,
|
||||||
|
embedding_size).permute(
|
||||||
|
0, 3, 1, 2)
|
||||||
|
pos_tokens = torch.nn.functional.interpolate(
|
||||||
|
pos_tokens,
|
||||||
|
size=(new_size, new_size),
|
||||||
|
mode='bicubic',
|
||||||
|
align_corners=False)
|
||||||
|
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
|
||||||
|
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
|
||||||
|
state_dict['pos_embed'] = new_pos_embed
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
def forward(self, inputs):
|
||||||
|
B = inputs.shape[0]
|
||||||
|
|
||||||
|
x, hw_shape = self.patch_embed(inputs)
|
||||||
|
|
||||||
|
# stole cls_tokens impl from Phil Wang, thanks
|
||||||
|
cls_tokens = self.cls_token.expand(B, -1, -1)
|
||||||
|
x = torch.cat((cls_tokens, x), dim=1)
|
||||||
|
x = x + self.pos_embed
|
||||||
|
|
||||||
|
outs = []
|
||||||
|
for i, layer in enumerate(self.layers):
|
||||||
|
x = layer(x)
|
||||||
|
if i == len(self.layers) - 1:
|
||||||
|
if self.final_norm:
|
||||||
|
x = self.norm1(x)
|
||||||
|
if i in self.out_indices:
|
||||||
|
out = x[:, 1:]
|
||||||
|
B, _, C = out.shape
|
||||||
|
out = out.reshape(B, hw_shape[0], hw_shape[1],
|
||||||
|
C).permute(0, 3, 1, 2).contiguous()
|
||||||
|
outs.append(out)
|
||||||
|
|
||||||
|
return tuple(outs)
|
@ -24,6 +24,7 @@ Import:
|
|||||||
- configs/icnet/icnet.yml
|
- configs/icnet/icnet.yml
|
||||||
- configs/isanet/isanet.yml
|
- configs/isanet/isanet.yml
|
||||||
- configs/knet/knet.yml
|
- configs/knet/knet.yml
|
||||||
|
- configs/mae/mae.yml
|
||||||
- configs/mobilenet_v2/mobilenet_v2.yml
|
- configs/mobilenet_v2/mobilenet_v2.yml
|
||||||
- configs/mobilenet_v3/mobilenet_v3.yml
|
- configs/mobilenet_v3/mobilenet_v3.yml
|
||||||
- configs/nonlocal_net/nonlocal_net.yml
|
- configs/nonlocal_net/nonlocal_net.yml
|
||||||
|
183
tests/test_models/test_backbones/test_mae.py
Normal file
183
tests/test_models/test_backbones/test_mae.py
Normal file
@ -0,0 +1,183 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from mmseg.models.backbones.mae import MAE
|
||||||
|
from .utils import check_norm_state
|
||||||
|
|
||||||
|
|
||||||
|
def test_mae_backbone():
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
# pretrained must be a string path
|
||||||
|
model = MAE()
|
||||||
|
model.init_weights(pretrained=0)
|
||||||
|
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
# img_size must be int or tuple
|
||||||
|
model = MAE(img_size=512.0)
|
||||||
|
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
# out_indices must be int ,list or tuple
|
||||||
|
model = MAE(out_indices=1.)
|
||||||
|
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
# The length of img_size tuple must be lower than 3.
|
||||||
|
MAE(img_size=(224, 224, 224))
|
||||||
|
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
# Pretrained must be None or Str.
|
||||||
|
MAE(pretrained=123)
|
||||||
|
|
||||||
|
# Test img_size isinstance tuple
|
||||||
|
imgs = torch.randn(1, 3, 224, 224)
|
||||||
|
model = MAE(img_size=(224, ))
|
||||||
|
model.init_weights()
|
||||||
|
model(imgs)
|
||||||
|
|
||||||
|
# Test img_size isinstance tuple
|
||||||
|
imgs = torch.randn(1, 3, 224, 224)
|
||||||
|
model = MAE(img_size=(224, 224))
|
||||||
|
model(imgs)
|
||||||
|
|
||||||
|
# Test norm_eval = True
|
||||||
|
model = MAE(norm_eval=True)
|
||||||
|
model.train()
|
||||||
|
|
||||||
|
# Test BEiT backbone with input size of 224 and patch size of 16
|
||||||
|
model = MAE()
|
||||||
|
model.init_weights()
|
||||||
|
model.train()
|
||||||
|
|
||||||
|
# Test out_indices = list
|
||||||
|
model = MAE(out_indices=[2, 4, 8, 12])
|
||||||
|
model.train()
|
||||||
|
|
||||||
|
assert check_norm_state(model.modules(), True)
|
||||||
|
|
||||||
|
# Test image size = (224, 224)
|
||||||
|
imgs = torch.randn(1, 3, 224, 224)
|
||||||
|
feat = model(imgs)
|
||||||
|
assert feat[-1].shape == (1, 768, 14, 14)
|
||||||
|
|
||||||
|
# Test MAE backbone with input size of 256 and patch size of 16
|
||||||
|
model = MAE(img_size=(256, 256))
|
||||||
|
model.init_weights()
|
||||||
|
model.train()
|
||||||
|
imgs = torch.randn(1, 3, 256, 256)
|
||||||
|
feat = model(imgs)
|
||||||
|
assert feat[-1].shape == (1, 768, 16, 16)
|
||||||
|
|
||||||
|
# Test MAE backbone with input size of 32 and patch size of 16
|
||||||
|
model = MAE(img_size=(32, 32))
|
||||||
|
model.init_weights()
|
||||||
|
model.train()
|
||||||
|
imgs = torch.randn(1, 3, 32, 32)
|
||||||
|
feat = model(imgs)
|
||||||
|
assert feat[-1].shape == (1, 768, 2, 2)
|
||||||
|
|
||||||
|
# Test unbalanced size input image
|
||||||
|
model = MAE(img_size=(112, 224))
|
||||||
|
model.init_weights()
|
||||||
|
model.train()
|
||||||
|
imgs = torch.randn(1, 3, 112, 224)
|
||||||
|
feat = model(imgs)
|
||||||
|
assert feat[-1].shape == (1, 768, 7, 14)
|
||||||
|
|
||||||
|
# Test irregular input image
|
||||||
|
model = MAE(img_size=(234, 345))
|
||||||
|
model.init_weights()
|
||||||
|
model.train()
|
||||||
|
imgs = torch.randn(1, 3, 234, 345)
|
||||||
|
feat = model(imgs)
|
||||||
|
assert feat[-1].shape == (1, 768, 14, 21)
|
||||||
|
|
||||||
|
# Test init_values=0
|
||||||
|
model = MAE(init_values=0)
|
||||||
|
imgs = torch.randn(1, 3, 224, 224)
|
||||||
|
feat = model(imgs)
|
||||||
|
assert feat[-1].shape == (1, 768, 14, 14)
|
||||||
|
|
||||||
|
# Test final norm
|
||||||
|
model = MAE(final_norm=True)
|
||||||
|
imgs = torch.randn(1, 3, 224, 224)
|
||||||
|
feat = model(imgs)
|
||||||
|
assert feat[-1].shape == (1, 768, 14, 14)
|
||||||
|
|
||||||
|
# Test patch norm
|
||||||
|
model = MAE(patch_norm=True)
|
||||||
|
imgs = torch.randn(1, 3, 224, 224)
|
||||||
|
feat = model(imgs)
|
||||||
|
assert feat[-1].shape == (1, 768, 14, 14)
|
||||||
|
|
||||||
|
|
||||||
|
def test_mae_init():
|
||||||
|
path = 'PATH_THAT_DO_NOT_EXIST'
|
||||||
|
# Test all combinations of pretrained and init_cfg
|
||||||
|
# pretrained=None, init_cfg=None
|
||||||
|
model = MAE(pretrained=None, init_cfg=None)
|
||||||
|
assert model.init_cfg is None
|
||||||
|
model.init_weights()
|
||||||
|
|
||||||
|
# pretrained=None
|
||||||
|
# init_cfg loads pretrain from an non-existent file
|
||||||
|
model = MAE(
|
||||||
|
pretrained=None, init_cfg=dict(type='Pretrained', checkpoint=path))
|
||||||
|
assert model.init_cfg == dict(type='Pretrained', checkpoint=path)
|
||||||
|
# Test loading a checkpoint from an non-existent file
|
||||||
|
with pytest.raises(OSError):
|
||||||
|
model.init_weights()
|
||||||
|
|
||||||
|
# test resize_rel_pos_embed
|
||||||
|
value = torch.randn(732, 16)
|
||||||
|
abs_pos_embed_value = torch.rand(1, 17, 768)
|
||||||
|
ckpt = {
|
||||||
|
'state_dict': {
|
||||||
|
'layers.0.attn.relative_position_index': 0,
|
||||||
|
'layers.0.attn.relative_position_bias_table': value,
|
||||||
|
'pos_embed': abs_pos_embed_value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
model = MAE(img_size=(512, 512))
|
||||||
|
with pytest.raises(AttributeError):
|
||||||
|
model.resize_rel_pos_embed(ckpt)
|
||||||
|
|
||||||
|
# test resize abs pos embed
|
||||||
|
ckpt = model.resize_abs_pos_embed(ckpt['state_dict'])
|
||||||
|
|
||||||
|
# pretrained=None
|
||||||
|
# init_cfg=123, whose type is unsupported
|
||||||
|
model = MAE(pretrained=None, init_cfg=123)
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
model.init_weights()
|
||||||
|
|
||||||
|
# pretrained loads pretrain from an non-existent file
|
||||||
|
# init_cfg=None
|
||||||
|
model = MAE(pretrained=path, init_cfg=None)
|
||||||
|
assert model.init_cfg == dict(type='Pretrained', checkpoint=path)
|
||||||
|
# Test loading a checkpoint from an non-existent file
|
||||||
|
with pytest.raises(OSError):
|
||||||
|
model.init_weights()
|
||||||
|
|
||||||
|
# pretrained loads pretrain from an non-existent file
|
||||||
|
# init_cfg loads pretrain from an non-existent file
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
model = MAE(
|
||||||
|
pretrained=path, init_cfg=dict(type='Pretrained', checkpoint=path))
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
model = MAE(pretrained=path, init_cfg=123)
|
||||||
|
|
||||||
|
# pretrain=123, whose type is unsupported
|
||||||
|
# init_cfg=None
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
model = MAE(pretrained=123, init_cfg=None)
|
||||||
|
|
||||||
|
# pretrain=123, whose type is unsupported
|
||||||
|
# init_cfg loads pretrain from an non-existent file
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
model = MAE(
|
||||||
|
pretrained=123, init_cfg=dict(type='Pretrained', checkpoint=path))
|
||||||
|
|
||||||
|
# pretrain=123, whose type is unsupported
|
||||||
|
# init_cfg=123, whose type is unsupported
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
model = MAE(pretrained=123, init_cfg=123)
|
Loading…
x
Reference in New Issue
Block a user