[Docs] Add metafiles for KD algos (ABLoss, DAFL, DFAD, FBKD, FitNets, ZSKT) (#266)
1.Add metafiles for 6 kd algos. 2.Add model and log links. 3.Revise data_samples in datafreedistillation for new feature of mmengine.pull/270/head
parent
24e106ba1d
commit
f69aeabc69
|
@ -14,9 +14,10 @@ An activation boundary for a neuron refers to a separating hyperplane that deter
|
||||||
|
|
||||||
### Classification
|
### Classification
|
||||||
|
|
||||||
| Location | Dataset | Teacher | Student | Acc | Acc(T) | Acc(S) | Config | Download |
|
| Location | Dataset | Teacher | Student | Acc | Acc(T) | Acc(S) | Config | Download |
|
||||||
| :----------------------------------: | :------: | :----------------------------------------------------------------------------------------------------------: | :----------------------------------------------------------------------------------------------------------: | :---: | :----: | :----: | :--------------------------------------------------------------------------------------------------------------------------------------: | :---------------------------------------------------------------------------------------------------------------------------------------- |
|
| :-----------------: | :------: | :----------------------------------------------------------------------------------------------------------: | :----------------------------------------------------------------------------------------------------------: | :---: | :----: | :----: | :------------------------------------------------------------------------: | :----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||||
| backbone (pretrain) & logits (train) | ImageNet | [resnet50](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet50_8xb32_in1k.py) | [resnet18](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet18_8xb32_in1k.py) | 70.58 | 76.55 | 69.90 | [pretrain_config](./abloss_backbone_resnet50_resnet18_8xb32_in1k_pretrain) [train_config](./abloss_head_resnet50_resnet18_8xb32_in1k.py) | [teacher](https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth) \|[model](<>) \| [log](<>) |
|
| backbone (pretrain) | ImageNet | [resnet50](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet50_8xb32_in1k.py) | [resnet18](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet18_8xb32_in1k.py) | | 76.55 | 69.90 | [pretrain_config](./abloss_pretrain_backbone_resnet50_resnet18_8xb32_in1k) | [teacher](https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth) \|[model](https://download.openmmlab.com/mmrazor/v1/ABLoss/abloss_pretrain_backbone_resnet50_resnet18_8xb32_in1k_20220830_165724-a6284e9f.pth) \| [log](https://download.openmmlab.com/mmrazor/v1/ABLoss/abloss_pretrain_backbone_resnet50_resnet18_8xb32_in1k_20220830_165724-a6284e9f.json) |
|
||||||
|
| logits (train) | ImageNet | [resnet50](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet50_8xb32_in1k.py) | [resnet18](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet18_8xb32_in1k.py) | 69.94 | 76.55 | 69.90 | [train_config](./abloss_logits_resnet50_resnet18_8xb32_in1k.py) | [teacher](https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth) \|[model](https://download.openmmlab.com/mmrazor/v1/ABLoss/abloss_logits_resnet50_resnet18_8xb32_in1k_20220830_202129-f35edde8.pth) \| [log](https://download.openmmlab.com/mmrazor/v1/ABLoss/abloss_logits_resnet50_resnet18_8xb32_in1k_20220830_202129-f35edde8.json) |
|
||||||
|
|
||||||
## Citation
|
## Citation
|
||||||
|
|
||||||
|
@ -43,31 +44,26 @@ An activation boundary for a neuron refers to a separating hyperplane that deter
|
||||||
|
|
||||||
## Getting Started
|
## Getting Started
|
||||||
|
|
||||||
### ABConnectors and Student pre-training.
|
### Pre-training.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
sh tools/slurm_train.sh $PARTITION $JOB_NAME \
|
sh tools/dist_train.sh configs/distill/mmcls/abloss/abloss_pretrain_backbone_resnet50_resnet18_8xb32_in1k.py 8
|
||||||
configs/distill/mmcls/abloss/abloss_backbone_resnet50_resnet18_8xb32_in1k_pretrain.py\
|
|
||||||
$PRETRAIN_WORK_DIR
|
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### Modify Distillation training config
|
### Modify Distillation training config
|
||||||
|
|
||||||
open file 'configs/distill/mmcls/abloss/abloss_head_resnet50_resnet18_8xb32_in1k.py'
|
open file 'configs/distill/mmcls/abloss/abloss_logits_resnet50_resnet18_8xb32_in1k.py'
|
||||||
|
|
||||||
```python
|
```python
|
||||||
# modify init_cfg in model settings
|
# Modify init_cfg in model settings.
|
||||||
# pretrain_work_dir is same as the PRETRAIN_WORK_DIR in pre-training.
|
# 'pretrain_work_dir' is same as the 'work_dir of pre-training'.
|
||||||
|
# 'last_epoch' defaults to 'epoch_20' in ABLoss.
|
||||||
init_cfg=dict(
|
init_cfg=dict(
|
||||||
type='Pretrained', checkpoint='pretrain_work_dir/last_chechpoint.pth'),
|
type='Pretrained', checkpoint='pretrain_work_dir/last_epoch.pth'),
|
||||||
```
|
```
|
||||||
|
|
||||||
### Distillation training.
|
### Distillation training.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
sh tools/slurm_train.sh $PARTITION $JOB_NAME \
|
sh tools/dist_train.sh configs/distill/mmcls/abloss/abloss_logits_resnet50_resnet18_8xb32_in1k.py 8
|
||||||
configs/distill/mmcls/abloss/abloss_head_resnet50_resnet18_8xb32_in1k.py\
|
|
||||||
$DISTILLATION_WORK_DIR
|
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
|
@ -4,6 +4,8 @@ _base_ = [
|
||||||
'mmcls::_base_/default_runtime.py'
|
'mmcls::_base_/default_runtime.py'
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# Modify pretrain_checkpoint before training.
|
||||||
|
pretrain_checkpoint = 'work_dir_of_abloss_pretrain/last_epoch.pth'
|
||||||
model = dict(
|
model = dict(
|
||||||
_scope_='mmrazor',
|
_scope_='mmrazor',
|
||||||
type='SingleTeacherDistill',
|
type='SingleTeacherDistill',
|
||||||
|
@ -18,8 +20,7 @@ model = dict(
|
||||||
cfg_path='mmcls::resnet/resnet18_8xb32_in1k.py', pretrained=False),
|
cfg_path='mmcls::resnet/resnet18_8xb32_in1k.py', pretrained=False),
|
||||||
teacher=dict(
|
teacher=dict(
|
||||||
cfg_path='mmcls::resnet/resnet50_8xb32_in1k.py', pretrained=False),
|
cfg_path='mmcls::resnet/resnet50_8xb32_in1k.py', pretrained=False),
|
||||||
init_cfg=dict(
|
init_cfg=dict(type='Pretrained', checkpoint=pretrain_checkpoint),
|
||||||
type='Pretrained', checkpoint='pretrain_work_dir/last_chechpoint.pth'),
|
|
||||||
distiller=dict(
|
distiller=dict(
|
||||||
type='ConfigurableDistiller',
|
type='ConfigurableDistiller',
|
||||||
student_recorders=dict(
|
student_recorders=dict(
|
||||||
|
@ -28,7 +29,7 @@ model = dict(
|
||||||
fc=dict(type='ModuleOutputs', source='head.fc')),
|
fc=dict(type='ModuleOutputs', source='head.fc')),
|
||||||
distill_losses=dict(
|
distill_losses=dict(
|
||||||
loss_kl=dict(
|
loss_kl=dict(
|
||||||
type='KLDivergence', loss_weight=200, reduction='mean')),
|
type='KLDivergence', loss_weight=6.25, reduction='mean')),
|
||||||
loss_forward_mappings=dict(
|
loss_forward_mappings=dict(
|
||||||
loss_kl=dict(
|
loss_kl=dict(
|
||||||
preds_S=dict(from_student=True, recorder='fc'),
|
preds_S=dict(from_student=True, recorder='fc'),
|
|
@ -4,8 +4,7 @@ _base_ = [
|
||||||
'mmcls::_base_/default_runtime.py'
|
'mmcls::_base_/default_runtime.py'
|
||||||
]
|
]
|
||||||
|
|
||||||
train_cfg = dict(by_epoch=True, max_epochs=20, val_interval=1)
|
teacher_ckpt = 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth' # noqa: E501
|
||||||
|
|
||||||
model = dict(
|
model = dict(
|
||||||
_scope_='mmrazor',
|
_scope_='mmrazor',
|
||||||
type='SingleTeacherDistill',
|
type='SingleTeacherDistill',
|
||||||
|
@ -20,7 +19,7 @@ model = dict(
|
||||||
cfg_path='mmcls::resnet/resnet18_8xb32_in1k.py', pretrained=False),
|
cfg_path='mmcls::resnet/resnet18_8xb32_in1k.py', pretrained=False),
|
||||||
teacher=dict(
|
teacher=dict(
|
||||||
cfg_path='mmcls::resnet/resnet50_8xb32_in1k.py', pretrained=True),
|
cfg_path='mmcls::resnet/resnet50_8xb32_in1k.py', pretrained=True),
|
||||||
teacher_ckpt='resnet50_8xb32_in1k_20210831-ea4938fc.pth',
|
teacher_ckpt=teacher_ckpt,
|
||||||
calculate_student_loss=False,
|
calculate_student_loss=False,
|
||||||
distiller=dict(
|
distiller=dict(
|
||||||
type='ConfigurableDistiller',
|
type='ConfigurableDistiller',
|
||||||
|
@ -94,4 +93,5 @@ model = dict(
|
||||||
|
|
||||||
find_unused_parameters = True
|
find_unused_parameters = True
|
||||||
|
|
||||||
|
train_cfg = dict(by_epoch=True, max_epochs=20, val_interval=1)
|
||||||
val_cfg = dict(_delete_=True, type='mmrazor.SingleTeacherDistillValLoop')
|
val_cfg = dict(_delete_=True, type='mmrazor.SingleTeacherDistillValLoop')
|
|
@ -0,0 +1,61 @@
|
||||||
|
Collections:
|
||||||
|
- Name: ABLoss
|
||||||
|
Metadata:
|
||||||
|
Training Data:
|
||||||
|
- ImageNet-1k
|
||||||
|
Paper:
|
||||||
|
URL: https://arxiv.org/pdf/1811.03233.pdf
|
||||||
|
Title: Knowledge Transfer via Distillation of Activation Boundaries Formed by Hidden Neurons
|
||||||
|
README: configs/distill/mmcls/abloss/README.md
|
||||||
|
Converted From:
|
||||||
|
Code:
|
||||||
|
URL: https://github.com/bhheo/AB_distillation
|
||||||
|
Models:
|
||||||
|
- Name: abloss_pretrain_backbone_resnet50_resnet18_8xb32_in1k
|
||||||
|
In Collection: ABLoss
|
||||||
|
Metadata:
|
||||||
|
inference time (ms/im):
|
||||||
|
- value: 0.21
|
||||||
|
hardware: NVIDIA A100-SXM4-80GB
|
||||||
|
backend: PyTorch
|
||||||
|
batch size: 32
|
||||||
|
mode: FP32
|
||||||
|
resolution: (224, 224)
|
||||||
|
Location: backbone
|
||||||
|
Student:
|
||||||
|
Config: mmcls::resnet/resnet18_8xb32_in1k.py
|
||||||
|
Weights: https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_8xb32_in1k_20210831-fbbb1da6.pth
|
||||||
|
Metrics:
|
||||||
|
Top 1 Accuracy: 69.90
|
||||||
|
Top 5 Accuracy: 89.43
|
||||||
|
Teacher:
|
||||||
|
Config: mmcls::resnet/resnet50_8xb32_in1k.py
|
||||||
|
Weights: https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth
|
||||||
|
Metrics:
|
||||||
|
Top 1 Accuracy: 76.55
|
||||||
|
Top 5 Accuracy: 93.06
|
||||||
|
Config: configs/distill/mmcls/abloss/abloss_pretrain_backbone_resnet50_resnet18_8xb32_in1k.py
|
||||||
|
Weights: https://download.openmmlab.com/mmrazor/v1/ABLoss/abloss_pretrain_backbone_resnet50_resnet18_8xb32_in1k_20220830_165724-a6284e9f.pth
|
||||||
|
- Name: abloss_logits_resnet50_resnet18_8xb32_in1k
|
||||||
|
In Collection: ABLoss
|
||||||
|
Metadata:
|
||||||
|
Location: logits
|
||||||
|
Student:
|
||||||
|
Config: mmcls::resnet/resnet18_8xb32_in1k.py
|
||||||
|
Weights: https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_8xb32_in1k_20210831-fbbb1da6.pth
|
||||||
|
Metrics:
|
||||||
|
Top 1 Accuracy: 69.90
|
||||||
|
Top 5 Accuracy: 89.43
|
||||||
|
Teacher:
|
||||||
|
Config: mmcls::resnet/resnet50_8xb32_in1k.py
|
||||||
|
Weights: https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth
|
||||||
|
Metrics:
|
||||||
|
Top 1 Accuracy: 76.55
|
||||||
|
Top 5 Accuracy: 93.06
|
||||||
|
Results:
|
||||||
|
- Task: Image Classification
|
||||||
|
Dataset: ImageNet-1k
|
||||||
|
Metrics:
|
||||||
|
Top 1 Accuracy: 69.94
|
||||||
|
Config: configs/distill/mmcls/abloss/abloss_logits_resnet50_resnet18_8xb32_in1k.py
|
||||||
|
Weights: https://download.openmmlab.com/mmrazor/v1/ABLoss/abloss_logits_resnet50_resnet18_8xb32_in1k_20220830_202129-f35edde8.pth
|
|
@ -14,9 +14,9 @@ Learning portable neural networks is very essential for computer vision for the
|
||||||
|
|
||||||
### Classification
|
### Classification
|
||||||
|
|
||||||
| Location | Dataset | Teacher | Student | Acc | Acc(T) | Acc(S) | Config | Download |
|
| Location | Dataset | Teacher | Student | Acc | Acc(T) | Acc(S) | Config | Download |
|
||||||
| :---------------: | :-----: | :-------------------------------------------------------------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------------: | :---: | :----: | :----: | :-----------------------------------------------: | :------------------------------------------------------------------------------------------------------------------------------------------- |
|
| :---------------: | :-----: | :-------------------------------------------------------------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------------: | :---: | :----: | :----: | :---------------------------------------------------------: | :---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||||
| backbone & logits | Cifar10 | [resnet34](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet34_8xb16_cifar10.py) | [resnet18](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet18_8xb16_cifar10.py) | 93.11 | 95.34 | 94.82 | [config](./dafl_logits_r34_r18_8xb256_cifar10.py) | [teacher](https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_b16x8_cifar10_20210528-a8aa36a6.pth) \|[model](<>) \| [log](<>) |
|
| backbone & logits | Cifar10 | [resnet34](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet34_8xb16_cifar10.py) | [resnet18](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet18_8xb16_cifar10.py) | 93.27 | 95.34 | 94.82 | [config](./dafl_logits_resnet34_resnet18_8xb256_cifar10.py) | [teacher](https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_b16x8_cifar10_20210528-a8aa36a6.pth) \|[model](https://download.openmmlab.com/mmrazor/v1/DAFL/dafl_logits_resnet34_resnet18_8xb256_cifar10_20220815_202654-67142167.pth) \| [log](https://download.openmmlab.com/mmrazor/v1/DAFL/dafl_logits_resnet34_resnet18_8xb256_cifar10_20220815_202654-67142167.json) |
|
||||||
|
|
||||||
## Citation
|
## Citation
|
||||||
|
|
||||||
|
@ -36,7 +36,3 @@ Learning portable neural networks is very essential for computer vision for the
|
||||||
biburl = {https://dblp.org/rec/conf/iccv/ChenW0YLSXX019.bib},
|
biburl = {https://dblp.org/rec/conf/iccv/ChenW0YLSXX019.bib},
|
||||||
bibsource = {dblp computer science bibliography, https://dblp.org}
|
bibsource = {dblp computer science bibliography, https://dblp.org}
|
||||||
```
|
```
|
||||||
|
|
||||||
## Acknowledgement
|
|
||||||
|
|
||||||
Shout out to Davidgzx.
|
|
||||||
|
|
|
@ -4,6 +4,7 @@ _base_ = [
|
||||||
'mmcls::_base_/default_runtime.py'
|
'mmcls::_base_/default_runtime.py'
|
||||||
]
|
]
|
||||||
|
|
||||||
|
res34_ckpt_path = 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_b16x8_cifar10_20210528-a8aa36a6.pth' # noqa: E501
|
||||||
model = dict(
|
model = dict(
|
||||||
_scope_='mmrazor',
|
_scope_='mmrazor',
|
||||||
type='DAFLDataFreeDistillation',
|
type='DAFLDataFreeDistillation',
|
||||||
|
@ -21,7 +22,7 @@ model = dict(
|
||||||
build_cfg=dict(
|
build_cfg=dict(
|
||||||
cfg_path='mmcls::resnet/resnet34_8xb16_cifar10.py',
|
cfg_path='mmcls::resnet/resnet34_8xb16_cifar10.py',
|
||||||
pretrained=True),
|
pretrained=True),
|
||||||
ckpt_path='resnet34_b16x8_cifar10_20210528-a8aa36a6.pth')),
|
ckpt_path=res34_ckpt_path)),
|
||||||
generator=dict(
|
generator=dict(
|
||||||
type='DAFLGenerator',
|
type='DAFLGenerator',
|
||||||
img_size=32,
|
img_size=32,
|
|
@ -0,0 +1,43 @@
|
||||||
|
Collections:
|
||||||
|
- Name: DAFL
|
||||||
|
Metadata:
|
||||||
|
Training Data:
|
||||||
|
- CIFAR-10
|
||||||
|
Paper:
|
||||||
|
URL: https://doi.org/10.1109/ICCV.2019.00361
|
||||||
|
Title: Data-Free Learning of Student Networks
|
||||||
|
README: configs/distill/mmcls/dafl/README.md
|
||||||
|
Converted From:
|
||||||
|
Code:
|
||||||
|
URL: https://github.com/huawei-noah/Efficient-Computing/tree/master/Data-Efficient-Model-Compression/DAFL
|
||||||
|
Models:
|
||||||
|
- Name: dafl_logits_resnet34_resnet18_8xb256_cifar10
|
||||||
|
In Collection: DAFL
|
||||||
|
Metadata:
|
||||||
|
inference time (ms/im):
|
||||||
|
- value: 0.34
|
||||||
|
hardware: NVIDIA A100-SXM4-80GB
|
||||||
|
backend: PyTorch
|
||||||
|
batch size: 256
|
||||||
|
mode: FP32
|
||||||
|
resolution: (32, 32)
|
||||||
|
Location: logits
|
||||||
|
Student:
|
||||||
|
Config: mmcls::resnet/resnet18_8xb16_cifar10.py
|
||||||
|
Weights: https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_b16x8_cifar10_20210528-bd6371c8.pth
|
||||||
|
Metrics:
|
||||||
|
Top 1 Accuracy: 94.82
|
||||||
|
Top 5 Accuracy: 99.87
|
||||||
|
Teacher:
|
||||||
|
Config: mmcls::resnet/resnet34_8xb16_cifar10.py
|
||||||
|
Weights: https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_b16x8_cifar10_20210528-a8aa36a6.pth
|
||||||
|
Metrics:
|
||||||
|
Top 1 Accuracy: 95.34
|
||||||
|
Top 5 Accuracy: 99.87
|
||||||
|
Results:
|
||||||
|
- Task: Image Classification
|
||||||
|
Dataset: CIFAR-10
|
||||||
|
Metrics:
|
||||||
|
Top 1 Accuracy: 93.27
|
||||||
|
Config: configs/distill/mmcls/dafl/dafl_logits_resnet34_resnet18_8xb256_cifar10.py
|
||||||
|
Weights: https://download.openmmlab.com/mmrazor/v1/DAFL/dafl_logits_resnet34_resnet18_8xb256_cifar10_20220815_202654-67142167.pth
|
|
@ -14,9 +14,9 @@ Knowledge Distillation (KD) has made remarkable progress in the last few years a
|
||||||
|
|
||||||
### Classification
|
### Classification
|
||||||
|
|
||||||
| Location | Dataset | Teacher | Student | Acc | Acc(T) | Acc(S) | Config | Download |
|
| Location | Dataset | Teacher | Student | Acc | Acc(T) | Acc(S) | Config | Download |
|
||||||
| :------: | :-----: | :-------------------------------------------------------------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------------: | :---: | :----: | :----: | :----------------------------------------------: | :------------------------------------------------------------------------------------------------------------------------------------------- |
|
| :------: | :-----: | :-------------------------------------------------------------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------------: | :---: | :----: | :----: | :--------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||||
| logits | Cifar10 | [resnet34](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet34_8xb16_cifar10.py) | [resnet18](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet18_8xb16_cifar10.py) | 93.26 | 95.34 | 94.82 | [config](./dfad_logits_r34_r18_8xb32_cifar10.py) | [teacher](https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_b16x8_cifar10_20210528-a8aa36a6.pth) \|[model](<>) \| [log](<>) |
|
| logits | Cifar10 | [resnet34](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet34_8xb16_cifar10.py) | [resnet18](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet18_8xb16_cifar10.py) | 92.80 | 95.34 | 94.82 | [config](./dfad_logits_resnet34_resnet18_8xb32_cifar10.py) | [teacher](https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_b16x8_cifar10_20210528-a8aa36a6.pth) \|[model](https://download.openmmlab.com/mmrazor/v1/DFAD/dfad_logits_resnet34_resnet18_8xb32_cifar10_20220819_051141-961a5b09.pth) \| [log](https://download.openmmlab.com/mmrazor/v1/DFAD/dfad_logits_resnet34_resnet18_8xb32_cifar10_20220819_051141-961a5b09.json) |
|
||||||
|
|
||||||
## Citation
|
## Citation
|
||||||
|
|
||||||
|
@ -28,7 +28,3 @@ Knowledge Distillation (KD) has made remarkable progress in the last few years a
|
||||||
year={2019}
|
year={2019}
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
## Acknowledgement
|
|
||||||
|
|
||||||
Appreciate Davidgzx's contribution.
|
|
||||||
|
|
|
@ -4,6 +4,7 @@ _base_ = [
|
||||||
'mmcls::_base_/default_runtime.py'
|
'mmcls::_base_/default_runtime.py'
|
||||||
]
|
]
|
||||||
|
|
||||||
|
res34_ckpt_path = 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_b16x8_cifar10_20210528-a8aa36a6.pth' # noqa: E501
|
||||||
model = dict(
|
model = dict(
|
||||||
_scope_='mmrazor',
|
_scope_='mmrazor',
|
||||||
type='DataFreeDistillation',
|
type='DataFreeDistillation',
|
||||||
|
@ -21,7 +22,7 @@ model = dict(
|
||||||
build_cfg=dict(
|
build_cfg=dict(
|
||||||
cfg_path='mmcls::resnet/resnet34_8xb16_cifar10.py',
|
cfg_path='mmcls::resnet/resnet34_8xb16_cifar10.py',
|
||||||
pretrained=True),
|
pretrained=True),
|
||||||
ckpt_path='resnet34_b16x8_cifar10_20210528-a8aa36a6.pth')),
|
ckpt_path=res34_ckpt_path)),
|
||||||
generator=dict(
|
generator=dict(
|
||||||
type='DAFLGenerator',
|
type='DAFLGenerator',
|
||||||
img_size=32,
|
img_size=32,
|
|
@ -0,0 +1,43 @@
|
||||||
|
Collections:
|
||||||
|
- Name: DFAD
|
||||||
|
Metadata:
|
||||||
|
Training Data:
|
||||||
|
- CIFAR-10
|
||||||
|
Paper:
|
||||||
|
URL: https://arxiv.org/pdf/1912.11006.pdf
|
||||||
|
Title: Data-Free Adversarial Distillation
|
||||||
|
README: configs/distill/mmcls/dfad/README.md
|
||||||
|
Converted From:
|
||||||
|
Code:
|
||||||
|
URL: https://github.com/VainF/Data-Free-Adversarial-Distillation
|
||||||
|
Models:
|
||||||
|
- Name: dfad_logits_resnet34_resnet18_8xb32_cifar10
|
||||||
|
In Collection: DFAD
|
||||||
|
Metadata:
|
||||||
|
inference time (ms/im):
|
||||||
|
- value: 0.38
|
||||||
|
hardware: NVIDIA A100-SXM4-80GB
|
||||||
|
backend: PyTorch
|
||||||
|
batch size: 32
|
||||||
|
mode: FP32
|
||||||
|
resolution: (32, 32)
|
||||||
|
Location: logits
|
||||||
|
Student:
|
||||||
|
Config: mmcls::resnet/resnet18_8xb16_cifar10.py
|
||||||
|
Weights: https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_b16x8_cifar10_20210528-bd6371c8.pth
|
||||||
|
Metrics:
|
||||||
|
Top 1 Accuracy: 94.82
|
||||||
|
Top 5 Accuracy: 99.87
|
||||||
|
Teacher:
|
||||||
|
Config: mmcls::resnet/resnet34_8xb16_cifar10.py
|
||||||
|
Weights: https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_b16x8_cifar10_20210528-a8aa36a6.pth
|
||||||
|
Metrics:
|
||||||
|
Top 1 Accuracy: 95.34
|
||||||
|
Top 5 Accuracy: 99.87
|
||||||
|
Results:
|
||||||
|
- Task: Image Classification
|
||||||
|
Dataset: CIFAR-10
|
||||||
|
Metrics:
|
||||||
|
Top 1 Accuracy: 92.80
|
||||||
|
Config: configs/distill/mmcls/dfad/dfad_logits_resnet34_resnet18_8xb32_cifar10.py
|
||||||
|
Weights: https://download.openmmlab.com/mmrazor/v1/DFAD/dfad_logits_resnet34_resnet18_8xb32_cifar10_20220819_051141-961a5b09.pth
|
|
@ -26,9 +26,9 @@ almost 10.4 times less parameters outperforms a larger, state-of-the-art teacher
|
||||||
|
|
||||||
### Classification
|
### Classification
|
||||||
|
|
||||||
| Location | Dataset | Teacher | Student | Acc | Acc(T) | Acc(S) | Config | Download |
|
| Location | Dataset | Teacher | Student | Acc | Acc(T) | Acc(S) | Config | Download |
|
||||||
| :---------------: | :------: | :----------------------------------------------------------------------------------------------------------: | :----------------------------------------------------------------------------------------------------------: | :---: | :----: | :----: | :----------------------------------------------------------------: | :---------------------------------------------------------------------------------------------------------------------------------------- |
|
| :---------------: | :------: | :----------------------------------------------------------------------------------------------------------: | :----------------------------------------------------------------------------------------------------------: | :---: | :----: | :----: | :-----------------------------------------------------------------: | :----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||||
| backbone & logits | ImageNet | [resnet50](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet50_8xb32_in1k.py) | [resnet18](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet18_8xb32_in1k.py) | 70.85 | 76.55 | 69.90 | [config](./fitnet_backbone_logits_resnet50_resnet18_8xb16_in1k.py) | [teacher](https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth) \|[model](<>) \| [log](<>) |
|
| backbone & logits | ImageNet | [resnet50](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet50_8xb32_in1k.py) | [resnet18](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet18_8xb32_in1k.py) | 70.58 | 76.55 | 69.90 | [config](./fitnets_backbone_logits_resnet50_resnet18_8xb32_in1k.py) | [teacher](https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth) \|[model](https://download.openmmlab.com/mmrazor/v1/FieNets/fitnets_backbone_logits_resnet50_resnet18_8xb32_in1k_20220830_155608-00ccdbe2.pth) \| [log](https://download.openmmlab.com/mmrazor/v1/FieNets/fitnets_backbone_logits_resnet50_resnet18_8xb32_in1k_20220830_155608-00ccdbe2.json) |
|
||||||
|
|
||||||
## Citation
|
## Citation
|
||||||
|
|
|
@ -4,6 +4,7 @@ _base_ = [
|
||||||
'mmcls::_base_/default_runtime.py'
|
'mmcls::_base_/default_runtime.py'
|
||||||
]
|
]
|
||||||
|
|
||||||
|
teacher_ckpt = 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth' # noqa: E501
|
||||||
model = dict(
|
model = dict(
|
||||||
_scope_='mmrazor',
|
_scope_='mmrazor',
|
||||||
type='SingleTeacherDistill',
|
type='SingleTeacherDistill',
|
||||||
|
@ -18,7 +19,7 @@ model = dict(
|
||||||
cfg_path='mmcls::resnet/resnet18_8xb32_in1k.py', pretrained=False),
|
cfg_path='mmcls::resnet/resnet18_8xb32_in1k.py', pretrained=False),
|
||||||
teacher=dict(
|
teacher=dict(
|
||||||
cfg_path='mmcls::resnet/resnet50_8xb32_in1k.py', pretrained=True),
|
cfg_path='mmcls::resnet/resnet50_8xb32_in1k.py', pretrained=True),
|
||||||
teacher_ckpt='resnet50_8xb32_in1k_20210831-ea4938fc.pth',
|
teacher_ckpt=teacher_ckpt,
|
||||||
distiller=dict(
|
distiller=dict(
|
||||||
type='ConfigurableDistiller',
|
type='ConfigurableDistiller',
|
||||||
student_recorders=dict(
|
student_recorders=dict(
|
|
@ -0,0 +1,40 @@
|
||||||
|
Collections:
|
||||||
|
- Name: FitNets
|
||||||
|
Metadata:
|
||||||
|
Training Data:
|
||||||
|
- ImageNet-1k
|
||||||
|
Paper:
|
||||||
|
URL: https://arxiv.org/abs/1412.6550
|
||||||
|
Title: FitNets- Hints for Thin Deep Nets
|
||||||
|
README: configs/distill/mmcls/fitnet/README.md
|
||||||
|
Models:
|
||||||
|
- Name: fitnets_backbone_logits_resnet50_resnet18_8xb32_in1k
|
||||||
|
In Collection: FitNets
|
||||||
|
Metadata:
|
||||||
|
inference time (ms/im):
|
||||||
|
- value: 0.18
|
||||||
|
hardware: NVIDIA A100-SXM4-80GB
|
||||||
|
backend: PyTorch
|
||||||
|
batch size: 32
|
||||||
|
mode: FP32
|
||||||
|
resolution: (224, 224)
|
||||||
|
Location: backbone & logits
|
||||||
|
Student:
|
||||||
|
Config: mmcls::resnet/resnet18_8xb32_in1k.py
|
||||||
|
Weights: https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_8xb32_in1k_20210831-fbbb1da6.pth
|
||||||
|
Metrics:
|
||||||
|
Top 1 Accuracy: 69.90
|
||||||
|
Top 5 Accuracy: 89.43
|
||||||
|
Teacher:
|
||||||
|
Config: mmcls::resnet/resnet50_8xb32_in1k.py
|
||||||
|
Weights: https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth
|
||||||
|
Metrics:
|
||||||
|
Top 1 Accuracy: 76.55
|
||||||
|
Top 5 Accuracy: 93.06
|
||||||
|
Results:
|
||||||
|
- Task: Image Classification
|
||||||
|
Dataset: ImageNet-1k
|
||||||
|
Metrics:
|
||||||
|
Top 1 Accuracy: 70.58
|
||||||
|
Config: configs/distill/mmcls/fitnet/fitnets_backbone_logits_resnet50_resnet18_8xb32_in1k.py
|
||||||
|
Weights: https://download.openmmlab.com/mmrazor/v1/FieNets/fitnets_backbone_logits_resnet50_resnet18_8xb32_in1k_20220830_155608-00ccdbe2.pth
|
|
@ -20,9 +20,9 @@ Performing knowledge transfer from a large teacher network to a smaller student
|
||||||
|
|
||||||
### Classification
|
### Classification
|
||||||
|
|
||||||
| Location | Dataset | Teacher | Student | Acc | Acc(T) | Acc(S) | Config | Download |
|
| Location | Dataset | Teacher | Student | Acc | Acc(T) | Acc(S) | Config | Download |
|
||||||
| :---------------: | :-----: | :-------------------------------------------------------------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------------: | :---: | :----: | :----: | :-------------------------------------------------------: | :------------------------------------------------------------------------------------------------------------------------------------------- |
|
| :---------------: | :-----: | :-------------------------------------------------------------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------------: | :---: | :----: | :----: | :-----------------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||||
| backbone & logits | Cifar10 | [resnet34](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet34_8xb16_cifar10.py) | [resnet18](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet18_8xb16_cifar10.py) | 93.50 | 95.34 | 94.82 | [config](./zskt_backbone_logits_r34_r18_8xb16_cifar10.py) | [teacher](https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_b16x8_cifar10_20210528-a8aa36a6.pth) \|[model](<>) \| [log](<>) |
|
| backbone & logits | Cifar10 | [resnet34](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet34_8xb16_cifar10.py) | [resnet18](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet18_8xb16_cifar10.py) | 93.05 | 95.34 | 94.82 | [config](./zskt_backbone_logits_resnet34_resnet18_8xb16_cifar10.py) | [teacher](https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_b16x8_cifar10_20210528-a8aa36a6.pth) \|[model](https://download.openmmlab.com/mmrazor/v1/ZSKT/zskt_backbone_logits_resnet34_resnet18_8xb16_cifar10_20220823_114006-28584c2e.pth) \| [log](https://download.openmmlab.com/mmrazor/v1/ZSKT/zskt_backbone_logits_resnet34_resnet18_8xb16_cifar10_20220823_114006-28584c2e.json) |
|
||||||
|
|
||||||
## Citation
|
## Citation
|
||||||
|
|
||||||
|
@ -35,7 +35,3 @@ Performing knowledge transfer from a large teacher network to a smaller student
|
||||||
year={2019}
|
year={2019}
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
## Acknowledgement
|
|
||||||
|
|
||||||
Appreciate Davidgzx's contribution.
|
|
||||||
|
|
|
@ -0,0 +1,43 @@
|
||||||
|
Collections:
|
||||||
|
- Name: ZSKT
|
||||||
|
Metadata:
|
||||||
|
Training Data:
|
||||||
|
- CIFAR-10
|
||||||
|
Paper:
|
||||||
|
URL: https://arxiv.org/abs/1905.09768
|
||||||
|
Title: Zero-shot Knowledge Transfer via Adversarial Belief Matching
|
||||||
|
README: configs/distill/mmcls/zskt/README.md
|
||||||
|
Converted From:
|
||||||
|
Code:
|
||||||
|
URL: https://github.com/polo5/ZeroShotKnowledgeTransfer
|
||||||
|
Models:
|
||||||
|
- Name: zskt_backbone_logits_resnet34_resnet18_8xb16_cifar10
|
||||||
|
In Collection: ZSKT
|
||||||
|
Metadata:
|
||||||
|
inference time (ms/im):
|
||||||
|
- value: 0.12
|
||||||
|
hardware: NVIDIA A100-SXM4-80GB
|
||||||
|
backend: PyTorch
|
||||||
|
batch size: 16
|
||||||
|
mode: FP32
|
||||||
|
resolution: (32, 32)
|
||||||
|
Location: backbone & logits
|
||||||
|
Student:
|
||||||
|
Config: mmcls::resnet/resnet18_8xb16_cifar10.py
|
||||||
|
Weights: https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_b16x8_cifar10_20210528-bd6371c8.pth
|
||||||
|
Metrics:
|
||||||
|
Top 1 Accuracy: 94.82
|
||||||
|
Top 5 Accuracy: 99.87
|
||||||
|
Teacher:
|
||||||
|
Config: mmcls::resnet/resnet34_8xb16_cifar10.py
|
||||||
|
Weights: https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_b16x8_cifar10_20210528-a8aa36a6.pth
|
||||||
|
Metrics:
|
||||||
|
Top 1 Accuracy: 95.34
|
||||||
|
Top 5 Accuracy: 99.87
|
||||||
|
Results:
|
||||||
|
- Task: Image Classification
|
||||||
|
Dataset: CIFAR-10
|
||||||
|
Metrics:
|
||||||
|
Top 1 Accuracy: 93.05
|
||||||
|
Config: configs/distill/mmcls/zskt/zskt_backbone_logits_resnet34_resnet18_8xb16_cifar10.py
|
||||||
|
Weights: https://download.openmmlab.com/mmrazor/v1/ZSKT/zskt_backbone_logits_resnet34_resnet18_8xb16_cifar10_20220823_114006-28584c2e.pth
|
|
@ -4,6 +4,7 @@ _base_ = [
|
||||||
'mmcls::_base_/default_runtime.py'
|
'mmcls::_base_/default_runtime.py'
|
||||||
]
|
]
|
||||||
|
|
||||||
|
res34_ckpt_path = 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_b16x8_cifar10_20210528-a8aa36a6.pth' # noqa: E501
|
||||||
model = dict(
|
model = dict(
|
||||||
_scope_='mmrazor',
|
_scope_='mmrazor',
|
||||||
type='DataFreeDistillation',
|
type='DataFreeDistillation',
|
||||||
|
@ -17,11 +18,11 @@ model = dict(
|
||||||
architecture=dict(
|
architecture=dict(
|
||||||
cfg_path='mmcls::resnet/resnet18_8xb16_cifar10.py', pretrained=False),
|
cfg_path='mmcls::resnet/resnet18_8xb16_cifar10.py', pretrained=False),
|
||||||
teachers=dict(
|
teachers=dict(
|
||||||
r34=dict(
|
res34=dict(
|
||||||
build_cfg=dict(
|
build_cfg=dict(
|
||||||
cfg_path='mmcls::resnet/resnet34_8xb16_cifar10.py',
|
cfg_path='mmcls::resnet/resnet34_8xb16_cifar10.py',
|
||||||
pretrained=True),
|
pretrained=True),
|
||||||
ckpt_path='resnet34_b16x8_cifar10_20210528-a8aa36a6.pth')),
|
ckpt_path=res34_ckpt_path)),
|
||||||
generator=dict(
|
generator=dict(
|
||||||
type='ZSKTGenerator', img_size=32, latent_dim=256,
|
type='ZSKTGenerator', img_size=32, latent_dim=256,
|
||||||
hidden_channels=128),
|
hidden_channels=128),
|
||||||
|
@ -34,15 +35,15 @@ model = dict(
|
||||||
bb_s4=dict(type='ModuleOutputs', source='backbone.layer4.1.relu'),
|
bb_s4=dict(type='ModuleOutputs', source='backbone.layer4.1.relu'),
|
||||||
fc=dict(type='ModuleOutputs', source='head.fc')),
|
fc=dict(type='ModuleOutputs', source='head.fc')),
|
||||||
teacher_recorders=dict(
|
teacher_recorders=dict(
|
||||||
r34_bb_s1=dict(
|
res34_bb_s1=dict(
|
||||||
type='ModuleOutputs', source='r34.backbone.layer1.2.relu'),
|
type='ModuleOutputs', source='res34.backbone.layer1.2.relu'),
|
||||||
r34_bb_s2=dict(
|
res34_bb_s2=dict(
|
||||||
type='ModuleOutputs', source='r34.backbone.layer2.3.relu'),
|
type='ModuleOutputs', source='res34.backbone.layer2.3.relu'),
|
||||||
r34_bb_s3=dict(
|
res34_bb_s3=dict(
|
||||||
type='ModuleOutputs', source='r34.backbone.layer3.5.relu'),
|
type='ModuleOutputs', source='res34.backbone.layer3.5.relu'),
|
||||||
r34_bb_s4=dict(
|
res34_bb_s4=dict(
|
||||||
type='ModuleOutputs', source='r34.backbone.layer4.2.relu'),
|
type='ModuleOutputs', source='res34.backbone.layer4.2.relu'),
|
||||||
r34_fc=dict(type='ModuleOutputs', source='r34.head.fc')),
|
res34_fc=dict(type='ModuleOutputs', source='res34.head.fc')),
|
||||||
distill_losses=dict(
|
distill_losses=dict(
|
||||||
loss_s1=dict(type='ATLoss', loss_weight=250.0),
|
loss_s1=dict(type='ATLoss', loss_weight=250.0),
|
||||||
loss_s2=dict(type='ATLoss', loss_weight=250.0),
|
loss_s2=dict(type='ATLoss', loss_weight=250.0),
|
||||||
|
@ -55,31 +56,31 @@ model = dict(
|
||||||
s_feature=dict(
|
s_feature=dict(
|
||||||
from_student=True, recorder='bb_s1', record_idx=1),
|
from_student=True, recorder='bb_s1', record_idx=1),
|
||||||
t_feature=dict(
|
t_feature=dict(
|
||||||
from_student=False, recorder='r34_bb_s1', record_idx=1)),
|
from_student=False, recorder='res34_bb_s1', record_idx=1)),
|
||||||
loss_s2=dict(
|
loss_s2=dict(
|
||||||
s_feature=dict(
|
s_feature=dict(
|
||||||
from_student=True, recorder='bb_s2', record_idx=1),
|
from_student=True, recorder='bb_s2', record_idx=1),
|
||||||
t_feature=dict(
|
t_feature=dict(
|
||||||
from_student=False, recorder='r34_bb_s2', record_idx=1)),
|
from_student=False, recorder='res34_bb_s2', record_idx=1)),
|
||||||
loss_s3=dict(
|
loss_s3=dict(
|
||||||
s_feature=dict(
|
s_feature=dict(
|
||||||
from_student=True, recorder='bb_s3', record_idx=1),
|
from_student=True, recorder='bb_s3', record_idx=1),
|
||||||
t_feature=dict(
|
t_feature=dict(
|
||||||
from_student=False, recorder='r34_bb_s3', record_idx=1)),
|
from_student=False, recorder='res34_bb_s3', record_idx=1)),
|
||||||
loss_s4=dict(
|
loss_s4=dict(
|
||||||
s_feature=dict(
|
s_feature=dict(
|
||||||
from_student=True, recorder='bb_s4', record_idx=1),
|
from_student=True, recorder='bb_s4', record_idx=1),
|
||||||
t_feature=dict(
|
t_feature=dict(
|
||||||
from_student=False, recorder='r34_bb_s4', record_idx=1)),
|
from_student=False, recorder='res34_bb_s4', record_idx=1)),
|
||||||
loss_kl=dict(
|
loss_kl=dict(
|
||||||
preds_S=dict(from_student=True, recorder='fc'),
|
preds_S=dict(from_student=True, recorder='fc'),
|
||||||
preds_T=dict(from_student=False, recorder='r34_fc')))),
|
preds_T=dict(from_student=False, recorder='res34_fc')))),
|
||||||
generator_distiller=dict(
|
generator_distiller=dict(
|
||||||
type='ConfigurableDistiller',
|
type='ConfigurableDistiller',
|
||||||
student_recorders=dict(
|
student_recorders=dict(
|
||||||
fc=dict(type='ModuleOutputs', source='head.fc')),
|
fc=dict(type='ModuleOutputs', source='head.fc')),
|
||||||
teacher_recorders=dict(
|
teacher_recorders=dict(
|
||||||
r34_fc=dict(type='ModuleOutputs', source='r34.head.fc')),
|
res34_fc=dict(type='ModuleOutputs', source='res34.head.fc')),
|
||||||
distill_losses=dict(
|
distill_losses=dict(
|
||||||
loss_kl=dict(
|
loss_kl=dict(
|
||||||
type='KLDivergence',
|
type='KLDivergence',
|
||||||
|
@ -89,7 +90,7 @@ model = dict(
|
||||||
loss_forward_mappings=dict(
|
loss_forward_mappings=dict(
|
||||||
loss_kl=dict(
|
loss_kl=dict(
|
||||||
preds_S=dict(from_student=True, recorder='fc'),
|
preds_S=dict(from_student=True, recorder='fc'),
|
||||||
preds_T=dict(from_student=False, recorder='r34_fc')))),
|
preds_T=dict(from_student=False, recorder='res34_fc')))),
|
||||||
student_iter=10)
|
student_iter=10)
|
||||||
|
|
||||||
# model wrapper
|
# model wrapper
|
|
@ -14,9 +14,9 @@ Knowledge distillation, in which a student model is trained to mimic a teacher m
|
||||||
|
|
||||||
### Detection
|
### Detection
|
||||||
|
|
||||||
| Location | Dataset | Teacher | Student | box AP | box AP(T) | box AP(S) | Config | Download |
|
| Location | Dataset | Teacher | Student | box AP | box AP(T) | box AP(S) | Config | Download |
|
||||||
| :------: | :-----: | :-------------------------------------------------------------------------------------------------------------------------------: | :-----------------------------------------------------------------------------------------------------------------------------: | :----: | :-------: | :-------: | :--------------------------------------------------: | :-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
| :------: | :-----: | :--------------------------------------------------------------------------------------------------------------------------------: | :------------------------------------------------------------------------------------------------------------------------------: | :----: | :-------: | :-------: | :------------------------------------------------------------: | :----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||||
| neck | COCO | [fasterrcnn_resnet101](https://github.com/open-mmlab/mmdetection/blob/master/configs/faster_rcnn/faster_rcnn_r101_fpn_1x_coco.py) | [fasterrcnn_resnet50](https://github.com/open-mmlab/mmdetection/blob/master/configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py) | 39.1 | 39.4 | 37.8 | [config](./fbkd_fpn_frcnn_r101_frcnn_r50_1x_coco.py) | [teacher](https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r101_fpn_1x_coco/faster_rcnn_r101_fpn_1x_coco_20200130-f513f705.pth) \|[model](<>) \| [log](<>) |
|
| neck | COCO | [faster-rcnn_resnet101](https://github.com/open-mmlab/mmdetection/blob/master/configs/faster_rcnn/faster_rcnn_r101_fpn_1x_coco.py) | [faster-rcnn_resnet50](https://github.com/open-mmlab/mmdetection/blob/master/configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py) | 39.3 | 39.4 | 37.4 | [config](./fbkd_fpn_frcnn_resnet101_frcnn_resnet50_1x_coco.py) | [teacher](https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r101_fpn_1x_coco/faster_rcnn_r101_fpn_1x_coco_20200130-f513f705.pth) \|[model](https://download.openmmlab.com/mmrazor/v1/FBKD/fbkd_fpn_frcnn_resnet101_frcnn_resnet50_1x_coco_20220830_121522-8d7e11df.pth) \| [log](https://download.openmmlab.com/mmrazor/v1/FBKD/fbkd_fpn_frcnn_resnet101_frcnn_resnet50_1x_coco_20220830_121522-8d7e11df.json) |
|
||||||
|
|
||||||
## Citation
|
## Citation
|
||||||
|
|
||||||
|
|
|
@ -4,16 +4,17 @@ _base_ = [
|
||||||
'mmdet::_base_/default_runtime.py'
|
'mmdet::_base_/default_runtime.py'
|
||||||
]
|
]
|
||||||
|
|
||||||
|
teacher_ckpt = 'https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r101_fpn_1x_coco/faster_rcnn_r101_fpn_1x_coco_20200130-f513f705.pth' # noqa: E501
|
||||||
model = dict(
|
model = dict(
|
||||||
_scope_='mmrazor',
|
_scope_='mmrazor',
|
||||||
type='SingleTeacherDistill',
|
type='SingleTeacherDistill',
|
||||||
architecture=dict(
|
architecture=dict(
|
||||||
cfg_path='mmdet::faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py',
|
cfg_path='mmdet::faster_rcnn/faster-rcnn_r50_fpn_1x_coco.py',
|
||||||
pretrained=True),
|
pretrained=True),
|
||||||
teacher=dict(
|
teacher=dict(
|
||||||
cfg_path='mmdet::faster_rcnn/faster_rcnn_r101_fpn_1x_coco.py',
|
cfg_path='mmdet::faster_rcnn/faster-rcnn_r101_fpn_1x_coco.py',
|
||||||
pretrained=False),
|
pretrained=False),
|
||||||
teacher_ckpt='faster_rcnn_r101_fpn_1x_coco_20200130-f513f705.pth',
|
teacher_ckpt=teacher_ckpt,
|
||||||
distiller=dict(
|
distiller=dict(
|
||||||
type='ConfigurableDistiller',
|
type='ConfigurableDistiller',
|
||||||
student_recorders=dict(
|
student_recorders=dict(
|
|
@ -0,0 +1,41 @@
|
||||||
|
Collections:
|
||||||
|
- Name: FBKD
|
||||||
|
Metadata:
|
||||||
|
Training Data:
|
||||||
|
- COCO
|
||||||
|
Paper:
|
||||||
|
URL: https://openreview.net/pdf?id=uKhGRvM8QNH
|
||||||
|
Title: IMPROVE OBJECT DETECTION WITH FEATURE-BASED KNOWLEDGE DISTILLATION- TOWARDS ACCURATE AND EFFICIENT DETECTORS
|
||||||
|
README: configs/distill/mmdet/fbkd/README.md
|
||||||
|
Converted From:
|
||||||
|
Code:
|
||||||
|
URL: https://github.com/ArchipLab-LinfengZhang/Object-Detection-Knowledge-Distillation-ICLR2021
|
||||||
|
Models:
|
||||||
|
- Name: fbkd_fpn_faster-rcnn_resnet101_faster-rcnn_resnet50_1x_coco
|
||||||
|
In Collection: FBKD
|
||||||
|
Metadata:
|
||||||
|
inference time (ms/im):
|
||||||
|
- value: 0.32
|
||||||
|
hardware: NVIDIA A100-SXM4-80GB
|
||||||
|
backend: PyTorch
|
||||||
|
batch size: 2
|
||||||
|
mode: FP32
|
||||||
|
resolution: (1333, 800)
|
||||||
|
Location: fpn
|
||||||
|
Student:
|
||||||
|
Metrics:
|
||||||
|
box AP: 37.4
|
||||||
|
Config: mmdet::faster_rcnn/faster-rcnn_r50_fpn_1x_coco.py
|
||||||
|
Weights: https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r50_fpn_1x_coco/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth
|
||||||
|
Teacher:
|
||||||
|
Metrics:
|
||||||
|
box AP: 39.4
|
||||||
|
Config: mmdet::faster_rcnn/faster-rcnn_r101_fpn_1x_coco.py
|
||||||
|
Weights: https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r101_fpn_1x_coco/faster_rcnn_r101_fpn_1x_coco_20200130-f513f705.pth
|
||||||
|
Results:
|
||||||
|
- Task: Object Detection
|
||||||
|
Dataset: COCO
|
||||||
|
Metrics:
|
||||||
|
box AP: 39.3
|
||||||
|
Config: configs/distill/mmdet/fbkd/fbkd_fpn_faster-rcnn_resnet101_faster-rcnn_resnet50_1x_coco.py
|
||||||
|
Weights: https://download.openmmlab.com/mmrazor/v1/FBKD/fbkd_fpn_frcnn_resnet101_frcnn_resnet50_1x_coco_20220830_121522-8d7e11df.pth
|
|
@ -78,12 +78,12 @@ class DataFreeDistillation(BaseAlgorithm):
|
||||||
"""Alias for ``architecture``."""
|
"""Alias for ``architecture``."""
|
||||||
return self.architecture
|
return self.architecture
|
||||||
|
|
||||||
def train_step(self, data: List[dict],
|
def train_step(self, data: Dict[str, List[dict]],
|
||||||
optim_wrapper: OptimWrapper) -> Dict[str, torch.Tensor]:
|
optim_wrapper: OptimWrapper) -> Dict[str, torch.Tensor]:
|
||||||
"""Train step for DataFreeDistillation.
|
"""Train step for DataFreeDistillation.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
data (List[dict]): Data sampled by dataloader.
|
data (Dict[str, List[dict]]): Data sampled by dataloader.
|
||||||
optim_wrapper (OptimWrapper): A wrapper of optimizer to
|
optim_wrapper (OptimWrapper): A wrapper of optimizer to
|
||||||
update parameters.
|
update parameters.
|
||||||
"""
|
"""
|
||||||
|
@ -107,16 +107,16 @@ class DataFreeDistillation(BaseAlgorithm):
|
||||||
return log_vars
|
return log_vars
|
||||||
|
|
||||||
def train_student(
|
def train_student(
|
||||||
self, data: List[dict], optimizer: OPTIMIZERS
|
self, data: Dict[str, List[dict]], optimizer: OPTIMIZERS
|
||||||
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
|
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
|
||||||
"""Train step for the student model.
|
"""Train step for the student model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
data (List[dict]): Data sampled by dataloader.
|
data (Dict[str, List[dict]]): Data sampled by dataloader.
|
||||||
optimizer (OPTIMIZERS): The optimizer to update student.
|
optimizer (OPTIMIZERS): The optimizer to update student.
|
||||||
"""
|
"""
|
||||||
log_vars = dict()
|
log_vars = dict()
|
||||||
batch_size = len(data)
|
batch_size = len(data['inputs'])
|
||||||
|
|
||||||
for _ in range(self.student_iter):
|
for _ in range(self.student_iter):
|
||||||
fakeimg_init = torch.randn(
|
fakeimg_init = torch.randn(
|
||||||
|
@ -124,13 +124,14 @@ class DataFreeDistillation(BaseAlgorithm):
|
||||||
fakeimg = self.generator(fakeimg_init, batch_size).detach()
|
fakeimg = self.generator(fakeimg_init, batch_size).detach()
|
||||||
|
|
||||||
with optimizer.optim_context(self.student):
|
with optimizer.optim_context(self.student):
|
||||||
_, data_samples = self.data_preprocessor(data, True)
|
pseudo_data = self.data_preprocessor(data, True)
|
||||||
|
pseudo_data_samples = pseudo_data['data_samples']
|
||||||
# recorde the needed information
|
# recorde the needed information
|
||||||
with self.distiller.student_recorders:
|
with self.distiller.student_recorders:
|
||||||
_ = self.student(fakeimg, data_samples, mode='loss')
|
_ = self.student(fakeimg, pseudo_data_samples, mode='loss')
|
||||||
with self.distiller.teacher_recorders, torch.no_grad():
|
with self.distiller.teacher_recorders, torch.no_grad():
|
||||||
for _, teacher in self.teachers.items():
|
for _, teacher in self.teachers.items():
|
||||||
_ = teacher(fakeimg, data_samples, mode='loss')
|
_ = teacher(fakeimg, pseudo_data_samples, mode='loss')
|
||||||
loss_distill = self.distiller.compute_distill_losses()
|
loss_distill = self.distiller.compute_distill_losses()
|
||||||
|
|
||||||
distill_loss, distill_log_vars = self.parse_losses(loss_distill)
|
distill_loss, distill_log_vars = self.parse_losses(loss_distill)
|
||||||
|
@ -140,15 +141,15 @@ class DataFreeDistillation(BaseAlgorithm):
|
||||||
return distill_loss, log_vars
|
return distill_loss, log_vars
|
||||||
|
|
||||||
def train_generator(
|
def train_generator(
|
||||||
self, data: List[dict], optimizer: OPTIMIZERS
|
self, data: Dict[str, List[dict]], optimizer: OPTIMIZERS
|
||||||
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
|
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
|
||||||
"""Train step for the generator.
|
"""Train step for the generator.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
data (List[dict]): Data sampled by dataloader.
|
data (Dict[str, List[dict]]): Data sampled by dataloader.
|
||||||
optimizer (OPTIMIZERS): The optimizer to update generator.
|
optimizer (OPTIMIZERS): The optimizer to update generator.
|
||||||
"""
|
"""
|
||||||
batch_size = len(data)
|
batch_size = len(data['inputs'])
|
||||||
fakeimg_init = torch.randn(
|
fakeimg_init = torch.randn(
|
||||||
(batch_size, self.generator.module.latent_dim))
|
(batch_size, self.generator.module.latent_dim))
|
||||||
fakeimg = self.generator(fakeimg_init, batch_size)
|
fakeimg = self.generator(fakeimg_init, batch_size)
|
||||||
|
@ -174,17 +175,17 @@ class DataFreeDistillation(BaseAlgorithm):
|
||||||
@MODELS.register_module()
|
@MODELS.register_module()
|
||||||
class DAFLDataFreeDistillation(DataFreeDistillation):
|
class DAFLDataFreeDistillation(DataFreeDistillation):
|
||||||
|
|
||||||
def train_step(self, data: List[dict],
|
def train_step(self, data: Dict[str, List[dict]],
|
||||||
optim_wrapper: OptimWrapper) -> Dict[str, torch.Tensor]:
|
optim_wrapper: OptimWrapper) -> Dict[str, torch.Tensor]:
|
||||||
"""DAFL train step.
|
"""DAFL train step.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
data (List[dict]): Data sampled by dataloader.
|
data (Dict[str, List[dict]): Data sampled by dataloader.
|
||||||
optim_wrapper (OptimWrapper): A wrapper of optimizer to
|
optim_wrapper (OptimWrapper): A wrapper of optimizer to
|
||||||
update parameters.
|
update parameters.
|
||||||
"""
|
"""
|
||||||
log_vars = dict()
|
log_vars = dict()
|
||||||
batch_size = len(data)
|
batch_size = len(data['inputs'])
|
||||||
|
|
||||||
for _, teacher in self.teachers.items():
|
for _, teacher in self.teachers.items():
|
||||||
teacher.eval()
|
teacher.eval()
|
||||||
|
|
|
@ -107,7 +107,6 @@ class TestLosses(TestCase):
|
||||||
ie_loss = InformationEntropyLoss(**dafl_loss_cfg, gather=True)
|
ie_loss = InformationEntropyLoss(**dafl_loss_cfg, gather=True)
|
||||||
ie_loss.world_size = 2
|
ie_loss.world_size = 2
|
||||||
|
|
||||||
# TODO: configure circle CI to test UT under multi torch versions.
|
|
||||||
if digit_version(torch.__version__) >= digit_version('1.8.0'):
|
if digit_version(torch.__version__) >= digit_version('1.8.0'):
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
RuntimeError,
|
RuntimeError,
|
||||||
|
|
Loading…
Reference in New Issue