[Docs] Add metafiles for KD algos (ABLoss, DAFL, DFAD, FBKD, FitNets, ZSKT) (#266)

1.Add metafiles for 6 kd algos.
2.Add model and log links. 
3.Revise data_samples in datafreedistillation for new feature of mmengine.
pull/270/head
zhongyu zhang 2022-08-31 22:12:25 +08:00 committed by GitHub
parent 24e106ba1d
commit f69aeabc69
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 349 additions and 88 deletions

View File

@ -14,9 +14,10 @@ An activation boundary for a neuron refers to a separating hyperplane that deter
### Classification ### Classification
| Location | Dataset | Teacher | Student | Acc | Acc(T) | Acc(S) | Config | Download | | Location | Dataset | Teacher | Student | Acc | Acc(T) | Acc(S) | Config | Download |
| :----------------------------------: | :------: | :----------------------------------------------------------------------------------------------------------: | :----------------------------------------------------------------------------------------------------------: | :---: | :----: | :----: | :--------------------------------------------------------------------------------------------------------------------------------------: | :---------------------------------------------------------------------------------------------------------------------------------------- | | :-----------------: | :------: | :----------------------------------------------------------------------------------------------------------: | :----------------------------------------------------------------------------------------------------------: | :---: | :----: | :----: | :------------------------------------------------------------------------: | :----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| backbone (pretrain) & logits (train) | ImageNet | [resnet50](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet50_8xb32_in1k.py) | [resnet18](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet18_8xb32_in1k.py) | 70.58 | 76.55 | 69.90 | [pretrain_config](./abloss_backbone_resnet50_resnet18_8xb32_in1k_pretrain) [train_config](./abloss_head_resnet50_resnet18_8xb32_in1k.py) | [teacher](https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth) \|[model](<>) \| [log](<>) | | backbone (pretrain) | ImageNet | [resnet50](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet50_8xb32_in1k.py) | [resnet18](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet18_8xb32_in1k.py) | | 76.55 | 69.90 | [pretrain_config](./abloss_pretrain_backbone_resnet50_resnet18_8xb32_in1k) | [teacher](https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth) \|[model](https://download.openmmlab.com/mmrazor/v1/ABLoss/abloss_pretrain_backbone_resnet50_resnet18_8xb32_in1k_20220830_165724-a6284e9f.pth) \| [log](https://download.openmmlab.com/mmrazor/v1/ABLoss/abloss_pretrain_backbone_resnet50_resnet18_8xb32_in1k_20220830_165724-a6284e9f.json) |
| logits (train) | ImageNet | [resnet50](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet50_8xb32_in1k.py) | [resnet18](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet18_8xb32_in1k.py) | 69.94 | 76.55 | 69.90 | [train_config](./abloss_logits_resnet50_resnet18_8xb32_in1k.py) | [teacher](https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth) \|[model](https://download.openmmlab.com/mmrazor/v1/ABLoss/abloss_logits_resnet50_resnet18_8xb32_in1k_20220830_202129-f35edde8.pth) \| [log](https://download.openmmlab.com/mmrazor/v1/ABLoss/abloss_logits_resnet50_resnet18_8xb32_in1k_20220830_202129-f35edde8.json) |
## Citation ## Citation
@ -43,31 +44,26 @@ An activation boundary for a neuron refers to a separating hyperplane that deter
## Getting Started ## Getting Started
### ABConnectors and Student pre-training. ### Pre-training.
```bash ```bash
sh tools/slurm_train.sh $PARTITION $JOB_NAME \ sh tools/dist_train.sh configs/distill/mmcls/abloss/abloss_pretrain_backbone_resnet50_resnet18_8xb32_in1k.py 8
configs/distill/mmcls/abloss/abloss_backbone_resnet50_resnet18_8xb32_in1k_pretrain.py\
$PRETRAIN_WORK_DIR
``` ```
### Modify Distillation training config ### Modify Distillation training config
open file 'configs/distill/mmcls/abloss/abloss_head_resnet50_resnet18_8xb32_in1k.py' open file 'configs/distill/mmcls/abloss/abloss_logits_resnet50_resnet18_8xb32_in1k.py'
```python ```python
# modify init_cfg in model settings # Modify init_cfg in model settings.
# pretrain_work_dir is same as the PRETRAIN_WORK_DIR in pre-training. # 'pretrain_work_dir' is same as the 'work_dir of pre-training'.
# 'last_epoch' defaults to 'epoch_20' in ABLoss.
init_cfg=dict( init_cfg=dict(
type='Pretrained', checkpoint='pretrain_work_dir/last_chechpoint.pth'), type='Pretrained', checkpoint='pretrain_work_dir/last_epoch.pth'),
``` ```
### Distillation training. ### Distillation training.
```bash ```bash
sh tools/slurm_train.sh $PARTITION $JOB_NAME \ sh tools/dist_train.sh configs/distill/mmcls/abloss/abloss_logits_resnet50_resnet18_8xb32_in1k.py 8
configs/distill/mmcls/abloss/abloss_head_resnet50_resnet18_8xb32_in1k.py\
$DISTILLATION_WORK_DIR
``` ```

View File

@ -4,6 +4,8 @@ _base_ = [
'mmcls::_base_/default_runtime.py' 'mmcls::_base_/default_runtime.py'
] ]
# Modify pretrain_checkpoint before training.
pretrain_checkpoint = 'work_dir_of_abloss_pretrain/last_epoch.pth'
model = dict( model = dict(
_scope_='mmrazor', _scope_='mmrazor',
type='SingleTeacherDistill', type='SingleTeacherDistill',
@ -18,8 +20,7 @@ model = dict(
cfg_path='mmcls::resnet/resnet18_8xb32_in1k.py', pretrained=False), cfg_path='mmcls::resnet/resnet18_8xb32_in1k.py', pretrained=False),
teacher=dict( teacher=dict(
cfg_path='mmcls::resnet/resnet50_8xb32_in1k.py', pretrained=False), cfg_path='mmcls::resnet/resnet50_8xb32_in1k.py', pretrained=False),
init_cfg=dict( init_cfg=dict(type='Pretrained', checkpoint=pretrain_checkpoint),
type='Pretrained', checkpoint='pretrain_work_dir/last_chechpoint.pth'),
distiller=dict( distiller=dict(
type='ConfigurableDistiller', type='ConfigurableDistiller',
student_recorders=dict( student_recorders=dict(
@ -28,7 +29,7 @@ model = dict(
fc=dict(type='ModuleOutputs', source='head.fc')), fc=dict(type='ModuleOutputs', source='head.fc')),
distill_losses=dict( distill_losses=dict(
loss_kl=dict( loss_kl=dict(
type='KLDivergence', loss_weight=200, reduction='mean')), type='KLDivergence', loss_weight=6.25, reduction='mean')),
loss_forward_mappings=dict( loss_forward_mappings=dict(
loss_kl=dict( loss_kl=dict(
preds_S=dict(from_student=True, recorder='fc'), preds_S=dict(from_student=True, recorder='fc'),

View File

@ -4,8 +4,7 @@ _base_ = [
'mmcls::_base_/default_runtime.py' 'mmcls::_base_/default_runtime.py'
] ]
train_cfg = dict(by_epoch=True, max_epochs=20, val_interval=1) teacher_ckpt = 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth' # noqa: E501
model = dict( model = dict(
_scope_='mmrazor', _scope_='mmrazor',
type='SingleTeacherDistill', type='SingleTeacherDistill',
@ -20,7 +19,7 @@ model = dict(
cfg_path='mmcls::resnet/resnet18_8xb32_in1k.py', pretrained=False), cfg_path='mmcls::resnet/resnet18_8xb32_in1k.py', pretrained=False),
teacher=dict( teacher=dict(
cfg_path='mmcls::resnet/resnet50_8xb32_in1k.py', pretrained=True), cfg_path='mmcls::resnet/resnet50_8xb32_in1k.py', pretrained=True),
teacher_ckpt='resnet50_8xb32_in1k_20210831-ea4938fc.pth', teacher_ckpt=teacher_ckpt,
calculate_student_loss=False, calculate_student_loss=False,
distiller=dict( distiller=dict(
type='ConfigurableDistiller', type='ConfigurableDistiller',
@ -94,4 +93,5 @@ model = dict(
find_unused_parameters = True find_unused_parameters = True
train_cfg = dict(by_epoch=True, max_epochs=20, val_interval=1)
val_cfg = dict(_delete_=True, type='mmrazor.SingleTeacherDistillValLoop') val_cfg = dict(_delete_=True, type='mmrazor.SingleTeacherDistillValLoop')

View File

@ -0,0 +1,61 @@
Collections:
- Name: ABLoss
Metadata:
Training Data:
- ImageNet-1k
Paper:
URL: https://arxiv.org/pdf/1811.03233.pdf
Title: Knowledge Transfer via Distillation of Activation Boundaries Formed by Hidden Neurons
README: configs/distill/mmcls/abloss/README.md
Converted From:
Code:
URL: https://github.com/bhheo/AB_distillation
Models:
- Name: abloss_pretrain_backbone_resnet50_resnet18_8xb32_in1k
In Collection: ABLoss
Metadata:
inference time (ms/im):
- value: 0.21
hardware: NVIDIA A100-SXM4-80GB
backend: PyTorch
batch size: 32
mode: FP32
resolution: (224, 224)
Location: backbone
Student:
Config: mmcls::resnet/resnet18_8xb32_in1k.py
Weights: https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_8xb32_in1k_20210831-fbbb1da6.pth
Metrics:
Top 1 Accuracy: 69.90
Top 5 Accuracy: 89.43
Teacher:
Config: mmcls::resnet/resnet50_8xb32_in1k.py
Weights: https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth
Metrics:
Top 1 Accuracy: 76.55
Top 5 Accuracy: 93.06
Config: configs/distill/mmcls/abloss/abloss_pretrain_backbone_resnet50_resnet18_8xb32_in1k.py
Weights: https://download.openmmlab.com/mmrazor/v1/ABLoss/abloss_pretrain_backbone_resnet50_resnet18_8xb32_in1k_20220830_165724-a6284e9f.pth
- Name: abloss_logits_resnet50_resnet18_8xb32_in1k
In Collection: ABLoss
Metadata:
Location: logits
Student:
Config: mmcls::resnet/resnet18_8xb32_in1k.py
Weights: https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_8xb32_in1k_20210831-fbbb1da6.pth
Metrics:
Top 1 Accuracy: 69.90
Top 5 Accuracy: 89.43
Teacher:
Config: mmcls::resnet/resnet50_8xb32_in1k.py
Weights: https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth
Metrics:
Top 1 Accuracy: 76.55
Top 5 Accuracy: 93.06
Results:
- Task: Image Classification
Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 69.94
Config: configs/distill/mmcls/abloss/abloss_logits_resnet50_resnet18_8xb32_in1k.py
Weights: https://download.openmmlab.com/mmrazor/v1/ABLoss/abloss_logits_resnet50_resnet18_8xb32_in1k_20220830_202129-f35edde8.pth

View File

@ -14,9 +14,9 @@ Learning portable neural networks is very essential for computer vision for the
### Classification ### Classification
| Location | Dataset | Teacher | Student | Acc | Acc(T) | Acc(S) | Config | Download | | Location | Dataset | Teacher | Student | Acc | Acc(T) | Acc(S) | Config | Download |
| :---------------: | :-----: | :-------------------------------------------------------------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------------: | :---: | :----: | :----: | :-----------------------------------------------: | :------------------------------------------------------------------------------------------------------------------------------------------- | | :---------------: | :-----: | :-------------------------------------------------------------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------------: | :---: | :----: | :----: | :---------------------------------------------------------: | :---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| backbone & logits | Cifar10 | [resnet34](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet34_8xb16_cifar10.py) | [resnet18](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet18_8xb16_cifar10.py) | 93.11 | 95.34 | 94.82 | [config](./dafl_logits_r34_r18_8xb256_cifar10.py) | [teacher](https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_b16x8_cifar10_20210528-a8aa36a6.pth) \|[model](<>) \| [log](<>) | | backbone & logits | Cifar10 | [resnet34](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet34_8xb16_cifar10.py) | [resnet18](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet18_8xb16_cifar10.py) | 93.27 | 95.34 | 94.82 | [config](./dafl_logits_resnet34_resnet18_8xb256_cifar10.py) | [teacher](https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_b16x8_cifar10_20210528-a8aa36a6.pth) \|[model](https://download.openmmlab.com/mmrazor/v1/DAFL/dafl_logits_resnet34_resnet18_8xb256_cifar10_20220815_202654-67142167.pth) \| [log](https://download.openmmlab.com/mmrazor/v1/DAFL/dafl_logits_resnet34_resnet18_8xb256_cifar10_20220815_202654-67142167.json) |
## Citation ## Citation
@ -36,7 +36,3 @@ Learning portable neural networks is very essential for computer vision for the
biburl = {https://dblp.org/rec/conf/iccv/ChenW0YLSXX019.bib}, biburl = {https://dblp.org/rec/conf/iccv/ChenW0YLSXX019.bib},
bibsource = {dblp computer science bibliography, https://dblp.org} bibsource = {dblp computer science bibliography, https://dblp.org}
``` ```
## Acknowledgement
Shout out to Davidgzx.

View File

@ -4,6 +4,7 @@ _base_ = [
'mmcls::_base_/default_runtime.py' 'mmcls::_base_/default_runtime.py'
] ]
res34_ckpt_path = 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_b16x8_cifar10_20210528-a8aa36a6.pth' # noqa: E501
model = dict( model = dict(
_scope_='mmrazor', _scope_='mmrazor',
type='DAFLDataFreeDistillation', type='DAFLDataFreeDistillation',
@ -21,7 +22,7 @@ model = dict(
build_cfg=dict( build_cfg=dict(
cfg_path='mmcls::resnet/resnet34_8xb16_cifar10.py', cfg_path='mmcls::resnet/resnet34_8xb16_cifar10.py',
pretrained=True), pretrained=True),
ckpt_path='resnet34_b16x8_cifar10_20210528-a8aa36a6.pth')), ckpt_path=res34_ckpt_path)),
generator=dict( generator=dict(
type='DAFLGenerator', type='DAFLGenerator',
img_size=32, img_size=32,

View File

@ -0,0 +1,43 @@
Collections:
- Name: DAFL
Metadata:
Training Data:
- CIFAR-10
Paper:
URL: https://doi.org/10.1109/ICCV.2019.00361
Title: Data-Free Learning of Student Networks
README: configs/distill/mmcls/dafl/README.md
Converted From:
Code:
URL: https://github.com/huawei-noah/Efficient-Computing/tree/master/Data-Efficient-Model-Compression/DAFL
Models:
- Name: dafl_logits_resnet34_resnet18_8xb256_cifar10
In Collection: DAFL
Metadata:
inference time (ms/im):
- value: 0.34
hardware: NVIDIA A100-SXM4-80GB
backend: PyTorch
batch size: 256
mode: FP32
resolution: (32, 32)
Location: logits
Student:
Config: mmcls::resnet/resnet18_8xb16_cifar10.py
Weights: https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_b16x8_cifar10_20210528-bd6371c8.pth
Metrics:
Top 1 Accuracy: 94.82
Top 5 Accuracy: 99.87
Teacher:
Config: mmcls::resnet/resnet34_8xb16_cifar10.py
Weights: https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_b16x8_cifar10_20210528-a8aa36a6.pth
Metrics:
Top 1 Accuracy: 95.34
Top 5 Accuracy: 99.87
Results:
- Task: Image Classification
Dataset: CIFAR-10
Metrics:
Top 1 Accuracy: 93.27
Config: configs/distill/mmcls/dafl/dafl_logits_resnet34_resnet18_8xb256_cifar10.py
Weights: https://download.openmmlab.com/mmrazor/v1/DAFL/dafl_logits_resnet34_resnet18_8xb256_cifar10_20220815_202654-67142167.pth

View File

@ -14,9 +14,9 @@ Knowledge Distillation (KD) has made remarkable progress in the last few years a
### Classification ### Classification
| Location | Dataset | Teacher | Student | Acc | Acc(T) | Acc(S) | Config | Download | | Location | Dataset | Teacher | Student | Acc | Acc(T) | Acc(S) | Config | Download |
| :------: | :-----: | :-------------------------------------------------------------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------------: | :---: | :----: | :----: | :----------------------------------------------: | :------------------------------------------------------------------------------------------------------------------------------------------- | | :------: | :-----: | :-------------------------------------------------------------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------------: | :---: | :----: | :----: | :--------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| logits | Cifar10 | [resnet34](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet34_8xb16_cifar10.py) | [resnet18](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet18_8xb16_cifar10.py) | 93.26 | 95.34 | 94.82 | [config](./dfad_logits_r34_r18_8xb32_cifar10.py) | [teacher](https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_b16x8_cifar10_20210528-a8aa36a6.pth) \|[model](<>) \| [log](<>) | | logits | Cifar10 | [resnet34](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet34_8xb16_cifar10.py) | [resnet18](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet18_8xb16_cifar10.py) | 92.80 | 95.34 | 94.82 | [config](./dfad_logits_resnet34_resnet18_8xb32_cifar10.py) | [teacher](https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_b16x8_cifar10_20210528-a8aa36a6.pth) \|[model](https://download.openmmlab.com/mmrazor/v1/DFAD/dfad_logits_resnet34_resnet18_8xb32_cifar10_20220819_051141-961a5b09.pth) \| [log](https://download.openmmlab.com/mmrazor/v1/DFAD/dfad_logits_resnet34_resnet18_8xb32_cifar10_20220819_051141-961a5b09.json) |
## Citation ## Citation
@ -28,7 +28,3 @@ Knowledge Distillation (KD) has made remarkable progress in the last few years a
year={2019} year={2019}
} }
``` ```
## Acknowledgement
Appreciate Davidgzx's contribution.

View File

@ -4,6 +4,7 @@ _base_ = [
'mmcls::_base_/default_runtime.py' 'mmcls::_base_/default_runtime.py'
] ]
res34_ckpt_path = 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_b16x8_cifar10_20210528-a8aa36a6.pth' # noqa: E501
model = dict( model = dict(
_scope_='mmrazor', _scope_='mmrazor',
type='DataFreeDistillation', type='DataFreeDistillation',
@ -21,7 +22,7 @@ model = dict(
build_cfg=dict( build_cfg=dict(
cfg_path='mmcls::resnet/resnet34_8xb16_cifar10.py', cfg_path='mmcls::resnet/resnet34_8xb16_cifar10.py',
pretrained=True), pretrained=True),
ckpt_path='resnet34_b16x8_cifar10_20210528-a8aa36a6.pth')), ckpt_path=res34_ckpt_path)),
generator=dict( generator=dict(
type='DAFLGenerator', type='DAFLGenerator',
img_size=32, img_size=32,

View File

@ -0,0 +1,43 @@
Collections:
- Name: DFAD
Metadata:
Training Data:
- CIFAR-10
Paper:
URL: https://arxiv.org/pdf/1912.11006.pdf
Title: Data-Free Adversarial Distillation
README: configs/distill/mmcls/dfad/README.md
Converted From:
Code:
URL: https://github.com/VainF/Data-Free-Adversarial-Distillation
Models:
- Name: dfad_logits_resnet34_resnet18_8xb32_cifar10
In Collection: DFAD
Metadata:
inference time (ms/im):
- value: 0.38
hardware: NVIDIA A100-SXM4-80GB
backend: PyTorch
batch size: 32
mode: FP32
resolution: (32, 32)
Location: logits
Student:
Config: mmcls::resnet/resnet18_8xb16_cifar10.py
Weights: https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_b16x8_cifar10_20210528-bd6371c8.pth
Metrics:
Top 1 Accuracy: 94.82
Top 5 Accuracy: 99.87
Teacher:
Config: mmcls::resnet/resnet34_8xb16_cifar10.py
Weights: https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_b16x8_cifar10_20210528-a8aa36a6.pth
Metrics:
Top 1 Accuracy: 95.34
Top 5 Accuracy: 99.87
Results:
- Task: Image Classification
Dataset: CIFAR-10
Metrics:
Top 1 Accuracy: 92.80
Config: configs/distill/mmcls/dfad/dfad_logits_resnet34_resnet18_8xb32_cifar10.py
Weights: https://download.openmmlab.com/mmrazor/v1/DFAD/dfad_logits_resnet34_resnet18_8xb32_cifar10_20220819_051141-961a5b09.pth

View File

@ -26,9 +26,9 @@ almost 10.4 times less parameters outperforms a larger, state-of-the-art teacher
### Classification ### Classification
| Location | Dataset | Teacher | Student | Acc | Acc(T) | Acc(S) | Config | Download | | Location | Dataset | Teacher | Student | Acc | Acc(T) | Acc(S) | Config | Download |
| :---------------: | :------: | :----------------------------------------------------------------------------------------------------------: | :----------------------------------------------------------------------------------------------------------: | :---: | :----: | :----: | :----------------------------------------------------------------: | :---------------------------------------------------------------------------------------------------------------------------------------- | | :---------------: | :------: | :----------------------------------------------------------------------------------------------------------: | :----------------------------------------------------------------------------------------------------------: | :---: | :----: | :----: | :-----------------------------------------------------------------: | :----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| backbone & logits | ImageNet | [resnet50](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet50_8xb32_in1k.py) | [resnet18](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet18_8xb32_in1k.py) | 70.85 | 76.55 | 69.90 | [config](./fitnet_backbone_logits_resnet50_resnet18_8xb16_in1k.py) | [teacher](https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth) \|[model](<>) \| [log](<>) | | backbone & logits | ImageNet | [resnet50](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet50_8xb32_in1k.py) | [resnet18](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet18_8xb32_in1k.py) | 70.58 | 76.55 | 69.90 | [config](./fitnets_backbone_logits_resnet50_resnet18_8xb32_in1k.py) | [teacher](https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth) \|[model](https://download.openmmlab.com/mmrazor/v1/FieNets/fitnets_backbone_logits_resnet50_resnet18_8xb32_in1k_20220830_155608-00ccdbe2.pth) \| [log](https://download.openmmlab.com/mmrazor/v1/FieNets/fitnets_backbone_logits_resnet50_resnet18_8xb32_in1k_20220830_155608-00ccdbe2.json) |
## Citation ## Citation

View File

@ -4,6 +4,7 @@ _base_ = [
'mmcls::_base_/default_runtime.py' 'mmcls::_base_/default_runtime.py'
] ]
teacher_ckpt = 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth' # noqa: E501
model = dict( model = dict(
_scope_='mmrazor', _scope_='mmrazor',
type='SingleTeacherDistill', type='SingleTeacherDistill',
@ -18,7 +19,7 @@ model = dict(
cfg_path='mmcls::resnet/resnet18_8xb32_in1k.py', pretrained=False), cfg_path='mmcls::resnet/resnet18_8xb32_in1k.py', pretrained=False),
teacher=dict( teacher=dict(
cfg_path='mmcls::resnet/resnet50_8xb32_in1k.py', pretrained=True), cfg_path='mmcls::resnet/resnet50_8xb32_in1k.py', pretrained=True),
teacher_ckpt='resnet50_8xb32_in1k_20210831-ea4938fc.pth', teacher_ckpt=teacher_ckpt,
distiller=dict( distiller=dict(
type='ConfigurableDistiller', type='ConfigurableDistiller',
student_recorders=dict( student_recorders=dict(

View File

@ -0,0 +1,40 @@
Collections:
- Name: FitNets
Metadata:
Training Data:
- ImageNet-1k
Paper:
URL: https://arxiv.org/abs/1412.6550
Title: FitNets- Hints for Thin Deep Nets
README: configs/distill/mmcls/fitnet/README.md
Models:
- Name: fitnets_backbone_logits_resnet50_resnet18_8xb32_in1k
In Collection: FitNets
Metadata:
inference time (ms/im):
- value: 0.18
hardware: NVIDIA A100-SXM4-80GB
backend: PyTorch
batch size: 32
mode: FP32
resolution: (224, 224)
Location: backbone & logits
Student:
Config: mmcls::resnet/resnet18_8xb32_in1k.py
Weights: https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_8xb32_in1k_20210831-fbbb1da6.pth
Metrics:
Top 1 Accuracy: 69.90
Top 5 Accuracy: 89.43
Teacher:
Config: mmcls::resnet/resnet50_8xb32_in1k.py
Weights: https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth
Metrics:
Top 1 Accuracy: 76.55
Top 5 Accuracy: 93.06
Results:
- Task: Image Classification
Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 70.58
Config: configs/distill/mmcls/fitnet/fitnets_backbone_logits_resnet50_resnet18_8xb32_in1k.py
Weights: https://download.openmmlab.com/mmrazor/v1/FieNets/fitnets_backbone_logits_resnet50_resnet18_8xb32_in1k_20220830_155608-00ccdbe2.pth

View File

@ -20,9 +20,9 @@ Performing knowledge transfer from a large teacher network to a smaller student
### Classification ### Classification
| Location | Dataset | Teacher | Student | Acc | Acc(T) | Acc(S) | Config | Download | | Location | Dataset | Teacher | Student | Acc | Acc(T) | Acc(S) | Config | Download |
| :---------------: | :-----: | :-------------------------------------------------------------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------------: | :---: | :----: | :----: | :-------------------------------------------------------: | :------------------------------------------------------------------------------------------------------------------------------------------- | | :---------------: | :-----: | :-------------------------------------------------------------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------------: | :---: | :----: | :----: | :-----------------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| backbone & logits | Cifar10 | [resnet34](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet34_8xb16_cifar10.py) | [resnet18](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet18_8xb16_cifar10.py) | 93.50 | 95.34 | 94.82 | [config](./zskt_backbone_logits_r34_r18_8xb16_cifar10.py) | [teacher](https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_b16x8_cifar10_20210528-a8aa36a6.pth) \|[model](<>) \| [log](<>) | | backbone & logits | Cifar10 | [resnet34](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet34_8xb16_cifar10.py) | [resnet18](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet18_8xb16_cifar10.py) | 93.05 | 95.34 | 94.82 | [config](./zskt_backbone_logits_resnet34_resnet18_8xb16_cifar10.py) | [teacher](https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_b16x8_cifar10_20210528-a8aa36a6.pth) \|[model](https://download.openmmlab.com/mmrazor/v1/ZSKT/zskt_backbone_logits_resnet34_resnet18_8xb16_cifar10_20220823_114006-28584c2e.pth) \| [log](https://download.openmmlab.com/mmrazor/v1/ZSKT/zskt_backbone_logits_resnet34_resnet18_8xb16_cifar10_20220823_114006-28584c2e.json) |
## Citation ## Citation
@ -35,7 +35,3 @@ Performing knowledge transfer from a large teacher network to a smaller student
year={2019} year={2019}
} }
``` ```
## Acknowledgement
Appreciate Davidgzx's contribution.

View File

@ -0,0 +1,43 @@
Collections:
- Name: ZSKT
Metadata:
Training Data:
- CIFAR-10
Paper:
URL: https://arxiv.org/abs/1905.09768
Title: Zero-shot Knowledge Transfer via Adversarial Belief Matching
README: configs/distill/mmcls/zskt/README.md
Converted From:
Code:
URL: https://github.com/polo5/ZeroShotKnowledgeTransfer
Models:
- Name: zskt_backbone_logits_resnet34_resnet18_8xb16_cifar10
In Collection: ZSKT
Metadata:
inference time (ms/im):
- value: 0.12
hardware: NVIDIA A100-SXM4-80GB
backend: PyTorch
batch size: 16
mode: FP32
resolution: (32, 32)
Location: backbone & logits
Student:
Config: mmcls::resnet/resnet18_8xb16_cifar10.py
Weights: https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_b16x8_cifar10_20210528-bd6371c8.pth
Metrics:
Top 1 Accuracy: 94.82
Top 5 Accuracy: 99.87
Teacher:
Config: mmcls::resnet/resnet34_8xb16_cifar10.py
Weights: https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_b16x8_cifar10_20210528-a8aa36a6.pth
Metrics:
Top 1 Accuracy: 95.34
Top 5 Accuracy: 99.87
Results:
- Task: Image Classification
Dataset: CIFAR-10
Metrics:
Top 1 Accuracy: 93.05
Config: configs/distill/mmcls/zskt/zskt_backbone_logits_resnet34_resnet18_8xb16_cifar10.py
Weights: https://download.openmmlab.com/mmrazor/v1/ZSKT/zskt_backbone_logits_resnet34_resnet18_8xb16_cifar10_20220823_114006-28584c2e.pth

View File

@ -4,6 +4,7 @@ _base_ = [
'mmcls::_base_/default_runtime.py' 'mmcls::_base_/default_runtime.py'
] ]
res34_ckpt_path = 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_b16x8_cifar10_20210528-a8aa36a6.pth' # noqa: E501
model = dict( model = dict(
_scope_='mmrazor', _scope_='mmrazor',
type='DataFreeDistillation', type='DataFreeDistillation',
@ -17,11 +18,11 @@ model = dict(
architecture=dict( architecture=dict(
cfg_path='mmcls::resnet/resnet18_8xb16_cifar10.py', pretrained=False), cfg_path='mmcls::resnet/resnet18_8xb16_cifar10.py', pretrained=False),
teachers=dict( teachers=dict(
r34=dict( res34=dict(
build_cfg=dict( build_cfg=dict(
cfg_path='mmcls::resnet/resnet34_8xb16_cifar10.py', cfg_path='mmcls::resnet/resnet34_8xb16_cifar10.py',
pretrained=True), pretrained=True),
ckpt_path='resnet34_b16x8_cifar10_20210528-a8aa36a6.pth')), ckpt_path=res34_ckpt_path)),
generator=dict( generator=dict(
type='ZSKTGenerator', img_size=32, latent_dim=256, type='ZSKTGenerator', img_size=32, latent_dim=256,
hidden_channels=128), hidden_channels=128),
@ -34,15 +35,15 @@ model = dict(
bb_s4=dict(type='ModuleOutputs', source='backbone.layer4.1.relu'), bb_s4=dict(type='ModuleOutputs', source='backbone.layer4.1.relu'),
fc=dict(type='ModuleOutputs', source='head.fc')), fc=dict(type='ModuleOutputs', source='head.fc')),
teacher_recorders=dict( teacher_recorders=dict(
r34_bb_s1=dict( res34_bb_s1=dict(
type='ModuleOutputs', source='r34.backbone.layer1.2.relu'), type='ModuleOutputs', source='res34.backbone.layer1.2.relu'),
r34_bb_s2=dict( res34_bb_s2=dict(
type='ModuleOutputs', source='r34.backbone.layer2.3.relu'), type='ModuleOutputs', source='res34.backbone.layer2.3.relu'),
r34_bb_s3=dict( res34_bb_s3=dict(
type='ModuleOutputs', source='r34.backbone.layer3.5.relu'), type='ModuleOutputs', source='res34.backbone.layer3.5.relu'),
r34_bb_s4=dict( res34_bb_s4=dict(
type='ModuleOutputs', source='r34.backbone.layer4.2.relu'), type='ModuleOutputs', source='res34.backbone.layer4.2.relu'),
r34_fc=dict(type='ModuleOutputs', source='r34.head.fc')), res34_fc=dict(type='ModuleOutputs', source='res34.head.fc')),
distill_losses=dict( distill_losses=dict(
loss_s1=dict(type='ATLoss', loss_weight=250.0), loss_s1=dict(type='ATLoss', loss_weight=250.0),
loss_s2=dict(type='ATLoss', loss_weight=250.0), loss_s2=dict(type='ATLoss', loss_weight=250.0),
@ -55,31 +56,31 @@ model = dict(
s_feature=dict( s_feature=dict(
from_student=True, recorder='bb_s1', record_idx=1), from_student=True, recorder='bb_s1', record_idx=1),
t_feature=dict( t_feature=dict(
from_student=False, recorder='r34_bb_s1', record_idx=1)), from_student=False, recorder='res34_bb_s1', record_idx=1)),
loss_s2=dict( loss_s2=dict(
s_feature=dict( s_feature=dict(
from_student=True, recorder='bb_s2', record_idx=1), from_student=True, recorder='bb_s2', record_idx=1),
t_feature=dict( t_feature=dict(
from_student=False, recorder='r34_bb_s2', record_idx=1)), from_student=False, recorder='res34_bb_s2', record_idx=1)),
loss_s3=dict( loss_s3=dict(
s_feature=dict( s_feature=dict(
from_student=True, recorder='bb_s3', record_idx=1), from_student=True, recorder='bb_s3', record_idx=1),
t_feature=dict( t_feature=dict(
from_student=False, recorder='r34_bb_s3', record_idx=1)), from_student=False, recorder='res34_bb_s3', record_idx=1)),
loss_s4=dict( loss_s4=dict(
s_feature=dict( s_feature=dict(
from_student=True, recorder='bb_s4', record_idx=1), from_student=True, recorder='bb_s4', record_idx=1),
t_feature=dict( t_feature=dict(
from_student=False, recorder='r34_bb_s4', record_idx=1)), from_student=False, recorder='res34_bb_s4', record_idx=1)),
loss_kl=dict( loss_kl=dict(
preds_S=dict(from_student=True, recorder='fc'), preds_S=dict(from_student=True, recorder='fc'),
preds_T=dict(from_student=False, recorder='r34_fc')))), preds_T=dict(from_student=False, recorder='res34_fc')))),
generator_distiller=dict( generator_distiller=dict(
type='ConfigurableDistiller', type='ConfigurableDistiller',
student_recorders=dict( student_recorders=dict(
fc=dict(type='ModuleOutputs', source='head.fc')), fc=dict(type='ModuleOutputs', source='head.fc')),
teacher_recorders=dict( teacher_recorders=dict(
r34_fc=dict(type='ModuleOutputs', source='r34.head.fc')), res34_fc=dict(type='ModuleOutputs', source='res34.head.fc')),
distill_losses=dict( distill_losses=dict(
loss_kl=dict( loss_kl=dict(
type='KLDivergence', type='KLDivergence',
@ -89,7 +90,7 @@ model = dict(
loss_forward_mappings=dict( loss_forward_mappings=dict(
loss_kl=dict( loss_kl=dict(
preds_S=dict(from_student=True, recorder='fc'), preds_S=dict(from_student=True, recorder='fc'),
preds_T=dict(from_student=False, recorder='r34_fc')))), preds_T=dict(from_student=False, recorder='res34_fc')))),
student_iter=10) student_iter=10)
# model wrapper # model wrapper

View File

@ -14,9 +14,9 @@ Knowledge distillation, in which a student model is trained to mimic a teacher m
### Detection ### Detection
| Location | Dataset | Teacher | Student | box AP | box AP(T) | box AP(S) | Config | Download | | Location | Dataset | Teacher | Student | box AP | box AP(T) | box AP(S) | Config | Download |
| :------: | :-----: | :-------------------------------------------------------------------------------------------------------------------------------: | :-----------------------------------------------------------------------------------------------------------------------------: | :----: | :-------: | :-------: | :--------------------------------------------------: | :-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | | :------: | :-----: | :--------------------------------------------------------------------------------------------------------------------------------: | :------------------------------------------------------------------------------------------------------------------------------: | :----: | :-------: | :-------: | :------------------------------------------------------------: | :----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| neck | COCO | [fasterrcnn_resnet101](https://github.com/open-mmlab/mmdetection/blob/master/configs/faster_rcnn/faster_rcnn_r101_fpn_1x_coco.py) | [fasterrcnn_resnet50](https://github.com/open-mmlab/mmdetection/blob/master/configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py) | 39.1 | 39.4 | 37.8 | [config](./fbkd_fpn_frcnn_r101_frcnn_r50_1x_coco.py) | [teacher](https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r101_fpn_1x_coco/faster_rcnn_r101_fpn_1x_coco_20200130-f513f705.pth) \|[model](<>) \| [log](<>) | | neck | COCO | [faster-rcnn_resnet101](https://github.com/open-mmlab/mmdetection/blob/master/configs/faster_rcnn/faster_rcnn_r101_fpn_1x_coco.py) | [faster-rcnn_resnet50](https://github.com/open-mmlab/mmdetection/blob/master/configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py) | 39.3 | 39.4 | 37.4 | [config](./fbkd_fpn_frcnn_resnet101_frcnn_resnet50_1x_coco.py) | [teacher](https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r101_fpn_1x_coco/faster_rcnn_r101_fpn_1x_coco_20200130-f513f705.pth) \|[model](https://download.openmmlab.com/mmrazor/v1/FBKD/fbkd_fpn_frcnn_resnet101_frcnn_resnet50_1x_coco_20220830_121522-8d7e11df.pth) \| [log](https://download.openmmlab.com/mmrazor/v1/FBKD/fbkd_fpn_frcnn_resnet101_frcnn_resnet50_1x_coco_20220830_121522-8d7e11df.json) |
## Citation ## Citation

View File

@ -4,16 +4,17 @@ _base_ = [
'mmdet::_base_/default_runtime.py' 'mmdet::_base_/default_runtime.py'
] ]
teacher_ckpt = 'https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r101_fpn_1x_coco/faster_rcnn_r101_fpn_1x_coco_20200130-f513f705.pth' # noqa: E501
model = dict( model = dict(
_scope_='mmrazor', _scope_='mmrazor',
type='SingleTeacherDistill', type='SingleTeacherDistill',
architecture=dict( architecture=dict(
cfg_path='mmdet::faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py', cfg_path='mmdet::faster_rcnn/faster-rcnn_r50_fpn_1x_coco.py',
pretrained=True), pretrained=True),
teacher=dict( teacher=dict(
cfg_path='mmdet::faster_rcnn/faster_rcnn_r101_fpn_1x_coco.py', cfg_path='mmdet::faster_rcnn/faster-rcnn_r101_fpn_1x_coco.py',
pretrained=False), pretrained=False),
teacher_ckpt='faster_rcnn_r101_fpn_1x_coco_20200130-f513f705.pth', teacher_ckpt=teacher_ckpt,
distiller=dict( distiller=dict(
type='ConfigurableDistiller', type='ConfigurableDistiller',
student_recorders=dict( student_recorders=dict(

View File

@ -0,0 +1,41 @@
Collections:
- Name: FBKD
Metadata:
Training Data:
- COCO
Paper:
URL: https://openreview.net/pdf?id=uKhGRvM8QNH
Title: IMPROVE OBJECT DETECTION WITH FEATURE-BASED KNOWLEDGE DISTILLATION- TOWARDS ACCURATE AND EFFICIENT DETECTORS
README: configs/distill/mmdet/fbkd/README.md
Converted From:
Code:
URL: https://github.com/ArchipLab-LinfengZhang/Object-Detection-Knowledge-Distillation-ICLR2021
Models:
- Name: fbkd_fpn_faster-rcnn_resnet101_faster-rcnn_resnet50_1x_coco
In Collection: FBKD
Metadata:
inference time (ms/im):
- value: 0.32
hardware: NVIDIA A100-SXM4-80GB
backend: PyTorch
batch size: 2
mode: FP32
resolution: (1333, 800)
Location: fpn
Student:
Metrics:
box AP: 37.4
Config: mmdet::faster_rcnn/faster-rcnn_r50_fpn_1x_coco.py
Weights: https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r50_fpn_1x_coco/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth
Teacher:
Metrics:
box AP: 39.4
Config: mmdet::faster_rcnn/faster-rcnn_r101_fpn_1x_coco.py
Weights: https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r101_fpn_1x_coco/faster_rcnn_r101_fpn_1x_coco_20200130-f513f705.pth
Results:
- Task: Object Detection
Dataset: COCO
Metrics:
box AP: 39.3
Config: configs/distill/mmdet/fbkd/fbkd_fpn_faster-rcnn_resnet101_faster-rcnn_resnet50_1x_coco.py
Weights: https://download.openmmlab.com/mmrazor/v1/FBKD/fbkd_fpn_frcnn_resnet101_frcnn_resnet50_1x_coco_20220830_121522-8d7e11df.pth

View File

@ -78,12 +78,12 @@ class DataFreeDistillation(BaseAlgorithm):
"""Alias for ``architecture``.""" """Alias for ``architecture``."""
return self.architecture return self.architecture
def train_step(self, data: List[dict], def train_step(self, data: Dict[str, List[dict]],
optim_wrapper: OptimWrapper) -> Dict[str, torch.Tensor]: optim_wrapper: OptimWrapper) -> Dict[str, torch.Tensor]:
"""Train step for DataFreeDistillation. """Train step for DataFreeDistillation.
Args: Args:
data (List[dict]): Data sampled by dataloader. data (Dict[str, List[dict]]): Data sampled by dataloader.
optim_wrapper (OptimWrapper): A wrapper of optimizer to optim_wrapper (OptimWrapper): A wrapper of optimizer to
update parameters. update parameters.
""" """
@ -107,16 +107,16 @@ class DataFreeDistillation(BaseAlgorithm):
return log_vars return log_vars
def train_student( def train_student(
self, data: List[dict], optimizer: OPTIMIZERS self, data: Dict[str, List[dict]], optimizer: OPTIMIZERS
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
"""Train step for the student model. """Train step for the student model.
Args: Args:
data (List[dict]): Data sampled by dataloader. data (Dict[str, List[dict]]): Data sampled by dataloader.
optimizer (OPTIMIZERS): The optimizer to update student. optimizer (OPTIMIZERS): The optimizer to update student.
""" """
log_vars = dict() log_vars = dict()
batch_size = len(data) batch_size = len(data['inputs'])
for _ in range(self.student_iter): for _ in range(self.student_iter):
fakeimg_init = torch.randn( fakeimg_init = torch.randn(
@ -124,13 +124,14 @@ class DataFreeDistillation(BaseAlgorithm):
fakeimg = self.generator(fakeimg_init, batch_size).detach() fakeimg = self.generator(fakeimg_init, batch_size).detach()
with optimizer.optim_context(self.student): with optimizer.optim_context(self.student):
_, data_samples = self.data_preprocessor(data, True) pseudo_data = self.data_preprocessor(data, True)
pseudo_data_samples = pseudo_data['data_samples']
# recorde the needed information # recorde the needed information
with self.distiller.student_recorders: with self.distiller.student_recorders:
_ = self.student(fakeimg, data_samples, mode='loss') _ = self.student(fakeimg, pseudo_data_samples, mode='loss')
with self.distiller.teacher_recorders, torch.no_grad(): with self.distiller.teacher_recorders, torch.no_grad():
for _, teacher in self.teachers.items(): for _, teacher in self.teachers.items():
_ = teacher(fakeimg, data_samples, mode='loss') _ = teacher(fakeimg, pseudo_data_samples, mode='loss')
loss_distill = self.distiller.compute_distill_losses() loss_distill = self.distiller.compute_distill_losses()
distill_loss, distill_log_vars = self.parse_losses(loss_distill) distill_loss, distill_log_vars = self.parse_losses(loss_distill)
@ -140,15 +141,15 @@ class DataFreeDistillation(BaseAlgorithm):
return distill_loss, log_vars return distill_loss, log_vars
def train_generator( def train_generator(
self, data: List[dict], optimizer: OPTIMIZERS self, data: Dict[str, List[dict]], optimizer: OPTIMIZERS
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
"""Train step for the generator. """Train step for the generator.
Args: Args:
data (List[dict]): Data sampled by dataloader. data (Dict[str, List[dict]]): Data sampled by dataloader.
optimizer (OPTIMIZERS): The optimizer to update generator. optimizer (OPTIMIZERS): The optimizer to update generator.
""" """
batch_size = len(data) batch_size = len(data['inputs'])
fakeimg_init = torch.randn( fakeimg_init = torch.randn(
(batch_size, self.generator.module.latent_dim)) (batch_size, self.generator.module.latent_dim))
fakeimg = self.generator(fakeimg_init, batch_size) fakeimg = self.generator(fakeimg_init, batch_size)
@ -174,17 +175,17 @@ class DataFreeDistillation(BaseAlgorithm):
@MODELS.register_module() @MODELS.register_module()
class DAFLDataFreeDistillation(DataFreeDistillation): class DAFLDataFreeDistillation(DataFreeDistillation):
def train_step(self, data: List[dict], def train_step(self, data: Dict[str, List[dict]],
optim_wrapper: OptimWrapper) -> Dict[str, torch.Tensor]: optim_wrapper: OptimWrapper) -> Dict[str, torch.Tensor]:
"""DAFL train step. """DAFL train step.
Args: Args:
data (List[dict]): Data sampled by dataloader. data (Dict[str, List[dict]): Data sampled by dataloader.
optim_wrapper (OptimWrapper): A wrapper of optimizer to optim_wrapper (OptimWrapper): A wrapper of optimizer to
update parameters. update parameters.
""" """
log_vars = dict() log_vars = dict()
batch_size = len(data) batch_size = len(data['inputs'])
for _, teacher in self.teachers.items(): for _, teacher in self.teachers.items():
teacher.eval() teacher.eval()

View File

@ -107,7 +107,6 @@ class TestLosses(TestCase):
ie_loss = InformationEntropyLoss(**dafl_loss_cfg, gather=True) ie_loss = InformationEntropyLoss(**dafl_loss_cfg, gather=True)
ie_loss.world_size = 2 ie_loss.world_size = 2
# TODO: configure circle CI to test UT under multi torch versions.
if digit_version(torch.__version__) >= digit_version('1.8.0'): if digit_version(torch.__version__) >= digit_version('1.8.0'):
with self.assertRaisesRegex( with self.assertRaisesRegex(
RuntimeError, RuntimeError,