[Reproduce] Reproduce RepVGG Training Accuracy. (#1264)

* repr repvgg

* add VisionRRC

* uodate repvgg configs

* add BCD seriers cfgs

* add cv backend config

* add vision configs

* add L2se configs

* add ra configs

* add num-works configs

* add num-works configs

* configs

* update README

* rm extra config

* reset un-needed changes

* update

* reset pbn

* update readme

* update code

* update code

* refine doc
This commit is contained in:
Ezra-Yu 2022-12-30 15:49:56 +08:00 committed by GitHub
parent e0e6a1f1ae
commit 88e5ba28db
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
32 changed files with 351 additions and 236 deletions

View File

@ -1,43 +1,134 @@
# RepVGG # RepVGG
> [Repvgg: Making vgg-style convnets great again](https://arxiv.org/abs/2101.03697) > [RepVGG: Making VGG-style ConvNets Great Again](https://arxiv.org/abs/2101.03697)
<!-- [ALGORITHM] --> <!-- [ALGORITHM] -->
## Abstract ## Introduction
We present a simple but powerful architecture of convolutional neural network, which has a VGG-like inference-time body composed of nothing but a stack of 3x3 convolution and ReLU, while the training-time model has a multi-branch topology. Such decoupling of the training-time and inference-time architecture is realized by a structural re-parameterization technique so that the model is named RepVGG. On ImageNet, RepVGG reaches over 80% top-1 accuracy, which is the first time for a plain model, to the best of our knowledge. On NVIDIA 1080Ti GPU, RepVGG models run 83% faster than ResNet-50 or 101% faster than ResNet-101 with higher accuracy and show favorable accuracy-speed trade-off compared to the state-of-the-art models like EfficientNet and RegNet. RepVGG is a VGG-style convolutional architecture. It has the following advantages:
1. The model has a VGG-like plain (a.k.a. feed-forward) topology 1 without any branches. I.e., every layer takes the output of its only preceding layer as input and feeds the output into its only following layer.
2. The models body uses only 3 × 3 conv and ReLU.
3. The concrete architecture (including the specific depth and layer widths) is instantiated with no automatic search, manual refinement, compound scaling, nor other heavy designs.
<div align=center> <div align=center>
<img src="https://user-images.githubusercontent.com/26739999/142573223-f7f14d32-ea08-43a1-81ad-5a6a83ee0122.png" width="60%"/> <img src="https://user-images.githubusercontent.com/26739999/142573223-f7f14d32-ea08-43a1-81ad-5a6a83ee0122.png" width="60%"/>
</div> </div>
## Results and models ## Abstract
### ImageNet-1k <details>
| Model | Epochs | Params(M) | Flops(G) | Top-1 (%) | Top-5 (%) | Config | Download | <summary>Show the paper's abstract</summary>
| :-----------: | :----: | :-------------------------------: | :-----------------------------: | :-------: | :-------: | :----------------------------------------------: | :-------------------------------------------------: |
| RepVGG-A0\* | 120 | 9.11train) \| 8.31 (deploy) | 1.52 (train) \| 1.36 (deploy) | 72.41 | 90.50 | [config (train)](./repvgg-A0_4xb64-coslr-120e_in1k.py) \| [config (deploy)](./deploy/repvgg-A0_deploy_4xb64-coslr-120e_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-A0_3rdparty_4xb64-coslr-120e_in1k_20210909-883ab98c.pth) |
| RepVGG-A1\* | 120 | 14.09 (train) \| 12.79 (deploy) | 2.64 (train) \| 2.37 (deploy) | 74.47 | 91.85 | [config (train)](./repvgg-A1_4xb64-coslr-120e_in1k.py) \| [config (deploy)](./deploy/repvgg-A1_deploy_4xb64-coslr-120e_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-A1_3rdparty_4xb64-coslr-120e_in1k_20210909-24003a24.pth) |
| RepVGG-A2\* | 120 | 28.21 (train) \| 25.5 (deploy) | 5.7 (train) \| 5.12 (deploy) | 76.48 | 93.01 | [config (train)](./repvgg-A2_4xb64-coslr-120e_in1k.py) \|[config (deploy)](./deploy/repvgg-A2_deploy_4xb64-coslr-120e_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-A2_3rdparty_4xb64-coslr-120e_in1k_20210909-97d7695a.pth) |
| RepVGG-B0\* | 120 | 15.82 (train) \| 14.34 (deploy) | 3.42 (train) \| 3.06 (deploy) | 75.14 | 92.42 | [config (train)](./repvgg-B0_4xb64-coslr-120e_in1k.py) \|[config (deploy)](./deploy/repvgg-B0_deploy_4xb64-coslr-120e_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B0_3rdparty_4xb64-coslr-120e_in1k_20210909-446375f4.pth) |
| RepVGG-B1\* | 120 | 57.42 (train) \| 51.83 (deploy) | 13.16 (train) \| 11.82 (deploy) | 78.37 | 94.11 | [config (train)](./repvgg-B1_4xb64-coslr-120e_in1k.py) \|[config (deploy)](./deploy/repvgg-B1_deploy_4xb64-coslr-120e_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B1_3rdparty_4xb64-coslr-120e_in1k_20210909-750cdf67.pth) |
| RepVGG-B1g2\* | 120 | 45.78 (train) \| 41.36 (deploy) | 9.82 (train) \| 8.82 (deploy) | 77.79 | 93.88 | [config (train)](./repvgg-B1g2_4xb64-coslr-120e_in1k.py) \|[config (deploy)](./deploy/repvgg-B1g2_deploy_4xb64-coslr-120e_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B1g2_3rdparty_4xb64-coslr-120e_in1k_20210909-344f6422.pth) |
| RepVGG-B1g4\* | 120 | 39.97 (train) \| 36.13 (deploy) | 8.15 (train) \| 7.32 (deploy) | 77.58 | 93.84 | [config (train)](./repvgg-B1g4_4xb64-coslr-120e_in1k.py) \|[config (deploy)](./deploy/repvgg-B1g4_deploy_4xb64-coslr-120e_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B1g4_3rdparty_4xb64-coslr-120e_in1k_20210909-d4c1a642.pth) |
| RepVGG-B2\* | 120 | 89.02 (train) \| 80.32 (deploy) | 20.46 (train) \| 18.39 (deploy) | 78.78 | 94.42 | [config (train)](./repvgg-B2_4xb64-coslr-120e_in1k.py) \|[config (deploy)](./deploy/repvgg-B2_deploy_4xb64-coslr-120e_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B2_3rdparty_4xb64-coslr-120e_in1k_20210909-bd6b937c.pth) |
| RepVGG-B2g4\* | 200 | 61.76 (train) \| 55.78 (deploy) | 12.63 (train) \| 11.34 (deploy) | 79.38 | 94.68 | [config (train)](./repvgg-B2g4_4xb64-autoaug-lbs-mixup-coslr-200e_in1k.py) \|[config (deploy)](./deploy/repvgg-B2g4_deploy_4xb64-autoaug-lbs-mixup-coslr-200e_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B2g4_3rdparty_4xb64-autoaug-lbs-mixup-coslr-200e_in1k_20210909-7b7955f0.pth) |
| RepVGG-B3\* | 200 | 123.09 (train) \| 110.96 (deploy) | 29.17 (train) \| 26.22 (deploy) | 80.52 | 95.26 | [config (train)](./repvgg-B3_4xb64-autoaug-lbs-mixup-coslr-200e_in1k.py) \|[config (deploy)](./deploy/repvgg-B3_deploy_4xb64-autoaug-lbs-mixup-coslr-200e_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B3_3rdparty_4xb64-autoaug-lbs-mixup-coslr-200e_in1k_20210909-dda968bf.pth) |
| RepVGG-B3g4\* | 200 | 83.83 (train) \| 75.63 (deploy) | 17.9 (train) \| 16.08 (deploy) | 80.22 | 95.10 | [config (train)](./repvgg-B3g4_4xb64-autoaug-lbs-mixup-coslr-200e_in1k.py) \|[config (deploy)](./deploy/repvgg-B3g4_deploy_4xb64-autoaug-lbs-mixup-coslr-200e_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B3g4_3rdparty_4xb64-autoaug-lbs-mixup-coslr-200e_in1k_20210909-4e54846a.pth) |
| RepVGG-D2se\* | 200 | 133.33 (train) \| 120.39 (deploy) | 36.56 (train) \| 32.85 (deploy) | 81.81 | 95.94 | [config (train)](./repvgg-D2se_4xb64-autoaug-lbs-mixup-coslr-200e_in1k.py) \|[config (deploy)](./deploy/repvgg-D2se_deploy_4xb64-autoaug-lbs-mixup-coslr-200e_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-D2se_3rdparty_4xb64-autoaug-lbs-mixup-coslr-200e_in1k_20210909-cf3139b7.pth) |
*Models with * are converted from the [official repo](https://github.com/DingXiaoH/RepVGG). The config files of these models are only for validation. We don't ensure these config files' training accuracy and welcome you to contribute your reproduction results.* <br>
We present a simple but powerful architecture of convolutional neural network, which has a VGG-like inference-time body composed of nothing but a stack of 3x3 convolution and ReLU, while the training-time model has a multi-branch topology. Such decoupling of the training-time and inference-time architecture is realized by a structural re-parameterization technique so that the model is named RepVGG. On ImageNet, RepVGG reaches over 80% top-1 accuracy, which is the first time for a plain model, to the best of our knowledge. On NVIDIA 1080Ti GPU, RepVGG models run 83% faster than ResNet-50 or 101% faster than ResNet-101 with higher accuracy and show favorable accuracy-speed trade-off compared to the state-of-the-art models like EfficientNet and RegNet.
</br>
</details>
## How to use ## How to use
The checkpoints provided are all `training-time` models. Use the reparameterize tool to switch them to more efficient `inference-time` architecture, which not only has fewer parameters but also less calculations. The checkpoints provided are all `training-time` models. Use the reparameterize tool or `switch_to_deploy` interface to switch them to more efficient `inference-time` architecture, which not only has fewer parameters but also less calculations.
### Use tool <!-- [TABS-BEGIN] -->
**Predict image**
Use `classifier.backbone.switch_to_deploy()` interface to switch the RepVGG models into inference mode.
```python
>>> import torch
>>> from mmcls.apis import init_model, inference_model
>>>
>>> model = init_model('configs/repvgg/repvgg-A0_8xb32_in1k.py', 'https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-A0_8xb32_in1k_20221213-60ae8e23.pth')
>>> results = inference_model(model, 'demo/demo.JPEG')
>>> print( (results['pred_class'], results['pred_score']) )
('sea snake' 0.8338906168937683)
>>>
>>> # switch to deploy mode
>>> model.backbone.switch_to_deploy()
>>> results = inference_model(model, 'demo/demo.JPEG')
>>> print( (results['pred_class'], results['pred_score']) )
('sea snake', 0.7883061170578003)
```
**Use the model**
```python
>>> import torch
>>> from mmcls.apis import get_model
>>>
>>> model = get_model("repvgg-a0_8xb32_in1k", pretrained=True)
>>> model.eval()
>>> inputs = torch.rand(1, 3, 224, 224).to(model.data_preprocessor.device)
>>> # To get classification scores.
>>> out = model(inputs)
>>> print(out.shape)
torch.Size([1, 1000])
>>> # To extract features.
>>> outs = model.extract_feat(inputs)
>>> print(outs[0].shape)
torch.Size([1, 1280])
>>>
>>> # switch to deploy mode
>>> model.backbone.switch_to_deploy()
>>> out_deploy = model(inputs)
>>> print(out.shape)
torch.Size([1, 1000])
>>> assert torch.allclose(out, out_deploy, rtol=1e-4, atol=1e-5) # pass without error
```
**Train/Test Command**
Place the ImageNet dataset to the `data/imagenet/` directory, or prepare datasets according to the [docs](https://mmclassification.readthedocs.io/en/1.x/user_guides/dataset_prepare.html#prepare-dataset).
Train:
```shell
python tools/train.py configs/repvgg/repvgg-a0_8xb32_in1k.py
```
Download Checkpoint:
```shell
wget https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-A0_8xb32_in1k_20221213-60ae8e23.pth
```
Test use unfused model:
```shell
python tools/test.py configs/repvgg/repvgg-a0_8xb32_in1k.py repvgg-A0_8xb32_in1k_20221213-60ae8e23.pth
```
Reparameterize checkpoint:
```shell
python ./tools/convert_models/reparameterize_model.py configs/repvgg/repvgg-a0_8xb32_in1k.py repvgg-A0_8xb32_in1k_20221213-60ae8e23.pth repvgg_A0_deploy.pth
```
Test use fused model:
```shell
python tools/test.py configs/repvgg/repvgg-A0_8xb32_in1k.py repvgg_A0_deploy.pth --cfg-options model.backbone.deploy=True
```
or
```shell
python tools/test.py configs/repvgg/repvgg-A0_deploy_in1k.py repvgg_A0_deploy.pth
```
<!-- [TABS-END] -->
For more configurable parameters, please refer to the [API](https://mmclassification.readthedocs.io/en/1.x/api/generated/mmcls.models.backbones.RepVGG.html#mmcls.models.backbones.RepVGG).
<details>
<summary><b>How to use the reparameterisation tool</b>(click to show)</summary>
<br>
Use provided tool to reparameterize the given model and save the checkpoint: Use provided tool to reparameterize the given model and save the checkpoint:
@ -45,52 +136,68 @@ Use provided tool to reparameterize the given model and save the checkpoint:
python tools/convert_models/reparameterize_model.py ${CFG_PATH} ${SRC_CKPT_PATH} ${TARGET_CKPT_PATH} python tools/convert_models/reparameterize_model.py ${CFG_PATH} ${SRC_CKPT_PATH} ${TARGET_CKPT_PATH}
``` ```
`${CFG_PATH}` is the config file, `${SRC_CKPT_PATH}` is the source chenpoint file, `${TARGET_CKPT_PATH}` is the target deploy weight file path. `${CFG_PATH}` is the config file path, `${SRC_CKPT_PATH}` is the source chenpoint file path, `${TARGET_CKPT_PATH}` is the target deploy weight file path.
To use reparameterized weights, the config file must switch to the deploy config files. For example:
```bash ```shell
python tools/test.py ${Deploy_CFG} ${Deploy_Checkpoint} --metrics accuracy # download the weight
wget https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-A0_8xb32_in1k_20221213-60ae8e23.pth
# reparameterize unfused weight to fused weight
python ./tools/convert_models/reparameterize_model.py configs/repvgg/repvgg-a0_8xb32_in1k.py repvgg-A0_8xb32_in1k_20221213-60ae8e23.pth repvgg-A0_deploy.pth
``` ```
### In the code To use reparameterized weights, the config file must switch to **the deploy config files** as [the deploy_A0 example](./repvgg-A0_deploy_in1k.py) or add `--cfg-options model.backbone.deploy=True` in command.
Use `backbone.switch_to_deploy()` or `classificer.backbone.switch_to_deploy()` to switch to the deploy mode. For example: For example of using the reparameterized weights above:
```python ```shell
from mmcls.models import build_backbone python ./tools/test.py ./configs/repvgg/repvgg-A0_deploy_in1k.py repvgg-A0_deploy.pth
backbone_cfg=dict(type='RepVGG',arch='A0'),
backbone = build_backbone(backbone_cfg)
backbone.switch_to_deploy()
``` ```
or You can get other deploy configs by modifying the [A0_deploy example](./repvgg-A0_deploy_in1k.py):
```python ```text
from mmcls.models import build_classifier # in repvgg-A0_deploy_in1k.py
_base_ = '../repvgg-A0_8xb32_in1k.py' # basic A0 config
cfg = dict( model = dict(backbone=dict(deploy=True)) # switch model into deploy mode
type='ImageClassifier',
backbone=dict(
type='RepVGG',
arch='A0'),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=1280,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
topk=(1, 5),
))
classifier = build_classifier(cfg)
classifier.backbone.switch_to_deploy()
``` ```
or add `--cfg-options model.backbone.deploy=True` in command as following
```shell
python tools/test.py configs/repvgg/repvgg-A0_8xb32_in1k.py repvgg_A0_deploy.pth --cfg-options model.backbone.deploy=True
```
</br>
</details>
## Results and models
### ImageNet-1k
| Model | Pretrain | <p> Params(M) <br>train\|deploy) </p> | <p> Flops(G) <br>train\|deploy) </p> | Top-1 (%) | Top-5 (%) | Config | Download |
| :-------------------------: | :----------: | :-------------------------------------: | :--------------------------------------: | :-------: | :-------: | :-----------------------------: | :-------------------------------: |
| repvgg-A0_8xb32_in1k | From scratch | 9.11 \| 8.31 | 1.53 \| 1.36 | 72.37 | 90.56 | [config](./repvgg-A0_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-A0_8xb32_in1k_20221213-60ae8e23.pth) \| [log](https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-A0_8xb32_in1k_20221213-60ae8e23.log) |
| repvgg-A1_8xb32_in1k | From scratch | 14.09 \| 12.79 | 2.65 \| 2.37 | 74.47 | 91.85 | [config](./repvgg-A1_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-A1_8xb32_in1k_20221213-f81bf3df.pth) \| [log](https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-A1_8xb32_in1k_20221213-f81bf3df.log) |
| repvgg-A2_8xb32_in1k | From scratch | 28.21 \| 25.5 | 5.72 \| 5.12 | 76.49 | 93.09 | [config](./repvgg-A2_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-A2_8xb32_in1k_20221213-a8767caf.pth) \| [log](https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-A2_8xb32_in1k_20221213-a8767caf.log) |
| repvgg-B0_8xb32_in1k | From scratch | 15.82 \| 14.34 | 3.43 \| 3.06 | 75.27 | 92.21 | [config](./repvgg-B0_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B0_8xb32_in1k_20221213-5091ecc7.pth) \| [log](https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B0_8xb32_in1k_20221213-5091ecc7.log) |
| repvgg-B1_8xb32_in1k | From scratch | 57.42 \| 51.83 | 13.20 \| 11.81 | 78.19 | 94.04 | [config](./repvgg-B1_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B1_8xb32_in1k_20221213-d17c45e7.pth) \| [log](https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B1_8xb32_in1k_20221213-d17c45e7.log) |
| repvgg-B1g2_8xb32_in1k | From scratch | 45.78 \| 41.36 | 9.86 \| 8.80 | 77.87 | 93.99 | [config](./repvgg-B1g2_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B1g2_8xb32_in1k_20221213-ae6428fd.pth) \| [log](https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B1g2_8xb32_in1k_20221213-ae6428fd.log) |
| repvgg-B1g4_8xb32_in1k | From scratch | 39.97 \| 36.13 | 8.19 \| 7.30 | 77.81 | 93.77 | [config](./repvgg-B1g4_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B1g4_8xb32_in1k_20221213-a7a4aaea.pth) \| [log](https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B1g4_8xb32_in1k_20221213-a7a4aaea.log) |
| repvgg-B2_8xb32_in1k | From scratch | 89.02 \| 80.32 | 20.5 \| 18.4 | 78.58 | 94.23 | [config](./repvgg-B2_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B2_8xb32_in1k_20221213-d8b420ef.pth) \| [log](https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B2_8xb32_in1k_20221213-d8b420ef.log) |
| repvgg-B2g4_8xb32_in1k | From scratch | 61.76 \| 55.78 | 12.7 \| 11.3 | 79.44 | 94.72 | [config](./repvgg-B2g4_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B2g4_8xb32_in1k_20221213-0c1990eb.pth) \| [log](https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B2g4_8xb32_in1k_20221213-0c1990eb.log) |
| repvgg-B3_8xb32_in1k | From scratch | 123.09 \| 110.96 | 29.2 \| 26.2 | 80.58 | 95.33 | [config](./repvgg-B3_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B3_8xb32_in1k_20221213-927a329a.pth) \| [log](https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B3_8xb32_in1k_20221213-927a329a.log) |
| repvgg-B3g4_8xb32_in1k | From scratch | 83.83 \| 75.63 | 18.0 \| 16.1 | 80.26 | 95.15 | [config](./repvgg-B3g4_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B3g4_8xb32_in1k_20221213-e01cb280.pth) \| [log](https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B3g4_8xb32_in1k_20221213-e01cb280.log) |
| repvgg-D2se_3rdparty_in1k\* | From scratch | 133.33 \| 120.39 | 36.6 \| 32.8 | 81.81 | 95.94 | [config](./repvgg-D2se_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-D2se_3rdparty_4xb64-autoaug-lbs-mixup-coslr-200e_in1k_20210909-cf3139b7.pth) |
*Models with * are converted from the [official repo](https://github.com/DingXiaoH/RepVGG/blob/9f272318abfc47a2b702cd0e916fca8d25d683e7/repvgg.py#L250). The config files of these models are only for inference. We don't ensure these config files' training accuracy and welcome you to contribute your reproduction results.*
## Citation ## Citation
``` ```bibtex
@inproceedings{ding2021repvgg, @inproceedings{ding2021repvgg,
title={Repvgg: Making vgg-style convnets great again}, title={Repvgg: Making vgg-style convnets great again},
author={Ding, Xiaohan and Zhang, Xiangyu and Ma, Ningning and Han, Jungong and Ding, Guiguang and Sun, Jian}, author={Ding, Xiaohan and Zhang, Xiangyu and Ma, Ningning and Han, Jungong and Ding, Guiguang and Sun, Jian},

View File

@ -1,3 +0,0 @@
_base_ = '../repvgg-A0_4xb64-coslr-120e_in1k.py'
model = dict(backbone=dict(deploy=True))

View File

@ -1,3 +0,0 @@
_base_ = '../repvgg-A1_4xb64-coslr-120e_in1k.py'
model = dict(backbone=dict(deploy=True))

View File

@ -1,3 +0,0 @@
_base_ = '../repvgg-A2_4xb64-coslr-120e_in1k.py'
model = dict(backbone=dict(deploy=True))

View File

@ -1,3 +0,0 @@
_base_ = '../repvgg-B0_4xb64-coslr-120e_in1k.py'
model = dict(backbone=dict(deploy=True))

View File

@ -1,3 +0,0 @@
_base_ = '../repvgg-B1_4xb64-coslr-120e_in1k.py'
model = dict(backbone=dict(deploy=True))

View File

@ -1,3 +0,0 @@
_base_ = '../repvgg-B1g2_4xb64-coslr-120e_in1k.py'
model = dict(backbone=dict(deploy=True))

View File

@ -1,3 +0,0 @@
_base_ = '../repvgg-B1g4_4xb64-coslr-120e_in1k.py'
model = dict(backbone=dict(deploy=True))

View File

@ -1,3 +0,0 @@
_base_ = '../repvgg-B2_4xb64-coslr-120e_in1k.py'
model = dict(backbone=dict(deploy=True))

View File

@ -1,3 +0,0 @@
_base_ = '../repvgg-B2g4_4xb64-autoaug-lbs-mixup-coslr-200e_in1k.py'
model = dict(backbone=dict(deploy=True))

View File

@ -1,3 +0,0 @@
_base_ = '../repvgg-B3_4xb64-autoaug-lbs-mixup-coslr-200e_in1k.py'
model = dict(backbone=dict(deploy=True))

View File

@ -1,3 +0,0 @@
_base_ = '../repvgg-B3g4_4xb64-autoaug-lbs-mixup-coslr-200e_in1k.py'
model = dict(backbone=dict(deploy=True))

View File

@ -1,3 +0,0 @@
_base_ = '../repvgg-D2se_4xb64-autoaug-lbs-mixup-coslr-200e_in1k.py'
model = dict(backbone=dict(deploy=True))

View File

@ -14,57 +14,48 @@ Collections:
Version: v0.16.0 Version: v0.16.0
Models: Models:
- Name: repvgg-A0_3rdparty_4xb64-coslr-120e_in1k - Name: repvgg-A0_8xb32_in1k
In Collection: RepVGG In Collection: RepVGG
Config: configs/repvgg/repvgg-A0_4xb64-coslr-120e_in1k.py Config: configs/repvgg/repvgg-A0_8xb32_in1k.py
Metadata: Metadata:
FLOPs: 1520000000 FLOPs: 1360233728
Parameters: 9110000 Parameters: 8309384
Results: Results:
- Dataset: ImageNet-1k - Dataset: ImageNet-1k
Task: Image Classification Task: Image Classification
Metrics: Metrics:
Top 1 Accuracy: 72.41 Top 1 Accuracy: 72.37
Top 5 Accuracy: 90.50 Top 5 Accuracy: 90.56
Weights: https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-A0_3rdparty_4xb64-coslr-120e_in1k_20210909-883ab98c.pth Weights: https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-A0_8xb32_in1k_20221213-60ae8e23.pth
Converted From: - Name: repvgg-A1_8xb32_in1k
Weights: https://drive.google.com/drive/folders/1Avome4KvNp0Lqh2QwhXO6L5URQjzCjUq
Code: https://github.com/DingXiaoH/RepVGG/blob/9f272318abfc47a2b702cd0e916fca8d25d683e7/repvgg.py#L196
- Name: repvgg-A1_3rdparty_4xb64-coslr-120e_in1k
In Collection: RepVGG In Collection: RepVGG
Config: configs/repvgg/repvgg-A1_4xb64-coslr-120e_in1k.py Config: configs/repvgg/repvgg-A1_8xb32_in1k.py
Metadata: Metadata:
FLOPs: 2640000000 FLOPs: 2362750208
Parameters: 14090000 Parameters: 12789864
Results: Results:
- Dataset: ImageNet-1k - Dataset: ImageNet-1k
Task: Image Classification Task: Image Classification
Metrics: Metrics:
Top 1 Accuracy: 74.47 Top 1 Accuracy: 74.23
Top 5 Accuracy: 91.85 Top 5 Accuracy: 91.80
Weights: https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-A1_3rdparty_4xb64-coslr-120e_in1k_20210909-24003a24.pth Weights: https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-A1_8xb32_in1k_20221213-f81bf3df.pth
Converted From: - Name: repvgg-A2_8xb32_in1k
Weights: https://drive.google.com/drive/folders/1Avome4KvNp0Lqh2QwhXO6L5URQjzCjUq
Code: https://github.com/DingXiaoH/RepVGG/blob/9f272318abfc47a2b702cd0e916fca8d25d683e7/repvgg.py#L200
- Name: repvgg-A2_3rdparty_4xb64-coslr-120e_in1k
In Collection: RepVGG In Collection: RepVGG
Config: configs/repvgg/repvgg-A2_4xb64-coslr-120e_in1k.py Config: configs/repvgg/repvgg-A2_8xb32_in1k.py
Metadata: Metadata:
FLOPs: 28210000000 FLOPs: 5115612544
Parameters: 5700000 Parameters: 25499944
Results: Results:
- Dataset: ImageNet-1k - Dataset: ImageNet-1k
Task: Image Classification Task: Image Classification
Metrics: Metrics:
Top 1 Accuracy: 76.48 Top 1 Accuracy: 76.49
Top 5 Accuracy: 93.01 Top 5 Accuracy: 93.09
Weights: https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-A2_3rdparty_4xb64-coslr-120e_in1k_20210909-97d7695a.pth Weights: https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-A2_8xb32_in1k_20221213-a8767caf.pth
Converted From: - Name: repvgg-B0_8xb32_in1k
Weights: https://drive.google.com/drive/folders/1Avome4KvNp0Lqh2QwhXO6L5URQjzCjUq
Code: https://github.com/DingXiaoH/RepVGG/blob/9f272318abfc47a2b702cd0e916fca8d25d683e7/repvgg.py#L204
- Name: repvgg-B0_3rdparty_4xb64-coslr-120e_in1k
In Collection: RepVGG In Collection: RepVGG
Config: configs/repvgg/repvgg-B0_4xb64-coslr-120e_in1k.py Config: configs/repvgg/repvgg-B0_8xb32_in1k.py
Metadata: Metadata:
FLOPs: 15820000000 FLOPs: 15820000000
Parameters: 3420000 Parameters: 3420000
@ -72,130 +63,106 @@ Models:
- Dataset: ImageNet-1k - Dataset: ImageNet-1k
Task: Image Classification Task: Image Classification
Metrics: Metrics:
Top 1 Accuracy: 75.14 Top 1 Accuracy: 75.27
Top 5 Accuracy: 92.42 Top 5 Accuracy: 92.21
Weights: https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B0_3rdparty_4xb64-coslr-120e_in1k_20210909-446375f4.pth Weights: https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B0_8xb32_in1k_20221213-5091ecc7.pth
Converted From: - Name: repvgg-B1_8xb32_in1k
Weights: https://drive.google.com/drive/folders/1Avome4KvNp0Lqh2QwhXO6L5URQjzCjUq
Code: https://github.com/DingXiaoH/RepVGG/blob/9f272318abfc47a2b702cd0e916fca8d25d683e7/repvgg.py#L208
- Name: repvgg-B1_3rdparty_4xb64-coslr-120e_in1k
In Collection: RepVGG In Collection: RepVGG
Config: configs/repvgg/repvgg-B1_4xb64-coslr-120e_in1k.py Config: configs/repvgg/repvgg-B1_8xb32_in1k.py
Metadata: Metadata:
FLOPs: 57420000000 FLOPs: 11813537792
Parameters: 13160000 Parameters: 51829480
Results: Results:
- Dataset: ImageNet-1k - Dataset: ImageNet-1k
Task: Image Classification Task: Image Classification
Metrics: Metrics:
Top 1 Accuracy: 78.37 Top 1 Accuracy: 78.19
Top 5 Accuracy: 94.11 Top 5 Accuracy: 94.04
Weights: https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B1_3rdparty_4xb64-coslr-120e_in1k_20210909-750cdf67.pth Weights: https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B1_8xb32_in1k_20221213-d17c45e7.pth
Converted From: - Name: repvgg-B1g2_8xb32_in1k
Weights: https://drive.google.com/drive/folders/1Avome4KvNp0Lqh2QwhXO6L5URQjzCjUq
Code: https://github.com/DingXiaoH/RepVGG/blob/9f272318abfc47a2b702cd0e916fca8d25d683e7/repvgg.py#L212
- Name: repvgg-B1g2_3rdparty_4xb64-coslr-120e_in1k
In Collection: RepVGG In Collection: RepVGG
Config: configs/repvgg/repvgg-B1g2_4xb64-coslr-120e_in1k.py Config: configs/repvgg/repvgg-B1g2_8xb32_in1k.py
Metadata: Metadata:
FLOPs: 45780000000 FLOPs: 8807794688
Parameters: 9820000 Parameters: 41360104
Results: Results:
- Dataset: ImageNet-1k - Dataset: ImageNet-1k
Task: Image Classification Task: Image Classification
Metrics: Metrics:
Top 1 Accuracy: 77.79 Top 1 Accuracy: 77.87
Top 5 Accuracy: 93.88 Top 5 Accuracy: 93.99
Weights: https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B1g2_3rdparty_4xb64-coslr-120e_in1k_20210909-344f6422.pth Weights: https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B1g2_8xb32_in1k_20221213-ae6428fd.pth
Converted From: - Name: repvgg-B1g4_8xb32_in1k
Weights: https://drive.google.com/drive/folders/1Avome4KvNp0Lqh2QwhXO6L5URQjzCjUq
Code: https://github.com/DingXiaoH/RepVGG/blob/9f272318abfc47a2b702cd0e916fca8d25d683e7/repvgg.py#L216
- Name: repvgg-B1g4_3rdparty_4xb64-coslr-120e_in1k
In Collection: RepVGG In Collection: RepVGG
Config: configs/repvgg/repvgg-B1g4_4xb64-coslr-120e_in1k.py Config: configs/repvgg/repvgg-B1g4_8xb32_in1k.py
Metadata: Metadata:
FLOPs: 39970000000 FLOPs: 7304923136
Parameters: 8150000 Parameters: 36125416
Results: Results:
- Dataset: ImageNet-1k - Dataset: ImageNet-1k
Task: Image Classification Task: Image Classification
Metrics: Metrics:
Top 1 Accuracy: 77.58 Top 1 Accuracy: 77.81
Top 5 Accuracy: 93.84 Top 5 Accuracy: 93.77
Weights: https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B1g4_3rdparty_4xb64-coslr-120e_in1k_20210909-d4c1a642.pth Weights: https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B1g4_8xb32_in1k_20221213-a7a4aaea.pth
Converted From: - Name: repvgg-B2_8xb32_in1k
Weights: https://drive.google.com/drive/folders/1Avome4KvNp0Lqh2QwhXO6L5URQjzCjUq
Code: https://github.com/DingXiaoH/RepVGG/blob/9f272318abfc47a2b702cd0e916fca8d25d683e7/repvgg.py#L220
- Name: repvgg-B2_3rdparty_4xb64-coslr-120e_in1k
In Collection: RepVGG In Collection: RepVGG
Config: configs/repvgg/repvgg-B2_4xb64-coslr-120e_in1k.py Config: configs/repvgg/repvgg-B2_8xb32_in1k.py
Metadata: Metadata:
FLOPs: 89020000000 FLOPs: 18374175232
Parameters: 20420000 Parameters: 80315112
Results: Results:
- Dataset: ImageNet-1k - Dataset: ImageNet-1k
Task: Image Classification Task: Image Classification
Metrics: Metrics:
Top 1 Accuracy: 78.78 Top 1 Accuracy: 78.58
Top 5 Accuracy: 94.42 Top 5 Accuracy: 94.23
Weights: https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B2_3rdparty_4xb64-coslr-120e_in1k_20210909-bd6b937c.pth Weights: https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B2_8xb32_in1k_20221213-d8b420ef.pth
Converted From: - Name: repvgg-B2g4_8xb32_in1k
Weights: https://drive.google.com/drive/folders/1Avome4KvNp0Lqh2QwhXO6L5URQjzCjUq
Code: https://github.com/DingXiaoH/RepVGG/blob/9f272318abfc47a2b702cd0e916fca8d25d683e7/repvgg.py#L225
- Name: repvgg-B2g4_3rdparty_4xb64-autoaug-lbs-mixup-coslr-200e_in1k
In Collection: RepVGG In Collection: RepVGG
Config: configs/repvgg/repvgg-B2g4_4xb64-autoaug-lbs-mixup-coslr-200e_in1k.py Config: configs/repvgg/repvgg-B2g4_8xb32_in1k.py
Metadata: Metadata:
FLOPs: 61760000000 FLOPs: 11329464832
Parameters: 12630000 Parameters: 55777512
Results: Results:
- Dataset: ImageNet-1k - Dataset: ImageNet-1k
Task: Image Classification Task: Image Classification
Metrics: Metrics:
Top 1 Accuracy: 79.38 Top 1 Accuracy: 79.44
Top 5 Accuracy: 94.68 Top 5 Accuracy: 94.72
Weights: https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B2g4_3rdparty_4xb64-autoaug-lbs-mixup-coslr-200e_in1k_20210909-7b7955f0.pth Weights: https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B2g4_8xb32_in1k_20221213-0c1990eb.pth
Converted From: - Name: repvgg-B3_8xb32_in1k
Weights: https://drive.google.com/drive/folders/1Avome4KvNp0Lqh2QwhXO6L5URQjzCjUq
Code: https://github.com/DingXiaoH/RepVGG/blob/9f272318abfc47a2b702cd0e916fca8d25d683e7/repvgg.py#L229
- Name: repvgg-B3_3rdparty_4xb64-autoaug-lbs-mixup-coslr-200e_in1k
In Collection: RepVGG In Collection: RepVGG
Config: configs/repvgg/repvgg-B3_4xb64-autoaug-lbs-mixup-coslr-200e_in1k.py Config: configs/repvgg/repvgg-B3_8xb32_in1k.py
Metadata: Metadata:
FLOPs: 123090000000 FLOPs: 26206448128
Parameters: 29170000 Parameters: 110960872
Results: Results:
- Dataset: ImageNet-1k - Dataset: ImageNet-1k
Task: Image Classification Task: Image Classification
Metrics: Metrics:
Top 1 Accuracy: 80.52 Top 1 Accuracy: 80.58
Top 5 Accuracy: 95.26 Top 5 Accuracy: 95.33
Weights: https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B3_3rdparty_4xb64-autoaug-lbs-mixup-coslr-200e_in1k_20210909-dda968bf.pth Weights: https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B3_8xb32_in1k_20221213-927a329a.pth
Converted From: - Name: repvgg-B3g4_8xb32_in1k
Weights: https://drive.google.com/drive/folders/1Avome4KvNp0Lqh2QwhXO6L5URQjzCjUq
Code: https://github.com/DingXiaoH/RepVGG/blob/9f272318abfc47a2b702cd0e916fca8d25d683e7/repvgg.py#L238
- Name: repvgg-B3g4_3rdparty_4xb64-autoaug-lbs-mixup-coslr-200e_in1k
In Collection: RepVGG In Collection: RepVGG
Config: configs/repvgg/repvgg-B3g4_4xb64-autoaug-lbs-mixup-coslr-200e_in1k.py Config: configs/repvgg/repvgg-B3g4_8xb32_in1k.py
Metadata: Metadata:
FLOPs: 83830000000 FLOPs: 16062065152
Parameters: 17900000 Parameters: 75626728
Results: Results:
- Dataset: ImageNet-1k - Dataset: ImageNet-1k
Task: Image Classification Task: Image Classification
Metrics: Metrics:
Top 1 Accuracy: 80.22 Top 1 Accuracy: 80.26
Top 5 Accuracy: 95.10 Top 5 Accuracy: 95.15
Weights: https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B3g4_3rdparty_4xb64-autoaug-lbs-mixup-coslr-200e_in1k_20210909-4e54846a.pth Weights: https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B3g4_8xb32_in1k_20221213-e01cb280.pth
Converted From: - Name: repvgg-D2se_3rdparty_in1k
Weights: https://drive.google.com/drive/folders/1Avome4KvNp0Lqh2QwhXO6L5URQjzCjUq
Code: https://github.com/DingXiaoH/RepVGG/blob/9f272318abfc47a2b702cd0e916fca8d25d683e7/repvgg.py#L238
- Name: repvgg-D2se_3rdparty_4xb64-autoaug-lbs-mixup-coslr-200e_in1k
In Collection: RepVGG In Collection: RepVGG
Config: configs/repvgg/repvgg-D2se_4xb64-autoaug-lbs-mixup-coslr-200e_in1k.py Config: configs/repvgg/repvgg-D2se_8xb32_in1k.py
Metadata: Metadata:
FLOPs: 133330000000 FLOPs: 32838581760
Parameters: 36560000 Parameters: 120387572
Results: Results:
- Dataset: ImageNet-1k - Dataset: ImageNet-1k
Task: Image Classification Task: Image Classification

View File

@ -1,12 +0,0 @@
_base_ = [
'../_base_/models/repvgg-A0_in1k.py',
'../_base_/datasets/imagenet_bs64_pil_resize.py',
'../_base_/schedules/imagenet_bs256_coslr.py',
'../_base_/default_runtime.py'
]
# schedule settings
param_scheduler = dict(
type='CosineAnnealingLR', T_max=120, by_epoch=True, begin=0, end=120)
train_cfg = dict(by_epoch=True, max_epochs=120)

View File

@ -0,0 +1,33 @@
_base_ = [
'../_base_/models/repvgg-A0_in1k.py',
'../_base_/datasets/imagenet_bs32_pil_resize.py',
'../_base_/schedules/imagenet_bs256_coslr.py',
'../_base_/default_runtime.py'
]
val_dataloader = dict(batch_size=256)
test_dataloader = dict(batch_size=256)
# schedule settings
optim_wrapper = dict(
paramwise_cfg=dict(
bias_decay_mult=0.0,
custom_keys={
'branch_3x3.norm': dict(decay_mult=0.0),
'branch_1x1.norm': dict(decay_mult=0.0),
'branch_norm.bias': dict(decay_mult=0.0),
}))
# schedule settings
param_scheduler = dict(
type='CosineAnnealingLR',
T_max=120,
by_epoch=True,
begin=0,
end=120,
convert_to_iter_based=True)
train_cfg = dict(by_epoch=True, max_epochs=120)
default_hooks = dict(
checkpoint=dict(type='CheckpointHook', interval=1, max_keep_ckpts=3))

View File

@ -0,0 +1,3 @@
_base_ = '../repvgg-A0_8xb32_in1k.py'
model = dict(backbone=dict(deploy=True))

View File

@ -1,3 +0,0 @@
_base_ = './repvgg-A0_4xb64-coslr-120e_in1k.py'
model = dict(backbone=dict(arch='A1'))

View File

@ -0,0 +1,3 @@
_base_ = './repvgg-A0_8xb32_in1k.py'
model = dict(backbone=dict(arch='A1'))

View File

@ -1,3 +1,3 @@
_base_ = './repvgg-A0_4xb64-coslr-120e_in1k.py' _base_ = './repvgg-A0_8xb32_in1k.py'
model = dict(backbone=dict(arch='A2'), head=dict(in_channels=1408)) model = dict(backbone=dict(arch='A2'), head=dict(in_channels=1408))

View File

@ -1,3 +1,3 @@
_base_ = './repvgg-A0_4xb64-coslr-120e_in1k.py' _base_ = './repvgg-A0_8xb32_in1k.py'
model = dict(backbone=dict(arch='B0'), head=dict(in_channels=1280)) model = dict(backbone=dict(arch='B0'), head=dict(in_channels=1280))

View File

@ -1,3 +1,3 @@
_base_ = './repvgg-A0_4xb64-coslr-120e_in1k.py' _base_ = './repvgg-A0_8xb32_in1k.py'
model = dict(backbone=dict(arch='B1'), head=dict(in_channels=2048)) model = dict(backbone=dict(arch='B1'), head=dict(in_channels=2048))

View File

@ -1,3 +1,3 @@
_base_ = './repvgg-A0_4xb64-coslr-120e_in1k.py' _base_ = './repvgg-A0_8xb32_in1k.py'
model = dict(backbone=dict(arch='B1g2'), head=dict(in_channels=2048)) model = dict(backbone=dict(arch='B1g2'), head=dict(in_channels=2048))

View File

@ -1,3 +1,3 @@
_base_ = './repvgg-A0_4xb64-coslr-120e_in1k.py' _base_ = './repvgg-A0_8xb32_in1k.py'
model = dict(backbone=dict(arch='B1g4'), head=dict(in_channels=2048)) model = dict(backbone=dict(arch='B1g4'), head=dict(in_channels=2048))

View File

@ -1,3 +1,3 @@
_base_ = './repvgg-A0_4xb64-coslr-120e_in1k.py' _base_ = './repvgg-A0_8xb32_in1k.py'
model = dict(backbone=dict(arch='B2'), head=dict(in_channels=2560)) model = dict(backbone=dict(arch='B2'), head=dict(in_channels=2560))

View File

@ -1,3 +0,0 @@
_base_ = './repvgg-B3_4xb64-autoaug-lbs-mixup-coslr-200e_in1k.py'
model = dict(backbone=dict(arch='B2g4'))

View File

@ -0,0 +1,3 @@
_base_ = './repvgg-B3_8xb32_in1k.py'
model = dict(backbone=dict(arch='B2g4'), head=dict(in_channels=2560))

View File

@ -1,10 +1,20 @@
_base_ = [ _base_ = [
'../_base_/models/repvgg-B3_lbs-mixup_in1k.py', '../_base_/models/repvgg-B3_lbs-mixup_in1k.py',
'../_base_/datasets/imagenet_bs64_pil_resize.py', '../_base_/datasets/imagenet_bs32_pil_resize.py',
'../_base_/schedules/imagenet_bs256_200e_coslr_warmup.py', '../_base_/schedules/imagenet_bs256_coslr.py',
'../_base_/default_runtime.py' '../_base_/default_runtime.py'
] ]
# schedule settings
optim_wrapper = dict(
paramwise_cfg=dict(
bias_decay_mult=0.0,
custom_keys={
'branch_3x3.norm': dict(decay_mult=0.0),
'branch_1x1.norm': dict(decay_mult=0.0),
'branch_norm.bias': dict(decay_mult=0.0),
}))
data_preprocessor = dict( data_preprocessor = dict(
# RGB format normalization parameters # RGB format normalization parameters
mean=[123.675, 116.28, 103.53], mean=[123.675, 116.28, 103.53],
@ -21,8 +31,12 @@ train_pipeline = [
dict(type='RandomResizedCrop', scale=224, backend='pillow'), dict(type='RandomResizedCrop', scale=224, backend='pillow'),
dict(type='RandomFlip', prob=0.5, direction='horizontal'), dict(type='RandomFlip', prob=0.5, direction='horizontal'),
dict( dict(
type='AutoAugment', type='RandAugment',
policies='imagenet', policies='timm_increasing',
num_policies=2,
total_level=10,
magnitude_level=7,
magnitude_std=0.5,
hparams=dict(pad_val=[round(x) for x in bgr_mean])), hparams=dict(pad_val=[round(x) for x in bgr_mean])),
dict(type='PackClsInputs'), dict(type='PackClsInputs'),
] ]
@ -37,3 +51,17 @@ test_pipeline = [
train_dataloader = dict(dataset=dict(pipeline=train_pipeline)) train_dataloader = dict(dataset=dict(pipeline=train_pipeline))
val_dataloader = dict(dataset=dict(pipeline=test_pipeline)) val_dataloader = dict(dataset=dict(pipeline=test_pipeline))
test_dataloader = dict(dataset=dict(pipeline=test_pipeline)) test_dataloader = dict(dataset=dict(pipeline=test_pipeline))
# schedule settings
param_scheduler = dict(
type='CosineAnnealingLR',
T_max=200,
by_epoch=True,
begin=0,
end=200,
convert_to_iter_based=True)
train_cfg = dict(by_epoch=True, max_epochs=200)
default_hooks = dict(
checkpoint=dict(type='CheckpointHook', interval=1, max_keep_ckpts=3))

View File

@ -1,3 +0,0 @@
_base_ = './repvgg-B3_4xb64-autoaug-lbs-mixup-coslr-200e_in1k.py'
model = dict(backbone=dict(arch='B3g4'))

View File

@ -0,0 +1,3 @@
_base_ = './repvgg-B3_8xb32_in1k.py'
model = dict(backbone=dict(arch='B3g4'))

View File

@ -1,3 +0,0 @@
_base_ = './repvgg-B3_4xb64-autoaug-lbs-mixup-coslr-200e_in1k.py'
model = dict(backbone=dict(arch='D2se'))

View File

@ -0,0 +1,28 @@
_base_ = './repvgg-B3_8xb32_in1k.py'
model = dict(backbone=dict(arch='D2se'), head=dict(in_channels=2560))
param_scheduler = [
# warm up learning rate scheduler
dict(
type='LinearLR',
start_factor=0.0001,
by_epoch=True,
begin=0,
end=5,
# update by iter
convert_to_iter_based=True),
# main learning rate scheduler
dict(
type='CosineAnnealingLR',
T_max=295,
eta_min=1.0e-6,
by_epoch=True,
begin=5,
end=300)
]
train_cfg = dict(by_epoch=True, max_epochs=300)
default_hooks = dict(
checkpoint=dict(type='CheckpointHook', interval=1, max_keep_ckpts=3))