[Feature] Add BYOT Distillation (#232)
* byot connector & distiller * fix config * fix connector * tmpsave * add byot & kdsoftce loss * update dev-1.x * fx wsld * Update README.md * Update README.md * fix md * add ut & REQUIRE REVIEW part * fix md * add SelfDistillValLoop UT * fix comments * fix comments v2 * fix comments v3 * add connector_idx=None to ConfigurableDistiller.get_record() Co-authored-by: zengyi.vendor <zengyi.vendor@sensetime.com>pull/255/head
parent
83240dcd8a
commit
c6e8dcd209
|
@ -0,0 +1,50 @@
|
|||
_base_ = ['./pipelines/auto_aug_cifar.py']
|
||||
|
||||
# dataset settings
|
||||
dataset_type = 'CIFAR100'
|
||||
preprocess_cfg = dict(
|
||||
# RGB format normalization parameters
|
||||
mean=[129.304, 124.070, 112.434],
|
||||
std=[68.170, 65.392, 70.418],
|
||||
# loaded images are already RGB format
|
||||
to_rgb=False)
|
||||
|
||||
train_pipeline = [
|
||||
dict(type='RandomCrop', crop_size=32, padding=4),
|
||||
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
|
||||
dict(type='Cutout', shape=16, pad_val=0),
|
||||
dict(type='AutoAugment', policies={{_base_.policy_cifar}}),
|
||||
dict(type='PackClsInputs'),
|
||||
]
|
||||
|
||||
test_pipeline = [
|
||||
dict(type='PackClsInputs'),
|
||||
]
|
||||
|
||||
train_dataloader = dict(
|
||||
batch_size=16,
|
||||
num_workers=2,
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_prefix='data/cifar100',
|
||||
test_mode=False,
|
||||
pipeline=train_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||
persistent_workers=True,
|
||||
)
|
||||
|
||||
val_dataloader = dict(
|
||||
batch_size=16,
|
||||
num_workers=2,
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_prefix='data/cifar100/',
|
||||
test_mode=True,
|
||||
pipeline=test_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
persistent_workers=True,
|
||||
)
|
||||
val_evaluator = dict(type='Accuracy', topk=(1, 5))
|
||||
|
||||
test_dataloader = val_dataloader
|
||||
test_evaluator = val_evaluator
|
|
@ -0,0 +1,125 @@
|
|||
# Policy for CIFAR, refer to
|
||||
# https://github.com/DeepVoltaire/AutoAugment/blame/master/autoaugment.py
|
||||
policy_cifar = [
|
||||
# Group 1
|
||||
[
|
||||
dict(type='Invert', prob=0.1),
|
||||
dict(type='Contrast', magnitude=0.5, prob=0.2)
|
||||
],
|
||||
[
|
||||
dict(type='Rotate', angle=10., prob=0.7),
|
||||
dict(type='Translate', magnitude=150 / 331, prob=0.3)
|
||||
],
|
||||
[
|
||||
dict(type='Sharpness', magnitude=0.9, prob=0.8),
|
||||
dict(type='Sharpness', magnitude=0.3, prob=0.9)
|
||||
],
|
||||
[
|
||||
dict(
|
||||
type='Shear',
|
||||
magnitude=0.3 / 9 * 8,
|
||||
direction='vertical',
|
||||
prob=0.5),
|
||||
dict(
|
||||
type='Translate',
|
||||
magnitude=150 / 331,
|
||||
direction='vertical',
|
||||
prob=0.3)
|
||||
],
|
||||
[dict(type='AutoContrast', prob=0.5),
|
||||
dict(type='Equalize', prob=0.9)],
|
||||
# Group 2
|
||||
[
|
||||
dict(
|
||||
type='Shear',
|
||||
magnitude=0.3 / 9 * 7,
|
||||
direction='vertical',
|
||||
prob=0.2),
|
||||
dict(type='Posterize', bits=5, prob=0.3)
|
||||
],
|
||||
[
|
||||
dict(type='ColorTransform', magnitude=0.3, prob=0.4),
|
||||
dict(type='Brightness', magnitude=0.7, prob=0.7)
|
||||
],
|
||||
[
|
||||
dict(type='Sharpness', magnitude=1.0, prob=0.3),
|
||||
dict(type='Brightness', magnitude=1.0, prob=0.7)
|
||||
],
|
||||
[dict(type='Equalize', prob=0.6),
|
||||
dict(type='Equalize', prob=0.5)],
|
||||
[
|
||||
dict(type='Contrast', magnitude=0.6, prob=0.6),
|
||||
dict(type='Sharpness', magnitude=0.4, prob=0.8),
|
||||
],
|
||||
# Group 3
|
||||
[
|
||||
dict(type='ColorTransform', magnitude=0.6, prob=0.7),
|
||||
dict(type='Translate', magnitude=150 / 331 / 9 * 7, prob=0.5)
|
||||
],
|
||||
[dict(type='Equalize', prob=0.3),
|
||||
dict(type='AutoContrast', prob=0.4)],
|
||||
[
|
||||
dict(
|
||||
type='Translate',
|
||||
magnitude=150 / 331 / 9 * 2,
|
||||
direction='vertical',
|
||||
prob=0.4),
|
||||
dict(type='Sharpness', magnitude=0.5, prob=0.2)
|
||||
],
|
||||
[
|
||||
dict(type='Brightness', magnitude=0.5, prob=0.9),
|
||||
dict(type='ColorTransform', magnitude=0.7, prob=0.2),
|
||||
],
|
||||
[
|
||||
dict(type='Solarize', thr=256 / 9 * 7, prob=0.5),
|
||||
dict(type='Invert', prob=0.0),
|
||||
],
|
||||
# Group 4
|
||||
[dict(type='Equalize', prob=0.2),
|
||||
dict(type='AutoContrast', prob=0.6)],
|
||||
[dict(type='Equalize', prob=0.2),
|
||||
dict(type='Equalize', prob=0.6)],
|
||||
[
|
||||
dict(type='ColorTransform', magnitude=0.9, prob=0.9),
|
||||
dict(type='Equalize', prob=0.6)
|
||||
],
|
||||
[
|
||||
dict(type='AutoContrast', prob=0.8),
|
||||
dict(type='Solarize', thr=256 / 9 * 1, prob=0.2),
|
||||
],
|
||||
[
|
||||
dict(type='Brightness', magnitude=0.3, prob=0.1),
|
||||
dict(type='ColorTransform', magnitude=0.0, prob=0.7)
|
||||
],
|
||||
# Group 5
|
||||
[
|
||||
dict(type='Solarize', thr=256 / 9 * 4, prob=0.4),
|
||||
dict(type='AutoContrast', prob=0.9)
|
||||
],
|
||||
[
|
||||
dict(
|
||||
type='Translate',
|
||||
magnitude=150 / 331,
|
||||
direction='vertical',
|
||||
prob=0.9),
|
||||
dict(
|
||||
type='Translate',
|
||||
magnitude=150 / 331,
|
||||
direction='vertical',
|
||||
prob=0.7)
|
||||
],
|
||||
[
|
||||
dict(type='AutoContrast', prob=0.9),
|
||||
dict(type='Solarize', thr=256 / 9 * 6, prob=0.8)
|
||||
],
|
||||
[dict(type='Equalize', prob=0.8),
|
||||
dict(type='Invert', prob=0.1)],
|
||||
[
|
||||
dict(
|
||||
type='Translate',
|
||||
magnitude=150 / 331,
|
||||
direction='vertical',
|
||||
prob=0.7),
|
||||
dict(type='AutoContrast', prob=0.9)
|
||||
]
|
||||
]
|
|
@ -0,0 +1,55 @@
|
|||
# Be Your Own Teacher: Improve the Performance of Convolutional Neural Networks via Self Distillation
|
||||
|
||||
## Abstract
|
||||
|
||||
Convolutional neural networks have been widely deployed in various application scenarios. In order to extend the applications' boundaries to some accuracy-crucial domains, researchers have been investigating approaches to boost accuracy through either deeper or wider network structures, which brings with them the exponential increment of the computational and storage cost, delaying the responding time. In this paper, we propose a general training framework named self distillation, which notably enhances the performance (accuracy) of convolutional neural networks through shrinking the size of the network rather than aggrandizing it. Different from traditional knowledge distillation - a knowledge transformation methodology among networks, which forces student neural networks to approximate the softmax layer outputs of pre-trained teacher neural networks, the proposed self distillation framework distills knowledge within network itself. The networks are firstly divided into several sections. Then the knowledge in the deeper portion of the networks is squeezed into the shallow ones. Experiments further prove the generalization of the proposed self distillation framework: enhancement of accuracy at average level is 2.65%, varying from 0.61% in ResNeXt as minimum to 4.07% in VGG19 as maximum. In addition, it can also provide flexibility of depth-wise scalable inference on resource-limited edge devices.Our codes will be released on github soon. [Unofficial code](https://github.com/luanyunteng/pytorch-be-your-own-teacher)
|
||||
|
||||
## Pipeline
|
||||
|
||||

|
||||
|
||||
## Results and models
|
||||
|
||||
#### Classification
|
||||
|
||||
| Location | Dataset | Model | Params(M) | Flops(G) | Top-1 (%) | Top-5 (%) | Download | |
|
||||
| :\------: | :------: | :-------------------------------------------------------: | :-------: | :------: | :-------: | :-------: | :--------------------------------------------------------------------------------------------------: |
|
||||
| logits | CIFAR100 | [R18_BYOT](./byot_logits_resnet18_cifar100_8xb16_in1k.py) | 11.22 | 0.56 | 80.66 | 95.76 | [model & log](https://autolink.sensetime.com/pages/model/share/08ad706f-b3d4-4854-8019-e0b43607f001) |
|
||||
|
||||
## Citation
|
||||
|
||||
```latex
|
||||
@ARTICLE{2019arXiv190508094Z,
|
||||
author = {{Zhang}, Linfeng and {Song}, Jiebo and {Gao}, Anni and {Chen}, Jingwei and {Bao}, Chenglong and {Ma}, Kaisheng},
|
||||
title = {Be Your Own Teacher: Improve the Performance of Convolutional Neural Networks via Self Distillation},
|
||||
journal = {arXiv e-prints},
|
||||
keywords = {Computer Science - Machine Learning, Statistics - Machine Learning},
|
||||
year = 2019,
|
||||
month = may,
|
||||
eid = {arXiv:1905.08094},
|
||||
pages = {arXiv:1905.08094},
|
||||
archivePrefix = {arXiv},
|
||||
eprint = {1905.08094},
|
||||
primaryClass = {cs.LG},
|
||||
adsurl = {https://ui.adsabs.harvard.edu/abs/2019arXiv190508094Z},
|
||||
adsnote = {Provided by the SAO/NASA Astrophysics Data System}
|
||||
}
|
||||
```
|
||||
|
||||
## Get Started
|
||||
|
||||
### Distillation training.
|
||||
|
||||
```bash
|
||||
sh tools/slurm_train.sh $PARTITION $JOB_NAME \
|
||||
configs/distill/mmcls/byot/byot_logits_resnet18_cifar100_8xb16_in1k.py\
|
||||
$WORK_DIR
|
||||
```
|
||||
|
||||
### Test
|
||||
|
||||
```bash
|
||||
sh tools/slurm_test.sh $PARTITION $JOB_NAME \
|
||||
configs/distill/mmcls/byot/byot_logits_resnet18_cifar100_8xb16_in1k.py\
|
||||
$WORK_DIR/latest.sh --eval $EVAL_SETTING
|
||||
```
|
|
@ -0,0 +1,155 @@
|
|||
_base_ = [
|
||||
'../../../_base_/datasets/mmcls/cifar100_bs16_auto_aug.py',
|
||||
'mmcls::_base_/schedules/cifar10_bs128.py',
|
||||
'mmcls::_base_/default_runtime.py'
|
||||
]
|
||||
|
||||
optim_wrapper = dict(
|
||||
optimizer=dict(type='SGD', lr=0.1, momentum=0.9, weight_decay=0.0005))
|
||||
param_scheduler = dict(
|
||||
type='MultiStepLR', by_epoch=True, milestones=[80, 160, 240], gamma=0.1)
|
||||
train_cfg = dict(by_epoch=True, max_epochs=250, val_interval=1)
|
||||
|
||||
model = dict(
|
||||
_scope_='mmrazor',
|
||||
type='SelfDistill',
|
||||
data_preprocessor=dict(
|
||||
type='ImgDataPreprocessor',
|
||||
# RGB format normalization parameters
|
||||
mean=[129.304, 124.070, 112.434],
|
||||
std=[68.170, 65.392, 70.418],
|
||||
# convert image from BGR to RGB
|
||||
bgr_to_rgb=False),
|
||||
architecture=dict(
|
||||
type='mmcls.ImageClassifier',
|
||||
backbone=dict(
|
||||
type='mmcls.ResNet_CIFAR',
|
||||
depth=18,
|
||||
num_stages=4,
|
||||
out_indices=(3, ),
|
||||
style='pytorch'),
|
||||
neck=dict(type='mmcls.GlobalAveragePooling'),
|
||||
head=dict(
|
||||
type='mmcls.LinearClsHead',
|
||||
num_classes=100,
|
||||
in_channels=512,
|
||||
loss=dict(type='mmcls.CrossEntropyLoss', loss_weight=1.0))),
|
||||
distiller=dict(
|
||||
type='BYOTDistiller',
|
||||
student_recorders=dict(
|
||||
bb_s1=dict(type='ModuleOutputs', source='backbone.layer1.1.relu'),
|
||||
bb_s2=dict(type='ModuleOutputs', source='backbone.layer2.1.relu'),
|
||||
bb_s3=dict(type='ModuleOutputs', source='backbone.layer3.1.relu')),
|
||||
teacher_recorders=dict(
|
||||
fc=dict(type='ModuleOutputs', source='head.fc'),
|
||||
neck_gap=dict(type='ModuleOutputs', source='neck.gap'),
|
||||
gt_labels=dict(type='ModuleInputs', source='head.loss_module')),
|
||||
distill_losses=dict(
|
||||
loss_fet_1=dict(
|
||||
type='L2Loss', normalize=False, loss_weight=0.03, dist=True),
|
||||
loss_label_1=dict(type='mmcls.CrossEntropyLoss', loss_weight=0.7),
|
||||
loss_softl_1=dict(type='KLDivergence', tau=3, loss_weight=0.3),
|
||||
loss_fet_2=dict(
|
||||
type='L2Loss', normalize=False, loss_weight=0.03, dist=True),
|
||||
loss_label_2=dict(type='mmcls.CrossEntropyLoss', loss_weight=0.7),
|
||||
loss_softl_2=dict(type='KLDivergence', tau=3, loss_weight=0.3),
|
||||
loss_fet_3=dict(
|
||||
type='L2Loss', normalize=False, loss_weight=0., dist=True),
|
||||
loss_label_3=dict(type='mmcls.CrossEntropyLoss', loss_weight=0.7),
|
||||
loss_softl_3=dict(type='KLDivergence', tau=3, loss_weight=0.3)),
|
||||
connectors=dict(
|
||||
loss_s1_sfeat=dict(
|
||||
type='BYOTConnector',
|
||||
in_channel=64,
|
||||
out_channel=512,
|
||||
expansion=1,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
num_classes=100),
|
||||
loss_s2_sfeat=dict(
|
||||
type='BYOTConnector',
|
||||
in_channel=128,
|
||||
out_channel=512,
|
||||
expansion=1,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
num_classes=100),
|
||||
loss_s3_sfeat=dict(
|
||||
type='BYOTConnector',
|
||||
in_channel=256,
|
||||
out_channel=512,
|
||||
expansion=1,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
num_classes=100)),
|
||||
loss_forward_mappings=dict(
|
||||
loss_fet_1=dict(
|
||||
s_feature=dict(
|
||||
recorder='bb_s1',
|
||||
from_student=True,
|
||||
connector='loss_s1_sfeat',
|
||||
connector_idx=0),
|
||||
t_feature=dict(recorder='neck_gap', from_student=False)),
|
||||
loss_label_1=dict(
|
||||
cls_score=dict(
|
||||
recorder='bb_s1',
|
||||
from_student=True,
|
||||
connector='loss_s1_sfeat',
|
||||
connector_idx=1),
|
||||
label=dict(
|
||||
recorder='gt_labels', from_student=False, data_idx=1)),
|
||||
loss_softl_1=dict(
|
||||
preds_S=dict(
|
||||
recorder='bb_s1',
|
||||
from_student=True,
|
||||
connector='loss_s1_sfeat',
|
||||
connector_idx=1),
|
||||
preds_T=dict(recorder='fc', from_student=False)),
|
||||
loss_fet_2=dict(
|
||||
s_feature=dict(
|
||||
recorder='bb_s2',
|
||||
from_student=True,
|
||||
connector='loss_s2_sfeat',
|
||||
connector_idx=0),
|
||||
t_feature=dict(recorder='neck_gap', from_student=False)),
|
||||
loss_label_2=dict(
|
||||
cls_score=dict(
|
||||
recorder='bb_s2',
|
||||
from_student=True,
|
||||
connector='loss_s2_sfeat',
|
||||
connector_idx=1),
|
||||
label=dict(
|
||||
recorder='gt_labels', from_student=False, data_idx=1)),
|
||||
loss_softl_2=dict(
|
||||
preds_S=dict(
|
||||
recorder='bb_s2',
|
||||
from_student=True,
|
||||
connector='loss_s2_sfeat',
|
||||
connector_idx=1),
|
||||
preds_T=dict(recorder='fc', from_student=False)),
|
||||
loss_fet_3=dict(
|
||||
s_feature=dict(
|
||||
recorder='bb_s3',
|
||||
from_student=True,
|
||||
connector='loss_s3_sfeat',
|
||||
connector_idx=0),
|
||||
t_feature=dict(recorder='neck_gap', from_student=False)),
|
||||
loss_label_3=dict(
|
||||
cls_score=dict(
|
||||
recorder='bb_s3',
|
||||
from_student=True,
|
||||
connector='loss_s3_sfeat',
|
||||
connector_idx=1),
|
||||
label=dict(
|
||||
recorder='gt_labels', from_student=False, data_idx=1)),
|
||||
loss_softl_3=dict(
|
||||
preds_S=dict(
|
||||
recorder='bb_s3',
|
||||
from_student=True,
|
||||
connector='loss_s3_sfeat',
|
||||
connector_idx=1),
|
||||
preds_T=dict(recorder='fc', from_student=False)))))
|
||||
|
||||
find_unused_parameters = True
|
||||
|
||||
val_cfg = dict(_delete_=True, type='mmrazor.SelfDistillValLoop')
|
Binary file not shown.
After Width: | Height: | Size: 274 KiB |
|
@ -3,12 +3,12 @@ from .hooks import DumpSubnetHook
|
|||
from .optimizers import SeparateOptimWrapperConstructor
|
||||
from .runner import (AutoSlimValLoop, DartsEpochBasedTrainLoop,
|
||||
DartsIterBasedTrainLoop, EvolutionSearchLoop,
|
||||
GreedySamplerTrainLoop, SingleTeacherDistillValLoop,
|
||||
SlimmableValLoop)
|
||||
GreedySamplerTrainLoop, SelfDistillValLoop,
|
||||
SingleTeacherDistillValLoop, SlimmableValLoop)
|
||||
|
||||
__all__ = [
|
||||
'SeparateOptimWrapperConstructor', 'DumpSubnetHook',
|
||||
'SingleTeacherDistillValLoop', 'DartsEpochBasedTrainLoop',
|
||||
'DartsIterBasedTrainLoop', 'SlimmableValLoop', 'EvolutionSearchLoop',
|
||||
'GreedySamplerTrainLoop', 'AutoSlimValLoop'
|
||||
'GreedySamplerTrainLoop', 'AutoSlimValLoop', 'SelfDistillValLoop'
|
||||
]
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .autoslim_val_loop import AutoSlimValLoop
|
||||
from .darts_loop import DartsEpochBasedTrainLoop, DartsIterBasedTrainLoop
|
||||
from .distill_val_loop import SingleTeacherDistillValLoop
|
||||
from .distill_val_loop import SelfDistillValLoop, SingleTeacherDistillValLoop
|
||||
from .evolution_search_loop import EvolutionSearchLoop
|
||||
from .slimmable_val_loop import SlimmableValLoop
|
||||
from .subnet_sampler_loop import GreedySamplerTrainLoop
|
||||
|
@ -9,5 +9,5 @@ from .subnet_sampler_loop import GreedySamplerTrainLoop
|
|||
__all__ = [
|
||||
'SingleTeacherDistillValLoop', 'DartsEpochBasedTrainLoop',
|
||||
'DartsIterBasedTrainLoop', 'SlimmableValLoop', 'EvolutionSearchLoop',
|
||||
'GreedySamplerTrainLoop', 'AutoSlimValLoop'
|
||||
'GreedySamplerTrainLoop', 'AutoSlimValLoop', 'SelfDistillValLoop'
|
||||
]
|
||||
|
|
|
@ -100,3 +100,42 @@ class SingleTeacherDistillValLoop(ValLoop):
|
|||
batch_idx=idx,
|
||||
data_batch=data_batch,
|
||||
outputs=outputs)
|
||||
|
||||
|
||||
@LOOPS.register_module()
|
||||
class SelfDistillValLoop(ValLoop):
|
||||
"""Knowledge Distill loop for validation. Only validate student.
|
||||
|
||||
Args:
|
||||
runner (Runner): A reference of runner.
|
||||
dataloader (Dataloader or dict): A dataloader object or a dict to
|
||||
build a dataloader.
|
||||
evaluator (Evaluator or dict or list): Used for computing metrics.
|
||||
fp16 (bool): Whether to enable fp16 validation. Defaults to
|
||||
False.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
runner,
|
||||
dataloader: Union[DataLoader, Dict],
|
||||
evaluator: Union[Evaluator, Dict, List],
|
||||
fp16: bool = False) -> None:
|
||||
super().__init__(runner, dataloader, evaluator, fp16)
|
||||
|
||||
def run(self):
|
||||
"""Launch validation."""
|
||||
self.runner.call_hook('before_val')
|
||||
self.runner.call_hook('before_val_epoch')
|
||||
self.runner.model.eval()
|
||||
|
||||
for idx, data_batch in enumerate(self.dataloader):
|
||||
self.run_iter(idx, data_batch)
|
||||
# compute student metrics
|
||||
metrics = self.evaluator.evaluate(len(self.dataloader.dataset))
|
||||
student_metrics = dict()
|
||||
for key, value in metrics.items():
|
||||
student_key = 'student.' + key
|
||||
student_metrics[student_key] = value
|
||||
|
||||
self.runner.call_hook('after_val_epoch', metrics=student_metrics)
|
||||
self.runner.call_hook('after_val')
|
||||
|
|
|
@ -1,11 +1,11 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .base import BaseAlgorithm
|
||||
from .distill import FpnTeacherDistill, SingleTeacherDistill
|
||||
from .distill import FpnTeacherDistill, SelfDistill, SingleTeacherDistill
|
||||
from .nas import SPOS, AutoSlim, AutoSlimDDP, Darts, DartsDDP
|
||||
from .pruning import SlimmableNetwork, SlimmableNetworkDDP
|
||||
|
||||
__all__ = [
|
||||
'SingleTeacherDistill', 'BaseAlgorithm', 'FpnTeacherDistill', 'SPOS',
|
||||
'SlimmableNetwork', 'SlimmableNetworkDDP', 'AutoSlim', 'AutoSlimDDP',
|
||||
'Darts', 'DartsDDP'
|
||||
'Darts', 'DartsDDP', 'SelfDistill'
|
||||
]
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .configurable import FpnTeacherDistill, SingleTeacherDistill
|
||||
from .configurable import FpnTeacherDistill, SelfDistill, SingleTeacherDistill
|
||||
|
||||
__all__ = ['SingleTeacherDistill', 'FpnTeacherDistill']
|
||||
__all__ = ['SingleTeacherDistill', 'FpnTeacherDistill', 'SelfDistill']
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .fpn_teacher_distill import FpnTeacherDistill
|
||||
from .self_distill import SelfDistill
|
||||
from .single_teacher_distill import SingleTeacherDistill
|
||||
|
||||
__all__ = ['SingleTeacherDistill', 'FpnTeacherDistill']
|
||||
__all__ = ['SelfDistill', 'SingleTeacherDistill', 'FpnTeacherDistill']
|
||||
|
|
|
@ -0,0 +1,92 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
from mmengine import BaseDataElement
|
||||
from torch import nn
|
||||
|
||||
from mmrazor.models.utils import add_prefix
|
||||
from mmrazor.registry import MODELS
|
||||
from ...base import BaseAlgorithm, LossResults
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class SelfDistill(BaseAlgorithm):
|
||||
"""``SelfDistill`` can be used to develop distill algorithms without
|
||||
teacher.
|
||||
|
||||
Args:
|
||||
distiller (dict): The config dict for built distiller. Distiller may
|
||||
have teacher.
|
||||
student_trainable (bool): Whether the student is trainable. Defaults
|
||||
to True.
|
||||
calculate_student_loss (bool): Whether to calculate student loss
|
||||
(original task loss) to update student model. Defaults to True.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
distiller: dict,
|
||||
student_trainable: bool = True,
|
||||
calculate_student_loss: bool = True,
|
||||
**kwargs) -> None:
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.distiller = MODELS.build(distiller)
|
||||
# The student model will not calculate gradients and update parameters
|
||||
# in some pretraining process.
|
||||
self.student_trainable = student_trainable
|
||||
|
||||
# The student loss will not be updated into ``losses`` in some
|
||||
# pretraining process.
|
||||
self.calculate_student_loss = calculate_student_loss
|
||||
|
||||
# In ``ConfigurableDistller``, the recorder manager is just
|
||||
# constructed, but not really initialized yet.
|
||||
self.distiller.prepare_from_student(self.student)
|
||||
# Still prepare from self-teacher. Teacher recorders of
|
||||
# ``SelfDistiller`` hook from self.student but require detach().
|
||||
self.distiller.prepare_from_teacher(self.student)
|
||||
|
||||
@property
|
||||
def student(self) -> nn.Module:
|
||||
"""Alias for ``architecture``."""
|
||||
return self.architecture
|
||||
|
||||
def loss(
|
||||
self,
|
||||
batch_inputs: torch.Tensor,
|
||||
data_samples: Optional[List[BaseDataElement]] = None,
|
||||
) -> LossResults:
|
||||
"""Calculate losses from a batch of inputs and data samples."""
|
||||
|
||||
losses = dict()
|
||||
|
||||
# If the `override_data` of a delivery is True, the delivery will
|
||||
# override the origin data with the recorded data.
|
||||
self.distiller.set_deliveries_override(True)
|
||||
# Original task loss will not be used during some pretraining process.
|
||||
if self.calculate_student_loss:
|
||||
# teacher_recorders hook from student
|
||||
with self.distiller.student_recorders, \
|
||||
self.distiller.teacher_recorders, \
|
||||
self.distiller.deliveries:
|
||||
student_losses = self.student(
|
||||
batch_inputs, data_samples, mode='loss')
|
||||
losses.update(add_prefix(student_losses, 'student'))
|
||||
else:
|
||||
with self.distiller.student_recorders, \
|
||||
self.distiller.teacher_recorders, \
|
||||
self.distiller.deliveries:
|
||||
if self.student_trainable:
|
||||
_ = self.student(batch_inputs, data_samples, mode='loss')
|
||||
else:
|
||||
with torch.no_grad():
|
||||
_ = self.student(
|
||||
batch_inputs, data_samples, mode='loss')
|
||||
|
||||
# Automatically compute distill losses based on `loss_forward_mappings`
|
||||
# The required data already exists in the recorders.
|
||||
distill_losses = self.distiller.compute_distill_losses()
|
||||
losses.update(add_prefix(distill_losses, 'distill'))
|
||||
|
||||
return losses
|
|
@ -1,4 +1,5 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .byot_connector import BYOTConnector
|
||||
from .convmodule_connector import ConvModuleConncetor
|
||||
|
||||
__all__ = ['ConvModuleConncetor']
|
||||
__all__ = ['ConvModuleConncetor', 'BYOTConnector']
|
||||
|
|
|
@ -0,0 +1,82 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from math import log
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from mmrazor.registry import MODELS
|
||||
from ...ops.darts_series import DartsSepConv
|
||||
from .base_connector import BaseConnector
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class BYOTConnector(BaseConnector):
|
||||
"""BYOTConnector connector that adds a self-attention with DartsSepConv.
|
||||
|
||||
Args:
|
||||
in_channel (int): The input channel of the DartsSepConv.
|
||||
Use like input_tensor_channel = in_channel * expansion.
|
||||
out_channel (int): The output channel of the DartsSepConv.
|
||||
Use like output_tensor_channel = out_channel * expansion.
|
||||
num_classes (int): The classification class num.
|
||||
expansion (int): Expansion of DartsSepConv. Default to 4.
|
||||
pool_size (int | tuple[int]): Average 2D pool size. Default to 4.
|
||||
kernel_size (int | tuple[int]): Size of the convolving kernel in
|
||||
DartsSepConv. Same as that in ``nn._ConvNd``. Default to 3.
|
||||
stride (int | tuple[int]): Stride of the first layer in DartsSepConv.
|
||||
Same as that in ``nn._ConvNd``. Default to 1.
|
||||
init_cfg (dict, optional): The config to control the initialization.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channel: int,
|
||||
out_channel: int,
|
||||
num_classes: int,
|
||||
expansion: int = 4,
|
||||
pool_size: Union[int, Tuple[int]] = 4,
|
||||
kernel_size: Union[int, Tuple[int]] = 3,
|
||||
stride: Union[int, Tuple[int]] = 1,
|
||||
init_cfg: Optional[Dict] = None,
|
||||
) -> None:
|
||||
super().__init__(init_cfg)
|
||||
self.attention = nn.Sequential(
|
||||
DartsSepConv(
|
||||
in_channels=in_channel * expansion,
|
||||
out_channels=in_channel * expansion,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride), nn.BatchNorm2d(in_channel * expansion),
|
||||
nn.ReLU(), nn.Upsample(scale_factor=2, mode='bilinear'),
|
||||
nn.Sigmoid())
|
||||
scala_num = log(out_channel / in_channel, 2)
|
||||
assert scala_num.is_integer()
|
||||
scala = []
|
||||
|
||||
_in_channel = in_channel
|
||||
|
||||
for _ in range(int(scala_num)):
|
||||
scala.append(
|
||||
DartsSepConv(
|
||||
in_channels=_in_channel * expansion,
|
||||
out_channels=_in_channel * 2 * expansion,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride))
|
||||
_in_channel *= 2
|
||||
scala.append(nn.AvgPool2d(pool_size))
|
||||
self.scala = nn.Sequential(*scala)
|
||||
self.fc = nn.Linear(out_channel * expansion, num_classes)
|
||||
|
||||
def forward_train(self, feature: torch.Tensor) -> torch.Tensor:
|
||||
"""Forward computation.
|
||||
|
||||
Args:
|
||||
feature (torch.Tensor): Input feature.
|
||||
"""
|
||||
feat = self.attention(feature)
|
||||
feat = feat * feature
|
||||
|
||||
feat = self.scala(feat)
|
||||
feat = feat.view(feature.size(0), -1)
|
||||
logits = self.fc(feat)
|
||||
return (feat, logits)
|
|
@ -1,5 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .base_distiller import BaseDistiller
|
||||
from .byot_distiller import BYOTDistiller
|
||||
from .configurable_distiller import ConfigurableDistiller
|
||||
|
||||
__all__ = ['ConfigurableDistiller', 'BaseDistiller']
|
||||
__all__ = ['ConfigurableDistiller', 'BaseDistiller', 'BYOTDistiller']
|
||||
|
|
|
@ -0,0 +1,44 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import List, Optional
|
||||
|
||||
from mmrazor.registry import MODELS
|
||||
from .configurable_distiller import ConfigurableDistiller
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class BYOTDistiller(ConfigurableDistiller):
|
||||
"""``BYOTDistiller`` inherits ``ConfigurableDistiller`` and only modifies
|
||||
``get_record()`` function to ``get_record_with_cidx()``.
|
||||
|
||||
In ``BYOTDistiller``, ``self.teacher_recorder`` records self-teacher data
|
||||
which requires detach().
|
||||
"""
|
||||
|
||||
def get_record(self,
|
||||
recorder: str,
|
||||
from_student: bool,
|
||||
record_idx: int = 0,
|
||||
data_idx: Optional[int] = None,
|
||||
connector: Optional[str] = None,
|
||||
connector_idx: Optional[int] = None) -> List:
|
||||
"""According to each item in ``record_infos``, get the corresponding
|
||||
record in ``recorder_manager``.
|
||||
|
||||
Detach teacher_record.
|
||||
"""
|
||||
|
||||
if from_student:
|
||||
recorder_ = self.student_recorders.get_recorder(recorder)
|
||||
else:
|
||||
recorder_ = self.teacher_recorders.get_recorder(recorder)
|
||||
record_data = recorder_.get_record_data(record_idx, data_idx)
|
||||
|
||||
if connector:
|
||||
record_data = self.connectors[connector](record_data)
|
||||
if connector_idx is not None:
|
||||
record_data = record_data[connector_idx]
|
||||
# Detach self-teacher output Tensor from model, assert hook tensor.
|
||||
if not from_student:
|
||||
record_data = record_data.detach()
|
||||
|
||||
return record_data
|
|
@ -180,7 +180,8 @@ class ConfigurableDistiller(BaseDistiller):
|
|||
from_student: bool,
|
||||
record_idx: int = 0,
|
||||
data_idx: Optional[int] = None,
|
||||
connector: Optional[str] = None) -> List:
|
||||
connector: Optional[str] = None,
|
||||
connector_idx: Optional[int] = None) -> List:
|
||||
"""According to each item in ``record_infos``, get the corresponding
|
||||
record in ``recorder_manager``."""
|
||||
|
||||
|
@ -192,6 +193,8 @@ class ConfigurableDistiller(BaseDistiller):
|
|||
|
||||
if connector:
|
||||
record_data = self.connectors[connector](record_data)
|
||||
if connector_idx is not None:
|
||||
record_data = record_data[connector_idx]
|
||||
|
||||
return record_data
|
||||
|
||||
|
@ -235,9 +238,10 @@ class ConfigurableDistiller(BaseDistiller):
|
|||
f'instance, but got {type(forward_mappings)}')
|
||||
|
||||
loss_module = losses[loss_name]
|
||||
loss_forward_keys = signature(
|
||||
loss_module.forward).parameters.keys()
|
||||
assert len(loss_forward_keys) == len(forward_mappings.keys())
|
||||
loss_forward_params = signature(loss_module.forward).parameters
|
||||
loss_forward_keys = loss_forward_params.keys()
|
||||
# Allow default params.
|
||||
# Check non-default params, not len(params).
|
||||
|
||||
for forward_key, record_info in forward_mappings.items():
|
||||
assert forward_key in loss_forward_keys, \
|
||||
|
@ -245,6 +249,11 @@ class ConfigurableDistiller(BaseDistiller):
|
|||
{type(loss_module).__name__} forward, \
|
||||
please check your config.'
|
||||
|
||||
if (loss_forward_params[forward_key].default ==
|
||||
loss_forward_params[forward_key].empty):
|
||||
# default params without check
|
||||
continue
|
||||
|
||||
assert 'recorder' in record_info, \
|
||||
'Each item of loss_forward_mappings should have ' \
|
||||
'"recorder", pls check your config.'
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
from .ab_loss import ABLoss
|
||||
from .cwd import ChannelWiseDivergence
|
||||
from .decoupled_kd import DKDLoss
|
||||
from .kd_soft_ce_loss import KDSoftCELoss
|
||||
from .kl_divergence import KLDivergence
|
||||
from .l2_loss import L2Loss
|
||||
from .relational_kd import AngleWiseRKD, DistanceWiseRKD
|
||||
|
@ -9,5 +10,5 @@ from .weighted_soft_label_distillation import WSLD
|
|||
|
||||
__all__ = [
|
||||
'ChannelWiseDivergence', 'KLDivergence', 'AngleWiseRKD', 'DistanceWiseRKD',
|
||||
'WSLD', 'L2Loss', 'ABLoss', 'DKDLoss'
|
||||
'WSLD', 'L2Loss', 'ABLoss', 'DKDLoss', 'KDSoftCELoss'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,89 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmcls.models.losses.cross_entropy_loss import soft_cross_entropy
|
||||
|
||||
from mmrazor.registry import MODELS
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class KDSoftCELoss(nn.Module):
|
||||
"""Distilling the Knowledge in a Neural Network, NIPS2014. Based on Soft
|
||||
Cross Entropy criterion.
|
||||
|
||||
https://arxiv.org/pdf/1503.02531.pdf
|
||||
|
||||
|
||||
Args:
|
||||
tau (int, optional): Temperature. Defaults to 1.
|
||||
reduction (str): Specifies the reduction to apply to the loss:
|
||||
``'none'`` | ``'none'`` | ``'sum'`` | ``'mean'``.
|
||||
``'none'``: no reduction will be applied,
|
||||
``'sum'``: the output will be summed,
|
||||
``'mean'``: the output will be divided by the number of
|
||||
elements in the output.
|
||||
Default: ``'mean'``
|
||||
mult_tem_square (bool, optional): Multiply square of temperature
|
||||
or not. Defaults to True.
|
||||
loss_weight (float): Weight of loss. Defaults to 1.0.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tau: float = 1.0,
|
||||
reduction: str = 'mean',
|
||||
mult_tem_square: bool = True,
|
||||
loss_weight: float = 1.0,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.tau = tau
|
||||
self.mult_tem_square = mult_tem_square
|
||||
self.loss_weight = loss_weight
|
||||
self.cls_criterion = soft_cross_entropy
|
||||
|
||||
accept_reduction = {None, 'none', 'mean', 'sum'}
|
||||
assert reduction in accept_reduction, \
|
||||
f'KLDivergence supports reduction {accept_reduction}, ' \
|
||||
f'but gets {reduction}.'
|
||||
self.reduction = reduction
|
||||
|
||||
def forward(
|
||||
self,
|
||||
preds_S: torch.Tensor,
|
||||
preds_T: torch.Tensor,
|
||||
weight: torch.Tensor = None,
|
||||
avg_factor: int = None,
|
||||
reduction_override: str = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward computation.
|
||||
|
||||
Args:
|
||||
preds_S (torch.Tensor): The student model prediction with
|
||||
shape (N, C).
|
||||
preds_T (torch.Tensor): The teacher model prediction with
|
||||
shape (N, C).
|
||||
weight (torch.Tensor, optional): Sample-wise loss weight with
|
||||
shape (N, C). Defaults to None.
|
||||
avg_factor (int, optional): Average factor that is used to average
|
||||
the loss. Defaults to None.
|
||||
reduction_override (str, optiom): Override redunction in forward.
|
||||
Defaults to None.
|
||||
|
||||
Return:
|
||||
torch.Tensor: The calculated loss value.
|
||||
"""
|
||||
reduction = (
|
||||
reduction_override if reduction_override else self.reduction)
|
||||
|
||||
preds_S = preds_S / self.tau
|
||||
soft_label = F.softmax((preds_T / self.tau), dim=-1)
|
||||
loss_cls = self.loss_weight * self.cls_criterion(
|
||||
preds_S,
|
||||
soft_label,
|
||||
weight,
|
||||
reduction=reduction,
|
||||
avg_factor=avg_factor)
|
||||
if self.mult_tem_square:
|
||||
loss_cls *= (self.tau**2)
|
||||
return loss_cls
|
|
@ -15,6 +15,8 @@ class L2Loss(nn.Module):
|
|||
mult (float): Multiplier for feature normalization. Defaults to 1.0.
|
||||
div_element (bool): Whether to divide the loss by element-wise.
|
||||
Defaults to False.
|
||||
dist (bool): Whether to conduct two-norm dist as torch.dist(p=2).
|
||||
Defaults to False.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
@ -23,12 +25,14 @@ class L2Loss(nn.Module):
|
|||
normalize: bool = True,
|
||||
mult: float = 1.0,
|
||||
div_element: bool = False,
|
||||
dist: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.loss_weight = loss_weight
|
||||
self.normalize = normalize
|
||||
self.mult = mult
|
||||
self.div_element = div_element
|
||||
self.dist = dist
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
@ -49,10 +53,14 @@ class L2Loss(nn.Module):
|
|||
|
||||
loss = torch.sum(torch.pow(torch.sub(s_feature, t_feature), 2))
|
||||
|
||||
if self.div_element:
|
||||
loss = loss / s_feature.numel()
|
||||
# Calculate l2_loss as dist.
|
||||
if self.dist:
|
||||
loss = torch.sqrt(loss)
|
||||
else:
|
||||
loss = loss / s_feature.size(0)
|
||||
if self.div_element:
|
||||
loss = loss / s_feature.numel()
|
||||
else:
|
||||
loss = loss / s_feature.size(0)
|
||||
|
||||
return self.loss_weight * loss
|
||||
|
||||
|
|
|
@ -0,0 +1,57 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from unittest import TestCase
|
||||
|
||||
import torch
|
||||
from mmcv import ConfigDict
|
||||
|
||||
from mmrazor.models import SelfDistill
|
||||
|
||||
|
||||
class TestSelfDistill(TestCase):
|
||||
|
||||
def test_init(self):
|
||||
|
||||
student_recorders_cfg = ConfigDict(
|
||||
conv=dict(type='ModuleOutputs', source='conv'))
|
||||
teacher_recorders_cfg = ConfigDict(
|
||||
conv=dict(type='ModuleOutputs', source='conv'))
|
||||
|
||||
alg_kwargs = ConfigDict(
|
||||
architecture=dict(type='ToyStudent'),
|
||||
distiller=dict(
|
||||
type='BYOTDistiller',
|
||||
student_recorders=student_recorders_cfg,
|
||||
teacher_recorders=teacher_recorders_cfg,
|
||||
distill_losses=dict(loss_toy=dict(type='ToyDistillLoss')),
|
||||
loss_forward_mappings=dict(
|
||||
loss_toy=dict(
|
||||
arg1=dict(from_student=True, recorder='conv'),
|
||||
arg2=dict(from_student=False, recorder='conv')))))
|
||||
|
||||
_ = SelfDistill(**alg_kwargs)
|
||||
|
||||
def test_loss(self):
|
||||
|
||||
student_recorders_cfg = ConfigDict(
|
||||
conv=dict(type='ModuleOutputs', source='conv'))
|
||||
teacher_recorders_cfg = ConfigDict(
|
||||
conv=dict(type='ModuleOutputs', source='conv'))
|
||||
|
||||
alg_kwargs = ConfigDict(
|
||||
architecture=dict(type='ToyStudent'),
|
||||
distiller=dict(
|
||||
type='BYOTDistiller',
|
||||
student_recorders=student_recorders_cfg,
|
||||
teacher_recorders=teacher_recorders_cfg,
|
||||
distill_losses=dict(loss_toy=dict(type='ToyDistillLoss')),
|
||||
loss_forward_mappings=dict(
|
||||
loss_toy=dict(
|
||||
arg1=dict(from_student=True, recorder='conv'),
|
||||
arg2=dict(from_student=False, recorder='conv')))))
|
||||
|
||||
img = torch.randn(1, 3, 1, 1)
|
||||
|
||||
alg = SelfDistill(**alg_kwargs)
|
||||
losses = alg(img, mode='loss')
|
||||
self.assertIn('distill.loss_toy', losses)
|
||||
self.assertIn('student.loss', losses)
|
|
@ -3,7 +3,7 @@ from unittest import TestCase
|
|||
|
||||
import torch
|
||||
|
||||
from mmrazor.models import ConvModuleConncetor
|
||||
from mmrazor.models import BYOTConnector, ConvModuleConncetor
|
||||
|
||||
|
||||
class TestConnector(TestCase):
|
||||
|
@ -36,3 +36,23 @@ class TestConnector(TestCase):
|
|||
convmodule_connector_cfg['conv_cfg'] = 'conv2d'
|
||||
with self.assertRaises(AssertionError):
|
||||
_ = ConvModuleConncetor(**convmodule_connector_cfg)
|
||||
|
||||
def test_byot_connector(self):
|
||||
byot_connector_cfg = dict(
|
||||
in_channel=16,
|
||||
out_channel=32,
|
||||
num_classes=10,
|
||||
expansion=4,
|
||||
pool_size=4,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
init_cfg=None)
|
||||
byot_connector = BYOTConnector(**byot_connector_cfg)
|
||||
|
||||
s_feat = torch.randn(1, 16 * 4, 8, 8)
|
||||
t_feat = torch.randn(1, 32 * 4)
|
||||
labels = torch.randn(1, 10)
|
||||
|
||||
output, logits = byot_connector.forward_train(s_feat)
|
||||
assert output.size() == t_feat.size()
|
||||
assert logits.size() == labels.size()
|
||||
|
|
|
@ -0,0 +1,69 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
from unittest import TestCase
|
||||
|
||||
from mmcv import ConfigDict
|
||||
|
||||
from mmrazor.models import BYOTDistiller
|
||||
|
||||
|
||||
class TestBYOTDistiller(TestCase):
|
||||
|
||||
def test_init(self):
|
||||
|
||||
student_recorders_cfg = ConfigDict(
|
||||
conv=dict(type='ModuleOutputs', source='conv'))
|
||||
teacher_recorders_cfg = ConfigDict(
|
||||
conv=dict(type='ModuleOutputs', source='conv'))
|
||||
|
||||
distiller_kwargs = ConfigDict(
|
||||
student_recorders=student_recorders_cfg,
|
||||
teacher_recorders=teacher_recorders_cfg,
|
||||
distill_losses=dict(loss_toy=dict(type='ToyDistillLoss')),
|
||||
loss_forward_mappings=dict(
|
||||
loss_toy=dict(
|
||||
arg1=dict(from_student=True, recorder='conv'),
|
||||
arg2=dict(from_student=False, recorder='conv'),
|
||||
)),
|
||||
)
|
||||
|
||||
_ = BYOTDistiller(**distiller_kwargs)
|
||||
|
||||
distiller_kwargs_ = copy.deepcopy(distiller_kwargs)
|
||||
distiller_kwargs_['distill_losses'] = None
|
||||
with self.assertRaisesRegex(AssertionError,
|
||||
'"loss_toy" is not in distill'):
|
||||
_ = BYOTDistiller(**distiller_kwargs_)
|
||||
|
||||
distiller_kwargs_ = copy.deepcopy(distiller_kwargs)
|
||||
distiller_kwargs_['distill_losses'] = dict(
|
||||
toy=dict(type='ToyDistillLoss'))
|
||||
distiller_kwargs_['loss_forward_mappings'] = dict(
|
||||
toy=dict(
|
||||
arg1=dict(from_student=True, recorder='conv'),
|
||||
arg2=dict(from_student=False, recorder='conv')))
|
||||
with self.assertWarnsRegex(UserWarning, 'Warning: If toy is a'):
|
||||
_ = BYOTDistiller(**distiller_kwargs_)
|
||||
|
||||
distiller_kwargs_ = copy.deepcopy(distiller_kwargs)
|
||||
distiller_kwargs_['loss_forward_mappings'] = None
|
||||
_ = BYOTDistiller(**distiller_kwargs_)
|
||||
|
||||
distiller_kwargs_ = copy.deepcopy(distiller_kwargs)
|
||||
distiller_kwargs_['loss_forward_mappings'] = list('AAA')
|
||||
|
||||
with self.assertRaisesRegex(TypeError,
|
||||
'loss_forward_mappings should be '):
|
||||
_ = BYOTDistiller(**distiller_kwargs_)
|
||||
|
||||
distiller_kwargs_ = copy.deepcopy(distiller_kwargs)
|
||||
distiller_kwargs_['loss_forward_mappings']['loss_toy'] = list()
|
||||
with self.assertRaisesRegex(
|
||||
TypeError, 'Each item of loss_forward_mappings should be '):
|
||||
_ = BYOTDistiller(**distiller_kwargs_)
|
||||
|
||||
distiller_kwargs_ = copy.deepcopy(distiller_kwargs)
|
||||
distiller_kwargs_.loss_forward_mappings.loss_toy.arg1.from_student = ''
|
||||
with self.assertRaisesRegex(TypeError,
|
||||
'from_student should be a bool'):
|
||||
_ = BYOTDistiller(**distiller_kwargs_)
|
|
@ -3,7 +3,7 @@ from unittest import TestCase
|
|||
|
||||
import torch
|
||||
|
||||
from mmrazor.models import ABLoss, DKDLoss
|
||||
from mmrazor.models import ABLoss, DKDLoss, KDSoftCELoss
|
||||
|
||||
|
||||
class TestLosses(TestCase):
|
||||
|
@ -50,3 +50,9 @@ class TestLosses(TestCase):
|
|||
dkd_loss = DKDLoss(**dkd_loss_cfg)
|
||||
# dkd requires label logits
|
||||
self.normal_test_1d(dkd_loss, labels=True)
|
||||
|
||||
def test_kdSoftce_loss(self):
|
||||
kdSoftce_loss_cfg = dict(loss_weight=1.0)
|
||||
kdSoftce_loss = KDSoftCELoss(**kdSoftce_loss_cfg)
|
||||
# kd soft ce loss requires label logits
|
||||
self.normal_test_1d(kdSoftce_loss, labels=True)
|
||||
|
|
|
@ -13,7 +13,8 @@ from mmengine.model import BaseModel
|
|||
from mmengine.runner import Runner
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from mmrazor.engine import SingleTeacherDistillValLoop # noqa: F401
|
||||
from mmrazor.engine import SelfDistillValLoop # noqa: F401
|
||||
from mmrazor.engine import SingleTeacherDistillValLoop
|
||||
from mmrazor.registry import DATASETS, METRICS, MODELS
|
||||
|
||||
|
||||
|
@ -125,3 +126,54 @@ class TestSingleTeacherDistillValLoop(TestCase):
|
|||
runner.val()
|
||||
|
||||
self.assertIn('val/teacher.acc', runner.message_hub.log_scalars.keys())
|
||||
|
||||
|
||||
class TestSelfDistillValLoop(TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.temp_dir = tempfile.mkdtemp()
|
||||
|
||||
val_dataloader = dict(
|
||||
dataset=dict(type='ToyDataset_DistillValLoop'),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
batch_size=3,
|
||||
num_workers=0)
|
||||
val_evaluator = dict(type='ToyMetric_DistillValLoop')
|
||||
|
||||
val_loop_cfg = dict(
|
||||
default_scope='mmrazor',
|
||||
model=dict(type='ToyModel_DistillValLoop'),
|
||||
work_dir=self.temp_dir,
|
||||
val_dataloader=val_dataloader,
|
||||
val_evaluator=val_evaluator,
|
||||
val_cfg=dict(type='SelfDistillValLoop'),
|
||||
custom_hooks=[],
|
||||
default_hooks=dict(
|
||||
runtime_info=dict(type='RuntimeInfoHook'),
|
||||
timer=dict(type='IterTimerHook'),
|
||||
logger=dict(type='LoggerHook'),
|
||||
param_scheduler=dict(type='ParamSchedulerHook'),
|
||||
checkpoint=dict(
|
||||
type='CheckpointHook', interval=1, by_epoch=True),
|
||||
sampler_seed=dict(type='DistSamplerSeedHook')),
|
||||
launcher='none',
|
||||
env_cfg=dict(dist_cfg=dict(backend='nccl')),
|
||||
)
|
||||
self.val_loop_cfg = Config(val_loop_cfg)
|
||||
|
||||
def tearDown(self):
|
||||
shutil.rmtree(self.temp_dir)
|
||||
|
||||
def test_init(self):
|
||||
cfg = copy.deepcopy(self.val_loop_cfg)
|
||||
cfg.experiment_name = 'test_init_self'
|
||||
runner = Runner.from_cfg(cfg)
|
||||
loop = runner.build_val_loop(cfg.val_cfg)
|
||||
|
||||
self.assertIsInstance(loop, SelfDistillValLoop)
|
||||
|
||||
def test_run(self):
|
||||
cfg = copy.deepcopy(self.val_loop_cfg)
|
||||
cfg.experiment_name = 'test_run_self'
|
||||
runner = Runner.from_cfg(cfg)
|
||||
runner.val()
|
||||
|
|
Loading…
Reference in New Issue