[Feature] Add FBKD algorithm and torch_connectors (#248)
* 1.Add FBKD * 1.Add torch_connector and its ut. 2.Revise readme and fbkd config. * 1.Revise UT for torch_connectors * 1.Revise nonlocalblock into a subclass of NonLocal2d in mmcv.cnnpull/244/head^2
parent
f3b964c521
commit
1c0da58dae
|
@ -14,9 +14,9 @@ Knowledge Distillation (KD) has made remarkable progress in the last few years a
|
|||
|
||||
### Classification
|
||||
|
||||
| Location | Dataset | Teacher | Student | Acc | Acc(T) | Acc(S) | Config | Download |
|
||||
| :------: | :-----: | :-------------------------------------------------------------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------------: | :---: | :----: | :----: | :-----------------------------------------------: | :------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| logits | Cifar10 | [resnet34](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet34_8xb16_cifar10.py) | [resnet18](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet18_8xb16_cifar10.py) | 93.26 | 95.34 | 94.82 | [config](./dafl_logits_r34_r18_8xb256_cifar10.py) | [teacher](https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_b16x8_cifar10_20210528-a8aa36a6.pth) \|[model](<>) \| [log](<>) |
|
||||
| Location | Dataset | Teacher | Student | Acc | Acc(T) | Acc(S) | Config | Download |
|
||||
| :------: | :-----: | :-------------------------------------------------------------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------------: | :---: | :----: | :----: | :----------------------------------------------: | :------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| logits | Cifar10 | [resnet34](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet34_8xb16_cifar10.py) | [resnet18](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet18_8xb16_cifar10.py) | 93.26 | 95.34 | 94.82 | [config](./dfad_logits_r34_r18_8xb32_cifar10.py) | [teacher](https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_b16x8_cifar10_20210528-a8aa36a6.pth) \|[model](<>) \| [log](<>) |
|
||||
|
||||
## Citation
|
||||
|
||||
|
|
|
@ -20,9 +20,9 @@ Performing knowledge transfer from a large teacher network to a smaller student
|
|||
|
||||
### Classification
|
||||
|
||||
| Location | Dataset | Teacher | Student | Acc | Acc(T) | Acc(S) | Config | Download |
|
||||
| :---------------: | :-----: | :-------------------------------------------------------------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------------: | :---: | :----: | :----: | :-----------------------------------------------: | :------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| backbone & logits | Cifar10 | [resnet34](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet34_8xb16_cifar10.py) | [resnet18](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet18_8xb16_cifar10.py) | 93.50 | 95.34 | 94.82 | [config](./dafl_logits_r34_r18_8xb256_cifar10.py) | [teacher](https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_b16x8_cifar10_20210528-a8aa36a6.pth) \|[model](<>) \| [log](<>) |
|
||||
| Location | Dataset | Teacher | Student | Acc | Acc(T) | Acc(S) | Config | Download |
|
||||
| :---------------: | :-----: | :-------------------------------------------------------------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------------: | :---: | :----: | :----: | :-------------------------------------------------------: | :------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| backbone & logits | Cifar10 | [resnet34](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet34_8xb16_cifar10.py) | [resnet18](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet18_8xb16_cifar10.py) | 93.50 | 95.34 | 94.82 | [config](./zskt_backbone_logits_r34_r18_8xb16_cifar10.py) | [teacher](https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_b16x8_cifar10_20210528-a8aa36a6.pth) \|[model](<>) \| [log](<>) |
|
||||
|
||||
## Citation
|
||||
|
||||
|
|
|
@ -0,0 +1,37 @@
|
|||
# IMPROVE OBJECT DETECTION WITH FEATURE-BASED KNOWLEDGE DISTILLATION: TOWARDS ACCURATE AND EFFICIENT DETECTORS (FBKD)
|
||||
|
||||
> [IMPROVE OBJECT DETECTION WITH FEATURE-BASED KNOWLEDGE DISTILLATION: TOWARDS ACCURATE AND EFFICIENT DETECTORS](https://openreview.net/pdf?id=uKhGRvM8QNH)
|
||||
|
||||
<!-- [ALGORITHM] -->
|
||||
|
||||
## Abstract
|
||||
|
||||
Knowledge distillation, in which a student model is trained to mimic a teacher model, has been proved as an effective technique for model compression and model accuracy boosting. However, most knowledge distillation methods, designed for image classification, have failed on more challenging tasks, such as object detection. In this paper, we suggest that the failure of knowledge distillation on object detection is mainly caused by two reasons: (1) the imbalance between pixels of foreground and background and (2) lack of distillation on the relation between different pixels. Observing the above reasons, we propose attention-guided distillation and non-local distillation to address the two problems, respectively. Attention-guided distillation is proposed to find the crucial pixels of foreground objects with attention mechanism and then make the students take more effort to learn their features. Non-local distillation is proposed to enable students to learn not only the feature of an individual pixel but also the relation between different pixels captured by non-local modules. Experiments show that our methods achieve excellent AP improvements on both one-stage and two-stage, both anchor-based and anchor-free detectors. For example, Faster RCNN (ResNet101 backbone) with our distillation achieves 43.9 AP on COCO2017, which is 4.1 higher than the baseline.
|
||||
|
||||

|
||||
|
||||
## Results and models
|
||||
|
||||
### Detection
|
||||
|
||||
| Location | Dataset | Teacher | Student | box AP | box AP(T) | box AP(S) | Config | Download |
|
||||
| :------: | :-----: | :-------------------------------------------------------------------------------------------------------------------------------: | :-----------------------------------------------------------------------------------------------------------------------------: | :----: | :-------: | :-------: | :--------------------------------------------------: | :-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| neck | COCO | [fasterrcnn_resnet101](https://github.com/open-mmlab/mmdetection/blob/master/configs/faster_rcnn/faster_rcnn_r101_fpn_1x_coco.py) | [fasterrcnn_resnet50](https://github.com/open-mmlab/mmdetection/blob/master/configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py) | 39.1 | 39.4 | 37.8 | [config](./fbkd_fpn_frcnn_r101_frcnn_r50_1x_coco.py) | [teacher](https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r101_fpn_1x_coco/faster_rcnn_r101_fpn_1x_coco_20200130-f513f705.pth) \|[model](<>) \| [log](<>) |
|
||||
|
||||
## Citation
|
||||
|
||||
```latex
|
||||
@inproceedings{DBLP:conf/iclr/ZhangM21,
|
||||
author = {Linfeng Zhang and Kaisheng Ma},
|
||||
title = {Improve Object Detection with Feature-based Knowledge Distillation:
|
||||
Towards Accurate and Efficient Detectors},
|
||||
booktitle = {9th International Conference on Learning Representations, {ICLR} 2021,
|
||||
Virtual Event, Austria, May 3-7, 2021},
|
||||
publisher = {OpenReview.net},
|
||||
year = {2021},
|
||||
url = {https://openreview.net/forum?id=uKhGRvM8QNH},
|
||||
timestamp = {Wed, 23 Jun 2021 17:36:39 +0200},
|
||||
biburl = {https://dblp.org/rec/conf/iclr/ZhangM21.bib},
|
||||
bibsource = {dblp computer science bibliography, https://dblp.org}
|
||||
}
|
||||
```
|
|
@ -0,0 +1,125 @@
|
|||
_base_ = [
|
||||
'mmdet::_base_/datasets/coco_detection.py',
|
||||
'mmdet::_base_/schedules/schedule_1x.py',
|
||||
'mmdet::_base_/default_runtime.py'
|
||||
]
|
||||
|
||||
model = dict(
|
||||
_scope_='mmrazor',
|
||||
type='SingleTeacherDistill',
|
||||
architecture=dict(
|
||||
cfg_path='mmdet::faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py',
|
||||
pretrained=True),
|
||||
teacher=dict(
|
||||
cfg_path='mmdet::faster_rcnn/faster_rcnn_r101_fpn_1x_coco.py',
|
||||
pretrained=False),
|
||||
teacher_ckpt='faster_rcnn_r101_fpn_1x_coco_20200130-f513f705.pth',
|
||||
distiller=dict(
|
||||
type='ConfigurableDistiller',
|
||||
student_recorders=dict(
|
||||
neck_s0=dict(type='ModuleOutputs', source='neck.fpn_convs.0.conv'),
|
||||
neck_s1=dict(type='ModuleOutputs', source='neck.fpn_convs.1.conv'),
|
||||
neck_s2=dict(type='ModuleOutputs', source='neck.fpn_convs.2.conv'),
|
||||
neck_s3=dict(type='ModuleOutputs',
|
||||
source='neck.fpn_convs.3.conv')),
|
||||
teacher_recorders=dict(
|
||||
neck_s0=dict(type='ModuleOutputs', source='neck.fpn_convs.0.conv'),
|
||||
neck_s1=dict(type='ModuleOutputs', source='neck.fpn_convs.1.conv'),
|
||||
neck_s2=dict(type='ModuleOutputs', source='neck.fpn_convs.2.conv'),
|
||||
neck_s3=dict(type='ModuleOutputs',
|
||||
source='neck.fpn_convs.3.conv')),
|
||||
distill_losses=dict(
|
||||
loss_s0=dict(type='FBKDLoss'),
|
||||
loss_s1=dict(type='FBKDLoss'),
|
||||
loss_s2=dict(type='FBKDLoss'),
|
||||
loss_s3=dict(type='FBKDLoss')),
|
||||
connectors=dict(
|
||||
loss_s0_sfeat=dict(
|
||||
type='FBKDStudentConnector',
|
||||
in_channels=256,
|
||||
reduction=4,
|
||||
mode='dot_product',
|
||||
sub_sample=True,
|
||||
maxpool_stride=8),
|
||||
loss_s0_tfeat=dict(
|
||||
type='FBKDTeacherConnector',
|
||||
in_channels=256,
|
||||
reduction=4,
|
||||
mode='dot_product',
|
||||
sub_sample=True,
|
||||
maxpool_stride=8),
|
||||
loss_s1_sfeat=dict(
|
||||
type='FBKDStudentConnector',
|
||||
in_channels=256,
|
||||
reduction=4,
|
||||
mode='dot_product',
|
||||
sub_sample=True,
|
||||
maxpool_stride=4),
|
||||
loss_s1_tfeat=dict(
|
||||
type='FBKDTeacherConnector',
|
||||
in_channels=256,
|
||||
reduction=4,
|
||||
mode='dot_product',
|
||||
sub_sample=True,
|
||||
maxpool_stride=4),
|
||||
loss_s2_sfeat=dict(
|
||||
type='FBKDStudentConnector',
|
||||
in_channels=256,
|
||||
mode='dot_product',
|
||||
sub_sample=True),
|
||||
loss_s2_tfeat=dict(
|
||||
type='FBKDTeacherConnector',
|
||||
in_channels=256,
|
||||
mode='dot_product',
|
||||
sub_sample=True),
|
||||
loss_s3_sfeat=dict(
|
||||
type='FBKDStudentConnector',
|
||||
in_channels=256,
|
||||
mode='dot_product',
|
||||
sub_sample=True),
|
||||
loss_s3_tfeat=dict(
|
||||
type='FBKDTeacherConnector',
|
||||
in_channels=256,
|
||||
mode='dot_product',
|
||||
sub_sample=True)),
|
||||
loss_forward_mappings=dict(
|
||||
loss_s0=dict(
|
||||
s_input=dict(
|
||||
from_student=True,
|
||||
recorder='neck_s0',
|
||||
connector='loss_s0_sfeat'),
|
||||
t_input=dict(
|
||||
from_student=False,
|
||||
recorder='neck_s0',
|
||||
connector='loss_s0_tfeat')),
|
||||
loss_s1=dict(
|
||||
s_input=dict(
|
||||
from_student=True,
|
||||
recorder='neck_s1',
|
||||
connector='loss_s1_sfeat'),
|
||||
t_input=dict(
|
||||
from_student=False,
|
||||
recorder='neck_s1',
|
||||
connector='loss_s1_tfeat')),
|
||||
loss_s2=dict(
|
||||
s_input=dict(
|
||||
from_student=True,
|
||||
recorder='neck_s2',
|
||||
connector='loss_s2_sfeat'),
|
||||
t_input=dict(
|
||||
from_student=False,
|
||||
recorder='neck_s2',
|
||||
connector='loss_s2_tfeat')),
|
||||
loss_s3=dict(
|
||||
s_input=dict(
|
||||
from_student=True,
|
||||
recorder='neck_s3',
|
||||
connector='loss_s3_sfeat'),
|
||||
t_input=dict(
|
||||
from_student=False,
|
||||
recorder='neck_s3',
|
||||
connector='loss_s3_tfeat')))))
|
||||
|
||||
find_unused_parameters = True
|
||||
|
||||
val_cfg = dict(_delete_=True, type='mmrazor.SingleTeacherDistillValLoop')
|
Binary file not shown.
After Width: | Height: | Size: 561 KiB |
|
@ -2,5 +2,11 @@
|
|||
from .byot_connector import BYOTConnector
|
||||
from .convmodule_connector import ConvModuleConncetor
|
||||
from .factor_transfer_connectors import Paraphraser, Translator
|
||||
from .fbkd_connector import FBKDStudentConnector, FBKDTeacherConnector
|
||||
from .torch_connector import TorchFunctionalConnector, TorchNNConnector
|
||||
|
||||
__all__ = ['ConvModuleConncetor', 'Translator', 'Paraphraser', 'BYOTConnector']
|
||||
__all__ = [
|
||||
'ConvModuleConncetor', 'Translator', 'Paraphraser', 'BYOTConnector',
|
||||
'FBKDTeacherConnector', 'FBKDStudentConnector', 'TorchFunctionalConnector',
|
||||
'TorchNNConnector'
|
||||
]
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from typing import Dict, Optional
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from mmengine.model import BaseModule
|
||||
|
@ -32,7 +32,9 @@ class BaseConnector(BaseModule, metaclass=ABCMeta):
|
|||
return self.forward_train(feature)
|
||||
|
||||
@abstractmethod
|
||||
def forward_train(self, feature) -> torch.Tensor:
|
||||
def forward_train(
|
||||
self, feature: torch.Tensor
|
||||
) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
|
||||
"""Abstract train computation.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -67,7 +67,7 @@ class BYOTConnector(BaseConnector):
|
|||
self.scala = nn.Sequential(*scala)
|
||||
self.fc = nn.Linear(out_channel * expansion, num_classes)
|
||||
|
||||
def forward_train(self, feature: torch.Tensor) -> torch.Tensor:
|
||||
def forward_train(self, feature: torch.Tensor) -> Tuple[torch.Tensor, ...]:
|
||||
"""Forward computation.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -0,0 +1,298 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import NonLocal2d
|
||||
|
||||
from mmrazor.registry import MODELS
|
||||
from .base_connector import BaseConnector
|
||||
|
||||
|
||||
class NonLocal2dMaxpoolNstride(NonLocal2d):
|
||||
"""Nonlocal block for 2-dimension inputs, with a configurable
|
||||
maxpool_stride.
|
||||
|
||||
This module is proposed in
|
||||
"Non-local Neural Networks"
|
||||
Paper reference: https://arxiv.org/abs/1711.07971
|
||||
Code reference: https://github.com/AlexHex7/Non-local_pytorch
|
||||
|
||||
Args:
|
||||
in_channels (int): Channels of the input feature map.
|
||||
reduction (int): Channel reduction ratio. Defaults to 2.
|
||||
conv_cfg (dict): The config dict for convolution layers.
|
||||
Defaults to `nn.Conv2d`.
|
||||
norm_cfg (dict): The config dict for normalization layers.
|
||||
Defaults to `BN`. (This parameter is only applicable to conv_out.)
|
||||
mode (str): Options are `gaussian`, `concatenation`,
|
||||
`embedded_gaussian` and `dot_product`. Default: dot_product.
|
||||
sub_sample (bool): Whether to apply max pooling after pairwise
|
||||
function (Note that the `sub_sample` is applied on spatial only).
|
||||
Default: False.
|
||||
maxpool_stride (int): The stride of the maxpooling module.
|
||||
Defaults to 2.
|
||||
zeros_init (bool): Whether to use zero to initialize weights of
|
||||
`conv_out`. Defaults to True.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels: int,
|
||||
reduction: int = 2,
|
||||
conv_cfg: Dict = dict(type='Conv2d'),
|
||||
norm_cfg: Dict = dict(type='BN'),
|
||||
mode: str = 'embedded_gaussian',
|
||||
sub_sample: bool = False,
|
||||
maxpool_stride: int = 2,
|
||||
zeros_init: bool = True,
|
||||
**kwargs) -> None:
|
||||
"""Inits the NonLocal2dMaxpoolNstride module."""
|
||||
super().__init__(
|
||||
in_channels=in_channels,
|
||||
sub_sample=sub_sample,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
reduction=reduction,
|
||||
mode=mode,
|
||||
zeros_init=zeros_init,
|
||||
**kwargs)
|
||||
self.norm_cfg = norm_cfg
|
||||
|
||||
if sub_sample:
|
||||
max_pool_layer = nn.MaxPool2d(
|
||||
kernel_size=(maxpool_stride, maxpool_stride))
|
||||
self.g: nn.Sequential = nn.Sequential(self.g, max_pool_layer)
|
||||
if self.mode != 'gaussian':
|
||||
self.phi: nn.Sequential = nn.Sequential(
|
||||
self.phi, max_pool_layer)
|
||||
else:
|
||||
self.phi = max_pool_layer
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class FBKDStudentConnector(BaseConnector):
|
||||
"""Improve Object Detection with Feature-based Knowledge Distillation:
|
||||
Towards Accurate and Efficient Detectors, ICLR2021.
|
||||
https://openreview.net/pdf?id=uKhGRvM8QNH.
|
||||
|
||||
Student connector for FBKD.
|
||||
|
||||
Args:
|
||||
in_channels (int): Channels of the input feature map.
|
||||
reduction (int): Channel reduction ratio. Defaults to 2.
|
||||
conv_cfg (dict): The config dict for convolution layers.
|
||||
Defaults to `nn.Conv2d`.
|
||||
norm_cfg (dict): The config dict for normalization layers.
|
||||
Defaults to `BN`. (This parameter is only applicable to conv_out.)
|
||||
mode (str): Options are `gaussian`, `concatenation`,
|
||||
`embedded_gaussian` and `dot_product`. Default: dot_product.
|
||||
sub_sample (bool): Whether to apply max pooling after pairwise
|
||||
function (Note that the `sub_sample` is applied on spatial only).
|
||||
Default: False.
|
||||
maxpool_stride (int): The stride of the maxpooling module.
|
||||
Defaults to 2.
|
||||
zeros_init (bool): Whether to use zero to initialize weights of
|
||||
`conv_out`. Defaults to True.
|
||||
spatial_T (float): Temperature used in spatial-wise pooling.
|
||||
Defaults to 0.5.
|
||||
channel_T (float): Temperature used in channel-wise pooling.
|
||||
Defaults to 0.5.
|
||||
init_cfg (dict, optional): The config to control the initialization.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels: int,
|
||||
reduction: int = 2,
|
||||
conv_cfg: Dict = dict(type='Conv2d'),
|
||||
norm_cfg: Dict = dict(type='BN'),
|
||||
mode: str = 'dot_product',
|
||||
sub_sample: bool = False,
|
||||
maxpool_stride: int = 2,
|
||||
zeros_init: bool = True,
|
||||
spatial_T: float = 0.5,
|
||||
channel_T: float = 0.5,
|
||||
init_cfg: Optional[Dict] = None,
|
||||
**kwargs) -> None:
|
||||
"""Inits the FBKDStuConnector."""
|
||||
super().__init__(init_cfg)
|
||||
self.channel_wise_adaptation = nn.Linear(in_channels, in_channels)
|
||||
|
||||
self.spatial_wise_adaptation = nn.Conv2d(
|
||||
1, 1, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
self.adaptation_layers = nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
self.student_non_local = NonLocal2dMaxpoolNstride(
|
||||
in_channels=in_channels,
|
||||
reduction=reduction,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
mode=mode,
|
||||
sub_sample=sub_sample,
|
||||
maxpool_stride=maxpool_stride,
|
||||
zeros_init=zeros_init,
|
||||
**kwargs)
|
||||
|
||||
self.non_local_adaptation = nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.spatial_T = spatial_T
|
||||
self.channel_T = channel_T
|
||||
|
||||
def forward_train(self, x: torch.Tensor) -> Tuple[torch.Tensor, ...]:
|
||||
"""Frorward function for training.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input student features.
|
||||
|
||||
Returns:
|
||||
s_spatial_mask (torch.Tensor): Student spatial-wise mask.
|
||||
s_channel_mask (torch.Tensor): Student channel-wise mask.
|
||||
s_feat_adapt (torch.Tensor): Adaptative student feature.
|
||||
s_channel_pool_adapt (torch.Tensor): Student feature which through
|
||||
channel-wise pooling and adaptation_layers.
|
||||
s_spatial_pool_adapt (torch.Tensor): Student feature which through
|
||||
spatial-wise pooling and adaptation_layers.
|
||||
s_relation_adapt (torch.Tensor): Adaptative student relations.
|
||||
"""
|
||||
# Calculate spatial-wise mask.
|
||||
s_spatial_mask = torch.mean(torch.abs(x), [1], keepdim=True)
|
||||
size = s_spatial_mask.size()
|
||||
s_spatial_mask = s_spatial_mask.view(x.size(0), -1)
|
||||
|
||||
# Soften or sharpen the spatial-wise mask by temperature.
|
||||
s_spatial_mask = torch.softmax(
|
||||
s_spatial_mask / self.spatial_T, dim=1) * size[-1] * size[-2]
|
||||
s_spatial_mask = s_spatial_mask.view(size)
|
||||
|
||||
# Calculate channel-wise mask.
|
||||
s_channel_mask = torch.mean(torch.abs(x), [2, 3], keepdim=True)
|
||||
channel_mask_size = s_channel_mask.size()
|
||||
s_channel_mask = s_channel_mask.view(x.size(0), -1)
|
||||
|
||||
# Soften or sharpen the channel-wise mask by temperature.
|
||||
s_channel_mask = torch.softmax(
|
||||
s_channel_mask / self.channel_T, dim=1) * self.in_channels
|
||||
s_channel_mask = s_channel_mask.view(channel_mask_size)
|
||||
|
||||
# Adaptative and pool student feature through channel-wise.
|
||||
s_feat_adapt = self.adaptation_layers(x)
|
||||
s_channel_pool_adapt = self.channel_wise_adaptation(
|
||||
torch.mean(x, [2, 3]))
|
||||
|
||||
# Adaptative and pool student feature through spatial-wise.
|
||||
s_spatial_pool = torch.mean(x, [1]).view(
|
||||
x.size(0), 1, x.size(2), x.size(3))
|
||||
s_spatial_pool_adapt = self.spatial_wise_adaptation(s_spatial_pool)
|
||||
|
||||
# Calculate non_local_adaptation.
|
||||
s_relation = self.student_non_local(x)
|
||||
s_relation_adapt = self.non_local_adaptation(s_relation)
|
||||
|
||||
return (s_spatial_mask, s_channel_mask, s_channel_pool_adapt,
|
||||
s_spatial_pool_adapt, s_relation_adapt, s_feat_adapt)
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class FBKDTeacherConnector(BaseConnector):
|
||||
"""Improve Object Detection with Feature-based Knowledge Distillation:
|
||||
Towards Accurate and Efficient Detectors, ICLR2021.
|
||||
https://openreview.net/pdf?id=uKhGRvM8QNH.
|
||||
|
||||
Teacher connector for FBKD.
|
||||
|
||||
Args:
|
||||
in_channels (int): Channels of the input feature map.
|
||||
reduction (int): Channel reduction ratio. Defaults to 2.
|
||||
conv_cfg (dict): The config dict for convolution layers.
|
||||
Defaults to `nn.Conv2d`.
|
||||
norm_cfg (dict): The config dict for normalization layers.
|
||||
Defaults to `BN`. (This parameter is only applicable to conv_out.)
|
||||
mode (str): Options are `gaussian`, `concatenation`,
|
||||
`embedded_gaussian` and `dot_product`. Default: dot_product.
|
||||
sub_sample (bool): Whether to apply max pooling after pairwise
|
||||
function (Note that the `sub_sample` is applied on spatial only).
|
||||
Default: False.
|
||||
maxpool_stride (int): The stride of the maxpooling module.
|
||||
Defaults to 2.
|
||||
zeros_init (bool): Whether to use zero to initialize weights of
|
||||
`conv_out`. Defaults to True.
|
||||
spatial_T (float): Temperature used in spatial-wise pooling.
|
||||
Defaults to 0.5.
|
||||
channel_T (float): Temperature used in channel-wise pooling.
|
||||
Defaults to 0.5.
|
||||
init_cfg (dict, optional): The config to control the initialization.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
reduction=2,
|
||||
conv_cfg: Dict = dict(type='Conv2d'),
|
||||
norm_cfg: Dict = dict(type='BN'),
|
||||
mode: str = 'dot_product',
|
||||
sub_sample: bool = False,
|
||||
maxpool_stride: int = 2,
|
||||
zeros_init: bool = True,
|
||||
spatial_T: float = 0.5,
|
||||
channel_T: float = 0.5,
|
||||
init_cfg: Optional[Dict] = None,
|
||||
**kwargs) -> None:
|
||||
super().__init__(init_cfg)
|
||||
self.teacher_non_local = NonLocal2dMaxpoolNstride(
|
||||
in_channels=in_channels,
|
||||
reduction=reduction,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
mode=mode,
|
||||
sub_sample=sub_sample,
|
||||
maxpool_stride=maxpool_stride,
|
||||
zeros_init=zeros_init,
|
||||
**kwargs)
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.spatial_T = spatial_T
|
||||
self.channel_T = channel_T
|
||||
|
||||
def forward_train(self, x: torch.Tensor) -> Tuple[torch.Tensor, ...]:
|
||||
"""Frorward function for training.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input teacher features.
|
||||
|
||||
Returns:
|
||||
t_spatial_mask (torch.Tensor): Teacher spatial-wise mask.
|
||||
t_channel_mask (torch.Tensor): Teacher channel-wise mask.
|
||||
t_spatial_pool (torch.Tensor): Teacher features which through
|
||||
spatial-wise pooling.
|
||||
t_relation (torch.Tensor): Teacher relation matrix.
|
||||
"""
|
||||
# Calculate spatial-wise mask.
|
||||
t_spatial_mask = torch.mean(torch.abs(x), [1], keepdim=True)
|
||||
size = t_spatial_mask.size()
|
||||
t_spatial_mask = t_spatial_mask.view(x.size(0), -1)
|
||||
|
||||
# Soften or sharpen the spatial-wise mask by temperature.
|
||||
t_spatial_mask = torch.softmax(
|
||||
t_spatial_mask / self.spatial_T, dim=1) * size[-1] * size[-2]
|
||||
t_spatial_mask = t_spatial_mask.view(size)
|
||||
|
||||
# Calculate channel-wise mask.
|
||||
t_channel_mask = torch.mean(torch.abs(x), [2, 3], keepdim=True)
|
||||
channel_mask_size = t_channel_mask.size()
|
||||
t_channel_mask = t_channel_mask.view(x.size(0), -1)
|
||||
|
||||
# Soften or sharpen the channel-wise mask by temperature.
|
||||
t_channel_mask = torch.softmax(
|
||||
t_channel_mask / self.channel_T, dim=1) * self.in_channels
|
||||
t_channel_mask = t_channel_mask.view(channel_mask_size)
|
||||
|
||||
# Adaptative and pool student feature through spatial-wise.
|
||||
t_spatial_pool = torch.mean(x, [1]).view(
|
||||
x.size(0), 1, x.size(2), x.size(3))
|
||||
|
||||
# Calculate non_local relation.
|
||||
t_relation = self.teacher_non_local(x)
|
||||
|
||||
return (t_spatial_mask, t_channel_mask, t_spatial_pool, t_relation, x)
|
|
@ -0,0 +1,135 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Dict, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from mmrazor.registry import MODELS
|
||||
from .base_connector import BaseConnector
|
||||
|
||||
FUNCTION_LIST = [
|
||||
'adaptive_avg_pool2d',
|
||||
'adaptive_max_pool2d',
|
||||
'avg_pool2d',
|
||||
'dropout',
|
||||
'dropout2d',
|
||||
'max_pool2d',
|
||||
'normalize',
|
||||
'relu',
|
||||
'softmax',
|
||||
'interpolate',
|
||||
]
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class TorchFunctionalConnector(BaseConnector):
|
||||
"""TorchFunctionalConnector: Call function in torch.nn.functional
|
||||
to process input data
|
||||
|
||||
usage:
|
||||
tensor1 = torch.rand(3,3,16,16)
|
||||
pool_connector = TorchFunctionalConnector(
|
||||
function_name='avg_pool2d',
|
||||
func_args=dict(kernel_size=4),
|
||||
)
|
||||
tensor2 = pool_connector.forward_train(tensor1)
|
||||
tensor2.size()
|
||||
# torch.Size([3, 3, 4, 4])
|
||||
|
||||
which is equal to torch.nn.functional.avg_pool2d(kernel_size=4)
|
||||
Args:
|
||||
function_name (str, optional): function. Defaults to None.
|
||||
func_args (dict): args parsed to function. Defaults to {}.
|
||||
init_cfg (dict, optional): The config to control the initialization.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
function_name: Optional[str] = None,
|
||||
func_args: Dict = {},
|
||||
init_cfg: Optional[Dict] = None) -> None:
|
||||
super().__init__(init_cfg)
|
||||
assert function_name is not None, 'Arg `function_name` cannot be None'
|
||||
if function_name not in FUNCTION_LIST:
|
||||
raise ValueError(
|
||||
' Arg `function_name` are not available, See this list',
|
||||
FUNCTION_LIST)
|
||||
self.func = getattr(F, function_name)
|
||||
self.func_args = func_args
|
||||
|
||||
def forward_train(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Frorward function for training.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input features.
|
||||
"""
|
||||
x = self.func(x, **self.func_args)
|
||||
return x
|
||||
|
||||
|
||||
MODULE_LIST = [
|
||||
'AdaptiveAvgPool2d',
|
||||
'AdaptiveMaxPool2d',
|
||||
'AvgPool2d',
|
||||
'BatchNorm2d',
|
||||
'Conv2d',
|
||||
'Dropout',
|
||||
'Dropout2d',
|
||||
'Linear',
|
||||
'MaxPool2d',
|
||||
'ReLU',
|
||||
'Softmax',
|
||||
]
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class TorchNNConnector(BaseConnector):
|
||||
"""TorchNNConnector: create nn.module in torch.nn to process input data
|
||||
|
||||
usage:
|
||||
tensor1 = torch.rand(3,3,16,16)
|
||||
pool_connector = TorchNNConnector(
|
||||
module_name='AvgPool2d',
|
||||
module_args=dict(kernel_size=4),
|
||||
)
|
||||
tensor2 = pool_connector.forward_train(tensor1)
|
||||
tensor2.size()
|
||||
# torch.Size([3, 3, 4, 4])
|
||||
|
||||
which is equal to torch.nn.AvgPool2d(kernel_size=4)
|
||||
Args:
|
||||
module_name (str, optional):
|
||||
module name. Defaults to None.
|
||||
possible_values:['AvgPool2d',
|
||||
'Dropout2d',
|
||||
'AdaptiveAvgPool2d',
|
||||
'AdaptiveMaxPool2d',
|
||||
'ReLU',
|
||||
'Softmax',
|
||||
'BatchNorm2d',
|
||||
'Linear',]
|
||||
module_args (dict):
|
||||
args parsed to nn.Module().__init__(). Defaults to {}.
|
||||
init_cfg (dict, optional): The config to control the initialization.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
module_name: Optional[str] = None,
|
||||
module_args: Dict = {},
|
||||
init_cfg: Optional[Dict] = None) -> None:
|
||||
super().__init__(init_cfg)
|
||||
assert module_name is not None, 'Arg `module_name` cannot be None'
|
||||
if module_name not in MODULE_LIST:
|
||||
raise ValueError(
|
||||
' Arg `module_name` are not available, See this list',
|
||||
MODULE_LIST)
|
||||
self.func = getattr(nn, module_name)(**module_args)
|
||||
|
||||
def forward_train(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Frorward function for training.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input features.
|
||||
"""
|
||||
x = self.func(x)
|
||||
return x
|
|
@ -5,6 +5,7 @@ from .cwd import ChannelWiseDivergence
|
|||
from .dafl_loss import ActivationLoss, InformationEntropyLoss, OnehotLikeLoss
|
||||
from .decoupled_kd import DKDLoss
|
||||
from .factor_transfer_loss import FTLoss
|
||||
from .fbkd_loss import FBKDLoss
|
||||
from .kd_soft_ce_loss import KDSoftCELoss
|
||||
from .kl_divergence import KLDivergence
|
||||
from .l1_loss import L1Loss
|
||||
|
@ -15,5 +16,6 @@ from .weighted_soft_label_distillation import WSLD
|
|||
__all__ = [
|
||||
'ChannelWiseDivergence', 'KLDivergence', 'AngleWiseRKD', 'DistanceWiseRKD',
|
||||
'WSLD', 'L2Loss', 'ABLoss', 'DKDLoss', 'KDSoftCELoss', 'ActivationLoss',
|
||||
'OnehotLikeLoss', 'InformationEntropyLoss', 'FTLoss', 'ATLoss', 'L1Loss'
|
||||
'OnehotLikeLoss', 'InformationEntropyLoss', 'FTLoss', 'ATLoss', 'L1Loss',
|
||||
'FBKDLoss'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,120 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from mmrazor.registry import MODELS
|
||||
|
||||
|
||||
def mask_l2_loss(
|
||||
tensor_a: torch.Tensor,
|
||||
tensor_b: torch.Tensor,
|
||||
saptial_attention_mask: Optional[torch.Tensor] = None,
|
||||
channel_attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
"""L2 loss with two attention mask, which used to weight the feature
|
||||
distillation loss in FBKD.
|
||||
|
||||
Args:
|
||||
tensor_a (torch.Tensor): Student featuremap.
|
||||
tensor_b (torch.Tensor): Teacher featuremap.
|
||||
saptial_attention_mask (torch.Tensor, optional): Mask of spatial-wise
|
||||
attention. Defaults to None.
|
||||
channel_attention_mask (torch.Tensor, optional): Mask of channel-wise
|
||||
attention. Defaults to None.
|
||||
|
||||
Returns:
|
||||
diff (torch.Tensor): l2 loss with two attention mask.
|
||||
"""
|
||||
diff = (tensor_a - tensor_b)**2
|
||||
if saptial_attention_mask is not None:
|
||||
diff = diff * saptial_attention_mask
|
||||
if channel_attention_mask is not None:
|
||||
diff = diff * channel_attention_mask
|
||||
diff = torch.sum(diff)**0.5
|
||||
return diff
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class FBKDLoss(nn.Module):
|
||||
"""Loss For FBKD, which includs feat_loss, channel_loss, spatial_loss and
|
||||
nonlocal_loss.
|
||||
|
||||
Source code:
|
||||
https://github.com/ArchipLab-LinfengZhang/Object-Detection-Knowledge-
|
||||
Distillation-ICLR2021
|
||||
|
||||
Args:
|
||||
mask_l2_weight (float): The weight of the mask l2 loss.
|
||||
Defaults to 7e-5, which is the default value in source code.
|
||||
channel_weight (float): The weight of the channel loss.
|
||||
Defaults to 4e-3, which is the default value in source code.
|
||||
spatial_weight (float): The weight of the spatial loss.
|
||||
Defaults to 4e-3, which is the default value in source code.
|
||||
nonloacl_weight (float): The weight of the nonlocal loss.
|
||||
Defaults to 7e-5, which is the default value in source code.
|
||||
loss_weight (float): Weight of loss. Defaults to 1.0.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
mask_l2_weight: float = 7e-5,
|
||||
channel_weight: float = 4e-3,
|
||||
spatial_weight: float = 4e-3,
|
||||
nonloacl_weight: float = 7e-5,
|
||||
loss_weight: float = 1.0) -> None:
|
||||
"""Inits FBKDLoss."""
|
||||
super().__init__()
|
||||
|
||||
self.mask_l2_weight = mask_l2_weight
|
||||
self.channel_weight = channel_weight
|
||||
self.spatial_weight = spatial_weight
|
||||
self.nonloacl_weight = nonloacl_weight
|
||||
self.loss_weight = loss_weight
|
||||
|
||||
def forward(self, s_input: Tuple[torch.Tensor, ...],
|
||||
t_input: Tuple[torch.Tensor, ...]) -> torch.Tensor:
|
||||
"""Forward function of FBKDLoss, including feat_loss, channel_loss,
|
||||
spatial_loss and nonlocal_loss.
|
||||
|
||||
Args:
|
||||
s_input (Tuple[torch.Tensor, ...]): Student input which is the
|
||||
output of ``'FBKDStudentConnector'``.
|
||||
t_input (Tuple[torch.Tensor, ...]): Teacher input which is the
|
||||
output of ``'FBKDTeacherConnector'``.
|
||||
"""
|
||||
losses = 0.0
|
||||
|
||||
(s_spatial_mask, s_channel_mask, s_channel_pool_adapt,
|
||||
s_spatial_pool_adapt, s_relation_adapt, s_feat_adapt) = s_input
|
||||
|
||||
(t_spatial_mask, t_channel_mask, t_spatial_pool, t_relation,
|
||||
t_feat) = t_input
|
||||
|
||||
# Spatial-wise mask.
|
||||
spatial_sum_mask = (t_spatial_mask + s_spatial_mask) / 2
|
||||
spatial_sum_mask = spatial_sum_mask.detach()
|
||||
|
||||
# Channel-wise mask, but not used in the FBKD source code.
|
||||
channel_sum_mask = (t_channel_mask + s_channel_mask) / 2
|
||||
channel_sum_mask = channel_sum_mask.detach()
|
||||
|
||||
# feat_loss with mask
|
||||
losses += mask_l2_loss(
|
||||
t_feat,
|
||||
s_feat_adapt,
|
||||
saptial_attention_mask=spatial_sum_mask,
|
||||
channel_attention_mask=None) * self.mask_l2_weight
|
||||
|
||||
# channel_loss
|
||||
losses += torch.dist(torch.mean(t_feat, [2, 3]),
|
||||
s_channel_pool_adapt) * self.channel_weight
|
||||
|
||||
# spatial_loss
|
||||
losses += torch.dist(t_spatial_pool,
|
||||
s_spatial_pool_adapt) * self.spatial_weight
|
||||
|
||||
# nonlocal_loss
|
||||
losses += torch.dist(
|
||||
t_relation, s_relation_adapt, p=2) * self.nonloacl_weight
|
||||
|
||||
return self.loss_weight * losses
|
|
@ -3,8 +3,10 @@ from unittest import TestCase
|
|||
|
||||
import torch
|
||||
|
||||
from mmrazor.models import (BYOTConnector, ConvModuleConncetor, Paraphraser,
|
||||
Translator)
|
||||
from mmrazor.models import (BYOTConnector, ConvModuleConncetor,
|
||||
FBKDStudentConnector, FBKDTeacherConnector,
|
||||
Paraphraser, TorchFunctionalConnector,
|
||||
TorchNNConnector, Translator)
|
||||
|
||||
|
||||
class TestConnector(TestCase):
|
||||
|
@ -68,3 +70,46 @@ class TestConnector(TestCase):
|
|||
output, logits = byot_connector.forward_train(s_feat)
|
||||
assert output.size() == t_feat.size()
|
||||
assert logits.size() == labels.size()
|
||||
|
||||
def test_fbkd_connector(self):
|
||||
fbkd_stuconnector_cfg = dict(
|
||||
in_channels=16, reduction=2, sub_sample=True)
|
||||
fbkd_stuconnector = FBKDStudentConnector(**fbkd_stuconnector_cfg)
|
||||
|
||||
fbkd_teaconnector_cfg = dict(
|
||||
in_channels=16, reduction=2, sub_sample=True)
|
||||
fbkd_teaconnector = FBKDTeacherConnector(**fbkd_teaconnector_cfg)
|
||||
|
||||
s_feat = torch.randn(1, 16, 8, 8)
|
||||
t_feat = torch.randn(1, 16, 8, 8)
|
||||
|
||||
s_output = fbkd_stuconnector(s_feat)
|
||||
t_output = fbkd_teaconnector(t_feat)
|
||||
|
||||
assert len(s_output) == 6
|
||||
assert len(t_output) == 5
|
||||
assert torch.equal(t_output[-1], t_feat)
|
||||
|
||||
def test_torch_connector(self):
|
||||
tensor1 = torch.rand(3, 3, 16, 16)
|
||||
functional_pool_connector = TorchFunctionalConnector(
|
||||
function_name='avg_pool2d', func_args=dict(kernel_size=4))
|
||||
tensor2 = functional_pool_connector.forward_train(tensor1)
|
||||
assert tensor2.shape == torch.Size([3, 3, 4, 4])
|
||||
|
||||
with self.assertRaises(AssertionError):
|
||||
functional_pool_connector = TorchFunctionalConnector()
|
||||
with self.assertRaises(ValueError):
|
||||
functional_pool_connector = TorchFunctionalConnector(
|
||||
function_name='fake')
|
||||
|
||||
nn_pool_connector = TorchNNConnector(
|
||||
module_name='AvgPool2d', module_args=dict(kernel_size=4))
|
||||
tensor3 = nn_pool_connector.forward_train(tensor1)
|
||||
assert tensor3.shape == torch.Size([3, 3, 4, 4])
|
||||
assert torch.equal(tensor2, tensor3)
|
||||
|
||||
with self.assertRaises(AssertionError):
|
||||
functional_pool_connector = TorchFunctionalConnector()
|
||||
with self.assertRaises(ValueError):
|
||||
functional_pool_connector = TorchNNConnector(module_name='fake')
|
||||
|
|
|
@ -4,8 +4,8 @@ from unittest import TestCase
|
|||
import torch
|
||||
|
||||
from mmrazor import digit_version
|
||||
from mmrazor.models import (ABLoss, ActivationLoss, ATLoss, DKDLoss, FTLoss,
|
||||
InformationEntropyLoss, KDSoftCELoss,
|
||||
from mmrazor.models import (ABLoss, ActivationLoss, ATLoss, DKDLoss, FBKDLoss,
|
||||
FTLoss, InformationEntropyLoss, KDSoftCELoss,
|
||||
OnehotLikeLoss)
|
||||
|
||||
|
||||
|
@ -113,3 +113,20 @@ class TestLosses(TestCase):
|
|||
self.normal_test_1d(at_loss)
|
||||
self.normal_test_2d(at_loss)
|
||||
self.normal_test_3d(at_loss)
|
||||
|
||||
def test_fbkdloss(self):
|
||||
fbkdloss_cfg = dict(loss_weight=1.0)
|
||||
fbkdloss = FBKDLoss(**fbkdloss_cfg)
|
||||
|
||||
spatial_mask = torch.randn(1, 1, 3, 3)
|
||||
channel_mask = torch.randn(1, 4, 1, 1)
|
||||
channel_pool_adapt = torch.randn(1, 4)
|
||||
relation_adpt = torch.randn(1, 4, 3, 3)
|
||||
|
||||
s_input = (spatial_mask, channel_mask, channel_pool_adapt,
|
||||
spatial_mask, channel_mask, relation_adpt)
|
||||
t_input = (spatial_mask, channel_mask, spatial_mask, channel_mask,
|
||||
relation_adpt)
|
||||
|
||||
fbkd_loss = fbkdloss(s_input, t_input)
|
||||
self.assertTrue(fbkd_loss.numel() == 1)
|
||||
|
|
Loading…
Reference in New Issue