From 4969830c8ae9e05857297f625f00ebdc5f511e79 Mon Sep 17 00:00:00 2001 From: Ezra-Yu <18586273+Ezra-Yu@users.noreply.github.com> Date: Mon, 21 Nov 2022 10:43:34 +0800 Subject: [PATCH] [Enhance] Reproduce mobileone training accuracy. (#1191) * add switch hook and UTs * update doc * update doc * fix lint * fix ci * fix ci * fix typo * fix ci * update configs names * update configs * update configs * update links * update readme * update vis_scheduler * update metafile * update configs * rebase * fix ci * rebase --- .../imagenet_bs256_coslr_coswd_300e.py} | 22 +- configs/mobileone/README.md | 191 +++++++++++------- .../deploy/mobileone-s0_deploy_8xb128_in1k.py | 3 - .../deploy/mobileone-s0_deploy_8xb32_in1k.py | 3 + .../deploy/mobileone-s1_deploy_8xb128_in1k.py | 3 - .../deploy/mobileone-s1_deploy_8xb32_in1k.py | 3 + .../deploy/mobileone-s2_deploy_8xb128_in1k.py | 3 - .../deploy/mobileone-s2_deploy_8xb32_in1k.py | 3 + .../deploy/mobileone-s3_deploy_8xb128_in1k.py | 3 - .../deploy/mobileone-s3_deploy_8xb32_in1k.py | 3 + .../deploy/mobileone-s4_deploy_8xb128_in1k.py | 3 - .../deploy/mobileone-s4_deploy_8xb32_in1k.py | 3 + configs/mobileone/metafile.yml | 83 ++++---- configs/mobileone/mobileone-s0_8xb32_in1k.py | 20 ++ configs/mobileone/mobileone-s1_8xb128_in1k.py | 15 -- configs/mobileone/mobileone-s1_8xb32_in1k.py | 60 ++++++ configs/mobileone/mobileone-s2_8xb128_in1k.py | 15 -- configs/mobileone/mobileone-s2_8xb32_in1k.py | 65 ++++++ configs/mobileone/mobileone-s3_8xb128_in1k.py | 15 -- configs/mobileone/mobileone-s3_8xb32_in1k.py | 65 ++++++ configs/mobileone/mobileone-s4_8xb128_in1k.py | 15 -- configs/mobileone/mobileone-s4_8xb32_in1k.py | 63 ++++++ tools/visualizations/vis_scheduler.py | 28 ++- 23 files changed, 460 insertions(+), 227 deletions(-) rename configs/{mobileone/mobileone-s0_8xb128_in1k.py => _base_/schedules/imagenet_bs256_coslr_coswd_300e.py} (63%) delete mode 100644 configs/mobileone/deploy/mobileone-s0_deploy_8xb128_in1k.py create mode 100644 configs/mobileone/deploy/mobileone-s0_deploy_8xb32_in1k.py delete mode 100644 configs/mobileone/deploy/mobileone-s1_deploy_8xb128_in1k.py create mode 100644 configs/mobileone/deploy/mobileone-s1_deploy_8xb32_in1k.py delete mode 100644 configs/mobileone/deploy/mobileone-s2_deploy_8xb128_in1k.py create mode 100644 configs/mobileone/deploy/mobileone-s2_deploy_8xb32_in1k.py delete mode 100644 configs/mobileone/deploy/mobileone-s3_deploy_8xb128_in1k.py create mode 100644 configs/mobileone/deploy/mobileone-s3_deploy_8xb32_in1k.py delete mode 100644 configs/mobileone/deploy/mobileone-s4_deploy_8xb128_in1k.py create mode 100644 configs/mobileone/deploy/mobileone-s4_deploy_8xb32_in1k.py create mode 100644 configs/mobileone/mobileone-s0_8xb32_in1k.py delete mode 100644 configs/mobileone/mobileone-s1_8xb128_in1k.py create mode 100644 configs/mobileone/mobileone-s1_8xb32_in1k.py delete mode 100644 configs/mobileone/mobileone-s2_8xb128_in1k.py create mode 100644 configs/mobileone/mobileone-s2_8xb32_in1k.py delete mode 100644 configs/mobileone/mobileone-s3_8xb128_in1k.py create mode 100644 configs/mobileone/mobileone-s3_8xb32_in1k.py delete mode 100644 configs/mobileone/mobileone-s4_8xb128_in1k.py create mode 100644 configs/mobileone/mobileone-s4_8xb32_in1k.py diff --git a/configs/mobileone/mobileone-s0_8xb128_in1k.py b/configs/_base_/schedules/imagenet_bs256_coslr_coswd_300e.py similarity index 63% rename from configs/mobileone/mobileone-s0_8xb128_in1k.py rename to configs/_base_/schedules/imagenet_bs256_coslr_coswd_300e.py index ceeb21b7..318e0315 100644 --- a/configs/mobileone/mobileone-s0_8xb128_in1k.py +++ b/configs/_base_/schedules/imagenet_bs256_coslr_coswd_300e.py @@ -1,19 +1,6 @@ -_base_ = [ - '../_base_/models/mobileone/mobileone_s0.py', - '../_base_/datasets/imagenet_bs32_pil_resize.py', - '../_base_/default_runtime.py' -] - -# dataset settings -train_dataloader = dict(batch_size=128) -val_dataloader = dict(batch_size=128) -test_dataloader = dict(batch_size=128) - -# schedule settings +# optimizer optim_wrapper = dict( - optimizer=dict(type='SGD', lr=0.1, momentum=0.9, weight_decay=0.0001), - paramwise_cfg=dict(bias_decay_mult=0., norm_decay_mult=0.), -) + optimizer=dict(type='SGD', lr=0.1, momentum=0.9, weight_decay=0.0001)) # learning policy param_scheduler = [ @@ -50,7 +37,4 @@ test_cfg = dict() # NOTE: `auto_scale_lr` is for automatically scaling LR, # based on the actual training batch size. -auto_scale_lr = dict(base_batch_size=1024) - -# runtime setting -custom_hooks = [dict(type='EMAHook', momentum=5e-4, priority='ABOVE_NORMAL')] +auto_scale_lr = dict(base_batch_size=256) diff --git a/configs/mobileone/README.md b/configs/mobileone/README.md index 80fbb148..ce187e26 100644 --- a/configs/mobileone/README.md +++ b/configs/mobileone/README.md @@ -4,35 +4,121 @@ -## Abstract +## Introduction -Efficient neural network backbones for mobile devices are often optimized for metrics such as FLOPs or parameter count. However, these metrics may not correlate well with latency of the network when deployed on a mobile device. Therefore, we perform extensive analysis of different metrics by deploying several mobile-friendly networks on a mobile device. We identify and analyze architectural and optimization bottlenecks in recent efficient neural networks and provide ways to mitigate these bottlenecks. To this end, we design an efficient backbone MobileOne, with variants achieving an inference time under 1 ms on an iPhone12 with 75.9% top-1 accuracy on ImageNet. We show that MobileOne achieves state-of-the-art performance within the efficient architectures while being many times faster on mobile. Our best model obtains similar performance on ImageNet as MobileFormer while being 38x faster. Our model obtains 2.3% better top-1 accuracy on ImageNet than EfficientNet at similar latency. Furthermore, we show that our model generalizes to multiple tasks - image classification, object detection, and semantic segmentation with significant improvements in latency and accuracy as compared to existing efficient architectures when deployed on a mobile device. +Mobileone is proposed by apple and based on reparameterization. On the apple chips, the accuracy of the model is close to 0.76 on the ImageNet dataset when the latency is less than 1ms. Its main improvements based on [RepVGG](../repvgg) are fllowing: + +- Reparameterization using Depthwise convolution and Pointwise convolution instead of normal convolution. +- Removal of the residual structure which is not friendly to access memory.
-## Results and models +## Abstract -### ImageNet-1k +
-| Model | Params(M) | Flops(G) | Top-1 (%) | Top-5 (%) | Config | Download | -| :------------: | :-----------------------------: | :----------------------------: | :-------: | :-------: | :--------------------------------------------------: | :-----------------------------------------------------: | -| MobileOne-s0\* | 5.29(train) \| 2.08 (deploy) | 1.09 (train) \| 0.28 (deploy) | 71.36 | 89.87 | [config (train)](./mobileone-s0_8xb128_in1k.py) \| [config (deploy)](./deploy/mobileone-s0_deploy_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/mobileone/mobileone-s0_3rdparty_in1k_20220915-007ae971.pth) | -| MobileOne-s1\* | 4.83 (train) \| 4.76 (deploy) | 0.86 (train) \| 0.84 (deploy) | 75.76 | 92.77 | [config (train)](./mobileone-s1_8xb128_in1k.py) \| [config (deploy)](./deploy/mobileone-s1_deploy_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/mobileone/mobileone-s1_3rdparty_in1k_20220915-473c8469.pth) | -| MobileOne-s2\* | 7.88 (train) \| 7.88 (deploy) | 1.34 (train) \| 1.31 (deploy) | 77.39 | 93.63 | [config (train)](./mobileone-s2_8xb128_in1k.py) \|[config (deploy)](./deploy/mobileone-s2_deploy_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/mobileone/mobileone-s2_3rdparty_in1k_20220915-ed2e4c30.pth) | -| MobileOne-s3\* | 10.17 (train) \| 10.08 (deploy) | 1.95 (train) \| 1.91 (deploy) | 77.93 | 93.89 | [config (train)](./mobileone-s3_8xb128_in1k.py) \|[config (deploy)](./deploy/mobileone-s3_deploy_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/mobileone/mobileone-s3_3rdparty_in1k_20220915-84d6a02c.pth) | -| MobileOne-s4\* | 14.95 (train) \| 14.84 (deploy) | 3.05 (train) \| 3.00 (deploy) | 79.30 | 94.37 | [config (train)](./mobileone-s4_8xb128_in1k.py) \|[config (deploy)](./deploy/mobileone-s4_deploy_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/mobileone/mobileone-s4_3rdparty_in1k_20220915-ce9509ee.pth) | +Show the paper's abstract -*Models with * are converted from the [official repo](https://github.com/apple/ml-mobileone). The config files of these models are only for validation. We don't ensure these config files' training accuracy and welcome you to contribute your reproduction results.* +
+Efficient neural network backbones for mobile devices are often optimized for metrics such as FLOPs or parameter count. However, these metrics may not correlate well with latency of the network when deployed on a mobile device. Therefore, we perform extensive analysis of different metrics by deploying several mobile-friendly networks on a mobile device. We identify and analyze architectural and optimization bottlenecks in recent efficient neural networks and provide ways to mitigate these bottlenecks. To this end, we design an efficient backbone MobileOne, with variants achieving an inference time under 1 ms on an iPhone12 with 75.9% top-1 accuracy on ImageNet. We show that MobileOne achieves state-of-the-art performance within the efficient architectures while being many times faster on mobile. Our best model obtains similar performance on ImageNet as MobileFormer while being 38x faster. Our model obtains 2.3% better top-1 accuracy on ImageNet than EfficientNet at similar latency. Furthermore, we show that our model generalizes to multiple tasks - image classification, object detection, and semantic segmentation with significant improvements in latency and accuracy as compared to existing efficient architectures when deployed on a mobile device. +
-*Because the [official repo.](https://github.com/apple/ml-mobileone) does not give a strategy for training and testing, the test data pipline of [RepVGG](https://github.com/open-mmlab/mmclassification/tree/master/configs/repvgg) is used here, and the result is about 0.1 lower than that in the paper. Refer to [this issue](https://github.com/apple/ml-mobileone/issues/2).* +
## How to use -The checkpoints provided are all `training-time` models. Use the reparameterize tool to switch them to more efficient `inference-time` architecture, which not only has fewer parameters but also less calculations. +The checkpoints provided are all `training-time` models. Use the reparameterize tool or `switch_to_deploy` interface to switch them to more efficient `inference-time` architecture, which not only has fewer parameters but also less calculations. -### Use tool + + +**Predict image** + +Use `classifier.backbone.switch_to_deploy()` interface to switch the MobileOne to a inference mode. + +```python +>>> import torch +>>> from mmcls.apis import init_model, inference_model +>>> +>>> model = init_model('configs/mobileone/mobileone-s0_8xb32_in1k.py', 'https://download.openmmlab.com/mmclassification/v0/mobileone/mobileone-s0_8xb32_in1k_20221110-0bc94952.pth') +>>> predict = inference_model(model, 'demo/demo.JPEG') +>>> print(predict['pred_class']) +sea snake +>>> print(predict['pred_score']) +0.4539405107498169 +>>> +>>> # switch to deploy mode +>>> model.backbone.switch_to_deploy() +>>> predict_deploy = inference_model(model, 'demo/demo.JPEG') +>>> print(predict_deploy['pred_class']) +sea snake +>>> print(predict_deploy['pred_score']) +0.4539395272731781 +``` + +**Use the model** + +```python +>>> import torch +>>> from mmcls.apis import init_model +>>> +>>> model = init_model('configs/mobileone/mobileone-s0_8xb32_in1k.py', 'https://download.openmmlab.com/mmclassification/v0/mobileone/mobileone-s0_8xb32_in1k_20221110-0bc94952.pth') +>>> inputs = torch.rand(1, 3, 224, 224).to(model.data_preprocessor.device) +>>> # To get classification scores. +>>> out = model(inputs) +>>> print(out.shape) +torch.Size([1, 1000]) +>>> # To extract features. +>>> outs = model.extract_feat(inputs) +>>> print(outs[0].shape) +torch.Size([1, 768]) +>>> +>>> # switch to deploy mode +>>> model.backbone.switch_to_deploy() +>>> out_deploy = model(inputs) +>>> print(out.shape) +torch.Size([1, 1000]) +>>> assert torch.allclose(out, out_deploy) # pass without error +``` + +**Train/Test Command** + +Place the ImageNet dataset to the `data/imagenet/` directory, or prepare datasets according to the [docs](https://mmclassification.readthedocs.io/en/1.x/user_guides/dataset_prepare.html#prepare-dataset). + +Train: + +```shell +python tools/train.py configs/mobileone/mobileone-s0_8xb32_in1k.py +``` + +Download Checkpoint: + +```shell +wget https://download.openmmlab.com/mmclassification/v0/mobileone/mobileone-s0_8xb32_in1k_20221110-0bc94952.pth +``` + +Test use unfused model: + +```shell +python tools/test.py configs/mobileone/mobileone-s0_8xb32_in1k.py mobileone-s0_8xb32_in1k_20221110-0bc94952.pth +``` + +Reparameterize checkpoint: + +```shell +python ./tools/convert_models/reparameterize_model.py ./configs/mobileone/mobileone-s0_8xb32_in1k.py mobileone-s0_8xb32_in1k_20221110-0bc94952.pth mobileone_s0_deploy.pth +``` + +Test use fused model: + +```shell +python tools/test.py configs/mobileone/deploy/mobileone-s0_deploy_8xb32_in1k.py mobileone_s0_deploy.pth +``` + + + +### Reparameterize Tool Use provided tool to reparameterize the given model and save the checkpoint: @@ -45,80 +131,35 @@ python tools/convert_models/reparameterize_model.py ${CFG_PATH} ${SRC_CKPT_PATH} For example: ```shell -python ./tools/convert_models/reparameterize_model.py ./configs/mobileone/mobileone-s0_8xb128_in1k.py https://download.openmmlab.com/mmclassification/v0/mobileone/mobileone-s0_3rdparty_in1k_20220811-db5ce29b.pth ./mobileone_s0_deploy.pth +wget https://download.openmmlab.com/mmclassification/v0/mobileone/mobileone-s0_8xb32_in1k_20221110-0bc94952.pth +python ./tools/convert_models/reparameterize_model.py ./configs/mobileone/mobileone-s0_8xb32_in1k.py mobileone-s0_8xb32_in1k_20221110-0bc94952.pth mobileone_s0_deploy.pth ``` -To use reparameterized weights, the config file must switch to **the deploy config files**. +To use reparameterized weights, the config file must switch to [**the deploy config files**](./deploy/). ```bash -python tools/test.py ${Deploy_CFG} ${Deploy_Checkpoint} --metrics accuracy +python tools/test.py ${Deploy_CFG} ${Deploy_Checkpoint} ``` For example of using the reparameterized weights above: ```shell -python ./tools/test.py ./configs/mobileone/deploy/mobileone-s0_deploy_8xb128_in1k.py mobileone_s0_deploy.pth --metrics accuracy +python ./tools/test.py ./configs/mobileone/deploy/mobileone-s0_deploy_8xb32_in1k.py mobileone_s0_deploy.pth ``` -### In the code +For more configurable parameters, please refer to the [API](https://mmclassification.readthedocs.io/en/1.x/api/generated/mmcls.models.backbones.MobileOne.html#mmcls.models.backbones.MobileOne). -Use the API `switch_to_deploy` of `MobileOne` backbone to to switch to the deploy mode. Usually called like `backbone.switch_to_deploy()` or `classificer.backbone.switch_to_deploy()`. +## Results and models -For Backbones: +### ImageNet-1k -```python -from mmcls.models import build_backbone -import torch - -x = torch.randn( (1, 3, 224, 224) ) -backbone_cfg=dict(type='MobileOne', arch='s0') -backbone = build_backbone(backbone_cfg) -backbone.init_weights() -backbone.eval() -outs_ori = backbone(x) - -backbone.switch_to_deploy() -outs_dep = backbone(x) - -for out1, out2 in zip(outs_ori, outs_dep): - assert torch.allclose(out1, out2) -``` - -For ImageClassifiers: - -```python -from mmcls.models import build_classifier -import torch -import numpy as np - -cfg = dict( - type='ImageClassifier', - backbone=dict( - type='MobileOne', - arch='s0', - out_indices=(3, ), - ), - neck=dict(type='GlobalAveragePooling'), - head=dict( - type='LinearClsHead', - num_classes=1000, - in_channels=1024, - loss=dict(type='CrossEntropyLoss', loss_weight=1.0), - topk=(1, 5), - )) - -x = torch.randn( (1, 3, 224, 224) ) -classifier = build_classifier(cfg) -classifier.init_weights() -classifier.eval() -y_ori = classifier(x, return_loss=False) - -classifier.backbone.switch_to_deploy() -y_dep = classifier(x, return_loss=False) - -for y1, y2 in zip(y_ori, y_dep): - assert np.allclose(y1, y2) -``` +| Model | Params(M) | Flops(G) | Top-1 (%) | Top-5 (%) | Config | Download | +| :----------: | :-----------------------------: | :----------------------------: | :-------: | :-------: | :---------------------------------------------------: | :------------------------------------------------------: | +| MobileOne-s0 | 5.29(train) \| 2.08 (deploy) | 1.09 (train) \| 0.28 (deploy) | 71.34 | 89.87 | [config (train)](./mobileone-s0_8xb32_in1k.py) \| [config (deploy)](./deploy/mobileone-s0_deploy_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/mobileone/mobileone-s0_8xb32_in1k_20221110-0bc94952.pth) \| [log](https://download.openmmlab.com/mmclassification/v0/mobileone/mobileone-s0_8xb32_in1k_20221110-0bc94952.json) | +| MobileOne-s1 | 4.83 (train) \| 4.76 (deploy) | 0.86 (train) \| 0.84 (deploy) | 75.72 | 92.54 | [config (train)](./mobileone-s1_8xb32_in1k.py) \| [config (deploy)](./deploy/mobileone-s1_deploy_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/mobileone/mobileone-s1_8xb32_in1k_20221110-ceeef467.pth) \| [log](https://download.openmmlab.com/mmclassification/v0/mobileone/mobileone-s1_8xb32_in1k_20221110-ceeef467.json) | +| MobileOne-s2 | 7.88 (train) \| 7.88 (deploy) | 1.34 (train) \| 1.31 (deploy) | 77.37 | 93.34 | [config (train)](./mobileone-s2_8xb32_in1k.py) \|[config (deploy)](./deploy/mobileone-s2_deploy_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/mobileone/mobileone-s2_8xb32_in1k_20221110-9c7ecb97.pth) \| [log](https://download.openmmlab.com/mmclassification/v0/mobileone/mobileone-s2_8xb32_in1k_20221110-9c7ecb97.json) | +| MobileOne-s3 | 10.17 (train) \| 10.08 (deploy) | 1.95 (train) \| 1.91 (deploy) | 78.06 | 93.83 | [config (train)](./mobileone-s3_8xb32_in1k.py) \|[config (deploy)](./deploy/mobileone-s3_deploy_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/mobileone/mobileone-s3_8xb32_in1k_20221110-c95eb3bf.pth) \| [log](https://download.openmmlab.com/mmclassification/v0/mobileone/mobileone-s3_8xb32_in1k_20221110-c95eb3bf.pth) | +| MobileOne-s4 | 14.95 (train) \| 14.84 (deploy) | 3.05 (train) \| 3.00 (deploy) | 79.69 | 94.46 | [config (train)](./mobileone-s4_8xb32_in1k.py) \|[config (deploy)](./deploy/mobileone-s4_deploy_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/mobileone/mobileone-s4_8xb32_in1k_20221110-28d888cb.pth) \| [log](https://download.openmmlab.com/mmclassification/v0/mobileone/mobileone-s4_8xb32_in1k_20221110-28d888cb.pth) | ## Citation diff --git a/configs/mobileone/deploy/mobileone-s0_deploy_8xb128_in1k.py b/configs/mobileone/deploy/mobileone-s0_deploy_8xb128_in1k.py deleted file mode 100644 index 8902483c..00000000 --- a/configs/mobileone/deploy/mobileone-s0_deploy_8xb128_in1k.py +++ /dev/null @@ -1,3 +0,0 @@ -_base_ = ['../mobileone-s0_8xb128_in1k.py'] - -model = dict(backbone=dict(deploy=True)) diff --git a/configs/mobileone/deploy/mobileone-s0_deploy_8xb32_in1k.py b/configs/mobileone/deploy/mobileone-s0_deploy_8xb32_in1k.py new file mode 100644 index 00000000..145f3f4e --- /dev/null +++ b/configs/mobileone/deploy/mobileone-s0_deploy_8xb32_in1k.py @@ -0,0 +1,3 @@ +_base_ = ['../mobileone-s0_8xb32_in1k.py'] + +model = dict(backbone=dict(deploy=True)) diff --git a/configs/mobileone/deploy/mobileone-s1_deploy_8xb128_in1k.py b/configs/mobileone/deploy/mobileone-s1_deploy_8xb128_in1k.py deleted file mode 100644 index 7bcf3211..00000000 --- a/configs/mobileone/deploy/mobileone-s1_deploy_8xb128_in1k.py +++ /dev/null @@ -1,3 +0,0 @@ -_base_ = ['../mobileone-s1_8xb128_in1k.py'] - -model = dict(backbone=dict(deploy=True)) diff --git a/configs/mobileone/deploy/mobileone-s1_deploy_8xb32_in1k.py b/configs/mobileone/deploy/mobileone-s1_deploy_8xb32_in1k.py new file mode 100644 index 00000000..8602c31c --- /dev/null +++ b/configs/mobileone/deploy/mobileone-s1_deploy_8xb32_in1k.py @@ -0,0 +1,3 @@ +_base_ = ['../mobileone-s1_8xb32_in1k.py'] + +model = dict(backbone=dict(deploy=True)) diff --git a/configs/mobileone/deploy/mobileone-s2_deploy_8xb128_in1k.py b/configs/mobileone/deploy/mobileone-s2_deploy_8xb128_in1k.py deleted file mode 100644 index 5d64d519..00000000 --- a/configs/mobileone/deploy/mobileone-s2_deploy_8xb128_in1k.py +++ /dev/null @@ -1,3 +0,0 @@ -_base_ = ['../mobileone-s2_8xb128_in1k.py'] - -model = dict(backbone=dict(deploy=True)) diff --git a/configs/mobileone/deploy/mobileone-s2_deploy_8xb32_in1k.py b/configs/mobileone/deploy/mobileone-s2_deploy_8xb32_in1k.py new file mode 100644 index 00000000..97aaddd0 --- /dev/null +++ b/configs/mobileone/deploy/mobileone-s2_deploy_8xb32_in1k.py @@ -0,0 +1,3 @@ +_base_ = ['../mobileone-s2_8xb32_in1k.py'] + +model = dict(backbone=dict(deploy=True)) diff --git a/configs/mobileone/deploy/mobileone-s3_deploy_8xb128_in1k.py b/configs/mobileone/deploy/mobileone-s3_deploy_8xb128_in1k.py deleted file mode 100644 index 8c710f78..00000000 --- a/configs/mobileone/deploy/mobileone-s3_deploy_8xb128_in1k.py +++ /dev/null @@ -1,3 +0,0 @@ -_base_ = ['../mobileone-s3_8xb128_in1k.py'] - -model = dict(backbone=dict(deploy=True)) diff --git a/configs/mobileone/deploy/mobileone-s3_deploy_8xb32_in1k.py b/configs/mobileone/deploy/mobileone-s3_deploy_8xb32_in1k.py new file mode 100644 index 00000000..0d335a7b --- /dev/null +++ b/configs/mobileone/deploy/mobileone-s3_deploy_8xb32_in1k.py @@ -0,0 +1,3 @@ +_base_ = ['../mobileone-s3_8xb32_in1k.py'] + +model = dict(backbone=dict(deploy=True)) diff --git a/configs/mobileone/deploy/mobileone-s4_deploy_8xb128_in1k.py b/configs/mobileone/deploy/mobileone-s4_deploy_8xb128_in1k.py deleted file mode 100644 index 6ca4d18e..00000000 --- a/configs/mobileone/deploy/mobileone-s4_deploy_8xb128_in1k.py +++ /dev/null @@ -1,3 +0,0 @@ -_base_ = ['../mobileone-s4_8xb128_in1k.py'] - -model = dict(backbone=dict(deploy=True)) diff --git a/configs/mobileone/deploy/mobileone-s4_deploy_8xb32_in1k.py b/configs/mobileone/deploy/mobileone-s4_deploy_8xb32_in1k.py new file mode 100644 index 00000000..b82f5a9a --- /dev/null +++ b/configs/mobileone/deploy/mobileone-s4_deploy_8xb32_in1k.py @@ -0,0 +1,3 @@ +_base_ = ['../mobileone-s4_8xb32_in1k.py'] + +model = dict(backbone=dict(deploy=True)) diff --git a/configs/mobileone/metafile.yml b/configs/mobileone/metafile.yml index 04eaceff..2a480dcd 100644 --- a/configs/mobileone/metafile.yml +++ b/configs/mobileone/metafile.yml @@ -16,83 +16,68 @@ Collections: Version: v1.0.0rc1 Models: - - Name: mobileone-s0_3rdparty_8xb128_in1k + - Name: mobileone-s0_8xb32_in1k In Collection: MobileOne - Config: configs/mobileone/mobileone-s0_8xb128_in1k.py + Config: configs/mobileone/mobileone-s0_8xb32_in1k.py Metadata: - FLOPs: 1091227648 # 1.09G - Parameters: 5293272 # 5.29M + FLOPs: 274136576 # 0.27G + Parameters: 2078504 # 2.08M Results: - Dataset: ImageNet-1k Task: Image Classification Metrics: - Top 1 Accuracy: 71.36 + Top 1 Accuracy: 71.34 Top 5 Accuracy: 89.87 - Weights: https://download.openmmlab.com/mmclassification/v0/mobileone/mobileone-s0_3rdparty_in1k_20220915-007ae971.pth - Converted From: - Weights: https://docs-assets.developer.apple.com/ml-research/datasets/mobileone/mobileone_s0_unfused.pth.tar - Code: https://github.com/apple/ml-mobileone - - Name: mobileone-s1_3rdparty_8xb128_in1k + Weights: https://download.openmmlab.com/mmclassification/v0/mobileone/mobileone-s0_8xb32_in1k_20221110-0bc94952.pth + - Name: mobileone-s1_8xb32_in1k In Collection: MobileOne - Config: configs/mobileone/mobileone-s1_8xb128_in1k.py + Config: configs/mobileone/mobileone-s1_8xb32_in1k.py Metadata: - FLOPs: 863491328 # 8.6G - Parameters: 4825192 # 4.82M + FLOPs: 823839744 # 8.6G + Parameters: 4764840 # 4.82M Results: - Dataset: ImageNet-1k Task: Image Classification Metrics: - Top 1 Accuracy: 75.76 - Top 5 Accuracy: 92.77 - Weights: https://download.openmmlab.com/mmclassification/v0/mobileone/mobileone-s1_3rdparty_in1k_20220915-473c8469.pth - Converted From: - Weights: https://docs-assets.developer.apple.com/ml-research/datasets/mobileone/mobileone_s1_unfused.pth.tar - Code: https://github.com/apple/ml-mobileone - - Name: mobileone-s2_3rdparty_8xb128_in1k + Top 1 Accuracy: 75.72 + Top 5 Accuracy: 92.54 + Weights: https://download.openmmlab.com/mmclassification/v0/mobileone/mobileone-s1_8xb32_in1k_20221110-ceeef467.pth + - Name: mobileone-s2_8xb32_in1k In Collection: MobileOne - Config: configs/mobileone/mobileone-s2_8xb128_in1k.py + Config: configs/mobileone/mobileone-s2_8xb32_in1k.py Metadata: - FLOPs: 1344083328 - Parameters: 7884648 + FLOPs: 1296478848 + Parameters: 7808168 Results: - Dataset: ImageNet-1k Task: Image Classification Metrics: - Top 1 Accuracy: 77.39 - Top 5 Accuracy: 93.63 - Weights: https://download.openmmlab.com/mmclassification/v0/mobileone/mobileone-s2_3rdparty_in1k_20220915-ed2e4c30.pth - Converted From: - Weights: https://docs-assets.developer.apple.com/ml-research/datasets/mobileone/mobileone_s2_unfused.pth.tar - Code: https://github.com/apple/ml-mobileone - - Name: mobileone-s3_3rdparty_8xb128_in1k + Top 1 Accuracy: 77.37 + Top 5 Accuracy: 93.34 + Weights: https://download.openmmlab.com/mmclassification/v0/mobileone/mobileone-s2_8xb32_in1k_20221110-9c7ecb97.pth + - Name: mobileone-s3_8xb32_in1k In Collection: MobileOne - Config: configs/mobileone/mobileone-s3_8xb128_in1k.py + Config: configs/mobileone/mobileone-s3_8xb32_in1k.py Metadata: - FLOPs: 1951043584 - Parameters: 10170600 + FLOPs: 1893842944 + Parameters: 10078312 Results: - Dataset: ImageNet-1k Task: Image Classification Metrics: - Top 1 Accuracy: 77.93 - Top 5 Accuracy: 93.89 - Weights: https://download.openmmlab.com/mmclassification/v0/mobileone/mobileone-s3_3rdparty_in1k_20220915-84d6a02c.pth - Converted From: - Weights: https://docs-assets.developer.apple.com/ml-research/datasets/mobileone/mobileone_s3_unfused.pth.tar - Code: https://github.com/apple/ml-mobileone - - Name: mobileone-s4_3rdparty_8xb128_in1k + Top 1 Accuracy: 78.06 + Top 5 Accuracy: 93.83 + Weights: https://download.openmmlab.com/mmclassification/v0/mobileone/mobileone-s3_8xb32_in1k_20221110-c95eb3bf.pth + - Name: mobileone-s4_8xb32_in1k In Collection: MobileOne - Config: configs/mobileone/mobileone-s4_8xb128_in1k.py + Config: configs/mobileone/mobileone-s4_8xb32_in1k.py Metadata: - FLOPs: 3052580688 - Parameters: 14951248 + FLOPs: 2979222528 + Parameters: 14838352 Results: - Dataset: ImageNet-1k Task: Image Classification Metrics: - Top 1 Accuracy: 79.30 - Top 5 Accuracy: 94.37 - Weights: https://download.openmmlab.com/mmclassification/v0/mobileone/mobileone-s4_3rdparty_in1k_20220915-ce9509ee.pth - Converted From: - Weights: https://docs-assets.developer.apple.com/ml-research/datasets/mobileone/mobileone_s4_unfused.pth.tar - Code: https://github.com/apple/ml-mobileone + Top 1 Accuracy: 79.69 + Top 5 Accuracy: 94.46 + Weights: https://download.openmmlab.com/mmclassification/v0/mobileone/mobileone-s4_8xb32_in1k_20221110-28d888cb.pth diff --git a/configs/mobileone/mobileone-s0_8xb32_in1k.py b/configs/mobileone/mobileone-s0_8xb32_in1k.py new file mode 100644 index 00000000..be56b86c --- /dev/null +++ b/configs/mobileone/mobileone-s0_8xb32_in1k.py @@ -0,0 +1,20 @@ +_base_ = [ + '../_base_/models/mobileone/mobileone_s0.py', + '../_base_/datasets/imagenet_bs32_pil_resize.py', + '../_base_/schedules/imagenet_bs256_coslr_coswd_300e.py', + '../_base_/default_runtime.py' +] + +# schedule settings +optim_wrapper = dict(paramwise_cfg=dict(norm_decay_mult=0.)) + +val_dataloader = dict(batch_size=256) +test_dataloader = dict(batch_size=256) + +custom_hooks = [ + dict( + type='EMAHook', + momentum=5e-4, + priority='ABOVE_NORMAL', + update_buffers=True) +] diff --git a/configs/mobileone/mobileone-s1_8xb128_in1k.py b/configs/mobileone/mobileone-s1_8xb128_in1k.py deleted file mode 100644 index b14c7c17..00000000 --- a/configs/mobileone/mobileone-s1_8xb128_in1k.py +++ /dev/null @@ -1,15 +0,0 @@ -_base_ = [ - '../_base_/models/mobileone/mobileone_s1.py', - '../_base_/datasets/imagenet_bs32_pil_resize.py', - '../_base_/schedules/imagenet_bs256_coslr.py', - '../_base_/default_runtime.py' -] - -# dataset settings -train_dataloader = dict(batch_size=128) -val_dataloader = dict(batch_size=128) -test_dataloader = dict(batch_size=128) - -# NOTE: `auto_scale_lr` is for automatically scaling LR, -# based on the actual training batch size. -auto_scale_lr = dict(base_batch_size=1024) diff --git a/configs/mobileone/mobileone-s1_8xb32_in1k.py b/configs/mobileone/mobileone-s1_8xb32_in1k.py new file mode 100644 index 00000000..52c8442e --- /dev/null +++ b/configs/mobileone/mobileone-s1_8xb32_in1k.py @@ -0,0 +1,60 @@ +_base_ = [ + '../_base_/models/mobileone/mobileone_s1.py', + '../_base_/datasets/imagenet_bs32_pil_resize.py', + '../_base_/schedules/imagenet_bs256_coslr_coswd_300e.py', + '../_base_/default_runtime.py' +] + +# schedule settings +optim_wrapper = dict(paramwise_cfg=dict(norm_decay_mult=0.)) + +val_dataloader = dict(batch_size=256) +test_dataloader = dict(batch_size=256) + +bgr_mean = _base_.data_preprocessor['mean'][::-1] +base_train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='RandomResizedCrop', scale=224, backend='pillow'), + dict(type='RandomFlip', prob=0.5, direction='horizontal'), + dict( + type='RandAugment', + policies='timm_increasing', + num_policies=2, + total_level=10, + magnitude_level=7, + magnitude_std=0.5, + hparams=dict(pad_val=[round(x) for x in bgr_mean])), + dict(type='PackClsInputs') +] + +import copy # noqa: E402 + +# modify start epoch's RandomResizedCrop.scale to 160 +train_pipeline_1e = copy.deepcopy(base_train_pipeline) +train_pipeline_1e[1]['scale'] = 160 +train_pipeline_1e[3]['magnitude_level'] *= 0.1 +_base_.train_dataloader.dataset.pipeline = train_pipeline_1e + +# modify 37 epoch's RandomResizedCrop.scale to 192 +train_pipeline_37e = copy.deepcopy(base_train_pipeline) +train_pipeline_37e[1]['scale'] = 192 +train_pipeline_1e[3]['magnitude_level'] *= 0.2 + +# modify 112 epoch's RandomResizedCrop.scale to 224 +train_pipeline_112e = copy.deepcopy(base_train_pipeline) +train_pipeline_112e[1]['scale'] = 224 +train_pipeline_1e[3]['magnitude_level'] *= 0.3 + +custom_hooks = [ + dict( + type='SwitchRecipeHook', + schedule=[ + dict(action_epoch=37, pipeline=train_pipeline_37e), + dict(action_epoch=112, pipeline=train_pipeline_112e), + ]), + dict( + type='EMAHook', + momentum=5e-4, + priority='ABOVE_NORMAL', + update_buffers=True) +] diff --git a/configs/mobileone/mobileone-s2_8xb128_in1k.py b/configs/mobileone/mobileone-s2_8xb128_in1k.py deleted file mode 100644 index dca0d4d3..00000000 --- a/configs/mobileone/mobileone-s2_8xb128_in1k.py +++ /dev/null @@ -1,15 +0,0 @@ -_base_ = [ - '../_base_/models/mobileone/mobileone_s2.py', - '../_base_/datasets/imagenet_bs32_pil_resize.py', - '../_base_/schedules/imagenet_bs256_coslr.py', - '../_base_/default_runtime.py' -] - -# dataset settings -train_dataloader = dict(batch_size=128) -val_dataloader = dict(batch_size=128) -test_dataloader = dict(batch_size=128) - -# NOTE: `auto_scale_lr` is for automatically scaling LR, -# based on the actual training batch size. -auto_scale_lr = dict(base_batch_size=1024) diff --git a/configs/mobileone/mobileone-s2_8xb32_in1k.py b/configs/mobileone/mobileone-s2_8xb32_in1k.py new file mode 100644 index 00000000..547ae995 --- /dev/null +++ b/configs/mobileone/mobileone-s2_8xb32_in1k.py @@ -0,0 +1,65 @@ +_base_ = [ + '../_base_/models/mobileone/mobileone_s2.py', + '../_base_/datasets/imagenet_bs32_pil_resize.py', + '../_base_/schedules/imagenet_bs256_coslr_coswd_300e.py', + '../_base_/default_runtime.py' +] + +# schedule settings +optim_wrapper = dict(paramwise_cfg=dict(norm_decay_mult=0.)) + +val_dataloader = dict(batch_size=256) +test_dataloader = dict(batch_size=256) + +import copy # noqa: E402 + +bgr_mean = _base_.data_preprocessor['mean'][::-1] +base_train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='RandomResizedCrop', scale=224, backend='pillow'), + dict(type='RandomFlip', prob=0.5, direction='horizontal'), + dict( + type='RandAugment', + policies='timm_increasing', + num_policies=2, + total_level=10, + magnitude_level=7, + magnitude_std=0.5, + hparams=dict(pad_val=[round(x) for x in bgr_mean])), + dict(type='PackClsInputs') +] + +# modify start epoch RandomResizedCrop.scale to 160 +# and RA.magnitude_level * 0.3 +train_pipeline_1e = copy.deepcopy(base_train_pipeline) +train_pipeline_1e[1]['scale'] = 160 +train_pipeline_1e[3]['magnitude_level'] *= 0.3 +_base_.train_dataloader.dataset.pipeline = train_pipeline_1e + +import copy # noqa: E402 + +# modify 137 epoch's RandomResizedCrop.scale to 192 +# and RA.magnitude_level * 0.7 +train_pipeline_37e = copy.deepcopy(base_train_pipeline) +train_pipeline_37e[1]['scale'] = 192 +train_pipeline_37e[3]['magnitude_level'] *= 0.7 + +# modify 112 epoch's RandomResizedCrop.scale to 224 +# and RA.magnitude_level * 1.0 +train_pipeline_112e = copy.deepcopy(base_train_pipeline) +train_pipeline_112e[1]['scale'] = 224 +train_pipeline_112e[3]['magnitude_level'] *= 1.0 + +custom_hooks = [ + dict( + type='SwitchRecipeHook', + schedule=[ + dict(action_epoch=37, pipeline=train_pipeline_37e), + dict(action_epoch=112, pipeline=train_pipeline_112e), + ]), + dict( + type='EMAHook', + momentum=5e-4, + priority='ABOVE_NORMAL', + update_buffers=True) +] diff --git a/configs/mobileone/mobileone-s3_8xb128_in1k.py b/configs/mobileone/mobileone-s3_8xb128_in1k.py deleted file mode 100644 index 89343d5d..00000000 --- a/configs/mobileone/mobileone-s3_8xb128_in1k.py +++ /dev/null @@ -1,15 +0,0 @@ -_base_ = [ - '../_base_/models/mobileone/mobileone_s3.py', - '../_base_/datasets/imagenet_bs64_pil_resize.py', - '../_base_/schedules/imagenet_bs256_coslr.py', - '../_base_/default_runtime.py' -] - -# dataset settings -train_dataloader = dict(batch_size=128) -val_dataloader = dict(batch_size=128) -test_dataloader = dict(batch_size=128) - -# NOTE: `auto_scale_lr` is for automatically scaling LR, -# based on the actual training batch size. -auto_scale_lr = dict(base_batch_size=1024) diff --git a/configs/mobileone/mobileone-s3_8xb32_in1k.py b/configs/mobileone/mobileone-s3_8xb32_in1k.py new file mode 100644 index 00000000..b0ef4164 --- /dev/null +++ b/configs/mobileone/mobileone-s3_8xb32_in1k.py @@ -0,0 +1,65 @@ +_base_ = [ + '../_base_/models/mobileone/mobileone_s3.py', + '../_base_/datasets/imagenet_bs32_pil_resize.py', + '../_base_/schedules/imagenet_bs256_coslr_coswd_300e.py', + '../_base_/default_runtime.py' +] + +# schedule settings +optim_wrapper = dict(paramwise_cfg=dict(norm_decay_mult=0.)) + +val_dataloader = dict(batch_size=256) +test_dataloader = dict(batch_size=256) + +import copy # noqa: E402 + +bgr_mean = _base_.data_preprocessor['mean'][::-1] +base_train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='RandomResizedCrop', scale=224, backend='pillow'), + dict(type='RandomFlip', prob=0.5, direction='horizontal'), + dict( + type='RandAugment', + policies='timm_increasing', + num_policies=2, + total_level=10, + magnitude_level=7, + magnitude_std=0.5, + hparams=dict(pad_val=[round(x) for x in bgr_mean])), + dict(type='PackClsInputs') +] + +# modify start epoch RandomResizedCrop.scale to 160 +# and RA.magnitude_level * 0.3 +train_pipeline_1e = copy.deepcopy(base_train_pipeline) +train_pipeline_1e[1]['scale'] = 160 +train_pipeline_1e[3]['magnitude_level'] *= 0.3 +_base_.train_dataloader.dataset.pipeline = train_pipeline_1e + +import copy # noqa: E402 + +# modify 137 epoch's RandomResizedCrop.scale to 192 +# and RA.magnitude_level * 0.7 +train_pipeline_37e = copy.deepcopy(base_train_pipeline) +train_pipeline_37e[1]['scale'] = 192 +train_pipeline_37e[3]['magnitude_level'] *= 0.7 + +# modify 112 epoch's RandomResizedCrop.scale to 224 +# and RA.magnitude_level * 1.0 +train_pipeline_112e = copy.deepcopy(base_train_pipeline) +train_pipeline_112e[1]['scale'] = 224 +train_pipeline_112e[3]['magnitude_level'] *= 1.0 + +custom_hooks = [ + dict( + type='SwitchRecipeHook', + schedule=[ + dict(action_epoch=37, pipeline=train_pipeline_37e), + dict(action_epoch=112, pipeline=train_pipeline_112e), + ]), + dict( + type='EMAHook', + momentum=5e-4, + priority='ABOVE_NORMAL', + update_buffers=True) +] diff --git a/configs/mobileone/mobileone-s4_8xb128_in1k.py b/configs/mobileone/mobileone-s4_8xb128_in1k.py deleted file mode 100644 index 1984ef35..00000000 --- a/configs/mobileone/mobileone-s4_8xb128_in1k.py +++ /dev/null @@ -1,15 +0,0 @@ -_base_ = [ - '../_base_/models/mobileone/mobileone_s4.py', - '../_base_/datasets/imagenet_bs64_pil_resize.py', - '../_base_/schedules/imagenet_bs256_coslr.py', - '../_base_/default_runtime.py' -] - -# dataset settings -train_dataloader = dict(batch_size=128) -val_dataloader = dict(batch_size=128) -test_dataloader = dict(batch_size=128) - -# NOTE: `auto_scale_lr` is for automatically scaling LR, -# based on the actual training batch size. -auto_scale_lr = dict(base_batch_size=1024) diff --git a/configs/mobileone/mobileone-s4_8xb32_in1k.py b/configs/mobileone/mobileone-s4_8xb32_in1k.py new file mode 100644 index 00000000..8c31f240 --- /dev/null +++ b/configs/mobileone/mobileone-s4_8xb32_in1k.py @@ -0,0 +1,63 @@ +_base_ = [ + '../_base_/models/mobileone/mobileone_s4.py', + '../_base_/datasets/imagenet_bs32_pil_resize.py', + '../_base_/schedules/imagenet_bs256_coslr_coswd_300e.py', + '../_base_/default_runtime.py' +] + +# schedule settings +optim_wrapper = dict(paramwise_cfg=dict(norm_decay_mult=0.)) + +val_dataloader = dict(batch_size=256) +test_dataloader = dict(batch_size=256) + +bgr_mean = _base_.data_preprocessor['mean'][::-1] +base_train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='RandomResizedCrop', scale=224, backend='pillow'), + dict(type='RandomFlip', prob=0.5, direction='horizontal'), + dict( + type='RandAugment', + policies='timm_increasing', + num_policies=2, + total_level=10, + magnitude_level=7, + magnitude_std=0.5, + hparams=dict(pad_val=[round(x) for x in bgr_mean])), + dict(type='PackClsInputs') +] + +import copy # noqa: E402 + +# modify start epoch RandomResizedCrop.scale to 160 +# and RA.magnitude_level * 0.3 +train_pipeline_1e = copy.deepcopy(base_train_pipeline) +train_pipeline_1e[1]['scale'] = 160 +train_pipeline_1e[3]['magnitude_level'] *= 0.3 +_base_.train_dataloader.dataset.pipeline = train_pipeline_1e + +# modify 137 epoch's RandomResizedCrop.scale to 192 +# and RA.magnitude_level * 0.7 +train_pipeline_37e = copy.deepcopy(base_train_pipeline) +train_pipeline_37e[1]['scale'] = 192 +train_pipeline_37e[3]['magnitude_level'] *= 0.7 + +# modify 112 epoch's RandomResizedCrop.scale to 224 +# and RA.magnitude_level * 1.0 +train_pipeline_112e = copy.deepcopy(base_train_pipeline) +train_pipeline_112e[1]['scale'] = 224 +train_pipeline_112e[3]['magnitude_level'] *= 1.0 + +custom_hooks = [ + dict( + type='SwitchRecipeHook', + schedule=[ + dict(action_epoch=37, pipeline=train_pipeline_37e), + dict(action_epoch=112, pipeline=train_pipeline_112e), + ]), + dict( + type='EMAHook', + momentum=5e-4, + priority='ABOVE_NORMAL', + update_buffers=True) +] diff --git a/tools/visualizations/vis_scheduler.py b/tools/visualizations/vis_scheduler.py index 87d076dc..5e08af39 100644 --- a/tools/visualizations/vis_scheduler.py +++ b/tools/visualizations/vis_scheduler.py @@ -41,6 +41,7 @@ class ParamRecordHook(Hook): self.by_epoch = by_epoch self.lr_list = [] self.momentum_list = [] + self.wd_list = [] self.task_id = 0 self.progress = Progress(BarColumn(), MofNCompleteColumn(), TextColumn('{task.description}')) @@ -66,6 +67,8 @@ class ParamRecordHook(Hook): self.lr_list.append(runner.optim_wrapper.get_lr()['lr'][0]) self.momentum_list.append( runner.optim_wrapper.get_momentum()['momentum'][0]) + self.wd_list.append( + runner.optim_wrapper.param_groups[0]['weight_decay']) def after_train(self, runner): self.progress.stop() @@ -80,9 +83,9 @@ def parse_args(): '--parameter', type=str, default='lr', - choices=['lr', 'momentum'], + choices=['lr', 'momentum', 'wd'], help='The parameter to visualize its change curve, choose from' - '"lr" and "momentum". Defaults to "lr".') + '"lr", "wd" and "momentum". Defaults to "lr".') parser.add_argument( '-d', '--dataset-size', @@ -192,7 +195,12 @@ def simulate_train(data_loader, cfg, by_epoch): runner.train() - return param_record_hook.lr_list, param_record_hook.momentum_list + param_dict = dict( + lr=param_record_hook.lr_list, + momentum=param_record_hook.momentum_list, + wd=param_record_hook.wd_list) + + return param_dict def main(): @@ -250,13 +258,15 @@ def main(): rich.print(dataset_info + '\n') # simulation training process - lr_list, momentum_list = simulate_train(data_loader, cfg, by_epoch) - if args.parameter == 'lr': - param_list = lr_list - else: - param_list = momentum_list + param_dict = simulate_train(data_loader, cfg, by_epoch) + param_list = param_dict[args.parameter] - param_name = 'Learning Rate' if args.parameter == 'lr' else 'Momentum' + if args.parameter == 'lr': + param_name = 'Learning Rate' + elif args.parameter == 'momentum': + param_name = 'Momentum' + else: + param_name = 'Weight Decay' plot_curve(param_list, args, param_name, len(data_loader), by_epoch) if args.save_path: