diff --git a/configs/selfsup/mae/README.md b/configs/selfsup/mae/README.md index caeec51c..cc545a72 100644 --- a/configs/selfsup/mae/README.md +++ b/configs/selfsup/mae/README.md @@ -142,6 +142,19 @@ methods that use only ImageNet-1K data. Transfer performance in downstream tasks +## Evaluating MAE on Detection and Segmentation + +If you want to evaluate your model on detection or segmentation task, we provide a [script](https://github.com/open-mmlab/mmselfsup/blob/dev-1.x/tools/model_converters/mmcls2timm.py) to convert the model keys from MMClassification style to timm style. + +```sh +cd $MMSELFSUP +python tools/model_converters/mmcls2timm.py $src_ckpt $dst_ckpt +``` + +Then, using this converted ckpt, you can evaluate your model on detection task, following [Detectron2](https://github.com/facebookresearch/detectron2/tree/main/projects/ViTDet), +and on semantic segmentation task, following this [project](https://github.com/implus/mae_segmentation). Besides, using the unconverted ckpt, you can use +[MMSegmentation](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/mae) to evaluate your model. + ## Citation ```bibtex @@ -149,7 +162,7 @@ methods that use only ImageNet-1K data. Transfer performance in downstream tasks title={Masked Autoencoders Are Scalable Vision Learners}, author={Kaiming He and Xinlei Chen and Saining Xie and Yanghao Li and Piotr Doll'ar and Ross B. Girshick}, - journal={ArXiv}, + journal={arXiv}, year={2021} } ``` diff --git a/tools/model_converters/mmcls2timm.py b/tools/model_converters/mmcls2timm.py new file mode 100644 index 00000000..99e7cd15 --- /dev/null +++ b/tools/model_converters/mmcls2timm.py @@ -0,0 +1,86 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import os.path as osp +from collections import OrderedDict +from typing import Union + +import mmengine +import torch +from mmengine.runner.checkpoint import _load_checkpoint + + +def convert_mmcls_to_timm(state_dict: Union[OrderedDict, dict]) -> OrderedDict: + """Convert keys in MMClassification pretrained vit models to timm tyle. + + Args: + state_dict (Union[OrderedDict, dict]): The state dict of + MMClassification pretrained vit models. + + Returns: + OrderedDict: The converted state dict. + """ + # only keep the backbone weights and remove the backbone. prefix + state_dict = { + key.replace('backbone.', ''): value + for key, value in state_dict.items() if key.startswith('backbone.') + } + + # replace projection with proj + state_dict = { + key.replace('projection', 'proj'): value + for key, value in state_dict.items() + } + + # replace ffn.layers.0.0 with mlp.fc1 + state_dict = { + key.replace('ffn.layers.0.0', 'mlp.fc1'): value + for key, value in state_dict.items() + } + + # replace ffn.layers.1 with mlp.fc2 + state_dict = { + key.replace('ffn.layers.1', 'mlp.fc2'): value + for key, value in state_dict.items() + } + + # replace layers with blocks + state_dict = { + key.replace('layers', 'blocks'): value + for key, value in state_dict.items() + } + + # replace ln with norm + state_dict = { + key.replace('ln', 'norm'): value + for key, value in state_dict.items() + } + + # replace the last norm1 with norm + state_dict['norm.weight'] = state_dict.pop('norm1.weight') + state_dict['norm.bias'] = state_dict.pop('norm1.bias') + + state_dict = OrderedDict({'model': state_dict}) + return state_dict + + +def main(): + parser = argparse.ArgumentParser( + description='Convert keys in MMClassification ' + 'pretrained vit models to timm tyle') + parser.add_argument('src', help='src model path or url') + parser.add_argument('dst', help='save path') + args = parser.parse_args() + + checkpoint = _load_checkpoint(args.src) + if 'state_dict' in checkpoint: + state_dict = checkpoint['state_dict'] + else: + state_dict = checkpoint + + state_dict = convert_mmcls_to_timm(state_dict) + mmengine.mkdir_or_exist(osp.dirname(args.dst)) + torch.save(state_dict, args.dst) + + +if __name__ == '__main__': + main()