Compare commits

...

303 Commits

Author SHA1 Message Date
BBC ee7f2e8850
Correct EfficientNetV2 URL in metafile.yml () 2024-11-01 14:27:36 +08:00
Ma Zerun 17a886cb58
Bump version to v1.2.0 ()
* [Fix] Fix resize mix argument bug.

* Bump version to v1.2.0

* Fix UT
2024-01-04 20:43:27 +08:00
mzr1996 9ac4b316f0 Merge remote-tracking branch 'origin/main' into dev 2024-01-04 18:07:08 +08:00
Ma Zerun 3022f9af7b
[Feature] Support LLaVA 1.5 ()
* Support LLaVA 1.5

* Fix lint
2023-12-22 16:28:20 +08:00
mzr1996 e95d9acb89 Update mmcv requirements 2023-11-16 10:29:10 +08:00
mzr1996 6e00cbecaa Update mmcv requirements 2023-11-15 17:34:13 +08:00
Coobiw ed5924b6fe
[Feature] Implement of RAM with a gradio interface. ()
* [CodeCamp2023-584]Support DINO self-supervised learning in project ()

* feat: impelemt DINO

* chore: delete debug code

* chore: impplement pre-commit

* fix: fix imported package

* chore: pre-commit check

* [CodeCamp2023-340] New Version of config Adapting MobileNet Algorithm ()

* add new config adapting MobileNetV2,V3

* add base model config for mobile net v3, modified all training configs of mobile net v3 inherit from the base model config

* removed directory _base_/models/mobilenet_v3

* [Feature] Implement of Zero-Shot CLIP Classifier ()

* zero-shot CLIP

* modify zero-shot clip config

* add in1k_sub_prompt(8 prompts) for improvement

* add some annotations doc

* clip base class & clip_zs sub-class

* some modifications of details after review

* convert into and use mmpretrain-vit

* modify names of some files and directories

* ram init commit

* [Fix] Fix pipeline bug in image retrieval inferencer

* [CodeCamp2023-341] 多模态数据集文档补充-COCO Retrieval

* Update OFA to compat with latest huggingface.

* Update train.py to compat with new config

* Bump version to v1.1.0

* Update __init__.py

---------

Co-authored-by: LALBJ <40877073+LALBJ@users.noreply.github.com>
Co-authored-by: DE009 <57087096+DE009@users.noreply.github.com>
Co-authored-by: mzr1996 <mzr1996@163.com>
Co-authored-by: 飞飞 <102729089+ASHORE1225@users.noreply.github.com>
2023-10-25 16:23:45 +08:00
mzr1996 a4c219e05d Bump version to v1.1.0 2023-10-12 17:20:22 +08:00
mzr1996 d35c778a6f Merge remote-tracking branch 'origin/main' into dev 2023-10-12 17:19:27 +08:00
hmtbgc c0766519b1
[Feature] Add minigpt4 gradio demo and training script. ()
* Add minigpt4 gradio demo

* update minigpt4 demo

* update minigpt4 demo (inference with float16)

* update minigpt4 and some dependent files

* add minigpt4 dataset for training

* add training script for minigpt4

* restore files deleted by mistake

* fix an error

* remove useless modification

* provide command line arguments for minigpt4 gradio demo and update some comments

* update code

* Update minigpt-4 readme

---------

Co-authored-by: mzr1996 <mzr1996@163.com>
2023-10-12 10:36:17 +08:00
mzr1996 4849324629 Update train.py to compat with new config 2023-10-11 11:13:40 +08:00
mzr1996 b0a792eb08 Update OFA to compat with latest huggingface. 2023-10-11 11:13:40 +08:00
飞飞 3bcf7e2d6e
[CodeCamp2023-341] 多模态数据集文档补充-COCO Retrieval 2023-10-08 15:46:47 +08:00
mzr1996 06bb586eb7 [Fix] Fix pipeline bug in image retrieval inferencer 2023-10-08 15:44:37 +08:00
Ma Zerun 5c71de6b8e
Merge pull request from timerring/dev
[CodeCamp2023-338] New Version of config Adapting Swin Transformer Algorithm
2023-09-08 16:01:38 +08:00
John 7734f073e4 set arch etc 2023-09-06 23:56:03 +08:00
John b0b4422736 fix a redundant 2023-09-05 22:22:43 +08:00
John 9b75ce0aa4 only keep one file to set swin transformer v2 model config 2023-09-05 22:16:07 +08:00
John f4d372ba7d only keep one file to set swin transformer model config 2023-09-05 21:26:43 +08:00
John ed3b7f8ae6 format all file names 2023-09-05 16:00:29 +08:00
John ddc6d0b121 Merge remote-tracking branch 'upstream/dev' into dev 2023-09-05 15:23:21 +08:00
ZhangYiqin da1da48eb6
[Enhance] Add iTPN Supports for Non-three channel image ()
* Add channel argments to mae_head

When trying iTPN pretrain, it only supports images with 3 channels. One of the restrictions is from MAEHead.

* Transfer other argments from iTPNHiViT to HiViT

The HiViT supports specifying channels, but the iTPNHiViT class can't pass channel argments to it. This is one of the reasons that iTPNHiViT implementation only support images with 3 channels.

* Update itpn.py

Fix hint problem
2023-09-04 13:11:16 +08:00
Coobiw bb59c9ad82
[Feature] Implement of Zero-Shot CLIP Classifier ()
* zero-shot CLIP

* modify zero-shot clip config

* add in1k_sub_prompt(8 prompts) for improvement

* add some annotations doc

* clip base class & clip_zs sub-class

* some modifications of details after review

* convert into and use mmpretrain-vit

* modify names of some files and directories
2023-09-04 10:30:28 +08:00
DE009 845b462190
[CodeCamp2023-340] New Version of config Adapting MobileNet Algorithm ()
* add new config adapting MobileNetV2,V3

* add base model config for mobile net v3, modified all training configs of mobile net v3 inherit from the base model config

* removed directory _base_/models/mobilenet_v3
2023-09-01 17:54:18 +08:00
John 634852ad61 [CodeCamp2023-338] New Version of config Adapting Swin Transformer Algorithm 2023-08-31 18:15:47 +08:00
zhengjie.xu e1675e893e
[Docs] Update QRcode ()
* Add miaomiao_qrcode.jpg

* Update qrcode
2023-08-30 19:47:21 +08:00
LALBJ d2ccc44a2c
[CodeCamp2023-584]Support DINO self-supervised learning in project ()
* feat: impelemt DINO

* chore: delete debug code

* chore: impplement pre-commit

* fix: fix imported package

* chore: pre-commit check
2023-08-23 10:45:18 +08:00
Ezra-Yu 853f0c6bca
[DOC] Update datset download score from opendatalab to openXlab ()
* update opendatalab to openXlab

* update dataset-index

---------

Co-authored-by: fangyixiao18 <fangyx18@hotmail.com>
2023-08-22 11:29:42 +08:00
Ma Zerun 732b0f4c98
Merge pull request from mzr1996/bump-v1.0.2
Bump version to v1.0.2
2023-08-15 15:10:04 +08:00
Ma Zerun b65a96a89c
Apply suggestions from code review
Co-authored-by: Yixiao Fang <36138628+fangyixiao18@users.noreply.github.com>
2023-08-15 14:43:21 +08:00
mzr1996 6bb0c8a987 Bump version to v1.0.2 2023-08-15 11:51:12 +08:00
mzr1996 bf62497e02 Merge remote-tracking branch 'origin/main' into dev 2023-08-15 11:37:22 +08:00
mstwutao 6474d6befa
[CodeCamp2023-336] New Version of `config` Adapting MAE Algorithm ()
* fix typo MIMHIVIT to MAEHiViT

* fix typo MIMHiViT to MAEHiViT

* [CodeCamp2023-336] New version of config adapting MAE algorithm

* pre-commit check

* Revert soft-link modification

---------

Co-authored-by: mzr1996 <mzr1996@163.com>
2023-08-14 17:20:39 +08:00
AzulaFire 1be28ea7c4
[CodeCamp2023-337] New Version of config Adapting ConvNeXt Algorithm ()
* add configs\_base_\datasets\imagenet21k_bs128.py

* update convnext_base_32xb128_in1k_384px.py

* add  convnext-base_32xb128_in1k.py

* add convnext-base_32xb128_in21k.py

* add convnext-large_64xb64_in1k-384px.py

* add convnext-large_64xb64_in1k.py

* add convnext-large_64xb64_in21k.py

* add convnext-small_32xb128_in1k-384px.py

* add convnext-small_32xb128_in1k.py

* add convnext-tiny_32xb128_in1k-384px.py

* add convnext-tiny_32xb128_in1k.py

* add convnext-xlarge_64xb64_in1k-384px.py

* add convnext-xlarge_64xb64_in1k.py

* add convnext-xlarge_64xb64_in21k.py

* pre-commit check
2023-08-14 15:25:59 +08:00
Am_mu bff80d3c48
[CodeCamp2023-335]New version of config adapting BeitV2 Algorithm () 2023-08-14 15:04:42 +08:00
fanqiNO1 29d706248c
[Enhancement] Support training of BLIP2 ()
* [Fix] Fix BEiT pre_norm

* [Enhancement] Support BLIP2 training

* [Fix] Fix quoted strings

* [Fix] Fix init_weights

* [Fix] Fix with_cls_token

* [Fix] Fix tokenizer

* [Fix] Fix quoted strings

* [Fix] Fix predict

* [Fix] Cancel changing BEiT

* [Fix] Add loading hook

* [Fix] Reformat with yapf

* [Fix] Fix prompt

* [Fix] Fix typo
2023-08-10 11:15:38 +08:00
Yuan Liu fa53174fd9
[Feature]: Add MFF ()
* [Feature]: Add MFF

* [Feature]: Add mff linear prob

* [Feature]: Add ft

* [Fix]: Update docstring

* [Feature]: Update out_indices

* [Feature]: Add prefix to ft

* [Feature]: Add README

* [Feature]: Update readme

* [Feature]: Update README

* [Feature]: Add metafile

* [Feature]: Update README

* [Fix]: Fix lint

* [Feature]: Add UT

* [Feature]: Update paper link
2023-08-08 16:01:07 +08:00
mstwutao 827a216155
[Fix] Fix typo MIMHIVIT to MAEHiViT ()
* fix typo MIMHIVIT to MAEHiViT

* fix typo MIMHiViT to MAEHiViT
2023-08-08 15:38:18 +08:00
No-518 1dda91bf24
[CodeCamp2023-343] Update dataset_prepare.md ()
* Update dataset_prepare.md

* Enhanced docstring for RefCOCO and updated datasets.rst

* fix ln

* update

---------

Co-authored-by: No-518 <wybang@gmail.com>
Co-authored-by: fangyixiao18 <fangyx18@hotmail.com>
2023-08-03 19:24:23 +08:00
Zeyuan 2fb52eefdc
[CodeCamp2023-339] New Version of `config` Adapting Vision Transformer Algorithm ()
* add old config

* add old config

* add old config

* renew vit-base-p16_64xb64_in1k.py

* rename

* finish vit_base_p16_64xb64_in1k_384px.py

* finish vit_base_p32_64xb64_in1k.py and 384px

* finish 4 vit_large*.py

* finish vit_base_p16_32xb128_mae_in1k.py

* add vit_base_p16_4xb544_ipu_in1k.py

* modify data_root

* using  to modify cfg

* pre-commit check

* ignore ipu

* keep other files no change

* remove redefinition

* only keep vit_base_p16.py

* move init_cfg into model.update
2023-08-02 10:06:08 +08:00
Yike Yuan 340d187765
Support Infographic VQA dataset and ANLS metric. () 2023-08-01 16:22:34 +08:00
Yike Yuan 4f2f3752d9
Support IconQA dataset. () 2023-08-01 16:14:40 +08:00
Yixiao Fang 5c71eba13d
Bump version to 1.0.1 ()
* bump version to 1.0.1

* update changelog

* update readme

* Update changelog.md

* update requirements

---------

Co-authored-by: Ma Zerun <mzr1996@163.com>
2023-07-31 17:08:05 +08:00
fangyixiao18 58a2243d99 Merge branch 'main' into dev 2023-07-28 15:35:55 +08:00
Yixiao Fang 1f99279657
[Fix] Fix dict update in minigpt4. () 2023-07-28 15:30:30 +08:00
Yixiao Fang 0b96dcaa67
[Enhance] Add init_cfg with type='pretrained' to downstream tasks. () 2023-07-28 15:28:29 +08:00
Yixiao Fang b1cd05caf2
[Enhance] Set 'is_init' in some multimodal methods ()
* update is_init of multimodal

* Update minigpt4.py

---------

Co-authored-by: Ma Zerun <mzr1996@163.com>
2023-07-28 15:28:07 +08:00
marouane amzil e7fc25cf64
[Fix] Fix nested predict for multi-task prediction. ()
* fix: multi task predict

* change the loop

---------

Co-authored-by: Pierre Colle <piercus@gmail.com>
2023-07-28 13:44:12 +08:00
Yinlei Sun c5248b17b7
[Enhance] Adapt test cases on Ascend NPU. () 2023-07-28 13:39:38 +08:00
Nripesh Niketan 4d1dbafaa2
[Enhance] Add GPU Acceleration Apple silicon mac ()
* Add GPU Acceleration Apple silicon mac

* lint fix

* Update launch.py

* Use  to refactor the device selection.

* Update launch.py

---------

Co-authored-by: mzr1996 <mzr1996@163.com>
2023-07-26 17:51:00 +08:00
liyl 2b8d8eecb2
[Fix] Fix the issue "GaussianBlur doesn't work" ()
* Fix issue 1711. GaussianBlur.

* Fix UT

---------

Co-authored-by: mzr1996 <mzr1996@163.com>
2023-07-25 11:25:32 +08:00
fanqiNO1 64c446d507
[Feature] Support LoRA. ()
* [Feature] Support LoRA

* [Feature] Support LoRA

* [Fix] Fix bugs

* [Refactor] Add copyright

* [Fix] Fix bugs

* [Enhancement] Add

* [Fix] Fix bugs

* [Fix] Fix bugs

* [Fix] Fix bugs

* [Fix] Fix bugs

* [Fix] Fix bugs

* [Docs] Update docstring

* [Docs] Update docstring

* [Refactor] Reformat with yapf

* [Docs] Update docstring

* [Refactor] Docformat

* [Refactor] Fix double-quote-string

* [Fix] fix pytorch version

* [Fix] isort

* [Fix] isort

* [Enhancement] Extend forward

* [Enhancement] Extend test

* [Fix] Fix targets

* [Enhancement] Extend LoRA to frozen models

* [Fix] Fix spelling

* [Fix] Override __getattr__

* [Fix] Add init_cfg

* [Enhancement] Add example config

* [Fix] Fix init_cfg

* [Enhancement] Add merging script

* [Fix] Remove init_cfg

* [Fix] Change lora key

* [Fix] Fix merge scripts

* [Fix] Fix merge scripts

* [Docs] Add docs

* [Fix] fix
2023-07-24 11:30:57 +08:00
mzr1996 60d780f99e Fix docs 2023-07-20 10:21:15 +08:00
BBC 569324b180
Just to correct a typo of 'target' () 2023-07-14 16:16:35 +08:00
Fabien Merceron PRL db395d35b1
fix_freeze_without_cls_token_vit () 2023-07-14 15:43:19 +08:00
fanqiNO1 465b6bdeec
[Refactor] Fix spelling () 2023-07-13 15:38:58 +08:00
fanqiNO1 5c43d3ef42
[Refactor] BEiT refactor ()
* [Refactor] BEiT refactor

* [Fix] Fix arch zoo

* [Fix] Fix arch zoo

* [Fix] Fix freeze stages

* [Fix] Fix freeze ln2

* [Fix] Fix freezing vit ln2
2023-07-11 15:49:41 +08:00
Ezra-Yu 78d0ddc852
[Fix] Fix RandomCrop bug () 2023-07-11 10:18:08 +08:00
Yixiao Fang ae7a7b7560
Bump version to 1.0.0 ()
* bump version to 1.0.0

* update

* update

* fix lint

* update

* update

* update changelog

* update
2023-07-05 11:51:12 +08:00
liweiwp 0d80ab4650
[Docs] fix doc typos ()
* fix doc typos

* fix link

---------

Co-authored-by: fangyixiao18 <fangyx18@hotmail.com>
2023-07-05 11:25:55 +08:00
Minato 8eaf8090e6
docs(advances_guides.modules): Correcting a typo ()
"涵盖了模型直接绝大多数的差异" -> "涵盖了模型之间绝大多数的差异"

English edition:
"backbone: usually a feature extraction network that records the major differences between models, e.g., ResNet, MobileNet."
2023-07-05 11:14:46 +08:00
Lamply 130751185c
[DOC] Fix typo in docs/*/migration.md ()
* Update migration.md

* Update migration.md
2023-07-05 11:12:25 +08:00
fanqiNO1 7cbfb36c14
[Refactor] Fix spelling ()
* [Refactor] Fix spelling

* [Refactor] Fix spelling

* [Refactor] Fix spelling

* [Refactor] Fix spelling
2023-07-05 11:07:43 +08:00
Wang Xiang feb0814b2f
[Feature] Transfer shape-bias tool from mmselfsup ()
* Transfer shape-bias tool from mmselfsup

* run shape-bias successfully, add CN docs

* fix unit test bug

* add shape_bias to index.rst in docs

* modified mistakes in shape-bias docs
2023-07-03 11:39:23 +08:00
Peng Lu 00030e3f7d
[Fix] refactor _prepare_pos_embed in ViT to fix bug in loading old checkpoint () 2023-07-03 11:36:44 +08:00
Ezra-Yu 59c077746f
[Feat] Download dataset by using MIM&OpenDataLab ()
* add dataset.index

* update preprocess shell

* update shell

* update docs

* update docs
2023-06-30 13:55:13 +08:00
Mashiro 8afad77a35
[Enhance] Update fsdp vit-huge and vit-large config ()
* Update fsdp vit-huge and vit-large config

* Update fsdp vit-huge and vit-large config

* rename
2023-06-30 11:15:18 +08:00
fanqiNO1 658db80089
[Enhancement] Support deepspeed with flexible runner ()
* [Feature] Support deepspeed with flexible runner

* [Fix] Reformat with yapf

* [Refacor] Rename configs

* [Fix] Reformat with yapf

* [Refactor] Remove unused keys

* [Refactor] Change the _base_ path

* [Refactor] Reformat
2023-06-29 10:16:27 +08:00
Wangbo Zhao(黑色枷锁) 68758db7a8
[Fix] freeze pre norm in vision transformer. () 2023-06-28 17:00:27 +08:00
Yixiao Fang 10685fc81c
[Refactor] Replace if '_base_' with read_base(). () 2023-06-28 16:57:18 +08:00
Yixiao Fang 70ff2abbf7
[Refactor] Refactor _prepare_pos_embed in ViT ()
* deal with cls_token

* Update implement

---------

Co-authored-by: mzr1996 <mzr1996@163.com>
2023-06-20 17:37:08 +08:00
Yixiao Fang d4a6dfa00a
Add benchmark options ()
* update dev_scripts

* update metafile

* update multimodal floating range

* fix lint

* update

* update

* fix lint

* Update metric map

---------

Co-authored-by: mzr1996 <mzr1996@163.com>
2023-06-20 14:18:57 +08:00
Ma Zerun 7d850dfadd
[Improve] Update Otter and LLaVA docs and config. () 2023-06-19 20:16:13 +08:00
mzr1996 dbef2b41c6 [Fix] Align COCO dataset format. 2023-06-19 07:24:07 +00:00
Mashiro d6056af2b8
[Fix][New_config] Fix demo bug ()
* Fix demo

* Update implement

---------

Co-authored-by: mzr1996 <mzr1996@163.com>
2023-06-19 15:15:28 +08:00
Yiqin Wang 王逸钦 6d7fe91a98
[Feature] Support Flickr30k Retrieval dataset ()
* format

* remove abs path

* init add flickr30k caption

* remove abs dir

* update blip readme

* add convert sscripts

* minor

* minor
2023-06-19 15:15:03 +08:00
Yixiao Fang a1cfe888e2
[Feature] Support SparK. ()
* add spark configs

* fix configs

* remove repeat aug

* add module codes

* support lr layer decay of resnet

* update

* fix lint

* add metafile and readme

* fix lint

* add models and logs

* refactor codes

* fix lint

* update model rst

* update name

* add docstring

* add ut

* fix lint

---------

Co-authored-by: Ma Zerun <mzr1996@163.com>
2023-06-19 11:27:50 +08:00
Ma Zerun bfd49b0d52
[Feature] Support LLaVA () 2023-06-17 16:05:52 +08:00
Ma Zerun e69bace03f
[Feature] Support otter ()
* [Feature] Support Otter

* Update docs
2023-06-17 16:03:21 +08:00
Yixiao Fang 9d3fc43073
[Feature] Support MiniGPT-4 ()
* support inference of MiniGPT-4

* refine codes

* update metafile, readme and docs

* fix typo

* fix lint

* add ckpt load hook
2023-06-16 22:50:34 +08:00
Yike Yuan a673b048a5
[Feature] Add support for VizWiz dataset. ()
* add vizwiz

* update dataset

* [Fix] Build img_path in data_sample.

* Fix isort.

---------

Co-authored-by: ZhangYuanhan-AI <yuanhan002@ntu.edu.sg>
2023-06-16 17:16:17 +08:00
Yixiao Fang aac398a83f
[Feature] Support new configs. ()
* [Feature] Support new configs ()

* add new config of mae and simclr

* update

* update setup.cfg

* update eva

* update

* update new config

* Add new config

* remove __init__.py

* 1. remove ; 2. remove mmpretrain/configs/_base_/models/convnext

* remove model_wrapper_cfg and add out type

* Add comment for setting default_scope to NOne

* update if '_base_' order

* update

* revert changes

---------

Co-authored-by: fangyixiao18 <fangyx18@hotmail.com>

* Add warn at the head of new config files

---------

Co-authored-by: Mashiro <57566630+HAOCHENYE@users.noreply.github.com>
Co-authored-by: mzr1996 <mzr1996@163.com>
2023-06-16 16:54:45 +08:00
Ezra-Yu 93e0f107c4
[Fix] Fix bug loading IN1k dataset. () 2023-06-16 15:35:27 +08:00
Yike Yuan 7581b76233
[Feature] Add support for vsr dataset ()
* add VSR dataset

* [Fix] Modify example and load gt_answer as string.

---------

Co-authored-by: ZhangYuanhan-AI <yuanhan002@ntu.edu.sg>
2023-06-15 19:17:02 +08:00
zzc98 53648baca5
[Fix] fix sam bug () 2023-06-15 10:10:51 +08:00
zzc98 3eaf719a64
[Feature] Add InternImage Classification project ()
* [Feature] add internimage project

* [Feature] add internimage project

* update license

* [Feature] add internimage project

* [Feature] add internimage project

* [Feature] add internimage project

* [Feature] add internimage project

* [Feature] add internimage project

* [Feature] add internimage project

* update license

* [Feature] add internimage project

* [Feature] add internimage project

* [Feature] add internimage project

* [Feature] add internimage project

* update internimage configs

* support internimage project

* support internimage project

* support internimage project

* internimage
2023-06-13 19:11:54 +08:00
Hubert 8e9e880601
[Feat] Add download link for coco caption and retrieval annotations. ()
* [Feat] Add download link for coco caption and retrieval annotations.

* minor fix
2023-06-13 10:29:54 +08:00
Yiqin Wang 王逸钦 bb415b91be
[Feature] Support OCR-VQA dataset ()
* support ocrvqa dataset

* minor

* remove abs path

* refine README
2023-06-13 10:28:45 +08:00
Yiqin Wang 王逸钦 dbfb84ccbd
[Feature] Support OK-VQA dataset ()
* add okvqa

* refine README
2023-06-08 16:57:18 +08:00
Mr.Li 057d7c6d6a
[BUG] Fixed circular import error for new transform () 2023-06-08 14:00:41 +08:00
Yuan Liu bddbc085fc
[Feature]: Add image_only param ()
* [Feature]: Add image_only param

* [Feature]: Use image_only
2023-06-06 12:50:42 +08:00
Wangbo Zhao(黑色枷锁) 3a277ee9e6
[Feature] support TextVQA dataset ()
* [Support] Suport TextVQA dataset

* add folder structure

* fix readme
2023-06-02 11:50:38 +08:00
zzc98 bc3c4a35ee
[Refactor] Support to use "split" to specify training set/validation set in the ImageNet dataset ()
* [Feature]: Add caption

* [Feature]: Update scienceqa

* [CI] Add test mim CI. ()

* refactor imagenet dataset

* refactor imagenet dataset

* refactor imagenet dataset

* update imagenet21k

* update configs

* update mnist

* update dataset_prepare.md

* fix sun397 url and update user_guides/dataset_prepare.md

* update dataset_prepare.md

* fix sun397 dataset

* fix sun397

* update chinese dataset_prepare.md

* update dataset_prepare.md

* [Refactor] update voc dataset

* [Refactor] update voc dataset

* refactor imagenet

* refactor imagenet

* use mmengine.fileio

---------

Co-authored-by: liuyuan <3463423099@qq.com>
Co-authored-by: Ma Zerun <mzr1996@163.com>
Co-authored-by: Ezra-Yu <18586273+Ezra-Yu@users.noreply.github.com>
2023-06-02 11:03:18 +08:00
Wang Xiang 795607cfeb
[Docs] Add t-SNE visualization doc ()
* 2023-05-08 add t-sne docs

* 2023-05-08 add t-sne docs

* 2023-05-10 add t-sne docs CN

* 2023-05-25 rebase dev

* add docs for running t-sne on mae models, and fix a bug in vis_tsne.py

* rewrite t-sne docs to correct some mistakes
2023-06-01 10:04:06 +08:00
Ma Zerun 5bd088ef43
[Fix] Update torchvision transform wrapper ()
* Update torchvision transform wrapper

* Update requirements

* fix unit tests

---------

Co-authored-by: Ezra-Yu <18586273+Ezra-Yu@users.noreply.github.com>
2023-05-26 17:56:09 +08:00
Yixiao Fang e4c4a81b56
[Feature] Support iTPN and HiViT ()
* hivit added

* Update hivit.py

* Update hivit.py

* Add files via upload

* Update __init__.py

* Add files via upload

* Update __init__.py

* Add files via upload

* Update hivit.py

* Add files via upload

* Add files via upload

* Add files via upload

* Add files via upload

* Update itpn.py

* Add files via upload

* Update __init__.py

* Update mae_hivit-base-p16.py

* Delete mim_itpn-base-p16.py

* Add files via upload

* Update itpn_hivit-base-p16.py

* Update itpn.py

* Update hivit.py

* Update __init__.py

* Update mae.py

* Delete hivit.py

* Update __init__.py

* Delete configs/itpn directory

* Add files via upload

* Add files via upload

* Delete configs/hivit directory

* Add files via upload

* refactor and add metafile and readme

* update clip

* add ut

* update ut

* update

* update docstring

* update model.rst

---------

Co-authored-by: 田运杰 <48153283+sunsmarterjie@users.noreply.github.com>
2023-05-26 12:08:34 +08:00
Ezra-Yu 1f07c92ed1
[Feature] Add retrieval mAP metric. ()
* rebase

* fefine

* fix lint

* update readme

* rebase

* fix lint

* update docstring

* update docstring

* rebase

* rename corespanding names

* rebase
2023-05-26 10:40:08 +08:00
Ezra-Yu 9bb692e440
[Fix] Set default out_type in CAM visualization. () 2023-05-24 14:09:41 +08:00
Wangbo Zhao(黑色枷锁) a779c8c5a7
[Feature] Support NoCap dataset based on BLIP. ()
* [Feature] Support nocaps dataset

* precommit

* Use official coco format

* add nocp readme

* fix readme

---------

Co-authored-by: mzr1996 <mzr1996@163.com>
2023-05-23 18:06:43 +08:00
Yuan Liu 46a523ef63
[Feature] Add GQA dataset. ()
* [Feature]: Add GQA dataset

* [Feature]: Add GQA

* [Feature]: Add GQA UT

* [Fix]: Fix hint

* [Feature]: Add BLIP2 GQA

* [Fix]: Fix lint

* [Feature]: Update anno link

* [Fix]: Update docstring

* [Feature]: Update all links
2023-05-23 11:25:42 +08:00
Ma Zerun 4dd8a86145
Bump version to v1.0.0rc8 ()
* Bump version to v1.0.0rc8

* Apply suggestions from code review

Co-authored-by: Yixiao Fang <36138628+fangyixiao18@users.noreply.github.com>

* Update README.md

---------

Co-authored-by: Yixiao Fang <36138628+fangyixiao18@users.noreply.github.com>
2023-05-23 11:22:51 +08:00
Yuan Liu be389eb846
[Fix] Fix scienceqa () 2023-05-22 16:10:17 +08:00
ZhangYiqin 023d6869bd
[Fix] Incorrect stage freeze on RIFormer Model ()
* [Doc] RIFormer's README did not link to its paper properly

* Incorrect code for reproducing RIFormer 

the default value of frozen stage is set to 0, and the doc says that this will lead to no stage be frozen. But the actual case is the patch_embed will be freezed.

This may cause incorrect training, thus influencing the result.

I suggest a careful review.
2023-05-22 16:01:32 +08:00
zzc98 b058912c0c
[Docs] Fix example_project README () 2023-05-22 15:47:03 +08:00
Yixiao Fang 1e478462b8
[Feature] Support Chinese CLIP. ()
* support cn-clip

* update README

* Update progress bar

* update order of category

* fix lint

* update

* update readme and metafile

* update

* update docstring

* refactor tokenizer

* fix lint

* Update README and progress bar

---------

Co-authored-by: mzr1996 <mzr1996@163.com>
2023-05-22 15:46:13 +08:00
Yuan Liu d04ef8a29e
Merge pull request from YuanLiuuuuuu/scienceqa_metrics
[Feature]: Add ScienceQA Metrics
2023-05-22 13:08:06 +08:00
liuyuan 74f24658e7 [Fix]: Delete GQA 2023-05-22 11:57:18 +08:00
liuyuan 13e4d6c512 [Fix]: Fix UT 2023-05-22 11:55:08 +08:00
liuyuan b0ad99afb9 [Fix]: Fix bug 2023-05-22 11:38:34 +08:00
liuyuan 1537d46596 [Feature]: Update scienceqa 2023-05-22 11:31:07 +08:00
liuyuan 87f849cbb6 [Feature]: Add scienceqa metric 2023-05-22 11:31:07 +08:00
liuyuan 1b8e86dca6 [Feature]: Add caption 2023-05-22 11:31:07 +08:00
Ma Zerun 6847d20d57
[Feature] Support multiple multi-modal algorithms and inferencers. ()
* [Feat] Migrate blip caption to mmpretrain. ()

* Migrate blip caption to mmpretrain

* minor fix

* support train

* [Feature] Support OFA caption task. ()

* [Feature] Support OFA caption task.

* Remove duplicated files.

* [Feature] Support OFA vqa task. ()

* [Feature] Support OFA vqa task.

* Fix lint.

* [Feat] Add BLIP retrieval to mmpretrain. ()

* init

* minor fix for train

* fix according to comments

* refactor

* Update Blip retrieval. ()

* [Feature] Support OFA visual grounding task. ()

* [Feature] Support OFA visual grounding task.

* minor add TODO

---------

Co-authored-by: yingfhu <yingfhu@gmail.com>

* [Feat] Add flamingos coco caption and vqa. ()

* first init

* init flamingo coco

* add vqa

* minor fix

* remove unnecessary modules

* Update config

* Use `ApplyToList`.

---------

Co-authored-by: mzr1996 <mzr1996@163.com>

* [Feature]: BLIP2 coco retrieval  ()

* [Feature]: Add blip2 retriever

* [Feature]: Add blip2 all modules

* [Feature]: Refine model

* [Feature]: x1

* [Feature]: Runnable coco ret

* [Feature]: Runnable version

* [Feature]: Fix lint

* [Fix]: Fix lint

* [Feature]: Use 364 img size

* [Feature]: Refactor blip2

* [Fix]: Fix lint

* refactor files

* minor fix

* minor fix

---------

Co-authored-by: yingfhu <yingfhu@gmail.com>

* Remove

* fix blip caption inputs ()

* [Feat] Add BLIP NLVR support. ()

* first init

* init flamingo coco

* add vqa

* add nlvr

* refactor nlvr

* minor fix

* minor fix

* Update dataset

---------

Co-authored-by: mzr1996 <mzr1996@163.com>

* [Feature]: BLIP2 Caption ()

* [Feature]: Add language model

* [Feature]: blip2 caption forward

* [Feature]: Reproduce the results

* [Feature]: Refactor caption

* refine config

---------

Co-authored-by: yingfhu <yingfhu@gmail.com>

* [Feat] Migrate BLIP VQA to mmpretrain ()

* reformat

* change

* change

* change

* change

* change

* change

* change

* change

* change

* change

* change

* change

* change

* change

* change

* change

* change

* change

* change

* refactor code

---------

Co-authored-by: yingfhu <yingfhu@gmail.com>

* Update RefCOCO dataset

* [Fix] fix lint

* [Feature] Implement inference APIs for multi-modal tasks. ()

* [Feature] Implement inference APIs for multi-modal tasks.

* [Project] Add gradio demo.

* [Improve] Update requirements

* Update flamingo

* Update blip

* Add NLVR inferencer

* Update flamingo

* Update hugging face model register

* Update ofa vqa

* Update BLIP-vqa ()

* Update blip-vqa docstring ()

* Refine flamingo docstring ()

* [Feature]: BLIP2 VQA ()

* [Feature]: VQA forward

* [Feature]: Reproduce accuracy

* [Fix]: Fix lint

* [Fix]: Add blank line

* minor fix

---------

Co-authored-by: yingfhu <yingfhu@gmail.com>

* [Feature]: BLIP2 docstring ()

* [Feature]: Add caption docstring

* [Feature]: Add docstring to blip2 vqa

* [Feature]: Add docstring to retrieval

* Update BLIP-2 metafile and README ()

* [Feature]: Add readme and docstring

* Update blip2 results

---------

Co-authored-by: mzr1996 <mzr1996@163.com>

* [Feature] BLIP Visual Grounding on MMPretrain Branch ()

* blip grounding merge with mmpretrain

* remove commit

* blip grounding test and inference api

* refcoco dataset

* refcoco dataset refine config

* rebasing

* gitignore

* rebasing

* minor edit

* minor edit

* Update blip-vqa docstring ()

* rebasing

* Revert "minor edit"

This reverts commit 639cec757c215e654625ed0979319e60f0be9044.

* blip grounding final

* precommit

* refine config

* refine config

* Update blip visual grounding

---------

Co-authored-by: Yiqin Wang 王逸钦 <wyq1217@outlook.com>
Co-authored-by: mzr1996 <mzr1996@163.com>

* Update visual grounding metric

* Update OFA docstring, README and metafiles. ()

* [Docs] Update installation docs and gradio demo docs. ()

* Update OFA name

* Update Visual Grounding Visualizer

* Integrate accelerate support

* Fix imports.

* Fix timm backbone

* Update imports

* Update README

* Update circle ci

* Update flamingo config

* Add gradio demo README

* [Feature]: Add scienceqa ()

* [Feature]: Add scienceqa

* [Feature]: Change param name

* Update docs

* Update video

---------

Co-authored-by: Hubert <42952108+yingfhu@users.noreply.github.com>
Co-authored-by: yingfhu <yingfhu@gmail.com>
Co-authored-by: Yuan Liu <30762564+YuanLiuuuuuu@users.noreply.github.com>
Co-authored-by: Yiqin Wang 王逸钦 <wyq1217@outlook.com>
Co-authored-by: Rongjie Li <limo97@163.com>
2023-05-19 16:50:04 +08:00
Yixiao Fang 770eb8e24a
[Fix] Fix ddp bugs caused by `out_type`. ()
* set out_type to be 'raw'

* update test
2023-05-17 17:32:10 +08:00
zzc98 034919d032
[Feature] add eva02 backbone ()
* [CI] Add test mim CI. ()

* [CI] Add test mim CI. ()

* feat: add eva02 backbone

* feat: add eva02 backbone

* feat: add eva02 backbone

* feat: add eva02 backbone

* feat: add eva02 backbone

* feat: add eva02 backbone

* feat: add eva02 backbone

* feat: add eva02 backbone

* update

* update ci

* rebase

* feat: add eva02 backbone

* feat: add eva02 backbone

* feat: add eva02 backbone

* feat: add eva02 backbone

* feat: add eva02 backbone

* feat: add eva02 backbone

* feat: add eva02 backbone

* feat: add eva02 backbone

* update

* update readme and configs

* update readme and configs

* refactore eva02

* [CI] Add test mim CI. ()

* feat: add eva02 backbone

* feat: add eva02 backbone

* feat: add eva02 backbone

* feat: add eva02 backbone

* feat: add eva02 backbone

* feat: add eva02 backbone

* feat: add eva02 backbone

* feat: add eva02 backbone

* update

* update ci

* rebase

* feat: add eva02 backbone

* feat: add eva02 backbone

* feat: add eva02 backbone

* update

* update readme and configs

* refactore eva02

* update readme and metafile

* update readme and metafile

* update readme and metafile

* update

* rename eva02

* rename eva02

* fix uts

* rename configs

---------

Co-authored-by: Ma Zerun <mzr1996@163.com>
Co-authored-by: Ezra-Yu <18586273+Ezra-Yu@users.noreply.github.com>
2023-05-06 19:28:31 +08:00
Ezra-Yu 7f4eccbecf
[Fix] Fix multi-task-head loss potential bug ()
* fix bug

* add comments
2023-05-06 18:04:34 +08:00
Ezra-Yu 9cf37b315c
[DOC] Refine Inference Doc ()
* update en doc

* update

* update zh doc

* refine

* refine
2023-05-06 17:54:13 +08:00
Kei-Chi Tse afa60c73bb
[Fix] Support bce loss without batch augmentations ()
* Support bce loss without batch augmentations

---------

Co-authored-by: Ezra-Yu <18586273+Ezra-Yu@users.noreply.github.com>
2023-05-05 17:19:42 +08:00
Yixiao Fang d9e561a09d
[Feature] Support dinov2 backbone ()
* support dinov2 backbone

* update metafile and readme

* compatible to use_layer_scale

* update SwiGLUFFN

* add deprecation warning

* update
2023-05-05 16:59:37 +08:00
zzc98 496e098b21
[Feature] Support some downstream classification datasets. ()
* feat: support some downstream classification datasets

* update sun397

* sum

* update sun397

* [CI] Add test mim CI. ()

* feat: support some downstream classification datasets

* update sun397

* sum

* update sun397

* rebase

* feat: support some downstream classification datasets

* update sun397

* update sun397

* update sun397

* update sun397

* fix unittest

* update docstring

* rm

* update

* update

* refactor names of datasets

* refactor some implements of datasets

* refactor some implements of datasets

* fix datasets unittest

* refactor cub and stanford cars

* refactor cub and cifar

* refactor cub and cifar

* refactor cub and cifar

* update downstream datasets and docs

* update docstring

---------

Co-authored-by: Ma Zerun <mzr1996@163.com>
Co-authored-by: Ezra-Yu <18586273+Ezra-Yu@users.noreply.github.com>
2023-05-05 14:43:14 +08:00
Yixiao Fang a3fa328f09
Fix config of beit () 2023-04-28 16:33:14 +08:00
Choi Sau Deng b51d7d21de
[DOC] Add doc for usage of confusion matrix ()
* add_doc_for_confusion_matrix

* add_doc_for_confusion_matrix_fix_mmcls

* add_doc_for_confusion_matrix_fix_shell

* add_doc_for_confusion_matrix_fix_shell

* fix

* update

---------

Co-authored-by: fangyixiao18 <fangyx18@hotmail.com>
2023-04-27 14:56:44 +08:00
Yixiao Fang 15cc2a5193
[Fix] Fix clip generator init bug () 2023-04-25 19:35:09 +08:00
Ezra-Yu 6ceba070a8
[DOC] Update MMagic link ()
* update repo links

* update mmengine links

* update mmengine links
2023-04-25 19:12:34 +08:00
Weihao Yu 3cd4fd4d64
Update PoolFormer citation to CVPR version () 2023-04-20 20:22:22 +08:00
Wangbo Zhao(黑色枷锁) e954cf0aaf
[Fix] Fix the bug in binary cross entropy loss ()
* [Fix] Fix the bug in binary cross entropy loss

 Fix the bug in binary cross entropy loss when using multi-label datasets e.g.VOC2007

* update ci

---------

Co-authored-by: Ezra-Yu <18586273+Ezra-Yu@users.noreply.github.com>
2023-04-19 13:53:31 +08:00
takuoko fec3da781f
[Feature] Support GLIP ()
* rebase

* add glip

* update glip

* add links

* rename

* fix doc

---------

Co-authored-by: Ezra-Yu <18586273+Ezra-Yu@users.noreply.github.com>
2023-04-17 19:19:23 +08:00
Yixiao Fang 2c913020b9
[Refactor] Support to freeze channel reduction and add layer decay function ()
* support to freeze channel reduction module

* add layer decay setting function
2023-04-17 13:36:47 +08:00
Yixiao Fang e93d124ad4
[Refactor] Support resizing pos_embed while loading ckpt and format output ()
* support resize pos_embed while loading ckpt

* update
2023-04-14 19:08:35 +08:00
Yixiao Fang 02571fe4b8
[Docs] Add NPU support page ()
* add npu docs

* fix lint
2023-04-14 13:58:10 +08:00
Ezra-Yu 645e2b4ed4
[DOC] Fix typo in MultiLabelDataset docstring ()
* fix doc

* fix ci

* fix ci

* fix ci

* fix ci
2023-04-14 13:57:54 +08:00
Ezra-Yu 99e48116aa
[Feature] Register torchvision transforms into mmcls ()
* [Enhance] Add stochastic depth decay rule in resnet. ()

* add stochastic depth decay rule to drop path rate

* add default value

* update

* pass ut

* update

* pass ut

* remove np

* rebase

* update ToPIL and ToNumpy

* rebase

* rebase

* rebase

* rebase

* add readme

* fix review suggestions

* rebase

* fix conflicts

* fix conflicts

* fix lint

* remove comments

* remove useless code

* update docstring

* update doc API

* update doc

---------

Co-authored-by: Yixiao Fang <36138628+fangyixiao18@users.noreply.github.com>
2023-04-13 18:05:57 +08:00
Yixiao Fang 0826df8963
[Feature] Add ViT of SAM ()
* add vit of sam

* update

* update

* add ut

* update ut

* remove num_classes

* support dynamic input

* add ut

* add comments

* update ut
2023-04-13 17:03:28 +08:00
Mr.Li e80418a424
[Docs] train cfg: Removed old description () 2023-04-10 15:37:14 +08:00
Yixiao Fang 9cbeceabb5
Bump version to v1.0.0rc7 ()
* update

* update info

* update changelog

* update

* update description

* change to v1.0.0rc7
2023-04-07 17:34:21 +08:00
Ezra-Yu 47e033c466
[Fix] fix pbn bug () 2023-04-07 15:38:04 +08:00
Yixiao Fang 5ea46fbbbc
[Docs] Fix docs link () 2023-04-07 15:06:18 +08:00
mzr1996 1e78f09d87 Update docs style. 2023-04-07 14:52:34 +08:00
Yixiao Fang 79ddc0f874
[Refactor] Update CI and issue template ()
* Update CI

* update issue template

* update

* update collect_env function
2023-04-07 14:29:41 +08:00
Ezra-Yu 411e05a705
Merge pull request from techmonsterwang/riformer_mmpt
[Feature] Add RIFormer Backbone
2023-04-07 10:32:23 +08:00
Ezra-Yu 05124dbb71 fix lint 2023-04-06 22:01:11 +08:00
Ezra-Yu b8cab5c9f7 update readme 2023-04-06 21:56:25 +08:00
Ezra-Yu 3932ddec10 update ckpt path 2023-04-06 21:56:25 +08:00
techmonsterwang 5c3abb2b2a update riformer mmpretrain 2023-04-06 21:56:25 +08:00
techmonsterwang e115ac89f4 update riformer mmpretrain 2023-04-06 21:56:25 +08:00
techmonsterwang 53a57c6dad update riformer mmpretrain 2023-04-06 21:56:25 +08:00
techmonsterwang e4d8511ddf update riformer mmpretrain 2023-04-06 21:56:25 +08:00
techmonsterwang c9c7d9cc0f update riformer mmpretrain 2023-04-06 21:56:25 +08:00
techmonsterwang a6c24d104e update riformer mmpretrain 2023-04-06 21:56:25 +08:00
techmonsterwang e7da3f29f4 update riformer mmpretrain 2023-04-06 21:56:25 +08:00
techmonsterwang 61b795f21f update riformer mmpretrain 2023-04-06 21:56:25 +08:00
techmonsterwang 0ef0b5ce08 update riformer mmpretrain 2023-04-06 21:56:25 +08:00
techmonsterwang 32c258ff19 update riformer mmpretrain 2023-04-06 21:56:25 +08:00
techmonsterwang 0b70c108b0 update riformer mmpretrain 2023-04-06 21:56:25 +08:00
Yixiao Fang 1ee9bbe050
[Docs] Update links ()
* update links

* update readtherdocs

* update

* update

* fix lint

* update

* update

* update

* update cov branch

* update

* update

* update
2023-04-06 20:58:52 +08:00
Yixiao Fang 3069e43f77
[Docs] Update readme ()
* update readme

* update

* refine

* refine

* update cn version

* update installation

* update modelzoo table

* fix lint

* update

* update

* update

* update

* fix lint

* update

* update

* update changelog

* remove gif

* fix typo

* update announcement

* update

* fixtypo

* update
2023-04-06 17:17:56 +08:00
Yixiao Fang 75dceaa78f
[Refactor] Add ln to vit avg_featmap output () 2023-04-06 11:59:39 +08:00
Ezra-Yu 3a25b13eb3
[Fix] Update CI ()
* update ci

* update window ci

* update window ci

* update window ci

* update window ci

* update window ci

* update window ci

* update window ci

* update window ci

* update window ci

* update window ci

* update window ci

* ignore tests/test_tools.py
2023-04-06 10:52:08 +08:00
Yixiao Fang 568188a6b0
[Docs] add overview in the homepage ()
* update index overview

* fix lint

* refine with chatgpt

* fix lint

* update according to review

* update cn version
2023-04-03 16:39:58 +08:00
Yixiao Fang 9fb4e9c911
[Fix] fix config of maskfeat () 2023-03-30 11:45:18 +08:00
Yixiao Fang 445eb3223a
[Docs] Refine advanced guides ()
* refine

* update description

* update links

* update links

* update installation

* refine
2023-03-29 16:23:57 +08:00
Ma Zerun b017670e1b
[Improve] Use PyTorch official `scaled_dot_product_attention` to accelerate `MultiheadAttention`. ()
* [Improve] Use PyTorch official `scaled_dot_product_attention` to accelerate `MultiheadAttention`.

* Support `--local-rank` and `--amp` option for new version PyTorch.

* Fix imports and UT.
2023-03-29 15:50:44 +08:00
Yixiao Fang 164f16e248
[Fix] Fix init bug of r50 in contrastive leanrning () 2023-03-29 15:49:51 +08:00
Ezra-Yu 555adab0a0
[Doc] Update logos and Add more social networking links ()
* add social media link in README

* replace logos

* update links

* update logo

* upodate logo size
2023-03-29 13:35:56 +08:00
Yixiao Fang 53dc810c08
[Refactor] Add projects from mmselfsup ()
* update projects/fgia

* add video maskfeat in projects

* update according to review
2023-03-27 16:59:08 +08:00
Ma Zerun c4ccae40db
[Docs] Update user guides docs and tools for MMPretrain. ()
* [Docs] Update user guides docs and tools for MMPretrain.

* Fix UT

* Fix Chinese docs.

* Improve according to comments.

* Fix windows CI.
2023-03-27 14:32:26 +08:00
mzr1996 a50d96f7f1 Update docs. 2023-03-20 16:12:10 +08:00
mzr1996 175d19f67e Update docs. 2023-03-20 16:10:33 +08:00
mzr1996 1f78ab410f Update docs. 2023-03-20 16:05:53 +08:00
mzr1996 6038df9514 Update docs. 2023-03-20 16:03:57 +08:00
Yixiao Fang f6b65fcbe7
[Docs] Update get start docs and user guides. ()
* update user_guides

* update test.md

* fix lint

* fix typo

* refine

* fix typo

* update retriever to api

* update rst and downstream

* update index.rst

* update index.rst

* update custom.js

* update chinese docs

* update config.md

* update train and test

* add pretrain on custom dataset

* fix lint
2023-03-20 15:56:09 +08:00
mzr1996 04e15ab347 Update circle-CI 2023-03-20 14:55:08 +08:00
Ma Zerun 6cedce234e
[Refactor] Update dev scripts to be compatible with selfsup tasks. ()
* [Refactor] Update dev scripts to be compatible with selfsup tasks.

* Fix some missing fields in config files.

* Set maximum number of gpus for local training.

* Update README files

* Update according to comments.
2023-03-20 14:30:57 +08:00
Ma Zerun 4f5b38f225
[Refactor] Update almost tools and add unit tests for these tools. ()
* [Refactor] Update almost tools and add unit tests for these tools.

* Fix Windows UT.
2023-03-17 10:50:51 +08:00
Yixiao Fang 8875e9da92
[Docs] Update migration.md ()
* update migration

* refine table

* update zh_cn

* fix lint

* Polish the documentation by ChatGPT.

* Update sphinx version and fix some warning.

---------

Co-authored-by: mzr1996 <mzr1996@163.com>
2023-03-17 10:30:09 +08:00
Yixiao Fang 76a1f3f735
[Refactor] Refactor the `browse_dataset.py` to support selfsup pipeline. ()
* refactor browsedataset to support selfsup pipeline

* update make_grid to support list input

* mode 'transformed' supports list

* Beautify the visualization image.

* Fix compatitibly bug with matplotlib=3.5

* remove print

* fix bug of resize

* Apply mask only on the first image.

* Remove master only for some API in visualizer.

---------

Co-authored-by: mzr1996 <mzr1996@163.com>
2023-03-15 14:18:36 +08:00
Ma Zerun 3472ee5d2c
[Feature] Implememnt the universal visualizer for multiple task. ()
* [Feature] Implememnt the universal visualizer for multiple task.

* Update tools

* Improve according to comments.

* Fix tools docs

* Add --test-cfg option and set default collate function.
2023-03-09 11:36:54 +08:00
Ma Zerun dbf3df21a3
[Refactor] Use `out_type` to specify ViT-like backbone output. ()
* [Refactor] Use  to specify ViT-like backbone output.

* Fix ClsBatchNormNeck

* Update mmpretrain/models/necks/mae_neck.py

---------

Co-authored-by: Yixiao Fang <36138628+fangyixiao18@users.noreply.github.com>
2023-03-09 11:02:58 +08:00
Yixiao Fang 63e5b512cc
[Refactor] Move part of tool scripts from mmselfsup. ()
* add dataset converters and benchmark .sh

* refine

* fix lint

* add tsne

* rename visualizaition

* update configs and script
2023-03-07 17:57:08 +08:00
Ma Zerun 274a67223e
[Feature] Implement layer-wise learning rate decay optimizer constructor. ()
* [Feature] Implement layer-wise learning rate decay optimizer constructor.

* Use num_layers instead of max_depth to avoid misleading

* Add UT

* Update docstring

* Update log info

* update LearningRateDecay configs

---------

Co-authored-by: fangyixiao18 <fangyx18@hotmail.com>
2023-03-07 17:30:39 +08:00
Yixiao Fang 827be6e22d
[Fix] Fix value error while start training. ()
* fix value error of randint

* add missing key in configs
2023-03-07 08:51:31 +08:00
Yixiao Fang 08dc8c75d3
[Refactor] Add selfsup algorithms. ()
* remove basehead

* add moco series

* add byol simclr simsiam

* add ut

* update configs

* add simsiam hook

* add and refactor beit

* update ut

* add cae

* update extract_feat

* refactor cae

* add mae

* refactor data preprocessor

* update heads

* add maskfeat

* add milan

* add simmim

* add mixmim

* fix lint

* fix ut

* fix lint

* add eva

* add densecl

* add barlowtwins

* add swav

* fix lint

* update readtherdocs rst

* update docs

* update

* Decrease UT memory usage

* Fix docstring

* update DALLEEncoder

* Update model docs

* refactor dalle encoder

* update docstring

* fix ut

* fix config error

* add val_cfg and test_cfg

* refactor clip generator

* fix lint

* pass check

* fix ut

* add lars

* update type of BEiT in configs

* Use MMEngine style momentum in EMA.

* apply mmpretrain solarize

---------

Co-authored-by: mzr1996 <mzr1996@163.com>
2023-03-06 16:53:15 +08:00
Ma Zerun a05c79e806
[Refactor] Move transforms in mmselfsup to mmpretrain. ()
* [Refactor] Move transforms in mmselfsup to mmpretrain.

* Update transform docs and configs. And register some mmcv transforms in
mmpretrain.

* Fix missing transform wrapper.

* update selfsup transforms

* Fix UT

* Fix UT

* update gaussianblur inconfigs

---------

Co-authored-by: fangyixiao18 <fangyx18@hotmail.com>
2023-03-03 15:01:11 +08:00
mzr1996 1d6e37e56b Avoid build pdf docs. 2023-03-02 18:14:35 +08:00
mzr1996 e035e03d59 Update docs style. 2023-03-02 13:29:48 +08:00
Ma Zerun dda3d6565b
[Docs] Update generate_readme.py and readme files. ()
* Update generate_readme.py and readme files.

* Update reamde

* Update docs

* update metafile

* update simmim readme

* update

* update mae

* fix lint

* update mocov2

* update readme pic

* fix lint

* Fix mmcls download links.

* Fix Chinese docs.

* Decrease readthedocs requirements.

---------

Co-authored-by: fangyixiao18 <fangyx18@hotmail.com>
2023-03-02 13:29:07 +08:00
Yixiao Fang c9670173aa
[Refactor] Move and refactor utils from mmselfsup. ()
* add heads

* add losses

* fix

* remove mim head

* add modified backbones and target generators

* fix lint

* fix lint

* add heads

* add losses

* fix

* add data preprocessor from mmselfsup

* add ut for data prepocessor

* add GatherLayer

* add ema

* add batch shuffle

* add misc

* fix lint

* update

* update docstring
2023-02-28 17:04:40 +08:00
Ma Zerun 414ba80274
[Refactor] Refactor APIs, add `ImageRetrievalInferencer` and `FeatureExtractor`. ()
* [Refactor] Refactor APIs, add `ImageRetrievalInferencer` and `FeatureExtractor'.

* Update image retrieval

* Update FeatureExtractor

* Fix UT
2023-02-28 16:31:42 +08:00
Yixiao Fang e453a45d31
[Refactor] Add self-supervised backbones and target generators. ()
* add heads

* add losses

* fix

* remove mim head

* add modified backbones and target generators

* add unittest

* refactor caevit

* add window_size check

* fix lint

* apply new DataSample

* fix ut error

* update ut

* fix ut

* fix lint

* Update base modules.

---------

Co-authored-by: mzr1996 <mzr1996@163.com>
2023-02-28 15:59:17 +08:00
Yixiao Fang 63d9f27fde
[Refactor] Add necks, heads and losses for the self-supervised task. ()
* add necks

* refactor linear neck

* rename simmim neck

* add heads

* add losses

* fix

* add unittest

* update

* update cae

* remove mim head

* update config
2023-02-28 10:05:00 +08:00
Yixiao Fang 75c79311f4
[Refactor] Update datasets ()
* add ut

* add places205

* support ann_file without labels

* temp test

* update custom

* update

* update ut

* Update CustomDataset.

* Update Places205.

---------

Co-authored-by: mzr1996 <mzr1996@163.com>
2023-02-27 15:42:22 +08:00
Yixiao Fang 89000c10eb
[Refactor] Refactor configs and metafile ()
* update base datasets

* update base

* update barlowtwins

* update with new convention

* update

* update

* update

* add schedule

* add densecl

* add eva

* add mae

* add maskfeat

* add milan and mixmim

* add moco

* add swav simclr

* add simmim and simsiam

* refine

* update

* add to model index

* update config inheritance

* fix error in metafile

* Update pre-commit and metafile check script

* update metafile

* fix name error

* Fix classification model name and config name

---------

Co-authored-by: mzr1996 <mzr1996@163.com>
2023-02-23 11:17:16 +08:00
Ma Zerun 36bea13fca
[Refactor] Refactor ClsDatasample to a union DataSample. ()
* [Refactor] Refactor ClsDatasample to a union DataSample.

* Add  method

* Fix docstring

* Update docstring.
2023-02-23 10:07:53 +08:00
mzr1996 4016f1348e Fix CI 2023-02-17 15:33:16 +08:00
mzr1996 0979e78573 Rename the package name to `mmpretrain`. 2023-02-17 15:20:55 +08:00
QINGTIAN 8352951f3d
[Feature] Support XCiT Backbone. ()
* update model file

* Update XCiT implementation and configs.

* Update metafiles

* Update metafile

* Fix floor divide

* Imporve memory usage

---------

Co-authored-by: qingtian <459291290@qq.com>
Co-authored-by: mzr1996 <mzr1996@163.com>
2023-02-15 10:32:35 +08:00
Ma Zerun bedf4e9f64
[Refactor] Update analysis tools and documentations. ()
* [Refactor] Update analysis tools and documentations.

* Update migration.md and add unit test.

* Fix print_config.py
2023-02-15 10:28:08 +08:00
Ma Zerun b4ee9d2848
[Feature] Support calculate confusion matrix and plot it. ()
* [Feature] Support calculate confusion matrix and plot it.

* Fix keepdim

* Update confusion_matrix tools and the plot graph.

* Revert accidental modification.

* Update docstring

* Move confusion matrix tool to
2023-02-14 12:58:11 +08:00
takuoko 841256b630
[Feature] Support RetrieverRecall metric & Add ArcFace config ()
* rebase

* add ap metric

* fix mlti-gpu bug in retrevel

* rebase

* rebase

* add training cfgs and update readme.md

* fix bugs(cannot load vecs in dist and diff test-val recall\)

* update configs and readme

* fix ut

* fix doc

* rebase

* fix rebase conflicts

* fix rebase error

* fix UT error

* fix docs

* fix typo

---------

Co-authored-by: Ezra-Yu <18586273+Ezra-Yu@users.noreply.github.com>
2023-02-14 12:46:21 +08:00
huangxuming 1c1273abca
[Docs] Translate some tools tutorials to Chinese. ()
* Delete print_config.md

* Add files via upload
2023-02-09 16:40:13 +08:00
Ezra-Yu 705ed2be49
[Fix] Fix retrieval multi gpu bug ()
* fix mlti-gpu bug in retrevel

* fix bugs(cannot load vecs in dist and diff test-val recall\)

* load weight each process
2023-02-09 15:55:47 +08:00
Ma Zerun 7ec6062415
[Refactor] Unify the `--out` and `--dump` in `tools/test.py`. () 2023-02-09 14:05:03 +08:00
Ezra-Yu 58cefa5c0f
[Fix] Fix error repvgg-deploy base config path. () 2023-02-09 11:27:35 +08:00
fam_taro 4ce7be17c9
[Enhance] Enable to toggle whether Gem Pooling is trainable or not. ()
* Enable to toggle whether Gem Pooling is trainable or not.

* Add test case whether Gem Pooling is trainable or not.

* Enable to toggle whether Gem Pooling is trainable or not by requires_grad

---------

Co-authored-by: Yusuke Fujimoto <yusuke.fujimoto@rist.co.jp>
2023-02-09 11:27:05 +08:00
Ma Zerun a3f2effb17
[Feature] Add `ImageClassificationInferencer`. ()
* [Feature] Add ImageClassificationInferencer.

* Update inferencer implementation and add unit tests.

* Update documentations.

* Update pre-commit hook

* Update docs
2023-02-08 14:30:12 +08:00
zzc98 7e4502b0ac
[Feature] Support InShop Dataset (Image Retrieval). ()
* rebase

* feat: add inshop dataset (retrieval)

* update fileIO

* update unit tests

* fix windows ci

* fix windows ci

* fix lint

* update unit tests

* update docs

* update docs

Co-authored-by: Ezra-Yu <18586273+Ezra-Yu@users.noreply.github.com>
2023-01-18 17:16:54 +08:00
GAO SHIQI 353886eaca
[Docs] Add Chinese translation for runtime.md. ()
* [Docs] Add Chinese translation for runtime.md

* [Docs] Update Chinese translation for runtime.md

* [Docs] Update for runtime.md

* [Docs] Update for runtime.md

* [Docs] Update for runtime.md
2023-01-18 17:15:43 +08:00
aso538 6b9e2b55dd
[Feature] Support LeViT backbone. ()
* 网络搭建完成、能正常推理

* 网络搭建完成、能正常推理

* 网络搭建完成、能正常推理

* 添加了模型转换未验证,配置文件 但有无法运行

* 模型转换、结构验证完成,可以推理出正确答案

* 推理精度与原论文一致 已完成转化

* 三个方法改为class 暂存

* 完成推理精度对齐 误差0.04

* 暂时使用的levit2mmcls

* 训练跑通,训练相关参数未对齐

* '训练相关参数对齐'参数'

* '修复训练时验证导致模型结构改变无法复原问题'

* '修复训练时验证导致模型结构改变无法复原问题'

* '添加mixup和labelsmooth'

* '配置文件补齐'

* 添加模型转换

* 添加meta文件

* 添加meta文件

* 删除demo.py测试文件

* 添加模型README文件

* docs文件回滚

* model-index删除末行空格

* 更新模型metafile

* 更新metafile

* 更新metafile

* 更新README和metafile

* 更新模型README

* 更新模型metafile

* Delete the model class and get_LeViT_model methods in the mmcls.models.backone.levit file

* Change the class name to Google Code Style

* use arch to provide default architectures

* use nn.Conv2d

* mmcv.cnn.fuse_conv_bn

* modify some details

* remove down_ops from the architectures.

* remove init_weight function

* Modify ambiguous variable names

* Change the drop_path in config to drop_path_rate

* Add unit test

* remove train function

* add unit test

* modify nn.norm1d to build_norm_layer

* update metafile and readme

* Update configs and LeViT implementations.

* Update README.

* Add docstring and update unit tests.

* Revert irrelative modification.

* Fix unit tests

* minor fix

Co-authored-by: mzr1996 <mzr1996@163.com>
2023-01-17 17:43:42 +08:00
szwlh-c c98dc4555c
[Feature] Support VIG Backbone. ()
* 添加vig源文件

* 某些模块修改到mmcls风格

* 修改到mmcls风格

* 修改

* 添加VIG模型及源文件

* update model file

* update model file and config

* change class name and some variable name

* change class name and some variable name

* update

* update

* change nn.BatchNorm to mmcv.cnn.build_norm_layer

* update

* change nn.Seq to mmcls

* change backbone to stage_blocks

* add vig_head

* update

* update config file

* update

* add readme and metafile

* update model-index.yml

* update model file

* rename config file and add docstring

* variable rename

* update readme and metafile

* update readme

* update

* Update VIG backbone implementation and docs.

* Fix configs.

Co-authored-by: mzr1996 <mzr1996@163.com>
2023-01-17 16:55:56 +08:00
CSH c73a5a8b15
[Fix] Fix bug in test tools. ()
* fix append

* Update test.py

* Update test.py
2023-01-12 14:18:00 +08:00
Ma Zerun 97c4ae8805
[Improve] Update registries of mmcls. ()
* [Improve] Update registries of mmcls.

* Update according to comments
2023-01-11 15:20:51 +08:00
Ma Zerun aa53f7790c
[Tool] Add metafile fill and validation tools. ()
* [Tool] Add metafile fill and validation tools.

* Fix pre-commit hooks and update fill tool.

* Update metafile

* Update fill tool

* Remove support no config.

* Add 3rdparty check.
2023-01-10 10:06:56 +08:00
Ma Zerun 060b0ed3b5
[Feature] Support ConvNeXt-V2 backbone. ()
* [Feature] Support ConvNeXt-V2.

* Use registry of mmcls instead of mmengine.

* Add README.

* Add unit tests and docs.
2023-01-06 16:13:41 +08:00
QINGTIAN e880451a54
[Improve] Remove useless EfficientnetV2 config files. ()
* update model file

* Delete xcit.py

* Delete efficientnet_v2_b0.py

* Delete efficientnet_v2_b1.py

* Delete efficientnet_v2_b2.py

* Delete efficientnet_v2_b3.py

* Delete efficientnet_v2_l.py

* Delete efficientnet_v2_m.py

* Delete efficientnet_v2_s.py

* Delete efficientnet_v2_xl.py

* Delete efficientnet_v2-b0_8xb32_in1k.py

* Delete efficientnet_v2-b1_8xb32_in1k.py

* Delete efficientnet_v2-b2_8xb32_in1k.py

* Delete efficientnet_v2-b3_8xb32_in1k.py

* Delete efficientnet_v2-l_8xb32_in1k.py

* Delete efficientnet_v2-l_8xb32_in21ft1k.py

* Delete efficientnet_v2-m_8xb32_in1k.py

* Delete efficientnet_v2-m_8xb32_in21ft1k.py

* Delete efficientnet_v2-s_8xb32_in1k.py

* Delete efficientnet_v2-s_8xb32_in21ft1k.py

* Delete efficientnet_v2-xl_8xb32_in21ft1k.py

Co-authored-by: qingtian <459291290@qq.com>
2023-01-06 16:10:59 +08:00
mzr1996 c7ec630c37 Merge branch 'dev-1.x' into 1.x 2022-12-30 17:32:33 +08:00
Ma Zerun 0d8f918eaa
Bump version to v1.0.0rc5. () 2022-12-30 17:32:04 +08:00
Mr.Li 4f5350f365
[Doc] Fix typo. ()
* [Fix] Fix imports in transforms. ()

* fix import

* import from mmegine.utils

* 修复错别字

Co-authored-by: Xieyuan Zhang <25652281+Francis777@users.noreply.github.com>
2022-12-30 15:52:57 +08:00
Ezra-Yu 88e5ba28db
[Reproduce] Reproduce RepVGG Training Accuracy. ()
* repr repvgg

* add VisionRRC

* uodate repvgg configs

* add BCD seriers cfgs

* add cv backend config

* add vision configs

* add L2se configs

* add ra configs

* add num-works configs

* add num-works configs

* configs

* update README

* rm extra config

* reset un-needed changes

* update

* reset pbn

* update readme

* update code

* update code

* refine doc
2022-12-30 15:49:56 +08:00
WINDSKY45 e0e6a1f1ae
[Docs] Fix typo. () 2022-12-30 15:21:51 +08:00
QINGTIAN 74743ef588
[Feature] [CodeCamp ] Add EfficientnetV2 Backbone. ()
* add efficientnet_v2.py

* add efficientnet_v2 in __init__.py

* add efficientnet_v2_s base config file

* add efficientnet_v2 config file

* add efficientnet_v2 config file

* update tuple output

* update config file

* update model file

* update model file

* update model file

* update config file

* update model file

* update config file

* update model file

* update model file

* update model file

* update model file

* update model file

* update config file

* update config file

* update model file

* update model file

* update model file

* update model file

* update model config file

* Update efficientnet_v2.py

* add config file and modify arch

* add config file and modify arch

* add the file about convert_pth from timm to mmcls

* update efficientnetv2 model file with mmcls style

* add the file about convert_pth from timm to mmcls

* add the file about convert_pth from timm to mmcls

* update convert file

* update model file

* update convert file

* update model file

* update model file

* update model file

* add metefile and README

* Update tools/model_converters/efficientnetv2-timm_to_mmcls.py

Co-authored-by: Ezra-Yu <18586273+Ezra-Yu@users.noreply.github.com>

* update model file and convert file

* update model file and convert file

* update model file and convert file

* update model file and convert file

* update model file

* update model file

* update model file

* update config file and README file

* update metafile

* Update efficientnetv2_to_mmcls.py

* update model-index.yml

* update metafile.yml

* update b0 and s train pipeline

* update b0 and s train pipeline

* update b0 and s train pipeline

* add test_efficientnet_v2

* update test_efficientnet_v2

* update model file docs

* update test_efficientnet_v2

* update test_efficientnet_v2

* add efficientnet_v2.py

* add efficientnet_v2 in __init__.py

* add efficientnet_v2_s base config file

* add efficientnet_v2 config file

* add efficientnet_v2 config file

* update tuple output

* update config file

* update model file

* update model file

* update model file

* update model file

* update config file

* update config file

* update model file

* update model file

* update model file

* update model file

* update model file

* update config file

* update config file

* update model file

* update model file

* update model file

* update model file

* update model config file

* Update efficientnet_v2.py

* add config file and modify arch

* add config file and modify arch

* add the file about convert_pth from timm to mmcls

* update efficientnetv2 model file with mmcls style

* add the file about convert_pth from timm to mmcls

* add the file about convert_pth from timm to mmcls

* update convert file

* update model file

* update convert file

* update model file

* update model file

* update model file

* add metefile and README

* Update tools/model_converters/efficientnetv2-timm_to_mmcls.py

Co-authored-by: Ezra-Yu <18586273+Ezra-Yu@users.noreply.github.com>

* update model file and convert file

* update model file and convert file

* update model file and convert file

* update model file and convert file

* update model file

* update model file

* update model file

* update config file and README file

* update metafile

* Update efficientnetv2_to_mmcls.py

* update model-index.yml

* update metafile.yml

* update b0 and s train pipeline

* update b0 and s train pipeline

* update b0 and s train pipeline

* add test_efficientnet_v2

* update test_efficientnet_v2

* update model file docs

* update test_efficientnet_v2

* update test_efficientnet_v2

* pass pre-commit hook

* refactor efficientnetv2

* refactor efficientnetv2

* update readme, metafile and weight links

* update model-index.yml

* fix lint

* fix typo

* Update efficientnetv2-b1_8xb32_in1k.py

* Update efficientnetv2-b2_8xb32_in1k.py

* Update efficientnetv2-b3_8xb32_in1k.py

* update two moduals and model file

* update modual file

* update accuracys

* update accuracys

* update metafile

* fix build docs

* update links

* update README.md

Co-authored-by: qingtian <459291290@qq.com>
Co-authored-by: Ezra-Yu <18586273+Ezra-Yu@users.noreply.github.com>
2022-12-30 15:18:39 +08:00
Ma Zerun 9038c1c255
[Feature] Support TTA and add `--tta` in `tools/test.py`. ()
* [Feature] Support TTA and add `--tta` in `tools/test.py`.

* Add unit tests.

* Rename the TTA model to `AverageClsScoreTTA`.
2022-12-30 11:46:17 +08:00
Colle bac181f393
[Feature] Support Multi-task. ()
* unit test for multi_task_head

* [Feature] MultiTaskHead (, )

* [Fix] lint for multi_task_head

* [Feature] Add `MultiTaskDataset` to support multi-task training.

* Update MultiTaskClsHead

* Update docs

* [CI] Add test mim CI. ()

* [Fix] Remove duplicated wide-resnet metafile.

* [Feature] Support MPS device. ()

* [Feature] Support MPS device.

* Add `auto_select_device`

* Add unit tests

* [Fix] Fix Albu crash bug. ()

* Fix albu BUG: using albu will cause the label from array(x) to array([x]) and crash the trainning

* Fix common

* Using copy incase potential bug in multi-label tasks

* Improve coding

* Improve code logic

* Add unit test

* Fix typo

* Fix yapf

* Bump version to 0.23.2. ()

* [Improve] Use `forward_dummy` to calculate FLOPS. ()

* Update README

* [Docs] Fix typo for wrong reference. ()

* [Doc] Fix typo in tutorial 2 ()

* [Docs] Fix a typo in ImageClassifier ()

* add mask to loss

* add another pipeline

* adpat the pipeline if there is no mask

* switch mask and task

* first version of multi data smaple

* fix problem with attribut by getattr

* rm img_label suffix, fix 'LabelData' object has no attribute 'gt_label'

* training  without evaluation

* first version work

* add others metrics

* delete evaluation from dataset

* fix linter

* fix linter

* multi metrics

* first version of test

* change evaluate metric

* Update tests/test_models/test_heads.py

Co-authored-by: Colle <piercus@users.noreply.github.com>

* Update tests/test_models/test_heads.py

Co-authored-by: Colle <piercus@users.noreply.github.com>

* add tests

* add test for multidatasample

* create a generic test

* create a generic test

* create a generic test

* change multi data sample

* correct test

* test

* add new test

* add test for dataset

* correct test

* correct test

* correct test

* correct test

* fix : 

* run yapf

* fix linter

* fix linter

* fix linter

* fix isort

* fix isort

* fix docformmater

* fix docformmater

* fix linter

* fix linter

* fix data sample

* Update mmcls/structures/multi_task_data_sample.py

Co-authored-by: Colle <piercus@users.noreply.github.com>

* Update mmcls/structures/multi_task_data_sample.py

Co-authored-by: Colle <piercus@users.noreply.github.com>

* Update mmcls/structures/multi_task_data_sample.py

Co-authored-by: Colle <piercus@users.noreply.github.com>

* Update mmcls/structures/multi_task_data_sample.py

Co-authored-by: Colle <piercus@users.noreply.github.com>

* Update mmcls/structures/multi_task_data_sample.py

Co-authored-by: Colle <piercus@users.noreply.github.com>

* Update mmcls/structures/multi_task_data_sample.py

Co-authored-by: Colle <piercus@users.noreply.github.com>

* Update tests/test_structures/test_datasample.py

Co-authored-by: Colle <piercus@users.noreply.github.com>

* Update mmcls/structures/multi_task_data_sample.py

Co-authored-by: Colle <piercus@users.noreply.github.com>

* Update tests/test_structures/test_datasample.py

Co-authored-by: Colle <piercus@users.noreply.github.com>

* Update tests/test_structures/test_datasample.py

Co-authored-by: Colle <piercus@users.noreply.github.com>

* update data sample

* update head

* update head

* update multi data sample

* fix linter

* fix linter

* fix linter

* fix linter

* fix linter

* fix linter

* update head

* fix problem we don't  set pred or  gt

* fix problem we don't  set pred or  gt

* fix problem we don't  set pred or  gt

* fix linter

* fix : 

* fix : linter

* update multi head

* fix linter

* fix linter

* update data sample

* update data sample

* fix ; linter

* update test

* test pipeline

* update pipeline

* update test

* update dataset

* update dataset

* fix linter

* fix linter

* update formatting

* add test for multi-task-eval

* update formatting

* fix linter

* update test

* update

* add test

* update metrics

* update metrics

* add doc for functions

* fix linter

* training for multitask 1.x

* fix linter

* run flake8

* run linter

* update test

* add mask in evaluation

* update metric doc

* update metric doc

* Update mmcls/evaluation/metrics/multi_task.py

Co-authored-by: Colle <piercus@users.noreply.github.com>

* Update mmcls/evaluation/metrics/multi_task.py

Co-authored-by: Colle <piercus@users.noreply.github.com>

* Update mmcls/evaluation/metrics/multi_task.py

Co-authored-by: Colle <piercus@users.noreply.github.com>

* Update mmcls/evaluation/metrics/multi_task.py

Co-authored-by: Colle <piercus@users.noreply.github.com>

* Update mmcls/evaluation/metrics/multi_task.py

Co-authored-by: Colle <piercus@users.noreply.github.com>

* Update mmcls/evaluation/metrics/multi_task.py

Co-authored-by: Colle <piercus@users.noreply.github.com>

* update metric doc

* update metric doc

* Fix cannot import name MultiTaskDataSample

* fix test_datasets

* fix test_datasets

* fix linter

* add an example of multitask

* change name of configs dataset

* Refactor the multi-task support

* correct test and metric

* add test to multidatasample

* add test to multidatasample

* correct test

* correct metrics and clshead

* Update mmcls/models/heads/cls_head.py

Co-authored-by: Colle <piercus@users.noreply.github.com>

* update cls_head.py documentation

* lint

* lint

* fix: lint

* fix linter

* add eval mask

* fix documentation

* fix: single_label.py back to 1.x

* Update mmcls/models/heads/multi_task_head.py

Co-authored-by: Ma Zerun <mzr1996@163.com>

* Remove multi-task configs.

Co-authored-by: mzr1996 <mzr1996@163.com>
Co-authored-by: HinGwenWoong <peterhuang0323@qq.com>
Co-authored-by: Ming-Hsuan-Tu <alec.tu@acer.com>
Co-authored-by: Lei Lei <18294546+Crescent-Saturn@users.noreply.github.com>
Co-authored-by: WRH <12756472+wangruohui@users.noreply.github.com>
Co-authored-by: marouaneamz <maroineamil99@gmail.com>
Co-authored-by: marouane amzil <53240092+marouaneamz@users.noreply.github.com>
2022-12-30 10:36:00 +08:00
Rongjie Li 5b266d9e7c
[Feature] Add clip backbone. ()
* clip backbone added

* passed precommit

* update readme

* update according to PR review

* add missing file

* add unittest

* refine metafile

* refine metafile and readme for readdocs

* refine metafile

* refine metafile

* Update metafile

Co-authored-by: mzr1996 <mzr1996@163.com>
2022-12-22 16:33:57 +08:00
Wangbo Zhao(黑色枷锁) 14dcb69092
[Feature] Add mixmim backbone with checkpoints. ()
* add mixmim backbone

* add mixmim inference

* add docstring, metafile, test and modify readme

* Update README and metafile

Co-authored-by: mzr1996 <mzr1996@163.com>
2022-12-20 16:52:54 +08:00
Ma Zerun 7dcf34533d
[Docs] Update TinyViT links. () 2022-12-20 16:22:31 +08:00
Songyang Zhang 5547f4cac4
[Feature] Add TinyViT for dev-1.x. ()
* [Feature] add TinyViT for dev-1.x

* [Feature] update readme

* fix lint error

* refactor the code

* [Update] update the args

* [Update] add unit test and fix bugs

* Rename the configuration file

* delete invalid files

* [Feature] update tinyvit readme

* [Feature] update tinyvit readme

* [Feature] update metafile

* Update tinyvit metafile
2022-12-20 13:04:00 +08:00
takuoko 9e82db6032
[Enhance] Support ConvNeXt More Weights. ()
* convnext more weights

* Update metafile and README

* Fix link

Co-authored-by: mzr1996 <mzr1996@163.com>
2022-12-19 17:17:57 +08:00
Ma Zerun 3006fa26ab
[Fix] Fix CAM visualization. () 2022-12-19 13:54:52 +08:00
Ma Zerun b63515111b
[Reproduce] Update ConvNeXt config files. ()
* Update ConvNeXt training configs.

* Update ConvNeXt network.

* Update metafile and README.

* Update README
2022-12-19 13:54:24 +08:00
Ma Zerun 0e4163668f
[Feature] Add some scripts for development. ()
* [Feature] Add some scripts for development.

* Add `generate_readme.py`.

* Update according to comments
2022-12-19 13:53:13 +08:00
Ma Zerun 6ea59bd846
[Fix] Fix the requirements and lazy register mmcls models. () 2022-12-19 13:01:11 +08:00
Xieyuan Zhang e9f9bb200e
[Fix] Fix imports in transforms. ()
* fix import

* import from mmegine.utils
2022-12-19 13:00:48 +08:00
Ma Zerun 46af7d3ed2
[CI] Update CI to test PyTorch 1.13.0. () 2022-12-14 13:47:32 +08:00
takuoko 2535c1ecd7
[Feature] Support EVA. ()
* add eva

* add eva

* add eva

* sklearn -> scikit-learn

* add large

* Update model names and links.

* Fix resize pos embed error when loading fp16 weight.

* Remove verbose configs.

Co-authored-by: mzr1996 <mzr1996@163.com>
2022-12-14 13:21:33 +08:00
Songyang Zhang 210373c093
[Feature] Implementation of RevViT. ()
* [Feature] implement rev-vit network

* can reproduce the RevViT-Small accuracy 79.9

* update

* [Feature] update revvit

* [Feature] update revvit readme

* Update links

Co-authored-by: mzr1996 <mzr1996@163.com>
2022-12-14 11:46:39 +08:00
Ezra-Yu 1c6b077bb1
[Project] Add ACCV workshop 1st Solution. ()
* add accv workshop 1st project

* update projects

* update projects

* fix lint

* Update projects/fgia_accv2022_1st/README.md

Co-authored-by: Yuan Liu <30762564+YuanLiuuuuuu@users.noreply.github.com>

* Update projects/fgia_accv2022_1st/README.md

Co-authored-by: Yuan Liu <30762564+YuanLiuuuuuu@users.noreply.github.com>

* Update projects/fgia_accv2022_1st/README.md

Co-authored-by: Yuan Liu <30762564+YuanLiuuuuuu@users.noreply.github.com>

* update

* Update projects/fgia_accv2022_1st/README.md

Co-authored-by: Yuan Liu <30762564+YuanLiuuuuuu@users.noreply.github.com>

* Update projects/fgia_accv2022_1st/README.md

Co-authored-by: Yuan Liu <30762564+YuanLiuuuuuu@users.noreply.github.com>

* Update projects/fgia_accv2022_1st/README.md

Co-authored-by: Yuan Liu <30762564+YuanLiuuuuuu@users.noreply.github.com>

* update

* Update projects/fgia_accv2022_1st/README.md

Co-authored-by: Yuan Liu <30762564+YuanLiuuuuuu@users.noreply.github.com>

* Update projects/fgia_accv2022_1st/README.md

Co-authored-by: Yuan Liu <30762564+YuanLiuuuuuu@users.noreply.github.com>

* Update projects/fgia_accv2022_1st/README.md

Co-authored-by: Yuan Liu <30762564+YuanLiuuuuuu@users.noreply.github.com>

* Update projects/fgia_accv2022_1st/README.md

Co-authored-by: Yuan Liu <30762564+YuanLiuuuuuu@users.noreply.github.com>

* update

* update

* update

Co-authored-by: Yuan Liu <30762564+YuanLiuuuuuu@users.noreply.github.com>
2022-12-12 18:55:09 +08:00
Ma Zerun ea53bce580
[Project] Add Example project. () 2022-12-12 17:57:09 +08:00
mzr1996 458ac4c89a Merge remote-tracking branch 'origin/dev-1.x' into 1.x 2022-12-06 18:00:59 +08:00
Ma Zerun 12eca5b94a
Bump version to v1.0.0rc4. ()
* Bump version to v1.0.0rc4

* Update according to comments
2022-12-06 18:00:01 +08:00
kitecats ef3610d962
[Bug] Fix `reparameterize_model.py` doesn't save meta info. ()
* fix reparameterize_model.py don't save meta info

* fix error

* Fix symbol link

Co-authored-by: mzr1996 <mzr1996@163.com>
2022-12-06 17:06:31 +08:00
Ma Zerun a4ec2799d6
[Docs] Update install tutorial. ()
* [Docs] Update install tutorial.

* Update collapsed sections.

* Update requirements

* Update install tutorial.
2022-12-06 17:00:32 +08:00
Ma Zerun c127c474b9
[Feature] Support getting model from the name defined in the model-index file. ()
* [Feature] Support getting model from the name defined in the model-index file.

* Add unit tests.

* Prevent import `timm` if the `TIMMBackbone` is not used.

* Fix Windows CI.

* Move `init_model` to `mmcls.apis.hub`, and support pass nn.Module to all
model components.

* Fix requirements

* Rename `hub.py` to `model.py` and add unit tests.
2022-12-06 17:00:22 +08:00
Jiahao Wang d990982fc0
[Docs] Update MobileNetv2 & MobileNetv3 readme. ()
* update mobilenetv2/v3 doc

* update mobilenetv2/v3 doc

* update mobilenetv2/v3 doc
2022-12-06 11:53:03 +08:00
Yixiao Fang df2f122daa
[Fix] Fix dict update in BeIT. ()
* fix dict update

* remove breakpoint

* fix key error

* fix lint

* update

* update
2022-12-05 17:59:36 +08:00
Ma Zerun 7b9a1010f5
[Enhance] Support evaluate on both EMA and non-EMA models. ()
* [Enhance] Support evaluate on both EMA and original models.

* Fix lint
2022-12-05 14:16:12 +08:00
Yixiao Fang d80ec5a4b8
[Refactor] Refactor BEiT backbone and support v1/v2 inference. ()
* refactor beit backbone

* use LinearClsHead

* fix mean and std value

* fix lint

* support inference if beit-v2

* update encoder layer and init

* update

* add ut

* add prepare_relative_position_bias_table function

* add cls_token

* fix lint

* add pos_embed check

* update metafile and readme

* update weights link

* update link of weights

* update metafile

* update

* update docstrings

* update according to review

* rename readme

* update docstring

* fix lint
2022-11-29 12:56:33 +08:00
mzr1996 35fb03a577 Remove wrong files. 2022-11-25 09:54:36 +08:00
Ma Zerun f9be21ab74
[Docs] Add version selection in the banner. ()
* [Docs] Add version selection in the banner.

* Small fix.
2022-11-23 13:38:23 +08:00
mzr1996 75e502ed75 Update Chinese translation. 2022-11-22 15:19:03 +08:00
mzr1996 a4cfd55dd2 Update chinese translation. 2022-11-22 11:39:49 +08:00
mzr1996 44d2886422 [Docs] Add some Chinese translation for API pages. 2022-11-22 09:49:39 +08:00
Ma Zerun 13ff394985
Bump version to v1.0.0rc3. ()
* Bump version to v1.0.0rc3

* Update pre-commit hook
2022-11-21 18:21:48 +08:00
Ezra-Yu b0007812d6
[Enhance] Enhance ArcFaceClsHead. ()
* update arcface

* fix unit tests

* add adv-margins

add adv-margins

update arcface

* rebase

* update doc and fix ut

* rebase

* update code

* rebase

* use label data

* update set-margins

* Modify Arcface related method names.

Co-authored-by: mzr1996 <mzr1996@163.com>
2022-11-21 18:10:39 +08:00
takuoko 4fb44f8770
[Feature] EfficientNets NoisyStudent & L2. ()
* add mvit 21k

* add mvit 21k

* add effnet noisy student

* Revert "add mvit 21k"

This reverts commit f51067c559.

* revert mvit pr

* update link and readme

* update readme

* update l2 link

* update link

Co-authored-by: Ezra-Yu <18586273+Ezra-Yu@users.noreply.github.com>
2022-11-21 16:57:13 +08:00
marouane amzil 743ca2d602
[Fix] Fix the torchserve. ()
* rebase

* update docker and rm deprecated deployment tools

* update docs

* rebase

Co-authored-by: Ezra-Yu <18586273+Ezra-Yu@users.noreply.github.com>
2022-11-21 11:18:42 +08:00
Ma Zerun 940a06f645
[Refactor] Refactor to use new fileio API in MMEngine. ()
* [Refactor] Refactor to use new fileio API in MMEngine.

* Add comment about why use `backend`
2022-11-21 10:56:35 +08:00
Ezra-Yu 4969830c8a
[Enhance] Reproduce mobileone training accuracy. ()
* add switch hook and UTs

* update doc

* update doc

* fix lint

* fix ci

* fix ci

* fix typo

* fix ci

* update configs names

* update configs

* update configs

* update links

* update readme

* update vis_scheduler

* update metafile

* update configs

* rebase

* fix ci

* rebase
2022-11-21 10:43:34 +08:00
Hubert 629f6447ef
[Feature] Migrate CSRA head to 1.x. ()
* [Feat] add csra to 1x

* minor fix

* add voc metrics

* refine

* add unittest

* minor fix

* add more comments

* Fix docs and metafile.

* Fix docs.

Co-authored-by: mzr1996 <mzr1996@163.com>
2022-11-21 10:39:16 +08:00
Ma Zerun 0e8cfa6286
[Docs] Add not-found page extension. ()
* [Docs] Add not-found page extension.

* Mock rich during generate docs.

* Fix multiple broken links in docs.

* Fix "right" to "left".
2022-11-21 10:34:05 +08:00
Jiahao Wang 72c6bc4864
[Feature] Support RepLKnet backbone. ()
* update replknet configs

* update replknet test

* update replknet model

* update replknet model

* update replknet model

* update replknet model

* Fix docs and config names

Co-authored-by: mzr1996 <mzr1996@163.com>
2022-11-21 10:18:58 +08:00
Ezra-Yu c3c1cb93aa
[Feature] Add Switch Recipe Hook. ()
* add switch hook and UTs

* update doc

* update doc

* fix lint

* fix ci

* fix ci

* fix typo

* fix ci

* switchTrainAugHook to switchRecipeHook

* fix lint

* Refactor the `SwitchRecipeHook`.

* Fix windows CI

* Fix windows CI

* Fix windows CI.

Co-authored-by: mzr1996 <mzr1996@163.com>
2022-11-18 18:12:03 +08:00
Songyang Zhang f458bf5a64
[Docs] update visualization doc. ()
* [Docs] update visualization doc

* update doc

* update folder

* update analysis

* Update print config tool

Co-authored-by: mzr1996 <mzr1996@163.com>
2022-11-18 12:44:55 +08:00
XingyuXie e51ecfb129
[Feature] Add adan optimizer. ()
* add adan optimizer

* uppdate adan optimizer

* uppdate adan optimizer

* uppdate adan optimizer

* uppdate adan optimizer

* uppdate adan optimizer

* uppdate adan optimizer

* uppdate adan optimizer

* uppdate adan optimizer

* uppdate adan optimizer

* uppdate init
2022-11-17 08:11:25 +08:00
takuoko c4f3883a22
[Feature] Support DaViT. ()
* add davit

* fix mixup config

* convert scripts

* lint

* test

* test

* Add checkpoint links.

Co-authored-by: mzr1996 <mzr1996@163.com>
2022-11-16 17:23:55 +08:00
Hubert 992d13e772
[Enhance] add deleting params info in swinv2. () 2022-11-14 17:07:21 +08:00
Hubert 1b98fc13d9
[Enhance] Add more mobilenetv3 pretrains. ()
* add small-050/075 and move files

* add previous results

* Update checkpoint link

Co-authored-by: mzr1996 <mzr1996@163.com>
2022-11-14 17:06:39 +08:00
Hakjin Lee cf5879988d
[Feature] Support Activation Checkpointing for ConvNeXt. ()
* Support Activation Checkpointing for ConvNeXt

* Add docstring.

Co-authored-by: mzr1996 <mzr1996@163.com>
2022-11-14 15:04:28 +08:00
Yang Gao 11cd88f39a
[Fix] Fix configs due to api refactor of `num_classes`. () 2022-11-14 14:43:56 +08:00
Austin Welch ee9ee9cf3c
[Fix] Update mmcls2torchserve. () 2022-11-14 14:17:11 +08:00
Ma Zerun 542143cb41
[Feature] Add TIMM and HuggingFace wrappers to build classifiers from them directly. ()
* [Feature] Add TIMM and HuggingFace wrappers to build classifiers from them directly.

* Support `with_cp` and add docstring.

* Add unit tests.

* Update CI.

* Update docs.
2022-11-10 14:56:19 +08:00
Ma Zerun 2151beeb77
[Docs] Support sort and search the Model Summary table. ()
* [Docs] Support sort and search the Model Summary table.

* Add description.

* Fix according to comments
2022-11-08 12:03:06 +08:00
Ma Zerun c48cfa9f47
[Docs] Improve the ResNet model page. ()
* [Docs] Improve the ResNet model page.

* Fix lint.

* Update stat.py
2022-11-08 11:09:51 +08:00
Songyang Zhang 28b71c15bd
[Docs] update the readme of convnext. ()
* [Doc] update the readme of convnext

* update
2022-11-04 17:18:08 +08:00
zzc98 9eb6fc4368
[Feature] Add reduction for neck ()
* feat: add reduction for neck

* feat: add reduction for neck

* feat: add reduction for neck

* feat:add linear reduction neck

* feat: add reduction neck

* mod out of linearReduction as tuple

* fix typo

* fix unit tests

* fix unit tests

Co-authored-by: Ezra-Yu <18586273+Ezra-Yu@users.noreply.github.com>
2022-11-04 15:44:37 +08:00
takuoko 8cc1fdef52
[Enhancement] RepVGG for YOLOX-PAI for dev-1.x. () 2022-11-04 15:36:18 +08:00
takuoko d05cbbcf9b
[Feature] Support HorNet Backbone for dev1.x. ()
* add hornet

* add hornet

* fix mixup config

* add optim cfgs

Co-authored-by: Ezra-Yu <18586273+Ezra-Yu@users.noreply.github.com>
2022-11-04 15:33:46 +08:00
Sanbu b16938dc59
[Docs] Fix the installation docs link in README. () 2022-11-04 15:26:47 +08:00
Hubert 6203fd6cc9
[Docs] Improve ViT and MobileViT model pages. ()
* [Docs] Improve the ViT model page

* [Docs] Improve the MobileViT model page

* fix
2022-11-04 14:53:26 +08:00
Ezra-Yu 63b124e2d7
[Docs] Improve Swin Doc and Add Tabs enxtation. ()
* improve_swin_doc

* fix requirments

* improve swin2 docs

* improve swin2 docs

* update

* update CN doc

* update CN doc

* update comments

* fix error

* update register_all_modules

* Update README.md

* Update stat.py

* Update readme

Co-authored-by: mzr1996 <mzr1996@163.com>
2022-11-04 11:58:39 +08:00
kitecats ef512c98d0
[Docs] Add MMEval projects link in README. ()
* Update README.md

* Update README_zh-CN.md
2022-11-04 10:41:58 +08:00
zzc98 9506241f73
[Feature] Add arcface head. ()
* feat: add arcface head

* feat: add arcface head

* update arcface

* refactor archface head

* update archface head

* update archface head

* use torch.cat instead of torch.hstack to fix ci

Co-authored-by: Ezra-Yu <18586273+Ezra-Yu@users.noreply.github.com>
2022-11-02 17:45:33 +08:00
zzc98 693596bc2f
[Feature] Add Base Retriever and Image2Image Retriever for retrieval tasks. ()
* feat: add image retriever

* feat: add image retriever

* feat: add image retriever

* feat: add image retriever

* feat: add image retriever

* feat: add image retriever

* feat: add image retriever

* feat: add image retriever

* feat: add image retriever

* update retriever

* fix lint

* add hook unit test

* Use `register_buffer` to save prototype vectors and add a progress bar
during preparing prototype.

* update UTs

* update UTs

* fix typo

* modify the hook

Co-authored-by: Ezra-Yu <18586273+Ezra-Yu@users.noreply.github.com>
Co-authored-by: mzr1996 <mzr1996@163.com>
2022-11-02 17:43:56 +08:00
Ma Zerun 29c46c8af2
[Docs] Add runtime configuration docs. ()
* [Docs] Add runtime configuration docs.

* Fix grammar errors.

* Imporve docs according to comments
2022-11-02 10:59:59 +08:00
Hubert 50aaa711ea
[Docs] Add custom evaluation docs ()
* [Docs] Add evaluation docs

* minor fix

* Fix docs.

Co-authored-by: mzr1996 <mzr1996@163.com>
2022-11-01 18:54:06 +08:00
Ma Zerun 280e916979
[Docs] Add custom pipeline docs. ()
* [Docs] Add custom pipeline docs.

* Fix link.

* Fix according to comments
2022-10-27 10:35:20 +08:00
kitecats cccbedf22d
[Docs] Add MMYOLO projects link in MMCLS1.x ()
* Fix for MMCLS1.x not being able to get classes information in checkpoint during inference

Let MMCLS1.x get classes information from checkpoint during inference instead of using imagenet classes initialization

* Update inference.py

* Update README.md

* Update README_zh-CN.md
2022-10-21 16:25:20 +08:00
mzr1996 b526f018db [Fix] Fix broken inference api because of the modification in data
preprocessor.
2022-10-18 17:50:35 +08:00
Hubert bcca619066
[Feature] Support MobileViT backbone. ()
* init

* fix

* add config

* add meta

* add unittest

* fix for comments

* Imporvee docstring and support custom arch.

* Update README

* Update windows CI

Co-authored-by: mzr1996 <mzr1996@163.com>
2022-10-18 17:05:59 +08:00
Ma Zerun 29f066f7fb
[Improve] Speed up data preprocessor. ()
* [Improve] Speed up data preprocessor.

* Add ClsDataSample serialization override functions.

* Add unit tests

* Modify configs to fit new mixup args.

* Fix `num_classes` of the ImageNet-21k config.

* Update docs.
2022-10-17 17:08:18 +08:00
kitecats 06c919efc2
[Fix] Fix for `inference_model` cannot get classes information in checkpoint. ()
* Fix for MMCLS1.x not being able to get classes information in checkpoint during inference

Let MMCLS1.x get classes information from checkpoint during inference instead of using imagenet classes initialization

* Update inference.py
2022-10-14 08:27:01 +08:00
Ma Zerun 31c67ffed4
Bump version to v1.0.0rc2. () 2022-10-12 16:52:31 +08:00
Ma Zerun dfe0874102
Update requirements. () 2022-10-12 10:57:37 +08:00
takuoko 9bc58745d1
[Enhance] Update analyze_results.py for dev-1.x. ()
* update analyze_results

* lint

* add --rescale-factor and fix filename logic

* lint
2022-10-11 11:34:23 +08:00
takuoko a49c3076e1
[Feature] Support DeiT3. ()
* deit3

deit3

lint

* add tools and test

* deit3

* deit3

* fix preprocess

* lint

* Update config names and checkpoint paths

* Update convert tools to use mmengine, and fix docstring.

Co-authored-by: mzr1996 <mzr1996@163.com>
2022-10-10 14:54:20 +08:00
mzr1996 043574cbb2 Merge branch '1.x' into dev-1.x 2022-10-10 11:18:05 +08:00
mzr1996 2153a16dc5 [Docs] Fix some docstrings. 2022-10-10 11:17:48 +08:00
Hubert dfb4e87123
[Docs] Add 1x docs schedule. ()
* [Docs] add schedule en docstring

* [Docs] add schedule cn docstring

* Improve schedule docs.

* refine according to comments

Co-authored-by: mzr1996 <mzr1996@163.com>
2022-10-09 10:39:53 +08:00
mzr1996 f452e242a7 [Docs] Fix `mmcv-full` in docs. 2022-10-09 08:15:15 +08:00
mzr1996 bf9f3bbdda [Docs] Fix `mmcv-full` in docs. 2022-10-09 08:14:06 +08:00
mzr1996 b9bb21738b [CI] Skip timm unit tests with the minimum version PyTorch. 2022-10-08 16:07:52 +08:00
takuoko a1642e42da
[Enhancement] Get scores from inference api. () 2022-10-08 15:21:34 +08:00
mzr1996 ae37d7fd27 [Fix] Fix visualization hook. 2022-10-08 11:29:18 +08:00
mzr1996 23cad6a0e1 [Enhance] Support `--batch-size` option in the validation benchmark tool. 2022-10-08 11:14:35 +08:00
1723 changed files with 142305 additions and 13314 deletions

View File

@ -22,11 +22,11 @@ workflows:
# line:
# <regex path-to-test> <parameter-to-set> <value-of-pipeline-parameter>
mapping: |
mmcls/.* lint_only false
mmpretrain/.* lint_only false
requirements/.* lint_only false
tests/.* lint_only false
.circleci/.* lint_only false
base-revision: dev-1.x
base-revision: main
# this is the path of the configuration we should trigger once
# path filtering and pipeline parameter value updates are
# complete. In this case, we are using the parent dynamic

View File

@ -31,7 +31,58 @@ jobs:
name: Check docstring coverage
command: |
pip install interrogate
interrogate -v --ignore-init-method --ignore-module --ignore-nested-functions --ignore-magic --ignore-regex "__repr__" --fail-under 60 mmcls
interrogate -v --ignore-init-method --ignore-module --ignore-nested-functions --ignore-magic --ignore-regex "__repr__" --fail-under 60 mmpretrain
build_cpu_with_3rdparty:
parameters:
# The python version must match available image tags in
# https://circleci.com/developer/images/image/cimg/python
python:
type: string
torch:
type: string
torchvision:
type: string
docker:
- image: cimg/python:<< parameters.python >>
resource_class: large
steps:
- checkout
- run:
name: Install Libraries
command: |
sudo apt-get update
sudo apt-get install -y libjpeg8-dev zlib1g-dev
- run:
name: Configure Python & pip
command: |
pip install --upgrade pip
pip install wheel
- run:
name: Install PyTorch
command: |
python -V
pip install torch==<< parameters.torch >>+cpu torchvision==<< parameters.torchvision >>+cpu -f https://download.pytorch.org/whl/torch_stable.html
- run:
name: Install mmpretrain dependencies
command: |
pip install git+https://github.com/open-mmlab/mmengine.git@main
pip install -U openmim
mim install 'mmcv >= 2.0.0rc4'
pip install timm
pip install transformers
pip install -r requirements.txt
python -c 'import mmcv; print(mmcv.__version__)'
- run:
name: Build and install
command: |
pip install -e .
- run:
name: Run unittests
command: |
coverage run --branch --source mmpretrain -m pytest tests/
coverage xml
coverage report -m
build_cpu:
parameters:
# The python version must match available image tags in
@ -63,12 +114,11 @@ jobs:
python -V
pip install torch==<< parameters.torch >>+cpu torchvision==<< parameters.torchvision >>+cpu -f https://download.pytorch.org/whl/torch_stable.html
- run:
name: Install mmcls dependencies
name: Install mmpretrain dependencies
command: |
pip install git+https://github.com/open-mmlab/mmengine.git@main
pip install -U openmim
mim install 'mmcv >= 2.0.0rc1'
pip install timm
mim install 'mmcv >= 2.0.0rc4'
pip install -r requirements.txt
python -c 'import mmcv; print(mmcv.__version__)'
- run:
@ -78,7 +128,7 @@ jobs:
- run:
name: Run unittests
command: |
coverage run --branch --source mmcls -m pytest tests/
coverage run --branch --source mmpretrain -m pytest tests/
coverage xml
coverage report -m
@ -86,15 +136,17 @@ jobs:
machine:
image: ubuntu-2004-cuda-11.4:202110-01
resource_class: gpu.nvidia.small
environment:
MKL_SERVICE_FORCE_INTEL: 1
parameters:
torch:
type: string
cuda:
type: enum
enum: ["10.1", "10.2", "11.1"]
enum: ["11.1", "11.7"]
cudnn:
type: integer
default: 7
default: 8
steps:
- checkout
- run:
@ -105,24 +157,24 @@ jobs:
- run:
name: Build Docker image
command: |
docker build .circleci/docker -t mmcls:gpu --build-arg PYTORCH=<< parameters.torch >> --build-arg CUDA=<< parameters.cuda >> --build-arg CUDNN=<< parameters.cudnn >>
docker run --gpus all -t -d -v /home/circleci/project:/mmcls -v /home/circleci/mmengine:/mmengine -w /mmcls --name mmcls mmcls:gpu
docker build .circleci/docker -t mmpretrain:gpu --build-arg PYTORCH=<< parameters.torch >> --build-arg CUDA=<< parameters.cuda >> --build-arg CUDNN=<< parameters.cudnn >>
docker run --gpus all -t -d -v /home/circleci/project:/mmpretrain -v /home/circleci/mmengine:/mmengine -w /mmpretrain --name mmpretrain mmpretrain:gpu
- run:
name: Install mmcls dependencies
name: Install mmpretrain dependencies
command: |
docker exec mmcls pip install -e /mmengine
docker exec mmcls pip install -U openmim
docker exec mmcls mim install 'mmcv >= 2.0.0rc1'
docker exec mmcls pip install -r requirements.txt
docker exec mmcls python -c 'import mmcv; print(mmcv.__version__)'
docker exec mmpretrain pip install -e /mmengine
docker exec mmpretrain pip install -U openmim
docker exec mmpretrain mim install 'mmcv >= 2.0.0rc4'
docker exec mmpretrain pip install -r requirements.txt
docker exec mmpretrain python -c 'import mmcv; print(mmcv.__version__)'
- run:
name: Build and install
command: |
docker exec mmcls pip install -e .
docker exec mmpretrain pip install -e .
- run:
name: Run unittests
command: |
docker exec mmcls python -m pytest tests/ -k 'not timm'
docker exec mmpretrain python -m pytest tests/
# Invoke jobs via workflows
# See: https://circleci.com/docs/2.0/configuration-reference/#workflows
@ -135,8 +187,8 @@ workflows:
filters:
branches:
ignore:
- dev-1.x
- 1.x
- dev
- main
pr_stage_test:
when:
not:
@ -147,19 +199,19 @@ workflows:
filters:
branches:
ignore:
- dev-1.x
- dev
- build_cpu:
name: minimum_version_cpu
torch: 1.6.0
torchvision: 0.7.0
python: 3.6.9 # The lowest python 3.6.x version available on CircleCI images
torch: 1.8.0
torchvision: 0.9.0
python: 3.7.16
requires:
- lint
- build_cpu:
- build_cpu_with_3rdparty:
name: maximum_version_cpu
torch: 1.12.1
torchvision: 0.13.1
python: 3.9.0
torch: 2.0.0
torchvision: 0.15.1
python: 3.10.0
requires:
- minimum_version_cpu
- hold:
@ -171,7 +223,14 @@ workflows:
torch: 1.8.1
# Use double quotation mark to explicitly specify its type
# as string instead of number
cuda: "10.2"
cuda: "11.1"
requires:
- hold
- build_cuda:
name: maximum_version_gpu
torch: 2.0.0
cuda: "11.7"
cudnn: 8
requires:
- hold
merge_stage_test:
@ -181,11 +240,11 @@ workflows:
jobs:
- build_cuda:
name: minimum_version_gpu
torch: 1.6.0
torch: 1.8.0
# Use double quotation mark to explicitly specify its type
# as string instead of number
cuda: "10.1"
cuda: "11.1"
filters:
branches:
only:
- dev-1.x
- pretrain

View File

@ -1,25 +1,27 @@
import logging
import re
import sys
import tempfile
from argparse import ArgumentParser
from collections import OrderedDict
from pathlib import Path
from time import time
from time import perf_counter
from unittest.mock import Mock
import mmcv
import numpy as np
import torch
from mmengine import Config, DictAction, MMLogger
from mmengine import DictAction, MMLogger
from mmengine.dataset import Compose, default_collate
from mmengine.fileio import FileClient
from mmengine.runner import Runner
from modelindex.load_model_index import load
from mmengine.device import get_device
from mmengine.model.utils import revert_sync_batchnorm
from mmengine.runner import Runner, load_checkpoint
from rich.console import Console
from rich.table import Table
from utils import substitute_weights
from mmcls.datasets import CIFAR10, CIFAR100, ImageNet
from mmcls.utils import register_all_modules
from mmcls.visualization import ClsVisualizer
from mmpretrain.apis import ModelHub, get_model, list_models
from mmpretrain.datasets import CIFAR10, CIFAR100, ImageNet
from mmpretrain.utils import register_all_modules
from mmpretrain.visualization import UniversalVisualizer
console = Console()
MMCLS_ROOT = Path(__file__).absolute().parents[2]
@ -30,6 +32,12 @@ classes_map = {
'CIFAR-100': CIFAR100.CLASSES,
}
logger = MMLogger.get_instance('validation', logger_name='mmpretrain')
logger.handlers[0].stream = sys.stderr
logger.addHandler(logging.FileHandler('benchmark_valid.log', mode='w'))
# Force to use the logger in runners.
Runner.build_logger = Mock(return_value=logger)
def parse_args():
parser = ArgumentParser(description='Valid all models in model-index.yml')
@ -48,6 +56,11 @@ def parse_args():
'--inference-time',
action='store_true',
help='Test inference time by run 10 times for each model.')
parser.add_argument(
'--batch-size',
type=int,
default=1,
help='The batch size during the inference.')
parser.add_argument(
'--flops', action='store_true', help='Get Flops and Params of models')
parser.add_argument(
@ -68,65 +81,76 @@ def parse_args():
return args
def inference(config_file, checkpoint, work_dir, args, exp_name):
cfg = Config.fromfile(config_file)
def inference(metainfo, checkpoint, work_dir, args, exp_name=None):
cfg = metainfo.config
cfg.work_dir = work_dir
cfg.load_from = checkpoint
cfg.log_level = 'WARN'
cfg.experiment_name = exp_name
cfg.experiment_name = exp_name or metainfo.name
if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)
# build the data pipeline
test_dataset = cfg.test_dataloader.dataset
if test_dataset.pipeline[0]['type'] != 'LoadImageFromFile':
test_dataset.pipeline.insert(0, dict(type='LoadImageFromFile'))
if test_dataset.type in ['CIFAR10', 'CIFAR100']:
# The image shape of CIFAR is (32, 32, 3)
test_dataset.pipeline.insert(1, dict(type='Resize', scale=32))
if 'test_dataloader' in cfg:
# build the data pipeline
test_dataset = cfg.test_dataloader.dataset
if test_dataset.pipeline[0]['type'] != 'LoadImageFromFile':
test_dataset.pipeline.insert(0, dict(type='LoadImageFromFile'))
if test_dataset.type in ['CIFAR10', 'CIFAR100']:
# The image shape of CIFAR is (32, 32, 3)
test_dataset.pipeline.insert(1, dict(type='Resize', scale=32))
data = Compose(test_dataset.pipeline)({'img_path': args.img})
data = default_collate([data])
resolution = tuple(data['inputs'].shape[-2:])
data = Compose(test_dataset.pipeline)({'img_path': args.img})
data = default_collate([data] * args.batch_size)
resolution = tuple(data['inputs'].shape[-2:])
model = Runner.from_cfg(cfg).model
model = revert_sync_batchnorm(model)
model.eval()
forward = model.val_step
else:
# For configs without data settings.
model = get_model(cfg, device=get_device())
model = revert_sync_batchnorm(model)
model.eval()
data = torch.rand(1, 3, 224, 224).to(model.data_preprocessor.device)
resolution = (224, 224)
forward = model.extract_feat
runner: Runner = Runner.from_cfg(cfg)
model = runner.model
if checkpoint is not None:
load_checkpoint(model, checkpoint, map_location='cpu')
# forward the model
result = {'resolution': resolution}
result = {'model': metainfo.name, 'resolution': resolution}
with torch.no_grad():
if args.inference_time:
time_record = []
forward(data) # warmup before profiling
for _ in range(10):
model.val_step(data) # warmup before profiling
torch.cuda.synchronize()
start = time()
model.val_step(data)
start = perf_counter()
forward(data)
torch.cuda.synchronize()
time_record.append((time() - start) * 1000)
time_record.append(
(perf_counter() - start) / args.batch_size * 1000)
result['time_mean'] = np.mean(time_record[1:-1])
result['time_std'] = np.std(time_record[1:-1])
else:
model.val_step(data)
result['model'] = config_file.stem
forward(data)
if args.flops:
from fvcore.nn import FlopCountAnalysis, parameter_count
from fvcore.nn.print_model_statistics import _format_size
from mmengine.analysis import FlopAnalyzer, parameter_count
from mmengine.analysis.print_helper import _format_size
_format_size = _format_size if args.flops_str else lambda x: x
with torch.no_grad():
if hasattr(model, 'extract_feat'):
model.forward = model.extract_feat
model.to('cpu')
inputs = (torch.randn((1, 3, *resolution)), )
flops = _format_size(FlopCountAnalysis(model, inputs).total())
params = _format_size(parameter_count(model)[''])
result['flops'] = flops if args.flops_str else int(flops)
result['params'] = params if args.flops_str else int(params)
else:
result['flops'] = ''
result['params'] = ''
model.forward = model.extract_feat
model.to('cpu')
inputs = (torch.randn((1, 3, *resolution)), )
analyzer = FlopAnalyzer(model, inputs)
# extract_feat only includes backbone
analyzer._enable_warn_uncalled_mods = False
flops = _format_size(analyzer.total())
params = _format_size(parameter_count(model)[''])
result['flops'] = flops if args.flops_str else int(flops)
result['params'] = params if args.flops_str else int(params)
return result
@ -135,17 +159,17 @@ def show_summary(summary_data, args):
table = Table(title='Validation Benchmark Regression Summary')
table.add_column('Model')
table.add_column('Validation')
table.add_column('Resolution (h, w)')
table.add_column('Resolution (h w)')
if args.inference_time:
table.add_column('Inference Time (std) (ms/im)')
if args.flops:
table.add_column('Flops', justify='right')
table.add_column('Params', justify='right')
table.add_column('Flops', justify='right', width=13)
table.add_column('Params', justify='right', width=11)
for model_name, summary in summary_data.items():
row = [model_name]
valid = summary['valid']
color = 'green' if valid == 'PASS' else 'red'
color = {'PASS': 'green', 'CUDA OOM': 'yellow'}.get(valid, 'red')
row.append(f'[{color}]{valid}[/{color}]')
if valid == 'PASS':
row.append(str(summary['resolution']))
@ -158,84 +182,55 @@ def show_summary(summary_data, args):
row.append(str(summary['params']))
table.add_row(*row)
console.print(table)
# Sample test whether the inference code is correct
def main(args):
register_all_modules()
model_index_file = MMCLS_ROOT / 'model-index.yml'
model_index = load(str(model_index_file))
model_index.build_models_with_collections()
models = OrderedDict({model.name: model for model in model_index.models})
logger = MMLogger(
'validation',
logger_name='validation',
log_file='benchmark_test_image.log',
log_level=logging.INFO)
if args.models:
patterns = [re.compile(pattern) for pattern in args.models]
filter_models = {}
for k, v in models.items():
if any([re.match(pattern, k) for pattern in patterns]):
filter_models[k] = v
if len(filter_models) == 0:
models = set()
for pattern in args.models:
models.update(list_models(pattern=pattern))
if len(models) == 0:
print('No model found, please specify models in:')
print('\n'.join(models.keys()))
print('\n'.join(list_models()))
return
models = filter_models
else:
models = list_models()
summary_data = {}
tmpdir = tempfile.TemporaryDirectory()
for model_name, model_info in models.items():
for model_name in models:
model_info = ModelHub.get(model_name)
if model_info.config is None:
continue
config = Path(model_info.config)
assert config.exists(), f'{model_name}: {config} not found.'
logger.info(f'Processing: {model_name}')
http_prefix = 'https://download.openmmlab.com/mmclassification/'
if args.checkpoint_root is not None:
root = args.checkpoint_root
if 's3://' in args.checkpoint_root:
from petrel_client.common.exception import AccessDeniedError
file_client = FileClient.infer_client(uri=root)
checkpoint = file_client.join_path(
root, model_info.weights[len(http_prefix):])
try:
exists = file_client.exists(checkpoint)
except AccessDeniedError:
exists = False
else:
checkpoint = Path(root) / model_info.weights[len(http_prefix):]
exists = checkpoint.exists()
if exists:
checkpoint = str(checkpoint)
else:
print(f'WARNING: {model_name}: {checkpoint} not found.')
checkpoint = None
weights = model_info.weights
if args.checkpoint_root is not None and weights is not None:
checkpoint = substitute_weights(weights, args.checkpoint_root)
else:
checkpoint = None
try:
# build the model from a config file and a checkpoint file
result = inference(MMCLS_ROOT / config, checkpoint, tmpdir.name,
args, model_name)
result = inference(model_info, checkpoint, tmpdir.name, args)
result['valid'] = 'PASS'
except Exception:
import traceback
logger.error(f'"{config}" :\n{traceback.format_exc()}')
result = {'valid': 'FAIL'}
except Exception as e:
if 'CUDA out of memory' in str(e):
logger.error(f'"{model_name}" :\nCUDA out of memory')
result = {'valid': 'CUDA OOM'}
else:
import traceback
logger.error(f'"{model_name}" :\n{traceback.format_exc()}')
result = {'valid': 'FAIL'}
summary_data[model_name] = result
# show the results
if args.show:
vis = ClsVisualizer.get_instance('valid')
vis = UniversalVisualizer.get_instance('valid')
vis.set_image(mmcv.imread(args.img))
vis.draw_texts(
texts='\n'.join([f'{k}: {v}' for k, v in result.items()]),

View File

@ -1,9 +1,10 @@
import argparse
import fnmatch
import logging
import os
import os.path as osp
import pickle
import re
from collections import OrderedDict
from collections import OrderedDict, defaultdict
from datetime import datetime
from pathlib import Path
@ -11,57 +12,57 @@ from modelindex.load_model_index import load
from rich.console import Console
from rich.syntax import Syntax
from rich.table import Table
from utils import METRICS_MAP, MMCLS_ROOT, substitute_weights
# Avoid to import MMPretrain to accelerate speed to show summary
console = Console()
MMCLS_ROOT = Path(__file__).absolute().parents[2]
METRICS_MAP = {
'Top 1 Accuracy': 'accuracy/top1',
'Top 5 Accuracy': 'accuracy/top5'
}
logger = logging.getLogger('test')
logger.addHandler(logging.StreamHandler())
logger.addHandler(logging.FileHandler('benchmark_test.log', mode='w'))
def parse_args():
parser = argparse.ArgumentParser(
description="Test all models' accuracy in model-index.yml")
parser.add_argument(
'partition', type=str, help='Cluster partition to use.')
parser.add_argument('checkpoint_root', help='Checkpoint file root path.')
parser.add_argument(
'--job-name',
type=str,
default='cls-test-benchmark',
help='Slurm job name prefix')
parser.add_argument('--port', type=int, default=29666, help='dist port')
'--local', action='store_true', help='run at local instead of slurm.')
parser.add_argument(
'--models', nargs='+', type=str, help='Specify model names to run.')
parser.add_argument(
'--work-dir',
default='work_dirs/benchmark_test',
help='the dir to save metric')
parser.add_argument(
'--run', action='store_true', help='run script directly')
parser.add_argument(
'--local',
action='store_true',
help='run at local instead of cluster.')
parser.add_argument(
'--mail', type=str, help='Mail address to watch test status.')
parser.add_argument(
'--mail-type',
nargs='+',
default=['BEGIN'],
choices=['NONE', 'BEGIN', 'END', 'FAIL', 'REQUEUE', 'ALL'],
help='Mail address to watch test status.')
parser.add_argument(
'--quotatype',
default=None,
choices=['reserved', 'auto', 'spot'],
help='Quota type, only available for phoenix-slurm>=0.2')
parser.add_argument(
'--summary',
action='store_true',
help='Summarize benchmark test results.')
parser.add_argument('--save', action='store_true', help='Save the summary')
parser.add_argument(
'--gpus', type=int, default=1, help='How many GPUS to use.')
parser.add_argument(
'--no-skip',
action='store_true',
help='Whether to skip models without results record in the metafile.')
parser.add_argument(
'--work-dir',
default='work_dirs/benchmark_test',
help='the dir to save metric')
parser.add_argument('--port', type=int, default=29666, help='dist port')
parser.add_argument(
'--partition',
type=str,
default='mm_model',
help='(for slurm) Cluster partition to use.')
parser.add_argument(
'--job-name',
type=str,
default='cls-test-benchmark',
help='(for slurm) Slurm job name prefix')
parser.add_argument(
'--quotatype',
default=None,
choices=['reserved', 'auto', 'spot'],
help='(for slurm) Quota type, only available for phoenix-slurm>=0.2')
parser.add_argument(
'--cfg-options',
nargs='+',
@ -74,64 +75,53 @@ def parse_args():
def create_test_job_batch(commands, model_info, args, port, script_name):
fname = model_info.name
model_name = model_info.name
config = Path(model_info.config)
assert config.exists(), f'{fname}: {config} not found.'
http_prefix = 'https://download.openmmlab.com/mmclassification/'
if 's3://' in args.checkpoint_root:
from mmengine.fileio import FileClient
from petrel_client.common.exception import AccessDeniedError
file_client = FileClient.infer_client(uri=args.checkpoint_root)
checkpoint = file_client.join_path(
args.checkpoint_root, model_info.weights[len(http_prefix):])
try:
exists = file_client.exists(checkpoint)
except AccessDeniedError:
exists = False
if model_info.weights is not None:
checkpoint = substitute_weights(model_info.weights,
args.checkpoint_root)
if checkpoint is None:
logger.warning(f'{model_name}: {checkpoint} not found.')
return None
else:
checkpoint_root = Path(args.checkpoint_root)
checkpoint = checkpoint_root / model_info.weights[len(http_prefix):]
exists = checkpoint.exists()
if not exists:
print(f'WARNING: {fname}: {checkpoint} not found.')
return None
job_name = f'{args.job_name}_{fname}'
work_dir = Path(args.work_dir) / fname
job_name = f'{args.job_name}_{model_name}'
work_dir = Path(args.work_dir) / model_name
work_dir.mkdir(parents=True, exist_ok=True)
result_file = work_dir / 'result.pkl'
if args.mail is not None and 'NONE' not in args.mail_type:
mail_cfg = (f'#SBATCH --mail {args.mail}\n'
f'#SBATCH --mail-type {args.mail_type}\n')
else:
mail_cfg = ''
if args.quotatype is not None:
quota_cfg = f'#SBATCH --quotatype {args.quotatype}\n'
quota_cfg = f'#SBATCH --quotatype {args.quotatype}'
else:
quota_cfg = ''
launcher = 'none' if args.local else 'slurm'
runner = 'python' if args.local else 'srun python'
if not args.local:
launcher = 'srun python'
runner = 'slurm'
elif args.gpus > 1:
launcher = 'pytorch'
runner = ('torchrun --master_addr="127.0.0.1" '
f'--master_port={port} --nproc_per_node={args.gpus}')
else:
launcher = 'none'
runner = 'python -u'
job_script = (f'#!/bin/bash\n'
f'#SBATCH --output {work_dir}/job.%j.out\n'
f'#SBATCH --partition={args.partition}\n'
f'#SBATCH --job-name {job_name}\n'
f'#SBATCH --gres=gpu:8\n'
f'{mail_cfg}{quota_cfg}'
f'#SBATCH --ntasks-per-node=8\n'
f'#SBATCH --ntasks=8\n'
f'#SBATCH --gres=gpu:{min(8, args.gpus)}\n'
f'{quota_cfg}\n'
f'#SBATCH --ntasks-per-node={min(8, args.gpus)}\n'
f'#SBATCH --ntasks={args.gpus}\n'
f'#SBATCH --cpus-per-task=5\n\n'
f'{runner} -u {script_name} {config} {checkpoint} '
f'--work-dir={work_dir} '
f'--out={result_file} '
f'--cfg-option dist_params.port={port} '
f'{runner} {script_name} {config} {checkpoint} '
f'--work-dir={work_dir} --cfg-option '
f'env_cfg.dist_cfg.port={port} '
f'{" ".join(args.cfg_options)} '
f'--out={result_file} --out-item="metrics" '
f'--launcher={launcher}\n')
with open(work_dir / 'job.sh', 'w') as f:
@ -146,33 +136,17 @@ def create_test_job_batch(commands, model_info, args, port, script_name):
return work_dir / 'job.sh'
def test(args):
# parse model-index.yml
model_index_file = MMCLS_ROOT / 'model-index.yml'
model_index = load(str(model_index_file))
model_index.build_models_with_collections()
models = OrderedDict({model.name: model for model in model_index.models})
def test(models, args):
script_name = osp.join('tools', 'test.py')
port = args.port
commands = []
if args.models:
patterns = [re.compile(pattern) for pattern in args.models]
filter_models = {}
for k, v in models.items():
if any([re.match(pattern, k) for pattern in patterns]):
filter_models[k] = v
if len(filter_models) == 0:
print('No model found, please specify models in:')
print('\n'.join(models.keys()))
return
models = filter_models
preview_script = ''
for model_info in models.values():
if model_info.results is None:
# Skip pre-train model
continue
script_path = create_test_job_batch(commands, model_info, args, port,
@ -205,44 +179,41 @@ def test(args):
console.print('Please set "--run" to start the job')
def save_summary(summary_data, models_map, work_dir):
summary_path = work_dir / 'test_benchmark_summary.md'
def save_summary(summary_data, work_dir):
summary_path = work_dir / 'test_benchmark_summary.csv'
file = open(summary_path, 'w')
headers = [
'Model', 'Top-1 Expected(%)', 'Top-1 (%)', 'Top-5 Expected (%)',
'Top-5 (%)', 'Config'
]
file.write('# Test Benchmark Regression Summary\n')
file.write('| ' + ' | '.join(headers) + ' |\n')
file.write('|:' + ':|:'.join(['---'] * len(headers)) + ':|\n')
columns = defaultdict(list)
for model_name, summary in summary_data.items():
if len(summary) == 0:
# Skip models without results
continue
row = [model_name]
if 'Top 1 Accuracy' in summary:
metric = summary['Top 1 Accuracy']
row.append(str(round(metric['expect'], 2)))
row.append(str(round(metric['result'], 2)))
else:
row.extend([''] * 2)
if 'Top 5 Accuracy' in summary:
metric = summary['Top 5 Accuracy']
row.append(str(round(metric['expect'], 2)))
row.append(str(round(metric['result'], 2)))
else:
row.extend([''] * 2)
columns['Name'].append(model_name)
model_info = models_map[model_name]
row.append(model_info.config)
file.write('| ' + ' | '.join(row) + ' |\n')
for metric_key in METRICS_MAP:
if metric_key in summary:
metric = summary[metric_key]
expect = round(metric['expect'], 2)
result = round(metric['result'], 2)
columns[f'{metric_key} (expect)'].append(str(expect))
columns[f'{metric_key}'].append(str(result))
else:
columns[f'{metric_key} (expect)'].append('')
columns[f'{metric_key}'].append('')
columns = {
field: column
for field, column in columns.items() if ''.join(column)
}
file.write(','.join(columns.keys()) + '\n')
for row in zip(*columns.values()):
file.write(','.join(row) + '\n')
file.close()
print('Summary file saved at ' + str(summary_path))
logger.info('Summary file saved at ' + str(summary_path))
def show_summary(summary_data):
table = Table(title='Test Benchmark Regression Summary')
table.add_column('Model')
table.add_column('Name')
for metric in METRICS_MAP:
table.add_column(f'{metric} (expect)')
table.add_column(f'{metric}')
@ -274,33 +245,20 @@ def show_summary(summary_data):
row.append('')
table.add_row(*row)
# Remove empty columns
table.columns = [
column for column in table.columns if ''.join(column._cells)
]
console.print(table)
def summary(args):
model_index_file = MMCLS_ROOT / 'model-index.yml'
model_index = load(str(model_index_file))
model_index.build_models_with_collections()
models = OrderedDict({model.name: model for model in model_index.models})
def summary(models, args):
work_dir = Path(args.work_dir)
if args.models:
patterns = [re.compile(pattern) for pattern in args.models]
filter_models = {}
for k, v in models.items():
if any([re.match(pattern, k) for pattern in patterns]):
filter_models[k] = v
if len(filter_models) == 0:
print('No model found, please specify models in:')
print('\n'.join(models.keys()))
return
models = filter_models
summary_data = {}
for model_name, model_info in models.items():
if model_info.results is None:
if model_info.results is None and not args.no_skip:
continue
# Skip if not found result file.
@ -327,16 +285,35 @@ def summary(args):
show_summary(summary_data)
if args.save:
save_summary(summary_data, models, work_dir)
save_summary(summary_data, work_dir)
def main():
args = parse_args()
# parse model-index.yml
model_index_file = MMCLS_ROOT / 'model-index.yml'
model_index = load(str(model_index_file))
model_index.build_models_with_collections()
models = OrderedDict({model.name: model for model in model_index.models})
if args.models:
filter_models = {}
for pattern in args.models:
filter_models.update({
name: models[name]
for name in fnmatch.filter(models, pattern + '*')
})
if len(filter_models) == 0:
logger.error('No model found, please specify models in:\n' +
'\n'.join(models.keys()))
return
models = filter_models
if args.summary:
summary(args)
summary(models, args)
else:
test(args)
test(models, args)
if __name__ == '__main__':

View File

@ -1,9 +1,12 @@
import argparse
import fnmatch
import json
import logging
import os
import os.path as osp
import re
from collections import OrderedDict
from collections import defaultdict
from copy import deepcopy
from datetime import datetime
from pathlib import Path
from zipfile import ZipFile
@ -14,18 +17,20 @@ from rich.console import Console
from rich.syntax import Syntax
from rich.table import Table
from .utils import METRICS_MAP, MMCLS_ROOT
# Avoid to import MMPretrain to accelerate speed to show summary
console = Console()
MMCLS_ROOT = Path(__file__).absolute().parents[2]
logger = logging.getLogger('train')
logger.addHandler(logging.StreamHandler())
logger.addHandler(logging.FileHandler('benchmark_train.log', mode='w'))
CYCLE_LEVELS = ['month', 'quarter', 'half-year', 'no-training']
METRICS_MAP = {
'Top 1 Accuracy': 'accuracy/top1',
'Top 5 Accuracy': 'accuracy/top5'
}
class RangeAction(argparse.Action):
def __call__(self, parser, namespace, values: str, option_string):
def __call__(self, _, namespace, values: str, __):
matches = re.match(r'([><=]*)([-\w]+)', values)
if matches is None:
raise ValueError(f'Unavailable range option {values}')
@ -49,15 +54,25 @@ def parse_args():
parser = argparse.ArgumentParser(
description='Train models (in bench_train.yml) and compare accuracy.')
parser.add_argument(
'partition', type=str, help='Cluster partition to use.')
parser.add_argument(
'--job-name',
type=str,
default='cls-train-benchmark',
help='Slurm job name prefix')
parser.add_argument('--port', type=int, default=29666, help='dist port')
'--local',
action='store_true',
help='run at local instead of cluster.')
parser.add_argument(
'--models', nargs='+', type=str, help='Specify model names to run.')
parser.add_argument(
'--run', action='store_true', help='run script directly')
parser.add_argument(
'--summary',
action='store_true',
help='Summarize benchmark train results.')
parser.add_argument(
'--save',
action='store_true',
help='Save the summary and archive log files.')
parser.add_argument(
'--non-distributed',
action='store_true',
help='Use non-distributed environment (for debug).')
parser.add_argument(
'--range',
type=str,
@ -70,33 +85,22 @@ def parse_args():
'--work-dir',
default='work_dirs/benchmark_train',
help='the dir to save train log')
parser.add_argument('--port', type=int, default=29666, help='dist port')
parser.add_argument(
'--run', action='store_true', help='run script directly')
'--partition',
type=str,
default='mm_model',
help='(for slurm) Cluster partition to use.')
parser.add_argument(
'--local',
action='store_true',
help='run at local instead of cluster.')
parser.add_argument(
'--mail', type=str, help='Mail address to watch train status.')
parser.add_argument(
'--mail-type',
nargs='+',
default=['BEGIN', 'END', 'FAIL'],
choices=['NONE', 'BEGIN', 'END', 'FAIL', 'REQUEUE', 'ALL'],
help='Mail address to watch train status.')
'--job-name',
type=str,
default='cls-train-benchmark',
help='(for slurm) Slurm job name prefix')
parser.add_argument(
'--quotatype',
default=None,
choices=['reserved', 'auto', 'spot'],
help='Quota type, only available for phoenix-slurm>=0.2')
parser.add_argument(
'--summary',
action='store_true',
help='Summarize benchmark train results.')
parser.add_argument(
'--save',
action='store_true',
help='Save the summary and archive log files.')
help='(for slurm) Quota type, only available for phoenix-slurm>=0.2')
parser.add_argument(
'--cfg-options',
nargs='+',
@ -118,72 +122,93 @@ def get_gpu_number(model_info):
return gpus
def create_train_job_batch(commands, model_info, args, port, script_name):
fname = model_info.name
gpus = get_gpu_number(model_info)
gpus_per_node = min(gpus, 8)
def create_train_job_batch(model_info, args, port, pretrain_info=None):
model_name = model_info.name
config = Path(model_info.config)
assert config.exists(), f'"{fname}": {config} not found.'
gpus = get_gpu_number(model_info)
job_name = f'{args.job_name}_{fname}'
work_dir = Path(args.work_dir) / fname
job_name = f'{args.job_name}_{model_name}'
work_dir = Path(args.work_dir) / model_name
work_dir.mkdir(parents=True, exist_ok=True)
if args.mail is not None and 'NONE' not in args.mail_type:
mail_cfg = (f'#SBATCH --mail {args.mail}\n'
f'#SBATCH --mail-type {args.mail_type}\n')
else:
mail_cfg = ''
cfg_options = deepcopy(args.cfg_options)
if args.quotatype is not None:
quota_cfg = f'#SBATCH --quotatype {args.quotatype}\n'
quota_cfg = f'#SBATCH --quotatype {args.quotatype}'
else:
quota_cfg = ''
launcher = 'none' if args.local else 'slurm'
runner = 'python' if args.local else 'srun python'
if pretrain_info is not None:
pretrain = Path(args.work_dir) / pretrain_info.name / 'last_checkpoint'
pretrain_cfg = (f'model.backbone.init_cfg.checkpoint="$(<{pretrain})" '
'model.backbone.init_cfg.type="Pretrained" '
'model.backbone.init_cfg.prefix="backbone."')
else:
pretrain_cfg = ''
if not args.local:
launcher = 'slurm'
runner = 'srun python'
if gpus > 8:
gpus = 8
cfg_options.append('auto_scale_lr.enable=True')
elif not args.non_distributed:
launcher = 'pytorch'
if gpus > 8:
gpus = 8
cfg_options.append('auto_scale_lr.enable=True')
runner = ('torchrun --master_addr="127.0.0.1" '
f'--master_port={port} --nproc_per_node={gpus}')
else:
launcher = 'none'
runner = 'python -u'
job_script = (f'#!/bin/bash\n'
f'#SBATCH --output {work_dir}/job.%j.out\n'
f'#SBATCH --partition={args.partition}\n'
f'#SBATCH --job-name {job_name}\n'
f'#SBATCH --gres=gpu:{gpus_per_node}\n'
f'{mail_cfg}{quota_cfg}'
f'#SBATCH --ntasks-per-node={gpus_per_node}\n'
f'#SBATCH --gres=gpu:{min(8, gpus)}\n'
f'{quota_cfg}\n'
f'#SBATCH --ntasks-per-node={min(8, gpus)}\n'
f'#SBATCH --ntasks={gpus}\n'
f'#SBATCH --cpus-per-task=5\n\n'
f'{runner} -u {script_name} {config} '
f'{runner} tools/train.py {config} '
f'--work-dir={work_dir} --cfg-option '
f'env_cfg.dist_cfg.port={port} '
f'{" ".join(args.cfg_options)} '
f'{" ".join(cfg_options)} '
f'default_hooks.checkpoint.max_keep_ckpts=2 '
f'default_hooks.checkpoint.save_best="auto" '
f'{pretrain_cfg} '
f'--launcher={launcher}\n')
with open(work_dir / 'job.sh', 'w') as f:
f.write(job_script)
commands.append(f'echo "{config}"')
if args.local:
commands.append(f'bash {work_dir}/job.sh')
else:
commands.append(f'sbatch {work_dir}/job.sh')
return work_dir / 'job.sh'
def train(models, args):
script_name = osp.join('tools', 'train.py')
port = args.port
commands = []
for model_info in models.values():
script_path = create_train_job_batch(commands, model_info, args, port,
script_name)
script_path = create_train_job_batch(model_info, args, port)
if hasattr(model_info, 'downstream'):
downstream_info = model_info.downstream
downstream_script = create_train_job_batch(
downstream_info, args, port, pretrain_info=model_info)
else:
downstream_script = None
if args.local:
command = f'bash {script_path}'
if downstream_script:
command += f' && bash {downstream_script}'
else:
command = f'JOBID=$(sbatch --parsable {script_path})'
if downstream_script:
command += f' && sbatch --dependency=afterok:$JOBID {downstream_script}' # noqa: E501
commands.append(command)
port += 1
command_str = '\n'.join(commands)
@ -211,63 +236,67 @@ def train(models, args):
console.print('Please set "--run" to start the job')
def save_summary(summary_data, models_map, work_dir):
def save_summary(summary_data, work_dir):
date = datetime.now().strftime('%Y%m%d-%H%M%S')
zip_path = work_dir / f'archive-{date}.zip'
zip_file = ZipFile(zip_path, 'w')
summary_path = work_dir / 'benchmark_summary.md'
summary_path = work_dir / 'benchmark_summary.csv'
file = open(summary_path, 'w')
headers = [
'Model', 'Top-1 Expected(%)', 'Top-1 (%)', 'Top-1 best(%)',
'best epoch', 'Top-5 Expected (%)', 'Top-5 (%)', 'Config', 'Log'
]
file.write('# Train Benchmark Regression Summary\n')
file.write('| ' + ' | '.join(headers) + ' |\n')
file.write('|:' + ':|:'.join(['---'] * len(headers)) + ':|\n')
columns = defaultdict(list)
for model_name, summary in summary_data.items():
if len(summary) == 0:
# Skip models without results
continue
row = [model_name]
if 'Top 1 Accuracy' in summary:
metric = summary['Top 1 Accuracy']
row.append(f"{metric['expect']:.2f}")
row.append(f"{metric['last']:.2f}")
row.append(f"{metric['best']:.2f}")
row.append(f"{metric['best_epoch']:.2f}")
else:
row.extend([''] * 4)
if 'Top 5 Accuracy' in summary:
metric = summary['Top 5 Accuracy']
row.append(f"{metric['expect']:.2f}")
row.append(f"{metric['last']:.2f}")
else:
row.extend([''] * 2)
columns['Name'].append(model_name)
model_info = models_map[model_name]
row.append(model_info.config)
row.append(str(summary['log_file'].relative_to(work_dir)))
for metric_key in METRICS_MAP:
if metric_key in summary:
metric = summary[metric_key]
expect = str(round(metric['expect'], 2))
result = str(round(metric['result'], 2))
columns[f'{metric_key} (expect)'].append(expect)
columns[f'{metric_key}'].append(result)
best = str(round(metric['best'], 2))
best_epoch = str(int(metric['best_epoch']))
columns[f'{metric_key} (best)'].append(best)
columns[f'{metric_key} (best epoch)'].append(best_epoch)
else:
columns[f'{metric_key} (expect)'].append('')
columns[f'{metric_key}'].append('')
columns[f'{metric_key} (best)'].append('')
columns[f'{metric_key} (best epoch)'].append('')
columns['Log'].append(str(summary['log_file'].relative_to(work_dir)))
zip_file.write(summary['log_file'])
file.write('| ' + ' | '.join(row) + ' |\n')
columns = {
field: column
for field, column in columns.items() if ''.join(column)
}
file.write(','.join(columns.keys()) + '\n')
for row in zip(*columns.values()):
file.write(','.join(row) + '\n')
file.close()
zip_file.write(summary_path)
zip_file.close()
print('Summary file saved at ' + str(summary_path))
print('Log files archived at ' + str(zip_path))
logger.info('Summary file saved at ' + str(summary_path))
logger.info('Log files archived at ' + str(zip_path))
def show_summary(summary_data):
table = Table(title='Train Benchmark Regression Summary')
table.add_column('Model')
table.add_column('Name')
for metric in METRICS_MAP:
table.add_column(f'{metric} (expect)')
table.add_column(f'{metric}')
table.add_column(f'{metric} (best)')
table.add_column('Date')
def set_color(value, expect):
if value > expect:
return 'green'
elif value > expect - 0.2:
elif value >= expect - 0.2:
return 'white'
else:
return 'red'
@ -277,25 +306,30 @@ def show_summary(summary_data):
for metric_key in METRICS_MAP:
if metric_key in summary:
metric = summary[metric_key]
expect = metric['expect']
last = metric['last']
expect = round(metric['expect'], 2)
last = round(metric['last'], 2)
last_epoch = metric['last_epoch']
last_color = set_color(last, expect)
best = metric['best']
best_color = set_color(best, expect)
best_epoch = metric['best_epoch']
best_epoch = round(metric['best_epoch'], 2)
row.append(f'{expect:.2f}')
row.append(
f'[{last_color}]{last:.2f}[/{last_color}] ({last_epoch})')
row.append(
f'[{best_color}]{best:.2f}[/{best_color}] ({best_epoch})')
else:
row.extend([''] * 3)
table.add_row(*row)
# Remove empty columns
table.columns = [
column for column in table.columns if ''.join(column._cells)
]
console.print(table)
def summary(models, args):
work_dir = Path(args.work_dir)
dir_map = {p.name: p for p in work_dir.iterdir() if p.is_dir()}
@ -306,9 +340,17 @@ def summary(models, args):
if model_name not in dir_map:
continue
elif hasattr(model_info, 'downstream'):
downstream_name = model_info.downstream.name
if downstream_name not in dir_map:
continue
else:
sub_dir = dir_map[downstream_name]
model_info = model_info.downstream
else:
# Skip if not found any vis_data folder.
sub_dir = dir_map[model_name]
# Skip if not found any vis_data folder.
sub_dir = dir_map[model_name]
log_files = [f for f in sub_dir.glob('*/vis_data/scalars.json')]
if len(log_files) == 0:
continue
@ -317,11 +359,8 @@ def summary(models, args):
# parse train log
with open(log_file) as f:
json_logs = [json.loads(s) for s in f.readlines()]
val_logs = [
log for log in json_logs
# TODO: need a better method to extract validate log
if 'loss' not in log and 'accuracy/top1' in log
]
# TODO: need a better method to extract validate log
val_logs = [log for log in json_logs if 'loss' not in log]
if len(val_logs) == 0:
continue
@ -351,12 +390,13 @@ def summary(models, args):
show_summary(summary_data)
if args.save:
save_summary(summary_data, models, work_dir)
save_summary(summary_data, work_dir)
def main():
args = parse_args()
# parse model-index.yml
model_index_file = MMCLS_ROOT / 'model-index.yml'
model_index = load(str(model_index_file))
model_index.build_models_with_collections()
@ -364,25 +404,28 @@ def main():
with open(Path(__file__).parent / 'bench_train.yml', 'r') as f:
train_items = yaml.safe_load(f)
models = OrderedDict()
models = {}
for item in train_items:
name = item['Name']
model_info = all_models[name]
model_info.cycle = item.get('Cycle', None)
cycle = getattr(model_info, 'cycle', 'month')
cycle = item['Cycle']
cycle_level = CYCLE_LEVELS.index(cycle)
if cycle_level in args.range:
model_info = all_models[name]
if 'Downstream' in item:
downstream = item['Downstream']
setattr(model_info, 'downstream', all_models[downstream])
models[name] = model_info
if args.models:
patterns = [re.compile(pattern) for pattern in args.models]
filter_models = {}
for k, v in models.items():
if any([re.match(pattern, k) for pattern in patterns]):
filter_models[k] = v
for pattern in args.models:
filter_models.update({
name: models[name]
for name in fnmatch.filter(models, pattern + '*')
})
if len(filter_models) == 0:
print('No model found, please specify models in:')
print('\n'.join(models.keys()))
logger.error('No model found, please specify models in:\n' +
'\n'.join(models.keys()))
return
models = filter_models

View File

@ -18,9 +18,9 @@ from modelindex.load_model_index import load
from rich.console import Console
from rich.table import Table
from mmcls.datasets.builder import build_dataloader
from mmcls.datasets.pipelines import Compose
from mmcls.models.builder import build_classifier
from mmpretrain.datasets.builder import build_dataloader
from mmpretrain.datasets.pipelines import Compose
from mmpretrain.models.builder import build_classifier
console = Console()
MMCLS_ROOT = Path(__file__).absolute().parents[2]

View File

@ -0,0 +1,14 @@
- Name: convnext-base_32xb128_in1k
- Name: convnext-v2-atto_fcmae-pre_3rdparty_in1k
- Name: mobilenet-v2_8xb32_in1k
- Name: mobilenet-v3-small-050_3rdparty_in1k
- Name: swin-tiny_16xb64_in1k
- Name: swinv2-tiny-w8_3rdparty_in1k-256px
- Name: vit-base-p16_32xb128-mae_in1k
- Name: resnet34_8xb32_in1k
- Name: resnext50-32x4d_8xb32_in1k
- Name: shufflenet-v2-1x_16xb64_in1k
- Name: riformer-s12_in1k
- Name: blip-base_3rdparty_retrieval
- Name: blip2-opt2.7b_3rdparty-zeroshot_caption
- Name: ofa-base_3rdparty-finetuned_caption

View File

@ -1,18 +1,21 @@
- Name: mobilenet-v2_8xb32_in1k
Cycle: month
- Name: resnet50_8xb32_in1k
Cycle: month
- Name: seresnet50_8xb32_in1k
- Name: resnet50_8xb256-rsb-a1-600e_in1k
Cycle: month
- Name: swin-small_16xb64_in1k
Cycle: month
- Name: vit-base-p16_pt-32xb128-mae_in1k
- Name: vit-base-p16_32xb128-mae_in1k
Cycle: month
- Name: seresnet50_8xb32_in1k
Cycle: quarter
- Name: resnet50_8xb32_in1k
Cycle: quarter
- Name: resnet50_8xb256-rsb-a1-600e_in1k
Cycle: quarter
@ -34,53 +37,85 @@
- Name: regnetx-1.6gf_8xb128_in1k
Cycle: half-year
- Name: van-small_8xb128_in1k
Cycle: no-training
- Name: conformer-small-p32_8xb128_in1k
Cycle: half-year
- Name: res2net50-w14-s8_3rdparty_8xb32_in1k
Cycle: no-training
- Name: convnext-small_32xb128_in1k
Cycle: month
- Name: repvgg-A2_3rdparty_4xb64-coslr-120e_in1k
Cycle: no-training
- Name: mobilenet-v3-small_8xb128_in1k
Cycle: half-year
- Name: tnt-small-p16_3rdparty_in1k
Cycle: no-training
- Name: mobileone-s2_8xb32_in1k
Cycle: quarter
- Name: mlp-mixer-base-p16_3rdparty_64xb64_in1k
Cycle: no-training
- Name: repvgg-b2g4_8xb32_in1k
Cycle: half-year
- Name: conformer-small-p16_3rdparty_8xb128_in1k
Cycle: no-training
- Name: barlowtwins_resnet50_8xb256-coslr-300e_in1k
Cycle: half-year
Downstream: resnet50_barlowtwins-pre_8xb32-linear-coslr-100e_in1k
- Name: twins-pcpvt-base_3rdparty_8xb128_in1k
Cycle: no-training
- Name: beit_beit-base-p16_8xb256-amp-coslr-300e_in1k
Cycle: quarter
Downstream: beit-base-p16_beit-pre_8xb128-coslr-100e_in1k
- Name: efficientnet-b0_3rdparty_8xb32_in1k
Cycle: no-training
- Name: beitv2_beit-base-p16_8xb256-amp-coslr-300e_in1k
Cycle: quarter
Downstream: beit-base-p16_beitv2-pre_8xb128-coslr-100e_in1k
- Name: convnext-small_3rdparty_32xb128_in1k
Cycle: no-training
- Name: byol_resnet50_16xb256-coslr-200e_in1k
Cycle: quarter
Downstream: resnet50_byol-pre_8xb512-linear-coslr-90e_in1k
- Name: hrnet-w18_3rdparty_8xb32_in1k
Cycle: no-training
- Name: cae_beit-base-p16_8xb256-amp-coslr-300e_in1k
Cycle: half-year
Downstream: beit-base-p16_cae-pre_8xb128-coslr-100e_in1k
- Name: repmlp-base_3rdparty_8xb64_in1k
Cycle: no-training
- Name: densecl_resnet50_8xb32-coslr-200e_in1k
Cycle: half-year
Downstream: resnet50_densecl-pre_8xb32-linear-steplr-100e_in1k
- Name: wide-resnet50_3rdparty_8xb32_in1k
Cycle: no-training
- Name: eva-mae-style_vit-base-p16_16xb256-coslr-400e_in1k
Cycle: half-year
Downstream: vit-base-p16_eva-mae-style-pre_8xb2048-linear-coslr-100e_in1k
- Name: cspresnet50_3rdparty_8xb32_in1k
Cycle: no-training
- Name: mae_vit-base-p16_8xb512-amp-coslr-300e_in1k
Cycle: month
Downstream: vit-base-p16_mae-300e-pre_8xb2048-linear-coslr-90e_in1k
- Name: convmixer-768-32_10xb64_in1k
Cycle: no-training
- Name: maskfeat_vit-base-p16_8xb256-amp-coslr-300e_in1k
Cycle: quarter
Downstream: vit-base-p16_maskfeat-pre_8xb256-coslr-100e_in1k
- Name: densenet169_4xb256_in1k
Cycle: no-training
- Name: milan_vit-base-p16_16xb256-amp-coslr-400e_in1k
Cycle: quarter
Downstream: vit-base-p16_milan-pre_8xb2048-linear-coslr-100e_in1k
- Name: poolformer-s24_3rdparty_32xb128_in1k
Cycle: no-training
- Name: mixmim_mixmim-base_16xb128-coslr-300e_in1k
Cycle: half-year
Downstream: mixmim-base_mixmim-pre_8xb128-coslr-100e_in1k
- Name: inception-v3_3rdparty_8xb32_in1k
Cycle: no-training
- Name: mocov2_resnet50_8xb32-coslr-200e_in1k
Cycle: quarter
Downstream: resnet50_mocov2-pre_8xb32-linear-steplr-100e_in1k
- Name: mocov3_vit-small-p16_16xb256-amp-coslr-300e_in1k
Cycle: month
Downstream: vit-small-p16_mocov3-pre_8xb128-linear-coslr-90e_in1k
- Name: simclr_resnet50_16xb256-coslr-200e_in1k
Cycle: quarter
Downstream: resnet50_simclr-200e-pre_8xb512-linear-coslr-90e_in1k
- Name: simmim_swin-base-w6_8xb256-amp-coslr-100e_in1k-192px
Cycle: month
Downstream: swin-base-w6_simmim-100e-pre_8xb256-coslr-100e_in1k-192px
- Name: simsiam_resnet50_8xb32-coslr-100e_in1k
Cycle: quarter
Downstream: resnet50_simsiam-100e-pre_8xb512-linear-coslr-90e_in1k
- Name: swav_resnet50_8xb32-mcrop-coslr-200e_in1k-224px-96px
Cycle: half-year
Downstream: resnet50_swav-pre_8xb32-linear-coslr-100e_in1k

View File

@ -0,0 +1,33 @@
from pathlib import Path
HTTP_PREFIX = 'https://download.openmmlab.com/'
MMCLS_ROOT = Path(__file__).absolute().parents[2]
METRICS_MAP = {
'Top 1 Accuracy': 'accuracy/top1',
'Top 5 Accuracy': 'accuracy/top5',
'Recall@1': 'retrieval/Recall@1',
'Recall@5': 'retrieval/Recall@5',
'BLEU-4': 'Bleu_4',
'CIDER': 'CIDEr',
}
def substitute_weights(download_link, root):
if 's3://' in root:
from mmengine.fileio.backends import PetrelBackend
from petrel_client.common.exception import AccessDeniedError
file_backend = PetrelBackend()
checkpoint = file_backend.join_path(root,
download_link[len(HTTP_PREFIX):])
try:
exists = file_backend.exists(checkpoint)
except AccessDeniedError:
exists = False
else:
checkpoint = Path(root) / download_link[len(HTTP_PREFIX):]
exists = checkpoint.exists()
if exists:
return str(checkpoint)
else:
return None

View File

@ -0,0 +1,207 @@
import argparse
import logging
import re
import sys
from pathlib import Path
import yaml
from modelindex.load_model_index import load
from modelindex.models.Collection import Collection
from modelindex.models.Model import Model
from modelindex.models.ModelIndex import ModelIndex
class ContextFilter(logging.Filter):
metafile = None
name = None
failed = False
def filter(self, record: logging.LogRecord):
record.color = {
logging.WARNING: '\x1b[33;20m',
logging.ERROR: '\x1b[31;1m',
}.get(record.levelno, '')
self.failed = self.failed or (record.levelno >= logging.ERROR)
record.metafile = self.metafile or ''
record.name = ('' if self.name is None else '\x1b[32m' + self.name +
'\x1b[0m: ')
return True
context = ContextFilter()
logging.basicConfig(
format='[%(metafile)s] %(color)s%(levelname)s\x1b[0m - %(name)s%(message)s'
)
logger = logging.getLogger()
logger.addFilter(context)
prog_description = """\
Check the format of metafile.
"""
MMCLS_ROOT = Path(__file__).absolute().parents[1]
def parse_args():
parser = argparse.ArgumentParser(description=prog_description)
parser.add_argument(
'metafile', type=Path, nargs='+', help='The path of the matafile.')
parser.add_argument(
'--Wall',
'-w',
action='store_true',
help='Whether to enable all warnings.')
parser.add_argument('--skip', action='append', help='Rules to skip check.')
args = parser.parse_args()
args.skip = args.skip or []
return args
def check_collection(modelindex: ModelIndex, skip=[]):
if len(modelindex.collections) == 0:
return ['No collection field.']
elif len(modelindex.collections) > 1:
logger.error('One metafile should have only one collection.')
collection: Collection = modelindex.collections[0]
if collection.name is None:
logger.error('The collection should have `Name` field.')
if collection.readme is None:
logger.error('The collection should have `README` field.')
if not (MMCLS_ROOT / collection.readme).exists():
logger.error(f'The README {collection.readme} is not found.')
if not isinstance(collection.paper, dict):
logger.error('The collection should have `Paper` field with '
'`Title` and `URL`.')
elif 'Title' not in collection.paper:
# URL is not necessary.
logger.error("The collection's paper should have `Paper` field.")
def check_model_name(name):
fields = name.split('_')
if len(fields) > 5:
logger.warning('Too many fields.')
return
elif len(fields) < 3:
logger.warning('Too few fields.')
return
elif len(fields) == 5:
algo, model, pre, train, data = fields
elif len(fields) == 3:
model, train, data = fields
algo, pre = None, None
elif len(fields) == 4 and fields[1].endswith('-pre'):
model, pre, train, data = fields
algo = None
else:
algo, model, train, data = fields
pre = None
if pre is not None and not pre.endswith('-pre'):
logger.warning(f'The position of `{pre}` should be '
'pre-training information, and ends with `-pre`.')
if '3rdparty' not in train and re.match(r'\d+xb\d+', train) is None:
logger.warning(f'The position of `{train}` should be training '
'infomation, and starts with `3rdparty` or '
'`{num_device}xb{batch_per_device}`')
def check_model(model: Model, skip=[]):
context.name = None
if model.name is None:
logger.error("A model doesn't have `Name` field.")
return
context.name = model.name
check_model_name(model.name)
if model.name.endswith('.py'):
logger.error("Don't add `.py` suffix in model name.")
if model.metadata is None and 'metadata' not in skip:
logger.error('No `Metadata` field.')
if (model.metadata.parameters is None
or model.metadata.flops is None) and 'flops-param' not in skip:
logger.error('Metadata should have `Parameters` and `FLOPs` fields. '
'You can use `tools/analysis_tools/get_flops.py` '
'to calculate them.')
if model.results is not None and 'result' not in skip:
result = model.results[0]
if not isinstance(result.dataset, str):
logger.error('Dataset field of Results should be a string. '
'If you want to specify the training dataset, '
'please use `Metadata.Training Data` field.')
if 'config' not in skip:
if model.config is None:
logger.error('No `Config` field.')
elif not (MMCLS_ROOT / model.config).exists():
logger.error(f'The config {model.config} is not found.')
if model.in_collection is None:
logger.error('No `In Collection` field.')
if (model.data.get('Converted From') is not None
and '3rdparty' not in model.name):
logger.warning("The model name should include '3rdparty' "
"since it's converted from other repository.")
if (model.weights is not None and model.weights.endswith('.pth')
and 'ckpt-name' not in skip):
basename = model.weights.rsplit('/', 1)[-1]
if not basename.startswith(model.name):
logger.warning(f'The checkpoint name {basename} is not the '
'same as the model name.')
context.name = None
def main(metafile: Path, args):
if metafile.name != 'metafile.yml':
# Avoid checking other yaml file.
return
elif metafile.samefile(MMCLS_ROOT / 'model-index.yml'):
return
context.metafile = metafile
with open(MMCLS_ROOT / 'model-index.yml', 'r') as f:
metafile_list = yaml.load(f, yaml.Loader)['Import']
if not any(
metafile.samefile(MMCLS_ROOT / file)
for file in metafile_list):
logger.error(
'The metafile is not imported in the `model-index.yml`.')
modelindex = load(str(metafile))
modelindex.build_models_with_collections()
check_collection(modelindex, args.skip)
names = {model.name for model in modelindex.models}
for model in modelindex.models:
check_model(model, args.skip)
for downstream in model.data.get('Downstream', []):
if downstream not in names:
context.name = model.name
logger.error(
f"The downstream model {downstream} doesn't exist.")
if __name__ == '__main__':
args = parse_args()
if args.Wall:
logger.setLevel(logging.WARNING)
else:
logger.setLevel(logging.ERROR)
for metafile in args.metafile:
main(metafile, args)
sys.exit(int(context.failed))

View File

@ -0,0 +1,186 @@
import argparse
import math
from pathlib import Path
import torch
from rich.console import Console
console = Console()
prog_description = """\
Draw the state dict tree.
"""
def parse_args():
parser = argparse.ArgumentParser(description=prog_description)
parser.add_argument(
'path',
type=Path,
help='The path of the checkpoint or model config to draw.')
parser.add_argument('--depth', type=int, help='The max depth to draw.')
parser.add_argument(
'--full-name',
action='store_true',
help='Whether to print the full name of the key.')
parser.add_argument(
'--shape',
action='store_true',
help='Whether to print the shape of the parameter.')
parser.add_argument(
'--state-key',
type=str,
help='The key of the state dict in the checkpoint.')
parser.add_argument(
'--number',
action='store_true',
help='Mark all parameters and their index number.')
parser.add_argument(
'--node',
type=str,
help='Show the sub-tree of a node, like "backbone.layers".')
args = parser.parse_args()
return args
def ckpt_to_state_dict(checkpoint, key=None):
if key is not None:
state_dict = checkpoint[key]
elif 'state_dict' in checkpoint:
# try mmpretrain style
state_dict = checkpoint['state_dict']
elif 'model' in checkpoint:
state_dict = checkpoint['model']
elif isinstance(next(iter(checkpoint.values())), torch.Tensor):
# try native style
state_dict = checkpoint
else:
raise KeyError('Please specify the key of state '
f'dict from {list(checkpoint.keys())}.')
return state_dict
class StateDictTree:
def __init__(self, key='', value=None):
self.children = {}
self.key: str = key
self.value = value
def add_parameter(self, key, value):
keys = key.split('.', 1)
if len(keys) == 1:
self.children[key] = StateDictTree(key, value)
elif keys[0] in self.children:
self.children[keys[0]].add_parameter(keys[1], value)
else:
node = StateDictTree(keys[0])
node.add_parameter(keys[1], value)
self.children[keys[0]] = node
def __getitem__(self, key: str):
return self.children[key]
def __repr__(self) -> str:
with console.capture() as capture:
for line in self.iter_tree():
console.print(line)
return capture.get()
def __len__(self):
return len(self.children)
def draw_tree(self,
max_depth=None,
full_name=False,
with_shape=False,
with_value=False):
for line in self.iter_tree(
max_depth=max_depth,
full_name=full_name,
with_shape=with_shape,
with_value=with_value,
):
console.print(line, highlight=False)
def iter_tree(
self,
lead='',
prefix='',
max_depth=None,
full_name=False,
with_shape=False,
with_value=False,
):
if self.value is None:
key_str = f'[blue]{self.key}[/]'
elif with_shape:
key_str = f'[green]{self.key}[/] {tuple(self.value.shape)}'
elif with_value:
key_str = f'[green]{self.key}[/] {self.value}'
else:
key_str = f'[green]{self.key}[/]'
yield lead + prefix + key_str
lead = lead.replace('├─', '')
lead = lead.replace('└─', ' ')
if self.key and full_name:
prefix = f'{prefix}{self.key}.'
if max_depth == 0:
return
elif max_depth is not None:
max_depth -= 1
for i, child in enumerate(self.children.values()):
level_lead = '├─' if i < len(self.children) - 1 else '└─'
yield from child.iter_tree(
lead=f'{lead}{level_lead} ',
prefix=prefix,
max_depth=max_depth,
full_name=full_name,
with_shape=with_shape,
with_value=with_value)
def main():
args = parse_args()
if args.path.suffix in ['.json', '.py', '.yml']:
from mmengine.runner import get_state_dict
from mmpretrain.apis import init_model
model = init_model(args.path, device='cpu')
state_dict = get_state_dict(model)
else:
ckpt = torch.load(args.path, map_location='cpu')
state_dict = ckpt_to_state_dict(ckpt, args.state_key)
root = StateDictTree()
for k, v in state_dict.items():
root.add_parameter(k, v)
para_index = 0
mark_width = math.floor(math.log(len(state_dict), 10) + 1)
if args.node is not None:
for key in args.node.split('.'):
root = root[key]
for line in root.iter_tree(
max_depth=args.depth,
full_name=args.full_name,
with_shape=args.shape,
):
if not args.number:
mark = ''
# A hack method to determine whether a line is parameter.
elif '[green]' in line:
mark = f'[red]({str(para_index).ljust(mark_width)})[/]'
para_index += 1
else:
mark = ' ' * (mark_width + 2)
console.print(mark + line, highlight=False)
if __name__ == '__main__':
main()

View File

@ -0,0 +1,121 @@
#!/usr/bin/env python
import argparse
from pathlib import Path
import matplotlib.pyplot as plt
import torch
from ckpt_tree import StateDictTree, ckpt_to_state_dict
from rich.progress import track
from scipy import stats
prog_description = """\
Compare the initialization distribution between state dicts by Kolmogorov-Smirnov test.
""" # noqa: E501
def parse_args():
parser = argparse.ArgumentParser(
formatter_class=argparse.RawDescriptionHelpFormatter,
description=prog_description)
parser.add_argument(
'model_a',
type=Path,
help='The path of the first checkpoint or model config.')
parser.add_argument(
'model_b',
type=Path,
help='The path of the second checkpoint or model config.')
parser.add_argument(
'--show',
action='store_true',
help='Whether to draw the KDE of variables')
parser.add_argument(
'-p',
default=0.01,
type=float,
help='The threshold of p-value. '
'Higher threshold means more strict test.')
args = parser.parse_args()
return args
def compare_distribution(state_dict_a, state_dict_b, p_thres):
assert len(state_dict_a) == len(state_dict_b)
for k, v1 in state_dict_a.items():
assert k in state_dict_b
v2 = state_dict_b[k]
v1 = v1.cpu().flatten()
v2 = v2.cpu().flatten()
pvalue = stats.kstest(v1, v2).pvalue
if pvalue < p_thres:
yield k, pvalue, v1, v2
def state_dict_from_cfg_or_ckpt(path, state_key=None):
if path.suffix in ['.json', '.py', '.yml']:
from mmengine.runner import get_state_dict
from mmpretrain.apis import init_model
model = init_model(path, device='cpu')
model.init_weights()
return get_state_dict(model)
else:
ckpt = torch.load(path, map_location='cpu')
return ckpt_to_state_dict(ckpt, state_key)
def main():
args = parse_args()
state_dict_a = state_dict_from_cfg_or_ckpt(args.model_a)
state_dict_b = state_dict_from_cfg_or_ckpt(args.model_b)
compare_keys = state_dict_a.keys() & state_dict_b.keys()
if len(compare_keys) == 0:
raise ValueError("The state dicts don't match, please convert "
'to the same keys before comparison.')
root = StateDictTree()
for key in track(compare_keys):
if state_dict_a[key].shape != state_dict_b[key].shape:
raise ValueError(f'The shapes of "{key}" are different. '
'Please check models in the same architecture.')
# Sample at most 30000 items to prevent long-time calcuation.
perm_ids = torch.randperm(state_dict_a[key].numel())[:30000]
value_a = state_dict_a[key].flatten()[perm_ids]
value_b = state_dict_b[key].flatten()[perm_ids]
pvalue = stats.kstest(value_a, value_b).pvalue
if pvalue < args.p:
root.add_parameter(key, round(pvalue, 4))
if args.show:
try:
import seaborn as sns
except ImportError:
raise ImportError('Please install `seaborn` by '
'`pip install seaborn` to show KDE.')
sample_a = str([round(v.item(), 2) for v in value_a[:10]])
sample_b = str([round(v.item(), 2) for v in value_b[:10]])
if value_a.std() > 0:
sns.kdeplot(value_a, fill=True)
else:
sns.scatterplot(x=[value_a[0].item()], y=[1])
if value_b.std() > 0:
sns.kdeplot(value_b, fill=True)
else:
sns.scatterplot(x=[value_b[0].item()], y=[1])
plt.legend([
f'{args.model_a.stem}: {sample_a}',
f'{args.model_b.stem}: {sample_b}'
])
plt.title(key)
plt.show()
if len(root) > 0:
root.draw_tree(with_value=True)
print("Above parameters didn't pass the test, "
'and the values are their similarity score.')
else:
print('The distributions of all weights are the same.')
if __name__ == '__main__':
main()

View File

@ -0,0 +1,501 @@
import argparse
import copy
import re
from functools import partial
from pathlib import Path
import yaml
from prompt_toolkit import ANSI
from prompt_toolkit import prompt as _prompt
from prompt_toolkit.completion import (FuzzyCompleter, FuzzyWordCompleter,
PathCompleter)
from rich.console import Console
from rich.panel import Panel
from rich.prompt import Confirm, Prompt
from rich.syntax import Syntax
prog_description = """\
To display metafile or fill missing fields of the metafile.
"""
MMCLS_ROOT = Path(__file__).absolute().parents[1].resolve().absolute()
console = Console()
dataset_completer = FuzzyWordCompleter([
'ImageNet-1k', 'ImageNet-21k', 'CIFAR-10', 'CIFAR-100', 'RefCOCO', 'VQAv2',
'COCO', 'OpenImages', 'Object365', 'CC3M', 'CC12M', 'YFCC100M', 'VG'
])
def prompt(message,
allow_empty=True,
default=None,
multiple=False,
completer=None):
with console.capture() as capture:
console.print(message, end='')
message = ANSI(capture.get())
ask = partial(
_prompt, message=message, default=default or '', completer=completer)
out = ask()
if multiple:
outs = []
while out != '':
outs.append(out)
out = ask()
return outs
if not allow_empty and out == '':
while out == '':
out = ask()
if default is None and out == '':
return None
else:
return out.strip()
class MyDumper(yaml.Dumper):
def increase_indent(self, flow=False, indentless=False):
return super(MyDumper, self).increase_indent(flow, False)
yaml_dump = partial(
yaml.dump, Dumper=MyDumper, default_flow_style=False, sort_keys=False)
def parse_args():
parser = argparse.ArgumentParser(description=prog_description)
parser.add_argument('--src', type=Path, help='The path of the matafile.')
parser.add_argument('--out', '-o', type=Path, help='The output path.')
parser.add_argument(
'--inplace',
'-i',
action='store_true',
help='Modify the source metafile inplace.')
parser.add_argument(
'--view', action='store_true', help='Only pretty print the metafile.')
parser.add_argument('--csv', type=str, help='Use a csv to update models.')
args = parser.parse_args()
if args.inplace:
args.out = args.src
return args
def get_flops_params(config_path):
import numpy as np
import torch
from mmengine.analysis import FlopAnalyzer, parameter_count
from mmengine.dataset import Compose
from mmengine.model.utils import revert_sync_batchnorm
from mmengine.registry import DefaultScope
from mmpretrain.apis import get_model
from mmpretrain.models.utils import no_load_hf_pretrained_model
with no_load_hf_pretrained_model():
model = get_model(config_path, device='cpu')
model = revert_sync_batchnorm(model)
model.eval()
params = int(parameter_count(model)[''])
# get flops
try:
if 'test_dataloader' in model._config:
# build the data pipeline
test_dataset = model._config.test_dataloader.dataset
if test_dataset.pipeline[0]['type'] == 'LoadImageFromFile':
test_dataset.pipeline.pop(0)
if test_dataset.type in ['CIFAR10', 'CIFAR100']:
# The image shape of CIFAR is (32, 32, 3)
test_dataset.pipeline.insert(1, dict(type='Resize', scale=32))
with DefaultScope.overwrite_default_scope('mmpretrain'):
data = Compose(test_dataset.pipeline)({
'img':
np.random.randint(0, 256, (224, 224, 3), dtype=np.uint8)
})
resolution = tuple(data['inputs'].shape[-2:])
else:
# For configs only for get model.
resolution = (224, 224)
with torch.no_grad():
# Skip flops if the model doesn't have `extract_feat` method.
model.forward = model.extract_feat
model.to('cpu')
inputs = (torch.randn((1, 3, *resolution)), )
analyzer = FlopAnalyzer(model, inputs)
analyzer.unsupported_ops_warnings(False)
analyzer.uncalled_modules_warnings(False)
flops = int(analyzer.total())
except Exception:
print('Unable to calculate flops.')
flops = None
return flops, params
def fill_collection(collection: dict):
if collection.get('Name') is None:
name = prompt(
'Please input the collection [red]name[/]: ', allow_empty=False)
collection['Name'] = name
if collection.get('Metadata', {}).get('Architecture') is None:
architecture = prompt(
'Please input the model [red]architecture[/] '
'(input empty to finish): ',
multiple=True)
if len(architecture) > 0:
collection.setdefault('Metadata', {})
collection['Metadata']['Architecture'] = architecture
if collection.get('Paper', {}).get('Title') is None:
title = prompt('Please input the [red]paper title[/]: ')
else:
title = collection['Paper']['Title']
if collection.get('Paper', {}).get('URL') is None:
url = prompt('Please input the [red]paper url[/]: ')
else:
url = collection['Paper']['URL']
paper = dict(Title=title, URL=url)
collection['Paper'] = paper
if collection.get('README') is None:
readme = prompt(
'Please input the [red]README[/] file path: ',
completer=PathCompleter(file_filter=lambda name: Path(name).is_dir(
) or 'README.md' in name))
if readme is not None:
collection['README'] = str(
Path(readme).absolute().relative_to(MMCLS_ROOT))
else:
collection['README'] = None
order = ['Name', 'Metadata', 'Paper', 'README', 'Code']
collection = {
k: collection[k]
for k in sorted(collection.keys(), key=order.index)
}
return collection
def fill_model_by_prompt(model: dict, defaults: dict):
# Name
if model.get('Name') is None:
name = prompt(
'Please input the model [red]name[/]: ', allow_empty=False)
model['Name'] = name
# In Collection
model['In Collection'] = defaults.get('In Collection')
# Config
config = model.get('Config')
if config is None:
config = prompt(
'Please input the [red]config[/] file path: ',
completer=FuzzyCompleter(PathCompleter()))
if config is not None:
config = str(Path(config).absolute().relative_to(MMCLS_ROOT))
model['Config'] = config
# Metadata.Flops, Metadata.Parameters
flops = model.get('Metadata', {}).get('FLOPs')
params = model.get('Metadata', {}).get('Parameters')
if model.get('Config') is not None and (
MMCLS_ROOT / model['Config']).exists() and (flops is None
and params is None):
print('Automatically compute FLOPs and Parameters from config.')
flops, params = get_flops_params(str(MMCLS_ROOT / model['Config']))
if flops is None:
flops = prompt('Please specify the [red]FLOPs[/]: ')
if flops is not None:
flops = int(flops)
if params is None:
params = prompt('Please specify the [red]number of parameters[/]: ')
if params is not None:
params = int(params)
model.setdefault('Metadata', {})
model['Metadata'].setdefault('FLOPs', flops)
model['Metadata'].setdefault('Parameters', params)
if 'Training Data' not in model.get('Metadata', {}) and \
'Training Data' not in defaults.get('Metadata', {}):
training_data = prompt(
'Please input all [red]training dataset[/], '
'include pre-training (input empty to finish): ',
completer=dataset_completer,
multiple=True)
if len(training_data) > 1:
model['Metadata']['Training Data'] = training_data
elif len(training_data) == 1:
model['Metadata']['Training Data'] = training_data[0]
results = model.get('Results')
if results is None:
test_dataset = prompt(
'Please input the [red]test dataset[/]: ',
completer=dataset_completer)
if test_dataset is not None:
task = Prompt.ask(
'Please input the [red]test task[/]',
default='Image Classification')
if task == 'Image Classification':
metrics = {}
top1 = prompt('Please input the [red]top-1 accuracy[/]: ')
top5 = prompt('Please input the [red]top-5 accuracy[/]: ')
if top1 is not None:
metrics['Top 1 Accuracy'] = round(float(top1), 2)
if top5 is not None:
metrics['Top 5 Accuracy'] = round(float(top5), 2)
else:
metrics_list = prompt(
'Please input the [red]metrics[/] like "mAP=94.98" '
'(input empty to finish): ',
multiple=True)
metrics = {}
for metric in metrics_list:
k, v = metric.split('=')[:2]
metrics[k] = round(float(v), 2)
results = [{
'Task': task,
'Dataset': test_dataset,
'Metrics': metrics or None,
}]
model['Results'] = results
weights = model.get('Weights')
if weights is None:
weights = prompt('Please input the [red]checkpoint download link[/]: ')
model['Weights'] = weights
if model.get('Converted From') is None and model.get(
'Weights') is not None:
if '3rdparty' in model['Name'] or Confirm.ask(
'Is the checkpoint is converted '
'from [red]other repository[/]?',
default=False):
converted_from = {}
converted_from['Weights'] = prompt(
'Please fill the original checkpoint download link: ')
converted_from['Code'] = Prompt.ask(
'Please fill the original repository link',
default=defaults.get('Convert From.Code', None))
defaults['Convert From.Code'] = converted_from['Code']
model['Converted From'] = converted_from
elif model.get('Converted From', {}).get('Code') is not None:
defaults['Convert From.Code'] = model['Converted From']['Code']
order = [
'Name', 'Metadata', 'In Collection', 'Results', 'Weights', 'Config',
'Converted From', 'Downstream'
]
model = {k: model[k] for k in sorted(model.keys(), key=order.index)}
return model
def update_model_by_dict(model: dict, update_dict: dict, defaults: dict):
# Name
if 'name override' in update_dict:
model['Name'] = update_dict['name override'].strip()
# In Collection
model['In Collection'] = defaults.get('In Collection')
# Config
if 'config' in update_dict:
config = update_dict['config'].strip()
config = str(Path(config).absolute().relative_to(MMCLS_ROOT))
config_updated = (config != model.get('Config'))
model['Config'] = config
else:
config_updated = False
# Metadata.Flops, Metadata.Parameters
flops = model.get('Metadata', {}).get('FLOPs')
params = model.get('Metadata', {}).get('Parameters')
if config_updated or (flops is None and params is None):
print(f'Automatically compute FLOPs and Parameters of {model["Name"]}')
flops, params = get_flops_params(str(MMCLS_ROOT / model['Config']))
model.setdefault('Metadata', {})
model['Metadata']['FLOPs'] = flops
model['Metadata']['Parameters'] = params
# Metadata.Training Data
if 'training dataset' in update_dict:
train_data = update_dict['training dataset'].strip()
train_data = re.split(r'\s+', train_data)
if len(train_data) > 1:
model['Metadata']['Training Data'] = train_data
elif len(train_data) == 1:
model['Metadata']['Training Data'] = train_data[0]
# Results.Dataset
if 'test dataset' in update_dict:
test_data = update_dict['test dataset'].strip()
results = model.get('Results') or [{}]
result = results[0]
result['Dataset'] = test_data
model['Results'] = results
# Results.Metrics.Top 1 Accuracy
result = None
if 'top-1' in update_dict:
top1 = update_dict['top-1']
results = model.get('Results') or [{}]
result = results[0]
result.setdefault('Metrics', {})
result['Metrics']['Top 1 Accuracy'] = round(float(top1), 2)
task = 'Image Classification'
model['Results'] = results
# Results.Metrics.Top 5 Accuracy
if 'top-5' in update_dict:
top5 = update_dict['top-5']
results = model.get('Results') or [{}]
result = results[0]
result.setdefault('Metrics', {})
result['Metrics']['Top 5 Accuracy'] = round(float(top5), 2)
task = 'Image Classification'
model['Results'] = results
if result is not None:
result['Metrics']['Task'] = task
# Weights
if 'weights' in update_dict:
weights = update_dict['weights'].strip()
model['Weights'] = weights
# Converted From.Code
if 'converted from.code' in update_dict:
from_code = update_dict['converted from.code'].strip()
model.setdefault('Converted From', {})
model['Converted From']['Code'] = from_code
# Converted From.Weights
if 'converted from.weights' in update_dict:
from_weight = update_dict['converted from.weights'].strip()
model.setdefault('Converted From', {})
model['Converted From']['Weights'] = from_weight
order = [
'Name', 'Metadata', 'In Collection', 'Results', 'Weights', 'Config',
'Converted From', 'Downstream'
]
model = {k: model[k] for k in sorted(model.keys(), key=order.index)}
return model
def format_collection(collection: dict):
yaml_str = yaml_dump(collection)
return Panel(
Syntax(yaml_str, 'yaml', background_color='default'),
width=150,
title='Collection')
def format_model(model: dict):
yaml_str = yaml_dump(model)
return Panel(
Syntax(yaml_str, 'yaml', background_color='default'),
width=150,
title='Model')
def order_models(model):
order = []
# Pre-trained model
order.append(int('Downstream' not in model))
# non-3rdparty model
order.append(int('3rdparty' in model['Name']))
# smaller model
order.append(model.get('Metadata', {}).get('Parameters', 0))
# faster model
order.append(model.get('Metadata', {}).get('FLOPs', 0))
# name order
order.append(len(model['Name']))
return tuple(order)
def main():
args = parse_args()
if args.src is not None:
with open(args.src, 'r') as f:
content = yaml.load(f, yaml.SafeLoader)
else:
content = {}
if args.view:
collection = content.get('Collections', [{}])[0]
console.print(format_collection(collection))
models = content.get('Models', [])
for model in models:
console.print(format_model(model))
return
collection = content.get('Collections', [{}])[0]
ori_collection = copy.deepcopy(collection)
console.print(format_collection(collection))
collection = fill_collection(collection)
if ori_collection != collection:
console.print(format_collection(collection))
model_defaults = {
'In Collection': collection['Name'],
'Metadata': collection.get('Metadata', {}),
}
models = content.get('Models', [])
updated_models = []
if args.csv is not None:
import pandas as pd
df = pd.read_csv(args.csv).rename(columns=lambda x: x.strip().lower())
assert df['name'].is_unique, 'The csv has duplicated model names.'
models_dict = {item['Name']: item for item in models}
for update_dict in df.to_dict('records'):
assert 'name' in update_dict, 'The csv must have the `Name` field.'
model_name = update_dict['name'].strip()
model = models_dict.pop(model_name, {'Name': model_name})
model = update_model_by_dict(model, update_dict, model_defaults)
updated_models.append(model)
updated_models.extend(models_dict.values())
else:
for model in models:
console.print(format_model(model))
ori_model = copy.deepcopy(model)
model = fill_model_by_prompt(model, model_defaults)
if ori_model != model:
console.print(format_model(model))
updated_models.append(model)
while Confirm.ask('Add new model?', default=False):
model = fill_model_by_prompt({}, model_defaults)
updated_models.append(model)
# Save updated models even error happened.
updated_models.sort(key=order_models)
if args.out is not None:
with open(args.out, 'w') as f:
yaml_dump({'Collections': [collection]}, f)
f.write('\n')
yaml_dump({'Models': updated_models}, f)
else:
modelindex = {'Collections': [collection], 'Models': updated_models}
yaml_str = yaml_dump(modelindex)
console.print(Syntax(yaml_str, 'yaml', background_color='default'))
console.print('Specify [red]`--out`[/] to dump to file.')
if __name__ == '__main__':
main()

View File

@ -0,0 +1,453 @@
# flake8: noqa
import argparse
import re
import warnings
from pathlib import Path
from modelindex.load_model_index import load
from modelindex.models.ModelIndex import ModelIndex
from tabulate import tabulate
MMPT_ROOT = Path(__file__).absolute().parents[1]
prog_description = """\
Use metafile to generate a README.md.
Notice that the tool may fail in some corner cases, and you still need to check and fill some contents manually in the generated README.
"""
PREDICT_TEMPLATE = """\
**Predict image**
```python
from mmpretrain import inference_model
predict = inference_model('{model_name}', 'demo/bird.JPEG')
print(predict['pred_class'])
print(predict['pred_score'])
```
"""
RETRIEVE_TEMPLATE = """\
**Retrieve image**
```python
from mmpretrain import ImageRetrievalInferencer
inferencer = ImageRetrievalInferencer('{model_name}', prototype='demo/')
predict = inferencer('demo/dog.jpg', topk=2)[0]
print(predict[0])
print(predict[1])
```
"""
USAGE_TEMPLATE = """\
**Use the model**
```python
import torch
from mmpretrain import get_model
model = get_model('{model_name}', pretrained=True)
inputs = torch.rand(1, 3, 224, 224)
out = model(inputs)
print(type(out))
# To extract features.
feats = model.extract_feat(inputs)
print(type(feats))
```
"""
TRAIN_TEST_TEMPLATE = """\
**Train/Test Command**
Prepare your dataset according to the [docs](https://mmpretrain.readthedocs.io/en/latest/user_guides/dataset_prepare.html#prepare-dataset).
Train:
```shell
python tools/train.py {train_config}
```
Test:
```shell
python tools/test.py {test_config} {test_weights}
```
"""
TEST_ONLY_TEMPLATE = """\
**Test Command**
Prepare your dataset according to the [docs](https://mmpretrain.readthedocs.io/en/latest/user_guides/dataset_prepare.html#prepare-dataset).
Test:
```shell
python tools/test.py {test_config} {test_weights}
```
"""
METRIC_MAPPING = {
'Top 1 Accuracy': 'Top-1 (%)',
'Top 5 Accuracy': 'Top-5 (%)',
}
DATASET_PRIORITY = {
'ImageNet-1k': 0,
'CIFAR-10': 10,
'CIFAR-100': 20,
}
def parse_args():
parser = argparse.ArgumentParser(description=prog_description)
parser.add_argument('metafile', type=Path, help='The path of metafile')
parser.add_argument(
'--table', action='store_true', help='Only generate summary tables')
parser.add_argument(
'--update', type=str, help='Update the specified readme file.')
parser.add_argument('--out', type=str, help='Output to the file.')
parser.add_argument(
'--update-items',
type=str,
nargs='+',
default=['models'],
help='Update the specified readme file.')
args = parser.parse_args()
return args
def filter_models_by_task(models, task):
model_list = []
for model in models:
if model.results is None and task is None:
model_list.append(model)
elif model.results is None:
continue
elif model.results[0].task == task or task == 'any':
model_list.append(model)
return model_list
def add_title(metafile: ModelIndex):
paper = metafile.collections[0].paper
title = paper['Title']
url = paper['URL']
abbr = metafile.collections[0].name
papertype = metafile.collections[0].data.get('type', 'Algorithm')
return f'# {abbr}\n> [{title}]({url})\n<!-- [{papertype.upper()}] -->\n'
def add_abstract(metafile: ModelIndex):
paper = metafile.collections[0].paper
url = paper['URL']
if 'arxiv' in url:
try:
import arxiv
search = arxiv.Search(id_list=[url.split('/')[-1]])
info = next(search.results())
abstract = info.summary.replace('\n', ' ')
except ImportError:
warnings.warn('Install arxiv parser by `pip install arxiv` '
'to automatically generate abstract.')
abstract = None
else:
abstract = None
content = '## Abstract\n'
if abstract is not None:
content += f'\n{abstract}\n'
return content
def add_usage(metafile):
models = metafile.models
if len(models) == 0:
return
content = []
content.append('## How to use it?\n\n<!-- [TABS-BEGIN] -->\n')
# Predict image
cls_models = filter_models_by_task(models, 'Image Classification')
if cls_models:
model_name = cls_models[0].name
content.append(PREDICT_TEMPLATE.format(model_name=model_name))
# Retrieve image
retrieval_models = filter_models_by_task(models, 'Image Retrieval')
if retrieval_models:
model_name = retrieval_models[0].name
content.append(RETRIEVE_TEMPLATE.format(model_name=model_name))
# Use the model
model_name = models[0].name
content.append(USAGE_TEMPLATE.format(model_name=model_name))
# Train/Test Command
inputs = {}
train_model = [
model for model in models
if 'headless' not in model.name and '3rdparty' not in model.name
]
if train_model:
template = TRAIN_TEST_TEMPLATE
inputs['train_config'] = train_model[0].config
elif len(filter_models_by_task(models, task='any')) > 0:
template = TEST_ONLY_TEMPLATE
else:
content.append('\n<!-- [TABS-END] -->\n')
return '\n'.join(content)
test_model = filter_models_by_task(models, task='any')[0]
inputs['test_config'] = test_model.config
inputs['test_weights'] = test_model.weights
content.append(template.format(**inputs))
content.append('\n<!-- [TABS-END] -->\n')
return '\n'.join(content)
def format_pretrain(pretrain_field):
pretrain_infos = pretrain_field.split('-')[:-1]
infos = []
for info in pretrain_infos:
if re.match('^\d+e$', info):
info = f'{info[:-1]}-Epochs'
elif re.match('^in\d+k$', info):
info = f'ImageNet-{info[2:-1]}k'
else:
info = info.upper()
infos.append(info)
return ' '.join(infos)
def generate_model_table(models,
folder,
with_pretrain=True,
with_metric=True,
pretrained_models=[]):
header = ['Model']
if with_pretrain:
header.append('Pretrain')
header.extend(['Params (M)', 'Flops (G)'])
if with_metric:
metrics = set()
for model in models:
metrics.update(model.results[0].metrics.keys())
metrics = sorted(list(set(metrics)))
for metric in metrics:
header.append(METRIC_MAPPING.get(metric, metric))
header.extend(['Config', 'Download'])
rows = []
for model in models:
model_name = f'`{model.name}`'
config = (MMPT_ROOT / model.config).relative_to(folder)
if model.weights is not None:
download = f'[model]({model.weights})'
else:
download = 'N/A'
if 'Converted From' in model.data:
model_name += '\*'
converted_from = model.data['Converted From']
elif model.weights is not None:
log = re.sub(r'.pth$', '.json', model.weights)
download += f' \| [log]({log})'
row = [model_name]
if with_pretrain:
pretrain_field = [
field for field in model.name.split('_')
if field.endswith('-pre')
]
if pretrain_field:
pretrain = format_pretrain(pretrain_field[0])
upstream = [
pretrain_model for pretrain_model in pretrained_models
if model.name in pretrain_model.data.get('Downstream', [])
]
if upstream:
pretrain = f'[{pretrain}]({upstream[0].weights})'
else:
pretrain = 'From scratch'
row.append(pretrain)
if model.metadata.parameters is not None:
row.append(f'{model.metadata.parameters / 1e6:.2f}') # Params
else:
row.append('N/A')
if model.metadata.flops is not None:
row.append(f'{model.metadata.flops / 1e9:.2f}') # Params
else:
row.append('N/A')
if with_metric:
for metric in metrics:
row.append(model.results[0].metrics.get(metric, 'N/A'))
row.append(f'[config]({config})')
row.append(download)
rows.append(row)
table_cfg = dict(
tablefmt='pipe',
floatfmt='.2f',
colalign=['left'] + ['center'] * (len(row) - 1))
table_string = tabulate(rows, header, **table_cfg) + '\n'
if any('Converted From' in model.data for model in models):
table_string += (
f"\n*Models with \* are converted from the [official repo]({converted_from['Code']}). "
"The config files of these models are only for inference. We haven't reproduce the training results.*\n"
)
return table_string
def add_models(metafile):
models = metafile.models
if len(models) == 0:
return ''
content = ['## Models and results\n']
algo_folder = Path(metafile.filepath).parent.absolute().resolve()
# Pretrained models
pretrain_models = filter_models_by_task(models, task=None)
if pretrain_models:
content.append('### Pretrained models\n')
content.append(
generate_model_table(
pretrain_models,
algo_folder,
with_pretrain=False,
with_metric=False))
# Classification models
tasks = [
'Image Classification',
'Image Retrieval',
'Multi-Label Classification',
'Image Caption',
'Visual Grounding',
'Visual Question Answering',
'Image-To-Text Retrieval',
'Text-To-Image Retrieval',
'NLVR',
]
for task in tasks:
task_models = filter_models_by_task(models, task=task)
if task_models:
datasets = {model.results[0].dataset for model in task_models}
datasets = sorted(
list(datasets), key=lambda x: DATASET_PRIORITY.get(x, 50))
for dataset in datasets:
content.append(f'### {task} on {dataset}\n')
dataset_models = [
model for model in task_models
if model.results[0].dataset == dataset
]
content.append(
generate_model_table(
dataset_models,
algo_folder,
pretrained_models=pretrain_models))
return '\n'.join(content)
def parse_readme(readme):
with open(readme, 'r') as f:
file = f.read()
content = {}
for img_match in re.finditer(
'^<div.*\n.*\n</div>\n', file, flags=re.MULTILINE):
content['image'] = img_match.group()
start, end = img_match.span()
file = file[:start] + file[end:]
break
sections = re.split('^## ', file, flags=re.MULTILINE)
for section in sections:
if section.startswith('# '):
content['title'] = section.strip() + '\n'
elif section.startswith('Introduction'):
content['intro'] = '## ' + section.strip() + '\n'
elif section.startswith('Abstract'):
content['abs'] = '## ' + section.strip() + '\n'
elif section.startswith('How to use it'):
content['usage'] = '## ' + section.strip() + '\n'
elif section.startswith('Models and results'):
content['models'] = '## ' + section.strip() + '\n'
elif section.startswith('Citation'):
content['citation'] = '## ' + section.strip() + '\n'
else:
section_title = section.split('\n', maxsplit=1)[0]
content[section_title] = '## ' + section.strip() + '\n'
return content
def combine_readme(content: dict):
content = content.copy()
readme = content.pop('title')
if 'intro' in content:
readme += f"\n{content.pop('intro')}"
readme += f"\n{content.pop('image')}"
readme += f"\n{content.pop('abs')}"
else:
readme += f"\n{content.pop('abs')}"
readme += f"\n{content.pop('image')}"
readme += f"\n{content.pop('usage')}"
readme += f"\n{content.pop('models')}"
citation = content.pop('citation')
if content:
# Custom sections
for v in content.values():
readme += f'\n{v}'
readme += f'\n{citation}'
return readme
def main():
args = parse_args()
metafile = load(str(args.metafile))
if args.table:
print(add_models(metafile))
return
if args.update is not None:
content = parse_readme(args.update)
else:
content = {}
if 'title' not in content or 'title' in args.update_items:
content['title'] = add_title(metafile)
if 'abs' not in content or 'abs' in args.update_items:
content['abs'] = add_abstract(metafile)
if 'image' not in content or 'image' in args.update_items:
img = '<div align=center>\n<img src="" width="50%"/>\n</div>\n'
content['image'] = img
if 'usage' not in content or 'usage' in args.update_items:
content['usage'] = add_usage(metafile)
if 'models' not in content or 'models' in args.update_items:
content['models'] = add_models(metafile)
if 'citation' not in content:
content['citation'] = '## Citation\n```bibtex\n```\n'
content = combine_readme(content)
if args.out is not None:
with open(args.out, 'w') as f:
f.write(content)
else:
print(content)
if __name__ == '__main__':
main()

View File

@ -1,33 +0,0 @@
---
name: 寻求帮助
about: 遇到问题并寻求帮助
title: ''
labels: help wanted
assignees: ''
---
推荐使用英语模板 General question以便你的问题帮助更多人。
### 首先确认以下内容
- 我已经查询了相关的 issue但没有找到需要的帮助。
- 我已经阅读了相关文档,但仍不知道如何解决。
### 描述你遇到的问题
\[填写这里\]
### 相关信息
1. `pip list | grep "mmcv\|mmcls\|^torch"` 命令的输出
\[填写这里\]
2. 如果你修改了,或者使用了新的配置文件,请在这里写明
```python
[填写这里]
```
3. 如果你是在训练过程中遇到的问题,请填写完整的训练日志和报错信息
\[填写这里\]
4. 如果你对 `mmcls` 文件夹下的代码做了其他相关的修改,请在这里写明
\[填写这里\]

View File

@ -1,34 +0,0 @@
---
name: 新功能
about: 为项目提一个建议
title: '[Feature]'
labels: enhancement
assignees: ''
---
推荐使用英语模板 Feature request以便你的问题帮助更多人。
### 描述这个功能
\[填写这里\]
### 动机
请简要说明以下为什么需要添加这个新功能
例 1. 现在进行 xxx 的时候不方便
例 2. 最近的论文中提出了有一个很有帮助的 xx
\[填写这里\]
### 相关资源
是否有相关的官方实现或者第三方实现?这些会很有参考意义。
\[填写这里\]
### 其他相关信息
其他和这个功能相关的信息或者截图,请放在这里。
另外如果你愿意参与实现这个功能并提交 PR请在这里说明我们将非常欢迎。
\[填写这里\]

View File

@ -1,44 +0,0 @@
---
name: 报告 Bug
about: 报告问题以帮助我们提升
title: '[Bug]'
labels: bug
assignees: ''
---
推荐使用英语模板 Bug report以便你的问题帮助更多人。
### 描述 bug
简单地描述一下遇到了什么 bug
\[填写这里\]
### 复现流程
在命令行中执行的详细操作
```shell
[填写这里]
```
### 相关信息
1. `pip list | grep "mmcv\|mmcls\|^torch"` 命令的输出
\[填写这里\]
2. 如果你修改了,或者使用了新的配置文件,请在这里写明
```python
[填写这里]
```
3. 如果你是在训练过程中遇到的问题,请填写完整的训练日志和报错信息
\[填写这里\]
4. 如果你对 `mmcls` 文件夹下的代码做了其他相关的修改,请在这里写明
\[填写这里\]
### 附加内容
任何其他有关该 bug 的信息、截图等
\[填写这里\]

View File

@ -0,0 +1,69 @@
name: 🐞 Bug report
description: Create a report to help us improve
labels: ["bug"]
title: "[Bug] "
body:
- type: markdown
attributes:
value: |
If you have already identified the reason, we strongly appreciate you creating a new PR according to [the tutorial](https://mmpretrain.readthedocs.io/en/master/community/CONTRIBUTING.html)!
If you need our help, please fill in the following form to help us to identify the bug.
- type: dropdown
id: version
attributes:
label: Branch
description: Which branch/version are you using?
options:
- main branch (mmpretrain version)
- mmcls-1.x branch (v1.0.0rc6 or other 1.x version)
- mmcls-0.x branch (v0.25.0 or other 0.x version)
validations:
required: true
- type: textarea
id: describe
validations:
required: true
attributes:
label: Describe the bug
description: |
Please provide a clear and concise description of what the bug is.
Preferably a simple and minimal code snippet that we can reproduce the error by running the code.
placeholder: |
A clear and concise description of what the bug is.
```python
# Sample code to reproduce the problem
```
```shell
The command or script you run.
```
```
The error message or logs you got, with the full traceback.
```
- type: textarea
id: environment
validations:
required: true
attributes:
label: Environment
description: |
Please run `python -c "import mmpretrain.utils;import pprint;pprint.pp(dict(mmpretrain.utils.collect_env()))"` to collect necessary environment information and paste it here.
placeholder: |
```python
# The output the above command
```
- type: textarea
id: other
attributes:
label: Other information
description: |
Tell us anything else you think we should know.
1. Did you make any modifications on the code or config?
2. What do you think might be the reason?

View File

@ -0,0 +1,29 @@
name: 🚀 Feature request
description: Suggest an idea for this project
labels: ["enhancement"]
title: "[Feature] "
body:
- type: markdown
attributes:
value: |
If you have already implemented the feature, we strongly appreciate you creating a new PR according to [the tutorial](https://mmpretrain.readthedocs.io/en/master/community/CONTRIBUTING.html)!
- type: textarea
id: describe
validations:
required: true
attributes:
label: Describe the feature
description: |
What kind of feature do you want MMPreTrain to add. If there is an official code release or third-party implementation, please also provide the information here, which would be very helpful.
placeholder: |
A clear and concise description of the motivation of the feature.
Ex1. It is inconvenient when \[....\].
Ex2. There is a recent paper \[....\], which is very helpful for \[....\].
- type: checkboxes
id: pr
attributes:
label: Will you implement it?
options:
- label: I would like to implement this feature and create a PR!

View File

@ -0,0 +1,70 @@
name: 🐞 报告 Bug
description: 报告你在使用中遇到的不合预期的情况
labels: ["bug"]
title: "[Bug] "
body:
- type: markdown
attributes:
value: |
我们推荐使用英语模板 Bug report以便你的问题帮助更多人。
如果你已经有了解决方案,我们非常欢迎你直接创建一个新的 PR 来解决这个问题。创建 PR 的流程可以参考[文档](https://mmpretrain.readthedocs.io/zh_CN/master/community/CONTRIBUTING.html)。
如果你需要我们的帮助,请填写以下内容帮助我们定位 Bug。
- type: dropdown
id: version
attributes:
label: 分支
description: 你正在使用的分支/版本是哪个?
options:
- main 分支 (mmpretrain 版本)
- mmcls-1.x 分支 (v1.0.0rc6 或者其它 1.x 版本)
- mmcls-0.x 分支 (v0.25.0 或者其它 0.x 版本)
validations:
required: true
- type: textarea
id: describe
validations:
required: true
attributes:
label: 描述该错误
description: |
请简要说明你遇到的错误。如果可以的话,请提供一个简短的代码片段帮助我们复现这一错误。
placeholder: |
问题的简要说明
```python
# 复现错误的代码片段
```
```shell
# 发生错误时你的运行命令
```
```
错误信息和日志,请展示全部的错误日志和 traceback
```
- type: textarea
id: environment
validations:
required: true
attributes:
label: 环境信息
description: |
请运行指令 `python -c "import mmpretrain.utils;import pprint;pprint.pp(dict(mmpretrain.utils.collect_env()))"` 来收集必要的环境信息,并贴在下方。
placeholder: |
```python
# 上述命令的输出
```
- type: textarea
id: other
attributes:
label: 其他信息
description: |
告诉我们其他有价值的信息。
1. 你是否对代码或配置文件做了任何改动?
2. 你认为可能的原因是什么?

View File

@ -0,0 +1,31 @@
name: 🚀 功能建议
description: 建议一项新的功能
labels: ["enhancement"]
title: "[Feature] "
body:
- type: markdown
attributes:
value: |
推荐使用英语模板 Feature request以便你的问题帮助更多人。
如果你已经实现了该功能,我们非常欢迎你直接创建一个新的 PR 来解决这个问题。创建 PR 的流程可以参考[文档](https://mmpretrain.readthedocs.io/zh_CN/master/community/CONTRIBUTING.html)。
- type: textarea
id: describe
validations:
required: true
attributes:
label: 描述该功能
description: |
你希望 MMPreTrain 添加什么功能?如果存在相关的论文、官方实现或者第三方实现,请同时贴出链接,这将非常有帮助。
placeholder: |
简要说明该功能,及为什么需要该功能
例 1. 现在进行 xxx 的时候不方便
例 2. 最近的论文中提出了有一个很有帮助的 xx
- type: checkboxes
id: pr
attributes:
label: 是否希望自己实现该功能?
options:
- label: 我希望自己来实现这一功能,并向 MMPreTrain 贡献代码!

View File

@ -1,42 +0,0 @@
---
name: Bug report
about: Create a report to help us improve
title: '[Bug]'
labels: bug
assignees: ''
---
### Describe the bug
A clear and concise description of what the bug is.
\[here\]
### To Reproduce
The command you executed.
```shell
[here]
```
### Post related information
1. The output of `pip list | grep "mmcv\|mmcls\|^torch"`
\[here\]
2. Your config file if you modified it or created a new one.
```python
[here]
```
3. Your train log file if you meet the problem during training.
\[here\]
4. Other code you modified in the `mmcls` folder.
\[here\]
### Additional context
Add any other context about the problem here.
\[here\]

View File

@ -1,6 +1,12 @@
blank_issues_enabled: false
contact_links:
- name: MMClassification Documentation
url: https://mmclassification.readthedocs.io/en/latest/
- name: 📚 MMPreTrain Documentation (官方文档)
url: https://mmpretrain.readthedocs.io/en/latest/
about: Check if your question is answered in docs
- name: 💬 General questions (寻求帮助)
url: https://github.com/open-mmlab/mmpretrain/discussions
about: Ask general usage questions and discuss with other MMPreTrain community members
- name: 🌐 Explore OpenMMLab (官网)
url: https://openmmlab.com/
about: Get know more about OpenMMLab

View File

@ -1,32 +0,0 @@
---
name: Feature request
about: Suggest an idea for this project
title: '[Feature]'
labels: enhancement
assignees: ''
---
### Describe the feature
\[here\]
### Motivation
A clear and concise description of the motivation of the feature.
Ex1. It is inconvenient when \[....\].
Ex2. There is a recent paper \[....\], which is very helpful for \[....\].
\[here\]
### Related resources
If there is an official code release or third-party implementation, please also provide the information here, which would be very helpful.
\[here\]
### Additional context
Add any other context or screenshots about the feature request here.
If you would like to implement the feature and create a PR, please leave a comment here and that would be much appreciated.
\[here\]

View File

@ -1,31 +0,0 @@
---
name: General questions
about: 'Ask general questions to get help '
title: ''
labels: help wanted
assignees: ''
---
### Checklist
- I have searched related issues but cannot get the expected help.
- I have read related documents and don't know what to do.
### Describe the question you meet
\[here\]
### Post related information
1. The output of `pip list | grep "mmcv\|mmcls\|^torch"`
\[here\]
2. Your config file if you modified it or created a new one.
```python
[here]
```
3. Your train log file if you meet the problem during training.
\[here\]
4. Other code you modified in the `mmcls` folder.
\[here\]

View File

@ -1,22 +0,0 @@
name: deploy
on: push
jobs:
build-n-publish:
runs-on: ubuntu-latest
if: startsWith(github.event.ref, 'refs/tags')
steps:
- uses: actions/checkout@v2
- name: Set up Python 3.7
uses: actions/setup-python@v2
with:
python-version: 3.7
- name: Build MMClassification
run: |
pip install wheel
python setup.py sdist bdist_wheel
- name: Publish distribution to PyPI
run: |
pip install twine
twine upload dist/* -u __token__ -p ${{ secrets.pypi_password }}

View File

@ -10,9 +10,9 @@ jobs:
lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v3
- name: Set up Python 3.7
uses: actions/setup-python@v2
uses: actions/setup-python@v4
with:
python-version: 3.7
- name: Install pre-commit hook
@ -24,4 +24,4 @@ jobs:
- name: Check docstring coverage
run: |
pip install interrogate
interrogate -v --ignore-init-method --ignore-module --ignore-nested-functions --ignore-magic --ignore-regex "__repr__" --fail-under 60 mmcls
interrogate -v --ignore-init-method --ignore-module --ignore-nested-functions --ignore-magic --ignore-regex "__repr__" --fail-under 60 mmpretrain

View File

@ -18,7 +18,7 @@ concurrency:
jobs:
build:
runs-on: ubuntu-18.04
runs-on: ubuntu-22.04
strategy:
matrix:
python-version: [3.7]
@ -26,29 +26,77 @@ jobs:
- torch: 1.8.1
torchvision: 0.9.1
steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Upgrade pip
run: pip install pip --upgrade
- name: Install PyTorch
run: pip install torch==${{matrix.torch}}+cpu torchvision==${{matrix.torchvision}}+cpu -f https://download.pytorch.org/whl/torch_stable.html
- name: Install mmcls dependencies
run: pip install torch==${{matrix.torch}}+cpu torchvision==${{matrix.torchvision}}+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html
- name: Install mmpretrain dependencies
run: |
pip install git+https://github.com/open-mmlab/mmengine.git@main
pip install -U openmim
mim install 'mmcv >= 2.0.0rc1'
mim install 'mmcv >= 2.0.0rc4'
pip install -r requirements.txt
- name: Build and install
run: mim install .
- name: Run unittests and generate coverage report
run: |
coverage run --branch --source mmpretrain -m pytest tests/
coverage xml
coverage report -m
# Upload coverage report for python3.7 && pytorch1.8.1 cpu
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v1.0.14
with:
file: ./coverage.xml
flags: unittests
env_vars: OS,PYTHON
name: codecov-umbrella
fail_ci_if_error: false
build_cu117:
runs-on: ubuntu-22.04
container:
image: pytorch/pytorch:2.0.0-cuda11.7-cudnn8-devel
strategy:
matrix:
python-version: [3.9]
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Upgrade pip
run: pip install pip --upgrade
- name: Fetch GPG keys
run: |
apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/3bf863cc.pub
apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64/7fa2af80.pub
- name: Install Python-dev
run: apt-get update && apt-get install -y python${{matrix.python-version}}-dev
if: ${{matrix.python-version != 3.9}}
- name: Install system dependencies
run: |
apt-get update
apt-get install -y ffmpeg libsm6 libxext6 git ninja-build libglib2.0-0 libxrender-dev libc6 libc6-dev
- name: Install mmpretrain dependencies
run: |
pip install git+https://github.com/open-mmlab/mmengine.git@main
pip install -U openmim
mim install 'mmcv >= 2.0.0rc4'
pip install -r requirements.txt
- name: Build and install
run: pip install -e .
- name: Run unittests and generate coverage report
run: |
coverage run --branch --source mmcls -m pytest tests/ -k 'not timm'
coverage run --branch --source mmpretrain -m pytest tests/ --ignore tests/test_tools.py
coverage xml
coverage report -m
# Upload coverage report for python3.7 && pytorch1.8.1 cpu
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v1.0.14
with:
@ -62,26 +110,26 @@ jobs:
runs-on: windows-2022
strategy:
matrix:
python: [3.7]
platform: [cu111]
python-version: [3.7]
platform: [cpu]
steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Upgrade pip
run: pip install pip --upgrade
run: python -m pip install pip --upgrade
- name: Install PyTorch
run: pip install torch==1.8.2+${{matrix.platform}} torchvision==0.9.2+${{matrix.platform}} -f https://download.pytorch.org/whl/lts/1.8/torch_lts.html
- name: Install mmcls dependencies
- name: Install mmpretrain dependencies
run: |
pip install git+https://github.com/open-mmlab/mmengine.git@main
pip install -U openmim
mim install 'mmcv >= 2.0.0rc1'
mim install mmengine
mim install 'mmcv >= 2.0.0rc4'
pip install -r requirements.txt
- name: Build and install
run: pip install -e .
run: mim install .
- name: Run unittests
run: |
pytest tests/ -k 'not timm' --ignore tests/test_models/test_backbones
pytest tests/ --ignore tests/test_models/test_backbones

View File

@ -7,12 +7,12 @@ jobs:
runs-on: ubuntu-latest
if: startsWith(github.event.ref, 'refs/tags')
steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v3
- name: Set up Python 3.7
uses: actions/setup-python@v2
uses: actions/setup-python@v4
with:
python-version: 3.7
- name: Build MMClassification
- name: Build MMPretrain
run: |
pip install wheel
python setup.py sdist bdist_wheel

View File

@ -17,7 +17,7 @@ concurrency:
jobs:
build_cpu:
runs-on: ubuntu-18.04
runs-on: ubuntu-22.04
strategy:
matrix:
python-version: [3.7]
@ -27,9 +27,9 @@ jobs:
torch_version: torch1.8
torchvision: 0.9.0
steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Upgrade pip
@ -39,6 +39,6 @@ jobs:
- name: Install openmim
run: pip install openmim
- name: Build and install
run: mim install -e .
run: mim install .
- name: test commands of mim
run: mim search mmcls
run: mim search mmpretrain

17
.gitignore vendored
View File

@ -76,6 +76,9 @@ docs/zh_CN/_model_zoo.rst
docs/zh_CN/modelzoo_statistics.md
docs/zh_CN/papers/
docs/zh_CN/api/generated/
docs/zh_CN/locales/
!docs/zh_CN/locales/zh_CN/LC_MESSAGES/api.po
!docs/zh_CN/locales/zh_CN/LC_MESSAGES/papers.po
# PyBuilder
target/
@ -122,7 +125,9 @@ venv.bak/
*.pkl.json
*.log.json
/work_dirs
/mmcls/.mim
/projects/*/work_dirs
/projects/*/data
/mmpretrain/.mim
.DS_Store
# Pytorch
@ -133,3 +138,13 @@ venv.bak/
*.pvti-journal
/cache_engine
/report
# slurm
*.out
# tensorflow
*.tar.gz
checkpoint
model_params.txt
*.ckpt*
results.txt

View File

@ -5,7 +5,7 @@ repos:
hooks:
- id: flake8
- repo: https://github.com/PyCQA/isort
rev: 5.10.1
rev: 5.11.5
hooks:
- id: isort
- repo: https://github.com/pre-commit/mirrors-yapf
@ -29,9 +29,9 @@ repos:
rev: 0.7.9
hooks:
- id: mdformat
args: ["--number", "--table-width", "200"]
args: ["--number", "--table-width", "200", '--disable-escape', 'backslash', '--disable-escape', 'link-enclosure']
additional_dependencies:
- mdformat-openmmlab
- "mdformat-openmmlab>=0.0.4"
- mdformat_frontmatter
- linkify-it-py
- repo: https://github.com/codespell-project/codespell
@ -47,7 +47,18 @@ repos:
rev: v0.4.0
hooks:
- id: check-copyright
args: ["mmcls", "tests", "demo", "tools", "--excludes", "mmcls/.mim/", "--ignore-file-not-found-error"]
args: ["mmpretrain", "tests", "demo", "tools", "--excludes", "mmpretrain/.mim/", "--ignore-file-not-found-error"]
- repo: local
hooks:
- id: metafile
args: ['--skip', 'flops-param']
name: metafile
description: Check the format of metafile
entry: python .dev_scripts/check_metafile.py
language: python
files: (metafile)\.(yml)$
additional_dependencies:
- modelindex
# - repo: local
# hooks:
# - id: clang-format

View File

@ -1,9 +1,15 @@
version: 2
formats: all
# Set the version of Python and other tools you might need
build:
os: ubuntu-22.04
tools:
python: "3.8"
formats:
- epub
python:
version: 3.7
install:
- requirements: requirements/docs.txt
- requirements: requirements/readthedocs.txt

View File

@ -1,9 +1,9 @@
cff-version: 1.2.0
message: "If you use this software, please cite it as below."
title: "OpenMMLab's Image Classification Toolbox and Benchmark"
title: "OpenMMLab's Pre-training Toolbox and Benchmark"
authors:
- name: "MMClassification Contributors"
- name: "MMPreTrain Contributors"
version: 0.15.0
date-released: 2020-07-09
repository-code: "https://github.com/open-mmlab/mmclassification"
date-released: 2023-04-06
repository-code: "https://github.com/open-mmlab/mmpretrain"
license: Apache-2.0

View File

@ -1,13 +1,13 @@
# Contributing to MMClassification
# Contributing to MMPreTrain
- [Contributing to MMClassification](#contributing-to-mmclassification)
- [Contributing to MMPreTrain](#contributing-to-mmpretrain)
- [Workflow](#workflow)
- [Code style](#code-style)
- [Python](#python)
- [C++ and CUDA](#c-and-cuda)
- [Pre-commit Hook](#pre-commit-hook)
Thanks for your interest in contributing to MMClassification! All kinds of contributions are welcome, including but not limited to the following.
Thanks for your interest in contributing to MMPreTrain! All kinds of contributions are welcome, including but not limited to the following.
- Fix typo or bugs
- Add documentation or translate the documentation into other languages
@ -17,7 +17,7 @@ Thanks for your interest in contributing to MMClassification! All kinds of contr
We recommend the potential contributors follow this workflow for contribution.
1. Fork and pull the latest MMClassification repository, follow [get started](https://mmclassification.readthedocs.io/en/1.x/get_started.html) to setup the environment.
1. Fork and pull the latest MMPreTrain repository, follow [get started](https://mmpretrain.readthedocs.io/en/latest/get_started.html) to setup the environment.
2. Checkout a new branch (**do not use the master or dev branch** for PRs)
```bash
@ -44,7 +44,7 @@ We use the following tools for linting and formatting:
- [mdformat](https://github.com/executablebooks/mdformat): Mdformat is an opinionated Markdown formatter that can be used to enforce a consistent style in Markdown files.
- [docformatter](https://github.com/myint/docformatter): A formatter to format docstring.
Style configurations of yapf and isort can be found in [setup.cfg](https://github.com/open-mmlab/mmclassification/blob/1.x/setup.cfg).
Style configurations of yapf and isort can be found in [setup.cfg](https://github.com/open-mmlab/mmpretrain/blob/main/setup.cfg).
### C++ and CUDA
@ -54,7 +54,7 @@ We follow the [Google C++ Style Guide](https://google.github.io/styleguide/cppgu
We use [pre-commit hook](https://pre-commit.com/) that checks and formats for `flake8`, `yapf`, `isort`, `trailing whitespaces`, `markdown files`,
fixes `end-of-files`, `double-quoted-strings`, `python-encoding-pragma`, `mixed-line-ending`, sorts `requirments.txt` automatically on every commit.
The config for a pre-commit hook is stored in [.pre-commit-config](https://github.com/open-mmlab/mmclassification/blob/1.x/.pre-commit-config.yaml).
The config for a pre-commit hook is stored in [.pre-commit-config](https://github.com/open-mmlab/mmpretrain/blob/main/.pre-commit-config.yaml).
After you clone the repository, you will need to install initialize pre-commit hook.

View File

@ -188,7 +188,7 @@ Copyright (c) OpenMMLab. All rights reserved
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright 2020 MMClassification Authors.
Copyright 2020 MMPreTrain Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View File

@ -1,4 +1,5 @@
include requirements/*.txt
include mmcls/.mim/model-index.yml
recursive-include mmcls/.mim/configs *.py *.yml
recursive-include mmcls/.mim/tools *.py *.sh
include mmpretrain/.mim/model-index.yml
include mmpretrain/.mim/dataset-index.yml
recursive-include mmpretrain/.mim/configs *.py *.yml
recursive-include mmpretrain/.mim/tools *.py *.sh

331
README.md
View File

@ -1,6 +1,6 @@
<div align="center">
<img src="resources/mmcls-logo.png" width="600"/>
<img src="resources/mmpt-logo.png" width="600"/>
<div>&nbsp;</div>
<div align="center">
<b><font size="5">OpenMMLab website</font></b>
@ -19,60 +19,103 @@
</div>
<div>&nbsp;</div>
[![PyPI](https://img.shields.io/pypi/v/mmcls)](https://pypi.org/project/mmcls)
[![Docs](https://img.shields.io/badge/docs-latest-blue)](https://mmclassification.readthedocs.io/en/1.x/)
[![Build Status](https://github.com/open-mmlab/mmclassification/workflows/build/badge.svg)](https://github.com/open-mmlab/mmclassification/actions)
[![codecov](https://codecov.io/gh/open-mmlab/mmclassification/branch/1.x/graph/badge.svg)](https://codecov.io/gh/open-mmlab/mmclassification)
[![license](https://img.shields.io/github/license/open-mmlab/mmclassification.svg)](https://github.com/open-mmlab/mmclassification/blob/1.x/LICENSE)
[![open issues](https://isitmaintained.com/badge/open/open-mmlab/mmclassification.svg)](https://github.com/open-mmlab/mmclassification/issues)
[![issue resolution](https://isitmaintained.com/badge/resolution/open-mmlab/mmclassification.svg)](https://github.com/open-mmlab/mmclassification/issues)
[![PyPI](https://img.shields.io/pypi/v/mmpretrain)](https://pypi.org/project/mmpretrain)
[![Docs](https://img.shields.io/badge/docs-latest-blue)](https://mmpretrain.readthedocs.io/en/latest/)
[![Build Status](https://github.com/open-mmlab/mmpretrain/workflows/build/badge.svg)](https://github.com/open-mmlab/mmpretrain/actions)
[![codecov](https://codecov.io/gh/open-mmlab/mmpretrain/branch/main/graph/badge.svg)](https://codecov.io/gh/open-mmlab/mmpretrain)
[![license](https://img.shields.io/github/license/open-mmlab/mmpretrain.svg)](https://github.com/open-mmlab/mmpretrain/blob/main/LICENSE)
[![open issues](https://isitmaintained.com/badge/open/open-mmlab/mmpretrain.svg)](https://github.com/open-mmlab/mmpretrain/issues)
[![issue resolution](https://isitmaintained.com/badge/resolution/open-mmlab/mmpretrain.svg)](https://github.com/open-mmlab/mmpretrain/issues)
[📘 Documentation](https://mmclassification.readthedocs.io/en/1.x/) |
[🛠️ Installation](https://mmclassification.readthedocs.io/en/1.xget_started.html) |
[👀 Model Zoo](https://mmclassification.readthedocs.io/en/1.x/modelzoo_statistics.html) |
[🆕 Update News](https://mmclassification.readthedocs.io/en/1.x/notes/changelog.html) |
[🤔 Reporting Issues](https://github.com/open-mmlab/mmclassification/issues/new/choose)
[📘 Documentation](https://mmpretrain.readthedocs.io/en/latest/) |
[🛠️ Installation](https://mmpretrain.readthedocs.io/en/latest/get_started.html#installation) |
[👀 Model Zoo](https://mmpretrain.readthedocs.io/en/latest/modelzoo_statistics.html) |
[🆕 Update News](https://mmpretrain.readthedocs.io/en/latest/notes/changelog.html) |
[🤔 Reporting Issues](https://github.com/open-mmlab/mmpretrain/issues/new/choose)
<img src="https://user-images.githubusercontent.com/36138628/230307505-4727ad0a-7d71-4069-939d-b499c7e272b7.png" width="400"/>
English | [简体中文](/README_zh-CN.md)
</div>
</div>
<div align="center">
<a href="https://openmmlab.medium.com/" style="text-decoration:none;">
<img src="https://user-images.githubusercontent.com/25839884/219255827-67c1a27f-f8c5-46a9-811d-5e57448c61d1.png" width="3%" alt="" /></a>
<img src="https://user-images.githubusercontent.com/25839884/218346358-56cc8e2f-a2b8-487f-9088-32480cceabcf.png" width="3%" alt="" />
<a href="https://discord.gg/raweFPmdzG" style="text-decoration:none;">
<img src="https://user-images.githubusercontent.com/25839884/218347213-c080267f-cbb6-443e-8532-8e1ed9a58ea9.png" width="3%" alt="" /></a>
<img src="https://user-images.githubusercontent.com/25839884/218346358-56cc8e2f-a2b8-487f-9088-32480cceabcf.png" width="3%" alt="" />
<a href="https://twitter.com/OpenMMLab" style="text-decoration:none;">
<img src="https://user-images.githubusercontent.com/25839884/218346637-d30c8a0f-3eba-4699-8131-512fb06d46db.png" width="3%" alt="" /></a>
<img src="https://user-images.githubusercontent.com/25839884/218346358-56cc8e2f-a2b8-487f-9088-32480cceabcf.png" width="3%" alt="" />
<a href="https://www.youtube.com/openmmlab" style="text-decoration:none;">
<img src="https://user-images.githubusercontent.com/25839884/218346691-ceb2116a-465a-40af-8424-9f30d2348ca9.png" width="3%" alt="" /></a>
<img src="https://user-images.githubusercontent.com/25839884/218346358-56cc8e2f-a2b8-487f-9088-32480cceabcf.png" width="3%" alt="" />
<a href="https://space.bilibili.com/1293512903" style="text-decoration:none;">
<img src="https://user-images.githubusercontent.com/25839884/219026751-d7d14cce-a7c9-4e82-9942-8375fca65b99.png" width="3%" alt="" /></a>
<img src="https://user-images.githubusercontent.com/25839884/218346358-56cc8e2f-a2b8-487f-9088-32480cceabcf.png" width="3%" alt="" />
<a href="https://www.zhihu.com/people/openmmlab" style="text-decoration:none;">
<img src="https://user-images.githubusercontent.com/25839884/219026120-ba71e48b-6e94-4bd4-b4e9-b7d175b5e362.png" width="3%" alt="" /></a>
</div>
## Introduction
English | [简体中文](/README_zh-CN.md)
MMPreTrain is an open source pre-training toolbox based on PyTorch. It is a part of the [OpenMMLab](https://openmmlab.com/) project.
MMClassification is an open source image classification toolbox based on PyTorch. It is
a part of the [OpenMMLab](https://openmmlab.com/) project.
The `1.x` branch works with **PyTorch 1.6+**.
<div align="center">
<img src="https://user-images.githubusercontent.com/9102141/87268895-3e0d0780-c4fe-11ea-849e-6140b7e0d4de.gif" width="70%"/>
</div>
The `main` branch works with **PyTorch 1.8+**.
### Major features
- Various backbones and pretrained models
- Rich training strategies (supervised learning, self-supervised learning, multi-modality learning etc.)
- Bag of training tricks
- Large-scale training configs
- High efficiency and extensibility
- Powerful toolkits
- Powerful toolkits for model analysis and experiments
- Various out-of-box inference tasks.
- Image Classification
- Image Caption
- Visual Question Answering
- Visual Grounding
- Retrieval (Image-To-Image, Text-To-Image, Image-To-Text)
https://github.com/open-mmlab/mmpretrain/assets/26739999/e4dcd3a2-f895-4d1b-a351-fbc74a04e904
## What's new
v1.0.0rc1 was released in 30/9/2022.
🌟 v1.2.0 was released in 04/01/2023
- Support MViT, EdgeNeXt, Swin-Transformer V2, EfficientFormer and MobileOne.
- Support BEiT type transformer layer.
- Support LLaVA 1.5.
- Implement of RAM with a gradio interface.
v1.0.0rc0 was released in 31/8/2022.
🌟 v1.1.0 was released in 12/10/2023
- Support Mini-GPT4 training and provide a Chinese model (based on Baichuan-7B)
- Support zero-shot classification based on CLIP.
🌟 v1.0.0 was released in 04/07/2023
- Support inference of more **multi-modal** algorithms, such as [**LLaVA**](./configs/llava/), [**MiniGPT-4**](./configs/minigpt4), [**Otter**](./configs/otter/), etc.
- Support around **10 multi-modal** datasets!
- Add [**iTPN**](./configs/itpn/), [**SparK**](./configs/spark/) self-supervised learning algorithms.
- Provide examples of [New Config](./mmpretrain/configs/) and [DeepSpeed/FSDP with FlexibleRunner](./configs/mae/benchmarks/). Here are the documentation links of [New Config](https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html#a-pure-python-style-configuration-file-beta) and [DeepSpeed/FSDP with FlexibleRunner](https://mmengine.readthedocs.io/en/latest/api/generated/mmengine.runner.FlexibleRunner.html#mmengine.runner.FlexibleRunner).
🌟 Upgrade from MMClassification to MMPreTrain
- Integrated Self-supervised learning algorithms from **MMSelfSup**, such as **MAE**, **BEiT**, etc.
- Support **RIFormer**, a simple but effective vision backbone by removing token mixer.
- Refactor dataset pipeline visualization.
- Support **LeViT**, **XCiT**, **ViG**, **ConvNeXt-V2**, **EVA**, **RevViT**, **EfficientnetV2**, **CLIP**, **TinyViT** and **MixMIM** backbones.
This release introduced a brand new and flexible training & test engine, but it's still in progress. Welcome
to try according to [the documentation](https://mmclassification.readthedocs.io/en/1.x/).
to try according to [the documentation](https://mmpretrain.readthedocs.io/en/latest/).
And there are some BC-breaking changes. Please check [the migration tutorial](https://mmclassification.readthedocs.io/en/1.x/migration.html).
And there are some BC-breaking changes. Please check [the migration tutorial](https://mmpretrain.readthedocs.io/en/latest/migration.html).
The release candidate will last until the end of 2022, and during the release candidate, we will develop on the `1.x` branch. And we will still maintain 0.x version still at least the end of 2023.
Please refer to [changelog.md](https://mmclassification.readthedocs.io/en/1.x/notes/changelog.html) for more details and other release history.
Please refer to [changelog](https://mmpretrain.readthedocs.io/en/latest/notes/changelog.html) for more details and other release history.
## Installation
@ -82,89 +125,186 @@ Below are quick steps for installation:
conda create -n open-mmlab python=3.8 pytorch==1.10.1 torchvision==0.11.2 cudatoolkit=11.3 -c pytorch -y
conda activate open-mmlab
pip install openmim
git clone -b 1.x https://github.com/open-mmlab/mmclassification.git
cd mmclassification
git clone https://github.com/open-mmlab/mmpretrain.git
cd mmpretrain
mim install -e .
```
Please refer to [install.md](https://mmclassification.readthedocs.io/en/1.x/get_started.html) for more detailed installation and dataset preparation.
Please refer to [installation documentation](https://mmpretrain.readthedocs.io/en/latest/get_started.html) for more detailed installation and dataset preparation.
For multi-modality models support, please install the extra dependencies by:
```shell
mim install -e ".[multimodal]"
```
## User Guides
We provided a series of tutorials about the basic usage of MMClassification for new users:
We provided a series of tutorials about the basic usage of MMPreTrain for new users:
- [Inference with existing models](https://mmclassification.readthedocs.io/en/1.x/user_guides/inference.html)
- [Prepare Dataset](https://mmclassification.readthedocs.io/en/1.x/user_guides/dataset_prepare.html)
- [Training and Test](https://mmclassification.readthedocs.io/en/1.x/user_guides/train_test.html)
- [Learn about Configs](https://mmclassification.readthedocs.io/en/1.x/user_guides/config.html)
- [Fine-tune Models](https://mmclassification.readthedocs.io/en/1.x/user_guides/finetune.html)
- [Analysis Tools](https://mmclassification.readthedocs.io/en/1.x/user_guides/analysis.html)
- [Visualization Tools](https://mmclassification.readthedocs.io/en/1.x/user_guides/visualization.html)
- [Other Useful Tools](https://mmclassification.readthedocs.io/en/1.x/user_guides/useful_tools.html)
- [Learn about Configs](https://mmpretrain.readthedocs.io/en/latest/user_guides/config.html)
- [Prepare Dataset](https://mmpretrain.readthedocs.io/en/latest/user_guides/dataset_prepare.html)
- [Inference with existing models](https://mmpretrain.readthedocs.io/en/latest/user_guides/inference.html)
- [Train](https://mmpretrain.readthedocs.io/en/latest/user_guides/train.html)
- [Test](https://mmpretrain.readthedocs.io/en/latest/user_guides/test.html)
- [Downstream tasks](https://mmpretrain.readthedocs.io/en/latest/user_guides/downstream.html)
For more information, please refer to [our documentation](https://mmpretrain.readthedocs.io/en/latest/).
## Model zoo
Results and models are available in the [model zoo](https://mmclassification.readthedocs.io/en/1.x/modelzoo_statistics.html).
Results and models are available in the [model zoo](https://mmpretrain.readthedocs.io/en/latest/modelzoo_statistics.html).
<details open>
<summary>Supported backbones</summary>
- [x] [VGG](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/vgg)
- [x] [ResNet](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/resnet)
- [x] [ResNeXt](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/resnext)
- [x] [SE-ResNet](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/seresnet)
- [x] [SE-ResNeXt](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/seresnet)
- [x] [RegNet](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/regnet)
- [x] [ShuffleNetV1](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/shufflenet_v1)
- [x] [ShuffleNetV2](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/shufflenet_v2)
- [x] [MobileNetV2](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/mobilenet_v2)
- [x] [MobileNetV3](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/mobilenet_v3)
- [x] [Swin-Transformer](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/swin_transformer)
- [x] [Swin-Transformer V2](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/swin_transformer_v2)
- [x] [RepVGG](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/repvgg)
- [x] [Vision-Transformer](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/vision_transformer)
- [x] [Transformer-in-Transformer](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/tnt)
- [x] [Res2Net](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/res2net)
- [x] [MLP-Mixer](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/mlp_mixer)
- [x] [DeiT](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/deit)
- [x] [Conformer](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/conformer)
- [x] [T2T-ViT](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/t2t_vit)
- [x] [Twins](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/twins)
- [x] [EfficientNet](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/efficientnet)
- [x] [EdgeNeXt](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/edgenext)
- [x] [ConvNeXt](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/convnext)
- [x] [HRNet](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/hrnet)
- [x] [VAN](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/van)
- [x] [ConvMixer](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/convmixer)
- [x] [CSPNet](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/cspnet)
- [x] [PoolFormer](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/poolformer)
- [x] [Inception V3](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/inception_v3)
- [x] [MobileOne](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/mobileone)
- [x] [EfficientFormer](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/efficientformer)
- [x] [MViT](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/mvit)
</details>
<div align="center">
<b>Overview</b>
</div>
<table align="center">
<tbody>
<tr align="center" valign="bottom">
<td>
<b>Supported Backbones</b>
</td>
<td>
<b>Self-supervised Learning</b>
</td>
<td>
<b>Multi-Modality Algorithms</b>
</td>
<td>
<b>Others</b>
</td>
</tr>
<tr valign="top">
<td>
<ul>
<li><a href="configs/vgg">VGG</a></li>
<li><a href="configs/resnet">ResNet</a></li>
<li><a href="configs/resnext">ResNeXt</a></li>
<li><a href="configs/seresnet">SE-ResNet</a></li>
<li><a href="configs/seresnet">SE-ResNeXt</a></li>
<li><a href="configs/regnet">RegNet</a></li>
<li><a href="configs/shufflenet_v1">ShuffleNet V1</a></li>
<li><a href="configs/shufflenet_v2">ShuffleNet V2</a></li>
<li><a href="configs/mobilenet_v2">MobileNet V2</a></li>
<li><a href="configs/mobilenet_v3">MobileNet V3</a></li>
<li><a href="configs/swin_transformer">Swin-Transformer</a></li>
<li><a href="configs/swin_transformer_v2">Swin-Transformer V2</a></li>
<li><a href="configs/repvgg">RepVGG</a></li>
<li><a href="configs/vision_transformer">Vision-Transformer</a></li>
<li><a href="configs/tnt">Transformer-in-Transformer</a></li>
<li><a href="configs/res2net">Res2Net</a></li>
<li><a href="configs/mlp_mixer">MLP-Mixer</a></li>
<li><a href="configs/deit">DeiT</a></li>
<li><a href="configs/deit3">DeiT-3</a></li>
<li><a href="configs/conformer">Conformer</a></li>
<li><a href="configs/t2t_vit">T2T-ViT</a></li>
<li><a href="configs/twins">Twins</a></li>
<li><a href="configs/efficientnet">EfficientNet</a></li>
<li><a href="configs/edgenext">EdgeNeXt</a></li>
<li><a href="configs/convnext">ConvNeXt</a></li>
<li><a href="configs/hrnet">HRNet</a></li>
<li><a href="configs/van">VAN</a></li>
<li><a href="configs/convmixer">ConvMixer</a></li>
<li><a href="configs/cspnet">CSPNet</a></li>
<li><a href="configs/poolformer">PoolFormer</a></li>
<li><a href="configs/inception_v3">Inception V3</a></li>
<li><a href="configs/mobileone">MobileOne</a></li>
<li><a href="configs/efficientformer">EfficientFormer</a></li>
<li><a href="configs/mvit">MViT</a></li>
<li><a href="configs/hornet">HorNet</a></li>
<li><a href="configs/mobilevit">MobileViT</a></li>
<li><a href="configs/davit">DaViT</a></li>
<li><a href="configs/replknet">RepLKNet</a></li>
<li><a href="configs/beit">BEiT</a></li>
<li><a href="configs/mixmim">MixMIM</a></li>
<li><a href="configs/efficientnet_v2">EfficientNet V2</a></li>
<li><a href="configs/revvit">RevViT</a></li>
<li><a href="configs/convnext_v2">ConvNeXt V2</a></li>
<li><a href="configs/vig">ViG</a></li>
<li><a href="configs/xcit">XCiT</a></li>
<li><a href="configs/levit">LeViT</a></li>
<li><a href="configs/riformer">RIFormer</a></li>
<li><a href="configs/glip">GLIP</a></li>
<li><a href="configs/sam">ViT SAM</a></li>
<li><a href="configs/eva02">EVA02</a></li>
<li><a href="configs/dinov2">DINO V2</a></li>
<li><a href="configs/hivit">HiViT</a></li>
</ul>
</td>
<td>
<ul>
<li><a href="configs/mocov2">MoCo V1 (CVPR'2020)</a></li>
<li><a href="configs/simclr">SimCLR (ICML'2020)</a></li>
<li><a href="configs/mocov2">MoCo V2 (arXiv'2020)</a></li>
<li><a href="configs/byol">BYOL (NeurIPS'2020)</a></li>
<li><a href="configs/swav">SwAV (NeurIPS'2020)</a></li>
<li><a href="configs/densecl">DenseCL (CVPR'2021)</a></li>
<li><a href="configs/simsiam">SimSiam (CVPR'2021)</a></li>
<li><a href="configs/barlowtwins">Barlow Twins (ICML'2021)</a></li>
<li><a href="configs/mocov3">MoCo V3 (ICCV'2021)</a></li>
<li><a href="configs/beit">BEiT (ICLR'2022)</a></li>
<li><a href="configs/mae">MAE (CVPR'2022)</a></li>
<li><a href="configs/simmim">SimMIM (CVPR'2022)</a></li>
<li><a href="configs/maskfeat">MaskFeat (CVPR'2022)</a></li>
<li><a href="configs/cae">CAE (arXiv'2022)</a></li>
<li><a href="configs/milan">MILAN (arXiv'2022)</a></li>
<li><a href="configs/beitv2">BEiT V2 (arXiv'2022)</a></li>
<li><a href="configs/eva">EVA (CVPR'2023)</a></li>
<li><a href="configs/mixmim">MixMIM (arXiv'2022)</a></li>
<li><a href="configs/itpn">iTPN (CVPR'2023)</a></li>
<li><a href="configs/spark">SparK (ICLR'2023)</a></li>
<li><a href="configs/mff">MFF (ICCV'2023)</a></li>
</ul>
</td>
<td>
<ul>
<li><a href="configs/blip">BLIP (arxiv'2022)</a></li>
<li><a href="configs/blip2">BLIP-2 (arxiv'2023)</a></li>
<li><a href="configs/ofa">OFA (CoRR'2022)</a></li>
<li><a href="configs/flamingo">Flamingo (NeurIPS'2022)</a></li>
<li><a href="configs/chinese_clip">Chinese CLIP (arxiv'2022)</a></li>
<li><a href="configs/minigpt4">MiniGPT-4 (arxiv'2023)</a></li>
<li><a href="configs/llava">LLaVA (arxiv'2023)</a></li>
<li><a href="configs/otter">Otter (arxiv'2023)</a></li>
</ul>
</td>
<td>
Image Retrieval Task:
<ul>
<li><a href="configs/arcface">ArcFace (CVPR'2019)</a></li>
</ul>
Training&Test Tips:
<ul>
<li><a href="https://arxiv.org/abs/1909.13719">RandAug</a></li>
<li><a href="https://arxiv.org/abs/1805.09501">AutoAug</a></li>
<li><a href="mmpretrain/datasets/samplers/repeat_aug.py">RepeatAugSampler</a></li>
<li><a href="mmpretrain/models/tta/score_tta.py">TTA</a></li>
<li>...</li>
</ul>
</td>
</tbody>
</table>
## Contributing
We appreciate all contributions to improve MMClassification.
Please refer to [CONTRUBUTING.md](https://mmclassification.readthedocs.io/en/1.x/notes/contribution_guide.html) for the contributing guideline.
We appreciate all contributions to improve MMPreTrain.
Please refer to [CONTRUBUTING](https://mmpretrain.readthedocs.io/en/latest/notes/contribution_guide.html) for the contributing guideline.
## Acknowledgement
MMClassification is an open source project that is contributed by researchers and engineers from various colleges and companies. We appreciate all the contributors who implement their methods or add new features, as well as users who give valuable feedbacks.
We wish that the toolbox and benchmark could serve the growing research community by providing a flexible toolkit to reimplement existing methods and develop their own new classifiers.
MMPreTrain is an open source project that is contributed by researchers and engineers from various colleges and companies. We appreciate all the contributors who implement their methods or add new features, as well as users who give valuable feedbacks.
We wish that the toolbox and benchmark could serve the growing research community by providing a flexible toolkit to reimplement existing methods and supporting their own academic research.
## Citation
If you find this project useful in your research, please consider cite:
```BibTeX
@misc{2020mmclassification,
title={OpenMMLab's Image Classification Toolbox and Benchmark},
author={MMClassification Contributors},
howpublished = {\url{https://github.com/open-mmlab/mmclassification}},
year={2020}
@misc{2023mmpretrain,
title={OpenMMLab's Pre-training Toolbox and Benchmark},
author={MMPreTrain Contributors},
howpublished = {\url{https://github.com/open-mmlab/mmpretrain}},
year={2023}
}
```
@ -177,10 +317,12 @@ This project is released under the [Apache 2.0 license](LICENSE).
- [MMEngine](https://github.com/open-mmlab/mmengine): OpenMMLab foundational library for training deep learning models.
- [MMCV](https://github.com/open-mmlab/mmcv): OpenMMLab foundational library for computer vision.
- [MIM](https://github.com/open-mmlab/mim): MIM installs OpenMMLab packages.
- [MMClassification](https://github.com/open-mmlab/mmclassification): OpenMMLab image classification toolbox and benchmark.
- [MMEval](https://github.com/open-mmlab/mmeval): A unified evaluation library for multiple machine learning libraries.
- [MMPreTrain](https://github.com/open-mmlab/mmpretrain): OpenMMLab pre-training toolbox and benchmark.
- [MMDetection](https://github.com/open-mmlab/mmdetection): OpenMMLab detection toolbox and benchmark.
- [MMDetection3D](https://github.com/open-mmlab/mmdetection3d): OpenMMLab's next-generation platform for general 3D object detection.
- [MMRotate](https://github.com/open-mmlab/mmrotate): OpenMMLab rotated object detection toolbox and benchmark.
- [MMYOLO](https://github.com/open-mmlab/mmyolo): OpenMMLab YOLO series toolbox and benchmark.
- [MMSegmentation](https://github.com/open-mmlab/mmsegmentation): OpenMMLab semantic segmentation toolbox and benchmark.
- [MMOCR](https://github.com/open-mmlab/mmocr): OpenMMLab text detection, recognition, and understanding toolbox.
- [MMPose](https://github.com/open-mmlab/mmpose): OpenMMLab pose estimation toolbox and benchmark.
@ -191,6 +333,7 @@ This project is released under the [Apache 2.0 license](LICENSE).
- [MMAction2](https://github.com/open-mmlab/mmaction2): OpenMMLab's next-generation action understanding toolbox and benchmark.
- [MMTracking](https://github.com/open-mmlab/mmtracking): OpenMMLab video perception toolbox and benchmark.
- [MMFlow](https://github.com/open-mmlab/mmflow): OpenMMLab optical flow toolbox and benchmark.
- [MMEditing](https://github.com/open-mmlab/mmediting): OpenMMLab image and video editing toolbox.
- [MMagic](https://github.com/open-mmlab/mmagic): Open**MM**Lab **A**dvanced, **G**enerative and **I**ntelligent **C**reation toolbox.
- [MMGeneration](https://github.com/open-mmlab/mmgeneration): OpenMMLab image and video generative models toolbox.
- [MMDeploy](https://github.com/open-mmlab/mmdeploy): OpenMMLab model deployment framework.
- [Playground](https://github.com/open-mmlab/playground): A central hub for gathering and showcasing amazing projects built upon OpenMMLab.

View File

@ -1,6 +1,6 @@
<div align="center">
<img src="resources/mmcls-logo.png" width="600"/>
<img src="resources/mmpt-logo.png" width="600"/>
<div>&nbsp;</div>
<div align="center">
<b><font size="5">OpenMMLab 官网</font></b>
@ -19,59 +19,100 @@
</div>
<div>&nbsp;</div>
[![PyPI](https://img.shields.io/pypi/v/mmcls)](https://pypi.org/project/mmcls)
[![Docs](https://img.shields.io/badge/docs-latest-blue)](https://mmclassification.readthedocs.io/zh_CN/1.x/)
[![Build Status](https://github.com/open-mmlab/mmclassification/workflows/build/badge.svg)](https://github.com/open-mmlab/mmclassification/actions)
[![codecov](https://codecov.io/gh/open-mmlab/mmclassification/branch/1.x/graph/badge.svg)](https://codecov.io/gh/open-mmlab/mmclassification)
[![license](https://img.shields.io/github/license/open-mmlab/mmclassification.svg)](https://github.com/open-mmlab/mmclassification/blob/1.x/LICENSE)
[![open issues](https://isitmaintained.com/badge/open/open-mmlab/mmclassification.svg)](https://github.com/open-mmlab/mmclassification/issues)
[![issue resolution](https://isitmaintained.com/badge/resolution/open-mmlab/mmclassification.svg)](https://github.com/open-mmlab/mmclassification/issues)
[![PyPI](https://img.shields.io/pypi/v/mmpretrain)](https://pypi.org/project/mmpretrain)
[![Docs](https://img.shields.io/badge/docs-latest-blue)](https://mmpretrain.readthedocs.io/zh_CN/latest/)
[![Build Status](https://github.com/open-mmlab/mmpretrain/workflows/build/badge.svg)](https://github.com/open-mmlab/mmpretrain/actions)
[![codecov](https://codecov.io/gh/open-mmlab/mmpretrain/branch/main/graph/badge.svg)](https://codecov.io/gh/open-mmlab/mmpretrain)
[![license](https://img.shields.io/github/license/open-mmlab/mmpretrain.svg)](https://github.com/open-mmlab/mmpretrain/blob/main/LICENSE)
[![open issues](https://isitmaintained.com/badge/open/open-mmlab/mmpretrain.svg)](https://github.com/open-mmlab/mmpretrain/issues)
[![issue resolution](https://isitmaintained.com/badge/resolution/open-mmlab/mmpretrain.svg)](https://github.com/open-mmlab/mmpretrain/issues)
[📘 中文文档](https://mmclassification.readthedocs.io/zh_CN/1.x/) |
[🛠️ 安装教程](https://mmclassification.readthedocs.io/zh_CN/1.x/get_started.html) |
[👀 模型库](https://mmclassification.readthedocs.io/zh_CN/1.x/modelzoo_statistics.html) |
[🆕 更新日志](https://mmclassification.readthedocs.io/en/1.x/notes/changelog.html) |
[🤔 报告问题](https://github.com/open-mmlab/mmclassification/issues/new/choose)
[📘 中文文档](https://mmpretrain.readthedocs.io/zh_CN/latest/) |
[🛠️ 安装教程](https://mmpretrain.readthedocs.io/zh_CN/latest/get_started.html) |
[👀 模型库](https://mmpretrain.readthedocs.io/zh_CN/latest/modelzoo_statistics.html) |
[🆕 更新日志](https://mmpretrain.readthedocs.io/zh_CN/latest/notes/changelog.html) |
[🤔 报告问题](https://github.com/open-mmlab/mmpretrain/issues/new/choose)
<img src="https://user-images.githubusercontent.com/36138628/230307505-4727ad0a-7d71-4069-939d-b499c7e272b7.png" width="400"/>
[English](/README.md) | 简体中文
</div>
<div align="center">
<a href="https://openmmlab.medium.com/" style="text-decoration:none;">
<img src="https://user-images.githubusercontent.com/25839884/219255827-67c1a27f-f8c5-46a9-811d-5e57448c61d1.png" width="3%" alt="" /></a>
<img src="https://user-images.githubusercontent.com/25839884/218346358-56cc8e2f-a2b8-487f-9088-32480cceabcf.png" width="3%" alt="" />
<a href="https://discord.gg/raweFPmdzG" style="text-decoration:none;">
<img src="https://user-images.githubusercontent.com/25839884/218347213-c080267f-cbb6-443e-8532-8e1ed9a58ea9.png" width="3%" alt="" /></a>
<img src="https://user-images.githubusercontent.com/25839884/218346358-56cc8e2f-a2b8-487f-9088-32480cceabcf.png" width="3%" alt="" />
<a href="https://twitter.com/OpenMMLab" style="text-decoration:none;">
<img src="https://user-images.githubusercontent.com/25839884/218346637-d30c8a0f-3eba-4699-8131-512fb06d46db.png" width="3%" alt="" /></a>
<img src="https://user-images.githubusercontent.com/25839884/218346358-56cc8e2f-a2b8-487f-9088-32480cceabcf.png" width="3%" alt="" />
<a href="https://www.youtube.com/openmmlab" style="text-decoration:none;">
<img src="https://user-images.githubusercontent.com/25839884/218346691-ceb2116a-465a-40af-8424-9f30d2348ca9.png" width="3%" alt="" /></a>
<img src="https://user-images.githubusercontent.com/25839884/218346358-56cc8e2f-a2b8-487f-9088-32480cceabcf.png" width="3%" alt="" />
<a href="https://space.bilibili.com/1293512903" style="text-decoration:none;">
<img src="https://user-images.githubusercontent.com/25839884/219026751-d7d14cce-a7c9-4e82-9942-8375fca65b99.png" width="3%" alt="" /></a>
<img src="https://user-images.githubusercontent.com/25839884/218346358-56cc8e2f-a2b8-487f-9088-32480cceabcf.png" width="3%" alt="" />
<a href="https://www.zhihu.com/people/openmmlab" style="text-decoration:none;">
<img src="https://user-images.githubusercontent.com/25839884/219026120-ba71e48b-6e94-4bd4-b4e9-b7d175b5e362.png" width="3%" alt="" /></a>
</div>
## Introduction
[English](/README.md) | 简体中文
MMPreTrain 是一款基于 PyTorch 的开源深度学习预训练工具箱,是 [OpenMMLab](https://openmmlab.com/) 项目的成员之一
MMClassification 是一款基于 PyTorch 的开源图像分类工具箱,是 [OpenMMLab](https://openmmlab.com/) 项目的成员之一
主分支代码目前支持 PyTorch 1.5 以上的版本。
<div align="center">
<img src="https://user-images.githubusercontent.com/9102141/87268895-3e0d0780-c4fe-11ea-849e-6140b7e0d4de.gif" width="70%"/>
</div>
`主分支`代码目前支持 PyTorch 1.8 以上的版本。
### 主要特性
- 支持多样的主干网络与预训练模型
- 支持配置多种训练技巧
- 支持多种训练策略(有监督学习,无监督学习,多模态学习等)
- 提供多种训练技巧
- 大量的训练配置文件
- 高效率和高可扩展性
- 功能强大的工具箱
- 功能强大的工具箱,有助于模型分析和实验
- 支持多种开箱即用的推理任务
- 图像分类
- 图像描述Image Caption
- 视觉问答Visual Question Answering
- 视觉定位Visual Grounding
- 检索(图搜图,图搜文,文搜图)
https://github.com/open-mmlab/mmpretrain/assets/26739999/e4dcd3a2-f895-4d1b-a351-fbc74a04e904
## 更新日志
2022/9/30 发布了 v1.0.0rc1 版本
🌟 2024/01/04 发布了 v1.2.0 版本
- 支持了 MViTEdgeNeXtSwin-Transformer V2EfficientFormerMobileOne 等主干网络。
- 支持了 BEiT 风格的 transformer 层。
- 支持了 LLaVA 1.5
- 实现了一个 RAM 模型的 gradio 推理例程
2022/8/31 发布了 v1.0.0rc0 版本
🌟 2023/10/12 发布了 v1.1.0 版本
这个版本引入一个全新的,可扩展性强的训练和测试引擎,但目前仍在开发中。欢迎根据[文档](https://mmclassification.readthedocs.io/zh_CN/1.x/)进行试用。
- 支持 Mini-GPT4 训练并提供一个基于 Baichuan-7B 的中文模型
- 支持基于 CLIP 的零样本分类。
同时,新版本中存在一些与旧版本不兼容的修改。请查看[迁移文档](https://mmclassification.readthedocs.io/zh_CN/1.x/migration.html)来详细了解这些变动。
🌟 2023/7/4 发布了 v1.0.0 版本
新版本的公测将持续到 2022 年末,在此期间,我们将基于 `1.x` 分支进行更新,不会合入到 `master` 分支。另外,至少
到 2023 年末,我们会保持对 0.x 版本的维护。
- 支持更多**多模态**算法的推理, 例如 [**LLaVA**](./configs/llava/), [**MiniGPT-4**](./configs/minigpt4), [**Otter**](./configs/otter/) 等。
- 支持约 **10 个多模态**数据集!
- 添加自监督学习算法 [**iTPN**](./configs/itpn/), [**SparK**](./configs/spark/)。
- 提供[新配置文件](./mmpretrain/configs/)和 [DeepSpeed/FSDP](./configs/mae/benchmarks/) 的样例。这是[新配置文件](https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html#a-pure-python-style-configuration-file-beta) 和 [DeepSpeed/FSDP with FlexibleRunner](https://mmengine.readthedocs.io/en/latest/api/generated/mmengine.runner.FlexibleRunner.html#mmengine.runner.FlexibleRunner) 的文档链接。
发布历史和更新细节请参考 [更新日志](https://mmclassification.readthedocs.io/zh_CN/1.x/notes/changelog.html)
🌟 从 MMClassification 升级到 MMPreTrain
- 整合来自 MMSelfSup 的自监督学习算法,例如 `MAE`, `BEiT`
- 支持了 **RIFormer**,简单但有效的视觉主干网络,却移除了 token mixer
- 重构数据管道可视化
- 支持了 **LeViT**, **XCiT**, **ViG**, **ConvNeXt-V2**, **EVA**, **RevViT**, **EfficientnetV2**, **CLIP**, **TinyViT****MixMIM** 等骨干网络结构
这个版本引入一个全新的,可扩展性强的训练和测试引擎,但目前仍在开发中。欢迎根据 [文档](https://mmpretrain.readthedocs.io/zh_CN/latest/) 进行试用。
同时,新版本中存在一些与旧版本不兼容的修改。请查看 [迁移文档](https://mmpretrain.readthedocs.io/zh_CN/latest/migration.html) 来详细了解这些变动。
发布历史和更新细节请参考 [更新日志](https://mmpretrain.readthedocs.io/zh_CN/latest/notes/changelog.html)。
## 安装
@ -81,89 +122,184 @@ MMClassification 是一款基于 PyTorch 的开源图像分类工具箱,是 [O
conda create -n open-mmlab python=3.8 pytorch==1.10.1 torchvision==0.11.2 cudatoolkit=11.3 -c pytorch -y
conda activate open-mmlab
pip3 install openmim
git clone -b 1.x https://github.com/open-mmlab/mmclassification.git
cd mmclassification
git clone https://github.com/open-mmlab/mmpretrain.git
cd mmpretrain
mim install -e .
```
更详细的步骤请参考 [安装指南](https://mmclassification.readthedocs.io/zh_CN/1.x/get_started.html) 进行安装。
更详细的步骤请参考 [安装指南](https://mmpretrain.readthedocs.io/zh_CN/latest/get_started.html) 进行安装。
如果需要多模态模型,请使用如下方式安装额外的依赖:
```shell
mim install -e ".[multimodal]"
```
## 基础教程
我们为新用户提供了一系列基础教程:
- [使用现有模型推理](https://mmclassification.readthedocs.io/zh_CN/1.x/user_guides/inference.html)
- [准备数据集](https://mmclassification.readthedocs.io/zh_CN/1.x/user_guides/dataset_prepare.html)
- [训练与测试](https://mmclassification.readthedocs.io/zh_CN/1.x/user_guides/train_test.html)
- [学习配置文件](https://mmclassification.readthedocs.io/zh_CN/1.x/user_guides/config.html)
- [如何微调模型](https://mmclassification.readthedocs.io/zh_CN/1.x/user_guides/finetune.html)
- [分析工具](https://mmclassification.readthedocs.io/zh_CN/1.x/user_guides/analysis.html)
- [可视化工具](https://mmclassification.readthedocs.io/zh_CN/1.x/user_guides/visualization.html)
- [其他工具](https://mmclassification.readthedocs.io/zh_CN/1.x/user_guides/useful_tools.html)
- [学习配置文件](https://mmpretrain.readthedocs.io/zh_CN/latest/user_guides/config.html)
- [准备数据集](https://mmpretrain.readthedocs.io/zh_CN/latest/user_guides/dataset_prepare.html)
- [使用现有模型推理](https://mmpretrain.readthedocs.io/zh_CN/latest/user_guides/inference.html)
- [训练](https://mmpretrain.readthedocs.io/zh_CN/latest/user_guides/train.html)
- [测试](https://mmpretrain.readthedocs.io/zh_CN/latest/user_guides/test.html)
- [下游任务](https://mmpretrain.readthedocs.io/zh_CN/latest/user_guides/downstream.html)
关于更多的信息,请查阅我们的 [相关文档](https://mmpretrain.readthedocs.io/zh_CN/latest/)。
## 模型库
相关结果和模型可在 [model zoo](https://mmclassification.readthedocs.io/zh_CN/1.x/modelzoo_statistics.html) 中获得
相关结果和模型可在 [模型库](https://mmpretrain.readthedocs.io/zh_CN/latest/modelzoo_statistics.html) 中获得。
<details open>
<summary>支持的主干网络</summary>
- [x] [VGG](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/vgg)
- [x] [ResNet](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/resnet)
- [x] [ResNeXt](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/resnext)
- [x] [SE-ResNet](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/seresnet)
- [x] [SE-ResNeXt](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/seresnet)
- [x] [RegNet](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/regnet)
- [x] [ShuffleNetV1](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/shufflenet_v1)
- [x] [ShuffleNetV2](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/shufflenet_v2)
- [x] [MobileNetV2](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/mobilenet_v2)
- [x] [MobileNetV3](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/mobilenet_v3)
- [x] [Swin-Transformer](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/swin_transformer)
- [x] [Swin-Transformer V2](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/swin_transformer_v2)
- [x] [RepVGG](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/repvgg)
- [x] [Vision-Transformer](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/vision_transformer)
- [x] [Transformer-in-Transformer](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/tnt)
- [x] [Res2Net](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/res2net)
- [x] [MLP-Mixer](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/mlp_mixer)
- [x] [DeiT](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/deit)
- [x] [Conformer](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/conformer)
- [x] [T2T-ViT](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/t2t_vit)
- [x] [Twins](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/twins)
- [x] [EfficientNet](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/efficientnet)
- [x] [EdgeNeXt](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/edgenext)
- [x] [ConvNeXt](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/convnext)
- [x] [HRNet](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/hrnet)
- [x] [VAN](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/van)
- [x] [ConvMixer](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/convmixer)
- [x] [CSPNet](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/cspnet)
- [x] [PoolFormer](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/poolformer)
- [x] [Inception V3](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/inception_v3)
- [x] [MobileOne](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/mobileone)
- [x] [EfficientFormer](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/efficientformer)
- [x] [MViT](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/mvit)
</details>
<div align="center">
<b>概览</b>
</div>
<table align="center">
<tbody>
<tr align="center" valign="bottom">
<td>
<b>支持的主干网络</b>
</td>
<td>
<b>自监督学习</b>
</td>
<td>
<b>多模态算法</b>
</td>
<td>
<b>其它</b>
</td>
</tr>
<tr valign="top">
<td>
<ul>
<li><a href="configs/vgg">VGG</a></li>
<li><a href="configs/resnet">ResNet</a></li>
<li><a href="configs/resnext">ResNeXt</a></li>
<li><a href="configs/seresnet">SE-ResNet</a></li>
<li><a href="configs/seresnet">SE-ResNeXt</a></li>
<li><a href="configs/regnet">RegNet</a></li>
<li><a href="configs/shufflenet_v1">ShuffleNet V1</a></li>
<li><a href="configs/shufflenet_v2">ShuffleNet V2</a></li>
<li><a href="configs/mobilenet_v2">MobileNet V2</a></li>
<li><a href="configs/mobilenet_v3">MobileNet V3</a></li>
<li><a href="configs/swin_transformer">Swin-Transformer</a></li>
<li><a href="configs/swin_transformer_v2">Swin-Transformer V2</a></li>
<li><a href="configs/repvgg">RepVGG</a></li>
<li><a href="configs/vision_transformer">Vision-Transformer</a></li>
<li><a href="configs/tnt">Transformer-in-Transformer</a></li>
<li><a href="configs/res2net">Res2Net</a></li>
<li><a href="configs/mlp_mixer">MLP-Mixer</a></li>
<li><a href="configs/deit">DeiT</a></li>
<li><a href="configs/deit3">DeiT-3</a></li>
<li><a href="configs/conformer">Conformer</a></li>
<li><a href="configs/t2t_vit">T2T-ViT</a></li>
<li><a href="configs/twins">Twins</a></li>
<li><a href="configs/efficientnet">EfficientNet</a></li>
<li><a href="configs/edgenext">EdgeNeXt</a></li>
<li><a href="configs/convnext">ConvNeXt</a></li>
<li><a href="configs/hrnet">HRNet</a></li>
<li><a href="configs/van">VAN</a></li>
<li><a href="configs/convmixer">ConvMixer</a></li>
<li><a href="configs/cspnet">CSPNet</a></li>
<li><a href="configs/poolformer">PoolFormer</a></li>
<li><a href="configs/inception_v3">Inception V3</a></li>
<li><a href="configs/mobileone">MobileOne</a></li>
<li><a href="configs/efficientformer">EfficientFormer</a></li>
<li><a href="configs/mvit">MViT</a></li>
<li><a href="configs/hornet">HorNet</a></li>
<li><a href="configs/mobilevit">MobileViT</a></li>
<li><a href="configs/davit">DaViT</a></li>
<li><a href="configs/replknet">RepLKNet</a></li>
<li><a href="configs/beit">BEiT</a></li>
<li><a href="configs/mixmim">MixMIM</a></li>
<li><a href="configs/revvit">RevViT</a></li>
<li><a href="configs/convnext_v2">ConvNeXt V2</a></li>
<li><a href="configs/vig">ViG</a></li>
<li><a href="configs/xcit">XCiT</a></li>
<li><a href="configs/levit">LeViT</a></li>
<li><a href="configs/riformer">RIFormer</a></li>
<li><a href="configs/glip">GLIP</a></li>
<li><a href="configs/sam">ViT SAM</a></li>
<li><a href="configs/eva02">EVA02</a></li>
<li><a href="configs/dinov2">DINO V2</a></li>
<li><a href="configs/hivit">HiViT</a></li>
</ul>
</td>
<td>
<ul>
<li><a href="configs/mocov2">MoCo V1 (CVPR'2020)</a></li>
<li><a href="configs/simclr">SimCLR (ICML'2020)</a></li>
<li><a href="configs/mocov2">MoCo V2 (arXiv'2020)</a></li>
<li><a href="configs/byol">BYOL (NeurIPS'2020)</a></li>
<li><a href="configs/swav">SwAV (NeurIPS'2020)</a></li>
<li><a href="configs/densecl">DenseCL (CVPR'2021)</a></li>
<li><a href="configs/simsiam">SimSiam (CVPR'2021)</a></li>
<li><a href="configs/barlowtwins">Barlow Twins (ICML'2021)</a></li>
<li><a href="configs/mocov3">MoCo V3 (ICCV'2021)</a></li>
<li><a href="configs/beit">BEiT (ICLR'2022)</a></li>
<li><a href="configs/mae">MAE (CVPR'2022)</a></li>
<li><a href="configs/simmim">SimMIM (CVPR'2022)</a></li>
<li><a href="configs/maskfeat">MaskFeat (CVPR'2022)</a></li>
<li><a href="configs/cae">CAE (arXiv'2022)</a></li>
<li><a href="configs/milan">MILAN (arXiv'2022)</a></li>
<li><a href="configs/beitv2">BEiT V2 (arXiv'2022)</a></li>
<li><a href="configs/eva">EVA (CVPR'2023)</a></li>
<li><a href="configs/mixmim">MixMIM (arXiv'2022)</a></li>
<li><a href="configs/itpn">iTPN (CVPR'2023)</a></li>
<li><a href="configs/spark">SparK (ICLR'2023)</a></li>
<li><a href="configs/mff">MFF (ICCV'2023)</a></li>
</ul>
</td>
<td>
<ul>
<li><a href="configs/blip">BLIP (arxiv'2022)</a></li>
<li><a href="configs/blip2">BLIP-2 (arxiv'2023)</a></li>
<li><a href="configs/ofa">OFA (CoRR'2022)</a></li>
<li><a href="configs/flamingo">Flamingo (NeurIPS'2022)</a></li>
<li><a href="configs/chinese_clip">Chinese CLIP (arxiv'2022)</a></li>
<li><a href="configs/minigpt4">MiniGPT-4 (arxiv'2023)</a></li>
<li><a href="configs/llava">LLaVA (arxiv'2023)</a></li>
<li><a href="configs/otter">Otter (arxiv'2023)</a></li>
</ul>
</td>
<td>
图像检索任务:
<ul>
<li><a href="configs/arcface">ArcFace (CVPR'2019)</a></li>
</ul>
训练和测试 Tips:
<ul>
<li><a href="https://arxiv.org/abs/1909.13719">RandAug</a></li>
<li><a href="https://arxiv.org/abs/1805.09501">AutoAug</a></li>
<li><a href="mmpretrain/datasets/samplers/repeat_aug.py">RepeatAugSampler</a></li>
<li><a href="mmpretrain/models/tta/score_tta.py">TTA</a></li>
<li>...</li>
</ul>
</td>
</tbody>
</table>
## 参与贡献
我们非常欢迎任何有助于提升 MMClassification 的贡献,请参考 [贡献指南](https://mmclassification.readthedocs.io/zh_CN/1.x/notes/contribution_guide.html) 来了解如何参与贡献。
我们非常欢迎任何有助于提升 MMPreTrain 的贡献,请参考 [贡献指南](https://mmpretrain.readthedocs.io/zh_CN/latest/notes/contribution_guide.html) 来了解如何参与贡献。
## 致谢
MMClassification 是一款由不同学校和公司共同贡献的开源项目。我们感谢所有为项目提供算法复现和新功能支持的贡献者,以及提供宝贵反馈的用户。
MMPreTrain 是一款由不同学校和公司共同贡献的开源项目。我们感谢所有为项目提供算法复现和新功能支持的贡献者,以及提供宝贵反馈的用户。
我们希望该工具箱和基准测试可以为社区提供灵活的代码工具,供用户复现现有算法并开发自己的新模型,从而不断为开源社区提供贡献。
## 引用
如果你在研究中使用了本项目的代码或者性能基准,请参考如下 bibtex 引用 MMClassification。
如果你在研究中使用了本项目的代码或者性能基准,请参考如下 bibtex 引用 MMPreTrain。
```BibTeX
@misc{2020mmclassification,
title={OpenMMLab's Image Classification Toolbox and Benchmark},
author={MMClassification Contributors},
howpublished = {\url{https://github.com/open-mmlab/mmclassification}},
year={2020}
@misc{2023mmpretrain,
title={OpenMMLab's Pre-training Toolbox and Benchmark},
author={MMPreTrain Contributors},
howpublished = {\url{https://github.com/open-mmlab/mmpretrain}},
year={2023}
}
```
@ -176,10 +312,12 @@ MMClassification 是一款由不同学校和公司共同贡献的开源项目。
- [MMEngine](https://github.com/open-mmlab/mmengine): OpenMMLab 深度学习模型训练基础库
- [MMCV](https://github.com/open-mmlab/mmcv): OpenMMLab 计算机视觉基础库
- [MIM](https://github.com/open-mmlab/mim): MIM 是 OpenMMlab 项目、算法、模型的统一入口
- [MMClassification](https://github.com/open-mmlab/mmclassification): OpenMMLab 图像分类工具箱
- [MMEval](https://github.com/open-mmlab/mmeval): 统一开放的跨框架算法评测库
- [MMPreTrain](https://github.com/open-mmlab/mmpretrain): OpenMMLab 深度学习预训练工具箱
- [MMDetection](https://github.com/open-mmlab/mmdetection): OpenMMLab 目标检测工具箱
- [MMDetection3D](https://github.com/open-mmlab/mmdetection3d): OpenMMLab 新一代通用 3D 目标检测平台
- [MMRotate](https://github.com/open-mmlab/mmrotate): OpenMMLab 旋转框检测工具箱与测试基准
- [MMYOLO](https://github.com/open-mmlab/mmyolo): OpenMMLab YOLO 系列工具箱与测试基准
- [MMSegmentation](https://github.com/open-mmlab/mmsegmentation): OpenMMLab 语义分割工具箱
- [MMOCR](https://github.com/open-mmlab/mmocr): OpenMMLab 全流程文字检测识别理解工具包
- [MMPose](https://github.com/open-mmlab/mmpose): OpenMMLab 姿态估计工具箱
@ -190,16 +328,17 @@ MMClassification 是一款由不同学校和公司共同贡献的开源项目。
- [MMAction2](https://github.com/open-mmlab/mmaction2): OpenMMLab 新一代视频理解工具箱
- [MMTracking](https://github.com/open-mmlab/mmtracking): OpenMMLab 一体化视频目标感知平台
- [MMFlow](https://github.com/open-mmlab/mmflow): OpenMMLab 光流估计工具箱与测试基准
- [MMEditing](https://github.com/open-mmlab/mmediting): OpenMMLab 图像视频编辑工具箱
- [MMagic](https://github.com/open-mmlab/mmagic): OpenMMLab 新一代人工智能内容生成AIGC工具箱
- [MMGeneration](https://github.com/open-mmlab/mmgeneration): OpenMMLab 图片视频生成模型工具箱
- [MMDeploy](https://github.com/open-mmlab/mmdeploy): OpenMMLab 模型部署框架
- [Playground](https://github.com/open-mmlab/playground): 收集和展示 OpenMMLab 相关的前沿、有趣的社区项目
## 欢迎加入 OpenMMLab 社区
扫描下方的二维码可关注 OpenMMLab 团队的 [知乎官方账号](https://www.zhihu.com/people/openmmlab)加入 OpenMMLab 团队的 [官方交流 QQ 群](https://jq.qq.com/?_wv=1027&k=aCvMxdr3) 或联络 OpenMMLab 官方微信小助手
扫描下方的二维码可关注 OpenMMLab 团队的 [知乎官方账号](https://www.zhihu.com/people/openmmlab)扫描下方微信二维码添加喵喵好友,进入 MMPretrain 微信交流社群。【加好友申请格式:研究方向+地区+学校/公司+姓名】
<div align="center">
<img src="https://github.com/open-mmlab/mmcv/raw/master/docs/en/_static/zhihu_qrcode.jpg" height="400" /> <img src="https://github.com/open-mmlab/mmcv/raw/master/docs/en/_static/qq_group_qrcode.jpg" height="400" /> <img src="https://github.com/open-mmlab/mmcv/raw/master/docs/en/_static/wechat_qrcode.jpg" height="400" />
<img src="./resources/zhihu_qrcode.jpg" height="400"/> <img src="./resources/miaomiao_qrcode.jpg" height="400"/>
</div>
我们会在 OpenMMLab 社区为大家

View File

@ -1,6 +1,7 @@
# dataset settings
dataset_type = 'CIFAR100'
data_preprocessor = dict(
num_classes=100,
# RGB format normalization parameters
mean=[129.304, 124.070, 112.434],
std=[68.170, 65.392, 70.418],
@ -10,11 +11,11 @@ data_preprocessor = dict(
train_pipeline = [
dict(type='RandomCrop', crop_size=32, padding=4),
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
dict(type='PackClsInputs'),
dict(type='PackInputs'),
]
test_pipeline = [
dict(type='PackClsInputs'),
dict(type='PackInputs'),
]
train_dataloader = dict(
@ -22,11 +23,10 @@ train_dataloader = dict(
num_workers=2,
dataset=dict(
type=dataset_type,
data_prefix='data/cifar100',
test_mode=False,
data_root='data/cifar100',
split='train',
pipeline=train_pipeline),
sampler=dict(type='DefaultSampler', shuffle=True),
persistent_workers=True,
)
val_dataloader = dict(
@ -34,11 +34,10 @@ val_dataloader = dict(
num_workers=2,
dataset=dict(
type=dataset_type,
data_prefix='data/cifar100/',
test_mode=True,
data_root='data/cifar100/',
split='test',
pipeline=test_pipeline),
sampler=dict(type='DefaultSampler', shuffle=False),
persistent_workers=True,
)
val_evaluator = dict(type='Accuracy', topk=(1, ))

View File

@ -1,6 +1,7 @@
# dataset settings
dataset_type = 'CIFAR10'
data_preprocessor = dict(
num_classes=10,
# RGB format normalization parameters
mean=[125.307, 122.961, 113.8575],
std=[51.5865, 50.847, 51.255],
@ -10,11 +11,11 @@ data_preprocessor = dict(
train_pipeline = [
dict(type='RandomCrop', crop_size=32, padding=4),
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
dict(type='PackClsInputs'),
dict(type='PackInputs'),
]
test_pipeline = [
dict(type='PackClsInputs'),
dict(type='PackInputs'),
]
train_dataloader = dict(
@ -22,11 +23,10 @@ train_dataloader = dict(
num_workers=2,
dataset=dict(
type=dataset_type,
data_prefix='data/cifar10',
test_mode=False,
data_root='data/cifar10',
split='train',
pipeline=train_pipeline),
sampler=dict(type='DefaultSampler', shuffle=True),
persistent_workers=True,
)
val_dataloader = dict(
@ -34,11 +34,10 @@ val_dataloader = dict(
num_workers=2,
dataset=dict(
type=dataset_type,
data_prefix='data/cifar10/',
test_mode=True,
data_root='data/cifar10/',
split='test',
pipeline=test_pipeline),
sampler=dict(type='DefaultSampler', shuffle=False),
persistent_workers=True,
)
val_evaluator = dict(type='Accuracy', topk=(1, ))

View File

@ -0,0 +1,70 @@
# data settings
# coco caption annotations can be grabbed from LAVIS repo
# https://github.com/salesforce/LAVIS/blob/main/lavis/configs/datasets/coco/defaults_cap.yaml
data_preprocessor = dict(
type='MultiModalDataPreprocessor',
mean=[122.770938, 116.7460125, 104.09373615],
std=[68.5005327, 66.6321579, 70.32316305],
to_rgb=True,
)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='RandomResizedCrop',
scale=384,
interpolation='bicubic',
backend='pillow'),
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
dict(type='CleanCaption', keys='gt_caption'),
dict(
type='PackInputs',
algorithm_keys=['gt_caption'],
meta_keys=['image_id'],
),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='Resize',
scale=(384, 384),
interpolation='bicubic',
backend='pillow'),
dict(type='PackInputs', meta_keys=['image_id']),
]
train_dataloader = dict(
batch_size=32,
num_workers=5,
dataset=dict(
type='COCOCaption',
data_root='data/coco',
ann_file='annotations/coco_karpathy_train.json',
pipeline=train_pipeline),
sampler=dict(type='DefaultSampler', shuffle=True),
persistent_workers=True,
drop_last=True,
)
val_dataloader = dict(
batch_size=16,
num_workers=5,
dataset=dict(
type='COCOCaption',
data_root='data/coco',
ann_file='annotations/coco_karpathy_val.json',
pipeline=test_pipeline,
),
sampler=dict(type='DefaultSampler', shuffle=False),
persistent_workers=True,
)
val_evaluator = dict(
type='COCOCaption',
ann_file='data/coco/annotations/coco_karpathy_val_gt.json',
)
# # If you want standard test, please manually configure the test dataset
test_dataloader = val_dataloader
test_evaluator = val_evaluator

View File

@ -0,0 +1,75 @@
# data settings
data_preprocessor = dict(
mean=[122.770938, 116.7460125, 104.09373615],
std=[68.5005327, 66.6321579, 70.32316305],
to_rgb=True,
)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='RandomResizedCrop',
scale=384,
interpolation='bicubic',
backend='pillow'),
dict(
type='PackInputs',
algorithm_keys=['question', 'gt_answer', 'gt_answer_weight'],
meta_keys=['question_id', 'image_id'],
),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='Resize',
scale=(480, 480),
interpolation='bicubic',
backend='pillow'),
dict(
type='CleanCaption',
keys=['question'],
),
dict(
type='PackInputs',
algorithm_keys=['question', 'gt_answer', 'gt_answer_weight'],
meta_keys=['question_id', 'image_id'],
),
]
train_dataloader = dict(
batch_size=16,
num_workers=8,
dataset=dict(
type='COCOVQA',
data_root='data/coco',
data_prefix='train2014',
question_file=
'annotations/okvqa_OpenEnded_mscoco_train2014_questions.json',
ann_file='annotations/okvqa_mscoco_train2014_annotations.json',
pipeline=train_pipeline),
sampler=dict(type='DefaultSampler', shuffle=True),
persistent_workers=True,
drop_last=True,
)
val_dataloader = dict(
batch_size=16,
num_workers=8,
dataset=dict(
type='COCOVQA',
data_root='data/coco',
data_prefix='val2014',
question_file=
'annotations/okvqa_OpenEnded_mscoco_val2014_questions.json',
ann_file='annotations/okvqa_mscoco_val2014_annotations.json',
pipeline=test_pipeline),
sampler=dict(type='DefaultSampler', shuffle=False),
persistent_workers=True,
)
val_evaluator = dict(type='VQAAcc')
test_dataloader = val_dataloader
test_evaluator = val_evaluator

View File

@ -0,0 +1,99 @@
# data settings
# Here are the links to download the annotations for coco retrieval for conveniency # noqa
# https://download.openmmlab.com/mmclassification/datasets/coco_retrieval/caption_karpathy_train2014.json
# https://download.openmmlab.com/mmclassification/datasets/coco_retrieval/caption_karpathy_val2014.json
# https://download.openmmlab.com/mmclassification/datasets/coco_retrieval/caption_karpathy_test2014.json
data_preprocessor = dict(
type='MultiModalDataPreprocessor',
mean=[122.770938, 116.7460125, 104.09373615],
std=[68.5005327, 66.6321579, 70.32316305],
to_rgb=True,
)
rand_increasing_policies = [
dict(type='AutoContrast'),
dict(type='Equalize'),
dict(type='Rotate', magnitude_key='angle', magnitude_range=(0, 30)),
dict(
type='Brightness', magnitude_key='magnitude',
magnitude_range=(0, 0.0)),
dict(type='Sharpness', magnitude_key='magnitude', magnitude_range=(0, 0)),
dict(
type='Shear',
magnitude_key='magnitude',
magnitude_range=(0, 0.3),
direction='horizontal'),
dict(
type='Shear',
magnitude_key='magnitude',
magnitude_range=(0, 0.3),
direction='vertical'),
]
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='RandomResizedCrop',
scale=384,
crop_ratio_range=(0.5, 1.0),
interpolation='bicubic'),
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
dict(
type='RandAugment',
policies=rand_increasing_policies,
num_policies=2,
magnitude_level=5),
dict(type='CleanCaption', keys='text'),
dict(
type='PackInputs',
algorithm_keys=['text', 'is_matched'],
meta_keys=['image_id']),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='Resize',
scale=(384, 384),
interpolation='bicubic',
backend='pillow'),
dict(type='CleanCaption', keys='text'),
dict(
type='PackInputs',
algorithm_keys=['text', 'gt_text_id', 'gt_image_id'],
meta_keys=['image_id']),
]
train_dataloader = dict(
batch_size=32,
num_workers=16,
dataset=dict(
type='COCORetrieval',
data_root='data/coco',
ann_file='annotations/caption_karpathy_train2014.json',
pipeline=train_pipeline),
sampler=dict(type='DefaultSampler', shuffle=True),
persistent_workers=True,
drop_last=True,
)
val_dataloader = dict(
batch_size=64,
num_workers=16,
dataset=dict(
type='COCORetrieval',
data_root='data/coco',
ann_file='annotations/caption_karpathy_val2014.json',
pipeline=test_pipeline,
# This is required for evaluation
test_mode=True,
),
sampler=dict(type='SequentialSampler', subsample_type='sequential'),
persistent_workers=True,
)
val_evaluator = dict(type='RetrievalRecall', topk=(1, 5, 10))
# If you want standard test, please manually configure the test dataset
test_dataloader = val_dataloader
test_evaluator = val_evaluator

View File

@ -0,0 +1,96 @@
# data settings
data_preprocessor = dict(
type='MultiModalDataPreprocessor',
mean=[122.770938, 116.7460125, 104.09373615],
std=[68.5005327, 66.6321579, 70.32316305],
to_rgb=True,
)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='RandomResizedCrop',
scale=(480, 480),
crop_ratio_range=(0.5, 1.0),
interpolation='bicubic',
backend='pillow'),
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
dict(
type='RandAugment',
policies='simple_increasing', # slightly different from LAVIS
num_policies=2,
magnitude_level=5),
dict(type='CleanCaption', keys=['question', 'gt_answer']),
dict(
type='PackInputs',
algorithm_keys=['question', 'gt_answer', 'gt_answer_weight']),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='Resize',
scale=(480, 480),
interpolation='bicubic',
backend='pillow'),
dict(type='CleanCaption', keys=['question']),
dict(
type='PackInputs',
algorithm_keys=['question'],
meta_keys=['question_id']),
]
train_dataloader = dict(
batch_size=32,
num_workers=8,
dataset=dict(
type='ConcatDataset',
datasets=[
# VQAv2 train
dict(
type='COCOVQA',
data_root='data/coco',
data_prefix='train2014',
question_file=
'annotations/v2_OpenEnded_mscoco_train2014_questions.json',
ann_file='annotations/v2_mscoco_train2014_annotations.json',
pipeline=train_pipeline,
),
# VQAv2 val
dict(
type='COCOVQA',
data_root='data/coco',
data_prefix='val2014',
question_file=
'annotations/v2_OpenEnded_mscoco_val2014_questions.json',
ann_file='annotations/v2_mscoco_val2014_annotations.json',
pipeline=train_pipeline,
),
# Visual Genome
dict(
type='VisualGenomeQA',
data_root='visual_genome',
data_prefix='image',
ann_file='question_answers.json',
pipeline=train_pipeline,
)
]),
sampler=dict(type='DefaultSampler', shuffle=True),
persistent_workers=True,
drop_last=True,
)
test_dataloader = dict(
batch_size=32,
num_workers=8,
dataset=dict(
type='COCOVQA',
data_root='data/coco',
data_prefix='test2015',
question_file=
'annotations/v2_OpenEnded_mscoco_test2015_questions.json', # noqa: E501
pipeline=test_pipeline,
),
sampler=dict(type='DefaultSampler', shuffle=False),
)
test_evaluator = dict(type='ReportVQA', file_path='vqa_test.json')

View File

@ -0,0 +1,84 @@
# data settings
data_preprocessor = dict(
mean=[122.770938, 116.7460125, 104.09373615],
std=[68.5005327, 66.6321579, 70.32316305],
to_rgb=True,
)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='RandomResizedCrop',
scale=384,
interpolation='bicubic',
backend='pillow'),
dict(
type='PackInputs',
algorithm_keys=['question', 'gt_answer', 'gt_answer_weight'],
meta_keys=['question_id', 'image_id'],
),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='Resize',
scale=(480, 480),
interpolation='bicubic',
backend='pillow'),
dict(
type='CleanCaption',
keys=['question'],
),
dict(
type='PackInputs',
algorithm_keys=['question', 'gt_answer', 'gt_answer_weight'],
meta_keys=['question_id', 'image_id'],
),
]
train_dataloader = dict(
batch_size=16,
num_workers=8,
dataset=dict(
type='COCOVQA',
data_root='data/coco',
data_prefix='train2014',
question_file=
'annotations/v2_OpenEnded_mscoco_train2014_questions.json',
ann_file='annotations/v2_mscoco_train2014_annotations.json',
pipeline=train_pipeline),
sampler=dict(type='DefaultSampler', shuffle=True),
persistent_workers=True,
drop_last=True,
)
val_dataloader = dict(
batch_size=16,
num_workers=8,
dataset=dict(
type='COCOVQA',
data_root='data/coco',
data_prefix='val2014',
question_file='annotations/v2_OpenEnded_mscoco_val2014_questions.json',
ann_file='annotations/v2_mscoco_val2014_annotations.json',
pipeline=test_pipeline),
sampler=dict(type='DefaultSampler', shuffle=False),
persistent_workers=True,
)
val_evaluator = dict(type='VQAAcc')
test_dataloader = dict(
batch_size=16,
num_workers=8,
dataset=dict(
type='COCOVQA',
data_root='data/coco',
data_prefix='test2015',
question_file= # noqa: E251
'annotations/v2_OpenEnded_mscoco_test2015_questions.json',
pipeline=test_pipeline),
sampler=dict(type='DefaultSampler', shuffle=False),
)
test_evaluator = dict(type='ReportVQA', file_path='vqa_test.json')

View File

@ -1,6 +1,7 @@
# dataset settings
dataset_type = 'CUB'
data_preprocessor = dict(
num_classes=200,
# RGB format normalization parameters
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
@ -13,14 +14,14 @@ train_pipeline = [
dict(type='Resize', scale=510),
dict(type='RandomCrop', crop_size=384),
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
dict(type='PackClsInputs'),
dict(type='PackInputs'),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='Resize', scale=510),
dict(type='CenterCrop', crop_size=384),
dict(type='PackClsInputs'),
dict(type='PackInputs'),
]
train_dataloader = dict(
@ -29,10 +30,9 @@ train_dataloader = dict(
dataset=dict(
type=dataset_type,
data_root='data/CUB_200_2011',
test_mode=False,
split='train',
pipeline=train_pipeline),
sampler=dict(type='DefaultSampler', shuffle=True),
persistent_workers=True,
)
val_dataloader = dict(
@ -41,10 +41,9 @@ val_dataloader = dict(
dataset=dict(
type=dataset_type,
data_root='data/CUB_200_2011',
test_mode=True,
split='test',
pipeline=test_pipeline),
sampler=dict(type='DefaultSampler', shuffle=False),
persistent_workers=True,
)
val_evaluator = dict(type='Accuracy', topk=(1, ))

View File

@ -1,6 +1,7 @@
# dataset settings
dataset_type = 'CUB'
data_preprocessor = dict(
num_classes=200,
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
# convert image from BGR to RGB
@ -12,14 +13,14 @@ train_pipeline = [
dict(type='Resize', scale=600),
dict(type='RandomCrop', crop_size=448),
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
dict(type='PackClsInputs'),
dict(type='PackInputs'),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='Resize', scale=600),
dict(type='CenterCrop', crop_size=448),
dict(type='PackClsInputs'),
dict(type='PackInputs'),
]
train_dataloader = dict(
@ -28,10 +29,9 @@ train_dataloader = dict(
dataset=dict(
type=dataset_type,
data_root='data/CUB_200_2011',
test_mode=False,
split='train',
pipeline=train_pipeline),
sampler=dict(type='DefaultSampler', shuffle=True),
persistent_workers=True,
)
val_dataloader = dict(
@ -40,10 +40,9 @@ val_dataloader = dict(
dataset=dict(
type=dataset_type,
data_root='data/CUB_200_2011',
test_mode=True,
split='test',
pipeline=test_pipeline),
sampler=dict(type='DefaultSampler', shuffle=False),
persistent_workers=True,
)
val_evaluator = dict(type='Accuracy', topk=(1, ))

View File

@ -0,0 +1,92 @@
# data settings
data_preprocessor = dict(
type='MultiModalDataPreprocessor',
mean=[122.770938, 116.7460125, 104.09373615],
std=[68.5005327, 66.6321579, 70.32316305],
to_rgb=True,
)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='RandomResizedCrop',
scale=384,
interpolation='bicubic',
backend='pillow'),
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
dict(type='CleanCaption', keys='gt_caption'),
dict(
type='PackInputs',
algorithm_keys=['gt_caption'],
meta_keys=['image_id'],
),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='Resize',
scale=(384, 384),
interpolation='bicubic',
backend='pillow'),
dict(type='PackInputs', meta_keys=['image_id']),
]
train_dataloader = dict(
batch_size=32,
num_workers=5,
dataset=dict(
type='Flickr30kCaption',
data_root='data/flickr30k',
ann_file='annotations/dataset_flickr30k.json',
data_prefix='images',
split='train',
pipeline=train_pipeline),
sampler=dict(type='DefaultSampler', shuffle=True),
persistent_workers=True,
drop_last=True,
)
val_dataloader = dict(
batch_size=16,
num_workers=5,
dataset=dict(
type='Flickr30kCaption',
data_root='data/flickr30k',
ann_file='annotations/dataset_flickr30k.json',
data_prefix='images',
split='val',
pipeline=test_pipeline,
),
sampler=dict(type='DefaultSampler', shuffle=False),
persistent_workers=True,
)
# refer tools/dataset_converters/convert_flickr30k_ann.py
val_evaluator = dict(
type='COCOCaption',
ann_file='data/flickr30k_val_gt.json',
)
# # If you want standard test, please manually configure the test dataset
test_dataloader = dict(
batch_size=16,
num_workers=5,
dataset=dict(
type='Flickr30kCaption',
data_root='data/flickr30k',
ann_file='annotations/dataset_flickr30k.json',
data_prefix='images',
split='test',
pipeline=test_pipeline,
),
sampler=dict(type='DefaultSampler', shuffle=False),
persistent_workers=True,
)
# refer tools/dataset_converters/convert_flickr30k_ann.py
test_evaluator = dict(
type='COCOCaption',
ann_file='data/flickr30k_test_gt.json',
)

View File

@ -0,0 +1,112 @@
# data settings
data_preprocessor = dict(
type='MultiModalDataPreprocessor',
mean=[122.770938, 116.7460125, 104.09373615],
std=[68.5005327, 66.6321579, 70.32316305],
to_rgb=True,
)
rand_increasing_policies = [
dict(type='AutoContrast'),
dict(type='Equalize'),
dict(type='Rotate', magnitude_key='angle', magnitude_range=(0, 30)),
dict(
type='Brightness', magnitude_key='magnitude',
magnitude_range=(0, 0.0)),
dict(type='Sharpness', magnitude_key='magnitude', magnitude_range=(0, 0)),
dict(
type='Shear',
magnitude_key='magnitude',
magnitude_range=(0, 0.3),
direction='horizontal'),
dict(
type='Shear',
magnitude_key='magnitude',
magnitude_range=(0, 0.3),
direction='vertical'),
]
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='RandomResizedCrop',
scale=384,
crop_ratio_range=(0.5, 1.0),
interpolation='bicubic'),
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
dict(
type='RandAugment',
policies=rand_increasing_policies,
num_policies=2,
magnitude_level=5),
dict(type='CleanCaption', keys='text'),
dict(
type='PackInputs',
algorithm_keys=['text', 'is_matched'],
meta_keys=['image_id']),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='Resize',
scale=(384, 384),
interpolation='bicubic',
backend='pillow'),
dict(type='CleanCaption', keys='text'),
dict(
type='PackInputs',
algorithm_keys=['text', 'gt_text_id', 'gt_image_id'],
meta_keys=['image_id']),
]
train_dataloader = dict(
batch_size=32,
num_workers=16,
dataset=dict(
type='Flickr30kRetrieval',
data_root='data/flickr30k',
ann_file='annotations/dataset_flickr30k.json',
data_prefix='images',
split='train',
pipeline=train_pipeline),
sampler=dict(type='DefaultSampler', shuffle=True),
persistent_workers=True,
drop_last=True,
)
val_dataloader = dict(
batch_size=64,
num_workers=16,
dataset=dict(
type='Flickr30kRetrieval',
data_root='data/flickr30k',
ann_file='annotations/dataset_flickr30k.json',
data_prefix='images',
split='val',
pipeline=test_pipeline,
test_mode=True, # This is required for evaluation
),
sampler=dict(type='SequentialSampler', subsample_type='sequential'),
persistent_workers=True,
)
val_evaluator = dict(type='RetrievalRecall', topk=(1, 5, 10))
# If you want standard test, please manually configure the test dataset
test_dataloader = dict(
batch_size=64,
num_workers=16,
dataset=dict(
type='Flickr30kRetrieval',
data_root='data/flickr30k',
ann_file='annotations/dataset_flickr30k.json',
data_prefix='images',
split='test',
pipeline=test_pipeline,
test_mode=True, # This is required for evaluation
),
sampler=dict(type='SequentialSampler', subsample_type='sequential'),
persistent_workers=True,
)
test_evaluator = val_evaluator

View File

@ -0,0 +1,81 @@
# data settings
data_preprocessor = dict(
mean=[122.770938, 116.7460125, 104.09373615],
std=[68.5005327, 66.6321579, 70.32316305],
to_rgb=True,
)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='RandomResizedCrop',
scale=384,
interpolation='bicubic',
backend='pillow'),
dict(
type='PackInputs',
algorithm_keys=['question', 'gt_answer', 'gt_answer_weight'],
meta_keys=['question_id', 'image_id'],
),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='Resize',
scale=(480, 480),
interpolation='bicubic',
backend='pillow'),
dict(
type='CleanCaption',
keys=['question'],
),
dict(
type='PackInputs',
algorithm_keys=['question', 'gt_answer', 'gt_answer_weight'],
meta_keys=['question_id', 'image_id'],
),
]
train_dataloader = dict(
batch_size=16,
num_workers=8,
dataset=dict(
type='GQA',
data_root='data/gqa',
data_prefix='images',
ann_file='annotations/train_balanced_questions.json',
pipeline=test_pipeline),
sampler=dict(type='DefaultSampler', shuffle=False),
persistent_workers=True,
drop_last=True,
)
val_dataloader = dict(
batch_size=16,
num_workers=8,
dataset=dict(
type='GQA',
data_root='data/gqa',
data_prefix='images',
ann_file='annotations/testdev_balanced_questions.json',
pipeline=test_pipeline),
sampler=dict(type='DefaultSampler', shuffle=False),
persistent_workers=True,
)
val_evaluator = dict(type='GQAAcc')
test_dataloader = dict(
batch_size=16,
num_workers=8,
dataset=dict(
type='GQA',
data_root='data/gqa',
data_prefix='images',
ann_file='annotations/testdev_balanced_questions.json',
pipeline=test_pipeline),
sampler=dict(type='DefaultSampler', shuffle=False),
persistent_workers=True,
)
test_evaluator = val_evaluator

View File

@ -1,6 +1,7 @@
# dataset settings
dataset_type = 'ImageNet21k'
data_preprocessor = dict(
num_classes=21842,
# RGB format normalization parameters
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
@ -12,14 +13,7 @@ train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='RandomResizedCrop', scale=224),
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
dict(type='PackClsInputs'),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='ResizeEdge', scale=256, edge='short'),
dict(type='CenterCrop', crop_size=224),
dict(type='PackClsInputs'),
dict(type='PackInputs'),
]
train_dataloader = dict(
@ -28,27 +22,7 @@ train_dataloader = dict(
dataset=dict(
type=dataset_type,
data_root='data/imagenet21k',
ann_file='meta/train.txt',
data_prefix='train',
split='train',
pipeline=train_pipeline),
sampler=dict(type='DefaultSampler', shuffle=True),
persistent_workers=True,
)
val_dataloader = dict(
batch_size=128,
num_workers=5,
dataset=dict(
type=dataset_type,
data_root='data/imagenet21k',
ann_file='meta/val.txt',
data_prefix='val',
pipeline=test_pipeline),
sampler=dict(type='DefaultSampler', shuffle=False),
persistent_workers=True,
)
val_evaluator = dict(type='Accuracy', topk=(1, 5))
# If you want standard test, please manually configure the test dataset
test_dataloader = val_dataloader
test_evaluator = val_evaluator

View File

@ -1,6 +1,7 @@
# dataset settings
dataset_type = 'ImageNet'
data_preprocessor = dict(
num_classes=1000,
# RGB format normalization parameters
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
@ -27,14 +28,14 @@ train_pipeline = [
max_area_ratio=1 / 3,
fill_color=bgr_mean,
fill_std=bgr_std),
dict(type='PackClsInputs'),
dict(type='PackInputs'),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='ResizeEdge', scale=256, edge='short', backend='pillow'),
dict(type='CenterCrop', crop_size=224),
dict(type='PackClsInputs'),
dict(type='PackInputs'),
]
train_dataloader = dict(
@ -43,11 +44,9 @@ train_dataloader = dict(
dataset=dict(
type=dataset_type,
data_root='data/imagenet',
ann_file='meta/train.txt',
data_prefix='train',
split='train',
pipeline=train_pipeline),
sampler=dict(type='DefaultSampler', shuffle=True),
persistent_workers=True,
)
val_dataloader = dict(
@ -56,11 +55,9 @@ val_dataloader = dict(
dataset=dict(
type=dataset_type,
data_root='data/imagenet',
ann_file='meta/val.txt',
data_prefix='val',
split='val',
pipeline=test_pipeline),
sampler=dict(type='DefaultSampler', shuffle=False),
persistent_workers=True,
)
val_evaluator = dict(type='Accuracy', topk=(1, 5))

View File

@ -1,6 +1,7 @@
# dataset settings
dataset_type = 'ImageNet'
data_preprocessor = dict(
num_classes=1000,
# RGB format normalization parameters
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
@ -36,7 +37,7 @@ train_pipeline = [
max_area_ratio=1 / 3,
fill_color=bgr_mean,
fill_std=bgr_std),
dict(type='PackClsInputs'),
dict(type='PackInputs'),
]
test_pipeline = [
@ -48,7 +49,7 @@ test_pipeline = [
backend='pillow',
interpolation='bicubic'),
dict(type='CenterCrop', crop_size=224),
dict(type='PackClsInputs'),
dict(type='PackInputs'),
]
train_dataloader = dict(
@ -57,11 +58,9 @@ train_dataloader = dict(
dataset=dict(
type=dataset_type,
data_root='data/imagenet',
ann_file='meta/train.txt',
data_prefix='train',
split='train',
pipeline=train_pipeline),
sampler=dict(type='DefaultSampler', shuffle=True),
persistent_workers=True,
)
val_dataloader = dict(
@ -70,11 +69,9 @@ val_dataloader = dict(
dataset=dict(
type=dataset_type,
data_root='data/imagenet',
ann_file='meta/val.txt',
data_prefix='val',
split='val',
pipeline=test_pipeline),
sampler=dict(type='DefaultSampler', shuffle=False),
persistent_workers=True,
)
val_evaluator = dict(type='Accuracy', topk=(1, 5))

View File

@ -1,6 +1,7 @@
# dataset settings
dataset_type = 'ImageNet'
data_preprocessor = dict(
num_classes=1000,
# RGB format normalization parameters
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
@ -36,7 +37,7 @@ train_pipeline = [
max_area_ratio=1 / 3,
fill_color=bgr_mean,
fill_std=bgr_std),
dict(type='PackClsInputs'),
dict(type='PackInputs'),
]
test_pipeline = [
@ -48,7 +49,7 @@ test_pipeline = [
backend='pillow',
interpolation='bicubic'),
dict(type='CenterCrop', crop_size=224),
dict(type='PackClsInputs'),
dict(type='PackInputs'),
]
train_dataloader = dict(
@ -57,11 +58,9 @@ train_dataloader = dict(
dataset=dict(
type=dataset_type,
data_root='data/imagenet',
ann_file='meta/train.txt',
data_prefix='train',
split='train',
pipeline=train_pipeline),
sampler=dict(type='DefaultSampler', shuffle=True),
persistent_workers=True,
)
val_dataloader = dict(
@ -70,11 +69,9 @@ val_dataloader = dict(
dataset=dict(
type=dataset_type,
data_root='data/imagenet',
ann_file='meta/val.txt',
data_prefix='val',
split='val',
pipeline=test_pipeline),
sampler=dict(type='DefaultSampler', shuffle=False),
persistent_workers=True,
)
val_evaluator = dict(type='Accuracy', topk=(1, 5))

View File

@ -0,0 +1,83 @@
# dataset settings
dataset_type = 'ImageNet'
data_preprocessor = dict(
num_classes=1000,
# RGB format normalization parameters
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
# convert image from BGR to RGB
to_rgb=True,
)
bgr_mean = data_preprocessor['mean'][::-1]
bgr_std = data_preprocessor['std'][::-1]
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='RandomResizedCrop',
scale=224,
backend='pillow',
interpolation='bicubic'),
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
dict(
type='RandAugment',
policies='timm_increasing',
num_policies=2,
total_level=10,
magnitude_level=7,
magnitude_std=0.5,
hparams=dict(
pad_val=[round(x) for x in bgr_mean], interpolation='bicubic')),
dict(type='ColorJitter', brightness=0.4, contrast=0.4, saturation=0.4),
dict(
type='RandomErasing',
erase_prob=0.25,
mode='rand', # should be 'pixel', but currently not supported
min_area_ratio=0.02,
max_area_ratio=1 / 3,
fill_color=bgr_mean,
fill_std=bgr_std),
dict(type='PackInputs'),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='ResizeEdge',
scale=256,
edge='short',
backend='pillow',
interpolation='bicubic'),
dict(type='CenterCrop', crop_size=224),
dict(type='PackInputs'),
]
train_dataloader = dict(
batch_size=256,
num_workers=5,
dataset=dict(
type=dataset_type,
data_root='data/imagenet',
split='train',
pipeline=train_pipeline),
sampler=dict(type='DefaultSampler', shuffle=True),
persistent_workers=True,
)
val_dataloader = dict(
batch_size=64,
num_workers=5,
dataset=dict(
type=dataset_type,
data_root='data/imagenet',
split='val',
pipeline=test_pipeline),
sampler=dict(type='DefaultSampler', shuffle=False),
persistent_workers=True,
)
val_evaluator = dict(type='Accuracy', topk=(1, 5))
# If you want standard test, please manually configure the test dataset
test_dataloader = val_dataloader
test_evaluator = val_evaluator

View File

@ -0,0 +1,80 @@
# dataset settings
dataset_type = 'ImageNet'
data_preprocessor = dict(
num_classes=1000,
# RGB format normalization parameters
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
# convert image from BGR to RGB
to_rgb=True,
)
bgr_mean = data_preprocessor['mean'][::-1]
bgr_std = data_preprocessor['std'][::-1]
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='RandomResizedCrop',
scale=384,
backend='pillow',
interpolation='bicubic'),
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
dict(
type='RandAugment',
policies='timm_increasing',
num_policies=2,
total_level=10,
magnitude_level=9,
magnitude_std=0.5,
hparams=dict(
pad_val=[round(x) for x in bgr_mean], interpolation='bicubic')),
dict(
type='RandomErasing',
erase_prob=0.25,
mode='rand',
min_area_ratio=0.02,
max_area_ratio=1 / 3,
fill_color=bgr_mean,
fill_std=bgr_std),
dict(type='PackInputs'),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='ResizeEdge',
scale=404,
edge='short',
backend='pillow',
interpolation='bicubic'),
dict(type='CenterCrop', crop_size=384),
dict(type='PackInputs'),
]
train_dataloader = dict(
batch_size=128,
num_workers=5,
dataset=dict(
type=dataset_type,
data_root='data/imagenet',
split='train',
pipeline=train_pipeline),
sampler=dict(type='DefaultSampler', shuffle=True),
)
val_dataloader = dict(
batch_size=16,
num_workers=5,
dataset=dict(
type=dataset_type,
data_root='data/imagenet',
split='val',
pipeline=test_pipeline),
sampler=dict(type='DefaultSampler', shuffle=False),
)
val_evaluator = dict(type='Accuracy', topk=(1, 5))
# If you want standard test, please manually configure the test dataset
test_dataloader = val_dataloader
test_evaluator = val_evaluator

View File

@ -0,0 +1,80 @@
# dataset settings
dataset_type = 'ImageNet'
data_preprocessor = dict(
num_classes=1000,
# RGB format normalization parameters
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
# convert image from BGR to RGB
to_rgb=True,
)
bgr_mean = data_preprocessor['mean'][::-1]
bgr_std = data_preprocessor['std'][::-1]
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='RandomResizedCrop',
scale=384,
backend='pillow',
interpolation='bicubic'),
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
dict(
type='RandAugment',
policies='timm_increasing',
num_policies=2,
total_level=10,
magnitude_level=9,
magnitude_std=0.5,
hparams=dict(
pad_val=[round(x) for x in bgr_mean], interpolation='bicubic')),
dict(
type='RandomErasing',
erase_prob=0.25,
mode='rand',
min_area_ratio=0.02,
max_area_ratio=1 / 3,
fill_color=bgr_mean,
fill_std=bgr_std),
dict(type='PackInputs'),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='ResizeEdge',
scale=426,
edge='short',
backend='pillow',
interpolation='bicubic'),
dict(type='CenterCrop', crop_size=384),
dict(type='PackInputs'),
]
train_dataloader = dict(
batch_size=128,
num_workers=5,
dataset=dict(
type=dataset_type,
data_root='data/imagenet',
split='train',
pipeline=train_pipeline),
sampler=dict(type='DefaultSampler', shuffle=True),
)
val_dataloader = dict(
batch_size=32,
num_workers=5,
dataset=dict(
type=dataset_type,
data_root='data/imagenet',
split='val',
pipeline=test_pipeline),
sampler=dict(type='DefaultSampler', shuffle=False),
)
val_evaluator = dict(type='Accuracy', topk=(1, 5))
# If you want standard test, please manually configure the test dataset
test_dataloader = val_dataloader
test_evaluator = val_evaluator

View File

@ -0,0 +1,80 @@
# dataset settings
dataset_type = 'ImageNet'
data_preprocessor = dict(
num_classes=1000,
# RGB format normalization parameters
mean=[127.5, 127.5, 127.5],
std=[127.5, 127.5, 127.5],
# convert image from BGR to RGB
to_rgb=True,
)
bgr_mean = data_preprocessor['mean'][::-1]
bgr_std = data_preprocessor['std'][::-1]
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='RandomResizedCrop',
scale=224,
backend='pillow',
interpolation='bicubic'),
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
dict(
type='RandAugment',
policies='timm_increasing',
num_policies=2,
total_level=10,
magnitude_level=9,
magnitude_std=0.5,
hparams=dict(
pad_val=[round(x) for x in bgr_mean], interpolation='bicubic')),
dict(
type='RandomErasing',
erase_prob=0.25,
mode='rand',
min_area_ratio=0.02,
max_area_ratio=1 / 3,
fill_color=bgr_mean,
fill_std=bgr_std),
dict(type='PackInputs'),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='ResizeEdge',
scale=248,
edge='short',
backend='pillow',
interpolation='bicubic'),
dict(type='CenterCrop', crop_size=224),
dict(type='PackInputs'),
]
train_dataloader = dict(
batch_size=128,
num_workers=5,
dataset=dict(
type=dataset_type,
data_root='data/imagenet',
split='train',
pipeline=train_pipeline),
sampler=dict(type='DefaultSampler', shuffle=True),
)
val_dataloader = dict(
batch_size=128,
num_workers=5,
dataset=dict(
type=dataset_type,
data_root='data/imagenet',
split='val',
pipeline=test_pipeline),
sampler=dict(type='DefaultSampler', shuffle=False),
)
val_evaluator = dict(type='Accuracy', topk=(1, 5))
# If you want standard test, please manually configure the test dataset
test_dataloader = val_dataloader
test_evaluator = val_evaluator

View File

@ -0,0 +1,60 @@
# dataset settings
dataset_type = 'ImageNet'
data_preprocessor = dict(
num_classes=1000,
# RGB format normalization parameters
mean=[0.48145466 * 255, 0.4578275 * 255, 0.40821073 * 255],
std=[0.26862954 * 255, 0.26130258 * 255, 0.27577711 * 255],
# convert image from BGR to RGB
to_rgb=True,
)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='RandomResizedCrop',
scale=196,
backend='pillow',
interpolation='bicubic'),
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
dict(type='PackInputs'),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='ResizeEdge',
scale=196,
edge='short',
backend='pillow',
interpolation='bicubic'),
dict(type='CenterCrop', crop_size=196),
dict(type='PackInputs'),
]
train_dataloader = dict(
batch_size=16,
num_workers=5,
dataset=dict(
type=dataset_type,
data_root='data/imagenet',
split='train',
pipeline=train_pipeline),
sampler=dict(type='DefaultSampler', shuffle=True),
)
val_dataloader = dict(
batch_size=16,
num_workers=5,
dataset=dict(
type=dataset_type,
data_root='data/imagenet',
split='val',
pipeline=test_pipeline),
sampler=dict(type='DefaultSampler', shuffle=False),
)
val_evaluator = dict(type='Accuracy', topk=(1, 5))
# If you want standard test, please manually configure the test dataset
test_dataloader = val_dataloader
test_evaluator = val_evaluator

View File

@ -0,0 +1,60 @@
# dataset settings
dataset_type = 'ImageNet'
data_preprocessor = dict(
num_classes=1000,
# RGB format normalization parameters
mean=[0.48145466 * 255, 0.4578275 * 255, 0.40821073 * 255],
std=[0.26862954 * 255, 0.26130258 * 255, 0.27577711 * 255],
# convert image from BGR to RGB
to_rgb=True,
)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='RandomResizedCrop',
scale=336,
backend='pillow',
interpolation='bicubic'),
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
dict(type='PackInputs'),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='ResizeEdge',
scale=336,
edge='short',
backend='pillow',
interpolation='bicubic'),
dict(type='CenterCrop', crop_size=336),
dict(type='PackInputs'),
]
train_dataloader = dict(
batch_size=16,
num_workers=5,
dataset=dict(
type=dataset_type,
data_root='data/imagenet',
split='train',
pipeline=train_pipeline),
sampler=dict(type='DefaultSampler', shuffle=True),
)
val_dataloader = dict(
batch_size=16,
num_workers=5,
dataset=dict(
type=dataset_type,
data_root='data/imagenet',
split='val',
pipeline=test_pipeline),
sampler=dict(type='DefaultSampler', shuffle=False),
)
val_evaluator = dict(type='Accuracy', topk=(1, 5))
# If you want standard test, please manually configure the test dataset
test_dataloader = val_dataloader
test_evaluator = val_evaluator

View File

@ -0,0 +1,62 @@
# dataset settings
dataset_type = 'ImageNet'
data_preprocessor = dict(
num_classes=1000,
# RGB format normalization parameters
mean=[0.48145466 * 255, 0.4578275 * 255, 0.40821073 * 255],
std=[0.26862954 * 255, 0.26130258 * 255, 0.27577711 * 255],
# convert image from BGR to RGB
to_rgb=True,
)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='RandomResizedCrop',
scale=448,
backend='pillow',
interpolation='bicubic'),
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
dict(type='PackInputs'),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='ResizeEdge',
scale=448,
edge='short',
backend='pillow',
interpolation='bicubic'),
dict(type='CenterCrop', crop_size=448),
dict(type='PackInputs'),
]
train_dataloader = dict(
batch_size=16,
num_workers=5,
dataset=dict(
type=dataset_type,
data_root='data/imagenet',
ann_file='meta/train.txt',
data_prefix='train',
pipeline=train_pipeline),
sampler=dict(type='DefaultSampler', shuffle=True),
)
val_dataloader = dict(
batch_size=8,
num_workers=5,
dataset=dict(
type=dataset_type,
data_root='data/imagenet',
ann_file='meta/val.txt',
data_prefix='val',
pipeline=test_pipeline),
sampler=dict(type='DefaultSampler', shuffle=False),
)
val_evaluator = dict(type='Accuracy', topk=(1, 5))
# If you want standard test, please manually configure the test dataset
test_dataloader = val_dataloader
test_evaluator = val_evaluator

View File

@ -0,0 +1,60 @@
# dataset settings
dataset_type = 'ImageNet'
data_preprocessor = dict(
num_classes=1000,
# RGB format normalization parameters
mean=[0.48145466 * 255, 0.4578275 * 255, 0.40821073 * 255],
std=[0.26862954 * 255, 0.26130258 * 255, 0.27577711 * 255],
# convert image from BGR to RGB
to_rgb=True,
)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='RandomResizedCrop',
scale=560,
backend='pillow',
interpolation='bicubic'),
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
dict(type='PackInputs'),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='ResizeEdge',
scale=560,
edge='short',
backend='pillow',
interpolation='bicubic'),
dict(type='CenterCrop', crop_size=560),
dict(type='PackInputs'),
]
train_dataloader = dict(
batch_size=16,
num_workers=5,
dataset=dict(
type=dataset_type,
data_root='data/imagenet',
split='train',
pipeline=train_pipeline),
sampler=dict(type='DefaultSampler', shuffle=True),
)
val_dataloader = dict(
batch_size=16,
num_workers=5,
dataset=dict(
type=dataset_type,
data_root='data/imagenet',
split='val',
pipeline=test_pipeline),
sampler=dict(type='DefaultSampler', shuffle=False),
)
val_evaluator = dict(type='Accuracy', topk=(1, 5))
# If you want standard test, please manually configure the test dataset
test_dataloader = val_dataloader
test_evaluator = val_evaluator

View File

@ -0,0 +1,53 @@
# dataset settings
dataset_type = 'ImageNet'
data_preprocessor = dict(
# RGB format normalization parameters
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
# convert image from BGR to RGB
to_rgb=True,
)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='RandomResizedCrop',
scale=384,
backend='pillow',
interpolation='bicubic'),
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
dict(type='PackInputs'),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='Resize', scale=384, backend='pillow', interpolation='bicubic'),
dict(type='PackInputs'),
]
train_dataloader = dict(
batch_size=16,
num_workers=5,
dataset=dict(
type=dataset_type,
data_root='data/imagenet',
split='train',
pipeline=train_pipeline),
sampler=dict(type='DefaultSampler', shuffle=True),
)
val_dataloader = dict(
batch_size=16,
num_workers=5,
dataset=dict(
type=dataset_type,
data_root='data/imagenet',
split='val',
pipeline=test_pipeline),
sampler=dict(type='DefaultSampler', shuffle=False),
)
val_evaluator = dict(type='Accuracy', topk=(1, 5))
# If you want standard test, please manually configure the test dataset
test_dataloader = val_dataloader
test_evaluator = val_evaluator

View File

@ -0,0 +1,47 @@
# dataset settings
dataset_type = 'ImageNet'
data_root = 'data/imagenet/'
data_preprocessor = dict(
type='TwoNormDataPreprocessor',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
second_mean=[127.5, 127.5, 127.5],
second_std=[127.5, 127.5, 127.5],
to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='ColorJitter',
brightness=0.4,
contrast=0.4,
saturation=0.4,
hue=0.),
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
dict(
type='RandomResizedCropAndInterpolationWithTwoPic',
size=224,
second_size=224,
interpolation='bicubic',
second_interpolation='bicubic',
scale=(0.2, 1.0)),
dict(
type='BEiTMaskGenerator',
input_size=(14, 14),
num_masking_patches=75,
max_num_patches=75,
min_num_patches=16),
dict(type='PackInputs')
]
train_dataloader = dict(
batch_size=256,
num_workers=8,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=True),
collate_fn=dict(type='default_collate'),
dataset=dict(
type=dataset_type,
data_root=data_root,
split='train',
pipeline=train_pipeline))

View File

@ -0,0 +1,80 @@
# dataset settings
dataset_type = 'ImageNet'
data_preprocessor = dict(
num_classes=1000,
# RGB format normalization parameters
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
# convert image from BGR to RGB
to_rgb=True,
)
bgr_mean = data_preprocessor['mean'][::-1]
bgr_std = data_preprocessor['std'][::-1]
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='RandomResizedCrop',
scale=224,
backend='pillow',
interpolation='bicubic'),
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
dict(
type='RandAugment',
policies='timm_increasing',
num_policies=2,
total_level=10,
magnitude_level=9,
magnitude_std=0.5,
hparams=dict(
pad_val=[round(x) for x in bgr_mean], interpolation='bicubic')),
dict(
type='RandomErasing',
erase_prob=0.25,
mode='rand',
min_area_ratio=0.02,
max_area_ratio=1 / 3,
fill_color=bgr_mean,
fill_std=bgr_std),
dict(type='PackInputs'),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='ResizeEdge',
scale=236,
edge='short',
backend='pillow',
interpolation='bicubic'),
dict(type='CenterCrop', crop_size=224),
dict(type='PackInputs'),
]
train_dataloader = dict(
batch_size=64,
num_workers=5,
dataset=dict(
type=dataset_type,
data_root='data/imagenet',
split='train',
pipeline=train_pipeline),
sampler=dict(type='DefaultSampler', shuffle=True),
)
val_dataloader = dict(
batch_size=64,
num_workers=5,
dataset=dict(
type=dataset_type,
data_root='data/imagenet',
split='val',
pipeline=test_pipeline),
sampler=dict(type='DefaultSampler', shuffle=False),
)
val_evaluator = dict(type='Accuracy', topk=(1, 5))
# If you want standard test, please manually configure the test dataset
test_dataloader = val_dataloader
test_evaluator = val_evaluator

View File

@ -0,0 +1,49 @@
# dataset settings
dataset_type = 'ImageNet'
data_root = 'data/imagenet/'
data_preprocessor = dict(
type='TwoNormDataPreprocessor',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
# clip mean & std
second_mean=[0.48145466 * 255, 0.4578275 * 255, 0.40821073 * 255],
second_std=[0.26862954 * 255, 0.26130258 * 255, 0.27577711 * 255],
to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='ColorJitter',
brightness=0.4,
contrast=0.4,
saturation=0.4,
hue=0.),
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
dict(
type='RandomResizedCropAndInterpolationWithTwoPic',
size=224,
second_size=224,
interpolation='bicubic',
second_interpolation='bicubic',
scale=(0.2, 1.0)),
dict(
type='BEiTMaskGenerator',
input_size=(14, 14),
num_masking_patches=75,
max_num_patches=75,
min_num_patches=16),
dict(type='PackInputs')
]
train_dataloader = dict(
batch_size=256,
num_workers=8,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=True),
collate_fn=dict(type='default_collate'),
dataset=dict(
type=dataset_type,
data_root=data_root,
ann_file='meta/train.txt',
data_prefix=dict(img_path='train/'),
pipeline=train_pipeline))

View File

@ -0,0 +1,80 @@
dataset_type = 'ImageNet'
data_preprocessor = dict(
num_classes=1000,
# RGB format normalization parameters
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
# convert image from BGR to RGB
to_rgb=True,
)
bgr_mean = data_preprocessor['mean'][::-1]
bgr_std = data_preprocessor['std'][::-1]
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='RandomResizedCrop',
scale=224,
backend='pillow',
interpolation='bicubic'),
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
dict(
type='RandAugment',
policies='timm_increasing',
num_policies=2,
total_level=10,
magnitude_level=9,
magnitude_std=0.5,
hparams=dict(
pad_val=[round(x) for x in bgr_mean], interpolation='bicubic')),
dict(
type='RandomErasing',
erase_prob=0.25,
mode='rand',
min_area_ratio=0.02,
max_area_ratio=1 / 3,
fill_color=bgr_mean,
fill_std=bgr_std),
dict(type='PackInputs'),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='ResizeEdge',
scale=256,
edge='short',
backend='pillow',
interpolation='bicubic'),
dict(type='CenterCrop', crop_size=224),
dict(type='PackInputs'),
]
train_dataloader = dict(
batch_size=256,
num_workers=4,
dataset=dict(
type=dataset_type,
data_root='data/imagenet',
split='train',
pipeline=train_pipeline),
sampler=dict(type='DefaultSampler', shuffle=True),
)
val_dataloader = dict(
batch_size=256,
num_workers=4,
dataset=dict(
type=dataset_type,
data_root='data/imagenet',
split='val',
pipeline=test_pipeline),
sampler=dict(type='DefaultSampler', shuffle=False),
)
val_evaluator = dict(type='Accuracy', topk=(1, 5))
# If you want standard test, please manually configure the test dataset
test_dataloader = val_dataloader
test_evaluator = val_evaluator

View File

@ -1,6 +1,7 @@
# dataset settings
dataset_type = 'ImageNet'
data_preprocessor = dict(
num_classes=1000,
# RGB format normalization parameters
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
@ -28,7 +29,7 @@ train_pipeline = [
magnitude_std=0.5,
hparams=dict(
pad_val=[round(x) for x in bgr_mean], interpolation='bicubic')),
dict(type='PackClsInputs'),
dict(type='PackInputs'),
]
test_pipeline = [
@ -40,7 +41,7 @@ test_pipeline = [
backend='pillow',
interpolation='bicubic'),
dict(type='CenterCrop', crop_size=224),
dict(type='PackClsInputs')
dict(type='PackInputs')
]
train_dataloader = dict(
@ -49,11 +50,9 @@ train_dataloader = dict(
dataset=dict(
type=dataset_type,
data_root='data/imagenet',
ann_file='meta/train.txt',
data_prefix='train',
split='train',
pipeline=train_pipeline),
sampler=dict(type='DefaultSampler', shuffle=True),
persistent_workers=True,
)
val_dataloader = dict(
@ -62,11 +61,9 @@ val_dataloader = dict(
dataset=dict(
type=dataset_type,
data_root='data/imagenet',
ann_file='meta/val.txt',
data_prefix='val',
split='val',
pipeline=test_pipeline),
sampler=dict(type='DefaultSampler', shuffle=False),
persistent_workers=True,
)
val_evaluator = dict(type='Accuracy', topk=(1, 5))

View File

@ -1,6 +1,7 @@
# dataset settings
dataset_type = 'ImageNet'
data_preprocessor = dict(
num_classes=1000,
# RGB format normalization parameters
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
@ -28,7 +29,7 @@ train_pipeline = [
magnitude_std=0.5,
hparams=dict(
pad_val=[round(x) for x in bgr_mean], interpolation='bicubic')),
dict(type='PackClsInputs'),
dict(type='PackInputs'),
]
test_pipeline = [
@ -40,7 +41,7 @@ test_pipeline = [
backend='pillow',
interpolation='bicubic'),
dict(type='CenterCrop', crop_size=224),
dict(type='PackClsInputs')
dict(type='PackInputs')
]
train_dataloader = dict(
@ -49,11 +50,9 @@ train_dataloader = dict(
dataset=dict(
type=dataset_type,
data_root='data/imagenet',
ann_file='meta/train.txt',
data_prefix='train',
split='train',
pipeline=train_pipeline),
sampler=dict(type='DefaultSampler', shuffle=True),
persistent_workers=True,
)
val_dataloader = dict(
@ -62,11 +61,9 @@ val_dataloader = dict(
dataset=dict(
type=dataset_type,
data_root='data/imagenet',
ann_file='meta/val.txt',
data_prefix='val',
split='val',
pipeline=test_pipeline),
sampler=dict(type='DefaultSampler', shuffle=False),
persistent_workers=True,
)
val_evaluator = dict(type='Accuracy', topk=(1, 5))

View File

@ -0,0 +1,33 @@
# dataset settings
dataset_type = 'ImageNet'
data_root = 'data/imagenet/'
data_preprocessor = dict(
type='SelfSupDataPreprocessor',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='RandomResizedCrop', scale=192, crop_ratio_range=(0.67, 1.0)),
dict(type='RandomFlip', prob=0.5),
dict(
type='SimMIMMaskGenerator',
input_size=192,
mask_patch_size=32,
model_patch_size=4,
mask_ratio=0.6),
dict(type='PackInputs')
]
train_dataloader = dict(
batch_size=256,
num_workers=8,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=True),
collate_fn=dict(type='default_collate'),
dataset=dict(
type=dataset_type,
data_root=data_root,
split='train',
pipeline=train_pipeline))

View File

@ -0,0 +1,81 @@
# dataset settings
dataset_type = 'ImageNet'
data_root = 'data/imagenet/'
data_preprocessor = dict(
num_classes=1000,
# RGB format normalization parameters
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
# convert image from BGR to RGB
to_rgb=True,
)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='RandomResizedCrop',
scale=192,
backend='pillow',
interpolation='bicubic'),
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
dict(
type='RandAugment',
policies='timm_increasing',
num_policies=2,
total_level=10,
magnitude_level=9,
magnitude_std=0.5,
hparams=dict(pad_val=[104, 116, 124], interpolation='bicubic')),
dict(
type='RandomErasing',
erase_prob=0.25,
mode='rand',
min_area_ratio=0.02,
max_area_ratio=1 / 3,
fill_color=[103.53, 116.28, 123.675],
fill_std=[57.375, 57.12, 58.395]),
dict(type='PackInputs'),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='ResizeEdge',
scale=219,
edge='short',
backend='pillow',
interpolation='bicubic'),
dict(type='CenterCrop', crop_size=192),
dict(type='PackInputs'),
]
train_dataloader = dict(
batch_size=256,
num_workers=8,
collate_fn=dict(type='default_collate'),
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=True),
dataset=dict(
type=dataset_type,
data_root=data_root,
split='train',
pipeline=train_pipeline),
)
val_dataloader = dict(
batch_size=64,
num_workers=5,
collate_fn=dict(type='default_collate'),
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type=dataset_type,
data_root=data_root,
split='val',
pipeline=test_pipeline),
)
val_evaluator = dict(type='Accuracy', topk=(1, 5))
# If you want standard test, please manually configure the test dataset
test_dataloader = val_dataloader
test_evaluator = val_evaluator

View File

@ -1,6 +1,7 @@
# dataset settings
dataset_type = 'ImageNet'
data_preprocessor = dict(
num_classes=1000,
# RGB format normalization parameters
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
@ -12,14 +13,14 @@ train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='RandomResizedCrop', scale=224),
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
dict(type='PackClsInputs'),
dict(type='PackInputs'),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='ResizeEdge', scale=256, edge='short'),
dict(type='CenterCrop', crop_size=224),
dict(type='PackClsInputs'),
dict(type='PackInputs'),
]
train_dataloader = dict(
@ -28,11 +29,9 @@ train_dataloader = dict(
dataset=dict(
type=dataset_type,
data_root='data/imagenet',
ann_file='meta/train.txt',
data_prefix='train',
split='train',
pipeline=train_pipeline),
sampler=dict(type='DefaultSampler', shuffle=True),
persistent_workers=True,
)
val_dataloader = dict(
@ -41,11 +40,9 @@ val_dataloader = dict(
dataset=dict(
type=dataset_type,
data_root='data/imagenet',
ann_file='meta/val.txt',
data_prefix='val',
split='val',
pipeline=test_pipeline),
sampler=dict(type='DefaultSampler', shuffle=False),
persistent_workers=True,
)
val_evaluator = dict(type='Accuracy', topk=(1, 5))

View File

@ -0,0 +1,89 @@
# dataset settings
dataset_type = 'ImageNet'
data_root = 'data/imagenet/'
data_preprocessor = dict(
type='SelfSupDataPreprocessor',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
to_rgb=True)
view_pipeline1 = [
dict(
type='RandomResizedCrop',
scale=224,
interpolation='bicubic',
backend='pillow'),
dict(type='RandomFlip', prob=0.5),
dict(
type='RandomApply',
transforms=[
dict(
type='ColorJitter',
brightness=0.4,
contrast=0.4,
saturation=0.2,
hue=0.1)
],
prob=0.8),
dict(
type='RandomGrayscale',
prob=0.2,
keep_channels=True,
channel_weights=(0.114, 0.587, 0.2989)),
dict(
type='GaussianBlur',
magnitude_range=(0.1, 2.0),
magnitude_std='inf',
prob=1.),
dict(type='Solarize', thr=128, prob=0.),
]
view_pipeline2 = [
dict(
type='RandomResizedCrop',
scale=224,
interpolation='bicubic',
backend='pillow'),
dict(type='RandomFlip', prob=0.5),
dict(
type='RandomApply',
transforms=[
dict(
type='ColorJitter',
brightness=0.4,
contrast=0.4,
saturation=0.2,
hue=0.1)
],
prob=0.8),
dict(
type='RandomGrayscale',
prob=0.2,
keep_channels=True,
channel_weights=(0.114, 0.587, 0.2989)),
dict(
type='GaussianBlur',
magnitude_range=(0.1, 2.0),
magnitude_std='inf',
prob=0.1),
dict(type='Solarize', thr=128, prob=0.2)
]
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiView',
num_views=[1, 1],
transforms=[view_pipeline1, view_pipeline2]),
dict(type='PackInputs')
]
train_dataloader = dict(
batch_size=32,
num_workers=4,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=True),
collate_fn=dict(type='default_collate'),
dataset=dict(
type=dataset_type,
data_root=data_root,
split='train',
pipeline=train_pipeline))

View File

@ -0,0 +1,58 @@
# dataset settings
dataset_type = 'ImageNet'
data_root = 'data/imagenet/'
data_preprocessor = dict(
type='SelfSupDataPreprocessor',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
to_rgb=True)
# The difference between mocov2 and mocov1 is the transforms in the pipeline
view_pipeline = [
dict(
type='RandomResizedCrop',
scale=224,
crop_ratio_range=(0.2, 1.),
backend='pillow'),
dict(
type='RandomApply',
transforms=[
dict(
type='ColorJitter',
brightness=0.4,
contrast=0.4,
saturation=0.4,
hue=0.1)
],
prob=0.8),
dict(
type='RandomGrayscale',
prob=0.2,
keep_channels=True,
channel_weights=(0.114, 0.587, 0.2989)),
dict(
type='GaussianBlur',
magnitude_range=(0.1, 2.0),
magnitude_std='inf',
prob=0.5),
dict(type='RandomFlip', prob=0.5),
]
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='MultiView', num_views=2, transforms=[view_pipeline]),
dict(type='PackInputs')
]
train_dataloader = dict(
batch_size=32,
num_workers=8,
drop_last=True,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=True),
collate_fn=dict(type='default_collate'),
dataset=dict(
type=dataset_type,
data_root=data_root,
split='train',
pipeline=train_pipeline))

View File

@ -1,6 +1,7 @@
# dataset settings
dataset_type = 'ImageNet'
data_preprocessor = dict(
num_classes=1000,
# RGB format normalization parameters
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
@ -16,7 +17,7 @@ train_pipeline = [
backend='pillow',
interpolation='bicubic'),
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
dict(type='PackClsInputs'),
dict(type='PackInputs'),
]
test_pipeline = [
@ -28,7 +29,7 @@ test_pipeline = [
backend='pillow',
interpolation='bicubic'),
dict(type='CenterCrop', crop_size=224),
dict(type='PackClsInputs'),
dict(type='PackInputs'),
]
train_dataloader = dict(
@ -37,11 +38,9 @@ train_dataloader = dict(
dataset=dict(
type=dataset_type,
data_root='data/imagenet',
ann_file='meta/train.txt',
data_prefix='train',
split='train',
pipeline=train_pipeline),
sampler=dict(type='DefaultSampler', shuffle=True),
persistent_workers=True,
)
val_dataloader = dict(
@ -50,11 +49,9 @@ val_dataloader = dict(
dataset=dict(
type=dataset_type,
data_root='data/imagenet',
ann_file='meta/val.txt',
data_prefix='val',
split='val',
pipeline=test_pipeline),
sampler=dict(type='DefaultSampler', shuffle=False),
persistent_workers=True,
)
val_evaluator = dict(type='Accuracy', topk=(1, 5))

View File

@ -1,6 +1,7 @@
# dataset settings
dataset_type = 'ImageNet'
data_preprocessor = dict(
num_classes=1000,
# RGB format normalization parameters
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
@ -12,14 +13,14 @@ train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='RandomResizedCrop', scale=224, backend='pillow'),
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
dict(type='PackClsInputs'),
dict(type='PackInputs'),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='ResizeEdge', scale=256, edge='short', backend='pillow'),
dict(type='CenterCrop', crop_size=224),
dict(type='PackClsInputs'),
dict(type='PackInputs'),
]
train_dataloader = dict(
@ -28,11 +29,9 @@ train_dataloader = dict(
dataset=dict(
type=dataset_type,
data_root='data/imagenet',
ann_file='meta/train.txt',
data_prefix='train',
split='train',
pipeline=train_pipeline),
sampler=dict(type='DefaultSampler', shuffle=True),
persistent_workers=True,
)
val_dataloader = dict(
@ -41,11 +40,9 @@ val_dataloader = dict(
dataset=dict(
type=dataset_type,
data_root='data/imagenet',
ann_file='meta/val.txt',
data_prefix='val',
split='val',
pipeline=test_pipeline),
sampler=dict(type='DefaultSampler', shuffle=False),
persistent_workers=True,
)
val_evaluator = dict(type='Accuracy', topk=(1, 5))

View File

@ -0,0 +1,52 @@
# dataset settings
dataset_type = 'ImageNet'
data_root = 'data/imagenet/'
data_preprocessor = dict(
type='SelfSupDataPreprocessor',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
to_rgb=True)
view_pipeline = [
dict(type='RandomResizedCrop', scale=224, backend='pillow'),
dict(type='RandomFlip', prob=0.5),
dict(
type='RandomApply',
transforms=[
dict(
type='ColorJitter',
brightness=0.8,
contrast=0.8,
saturation=0.8,
hue=0.2)
],
prob=0.8),
dict(
type='RandomGrayscale',
prob=0.2,
keep_channels=True,
channel_weights=(0.114, 0.587, 0.2989)),
dict(
type='GaussianBlur',
magnitude_range=(0.1, 2.0),
magnitude_std='inf',
prob=0.5),
]
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='MultiView', num_views=2, transforms=[view_pipeline]),
dict(type='PackInputs')
]
train_dataloader = dict(
batch_size=32,
num_workers=4,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=True),
collate_fn=dict(type='default_collate'),
dataset=dict(
type=dataset_type,
data_root=data_root,
split='train',
pipeline=train_pipeline))

View File

@ -0,0 +1,32 @@
# dataset settings
dataset_type = 'ImageNet'
data_root = 'data/imagenet/'
data_preprocessor = dict(
type='SelfSupDataPreprocessor',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='RandomResizedCrop',
scale=224,
crop_ratio_range=(0.2, 1.0),
backend='pillow',
interpolation='bicubic'),
dict(type='RandomFlip', prob=0.5),
dict(type='PackInputs')
]
train_dataloader = dict(
batch_size=512,
num_workers=8,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=True),
collate_fn=dict(type='default_collate'),
dataset=dict(
type=dataset_type,
data_root=data_root,
split='train',
pipeline=train_pipeline))

View File

@ -0,0 +1,90 @@
# dataset settings
dataset_type = 'ImageNet'
data_root = 'data/imagenet/'
data_preprocessor = dict(
type='SelfSupDataPreprocessor',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
to_rgb=True)
view_pipeline1 = [
dict(
type='RandomResizedCrop',
scale=224,
crop_ratio_range=(0.2, 1.),
backend='pillow'),
dict(
type='RandomApply',
transforms=[
dict(
type='ColorJitter',
brightness=0.4,
contrast=0.4,
saturation=0.2,
hue=0.1)
],
prob=0.8),
dict(
type='RandomGrayscale',
prob=0.2,
keep_channels=True,
channel_weights=(0.114, 0.587, 0.2989)),
dict(
type='GaussianBlur',
magnitude_range=(0.1, 2.0),
magnitude_std='inf',
prob=1.),
dict(type='Solarize', thr=128, prob=0.),
dict(type='RandomFlip', prob=0.5),
]
view_pipeline2 = [
dict(
type='RandomResizedCrop',
scale=224,
crop_ratio_range=(0.2, 1.),
backend='pillow'),
dict(
type='RandomApply',
transforms=[
dict(
type='ColorJitter',
brightness=0.4,
contrast=0.4,
saturation=0.2,
hue=0.1)
],
prob=0.8),
dict(
type='RandomGrayscale',
prob=0.2,
keep_channels=True,
channel_weights=(0.114, 0.587, 0.2989)),
dict(
type='GaussianBlur',
magnitude_range=(0.1, 2.0),
magnitude_std='inf',
prob=0.1),
dict(type='Solarize', thr=128, prob=0.2),
dict(type='RandomFlip', prob=0.5),
]
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiView',
num_views=[1, 1],
transforms=[view_pipeline1, view_pipeline2]),
dict(type='PackInputs')
]
train_dataloader = dict(
batch_size=512,
num_workers=8,
persistent_workers=True,
pin_memory=True,
sampler=dict(type='DefaultSampler', shuffle=True),
collate_fn=dict(type='default_collate'),
dataset=dict(
type=dataset_type,
data_root=data_root,
split='train',
pipeline=train_pipeline))

View File

@ -1,6 +1,7 @@
# dataset settings
dataset_type = 'ImageNet'
data_preprocessor = dict(
num_classes=1000,
# RGB format normalization parameters
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
@ -12,14 +13,14 @@ train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='RandomResizedCrop', scale=224),
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
dict(type='PackClsInputs'),
dict(type='PackInputs'),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='ResizeEdge', scale=256, edge='short'),
dict(type='CenterCrop', crop_size=224),
dict(type='PackClsInputs'),
dict(type='PackInputs'),
]
train_dataloader = dict(
@ -28,11 +29,9 @@ train_dataloader = dict(
dataset=dict(
type=dataset_type,
data_root='data/imagenet',
ann_file='meta/train.txt',
data_prefix='train',
split='train',
pipeline=train_pipeline),
sampler=dict(type='DefaultSampler', shuffle=True),
persistent_workers=True,
)
val_dataloader = dict(
@ -41,11 +40,9 @@ val_dataloader = dict(
dataset=dict(
type=dataset_type,
data_root='data/imagenet',
ann_file='meta/val.txt',
data_prefix='val',
split='val',
pipeline=test_pipeline),
sampler=dict(type='DefaultSampler', shuffle=False),
persistent_workers=True,
)
val_evaluator = dict(type='Accuracy', topk=(1, 5))

View File

@ -1,6 +1,7 @@
# dataset settings
dataset_type = 'ImageNet'
data_preprocessor = dict(
num_classes=1000,
# RGB format normalization parameters
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
@ -20,14 +21,14 @@ train_pipeline = [
policies='imagenet',
hparams=dict(
pad_val=[round(x) for x in bgr_mean], interpolation='bicubic')),
dict(type='PackClsInputs'),
dict(type='PackInputs'),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='ResizeEdge', scale=256, edge='short'),
dict(type='CenterCrop', crop_size=224),
dict(type='PackClsInputs'),
dict(type='PackInputs'),
]
train_dataloader = dict(
@ -36,11 +37,9 @@ train_dataloader = dict(
dataset=dict(
type=dataset_type,
data_root='data/imagenet',
ann_file='meta/train.txt',
data_prefix='train',
split='train',
pipeline=train_pipeline),
sampler=dict(type='DefaultSampler', shuffle=True),
persistent_workers=True,
)
val_dataloader = dict(
@ -49,11 +48,9 @@ val_dataloader = dict(
dataset=dict(
type=dataset_type,
data_root='data/imagenet',
ann_file='meta/val.txt',
data_prefix='val',
split='val',
pipeline=test_pipeline),
sampler=dict(type='DefaultSampler', shuffle=False),
persistent_workers=True,
)
val_evaluator = dict(type='Accuracy', topk=(1, 5))

View File

@ -0,0 +1,73 @@
# dataset settings
dataset_type = 'ImageNet'
img_norm_cfg = dict(
mean=[0.48145466 * 255, 0.4578275 * 255, 0.40821073 * 255],
std=[0.26862954 * 255, 0.26130258 * 255, 0.27577711 * 255],
to_rgb=True)
image_size = 224
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='RandomResizedCrop',
size=image_size,
backend='pillow',
interpolation='bicubic'),
dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
# dict(
# type='RandAugment',
# policies={{_base_.rand_increasing_policies}},
# num_policies=2,
# total_level=10,
# magnitude_level=9,
# magnitude_std=0.5,
# hparams=dict(
# pad_val=[round(x) for x in img_norm_cfg['mean'][::-1]],
# interpolation='bicubic')),
dict(
type='RandomErasing',
erase_prob=0.25,
mode='rand',
min_area_ratio=0.02,
max_area_ratio=1 / 3,
fill_color=img_norm_cfg['mean'][::-1],
fill_std=img_norm_cfg['std'][::-1]),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='ToTensor', keys=['gt_label']),
dict(type='Collect', keys=['img', 'gt_label'])
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='Resize',
size=(image_size, -1),
backend='pillow',
interpolation='bicubic'),
dict(type='CenterCrop', crop_size=image_size),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img'])
]
data = dict(
samples_per_gpu=64,
workers_per_gpu=8,
train=dict(
type=dataset_type,
data_root='data/imagenet',
split='train',
pipeline=train_pipeline),
val=dict(
type=dataset_type,
data_root='data/imagenet',
split='val',
pipeline=test_pipeline),
test=dict(
# replace `data/val` with `data/test` for standard test
type=dataset_type,
data_root='data/imagenet',
split='val',
pipeline=test_pipeline))
evaluation = dict(interval=10, metric='accuracy')

View File

@ -0,0 +1,73 @@
# dataset settings
dataset_type = 'ImageNet'
img_norm_cfg = dict(
mean=[0.48145466 * 255, 0.4578275 * 255, 0.40821073 * 255],
std=[0.26862954 * 255, 0.26130258 * 255, 0.27577711 * 255],
to_rgb=True)
image_size = 384
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='RandomResizedCrop',
size=image_size,
backend='pillow',
interpolation='bicubic'),
dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
# dict(
# type='RandAugment',
# policies={{_base_.rand_increasing_policies}},
# num_policies=2,
# total_level=10,
# magnitude_level=9,
# magnitude_std=0.5,
# hparams=dict(
# pad_val=[round(x) for x in img_norm_cfg['mean'][::-1]],
# interpolation='bicubic')),
dict(
type='RandomErasing',
erase_prob=0.25,
mode='rand',
min_area_ratio=0.02,
max_area_ratio=1 / 3,
fill_color=img_norm_cfg['mean'][::-1],
fill_std=img_norm_cfg['std'][::-1]),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='ToTensor', keys=['gt_label']),
dict(type='Collect', keys=['img', 'gt_label'])
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='Resize',
size=(image_size, -1),
backend='pillow',
interpolation='bicubic'),
dict(type='CenterCrop', crop_size=image_size),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img'])
]
data = dict(
samples_per_gpu=64,
workers_per_gpu=8,
train=dict(
type=dataset_type,
data_root='data/imagenet',
split='train',
pipeline=train_pipeline),
val=dict(
type=dataset_type,
data_root='data/imagenet',
split='val',
pipeline=test_pipeline),
test=dict(
# replace `data/val` with `data/test` for standard test
type=dataset_type,
data_root='data/imagenet',
split='val',
pipeline=test_pipeline))
evaluation = dict(interval=10, metric='accuracy')

View File

@ -0,0 +1,74 @@
# dataset settings
dataset_type = 'ImageNet'
img_norm_cfg = dict(
mean=[0.48145466 * 255, 0.4578275 * 255, 0.40821073 * 255],
std=[0.26862954 * 255, 0.26130258 * 255, 0.27577711 * 255],
to_rgb=True)
image_size = 448
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='RandomResizedCrop',
size=image_size,
backend='pillow',
interpolation='bicubic'),
dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
# dict(
# type='RandAugment',
# policies={{_base_.rand_increasing_policies}},
# num_policies=2,
# total_level=10,
# magnitude_level=9,
# magnitude_std=0.5,
# hparams=dict(
# pad_val=[round(x) for x in img_norm_cfg['mean'][::-1]],
# interpolation='bicubic')),
dict(
type='RandomErasing',
erase_prob=0.25,
mode='rand',
min_area_ratio=0.02,
max_area_ratio=1 / 3,
fill_color=img_norm_cfg['mean'][::-1],
fill_std=img_norm_cfg['std'][::-1]),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='ToTensor', keys=['gt_label']),
dict(type='Collect', keys=['img', 'gt_label'])
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='Resize',
size=(image_size, -1),
backend='pillow',
interpolation='bicubic'),
dict(type='CenterCrop', crop_size=image_size),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img'])
]
data = dict(
samples_per_gpu=64,
workers_per_gpu=8,
train=dict(
type=dataset_type,
data_root='data/imagenet',
split='train',
pipeline=train_pipeline),
val=dict(
type=dataset_type,
data_root='data/imagenet',
split='val',
pipeline=test_pipeline),
test=dict(
# replace `data/val` with `data/test` for standard test
type=dataset_type,
data_root='data/imagenet',
split='val',
pipeline=test_pipeline))
evaluation = dict(interval=10, metric='accuracy')

View File

@ -1,6 +1,7 @@
# dataset settings
dataset_type = 'ImageNet'
data_preprocessor = dict(
num_classes=1000,
# RGB format normalization parameters
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
@ -36,7 +37,7 @@ train_pipeline = [
max_area_ratio=1 / 3,
fill_color=bgr_mean,
fill_std=bgr_std),
dict(type='PackClsInputs')
dict(type='PackInputs')
]
test_pipeline = [
@ -48,7 +49,7 @@ test_pipeline = [
backend='pillow',
interpolation='bicubic'),
dict(type='CenterCrop', crop_size=224),
dict(type='PackClsInputs')
dict(type='PackInputs')
]
train_dataloader = dict(
@ -57,11 +58,9 @@ train_dataloader = dict(
dataset=dict(
type=dataset_type,
data_root='data/imagenet',
ann_file='meta/train.txt',
data_prefix='train',
split='train',
pipeline=train_pipeline),
sampler=dict(type='DefaultSampler', shuffle=True),
persistent_workers=True,
)
val_dataloader = dict(
@ -70,11 +69,9 @@ val_dataloader = dict(
dataset=dict(
type=dataset_type,
data_root='data/imagenet',
ann_file='meta/val.txt',
data_prefix='val',
split='val',
pipeline=test_pipeline),
sampler=dict(type='DefaultSampler', shuffle=False),
persistent_workers=True,
)
val_evaluator = dict(type='Accuracy', topk=(1, 5))

View File

@ -0,0 +1,80 @@
# dataset settings
dataset_type = 'ImageNet'
data_preprocessor = dict(
num_classes=1000,
# RGB format normalization parameters
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
# convert image from BGR to RGB
to_rgb=True,
)
bgr_mean = data_preprocessor['mean'][::-1]
bgr_std = data_preprocessor['std'][::-1]
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='RandomResizedCrop',
scale=224,
backend='pillow',
interpolation='bicubic'),
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
dict(
type='RandAugment',
policies='timm_increasing',
num_policies=2,
total_level=10,
magnitude_level=9,
magnitude_std=0.5,
hparams=dict(
pad_val=[round(x) for x in bgr_mean], interpolation='bicubic')),
dict(
type='RandomErasing',
erase_prob=0.25,
mode='rand',
min_area_ratio=0.02,
max_area_ratio=1 / 3,
fill_color=bgr_mean,
fill_std=bgr_std),
dict(type='PackInputs'),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='ResizeEdge',
scale=224,
edge='short',
backend='pillow',
interpolation='bicubic'),
dict(type='CenterCrop', crop_size=224),
dict(type='PackInputs'),
]
train_dataloader = dict(
batch_size=64,
num_workers=5,
dataset=dict(
type=dataset_type,
data_root='data/imagenet',
split='train',
pipeline=train_pipeline),
sampler=dict(type='DefaultSampler', shuffle=True),
)
val_dataloader = dict(
batch_size=64,
num_workers=5,
dataset=dict(
type=dataset_type,
data_root='data/imagenet',
split='val',
pipeline=test_pipeline),
sampler=dict(type='DefaultSampler', shuffle=False),
)
val_evaluator = dict(type='Accuracy', topk=(1, 5))
# If you want standard test, please manually configure the test dataset
test_dataloader = val_dataloader
test_evaluator = val_evaluator

View File

@ -0,0 +1,60 @@
# dataset settings
dataset_type = 'ImageNet'
data_preprocessor = dict(
num_classes=1000,
# RGB format normalization parameters
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
# convert image from BGR to RGB
to_rgb=True,
)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='RandomResizedCrop',
scale=384,
backend='pillow',
interpolation='bicubic'),
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
dict(type='PackInputs'),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='ResizeEdge',
scale=384,
edge='short',
backend='pillow',
interpolation='bicubic'),
dict(type='CenterCrop', crop_size=384),
dict(type='PackInputs'),
]
train_dataloader = dict(
batch_size=64,
num_workers=5,
dataset=dict(
type=dataset_type,
data_root='data/imagenet',
split='train',
pipeline=train_pipeline),
sampler=dict(type='DefaultSampler', shuffle=True),
)
val_dataloader = dict(
batch_size=64,
num_workers=5,
dataset=dict(
type=dataset_type,
data_root='data/imagenet',
split='val',
pipeline=test_pipeline),
sampler=dict(type='DefaultSampler', shuffle=False),
)
val_evaluator = dict(type='Accuracy', topk=(1, 5))
# If you want standard test, please manually configure the test dataset
test_dataloader = val_dataloader
test_evaluator = val_evaluator

View File

@ -1,6 +1,7 @@
# dataset settings
dataset_type = 'ImageNet'
data_preprocessor = dict(
num_classes=1000,
# RGB format normalization parameters
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
@ -36,7 +37,7 @@ train_pipeline = [
max_area_ratio=1 / 3,
fill_color=bgr_mean,
fill_std=bgr_std),
dict(type='PackClsInputs'),
dict(type='PackInputs'),
]
test_pipeline = [
@ -48,7 +49,7 @@ test_pipeline = [
backend='pillow',
interpolation='bicubic'),
dict(type='CenterCrop', crop_size=256),
dict(type='PackClsInputs')
dict(type='PackInputs')
]
train_dataloader = dict(
@ -57,11 +58,9 @@ train_dataloader = dict(
dataset=dict(
type=dataset_type,
data_root='data/imagenet',
ann_file='meta/train.txt',
data_prefix='train',
split='train',
pipeline=train_pipeline),
sampler=dict(type='DefaultSampler', shuffle=True),
persistent_workers=True,
)
val_dataloader = dict(
@ -70,11 +69,9 @@ val_dataloader = dict(
dataset=dict(
type=dataset_type,
data_root='data/imagenet',
ann_file='meta/val.txt',
data_prefix='val',
split='val',
pipeline=test_pipeline),
sampler=dict(type='DefaultSampler', shuffle=False),
persistent_workers=True,
)
val_evaluator = dict(type='Accuracy', topk=(1, 5))

View File

@ -0,0 +1,83 @@
# dataset settings
dataset_type = 'ImageNet'
data_root = 'data/imagenet/'
data_preprocessor = dict(
num_classes=1000,
# RGB format normalization parameters
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
# convert image from BGR to RGB
to_rgb=True,
)
bgr_mean = data_preprocessor['mean'][::-1]
bgr_std = data_preprocessor['std'][::-1]
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='RandomResizedCrop',
scale=224,
backend='pillow',
interpolation='bicubic'),
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
dict(
type='RandAugment',
policies='timm_increasing',
num_policies=2,
total_level=10,
magnitude_level=9,
magnitude_std=0.5,
hparams=dict(
pad_val=[round(x) for x in bgr_mean], interpolation='bicubic')),
dict(
type='RandomErasing',
erase_prob=0.25,
mode='rand',
min_area_ratio=0.02,
max_area_ratio=1 / 3,
fill_color=bgr_mean,
fill_std=bgr_std),
dict(type='PackInputs'),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='ResizeEdge',
scale=256,
edge='short',
backend='pillow',
interpolation='bicubic'),
dict(type='CenterCrop', crop_size=224),
dict(type='PackInputs'),
]
train_dataloader = dict(
batch_size=64,
num_workers=5,
dataset=dict(
type=dataset_type,
data_root=data_root,
ann_file='meta/train.txt',
data_prefix='train',
pipeline=train_pipeline),
sampler=dict(type='DefaultSampler', shuffle=True),
)
val_dataloader = dict(
batch_size=64,
num_workers=5,
dataset=dict(
type=dataset_type,
data_root=data_root,
ann_file='meta/val.txt',
data_prefix='val',
pipeline=test_pipeline),
sampler=dict(type='DefaultSampler', shuffle=False),
)
val_evaluator = dict(type='Accuracy', topk=(1, 5))
# If you want standard test, please manually configure the test dataset
test_dataloader = val_dataloader
test_evaluator = val_evaluator

View File

@ -3,6 +3,7 @@ dataset_type = 'ImageNet'
# Google research usually use the below normalization setting.
data_preprocessor = dict(
num_classes=1000,
mean=[127.5, 127.5, 127.5],
std=[127.5, 127.5, 127.5],
# convert image from BGR to RGB
@ -13,14 +14,14 @@ train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='RandomResizedCrop', scale=224),
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
dict(type='PackClsInputs'),
dict(type='PackInputs'),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='ResizeEdge', scale=256, edge='short', interpolation='bicubic'),
dict(type='CenterCrop', crop_size=224),
dict(type='PackClsInputs'),
dict(type='PackInputs'),
]
train_dataloader = dict(
@ -29,11 +30,9 @@ train_dataloader = dict(
dataset=dict(
type=dataset_type,
data_root='data/imagenet',
ann_file='meta/train.txt',
data_prefix='train',
split='train',
pipeline=train_pipeline),
sampler=dict(type='DefaultSampler', shuffle=True),
persistent_workers=True,
)
val_dataloader = dict(
@ -42,11 +41,9 @@ val_dataloader = dict(
dataset=dict(
type=dataset_type,
data_root='data/imagenet',
ann_file='meta/val.txt',
data_prefix='val',
split='val',
pipeline=test_pipeline),
sampler=dict(type='DefaultSampler', shuffle=False),
persistent_workers=True,
)
val_evaluator = dict(type='Accuracy', topk=(1, 5))

View File

@ -1,6 +1,7 @@
# dataset settings
dataset_type = 'ImageNet'
data_preprocessor = dict(
num_classes=1000,
# RGB format normalization parameters
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
@ -12,14 +13,14 @@ train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='RandomResizedCrop', scale=224, backend='pillow'),
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
dict(type='PackClsInputs'),
dict(type='PackInputs'),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='ResizeEdge', scale=256, edge='short', backend='pillow'),
dict(type='CenterCrop', crop_size=224),
dict(type='PackClsInputs'),
dict(type='PackInputs'),
]
train_dataloader = dict(
@ -28,11 +29,9 @@ train_dataloader = dict(
dataset=dict(
type=dataset_type,
data_root='data/imagenet',
ann_file='meta/train.txt',
data_prefix='train',
split='train',
pipeline=train_pipeline),
sampler=dict(type='DefaultSampler', shuffle=True),
persistent_workers=True,
)
val_dataloader = dict(
@ -41,11 +40,9 @@ val_dataloader = dict(
dataset=dict(
type=dataset_type,
data_root='data/imagenet',
ann_file='meta/val.txt',
data_prefix='val',
split='val',
pipeline=test_pipeline),
sampler=dict(type='DefaultSampler', shuffle=False),
persistent_workers=True,
)
val_evaluator = dict(type='Accuracy', topk=(1, 5))

View File

@ -1,6 +1,7 @@
# dataset settings
dataset_type = 'ImageNet'
data_preprocessor = dict(
num_classes=1000,
# RGB format normalization parameters
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
@ -24,7 +25,7 @@ train_pipeline = [
policies='imagenet',
hparams=dict(
pad_val=[round(x) for x in bgr_mean], interpolation='bicubic')),
dict(type='PackClsInputs'),
dict(type='PackInputs'),
]
test_pipeline = [
@ -36,7 +37,7 @@ test_pipeline = [
backend='pillow',
interpolation='bicubic'),
dict(type='CenterCrop', crop_size=224),
dict(type='PackClsInputs'),
dict(type='PackInputs'),
]
train_dataloader = dict(
@ -45,11 +46,9 @@ train_dataloader = dict(
dataset=dict(
type=dataset_type,
data_root='data/imagenet',
ann_file='meta/train.txt',
data_prefix='train',
split='train',
pipeline=train_pipeline),
sampler=dict(type='DefaultSampler', shuffle=True),
persistent_workers=True,
)
val_dataloader = dict(
@ -58,11 +57,9 @@ val_dataloader = dict(
dataset=dict(
type=dataset_type,
data_root='data/imagenet',
ann_file='meta/val.txt',
data_prefix='val',
split='val',
pipeline=test_pipeline),
sampler=dict(type='DefaultSampler', shuffle=False),
persistent_workers=True,
)
val_evaluator = dict(type='Accuracy', topk=(1, 5))

View File

@ -1,6 +1,7 @@
# dataset settings
dataset_type = 'ImageNet'
data_preprocessor = dict(
num_classes=1000,
# RGB format normalization parameters
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
@ -36,7 +37,7 @@ train_pipeline = [
max_area_ratio=1 / 3,
fill_color=bgr_mean,
fill_std=bgr_std),
dict(type='PackClsInputs'),
dict(type='PackInputs'),
]
test_pipeline = [
@ -48,7 +49,7 @@ test_pipeline = [
backend='pillow',
interpolation='bicubic'),
dict(type='CenterCrop', crop_size=224),
dict(type='PackClsInputs'),
dict(type='PackInputs'),
]
train_dataloader = dict(
@ -57,11 +58,9 @@ train_dataloader = dict(
dataset=dict(
type=dataset_type,
data_root='data/imagenet',
ann_file='meta/train.txt',
data_prefix='train',
split='train',
pipeline=train_pipeline),
sampler=dict(type='DefaultSampler', shuffle=True),
persistent_workers=True,
)
val_dataloader = dict(
@ -70,11 +69,9 @@ val_dataloader = dict(
dataset=dict(
type=dataset_type,
data_root='data/imagenet',
ann_file='meta/val.txt',
data_prefix='val',
split='val',
pipeline=test_pipeline),
sampler=dict(type='DefaultSampler', shuffle=False),
persistent_workers=True,
)
val_evaluator = dict(type='Accuracy', topk=(1, 5))

View File

@ -1,6 +1,7 @@
# dataset settings
dataset_type = 'ImageNet'
data_preprocessor = dict(
num_classes=1000,
# RGB format normalization parameters
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
@ -36,7 +37,7 @@ train_pipeline = [
max_area_ratio=1 / 3,
fill_color=bgr_mean,
fill_std=bgr_std),
dict(type='PackClsInputs'),
dict(type='PackInputs'),
]
test_pipeline = [
@ -48,7 +49,7 @@ test_pipeline = [
backend='pillow',
interpolation='bicubic'),
dict(type='CenterCrop', crop_size=256),
dict(type='PackClsInputs'),
dict(type='PackInputs'),
]
train_dataloader = dict(
@ -57,11 +58,9 @@ train_dataloader = dict(
dataset=dict(
type=dataset_type,
data_root='data/imagenet',
ann_file='meta/train.txt',
data_prefix='train',
split='train',
pipeline=train_pipeline),
sampler=dict(type='DefaultSampler', shuffle=True),
persistent_workers=True,
)
val_dataloader = dict(
@ -70,11 +69,9 @@ val_dataloader = dict(
dataset=dict(
type=dataset_type,
data_root='data/imagenet',
ann_file='meta/val.txt',
data_prefix='val',
split='val',
pipeline=test_pipeline),
sampler=dict(type='DefaultSampler', shuffle=False),
persistent_workers=True,
)
val_evaluator = dict(type='Accuracy', topk=(1, 5))

View File

@ -1,6 +1,7 @@
# dataset settings
dataset_type = 'ImageNet'
data_preprocessor = dict(
num_classes=1000,
# RGB format normalization parameters
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
@ -16,13 +17,13 @@ train_pipeline = [
backend='pillow',
interpolation='bicubic'),
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
dict(type='PackClsInputs'),
dict(type='PackInputs'),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='Resize', scale=384, backend='pillow', interpolation='bicubic'),
dict(type='PackClsInputs'),
dict(type='PackInputs'),
]
train_dataloader = dict(
@ -31,11 +32,9 @@ train_dataloader = dict(
dataset=dict(
type=dataset_type,
data_root='data/imagenet',
ann_file='meta/train.txt',
data_prefix='train',
split='train',
pipeline=train_pipeline),
sampler=dict(type='DefaultSampler', shuffle=True),
persistent_workers=True,
)
val_dataloader = dict(
@ -44,11 +43,9 @@ val_dataloader = dict(
dataset=dict(
type=dataset_type,
data_root='data/imagenet',
ann_file='meta/val.txt',
data_prefix='val',
split='val',
pipeline=test_pipeline),
sampler=dict(type='DefaultSampler', shuffle=False),
persistent_workers=True,
)
val_evaluator = dict(type='Accuracy', topk=(1, 5))

View File

@ -1,6 +1,7 @@
# dataset settings
dataset_type = 'ImageNet'
data_preprocessor = dict(
num_classes=1000,
# RGB format normalization parameters
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
@ -36,7 +37,7 @@ train_pipeline = [
max_area_ratio=1 / 3,
fill_color=bgr_mean,
fill_std=bgr_std),
dict(type='PackClsInputs'),
dict(type='PackInputs'),
]
test_pipeline = [
@ -48,7 +49,7 @@ test_pipeline = [
backend='pillow',
interpolation='bicubic'),
dict(type='CenterCrop', crop_size=224),
dict(type='PackClsInputs'),
dict(type='PackInputs'),
]
train_dataloader = dict(
@ -57,11 +58,9 @@ train_dataloader = dict(
dataset=dict(
type=dataset_type,
data_root='data/imagenet',
ann_file='meta/train.txt',
data_prefix='train',
split='train',
pipeline=train_pipeline),
sampler=dict(type='DefaultSampler', shuffle=True),
persistent_workers=True,
)
val_dataloader = dict(
@ -70,11 +69,9 @@ val_dataloader = dict(
dataset=dict(
type=dataset_type,
data_root='data/imagenet',
ann_file='meta/val.txt',
data_prefix='val',
split='val',
pipeline=test_pipeline),
sampler=dict(type='DefaultSampler', shuffle=False),
persistent_workers=True,
)
val_evaluator = dict(type='Accuracy', topk=(1, 5))

View File

@ -0,0 +1,59 @@
# dataset settings
dataset_type = 'ImageNet'
data_preprocessor = dict(
# RGB format normalization parameters
mean=[122.5, 122.5, 122.5],
std=[122.5, 122.5, 122.5],
# convert image from BGR to RGB
to_rgb=True,
)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='RandomResizedCrop',
scale=320,
backend='pillow',
interpolation='bicubic'),
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
dict(type='PackInputs'),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='ResizeEdge',
scale=int(320 / 224 * 256),
edge='short',
backend='pillow',
interpolation='bicubic'),
dict(type='CenterCrop', crop_size=320),
dict(type='PackInputs'),
]
train_dataloader = dict(
batch_size=8,
num_workers=5,
dataset=dict(
type=dataset_type,
data_root='data/imagenet',
split='train',
pipeline=train_pipeline),
sampler=dict(type='DefaultSampler', shuffle=True),
)
val_dataloader = dict(
batch_size=8,
num_workers=5,
dataset=dict(
type=dataset_type,
data_root='data/imagenet',
split='val',
pipeline=test_pipeline),
sampler=dict(type='DefaultSampler', shuffle=False),
)
val_evaluator = dict(type='Accuracy', topk=(1, 5))
# If you want standard test, please manually configure the test dataset
test_dataloader = val_dataloader
test_evaluator = val_evaluator

View File

@ -0,0 +1,64 @@
# dataset settings
dataset_type = 'InShop'
data_preprocessor = dict(
num_classes=3997,
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
# convert image from BGR to RGB
to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='Resize', scale=512),
dict(type='RandomCrop', crop_size=448),
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
dict(type='PackInputs'),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='Resize', scale=512),
dict(type='CenterCrop', crop_size=448),
dict(type='PackInputs'),
]
train_dataloader = dict(
batch_size=32,
num_workers=4,
dataset=dict(
type=dataset_type,
data_root='data/inshop',
split='train',
pipeline=train_pipeline),
sampler=dict(type='DefaultSampler', shuffle=True),
)
query_dataloader = dict(
batch_size=32,
num_workers=4,
dataset=dict(
type=dataset_type,
data_root='data/inshop',
split='query',
pipeline=test_pipeline),
sampler=dict(type='DefaultSampler', shuffle=False),
)
gallery_dataloader = dict(
batch_size=32,
num_workers=4,
dataset=dict(
type=dataset_type,
data_root='data/inshop',
split='gallery',
pipeline=test_pipeline),
sampler=dict(type='DefaultSampler', shuffle=False),
)
val_dataloader = query_dataloader
val_evaluator = [
dict(type='RetrievalRecall', topk=1),
dict(type='RetrievalAveragePrecision', topk=10),
]
test_dataloader = val_dataloader
test_evaluator = val_evaluator

View File

@ -0,0 +1,86 @@
# dataset settings
data_preprocessor = dict(
type='MultiModalDataPreprocessor',
mean=[122.770938, 116.7460125, 104.09373615],
std=[68.5005327, 66.6321579, 70.32316305],
to_rgb=True,
)
train_pipeline = [
dict(
type='ApplyToList',
# NLVR requires to load two images in task.
scatter_key='img_path',
transforms=[
dict(type='LoadImageFromFile'),
dict(
type='RandomResizedCrop',
scale=384,
interpolation='bicubic',
backend='pillow'),
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
],
collate_keys=['img', 'scale_factor', 'ori_shape'],
),
dict(type='CleanCaption', keys='text'),
dict(
type='PackInputs',
algorithm_keys=['text'],
meta_keys=['image_id'],
),
]
test_pipeline = [
dict(
type='ApplyToList',
# NLVR requires to load two images in task.
scatter_key='img_path',
transforms=[
dict(type='LoadImageFromFile'),
dict(
type='Resize',
scale=(384, 384),
interpolation='bicubic',
backend='pillow'),
],
collate_keys=['img', 'scale_factor', 'ori_shape'],
),
dict(
type='PackInputs',
algorithm_keys=['text'],
meta_keys=['image_id'],
),
]
train_dataloader = dict(
batch_size=16,
num_workers=8,
dataset=dict(
type='NLVR2',
data_root='data/nlvr2',
ann_file='dev.json',
data_prefix='dev',
pipeline=train_pipeline),
sampler=dict(type='DefaultSampler', shuffle=True),
persistent_workers=True,
drop_last=True,
)
val_dataloader = dict(
batch_size=64,
num_workers=8,
dataset=dict(
type='NLVR2',
data_root='data/nlvr2',
ann_file='dev.json',
data_prefix='dev',
pipeline=test_pipeline,
),
sampler=dict(type='DefaultSampler', shuffle=False),
persistent_workers=True,
)
val_evaluator = dict(type='Accuracy')
# If you want standard test, please manually configure the test dataset
test_dataloader = val_dataloader
test_evaluator = val_evaluator

Some files were not shown because too many files have changed in this diff Show More