[Feature] Add DAFL Distillation ()

* 1.Add DAFL, including config, DAFLLoss and readme. 2.Add DataFreeDistillationtillation. 3.Add Generator, including base_generator and dafl_generator. 4.Add get_module_device and set_requires_grad functions in utils.

* 1.Amend the file that report error in mypy test under py37, including gather_tensors, datafree_distillation, base_generator. 2.Revise other linting error.

* 1.Revise some docstrings.

* 1.Add UT for datafreedistillation. 2.Add all typing.hints.

* 1.Add UT for generators and gather_tensors.

* 1.Add assert of batch_size in base_generator

* 1.Isort

Co-authored-by: zhangzhongyu.vendor < zhangzhongyu.vendor@sensetime.com>
pull/230/head
zhongyu zhang 2022-08-23 10:47:34 +08:00 committed by GitHub
parent 72c11751cb
commit 57aec1f730
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
23 changed files with 1092 additions and 18 deletions

View File

@ -0,0 +1,42 @@
# Data-Free Learning of Student Networks (DAFL)
> [Data-Free Learning of Student Networks](https://doi.org/10.1109/ICCV.2019.00361)
<!-- [ALGORITHM] -->
## Abstract
Learning portable neural networks is very essential for computer vision for the purpose that pre-trained heavy deep models can be well applied on edge devices such as mobile phones and micro sensors. Most existing deep neural network compression and speed-up methods are very effective for training compact deep models, when we can directly access the training dataset. However, training data for the given deep network are often unavailable due to some practice problems (e.g. privacy, legal issue, and transmission), and the architecture of the given network are also unknown except some interfaces. To this end, we propose a novel framework for training efficient deep neural networks by exploiting generative adversarial networks (GANs). To be specific, the pre-trained teacher networks are regarded as a fixed discriminator and the generator is utilized for deviating training samples which can obtain the maximum response on the discriminator. Then, an efficient network with smaller model size and computational complexity is trained using the generated data and the teacher network, simultaneously. Efficient student networks learned using the pro- posed Data-Free Learning (DAFL) method achieve 92.22% and 74.47% accuracies using ResNet-18 without any training data on the CIFAR-10 and CIFAR-100 datasets, respectively. Meanwhile, our student network obtains an 80.56% accuracy on the CelebA benchmark.
![pipeline](/docs/en/imgs/model_zoo/dafl/pipeline.png)
## Results and models
### Classification
| Location | Dataset | Teacher | Student | Acc | Acc(T) | Acc(S) | Config | Download |
| :----------------------------------: | :-----: | :-------------------------------------------------------------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------------: | :---: | :----: | :----: | :-----------------------------------------------: | :------------------------------------------------------------------------------------------------------------------------------------------- |
| backbone (pretrain) & logits (train) | Cifar10 | [resnet34](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet34_8xb16_cifar10.py) | [resnet18](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet18_8xb16_cifar10.py) | 93.11 | 95.34 | 94.82 | [config](./dafl_logits_r34_r18_8xb256_cifar10.py) | [teacher](https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_b16x8_cifar10_20210528-a8aa36a6.pth) \|[model](<>) \| [log](<>) |
## Citation
```latex
@inproceedings{DBLP:conf/iccv/ChenW0YLSXX019,
author = {Hanting Chen, Yunhe Wang, Chang Xu, Zhaohui Yang, Chuanjian Liu,
Boxin Shi, Chunjing Xu, Chao Xu and Qi Tian},
title = {Data-Free Learning of Student Networks},
booktitle = {2019 {IEEE/CVF} International Conference on Computer Vision, {ICCV}
2019, Seoul, Korea (South), October 27 - November 2, 2019},
pages = {3513--3521},
publisher = {{IEEE}},
year = {2019},
url = {https://doi.org/10.1109/ICCV.2019.00361},
doi = {10.1109/ICCV.2019.00361},
timestamp = {Mon, 17 May 2021 08:18:18 +0200},
biburl = {https://dblp.org/rec/conf/iccv/ChenW0YLSXX019.bib},
bibsource = {dblp computer science bibliography, https://dblp.org}
```
## Acknowledgement
Shout out to Davidgzx.

View File

@ -0,0 +1,104 @@
_base_ = [
'mmcls::_base_/datasets/cifar10_bs16.py',
'mmcls::_base_/schedules/cifar10_bs128.py',
'mmcls::_base_/default_runtime.py'
]
model = dict(
_scope_='mmrazor',
type='DAFLDataFreeDistillation',
data_preprocessor=dict(
type='ImgDataPreprocessor',
# RGB format normalization parameters
mean=[125.307, 122.961, 113.8575],
std=[51.5865, 50.847, 51.255],
# convert image from BGR to RGB
bgr_to_rgb=False),
architecture=dict(
cfg_path='mmcls::resnet/resnet18_8xb16_cifar10.py', pretrained=False),
teachers=dict(
res34=dict(
build_cfg=dict(
cfg_path='mmcls::resnet/resnet34_8xb16_cifar10.py',
pretrained=True),
ckpt_path='resnet34_b16x8_cifar10_20210528-a8aa36a6.pth')),
generator=dict(
type='DAFLGenerator',
img_size=32,
latent_dim=1000,
hidden_channels=128),
distiller=dict(
type='ConfigurableDistiller',
student_recorders=dict(
fc=dict(type='ModuleOutputs', source='head.fc')),
teacher_recorders=dict(
res34_fc=dict(type='ModuleOutputs', source='res34.head.fc')),
distill_losses=dict(
loss_kl=dict(type='KLDivergence', tau=6, loss_weight=1)),
loss_forward_mappings=dict(
loss_kl=dict(
preds_S=dict(from_student=True, recorder='fc'),
preds_T=dict(from_student=False, recorder='res34_fc')))),
generator_distiller=dict(
type='ConfigurableDistiller',
teacher_recorders=dict(
res34_neck_gap=dict(type='ModuleOutputs', source='res34.neck.gap'),
res34_fc=dict(type='ModuleOutputs', source='res34.head.fc')),
distill_losses=dict(
loss_res34_oh=dict(type='OnehotLikeLoss', loss_weight=0.05),
loss_res34_ie=dict(type='InformationEntropyLoss', loss_weight=5),
loss_res34_ac=dict(type='ActivationLoss', loss_weight=0.01)),
loss_forward_mappings=dict(
loss_res34_oh=dict(
preds_T=dict(from_student=False, recorder='res34_fc')),
loss_res34_ie=dict(
preds_T=dict(from_student=False, recorder='res34_fc')),
loss_res34_ac=dict(
feat_T=dict(from_student=False, recorder='res34_neck_gap')))))
# model wrapper
model_wrapper_cfg = dict(
type='mmengine.MMSeparateDistributedDataParallel',
broadcast_buffers=False,
find_unused_parameters=False)
find_unused_parameters = True
# optimizer wrapper
optim_wrapper = dict(
_delete_=True,
constructor='mmrazor.SeparateOptimWrapperConstructor',
architecture=dict(optimizer=dict(type='AdamW', lr=1e-1)),
generator=dict(optimizer=dict(type='AdamW', lr=1e-3)))
auto_scale_lr = dict(base_batch_size=256)
param_scheduler = dict(
_delete_=True,
architecture=[
dict(type='LinearLR', end=500, by_epoch=False, start_factor=0.0001),
dict(
type='MultiStepLR',
begin=500,
milestones=[100 * 120, 200 * 120],
by_epoch=False)
],
generator=dict(
type='LinearLR', end=500, by_epoch=False, start_factor=0.0001))
train_cfg = dict(
_delete_=True, by_epoch=False, max_iters=250 * 120, val_interval=150)
train_dataloader = dict(
batch_size=256, sampler=dict(type='InfiniteSampler', shuffle=True))
val_dataloader = dict(batch_size=256)
val_evaluator = dict(type='Accuracy', topk=(1, 5))
default_hooks = dict(
logger=dict(type='LoggerHook', interval=75, log_metric_by_epoch=False),
checkpoint=dict(
type='CheckpointHook', by_epoch=False, interval=150, max_keep_ckpts=2))
log_processor = dict(by_epoch=False)
# Must set diff_rank_seed to True!
randomness = dict(seed=None, diff_rank_seed=True)

Binary file not shown.

After

Width:  |  Height:  |  Size: 422 KiB

View File

@ -1,11 +1,13 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .base import BaseAlgorithm
from .distill import FpnTeacherDistill, SelfDistill, SingleTeacherDistill
from .distill import (DAFLDataFreeDistillation, DataFreeDistillation,
FpnTeacherDistill, SelfDistill, SingleTeacherDistill)
from .nas import SPOS, AutoSlim, AutoSlimDDP, Darts, DartsDDP
from .pruning import SlimmableNetwork, SlimmableNetworkDDP
__all__ = [
'SingleTeacherDistill', 'BaseAlgorithm', 'FpnTeacherDistill', 'SPOS',
'SlimmableNetwork', 'SlimmableNetworkDDP', 'AutoSlim', 'AutoSlimDDP',
'Darts', 'DartsDDP', 'SelfDistill'
'Darts', 'DartsDDP', 'SelfDistill', 'DataFreeDistillation',
'DAFLDataFreeDistillation'
]

View File

@ -1,4 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .configurable import FpnTeacherDistill, SelfDistill, SingleTeacherDistill
from .configurable import (DAFLDataFreeDistillation, DataFreeDistillation,
FpnTeacherDistill, SelfDistill,
SingleTeacherDistill)
__all__ = ['SingleTeacherDistill', 'FpnTeacherDistill', 'SelfDistill']
__all__ = [
'SingleTeacherDistill', 'FpnTeacherDistill', 'SelfDistill',
'DataFreeDistillation', 'DAFLDataFreeDistillation'
]

View File

@ -1,6 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .datafree_distillation import (DAFLDataFreeDistillation,
DataFreeDistillation)
from .fpn_teacher_distill import FpnTeacherDistill
from .self_distill import SelfDistill
from .single_teacher_distill import SingleTeacherDistill
__all__ = ['SelfDistill', 'SingleTeacherDistill', 'FpnTeacherDistill']
__all__ = [
'SelfDistill', 'SingleTeacherDistill', 'FpnTeacherDistill',
'DataFreeDistillation', 'DAFLDataFreeDistillation'
]

View File

@ -0,0 +1,224 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Tuple
import torch
import torch.nn as nn
from mmcv.runner import load_checkpoint
from mmengine.optim import OPTIMIZERS, OptimWrapper
from mmrazor.models.utils import add_prefix, set_requires_grad
from mmrazor.registry import MODELS
from ...base import BaseAlgorithm
@MODELS.register_module()
class DataFreeDistillation(BaseAlgorithm):
"""Algorithm for data-free teacher-student distillation Typically, the
teacher is a pretrained model and the student is a small model trained on
the generator's output. The student is trained to mimic the behavior of the
teacher. The generator is trained to generate images that are similar to
the real images.
Args:
distiller (dict): The config dict for built distiller.
generator_distiller (dict): The distiller collecting outputs & losses
to update the generator.
teachers (dict[str, dict]): The dict of config dict for teacher models
and their ckpt_path (optional).
generator (dictl): The config dict for built distiller generator.
student_iter (int): The number of student steps in train_step().
Defaults to 1.
student_train_first (bool): Whether to train student in first place.
Defaults to False.
"""
def __init__(self,
distiller: dict,
generator_distiller: dict,
teachers: Dict[str, Dict[str, dict]],
generator: dict,
student_iter: int = 1,
student_train_first: bool = False,
**kwargs) -> None:
super().__init__(**kwargs)
self.student_iter = student_iter
self.student_train_first = student_train_first
self.distiller = MODELS.build(distiller)
self.generator_distiller = MODELS.build(generator_distiller)
if not isinstance(teachers, Dict):
raise TypeError('teacher should be a `dict` but got '
f'{type(teachers)}')
self.teachers = nn.ModuleDict()
for teacher_name, cfg in teachers.items():
self.teachers[teacher_name] = MODELS.build(cfg['build_cfg'])
if 'ckpt_path' in cfg:
# avoid loaded parameters be overwritten
self.teachers[teacher_name].init_weights()
_ = load_checkpoint(self.teachers[teacher_name],
cfg['ckpt_path'])
self.teachers[teacher_name].eval()
set_requires_grad(self.teachers[teacher_name], False)
if not isinstance(generator, Dict):
raise TypeError('generator should be a `dict` instance, but got '
f'{type(generator)}')
self.generator = MODELS.build(generator)
# In ``DataFreeDistiller``, the recorder manager is just
# constructed, but not really initialized yet.
self.distiller.prepare_from_student(self.student)
self.distiller.prepare_from_teacher(self.teachers)
self.generator_distiller.prepare_from_student(self.student)
self.generator_distiller.prepare_from_teacher(self.teachers)
@property
def student(self) -> nn.Module:
"""Alias for ``architecture``."""
return self.architecture
def train_step(self, data: List[dict],
optim_wrapper: OptimWrapper) -> Dict[str, torch.Tensor]:
"""Train step for DataFreeDistillation.
Args:
data (List[dict]): Data sampled by dataloader.
optim_wrapper (OptimWrapper): A wrapper of optimizer to
update parameters.
"""
log_vars = dict()
for _, teacher in self.teachers.items():
teacher.eval()
if self.student_train_first:
_, dis_log_vars = self.train_student(data,
optim_wrapper['architecture'])
_, generator_loss_vars = self.train_generator(
data, optim_wrapper['generator'])
else:
_, generator_loss_vars = self.train_generator(
data, optim_wrapper['generator'])
_, dis_log_vars = self.train_student(data,
optim_wrapper['architecture'])
log_vars.update(dis_log_vars)
log_vars.update(generator_loss_vars)
return log_vars
def train_student(
self, data: List[dict], optimizer: OPTIMIZERS
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
"""Train step for the student model.
Args:
data (List[dict]): Data sampled by dataloader.
optimizer (OPTIMIZERS): The optimizer to update student.
"""
log_vars = dict()
batch_size = len(data)
for _ in range(self.student_iter):
fakeimg_init = torch.randn(
(batch_size, self.generator.module.latent_dim))
fakeimg = self.generator(fakeimg_init, batch_size).detach()
with optimizer.optim_context(self.student):
_, data_samples = self.data_preprocessor(data, True)
# recorde the needed information
with self.distiller.student_recorders:
_ = self.student(fakeimg, data_samples, mode='loss')
with self.distiller.teacher_recorders, torch.no_grad():
for _, teacher in self.teachers.items():
_ = teacher(fakeimg, data_samples, mode='loss')
loss_distill = self.distiller.compute_distill_losses()
distill_loss, distill_log_vars = self.parse_losses(loss_distill)
optimizer.update_params(distill_loss)
log_vars = dict(add_prefix(distill_log_vars, 'distill'))
return distill_loss, log_vars
def train_generator(
self, data: List[dict], optimizer: OPTIMIZERS
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
"""Train step for the generator.
Args:
data (List[dict]): Data sampled by dataloader.
optimizer (OPTIMIZERS): The optimizer to update generator.
"""
batch_size = len(data)
fakeimg_init = torch.randn(
(batch_size, self.generator.module.latent_dim))
fakeimg = self.generator(fakeimg_init, batch_size)
with optimizer.optim_context(self.generator):
_, data_samples = self.data_preprocessor(data, True)
# recorde the needed information
with self.generator_distiller.student_recorders:
_ = self.student(fakeimg, data_samples, mode='loss')
with self.generator_distiller.teacher_recorders:
for _, teacher in self.teachers.items():
_ = teacher(fakeimg, data_samples, mode='loss')
loss_generator = self.generator_distiller.compute_distill_losses()
generator_loss, generator_loss_vars = self.parse_losses(loss_generator)
optimizer.update_params(generator_loss)
log_vars = dict(add_prefix(generator_loss_vars, 'generator'))
return generator_loss, log_vars
@MODELS.register_module()
class DAFLDataFreeDistillation(DataFreeDistillation):
def train_step(self, data: List[dict],
optim_wrapper: OptimWrapper) -> Dict[str, torch.Tensor]:
"""DAFL train step.
Args:
data (List[dict]): Data sampled by dataloader.
optim_wrapper (OptimWrapper): A wrapper of optimizer to
update parameters.
"""
log_vars = dict()
batch_size = len(data)
for _, teacher in self.teachers.items():
teacher.eval()
# fakeimg initialization and revised by generator.
fakeimg_init = torch.randn(
(batch_size, self.generator.module.latent_dim))
fakeimg = self.generator(fakeimg_init, batch_size)
_, data_samples = self.data_preprocessor(data, True)
with optim_wrapper['generator'].optim_context(self.generator):
# recorde the needed information
with self.generator_distiller.student_recorders:
_ = self.student(fakeimg, data_samples, mode='loss')
with self.generator_distiller.teacher_recorders:
for _, teacher in self.teachers.items():
_ = teacher(fakeimg, data_samples, mode='loss')
loss_generator = self.generator_distiller.compute_distill_losses()
generator_loss, generator_loss_vars = self.parse_losses(loss_generator)
log_vars.update(add_prefix(generator_loss_vars, 'generator'))
with optim_wrapper['architecture'].optim_context(self.student):
# recorde the needed information
with self.distiller.student_recorders:
_ = self.student(fakeimg.detach(), data_samples, mode='loss')
with self.distiller.teacher_recorders, torch.no_grad():
for _, teacher in self.teachers.items():
_ = teacher(fakeimg.detach(), data_samples, mode='loss')
loss_distill = self.distiller.compute_distill_losses()
distill_loss, distill_log_vars = self.parse_losses(loss_distill)
log_vars.update(add_prefix(distill_log_vars, 'distill'))
optim_wrapper['generator'].update_params(generator_loss)
optim_wrapper['architecture'].update_params(distill_loss)
return log_vars

View File

@ -2,4 +2,5 @@
from .backbones import * # noqa: F401,F403
from .connectors import * # noqa: F401,F403
from .dynamic_op import * # noqa: F401,F403
from .generators import * # noqa: F401,F403
from .heads import * # noqa: F401,F403

View File

@ -23,7 +23,7 @@ class BaseConnector(BaseModule, metaclass=ABCMeta):
def __init__(self, init_cfg: Optional[Dict] = None) -> None:
super().__init__(init_cfg=init_cfg)
def forward(self, feature: torch.Tensor) -> None:
def forward(self, feature: torch.Tensor) -> torch.Tensor:
"""Forward computation.
Args:

View File

@ -15,13 +15,13 @@ class ConvModuleConncetor(BaseConnector):
Args:
in_channel (int): The input channel of the connector.
out_channel (int): The output channel of the connector.
kernel_size (int | tuple[int]): Size of the convolving kernel.
kernel_size (int | tuple[int, int]): Size of the convolving kernel.
Same as that in ``nn._ConvNd``.
stride (int | tuple[int]): Stride of the convolution.
stride (int | tuple[int, int]): Stride of the convolution.
Same as that in ``nn._ConvNd``.
padding (int | tuple[int]): Zero-padding added to both sides of
padding (int | tuple[int, int]): Zero-padding added to both sides of
the input. Same as that in ``nn._ConvNd``.
dilation (int | tuple[int]): Spacing between kernel elements.
dilation (int | tuple[int, int]): Spacing between kernel elements.
Same as that in ``nn._ConvNd``.
groups (int): Number of blocked connections from input channels to
output channels. Same as that in ``nn._ConvNd``.
@ -53,10 +53,10 @@ class ConvModuleConncetor(BaseConnector):
self,
in_channel: int,
out_channel: int,
kernel_size: Union[int, Tuple[int]] = 1,
stride: Union[int, Tuple[int]] = 1,
padding: Union[int, Tuple[int]] = 0,
dilation: Union[int, Tuple[int]] = 1,
kernel_size: Union[int, Tuple[int, int]] = 1,
stride: Union[int, Tuple[int, int]] = 1,
padding: Union[int, Tuple[int, int]] = 0,
dilation: Union[int, Tuple[int, int]] = 1,
groups: int = 1,
bias: Union[str, bool] = 'auto',
conv_cfg: Optional[Dict] = None,

View File

@ -0,0 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .dafl_generator import DAFLGenerator
__all__ = ['DAFLGenerator']

View File

@ -0,0 +1,63 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, Optional
import torch
from mmcv.runner import BaseModule
from mmrazor.models.utils import get_module_device
class BaseGenerator(BaseModule):
"""The base class for generating images.
Args:
img_size (int): The size of generated image.
latent_dim (int): The dimension of latent data.
hidden_channels (int): The dimension of hidden channels.
init_cfg (dict, optional): The config to control the initialization.
Defaults to None.
"""
def __init__(self,
img_size: int,
latent_dim: int,
hidden_channels: int,
init_cfg: Optional[Dict] = None) -> None:
super().__init__(init_cfg=init_cfg)
self.img_size = img_size
self.latent_dim = latent_dim
self.hidden_channels = hidden_channels
def process_latent(self,
latent_data: Optional[torch.Tensor] = None,
batch_size: int = 1) -> torch.Tensor:
"""Generate the latent data if the input is None. Put the latent data
into the current gpu.
Args:
latent_data (torch.Tensor, optional): The latent data. Defaults to
None.
batch_size (int): The batch size of the latent data. Defaults to 1.
"""
if isinstance(latent_data, torch.Tensor):
assert latent_data.shape[1] == self.latent_dim, \
'Second dimension of the input must be equal to "latent_dim",'\
f'but got {latent_data.shape[1]} != {self.latent_dim}.'
if latent_data.ndim == 2:
batch_data = latent_data
else:
raise ValueError('The noise should be in shape of (n, c)'
f'but got {latent_data.shape}')
elif latent_data is None:
assert batch_size > 0, \
'"batch_size" should larger than zero when "latent_data" is '\
f'None, but got {batch_size}.'
batch_data = torch.randn((batch_size, self.latent_dim))
# putting data on the right device
batch_data = batch_data.to(get_module_device(self))
return batch_data
def forward(self) -> None:
"""Forward function."""
raise NotImplementedError

View File

@ -0,0 +1,86 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmrazor.registry import MODELS
from .base_generator import BaseGenerator
@MODELS.register_module()
class DAFLGenerator(BaseGenerator):
"""Generator for DAFL.
Args:
img_size (int): The size of generated image.
latent_dim (int): The dimension of latent data.
hidden_channels (int): The dimension of hidden channels.
scale_factor (int, optional): The scale factor for F.interpolate.
Defaults to 2.
bn_eps (float, optional): The eps param in bn. Defaults to 0.8.
leaky_slope (float, optional): The slope param in leaky relu. Defaults
to 0.2.
init_cfg (dict, optional): The config to control the initialization.
"""
def __init__(
self,
img_size: int,
latent_dim: int,
hidden_channels: int,
scale_factor: int = 2,
bn_eps: float = 0.8,
leaky_slope: float = 0.2,
init_cfg: Optional[Dict] = None,
) -> None:
super().__init__(
img_size, latent_dim, hidden_channels, init_cfg=init_cfg)
self.init_size = self.img_size // (scale_factor**2)
self.scale_factor = scale_factor
self.linear = nn.Linear(self.latent_dim,
self.hidden_channels * self.init_size**2)
self.bn1 = nn.BatchNorm2d(self.hidden_channels)
self.conv_blocks1 = nn.Sequential(
nn.Conv2d(
self.hidden_channels,
self.hidden_channels,
3,
stride=1,
padding=1),
nn.BatchNorm2d(self.hidden_channels, eps=bn_eps),
nn.LeakyReLU(leaky_slope, inplace=True),
)
self.conv_blocks2 = nn.Sequential(
nn.Conv2d(
self.hidden_channels,
self.hidden_channels // 2,
3,
stride=1,
padding=1),
nn.BatchNorm2d(self.hidden_channels // 2, eps=bn_eps),
nn.LeakyReLU(leaky_slope, inplace=True),
nn.Conv2d(self.hidden_channels // 2, 3, 3, stride=1, padding=1),
nn.Tanh(), nn.BatchNorm2d(3, affine=False))
def forward(self,
data: Optional[torch.Tensor] = None,
batch_size: int = 0) -> torch.Tensor:
"""Forward function for generator.
Args:
data (torch.Tensor, optional): The input data. Defaults to None.
batch_size (int): Batch size. Defaults to 0.
"""
batch_data = self.process_latent(data, batch_size)
img = self.linear(batch_data)
img = img.view(img.shape[0], self.hidden_channels, self.init_size,
self.init_size)
img = self.bn1(img)
img = F.interpolate(img, scale_factor=self.scale_factor)
img = self.conv_blocks1(img)
img = F.interpolate(img, scale_factor=self.scale_factor)
img = self.conv_blocks2(img)
return img

View File

@ -1,6 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .ab_loss import ABLoss
from .cwd import ChannelWiseDivergence
from .dafl_loss import ActivationLoss, InformationEntropyLoss, OnehotLikeLoss
from .decoupled_kd import DKDLoss
from .kd_soft_ce_loss import KDSoftCELoss
from .kl_divergence import KLDivergence
@ -10,5 +11,6 @@ from .weighted_soft_label_distillation import WSLD
__all__ = [
'ChannelWiseDivergence', 'KLDivergence', 'AngleWiseRKD', 'DistanceWiseRKD',
'WSLD', 'L2Loss', 'ABLoss', 'DKDLoss', 'KDSoftCELoss'
'WSLD', 'L2Loss', 'ABLoss', 'DKDLoss', 'KDSoftCELoss', 'ActivationLoss',
'OnehotLikeLoss', 'InformationEntropyLoss'
]

View File

@ -0,0 +1,124 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.runner import get_dist_info
from mmrazor.registry import MODELS
from ..ops import GatherTensors
class DAFLLoss(nn.Module):
"""Base class for DAFL losses.
paper link: https://arxiv.org/pdf/1904.01186.pdf
Args:
loss_weight (float): Weight of the loss. Defaults to 1.0.
"""
def __init__(self, loss_weight=1.0) -> None:
super().__init__()
self.loss_weight = loss_weight
def forward(self, preds_T: torch.Tensor) -> torch.Tensor:
"""Forward function for the DAFLLoss.
Args:
preds_T (torch.Tensor): The predictions of teacher.
"""
return self.loss_weight * self.forward_train(preds_T)
def forward_train(self, preds_T: torch.Tensor) -> torch.Tensor:
"""Forward function during training.
Args:
preds_T (torch.Tensor): The predictions of teacher.
"""
raise NotImplementedError
@MODELS.register_module()
class OnehotLikeLoss(DAFLLoss):
"""The loss function for measuring the one-hot-likeness of the target
logits."""
def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)
def forward_train(self, preds_T: torch.Tensor) -> torch.Tensor:
"""Forward function in training for the OnehotLikeLoss.
Args:
preds_T (torch.Tensor): The predictions of teacher.
"""
fake_label = preds_T.data.max(1)[1]
return F.cross_entropy(preds_T, fake_label)
@MODELS.register_module()
class InformationEntropyLoss(DAFLLoss):
"""The loss function for measuring the class balance of the target logits.
Args:
gather (bool, optional): The switch controlling whether
collecting tensors from multiple gpus. Defaults to True.
"""
def __init__(self, gather=True, **kwargs) -> None:
super().__init__(**kwargs)
self.gather = gather
_, self.world_size = get_dist_info()
def forward_train(self, preds_T: torch.Tensor) -> torch.Tensor:
"""Forward function in training for the InformationEntropyLoss.
Args:
preds_T (torch.Tensor): The predictions of teacher.
"""
# Gather predictions from all GPUS to calibrate the loss function.
if self.gather and self.world_size > 1:
preds_T = torch.cat(GatherTensors.apply(preds_T), dim=0)
class_prob = F.softmax(preds_T, dim=1).mean(dim=0)
info_entropy = class_prob * torch.log10(class_prob)
return info_entropy.sum()
@MODELS.register_module()
class ActivationLoss(nn.Module):
"""The loss function for measuring the activation of the target featuremap.
It is negative of the norm of the target featuremap.
Args:
loss_weight (float): Weight of the loss. Defaults to 1.0.
norm_type (str, optional):The type of the norm. Defaults to 'abs'.
"""
def __init__(self, loss_weight=1.0, norm_type='abs') -> None:
super().__init__()
self.loss_weight = loss_weight
assert norm_type in ['norm', 'abs'], \
'"norm_type" must be "norm" or "abs"'
self.norm_type = norm_type
if self.norm_type == 'norm':
self.norm_fn = lambda x: -x.norm()
elif self.norm_type == 'abs':
self.norm_fn = lambda x: -x.abs().mean()
def forward(self, feat_T: torch.Tensor) -> torch.Tensor:
"""Forward function for the ActivationLoss.
Args:
feat_T (torch.Tensor): The featuremap of teacher.
"""
return self.loss_weight * self.forward_train(feat_T)
def forward_train(self, feat_T: torch.Tensor) -> torch.Tensor:
"""Forward function in training for the ActivationLoss.
Args:
feat_T (torch.Tensor): The featuremap of teacher.
"""
feat_T = feat_T.view(feat_T.size(0), -1)
return self.norm_fn(feat_T)

View File

@ -3,11 +3,12 @@ from .common import Identity
from .darts_series import (DartsDilConv, DartsPoolBN, DartsSepConv,
DartsSkipConnect, DartsZero)
from .efficientnet_series import ConvBnAct, DepthwiseSeparableConv
from .gather_tensors import GatherTensors
from .mobilenet_series import MBBlock
from .shufflenet_series import ShuffleBlock, ShuffleXception
__all__ = [
'ShuffleBlock', 'ShuffleXception', 'DartsPoolBN', 'DartsDilConv',
'DartsSepConv', 'DartsSkipConnect', 'DartsZero', 'MBBlock', 'Identity',
'ConvBnAct', 'DepthwiseSeparableConv'
'ConvBnAct', 'DepthwiseSeparableConv', 'GatherTensors'
]

View File

@ -0,0 +1,58 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Any, Tuple
import torch
import torch.distributed as dist
class GatherTensors(torch.autograd.Function):
"""Gather tensors from all GPUS, supporting backward propagation.
See more details in torch.distributed.all_gather and
torch.distributed.all_reduce.
"""
@staticmethod
def forward(ctx: Any, input: torch.Tensor) -> Tuple[Any, ...]:
"""Forward function.
It must accept a context ctx as the first argument.
The context can be used to store tensors that can be then retrieved
during the backward pass.
Args:
ctx (Any): Context to be used for forward propagation.
input (torch.Tensor): Tensor to be broadcast from current process.
"""
output = [
torch.empty_like(input) for _ in range(dist.get_world_size())
]
dist.all_gather(output, input)
return tuple(output)
@staticmethod
def backward(ctx: Any, *grads: torch.Tensor) -> torch.Tensor:
"""Backward function.
It must accept a context :attr:`ctx` as the first argument, followed by
as many outputs did :func:`forward` return, and it should return as
many tensors, as there were inputs to :func:`forward`. Each argument is
the gradient w.r.t the given output, and each returned value should be
the gradient w.r.t. the corresponding input.
The context can be used to retrieve tensors saved during the forward
pass. It also has an attribute :attr:`ctx.needs_input_grad` as a tuple
of booleans representing whether each input needs gradient. E.g.,
:func:`backward` will have ``ctx.needs_input_grad[0] = True`` if the
first input to :func:`forward` needs gradient computated w.r.t. the
output.
Args:
ctx (Any): Context to be used for forward propagation.
grads (torch.Tensor): Grads to be merged from current process.
"""
rank = dist.get_rank()
merged = torch.stack(grads)
dist.all_reduce(merged)
return merged[rank]

View File

@ -2,7 +2,9 @@
from .make_divisible import make_divisible
from .misc import add_prefix
from .optim_wrapper import reinitialize_optim_wrapper_count_status
from .utils import get_module_device, set_requires_grad
__all__ = [
'add_prefix', 'reinitialize_optim_wrapper_count_status', 'make_divisible'
'add_prefix', 'reinitialize_optim_wrapper_count_status', 'make_divisible',
'get_module_device', 'set_requires_grad'
]

View File

@ -0,0 +1,39 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Union
import torch
import torch.nn as nn
def get_module_device(module: nn.Module) -> torch.device:
"""Get the device of a module.
Args:
module (nn.Module): A module contains the parameters.
"""
try:
next(module.parameters())
except StopIteration as e:
raise ValueError('The input module should contain parameters.') from e
if next(module.parameters()).is_cuda:
return next(module.parameters()).get_device()
return torch.device('cpu')
def set_requires_grad(nets: Union[nn.Module, List[nn.Module]],
requires_grad: bool = False) -> None:
"""Set requires_grad for all the networks.
Args:
nets (nn.Module | list[nn.Module]): A list of networks or a single
network.
requires_grad (bool): Whether the networks require gradients or not
"""
if not isinstance(nets, list):
nets = [nets]
for net in nets:
if net is not None:
for param in net.parameters():
param.requires_grad = requires_grad

View File

@ -0,0 +1,220 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
from unittest import TestCase
import torch
from mmcv import ConfigDict
from mmengine.optim import build_optim_wrapper
from mmrazor.models import DAFLDataFreeDistillation, DataFreeDistillation
class TestDataFreeDistill(TestCase):
def test_init(self):
recorders_cfg = ConfigDict(
conv=dict(type='ModuleOutputs', source='conv'))
alg_kwargs = ConfigDict(
architecture=dict(type='ToyStudent'),
teachers=dict(
tea1=dict(build_cfg=dict(type='ToyTeacher')),
tea2=dict(build_cfg=dict(type='ToyTeacher'))),
generator=dict(type='ToyGenerator'),
distiller=dict(
type='ConfigurableDistiller',
student_recorders=recorders_cfg,
teacher_recorders=dict(
tea1_conv=dict(type='ModuleOutputs', source='tea1.conv')),
distill_losses=dict(loss_dis=dict(type='ToyDistillLoss')),
loss_forward_mappings=dict(
loss_dis=dict(
arg1=dict(from_student=True, recorder='conv'),
arg2=dict(from_student=False, recorder='tea1_conv')))),
generator_distiller=dict(
type='ConfigurableDistiller',
student_recorders=recorders_cfg,
teacher_recorders=dict(
tea2_conv=dict(type='ModuleOutputs', source='tea2.conv')),
distill_losses=dict(loss_gen=dict(type='ToyDistillLoss')),
loss_forward_mappings=dict(
loss_gen=dict(
arg1=dict(from_student=True, recorder='conv'),
arg2=dict(from_student=False, recorder='tea2_conv')))),
)
alg = DataFreeDistillation(**alg_kwargs)
self.assertEquals(len(alg.teachers), len(alg_kwargs['teachers']))
alg_kwargs_ = copy.deepcopy(alg_kwargs)
alg_kwargs_['teachers'] = 'ToyTeacher'
with self.assertRaisesRegex(TypeError,
'teacher should be a `dict` but got '):
alg = DataFreeDistillation(**alg_kwargs_)
alg_kwargs_ = copy.deepcopy(alg_kwargs)
alg_kwargs_['generator'] = 'ToyGenerator'
with self.assertRaisesRegex(
TypeError, 'generator should be a `dict` instance, but got '):
_ = DataFreeDistillation(**alg_kwargs_)
def test_loss(self):
recorders_cfg = ConfigDict(
conv=dict(type='ModuleOutputs', source='conv'))
alg_kwargs = ConfigDict(
architecture=dict(type='ToyStudent'),
teachers=dict(
tea1=dict(build_cfg=dict(type='ToyTeacher')),
tea2=dict(build_cfg=dict(type='ToyTeacher'))),
generator=dict(type='ToyGenerator'),
distiller=dict(
type='ConfigurableDistiller',
student_recorders=recorders_cfg,
teacher_recorders=dict(
tea1_conv=dict(type='ModuleOutputs', source='tea1.conv')),
distill_losses=dict(loss_dis=dict(type='ToyDistillLoss')),
loss_forward_mappings=dict(
loss_dis=dict(
arg1=dict(from_student=True, recorder='conv'),
arg2=dict(from_student=False, recorder='tea1_conv')))),
generator_distiller=dict(
type='ConfigurableDistiller',
student_recorders=recorders_cfg,
teacher_recorders=dict(
tea2_conv=dict(type='ModuleOutputs', source='tea2.conv')),
distill_losses=dict(loss_gen=dict(type='ToyDistillLoss')),
loss_forward_mappings=dict(
loss_gen=dict(
arg1=dict(from_student=True, recorder='conv'),
arg2=dict(from_student=False, recorder='tea2_conv')))),
)
optim_wrapper_cfg = dict(
type='OptimWrapper',
optimizer=dict(
type='SGD', lr=0.1, weight_decay=0.01, momentum=0.9))
data = [dict(inputs=torch.randn(3, 1, 1)) for _ in range(4)]
alg = DataFreeDistillation(**alg_kwargs)
optim_wrapper = build_optim_wrapper(alg, optim_wrapper_cfg)
optim_wrapper_dict = dict(
architecture=optim_wrapper, generator=optim_wrapper)
losses = alg.train_step(data, optim_wrapper_dict)
self.assertIn('distill.loss_dis', losses)
self.assertIn('distill.loss', losses)
self.assertIn('generator.loss_gen', losses)
self.assertIn('generator.loss', losses)
alg_kwargs_ = copy.deepcopy(alg_kwargs)
alg_kwargs_['student_iter'] = 5
alg = DataFreeDistillation(**alg_kwargs_)
losses = alg.train_step(data, optim_wrapper_dict)
self.assertIn('distill.loss_dis', losses)
self.assertIn('distill.loss', losses)
self.assertIn('generator.loss_gen', losses)
self.assertIn('generator.loss', losses)
class TestDAFLDataFreeDistill(TestCase):
def test_init(self):
recorders_cfg = ConfigDict(
conv=dict(type='ModuleOutputs', source='conv'))
alg_kwargs = ConfigDict(
architecture=dict(type='ToyStudent'),
teachers=dict(
tea1=dict(build_cfg=dict(type='ToyTeacher')),
tea2=dict(build_cfg=dict(type='ToyTeacher'))),
generator=dict(type='ToyGenerator'),
distiller=dict(
type='ConfigurableDistiller',
student_recorders=recorders_cfg,
teacher_recorders=dict(
tea1_conv=dict(type='ModuleOutputs', source='tea1.conv')),
distill_losses=dict(loss_dis=dict(type='ToyDistillLoss')),
loss_forward_mappings=dict(
loss_dis=dict(
arg1=dict(from_student=True, recorder='conv'),
arg2=dict(from_student=False, recorder='tea1_conv')))),
generator_distiller=dict(
type='ConfigurableDistiller',
student_recorders=recorders_cfg,
teacher_recorders=dict(
tea2_conv=dict(type='ModuleOutputs', source='tea2.conv')),
distill_losses=dict(loss_gen=dict(type='ToyDistillLoss')),
loss_forward_mappings=dict(
loss_gen=dict(
arg1=dict(from_student=True, recorder='conv'),
arg2=dict(from_student=False, recorder='tea2_conv')))))
alg = DAFLDataFreeDistillation(**alg_kwargs)
self.assertEquals(len(alg.teachers), len(alg_kwargs['teachers']))
alg_kwargs_ = copy.deepcopy(alg_kwargs)
alg_kwargs_['teachers'] = 'ToyTeacher'
with self.assertRaisesRegex(TypeError,
'teacher should be a `dict` but got '):
alg = DAFLDataFreeDistillation(**alg_kwargs_)
alg_kwargs_ = copy.deepcopy(alg_kwargs)
alg_kwargs_['generator'] = 'ToyGenerator'
with self.assertRaisesRegex(
TypeError, 'generator should be a `dict` instance, but got '):
_ = DAFLDataFreeDistillation(**alg_kwargs_)
def test_loss(self):
recorders_cfg = ConfigDict(
conv=dict(type='ModuleOutputs', source='conv'))
alg_kwargs = ConfigDict(
architecture=dict(type='ToyStudent'),
teachers=dict(
tea1=dict(build_cfg=dict(type='ToyTeacher')),
tea2=dict(build_cfg=dict(type='ToyTeacher'))),
generator=dict(type='ToyGenerator'),
distiller=dict(
type='ConfigurableDistiller',
student_recorders=recorders_cfg,
teacher_recorders=dict(
tea1_conv=dict(type='ModuleOutputs', source='tea1.conv')),
distill_losses=dict(loss_dis=dict(type='ToyDistillLoss')),
loss_forward_mappings=dict(
loss_dis=dict(
arg1=dict(from_student=True, recorder='conv'),
arg2=dict(from_student=False, recorder='tea1_conv')))),
generator_distiller=dict(
type='ConfigurableDistiller',
student_recorders=recorders_cfg,
teacher_recorders=dict(
tea1_conv=dict(type='ModuleOutputs', source='tea1.conv'),
tea2_conv=dict(type='ModuleOutputs', source='tea2.conv')),
distill_losses=dict(loss_gen=dict(type='ToyDistillLoss')),
loss_forward_mappings=dict(
loss_gen=dict(
arg1=dict(from_student=False, recorder='tea1_conv'),
arg2=dict(from_student=False, recorder='tea2_conv')))))
optim_wrapper_cfg = dict(
type='OptimWrapper',
optimizer=dict(
type='SGD', lr=0.1, weight_decay=0.01, momentum=0.9))
data = [dict(inputs=torch.randn(3, 1, 1)) for _ in range(4)]
alg = DAFLDataFreeDistillation(**alg_kwargs)
optim_wrapper = build_optim_wrapper(alg, optim_wrapper_cfg)
optim_wrapper_dict = dict(
architecture=optim_wrapper, generator=optim_wrapper)
losses = alg.train_step(data, optim_wrapper_dict)
self.assertIn('distill.loss_dis', losses)
self.assertIn('distill.loss', losses)
self.assertIn('generator.loss_gen', losses)
self.assertIn('generator.loss', losses)

View File

@ -1,5 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
from dataclasses import dataclass
import torch
from mmengine.model import BaseModel
from torch import nn
@ -32,6 +34,29 @@ class ToyTeacher(ToyStudent):
super().__init__()
@dataclass(frozen=True)
class Data:
latent_dim: int = 1
@MODELS.register_module()
class ToyGenerator(BaseModel):
def __init__(self, latent_dim=4, out_channel=3):
super().__init__(data_preprocessor=None, init_cfg=None)
self.latent_dim = latent_dim
self.out_channel = out_channel
self.conv = nn.Conv2d(self.latent_dim, self.out_channel, 1)
# Imitate the structure of generator in separate model_wrapper.
self.module = Data(latent_dim=self.latent_dim)
def forward(self, data=None, batch_size=4):
fakeimg_init = torch.randn(batch_size, self.latent_dim, 1, 1)
fakeimg = self.conv(fakeimg_init)
return fakeimg
@MODELS.register_module()
class ToyDistillLoss(nn.Module):

View File

@ -0,0 +1,39 @@
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch
from mmrazor.models import DAFLGenerator
def test_dafl_generator():
dafl_generator = DAFLGenerator(
img_size=32, latent_dim=10, hidden_channels=32)
z_batch = torch.randn(8, 10)
fake_img = dafl_generator(z_batch)
assert fake_img.size() == torch.Size([8, 3, 32, 32])
with pytest.raises(AssertionError):
z_batch = torch.randn(8, 11)
fake_img = dafl_generator(z_batch)
with pytest.raises(ValueError):
z_batch = torch.randn(8, 10, 1, 1)
fake_img = dafl_generator(z_batch)
fake_img = dafl_generator(batch_size=8)
assert fake_img.size() == torch.Size([8, 3, 32, 32])
# scale_factor = 4
dafl_generator = DAFLGenerator(
img_size=32, latent_dim=10, hidden_channels=32, scale_factor=4)
z_batch = torch.randn(8, 10)
fake_img = dafl_generator(z_batch)
assert fake_img.size() == torch.Size([8, 3, 32, 32])
# hidden_channels=64
dafl_generator = DAFLGenerator(
img_size=32, latent_dim=10, hidden_channels=64)
z_batch = torch.randn(8, 10)
fake_img = dafl_generator(z_batch)
assert fake_img.size() == torch.Size([8, 3, 32, 32])
with pytest.raises(AssertionError):
fake_img = dafl_generator(data=None, batch_size=0)

View File

@ -3,7 +3,9 @@ from unittest import TestCase
import torch
from mmrazor.models import ABLoss, DKDLoss, KDSoftCELoss
from mmrazor.models import (ABLoss, ActivationLoss, DKDLoss,
InformationEntropyLoss, KDSoftCELoss,
OnehotLikeLoss)
class TestLosses(TestCase):
@ -51,6 +53,32 @@ class TestLosses(TestCase):
# dkd requires label logits
self.normal_test_1d(dkd_loss, labels=True)
def test_dafl_loss(self):
dafl_loss_cfg = dict(loss_weight=1.0)
ac_loss = ActivationLoss(**dafl_loss_cfg, norm_type='abs')
oh_loss = OnehotLikeLoss(**dafl_loss_cfg)
ie_loss = InformationEntropyLoss(**dafl_loss_cfg, gather=False)
# normal test with only one input
loss_ac = ac_loss.forward(self.feats_1d)
self.assertTrue(loss_ac.numel() == 1)
loss_oh = oh_loss.forward(self.feats_1d)
self.assertTrue(loss_oh.numel() == 1)
loss_ie = ie_loss.forward(self.feats_1d)
self.assertTrue(loss_ie.numel() == 1)
with self.assertRaisesRegex(AssertionError,
'"norm_type" must be "norm" or "abs"'):
_ = ActivationLoss(**dafl_loss_cfg, norm_type='random')
# test gather_tensors
ie_loss = InformationEntropyLoss(**dafl_loss_cfg, gather=True)
ie_loss.world_size = 2
with self.assertRaisesRegex(
RuntimeError,
'Default process group has not been initialized'):
loss_ie = ie_loss.forward(self.feats_1d)
def test_kdSoftce_loss(self):
kdSoftce_loss_cfg = dict(loss_weight=1.0)
kdSoftce_loss = KDSoftCELoss(**kdSoftce_loss_cfg)