Compare commits
303 Commits
Author | SHA1 | Date |
---|---|---|
|
ee7f2e8850 | |
|
17a886cb58 | |
|
9ac4b316f0 | |
|
3022f9af7b | |
|
e95d9acb89 | |
|
6e00cbecaa | |
|
ed5924b6fe | |
|
a4c219e05d | |
|
d35c778a6f | |
|
c0766519b1 | |
|
4849324629 | |
|
b0a792eb08 | |
|
3bcf7e2d6e | |
|
06bb586eb7 | |
|
5c71de6b8e | |
|
7734f073e4 | |
|
b0b4422736 | |
|
9b75ce0aa4 | |
|
f4d372ba7d | |
|
ed3b7f8ae6 | |
|
ddc6d0b121 | |
|
da1da48eb6 | |
|
bb59c9ad82 | |
|
845b462190 | |
|
634852ad61 | |
|
e1675e893e | |
|
d2ccc44a2c | |
|
853f0c6bca | |
|
732b0f4c98 | |
|
b65a96a89c | |
|
6bb0c8a987 | |
|
bf62497e02 | |
|
6474d6befa | |
|
1be28ea7c4 | |
|
bff80d3c48 | |
|
29d706248c | |
|
fa53174fd9 | |
|
827a216155 | |
|
1dda91bf24 | |
|
2fb52eefdc | |
|
340d187765 | |
|
4f2f3752d9 | |
|
5c71eba13d | |
|
58a2243d99 | |
|
1f99279657 | |
|
0b96dcaa67 | |
|
b1cd05caf2 | |
|
e7fc25cf64 | |
|
c5248b17b7 | |
|
4d1dbafaa2 | |
|
2b8d8eecb2 | |
|
64c446d507 | |
|
60d780f99e | |
|
569324b180 | |
|
db395d35b1 | |
|
465b6bdeec | |
|
5c43d3ef42 | |
|
78d0ddc852 | |
|
ae7a7b7560 | |
|
0d80ab4650 | |
|
8eaf8090e6 | |
|
130751185c | |
|
7cbfb36c14 | |
|
feb0814b2f | |
|
00030e3f7d | |
|
59c077746f | |
|
8afad77a35 | |
|
658db80089 | |
|
68758db7a8 | |
|
10685fc81c | |
|
70ff2abbf7 | |
|
d4a6dfa00a | |
|
7d850dfadd | |
|
dbef2b41c6 | |
|
d6056af2b8 | |
|
6d7fe91a98 | |
|
a1cfe888e2 | |
|
bfd49b0d52 | |
|
e69bace03f | |
|
9d3fc43073 | |
|
a673b048a5 | |
|
aac398a83f | |
|
93e0f107c4 | |
|
7581b76233 | |
|
53648baca5 | |
|
3eaf719a64 | |
|
8e9e880601 | |
|
bb415b91be | |
|
dbfb84ccbd | |
|
057d7c6d6a | |
|
bddbc085fc | |
|
3a277ee9e6 | |
|
bc3c4a35ee | |
|
795607cfeb | |
|
5bd088ef43 | |
|
e4c4a81b56 | |
|
1f07c92ed1 | |
|
9bb692e440 | |
|
a779c8c5a7 | |
|
46a523ef63 | |
|
4dd8a86145 | |
|
be389eb846 | |
|
023d6869bd | |
|
b058912c0c | |
|
1e478462b8 | |
|
d04ef8a29e | |
|
74f24658e7 | |
|
13e4d6c512 | |
|
b0ad99afb9 | |
|
1537d46596 | |
|
87f849cbb6 | |
|
1b8e86dca6 | |
|
6847d20d57 | |
|
770eb8e24a | |
|
034919d032 | |
|
7f4eccbecf | |
|
9cf37b315c | |
|
afa60c73bb | |
|
d9e561a09d | |
|
496e098b21 | |
|
a3fa328f09 | |
|
b51d7d21de | |
|
15cc2a5193 | |
|
6ceba070a8 | |
|
3cd4fd4d64 | |
|
e954cf0aaf | |
|
fec3da781f | |
|
2c913020b9 | |
|
e93d124ad4 | |
|
02571fe4b8 | |
|
645e2b4ed4 | |
|
99e48116aa | |
|
0826df8963 | |
|
e80418a424 | |
|
9cbeceabb5 | |
|
47e033c466 | |
|
5ea46fbbbc | |
|
1e78f09d87 | |
|
79ddc0f874 | |
|
411e05a705 | |
|
05124dbb71 | |
|
b8cab5c9f7 | |
|
3932ddec10 | |
|
5c3abb2b2a | |
|
e115ac89f4 | |
|
53a57c6dad | |
|
e4d8511ddf | |
|
c9c7d9cc0f | |
|
a6c24d104e | |
|
e7da3f29f4 | |
|
61b795f21f | |
|
0ef0b5ce08 | |
|
32c258ff19 | |
|
0b70c108b0 | |
|
1ee9bbe050 | |
|
3069e43f77 | |
|
75dceaa78f | |
|
3a25b13eb3 | |
|
568188a6b0 | |
|
9fb4e9c911 | |
|
445eb3223a | |
|
b017670e1b | |
|
164f16e248 | |
|
555adab0a0 | |
|
53dc810c08 | |
|
c4ccae40db | |
|
a50d96f7f1 | |
|
175d19f67e | |
|
1f78ab410f | |
|
6038df9514 | |
|
f6b65fcbe7 | |
|
04e15ab347 | |
|
6cedce234e | |
|
4f5b38f225 | |
|
8875e9da92 | |
|
76a1f3f735 | |
|
3472ee5d2c | |
|
dbf3df21a3 | |
|
63e5b512cc | |
|
274a67223e | |
|
827be6e22d | |
|
08dc8c75d3 | |
|
a05c79e806 | |
|
1d6e37e56b | |
|
e035e03d59 | |
|
dda3d6565b | |
|
c9670173aa | |
|
414ba80274 | |
|
e453a45d31 | |
|
63d9f27fde | |
|
75c79311f4 | |
|
89000c10eb | |
|
36bea13fca | |
|
4016f1348e | |
|
0979e78573 | |
|
8352951f3d | |
|
bedf4e9f64 | |
|
b4ee9d2848 | |
|
841256b630 | |
|
1c1273abca | |
|
705ed2be49 | |
|
7ec6062415 | |
|
58cefa5c0f | |
|
4ce7be17c9 | |
|
a3f2effb17 | |
|
7e4502b0ac | |
|
353886eaca | |
|
6b9e2b55dd | |
|
c98dc4555c | |
|
c73a5a8b15 | |
|
97c4ae8805 | |
|
aa53f7790c | |
|
060b0ed3b5 | |
|
e880451a54 | |
|
c7ec630c37 | |
|
0d8f918eaa | |
|
4f5350f365 | |
|
88e5ba28db | |
|
e0e6a1f1ae | |
|
74743ef588 | |
|
9038c1c255 | |
|
bac181f393 | |
|
5b266d9e7c | |
|
14dcb69092 | |
|
7dcf34533d | |
|
5547f4cac4 | |
|
9e82db6032 | |
|
3006fa26ab | |
|
b63515111b | |
|
0e4163668f | |
|
6ea59bd846 | |
|
e9f9bb200e | |
|
46af7d3ed2 | |
|
2535c1ecd7 | |
|
210373c093 | |
|
1c6b077bb1 | |
|
ea53bce580 | |
|
458ac4c89a | |
|
12eca5b94a | |
|
ef3610d962 | |
|
a4ec2799d6 | |
|
c127c474b9 | |
|
d990982fc0 | |
|
df2f122daa | |
|
7b9a1010f5 | |
|
d80ec5a4b8 | |
|
35fb03a577 | |
|
f9be21ab74 | |
|
75e502ed75 | |
|
a4cfd55dd2 | |
|
44d2886422 | |
|
13ff394985 | |
|
b0007812d6 | |
|
4fb44f8770 | |
|
743ca2d602 | |
|
940a06f645 | |
|
4969830c8a | |
|
629f6447ef | |
|
0e8cfa6286 | |
|
72c6bc4864 | |
|
c3c1cb93aa | |
|
f458bf5a64 | |
|
e51ecfb129 | |
|
c4f3883a22 | |
|
992d13e772 | |
|
1b98fc13d9 | |
|
cf5879988d | |
|
11cd88f39a | |
|
ee9ee9cf3c | |
|
542143cb41 | |
|
2151beeb77 | |
|
c48cfa9f47 | |
|
28b71c15bd | |
|
9eb6fc4368 | |
|
8cc1fdef52 | |
|
d05cbbcf9b | |
|
b16938dc59 | |
|
6203fd6cc9 | |
|
63b124e2d7 | |
|
ef512c98d0 | |
|
9506241f73 | |
|
693596bc2f | |
|
29c46c8af2 | |
|
50aaa711ea | |
|
280e916979 | |
|
cccbedf22d | |
|
b526f018db | |
|
bcca619066 | |
|
29f066f7fb | |
|
06c919efc2 | |
|
31c67ffed4 | |
|
dfe0874102 | |
|
9bc58745d1 | |
|
a49c3076e1 | |
|
043574cbb2 | |
|
2153a16dc5 | |
|
dfb4e87123 | |
|
f452e242a7 | |
|
bf9f3bbdda | |
|
b9bb21738b | |
|
a1642e42da | |
|
ae37d7fd27 | |
|
23cad6a0e1 |
.circleci
configs/_base_/datasets
|
@ -22,11 +22,11 @@ workflows:
|
|||
# line:
|
||||
# <regex path-to-test> <parameter-to-set> <value-of-pipeline-parameter>
|
||||
mapping: |
|
||||
mmcls/.* lint_only false
|
||||
mmpretrain/.* lint_only false
|
||||
requirements/.* lint_only false
|
||||
tests/.* lint_only false
|
||||
.circleci/.* lint_only false
|
||||
base-revision: dev-1.x
|
||||
base-revision: main
|
||||
# this is the path of the configuration we should trigger once
|
||||
# path filtering and pipeline parameter value updates are
|
||||
# complete. In this case, we are using the parent dynamic
|
||||
|
|
|
@ -31,7 +31,58 @@ jobs:
|
|||
name: Check docstring coverage
|
||||
command: |
|
||||
pip install interrogate
|
||||
interrogate -v --ignore-init-method --ignore-module --ignore-nested-functions --ignore-magic --ignore-regex "__repr__" --fail-under 60 mmcls
|
||||
interrogate -v --ignore-init-method --ignore-module --ignore-nested-functions --ignore-magic --ignore-regex "__repr__" --fail-under 60 mmpretrain
|
||||
build_cpu_with_3rdparty:
|
||||
parameters:
|
||||
# The python version must match available image tags in
|
||||
# https://circleci.com/developer/images/image/cimg/python
|
||||
python:
|
||||
type: string
|
||||
torch:
|
||||
type: string
|
||||
torchvision:
|
||||
type: string
|
||||
docker:
|
||||
- image: cimg/python:<< parameters.python >>
|
||||
resource_class: large
|
||||
steps:
|
||||
- checkout
|
||||
- run:
|
||||
name: Install Libraries
|
||||
command: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y libjpeg8-dev zlib1g-dev
|
||||
- run:
|
||||
name: Configure Python & pip
|
||||
command: |
|
||||
pip install --upgrade pip
|
||||
pip install wheel
|
||||
- run:
|
||||
name: Install PyTorch
|
||||
command: |
|
||||
python -V
|
||||
pip install torch==<< parameters.torch >>+cpu torchvision==<< parameters.torchvision >>+cpu -f https://download.pytorch.org/whl/torch_stable.html
|
||||
- run:
|
||||
name: Install mmpretrain dependencies
|
||||
command: |
|
||||
pip install git+https://github.com/open-mmlab/mmengine.git@main
|
||||
pip install -U openmim
|
||||
mim install 'mmcv >= 2.0.0rc4'
|
||||
pip install timm
|
||||
pip install transformers
|
||||
pip install -r requirements.txt
|
||||
python -c 'import mmcv; print(mmcv.__version__)'
|
||||
- run:
|
||||
name: Build and install
|
||||
command: |
|
||||
pip install -e .
|
||||
- run:
|
||||
name: Run unittests
|
||||
command: |
|
||||
coverage run --branch --source mmpretrain -m pytest tests/
|
||||
coverage xml
|
||||
coverage report -m
|
||||
|
||||
build_cpu:
|
||||
parameters:
|
||||
# The python version must match available image tags in
|
||||
|
@ -63,12 +114,11 @@ jobs:
|
|||
python -V
|
||||
pip install torch==<< parameters.torch >>+cpu torchvision==<< parameters.torchvision >>+cpu -f https://download.pytorch.org/whl/torch_stable.html
|
||||
- run:
|
||||
name: Install mmcls dependencies
|
||||
name: Install mmpretrain dependencies
|
||||
command: |
|
||||
pip install git+https://github.com/open-mmlab/mmengine.git@main
|
||||
pip install -U openmim
|
||||
mim install 'mmcv >= 2.0.0rc1'
|
||||
pip install timm
|
||||
mim install 'mmcv >= 2.0.0rc4'
|
||||
pip install -r requirements.txt
|
||||
python -c 'import mmcv; print(mmcv.__version__)'
|
||||
- run:
|
||||
|
@ -78,7 +128,7 @@ jobs:
|
|||
- run:
|
||||
name: Run unittests
|
||||
command: |
|
||||
coverage run --branch --source mmcls -m pytest tests/
|
||||
coverage run --branch --source mmpretrain -m pytest tests/
|
||||
coverage xml
|
||||
coverage report -m
|
||||
|
||||
|
@ -86,15 +136,17 @@ jobs:
|
|||
machine:
|
||||
image: ubuntu-2004-cuda-11.4:202110-01
|
||||
resource_class: gpu.nvidia.small
|
||||
environment:
|
||||
MKL_SERVICE_FORCE_INTEL: 1
|
||||
parameters:
|
||||
torch:
|
||||
type: string
|
||||
cuda:
|
||||
type: enum
|
||||
enum: ["10.1", "10.2", "11.1"]
|
||||
enum: ["11.1", "11.7"]
|
||||
cudnn:
|
||||
type: integer
|
||||
default: 7
|
||||
default: 8
|
||||
steps:
|
||||
- checkout
|
||||
- run:
|
||||
|
@ -105,24 +157,24 @@ jobs:
|
|||
- run:
|
||||
name: Build Docker image
|
||||
command: |
|
||||
docker build .circleci/docker -t mmcls:gpu --build-arg PYTORCH=<< parameters.torch >> --build-arg CUDA=<< parameters.cuda >> --build-arg CUDNN=<< parameters.cudnn >>
|
||||
docker run --gpus all -t -d -v /home/circleci/project:/mmcls -v /home/circleci/mmengine:/mmengine -w /mmcls --name mmcls mmcls:gpu
|
||||
docker build .circleci/docker -t mmpretrain:gpu --build-arg PYTORCH=<< parameters.torch >> --build-arg CUDA=<< parameters.cuda >> --build-arg CUDNN=<< parameters.cudnn >>
|
||||
docker run --gpus all -t -d -v /home/circleci/project:/mmpretrain -v /home/circleci/mmengine:/mmengine -w /mmpretrain --name mmpretrain mmpretrain:gpu
|
||||
- run:
|
||||
name: Install mmcls dependencies
|
||||
name: Install mmpretrain dependencies
|
||||
command: |
|
||||
docker exec mmcls pip install -e /mmengine
|
||||
docker exec mmcls pip install -U openmim
|
||||
docker exec mmcls mim install 'mmcv >= 2.0.0rc1'
|
||||
docker exec mmcls pip install -r requirements.txt
|
||||
docker exec mmcls python -c 'import mmcv; print(mmcv.__version__)'
|
||||
docker exec mmpretrain pip install -e /mmengine
|
||||
docker exec mmpretrain pip install -U openmim
|
||||
docker exec mmpretrain mim install 'mmcv >= 2.0.0rc4'
|
||||
docker exec mmpretrain pip install -r requirements.txt
|
||||
docker exec mmpretrain python -c 'import mmcv; print(mmcv.__version__)'
|
||||
- run:
|
||||
name: Build and install
|
||||
command: |
|
||||
docker exec mmcls pip install -e .
|
||||
docker exec mmpretrain pip install -e .
|
||||
- run:
|
||||
name: Run unittests
|
||||
command: |
|
||||
docker exec mmcls python -m pytest tests/ -k 'not timm'
|
||||
docker exec mmpretrain python -m pytest tests/
|
||||
|
||||
# Invoke jobs via workflows
|
||||
# See: https://circleci.com/docs/2.0/configuration-reference/#workflows
|
||||
|
@ -135,8 +187,8 @@ workflows:
|
|||
filters:
|
||||
branches:
|
||||
ignore:
|
||||
- dev-1.x
|
||||
- 1.x
|
||||
- dev
|
||||
- main
|
||||
pr_stage_test:
|
||||
when:
|
||||
not:
|
||||
|
@ -147,19 +199,19 @@ workflows:
|
|||
filters:
|
||||
branches:
|
||||
ignore:
|
||||
- dev-1.x
|
||||
- dev
|
||||
- build_cpu:
|
||||
name: minimum_version_cpu
|
||||
torch: 1.6.0
|
||||
torchvision: 0.7.0
|
||||
python: 3.6.9 # The lowest python 3.6.x version available on CircleCI images
|
||||
torch: 1.8.0
|
||||
torchvision: 0.9.0
|
||||
python: 3.7.16
|
||||
requires:
|
||||
- lint
|
||||
- build_cpu:
|
||||
- build_cpu_with_3rdparty:
|
||||
name: maximum_version_cpu
|
||||
torch: 1.12.1
|
||||
torchvision: 0.13.1
|
||||
python: 3.9.0
|
||||
torch: 2.0.0
|
||||
torchvision: 0.15.1
|
||||
python: 3.10.0
|
||||
requires:
|
||||
- minimum_version_cpu
|
||||
- hold:
|
||||
|
@ -171,7 +223,14 @@ workflows:
|
|||
torch: 1.8.1
|
||||
# Use double quotation mark to explicitly specify its type
|
||||
# as string instead of number
|
||||
cuda: "10.2"
|
||||
cuda: "11.1"
|
||||
requires:
|
||||
- hold
|
||||
- build_cuda:
|
||||
name: maximum_version_gpu
|
||||
torch: 2.0.0
|
||||
cuda: "11.7"
|
||||
cudnn: 8
|
||||
requires:
|
||||
- hold
|
||||
merge_stage_test:
|
||||
|
@ -181,11 +240,11 @@ workflows:
|
|||
jobs:
|
||||
- build_cuda:
|
||||
name: minimum_version_gpu
|
||||
torch: 1.6.0
|
||||
torch: 1.8.0
|
||||
# Use double quotation mark to explicitly specify its type
|
||||
# as string instead of number
|
||||
cuda: "10.1"
|
||||
cuda: "11.1"
|
||||
filters:
|
||||
branches:
|
||||
only:
|
||||
- dev-1.x
|
||||
- pretrain
|
||||
|
|
|
@ -1,25 +1,27 @@
|
|||
import logging
|
||||
import re
|
||||
import sys
|
||||
import tempfile
|
||||
from argparse import ArgumentParser
|
||||
from collections import OrderedDict
|
||||
from pathlib import Path
|
||||
from time import time
|
||||
from time import perf_counter
|
||||
from unittest.mock import Mock
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
import torch
|
||||
from mmengine import Config, DictAction, MMLogger
|
||||
from mmengine import DictAction, MMLogger
|
||||
from mmengine.dataset import Compose, default_collate
|
||||
from mmengine.fileio import FileClient
|
||||
from mmengine.runner import Runner
|
||||
from modelindex.load_model_index import load
|
||||
from mmengine.device import get_device
|
||||
from mmengine.model.utils import revert_sync_batchnorm
|
||||
from mmengine.runner import Runner, load_checkpoint
|
||||
from rich.console import Console
|
||||
from rich.table import Table
|
||||
from utils import substitute_weights
|
||||
|
||||
from mmcls.datasets import CIFAR10, CIFAR100, ImageNet
|
||||
from mmcls.utils import register_all_modules
|
||||
from mmcls.visualization import ClsVisualizer
|
||||
from mmpretrain.apis import ModelHub, get_model, list_models
|
||||
from mmpretrain.datasets import CIFAR10, CIFAR100, ImageNet
|
||||
from mmpretrain.utils import register_all_modules
|
||||
from mmpretrain.visualization import UniversalVisualizer
|
||||
|
||||
console = Console()
|
||||
MMCLS_ROOT = Path(__file__).absolute().parents[2]
|
||||
|
@ -30,6 +32,12 @@ classes_map = {
|
|||
'CIFAR-100': CIFAR100.CLASSES,
|
||||
}
|
||||
|
||||
logger = MMLogger.get_instance('validation', logger_name='mmpretrain')
|
||||
logger.handlers[0].stream = sys.stderr
|
||||
logger.addHandler(logging.FileHandler('benchmark_valid.log', mode='w'))
|
||||
# Force to use the logger in runners.
|
||||
Runner.build_logger = Mock(return_value=logger)
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = ArgumentParser(description='Valid all models in model-index.yml')
|
||||
|
@ -48,6 +56,11 @@ def parse_args():
|
|||
'--inference-time',
|
||||
action='store_true',
|
||||
help='Test inference time by run 10 times for each model.')
|
||||
parser.add_argument(
|
||||
'--batch-size',
|
||||
type=int,
|
||||
default=1,
|
||||
help='The batch size during the inference.')
|
||||
parser.add_argument(
|
||||
'--flops', action='store_true', help='Get Flops and Params of models')
|
||||
parser.add_argument(
|
||||
|
@ -68,65 +81,76 @@ def parse_args():
|
|||
return args
|
||||
|
||||
|
||||
def inference(config_file, checkpoint, work_dir, args, exp_name):
|
||||
cfg = Config.fromfile(config_file)
|
||||
def inference(metainfo, checkpoint, work_dir, args, exp_name=None):
|
||||
cfg = metainfo.config
|
||||
cfg.work_dir = work_dir
|
||||
cfg.load_from = checkpoint
|
||||
cfg.log_level = 'WARN'
|
||||
cfg.experiment_name = exp_name
|
||||
cfg.experiment_name = exp_name or metainfo.name
|
||||
if args.cfg_options is not None:
|
||||
cfg.merge_from_dict(args.cfg_options)
|
||||
|
||||
# build the data pipeline
|
||||
test_dataset = cfg.test_dataloader.dataset
|
||||
if test_dataset.pipeline[0]['type'] != 'LoadImageFromFile':
|
||||
test_dataset.pipeline.insert(0, dict(type='LoadImageFromFile'))
|
||||
if test_dataset.type in ['CIFAR10', 'CIFAR100']:
|
||||
# The image shape of CIFAR is (32, 32, 3)
|
||||
test_dataset.pipeline.insert(1, dict(type='Resize', scale=32))
|
||||
if 'test_dataloader' in cfg:
|
||||
# build the data pipeline
|
||||
test_dataset = cfg.test_dataloader.dataset
|
||||
if test_dataset.pipeline[0]['type'] != 'LoadImageFromFile':
|
||||
test_dataset.pipeline.insert(0, dict(type='LoadImageFromFile'))
|
||||
if test_dataset.type in ['CIFAR10', 'CIFAR100']:
|
||||
# The image shape of CIFAR is (32, 32, 3)
|
||||
test_dataset.pipeline.insert(1, dict(type='Resize', scale=32))
|
||||
|
||||
data = Compose(test_dataset.pipeline)({'img_path': args.img})
|
||||
data = default_collate([data])
|
||||
resolution = tuple(data['inputs'].shape[-2:])
|
||||
data = Compose(test_dataset.pipeline)({'img_path': args.img})
|
||||
data = default_collate([data] * args.batch_size)
|
||||
resolution = tuple(data['inputs'].shape[-2:])
|
||||
model = Runner.from_cfg(cfg).model
|
||||
model = revert_sync_batchnorm(model)
|
||||
model.eval()
|
||||
forward = model.val_step
|
||||
else:
|
||||
# For configs without data settings.
|
||||
model = get_model(cfg, device=get_device())
|
||||
model = revert_sync_batchnorm(model)
|
||||
model.eval()
|
||||
data = torch.rand(1, 3, 224, 224).to(model.data_preprocessor.device)
|
||||
resolution = (224, 224)
|
||||
forward = model.extract_feat
|
||||
|
||||
runner: Runner = Runner.from_cfg(cfg)
|
||||
model = runner.model
|
||||
if checkpoint is not None:
|
||||
load_checkpoint(model, checkpoint, map_location='cpu')
|
||||
|
||||
# forward the model
|
||||
result = {'resolution': resolution}
|
||||
result = {'model': metainfo.name, 'resolution': resolution}
|
||||
with torch.no_grad():
|
||||
if args.inference_time:
|
||||
time_record = []
|
||||
forward(data) # warmup before profiling
|
||||
for _ in range(10):
|
||||
model.val_step(data) # warmup before profiling
|
||||
torch.cuda.synchronize()
|
||||
start = time()
|
||||
model.val_step(data)
|
||||
start = perf_counter()
|
||||
forward(data)
|
||||
torch.cuda.synchronize()
|
||||
time_record.append((time() - start) * 1000)
|
||||
time_record.append(
|
||||
(perf_counter() - start) / args.batch_size * 1000)
|
||||
result['time_mean'] = np.mean(time_record[1:-1])
|
||||
result['time_std'] = np.std(time_record[1:-1])
|
||||
else:
|
||||
model.val_step(data)
|
||||
|
||||
result['model'] = config_file.stem
|
||||
forward(data)
|
||||
|
||||
if args.flops:
|
||||
from fvcore.nn import FlopCountAnalysis, parameter_count
|
||||
from fvcore.nn.print_model_statistics import _format_size
|
||||
from mmengine.analysis import FlopAnalyzer, parameter_count
|
||||
from mmengine.analysis.print_helper import _format_size
|
||||
_format_size = _format_size if args.flops_str else lambda x: x
|
||||
with torch.no_grad():
|
||||
if hasattr(model, 'extract_feat'):
|
||||
model.forward = model.extract_feat
|
||||
model.to('cpu')
|
||||
inputs = (torch.randn((1, 3, *resolution)), )
|
||||
flops = _format_size(FlopCountAnalysis(model, inputs).total())
|
||||
params = _format_size(parameter_count(model)[''])
|
||||
result['flops'] = flops if args.flops_str else int(flops)
|
||||
result['params'] = params if args.flops_str else int(params)
|
||||
else:
|
||||
result['flops'] = ''
|
||||
result['params'] = ''
|
||||
model.forward = model.extract_feat
|
||||
model.to('cpu')
|
||||
inputs = (torch.randn((1, 3, *resolution)), )
|
||||
analyzer = FlopAnalyzer(model, inputs)
|
||||
# extract_feat only includes backbone
|
||||
analyzer._enable_warn_uncalled_mods = False
|
||||
flops = _format_size(analyzer.total())
|
||||
params = _format_size(parameter_count(model)[''])
|
||||
result['flops'] = flops if args.flops_str else int(flops)
|
||||
result['params'] = params if args.flops_str else int(params)
|
||||
|
||||
return result
|
||||
|
||||
|
@ -135,17 +159,17 @@ def show_summary(summary_data, args):
|
|||
table = Table(title='Validation Benchmark Regression Summary')
|
||||
table.add_column('Model')
|
||||
table.add_column('Validation')
|
||||
table.add_column('Resolution (h, w)')
|
||||
table.add_column('Resolution (h w)')
|
||||
if args.inference_time:
|
||||
table.add_column('Inference Time (std) (ms/im)')
|
||||
if args.flops:
|
||||
table.add_column('Flops', justify='right')
|
||||
table.add_column('Params', justify='right')
|
||||
table.add_column('Flops', justify='right', width=13)
|
||||
table.add_column('Params', justify='right', width=11)
|
||||
|
||||
for model_name, summary in summary_data.items():
|
||||
row = [model_name]
|
||||
valid = summary['valid']
|
||||
color = 'green' if valid == 'PASS' else 'red'
|
||||
color = {'PASS': 'green', 'CUDA OOM': 'yellow'}.get(valid, 'red')
|
||||
row.append(f'[{color}]{valid}[/{color}]')
|
||||
if valid == 'PASS':
|
||||
row.append(str(summary['resolution']))
|
||||
|
@ -158,84 +182,55 @@ def show_summary(summary_data, args):
|
|||
row.append(str(summary['params']))
|
||||
table.add_row(*row)
|
||||
|
||||
console.print(table)
|
||||
|
||||
|
||||
# Sample test whether the inference code is correct
|
||||
def main(args):
|
||||
register_all_modules()
|
||||
model_index_file = MMCLS_ROOT / 'model-index.yml'
|
||||
model_index = load(str(model_index_file))
|
||||
model_index.build_models_with_collections()
|
||||
models = OrderedDict({model.name: model for model in model_index.models})
|
||||
|
||||
logger = MMLogger(
|
||||
'validation',
|
||||
logger_name='validation',
|
||||
log_file='benchmark_test_image.log',
|
||||
log_level=logging.INFO)
|
||||
|
||||
if args.models:
|
||||
patterns = [re.compile(pattern) for pattern in args.models]
|
||||
filter_models = {}
|
||||
for k, v in models.items():
|
||||
if any([re.match(pattern, k) for pattern in patterns]):
|
||||
filter_models[k] = v
|
||||
if len(filter_models) == 0:
|
||||
models = set()
|
||||
for pattern in args.models:
|
||||
models.update(list_models(pattern=pattern))
|
||||
if len(models) == 0:
|
||||
print('No model found, please specify models in:')
|
||||
print('\n'.join(models.keys()))
|
||||
print('\n'.join(list_models()))
|
||||
return
|
||||
models = filter_models
|
||||
else:
|
||||
models = list_models()
|
||||
|
||||
summary_data = {}
|
||||
tmpdir = tempfile.TemporaryDirectory()
|
||||
for model_name, model_info in models.items():
|
||||
for model_name in models:
|
||||
|
||||
model_info = ModelHub.get(model_name)
|
||||
if model_info.config is None:
|
||||
continue
|
||||
|
||||
config = Path(model_info.config)
|
||||
assert config.exists(), f'{model_name}: {config} not found.'
|
||||
|
||||
logger.info(f'Processing: {model_name}')
|
||||
|
||||
http_prefix = 'https://download.openmmlab.com/mmclassification/'
|
||||
if args.checkpoint_root is not None:
|
||||
root = args.checkpoint_root
|
||||
if 's3://' in args.checkpoint_root:
|
||||
from petrel_client.common.exception import AccessDeniedError
|
||||
file_client = FileClient.infer_client(uri=root)
|
||||
checkpoint = file_client.join_path(
|
||||
root, model_info.weights[len(http_prefix):])
|
||||
try:
|
||||
exists = file_client.exists(checkpoint)
|
||||
except AccessDeniedError:
|
||||
exists = False
|
||||
else:
|
||||
checkpoint = Path(root) / model_info.weights[len(http_prefix):]
|
||||
exists = checkpoint.exists()
|
||||
if exists:
|
||||
checkpoint = str(checkpoint)
|
||||
else:
|
||||
print(f'WARNING: {model_name}: {checkpoint} not found.')
|
||||
checkpoint = None
|
||||
weights = model_info.weights
|
||||
if args.checkpoint_root is not None and weights is not None:
|
||||
checkpoint = substitute_weights(weights, args.checkpoint_root)
|
||||
else:
|
||||
checkpoint = None
|
||||
|
||||
try:
|
||||
# build the model from a config file and a checkpoint file
|
||||
result = inference(MMCLS_ROOT / config, checkpoint, tmpdir.name,
|
||||
args, model_name)
|
||||
result = inference(model_info, checkpoint, tmpdir.name, args)
|
||||
result['valid'] = 'PASS'
|
||||
except Exception:
|
||||
import traceback
|
||||
logger.error(f'"{config}" :\n{traceback.format_exc()}')
|
||||
result = {'valid': 'FAIL'}
|
||||
except Exception as e:
|
||||
if 'CUDA out of memory' in str(e):
|
||||
logger.error(f'"{model_name}" :\nCUDA out of memory')
|
||||
result = {'valid': 'CUDA OOM'}
|
||||
else:
|
||||
import traceback
|
||||
logger.error(f'"{model_name}" :\n{traceback.format_exc()}')
|
||||
result = {'valid': 'FAIL'}
|
||||
|
||||
summary_data[model_name] = result
|
||||
# show the results
|
||||
if args.show:
|
||||
vis = ClsVisualizer.get_instance('valid')
|
||||
vis = UniversalVisualizer.get_instance('valid')
|
||||
vis.set_image(mmcv.imread(args.img))
|
||||
vis.draw_texts(
|
||||
texts='\n'.join([f'{k}: {v}' for k, v in result.items()]),
|
||||
|
|
|
@ -1,9 +1,10 @@
|
|||
import argparse
|
||||
import fnmatch
|
||||
import logging
|
||||
import os
|
||||
import os.path as osp
|
||||
import pickle
|
||||
import re
|
||||
from collections import OrderedDict
|
||||
from collections import OrderedDict, defaultdict
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
|
@ -11,57 +12,57 @@ from modelindex.load_model_index import load
|
|||
from rich.console import Console
|
||||
from rich.syntax import Syntax
|
||||
from rich.table import Table
|
||||
from utils import METRICS_MAP, MMCLS_ROOT, substitute_weights
|
||||
|
||||
# Avoid to import MMPretrain to accelerate speed to show summary
|
||||
|
||||
console = Console()
|
||||
MMCLS_ROOT = Path(__file__).absolute().parents[2]
|
||||
METRICS_MAP = {
|
||||
'Top 1 Accuracy': 'accuracy/top1',
|
||||
'Top 5 Accuracy': 'accuracy/top5'
|
||||
}
|
||||
logger = logging.getLogger('test')
|
||||
logger.addHandler(logging.StreamHandler())
|
||||
logger.addHandler(logging.FileHandler('benchmark_test.log', mode='w'))
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Test all models' accuracy in model-index.yml")
|
||||
parser.add_argument(
|
||||
'partition', type=str, help='Cluster partition to use.')
|
||||
parser.add_argument('checkpoint_root', help='Checkpoint file root path.')
|
||||
parser.add_argument(
|
||||
'--job-name',
|
||||
type=str,
|
||||
default='cls-test-benchmark',
|
||||
help='Slurm job name prefix')
|
||||
parser.add_argument('--port', type=int, default=29666, help='dist port')
|
||||
'--local', action='store_true', help='run at local instead of slurm.')
|
||||
parser.add_argument(
|
||||
'--models', nargs='+', type=str, help='Specify model names to run.')
|
||||
parser.add_argument(
|
||||
'--work-dir',
|
||||
default='work_dirs/benchmark_test',
|
||||
help='the dir to save metric')
|
||||
parser.add_argument(
|
||||
'--run', action='store_true', help='run script directly')
|
||||
parser.add_argument(
|
||||
'--local',
|
||||
action='store_true',
|
||||
help='run at local instead of cluster.')
|
||||
parser.add_argument(
|
||||
'--mail', type=str, help='Mail address to watch test status.')
|
||||
parser.add_argument(
|
||||
'--mail-type',
|
||||
nargs='+',
|
||||
default=['BEGIN'],
|
||||
choices=['NONE', 'BEGIN', 'END', 'FAIL', 'REQUEUE', 'ALL'],
|
||||
help='Mail address to watch test status.')
|
||||
parser.add_argument(
|
||||
'--quotatype',
|
||||
default=None,
|
||||
choices=['reserved', 'auto', 'spot'],
|
||||
help='Quota type, only available for phoenix-slurm>=0.2')
|
||||
parser.add_argument(
|
||||
'--summary',
|
||||
action='store_true',
|
||||
help='Summarize benchmark test results.')
|
||||
parser.add_argument('--save', action='store_true', help='Save the summary')
|
||||
parser.add_argument(
|
||||
'--gpus', type=int, default=1, help='How many GPUS to use.')
|
||||
parser.add_argument(
|
||||
'--no-skip',
|
||||
action='store_true',
|
||||
help='Whether to skip models without results record in the metafile.')
|
||||
parser.add_argument(
|
||||
'--work-dir',
|
||||
default='work_dirs/benchmark_test',
|
||||
help='the dir to save metric')
|
||||
parser.add_argument('--port', type=int, default=29666, help='dist port')
|
||||
parser.add_argument(
|
||||
'--partition',
|
||||
type=str,
|
||||
default='mm_model',
|
||||
help='(for slurm) Cluster partition to use.')
|
||||
parser.add_argument(
|
||||
'--job-name',
|
||||
type=str,
|
||||
default='cls-test-benchmark',
|
||||
help='(for slurm) Slurm job name prefix')
|
||||
parser.add_argument(
|
||||
'--quotatype',
|
||||
default=None,
|
||||
choices=['reserved', 'auto', 'spot'],
|
||||
help='(for slurm) Quota type, only available for phoenix-slurm>=0.2')
|
||||
parser.add_argument(
|
||||
'--cfg-options',
|
||||
nargs='+',
|
||||
|
@ -74,64 +75,53 @@ def parse_args():
|
|||
|
||||
|
||||
def create_test_job_batch(commands, model_info, args, port, script_name):
|
||||
|
||||
fname = model_info.name
|
||||
|
||||
model_name = model_info.name
|
||||
config = Path(model_info.config)
|
||||
assert config.exists(), f'{fname}: {config} not found.'
|
||||
|
||||
http_prefix = 'https://download.openmmlab.com/mmclassification/'
|
||||
if 's3://' in args.checkpoint_root:
|
||||
from mmengine.fileio import FileClient
|
||||
from petrel_client.common.exception import AccessDeniedError
|
||||
file_client = FileClient.infer_client(uri=args.checkpoint_root)
|
||||
checkpoint = file_client.join_path(
|
||||
args.checkpoint_root, model_info.weights[len(http_prefix):])
|
||||
try:
|
||||
exists = file_client.exists(checkpoint)
|
||||
except AccessDeniedError:
|
||||
exists = False
|
||||
if model_info.weights is not None:
|
||||
checkpoint = substitute_weights(model_info.weights,
|
||||
args.checkpoint_root)
|
||||
if checkpoint is None:
|
||||
logger.warning(f'{model_name}: {checkpoint} not found.')
|
||||
return None
|
||||
else:
|
||||
checkpoint_root = Path(args.checkpoint_root)
|
||||
checkpoint = checkpoint_root / model_info.weights[len(http_prefix):]
|
||||
exists = checkpoint.exists()
|
||||
if not exists:
|
||||
print(f'WARNING: {fname}: {checkpoint} not found.')
|
||||
return None
|
||||
|
||||
job_name = f'{args.job_name}_{fname}'
|
||||
work_dir = Path(args.work_dir) / fname
|
||||
job_name = f'{args.job_name}_{model_name}'
|
||||
work_dir = Path(args.work_dir) / model_name
|
||||
work_dir.mkdir(parents=True, exist_ok=True)
|
||||
result_file = work_dir / 'result.pkl'
|
||||
|
||||
if args.mail is not None and 'NONE' not in args.mail_type:
|
||||
mail_cfg = (f'#SBATCH --mail {args.mail}\n'
|
||||
f'#SBATCH --mail-type {args.mail_type}\n')
|
||||
else:
|
||||
mail_cfg = ''
|
||||
|
||||
if args.quotatype is not None:
|
||||
quota_cfg = f'#SBATCH --quotatype {args.quotatype}\n'
|
||||
quota_cfg = f'#SBATCH --quotatype {args.quotatype}'
|
||||
else:
|
||||
quota_cfg = ''
|
||||
|
||||
launcher = 'none' if args.local else 'slurm'
|
||||
runner = 'python' if args.local else 'srun python'
|
||||
if not args.local:
|
||||
launcher = 'srun python'
|
||||
runner = 'slurm'
|
||||
elif args.gpus > 1:
|
||||
launcher = 'pytorch'
|
||||
runner = ('torchrun --master_addr="127.0.0.1" '
|
||||
f'--master_port={port} --nproc_per_node={args.gpus}')
|
||||
else:
|
||||
launcher = 'none'
|
||||
runner = 'python -u'
|
||||
|
||||
job_script = (f'#!/bin/bash\n'
|
||||
f'#SBATCH --output {work_dir}/job.%j.out\n'
|
||||
f'#SBATCH --partition={args.partition}\n'
|
||||
f'#SBATCH --job-name {job_name}\n'
|
||||
f'#SBATCH --gres=gpu:8\n'
|
||||
f'{mail_cfg}{quota_cfg}'
|
||||
f'#SBATCH --ntasks-per-node=8\n'
|
||||
f'#SBATCH --ntasks=8\n'
|
||||
f'#SBATCH --gres=gpu:{min(8, args.gpus)}\n'
|
||||
f'{quota_cfg}\n'
|
||||
f'#SBATCH --ntasks-per-node={min(8, args.gpus)}\n'
|
||||
f'#SBATCH --ntasks={args.gpus}\n'
|
||||
f'#SBATCH --cpus-per-task=5\n\n'
|
||||
f'{runner} -u {script_name} {config} {checkpoint} '
|
||||
f'--work-dir={work_dir} '
|
||||
f'--out={result_file} '
|
||||
f'--cfg-option dist_params.port={port} '
|
||||
f'{runner} {script_name} {config} {checkpoint} '
|
||||
f'--work-dir={work_dir} --cfg-option '
|
||||
f'env_cfg.dist_cfg.port={port} '
|
||||
f'{" ".join(args.cfg_options)} '
|
||||
f'--out={result_file} --out-item="metrics" '
|
||||
f'--launcher={launcher}\n')
|
||||
|
||||
with open(work_dir / 'job.sh', 'w') as f:
|
||||
|
@ -146,33 +136,17 @@ def create_test_job_batch(commands, model_info, args, port, script_name):
|
|||
return work_dir / 'job.sh'
|
||||
|
||||
|
||||
def test(args):
|
||||
# parse model-index.yml
|
||||
model_index_file = MMCLS_ROOT / 'model-index.yml'
|
||||
model_index = load(str(model_index_file))
|
||||
model_index.build_models_with_collections()
|
||||
models = OrderedDict({model.name: model for model in model_index.models})
|
||||
|
||||
def test(models, args):
|
||||
script_name = osp.join('tools', 'test.py')
|
||||
port = args.port
|
||||
|
||||
commands = []
|
||||
if args.models:
|
||||
patterns = [re.compile(pattern) for pattern in args.models]
|
||||
filter_models = {}
|
||||
for k, v in models.items():
|
||||
if any([re.match(pattern, k) for pattern in patterns]):
|
||||
filter_models[k] = v
|
||||
if len(filter_models) == 0:
|
||||
print('No model found, please specify models in:')
|
||||
print('\n'.join(models.keys()))
|
||||
return
|
||||
models = filter_models
|
||||
|
||||
preview_script = ''
|
||||
for model_info in models.values():
|
||||
|
||||
if model_info.results is None:
|
||||
# Skip pre-train model
|
||||
continue
|
||||
|
||||
script_path = create_test_job_batch(commands, model_info, args, port,
|
||||
|
@ -205,44 +179,41 @@ def test(args):
|
|||
console.print('Please set "--run" to start the job')
|
||||
|
||||
|
||||
def save_summary(summary_data, models_map, work_dir):
|
||||
summary_path = work_dir / 'test_benchmark_summary.md'
|
||||
def save_summary(summary_data, work_dir):
|
||||
summary_path = work_dir / 'test_benchmark_summary.csv'
|
||||
file = open(summary_path, 'w')
|
||||
headers = [
|
||||
'Model', 'Top-1 Expected(%)', 'Top-1 (%)', 'Top-5 Expected (%)',
|
||||
'Top-5 (%)', 'Config'
|
||||
]
|
||||
file.write('# Test Benchmark Regression Summary\n')
|
||||
file.write('| ' + ' | '.join(headers) + ' |\n')
|
||||
file.write('|:' + ':|:'.join(['---'] * len(headers)) + ':|\n')
|
||||
columns = defaultdict(list)
|
||||
for model_name, summary in summary_data.items():
|
||||
if len(summary) == 0:
|
||||
# Skip models without results
|
||||
continue
|
||||
row = [model_name]
|
||||
if 'Top 1 Accuracy' in summary:
|
||||
metric = summary['Top 1 Accuracy']
|
||||
row.append(str(round(metric['expect'], 2)))
|
||||
row.append(str(round(metric['result'], 2)))
|
||||
else:
|
||||
row.extend([''] * 2)
|
||||
if 'Top 5 Accuracy' in summary:
|
||||
metric = summary['Top 5 Accuracy']
|
||||
row.append(str(round(metric['expect'], 2)))
|
||||
row.append(str(round(metric['result'], 2)))
|
||||
else:
|
||||
row.extend([''] * 2)
|
||||
columns['Name'].append(model_name)
|
||||
|
||||
model_info = models_map[model_name]
|
||||
row.append(model_info.config)
|
||||
file.write('| ' + ' | '.join(row) + ' |\n')
|
||||
for metric_key in METRICS_MAP:
|
||||
if metric_key in summary:
|
||||
metric = summary[metric_key]
|
||||
expect = round(metric['expect'], 2)
|
||||
result = round(metric['result'], 2)
|
||||
columns[f'{metric_key} (expect)'].append(str(expect))
|
||||
columns[f'{metric_key}'].append(str(result))
|
||||
else:
|
||||
columns[f'{metric_key} (expect)'].append('')
|
||||
columns[f'{metric_key}'].append('')
|
||||
|
||||
columns = {
|
||||
field: column
|
||||
for field, column in columns.items() if ''.join(column)
|
||||
}
|
||||
file.write(','.join(columns.keys()) + '\n')
|
||||
for row in zip(*columns.values()):
|
||||
file.write(','.join(row) + '\n')
|
||||
file.close()
|
||||
print('Summary file saved at ' + str(summary_path))
|
||||
logger.info('Summary file saved at ' + str(summary_path))
|
||||
|
||||
|
||||
def show_summary(summary_data):
|
||||
table = Table(title='Test Benchmark Regression Summary')
|
||||
table.add_column('Model')
|
||||
table.add_column('Name')
|
||||
for metric in METRICS_MAP:
|
||||
table.add_column(f'{metric} (expect)')
|
||||
table.add_column(f'{metric}')
|
||||
|
@ -274,33 +245,20 @@ def show_summary(summary_data):
|
|||
row.append('')
|
||||
table.add_row(*row)
|
||||
|
||||
# Remove empty columns
|
||||
table.columns = [
|
||||
column for column in table.columns if ''.join(column._cells)
|
||||
]
|
||||
console.print(table)
|
||||
|
||||
|
||||
def summary(args):
|
||||
model_index_file = MMCLS_ROOT / 'model-index.yml'
|
||||
model_index = load(str(model_index_file))
|
||||
model_index.build_models_with_collections()
|
||||
models = OrderedDict({model.name: model for model in model_index.models})
|
||||
|
||||
def summary(models, args):
|
||||
work_dir = Path(args.work_dir)
|
||||
|
||||
if args.models:
|
||||
patterns = [re.compile(pattern) for pattern in args.models]
|
||||
filter_models = {}
|
||||
for k, v in models.items():
|
||||
if any([re.match(pattern, k) for pattern in patterns]):
|
||||
filter_models[k] = v
|
||||
if len(filter_models) == 0:
|
||||
print('No model found, please specify models in:')
|
||||
print('\n'.join(models.keys()))
|
||||
return
|
||||
models = filter_models
|
||||
|
||||
summary_data = {}
|
||||
for model_name, model_info in models.items():
|
||||
|
||||
if model_info.results is None:
|
||||
if model_info.results is None and not args.no_skip:
|
||||
continue
|
||||
|
||||
# Skip if not found result file.
|
||||
|
@ -327,16 +285,35 @@ def summary(args):
|
|||
|
||||
show_summary(summary_data)
|
||||
if args.save:
|
||||
save_summary(summary_data, models, work_dir)
|
||||
save_summary(summary_data, work_dir)
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
# parse model-index.yml
|
||||
model_index_file = MMCLS_ROOT / 'model-index.yml'
|
||||
model_index = load(str(model_index_file))
|
||||
model_index.build_models_with_collections()
|
||||
models = OrderedDict({model.name: model for model in model_index.models})
|
||||
|
||||
if args.models:
|
||||
filter_models = {}
|
||||
for pattern in args.models:
|
||||
filter_models.update({
|
||||
name: models[name]
|
||||
for name in fnmatch.filter(models, pattern + '*')
|
||||
})
|
||||
if len(filter_models) == 0:
|
||||
logger.error('No model found, please specify models in:\n' +
|
||||
'\n'.join(models.keys()))
|
||||
return
|
||||
models = filter_models
|
||||
|
||||
if args.summary:
|
||||
summary(args)
|
||||
summary(models, args)
|
||||
else:
|
||||
test(args)
|
||||
test(models, args)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -1,9 +1,12 @@
|
|||
import argparse
|
||||
import fnmatch
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import os.path as osp
|
||||
import re
|
||||
from collections import OrderedDict
|
||||
from collections import defaultdict
|
||||
from copy import deepcopy
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from zipfile import ZipFile
|
||||
|
@ -14,18 +17,20 @@ from rich.console import Console
|
|||
from rich.syntax import Syntax
|
||||
from rich.table import Table
|
||||
|
||||
from .utils import METRICS_MAP, MMCLS_ROOT
|
||||
|
||||
# Avoid to import MMPretrain to accelerate speed to show summary
|
||||
|
||||
console = Console()
|
||||
MMCLS_ROOT = Path(__file__).absolute().parents[2]
|
||||
logger = logging.getLogger('train')
|
||||
logger.addHandler(logging.StreamHandler())
|
||||
logger.addHandler(logging.FileHandler('benchmark_train.log', mode='w'))
|
||||
CYCLE_LEVELS = ['month', 'quarter', 'half-year', 'no-training']
|
||||
METRICS_MAP = {
|
||||
'Top 1 Accuracy': 'accuracy/top1',
|
||||
'Top 5 Accuracy': 'accuracy/top5'
|
||||
}
|
||||
|
||||
|
||||
class RangeAction(argparse.Action):
|
||||
|
||||
def __call__(self, parser, namespace, values: str, option_string):
|
||||
def __call__(self, _, namespace, values: str, __):
|
||||
matches = re.match(r'([><=]*)([-\w]+)', values)
|
||||
if matches is None:
|
||||
raise ValueError(f'Unavailable range option {values}')
|
||||
|
@ -49,15 +54,25 @@ def parse_args():
|
|||
parser = argparse.ArgumentParser(
|
||||
description='Train models (in bench_train.yml) and compare accuracy.')
|
||||
parser.add_argument(
|
||||
'partition', type=str, help='Cluster partition to use.')
|
||||
parser.add_argument(
|
||||
'--job-name',
|
||||
type=str,
|
||||
default='cls-train-benchmark',
|
||||
help='Slurm job name prefix')
|
||||
parser.add_argument('--port', type=int, default=29666, help='dist port')
|
||||
'--local',
|
||||
action='store_true',
|
||||
help='run at local instead of cluster.')
|
||||
parser.add_argument(
|
||||
'--models', nargs='+', type=str, help='Specify model names to run.')
|
||||
parser.add_argument(
|
||||
'--run', action='store_true', help='run script directly')
|
||||
parser.add_argument(
|
||||
'--summary',
|
||||
action='store_true',
|
||||
help='Summarize benchmark train results.')
|
||||
parser.add_argument(
|
||||
'--save',
|
||||
action='store_true',
|
||||
help='Save the summary and archive log files.')
|
||||
parser.add_argument(
|
||||
'--non-distributed',
|
||||
action='store_true',
|
||||
help='Use non-distributed environment (for debug).')
|
||||
parser.add_argument(
|
||||
'--range',
|
||||
type=str,
|
||||
|
@ -70,33 +85,22 @@ def parse_args():
|
|||
'--work-dir',
|
||||
default='work_dirs/benchmark_train',
|
||||
help='the dir to save train log')
|
||||
parser.add_argument('--port', type=int, default=29666, help='dist port')
|
||||
parser.add_argument(
|
||||
'--run', action='store_true', help='run script directly')
|
||||
'--partition',
|
||||
type=str,
|
||||
default='mm_model',
|
||||
help='(for slurm) Cluster partition to use.')
|
||||
parser.add_argument(
|
||||
'--local',
|
||||
action='store_true',
|
||||
help='run at local instead of cluster.')
|
||||
parser.add_argument(
|
||||
'--mail', type=str, help='Mail address to watch train status.')
|
||||
parser.add_argument(
|
||||
'--mail-type',
|
||||
nargs='+',
|
||||
default=['BEGIN', 'END', 'FAIL'],
|
||||
choices=['NONE', 'BEGIN', 'END', 'FAIL', 'REQUEUE', 'ALL'],
|
||||
help='Mail address to watch train status.')
|
||||
'--job-name',
|
||||
type=str,
|
||||
default='cls-train-benchmark',
|
||||
help='(for slurm) Slurm job name prefix')
|
||||
parser.add_argument(
|
||||
'--quotatype',
|
||||
default=None,
|
||||
choices=['reserved', 'auto', 'spot'],
|
||||
help='Quota type, only available for phoenix-slurm>=0.2')
|
||||
parser.add_argument(
|
||||
'--summary',
|
||||
action='store_true',
|
||||
help='Summarize benchmark train results.')
|
||||
parser.add_argument(
|
||||
'--save',
|
||||
action='store_true',
|
||||
help='Save the summary and archive log files.')
|
||||
help='(for slurm) Quota type, only available for phoenix-slurm>=0.2')
|
||||
parser.add_argument(
|
||||
'--cfg-options',
|
||||
nargs='+',
|
||||
|
@ -118,72 +122,93 @@ def get_gpu_number(model_info):
|
|||
return gpus
|
||||
|
||||
|
||||
def create_train_job_batch(commands, model_info, args, port, script_name):
|
||||
|
||||
fname = model_info.name
|
||||
|
||||
gpus = get_gpu_number(model_info)
|
||||
gpus_per_node = min(gpus, 8)
|
||||
|
||||
def create_train_job_batch(model_info, args, port, pretrain_info=None):
|
||||
model_name = model_info.name
|
||||
config = Path(model_info.config)
|
||||
assert config.exists(), f'"{fname}": {config} not found.'
|
||||
gpus = get_gpu_number(model_info)
|
||||
|
||||
job_name = f'{args.job_name}_{fname}'
|
||||
work_dir = Path(args.work_dir) / fname
|
||||
job_name = f'{args.job_name}_{model_name}'
|
||||
work_dir = Path(args.work_dir) / model_name
|
||||
work_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if args.mail is not None and 'NONE' not in args.mail_type:
|
||||
mail_cfg = (f'#SBATCH --mail {args.mail}\n'
|
||||
f'#SBATCH --mail-type {args.mail_type}\n')
|
||||
else:
|
||||
mail_cfg = ''
|
||||
cfg_options = deepcopy(args.cfg_options)
|
||||
|
||||
if args.quotatype is not None:
|
||||
quota_cfg = f'#SBATCH --quotatype {args.quotatype}\n'
|
||||
quota_cfg = f'#SBATCH --quotatype {args.quotatype}'
|
||||
else:
|
||||
quota_cfg = ''
|
||||
|
||||
launcher = 'none' if args.local else 'slurm'
|
||||
runner = 'python' if args.local else 'srun python'
|
||||
if pretrain_info is not None:
|
||||
pretrain = Path(args.work_dir) / pretrain_info.name / 'last_checkpoint'
|
||||
pretrain_cfg = (f'model.backbone.init_cfg.checkpoint="$(<{pretrain})" '
|
||||
'model.backbone.init_cfg.type="Pretrained" '
|
||||
'model.backbone.init_cfg.prefix="backbone."')
|
||||
else:
|
||||
pretrain_cfg = ''
|
||||
|
||||
if not args.local:
|
||||
launcher = 'slurm'
|
||||
runner = 'srun python'
|
||||
if gpus > 8:
|
||||
gpus = 8
|
||||
cfg_options.append('auto_scale_lr.enable=True')
|
||||
elif not args.non_distributed:
|
||||
launcher = 'pytorch'
|
||||
if gpus > 8:
|
||||
gpus = 8
|
||||
cfg_options.append('auto_scale_lr.enable=True')
|
||||
runner = ('torchrun --master_addr="127.0.0.1" '
|
||||
f'--master_port={port} --nproc_per_node={gpus}')
|
||||
else:
|
||||
launcher = 'none'
|
||||
runner = 'python -u'
|
||||
|
||||
job_script = (f'#!/bin/bash\n'
|
||||
f'#SBATCH --output {work_dir}/job.%j.out\n'
|
||||
f'#SBATCH --partition={args.partition}\n'
|
||||
f'#SBATCH --job-name {job_name}\n'
|
||||
f'#SBATCH --gres=gpu:{gpus_per_node}\n'
|
||||
f'{mail_cfg}{quota_cfg}'
|
||||
f'#SBATCH --ntasks-per-node={gpus_per_node}\n'
|
||||
f'#SBATCH --gres=gpu:{min(8, gpus)}\n'
|
||||
f'{quota_cfg}\n'
|
||||
f'#SBATCH --ntasks-per-node={min(8, gpus)}\n'
|
||||
f'#SBATCH --ntasks={gpus}\n'
|
||||
f'#SBATCH --cpus-per-task=5\n\n'
|
||||
f'{runner} -u {script_name} {config} '
|
||||
f'{runner} tools/train.py {config} '
|
||||
f'--work-dir={work_dir} --cfg-option '
|
||||
f'env_cfg.dist_cfg.port={port} '
|
||||
f'{" ".join(args.cfg_options)} '
|
||||
f'{" ".join(cfg_options)} '
|
||||
f'default_hooks.checkpoint.max_keep_ckpts=2 '
|
||||
f'default_hooks.checkpoint.save_best="auto" '
|
||||
f'{pretrain_cfg} '
|
||||
f'--launcher={launcher}\n')
|
||||
|
||||
with open(work_dir / 'job.sh', 'w') as f:
|
||||
f.write(job_script)
|
||||
|
||||
commands.append(f'echo "{config}"')
|
||||
if args.local:
|
||||
commands.append(f'bash {work_dir}/job.sh')
|
||||
else:
|
||||
commands.append(f'sbatch {work_dir}/job.sh')
|
||||
|
||||
return work_dir / 'job.sh'
|
||||
|
||||
|
||||
def train(models, args):
|
||||
script_name = osp.join('tools', 'train.py')
|
||||
port = args.port
|
||||
|
||||
commands = []
|
||||
|
||||
for model_info in models.values():
|
||||
script_path = create_train_job_batch(commands, model_info, args, port,
|
||||
script_name)
|
||||
script_path = create_train_job_batch(model_info, args, port)
|
||||
if hasattr(model_info, 'downstream'):
|
||||
downstream_info = model_info.downstream
|
||||
downstream_script = create_train_job_batch(
|
||||
downstream_info, args, port, pretrain_info=model_info)
|
||||
else:
|
||||
downstream_script = None
|
||||
|
||||
if args.local:
|
||||
command = f'bash {script_path}'
|
||||
if downstream_script:
|
||||
command += f' && bash {downstream_script}'
|
||||
else:
|
||||
command = f'JOBID=$(sbatch --parsable {script_path})'
|
||||
if downstream_script:
|
||||
command += f' && sbatch --dependency=afterok:$JOBID {downstream_script}' # noqa: E501
|
||||
commands.append(command)
|
||||
|
||||
port += 1
|
||||
|
||||
command_str = '\n'.join(commands)
|
||||
|
@ -211,63 +236,67 @@ def train(models, args):
|
|||
console.print('Please set "--run" to start the job')
|
||||
|
||||
|
||||
def save_summary(summary_data, models_map, work_dir):
|
||||
def save_summary(summary_data, work_dir):
|
||||
date = datetime.now().strftime('%Y%m%d-%H%M%S')
|
||||
zip_path = work_dir / f'archive-{date}.zip'
|
||||
zip_file = ZipFile(zip_path, 'w')
|
||||
summary_path = work_dir / 'benchmark_summary.md'
|
||||
|
||||
summary_path = work_dir / 'benchmark_summary.csv'
|
||||
file = open(summary_path, 'w')
|
||||
headers = [
|
||||
'Model', 'Top-1 Expected(%)', 'Top-1 (%)', 'Top-1 best(%)',
|
||||
'best epoch', 'Top-5 Expected (%)', 'Top-5 (%)', 'Config', 'Log'
|
||||
]
|
||||
file.write('# Train Benchmark Regression Summary\n')
|
||||
file.write('| ' + ' | '.join(headers) + ' |\n')
|
||||
file.write('|:' + ':|:'.join(['---'] * len(headers)) + ':|\n')
|
||||
columns = defaultdict(list)
|
||||
for model_name, summary in summary_data.items():
|
||||
if len(summary) == 0:
|
||||
# Skip models without results
|
||||
continue
|
||||
row = [model_name]
|
||||
if 'Top 1 Accuracy' in summary:
|
||||
metric = summary['Top 1 Accuracy']
|
||||
row.append(f"{metric['expect']:.2f}")
|
||||
row.append(f"{metric['last']:.2f}")
|
||||
row.append(f"{metric['best']:.2f}")
|
||||
row.append(f"{metric['best_epoch']:.2f}")
|
||||
else:
|
||||
row.extend([''] * 4)
|
||||
if 'Top 5 Accuracy' in summary:
|
||||
metric = summary['Top 5 Accuracy']
|
||||
row.append(f"{metric['expect']:.2f}")
|
||||
row.append(f"{metric['last']:.2f}")
|
||||
else:
|
||||
row.extend([''] * 2)
|
||||
columns['Name'].append(model_name)
|
||||
|
||||
model_info = models_map[model_name]
|
||||
row.append(model_info.config)
|
||||
row.append(str(summary['log_file'].relative_to(work_dir)))
|
||||
for metric_key in METRICS_MAP:
|
||||
if metric_key in summary:
|
||||
metric = summary[metric_key]
|
||||
expect = str(round(metric['expect'], 2))
|
||||
result = str(round(metric['result'], 2))
|
||||
columns[f'{metric_key} (expect)'].append(expect)
|
||||
columns[f'{metric_key}'].append(result)
|
||||
best = str(round(metric['best'], 2))
|
||||
best_epoch = str(int(metric['best_epoch']))
|
||||
columns[f'{metric_key} (best)'].append(best)
|
||||
columns[f'{metric_key} (best epoch)'].append(best_epoch)
|
||||
else:
|
||||
columns[f'{metric_key} (expect)'].append('')
|
||||
columns[f'{metric_key}'].append('')
|
||||
columns[f'{metric_key} (best)'].append('')
|
||||
columns[f'{metric_key} (best epoch)'].append('')
|
||||
|
||||
columns['Log'].append(str(summary['log_file'].relative_to(work_dir)))
|
||||
zip_file.write(summary['log_file'])
|
||||
file.write('| ' + ' | '.join(row) + ' |\n')
|
||||
|
||||
columns = {
|
||||
field: column
|
||||
for field, column in columns.items() if ''.join(column)
|
||||
}
|
||||
file.write(','.join(columns.keys()) + '\n')
|
||||
for row in zip(*columns.values()):
|
||||
file.write(','.join(row) + '\n')
|
||||
file.close()
|
||||
zip_file.write(summary_path)
|
||||
zip_file.close()
|
||||
print('Summary file saved at ' + str(summary_path))
|
||||
print('Log files archived at ' + str(zip_path))
|
||||
logger.info('Summary file saved at ' + str(summary_path))
|
||||
logger.info('Log files archived at ' + str(zip_path))
|
||||
|
||||
|
||||
def show_summary(summary_data):
|
||||
table = Table(title='Train Benchmark Regression Summary')
|
||||
table.add_column('Model')
|
||||
table.add_column('Name')
|
||||
for metric in METRICS_MAP:
|
||||
table.add_column(f'{metric} (expect)')
|
||||
table.add_column(f'{metric}')
|
||||
table.add_column(f'{metric} (best)')
|
||||
table.add_column('Date')
|
||||
|
||||
def set_color(value, expect):
|
||||
if value > expect:
|
||||
return 'green'
|
||||
elif value > expect - 0.2:
|
||||
elif value >= expect - 0.2:
|
||||
return 'white'
|
||||
else:
|
||||
return 'red'
|
||||
|
@ -277,25 +306,30 @@ def show_summary(summary_data):
|
|||
for metric_key in METRICS_MAP:
|
||||
if metric_key in summary:
|
||||
metric = summary[metric_key]
|
||||
expect = metric['expect']
|
||||
last = metric['last']
|
||||
expect = round(metric['expect'], 2)
|
||||
last = round(metric['last'], 2)
|
||||
last_epoch = metric['last_epoch']
|
||||
last_color = set_color(last, expect)
|
||||
best = metric['best']
|
||||
best_color = set_color(best, expect)
|
||||
best_epoch = metric['best_epoch']
|
||||
best_epoch = round(metric['best_epoch'], 2)
|
||||
row.append(f'{expect:.2f}')
|
||||
row.append(
|
||||
f'[{last_color}]{last:.2f}[/{last_color}] ({last_epoch})')
|
||||
row.append(
|
||||
f'[{best_color}]{best:.2f}[/{best_color}] ({best_epoch})')
|
||||
else:
|
||||
row.extend([''] * 3)
|
||||
table.add_row(*row)
|
||||
|
||||
# Remove empty columns
|
||||
table.columns = [
|
||||
column for column in table.columns if ''.join(column._cells)
|
||||
]
|
||||
console.print(table)
|
||||
|
||||
|
||||
def summary(models, args):
|
||||
|
||||
work_dir = Path(args.work_dir)
|
||||
dir_map = {p.name: p for p in work_dir.iterdir() if p.is_dir()}
|
||||
|
||||
|
@ -306,9 +340,17 @@ def summary(models, args):
|
|||
|
||||
if model_name not in dir_map:
|
||||
continue
|
||||
elif hasattr(model_info, 'downstream'):
|
||||
downstream_name = model_info.downstream.name
|
||||
if downstream_name not in dir_map:
|
||||
continue
|
||||
else:
|
||||
sub_dir = dir_map[downstream_name]
|
||||
model_info = model_info.downstream
|
||||
else:
|
||||
# Skip if not found any vis_data folder.
|
||||
sub_dir = dir_map[model_name]
|
||||
|
||||
# Skip if not found any vis_data folder.
|
||||
sub_dir = dir_map[model_name]
|
||||
log_files = [f for f in sub_dir.glob('*/vis_data/scalars.json')]
|
||||
if len(log_files) == 0:
|
||||
continue
|
||||
|
@ -317,11 +359,8 @@ def summary(models, args):
|
|||
# parse train log
|
||||
with open(log_file) as f:
|
||||
json_logs = [json.loads(s) for s in f.readlines()]
|
||||
val_logs = [
|
||||
log for log in json_logs
|
||||
# TODO: need a better method to extract validate log
|
||||
if 'loss' not in log and 'accuracy/top1' in log
|
||||
]
|
||||
# TODO: need a better method to extract validate log
|
||||
val_logs = [log for log in json_logs if 'loss' not in log]
|
||||
|
||||
if len(val_logs) == 0:
|
||||
continue
|
||||
|
@ -351,12 +390,13 @@ def summary(models, args):
|
|||
|
||||
show_summary(summary_data)
|
||||
if args.save:
|
||||
save_summary(summary_data, models, work_dir)
|
||||
save_summary(summary_data, work_dir)
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
# parse model-index.yml
|
||||
model_index_file = MMCLS_ROOT / 'model-index.yml'
|
||||
model_index = load(str(model_index_file))
|
||||
model_index.build_models_with_collections()
|
||||
|
@ -364,25 +404,28 @@ def main():
|
|||
|
||||
with open(Path(__file__).parent / 'bench_train.yml', 'r') as f:
|
||||
train_items = yaml.safe_load(f)
|
||||
models = OrderedDict()
|
||||
models = {}
|
||||
for item in train_items:
|
||||
name = item['Name']
|
||||
model_info = all_models[name]
|
||||
model_info.cycle = item.get('Cycle', None)
|
||||
cycle = getattr(model_info, 'cycle', 'month')
|
||||
cycle = item['Cycle']
|
||||
cycle_level = CYCLE_LEVELS.index(cycle)
|
||||
if cycle_level in args.range:
|
||||
model_info = all_models[name]
|
||||
if 'Downstream' in item:
|
||||
downstream = item['Downstream']
|
||||
setattr(model_info, 'downstream', all_models[downstream])
|
||||
models[name] = model_info
|
||||
|
||||
if args.models:
|
||||
patterns = [re.compile(pattern) for pattern in args.models]
|
||||
filter_models = {}
|
||||
for k, v in models.items():
|
||||
if any([re.match(pattern, k) for pattern in patterns]):
|
||||
filter_models[k] = v
|
||||
for pattern in args.models:
|
||||
filter_models.update({
|
||||
name: models[name]
|
||||
for name in fnmatch.filter(models, pattern + '*')
|
||||
})
|
||||
if len(filter_models) == 0:
|
||||
print('No model found, please specify models in:')
|
||||
print('\n'.join(models.keys()))
|
||||
logger.error('No model found, please specify models in:\n' +
|
||||
'\n'.join(models.keys()))
|
||||
return
|
||||
models = filter_models
|
||||
|
||||
|
|
|
@ -18,9 +18,9 @@ from modelindex.load_model_index import load
|
|||
from rich.console import Console
|
||||
from rich.table import Table
|
||||
|
||||
from mmcls.datasets.builder import build_dataloader
|
||||
from mmcls.datasets.pipelines import Compose
|
||||
from mmcls.models.builder import build_classifier
|
||||
from mmpretrain.datasets.builder import build_dataloader
|
||||
from mmpretrain.datasets.pipelines import Compose
|
||||
from mmpretrain.models.builder import build_classifier
|
||||
|
||||
console = Console()
|
||||
MMCLS_ROOT = Path(__file__).absolute().parents[2]
|
||||
|
|
|
@ -0,0 +1,14 @@
|
|||
- Name: convnext-base_32xb128_in1k
|
||||
- Name: convnext-v2-atto_fcmae-pre_3rdparty_in1k
|
||||
- Name: mobilenet-v2_8xb32_in1k
|
||||
- Name: mobilenet-v3-small-050_3rdparty_in1k
|
||||
- Name: swin-tiny_16xb64_in1k
|
||||
- Name: swinv2-tiny-w8_3rdparty_in1k-256px
|
||||
- Name: vit-base-p16_32xb128-mae_in1k
|
||||
- Name: resnet34_8xb32_in1k
|
||||
- Name: resnext50-32x4d_8xb32_in1k
|
||||
- Name: shufflenet-v2-1x_16xb64_in1k
|
||||
- Name: riformer-s12_in1k
|
||||
- Name: blip-base_3rdparty_retrieval
|
||||
- Name: blip2-opt2.7b_3rdparty-zeroshot_caption
|
||||
- Name: ofa-base_3rdparty-finetuned_caption
|
|
@ -1,18 +1,21 @@
|
|||
- Name: mobilenet-v2_8xb32_in1k
|
||||
Cycle: month
|
||||
|
||||
- Name: resnet50_8xb32_in1k
|
||||
Cycle: month
|
||||
|
||||
- Name: seresnet50_8xb32_in1k
|
||||
- Name: resnet50_8xb256-rsb-a1-600e_in1k
|
||||
Cycle: month
|
||||
|
||||
- Name: swin-small_16xb64_in1k
|
||||
Cycle: month
|
||||
|
||||
- Name: vit-base-p16_pt-32xb128-mae_in1k
|
||||
- Name: vit-base-p16_32xb128-mae_in1k
|
||||
Cycle: month
|
||||
|
||||
- Name: seresnet50_8xb32_in1k
|
||||
Cycle: quarter
|
||||
|
||||
- Name: resnet50_8xb32_in1k
|
||||
Cycle: quarter
|
||||
|
||||
- Name: resnet50_8xb256-rsb-a1-600e_in1k
|
||||
Cycle: quarter
|
||||
|
||||
|
@ -34,53 +37,85 @@
|
|||
- Name: regnetx-1.6gf_8xb128_in1k
|
||||
Cycle: half-year
|
||||
|
||||
- Name: van-small_8xb128_in1k
|
||||
Cycle: no-training
|
||||
- Name: conformer-small-p32_8xb128_in1k
|
||||
Cycle: half-year
|
||||
|
||||
- Name: res2net50-w14-s8_3rdparty_8xb32_in1k
|
||||
Cycle: no-training
|
||||
- Name: convnext-small_32xb128_in1k
|
||||
Cycle: month
|
||||
|
||||
- Name: repvgg-A2_3rdparty_4xb64-coslr-120e_in1k
|
||||
Cycle: no-training
|
||||
- Name: mobilenet-v3-small_8xb128_in1k
|
||||
Cycle: half-year
|
||||
|
||||
- Name: tnt-small-p16_3rdparty_in1k
|
||||
Cycle: no-training
|
||||
- Name: mobileone-s2_8xb32_in1k
|
||||
Cycle: quarter
|
||||
|
||||
- Name: mlp-mixer-base-p16_3rdparty_64xb64_in1k
|
||||
Cycle: no-training
|
||||
- Name: repvgg-b2g4_8xb32_in1k
|
||||
Cycle: half-year
|
||||
|
||||
- Name: conformer-small-p16_3rdparty_8xb128_in1k
|
||||
Cycle: no-training
|
||||
- Name: barlowtwins_resnet50_8xb256-coslr-300e_in1k
|
||||
Cycle: half-year
|
||||
Downstream: resnet50_barlowtwins-pre_8xb32-linear-coslr-100e_in1k
|
||||
|
||||
- Name: twins-pcpvt-base_3rdparty_8xb128_in1k
|
||||
Cycle: no-training
|
||||
- Name: beit_beit-base-p16_8xb256-amp-coslr-300e_in1k
|
||||
Cycle: quarter
|
||||
Downstream: beit-base-p16_beit-pre_8xb128-coslr-100e_in1k
|
||||
|
||||
- Name: efficientnet-b0_3rdparty_8xb32_in1k
|
||||
Cycle: no-training
|
||||
- Name: beitv2_beit-base-p16_8xb256-amp-coslr-300e_in1k
|
||||
Cycle: quarter
|
||||
Downstream: beit-base-p16_beitv2-pre_8xb128-coslr-100e_in1k
|
||||
|
||||
- Name: convnext-small_3rdparty_32xb128_in1k
|
||||
Cycle: no-training
|
||||
- Name: byol_resnet50_16xb256-coslr-200e_in1k
|
||||
Cycle: quarter
|
||||
Downstream: resnet50_byol-pre_8xb512-linear-coslr-90e_in1k
|
||||
|
||||
- Name: hrnet-w18_3rdparty_8xb32_in1k
|
||||
Cycle: no-training
|
||||
- Name: cae_beit-base-p16_8xb256-amp-coslr-300e_in1k
|
||||
Cycle: half-year
|
||||
Downstream: beit-base-p16_cae-pre_8xb128-coslr-100e_in1k
|
||||
|
||||
- Name: repmlp-base_3rdparty_8xb64_in1k
|
||||
Cycle: no-training
|
||||
- Name: densecl_resnet50_8xb32-coslr-200e_in1k
|
||||
Cycle: half-year
|
||||
Downstream: resnet50_densecl-pre_8xb32-linear-steplr-100e_in1k
|
||||
|
||||
- Name: wide-resnet50_3rdparty_8xb32_in1k
|
||||
Cycle: no-training
|
||||
- Name: eva-mae-style_vit-base-p16_16xb256-coslr-400e_in1k
|
||||
Cycle: half-year
|
||||
Downstream: vit-base-p16_eva-mae-style-pre_8xb2048-linear-coslr-100e_in1k
|
||||
|
||||
- Name: cspresnet50_3rdparty_8xb32_in1k
|
||||
Cycle: no-training
|
||||
- Name: mae_vit-base-p16_8xb512-amp-coslr-300e_in1k
|
||||
Cycle: month
|
||||
Downstream: vit-base-p16_mae-300e-pre_8xb2048-linear-coslr-90e_in1k
|
||||
|
||||
- Name: convmixer-768-32_10xb64_in1k
|
||||
Cycle: no-training
|
||||
- Name: maskfeat_vit-base-p16_8xb256-amp-coslr-300e_in1k
|
||||
Cycle: quarter
|
||||
Downstream: vit-base-p16_maskfeat-pre_8xb256-coslr-100e_in1k
|
||||
|
||||
- Name: densenet169_4xb256_in1k
|
||||
Cycle: no-training
|
||||
- Name: milan_vit-base-p16_16xb256-amp-coslr-400e_in1k
|
||||
Cycle: quarter
|
||||
Downstream: vit-base-p16_milan-pre_8xb2048-linear-coslr-100e_in1k
|
||||
|
||||
- Name: poolformer-s24_3rdparty_32xb128_in1k
|
||||
Cycle: no-training
|
||||
- Name: mixmim_mixmim-base_16xb128-coslr-300e_in1k
|
||||
Cycle: half-year
|
||||
Downstream: mixmim-base_mixmim-pre_8xb128-coslr-100e_in1k
|
||||
|
||||
- Name: inception-v3_3rdparty_8xb32_in1k
|
||||
Cycle: no-training
|
||||
- Name: mocov2_resnet50_8xb32-coslr-200e_in1k
|
||||
Cycle: quarter
|
||||
Downstream: resnet50_mocov2-pre_8xb32-linear-steplr-100e_in1k
|
||||
|
||||
- Name: mocov3_vit-small-p16_16xb256-amp-coslr-300e_in1k
|
||||
Cycle: month
|
||||
Downstream: vit-small-p16_mocov3-pre_8xb128-linear-coslr-90e_in1k
|
||||
|
||||
- Name: simclr_resnet50_16xb256-coslr-200e_in1k
|
||||
Cycle: quarter
|
||||
Downstream: resnet50_simclr-200e-pre_8xb512-linear-coslr-90e_in1k
|
||||
|
||||
- Name: simmim_swin-base-w6_8xb256-amp-coslr-100e_in1k-192px
|
||||
Cycle: month
|
||||
Downstream: swin-base-w6_simmim-100e-pre_8xb256-coslr-100e_in1k-192px
|
||||
|
||||
- Name: simsiam_resnet50_8xb32-coslr-100e_in1k
|
||||
Cycle: quarter
|
||||
Downstream: resnet50_simsiam-100e-pre_8xb512-linear-coslr-90e_in1k
|
||||
|
||||
- Name: swav_resnet50_8xb32-mcrop-coslr-200e_in1k-224px-96px
|
||||
Cycle: half-year
|
||||
Downstream: resnet50_swav-pre_8xb32-linear-coslr-100e_in1k
|
||||
|
|
|
@ -0,0 +1,33 @@
|
|||
from pathlib import Path
|
||||
|
||||
HTTP_PREFIX = 'https://download.openmmlab.com/'
|
||||
MMCLS_ROOT = Path(__file__).absolute().parents[2]
|
||||
METRICS_MAP = {
|
||||
'Top 1 Accuracy': 'accuracy/top1',
|
||||
'Top 5 Accuracy': 'accuracy/top5',
|
||||
'Recall@1': 'retrieval/Recall@1',
|
||||
'Recall@5': 'retrieval/Recall@5',
|
||||
'BLEU-4': 'Bleu_4',
|
||||
'CIDER': 'CIDEr',
|
||||
}
|
||||
|
||||
|
||||
def substitute_weights(download_link, root):
|
||||
if 's3://' in root:
|
||||
from mmengine.fileio.backends import PetrelBackend
|
||||
from petrel_client.common.exception import AccessDeniedError
|
||||
file_backend = PetrelBackend()
|
||||
checkpoint = file_backend.join_path(root,
|
||||
download_link[len(HTTP_PREFIX):])
|
||||
try:
|
||||
exists = file_backend.exists(checkpoint)
|
||||
except AccessDeniedError:
|
||||
exists = False
|
||||
else:
|
||||
checkpoint = Path(root) / download_link[len(HTTP_PREFIX):]
|
||||
exists = checkpoint.exists()
|
||||
|
||||
if exists:
|
||||
return str(checkpoint)
|
||||
else:
|
||||
return None
|
|
@ -0,0 +1,207 @@
|
|||
import argparse
|
||||
import logging
|
||||
import re
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import yaml
|
||||
from modelindex.load_model_index import load
|
||||
from modelindex.models.Collection import Collection
|
||||
from modelindex.models.Model import Model
|
||||
from modelindex.models.ModelIndex import ModelIndex
|
||||
|
||||
|
||||
class ContextFilter(logging.Filter):
|
||||
metafile = None
|
||||
name = None
|
||||
failed = False
|
||||
|
||||
def filter(self, record: logging.LogRecord):
|
||||
record.color = {
|
||||
logging.WARNING: '\x1b[33;20m',
|
||||
logging.ERROR: '\x1b[31;1m',
|
||||
}.get(record.levelno, '')
|
||||
self.failed = self.failed or (record.levelno >= logging.ERROR)
|
||||
record.metafile = self.metafile or ''
|
||||
record.name = ('' if self.name is None else '\x1b[32m' + self.name +
|
||||
'\x1b[0m: ')
|
||||
return True
|
||||
|
||||
|
||||
context = ContextFilter()
|
||||
logging.basicConfig(
|
||||
format='[%(metafile)s] %(color)s%(levelname)s\x1b[0m - %(name)s%(message)s'
|
||||
)
|
||||
logger = logging.getLogger()
|
||||
logger.addFilter(context)
|
||||
|
||||
prog_description = """\
|
||||
Check the format of metafile.
|
||||
"""
|
||||
|
||||
MMCLS_ROOT = Path(__file__).absolute().parents[1]
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description=prog_description)
|
||||
parser.add_argument(
|
||||
'metafile', type=Path, nargs='+', help='The path of the matafile.')
|
||||
parser.add_argument(
|
||||
'--Wall',
|
||||
'-w',
|
||||
action='store_true',
|
||||
help='Whether to enable all warnings.')
|
||||
parser.add_argument('--skip', action='append', help='Rules to skip check.')
|
||||
args = parser.parse_args()
|
||||
args.skip = args.skip or []
|
||||
return args
|
||||
|
||||
|
||||
def check_collection(modelindex: ModelIndex, skip=[]):
|
||||
|
||||
if len(modelindex.collections) == 0:
|
||||
return ['No collection field.']
|
||||
elif len(modelindex.collections) > 1:
|
||||
logger.error('One metafile should have only one collection.')
|
||||
|
||||
collection: Collection = modelindex.collections[0]
|
||||
|
||||
if collection.name is None:
|
||||
logger.error('The collection should have `Name` field.')
|
||||
if collection.readme is None:
|
||||
logger.error('The collection should have `README` field.')
|
||||
if not (MMCLS_ROOT / collection.readme).exists():
|
||||
logger.error(f'The README {collection.readme} is not found.')
|
||||
if not isinstance(collection.paper, dict):
|
||||
logger.error('The collection should have `Paper` field with '
|
||||
'`Title` and `URL`.')
|
||||
elif 'Title' not in collection.paper:
|
||||
# URL is not necessary.
|
||||
logger.error("The collection's paper should have `Paper` field.")
|
||||
|
||||
|
||||
def check_model_name(name):
|
||||
fields = name.split('_')
|
||||
|
||||
if len(fields) > 5:
|
||||
logger.warning('Too many fields.')
|
||||
return
|
||||
elif len(fields) < 3:
|
||||
logger.warning('Too few fields.')
|
||||
return
|
||||
elif len(fields) == 5:
|
||||
algo, model, pre, train, data = fields
|
||||
elif len(fields) == 3:
|
||||
model, train, data = fields
|
||||
algo, pre = None, None
|
||||
elif len(fields) == 4 and fields[1].endswith('-pre'):
|
||||
model, pre, train, data = fields
|
||||
algo = None
|
||||
else:
|
||||
algo, model, train, data = fields
|
||||
pre = None
|
||||
|
||||
if pre is not None and not pre.endswith('-pre'):
|
||||
logger.warning(f'The position of `{pre}` should be '
|
||||
'pre-training information, and ends with `-pre`.')
|
||||
|
||||
if '3rdparty' not in train and re.match(r'\d+xb\d+', train) is None:
|
||||
logger.warning(f'The position of `{train}` should be training '
|
||||
'infomation, and starts with `3rdparty` or '
|
||||
'`{num_device}xb{batch_per_device}`')
|
||||
|
||||
|
||||
def check_model(model: Model, skip=[]):
|
||||
|
||||
context.name = None
|
||||
if model.name is None:
|
||||
logger.error("A model doesn't have `Name` field.")
|
||||
return
|
||||
context.name = model.name
|
||||
check_model_name(model.name)
|
||||
|
||||
if model.name.endswith('.py'):
|
||||
logger.error("Don't add `.py` suffix in model name.")
|
||||
|
||||
if model.metadata is None and 'metadata' not in skip:
|
||||
logger.error('No `Metadata` field.')
|
||||
|
||||
if (model.metadata.parameters is None
|
||||
or model.metadata.flops is None) and 'flops-param' not in skip:
|
||||
logger.error('Metadata should have `Parameters` and `FLOPs` fields. '
|
||||
'You can use `tools/analysis_tools/get_flops.py` '
|
||||
'to calculate them.')
|
||||
|
||||
if model.results is not None and 'result' not in skip:
|
||||
result = model.results[0]
|
||||
if not isinstance(result.dataset, str):
|
||||
logger.error('Dataset field of Results should be a string. '
|
||||
'If you want to specify the training dataset, '
|
||||
'please use `Metadata.Training Data` field.')
|
||||
|
||||
if 'config' not in skip:
|
||||
if model.config is None:
|
||||
logger.error('No `Config` field.')
|
||||
elif not (MMCLS_ROOT / model.config).exists():
|
||||
logger.error(f'The config {model.config} is not found.')
|
||||
|
||||
if model.in_collection is None:
|
||||
logger.error('No `In Collection` field.')
|
||||
|
||||
if (model.data.get('Converted From') is not None
|
||||
and '3rdparty' not in model.name):
|
||||
logger.warning("The model name should include '3rdparty' "
|
||||
"since it's converted from other repository.")
|
||||
|
||||
if (model.weights is not None and model.weights.endswith('.pth')
|
||||
and 'ckpt-name' not in skip):
|
||||
basename = model.weights.rsplit('/', 1)[-1]
|
||||
if not basename.startswith(model.name):
|
||||
logger.warning(f'The checkpoint name {basename} is not the '
|
||||
'same as the model name.')
|
||||
|
||||
context.name = None
|
||||
|
||||
|
||||
def main(metafile: Path, args):
|
||||
if metafile.name != 'metafile.yml':
|
||||
# Avoid checking other yaml file.
|
||||
return
|
||||
elif metafile.samefile(MMCLS_ROOT / 'model-index.yml'):
|
||||
return
|
||||
|
||||
context.metafile = metafile
|
||||
|
||||
with open(MMCLS_ROOT / 'model-index.yml', 'r') as f:
|
||||
metafile_list = yaml.load(f, yaml.Loader)['Import']
|
||||
if not any(
|
||||
metafile.samefile(MMCLS_ROOT / file)
|
||||
for file in metafile_list):
|
||||
logger.error(
|
||||
'The metafile is not imported in the `model-index.yml`.')
|
||||
|
||||
modelindex = load(str(metafile))
|
||||
modelindex.build_models_with_collections()
|
||||
check_collection(modelindex, args.skip)
|
||||
|
||||
names = {model.name for model in modelindex.models}
|
||||
|
||||
for model in modelindex.models:
|
||||
check_model(model, args.skip)
|
||||
|
||||
for downstream in model.data.get('Downstream', []):
|
||||
if downstream not in names:
|
||||
context.name = model.name
|
||||
logger.error(
|
||||
f"The downstream model {downstream} doesn't exist.")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parse_args()
|
||||
if args.Wall:
|
||||
logger.setLevel(logging.WARNING)
|
||||
else:
|
||||
logger.setLevel(logging.ERROR)
|
||||
for metafile in args.metafile:
|
||||
main(metafile, args)
|
||||
sys.exit(int(context.failed))
|
|
@ -0,0 +1,186 @@
|
|||
import argparse
|
||||
import math
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from rich.console import Console
|
||||
|
||||
console = Console()
|
||||
|
||||
prog_description = """\
|
||||
Draw the state dict tree.
|
||||
"""
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description=prog_description)
|
||||
parser.add_argument(
|
||||
'path',
|
||||
type=Path,
|
||||
help='The path of the checkpoint or model config to draw.')
|
||||
parser.add_argument('--depth', type=int, help='The max depth to draw.')
|
||||
parser.add_argument(
|
||||
'--full-name',
|
||||
action='store_true',
|
||||
help='Whether to print the full name of the key.')
|
||||
parser.add_argument(
|
||||
'--shape',
|
||||
action='store_true',
|
||||
help='Whether to print the shape of the parameter.')
|
||||
parser.add_argument(
|
||||
'--state-key',
|
||||
type=str,
|
||||
help='The key of the state dict in the checkpoint.')
|
||||
parser.add_argument(
|
||||
'--number',
|
||||
action='store_true',
|
||||
help='Mark all parameters and their index number.')
|
||||
parser.add_argument(
|
||||
'--node',
|
||||
type=str,
|
||||
help='Show the sub-tree of a node, like "backbone.layers".')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def ckpt_to_state_dict(checkpoint, key=None):
|
||||
if key is not None:
|
||||
state_dict = checkpoint[key]
|
||||
elif 'state_dict' in checkpoint:
|
||||
# try mmpretrain style
|
||||
state_dict = checkpoint['state_dict']
|
||||
elif 'model' in checkpoint:
|
||||
state_dict = checkpoint['model']
|
||||
elif isinstance(next(iter(checkpoint.values())), torch.Tensor):
|
||||
# try native style
|
||||
state_dict = checkpoint
|
||||
else:
|
||||
raise KeyError('Please specify the key of state '
|
||||
f'dict from {list(checkpoint.keys())}.')
|
||||
return state_dict
|
||||
|
||||
|
||||
class StateDictTree:
|
||||
|
||||
def __init__(self, key='', value=None):
|
||||
self.children = {}
|
||||
self.key: str = key
|
||||
self.value = value
|
||||
|
||||
def add_parameter(self, key, value):
|
||||
keys = key.split('.', 1)
|
||||
if len(keys) == 1:
|
||||
self.children[key] = StateDictTree(key, value)
|
||||
elif keys[0] in self.children:
|
||||
self.children[keys[0]].add_parameter(keys[1], value)
|
||||
else:
|
||||
node = StateDictTree(keys[0])
|
||||
node.add_parameter(keys[1], value)
|
||||
self.children[keys[0]] = node
|
||||
|
||||
def __getitem__(self, key: str):
|
||||
return self.children[key]
|
||||
|
||||
def __repr__(self) -> str:
|
||||
with console.capture() as capture:
|
||||
for line in self.iter_tree():
|
||||
console.print(line)
|
||||
return capture.get()
|
||||
|
||||
def __len__(self):
|
||||
return len(self.children)
|
||||
|
||||
def draw_tree(self,
|
||||
max_depth=None,
|
||||
full_name=False,
|
||||
with_shape=False,
|
||||
with_value=False):
|
||||
for line in self.iter_tree(
|
||||
max_depth=max_depth,
|
||||
full_name=full_name,
|
||||
with_shape=with_shape,
|
||||
with_value=with_value,
|
||||
):
|
||||
console.print(line, highlight=False)
|
||||
|
||||
def iter_tree(
|
||||
self,
|
||||
lead='',
|
||||
prefix='',
|
||||
max_depth=None,
|
||||
full_name=False,
|
||||
with_shape=False,
|
||||
with_value=False,
|
||||
):
|
||||
if self.value is None:
|
||||
key_str = f'[blue]{self.key}[/]'
|
||||
elif with_shape:
|
||||
key_str = f'[green]{self.key}[/] {tuple(self.value.shape)}'
|
||||
elif with_value:
|
||||
key_str = f'[green]{self.key}[/] {self.value}'
|
||||
else:
|
||||
key_str = f'[green]{self.key}[/]'
|
||||
|
||||
yield lead + prefix + key_str
|
||||
|
||||
lead = lead.replace('├─', '│ ')
|
||||
lead = lead.replace('└─', ' ')
|
||||
if self.key and full_name:
|
||||
prefix = f'{prefix}{self.key}.'
|
||||
|
||||
if max_depth == 0:
|
||||
return
|
||||
elif max_depth is not None:
|
||||
max_depth -= 1
|
||||
|
||||
for i, child in enumerate(self.children.values()):
|
||||
level_lead = '├─' if i < len(self.children) - 1 else '└─'
|
||||
yield from child.iter_tree(
|
||||
lead=f'{lead}{level_lead} ',
|
||||
prefix=prefix,
|
||||
max_depth=max_depth,
|
||||
full_name=full_name,
|
||||
with_shape=with_shape,
|
||||
with_value=with_value)
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
if args.path.suffix in ['.json', '.py', '.yml']:
|
||||
from mmengine.runner import get_state_dict
|
||||
|
||||
from mmpretrain.apis import init_model
|
||||
model = init_model(args.path, device='cpu')
|
||||
state_dict = get_state_dict(model)
|
||||
else:
|
||||
ckpt = torch.load(args.path, map_location='cpu')
|
||||
state_dict = ckpt_to_state_dict(ckpt, args.state_key)
|
||||
|
||||
root = StateDictTree()
|
||||
for k, v in state_dict.items():
|
||||
root.add_parameter(k, v)
|
||||
|
||||
para_index = 0
|
||||
mark_width = math.floor(math.log(len(state_dict), 10) + 1)
|
||||
if args.node is not None:
|
||||
for key in args.node.split('.'):
|
||||
root = root[key]
|
||||
|
||||
for line in root.iter_tree(
|
||||
max_depth=args.depth,
|
||||
full_name=args.full_name,
|
||||
with_shape=args.shape,
|
||||
):
|
||||
if not args.number:
|
||||
mark = ''
|
||||
# A hack method to determine whether a line is parameter.
|
||||
elif '[green]' in line:
|
||||
mark = f'[red]({str(para_index).ljust(mark_width)})[/]'
|
||||
para_index += 1
|
||||
else:
|
||||
mark = ' ' * (mark_width + 2)
|
||||
console.print(mark + line, highlight=False)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -0,0 +1,121 @@
|
|||
#!/usr/bin/env python
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import torch
|
||||
from ckpt_tree import StateDictTree, ckpt_to_state_dict
|
||||
from rich.progress import track
|
||||
from scipy import stats
|
||||
|
||||
prog_description = """\
|
||||
Compare the initialization distribution between state dicts by Kolmogorov-Smirnov test.
|
||||
""" # noqa: E501
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
description=prog_description)
|
||||
parser.add_argument(
|
||||
'model_a',
|
||||
type=Path,
|
||||
help='The path of the first checkpoint or model config.')
|
||||
parser.add_argument(
|
||||
'model_b',
|
||||
type=Path,
|
||||
help='The path of the second checkpoint or model config.')
|
||||
parser.add_argument(
|
||||
'--show',
|
||||
action='store_true',
|
||||
help='Whether to draw the KDE of variables')
|
||||
parser.add_argument(
|
||||
'-p',
|
||||
default=0.01,
|
||||
type=float,
|
||||
help='The threshold of p-value. '
|
||||
'Higher threshold means more strict test.')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def compare_distribution(state_dict_a, state_dict_b, p_thres):
|
||||
assert len(state_dict_a) == len(state_dict_b)
|
||||
for k, v1 in state_dict_a.items():
|
||||
assert k in state_dict_b
|
||||
v2 = state_dict_b[k]
|
||||
v1 = v1.cpu().flatten()
|
||||
v2 = v2.cpu().flatten()
|
||||
pvalue = stats.kstest(v1, v2).pvalue
|
||||
if pvalue < p_thres:
|
||||
yield k, pvalue, v1, v2
|
||||
|
||||
|
||||
def state_dict_from_cfg_or_ckpt(path, state_key=None):
|
||||
if path.suffix in ['.json', '.py', '.yml']:
|
||||
from mmengine.runner import get_state_dict
|
||||
|
||||
from mmpretrain.apis import init_model
|
||||
model = init_model(path, device='cpu')
|
||||
model.init_weights()
|
||||
return get_state_dict(model)
|
||||
else:
|
||||
ckpt = torch.load(path, map_location='cpu')
|
||||
return ckpt_to_state_dict(ckpt, state_key)
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
state_dict_a = state_dict_from_cfg_or_ckpt(args.model_a)
|
||||
state_dict_b = state_dict_from_cfg_or_ckpt(args.model_b)
|
||||
compare_keys = state_dict_a.keys() & state_dict_b.keys()
|
||||
if len(compare_keys) == 0:
|
||||
raise ValueError("The state dicts don't match, please convert "
|
||||
'to the same keys before comparison.')
|
||||
|
||||
root = StateDictTree()
|
||||
for key in track(compare_keys):
|
||||
if state_dict_a[key].shape != state_dict_b[key].shape:
|
||||
raise ValueError(f'The shapes of "{key}" are different. '
|
||||
'Please check models in the same architecture.')
|
||||
|
||||
# Sample at most 30000 items to prevent long-time calcuation.
|
||||
perm_ids = torch.randperm(state_dict_a[key].numel())[:30000]
|
||||
value_a = state_dict_a[key].flatten()[perm_ids]
|
||||
value_b = state_dict_b[key].flatten()[perm_ids]
|
||||
pvalue = stats.kstest(value_a, value_b).pvalue
|
||||
if pvalue < args.p:
|
||||
root.add_parameter(key, round(pvalue, 4))
|
||||
if args.show:
|
||||
try:
|
||||
import seaborn as sns
|
||||
except ImportError:
|
||||
raise ImportError('Please install `seaborn` by '
|
||||
'`pip install seaborn` to show KDE.')
|
||||
sample_a = str([round(v.item(), 2) for v in value_a[:10]])
|
||||
sample_b = str([round(v.item(), 2) for v in value_b[:10]])
|
||||
if value_a.std() > 0:
|
||||
sns.kdeplot(value_a, fill=True)
|
||||
else:
|
||||
sns.scatterplot(x=[value_a[0].item()], y=[1])
|
||||
if value_b.std() > 0:
|
||||
sns.kdeplot(value_b, fill=True)
|
||||
else:
|
||||
sns.scatterplot(x=[value_b[0].item()], y=[1])
|
||||
plt.legend([
|
||||
f'{args.model_a.stem}: {sample_a}',
|
||||
f'{args.model_b.stem}: {sample_b}'
|
||||
])
|
||||
plt.title(key)
|
||||
plt.show()
|
||||
if len(root) > 0:
|
||||
root.draw_tree(with_value=True)
|
||||
print("Above parameters didn't pass the test, "
|
||||
'and the values are their similarity score.')
|
||||
else:
|
||||
print('The distributions of all weights are the same.')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -0,0 +1,501 @@
|
|||
import argparse
|
||||
import copy
|
||||
import re
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
|
||||
import yaml
|
||||
from prompt_toolkit import ANSI
|
||||
from prompt_toolkit import prompt as _prompt
|
||||
from prompt_toolkit.completion import (FuzzyCompleter, FuzzyWordCompleter,
|
||||
PathCompleter)
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
from rich.prompt import Confirm, Prompt
|
||||
from rich.syntax import Syntax
|
||||
|
||||
prog_description = """\
|
||||
To display metafile or fill missing fields of the metafile.
|
||||
"""
|
||||
|
||||
MMCLS_ROOT = Path(__file__).absolute().parents[1].resolve().absolute()
|
||||
console = Console()
|
||||
dataset_completer = FuzzyWordCompleter([
|
||||
'ImageNet-1k', 'ImageNet-21k', 'CIFAR-10', 'CIFAR-100', 'RefCOCO', 'VQAv2',
|
||||
'COCO', 'OpenImages', 'Object365', 'CC3M', 'CC12M', 'YFCC100M', 'VG'
|
||||
])
|
||||
|
||||
|
||||
def prompt(message,
|
||||
allow_empty=True,
|
||||
default=None,
|
||||
multiple=False,
|
||||
completer=None):
|
||||
with console.capture() as capture:
|
||||
console.print(message, end='')
|
||||
|
||||
message = ANSI(capture.get())
|
||||
ask = partial(
|
||||
_prompt, message=message, default=default or '', completer=completer)
|
||||
|
||||
out = ask()
|
||||
|
||||
if multiple:
|
||||
outs = []
|
||||
while out != '':
|
||||
outs.append(out)
|
||||
out = ask()
|
||||
return outs
|
||||
|
||||
if not allow_empty and out == '':
|
||||
while out == '':
|
||||
out = ask()
|
||||
|
||||
if default is None and out == '':
|
||||
return None
|
||||
else:
|
||||
return out.strip()
|
||||
|
||||
|
||||
class MyDumper(yaml.Dumper):
|
||||
|
||||
def increase_indent(self, flow=False, indentless=False):
|
||||
return super(MyDumper, self).increase_indent(flow, False)
|
||||
|
||||
|
||||
yaml_dump = partial(
|
||||
yaml.dump, Dumper=MyDumper, default_flow_style=False, sort_keys=False)
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description=prog_description)
|
||||
parser.add_argument('--src', type=Path, help='The path of the matafile.')
|
||||
parser.add_argument('--out', '-o', type=Path, help='The output path.')
|
||||
parser.add_argument(
|
||||
'--inplace',
|
||||
'-i',
|
||||
action='store_true',
|
||||
help='Modify the source metafile inplace.')
|
||||
parser.add_argument(
|
||||
'--view', action='store_true', help='Only pretty print the metafile.')
|
||||
parser.add_argument('--csv', type=str, help='Use a csv to update models.')
|
||||
args = parser.parse_args()
|
||||
if args.inplace:
|
||||
args.out = args.src
|
||||
return args
|
||||
|
||||
|
||||
def get_flops_params(config_path):
|
||||
import numpy as np
|
||||
import torch
|
||||
from mmengine.analysis import FlopAnalyzer, parameter_count
|
||||
from mmengine.dataset import Compose
|
||||
from mmengine.model.utils import revert_sync_batchnorm
|
||||
from mmengine.registry import DefaultScope
|
||||
|
||||
from mmpretrain.apis import get_model
|
||||
from mmpretrain.models.utils import no_load_hf_pretrained_model
|
||||
|
||||
with no_load_hf_pretrained_model():
|
||||
model = get_model(config_path, device='cpu')
|
||||
model = revert_sync_batchnorm(model)
|
||||
model.eval()
|
||||
params = int(parameter_count(model)[''])
|
||||
|
||||
# get flops
|
||||
try:
|
||||
if 'test_dataloader' in model._config:
|
||||
# build the data pipeline
|
||||
test_dataset = model._config.test_dataloader.dataset
|
||||
if test_dataset.pipeline[0]['type'] == 'LoadImageFromFile':
|
||||
test_dataset.pipeline.pop(0)
|
||||
if test_dataset.type in ['CIFAR10', 'CIFAR100']:
|
||||
# The image shape of CIFAR is (32, 32, 3)
|
||||
test_dataset.pipeline.insert(1, dict(type='Resize', scale=32))
|
||||
|
||||
with DefaultScope.overwrite_default_scope('mmpretrain'):
|
||||
data = Compose(test_dataset.pipeline)({
|
||||
'img':
|
||||
np.random.randint(0, 256, (224, 224, 3), dtype=np.uint8)
|
||||
})
|
||||
resolution = tuple(data['inputs'].shape[-2:])
|
||||
else:
|
||||
# For configs only for get model.
|
||||
resolution = (224, 224)
|
||||
|
||||
with torch.no_grad():
|
||||
# Skip flops if the model doesn't have `extract_feat` method.
|
||||
model.forward = model.extract_feat
|
||||
model.to('cpu')
|
||||
inputs = (torch.randn((1, 3, *resolution)), )
|
||||
analyzer = FlopAnalyzer(model, inputs)
|
||||
analyzer.unsupported_ops_warnings(False)
|
||||
analyzer.uncalled_modules_warnings(False)
|
||||
flops = int(analyzer.total())
|
||||
except Exception:
|
||||
print('Unable to calculate flops.')
|
||||
flops = None
|
||||
return flops, params
|
||||
|
||||
|
||||
def fill_collection(collection: dict):
|
||||
if collection.get('Name') is None:
|
||||
name = prompt(
|
||||
'Please input the collection [red]name[/]: ', allow_empty=False)
|
||||
collection['Name'] = name
|
||||
|
||||
if collection.get('Metadata', {}).get('Architecture') is None:
|
||||
architecture = prompt(
|
||||
'Please input the model [red]architecture[/] '
|
||||
'(input empty to finish): ',
|
||||
multiple=True)
|
||||
if len(architecture) > 0:
|
||||
collection.setdefault('Metadata', {})
|
||||
collection['Metadata']['Architecture'] = architecture
|
||||
|
||||
if collection.get('Paper', {}).get('Title') is None:
|
||||
title = prompt('Please input the [red]paper title[/]: ')
|
||||
else:
|
||||
title = collection['Paper']['Title']
|
||||
if collection.get('Paper', {}).get('URL') is None:
|
||||
url = prompt('Please input the [red]paper url[/]: ')
|
||||
else:
|
||||
url = collection['Paper']['URL']
|
||||
paper = dict(Title=title, URL=url)
|
||||
collection['Paper'] = paper
|
||||
|
||||
if collection.get('README') is None:
|
||||
readme = prompt(
|
||||
'Please input the [red]README[/] file path: ',
|
||||
completer=PathCompleter(file_filter=lambda name: Path(name).is_dir(
|
||||
) or 'README.md' in name))
|
||||
if readme is not None:
|
||||
collection['README'] = str(
|
||||
Path(readme).absolute().relative_to(MMCLS_ROOT))
|
||||
else:
|
||||
collection['README'] = None
|
||||
|
||||
order = ['Name', 'Metadata', 'Paper', 'README', 'Code']
|
||||
collection = {
|
||||
k: collection[k]
|
||||
for k in sorted(collection.keys(), key=order.index)
|
||||
}
|
||||
return collection
|
||||
|
||||
|
||||
def fill_model_by_prompt(model: dict, defaults: dict):
|
||||
# Name
|
||||
if model.get('Name') is None:
|
||||
name = prompt(
|
||||
'Please input the model [red]name[/]: ', allow_empty=False)
|
||||
model['Name'] = name
|
||||
|
||||
# In Collection
|
||||
model['In Collection'] = defaults.get('In Collection')
|
||||
|
||||
# Config
|
||||
config = model.get('Config')
|
||||
if config is None:
|
||||
config = prompt(
|
||||
'Please input the [red]config[/] file path: ',
|
||||
completer=FuzzyCompleter(PathCompleter()))
|
||||
if config is not None:
|
||||
config = str(Path(config).absolute().relative_to(MMCLS_ROOT))
|
||||
model['Config'] = config
|
||||
|
||||
# Metadata.Flops, Metadata.Parameters
|
||||
flops = model.get('Metadata', {}).get('FLOPs')
|
||||
params = model.get('Metadata', {}).get('Parameters')
|
||||
if model.get('Config') is not None and (
|
||||
MMCLS_ROOT / model['Config']).exists() and (flops is None
|
||||
and params is None):
|
||||
print('Automatically compute FLOPs and Parameters from config.')
|
||||
flops, params = get_flops_params(str(MMCLS_ROOT / model['Config']))
|
||||
|
||||
if flops is None:
|
||||
flops = prompt('Please specify the [red]FLOPs[/]: ')
|
||||
if flops is not None:
|
||||
flops = int(flops)
|
||||
if params is None:
|
||||
params = prompt('Please specify the [red]number of parameters[/]: ')
|
||||
if params is not None:
|
||||
params = int(params)
|
||||
|
||||
model.setdefault('Metadata', {})
|
||||
model['Metadata'].setdefault('FLOPs', flops)
|
||||
model['Metadata'].setdefault('Parameters', params)
|
||||
|
||||
if 'Training Data' not in model.get('Metadata', {}) and \
|
||||
'Training Data' not in defaults.get('Metadata', {}):
|
||||
training_data = prompt(
|
||||
'Please input all [red]training dataset[/], '
|
||||
'include pre-training (input empty to finish): ',
|
||||
completer=dataset_completer,
|
||||
multiple=True)
|
||||
if len(training_data) > 1:
|
||||
model['Metadata']['Training Data'] = training_data
|
||||
elif len(training_data) == 1:
|
||||
model['Metadata']['Training Data'] = training_data[0]
|
||||
|
||||
results = model.get('Results')
|
||||
if results is None:
|
||||
test_dataset = prompt(
|
||||
'Please input the [red]test dataset[/]: ',
|
||||
completer=dataset_completer)
|
||||
if test_dataset is not None:
|
||||
task = Prompt.ask(
|
||||
'Please input the [red]test task[/]',
|
||||
default='Image Classification')
|
||||
if task == 'Image Classification':
|
||||
metrics = {}
|
||||
top1 = prompt('Please input the [red]top-1 accuracy[/]: ')
|
||||
top5 = prompt('Please input the [red]top-5 accuracy[/]: ')
|
||||
if top1 is not None:
|
||||
metrics['Top 1 Accuracy'] = round(float(top1), 2)
|
||||
if top5 is not None:
|
||||
metrics['Top 5 Accuracy'] = round(float(top5), 2)
|
||||
else:
|
||||
metrics_list = prompt(
|
||||
'Please input the [red]metrics[/] like "mAP=94.98" '
|
||||
'(input empty to finish): ',
|
||||
multiple=True)
|
||||
metrics = {}
|
||||
for metric in metrics_list:
|
||||
k, v = metric.split('=')[:2]
|
||||
metrics[k] = round(float(v), 2)
|
||||
results = [{
|
||||
'Task': task,
|
||||
'Dataset': test_dataset,
|
||||
'Metrics': metrics or None,
|
||||
}]
|
||||
model['Results'] = results
|
||||
|
||||
weights = model.get('Weights')
|
||||
if weights is None:
|
||||
weights = prompt('Please input the [red]checkpoint download link[/]: ')
|
||||
model['Weights'] = weights
|
||||
|
||||
if model.get('Converted From') is None and model.get(
|
||||
'Weights') is not None:
|
||||
if '3rdparty' in model['Name'] or Confirm.ask(
|
||||
'Is the checkpoint is converted '
|
||||
'from [red]other repository[/]?',
|
||||
default=False):
|
||||
converted_from = {}
|
||||
converted_from['Weights'] = prompt(
|
||||
'Please fill the original checkpoint download link: ')
|
||||
converted_from['Code'] = Prompt.ask(
|
||||
'Please fill the original repository link',
|
||||
default=defaults.get('Convert From.Code', None))
|
||||
defaults['Convert From.Code'] = converted_from['Code']
|
||||
model['Converted From'] = converted_from
|
||||
elif model.get('Converted From', {}).get('Code') is not None:
|
||||
defaults['Convert From.Code'] = model['Converted From']['Code']
|
||||
|
||||
order = [
|
||||
'Name', 'Metadata', 'In Collection', 'Results', 'Weights', 'Config',
|
||||
'Converted From', 'Downstream'
|
||||
]
|
||||
model = {k: model[k] for k in sorted(model.keys(), key=order.index)}
|
||||
return model
|
||||
|
||||
|
||||
def update_model_by_dict(model: dict, update_dict: dict, defaults: dict):
|
||||
# Name
|
||||
if 'name override' in update_dict:
|
||||
model['Name'] = update_dict['name override'].strip()
|
||||
|
||||
# In Collection
|
||||
model['In Collection'] = defaults.get('In Collection')
|
||||
|
||||
# Config
|
||||
if 'config' in update_dict:
|
||||
config = update_dict['config'].strip()
|
||||
config = str(Path(config).absolute().relative_to(MMCLS_ROOT))
|
||||
config_updated = (config != model.get('Config'))
|
||||
model['Config'] = config
|
||||
else:
|
||||
config_updated = False
|
||||
|
||||
# Metadata.Flops, Metadata.Parameters
|
||||
flops = model.get('Metadata', {}).get('FLOPs')
|
||||
params = model.get('Metadata', {}).get('Parameters')
|
||||
if config_updated or (flops is None and params is None):
|
||||
print(f'Automatically compute FLOPs and Parameters of {model["Name"]}')
|
||||
flops, params = get_flops_params(str(MMCLS_ROOT / model['Config']))
|
||||
|
||||
model.setdefault('Metadata', {})
|
||||
model['Metadata']['FLOPs'] = flops
|
||||
model['Metadata']['Parameters'] = params
|
||||
|
||||
# Metadata.Training Data
|
||||
if 'training dataset' in update_dict:
|
||||
train_data = update_dict['training dataset'].strip()
|
||||
train_data = re.split(r'\s+', train_data)
|
||||
if len(train_data) > 1:
|
||||
model['Metadata']['Training Data'] = train_data
|
||||
elif len(train_data) == 1:
|
||||
model['Metadata']['Training Data'] = train_data[0]
|
||||
|
||||
# Results.Dataset
|
||||
if 'test dataset' in update_dict:
|
||||
test_data = update_dict['test dataset'].strip()
|
||||
results = model.get('Results') or [{}]
|
||||
result = results[0]
|
||||
result['Dataset'] = test_data
|
||||
model['Results'] = results
|
||||
|
||||
# Results.Metrics.Top 1 Accuracy
|
||||
result = None
|
||||
if 'top-1' in update_dict:
|
||||
top1 = update_dict['top-1']
|
||||
results = model.get('Results') or [{}]
|
||||
result = results[0]
|
||||
result.setdefault('Metrics', {})
|
||||
result['Metrics']['Top 1 Accuracy'] = round(float(top1), 2)
|
||||
task = 'Image Classification'
|
||||
model['Results'] = results
|
||||
|
||||
# Results.Metrics.Top 5 Accuracy
|
||||
if 'top-5' in update_dict:
|
||||
top5 = update_dict['top-5']
|
||||
results = model.get('Results') or [{}]
|
||||
result = results[0]
|
||||
result.setdefault('Metrics', {})
|
||||
result['Metrics']['Top 5 Accuracy'] = round(float(top5), 2)
|
||||
task = 'Image Classification'
|
||||
model['Results'] = results
|
||||
|
||||
if result is not None:
|
||||
result['Metrics']['Task'] = task
|
||||
|
||||
# Weights
|
||||
if 'weights' in update_dict:
|
||||
weights = update_dict['weights'].strip()
|
||||
model['Weights'] = weights
|
||||
|
||||
# Converted From.Code
|
||||
if 'converted from.code' in update_dict:
|
||||
from_code = update_dict['converted from.code'].strip()
|
||||
model.setdefault('Converted From', {})
|
||||
model['Converted From']['Code'] = from_code
|
||||
|
||||
# Converted From.Weights
|
||||
if 'converted from.weights' in update_dict:
|
||||
from_weight = update_dict['converted from.weights'].strip()
|
||||
model.setdefault('Converted From', {})
|
||||
model['Converted From']['Weights'] = from_weight
|
||||
|
||||
order = [
|
||||
'Name', 'Metadata', 'In Collection', 'Results', 'Weights', 'Config',
|
||||
'Converted From', 'Downstream'
|
||||
]
|
||||
model = {k: model[k] for k in sorted(model.keys(), key=order.index)}
|
||||
return model
|
||||
|
||||
|
||||
def format_collection(collection: dict):
|
||||
yaml_str = yaml_dump(collection)
|
||||
return Panel(
|
||||
Syntax(yaml_str, 'yaml', background_color='default'),
|
||||
width=150,
|
||||
title='Collection')
|
||||
|
||||
|
||||
def format_model(model: dict):
|
||||
yaml_str = yaml_dump(model)
|
||||
return Panel(
|
||||
Syntax(yaml_str, 'yaml', background_color='default'),
|
||||
width=150,
|
||||
title='Model')
|
||||
|
||||
|
||||
def order_models(model):
|
||||
order = []
|
||||
# Pre-trained model
|
||||
order.append(int('Downstream' not in model))
|
||||
# non-3rdparty model
|
||||
order.append(int('3rdparty' in model['Name']))
|
||||
# smaller model
|
||||
order.append(model.get('Metadata', {}).get('Parameters', 0))
|
||||
# faster model
|
||||
order.append(model.get('Metadata', {}).get('FLOPs', 0))
|
||||
# name order
|
||||
order.append(len(model['Name']))
|
||||
|
||||
return tuple(order)
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
if args.src is not None:
|
||||
with open(args.src, 'r') as f:
|
||||
content = yaml.load(f, yaml.SafeLoader)
|
||||
else:
|
||||
content = {}
|
||||
|
||||
if args.view:
|
||||
collection = content.get('Collections', [{}])[0]
|
||||
console.print(format_collection(collection))
|
||||
models = content.get('Models', [])
|
||||
for model in models:
|
||||
console.print(format_model(model))
|
||||
return
|
||||
|
||||
collection = content.get('Collections', [{}])[0]
|
||||
ori_collection = copy.deepcopy(collection)
|
||||
|
||||
console.print(format_collection(collection))
|
||||
collection = fill_collection(collection)
|
||||
if ori_collection != collection:
|
||||
console.print(format_collection(collection))
|
||||
model_defaults = {
|
||||
'In Collection': collection['Name'],
|
||||
'Metadata': collection.get('Metadata', {}),
|
||||
}
|
||||
|
||||
models = content.get('Models', [])
|
||||
updated_models = []
|
||||
|
||||
if args.csv is not None:
|
||||
import pandas as pd
|
||||
df = pd.read_csv(args.csv).rename(columns=lambda x: x.strip().lower())
|
||||
assert df['name'].is_unique, 'The csv has duplicated model names.'
|
||||
models_dict = {item['Name']: item for item in models}
|
||||
for update_dict in df.to_dict('records'):
|
||||
assert 'name' in update_dict, 'The csv must have the `Name` field.'
|
||||
model_name = update_dict['name'].strip()
|
||||
model = models_dict.pop(model_name, {'Name': model_name})
|
||||
model = update_model_by_dict(model, update_dict, model_defaults)
|
||||
updated_models.append(model)
|
||||
updated_models.extend(models_dict.values())
|
||||
else:
|
||||
for model in models:
|
||||
console.print(format_model(model))
|
||||
ori_model = copy.deepcopy(model)
|
||||
model = fill_model_by_prompt(model, model_defaults)
|
||||
if ori_model != model:
|
||||
console.print(format_model(model))
|
||||
updated_models.append(model)
|
||||
|
||||
while Confirm.ask('Add new model?', default=False):
|
||||
model = fill_model_by_prompt({}, model_defaults)
|
||||
updated_models.append(model)
|
||||
|
||||
# Save updated models even error happened.
|
||||
updated_models.sort(key=order_models)
|
||||
if args.out is not None:
|
||||
with open(args.out, 'w') as f:
|
||||
yaml_dump({'Collections': [collection]}, f)
|
||||
f.write('\n')
|
||||
yaml_dump({'Models': updated_models}, f)
|
||||
else:
|
||||
modelindex = {'Collections': [collection], 'Models': updated_models}
|
||||
yaml_str = yaml_dump(modelindex)
|
||||
console.print(Syntax(yaml_str, 'yaml', background_color='default'))
|
||||
console.print('Specify [red]`--out`[/] to dump to file.')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -0,0 +1,453 @@
|
|||
# flake8: noqa
|
||||
import argparse
|
||||
import re
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
|
||||
from modelindex.load_model_index import load
|
||||
from modelindex.models.ModelIndex import ModelIndex
|
||||
from tabulate import tabulate
|
||||
|
||||
MMPT_ROOT = Path(__file__).absolute().parents[1]
|
||||
|
||||
prog_description = """\
|
||||
Use metafile to generate a README.md.
|
||||
|
||||
Notice that the tool may fail in some corner cases, and you still need to check and fill some contents manually in the generated README.
|
||||
"""
|
||||
|
||||
PREDICT_TEMPLATE = """\
|
||||
**Predict image**
|
||||
|
||||
```python
|
||||
from mmpretrain import inference_model
|
||||
|
||||
predict = inference_model('{model_name}', 'demo/bird.JPEG')
|
||||
print(predict['pred_class'])
|
||||
print(predict['pred_score'])
|
||||
```
|
||||
"""
|
||||
|
||||
RETRIEVE_TEMPLATE = """\
|
||||
**Retrieve image**
|
||||
|
||||
```python
|
||||
from mmpretrain import ImageRetrievalInferencer
|
||||
|
||||
inferencer = ImageRetrievalInferencer('{model_name}', prototype='demo/')
|
||||
predict = inferencer('demo/dog.jpg', topk=2)[0]
|
||||
print(predict[0])
|
||||
print(predict[1])
|
||||
```
|
||||
"""
|
||||
|
||||
USAGE_TEMPLATE = """\
|
||||
**Use the model**
|
||||
|
||||
```python
|
||||
import torch
|
||||
from mmpretrain import get_model
|
||||
|
||||
model = get_model('{model_name}', pretrained=True)
|
||||
inputs = torch.rand(1, 3, 224, 224)
|
||||
out = model(inputs)
|
||||
print(type(out))
|
||||
# To extract features.
|
||||
feats = model.extract_feat(inputs)
|
||||
print(type(feats))
|
||||
```
|
||||
"""
|
||||
|
||||
TRAIN_TEST_TEMPLATE = """\
|
||||
**Train/Test Command**
|
||||
|
||||
Prepare your dataset according to the [docs](https://mmpretrain.readthedocs.io/en/latest/user_guides/dataset_prepare.html#prepare-dataset).
|
||||
|
||||
Train:
|
||||
|
||||
```shell
|
||||
python tools/train.py {train_config}
|
||||
```
|
||||
|
||||
Test:
|
||||
|
||||
```shell
|
||||
python tools/test.py {test_config} {test_weights}
|
||||
```
|
||||
"""
|
||||
|
||||
TEST_ONLY_TEMPLATE = """\
|
||||
**Test Command**
|
||||
|
||||
Prepare your dataset according to the [docs](https://mmpretrain.readthedocs.io/en/latest/user_guides/dataset_prepare.html#prepare-dataset).
|
||||
|
||||
Test:
|
||||
|
||||
```shell
|
||||
python tools/test.py {test_config} {test_weights}
|
||||
```
|
||||
"""
|
||||
|
||||
METRIC_MAPPING = {
|
||||
'Top 1 Accuracy': 'Top-1 (%)',
|
||||
'Top 5 Accuracy': 'Top-5 (%)',
|
||||
}
|
||||
|
||||
DATASET_PRIORITY = {
|
||||
'ImageNet-1k': 0,
|
||||
'CIFAR-10': 10,
|
||||
'CIFAR-100': 20,
|
||||
}
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description=prog_description)
|
||||
parser.add_argument('metafile', type=Path, help='The path of metafile')
|
||||
parser.add_argument(
|
||||
'--table', action='store_true', help='Only generate summary tables')
|
||||
parser.add_argument(
|
||||
'--update', type=str, help='Update the specified readme file.')
|
||||
parser.add_argument('--out', type=str, help='Output to the file.')
|
||||
parser.add_argument(
|
||||
'--update-items',
|
||||
type=str,
|
||||
nargs='+',
|
||||
default=['models'],
|
||||
help='Update the specified readme file.')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def filter_models_by_task(models, task):
|
||||
model_list = []
|
||||
for model in models:
|
||||
if model.results is None and task is None:
|
||||
model_list.append(model)
|
||||
elif model.results is None:
|
||||
continue
|
||||
elif model.results[0].task == task or task == 'any':
|
||||
model_list.append(model)
|
||||
return model_list
|
||||
|
||||
|
||||
def add_title(metafile: ModelIndex):
|
||||
paper = metafile.collections[0].paper
|
||||
title = paper['Title']
|
||||
url = paper['URL']
|
||||
abbr = metafile.collections[0].name
|
||||
papertype = metafile.collections[0].data.get('type', 'Algorithm')
|
||||
|
||||
return f'# {abbr}\n> [{title}]({url})\n<!-- [{papertype.upper()}] -->\n'
|
||||
|
||||
|
||||
def add_abstract(metafile: ModelIndex):
|
||||
paper = metafile.collections[0].paper
|
||||
url = paper['URL']
|
||||
if 'arxiv' in url:
|
||||
try:
|
||||
import arxiv
|
||||
search = arxiv.Search(id_list=[url.split('/')[-1]])
|
||||
info = next(search.results())
|
||||
abstract = info.summary.replace('\n', ' ')
|
||||
except ImportError:
|
||||
warnings.warn('Install arxiv parser by `pip install arxiv` '
|
||||
'to automatically generate abstract.')
|
||||
abstract = None
|
||||
else:
|
||||
abstract = None
|
||||
|
||||
content = '## Abstract\n'
|
||||
if abstract is not None:
|
||||
content += f'\n{abstract}\n'
|
||||
return content
|
||||
|
||||
|
||||
def add_usage(metafile):
|
||||
models = metafile.models
|
||||
if len(models) == 0:
|
||||
return
|
||||
|
||||
content = []
|
||||
content.append('## How to use it?\n\n<!-- [TABS-BEGIN] -->\n')
|
||||
|
||||
# Predict image
|
||||
cls_models = filter_models_by_task(models, 'Image Classification')
|
||||
if cls_models:
|
||||
model_name = cls_models[0].name
|
||||
content.append(PREDICT_TEMPLATE.format(model_name=model_name))
|
||||
|
||||
# Retrieve image
|
||||
retrieval_models = filter_models_by_task(models, 'Image Retrieval')
|
||||
if retrieval_models:
|
||||
model_name = retrieval_models[0].name
|
||||
content.append(RETRIEVE_TEMPLATE.format(model_name=model_name))
|
||||
|
||||
# Use the model
|
||||
model_name = models[0].name
|
||||
content.append(USAGE_TEMPLATE.format(model_name=model_name))
|
||||
|
||||
# Train/Test Command
|
||||
inputs = {}
|
||||
train_model = [
|
||||
model for model in models
|
||||
if 'headless' not in model.name and '3rdparty' not in model.name
|
||||
]
|
||||
if train_model:
|
||||
template = TRAIN_TEST_TEMPLATE
|
||||
inputs['train_config'] = train_model[0].config
|
||||
elif len(filter_models_by_task(models, task='any')) > 0:
|
||||
template = TEST_ONLY_TEMPLATE
|
||||
else:
|
||||
content.append('\n<!-- [TABS-END] -->\n')
|
||||
return '\n'.join(content)
|
||||
|
||||
test_model = filter_models_by_task(models, task='any')[0]
|
||||
inputs['test_config'] = test_model.config
|
||||
inputs['test_weights'] = test_model.weights
|
||||
content.append(template.format(**inputs))
|
||||
|
||||
content.append('\n<!-- [TABS-END] -->\n')
|
||||
return '\n'.join(content)
|
||||
|
||||
|
||||
def format_pretrain(pretrain_field):
|
||||
pretrain_infos = pretrain_field.split('-')[:-1]
|
||||
infos = []
|
||||
for info in pretrain_infos:
|
||||
if re.match('^\d+e$', info):
|
||||
info = f'{info[:-1]}-Epochs'
|
||||
elif re.match('^in\d+k$', info):
|
||||
info = f'ImageNet-{info[2:-1]}k'
|
||||
else:
|
||||
info = info.upper()
|
||||
infos.append(info)
|
||||
return ' '.join(infos)
|
||||
|
||||
|
||||
def generate_model_table(models,
|
||||
folder,
|
||||
with_pretrain=True,
|
||||
with_metric=True,
|
||||
pretrained_models=[]):
|
||||
header = ['Model']
|
||||
if with_pretrain:
|
||||
header.append('Pretrain')
|
||||
header.extend(['Params (M)', 'Flops (G)'])
|
||||
if with_metric:
|
||||
metrics = set()
|
||||
for model in models:
|
||||
metrics.update(model.results[0].metrics.keys())
|
||||
metrics = sorted(list(set(metrics)))
|
||||
for metric in metrics:
|
||||
header.append(METRIC_MAPPING.get(metric, metric))
|
||||
header.extend(['Config', 'Download'])
|
||||
|
||||
rows = []
|
||||
for model in models:
|
||||
model_name = f'`{model.name}`'
|
||||
config = (MMPT_ROOT / model.config).relative_to(folder)
|
||||
if model.weights is not None:
|
||||
download = f'[model]({model.weights})'
|
||||
else:
|
||||
download = 'N/A'
|
||||
|
||||
if 'Converted From' in model.data:
|
||||
model_name += '\*'
|
||||
converted_from = model.data['Converted From']
|
||||
elif model.weights is not None:
|
||||
log = re.sub(r'.pth$', '.json', model.weights)
|
||||
download += f' \| [log]({log})'
|
||||
|
||||
row = [model_name]
|
||||
if with_pretrain:
|
||||
pretrain_field = [
|
||||
field for field in model.name.split('_')
|
||||
if field.endswith('-pre')
|
||||
]
|
||||
if pretrain_field:
|
||||
pretrain = format_pretrain(pretrain_field[0])
|
||||
upstream = [
|
||||
pretrain_model for pretrain_model in pretrained_models
|
||||
if model.name in pretrain_model.data.get('Downstream', [])
|
||||
]
|
||||
if upstream:
|
||||
pretrain = f'[{pretrain}]({upstream[0].weights})'
|
||||
else:
|
||||
pretrain = 'From scratch'
|
||||
row.append(pretrain)
|
||||
|
||||
if model.metadata.parameters is not None:
|
||||
row.append(f'{model.metadata.parameters / 1e6:.2f}') # Params
|
||||
else:
|
||||
row.append('N/A')
|
||||
if model.metadata.flops is not None:
|
||||
row.append(f'{model.metadata.flops / 1e9:.2f}') # Params
|
||||
else:
|
||||
row.append('N/A')
|
||||
|
||||
if with_metric:
|
||||
for metric in metrics:
|
||||
row.append(model.results[0].metrics.get(metric, 'N/A'))
|
||||
row.append(f'[config]({config})')
|
||||
row.append(download)
|
||||
|
||||
rows.append(row)
|
||||
|
||||
table_cfg = dict(
|
||||
tablefmt='pipe',
|
||||
floatfmt='.2f',
|
||||
colalign=['left'] + ['center'] * (len(row) - 1))
|
||||
table_string = tabulate(rows, header, **table_cfg) + '\n'
|
||||
if any('Converted From' in model.data for model in models):
|
||||
table_string += (
|
||||
f"\n*Models with \* are converted from the [official repo]({converted_from['Code']}). "
|
||||
"The config files of these models are only for inference. We haven't reproduce the training results.*\n"
|
||||
)
|
||||
|
||||
return table_string
|
||||
|
||||
|
||||
def add_models(metafile):
|
||||
models = metafile.models
|
||||
if len(models) == 0:
|
||||
return ''
|
||||
|
||||
content = ['## Models and results\n']
|
||||
algo_folder = Path(metafile.filepath).parent.absolute().resolve()
|
||||
|
||||
# Pretrained models
|
||||
pretrain_models = filter_models_by_task(models, task=None)
|
||||
if pretrain_models:
|
||||
content.append('### Pretrained models\n')
|
||||
content.append(
|
||||
generate_model_table(
|
||||
pretrain_models,
|
||||
algo_folder,
|
||||
with_pretrain=False,
|
||||
with_metric=False))
|
||||
|
||||
# Classification models
|
||||
tasks = [
|
||||
'Image Classification',
|
||||
'Image Retrieval',
|
||||
'Multi-Label Classification',
|
||||
'Image Caption',
|
||||
'Visual Grounding',
|
||||
'Visual Question Answering',
|
||||
'Image-To-Text Retrieval',
|
||||
'Text-To-Image Retrieval',
|
||||
'NLVR',
|
||||
]
|
||||
|
||||
for task in tasks:
|
||||
task_models = filter_models_by_task(models, task=task)
|
||||
if task_models:
|
||||
datasets = {model.results[0].dataset for model in task_models}
|
||||
datasets = sorted(
|
||||
list(datasets), key=lambda x: DATASET_PRIORITY.get(x, 50))
|
||||
for dataset in datasets:
|
||||
content.append(f'### {task} on {dataset}\n')
|
||||
dataset_models = [
|
||||
model for model in task_models
|
||||
if model.results[0].dataset == dataset
|
||||
]
|
||||
content.append(
|
||||
generate_model_table(
|
||||
dataset_models,
|
||||
algo_folder,
|
||||
pretrained_models=pretrain_models))
|
||||
return '\n'.join(content)
|
||||
|
||||
|
||||
def parse_readme(readme):
|
||||
with open(readme, 'r') as f:
|
||||
file = f.read()
|
||||
|
||||
content = {}
|
||||
|
||||
for img_match in re.finditer(
|
||||
'^<div.*\n.*\n</div>\n', file, flags=re.MULTILINE):
|
||||
content['image'] = img_match.group()
|
||||
start, end = img_match.span()
|
||||
file = file[:start] + file[end:]
|
||||
break
|
||||
|
||||
sections = re.split('^## ', file, flags=re.MULTILINE)
|
||||
for section in sections:
|
||||
if section.startswith('# '):
|
||||
content['title'] = section.strip() + '\n'
|
||||
elif section.startswith('Introduction'):
|
||||
content['intro'] = '## ' + section.strip() + '\n'
|
||||
elif section.startswith('Abstract'):
|
||||
content['abs'] = '## ' + section.strip() + '\n'
|
||||
elif section.startswith('How to use it'):
|
||||
content['usage'] = '## ' + section.strip() + '\n'
|
||||
elif section.startswith('Models and results'):
|
||||
content['models'] = '## ' + section.strip() + '\n'
|
||||
elif section.startswith('Citation'):
|
||||
content['citation'] = '## ' + section.strip() + '\n'
|
||||
else:
|
||||
section_title = section.split('\n', maxsplit=1)[0]
|
||||
content[section_title] = '## ' + section.strip() + '\n'
|
||||
return content
|
||||
|
||||
|
||||
def combine_readme(content: dict):
|
||||
content = content.copy()
|
||||
readme = content.pop('title')
|
||||
if 'intro' in content:
|
||||
readme += f"\n{content.pop('intro')}"
|
||||
readme += f"\n{content.pop('image')}"
|
||||
readme += f"\n{content.pop('abs')}"
|
||||
else:
|
||||
readme += f"\n{content.pop('abs')}"
|
||||
readme += f"\n{content.pop('image')}"
|
||||
|
||||
readme += f"\n{content.pop('usage')}"
|
||||
readme += f"\n{content.pop('models')}"
|
||||
|
||||
citation = content.pop('citation')
|
||||
if content:
|
||||
# Custom sections
|
||||
for v in content.values():
|
||||
readme += f'\n{v}'
|
||||
readme += f'\n{citation}'
|
||||
return readme
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
metafile = load(str(args.metafile))
|
||||
if args.table:
|
||||
print(add_models(metafile))
|
||||
return
|
||||
|
||||
if args.update is not None:
|
||||
content = parse_readme(args.update)
|
||||
else:
|
||||
content = {}
|
||||
|
||||
if 'title' not in content or 'title' in args.update_items:
|
||||
content['title'] = add_title(metafile)
|
||||
if 'abs' not in content or 'abs' in args.update_items:
|
||||
content['abs'] = add_abstract(metafile)
|
||||
if 'image' not in content or 'image' in args.update_items:
|
||||
img = '<div align=center>\n<img src="" width="50%"/>\n</div>\n'
|
||||
content['image'] = img
|
||||
if 'usage' not in content or 'usage' in args.update_items:
|
||||
content['usage'] = add_usage(metafile)
|
||||
if 'models' not in content or 'models' in args.update_items:
|
||||
content['models'] = add_models(metafile)
|
||||
if 'citation' not in content:
|
||||
content['citation'] = '## Citation\n```bibtex\n```\n'
|
||||
|
||||
content = combine_readme(content)
|
||||
if args.out is not None:
|
||||
with open(args.out, 'w') as f:
|
||||
f.write(content)
|
||||
else:
|
||||
print(content)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -1,33 +0,0 @@
|
|||
---
|
||||
name: 寻求帮助
|
||||
about: 遇到问题并寻求帮助
|
||||
title: ''
|
||||
labels: help wanted
|
||||
assignees: ''
|
||||
---
|
||||
|
||||
推荐使用英语模板 General question,以便你的问题帮助更多人。
|
||||
|
||||
### 首先确认以下内容
|
||||
|
||||
- 我已经查询了相关的 issue,但没有找到需要的帮助。
|
||||
- 我已经阅读了相关文档,但仍不知道如何解决。
|
||||
|
||||
### 描述你遇到的问题
|
||||
|
||||
\[填写这里\]
|
||||
|
||||
### 相关信息
|
||||
|
||||
1. `pip list | grep "mmcv\|mmcls\|^torch"` 命令的输出
|
||||
\[填写这里\]
|
||||
2. 如果你修改了,或者使用了新的配置文件,请在这里写明
|
||||
|
||||
```python
|
||||
[填写这里]
|
||||
```
|
||||
|
||||
3. 如果你是在训练过程中遇到的问题,请填写完整的训练日志和报错信息
|
||||
\[填写这里\]
|
||||
4. 如果你对 `mmcls` 文件夹下的代码做了其他相关的修改,请在这里写明
|
||||
\[填写这里\]
|
|
@ -1,34 +0,0 @@
|
|||
---
|
||||
name: 新功能
|
||||
about: 为项目提一个建议
|
||||
title: '[Feature]'
|
||||
labels: enhancement
|
||||
assignees: ''
|
||||
---
|
||||
|
||||
推荐使用英语模板 Feature request,以便你的问题帮助更多人。
|
||||
|
||||
### 描述这个功能
|
||||
|
||||
\[填写这里\]
|
||||
|
||||
### 动机
|
||||
|
||||
请简要说明以下为什么需要添加这个新功能
|
||||
例 1. 现在进行 xxx 的时候不方便
|
||||
例 2. 最近的论文中提出了有一个很有帮助的 xx
|
||||
|
||||
\[填写这里\]
|
||||
|
||||
### 相关资源
|
||||
|
||||
是否有相关的官方实现或者第三方实现?这些会很有参考意义。
|
||||
|
||||
\[填写这里\]
|
||||
|
||||
### 其他相关信息
|
||||
|
||||
其他和这个功能相关的信息或者截图,请放在这里。
|
||||
另外如果你愿意参与实现这个功能并提交 PR,请在这里说明,我们将非常欢迎。
|
||||
|
||||
\[填写这里\]
|
|
@ -1,44 +0,0 @@
|
|||
---
|
||||
name: 报告 Bug
|
||||
about: 报告问题以帮助我们提升
|
||||
title: '[Bug]'
|
||||
labels: bug
|
||||
assignees: ''
|
||||
---
|
||||
|
||||
推荐使用英语模板 Bug report,以便你的问题帮助更多人。
|
||||
|
||||
### 描述 bug
|
||||
|
||||
简单地描述一下遇到了什么 bug
|
||||
|
||||
\[填写这里\]
|
||||
|
||||
### 复现流程
|
||||
|
||||
在命令行中执行的详细操作
|
||||
|
||||
```shell
|
||||
[填写这里]
|
||||
```
|
||||
|
||||
### 相关信息
|
||||
|
||||
1. `pip list | grep "mmcv\|mmcls\|^torch"` 命令的输出
|
||||
\[填写这里\]
|
||||
2. 如果你修改了,或者使用了新的配置文件,请在这里写明
|
||||
|
||||
```python
|
||||
[填写这里]
|
||||
```
|
||||
|
||||
3. 如果你是在训练过程中遇到的问题,请填写完整的训练日志和报错信息
|
||||
\[填写这里\]
|
||||
4. 如果你对 `mmcls` 文件夹下的代码做了其他相关的修改,请在这里写明
|
||||
\[填写这里\]
|
||||
|
||||
### 附加内容
|
||||
|
||||
任何其他有关该 bug 的信息、截图等
|
||||
|
||||
\[填写这里\]
|
|
@ -0,0 +1,69 @@
|
|||
name: 🐞 Bug report
|
||||
description: Create a report to help us improve
|
||||
labels: ["bug"]
|
||||
title: "[Bug] "
|
||||
body:
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
If you have already identified the reason, we strongly appreciate you creating a new PR according to [the tutorial](https://mmpretrain.readthedocs.io/en/master/community/CONTRIBUTING.html)!
|
||||
If you need our help, please fill in the following form to help us to identify the bug.
|
||||
|
||||
- type: dropdown
|
||||
id: version
|
||||
attributes:
|
||||
label: Branch
|
||||
description: Which branch/version are you using?
|
||||
options:
|
||||
- main branch (mmpretrain version)
|
||||
- mmcls-1.x branch (v1.0.0rc6 or other 1.x version)
|
||||
- mmcls-0.x branch (v0.25.0 or other 0.x version)
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: describe
|
||||
validations:
|
||||
required: true
|
||||
attributes:
|
||||
label: Describe the bug
|
||||
description: |
|
||||
Please provide a clear and concise description of what the bug is.
|
||||
Preferably a simple and minimal code snippet that we can reproduce the error by running the code.
|
||||
placeholder: |
|
||||
A clear and concise description of what the bug is.
|
||||
|
||||
```python
|
||||
# Sample code to reproduce the problem
|
||||
```
|
||||
|
||||
```shell
|
||||
The command or script you run.
|
||||
```
|
||||
|
||||
```
|
||||
The error message or logs you got, with the full traceback.
|
||||
```
|
||||
|
||||
- type: textarea
|
||||
id: environment
|
||||
validations:
|
||||
required: true
|
||||
attributes:
|
||||
label: Environment
|
||||
description: |
|
||||
Please run `python -c "import mmpretrain.utils;import pprint;pprint.pp(dict(mmpretrain.utils.collect_env()))"` to collect necessary environment information and paste it here.
|
||||
placeholder: |
|
||||
```python
|
||||
# The output the above command
|
||||
```
|
||||
|
||||
- type: textarea
|
||||
id: other
|
||||
attributes:
|
||||
label: Other information
|
||||
description: |
|
||||
Tell us anything else you think we should know.
|
||||
|
||||
1. Did you make any modifications on the code or config?
|
||||
2. What do you think might be the reason?
|
|
@ -0,0 +1,29 @@
|
|||
name: 🚀 Feature request
|
||||
description: Suggest an idea for this project
|
||||
labels: ["enhancement"]
|
||||
title: "[Feature] "
|
||||
body:
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
If you have already implemented the feature, we strongly appreciate you creating a new PR according to [the tutorial](https://mmpretrain.readthedocs.io/en/master/community/CONTRIBUTING.html)!
|
||||
|
||||
- type: textarea
|
||||
id: describe
|
||||
validations:
|
||||
required: true
|
||||
attributes:
|
||||
label: Describe the feature
|
||||
description: |
|
||||
What kind of feature do you want MMPreTrain to add. If there is an official code release or third-party implementation, please also provide the information here, which would be very helpful.
|
||||
placeholder: |
|
||||
A clear and concise description of the motivation of the feature.
|
||||
Ex1. It is inconvenient when \[....\].
|
||||
Ex2. There is a recent paper \[....\], which is very helpful for \[....\].
|
||||
|
||||
- type: checkboxes
|
||||
id: pr
|
||||
attributes:
|
||||
label: Will you implement it?
|
||||
options:
|
||||
- label: I would like to implement this feature and create a PR!
|
|
@ -0,0 +1,70 @@
|
|||
name: 🐞 报告 Bug
|
||||
description: 报告你在使用中遇到的不合预期的情况
|
||||
labels: ["bug"]
|
||||
title: "[Bug] "
|
||||
body:
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
我们推荐使用英语模板 Bug report,以便你的问题帮助更多人。
|
||||
|
||||
如果你已经有了解决方案,我们非常欢迎你直接创建一个新的 PR 来解决这个问题。创建 PR 的流程可以参考[文档](https://mmpretrain.readthedocs.io/zh_CN/master/community/CONTRIBUTING.html)。
|
||||
如果你需要我们的帮助,请填写以下内容帮助我们定位 Bug。
|
||||
|
||||
- type: dropdown
|
||||
id: version
|
||||
attributes:
|
||||
label: 分支
|
||||
description: 你正在使用的分支/版本是哪个?
|
||||
options:
|
||||
- main 分支 (mmpretrain 版本)
|
||||
- mmcls-1.x 分支 (v1.0.0rc6 或者其它 1.x 版本)
|
||||
- mmcls-0.x 分支 (v0.25.0 或者其它 0.x 版本)
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: describe
|
||||
validations:
|
||||
required: true
|
||||
attributes:
|
||||
label: 描述该错误
|
||||
description: |
|
||||
请简要说明你遇到的错误。如果可以的话,请提供一个简短的代码片段帮助我们复现这一错误。
|
||||
placeholder: |
|
||||
问题的简要说明
|
||||
|
||||
```python
|
||||
# 复现错误的代码片段
|
||||
```
|
||||
|
||||
```shell
|
||||
# 发生错误时你的运行命令
|
||||
```
|
||||
|
||||
```
|
||||
错误信息和日志,请展示全部的错误日志和 traceback
|
||||
```
|
||||
|
||||
- type: textarea
|
||||
id: environment
|
||||
validations:
|
||||
required: true
|
||||
attributes:
|
||||
label: 环境信息
|
||||
description: |
|
||||
请运行指令 `python -c "import mmpretrain.utils;import pprint;pprint.pp(dict(mmpretrain.utils.collect_env()))"` 来收集必要的环境信息,并贴在下方。
|
||||
placeholder: |
|
||||
```python
|
||||
# 上述命令的输出
|
||||
```
|
||||
|
||||
- type: textarea
|
||||
id: other
|
||||
attributes:
|
||||
label: 其他信息
|
||||
description: |
|
||||
告诉我们其他有价值的信息。
|
||||
|
||||
1. 你是否对代码或配置文件做了任何改动?
|
||||
2. 你认为可能的原因是什么?
|
|
@ -0,0 +1,31 @@
|
|||
name: 🚀 功能建议
|
||||
description: 建议一项新的功能
|
||||
labels: ["enhancement"]
|
||||
title: "[Feature] "
|
||||
body:
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
推荐使用英语模板 Feature request,以便你的问题帮助更多人。
|
||||
|
||||
如果你已经实现了该功能,我们非常欢迎你直接创建一个新的 PR 来解决这个问题。创建 PR 的流程可以参考[文档](https://mmpretrain.readthedocs.io/zh_CN/master/community/CONTRIBUTING.html)。
|
||||
|
||||
- type: textarea
|
||||
id: describe
|
||||
validations:
|
||||
required: true
|
||||
attributes:
|
||||
label: 描述该功能
|
||||
description: |
|
||||
你希望 MMPreTrain 添加什么功能?如果存在相关的论文、官方实现或者第三方实现,请同时贴出链接,这将非常有帮助。
|
||||
placeholder: |
|
||||
简要说明该功能,及为什么需要该功能
|
||||
例 1. 现在进行 xxx 的时候不方便
|
||||
例 2. 最近的论文中提出了有一个很有帮助的 xx
|
||||
|
||||
- type: checkboxes
|
||||
id: pr
|
||||
attributes:
|
||||
label: 是否希望自己实现该功能?
|
||||
options:
|
||||
- label: 我希望自己来实现这一功能,并向 MMPreTrain 贡献代码!
|
|
@ -1,42 +0,0 @@
|
|||
---
|
||||
name: Bug report
|
||||
about: Create a report to help us improve
|
||||
title: '[Bug]'
|
||||
labels: bug
|
||||
assignees: ''
|
||||
---
|
||||
|
||||
### Describe the bug
|
||||
|
||||
A clear and concise description of what the bug is.
|
||||
|
||||
\[here\]
|
||||
|
||||
### To Reproduce
|
||||
|
||||
The command you executed.
|
||||
|
||||
```shell
|
||||
[here]
|
||||
```
|
||||
|
||||
### Post related information
|
||||
|
||||
1. The output of `pip list | grep "mmcv\|mmcls\|^torch"`
|
||||
\[here\]
|
||||
2. Your config file if you modified it or created a new one.
|
||||
|
||||
```python
|
||||
[here]
|
||||
```
|
||||
|
||||
3. Your train log file if you meet the problem during training.
|
||||
\[here\]
|
||||
4. Other code you modified in the `mmcls` folder.
|
||||
\[here\]
|
||||
|
||||
### Additional context
|
||||
|
||||
Add any other context about the problem here.
|
||||
|
||||
\[here\]
|
|
@ -1,6 +1,12 @@
|
|||
blank_issues_enabled: false
|
||||
|
||||
contact_links:
|
||||
- name: MMClassification Documentation
|
||||
url: https://mmclassification.readthedocs.io/en/latest/
|
||||
- name: 📚 MMPreTrain Documentation (官方文档)
|
||||
url: https://mmpretrain.readthedocs.io/en/latest/
|
||||
about: Check if your question is answered in docs
|
||||
- name: 💬 General questions (寻求帮助)
|
||||
url: https://github.com/open-mmlab/mmpretrain/discussions
|
||||
about: Ask general usage questions and discuss with other MMPreTrain community members
|
||||
- name: 🌐 Explore OpenMMLab (官网)
|
||||
url: https://openmmlab.com/
|
||||
about: Get know more about OpenMMLab
|
||||
|
|
|
@ -1,32 +0,0 @@
|
|||
---
|
||||
name: Feature request
|
||||
about: Suggest an idea for this project
|
||||
title: '[Feature]'
|
||||
labels: enhancement
|
||||
assignees: ''
|
||||
---
|
||||
|
||||
### Describe the feature
|
||||
|
||||
\[here\]
|
||||
|
||||
### Motivation
|
||||
|
||||
A clear and concise description of the motivation of the feature.
|
||||
Ex1. It is inconvenient when \[....\].
|
||||
Ex2. There is a recent paper \[....\], which is very helpful for \[....\].
|
||||
|
||||
\[here\]
|
||||
|
||||
### Related resources
|
||||
|
||||
If there is an official code release or third-party implementation, please also provide the information here, which would be very helpful.
|
||||
|
||||
\[here\]
|
||||
|
||||
### Additional context
|
||||
|
||||
Add any other context or screenshots about the feature request here.
|
||||
If you would like to implement the feature and create a PR, please leave a comment here and that would be much appreciated.
|
||||
|
||||
\[here\]
|
|
@ -1,31 +0,0 @@
|
|||
---
|
||||
name: General questions
|
||||
about: 'Ask general questions to get help '
|
||||
title: ''
|
||||
labels: help wanted
|
||||
assignees: ''
|
||||
---
|
||||
|
||||
### Checklist
|
||||
|
||||
- I have searched related issues but cannot get the expected help.
|
||||
- I have read related documents and don't know what to do.
|
||||
|
||||
### Describe the question you meet
|
||||
|
||||
\[here\]
|
||||
|
||||
### Post related information
|
||||
|
||||
1. The output of `pip list | grep "mmcv\|mmcls\|^torch"`
|
||||
\[here\]
|
||||
2. Your config file if you modified it or created a new one.
|
||||
|
||||
```python
|
||||
[here]
|
||||
```
|
||||
|
||||
3. Your train log file if you meet the problem during training.
|
||||
\[here\]
|
||||
4. Other code you modified in the `mmcls` folder.
|
||||
\[here\]
|
|
@ -1,22 +0,0 @@
|
|||
name: deploy
|
||||
|
||||
on: push
|
||||
|
||||
jobs:
|
||||
build-n-publish:
|
||||
runs-on: ubuntu-latest
|
||||
if: startsWith(github.event.ref, 'refs/tags')
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- name: Set up Python 3.7
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: 3.7
|
||||
- name: Build MMClassification
|
||||
run: |
|
||||
pip install wheel
|
||||
python setup.py sdist bdist_wheel
|
||||
- name: Publish distribution to PyPI
|
||||
run: |
|
||||
pip install twine
|
||||
twine upload dist/* -u __token__ -p ${{ secrets.pypi_password }}
|
|
@ -10,9 +10,9 @@ jobs:
|
|||
lint:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set up Python 3.7
|
||||
uses: actions/setup-python@v2
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: 3.7
|
||||
- name: Install pre-commit hook
|
||||
|
@ -24,4 +24,4 @@ jobs:
|
|||
- name: Check docstring coverage
|
||||
run: |
|
||||
pip install interrogate
|
||||
interrogate -v --ignore-init-method --ignore-module --ignore-nested-functions --ignore-magic --ignore-regex "__repr__" --fail-under 60 mmcls
|
||||
interrogate -v --ignore-init-method --ignore-module --ignore-nested-functions --ignore-magic --ignore-regex "__repr__" --fail-under 60 mmpretrain
|
||||
|
|
|
@ -18,7 +18,7 @@ concurrency:
|
|||
|
||||
jobs:
|
||||
build:
|
||||
runs-on: ubuntu-18.04
|
||||
runs-on: ubuntu-22.04
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: [3.7]
|
||||
|
@ -26,29 +26,77 @@ jobs:
|
|||
- torch: 1.8.1
|
||||
torchvision: 0.9.1
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v2
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- name: Upgrade pip
|
||||
run: pip install pip --upgrade
|
||||
- name: Install PyTorch
|
||||
run: pip install torch==${{matrix.torch}}+cpu torchvision==${{matrix.torchvision}}+cpu -f https://download.pytorch.org/whl/torch_stable.html
|
||||
- name: Install mmcls dependencies
|
||||
run: pip install torch==${{matrix.torch}}+cpu torchvision==${{matrix.torchvision}}+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html
|
||||
- name: Install mmpretrain dependencies
|
||||
run: |
|
||||
pip install git+https://github.com/open-mmlab/mmengine.git@main
|
||||
pip install -U openmim
|
||||
mim install 'mmcv >= 2.0.0rc1'
|
||||
mim install 'mmcv >= 2.0.0rc4'
|
||||
pip install -r requirements.txt
|
||||
- name: Build and install
|
||||
run: mim install .
|
||||
- name: Run unittests and generate coverage report
|
||||
run: |
|
||||
coverage run --branch --source mmpretrain -m pytest tests/
|
||||
coverage xml
|
||||
coverage report -m
|
||||
# Upload coverage report for python3.7 && pytorch1.8.1 cpu
|
||||
- name: Upload coverage to Codecov
|
||||
uses: codecov/codecov-action@v1.0.14
|
||||
with:
|
||||
file: ./coverage.xml
|
||||
flags: unittests
|
||||
env_vars: OS,PYTHON
|
||||
name: codecov-umbrella
|
||||
fail_ci_if_error: false
|
||||
|
||||
build_cu117:
|
||||
runs-on: ubuntu-22.04
|
||||
container:
|
||||
image: pytorch/pytorch:2.0.0-cuda11.7-cudnn8-devel
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: [3.9]
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- name: Upgrade pip
|
||||
run: pip install pip --upgrade
|
||||
- name: Fetch GPG keys
|
||||
run: |
|
||||
apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/3bf863cc.pub
|
||||
apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64/7fa2af80.pub
|
||||
- name: Install Python-dev
|
||||
run: apt-get update && apt-get install -y python${{matrix.python-version}}-dev
|
||||
if: ${{matrix.python-version != 3.9}}
|
||||
- name: Install system dependencies
|
||||
run: |
|
||||
apt-get update
|
||||
apt-get install -y ffmpeg libsm6 libxext6 git ninja-build libglib2.0-0 libxrender-dev libc6 libc6-dev
|
||||
- name: Install mmpretrain dependencies
|
||||
run: |
|
||||
pip install git+https://github.com/open-mmlab/mmengine.git@main
|
||||
pip install -U openmim
|
||||
mim install 'mmcv >= 2.0.0rc4'
|
||||
pip install -r requirements.txt
|
||||
- name: Build and install
|
||||
run: pip install -e .
|
||||
- name: Run unittests and generate coverage report
|
||||
run: |
|
||||
coverage run --branch --source mmcls -m pytest tests/ -k 'not timm'
|
||||
coverage run --branch --source mmpretrain -m pytest tests/ --ignore tests/test_tools.py
|
||||
coverage xml
|
||||
coverage report -m
|
||||
# Upload coverage report for python3.7 && pytorch1.8.1 cpu
|
||||
- name: Upload coverage to Codecov
|
||||
uses: codecov/codecov-action@v1.0.14
|
||||
with:
|
||||
|
@ -62,26 +110,26 @@ jobs:
|
|||
runs-on: windows-2022
|
||||
strategy:
|
||||
matrix:
|
||||
python: [3.7]
|
||||
platform: [cu111]
|
||||
python-version: [3.7]
|
||||
platform: [cpu]
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v2
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- name: Upgrade pip
|
||||
run: pip install pip --upgrade
|
||||
run: python -m pip install pip --upgrade
|
||||
- name: Install PyTorch
|
||||
run: pip install torch==1.8.2+${{matrix.platform}} torchvision==0.9.2+${{matrix.platform}} -f https://download.pytorch.org/whl/lts/1.8/torch_lts.html
|
||||
- name: Install mmcls dependencies
|
||||
- name: Install mmpretrain dependencies
|
||||
run: |
|
||||
pip install git+https://github.com/open-mmlab/mmengine.git@main
|
||||
pip install -U openmim
|
||||
mim install 'mmcv >= 2.0.0rc1'
|
||||
mim install mmengine
|
||||
mim install 'mmcv >= 2.0.0rc4'
|
||||
pip install -r requirements.txt
|
||||
- name: Build and install
|
||||
run: pip install -e .
|
||||
run: mim install .
|
||||
- name: Run unittests
|
||||
run: |
|
||||
pytest tests/ -k 'not timm' --ignore tests/test_models/test_backbones
|
||||
pytest tests/ --ignore tests/test_models/test_backbones
|
||||
|
|
|
@ -7,12 +7,12 @@ jobs:
|
|||
runs-on: ubuntu-latest
|
||||
if: startsWith(github.event.ref, 'refs/tags')
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set up Python 3.7
|
||||
uses: actions/setup-python@v2
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: 3.7
|
||||
- name: Build MMClassification
|
||||
- name: Build MMPretrain
|
||||
run: |
|
||||
pip install wheel
|
||||
python setup.py sdist bdist_wheel
|
||||
|
|
|
@ -17,7 +17,7 @@ concurrency:
|
|||
|
||||
jobs:
|
||||
build_cpu:
|
||||
runs-on: ubuntu-18.04
|
||||
runs-on: ubuntu-22.04
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: [3.7]
|
||||
|
@ -27,9 +27,9 @@ jobs:
|
|||
torch_version: torch1.8
|
||||
torchvision: 0.9.0
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v2
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- name: Upgrade pip
|
||||
|
@ -39,6 +39,6 @@ jobs:
|
|||
- name: Install openmim
|
||||
run: pip install openmim
|
||||
- name: Build and install
|
||||
run: mim install -e .
|
||||
run: mim install .
|
||||
- name: test commands of mim
|
||||
run: mim search mmcls
|
||||
run: mim search mmpretrain
|
||||
|
|
|
@ -76,6 +76,9 @@ docs/zh_CN/_model_zoo.rst
|
|||
docs/zh_CN/modelzoo_statistics.md
|
||||
docs/zh_CN/papers/
|
||||
docs/zh_CN/api/generated/
|
||||
docs/zh_CN/locales/
|
||||
!docs/zh_CN/locales/zh_CN/LC_MESSAGES/api.po
|
||||
!docs/zh_CN/locales/zh_CN/LC_MESSAGES/papers.po
|
||||
|
||||
# PyBuilder
|
||||
target/
|
||||
|
@ -122,7 +125,9 @@ venv.bak/
|
|||
*.pkl.json
|
||||
*.log.json
|
||||
/work_dirs
|
||||
/mmcls/.mim
|
||||
/projects/*/work_dirs
|
||||
/projects/*/data
|
||||
/mmpretrain/.mim
|
||||
.DS_Store
|
||||
|
||||
# Pytorch
|
||||
|
@ -133,3 +138,13 @@ venv.bak/
|
|||
*.pvti-journal
|
||||
/cache_engine
|
||||
/report
|
||||
|
||||
# slurm
|
||||
*.out
|
||||
|
||||
# tensorflow
|
||||
*.tar.gz
|
||||
checkpoint
|
||||
model_params.txt
|
||||
*.ckpt*
|
||||
results.txt
|
||||
|
|
|
@ -5,7 +5,7 @@ repos:
|
|||
hooks:
|
||||
- id: flake8
|
||||
- repo: https://github.com/PyCQA/isort
|
||||
rev: 5.10.1
|
||||
rev: 5.11.5
|
||||
hooks:
|
||||
- id: isort
|
||||
- repo: https://github.com/pre-commit/mirrors-yapf
|
||||
|
@ -29,9 +29,9 @@ repos:
|
|||
rev: 0.7.9
|
||||
hooks:
|
||||
- id: mdformat
|
||||
args: ["--number", "--table-width", "200"]
|
||||
args: ["--number", "--table-width", "200", '--disable-escape', 'backslash', '--disable-escape', 'link-enclosure']
|
||||
additional_dependencies:
|
||||
- mdformat-openmmlab
|
||||
- "mdformat-openmmlab>=0.0.4"
|
||||
- mdformat_frontmatter
|
||||
- linkify-it-py
|
||||
- repo: https://github.com/codespell-project/codespell
|
||||
|
@ -47,7 +47,18 @@ repos:
|
|||
rev: v0.4.0
|
||||
hooks:
|
||||
- id: check-copyright
|
||||
args: ["mmcls", "tests", "demo", "tools", "--excludes", "mmcls/.mim/", "--ignore-file-not-found-error"]
|
||||
args: ["mmpretrain", "tests", "demo", "tools", "--excludes", "mmpretrain/.mim/", "--ignore-file-not-found-error"]
|
||||
- repo: local
|
||||
hooks:
|
||||
- id: metafile
|
||||
args: ['--skip', 'flops-param']
|
||||
name: metafile
|
||||
description: Check the format of metafile
|
||||
entry: python .dev_scripts/check_metafile.py
|
||||
language: python
|
||||
files: (metafile)\.(yml)$
|
||||
additional_dependencies:
|
||||
- modelindex
|
||||
# - repo: local
|
||||
# hooks:
|
||||
# - id: clang-format
|
||||
|
|
|
@ -1,9 +1,15 @@
|
|||
version: 2
|
||||
|
||||
formats: all
|
||||
# Set the version of Python and other tools you might need
|
||||
build:
|
||||
os: ubuntu-22.04
|
||||
tools:
|
||||
python: "3.8"
|
||||
|
||||
formats:
|
||||
- epub
|
||||
|
||||
python:
|
||||
version: 3.7
|
||||
install:
|
||||
- requirements: requirements/docs.txt
|
||||
- requirements: requirements/readthedocs.txt
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
cff-version: 1.2.0
|
||||
message: "If you use this software, please cite it as below."
|
||||
title: "OpenMMLab's Image Classification Toolbox and Benchmark"
|
||||
title: "OpenMMLab's Pre-training Toolbox and Benchmark"
|
||||
authors:
|
||||
- name: "MMClassification Contributors"
|
||||
- name: "MMPreTrain Contributors"
|
||||
version: 0.15.0
|
||||
date-released: 2020-07-09
|
||||
repository-code: "https://github.com/open-mmlab/mmclassification"
|
||||
date-released: 2023-04-06
|
||||
repository-code: "https://github.com/open-mmlab/mmpretrain"
|
||||
license: Apache-2.0
|
||||
|
|
|
@ -1,13 +1,13 @@
|
|||
# Contributing to MMClassification
|
||||
# Contributing to MMPreTrain
|
||||
|
||||
- [Contributing to MMClassification](#contributing-to-mmclassification)
|
||||
- [Contributing to MMPreTrain](#contributing-to-mmpretrain)
|
||||
- [Workflow](#workflow)
|
||||
- [Code style](#code-style)
|
||||
- [Python](#python)
|
||||
- [C++ and CUDA](#c-and-cuda)
|
||||
- [Pre-commit Hook](#pre-commit-hook)
|
||||
|
||||
Thanks for your interest in contributing to MMClassification! All kinds of contributions are welcome, including but not limited to the following.
|
||||
Thanks for your interest in contributing to MMPreTrain! All kinds of contributions are welcome, including but not limited to the following.
|
||||
|
||||
- Fix typo or bugs
|
||||
- Add documentation or translate the documentation into other languages
|
||||
|
@ -17,7 +17,7 @@ Thanks for your interest in contributing to MMClassification! All kinds of contr
|
|||
|
||||
We recommend the potential contributors follow this workflow for contribution.
|
||||
|
||||
1. Fork and pull the latest MMClassification repository, follow [get started](https://mmclassification.readthedocs.io/en/1.x/get_started.html) to setup the environment.
|
||||
1. Fork and pull the latest MMPreTrain repository, follow [get started](https://mmpretrain.readthedocs.io/en/latest/get_started.html) to setup the environment.
|
||||
2. Checkout a new branch (**do not use the master or dev branch** for PRs)
|
||||
|
||||
```bash
|
||||
|
@ -44,7 +44,7 @@ We use the following tools for linting and formatting:
|
|||
- [mdformat](https://github.com/executablebooks/mdformat): Mdformat is an opinionated Markdown formatter that can be used to enforce a consistent style in Markdown files.
|
||||
- [docformatter](https://github.com/myint/docformatter): A formatter to format docstring.
|
||||
|
||||
Style configurations of yapf and isort can be found in [setup.cfg](https://github.com/open-mmlab/mmclassification/blob/1.x/setup.cfg).
|
||||
Style configurations of yapf and isort can be found in [setup.cfg](https://github.com/open-mmlab/mmpretrain/blob/main/setup.cfg).
|
||||
|
||||
### C++ and CUDA
|
||||
|
||||
|
@ -54,7 +54,7 @@ We follow the [Google C++ Style Guide](https://google.github.io/styleguide/cppgu
|
|||
|
||||
We use [pre-commit hook](https://pre-commit.com/) that checks and formats for `flake8`, `yapf`, `isort`, `trailing whitespaces`, `markdown files`,
|
||||
fixes `end-of-files`, `double-quoted-strings`, `python-encoding-pragma`, `mixed-line-ending`, sorts `requirments.txt` automatically on every commit.
|
||||
The config for a pre-commit hook is stored in [.pre-commit-config](https://github.com/open-mmlab/mmclassification/blob/1.x/.pre-commit-config.yaml).
|
||||
The config for a pre-commit hook is stored in [.pre-commit-config](https://github.com/open-mmlab/mmpretrain/blob/main/.pre-commit-config.yaml).
|
||||
|
||||
After you clone the repository, you will need to install initialize pre-commit hook.
|
||||
|
||||
|
|
2
LICENSE
2
LICENSE
|
@ -188,7 +188,7 @@ Copyright (c) OpenMMLab. All rights reserved
|
|||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright 2020 MMClassification Authors.
|
||||
Copyright 2020 MMPreTrain Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
include requirements/*.txt
|
||||
include mmcls/.mim/model-index.yml
|
||||
recursive-include mmcls/.mim/configs *.py *.yml
|
||||
recursive-include mmcls/.mim/tools *.py *.sh
|
||||
include mmpretrain/.mim/model-index.yml
|
||||
include mmpretrain/.mim/dataset-index.yml
|
||||
recursive-include mmpretrain/.mim/configs *.py *.yml
|
||||
recursive-include mmpretrain/.mim/tools *.py *.sh
|
||||
|
|
331
README.md
331
README.md
|
@ -1,6 +1,6 @@
|
|||
<div align="center">
|
||||
|
||||
<img src="resources/mmcls-logo.png" width="600"/>
|
||||
<img src="resources/mmpt-logo.png" width="600"/>
|
||||
<div> </div>
|
||||
<div align="center">
|
||||
<b><font size="5">OpenMMLab website</font></b>
|
||||
|
@ -19,60 +19,103 @@
|
|||
</div>
|
||||
<div> </div>
|
||||
|
||||
[](https://pypi.org/project/mmcls)
|
||||
[](https://mmclassification.readthedocs.io/en/1.x/)
|
||||
[](https://github.com/open-mmlab/mmclassification/actions)
|
||||
[](https://codecov.io/gh/open-mmlab/mmclassification)
|
||||
[](https://github.com/open-mmlab/mmclassification/blob/1.x/LICENSE)
|
||||
[](https://github.com/open-mmlab/mmclassification/issues)
|
||||
[](https://github.com/open-mmlab/mmclassification/issues)
|
||||
[](https://pypi.org/project/mmpretrain)
|
||||
[](https://mmpretrain.readthedocs.io/en/latest/)
|
||||
[](https://github.com/open-mmlab/mmpretrain/actions)
|
||||
[](https://codecov.io/gh/open-mmlab/mmpretrain)
|
||||
[](https://github.com/open-mmlab/mmpretrain/blob/main/LICENSE)
|
||||
[](https://github.com/open-mmlab/mmpretrain/issues)
|
||||
[](https://github.com/open-mmlab/mmpretrain/issues)
|
||||
|
||||
[📘 Documentation](https://mmclassification.readthedocs.io/en/1.x/) |
|
||||
[🛠️ Installation](https://mmclassification.readthedocs.io/en/1.xget_started.html) |
|
||||
[👀 Model Zoo](https://mmclassification.readthedocs.io/en/1.x/modelzoo_statistics.html) |
|
||||
[🆕 Update News](https://mmclassification.readthedocs.io/en/1.x/notes/changelog.html) |
|
||||
[🤔 Reporting Issues](https://github.com/open-mmlab/mmclassification/issues/new/choose)
|
||||
[📘 Documentation](https://mmpretrain.readthedocs.io/en/latest/) |
|
||||
[🛠️ Installation](https://mmpretrain.readthedocs.io/en/latest/get_started.html#installation) |
|
||||
[👀 Model Zoo](https://mmpretrain.readthedocs.io/en/latest/modelzoo_statistics.html) |
|
||||
[🆕 Update News](https://mmpretrain.readthedocs.io/en/latest/notes/changelog.html) |
|
||||
[🤔 Reporting Issues](https://github.com/open-mmlab/mmpretrain/issues/new/choose)
|
||||
|
||||
<img src="https://user-images.githubusercontent.com/36138628/230307505-4727ad0a-7d71-4069-939d-b499c7e272b7.png" width="400"/>
|
||||
|
||||
English | [简体中文](/README_zh-CN.md)
|
||||
|
||||
</div>
|
||||
|
||||
</div>
|
||||
|
||||
<div align="center">
|
||||
<a href="https://openmmlab.medium.com/" style="text-decoration:none;">
|
||||
<img src="https://user-images.githubusercontent.com/25839884/219255827-67c1a27f-f8c5-46a9-811d-5e57448c61d1.png" width="3%" alt="" /></a>
|
||||
<img src="https://user-images.githubusercontent.com/25839884/218346358-56cc8e2f-a2b8-487f-9088-32480cceabcf.png" width="3%" alt="" />
|
||||
<a href="https://discord.gg/raweFPmdzG" style="text-decoration:none;">
|
||||
<img src="https://user-images.githubusercontent.com/25839884/218347213-c080267f-cbb6-443e-8532-8e1ed9a58ea9.png" width="3%" alt="" /></a>
|
||||
<img src="https://user-images.githubusercontent.com/25839884/218346358-56cc8e2f-a2b8-487f-9088-32480cceabcf.png" width="3%" alt="" />
|
||||
<a href="https://twitter.com/OpenMMLab" style="text-decoration:none;">
|
||||
<img src="https://user-images.githubusercontent.com/25839884/218346637-d30c8a0f-3eba-4699-8131-512fb06d46db.png" width="3%" alt="" /></a>
|
||||
<img src="https://user-images.githubusercontent.com/25839884/218346358-56cc8e2f-a2b8-487f-9088-32480cceabcf.png" width="3%" alt="" />
|
||||
<a href="https://www.youtube.com/openmmlab" style="text-decoration:none;">
|
||||
<img src="https://user-images.githubusercontent.com/25839884/218346691-ceb2116a-465a-40af-8424-9f30d2348ca9.png" width="3%" alt="" /></a>
|
||||
<img src="https://user-images.githubusercontent.com/25839884/218346358-56cc8e2f-a2b8-487f-9088-32480cceabcf.png" width="3%" alt="" />
|
||||
<a href="https://space.bilibili.com/1293512903" style="text-decoration:none;">
|
||||
<img src="https://user-images.githubusercontent.com/25839884/219026751-d7d14cce-a7c9-4e82-9942-8375fca65b99.png" width="3%" alt="" /></a>
|
||||
<img src="https://user-images.githubusercontent.com/25839884/218346358-56cc8e2f-a2b8-487f-9088-32480cceabcf.png" width="3%" alt="" />
|
||||
<a href="https://www.zhihu.com/people/openmmlab" style="text-decoration:none;">
|
||||
<img src="https://user-images.githubusercontent.com/25839884/219026120-ba71e48b-6e94-4bd4-b4e9-b7d175b5e362.png" width="3%" alt="" /></a>
|
||||
</div>
|
||||
|
||||
## Introduction
|
||||
|
||||
English | [简体中文](/README_zh-CN.md)
|
||||
MMPreTrain is an open source pre-training toolbox based on PyTorch. It is a part of the [OpenMMLab](https://openmmlab.com/) project.
|
||||
|
||||
MMClassification is an open source image classification toolbox based on PyTorch. It is
|
||||
a part of the [OpenMMLab](https://openmmlab.com/) project.
|
||||
|
||||
The `1.x` branch works with **PyTorch 1.6+**.
|
||||
|
||||
<div align="center">
|
||||
<img src="https://user-images.githubusercontent.com/9102141/87268895-3e0d0780-c4fe-11ea-849e-6140b7e0d4de.gif" width="70%"/>
|
||||
</div>
|
||||
The `main` branch works with **PyTorch 1.8+**.
|
||||
|
||||
### Major features
|
||||
|
||||
- Various backbones and pretrained models
|
||||
- Rich training strategies (supervised learning, self-supervised learning, multi-modality learning etc.)
|
||||
- Bag of training tricks
|
||||
- Large-scale training configs
|
||||
- High efficiency and extensibility
|
||||
- Powerful toolkits
|
||||
- Powerful toolkits for model analysis and experiments
|
||||
- Various out-of-box inference tasks.
|
||||
- Image Classification
|
||||
- Image Caption
|
||||
- Visual Question Answering
|
||||
- Visual Grounding
|
||||
- Retrieval (Image-To-Image, Text-To-Image, Image-To-Text)
|
||||
|
||||
https://github.com/open-mmlab/mmpretrain/assets/26739999/e4dcd3a2-f895-4d1b-a351-fbc74a04e904
|
||||
|
||||
## What's new
|
||||
|
||||
v1.0.0rc1 was released in 30/9/2022.
|
||||
🌟 v1.2.0 was released in 04/01/2023
|
||||
|
||||
- Support MViT, EdgeNeXt, Swin-Transformer V2, EfficientFormer and MobileOne.
|
||||
- Support BEiT type transformer layer.
|
||||
- Support LLaVA 1.5.
|
||||
- Implement of RAM with a gradio interface.
|
||||
|
||||
v1.0.0rc0 was released in 31/8/2022.
|
||||
🌟 v1.1.0 was released in 12/10/2023
|
||||
|
||||
- Support Mini-GPT4 training and provide a Chinese model (based on Baichuan-7B)
|
||||
- Support zero-shot classification based on CLIP.
|
||||
|
||||
🌟 v1.0.0 was released in 04/07/2023
|
||||
|
||||
- Support inference of more **multi-modal** algorithms, such as [**LLaVA**](./configs/llava/), [**MiniGPT-4**](./configs/minigpt4), [**Otter**](./configs/otter/), etc.
|
||||
- Support around **10 multi-modal** datasets!
|
||||
- Add [**iTPN**](./configs/itpn/), [**SparK**](./configs/spark/) self-supervised learning algorithms.
|
||||
- Provide examples of [New Config](./mmpretrain/configs/) and [DeepSpeed/FSDP with FlexibleRunner](./configs/mae/benchmarks/). Here are the documentation links of [New Config](https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html#a-pure-python-style-configuration-file-beta) and [DeepSpeed/FSDP with FlexibleRunner](https://mmengine.readthedocs.io/en/latest/api/generated/mmengine.runner.FlexibleRunner.html#mmengine.runner.FlexibleRunner).
|
||||
|
||||
🌟 Upgrade from MMClassification to MMPreTrain
|
||||
|
||||
- Integrated Self-supervised learning algorithms from **MMSelfSup**, such as **MAE**, **BEiT**, etc.
|
||||
- Support **RIFormer**, a simple but effective vision backbone by removing token mixer.
|
||||
- Refactor dataset pipeline visualization.
|
||||
- Support **LeViT**, **XCiT**, **ViG**, **ConvNeXt-V2**, **EVA**, **RevViT**, **EfficientnetV2**, **CLIP**, **TinyViT** and **MixMIM** backbones.
|
||||
|
||||
This release introduced a brand new and flexible training & test engine, but it's still in progress. Welcome
|
||||
to try according to [the documentation](https://mmclassification.readthedocs.io/en/1.x/).
|
||||
to try according to [the documentation](https://mmpretrain.readthedocs.io/en/latest/).
|
||||
|
||||
And there are some BC-breaking changes. Please check [the migration tutorial](https://mmclassification.readthedocs.io/en/1.x/migration.html).
|
||||
And there are some BC-breaking changes. Please check [the migration tutorial](https://mmpretrain.readthedocs.io/en/latest/migration.html).
|
||||
|
||||
The release candidate will last until the end of 2022, and during the release candidate, we will develop on the `1.x` branch. And we will still maintain 0.x version still at least the end of 2023.
|
||||
|
||||
Please refer to [changelog.md](https://mmclassification.readthedocs.io/en/1.x/notes/changelog.html) for more details and other release history.
|
||||
Please refer to [changelog](https://mmpretrain.readthedocs.io/en/latest/notes/changelog.html) for more details and other release history.
|
||||
|
||||
## Installation
|
||||
|
||||
|
@ -82,89 +125,186 @@ Below are quick steps for installation:
|
|||
conda create -n open-mmlab python=3.8 pytorch==1.10.1 torchvision==0.11.2 cudatoolkit=11.3 -c pytorch -y
|
||||
conda activate open-mmlab
|
||||
pip install openmim
|
||||
git clone -b 1.x https://github.com/open-mmlab/mmclassification.git
|
||||
cd mmclassification
|
||||
git clone https://github.com/open-mmlab/mmpretrain.git
|
||||
cd mmpretrain
|
||||
mim install -e .
|
||||
```
|
||||
|
||||
Please refer to [install.md](https://mmclassification.readthedocs.io/en/1.x/get_started.html) for more detailed installation and dataset preparation.
|
||||
Please refer to [installation documentation](https://mmpretrain.readthedocs.io/en/latest/get_started.html) for more detailed installation and dataset preparation.
|
||||
|
||||
For multi-modality models support, please install the extra dependencies by:
|
||||
|
||||
```shell
|
||||
mim install -e ".[multimodal]"
|
||||
```
|
||||
|
||||
## User Guides
|
||||
|
||||
We provided a series of tutorials about the basic usage of MMClassification for new users:
|
||||
We provided a series of tutorials about the basic usage of MMPreTrain for new users:
|
||||
|
||||
- [Inference with existing models](https://mmclassification.readthedocs.io/en/1.x/user_guides/inference.html)
|
||||
- [Prepare Dataset](https://mmclassification.readthedocs.io/en/1.x/user_guides/dataset_prepare.html)
|
||||
- [Training and Test](https://mmclassification.readthedocs.io/en/1.x/user_guides/train_test.html)
|
||||
- [Learn about Configs](https://mmclassification.readthedocs.io/en/1.x/user_guides/config.html)
|
||||
- [Fine-tune Models](https://mmclassification.readthedocs.io/en/1.x/user_guides/finetune.html)
|
||||
- [Analysis Tools](https://mmclassification.readthedocs.io/en/1.x/user_guides/analysis.html)
|
||||
- [Visualization Tools](https://mmclassification.readthedocs.io/en/1.x/user_guides/visualization.html)
|
||||
- [Other Useful Tools](https://mmclassification.readthedocs.io/en/1.x/user_guides/useful_tools.html)
|
||||
- [Learn about Configs](https://mmpretrain.readthedocs.io/en/latest/user_guides/config.html)
|
||||
- [Prepare Dataset](https://mmpretrain.readthedocs.io/en/latest/user_guides/dataset_prepare.html)
|
||||
- [Inference with existing models](https://mmpretrain.readthedocs.io/en/latest/user_guides/inference.html)
|
||||
- [Train](https://mmpretrain.readthedocs.io/en/latest/user_guides/train.html)
|
||||
- [Test](https://mmpretrain.readthedocs.io/en/latest/user_guides/test.html)
|
||||
- [Downstream tasks](https://mmpretrain.readthedocs.io/en/latest/user_guides/downstream.html)
|
||||
|
||||
For more information, please refer to [our documentation](https://mmpretrain.readthedocs.io/en/latest/).
|
||||
|
||||
## Model zoo
|
||||
|
||||
Results and models are available in the [model zoo](https://mmclassification.readthedocs.io/en/1.x/modelzoo_statistics.html).
|
||||
Results and models are available in the [model zoo](https://mmpretrain.readthedocs.io/en/latest/modelzoo_statistics.html).
|
||||
|
||||
<details open>
|
||||
<summary>Supported backbones</summary>
|
||||
|
||||
- [x] [VGG](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/vgg)
|
||||
- [x] [ResNet](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/resnet)
|
||||
- [x] [ResNeXt](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/resnext)
|
||||
- [x] [SE-ResNet](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/seresnet)
|
||||
- [x] [SE-ResNeXt](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/seresnet)
|
||||
- [x] [RegNet](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/regnet)
|
||||
- [x] [ShuffleNetV1](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/shufflenet_v1)
|
||||
- [x] [ShuffleNetV2](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/shufflenet_v2)
|
||||
- [x] [MobileNetV2](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/mobilenet_v2)
|
||||
- [x] [MobileNetV3](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/mobilenet_v3)
|
||||
- [x] [Swin-Transformer](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/swin_transformer)
|
||||
- [x] [Swin-Transformer V2](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/swin_transformer_v2)
|
||||
- [x] [RepVGG](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/repvgg)
|
||||
- [x] [Vision-Transformer](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/vision_transformer)
|
||||
- [x] [Transformer-in-Transformer](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/tnt)
|
||||
- [x] [Res2Net](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/res2net)
|
||||
- [x] [MLP-Mixer](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/mlp_mixer)
|
||||
- [x] [DeiT](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/deit)
|
||||
- [x] [Conformer](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/conformer)
|
||||
- [x] [T2T-ViT](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/t2t_vit)
|
||||
- [x] [Twins](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/twins)
|
||||
- [x] [EfficientNet](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/efficientnet)
|
||||
- [x] [EdgeNeXt](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/edgenext)
|
||||
- [x] [ConvNeXt](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/convnext)
|
||||
- [x] [HRNet](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/hrnet)
|
||||
- [x] [VAN](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/van)
|
||||
- [x] [ConvMixer](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/convmixer)
|
||||
- [x] [CSPNet](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/cspnet)
|
||||
- [x] [PoolFormer](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/poolformer)
|
||||
- [x] [Inception V3](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/inception_v3)
|
||||
- [x] [MobileOne](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/mobileone)
|
||||
- [x] [EfficientFormer](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/efficientformer)
|
||||
- [x] [MViT](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/mvit)
|
||||
|
||||
</details>
|
||||
<div align="center">
|
||||
<b>Overview</b>
|
||||
</div>
|
||||
<table align="center">
|
||||
<tbody>
|
||||
<tr align="center" valign="bottom">
|
||||
<td>
|
||||
<b>Supported Backbones</b>
|
||||
</td>
|
||||
<td>
|
||||
<b>Self-supervised Learning</b>
|
||||
</td>
|
||||
<td>
|
||||
<b>Multi-Modality Algorithms</b>
|
||||
</td>
|
||||
<td>
|
||||
<b>Others</b>
|
||||
</td>
|
||||
</tr>
|
||||
<tr valign="top">
|
||||
<td>
|
||||
<ul>
|
||||
<li><a href="configs/vgg">VGG</a></li>
|
||||
<li><a href="configs/resnet">ResNet</a></li>
|
||||
<li><a href="configs/resnext">ResNeXt</a></li>
|
||||
<li><a href="configs/seresnet">SE-ResNet</a></li>
|
||||
<li><a href="configs/seresnet">SE-ResNeXt</a></li>
|
||||
<li><a href="configs/regnet">RegNet</a></li>
|
||||
<li><a href="configs/shufflenet_v1">ShuffleNet V1</a></li>
|
||||
<li><a href="configs/shufflenet_v2">ShuffleNet V2</a></li>
|
||||
<li><a href="configs/mobilenet_v2">MobileNet V2</a></li>
|
||||
<li><a href="configs/mobilenet_v3">MobileNet V3</a></li>
|
||||
<li><a href="configs/swin_transformer">Swin-Transformer</a></li>
|
||||
<li><a href="configs/swin_transformer_v2">Swin-Transformer V2</a></li>
|
||||
<li><a href="configs/repvgg">RepVGG</a></li>
|
||||
<li><a href="configs/vision_transformer">Vision-Transformer</a></li>
|
||||
<li><a href="configs/tnt">Transformer-in-Transformer</a></li>
|
||||
<li><a href="configs/res2net">Res2Net</a></li>
|
||||
<li><a href="configs/mlp_mixer">MLP-Mixer</a></li>
|
||||
<li><a href="configs/deit">DeiT</a></li>
|
||||
<li><a href="configs/deit3">DeiT-3</a></li>
|
||||
<li><a href="configs/conformer">Conformer</a></li>
|
||||
<li><a href="configs/t2t_vit">T2T-ViT</a></li>
|
||||
<li><a href="configs/twins">Twins</a></li>
|
||||
<li><a href="configs/efficientnet">EfficientNet</a></li>
|
||||
<li><a href="configs/edgenext">EdgeNeXt</a></li>
|
||||
<li><a href="configs/convnext">ConvNeXt</a></li>
|
||||
<li><a href="configs/hrnet">HRNet</a></li>
|
||||
<li><a href="configs/van">VAN</a></li>
|
||||
<li><a href="configs/convmixer">ConvMixer</a></li>
|
||||
<li><a href="configs/cspnet">CSPNet</a></li>
|
||||
<li><a href="configs/poolformer">PoolFormer</a></li>
|
||||
<li><a href="configs/inception_v3">Inception V3</a></li>
|
||||
<li><a href="configs/mobileone">MobileOne</a></li>
|
||||
<li><a href="configs/efficientformer">EfficientFormer</a></li>
|
||||
<li><a href="configs/mvit">MViT</a></li>
|
||||
<li><a href="configs/hornet">HorNet</a></li>
|
||||
<li><a href="configs/mobilevit">MobileViT</a></li>
|
||||
<li><a href="configs/davit">DaViT</a></li>
|
||||
<li><a href="configs/replknet">RepLKNet</a></li>
|
||||
<li><a href="configs/beit">BEiT</a></li>
|
||||
<li><a href="configs/mixmim">MixMIM</a></li>
|
||||
<li><a href="configs/efficientnet_v2">EfficientNet V2</a></li>
|
||||
<li><a href="configs/revvit">RevViT</a></li>
|
||||
<li><a href="configs/convnext_v2">ConvNeXt V2</a></li>
|
||||
<li><a href="configs/vig">ViG</a></li>
|
||||
<li><a href="configs/xcit">XCiT</a></li>
|
||||
<li><a href="configs/levit">LeViT</a></li>
|
||||
<li><a href="configs/riformer">RIFormer</a></li>
|
||||
<li><a href="configs/glip">GLIP</a></li>
|
||||
<li><a href="configs/sam">ViT SAM</a></li>
|
||||
<li><a href="configs/eva02">EVA02</a></li>
|
||||
<li><a href="configs/dinov2">DINO V2</a></li>
|
||||
<li><a href="configs/hivit">HiViT</a></li>
|
||||
</ul>
|
||||
</td>
|
||||
<td>
|
||||
<ul>
|
||||
<li><a href="configs/mocov2">MoCo V1 (CVPR'2020)</a></li>
|
||||
<li><a href="configs/simclr">SimCLR (ICML'2020)</a></li>
|
||||
<li><a href="configs/mocov2">MoCo V2 (arXiv'2020)</a></li>
|
||||
<li><a href="configs/byol">BYOL (NeurIPS'2020)</a></li>
|
||||
<li><a href="configs/swav">SwAV (NeurIPS'2020)</a></li>
|
||||
<li><a href="configs/densecl">DenseCL (CVPR'2021)</a></li>
|
||||
<li><a href="configs/simsiam">SimSiam (CVPR'2021)</a></li>
|
||||
<li><a href="configs/barlowtwins">Barlow Twins (ICML'2021)</a></li>
|
||||
<li><a href="configs/mocov3">MoCo V3 (ICCV'2021)</a></li>
|
||||
<li><a href="configs/beit">BEiT (ICLR'2022)</a></li>
|
||||
<li><a href="configs/mae">MAE (CVPR'2022)</a></li>
|
||||
<li><a href="configs/simmim">SimMIM (CVPR'2022)</a></li>
|
||||
<li><a href="configs/maskfeat">MaskFeat (CVPR'2022)</a></li>
|
||||
<li><a href="configs/cae">CAE (arXiv'2022)</a></li>
|
||||
<li><a href="configs/milan">MILAN (arXiv'2022)</a></li>
|
||||
<li><a href="configs/beitv2">BEiT V2 (arXiv'2022)</a></li>
|
||||
<li><a href="configs/eva">EVA (CVPR'2023)</a></li>
|
||||
<li><a href="configs/mixmim">MixMIM (arXiv'2022)</a></li>
|
||||
<li><a href="configs/itpn">iTPN (CVPR'2023)</a></li>
|
||||
<li><a href="configs/spark">SparK (ICLR'2023)</a></li>
|
||||
<li><a href="configs/mff">MFF (ICCV'2023)</a></li>
|
||||
</ul>
|
||||
</td>
|
||||
<td>
|
||||
<ul>
|
||||
<li><a href="configs/blip">BLIP (arxiv'2022)</a></li>
|
||||
<li><a href="configs/blip2">BLIP-2 (arxiv'2023)</a></li>
|
||||
<li><a href="configs/ofa">OFA (CoRR'2022)</a></li>
|
||||
<li><a href="configs/flamingo">Flamingo (NeurIPS'2022)</a></li>
|
||||
<li><a href="configs/chinese_clip">Chinese CLIP (arxiv'2022)</a></li>
|
||||
<li><a href="configs/minigpt4">MiniGPT-4 (arxiv'2023)</a></li>
|
||||
<li><a href="configs/llava">LLaVA (arxiv'2023)</a></li>
|
||||
<li><a href="configs/otter">Otter (arxiv'2023)</a></li>
|
||||
</ul>
|
||||
</td>
|
||||
<td>
|
||||
Image Retrieval Task:
|
||||
<ul>
|
||||
<li><a href="configs/arcface">ArcFace (CVPR'2019)</a></li>
|
||||
</ul>
|
||||
Training&Test Tips:
|
||||
<ul>
|
||||
<li><a href="https://arxiv.org/abs/1909.13719">RandAug</a></li>
|
||||
<li><a href="https://arxiv.org/abs/1805.09501">AutoAug</a></li>
|
||||
<li><a href="mmpretrain/datasets/samplers/repeat_aug.py">RepeatAugSampler</a></li>
|
||||
<li><a href="mmpretrain/models/tta/score_tta.py">TTA</a></li>
|
||||
<li>...</li>
|
||||
</ul>
|
||||
</td>
|
||||
</tbody>
|
||||
</table>
|
||||
|
||||
## Contributing
|
||||
|
||||
We appreciate all contributions to improve MMClassification.
|
||||
Please refer to [CONTRUBUTING.md](https://mmclassification.readthedocs.io/en/1.x/notes/contribution_guide.html) for the contributing guideline.
|
||||
We appreciate all contributions to improve MMPreTrain.
|
||||
Please refer to [CONTRUBUTING](https://mmpretrain.readthedocs.io/en/latest/notes/contribution_guide.html) for the contributing guideline.
|
||||
|
||||
## Acknowledgement
|
||||
|
||||
MMClassification is an open source project that is contributed by researchers and engineers from various colleges and companies. We appreciate all the contributors who implement their methods or add new features, as well as users who give valuable feedbacks.
|
||||
We wish that the toolbox and benchmark could serve the growing research community by providing a flexible toolkit to reimplement existing methods and develop their own new classifiers.
|
||||
MMPreTrain is an open source project that is contributed by researchers and engineers from various colleges and companies. We appreciate all the contributors who implement their methods or add new features, as well as users who give valuable feedbacks.
|
||||
We wish that the toolbox and benchmark could serve the growing research community by providing a flexible toolkit to reimplement existing methods and supporting their own academic research.
|
||||
|
||||
## Citation
|
||||
|
||||
If you find this project useful in your research, please consider cite:
|
||||
|
||||
```BibTeX
|
||||
@misc{2020mmclassification,
|
||||
title={OpenMMLab's Image Classification Toolbox and Benchmark},
|
||||
author={MMClassification Contributors},
|
||||
howpublished = {\url{https://github.com/open-mmlab/mmclassification}},
|
||||
year={2020}
|
||||
@misc{2023mmpretrain,
|
||||
title={OpenMMLab's Pre-training Toolbox and Benchmark},
|
||||
author={MMPreTrain Contributors},
|
||||
howpublished = {\url{https://github.com/open-mmlab/mmpretrain}},
|
||||
year={2023}
|
||||
}
|
||||
```
|
||||
|
||||
|
@ -177,10 +317,12 @@ This project is released under the [Apache 2.0 license](LICENSE).
|
|||
- [MMEngine](https://github.com/open-mmlab/mmengine): OpenMMLab foundational library for training deep learning models.
|
||||
- [MMCV](https://github.com/open-mmlab/mmcv): OpenMMLab foundational library for computer vision.
|
||||
- [MIM](https://github.com/open-mmlab/mim): MIM installs OpenMMLab packages.
|
||||
- [MMClassification](https://github.com/open-mmlab/mmclassification): OpenMMLab image classification toolbox and benchmark.
|
||||
- [MMEval](https://github.com/open-mmlab/mmeval): A unified evaluation library for multiple machine learning libraries.
|
||||
- [MMPreTrain](https://github.com/open-mmlab/mmpretrain): OpenMMLab pre-training toolbox and benchmark.
|
||||
- [MMDetection](https://github.com/open-mmlab/mmdetection): OpenMMLab detection toolbox and benchmark.
|
||||
- [MMDetection3D](https://github.com/open-mmlab/mmdetection3d): OpenMMLab's next-generation platform for general 3D object detection.
|
||||
- [MMRotate](https://github.com/open-mmlab/mmrotate): OpenMMLab rotated object detection toolbox and benchmark.
|
||||
- [MMYOLO](https://github.com/open-mmlab/mmyolo): OpenMMLab YOLO series toolbox and benchmark.
|
||||
- [MMSegmentation](https://github.com/open-mmlab/mmsegmentation): OpenMMLab semantic segmentation toolbox and benchmark.
|
||||
- [MMOCR](https://github.com/open-mmlab/mmocr): OpenMMLab text detection, recognition, and understanding toolbox.
|
||||
- [MMPose](https://github.com/open-mmlab/mmpose): OpenMMLab pose estimation toolbox and benchmark.
|
||||
|
@ -191,6 +333,7 @@ This project is released under the [Apache 2.0 license](LICENSE).
|
|||
- [MMAction2](https://github.com/open-mmlab/mmaction2): OpenMMLab's next-generation action understanding toolbox and benchmark.
|
||||
- [MMTracking](https://github.com/open-mmlab/mmtracking): OpenMMLab video perception toolbox and benchmark.
|
||||
- [MMFlow](https://github.com/open-mmlab/mmflow): OpenMMLab optical flow toolbox and benchmark.
|
||||
- [MMEditing](https://github.com/open-mmlab/mmediting): OpenMMLab image and video editing toolbox.
|
||||
- [MMagic](https://github.com/open-mmlab/mmagic): Open**MM**Lab **A**dvanced, **G**enerative and **I**ntelligent **C**reation toolbox.
|
||||
- [MMGeneration](https://github.com/open-mmlab/mmgeneration): OpenMMLab image and video generative models toolbox.
|
||||
- [MMDeploy](https://github.com/open-mmlab/mmdeploy): OpenMMLab model deployment framework.
|
||||
- [Playground](https://github.com/open-mmlab/playground): A central hub for gathering and showcasing amazing projects built upon OpenMMLab.
|
||||
|
|
329
README_zh-CN.md
329
README_zh-CN.md
|
@ -1,6 +1,6 @@
|
|||
<div align="center">
|
||||
|
||||
<img src="resources/mmcls-logo.png" width="600"/>
|
||||
<img src="resources/mmpt-logo.png" width="600"/>
|
||||
<div> </div>
|
||||
<div align="center">
|
||||
<b><font size="5">OpenMMLab 官网</font></b>
|
||||
|
@ -19,59 +19,100 @@
|
|||
</div>
|
||||
<div> </div>
|
||||
|
||||
[](https://pypi.org/project/mmcls)
|
||||
[](https://mmclassification.readthedocs.io/zh_CN/1.x/)
|
||||
[](https://github.com/open-mmlab/mmclassification/actions)
|
||||
[](https://codecov.io/gh/open-mmlab/mmclassification)
|
||||
[](https://github.com/open-mmlab/mmclassification/blob/1.x/LICENSE)
|
||||
[](https://github.com/open-mmlab/mmclassification/issues)
|
||||
[](https://github.com/open-mmlab/mmclassification/issues)
|
||||
[](https://pypi.org/project/mmpretrain)
|
||||
[](https://mmpretrain.readthedocs.io/zh_CN/latest/)
|
||||
[](https://github.com/open-mmlab/mmpretrain/actions)
|
||||
[](https://codecov.io/gh/open-mmlab/mmpretrain)
|
||||
[](https://github.com/open-mmlab/mmpretrain/blob/main/LICENSE)
|
||||
[](https://github.com/open-mmlab/mmpretrain/issues)
|
||||
[](https://github.com/open-mmlab/mmpretrain/issues)
|
||||
|
||||
[📘 中文文档](https://mmclassification.readthedocs.io/zh_CN/1.x/) |
|
||||
[🛠️ 安装教程](https://mmclassification.readthedocs.io/zh_CN/1.x/get_started.html) |
|
||||
[👀 模型库](https://mmclassification.readthedocs.io/zh_CN/1.x/modelzoo_statistics.html) |
|
||||
[🆕 更新日志](https://mmclassification.readthedocs.io/en/1.x/notes/changelog.html) |
|
||||
[🤔 报告问题](https://github.com/open-mmlab/mmclassification/issues/new/choose)
|
||||
[📘 中文文档](https://mmpretrain.readthedocs.io/zh_CN/latest/) |
|
||||
[🛠️ 安装教程](https://mmpretrain.readthedocs.io/zh_CN/latest/get_started.html) |
|
||||
[👀 模型库](https://mmpretrain.readthedocs.io/zh_CN/latest/modelzoo_statistics.html) |
|
||||
[🆕 更新日志](https://mmpretrain.readthedocs.io/zh_CN/latest/notes/changelog.html) |
|
||||
[🤔 报告问题](https://github.com/open-mmlab/mmpretrain/issues/new/choose)
|
||||
|
||||
<img src="https://user-images.githubusercontent.com/36138628/230307505-4727ad0a-7d71-4069-939d-b499c7e272b7.png" width="400"/>
|
||||
|
||||
[English](/README.md) | 简体中文
|
||||
|
||||
</div>
|
||||
|
||||
<div align="center">
|
||||
<a href="https://openmmlab.medium.com/" style="text-decoration:none;">
|
||||
<img src="https://user-images.githubusercontent.com/25839884/219255827-67c1a27f-f8c5-46a9-811d-5e57448c61d1.png" width="3%" alt="" /></a>
|
||||
<img src="https://user-images.githubusercontent.com/25839884/218346358-56cc8e2f-a2b8-487f-9088-32480cceabcf.png" width="3%" alt="" />
|
||||
<a href="https://discord.gg/raweFPmdzG" style="text-decoration:none;">
|
||||
<img src="https://user-images.githubusercontent.com/25839884/218347213-c080267f-cbb6-443e-8532-8e1ed9a58ea9.png" width="3%" alt="" /></a>
|
||||
<img src="https://user-images.githubusercontent.com/25839884/218346358-56cc8e2f-a2b8-487f-9088-32480cceabcf.png" width="3%" alt="" />
|
||||
<a href="https://twitter.com/OpenMMLab" style="text-decoration:none;">
|
||||
<img src="https://user-images.githubusercontent.com/25839884/218346637-d30c8a0f-3eba-4699-8131-512fb06d46db.png" width="3%" alt="" /></a>
|
||||
<img src="https://user-images.githubusercontent.com/25839884/218346358-56cc8e2f-a2b8-487f-9088-32480cceabcf.png" width="3%" alt="" />
|
||||
<a href="https://www.youtube.com/openmmlab" style="text-decoration:none;">
|
||||
<img src="https://user-images.githubusercontent.com/25839884/218346691-ceb2116a-465a-40af-8424-9f30d2348ca9.png" width="3%" alt="" /></a>
|
||||
<img src="https://user-images.githubusercontent.com/25839884/218346358-56cc8e2f-a2b8-487f-9088-32480cceabcf.png" width="3%" alt="" />
|
||||
<a href="https://space.bilibili.com/1293512903" style="text-decoration:none;">
|
||||
<img src="https://user-images.githubusercontent.com/25839884/219026751-d7d14cce-a7c9-4e82-9942-8375fca65b99.png" width="3%" alt="" /></a>
|
||||
<img src="https://user-images.githubusercontent.com/25839884/218346358-56cc8e2f-a2b8-487f-9088-32480cceabcf.png" width="3%" alt="" />
|
||||
<a href="https://www.zhihu.com/people/openmmlab" style="text-decoration:none;">
|
||||
<img src="https://user-images.githubusercontent.com/25839884/219026120-ba71e48b-6e94-4bd4-b4e9-b7d175b5e362.png" width="3%" alt="" /></a>
|
||||
</div>
|
||||
|
||||
## Introduction
|
||||
|
||||
[English](/README.md) | 简体中文
|
||||
MMPreTrain 是一款基于 PyTorch 的开源深度学习预训练工具箱,是 [OpenMMLab](https://openmmlab.com/) 项目的成员之一
|
||||
|
||||
MMClassification 是一款基于 PyTorch 的开源图像分类工具箱,是 [OpenMMLab](https://openmmlab.com/) 项目的成员之一
|
||||
|
||||
主分支代码目前支持 PyTorch 1.5 以上的版本。
|
||||
|
||||
<div align="center">
|
||||
<img src="https://user-images.githubusercontent.com/9102141/87268895-3e0d0780-c4fe-11ea-849e-6140b7e0d4de.gif" width="70%"/>
|
||||
</div>
|
||||
`主分支`代码目前支持 PyTorch 1.8 以上的版本。
|
||||
|
||||
### 主要特性
|
||||
|
||||
- 支持多样的主干网络与预训练模型
|
||||
- 支持配置多种训练技巧
|
||||
- 支持多种训练策略(有监督学习,无监督学习,多模态学习等)
|
||||
- 提供多种训练技巧
|
||||
- 大量的训练配置文件
|
||||
- 高效率和高可扩展性
|
||||
- 功能强大的工具箱
|
||||
- 功能强大的工具箱,有助于模型分析和实验
|
||||
- 支持多种开箱即用的推理任务
|
||||
- 图像分类
|
||||
- 图像描述(Image Caption)
|
||||
- 视觉问答(Visual Question Answering)
|
||||
- 视觉定位(Visual Grounding)
|
||||
- 检索(图搜图,图搜文,文搜图)
|
||||
|
||||
https://github.com/open-mmlab/mmpretrain/assets/26739999/e4dcd3a2-f895-4d1b-a351-fbc74a04e904
|
||||
|
||||
## 更新日志
|
||||
|
||||
2022/9/30 发布了 v1.0.0rc1 版本
|
||||
🌟 2024/01/04 发布了 v1.2.0 版本
|
||||
|
||||
- 支持了 MViT,EdgeNeXt,Swin-Transformer V2,EfficientFormer,MobileOne 等主干网络。
|
||||
- 支持了 BEiT 风格的 transformer 层。
|
||||
- 支持了 LLaVA 1.5
|
||||
- 实现了一个 RAM 模型的 gradio 推理例程
|
||||
|
||||
2022/8/31 发布了 v1.0.0rc0 版本
|
||||
🌟 2023/10/12 发布了 v1.1.0 版本
|
||||
|
||||
这个版本引入一个全新的,可扩展性强的训练和测试引擎,但目前仍在开发中。欢迎根据[文档](https://mmclassification.readthedocs.io/zh_CN/1.x/)进行试用。
|
||||
- 支持 Mini-GPT4 训练并提供一个基于 Baichuan-7B 的中文模型
|
||||
- 支持基于 CLIP 的零样本分类。
|
||||
|
||||
同时,新版本中存在一些与旧版本不兼容的修改。请查看[迁移文档](https://mmclassification.readthedocs.io/zh_CN/1.x/migration.html)来详细了解这些变动。
|
||||
🌟 2023/7/4 发布了 v1.0.0 版本
|
||||
|
||||
新版本的公测将持续到 2022 年末,在此期间,我们将基于 `1.x` 分支进行更新,不会合入到 `master` 分支。另外,至少
|
||||
到 2023 年末,我们会保持对 0.x 版本的维护。
|
||||
- 支持更多**多模态**算法的推理, 例如 [**LLaVA**](./configs/llava/), [**MiniGPT-4**](./configs/minigpt4), [**Otter**](./configs/otter/) 等。
|
||||
- 支持约 **10 个多模态**数据集!
|
||||
- 添加自监督学习算法 [**iTPN**](./configs/itpn/), [**SparK**](./configs/spark/)。
|
||||
- 提供[新配置文件](./mmpretrain/configs/)和 [DeepSpeed/FSDP](./configs/mae/benchmarks/) 的样例。这是[新配置文件](https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html#a-pure-python-style-configuration-file-beta) 和 [DeepSpeed/FSDP with FlexibleRunner](https://mmengine.readthedocs.io/en/latest/api/generated/mmengine.runner.FlexibleRunner.html#mmengine.runner.FlexibleRunner) 的文档链接。
|
||||
|
||||
发布历史和更新细节请参考 [更新日志](https://mmclassification.readthedocs.io/zh_CN/1.x/notes/changelog.html)
|
||||
🌟 从 MMClassification 升级到 MMPreTrain
|
||||
|
||||
- 整合来自 MMSelfSup 的自监督学习算法,例如 `MAE`, `BEiT` 等
|
||||
- 支持了 **RIFormer**,简单但有效的视觉主干网络,却移除了 token mixer
|
||||
- 重构数据管道可视化
|
||||
- 支持了 **LeViT**, **XCiT**, **ViG**, **ConvNeXt-V2**, **EVA**, **RevViT**, **EfficientnetV2**, **CLIP**, **TinyViT** 和 **MixMIM** 等骨干网络结构
|
||||
|
||||
这个版本引入一个全新的,可扩展性强的训练和测试引擎,但目前仍在开发中。欢迎根据 [文档](https://mmpretrain.readthedocs.io/zh_CN/latest/) 进行试用。
|
||||
|
||||
同时,新版本中存在一些与旧版本不兼容的修改。请查看 [迁移文档](https://mmpretrain.readthedocs.io/zh_CN/latest/migration.html) 来详细了解这些变动。
|
||||
|
||||
发布历史和更新细节请参考 [更新日志](https://mmpretrain.readthedocs.io/zh_CN/latest/notes/changelog.html)。
|
||||
|
||||
## 安装
|
||||
|
||||
|
@ -81,89 +122,184 @@ MMClassification 是一款基于 PyTorch 的开源图像分类工具箱,是 [O
|
|||
conda create -n open-mmlab python=3.8 pytorch==1.10.1 torchvision==0.11.2 cudatoolkit=11.3 -c pytorch -y
|
||||
conda activate open-mmlab
|
||||
pip3 install openmim
|
||||
git clone -b 1.x https://github.com/open-mmlab/mmclassification.git
|
||||
cd mmclassification
|
||||
git clone https://github.com/open-mmlab/mmpretrain.git
|
||||
cd mmpretrain
|
||||
mim install -e .
|
||||
```
|
||||
|
||||
更详细的步骤请参考 [安装指南](https://mmclassification.readthedocs.io/zh_CN/1.x/get_started.html) 进行安装。
|
||||
更详细的步骤请参考 [安装指南](https://mmpretrain.readthedocs.io/zh_CN/latest/get_started.html) 进行安装。
|
||||
|
||||
如果需要多模态模型,请使用如下方式安装额外的依赖:
|
||||
|
||||
```shell
|
||||
mim install -e ".[multimodal]"
|
||||
```
|
||||
|
||||
## 基础教程
|
||||
|
||||
我们为新用户提供了一系列基础教程:
|
||||
|
||||
- [使用现有模型推理](https://mmclassification.readthedocs.io/zh_CN/1.x/user_guides/inference.html)
|
||||
- [准备数据集](https://mmclassification.readthedocs.io/zh_CN/1.x/user_guides/dataset_prepare.html)
|
||||
- [训练与测试](https://mmclassification.readthedocs.io/zh_CN/1.x/user_guides/train_test.html)
|
||||
- [学习配置文件](https://mmclassification.readthedocs.io/zh_CN/1.x/user_guides/config.html)
|
||||
- [如何微调模型](https://mmclassification.readthedocs.io/zh_CN/1.x/user_guides/finetune.html)
|
||||
- [分析工具](https://mmclassification.readthedocs.io/zh_CN/1.x/user_guides/analysis.html)
|
||||
- [可视化工具](https://mmclassification.readthedocs.io/zh_CN/1.x/user_guides/visualization.html)
|
||||
- [其他工具](https://mmclassification.readthedocs.io/zh_CN/1.x/user_guides/useful_tools.html)
|
||||
- [学习配置文件](https://mmpretrain.readthedocs.io/zh_CN/latest/user_guides/config.html)
|
||||
- [准备数据集](https://mmpretrain.readthedocs.io/zh_CN/latest/user_guides/dataset_prepare.html)
|
||||
- [使用现有模型推理](https://mmpretrain.readthedocs.io/zh_CN/latest/user_guides/inference.html)
|
||||
- [训练](https://mmpretrain.readthedocs.io/zh_CN/latest/user_guides/train.html)
|
||||
- [测试](https://mmpretrain.readthedocs.io/zh_CN/latest/user_guides/test.html)
|
||||
- [下游任务](https://mmpretrain.readthedocs.io/zh_CN/latest/user_guides/downstream.html)
|
||||
|
||||
关于更多的信息,请查阅我们的 [相关文档](https://mmpretrain.readthedocs.io/zh_CN/latest/)。
|
||||
|
||||
## 模型库
|
||||
|
||||
相关结果和模型可在 [model zoo](https://mmclassification.readthedocs.io/zh_CN/1.x/modelzoo_statistics.html) 中获得
|
||||
相关结果和模型可在 [模型库](https://mmpretrain.readthedocs.io/zh_CN/latest/modelzoo_statistics.html) 中获得。
|
||||
|
||||
<details open>
|
||||
<summary>支持的主干网络</summary>
|
||||
|
||||
- [x] [VGG](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/vgg)
|
||||
- [x] [ResNet](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/resnet)
|
||||
- [x] [ResNeXt](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/resnext)
|
||||
- [x] [SE-ResNet](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/seresnet)
|
||||
- [x] [SE-ResNeXt](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/seresnet)
|
||||
- [x] [RegNet](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/regnet)
|
||||
- [x] [ShuffleNetV1](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/shufflenet_v1)
|
||||
- [x] [ShuffleNetV2](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/shufflenet_v2)
|
||||
- [x] [MobileNetV2](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/mobilenet_v2)
|
||||
- [x] [MobileNetV3](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/mobilenet_v3)
|
||||
- [x] [Swin-Transformer](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/swin_transformer)
|
||||
- [x] [Swin-Transformer V2](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/swin_transformer_v2)
|
||||
- [x] [RepVGG](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/repvgg)
|
||||
- [x] [Vision-Transformer](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/vision_transformer)
|
||||
- [x] [Transformer-in-Transformer](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/tnt)
|
||||
- [x] [Res2Net](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/res2net)
|
||||
- [x] [MLP-Mixer](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/mlp_mixer)
|
||||
- [x] [DeiT](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/deit)
|
||||
- [x] [Conformer](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/conformer)
|
||||
- [x] [T2T-ViT](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/t2t_vit)
|
||||
- [x] [Twins](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/twins)
|
||||
- [x] [EfficientNet](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/efficientnet)
|
||||
- [x] [EdgeNeXt](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/edgenext)
|
||||
- [x] [ConvNeXt](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/convnext)
|
||||
- [x] [HRNet](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/hrnet)
|
||||
- [x] [VAN](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/van)
|
||||
- [x] [ConvMixer](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/convmixer)
|
||||
- [x] [CSPNet](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/cspnet)
|
||||
- [x] [PoolFormer](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/poolformer)
|
||||
- [x] [Inception V3](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/inception_v3)
|
||||
- [x] [MobileOne](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/mobileone)
|
||||
- [x] [EfficientFormer](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/efficientformer)
|
||||
- [x] [MViT](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/mvit)
|
||||
|
||||
</details>
|
||||
<div align="center">
|
||||
<b>概览</b>
|
||||
</div>
|
||||
<table align="center">
|
||||
<tbody>
|
||||
<tr align="center" valign="bottom">
|
||||
<td>
|
||||
<b>支持的主干网络</b>
|
||||
</td>
|
||||
<td>
|
||||
<b>自监督学习</b>
|
||||
</td>
|
||||
<td>
|
||||
<b>多模态算法</b>
|
||||
</td>
|
||||
<td>
|
||||
<b>其它</b>
|
||||
</td>
|
||||
</tr>
|
||||
<tr valign="top">
|
||||
<td>
|
||||
<ul>
|
||||
<li><a href="configs/vgg">VGG</a></li>
|
||||
<li><a href="configs/resnet">ResNet</a></li>
|
||||
<li><a href="configs/resnext">ResNeXt</a></li>
|
||||
<li><a href="configs/seresnet">SE-ResNet</a></li>
|
||||
<li><a href="configs/seresnet">SE-ResNeXt</a></li>
|
||||
<li><a href="configs/regnet">RegNet</a></li>
|
||||
<li><a href="configs/shufflenet_v1">ShuffleNet V1</a></li>
|
||||
<li><a href="configs/shufflenet_v2">ShuffleNet V2</a></li>
|
||||
<li><a href="configs/mobilenet_v2">MobileNet V2</a></li>
|
||||
<li><a href="configs/mobilenet_v3">MobileNet V3</a></li>
|
||||
<li><a href="configs/swin_transformer">Swin-Transformer</a></li>
|
||||
<li><a href="configs/swin_transformer_v2">Swin-Transformer V2</a></li>
|
||||
<li><a href="configs/repvgg">RepVGG</a></li>
|
||||
<li><a href="configs/vision_transformer">Vision-Transformer</a></li>
|
||||
<li><a href="configs/tnt">Transformer-in-Transformer</a></li>
|
||||
<li><a href="configs/res2net">Res2Net</a></li>
|
||||
<li><a href="configs/mlp_mixer">MLP-Mixer</a></li>
|
||||
<li><a href="configs/deit">DeiT</a></li>
|
||||
<li><a href="configs/deit3">DeiT-3</a></li>
|
||||
<li><a href="configs/conformer">Conformer</a></li>
|
||||
<li><a href="configs/t2t_vit">T2T-ViT</a></li>
|
||||
<li><a href="configs/twins">Twins</a></li>
|
||||
<li><a href="configs/efficientnet">EfficientNet</a></li>
|
||||
<li><a href="configs/edgenext">EdgeNeXt</a></li>
|
||||
<li><a href="configs/convnext">ConvNeXt</a></li>
|
||||
<li><a href="configs/hrnet">HRNet</a></li>
|
||||
<li><a href="configs/van">VAN</a></li>
|
||||
<li><a href="configs/convmixer">ConvMixer</a></li>
|
||||
<li><a href="configs/cspnet">CSPNet</a></li>
|
||||
<li><a href="configs/poolformer">PoolFormer</a></li>
|
||||
<li><a href="configs/inception_v3">Inception V3</a></li>
|
||||
<li><a href="configs/mobileone">MobileOne</a></li>
|
||||
<li><a href="configs/efficientformer">EfficientFormer</a></li>
|
||||
<li><a href="configs/mvit">MViT</a></li>
|
||||
<li><a href="configs/hornet">HorNet</a></li>
|
||||
<li><a href="configs/mobilevit">MobileViT</a></li>
|
||||
<li><a href="configs/davit">DaViT</a></li>
|
||||
<li><a href="configs/replknet">RepLKNet</a></li>
|
||||
<li><a href="configs/beit">BEiT</a></li>
|
||||
<li><a href="configs/mixmim">MixMIM</a></li>
|
||||
<li><a href="configs/revvit">RevViT</a></li>
|
||||
<li><a href="configs/convnext_v2">ConvNeXt V2</a></li>
|
||||
<li><a href="configs/vig">ViG</a></li>
|
||||
<li><a href="configs/xcit">XCiT</a></li>
|
||||
<li><a href="configs/levit">LeViT</a></li>
|
||||
<li><a href="configs/riformer">RIFormer</a></li>
|
||||
<li><a href="configs/glip">GLIP</a></li>
|
||||
<li><a href="configs/sam">ViT SAM</a></li>
|
||||
<li><a href="configs/eva02">EVA02</a></li>
|
||||
<li><a href="configs/dinov2">DINO V2</a></li>
|
||||
<li><a href="configs/hivit">HiViT</a></li>
|
||||
</ul>
|
||||
</td>
|
||||
<td>
|
||||
<ul>
|
||||
<li><a href="configs/mocov2">MoCo V1 (CVPR'2020)</a></li>
|
||||
<li><a href="configs/simclr">SimCLR (ICML'2020)</a></li>
|
||||
<li><a href="configs/mocov2">MoCo V2 (arXiv'2020)</a></li>
|
||||
<li><a href="configs/byol">BYOL (NeurIPS'2020)</a></li>
|
||||
<li><a href="configs/swav">SwAV (NeurIPS'2020)</a></li>
|
||||
<li><a href="configs/densecl">DenseCL (CVPR'2021)</a></li>
|
||||
<li><a href="configs/simsiam">SimSiam (CVPR'2021)</a></li>
|
||||
<li><a href="configs/barlowtwins">Barlow Twins (ICML'2021)</a></li>
|
||||
<li><a href="configs/mocov3">MoCo V3 (ICCV'2021)</a></li>
|
||||
<li><a href="configs/beit">BEiT (ICLR'2022)</a></li>
|
||||
<li><a href="configs/mae">MAE (CVPR'2022)</a></li>
|
||||
<li><a href="configs/simmim">SimMIM (CVPR'2022)</a></li>
|
||||
<li><a href="configs/maskfeat">MaskFeat (CVPR'2022)</a></li>
|
||||
<li><a href="configs/cae">CAE (arXiv'2022)</a></li>
|
||||
<li><a href="configs/milan">MILAN (arXiv'2022)</a></li>
|
||||
<li><a href="configs/beitv2">BEiT V2 (arXiv'2022)</a></li>
|
||||
<li><a href="configs/eva">EVA (CVPR'2023)</a></li>
|
||||
<li><a href="configs/mixmim">MixMIM (arXiv'2022)</a></li>
|
||||
<li><a href="configs/itpn">iTPN (CVPR'2023)</a></li>
|
||||
<li><a href="configs/spark">SparK (ICLR'2023)</a></li>
|
||||
<li><a href="configs/mff">MFF (ICCV'2023)</a></li>
|
||||
</ul>
|
||||
</td>
|
||||
<td>
|
||||
<ul>
|
||||
<li><a href="configs/blip">BLIP (arxiv'2022)</a></li>
|
||||
<li><a href="configs/blip2">BLIP-2 (arxiv'2023)</a></li>
|
||||
<li><a href="configs/ofa">OFA (CoRR'2022)</a></li>
|
||||
<li><a href="configs/flamingo">Flamingo (NeurIPS'2022)</a></li>
|
||||
<li><a href="configs/chinese_clip">Chinese CLIP (arxiv'2022)</a></li>
|
||||
<li><a href="configs/minigpt4">MiniGPT-4 (arxiv'2023)</a></li>
|
||||
<li><a href="configs/llava">LLaVA (arxiv'2023)</a></li>
|
||||
<li><a href="configs/otter">Otter (arxiv'2023)</a></li>
|
||||
</ul>
|
||||
</td>
|
||||
<td>
|
||||
图像检索任务:
|
||||
<ul>
|
||||
<li><a href="configs/arcface">ArcFace (CVPR'2019)</a></li>
|
||||
</ul>
|
||||
训练和测试 Tips:
|
||||
<ul>
|
||||
<li><a href="https://arxiv.org/abs/1909.13719">RandAug</a></li>
|
||||
<li><a href="https://arxiv.org/abs/1805.09501">AutoAug</a></li>
|
||||
<li><a href="mmpretrain/datasets/samplers/repeat_aug.py">RepeatAugSampler</a></li>
|
||||
<li><a href="mmpretrain/models/tta/score_tta.py">TTA</a></li>
|
||||
<li>...</li>
|
||||
</ul>
|
||||
</td>
|
||||
</tbody>
|
||||
</table>
|
||||
|
||||
## 参与贡献
|
||||
|
||||
我们非常欢迎任何有助于提升 MMClassification 的贡献,请参考 [贡献指南](https://mmclassification.readthedocs.io/zh_CN/1.x/notes/contribution_guide.html) 来了解如何参与贡献。
|
||||
我们非常欢迎任何有助于提升 MMPreTrain 的贡献,请参考 [贡献指南](https://mmpretrain.readthedocs.io/zh_CN/latest/notes/contribution_guide.html) 来了解如何参与贡献。
|
||||
|
||||
## 致谢
|
||||
|
||||
MMClassification 是一款由不同学校和公司共同贡献的开源项目。我们感谢所有为项目提供算法复现和新功能支持的贡献者,以及提供宝贵反馈的用户。
|
||||
|
||||
MMPreTrain 是一款由不同学校和公司共同贡献的开源项目。我们感谢所有为项目提供算法复现和新功能支持的贡献者,以及提供宝贵反馈的用户。
|
||||
我们希望该工具箱和基准测试可以为社区提供灵活的代码工具,供用户复现现有算法并开发自己的新模型,从而不断为开源社区提供贡献。
|
||||
|
||||
## 引用
|
||||
|
||||
如果你在研究中使用了本项目的代码或者性能基准,请参考如下 bibtex 引用 MMClassification。
|
||||
如果你在研究中使用了本项目的代码或者性能基准,请参考如下 bibtex 引用 MMPreTrain。
|
||||
|
||||
```BibTeX
|
||||
@misc{2020mmclassification,
|
||||
title={OpenMMLab's Image Classification Toolbox and Benchmark},
|
||||
author={MMClassification Contributors},
|
||||
howpublished = {\url{https://github.com/open-mmlab/mmclassification}},
|
||||
year={2020}
|
||||
@misc{2023mmpretrain,
|
||||
title={OpenMMLab's Pre-training Toolbox and Benchmark},
|
||||
author={MMPreTrain Contributors},
|
||||
howpublished = {\url{https://github.com/open-mmlab/mmpretrain}},
|
||||
year={2023}
|
||||
}
|
||||
```
|
||||
|
||||
|
@ -176,10 +312,12 @@ MMClassification 是一款由不同学校和公司共同贡献的开源项目。
|
|||
- [MMEngine](https://github.com/open-mmlab/mmengine): OpenMMLab 深度学习模型训练基础库
|
||||
- [MMCV](https://github.com/open-mmlab/mmcv): OpenMMLab 计算机视觉基础库
|
||||
- [MIM](https://github.com/open-mmlab/mim): MIM 是 OpenMMlab 项目、算法、模型的统一入口
|
||||
- [MMClassification](https://github.com/open-mmlab/mmclassification): OpenMMLab 图像分类工具箱
|
||||
- [MMEval](https://github.com/open-mmlab/mmeval): 统一开放的跨框架算法评测库
|
||||
- [MMPreTrain](https://github.com/open-mmlab/mmpretrain): OpenMMLab 深度学习预训练工具箱
|
||||
- [MMDetection](https://github.com/open-mmlab/mmdetection): OpenMMLab 目标检测工具箱
|
||||
- [MMDetection3D](https://github.com/open-mmlab/mmdetection3d): OpenMMLab 新一代通用 3D 目标检测平台
|
||||
- [MMRotate](https://github.com/open-mmlab/mmrotate): OpenMMLab 旋转框检测工具箱与测试基准
|
||||
- [MMYOLO](https://github.com/open-mmlab/mmyolo): OpenMMLab YOLO 系列工具箱与测试基准
|
||||
- [MMSegmentation](https://github.com/open-mmlab/mmsegmentation): OpenMMLab 语义分割工具箱
|
||||
- [MMOCR](https://github.com/open-mmlab/mmocr): OpenMMLab 全流程文字检测识别理解工具包
|
||||
- [MMPose](https://github.com/open-mmlab/mmpose): OpenMMLab 姿态估计工具箱
|
||||
|
@ -190,16 +328,17 @@ MMClassification 是一款由不同学校和公司共同贡献的开源项目。
|
|||
- [MMAction2](https://github.com/open-mmlab/mmaction2): OpenMMLab 新一代视频理解工具箱
|
||||
- [MMTracking](https://github.com/open-mmlab/mmtracking): OpenMMLab 一体化视频目标感知平台
|
||||
- [MMFlow](https://github.com/open-mmlab/mmflow): OpenMMLab 光流估计工具箱与测试基准
|
||||
- [MMEditing](https://github.com/open-mmlab/mmediting): OpenMMLab 图像视频编辑工具箱
|
||||
- [MMagic](https://github.com/open-mmlab/mmagic): OpenMMLab 新一代人工智能内容生成(AIGC)工具箱
|
||||
- [MMGeneration](https://github.com/open-mmlab/mmgeneration): OpenMMLab 图片视频生成模型工具箱
|
||||
- [MMDeploy](https://github.com/open-mmlab/mmdeploy): OpenMMLab 模型部署框架
|
||||
- [Playground](https://github.com/open-mmlab/playground): 收集和展示 OpenMMLab 相关的前沿、有趣的社区项目
|
||||
|
||||
## 欢迎加入 OpenMMLab 社区
|
||||
|
||||
扫描下方的二维码可关注 OpenMMLab 团队的 [知乎官方账号](https://www.zhihu.com/people/openmmlab),加入 OpenMMLab 团队的 [官方交流 QQ 群](https://jq.qq.com/?_wv=1027&k=aCvMxdr3) 或联络 OpenMMLab 官方微信小助手
|
||||
扫描下方的二维码可关注 OpenMMLab 团队的 [知乎官方账号](https://www.zhihu.com/people/openmmlab),扫描下方微信二维码添加喵喵好友,进入 MMPretrain 微信交流社群。【加好友申请格式:研究方向+地区+学校/公司+姓名】
|
||||
|
||||
<div align="center">
|
||||
<img src="https://github.com/open-mmlab/mmcv/raw/master/docs/en/_static/zhihu_qrcode.jpg" height="400" /> <img src="https://github.com/open-mmlab/mmcv/raw/master/docs/en/_static/qq_group_qrcode.jpg" height="400" /> <img src="https://github.com/open-mmlab/mmcv/raw/master/docs/en/_static/wechat_qrcode.jpg" height="400" />
|
||||
<img src="./resources/zhihu_qrcode.jpg" height="400"/> <img src="./resources/miaomiao_qrcode.jpg" height="400"/>
|
||||
</div>
|
||||
|
||||
我们会在 OpenMMLab 社区为大家
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
# dataset settings
|
||||
dataset_type = 'CIFAR100'
|
||||
data_preprocessor = dict(
|
||||
num_classes=100,
|
||||
# RGB format normalization parameters
|
||||
mean=[129.304, 124.070, 112.434],
|
||||
std=[68.170, 65.392, 70.418],
|
||||
|
@ -10,11 +11,11 @@ data_preprocessor = dict(
|
|||
train_pipeline = [
|
||||
dict(type='RandomCrop', crop_size=32, padding=4),
|
||||
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
|
||||
dict(type='PackClsInputs'),
|
||||
dict(type='PackInputs'),
|
||||
]
|
||||
|
||||
test_pipeline = [
|
||||
dict(type='PackClsInputs'),
|
||||
dict(type='PackInputs'),
|
||||
]
|
||||
|
||||
train_dataloader = dict(
|
||||
|
@ -22,11 +23,10 @@ train_dataloader = dict(
|
|||
num_workers=2,
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_prefix='data/cifar100',
|
||||
test_mode=False,
|
||||
data_root='data/cifar100',
|
||||
split='train',
|
||||
pipeline=train_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||
persistent_workers=True,
|
||||
)
|
||||
|
||||
val_dataloader = dict(
|
||||
|
@ -34,11 +34,10 @@ val_dataloader = dict(
|
|||
num_workers=2,
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_prefix='data/cifar100/',
|
||||
test_mode=True,
|
||||
data_root='data/cifar100/',
|
||||
split='test',
|
||||
pipeline=test_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
persistent_workers=True,
|
||||
)
|
||||
val_evaluator = dict(type='Accuracy', topk=(1, ))
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
# dataset settings
|
||||
dataset_type = 'CIFAR10'
|
||||
data_preprocessor = dict(
|
||||
num_classes=10,
|
||||
# RGB format normalization parameters
|
||||
mean=[125.307, 122.961, 113.8575],
|
||||
std=[51.5865, 50.847, 51.255],
|
||||
|
@ -10,11 +11,11 @@ data_preprocessor = dict(
|
|||
train_pipeline = [
|
||||
dict(type='RandomCrop', crop_size=32, padding=4),
|
||||
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
|
||||
dict(type='PackClsInputs'),
|
||||
dict(type='PackInputs'),
|
||||
]
|
||||
|
||||
test_pipeline = [
|
||||
dict(type='PackClsInputs'),
|
||||
dict(type='PackInputs'),
|
||||
]
|
||||
|
||||
train_dataloader = dict(
|
||||
|
@ -22,11 +23,10 @@ train_dataloader = dict(
|
|||
num_workers=2,
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_prefix='data/cifar10',
|
||||
test_mode=False,
|
||||
data_root='data/cifar10',
|
||||
split='train',
|
||||
pipeline=train_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||
persistent_workers=True,
|
||||
)
|
||||
|
||||
val_dataloader = dict(
|
||||
|
@ -34,11 +34,10 @@ val_dataloader = dict(
|
|||
num_workers=2,
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_prefix='data/cifar10/',
|
||||
test_mode=True,
|
||||
data_root='data/cifar10/',
|
||||
split='test',
|
||||
pipeline=test_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
persistent_workers=True,
|
||||
)
|
||||
val_evaluator = dict(type='Accuracy', topk=(1, ))
|
||||
|
||||
|
|
|
@ -0,0 +1,70 @@
|
|||
# data settings
|
||||
# coco caption annotations can be grabbed from LAVIS repo
|
||||
# https://github.com/salesforce/LAVIS/blob/main/lavis/configs/datasets/coco/defaults_cap.yaml
|
||||
data_preprocessor = dict(
|
||||
type='MultiModalDataPreprocessor',
|
||||
mean=[122.770938, 116.7460125, 104.09373615],
|
||||
std=[68.5005327, 66.6321579, 70.32316305],
|
||||
to_rgb=True,
|
||||
)
|
||||
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='RandomResizedCrop',
|
||||
scale=384,
|
||||
interpolation='bicubic',
|
||||
backend='pillow'),
|
||||
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
|
||||
dict(type='CleanCaption', keys='gt_caption'),
|
||||
dict(
|
||||
type='PackInputs',
|
||||
algorithm_keys=['gt_caption'],
|
||||
meta_keys=['image_id'],
|
||||
),
|
||||
]
|
||||
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='Resize',
|
||||
scale=(384, 384),
|
||||
interpolation='bicubic',
|
||||
backend='pillow'),
|
||||
dict(type='PackInputs', meta_keys=['image_id']),
|
||||
]
|
||||
|
||||
train_dataloader = dict(
|
||||
batch_size=32,
|
||||
num_workers=5,
|
||||
dataset=dict(
|
||||
type='COCOCaption',
|
||||
data_root='data/coco',
|
||||
ann_file='annotations/coco_karpathy_train.json',
|
||||
pipeline=train_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||
persistent_workers=True,
|
||||
drop_last=True,
|
||||
)
|
||||
|
||||
val_dataloader = dict(
|
||||
batch_size=16,
|
||||
num_workers=5,
|
||||
dataset=dict(
|
||||
type='COCOCaption',
|
||||
data_root='data/coco',
|
||||
ann_file='annotations/coco_karpathy_val.json',
|
||||
pipeline=test_pipeline,
|
||||
),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
persistent_workers=True,
|
||||
)
|
||||
|
||||
val_evaluator = dict(
|
||||
type='COCOCaption',
|
||||
ann_file='data/coco/annotations/coco_karpathy_val_gt.json',
|
||||
)
|
||||
|
||||
# # If you want standard test, please manually configure the test dataset
|
||||
test_dataloader = val_dataloader
|
||||
test_evaluator = val_evaluator
|
|
@ -0,0 +1,75 @@
|
|||
# data settings
|
||||
|
||||
data_preprocessor = dict(
|
||||
mean=[122.770938, 116.7460125, 104.09373615],
|
||||
std=[68.5005327, 66.6321579, 70.32316305],
|
||||
to_rgb=True,
|
||||
)
|
||||
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='RandomResizedCrop',
|
||||
scale=384,
|
||||
interpolation='bicubic',
|
||||
backend='pillow'),
|
||||
dict(
|
||||
type='PackInputs',
|
||||
algorithm_keys=['question', 'gt_answer', 'gt_answer_weight'],
|
||||
meta_keys=['question_id', 'image_id'],
|
||||
),
|
||||
]
|
||||
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='Resize',
|
||||
scale=(480, 480),
|
||||
interpolation='bicubic',
|
||||
backend='pillow'),
|
||||
dict(
|
||||
type='CleanCaption',
|
||||
keys=['question'],
|
||||
),
|
||||
dict(
|
||||
type='PackInputs',
|
||||
algorithm_keys=['question', 'gt_answer', 'gt_answer_weight'],
|
||||
meta_keys=['question_id', 'image_id'],
|
||||
),
|
||||
]
|
||||
|
||||
train_dataloader = dict(
|
||||
batch_size=16,
|
||||
num_workers=8,
|
||||
dataset=dict(
|
||||
type='COCOVQA',
|
||||
data_root='data/coco',
|
||||
data_prefix='train2014',
|
||||
question_file=
|
||||
'annotations/okvqa_OpenEnded_mscoco_train2014_questions.json',
|
||||
ann_file='annotations/okvqa_mscoco_train2014_annotations.json',
|
||||
pipeline=train_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||
persistent_workers=True,
|
||||
drop_last=True,
|
||||
)
|
||||
|
||||
val_dataloader = dict(
|
||||
batch_size=16,
|
||||
num_workers=8,
|
||||
dataset=dict(
|
||||
type='COCOVQA',
|
||||
data_root='data/coco',
|
||||
data_prefix='val2014',
|
||||
question_file=
|
||||
'annotations/okvqa_OpenEnded_mscoco_val2014_questions.json',
|
||||
ann_file='annotations/okvqa_mscoco_val2014_annotations.json',
|
||||
pipeline=test_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
persistent_workers=True,
|
||||
)
|
||||
|
||||
val_evaluator = dict(type='VQAAcc')
|
||||
|
||||
test_dataloader = val_dataloader
|
||||
test_evaluator = val_evaluator
|
|
@ -0,0 +1,99 @@
|
|||
# data settings
|
||||
# Here are the links to download the annotations for coco retrieval for conveniency # noqa
|
||||
# https://download.openmmlab.com/mmclassification/datasets/coco_retrieval/caption_karpathy_train2014.json
|
||||
# https://download.openmmlab.com/mmclassification/datasets/coco_retrieval/caption_karpathy_val2014.json
|
||||
# https://download.openmmlab.com/mmclassification/datasets/coco_retrieval/caption_karpathy_test2014.json
|
||||
data_preprocessor = dict(
|
||||
type='MultiModalDataPreprocessor',
|
||||
mean=[122.770938, 116.7460125, 104.09373615],
|
||||
std=[68.5005327, 66.6321579, 70.32316305],
|
||||
to_rgb=True,
|
||||
)
|
||||
|
||||
rand_increasing_policies = [
|
||||
dict(type='AutoContrast'),
|
||||
dict(type='Equalize'),
|
||||
dict(type='Rotate', magnitude_key='angle', magnitude_range=(0, 30)),
|
||||
dict(
|
||||
type='Brightness', magnitude_key='magnitude',
|
||||
magnitude_range=(0, 0.0)),
|
||||
dict(type='Sharpness', magnitude_key='magnitude', magnitude_range=(0, 0)),
|
||||
dict(
|
||||
type='Shear',
|
||||
magnitude_key='magnitude',
|
||||
magnitude_range=(0, 0.3),
|
||||
direction='horizontal'),
|
||||
dict(
|
||||
type='Shear',
|
||||
magnitude_key='magnitude',
|
||||
magnitude_range=(0, 0.3),
|
||||
direction='vertical'),
|
||||
]
|
||||
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='RandomResizedCrop',
|
||||
scale=384,
|
||||
crop_ratio_range=(0.5, 1.0),
|
||||
interpolation='bicubic'),
|
||||
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
|
||||
dict(
|
||||
type='RandAugment',
|
||||
policies=rand_increasing_policies,
|
||||
num_policies=2,
|
||||
magnitude_level=5),
|
||||
dict(type='CleanCaption', keys='text'),
|
||||
dict(
|
||||
type='PackInputs',
|
||||
algorithm_keys=['text', 'is_matched'],
|
||||
meta_keys=['image_id']),
|
||||
]
|
||||
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='Resize',
|
||||
scale=(384, 384),
|
||||
interpolation='bicubic',
|
||||
backend='pillow'),
|
||||
dict(type='CleanCaption', keys='text'),
|
||||
dict(
|
||||
type='PackInputs',
|
||||
algorithm_keys=['text', 'gt_text_id', 'gt_image_id'],
|
||||
meta_keys=['image_id']),
|
||||
]
|
||||
|
||||
train_dataloader = dict(
|
||||
batch_size=32,
|
||||
num_workers=16,
|
||||
dataset=dict(
|
||||
type='COCORetrieval',
|
||||
data_root='data/coco',
|
||||
ann_file='annotations/caption_karpathy_train2014.json',
|
||||
pipeline=train_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||
persistent_workers=True,
|
||||
drop_last=True,
|
||||
)
|
||||
|
||||
val_dataloader = dict(
|
||||
batch_size=64,
|
||||
num_workers=16,
|
||||
dataset=dict(
|
||||
type='COCORetrieval',
|
||||
data_root='data/coco',
|
||||
ann_file='annotations/caption_karpathy_val2014.json',
|
||||
pipeline=test_pipeline,
|
||||
# This is required for evaluation
|
||||
test_mode=True,
|
||||
),
|
||||
sampler=dict(type='SequentialSampler', subsample_type='sequential'),
|
||||
persistent_workers=True,
|
||||
)
|
||||
|
||||
val_evaluator = dict(type='RetrievalRecall', topk=(1, 5, 10))
|
||||
|
||||
# If you want standard test, please manually configure the test dataset
|
||||
test_dataloader = val_dataloader
|
||||
test_evaluator = val_evaluator
|
|
@ -0,0 +1,96 @@
|
|||
# data settings
|
||||
data_preprocessor = dict(
|
||||
type='MultiModalDataPreprocessor',
|
||||
mean=[122.770938, 116.7460125, 104.09373615],
|
||||
std=[68.5005327, 66.6321579, 70.32316305],
|
||||
to_rgb=True,
|
||||
)
|
||||
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='RandomResizedCrop',
|
||||
scale=(480, 480),
|
||||
crop_ratio_range=(0.5, 1.0),
|
||||
interpolation='bicubic',
|
||||
backend='pillow'),
|
||||
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
|
||||
dict(
|
||||
type='RandAugment',
|
||||
policies='simple_increasing', # slightly different from LAVIS
|
||||
num_policies=2,
|
||||
magnitude_level=5),
|
||||
dict(type='CleanCaption', keys=['question', 'gt_answer']),
|
||||
dict(
|
||||
type='PackInputs',
|
||||
algorithm_keys=['question', 'gt_answer', 'gt_answer_weight']),
|
||||
]
|
||||
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='Resize',
|
||||
scale=(480, 480),
|
||||
interpolation='bicubic',
|
||||
backend='pillow'),
|
||||
dict(type='CleanCaption', keys=['question']),
|
||||
dict(
|
||||
type='PackInputs',
|
||||
algorithm_keys=['question'],
|
||||
meta_keys=['question_id']),
|
||||
]
|
||||
|
||||
train_dataloader = dict(
|
||||
batch_size=32,
|
||||
num_workers=8,
|
||||
dataset=dict(
|
||||
type='ConcatDataset',
|
||||
datasets=[
|
||||
# VQAv2 train
|
||||
dict(
|
||||
type='COCOVQA',
|
||||
data_root='data/coco',
|
||||
data_prefix='train2014',
|
||||
question_file=
|
||||
'annotations/v2_OpenEnded_mscoco_train2014_questions.json',
|
||||
ann_file='annotations/v2_mscoco_train2014_annotations.json',
|
||||
pipeline=train_pipeline,
|
||||
),
|
||||
# VQAv2 val
|
||||
dict(
|
||||
type='COCOVQA',
|
||||
data_root='data/coco',
|
||||
data_prefix='val2014',
|
||||
question_file=
|
||||
'annotations/v2_OpenEnded_mscoco_val2014_questions.json',
|
||||
ann_file='annotations/v2_mscoco_val2014_annotations.json',
|
||||
pipeline=train_pipeline,
|
||||
),
|
||||
# Visual Genome
|
||||
dict(
|
||||
type='VisualGenomeQA',
|
||||
data_root='visual_genome',
|
||||
data_prefix='image',
|
||||
ann_file='question_answers.json',
|
||||
pipeline=train_pipeline,
|
||||
)
|
||||
]),
|
||||
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||
persistent_workers=True,
|
||||
drop_last=True,
|
||||
)
|
||||
|
||||
test_dataloader = dict(
|
||||
batch_size=32,
|
||||
num_workers=8,
|
||||
dataset=dict(
|
||||
type='COCOVQA',
|
||||
data_root='data/coco',
|
||||
data_prefix='test2015',
|
||||
question_file=
|
||||
'annotations/v2_OpenEnded_mscoco_test2015_questions.json', # noqa: E501
|
||||
pipeline=test_pipeline,
|
||||
),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
)
|
||||
test_evaluator = dict(type='ReportVQA', file_path='vqa_test.json')
|
|
@ -0,0 +1,84 @@
|
|||
# data settings
|
||||
|
||||
data_preprocessor = dict(
|
||||
mean=[122.770938, 116.7460125, 104.09373615],
|
||||
std=[68.5005327, 66.6321579, 70.32316305],
|
||||
to_rgb=True,
|
||||
)
|
||||
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='RandomResizedCrop',
|
||||
scale=384,
|
||||
interpolation='bicubic',
|
||||
backend='pillow'),
|
||||
dict(
|
||||
type='PackInputs',
|
||||
algorithm_keys=['question', 'gt_answer', 'gt_answer_weight'],
|
||||
meta_keys=['question_id', 'image_id'],
|
||||
),
|
||||
]
|
||||
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='Resize',
|
||||
scale=(480, 480),
|
||||
interpolation='bicubic',
|
||||
backend='pillow'),
|
||||
dict(
|
||||
type='CleanCaption',
|
||||
keys=['question'],
|
||||
),
|
||||
dict(
|
||||
type='PackInputs',
|
||||
algorithm_keys=['question', 'gt_answer', 'gt_answer_weight'],
|
||||
meta_keys=['question_id', 'image_id'],
|
||||
),
|
||||
]
|
||||
|
||||
train_dataloader = dict(
|
||||
batch_size=16,
|
||||
num_workers=8,
|
||||
dataset=dict(
|
||||
type='COCOVQA',
|
||||
data_root='data/coco',
|
||||
data_prefix='train2014',
|
||||
question_file=
|
||||
'annotations/v2_OpenEnded_mscoco_train2014_questions.json',
|
||||
ann_file='annotations/v2_mscoco_train2014_annotations.json',
|
||||
pipeline=train_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||
persistent_workers=True,
|
||||
drop_last=True,
|
||||
)
|
||||
|
||||
val_dataloader = dict(
|
||||
batch_size=16,
|
||||
num_workers=8,
|
||||
dataset=dict(
|
||||
type='COCOVQA',
|
||||
data_root='data/coco',
|
||||
data_prefix='val2014',
|
||||
question_file='annotations/v2_OpenEnded_mscoco_val2014_questions.json',
|
||||
ann_file='annotations/v2_mscoco_val2014_annotations.json',
|
||||
pipeline=test_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
persistent_workers=True,
|
||||
)
|
||||
val_evaluator = dict(type='VQAAcc')
|
||||
|
||||
test_dataloader = dict(
|
||||
batch_size=16,
|
||||
num_workers=8,
|
||||
dataset=dict(
|
||||
type='COCOVQA',
|
||||
data_root='data/coco',
|
||||
data_prefix='test2015',
|
||||
question_file= # noqa: E251
|
||||
'annotations/v2_OpenEnded_mscoco_test2015_questions.json',
|
||||
pipeline=test_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
)
|
||||
test_evaluator = dict(type='ReportVQA', file_path='vqa_test.json')
|
|
@ -1,6 +1,7 @@
|
|||
# dataset settings
|
||||
dataset_type = 'CUB'
|
||||
data_preprocessor = dict(
|
||||
num_classes=200,
|
||||
# RGB format normalization parameters
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
|
@ -13,14 +14,14 @@ train_pipeline = [
|
|||
dict(type='Resize', scale=510),
|
||||
dict(type='RandomCrop', crop_size=384),
|
||||
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
|
||||
dict(type='PackClsInputs'),
|
||||
dict(type='PackInputs'),
|
||||
]
|
||||
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='Resize', scale=510),
|
||||
dict(type='CenterCrop', crop_size=384),
|
||||
dict(type='PackClsInputs'),
|
||||
dict(type='PackInputs'),
|
||||
]
|
||||
|
||||
train_dataloader = dict(
|
||||
|
@ -29,10 +30,9 @@ train_dataloader = dict(
|
|||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root='data/CUB_200_2011',
|
||||
test_mode=False,
|
||||
split='train',
|
||||
pipeline=train_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||
persistent_workers=True,
|
||||
)
|
||||
|
||||
val_dataloader = dict(
|
||||
|
@ -41,10 +41,9 @@ val_dataloader = dict(
|
|||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root='data/CUB_200_2011',
|
||||
test_mode=True,
|
||||
split='test',
|
||||
pipeline=test_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
persistent_workers=True,
|
||||
)
|
||||
val_evaluator = dict(type='Accuracy', topk=(1, ))
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
# dataset settings
|
||||
dataset_type = 'CUB'
|
||||
data_preprocessor = dict(
|
||||
num_classes=200,
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
# convert image from BGR to RGB
|
||||
|
@ -12,14 +13,14 @@ train_pipeline = [
|
|||
dict(type='Resize', scale=600),
|
||||
dict(type='RandomCrop', crop_size=448),
|
||||
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
|
||||
dict(type='PackClsInputs'),
|
||||
dict(type='PackInputs'),
|
||||
]
|
||||
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='Resize', scale=600),
|
||||
dict(type='CenterCrop', crop_size=448),
|
||||
dict(type='PackClsInputs'),
|
||||
dict(type='PackInputs'),
|
||||
]
|
||||
|
||||
train_dataloader = dict(
|
||||
|
@ -28,10 +29,9 @@ train_dataloader = dict(
|
|||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root='data/CUB_200_2011',
|
||||
test_mode=False,
|
||||
split='train',
|
||||
pipeline=train_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||
persistent_workers=True,
|
||||
)
|
||||
|
||||
val_dataloader = dict(
|
||||
|
@ -40,10 +40,9 @@ val_dataloader = dict(
|
|||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root='data/CUB_200_2011',
|
||||
test_mode=True,
|
||||
split='test',
|
||||
pipeline=test_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
persistent_workers=True,
|
||||
)
|
||||
val_evaluator = dict(type='Accuracy', topk=(1, ))
|
||||
|
||||
|
|
|
@ -0,0 +1,92 @@
|
|||
# data settings
|
||||
|
||||
data_preprocessor = dict(
|
||||
type='MultiModalDataPreprocessor',
|
||||
mean=[122.770938, 116.7460125, 104.09373615],
|
||||
std=[68.5005327, 66.6321579, 70.32316305],
|
||||
to_rgb=True,
|
||||
)
|
||||
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='RandomResizedCrop',
|
||||
scale=384,
|
||||
interpolation='bicubic',
|
||||
backend='pillow'),
|
||||
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
|
||||
dict(type='CleanCaption', keys='gt_caption'),
|
||||
dict(
|
||||
type='PackInputs',
|
||||
algorithm_keys=['gt_caption'],
|
||||
meta_keys=['image_id'],
|
||||
),
|
||||
]
|
||||
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='Resize',
|
||||
scale=(384, 384),
|
||||
interpolation='bicubic',
|
||||
backend='pillow'),
|
||||
dict(type='PackInputs', meta_keys=['image_id']),
|
||||
]
|
||||
|
||||
train_dataloader = dict(
|
||||
batch_size=32,
|
||||
num_workers=5,
|
||||
dataset=dict(
|
||||
type='Flickr30kCaption',
|
||||
data_root='data/flickr30k',
|
||||
ann_file='annotations/dataset_flickr30k.json',
|
||||
data_prefix='images',
|
||||
split='train',
|
||||
pipeline=train_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||
persistent_workers=True,
|
||||
drop_last=True,
|
||||
)
|
||||
|
||||
val_dataloader = dict(
|
||||
batch_size=16,
|
||||
num_workers=5,
|
||||
dataset=dict(
|
||||
type='Flickr30kCaption',
|
||||
data_root='data/flickr30k',
|
||||
ann_file='annotations/dataset_flickr30k.json',
|
||||
data_prefix='images',
|
||||
split='val',
|
||||
pipeline=test_pipeline,
|
||||
),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
persistent_workers=True,
|
||||
)
|
||||
|
||||
# refer tools/dataset_converters/convert_flickr30k_ann.py
|
||||
val_evaluator = dict(
|
||||
type='COCOCaption',
|
||||
ann_file='data/flickr30k_val_gt.json',
|
||||
)
|
||||
|
||||
# # If you want standard test, please manually configure the test dataset
|
||||
test_dataloader = dict(
|
||||
batch_size=16,
|
||||
num_workers=5,
|
||||
dataset=dict(
|
||||
type='Flickr30kCaption',
|
||||
data_root='data/flickr30k',
|
||||
ann_file='annotations/dataset_flickr30k.json',
|
||||
data_prefix='images',
|
||||
split='test',
|
||||
pipeline=test_pipeline,
|
||||
),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
persistent_workers=True,
|
||||
)
|
||||
|
||||
# refer tools/dataset_converters/convert_flickr30k_ann.py
|
||||
test_evaluator = dict(
|
||||
type='COCOCaption',
|
||||
ann_file='data/flickr30k_test_gt.json',
|
||||
)
|
|
@ -0,0 +1,112 @@
|
|||
# data settings
|
||||
data_preprocessor = dict(
|
||||
type='MultiModalDataPreprocessor',
|
||||
mean=[122.770938, 116.7460125, 104.09373615],
|
||||
std=[68.5005327, 66.6321579, 70.32316305],
|
||||
to_rgb=True,
|
||||
)
|
||||
|
||||
rand_increasing_policies = [
|
||||
dict(type='AutoContrast'),
|
||||
dict(type='Equalize'),
|
||||
dict(type='Rotate', magnitude_key='angle', magnitude_range=(0, 30)),
|
||||
dict(
|
||||
type='Brightness', magnitude_key='magnitude',
|
||||
magnitude_range=(0, 0.0)),
|
||||
dict(type='Sharpness', magnitude_key='magnitude', magnitude_range=(0, 0)),
|
||||
dict(
|
||||
type='Shear',
|
||||
magnitude_key='magnitude',
|
||||
magnitude_range=(0, 0.3),
|
||||
direction='horizontal'),
|
||||
dict(
|
||||
type='Shear',
|
||||
magnitude_key='magnitude',
|
||||
magnitude_range=(0, 0.3),
|
||||
direction='vertical'),
|
||||
]
|
||||
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='RandomResizedCrop',
|
||||
scale=384,
|
||||
crop_ratio_range=(0.5, 1.0),
|
||||
interpolation='bicubic'),
|
||||
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
|
||||
dict(
|
||||
type='RandAugment',
|
||||
policies=rand_increasing_policies,
|
||||
num_policies=2,
|
||||
magnitude_level=5),
|
||||
dict(type='CleanCaption', keys='text'),
|
||||
dict(
|
||||
type='PackInputs',
|
||||
algorithm_keys=['text', 'is_matched'],
|
||||
meta_keys=['image_id']),
|
||||
]
|
||||
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='Resize',
|
||||
scale=(384, 384),
|
||||
interpolation='bicubic',
|
||||
backend='pillow'),
|
||||
dict(type='CleanCaption', keys='text'),
|
||||
dict(
|
||||
type='PackInputs',
|
||||
algorithm_keys=['text', 'gt_text_id', 'gt_image_id'],
|
||||
meta_keys=['image_id']),
|
||||
]
|
||||
|
||||
train_dataloader = dict(
|
||||
batch_size=32,
|
||||
num_workers=16,
|
||||
dataset=dict(
|
||||
type='Flickr30kRetrieval',
|
||||
data_root='data/flickr30k',
|
||||
ann_file='annotations/dataset_flickr30k.json',
|
||||
data_prefix='images',
|
||||
split='train',
|
||||
pipeline=train_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||
persistent_workers=True,
|
||||
drop_last=True,
|
||||
)
|
||||
|
||||
val_dataloader = dict(
|
||||
batch_size=64,
|
||||
num_workers=16,
|
||||
dataset=dict(
|
||||
type='Flickr30kRetrieval',
|
||||
data_root='data/flickr30k',
|
||||
ann_file='annotations/dataset_flickr30k.json',
|
||||
data_prefix='images',
|
||||
split='val',
|
||||
pipeline=test_pipeline,
|
||||
test_mode=True, # This is required for evaluation
|
||||
),
|
||||
sampler=dict(type='SequentialSampler', subsample_type='sequential'),
|
||||
persistent_workers=True,
|
||||
)
|
||||
|
||||
val_evaluator = dict(type='RetrievalRecall', topk=(1, 5, 10))
|
||||
|
||||
# If you want standard test, please manually configure the test dataset
|
||||
test_dataloader = dict(
|
||||
batch_size=64,
|
||||
num_workers=16,
|
||||
dataset=dict(
|
||||
type='Flickr30kRetrieval',
|
||||
data_root='data/flickr30k',
|
||||
ann_file='annotations/dataset_flickr30k.json',
|
||||
data_prefix='images',
|
||||
split='test',
|
||||
pipeline=test_pipeline,
|
||||
test_mode=True, # This is required for evaluation
|
||||
),
|
||||
sampler=dict(type='SequentialSampler', subsample_type='sequential'),
|
||||
persistent_workers=True,
|
||||
)
|
||||
test_evaluator = val_evaluator
|
|
@ -0,0 +1,81 @@
|
|||
# data settings
|
||||
|
||||
data_preprocessor = dict(
|
||||
mean=[122.770938, 116.7460125, 104.09373615],
|
||||
std=[68.5005327, 66.6321579, 70.32316305],
|
||||
to_rgb=True,
|
||||
)
|
||||
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='RandomResizedCrop',
|
||||
scale=384,
|
||||
interpolation='bicubic',
|
||||
backend='pillow'),
|
||||
dict(
|
||||
type='PackInputs',
|
||||
algorithm_keys=['question', 'gt_answer', 'gt_answer_weight'],
|
||||
meta_keys=['question_id', 'image_id'],
|
||||
),
|
||||
]
|
||||
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='Resize',
|
||||
scale=(480, 480),
|
||||
interpolation='bicubic',
|
||||
backend='pillow'),
|
||||
dict(
|
||||
type='CleanCaption',
|
||||
keys=['question'],
|
||||
),
|
||||
dict(
|
||||
type='PackInputs',
|
||||
algorithm_keys=['question', 'gt_answer', 'gt_answer_weight'],
|
||||
meta_keys=['question_id', 'image_id'],
|
||||
),
|
||||
]
|
||||
|
||||
train_dataloader = dict(
|
||||
batch_size=16,
|
||||
num_workers=8,
|
||||
dataset=dict(
|
||||
type='GQA',
|
||||
data_root='data/gqa',
|
||||
data_prefix='images',
|
||||
ann_file='annotations/train_balanced_questions.json',
|
||||
pipeline=test_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
persistent_workers=True,
|
||||
drop_last=True,
|
||||
)
|
||||
|
||||
val_dataloader = dict(
|
||||
batch_size=16,
|
||||
num_workers=8,
|
||||
dataset=dict(
|
||||
type='GQA',
|
||||
data_root='data/gqa',
|
||||
data_prefix='images',
|
||||
ann_file='annotations/testdev_balanced_questions.json',
|
||||
pipeline=test_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
persistent_workers=True,
|
||||
)
|
||||
val_evaluator = dict(type='GQAAcc')
|
||||
|
||||
test_dataloader = dict(
|
||||
batch_size=16,
|
||||
num_workers=8,
|
||||
dataset=dict(
|
||||
type='GQA',
|
||||
data_root='data/gqa',
|
||||
data_prefix='images',
|
||||
ann_file='annotations/testdev_balanced_questions.json',
|
||||
pipeline=test_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
persistent_workers=True,
|
||||
)
|
||||
test_evaluator = val_evaluator
|
|
@ -1,6 +1,7 @@
|
|||
# dataset settings
|
||||
dataset_type = 'ImageNet21k'
|
||||
data_preprocessor = dict(
|
||||
num_classes=21842,
|
||||
# RGB format normalization parameters
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
|
@ -12,14 +13,7 @@ train_pipeline = [
|
|||
dict(type='LoadImageFromFile'),
|
||||
dict(type='RandomResizedCrop', scale=224),
|
||||
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
|
||||
dict(type='PackClsInputs'),
|
||||
]
|
||||
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='ResizeEdge', scale=256, edge='short'),
|
||||
dict(type='CenterCrop', crop_size=224),
|
||||
dict(type='PackClsInputs'),
|
||||
dict(type='PackInputs'),
|
||||
]
|
||||
|
||||
train_dataloader = dict(
|
||||
|
@ -28,27 +22,7 @@ train_dataloader = dict(
|
|||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root='data/imagenet21k',
|
||||
ann_file='meta/train.txt',
|
||||
data_prefix='train',
|
||||
split='train',
|
||||
pipeline=train_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||
persistent_workers=True,
|
||||
)
|
||||
|
||||
val_dataloader = dict(
|
||||
batch_size=128,
|
||||
num_workers=5,
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root='data/imagenet21k',
|
||||
ann_file='meta/val.txt',
|
||||
data_prefix='val',
|
||||
pipeline=test_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
persistent_workers=True,
|
||||
)
|
||||
val_evaluator = dict(type='Accuracy', topk=(1, 5))
|
||||
|
||||
# If you want standard test, please manually configure the test dataset
|
||||
test_dataloader = val_dataloader
|
||||
test_evaluator = val_evaluator
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
# dataset settings
|
||||
dataset_type = 'ImageNet'
|
||||
data_preprocessor = dict(
|
||||
num_classes=1000,
|
||||
# RGB format normalization parameters
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
|
@ -27,14 +28,14 @@ train_pipeline = [
|
|||
max_area_ratio=1 / 3,
|
||||
fill_color=bgr_mean,
|
||||
fill_std=bgr_std),
|
||||
dict(type='PackClsInputs'),
|
||||
dict(type='PackInputs'),
|
||||
]
|
||||
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='ResizeEdge', scale=256, edge='short', backend='pillow'),
|
||||
dict(type='CenterCrop', crop_size=224),
|
||||
dict(type='PackClsInputs'),
|
||||
dict(type='PackInputs'),
|
||||
]
|
||||
|
||||
train_dataloader = dict(
|
||||
|
@ -43,11 +44,9 @@ train_dataloader = dict(
|
|||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root='data/imagenet',
|
||||
ann_file='meta/train.txt',
|
||||
data_prefix='train',
|
||||
split='train',
|
||||
pipeline=train_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||
persistent_workers=True,
|
||||
)
|
||||
|
||||
val_dataloader = dict(
|
||||
|
@ -56,11 +55,9 @@ val_dataloader = dict(
|
|||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root='data/imagenet',
|
||||
ann_file='meta/val.txt',
|
||||
data_prefix='val',
|
||||
split='val',
|
||||
pipeline=test_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
persistent_workers=True,
|
||||
)
|
||||
val_evaluator = dict(type='Accuracy', topk=(1, 5))
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
# dataset settings
|
||||
dataset_type = 'ImageNet'
|
||||
data_preprocessor = dict(
|
||||
num_classes=1000,
|
||||
# RGB format normalization parameters
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
|
@ -36,7 +37,7 @@ train_pipeline = [
|
|||
max_area_ratio=1 / 3,
|
||||
fill_color=bgr_mean,
|
||||
fill_std=bgr_std),
|
||||
dict(type='PackClsInputs'),
|
||||
dict(type='PackInputs'),
|
||||
]
|
||||
|
||||
test_pipeline = [
|
||||
|
@ -48,7 +49,7 @@ test_pipeline = [
|
|||
backend='pillow',
|
||||
interpolation='bicubic'),
|
||||
dict(type='CenterCrop', crop_size=224),
|
||||
dict(type='PackClsInputs'),
|
||||
dict(type='PackInputs'),
|
||||
]
|
||||
|
||||
train_dataloader = dict(
|
||||
|
@ -57,11 +58,9 @@ train_dataloader = dict(
|
|||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root='data/imagenet',
|
||||
ann_file='meta/train.txt',
|
||||
data_prefix='train',
|
||||
split='train',
|
||||
pipeline=train_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||
persistent_workers=True,
|
||||
)
|
||||
|
||||
val_dataloader = dict(
|
||||
|
@ -70,11 +69,9 @@ val_dataloader = dict(
|
|||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root='data/imagenet',
|
||||
ann_file='meta/val.txt',
|
||||
data_prefix='val',
|
||||
split='val',
|
||||
pipeline=test_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
persistent_workers=True,
|
||||
)
|
||||
val_evaluator = dict(type='Accuracy', topk=(1, 5))
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
# dataset settings
|
||||
dataset_type = 'ImageNet'
|
||||
data_preprocessor = dict(
|
||||
num_classes=1000,
|
||||
# RGB format normalization parameters
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
|
@ -36,7 +37,7 @@ train_pipeline = [
|
|||
max_area_ratio=1 / 3,
|
||||
fill_color=bgr_mean,
|
||||
fill_std=bgr_std),
|
||||
dict(type='PackClsInputs'),
|
||||
dict(type='PackInputs'),
|
||||
]
|
||||
|
||||
test_pipeline = [
|
||||
|
@ -48,7 +49,7 @@ test_pipeline = [
|
|||
backend='pillow',
|
||||
interpolation='bicubic'),
|
||||
dict(type='CenterCrop', crop_size=224),
|
||||
dict(type='PackClsInputs'),
|
||||
dict(type='PackInputs'),
|
||||
]
|
||||
|
||||
train_dataloader = dict(
|
||||
|
@ -57,11 +58,9 @@ train_dataloader = dict(
|
|||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root='data/imagenet',
|
||||
ann_file='meta/train.txt',
|
||||
data_prefix='train',
|
||||
split='train',
|
||||
pipeline=train_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||
persistent_workers=True,
|
||||
)
|
||||
|
||||
val_dataloader = dict(
|
||||
|
@ -70,11 +69,9 @@ val_dataloader = dict(
|
|||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root='data/imagenet',
|
||||
ann_file='meta/val.txt',
|
||||
data_prefix='val',
|
||||
split='val',
|
||||
pipeline=test_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
persistent_workers=True,
|
||||
)
|
||||
val_evaluator = dict(type='Accuracy', topk=(1, 5))
|
||||
|
||||
|
|
|
@ -0,0 +1,83 @@
|
|||
# dataset settings
|
||||
dataset_type = 'ImageNet'
|
||||
data_preprocessor = dict(
|
||||
num_classes=1000,
|
||||
# RGB format normalization parameters
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
# convert image from BGR to RGB
|
||||
to_rgb=True,
|
||||
)
|
||||
|
||||
bgr_mean = data_preprocessor['mean'][::-1]
|
||||
bgr_std = data_preprocessor['std'][::-1]
|
||||
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='RandomResizedCrop',
|
||||
scale=224,
|
||||
backend='pillow',
|
||||
interpolation='bicubic'),
|
||||
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
|
||||
dict(
|
||||
type='RandAugment',
|
||||
policies='timm_increasing',
|
||||
num_policies=2,
|
||||
total_level=10,
|
||||
magnitude_level=7,
|
||||
magnitude_std=0.5,
|
||||
hparams=dict(
|
||||
pad_val=[round(x) for x in bgr_mean], interpolation='bicubic')),
|
||||
dict(type='ColorJitter', brightness=0.4, contrast=0.4, saturation=0.4),
|
||||
dict(
|
||||
type='RandomErasing',
|
||||
erase_prob=0.25,
|
||||
mode='rand', # should be 'pixel', but currently not supported
|
||||
min_area_ratio=0.02,
|
||||
max_area_ratio=1 / 3,
|
||||
fill_color=bgr_mean,
|
||||
fill_std=bgr_std),
|
||||
dict(type='PackInputs'),
|
||||
]
|
||||
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='ResizeEdge',
|
||||
scale=256,
|
||||
edge='short',
|
||||
backend='pillow',
|
||||
interpolation='bicubic'),
|
||||
dict(type='CenterCrop', crop_size=224),
|
||||
dict(type='PackInputs'),
|
||||
]
|
||||
|
||||
train_dataloader = dict(
|
||||
batch_size=256,
|
||||
num_workers=5,
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root='data/imagenet',
|
||||
split='train',
|
||||
pipeline=train_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||
persistent_workers=True,
|
||||
)
|
||||
|
||||
val_dataloader = dict(
|
||||
batch_size=64,
|
||||
num_workers=5,
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root='data/imagenet',
|
||||
split='val',
|
||||
pipeline=test_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
persistent_workers=True,
|
||||
)
|
||||
val_evaluator = dict(type='Accuracy', topk=(1, 5))
|
||||
|
||||
# If you want standard test, please manually configure the test dataset
|
||||
test_dataloader = val_dataloader
|
||||
test_evaluator = val_evaluator
|
|
@ -0,0 +1,80 @@
|
|||
# dataset settings
|
||||
dataset_type = 'ImageNet'
|
||||
data_preprocessor = dict(
|
||||
num_classes=1000,
|
||||
# RGB format normalization parameters
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
# convert image from BGR to RGB
|
||||
to_rgb=True,
|
||||
)
|
||||
|
||||
bgr_mean = data_preprocessor['mean'][::-1]
|
||||
bgr_std = data_preprocessor['std'][::-1]
|
||||
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='RandomResizedCrop',
|
||||
scale=384,
|
||||
backend='pillow',
|
||||
interpolation='bicubic'),
|
||||
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
|
||||
dict(
|
||||
type='RandAugment',
|
||||
policies='timm_increasing',
|
||||
num_policies=2,
|
||||
total_level=10,
|
||||
magnitude_level=9,
|
||||
magnitude_std=0.5,
|
||||
hparams=dict(
|
||||
pad_val=[round(x) for x in bgr_mean], interpolation='bicubic')),
|
||||
dict(
|
||||
type='RandomErasing',
|
||||
erase_prob=0.25,
|
||||
mode='rand',
|
||||
min_area_ratio=0.02,
|
||||
max_area_ratio=1 / 3,
|
||||
fill_color=bgr_mean,
|
||||
fill_std=bgr_std),
|
||||
dict(type='PackInputs'),
|
||||
]
|
||||
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='ResizeEdge',
|
||||
scale=404,
|
||||
edge='short',
|
||||
backend='pillow',
|
||||
interpolation='bicubic'),
|
||||
dict(type='CenterCrop', crop_size=384),
|
||||
dict(type='PackInputs'),
|
||||
]
|
||||
|
||||
train_dataloader = dict(
|
||||
batch_size=128,
|
||||
num_workers=5,
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root='data/imagenet',
|
||||
split='train',
|
||||
pipeline=train_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||
)
|
||||
|
||||
val_dataloader = dict(
|
||||
batch_size=16,
|
||||
num_workers=5,
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root='data/imagenet',
|
||||
split='val',
|
||||
pipeline=test_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
)
|
||||
val_evaluator = dict(type='Accuracy', topk=(1, 5))
|
||||
|
||||
# If you want standard test, please manually configure the test dataset
|
||||
test_dataloader = val_dataloader
|
||||
test_evaluator = val_evaluator
|
|
@ -0,0 +1,80 @@
|
|||
# dataset settings
|
||||
dataset_type = 'ImageNet'
|
||||
data_preprocessor = dict(
|
||||
num_classes=1000,
|
||||
# RGB format normalization parameters
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
# convert image from BGR to RGB
|
||||
to_rgb=True,
|
||||
)
|
||||
|
||||
bgr_mean = data_preprocessor['mean'][::-1]
|
||||
bgr_std = data_preprocessor['std'][::-1]
|
||||
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='RandomResizedCrop',
|
||||
scale=384,
|
||||
backend='pillow',
|
||||
interpolation='bicubic'),
|
||||
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
|
||||
dict(
|
||||
type='RandAugment',
|
||||
policies='timm_increasing',
|
||||
num_policies=2,
|
||||
total_level=10,
|
||||
magnitude_level=9,
|
||||
magnitude_std=0.5,
|
||||
hparams=dict(
|
||||
pad_val=[round(x) for x in bgr_mean], interpolation='bicubic')),
|
||||
dict(
|
||||
type='RandomErasing',
|
||||
erase_prob=0.25,
|
||||
mode='rand',
|
||||
min_area_ratio=0.02,
|
||||
max_area_ratio=1 / 3,
|
||||
fill_color=bgr_mean,
|
||||
fill_std=bgr_std),
|
||||
dict(type='PackInputs'),
|
||||
]
|
||||
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='ResizeEdge',
|
||||
scale=426,
|
||||
edge='short',
|
||||
backend='pillow',
|
||||
interpolation='bicubic'),
|
||||
dict(type='CenterCrop', crop_size=384),
|
||||
dict(type='PackInputs'),
|
||||
]
|
||||
|
||||
train_dataloader = dict(
|
||||
batch_size=128,
|
||||
num_workers=5,
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root='data/imagenet',
|
||||
split='train',
|
||||
pipeline=train_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||
)
|
||||
|
||||
val_dataloader = dict(
|
||||
batch_size=32,
|
||||
num_workers=5,
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root='data/imagenet',
|
||||
split='val',
|
||||
pipeline=test_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
)
|
||||
val_evaluator = dict(type='Accuracy', topk=(1, 5))
|
||||
|
||||
# If you want standard test, please manually configure the test dataset
|
||||
test_dataloader = val_dataloader
|
||||
test_evaluator = val_evaluator
|
|
@ -0,0 +1,80 @@
|
|||
# dataset settings
|
||||
dataset_type = 'ImageNet'
|
||||
data_preprocessor = dict(
|
||||
num_classes=1000,
|
||||
# RGB format normalization parameters
|
||||
mean=[127.5, 127.5, 127.5],
|
||||
std=[127.5, 127.5, 127.5],
|
||||
# convert image from BGR to RGB
|
||||
to_rgb=True,
|
||||
)
|
||||
|
||||
bgr_mean = data_preprocessor['mean'][::-1]
|
||||
bgr_std = data_preprocessor['std'][::-1]
|
||||
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='RandomResizedCrop',
|
||||
scale=224,
|
||||
backend='pillow',
|
||||
interpolation='bicubic'),
|
||||
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
|
||||
dict(
|
||||
type='RandAugment',
|
||||
policies='timm_increasing',
|
||||
num_policies=2,
|
||||
total_level=10,
|
||||
magnitude_level=9,
|
||||
magnitude_std=0.5,
|
||||
hparams=dict(
|
||||
pad_val=[round(x) for x in bgr_mean], interpolation='bicubic')),
|
||||
dict(
|
||||
type='RandomErasing',
|
||||
erase_prob=0.25,
|
||||
mode='rand',
|
||||
min_area_ratio=0.02,
|
||||
max_area_ratio=1 / 3,
|
||||
fill_color=bgr_mean,
|
||||
fill_std=bgr_std),
|
||||
dict(type='PackInputs'),
|
||||
]
|
||||
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='ResizeEdge',
|
||||
scale=248,
|
||||
edge='short',
|
||||
backend='pillow',
|
||||
interpolation='bicubic'),
|
||||
dict(type='CenterCrop', crop_size=224),
|
||||
dict(type='PackInputs'),
|
||||
]
|
||||
|
||||
train_dataloader = dict(
|
||||
batch_size=128,
|
||||
num_workers=5,
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root='data/imagenet',
|
||||
split='train',
|
||||
pipeline=train_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||
)
|
||||
|
||||
val_dataloader = dict(
|
||||
batch_size=128,
|
||||
num_workers=5,
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root='data/imagenet',
|
||||
split='val',
|
||||
pipeline=test_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
)
|
||||
val_evaluator = dict(type='Accuracy', topk=(1, 5))
|
||||
|
||||
# If you want standard test, please manually configure the test dataset
|
||||
test_dataloader = val_dataloader
|
||||
test_evaluator = val_evaluator
|
|
@ -0,0 +1,60 @@
|
|||
# dataset settings
|
||||
dataset_type = 'ImageNet'
|
||||
data_preprocessor = dict(
|
||||
num_classes=1000,
|
||||
# RGB format normalization parameters
|
||||
mean=[0.48145466 * 255, 0.4578275 * 255, 0.40821073 * 255],
|
||||
std=[0.26862954 * 255, 0.26130258 * 255, 0.27577711 * 255],
|
||||
# convert image from BGR to RGB
|
||||
to_rgb=True,
|
||||
)
|
||||
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='RandomResizedCrop',
|
||||
scale=196,
|
||||
backend='pillow',
|
||||
interpolation='bicubic'),
|
||||
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
|
||||
dict(type='PackInputs'),
|
||||
]
|
||||
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='ResizeEdge',
|
||||
scale=196,
|
||||
edge='short',
|
||||
backend='pillow',
|
||||
interpolation='bicubic'),
|
||||
dict(type='CenterCrop', crop_size=196),
|
||||
dict(type='PackInputs'),
|
||||
]
|
||||
|
||||
train_dataloader = dict(
|
||||
batch_size=16,
|
||||
num_workers=5,
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root='data/imagenet',
|
||||
split='train',
|
||||
pipeline=train_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||
)
|
||||
|
||||
val_dataloader = dict(
|
||||
batch_size=16,
|
||||
num_workers=5,
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root='data/imagenet',
|
||||
split='val',
|
||||
pipeline=test_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
)
|
||||
val_evaluator = dict(type='Accuracy', topk=(1, 5))
|
||||
|
||||
# If you want standard test, please manually configure the test dataset
|
||||
test_dataloader = val_dataloader
|
||||
test_evaluator = val_evaluator
|
|
@ -0,0 +1,60 @@
|
|||
# dataset settings
|
||||
dataset_type = 'ImageNet'
|
||||
data_preprocessor = dict(
|
||||
num_classes=1000,
|
||||
# RGB format normalization parameters
|
||||
mean=[0.48145466 * 255, 0.4578275 * 255, 0.40821073 * 255],
|
||||
std=[0.26862954 * 255, 0.26130258 * 255, 0.27577711 * 255],
|
||||
# convert image from BGR to RGB
|
||||
to_rgb=True,
|
||||
)
|
||||
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='RandomResizedCrop',
|
||||
scale=336,
|
||||
backend='pillow',
|
||||
interpolation='bicubic'),
|
||||
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
|
||||
dict(type='PackInputs'),
|
||||
]
|
||||
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='ResizeEdge',
|
||||
scale=336,
|
||||
edge='short',
|
||||
backend='pillow',
|
||||
interpolation='bicubic'),
|
||||
dict(type='CenterCrop', crop_size=336),
|
||||
dict(type='PackInputs'),
|
||||
]
|
||||
|
||||
train_dataloader = dict(
|
||||
batch_size=16,
|
||||
num_workers=5,
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root='data/imagenet',
|
||||
split='train',
|
||||
pipeline=train_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||
)
|
||||
|
||||
val_dataloader = dict(
|
||||
batch_size=16,
|
||||
num_workers=5,
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root='data/imagenet',
|
||||
split='val',
|
||||
pipeline=test_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
)
|
||||
val_evaluator = dict(type='Accuracy', topk=(1, 5))
|
||||
|
||||
# If you want standard test, please manually configure the test dataset
|
||||
test_dataloader = val_dataloader
|
||||
test_evaluator = val_evaluator
|
|
@ -0,0 +1,62 @@
|
|||
# dataset settings
|
||||
dataset_type = 'ImageNet'
|
||||
data_preprocessor = dict(
|
||||
num_classes=1000,
|
||||
# RGB format normalization parameters
|
||||
mean=[0.48145466 * 255, 0.4578275 * 255, 0.40821073 * 255],
|
||||
std=[0.26862954 * 255, 0.26130258 * 255, 0.27577711 * 255],
|
||||
# convert image from BGR to RGB
|
||||
to_rgb=True,
|
||||
)
|
||||
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='RandomResizedCrop',
|
||||
scale=448,
|
||||
backend='pillow',
|
||||
interpolation='bicubic'),
|
||||
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
|
||||
dict(type='PackInputs'),
|
||||
]
|
||||
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='ResizeEdge',
|
||||
scale=448,
|
||||
edge='short',
|
||||
backend='pillow',
|
||||
interpolation='bicubic'),
|
||||
dict(type='CenterCrop', crop_size=448),
|
||||
dict(type='PackInputs'),
|
||||
]
|
||||
|
||||
train_dataloader = dict(
|
||||
batch_size=16,
|
||||
num_workers=5,
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root='data/imagenet',
|
||||
ann_file='meta/train.txt',
|
||||
data_prefix='train',
|
||||
pipeline=train_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||
)
|
||||
|
||||
val_dataloader = dict(
|
||||
batch_size=8,
|
||||
num_workers=5,
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root='data/imagenet',
|
||||
ann_file='meta/val.txt',
|
||||
data_prefix='val',
|
||||
pipeline=test_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
)
|
||||
val_evaluator = dict(type='Accuracy', topk=(1, 5))
|
||||
|
||||
# If you want standard test, please manually configure the test dataset
|
||||
test_dataloader = val_dataloader
|
||||
test_evaluator = val_evaluator
|
|
@ -0,0 +1,60 @@
|
|||
# dataset settings
|
||||
dataset_type = 'ImageNet'
|
||||
data_preprocessor = dict(
|
||||
num_classes=1000,
|
||||
# RGB format normalization parameters
|
||||
mean=[0.48145466 * 255, 0.4578275 * 255, 0.40821073 * 255],
|
||||
std=[0.26862954 * 255, 0.26130258 * 255, 0.27577711 * 255],
|
||||
# convert image from BGR to RGB
|
||||
to_rgb=True,
|
||||
)
|
||||
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='RandomResizedCrop',
|
||||
scale=560,
|
||||
backend='pillow',
|
||||
interpolation='bicubic'),
|
||||
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
|
||||
dict(type='PackInputs'),
|
||||
]
|
||||
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='ResizeEdge',
|
||||
scale=560,
|
||||
edge='short',
|
||||
backend='pillow',
|
||||
interpolation='bicubic'),
|
||||
dict(type='CenterCrop', crop_size=560),
|
||||
dict(type='PackInputs'),
|
||||
]
|
||||
|
||||
train_dataloader = dict(
|
||||
batch_size=16,
|
||||
num_workers=5,
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root='data/imagenet',
|
||||
split='train',
|
||||
pipeline=train_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||
)
|
||||
|
||||
val_dataloader = dict(
|
||||
batch_size=16,
|
||||
num_workers=5,
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root='data/imagenet',
|
||||
split='val',
|
||||
pipeline=test_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
)
|
||||
val_evaluator = dict(type='Accuracy', topk=(1, 5))
|
||||
|
||||
# If you want standard test, please manually configure the test dataset
|
||||
test_dataloader = val_dataloader
|
||||
test_evaluator = val_evaluator
|
|
@ -0,0 +1,53 @@
|
|||
# dataset settings
|
||||
dataset_type = 'ImageNet'
|
||||
data_preprocessor = dict(
|
||||
# RGB format normalization parameters
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
# convert image from BGR to RGB
|
||||
to_rgb=True,
|
||||
)
|
||||
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='RandomResizedCrop',
|
||||
scale=384,
|
||||
backend='pillow',
|
||||
interpolation='bicubic'),
|
||||
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
|
||||
dict(type='PackInputs'),
|
||||
]
|
||||
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='Resize', scale=384, backend='pillow', interpolation='bicubic'),
|
||||
dict(type='PackInputs'),
|
||||
]
|
||||
|
||||
train_dataloader = dict(
|
||||
batch_size=16,
|
||||
num_workers=5,
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root='data/imagenet',
|
||||
split='train',
|
||||
pipeline=train_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||
)
|
||||
|
||||
val_dataloader = dict(
|
||||
batch_size=16,
|
||||
num_workers=5,
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root='data/imagenet',
|
||||
split='val',
|
||||
pipeline=test_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
)
|
||||
val_evaluator = dict(type='Accuracy', topk=(1, 5))
|
||||
|
||||
# If you want standard test, please manually configure the test dataset
|
||||
test_dataloader = val_dataloader
|
||||
test_evaluator = val_evaluator
|
|
@ -0,0 +1,47 @@
|
|||
# dataset settings
|
||||
dataset_type = 'ImageNet'
|
||||
data_root = 'data/imagenet/'
|
||||
data_preprocessor = dict(
|
||||
type='TwoNormDataPreprocessor',
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
second_mean=[127.5, 127.5, 127.5],
|
||||
second_std=[127.5, 127.5, 127.5],
|
||||
to_rgb=True)
|
||||
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='ColorJitter',
|
||||
brightness=0.4,
|
||||
contrast=0.4,
|
||||
saturation=0.4,
|
||||
hue=0.),
|
||||
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
|
||||
dict(
|
||||
type='RandomResizedCropAndInterpolationWithTwoPic',
|
||||
size=224,
|
||||
second_size=224,
|
||||
interpolation='bicubic',
|
||||
second_interpolation='bicubic',
|
||||
scale=(0.2, 1.0)),
|
||||
dict(
|
||||
type='BEiTMaskGenerator',
|
||||
input_size=(14, 14),
|
||||
num_masking_patches=75,
|
||||
max_num_patches=75,
|
||||
min_num_patches=16),
|
||||
dict(type='PackInputs')
|
||||
]
|
||||
|
||||
train_dataloader = dict(
|
||||
batch_size=256,
|
||||
num_workers=8,
|
||||
persistent_workers=True,
|
||||
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||
collate_fn=dict(type='default_collate'),
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
split='train',
|
||||
pipeline=train_pipeline))
|
|
@ -0,0 +1,80 @@
|
|||
# dataset settings
|
||||
dataset_type = 'ImageNet'
|
||||
data_preprocessor = dict(
|
||||
num_classes=1000,
|
||||
# RGB format normalization parameters
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
# convert image from BGR to RGB
|
||||
to_rgb=True,
|
||||
)
|
||||
|
||||
bgr_mean = data_preprocessor['mean'][::-1]
|
||||
bgr_std = data_preprocessor['std'][::-1]
|
||||
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='RandomResizedCrop',
|
||||
scale=224,
|
||||
backend='pillow',
|
||||
interpolation='bicubic'),
|
||||
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
|
||||
dict(
|
||||
type='RandAugment',
|
||||
policies='timm_increasing',
|
||||
num_policies=2,
|
||||
total_level=10,
|
||||
magnitude_level=9,
|
||||
magnitude_std=0.5,
|
||||
hparams=dict(
|
||||
pad_val=[round(x) for x in bgr_mean], interpolation='bicubic')),
|
||||
dict(
|
||||
type='RandomErasing',
|
||||
erase_prob=0.25,
|
||||
mode='rand',
|
||||
min_area_ratio=0.02,
|
||||
max_area_ratio=1 / 3,
|
||||
fill_color=bgr_mean,
|
||||
fill_std=bgr_std),
|
||||
dict(type='PackInputs'),
|
||||
]
|
||||
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='ResizeEdge',
|
||||
scale=236,
|
||||
edge='short',
|
||||
backend='pillow',
|
||||
interpolation='bicubic'),
|
||||
dict(type='CenterCrop', crop_size=224),
|
||||
dict(type='PackInputs'),
|
||||
]
|
||||
|
||||
train_dataloader = dict(
|
||||
batch_size=64,
|
||||
num_workers=5,
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root='data/imagenet',
|
||||
split='train',
|
||||
pipeline=train_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||
)
|
||||
|
||||
val_dataloader = dict(
|
||||
batch_size=64,
|
||||
num_workers=5,
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root='data/imagenet',
|
||||
split='val',
|
||||
pipeline=test_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
)
|
||||
val_evaluator = dict(type='Accuracy', topk=(1, 5))
|
||||
|
||||
# If you want standard test, please manually configure the test dataset
|
||||
test_dataloader = val_dataloader
|
||||
test_evaluator = val_evaluator
|
|
@ -0,0 +1,49 @@
|
|||
# dataset settings
|
||||
dataset_type = 'ImageNet'
|
||||
data_root = 'data/imagenet/'
|
||||
data_preprocessor = dict(
|
||||
type='TwoNormDataPreprocessor',
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
# clip mean & std
|
||||
second_mean=[0.48145466 * 255, 0.4578275 * 255, 0.40821073 * 255],
|
||||
second_std=[0.26862954 * 255, 0.26130258 * 255, 0.27577711 * 255],
|
||||
to_rgb=True)
|
||||
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='ColorJitter',
|
||||
brightness=0.4,
|
||||
contrast=0.4,
|
||||
saturation=0.4,
|
||||
hue=0.),
|
||||
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
|
||||
dict(
|
||||
type='RandomResizedCropAndInterpolationWithTwoPic',
|
||||
size=224,
|
||||
second_size=224,
|
||||
interpolation='bicubic',
|
||||
second_interpolation='bicubic',
|
||||
scale=(0.2, 1.0)),
|
||||
dict(
|
||||
type='BEiTMaskGenerator',
|
||||
input_size=(14, 14),
|
||||
num_masking_patches=75,
|
||||
max_num_patches=75,
|
||||
min_num_patches=16),
|
||||
dict(type='PackInputs')
|
||||
]
|
||||
|
||||
train_dataloader = dict(
|
||||
batch_size=256,
|
||||
num_workers=8,
|
||||
persistent_workers=True,
|
||||
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||
collate_fn=dict(type='default_collate'),
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
ann_file='meta/train.txt',
|
||||
data_prefix=dict(img_path='train/'),
|
||||
pipeline=train_pipeline))
|
|
@ -0,0 +1,80 @@
|
|||
dataset_type = 'ImageNet'
|
||||
|
||||
data_preprocessor = dict(
|
||||
num_classes=1000,
|
||||
# RGB format normalization parameters
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
# convert image from BGR to RGB
|
||||
to_rgb=True,
|
||||
)
|
||||
|
||||
bgr_mean = data_preprocessor['mean'][::-1]
|
||||
bgr_std = data_preprocessor['std'][::-1]
|
||||
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='RandomResizedCrop',
|
||||
scale=224,
|
||||
backend='pillow',
|
||||
interpolation='bicubic'),
|
||||
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
|
||||
dict(
|
||||
type='RandAugment',
|
||||
policies='timm_increasing',
|
||||
num_policies=2,
|
||||
total_level=10,
|
||||
magnitude_level=9,
|
||||
magnitude_std=0.5,
|
||||
hparams=dict(
|
||||
pad_val=[round(x) for x in bgr_mean], interpolation='bicubic')),
|
||||
dict(
|
||||
type='RandomErasing',
|
||||
erase_prob=0.25,
|
||||
mode='rand',
|
||||
min_area_ratio=0.02,
|
||||
max_area_ratio=1 / 3,
|
||||
fill_color=bgr_mean,
|
||||
fill_std=bgr_std),
|
||||
dict(type='PackInputs'),
|
||||
]
|
||||
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='ResizeEdge',
|
||||
scale=256,
|
||||
edge='short',
|
||||
backend='pillow',
|
||||
interpolation='bicubic'),
|
||||
dict(type='CenterCrop', crop_size=224),
|
||||
dict(type='PackInputs'),
|
||||
]
|
||||
|
||||
train_dataloader = dict(
|
||||
batch_size=256,
|
||||
num_workers=4,
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root='data/imagenet',
|
||||
split='train',
|
||||
pipeline=train_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||
)
|
||||
|
||||
val_dataloader = dict(
|
||||
batch_size=256,
|
||||
num_workers=4,
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root='data/imagenet',
|
||||
split='val',
|
||||
pipeline=test_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
)
|
||||
val_evaluator = dict(type='Accuracy', topk=(1, 5))
|
||||
|
||||
# If you want standard test, please manually configure the test dataset
|
||||
test_dataloader = val_dataloader
|
||||
test_evaluator = val_evaluator
|
|
@ -1,6 +1,7 @@
|
|||
# dataset settings
|
||||
dataset_type = 'ImageNet'
|
||||
data_preprocessor = dict(
|
||||
num_classes=1000,
|
||||
# RGB format normalization parameters
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
|
@ -28,7 +29,7 @@ train_pipeline = [
|
|||
magnitude_std=0.5,
|
||||
hparams=dict(
|
||||
pad_val=[round(x) for x in bgr_mean], interpolation='bicubic')),
|
||||
dict(type='PackClsInputs'),
|
||||
dict(type='PackInputs'),
|
||||
]
|
||||
|
||||
test_pipeline = [
|
||||
|
@ -40,7 +41,7 @@ test_pipeline = [
|
|||
backend='pillow',
|
||||
interpolation='bicubic'),
|
||||
dict(type='CenterCrop', crop_size=224),
|
||||
dict(type='PackClsInputs')
|
||||
dict(type='PackInputs')
|
||||
]
|
||||
|
||||
train_dataloader = dict(
|
||||
|
@ -49,11 +50,9 @@ train_dataloader = dict(
|
|||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root='data/imagenet',
|
||||
ann_file='meta/train.txt',
|
||||
data_prefix='train',
|
||||
split='train',
|
||||
pipeline=train_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||
persistent_workers=True,
|
||||
)
|
||||
|
||||
val_dataloader = dict(
|
||||
|
@ -62,11 +61,9 @@ val_dataloader = dict(
|
|||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root='data/imagenet',
|
||||
ann_file='meta/val.txt',
|
||||
data_prefix='val',
|
||||
split='val',
|
||||
pipeline=test_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
persistent_workers=True,
|
||||
)
|
||||
val_evaluator = dict(type='Accuracy', topk=(1, 5))
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
# dataset settings
|
||||
dataset_type = 'ImageNet'
|
||||
data_preprocessor = dict(
|
||||
num_classes=1000,
|
||||
# RGB format normalization parameters
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
|
@ -28,7 +29,7 @@ train_pipeline = [
|
|||
magnitude_std=0.5,
|
||||
hparams=dict(
|
||||
pad_val=[round(x) for x in bgr_mean], interpolation='bicubic')),
|
||||
dict(type='PackClsInputs'),
|
||||
dict(type='PackInputs'),
|
||||
]
|
||||
|
||||
test_pipeline = [
|
||||
|
@ -40,7 +41,7 @@ test_pipeline = [
|
|||
backend='pillow',
|
||||
interpolation='bicubic'),
|
||||
dict(type='CenterCrop', crop_size=224),
|
||||
dict(type='PackClsInputs')
|
||||
dict(type='PackInputs')
|
||||
]
|
||||
|
||||
train_dataloader = dict(
|
||||
|
@ -49,11 +50,9 @@ train_dataloader = dict(
|
|||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root='data/imagenet',
|
||||
ann_file='meta/train.txt',
|
||||
data_prefix='train',
|
||||
split='train',
|
||||
pipeline=train_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||
persistent_workers=True,
|
||||
)
|
||||
|
||||
val_dataloader = dict(
|
||||
|
@ -62,11 +61,9 @@ val_dataloader = dict(
|
|||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root='data/imagenet',
|
||||
ann_file='meta/val.txt',
|
||||
data_prefix='val',
|
||||
split='val',
|
||||
pipeline=test_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
persistent_workers=True,
|
||||
)
|
||||
val_evaluator = dict(type='Accuracy', topk=(1, 5))
|
||||
|
||||
|
|
|
@ -0,0 +1,33 @@
|
|||
# dataset settings
|
||||
dataset_type = 'ImageNet'
|
||||
data_root = 'data/imagenet/'
|
||||
data_preprocessor = dict(
|
||||
type='SelfSupDataPreprocessor',
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
to_rgb=True)
|
||||
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='RandomResizedCrop', scale=192, crop_ratio_range=(0.67, 1.0)),
|
||||
dict(type='RandomFlip', prob=0.5),
|
||||
dict(
|
||||
type='SimMIMMaskGenerator',
|
||||
input_size=192,
|
||||
mask_patch_size=32,
|
||||
model_patch_size=4,
|
||||
mask_ratio=0.6),
|
||||
dict(type='PackInputs')
|
||||
]
|
||||
|
||||
train_dataloader = dict(
|
||||
batch_size=256,
|
||||
num_workers=8,
|
||||
persistent_workers=True,
|
||||
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||
collate_fn=dict(type='default_collate'),
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
split='train',
|
||||
pipeline=train_pipeline))
|
|
@ -0,0 +1,81 @@
|
|||
# dataset settings
|
||||
dataset_type = 'ImageNet'
|
||||
data_root = 'data/imagenet/'
|
||||
data_preprocessor = dict(
|
||||
num_classes=1000,
|
||||
# RGB format normalization parameters
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
# convert image from BGR to RGB
|
||||
to_rgb=True,
|
||||
)
|
||||
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='RandomResizedCrop',
|
||||
scale=192,
|
||||
backend='pillow',
|
||||
interpolation='bicubic'),
|
||||
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
|
||||
dict(
|
||||
type='RandAugment',
|
||||
policies='timm_increasing',
|
||||
num_policies=2,
|
||||
total_level=10,
|
||||
magnitude_level=9,
|
||||
magnitude_std=0.5,
|
||||
hparams=dict(pad_val=[104, 116, 124], interpolation='bicubic')),
|
||||
dict(
|
||||
type='RandomErasing',
|
||||
erase_prob=0.25,
|
||||
mode='rand',
|
||||
min_area_ratio=0.02,
|
||||
max_area_ratio=1 / 3,
|
||||
fill_color=[103.53, 116.28, 123.675],
|
||||
fill_std=[57.375, 57.12, 58.395]),
|
||||
dict(type='PackInputs'),
|
||||
]
|
||||
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='ResizeEdge',
|
||||
scale=219,
|
||||
edge='short',
|
||||
backend='pillow',
|
||||
interpolation='bicubic'),
|
||||
dict(type='CenterCrop', crop_size=192),
|
||||
dict(type='PackInputs'),
|
||||
]
|
||||
|
||||
train_dataloader = dict(
|
||||
batch_size=256,
|
||||
num_workers=8,
|
||||
collate_fn=dict(type='default_collate'),
|
||||
persistent_workers=True,
|
||||
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
split='train',
|
||||
pipeline=train_pipeline),
|
||||
)
|
||||
|
||||
val_dataloader = dict(
|
||||
batch_size=64,
|
||||
num_workers=5,
|
||||
collate_fn=dict(type='default_collate'),
|
||||
persistent_workers=True,
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
split='val',
|
||||
pipeline=test_pipeline),
|
||||
)
|
||||
val_evaluator = dict(type='Accuracy', topk=(1, 5))
|
||||
|
||||
# If you want standard test, please manually configure the test dataset
|
||||
test_dataloader = val_dataloader
|
||||
test_evaluator = val_evaluator
|
|
@ -1,6 +1,7 @@
|
|||
# dataset settings
|
||||
dataset_type = 'ImageNet'
|
||||
data_preprocessor = dict(
|
||||
num_classes=1000,
|
||||
# RGB format normalization parameters
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
|
@ -12,14 +13,14 @@ train_pipeline = [
|
|||
dict(type='LoadImageFromFile'),
|
||||
dict(type='RandomResizedCrop', scale=224),
|
||||
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
|
||||
dict(type='PackClsInputs'),
|
||||
dict(type='PackInputs'),
|
||||
]
|
||||
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='ResizeEdge', scale=256, edge='short'),
|
||||
dict(type='CenterCrop', crop_size=224),
|
||||
dict(type='PackClsInputs'),
|
||||
dict(type='PackInputs'),
|
||||
]
|
||||
|
||||
train_dataloader = dict(
|
||||
|
@ -28,11 +29,9 @@ train_dataloader = dict(
|
|||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root='data/imagenet',
|
||||
ann_file='meta/train.txt',
|
||||
data_prefix='train',
|
||||
split='train',
|
||||
pipeline=train_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||
persistent_workers=True,
|
||||
)
|
||||
|
||||
val_dataloader = dict(
|
||||
|
@ -41,11 +40,9 @@ val_dataloader = dict(
|
|||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root='data/imagenet',
|
||||
ann_file='meta/val.txt',
|
||||
data_prefix='val',
|
||||
split='val',
|
||||
pipeline=test_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
persistent_workers=True,
|
||||
)
|
||||
val_evaluator = dict(type='Accuracy', topk=(1, 5))
|
||||
|
||||
|
|
|
@ -0,0 +1,89 @@
|
|||
# dataset settings
|
||||
dataset_type = 'ImageNet'
|
||||
data_root = 'data/imagenet/'
|
||||
data_preprocessor = dict(
|
||||
type='SelfSupDataPreprocessor',
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
to_rgb=True)
|
||||
|
||||
view_pipeline1 = [
|
||||
dict(
|
||||
type='RandomResizedCrop',
|
||||
scale=224,
|
||||
interpolation='bicubic',
|
||||
backend='pillow'),
|
||||
dict(type='RandomFlip', prob=0.5),
|
||||
dict(
|
||||
type='RandomApply',
|
||||
transforms=[
|
||||
dict(
|
||||
type='ColorJitter',
|
||||
brightness=0.4,
|
||||
contrast=0.4,
|
||||
saturation=0.2,
|
||||
hue=0.1)
|
||||
],
|
||||
prob=0.8),
|
||||
dict(
|
||||
type='RandomGrayscale',
|
||||
prob=0.2,
|
||||
keep_channels=True,
|
||||
channel_weights=(0.114, 0.587, 0.2989)),
|
||||
dict(
|
||||
type='GaussianBlur',
|
||||
magnitude_range=(0.1, 2.0),
|
||||
magnitude_std='inf',
|
||||
prob=1.),
|
||||
dict(type='Solarize', thr=128, prob=0.),
|
||||
]
|
||||
view_pipeline2 = [
|
||||
dict(
|
||||
type='RandomResizedCrop',
|
||||
scale=224,
|
||||
interpolation='bicubic',
|
||||
backend='pillow'),
|
||||
dict(type='RandomFlip', prob=0.5),
|
||||
dict(
|
||||
type='RandomApply',
|
||||
transforms=[
|
||||
dict(
|
||||
type='ColorJitter',
|
||||
brightness=0.4,
|
||||
contrast=0.4,
|
||||
saturation=0.2,
|
||||
hue=0.1)
|
||||
],
|
||||
prob=0.8),
|
||||
dict(
|
||||
type='RandomGrayscale',
|
||||
prob=0.2,
|
||||
keep_channels=True,
|
||||
channel_weights=(0.114, 0.587, 0.2989)),
|
||||
dict(
|
||||
type='GaussianBlur',
|
||||
magnitude_range=(0.1, 2.0),
|
||||
magnitude_std='inf',
|
||||
prob=0.1),
|
||||
dict(type='Solarize', thr=128, prob=0.2)
|
||||
]
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='MultiView',
|
||||
num_views=[1, 1],
|
||||
transforms=[view_pipeline1, view_pipeline2]),
|
||||
dict(type='PackInputs')
|
||||
]
|
||||
|
||||
train_dataloader = dict(
|
||||
batch_size=32,
|
||||
num_workers=4,
|
||||
persistent_workers=True,
|
||||
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||
collate_fn=dict(type='default_collate'),
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
split='train',
|
||||
pipeline=train_pipeline))
|
|
@ -0,0 +1,58 @@
|
|||
# dataset settings
|
||||
dataset_type = 'ImageNet'
|
||||
data_root = 'data/imagenet/'
|
||||
data_preprocessor = dict(
|
||||
type='SelfSupDataPreprocessor',
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
to_rgb=True)
|
||||
|
||||
# The difference between mocov2 and mocov1 is the transforms in the pipeline
|
||||
view_pipeline = [
|
||||
dict(
|
||||
type='RandomResizedCrop',
|
||||
scale=224,
|
||||
crop_ratio_range=(0.2, 1.),
|
||||
backend='pillow'),
|
||||
dict(
|
||||
type='RandomApply',
|
||||
transforms=[
|
||||
dict(
|
||||
type='ColorJitter',
|
||||
brightness=0.4,
|
||||
contrast=0.4,
|
||||
saturation=0.4,
|
||||
hue=0.1)
|
||||
],
|
||||
prob=0.8),
|
||||
dict(
|
||||
type='RandomGrayscale',
|
||||
prob=0.2,
|
||||
keep_channels=True,
|
||||
channel_weights=(0.114, 0.587, 0.2989)),
|
||||
dict(
|
||||
type='GaussianBlur',
|
||||
magnitude_range=(0.1, 2.0),
|
||||
magnitude_std='inf',
|
||||
prob=0.5),
|
||||
dict(type='RandomFlip', prob=0.5),
|
||||
]
|
||||
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='MultiView', num_views=2, transforms=[view_pipeline]),
|
||||
dict(type='PackInputs')
|
||||
]
|
||||
|
||||
train_dataloader = dict(
|
||||
batch_size=32,
|
||||
num_workers=8,
|
||||
drop_last=True,
|
||||
persistent_workers=True,
|
||||
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||
collate_fn=dict(type='default_collate'),
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
split='train',
|
||||
pipeline=train_pipeline))
|
|
@ -1,6 +1,7 @@
|
|||
# dataset settings
|
||||
dataset_type = 'ImageNet'
|
||||
data_preprocessor = dict(
|
||||
num_classes=1000,
|
||||
# RGB format normalization parameters
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
|
@ -16,7 +17,7 @@ train_pipeline = [
|
|||
backend='pillow',
|
||||
interpolation='bicubic'),
|
||||
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
|
||||
dict(type='PackClsInputs'),
|
||||
dict(type='PackInputs'),
|
||||
]
|
||||
|
||||
test_pipeline = [
|
||||
|
@ -28,7 +29,7 @@ test_pipeline = [
|
|||
backend='pillow',
|
||||
interpolation='bicubic'),
|
||||
dict(type='CenterCrop', crop_size=224),
|
||||
dict(type='PackClsInputs'),
|
||||
dict(type='PackInputs'),
|
||||
]
|
||||
|
||||
train_dataloader = dict(
|
||||
|
@ -37,11 +38,9 @@ train_dataloader = dict(
|
|||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root='data/imagenet',
|
||||
ann_file='meta/train.txt',
|
||||
data_prefix='train',
|
||||
split='train',
|
||||
pipeline=train_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||
persistent_workers=True,
|
||||
)
|
||||
|
||||
val_dataloader = dict(
|
||||
|
@ -50,11 +49,9 @@ val_dataloader = dict(
|
|||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root='data/imagenet',
|
||||
ann_file='meta/val.txt',
|
||||
data_prefix='val',
|
||||
split='val',
|
||||
pipeline=test_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
persistent_workers=True,
|
||||
)
|
||||
val_evaluator = dict(type='Accuracy', topk=(1, 5))
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
# dataset settings
|
||||
dataset_type = 'ImageNet'
|
||||
data_preprocessor = dict(
|
||||
num_classes=1000,
|
||||
# RGB format normalization parameters
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
|
@ -12,14 +13,14 @@ train_pipeline = [
|
|||
dict(type='LoadImageFromFile'),
|
||||
dict(type='RandomResizedCrop', scale=224, backend='pillow'),
|
||||
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
|
||||
dict(type='PackClsInputs'),
|
||||
dict(type='PackInputs'),
|
||||
]
|
||||
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='ResizeEdge', scale=256, edge='short', backend='pillow'),
|
||||
dict(type='CenterCrop', crop_size=224),
|
||||
dict(type='PackClsInputs'),
|
||||
dict(type='PackInputs'),
|
||||
]
|
||||
|
||||
train_dataloader = dict(
|
||||
|
@ -28,11 +29,9 @@ train_dataloader = dict(
|
|||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root='data/imagenet',
|
||||
ann_file='meta/train.txt',
|
||||
data_prefix='train',
|
||||
split='train',
|
||||
pipeline=train_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||
persistent_workers=True,
|
||||
)
|
||||
|
||||
val_dataloader = dict(
|
||||
|
@ -41,11 +40,9 @@ val_dataloader = dict(
|
|||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root='data/imagenet',
|
||||
ann_file='meta/val.txt',
|
||||
data_prefix='val',
|
||||
split='val',
|
||||
pipeline=test_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
persistent_workers=True,
|
||||
)
|
||||
val_evaluator = dict(type='Accuracy', topk=(1, 5))
|
||||
|
||||
|
|
|
@ -0,0 +1,52 @@
|
|||
# dataset settings
|
||||
dataset_type = 'ImageNet'
|
||||
data_root = 'data/imagenet/'
|
||||
data_preprocessor = dict(
|
||||
type='SelfSupDataPreprocessor',
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
to_rgb=True)
|
||||
|
||||
view_pipeline = [
|
||||
dict(type='RandomResizedCrop', scale=224, backend='pillow'),
|
||||
dict(type='RandomFlip', prob=0.5),
|
||||
dict(
|
||||
type='RandomApply',
|
||||
transforms=[
|
||||
dict(
|
||||
type='ColorJitter',
|
||||
brightness=0.8,
|
||||
contrast=0.8,
|
||||
saturation=0.8,
|
||||
hue=0.2)
|
||||
],
|
||||
prob=0.8),
|
||||
dict(
|
||||
type='RandomGrayscale',
|
||||
prob=0.2,
|
||||
keep_channels=True,
|
||||
channel_weights=(0.114, 0.587, 0.2989)),
|
||||
dict(
|
||||
type='GaussianBlur',
|
||||
magnitude_range=(0.1, 2.0),
|
||||
magnitude_std='inf',
|
||||
prob=0.5),
|
||||
]
|
||||
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='MultiView', num_views=2, transforms=[view_pipeline]),
|
||||
dict(type='PackInputs')
|
||||
]
|
||||
|
||||
train_dataloader = dict(
|
||||
batch_size=32,
|
||||
num_workers=4,
|
||||
persistent_workers=True,
|
||||
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||
collate_fn=dict(type='default_collate'),
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
split='train',
|
||||
pipeline=train_pipeline))
|
|
@ -0,0 +1,32 @@
|
|||
# dataset settings
|
||||
dataset_type = 'ImageNet'
|
||||
data_root = 'data/imagenet/'
|
||||
data_preprocessor = dict(
|
||||
type='SelfSupDataPreprocessor',
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
to_rgb=True)
|
||||
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='RandomResizedCrop',
|
||||
scale=224,
|
||||
crop_ratio_range=(0.2, 1.0),
|
||||
backend='pillow',
|
||||
interpolation='bicubic'),
|
||||
dict(type='RandomFlip', prob=0.5),
|
||||
dict(type='PackInputs')
|
||||
]
|
||||
|
||||
train_dataloader = dict(
|
||||
batch_size=512,
|
||||
num_workers=8,
|
||||
persistent_workers=True,
|
||||
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||
collate_fn=dict(type='default_collate'),
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
split='train',
|
||||
pipeline=train_pipeline))
|
|
@ -0,0 +1,90 @@
|
|||
# dataset settings
|
||||
dataset_type = 'ImageNet'
|
||||
data_root = 'data/imagenet/'
|
||||
data_preprocessor = dict(
|
||||
type='SelfSupDataPreprocessor',
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
to_rgb=True)
|
||||
|
||||
view_pipeline1 = [
|
||||
dict(
|
||||
type='RandomResizedCrop',
|
||||
scale=224,
|
||||
crop_ratio_range=(0.2, 1.),
|
||||
backend='pillow'),
|
||||
dict(
|
||||
type='RandomApply',
|
||||
transforms=[
|
||||
dict(
|
||||
type='ColorJitter',
|
||||
brightness=0.4,
|
||||
contrast=0.4,
|
||||
saturation=0.2,
|
||||
hue=0.1)
|
||||
],
|
||||
prob=0.8),
|
||||
dict(
|
||||
type='RandomGrayscale',
|
||||
prob=0.2,
|
||||
keep_channels=True,
|
||||
channel_weights=(0.114, 0.587, 0.2989)),
|
||||
dict(
|
||||
type='GaussianBlur',
|
||||
magnitude_range=(0.1, 2.0),
|
||||
magnitude_std='inf',
|
||||
prob=1.),
|
||||
dict(type='Solarize', thr=128, prob=0.),
|
||||
dict(type='RandomFlip', prob=0.5),
|
||||
]
|
||||
view_pipeline2 = [
|
||||
dict(
|
||||
type='RandomResizedCrop',
|
||||
scale=224,
|
||||
crop_ratio_range=(0.2, 1.),
|
||||
backend='pillow'),
|
||||
dict(
|
||||
type='RandomApply',
|
||||
transforms=[
|
||||
dict(
|
||||
type='ColorJitter',
|
||||
brightness=0.4,
|
||||
contrast=0.4,
|
||||
saturation=0.2,
|
||||
hue=0.1)
|
||||
],
|
||||
prob=0.8),
|
||||
dict(
|
||||
type='RandomGrayscale',
|
||||
prob=0.2,
|
||||
keep_channels=True,
|
||||
channel_weights=(0.114, 0.587, 0.2989)),
|
||||
dict(
|
||||
type='GaussianBlur',
|
||||
magnitude_range=(0.1, 2.0),
|
||||
magnitude_std='inf',
|
||||
prob=0.1),
|
||||
dict(type='Solarize', thr=128, prob=0.2),
|
||||
dict(type='RandomFlip', prob=0.5),
|
||||
]
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='MultiView',
|
||||
num_views=[1, 1],
|
||||
transforms=[view_pipeline1, view_pipeline2]),
|
||||
dict(type='PackInputs')
|
||||
]
|
||||
|
||||
train_dataloader = dict(
|
||||
batch_size=512,
|
||||
num_workers=8,
|
||||
persistent_workers=True,
|
||||
pin_memory=True,
|
||||
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||
collate_fn=dict(type='default_collate'),
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
split='train',
|
||||
pipeline=train_pipeline))
|
|
@ -1,6 +1,7 @@
|
|||
# dataset settings
|
||||
dataset_type = 'ImageNet'
|
||||
data_preprocessor = dict(
|
||||
num_classes=1000,
|
||||
# RGB format normalization parameters
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
|
@ -12,14 +13,14 @@ train_pipeline = [
|
|||
dict(type='LoadImageFromFile'),
|
||||
dict(type='RandomResizedCrop', scale=224),
|
||||
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
|
||||
dict(type='PackClsInputs'),
|
||||
dict(type='PackInputs'),
|
||||
]
|
||||
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='ResizeEdge', scale=256, edge='short'),
|
||||
dict(type='CenterCrop', crop_size=224),
|
||||
dict(type='PackClsInputs'),
|
||||
dict(type='PackInputs'),
|
||||
]
|
||||
|
||||
train_dataloader = dict(
|
||||
|
@ -28,11 +29,9 @@ train_dataloader = dict(
|
|||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root='data/imagenet',
|
||||
ann_file='meta/train.txt',
|
||||
data_prefix='train',
|
||||
split='train',
|
||||
pipeline=train_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||
persistent_workers=True,
|
||||
)
|
||||
|
||||
val_dataloader = dict(
|
||||
|
@ -41,11 +40,9 @@ val_dataloader = dict(
|
|||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root='data/imagenet',
|
||||
ann_file='meta/val.txt',
|
||||
data_prefix='val',
|
||||
split='val',
|
||||
pipeline=test_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
persistent_workers=True,
|
||||
)
|
||||
val_evaluator = dict(type='Accuracy', topk=(1, 5))
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
# dataset settings
|
||||
dataset_type = 'ImageNet'
|
||||
data_preprocessor = dict(
|
||||
num_classes=1000,
|
||||
# RGB format normalization parameters
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
|
@ -20,14 +21,14 @@ train_pipeline = [
|
|||
policies='imagenet',
|
||||
hparams=dict(
|
||||
pad_val=[round(x) for x in bgr_mean], interpolation='bicubic')),
|
||||
dict(type='PackClsInputs'),
|
||||
dict(type='PackInputs'),
|
||||
]
|
||||
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='ResizeEdge', scale=256, edge='short'),
|
||||
dict(type='CenterCrop', crop_size=224),
|
||||
dict(type='PackClsInputs'),
|
||||
dict(type='PackInputs'),
|
||||
]
|
||||
|
||||
train_dataloader = dict(
|
||||
|
@ -36,11 +37,9 @@ train_dataloader = dict(
|
|||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root='data/imagenet',
|
||||
ann_file='meta/train.txt',
|
||||
data_prefix='train',
|
||||
split='train',
|
||||
pipeline=train_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||
persistent_workers=True,
|
||||
)
|
||||
|
||||
val_dataloader = dict(
|
||||
|
@ -49,11 +48,9 @@ val_dataloader = dict(
|
|||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root='data/imagenet',
|
||||
ann_file='meta/val.txt',
|
||||
data_prefix='val',
|
||||
split='val',
|
||||
pipeline=test_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
persistent_workers=True,
|
||||
)
|
||||
val_evaluator = dict(type='Accuracy', topk=(1, 5))
|
||||
|
||||
|
|
|
@ -0,0 +1,73 @@
|
|||
# dataset settings
|
||||
dataset_type = 'ImageNet'
|
||||
img_norm_cfg = dict(
|
||||
mean=[0.48145466 * 255, 0.4578275 * 255, 0.40821073 * 255],
|
||||
std=[0.26862954 * 255, 0.26130258 * 255, 0.27577711 * 255],
|
||||
to_rgb=True)
|
||||
image_size = 224
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='RandomResizedCrop',
|
||||
size=image_size,
|
||||
backend='pillow',
|
||||
interpolation='bicubic'),
|
||||
dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
|
||||
# dict(
|
||||
# type='RandAugment',
|
||||
# policies={{_base_.rand_increasing_policies}},
|
||||
# num_policies=2,
|
||||
# total_level=10,
|
||||
# magnitude_level=9,
|
||||
# magnitude_std=0.5,
|
||||
# hparams=dict(
|
||||
# pad_val=[round(x) for x in img_norm_cfg['mean'][::-1]],
|
||||
# interpolation='bicubic')),
|
||||
dict(
|
||||
type='RandomErasing',
|
||||
erase_prob=0.25,
|
||||
mode='rand',
|
||||
min_area_ratio=0.02,
|
||||
max_area_ratio=1 / 3,
|
||||
fill_color=img_norm_cfg['mean'][::-1],
|
||||
fill_std=img_norm_cfg['std'][::-1]),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='ImageToTensor', keys=['img']),
|
||||
dict(type='ToTensor', keys=['gt_label']),
|
||||
dict(type='Collect', keys=['img', 'gt_label'])
|
||||
]
|
||||
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='Resize',
|
||||
size=(image_size, -1),
|
||||
backend='pillow',
|
||||
interpolation='bicubic'),
|
||||
dict(type='CenterCrop', crop_size=image_size),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='ImageToTensor', keys=['img']),
|
||||
dict(type='Collect', keys=['img'])
|
||||
]
|
||||
|
||||
data = dict(
|
||||
samples_per_gpu=64,
|
||||
workers_per_gpu=8,
|
||||
train=dict(
|
||||
type=dataset_type,
|
||||
data_root='data/imagenet',
|
||||
split='train',
|
||||
pipeline=train_pipeline),
|
||||
val=dict(
|
||||
type=dataset_type,
|
||||
data_root='data/imagenet',
|
||||
split='val',
|
||||
pipeline=test_pipeline),
|
||||
test=dict(
|
||||
# replace `data/val` with `data/test` for standard test
|
||||
type=dataset_type,
|
||||
data_root='data/imagenet',
|
||||
split='val',
|
||||
pipeline=test_pipeline))
|
||||
|
||||
evaluation = dict(interval=10, metric='accuracy')
|
|
@ -0,0 +1,73 @@
|
|||
# dataset settings
|
||||
dataset_type = 'ImageNet'
|
||||
img_norm_cfg = dict(
|
||||
mean=[0.48145466 * 255, 0.4578275 * 255, 0.40821073 * 255],
|
||||
std=[0.26862954 * 255, 0.26130258 * 255, 0.27577711 * 255],
|
||||
to_rgb=True)
|
||||
image_size = 384
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='RandomResizedCrop',
|
||||
size=image_size,
|
||||
backend='pillow',
|
||||
interpolation='bicubic'),
|
||||
dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
|
||||
# dict(
|
||||
# type='RandAugment',
|
||||
# policies={{_base_.rand_increasing_policies}},
|
||||
# num_policies=2,
|
||||
# total_level=10,
|
||||
# magnitude_level=9,
|
||||
# magnitude_std=0.5,
|
||||
# hparams=dict(
|
||||
# pad_val=[round(x) for x in img_norm_cfg['mean'][::-1]],
|
||||
# interpolation='bicubic')),
|
||||
dict(
|
||||
type='RandomErasing',
|
||||
erase_prob=0.25,
|
||||
mode='rand',
|
||||
min_area_ratio=0.02,
|
||||
max_area_ratio=1 / 3,
|
||||
fill_color=img_norm_cfg['mean'][::-1],
|
||||
fill_std=img_norm_cfg['std'][::-1]),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='ImageToTensor', keys=['img']),
|
||||
dict(type='ToTensor', keys=['gt_label']),
|
||||
dict(type='Collect', keys=['img', 'gt_label'])
|
||||
]
|
||||
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='Resize',
|
||||
size=(image_size, -1),
|
||||
backend='pillow',
|
||||
interpolation='bicubic'),
|
||||
dict(type='CenterCrop', crop_size=image_size),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='ImageToTensor', keys=['img']),
|
||||
dict(type='Collect', keys=['img'])
|
||||
]
|
||||
|
||||
data = dict(
|
||||
samples_per_gpu=64,
|
||||
workers_per_gpu=8,
|
||||
train=dict(
|
||||
type=dataset_type,
|
||||
data_root='data/imagenet',
|
||||
split='train',
|
||||
pipeline=train_pipeline),
|
||||
val=dict(
|
||||
type=dataset_type,
|
||||
data_root='data/imagenet',
|
||||
split='val',
|
||||
pipeline=test_pipeline),
|
||||
test=dict(
|
||||
# replace `data/val` with `data/test` for standard test
|
||||
type=dataset_type,
|
||||
data_root='data/imagenet',
|
||||
split='val',
|
||||
pipeline=test_pipeline))
|
||||
|
||||
evaluation = dict(interval=10, metric='accuracy')
|
|
@ -0,0 +1,74 @@
|
|||
# dataset settings
|
||||
dataset_type = 'ImageNet'
|
||||
img_norm_cfg = dict(
|
||||
mean=[0.48145466 * 255, 0.4578275 * 255, 0.40821073 * 255],
|
||||
std=[0.26862954 * 255, 0.26130258 * 255, 0.27577711 * 255],
|
||||
to_rgb=True)
|
||||
image_size = 448
|
||||
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='RandomResizedCrop',
|
||||
size=image_size,
|
||||
backend='pillow',
|
||||
interpolation='bicubic'),
|
||||
dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
|
||||
# dict(
|
||||
# type='RandAugment',
|
||||
# policies={{_base_.rand_increasing_policies}},
|
||||
# num_policies=2,
|
||||
# total_level=10,
|
||||
# magnitude_level=9,
|
||||
# magnitude_std=0.5,
|
||||
# hparams=dict(
|
||||
# pad_val=[round(x) for x in img_norm_cfg['mean'][::-1]],
|
||||
# interpolation='bicubic')),
|
||||
dict(
|
||||
type='RandomErasing',
|
||||
erase_prob=0.25,
|
||||
mode='rand',
|
||||
min_area_ratio=0.02,
|
||||
max_area_ratio=1 / 3,
|
||||
fill_color=img_norm_cfg['mean'][::-1],
|
||||
fill_std=img_norm_cfg['std'][::-1]),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='ImageToTensor', keys=['img']),
|
||||
dict(type='ToTensor', keys=['gt_label']),
|
||||
dict(type='Collect', keys=['img', 'gt_label'])
|
||||
]
|
||||
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='Resize',
|
||||
size=(image_size, -1),
|
||||
backend='pillow',
|
||||
interpolation='bicubic'),
|
||||
dict(type='CenterCrop', crop_size=image_size),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='ImageToTensor', keys=['img']),
|
||||
dict(type='Collect', keys=['img'])
|
||||
]
|
||||
|
||||
data = dict(
|
||||
samples_per_gpu=64,
|
||||
workers_per_gpu=8,
|
||||
train=dict(
|
||||
type=dataset_type,
|
||||
data_root='data/imagenet',
|
||||
split='train',
|
||||
pipeline=train_pipeline),
|
||||
val=dict(
|
||||
type=dataset_type,
|
||||
data_root='data/imagenet',
|
||||
split='val',
|
||||
pipeline=test_pipeline),
|
||||
test=dict(
|
||||
# replace `data/val` with `data/test` for standard test
|
||||
type=dataset_type,
|
||||
data_root='data/imagenet',
|
||||
split='val',
|
||||
pipeline=test_pipeline))
|
||||
|
||||
evaluation = dict(interval=10, metric='accuracy')
|
|
@ -1,6 +1,7 @@
|
|||
# dataset settings
|
||||
dataset_type = 'ImageNet'
|
||||
data_preprocessor = dict(
|
||||
num_classes=1000,
|
||||
# RGB format normalization parameters
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
|
@ -36,7 +37,7 @@ train_pipeline = [
|
|||
max_area_ratio=1 / 3,
|
||||
fill_color=bgr_mean,
|
||||
fill_std=bgr_std),
|
||||
dict(type='PackClsInputs')
|
||||
dict(type='PackInputs')
|
||||
]
|
||||
|
||||
test_pipeline = [
|
||||
|
@ -48,7 +49,7 @@ test_pipeline = [
|
|||
backend='pillow',
|
||||
interpolation='bicubic'),
|
||||
dict(type='CenterCrop', crop_size=224),
|
||||
dict(type='PackClsInputs')
|
||||
dict(type='PackInputs')
|
||||
]
|
||||
|
||||
train_dataloader = dict(
|
||||
|
@ -57,11 +58,9 @@ train_dataloader = dict(
|
|||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root='data/imagenet',
|
||||
ann_file='meta/train.txt',
|
||||
data_prefix='train',
|
||||
split='train',
|
||||
pipeline=train_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||
persistent_workers=True,
|
||||
)
|
||||
|
||||
val_dataloader = dict(
|
||||
|
@ -70,11 +69,9 @@ val_dataloader = dict(
|
|||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root='data/imagenet',
|
||||
ann_file='meta/val.txt',
|
||||
data_prefix='val',
|
||||
split='val',
|
||||
pipeline=test_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
persistent_workers=True,
|
||||
)
|
||||
val_evaluator = dict(type='Accuracy', topk=(1, 5))
|
||||
|
||||
|
|
|
@ -0,0 +1,80 @@
|
|||
# dataset settings
|
||||
dataset_type = 'ImageNet'
|
||||
data_preprocessor = dict(
|
||||
num_classes=1000,
|
||||
# RGB format normalization parameters
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
# convert image from BGR to RGB
|
||||
to_rgb=True,
|
||||
)
|
||||
|
||||
bgr_mean = data_preprocessor['mean'][::-1]
|
||||
bgr_std = data_preprocessor['std'][::-1]
|
||||
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='RandomResizedCrop',
|
||||
scale=224,
|
||||
backend='pillow',
|
||||
interpolation='bicubic'),
|
||||
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
|
||||
dict(
|
||||
type='RandAugment',
|
||||
policies='timm_increasing',
|
||||
num_policies=2,
|
||||
total_level=10,
|
||||
magnitude_level=9,
|
||||
magnitude_std=0.5,
|
||||
hparams=dict(
|
||||
pad_val=[round(x) for x in bgr_mean], interpolation='bicubic')),
|
||||
dict(
|
||||
type='RandomErasing',
|
||||
erase_prob=0.25,
|
||||
mode='rand',
|
||||
min_area_ratio=0.02,
|
||||
max_area_ratio=1 / 3,
|
||||
fill_color=bgr_mean,
|
||||
fill_std=bgr_std),
|
||||
dict(type='PackInputs'),
|
||||
]
|
||||
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='ResizeEdge',
|
||||
scale=224,
|
||||
edge='short',
|
||||
backend='pillow',
|
||||
interpolation='bicubic'),
|
||||
dict(type='CenterCrop', crop_size=224),
|
||||
dict(type='PackInputs'),
|
||||
]
|
||||
|
||||
train_dataloader = dict(
|
||||
batch_size=64,
|
||||
num_workers=5,
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root='data/imagenet',
|
||||
split='train',
|
||||
pipeline=train_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||
)
|
||||
|
||||
val_dataloader = dict(
|
||||
batch_size=64,
|
||||
num_workers=5,
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root='data/imagenet',
|
||||
split='val',
|
||||
pipeline=test_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
)
|
||||
val_evaluator = dict(type='Accuracy', topk=(1, 5))
|
||||
|
||||
# If you want standard test, please manually configure the test dataset
|
||||
test_dataloader = val_dataloader
|
||||
test_evaluator = val_evaluator
|
|
@ -0,0 +1,60 @@
|
|||
# dataset settings
|
||||
dataset_type = 'ImageNet'
|
||||
data_preprocessor = dict(
|
||||
num_classes=1000,
|
||||
# RGB format normalization parameters
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
# convert image from BGR to RGB
|
||||
to_rgb=True,
|
||||
)
|
||||
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='RandomResizedCrop',
|
||||
scale=384,
|
||||
backend='pillow',
|
||||
interpolation='bicubic'),
|
||||
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
|
||||
dict(type='PackInputs'),
|
||||
]
|
||||
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='ResizeEdge',
|
||||
scale=384,
|
||||
edge='short',
|
||||
backend='pillow',
|
||||
interpolation='bicubic'),
|
||||
dict(type='CenterCrop', crop_size=384),
|
||||
dict(type='PackInputs'),
|
||||
]
|
||||
|
||||
train_dataloader = dict(
|
||||
batch_size=64,
|
||||
num_workers=5,
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root='data/imagenet',
|
||||
split='train',
|
||||
pipeline=train_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||
)
|
||||
|
||||
val_dataloader = dict(
|
||||
batch_size=64,
|
||||
num_workers=5,
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root='data/imagenet',
|
||||
split='val',
|
||||
pipeline=test_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
)
|
||||
val_evaluator = dict(type='Accuracy', topk=(1, 5))
|
||||
|
||||
# If you want standard test, please manually configure the test dataset
|
||||
test_dataloader = val_dataloader
|
||||
test_evaluator = val_evaluator
|
|
@ -1,6 +1,7 @@
|
|||
# dataset settings
|
||||
dataset_type = 'ImageNet'
|
||||
data_preprocessor = dict(
|
||||
num_classes=1000,
|
||||
# RGB format normalization parameters
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
|
@ -36,7 +37,7 @@ train_pipeline = [
|
|||
max_area_ratio=1 / 3,
|
||||
fill_color=bgr_mean,
|
||||
fill_std=bgr_std),
|
||||
dict(type='PackClsInputs'),
|
||||
dict(type='PackInputs'),
|
||||
]
|
||||
|
||||
test_pipeline = [
|
||||
|
@ -48,7 +49,7 @@ test_pipeline = [
|
|||
backend='pillow',
|
||||
interpolation='bicubic'),
|
||||
dict(type='CenterCrop', crop_size=256),
|
||||
dict(type='PackClsInputs')
|
||||
dict(type='PackInputs')
|
||||
]
|
||||
|
||||
train_dataloader = dict(
|
||||
|
@ -57,11 +58,9 @@ train_dataloader = dict(
|
|||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root='data/imagenet',
|
||||
ann_file='meta/train.txt',
|
||||
data_prefix='train',
|
||||
split='train',
|
||||
pipeline=train_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||
persistent_workers=True,
|
||||
)
|
||||
|
||||
val_dataloader = dict(
|
||||
|
@ -70,11 +69,9 @@ val_dataloader = dict(
|
|||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root='data/imagenet',
|
||||
ann_file='meta/val.txt',
|
||||
data_prefix='val',
|
||||
split='val',
|
||||
pipeline=test_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
persistent_workers=True,
|
||||
)
|
||||
val_evaluator = dict(type='Accuracy', topk=(1, 5))
|
||||
|
||||
|
|
|
@ -0,0 +1,83 @@
|
|||
# dataset settings
|
||||
dataset_type = 'ImageNet'
|
||||
data_root = 'data/imagenet/'
|
||||
data_preprocessor = dict(
|
||||
num_classes=1000,
|
||||
# RGB format normalization parameters
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
# convert image from BGR to RGB
|
||||
to_rgb=True,
|
||||
)
|
||||
|
||||
bgr_mean = data_preprocessor['mean'][::-1]
|
||||
bgr_std = data_preprocessor['std'][::-1]
|
||||
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='RandomResizedCrop',
|
||||
scale=224,
|
||||
backend='pillow',
|
||||
interpolation='bicubic'),
|
||||
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
|
||||
dict(
|
||||
type='RandAugment',
|
||||
policies='timm_increasing',
|
||||
num_policies=2,
|
||||
total_level=10,
|
||||
magnitude_level=9,
|
||||
magnitude_std=0.5,
|
||||
hparams=dict(
|
||||
pad_val=[round(x) for x in bgr_mean], interpolation='bicubic')),
|
||||
dict(
|
||||
type='RandomErasing',
|
||||
erase_prob=0.25,
|
||||
mode='rand',
|
||||
min_area_ratio=0.02,
|
||||
max_area_ratio=1 / 3,
|
||||
fill_color=bgr_mean,
|
||||
fill_std=bgr_std),
|
||||
dict(type='PackInputs'),
|
||||
]
|
||||
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='ResizeEdge',
|
||||
scale=256,
|
||||
edge='short',
|
||||
backend='pillow',
|
||||
interpolation='bicubic'),
|
||||
dict(type='CenterCrop', crop_size=224),
|
||||
dict(type='PackInputs'),
|
||||
]
|
||||
|
||||
train_dataloader = dict(
|
||||
batch_size=64,
|
||||
num_workers=5,
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
ann_file='meta/train.txt',
|
||||
data_prefix='train',
|
||||
pipeline=train_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||
)
|
||||
|
||||
val_dataloader = dict(
|
||||
batch_size=64,
|
||||
num_workers=5,
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
ann_file='meta/val.txt',
|
||||
data_prefix='val',
|
||||
pipeline=test_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
)
|
||||
val_evaluator = dict(type='Accuracy', topk=(1, 5))
|
||||
|
||||
# If you want standard test, please manually configure the test dataset
|
||||
test_dataloader = val_dataloader
|
||||
test_evaluator = val_evaluator
|
|
@ -3,6 +3,7 @@ dataset_type = 'ImageNet'
|
|||
|
||||
# Google research usually use the below normalization setting.
|
||||
data_preprocessor = dict(
|
||||
num_classes=1000,
|
||||
mean=[127.5, 127.5, 127.5],
|
||||
std=[127.5, 127.5, 127.5],
|
||||
# convert image from BGR to RGB
|
||||
|
@ -13,14 +14,14 @@ train_pipeline = [
|
|||
dict(type='LoadImageFromFile'),
|
||||
dict(type='RandomResizedCrop', scale=224),
|
||||
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
|
||||
dict(type='PackClsInputs'),
|
||||
dict(type='PackInputs'),
|
||||
]
|
||||
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='ResizeEdge', scale=256, edge='short', interpolation='bicubic'),
|
||||
dict(type='CenterCrop', crop_size=224),
|
||||
dict(type='PackClsInputs'),
|
||||
dict(type='PackInputs'),
|
||||
]
|
||||
|
||||
train_dataloader = dict(
|
||||
|
@ -29,11 +30,9 @@ train_dataloader = dict(
|
|||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root='data/imagenet',
|
||||
ann_file='meta/train.txt',
|
||||
data_prefix='train',
|
||||
split='train',
|
||||
pipeline=train_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||
persistent_workers=True,
|
||||
)
|
||||
|
||||
val_dataloader = dict(
|
||||
|
@ -42,11 +41,9 @@ val_dataloader = dict(
|
|||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root='data/imagenet',
|
||||
ann_file='meta/val.txt',
|
||||
data_prefix='val',
|
||||
split='val',
|
||||
pipeline=test_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
persistent_workers=True,
|
||||
)
|
||||
val_evaluator = dict(type='Accuracy', topk=(1, 5))
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
# dataset settings
|
||||
dataset_type = 'ImageNet'
|
||||
data_preprocessor = dict(
|
||||
num_classes=1000,
|
||||
# RGB format normalization parameters
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
|
@ -12,14 +13,14 @@ train_pipeline = [
|
|||
dict(type='LoadImageFromFile'),
|
||||
dict(type='RandomResizedCrop', scale=224, backend='pillow'),
|
||||
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
|
||||
dict(type='PackClsInputs'),
|
||||
dict(type='PackInputs'),
|
||||
]
|
||||
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='ResizeEdge', scale=256, edge='short', backend='pillow'),
|
||||
dict(type='CenterCrop', crop_size=224),
|
||||
dict(type='PackClsInputs'),
|
||||
dict(type='PackInputs'),
|
||||
]
|
||||
|
||||
train_dataloader = dict(
|
||||
|
@ -28,11 +29,9 @@ train_dataloader = dict(
|
|||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root='data/imagenet',
|
||||
ann_file='meta/train.txt',
|
||||
data_prefix='train',
|
||||
split='train',
|
||||
pipeline=train_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||
persistent_workers=True,
|
||||
)
|
||||
|
||||
val_dataloader = dict(
|
||||
|
@ -41,11 +40,9 @@ val_dataloader = dict(
|
|||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root='data/imagenet',
|
||||
ann_file='meta/val.txt',
|
||||
data_prefix='val',
|
||||
split='val',
|
||||
pipeline=test_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
persistent_workers=True,
|
||||
)
|
||||
val_evaluator = dict(type='Accuracy', topk=(1, 5))
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
# dataset settings
|
||||
dataset_type = 'ImageNet'
|
||||
data_preprocessor = dict(
|
||||
num_classes=1000,
|
||||
# RGB format normalization parameters
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
|
@ -24,7 +25,7 @@ train_pipeline = [
|
|||
policies='imagenet',
|
||||
hparams=dict(
|
||||
pad_val=[round(x) for x in bgr_mean], interpolation='bicubic')),
|
||||
dict(type='PackClsInputs'),
|
||||
dict(type='PackInputs'),
|
||||
]
|
||||
|
||||
test_pipeline = [
|
||||
|
@ -36,7 +37,7 @@ test_pipeline = [
|
|||
backend='pillow',
|
||||
interpolation='bicubic'),
|
||||
dict(type='CenterCrop', crop_size=224),
|
||||
dict(type='PackClsInputs'),
|
||||
dict(type='PackInputs'),
|
||||
]
|
||||
|
||||
train_dataloader = dict(
|
||||
|
@ -45,11 +46,9 @@ train_dataloader = dict(
|
|||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root='data/imagenet',
|
||||
ann_file='meta/train.txt',
|
||||
data_prefix='train',
|
||||
split='train',
|
||||
pipeline=train_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||
persistent_workers=True,
|
||||
)
|
||||
|
||||
val_dataloader = dict(
|
||||
|
@ -58,11 +57,9 @@ val_dataloader = dict(
|
|||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root='data/imagenet',
|
||||
ann_file='meta/val.txt',
|
||||
data_prefix='val',
|
||||
split='val',
|
||||
pipeline=test_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
persistent_workers=True,
|
||||
)
|
||||
val_evaluator = dict(type='Accuracy', topk=(1, 5))
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
# dataset settings
|
||||
dataset_type = 'ImageNet'
|
||||
data_preprocessor = dict(
|
||||
num_classes=1000,
|
||||
# RGB format normalization parameters
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
|
@ -36,7 +37,7 @@ train_pipeline = [
|
|||
max_area_ratio=1 / 3,
|
||||
fill_color=bgr_mean,
|
||||
fill_std=bgr_std),
|
||||
dict(type='PackClsInputs'),
|
||||
dict(type='PackInputs'),
|
||||
]
|
||||
|
||||
test_pipeline = [
|
||||
|
@ -48,7 +49,7 @@ test_pipeline = [
|
|||
backend='pillow',
|
||||
interpolation='bicubic'),
|
||||
dict(type='CenterCrop', crop_size=224),
|
||||
dict(type='PackClsInputs'),
|
||||
dict(type='PackInputs'),
|
||||
]
|
||||
|
||||
train_dataloader = dict(
|
||||
|
@ -57,11 +58,9 @@ train_dataloader = dict(
|
|||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root='data/imagenet',
|
||||
ann_file='meta/train.txt',
|
||||
data_prefix='train',
|
||||
split='train',
|
||||
pipeline=train_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||
persistent_workers=True,
|
||||
)
|
||||
|
||||
val_dataloader = dict(
|
||||
|
@ -70,11 +69,9 @@ val_dataloader = dict(
|
|||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root='data/imagenet',
|
||||
ann_file='meta/val.txt',
|
||||
data_prefix='val',
|
||||
split='val',
|
||||
pipeline=test_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
persistent_workers=True,
|
||||
)
|
||||
val_evaluator = dict(type='Accuracy', topk=(1, 5))
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
# dataset settings
|
||||
dataset_type = 'ImageNet'
|
||||
data_preprocessor = dict(
|
||||
num_classes=1000,
|
||||
# RGB format normalization parameters
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
|
@ -36,7 +37,7 @@ train_pipeline = [
|
|||
max_area_ratio=1 / 3,
|
||||
fill_color=bgr_mean,
|
||||
fill_std=bgr_std),
|
||||
dict(type='PackClsInputs'),
|
||||
dict(type='PackInputs'),
|
||||
]
|
||||
|
||||
test_pipeline = [
|
||||
|
@ -48,7 +49,7 @@ test_pipeline = [
|
|||
backend='pillow',
|
||||
interpolation='bicubic'),
|
||||
dict(type='CenterCrop', crop_size=256),
|
||||
dict(type='PackClsInputs'),
|
||||
dict(type='PackInputs'),
|
||||
]
|
||||
|
||||
train_dataloader = dict(
|
||||
|
@ -57,11 +58,9 @@ train_dataloader = dict(
|
|||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root='data/imagenet',
|
||||
ann_file='meta/train.txt',
|
||||
data_prefix='train',
|
||||
split='train',
|
||||
pipeline=train_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||
persistent_workers=True,
|
||||
)
|
||||
|
||||
val_dataloader = dict(
|
||||
|
@ -70,11 +69,9 @@ val_dataloader = dict(
|
|||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root='data/imagenet',
|
||||
ann_file='meta/val.txt',
|
||||
data_prefix='val',
|
||||
split='val',
|
||||
pipeline=test_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
persistent_workers=True,
|
||||
)
|
||||
val_evaluator = dict(type='Accuracy', topk=(1, 5))
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
# dataset settings
|
||||
dataset_type = 'ImageNet'
|
||||
data_preprocessor = dict(
|
||||
num_classes=1000,
|
||||
# RGB format normalization parameters
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
|
@ -16,13 +17,13 @@ train_pipeline = [
|
|||
backend='pillow',
|
||||
interpolation='bicubic'),
|
||||
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
|
||||
dict(type='PackClsInputs'),
|
||||
dict(type='PackInputs'),
|
||||
]
|
||||
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='Resize', scale=384, backend='pillow', interpolation='bicubic'),
|
||||
dict(type='PackClsInputs'),
|
||||
dict(type='PackInputs'),
|
||||
]
|
||||
|
||||
train_dataloader = dict(
|
||||
|
@ -31,11 +32,9 @@ train_dataloader = dict(
|
|||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root='data/imagenet',
|
||||
ann_file='meta/train.txt',
|
||||
data_prefix='train',
|
||||
split='train',
|
||||
pipeline=train_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||
persistent_workers=True,
|
||||
)
|
||||
|
||||
val_dataloader = dict(
|
||||
|
@ -44,11 +43,9 @@ val_dataloader = dict(
|
|||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root='data/imagenet',
|
||||
ann_file='meta/val.txt',
|
||||
data_prefix='val',
|
||||
split='val',
|
||||
pipeline=test_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
persistent_workers=True,
|
||||
)
|
||||
val_evaluator = dict(type='Accuracy', topk=(1, 5))
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
# dataset settings
|
||||
dataset_type = 'ImageNet'
|
||||
data_preprocessor = dict(
|
||||
num_classes=1000,
|
||||
# RGB format normalization parameters
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
|
@ -36,7 +37,7 @@ train_pipeline = [
|
|||
max_area_ratio=1 / 3,
|
||||
fill_color=bgr_mean,
|
||||
fill_std=bgr_std),
|
||||
dict(type='PackClsInputs'),
|
||||
dict(type='PackInputs'),
|
||||
]
|
||||
|
||||
test_pipeline = [
|
||||
|
@ -48,7 +49,7 @@ test_pipeline = [
|
|||
backend='pillow',
|
||||
interpolation='bicubic'),
|
||||
dict(type='CenterCrop', crop_size=224),
|
||||
dict(type='PackClsInputs'),
|
||||
dict(type='PackInputs'),
|
||||
]
|
||||
|
||||
train_dataloader = dict(
|
||||
|
@ -57,11 +58,9 @@ train_dataloader = dict(
|
|||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root='data/imagenet',
|
||||
ann_file='meta/train.txt',
|
||||
data_prefix='train',
|
||||
split='train',
|
||||
pipeline=train_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||
persistent_workers=True,
|
||||
)
|
||||
|
||||
val_dataloader = dict(
|
||||
|
@ -70,11 +69,9 @@ val_dataloader = dict(
|
|||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root='data/imagenet',
|
||||
ann_file='meta/val.txt',
|
||||
data_prefix='val',
|
||||
split='val',
|
||||
pipeline=test_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
persistent_workers=True,
|
||||
)
|
||||
val_evaluator = dict(type='Accuracy', topk=(1, 5))
|
||||
|
||||
|
|
|
@ -0,0 +1,59 @@
|
|||
# dataset settings
|
||||
dataset_type = 'ImageNet'
|
||||
data_preprocessor = dict(
|
||||
# RGB format normalization parameters
|
||||
mean=[122.5, 122.5, 122.5],
|
||||
std=[122.5, 122.5, 122.5],
|
||||
# convert image from BGR to RGB
|
||||
to_rgb=True,
|
||||
)
|
||||
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='RandomResizedCrop',
|
||||
scale=320,
|
||||
backend='pillow',
|
||||
interpolation='bicubic'),
|
||||
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
|
||||
dict(type='PackInputs'),
|
||||
]
|
||||
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='ResizeEdge',
|
||||
scale=int(320 / 224 * 256),
|
||||
edge='short',
|
||||
backend='pillow',
|
||||
interpolation='bicubic'),
|
||||
dict(type='CenterCrop', crop_size=320),
|
||||
dict(type='PackInputs'),
|
||||
]
|
||||
|
||||
train_dataloader = dict(
|
||||
batch_size=8,
|
||||
num_workers=5,
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root='data/imagenet',
|
||||
split='train',
|
||||
pipeline=train_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||
)
|
||||
|
||||
val_dataloader = dict(
|
||||
batch_size=8,
|
||||
num_workers=5,
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root='data/imagenet',
|
||||
split='val',
|
||||
pipeline=test_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
)
|
||||
val_evaluator = dict(type='Accuracy', topk=(1, 5))
|
||||
|
||||
# If you want standard test, please manually configure the test dataset
|
||||
test_dataloader = val_dataloader
|
||||
test_evaluator = val_evaluator
|
|
@ -0,0 +1,64 @@
|
|||
# dataset settings
|
||||
dataset_type = 'InShop'
|
||||
data_preprocessor = dict(
|
||||
num_classes=3997,
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
# convert image from BGR to RGB
|
||||
to_rgb=True)
|
||||
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='Resize', scale=512),
|
||||
dict(type='RandomCrop', crop_size=448),
|
||||
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
|
||||
dict(type='PackInputs'),
|
||||
]
|
||||
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='Resize', scale=512),
|
||||
dict(type='CenterCrop', crop_size=448),
|
||||
dict(type='PackInputs'),
|
||||
]
|
||||
|
||||
train_dataloader = dict(
|
||||
batch_size=32,
|
||||
num_workers=4,
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root='data/inshop',
|
||||
split='train',
|
||||
pipeline=train_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||
)
|
||||
|
||||
query_dataloader = dict(
|
||||
batch_size=32,
|
||||
num_workers=4,
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root='data/inshop',
|
||||
split='query',
|
||||
pipeline=test_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
)
|
||||
|
||||
gallery_dataloader = dict(
|
||||
batch_size=32,
|
||||
num_workers=4,
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root='data/inshop',
|
||||
split='gallery',
|
||||
pipeline=test_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
)
|
||||
val_dataloader = query_dataloader
|
||||
val_evaluator = [
|
||||
dict(type='RetrievalRecall', topk=1),
|
||||
dict(type='RetrievalAveragePrecision', topk=10),
|
||||
]
|
||||
|
||||
test_dataloader = val_dataloader
|
||||
test_evaluator = val_evaluator
|
|
@ -0,0 +1,86 @@
|
|||
# dataset settings
|
||||
data_preprocessor = dict(
|
||||
type='MultiModalDataPreprocessor',
|
||||
mean=[122.770938, 116.7460125, 104.09373615],
|
||||
std=[68.5005327, 66.6321579, 70.32316305],
|
||||
to_rgb=True,
|
||||
)
|
||||
|
||||
train_pipeline = [
|
||||
dict(
|
||||
type='ApplyToList',
|
||||
# NLVR requires to load two images in task.
|
||||
scatter_key='img_path',
|
||||
transforms=[
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='RandomResizedCrop',
|
||||
scale=384,
|
||||
interpolation='bicubic',
|
||||
backend='pillow'),
|
||||
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
|
||||
],
|
||||
collate_keys=['img', 'scale_factor', 'ori_shape'],
|
||||
),
|
||||
dict(type='CleanCaption', keys='text'),
|
||||
dict(
|
||||
type='PackInputs',
|
||||
algorithm_keys=['text'],
|
||||
meta_keys=['image_id'],
|
||||
),
|
||||
]
|
||||
|
||||
test_pipeline = [
|
||||
dict(
|
||||
type='ApplyToList',
|
||||
# NLVR requires to load two images in task.
|
||||
scatter_key='img_path',
|
||||
transforms=[
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='Resize',
|
||||
scale=(384, 384),
|
||||
interpolation='bicubic',
|
||||
backend='pillow'),
|
||||
],
|
||||
collate_keys=['img', 'scale_factor', 'ori_shape'],
|
||||
),
|
||||
dict(
|
||||
type='PackInputs',
|
||||
algorithm_keys=['text'],
|
||||
meta_keys=['image_id'],
|
||||
),
|
||||
]
|
||||
|
||||
train_dataloader = dict(
|
||||
batch_size=16,
|
||||
num_workers=8,
|
||||
dataset=dict(
|
||||
type='NLVR2',
|
||||
data_root='data/nlvr2',
|
||||
ann_file='dev.json',
|
||||
data_prefix='dev',
|
||||
pipeline=train_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||
persistent_workers=True,
|
||||
drop_last=True,
|
||||
)
|
||||
|
||||
val_dataloader = dict(
|
||||
batch_size=64,
|
||||
num_workers=8,
|
||||
dataset=dict(
|
||||
type='NLVR2',
|
||||
data_root='data/nlvr2',
|
||||
ann_file='dev.json',
|
||||
data_prefix='dev',
|
||||
pipeline=test_pipeline,
|
||||
),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
persistent_workers=True,
|
||||
)
|
||||
val_evaluator = dict(type='Accuracy')
|
||||
|
||||
# If you want standard test, please manually configure the test dataset
|
||||
test_dataloader = val_dataloader
|
||||
test_evaluator = val_evaluator
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue