mirror of https://github.com/open-mmlab/mmyolo.git
[Feature] Add RTMDet distillation cfg (#544)
* add rtm distillation cfg * rename the cfg file * use norm connector * fix cfg * fix cfg * support rtm distillation * fix readme and cfgs * fix readme * add docstring * add links of ckpts and logs * Update configs/rtmdet/README.md Co-authored-by: RangiLyu <lyuchqi@gmail.com> * fix cfgs * rename stop distillation hook * rename stop_epoch * fix cfg * add model converter * add metafile and tta results * fix metafile * fix readme * mv distillation/metafile to metafile --------- Co-authored-by: RangiLyu <lyuchqi@gmail.com>pull/607/head
parent
d06de6d36f
commit
6f38b781bd
|
@ -23,19 +23,24 @@ RTMDet-l model structure
|
|||
|
||||
## Object Detection
|
||||
|
||||
| Model | size | Params(M) | FLOPS(G) | TRT-FP16-Latency(ms) | box AP | TTA box AP | Config | Download |
|
||||
| :---------: | :--: | :-------: | :------: | :------------------: | :----: | :--------: | :----------------------------------------------------: | :--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
|
||||
| RTMDet-tiny | 640 | 4.8 | 8.1 | 0.98 | 41.0 | 42.7 | [config](./rtmdet_tiny_syncbn_fast_8xb32-300e_coco.py) | [model](https://download.openmmlab.com/mmyolo/v0/rtmdet/rtmdet_tiny_syncbn_fast_8xb32-300e_coco/rtmdet_tiny_syncbn_fast_8xb32-300e_coco_20230102_140117-dbb1dc83.pth) \| [log](https://download.openmmlab.com/mmyolo/v0/rtmdet/rtmdet_tiny_syncbn_fast_8xb32-300e_coco/rtmdet_tiny_syncbn_fast_8xb32-300e_coco_20230102_140117.log.json) |
|
||||
| RTMDet-s | 640 | 8.89 | 14.8 | 1.22 | 44.6 | 45.8 | [config](./rtmdet_s_syncbn_fast_8xb32-300e_coco.py) | [model](https://download.openmmlab.com/mmyolo/v0/rtmdet/rtmdet_s_syncbn_fast_8xb32-300e_coco/rtmdet_s_syncbn_fast_8xb32-300e_coco_20221230_182329-0a8c901a.pth) \| [log](https://download.openmmlab.com/mmyolo/v0/rtmdet/rtmdet_s_syncbn_fast_8xb32-300e_coco/rtmdet_s_syncbn_fast_8xb32-300e_coco_20221230_182329.log.json) |
|
||||
| RTMDet-m | 640 | 24.71 | 39.27 | 1.62 | 49.3 | 50.9 | [config](./rtmdet_m_syncbn_fast_8xb32-300e_coco.py) | [model](https://download.openmmlab.com/mmyolo/v0/rtmdet/rtmdet_m_syncbn_fast_8xb32-300e_coco/rtmdet_m_syncbn_fast_8xb32-300e_coco_20230102_135952-40af4fe8.pth) \| [log](https://download.openmmlab.com/mmyolo/v0/rtmdet/rtmdet_m_syncbn_fast_8xb32-300e_coco/rtmdet_m_syncbn_fast_8xb32-300e_coco_20230102_135952.log.json) |
|
||||
| RTMDet-l | 640 | 52.3 | 80.23 | 2.44 | 51.4 | 53.1 | [config](./rtmdet_l_syncbn_fast_8xb32-300e_coco.py) | [model](https://download.openmmlab.com/mmyolo/v0/rtmdet/rtmdet_l_syncbn_fast_8xb32-300e_coco/rtmdet_l_syncbn_fast_8xb32-300e_coco_20230102_135928-ee3abdc4.pth) \| [log](https://download.openmmlab.com/mmyolo/v0/rtmdet/rtmdet_l_syncbn_fast_8xb32-300e_coco/rtmdet_l_syncbn_fast_8xb32-300e_coco_20230102_135928.log.json) |
|
||||
| RTMDet-x | 640 | 94.86 | 141.67 | 3.10 | 52.8 | 54.2 | [config](./rtmdet_x_syncbn_fast_8xb32-300e_coco.py) | [model](https://download.openmmlab.com/mmyolo/v0/rtmdet/rtmdet_x_syncbn_fast_8xb32-300e_coco/rtmdet_x_syncbn_fast_8xb32-300e_coco_20221231_100345-b85cd476.pth) \| [log](https://download.openmmlab.com/mmyolo/v0/rtmdet/rtmdet_x_syncbn_fast_8xb32-300e_coco/rtmdet_x_syncbn_fast_8xb32-300e_coco_20221231_100345.log.json) |
|
||||
| Model | size | Params(M) | FLOPS(G) | TRT-FP16-Latency(ms) | box AP | TTA box AP | Config | Download |
|
||||
| :------------: | :--: | :-------: | :------: | :------------------: | :---------: | :---------: | :---------------------------------------------------------: | :---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
|
||||
| RTMDet-tiny | 640 | 4.8 | 8.1 | 0.98 | 41.0 | 42.7 | [config](./rtmdet_tiny_syncbn_fast_8xb32-300e_coco.py) | [model](https://download.openmmlab.com/mmyolo/v0/rtmdet/rtmdet_tiny_syncbn_fast_8xb32-300e_coco/rtmdet_tiny_syncbn_fast_8xb32-300e_coco_20230102_140117-dbb1dc83.pth) \| [log](https://download.openmmlab.com/mmyolo/v0/rtmdet/rtmdet_tiny_syncbn_fast_8xb32-300e_coco/rtmdet_tiny_syncbn_fast_8xb32-300e_coco_20230102_140117.log.json) |
|
||||
| RTMDet-tiny \* | 640 | 4.8 | 8.1 | 0.98 | 41.8 (+0.8) | 43.2 (+0.5) | [config](./distillation/kd_tiny_rtmdet_s_neck_300e_coco.py) | [model](https://download.openmmlab.com/mmrazor/v1/rtmdet_distillation/kd_tiny_rtmdet_s_neck_300e_coco/kd_tiny_rtmdet_s_neck_300e_coco_20230213_104240-e1e4197c.pth) \| [log](https://download.openmmlab.com/mmrazor/v1/rtmdet_distillation/kd_tiny_rtmdet_s_neck_300e_coco/kd_tiny_rtmdet_s_neck_300e_coco_20230213_104240-176901d8.json) |
|
||||
| RTMDet-s | 640 | 8.89 | 14.8 | 1.22 | 44.6 | 45.8 | [config](./rtmdet_s_syncbn_fast_8xb32-300e_coco.py) | [model](https://download.openmmlab.com/mmyolo/v0/rtmdet/rtmdet_s_syncbn_fast_8xb32-300e_coco/rtmdet_s_syncbn_fast_8xb32-300e_coco_20221230_182329-0a8c901a.pth) \| [log](https://download.openmmlab.com/mmyolo/v0/rtmdet/rtmdet_s_syncbn_fast_8xb32-300e_coco/rtmdet_s_syncbn_fast_8xb32-300e_coco_20221230_182329.log.json) |
|
||||
| RTMDet-s \* | 640 | 8.89 | 14.8 | 1.22 | 45.7 (+1.1) | 47.3 (+1.5) | [config](./distillation/kd_s_rtmdet_m_neck_300e_coco.py) | [model](https://download.openmmlab.com/mmrazor/v1/rtmdet_distillation/kd_s_rtmdet_m_neck_300e_coco/kd_s_rtmdet_m_neck_300e_coco_20230220_140647-446ff003.pth) \| [log](https://download.openmmlab.com/mmrazor/v1/rtmdet_distillation/kd_s_rtmdet_m_neck_300e_coco/kd_s_rtmdet_m_neck_300e_coco_20230220_140647-89862269.json) |
|
||||
| RTMDet-m | 640 | 24.71 | 39.27 | 1.62 | 49.3 | 50.9 | [config](./rtmdet_m_syncbn_fast_8xb32-300e_coco.py) | [model](https://download.openmmlab.com/mmyolo/v0/rtmdet/rtmdet_m_syncbn_fast_8xb32-300e_coco/rtmdet_m_syncbn_fast_8xb32-300e_coco_20230102_135952-40af4fe8.pth) \| [log](https://download.openmmlab.com/mmyolo/v0/rtmdet/rtmdet_m_syncbn_fast_8xb32-300e_coco/rtmdet_m_syncbn_fast_8xb32-300e_coco_20230102_135952.log.json) |
|
||||
| RTMDet-m \* | 640 | 24.71 | 39.27 | 1.62 | 50.2 (+0.9) | 51.9 (+1.0) | [config](./distillation/kd_m_rtmdet_l_neck_300e_coco.py) | [model](https://download.openmmlab.com/mmrazor/v1/rtmdet_distillation/kd_m_rtmdet_l_neck_300e_coco/kd_m_rtmdet_l_neck_300e_coco_20230220_141313-b806f503.pth) \| [log](https://download.openmmlab.com/mmrazor/v1/rtmdet_distillation/kd_m_rtmdet_l_neck_300e_coco/kd_m_rtmdet_l_neck_300e_coco_20230220_141313-bd028fd3.json) |
|
||||
| RTMDet-l | 640 | 52.3 | 80.23 | 2.44 | 51.4 | 53.1 | [config](./rtmdet_l_syncbn_fast_8xb32-300e_coco.py) | [model](https://download.openmmlab.com/mmyolo/v0/rtmdet/rtmdet_l_syncbn_fast_8xb32-300e_coco/rtmdet_l_syncbn_fast_8xb32-300e_coco_20230102_135928-ee3abdc4.pth) \| [log](https://download.openmmlab.com/mmyolo/v0/rtmdet/rtmdet_l_syncbn_fast_8xb32-300e_coco/rtmdet_l_syncbn_fast_8xb32-300e_coco_20230102_135928.log.json) |
|
||||
| RTMDet-l \* | 640 | 52.3 | 80.23 | 2.44 | 52.3 (+0.9) | 53.7 (+0.6) | [config](./distillation/kd_l_rtmdet_x_neck_300e_coco.py) | [model](https://download.openmmlab.com/mmrazor/v1/rtmdet_distillation/kd_l_rtmdet_x_neck_300e_coco/kd_l_rtmdet_x_neck_300e_coco_20230220_141912-c9979722.pth) \| [log](https://download.openmmlab.com/mmrazor/v1/rtmdet_distillation/kd_l_rtmdet_x_neck_300e_coco/kd_l_rtmdet_x_neck_300e_coco_20230220_141912-c5c4e17b.json) |
|
||||
| RTMDet-x | 640 | 94.86 | 141.67 | 3.10 | 52.8 | 54.2 | [config](./rtmdet_x_syncbn_fast_8xb32-300e_coco.py) | [model](https://download.openmmlab.com/mmyolo/v0/rtmdet/rtmdet_x_syncbn_fast_8xb32-300e_coco/rtmdet_x_syncbn_fast_8xb32-300e_coco_20221231_100345-b85cd476.pth) \| [log](https://download.openmmlab.com/mmyolo/v0/rtmdet/rtmdet_x_syncbn_fast_8xb32-300e_coco/rtmdet_x_syncbn_fast_8xb32-300e_coco_20221231_100345.log.json) |
|
||||
|
||||
**Note**:
|
||||
|
||||
1. The inference speed of RTMDet is measured on an NVIDIA 3090 GPU with TensorRT 8.4.3, cuDNN 8.2.0, FP16, batch size=1, and without NMS.
|
||||
2. For a fair comparison, the config of bbox postprocessing is changed to be consistent with YOLOv5/6/7 after [PR#9494](https://github.com/open-mmlab/mmdetection/pull/9494), bringing about 0.1~0.3% AP improvement.
|
||||
3. `TTA` means that Test Time Augmentation. It's perform 3 multi-scaling transformations on the image, followed by 2 flipping transformations (flipping and not flipping). You only need to specify `--tta` when testing to enable. see [TTA](https://github.com/open-mmlab/mmyolo/blob/dev/docs/en/common_usage/tta.md) for details.
|
||||
4. \* means checkpoints are trained with knowledge distillation. More details can be found in [RTMDet distillation](./distillation).
|
||||
|
||||
## Citation
|
||||
|
||||
|
|
|
@ -0,0 +1,146 @@
|
|||
# Distill RTM Detectors Based on MMRazor
|
||||
|
||||
## Description
|
||||
|
||||
To further improve the model accuracy while not introducing much additional
|
||||
computation cost, we apply the feature-based distillation to the training phase
|
||||
of these RTM detectors. In summary, our distillation strategy are threefold:
|
||||
|
||||
(1) Inspired by [PKD](https://arxiv.org/abs/2207.02039), we first normalize
|
||||
the intermediate feature maps to have zero mean and unit variances before calculating
|
||||
the distillation loss.
|
||||
|
||||
(2) Inspired by [CWD](https://arxiv.org/abs/2011.13256), we adopt the channel-wise
|
||||
distillation paradigm, which can pay more attention to the most salient regions
|
||||
of each channel.
|
||||
|
||||
(3) Inspired by [DAMO-YOLO](https://arxiv.org/abs/2211.15444), the distillation
|
||||
process is split into two stages. 1) The teacher distills the student at the
|
||||
first stage (280 epochs) on strong mosaic domain. 2) The student finetunes itself
|
||||
on no masaic domain at the second stage (20 epochs).
|
||||
|
||||
## Results and Models
|
||||
|
||||
| Location | Dataset | Teacher | Student | mAP | mAP(T) | mAP(S) | Config | Download |
|
||||
| :------: | :-----: | :---------------------------------------------------------------------------------------------------------------: | :---------------------------------------------------------------------------------------------------------------------: | :---------: | :----: | :----: | :------------------------------------------: | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
|
||||
| FPN | COCO | [RTMDet-s](https://github.com/open-mmlab/mmyolo/blob/main/configs/rtmdet/rtmdet_s_syncbn_fast_8xb32-300e_coco.py) | [RTMDet-tiny](https://github.com/open-mmlab/mmyolo/blob/main/configs/rtmdet/rtmdet_tiny_syncbn_fast_8xb32-300e_coco.py) | 41.8 (+0.8) | 44.6 | 41.0 | [config](kd_tiny_rtmdet_s_neck_300e_coco.py) | [teacher](https://download.openmmlab.com/mmyolo/v0/rtmdet/rtmdet_s_syncbn_fast_8xb32-300e_coco/rtmdet_s_syncbn_fast_8xb32-300e_coco_20221230_182329-0a8c901a.pth) \|[model](https://download.openmmlab.com/mmrazor/v1/rtmdet_distillation/kd_tiny_rtmdet_s_neck_300e_coco/kd_tiny_rtmdet_s_neck_300e_coco_20230213_104240-e1e4197c.pth) \| [log](https://download.openmmlab.com/mmrazor/v1/rtmdet_distillation/kd_tiny_rtmdet_s_neck_300e_coco/kd_tiny_rtmdet_s_neck_300e_coco_20230213_104240-176901d8.json) |
|
||||
| FPN | COCO | [RTMDet-m](https://github.com/open-mmlab/mmyolo/blob/main/configs/rtmdet/rtmdet_m_syncbn_fast_8xb32-300e_coco.py) | [RTMDet-s](https://github.com/open-mmlab/mmyolo/blob/main/configs/rtmdet/rtmdet_s_syncbn_fast_8xb32-300e_coco.py) | 45.7 (+1.1) | 49.3 | 44.6 | [config](kd_s_rtmdet_m_neck_300e_coco.py) | [teacher](https://download.openmmlab.com/mmyolo/v0/rtmdet/rtmdet_m_syncbn_fast_8xb32-300e_coco/rtmdet_m_syncbn_fast_8xb32-300e_coco_20230102_135952-40af4fe8.pth) \|[model](https://download.openmmlab.com/mmrazor/v1/rtmdet_distillation/kd_s_rtmdet_m_neck_300e_coco/kd_s_rtmdet_m_neck_300e_coco_20230220_140647-446ff003.pth) \| [log](https://download.openmmlab.com/mmrazor/v1/rtmdet_distillation/kd_s_rtmdet_m_neck_300e_coco/kd_s_rtmdet_m_neck_300e_coco_20230220_140647-89862269.json) |
|
||||
| FPN | COCO | [RTMDet-l](https://github.com/open-mmlab/mmyolo/blob/main/configs/rtmdet/rtmdet_l_syncbn_fast_8xb32-300e_coco.py) | [RTMDet-m](https://github.com/open-mmlab/mmyolo/blob/main/configs/rtmdet/rtmdet_m_syncbn_fast_8xb32-300e_coco.py) | 50.2 (+0.9) | 51.4 | 49.3 | [config](kd_m_rtmdet_l_neck_300e_coco.py) | [teacher](https://download.openmmlab.com/mmyolo/v0/rtmdet/rtmdet_l_syncbn_fast_8xb32-300e_coco/rtmdet_l_syncbn_fast_8xb32-300e_coco_20230102_135928-ee3abdc4.pth) \|[model](https://download.openmmlab.com/mmrazor/v1/rtmdet_distillation/kd_m_rtmdet_l_neck_300e_coco/kd_m_rtmdet_l_neck_300e_coco_20230220_141313-b806f503.pth) \| [log](https://download.openmmlab.com/mmrazor/v1/rtmdet_distillation/kd_m_rtmdet_l_neck_300e_coco/kd_m_rtmdet_l_neck_300e_coco_20230220_141313-bd028fd3.json) |
|
||||
| FPN | COCO | [RTMDet-x](https://github.com/open-mmlab/mmyolo/blob/main/configs/rtmdet/rtmdet_x_syncbn_fast_8xb32-300e_coco.py) | [RTMDet-l](https://github.com/open-mmlab/mmyolo/blob/main/configs/rtmdet/rtmdet_l_syncbn_fast_8xb32-300e_coco.py) | 52.3 (+0.9) | 52.8 | 51.4 | [config](kd_l_rtmdet_x_neck_300e_coco.py) | [teacher](https://download.openmmlab.com/mmyolo/v0/rtmdet/rtmdet_x_syncbn_fast_8xb32-300e_coco/rtmdet_x_syncbn_fast_8xb32-300e_coco_20221231_100345-b85cd476.pth) \|[model](https://download.openmmlab.com/mmrazor/v1/rtmdet_distillation/kd_l_rtmdet_x_neck_300e_coco/kd_l_rtmdet_x_neck_300e_coco_20230220_141912-c9979722.pth) \| [log](https://download.openmmlab.com/mmrazor/v1/rtmdet_distillation/kd_l_rtmdet_x_neck_300e_coco/kd_l_rtmdet_x_neck_300e_coco_20230220_141912-c5c4e17b.json) |
|
||||
|
||||
## Usage
|
||||
|
||||
### Prerequisites
|
||||
|
||||
- [MMRazor dev-1.x](https://github.com/open-mmlab/mmrazor/tree/dev-1.x)
|
||||
|
||||
Install MMRazor from source
|
||||
|
||||
```
|
||||
git clone -b dev-1.x https://github.com/open-mmlab/mmrazor.git
|
||||
cd mmrazor
|
||||
# Install MMRazor
|
||||
mim install -v -e .
|
||||
```
|
||||
|
||||
### Training commands
|
||||
|
||||
In MMYOLO's root directory, run the following command to train the RTMDet-tiny
|
||||
with 8 GPUs, using RTMDet-s as the teacher:
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 PORT=29500 ./tools/dist_train.sh configs/rtmdet/distillation/kd_tiny_rtmdet_s_neck_300e_coco.py
|
||||
```
|
||||
|
||||
### Testing commands
|
||||
|
||||
In MMYOLO's root directory, run the following command to test the model:
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 PORT=29500 ./tools/dist_test.sh configs/rtmdet/distillation/kd_tiny_rtmdet_s_neck_300e_coco.py ${CHECKPOINT_PATH}
|
||||
```
|
||||
|
||||
### Getting student-only checkpoint
|
||||
|
||||
After training, the checkpoint contains parameters for both student and teacher models.
|
||||
Run the following command to convert it to student-only checkpoint:
|
||||
|
||||
```bash
|
||||
python ./tools/model_converters/convert_kd_ckpt_to_student.py ${CHECKPOINT_PATH} --out-path ${OUTPUT_CHECKPOINT_PATH}
|
||||
```
|
||||
|
||||
## Configs
|
||||
|
||||
Here we provide detection configs and models for MMRazor in MMYOLO. For clarify,
|
||||
we take `./kd_tiny_rtmdet_s_neck_300e_coco.py` as an example to show how to
|
||||
distill a RTM detector based on MMRazor.
|
||||
|
||||
Here is the main part of `./kd_tiny_rtmdet_s_neck_300e_coco.py`.
|
||||
|
||||
```shell
|
||||
norm_cfg = dict(type='BN', affine=False, track_running_stats=False)
|
||||
|
||||
distiller=dict(
|
||||
type='ConfigurableDistiller',
|
||||
student_recorders=dict(
|
||||
fpn0=dict(type='ModuleOutputs', source='neck.out_layers.0.conv'),
|
||||
fpn1=dict(type='ModuleOutputs', source='neck.out_layers.1.conv'),
|
||||
fpn2=dict(type='ModuleOutputs', source='neck.out_layers.2.conv'),
|
||||
),
|
||||
teacher_recorders=dict(
|
||||
fpn0=dict(type='ModuleOutputs', source='neck.out_layers.0.conv'),
|
||||
fpn1=dict(type='ModuleOutputs', source='neck.out_layers.1.conv'),
|
||||
fpn2=dict(type='ModuleOutputs', source='neck.out_layers.2.conv')),
|
||||
connectors=dict(
|
||||
fpn0_s=dict(type='ConvModuleConnector', in_channel=96,
|
||||
out_channel=128, bias=False, norm_cfg=norm_cfg,
|
||||
act_cfg=None),
|
||||
fpn0_t=dict(
|
||||
type='NormConnector', in_channels=128, norm_cfg=norm_cfg),
|
||||
fpn1_s=dict(
|
||||
type='ConvModuleConnector', in_channel=96,
|
||||
out_channel=128, bias=False, norm_cfg=norm_cfg,
|
||||
act_cfg=None),
|
||||
fpn1_t=dict(
|
||||
type='NormConnector', in_channels=128, norm_cfg=norm_cfg),
|
||||
fpn2_s=dict(
|
||||
type='ConvModuleConnector', in_channel=96,
|
||||
out_channel=128, bias=False, norm_cfg=norm_cfg,
|
||||
act_cfg=None),
|
||||
fpn2_t=dict(
|
||||
type='NormConnector', in_channels=128, norm_cfg=norm_cfg)),
|
||||
distill_losses=dict(
|
||||
loss_fpn0=dict(type='ChannelWiseDivergence', loss_weight=1),
|
||||
loss_fpn1=dict(type='ChannelWiseDivergence', loss_weight=1),
|
||||
loss_fpn2=dict(type='ChannelWiseDivergence', loss_weight=1)),
|
||||
loss_forward_mappings=dict(
|
||||
loss_fpn0=dict(
|
||||
preds_S=dict(from_student=True, recorder='fpn0', connector='fpn0_s'),
|
||||
preds_T=dict(from_student=False, recorder='fpn0', connector='fpn0_t')),
|
||||
loss_fpn1=dict(
|
||||
preds_S=dict(from_student=True, recorder='fpn1', connector='fpn1_s'),
|
||||
preds_T=dict(from_student=False, recorder='fpn1', connector='fpn1_t')),
|
||||
loss_fpn2=dict(
|
||||
preds_S=dict(from_student=True, recorder='fpn2', connector='fpn2_s'),
|
||||
preds_T=dict(from_student=False, recorder='fpn2', connector='fpn2_t'))))
|
||||
|
||||
```
|
||||
|
||||
`recorders` are used to record various intermediate results during the model forward.
|
||||
In this example, they can help record the output of 3 `nn.Module` of the teacher
|
||||
and the student. Details are list in [Recorder](https://github.com/open-mmlab/mmrazor/blob/dev-1.x/docs/en/advanced_guides/recorder.md) and [MMRazor Distillation](https://zhuanlan.zhihu.com/p/596582609) (if users can read Chinese).
|
||||
|
||||
`connectors` are adaptive layers which usually map teacher's and students features
|
||||
to the same dimension.
|
||||
|
||||
`distill_losses` are configs for multiple distill losses.
|
||||
|
||||
`loss_forward_mappings` are mappings between distill loss forward arguments and records.
|
||||
|
||||
In addition, the student finetunes itself on no masaic domain at the last 20 epochs,
|
||||
so we add a new hook named `StopDistillHook` to stop distillation on time.
|
||||
We need to add this hook to the `custom_hooks` list like this:
|
||||
|
||||
```shell
|
||||
custom_hooks = [..., dict(type='mmrazor.StopDistillHook', detach_epoch=280)]
|
||||
```
|
|
@ -0,0 +1,99 @@
|
|||
_base_ = '../rtmdet_l_syncbn_fast_8xb32-300e_coco.py'
|
||||
|
||||
teacher_ckpt = 'https://download.openmmlab.com/mmyolo/v0/rtmdet/rtmdet_x_syncbn_fast_8xb32-300e_coco/rtmdet_x_syncbn_fast_8xb32-300e_coco_20221231_100345-b85cd476.pth' # noqa: E501
|
||||
|
||||
norm_cfg = dict(type='BN', affine=False, track_running_stats=False)
|
||||
|
||||
model = dict(
|
||||
_delete_=True,
|
||||
_scope_='mmrazor',
|
||||
type='FpnTeacherDistill',
|
||||
architecture=dict(
|
||||
cfg_path='mmyolo::rtmdet/rtmdet_l_syncbn_fast_8xb32-300e_coco.py'),
|
||||
teacher=dict(
|
||||
cfg_path='mmyolo::rtmdet/rtmdet_x_syncbn_fast_8xb32-300e_coco.py'),
|
||||
teacher_ckpt=teacher_ckpt,
|
||||
distiller=dict(
|
||||
type='ConfigurableDistiller',
|
||||
# `recorders` are used to record various intermediate results during
|
||||
# the model forward.
|
||||
student_recorders=dict(
|
||||
fpn0=dict(type='ModuleOutputs', source='neck.out_layers.0.conv'),
|
||||
fpn1=dict(type='ModuleOutputs', source='neck.out_layers.1.conv'),
|
||||
fpn2=dict(type='ModuleOutputs', source='neck.out_layers.2.conv'),
|
||||
),
|
||||
teacher_recorders=dict(
|
||||
fpn0=dict(type='ModuleOutputs', source='neck.out_layers.0.conv'),
|
||||
fpn1=dict(type='ModuleOutputs', source='neck.out_layers.1.conv'),
|
||||
fpn2=dict(type='ModuleOutputs', source='neck.out_layers.2.conv')),
|
||||
# `connectors` are adaptive layers which usually map teacher's and
|
||||
# students features to the same dimension.
|
||||
connectors=dict(
|
||||
fpn0_s=dict(
|
||||
type='ConvModuleConnector',
|
||||
in_channel=256,
|
||||
out_channel=320,
|
||||
bias=False,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=None),
|
||||
fpn0_t=dict(
|
||||
type='NormConnector', in_channels=320, norm_cfg=norm_cfg),
|
||||
fpn1_s=dict(
|
||||
type='ConvModuleConnector',
|
||||
in_channel=256,
|
||||
out_channel=320,
|
||||
bias=False,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=None),
|
||||
fpn1_t=dict(
|
||||
type='NormConnector', in_channels=320, norm_cfg=norm_cfg),
|
||||
fpn2_s=dict(
|
||||
type='ConvModuleConnector',
|
||||
in_channel=256,
|
||||
out_channel=320,
|
||||
bias=False,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=None),
|
||||
fpn2_t=dict(
|
||||
type='NormConnector', in_channels=320, norm_cfg=norm_cfg)),
|
||||
distill_losses=dict(
|
||||
loss_fpn0=dict(type='ChannelWiseDivergence', loss_weight=1),
|
||||
loss_fpn1=dict(type='ChannelWiseDivergence', loss_weight=1),
|
||||
loss_fpn2=dict(type='ChannelWiseDivergence', loss_weight=1)),
|
||||
# `loss_forward_mappings` are mappings between distill loss forward
|
||||
# arguments and records.
|
||||
loss_forward_mappings=dict(
|
||||
loss_fpn0=dict(
|
||||
preds_S=dict(
|
||||
from_student=True, recorder='fpn0', connector='fpn0_s'),
|
||||
preds_T=dict(
|
||||
from_student=False, recorder='fpn0', connector='fpn0_t')),
|
||||
loss_fpn1=dict(
|
||||
preds_S=dict(
|
||||
from_student=True, recorder='fpn1', connector='fpn1_s'),
|
||||
preds_T=dict(
|
||||
from_student=False, recorder='fpn1', connector='fpn1_t')),
|
||||
loss_fpn2=dict(
|
||||
preds_S=dict(
|
||||
from_student=True, recorder='fpn2', connector='fpn2_s'),
|
||||
preds_T=dict(
|
||||
from_student=False, recorder='fpn2',
|
||||
connector='fpn2_t')))))
|
||||
|
||||
find_unused_parameters = True
|
||||
|
||||
custom_hooks = [
|
||||
dict(
|
||||
type='EMAHook',
|
||||
ema_type='ExpMomentumEMA',
|
||||
momentum=0.0002,
|
||||
update_buffers=True,
|
||||
strict_load=False,
|
||||
priority=49),
|
||||
dict(
|
||||
type='mmdet.PipelineSwitchHook',
|
||||
switch_epoch=_base_.max_epochs - _base_.num_epochs_stage2,
|
||||
switch_pipeline=_base_.train_pipeline_stage2),
|
||||
# stop distillation after the 280th epoch
|
||||
dict(type='mmrazor.StopDistillHook', stop_epoch=280)
|
||||
]
|
|
@ -0,0 +1,99 @@
|
|||
_base_ = '../rtmdet_m_syncbn_fast_8xb32-300e_coco.py'
|
||||
|
||||
teacher_ckpt = 'https://download.openmmlab.com/mmyolo/v0/rtmdet/rtmdet_l_syncbn_fast_8xb32-300e_coco/rtmdet_l_syncbn_fast_8xb32-300e_coco_20230102_135928-ee3abdc4.pth' # noqa: E501
|
||||
|
||||
norm_cfg = dict(type='BN', affine=False, track_running_stats=False)
|
||||
|
||||
model = dict(
|
||||
_delete_=True,
|
||||
_scope_='mmrazor',
|
||||
type='FpnTeacherDistill',
|
||||
architecture=dict(
|
||||
cfg_path='mmyolo::rtmdet/rtmdet_m_syncbn_fast_8xb32-300e_coco.py'),
|
||||
teacher=dict(
|
||||
cfg_path='mmyolo::rtmdet/rtmdet_l_syncbn_fast_8xb32-300e_coco.py'),
|
||||
teacher_ckpt=teacher_ckpt,
|
||||
distiller=dict(
|
||||
type='ConfigurableDistiller',
|
||||
# `recorders` are used to record various intermediate results during
|
||||
# the model forward.
|
||||
student_recorders=dict(
|
||||
fpn0=dict(type='ModuleOutputs', source='neck.out_layers.0.conv'),
|
||||
fpn1=dict(type='ModuleOutputs', source='neck.out_layers.1.conv'),
|
||||
fpn2=dict(type='ModuleOutputs', source='neck.out_layers.2.conv'),
|
||||
),
|
||||
teacher_recorders=dict(
|
||||
fpn0=dict(type='ModuleOutputs', source='neck.out_layers.0.conv'),
|
||||
fpn1=dict(type='ModuleOutputs', source='neck.out_layers.1.conv'),
|
||||
fpn2=dict(type='ModuleOutputs', source='neck.out_layers.2.conv')),
|
||||
# `connectors` are adaptive layers which usually map teacher's and
|
||||
# students features to the same dimension.
|
||||
connectors=dict(
|
||||
fpn0_s=dict(
|
||||
type='ConvModuleConnector',
|
||||
in_channel=192,
|
||||
out_channel=256,
|
||||
bias=False,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=None),
|
||||
fpn0_t=dict(
|
||||
type='NormConnector', in_channels=256, norm_cfg=norm_cfg),
|
||||
fpn1_s=dict(
|
||||
type='ConvModuleConnector',
|
||||
in_channel=192,
|
||||
out_channel=256,
|
||||
bias=False,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=None),
|
||||
fpn1_t=dict(
|
||||
type='NormConnector', in_channels=256, norm_cfg=norm_cfg),
|
||||
fpn2_s=dict(
|
||||
type='ConvModuleConnector',
|
||||
in_channel=192,
|
||||
out_channel=256,
|
||||
bias=False,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=None),
|
||||
fpn2_t=dict(
|
||||
type='NormConnector', in_channels=256, norm_cfg=norm_cfg)),
|
||||
distill_losses=dict(
|
||||
loss_fpn0=dict(type='ChannelWiseDivergence', loss_weight=1),
|
||||
loss_fpn1=dict(type='ChannelWiseDivergence', loss_weight=1),
|
||||
loss_fpn2=dict(type='ChannelWiseDivergence', loss_weight=1)),
|
||||
# `loss_forward_mappings` are mappings between distill loss forward
|
||||
# arguments and records.
|
||||
loss_forward_mappings=dict(
|
||||
loss_fpn0=dict(
|
||||
preds_S=dict(
|
||||
from_student=True, recorder='fpn0', connector='fpn0_s'),
|
||||
preds_T=dict(
|
||||
from_student=False, recorder='fpn0', connector='fpn0_t')),
|
||||
loss_fpn1=dict(
|
||||
preds_S=dict(
|
||||
from_student=True, recorder='fpn1', connector='fpn1_s'),
|
||||
preds_T=dict(
|
||||
from_student=False, recorder='fpn1', connector='fpn1_t')),
|
||||
loss_fpn2=dict(
|
||||
preds_S=dict(
|
||||
from_student=True, recorder='fpn2', connector='fpn2_s'),
|
||||
preds_T=dict(
|
||||
from_student=False, recorder='fpn2',
|
||||
connector='fpn2_t')))))
|
||||
|
||||
find_unused_parameters = True
|
||||
|
||||
custom_hooks = [
|
||||
dict(
|
||||
type='EMAHook',
|
||||
ema_type='ExpMomentumEMA',
|
||||
momentum=0.0002,
|
||||
update_buffers=True,
|
||||
strict_load=False,
|
||||
priority=49),
|
||||
dict(
|
||||
type='mmdet.PipelineSwitchHook',
|
||||
switch_epoch=_base_.max_epochs - _base_.num_epochs_stage2,
|
||||
switch_pipeline=_base_.train_pipeline_stage2),
|
||||
# stop distillation after the 280th epoch
|
||||
dict(type='mmrazor.StopDistillHook', stop_epoch=280)
|
||||
]
|
|
@ -0,0 +1,99 @@
|
|||
_base_ = '../rtmdet_s_syncbn_fast_8xb32-300e_coco.py'
|
||||
|
||||
teacher_ckpt = 'https://download.openmmlab.com/mmyolo/v0/rtmdet/rtmdet_m_syncbn_fast_8xb32-300e_coco/rtmdet_m_syncbn_fast_8xb32-300e_coco_20230102_135952-40af4fe8.pth' # noqa: E501
|
||||
|
||||
norm_cfg = dict(type='BN', affine=False, track_running_stats=False)
|
||||
|
||||
model = dict(
|
||||
_delete_=True,
|
||||
_scope_='mmrazor',
|
||||
type='FpnTeacherDistill',
|
||||
architecture=dict(
|
||||
cfg_path='mmyolo::rtmdet/rtmdet_s_syncbn_fast_8xb32-300e_coco.py'),
|
||||
teacher=dict(
|
||||
cfg_path='mmyolo::rtmdet/rtmdet_m_syncbn_fast_8xb32-300e_coco.py'),
|
||||
teacher_ckpt=teacher_ckpt,
|
||||
distiller=dict(
|
||||
type='ConfigurableDistiller',
|
||||
# `recorders` are used to record various intermediate results during
|
||||
# the model forward.
|
||||
student_recorders=dict(
|
||||
fpn0=dict(type='ModuleOutputs', source='neck.out_layers.0.conv'),
|
||||
fpn1=dict(type='ModuleOutputs', source='neck.out_layers.1.conv'),
|
||||
fpn2=dict(type='ModuleOutputs', source='neck.out_layers.2.conv'),
|
||||
),
|
||||
teacher_recorders=dict(
|
||||
fpn0=dict(type='ModuleOutputs', source='neck.out_layers.0.conv'),
|
||||
fpn1=dict(type='ModuleOutputs', source='neck.out_layers.1.conv'),
|
||||
fpn2=dict(type='ModuleOutputs', source='neck.out_layers.2.conv')),
|
||||
# `connectors` are adaptive layers which usually map teacher's and
|
||||
# students features to the same dimension.
|
||||
connectors=dict(
|
||||
fpn0_s=dict(
|
||||
type='ConvModuleConnector',
|
||||
in_channel=128,
|
||||
out_channel=192,
|
||||
bias=False,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=None),
|
||||
fpn0_t=dict(
|
||||
type='NormConnector', in_channels=192, norm_cfg=norm_cfg),
|
||||
fpn1_s=dict(
|
||||
type='ConvModuleConnector',
|
||||
in_channel=128,
|
||||
out_channel=192,
|
||||
bias=False,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=None),
|
||||
fpn1_t=dict(
|
||||
type='NormConnector', in_channels=192, norm_cfg=norm_cfg),
|
||||
fpn2_s=dict(
|
||||
type='ConvModuleConnector',
|
||||
in_channel=128,
|
||||
out_channel=192,
|
||||
bias=False,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=None),
|
||||
fpn2_t=dict(
|
||||
type='NormConnector', in_channels=192, norm_cfg=norm_cfg)),
|
||||
distill_losses=dict(
|
||||
loss_fpn0=dict(type='ChannelWiseDivergence', loss_weight=1),
|
||||
loss_fpn1=dict(type='ChannelWiseDivergence', loss_weight=1),
|
||||
loss_fpn2=dict(type='ChannelWiseDivergence', loss_weight=1)),
|
||||
# `loss_forward_mappings` are mappings between distill loss forward
|
||||
# arguments and records.
|
||||
loss_forward_mappings=dict(
|
||||
loss_fpn0=dict(
|
||||
preds_S=dict(
|
||||
from_student=True, recorder='fpn0', connector='fpn0_s'),
|
||||
preds_T=dict(
|
||||
from_student=False, recorder='fpn0', connector='fpn0_t')),
|
||||
loss_fpn1=dict(
|
||||
preds_S=dict(
|
||||
from_student=True, recorder='fpn1', connector='fpn1_s'),
|
||||
preds_T=dict(
|
||||
from_student=False, recorder='fpn1', connector='fpn1_t')),
|
||||
loss_fpn2=dict(
|
||||
preds_S=dict(
|
||||
from_student=True, recorder='fpn2', connector='fpn2_s'),
|
||||
preds_T=dict(
|
||||
from_student=False, recorder='fpn2',
|
||||
connector='fpn2_t')))))
|
||||
|
||||
find_unused_parameters = True
|
||||
|
||||
custom_hooks = [
|
||||
dict(
|
||||
type='EMAHook',
|
||||
ema_type='ExpMomentumEMA',
|
||||
momentum=0.0002,
|
||||
update_buffers=True,
|
||||
strict_load=False,
|
||||
priority=49),
|
||||
dict(
|
||||
type='mmdet.PipelineSwitchHook',
|
||||
switch_epoch=_base_.max_epochs - _base_.num_epochs_stage2,
|
||||
switch_pipeline=_base_.train_pipeline_stage2),
|
||||
# stop distillation after the 280th epoch
|
||||
dict(type='mmrazor.StopDistillHook', stop_epoch=280)
|
||||
]
|
|
@ -0,0 +1,99 @@
|
|||
_base_ = '../rtmdet_tiny_syncbn_fast_8xb32-300e_coco.py'
|
||||
|
||||
teacher_ckpt = 'https://download.openmmlab.com/mmyolo/v0/rtmdet/rtmdet_s_syncbn_fast_8xb32-300e_coco/rtmdet_s_syncbn_fast_8xb32-300e_coco_20221230_182329-0a8c901a.pth' # noqa: E501
|
||||
|
||||
norm_cfg = dict(type='BN', affine=False, track_running_stats=False)
|
||||
|
||||
model = dict(
|
||||
_delete_=True,
|
||||
_scope_='mmrazor',
|
||||
type='FpnTeacherDistill',
|
||||
architecture=dict(
|
||||
cfg_path='mmyolo::rtmdet/rtmdet_tiny_syncbn_fast_8xb32-300e_coco.py'),
|
||||
teacher=dict(
|
||||
cfg_path='mmyolo::rtmdet/rtmdet_s_syncbn_fast_8xb32-300e_coco.py'),
|
||||
teacher_ckpt=teacher_ckpt,
|
||||
distiller=dict(
|
||||
type='ConfigurableDistiller',
|
||||
# `recorders` are used to record various intermediate results during
|
||||
# the model forward.
|
||||
student_recorders=dict(
|
||||
fpn0=dict(type='ModuleOutputs', source='neck.out_layers.0.conv'),
|
||||
fpn1=dict(type='ModuleOutputs', source='neck.out_layers.1.conv'),
|
||||
fpn2=dict(type='ModuleOutputs', source='neck.out_layers.2.conv'),
|
||||
),
|
||||
teacher_recorders=dict(
|
||||
fpn0=dict(type='ModuleOutputs', source='neck.out_layers.0.conv'),
|
||||
fpn1=dict(type='ModuleOutputs', source='neck.out_layers.1.conv'),
|
||||
fpn2=dict(type='ModuleOutputs', source='neck.out_layers.2.conv')),
|
||||
# `connectors` are adaptive layers which usually map teacher's and
|
||||
# students features to the same dimension.
|
||||
connectors=dict(
|
||||
fpn0_s=dict(
|
||||
type='ConvModuleConnector',
|
||||
in_channel=96,
|
||||
out_channel=128,
|
||||
bias=False,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=None),
|
||||
fpn0_t=dict(
|
||||
type='NormConnector', in_channels=128, norm_cfg=norm_cfg),
|
||||
fpn1_s=dict(
|
||||
type='ConvModuleConnector',
|
||||
in_channel=96,
|
||||
out_channel=128,
|
||||
bias=False,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=None),
|
||||
fpn1_t=dict(
|
||||
type='NormConnector', in_channels=128, norm_cfg=norm_cfg),
|
||||
fpn2_s=dict(
|
||||
type='ConvModuleConnector',
|
||||
in_channel=96,
|
||||
out_channel=128,
|
||||
bias=False,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=None),
|
||||
fpn2_t=dict(
|
||||
type='NormConnector', in_channels=128, norm_cfg=norm_cfg)),
|
||||
distill_losses=dict(
|
||||
loss_fpn0=dict(type='ChannelWiseDivergence', loss_weight=1),
|
||||
loss_fpn1=dict(type='ChannelWiseDivergence', loss_weight=1),
|
||||
loss_fpn2=dict(type='ChannelWiseDivergence', loss_weight=1)),
|
||||
# `loss_forward_mappings` are mappings between distill loss forward
|
||||
# arguments and records.
|
||||
loss_forward_mappings=dict(
|
||||
loss_fpn0=dict(
|
||||
preds_S=dict(
|
||||
from_student=True, recorder='fpn0', connector='fpn0_s'),
|
||||
preds_T=dict(
|
||||
from_student=False, recorder='fpn0', connector='fpn0_t')),
|
||||
loss_fpn1=dict(
|
||||
preds_S=dict(
|
||||
from_student=True, recorder='fpn1', connector='fpn1_s'),
|
||||
preds_T=dict(
|
||||
from_student=False, recorder='fpn1', connector='fpn1_t')),
|
||||
loss_fpn2=dict(
|
||||
preds_S=dict(
|
||||
from_student=True, recorder='fpn2', connector='fpn2_s'),
|
||||
preds_T=dict(
|
||||
from_student=False, recorder='fpn2',
|
||||
connector='fpn2_t')))))
|
||||
|
||||
find_unused_parameters = True
|
||||
|
||||
custom_hooks = [
|
||||
dict(
|
||||
type='EMAHook',
|
||||
ema_type='ExpMomentumEMA',
|
||||
momentum=0.0002,
|
||||
update_buffers=True,
|
||||
strict_load=False,
|
||||
priority=49),
|
||||
dict(
|
||||
type='mmdet.PipelineSwitchHook',
|
||||
switch_epoch=_base_.max_epochs - _base_.num_epochs_stage2,
|
||||
switch_pipeline=_base_.train_pipeline_stage2),
|
||||
# stop distillation after the 280th epoch
|
||||
dict(type='mmrazor.StopDistillHook', stop_epoch=280)
|
||||
]
|
|
@ -28,6 +28,19 @@ Models:
|
|||
box AP: 41.0
|
||||
Weights: https://download.openmmlab.com/mmyolo/v0/rtmdet/rtmdet_tiny_syncbn_fast_8xb32-300e_coco/rtmdet_tiny_syncbn_fast_8xb32-300e_coco_20230102_140117-dbb1dc83.pth
|
||||
|
||||
- Name: kd_tiny_rtmdet_s_neck_300e_coco
|
||||
In Collection: RTMDet
|
||||
Config: configs/rtmdet/distillation/kd_tiny_rtmdet_s_neck_300e_coco.py
|
||||
Metadata:
|
||||
Training Memory (GB): 11.9
|
||||
Epochs: 300
|
||||
Results:
|
||||
- Task: Object Detection
|
||||
Dataset: COCO
|
||||
Metrics:
|
||||
box AP: 41.8
|
||||
Weights: https://download.openmmlab.com/mmrazor/v1/rtmdet_distillation/kd_tiny_rtmdet_s_neck_300e_coco/kd_tiny_rtmdet_s_neck_300e_coco_20230213_104240-e1e4197c.pth
|
||||
|
||||
- Name: rtmdet_s_syncbn_fast_8xb32-300e_coco
|
||||
In Collection: RTMDet
|
||||
Config: configs/rtmdet/rtmdet_s_syncbn_fast_8xb32-300e_coco.py
|
||||
|
@ -41,6 +54,19 @@ Models:
|
|||
box AP: 44.6
|
||||
Weights: https://download.openmmlab.com/mmyolo/v0/rtmdet/rtmdet_s_syncbn_fast_8xb32-300e_coco/rtmdet_s_syncbn_fast_8xb32-300e_coco_20221230_182329-0a8c901a.pth
|
||||
|
||||
- Name: kd_s_rtmdet_m_neck_300e_coco
|
||||
In Collection: RTMDet
|
||||
Config: configs/rtmdet/distillation/kd_s_rtmdet_m_neck_300e_coco.py
|
||||
Metadata:
|
||||
Training Memory (GB): 16.3
|
||||
Epochs: 300
|
||||
Results:
|
||||
- Task: Object Detection
|
||||
Dataset: COCO
|
||||
Metrics:
|
||||
box AP: 45.7
|
||||
Weights: https://download.openmmlab.com/mmrazor/v1/rtmdet_distillation/kd_s_rtmdet_m_neck_300e_coco/kd_s_rtmdet_m_neck_300e_coco_20230220_140647-446ff003.pth
|
||||
|
||||
- Name: rtmdet_m_syncbn_fast_8xb32-300e_coco
|
||||
In Collection: RTMDet
|
||||
Config: configs/rtmdet/rtmdet_m_syncbn_fast_8xb32-300e_coco.py
|
||||
|
@ -54,6 +80,19 @@ Models:
|
|||
box AP: 49.3
|
||||
Weights: https://download.openmmlab.com/mmyolo/v0/rtmdet/rtmdet_m_syncbn_fast_8xb32-300e_coco/rtmdet_m_syncbn_fast_8xb32-300e_coco_20230102_135952-40af4fe8.pth
|
||||
|
||||
- Name: kd_m_rtmdet_l_neck_300e_coco
|
||||
In Collection: RTMDet
|
||||
Config: configs/rtmdet/distillation/kd_m_rtmdet_l_neck_300e_coco.py
|
||||
Metadata:
|
||||
Training Memory (GB): 29.0
|
||||
Epochs: 300
|
||||
Results:
|
||||
- Task: Object Detection
|
||||
Dataset: COCO
|
||||
Metrics:
|
||||
box AP: 50.2
|
||||
Weights: https://download.openmmlab.com/mmrazor/v1/rtmdet_distillation/kd_m_rtmdet_l_neck_300e_coco/kd_m_rtmdet_l_neck_300e_coco_20230220_141313-b806f503.pth
|
||||
|
||||
- Name: rtmdet_l_syncbn_fast_8xb32-300e_coco
|
||||
In Collection: RTMDet
|
||||
Config: configs/rtmdet/rtmdet_l_syncbn_fast_8xb32-300e_coco.py
|
||||
|
@ -67,6 +106,19 @@ Models:
|
|||
box AP: 51.4
|
||||
Weights: https://download.openmmlab.com/mmyolo/v0/rtmdet/rtmdet_l_syncbn_fast_8xb32-300e_coco/rtmdet_l_syncbn_fast_8xb32-300e_coco_20230102_135928-ee3abdc4.pth
|
||||
|
||||
- Name: kd_l_rtmdet_x_neck_300e_coco
|
||||
In Collection: RTMDet
|
||||
Config: configs/rtmdet/distillation/kd_l_rtmdet_x_neck_300e_coco.py
|
||||
Metadata:
|
||||
Training Memory (GB): 45.2
|
||||
Epochs: 300
|
||||
Results:
|
||||
- Task: Object Detection
|
||||
Dataset: COCO
|
||||
Metrics:
|
||||
box AP: 52.3
|
||||
Weights: https://download.openmmlab.com/mmrazor/v1/rtmdet_distillation/kd_l_rtmdet_x_neck_300e_coco/kd_l_rtmdet_x_neck_300e_coco_20230220_141912-c9979722.pth
|
||||
|
||||
- Name: rtmdet_x_syncbn_fast_8xb32-300e_coco
|
||||
In Collection: RTMDet
|
||||
Config: configs/rtmdet/rtmdet_x_syncbn_fast_8xb32-300e_coco.py
|
||||
|
|
|
@ -0,0 +1,48 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Convert KD checkpoint to student-only checkpoint')
|
||||
parser.add_argument('checkpoint', help='input checkpoint filename')
|
||||
parser.add_argument('--out-path', help='save checkpoint path')
|
||||
parser.add_argument(
|
||||
'--inplace', action='store_true', help='replace origin ckpt')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
checkpoint = torch.load(args.checkpoint, map_location='cpu')
|
||||
new_state_dict = dict()
|
||||
new_meta = checkpoint['meta']
|
||||
|
||||
for key, value in checkpoint['state_dict'].items():
|
||||
if key.startswith('architecture.'):
|
||||
new_key = key.replace('architecture.', '')
|
||||
new_state_dict[new_key] = value
|
||||
|
||||
checkpoint = dict()
|
||||
checkpoint['meta'] = new_meta
|
||||
checkpoint['state_dict'] = new_state_dict
|
||||
|
||||
if args.inplace:
|
||||
torch.save(checkpoint, args.checkpoint)
|
||||
else:
|
||||
ckpt_path = Path(args.checkpoint)
|
||||
ckpt_name = ckpt_path.stem
|
||||
if args.out_path:
|
||||
ckpt_dir = Path(args.out_path)
|
||||
else:
|
||||
ckpt_dir = ckpt_path.parent
|
||||
new_ckpt_path = ckpt_dir / f'{ckpt_name}_student.pth'
|
||||
torch.save(checkpoint, new_ckpt_path)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
Loading…
Reference in New Issue