[Feature] Support ConvNeXt (#670)
* Support ConvNeXt * Add configs of ConvNeXt * Update dev scripts * Update docs. * Use new style README * Add unit tests. * Update README * Imporve according to comments * Modify refers to timm. * Imporve according to commentspull/679/head
parent
8488a784f0
commit
dc456a0c2c
|
@ -128,18 +128,19 @@ def inference(config_file, checkpoint, classes, args):
|
|||
|
||||
if args.flops:
|
||||
from mmcv.cnn.utils import get_model_complexity_info
|
||||
if hasattr(model, 'extract_feat'):
|
||||
model.forward = model.extract_feat
|
||||
flops, params = get_model_complexity_info(
|
||||
model,
|
||||
input_shape=(3, ) + resolution,
|
||||
print_per_layer_stat=False,
|
||||
as_strings=args.flops_str)
|
||||
result['flops'] = flops if args.flops_str else int(flops)
|
||||
result['params'] = params if args.flops_str else int(params)
|
||||
else:
|
||||
result['flops'] = ''
|
||||
result['params'] = ''
|
||||
with torch.no_grad():
|
||||
if hasattr(model, 'extract_feat'):
|
||||
model.forward = model.extract_feat
|
||||
flops, params = get_model_complexity_info(
|
||||
model,
|
||||
input_shape=(3, ) + resolution,
|
||||
print_per_layer_stat=False,
|
||||
as_strings=args.flops_str)
|
||||
result['flops'] = flops if args.flops_str else int(flops)
|
||||
result['params'] = params if args.flops_str else int(params)
|
||||
else:
|
||||
result['flops'] = ''
|
||||
result['params'] = ''
|
||||
|
||||
return result
|
||||
|
||||
|
@ -199,6 +200,9 @@ def main(args):
|
|||
summary_data = {}
|
||||
for model_name, model_info in models.items():
|
||||
|
||||
if model_info.config is None:
|
||||
continue
|
||||
|
||||
config = Path(model_info.config)
|
||||
assert config.exists(), f'{model_name}: {config} not found.'
|
||||
|
||||
|
|
|
@ -163,6 +163,10 @@ def test(args):
|
|||
|
||||
preview_script = ''
|
||||
for model_info in models.values():
|
||||
|
||||
if model_info.results is None:
|
||||
continue
|
||||
|
||||
script_path = create_test_job_batch(commands, model_info, args, port,
|
||||
script_name)
|
||||
preview_script = script_path or preview_script
|
||||
|
@ -288,6 +292,9 @@ def summary(args):
|
|||
summary_data = {}
|
||||
for model_name, model_info in models.items():
|
||||
|
||||
if model_info.results is None:
|
||||
continue
|
||||
|
||||
# Skip if not found result file.
|
||||
result_file = work_dir / model_name / 'result.pkl'
|
||||
if not result_file.exists():
|
||||
|
|
|
@ -79,6 +79,7 @@ Results and models are available in the [model zoo](https://mmclassification.rea
|
|||
- [x] [T2T-ViT](https://github.com/open-mmlab/mmclassification/tree/master/configs/t2t_vit)
|
||||
- [x] [Twins](https://github.com/open-mmlab/mmclassification/tree/master/configs/twins)
|
||||
- [x] [EfficientNet](https://github.com/open-mmlab/mmclassification/tree/master/configs/efficientnet)
|
||||
- [x] [ConvNeXt](https://github.com/open-mmlab/mmclassification/tree/master/configs/convnext)
|
||||
- [ ] HRNet
|
||||
|
||||
</details>
|
||||
|
|
|
@ -76,8 +76,9 @@ MMClassification 是一款基于 PyTorch 的开源图像分类工具箱,是 [O
|
|||
- [x] [DeiT](https://github.com/open-mmlab/mmclassification/tree/master/configs/deit)
|
||||
- [x] [Conformer](https://github.com/open-mmlab/mmclassification/tree/master/configs/conformer)
|
||||
- [x] [T2T-ViT](https://github.com/open-mmlab/mmclassification/tree/master/configs/t2t_vit)
|
||||
- [ ] EfficientNet
|
||||
- [ ] Twins
|
||||
- [x] [Twins](https://github.com/open-mmlab/mmclassification/tree/master/configs/twins)
|
||||
- [x] [EfficientNet](https://github.com/open-mmlab/mmclassification/tree/master/configs/efficientnet)
|
||||
- [x] [ConvNeXt](https://github.com/open-mmlab/mmclassification/tree/master/configs/convnext)
|
||||
- [ ] HRNet
|
||||
|
||||
</details>
|
||||
|
|
|
@ -0,0 +1,23 @@
|
|||
# Model settings
|
||||
model = dict(
|
||||
type='ImageClassifier',
|
||||
backbone=dict(
|
||||
type='ConvNeXt',
|
||||
arch='base',
|
||||
out_indices=(3, ),
|
||||
drop_path_rate=0.5,
|
||||
gap_before_final_norm=True,
|
||||
init_cfg=[
|
||||
dict(
|
||||
type='TruncNormal',
|
||||
layer=['Conv2d', 'Linear'],
|
||||
std=.02,
|
||||
bias=0.),
|
||||
dict(type='Constant', layer=['LayerNorm'], val=1., bias=0.),
|
||||
]),
|
||||
head=dict(
|
||||
type='LinearClsHead',
|
||||
num_classes=1000,
|
||||
in_channels=1024,
|
||||
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
|
||||
))
|
|
@ -0,0 +1,23 @@
|
|||
# Model settings
|
||||
model = dict(
|
||||
type='ImageClassifier',
|
||||
backbone=dict(
|
||||
type='ConvNeXt',
|
||||
arch='large',
|
||||
out_indices=(3, ),
|
||||
drop_path_rate=0.5,
|
||||
gap_before_final_norm=True,
|
||||
init_cfg=[
|
||||
dict(
|
||||
type='TruncNormal',
|
||||
layer=['Conv2d', 'Linear'],
|
||||
std=.02,
|
||||
bias=0.),
|
||||
dict(type='Constant', layer=['LayerNorm'], val=1., bias=0.),
|
||||
]),
|
||||
head=dict(
|
||||
type='LinearClsHead',
|
||||
num_classes=1000,
|
||||
in_channels=1536,
|
||||
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
|
||||
))
|
|
@ -0,0 +1,23 @@
|
|||
# Model settings
|
||||
model = dict(
|
||||
type='ImageClassifier',
|
||||
backbone=dict(
|
||||
type='ConvNeXt',
|
||||
arch='small',
|
||||
out_indices=(3, ),
|
||||
drop_path_rate=0.4,
|
||||
gap_before_final_norm=True,
|
||||
init_cfg=[
|
||||
dict(
|
||||
type='TruncNormal',
|
||||
layer=['Conv2d', 'Linear'],
|
||||
std=.02,
|
||||
bias=0.),
|
||||
dict(type='Constant', layer=['LayerNorm'], val=1., bias=0.),
|
||||
]),
|
||||
head=dict(
|
||||
type='LinearClsHead',
|
||||
num_classes=1000,
|
||||
in_channels=768,
|
||||
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
|
||||
))
|
|
@ -0,0 +1,23 @@
|
|||
# Model settings
|
||||
model = dict(
|
||||
type='ImageClassifier',
|
||||
backbone=dict(
|
||||
type='ConvNeXt',
|
||||
arch='tiny',
|
||||
out_indices=(3, ),
|
||||
drop_path_rate=0.1,
|
||||
gap_before_final_norm=True,
|
||||
init_cfg=[
|
||||
dict(
|
||||
type='TruncNormal',
|
||||
layer=['Conv2d', 'Linear'],
|
||||
std=.02,
|
||||
bias=0.),
|
||||
dict(type='Constant', layer=['LayerNorm'], val=1., bias=0.),
|
||||
]),
|
||||
head=dict(
|
||||
type='LinearClsHead',
|
||||
num_classes=1000,
|
||||
in_channels=768,
|
||||
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
|
||||
))
|
|
@ -0,0 +1,23 @@
|
|||
# Model settings
|
||||
model = dict(
|
||||
type='ImageClassifier',
|
||||
backbone=dict(
|
||||
type='ConvNeXt',
|
||||
arch='xlarge',
|
||||
out_indices=(3, ),
|
||||
drop_path_rate=0.5,
|
||||
gap_before_final_norm=True,
|
||||
init_cfg=[
|
||||
dict(
|
||||
type='TruncNormal',
|
||||
layer=['Conv2d', 'Linear'],
|
||||
std=.02,
|
||||
bias=0.),
|
||||
dict(type='Constant', layer=['LayerNorm'], val=1., bias=0.),
|
||||
]),
|
||||
head=dict(
|
||||
type='LinearClsHead',
|
||||
num_classes=1000,
|
||||
in_channels=2048,
|
||||
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
|
||||
))
|
|
@ -24,7 +24,7 @@ lr_config = dict(
|
|||
min_lr_ratio=1e-2,
|
||||
warmup='linear',
|
||||
warmup_ratio=1e-3,
|
||||
warmup_iters=20 * 1252,
|
||||
warmup_by_epoch=False)
|
||||
warmup_iters=20,
|
||||
warmup_by_epoch=True)
|
||||
|
||||
runner = dict(type='EpochBasedRunner', max_epochs=300)
|
||||
|
|
|
@ -0,0 +1,53 @@
|
|||
# ConvNeXt
|
||||
|
||||
> [A ConvNet for the 2020s](https://arxiv.org/abs/2201.03545v1)
|
||||
<!-- [ALGORITHM] -->
|
||||
|
||||
## Abstract
|
||||
|
||||
<!-- [ABSTRACT] -->
|
||||
The "Roaring 20s" of visual recognition began with the introduction of Vision Transformers (ViTs), which quickly superseded ConvNets as the state-of-the-art image classification model. A vanilla ViT, on the other hand, faces difficulties when applied to general computer vision tasks such as object detection and semantic segmentation. It is the hierarchical Transformers (e.g., Swin Transformers) that reintroduced several ConvNet priors, making Transformers practically viable as a generic vision backbone and demonstrating remarkable performance on a wide variety of vision tasks. However, the effectiveness of such hybrid approaches is still largely credited to the intrinsic superiority of Transformers, rather than the inherent inductive biases of convolutions. In this work, we reexamine the design spaces and test the limits of what a pure ConvNet can achieve. We gradually "modernize" a standard ResNet toward the design of a vision Transformer, and discover several key components that contribute to the performance difference along the way. The outcome of this exploration is a family of pure ConvNet models dubbed ConvNeXt. Constructed entirely from standard ConvNet modules, ConvNeXts compete favorably with Transformers in terms of accuracy and scalability, achieving 87.8% ImageNet top-1 accuracy and outperforming Swin Transformers on COCO detection and ADE20K segmentation, while maintaining the simplicity and efficiency of standard ConvNets.
|
||||
|
||||
<!-- [IMAGE] -->
|
||||
<div align=center>
|
||||
<img src="https://user-images.githubusercontent.com/8370623/148624004-e9581042-ea4d-4e10-b3bd-42c92b02053b.png" width="100%"/>
|
||||
</div>
|
||||
|
||||
## Results and models
|
||||
|
||||
### ImageNet-1k
|
||||
|
||||
| Model | Pretrain | Params(M) | Flops(G) | Top-1 (%) | Top-5 (%) | Config | Download |
|
||||
|:---------------:|:------------:|:---------:|:--------:|:---------:|:---------:|:------:|:--------:|
|
||||
| ConvNeXt-T\* | From scratch | 28.59 | 4.46 | 82.05 | 95.86 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/convnext/convnext-tiny_32xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/convnext/convnext-tiny_3rdparty_32xb128_in1k_20220124-18abde00.pth) |
|
||||
| ConvNeXt-S\* | From scratch | 50.22 | 8.69 | 83.13 | 96.44 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/convnext/convnext-small_32xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/convnext/convnext-small_3rdparty_32xb128_in1k_20220124-d39b5192.pth) |
|
||||
| ConvNeXt-B\* | From scratch | 88.59 | 15.36 | 83.85 | 96.74 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/convnext/convnext-base_32xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/convnext/convnext-base_3rdparty_32xb128_in1k_20220124-d0915162.pth) |
|
||||
| ConvNeXt-B\* | ImageNet-21k | 88.59 | 15.36 | 85.81 | 97.86 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/convnext/convnext-base_32xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/convnext/convnext-base_in21k-pre-3rdparty_32xb128_in1k_20220124-eb2d6ada.pth) |
|
||||
| ConvNeXt-L\* | From scratch | 197.77 | 34.37 | 84.30 | 96.89 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/convnext/convnext-large_64xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/convnext/convnext-large_3rdparty_64xb64_in1k_20220124-f8a0ded0.pth) |
|
||||
| ConvNeXt-L\* | ImageNet-21k | 197.77 | 34.37 | 86.61 | 98.04 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/convnext/convnext-large_64xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/convnext/convnext-large_in21k-pre-3rdparty_64xb64_in1k_20220124-2412403d.pth) |
|
||||
| ConvNeXt-XL\* | ImageNet-21k | 350.20 | 60.93 | 86.97 | 98.20 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/convnext/convnext-xlarge_64xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/convnext/convnext-xlarge_in21k-pre-3rdparty_64xb64_in1k_20220124-76b6863d.pth) |
|
||||
|
||||
*Models with \* are converted from the [official repo](https://github.com/facebookresearch/ConvNeXt). The config files of these models are only for inference. We don't ensure these config files' training accuracy and welcome you to contribute your reproduction results.*
|
||||
|
||||
### ImageNet-21k
|
||||
|
||||
The pre-trained models on ImageNet-21k are used to fine-tune, and therefore don't have evaluation results.
|
||||
|
||||
| Model | Params(M) | Flops(G) | Download |
|
||||
|:--------------------------------:|:---------:|:--------:|:--------:|
|
||||
| convnext-base_3rdparty_in21k\* | 88.59 | 15.36 | [model](https://download.openmmlab.com/mmclassification/v0/convnext/convnext-base_3rdparty_in21k_20220124-13b83eec.pth) |
|
||||
| convnext-large_3rdparty_in21k\* | 197.77 | 34.37 | [model](https://download.openmmlab.com/mmclassification/v0/convnext/convnext-large_3rdparty_in21k_20220124-41b5a79f.pth) |
|
||||
| convnext-xlarge_3rdparty_in21k\* | 350.20 | 60.93 | [model](https://download.openmmlab.com/mmclassification/v0/convnext/convnext-xlarge_3rdparty_in21k_20220124-f909bad7.pth) |
|
||||
|
||||
*Models with \* are converted from the [official repo](https://github.com/facebookresearch/ConvNeXt).*
|
||||
|
||||
## Citation
|
||||
|
||||
```bibtex
|
||||
@Article{liu2022convnet,
|
||||
author = {Zhuang Liu and Hanzi Mao and Chao-Yuan Wu and Christoph Feichtenhofer and Trevor Darrell and Saining Xie},
|
||||
title = {A ConvNet for the 2020s},
|
||||
journal = {arXiv preprint arXiv:2201.03545},
|
||||
year = {2022},
|
||||
}
|
||||
```
|
|
@ -0,0 +1,12 @@
|
|||
_base_ = [
|
||||
'../_base_/models/convnext/convnext-base.py',
|
||||
'../_base_/datasets/imagenet_bs64_swin_224.py',
|
||||
'../_base_/schedules/imagenet_bs1024_adamw_swin.py',
|
||||
'../_base_/default_runtime.py',
|
||||
]
|
||||
|
||||
data = dict(samples_per_gpu=128)
|
||||
|
||||
optimizer = dict(lr=4e-3)
|
||||
|
||||
custom_hooks = [dict(type='EMAHook', momentum=4e-5, priority='ABOVE_NORMAL')]
|
|
@ -0,0 +1,12 @@
|
|||
_base_ = [
|
||||
'../_base_/models/convnext/convnext-large.py',
|
||||
'../_base_/datasets/imagenet_bs64_swin_224.py',
|
||||
'../_base_/schedules/imagenet_bs1024_adamw_swin.py',
|
||||
'../_base_/default_runtime.py',
|
||||
]
|
||||
|
||||
data = dict(samples_per_gpu=64)
|
||||
|
||||
optimizer = dict(lr=4e-3)
|
||||
|
||||
custom_hooks = [dict(type='EMAHook', momentum=4e-5, priority='ABOVE_NORMAL')]
|
|
@ -0,0 +1,12 @@
|
|||
_base_ = [
|
||||
'../_base_/models/convnext/convnext-small.py',
|
||||
'../_base_/datasets/imagenet_bs64_swin_224.py',
|
||||
'../_base_/schedules/imagenet_bs1024_adamw_swin.py',
|
||||
'../_base_/default_runtime.py',
|
||||
]
|
||||
|
||||
data = dict(samples_per_gpu=128)
|
||||
|
||||
optimizer = dict(lr=4e-3)
|
||||
|
||||
custom_hooks = [dict(type='EMAHook', momentum=4e-5, priority='ABOVE_NORMAL')]
|
|
@ -0,0 +1,12 @@
|
|||
_base_ = [
|
||||
'../_base_/models/convnext/convnext-tiny.py',
|
||||
'../_base_/datasets/imagenet_bs64_swin_224.py',
|
||||
'../_base_/schedules/imagenet_bs1024_adamw_swin.py',
|
||||
'../_base_/default_runtime.py',
|
||||
]
|
||||
|
||||
data = dict(samples_per_gpu=128)
|
||||
|
||||
optimizer = dict(lr=4e-3)
|
||||
|
||||
custom_hooks = [dict(type='EMAHook', momentum=4e-5, priority='ABOVE_NORMAL')]
|
|
@ -0,0 +1,12 @@
|
|||
_base_ = [
|
||||
'../_base_/models/convnext/convnext-xlarge.py',
|
||||
'../_base_/datasets/imagenet_bs64_swin_224.py',
|
||||
'../_base_/schedules/imagenet_bs1024_adamw_swin.py',
|
||||
'../_base_/default_runtime.py',
|
||||
]
|
||||
|
||||
data = dict(samples_per_gpu=64)
|
||||
|
||||
optimizer = dict(lr=4e-3)
|
||||
|
||||
custom_hooks = [dict(type='EMAHook', momentum=4e-5, priority='ABOVE_NORMAL')]
|
|
@ -0,0 +1,167 @@
|
|||
Collections:
|
||||
- Name: ConvNeXt
|
||||
Metadata:
|
||||
Training Data: ImageNet-1k
|
||||
Architecture:
|
||||
- 1x1 Convolution
|
||||
- LayerScale
|
||||
Paper:
|
||||
URL: https://arxiv.org/abs/2201.03545v1
|
||||
Title: A ConvNet for the 2020s
|
||||
README: configs/convnext/README.md
|
||||
Code:
|
||||
Version: v0.20.0
|
||||
URL: https://github.com/open-mmlab/mmclassification/blob/v0.20.0/mmcls/models/backbones/convnext.py
|
||||
|
||||
Models:
|
||||
- Name: convnext-tiny_3rdparty_32xb128_in1k
|
||||
Metadata:
|
||||
FLOPs: 4457472768
|
||||
Parameters: 28589128
|
||||
In Collections: ConvNeXt
|
||||
Results:
|
||||
- Dataset: ImageNet-1k
|
||||
Metrics:
|
||||
Top 1 Accuracy: 82.05
|
||||
Top 5 Accuracy: 95.86
|
||||
Task: Image Classification
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/convnext/convnext-tiny_3rdparty_32xb128_in1k_20220124-18abde00.pth
|
||||
Config: configs/convnext/convnext-tiny_32xb128_in1k.py
|
||||
Converted From:
|
||||
Weights: https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth
|
||||
Code: https://github.com/facebookresearch/ConvNeXt
|
||||
- Name: convnext-small_3rdparty_32xb128_in1k
|
||||
Metadata:
|
||||
FLOPs: 8687008512
|
||||
Parameters: 50223688
|
||||
In Collections: ConvNeXt
|
||||
Results:
|
||||
- Dataset: ImageNet-1k
|
||||
Metrics:
|
||||
Top 1 Accuracy: 83.13
|
||||
Top 5 Accuracy: 96.44
|
||||
Task: Image Classification
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/convnext/convnext-small_3rdparty_32xb128_in1k_20220124-d39b5192.pth
|
||||
Config: configs/convnext/convnext-small_32xb128_in1k.py
|
||||
Converted From:
|
||||
Weights: https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth
|
||||
Code: https://github.com/facebookresearch/ConvNeXt
|
||||
- Name: convnext-base_3rdparty_32xb128_in1k
|
||||
Metadata:
|
||||
FLOPs: 15359124480
|
||||
Parameters: 88591464
|
||||
In Collections: ConvNeXt
|
||||
Results:
|
||||
- Dataset: ImageNet-1k
|
||||
Metrics:
|
||||
Top 1 Accuracy: 83.85
|
||||
Top 5 Accuracy: 96.74
|
||||
Task: Image Classification
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/convnext/convnext-base_3rdparty_32xb128_in1k_20220124-d0915162.pth
|
||||
Config: configs/convnext/convnext-base_32xb128_in1k.py
|
||||
Converted From:
|
||||
Weights: https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth
|
||||
Code: https://github.com/facebookresearch/ConvNeXt
|
||||
- Name: convnext-base_3rdparty_in21k
|
||||
Metadata:
|
||||
Training Data: ImageNet-21k
|
||||
FLOPs: 15359124480
|
||||
Parameters: 88591464
|
||||
In Collections: ConvNeXt
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/convnext/convnext-base_3rdparty_in21k_20220124-13b83eec.pth
|
||||
Converted From:
|
||||
Weights: https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth
|
||||
Code: https://github.com/facebookresearch/ConvNeXt
|
||||
- Name: convnext-base_in21k-pre-3rdparty_32xb128_in1k
|
||||
Metadata:
|
||||
Training Data:
|
||||
- ImageNet-21k
|
||||
- ImageNet-1k
|
||||
FLOPs: 15359124480
|
||||
Parameters: 88591464
|
||||
In Collections: ConvNeXt
|
||||
Results:
|
||||
- Dataset: ImageNet-1k
|
||||
Metrics:
|
||||
Top 1 Accuracy: 85.81
|
||||
Top 5 Accuracy: 97.86
|
||||
Task: Image Classification
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/convnext/convnext-base_in21k-pre-3rdparty_32xb128_in1k_20220124-eb2d6ada.pth
|
||||
Config: configs/convnext/convnext-base_32xb128_in1k.py
|
||||
Converted From:
|
||||
Weights: https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_1k_224.pth
|
||||
Code: https://github.com/facebookresearch/ConvNeXt
|
||||
- Name: convnext-large_3rdparty_64xb64_in1k
|
||||
Metadata:
|
||||
FLOPs: 34368026112
|
||||
Parameters: 197767336
|
||||
In Collections: ConvNeXt
|
||||
Results:
|
||||
- Dataset: ImageNet-1k
|
||||
Metrics:
|
||||
Top 1 Accuracy: 84.30
|
||||
Top 5 Accuracy: 96.89
|
||||
Task: Image Classification
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/convnext/convnext-large_3rdparty_64xb64_in1k_20220124-f8a0ded0.pth
|
||||
Config: configs/convnext/convnext-large_64xb64_in1k.py
|
||||
Converted From:
|
||||
Weights: https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth
|
||||
Code: https://github.com/facebookresearch/ConvNeXt
|
||||
- Name: convnext-large_3rdparty_in21k
|
||||
Metadata:
|
||||
Training Data: ImageNet-21k
|
||||
FLOPs: 34368026112
|
||||
Parameters: 197767336
|
||||
In Collections: ConvNeXt
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/convnext/convnext-large_3rdparty_in21k_20220124-41b5a79f.pth
|
||||
Converted From:
|
||||
Weights: https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_224.pth
|
||||
Code: https://github.com/facebookresearch/ConvNeXt
|
||||
- Name: convnext-large_in21k-pre-3rdparty_64xb64_in1k
|
||||
Metadata:
|
||||
Training Data:
|
||||
- ImageNet-21k
|
||||
- ImageNet-1k
|
||||
FLOPs: 34368026112
|
||||
Parameters: 197767336
|
||||
In Collections: ConvNeXt
|
||||
Results:
|
||||
- Dataset: ImageNet-1k
|
||||
Metrics:
|
||||
Top 1 Accuracy: 86.61
|
||||
Top 5 Accuracy: 98.04
|
||||
Task: Image Classification
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/convnext/convnext-large_in21k-pre-3rdparty_64xb64_in1k_20220124-2412403d.pth
|
||||
Config: configs/convnext/convnext-large_64xb64_in1k.py
|
||||
Converted From:
|
||||
Weights: https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_1k_224.pth
|
||||
Code: https://github.com/facebookresearch/ConvNeXt
|
||||
- Name: convnext-xlarge_3rdparty_in21k
|
||||
Metadata:
|
||||
Training Data: ImageNet-21k
|
||||
FLOPs: 60929820672
|
||||
Parameters: 350196968
|
||||
In Collections: ConvNeXt
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/convnext/convnext-xlarge_3rdparty_in21k_20220124-f909bad7.pth
|
||||
Converted From:
|
||||
Weights: https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_224.pth
|
||||
Code: https://github.com/facebookresearch/ConvNeXt
|
||||
- Name: convnext-xlarge_in21k-pre-3rdparty_64xb64_in1k
|
||||
Metadata:
|
||||
Training Data:
|
||||
- ImageNet-21k
|
||||
- ImageNet-1k
|
||||
FLOPs: 60929820672
|
||||
Parameters: 350196968
|
||||
In Collections: ConvNeXt
|
||||
Results:
|
||||
- Dataset: ImageNet-1k
|
||||
Metrics:
|
||||
Top 1 Accuracy: 86.97
|
||||
Top 5 Accuracy: 98.20
|
||||
Task: Image Classification
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/convnext/convnext-xlarge_in21k-pre-3rdparty_64xb64_in1k_20220124-76b6863d.pth
|
||||
Config: configs/convnext/convnext-xlarge_64xb64_in1k.py
|
||||
Converted From:
|
||||
Weights: https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_1k_224_ema.pth
|
||||
Code: https://github.com/facebookresearch/ConvNeXt
|
|
@ -112,6 +112,13 @@ The ResNet family models below are trained by standard data augmentations, i.e.,
|
|||
| EfficientNet-B7 (AA)\* | 66.35 | 0.72 | 84.38 | 96.88 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/efficientnet/efficientnet-b7_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b7_3rdparty_8xb32-aa_in1k_20220119-bf03951c.pth) |
|
||||
| EfficientNet-B7 (AA + AdvProp)\* | 66.35 | 0.72 | 85.14 | 97.23 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/efficientnet/efficientnet-b7_8xb32-01norm_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b7_3rdparty_8xb32-aa-advprop_in1k_20220119-c6dbff10.pth) |
|
||||
| EfficientNet-B8 (AA + AdvProp)\* | 87.41 | 1.09 | 85.38 | 97.28 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/efficientnet/efficientnet-b8_8xb32-01norm_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b8_3rdparty_8xb32-aa-advprop_in1k_20220119-297ce1b7.pth) |
|
||||
| ConvNeXt-T\* | 28.59 | 4.46 | 82.05 | 95.86 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/convnext/convnext-tiny_32xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/convnext/convnext-tiny_3rdparty_32xb128_in1k_20220124-18abde00.pth) |
|
||||
| ConvNeXt-S\* | 50.22 | 8.69 | 83.13 | 96.44 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/convnext/convnext-small_32xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/convnext/convnext-small_3rdparty_32xb128_in1k_20220124-d39b5192.pth) |
|
||||
| ConvNeXt-B\* | 88.59 | 15.36 | 83.85 | 96.74 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/convnext/convnext-base_32xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/convnext/convnext-base_3rdparty_32xb128_in1k_20220124-d0915162.pth) |
|
||||
| ConvNeXt-B\* | 88.59 | 15.36 | 85.81 | 97.86 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/convnext/convnext-base_32xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/convnext/convnext-base_in21k-pre-3rdparty_32xb128_in1k_20220124-eb2d6ada.pth) |
|
||||
| ConvNeXt-L\* | 197.77 | 34.37 | 84.30 | 96.89 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/convnext/convnext-large_64xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/convnext/convnext-large_3rdparty_64xb64_in1k_20220124-f8a0ded0.pth) |
|
||||
| ConvNeXt-L\* | 197.77 | 34.37 | 86.61 | 98.04 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/convnext/convnext-large_64xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/convnext/convnext-large_in21k-pre-3rdparty_64xb64_in1k_20220124-2412403d.pth) |
|
||||
| ConvNeXt-XL\* | 350.20 | 60.93 | 86.97 | 98.20 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/convnext/convnext-xlarge_64xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/convnext/convnext-xlarge_in21k-pre-3rdparty_64xb64_in1k_20220124-76b6863d.pth) |
|
||||
|
||||
*Models with \* are converted from other repos, others are trained by ourselves.*
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .alexnet import AlexNet
|
||||
from .conformer import Conformer
|
||||
from .convnext import ConvNeXt
|
||||
from .deit import DistilledVisionTransformer
|
||||
from .efficientnet import EfficientNet
|
||||
from .lenet import LeNet5
|
||||
|
@ -32,5 +33,5 @@ __all__ = [
|
|||
'ShuffleNetV2', 'MobileNetV2', 'MobileNetV3', 'VisionTransformer',
|
||||
'SwinTransformer', 'TNT', 'TIMMBackbone', 'T2T_ViT', 'Res2Net', 'RepVGG',
|
||||
'Conformer', 'MlpMixer', 'DistilledVisionTransformer', 'PCPVT', 'SVT',
|
||||
'EfficientNet'
|
||||
'EfficientNet', 'ConvNeXt'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,331 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from functools import partial
|
||||
from itertools import chain
|
||||
from typing import Sequence
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmcv.cnn.bricks import (NORM_LAYERS, DropPath, build_activation_layer,
|
||||
build_norm_layer)
|
||||
from mmcv.runner import BaseModule
|
||||
from mmcv.runner.base_module import ModuleList, Sequential
|
||||
|
||||
from ..builder import BACKBONES
|
||||
from .base_backbone import BaseBackbone
|
||||
|
||||
|
||||
@NORM_LAYERS.register_module('LN2d')
|
||||
class LayerNorm2d(nn.LayerNorm):
|
||||
"""LayerNorm on channels for 2d images.
|
||||
|
||||
Args:
|
||||
num_channels (int): The number of channels of the input tensor.
|
||||
eps (float): a value added to the denominator for numerical stability.
|
||||
Defaults to 1e-5.
|
||||
elementwise_affine (bool): a boolean value that when set to ``True``,
|
||||
this module has learnable per-element affine parameters initialized
|
||||
to ones (for weights) and zeros (for biases). Defaults to True.
|
||||
"""
|
||||
|
||||
def __init__(self, num_channels: int, **kwargs) -> None:
|
||||
super().__init__(num_channels, **kwargs)
|
||||
self.num_channels = self.normalized_shape[0]
|
||||
|
||||
def forward(self, x):
|
||||
assert x.dim() == 4, 'LayerNorm2d only supports inputs with shape ' \
|
||||
f'(N, C, H, W), but got tensor with shape {x.shape}'
|
||||
return F.layer_norm(
|
||||
x.permute(0, 2, 3, 1), self.normalized_shape, self.weight,
|
||||
self.bias, self.eps).permute(0, 3, 1, 2)
|
||||
|
||||
|
||||
class ConvNeXtBlock(BaseModule):
|
||||
"""ConvNeXt Block.
|
||||
|
||||
Args:
|
||||
in_channels (int): The number of input channels.
|
||||
norm_cfg (dict): The config dict for norm layers.
|
||||
Defaults to ``dict(type='LN2d', eps=1e-6)``.
|
||||
act_cfg (dict): The config dict for activation between pointwise
|
||||
convolution. Defaults to ``dict(type='GELU')``.
|
||||
mlp_ratio (float): The expansion ratio in both pointwise convolution.
|
||||
Defaults to 4.
|
||||
linear_pw_conv (bool): Whether to use linear layer to do pointwise
|
||||
convolution. More details can be found in the note.
|
||||
Defaults to True.
|
||||
drop_path_rate (float): Stochastic depth rate. Defaults to 0.
|
||||
layer_scale_init_value (float): Init value for Layer Scale.
|
||||
Defaults to 1e-6.
|
||||
|
||||
Note:
|
||||
There are two equivalent implementations:
|
||||
|
||||
1. DwConv -> LayerNorm -> 1x1 Conv -> GELU -> 1x1 Conv;
|
||||
all outputs are in (N, C, H, W).
|
||||
2. DwConv -> LayerNorm -> Permute to (N, H, W, C) -> Linear -> GELU
|
||||
-> Linear; Permute back
|
||||
|
||||
As default, we use the second to align with the official repository.
|
||||
And it may be slightly faster.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
norm_cfg=dict(type='LN2d', eps=1e-6),
|
||||
act_cfg=dict(type='GELU'),
|
||||
mlp_ratio=4.,
|
||||
linear_pw_conv=True,
|
||||
drop_path_rate=0.,
|
||||
layer_scale_init_value=1e-6):
|
||||
super().__init__()
|
||||
self.depthwise_conv = nn.Conv2d(
|
||||
in_channels,
|
||||
in_channels,
|
||||
kernel_size=7,
|
||||
padding=3,
|
||||
groups=in_channels)
|
||||
|
||||
self.linear_pw_conv = linear_pw_conv
|
||||
self.norm = build_norm_layer(norm_cfg, in_channels)[1]
|
||||
|
||||
mid_channels = int(mlp_ratio * in_channels)
|
||||
if self.linear_pw_conv:
|
||||
# Use linear layer to do pointwise conv.
|
||||
pw_conv = nn.Linear
|
||||
else:
|
||||
pw_conv = partial(nn.Conv2d, kernel_size=1)
|
||||
|
||||
self.pointwise_conv1 = pw_conv(in_channels, mid_channels)
|
||||
self.act = build_activation_layer(act_cfg)
|
||||
self.pointwise_conv2 = pw_conv(mid_channels, in_channels)
|
||||
|
||||
self.gamma = nn.Parameter(
|
||||
layer_scale_init_value * torch.ones((in_channels)),
|
||||
requires_grad=True) if layer_scale_init_value > 0 else None
|
||||
|
||||
self.drop_path = DropPath(
|
||||
drop_path_rate) if drop_path_rate > 0. else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
shortcut = x
|
||||
x = self.depthwise_conv(x)
|
||||
x = self.norm(x)
|
||||
|
||||
if self.linear_pw_conv:
|
||||
x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
|
||||
|
||||
x = self.pointwise_conv1(x)
|
||||
x = self.act(x)
|
||||
x = self.pointwise_conv2(x)
|
||||
|
||||
if self.linear_pw_conv:
|
||||
x = x.permute(0, 3, 1, 2) # permute back
|
||||
|
||||
if self.gamma is not None:
|
||||
x = x.mul(self.gamma.view(1, -1, 1, 1))
|
||||
|
||||
x = shortcut + self.drop_path(x)
|
||||
return x
|
||||
|
||||
|
||||
@BACKBONES.register_module()
|
||||
class ConvNeXt(BaseBackbone):
|
||||
"""ConvNeXt.
|
||||
|
||||
A PyTorch implementation of : `A ConvNet for the 2020s
|
||||
<https://arxiv.org/pdf/2201.03545.pdf>`_
|
||||
|
||||
Modified from the `official repo
|
||||
<https://github.com/facebookresearch/ConvNeXt/blob/main/models/convnext.py>`_
|
||||
and `timm
|
||||
<https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/convnext.py>`_.
|
||||
|
||||
Args:
|
||||
arch (str | dict): The model's architecture. If string, it should be
|
||||
one of architecture in ``ConvNeXt.arch_settings``. And if dict, it
|
||||
should include the following two keys:
|
||||
|
||||
- depths (list[int]): Number of blocks at each stage.
|
||||
- channels (list[int]): The number of channels at each stage.
|
||||
|
||||
Defaults to 'tiny'.
|
||||
in_channels (int): Number of input image channels. Defaults to 3.
|
||||
stem_patch_size (int): The size of one patch in the stem layer.
|
||||
Defaults to 4.
|
||||
norm_cfg (dict): The config dict for norm layers.
|
||||
Defaults to ``dict(type='LN2d', eps=1e-6)``.
|
||||
act_cfg (dict): The config dict for activation between pointwise
|
||||
convolution. Defaults to ``dict(type='GELU')``.
|
||||
linear_pw_conv (bool): Whether to use linear layer to do pointwise
|
||||
convolution. Defaults to True.
|
||||
drop_path_rate (float): Stochastic depth rate. Defaults to 0.
|
||||
layer_scale_init_value (float): Init value for Layer Scale.
|
||||
Defaults to 1e-6.
|
||||
out_indices (Sequence | int): Output from which stages.
|
||||
Defaults to -1, means the last stage.
|
||||
frozen_stages (int): Stages to be frozen (all param fixed).
|
||||
Defaults to 0, which means not freezing any parameters.
|
||||
gap_before_final_norm (bool): Whether to globally average the feature
|
||||
map before the final norm layer. In the official repo, it's only
|
||||
used in classification task. Defaults to True.
|
||||
init_cfg (dict, optional): Initialization config dict
|
||||
""" # noqa: E501
|
||||
arch_settings = {
|
||||
'tiny': {
|
||||
'depths': [3, 3, 9, 3],
|
||||
'channels': [96, 192, 384, 768]
|
||||
},
|
||||
'small': {
|
||||
'depths': [3, 3, 27, 3],
|
||||
'channels': [96, 192, 384, 768]
|
||||
},
|
||||
'base': {
|
||||
'depths': [3, 3, 27, 3],
|
||||
'channels': [128, 256, 512, 1024]
|
||||
},
|
||||
'large': {
|
||||
'depths': [3, 3, 27, 3],
|
||||
'channels': [192, 384, 768, 1536]
|
||||
},
|
||||
'xlarge': {
|
||||
'depths': [3, 3, 27, 3],
|
||||
'channels': [256, 512, 1024, 2048]
|
||||
},
|
||||
}
|
||||
|
||||
def __init__(self,
|
||||
arch='tiny',
|
||||
in_channels=3,
|
||||
stem_patch_size=4,
|
||||
norm_cfg=dict(type='LN2d', eps=1e-6),
|
||||
act_cfg=dict(type='GELU'),
|
||||
linear_pw_conv=True,
|
||||
drop_path_rate=0.,
|
||||
layer_scale_init_value=1e-6,
|
||||
out_indices=-1,
|
||||
frozen_stages=0,
|
||||
gap_before_final_norm=True,
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
|
||||
if isinstance(arch, str):
|
||||
assert arch in self.arch_settings, \
|
||||
f'Unavailable arch, please choose from ' \
|
||||
f'({set(self.arch_settings)}) or pass a dict.'
|
||||
arch = self.arch_settings[arch]
|
||||
elif isinstance(arch, dict):
|
||||
assert 'depths' in arch and 'channels' in arch, \
|
||||
f'The arch dict must have "depths" and "channels", ' \
|
||||
f'but got {list(arch.keys())}.'
|
||||
|
||||
self.depths = arch['depths']
|
||||
self.channels = arch['channels']
|
||||
assert (isinstance(self.depths, Sequence)
|
||||
and isinstance(self.channels, Sequence)
|
||||
and len(self.depths) == len(self.channels)), \
|
||||
f'The "depths" ({self.depths}) and "channels" ({self.channels}) ' \
|
||||
'should be both sequence with the same length.'
|
||||
|
||||
self.num_stages = len(self.depths)
|
||||
|
||||
if isinstance(out_indices, int):
|
||||
out_indices = [out_indices]
|
||||
assert isinstance(out_indices, Sequence), \
|
||||
f'"out_indices" must by a sequence or int, ' \
|
||||
f'get {type(out_indices)} instead.'
|
||||
for i, index in enumerate(out_indices):
|
||||
if index < 0:
|
||||
out_indices[i] = 4 + index
|
||||
assert out_indices[i] >= 0, f'Invalid out_indices {index}'
|
||||
self.out_indices = out_indices
|
||||
|
||||
self.frozen_stages = frozen_stages
|
||||
self.gap_before_final_norm = gap_before_final_norm
|
||||
|
||||
# stochastic depth decay rule
|
||||
dpr = [
|
||||
x.item()
|
||||
for x in torch.linspace(0, drop_path_rate, sum(self.depths))
|
||||
]
|
||||
block_idx = 0
|
||||
|
||||
# 4 downsample layers between stages, including the stem layer.
|
||||
self.downsample_layers = ModuleList()
|
||||
stem = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
in_channels,
|
||||
self.channels[0],
|
||||
kernel_size=stem_patch_size,
|
||||
stride=stem_patch_size),
|
||||
build_norm_layer(norm_cfg, self.channels[0])[1],
|
||||
)
|
||||
self.downsample_layers.append(stem)
|
||||
|
||||
# 4 feature resolution stages, each consisting of multiple residual
|
||||
# blocks
|
||||
self.stages = nn.ModuleList()
|
||||
|
||||
for i in range(self.num_stages):
|
||||
depth = self.depths[i]
|
||||
channels = self.channels[i]
|
||||
|
||||
if i >= 1:
|
||||
downsample_layer = nn.Sequential(
|
||||
LayerNorm2d(self.channels[i - 1]),
|
||||
nn.Conv2d(
|
||||
self.channels[i - 1],
|
||||
channels,
|
||||
kernel_size=2,
|
||||
stride=2),
|
||||
)
|
||||
self.downsample_layers.append(downsample_layer)
|
||||
|
||||
stage = Sequential(*[
|
||||
ConvNeXtBlock(
|
||||
in_channels=channels,
|
||||
drop_path_rate=dpr[block_idx + j],
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg,
|
||||
linear_pw_conv=linear_pw_conv,
|
||||
layer_scale_init_value=layer_scale_init_value)
|
||||
for j in range(depth)
|
||||
])
|
||||
block_idx += depth
|
||||
|
||||
self.stages.append(stage)
|
||||
|
||||
if i in self.out_indices:
|
||||
norm_layer = build_norm_layer(norm_cfg, channels)[1]
|
||||
self.add_module(f'norm{i}', norm_layer)
|
||||
|
||||
self._freeze_stages()
|
||||
|
||||
def forward(self, x):
|
||||
outs = []
|
||||
for i, stage in enumerate(self.stages):
|
||||
x = self.downsample_layers[i](x)
|
||||
x = stage(x)
|
||||
if i in self.out_indices:
|
||||
norm_layer = getattr(self, f'norm{i}')
|
||||
if self.gap_before_final_norm:
|
||||
gap = x.mean([-2, -1], keepdim=True)
|
||||
outs.append(norm_layer(gap).flatten(1))
|
||||
else:
|
||||
outs.append(norm_layer(x))
|
||||
|
||||
return tuple(outs)
|
||||
|
||||
def _freeze_stages(self):
|
||||
for i in range(self.frozen_stages):
|
||||
downsample_layer = self.downsample_layers[i]
|
||||
stage = self.stages[i]
|
||||
downsample_layer.eval()
|
||||
stage.eval()
|
||||
for param in chain(downsample_layer.parameters(),
|
||||
stage.parameters()):
|
||||
param.requires_grad = False
|
||||
|
||||
def train(self, mode=True):
|
||||
super(ConvNeXt, self).train(mode)
|
||||
self._freeze_stages()
|
|
@ -18,3 +18,4 @@ Import:
|
|||
- configs/deit/metafile.yml
|
||||
- configs/twins/metafile.yml
|
||||
- configs/efficientnet/metafile.yml
|
||||
- configs/convnext/metafile.yml
|
||||
|
|
|
@ -0,0 +1,86 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmcls.models.backbones import ConvNeXt
|
||||
|
||||
|
||||
def test_assertion():
|
||||
with pytest.raises(AssertionError):
|
||||
ConvNeXt(arch='unknown')
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# ConvNeXt arch dict should include 'embed_dims',
|
||||
ConvNeXt(arch=dict(channels=[2, 3, 4, 5]))
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# ConvNeXt arch dict should include 'embed_dims',
|
||||
ConvNeXt(arch=dict(depths=[2, 3, 4], channels=[2, 3, 4, 5]))
|
||||
|
||||
|
||||
def test_convnext():
|
||||
|
||||
# Test forward
|
||||
model = ConvNeXt(arch='tiny', out_indices=-1)
|
||||
model.init_weights()
|
||||
model.train()
|
||||
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
feat = model(imgs)
|
||||
assert len(feat) == 1
|
||||
assert feat[0].shape == torch.Size([1, 768])
|
||||
|
||||
# Test forward with multiple outputs
|
||||
model = ConvNeXt(arch='small', out_indices=(0, 1, 2, 3))
|
||||
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
feat = model(imgs)
|
||||
assert len(feat) == 4
|
||||
assert feat[0].shape == torch.Size([1, 96])
|
||||
assert feat[1].shape == torch.Size([1, 192])
|
||||
assert feat[2].shape == torch.Size([1, 384])
|
||||
assert feat[3].shape == torch.Size([1, 768])
|
||||
|
||||
# Test with custom arch
|
||||
model = ConvNeXt(
|
||||
arch={
|
||||
'depths': [2, 3, 4, 5, 6],
|
||||
'channels': [16, 32, 64, 128, 256]
|
||||
},
|
||||
out_indices=(0, 1, 2, 3, 4))
|
||||
model.init_weights()
|
||||
model.train()
|
||||
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
feat = model(imgs)
|
||||
assert len(feat) == 5
|
||||
assert feat[0].shape == torch.Size([1, 16])
|
||||
assert feat[1].shape == torch.Size([1, 32])
|
||||
assert feat[2].shape == torch.Size([1, 64])
|
||||
assert feat[3].shape == torch.Size([1, 128])
|
||||
assert feat[4].shape == torch.Size([1, 256])
|
||||
|
||||
# Test without gap before final norm
|
||||
model = ConvNeXt(
|
||||
arch='small', out_indices=(0, 1, 2, 3), gap_before_final_norm=False)
|
||||
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
feat = model(imgs)
|
||||
assert len(feat) == 4
|
||||
assert feat[0].shape == torch.Size([1, 96, 56, 56])
|
||||
assert feat[1].shape == torch.Size([1, 192, 28, 28])
|
||||
assert feat[2].shape == torch.Size([1, 384, 14, 14])
|
||||
assert feat[3].shape == torch.Size([1, 768, 7, 7])
|
||||
|
||||
# Test frozen_stages
|
||||
model = ConvNeXt(arch='small', out_indices=(0, 1, 2, 3), frozen_stages=2)
|
||||
model.init_weights()
|
||||
model.train()
|
||||
|
||||
for i in range(2):
|
||||
assert not model.downsample_layers[i].training
|
||||
assert not model.stages[i].training
|
||||
|
||||
for i in range(2, 4):
|
||||
assert model.downsample_layers[i].training
|
||||
assert model.stages[i].training
|
Loading…
Reference in New Issue