[Feature] Add FBKD algorithm and torch_connectors (#248)

* 1.Add FBKD

* 1.Add torch_connector and its ut. 2.Revise readme and fbkd config.

* 1.Revise UT for torch_connectors

* 1.Revise nonlocalblock into a subclass of NonLocal2d in mmcv.cnn
pull/244/head^2
zhongyu zhang 2022-08-29 10:05:32 +08:00 committed by GitHub
parent f3b964c521
commit 1c0da58dae
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 802 additions and 15 deletions

View File

@ -14,9 +14,9 @@ Knowledge Distillation (KD) has made remarkable progress in the last few years a
### Classification
| Location | Dataset | Teacher | Student | Acc | Acc(T) | Acc(S) | Config | Download |
| :------: | :-----: | :-------------------------------------------------------------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------------: | :---: | :----: | :----: | :-----------------------------------------------: | :------------------------------------------------------------------------------------------------------------------------------------------- |
| logits | Cifar10 | [resnet34](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet34_8xb16_cifar10.py) | [resnet18](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet18_8xb16_cifar10.py) | 93.26 | 95.34 | 94.82 | [config](./dafl_logits_r34_r18_8xb256_cifar10.py) | [teacher](https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_b16x8_cifar10_20210528-a8aa36a6.pth) \|[model](<>) \| [log](<>) |
| Location | Dataset | Teacher | Student | Acc | Acc(T) | Acc(S) | Config | Download |
| :------: | :-----: | :-------------------------------------------------------------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------------: | :---: | :----: | :----: | :----------------------------------------------: | :------------------------------------------------------------------------------------------------------------------------------------------- |
| logits | Cifar10 | [resnet34](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet34_8xb16_cifar10.py) | [resnet18](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet18_8xb16_cifar10.py) | 93.26 | 95.34 | 94.82 | [config](./dfad_logits_r34_r18_8xb32_cifar10.py) | [teacher](https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_b16x8_cifar10_20210528-a8aa36a6.pth) \|[model](<>) \| [log](<>) |
## Citation

View File

@ -20,9 +20,9 @@ Performing knowledge transfer from a large teacher network to a smaller student
### Classification
| Location | Dataset | Teacher | Student | Acc | Acc(T) | Acc(S) | Config | Download |
| :---------------: | :-----: | :-------------------------------------------------------------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------------: | :---: | :----: | :----: | :-----------------------------------------------: | :------------------------------------------------------------------------------------------------------------------------------------------- |
| backbone & logits | Cifar10 | [resnet34](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet34_8xb16_cifar10.py) | [resnet18](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet18_8xb16_cifar10.py) | 93.50 | 95.34 | 94.82 | [config](./dafl_logits_r34_r18_8xb256_cifar10.py) | [teacher](https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_b16x8_cifar10_20210528-a8aa36a6.pth) \|[model](<>) \| [log](<>) |
| Location | Dataset | Teacher | Student | Acc | Acc(T) | Acc(S) | Config | Download |
| :---------------: | :-----: | :-------------------------------------------------------------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------------: | :---: | :----: | :----: | :-------------------------------------------------------: | :------------------------------------------------------------------------------------------------------------------------------------------- |
| backbone & logits | Cifar10 | [resnet34](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet34_8xb16_cifar10.py) | [resnet18](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet18_8xb16_cifar10.py) | 93.50 | 95.34 | 94.82 | [config](./zskt_backbone_logits_r34_r18_8xb16_cifar10.py) | [teacher](https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_b16x8_cifar10_20210528-a8aa36a6.pth) \|[model](<>) \| [log](<>) |
## Citation

View File

@ -0,0 +1,37 @@
# IMPROVE OBJECT DETECTION WITH FEATURE-BASED KNOWLEDGE DISTILLATION: TOWARDS ACCURATE AND EFFICIENT DETECTORS (FBKD)
> [IMPROVE OBJECT DETECTION WITH FEATURE-BASED KNOWLEDGE DISTILLATION: TOWARDS ACCURATE AND EFFICIENT DETECTORS](https://openreview.net/pdf?id=uKhGRvM8QNH)
<!-- [ALGORITHM] -->
## Abstract
Knowledge distillation, in which a student model is trained to mimic a teacher model, has been proved as an effective technique for model compression and model accuracy boosting. However, most knowledge distillation methods, designed for image classification, have failed on more challenging tasks, such as object detection. In this paper, we suggest that the failure of knowledge distillation on object detection is mainly caused by two reasons: (1) the imbalance between pixels of foreground and background and (2) lack of distillation on the relation between different pixels. Observing the above reasons, we propose attention-guided distillation and non-local distillation to address the two problems, respectively. Attention-guided distillation is proposed to find the crucial pixels of foreground objects with attention mechanism and then make the students take more effort to learn their features. Non-local distillation is proposed to enable students to learn not only the feature of an individual pixel but also the relation between different pixels captured by non-local modules. Experiments show that our methods achieve excellent AP improvements on both one-stage and two-stage, both anchor-based and anchor-free detectors. For example, Faster RCNN (ResNet101 backbone) with our distillation achieves 43.9 AP on COCO2017, which is 4.1 higher than the baseline.
![pipeline](/docs/en/imgs/model_zoo/fbkd/pipeline.png)
## Results and models
### Detection
| Location | Dataset | Teacher | Student | box AP | box AP(T) | box AP(S) | Config | Download |
| :------: | :-----: | :-------------------------------------------------------------------------------------------------------------------------------: | :-----------------------------------------------------------------------------------------------------------------------------: | :----: | :-------: | :-------: | :--------------------------------------------------: | :-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| neck | COCO | [fasterrcnn_resnet101](https://github.com/open-mmlab/mmdetection/blob/master/configs/faster_rcnn/faster_rcnn_r101_fpn_1x_coco.py) | [fasterrcnn_resnet50](https://github.com/open-mmlab/mmdetection/blob/master/configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py) | 39.1 | 39.4 | 37.8 | [config](./fbkd_fpn_frcnn_r101_frcnn_r50_1x_coco.py) | [teacher](https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r101_fpn_1x_coco/faster_rcnn_r101_fpn_1x_coco_20200130-f513f705.pth) \|[model](<>) \| [log](<>) |
## Citation
```latex
@inproceedings{DBLP:conf/iclr/ZhangM21,
author = {Linfeng Zhang and Kaisheng Ma},
title = {Improve Object Detection with Feature-based Knowledge Distillation:
Towards Accurate and Efficient Detectors},
booktitle = {9th International Conference on Learning Representations, {ICLR} 2021,
Virtual Event, Austria, May 3-7, 2021},
publisher = {OpenReview.net},
year = {2021},
url = {https://openreview.net/forum?id=uKhGRvM8QNH},
timestamp = {Wed, 23 Jun 2021 17:36:39 +0200},
biburl = {https://dblp.org/rec/conf/iclr/ZhangM21.bib},
bibsource = {dblp computer science bibliography, https://dblp.org}
}
```

View File

@ -0,0 +1,125 @@
_base_ = [
'mmdet::_base_/datasets/coco_detection.py',
'mmdet::_base_/schedules/schedule_1x.py',
'mmdet::_base_/default_runtime.py'
]
model = dict(
_scope_='mmrazor',
type='SingleTeacherDistill',
architecture=dict(
cfg_path='mmdet::faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py',
pretrained=True),
teacher=dict(
cfg_path='mmdet::faster_rcnn/faster_rcnn_r101_fpn_1x_coco.py',
pretrained=False),
teacher_ckpt='faster_rcnn_r101_fpn_1x_coco_20200130-f513f705.pth',
distiller=dict(
type='ConfigurableDistiller',
student_recorders=dict(
neck_s0=dict(type='ModuleOutputs', source='neck.fpn_convs.0.conv'),
neck_s1=dict(type='ModuleOutputs', source='neck.fpn_convs.1.conv'),
neck_s2=dict(type='ModuleOutputs', source='neck.fpn_convs.2.conv'),
neck_s3=dict(type='ModuleOutputs',
source='neck.fpn_convs.3.conv')),
teacher_recorders=dict(
neck_s0=dict(type='ModuleOutputs', source='neck.fpn_convs.0.conv'),
neck_s1=dict(type='ModuleOutputs', source='neck.fpn_convs.1.conv'),
neck_s2=dict(type='ModuleOutputs', source='neck.fpn_convs.2.conv'),
neck_s3=dict(type='ModuleOutputs',
source='neck.fpn_convs.3.conv')),
distill_losses=dict(
loss_s0=dict(type='FBKDLoss'),
loss_s1=dict(type='FBKDLoss'),
loss_s2=dict(type='FBKDLoss'),
loss_s3=dict(type='FBKDLoss')),
connectors=dict(
loss_s0_sfeat=dict(
type='FBKDStudentConnector',
in_channels=256,
reduction=4,
mode='dot_product',
sub_sample=True,
maxpool_stride=8),
loss_s0_tfeat=dict(
type='FBKDTeacherConnector',
in_channels=256,
reduction=4,
mode='dot_product',
sub_sample=True,
maxpool_stride=8),
loss_s1_sfeat=dict(
type='FBKDStudentConnector',
in_channels=256,
reduction=4,
mode='dot_product',
sub_sample=True,
maxpool_stride=4),
loss_s1_tfeat=dict(
type='FBKDTeacherConnector',
in_channels=256,
reduction=4,
mode='dot_product',
sub_sample=True,
maxpool_stride=4),
loss_s2_sfeat=dict(
type='FBKDStudentConnector',
in_channels=256,
mode='dot_product',
sub_sample=True),
loss_s2_tfeat=dict(
type='FBKDTeacherConnector',
in_channels=256,
mode='dot_product',
sub_sample=True),
loss_s3_sfeat=dict(
type='FBKDStudentConnector',
in_channels=256,
mode='dot_product',
sub_sample=True),
loss_s3_tfeat=dict(
type='FBKDTeacherConnector',
in_channels=256,
mode='dot_product',
sub_sample=True)),
loss_forward_mappings=dict(
loss_s0=dict(
s_input=dict(
from_student=True,
recorder='neck_s0',
connector='loss_s0_sfeat'),
t_input=dict(
from_student=False,
recorder='neck_s0',
connector='loss_s0_tfeat')),
loss_s1=dict(
s_input=dict(
from_student=True,
recorder='neck_s1',
connector='loss_s1_sfeat'),
t_input=dict(
from_student=False,
recorder='neck_s1',
connector='loss_s1_tfeat')),
loss_s2=dict(
s_input=dict(
from_student=True,
recorder='neck_s2',
connector='loss_s2_sfeat'),
t_input=dict(
from_student=False,
recorder='neck_s2',
connector='loss_s2_tfeat')),
loss_s3=dict(
s_input=dict(
from_student=True,
recorder='neck_s3',
connector='loss_s3_sfeat'),
t_input=dict(
from_student=False,
recorder='neck_s3',
connector='loss_s3_tfeat')))))
find_unused_parameters = True
val_cfg = dict(_delete_=True, type='mmrazor.SingleTeacherDistillValLoop')

Binary file not shown.

After

Width:  |  Height:  |  Size: 561 KiB

View File

@ -2,5 +2,11 @@
from .byot_connector import BYOTConnector
from .convmodule_connector import ConvModuleConncetor
from .factor_transfer_connectors import Paraphraser, Translator
from .fbkd_connector import FBKDStudentConnector, FBKDTeacherConnector
from .torch_connector import TorchFunctionalConnector, TorchNNConnector
__all__ = ['ConvModuleConncetor', 'Translator', 'Paraphraser', 'BYOTConnector']
__all__ = [
'ConvModuleConncetor', 'Translator', 'Paraphraser', 'BYOTConnector',
'FBKDTeacherConnector', 'FBKDStudentConnector', 'TorchFunctionalConnector',
'TorchNNConnector'
]

View File

@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from abc import ABCMeta, abstractmethod
from typing import Dict, Optional
from typing import Dict, Optional, Tuple, Union
import torch
from mmengine.model import BaseModule
@ -32,7 +32,9 @@ class BaseConnector(BaseModule, metaclass=ABCMeta):
return self.forward_train(feature)
@abstractmethod
def forward_train(self, feature) -> torch.Tensor:
def forward_train(
self, feature: torch.Tensor
) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
"""Abstract train computation.
Args:

View File

@ -67,7 +67,7 @@ class BYOTConnector(BaseConnector):
self.scala = nn.Sequential(*scala)
self.fc = nn.Linear(out_channel * expansion, num_classes)
def forward_train(self, feature: torch.Tensor) -> torch.Tensor:
def forward_train(self, feature: torch.Tensor) -> Tuple[torch.Tensor, ...]:
"""Forward computation.
Args:

View File

@ -0,0 +1,298 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, Optional, Tuple
import torch
import torch.nn as nn
from mmcv.cnn import NonLocal2d
from mmrazor.registry import MODELS
from .base_connector import BaseConnector
class NonLocal2dMaxpoolNstride(NonLocal2d):
"""Nonlocal block for 2-dimension inputs, with a configurable
maxpool_stride.
This module is proposed in
"Non-local Neural Networks"
Paper reference: https://arxiv.org/abs/1711.07971
Code reference: https://github.com/AlexHex7/Non-local_pytorch
Args:
in_channels (int): Channels of the input feature map.
reduction (int): Channel reduction ratio. Defaults to 2.
conv_cfg (dict): The config dict for convolution layers.
Defaults to `nn.Conv2d`.
norm_cfg (dict): The config dict for normalization layers.
Defaults to `BN`. (This parameter is only applicable to conv_out.)
mode (str): Options are `gaussian`, `concatenation`,
`embedded_gaussian` and `dot_product`. Default: dot_product.
sub_sample (bool): Whether to apply max pooling after pairwise
function (Note that the `sub_sample` is applied on spatial only).
Default: False.
maxpool_stride (int): The stride of the maxpooling module.
Defaults to 2.
zeros_init (bool): Whether to use zero to initialize weights of
`conv_out`. Defaults to True.
"""
def __init__(self,
in_channels: int,
reduction: int = 2,
conv_cfg: Dict = dict(type='Conv2d'),
norm_cfg: Dict = dict(type='BN'),
mode: str = 'embedded_gaussian',
sub_sample: bool = False,
maxpool_stride: int = 2,
zeros_init: bool = True,
**kwargs) -> None:
"""Inits the NonLocal2dMaxpoolNstride module."""
super().__init__(
in_channels=in_channels,
sub_sample=sub_sample,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
reduction=reduction,
mode=mode,
zeros_init=zeros_init,
**kwargs)
self.norm_cfg = norm_cfg
if sub_sample:
max_pool_layer = nn.MaxPool2d(
kernel_size=(maxpool_stride, maxpool_stride))
self.g: nn.Sequential = nn.Sequential(self.g, max_pool_layer)
if self.mode != 'gaussian':
self.phi: nn.Sequential = nn.Sequential(
self.phi, max_pool_layer)
else:
self.phi = max_pool_layer
@MODELS.register_module()
class FBKDStudentConnector(BaseConnector):
"""Improve Object Detection with Feature-based Knowledge Distillation:
Towards Accurate and Efficient Detectors, ICLR2021.
https://openreview.net/pdf?id=uKhGRvM8QNH.
Student connector for FBKD.
Args:
in_channels (int): Channels of the input feature map.
reduction (int): Channel reduction ratio. Defaults to 2.
conv_cfg (dict): The config dict for convolution layers.
Defaults to `nn.Conv2d`.
norm_cfg (dict): The config dict for normalization layers.
Defaults to `BN`. (This parameter is only applicable to conv_out.)
mode (str): Options are `gaussian`, `concatenation`,
`embedded_gaussian` and `dot_product`. Default: dot_product.
sub_sample (bool): Whether to apply max pooling after pairwise
function (Note that the `sub_sample` is applied on spatial only).
Default: False.
maxpool_stride (int): The stride of the maxpooling module.
Defaults to 2.
zeros_init (bool): Whether to use zero to initialize weights of
`conv_out`. Defaults to True.
spatial_T (float): Temperature used in spatial-wise pooling.
Defaults to 0.5.
channel_T (float): Temperature used in channel-wise pooling.
Defaults to 0.5.
init_cfg (dict, optional): The config to control the initialization.
"""
def __init__(self,
in_channels: int,
reduction: int = 2,
conv_cfg: Dict = dict(type='Conv2d'),
norm_cfg: Dict = dict(type='BN'),
mode: str = 'dot_product',
sub_sample: bool = False,
maxpool_stride: int = 2,
zeros_init: bool = True,
spatial_T: float = 0.5,
channel_T: float = 0.5,
init_cfg: Optional[Dict] = None,
**kwargs) -> None:
"""Inits the FBKDStuConnector."""
super().__init__(init_cfg)
self.channel_wise_adaptation = nn.Linear(in_channels, in_channels)
self.spatial_wise_adaptation = nn.Conv2d(
1, 1, kernel_size=3, stride=1, padding=1)
self.adaptation_layers = nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.student_non_local = NonLocal2dMaxpoolNstride(
in_channels=in_channels,
reduction=reduction,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
mode=mode,
sub_sample=sub_sample,
maxpool_stride=maxpool_stride,
zeros_init=zeros_init,
**kwargs)
self.non_local_adaptation = nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.in_channels = in_channels
self.spatial_T = spatial_T
self.channel_T = channel_T
def forward_train(self, x: torch.Tensor) -> Tuple[torch.Tensor, ...]:
"""Frorward function for training.
Args:
x (torch.Tensor): Input student features.
Returns:
s_spatial_mask (torch.Tensor): Student spatial-wise mask.
s_channel_mask (torch.Tensor): Student channel-wise mask.
s_feat_adapt (torch.Tensor): Adaptative student feature.
s_channel_pool_adapt (torch.Tensor): Student feature which through
channel-wise pooling and adaptation_layers.
s_spatial_pool_adapt (torch.Tensor): Student feature which through
spatial-wise pooling and adaptation_layers.
s_relation_adapt (torch.Tensor): Adaptative student relations.
"""
# Calculate spatial-wise mask.
s_spatial_mask = torch.mean(torch.abs(x), [1], keepdim=True)
size = s_spatial_mask.size()
s_spatial_mask = s_spatial_mask.view(x.size(0), -1)
# Soften or sharpen the spatial-wise mask by temperature.
s_spatial_mask = torch.softmax(
s_spatial_mask / self.spatial_T, dim=1) * size[-1] * size[-2]
s_spatial_mask = s_spatial_mask.view(size)
# Calculate channel-wise mask.
s_channel_mask = torch.mean(torch.abs(x), [2, 3], keepdim=True)
channel_mask_size = s_channel_mask.size()
s_channel_mask = s_channel_mask.view(x.size(0), -1)
# Soften or sharpen the channel-wise mask by temperature.
s_channel_mask = torch.softmax(
s_channel_mask / self.channel_T, dim=1) * self.in_channels
s_channel_mask = s_channel_mask.view(channel_mask_size)
# Adaptative and pool student feature through channel-wise.
s_feat_adapt = self.adaptation_layers(x)
s_channel_pool_adapt = self.channel_wise_adaptation(
torch.mean(x, [2, 3]))
# Adaptative and pool student feature through spatial-wise.
s_spatial_pool = torch.mean(x, [1]).view(
x.size(0), 1, x.size(2), x.size(3))
s_spatial_pool_adapt = self.spatial_wise_adaptation(s_spatial_pool)
# Calculate non_local_adaptation.
s_relation = self.student_non_local(x)
s_relation_adapt = self.non_local_adaptation(s_relation)
return (s_spatial_mask, s_channel_mask, s_channel_pool_adapt,
s_spatial_pool_adapt, s_relation_adapt, s_feat_adapt)
@MODELS.register_module()
class FBKDTeacherConnector(BaseConnector):
"""Improve Object Detection with Feature-based Knowledge Distillation:
Towards Accurate and Efficient Detectors, ICLR2021.
https://openreview.net/pdf?id=uKhGRvM8QNH.
Teacher connector for FBKD.
Args:
in_channels (int): Channels of the input feature map.
reduction (int): Channel reduction ratio. Defaults to 2.
conv_cfg (dict): The config dict for convolution layers.
Defaults to `nn.Conv2d`.
norm_cfg (dict): The config dict for normalization layers.
Defaults to `BN`. (This parameter is only applicable to conv_out.)
mode (str): Options are `gaussian`, `concatenation`,
`embedded_gaussian` and `dot_product`. Default: dot_product.
sub_sample (bool): Whether to apply max pooling after pairwise
function (Note that the `sub_sample` is applied on spatial only).
Default: False.
maxpool_stride (int): The stride of the maxpooling module.
Defaults to 2.
zeros_init (bool): Whether to use zero to initialize weights of
`conv_out`. Defaults to True.
spatial_T (float): Temperature used in spatial-wise pooling.
Defaults to 0.5.
channel_T (float): Temperature used in channel-wise pooling.
Defaults to 0.5.
init_cfg (dict, optional): The config to control the initialization.
"""
def __init__(self,
in_channels,
reduction=2,
conv_cfg: Dict = dict(type='Conv2d'),
norm_cfg: Dict = dict(type='BN'),
mode: str = 'dot_product',
sub_sample: bool = False,
maxpool_stride: int = 2,
zeros_init: bool = True,
spatial_T: float = 0.5,
channel_T: float = 0.5,
init_cfg: Optional[Dict] = None,
**kwargs) -> None:
super().__init__(init_cfg)
self.teacher_non_local = NonLocal2dMaxpoolNstride(
in_channels=in_channels,
reduction=reduction,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
mode=mode,
sub_sample=sub_sample,
maxpool_stride=maxpool_stride,
zeros_init=zeros_init,
**kwargs)
self.in_channels = in_channels
self.spatial_T = spatial_T
self.channel_T = channel_T
def forward_train(self, x: torch.Tensor) -> Tuple[torch.Tensor, ...]:
"""Frorward function for training.
Args:
x (torch.Tensor): Input teacher features.
Returns:
t_spatial_mask (torch.Tensor): Teacher spatial-wise mask.
t_channel_mask (torch.Tensor): Teacher channel-wise mask.
t_spatial_pool (torch.Tensor): Teacher features which through
spatial-wise pooling.
t_relation (torch.Tensor): Teacher relation matrix.
"""
# Calculate spatial-wise mask.
t_spatial_mask = torch.mean(torch.abs(x), [1], keepdim=True)
size = t_spatial_mask.size()
t_spatial_mask = t_spatial_mask.view(x.size(0), -1)
# Soften or sharpen the spatial-wise mask by temperature.
t_spatial_mask = torch.softmax(
t_spatial_mask / self.spatial_T, dim=1) * size[-1] * size[-2]
t_spatial_mask = t_spatial_mask.view(size)
# Calculate channel-wise mask.
t_channel_mask = torch.mean(torch.abs(x), [2, 3], keepdim=True)
channel_mask_size = t_channel_mask.size()
t_channel_mask = t_channel_mask.view(x.size(0), -1)
# Soften or sharpen the channel-wise mask by temperature.
t_channel_mask = torch.softmax(
t_channel_mask / self.channel_T, dim=1) * self.in_channels
t_channel_mask = t_channel_mask.view(channel_mask_size)
# Adaptative and pool student feature through spatial-wise.
t_spatial_pool = torch.mean(x, [1]).view(
x.size(0), 1, x.size(2), x.size(3))
# Calculate non_local relation.
t_relation = self.teacher_non_local(x)
return (t_spatial_mask, t_channel_mask, t_spatial_pool, t_relation, x)

View File

@ -0,0 +1,135 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmrazor.registry import MODELS
from .base_connector import BaseConnector
FUNCTION_LIST = [
'adaptive_avg_pool2d',
'adaptive_max_pool2d',
'avg_pool2d',
'dropout',
'dropout2d',
'max_pool2d',
'normalize',
'relu',
'softmax',
'interpolate',
]
@MODELS.register_module()
class TorchFunctionalConnector(BaseConnector):
"""TorchFunctionalConnector: Call function in torch.nn.functional
to process input data
usage:
tensor1 = torch.rand(3,3,16,16)
pool_connector = TorchFunctionalConnector(
function_name='avg_pool2d',
func_args=dict(kernel_size=4),
)
tensor2 = pool_connector.forward_train(tensor1)
tensor2.size()
# torch.Size([3, 3, 4, 4])
which is equal to torch.nn.functional.avg_pool2d(kernel_size=4)
Args:
function_name (str, optional): function. Defaults to None.
func_args (dict): args parsed to function. Defaults to {}.
init_cfg (dict, optional): The config to control the initialization.
"""
def __init__(self,
function_name: Optional[str] = None,
func_args: Dict = {},
init_cfg: Optional[Dict] = None) -> None:
super().__init__(init_cfg)
assert function_name is not None, 'Arg `function_name` cannot be None'
if function_name not in FUNCTION_LIST:
raise ValueError(
' Arg `function_name` are not available, See this list',
FUNCTION_LIST)
self.func = getattr(F, function_name)
self.func_args = func_args
def forward_train(self, x: torch.Tensor) -> torch.Tensor:
"""Frorward function for training.
Args:
x (torch.Tensor): Input features.
"""
x = self.func(x, **self.func_args)
return x
MODULE_LIST = [
'AdaptiveAvgPool2d',
'AdaptiveMaxPool2d',
'AvgPool2d',
'BatchNorm2d',
'Conv2d',
'Dropout',
'Dropout2d',
'Linear',
'MaxPool2d',
'ReLU',
'Softmax',
]
@MODELS.register_module()
class TorchNNConnector(BaseConnector):
"""TorchNNConnector: create nn.module in torch.nn to process input data
usage:
tensor1 = torch.rand(3,3,16,16)
pool_connector = TorchNNConnector(
module_name='AvgPool2d',
module_args=dict(kernel_size=4),
)
tensor2 = pool_connector.forward_train(tensor1)
tensor2.size()
# torch.Size([3, 3, 4, 4])
which is equal to torch.nn.AvgPool2d(kernel_size=4)
Args:
module_name (str, optional):
module name. Defaults to None.
possible_values:['AvgPool2d',
'Dropout2d',
'AdaptiveAvgPool2d',
'AdaptiveMaxPool2d',
'ReLU',
'Softmax',
'BatchNorm2d',
'Linear',]
module_args (dict):
args parsed to nn.Module().__init__(). Defaults to {}.
init_cfg (dict, optional): The config to control the initialization.
"""
def __init__(self,
module_name: Optional[str] = None,
module_args: Dict = {},
init_cfg: Optional[Dict] = None) -> None:
super().__init__(init_cfg)
assert module_name is not None, 'Arg `module_name` cannot be None'
if module_name not in MODULE_LIST:
raise ValueError(
' Arg `module_name` are not available, See this list',
MODULE_LIST)
self.func = getattr(nn, module_name)(**module_args)
def forward_train(self, x: torch.Tensor) -> torch.Tensor:
"""Frorward function for training.
Args:
x (torch.Tensor): Input features.
"""
x = self.func(x)
return x

View File

@ -5,6 +5,7 @@ from .cwd import ChannelWiseDivergence
from .dafl_loss import ActivationLoss, InformationEntropyLoss, OnehotLikeLoss
from .decoupled_kd import DKDLoss
from .factor_transfer_loss import FTLoss
from .fbkd_loss import FBKDLoss
from .kd_soft_ce_loss import KDSoftCELoss
from .kl_divergence import KLDivergence
from .l1_loss import L1Loss
@ -15,5 +16,6 @@ from .weighted_soft_label_distillation import WSLD
__all__ = [
'ChannelWiseDivergence', 'KLDivergence', 'AngleWiseRKD', 'DistanceWiseRKD',
'WSLD', 'L2Loss', 'ABLoss', 'DKDLoss', 'KDSoftCELoss', 'ActivationLoss',
'OnehotLikeLoss', 'InformationEntropyLoss', 'FTLoss', 'ATLoss', 'L1Loss'
'OnehotLikeLoss', 'InformationEntropyLoss', 'FTLoss', 'ATLoss', 'L1Loss',
'FBKDLoss'
]

View File

@ -0,0 +1,120 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Tuple
import torch
import torch.nn as nn
from mmrazor.registry import MODELS
def mask_l2_loss(
tensor_a: torch.Tensor,
tensor_b: torch.Tensor,
saptial_attention_mask: Optional[torch.Tensor] = None,
channel_attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
"""L2 loss with two attention mask, which used to weight the feature
distillation loss in FBKD.
Args:
tensor_a (torch.Tensor): Student featuremap.
tensor_b (torch.Tensor): Teacher featuremap.
saptial_attention_mask (torch.Tensor, optional): Mask of spatial-wise
attention. Defaults to None.
channel_attention_mask (torch.Tensor, optional): Mask of channel-wise
attention. Defaults to None.
Returns:
diff (torch.Tensor): l2 loss with two attention mask.
"""
diff = (tensor_a - tensor_b)**2
if saptial_attention_mask is not None:
diff = diff * saptial_attention_mask
if channel_attention_mask is not None:
diff = diff * channel_attention_mask
diff = torch.sum(diff)**0.5
return diff
@MODELS.register_module()
class FBKDLoss(nn.Module):
"""Loss For FBKD, which includs feat_loss, channel_loss, spatial_loss and
nonlocal_loss.
Source code:
https://github.com/ArchipLab-LinfengZhang/Object-Detection-Knowledge-
Distillation-ICLR2021
Args:
mask_l2_weight (float): The weight of the mask l2 loss.
Defaults to 7e-5, which is the default value in source code.
channel_weight (float): The weight of the channel loss.
Defaults to 4e-3, which is the default value in source code.
spatial_weight (float): The weight of the spatial loss.
Defaults to 4e-3, which is the default value in source code.
nonloacl_weight (float): The weight of the nonlocal loss.
Defaults to 7e-5, which is the default value in source code.
loss_weight (float): Weight of loss. Defaults to 1.0.
"""
def __init__(self,
mask_l2_weight: float = 7e-5,
channel_weight: float = 4e-3,
spatial_weight: float = 4e-3,
nonloacl_weight: float = 7e-5,
loss_weight: float = 1.0) -> None:
"""Inits FBKDLoss."""
super().__init__()
self.mask_l2_weight = mask_l2_weight
self.channel_weight = channel_weight
self.spatial_weight = spatial_weight
self.nonloacl_weight = nonloacl_weight
self.loss_weight = loss_weight
def forward(self, s_input: Tuple[torch.Tensor, ...],
t_input: Tuple[torch.Tensor, ...]) -> torch.Tensor:
"""Forward function of FBKDLoss, including feat_loss, channel_loss,
spatial_loss and nonlocal_loss.
Args:
s_input (Tuple[torch.Tensor, ...]): Student input which is the
output of ``'FBKDStudentConnector'``.
t_input (Tuple[torch.Tensor, ...]): Teacher input which is the
output of ``'FBKDTeacherConnector'``.
"""
losses = 0.0
(s_spatial_mask, s_channel_mask, s_channel_pool_adapt,
s_spatial_pool_adapt, s_relation_adapt, s_feat_adapt) = s_input
(t_spatial_mask, t_channel_mask, t_spatial_pool, t_relation,
t_feat) = t_input
# Spatial-wise mask.
spatial_sum_mask = (t_spatial_mask + s_spatial_mask) / 2
spatial_sum_mask = spatial_sum_mask.detach()
# Channel-wise mask, but not used in the FBKD source code.
channel_sum_mask = (t_channel_mask + s_channel_mask) / 2
channel_sum_mask = channel_sum_mask.detach()
# feat_loss with mask
losses += mask_l2_loss(
t_feat,
s_feat_adapt,
saptial_attention_mask=spatial_sum_mask,
channel_attention_mask=None) * self.mask_l2_weight
# channel_loss
losses += torch.dist(torch.mean(t_feat, [2, 3]),
s_channel_pool_adapt) * self.channel_weight
# spatial_loss
losses += torch.dist(t_spatial_pool,
s_spatial_pool_adapt) * self.spatial_weight
# nonlocal_loss
losses += torch.dist(
t_relation, s_relation_adapt, p=2) * self.nonloacl_weight
return self.loss_weight * losses

View File

@ -3,8 +3,10 @@ from unittest import TestCase
import torch
from mmrazor.models import (BYOTConnector, ConvModuleConncetor, Paraphraser,
Translator)
from mmrazor.models import (BYOTConnector, ConvModuleConncetor,
FBKDStudentConnector, FBKDTeacherConnector,
Paraphraser, TorchFunctionalConnector,
TorchNNConnector, Translator)
class TestConnector(TestCase):
@ -68,3 +70,46 @@ class TestConnector(TestCase):
output, logits = byot_connector.forward_train(s_feat)
assert output.size() == t_feat.size()
assert logits.size() == labels.size()
def test_fbkd_connector(self):
fbkd_stuconnector_cfg = dict(
in_channels=16, reduction=2, sub_sample=True)
fbkd_stuconnector = FBKDStudentConnector(**fbkd_stuconnector_cfg)
fbkd_teaconnector_cfg = dict(
in_channels=16, reduction=2, sub_sample=True)
fbkd_teaconnector = FBKDTeacherConnector(**fbkd_teaconnector_cfg)
s_feat = torch.randn(1, 16, 8, 8)
t_feat = torch.randn(1, 16, 8, 8)
s_output = fbkd_stuconnector(s_feat)
t_output = fbkd_teaconnector(t_feat)
assert len(s_output) == 6
assert len(t_output) == 5
assert torch.equal(t_output[-1], t_feat)
def test_torch_connector(self):
tensor1 = torch.rand(3, 3, 16, 16)
functional_pool_connector = TorchFunctionalConnector(
function_name='avg_pool2d', func_args=dict(kernel_size=4))
tensor2 = functional_pool_connector.forward_train(tensor1)
assert tensor2.shape == torch.Size([3, 3, 4, 4])
with self.assertRaises(AssertionError):
functional_pool_connector = TorchFunctionalConnector()
with self.assertRaises(ValueError):
functional_pool_connector = TorchFunctionalConnector(
function_name='fake')
nn_pool_connector = TorchNNConnector(
module_name='AvgPool2d', module_args=dict(kernel_size=4))
tensor3 = nn_pool_connector.forward_train(tensor1)
assert tensor3.shape == torch.Size([3, 3, 4, 4])
assert torch.equal(tensor2, tensor3)
with self.assertRaises(AssertionError):
functional_pool_connector = TorchFunctionalConnector()
with self.assertRaises(ValueError):
functional_pool_connector = TorchNNConnector(module_name='fake')

View File

@ -4,8 +4,8 @@ from unittest import TestCase
import torch
from mmrazor import digit_version
from mmrazor.models import (ABLoss, ActivationLoss, ATLoss, DKDLoss, FTLoss,
InformationEntropyLoss, KDSoftCELoss,
from mmrazor.models import (ABLoss, ActivationLoss, ATLoss, DKDLoss, FBKDLoss,
FTLoss, InformationEntropyLoss, KDSoftCELoss,
OnehotLikeLoss)
@ -113,3 +113,20 @@ class TestLosses(TestCase):
self.normal_test_1d(at_loss)
self.normal_test_2d(at_loss)
self.normal_test_3d(at_loss)
def test_fbkdloss(self):
fbkdloss_cfg = dict(loss_weight=1.0)
fbkdloss = FBKDLoss(**fbkdloss_cfg)
spatial_mask = torch.randn(1, 1, 3, 3)
channel_mask = torch.randn(1, 4, 1, 1)
channel_pool_adapt = torch.randn(1, 4)
relation_adpt = torch.randn(1, 4, 3, 3)
s_input = (spatial_mask, channel_mask, channel_pool_adapt,
spatial_mask, channel_mask, relation_adpt)
t_input = (spatial_mask, channel_mask, spatial_mask, channel_mask,
relation_adpt)
fbkd_loss = fbkdloss(s_input, t_input)
self.assertTrue(fbkd_loss.numel() == 1)