[Feature] Add DAFL Distillation (#235)
* 1.Add DAFL, including config, DAFLLoss and readme. 2.Add DataFreeDistillationtillation. 3.Add Generator, including base_generator and dafl_generator. 4.Add get_module_device and set_requires_grad functions in utils. * 1.Amend the file that report error in mypy test under py37, including gather_tensors, datafree_distillation, base_generator. 2.Revise other linting error. * 1.Revise some docstrings. * 1.Add UT for datafreedistillation. 2.Add all typing.hints. * 1.Add UT for generators and gather_tensors. * 1.Add assert of batch_size in base_generator * 1.Isort Co-authored-by: zhangzhongyu.vendor < zhangzhongyu.vendor@sensetime.com>pull/230/head
parent
72c11751cb
commit
57aec1f730
configs/distill/mmcls/dafl
docs/en/imgs/model_zoo/dafl
mmrazor/models
algorithms
distill
configurable
architectures
losses
utils
tests/test_models
test_algorithms
test_architectures/test_generators
test_losses
|
@ -0,0 +1,42 @@
|
|||
# Data-Free Learning of Student Networks (DAFL)
|
||||
|
||||
> [Data-Free Learning of Student Networks](https://doi.org/10.1109/ICCV.2019.00361)
|
||||
|
||||
<!-- [ALGORITHM] -->
|
||||
|
||||
## Abstract
|
||||
|
||||
Learning portable neural networks is very essential for computer vision for the purpose that pre-trained heavy deep models can be well applied on edge devices such as mobile phones and micro sensors. Most existing deep neural network compression and speed-up methods are very effective for training compact deep models, when we can directly access the training dataset. However, training data for the given deep network are often unavailable due to some practice problems (e.g. privacy, legal issue, and transmission), and the architecture of the given network are also unknown except some interfaces. To this end, we propose a novel framework for training efficient deep neural networks by exploiting generative adversarial networks (GANs). To be specific, the pre-trained teacher networks are regarded as a fixed discriminator and the generator is utilized for deviating training samples which can obtain the maximum response on the discriminator. Then, an efficient network with smaller model size and computational complexity is trained using the generated data and the teacher network, simultaneously. Efficient student networks learned using the pro- posed Data-Free Learning (DAFL) method achieve 92.22% and 74.47% accuracies using ResNet-18 without any training data on the CIFAR-10 and CIFAR-100 datasets, respectively. Meanwhile, our student network obtains an 80.56% accuracy on the CelebA benchmark.
|
||||
|
||||

|
||||
|
||||
## Results and models
|
||||
|
||||
### Classification
|
||||
|
||||
| Location | Dataset | Teacher | Student | Acc | Acc(T) | Acc(S) | Config | Download |
|
||||
| :----------------------------------: | :-----: | :-------------------------------------------------------------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------------: | :---: | :----: | :----: | :-----------------------------------------------: | :------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| backbone (pretrain) & logits (train) | Cifar10 | [resnet34](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet34_8xb16_cifar10.py) | [resnet18](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet18_8xb16_cifar10.py) | 93.11 | 95.34 | 94.82 | [config](./dafl_logits_r34_r18_8xb256_cifar10.py) | [teacher](https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_b16x8_cifar10_20210528-a8aa36a6.pth) \|[model](<>) \| [log](<>) |
|
||||
|
||||
## Citation
|
||||
|
||||
```latex
|
||||
@inproceedings{DBLP:conf/iccv/ChenW0YLSXX019,
|
||||
author = {Hanting Chen, Yunhe Wang, Chang Xu, Zhaohui Yang, Chuanjian Liu,
|
||||
Boxin Shi, Chunjing Xu, Chao Xu and Qi Tian},
|
||||
title = {Data-Free Learning of Student Networks},
|
||||
booktitle = {2019 {IEEE/CVF} International Conference on Computer Vision, {ICCV}
|
||||
2019, Seoul, Korea (South), October 27 - November 2, 2019},
|
||||
pages = {3513--3521},
|
||||
publisher = {{IEEE}},
|
||||
year = {2019},
|
||||
url = {https://doi.org/10.1109/ICCV.2019.00361},
|
||||
doi = {10.1109/ICCV.2019.00361},
|
||||
timestamp = {Mon, 17 May 2021 08:18:18 +0200},
|
||||
biburl = {https://dblp.org/rec/conf/iccv/ChenW0YLSXX019.bib},
|
||||
bibsource = {dblp computer science bibliography, https://dblp.org}
|
||||
```
|
||||
|
||||
## Acknowledgement
|
||||
|
||||
Shout out to Davidgzx.
|
|
@ -0,0 +1,104 @@
|
|||
_base_ = [
|
||||
'mmcls::_base_/datasets/cifar10_bs16.py',
|
||||
'mmcls::_base_/schedules/cifar10_bs128.py',
|
||||
'mmcls::_base_/default_runtime.py'
|
||||
]
|
||||
|
||||
model = dict(
|
||||
_scope_='mmrazor',
|
||||
type='DAFLDataFreeDistillation',
|
||||
data_preprocessor=dict(
|
||||
type='ImgDataPreprocessor',
|
||||
# RGB format normalization parameters
|
||||
mean=[125.307, 122.961, 113.8575],
|
||||
std=[51.5865, 50.847, 51.255],
|
||||
# convert image from BGR to RGB
|
||||
bgr_to_rgb=False),
|
||||
architecture=dict(
|
||||
cfg_path='mmcls::resnet/resnet18_8xb16_cifar10.py', pretrained=False),
|
||||
teachers=dict(
|
||||
res34=dict(
|
||||
build_cfg=dict(
|
||||
cfg_path='mmcls::resnet/resnet34_8xb16_cifar10.py',
|
||||
pretrained=True),
|
||||
ckpt_path='resnet34_b16x8_cifar10_20210528-a8aa36a6.pth')),
|
||||
generator=dict(
|
||||
type='DAFLGenerator',
|
||||
img_size=32,
|
||||
latent_dim=1000,
|
||||
hidden_channels=128),
|
||||
distiller=dict(
|
||||
type='ConfigurableDistiller',
|
||||
student_recorders=dict(
|
||||
fc=dict(type='ModuleOutputs', source='head.fc')),
|
||||
teacher_recorders=dict(
|
||||
res34_fc=dict(type='ModuleOutputs', source='res34.head.fc')),
|
||||
distill_losses=dict(
|
||||
loss_kl=dict(type='KLDivergence', tau=6, loss_weight=1)),
|
||||
loss_forward_mappings=dict(
|
||||
loss_kl=dict(
|
||||
preds_S=dict(from_student=True, recorder='fc'),
|
||||
preds_T=dict(from_student=False, recorder='res34_fc')))),
|
||||
generator_distiller=dict(
|
||||
type='ConfigurableDistiller',
|
||||
teacher_recorders=dict(
|
||||
res34_neck_gap=dict(type='ModuleOutputs', source='res34.neck.gap'),
|
||||
res34_fc=dict(type='ModuleOutputs', source='res34.head.fc')),
|
||||
distill_losses=dict(
|
||||
loss_res34_oh=dict(type='OnehotLikeLoss', loss_weight=0.05),
|
||||
loss_res34_ie=dict(type='InformationEntropyLoss', loss_weight=5),
|
||||
loss_res34_ac=dict(type='ActivationLoss', loss_weight=0.01)),
|
||||
loss_forward_mappings=dict(
|
||||
loss_res34_oh=dict(
|
||||
preds_T=dict(from_student=False, recorder='res34_fc')),
|
||||
loss_res34_ie=dict(
|
||||
preds_T=dict(from_student=False, recorder='res34_fc')),
|
||||
loss_res34_ac=dict(
|
||||
feat_T=dict(from_student=False, recorder='res34_neck_gap')))))
|
||||
|
||||
# model wrapper
|
||||
model_wrapper_cfg = dict(
|
||||
type='mmengine.MMSeparateDistributedDataParallel',
|
||||
broadcast_buffers=False,
|
||||
find_unused_parameters=False)
|
||||
|
||||
find_unused_parameters = True
|
||||
|
||||
# optimizer wrapper
|
||||
optim_wrapper = dict(
|
||||
_delete_=True,
|
||||
constructor='mmrazor.SeparateOptimWrapperConstructor',
|
||||
architecture=dict(optimizer=dict(type='AdamW', lr=1e-1)),
|
||||
generator=dict(optimizer=dict(type='AdamW', lr=1e-3)))
|
||||
|
||||
auto_scale_lr = dict(base_batch_size=256)
|
||||
|
||||
param_scheduler = dict(
|
||||
_delete_=True,
|
||||
architecture=[
|
||||
dict(type='LinearLR', end=500, by_epoch=False, start_factor=0.0001),
|
||||
dict(
|
||||
type='MultiStepLR',
|
||||
begin=500,
|
||||
milestones=[100 * 120, 200 * 120],
|
||||
by_epoch=False)
|
||||
],
|
||||
generator=dict(
|
||||
type='LinearLR', end=500, by_epoch=False, start_factor=0.0001))
|
||||
|
||||
train_cfg = dict(
|
||||
_delete_=True, by_epoch=False, max_iters=250 * 120, val_interval=150)
|
||||
|
||||
train_dataloader = dict(
|
||||
batch_size=256, sampler=dict(type='InfiniteSampler', shuffle=True))
|
||||
val_dataloader = dict(batch_size=256)
|
||||
val_evaluator = dict(type='Accuracy', topk=(1, 5))
|
||||
|
||||
default_hooks = dict(
|
||||
logger=dict(type='LoggerHook', interval=75, log_metric_by_epoch=False),
|
||||
checkpoint=dict(
|
||||
type='CheckpointHook', by_epoch=False, interval=150, max_keep_ckpts=2))
|
||||
|
||||
log_processor = dict(by_epoch=False)
|
||||
# Must set diff_rank_seed to True!
|
||||
randomness = dict(seed=None, diff_rank_seed=True)
|
Binary file not shown.
After Width: | Height: | Size: 422 KiB |
|
@ -1,11 +1,13 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .base import BaseAlgorithm
|
||||
from .distill import FpnTeacherDistill, SelfDistill, SingleTeacherDistill
|
||||
from .distill import (DAFLDataFreeDistillation, DataFreeDistillation,
|
||||
FpnTeacherDistill, SelfDistill, SingleTeacherDistill)
|
||||
from .nas import SPOS, AutoSlim, AutoSlimDDP, Darts, DartsDDP
|
||||
from .pruning import SlimmableNetwork, SlimmableNetworkDDP
|
||||
|
||||
__all__ = [
|
||||
'SingleTeacherDistill', 'BaseAlgorithm', 'FpnTeacherDistill', 'SPOS',
|
||||
'SlimmableNetwork', 'SlimmableNetworkDDP', 'AutoSlim', 'AutoSlimDDP',
|
||||
'Darts', 'DartsDDP', 'SelfDistill'
|
||||
'Darts', 'DartsDDP', 'SelfDistill', 'DataFreeDistillation',
|
||||
'DAFLDataFreeDistillation'
|
||||
]
|
||||
|
|
|
@ -1,4 +1,9 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .configurable import FpnTeacherDistill, SelfDistill, SingleTeacherDistill
|
||||
from .configurable import (DAFLDataFreeDistillation, DataFreeDistillation,
|
||||
FpnTeacherDistill, SelfDistill,
|
||||
SingleTeacherDistill)
|
||||
|
||||
__all__ = ['SingleTeacherDistill', 'FpnTeacherDistill', 'SelfDistill']
|
||||
__all__ = [
|
||||
'SingleTeacherDistill', 'FpnTeacherDistill', 'SelfDistill',
|
||||
'DataFreeDistillation', 'DAFLDataFreeDistillation'
|
||||
]
|
||||
|
|
|
@ -1,6 +1,11 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .datafree_distillation import (DAFLDataFreeDistillation,
|
||||
DataFreeDistillation)
|
||||
from .fpn_teacher_distill import FpnTeacherDistill
|
||||
from .self_distill import SelfDistill
|
||||
from .single_teacher_distill import SingleTeacherDistill
|
||||
|
||||
__all__ = ['SelfDistill', 'SingleTeacherDistill', 'FpnTeacherDistill']
|
||||
__all__ = [
|
||||
'SelfDistill', 'SingleTeacherDistill', 'FpnTeacherDistill',
|
||||
'DataFreeDistillation', 'DAFLDataFreeDistillation'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,224 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.runner import load_checkpoint
|
||||
from mmengine.optim import OPTIMIZERS, OptimWrapper
|
||||
|
||||
from mmrazor.models.utils import add_prefix, set_requires_grad
|
||||
from mmrazor.registry import MODELS
|
||||
from ...base import BaseAlgorithm
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class DataFreeDistillation(BaseAlgorithm):
|
||||
"""Algorithm for data-free teacher-student distillation Typically, the
|
||||
teacher is a pretrained model and the student is a small model trained on
|
||||
the generator's output. The student is trained to mimic the behavior of the
|
||||
teacher. The generator is trained to generate images that are similar to
|
||||
the real images.
|
||||
|
||||
Args:
|
||||
distiller (dict): The config dict for built distiller.
|
||||
generator_distiller (dict): The distiller collecting outputs & losses
|
||||
to update the generator.
|
||||
teachers (dict[str, dict]): The dict of config dict for teacher models
|
||||
and their ckpt_path (optional).
|
||||
generator (dictl): The config dict for built distiller generator.
|
||||
student_iter (int): The number of student steps in train_step().
|
||||
Defaults to 1.
|
||||
student_train_first (bool): Whether to train student in first place.
|
||||
Defaults to False.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
distiller: dict,
|
||||
generator_distiller: dict,
|
||||
teachers: Dict[str, Dict[str, dict]],
|
||||
generator: dict,
|
||||
student_iter: int = 1,
|
||||
student_train_first: bool = False,
|
||||
**kwargs) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.student_iter = student_iter
|
||||
self.student_train_first = student_train_first
|
||||
self.distiller = MODELS.build(distiller)
|
||||
self.generator_distiller = MODELS.build(generator_distiller)
|
||||
|
||||
if not isinstance(teachers, Dict):
|
||||
raise TypeError('teacher should be a `dict` but got '
|
||||
f'{type(teachers)}')
|
||||
|
||||
self.teachers = nn.ModuleDict()
|
||||
for teacher_name, cfg in teachers.items():
|
||||
self.teachers[teacher_name] = MODELS.build(cfg['build_cfg'])
|
||||
if 'ckpt_path' in cfg:
|
||||
# avoid loaded parameters be overwritten
|
||||
self.teachers[teacher_name].init_weights()
|
||||
_ = load_checkpoint(self.teachers[teacher_name],
|
||||
cfg['ckpt_path'])
|
||||
self.teachers[teacher_name].eval()
|
||||
set_requires_grad(self.teachers[teacher_name], False)
|
||||
|
||||
if not isinstance(generator, Dict):
|
||||
raise TypeError('generator should be a `dict` instance, but got '
|
||||
f'{type(generator)}')
|
||||
self.generator = MODELS.build(generator)
|
||||
|
||||
# In ``DataFreeDistiller``, the recorder manager is just
|
||||
# constructed, but not really initialized yet.
|
||||
self.distiller.prepare_from_student(self.student)
|
||||
self.distiller.prepare_from_teacher(self.teachers)
|
||||
self.generator_distiller.prepare_from_student(self.student)
|
||||
self.generator_distiller.prepare_from_teacher(self.teachers)
|
||||
|
||||
@property
|
||||
def student(self) -> nn.Module:
|
||||
"""Alias for ``architecture``."""
|
||||
return self.architecture
|
||||
|
||||
def train_step(self, data: List[dict],
|
||||
optim_wrapper: OptimWrapper) -> Dict[str, torch.Tensor]:
|
||||
"""Train step for DataFreeDistillation.
|
||||
|
||||
Args:
|
||||
data (List[dict]): Data sampled by dataloader.
|
||||
optim_wrapper (OptimWrapper): A wrapper of optimizer to
|
||||
update parameters.
|
||||
"""
|
||||
log_vars = dict()
|
||||
for _, teacher in self.teachers.items():
|
||||
teacher.eval()
|
||||
|
||||
if self.student_train_first:
|
||||
_, dis_log_vars = self.train_student(data,
|
||||
optim_wrapper['architecture'])
|
||||
_, generator_loss_vars = self.train_generator(
|
||||
data, optim_wrapper['generator'])
|
||||
else:
|
||||
_, generator_loss_vars = self.train_generator(
|
||||
data, optim_wrapper['generator'])
|
||||
_, dis_log_vars = self.train_student(data,
|
||||
optim_wrapper['architecture'])
|
||||
|
||||
log_vars.update(dis_log_vars)
|
||||
log_vars.update(generator_loss_vars)
|
||||
return log_vars
|
||||
|
||||
def train_student(
|
||||
self, data: List[dict], optimizer: OPTIMIZERS
|
||||
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
|
||||
"""Train step for the student model.
|
||||
|
||||
Args:
|
||||
data (List[dict]): Data sampled by dataloader.
|
||||
optimizer (OPTIMIZERS): The optimizer to update student.
|
||||
"""
|
||||
log_vars = dict()
|
||||
batch_size = len(data)
|
||||
|
||||
for _ in range(self.student_iter):
|
||||
fakeimg_init = torch.randn(
|
||||
(batch_size, self.generator.module.latent_dim))
|
||||
fakeimg = self.generator(fakeimg_init, batch_size).detach()
|
||||
|
||||
with optimizer.optim_context(self.student):
|
||||
_, data_samples = self.data_preprocessor(data, True)
|
||||
# recorde the needed information
|
||||
with self.distiller.student_recorders:
|
||||
_ = self.student(fakeimg, data_samples, mode='loss')
|
||||
with self.distiller.teacher_recorders, torch.no_grad():
|
||||
for _, teacher in self.teachers.items():
|
||||
_ = teacher(fakeimg, data_samples, mode='loss')
|
||||
loss_distill = self.distiller.compute_distill_losses()
|
||||
|
||||
distill_loss, distill_log_vars = self.parse_losses(loss_distill)
|
||||
optimizer.update_params(distill_loss)
|
||||
log_vars = dict(add_prefix(distill_log_vars, 'distill'))
|
||||
|
||||
return distill_loss, log_vars
|
||||
|
||||
def train_generator(
|
||||
self, data: List[dict], optimizer: OPTIMIZERS
|
||||
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
|
||||
"""Train step for the generator.
|
||||
|
||||
Args:
|
||||
data (List[dict]): Data sampled by dataloader.
|
||||
optimizer (OPTIMIZERS): The optimizer to update generator.
|
||||
"""
|
||||
batch_size = len(data)
|
||||
fakeimg_init = torch.randn(
|
||||
(batch_size, self.generator.module.latent_dim))
|
||||
fakeimg = self.generator(fakeimg_init, batch_size)
|
||||
|
||||
with optimizer.optim_context(self.generator):
|
||||
_, data_samples = self.data_preprocessor(data, True)
|
||||
# recorde the needed information
|
||||
with self.generator_distiller.student_recorders:
|
||||
_ = self.student(fakeimg, data_samples, mode='loss')
|
||||
with self.generator_distiller.teacher_recorders:
|
||||
for _, teacher in self.teachers.items():
|
||||
_ = teacher(fakeimg, data_samples, mode='loss')
|
||||
loss_generator = self.generator_distiller.compute_distill_losses()
|
||||
|
||||
generator_loss, generator_loss_vars = self.parse_losses(loss_generator)
|
||||
optimizer.update_params(generator_loss)
|
||||
log_vars = dict(add_prefix(generator_loss_vars, 'generator'))
|
||||
|
||||
return generator_loss, log_vars
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class DAFLDataFreeDistillation(DataFreeDistillation):
|
||||
|
||||
def train_step(self, data: List[dict],
|
||||
optim_wrapper: OptimWrapper) -> Dict[str, torch.Tensor]:
|
||||
"""DAFL train step.
|
||||
|
||||
Args:
|
||||
data (List[dict]): Data sampled by dataloader.
|
||||
optim_wrapper (OptimWrapper): A wrapper of optimizer to
|
||||
update parameters.
|
||||
"""
|
||||
log_vars = dict()
|
||||
batch_size = len(data)
|
||||
|
||||
for _, teacher in self.teachers.items():
|
||||
teacher.eval()
|
||||
|
||||
# fakeimg initialization and revised by generator.
|
||||
fakeimg_init = torch.randn(
|
||||
(batch_size, self.generator.module.latent_dim))
|
||||
fakeimg = self.generator(fakeimg_init, batch_size)
|
||||
_, data_samples = self.data_preprocessor(data, True)
|
||||
|
||||
with optim_wrapper['generator'].optim_context(self.generator):
|
||||
# recorde the needed information
|
||||
with self.generator_distiller.student_recorders:
|
||||
_ = self.student(fakeimg, data_samples, mode='loss')
|
||||
with self.generator_distiller.teacher_recorders:
|
||||
for _, teacher in self.teachers.items():
|
||||
_ = teacher(fakeimg, data_samples, mode='loss')
|
||||
loss_generator = self.generator_distiller.compute_distill_losses()
|
||||
|
||||
generator_loss, generator_loss_vars = self.parse_losses(loss_generator)
|
||||
log_vars.update(add_prefix(generator_loss_vars, 'generator'))
|
||||
|
||||
with optim_wrapper['architecture'].optim_context(self.student):
|
||||
# recorde the needed information
|
||||
with self.distiller.student_recorders:
|
||||
_ = self.student(fakeimg.detach(), data_samples, mode='loss')
|
||||
with self.distiller.teacher_recorders, torch.no_grad():
|
||||
for _, teacher in self.teachers.items():
|
||||
_ = teacher(fakeimg.detach(), data_samples, mode='loss')
|
||||
loss_distill = self.distiller.compute_distill_losses()
|
||||
|
||||
distill_loss, distill_log_vars = self.parse_losses(loss_distill)
|
||||
log_vars.update(add_prefix(distill_log_vars, 'distill'))
|
||||
|
||||
optim_wrapper['generator'].update_params(generator_loss)
|
||||
optim_wrapper['architecture'].update_params(distill_loss)
|
||||
|
||||
return log_vars
|
|
@ -2,4 +2,5 @@
|
|||
from .backbones import * # noqa: F401,F403
|
||||
from .connectors import * # noqa: F401,F403
|
||||
from .dynamic_op import * # noqa: F401,F403
|
||||
from .generators import * # noqa: F401,F403
|
||||
from .heads import * # noqa: F401,F403
|
||||
|
|
|
@ -23,7 +23,7 @@ class BaseConnector(BaseModule, metaclass=ABCMeta):
|
|||
def __init__(self, init_cfg: Optional[Dict] = None) -> None:
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
|
||||
def forward(self, feature: torch.Tensor) -> None:
|
||||
def forward(self, feature: torch.Tensor) -> torch.Tensor:
|
||||
"""Forward computation.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -15,13 +15,13 @@ class ConvModuleConncetor(BaseConnector):
|
|||
Args:
|
||||
in_channel (int): The input channel of the connector.
|
||||
out_channel (int): The output channel of the connector.
|
||||
kernel_size (int | tuple[int]): Size of the convolving kernel.
|
||||
kernel_size (int | tuple[int, int]): Size of the convolving kernel.
|
||||
Same as that in ``nn._ConvNd``.
|
||||
stride (int | tuple[int]): Stride of the convolution.
|
||||
stride (int | tuple[int, int]): Stride of the convolution.
|
||||
Same as that in ``nn._ConvNd``.
|
||||
padding (int | tuple[int]): Zero-padding added to both sides of
|
||||
padding (int | tuple[int, int]): Zero-padding added to both sides of
|
||||
the input. Same as that in ``nn._ConvNd``.
|
||||
dilation (int | tuple[int]): Spacing between kernel elements.
|
||||
dilation (int | tuple[int, int]): Spacing between kernel elements.
|
||||
Same as that in ``nn._ConvNd``.
|
||||
groups (int): Number of blocked connections from input channels to
|
||||
output channels. Same as that in ``nn._ConvNd``.
|
||||
|
@ -53,10 +53,10 @@ class ConvModuleConncetor(BaseConnector):
|
|||
self,
|
||||
in_channel: int,
|
||||
out_channel: int,
|
||||
kernel_size: Union[int, Tuple[int]] = 1,
|
||||
stride: Union[int, Tuple[int]] = 1,
|
||||
padding: Union[int, Tuple[int]] = 0,
|
||||
dilation: Union[int, Tuple[int]] = 1,
|
||||
kernel_size: Union[int, Tuple[int, int]] = 1,
|
||||
stride: Union[int, Tuple[int, int]] = 1,
|
||||
padding: Union[int, Tuple[int, int]] = 0,
|
||||
dilation: Union[int, Tuple[int, int]] = 1,
|
||||
groups: int = 1,
|
||||
bias: Union[str, bool] = 'auto',
|
||||
conv_cfg: Optional[Dict] = None,
|
||||
|
|
|
@ -0,0 +1,4 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .dafl_generator import DAFLGenerator
|
||||
|
||||
__all__ = ['DAFLGenerator']
|
|
@ -0,0 +1,63 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Dict, Optional
|
||||
|
||||
import torch
|
||||
from mmcv.runner import BaseModule
|
||||
|
||||
from mmrazor.models.utils import get_module_device
|
||||
|
||||
|
||||
class BaseGenerator(BaseModule):
|
||||
"""The base class for generating images.
|
||||
|
||||
Args:
|
||||
img_size (int): The size of generated image.
|
||||
latent_dim (int): The dimension of latent data.
|
||||
hidden_channels (int): The dimension of hidden channels.
|
||||
init_cfg (dict, optional): The config to control the initialization.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
img_size: int,
|
||||
latent_dim: int,
|
||||
hidden_channels: int,
|
||||
init_cfg: Optional[Dict] = None) -> None:
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
self.img_size = img_size
|
||||
self.latent_dim = latent_dim
|
||||
self.hidden_channels = hidden_channels
|
||||
|
||||
def process_latent(self,
|
||||
latent_data: Optional[torch.Tensor] = None,
|
||||
batch_size: int = 1) -> torch.Tensor:
|
||||
"""Generate the latent data if the input is None. Put the latent data
|
||||
into the current gpu.
|
||||
|
||||
Args:
|
||||
latent_data (torch.Tensor, optional): The latent data. Defaults to
|
||||
None.
|
||||
batch_size (int): The batch size of the latent data. Defaults to 1.
|
||||
"""
|
||||
if isinstance(latent_data, torch.Tensor):
|
||||
assert latent_data.shape[1] == self.latent_dim, \
|
||||
'Second dimension of the input must be equal to "latent_dim",'\
|
||||
f'but got {latent_data.shape[1]} != {self.latent_dim}.'
|
||||
if latent_data.ndim == 2:
|
||||
batch_data = latent_data
|
||||
else:
|
||||
raise ValueError('The noise should be in shape of (n, c)'
|
||||
f'but got {latent_data.shape}')
|
||||
elif latent_data is None:
|
||||
assert batch_size > 0, \
|
||||
'"batch_size" should larger than zero when "latent_data" is '\
|
||||
f'None, but got {batch_size}.'
|
||||
batch_data = torch.randn((batch_size, self.latent_dim))
|
||||
|
||||
# putting data on the right device
|
||||
batch_data = batch_data.to(get_module_device(self))
|
||||
return batch_data
|
||||
|
||||
def forward(self) -> None:
|
||||
"""Forward function."""
|
||||
raise NotImplementedError
|
|
@ -0,0 +1,86 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Dict, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from mmrazor.registry import MODELS
|
||||
from .base_generator import BaseGenerator
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class DAFLGenerator(BaseGenerator):
|
||||
"""Generator for DAFL.
|
||||
|
||||
Args:
|
||||
img_size (int): The size of generated image.
|
||||
latent_dim (int): The dimension of latent data.
|
||||
hidden_channels (int): The dimension of hidden channels.
|
||||
scale_factor (int, optional): The scale factor for F.interpolate.
|
||||
Defaults to 2.
|
||||
bn_eps (float, optional): The eps param in bn. Defaults to 0.8.
|
||||
leaky_slope (float, optional): The slope param in leaky relu. Defaults
|
||||
to 0.2.
|
||||
init_cfg (dict, optional): The config to control the initialization.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
img_size: int,
|
||||
latent_dim: int,
|
||||
hidden_channels: int,
|
||||
scale_factor: int = 2,
|
||||
bn_eps: float = 0.8,
|
||||
leaky_slope: float = 0.2,
|
||||
init_cfg: Optional[Dict] = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
img_size, latent_dim, hidden_channels, init_cfg=init_cfg)
|
||||
self.init_size = self.img_size // (scale_factor**2)
|
||||
self.scale_factor = scale_factor
|
||||
self.linear = nn.Linear(self.latent_dim,
|
||||
self.hidden_channels * self.init_size**2)
|
||||
|
||||
self.bn1 = nn.BatchNorm2d(self.hidden_channels)
|
||||
self.conv_blocks1 = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
self.hidden_channels,
|
||||
self.hidden_channels,
|
||||
3,
|
||||
stride=1,
|
||||
padding=1),
|
||||
nn.BatchNorm2d(self.hidden_channels, eps=bn_eps),
|
||||
nn.LeakyReLU(leaky_slope, inplace=True),
|
||||
)
|
||||
self.conv_blocks2 = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
self.hidden_channels,
|
||||
self.hidden_channels // 2,
|
||||
3,
|
||||
stride=1,
|
||||
padding=1),
|
||||
nn.BatchNorm2d(self.hidden_channels // 2, eps=bn_eps),
|
||||
nn.LeakyReLU(leaky_slope, inplace=True),
|
||||
nn.Conv2d(self.hidden_channels // 2, 3, 3, stride=1, padding=1),
|
||||
nn.Tanh(), nn.BatchNorm2d(3, affine=False))
|
||||
|
||||
def forward(self,
|
||||
data: Optional[torch.Tensor] = None,
|
||||
batch_size: int = 0) -> torch.Tensor:
|
||||
"""Forward function for generator.
|
||||
|
||||
Args:
|
||||
data (torch.Tensor, optional): The input data. Defaults to None.
|
||||
batch_size (int): Batch size. Defaults to 0.
|
||||
"""
|
||||
batch_data = self.process_latent(data, batch_size)
|
||||
img = self.linear(batch_data)
|
||||
img = img.view(img.shape[0], self.hidden_channels, self.init_size,
|
||||
self.init_size)
|
||||
img = self.bn1(img)
|
||||
img = F.interpolate(img, scale_factor=self.scale_factor)
|
||||
img = self.conv_blocks1(img)
|
||||
img = F.interpolate(img, scale_factor=self.scale_factor)
|
||||
img = self.conv_blocks2(img)
|
||||
return img
|
|
@ -1,6 +1,7 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .ab_loss import ABLoss
|
||||
from .cwd import ChannelWiseDivergence
|
||||
from .dafl_loss import ActivationLoss, InformationEntropyLoss, OnehotLikeLoss
|
||||
from .decoupled_kd import DKDLoss
|
||||
from .kd_soft_ce_loss import KDSoftCELoss
|
||||
from .kl_divergence import KLDivergence
|
||||
|
@ -10,5 +11,6 @@ from .weighted_soft_label_distillation import WSLD
|
|||
|
||||
__all__ = [
|
||||
'ChannelWiseDivergence', 'KLDivergence', 'AngleWiseRKD', 'DistanceWiseRKD',
|
||||
'WSLD', 'L2Loss', 'ABLoss', 'DKDLoss', 'KDSoftCELoss'
|
||||
'WSLD', 'L2Loss', 'ABLoss', 'DKDLoss', 'KDSoftCELoss', 'ActivationLoss',
|
||||
'OnehotLikeLoss', 'InformationEntropyLoss'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,124 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmcv.runner import get_dist_info
|
||||
|
||||
from mmrazor.registry import MODELS
|
||||
from ..ops import GatherTensors
|
||||
|
||||
|
||||
class DAFLLoss(nn.Module):
|
||||
"""Base class for DAFL losses.
|
||||
|
||||
paper link: https://arxiv.org/pdf/1904.01186.pdf
|
||||
|
||||
Args:
|
||||
loss_weight (float): Weight of the loss. Defaults to 1.0.
|
||||
"""
|
||||
|
||||
def __init__(self, loss_weight=1.0) -> None:
|
||||
super().__init__()
|
||||
self.loss_weight = loss_weight
|
||||
|
||||
def forward(self, preds_T: torch.Tensor) -> torch.Tensor:
|
||||
"""Forward function for the DAFLLoss.
|
||||
|
||||
Args:
|
||||
preds_T (torch.Tensor): The predictions of teacher.
|
||||
"""
|
||||
return self.loss_weight * self.forward_train(preds_T)
|
||||
|
||||
def forward_train(self, preds_T: torch.Tensor) -> torch.Tensor:
|
||||
"""Forward function during training.
|
||||
|
||||
Args:
|
||||
preds_T (torch.Tensor): The predictions of teacher.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class OnehotLikeLoss(DAFLLoss):
|
||||
"""The loss function for measuring the one-hot-likeness of the target
|
||||
logits."""
|
||||
|
||||
def __init__(self, **kwargs) -> None:
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def forward_train(self, preds_T: torch.Tensor) -> torch.Tensor:
|
||||
"""Forward function in training for the OnehotLikeLoss.
|
||||
|
||||
Args:
|
||||
preds_T (torch.Tensor): The predictions of teacher.
|
||||
"""
|
||||
fake_label = preds_T.data.max(1)[1]
|
||||
return F.cross_entropy(preds_T, fake_label)
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class InformationEntropyLoss(DAFLLoss):
|
||||
"""The loss function for measuring the class balance of the target logits.
|
||||
|
||||
Args:
|
||||
gather (bool, optional): The switch controlling whether
|
||||
collecting tensors from multiple gpus. Defaults to True.
|
||||
"""
|
||||
|
||||
def __init__(self, gather=True, **kwargs) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.gather = gather
|
||||
_, self.world_size = get_dist_info()
|
||||
|
||||
def forward_train(self, preds_T: torch.Tensor) -> torch.Tensor:
|
||||
"""Forward function in training for the InformationEntropyLoss.
|
||||
|
||||
Args:
|
||||
preds_T (torch.Tensor): The predictions of teacher.
|
||||
"""
|
||||
# Gather predictions from all GPUS to calibrate the loss function.
|
||||
if self.gather and self.world_size > 1:
|
||||
preds_T = torch.cat(GatherTensors.apply(preds_T), dim=0)
|
||||
class_prob = F.softmax(preds_T, dim=1).mean(dim=0)
|
||||
info_entropy = class_prob * torch.log10(class_prob)
|
||||
return info_entropy.sum()
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class ActivationLoss(nn.Module):
|
||||
"""The loss function for measuring the activation of the target featuremap.
|
||||
It is negative of the norm of the target featuremap.
|
||||
|
||||
Args:
|
||||
loss_weight (float): Weight of the loss. Defaults to 1.0.
|
||||
norm_type (str, optional):The type of the norm. Defaults to 'abs'.
|
||||
"""
|
||||
|
||||
def __init__(self, loss_weight=1.0, norm_type='abs') -> None:
|
||||
super().__init__()
|
||||
self.loss_weight = loss_weight
|
||||
assert norm_type in ['norm', 'abs'], \
|
||||
'"norm_type" must be "norm" or "abs"'
|
||||
self.norm_type = norm_type
|
||||
|
||||
if self.norm_type == 'norm':
|
||||
self.norm_fn = lambda x: -x.norm()
|
||||
elif self.norm_type == 'abs':
|
||||
self.norm_fn = lambda x: -x.abs().mean()
|
||||
|
||||
def forward(self, feat_T: torch.Tensor) -> torch.Tensor:
|
||||
"""Forward function for the ActivationLoss.
|
||||
|
||||
Args:
|
||||
feat_T (torch.Tensor): The featuremap of teacher.
|
||||
"""
|
||||
return self.loss_weight * self.forward_train(feat_T)
|
||||
|
||||
def forward_train(self, feat_T: torch.Tensor) -> torch.Tensor:
|
||||
"""Forward function in training for the ActivationLoss.
|
||||
|
||||
Args:
|
||||
feat_T (torch.Tensor): The featuremap of teacher.
|
||||
"""
|
||||
feat_T = feat_T.view(feat_T.size(0), -1)
|
||||
return self.norm_fn(feat_T)
|
|
@ -3,11 +3,12 @@ from .common import Identity
|
|||
from .darts_series import (DartsDilConv, DartsPoolBN, DartsSepConv,
|
||||
DartsSkipConnect, DartsZero)
|
||||
from .efficientnet_series import ConvBnAct, DepthwiseSeparableConv
|
||||
from .gather_tensors import GatherTensors
|
||||
from .mobilenet_series import MBBlock
|
||||
from .shufflenet_series import ShuffleBlock, ShuffleXception
|
||||
|
||||
__all__ = [
|
||||
'ShuffleBlock', 'ShuffleXception', 'DartsPoolBN', 'DartsDilConv',
|
||||
'DartsSepConv', 'DartsSkipConnect', 'DartsZero', 'MBBlock', 'Identity',
|
||||
'ConvBnAct', 'DepthwiseSeparableConv'
|
||||
'ConvBnAct', 'DepthwiseSeparableConv', 'GatherTensors'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,58 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Any, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
|
||||
class GatherTensors(torch.autograd.Function):
|
||||
"""Gather tensors from all GPUS, supporting backward propagation.
|
||||
|
||||
See more details in torch.distributed.all_gather and
|
||||
torch.distributed.all_reduce.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx: Any, input: torch.Tensor) -> Tuple[Any, ...]:
|
||||
"""Forward function.
|
||||
|
||||
It must accept a context ctx as the first argument.
|
||||
|
||||
The context can be used to store tensors that can be then retrieved
|
||||
during the backward pass.
|
||||
|
||||
Args:
|
||||
ctx (Any): Context to be used for forward propagation.
|
||||
input (torch.Tensor): Tensor to be broadcast from current process.
|
||||
"""
|
||||
output = [
|
||||
torch.empty_like(input) for _ in range(dist.get_world_size())
|
||||
]
|
||||
dist.all_gather(output, input)
|
||||
return tuple(output)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx: Any, *grads: torch.Tensor) -> torch.Tensor:
|
||||
"""Backward function.
|
||||
|
||||
It must accept a context :attr:`ctx` as the first argument, followed by
|
||||
as many outputs did :func:`forward` return, and it should return as
|
||||
many tensors, as there were inputs to :func:`forward`. Each argument is
|
||||
the gradient w.r.t the given output, and each returned value should be
|
||||
the gradient w.r.t. the corresponding input.
|
||||
|
||||
The context can be used to retrieve tensors saved during the forward
|
||||
pass. It also has an attribute :attr:`ctx.needs_input_grad` as a tuple
|
||||
of booleans representing whether each input needs gradient. E.g.,
|
||||
:func:`backward` will have ``ctx.needs_input_grad[0] = True`` if the
|
||||
first input to :func:`forward` needs gradient computated w.r.t. the
|
||||
output.
|
||||
|
||||
Args:
|
||||
ctx (Any): Context to be used for forward propagation.
|
||||
grads (torch.Tensor): Grads to be merged from current process.
|
||||
"""
|
||||
rank = dist.get_rank()
|
||||
merged = torch.stack(grads)
|
||||
dist.all_reduce(merged)
|
||||
return merged[rank]
|
|
@ -2,7 +2,9 @@
|
|||
from .make_divisible import make_divisible
|
||||
from .misc import add_prefix
|
||||
from .optim_wrapper import reinitialize_optim_wrapper_count_status
|
||||
from .utils import get_module_device, set_requires_grad
|
||||
|
||||
__all__ = [
|
||||
'add_prefix', 'reinitialize_optim_wrapper_count_status', 'make_divisible'
|
||||
'add_prefix', 'reinitialize_optim_wrapper_count_status', 'make_divisible',
|
||||
'get_module_device', 'set_requires_grad'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,39 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import List, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def get_module_device(module: nn.Module) -> torch.device:
|
||||
"""Get the device of a module.
|
||||
|
||||
Args:
|
||||
module (nn.Module): A module contains the parameters.
|
||||
"""
|
||||
try:
|
||||
next(module.parameters())
|
||||
except StopIteration as e:
|
||||
raise ValueError('The input module should contain parameters.') from e
|
||||
|
||||
if next(module.parameters()).is_cuda:
|
||||
return next(module.parameters()).get_device()
|
||||
|
||||
return torch.device('cpu')
|
||||
|
||||
|
||||
def set_requires_grad(nets: Union[nn.Module, List[nn.Module]],
|
||||
requires_grad: bool = False) -> None:
|
||||
"""Set requires_grad for all the networks.
|
||||
|
||||
Args:
|
||||
nets (nn.Module | list[nn.Module]): A list of networks or a single
|
||||
network.
|
||||
requires_grad (bool): Whether the networks require gradients or not
|
||||
"""
|
||||
if not isinstance(nets, list):
|
||||
nets = [nets]
|
||||
for net in nets:
|
||||
if net is not None:
|
||||
for param in net.parameters():
|
||||
param.requires_grad = requires_grad
|
|
@ -0,0 +1,220 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
from unittest import TestCase
|
||||
|
||||
import torch
|
||||
from mmcv import ConfigDict
|
||||
from mmengine.optim import build_optim_wrapper
|
||||
|
||||
from mmrazor.models import DAFLDataFreeDistillation, DataFreeDistillation
|
||||
|
||||
|
||||
class TestDataFreeDistill(TestCase):
|
||||
|
||||
def test_init(self):
|
||||
|
||||
recorders_cfg = ConfigDict(
|
||||
conv=dict(type='ModuleOutputs', source='conv'))
|
||||
|
||||
alg_kwargs = ConfigDict(
|
||||
architecture=dict(type='ToyStudent'),
|
||||
teachers=dict(
|
||||
tea1=dict(build_cfg=dict(type='ToyTeacher')),
|
||||
tea2=dict(build_cfg=dict(type='ToyTeacher'))),
|
||||
generator=dict(type='ToyGenerator'),
|
||||
distiller=dict(
|
||||
type='ConfigurableDistiller',
|
||||
student_recorders=recorders_cfg,
|
||||
teacher_recorders=dict(
|
||||
tea1_conv=dict(type='ModuleOutputs', source='tea1.conv')),
|
||||
distill_losses=dict(loss_dis=dict(type='ToyDistillLoss')),
|
||||
loss_forward_mappings=dict(
|
||||
loss_dis=dict(
|
||||
arg1=dict(from_student=True, recorder='conv'),
|
||||
arg2=dict(from_student=False, recorder='tea1_conv')))),
|
||||
generator_distiller=dict(
|
||||
type='ConfigurableDistiller',
|
||||
student_recorders=recorders_cfg,
|
||||
teacher_recorders=dict(
|
||||
tea2_conv=dict(type='ModuleOutputs', source='tea2.conv')),
|
||||
distill_losses=dict(loss_gen=dict(type='ToyDistillLoss')),
|
||||
loss_forward_mappings=dict(
|
||||
loss_gen=dict(
|
||||
arg1=dict(from_student=True, recorder='conv'),
|
||||
arg2=dict(from_student=False, recorder='tea2_conv')))),
|
||||
)
|
||||
|
||||
alg = DataFreeDistillation(**alg_kwargs)
|
||||
self.assertEquals(len(alg.teachers), len(alg_kwargs['teachers']))
|
||||
|
||||
alg_kwargs_ = copy.deepcopy(alg_kwargs)
|
||||
alg_kwargs_['teachers'] = 'ToyTeacher'
|
||||
with self.assertRaisesRegex(TypeError,
|
||||
'teacher should be a `dict` but got '):
|
||||
alg = DataFreeDistillation(**alg_kwargs_)
|
||||
|
||||
alg_kwargs_ = copy.deepcopy(alg_kwargs)
|
||||
alg_kwargs_['generator'] = 'ToyGenerator'
|
||||
with self.assertRaisesRegex(
|
||||
TypeError, 'generator should be a `dict` instance, but got '):
|
||||
_ = DataFreeDistillation(**alg_kwargs_)
|
||||
|
||||
def test_loss(self):
|
||||
|
||||
recorders_cfg = ConfigDict(
|
||||
conv=dict(type='ModuleOutputs', source='conv'))
|
||||
|
||||
alg_kwargs = ConfigDict(
|
||||
architecture=dict(type='ToyStudent'),
|
||||
teachers=dict(
|
||||
tea1=dict(build_cfg=dict(type='ToyTeacher')),
|
||||
tea2=dict(build_cfg=dict(type='ToyTeacher'))),
|
||||
generator=dict(type='ToyGenerator'),
|
||||
distiller=dict(
|
||||
type='ConfigurableDistiller',
|
||||
student_recorders=recorders_cfg,
|
||||
teacher_recorders=dict(
|
||||
tea1_conv=dict(type='ModuleOutputs', source='tea1.conv')),
|
||||
distill_losses=dict(loss_dis=dict(type='ToyDistillLoss')),
|
||||
loss_forward_mappings=dict(
|
||||
loss_dis=dict(
|
||||
arg1=dict(from_student=True, recorder='conv'),
|
||||
arg2=dict(from_student=False, recorder='tea1_conv')))),
|
||||
generator_distiller=dict(
|
||||
type='ConfigurableDistiller',
|
||||
student_recorders=recorders_cfg,
|
||||
teacher_recorders=dict(
|
||||
tea2_conv=dict(type='ModuleOutputs', source='tea2.conv')),
|
||||
distill_losses=dict(loss_gen=dict(type='ToyDistillLoss')),
|
||||
loss_forward_mappings=dict(
|
||||
loss_gen=dict(
|
||||
arg1=dict(from_student=True, recorder='conv'),
|
||||
arg2=dict(from_student=False, recorder='tea2_conv')))),
|
||||
)
|
||||
|
||||
optim_wrapper_cfg = dict(
|
||||
type='OptimWrapper',
|
||||
optimizer=dict(
|
||||
type='SGD', lr=0.1, weight_decay=0.01, momentum=0.9))
|
||||
|
||||
data = [dict(inputs=torch.randn(3, 1, 1)) for _ in range(4)]
|
||||
|
||||
alg = DataFreeDistillation(**alg_kwargs)
|
||||
optim_wrapper = build_optim_wrapper(alg, optim_wrapper_cfg)
|
||||
optim_wrapper_dict = dict(
|
||||
architecture=optim_wrapper, generator=optim_wrapper)
|
||||
|
||||
losses = alg.train_step(data, optim_wrapper_dict)
|
||||
self.assertIn('distill.loss_dis', losses)
|
||||
self.assertIn('distill.loss', losses)
|
||||
self.assertIn('generator.loss_gen', losses)
|
||||
self.assertIn('generator.loss', losses)
|
||||
|
||||
alg_kwargs_ = copy.deepcopy(alg_kwargs)
|
||||
alg_kwargs_['student_iter'] = 5
|
||||
alg = DataFreeDistillation(**alg_kwargs_)
|
||||
losses = alg.train_step(data, optim_wrapper_dict)
|
||||
self.assertIn('distill.loss_dis', losses)
|
||||
self.assertIn('distill.loss', losses)
|
||||
self.assertIn('generator.loss_gen', losses)
|
||||
self.assertIn('generator.loss', losses)
|
||||
|
||||
|
||||
class TestDAFLDataFreeDistill(TestCase):
|
||||
|
||||
def test_init(self):
|
||||
|
||||
recorders_cfg = ConfigDict(
|
||||
conv=dict(type='ModuleOutputs', source='conv'))
|
||||
|
||||
alg_kwargs = ConfigDict(
|
||||
architecture=dict(type='ToyStudent'),
|
||||
teachers=dict(
|
||||
tea1=dict(build_cfg=dict(type='ToyTeacher')),
|
||||
tea2=dict(build_cfg=dict(type='ToyTeacher'))),
|
||||
generator=dict(type='ToyGenerator'),
|
||||
distiller=dict(
|
||||
type='ConfigurableDistiller',
|
||||
student_recorders=recorders_cfg,
|
||||
teacher_recorders=dict(
|
||||
tea1_conv=dict(type='ModuleOutputs', source='tea1.conv')),
|
||||
distill_losses=dict(loss_dis=dict(type='ToyDistillLoss')),
|
||||
loss_forward_mappings=dict(
|
||||
loss_dis=dict(
|
||||
arg1=dict(from_student=True, recorder='conv'),
|
||||
arg2=dict(from_student=False, recorder='tea1_conv')))),
|
||||
generator_distiller=dict(
|
||||
type='ConfigurableDistiller',
|
||||
student_recorders=recorders_cfg,
|
||||
teacher_recorders=dict(
|
||||
tea2_conv=dict(type='ModuleOutputs', source='tea2.conv')),
|
||||
distill_losses=dict(loss_gen=dict(type='ToyDistillLoss')),
|
||||
loss_forward_mappings=dict(
|
||||
loss_gen=dict(
|
||||
arg1=dict(from_student=True, recorder='conv'),
|
||||
arg2=dict(from_student=False, recorder='tea2_conv')))))
|
||||
|
||||
alg = DAFLDataFreeDistillation(**alg_kwargs)
|
||||
self.assertEquals(len(alg.teachers), len(alg_kwargs['teachers']))
|
||||
|
||||
alg_kwargs_ = copy.deepcopy(alg_kwargs)
|
||||
alg_kwargs_['teachers'] = 'ToyTeacher'
|
||||
with self.assertRaisesRegex(TypeError,
|
||||
'teacher should be a `dict` but got '):
|
||||
alg = DAFLDataFreeDistillation(**alg_kwargs_)
|
||||
|
||||
alg_kwargs_ = copy.deepcopy(alg_kwargs)
|
||||
alg_kwargs_['generator'] = 'ToyGenerator'
|
||||
with self.assertRaisesRegex(
|
||||
TypeError, 'generator should be a `dict` instance, but got '):
|
||||
_ = DAFLDataFreeDistillation(**alg_kwargs_)
|
||||
|
||||
def test_loss(self):
|
||||
|
||||
recorders_cfg = ConfigDict(
|
||||
conv=dict(type='ModuleOutputs', source='conv'))
|
||||
|
||||
alg_kwargs = ConfigDict(
|
||||
architecture=dict(type='ToyStudent'),
|
||||
teachers=dict(
|
||||
tea1=dict(build_cfg=dict(type='ToyTeacher')),
|
||||
tea2=dict(build_cfg=dict(type='ToyTeacher'))),
|
||||
generator=dict(type='ToyGenerator'),
|
||||
distiller=dict(
|
||||
type='ConfigurableDistiller',
|
||||
student_recorders=recorders_cfg,
|
||||
teacher_recorders=dict(
|
||||
tea1_conv=dict(type='ModuleOutputs', source='tea1.conv')),
|
||||
distill_losses=dict(loss_dis=dict(type='ToyDistillLoss')),
|
||||
loss_forward_mappings=dict(
|
||||
loss_dis=dict(
|
||||
arg1=dict(from_student=True, recorder='conv'),
|
||||
arg2=dict(from_student=False, recorder='tea1_conv')))),
|
||||
generator_distiller=dict(
|
||||
type='ConfigurableDistiller',
|
||||
student_recorders=recorders_cfg,
|
||||
teacher_recorders=dict(
|
||||
tea1_conv=dict(type='ModuleOutputs', source='tea1.conv'),
|
||||
tea2_conv=dict(type='ModuleOutputs', source='tea2.conv')),
|
||||
distill_losses=dict(loss_gen=dict(type='ToyDistillLoss')),
|
||||
loss_forward_mappings=dict(
|
||||
loss_gen=dict(
|
||||
arg1=dict(from_student=False, recorder='tea1_conv'),
|
||||
arg2=dict(from_student=False, recorder='tea2_conv')))))
|
||||
|
||||
optim_wrapper_cfg = dict(
|
||||
type='OptimWrapper',
|
||||
optimizer=dict(
|
||||
type='SGD', lr=0.1, weight_decay=0.01, momentum=0.9))
|
||||
|
||||
data = [dict(inputs=torch.randn(3, 1, 1)) for _ in range(4)]
|
||||
|
||||
alg = DAFLDataFreeDistillation(**alg_kwargs)
|
||||
optim_wrapper = build_optim_wrapper(alg, optim_wrapper_cfg)
|
||||
optim_wrapper_dict = dict(
|
||||
architecture=optim_wrapper, generator=optim_wrapper)
|
||||
losses = alg.train_step(data, optim_wrapper_dict)
|
||||
self.assertIn('distill.loss_dis', losses)
|
||||
self.assertIn('distill.loss', losses)
|
||||
self.assertIn('generator.loss_gen', losses)
|
||||
self.assertIn('generator.loss', losses)
|
|
@ -1,5 +1,7 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
from mmengine.model import BaseModel
|
||||
from torch import nn
|
||||
|
||||
|
@ -32,6 +34,29 @@ class ToyTeacher(ToyStudent):
|
|||
super().__init__()
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Data:
|
||||
latent_dim: int = 1
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class ToyGenerator(BaseModel):
|
||||
|
||||
def __init__(self, latent_dim=4, out_channel=3):
|
||||
super().__init__(data_preprocessor=None, init_cfg=None)
|
||||
self.latent_dim = latent_dim
|
||||
self.out_channel = out_channel
|
||||
self.conv = nn.Conv2d(self.latent_dim, self.out_channel, 1)
|
||||
|
||||
# Imitate the structure of generator in separate model_wrapper.
|
||||
self.module = Data(latent_dim=self.latent_dim)
|
||||
|
||||
def forward(self, data=None, batch_size=4):
|
||||
fakeimg_init = torch.randn(batch_size, self.latent_dim, 1, 1)
|
||||
fakeimg = self.conv(fakeimg_init)
|
||||
return fakeimg
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class ToyDistillLoss(nn.Module):
|
||||
|
||||
|
|
|
@ -0,0 +1,39 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmrazor.models import DAFLGenerator
|
||||
|
||||
|
||||
def test_dafl_generator():
|
||||
dafl_generator = DAFLGenerator(
|
||||
img_size=32, latent_dim=10, hidden_channels=32)
|
||||
z_batch = torch.randn(8, 10)
|
||||
fake_img = dafl_generator(z_batch)
|
||||
assert fake_img.size() == torch.Size([8, 3, 32, 32])
|
||||
with pytest.raises(AssertionError):
|
||||
z_batch = torch.randn(8, 11)
|
||||
fake_img = dafl_generator(z_batch)
|
||||
with pytest.raises(ValueError):
|
||||
z_batch = torch.randn(8, 10, 1, 1)
|
||||
fake_img = dafl_generator(z_batch)
|
||||
|
||||
fake_img = dafl_generator(batch_size=8)
|
||||
assert fake_img.size() == torch.Size([8, 3, 32, 32])
|
||||
|
||||
# scale_factor = 4
|
||||
dafl_generator = DAFLGenerator(
|
||||
img_size=32, latent_dim=10, hidden_channels=32, scale_factor=4)
|
||||
z_batch = torch.randn(8, 10)
|
||||
fake_img = dafl_generator(z_batch)
|
||||
assert fake_img.size() == torch.Size([8, 3, 32, 32])
|
||||
|
||||
# hidden_channels=64
|
||||
dafl_generator = DAFLGenerator(
|
||||
img_size=32, latent_dim=10, hidden_channels=64)
|
||||
z_batch = torch.randn(8, 10)
|
||||
fake_img = dafl_generator(z_batch)
|
||||
assert fake_img.size() == torch.Size([8, 3, 32, 32])
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
fake_img = dafl_generator(data=None, batch_size=0)
|
|
@ -3,7 +3,9 @@ from unittest import TestCase
|
|||
|
||||
import torch
|
||||
|
||||
from mmrazor.models import ABLoss, DKDLoss, KDSoftCELoss
|
||||
from mmrazor.models import (ABLoss, ActivationLoss, DKDLoss,
|
||||
InformationEntropyLoss, KDSoftCELoss,
|
||||
OnehotLikeLoss)
|
||||
|
||||
|
||||
class TestLosses(TestCase):
|
||||
|
@ -51,6 +53,32 @@ class TestLosses(TestCase):
|
|||
# dkd requires label logits
|
||||
self.normal_test_1d(dkd_loss, labels=True)
|
||||
|
||||
def test_dafl_loss(self):
|
||||
dafl_loss_cfg = dict(loss_weight=1.0)
|
||||
ac_loss = ActivationLoss(**dafl_loss_cfg, norm_type='abs')
|
||||
oh_loss = OnehotLikeLoss(**dafl_loss_cfg)
|
||||
ie_loss = InformationEntropyLoss(**dafl_loss_cfg, gather=False)
|
||||
|
||||
# normal test with only one input
|
||||
loss_ac = ac_loss.forward(self.feats_1d)
|
||||
self.assertTrue(loss_ac.numel() == 1)
|
||||
loss_oh = oh_loss.forward(self.feats_1d)
|
||||
self.assertTrue(loss_oh.numel() == 1)
|
||||
loss_ie = ie_loss.forward(self.feats_1d)
|
||||
self.assertTrue(loss_ie.numel() == 1)
|
||||
|
||||
with self.assertRaisesRegex(AssertionError,
|
||||
'"norm_type" must be "norm" or "abs"'):
|
||||
_ = ActivationLoss(**dafl_loss_cfg, norm_type='random')
|
||||
|
||||
# test gather_tensors
|
||||
ie_loss = InformationEntropyLoss(**dafl_loss_cfg, gather=True)
|
||||
ie_loss.world_size = 2
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
'Default process group has not been initialized'):
|
||||
loss_ie = ie_loss.forward(self.feats_1d)
|
||||
|
||||
def test_kdSoftce_loss(self):
|
||||
kdSoftce_loss_cfg = dict(loss_weight=1.0)
|
||||
kdSoftce_loss = KDSoftCELoss(**kdSoftce_loss_cfg)
|
||||
|
|
Loading…
Reference in New Issue