diff --git a/configs/distill/mmcls/abloss/README.md b/configs/distill/mmcls/abloss/README.md index 5120ed18..1e6e1951 100644 --- a/configs/distill/mmcls/abloss/README.md +++ b/configs/distill/mmcls/abloss/README.md @@ -14,9 +14,10 @@ An activation boundary for a neuron refers to a separating hyperplane that deter ### Classification -| Location | Dataset | Teacher | Student | Acc | Acc(T) | Acc(S) | Config | Download | -| :----------------------------------: | :------: | :----------------------------------------------------------------------------------------------------------: | :----------------------------------------------------------------------------------------------------------: | :---: | :----: | :----: | :--------------------------------------------------------------------------------------------------------------------------------------: | :---------------------------------------------------------------------------------------------------------------------------------------- | -| backbone (pretrain) & logits (train) | ImageNet | [resnet50](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet50_8xb32_in1k.py) | [resnet18](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet18_8xb32_in1k.py) | 70.58 | 76.55 | 69.90 | [pretrain_config](./abloss_backbone_resnet50_resnet18_8xb32_in1k_pretrain) [train_config](./abloss_head_resnet50_resnet18_8xb32_in1k.py) | [teacher](https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth) \|[model](<>) \| [log](<>) | +| Location | Dataset | Teacher | Student | Acc | Acc(T) | Acc(S) | Config | Download | +| :-----------------: | :------: | :----------------------------------------------------------------------------------------------------------: | :----------------------------------------------------------------------------------------------------------: | :---: | :----: | :----: | :------------------------------------------------------------------------: | :----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| backbone (pretrain) | ImageNet | [resnet50](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet50_8xb32_in1k.py) | [resnet18](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet18_8xb32_in1k.py) | | 76.55 | 69.90 | [pretrain_config](./abloss_pretrain_backbone_resnet50_resnet18_8xb32_in1k) | [teacher](https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth) \|[model](https://download.openmmlab.com/mmrazor/v1/ABLoss/abloss_pretrain_backbone_resnet50_resnet18_8xb32_in1k_20220830_165724-a6284e9f.pth) \| [log](https://download.openmmlab.com/mmrazor/v1/ABLoss/abloss_pretrain_backbone_resnet50_resnet18_8xb32_in1k_20220830_165724-a6284e9f.json) | +| logits (train) | ImageNet | [resnet50](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet50_8xb32_in1k.py) | [resnet18](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet18_8xb32_in1k.py) | 69.94 | 76.55 | 69.90 | [train_config](./abloss_logits_resnet50_resnet18_8xb32_in1k.py) | [teacher](https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth) \|[model](https://download.openmmlab.com/mmrazor/v1/ABLoss/abloss_logits_resnet50_resnet18_8xb32_in1k_20220830_202129-f35edde8.pth) \| [log](https://download.openmmlab.com/mmrazor/v1/ABLoss/abloss_logits_resnet50_resnet18_8xb32_in1k_20220830_202129-f35edde8.json) | ## Citation @@ -43,31 +44,26 @@ An activation boundary for a neuron refers to a separating hyperplane that deter ## Getting Started -### ABConnectors and Student pre-training. +### Pre-training. ```bash -sh tools/slurm_train.sh $PARTITION $JOB_NAME \ - configs/distill/mmcls/abloss/abloss_backbone_resnet50_resnet18_8xb32_in1k_pretrain.py\ -$PRETRAIN_WORK_DIR - +sh tools/dist_train.sh configs/distill/mmcls/abloss/abloss_pretrain_backbone_resnet50_resnet18_8xb32_in1k.py 8 ``` ### Modify Distillation training config -open file 'configs/distill/mmcls/abloss/abloss_head_resnet50_resnet18_8xb32_in1k.py' +open file 'configs/distill/mmcls/abloss/abloss_logits_resnet50_resnet18_8xb32_in1k.py' ```python -# modify init_cfg in model settings -# pretrain_work_dir is same as the PRETRAIN_WORK_DIR in pre-training. +# Modify init_cfg in model settings. +# 'pretrain_work_dir' is same as the 'work_dir of pre-training'. +# 'last_epoch' defaults to 'epoch_20' in ABLoss. init_cfg=dict( - type='Pretrained', checkpoint='pretrain_work_dir/last_chechpoint.pth'), + type='Pretrained', checkpoint='pretrain_work_dir/last_epoch.pth'), ``` ### Distillation training. ```bash -sh tools/slurm_train.sh $PARTITION $JOB_NAME \ - configs/distill/mmcls/abloss/abloss_head_resnet50_resnet18_8xb32_in1k.py\ -$DISTILLATION_WORK_DIR - +sh tools/dist_train.sh configs/distill/mmcls/abloss/abloss_logits_resnet50_resnet18_8xb32_in1k.py 8 ``` diff --git a/configs/distill/mmcls/abloss/abloss_head_resnet50_resnet18_8xb32_in1k.py b/configs/distill/mmcls/abloss/abloss_logits_resnet50_resnet18_8xb32_in1k.py similarity index 83% rename from configs/distill/mmcls/abloss/abloss_head_resnet50_resnet18_8xb32_in1k.py rename to configs/distill/mmcls/abloss/abloss_logits_resnet50_resnet18_8xb32_in1k.py index 0a2e8f85..e1e9c3a7 100644 --- a/configs/distill/mmcls/abloss/abloss_head_resnet50_resnet18_8xb32_in1k.py +++ b/configs/distill/mmcls/abloss/abloss_logits_resnet50_resnet18_8xb32_in1k.py @@ -4,6 +4,8 @@ _base_ = [ 'mmcls::_base_/default_runtime.py' ] +# Modify pretrain_checkpoint before training. +pretrain_checkpoint = 'work_dir_of_abloss_pretrain/last_epoch.pth' model = dict( _scope_='mmrazor', type='SingleTeacherDistill', @@ -18,8 +20,7 @@ model = dict( cfg_path='mmcls::resnet/resnet18_8xb32_in1k.py', pretrained=False), teacher=dict( cfg_path='mmcls::resnet/resnet50_8xb32_in1k.py', pretrained=False), - init_cfg=dict( - type='Pretrained', checkpoint='pretrain_work_dir/last_chechpoint.pth'), + init_cfg=dict(type='Pretrained', checkpoint=pretrain_checkpoint), distiller=dict( type='ConfigurableDistiller', student_recorders=dict( @@ -28,7 +29,7 @@ model = dict( fc=dict(type='ModuleOutputs', source='head.fc')), distill_losses=dict( loss_kl=dict( - type='KLDivergence', loss_weight=200, reduction='mean')), + type='KLDivergence', loss_weight=6.25, reduction='mean')), loss_forward_mappings=dict( loss_kl=dict( preds_S=dict(from_student=True, recorder='fc'), diff --git a/configs/distill/mmcls/abloss/abloss_backbone_resnet50_resnet18_8xb32_in1k_pretrain.py b/configs/distill/mmcls/abloss/abloss_pretrain_backbone_resnet50_resnet18_8xb32_in1k.py similarity index 95% rename from configs/distill/mmcls/abloss/abloss_backbone_resnet50_resnet18_8xb32_in1k_pretrain.py rename to configs/distill/mmcls/abloss/abloss_pretrain_backbone_resnet50_resnet18_8xb32_in1k.py index 2a580800..d24ef0ea 100644 --- a/configs/distill/mmcls/abloss/abloss_backbone_resnet50_resnet18_8xb32_in1k_pretrain.py +++ b/configs/distill/mmcls/abloss/abloss_pretrain_backbone_resnet50_resnet18_8xb32_in1k.py @@ -4,8 +4,7 @@ _base_ = [ 'mmcls::_base_/default_runtime.py' ] -train_cfg = dict(by_epoch=True, max_epochs=20, val_interval=1) - +teacher_ckpt = 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth' # noqa: E501 model = dict( _scope_='mmrazor', type='SingleTeacherDistill', @@ -20,7 +19,7 @@ model = dict( cfg_path='mmcls::resnet/resnet18_8xb32_in1k.py', pretrained=False), teacher=dict( cfg_path='mmcls::resnet/resnet50_8xb32_in1k.py', pretrained=True), - teacher_ckpt='resnet50_8xb32_in1k_20210831-ea4938fc.pth', + teacher_ckpt=teacher_ckpt, calculate_student_loss=False, distiller=dict( type='ConfigurableDistiller', @@ -94,4 +93,5 @@ model = dict( find_unused_parameters = True +train_cfg = dict(by_epoch=True, max_epochs=20, val_interval=1) val_cfg = dict(_delete_=True, type='mmrazor.SingleTeacherDistillValLoop') diff --git a/configs/distill/mmcls/abloss/metafile.yml b/configs/distill/mmcls/abloss/metafile.yml new file mode 100644 index 00000000..f19f42c6 --- /dev/null +++ b/configs/distill/mmcls/abloss/metafile.yml @@ -0,0 +1,61 @@ +Collections: + - Name: ABLoss + Metadata: + Training Data: + - ImageNet-1k + Paper: + URL: https://arxiv.org/pdf/1811.03233.pdf + Title: Knowledge Transfer via Distillation of Activation Boundaries Formed by Hidden Neurons + README: configs/distill/mmcls/abloss/README.md + Converted From: + Code: + URL: https://github.com/bhheo/AB_distillation +Models: + - Name: abloss_pretrain_backbone_resnet50_resnet18_8xb32_in1k + In Collection: ABLoss + Metadata: + inference time (ms/im): + - value: 0.21 + hardware: NVIDIA A100-SXM4-80GB + backend: PyTorch + batch size: 32 + mode: FP32 + resolution: (224, 224) + Location: backbone + Student: + Config: mmcls::resnet/resnet18_8xb32_in1k.py + Weights: https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_8xb32_in1k_20210831-fbbb1da6.pth + Metrics: + Top 1 Accuracy: 69.90 + Top 5 Accuracy: 89.43 + Teacher: + Config: mmcls::resnet/resnet50_8xb32_in1k.py + Weights: https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth + Metrics: + Top 1 Accuracy: 76.55 + Top 5 Accuracy: 93.06 + Config: configs/distill/mmcls/abloss/abloss_pretrain_backbone_resnet50_resnet18_8xb32_in1k.py + Weights: https://download.openmmlab.com/mmrazor/v1/ABLoss/abloss_pretrain_backbone_resnet50_resnet18_8xb32_in1k_20220830_165724-a6284e9f.pth + - Name: abloss_logits_resnet50_resnet18_8xb32_in1k + In Collection: ABLoss + Metadata: + Location: logits + Student: + Config: mmcls::resnet/resnet18_8xb32_in1k.py + Weights: https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_8xb32_in1k_20210831-fbbb1da6.pth + Metrics: + Top 1 Accuracy: 69.90 + Top 5 Accuracy: 89.43 + Teacher: + Config: mmcls::resnet/resnet50_8xb32_in1k.py + Weights: https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth + Metrics: + Top 1 Accuracy: 76.55 + Top 5 Accuracy: 93.06 + Results: + - Task: Image Classification + Dataset: ImageNet-1k + Metrics: + Top 1 Accuracy: 69.94 + Config: configs/distill/mmcls/abloss/abloss_logits_resnet50_resnet18_8xb32_in1k.py + Weights: https://download.openmmlab.com/mmrazor/v1/ABLoss/abloss_logits_resnet50_resnet18_8xb32_in1k_20220830_202129-f35edde8.pth diff --git a/configs/distill/mmcls/dafl/README.md b/configs/distill/mmcls/dafl/README.md index 47b7856a..76bf6dee 100644 --- a/configs/distill/mmcls/dafl/README.md +++ b/configs/distill/mmcls/dafl/README.md @@ -14,9 +14,9 @@ Learning portable neural networks is very essential for computer vision for the ### Classification -| Location | Dataset | Teacher | Student | Acc | Acc(T) | Acc(S) | Config | Download | -| :---------------: | :-----: | :-------------------------------------------------------------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------------: | :---: | :----: | :----: | :-----------------------------------------------: | :------------------------------------------------------------------------------------------------------------------------------------------- | -| backbone & logits | Cifar10 | [resnet34](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet34_8xb16_cifar10.py) | [resnet18](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet18_8xb16_cifar10.py) | 93.11 | 95.34 | 94.82 | [config](./dafl_logits_r34_r18_8xb256_cifar10.py) | [teacher](https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_b16x8_cifar10_20210528-a8aa36a6.pth) \|[model](<>) \| [log](<>) | +| Location | Dataset | Teacher | Student | Acc | Acc(T) | Acc(S) | Config | Download | +| :---------------: | :-----: | :-------------------------------------------------------------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------------: | :---: | :----: | :----: | :---------------------------------------------------------: | :---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| backbone & logits | Cifar10 | [resnet34](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet34_8xb16_cifar10.py) | [resnet18](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet18_8xb16_cifar10.py) | 93.27 | 95.34 | 94.82 | [config](./dafl_logits_resnet34_resnet18_8xb256_cifar10.py) | [teacher](https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_b16x8_cifar10_20210528-a8aa36a6.pth) \|[model](https://download.openmmlab.com/mmrazor/v1/DAFL/dafl_logits_resnet34_resnet18_8xb256_cifar10_20220815_202654-67142167.pth) \| [log](https://download.openmmlab.com/mmrazor/v1/DAFL/dafl_logits_resnet34_resnet18_8xb256_cifar10_20220815_202654-67142167.json) | ## Citation @@ -36,7 +36,3 @@ Learning portable neural networks is very essential for computer vision for the biburl = {https://dblp.org/rec/conf/iccv/ChenW0YLSXX019.bib}, bibsource = {dblp computer science bibliography, https://dblp.org} ``` - -## Acknowledgement - -Shout out to Davidgzx. diff --git a/configs/distill/mmcls/dafl/dafl_logits_r34_r18_8xb256_cifar10.py b/configs/distill/mmcls/dafl/dafl_logits_resnet34_resnet18_8xb256_cifar10.py similarity index 95% rename from configs/distill/mmcls/dafl/dafl_logits_r34_r18_8xb256_cifar10.py rename to configs/distill/mmcls/dafl/dafl_logits_resnet34_resnet18_8xb256_cifar10.py index 43bd8482..5a3a3f26 100644 --- a/configs/distill/mmcls/dafl/dafl_logits_r34_r18_8xb256_cifar10.py +++ b/configs/distill/mmcls/dafl/dafl_logits_resnet34_resnet18_8xb256_cifar10.py @@ -4,6 +4,7 @@ _base_ = [ 'mmcls::_base_/default_runtime.py' ] +res34_ckpt_path = 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_b16x8_cifar10_20210528-a8aa36a6.pth' # noqa: E501 model = dict( _scope_='mmrazor', type='DAFLDataFreeDistillation', @@ -21,7 +22,7 @@ model = dict( build_cfg=dict( cfg_path='mmcls::resnet/resnet34_8xb16_cifar10.py', pretrained=True), - ckpt_path='resnet34_b16x8_cifar10_20210528-a8aa36a6.pth')), + ckpt_path=res34_ckpt_path)), generator=dict( type='DAFLGenerator', img_size=32, diff --git a/configs/distill/mmcls/dafl/metafile.yml b/configs/distill/mmcls/dafl/metafile.yml new file mode 100644 index 00000000..9438b878 --- /dev/null +++ b/configs/distill/mmcls/dafl/metafile.yml @@ -0,0 +1,43 @@ +Collections: + - Name: DAFL + Metadata: + Training Data: + - CIFAR-10 + Paper: + URL: https://doi.org/10.1109/ICCV.2019.00361 + Title: Data-Free Learning of Student Networks + README: configs/distill/mmcls/dafl/README.md + Converted From: + Code: + URL: https://github.com/huawei-noah/Efficient-Computing/tree/master/Data-Efficient-Model-Compression/DAFL +Models: + - Name: dafl_logits_resnet34_resnet18_8xb256_cifar10 + In Collection: DAFL + Metadata: + inference time (ms/im): + - value: 0.34 + hardware: NVIDIA A100-SXM4-80GB + backend: PyTorch + batch size: 256 + mode: FP32 + resolution: (32, 32) + Location: logits + Student: + Config: mmcls::resnet/resnet18_8xb16_cifar10.py + Weights: https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_b16x8_cifar10_20210528-bd6371c8.pth + Metrics: + Top 1 Accuracy: 94.82 + Top 5 Accuracy: 99.87 + Teacher: + Config: mmcls::resnet/resnet34_8xb16_cifar10.py + Weights: https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_b16x8_cifar10_20210528-a8aa36a6.pth + Metrics: + Top 1 Accuracy: 95.34 + Top 5 Accuracy: 99.87 + Results: + - Task: Image Classification + Dataset: CIFAR-10 + Metrics: + Top 1 Accuracy: 93.27 + Config: configs/distill/mmcls/dafl/dafl_logits_resnet34_resnet18_8xb256_cifar10.py + Weights: https://download.openmmlab.com/mmrazor/v1/DAFL/dafl_logits_resnet34_resnet18_8xb256_cifar10_20220815_202654-67142167.pth diff --git a/configs/distill/mmcls/dfad/README.md b/configs/distill/mmcls/dfad/README.md index ac86adb2..4f81fcc4 100644 --- a/configs/distill/mmcls/dfad/README.md +++ b/configs/distill/mmcls/dfad/README.md @@ -14,9 +14,9 @@ Knowledge Distillation (KD) has made remarkable progress in the last few years a ### Classification -| Location | Dataset | Teacher | Student | Acc | Acc(T) | Acc(S) | Config | Download | -| :------: | :-----: | :-------------------------------------------------------------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------------: | :---: | :----: | :----: | :----------------------------------------------: | :------------------------------------------------------------------------------------------------------------------------------------------- | -| logits | Cifar10 | [resnet34](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet34_8xb16_cifar10.py) | [resnet18](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet18_8xb16_cifar10.py) | 93.26 | 95.34 | 94.82 | [config](./dfad_logits_r34_r18_8xb32_cifar10.py) | [teacher](https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_b16x8_cifar10_20210528-a8aa36a6.pth) \|[model](<>) \| [log](<>) | +| Location | Dataset | Teacher | Student | Acc | Acc(T) | Acc(S) | Config | Download | +| :------: | :-----: | :-------------------------------------------------------------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------------: | :---: | :----: | :----: | :--------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| logits | Cifar10 | [resnet34](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet34_8xb16_cifar10.py) | [resnet18](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet18_8xb16_cifar10.py) | 92.80 | 95.34 | 94.82 | [config](./dfad_logits_resnet34_resnet18_8xb32_cifar10.py) | [teacher](https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_b16x8_cifar10_20210528-a8aa36a6.pth) \|[model](https://download.openmmlab.com/mmrazor/v1/DFAD/dfad_logits_resnet34_resnet18_8xb32_cifar10_20220819_051141-961a5b09.pth) \| [log](https://download.openmmlab.com/mmrazor/v1/DFAD/dfad_logits_resnet34_resnet18_8xb32_cifar10_20220819_051141-961a5b09.json) | ## Citation @@ -28,7 +28,3 @@ Knowledge Distillation (KD) has made remarkable progress in the last few years a year={2019} } ``` - -## Acknowledgement - -Appreciate Davidgzx's contribution. diff --git a/configs/distill/mmcls/dfad/dfad_logits_r34_r18_8xb32_cifar10.py b/configs/distill/mmcls/dfad/dfad_logits_resnet34_resnet18_8xb32_cifar10.py similarity index 95% rename from configs/distill/mmcls/dfad/dfad_logits_r34_r18_8xb32_cifar10.py rename to configs/distill/mmcls/dfad/dfad_logits_resnet34_resnet18_8xb32_cifar10.py index b654b195..59bc4d32 100644 --- a/configs/distill/mmcls/dfad/dfad_logits_r34_r18_8xb32_cifar10.py +++ b/configs/distill/mmcls/dfad/dfad_logits_resnet34_resnet18_8xb32_cifar10.py @@ -4,6 +4,7 @@ _base_ = [ 'mmcls::_base_/default_runtime.py' ] +res34_ckpt_path = 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_b16x8_cifar10_20210528-a8aa36a6.pth' # noqa: E501 model = dict( _scope_='mmrazor', type='DataFreeDistillation', @@ -21,7 +22,7 @@ model = dict( build_cfg=dict( cfg_path='mmcls::resnet/resnet34_8xb16_cifar10.py', pretrained=True), - ckpt_path='resnet34_b16x8_cifar10_20210528-a8aa36a6.pth')), + ckpt_path=res34_ckpt_path)), generator=dict( type='DAFLGenerator', img_size=32, diff --git a/configs/distill/mmcls/dfad/metafile.yml b/configs/distill/mmcls/dfad/metafile.yml new file mode 100644 index 00000000..0601f895 --- /dev/null +++ b/configs/distill/mmcls/dfad/metafile.yml @@ -0,0 +1,43 @@ +Collections: + - Name: DFAD + Metadata: + Training Data: + - CIFAR-10 + Paper: + URL: https://arxiv.org/pdf/1912.11006.pdf + Title: Data-Free Adversarial Distillation + README: configs/distill/mmcls/dfad/README.md + Converted From: + Code: + URL: https://github.com/VainF/Data-Free-Adversarial-Distillation +Models: + - Name: dfad_logits_resnet34_resnet18_8xb32_cifar10 + In Collection: DFAD + Metadata: + inference time (ms/im): + - value: 0.38 + hardware: NVIDIA A100-SXM4-80GB + backend: PyTorch + batch size: 32 + mode: FP32 + resolution: (32, 32) + Location: logits + Student: + Config: mmcls::resnet/resnet18_8xb16_cifar10.py + Weights: https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_b16x8_cifar10_20210528-bd6371c8.pth + Metrics: + Top 1 Accuracy: 94.82 + Top 5 Accuracy: 99.87 + Teacher: + Config: mmcls::resnet/resnet34_8xb16_cifar10.py + Weights: https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_b16x8_cifar10_20210528-a8aa36a6.pth + Metrics: + Top 1 Accuracy: 95.34 + Top 5 Accuracy: 99.87 + Results: + - Task: Image Classification + Dataset: CIFAR-10 + Metrics: + Top 1 Accuracy: 92.80 + Config: configs/distill/mmcls/dfad/dfad_logits_resnet34_resnet18_8xb32_cifar10.py + Weights: https://download.openmmlab.com/mmrazor/v1/DFAD/dfad_logits_resnet34_resnet18_8xb32_cifar10_20220819_051141-961a5b09.pth diff --git a/configs/distill/mmcls/fitnet/README.md b/configs/distill/mmcls/fitnets/README.md similarity index 72% rename from configs/distill/mmcls/fitnet/README.md rename to configs/distill/mmcls/fitnets/README.md index f9616ea9..8a6515a3 100644 --- a/configs/distill/mmcls/fitnet/README.md +++ b/configs/distill/mmcls/fitnets/README.md @@ -26,9 +26,9 @@ almost 10.4 times less parameters outperforms a larger, state-of-the-art teacher ### Classification -| Location | Dataset | Teacher | Student | Acc | Acc(T) | Acc(S) | Config | Download | -| :---------------: | :------: | :----------------------------------------------------------------------------------------------------------: | :----------------------------------------------------------------------------------------------------------: | :---: | :----: | :----: | :----------------------------------------------------------------: | :---------------------------------------------------------------------------------------------------------------------------------------- | -| backbone & logits | ImageNet | [resnet50](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet50_8xb32_in1k.py) | [resnet18](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet18_8xb32_in1k.py) | 70.85 | 76.55 | 69.90 | [config](./fitnet_backbone_logits_resnet50_resnet18_8xb16_in1k.py) | [teacher](https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth) \|[model](<>) \| [log](<>) | +| Location | Dataset | Teacher | Student | Acc | Acc(T) | Acc(S) | Config | Download | +| :---------------: | :------: | :----------------------------------------------------------------------------------------------------------: | :----------------------------------------------------------------------------------------------------------: | :---: | :----: | :----: | :-----------------------------------------------------------------: | :----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| backbone & logits | ImageNet | [resnet50](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet50_8xb32_in1k.py) | [resnet18](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet18_8xb32_in1k.py) | 70.58 | 76.55 | 69.90 | [config](./fitnets_backbone_logits_resnet50_resnet18_8xb32_in1k.py) | [teacher](https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth) \|[model](https://download.openmmlab.com/mmrazor/v1/FieNets/fitnets_backbone_logits_resnet50_resnet18_8xb32_in1k_20220830_155608-00ccdbe2.pth) \| [log](https://download.openmmlab.com/mmrazor/v1/FieNets/fitnets_backbone_logits_resnet50_resnet18_8xb32_in1k_20220830_155608-00ccdbe2.json) | ## Citation diff --git a/configs/distill/mmcls/fitnet/fitnet_backbone_logits_resnet50_resnet18_8xb32_in1k.py b/configs/distill/mmcls/fitnets/fitnets_backbone_logits_resnet50_resnet18_8xb32_in1k.py similarity index 94% rename from configs/distill/mmcls/fitnet/fitnet_backbone_logits_resnet50_resnet18_8xb32_in1k.py rename to configs/distill/mmcls/fitnets/fitnets_backbone_logits_resnet50_resnet18_8xb32_in1k.py index 529aa751..3ee855ae 100644 --- a/configs/distill/mmcls/fitnet/fitnet_backbone_logits_resnet50_resnet18_8xb32_in1k.py +++ b/configs/distill/mmcls/fitnets/fitnets_backbone_logits_resnet50_resnet18_8xb32_in1k.py @@ -4,6 +4,7 @@ _base_ = [ 'mmcls::_base_/default_runtime.py' ] +teacher_ckpt = 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth' # noqa: E501 model = dict( _scope_='mmrazor', type='SingleTeacherDistill', @@ -18,7 +19,7 @@ model = dict( cfg_path='mmcls::resnet/resnet18_8xb32_in1k.py', pretrained=False), teacher=dict( cfg_path='mmcls::resnet/resnet50_8xb32_in1k.py', pretrained=True), - teacher_ckpt='resnet50_8xb32_in1k_20210831-ea4938fc.pth', + teacher_ckpt=teacher_ckpt, distiller=dict( type='ConfigurableDistiller', student_recorders=dict( diff --git a/configs/distill/mmcls/fitnets/metafile.yml b/configs/distill/mmcls/fitnets/metafile.yml new file mode 100644 index 00000000..dc48811d --- /dev/null +++ b/configs/distill/mmcls/fitnets/metafile.yml @@ -0,0 +1,40 @@ +Collections: + - Name: FitNets + Metadata: + Training Data: + - ImageNet-1k + Paper: + URL: https://arxiv.org/abs/1412.6550 + Title: FitNets- Hints for Thin Deep Nets + README: configs/distill/mmcls/fitnet/README.md +Models: + - Name: fitnets_backbone_logits_resnet50_resnet18_8xb32_in1k + In Collection: FitNets + Metadata: + inference time (ms/im): + - value: 0.18 + hardware: NVIDIA A100-SXM4-80GB + backend: PyTorch + batch size: 32 + mode: FP32 + resolution: (224, 224) + Location: backbone & logits + Student: + Config: mmcls::resnet/resnet18_8xb32_in1k.py + Weights: https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_8xb32_in1k_20210831-fbbb1da6.pth + Metrics: + Top 1 Accuracy: 69.90 + Top 5 Accuracy: 89.43 + Teacher: + Config: mmcls::resnet/resnet50_8xb32_in1k.py + Weights: https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth + Metrics: + Top 1 Accuracy: 76.55 + Top 5 Accuracy: 93.06 + Results: + - Task: Image Classification + Dataset: ImageNet-1k + Metrics: + Top 1 Accuracy: 70.58 + Config: configs/distill/mmcls/fitnet/fitnets_backbone_logits_resnet50_resnet18_8xb32_in1k.py + Weights: https://download.openmmlab.com/mmrazor/v1/FieNets/fitnets_backbone_logits_resnet50_resnet18_8xb32_in1k_20220830_155608-00ccdbe2.pth diff --git a/configs/distill/mmcls/zskt/README.md b/configs/distill/mmcls/zskt/README.md index 51ca46b0..fff12361 100644 --- a/configs/distill/mmcls/zskt/README.md +++ b/configs/distill/mmcls/zskt/README.md @@ -20,9 +20,9 @@ Performing knowledge transfer from a large teacher network to a smaller student ### Classification -| Location | Dataset | Teacher | Student | Acc | Acc(T) | Acc(S) | Config | Download | -| :---------------: | :-----: | :-------------------------------------------------------------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------------: | :---: | :----: | :----: | :-------------------------------------------------------: | :------------------------------------------------------------------------------------------------------------------------------------------- | -| backbone & logits | Cifar10 | [resnet34](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet34_8xb16_cifar10.py) | [resnet18](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet18_8xb16_cifar10.py) | 93.50 | 95.34 | 94.82 | [config](./zskt_backbone_logits_r34_r18_8xb16_cifar10.py) | [teacher](https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_b16x8_cifar10_20210528-a8aa36a6.pth) \|[model](<>) \| [log](<>) | +| Location | Dataset | Teacher | Student | Acc | Acc(T) | Acc(S) | Config | Download | +| :---------------: | :-----: | :-------------------------------------------------------------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------------: | :---: | :----: | :----: | :-----------------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| backbone & logits | Cifar10 | [resnet34](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet34_8xb16_cifar10.py) | [resnet18](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet18_8xb16_cifar10.py) | 93.05 | 95.34 | 94.82 | [config](./zskt_backbone_logits_resnet34_resnet18_8xb16_cifar10.py) | [teacher](https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_b16x8_cifar10_20210528-a8aa36a6.pth) \|[model](https://download.openmmlab.com/mmrazor/v1/ZSKT/zskt_backbone_logits_resnet34_resnet18_8xb16_cifar10_20220823_114006-28584c2e.pth) \| [log](https://download.openmmlab.com/mmrazor/v1/ZSKT/zskt_backbone_logits_resnet34_resnet18_8xb16_cifar10_20220823_114006-28584c2e.json) | ## Citation @@ -35,7 +35,3 @@ Performing knowledge transfer from a large teacher network to a smaller student year={2019} } ``` - -## Acknowledgement - -Appreciate Davidgzx's contribution. diff --git a/configs/distill/mmcls/zskt/metafile.yml b/configs/distill/mmcls/zskt/metafile.yml new file mode 100644 index 00000000..54494fa7 --- /dev/null +++ b/configs/distill/mmcls/zskt/metafile.yml @@ -0,0 +1,43 @@ +Collections: + - Name: ZSKT + Metadata: + Training Data: + - CIFAR-10 + Paper: + URL: https://arxiv.org/abs/1905.09768 + Title: Zero-shot Knowledge Transfer via Adversarial Belief Matching + README: configs/distill/mmcls/zskt/README.md + Converted From: + Code: + URL: https://github.com/polo5/ZeroShotKnowledgeTransfer +Models: + - Name: zskt_backbone_logits_resnet34_resnet18_8xb16_cifar10 + In Collection: ZSKT + Metadata: + inference time (ms/im): + - value: 0.12 + hardware: NVIDIA A100-SXM4-80GB + backend: PyTorch + batch size: 16 + mode: FP32 + resolution: (32, 32) + Location: backbone & logits + Student: + Config: mmcls::resnet/resnet18_8xb16_cifar10.py + Weights: https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_b16x8_cifar10_20210528-bd6371c8.pth + Metrics: + Top 1 Accuracy: 94.82 + Top 5 Accuracy: 99.87 + Teacher: + Config: mmcls::resnet/resnet34_8xb16_cifar10.py + Weights: https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_b16x8_cifar10_20210528-a8aa36a6.pth + Metrics: + Top 1 Accuracy: 95.34 + Top 5 Accuracy: 99.87 + Results: + - Task: Image Classification + Dataset: CIFAR-10 + Metrics: + Top 1 Accuracy: 93.05 + Config: configs/distill/mmcls/zskt/zskt_backbone_logits_resnet34_resnet18_8xb16_cifar10.py + Weights: https://download.openmmlab.com/mmrazor/v1/ZSKT/zskt_backbone_logits_resnet34_resnet18_8xb16_cifar10_20220823_114006-28584c2e.pth diff --git a/configs/distill/mmcls/zskt/zskt_backbone_logits_r34_r18_8xb16_cifar10.py b/configs/distill/mmcls/zskt/zskt_backbone_logits_resnet34_resnet18_8xb16_cifar10.py similarity index 78% rename from configs/distill/mmcls/zskt/zskt_backbone_logits_r34_r18_8xb16_cifar10.py rename to configs/distill/mmcls/zskt/zskt_backbone_logits_resnet34_resnet18_8xb16_cifar10.py index e7b4c4c8..5c1ab424 100644 --- a/configs/distill/mmcls/zskt/zskt_backbone_logits_r34_r18_8xb16_cifar10.py +++ b/configs/distill/mmcls/zskt/zskt_backbone_logits_resnet34_resnet18_8xb16_cifar10.py @@ -4,6 +4,7 @@ _base_ = [ 'mmcls::_base_/default_runtime.py' ] +res34_ckpt_path = 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_b16x8_cifar10_20210528-a8aa36a6.pth' # noqa: E501 model = dict( _scope_='mmrazor', type='DataFreeDistillation', @@ -17,11 +18,11 @@ model = dict( architecture=dict( cfg_path='mmcls::resnet/resnet18_8xb16_cifar10.py', pretrained=False), teachers=dict( - r34=dict( + res34=dict( build_cfg=dict( cfg_path='mmcls::resnet/resnet34_8xb16_cifar10.py', pretrained=True), - ckpt_path='resnet34_b16x8_cifar10_20210528-a8aa36a6.pth')), + ckpt_path=res34_ckpt_path)), generator=dict( type='ZSKTGenerator', img_size=32, latent_dim=256, hidden_channels=128), @@ -34,15 +35,15 @@ model = dict( bb_s4=dict(type='ModuleOutputs', source='backbone.layer4.1.relu'), fc=dict(type='ModuleOutputs', source='head.fc')), teacher_recorders=dict( - r34_bb_s1=dict( - type='ModuleOutputs', source='r34.backbone.layer1.2.relu'), - r34_bb_s2=dict( - type='ModuleOutputs', source='r34.backbone.layer2.3.relu'), - r34_bb_s3=dict( - type='ModuleOutputs', source='r34.backbone.layer3.5.relu'), - r34_bb_s4=dict( - type='ModuleOutputs', source='r34.backbone.layer4.2.relu'), - r34_fc=dict(type='ModuleOutputs', source='r34.head.fc')), + res34_bb_s1=dict( + type='ModuleOutputs', source='res34.backbone.layer1.2.relu'), + res34_bb_s2=dict( + type='ModuleOutputs', source='res34.backbone.layer2.3.relu'), + res34_bb_s3=dict( + type='ModuleOutputs', source='res34.backbone.layer3.5.relu'), + res34_bb_s4=dict( + type='ModuleOutputs', source='res34.backbone.layer4.2.relu'), + res34_fc=dict(type='ModuleOutputs', source='res34.head.fc')), distill_losses=dict( loss_s1=dict(type='ATLoss', loss_weight=250.0), loss_s2=dict(type='ATLoss', loss_weight=250.0), @@ -55,31 +56,31 @@ model = dict( s_feature=dict( from_student=True, recorder='bb_s1', record_idx=1), t_feature=dict( - from_student=False, recorder='r34_bb_s1', record_idx=1)), + from_student=False, recorder='res34_bb_s1', record_idx=1)), loss_s2=dict( s_feature=dict( from_student=True, recorder='bb_s2', record_idx=1), t_feature=dict( - from_student=False, recorder='r34_bb_s2', record_idx=1)), + from_student=False, recorder='res34_bb_s2', record_idx=1)), loss_s3=dict( s_feature=dict( from_student=True, recorder='bb_s3', record_idx=1), t_feature=dict( - from_student=False, recorder='r34_bb_s3', record_idx=1)), + from_student=False, recorder='res34_bb_s3', record_idx=1)), loss_s4=dict( s_feature=dict( from_student=True, recorder='bb_s4', record_idx=1), t_feature=dict( - from_student=False, recorder='r34_bb_s4', record_idx=1)), + from_student=False, recorder='res34_bb_s4', record_idx=1)), loss_kl=dict( preds_S=dict(from_student=True, recorder='fc'), - preds_T=dict(from_student=False, recorder='r34_fc')))), + preds_T=dict(from_student=False, recorder='res34_fc')))), generator_distiller=dict( type='ConfigurableDistiller', student_recorders=dict( fc=dict(type='ModuleOutputs', source='head.fc')), teacher_recorders=dict( - r34_fc=dict(type='ModuleOutputs', source='r34.head.fc')), + res34_fc=dict(type='ModuleOutputs', source='res34.head.fc')), distill_losses=dict( loss_kl=dict( type='KLDivergence', @@ -89,7 +90,7 @@ model = dict( loss_forward_mappings=dict( loss_kl=dict( preds_S=dict(from_student=True, recorder='fc'), - preds_T=dict(from_student=False, recorder='r34_fc')))), + preds_T=dict(from_student=False, recorder='res34_fc')))), student_iter=10) # model wrapper diff --git a/configs/distill/mmdet/fbkd/README.md b/configs/distill/mmdet/fbkd/README.md index 99e6cc64..cded5983 100644 --- a/configs/distill/mmdet/fbkd/README.md +++ b/configs/distill/mmdet/fbkd/README.md @@ -14,9 +14,9 @@ Knowledge distillation, in which a student model is trained to mimic a teacher m ### Detection -| Location | Dataset | Teacher | Student | box AP | box AP(T) | box AP(S) | Config | Download | -| :------: | :-----: | :-------------------------------------------------------------------------------------------------------------------------------: | :-----------------------------------------------------------------------------------------------------------------------------: | :----: | :-------: | :-------: | :--------------------------------------------------: | :-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| neck | COCO | [fasterrcnn_resnet101](https://github.com/open-mmlab/mmdetection/blob/master/configs/faster_rcnn/faster_rcnn_r101_fpn_1x_coco.py) | [fasterrcnn_resnet50](https://github.com/open-mmlab/mmdetection/blob/master/configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py) | 39.1 | 39.4 | 37.8 | [config](./fbkd_fpn_frcnn_r101_frcnn_r50_1x_coco.py) | [teacher](https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r101_fpn_1x_coco/faster_rcnn_r101_fpn_1x_coco_20200130-f513f705.pth) \|[model](<>) \| [log](<>) | +| Location | Dataset | Teacher | Student | box AP | box AP(T) | box AP(S) | Config | Download | +| :------: | :-----: | :--------------------------------------------------------------------------------------------------------------------------------: | :------------------------------------------------------------------------------------------------------------------------------: | :----: | :-------: | :-------: | :------------------------------------------------------------: | :----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| neck | COCO | [faster-rcnn_resnet101](https://github.com/open-mmlab/mmdetection/blob/master/configs/faster_rcnn/faster_rcnn_r101_fpn_1x_coco.py) | [faster-rcnn_resnet50](https://github.com/open-mmlab/mmdetection/blob/master/configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py) | 39.3 | 39.4 | 37.4 | [config](./fbkd_fpn_frcnn_resnet101_frcnn_resnet50_1x_coco.py) | [teacher](https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r101_fpn_1x_coco/faster_rcnn_r101_fpn_1x_coco_20200130-f513f705.pth) \|[model](https://download.openmmlab.com/mmrazor/v1/FBKD/fbkd_fpn_frcnn_resnet101_frcnn_resnet50_1x_coco_20220830_121522-8d7e11df.pth) \| [log](https://download.openmmlab.com/mmrazor/v1/FBKD/fbkd_fpn_frcnn_resnet101_frcnn_resnet50_1x_coco_20220830_121522-8d7e11df.json) | ## Citation diff --git a/configs/distill/mmdet/fbkd/fbkd_fpn_frcnn_r101_frcnn_r50_1x_coco.py b/configs/distill/mmdet/fbkd/fbkd_fpn_frcnn_resnet101_frcnn_resnet50_1x_coco.py similarity index 93% rename from configs/distill/mmdet/fbkd/fbkd_fpn_frcnn_r101_frcnn_r50_1x_coco.py rename to configs/distill/mmdet/fbkd/fbkd_fpn_frcnn_resnet101_frcnn_resnet50_1x_coco.py index ef72b880..3203e3fe 100644 --- a/configs/distill/mmdet/fbkd/fbkd_fpn_frcnn_r101_frcnn_r50_1x_coco.py +++ b/configs/distill/mmdet/fbkd/fbkd_fpn_frcnn_resnet101_frcnn_resnet50_1x_coco.py @@ -4,16 +4,17 @@ _base_ = [ 'mmdet::_base_/default_runtime.py' ] +teacher_ckpt = 'https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r101_fpn_1x_coco/faster_rcnn_r101_fpn_1x_coco_20200130-f513f705.pth' # noqa: E501 model = dict( _scope_='mmrazor', type='SingleTeacherDistill', architecture=dict( - cfg_path='mmdet::faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py', + cfg_path='mmdet::faster_rcnn/faster-rcnn_r50_fpn_1x_coco.py', pretrained=True), teacher=dict( - cfg_path='mmdet::faster_rcnn/faster_rcnn_r101_fpn_1x_coco.py', + cfg_path='mmdet::faster_rcnn/faster-rcnn_r101_fpn_1x_coco.py', pretrained=False), - teacher_ckpt='faster_rcnn_r101_fpn_1x_coco_20200130-f513f705.pth', + teacher_ckpt=teacher_ckpt, distiller=dict( type='ConfigurableDistiller', student_recorders=dict( diff --git a/configs/distill/mmdet/fbkd/metafile.yml b/configs/distill/mmdet/fbkd/metafile.yml new file mode 100644 index 00000000..29925787 --- /dev/null +++ b/configs/distill/mmdet/fbkd/metafile.yml @@ -0,0 +1,41 @@ +Collections: + - Name: FBKD + Metadata: + Training Data: + - COCO + Paper: + URL: https://openreview.net/pdf?id=uKhGRvM8QNH + Title: IMPROVE OBJECT DETECTION WITH FEATURE-BASED KNOWLEDGE DISTILLATION- TOWARDS ACCURATE AND EFFICIENT DETECTORS + README: configs/distill/mmdet/fbkd/README.md + Converted From: + Code: + URL: https://github.com/ArchipLab-LinfengZhang/Object-Detection-Knowledge-Distillation-ICLR2021 +Models: + - Name: fbkd_fpn_faster-rcnn_resnet101_faster-rcnn_resnet50_1x_coco + In Collection: FBKD + Metadata: + inference time (ms/im): + - value: 0.32 + hardware: NVIDIA A100-SXM4-80GB + backend: PyTorch + batch size: 2 + mode: FP32 + resolution: (1333, 800) + Location: fpn + Student: + Metrics: + box AP: 37.4 + Config: mmdet::faster_rcnn/faster-rcnn_r50_fpn_1x_coco.py + Weights: https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r50_fpn_1x_coco/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth + Teacher: + Metrics: + box AP: 39.4 + Config: mmdet::faster_rcnn/faster-rcnn_r101_fpn_1x_coco.py + Weights: https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r101_fpn_1x_coco/faster_rcnn_r101_fpn_1x_coco_20200130-f513f705.pth + Results: + - Task: Object Detection + Dataset: COCO + Metrics: + box AP: 39.3 + Config: configs/distill/mmdet/fbkd/fbkd_fpn_faster-rcnn_resnet101_faster-rcnn_resnet50_1x_coco.py + Weights: https://download.openmmlab.com/mmrazor/v1/FBKD/fbkd_fpn_frcnn_resnet101_frcnn_resnet50_1x_coco_20220830_121522-8d7e11df.pth diff --git a/mmrazor/models/algorithms/distill/configurable/datafree_distillation.py b/mmrazor/models/algorithms/distill/configurable/datafree_distillation.py index 2efb8d81..1aff807c 100644 --- a/mmrazor/models/algorithms/distill/configurable/datafree_distillation.py +++ b/mmrazor/models/algorithms/distill/configurable/datafree_distillation.py @@ -78,12 +78,12 @@ class DataFreeDistillation(BaseAlgorithm): """Alias for ``architecture``.""" return self.architecture - def train_step(self, data: List[dict], + def train_step(self, data: Dict[str, List[dict]], optim_wrapper: OptimWrapper) -> Dict[str, torch.Tensor]: """Train step for DataFreeDistillation. Args: - data (List[dict]): Data sampled by dataloader. + data (Dict[str, List[dict]]): Data sampled by dataloader. optim_wrapper (OptimWrapper): A wrapper of optimizer to update parameters. """ @@ -107,16 +107,16 @@ class DataFreeDistillation(BaseAlgorithm): return log_vars def train_student( - self, data: List[dict], optimizer: OPTIMIZERS + self, data: Dict[str, List[dict]], optimizer: OPTIMIZERS ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: """Train step for the student model. Args: - data (List[dict]): Data sampled by dataloader. + data (Dict[str, List[dict]]): Data sampled by dataloader. optimizer (OPTIMIZERS): The optimizer to update student. """ log_vars = dict() - batch_size = len(data) + batch_size = len(data['inputs']) for _ in range(self.student_iter): fakeimg_init = torch.randn( @@ -124,13 +124,14 @@ class DataFreeDistillation(BaseAlgorithm): fakeimg = self.generator(fakeimg_init, batch_size).detach() with optimizer.optim_context(self.student): - _, data_samples = self.data_preprocessor(data, True) + pseudo_data = self.data_preprocessor(data, True) + pseudo_data_samples = pseudo_data['data_samples'] # recorde the needed information with self.distiller.student_recorders: - _ = self.student(fakeimg, data_samples, mode='loss') + _ = self.student(fakeimg, pseudo_data_samples, mode='loss') with self.distiller.teacher_recorders, torch.no_grad(): for _, teacher in self.teachers.items(): - _ = teacher(fakeimg, data_samples, mode='loss') + _ = teacher(fakeimg, pseudo_data_samples, mode='loss') loss_distill = self.distiller.compute_distill_losses() distill_loss, distill_log_vars = self.parse_losses(loss_distill) @@ -140,15 +141,15 @@ class DataFreeDistillation(BaseAlgorithm): return distill_loss, log_vars def train_generator( - self, data: List[dict], optimizer: OPTIMIZERS + self, data: Dict[str, List[dict]], optimizer: OPTIMIZERS ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: """Train step for the generator. Args: - data (List[dict]): Data sampled by dataloader. + data (Dict[str, List[dict]]): Data sampled by dataloader. optimizer (OPTIMIZERS): The optimizer to update generator. """ - batch_size = len(data) + batch_size = len(data['inputs']) fakeimg_init = torch.randn( (batch_size, self.generator.module.latent_dim)) fakeimg = self.generator(fakeimg_init, batch_size) @@ -174,17 +175,17 @@ class DataFreeDistillation(BaseAlgorithm): @MODELS.register_module() class DAFLDataFreeDistillation(DataFreeDistillation): - def train_step(self, data: List[dict], + def train_step(self, data: Dict[str, List[dict]], optim_wrapper: OptimWrapper) -> Dict[str, torch.Tensor]: """DAFL train step. Args: - data (List[dict]): Data sampled by dataloader. + data (Dict[str, List[dict]): Data sampled by dataloader. optim_wrapper (OptimWrapper): A wrapper of optimizer to update parameters. """ log_vars = dict() - batch_size = len(data) + batch_size = len(data['inputs']) for _, teacher in self.teachers.items(): teacher.eval() diff --git a/tests/test_models/test_losses/test_distillation_losses.py b/tests/test_models/test_losses/test_distillation_losses.py index 3efe1314..c3ab1694 100644 --- a/tests/test_models/test_losses/test_distillation_losses.py +++ b/tests/test_models/test_losses/test_distillation_losses.py @@ -107,7 +107,6 @@ class TestLosses(TestCase): ie_loss = InformationEntropyLoss(**dafl_loss_cfg, gather=True) ie_loss.world_size = 2 - # TODO: configure circle CI to test UT under multi torch versions. if digit_version(torch.__version__) >= digit_version('1.8.0'): with self.assertRaisesRegex( RuntimeError,