mirror of
https://github.com/open-mmlab/mmclassification.git
synced 2025-06-03 21:53:55 +08:00
Merge remote-tracking branch 'origin/dev'
This commit is contained in:
commit
2037260ea6
87
.github/workflows/build.yml
vendored
87
.github/workflows/build.yml
vendored
@ -40,10 +40,13 @@ jobs:
|
||||
include:
|
||||
- torch: 1.5.0
|
||||
torchvision: 0.6.0
|
||||
torch_major: 1.5.0
|
||||
- torch: 1.8.0
|
||||
torchvision: 0.9.0
|
||||
torch_major: 1.8.0
|
||||
- torch: 1.9.0
|
||||
torchvision: 0.10.0
|
||||
torch_major: 1.9.0
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
@ -51,14 +54,11 @@ jobs:
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- name: Install Pillow
|
||||
run: pip install Pillow==6.2.2
|
||||
if: ${{matrix.torchvision < 0.5}}
|
||||
- name: Install PyTorch
|
||||
run: pip install --use-deprecated=legacy-resolver torch==${{matrix.torch}}+cpu torchvision==${{matrix.torchvision}}+cpu -f https://download.pytorch.org/whl/torch_stable.html
|
||||
run: pip install torch==${{matrix.torch}}+cpu torchvision==${{matrix.torchvision}}+cpu -f https://download.pytorch.org/whl/torch_stable.html
|
||||
- name: Install MMCV
|
||||
run: |
|
||||
pip install --use-deprecated=legacy-resolver mmcv-full -f https://download.openmmlab.com/mmcv/dist/cpu/torch${{matrix.torch}}/index.html
|
||||
pip install mmcv-full -f https://download.openmmlab.com/mmcv/dist/cpu/torch${{matrix.torch_major}}/index.html
|
||||
python -c 'import mmcv; print(mmcv.__version__)'
|
||||
- name: Install mmcls dependencies
|
||||
run: |
|
||||
@ -82,25 +82,37 @@ jobs:
|
||||
include:
|
||||
- torch: 1.5.0
|
||||
torchvision: 0.6.0
|
||||
torch_major: 1.5.0
|
||||
- torch: 1.6.0
|
||||
torchvision: 0.7.0
|
||||
torch_major: 1.6.0
|
||||
- torch: 1.7.0
|
||||
torchvision: 0.8.1
|
||||
torch_major: 1.7.0
|
||||
- torch: 1.8.0
|
||||
torchvision: 0.9.0
|
||||
torch_major: 1.8.0
|
||||
- torch: 1.8.0
|
||||
torchvision: 0.9.0
|
||||
torch_major: 1.8.0
|
||||
python-version: 3.8
|
||||
- torch: 1.8.0
|
||||
torchvision: 0.9.0
|
||||
torch_major: 1.8.0
|
||||
python-version: 3.9
|
||||
- torch: 1.9.0
|
||||
torchvision: 0.10.0
|
||||
- torch: 1.9.0
|
||||
torchvision: 0.10.0
|
||||
torch_major: 1.9.0
|
||||
- torch: 1.10.0
|
||||
torchvision: 0.11.1
|
||||
torch_major: 1.10.0
|
||||
- torch: 1.10.0
|
||||
torchvision: 0.11.1
|
||||
torch_major: 1.10.0
|
||||
python-version: 3.8
|
||||
- torch: 1.9.0
|
||||
torchvision: 0.10.0
|
||||
- torch: 1.10.0
|
||||
torchvision: 0.11.1
|
||||
torch_major: 1.10.0
|
||||
python-version: 3.9
|
||||
|
||||
steps:
|
||||
@ -109,14 +121,11 @@ jobs:
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- name: Install Pillow
|
||||
run: pip install Pillow==6.2.2
|
||||
if: ${{matrix.torchvision < 0.5}}
|
||||
- name: Install PyTorch
|
||||
run: pip install --use-deprecated=legacy-resolver torch==${{matrix.torch}}+cpu torchvision==${{matrix.torchvision}}+cpu -f https://download.pytorch.org/whl/torch_stable.html
|
||||
run: pip install torch==${{matrix.torch}}+cpu torchvision==${{matrix.torchvision}}+cpu -f https://download.pytorch.org/whl/torch_stable.html
|
||||
- name: Install MMCV
|
||||
run: |
|
||||
pip install --use-deprecated=legacy-resolver mmcv-full -f https://download.openmmlab.com/mmcv/dist/cpu/torch${{matrix.torch}}/index.html
|
||||
pip install mmcv-full -f https://download.openmmlab.com/mmcv/dist/cpu/torch${{matrix.torch_major}}/index.html
|
||||
python -c 'import mmcv; print(mmcv.__version__)'
|
||||
- name: Install mmcls dependencies
|
||||
run: |
|
||||
@ -133,10 +142,54 @@ jobs:
|
||||
coverage run --branch --source mmcls -m pytest tests/
|
||||
coverage xml
|
||||
coverage report -m --omit="mmcls/utils/*","mmcls/apis/*"
|
||||
# Only upload coverage report for python3.7 && pytorch1.5
|
||||
- name: Upload coverage to Codecov
|
||||
if: ${{matrix.torch == '1.5.0' && matrix.python-version == '3.7'}}
|
||||
uses: codecov/codecov-action@v1.0.10
|
||||
uses: codecov/codecov-action@v2
|
||||
with:
|
||||
file: ./coverage.xml
|
||||
flags: unittests
|
||||
env_vars: OS,PYTHON
|
||||
name: codecov-umbrella
|
||||
fail_ci_if_error: false
|
||||
|
||||
build-windows:
|
||||
runs-on: windows-2022
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: [3.8]
|
||||
torch: [1.8.1]
|
||||
include:
|
||||
- torch: 1.8.1
|
||||
torchvision: 0.9.1
|
||||
torch_major: 1.8.0
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- name: Install PyTorch
|
||||
run: pip install torch==${{matrix.torch}}+cpu torchvision==${{matrix.torchvision}}+cpu -f https://download.pytorch.org/whl/torch_stable.html
|
||||
- name: Install MMCV & OpenCV
|
||||
run: |
|
||||
pip install opencv-python
|
||||
pip install mmcv-full==1.4.2 -f https://download.openmmlab.com/mmcv/dist/cpu/torch${{matrix.torch_major}}/index.html
|
||||
python -c 'import mmcv; print(mmcv.__version__)'
|
||||
- name: Install mmcls dependencies
|
||||
run: |
|
||||
pip install -r requirements.txt
|
||||
- name: Install timm
|
||||
run: |
|
||||
pip install timm
|
||||
- name: Build and install
|
||||
run: |
|
||||
pip install -e . -U
|
||||
- name: Run unittests and generate coverage report
|
||||
run: |
|
||||
coverage run --branch --source mmcls -m pytest tests/
|
||||
coverage xml
|
||||
coverage report -m --omit="mmcls/utils/*","mmcls/apis/*"
|
||||
- name: Upload coverage to Codecov
|
||||
uses: codecov/codecov-action@v2
|
||||
with:
|
||||
file: ./coverage.xml
|
||||
flags: unittests
|
||||
|
@ -1,15 +1,11 @@
|
||||
exclude: ^tests/data/
|
||||
repos:
|
||||
- repo: https://gitlab.com/pycqa/flake8.git
|
||||
rev: 3.8.3
|
||||
- repo: https://github.com/PyCQA/flake8
|
||||
rev: 4.0.1
|
||||
hooks:
|
||||
- id: flake8
|
||||
- repo: https://github.com/asottile/seed-isort-config
|
||||
rev: v2.2.0
|
||||
hooks:
|
||||
- id: seed-isort-config
|
||||
- repo: https://github.com/timothycrosley/isort
|
||||
rev: 4.3.21
|
||||
- repo: https://github.com/PyCQA/isort
|
||||
rev: 5.10.1
|
||||
hooks:
|
||||
- id: isort
|
||||
- repo: https://github.com/pre-commit/mirrors-yapf
|
||||
@ -44,6 +40,11 @@ repos:
|
||||
hooks:
|
||||
- id: docformatter
|
||||
args: ["--in-place", "--wrap-descriptions", "79"]
|
||||
- repo: https://github.com/open-mmlab/pre-commit-hooks
|
||||
rev: v0.2.0
|
||||
hooks:
|
||||
- id: check-copyright
|
||||
args: ["mmcls", "tests", "demo", "tools"]
|
||||
# - repo: local
|
||||
# hooks:
|
||||
# - id: clang-format
|
||||
|
37
README.md
37
README.md
@ -59,6 +59,13 @@ The master branch works with **PyTorch 1.5+**.
|
||||
|
||||
## What's new
|
||||
|
||||
v0.21.0 was released in 4/3/2022.
|
||||
|
||||
Highlights of the new version:
|
||||
- Support **ResNetV1c** and **Wide-ResNet**, and provide pre-trained models.
|
||||
- Support **dynamic input shape** for ViT-based algorithms. Now our ViT, DeiT, Swin-Transformer and T2T-ViT support forwarding with any input shape.
|
||||
- Reproduce training results of DeiT. And our DeiT-T and DeiT-S have **higher accuracy** comparing with the official weights.
|
||||
|
||||
v0.20.0 was released in 30/1/2022.
|
||||
|
||||
Highlights of the new version:
|
||||
@ -66,15 +73,6 @@ Highlights of the new version:
|
||||
- Support **HRNet**, **ConvNeXt**, **Twins** and **EfficientNet**.
|
||||
- Support model conversion from PyTorch to **Core ML** by a tool.
|
||||
|
||||
v0.19.0 was released in 31/12/2021.
|
||||
|
||||
Highlights of the new version:
|
||||
- The **feature extraction** function has been enhanced. See [#593](https://github.com/open-mmlab/mmclassification/pull/593) for more details.
|
||||
- Provide the high-acc **ResNet-50** training settings from [*ResNet strikes back*](https://arxiv.org/abs/2110.00476).
|
||||
- Reproduce the training accuracy of **T2T-ViT** & **RegNetX**, and provide self-training checkpoints.
|
||||
- Support **DeiT** & **Conformer** backbone and checkpoints.
|
||||
- Provide a **CAM visualization** tool based on [pytorch-grad-cam](https://github.com/jacobgil/pytorch-grad-cam), and detailed [user guide](https://mmclassification.readthedocs.io/en/latest/tools/visualization.html#class-activation-map-visualization)!
|
||||
|
||||
Please refer to [changelog.md](docs/en/changelog.md) for more details and other release history.
|
||||
|
||||
## Installation
|
||||
@ -161,20 +159,21 @@ This project is released under the [Apache 2.0 license](LICENSE).
|
||||
## Projects in OpenMMLab
|
||||
|
||||
- [MMCV](https://github.com/open-mmlab/mmcv): OpenMMLab foundational library for computer vision.
|
||||
- [MIM](https://github.com/open-mmlab/mim): MIM Installs OpenMMLab Packages.
|
||||
- [MIM](https://github.com/open-mmlab/mim): MIM installs OpenMMLab packages.
|
||||
- [MMClassification](https://github.com/open-mmlab/mmclassification): OpenMMLab image classification toolbox and benchmark.
|
||||
- [MMDetection](https://github.com/open-mmlab/mmdetection): OpenMMLab detection toolbox and benchmark.
|
||||
- [MMDetection3D](https://github.com/open-mmlab/mmdetection3d): OpenMMLab's next-generation platform for general 3D object detection.
|
||||
- [MMRotate](https://github.com/open-mmlab/mmrotate): OpenMMLab rotated object detection toolbox and benchmark.
|
||||
- [MMSegmentation](https://github.com/open-mmlab/mmsegmentation): OpenMMLab semantic segmentation toolbox and benchmark.
|
||||
- [MMOCR](https://github.com/open-mmlab/mmocr): OpenMMLab text detection, recognition, and understanding toolbox.
|
||||
- [MMPose](https://github.com/open-mmlab/mmpose): OpenMMLab pose estimation toolbox and benchmark.
|
||||
- [MMHuman3D](https://github.com/open-mmlab/mmhuman3d): OpenMMLab 3D human parametric model toolbox and benchmark.
|
||||
- [MMSelfSup](https://github.com/open-mmlab/mmselfsup): OpenMMLab self-supervised learning toolbox and benchmark.
|
||||
- [MMRazor](https://github.com/open-mmlab/mmrazor): OpenMMLab model compression toolbox and benchmark.
|
||||
- [MMFewShot](https://github.com/open-mmlab/mmfewshot): OpenMMLab fewshot learning toolbox and benchmark.
|
||||
- [MMAction2](https://github.com/open-mmlab/mmaction2): OpenMMLab's next-generation action understanding toolbox and benchmark.
|
||||
- [MMTracking](https://github.com/open-mmlab/mmtracking): OpenMMLab video perception toolbox and benchmark.
|
||||
- [MMPose](https://github.com/open-mmlab/mmpose): OpenMMLab pose estimation toolbox and benchmark.
|
||||
- [MMFlow](https://github.com/open-mmlab/mmflow): OpenMMLab optical flow toolbox and benchmark.
|
||||
- [MMEditing](https://github.com/open-mmlab/mmediting): OpenMMLab image and video editing toolbox.
|
||||
- [MMOCR](https://github.com/open-mmlab/mmocr): OpenMMLab toolbox for text detection, recognition and understanding.
|
||||
- [MMGeneration](https://github.com/open-mmlab/mmgeneration): OpenMMlab toolkit for generative models.
|
||||
- [MMFlow](https://github.com/open-mmlab/mmflow) OpenMMLab optical flow toolbox and benchmark.
|
||||
- [MMFewShot](https://github.com/open-mmlab/mmfewshot): OpenMMLab FewShot Learning Toolbox and Benchmark.
|
||||
- [MMHuman3D](https://github.com/open-mmlab/mmhuman3d): OpenMMLab 3D Human Parametric Model Toolbox and Benchmark.
|
||||
- [MMSelfSup](https://github.com/open-mmlab/mmselfsup): OpenMMLab self-supervised learning toolbox and benchmark.
|
||||
- [MMRazor](https://github.com/open-mmlab/mmrazor): OpenMMLab Model Compression Toolbox and Benchmark.
|
||||
- [MMDeploy](https://github.com/open-mmlab/mmdeploy): OpenMMLab Model Deployment Framework.
|
||||
- [MMGeneration](https://github.com/open-mmlab/mmgeneration): OpenMMLab image and video generative models toolbox.
|
||||
- [MMDeploy](https://github.com/open-mmlab/mmdeploy): OpenMMLab model deployment framework.
|
||||
|
@ -57,6 +57,13 @@ MMClassification 是一款基于 PyTorch 的开源图像分类工具箱,是 [O
|
||||
|
||||
## 更新日志
|
||||
|
||||
2022/3/4 发布了 v0.21.0 版本
|
||||
|
||||
新版本亮点:
|
||||
- 支持了 **ResNetV1c** 和 **Wide-ResNet** 两个 ResNet 变种,并提供了预训练模型
|
||||
- ViT相关模型支持 **动态输入尺寸**。现在我们的 ViT,DeiT,Swin-Transformer 和 T2T-ViT 支持任意尺寸的输入。
|
||||
- 复现了 DeiT 的训练结果,并且我们的 DeiT-T 和 DeiT-S 拥有比官方权重 **更高的精度**。
|
||||
|
||||
2022/1/30 发布了 v0.20.0 版本
|
||||
|
||||
新版本亮点:
|
||||
@ -64,15 +71,6 @@ MMClassification 是一款基于 PyTorch 的开源图像分类工具箱,是 [O
|
||||
- 支持了 **HRNet**,**ConvNeXt**,**Twins** 以及 **EfficientNet** 四个主干网络,欢迎使用!
|
||||
- 支持了从 PyTorch 模型到 Core-ML 模型的转换工具。
|
||||
|
||||
2021/12/31 发布了 v0.19.0 版本
|
||||
|
||||
新版本亮点:
|
||||
- **特征提取**功能得到了加强。详见 [#593](https://github.com/open-mmlab/mmclassification/pull/593)。
|
||||
- 提供了 **ResNet-50** 的高精度训练配置,原论文参见 [*ResNet strikes back*](https://arxiv.org/abs/2110.00476)。
|
||||
- 复现了 **T2T-ViT** 和 **RegNetX** 的训练精度,并提供了自训练的模型权重文件。
|
||||
- 支持了 **DeiT** 和 **Conformer** 主干网络,并提供了预训练模型。
|
||||
- 提供了一个 **CAM 可视化** 工具。该工具基于 [pytorch-grad-cam](https://github.com/jacobgil/pytorch-grad-cam),我们提供了详细的 [使用教程](https://mmclassification.readthedocs.io/en/latest/tools/visualization.html#class-activation-map-visualization)!
|
||||
|
||||
发布历史和更新细节请参考 [更新日志](docs/en/changelog.md)
|
||||
|
||||
## 安装
|
||||
@ -160,20 +158,22 @@ MMClassification 是一款由不同学校和公司共同贡献的开源项目。
|
||||
|
||||
- [MMCV](https://github.com/open-mmlab/mmcv): OpenMMLab 计算机视觉基础库
|
||||
- [MIM](https://github.com/open-mmlab/mim): MIM 是 OpenMMlab 项目、算法、模型的统一入口
|
||||
- [MMDetection](https://github.com/open-mmlab/mmdetection): OpenMMLab 检测工具箱与测试基准
|
||||
- [MMClassification](https://github.com/open-mmlab/mmclassification): OpenMMLab 图像分类工具箱
|
||||
- [MMDetection](https://github.com/open-mmlab/mmdetection): OpenMMLab 目标检测工具箱
|
||||
- [MMDetection3D](https://github.com/open-mmlab/mmdetection3d): OpenMMLab 新一代通用 3D 目标检测平台
|
||||
- [MMSegmentation](https://github.com/open-mmlab/mmsegmentation): OpenMMLab 语义分割工具箱与测试基准
|
||||
- [MMAction2](https://github.com/open-mmlab/mmaction2): OpenMMLab 新一代视频理解工具箱与测试基准
|
||||
- [MMTracking](https://github.com/open-mmlab/mmtracking): OpenMMLab 一体化视频目标感知平台
|
||||
- [MMPose](https://github.com/open-mmlab/mmpose): OpenMMLab 姿态估计工具箱与测试基准
|
||||
- [MMEditing](https://github.com/open-mmlab/mmediting): OpenMMLab 图像视频编辑工具箱
|
||||
- [MMRotate](https://github.com/open-mmlab/mmrotate): OpenMMLab 旋转框检测工具箱与测试基准
|
||||
- [MMSegmentation](https://github.com/open-mmlab/mmsegmentation): OpenMMLab 语义分割工具箱
|
||||
- [MMOCR](https://github.com/open-mmlab/mmocr): OpenMMLab 全流程文字检测识别理解工具包
|
||||
- [MMGeneration](https://github.com/open-mmlab/mmgeneration): OpenMMLab 生成模型工具箱
|
||||
- [MMFlow](https://github.com/open-mmlab/mmflow): OpenMMLab 光流估计工具箱与测试基准
|
||||
- [MMFewShot](https://github.com/open-mmlab/mmfewshot): OpenMMLab 少样本学习工具箱与测试基准
|
||||
- [MMHuman3D](https://github.com/open-mmlab/mmhuman3d):OpenMMLab 人体参数化模型工具箱与测试基准
|
||||
- [MMPose](https://github.com/open-mmlab/mmpose): OpenMMLab 姿态估计工具箱
|
||||
- [MMHuman3D](https://github.com/open-mmlab/mmhuman3d): OpenMMLab 人体参数化模型工具箱与测试基准
|
||||
- [MMSelfSup](https://github.com/open-mmlab/mmselfsup): OpenMMLab 自监督学习工具箱与测试基准
|
||||
- [MMRazor](https://github.com/open-mmlab/mmrazor): OpenMMLab 模型压缩工具箱与测试基准
|
||||
- [MMFewShot](https://github.com/open-mmlab/mmfewshot): OpenMMLab 少样本学习工具箱与测试基准
|
||||
- [MMAction2](https://github.com/open-mmlab/mmaction2): OpenMMLab 新一代视频理解工具箱
|
||||
- [MMTracking](https://github.com/open-mmlab/mmtracking): OpenMMLab 一体化视频目标感知平台
|
||||
- [MMFlow](https://github.com/open-mmlab/mmflow): OpenMMLab 光流估计工具箱与测试基准
|
||||
- [MMEditing](https://github.com/open-mmlab/mmediting): OpenMMLab 图像视频编辑工具箱
|
||||
- [MMGeneration](https://github.com/open-mmlab/mmgeneration): OpenMMLab 图片视频生成模型工具箱
|
||||
- [MMDeploy](https://github.com/open-mmlab/mmdeploy): OpenMMLab 模型部署框架
|
||||
|
||||
## 欢迎加入 OpenMMLab 社区
|
||||
|
48
configs/_base_/datasets/imagenet_bs32_pil_bicubic.py
Normal file
48
configs/_base_/datasets/imagenet_bs32_pil_bicubic.py
Normal file
@ -0,0 +1,48 @@
|
||||
# dataset settings
|
||||
dataset_type = 'ImageNet'
|
||||
img_norm_cfg = dict(
|
||||
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='RandomResizedCrop',
|
||||
size=224,
|
||||
backend='pillow',
|
||||
interpolation='bicubic'),
|
||||
dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='ImageToTensor', keys=['img']),
|
||||
dict(type='ToTensor', keys=['gt_label']),
|
||||
dict(type='Collect', keys=['img', 'gt_label'])
|
||||
]
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='Resize',
|
||||
size=(256, -1),
|
||||
backend='pillow',
|
||||
interpolation='bicubic'),
|
||||
dict(type='CenterCrop', crop_size=224),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='ImageToTensor', keys=['img']),
|
||||
dict(type='Collect', keys=['img'])
|
||||
]
|
||||
data = dict(
|
||||
samples_per_gpu=32,
|
||||
workers_per_gpu=2,
|
||||
train=dict(
|
||||
type=dataset_type,
|
||||
data_prefix='data/imagenet/train',
|
||||
pipeline=train_pipeline),
|
||||
val=dict(
|
||||
type=dataset_type,
|
||||
data_prefix='data/imagenet/val',
|
||||
ann_file='data/imagenet/meta/val.txt',
|
||||
pipeline=test_pipeline),
|
||||
test=dict(
|
||||
# replace `data/val` with `data/test` for standard test
|
||||
type=dataset_type,
|
||||
data_prefix='data/imagenet/val',
|
||||
ann_file='data/imagenet/meta/val.txt',
|
||||
pipeline=test_pipeline))
|
||||
evaluation = dict(interval=1, metric='accuracy')
|
17
configs/_base_/models/resnet34_gem.py
Normal file
17
configs/_base_/models/resnet34_gem.py
Normal file
@ -0,0 +1,17 @@
|
||||
# model settings
|
||||
model = dict(
|
||||
type='ImageClassifier',
|
||||
backbone=dict(
|
||||
type='ResNet',
|
||||
depth=34,
|
||||
num_stages=4,
|
||||
out_indices=(3, ),
|
||||
style='pytorch'),
|
||||
neck=dict(type='GeneralizedMeanPooling'),
|
||||
head=dict(
|
||||
type='LinearClsHead',
|
||||
num_classes=1000,
|
||||
in_channels=512,
|
||||
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
|
||||
topk=(1, 5),
|
||||
))
|
17
configs/_base_/models/resnetv1c50.py
Normal file
17
configs/_base_/models/resnetv1c50.py
Normal file
@ -0,0 +1,17 @@
|
||||
# model settings
|
||||
model = dict(
|
||||
type='ImageClassifier',
|
||||
backbone=dict(
|
||||
type='ResNetV1c',
|
||||
depth=50,
|
||||
num_stages=4,
|
||||
out_indices=(3, ),
|
||||
style='pytorch'),
|
||||
neck=dict(type='GlobalAveragePooling'),
|
||||
head=dict(
|
||||
type='LinearClsHead',
|
||||
num_classes=1000,
|
||||
in_channels=2048,
|
||||
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
|
||||
topk=(1, 5),
|
||||
))
|
20
configs/_base_/models/wide-resnet50.py
Normal file
20
configs/_base_/models/wide-resnet50.py
Normal file
@ -0,0 +1,20 @@
|
||||
# model settings
|
||||
model = dict(
|
||||
type='ImageClassifier',
|
||||
backbone=dict(
|
||||
type='ResNet',
|
||||
depth=50,
|
||||
num_stages=4,
|
||||
out_indices=(3, ),
|
||||
stem_channels=64,
|
||||
base_channels=128,
|
||||
expansion=2,
|
||||
style='pytorch'),
|
||||
neck=dict(type='GlobalAveragePooling'),
|
||||
head=dict(
|
||||
type='LinearClsHead',
|
||||
num_classes=1000,
|
||||
in_channels=2048,
|
||||
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
|
||||
topk=(1, 5),
|
||||
))
|
@ -10,7 +10,7 @@ paramwise_cfg = dict(
|
||||
# lr = 5e-4 * 128 * 8 / 512 = 0.001
|
||||
optimizer = dict(
|
||||
type='AdamW',
|
||||
lr=5e-4 * 128 * 8 / 512,
|
||||
lr=5e-4 * 1024 / 512,
|
||||
weight_decay=0.05,
|
||||
eps=1e-8,
|
||||
betas=(0.9, 0.999),
|
||||
|
@ -29,15 +29,18 @@ The "Roaring 20s" of visual recognition began with the introduction of Vision Tr
|
||||
|
||||
*Models with \* are converted from the [official repo](https://github.com/facebookresearch/ConvNeXt). The config files of these models are only for inference. We don't ensure these config files' training accuracy and welcome you to contribute your reproduction results.*
|
||||
|
||||
### ImageNet-21k
|
||||
### Pre-trained Models
|
||||
|
||||
The pre-trained models on ImageNet-21k are used to fine-tune, and therefore don't have evaluation results.
|
||||
The pre-trained models on ImageNet-1k or ImageNet-21k are used to fine-tune on the downstream tasks.
|
||||
|
||||
| Model | Params(M) | Flops(G) | Download |
|
||||
|:--------------------------------:|:---------:|:--------:|:--------:|
|
||||
| convnext-base_3rdparty_in21k\* | 88.59 | 15.36 | [model](https://download.openmmlab.com/mmclassification/v0/convnext/convnext-base_3rdparty_in21k_20220124-13b83eec.pth) |
|
||||
| convnext-large_3rdparty_in21k\* | 197.77 | 34.37 | [model](https://download.openmmlab.com/mmclassification/v0/convnext/convnext-large_3rdparty_in21k_20220124-41b5a79f.pth) |
|
||||
| convnext-xlarge_3rdparty_in21k\* | 350.20 | 60.93 | [model](https://download.openmmlab.com/mmclassification/v0/convnext/convnext-xlarge_3rdparty_in21k_20220124-f909bad7.pth) |
|
||||
| Model | Training Data | Params(M) | Flops(G) | Download |
|
||||
|:--------------:|:-------------:|:---------:|:--------:|:--------:|
|
||||
| ConvNeXt-T\* | ImageNet-1k | 28.59 | 4.46 | [model](https://download.openmmlab.com/mmclassification/v0/convnext/convnext-tiny_3rdparty_32xb128-noema_in1k_20220222-2908964a.pth) |
|
||||
| ConvNeXt-S\* | ImageNet-1k | 50.22 | 8.69 | [model](https://download.openmmlab.com/mmclassification/v0/convnext/convnext-small_3rdparty_32xb128-noema_in1k_20220222-fa001ca5.pth) |
|
||||
| ConvNeXt-B\* | ImageNet-1k | 88.59 | 15.36 | [model](https://download.openmmlab.com/mmclassification/v0/convnext/convnext-base_3rdparty_32xb128-noema_in1k_20220222-dba4f95f.pth) |
|
||||
| ConvNeXt-B\* | ImageNet-21k | 88.59 | 15.36 | [model](https://download.openmmlab.com/mmclassification/v0/convnext/convnext-base_3rdparty_in21k_20220124-13b83eec.pth) |
|
||||
| ConvNeXt-L\* | ImageNet-21k | 197.77 | 34.37 | [model](https://download.openmmlab.com/mmclassification/v0/convnext/convnext-large_3rdparty_in21k_20220124-41b5a79f.pth) |
|
||||
| ConvNeXt-XL\* | ImageNet-21k | 350.20 | 60.93 | [model](https://download.openmmlab.com/mmclassification/v0/convnext/convnext-xlarge_3rdparty_in21k_20220124-f909bad7.pth) |
|
||||
|
||||
*Models with \* are converted from the [official repo](https://github.com/facebookresearch/ConvNeXt).*
|
||||
|
||||
|
@ -30,6 +30,23 @@ Models:
|
||||
Converted From:
|
||||
Weights: https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth
|
||||
Code: https://github.com/facebookresearch/ConvNeXt
|
||||
- Name: convnext-tiny_3rdparty_32xb128-noema_in1k
|
||||
Metadata:
|
||||
Training Data: ImageNet-1k
|
||||
FLOPs: 4457472768
|
||||
Parameters: 28589128
|
||||
In Collections: ConvNeXt
|
||||
Results:
|
||||
- Dataset: ImageNet-1k
|
||||
Metrics:
|
||||
Top 1 Accuracy: 81.81
|
||||
Top 5 Accuracy: 95.67
|
||||
Task: Image Classification
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/convnext/convnext-tiny_3rdparty_32xb128-noema_in1k_20220222-2908964a.pth
|
||||
Config: configs/convnext/convnext-tiny_32xb128_in1k.py
|
||||
Converted From:
|
||||
Weights: https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224.pth
|
||||
Code: https://github.com/facebookresearch/ConvNeXt
|
||||
- Name: convnext-small_3rdparty_32xb128_in1k
|
||||
Metadata:
|
||||
FLOPs: 8687008512
|
||||
@ -46,6 +63,23 @@ Models:
|
||||
Converted From:
|
||||
Weights: https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth
|
||||
Code: https://github.com/facebookresearch/ConvNeXt
|
||||
- Name: convnext-small_3rdparty_32xb128-noema_in1k
|
||||
Metadata:
|
||||
Training Data: ImageNet-1k
|
||||
FLOPs: 8687008512
|
||||
Parameters: 50223688
|
||||
In Collections: ConvNeXt
|
||||
Results:
|
||||
- Dataset: ImageNet-1k
|
||||
Metrics:
|
||||
Top 1 Accuracy: 83.11
|
||||
Top 5 Accuracy: 96.34
|
||||
Task: Image Classification
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/convnext/convnext-small_3rdparty_32xb128-noema_in1k_20220222-fa001ca5.pth
|
||||
Config: configs/convnext/convnext-small_32xb128_in1k.py
|
||||
Converted From:
|
||||
Weights: https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224.pth
|
||||
Code: https://github.com/facebookresearch/ConvNeXt
|
||||
- Name: convnext-base_3rdparty_32xb128_in1k
|
||||
Metadata:
|
||||
FLOPs: 15359124480
|
||||
@ -62,12 +96,30 @@ Models:
|
||||
Converted From:
|
||||
Weights: https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth
|
||||
Code: https://github.com/facebookresearch/ConvNeXt
|
||||
- Name: convnext-base_3rdparty_32xb128-noema_in1k
|
||||
Metadata:
|
||||
Training Data: ImageNet-1k
|
||||
FLOPs: 15359124480
|
||||
Parameters: 88591464
|
||||
In Collections: ConvNeXt
|
||||
Results:
|
||||
- Dataset: ImageNet-1k
|
||||
Metrics:
|
||||
Top 1 Accuracy: 83.71
|
||||
Top 5 Accuracy: 96.60
|
||||
Task: Image Classification
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/convnext/convnext-base_3rdparty_32xb128-noema_in1k_20220222-dba4f95f.pth
|
||||
Config: configs/convnext/convnext-base_32xb128_in1k.py
|
||||
Converted From:
|
||||
Weights: https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224.pth
|
||||
Code: https://github.com/facebookresearch/ConvNeXt
|
||||
- Name: convnext-base_3rdparty_in21k
|
||||
Metadata:
|
||||
Training Data: ImageNet-21k
|
||||
FLOPs: 15359124480
|
||||
Parameters: 88591464
|
||||
In Collections: ConvNeXt
|
||||
Results: null
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/convnext/convnext-base_3rdparty_in21k_20220124-13b83eec.pth
|
||||
Converted From:
|
||||
Weights: https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth
|
||||
@ -113,6 +165,7 @@ Models:
|
||||
FLOPs: 34368026112
|
||||
Parameters: 197767336
|
||||
In Collections: ConvNeXt
|
||||
Results: null
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/convnext/convnext-large_3rdparty_in21k_20220124-41b5a79f.pth
|
||||
Converted From:
|
||||
Weights: https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_224.pth
|
||||
@ -142,6 +195,7 @@ Models:
|
||||
FLOPs: 60929820672
|
||||
Parameters: 350196968
|
||||
In Collections: ConvNeXt
|
||||
Results: null
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/convnext/convnext-xlarge_3rdparty_in21k_20220124-f909bad7.pth
|
||||
Converted From:
|
||||
Weights: https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_224.pth
|
||||
|
@ -19,11 +19,12 @@ The teacher of the distilled version DeiT is RegNetY-16GF.
|
||||
|
||||
| Model | Pretrain | Params(M) | Flops(G) | Top-1 (%) | Top-5 (%) | Config | Download |
|
||||
|:---------------------:|:------------:|:---------:|:--------:|:---------:|:---------:|:------:|:--------:|
|
||||
| DeiT-tiny\* | From scratch | 5.72 | 1.08 | 72.13 | 91.13 | [config](https://github.com/open-mmlab/mmclassification/tree/master/configs/deit/deit-tiny_pt-4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-tiny_3rdparty_pt-4xb256_in1k_20211124-e930093b.pth) |
|
||||
| DeiT-tiny | From scratch | 5.72 | 1.08 | 74.50 | 92.24 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/deit/deit-tiny_pt-4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-tiny_pt-4xb256_in1k_20220218-13b382a0.pth) | [log](https://download.openmmlab.com/mmclassification/v0/deit/deit-tiny_pt-4xb256_in1k_20220218-13b382a0.log.json) |
|
||||
| DeiT-tiny distilled\* | From scratch | 5.72 | 1.08 | 74.51 | 91.90 | [config](https://github.com/open-mmlab/mmclassification/tree/master/configs/deit/deit-tiny-distilled_pt-4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-tiny-distilled_3rdparty_pt-4xb256_in1k_20211216-c429839a.pth) |
|
||||
| DeiT-small\* | From scratch | 22.05 | 4.24 | 79.83 | 94.95 | [config](https://github.com/open-mmlab/mmclassification/tree/master/configs/deit/deit-small_pt-4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-small_3rdparty_pt-4xb256_in1k_20211124-ffe94edd.pth) |
|
||||
| DeiT-small | From scratch | 22.05 | 4.24 | 80.69 | 95.06 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/deit/deit-small_pt-4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-small_pt-4xb256_in1k_20220218-9425b9bb.pth) | [log](https://download.openmmlab.com/mmclassification/v0/deit/deit-small_pt-4xb256_in1k_20220218-9425b9bb.log.json) |
|
||||
| DeiT-small distilled\*| From scratch | 22.05 | 4.24 | 81.17 | 95.40 | [config](https://github.com/open-mmlab/mmclassification/tree/master/configs/deit/deit-small-distilled_pt-4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-small-distilled_3rdparty_pt-4xb256_in1k_20211216-4de1d725.pth) |
|
||||
| DeiT-base\* | From scratch | 86.57 | 16.86 | 81.79 | 95.59 | [config](https://github.com/open-mmlab/mmclassification/tree/master/configs/deit/deit-base_pt-16xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-base_3rdparty_pt-16xb64_in1k_20211124-6f40c188.pth) |
|
||||
| DeiT-base | From scratch | 86.57 | 16.86 | 81.76 | 95.81 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/deit/deit-base_pt-16xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-base_pt-16xb64_in1k_20220216-db63c16c.pth) | [log](https://download.openmmlab.com/mmclassification/v0/deit/deit-base_pt-16xb64_in1k_20220216-db63c16c.log.json) |
|
||||
| DeiT-base\* | From scratch | 86.57 | 16.86 | 81.79 | 95.59 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/deit/deit-base_pt-16xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-base_3rdparty_pt-16xb64_in1k_20211124-6f40c188.pth) |
|
||||
| DeiT-base distilled\* | From scratch | 86.57 | 16.86 | 83.33 | 96.49 | [config](https://github.com/open-mmlab/mmclassification/tree/master/configs/deit/deit-base-distilled_pt-16xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-base-distilled_3rdparty_pt-16xb64_in1k_20211216-42891296.pth) |
|
||||
| DeiT-base 384px\* | ImageNet-1k | 86.86 | 49.37 | 83.04 | 96.31 | [config](https://github.com/open-mmlab/mmclassification/tree/master/configs/deit/deit-base_ft-16xb32_in1k-384px.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-base_3rdparty_ft-16xb32_in1k-384px_20211124-822d02f2.pth) |
|
||||
| DeiT-base distilled 384px\* | ImageNet-1k | 86.86 | 49.37 | 85.55 | 97.35 | [config](https://github.com/open-mmlab/mmclassification/tree/master/configs/deit/deit-base-distilled_ft-16xb32_in1k-384px.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-base-distilled_3rdparty_ft-16xb32_in1k-384px_20211216-e48d6000.pth) |
|
||||
|
@ -2,9 +2,12 @@ _base_ = './deit-small_pt-4xb256_in1k.py'
|
||||
|
||||
# model settings
|
||||
model = dict(
|
||||
backbone=dict(type='VisionTransformer', arch='deit-base'),
|
||||
backbone=dict(
|
||||
type='VisionTransformer', arch='deit-base', drop_path_rate=0.1),
|
||||
head=dict(type='VisionTransformerClsHead', in_channels=768),
|
||||
)
|
||||
|
||||
# data settings
|
||||
data = dict(samples_per_gpu=64, workers_per_gpu=5)
|
||||
|
||||
custom_hooks = [dict(type='EMAHook', momentum=4e-5, priority='ABOVE_NORMAL')]
|
||||
|
@ -1,6 +1,8 @@
|
||||
# In small and tiny arch, remove drop path and EMA hook comparing with the
|
||||
# original config
|
||||
_base_ = [
|
||||
'../_base_/datasets/imagenet_bs64_pil_resize_autoaug.py',
|
||||
'../_base_/schedules/imagenet_bs4096_AdamW.py',
|
||||
'../_base_/datasets/imagenet_bs64_swin_224.py',
|
||||
'../_base_/schedules/imagenet_bs1024_adamw_swin.py',
|
||||
'../_base_/default_runtime.py'
|
||||
]
|
||||
|
||||
@ -23,7 +25,20 @@ model = dict(
|
||||
init_cfg=[
|
||||
dict(type='TruncNormal', layer='Linear', std=.02),
|
||||
dict(type='Constant', layer='LayerNorm', val=1., bias=0.),
|
||||
])
|
||||
],
|
||||
train_cfg=dict(augments=[
|
||||
dict(type='BatchMixup', alpha=0.8, num_classes=1000, prob=0.5),
|
||||
dict(type='BatchCutMix', alpha=1.0, num_classes=1000, prob=0.5)
|
||||
]))
|
||||
|
||||
# data settings
|
||||
data = dict(samples_per_gpu=256, workers_per_gpu=5)
|
||||
|
||||
paramwise_cfg = dict(
|
||||
norm_decay_mult=0.0,
|
||||
bias_decay_mult=0.0,
|
||||
custom_keys={
|
||||
'.cls_token': dict(decay_mult=0.0),
|
||||
'.pos_embed': dict(decay_mult=0.0)
|
||||
})
|
||||
optimizer = dict(paramwise_cfg=paramwise_cfg)
|
||||
|
@ -16,7 +16,7 @@ Collections:
|
||||
Version: https://github.com/open-mmlab/mmclassification/blob/v0.19.0/mmcls/models/backbones/deit.py
|
||||
|
||||
Models:
|
||||
- Name: deit-tiny_3rdparty_pt-4xb256_in1k
|
||||
- Name: deit-tiny_pt-4xb256_in1k
|
||||
Metadata:
|
||||
FLOPs: 1080000000
|
||||
Parameters: 5720000
|
||||
@ -24,13 +24,10 @@ Models:
|
||||
Results:
|
||||
- Dataset: ImageNet-1k
|
||||
Metrics:
|
||||
Top 1 Accuracy: 72.13
|
||||
Top 5 Accuracy: 91.13
|
||||
Top 1 Accuracy: 74.50
|
||||
Top 5 Accuracy: 92.24
|
||||
Task: Image Classification
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/deit/deit-tiny_3rdparty_pt-4xb256_in1k_20211124-e930093b.pth
|
||||
Converted From:
|
||||
Weights: https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth
|
||||
Code: https://github.com/facebookresearch/deit/blob/f5123946205daf72a88783dae94cabff98c49c55/models.py#L63
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/deit/deit-tiny_pt-4xb256_in1k_20220218-13b382a0.pth
|
||||
Config: configs/deit/deit-tiny_pt-4xb256_in1k.py
|
||||
- Name: deit-tiny-distilled_3rdparty_pt-4xb256_in1k
|
||||
Metadata:
|
||||
@ -48,7 +45,7 @@ Models:
|
||||
Weights: https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth
|
||||
Code: https://github.com/facebookresearch/deit/blob/f5123946205daf72a88783dae94cabff98c49c55/models.py#L108
|
||||
Config: configs/deit/deit-tiny-distilled_pt-4xb256_in1k.py
|
||||
- Name: deit-small_3rdparty_pt-4xb256_in1k
|
||||
- Name: deit-small_pt-4xb256_in1k
|
||||
Metadata:
|
||||
FLOPs: 4240000000
|
||||
Parameters: 22050000
|
||||
@ -56,13 +53,10 @@ Models:
|
||||
Results:
|
||||
- Dataset: ImageNet-1k
|
||||
Metrics:
|
||||
Top 1 Accuracy: 79.83
|
||||
Top 5 Accuracy: 94.95
|
||||
Top 1 Accuracy: 80.69
|
||||
Top 5 Accuracy: 95.06
|
||||
Task: Image Classification
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/deit/deit-small_3rdparty_pt-4xb256_in1k_20211124-ffe94edd.pth
|
||||
Converted From:
|
||||
Weights: https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth
|
||||
Code: https://github.com/facebookresearch/deit/blob/f5123946205daf72a88783dae94cabff98c49c55/models.py#L78
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/deit/deit-small_pt-4xb256_in1k_20220218-9425b9bb.pth
|
||||
Config: configs/deit/deit-small_pt-4xb256_in1k.py
|
||||
- Name: deit-small-distilled_3rdparty_pt-4xb256_in1k
|
||||
Metadata:
|
||||
@ -80,6 +74,19 @@ Models:
|
||||
Weights: https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth
|
||||
Code: https://github.com/facebookresearch/deit/blob/f5123946205daf72a88783dae94cabff98c49c55/models.py#L123
|
||||
Config: configs/deit/deit-small-distilled_pt-4xb256_in1k.py
|
||||
- Name: deit-base_pt-16xb64_in1k
|
||||
Metadata:
|
||||
FLOPs: 16860000000
|
||||
Parameters: 86570000
|
||||
In Collection: DeiT
|
||||
Results:
|
||||
- Dataset: ImageNet-1k
|
||||
Metrics:
|
||||
Top 1 Accuracy: 81.76
|
||||
Top 5 Accuracy: 95.81
|
||||
Task: Image Classification
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/deit/deit-base_pt-16xb64_in1k_20220216-db63c16c.pth
|
||||
Config: configs/deit/deit-base_pt-16xb64_in1k.py
|
||||
- Name: deit-base_3rdparty_pt-16xb64_in1k
|
||||
Metadata:
|
||||
FLOPs: 16860000000
|
||||
|
@ -40,16 +40,23 @@ The depth of representations is of central importance for many visual recognitio
|
||||
| ResNet-50 | 25.56 | 4.12 | 76.55 | 93.06 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet50_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth) | [log](https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.log.json) |
|
||||
| ResNet-101 | 44.55 | 7.85 | 77.97 | 94.06 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet101_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/resnet/resnet101_8xb32_in1k_20210831-539c63f8.pth) | [log](https://download.openmmlab.com/mmclassification/v0/resnet/resnet101_8xb32_in1k_20210831-539c63f8.log.json) |
|
||||
| ResNet-152 | 60.19 | 11.58 | 78.48 | 94.13 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet152_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/resnet/resnet152_8xb32_in1k_20210901-4d7582fa.pth) | [log](https://download.openmmlab.com/mmclassification/v0/resnet/resnet152_8xb32_in1k_20210901-4d7582fa.log.json) |
|
||||
| ResNetV1C-50 | 25.58 | 4.36 | 77.01 | 93.58 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnetv1c50_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1c50_8xb32_in1k_20220214-3343eccd.pth) | [log](https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1c50_8xb32_in1k_20220214-3343eccd.log.json) |
|
||||
| ResNetV1C-101 | 44.57 | 8.09 | 78.30 | 94.27 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnetv1c101_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1c101_8xb32_in1k_20220214-434fe45f.pth) | [log](https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1c101_8xb32_in1k_20220214-434fe45f.log.json) |
|
||||
| ResNetV1C-152 | 60.21 | 11.82 | 78.76 | 94.41 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnetv1c152_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1c152_8xb32_in1k_20220214-c013291f.pth) | [log](https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1c152_8xb32_in1k_20220214-c013291f.log.json) |
|
||||
| ResNetV1D-50 | 25.58 | 4.36 | 77.54 | 93.57 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnetv1d50_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1d50_b32x8_imagenet_20210531-db14775a.pth) | [log](https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1d50_b32x8_imagenet_20210531-db14775a.log.json) |
|
||||
| ResNetV1D-101 | 44.57 | 8.09 | 78.93 | 94.48 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnetv1d101_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1d101_b32x8_imagenet_20210531-6e13bcd3.pth) | [log](https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1d101_b32x8_imagenet_20210531-6e13bcd3.log.json) |
|
||||
| ResNetV1D-152 | 60.21 | 11.82 | 79.41 | 94.70 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnetv1d152_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1d152_b32x8_imagenet_20210531-278cf22a.pth) | [log](https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1d152_b32x8_imagenet_20210531-278cf22a.log.json) |
|
||||
| ResNet-50 (fp16) | 25.56 | 4.12 | 76.30 | 93.07 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet50_8xb32-fp16_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/fp16/resnet50_batch256_fp16_imagenet_20210320-b3964210.pth) | [log](https://download.openmmlab.com/mmclassification/v0/fp16/resnet50_batch256_fp16_imagenet_20210320-b3964210.log.json) |
|
||||
| Wide-ResNet-50\* | 68.88 | 11.44 | 78.48 | 94.08 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/wide-resnet50_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/resnet/wide-resnet50_3rdparty_8xb32_in1k_20220304-66678344.pth) |
|
||||
| Wide-ResNet-101\* | 126.89 | 22.81 | 78.84 | 94.28 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet101_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/resnet/wide-resnet101_3rdparty_8xb32_in1k_20220304-8d5f9d61.pth) |
|
||||
| ResNet-50 (rsb-a1) | 25.56 | 4.12 | 80.12 | 94.78 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet50_8xb256-rsb-a1-600e_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb256-rsb-a1-600e_in1k_20211228-20e21305.pth) | [log](https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb256-rsb-a1-600e_in1k_20211228-20e21305.log.json) |
|
||||
| ResNet-50 (rsb-a2) | 25.56 | 4.12 | 79.55 | 94.37 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet50_8xb256-rsb-a2-300e_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb256-rsb-a2-300e_in1k_20211228-0fd8be6e.pth) | [log](https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb256-rsb-a2-300e_in1k_20211228-0fd8be6e.log.json) |
|
||||
| ResNet-50 (rsb-a3) | 25.56 | 4.12 | 78.30 | 93.80 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet50_8xb256-rsb-a3-100e_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb256-rsb-a3-100e_in1k_20211228-3493673c.pth) | [log](https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb256-rsb-a3-100e_in1k_20211228-3493673c.log.json) |
|
||||
|
||||
*The "rsb" means using the training settings from [ResNet strikes back: An improved training procedure in timm](https://arxiv.org/abs/2110.00476).*
|
||||
|
||||
*Models with \* are converted from the [official repo](https://github.com/pytorch/vision). The config files of these models are only for validation. We don't ensure these config files' training accuracy and welcome you to contribute your reproduction results.*
|
||||
|
||||
## Citation
|
||||
|
||||
```
|
||||
|
@ -298,3 +298,80 @@ Models:
|
||||
Top 5 Accuracy: 93.80
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb256-rsb-a3-100e_in1k_20211228-3493673c.pth
|
||||
Config: configs/resnet/resnet50_8xb256-rsb-a3-100e_in1k.py
|
||||
- Name: wide-resnet50_3rdparty_8xb32_in1k
|
||||
Metadata:
|
||||
FLOPs: 11440000000 # 11.44G
|
||||
Parameters: 68880000 # 68.88M
|
||||
Training Techniques:
|
||||
- SGD with Momentum
|
||||
- Weight Decay
|
||||
In Collection: ResNet
|
||||
Results:
|
||||
- Task: Image Classification
|
||||
Dataset: ImageNet-1k
|
||||
Metrics:
|
||||
Top 1 Accuracy: 78.48
|
||||
Top 5 Accuracy: 94.08
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/resnet/wide-resnet50_3rdparty_8xb32_in1k_20220304-66678344.pth
|
||||
Config: configs/resnet/wide-resnet50_8xb32_in1k.py
|
||||
Converted From:
|
||||
Weights: https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth
|
||||
Code: https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py
|
||||
- Name: wide-resnet101_3rdparty_8xb32_in1k
|
||||
Metadata:
|
||||
FLOPs: 22810000000 # 22.81G
|
||||
Parameters: 126890000 # 126.89M
|
||||
Training Techniques:
|
||||
- SGD with Momentum
|
||||
- Weight Decay
|
||||
In Collection: ResNet
|
||||
Results:
|
||||
- Task: Image Classification
|
||||
Dataset: ImageNet-1k
|
||||
Metrics:
|
||||
Top 1 Accuracy: 78.84
|
||||
Top 5 Accuracy: 94.28
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/resnet/wide-resnet101_3rdparty_8xb32_in1k_20220304-8d5f9d61.pth
|
||||
Config: configs/resnet/wide-resnet101_8xb32_in1k.py
|
||||
Converted From:
|
||||
Weights: https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth
|
||||
Code: https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py
|
||||
- Name: resnetv1c50_8xb32_in1k
|
||||
Metadata:
|
||||
FLOPs: 4360000000
|
||||
Parameters: 25580000
|
||||
In Collection: ResNet
|
||||
Results:
|
||||
- Dataset: ImageNet-1k
|
||||
Metrics:
|
||||
Top 1 Accuracy: 77.01
|
||||
Top 5 Accuracy: 93.58
|
||||
Task: Image Classification
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1c50_8xb32_in1k_20220214-3343eccd.pth
|
||||
Config: configs/resnet/resnetv1c50_8xb32_in1k.py
|
||||
- Name: resnetv1c101_8xb32_in1k
|
||||
Metadata:
|
||||
FLOPs: 8090000000
|
||||
Parameters: 44570000
|
||||
In Collection: ResNet
|
||||
Results:
|
||||
- Dataset: ImageNet-1k
|
||||
Metrics:
|
||||
Top 1 Accuracy: 78.30
|
||||
Top 5 Accuracy: 94.27
|
||||
Task: Image Classification
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1c101_8xb32_in1k_20220214-434fe45f.pth
|
||||
Config: configs/resnet/resnetv1c101_8xb32_in1k.py
|
||||
- Name: resnetv1c152_8xb32_in1k
|
||||
Metadata:
|
||||
FLOPs: 11820000000
|
||||
Parameters: 60210000
|
||||
In Collection: ResNet
|
||||
Results:
|
||||
- Dataset: ImageNet-1k
|
||||
Metrics:
|
||||
Top 1 Accuracy: 78.76
|
||||
Top 5 Accuracy: 94.41
|
||||
Task: Image Classification
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1c152_8xb32_in1k_20220214-c013291f.pth
|
||||
Config: configs/resnet/resnetv1c152_8xb32_in1k.py
|
||||
|
7
configs/resnet/resnetv1c101_8xb32_in1k.py
Normal file
7
configs/resnet/resnetv1c101_8xb32_in1k.py
Normal file
@ -0,0 +1,7 @@
|
||||
_base_ = [
|
||||
'../_base_/models/resnetv1c50.py',
|
||||
'../_base_/datasets/imagenet_bs32_pil_resize.py',
|
||||
'../_base_/schedules/imagenet_bs256.py', '../_base_/default_runtime.py'
|
||||
]
|
||||
|
||||
model = dict(backbone=dict(depth=101))
|
7
configs/resnet/resnetv1c152_8xb32_in1k.py
Normal file
7
configs/resnet/resnetv1c152_8xb32_in1k.py
Normal file
@ -0,0 +1,7 @@
|
||||
_base_ = [
|
||||
'../_base_/models/resnetv1c50.py',
|
||||
'../_base_/datasets/imagenet_bs32_pil_resize.py',
|
||||
'../_base_/schedules/imagenet_bs256.py', '../_base_/default_runtime.py'
|
||||
]
|
||||
|
||||
model = dict(backbone=dict(depth=152))
|
5
configs/resnet/resnetv1c50_8xb32_in1k.py
Normal file
5
configs/resnet/resnetv1c50_8xb32_in1k.py
Normal file
@ -0,0 +1,5 @@
|
||||
_base_ = [
|
||||
'../_base_/models/resnetv1c50.py',
|
||||
'../_base_/datasets/imagenet_bs32_pil_resize.py',
|
||||
'../_base_/schedules/imagenet_bs256.py', '../_base_/default_runtime.py'
|
||||
]
|
34
configs/wrn/README.md
Normal file
34
configs/wrn/README.md
Normal file
@ -0,0 +1,34 @@
|
||||
# Wide-ResNet
|
||||
|
||||
> [Wide Residual Networks](https://arxiv.org/abs/1605.07146)
|
||||
<!-- [ALGORITHM] -->
|
||||
|
||||
## Abstract
|
||||
|
||||
Deep residual networks were shown to be able to scale up to thousands of layers and still have improving performance. However, each fraction of a percent of improved accuracy costs nearly doubling the number of layers, and so training very deep residual networks has a problem of diminishing feature reuse, which makes these networks very slow to train. To tackle these problems, in this paper we conduct a detailed experimental study on the architecture of ResNet blocks, based on which we propose a novel architecture where we decrease depth and increase width of residual networks. We call the resulting network structures wide residual networks (WRNs) and show that these are far superior over their commonly used thin and very deep counterparts. For example, we demonstrate that even a simple 16-layer-deep wide residual network outperforms in accuracy and efficiency all previous deep residual networks, including thousand-layer-deep networks, achieving new state-of-the-art results on CIFAR, SVHN, COCO, and significant improvements on ImageNet.
|
||||
|
||||
<div align=center>
|
||||
<img src="https://user-images.githubusercontent.com/26739999/156701329-2c7ec7bc-23da-401b-86bf-dea8567ccee8.png" width="90%"/>
|
||||
</div>
|
||||
|
||||
## Results and models
|
||||
|
||||
### ImageNet-1k
|
||||
|
||||
| Model | Params(M) | Flops(G) | Top-1 (%) | Top-5 (%) | Config | Download |
|
||||
|:---------------------:|:---------:|:--------:|:---------:|:---------:|:------:|:--------:|
|
||||
| WRN-50\* | 68.88 | 11.44 | 78.48 | 94.08 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/wrn/wide-resnet50_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/wrn/wide-resnet50_3rdparty_8xb32_in1k_20220304-66678344.pth) |
|
||||
| WRN-101\* | 126.89 | 22.81 | 78.84 | 94.28 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/wrn/wide-resnet101_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/wrn/wide-resnet101_3rdparty_8xb32_in1k_20220304-8d5f9d61.pth) |
|
||||
| WRN-50 (timm)\* | 68.88 | 11.44 | 81.45 | 95.53 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/wrn/wide-resnet50_timm_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/wrn/wide-resnet50_3rdparty-timm_8xb32_in1k_20220304-83ae4399.pth) |
|
||||
|
||||
*Models with \* are converted from the [TorchVision](https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py) and [TIMM](https://github.com/rwightman/pytorch-image-models/blob/master). The config files of these models are only for inference. We don't ensure these config files' training accuracy and welcome you to contribute your reproduction results.*
|
||||
|
||||
## Citation
|
||||
|
||||
```bibtex
|
||||
@INPROCEEDINGS{Zagoruyko2016WRN,
|
||||
author = {Sergey Zagoruyko and Nikos Komodakis},
|
||||
title = {Wide Residual Networks},
|
||||
booktitle = {BMVC},
|
||||
year = {2016}}
|
||||
```
|
77
configs/wrn/metafile.yml
Normal file
77
configs/wrn/metafile.yml
Normal file
@ -0,0 +1,77 @@
|
||||
Collections:
|
||||
- Name: Wide-ResNet
|
||||
Metadata:
|
||||
Training Data: ImageNet-1k
|
||||
Training Techniques:
|
||||
- SGD with Momentum
|
||||
- Weight Decay
|
||||
Training Resources: 8x V100 GPUs
|
||||
Epochs: 100
|
||||
Batch Size: 256
|
||||
Architecture:
|
||||
- 1x1 Convolution
|
||||
- Batch Normalization
|
||||
- Convolution
|
||||
- Global Average Pooling
|
||||
- Max Pooling
|
||||
- ReLU
|
||||
- Residual Connection
|
||||
- Softmax
|
||||
- Wide Residual Block
|
||||
Paper:
|
||||
URL: https://arxiv.org/abs/1605.07146
|
||||
Title: "Wide Residual Networks"
|
||||
README: configs/wrn/README.md
|
||||
Code:
|
||||
URL: https://github.com/open-mmlab/mmclassification/blob/v0.20.1/mmcls/models/backbones/resnet.py#L383
|
||||
Version: v0.20.1
|
||||
|
||||
Models:
|
||||
- Name: wide-resnet50_3rdparty_8xb32_in1k
|
||||
Metadata:
|
||||
FLOPs: 11440000000 # 11.44G
|
||||
Parameters: 68880000 # 68.88M
|
||||
In Collection: Wide-ResNet
|
||||
Results:
|
||||
- Task: Image Classification
|
||||
Dataset: ImageNet-1k
|
||||
Metrics:
|
||||
Top 1 Accuracy: 78.48
|
||||
Top 5 Accuracy: 94.08
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/wrn/wide-resnet50_3rdparty_8xb32_in1k_20220304-66678344.pth
|
||||
Config: configs/wrn/wide-resnet50_8xb32_in1k.py
|
||||
Converted From:
|
||||
Weights: https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth
|
||||
Code: https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py
|
||||
- Name: wide-resnet101_3rdparty_8xb32_in1k
|
||||
Metadata:
|
||||
FLOPs: 22810000000 # 22.81G
|
||||
Parameters: 126890000 # 126.89M
|
||||
In Collection: Wide-ResNet
|
||||
Results:
|
||||
- Task: Image Classification
|
||||
Dataset: ImageNet-1k
|
||||
Metrics:
|
||||
Top 1 Accuracy: 78.84
|
||||
Top 5 Accuracy: 94.28
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/wrn/wide-resnet101_3rdparty_8xb32_in1k_20220304-8d5f9d61.pth
|
||||
Config: configs/wrn/wide-resnet101_8xb32_in1k.py
|
||||
Converted From:
|
||||
Weights: https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth
|
||||
Code: https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py
|
||||
- Name: wide-resnet50_3rdparty-timm_8xb32_in1k
|
||||
Metadata:
|
||||
FLOPs: 11440000000 # 11.44G
|
||||
Parameters: 68880000 # 68.88M
|
||||
In Collection: Wide-ResNet
|
||||
Results:
|
||||
- Task: Image Classification
|
||||
Dataset: ImageNet-1k
|
||||
Metrics:
|
||||
Top 1 Accuracy: 81.45
|
||||
Top 5 Accuracy: 95.53
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/wrn/wide-resnet50_3rdparty-timm_8xb32_in1k_20220304-83ae4399.pth
|
||||
Config: configs/wrn/wide-resnet50_timm_8xb32_in1k.py
|
||||
Converted From:
|
||||
Weights: https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/wide_resnet50_racm-8234f177.pth
|
||||
Code: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/resnet.py
|
7
configs/wrn/wide-resnet101_8xb32_in1k.py
Normal file
7
configs/wrn/wide-resnet101_8xb32_in1k.py
Normal file
@ -0,0 +1,7 @@
|
||||
_base_ = [
|
||||
'../_base_/models/wide-resnet50.py',
|
||||
'../_base_/datasets/imagenet_bs32_pil_resize.py',
|
||||
'../_base_/schedules/imagenet_bs256.py', '../_base_/default_runtime.py'
|
||||
]
|
||||
|
||||
model = dict(backbone=dict(depth=101))
|
5
configs/wrn/wide-resnet50_8xb32_in1k.py
Normal file
5
configs/wrn/wide-resnet50_8xb32_in1k.py
Normal file
@ -0,0 +1,5 @@
|
||||
_base_ = [
|
||||
'../_base_/models/wide-resnet50.py',
|
||||
'../_base_/datasets/imagenet_bs32_pil_resize.py',
|
||||
'../_base_/schedules/imagenet_bs256.py', '../_base_/default_runtime.py'
|
||||
]
|
5
configs/wrn/wide-resnet50_timm_8xb32_in1k.py
Normal file
5
configs/wrn/wide-resnet50_timm_8xb32_in1k.py
Normal file
@ -0,0 +1,5 @@
|
||||
_base_ = [
|
||||
'../_base_/models/wide-resnet50.py',
|
||||
'../_base_/datasets/imagenet_bs32_pil_bicubic.py',
|
||||
'../_base_/schedules/imagenet_bs256.py', '../_base_/default_runtime.py'
|
||||
]
|
@ -4,7 +4,7 @@ ARG CUDNN="7"
|
||||
FROM pytorch/pytorch:${PYTORCH}-cuda${CUDA}-cudnn${CUDNN}-devel
|
||||
|
||||
ARG MMCV="1.4.2"
|
||||
ARG MMCLS="0.20.1"
|
||||
ARG MMCLS="0.21.0"
|
||||
|
||||
ENV PYTHONUNBUFFERED TRUE
|
||||
|
||||
|
Binary file not shown.
Before Width: | Height: | Size: 44 KiB |
Binary file not shown.
Before Width: | Height: | Size: 9.2 KiB |
Binary file not shown.
Before Width: | Height: | Size: 19 KiB |
@ -1,5 +1,39 @@
|
||||
# Changelog
|
||||
|
||||
## v0.21.0(04/03/2022)
|
||||
|
||||
### Highlights
|
||||
|
||||
- Support ResNetV1c and Wide-ResNet, and provide pre-trained models.
|
||||
- Support dynamic input shape for ViT-based algorithms. Now our ViT, DeiT, Swin-Transformer and T2T-ViT support forwarding with any input shape.
|
||||
- Reproduce training results of DeiT. And our DeiT-T and DeiT-S have higher accuracy comparing with the official weights.
|
||||
|
||||
### New Features
|
||||
|
||||
- Add ResNetV1c. ([#692](https://github.com/open-mmlab/mmclassification/pull/692))
|
||||
- Support Wide-ResNet. ([#715](https://github.com/open-mmlab/mmclassification/pull/715))
|
||||
- Support gem pooling ([#677](https://github.com/open-mmlab/mmclassification/pull/677))
|
||||
|
||||
### Improvements
|
||||
|
||||
- Reproduce training results of DeiT. ([#711](https://github.com/open-mmlab/mmclassification/pull/711))
|
||||
- Add ConvNeXt pretrain models on ImageNet-1k. ([#707](https://github.com/open-mmlab/mmclassification/pull/707))
|
||||
- Support dynamic input shape for ViT-based algorithms. ([#706](https://github.com/open-mmlab/mmclassification/pull/706))
|
||||
- Add `evaluate` function for ConcatDataset. ([#650](https://github.com/open-mmlab/mmclassification/pull/650))
|
||||
- Enhance vis-pipeline tool. ([#604](https://github.com/open-mmlab/mmclassification/pull/604))
|
||||
- Return code 1 if scripts runs failed. ([#694](https://github.com/open-mmlab/mmclassification/pull/694))
|
||||
- Use PyTorch official `one_hot` to implement `convert_to_one_hot`. ([#696](https://github.com/open-mmlab/mmclassification/pull/696))
|
||||
- Add a new pre-commit-hook to automatically add a copyright. ([#710](https://github.com/open-mmlab/mmclassification/pull/710))
|
||||
- Add deprecation message for deploy tools. ([#697](https://github.com/open-mmlab/mmclassification/pull/697))
|
||||
- Upgrade isort pre-commit hooks. ([#687](https://github.com/open-mmlab/mmclassification/pull/687))
|
||||
- Use `--gpu-id` instead of `--gpu-ids` in non-distributed multi-gpu training/testing. ([#688](https://github.com/open-mmlab/mmclassification/pull/688))
|
||||
- Remove deprecation. ([#633](https://github.com/open-mmlab/mmclassification/pull/633))
|
||||
|
||||
### Bug Fixes
|
||||
|
||||
- Fix Conformer forward with irregular input size. ([#686](https://github.com/open-mmlab/mmclassification/pull/686))
|
||||
- Add `dist.barrier` to fix a bug in directory checking. ([#666](https://github.com/open-mmlab/mmclassification/pull/666))
|
||||
|
||||
## v0.20.1(07/02/2022)
|
||||
|
||||
### Bug Fixes
|
||||
|
@ -10,8 +10,9 @@ The compatible MMClassification and MMCV versions are as below. Please install t
|
||||
|
||||
| MMClassification version | MMCV version |
|
||||
|:------------------------:|:---------------------:|
|
||||
| dev | mmcv>=1.4.4, <=1.5.0 |
|
||||
| 0.20.1 (master) | mmcv>=1.4.2, <=1.5.0 |
|
||||
| dev | mmcv>=1.4.6, <=1.5.0 |
|
||||
| 0.21.0 (master) | mmcv>=1.4.2, <=1.5.0 |
|
||||
| 0.20.1 | mmcv>=1.4.2, <=1.5.0 |
|
||||
| 0.19.0 | mmcv>=1.3.16, <=1.5.0 |
|
||||
| 0.18.0 | mmcv>=1.3.16, <=1.5.0 |
|
||||
| 0.17.0 | mmcv>=1.3.8, <=1.5.0 |
|
||||
|
@ -71,11 +71,11 @@ The ResNet family models below are trained by standard data augmentations, i.e.,
|
||||
| T2T-ViT_t-24 | 64.00 | 12.69 | 82.71 | 96.09 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/t2t_vit/t2t-vit-t-24_8xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/t2t-vit/t2t-vit-t-24_8xb64_in1k_20211214-b2a68ae3.pth) | [log](https://download.openmmlab.com/mmclassification/v0/t2t-vit/t2t-vit-t-24_8xb64_in1k_20211214-b2a68ae3.log.json)|
|
||||
| Mixer-B/16\* | 59.88 | 12.61 | 76.68 | 92.25 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/mlp_mixer/mlp-mixer-base-p16_64xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/mlp-mixer/mixer-base-p16_3rdparty_64xb64_in1k_20211124-1377e3e0.pth) |
|
||||
| Mixer-L/16\* | 208.2 | 44.57 | 72.34 | 88.02 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/mlp_mixer/mlp-mixer-large-p16_64xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/mlp-mixer/mixer-large-p16_3rdparty_64xb64_in1k_20211124-5a2519d2.pth) |
|
||||
| DeiT-tiny\* | 5.72 | 1.08 | 72.13 | 91.13 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/deit/deit-tiny_pt-4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-tiny_3rdparty_pt-4xb256_in1k_20211124-e930093b.pth) |
|
||||
| DeiT-tiny | 5.72 | 1.08 | 74.50 | 92.24 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/deit/deit-tiny_pt-4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-tiny_pt-4xb256_in1k_20220218-13b382a0.pth) | [log](https://download.openmmlab.com/mmclassification/v0/deit/deit-tiny_pt-4xb256_in1k_20220218-13b382a0.log.json) |
|
||||
| DeiT-tiny distilled\* | 5.72 | 1.08 | 74.51 | 91.90 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/deit/deit-tiny-distilled_pt-4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-tiny-distilled_3rdparty_pt-4xb256_in1k_20211216-c429839a.pth) |
|
||||
| DeiT-small\* | 22.05 | 4.24 | 79.83 | 94.95 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/deit/deit-small_pt-4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-small_3rdparty_pt-4xb256_in1k_20211124-ffe94edd.pth) |
|
||||
| DeiT-small | 22.05 | 4.24 | 80.69 | 95.06 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/deit/deit-small_pt-4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-small_pt-4xb256_in1k_20220218-9425b9bb.pth) | [log](https://download.openmmlab.com/mmclassification/v0/deit/deit-small_pt-4xb256_in1k_20220218-9425b9bb.log.json) |
|
||||
| DeiT-small distilled\* | 22.05 | 4.24 | 81.17 | 95.40 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/deit/deit-small-distilled_pt-4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-small-distilled_3rdparty_pt-4xb256_in1k_20211216-4de1d725.pth) |
|
||||
| DeiT-base\* | 86.57 | 16.86 | 81.79 | 95.59 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/deit/deit-base_pt-16xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-base_3rdparty_pt-16xb64_in1k_20211124-6f40c188.pth) |
|
||||
| DeiT-base | 86.57 | 16.86 | 81.76 | 95.81 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/deit/deit-base_pt-16xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-base_pt-16xb64_in1k_20220216-db63c16c.pth) | [log](https://download.openmmlab.com/mmclassification/v0/deit/deit-base_pt-16xb64_in1k_20220216-db63c16c.log.json) |
|
||||
| DeiT-base distilled\* | 86.57 | 16.86 | 83.33 | 96.49 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/deit/deit-base-distilled_pt-16xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-base-distilled_3rdparty_pt-16xb64_in1k_20211216-42891296.pth) |
|
||||
| DeiT-base 384px\* | 86.86 | 49.37 | 83.04 | 96.31 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/deit/deit-base_ft-16xb32_in1k-384px.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-base_3rdparty_ft-16xb32_in1k-384px_20211124-822d02f2.pth) |
|
||||
| DeiT-base distilled 384px\* | 86.86 | 49.37 | 85.55 | 97.35 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/deit/deit-base-distilled_ft-16xb32_in1k-384px.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-base-distilled_3rdparty_ft-16xb32_in1k-384px_20211216-e48d6000.pth) |
|
||||
@ -128,6 +128,8 @@ The ResNet family models below are trained by standard data augmentations, i.e.,
|
||||
| HRNet-W64\* | 128.06 | 29.00 | 79.46 | 94.65 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/hrnet/hrnet-w64_4xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/hrnet/hrnet-w64_3rdparty_8xb32_in1k_20220120-19126642.pth) |
|
||||
| HRNet-W18 (ssld)\* | 21.30 | 4.33 | 81.06 | 95.70 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/hrnet/hrnet-w18_4xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/hrnet/hrnet-w18_3rdparty_8xb32-ssld_in1k_20220120-455f69ea.pth) |
|
||||
| HRNet-W48 (ssld)\* | 77.47 | 17.36 | 83.63 | 96.79 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/hrnet/hrnet-w48_4xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/hrnet/hrnet-w48_3rdparty_8xb32-ssld_in1k_20220120-d0459c38.pth) |
|
||||
| WRN-50\* | 68.88 | 11.44 | 81.45 | 95.53 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/wrn/wide-resnet50_timm_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/wrn/wide-resnet50_3rdparty-timm_8xb32_in1k_20220304-83ae4399.pth) |
|
||||
| WRN-101\* | 126.89| 22.81 | 78.84 | 94.28 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/wrn/wide-resnet101_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/wrn/wide-resnet101_3rdparty_8xb32_in1k_20220304-8d5f9d61.pth) |
|
||||
|
||||
*Models with \* are converted from other repos, others are trained by ourselves.*
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
# MISCELLANEOUS
|
||||
# Miscellaneous
|
||||
|
||||
<!-- TOC -->
|
||||
|
||||
|
@ -2,11 +2,10 @@
|
||||
|
||||
<!-- TOC -->
|
||||
|
||||
- [Visualization](#visualization)
|
||||
- [Pipeline Visualization](#pipeline-visualization)
|
||||
- [Learning Rate Schedule Visualization](#learning-rate-schedule-visualization)
|
||||
- [Class Activation Map Visualization](#class-activation-map-visualization)
|
||||
- [FAQs](#faqs)
|
||||
- [Pipeline Visualization](#pipeline-visualization)
|
||||
- [Learning Rate Schedule Visualization](#learning-rate-schedule-visualization)
|
||||
- [Class Activation Map Visualization](#class-activation-map-visualization)
|
||||
- [FAQs](#faqs)
|
||||
|
||||
<!-- TOC -->
|
||||
## Pipeline Visualization
|
||||
@ -14,17 +13,18 @@
|
||||
```bash
|
||||
python tools/visualizations/vis_pipeline.py \
|
||||
${CONFIG_FILE} \
|
||||
--output-dir ${OUTPUT_DIR} \
|
||||
--phase ${DATASET_PHASE} \
|
||||
--number ${BUNBER_IMAGES_DISPLAY} \
|
||||
--skip-type ${SKIP_TRANSFORM_TYPE}
|
||||
--mode ${DISPLAY_MODE} \
|
||||
--show \
|
||||
--adaptive \
|
||||
--min-edge-length ${MIN_EDGE_LENGTH} \
|
||||
--max-edge-length ${MAX_EDGE_LENGTH} \
|
||||
--bgr2rgb \
|
||||
--window-size ${WINDOW_SIZE}
|
||||
[--output-dir ${OUTPUT_DIR}] \
|
||||
[--phase ${DATASET_PHASE}] \
|
||||
[--number ${BUNBER_IMAGES_DISPLAY}] \
|
||||
[--skip-type ${SKIP_TRANSFORM_TYPE}] \
|
||||
[--mode ${DISPLAY_MODE}] \
|
||||
[--show] \
|
||||
[--adaptive] \
|
||||
[--min-edge-length ${MIN_EDGE_LENGTH}] \
|
||||
[--max-edge-length ${MAX_EDGE_LENGTH}] \
|
||||
[--bgr2rgb] \
|
||||
[--window-size ${WINDOW_SIZE}] \
|
||||
[--cfg-options ${CFG_OPTIONS}]
|
||||
```
|
||||
|
||||
**Description of all arguments**:
|
||||
@ -32,48 +32,57 @@ python tools/visualizations/vis_pipeline.py \
|
||||
- `config` : The path of a model config file.
|
||||
- `--output-dir`: The output path for visualized images. If not specified, it will be set to `''`, which means not to save.
|
||||
- `--phase`: Phase of visualizing dataset,must be one of `[train, val, test]`. If not specified, it will be set to `train`.
|
||||
- `--number`: The number of samples to visualize. If not specified, display all images in the dataset.
|
||||
- `--number`: The number of samples to visualized. If not specified, display all images in the dataset.
|
||||
- `--skip-type`: The pipelines to be skipped. If not specified, it will be set to `['ToTensor', 'Normalize', 'ImageToTensor', 'Collect']`.
|
||||
- `--mode`: The display mode, can be one of `[original, pipeline, concat]`. If not specified, it will be set to `concat`.
|
||||
- `--show`: If set, display pictures in pop-up windows.
|
||||
- `--adaptive`: If set, automatically adjust the size of the visualization images.
|
||||
- `--adaptive`: If set, adaptively resize images for better visualization.
|
||||
- `--min-edge-length`: The minimum edge length, used when `--adaptive` is set. When any side of the picture is smaller than `${MIN_EDGE_LENGTH}`, the picture will be enlarged while keeping the aspect ratio unchanged, and the short side will be aligned to `${MIN_EDGE_LENGTH}`. If not specified, it will be set to 200.
|
||||
- `--max-edge-length`: The maximum edge length, used when `--adaptive` is set. When any side of the picture is larger than `${MAX_EDGE_LENGTH}`, the picture will be reduced while keeping the aspect ratio unchanged, and the long side will be aligned to `${MAX_EDGE_LENGTH}`. If not specified, it will be set to 1000.
|
||||
- `--bgr2rgb`: If set, flip the color channel order of images.
|
||||
- `--window-size`: The shape of the display window. If not specified, it will be set to `12*7`. If used, it must be in the format `'W*H'`.
|
||||
- `--cfg-options` : Modifications to the configuration file, refer to [Tutorial 1: Learn about Configs](https://mmclassification.readthedocs.io/en/latest/tutorials/config.html).
|
||||
|
||||
```{note}
|
||||
|
||||
1. If the `--mode` is not specified, it will be set to `concat` as default, get the pictures stitched together by original pictures and transformed pictures; if the `--mode` is set to `original`, get the original pictures; if the `--mode` is set to `pipeline`, get the transformed pictures.
|
||||
1. If the `--mode` is not specified, it will be set to `concat` as default, get the pictures stitched together by original pictures and transformed pictures; if the `--mode` is set to `original`, get the original pictures; if the `--mode` is set to `transformed`, get the transformed pictures; if the `--mode` is set to `pipeline`, get all the intermediate images through the pipeline.
|
||||
|
||||
2. When `--adaptive` option is set, images that are too large or too small will be automatically adjusted, you can use `--min-edge-length` and `--max-edge-length` to set the adjust size.
|
||||
```
|
||||
|
||||
**Examples**:
|
||||
|
||||
1. Visualize all the transformed pictures of the `ImageNet` training set and display them in pop-up windows:
|
||||
1. In **'original'** mode, visualize 100 original pictures in the `CIFAR100` validation set, then display and save them in the `./tmp` folder:
|
||||
|
||||
```shell
|
||||
python ./tools/visualizations/vis_pipeline.py ./configs/resnet/resnet50_8xb32_in1k.py --show --mode pipeline
|
||||
```
|
||||
```shell
|
||||
python ./tools/visualizations/vis_pipeline.py configs/resnet/resnet50_8xb16_cifar100.py --phase val --output-dir tmp --mode original --number 100 --show --adaptive --bgr2rgb
|
||||
```
|
||||
|
||||
<div align=center><img src="../_static/image/tools/visualization/pipeline-pipeline.jpg" style=" width: auto; height: 40%; "></div>
|
||||
<div align=center><img src="https://user-images.githubusercontent.com/18586273/146117528-1ec2d918-57f8-4ae4-8ca3-a8d31b602f64.jpg" style=" width: auto; height: 40%; "></div>
|
||||
|
||||
2. Visualize 10 comparison pictures in the `ImageNet` train set and save them in the `./tmp` folder:
|
||||
2. In **'transformed'** mode, visualize all the transformed pictures of the `ImageNet` training set and display them in pop-up windows:
|
||||
|
||||
```shell
|
||||
python ./tools/visualizations/vis_pipeline.py configs/swin_transformer/swin_base_224_b16x64_300e_imagenet.py --phase train --output-dir tmp --number 10 --adaptive
|
||||
```
|
||||
```shell
|
||||
python ./tools/visualizations/vis_pipeline.py ./configs/resnet/resnet50_8xb32_in1k.py --show --mode transformed
|
||||
```
|
||||
|
||||
<div align=center><img src="../_static/image/tools/visualization/pipeline-concat.jpg" style=" width: auto; height: 40%; "></div>
|
||||
<div align=center><img src="https://user-images.githubusercontent.com/18586273/146117553-8006a4ba-e2fa-4f53-99bc-42a4b06e413f.jpg" style=" width: auto; height: 40%; "></div>
|
||||
|
||||
3. Visualize 100 original pictures in the `CIFAR100` validation set, then display and save them in the `./tmp` folder:
|
||||
3. In **'concat'** mode, visualize 10 pairs of origin and transformed images for comparison in the `ImageNet` train set and save them in the `./tmp` folder:
|
||||
|
||||
```shell
|
||||
python ./tools/visualizations/vis_pipeline.py configs/resnet/resnet50_8xb16_cifar100.py --phase val --output-dir tmp --mode original --number 100 --show --adaptive --bgr2rgb
|
||||
```
|
||||
```shell
|
||||
python ./tools/visualizations/vis_pipeline.py configs/swin_transformer/swin_base_224_b16x64_300e_imagenet.py --phase train --output-dir tmp --number 10 --adaptive
|
||||
```
|
||||
|
||||
<div align=center><img src="../_static/image/tools/visualization/pipeline-original.jpg" style=" width: auto; height: 40%; "></div>
|
||||
<div align=center><img src="https://user-images.githubusercontent.com/18586273/146128259-0a369991-7716-411d-8c27-c6863e6d76ea.JPEG" style=" width: auto; height: 40%; "></div>
|
||||
|
||||
4. In **'pipeline'** mode, visualize all the intermediate pictures in the `ImageNet` train set through the pipeline:
|
||||
|
||||
```shell
|
||||
python ./tools/visualizations/vis_pipeline.py configs/swin_transformer/swin_base_224_b16x64_300e_imagenet.py --phase train --adaptive --mode pipeline --show
|
||||
```
|
||||
|
||||
<div align=center><img src="https://user-images.githubusercontent.com/18586273/146128201-eb97c2aa-a615-4a81-a649-38db1c315d0e.JPEG" style=" width: auto; height: 40%; "></div>
|
||||
|
||||
## Learning Rate Schedule Visualization
|
||||
|
||||
|
Binary file not shown.
Before Width: | Height: | Size: 44 KiB |
Binary file not shown.
Before Width: | Height: | Size: 9.2 KiB |
Binary file not shown.
Before Width: | Height: | Size: 19 KiB |
@ -10,8 +10,9 @@ MMClassification 和 MMCV 的适配关系如下,请安装正确版本的 MMCV
|
||||
|
||||
| MMClassification 版本 | MMCV 版本 |
|
||||
|:---------------------:|:---------------------:|
|
||||
| dev | mmcv>=1.4.4, <=1.5.0 |
|
||||
| 0.20.1 (master)| mmcv>=1.4.2, <=1.5.0 |
|
||||
| dev | mmcv>=1.4.6, <=1.5.0 |
|
||||
| 0.21.0 (master)| mmcv>=1.4.2, <=1.5.0 |
|
||||
| 0.20.1 | mmcv>=1.4.2, <=1.5.0 |
|
||||
| 0.19.0 | mmcv>=1.3.16, <=1.5.0 |
|
||||
| 0.18.0 | mmcv>=1.3.16, <=1.5.0 |
|
||||
| 0.17.0 | mmcv>=1.3.8, <=1.5.0 |
|
||||
|
@ -2,11 +2,10 @@
|
||||
|
||||
<!-- TOC -->
|
||||
|
||||
- [可视化](#可视化)
|
||||
- [数据流水线可视化](#数据流水线可视化)
|
||||
- [学习率策略可视化](#学习率策略可视化)
|
||||
- [类别激活图可视化](#类别激活图可视化)
|
||||
- [常见问题](#常见问题)
|
||||
- [数据流水线可视化](#数据流水线可视化)
|
||||
- [学习率策略可视化](#学习率策略可视化)
|
||||
- [类别激活图可视化](#类别激活图可视化)
|
||||
- [常见问题](#常见问题)
|
||||
|
||||
<!-- TOC -->
|
||||
|
||||
@ -15,17 +14,18 @@
|
||||
```bash
|
||||
python tools/visualizations/vis_pipeline.py \
|
||||
${CONFIG_FILE} \
|
||||
--output-dir ${OUTPUT_DIR} \
|
||||
--phase ${DATASET_PHASE} \
|
||||
--number ${BUNBER_IMAGES_DISPLAY} \
|
||||
--skip-type ${SKIP_TRANSFORM_TYPE} \
|
||||
--mode ${DISPLAY_MODE} \
|
||||
--show \
|
||||
--adaptive \
|
||||
--min-edge-length ${MIN_EDGE_LENGTH} \
|
||||
--max-edge-length ${MAX_EDGE_LENGTH} \
|
||||
--bgr2rgb \
|
||||
--window-size ${WINDOW_SIZE}
|
||||
[--output-dir ${OUTPUT_DIR}] \
|
||||
[--phase ${DATASET_PHASE}] \
|
||||
[--number ${BUNBER_IMAGES_DISPLAY}] \
|
||||
[--skip-type ${SKIP_TRANSFORM_TYPE}] \
|
||||
[--mode ${DISPLAY_MODE}] \
|
||||
[--show] \
|
||||
[--adaptive] \
|
||||
[--min-edge-length ${MIN_EDGE_LENGTH}] \
|
||||
[--max-edge-length ${MAX_EDGE_LENGTH}] \
|
||||
[--bgr2rgb] \
|
||||
[--window-size ${WINDOW_SIZE}] \
|
||||
[--cfg-options ${CFG_OPTIONS}]
|
||||
```
|
||||
|
||||
**所有参数的说明**:
|
||||
@ -35,71 +35,80 @@ python tools/visualizations/vis_pipeline.py \
|
||||
- `--phase`: 可视化数据集的阶段,只能为 `[train, val, test]` 之一,默认为 `train`。
|
||||
- `--number`: 可视化样本数量。如果没有指定,默认展示数据集的所有图片。
|
||||
- `--skip-type`: 预设跳过的数据流水线过程。如果没有指定,默认为 `['ToTensor', 'Normalize', 'ImageToTensor', 'Collect']`。
|
||||
- `--mode`: 可视化的模式,只能为 `[original, pipeline, concat]` 之一,如果没有指定,默认为 `concat`。
|
||||
- `--mode`: 可视化的模式,只能为 `[original, transformed, concat, pipeline]` 之一,如果没有指定,默认为 `concat`。
|
||||
- `--show`: 将可视化图片以弹窗形式展示。
|
||||
- `--adaptive`: 自动调节可视化图片的大小。
|
||||
- `--min-edge-length`: 最短边长度,当使用了 `--adaptive` 时有效。 当图片任意边小于 `${MIN_EDGE_LENGTH}` 时,会保持长宽比不变放大图片,短边对齐至 `${MIN_EDGE_LENGTH}`,默认为200。
|
||||
- `--max-edge-length`: 最长边长度,当使用了 `--adaptive` 时有效。 当图片任意边大于 `${MAX_EDGE_LENGTH}` 时,会保持长宽比不变缩小图片,短边对齐至 `${MAX_EDGE_LENGTH}`,默认为1000。
|
||||
- `--bgr2rgb`: 将图片的颜色通道翻转。
|
||||
- `--window-size`: 可视化窗口大小,如果没有指定,默认为 `12*7`。如果需要指定,按照格式 `'W*H'`。
|
||||
- `--cfg-options` : 对配置文件的修改,参考[教程 1:如何编写配置文件](https://mmclassification.readthedocs.io/zh_CN/latest/tutorials/config.html)。
|
||||
|
||||
```{note}
|
||||
|
||||
1. 如果不指定 `--mode`,默认设置为 `concat`,获取原始图片和预处理后图片拼接的图片;如果 `--mode` 设置为 `original`,则获取原始图片; 如果 `--mode` 设置为 `pipeline`,则获取预处理后的图片。
|
||||
1. 如果不指定 `--mode`,默认设置为 `concat`,获取原始图片和预处理后图片拼接的图片;如果 `--mode` 设置为 `original`,则获取原始图片;如果 `--mode` 设置为 `transformed`,则获取预处理后的图片;如果 `--mode` 设置为 `pipeline`,则获得数据流水线所有中间过程图片。
|
||||
|
||||
2. 当指定了 `--adaptive` 选项时,会自动的调整尺寸过大和过小的图片,你可以通过设定 `--min-edge-length` 与 `--max-edge-length` 来指定自动调整的图片尺寸。
|
||||
```
|
||||
|
||||
**示例**:
|
||||
|
||||
1. 可视化 `ImageNet` 训练集的所有经过预处理的图片,并以弹窗形式显示:
|
||||
1. **'original'** 模式,可视化 `CIFAR100` 验证集中的100张原始图片,显示并保存在 `./tmp` 文件夹下:
|
||||
|
||||
```shell
|
||||
python ./tools/visualizations/vis_pipeline.py ./configs/resnet/resnet50_8xb32_in1k.py --show --mode pipeline
|
||||
```
|
||||
```shell
|
||||
python ./tools/visualizations/vis_pipeline.py configs/resnet/resnet50_8xb16_cifar100.py --phase val --output-dir tmp --mode original --number 100 --show --adaptive --bgr2rgb
|
||||
```
|
||||
|
||||
<div align=center><img src="../_static/image/tools/visualization/pipeline-pipeline.jpg" style=" width: auto; height: 40%; "></div>
|
||||
<div align=center><img src="https://user-images.githubusercontent.com/18586273/146117528-1ec2d918-57f8-4ae4-8ca3-a8d31b602f64.jpg" style=" width: auto; height: 40%; "></div>
|
||||
|
||||
2. 可视化 `ImageNet` 训练集的10张原始图片与预处理后图片对比图,保存在 `./tmp` 文件夹下:
|
||||
2. **'transformed'** 模式,可视化 `ImageNet` 训练集的所有经过预处理的图片,并以弹窗形式显示:
|
||||
|
||||
```shell
|
||||
python ./tools/visualizations/vis_pipeline.py configs/swin_transformer/swin_base_224_b16x64_300e_imagenet.py --phase train --output-dir tmp --number 10 --adaptive
|
||||
```
|
||||
```shell
|
||||
python ./tools/visualizations/vis_pipeline.py ./configs/resnet/resnet50_8xb32_in1k.py --show --mode transformed
|
||||
```
|
||||
|
||||
<div align=center><img src="../_static/image/tools/visualization/pipeline-concat.jpg" style=" width: auto; height: 40%; "></div>
|
||||
<div align=center><img src="https://user-images.githubusercontent.com/18586273/146117553-8006a4ba-e2fa-4f53-99bc-42a4b06e413f.jpg" style=" width: auto; height: 40%; "></div>
|
||||
|
||||
3. 可视化 `CIFAR100` 验证集中的100张原始图片,显示并保存在 `./tmp` 文件夹下:
|
||||
3. **'concat'** 模式,可视化 `ImageNet` 训练集的10张原始图片与预处理后图片对比图,保存在 `./tmp` 文件夹下:
|
||||
|
||||
```shell
|
||||
python ./tools/visualizations/vis_pipeline.py configs/resnet/resnet50_8xb16_cifar100.py --phase val --output-dir tmp --mode original --number 100 --show --adaptive --bgr2rgb
|
||||
```
|
||||
```shell
|
||||
python ./tools/visualizations/vis_pipeline.py configs/swin_transformer/swin_base_224_b16x64_300e_imagenet.py --phase train --output-dir tmp --number 10 --adaptive
|
||||
```
|
||||
|
||||
<div align=center><img src="../_static/image/tools/visualization/pipeline-original.jpg" style=" width: auto; height: 40%; "></div>
|
||||
<div align=center><img src="https://user-images.githubusercontent.com/18586273/146128259-0a369991-7716-411d-8c27-c6863e6d76ea.JPEG" style=" width: auto; height: 40%; "></div>
|
||||
|
||||
4. **'pipeline'** 模式,可视化 `ImageNet` 训练集经过数据流水线的过程图像:
|
||||
|
||||
```shell
|
||||
python ./tools/visualizations/vis_pipeline.py configs/swin_transformer/swin_base_224_b16x64_300e_imagenet.py --phase train --adaptive --mode pipeline --show
|
||||
```
|
||||
|
||||
<div align=center><img src="https://user-images.githubusercontent.com/18586273/146128201-eb97c2aa-a615-4a81-a649-38db1c315d0e.JPEG" style=" width: auto; height: 40%; "></div>
|
||||
|
||||
## 学习率策略可视化
|
||||
|
||||
```bash
|
||||
python tools/visualizations/vis_lr.py \
|
||||
${CONFIG_FILE} \
|
||||
--dataset-size ${Dataset_Size} \
|
||||
--ngpus ${NUM_GPUs}
|
||||
--save-path ${SAVE_PATH} \
|
||||
--title ${TITLE} \
|
||||
--style ${STYLE} \
|
||||
--window-size ${WINDOW_SIZE}
|
||||
--cfg-options
|
||||
[--dataset-size ${Dataset_Size}] \
|
||||
[--ngpus ${NUM_GPUs}] \
|
||||
[--save-path ${SAVE_PATH}] \
|
||||
[--title ${TITLE}] \
|
||||
[--style ${STYLE}] \
|
||||
[--window-size ${WINDOW_SIZE}] \
|
||||
[--cfg-options ${CFG_OPTIONS}] \
|
||||
```
|
||||
|
||||
**所有参数的说明**:
|
||||
|
||||
- `config` : 模型配置文件的路径。
|
||||
- `dataset-size` : 数据集的大小。如果指定,`build_dataset` 将被跳过并使用这个大小作为数据集大小,默认使用 `build_dataset` 所得数据集的大小。
|
||||
- `ngpus` : 使用 GPU 的数量。
|
||||
- `save-path` : 保存的可视化图片的路径,默认不保存。
|
||||
- `title` : 可视化图片的标题,默认为配置文件名。
|
||||
- `style` : 可视化图片的风格,默认为 `whitegrid`。
|
||||
- `window-size`: 可视化窗口大小,如果没有指定,默认为 `12*7`。如果需要指定,按照格式 `'W*H'`。
|
||||
- `cfg-options` : 对配置文件的修改,参考[教程 1:如何编写配置文件](https://mmclassification.readthedocs.io/zh_CN/latest/tutorials/config.html)。
|
||||
- `--dataset-size` : 数据集的大小。如果指定,`build_dataset` 将被跳过并使用这个大小作为数据集大小,默认使用 `build_dataset` 所得数据集的大小。
|
||||
- `--ngpus` : 使用 GPU 的数量。
|
||||
- `--save-path` : 保存的可视化图片的路径,默认不保存。
|
||||
- `--title` : 可视化图片的标题,默认为配置文件名。
|
||||
- `--style` : 可视化图片的风格,默认为 `whitegrid`。
|
||||
- `--window-size`: 可视化窗口大小,如果没有指定,默认为 `12*7`。如果需要指定,按照格式 `'W*H'`。
|
||||
- `--cfg-options` : 对配置文件的修改,参考[教程 1:如何编写配置文件](https://mmclassification.readthedocs.io/zh_CN/latest/tutorials/config.html)。
|
||||
|
||||
```{note}
|
||||
|
||||
|
@ -6,30 +6,14 @@ import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
|
||||
from mmcv.runner import (DistSamplerSeedHook, build_optimizer, build_runner,
|
||||
get_dist_info)
|
||||
from mmcv.runner import (DistSamplerSeedHook, Fp16OptimizerHook,
|
||||
build_optimizer, build_runner, get_dist_info)
|
||||
from mmcv.runner.hooks import DistEvalHook, EvalHook
|
||||
|
||||
from mmcls.core import DistOptimizerHook
|
||||
from mmcls.datasets import build_dataloader, build_dataset
|
||||
from mmcls.utils import get_root_logger
|
||||
|
||||
# TODO import eval hooks from mmcv and delete them from mmcls
|
||||
try:
|
||||
from mmcv.runner.hooks import EvalHook, DistEvalHook
|
||||
except ImportError:
|
||||
warnings.warn('DeprecationWarning: EvalHook and DistEvalHook from mmcls '
|
||||
'will be deprecated.'
|
||||
'Please install mmcv through master branch.')
|
||||
from mmcls.core import EvalHook, DistEvalHook
|
||||
|
||||
# TODO import optimizer hook from mmcv and delete them from mmcls
|
||||
try:
|
||||
from mmcv.runner import Fp16OptimizerHook
|
||||
except ImportError:
|
||||
warnings.warn('DeprecationWarning: FP16OptimizerHook from mmcls will be '
|
||||
'deprecated. Please install mmcv>=1.1.4.')
|
||||
from mmcls.core import Fp16OptimizerHook
|
||||
|
||||
|
||||
def init_random_seed(seed=None, device='cuda'):
|
||||
"""Initialize random seed.
|
||||
@ -131,7 +115,7 @@ def train_model(model,
|
||||
else:
|
||||
model = MMDataParallel(model, device_ids=cfg.gpu_ids)
|
||||
if not model.device_ids:
|
||||
from mmcv import digit_version, __version__
|
||||
from mmcv import __version__, digit_version
|
||||
assert digit_version(__version__) >= (1, 4, 4), \
|
||||
'To train with CPU, please confirm your mmcv version ' \
|
||||
'is not lower than v1.4.4'
|
||||
|
@ -1,6 +1,5 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .evaluation import * # noqa: F401, F403
|
||||
from .fp16 import * # noqa: F401, F403
|
||||
from .hook import * # noqa: F401, F403
|
||||
from .optimizers import * # noqa: F401, F403
|
||||
from .utils import * # noqa: F401, F403
|
||||
|
@ -1,12 +1,10 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .eval_hooks import DistEvalHook, EvalHook
|
||||
from .eval_metrics import (calculate_confusion_matrix, f1_score, precision,
|
||||
precision_recall_f1, recall, support)
|
||||
from .mean_ap import average_precision, mAP
|
||||
from .multilabel_eval_metrics import average_performance
|
||||
|
||||
__all__ = [
|
||||
'DistEvalHook', 'EvalHook', 'precision', 'recall', 'f1_score', 'support',
|
||||
'average_precision', 'mAP', 'average_performance',
|
||||
'calculate_confusion_matrix', 'precision_recall_f1'
|
||||
'precision', 'recall', 'f1_score', 'support', 'average_precision', 'mAP',
|
||||
'average_performance', 'calculate_confusion_matrix', 'precision_recall_f1'
|
||||
]
|
||||
|
@ -1,107 +0,0 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os.path as osp
|
||||
import warnings
|
||||
|
||||
from mmcv.runner import Hook
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
|
||||
class EvalHook(Hook):
|
||||
"""Evaluation hook.
|
||||
|
||||
Args:
|
||||
dataloader (DataLoader): A PyTorch dataloader.
|
||||
interval (int): Evaluation interval (by epochs). Default: 1.
|
||||
"""
|
||||
|
||||
def __init__(self, dataloader, interval=1, by_epoch=True, **eval_kwargs):
|
||||
warnings.warn(
|
||||
'DeprecationWarning: EvalHook and DistEvalHook in mmcls will be '
|
||||
'deprecated, please install mmcv through master branch.')
|
||||
if not isinstance(dataloader, DataLoader):
|
||||
raise TypeError('dataloader must be a pytorch DataLoader, but got'
|
||||
f' {type(dataloader)}')
|
||||
self.dataloader = dataloader
|
||||
self.interval = interval
|
||||
self.eval_kwargs = eval_kwargs
|
||||
self.by_epoch = by_epoch
|
||||
|
||||
def after_train_epoch(self, runner):
|
||||
if not self.by_epoch or not self.every_n_epochs(runner, self.interval):
|
||||
return
|
||||
from mmcls.apis import single_gpu_test
|
||||
results = single_gpu_test(runner.model, self.dataloader, show=False)
|
||||
self.evaluate(runner, results)
|
||||
|
||||
def after_train_iter(self, runner):
|
||||
if self.by_epoch or not self.every_n_iters(runner, self.interval):
|
||||
return
|
||||
from mmcls.apis import single_gpu_test
|
||||
runner.log_buffer.clear()
|
||||
results = single_gpu_test(runner.model, self.dataloader, show=False)
|
||||
self.evaluate(runner, results)
|
||||
|
||||
def evaluate(self, runner, results):
|
||||
eval_res = self.dataloader.dataset.evaluate(
|
||||
results, logger=runner.logger, **self.eval_kwargs)
|
||||
for name, val in eval_res.items():
|
||||
runner.log_buffer.output[name] = val
|
||||
runner.log_buffer.ready = True
|
||||
|
||||
|
||||
class DistEvalHook(EvalHook):
|
||||
"""Distributed evaluation hook.
|
||||
|
||||
Args:
|
||||
dataloader (DataLoader): A PyTorch dataloader.
|
||||
interval (int): Evaluation interval (by epochs). Default: 1.
|
||||
tmpdir (str, optional): Temporary directory to save the results of all
|
||||
processes. Default: None.
|
||||
gpu_collect (bool): Whether to use gpu or cpu to collect results.
|
||||
Default: False.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
dataloader,
|
||||
interval=1,
|
||||
gpu_collect=False,
|
||||
by_epoch=True,
|
||||
**eval_kwargs):
|
||||
warnings.warn(
|
||||
'DeprecationWarning: EvalHook and DistEvalHook in mmcls will be '
|
||||
'deprecated, please install mmcv through master branch.')
|
||||
if not isinstance(dataloader, DataLoader):
|
||||
raise TypeError('dataloader must be a pytorch DataLoader, but got '
|
||||
f'{type(dataloader)}')
|
||||
self.dataloader = dataloader
|
||||
self.interval = interval
|
||||
self.gpu_collect = gpu_collect
|
||||
self.by_epoch = by_epoch
|
||||
self.eval_kwargs = eval_kwargs
|
||||
|
||||
def after_train_epoch(self, runner):
|
||||
if not self.by_epoch or not self.every_n_epochs(runner, self.interval):
|
||||
return
|
||||
from mmcls.apis import multi_gpu_test
|
||||
results = multi_gpu_test(
|
||||
runner.model,
|
||||
self.dataloader,
|
||||
tmpdir=osp.join(runner.work_dir, '.eval_hook'),
|
||||
gpu_collect=self.gpu_collect)
|
||||
if runner.rank == 0:
|
||||
print('\n')
|
||||
self.evaluate(runner, results)
|
||||
|
||||
def after_train_iter(self, runner):
|
||||
if self.by_epoch or not self.every_n_iters(runner, self.interval):
|
||||
return
|
||||
from mmcls.apis import multi_gpu_test
|
||||
runner.log_buffer.clear()
|
||||
results = multi_gpu_test(
|
||||
runner.model,
|
||||
self.dataloader,
|
||||
tmpdir=osp.join(runner.work_dir, '.eval_hook'),
|
||||
gpu_collect=self.gpu_collect)
|
||||
if runner.rank == 0:
|
||||
print('\n')
|
||||
self.evaluate(runner, results)
|
@ -80,7 +80,7 @@ def precision_recall_f1(pred, target, average_mode='macro', thrs=0.):
|
||||
assert isinstance(pred, torch.Tensor), \
|
||||
(f'pred should be torch.Tensor or np.ndarray, but got {type(pred)}.')
|
||||
if isinstance(target, np.ndarray):
|
||||
target = torch.from_numpy(target)
|
||||
target = torch.from_numpy(target).long()
|
||||
assert isinstance(target, torch.Tensor), \
|
||||
f'target should be torch.Tensor or np.ndarray, ' \
|
||||
f'but got {type(target)}.'
|
||||
|
@ -1,5 +0,0 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .decorators import auto_fp16, force_fp32
|
||||
from .hooks import Fp16OptimizerHook, wrap_fp16_model
|
||||
|
||||
__all__ = ['auto_fp16', 'force_fp32', 'Fp16OptimizerHook', 'wrap_fp16_model']
|
@ -1,161 +0,0 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import functools
|
||||
from inspect import getfullargspec
|
||||
|
||||
import torch
|
||||
|
||||
from .utils import cast_tensor_type
|
||||
|
||||
|
||||
def auto_fp16(apply_to=None, out_fp32=False):
|
||||
"""Decorator to enable fp16 training automatically.
|
||||
|
||||
This decorator is useful when you write custom modules and want to support
|
||||
mixed precision training. If inputs arguments are fp32 tensors, they will
|
||||
be converted to fp16 automatically. Arguments other than fp32 tensors are
|
||||
ignored.
|
||||
|
||||
Args:
|
||||
apply_to (Iterable, optional): The argument names to be converted.
|
||||
`None` indicates all arguments.
|
||||
out_fp32 (bool): Whether to convert the output back to fp32.
|
||||
|
||||
:Example:
|
||||
|
||||
class MyModule1(nn.Module)
|
||||
|
||||
# Convert x and y to fp16
|
||||
@auto_fp16()
|
||||
def forward(self, x, y):
|
||||
pass
|
||||
|
||||
class MyModule2(nn.Module):
|
||||
|
||||
# convert pred to fp16
|
||||
@auto_fp16(apply_to=('pred', ))
|
||||
def do_something(self, pred, others):
|
||||
pass
|
||||
"""
|
||||
|
||||
def auto_fp16_wrapper(old_func):
|
||||
|
||||
@functools.wraps(old_func)
|
||||
def new_func(*args, **kwargs):
|
||||
# check if the module has set the attribute `fp16_enabled`, if not,
|
||||
# just fallback to the original method.
|
||||
if not isinstance(args[0], torch.nn.Module):
|
||||
raise TypeError('@auto_fp16 can only be used to decorate the '
|
||||
'method of nn.Module')
|
||||
if not (hasattr(args[0], 'fp16_enabled') and args[0].fp16_enabled):
|
||||
return old_func(*args, **kwargs)
|
||||
# get the arg spec of the decorated method
|
||||
args_info = getfullargspec(old_func)
|
||||
# get the argument names to be casted
|
||||
args_to_cast = args_info.args if apply_to is None else apply_to
|
||||
# convert the args that need to be processed
|
||||
new_args = []
|
||||
# NOTE: default args are not taken into consideration
|
||||
if args:
|
||||
arg_names = args_info.args[:len(args)]
|
||||
for i, arg_name in enumerate(arg_names):
|
||||
if arg_name in args_to_cast:
|
||||
new_args.append(
|
||||
cast_tensor_type(args[i], torch.float, torch.half))
|
||||
else:
|
||||
new_args.append(args[i])
|
||||
# convert the kwargs that need to be processed
|
||||
new_kwargs = {}
|
||||
if kwargs:
|
||||
for arg_name, arg_value in kwargs.items():
|
||||
if arg_name in args_to_cast:
|
||||
new_kwargs[arg_name] = cast_tensor_type(
|
||||
arg_value, torch.float, torch.half)
|
||||
else:
|
||||
new_kwargs[arg_name] = arg_value
|
||||
# apply converted arguments to the decorated method
|
||||
output = old_func(*new_args, **new_kwargs)
|
||||
# cast the results back to fp32 if necessary
|
||||
if out_fp32:
|
||||
output = cast_tensor_type(output, torch.half, torch.float)
|
||||
return output
|
||||
|
||||
return new_func
|
||||
|
||||
return auto_fp16_wrapper
|
||||
|
||||
|
||||
def force_fp32(apply_to=None, out_fp16=False):
|
||||
"""Decorator to convert input arguments to fp32 in force.
|
||||
|
||||
This decorator is useful when you write custom modules and want to support
|
||||
mixed precision training. If there are some inputs that must be processed
|
||||
in fp32 mode, then this decorator can handle it. If inputs arguments are
|
||||
fp16 tensors, they will be converted to fp32 automatically. Arguments other
|
||||
than fp16 tensors are ignored.
|
||||
|
||||
Args:
|
||||
apply_to (Iterable, optional): The argument names to be converted.
|
||||
`None` indicates all arguments.
|
||||
out_fp16 (bool): Whether to convert the output back to fp16.
|
||||
|
||||
:Example:
|
||||
|
||||
class MyModule1(nn.Module)
|
||||
|
||||
# Convert x and y to fp32
|
||||
@force_fp32()
|
||||
def loss(self, x, y):
|
||||
pass
|
||||
|
||||
class MyModule2(nn.Module):
|
||||
|
||||
# convert pred to fp32
|
||||
@force_fp32(apply_to=('pred', ))
|
||||
def post_process(self, pred, others):
|
||||
pass
|
||||
"""
|
||||
|
||||
def force_fp32_wrapper(old_func):
|
||||
|
||||
@functools.wraps(old_func)
|
||||
def new_func(*args, **kwargs):
|
||||
# check if the module has set the attribute `fp16_enabled`, if not,
|
||||
# just fallback to the original method.
|
||||
if not isinstance(args[0], torch.nn.Module):
|
||||
raise TypeError('@force_fp32 can only be used to decorate the '
|
||||
'method of nn.Module')
|
||||
if not (hasattr(args[0], 'fp16_enabled') and args[0].fp16_enabled):
|
||||
return old_func(*args, **kwargs)
|
||||
# get the arg spec of the decorated method
|
||||
args_info = getfullargspec(old_func)
|
||||
# get the argument names to be casted
|
||||
args_to_cast = args_info.args if apply_to is None else apply_to
|
||||
# convert the args that need to be processed
|
||||
new_args = []
|
||||
if args:
|
||||
arg_names = args_info.args[:len(args)]
|
||||
for i, arg_name in enumerate(arg_names):
|
||||
if arg_name in args_to_cast:
|
||||
new_args.append(
|
||||
cast_tensor_type(args[i], torch.half, torch.float))
|
||||
else:
|
||||
new_args.append(args[i])
|
||||
# convert the kwargs that need to be processed
|
||||
new_kwargs = dict()
|
||||
if kwargs:
|
||||
for arg_name, arg_value in kwargs.items():
|
||||
if arg_name in args_to_cast:
|
||||
new_kwargs[arg_name] = cast_tensor_type(
|
||||
arg_value, torch.half, torch.float)
|
||||
else:
|
||||
new_kwargs[arg_name] = arg_value
|
||||
# apply converted arguments to the decorated method
|
||||
output = old_func(*new_args, **new_kwargs)
|
||||
# cast the results back to fp32 if necessary
|
||||
if out_fp16:
|
||||
output = cast_tensor_type(output, torch.float, torch.half)
|
||||
return output
|
||||
|
||||
return new_func
|
||||
|
||||
return force_fp32_wrapper
|
@ -1,129 +0,0 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.runner import OptimizerHook
|
||||
from mmcv.utils.parrots_wrapper import _BatchNorm
|
||||
|
||||
from ..utils import allreduce_grads
|
||||
from .utils import cast_tensor_type
|
||||
|
||||
|
||||
class Fp16OptimizerHook(OptimizerHook):
|
||||
"""FP16 optimizer hook.
|
||||
|
||||
The steps of fp16 optimizer is as follows.
|
||||
1. Scale the loss value.
|
||||
2. BP in the fp16 model.
|
||||
2. Copy gradients from fp16 model to fp32 weights.
|
||||
3. Update fp32 weights.
|
||||
4. Copy updated parameters from fp32 weights to fp16 model.
|
||||
|
||||
Refer to https://arxiv.org/abs/1710.03740 for more details.
|
||||
|
||||
Args:
|
||||
loss_scale (float): Scale factor multiplied with loss.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
grad_clip=None,
|
||||
coalesce=True,
|
||||
bucket_size_mb=-1,
|
||||
loss_scale=512.,
|
||||
distributed=True):
|
||||
self.grad_clip = grad_clip
|
||||
self.coalesce = coalesce
|
||||
self.bucket_size_mb = bucket_size_mb
|
||||
self.loss_scale = loss_scale
|
||||
self.distributed = distributed
|
||||
|
||||
def before_run(self, runner):
|
||||
# keep a copy of fp32 weights
|
||||
runner.optimizer.param_groups = copy.deepcopy(
|
||||
runner.optimizer.param_groups)
|
||||
# convert model to fp16
|
||||
wrap_fp16_model(runner.model)
|
||||
|
||||
def copy_grads_to_fp32(self, fp16_net, fp32_weights):
|
||||
"""Copy gradients from fp16 model to fp32 weight copy."""
|
||||
for fp32_param, fp16_param in zip(fp32_weights, fp16_net.parameters()):
|
||||
if fp16_param.grad is not None:
|
||||
if fp32_param.grad is None:
|
||||
fp32_param.grad = fp32_param.data.new(fp32_param.size())
|
||||
fp32_param.grad.copy_(fp16_param.grad)
|
||||
|
||||
def copy_params_to_fp16(self, fp16_net, fp32_weights):
|
||||
"""Copy updated params from fp32 weight copy to fp16 model."""
|
||||
for fp16_param, fp32_param in zip(fp16_net.parameters(), fp32_weights):
|
||||
fp16_param.data.copy_(fp32_param.data)
|
||||
|
||||
def after_train_iter(self, runner):
|
||||
# clear grads of last iteration
|
||||
runner.model.zero_grad()
|
||||
runner.optimizer.zero_grad()
|
||||
# scale the loss value
|
||||
scaled_loss = runner.outputs['loss'] * self.loss_scale
|
||||
scaled_loss.backward()
|
||||
# copy fp16 grads in the model to fp32 params in the optimizer
|
||||
fp32_weights = []
|
||||
for param_group in runner.optimizer.param_groups:
|
||||
fp32_weights += param_group['params']
|
||||
self.copy_grads_to_fp32(runner.model, fp32_weights)
|
||||
# allreduce grads
|
||||
if self.distributed:
|
||||
allreduce_grads(fp32_weights, self.coalesce, self.bucket_size_mb)
|
||||
# scale the gradients back
|
||||
for param in fp32_weights:
|
||||
if param.grad is not None:
|
||||
param.grad.div_(self.loss_scale)
|
||||
if self.grad_clip is not None:
|
||||
self.clip_grads(fp32_weights)
|
||||
# update fp32 params
|
||||
runner.optimizer.step()
|
||||
# copy fp32 params to the fp16 model
|
||||
self.copy_params_to_fp16(runner.model, fp32_weights)
|
||||
|
||||
|
||||
def wrap_fp16_model(model):
|
||||
# convert model to fp16
|
||||
model.half()
|
||||
# patch the normalization layers to make it work in fp32 mode
|
||||
patch_norm_fp32(model)
|
||||
# set `fp16_enabled` flag
|
||||
for m in model.modules():
|
||||
if hasattr(m, 'fp16_enabled'):
|
||||
m.fp16_enabled = True
|
||||
|
||||
|
||||
def patch_norm_fp32(module):
|
||||
if isinstance(module, (_BatchNorm, nn.GroupNorm)):
|
||||
module.float()
|
||||
module.forward = patch_forward_method(module.forward, torch.half,
|
||||
torch.float)
|
||||
for child in module.children():
|
||||
patch_norm_fp32(child)
|
||||
return module
|
||||
|
||||
|
||||
def patch_forward_method(func, src_type, dst_type, convert_output=True):
|
||||
"""Patch the forward method of a module.
|
||||
|
||||
Args:
|
||||
func (callable): The original forward method.
|
||||
src_type (torch.dtype): Type of input arguments to be converted from.
|
||||
dst_type (torch.dtype): Type of input arguments to be converted to.
|
||||
convert_output (bool): Whether to convert the output back to src_type.
|
||||
|
||||
Returns:
|
||||
callable: The patched forward method.
|
||||
"""
|
||||
|
||||
def new_forward(*args, **kwargs):
|
||||
output = func(*cast_tensor_type(args, src_type, dst_type),
|
||||
**cast_tensor_type(kwargs, src_type, dst_type))
|
||||
if convert_output:
|
||||
output = cast_tensor_type(output, dst_type, src_type)
|
||||
return output
|
||||
|
||||
return new_forward
|
@ -1,24 +0,0 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from collections import abc
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def cast_tensor_type(inputs, src_type, dst_type):
|
||||
if isinstance(inputs, torch.Tensor):
|
||||
return inputs.to(dst_type)
|
||||
elif isinstance(inputs, str):
|
||||
return inputs
|
||||
elif isinstance(inputs, np.ndarray):
|
||||
return inputs
|
||||
elif isinstance(inputs, abc.Mapping):
|
||||
return type(inputs)({
|
||||
k: cast_tensor_type(v, src_type, dst_type)
|
||||
for k, v in inputs.items()
|
||||
})
|
||||
elif isinstance(inputs, abc.Iterable):
|
||||
return type(inputs)(
|
||||
cast_tensor_type(item, src_type, dst_type) for item in inputs)
|
||||
else:
|
||||
return inputs
|
@ -1,3 +1,4 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .lamb import Lamb
|
||||
|
||||
__all__ = [
|
||||
|
@ -1,3 +1,4 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .image import (BaseFigureContextManager, ImshowInfosContextManager,
|
||||
color_val_matplotlib, imshow_infos)
|
||||
|
||||
|
@ -1,3 +1,4 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import matplotlib.pyplot as plt
|
||||
import mmcv
|
||||
import numpy as np
|
||||
|
@ -59,7 +59,7 @@ class BaseDataset(Dataset, metaclass=ABCMeta):
|
||||
"""Get all ground-truth labels (categories).
|
||||
|
||||
Returns:
|
||||
list[int]: categories for all images.
|
||||
np.ndarray: categories for all images.
|
||||
"""
|
||||
|
||||
gt_labels = np.array([data['gt_label'] for data in self.data_infos])
|
||||
|
@ -25,10 +25,14 @@ SAMPLERS = Registry('sampler')
|
||||
|
||||
|
||||
def build_dataset(cfg, default_args=None):
|
||||
from .dataset_wrappers import (ConcatDataset, RepeatDataset,
|
||||
ClassBalancedDataset, KFoldDataset)
|
||||
from .dataset_wrappers import (ClassBalancedDataset, ConcatDataset,
|
||||
KFoldDataset, RepeatDataset)
|
||||
if isinstance(cfg, (list, tuple)):
|
||||
dataset = ConcatDataset([build_dataset(c, default_args) for c in cfg])
|
||||
elif cfg['type'] == 'ConcatDataset':
|
||||
dataset = ConcatDataset(
|
||||
[build_dataset(c, default_args) for c in cfg['datasets']],
|
||||
separate_eval=cfg.get('separate_eval', True))
|
||||
elif cfg['type'] == 'RepeatDataset':
|
||||
dataset = RepeatDataset(
|
||||
build_dataset(cfg['dataset'], default_args), cfg['times'])
|
||||
|
@ -4,6 +4,7 @@ import math
|
||||
from collections import defaultdict
|
||||
|
||||
import numpy as np
|
||||
from mmcv.utils import print_log
|
||||
from torch.utils.data.dataset import ConcatDataset as _ConcatDataset
|
||||
|
||||
from .builder import DATASETS
|
||||
@ -18,12 +19,23 @@ class ConcatDataset(_ConcatDataset):
|
||||
|
||||
Args:
|
||||
datasets (list[:obj:`Dataset`]): A list of datasets.
|
||||
separate_eval (bool): Whether to evaluate the results
|
||||
separately if it is used as validation dataset.
|
||||
Defaults to True.
|
||||
"""
|
||||
|
||||
def __init__(self, datasets):
|
||||
def __init__(self, datasets, separate_eval=True):
|
||||
super(ConcatDataset, self).__init__(datasets)
|
||||
self.separate_eval = separate_eval
|
||||
|
||||
self.CLASSES = datasets[0].CLASSES
|
||||
|
||||
if not separate_eval:
|
||||
if len(set([type(ds) for ds in datasets])) != 1:
|
||||
raise NotImplementedError(
|
||||
'To evaluate a concat dataset non-separately, '
|
||||
'all the datasets should have same types')
|
||||
|
||||
def get_cat_ids(self, idx):
|
||||
if idx < 0:
|
||||
if -idx > len(self):
|
||||
@ -37,6 +49,63 @@ class ConcatDataset(_ConcatDataset):
|
||||
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
|
||||
return self.datasets[dataset_idx].get_cat_ids(sample_idx)
|
||||
|
||||
def evaluate(self, results, *args, indices=None, logger=None, **kwargs):
|
||||
"""Evaluate the results.
|
||||
|
||||
Args:
|
||||
results (list[list | tuple]): Testing results of the dataset.
|
||||
indices (list, optional): The indices of samples corresponding to
|
||||
the results. It's unavailable on ConcatDataset.
|
||||
Defaults to None.
|
||||
logger (logging.Logger | str, optional): Logger used for printing
|
||||
related information during evaluation. Defaults to None.
|
||||
|
||||
Returns:
|
||||
dict[str: float]: AP results of the total dataset or each separate
|
||||
dataset if `self.separate_eval=True`.
|
||||
"""
|
||||
if indices is not None:
|
||||
raise NotImplementedError(
|
||||
'Use indices to evaluate speific samples in a ConcatDataset '
|
||||
'is not supported by now.')
|
||||
|
||||
assert len(results) == len(self), \
|
||||
('Dataset and results have different sizes: '
|
||||
f'{len(self)} v.s. {len(results)}')
|
||||
|
||||
# Check whether all the datasets support evaluation
|
||||
for dataset in self.datasets:
|
||||
assert hasattr(dataset, 'evaluate'), \
|
||||
f"{type(dataset)} haven't implemented the evaluate function."
|
||||
|
||||
if self.separate_eval:
|
||||
total_eval_results = dict()
|
||||
for dataset_idx, dataset in enumerate(self.datasets):
|
||||
start_idx = 0 if dataset_idx == 0 else \
|
||||
self.cumulative_sizes[dataset_idx-1]
|
||||
end_idx = self.cumulative_sizes[dataset_idx]
|
||||
|
||||
results_per_dataset = results[start_idx:end_idx]
|
||||
print_log(
|
||||
f'Evaluateing dataset-{dataset_idx} with '
|
||||
f'{len(results_per_dataset)} images now',
|
||||
logger=logger)
|
||||
|
||||
eval_results_per_dataset = dataset.evaluate(
|
||||
results_per_dataset, *args, logger=logger, **kwargs)
|
||||
for k, v in eval_results_per_dataset.items():
|
||||
total_eval_results.update({f'{dataset_idx}_{k}': v})
|
||||
|
||||
return total_eval_results
|
||||
else:
|
||||
original_data_infos = self.datasets[0].data_infos
|
||||
self.datasets[0].data_infos = sum(
|
||||
[dataset.data_infos for dataset in self.datasets], [])
|
||||
eval_results = self.datasets[0].evaluate(
|
||||
results, logger=logger, **kwargs)
|
||||
self.datasets[0].data_infos = original_data_infos
|
||||
return eval_results
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
class RepeatDataset(object):
|
||||
@ -68,6 +137,20 @@ class RepeatDataset(object):
|
||||
def __len__(self):
|
||||
return self.times * self._ori_len
|
||||
|
||||
def evaluate(self, *args, **kwargs):
|
||||
raise NotImplementedError(
|
||||
'evaluate results on a repeated dataset is weird. '
|
||||
'Please inference and evaluate on the original dataset.')
|
||||
|
||||
def __repr__(self):
|
||||
"""Print the number of instance number."""
|
||||
dataset_type = 'Test' if self.test_mode else 'Train'
|
||||
result = (
|
||||
f'\n{self.__class__.__name__} ({self.dataset.__class__.__name__}) '
|
||||
f'{dataset_type} dataset with total number of samples {len(self)}.'
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
# Modified from https://github.com/facebookresearch/detectron2/blob/41d475b75a230221e21d9cac5d69655e3415e3a4/detectron2/data/samplers/distributed_sampler.py#L57 # noqa
|
||||
@DATASETS.register_module()
|
||||
@ -171,6 +254,20 @@ class ClassBalancedDataset(object):
|
||||
def __len__(self):
|
||||
return len(self.repeat_indices)
|
||||
|
||||
def evaluate(self, *args, **kwargs):
|
||||
raise NotImplementedError(
|
||||
'evaluate results on a class-balanced dataset is weird. '
|
||||
'Please inference and evaluate on the original dataset.')
|
||||
|
||||
def __repr__(self):
|
||||
"""Print the number of instance number."""
|
||||
dataset_type = 'Test' if self.test_mode else 'Train'
|
||||
result = (
|
||||
f'\n{self.__class__.__name__} ({self.dataset.__class__.__name__}) '
|
||||
f'{dataset_type} dataset with total number of samples {len(self)}.'
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
class KFoldDataset:
|
||||
|
@ -1,5 +1,4 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
@ -29,8 +28,7 @@ class MultiLabelDataset(BaseDataset):
|
||||
metric='mAP',
|
||||
metric_options=None,
|
||||
indices=None,
|
||||
logger=None,
|
||||
**deprecated_kwargs):
|
||||
logger=None):
|
||||
"""Evaluate the dataset.
|
||||
|
||||
Args:
|
||||
@ -42,7 +40,6 @@ class MultiLabelDataset(BaseDataset):
|
||||
Allowed keys are 'k' and 'thr'. Defaults to None
|
||||
logger (logging.Logger | str, optional): Logger used for printing
|
||||
related information during evaluation. Defaults to None.
|
||||
deprecated_kwargs (dict): Used for containing deprecated arguments.
|
||||
|
||||
Returns:
|
||||
dict: evaluation results
|
||||
@ -50,11 +47,6 @@ class MultiLabelDataset(BaseDataset):
|
||||
if metric_options is None or metric_options == {}:
|
||||
metric_options = {'thr': 0.5}
|
||||
|
||||
if deprecated_kwargs != {}:
|
||||
warnings.warn('Option arguments for metrics has been changed to '
|
||||
'`metric_options`.')
|
||||
metric_options = {**deprecated_kwargs}
|
||||
|
||||
if isinstance(metric, str):
|
||||
metrics = [metric]
|
||||
else:
|
||||
|
@ -1,9 +0,0 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# flake8: noqa
|
||||
import warnings
|
||||
|
||||
from .formatting import *
|
||||
|
||||
warnings.warn('DeprecationWarning: mmcls.datasets.pipelines.formating will be '
|
||||
'deprecated in 2021, please replace it with '
|
||||
'mmcls.datasets.pipelines.formatting.')
|
@ -13,7 +13,7 @@ from .regnet import RegNet
|
||||
from .repvgg import RepVGG
|
||||
from .res2net import Res2Net
|
||||
from .resnest import ResNeSt
|
||||
from .resnet import ResNet, ResNetV1d
|
||||
from .resnet import ResNet, ResNetV1c, ResNetV1d
|
||||
from .resnet_cifar import ResNet_CIFAR
|
||||
from .resnext import ResNeXt
|
||||
from .seresnet import SEResNet
|
||||
@ -34,5 +34,5 @@ __all__ = [
|
||||
'ShuffleNetV2', 'MobileNetV2', 'MobileNetV3', 'VisionTransformer',
|
||||
'SwinTransformer', 'TNT', 'TIMMBackbone', 'T2T_ViT', 'Res2Net', 'RepVGG',
|
||||
'Conformer', 'MlpMixer', 'DistilledVisionTransformer', 'PCPVT', 'SVT',
|
||||
'EfficientNet', 'ConvNeXt', 'HRNet'
|
||||
'EfficientNet', 'ConvNeXt', 'HRNet', 'ResNetV1c'
|
||||
]
|
||||
|
@ -1,3 +1,4 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Sequence
|
||||
|
||||
import torch
|
||||
@ -5,6 +6,7 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmcv.cnn import build_activation_layer, build_norm_layer
|
||||
from mmcv.cnn.bricks.drop import DropPath
|
||||
from mmcv.cnn.bricks.transformer import AdaptivePadding
|
||||
from mmcv.cnn.utils.weight_init import trunc_normal_
|
||||
|
||||
from mmcls.utils import get_root_logger
|
||||
@ -438,9 +440,16 @@ class Conformer(BaseBackbone):
|
||||
self.maxpool = nn.MaxPool2d(
|
||||
kernel_size=3, stride=2, padding=1) # 1 / 4 [56, 56]
|
||||
|
||||
assert patch_size % 16 == 0, 'The patch size of Conformer must ' \
|
||||
'be divisible by 16.'
|
||||
trans_down_stride = patch_size // 4
|
||||
|
||||
# To solve the issue #680
|
||||
# Auto pad the feature map to be divisible by trans_down_stride
|
||||
self.auto_pad = AdaptivePadding(trans_down_stride, trans_down_stride)
|
||||
|
||||
# 1 stage
|
||||
stage1_channels = int(base_channels * self.channel_ratio)
|
||||
trans_down_stride = patch_size // 4
|
||||
self.conv_1 = ConvBlock(
|
||||
in_channels=64,
|
||||
out_channels=stage1_channels,
|
||||
@ -587,6 +596,7 @@ class Conformer(BaseBackbone):
|
||||
|
||||
# stem
|
||||
x_base = self.maxpool(self.act1(self.bn1(self.conv1(x))))
|
||||
x_base = self.auto_pad(x_base)
|
||||
|
||||
# 1 stage [N, 64, 56, 56] -> [N, 128, 56, 56]
|
||||
x = self.conv_1(x_base, out_conv2=False)
|
||||
|
@ -15,21 +15,38 @@ class DistilledVisionTransformer(VisionTransformer):
|
||||
distillation through attention <https://arxiv.org/abs/2012.12877>`_
|
||||
|
||||
Args:
|
||||
arch (str | dict): Vision Transformer architecture
|
||||
Default: 'b'
|
||||
img_size (int | tuple): Input image size
|
||||
patch_size (int | tuple): The patch size
|
||||
arch (str | dict): Vision Transformer architecture. If use string,
|
||||
choose from 'small', 'base', 'large', 'deit-tiny', 'deit-small'
|
||||
and 'deit-base'. If use dict, it should have below keys:
|
||||
|
||||
- **embed_dims** (int): The dimensions of embedding.
|
||||
- **num_layers** (int): The number of transformer encoder layers.
|
||||
- **num_heads** (int): The number of heads in attention modules.
|
||||
- **feedforward_channels** (int): The hidden dimensions in
|
||||
feedforward modules.
|
||||
|
||||
Defaults to 'deit-base'.
|
||||
img_size (int | tuple): The expected input image shape. Because we
|
||||
support dynamic input shape, just set the argument to the most
|
||||
common input image shape. Defaults to 224.
|
||||
patch_size (int | tuple): The patch size in patch embedding.
|
||||
Defaults to 16.
|
||||
in_channels (int): The num of input channels. Defaults to 3.
|
||||
out_indices (Sequence | int): Output from which stages.
|
||||
Defaults to -1, means the last stage.
|
||||
drop_rate (float): Probability of an element to be zeroed.
|
||||
Defaults to 0.
|
||||
drop_path_rate (float): stochastic depth rate. Defaults to 0.
|
||||
qkv_bias (bool): Whether to add bias for qkv in attention modules.
|
||||
Defaults to True.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Defaults to ``dict(type='LN')``.
|
||||
final_norm (bool): Whether to add a additional layer to normalize
|
||||
final feature map. Defaults to True.
|
||||
with_cls_token (bool): Whether concatenating class token into image
|
||||
tokens as transformer input. Defaults to True.
|
||||
output_cls_token (bool): Whether output the cls_token. If set True,
|
||||
`with_cls_token` must be True. Defaults to True.
|
||||
``with_cls_token`` must be True. Defaults to True.
|
||||
interpolate_mode (str): Select the interpolate mode for position
|
||||
embeding vector resize. Defaults to "bicubic".
|
||||
patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict.
|
||||
@ -40,22 +57,31 @@ class DistilledVisionTransformer(VisionTransformer):
|
||||
"""
|
||||
num_extra_tokens = 2 # cls_token, dist_token
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(DistilledVisionTransformer, self).__init__(*args, **kwargs)
|
||||
def __init__(self, arch='deit-base', *args, **kwargs):
|
||||
super(DistilledVisionTransformer, self).__init__(
|
||||
arch=arch, *args, **kwargs)
|
||||
self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims))
|
||||
|
||||
def forward(self, x):
|
||||
B = x.shape[0]
|
||||
x = self.patch_embed(x)
|
||||
patch_resolution = self.patch_embed.patches_resolution
|
||||
x, patch_resolution = self.patch_embed(x)
|
||||
|
||||
# stole cls_tokens impl from Phil Wang, thanks
|
||||
cls_tokens = self.cls_token.expand(B, -1, -1)
|
||||
dist_token = self.dist_token.expand(B, -1, -1)
|
||||
x = torch.cat((cls_tokens, dist_token, x), dim=1)
|
||||
x = x + self.pos_embed
|
||||
x = x + self.resize_pos_embed(
|
||||
self.pos_embed,
|
||||
self.patch_resolution,
|
||||
patch_resolution,
|
||||
mode=self.interpolate_mode,
|
||||
num_extra_tokens=self.num_extra_tokens)
|
||||
x = self.drop_after_pos(x)
|
||||
|
||||
if not self.with_cls_token:
|
||||
# Remove class token for transformer encoder input
|
||||
x = x[:, 2:]
|
||||
|
||||
outs = []
|
||||
for i, layer in enumerate(self.layers):
|
||||
x = layer(x)
|
||||
@ -65,10 +91,16 @@ class DistilledVisionTransformer(VisionTransformer):
|
||||
|
||||
if i in self.out_indices:
|
||||
B, _, C = x.shape
|
||||
patch_token = x[:, 2:].reshape(B, *patch_resolution, C)
|
||||
patch_token = patch_token.permute(0, 3, 1, 2)
|
||||
cls_token = x[:, 0]
|
||||
dist_token = x[:, 1]
|
||||
if self.with_cls_token:
|
||||
patch_token = x[:, 2:].reshape(B, *patch_resolution, C)
|
||||
patch_token = patch_token.permute(0, 3, 1, 2)
|
||||
cls_token = x[:, 0]
|
||||
dist_token = x[:, 1]
|
||||
else:
|
||||
patch_token = x.reshape(B, *patch_resolution, C)
|
||||
patch_token = patch_token.permute(0, 3, 1, 2)
|
||||
cls_token = None
|
||||
dist_token = None
|
||||
if self.output_cls_token:
|
||||
out = [patch_token, cls_token, dist_token]
|
||||
else:
|
||||
|
@ -3,11 +3,11 @@ from typing import Sequence
|
||||
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import build_norm_layer
|
||||
from mmcv.cnn.bricks.transformer import FFN
|
||||
from mmcv.cnn.bricks.transformer import FFN, PatchEmbed
|
||||
from mmcv.runner.base_module import BaseModule, ModuleList
|
||||
|
||||
from ..builder import BACKBONES
|
||||
from ..utils import PatchEmbed, to_2tuple
|
||||
from ..utils import to_2tuple
|
||||
from .base_backbone import BaseBackbone
|
||||
|
||||
|
||||
@ -105,10 +105,20 @@ class MlpMixer(BaseBackbone):
|
||||
<https://arxiv.org/pdf/2105.01601.pdf>`_
|
||||
|
||||
Args:
|
||||
arch (str | dict): MLP Mixer architecture
|
||||
Defaults to 'b'.
|
||||
img_size (int | tuple): Input image size.
|
||||
patch_size (int | tuple): The patch size.
|
||||
arch (str | dict): MLP Mixer architecture. If use string, choose from
|
||||
'small', 'base' and 'large'. If use dict, it should have below
|
||||
keys:
|
||||
|
||||
- **embed_dims** (int): The dimensions of embedding.
|
||||
- **num_layers** (int): The number of MLP blocks.
|
||||
- **tokens_mlp_dims** (int): The hidden dimensions for tokens FFNs.
|
||||
- **channels_mlp_dims** (int): The The hidden dimensions for
|
||||
channels FFNs.
|
||||
|
||||
Defaults to 'base'.
|
||||
img_size (int | tuple): The input image shape. Defaults to 224.
|
||||
patch_size (int | tuple): The patch size in patch embedding.
|
||||
Defaults to 16.
|
||||
out_indices (Sequence | int): Output from which layer.
|
||||
Defaults to -1, means the last layer.
|
||||
drop_rate (float): Probability of an element to be zeroed.
|
||||
@ -149,7 +159,7 @@ class MlpMixer(BaseBackbone):
|
||||
}
|
||||
|
||||
def __init__(self,
|
||||
arch='b',
|
||||
arch='base',
|
||||
img_size=224,
|
||||
patch_size=16,
|
||||
out_indices=-1,
|
||||
@ -184,14 +194,16 @@ class MlpMixer(BaseBackbone):
|
||||
self.img_size = to_2tuple(img_size)
|
||||
|
||||
_patch_cfg = dict(
|
||||
img_size=img_size,
|
||||
input_size=img_size,
|
||||
embed_dims=self.embed_dims,
|
||||
conv_cfg=dict(
|
||||
type='Conv2d', kernel_size=patch_size, stride=patch_size),
|
||||
conv_type='Conv2d',
|
||||
kernel_size=patch_size,
|
||||
stride=patch_size,
|
||||
)
|
||||
_patch_cfg.update(patch_cfg)
|
||||
self.patch_embed = PatchEmbed(**_patch_cfg)
|
||||
num_patches = self.patch_embed.num_patches
|
||||
self.patch_resolution = self.patch_embed.init_out_size
|
||||
num_patches = self.patch_resolution[0] * self.patch_resolution[1]
|
||||
|
||||
if isinstance(out_indices, int):
|
||||
out_indices = [out_indices]
|
||||
@ -232,7 +244,10 @@ class MlpMixer(BaseBackbone):
|
||||
return getattr(self, self.norm1_name)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.patch_embed(x)
|
||||
assert x.shape[2:] == self.img_size, \
|
||||
"The MLP-Mixer doesn't support dynamic input shape. " \
|
||||
f'Please input images with shape {self.img_size}'
|
||||
x, _ = self.patch_embed(x)
|
||||
|
||||
outs = []
|
||||
for i, layer in enumerate(self.layers):
|
||||
|
@ -1,3 +1,4 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint as cp
|
||||
|
@ -653,6 +653,22 @@ class ResNet(BaseBackbone):
|
||||
m.eval()
|
||||
|
||||
|
||||
@BACKBONES.register_module()
|
||||
class ResNetV1c(ResNet):
|
||||
"""ResNetV1c backbone.
|
||||
|
||||
This variant is described in `Bag of Tricks.
|
||||
<https://arxiv.org/pdf/1812.01187.pdf>`_.
|
||||
|
||||
Compared with default ResNet(ResNetV1b), ResNetV1c replaces the 7x7 conv
|
||||
in the input stem with three 3x3 convs.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super(ResNetV1c, self).__init__(
|
||||
deep_stem=True, avg_down=False, **kwargs)
|
||||
|
||||
|
||||
@BACKBONES.register_module()
|
||||
class ResNetV1d(ResNet):
|
||||
"""ResNetV1d backbone.
|
||||
|
@ -2,17 +2,18 @@
|
||||
from copy import deepcopy
|
||||
from typing import Sequence
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.checkpoint as cp
|
||||
from mmcv.cnn import build_norm_layer
|
||||
from mmcv.cnn.bricks.transformer import FFN
|
||||
from mmcv.cnn.bricks.transformer import FFN, PatchEmbed, PatchMerging
|
||||
from mmcv.cnn.utils.weight_init import trunc_normal_
|
||||
from mmcv.runner.base_module import BaseModule, ModuleList
|
||||
from mmcv.utils.parrots_wrapper import _BatchNorm
|
||||
|
||||
from ..builder import BACKBONES
|
||||
from ..utils import PatchEmbed, PatchMerging, ShiftWindowMSA
|
||||
from ..utils import ShiftWindowMSA, resize_pos_embed, to_2tuple
|
||||
from .base_backbone import BaseBackbone
|
||||
|
||||
|
||||
@ -21,45 +22,41 @@ class SwinBlock(BaseModule):
|
||||
|
||||
Args:
|
||||
embed_dims (int): Number of input channels.
|
||||
input_resolution (Tuple[int, int]): The resolution of the input feature
|
||||
map.
|
||||
num_heads (int): Number of attention heads.
|
||||
window_size (int, optional): The height and width of the window.
|
||||
Defaults to 7.
|
||||
shift (bool, optional): Shift the attention window or not.
|
||||
window_size (int): The height and width of the window. Defaults to 7.
|
||||
shift (bool): Shift the attention window or not. Defaults to False.
|
||||
ffn_ratio (float): The expansion ratio of feedforward network hidden
|
||||
layer channels. Defaults to 4.
|
||||
drop_path (float): The drop path rate after attention and ffn.
|
||||
Defaults to 0.
|
||||
pad_small_map (bool): If True, pad the small feature map to the window
|
||||
size, which is common used in detection and segmentation. If False,
|
||||
avoid shifting window and shrink the window size to the size of
|
||||
feature map, which is common used in classification.
|
||||
Defaults to False.
|
||||
ffn_ratio (float, optional): The expansion ratio of feedforward network
|
||||
hidden layer channels. Defaults to 4.
|
||||
drop_path (float, optional): The drop path rate after attention and
|
||||
ffn. Defaults to 0.
|
||||
attn_cfgs (dict, optional): The extra config of Shift Window-MSA.
|
||||
attn_cfgs (dict): The extra config of Shift Window-MSA.
|
||||
Defaults to empty dict.
|
||||
ffn_cfgs (dict, optional): The extra config of FFN.
|
||||
Defaults to empty dict.
|
||||
norm_cfg (dict, optional): The config of norm layers.
|
||||
Defaults to dict(type='LN').
|
||||
with_cp (bool, optional): Use checkpoint or not. Using checkpoint
|
||||
will save some memory while slowing down the training speed.
|
||||
Defaults to False.
|
||||
auto_pad (bool, optional): Auto pad the feature map to be divisible by
|
||||
window_size, Defaults to False.
|
||||
ffn_cfgs (dict): The extra config of FFN. Defaults to empty dict.
|
||||
norm_cfg (dict): The config of norm layers.
|
||||
Defaults to ``dict(type='LN')``.
|
||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
||||
memory while slowing down the training speed. Defaults to False.
|
||||
init_cfg (dict, optional): The extra config for initialization.
|
||||
Default: None.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
embed_dims,
|
||||
input_resolution,
|
||||
num_heads,
|
||||
window_size=7,
|
||||
shift=False,
|
||||
ffn_ratio=4.,
|
||||
drop_path=0.,
|
||||
pad_small_map=False,
|
||||
attn_cfgs=dict(),
|
||||
ffn_cfgs=dict(),
|
||||
norm_cfg=dict(type='LN'),
|
||||
with_cp=False,
|
||||
auto_pad=False,
|
||||
init_cfg=None):
|
||||
|
||||
super(SwinBlock, self).__init__(init_cfg)
|
||||
@ -67,12 +64,11 @@ class SwinBlock(BaseModule):
|
||||
|
||||
_attn_cfgs = {
|
||||
'embed_dims': embed_dims,
|
||||
'input_resolution': input_resolution,
|
||||
'num_heads': num_heads,
|
||||
'shift_size': window_size // 2 if shift else 0,
|
||||
'window_size': window_size,
|
||||
'dropout_layer': dict(type='DropPath', drop_prob=drop_path),
|
||||
'auto_pad': auto_pad,
|
||||
'pad_small_map': pad_small_map,
|
||||
**attn_cfgs
|
||||
}
|
||||
self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1]
|
||||
@ -90,12 +86,12 @@ class SwinBlock(BaseModule):
|
||||
self.norm2 = build_norm_layer(norm_cfg, embed_dims)[1]
|
||||
self.ffn = FFN(**_ffn_cfgs)
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x, hw_shape):
|
||||
|
||||
def _inner_forward(x):
|
||||
identity = x
|
||||
x = self.norm1(x)
|
||||
x = self.attn(x)
|
||||
x = self.attn(x, hw_shape)
|
||||
x = x + identity
|
||||
|
||||
identity = x
|
||||
@ -117,38 +113,39 @@ class SwinBlockSequence(BaseModule):
|
||||
|
||||
Args:
|
||||
embed_dims (int): Number of input channels.
|
||||
input_resolution (Tuple[int, int]): The resolution of the input feature
|
||||
map.
|
||||
depth (int): Number of successive swin transformer blocks.
|
||||
num_heads (int): Number of attention heads.
|
||||
downsample (bool, optional): Downsample the output of blocks by patch
|
||||
merging. Defaults to False.
|
||||
downsample_cfg (dict, optional): The extra config of the patch merging
|
||||
layer. Defaults to empty dict.
|
||||
drop_paths (Sequence[float] | float, optional): The drop path rate in
|
||||
each block. Defaults to 0.
|
||||
block_cfgs (Sequence[dict] | dict, optional): The extra config of each
|
||||
block. Defaults to empty dicts.
|
||||
with_cp (bool, optional): Use checkpoint or not. Using checkpoint
|
||||
will save some memory while slowing down the training speed.
|
||||
window_size (int): The height and width of the window. Defaults to 7.
|
||||
downsample (bool): Downsample the output of blocks by patch merging.
|
||||
Defaults to False.
|
||||
downsample_cfg (dict): The extra config of the patch merging layer.
|
||||
Defaults to empty dict.
|
||||
drop_paths (Sequence[float] | float): The drop path rate in each block.
|
||||
Defaults to 0.
|
||||
block_cfgs (Sequence[dict] | dict): The extra config of each block.
|
||||
Defaults to empty dicts.
|
||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
||||
memory while slowing down the training speed. Defaults to False.
|
||||
pad_small_map (bool): If True, pad the small feature map to the window
|
||||
size, which is common used in detection and segmentation. If False,
|
||||
avoid shifting window and shrink the window size to the size of
|
||||
feature map, which is common used in classification.
|
||||
Defaults to False.
|
||||
auto_pad (bool, optional): Auto pad the feature map to be divisible by
|
||||
window_size, Defaults to False.
|
||||
init_cfg (dict, optional): The extra config for initialization.
|
||||
Default: None.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
embed_dims,
|
||||
input_resolution,
|
||||
depth,
|
||||
num_heads,
|
||||
window_size=7,
|
||||
downsample=False,
|
||||
downsample_cfg=dict(),
|
||||
drop_paths=0.,
|
||||
block_cfgs=dict(),
|
||||
with_cp=False,
|
||||
auto_pad=False,
|
||||
pad_small_map=False,
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg)
|
||||
|
||||
@ -159,17 +156,16 @@ class SwinBlockSequence(BaseModule):
|
||||
block_cfgs = [deepcopy(block_cfgs) for _ in range(depth)]
|
||||
|
||||
self.embed_dims = embed_dims
|
||||
self.input_resolution = input_resolution
|
||||
self.blocks = ModuleList()
|
||||
for i in range(depth):
|
||||
_block_cfg = {
|
||||
'embed_dims': embed_dims,
|
||||
'input_resolution': input_resolution,
|
||||
'num_heads': num_heads,
|
||||
'window_size': window_size,
|
||||
'shift': False if i % 2 == 0 else True,
|
||||
'drop_path': drop_paths[i],
|
||||
'with_cp': with_cp,
|
||||
'auto_pad': auto_pad,
|
||||
'pad_small_map': pad_small_map,
|
||||
**block_cfgs[i]
|
||||
}
|
||||
block = SwinBlock(**_block_cfg)
|
||||
@ -177,9 +173,8 @@ class SwinBlockSequence(BaseModule):
|
||||
|
||||
if downsample:
|
||||
_downsample_cfg = {
|
||||
'input_resolution': input_resolution,
|
||||
'in_channels': embed_dims,
|
||||
'expansion_ratio': 2,
|
||||
'out_channels': 2 * embed_dims,
|
||||
'norm_cfg': dict(type='LN'),
|
||||
**downsample_cfg
|
||||
}
|
||||
@ -187,20 +182,15 @@ class SwinBlockSequence(BaseModule):
|
||||
else:
|
||||
self.downsample = None
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x, in_shape):
|
||||
for block in self.blocks:
|
||||
x = block(x)
|
||||
x = block(x, in_shape)
|
||||
|
||||
if self.downsample:
|
||||
x = self.downsample(x)
|
||||
return x
|
||||
|
||||
@property
|
||||
def out_resolution(self):
|
||||
if self.downsample:
|
||||
return self.downsample.output_resolution
|
||||
x, out_shape = self.downsample(x, in_shape)
|
||||
else:
|
||||
return self.input_resolution
|
||||
out_shape = in_shape
|
||||
return x, out_shape
|
||||
|
||||
@property
|
||||
def out_channels(self):
|
||||
@ -212,7 +202,8 @@ class SwinBlockSequence(BaseModule):
|
||||
|
||||
@BACKBONES.register_module()
|
||||
class SwinTransformer(BaseBackbone):
|
||||
""" Swin Transformer
|
||||
"""Swin Transformer.
|
||||
|
||||
A PyTorch implement of : `Swin Transformer:
|
||||
Hierarchical Vision Transformer using Shifted Windows
|
||||
<https://arxiv.org/abs/2103.14030>`_
|
||||
@ -221,34 +212,47 @@ class SwinTransformer(BaseBackbone):
|
||||
https://github.com/microsoft/Swin-Transformer
|
||||
|
||||
Args:
|
||||
arch (str | dict): Swin Transformer architecture
|
||||
Defaults to 'T'.
|
||||
img_size (int | tuple): The size of input image.
|
||||
Defaults to 224.
|
||||
in_channels (int): The num of input channels.
|
||||
Defaults to 3.
|
||||
drop_rate (float): Dropout rate after embedding.
|
||||
Defaults to 0.
|
||||
drop_path_rate (float): Stochastic depth rate.
|
||||
Defaults to 0.1.
|
||||
arch (str | dict): Swin Transformer architecture. If use string, choose
|
||||
from 'tiny', 'small', 'base' and 'large'. If use dict, it should
|
||||
have below keys:
|
||||
|
||||
- **embed_dims** (int): The dimensions of embedding.
|
||||
- **depths** (List[int]): The number of blocks in each stage.
|
||||
- **num_heads** (List[int]): The number of heads in attention
|
||||
modules of each stage.
|
||||
|
||||
Defaults to 'tiny'.
|
||||
img_size (int | tuple): The expected input image shape. Because we
|
||||
support dynamic input shape, just set the argument to the most
|
||||
common input image shape. Defaults to 224.
|
||||
patch_size (int | tuple): The patch size in patch embedding.
|
||||
Defaults to 4.
|
||||
in_channels (int): The num of input channels. Defaults to 3.
|
||||
window_size (int): The height and width of the window. Defaults to 7.
|
||||
drop_rate (float): Dropout rate after embedding. Defaults to 0.
|
||||
drop_path_rate (float): Stochastic depth rate. Defaults to 0.1.
|
||||
use_abs_pos_embed (bool): If True, add absolute position embedding to
|
||||
the patch embedding. Defaults to False.
|
||||
with_cp (bool, optional): Use checkpoint or not. Using checkpoint
|
||||
will save some memory while slowing down the training speed.
|
||||
Defaults to False.
|
||||
interpolate_mode (str): Select the interpolate mode for absolute
|
||||
position embeding vector resize. Defaults to "bicubic".
|
||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
||||
memory while slowing down the training speed. Defaults to False.
|
||||
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
|
||||
-1 means not freezing any parameters. Defaults to -1.
|
||||
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
||||
freeze running stats (mean and var). Note: Effect on Batch Norm
|
||||
and its variants only. Defaults to False.
|
||||
auto_pad (bool): If True, auto pad feature map to fit window_size.
|
||||
pad_small_map (bool): If True, pad the small feature map to the window
|
||||
size, which is common used in detection and segmentation. If False,
|
||||
avoid shifting window and shrink the window size to the size of
|
||||
feature map, which is common used in classification.
|
||||
Defaults to False.
|
||||
norm_cfg (dict, optional): Config dict for normalization layer at end
|
||||
of backone. Defaults to dict(type='LN')
|
||||
stage_cfgs (Sequence | dict, optional): Extra config dict for each
|
||||
stage. Defaults to empty dict.
|
||||
patch_cfg (dict, optional): Extra config dict for patch embedding.
|
||||
Defaults to empty dict.
|
||||
norm_cfg (dict): Config dict for normalization layer for all output
|
||||
features. Defaults to ``dict(type='LN')``
|
||||
stage_cfgs (Sequence[dict] | dict): Extra config dict for each
|
||||
stage. Defaults to an empty dict.
|
||||
patch_cfg (dict): Extra config dict for patch embedding.
|
||||
Defaults to an empty dict.
|
||||
init_cfg (dict, optional): The Config for initialization.
|
||||
Defaults to None.
|
||||
|
||||
@ -258,8 +262,7 @@ class SwinTransformer(BaseBackbone):
|
||||
>>> extra_config = dict(
|
||||
>>> arch='tiny',
|
||||
>>> stage_cfgs=dict(downsample_cfg={'kernel_size': 3,
|
||||
>>> 'expansion_ratio': 3}),
|
||||
>>> auto_pad=True)
|
||||
>>> 'expansion_ratio': 3}))
|
||||
>>> self = SwinTransformer(**extra_config)
|
||||
>>> inputs = torch.rand(1, 3, 224, 224)
|
||||
>>> output = self.forward(inputs)
|
||||
@ -285,25 +288,29 @@ class SwinTransformer(BaseBackbone):
|
||||
'num_heads': [6, 12, 24, 48]}),
|
||||
} # yapf: disable
|
||||
|
||||
_version = 2
|
||||
_version = 3
|
||||
num_extra_tokens = 0
|
||||
|
||||
def __init__(self,
|
||||
arch='T',
|
||||
arch='tiny',
|
||||
img_size=224,
|
||||
patch_size=4,
|
||||
in_channels=3,
|
||||
window_size=7,
|
||||
drop_rate=0.,
|
||||
drop_path_rate=0.1,
|
||||
out_indices=(3, ),
|
||||
use_abs_pos_embed=False,
|
||||
auto_pad=False,
|
||||
interpolate_mode='bicubic',
|
||||
with_cp=False,
|
||||
frozen_stages=-1,
|
||||
norm_eval=False,
|
||||
pad_small_map=False,
|
||||
norm_cfg=dict(type='LN'),
|
||||
stage_cfgs=dict(),
|
||||
patch_cfg=dict(),
|
||||
init_cfg=None):
|
||||
super(SwinTransformer, self).__init__(init_cfg)
|
||||
super(SwinTransformer, self).__init__(init_cfg=init_cfg)
|
||||
|
||||
if isinstance(arch, str):
|
||||
arch = arch.lower()
|
||||
@ -311,7 +318,7 @@ class SwinTransformer(BaseBackbone):
|
||||
f'Arch {arch} is not in default archs {set(self.arch_zoo)}'
|
||||
self.arch_settings = self.arch_zoo[arch]
|
||||
else:
|
||||
essential_keys = {'embed_dims', 'depths', 'num_head'}
|
||||
essential_keys = {'embed_dims', 'depths', 'num_heads'}
|
||||
assert isinstance(arch, dict) and set(arch) == essential_keys, \
|
||||
f'Custom arch needs a dict with keys {essential_keys}'
|
||||
self.arch_settings = arch
|
||||
@ -322,26 +329,28 @@ class SwinTransformer(BaseBackbone):
|
||||
self.num_layers = len(self.depths)
|
||||
self.out_indices = out_indices
|
||||
self.use_abs_pos_embed = use_abs_pos_embed
|
||||
self.auto_pad = auto_pad
|
||||
self.interpolate_mode = interpolate_mode
|
||||
self.frozen_stages = frozen_stages
|
||||
self.num_extra_tokens = 0
|
||||
|
||||
_patch_cfg = {
|
||||
'img_size': img_size,
|
||||
'in_channels': in_channels,
|
||||
'embed_dims': self.embed_dims,
|
||||
'conv_cfg': dict(type='Conv2d', kernel_size=4, stride=4),
|
||||
'norm_cfg': dict(type='LN'),
|
||||
**patch_cfg
|
||||
}
|
||||
_patch_cfg = dict(
|
||||
in_channels=in_channels,
|
||||
input_size=img_size,
|
||||
embed_dims=self.embed_dims,
|
||||
conv_type='Conv2d',
|
||||
kernel_size=patch_size,
|
||||
stride=patch_size,
|
||||
norm_cfg=dict(type='LN'),
|
||||
)
|
||||
_patch_cfg.update(patch_cfg)
|
||||
self.patch_embed = PatchEmbed(**_patch_cfg)
|
||||
num_patches = self.patch_embed.num_patches
|
||||
patches_resolution = self.patch_embed.patches_resolution
|
||||
self.patches_resolution = patches_resolution
|
||||
self.patch_resolution = self.patch_embed.init_out_size
|
||||
|
||||
if self.use_abs_pos_embed:
|
||||
num_patches = self.patch_resolution[0] * self.patch_resolution[1]
|
||||
self.absolute_pos_embed = nn.Parameter(
|
||||
torch.zeros(1, num_patches, self.embed_dims))
|
||||
self._register_load_state_dict_pre_hook(
|
||||
self._prepare_abs_pos_embed)
|
||||
|
||||
self.drop_after_pos = nn.Dropout(p=drop_rate)
|
||||
self.norm_eval = norm_eval
|
||||
@ -354,7 +363,6 @@ class SwinTransformer(BaseBackbone):
|
||||
|
||||
self.stages = ModuleList()
|
||||
embed_dims = [self.embed_dims]
|
||||
input_resolution = patches_resolution
|
||||
for i, (depth,
|
||||
num_heads) in enumerate(zip(self.depths, self.num_heads)):
|
||||
if isinstance(stage_cfgs, Sequence):
|
||||
@ -366,11 +374,11 @@ class SwinTransformer(BaseBackbone):
|
||||
'embed_dims': embed_dims[-1],
|
||||
'depth': depth,
|
||||
'num_heads': num_heads,
|
||||
'window_size': window_size,
|
||||
'downsample': downsample,
|
||||
'input_resolution': input_resolution,
|
||||
'drop_paths': dpr[:depth],
|
||||
'with_cp': with_cp,
|
||||
'auto_pad': auto_pad,
|
||||
'pad_small_map': pad_small_map,
|
||||
**stage_cfg
|
||||
}
|
||||
|
||||
@ -379,7 +387,6 @@ class SwinTransformer(BaseBackbone):
|
||||
|
||||
dpr = dpr[depth:]
|
||||
embed_dims.append(stage.out_channels)
|
||||
input_resolution = stage.out_resolution
|
||||
|
||||
for i in out_indices:
|
||||
if norm_cfg is not None:
|
||||
@ -401,18 +408,20 @@ class SwinTransformer(BaseBackbone):
|
||||
trunc_normal_(self.absolute_pos_embed, std=0.02)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.patch_embed(x)
|
||||
x, hw_shape = self.patch_embed(x)
|
||||
if self.use_abs_pos_embed:
|
||||
x = x + self.absolute_pos_embed
|
||||
x = x + resize_pos_embed(
|
||||
self.absolute_pos_embed, self.patch_resolution, hw_shape,
|
||||
self.interpolate_mode, self.num_extra_tokens)
|
||||
x = self.drop_after_pos(x)
|
||||
|
||||
outs = []
|
||||
for i, stage in enumerate(self.stages):
|
||||
x = stage(x)
|
||||
x, hw_shape = stage(x, hw_shape)
|
||||
if i in self.out_indices:
|
||||
norm_layer = getattr(self, f'norm{i}')
|
||||
out = norm_layer(x)
|
||||
out = out.view(-1, *stage.out_resolution,
|
||||
out = out.view(-1, *hw_shape,
|
||||
stage.out_channels).permute(0, 3, 1,
|
||||
2).contiguous()
|
||||
outs.append(out)
|
||||
@ -433,6 +442,12 @@ class SwinTransformer(BaseBackbone):
|
||||
convert_key = k.replace('norm.', f'norm{final_stage_num}.')
|
||||
state_dict[convert_key] = state_dict[k]
|
||||
del state_dict[k]
|
||||
if (version is None
|
||||
or version < 3) and self.__class__ is SwinTransformer:
|
||||
state_dict_keys = list(state_dict.keys())
|
||||
for k in state_dict_keys:
|
||||
if 'attn_mask' in k:
|
||||
del state_dict[k]
|
||||
|
||||
super()._load_from_state_dict(state_dict, prefix, local_metadata,
|
||||
*args, **kwargs)
|
||||
@ -461,3 +476,26 @@ class SwinTransformer(BaseBackbone):
|
||||
# trick: eval have effect on BatchNorm only
|
||||
if isinstance(m, _BatchNorm):
|
||||
m.eval()
|
||||
|
||||
def _prepare_abs_pos_embed(self, state_dict, prefix, *args, **kwargs):
|
||||
name = prefix + 'absolute_pos_embed'
|
||||
if name not in state_dict.keys():
|
||||
return
|
||||
|
||||
ckpt_pos_embed_shape = state_dict[name].shape
|
||||
if self.absolute_pos_embed.shape != ckpt_pos_embed_shape:
|
||||
from mmcls.utils import get_root_logger
|
||||
logger = get_root_logger()
|
||||
logger.info(
|
||||
'Resize the absolute_pos_embed shape from '
|
||||
f'{ckpt_pos_embed_shape} to {self.absolute_pos_embed.shape}.')
|
||||
|
||||
ckpt_pos_embed_shape = to_2tuple(
|
||||
int(np.sqrt(ckpt_pos_embed_shape[1] - self.num_extra_tokens)))
|
||||
pos_embed_shape = self.patch_embed.init_out_size
|
||||
|
||||
state_dict[name] = resize_pos_embed(state_dict[name],
|
||||
ckpt_pos_embed_shape,
|
||||
pos_embed_shape,
|
||||
self.interpolate_mode,
|
||||
self.num_extra_tokens)
|
||||
|
@ -11,7 +11,7 @@ from mmcv.cnn.utils.weight_init import trunc_normal_
|
||||
from mmcv.runner.base_module import BaseModule, ModuleList
|
||||
|
||||
from ..builder import BACKBONES
|
||||
from ..utils import MultiheadAttention
|
||||
from ..utils import MultiheadAttention, resize_pos_embed, to_2tuple
|
||||
from .base_backbone import BaseBackbone
|
||||
|
||||
|
||||
@ -173,26 +173,44 @@ class T2TModule(BaseModule):
|
||||
raise NotImplementedError("Performer hasn't been implemented.")
|
||||
|
||||
# there are 3 soft split, stride are 4,2,2 separately
|
||||
self.num_patches = (img_size // (4 * 2 * 2))**2
|
||||
out_side = img_size // (4 * 2 * 2)
|
||||
self.init_out_size = [out_side, out_side]
|
||||
self.num_patches = out_side**2
|
||||
|
||||
@staticmethod
|
||||
def _get_unfold_size(unfold: nn.Unfold, input_size):
|
||||
h, w = input_size
|
||||
kernel_size = to_2tuple(unfold.kernel_size)
|
||||
stride = to_2tuple(unfold.stride)
|
||||
padding = to_2tuple(unfold.padding)
|
||||
dilation = to_2tuple(unfold.dilation)
|
||||
|
||||
h_out = (h + 2 * padding[0] - dilation[0] *
|
||||
(kernel_size[0] - 1) - 1) // stride[0] + 1
|
||||
w_out = (w + 2 * padding[1] - dilation[1] *
|
||||
(kernel_size[1] - 1) - 1) // stride[1] + 1
|
||||
return (h_out, w_out)
|
||||
|
||||
def forward(self, x):
|
||||
# step0: soft split
|
||||
hw_shape = self._get_unfold_size(self.soft_split0, x.shape[2:])
|
||||
x = self.soft_split0(x).transpose(1, 2)
|
||||
|
||||
for step in [1, 2]:
|
||||
# re-structurization/reconstruction
|
||||
attn = getattr(self, f'attention{step}')
|
||||
x = attn(x).transpose(1, 2)
|
||||
B, C, new_HW = x.shape
|
||||
x = x.reshape(B, C, int(np.sqrt(new_HW)), int(np.sqrt(new_HW)))
|
||||
B, C, _ = x.shape
|
||||
x = x.reshape(B, C, hw_shape[0], hw_shape[1])
|
||||
|
||||
# soft split
|
||||
soft_split = getattr(self, f'soft_split{step}')
|
||||
hw_shape = self._get_unfold_size(soft_split, hw_shape)
|
||||
x = soft_split(x).transpose(1, 2)
|
||||
|
||||
# final tokens
|
||||
x = self.project(x)
|
||||
return x
|
||||
return x, hw_shape
|
||||
|
||||
|
||||
def get_sinusoid_encoding(n_position, embed_dims):
|
||||
@ -231,43 +249,52 @@ class T2T_ViT(BaseBackbone):
|
||||
Transformers from Scratch on ImageNet <https://arxiv.org/abs/2101.11986>`_
|
||||
|
||||
Args:
|
||||
img_size (int): Input image size.
|
||||
img_size (int | tuple): The expected input image shape. Because we
|
||||
support dynamic input shape, just set the argument to the most
|
||||
common input image shape. Defaults to 224.
|
||||
in_channels (int): Number of input channels.
|
||||
embed_dims (int): Embedding dimension.
|
||||
t2t_cfg (dict): Extra config of Tokens-to-Token module.
|
||||
Defaults to an empty dict.
|
||||
drop_rate (float): Dropout rate after position embedding.
|
||||
Defaults to 0.
|
||||
num_layers (int): Num of transformer layers in encoder.
|
||||
Defaults to 14.
|
||||
out_indices (Sequence | int): Output from which stages.
|
||||
Defaults to -1, means the last stage.
|
||||
layer_cfgs (Sequence | dict): Configs of each transformer layer in
|
||||
encoder. Defaults to an empty dict.
|
||||
drop_rate (float): Dropout rate after position embedding.
|
||||
Defaults to 0.
|
||||
drop_path_rate (float): stochastic depth rate. Defaults to 0.
|
||||
norm_cfg (dict): Config dict for normalization layer. Defaults to
|
||||
``dict(type='LN')``.
|
||||
final_norm (bool): Whether to add a additional layer to normalize
|
||||
final feature map. Defaults to True.
|
||||
output_cls_token (bool): Whether output the cls_token.
|
||||
Defaults to True.
|
||||
with_cls_token (bool): Whether concatenating class token into image
|
||||
tokens as transformer input. Defaults to True.
|
||||
output_cls_token (bool): Whether output the cls_token. If set True,
|
||||
``with_cls_token`` must be True. Defaults to True.
|
||||
interpolate_mode (str): Select the interpolate mode for position
|
||||
embeding vector resize. Defaults to "bicubic".
|
||||
t2t_cfg (dict): Extra config of Tokens-to-Token module.
|
||||
Defaults to an empty dict.
|
||||
layer_cfgs (Sequence | dict): Configs of each transformer layer in
|
||||
encoder. Defaults to an empty dict.
|
||||
init_cfg (dict, optional): The Config for initialization.
|
||||
Defaults to None.
|
||||
"""
|
||||
num_extra_tokens = 1 # cls_token
|
||||
|
||||
def __init__(self,
|
||||
img_size=224,
|
||||
in_channels=3,
|
||||
embed_dims=384,
|
||||
t2t_cfg=dict(),
|
||||
drop_rate=0.,
|
||||
num_layers=14,
|
||||
out_indices=-1,
|
||||
layer_cfgs=dict(),
|
||||
drop_rate=0.,
|
||||
drop_path_rate=0.,
|
||||
norm_cfg=dict(type='LN'),
|
||||
final_norm=True,
|
||||
with_cls_token=True,
|
||||
output_cls_token=True,
|
||||
interpolate_mode='bicubic',
|
||||
t2t_cfg=dict(),
|
||||
layer_cfgs=dict(),
|
||||
init_cfg=None):
|
||||
super(T2T_ViT, self).__init__(init_cfg)
|
||||
|
||||
@ -277,30 +304,41 @@ class T2T_ViT(BaseBackbone):
|
||||
in_channels=in_channels,
|
||||
embed_dims=embed_dims,
|
||||
**t2t_cfg)
|
||||
num_patches = self.tokens_to_token.num_patches
|
||||
self.patch_resolution = self.tokens_to_token.init_out_size
|
||||
num_patches = self.patch_resolution[0] * self.patch_resolution[1]
|
||||
|
||||
# Class token
|
||||
# Set cls token
|
||||
if output_cls_token:
|
||||
assert with_cls_token is True, f'with_cls_token must be True if' \
|
||||
f'set output_cls_token to True, but got {with_cls_token}'
|
||||
self.with_cls_token = with_cls_token
|
||||
self.output_cls_token = output_cls_token
|
||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims))
|
||||
self.num_extra_tokens = 1
|
||||
|
||||
# Position Embedding
|
||||
sinusoid_table = get_sinusoid_encoding(num_patches + 1, embed_dims)
|
||||
# Set position embedding
|
||||
self.interpolate_mode = interpolate_mode
|
||||
sinusoid_table = get_sinusoid_encoding(
|
||||
num_patches + self.num_extra_tokens, embed_dims)
|
||||
self.register_buffer('pos_embed', sinusoid_table)
|
||||
self._register_load_state_dict_pre_hook(self._prepare_pos_embed)
|
||||
|
||||
self.drop_after_pos = nn.Dropout(p=drop_rate)
|
||||
|
||||
if isinstance(out_indices, int):
|
||||
out_indices = [out_indices]
|
||||
assert isinstance(out_indices, Sequence), \
|
||||
f'"out_indices" must by a sequence or int, ' \
|
||||
f'"out_indices" must be a sequence or int, ' \
|
||||
f'get {type(out_indices)} instead.'
|
||||
for i, index in enumerate(out_indices):
|
||||
if index < 0:
|
||||
out_indices[i] = num_layers + index
|
||||
assert out_indices[i] >= 0, f'Invalid out_indices {index}'
|
||||
assert 0 <= out_indices[i] <= num_layers, \
|
||||
f'Invalid out_indices {index}'
|
||||
self.out_indices = out_indices
|
||||
|
||||
# stochastic depth decay rule
|
||||
dpr = [x for x in np.linspace(0, drop_path_rate, num_layers)]
|
||||
|
||||
self.encoder = ModuleList()
|
||||
for i in range(num_layers):
|
||||
if isinstance(layer_cfgs, Sequence):
|
||||
@ -336,17 +374,49 @@ class T2T_ViT(BaseBackbone):
|
||||
|
||||
trunc_normal_(self.cls_token, std=.02)
|
||||
|
||||
def _prepare_pos_embed(self, state_dict, prefix, *args, **kwargs):
|
||||
name = prefix + 'pos_embed'
|
||||
if name not in state_dict.keys():
|
||||
return
|
||||
|
||||
ckpt_pos_embed_shape = state_dict[name].shape
|
||||
if self.pos_embed.shape != ckpt_pos_embed_shape:
|
||||
from mmcls.utils import get_root_logger
|
||||
logger = get_root_logger()
|
||||
logger.info(
|
||||
f'Resize the pos_embed shape from {ckpt_pos_embed_shape} '
|
||||
f'to {self.pos_embed.shape}.')
|
||||
|
||||
ckpt_pos_embed_shape = to_2tuple(
|
||||
int(np.sqrt(ckpt_pos_embed_shape[1] - self.num_extra_tokens)))
|
||||
pos_embed_shape = self.tokens_to_token.init_out_size
|
||||
|
||||
state_dict[name] = resize_pos_embed(state_dict[name],
|
||||
ckpt_pos_embed_shape,
|
||||
pos_embed_shape,
|
||||
self.interpolate_mode,
|
||||
self.num_extra_tokens)
|
||||
|
||||
def forward(self, x):
|
||||
B = x.shape[0]
|
||||
x = self.tokens_to_token(x)
|
||||
num_patches = self.tokens_to_token.num_patches
|
||||
patch_resolution = [int(np.sqrt(num_patches))] * 2
|
||||
x, patch_resolution = self.tokens_to_token(x)
|
||||
|
||||
# stole cls_tokens impl from Phil Wang, thanks
|
||||
cls_tokens = self.cls_token.expand(B, -1, -1)
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
x = x + self.pos_embed
|
||||
|
||||
x = x + resize_pos_embed(
|
||||
self.pos_embed,
|
||||
self.patch_resolution,
|
||||
patch_resolution,
|
||||
mode=self.interpolate_mode,
|
||||
num_extra_tokens=self.num_extra_tokens)
|
||||
x = self.drop_after_pos(x)
|
||||
|
||||
if not self.with_cls_token:
|
||||
# Remove class token for transformer encoder input
|
||||
x = x[:, 1:]
|
||||
|
||||
outs = []
|
||||
for i, layer in enumerate(self.encoder):
|
||||
x = layer(x)
|
||||
@ -356,9 +426,14 @@ class T2T_ViT(BaseBackbone):
|
||||
|
||||
if i in self.out_indices:
|
||||
B, _, C = x.shape
|
||||
patch_token = x[:, 1:].reshape(B, *patch_resolution, C)
|
||||
patch_token = patch_token.permute(0, 3, 1, 2)
|
||||
cls_token = x[:, 0]
|
||||
if self.with_cls_token:
|
||||
patch_token = x[:, 1:].reshape(B, *patch_resolution, C)
|
||||
patch_token = patch_token.permute(0, 3, 1, 2)
|
||||
cls_token = x[:, 0]
|
||||
else:
|
||||
patch_token = x.reshape(B, *patch_resolution, C)
|
||||
patch_token = patch_token.permute(0, 3, 1, 2)
|
||||
cls_token = None
|
||||
if self.output_cls_token:
|
||||
out = [patch_token, cls_token]
|
||||
else:
|
||||
|
@ -4,15 +4,14 @@ from typing import Sequence
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmcv.cnn import build_norm_layer
|
||||
from mmcv.cnn.bricks.transformer import FFN
|
||||
from mmcv.cnn.bricks.transformer import FFN, PatchEmbed
|
||||
from mmcv.cnn.utils.weight_init import trunc_normal_
|
||||
from mmcv.runner.base_module import BaseModule, ModuleList
|
||||
|
||||
from mmcls.utils import get_root_logger
|
||||
from ..builder import BACKBONES
|
||||
from ..utils import MultiheadAttention, PatchEmbed, to_2tuple
|
||||
from ..utils import MultiheadAttention, resize_pos_embed, to_2tuple
|
||||
from .base_backbone import BaseBackbone
|
||||
|
||||
|
||||
@ -108,21 +107,38 @@ class VisionTransformer(BaseBackbone):
|
||||
for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>`_
|
||||
|
||||
Args:
|
||||
arch (str | dict): Vision Transformer architecture
|
||||
Default: 'b'
|
||||
img_size (int | tuple): Input image size
|
||||
patch_size (int | tuple): The patch size
|
||||
arch (str | dict): Vision Transformer architecture. If use string,
|
||||
choose from 'small', 'base', 'large', 'deit-tiny', 'deit-small'
|
||||
and 'deit-base'. If use dict, it should have below keys:
|
||||
|
||||
- **embed_dims** (int): The dimensions of embedding.
|
||||
- **num_layers** (int): The number of transformer encoder layers.
|
||||
- **num_heads** (int): The number of heads in attention modules.
|
||||
- **feedforward_channels** (int): The hidden dimensions in
|
||||
feedforward modules.
|
||||
|
||||
Defaults to 'base'.
|
||||
img_size (int | tuple): The expected input image shape. Because we
|
||||
support dynamic input shape, just set the argument to the most
|
||||
common input image shape. Defaults to 224.
|
||||
patch_size (int | tuple): The patch size in patch embedding.
|
||||
Defaults to 16.
|
||||
in_channels (int): The num of input channels. Defaults to 3.
|
||||
out_indices (Sequence | int): Output from which stages.
|
||||
Defaults to -1, means the last stage.
|
||||
drop_rate (float): Probability of an element to be zeroed.
|
||||
Defaults to 0.
|
||||
drop_path_rate (float): stochastic depth rate. Defaults to 0.
|
||||
qkv_bias (bool): Whether to add bias for qkv in attention modules.
|
||||
Defaults to True.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Defaults to ``dict(type='LN')``.
|
||||
final_norm (bool): Whether to add a additional layer to normalize
|
||||
final feature map. Defaults to True.
|
||||
with_cls_token (bool): Whether concatenating class token into image
|
||||
tokens as transformer input. Defaults to True.
|
||||
output_cls_token (bool): Whether output the cls_token. If set True,
|
||||
`with_cls_token` must be True. Defaults to True.
|
||||
``with_cls_token`` must be True. Defaults to True.
|
||||
interpolate_mode (str): Select the interpolate mode for position
|
||||
embeding vector resize. Defaults to "bicubic".
|
||||
patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict.
|
||||
@ -138,7 +154,6 @@ class VisionTransformer(BaseBackbone):
|
||||
'num_layers': 8,
|
||||
'num_heads': 8,
|
||||
'feedforward_channels': 768 * 3,
|
||||
'qkv_bias': False
|
||||
}),
|
||||
**dict.fromkeys(
|
||||
['b', 'base'], {
|
||||
@ -180,14 +195,17 @@ class VisionTransformer(BaseBackbone):
|
||||
num_extra_tokens = 1 # cls_token
|
||||
|
||||
def __init__(self,
|
||||
arch='b',
|
||||
arch='base',
|
||||
img_size=224,
|
||||
patch_size=16,
|
||||
in_channels=3,
|
||||
out_indices=-1,
|
||||
drop_rate=0.,
|
||||
drop_path_rate=0.,
|
||||
qkv_bias=True,
|
||||
norm_cfg=dict(type='LN', eps=1e-6),
|
||||
final_norm=True,
|
||||
with_cls_token=True,
|
||||
output_cls_token=True,
|
||||
interpolate_mode='bicubic',
|
||||
patch_cfg=dict(),
|
||||
@ -214,16 +232,23 @@ class VisionTransformer(BaseBackbone):
|
||||
|
||||
# Set patch embedding
|
||||
_patch_cfg = dict(
|
||||
img_size=img_size,
|
||||
in_channels=in_channels,
|
||||
input_size=img_size,
|
||||
embed_dims=self.embed_dims,
|
||||
conv_cfg=dict(
|
||||
type='Conv2d', kernel_size=patch_size, stride=patch_size),
|
||||
conv_type='Conv2d',
|
||||
kernel_size=patch_size,
|
||||
stride=patch_size,
|
||||
)
|
||||
_patch_cfg.update(patch_cfg)
|
||||
self.patch_embed = PatchEmbed(**_patch_cfg)
|
||||
num_patches = self.patch_embed.num_patches
|
||||
self.patch_resolution = self.patch_embed.init_out_size
|
||||
num_patches = self.patch_resolution[0] * self.patch_resolution[1]
|
||||
|
||||
# Set cls token
|
||||
if output_cls_token:
|
||||
assert with_cls_token is True, f'with_cls_token must be True if' \
|
||||
f'set output_cls_token to True, but got {with_cls_token}'
|
||||
self.with_cls_token = with_cls_token
|
||||
self.output_cls_token = output_cls_token
|
||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims))
|
||||
|
||||
@ -232,6 +257,8 @@ class VisionTransformer(BaseBackbone):
|
||||
self.pos_embed = nn.Parameter(
|
||||
torch.zeros(1, num_patches + self.num_extra_tokens,
|
||||
self.embed_dims))
|
||||
self._register_load_state_dict_pre_hook(self._prepare_pos_embed)
|
||||
|
||||
self.drop_after_pos = nn.Dropout(p=drop_rate)
|
||||
|
||||
if isinstance(out_indices, int):
|
||||
@ -242,11 +269,12 @@ class VisionTransformer(BaseBackbone):
|
||||
for i, index in enumerate(out_indices):
|
||||
if index < 0:
|
||||
out_indices[i] = self.num_layers + index
|
||||
assert out_indices[i] >= 0, f'Invalid out_indices {index}'
|
||||
assert 0 <= out_indices[i] <= self.num_layers, \
|
||||
f'Invalid out_indices {index}'
|
||||
self.out_indices = out_indices
|
||||
|
||||
# stochastic depth decay rule
|
||||
dpr = np.linspace(0, drop_path_rate, self.arch_settings['num_layers'])
|
||||
dpr = np.linspace(0, drop_path_rate, self.num_layers)
|
||||
|
||||
self.layers = ModuleList()
|
||||
if isinstance(layer_cfgs, dict):
|
||||
@ -259,7 +287,7 @@ class VisionTransformer(BaseBackbone):
|
||||
arch_settings['feedforward_channels'],
|
||||
drop_rate=drop_rate,
|
||||
drop_path_rate=dpr[i],
|
||||
qkv_bias=self.arch_settings.get('qkv_bias', True),
|
||||
qkv_bias=qkv_bias,
|
||||
norm_cfg=norm_cfg)
|
||||
_layer_cfg.update(layer_cfgs[i])
|
||||
self.layers.append(TransformerEncoderLayer(**_layer_cfg))
|
||||
@ -270,8 +298,6 @@ class VisionTransformer(BaseBackbone):
|
||||
norm_cfg, self.embed_dims, postfix=1)
|
||||
self.add_module(self.norm1_name, norm1)
|
||||
|
||||
self._register_load_state_dict_pre_hook(self._prepare_checkpoint_hook)
|
||||
|
||||
@property
|
||||
def norm1(self):
|
||||
return getattr(self, self.norm1_name)
|
||||
@ -283,7 +309,7 @@ class VisionTransformer(BaseBackbone):
|
||||
and self.init_cfg['type'] == 'Pretrained'):
|
||||
trunc_normal_(self.pos_embed, std=0.02)
|
||||
|
||||
def _prepare_checkpoint_hook(self, state_dict, prefix, *args, **kwargs):
|
||||
def _prepare_pos_embed(self, state_dict, prefix, *args, **kwargs):
|
||||
name = prefix + 'pos_embed'
|
||||
if name not in state_dict.keys():
|
||||
return
|
||||
@ -299,61 +325,38 @@ class VisionTransformer(BaseBackbone):
|
||||
|
||||
ckpt_pos_embed_shape = to_2tuple(
|
||||
int(np.sqrt(ckpt_pos_embed_shape[1] - self.num_extra_tokens)))
|
||||
pos_embed_shape = self.patch_embed.patches_resolution
|
||||
pos_embed_shape = self.patch_embed.init_out_size
|
||||
|
||||
state_dict[name] = self.resize_pos_embed(state_dict[name],
|
||||
ckpt_pos_embed_shape,
|
||||
pos_embed_shape,
|
||||
self.interpolate_mode,
|
||||
self.num_extra_tokens)
|
||||
state_dict[name] = resize_pos_embed(state_dict[name],
|
||||
ckpt_pos_embed_shape,
|
||||
pos_embed_shape,
|
||||
self.interpolate_mode,
|
||||
self.num_extra_tokens)
|
||||
|
||||
@staticmethod
|
||||
def resize_pos_embed(pos_embed,
|
||||
src_shape,
|
||||
dst_shape,
|
||||
mode='bicubic',
|
||||
num_extra_tokens=1):
|
||||
"""Resize pos_embed weights.
|
||||
|
||||
Args:
|
||||
pos_embed (torch.Tensor): Position embedding weights with shape
|
||||
[1, L, C].
|
||||
src_shape (tuple): The resolution of downsampled origin training
|
||||
image.
|
||||
dst_shape (tuple): The resolution of downsampled new training
|
||||
image.
|
||||
mode (str): Algorithm used for upsampling:
|
||||
``'nearest'`` | ``'linear'`` | ``'bilinear'`` | ``'bicubic'`` |
|
||||
``'trilinear'``. Default: ``'bicubic'``
|
||||
Return:
|
||||
torch.Tensor: The resized pos_embed of shape [1, L_new, C]
|
||||
"""
|
||||
assert pos_embed.ndim == 3, 'shape of pos_embed must be [1, L, C]'
|
||||
_, L, C = pos_embed.shape
|
||||
src_h, src_w = src_shape
|
||||
assert L == src_h * src_w + num_extra_tokens
|
||||
extra_tokens = pos_embed[:, :num_extra_tokens]
|
||||
|
||||
src_weight = pos_embed[:, num_extra_tokens:]
|
||||
src_weight = src_weight.reshape(1, src_h, src_w, C).permute(0, 3, 1, 2)
|
||||
|
||||
dst_weight = F.interpolate(
|
||||
src_weight, size=dst_shape, align_corners=False, mode=mode)
|
||||
dst_weight = torch.flatten(dst_weight, 2).transpose(1, 2)
|
||||
|
||||
return torch.cat((extra_tokens, dst_weight), dim=1)
|
||||
def resize_pos_embed(*args, **kwargs):
|
||||
"""Interface for backward-compatibility."""
|
||||
return resize_pos_embed(*args, **kwargs)
|
||||
|
||||
def forward(self, x):
|
||||
B = x.shape[0]
|
||||
x = self.patch_embed(x)
|
||||
patch_resolution = self.patch_embed.patches_resolution
|
||||
x, patch_resolution = self.patch_embed(x)
|
||||
|
||||
# stole cls_tokens impl from Phil Wang, thanks
|
||||
cls_tokens = self.cls_token.expand(B, -1, -1)
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
x = x + self.pos_embed
|
||||
x = x + resize_pos_embed(
|
||||
self.pos_embed,
|
||||
self.patch_resolution,
|
||||
patch_resolution,
|
||||
mode=self.interpolate_mode,
|
||||
num_extra_tokens=self.num_extra_tokens)
|
||||
x = self.drop_after_pos(x)
|
||||
|
||||
if not self.with_cls_token:
|
||||
# Remove class token for transformer encoder input
|
||||
x = x[:, 1:]
|
||||
|
||||
outs = []
|
||||
for i, layer in enumerate(self.layers):
|
||||
x = layer(x)
|
||||
@ -363,9 +366,14 @@ class VisionTransformer(BaseBackbone):
|
||||
|
||||
if i in self.out_indices:
|
||||
B, _, C = x.shape
|
||||
patch_token = x[:, 1:].reshape(B, *patch_resolution, C)
|
||||
patch_token = patch_token.permute(0, 3, 1, 2)
|
||||
cls_token = x[:, 0]
|
||||
if self.with_cls_token:
|
||||
patch_token = x[:, 1:].reshape(B, *patch_resolution, C)
|
||||
patch_token = patch_token.permute(0, 3, 1, 2)
|
||||
cls_token = x[:, 0]
|
||||
else:
|
||||
patch_token = x.reshape(B, *patch_resolution, C)
|
||||
patch_token = patch_token.permute(0, 3, 1, 2)
|
||||
cls_token = None
|
||||
if self.output_cls_token:
|
||||
out = [patch_token, cls_token]
|
||||
else:
|
||||
|
@ -1,5 +1,4 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from collections import OrderedDict
|
||||
from typing import Sequence
|
||||
@ -7,18 +6,10 @@ from typing import Sequence
|
||||
import mmcv
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from mmcv.runner import BaseModule
|
||||
from mmcv.runner import BaseModule, auto_fp16
|
||||
|
||||
from mmcls.core.visualization import imshow_infos
|
||||
|
||||
# TODO import `auto_fp16` from mmcv and delete them from mmcls
|
||||
try:
|
||||
from mmcv.runner import auto_fp16
|
||||
except ImportError:
|
||||
warnings.warn('auto_fp16 from mmcls will be deprecated.'
|
||||
'Please install mmcv>=1.1.4.')
|
||||
from mmcls.core import auto_fp16
|
||||
|
||||
|
||||
class BaseClassifier(BaseModule, metaclass=ABCMeta):
|
||||
"""Base class for classifiers."""
|
||||
|
@ -1,14 +1,9 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
import warnings
|
||||
|
||||
from ..builder import CLASSIFIERS, build_backbone, build_head, build_neck
|
||||
from ..heads import MultiLabelClsHead
|
||||
from ..utils.augment import Augments
|
||||
from .base import BaseClassifier
|
||||
|
||||
warnings.simplefilter('once')
|
||||
|
||||
|
||||
@CLASSIFIERS.register_module()
|
||||
class ImageClassifier(BaseClassifier):
|
||||
@ -23,18 +18,8 @@ class ImageClassifier(BaseClassifier):
|
||||
super(ImageClassifier, self).__init__(init_cfg)
|
||||
|
||||
if pretrained is not None:
|
||||
warnings.warn('DeprecationWarning: pretrained is a deprecated \
|
||||
key, please consider using init_cfg')
|
||||
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
|
||||
|
||||
return_tuple = backbone.pop('return_tuple', True)
|
||||
self.backbone = build_backbone(backbone)
|
||||
if return_tuple is False:
|
||||
warnings.warn(
|
||||
'The `return_tuple` is a temporary arg, we will force to '
|
||||
'return tuple in the future. Please handle tuple in your '
|
||||
'custom neck or head.', DeprecationWarning)
|
||||
self.return_tuple = return_tuple
|
||||
|
||||
if neck is not None:
|
||||
self.neck = build_neck(neck)
|
||||
@ -47,29 +32,6 @@ class ImageClassifier(BaseClassifier):
|
||||
augments_cfg = train_cfg.get('augments', None)
|
||||
if augments_cfg is not None:
|
||||
self.augments = Augments(augments_cfg)
|
||||
else:
|
||||
# Considering BC-breaking
|
||||
mixup_cfg = train_cfg.get('mixup', None)
|
||||
cutmix_cfg = train_cfg.get('cutmix', None)
|
||||
assert mixup_cfg is None or cutmix_cfg is None, \
|
||||
'If mixup and cutmix are set simultaneously,' \
|
||||
'use augments instead.'
|
||||
if mixup_cfg is not None:
|
||||
warnings.warn('The mixup attribute will be deprecated. '
|
||||
'Please use augments instead.')
|
||||
cfg = copy.deepcopy(mixup_cfg)
|
||||
cfg['type'] = 'BatchMixup'
|
||||
# In the previous version, mixup_prob is always 1.0.
|
||||
cfg['prob'] = 1.0
|
||||
self.augments = Augments(cfg)
|
||||
if cutmix_cfg is not None:
|
||||
warnings.warn('The cutmix attribute will be deprecated. '
|
||||
'Please use augments instead.')
|
||||
cfg = copy.deepcopy(cutmix_cfg)
|
||||
cutmix_prob = cfg.pop('cutmix_prob')
|
||||
cfg['type'] = 'BatchCutMix'
|
||||
cfg['prob'] = cutmix_prob
|
||||
self.augments = Augments(cfg)
|
||||
|
||||
def extract_feat(self, img, stage='neck'):
|
||||
"""Directly extract features from the specified stage.
|
||||
@ -140,16 +102,7 @@ class ImageClassifier(BaseClassifier):
|
||||
'"neck" and "pre_logits"')
|
||||
|
||||
x = self.backbone(img)
|
||||
if self.return_tuple:
|
||||
if not isinstance(x, tuple):
|
||||
x = (x, )
|
||||
warnings.warn(
|
||||
'We will force all backbones to return a tuple in the '
|
||||
'future. Please check your backbone and wrap the output '
|
||||
'as a tuple.', DeprecationWarning)
|
||||
else:
|
||||
if isinstance(x, tuple):
|
||||
x = x[-1]
|
||||
|
||||
if stage == 'backbone':
|
||||
return x
|
||||
|
||||
@ -181,17 +134,7 @@ class ImageClassifier(BaseClassifier):
|
||||
x = self.extract_feat(img)
|
||||
|
||||
losses = dict()
|
||||
try:
|
||||
loss = self.head.forward_train(x, gt_label)
|
||||
except TypeError as e:
|
||||
if 'not tuple' in str(e) and self.return_tuple:
|
||||
return TypeError(
|
||||
'Seems the head cannot handle tuple input. We have '
|
||||
'changed all backbones\' output to a tuple. Please '
|
||||
'update your custom head\'s forward function. '
|
||||
'Temporarily, you can set "return_tuple=False" in '
|
||||
'your backbone config to disable this feature.')
|
||||
raise e
|
||||
loss = self.head.forward_train(x, gt_label)
|
||||
|
||||
losses.update(loss)
|
||||
|
||||
@ -201,20 +144,10 @@ class ImageClassifier(BaseClassifier):
|
||||
"""Test without augmentation."""
|
||||
x = self.extract_feat(img)
|
||||
|
||||
try:
|
||||
if isinstance(self.head, MultiLabelClsHead):
|
||||
assert 'softmax' not in kwargs, (
|
||||
'Please use `sigmoid` instead of `softmax` '
|
||||
'in multi-label tasks.')
|
||||
res = self.head.simple_test(x, **kwargs)
|
||||
except TypeError as e:
|
||||
if 'not tuple' in str(e) and self.return_tuple:
|
||||
return TypeError(
|
||||
'Seems the head cannot handle tuple input. We have '
|
||||
'changed all backbones\' output to a tuple. Please '
|
||||
'update your custom head\'s forward function. '
|
||||
'Temporarily, you can set "return_tuple=False" in '
|
||||
'your backbone config to disable this feature.')
|
||||
raise e
|
||||
if isinstance(self.head, MultiLabelClsHead):
|
||||
assert 'softmax' not in kwargs, (
|
||||
'Please use `sigmoid` instead of `softmax` '
|
||||
'in multi-label tasks.')
|
||||
res = self.head.simple_test(x, **kwargs)
|
||||
|
||||
return res
|
||||
|
@ -1,6 +1,4 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
@ -26,7 +24,7 @@ class LabelSmoothLoss(nn.Module):
|
||||
label_smooth_val (float): The degree of label smoothing.
|
||||
num_classes (int, optional): Number of classes. Defaults to None.
|
||||
mode (str): Refers to notes, Options are 'original', 'classy_vision',
|
||||
'multi_label'. Defaults to 'classy_vision'
|
||||
'multi_label'. Defaults to 'original'
|
||||
reduction (str): The method used to reduce the loss.
|
||||
Options are "none", "mean" and "sum". Defaults to 'mean'.
|
||||
loss_weight (float): Weight of the loss. Defaults to 1.0.
|
||||
@ -57,7 +55,7 @@ class LabelSmoothLoss(nn.Module):
|
||||
def __init__(self,
|
||||
label_smooth_val,
|
||||
num_classes=None,
|
||||
mode=None,
|
||||
mode='original',
|
||||
reduction='mean',
|
||||
loss_weight=1.0):
|
||||
super().__init__()
|
||||
@ -76,14 +74,6 @@ class LabelSmoothLoss(nn.Module):
|
||||
f'but gets {mode}.'
|
||||
self.reduction = reduction
|
||||
|
||||
if mode is None:
|
||||
warnings.warn(
|
||||
'LabelSmoothLoss mode is not set, use "classy_vision" '
|
||||
'by default. The default value will be changed to '
|
||||
'"original" recently. Please set mode manually if want '
|
||||
'to keep "classy_vision".', UserWarning)
|
||||
mode = 'classy_vision'
|
||||
|
||||
accept_mode = {'original', 'classy_vision', 'multi_label'}
|
||||
assert mode in accept_mode, \
|
||||
f'LabelSmoothLoss supports mode {accept_mode}, but gets {mode}.'
|
||||
|
@ -114,8 +114,6 @@ def convert_to_one_hot(targets: torch.Tensor, classes) -> torch.Tensor:
|
||||
"""
|
||||
assert (torch.max(targets).item() <
|
||||
classes), 'Class Index must be less than number of classes'
|
||||
one_hot_targets = torch.zeros((targets.shape[0], classes),
|
||||
dtype=torch.long,
|
||||
device=targets.device)
|
||||
one_hot_targets.scatter_(1, targets.long(), 1)
|
||||
one_hot_targets = F.one_hot(
|
||||
targets.long().squeeze(-1), num_classes=classes)
|
||||
return one_hot_targets
|
||||
|
@ -1,5 +1,6 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .gap import GlobalAveragePooling
|
||||
from .gem import GeneralizedMeanPooling
|
||||
from .hr_fuse import HRFuseScales
|
||||
|
||||
__all__ = ['GlobalAveragePooling', 'HRFuseScales']
|
||||
__all__ = ['GlobalAveragePooling', 'GeneralizedMeanPooling', 'HRFuseScales']
|
||||
|
53
mmcls/models/necks/gem.py
Normal file
53
mmcls/models/necks/gem.py
Normal file
@ -0,0 +1,53 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
from torch.nn import functional as F
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from ..builder import NECKS
|
||||
|
||||
|
||||
def gem(x: Tensor, p: Parameter, eps: float = 1e-6, clamp=True) -> Tensor:
|
||||
if clamp:
|
||||
x = x.clamp(min=eps)
|
||||
return F.avg_pool2d(x.pow(p), (x.size(-2), x.size(-1))).pow(1. / p)
|
||||
|
||||
|
||||
@NECKS.register_module()
|
||||
class GeneralizedMeanPooling(nn.Module):
|
||||
"""Generalized Mean Pooling neck.
|
||||
|
||||
Note that we use `view` to remove extra channel after pooling. We do not
|
||||
use `squeeze` as it will also remove the batch dimension when the tensor
|
||||
has a batch dimension of size 1, which can lead to unexpected errors.
|
||||
|
||||
Args:
|
||||
p (float): Parameter value.
|
||||
Default: 3.
|
||||
eps (float): epsilon.
|
||||
Default: 1e-6
|
||||
clamp (bool): Use clamp before pooling.
|
||||
Default: True
|
||||
"""
|
||||
|
||||
def __init__(self, p=3., eps=1e-6, clamp=True):
|
||||
assert p >= 1, "'p' must be a value greater then 1"
|
||||
super(GeneralizedMeanPooling, self).__init__()
|
||||
self.p = Parameter(torch.ones(1) * p)
|
||||
self.eps = eps
|
||||
self.clamp = clamp
|
||||
|
||||
def forward(self, inputs):
|
||||
if isinstance(inputs, tuple):
|
||||
outs = tuple([
|
||||
gem(x, p=self.p, eps=self.eps, clamp=self.clamp)
|
||||
for x in inputs
|
||||
])
|
||||
outs = tuple(
|
||||
[out.view(x.size(0), -1) for out, x in zip(outs, inputs)])
|
||||
elif isinstance(inputs, torch.Tensor):
|
||||
outs = gem(inputs, p=self.p, eps=self.eps, clamp=self.clamp)
|
||||
outs = outs.view(inputs.size(0), -1)
|
||||
else:
|
||||
raise TypeError('neck inputs should be tuple or torch.tensor')
|
||||
return outs
|
@ -2,7 +2,7 @@
|
||||
from .attention import MultiheadAttention, ShiftWindowMSA
|
||||
from .augment.augments import Augments
|
||||
from .channel_shuffle import channel_shuffle
|
||||
from .embed import HybridEmbed, PatchEmbed, PatchMerging
|
||||
from .embed import HybridEmbed, PatchEmbed, PatchMerging, resize_pos_embed
|
||||
from .helpers import is_tracing, to_2tuple, to_3tuple, to_4tuple, to_ntuple
|
||||
from .inverted_residual import InvertedResidual
|
||||
from .make_divisible import make_divisible
|
||||
@ -13,5 +13,5 @@ __all__ = [
|
||||
'channel_shuffle', 'make_divisible', 'InvertedResidual', 'SELayer',
|
||||
'to_ntuple', 'to_2tuple', 'to_3tuple', 'to_4tuple', 'PatchEmbed',
|
||||
'PatchMerging', 'HybridEmbed', 'Augments', 'ShiftWindowMSA', 'is_tracing',
|
||||
'MultiheadAttention', 'ConditionalPositionEncoding'
|
||||
'MultiheadAttention', 'ConditionalPositionEncoding', 'resize_pos_embed'
|
||||
]
|
||||
|
@ -1,4 +1,6 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
@ -126,14 +128,12 @@ class ShiftWindowMSA(BaseModule):
|
||||
|
||||
Args:
|
||||
embed_dims (int): Number of input channels.
|
||||
input_resolution (Tuple[int, int]): The resolution of the input feature
|
||||
map.
|
||||
num_heads (int): Number of attention heads.
|
||||
window_size (int): The height and width of the window.
|
||||
shift_size (int, optional): The shift step of each window towards
|
||||
right-bottom. If zero, act as regular window-msa. Defaults to 0.
|
||||
qkv_bias (bool, optional): If True, add a learnable bias to q, k, v.
|
||||
Default: True
|
||||
Defaults to True
|
||||
qk_scale (float | None, optional): Override default qk scale of
|
||||
head_dim ** -0.5 if set. Defaults to None.
|
||||
attn_drop (float, optional): Dropout ratio of attention weight.
|
||||
@ -141,15 +141,17 @@ class ShiftWindowMSA(BaseModule):
|
||||
proj_drop (float, optional): Dropout ratio of output. Defaults to 0.
|
||||
dropout_layer (dict, optional): The dropout_layer used before output.
|
||||
Defaults to dict(type='DropPath', drop_prob=0.).
|
||||
auto_pad (bool, optional): Auto pad the feature map to be divisible by
|
||||
window_size, Defaults to False.
|
||||
pad_small_map (bool): If True, pad the small feature map to the window
|
||||
size, which is common used in detection and segmentation. If False,
|
||||
avoid shifting window and shrink the window size to the size of
|
||||
feature map, which is common used in classification.
|
||||
Defaults to False.
|
||||
init_cfg (dict, optional): The extra config for initialization.
|
||||
Default: None.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
embed_dims,
|
||||
input_resolution,
|
||||
num_heads,
|
||||
window_size,
|
||||
shift_size=0,
|
||||
@ -158,53 +160,134 @@ class ShiftWindowMSA(BaseModule):
|
||||
attn_drop=0,
|
||||
proj_drop=0,
|
||||
dropout_layer=dict(type='DropPath', drop_prob=0.),
|
||||
auto_pad=False,
|
||||
pad_small_map=False,
|
||||
input_resolution=None,
|
||||
auto_pad=None,
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg)
|
||||
|
||||
self.embed_dims = embed_dims
|
||||
self.input_resolution = input_resolution
|
||||
if input_resolution is not None or auto_pad is not None:
|
||||
warnings.warn(
|
||||
'The ShiftWindowMSA in new version has supported auto padding '
|
||||
'and dynamic input shape in all condition. And the argument '
|
||||
'`auto_pad` and `input_resolution` have been deprecated.',
|
||||
DeprecationWarning)
|
||||
|
||||
self.shift_size = shift_size
|
||||
self.window_size = window_size
|
||||
if min(self.input_resolution) <= self.window_size:
|
||||
# if window size is larger than input resolution, don't partition
|
||||
self.shift_size = 0
|
||||
self.window_size = min(self.input_resolution)
|
||||
assert 0 <= self.shift_size < self.window_size
|
||||
|
||||
self.w_msa = WindowMSA(embed_dims, to_2tuple(self.window_size),
|
||||
num_heads, qkv_bias, qk_scale, attn_drop,
|
||||
proj_drop)
|
||||
self.w_msa = WindowMSA(
|
||||
embed_dims=embed_dims,
|
||||
window_size=to_2tuple(self.window_size),
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale,
|
||||
attn_drop=attn_drop,
|
||||
proj_drop=proj_drop,
|
||||
)
|
||||
|
||||
self.drop = build_dropout(dropout_layer)
|
||||
self.pad_small_map = pad_small_map
|
||||
|
||||
H, W = self.input_resolution
|
||||
# Handle auto padding
|
||||
self.auto_pad = auto_pad
|
||||
if self.auto_pad:
|
||||
self.pad_r = (self.window_size -
|
||||
W % self.window_size) % self.window_size
|
||||
self.pad_b = (self.window_size -
|
||||
H % self.window_size) % self.window_size
|
||||
self.H_pad = H + self.pad_b
|
||||
self.W_pad = W + self.pad_r
|
||||
else:
|
||||
H_pad, W_pad = self.input_resolution
|
||||
assert H_pad % self.window_size + W_pad % self.window_size == 0,\
|
||||
f'input_resolution({self.input_resolution}) is not divisible '\
|
||||
f'by window_size({self.window_size}). Please check feature '\
|
||||
f'map shape or set `auto_pad=True`.'
|
||||
self.H_pad, self.W_pad = H_pad, W_pad
|
||||
self.pad_r, self.pad_b = 0, 0
|
||||
def forward(self, query, hw_shape):
|
||||
B, L, C = query.shape
|
||||
H, W = hw_shape
|
||||
assert L == H * W, f"The query length {L} doesn't match the input "\
|
||||
f'shape ({H}, {W}).'
|
||||
query = query.view(B, H, W, C)
|
||||
|
||||
window_size = self.window_size
|
||||
shift_size = self.shift_size
|
||||
|
||||
if min(H, W) == window_size:
|
||||
# If not pad small feature map, avoid shifting when the window size
|
||||
# is equal to the size of feature map. It's to align with the
|
||||
# behavior of the original implementation.
|
||||
shift_size = shift_size if self.pad_small_map else 0
|
||||
elif min(H, W) < window_size:
|
||||
# In the original implementation, the window size will be shrunk
|
||||
# to the size of feature map. The behavior is different with
|
||||
# swin-transformer for downstream tasks. To support dynamic input
|
||||
# shape, we don't allow this feature.
|
||||
assert self.pad_small_map, \
|
||||
f'The input shape ({H}, {W}) is smaller than the window ' \
|
||||
f'size ({window_size}). Please set `pad_small_map=True`, or ' \
|
||||
'decrease the `window_size`.'
|
||||
|
||||
pad_r = (window_size - W % window_size) % window_size
|
||||
pad_b = (window_size - H % window_size) % window_size
|
||||
query = F.pad(query, (0, 0, 0, pad_r, 0, pad_b))
|
||||
|
||||
H_pad, W_pad = query.shape[1], query.shape[2]
|
||||
|
||||
# cyclic shift
|
||||
if shift_size > 0:
|
||||
query = torch.roll(
|
||||
query, shifts=(-shift_size, -shift_size), dims=(1, 2))
|
||||
|
||||
attn_mask = self.get_attn_mask((H_pad, W_pad),
|
||||
window_size=window_size,
|
||||
shift_size=shift_size,
|
||||
device=query.device)
|
||||
|
||||
# nW*B, window_size, window_size, C
|
||||
query_windows = self.window_partition(query, window_size)
|
||||
# nW*B, window_size*window_size, C
|
||||
query_windows = query_windows.view(-1, window_size**2, C)
|
||||
|
||||
# W-MSA/SW-MSA (nW*B, window_size*window_size, C)
|
||||
attn_windows = self.w_msa(query_windows, mask=attn_mask)
|
||||
|
||||
# merge windows
|
||||
attn_windows = attn_windows.view(-1, window_size, window_size, C)
|
||||
|
||||
# B H' W' C
|
||||
shifted_x = self.window_reverse(attn_windows, H_pad, W_pad,
|
||||
window_size)
|
||||
# reverse cyclic shift
|
||||
if self.shift_size > 0:
|
||||
# calculate attention mask for SW-MSA
|
||||
img_mask = torch.zeros((1, self.H_pad, self.W_pad, 1)) # 1 H W 1
|
||||
h_slices = (slice(0, -self.window_size),
|
||||
slice(-self.window_size,
|
||||
-self.shift_size), slice(-self.shift_size, None))
|
||||
w_slices = (slice(0, -self.window_size),
|
||||
slice(-self.window_size,
|
||||
-self.shift_size), slice(-self.shift_size, None))
|
||||
x = torch.roll(
|
||||
shifted_x, shifts=(shift_size, shift_size), dims=(1, 2))
|
||||
else:
|
||||
x = shifted_x
|
||||
|
||||
if H != H_pad or W != W_pad:
|
||||
x = x[:, :H, :W, :].contiguous()
|
||||
|
||||
x = x.view(B, H * W, C)
|
||||
|
||||
x = self.drop(x)
|
||||
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def window_reverse(windows, H, W, window_size):
|
||||
B = int(windows.shape[0] / (H * W / window_size / window_size))
|
||||
x = windows.view(B, H // window_size, W // window_size, window_size,
|
||||
window_size, -1)
|
||||
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def window_partition(x, window_size):
|
||||
B, H, W, C = x.shape
|
||||
x = x.view(B, H // window_size, window_size, W // window_size,
|
||||
window_size, C)
|
||||
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous()
|
||||
windows = windows.view(-1, window_size, window_size, C)
|
||||
return windows
|
||||
|
||||
@staticmethod
|
||||
def get_attn_mask(hw_shape, window_size, shift_size, device=None):
|
||||
if shift_size > 0:
|
||||
img_mask = torch.zeros(1, *hw_shape, 1, device=device)
|
||||
h_slices = (slice(0, -window_size), slice(-window_size,
|
||||
-shift_size),
|
||||
slice(-shift_size, None))
|
||||
w_slices = (slice(0, -window_size), slice(-window_size,
|
||||
-shift_size),
|
||||
slice(-shift_size, None))
|
||||
cnt = 0
|
||||
for h in h_slices:
|
||||
for w in w_slices:
|
||||
@ -212,83 +295,15 @@ class ShiftWindowMSA(BaseModule):
|
||||
cnt += 1
|
||||
|
||||
# nW, window_size, window_size, 1
|
||||
mask_windows = self.window_partition(img_mask)
|
||||
mask_windows = mask_windows.view(
|
||||
-1, self.window_size * self.window_size)
|
||||
mask_windows = ShiftWindowMSA.window_partition(
|
||||
img_mask, window_size)
|
||||
mask_windows = mask_windows.view(-1, window_size * window_size)
|
||||
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
||||
attn_mask = attn_mask.masked_fill(attn_mask != 0,
|
||||
float(-100.0)).masked_fill(
|
||||
attn_mask == 0, float(0.0))
|
||||
attn_mask = attn_mask.masked_fill(attn_mask != 0, -100.0)
|
||||
attn_mask = attn_mask.masked_fill(attn_mask == 0, 0.0)
|
||||
else:
|
||||
attn_mask = None
|
||||
|
||||
self.register_buffer('attn_mask', attn_mask)
|
||||
|
||||
def forward(self, query):
|
||||
H, W = self.input_resolution
|
||||
B, L, C = query.shape
|
||||
assert L == H * W, 'input feature has wrong size'
|
||||
query = query.view(B, H, W, C)
|
||||
|
||||
if self.pad_r or self.pad_b:
|
||||
query = F.pad(query, (0, 0, 0, self.pad_r, 0, self.pad_b))
|
||||
|
||||
# cyclic shift
|
||||
if self.shift_size > 0:
|
||||
shifted_query = torch.roll(
|
||||
query,
|
||||
shifts=(-self.shift_size, -self.shift_size),
|
||||
dims=(1, 2))
|
||||
else:
|
||||
shifted_query = query
|
||||
|
||||
# nW*B, window_size, window_size, C
|
||||
query_windows = self.window_partition(shifted_query)
|
||||
# nW*B, window_size*window_size, C
|
||||
query_windows = query_windows.view(-1, self.window_size**2, C)
|
||||
|
||||
# W-MSA/SW-MSA (nW*B, window_size*window_size, C)
|
||||
attn_windows = self.w_msa(query_windows, mask=self.attn_mask)
|
||||
|
||||
# merge windows
|
||||
attn_windows = attn_windows.view(-1, self.window_size,
|
||||
self.window_size, C)
|
||||
|
||||
# B H' W' C
|
||||
shifted_x = self.window_reverse(attn_windows, self.H_pad, self.W_pad)
|
||||
# reverse cyclic shift
|
||||
if self.shift_size > 0:
|
||||
x = torch.roll(
|
||||
shifted_x,
|
||||
shifts=(self.shift_size, self.shift_size),
|
||||
dims=(1, 2))
|
||||
else:
|
||||
x = shifted_x
|
||||
|
||||
if self.pad_r or self.pad_b:
|
||||
x = x[:, :H, :W, :].contiguous()
|
||||
|
||||
x = x.view(B, H * W, C)
|
||||
|
||||
x = self.drop(x)
|
||||
return x
|
||||
|
||||
def window_reverse(self, windows, H, W):
|
||||
window_size = self.window_size
|
||||
B = int(windows.shape[0] / (H * W / window_size / window_size))
|
||||
x = windows.view(B, H // window_size, W // window_size, window_size,
|
||||
window_size, -1)
|
||||
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
|
||||
return x
|
||||
|
||||
def window_partition(self, x):
|
||||
B, H, W, C = x.shape
|
||||
window_size = self.window_size
|
||||
x = x.view(B, H // window_size, window_size, W // window_size,
|
||||
window_size, C)
|
||||
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous()
|
||||
windows = windows.view(-1, window_size, window_size, C)
|
||||
return windows
|
||||
return attn_mask
|
||||
|
||||
|
||||
class MultiheadAttention(BaseModule):
|
||||
|
@ -1,3 +1,4 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
|
@ -1,12 +1,59 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmcv.cnn import build_conv_layer, build_norm_layer
|
||||
from mmcv.runner.base_module import BaseModule
|
||||
|
||||
from .helpers import to_2tuple
|
||||
|
||||
|
||||
def resize_pos_embed(pos_embed,
|
||||
src_shape,
|
||||
dst_shape,
|
||||
mode='bicubic',
|
||||
num_extra_tokens=1):
|
||||
"""Resize pos_embed weights.
|
||||
|
||||
Args:
|
||||
pos_embed (torch.Tensor): Position embedding weights with shape
|
||||
[1, L, C].
|
||||
src_shape (tuple): The resolution of downsampled origin training
|
||||
image, in format (H, W).
|
||||
dst_shape (tuple): The resolution of downsampled new training
|
||||
image, in format (H, W).
|
||||
mode (str): Algorithm used for upsampling. Choose one from 'nearest',
|
||||
'linear', 'bilinear', 'bicubic' and 'trilinear'.
|
||||
Defaults to 'bicubic'.
|
||||
num_extra_tokens (int): The number of extra tokens, such as cls_token.
|
||||
Defaults to 1.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The resized pos_embed of shape [1, L_new, C]
|
||||
"""
|
||||
if src_shape[0] == dst_shape[0] and src_shape[1] == dst_shape[1]:
|
||||
return pos_embed
|
||||
assert pos_embed.ndim == 3, 'shape of pos_embed must be [1, L, C]'
|
||||
_, L, C = pos_embed.shape
|
||||
src_h, src_w = src_shape
|
||||
assert L == src_h * src_w + num_extra_tokens, \
|
||||
f"The length of `pos_embed` ({L}) doesn't match the expected " \
|
||||
f'shape ({src_h}*{src_w}+{num_extra_tokens}). Please check the' \
|
||||
'`img_size` argument.'
|
||||
extra_tokens = pos_embed[:, :num_extra_tokens]
|
||||
|
||||
src_weight = pos_embed[:, num_extra_tokens:]
|
||||
src_weight = src_weight.reshape(1, src_h, src_w, C).permute(0, 3, 1, 2)
|
||||
|
||||
dst_weight = F.interpolate(
|
||||
src_weight, size=dst_shape, align_corners=False, mode=mode)
|
||||
dst_weight = torch.flatten(dst_weight, 2).transpose(1, 2)
|
||||
|
||||
return torch.cat((extra_tokens, dst_weight), dim=1)
|
||||
|
||||
|
||||
class PatchEmbed(BaseModule):
|
||||
"""Image to Patch Embedding.
|
||||
|
||||
@ -32,6 +79,9 @@ class PatchEmbed(BaseModule):
|
||||
conv_cfg=None,
|
||||
init_cfg=None):
|
||||
super(PatchEmbed, self).__init__(init_cfg)
|
||||
warnings.warn('The `PatchEmbed` in mmcls will be deprecated. '
|
||||
'Please use `mmcv.cnn.bricks.transformer.PatchEmbed`. '
|
||||
"It's more general and supports dynamic input shape")
|
||||
|
||||
if isinstance(img_size, int):
|
||||
img_size = to_2tuple(img_size)
|
||||
@ -203,6 +253,10 @@ class PatchMerging(BaseModule):
|
||||
norm_cfg=dict(type='LN'),
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg)
|
||||
warnings.warn('The `PatchMerging` in mmcls will be deprecated. '
|
||||
'Please use `mmcv.cnn.bricks.transformer.PatchMerging`. '
|
||||
"It's more general and supports dynamic input shape")
|
||||
|
||||
H, W = input_resolution
|
||||
self.input_resolution = input_resolution
|
||||
self.in_channels = in_channels
|
||||
|
@ -1,6 +1,6 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved
|
||||
|
||||
__version__ = '0.20.1'
|
||||
__version__ = '0.21.0'
|
||||
|
||||
|
||||
def parse_version_info(version_str):
|
||||
|
@ -20,3 +20,4 @@ Import:
|
||||
- configs/efficientnet/metafile.yml
|
||||
- configs/convnext/metafile.yml
|
||||
- configs/hrnet/metafile.yml
|
||||
- configs/wrn/metafile.yml
|
||||
|
@ -1,2 +1,4 @@
|
||||
albumentations>=0.3.2
|
||||
albumentations>=0.3.2 --no-binary qudida,albumentations
|
||||
colorama
|
||||
requests
|
||||
rich
|
||||
|
@ -12,9 +12,8 @@ split_before_expression_after_opening_paren = true
|
||||
[isort]
|
||||
line_length = 79
|
||||
multi_line_output = 0
|
||||
known_standard_library = pkg_resources,setuptools
|
||||
extra_standard_library = pkg_resources,setuptools
|
||||
known_first_party = mmcls
|
||||
known_third_party = PIL,cv2,matplotlib,mmcv,mmdet,modelindex,numpy,onnxruntime,packaging,pytest,pytorch_sphinx_theme,requests,rich,sphinx,tensorflow,torch,torchvision,ts
|
||||
no_lines_before = STDLIB,LOCALFOLDER
|
||||
default_section = THIRDPARTY
|
||||
|
||||
|
4
setup.py
4
setup.py
@ -66,6 +66,9 @@ def parse_requirements(fname='requirements.txt', with_version=True):
|
||||
info['platform_deps'] = platform_deps
|
||||
else:
|
||||
version = rest # NOQA
|
||||
if '--' in version:
|
||||
# the `extras_require` doesn't accept options.
|
||||
version = version.split('--')[0].strip()
|
||||
info['version'] = (op, version)
|
||||
yield info
|
||||
|
||||
@ -171,7 +174,6 @@ if __name__ == '__main__':
|
||||
'License :: OSI Approved :: Apache Software License',
|
||||
'Operating System :: OS Independent',
|
||||
'Programming Language :: Python :: 3',
|
||||
'Programming Language :: Python :: 3.5',
|
||||
'Programming Language :: Python :: 3.6',
|
||||
'Programming Language :: Python :: 3.7',
|
||||
'Programming Language :: Python :: 3.8',
|
||||
|
2
tests/data/dataset/classes.txt
Normal file
2
tests/data/dataset/classes.txt
Normal file
@ -0,0 +1,2 @@
|
||||
bus
|
||||
car
|
@ -1,3 +1,4 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# small RetinaNet
|
||||
num_classes = 3
|
||||
|
||||
|
@ -1,3 +1,4 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os.path as osp
|
||||
from copy import deepcopy
|
||||
from unittest.mock import patch
|
||||
|
@ -1,5 +1,5 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import tempfile
|
||||
import os.path as osp
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import numpy as np
|
||||
@ -65,15 +65,13 @@ def test_datasets_override_default(dataset_name):
|
||||
assert dataset.CLASSES == ['bus', 'car']
|
||||
|
||||
# Test setting classes through a file
|
||||
tmp_file = tempfile.NamedTemporaryFile()
|
||||
with open(tmp_file.name, 'w') as f:
|
||||
f.write('bus\ncar\n')
|
||||
classes_file = osp.join(
|
||||
osp.dirname(__file__), '../../data/dataset/classes.txt')
|
||||
dataset = dataset_class(
|
||||
data_prefix='VOC2007' if dataset_name == 'VOC' else '',
|
||||
pipeline=[],
|
||||
classes=tmp_file.name,
|
||||
classes=classes_file,
|
||||
test_mode=True)
|
||||
tmp_file.close()
|
||||
|
||||
assert dataset.CLASSES == ['bus', 'car']
|
||||
|
||||
|
@ -1,7 +1,7 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os.path as osp
|
||||
import random
|
||||
import string
|
||||
import tempfile
|
||||
|
||||
from mmcls.datasets.utils import check_integrity, rm_suffix
|
||||
|
||||
@ -17,6 +17,6 @@ def test_dataset_utils():
|
||||
rand_file = ''.join(random.sample(string.ascii_letters, 10))
|
||||
assert not check_integrity(rand_file, md5=None)
|
||||
assert not check_integrity(rand_file, md5=2333)
|
||||
tmp_file = tempfile.NamedTemporaryFile()
|
||||
assert check_integrity(tmp_file.name, md5=None)
|
||||
assert not check_integrity(tmp_file.name, md5=2333)
|
||||
test_file = osp.join(osp.dirname(__file__), '../../data/color.jpg')
|
||||
assert check_integrity(test_file, md5='08252e5100cb321fe74e0e12a724ce14')
|
||||
assert not check_integrity(test_file, md5=2333)
|
||||
|
@ -760,7 +760,7 @@ def test_equalize(nb_rand_test=100):
|
||||
|
||||
def _imequalize(img):
|
||||
# equalize the image using PIL.ImageOps.equalize
|
||||
from PIL import ImageOps, Image
|
||||
from PIL import Image, ImageOps
|
||||
img = Image.fromarray(img)
|
||||
equalized_img = np.asarray(ImageOps.equalize(img))
|
||||
return equalized_img
|
||||
@ -932,8 +932,9 @@ def test_posterize():
|
||||
def test_contrast(nb_rand_test=100):
|
||||
|
||||
def _adjust_contrast(img, factor):
|
||||
from PIL.ImageEnhance import Contrast
|
||||
from PIL import Image
|
||||
from PIL.ImageEnhance import Contrast
|
||||
|
||||
# Image.fromarray defaultly supports RGB, not BGR.
|
||||
# convert from BGR to RGB
|
||||
img = Image.fromarray(img[..., ::-1], mode='RGB')
|
||||
@ -1066,8 +1067,8 @@ def test_brightness(nb_rand_test=100):
|
||||
def _adjust_brightness(img, factor):
|
||||
# adjust the brightness of image using
|
||||
# PIL.ImageEnhance.Brightness
|
||||
from PIL.ImageEnhance import Brightness
|
||||
from PIL import Image
|
||||
from PIL.ImageEnhance import Brightness
|
||||
img = Image.fromarray(img)
|
||||
brightened_img = Brightness(img).enhance(factor)
|
||||
return np.asarray(brightened_img)
|
||||
@ -1128,8 +1129,8 @@ def test_sharpness(nb_rand_test=100):
|
||||
def _adjust_sharpness(img, factor):
|
||||
# adjust the sharpness of image using
|
||||
# PIL.ImageEnhance.Sharpness
|
||||
from PIL.ImageEnhance import Sharpness
|
||||
from PIL import Image
|
||||
from PIL.ImageEnhance import Sharpness
|
||||
img = Image.fromarray(img)
|
||||
sharpened_img = Sharpness(img).enhance(factor)
|
||||
return np.asarray(sharpened_img)
|
||||
|
@ -52,8 +52,7 @@ backbone_configs = dict(
|
||||
arch='small',
|
||||
drop_path_rate=0.2,
|
||||
img_size=800,
|
||||
out_indices=(2, 3),
|
||||
auto_pad=True),
|
||||
out_indices=(2, 3)),
|
||||
out_channels=[384, 768]),
|
||||
timm_efficientnet=dict(
|
||||
backbone=dict(
|
||||
|
49
tests/test_metrics/test_utils.py
Normal file
49
tests/test_metrics/test_utils.py
Normal file
@ -0,0 +1,49 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmcls.models.losses.utils import convert_to_one_hot
|
||||
|
||||
|
||||
def ori_convert_to_one_hot(targets: torch.Tensor, classes) -> torch.Tensor:
|
||||
assert (torch.max(targets).item() <
|
||||
classes), 'Class Index must be less than number of classes'
|
||||
one_hot_targets = torch.zeros((targets.shape[0], classes),
|
||||
dtype=torch.long,
|
||||
device=targets.device)
|
||||
one_hot_targets.scatter_(1, targets.long(), 1)
|
||||
return one_hot_targets
|
||||
|
||||
|
||||
def test_convert_to_one_hot():
|
||||
# label should smaller than classes
|
||||
targets = torch.tensor([1, 2, 3, 8, 5])
|
||||
classes = 5
|
||||
with pytest.raises(AssertionError):
|
||||
_ = convert_to_one_hot(targets, classes)
|
||||
|
||||
# test with original impl
|
||||
classes = 10
|
||||
targets = torch.randint(high=classes, size=(10, 1))
|
||||
ori_one_hot_targets = torch.zeros((targets.shape[0], classes),
|
||||
dtype=torch.long,
|
||||
device=targets.device)
|
||||
ori_one_hot_targets.scatter_(1, targets.long(), 1)
|
||||
one_hot_targets = convert_to_one_hot(targets, classes)
|
||||
assert torch.equal(ori_one_hot_targets, one_hot_targets)
|
||||
|
||||
|
||||
# test cuda version
|
||||
@pytest.mark.skipif(
|
||||
not torch.cuda.is_available(), reason='requires CUDA support')
|
||||
def test_convert_to_one_hot_cuda():
|
||||
# test with original impl
|
||||
classes = 10
|
||||
targets = torch.randint(high=classes, size=(10, 1)).cuda()
|
||||
ori_one_hot_targets = torch.zeros((targets.shape[0], classes),
|
||||
dtype=torch.long,
|
||||
device=targets.device)
|
||||
ori_one_hot_targets.scatter_(1, targets.long(), 1)
|
||||
one_hot_targets = convert_to_one_hot(targets, classes)
|
||||
assert torch.equal(ori_one_hot_targets, one_hot_targets)
|
||||
assert ori_one_hot_targets.device == one_hot_targets.device
|
1
tests/test_models/test_backbones/__init__.py
Normal file
1
tests/test_models/test_backbones/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
@ -57,6 +57,25 @@ def test_conformer_backbone():
|
||||
) # base_channels * channel_ratio * 4
|
||||
assert transformer_feature.shape == (3, 384)
|
||||
|
||||
# Test Conformer with irregular input size.
|
||||
model = Conformer(**cfg_ori)
|
||||
model.init_weights()
|
||||
model.train()
|
||||
|
||||
assert check_norm_state(model.modules(), True)
|
||||
|
||||
imgs = torch.randn(3, 3, 241, 241)
|
||||
conv_feature, transformer_feature = model(imgs)[-1]
|
||||
assert conv_feature.shape == (3, 64 * 1 * 4
|
||||
) # base_channels * channel_ratio * 4
|
||||
assert transformer_feature.shape == (3, 384)
|
||||
|
||||
imgs = torch.randn(3, 3, 321, 221)
|
||||
conv_feature, transformer_feature = model(imgs)[-1]
|
||||
assert conv_feature.shape == (3, 64 * 1 * 4
|
||||
) # base_channels * channel_ratio * 4
|
||||
assert transformer_feature.shape == (3, 384)
|
||||
|
||||
# Test custom arch Conformer without output cls token
|
||||
cfg = deepcopy(cfg_ori)
|
||||
cfg['arch'] = {
|
||||
@ -72,7 +91,7 @@ def test_conformer_backbone():
|
||||
assert conv_feature.shape == (3, 32 * 3 * 4)
|
||||
assert transformer_feature.shape == (3, 128)
|
||||
|
||||
# Test ViT with multi out indices
|
||||
# Test Conformer with multi out indices
|
||||
cfg = deepcopy(cfg_ori)
|
||||
cfg['out_indices'] = [4, 8, 12]
|
||||
model = Conformer(**cfg)
|
||||
|
@ -1,43 +1,131 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import math
|
||||
import os
|
||||
import tempfile
|
||||
from copy import deepcopy
|
||||
from unittest import TestCase
|
||||
|
||||
import torch
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
from mmcv.runner import load_checkpoint, save_checkpoint
|
||||
|
||||
from mmcls.models.backbones import DistilledVisionTransformer
|
||||
from .utils import timm_resize_pos_embed
|
||||
|
||||
|
||||
def check_norm_state(modules, train_state):
|
||||
"""Check if norm layer is in correct train state."""
|
||||
for mod in modules:
|
||||
if isinstance(mod, _BatchNorm):
|
||||
if mod.training != train_state:
|
||||
return False
|
||||
return True
|
||||
class TestDeiT(TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.cfg = dict(
|
||||
arch='deit-base', img_size=224, patch_size=16, drop_rate=0.1)
|
||||
|
||||
def test_deit_backbone():
|
||||
cfg_ori = dict(arch='deit-b', img_size=224, patch_size=16)
|
||||
def test_init_weights(self):
|
||||
# test weight init cfg
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['init_cfg'] = [
|
||||
dict(
|
||||
type='Kaiming',
|
||||
layer='Conv2d',
|
||||
mode='fan_in',
|
||||
nonlinearity='linear')
|
||||
]
|
||||
model = DistilledVisionTransformer(**cfg)
|
||||
ori_weight = model.patch_embed.projection.weight.clone().detach()
|
||||
# The pos_embed is all zero before initialize
|
||||
self.assertTrue(torch.allclose(model.dist_token, torch.tensor(0.)))
|
||||
|
||||
# Test structure
|
||||
model = DistilledVisionTransformer(**cfg_ori)
|
||||
model.init_weights()
|
||||
model.train()
|
||||
model.init_weights()
|
||||
initialized_weight = model.patch_embed.projection.weight
|
||||
self.assertFalse(torch.allclose(ori_weight, initialized_weight))
|
||||
self.assertFalse(torch.allclose(model.dist_token, torch.tensor(0.)))
|
||||
|
||||
assert check_norm_state(model.modules(), True)
|
||||
assert model.dist_token.shape == (1, 1, 768)
|
||||
assert model.pos_embed.shape == (1, model.patch_embed.num_patches + 2, 768)
|
||||
# test load checkpoint
|
||||
pretrain_pos_embed = model.pos_embed.clone().detach()
|
||||
tmpdir = tempfile.gettempdir()
|
||||
checkpoint = os.path.join(tmpdir, 'test.pth')
|
||||
save_checkpoint(model, checkpoint)
|
||||
cfg = deepcopy(self.cfg)
|
||||
model = DistilledVisionTransformer(**cfg)
|
||||
load_checkpoint(model, checkpoint, strict=True)
|
||||
self.assertTrue(torch.allclose(model.pos_embed, pretrain_pos_embed))
|
||||
|
||||
# Test forward
|
||||
imgs = torch.rand(1, 3, 224, 224)
|
||||
outs = model(imgs)
|
||||
patch_token, cls_token, dist_token = outs[0]
|
||||
assert patch_token.shape == (1, 768, 14, 14)
|
||||
assert cls_token.shape == (1, 768)
|
||||
assert dist_token.shape == (1, 768)
|
||||
# test load checkpoint with different img_size
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['img_size'] = 384
|
||||
model = DistilledVisionTransformer(**cfg)
|
||||
load_checkpoint(model, checkpoint, strict=True)
|
||||
resized_pos_embed = timm_resize_pos_embed(
|
||||
pretrain_pos_embed, model.pos_embed, num_tokens=2)
|
||||
self.assertTrue(torch.allclose(model.pos_embed, resized_pos_embed))
|
||||
|
||||
# Test multiple out_indices
|
||||
model = DistilledVisionTransformer(
|
||||
**cfg_ori, out_indices=(0, 1, 2, 3), output_cls_token=False)
|
||||
outs = model(imgs)
|
||||
for out in outs:
|
||||
assert out.shape == (1, 768, 14, 14)
|
||||
os.remove(checkpoint)
|
||||
|
||||
def test_forward(self):
|
||||
imgs = torch.randn(3, 3, 224, 224)
|
||||
|
||||
# test with_cls_token=False
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['with_cls_token'] = False
|
||||
cfg['output_cls_token'] = True
|
||||
with self.assertRaisesRegex(AssertionError, 'but got False'):
|
||||
DistilledVisionTransformer(**cfg)
|
||||
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['with_cls_token'] = False
|
||||
cfg['output_cls_token'] = False
|
||||
model = DistilledVisionTransformer(**cfg)
|
||||
outs = model(imgs)
|
||||
self.assertIsInstance(outs, tuple)
|
||||
self.assertEqual(len(outs), 1)
|
||||
patch_token = outs[-1]
|
||||
self.assertEqual(patch_token.shape, (3, 768, 14, 14))
|
||||
|
||||
# test with output_cls_token
|
||||
cfg = deepcopy(self.cfg)
|
||||
model = DistilledVisionTransformer(**cfg)
|
||||
outs = model(imgs)
|
||||
self.assertIsInstance(outs, tuple)
|
||||
self.assertEqual(len(outs), 1)
|
||||
patch_token, cls_token, dist_token = outs[-1]
|
||||
self.assertEqual(patch_token.shape, (3, 768, 14, 14))
|
||||
self.assertEqual(cls_token.shape, (3, 768))
|
||||
self.assertEqual(dist_token.shape, (3, 768))
|
||||
|
||||
# test without output_cls_token
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['output_cls_token'] = False
|
||||
model = DistilledVisionTransformer(**cfg)
|
||||
outs = model(imgs)
|
||||
self.assertIsInstance(outs, tuple)
|
||||
self.assertEqual(len(outs), 1)
|
||||
patch_token = outs[-1]
|
||||
self.assertEqual(patch_token.shape, (3, 768, 14, 14))
|
||||
|
||||
# Test forward with multi out indices
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['out_indices'] = [-3, -2, -1]
|
||||
model = DistilledVisionTransformer(**cfg)
|
||||
outs = model(imgs)
|
||||
self.assertIsInstance(outs, tuple)
|
||||
self.assertEqual(len(outs), 3)
|
||||
for out in outs:
|
||||
patch_token, cls_token, dist_token = out
|
||||
self.assertEqual(patch_token.shape, (3, 768, 14, 14))
|
||||
self.assertEqual(cls_token.shape, (3, 768))
|
||||
self.assertEqual(dist_token.shape, (3, 768))
|
||||
|
||||
# Test forward with dynamic input size
|
||||
imgs1 = torch.randn(3, 3, 224, 224)
|
||||
imgs2 = torch.randn(3, 3, 256, 256)
|
||||
imgs3 = torch.randn(3, 3, 256, 309)
|
||||
cfg = deepcopy(self.cfg)
|
||||
model = DistilledVisionTransformer(**cfg)
|
||||
for imgs in [imgs1, imgs2, imgs3]:
|
||||
outs = model(imgs)
|
||||
self.assertIsInstance(outs, tuple)
|
||||
self.assertEqual(len(outs), 1)
|
||||
patch_token, cls_token, dist_token = outs[-1]
|
||||
expect_feat_shape = (math.ceil(imgs.shape[2] / 16),
|
||||
math.ceil(imgs.shape[3] / 16))
|
||||
self.assertEqual(patch_token.shape, (3, 768, *expect_feat_shape))
|
||||
self.assertEqual(cls_token.shape, (3, 768))
|
||||
self.assertEqual(dist_token.shape, (3, 768))
|
||||
|
@ -1,3 +1,4 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import pytest
|
||||
import torch
|
||||
from torch.nn.modules import GroupNorm
|
||||
|
@ -1,7 +1,7 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from copy import deepcopy
|
||||
from unittest import TestCase
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from torch.nn.modules import GroupNorm
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
@ -25,51 +25,95 @@ def check_norm_state(modules, train_state):
|
||||
return True
|
||||
|
||||
|
||||
def test_mlp_mixer_backbone():
|
||||
cfg_ori = dict(
|
||||
arch='b',
|
||||
img_size=224,
|
||||
patch_size=16,
|
||||
drop_rate=0.1,
|
||||
init_cfg=[
|
||||
class TestMLPMixer(TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.cfg = dict(
|
||||
arch='b',
|
||||
img_size=224,
|
||||
patch_size=16,
|
||||
drop_rate=0.1,
|
||||
init_cfg=[
|
||||
dict(
|
||||
type='Kaiming',
|
||||
layer='Conv2d',
|
||||
mode='fan_in',
|
||||
nonlinearity='linear')
|
||||
])
|
||||
|
||||
def test_arch(self):
|
||||
# Test invalid default arch
|
||||
with self.assertRaisesRegex(AssertionError, 'not in default archs'):
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['arch'] = 'unknown'
|
||||
MlpMixer(**cfg)
|
||||
|
||||
# Test invalid custom arch
|
||||
with self.assertRaisesRegex(AssertionError, 'Custom arch needs'):
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['arch'] = {
|
||||
'embed_dims': 24,
|
||||
'num_layers': 16,
|
||||
'tokens_mlp_dims': 4096
|
||||
}
|
||||
MlpMixer(**cfg)
|
||||
|
||||
# Test custom arch
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['arch'] = {
|
||||
'embed_dims': 128,
|
||||
'num_layers': 6,
|
||||
'tokens_mlp_dims': 256,
|
||||
'channels_mlp_dims': 1024
|
||||
}
|
||||
model = MlpMixer(**cfg)
|
||||
self.assertEqual(model.embed_dims, 128)
|
||||
self.assertEqual(model.num_layers, 6)
|
||||
for layer in model.layers:
|
||||
self.assertEqual(layer.token_mix.feedforward_channels, 256)
|
||||
self.assertEqual(layer.channel_mix.feedforward_channels, 1024)
|
||||
|
||||
def test_init_weights(self):
|
||||
# test weight init cfg
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['init_cfg'] = [
|
||||
dict(
|
||||
type='Kaiming',
|
||||
layer='Conv2d',
|
||||
mode='fan_in',
|
||||
nonlinearity='linear')
|
||||
])
|
||||
]
|
||||
model = MlpMixer(**cfg)
|
||||
ori_weight = model.patch_embed.projection.weight.clone().detach()
|
||||
model.init_weights()
|
||||
initialized_weight = model.patch_embed.projection.weight
|
||||
self.assertFalse(torch.allclose(ori_weight, initialized_weight))
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# test invalid arch
|
||||
cfg = deepcopy(cfg_ori)
|
||||
cfg['arch'] = 'unknown'
|
||||
MlpMixer(**cfg)
|
||||
def test_forward(self):
|
||||
imgs = torch.randn(3, 3, 224, 224)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# test arch without essential keys
|
||||
cfg = deepcopy(cfg_ori)
|
||||
cfg['arch'] = {
|
||||
'num_layers': 24,
|
||||
'tokens_mlp_dims': 384,
|
||||
'channels_mlp_dims': 3072,
|
||||
}
|
||||
MlpMixer(**cfg)
|
||||
# test forward with single out indices
|
||||
cfg = deepcopy(self.cfg)
|
||||
model = MlpMixer(**cfg)
|
||||
outs = model(imgs)
|
||||
self.assertIsInstance(outs, tuple)
|
||||
self.assertEqual(len(outs), 1)
|
||||
feat = outs[-1]
|
||||
self.assertEqual(feat.shape, (3, 768, 196))
|
||||
|
||||
# Test MlpMixer base model with input size of 224
|
||||
# and patch size of 16
|
||||
model = MlpMixer(**cfg_ori)
|
||||
model.init_weights()
|
||||
model.train()
|
||||
# test forward with multi out indices
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['out_indices'] = [-3, -2, -1]
|
||||
model = MlpMixer(**cfg)
|
||||
outs = model(imgs)
|
||||
self.assertIsInstance(outs, tuple)
|
||||
self.assertEqual(len(outs), 3)
|
||||
for feat in outs:
|
||||
self.assertEqual(feat.shape, (3, 768, 196))
|
||||
|
||||
assert check_norm_state(model.modules(), True)
|
||||
|
||||
imgs = torch.randn(3, 3, 224, 224)
|
||||
feat = model(imgs)[-1]
|
||||
assert feat.shape == (3, 768, 196)
|
||||
|
||||
# Test MlpMixer with multi out indices
|
||||
cfg = deepcopy(cfg_ori)
|
||||
cfg['out_indices'] = [-3, -2, -1]
|
||||
model = MlpMixer(**cfg)
|
||||
for out in model(imgs):
|
||||
assert out.shape == (3, 768, 196)
|
||||
# test with invalid input shape
|
||||
imgs2 = torch.randn(3, 3, 256, 256)
|
||||
cfg = deepcopy(self.cfg)
|
||||
model = MlpMixer(**cfg)
|
||||
with self.assertRaisesRegex(AssertionError, 'dynamic input shape.'):
|
||||
model(imgs2)
|
||||
|
@ -1,3 +1,4 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
|
@ -5,7 +5,7 @@ import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule
|
||||
from mmcv.utils.parrots_wrapper import _BatchNorm
|
||||
|
||||
from mmcls.models.backbones import ResNet, ResNetV1d
|
||||
from mmcls.models.backbones import ResNet, ResNetV1c, ResNetV1d
|
||||
from mmcls.models.backbones.resnet import (BasicBlock, Bottleneck, ResLayer,
|
||||
get_expansion)
|
||||
|
||||
@ -526,6 +526,45 @@ def test_resnet():
|
||||
assert not all_zeros(m.norm2)
|
||||
|
||||
|
||||
def test_resnet_v1c():
|
||||
model = ResNetV1c(depth=50, out_indices=(0, 1, 2, 3))
|
||||
model.init_weights()
|
||||
model.train()
|
||||
|
||||
assert len(model.stem) == 3
|
||||
for i in range(3):
|
||||
assert isinstance(model.stem[i], ConvModule)
|
||||
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
feat = model.stem(imgs)
|
||||
assert feat.shape == (1, 64, 112, 112)
|
||||
feat = model(imgs)
|
||||
assert len(feat) == 4
|
||||
assert feat[0].shape == (1, 256, 56, 56)
|
||||
assert feat[1].shape == (1, 512, 28, 28)
|
||||
assert feat[2].shape == (1, 1024, 14, 14)
|
||||
assert feat[3].shape == (1, 2048, 7, 7)
|
||||
|
||||
# Test ResNet50V1d with first stage frozen
|
||||
frozen_stages = 1
|
||||
model = ResNetV1d(depth=50, frozen_stages=frozen_stages)
|
||||
assert len(model.stem) == 3
|
||||
for i in range(3):
|
||||
assert isinstance(model.stem[i], ConvModule)
|
||||
model.init_weights()
|
||||
model.train()
|
||||
check_norm_state(model.stem, False)
|
||||
for param in model.stem.parameters():
|
||||
assert param.requires_grad is False
|
||||
for i in range(1, frozen_stages + 1):
|
||||
layer = getattr(model, f'layer{i}')
|
||||
for mod in layer.modules():
|
||||
if isinstance(mod, _BatchNorm):
|
||||
assert mod.training is False
|
||||
for param in layer.parameters():
|
||||
assert param.requires_grad is False
|
||||
|
||||
|
||||
def test_resnet_v1d():
|
||||
model = ResNetV1d(depth=50, out_indices=(0, 1, 2, 3))
|
||||
model.init_weights()
|
||||
|
@ -1,16 +1,18 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import math
|
||||
import os
|
||||
import tempfile
|
||||
from math import ceil
|
||||
from copy import deepcopy
|
||||
from itertools import chain
|
||||
from unittest import TestCase
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from mmcv.runner import load_checkpoint, save_checkpoint
|
||||
from mmcv.utils.parrots_wrapper import _BatchNorm
|
||||
|
||||
from mmcls.models.backbones import SwinTransformer
|
||||
from mmcls.models.backbones.swin_transformer import SwinBlock
|
||||
from .utils import timm_resize_pos_embed
|
||||
|
||||
|
||||
def check_norm_state(modules, train_state):
|
||||
@ -22,215 +24,232 @@ def check_norm_state(modules, train_state):
|
||||
return True
|
||||
|
||||
|
||||
def test_assertion():
|
||||
"""Test Swin Transformer backbone."""
|
||||
with pytest.raises(AssertionError):
|
||||
# Swin Transformer arch string should be in
|
||||
SwinTransformer(arch='unknown')
|
||||
class TestSwinTransformer(TestCase):
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# Swin Transformer arch dict should include 'embed_dims',
|
||||
# 'depths' and 'num_head' keys.
|
||||
SwinTransformer(arch=dict(embed_dims=96, depths=[2, 2, 18, 2]))
|
||||
def setUp(self):
|
||||
self.cfg = dict(
|
||||
arch='b', img_size=224, patch_size=4, drop_path_rate=0.1)
|
||||
|
||||
def test_arch(self):
|
||||
# Test invalid default arch
|
||||
with self.assertRaisesRegex(AssertionError, 'not in default archs'):
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['arch'] = 'unknown'
|
||||
SwinTransformer(**cfg)
|
||||
|
||||
def test_forward():
|
||||
# Test tiny arch forward
|
||||
model = SwinTransformer(arch='Tiny')
|
||||
model.init_weights()
|
||||
model.train()
|
||||
# Test invalid custom arch
|
||||
with self.assertRaisesRegex(AssertionError, 'Custom arch needs'):
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['arch'] = {
|
||||
'embed_dims': 96,
|
||||
'num_heads': [3, 6, 12, 16],
|
||||
}
|
||||
SwinTransformer(**cfg)
|
||||
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
output = model(imgs)
|
||||
assert len(output) == 1
|
||||
assert output[0].shape == (1, 768, 7, 7)
|
||||
# Test custom arch
|
||||
cfg = deepcopy(self.cfg)
|
||||
depths = [2, 2, 4, 2]
|
||||
num_heads = [6, 12, 6, 12]
|
||||
cfg['arch'] = {
|
||||
'embed_dims': 256,
|
||||
'depths': depths,
|
||||
'num_heads': num_heads
|
||||
}
|
||||
model = SwinTransformer(**cfg)
|
||||
for i, stage in enumerate(model.stages):
|
||||
self.assertEqual(stage.embed_dims, 256 * (2**i))
|
||||
self.assertEqual(len(stage.blocks), depths[i])
|
||||
self.assertEqual(stage.blocks[0].attn.w_msa.num_heads,
|
||||
num_heads[i])
|
||||
|
||||
# Test small arch forward
|
||||
model = SwinTransformer(arch='small')
|
||||
model.init_weights()
|
||||
model.train()
|
||||
def test_init_weights(self):
|
||||
# test weight init cfg
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['use_abs_pos_embed'] = True
|
||||
cfg['init_cfg'] = [
|
||||
dict(
|
||||
type='Kaiming',
|
||||
layer='Conv2d',
|
||||
mode='fan_in',
|
||||
nonlinearity='linear')
|
||||
]
|
||||
model = SwinTransformer(**cfg)
|
||||
ori_weight = model.patch_embed.projection.weight.clone().detach()
|
||||
# The pos_embed is all zero before initialize
|
||||
self.assertTrue(
|
||||
torch.allclose(model.absolute_pos_embed, torch.tensor(0.)))
|
||||
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
output = model(imgs)
|
||||
assert len(output) == 1
|
||||
assert output[0].shape == (1, 768, 7, 7)
|
||||
model.init_weights()
|
||||
initialized_weight = model.patch_embed.projection.weight
|
||||
self.assertFalse(torch.allclose(ori_weight, initialized_weight))
|
||||
self.assertFalse(
|
||||
torch.allclose(model.absolute_pos_embed, torch.tensor(0.)))
|
||||
|
||||
# Test base arch forward
|
||||
model = SwinTransformer(arch='B')
|
||||
model.init_weights()
|
||||
model.train()
|
||||
pretrain_pos_embed = model.absolute_pos_embed.clone().detach()
|
||||
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
output = model(imgs)
|
||||
assert len(output) == 1
|
||||
assert output[0].shape == (1, 1024, 7, 7)
|
||||
tmpdir = tempfile.gettempdir()
|
||||
# Save v3 checkpoints
|
||||
checkpoint_v2 = os.path.join(tmpdir, 'v3.pth')
|
||||
save_checkpoint(model, checkpoint_v2)
|
||||
# Save v1 checkpoints
|
||||
setattr(model, 'norm', model.norm3)
|
||||
setattr(model.stages[0].blocks[1].attn, 'attn_mask',
|
||||
torch.zeros(64, 49, 49))
|
||||
model._version = 1
|
||||
del model.norm3
|
||||
checkpoint_v1 = os.path.join(tmpdir, 'v1.pth')
|
||||
save_checkpoint(model, checkpoint_v1)
|
||||
|
||||
# Test large arch forward
|
||||
model = SwinTransformer(arch='l')
|
||||
model.init_weights()
|
||||
model.train()
|
||||
# test load v1 checkpoint
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['use_abs_pos_embed'] = True
|
||||
model = SwinTransformer(**cfg)
|
||||
load_checkpoint(model, checkpoint_v1, strict=True)
|
||||
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
output = model(imgs)
|
||||
assert len(output) == 1
|
||||
assert output[0].shape == (1, 1536, 7, 7)
|
||||
# test load v3 checkpoint
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['use_abs_pos_embed'] = True
|
||||
model = SwinTransformer(**cfg)
|
||||
load_checkpoint(model, checkpoint_v2, strict=True)
|
||||
|
||||
# Test base arch with window_size=12, image_size=384
|
||||
model = SwinTransformer(
|
||||
arch='base',
|
||||
img_size=384,
|
||||
stage_cfgs=dict(block_cfgs=dict(window_size=12)))
|
||||
model.init_weights()
|
||||
model.train()
|
||||
# test load v3 checkpoint with different img_size
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['img_size'] = 384
|
||||
cfg['use_abs_pos_embed'] = True
|
||||
model = SwinTransformer(**cfg)
|
||||
load_checkpoint(model, checkpoint_v2, strict=True)
|
||||
resized_pos_embed = timm_resize_pos_embed(
|
||||
pretrain_pos_embed, model.absolute_pos_embed, num_tokens=0)
|
||||
self.assertTrue(
|
||||
torch.allclose(model.absolute_pos_embed, resized_pos_embed))
|
||||
|
||||
imgs = torch.randn(1, 3, 384, 384)
|
||||
output = model(imgs)
|
||||
assert len(output) == 1
|
||||
assert output[0].shape == (1, 1024, 12, 12)
|
||||
os.remove(checkpoint_v1)
|
||||
os.remove(checkpoint_v2)
|
||||
|
||||
# Test multiple output indices
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
model = SwinTransformer(arch='base', out_indices=(0, 1, 2, 3))
|
||||
outs = model(imgs)
|
||||
assert outs[0].shape == (1, 256, 28, 28)
|
||||
assert outs[1].shape == (1, 512, 14, 14)
|
||||
assert outs[2].shape == (1, 1024, 7, 7)
|
||||
assert outs[3].shape == (1, 1024, 7, 7)
|
||||
def test_forward(self):
|
||||
imgs = torch.randn(3, 3, 224, 224)
|
||||
|
||||
# Test base arch with with checkpoint forward
|
||||
model = SwinTransformer(arch='B', with_cp=True)
|
||||
for m in model.modules():
|
||||
if isinstance(m, SwinBlock):
|
||||
assert m.with_cp
|
||||
model.init_weights()
|
||||
model.train()
|
||||
cfg = deepcopy(self.cfg)
|
||||
model = SwinTransformer(**cfg)
|
||||
outs = model(imgs)
|
||||
self.assertIsInstance(outs, tuple)
|
||||
self.assertEqual(len(outs), 1)
|
||||
feat = outs[-1]
|
||||
self.assertEqual(feat.shape, (3, 1024, 7, 7))
|
||||
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
output = model(imgs)
|
||||
assert len(output) == 1
|
||||
assert output[0].shape == (1, 1024, 7, 7)
|
||||
# test with window_size=12
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['window_size'] = 12
|
||||
model = SwinTransformer(**cfg)
|
||||
outs = model(torch.randn(3, 3, 384, 384))
|
||||
self.assertIsInstance(outs, tuple)
|
||||
self.assertEqual(len(outs), 1)
|
||||
feat = outs[-1]
|
||||
self.assertEqual(feat.shape, (3, 1024, 12, 12))
|
||||
with self.assertRaisesRegex(AssertionError, r'the window size \(12\)'):
|
||||
model(torch.randn(3, 3, 224, 224))
|
||||
|
||||
# test with pad_small_map=True
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['window_size'] = 12
|
||||
cfg['pad_small_map'] = True
|
||||
model = SwinTransformer(**cfg)
|
||||
outs = model(torch.randn(3, 3, 224, 224))
|
||||
self.assertIsInstance(outs, tuple)
|
||||
self.assertEqual(len(outs), 1)
|
||||
feat = outs[-1]
|
||||
self.assertEqual(feat.shape, (3, 1024, 7, 7))
|
||||
|
||||
def test_structure():
|
||||
# Test small with use_abs_pos_embed = True
|
||||
model = SwinTransformer(arch='small', use_abs_pos_embed=True)
|
||||
assert model.absolute_pos_embed.shape == (1, 3136, 96)
|
||||
# test multiple output indices
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['out_indices'] = (0, 1, 2, 3)
|
||||
model = SwinTransformer(**cfg)
|
||||
outs = model(imgs)
|
||||
self.assertIsInstance(outs, tuple)
|
||||
self.assertEqual(len(outs), 4)
|
||||
for stride, out in zip([2, 4, 8, 8], outs):
|
||||
self.assertEqual(out.shape,
|
||||
(3, 128 * stride, 56 // stride, 56 // stride))
|
||||
|
||||
# Test small with use_abs_pos_embed = False
|
||||
model = SwinTransformer(arch='small', use_abs_pos_embed=False)
|
||||
assert not hasattr(model, 'absolute_pos_embed')
|
||||
# test with checkpoint forward
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['with_cp'] = True
|
||||
model = SwinTransformer(**cfg)
|
||||
for m in model.modules():
|
||||
if isinstance(m, SwinBlock):
|
||||
self.assertTrue(m.with_cp)
|
||||
model.init_weights()
|
||||
model.train()
|
||||
|
||||
# Test small with auto_pad = True
|
||||
model = SwinTransformer(
|
||||
arch='small',
|
||||
auto_pad=True,
|
||||
stage_cfgs=dict(
|
||||
block_cfgs={'window_size': 7},
|
||||
downsample_cfg={
|
||||
'kernel_size': (3, 2),
|
||||
}))
|
||||
outs = model(imgs)
|
||||
self.assertIsInstance(outs, tuple)
|
||||
self.assertEqual(len(outs), 1)
|
||||
feat = outs[-1]
|
||||
self.assertEqual(feat.shape, (3, 1024, 7, 7))
|
||||
|
||||
# stage 1
|
||||
input_h = int(224 / 4 / 3)
|
||||
expect_h = ceil(input_h / 7) * 7
|
||||
input_w = int(224 / 4 / 2)
|
||||
expect_w = ceil(input_w / 7) * 7
|
||||
assert model.stages[1].blocks[0].attn.pad_b == expect_h - input_h
|
||||
assert model.stages[1].blocks[0].attn.pad_r == expect_w - input_w
|
||||
# test with dynamic input shape
|
||||
imgs1 = torch.randn(3, 3, 224, 224)
|
||||
imgs2 = torch.randn(3, 3, 256, 256)
|
||||
imgs3 = torch.randn(3, 3, 256, 309)
|
||||
cfg = deepcopy(self.cfg)
|
||||
model = SwinTransformer(**cfg)
|
||||
for imgs in [imgs1, imgs2, imgs3]:
|
||||
outs = model(imgs)
|
||||
self.assertIsInstance(outs, tuple)
|
||||
self.assertEqual(len(outs), 1)
|
||||
feat = outs[-1]
|
||||
expect_feat_shape = (math.ceil(imgs.shape[2] / 32),
|
||||
math.ceil(imgs.shape[3] / 32))
|
||||
self.assertEqual(feat.shape, (3, 1024, *expect_feat_shape))
|
||||
|
||||
# stage 2
|
||||
input_h = int(224 / 4 / 3 / 3)
|
||||
# input_h is smaller than window_size, shrink the window_size to input_h.
|
||||
expect_h = input_h
|
||||
input_w = int(224 / 4 / 2 / 2)
|
||||
expect_w = ceil(input_w / input_h) * input_h
|
||||
assert model.stages[2].blocks[0].attn.pad_b == expect_h - input_h
|
||||
assert model.stages[2].blocks[0].attn.pad_r == expect_w - input_w
|
||||
def test_structure(self):
|
||||
# test drop_path_rate decay
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['drop_path_rate'] = 0.2
|
||||
model = SwinTransformer(**cfg)
|
||||
depths = model.arch_settings['depths']
|
||||
blocks = chain(*[stage.blocks for stage in model.stages])
|
||||
for i, block in enumerate(blocks):
|
||||
expect_prob = 0.2 / (sum(depths) - 1) * i
|
||||
self.assertAlmostEqual(block.ffn.dropout_layer.drop_prob,
|
||||
expect_prob)
|
||||
self.assertAlmostEqual(block.attn.drop.drop_prob, expect_prob)
|
||||
|
||||
# stage 3
|
||||
input_h = int(224 / 4 / 3 / 3 / 3)
|
||||
expect_h = input_h
|
||||
input_w = int(224 / 4 / 2 / 2 / 2)
|
||||
expect_w = ceil(input_w / input_h) * input_h
|
||||
assert model.stages[3].blocks[0].attn.pad_b == expect_h - input_h
|
||||
assert model.stages[3].blocks[0].attn.pad_r == expect_w - input_w
|
||||
# test Swin-Transformer with norm_eval=True
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['norm_eval'] = True
|
||||
cfg['norm_cfg'] = dict(type='BN')
|
||||
cfg['stage_cfgs'] = dict(block_cfgs=dict(norm_cfg=dict(type='BN')))
|
||||
model = SwinTransformer(**cfg)
|
||||
model.init_weights()
|
||||
model.train()
|
||||
self.assertTrue(check_norm_state(model.modules(), False))
|
||||
|
||||
# Test small with auto_pad = False
|
||||
with pytest.raises(AssertionError):
|
||||
model = SwinTransformer(
|
||||
arch='small',
|
||||
auto_pad=False,
|
||||
stage_cfgs=dict(
|
||||
block_cfgs={'window_size': 7},
|
||||
downsample_cfg={
|
||||
'kernel_size': (3, 2),
|
||||
}))
|
||||
# test Swin-Transformer with first stage frozen.
|
||||
cfg = deepcopy(self.cfg)
|
||||
frozen_stages = 0
|
||||
cfg['frozen_stages'] = frozen_stages
|
||||
cfg['out_indices'] = (0, 1, 2, 3)
|
||||
model = SwinTransformer(**cfg)
|
||||
model.init_weights()
|
||||
model.train()
|
||||
|
||||
# Test drop_path_rate decay
|
||||
model = SwinTransformer(
|
||||
arch='small',
|
||||
drop_path_rate=0.2,
|
||||
)
|
||||
depths = model.arch_settings['depths']
|
||||
pos = 0
|
||||
for i, depth in enumerate(depths):
|
||||
for j in range(depth):
|
||||
block = model.stages[i].blocks[j]
|
||||
expect_prob = 0.2 / (sum(depths) - 1) * pos
|
||||
assert np.isclose(block.ffn.dropout_layer.drop_prob, expect_prob)
|
||||
assert np.isclose(block.attn.drop.drop_prob, expect_prob)
|
||||
pos += 1
|
||||
# the patch_embed and first stage should not require grad.
|
||||
self.assertFalse(model.patch_embed.training)
|
||||
for param in model.patch_embed.parameters():
|
||||
self.assertFalse(param.requires_grad)
|
||||
for i in range(frozen_stages + 1):
|
||||
stage = model.stages[i]
|
||||
for param in stage.parameters():
|
||||
self.assertFalse(param.requires_grad)
|
||||
for param in model.norm0.parameters():
|
||||
self.assertFalse(param.requires_grad)
|
||||
|
||||
# Test Swin-Transformer with norm_eval=True
|
||||
model = SwinTransformer(
|
||||
arch='small',
|
||||
norm_eval=True,
|
||||
norm_cfg=dict(type='BN'),
|
||||
stage_cfgs=dict(block_cfgs=dict(norm_cfg=dict(type='BN'))),
|
||||
)
|
||||
model.init_weights()
|
||||
model.train()
|
||||
assert check_norm_state(model.modules(), False)
|
||||
|
||||
# Test Swin-Transformer with first stage frozen.
|
||||
frozen_stages = 0
|
||||
model = SwinTransformer(
|
||||
arch='small', frozen_stages=frozen_stages, out_indices=(0, 1, 2, 3))
|
||||
model.init_weights()
|
||||
model.train()
|
||||
|
||||
assert model.patch_embed.training is False
|
||||
for param in model.patch_embed.parameters():
|
||||
assert param.requires_grad is False
|
||||
for i in range(frozen_stages + 1):
|
||||
stage = model.stages[i]
|
||||
for param in stage.parameters():
|
||||
assert param.requires_grad is False
|
||||
for param in model.norm0.parameters():
|
||||
assert param.requires_grad is False
|
||||
|
||||
# the second stage should require grad.
|
||||
stage = model.stages[1]
|
||||
for param in stage.parameters():
|
||||
assert param.requires_grad is True
|
||||
for param in model.norm1.parameters():
|
||||
assert param.requires_grad is True
|
||||
|
||||
|
||||
def test_load_checkpoint():
|
||||
model = SwinTransformer(arch='tiny')
|
||||
ckpt_path = os.path.join(tempfile.gettempdir(), 'ckpt.pth')
|
||||
|
||||
assert model._version == 2
|
||||
|
||||
# test load v2 checkpoint
|
||||
save_checkpoint(model, ckpt_path)
|
||||
load_checkpoint(model, ckpt_path, strict=True)
|
||||
|
||||
# test load v1 checkpoint
|
||||
setattr(model, 'norm', model.norm3)
|
||||
model._version = 1
|
||||
del model.norm3
|
||||
save_checkpoint(model, ckpt_path)
|
||||
model = SwinTransformer(arch='tiny')
|
||||
load_checkpoint(model, ckpt_path, strict=True)
|
||||
# the second stage should require grad.
|
||||
for i in range(frozen_stages + 1, 4):
|
||||
stage = model.stages[i]
|
||||
for param in stage.parameters():
|
||||
self.assertTrue(param.requires_grad)
|
||||
norm = getattr(model, f'norm{i}')
|
||||
for param in norm.parameters():
|
||||
self.assertTrue(param.requires_grad)
|
||||
|
@ -1,84 +1,157 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import math
|
||||
import os
|
||||
import tempfile
|
||||
from copy import deepcopy
|
||||
from unittest import TestCase
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from torch.nn.modules import GroupNorm
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
from mmcv.runner import load_checkpoint, save_checkpoint
|
||||
|
||||
from mmcls.models.backbones import T2T_ViT
|
||||
from .utils import timm_resize_pos_embed
|
||||
|
||||
|
||||
def is_norm(modules):
|
||||
"""Check if is one of the norms."""
|
||||
if isinstance(modules, (GroupNorm, _BatchNorm)):
|
||||
return True
|
||||
return False
|
||||
class TestT2TViT(TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.cfg = dict(
|
||||
img_size=224,
|
||||
in_channels=3,
|
||||
embed_dims=384,
|
||||
t2t_cfg=dict(
|
||||
token_dims=64,
|
||||
use_performer=False,
|
||||
),
|
||||
num_layers=14,
|
||||
drop_path_rate=0.1)
|
||||
|
||||
def check_norm_state(modules, train_state):
|
||||
"""Check if norm layer is in correct train state."""
|
||||
for mod in modules:
|
||||
if isinstance(mod, _BatchNorm):
|
||||
if mod.training != train_state:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def test_vit_backbone():
|
||||
|
||||
cfg_ori = dict(
|
||||
img_size=224,
|
||||
in_channels=3,
|
||||
embed_dims=384,
|
||||
t2t_cfg=dict(
|
||||
token_dims=64,
|
||||
use_performer=False,
|
||||
),
|
||||
num_layers=14,
|
||||
layer_cfgs=dict(
|
||||
num_heads=6,
|
||||
feedforward_channels=3 * 384, # mlp_ratio = 3
|
||||
),
|
||||
drop_path_rate=0.1,
|
||||
init_cfg=[
|
||||
dict(type='TruncNormal', layer='Linear', std=.02),
|
||||
dict(type='Constant', layer='LayerNorm', val=1., bias=0.),
|
||||
])
|
||||
|
||||
with pytest.raises(NotImplementedError):
|
||||
# test if use performer
|
||||
cfg = deepcopy(cfg_ori)
|
||||
def test_structure(self):
|
||||
# The performer hasn't been implemented
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['t2t_cfg']['use_performer'] = True
|
||||
T2T_ViT(**cfg)
|
||||
with self.assertRaises(NotImplementedError):
|
||||
T2T_ViT(**cfg)
|
||||
|
||||
# Test T2T-ViT model with input size of 224
|
||||
model = T2T_ViT(**cfg_ori)
|
||||
model.init_weights()
|
||||
model.train()
|
||||
# Test out_indices
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['out_indices'] = {1: 1}
|
||||
with self.assertRaisesRegex(AssertionError, "get <class 'dict'>"):
|
||||
T2T_ViT(**cfg)
|
||||
cfg['out_indices'] = [0, 15]
|
||||
with self.assertRaisesRegex(AssertionError, 'Invalid out_indices 15'):
|
||||
T2T_ViT(**cfg)
|
||||
|
||||
assert check_norm_state(model.modules(), True)
|
||||
# Test model structure
|
||||
cfg = deepcopy(self.cfg)
|
||||
model = T2T_ViT(**cfg)
|
||||
self.assertEqual(len(model.encoder), 14)
|
||||
dpr_inc = 0.1 / (14 - 1)
|
||||
dpr = 0
|
||||
for layer in model.encoder:
|
||||
self.assertEqual(layer.attn.embed_dims, 384)
|
||||
# The default mlp_ratio is 3
|
||||
self.assertEqual(layer.ffn.feedforward_channels, 384 * 3)
|
||||
self.assertAlmostEqual(layer.attn.out_drop.drop_prob, dpr)
|
||||
self.assertAlmostEqual(layer.ffn.dropout_layer.drop_prob, dpr)
|
||||
dpr += dpr_inc
|
||||
|
||||
imgs = torch.randn(3, 3, 224, 224)
|
||||
patch_token, cls_token = model(imgs)[-1]
|
||||
assert cls_token.shape == (3, 384)
|
||||
assert patch_token.shape == (3, 384, 14, 14)
|
||||
def test_init_weights(self):
|
||||
# test weight init cfg
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['init_cfg'] = [dict(type='TruncNormal', layer='Linear', std=.02)]
|
||||
model = T2T_ViT(**cfg)
|
||||
ori_weight = model.tokens_to_token.project.weight.clone().detach()
|
||||
|
||||
# Test custom arch T2T-ViT without output cls token
|
||||
cfg = deepcopy(cfg_ori)
|
||||
cfg['embed_dims'] = 256
|
||||
cfg['num_layers'] = 16
|
||||
cfg['layer_cfgs'] = dict(num_heads=8, feedforward_channels=1024)
|
||||
cfg['output_cls_token'] = False
|
||||
model.init_weights()
|
||||
initialized_weight = model.tokens_to_token.project.weight
|
||||
self.assertFalse(torch.allclose(ori_weight, initialized_weight))
|
||||
|
||||
model = T2T_ViT(**cfg)
|
||||
patch_token = model(imgs)[-1]
|
||||
assert patch_token.shape == (3, 256, 14, 14)
|
||||
# test load checkpoint
|
||||
pretrain_pos_embed = model.pos_embed.clone().detach()
|
||||
tmpdir = tempfile.gettempdir()
|
||||
checkpoint = os.path.join(tmpdir, 'test.pth')
|
||||
save_checkpoint(model, checkpoint)
|
||||
cfg = deepcopy(self.cfg)
|
||||
model = T2T_ViT(**cfg)
|
||||
load_checkpoint(model, checkpoint, strict=True)
|
||||
self.assertTrue(torch.allclose(model.pos_embed, pretrain_pos_embed))
|
||||
|
||||
# Test T2T_ViT with multi out indices
|
||||
cfg = deepcopy(cfg_ori)
|
||||
cfg['out_indices'] = [-3, -2, -1]
|
||||
model = T2T_ViT(**cfg)
|
||||
for out in model(imgs):
|
||||
assert out[0].shape == (3, 384, 14, 14)
|
||||
assert out[1].shape == (3, 384)
|
||||
# test load checkpoint with different img_size
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['img_size'] = 384
|
||||
model = T2T_ViT(**cfg)
|
||||
load_checkpoint(model, checkpoint, strict=True)
|
||||
resized_pos_embed = timm_resize_pos_embed(pretrain_pos_embed,
|
||||
model.pos_embed)
|
||||
self.assertTrue(torch.allclose(model.pos_embed, resized_pos_embed))
|
||||
|
||||
os.remove(checkpoint)
|
||||
|
||||
def test_forward(self):
|
||||
imgs = torch.randn(3, 3, 224, 224)
|
||||
|
||||
# test with_cls_token=False
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['with_cls_token'] = False
|
||||
cfg['output_cls_token'] = True
|
||||
with self.assertRaisesRegex(AssertionError, 'but got False'):
|
||||
T2T_ViT(**cfg)
|
||||
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['with_cls_token'] = False
|
||||
cfg['output_cls_token'] = False
|
||||
model = T2T_ViT(**cfg)
|
||||
outs = model(imgs)
|
||||
self.assertIsInstance(outs, tuple)
|
||||
self.assertEqual(len(outs), 1)
|
||||
patch_token = outs[-1]
|
||||
self.assertEqual(patch_token.shape, (3, 384, 14, 14))
|
||||
|
||||
# test with output_cls_token
|
||||
cfg = deepcopy(self.cfg)
|
||||
model = T2T_ViT(**cfg)
|
||||
outs = model(imgs)
|
||||
self.assertIsInstance(outs, tuple)
|
||||
self.assertEqual(len(outs), 1)
|
||||
patch_token, cls_token = outs[-1]
|
||||
self.assertEqual(patch_token.shape, (3, 384, 14, 14))
|
||||
self.assertEqual(cls_token.shape, (3, 384))
|
||||
|
||||
# test without output_cls_token
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['output_cls_token'] = False
|
||||
model = T2T_ViT(**cfg)
|
||||
outs = model(imgs)
|
||||
self.assertIsInstance(outs, tuple)
|
||||
self.assertEqual(len(outs), 1)
|
||||
patch_token = outs[-1]
|
||||
self.assertEqual(patch_token.shape, (3, 384, 14, 14))
|
||||
|
||||
# Test forward with multi out indices
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['out_indices'] = [-3, -2, -1]
|
||||
model = T2T_ViT(**cfg)
|
||||
outs = model(imgs)
|
||||
self.assertIsInstance(outs, tuple)
|
||||
self.assertEqual(len(outs), 3)
|
||||
for out in outs:
|
||||
patch_token, cls_token = out
|
||||
self.assertEqual(patch_token.shape, (3, 384, 14, 14))
|
||||
self.assertEqual(cls_token.shape, (3, 384))
|
||||
|
||||
# Test forward with dynamic input size
|
||||
imgs1 = torch.randn(3, 3, 224, 224)
|
||||
imgs2 = torch.randn(3, 3, 256, 256)
|
||||
imgs3 = torch.randn(3, 3, 256, 309)
|
||||
cfg = deepcopy(self.cfg)
|
||||
model = T2T_ViT(**cfg)
|
||||
for imgs in [imgs1, imgs2, imgs3]:
|
||||
outs = model(imgs)
|
||||
self.assertIsInstance(outs, tuple)
|
||||
self.assertEqual(len(outs), 1)
|
||||
patch_token, cls_token = outs[-1]
|
||||
expect_feat_shape = (math.ceil(imgs.shape[2] / 16),
|
||||
math.ceil(imgs.shape[3] / 16))
|
||||
self.assertEqual(patch_token.shape, (3, 384, *expect_feat_shape))
|
||||
self.assertEqual(cls_token.shape, (3, 384))
|
||||
|
@ -3,160 +3,181 @@ import math
|
||||
import os
|
||||
import tempfile
|
||||
from copy import deepcopy
|
||||
from unittest import TestCase
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.modules import GroupNorm
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
from mmcv.runner import load_checkpoint, save_checkpoint
|
||||
|
||||
from mmcls.models.backbones import VisionTransformer
|
||||
from .utils import timm_resize_pos_embed
|
||||
|
||||
|
||||
def is_norm(modules):
|
||||
"""Check if is one of the norms."""
|
||||
if isinstance(modules, (GroupNorm, _BatchNorm)):
|
||||
return True
|
||||
return False
|
||||
class TestVisionTransformer(TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.cfg = dict(
|
||||
arch='b', img_size=224, patch_size=16, drop_path_rate=0.1)
|
||||
|
||||
def check_norm_state(modules, train_state):
|
||||
"""Check if norm layer is in correct train state."""
|
||||
for mod in modules:
|
||||
if isinstance(mod, _BatchNorm):
|
||||
if mod.training != train_state:
|
||||
return False
|
||||
return True
|
||||
def test_structure(self):
|
||||
# Test invalid default arch
|
||||
with self.assertRaisesRegex(AssertionError, 'not in default archs'):
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['arch'] = 'unknown'
|
||||
VisionTransformer(**cfg)
|
||||
|
||||
# Test invalid custom arch
|
||||
with self.assertRaisesRegex(AssertionError, 'Custom arch needs'):
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['arch'] = {
|
||||
'num_layers': 24,
|
||||
'num_heads': 16,
|
||||
'feedforward_channels': 4096
|
||||
}
|
||||
VisionTransformer(**cfg)
|
||||
|
||||
def test_vit_backbone():
|
||||
# Test custom arch
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['arch'] = {
|
||||
'embed_dims': 128,
|
||||
'num_layers': 24,
|
||||
'num_heads': 16,
|
||||
'feedforward_channels': 1024
|
||||
}
|
||||
model = VisionTransformer(**cfg)
|
||||
self.assertEqual(model.embed_dims, 128)
|
||||
self.assertEqual(model.num_layers, 24)
|
||||
for layer in model.layers:
|
||||
self.assertEqual(layer.attn.num_heads, 16)
|
||||
self.assertEqual(layer.ffn.feedforward_channels, 1024)
|
||||
|
||||
cfg_ori = dict(
|
||||
arch='b',
|
||||
img_size=224,
|
||||
patch_size=16,
|
||||
drop_rate=0.1,
|
||||
init_cfg=[
|
||||
# Test out_indices
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['out_indices'] = {1: 1}
|
||||
with self.assertRaisesRegex(AssertionError, "get <class 'dict'>"):
|
||||
VisionTransformer(**cfg)
|
||||
cfg['out_indices'] = [0, 13]
|
||||
with self.assertRaisesRegex(AssertionError, 'Invalid out_indices 13'):
|
||||
VisionTransformer(**cfg)
|
||||
|
||||
# Test model structure
|
||||
cfg = deepcopy(self.cfg)
|
||||
model = VisionTransformer(**cfg)
|
||||
self.assertEqual(len(model.layers), 12)
|
||||
dpr_inc = 0.1 / (12 - 1)
|
||||
dpr = 0
|
||||
for layer in model.layers:
|
||||
self.assertEqual(layer.attn.embed_dims, 768)
|
||||
self.assertEqual(layer.attn.num_heads, 12)
|
||||
self.assertEqual(layer.ffn.feedforward_channels, 3072)
|
||||
self.assertAlmostEqual(layer.attn.out_drop.drop_prob, dpr)
|
||||
self.assertAlmostEqual(layer.ffn.dropout_layer.drop_prob, dpr)
|
||||
dpr += dpr_inc
|
||||
|
||||
def test_init_weights(self):
|
||||
# test weight init cfg
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['init_cfg'] = [
|
||||
dict(
|
||||
type='Kaiming',
|
||||
layer='Conv2d',
|
||||
mode='fan_in',
|
||||
nonlinearity='linear')
|
||||
])
|
||||
]
|
||||
model = VisionTransformer(**cfg)
|
||||
ori_weight = model.patch_embed.projection.weight.clone().detach()
|
||||
# The pos_embed is all zero before initialize
|
||||
self.assertTrue(torch.allclose(model.pos_embed, torch.tensor(0.)))
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# test invalid arch
|
||||
cfg = deepcopy(cfg_ori)
|
||||
cfg['arch'] = 'unknown'
|
||||
VisionTransformer(**cfg)
|
||||
model.init_weights()
|
||||
initialized_weight = model.patch_embed.projection.weight
|
||||
self.assertFalse(torch.allclose(ori_weight, initialized_weight))
|
||||
self.assertFalse(torch.allclose(model.pos_embed, torch.tensor(0.)))
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# test arch without essential keys
|
||||
cfg = deepcopy(cfg_ori)
|
||||
cfg['arch'] = {
|
||||
'num_layers': 24,
|
||||
'num_heads': 16,
|
||||
'feedforward_channels': 4096
|
||||
}
|
||||
VisionTransformer(**cfg)
|
||||
# test load checkpoint
|
||||
pretrain_pos_embed = model.pos_embed.clone().detach()
|
||||
tmpdir = tempfile.gettempdir()
|
||||
checkpoint = os.path.join(tmpdir, 'test.pth')
|
||||
save_checkpoint(model, checkpoint)
|
||||
cfg = deepcopy(self.cfg)
|
||||
model = VisionTransformer(**cfg)
|
||||
load_checkpoint(model, checkpoint, strict=True)
|
||||
self.assertTrue(torch.allclose(model.pos_embed, pretrain_pos_embed))
|
||||
|
||||
# Test ViT base model with input size of 224
|
||||
# and patch size of 16
|
||||
model = VisionTransformer(**cfg_ori)
|
||||
model.init_weights()
|
||||
model.train()
|
||||
# test load checkpoint with different img_size
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['img_size'] = 384
|
||||
model = VisionTransformer(**cfg)
|
||||
load_checkpoint(model, checkpoint, strict=True)
|
||||
resized_pos_embed = timm_resize_pos_embed(pretrain_pos_embed,
|
||||
model.pos_embed)
|
||||
self.assertTrue(torch.allclose(model.pos_embed, resized_pos_embed))
|
||||
|
||||
assert check_norm_state(model.modules(), True)
|
||||
os.remove(checkpoint)
|
||||
|
||||
imgs = torch.randn(3, 3, 224, 224)
|
||||
patch_token, cls_token = model(imgs)[-1]
|
||||
assert cls_token.shape == (3, 768)
|
||||
assert patch_token.shape == (3, 768, 14, 14)
|
||||
def test_forward(self):
|
||||
imgs = torch.randn(3, 3, 224, 224)
|
||||
|
||||
# Test custom arch ViT without output cls token
|
||||
cfg = deepcopy(cfg_ori)
|
||||
cfg['arch'] = {
|
||||
'embed_dims': 128,
|
||||
'num_layers': 24,
|
||||
'num_heads': 16,
|
||||
'feedforward_channels': 1024
|
||||
}
|
||||
cfg['output_cls_token'] = False
|
||||
model = VisionTransformer(**cfg)
|
||||
patch_token = model(imgs)[-1]
|
||||
assert patch_token.shape == (3, 128, 14, 14)
|
||||
# test with_cls_token=False
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['with_cls_token'] = False
|
||||
cfg['output_cls_token'] = True
|
||||
with self.assertRaisesRegex(AssertionError, 'but got False'):
|
||||
VisionTransformer(**cfg)
|
||||
|
||||
# Test ViT with multi out indices
|
||||
cfg = deepcopy(cfg_ori)
|
||||
cfg['out_indices'] = [-3, -2, -1]
|
||||
model = VisionTransformer(**cfg)
|
||||
for out in model(imgs):
|
||||
assert out[0].shape == (3, 768, 14, 14)
|
||||
assert out[1].shape == (3, 768)
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['with_cls_token'] = False
|
||||
cfg['output_cls_token'] = False
|
||||
model = VisionTransformer(**cfg)
|
||||
outs = model(imgs)
|
||||
self.assertIsInstance(outs, tuple)
|
||||
self.assertEqual(len(outs), 1)
|
||||
patch_token = outs[-1]
|
||||
self.assertEqual(patch_token.shape, (3, 768, 14, 14))
|
||||
|
||||
# test with output_cls_token
|
||||
cfg = deepcopy(self.cfg)
|
||||
model = VisionTransformer(**cfg)
|
||||
outs = model(imgs)
|
||||
self.assertIsInstance(outs, tuple)
|
||||
self.assertEqual(len(outs), 1)
|
||||
patch_token, cls_token = outs[-1]
|
||||
self.assertEqual(patch_token.shape, (3, 768, 14, 14))
|
||||
self.assertEqual(cls_token.shape, (3, 768))
|
||||
|
||||
def timm_resize_pos_embed(posemb, posemb_new, num_tokens=1, gs_new=()):
|
||||
# Timm version pos embed resize function.
|
||||
# Refers to https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py # noqa:E501
|
||||
ntok_new = posemb_new.shape[1]
|
||||
if num_tokens:
|
||||
posemb_tok, posemb_grid = posemb[:, :num_tokens], posemb[0,
|
||||
num_tokens:]
|
||||
ntok_new -= num_tokens
|
||||
else:
|
||||
posemb_tok, posemb_grid = posemb[:, :0], posemb[0]
|
||||
gs_old = int(math.sqrt(len(posemb_grid)))
|
||||
if not len(gs_new): # backwards compatibility
|
||||
gs_new = [int(math.sqrt(ntok_new))] * 2
|
||||
assert len(gs_new) >= 2
|
||||
posemb_grid = posemb_grid.reshape(1, gs_old, gs_old,
|
||||
-1).permute(0, 3, 1, 2)
|
||||
posemb_grid = F.interpolate(
|
||||
posemb_grid, size=gs_new, mode='bicubic', align_corners=False)
|
||||
posemb_grid = posemb_grid.permute(0, 2, 3,
|
||||
1).reshape(1, gs_new[0] * gs_new[1], -1)
|
||||
posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
|
||||
return posemb
|
||||
# test without output_cls_token
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['output_cls_token'] = False
|
||||
model = VisionTransformer(**cfg)
|
||||
outs = model(imgs)
|
||||
self.assertIsInstance(outs, tuple)
|
||||
self.assertEqual(len(outs), 1)
|
||||
patch_token = outs[-1]
|
||||
self.assertEqual(patch_token.shape, (3, 768, 14, 14))
|
||||
|
||||
# Test forward with multi out indices
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['out_indices'] = [-3, -2, -1]
|
||||
model = VisionTransformer(**cfg)
|
||||
outs = model(imgs)
|
||||
self.assertIsInstance(outs, tuple)
|
||||
self.assertEqual(len(outs), 3)
|
||||
for out in outs:
|
||||
patch_token, cls_token = out
|
||||
self.assertEqual(patch_token.shape, (3, 768, 14, 14))
|
||||
self.assertEqual(cls_token.shape, (3, 768))
|
||||
|
||||
def test_vit_weight_init():
|
||||
# test weight init cfg
|
||||
pretrain_cfg = dict(
|
||||
arch='b',
|
||||
img_size=224,
|
||||
patch_size=16,
|
||||
init_cfg=[dict(type='Constant', val=1., layer='Conv2d')])
|
||||
pretrain_model = VisionTransformer(**pretrain_cfg)
|
||||
pretrain_model.init_weights()
|
||||
assert torch.allclose(pretrain_model.patch_embed.projection.weight,
|
||||
torch.tensor(1.))
|
||||
assert pretrain_model.pos_embed.abs().sum() > 0
|
||||
|
||||
pos_embed_weight = pretrain_model.pos_embed.detach()
|
||||
tmpdir = tempfile.gettempdir()
|
||||
checkpoint = os.path.join(tmpdir, 'test.pth')
|
||||
torch.save(pretrain_model.state_dict(), checkpoint)
|
||||
|
||||
# test load checkpoint
|
||||
finetune_cfg = dict(
|
||||
arch='b',
|
||||
img_size=224,
|
||||
patch_size=16,
|
||||
init_cfg=dict(type='Pretrained', checkpoint=checkpoint))
|
||||
finetune_model = VisionTransformer(**finetune_cfg)
|
||||
finetune_model.init_weights()
|
||||
assert torch.allclose(finetune_model.pos_embed, pos_embed_weight)
|
||||
|
||||
# test load checkpoint with different img_size
|
||||
finetune_cfg = dict(
|
||||
arch='b',
|
||||
img_size=384,
|
||||
patch_size=16,
|
||||
init_cfg=dict(type='Pretrained', checkpoint=checkpoint))
|
||||
finetune_model = VisionTransformer(**finetune_cfg)
|
||||
finetune_model.init_weights()
|
||||
resized_pos_embed = timm_resize_pos_embed(pos_embed_weight,
|
||||
finetune_model.pos_embed)
|
||||
assert torch.allclose(finetune_model.pos_embed, resized_pos_embed)
|
||||
|
||||
os.remove(checkpoint)
|
||||
# Test forward with dynamic input size
|
||||
imgs1 = torch.randn(3, 3, 224, 224)
|
||||
imgs2 = torch.randn(3, 3, 256, 256)
|
||||
imgs3 = torch.randn(3, 3, 256, 309)
|
||||
cfg = deepcopy(self.cfg)
|
||||
model = VisionTransformer(**cfg)
|
||||
for imgs in [imgs1, imgs2, imgs3]:
|
||||
outs = model(imgs)
|
||||
self.assertIsInstance(outs, tuple)
|
||||
self.assertEqual(len(outs), 1)
|
||||
patch_token, cls_token = outs[-1]
|
||||
expect_feat_shape = (math.ceil(imgs.shape[2] / 16),
|
||||
math.ceil(imgs.shape[3] / 16))
|
||||
self.assertEqual(patch_token.shape, (3, 768, *expect_feat_shape))
|
||||
self.assertEqual(cls_token.shape, (3, 768))
|
||||
|
31
tests/test_models/test_backbones/utils.py
Normal file
31
tests/test_models/test_backbones/utils.py
Normal file
@ -0,0 +1,31 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def timm_resize_pos_embed(posemb, posemb_new, num_tokens=1, gs_new=()):
|
||||
"""Timm version pos embed resize function.
|
||||
|
||||
copied from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
||||
""" # noqa:E501
|
||||
ntok_new = posemb_new.shape[1]
|
||||
if num_tokens:
|
||||
posemb_tok, posemb_grid = posemb[:, :num_tokens], posemb[0,
|
||||
num_tokens:]
|
||||
ntok_new -= num_tokens
|
||||
else:
|
||||
posemb_tok, posemb_grid = posemb[:, :0], posemb[0]
|
||||
gs_old = int(math.sqrt(len(posemb_grid)))
|
||||
if not len(gs_new): # backwards compatibility
|
||||
gs_new = [int(math.sqrt(ntok_new))] * 2
|
||||
assert len(gs_new) >= 2
|
||||
posemb_grid = posemb_grid.reshape(1, gs_old, gs_old,
|
||||
-1).permute(0, 3, 1, 2)
|
||||
posemb_grid = F.interpolate(
|
||||
posemb_grid, size=gs_new, mode='bicubic', align_corners=False)
|
||||
posemb_grid = posemb_grid.permute(0, 2, 3,
|
||||
1).reshape(1, gs_new[0] * gs_new[1], -1)
|
||||
posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
|
||||
return posemb
|
@ -4,10 +4,8 @@ import tempfile
|
||||
from copy import deepcopy
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from mmcv import ConfigDict
|
||||
from mmcv.runner.base_module import BaseModule
|
||||
|
||||
from mmcls.models import CLASSIFIERS
|
||||
from mmcls.models.classifiers import ImageClassifier
|
||||
@ -87,13 +85,10 @@ def test_image_classifier():
|
||||
torch.testing.assert_allclose(soft_pred, torch.softmax(pred, dim=1))
|
||||
|
||||
# test pretrained
|
||||
# TODO remove deprecated pretrained
|
||||
with pytest.warns(UserWarning):
|
||||
model_cfg_ = deepcopy(model_cfg)
|
||||
model_cfg_['pretrained'] = 'checkpoint'
|
||||
model = CLASSIFIERS.build(model_cfg_)
|
||||
assert model.init_cfg == dict(
|
||||
type='Pretrained', checkpoint='checkpoint')
|
||||
model_cfg_ = deepcopy(model_cfg)
|
||||
model_cfg_['pretrained'] = 'checkpoint'
|
||||
model = CLASSIFIERS.build(model_cfg_)
|
||||
assert model.init_cfg == dict(type='Pretrained', checkpoint='checkpoint')
|
||||
|
||||
# test show_result
|
||||
img = np.random.randint(0, 256, (224, 224, 3)).astype(np.uint8)
|
||||
@ -137,17 +132,6 @@ def test_image_classifier_with_mixup():
|
||||
losses = img_classifier.forward_train(imgs, label)
|
||||
assert losses['loss'].item() > 0
|
||||
|
||||
# Considering BC-breaking
|
||||
# TODO remove deprecated mixup usage.
|
||||
model_cfg['train_cfg'] = dict(mixup=dict(alpha=1.0, num_classes=10))
|
||||
img_classifier = ImageClassifier(**model_cfg)
|
||||
img_classifier.init_weights()
|
||||
imgs = torch.randn(16, 3, 32, 32)
|
||||
label = torch.randint(0, 10, (16, ))
|
||||
|
||||
losses = img_classifier.forward_train(imgs, label)
|
||||
assert losses['loss'].item() > 0
|
||||
|
||||
|
||||
def test_image_classifier_with_cutmix():
|
||||
|
||||
@ -177,18 +161,6 @@ def test_image_classifier_with_cutmix():
|
||||
losses = img_classifier.forward_train(imgs, label)
|
||||
assert losses['loss'].item() > 0
|
||||
|
||||
# Considering BC-breaking
|
||||
# TODO remove deprecated mixup usage.
|
||||
model_cfg['train_cfg'] = dict(
|
||||
cutmix=dict(alpha=1.0, num_classes=10, cutmix_prob=1.0))
|
||||
img_classifier = ImageClassifier(**model_cfg)
|
||||
img_classifier.init_weights()
|
||||
imgs = torch.randn(16, 3, 32, 32)
|
||||
label = torch.randint(0, 10, (16, ))
|
||||
|
||||
losses = img_classifier.forward_train(imgs, label)
|
||||
assert losses['loss'].item() > 0
|
||||
|
||||
|
||||
def test_image_classifier_with_augments():
|
||||
|
||||
@ -266,59 +238,6 @@ def test_image_classifier_with_augments():
|
||||
assert losses['loss'].item() > 0
|
||||
|
||||
|
||||
def test_image_classifier_return_tuple():
|
||||
model_cfg = ConfigDict(
|
||||
type='ImageClassifier',
|
||||
backbone=dict(
|
||||
type='ResNet_CIFAR',
|
||||
depth=50,
|
||||
num_stages=4,
|
||||
out_indices=(3, ),
|
||||
style='pytorch',
|
||||
return_tuple=False),
|
||||
head=dict(
|
||||
type='LinearClsHead',
|
||||
num_classes=10,
|
||||
in_channels=2048,
|
||||
loss=dict(type='CrossEntropyLoss')))
|
||||
|
||||
imgs = torch.randn(16, 3, 32, 32)
|
||||
|
||||
model_cfg_ = deepcopy(model_cfg)
|
||||
with pytest.warns(DeprecationWarning):
|
||||
model = CLASSIFIERS.build(model_cfg_)
|
||||
|
||||
# test backbone return tensor
|
||||
feat = model.extract_feat(imgs)
|
||||
assert isinstance(feat, torch.Tensor)
|
||||
|
||||
# test backbone return tuple
|
||||
model_cfg_ = deepcopy(model_cfg)
|
||||
model_cfg_.backbone.return_tuple = True
|
||||
model = CLASSIFIERS.build(model_cfg_)
|
||||
|
||||
feat = model.extract_feat(imgs)
|
||||
assert isinstance(feat, tuple)
|
||||
|
||||
# test warning if backbone return tensor
|
||||
class ToyBackbone(BaseModule):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(3, 16, 3)
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv(x)
|
||||
|
||||
model_cfg_ = deepcopy(model_cfg)
|
||||
model_cfg_.backbone.return_tuple = True
|
||||
model = CLASSIFIERS.build(model_cfg_)
|
||||
model.backbone = ToyBackbone()
|
||||
|
||||
with pytest.warns(DeprecationWarning):
|
||||
model.extract_feat(imgs)
|
||||
|
||||
|
||||
def test_classifier_extract_feat():
|
||||
model_cfg = ConfigDict(
|
||||
type='ImageClassifier',
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user