[Enhance] Reproduce mobileone training accuracy. (#1191)

* add switch hook and UTs

* update doc

* update doc

* fix lint

* fix ci

* fix ci

* fix typo

* fix ci

* update configs names

* update configs

* update configs

* update links

* update readme

* update vis_scheduler

* update metafile

* update configs

* rebase

* fix ci

* rebase
pull/1143/head
Ezra-Yu 2022-11-21 10:43:34 +08:00 committed by GitHub
parent 629f6447ef
commit 4969830c8a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
23 changed files with 460 additions and 227 deletions

View File

@ -1,19 +1,6 @@
_base_ = [ # optimizer
'../_base_/models/mobileone/mobileone_s0.py',
'../_base_/datasets/imagenet_bs32_pil_resize.py',
'../_base_/default_runtime.py'
]
# dataset settings
train_dataloader = dict(batch_size=128)
val_dataloader = dict(batch_size=128)
test_dataloader = dict(batch_size=128)
# schedule settings
optim_wrapper = dict( optim_wrapper = dict(
optimizer=dict(type='SGD', lr=0.1, momentum=0.9, weight_decay=0.0001), optimizer=dict(type='SGD', lr=0.1, momentum=0.9, weight_decay=0.0001))
paramwise_cfg=dict(bias_decay_mult=0., norm_decay_mult=0.),
)
# learning policy # learning policy
param_scheduler = [ param_scheduler = [
@ -50,7 +37,4 @@ test_cfg = dict()
# NOTE: `auto_scale_lr` is for automatically scaling LR, # NOTE: `auto_scale_lr` is for automatically scaling LR,
# based on the actual training batch size. # based on the actual training batch size.
auto_scale_lr = dict(base_batch_size=1024) auto_scale_lr = dict(base_batch_size=256)
# runtime setting
custom_hooks = [dict(type='EMAHook', momentum=5e-4, priority='ABOVE_NORMAL')]

View File

@ -4,35 +4,121 @@
<!-- [ALGORITHM] --> <!-- [ALGORITHM] -->
## Abstract ## Introduction
Efficient neural network backbones for mobile devices are often optimized for metrics such as FLOPs or parameter count. However, these metrics may not correlate well with latency of the network when deployed on a mobile device. Therefore, we perform extensive analysis of different metrics by deploying several mobile-friendly networks on a mobile device. We identify and analyze architectural and optimization bottlenecks in recent efficient neural networks and provide ways to mitigate these bottlenecks. To this end, we design an efficient backbone MobileOne, with variants achieving an inference time under 1 ms on an iPhone12 with 75.9% top-1 accuracy on ImageNet. We show that MobileOne achieves state-of-the-art performance within the efficient architectures while being many times faster on mobile. Our best model obtains similar performance on ImageNet as MobileFormer while being 38x faster. Our model obtains 2.3% better top-1 accuracy on ImageNet than EfficientNet at similar latency. Furthermore, we show that our model generalizes to multiple tasks - image classification, object detection, and semantic segmentation with significant improvements in latency and accuracy as compared to existing efficient architectures when deployed on a mobile device. Mobileone is proposed by apple and based on reparameterization. On the apple chips, the accuracy of the model is close to 0.76 on the ImageNet dataset when the latency is less than 1ms. Its main improvements based on [RepVGG](../repvgg) are fllowing:
- Reparameterization using Depthwise convolution and Pointwise convolution instead of normal convolution.
- Removal of the residual structure which is not friendly to access memory.
<div align=center> <div align=center>
<img src="https://user-images.githubusercontent.com/18586273/183552452-74657532-f461-48f7-9aa7-c23f006cdb07.png" width="40%"/> <img src="https://user-images.githubusercontent.com/18586273/183552452-74657532-f461-48f7-9aa7-c23f006cdb07.png" width="40%"/>
</div> </div>
## Results and models ## Abstract
### ImageNet-1k <details>
| Model | Params(M) | Flops(G) | Top-1 (%) | Top-5 (%) | Config | Download | <summary>Show the paper's abstract</summary>
| :------------: | :-----------------------------: | :----------------------------: | :-------: | :-------: | :--------------------------------------------------: | :-----------------------------------------------------: |
| MobileOne-s0\* | 5.29train) \| 2.08 (deploy) | 1.09 (train) \| 0.28 (deploy) | 71.36 | 89.87 | [config (train)](./mobileone-s0_8xb128_in1k.py) \| [config (deploy)](./deploy/mobileone-s0_deploy_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/mobileone/mobileone-s0_3rdparty_in1k_20220915-007ae971.pth) |
| MobileOne-s1\* | 4.83 (train) \| 4.76 (deploy) | 0.86 (train) \| 0.84 (deploy) | 75.76 | 92.77 | [config (train)](./mobileone-s1_8xb128_in1k.py) \| [config (deploy)](./deploy/mobileone-s1_deploy_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/mobileone/mobileone-s1_3rdparty_in1k_20220915-473c8469.pth) |
| MobileOne-s2\* | 7.88 (train) \| 7.88 (deploy) | 1.34 (train) \| 1.31 (deploy) | 77.39 | 93.63 | [config (train)](./mobileone-s2_8xb128_in1k.py) \|[config (deploy)](./deploy/mobileone-s2_deploy_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/mobileone/mobileone-s2_3rdparty_in1k_20220915-ed2e4c30.pth) |
| MobileOne-s3\* | 10.17 (train) \| 10.08 (deploy) | 1.95 (train) \| 1.91 (deploy) | 77.93 | 93.89 | [config (train)](./mobileone-s3_8xb128_in1k.py) \|[config (deploy)](./deploy/mobileone-s3_deploy_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/mobileone/mobileone-s3_3rdparty_in1k_20220915-84d6a02c.pth) |
| MobileOne-s4\* | 14.95 (train) \| 14.84 (deploy) | 3.05 (train) \| 3.00 (deploy) | 79.30 | 94.37 | [config (train)](./mobileone-s4_8xb128_in1k.py) \|[config (deploy)](./deploy/mobileone-s4_deploy_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/mobileone/mobileone-s4_3rdparty_in1k_20220915-ce9509ee.pth) |
*Models with * are converted from the [official repo](https://github.com/apple/ml-mobileone). The config files of these models are only for validation. We don't ensure these config files' training accuracy and welcome you to contribute your reproduction results.* <br>
Efficient neural network backbones for mobile devices are often optimized for metrics such as FLOPs or parameter count. However, these metrics may not correlate well with latency of the network when deployed on a mobile device. Therefore, we perform extensive analysis of different metrics by deploying several mobile-friendly networks on a mobile device. We identify and analyze architectural and optimization bottlenecks in recent efficient neural networks and provide ways to mitigate these bottlenecks. To this end, we design an efficient backbone MobileOne, with variants achieving an inference time under 1 ms on an iPhone12 with 75.9% top-1 accuracy on ImageNet. We show that MobileOne achieves state-of-the-art performance within the efficient architectures while being many times faster on mobile. Our best model obtains similar performance on ImageNet as MobileFormer while being 38x faster. Our model obtains 2.3% better top-1 accuracy on ImageNet than EfficientNet at similar latency. Furthermore, we show that our model generalizes to multiple tasks - image classification, object detection, and semantic segmentation with significant improvements in latency and accuracy as compared to existing efficient architectures when deployed on a mobile device.
</br>
*Because the [official repo.](https://github.com/apple/ml-mobileone) does not give a strategy for training and testing, the test data pipline of [RepVGG](https://github.com/open-mmlab/mmclassification/tree/master/configs/repvgg) is used here, and the result is about 0.1 lower than that in the paper. Refer to [this issue](https://github.com/apple/ml-mobileone/issues/2).* </details>
## How to use ## How to use
The checkpoints provided are all `training-time` models. Use the reparameterize tool to switch them to more efficient `inference-time` architecture, which not only has fewer parameters but also less calculations. The checkpoints provided are all `training-time` models. Use the reparameterize tool or `switch_to_deploy` interface to switch them to more efficient `inference-time` architecture, which not only has fewer parameters but also less calculations.
### Use tool <!-- [TABS-BEGIN] -->
**Predict image**
Use `classifier.backbone.switch_to_deploy()` interface to switch the MobileOne to a inference mode.
```python
>>> import torch
>>> from mmcls.apis import init_model, inference_model
>>>
>>> model = init_model('configs/mobileone/mobileone-s0_8xb32_in1k.py', 'https://download.openmmlab.com/mmclassification/v0/mobileone/mobileone-s0_8xb32_in1k_20221110-0bc94952.pth')
>>> predict = inference_model(model, 'demo/demo.JPEG')
>>> print(predict['pred_class'])
sea snake
>>> print(predict['pred_score'])
0.4539405107498169
>>>
>>> # switch to deploy mode
>>> model.backbone.switch_to_deploy()
>>> predict_deploy = inference_model(model, 'demo/demo.JPEG')
>>> print(predict_deploy['pred_class'])
sea snake
>>> print(predict_deploy['pred_score'])
0.4539395272731781
```
**Use the model**
```python
>>> import torch
>>> from mmcls.apis import init_model
>>>
>>> model = init_model('configs/mobileone/mobileone-s0_8xb32_in1k.py', 'https://download.openmmlab.com/mmclassification/v0/mobileone/mobileone-s0_8xb32_in1k_20221110-0bc94952.pth')
>>> inputs = torch.rand(1, 3, 224, 224).to(model.data_preprocessor.device)
>>> # To get classification scores.
>>> out = model(inputs)
>>> print(out.shape)
torch.Size([1, 1000])
>>> # To extract features.
>>> outs = model.extract_feat(inputs)
>>> print(outs[0].shape)
torch.Size([1, 768])
>>>
>>> # switch to deploy mode
>>> model.backbone.switch_to_deploy()
>>> out_deploy = model(inputs)
>>> print(out.shape)
torch.Size([1, 1000])
>>> assert torch.allclose(out, out_deploy) # pass without error
```
**Train/Test Command**
Place the ImageNet dataset to the `data/imagenet/` directory, or prepare datasets according to the [docs](https://mmclassification.readthedocs.io/en/1.x/user_guides/dataset_prepare.html#prepare-dataset).
Train:
```shell
python tools/train.py configs/mobileone/mobileone-s0_8xb32_in1k.py
```
Download Checkpoint:
```shell
wget https://download.openmmlab.com/mmclassification/v0/mobileone/mobileone-s0_8xb32_in1k_20221110-0bc94952.pth
```
Test use unfused model:
```shell
python tools/test.py configs/mobileone/mobileone-s0_8xb32_in1k.py mobileone-s0_8xb32_in1k_20221110-0bc94952.pth
```
Reparameterize checkpoint:
```shell
python ./tools/convert_models/reparameterize_model.py ./configs/mobileone/mobileone-s0_8xb32_in1k.py mobileone-s0_8xb32_in1k_20221110-0bc94952.pth mobileone_s0_deploy.pth
```
Test use fused model:
```shell
python tools/test.py configs/mobileone/deploy/mobileone-s0_deploy_8xb32_in1k.py mobileone_s0_deploy.pth
```
<!-- [TABS-END] -->
### Reparameterize Tool
Use provided tool to reparameterize the given model and save the checkpoint: Use provided tool to reparameterize the given model and save the checkpoint:
@ -45,80 +131,35 @@ python tools/convert_models/reparameterize_model.py ${CFG_PATH} ${SRC_CKPT_PATH}
For example: For example:
```shell ```shell
python ./tools/convert_models/reparameterize_model.py ./configs/mobileone/mobileone-s0_8xb128_in1k.py https://download.openmmlab.com/mmclassification/v0/mobileone/mobileone-s0_3rdparty_in1k_20220811-db5ce29b.pth ./mobileone_s0_deploy.pth wget https://download.openmmlab.com/mmclassification/v0/mobileone/mobileone-s0_8xb32_in1k_20221110-0bc94952.pth
python ./tools/convert_models/reparameterize_model.py ./configs/mobileone/mobileone-s0_8xb32_in1k.py mobileone-s0_8xb32_in1k_20221110-0bc94952.pth mobileone_s0_deploy.pth
``` ```
To use reparameterized weights, the config file must switch to **the deploy config files**. To use reparameterized weights, the config file must switch to [**the deploy config files**](./deploy/).
```bash ```bash
python tools/test.py ${Deploy_CFG} ${Deploy_Checkpoint} --metrics accuracy python tools/test.py ${Deploy_CFG} ${Deploy_Checkpoint}
``` ```
For example of using the reparameterized weights above: For example of using the reparameterized weights above:
```shell ```shell
python ./tools/test.py ./configs/mobileone/deploy/mobileone-s0_deploy_8xb128_in1k.py mobileone_s0_deploy.pth --metrics accuracy python ./tools/test.py ./configs/mobileone/deploy/mobileone-s0_deploy_8xb32_in1k.py mobileone_s0_deploy.pth
``` ```
### In the code For more configurable parameters, please refer to the [API](https://mmclassification.readthedocs.io/en/1.x/api/generated/mmcls.models.backbones.MobileOne.html#mmcls.models.backbones.MobileOne).
Use the API `switch_to_deploy` of `MobileOne` backbone to to switch to the deploy mode. Usually called like `backbone.switch_to_deploy()` or `classificer.backbone.switch_to_deploy()`. ## Results and models
For Backbones: ### ImageNet-1k
```python | Model | Params(M) | Flops(G) | Top-1 (%) | Top-5 (%) | Config | Download |
from mmcls.models import build_backbone | :----------: | :-----------------------------: | :----------------------------: | :-------: | :-------: | :---------------------------------------------------: | :------------------------------------------------------: |
import torch | MobileOne-s0 | 5.29train) \| 2.08 (deploy) | 1.09 (train) \| 0.28 (deploy) | 71.34 | 89.87 | [config (train)](./mobileone-s0_8xb32_in1k.py) \| [config (deploy)](./deploy/mobileone-s0_deploy_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/mobileone/mobileone-s0_8xb32_in1k_20221110-0bc94952.pth) \| [log](https://download.openmmlab.com/mmclassification/v0/mobileone/mobileone-s0_8xb32_in1k_20221110-0bc94952.json) |
| MobileOne-s1 | 4.83 (train) \| 4.76 (deploy) | 0.86 (train) \| 0.84 (deploy) | 75.72 | 92.54 | [config (train)](./mobileone-s1_8xb32_in1k.py) \| [config (deploy)](./deploy/mobileone-s1_deploy_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/mobileone/mobileone-s1_8xb32_in1k_20221110-ceeef467.pth) \| [log](https://download.openmmlab.com/mmclassification/v0/mobileone/mobileone-s1_8xb32_in1k_20221110-ceeef467.json) |
x = torch.randn( (1, 3, 224, 224) ) | MobileOne-s2 | 7.88 (train) \| 7.88 (deploy) | 1.34 (train) \| 1.31 (deploy) | 77.37 | 93.34 | [config (train)](./mobileone-s2_8xb32_in1k.py) \|[config (deploy)](./deploy/mobileone-s2_deploy_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/mobileone/mobileone-s2_8xb32_in1k_20221110-9c7ecb97.pth) \| [log](https://download.openmmlab.com/mmclassification/v0/mobileone/mobileone-s2_8xb32_in1k_20221110-9c7ecb97.json) |
backbone_cfg=dict(type='MobileOne', arch='s0') | MobileOne-s3 | 10.17 (train) \| 10.08 (deploy) | 1.95 (train) \| 1.91 (deploy) | 78.06 | 93.83 | [config (train)](./mobileone-s3_8xb32_in1k.py) \|[config (deploy)](./deploy/mobileone-s3_deploy_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/mobileone/mobileone-s3_8xb32_in1k_20221110-c95eb3bf.pth) \| [log](https://download.openmmlab.com/mmclassification/v0/mobileone/mobileone-s3_8xb32_in1k_20221110-c95eb3bf.pth) |
backbone = build_backbone(backbone_cfg) | MobileOne-s4 | 14.95 (train) \| 14.84 (deploy) | 3.05 (train) \| 3.00 (deploy) | 79.69 | 94.46 | [config (train)](./mobileone-s4_8xb32_in1k.py) \|[config (deploy)](./deploy/mobileone-s4_deploy_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/mobileone/mobileone-s4_8xb32_in1k_20221110-28d888cb.pth) \| [log](https://download.openmmlab.com/mmclassification/v0/mobileone/mobileone-s4_8xb32_in1k_20221110-28d888cb.pth) |
backbone.init_weights()
backbone.eval()
outs_ori = backbone(x)
backbone.switch_to_deploy()
outs_dep = backbone(x)
for out1, out2 in zip(outs_ori, outs_dep):
assert torch.allclose(out1, out2)
```
For ImageClassifiers:
```python
from mmcls.models import build_classifier
import torch
import numpy as np
cfg = dict(
type='ImageClassifier',
backbone=dict(
type='MobileOne',
arch='s0',
out_indices=(3, ),
),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=1024,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
topk=(1, 5),
))
x = torch.randn( (1, 3, 224, 224) )
classifier = build_classifier(cfg)
classifier.init_weights()
classifier.eval()
y_ori = classifier(x, return_loss=False)
classifier.backbone.switch_to_deploy()
y_dep = classifier(x, return_loss=False)
for y1, y2 in zip(y_ori, y_dep):
assert np.allclose(y1, y2)
```
## Citation ## Citation

View File

@ -1,3 +0,0 @@
_base_ = ['../mobileone-s0_8xb128_in1k.py']
model = dict(backbone=dict(deploy=True))

View File

@ -0,0 +1,3 @@
_base_ = ['../mobileone-s0_8xb32_in1k.py']
model = dict(backbone=dict(deploy=True))

View File

@ -1,3 +0,0 @@
_base_ = ['../mobileone-s1_8xb128_in1k.py']
model = dict(backbone=dict(deploy=True))

View File

@ -0,0 +1,3 @@
_base_ = ['../mobileone-s1_8xb32_in1k.py']
model = dict(backbone=dict(deploy=True))

View File

@ -1,3 +0,0 @@
_base_ = ['../mobileone-s2_8xb128_in1k.py']
model = dict(backbone=dict(deploy=True))

View File

@ -0,0 +1,3 @@
_base_ = ['../mobileone-s2_8xb32_in1k.py']
model = dict(backbone=dict(deploy=True))

View File

@ -1,3 +0,0 @@
_base_ = ['../mobileone-s3_8xb128_in1k.py']
model = dict(backbone=dict(deploy=True))

View File

@ -0,0 +1,3 @@
_base_ = ['../mobileone-s3_8xb32_in1k.py']
model = dict(backbone=dict(deploy=True))

View File

@ -1,3 +0,0 @@
_base_ = ['../mobileone-s4_8xb128_in1k.py']
model = dict(backbone=dict(deploy=True))

View File

@ -0,0 +1,3 @@
_base_ = ['../mobileone-s4_8xb32_in1k.py']
model = dict(backbone=dict(deploy=True))

View File

@ -16,83 +16,68 @@ Collections:
Version: v1.0.0rc1 Version: v1.0.0rc1
Models: Models:
- Name: mobileone-s0_3rdparty_8xb128_in1k - Name: mobileone-s0_8xb32_in1k
In Collection: MobileOne In Collection: MobileOne
Config: configs/mobileone/mobileone-s0_8xb128_in1k.py Config: configs/mobileone/mobileone-s0_8xb32_in1k.py
Metadata: Metadata:
FLOPs: 1091227648 # 1.09G FLOPs: 274136576 # 0.27G
Parameters: 5293272 # 5.29M Parameters: 2078504 # 2.08M
Results: Results:
- Dataset: ImageNet-1k - Dataset: ImageNet-1k
Task: Image Classification Task: Image Classification
Metrics: Metrics:
Top 1 Accuracy: 71.36 Top 1 Accuracy: 71.34
Top 5 Accuracy: 89.87 Top 5 Accuracy: 89.87
Weights: https://download.openmmlab.com/mmclassification/v0/mobileone/mobileone-s0_3rdparty_in1k_20220915-007ae971.pth Weights: https://download.openmmlab.com/mmclassification/v0/mobileone/mobileone-s0_8xb32_in1k_20221110-0bc94952.pth
Converted From: - Name: mobileone-s1_8xb32_in1k
Weights: https://docs-assets.developer.apple.com/ml-research/datasets/mobileone/mobileone_s0_unfused.pth.tar
Code: https://github.com/apple/ml-mobileone
- Name: mobileone-s1_3rdparty_8xb128_in1k
In Collection: MobileOne In Collection: MobileOne
Config: configs/mobileone/mobileone-s1_8xb128_in1k.py Config: configs/mobileone/mobileone-s1_8xb32_in1k.py
Metadata: Metadata:
FLOPs: 863491328 # 8.6G FLOPs: 823839744 # 8.6G
Parameters: 4825192 # 4.82M Parameters: 4764840 # 4.82M
Results: Results:
- Dataset: ImageNet-1k - Dataset: ImageNet-1k
Task: Image Classification Task: Image Classification
Metrics: Metrics:
Top 1 Accuracy: 75.76 Top 1 Accuracy: 75.72
Top 5 Accuracy: 92.77 Top 5 Accuracy: 92.54
Weights: https://download.openmmlab.com/mmclassification/v0/mobileone/mobileone-s1_3rdparty_in1k_20220915-473c8469.pth Weights: https://download.openmmlab.com/mmclassification/v0/mobileone/mobileone-s1_8xb32_in1k_20221110-ceeef467.pth
Converted From: - Name: mobileone-s2_8xb32_in1k
Weights: https://docs-assets.developer.apple.com/ml-research/datasets/mobileone/mobileone_s1_unfused.pth.tar
Code: https://github.com/apple/ml-mobileone
- Name: mobileone-s2_3rdparty_8xb128_in1k
In Collection: MobileOne In Collection: MobileOne
Config: configs/mobileone/mobileone-s2_8xb128_in1k.py Config: configs/mobileone/mobileone-s2_8xb32_in1k.py
Metadata: Metadata:
FLOPs: 1344083328 FLOPs: 1296478848
Parameters: 7884648 Parameters: 7808168
Results: Results:
- Dataset: ImageNet-1k - Dataset: ImageNet-1k
Task: Image Classification Task: Image Classification
Metrics: Metrics:
Top 1 Accuracy: 77.39 Top 1 Accuracy: 77.37
Top 5 Accuracy: 93.63 Top 5 Accuracy: 93.34
Weights: https://download.openmmlab.com/mmclassification/v0/mobileone/mobileone-s2_3rdparty_in1k_20220915-ed2e4c30.pth Weights: https://download.openmmlab.com/mmclassification/v0/mobileone/mobileone-s2_8xb32_in1k_20221110-9c7ecb97.pth
Converted From: - Name: mobileone-s3_8xb32_in1k
Weights: https://docs-assets.developer.apple.com/ml-research/datasets/mobileone/mobileone_s2_unfused.pth.tar
Code: https://github.com/apple/ml-mobileone
- Name: mobileone-s3_3rdparty_8xb128_in1k
In Collection: MobileOne In Collection: MobileOne
Config: configs/mobileone/mobileone-s3_8xb128_in1k.py Config: configs/mobileone/mobileone-s3_8xb32_in1k.py
Metadata: Metadata:
FLOPs: 1951043584 FLOPs: 1893842944
Parameters: 10170600 Parameters: 10078312
Results: Results:
- Dataset: ImageNet-1k - Dataset: ImageNet-1k
Task: Image Classification Task: Image Classification
Metrics: Metrics:
Top 1 Accuracy: 77.93 Top 1 Accuracy: 78.06
Top 5 Accuracy: 93.89 Top 5 Accuracy: 93.83
Weights: https://download.openmmlab.com/mmclassification/v0/mobileone/mobileone-s3_3rdparty_in1k_20220915-84d6a02c.pth Weights: https://download.openmmlab.com/mmclassification/v0/mobileone/mobileone-s3_8xb32_in1k_20221110-c95eb3bf.pth
Converted From: - Name: mobileone-s4_8xb32_in1k
Weights: https://docs-assets.developer.apple.com/ml-research/datasets/mobileone/mobileone_s3_unfused.pth.tar
Code: https://github.com/apple/ml-mobileone
- Name: mobileone-s4_3rdparty_8xb128_in1k
In Collection: MobileOne In Collection: MobileOne
Config: configs/mobileone/mobileone-s4_8xb128_in1k.py Config: configs/mobileone/mobileone-s4_8xb32_in1k.py
Metadata: Metadata:
FLOPs: 3052580688 FLOPs: 2979222528
Parameters: 14951248 Parameters: 14838352
Results: Results:
- Dataset: ImageNet-1k - Dataset: ImageNet-1k
Task: Image Classification Task: Image Classification
Metrics: Metrics:
Top 1 Accuracy: 79.30 Top 1 Accuracy: 79.69
Top 5 Accuracy: 94.37 Top 5 Accuracy: 94.46
Weights: https://download.openmmlab.com/mmclassification/v0/mobileone/mobileone-s4_3rdparty_in1k_20220915-ce9509ee.pth Weights: https://download.openmmlab.com/mmclassification/v0/mobileone/mobileone-s4_8xb32_in1k_20221110-28d888cb.pth
Converted From:
Weights: https://docs-assets.developer.apple.com/ml-research/datasets/mobileone/mobileone_s4_unfused.pth.tar
Code: https://github.com/apple/ml-mobileone

View File

@ -0,0 +1,20 @@
_base_ = [
'../_base_/models/mobileone/mobileone_s0.py',
'../_base_/datasets/imagenet_bs32_pil_resize.py',
'../_base_/schedules/imagenet_bs256_coslr_coswd_300e.py',
'../_base_/default_runtime.py'
]
# schedule settings
optim_wrapper = dict(paramwise_cfg=dict(norm_decay_mult=0.))
val_dataloader = dict(batch_size=256)
test_dataloader = dict(batch_size=256)
custom_hooks = [
dict(
type='EMAHook',
momentum=5e-4,
priority='ABOVE_NORMAL',
update_buffers=True)
]

View File

@ -1,15 +0,0 @@
_base_ = [
'../_base_/models/mobileone/mobileone_s1.py',
'../_base_/datasets/imagenet_bs32_pil_resize.py',
'../_base_/schedules/imagenet_bs256_coslr.py',
'../_base_/default_runtime.py'
]
# dataset settings
train_dataloader = dict(batch_size=128)
val_dataloader = dict(batch_size=128)
test_dataloader = dict(batch_size=128)
# NOTE: `auto_scale_lr` is for automatically scaling LR,
# based on the actual training batch size.
auto_scale_lr = dict(base_batch_size=1024)

View File

@ -0,0 +1,60 @@
_base_ = [
'../_base_/models/mobileone/mobileone_s1.py',
'../_base_/datasets/imagenet_bs32_pil_resize.py',
'../_base_/schedules/imagenet_bs256_coslr_coswd_300e.py',
'../_base_/default_runtime.py'
]
# schedule settings
optim_wrapper = dict(paramwise_cfg=dict(norm_decay_mult=0.))
val_dataloader = dict(batch_size=256)
test_dataloader = dict(batch_size=256)
bgr_mean = _base_.data_preprocessor['mean'][::-1]
base_train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='RandomResizedCrop', scale=224, backend='pillow'),
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
dict(
type='RandAugment',
policies='timm_increasing',
num_policies=2,
total_level=10,
magnitude_level=7,
magnitude_std=0.5,
hparams=dict(pad_val=[round(x) for x in bgr_mean])),
dict(type='PackClsInputs')
]
import copy # noqa: E402
# modify start epoch's RandomResizedCrop.scale to 160
train_pipeline_1e = copy.deepcopy(base_train_pipeline)
train_pipeline_1e[1]['scale'] = 160
train_pipeline_1e[3]['magnitude_level'] *= 0.1
_base_.train_dataloader.dataset.pipeline = train_pipeline_1e
# modify 37 epoch's RandomResizedCrop.scale to 192
train_pipeline_37e = copy.deepcopy(base_train_pipeline)
train_pipeline_37e[1]['scale'] = 192
train_pipeline_1e[3]['magnitude_level'] *= 0.2
# modify 112 epoch's RandomResizedCrop.scale to 224
train_pipeline_112e = copy.deepcopy(base_train_pipeline)
train_pipeline_112e[1]['scale'] = 224
train_pipeline_1e[3]['magnitude_level'] *= 0.3
custom_hooks = [
dict(
type='SwitchRecipeHook',
schedule=[
dict(action_epoch=37, pipeline=train_pipeline_37e),
dict(action_epoch=112, pipeline=train_pipeline_112e),
]),
dict(
type='EMAHook',
momentum=5e-4,
priority='ABOVE_NORMAL',
update_buffers=True)
]

View File

@ -1,15 +0,0 @@
_base_ = [
'../_base_/models/mobileone/mobileone_s2.py',
'../_base_/datasets/imagenet_bs32_pil_resize.py',
'../_base_/schedules/imagenet_bs256_coslr.py',
'../_base_/default_runtime.py'
]
# dataset settings
train_dataloader = dict(batch_size=128)
val_dataloader = dict(batch_size=128)
test_dataloader = dict(batch_size=128)
# NOTE: `auto_scale_lr` is for automatically scaling LR,
# based on the actual training batch size.
auto_scale_lr = dict(base_batch_size=1024)

View File

@ -0,0 +1,65 @@
_base_ = [
'../_base_/models/mobileone/mobileone_s2.py',
'../_base_/datasets/imagenet_bs32_pil_resize.py',
'../_base_/schedules/imagenet_bs256_coslr_coswd_300e.py',
'../_base_/default_runtime.py'
]
# schedule settings
optim_wrapper = dict(paramwise_cfg=dict(norm_decay_mult=0.))
val_dataloader = dict(batch_size=256)
test_dataloader = dict(batch_size=256)
import copy # noqa: E402
bgr_mean = _base_.data_preprocessor['mean'][::-1]
base_train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='RandomResizedCrop', scale=224, backend='pillow'),
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
dict(
type='RandAugment',
policies='timm_increasing',
num_policies=2,
total_level=10,
magnitude_level=7,
magnitude_std=0.5,
hparams=dict(pad_val=[round(x) for x in bgr_mean])),
dict(type='PackClsInputs')
]
# modify start epoch RandomResizedCrop.scale to 160
# and RA.magnitude_level * 0.3
train_pipeline_1e = copy.deepcopy(base_train_pipeline)
train_pipeline_1e[1]['scale'] = 160
train_pipeline_1e[3]['magnitude_level'] *= 0.3
_base_.train_dataloader.dataset.pipeline = train_pipeline_1e
import copy # noqa: E402
# modify 137 epoch's RandomResizedCrop.scale to 192
# and RA.magnitude_level * 0.7
train_pipeline_37e = copy.deepcopy(base_train_pipeline)
train_pipeline_37e[1]['scale'] = 192
train_pipeline_37e[3]['magnitude_level'] *= 0.7
# modify 112 epoch's RandomResizedCrop.scale to 224
# and RA.magnitude_level * 1.0
train_pipeline_112e = copy.deepcopy(base_train_pipeline)
train_pipeline_112e[1]['scale'] = 224
train_pipeline_112e[3]['magnitude_level'] *= 1.0
custom_hooks = [
dict(
type='SwitchRecipeHook',
schedule=[
dict(action_epoch=37, pipeline=train_pipeline_37e),
dict(action_epoch=112, pipeline=train_pipeline_112e),
]),
dict(
type='EMAHook',
momentum=5e-4,
priority='ABOVE_NORMAL',
update_buffers=True)
]

View File

@ -1,15 +0,0 @@
_base_ = [
'../_base_/models/mobileone/mobileone_s3.py',
'../_base_/datasets/imagenet_bs64_pil_resize.py',
'../_base_/schedules/imagenet_bs256_coslr.py',
'../_base_/default_runtime.py'
]
# dataset settings
train_dataloader = dict(batch_size=128)
val_dataloader = dict(batch_size=128)
test_dataloader = dict(batch_size=128)
# NOTE: `auto_scale_lr` is for automatically scaling LR,
# based on the actual training batch size.
auto_scale_lr = dict(base_batch_size=1024)

View File

@ -0,0 +1,65 @@
_base_ = [
'../_base_/models/mobileone/mobileone_s3.py',
'../_base_/datasets/imagenet_bs32_pil_resize.py',
'../_base_/schedules/imagenet_bs256_coslr_coswd_300e.py',
'../_base_/default_runtime.py'
]
# schedule settings
optim_wrapper = dict(paramwise_cfg=dict(norm_decay_mult=0.))
val_dataloader = dict(batch_size=256)
test_dataloader = dict(batch_size=256)
import copy # noqa: E402
bgr_mean = _base_.data_preprocessor['mean'][::-1]
base_train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='RandomResizedCrop', scale=224, backend='pillow'),
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
dict(
type='RandAugment',
policies='timm_increasing',
num_policies=2,
total_level=10,
magnitude_level=7,
magnitude_std=0.5,
hparams=dict(pad_val=[round(x) for x in bgr_mean])),
dict(type='PackClsInputs')
]
# modify start epoch RandomResizedCrop.scale to 160
# and RA.magnitude_level * 0.3
train_pipeline_1e = copy.deepcopy(base_train_pipeline)
train_pipeline_1e[1]['scale'] = 160
train_pipeline_1e[3]['magnitude_level'] *= 0.3
_base_.train_dataloader.dataset.pipeline = train_pipeline_1e
import copy # noqa: E402
# modify 137 epoch's RandomResizedCrop.scale to 192
# and RA.magnitude_level * 0.7
train_pipeline_37e = copy.deepcopy(base_train_pipeline)
train_pipeline_37e[1]['scale'] = 192
train_pipeline_37e[3]['magnitude_level'] *= 0.7
# modify 112 epoch's RandomResizedCrop.scale to 224
# and RA.magnitude_level * 1.0
train_pipeline_112e = copy.deepcopy(base_train_pipeline)
train_pipeline_112e[1]['scale'] = 224
train_pipeline_112e[3]['magnitude_level'] *= 1.0
custom_hooks = [
dict(
type='SwitchRecipeHook',
schedule=[
dict(action_epoch=37, pipeline=train_pipeline_37e),
dict(action_epoch=112, pipeline=train_pipeline_112e),
]),
dict(
type='EMAHook',
momentum=5e-4,
priority='ABOVE_NORMAL',
update_buffers=True)
]

View File

@ -1,15 +0,0 @@
_base_ = [
'../_base_/models/mobileone/mobileone_s4.py',
'../_base_/datasets/imagenet_bs64_pil_resize.py',
'../_base_/schedules/imagenet_bs256_coslr.py',
'../_base_/default_runtime.py'
]
# dataset settings
train_dataloader = dict(batch_size=128)
val_dataloader = dict(batch_size=128)
test_dataloader = dict(batch_size=128)
# NOTE: `auto_scale_lr` is for automatically scaling LR,
# based on the actual training batch size.
auto_scale_lr = dict(base_batch_size=1024)

View File

@ -0,0 +1,63 @@
_base_ = [
'../_base_/models/mobileone/mobileone_s4.py',
'../_base_/datasets/imagenet_bs32_pil_resize.py',
'../_base_/schedules/imagenet_bs256_coslr_coswd_300e.py',
'../_base_/default_runtime.py'
]
# schedule settings
optim_wrapper = dict(paramwise_cfg=dict(norm_decay_mult=0.))
val_dataloader = dict(batch_size=256)
test_dataloader = dict(batch_size=256)
bgr_mean = _base_.data_preprocessor['mean'][::-1]
base_train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='RandomResizedCrop', scale=224, backend='pillow'),
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
dict(
type='RandAugment',
policies='timm_increasing',
num_policies=2,
total_level=10,
magnitude_level=7,
magnitude_std=0.5,
hparams=dict(pad_val=[round(x) for x in bgr_mean])),
dict(type='PackClsInputs')
]
import copy # noqa: E402
# modify start epoch RandomResizedCrop.scale to 160
# and RA.magnitude_level * 0.3
train_pipeline_1e = copy.deepcopy(base_train_pipeline)
train_pipeline_1e[1]['scale'] = 160
train_pipeline_1e[3]['magnitude_level'] *= 0.3
_base_.train_dataloader.dataset.pipeline = train_pipeline_1e
# modify 137 epoch's RandomResizedCrop.scale to 192
# and RA.magnitude_level * 0.7
train_pipeline_37e = copy.deepcopy(base_train_pipeline)
train_pipeline_37e[1]['scale'] = 192
train_pipeline_37e[3]['magnitude_level'] *= 0.7
# modify 112 epoch's RandomResizedCrop.scale to 224
# and RA.magnitude_level * 1.0
train_pipeline_112e = copy.deepcopy(base_train_pipeline)
train_pipeline_112e[1]['scale'] = 224
train_pipeline_112e[3]['magnitude_level'] *= 1.0
custom_hooks = [
dict(
type='SwitchRecipeHook',
schedule=[
dict(action_epoch=37, pipeline=train_pipeline_37e),
dict(action_epoch=112, pipeline=train_pipeline_112e),
]),
dict(
type='EMAHook',
momentum=5e-4,
priority='ABOVE_NORMAL',
update_buffers=True)
]

View File

@ -41,6 +41,7 @@ class ParamRecordHook(Hook):
self.by_epoch = by_epoch self.by_epoch = by_epoch
self.lr_list = [] self.lr_list = []
self.momentum_list = [] self.momentum_list = []
self.wd_list = []
self.task_id = 0 self.task_id = 0
self.progress = Progress(BarColumn(), MofNCompleteColumn(), self.progress = Progress(BarColumn(), MofNCompleteColumn(),
TextColumn('{task.description}')) TextColumn('{task.description}'))
@ -66,6 +67,8 @@ class ParamRecordHook(Hook):
self.lr_list.append(runner.optim_wrapper.get_lr()['lr'][0]) self.lr_list.append(runner.optim_wrapper.get_lr()['lr'][0])
self.momentum_list.append( self.momentum_list.append(
runner.optim_wrapper.get_momentum()['momentum'][0]) runner.optim_wrapper.get_momentum()['momentum'][0])
self.wd_list.append(
runner.optim_wrapper.param_groups[0]['weight_decay'])
def after_train(self, runner): def after_train(self, runner):
self.progress.stop() self.progress.stop()
@ -80,9 +83,9 @@ def parse_args():
'--parameter', '--parameter',
type=str, type=str,
default='lr', default='lr',
choices=['lr', 'momentum'], choices=['lr', 'momentum', 'wd'],
help='The parameter to visualize its change curve, choose from' help='The parameter to visualize its change curve, choose from'
'"lr" and "momentum". Defaults to "lr".') '"lr", "wd" and "momentum". Defaults to "lr".')
parser.add_argument( parser.add_argument(
'-d', '-d',
'--dataset-size', '--dataset-size',
@ -192,7 +195,12 @@ def simulate_train(data_loader, cfg, by_epoch):
runner.train() runner.train()
return param_record_hook.lr_list, param_record_hook.momentum_list param_dict = dict(
lr=param_record_hook.lr_list,
momentum=param_record_hook.momentum_list,
wd=param_record_hook.wd_list)
return param_dict
def main(): def main():
@ -250,13 +258,15 @@ def main():
rich.print(dataset_info + '\n') rich.print(dataset_info + '\n')
# simulation training process # simulation training process
lr_list, momentum_list = simulate_train(data_loader, cfg, by_epoch) param_dict = simulate_train(data_loader, cfg, by_epoch)
if args.parameter == 'lr': param_list = param_dict[args.parameter]
param_list = lr_list
else:
param_list = momentum_list
param_name = 'Learning Rate' if args.parameter == 'lr' else 'Momentum' if args.parameter == 'lr':
param_name = 'Learning Rate'
elif args.parameter == 'momentum':
param_name = 'Momentum'
else:
param_name = 'Weight Decay'
plot_curve(param_list, args, param_name, len(data_loader), by_epoch) plot_curve(param_list, args, param_name, len(data_loader), by_epoch)
if args.save_path: if args.save_path: