mmyolo/configs/rtmdet/distillation
whcao 6f38b781bd
[Feature] Add RTMDet distillation cfg (#544)
* add rtm distillation cfg

* rename the cfg file

* use norm connector

* fix cfg

* fix cfg

* support rtm distillation

* fix readme and cfgs

* fix readme

* add docstring

* add links of ckpts and logs

* Update configs/rtmdet/README.md

Co-authored-by: RangiLyu <lyuchqi@gmail.com>

* fix cfgs

* rename stop distillation hook

* rename stop_epoch

* fix cfg

* add model converter

* add metafile and tta results

* fix metafile

* fix readme

* mv distillation/metafile to metafile

---------

Co-authored-by: RangiLyu <lyuchqi@gmail.com>
2023-03-01 16:20:45 +08:00
..
README.md [Feature] Add RTMDet distillation cfg (#544) 2023-03-01 16:20:45 +08:00
kd_l_rtmdet_x_neck_300e_coco.py [Feature] Add RTMDet distillation cfg (#544) 2023-03-01 16:20:45 +08:00
kd_m_rtmdet_l_neck_300e_coco.py [Feature] Add RTMDet distillation cfg (#544) 2023-03-01 16:20:45 +08:00
kd_s_rtmdet_m_neck_300e_coco.py [Feature] Add RTMDet distillation cfg (#544) 2023-03-01 16:20:45 +08:00
kd_tiny_rtmdet_s_neck_300e_coco.py [Feature] Add RTMDet distillation cfg (#544) 2023-03-01 16:20:45 +08:00

README.md

Distill RTM Detectors Based on MMRazor

Description

To further improve the model accuracy while not introducing much additional computation cost, we apply the feature-based distillation to the training phase of these RTM detectors. In summary, our distillation strategy are threefold:

(1) Inspired by PKD, we first normalize the intermediate feature maps to have zero mean and unit variances before calculating the distillation loss.

(2) Inspired by CWD, we adopt the channel-wise distillation paradigm, which can pay more attention to the most salient regions of each channel.

(3) Inspired by DAMO-YOLO, the distillation process is split into two stages. 1) The teacher distills the student at the first stage (280 epochs) on strong mosaic domain. 2) The student finetunes itself on no masaic domain at the second stage (20 epochs).

Results and Models

Location Dataset Teacher Student mAP mAP(T) mAP(S) Config Download
FPN COCO RTMDet-s RTMDet-tiny 41.8 (+0.8) 44.6 41.0 config teacher |model | log
FPN COCO RTMDet-m RTMDet-s 45.7 (+1.1) 49.3 44.6 config teacher |model | log
FPN COCO RTMDet-l RTMDet-m 50.2 (+0.9) 51.4 49.3 config teacher |model | log
FPN COCO RTMDet-x RTMDet-l 52.3 (+0.9) 52.8 51.4 config teacher |model | log

Usage

Prerequisites

Install MMRazor from source

git clone -b dev-1.x https://github.com/open-mmlab/mmrazor.git
cd mmrazor
# Install MMRazor
mim install -v -e .

Training commands

In MMYOLO's root directory, run the following command to train the RTMDet-tiny with 8 GPUs, using RTMDet-s as the teacher:

CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 PORT=29500 ./tools/dist_train.sh configs/rtmdet/distillation/kd_tiny_rtmdet_s_neck_300e_coco.py

Testing commands

In MMYOLO's root directory, run the following command to test the model:

CUDA_VISIBLE_DEVICES=0 PORT=29500 ./tools/dist_test.sh configs/rtmdet/distillation/kd_tiny_rtmdet_s_neck_300e_coco.py ${CHECKPOINT_PATH}

Getting student-only checkpoint

After training, the checkpoint contains parameters for both student and teacher models. Run the following command to convert it to student-only checkpoint:

python ./tools/model_converters/convert_kd_ckpt_to_student.py ${CHECKPOINT_PATH} --out-path ${OUTPUT_CHECKPOINT_PATH}

Configs

Here we provide detection configs and models for MMRazor in MMYOLO. For clarify, we take ./kd_tiny_rtmdet_s_neck_300e_coco.py as an example to show how to distill a RTM detector based on MMRazor.

Here is the main part of ./kd_tiny_rtmdet_s_neck_300e_coco.py.

norm_cfg = dict(type='BN', affine=False, track_running_stats=False)

distiller=dict(
    type='ConfigurableDistiller',
    student_recorders=dict(
        fpn0=dict(type='ModuleOutputs', source='neck.out_layers.0.conv'),
        fpn1=dict(type='ModuleOutputs', source='neck.out_layers.1.conv'),
        fpn2=dict(type='ModuleOutputs', source='neck.out_layers.2.conv'),
    ),
    teacher_recorders=dict(
        fpn0=dict(type='ModuleOutputs', source='neck.out_layers.0.conv'),
        fpn1=dict(type='ModuleOutputs', source='neck.out_layers.1.conv'),
        fpn2=dict(type='ModuleOutputs', source='neck.out_layers.2.conv')),
    connectors=dict(
        fpn0_s=dict(type='ConvModuleConnector', in_channel=96,
            out_channel=128, bias=False, norm_cfg=norm_cfg,
            act_cfg=None),
        fpn0_t=dict(
            type='NormConnector', in_channels=128, norm_cfg=norm_cfg),
        fpn1_s=dict(
            type='ConvModuleConnector', in_channel=96,
            out_channel=128, bias=False, norm_cfg=norm_cfg,
            act_cfg=None),
        fpn1_t=dict(
            type='NormConnector', in_channels=128, norm_cfg=norm_cfg),
        fpn2_s=dict(
            type='ConvModuleConnector', in_channel=96,
            out_channel=128, bias=False, norm_cfg=norm_cfg,
            act_cfg=None),
        fpn2_t=dict(
            type='NormConnector', in_channels=128, norm_cfg=norm_cfg)),
    distill_losses=dict(
        loss_fpn0=dict(type='ChannelWiseDivergence', loss_weight=1),
        loss_fpn1=dict(type='ChannelWiseDivergence', loss_weight=1),
        loss_fpn2=dict(type='ChannelWiseDivergence', loss_weight=1)),
    loss_forward_mappings=dict(
        loss_fpn0=dict(
            preds_S=dict(from_student=True, recorder='fpn0', connector='fpn0_s'),
            preds_T=dict(from_student=False, recorder='fpn0', connector='fpn0_t')),
        loss_fpn1=dict(
            preds_S=dict(from_student=True, recorder='fpn1', connector='fpn1_s'),
            preds_T=dict(from_student=False, recorder='fpn1', connector='fpn1_t')),
        loss_fpn2=dict(
            preds_S=dict(from_student=True, recorder='fpn2', connector='fpn2_s'),
            preds_T=dict(from_student=False, recorder='fpn2', connector='fpn2_t'))))

recorders are used to record various intermediate results during the model forward. In this example, they can help record the output of 3 nn.Module of the teacher and the student. Details are list in Recorder and MMRazor Distillation (if users can read Chinese).

connectors are adaptive layers which usually map teacher's and students features to the same dimension.

distill_losses are configs for multiple distill losses.

loss_forward_mappings are mappings between distill loss forward arguments and records.

In addition, the student finetunes itself on no masaic domain at the last 20 epochs, so we add a new hook named StopDistillHook to stop distillation on time. We need to add this hook to the custom_hooks list like this:

custom_hooks = [..., dict(type='mmrazor.StopDistillHook', detach_epoch=280)]