[Reproduce] Update ConvNeXt config files. (#1256)
* Update ConvNeXt training configs. * Update ConvNeXt network. * Update metafile and README. * Update READMEpull/1240/head
parent
0e4163668f
commit
b63515111b
|
@ -19,5 +19,11 @@ model = dict(
|
|||
type='LinearClsHead',
|
||||
num_classes=1000,
|
||||
in_channels=1024,
|
||||
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
|
||||
))
|
||||
loss=dict(
|
||||
type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
|
||||
),
|
||||
train_cfg=dict(augments=[
|
||||
dict(type='Mixup', alpha=0.8),
|
||||
dict(type='CutMix', alpha=1.0),
|
||||
]),
|
||||
)
|
||||
|
|
|
@ -19,5 +19,11 @@ model = dict(
|
|||
type='LinearClsHead',
|
||||
num_classes=1000,
|
||||
in_channels=1536,
|
||||
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
|
||||
))
|
||||
loss=dict(
|
||||
type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
|
||||
),
|
||||
train_cfg=dict(augments=[
|
||||
dict(type='Mixup', alpha=0.8),
|
||||
dict(type='CutMix', alpha=1.0),
|
||||
]),
|
||||
)
|
||||
|
|
|
@ -19,5 +19,11 @@ model = dict(
|
|||
type='LinearClsHead',
|
||||
num_classes=1000,
|
||||
in_channels=768,
|
||||
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
|
||||
))
|
||||
loss=dict(
|
||||
type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
|
||||
),
|
||||
train_cfg=dict(augments=[
|
||||
dict(type='Mixup', alpha=0.8),
|
||||
dict(type='CutMix', alpha=1.0),
|
||||
]),
|
||||
)
|
||||
|
|
|
@ -19,5 +19,11 @@ model = dict(
|
|||
type='LinearClsHead',
|
||||
num_classes=1000,
|
||||
in_channels=768,
|
||||
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
|
||||
))
|
||||
loss=dict(
|
||||
type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
|
||||
),
|
||||
train_cfg=dict(augments=[
|
||||
dict(type='Mixup', alpha=0.8),
|
||||
dict(type='CutMix', alpha=1.0),
|
||||
]),
|
||||
)
|
||||
|
|
|
@ -19,5 +19,11 @@ model = dict(
|
|||
type='LinearClsHead',
|
||||
num_classes=1000,
|
||||
in_channels=2048,
|
||||
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
|
||||
))
|
||||
loss=dict(
|
||||
type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
|
||||
),
|
||||
train_cfg=dict(augments=[
|
||||
dict(type='Mixup', alpha=0.8),
|
||||
dict(type='CutMix', alpha=1.0),
|
||||
]),
|
||||
)
|
||||
|
|
|
@ -10,6 +10,7 @@ optim_wrapper = dict(
|
|||
paramwise_cfg=dict(
|
||||
norm_decay_mult=0.0,
|
||||
bias_decay_mult=0.0,
|
||||
flat_decay_mult=0.0,
|
||||
custom_keys={
|
||||
'.absolute_pos_embed': dict(decay_mult=0.0),
|
||||
'.relative_position_bias_table': dict(decay_mult=0.0)
|
||||
|
|
|
@ -36,9 +36,9 @@ The "Roaring 20s" of visual recognition began with the introduction of Vision Tr
|
|||
|
||||
```python
|
||||
>>> import torch
|
||||
>>> from mmcls.apis import init_model, inference_model
|
||||
>>> from mmcls.apis import get_model, inference_model
|
||||
>>>
|
||||
>>> model = init_model('configs/convnext/convnext-tiny_32xb128_in1k.py', 'https://download.openmmlab.com/mmclassification/v0/convnext/convnext-tiny_3rdparty_32xb128-noema_in1k_20220222-2908964a.pth')
|
||||
>>> model = get_model('convnext-tiny_32xb128_in1k', pretrained=True)
|
||||
>>> predict = inference_model(model, 'demo/demo.JPEG')
|
||||
>>> print(predict['pred_class'])
|
||||
sea snake
|
||||
|
@ -50,10 +50,10 @@ sea snake
|
|||
|
||||
```python
|
||||
>>> import torch
|
||||
>>> from mmcls.apis import init_model
|
||||
>>> from mmcls.apis import get_model
|
||||
>>>
|
||||
>>> model = init_model('configs/convnext/convnext-tiny_32xb128_in1k.py', 'https://download.openmmlab.com/mmclassification/v0/convnext/convnext-tiny_3rdparty_32xb128-noema_in1k_20220222-2908964a.pth')
|
||||
>>> inputs = torch.rand(1, 3, 224, 224).to(model.data_preprocessor.device)
|
||||
>>> model = get_model('convnext-tiny_32xb128_in1k', pretrained=True)
|
||||
>>> inputs = torch.rand(1, 3, 224, 224))
|
||||
>>> # To get classification scores.
|
||||
>>> out = model(inputs)
|
||||
>>> print(out.shape)
|
||||
|
@ -85,35 +85,37 @@ For more configurable parameters, please refer to the [API](https://mmclassifica
|
|||
|
||||
## Results and models
|
||||
|
||||
### ImageNet-1k
|
||||
|
||||
| Model | Pretrain | Params(M) | Flops(G) | Top-1 (%) | Top-5 (%) | Config | Download |
|
||||
| :-----------: | :----------: | :-------: | :------: | :-------: | :-------: | :----------------------------------------: | :------------------------------------------------------------------------------------------------: |
|
||||
| ConvNeXt-T\* | From scratch | 28.59 | 4.46 | 82.05 | 95.86 | [config](./convnext-tiny_32xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/convnext/convnext-tiny_3rdparty_32xb128_in1k_20220124-18abde00.pth) |
|
||||
| ConvNeXt-S\* | From scratch | 50.22 | 8.69 | 83.13 | 96.44 | [config](./convnext-small_32xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/convnext/convnext-small_3rdparty_32xb128_in1k_20220124-d39b5192.pth) |
|
||||
| ConvNeXt-B\* | From scratch | 88.59 | 15.36 | 83.85 | 96.74 | [config](./convnext-base_32xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/convnext/convnext-base_3rdparty_32xb128_in1k_20220124-d0915162.pth) |
|
||||
| ConvNeXt-B\* | ImageNet-21k | 88.59 | 15.36 | 85.81 | 97.86 | [config](./convnext-base_32xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/convnext/convnext-base_in21k-pre-3rdparty_32xb128_in1k_20220124-eb2d6ada.pth) |
|
||||
| ConvNeXt-L\* | From scratch | 197.77 | 34.37 | 84.30 | 96.89 | [config](./convnext-large_64xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/convnext/convnext-large_3rdparty_64xb64_in1k_20220124-f8a0ded0.pth) |
|
||||
| ConvNeXt-L\* | ImageNet-21k | 197.77 | 34.37 | 86.61 | 98.04 | [config](./convnext-large_64xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/convnext/convnext-large_in21k-pre-3rdparty_64xb64_in1k_20220124-2412403d.pth) |
|
||||
| ConvNeXt-XL\* | ImageNet-21k | 350.20 | 60.93 | 86.97 | 98.20 | [config](./convnext-xlarge_64xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/convnext/convnext-xlarge_in21k-pre-3rdparty_64xb64_in1k_20220124-76b6863d.pth) |
|
||||
|
||||
*Models with * are converted from the [official repo](https://github.com/facebookresearch/ConvNeXt). The config files of these models are only for inference. We don't ensure these config files' training accuracy and welcome you to contribute your reproduction results.*
|
||||
|
||||
### Pre-trained Models
|
||||
|
||||
The pre-trained models on ImageNet-1k or ImageNet-21k are used to fine-tune on the downstream tasks.
|
||||
|
||||
| Model | Training Data | Params(M) | Flops(G) | Download |
|
||||
| :-----------: | :-----------: | :-------: | :------: | :-----------------------------------------------------------------------------------------------------------------------------------: |
|
||||
| ConvNeXt-T\* | ImageNet-1k | 28.59 | 4.46 | [model](https://download.openmmlab.com/mmclassification/v0/convnext/convnext-tiny_3rdparty_32xb128-noema_in1k_20220222-2908964a.pth) |
|
||||
| ConvNeXt-S\* | ImageNet-1k | 50.22 | 8.69 | [model](https://download.openmmlab.com/mmclassification/v0/convnext/convnext-small_3rdparty_32xb128-noema_in1k_20220222-fa001ca5.pth) |
|
||||
| ConvNeXt-B\* | ImageNet-1k | 88.59 | 15.36 | [model](https://download.openmmlab.com/mmclassification/v0/convnext/convnext-base_3rdparty_32xb128-noema_in1k_20220222-dba4f95f.pth) |
|
||||
| ConvNeXt-B\* | ImageNet-21k | 88.59 | 15.36 | [model](https://download.openmmlab.com/mmclassification/v0/convnext/convnext-base_3rdparty_in21k_20220124-13b83eec.pth) |
|
||||
| ConvNeXt-L\* | ImageNet-21k | 197.77 | 34.37 | [model](https://download.openmmlab.com/mmclassification/v0/convnext/convnext-large_3rdparty_in21k_20220124-41b5a79f.pth) |
|
||||
| ConvNeXt-XL\* | ImageNet-21k | 350.20 | 60.93 | [model](https://download.openmmlab.com/mmclassification/v0/convnext/convnext-xlarge_3rdparty_in21k_20220124-f909bad7.pth) |
|
||||
| Model | Training Data | Params(M) | Flops(G) | Top-1 (%) | Top-5 (%) | Download |
|
||||
| :------------------------------------------------- | :-----------: | :-------: | :------: | :-------: | :-------: | :----------------------------------------------------------------------------------------------------: |
|
||||
| ConvNeXt-T (`convnext-tiny_32xb128-noema_in1k`) | ImageNet-1k | 28.59 | 4.46 | 81.95 | 95.89 | [model](https://download.openmmlab.com/mmclassification/v0/convnext/convnext-tiny_32xb128-noema_in1k_20221208-5d4509c7.pth) |
|
||||
| ConvNeXt-S (`convnext-small_32xb128-noema_in1k`) | ImageNet-1k | 50.22 | 8.69 | 83.21 | 96.48 | [model](https://download.openmmlab.com/mmclassification/v0/convnext/convnext-small_32xb128-noema_in1k_20221208-4a618995.pth) |
|
||||
| ConvNeXt-B (`convnext-base_32xb128-noema_in1k`) | ImageNet-1k | 88.59 | 15.36 | 83.64 | 96.61 | [model](https://download.openmmlab.com/mmclassification/v0/convnext/convnext-base_32xb128-noema_in1k_20221208-f8182678.pth) |
|
||||
| ConvNeXt-B (`convnext-base_3rdparty-noema_in1k`)\* | ImageNet-1k | 88.59 | 15.36 | 83.71 | 96.60 | [model](https://download.openmmlab.com/mmclassification/v0/convnext/convnext-base_3rdparty_32xb128-noema_in1k_20220222-dba4f95f.pth) |
|
||||
| ConvNeXt-B (`convnext-base_3rdparty_in21k`)\* | ImageNet-21k | 88.59 | 15.36 | N/A | N/A | [model](https://download.openmmlab.com/mmclassification/v0/convnext/convnext-base_3rdparty_in21k_20220124-13b83eec.pth) |
|
||||
| ConvNeXt-L (`convnext-large_3rdparty_in21k`)\* | ImageNet-21k | 197.77 | 34.37 | N/A | N/A | [model](https://download.openmmlab.com/mmclassification/v0/convnext/convnext-large_3rdparty_in21k_20220124-41b5a79f.pth) |
|
||||
| ConvNeXt-XL (`convnext-xlarge_3rdparty_in21k`)\* | ImageNet-21k | 350.20 | 60.93 | N/A | N/A | [model](https://download.openmmlab.com/mmclassification/v0/convnext/convnext-xlarge_3rdparty_in21k_20220124-f909bad7.pth) |
|
||||
|
||||
*Models with * are converted from the [official repo](https://github.com/facebookresearch/ConvNeXt).*
|
||||
|
||||
### ImageNet-1k
|
||||
|
||||
| Model | Pretrain | Params(M) | Flops(G) | Top-1 (%) | Top-5 (%) | Config | Download |
|
||||
| :----------------------------------------------------- | :----------: | :-------: | :------: | :-------: | :-------: | :----------------------------------------: | :-------------------------------------------------------: |
|
||||
| ConvNeXt-T (`convnext-tiny_32xb128_in1k`) | From scratch | 28.59 | 4.46 | 82.14 | 96.06 | [config](./convnext-tiny_32xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/convnext/convnext-tiny_32xb128_in1k_20221207-998cf3e9.pth) \| [log](https://download.openmmlab.com/mmclassification/v0/convnext/convnext-tiny_32xb128_in1k_20221207-998cf3e9.log.json) |
|
||||
| ConvNeXt-S (`convnext-small_32xb128_in1k`) | From scratch | 50.22 | 8.69 | 83.16 | 96.56 | [config](./convnext-small_32xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/convnext/convnext-small_32xb128_in1k_20221207-4ab7052c.pth) \| [log](https://download.openmmlab.com/mmclassification/v0/convnext/convnext-small_32xb128_in1k_20221207-4ab7052c.log.json) |
|
||||
| ConvNeXt-B (`convnext-base_32xb128_in1k`) | From scratch | 88.59 | 15.36 | 83.66 | 96.74 | [config](./convnext-base_32xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/convnext/convnext-base_32xb128_in1k_20221207-fbdb5eb9.pth) \| [log](https://download.openmmlab.com/mmclassification/v0/convnext/convnext-base_32xb128_in1k_20221207-fbdb5eb9.log.json) |
|
||||
| ConvNeXt-B (`convnext-base_3rdparty_in1k`)\* | From scratch | 88.59 | 15.36 | 83.85 | 96.74 | [config](./convnext-base_32xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/convnext/convnext-base_3rdparty_32xb128_in1k_20220124-d0915162.pth) |
|
||||
| ConvNeXt-B (`convnext-base_in21k-pre_3rdparty_in1k`)\* | ImageNet 21k | 88.59 | 15.36 | 85.81 | 97.86 | [config](./convnext-base_32xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/convnext/convnext-base_in21k-pre-3rdparty_32xb128_in1k_20220124-eb2d6ada.pth) |
|
||||
| ConvNeXt-L (`convnext-large_3rdparty_in1k`)\* | From scratch | 197.77 | 34.37 | 84.30 | 96.89 | [config](./convnext-large_64xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/convnext/convnext-large_3rdparty_64xb64_in1k_20220124-f8a0ded0.pth) |
|
||||
| ConvNeXt-L (`convnext-large_in21k-pre_3rdparty_in1k`)\* | ImageNet 21k | 197.77 | 34.37 | 86.61 | 98.04 | [config](./convnext-large_64xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/convnext/convnext-large_in21k-pre-3rdparty_64xb64_in1k_20220124-2412403d.pth) |
|
||||
| ConvNeXt-XL (`convnext-xlarge_in21k-pre_3rdparty_in1k`)\* | ImageNet 21k | 350.20 | 60.93 | 86.97 | 98.20 | [config](./convnext-xlarge_64xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/convnext/convnext-xlarge_in21k-pre-3rdparty_64xb64_in1k_20220124-76b6863d.pth) |
|
||||
|
||||
*Models with * are converted from the [official repo](https://github.com/facebookresearch/ConvNeXt). The config files of these models are only for inference. We don't ensure these config files' training accuracy and welcome you to contribute your reproduction results.*
|
||||
|
||||
## Citation
|
||||
|
||||
```bibtex
|
||||
|
|
|
@ -11,11 +11,11 @@ train_dataloader = dict(batch_size=128)
|
|||
# schedule setting
|
||||
optim_wrapper = dict(
|
||||
optimizer=dict(lr=4e-3),
|
||||
clip_grad=dict(max_norm=5.0),
|
||||
clip_grad=None,
|
||||
)
|
||||
|
||||
# runtime setting
|
||||
custom_hooks = [dict(type='EMAHook', momentum=4e-5, priority='ABOVE_NORMAL')]
|
||||
custom_hooks = [dict(type='EMAHook', momentum=1e-4, priority='ABOVE_NORMAL')]
|
||||
|
||||
# NOTE: `auto_scale_lr` is for automatically scaling LR
|
||||
# based on the actual training batch size.
|
||||
|
|
|
@ -0,0 +1,24 @@
|
|||
_base_ = [
|
||||
'../_base_/models/convnext/convnext-base.py',
|
||||
'../_base_/datasets/imagenet21k_bs128.py',
|
||||
'../_base_/schedules/imagenet_bs1024_adamw_swin.py',
|
||||
'../_base_/default_runtime.py',
|
||||
]
|
||||
|
||||
# model setting
|
||||
model = dict(head=dict(num_classes=21841))
|
||||
|
||||
# dataset setting
|
||||
data_preprocessor = dict(num_classes=21841)
|
||||
train_dataloader = dict(batch_size=128)
|
||||
|
||||
# schedule setting
|
||||
optim_wrapper = dict(
|
||||
optimizer=dict(lr=4e-3),
|
||||
clip_grad=dict(max_norm=5.0),
|
||||
)
|
||||
|
||||
# NOTE: `auto_scale_lr` is for automatically scaling LR
|
||||
# based on the actual training batch size.
|
||||
# base_batch_size = (32 GPUs) x (128 samples per GPU)
|
||||
auto_scale_lr = dict(base_batch_size=4096)
|
|
@ -11,11 +11,11 @@ train_dataloader = dict(batch_size=64)
|
|||
# schedule setting
|
||||
optim_wrapper = dict(
|
||||
optimizer=dict(lr=4e-3),
|
||||
clip_grad=dict(max_norm=5.0),
|
||||
clip_grad=None,
|
||||
)
|
||||
|
||||
# runtime setting
|
||||
custom_hooks = [dict(type='EMAHook', momentum=4e-5, priority='ABOVE_NORMAL')]
|
||||
custom_hooks = [dict(type='EMAHook', momentum=1e-4, priority='ABOVE_NORMAL')]
|
||||
|
||||
# NOTE: `auto_scale_lr` is for automatically scaling LR
|
||||
# based on the actual training batch size.
|
||||
|
|
|
@ -0,0 +1,24 @@
|
|||
_base_ = [
|
||||
'../_base_/models/convnext/convnext-base.py',
|
||||
'../_base_/datasets/imagenet21k_bs128.py',
|
||||
'../_base_/schedules/imagenet_bs1024_adamw_swin.py',
|
||||
'../_base_/default_runtime.py',
|
||||
]
|
||||
|
||||
# model setting
|
||||
model = dict(head=dict(num_classes=21841))
|
||||
|
||||
# dataset setting
|
||||
data_preprocessor = dict(num_classes=21841)
|
||||
train_dataloader = dict(batch_size=64)
|
||||
|
||||
# schedule setting
|
||||
optim_wrapper = dict(
|
||||
optimizer=dict(lr=4e-3),
|
||||
clip_grad=dict(max_norm=5.0),
|
||||
)
|
||||
|
||||
# NOTE: `auto_scale_lr` is for automatically scaling LR
|
||||
# based on the actual training batch size.
|
||||
# base_batch_size = (32 GPUs) x (128 samples per GPU)
|
||||
auto_scale_lr = dict(base_batch_size=4096)
|
|
@ -11,11 +11,11 @@ train_dataloader = dict(batch_size=128)
|
|||
# schedule setting
|
||||
optim_wrapper = dict(
|
||||
optimizer=dict(lr=4e-3),
|
||||
clip_grad=dict(max_norm=5.0),
|
||||
clip_grad=None,
|
||||
)
|
||||
|
||||
# runtime setting
|
||||
custom_hooks = [dict(type='EMAHook', momentum=4e-5, priority='ABOVE_NORMAL')]
|
||||
custom_hooks = [dict(type='EMAHook', momentum=1e-4, priority='ABOVE_NORMAL')]
|
||||
|
||||
# NOTE: `auto_scale_lr` is for automatically scaling LR
|
||||
# based on the actual training batch size.
|
||||
|
|
|
@ -11,11 +11,11 @@ train_dataloader = dict(batch_size=128)
|
|||
# schedule setting
|
||||
optim_wrapper = dict(
|
||||
optimizer=dict(lr=4e-3),
|
||||
clip_grad=dict(max_norm=5.0),
|
||||
clip_grad=None,
|
||||
)
|
||||
|
||||
# runtime setting
|
||||
custom_hooks = [dict(type='EMAHook', momentum=4e-5, priority='ABOVE_NORMAL')]
|
||||
custom_hooks = [dict(type='EMAHook', momentum=1e-4, priority='ABOVE_NORMAL')]
|
||||
|
||||
# NOTE: `auto_scale_lr` is for automatically scaling LR
|
||||
# based on the actual training batch size.
|
||||
|
|
|
@ -11,11 +11,11 @@ train_dataloader = dict(batch_size=64)
|
|||
# schedule setting
|
||||
optim_wrapper = dict(
|
||||
optimizer=dict(lr=4e-3),
|
||||
clip_grad=dict(max_norm=5.0),
|
||||
clip_grad=None,
|
||||
)
|
||||
|
||||
# runtime setting
|
||||
custom_hooks = [dict(type='EMAHook', momentum=4e-5, priority='ABOVE_NORMAL')]
|
||||
custom_hooks = [dict(type='EMAHook', momentum=1e-4, priority='ABOVE_NORMAL')]
|
||||
|
||||
# NOTE: `auto_scale_lr` is for automatically scaling LR
|
||||
# based on the actual training batch size.
|
||||
|
|
|
@ -0,0 +1,24 @@
|
|||
_base_ = [
|
||||
'../_base_/models/convnext/convnext-base.py',
|
||||
'../_base_/datasets/imagenet21k_bs128.py',
|
||||
'../_base_/schedules/imagenet_bs1024_adamw_swin.py',
|
||||
'../_base_/default_runtime.py',
|
||||
]
|
||||
|
||||
# model setting
|
||||
model = dict(head=dict(num_classes=21841))
|
||||
|
||||
# dataset setting
|
||||
data_preprocessor = dict(num_classes=21841)
|
||||
train_dataloader = dict(batch_size=64)
|
||||
|
||||
# schedule setting
|
||||
optim_wrapper = dict(
|
||||
optimizer=dict(lr=4e-3),
|
||||
clip_grad=dict(max_norm=5.0),
|
||||
)
|
||||
|
||||
# NOTE: `auto_scale_lr` is for automatically scaling LR
|
||||
# based on the actual training batch size.
|
||||
# base_batch_size = (32 GPUs) x (128 samples per GPU)
|
||||
auto_scale_lr = dict(base_batch_size=4096)
|
|
@ -14,7 +14,7 @@ Collections:
|
|||
URL: https://github.com/open-mmlab/mmclassification/blob/v0.20.1/mmcls/models/backbones/convnext.py
|
||||
|
||||
Models:
|
||||
- Name: convnext-tiny_3rdparty_32xb128_in1k
|
||||
- Name: convnext-tiny_32xb128_in1k
|
||||
Metadata:
|
||||
FLOPs: 4457472768
|
||||
Parameters: 28589128
|
||||
|
@ -22,15 +22,12 @@ Models:
|
|||
Results:
|
||||
- Dataset: ImageNet-1k
|
||||
Metrics:
|
||||
Top 1 Accuracy: 82.05
|
||||
Top 5 Accuracy: 95.86
|
||||
Top 1 Accuracy: 82.14
|
||||
Top 5 Accuracy: 96.06
|
||||
Task: Image Classification
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/convnext/convnext-tiny_3rdparty_32xb128_in1k_20220124-18abde00.pth
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/convnext/convnext-tiny_32xb128_in1k_20221207-998cf3e9.pth
|
||||
Config: configs/convnext/convnext-tiny_32xb128_in1k.py
|
||||
Converted From:
|
||||
Weights: https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth
|
||||
Code: https://github.com/facebookresearch/ConvNeXt
|
||||
- Name: convnext-tiny_3rdparty_32xb128-noema_in1k
|
||||
- Name: convnext-tiny_32xb128-noema_in1k
|
||||
Metadata:
|
||||
Training Data: ImageNet-1k
|
||||
FLOPs: 4457472768
|
||||
|
@ -39,15 +36,12 @@ Models:
|
|||
Results:
|
||||
- Dataset: ImageNet-1k
|
||||
Metrics:
|
||||
Top 1 Accuracy: 81.81
|
||||
Top 5 Accuracy: 95.67
|
||||
Top 1 Accuracy: 81.95
|
||||
Top 5 Accuracy: 95.89
|
||||
Task: Image Classification
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/convnext/convnext-tiny_3rdparty_32xb128-noema_in1k_20220222-2908964a.pth
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/convnext/convnext-tiny_32xb128-noema_in1k_20221208-5d4509c7.pth
|
||||
Config: configs/convnext/convnext-tiny_32xb128_in1k.py
|
||||
Converted From:
|
||||
Weights: https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224.pth
|
||||
Code: https://github.com/facebookresearch/ConvNeXt
|
||||
- Name: convnext-small_3rdparty_32xb128_in1k
|
||||
- Name: convnext-small_32xb128_in1k
|
||||
Metadata:
|
||||
Training Data: ImageNet-1k
|
||||
FLOPs: 8687008512
|
||||
|
@ -56,15 +50,12 @@ Models:
|
|||
Results:
|
||||
- Dataset: ImageNet-1k
|
||||
Metrics:
|
||||
Top 1 Accuracy: 83.13
|
||||
Top 5 Accuracy: 96.44
|
||||
Top 1 Accuracy: 83.16
|
||||
Top 5 Accuracy: 96.56
|
||||
Task: Image Classification
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/convnext/convnext-small_3rdparty_32xb128_in1k_20220124-d39b5192.pth
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/convnext/convnext-small_32xb128_in1k_20221207-4ab7052c.pth
|
||||
Config: configs/convnext/convnext-small_32xb128_in1k.py
|
||||
Converted From:
|
||||
Weights: https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth
|
||||
Code: https://github.com/facebookresearch/ConvNeXt
|
||||
- Name: convnext-small_3rdparty_32xb128-noema_in1k
|
||||
- Name: convnext-small_32xb128-noema_in1k
|
||||
Metadata:
|
||||
Training Data: ImageNet-1k
|
||||
FLOPs: 8687008512
|
||||
|
@ -73,15 +64,40 @@ Models:
|
|||
Results:
|
||||
- Dataset: ImageNet-1k
|
||||
Metrics:
|
||||
Top 1 Accuracy: 83.11
|
||||
Top 5 Accuracy: 96.34
|
||||
Top 1 Accuracy: 83.21
|
||||
Top 5 Accuracy: 96.48
|
||||
Task: Image Classification
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/convnext/convnext-small_3rdparty_32xb128-noema_in1k_20220222-fa001ca5.pth
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/convnext/convnext-small_32xb128-noema_in1k_20221208-4a618995.pth
|
||||
Config: configs/convnext/convnext-small_32xb128_in1k.py
|
||||
Converted From:
|
||||
Weights: https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224.pth
|
||||
Code: https://github.com/facebookresearch/ConvNeXt
|
||||
- Name: convnext-base_3rdparty_32xb128_in1k
|
||||
- Name: convnext-base_32xb128_in1k
|
||||
Metadata:
|
||||
Training Data: ImageNet-1k
|
||||
FLOPs: 15359124480
|
||||
Parameters: 88591464
|
||||
In Collection: ConvNeXt
|
||||
Results:
|
||||
- Dataset: ImageNet-1k
|
||||
Metrics:
|
||||
Top 1 Accuracy: 83.66
|
||||
Top 5 Accuracy: 96.74
|
||||
Task: Image Classification
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/convnext/convnext-base_32xb128_in1k_20221207-fbdb5eb9.pth
|
||||
Config: configs/convnext/convnext-base_32xb128_in1k.py
|
||||
- Name: convnext-base_32xb128-noema_in1k
|
||||
Metadata:
|
||||
Training Data: ImageNet-1k
|
||||
FLOPs: 15359124480
|
||||
Parameters: 88591464
|
||||
In Collection: ConvNeXt
|
||||
Results:
|
||||
- Dataset: ImageNet-1k
|
||||
Metrics:
|
||||
Top 1 Accuracy: 83.64
|
||||
Top 5 Accuracy: 96.61
|
||||
Task: Image Classification
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/convnext/convnext-base_32xb128-noema_in1k_20221208-f8182678.pth
|
||||
Config: configs/convnext/convnext-base_32xb128_in1k.py
|
||||
- Name: convnext-base_3rdparty_in1k
|
||||
Metadata:
|
||||
Training Data: ImageNet-1k
|
||||
FLOPs: 15359124480
|
||||
|
@ -98,7 +114,7 @@ Models:
|
|||
Converted From:
|
||||
Weights: https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth
|
||||
Code: https://github.com/facebookresearch/ConvNeXt
|
||||
- Name: convnext-base_3rdparty_32xb128-noema_in1k
|
||||
- Name: convnext-base_3rdparty-noema_in1k
|
||||
Metadata:
|
||||
Training Data: ImageNet-1k
|
||||
FLOPs: 15359124480
|
||||
|
@ -123,10 +139,11 @@ Models:
|
|||
In Collection: ConvNeXt
|
||||
Results: null
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/convnext/convnext-base_3rdparty_in21k_20220124-13b83eec.pth
|
||||
Config: configs/convnext/convnext-base_32xb128_in21k.py
|
||||
Converted From:
|
||||
Weights: https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth
|
||||
Code: https://github.com/facebookresearch/ConvNeXt
|
||||
- Name: convnext-base_in21k-pre-3rdparty_32xb128_in1k
|
||||
- Name: convnext-base_in21k-pre_3rdparty_in1k
|
||||
Metadata:
|
||||
Training Data:
|
||||
- ImageNet-21k
|
||||
|
@ -145,7 +162,7 @@ Models:
|
|||
Converted From:
|
||||
Weights: https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_1k_224.pth
|
||||
Code: https://github.com/facebookresearch/ConvNeXt
|
||||
- Name: convnext-large_3rdparty_64xb64_in1k
|
||||
- Name: convnext-large_3rdparty_in1k
|
||||
Metadata:
|
||||
Training Data: ImageNet-1k
|
||||
FLOPs: 34368026112
|
||||
|
@ -170,10 +187,11 @@ Models:
|
|||
In Collection: ConvNeXt
|
||||
Results: null
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/convnext/convnext-large_3rdparty_in21k_20220124-41b5a79f.pth
|
||||
Config: configs/convnext/convnext-large_64xb64_in21k.py
|
||||
Converted From:
|
||||
Weights: https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_224.pth
|
||||
Code: https://github.com/facebookresearch/ConvNeXt
|
||||
- Name: convnext-large_in21k-pre-3rdparty_64xb64_in1k
|
||||
- Name: convnext-large_in21k-pre_3rdparty_in1k
|
||||
Metadata:
|
||||
Training Data:
|
||||
- ImageNet-21k
|
||||
|
@ -200,10 +218,11 @@ Models:
|
|||
In Collection: ConvNeXt
|
||||
Results: null
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/convnext/convnext-xlarge_3rdparty_in21k_20220124-f909bad7.pth
|
||||
Config: configs/convnext/convnext-xlarge_64xb64_in21k.py
|
||||
Converted From:
|
||||
Weights: https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_224.pth
|
||||
Code: https://github.com/facebookresearch/ConvNeXt
|
||||
- Name: convnext-xlarge_in21k-pre-3rdparty_64xb64_in1k
|
||||
- Name: convnext-xlarge_in21k-pre_3rdparty_in1k
|
||||
Metadata:
|
||||
Training Data:
|
||||
- ImageNet-21k
|
||||
|
|
|
@ -481,7 +481,7 @@ visualizer = dict(
|
|||
)
|
||||
```
|
||||
|
||||
New field **`default_scope`**: The start point to search module for all registries. The `default_scope` in MMClassification is `mmcls`. See {external+mmengine:doc}`the registry tutorial <tutorials/registry>` for more details.
|
||||
New field **`default_scope`**: The start point to search module for all registries. The `default_scope` in MMClassification is `mmcls`. See {external+mmengine:doc}`the registry tutorial <advanced_tutorials/registry>` for more details.
|
||||
|
||||
## Packages
|
||||
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
To manage various configurations in a deep-learning experiment, we use a kind of config file to record all of
|
||||
these configurations. This config system has a modular and inheritance design, and more details can be found in
|
||||
{external+mmengine:doc}`the tutorial in MMEngine <tutorials/config>`.
|
||||
{external+mmengine:doc}`the tutorial in MMEngine <advanced_tutorials/config>`.
|
||||
|
||||
Usually, we use python files as config file. All configuration files are placed under the [`configs`](https://github.com/open-mmlab/mmclassification/tree/1.x/configs) folder, and the directory structure is as follows:
|
||||
|
||||
|
@ -64,7 +64,7 @@ This primitive config file includes a dict variable `model`, which mainly includ
|
|||
|
||||
```{note}
|
||||
Usually, we use the `type` field to specify the class of the component and use other fields to pass
|
||||
the initialization arguments of the class. The {external+mmengine:doc}`registry tutorial <tutorials/registry>` describes it in detail.
|
||||
the initialization arguments of the class. The {external+mmengine:doc}`registry tutorial <advanced_tutorials/registry>` describes it in detail.
|
||||
```
|
||||
|
||||
Following is the model primitive config of the ResNet50 config file in [`configs/_base_/models/resnet50.py`](https://github.com/open-mmlab/mmclassification/blob/1.x/configs/_base_/models/resnet50.py):
|
||||
|
@ -348,7 +348,7 @@ test_dataloader = dict(dataset=dict(pipeline=val_pipeline))
|
|||
|
||||
### Ignore some fields in the base configs
|
||||
|
||||
Sometimes, you need to set `_delete_=True` to ignore some domain content in the basic configuration file. You can refer to the {external+mmengine:doc}`documentation in MMEngine <tutorials/config>` for more instructions.
|
||||
Sometimes, you need to set `_delete_=True` to ignore some domain content in the basic configuration file. You can refer to the {external+mmengine:doc}`documentation in MMEngine <advanced_tutorials/config>` for more instructions.
|
||||
|
||||
The following is an example. If you want to use cosine schedule in the above ResNet50 case, just using inheritance and directly modifying it will report `get unexpected keyword 'step'` error, because the `'step'` field of the basic config in `param_scheduler` domain information is reserved, and you need to add `_delete_ =True` to ignore the content of `param_scheduler` related fields in the basic configuration file:
|
||||
|
||||
|
@ -361,7 +361,7 @@ param_scheduler = dict(type='CosineAnnealingLR', by_epoch=True, _delete_=True)
|
|||
|
||||
### Use some fields in the base configs
|
||||
|
||||
Sometimes, you may refer to some fields in the `_base_` config, to avoid duplication of definitions. You can refer to {external+mmengine:doc}`MMEngine <tutorials/config>` for some more instructions.
|
||||
Sometimes, you may refer to some fields in the `_base_` config, to avoid duplication of definitions. You can refer to {external+mmengine:doc}`MMEngine <advanced_tutorials/config>` for some more instructions.
|
||||
|
||||
The following is an example of using auto augment in the training data preprocessing pipeline, refer to [`configs/resnest/resnest50_32xb64_in1k.py`](https://github.com/open-mmlab/mmclassification/blob/1.x/configs/resnest/resnest50_32xb64_in1k.py). When defining `train_pipeline`, just add the definition file name of auto augment to `_base_`, and then use `_base_.auto_increasing_policies` to reference the variables in the primitive config:
|
||||
|
||||
|
|
|
@ -31,12 +31,20 @@ class LayerNorm2d(nn.LayerNorm):
|
|||
super().__init__(num_channels, **kwargs)
|
||||
self.num_channels = self.normalized_shape[0]
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x, data_format='channel_first'):
|
||||
assert x.dim() == 4, 'LayerNorm2d only supports inputs with shape ' \
|
||||
f'(N, C, H, W), but got tensor with shape {x.shape}'
|
||||
return F.layer_norm(
|
||||
x.permute(0, 2, 3, 1), self.normalized_shape, self.weight,
|
||||
self.bias, self.eps).permute(0, 3, 1, 2)
|
||||
if data_format == 'channel_last':
|
||||
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias,
|
||||
self.eps)
|
||||
elif data_format == 'channel_first':
|
||||
x = x.permute(0, 2, 3, 1)
|
||||
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias,
|
||||
self.eps)
|
||||
# If the output is discontiguous, it may cause some unexpected
|
||||
# problem in the downstream tasks
|
||||
x = x.permute(0, 3, 1, 2).contiguous()
|
||||
return x
|
||||
|
||||
|
||||
class ConvNeXtBlock(BaseModule):
|
||||
|
@ -113,10 +121,10 @@ class ConvNeXtBlock(BaseModule):
|
|||
def _inner_forward(x):
|
||||
shortcut = x
|
||||
x = self.depthwise_conv(x)
|
||||
x = self.norm(x)
|
||||
|
||||
if self.linear_pw_conv:
|
||||
x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
|
||||
x = self.norm(x, data_format='channel_last')
|
||||
|
||||
x = self.pointwise_conv1(x)
|
||||
x = self.act(x)
|
||||
|
@ -284,7 +292,7 @@ class ConvNeXt(BaseBackbone):
|
|||
|
||||
if i >= 1:
|
||||
downsample_layer = nn.Sequential(
|
||||
LayerNorm2d(self.channels[i - 1]),
|
||||
build_norm_layer(norm_cfg, self.channels[i - 1])[1],
|
||||
nn.Conv2d(
|
||||
self.channels[i - 1],
|
||||
channels,
|
||||
|
@ -324,9 +332,7 @@ class ConvNeXt(BaseBackbone):
|
|||
gap = x.mean([-2, -1], keepdim=True)
|
||||
outs.append(norm_layer(gap).flatten(1))
|
||||
else:
|
||||
# The output of LayerNorm2d may be discontiguous, which
|
||||
# may cause some problem in the downstream tasks
|
||||
outs.append(norm_layer(x).contiguous())
|
||||
outs.append(norm_layer(x))
|
||||
|
||||
return tuple(outs)
|
||||
|
||||
|
|
Loading…
Reference in New Issue