[Feature] Add BYOT Distillation (#232)

* byot connector & distiller

* fix config

* fix connector

* tmpsave

* add byot & kdsoftce loss

* update dev-1.x

* fx wsld

* Update README.md

* Update README.md

* fix md

* add ut & REQUIRE REVIEW part

* fix md

* add SelfDistillValLoop UT

* fix comments

* fix comments v2

* fix comments v3

* add connector_idx=None to ConfigurableDistiller.get_record()

Co-authored-by: zengyi.vendor <zengyi.vendor@sensetime.com>
pull/255/head
zengyi 2022-08-22 14:08:02 +08:00 committed by GitHub
parent 83240dcd8a
commit c6e8dcd209
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
25 changed files with 979 additions and 23 deletions

View File

@ -0,0 +1,50 @@
_base_ = ['./pipelines/auto_aug_cifar.py']
# dataset settings
dataset_type = 'CIFAR100'
preprocess_cfg = dict(
# RGB format normalization parameters
mean=[129.304, 124.070, 112.434],
std=[68.170, 65.392, 70.418],
# loaded images are already RGB format
to_rgb=False)
train_pipeline = [
dict(type='RandomCrop', crop_size=32, padding=4),
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
dict(type='Cutout', shape=16, pad_val=0),
dict(type='AutoAugment', policies={{_base_.policy_cifar}}),
dict(type='PackClsInputs'),
]
test_pipeline = [
dict(type='PackClsInputs'),
]
train_dataloader = dict(
batch_size=16,
num_workers=2,
dataset=dict(
type=dataset_type,
data_prefix='data/cifar100',
test_mode=False,
pipeline=train_pipeline),
sampler=dict(type='DefaultSampler', shuffle=True),
persistent_workers=True,
)
val_dataloader = dict(
batch_size=16,
num_workers=2,
dataset=dict(
type=dataset_type,
data_prefix='data/cifar100/',
test_mode=True,
pipeline=test_pipeline),
sampler=dict(type='DefaultSampler', shuffle=False),
persistent_workers=True,
)
val_evaluator = dict(type='Accuracy', topk=(1, 5))
test_dataloader = val_dataloader
test_evaluator = val_evaluator

View File

@ -0,0 +1,125 @@
# Policy for CIFAR, refer to
# https://github.com/DeepVoltaire/AutoAugment/blame/master/autoaugment.py
policy_cifar = [
# Group 1
[
dict(type='Invert', prob=0.1),
dict(type='Contrast', magnitude=0.5, prob=0.2)
],
[
dict(type='Rotate', angle=10., prob=0.7),
dict(type='Translate', magnitude=150 / 331, prob=0.3)
],
[
dict(type='Sharpness', magnitude=0.9, prob=0.8),
dict(type='Sharpness', magnitude=0.3, prob=0.9)
],
[
dict(
type='Shear',
magnitude=0.3 / 9 * 8,
direction='vertical',
prob=0.5),
dict(
type='Translate',
magnitude=150 / 331,
direction='vertical',
prob=0.3)
],
[dict(type='AutoContrast', prob=0.5),
dict(type='Equalize', prob=0.9)],
# Group 2
[
dict(
type='Shear',
magnitude=0.3 / 9 * 7,
direction='vertical',
prob=0.2),
dict(type='Posterize', bits=5, prob=0.3)
],
[
dict(type='ColorTransform', magnitude=0.3, prob=0.4),
dict(type='Brightness', magnitude=0.7, prob=0.7)
],
[
dict(type='Sharpness', magnitude=1.0, prob=0.3),
dict(type='Brightness', magnitude=1.0, prob=0.7)
],
[dict(type='Equalize', prob=0.6),
dict(type='Equalize', prob=0.5)],
[
dict(type='Contrast', magnitude=0.6, prob=0.6),
dict(type='Sharpness', magnitude=0.4, prob=0.8),
],
# Group 3
[
dict(type='ColorTransform', magnitude=0.6, prob=0.7),
dict(type='Translate', magnitude=150 / 331 / 9 * 7, prob=0.5)
],
[dict(type='Equalize', prob=0.3),
dict(type='AutoContrast', prob=0.4)],
[
dict(
type='Translate',
magnitude=150 / 331 / 9 * 2,
direction='vertical',
prob=0.4),
dict(type='Sharpness', magnitude=0.5, prob=0.2)
],
[
dict(type='Brightness', magnitude=0.5, prob=0.9),
dict(type='ColorTransform', magnitude=0.7, prob=0.2),
],
[
dict(type='Solarize', thr=256 / 9 * 7, prob=0.5),
dict(type='Invert', prob=0.0),
],
# Group 4
[dict(type='Equalize', prob=0.2),
dict(type='AutoContrast', prob=0.6)],
[dict(type='Equalize', prob=0.2),
dict(type='Equalize', prob=0.6)],
[
dict(type='ColorTransform', magnitude=0.9, prob=0.9),
dict(type='Equalize', prob=0.6)
],
[
dict(type='AutoContrast', prob=0.8),
dict(type='Solarize', thr=256 / 9 * 1, prob=0.2),
],
[
dict(type='Brightness', magnitude=0.3, prob=0.1),
dict(type='ColorTransform', magnitude=0.0, prob=0.7)
],
# Group 5
[
dict(type='Solarize', thr=256 / 9 * 4, prob=0.4),
dict(type='AutoContrast', prob=0.9)
],
[
dict(
type='Translate',
magnitude=150 / 331,
direction='vertical',
prob=0.9),
dict(
type='Translate',
magnitude=150 / 331,
direction='vertical',
prob=0.7)
],
[
dict(type='AutoContrast', prob=0.9),
dict(type='Solarize', thr=256 / 9 * 6, prob=0.8)
],
[dict(type='Equalize', prob=0.8),
dict(type='Invert', prob=0.1)],
[
dict(
type='Translate',
magnitude=150 / 331,
direction='vertical',
prob=0.7),
dict(type='AutoContrast', prob=0.9)
]
]

View File

@ -0,0 +1,55 @@
# Be Your Own Teacher: Improve the Performance of Convolutional Neural Networks via Self Distillation
## Abstract
Convolutional neural networks have been widely deployed in various application scenarios. In order to extend the applications' boundaries to some accuracy-crucial domains, researchers have been investigating approaches to boost accuracy through either deeper or wider network structures, which brings with them the exponential increment of the computational and storage cost, delaying the responding time. In this paper, we propose a general training framework named self distillation, which notably enhances the performance (accuracy) of convolutional neural networks through shrinking the size of the network rather than aggrandizing it. Different from traditional knowledge distillation - a knowledge transformation methodology among networks, which forces student neural networks to approximate the softmax layer outputs of pre-trained teacher neural networks, the proposed self distillation framework distills knowledge within network itself. The networks are firstly divided into several sections. Then the knowledge in the deeper portion of the networks is squeezed into the shallow ones. Experiments further prove the generalization of the proposed self distillation framework: enhancement of accuracy at average level is 2.65%, varying from 0.61% in ResNeXt as minimum to 4.07% in VGG19 as maximum. In addition, it can also provide flexibility of depth-wise scalable inference on resource-limited edge devices.Our codes will be released on github soon. [Unofficial code](https://github.com/luanyunteng/pytorch-be-your-own-teacher)
## Pipeline
![pipeline](../../../../docs/en/imgs/model_zoo/byot/byot.png)
## Results and models
#### Classification
| Location | Dataset | Model | Params(M) | Flops(G) | Top-1 (%) | Top-5 (%) | Download | |
| :\------: | :------: | :-------------------------------------------------------: | :-------: | :------: | :-------: | :-------: | :--------------------------------------------------------------------------------------------------: |
| logits | CIFAR100 | [R18_BYOT](./byot_logits_resnet18_cifar100_8xb16_in1k.py) | 11.22 | 0.56 | 80.66 | 95.76 | [model & log](https://autolink.sensetime.com/pages/model/share/08ad706f-b3d4-4854-8019-e0b43607f001) |
## Citation
```latex
@ARTICLE{2019arXiv190508094Z,
author = {{Zhang}, Linfeng and {Song}, Jiebo and {Gao}, Anni and {Chen}, Jingwei and {Bao}, Chenglong and {Ma}, Kaisheng},
title = {Be Your Own Teacher: Improve the Performance of Convolutional Neural Networks via Self Distillation},
journal = {arXiv e-prints},
keywords = {Computer Science - Machine Learning, Statistics - Machine Learning},
year = 2019,
month = may,
eid = {arXiv:1905.08094},
pages = {arXiv:1905.08094},
archivePrefix = {arXiv},
eprint = {1905.08094},
primaryClass = {cs.LG},
adsurl = {https://ui.adsabs.harvard.edu/abs/2019arXiv190508094Z},
adsnote = {Provided by the SAO/NASA Astrophysics Data System}
}
```
## Get Started
### Distillation training.
```bash
sh tools/slurm_train.sh $PARTITION $JOB_NAME \
configs/distill/mmcls/byot/byot_logits_resnet18_cifar100_8xb16_in1k.py\
$WORK_DIR
```
### Test
```bash
sh tools/slurm_test.sh $PARTITION $JOB_NAME \
configs/distill/mmcls/byot/byot_logits_resnet18_cifar100_8xb16_in1k.py\
$WORK_DIR/latest.sh --eval $EVAL_SETTING
```

View File

@ -0,0 +1,155 @@
_base_ = [
'../../../_base_/datasets/mmcls/cifar100_bs16_auto_aug.py',
'mmcls::_base_/schedules/cifar10_bs128.py',
'mmcls::_base_/default_runtime.py'
]
optim_wrapper = dict(
optimizer=dict(type='SGD', lr=0.1, momentum=0.9, weight_decay=0.0005))
param_scheduler = dict(
type='MultiStepLR', by_epoch=True, milestones=[80, 160, 240], gamma=0.1)
train_cfg = dict(by_epoch=True, max_epochs=250, val_interval=1)
model = dict(
_scope_='mmrazor',
type='SelfDistill',
data_preprocessor=dict(
type='ImgDataPreprocessor',
# RGB format normalization parameters
mean=[129.304, 124.070, 112.434],
std=[68.170, 65.392, 70.418],
# convert image from BGR to RGB
bgr_to_rgb=False),
architecture=dict(
type='mmcls.ImageClassifier',
backbone=dict(
type='mmcls.ResNet_CIFAR',
depth=18,
num_stages=4,
out_indices=(3, ),
style='pytorch'),
neck=dict(type='mmcls.GlobalAveragePooling'),
head=dict(
type='mmcls.LinearClsHead',
num_classes=100,
in_channels=512,
loss=dict(type='mmcls.CrossEntropyLoss', loss_weight=1.0))),
distiller=dict(
type='BYOTDistiller',
student_recorders=dict(
bb_s1=dict(type='ModuleOutputs', source='backbone.layer1.1.relu'),
bb_s2=dict(type='ModuleOutputs', source='backbone.layer2.1.relu'),
bb_s3=dict(type='ModuleOutputs', source='backbone.layer3.1.relu')),
teacher_recorders=dict(
fc=dict(type='ModuleOutputs', source='head.fc'),
neck_gap=dict(type='ModuleOutputs', source='neck.gap'),
gt_labels=dict(type='ModuleInputs', source='head.loss_module')),
distill_losses=dict(
loss_fet_1=dict(
type='L2Loss', normalize=False, loss_weight=0.03, dist=True),
loss_label_1=dict(type='mmcls.CrossEntropyLoss', loss_weight=0.7),
loss_softl_1=dict(type='KLDivergence', tau=3, loss_weight=0.3),
loss_fet_2=dict(
type='L2Loss', normalize=False, loss_weight=0.03, dist=True),
loss_label_2=dict(type='mmcls.CrossEntropyLoss', loss_weight=0.7),
loss_softl_2=dict(type='KLDivergence', tau=3, loss_weight=0.3),
loss_fet_3=dict(
type='L2Loss', normalize=False, loss_weight=0., dist=True),
loss_label_3=dict(type='mmcls.CrossEntropyLoss', loss_weight=0.7),
loss_softl_3=dict(type='KLDivergence', tau=3, loss_weight=0.3)),
connectors=dict(
loss_s1_sfeat=dict(
type='BYOTConnector',
in_channel=64,
out_channel=512,
expansion=1,
kernel_size=3,
stride=2,
num_classes=100),
loss_s2_sfeat=dict(
type='BYOTConnector',
in_channel=128,
out_channel=512,
expansion=1,
kernel_size=3,
stride=2,
num_classes=100),
loss_s3_sfeat=dict(
type='BYOTConnector',
in_channel=256,
out_channel=512,
expansion=1,
kernel_size=3,
stride=2,
num_classes=100)),
loss_forward_mappings=dict(
loss_fet_1=dict(
s_feature=dict(
recorder='bb_s1',
from_student=True,
connector='loss_s1_sfeat',
connector_idx=0),
t_feature=dict(recorder='neck_gap', from_student=False)),
loss_label_1=dict(
cls_score=dict(
recorder='bb_s1',
from_student=True,
connector='loss_s1_sfeat',
connector_idx=1),
label=dict(
recorder='gt_labels', from_student=False, data_idx=1)),
loss_softl_1=dict(
preds_S=dict(
recorder='bb_s1',
from_student=True,
connector='loss_s1_sfeat',
connector_idx=1),
preds_T=dict(recorder='fc', from_student=False)),
loss_fet_2=dict(
s_feature=dict(
recorder='bb_s2',
from_student=True,
connector='loss_s2_sfeat',
connector_idx=0),
t_feature=dict(recorder='neck_gap', from_student=False)),
loss_label_2=dict(
cls_score=dict(
recorder='bb_s2',
from_student=True,
connector='loss_s2_sfeat',
connector_idx=1),
label=dict(
recorder='gt_labels', from_student=False, data_idx=1)),
loss_softl_2=dict(
preds_S=dict(
recorder='bb_s2',
from_student=True,
connector='loss_s2_sfeat',
connector_idx=1),
preds_T=dict(recorder='fc', from_student=False)),
loss_fet_3=dict(
s_feature=dict(
recorder='bb_s3',
from_student=True,
connector='loss_s3_sfeat',
connector_idx=0),
t_feature=dict(recorder='neck_gap', from_student=False)),
loss_label_3=dict(
cls_score=dict(
recorder='bb_s3',
from_student=True,
connector='loss_s3_sfeat',
connector_idx=1),
label=dict(
recorder='gt_labels', from_student=False, data_idx=1)),
loss_softl_3=dict(
preds_S=dict(
recorder='bb_s3',
from_student=True,
connector='loss_s3_sfeat',
connector_idx=1),
preds_T=dict(recorder='fc', from_student=False)))))
find_unused_parameters = True
val_cfg = dict(_delete_=True, type='mmrazor.SelfDistillValLoop')

Binary file not shown.

After

Width:  |  Height:  |  Size: 274 KiB

View File

@ -3,12 +3,12 @@ from .hooks import DumpSubnetHook
from .optimizers import SeparateOptimWrapperConstructor
from .runner import (AutoSlimValLoop, DartsEpochBasedTrainLoop,
DartsIterBasedTrainLoop, EvolutionSearchLoop,
GreedySamplerTrainLoop, SingleTeacherDistillValLoop,
SlimmableValLoop)
GreedySamplerTrainLoop, SelfDistillValLoop,
SingleTeacherDistillValLoop, SlimmableValLoop)
__all__ = [
'SeparateOptimWrapperConstructor', 'DumpSubnetHook',
'SingleTeacherDistillValLoop', 'DartsEpochBasedTrainLoop',
'DartsIterBasedTrainLoop', 'SlimmableValLoop', 'EvolutionSearchLoop',
'GreedySamplerTrainLoop', 'AutoSlimValLoop'
'GreedySamplerTrainLoop', 'AutoSlimValLoop', 'SelfDistillValLoop'
]

View File

@ -1,7 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .autoslim_val_loop import AutoSlimValLoop
from .darts_loop import DartsEpochBasedTrainLoop, DartsIterBasedTrainLoop
from .distill_val_loop import SingleTeacherDistillValLoop
from .distill_val_loop import SelfDistillValLoop, SingleTeacherDistillValLoop
from .evolution_search_loop import EvolutionSearchLoop
from .slimmable_val_loop import SlimmableValLoop
from .subnet_sampler_loop import GreedySamplerTrainLoop
@ -9,5 +9,5 @@ from .subnet_sampler_loop import GreedySamplerTrainLoop
__all__ = [
'SingleTeacherDistillValLoop', 'DartsEpochBasedTrainLoop',
'DartsIterBasedTrainLoop', 'SlimmableValLoop', 'EvolutionSearchLoop',
'GreedySamplerTrainLoop', 'AutoSlimValLoop'
'GreedySamplerTrainLoop', 'AutoSlimValLoop', 'SelfDistillValLoop'
]

View File

@ -100,3 +100,42 @@ class SingleTeacherDistillValLoop(ValLoop):
batch_idx=idx,
data_batch=data_batch,
outputs=outputs)
@LOOPS.register_module()
class SelfDistillValLoop(ValLoop):
"""Knowledge Distill loop for validation. Only validate student.
Args:
runner (Runner): A reference of runner.
dataloader (Dataloader or dict): A dataloader object or a dict to
build a dataloader.
evaluator (Evaluator or dict or list): Used for computing metrics.
fp16 (bool): Whether to enable fp16 validation. Defaults to
False.
"""
def __init__(self,
runner,
dataloader: Union[DataLoader, Dict],
evaluator: Union[Evaluator, Dict, List],
fp16: bool = False) -> None:
super().__init__(runner, dataloader, evaluator, fp16)
def run(self):
"""Launch validation."""
self.runner.call_hook('before_val')
self.runner.call_hook('before_val_epoch')
self.runner.model.eval()
for idx, data_batch in enumerate(self.dataloader):
self.run_iter(idx, data_batch)
# compute student metrics
metrics = self.evaluator.evaluate(len(self.dataloader.dataset))
student_metrics = dict()
for key, value in metrics.items():
student_key = 'student.' + key
student_metrics[student_key] = value
self.runner.call_hook('after_val_epoch', metrics=student_metrics)
self.runner.call_hook('after_val')

View File

@ -1,11 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .base import BaseAlgorithm
from .distill import FpnTeacherDistill, SingleTeacherDistill
from .distill import FpnTeacherDistill, SelfDistill, SingleTeacherDistill
from .nas import SPOS, AutoSlim, AutoSlimDDP, Darts, DartsDDP
from .pruning import SlimmableNetwork, SlimmableNetworkDDP
__all__ = [
'SingleTeacherDistill', 'BaseAlgorithm', 'FpnTeacherDistill', 'SPOS',
'SlimmableNetwork', 'SlimmableNetworkDDP', 'AutoSlim', 'AutoSlimDDP',
'Darts', 'DartsDDP'
'Darts', 'DartsDDP', 'SelfDistill'
]

View File

@ -1,4 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .configurable import FpnTeacherDistill, SingleTeacherDistill
from .configurable import FpnTeacherDistill, SelfDistill, SingleTeacherDistill
__all__ = ['SingleTeacherDistill', 'FpnTeacherDistill']
__all__ = ['SingleTeacherDistill', 'FpnTeacherDistill', 'SelfDistill']

View File

@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .fpn_teacher_distill import FpnTeacherDistill
from .self_distill import SelfDistill
from .single_teacher_distill import SingleTeacherDistill
__all__ = ['SingleTeacherDistill', 'FpnTeacherDistill']
__all__ = ['SelfDistill', 'SingleTeacherDistill', 'FpnTeacherDistill']

View File

@ -0,0 +1,92 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional
import torch
from mmengine import BaseDataElement
from torch import nn
from mmrazor.models.utils import add_prefix
from mmrazor.registry import MODELS
from ...base import BaseAlgorithm, LossResults
@MODELS.register_module()
class SelfDistill(BaseAlgorithm):
"""``SelfDistill`` can be used to develop distill algorithms without
teacher.
Args:
distiller (dict): The config dict for built distiller. Distiller may
have teacher.
student_trainable (bool): Whether the student is trainable. Defaults
to True.
calculate_student_loss (bool): Whether to calculate student loss
(original task loss) to update student model. Defaults to True.
"""
def __init__(self,
distiller: dict,
student_trainable: bool = True,
calculate_student_loss: bool = True,
**kwargs) -> None:
super().__init__(**kwargs)
self.distiller = MODELS.build(distiller)
# The student model will not calculate gradients and update parameters
# in some pretraining process.
self.student_trainable = student_trainable
# The student loss will not be updated into ``losses`` in some
# pretraining process.
self.calculate_student_loss = calculate_student_loss
# In ``ConfigurableDistller``, the recorder manager is just
# constructed, but not really initialized yet.
self.distiller.prepare_from_student(self.student)
# Still prepare from self-teacher. Teacher recorders of
# ``SelfDistiller`` hook from self.student but require detach().
self.distiller.prepare_from_teacher(self.student)
@property
def student(self) -> nn.Module:
"""Alias for ``architecture``."""
return self.architecture
def loss(
self,
batch_inputs: torch.Tensor,
data_samples: Optional[List[BaseDataElement]] = None,
) -> LossResults:
"""Calculate losses from a batch of inputs and data samples."""
losses = dict()
# If the `override_data` of a delivery is True, the delivery will
# override the origin data with the recorded data.
self.distiller.set_deliveries_override(True)
# Original task loss will not be used during some pretraining process.
if self.calculate_student_loss:
# teacher_recorders hook from student
with self.distiller.student_recorders, \
self.distiller.teacher_recorders, \
self.distiller.deliveries:
student_losses = self.student(
batch_inputs, data_samples, mode='loss')
losses.update(add_prefix(student_losses, 'student'))
else:
with self.distiller.student_recorders, \
self.distiller.teacher_recorders, \
self.distiller.deliveries:
if self.student_trainable:
_ = self.student(batch_inputs, data_samples, mode='loss')
else:
with torch.no_grad():
_ = self.student(
batch_inputs, data_samples, mode='loss')
# Automatically compute distill losses based on `loss_forward_mappings`
# The required data already exists in the recorders.
distill_losses = self.distiller.compute_distill_losses()
losses.update(add_prefix(distill_losses, 'distill'))
return losses

View File

@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .byot_connector import BYOTConnector
from .convmodule_connector import ConvModuleConncetor
__all__ = ['ConvModuleConncetor']
__all__ = ['ConvModuleConncetor', 'BYOTConnector']

View File

@ -0,0 +1,82 @@
# Copyright (c) OpenMMLab. All rights reserved.
from math import log
from typing import Dict, Optional, Tuple, Union
import torch
import torch.nn as nn
from mmrazor.registry import MODELS
from ...ops.darts_series import DartsSepConv
from .base_connector import BaseConnector
@MODELS.register_module()
class BYOTConnector(BaseConnector):
"""BYOTConnector connector that adds a self-attention with DartsSepConv.
Args:
in_channel (int): The input channel of the DartsSepConv.
Use like input_tensor_channel = in_channel * expansion.
out_channel (int): The output channel of the DartsSepConv.
Use like output_tensor_channel = out_channel * expansion.
num_classes (int): The classification class num.
expansion (int): Expansion of DartsSepConv. Default to 4.
pool_size (int | tuple[int]): Average 2D pool size. Default to 4.
kernel_size (int | tuple[int]): Size of the convolving kernel in
DartsSepConv. Same as that in ``nn._ConvNd``. Default to 3.
stride (int | tuple[int]): Stride of the first layer in DartsSepConv.
Same as that in ``nn._ConvNd``. Default to 1.
init_cfg (dict, optional): The config to control the initialization.
"""
def __init__(
self,
in_channel: int,
out_channel: int,
num_classes: int,
expansion: int = 4,
pool_size: Union[int, Tuple[int]] = 4,
kernel_size: Union[int, Tuple[int]] = 3,
stride: Union[int, Tuple[int]] = 1,
init_cfg: Optional[Dict] = None,
) -> None:
super().__init__(init_cfg)
self.attention = nn.Sequential(
DartsSepConv(
in_channels=in_channel * expansion,
out_channels=in_channel * expansion,
kernel_size=kernel_size,
stride=stride), nn.BatchNorm2d(in_channel * expansion),
nn.ReLU(), nn.Upsample(scale_factor=2, mode='bilinear'),
nn.Sigmoid())
scala_num = log(out_channel / in_channel, 2)
assert scala_num.is_integer()
scala = []
_in_channel = in_channel
for _ in range(int(scala_num)):
scala.append(
DartsSepConv(
in_channels=_in_channel * expansion,
out_channels=_in_channel * 2 * expansion,
kernel_size=kernel_size,
stride=stride))
_in_channel *= 2
scala.append(nn.AvgPool2d(pool_size))
self.scala = nn.Sequential(*scala)
self.fc = nn.Linear(out_channel * expansion, num_classes)
def forward_train(self, feature: torch.Tensor) -> torch.Tensor:
"""Forward computation.
Args:
feature (torch.Tensor): Input feature.
"""
feat = self.attention(feature)
feat = feat * feature
feat = self.scala(feat)
feat = feat.view(feature.size(0), -1)
logits = self.fc(feat)
return (feat, logits)

View File

@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .base_distiller import BaseDistiller
from .byot_distiller import BYOTDistiller
from .configurable_distiller import ConfigurableDistiller
__all__ = ['ConfigurableDistiller', 'BaseDistiller']
__all__ = ['ConfigurableDistiller', 'BaseDistiller', 'BYOTDistiller']

View File

@ -0,0 +1,44 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional
from mmrazor.registry import MODELS
from .configurable_distiller import ConfigurableDistiller
@MODELS.register_module()
class BYOTDistiller(ConfigurableDistiller):
"""``BYOTDistiller`` inherits ``ConfigurableDistiller`` and only modifies
``get_record()`` function to ``get_record_with_cidx()``.
In ``BYOTDistiller``, ``self.teacher_recorder`` records self-teacher data
which requires detach().
"""
def get_record(self,
recorder: str,
from_student: bool,
record_idx: int = 0,
data_idx: Optional[int] = None,
connector: Optional[str] = None,
connector_idx: Optional[int] = None) -> List:
"""According to each item in ``record_infos``, get the corresponding
record in ``recorder_manager``.
Detach teacher_record.
"""
if from_student:
recorder_ = self.student_recorders.get_recorder(recorder)
else:
recorder_ = self.teacher_recorders.get_recorder(recorder)
record_data = recorder_.get_record_data(record_idx, data_idx)
if connector:
record_data = self.connectors[connector](record_data)
if connector_idx is not None:
record_data = record_data[connector_idx]
# Detach self-teacher output Tensor from model, assert hook tensor.
if not from_student:
record_data = record_data.detach()
return record_data

View File

@ -180,7 +180,8 @@ class ConfigurableDistiller(BaseDistiller):
from_student: bool,
record_idx: int = 0,
data_idx: Optional[int] = None,
connector: Optional[str] = None) -> List:
connector: Optional[str] = None,
connector_idx: Optional[int] = None) -> List:
"""According to each item in ``record_infos``, get the corresponding
record in ``recorder_manager``."""
@ -192,6 +193,8 @@ class ConfigurableDistiller(BaseDistiller):
if connector:
record_data = self.connectors[connector](record_data)
if connector_idx is not None:
record_data = record_data[connector_idx]
return record_data
@ -235,9 +238,10 @@ class ConfigurableDistiller(BaseDistiller):
f'instance, but got {type(forward_mappings)}')
loss_module = losses[loss_name]
loss_forward_keys = signature(
loss_module.forward).parameters.keys()
assert len(loss_forward_keys) == len(forward_mappings.keys())
loss_forward_params = signature(loss_module.forward).parameters
loss_forward_keys = loss_forward_params.keys()
# Allow default params.
# Check non-default params, not len(params).
for forward_key, record_info in forward_mappings.items():
assert forward_key in loss_forward_keys, \
@ -245,6 +249,11 @@ class ConfigurableDistiller(BaseDistiller):
{type(loss_module).__name__} forward, \
please check your config.'
if (loss_forward_params[forward_key].default ==
loss_forward_params[forward_key].empty):
# default params without check
continue
assert 'recorder' in record_info, \
'Each item of loss_forward_mappings should have ' \
'"recorder", pls check your config.'

View File

@ -2,6 +2,7 @@
from .ab_loss import ABLoss
from .cwd import ChannelWiseDivergence
from .decoupled_kd import DKDLoss
from .kd_soft_ce_loss import KDSoftCELoss
from .kl_divergence import KLDivergence
from .l2_loss import L2Loss
from .relational_kd import AngleWiseRKD, DistanceWiseRKD
@ -9,5 +10,5 @@ from .weighted_soft_label_distillation import WSLD
__all__ = [
'ChannelWiseDivergence', 'KLDivergence', 'AngleWiseRKD', 'DistanceWiseRKD',
'WSLD', 'L2Loss', 'ABLoss', 'DKDLoss'
'WSLD', 'L2Loss', 'ABLoss', 'DKDLoss', 'KDSoftCELoss'
]

View File

@ -0,0 +1,89 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcls.models.losses.cross_entropy_loss import soft_cross_entropy
from mmrazor.registry import MODELS
@MODELS.register_module()
class KDSoftCELoss(nn.Module):
"""Distilling the Knowledge in a Neural Network, NIPS2014. Based on Soft
Cross Entropy criterion.
https://arxiv.org/pdf/1503.02531.pdf
Args:
tau (int, optional): Temperature. Defaults to 1.
reduction (str): Specifies the reduction to apply to the loss:
``'none'`` | ``'none'`` | ``'sum'`` | ``'mean'``.
``'none'``: no reduction will be applied,
``'sum'``: the output will be summed,
``'mean'``: the output will be divided by the number of
elements in the output.
Default: ``'mean'``
mult_tem_square (bool, optional): Multiply square of temperature
or not. Defaults to True.
loss_weight (float): Weight of loss. Defaults to 1.0.
"""
def __init__(
self,
tau: float = 1.0,
reduction: str = 'mean',
mult_tem_square: bool = True,
loss_weight: float = 1.0,
) -> None:
super().__init__()
self.tau = tau
self.mult_tem_square = mult_tem_square
self.loss_weight = loss_weight
self.cls_criterion = soft_cross_entropy
accept_reduction = {None, 'none', 'mean', 'sum'}
assert reduction in accept_reduction, \
f'KLDivergence supports reduction {accept_reduction}, ' \
f'but gets {reduction}.'
self.reduction = reduction
def forward(
self,
preds_S: torch.Tensor,
preds_T: torch.Tensor,
weight: torch.Tensor = None,
avg_factor: int = None,
reduction_override: str = None,
) -> torch.Tensor:
"""Forward computation.
Args:
preds_S (torch.Tensor): The student model prediction with
shape (N, C).
preds_T (torch.Tensor): The teacher model prediction with
shape (N, C).
weight (torch.Tensor, optional): Sample-wise loss weight with
shape (N, C). Defaults to None.
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
reduction_override (str, optiom): Override redunction in forward.
Defaults to None.
Return:
torch.Tensor: The calculated loss value.
"""
reduction = (
reduction_override if reduction_override else self.reduction)
preds_S = preds_S / self.tau
soft_label = F.softmax((preds_T / self.tau), dim=-1)
loss_cls = self.loss_weight * self.cls_criterion(
preds_S,
soft_label,
weight,
reduction=reduction,
avg_factor=avg_factor)
if self.mult_tem_square:
loss_cls *= (self.tau**2)
return loss_cls

View File

@ -15,6 +15,8 @@ class L2Loss(nn.Module):
mult (float): Multiplier for feature normalization. Defaults to 1.0.
div_element (bool): Whether to divide the loss by element-wise.
Defaults to False.
dist (bool): Whether to conduct two-norm dist as torch.dist(p=2).
Defaults to False.
"""
def __init__(
@ -23,12 +25,14 @@ class L2Loss(nn.Module):
normalize: bool = True,
mult: float = 1.0,
div_element: bool = False,
dist: bool = False,
) -> None:
super().__init__()
self.loss_weight = loss_weight
self.normalize = normalize
self.mult = mult
self.div_element = div_element
self.dist = dist
def forward(
self,
@ -49,10 +53,14 @@ class L2Loss(nn.Module):
loss = torch.sum(torch.pow(torch.sub(s_feature, t_feature), 2))
if self.div_element:
loss = loss / s_feature.numel()
# Calculate l2_loss as dist.
if self.dist:
loss = torch.sqrt(loss)
else:
loss = loss / s_feature.size(0)
if self.div_element:
loss = loss / s_feature.numel()
else:
loss = loss / s_feature.size(0)
return self.loss_weight * loss

View File

@ -0,0 +1,57 @@
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
import torch
from mmcv import ConfigDict
from mmrazor.models import SelfDistill
class TestSelfDistill(TestCase):
def test_init(self):
student_recorders_cfg = ConfigDict(
conv=dict(type='ModuleOutputs', source='conv'))
teacher_recorders_cfg = ConfigDict(
conv=dict(type='ModuleOutputs', source='conv'))
alg_kwargs = ConfigDict(
architecture=dict(type='ToyStudent'),
distiller=dict(
type='BYOTDistiller',
student_recorders=student_recorders_cfg,
teacher_recorders=teacher_recorders_cfg,
distill_losses=dict(loss_toy=dict(type='ToyDistillLoss')),
loss_forward_mappings=dict(
loss_toy=dict(
arg1=dict(from_student=True, recorder='conv'),
arg2=dict(from_student=False, recorder='conv')))))
_ = SelfDistill(**alg_kwargs)
def test_loss(self):
student_recorders_cfg = ConfigDict(
conv=dict(type='ModuleOutputs', source='conv'))
teacher_recorders_cfg = ConfigDict(
conv=dict(type='ModuleOutputs', source='conv'))
alg_kwargs = ConfigDict(
architecture=dict(type='ToyStudent'),
distiller=dict(
type='BYOTDistiller',
student_recorders=student_recorders_cfg,
teacher_recorders=teacher_recorders_cfg,
distill_losses=dict(loss_toy=dict(type='ToyDistillLoss')),
loss_forward_mappings=dict(
loss_toy=dict(
arg1=dict(from_student=True, recorder='conv'),
arg2=dict(from_student=False, recorder='conv')))))
img = torch.randn(1, 3, 1, 1)
alg = SelfDistill(**alg_kwargs)
losses = alg(img, mode='loss')
self.assertIn('distill.loss_toy', losses)
self.assertIn('student.loss', losses)

View File

@ -3,7 +3,7 @@ from unittest import TestCase
import torch
from mmrazor.models import ConvModuleConncetor
from mmrazor.models import BYOTConnector, ConvModuleConncetor
class TestConnector(TestCase):
@ -36,3 +36,23 @@ class TestConnector(TestCase):
convmodule_connector_cfg['conv_cfg'] = 'conv2d'
with self.assertRaises(AssertionError):
_ = ConvModuleConncetor(**convmodule_connector_cfg)
def test_byot_connector(self):
byot_connector_cfg = dict(
in_channel=16,
out_channel=32,
num_classes=10,
expansion=4,
pool_size=4,
kernel_size=3,
stride=2,
init_cfg=None)
byot_connector = BYOTConnector(**byot_connector_cfg)
s_feat = torch.randn(1, 16 * 4, 8, 8)
t_feat = torch.randn(1, 32 * 4)
labels = torch.randn(1, 10)
output, logits = byot_connector.forward_train(s_feat)
assert output.size() == t_feat.size()
assert logits.size() == labels.size()

View File

@ -0,0 +1,69 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
from unittest import TestCase
from mmcv import ConfigDict
from mmrazor.models import BYOTDistiller
class TestBYOTDistiller(TestCase):
def test_init(self):
student_recorders_cfg = ConfigDict(
conv=dict(type='ModuleOutputs', source='conv'))
teacher_recorders_cfg = ConfigDict(
conv=dict(type='ModuleOutputs', source='conv'))
distiller_kwargs = ConfigDict(
student_recorders=student_recorders_cfg,
teacher_recorders=teacher_recorders_cfg,
distill_losses=dict(loss_toy=dict(type='ToyDistillLoss')),
loss_forward_mappings=dict(
loss_toy=dict(
arg1=dict(from_student=True, recorder='conv'),
arg2=dict(from_student=False, recorder='conv'),
)),
)
_ = BYOTDistiller(**distiller_kwargs)
distiller_kwargs_ = copy.deepcopy(distiller_kwargs)
distiller_kwargs_['distill_losses'] = None
with self.assertRaisesRegex(AssertionError,
'"loss_toy" is not in distill'):
_ = BYOTDistiller(**distiller_kwargs_)
distiller_kwargs_ = copy.deepcopy(distiller_kwargs)
distiller_kwargs_['distill_losses'] = dict(
toy=dict(type='ToyDistillLoss'))
distiller_kwargs_['loss_forward_mappings'] = dict(
toy=dict(
arg1=dict(from_student=True, recorder='conv'),
arg2=dict(from_student=False, recorder='conv')))
with self.assertWarnsRegex(UserWarning, 'Warning: If toy is a'):
_ = BYOTDistiller(**distiller_kwargs_)
distiller_kwargs_ = copy.deepcopy(distiller_kwargs)
distiller_kwargs_['loss_forward_mappings'] = None
_ = BYOTDistiller(**distiller_kwargs_)
distiller_kwargs_ = copy.deepcopy(distiller_kwargs)
distiller_kwargs_['loss_forward_mappings'] = list('AAA')
with self.assertRaisesRegex(TypeError,
'loss_forward_mappings should be '):
_ = BYOTDistiller(**distiller_kwargs_)
distiller_kwargs_ = copy.deepcopy(distiller_kwargs)
distiller_kwargs_['loss_forward_mappings']['loss_toy'] = list()
with self.assertRaisesRegex(
TypeError, 'Each item of loss_forward_mappings should be '):
_ = BYOTDistiller(**distiller_kwargs_)
distiller_kwargs_ = copy.deepcopy(distiller_kwargs)
distiller_kwargs_.loss_forward_mappings.loss_toy.arg1.from_student = ''
with self.assertRaisesRegex(TypeError,
'from_student should be a bool'):
_ = BYOTDistiller(**distiller_kwargs_)

View File

@ -3,7 +3,7 @@ from unittest import TestCase
import torch
from mmrazor.models import ABLoss, DKDLoss
from mmrazor.models import ABLoss, DKDLoss, KDSoftCELoss
class TestLosses(TestCase):
@ -50,3 +50,9 @@ class TestLosses(TestCase):
dkd_loss = DKDLoss(**dkd_loss_cfg)
# dkd requires label logits
self.normal_test_1d(dkd_loss, labels=True)
def test_kdSoftce_loss(self):
kdSoftce_loss_cfg = dict(loss_weight=1.0)
kdSoftce_loss = KDSoftCELoss(**kdSoftce_loss_cfg)
# kd soft ce loss requires label logits
self.normal_test_1d(kdSoftce_loss, labels=True)

View File

@ -13,7 +13,8 @@ from mmengine.model import BaseModel
from mmengine.runner import Runner
from torch.utils.data import Dataset
from mmrazor.engine import SingleTeacherDistillValLoop # noqa: F401
from mmrazor.engine import SelfDistillValLoop # noqa: F401
from mmrazor.engine import SingleTeacherDistillValLoop
from mmrazor.registry import DATASETS, METRICS, MODELS
@ -125,3 +126,54 @@ class TestSingleTeacherDistillValLoop(TestCase):
runner.val()
self.assertIn('val/teacher.acc', runner.message_hub.log_scalars.keys())
class TestSelfDistillValLoop(TestCase):
def setUp(self):
self.temp_dir = tempfile.mkdtemp()
val_dataloader = dict(
dataset=dict(type='ToyDataset_DistillValLoop'),
sampler=dict(type='DefaultSampler', shuffle=False),
batch_size=3,
num_workers=0)
val_evaluator = dict(type='ToyMetric_DistillValLoop')
val_loop_cfg = dict(
default_scope='mmrazor',
model=dict(type='ToyModel_DistillValLoop'),
work_dir=self.temp_dir,
val_dataloader=val_dataloader,
val_evaluator=val_evaluator,
val_cfg=dict(type='SelfDistillValLoop'),
custom_hooks=[],
default_hooks=dict(
runtime_info=dict(type='RuntimeInfoHook'),
timer=dict(type='IterTimerHook'),
logger=dict(type='LoggerHook'),
param_scheduler=dict(type='ParamSchedulerHook'),
checkpoint=dict(
type='CheckpointHook', interval=1, by_epoch=True),
sampler_seed=dict(type='DistSamplerSeedHook')),
launcher='none',
env_cfg=dict(dist_cfg=dict(backend='nccl')),
)
self.val_loop_cfg = Config(val_loop_cfg)
def tearDown(self):
shutil.rmtree(self.temp_dir)
def test_init(self):
cfg = copy.deepcopy(self.val_loop_cfg)
cfg.experiment_name = 'test_init_self'
runner = Runner.from_cfg(cfg)
loop = runner.build_val_loop(cfg.val_cfg)
self.assertIsInstance(loop, SelfDistillValLoop)
def test_run(self):
cfg = copy.deepcopy(self.val_loop_cfg)
cfg.experiment_name = 'test_run_self'
runner = Runner.from_cfg(cfg)
runner.val()