pull/586/merge
Cbtor 2023-09-22 06:28:17 +00:00 committed by GitHub
commit e95c0363c7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 485 additions and 4 deletions

View File

@ -0,0 +1,31 @@
# Learning Student Networks in the Wild (DFND)
> [Learning Student Networks in the Wild](https://openaccess.thecvf.com/content/CVPR2021/papers/Chen_Learning_Student_Networks_in_the_Wild_CVPR_2021_paper.pdf)
<!-- [ALGORITHM] -->
## Abstract
Data-free learning for student networks is a new paradigm for solving users anxiety caused by the privacy problem of using original training data. Since the architectures of modern convolutional neural networks (CNNs) are compact and sophisticated, the alternative images or meta-data generated from the teacher network are often broken. Thus, the student network cannot achieve the comparable performance to that of the pre-trained teacher network especially on the large-scale image dataset. Different to previous works, we present to maximally utilize the massive available unlabeled data in the wild. Specifically, we first thoroughly analyze the output differences between teacher and student network on the original data and develop a data collection method. Then, a noisy knowledge distillation algorithm is proposed for achieving the performance of the student network. In practice, an adaptation matrix is learned with the student network for correcting the label noise produced by the teacher network on the collected unlabeled images. The effectiveness of our DFND (DataFree Noisy Distillation) method is then verified on several benchmarks to demonstrate its superiority over state-of-theart data-free distillation methods. Experiments on various datasets demonstrate that the student networks learned by the proposed method can achieve comparable performance with those using the original dataset.
<img width="910" alt="pipeline" src="./dfnd.PNG">
## Results and models
### Classification
| Location | Dataset | Teacher | Student | Acc | Acc(T) | Acc(S) | Config | |
| :---------------: | :-----: | :-------------------------------------------------------------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------------: | :---: | :----: | :----: | :---------------------------------------------------------: | :---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| backbone & logits | Cifar10 | [resnet34](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet34_8xb16_cifar10.py) | [resnet18](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet18_8xb16_cifar10.py) | 94.78 | 95.34 | 94.82 | [config](./dfnd_logits_resnet34_resnet18_8xb32_cifar10.py) | [student](https://drive.google.com/file/d/1_MekfTkCsEl68meWPqtdNZIxdJO2R2Eb/view?usp=drive_link) |
## Citation
```latex
@inproceedings{chen2021learning,
title={Learning student networks in the wild},
author={Chen, Hanting and Guo, Tianyu and Xu, Chang and Li, Wenshuo and Xu, Chunjing and Xu, Chao and Wang, Yunhe},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
pages={6428--6437},
year={2021}
}
```

Binary file not shown.

After

Width:  |  Height:  |  Size: 644 KiB

View File

@ -0,0 +1,100 @@
_base_ = ['mmcls::_base_/default_runtime.py']
# optimizer
optim_wrapper = dict(
optimizer=dict(type='SGD', lr=0.1, momentum=0.9, weight_decay=0.0001))
# learning policy
param_scheduler = dict(
type='MultiStepLR', by_epoch=True, milestones=[320, 640], gamma=0.1)
# train, val, test setting
train_cfg = dict(by_epoch=True, max_epochs=800, val_interval=1)
test_cfg = dict()
# NOTE: `auto_scale_lr` is for automatically scaling LR
# based on the actual training batch size.
auto_scale_lr = dict(base_batch_size=128)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='RandomResizedCrop', scale=32),
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
dict(type='PackClsInputs'),
]
train_dataloader = dict(
batch_size=256,
num_workers=5,
dataset=dict(
type='ImageNet',
data_root='/cache/data/imagenet/',
data_prefix='train',
pipeline=train_pipeline),
sampler=dict(type='DefaultSampler', shuffle=True),
)
test_pipeline = [
dict(type='PackClsInputs'),
]
val_dataloader = dict(
batch_size=16,
num_workers=2,
dataset=dict(
type='CIFAR10',
data_prefix='/cache/data/cifar',
test_mode=True,
pipeline=test_pipeline),
sampler=dict(type='DefaultSampler', shuffle=False),
)
val_evaluator = dict(type='Accuracy', topk=(1, ))
test_dataloader = val_dataloader
test_evaluator = val_evaluator
teacher_ckpt = '/cache/models/resnet_model.pth' # noqa: E501
model = dict(
_scope_='mmrazor',
type='DFNDDistill',
calculate_student_loss=False,
data_preprocessor=dict(
type='ImgDataPreprocessor',
# RGB format normalization parameters
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
# convert image from BGR to RGB
bgr_to_rgb=True),
val_data_preprocessor=dict(
type='ImgDataPreprocessor',
# RGB format normalization parameters
mean=[125.307, 122.961, 113.8575],
std=[51.5865, 50.847, 51.255],
# convert image from BGR to RGB
bgr_to_rgb=False),
architecture=dict(
cfg_path='mmcls::resnet/resnet18_8xb16_cifar10.py', pretrained=False),
teacher=dict(
cfg_path='mmcls::resnet/resnet34_8xb16_cifar10.py', pretrained=False),
teacher_ckpt=teacher_ckpt,
distiller=dict(
type='ConfigurableDistiller',
student_recorders=dict(
fc=dict(type='ModuleOutputs', source='head.fc')),
teacher_recorders=dict(
fc=dict(type='ModuleOutputs', source='head.fc')),
distill_losses=dict(
loss_kl=dict(
type='DFNDLoss',
tau=4,
loss_weight=1,
num_classes=10,
batch_select=0.5)),
loss_forward_mappings=dict(
loss_kl=dict(
preds_S=dict(from_student=True, recorder='fc'),
preds_T=dict(from_student=False, recorder='fc')))))
find_unused_parameters = True
val_cfg = dict(type='mmrazor.DFNDValLoop')

View File

@ -1,7 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .autoslim_greedy_search_loop import AutoSlimGreedySearchLoop
from .darts_loop import DartsEpochBasedTrainLoop, DartsIterBasedTrainLoop
from .distill_val_loop import SelfDistillValLoop, SingleTeacherDistillValLoop
from .distill_val_loop import (DFNDValLoop, SelfDistillValLoop,
SingleTeacherDistillValLoop)
from .evolution_search_loop import EvolutionSearchLoop
from .iteprune_val_loop import ItePruneValLoop
from .quantization_loops import (LSQEpochBasedLoop, PTQLoop, QATEpochBasedLoop,
@ -15,5 +16,5 @@ __all__ = [
'DartsIterBasedTrainLoop', 'SlimmableValLoop', 'EvolutionSearchLoop',
'GreedySamplerTrainLoop', 'SubnetValLoop', 'SelfDistillValLoop',
'ItePruneValLoop', 'AutoSlimGreedySearchLoop', 'QATEpochBasedLoop',
'PTQLoop', 'LSQEpochBasedLoop', 'QATValLoop'
'PTQLoop', 'LSQEpochBasedLoop', 'QATValLoop', 'DFNDValLoop'
]

View File

@ -125,3 +125,38 @@ class SelfDistillValLoop(ValLoop):
self.runner.call_hook('after_val_epoch', metrics=student_metrics)
self.runner.call_hook('after_val')
@LOOPS.register_module()
class DFNDValLoop(SingleTeacherDistillValLoop):
"""Validation loop for DFND. DFND requires different dataset for training
and validation.
Args:
runner (Runner): A reference of runner.
dataloader (Dataloader or dict): A dataloader object or a dict to
build a dataloader.
evaluator (Evaluator or dict or list): Used for computing metrics.
fp16 (bool): Whether to enable fp16 validation. Defaults to
False.
"""
def __init__(self,
runner,
dataloader: Union[DataLoader, Dict],
evaluator: Union[Evaluator, Dict, List],
fp16: bool = False) -> None:
super().__init__(runner, dataloader, evaluator, fp16)
if self.runner.distributed:
assert hasattr(self.runner.model.module, 'teacher')
# TODO: remove hard code after mmcls add data_preprocessor
data_preprocessor = self.runner.model.module.val_data_preprocessor
self.teacher = self.runner.model.module.teacher
self.teacher.data_preprocessor = data_preprocessor
else:
assert hasattr(self.runner.model, 'teacher')
# TODO: remove hard code after mmcls add data_preprocessor
data_preprocessor = self.runner.model.val_data_preprocessor
self.teacher = self.runner.model.teacher
self.teacher.data_preprocessor = data_preprocessor

View File

@ -1,6 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .datafree_distillation import (DAFLDataFreeDistillation,
DataFreeDistillation)
from .dfnd_distill import DFNDDistill
from .fpn_teacher_distill import FpnTeacherDistill
from .overhaul_feature_distillation import OverhaulFeatureDistillation
from .self_distill import SelfDistill
@ -9,5 +10,5 @@ from .single_teacher_distill import SingleTeacherDistill
__all__ = [
'SelfDistill', 'SingleTeacherDistill', 'FpnTeacherDistill',
'DataFreeDistillation', 'DAFLDataFreeDistillation',
'OverhaulFeatureDistillation'
'OverhaulFeatureDistillation', 'DFNDDistill'
]

View File

@ -0,0 +1,198 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Optional, Union
import torch
from mmengine.model import BaseModel
from mmengine.runner import load_checkpoint
from mmengine.structures import BaseDataElement
from torch import nn
from torch.nn.modules.batchnorm import _BatchNorm
from mmrazor.models.utils import add_prefix
from mmrazor.registry import MODELS
from ...base import BaseAlgorithm, LossResults
@MODELS.register_module()
class DFNDDistill(BaseAlgorithm):
"""``DFNDDistill`` algorithm for training student model in the wild dataset.
https://openaccess.thecvf.com/content/CVPR2021/papers/Chen_Learning_Student_Networks_in_the_Wild_CVPR_2021_paper.pdf
Args:
distiller (dict): The config dict for built distiller.
teacher (dict | BaseModel): The config dict for teacher model or built
teacher model.
val_data_preprocessor (Union[Dict, nn.Module]): Data preprocessor for
evaluation dataset. Defaults to None.
teacher_ckpt (str): The path of teacher's checkpoint. Defaults to None.
teacher_trainable (bool): Whether the teacher is trainable. Defaults
to False.
teacher_norm_eval (bool): Whether to set teacher's norm layers to eval
mode, namely, freeze running stats (mean and var). Note: Effect on
Batch Norm and its variants only. Defaults to True.
student_trainable (bool): Whether the student is trainable. Defaults
to True.
calculate_student_loss (bool): Whether to calculate student loss
(original task loss) to update student model. Defaults to True.
teacher_module_inplace(bool): Whether to allow teacher module inplace
attribute True. Defaults to False.
"""
def __init__(self,
distiller: dict,
teacher: Union[BaseModel, Dict],
val_data_preprocessor: Optional[Union[Dict,
nn.Module]] = None,
teacher_ckpt: Optional[str] = None,
teacher_trainable: bool = False,
teacher_norm_eval: bool = True,
student_trainable: bool = True,
calculate_student_loss: bool = True,
teacher_module_inplace: bool = False,
**kwargs) -> None:
super().__init__(**kwargs)
self.distiller = MODELS.build(distiller)
if isinstance(teacher, Dict):
teacher = MODELS.build(teacher)
if not isinstance(teacher, BaseModel):
raise TypeError('teacher should be a `dict` or '
f'`BaseModel` instance, but got '
f'{type(teacher)}')
self.teacher = teacher
# Find all nn.Modules in the model that contain the 'inplace' attribute
# and set them to False.
self.teacher_module_inplace = teacher_module_inplace
if not self.teacher_module_inplace:
self.set_module_inplace_false(teacher, 'self.teacher')
if teacher_ckpt:
_ = load_checkpoint(self.teacher, teacher_ckpt)
# avoid loaded parameters be overwritten
self.teacher._is_init = True
self.teacher_trainable = teacher_trainable
if not self.teacher_trainable:
for param in self.teacher.parameters():
param.requires_grad = False
self.teacher_norm_eval = teacher_norm_eval
# The student model will not calculate gradients and update parameters
# in some pretraining process.
self.student_trainable = student_trainable
# The student loss will not be updated into ``losses`` in some
# pretraining process.
self.calculate_student_loss = calculate_student_loss
# In ``ConfigurableDistller``, the recorder manager is just
# constructed, but not really initialized yet.
self.distiller.prepare_from_student(self.student)
self.distiller.prepare_from_teacher(self.teacher)
# may be modified by stop distillation hook
self.distillation_stopped = False
if val_data_preprocessor is None:
val_data_preprocessor = dict(type='BaseDataPreprocessor')
if isinstance(val_data_preprocessor, nn.Module):
self.val_data_preprocessor = val_data_preprocessor
elif isinstance(val_data_preprocessor, dict):
self.val_data_preprocessor = MODELS.build(val_data_preprocessor)
else:
raise TypeError('val_data_preprocessor should be a `dict` or '
f'`nn.Module` instance, but got '
f'{type(val_data_preprocessor)}')
@property
def student(self) -> nn.Module:
"""Alias for ``architecture``."""
return self.architecture
def loss(
self,
batch_inputs: torch.Tensor,
data_samples: Optional[List[BaseDataElement]] = None,
) -> LossResults:
"""Calculate losses from a batch of inputs and data samples."""
losses = dict()
# If the `override_data` of a delivery is False, the delivery will
# record the origin data.
self.distiller.set_deliveries_override(False)
if self.teacher_trainable:
with self.distiller.teacher_recorders, self.distiller.deliveries:
teacher_losses = self.teacher(
batch_inputs, data_samples, mode='loss')
losses.update(add_prefix(teacher_losses, 'teacher'))
else:
with self.distiller.teacher_recorders, self.distiller.deliveries:
with torch.no_grad():
_ = self.teacher(batch_inputs, data_samples, mode='tensor')
# If the `override_data` of a delivery is True, the delivery will
# override the origin data with the recorded data.
self.distiller.set_deliveries_override(True)
# Original task loss will not be used during some pretraining process.
if self.calculate_student_loss:
with self.distiller.student_recorders, self.distiller.deliveries:
student_losses = self.student(
batch_inputs, data_samples, mode='loss')
losses.update(add_prefix(student_losses, 'student'))
else:
with self.distiller.student_recorders, self.distiller.deliveries:
if self.student_trainable:
_ = self.student(batch_inputs, data_samples, mode='tensor')
else:
with torch.no_grad():
_ = self.student(
batch_inputs, data_samples, mode='tensor')
if not self.distillation_stopped:
# Automatically compute distill losses based on
# `loss_forward_mappings`.
# The required data already exists in the recorders.
distill_losses = self.distiller.compute_distill_losses()
losses.update(add_prefix(distill_losses, 'distill'))
return losses
def train(self, mode: bool = True) -> None:
"""Set distiller's forward mode."""
super().train(mode)
if mode and self.teacher_norm_eval:
for m in self.teacher.modules():
if isinstance(m, _BatchNorm):
m.eval()
def val_step(self, data: Union[tuple, dict, list]) -> list:
"""Gets the predictions of given data.
Calls ``self.val_data_preprocessor(data, False)`` and
``self(inputs, data_sample, mode='predict')`` in order. Return the
predictions which will be passed to evaluator.
Args:
data (dict or tuple or list): Data sampled from dataset.
Returns:
list: The predictions of given data.
"""
data = self.val_data_preprocessor(data, False)
return self._run_forward(data, mode='predict') # type: ignore
def test_step(self, data: Union[dict, tuple, list]) -> list:
"""``BaseModel`` implements ``test_step`` the same as ``val_step``.
Args:
data (dict or tuple or list): Data sampled from dataset.
Returns:
list: The predictions of given data.
"""
data = self.val_data_preprocessor(data, False)
return self._run_forward(data, mode='predict') # type: ignore

View File

@ -6,6 +6,7 @@ from .cross_entropy_loss import CrossEntropyLoss
from .cwd import ChannelWiseDivergence
from .dafl_loss import ActivationLoss, InformationEntropyLoss, OnehotLikeLoss
from .decoupled_kd import DKDLoss
from .dfnd_loss import DFNDLoss
from .dist_loss import DISTLoss
from .factor_transfer_loss import FTLoss
from .fbkd_loss import FBKDLoss
@ -24,5 +25,5 @@ __all__ = [
'WSLD', 'L2Loss', 'ABLoss', 'DKDLoss', 'KDSoftCELoss', 'ActivationLoss',
'OnehotLikeLoss', 'InformationEntropyLoss', 'FTLoss', 'ATLoss', 'OFDLoss',
'L1Loss', 'FBKDLoss', 'CRDLoss', 'CrossEntropyLoss', 'PKDLoss', 'MGDLoss',
'DISTLoss'
'DISTLoss', 'DFNDLoss'
]

View File

@ -0,0 +1,114 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmrazor.registry import MODELS
@MODELS.register_module()
class DFNDLoss(nn.Module):
"""Loss function for DFND.
https://openaccess.thecvf.com/content/CVPR2021/papers/Chen_Learning_Student_Networks_in_the_Wild_CVPR_2021_paper.pdf
Args:
tau (float): Temperature coefficient. Defaults to 1.0.
reduction (str): Specifies the reduction to apply to the loss:
``'none'`` | ``'batchmean'`` | ``'sum'`` | ``'mean'``.
``'none'``: no reduction will be applied,
``'batchmean'``: the sum of the output will be divided by
the batchsize,
``'sum'``: the output will be summed,
``'mean'``: the output will be divided by the number of
elements in the output.
Default: ``'batchmean'``
loss_weight (float): Weight of loss. Defaults to 1.0.
teacher_detach (bool): Whether to detach the teacher model prediction.
Will set to ``'False'`` in some data-free distillation algorithms.
Defaults to True.
num_classes (int): Number of classes.
teacher_acc (float): The performance of teacher network in the target
dataset.
batch_select (float): ratio of data in the wild dataset to participate
in training.
"""
def __init__(
self,
tau: float = 1.0,
reduction: str = 'batchmean',
loss_weight: float = 1.0,
teacher_detach: bool = True,
num_classes: int = 1000,
teacher_acc: float = 0.95,
batch_select: float = 0.5,
):
super(DFNDLoss, self).__init__()
self.tau = tau
self.loss_weight = loss_weight
self.teacher_detach = teacher_detach
accept_reduction = {'none', 'batchmean', 'sum', 'mean'}
assert reduction in accept_reduction, \
f'KLDivergence supports reduction {accept_reduction}, ' \
f'but gets {reduction}.'
self.reduction = reduction
self.noisy_adaptation = torch.nn.Parameter(
torch.zeros(num_classes, num_classes - 1))
self.teacher_acc = teacher_acc
self.num_classes = num_classes
self.nll_loss = torch.nn.NLLLoss()
self.ce_loss = torch.nn.CrossEntropyLoss(reduction='none')
self.batch_select = batch_select
def noisy(self):
noise_adaptation_softmax = torch.nn.functional.softmax(
self.noisy_adaptation, dim=1) * (1 - self.teacher_acc)
noise_adaptation_layer = torch.zeros(self.num_classes,
self.num_classes).to(
self.noisy_adaptation.device)
tc = torch.FloatTensor([self.teacher_acc
]).to(noise_adaptation_softmax.device)
for i in range(self.num_classes):
if i == 0:
noise_adaptation_layer[i] = \
torch.cat([tc, noise_adaptation_softmax[i][i:]])
if i == self.num_classes - 1:
noise_adaptation_layer[i] = \
torch.cat([noise_adaptation_softmax[i][:i], tc])
else:
noise_adaptation_layer[i] = \
torch.cat([noise_adaptation_softmax[i][:i], tc,
noise_adaptation_softmax[i][i:]])
return noise_adaptation_layer
def forward(self, preds_S, preds_T):
"""Forward computation.
Args:
preds_S (torch.Tensor): The student model prediction with
shape (N, C, H, W) or shape (N, C).
preds_T (torch.Tensor): The teacher model prediction with
shape (N, C, H, W) or shape (N, C).
Return:
torch.Tensor: The calculated loss value.
"""
if self.teacher_detach:
preds_T = preds_T.detach()
pred = preds_T.data.max(1)[1]
loss_t = self.ce_loss(preds_T, pred)
positive_loss_idx = loss_t.topk(
int(self.batch_select * preds_S.shape[0]), largest=False)[1]
softmax_pred_T = F.softmax(preds_T / self.tau, dim=1)
log_softmax_preds_S = F.log_softmax(preds_S / self.tau, dim=1)
softmax_preds_S_adaptation = torch.matmul(
F.softmax(preds_S, dim=1), self.noisy())
loss = (self.tau**2) * (
torch.sum(
F.kl_div(
log_softmax_preds_S[positive_loss_idx],
softmax_pred_T[positive_loss_idx],
reduction='none')) / preds_S.shape[0])
loss += self.nll_loss(torch.log(softmax_preds_S_adaptation), pred)
return self.loss_weight * loss