[Feature]Support Relational Knowledge Distillation (#127)
* add rkd * add rkd pytest * add rkd configs * fix readme * fix rkd * split rkd loss to distance-wise and angle-wise losses * rename rkd losses * add rkd metaflie * add rkd related links * rename rkd metafile and add to model index * delete cifar100 Co-authored-by: caoweihan <caoweihan@sensetime.com> Co-authored-by: pppppM <gjf_mail@126.com>pull/135/head
parent
f9920a403e
commit
de4dd13cc6
|
@ -0,0 +1,42 @@
|
|||
# RKD
|
||||
|
||||
|
||||
|
||||
> [Relational Knowledge Distillation](https://arxiv.org/abs/1904.05068)
|
||||
|
||||
<!-- [ALGORITHM] -->
|
||||
## Abstract
|
||||
Knowledge distillation aims at transferring knowledge acquired
|
||||
in one model (a teacher) to another model (a student) that is
|
||||
typically smaller. Previous approaches can be expressed as
|
||||
a form of training the student to mimic output activations of
|
||||
individual data examples represented by the teacher. We introduce
|
||||
a novel approach, dubbed relational knowledge distillation (RKD),
|
||||
that transfers mutual relations of data examples instead.
|
||||
For concrete realizations of RKD, we propose distance-wise and
|
||||
angle-wise distillation losses that penalize structural differences
|
||||
in relations. Experiments conducted on different tasks show that the
|
||||
proposed method improves educated student models with a significant margin.
|
||||
In particular for metric learning, it allows students to outperform their
|
||||
teachers' performance, achieving the state of the arts on standard benchmark datasets.
|
||||
|
||||

|
||||
|
||||
## Results and models
|
||||
### Classification
|
||||
|Location|Dataset|Teacher|Student|Acc|Acc(T)|Acc(S)|Config | Download |
|
||||
:--------:|:---------:|:---------:|:---------:|:---------:|:---------:|:---------:|:------:|:---------|
|
||||
| neck |ImageNet|[resnet34](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet34_8xb32_in1k.py)|[resnet18](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet18_8xb32_in1k.py)| 70.23 | 73.62 | 69.90 |[config](./rkd_neck_resnet34_resnet18_8xb32_in1k.py)|[teacher](https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_b16x8_cifar10_20210528-a8aa36a6.pth) |[model](https://download.openmmlab.com/mmrazor/v0.3/distill/rkd/rkd_neck_resnet34_resnet18_8xb32_in1k_acc-70.23_20220401-f25700ac.pth) | [log](https://download.openmmlab.com/mmrazor/v0.3/distill/rkd/rkd_neck_resnet34_resnet18_8xb32_in1k_20220312_130419.log.json)|
|
||||
|
||||
|
||||
|
||||
## Citation
|
||||
```latex
|
||||
@inproceedings{park2019relational,
|
||||
title={Relational knowledge distillation},
|
||||
author={Park, Wonpyo and Kim, Dongju and Lu, Yan and Cho, Minsu},
|
||||
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
|
||||
pages={3967--3976},
|
||||
year={2019}
|
||||
}
|
||||
```
|
|
@ -0,0 +1,31 @@
|
|||
Collections:
|
||||
- Name: RKD
|
||||
Metadata:
|
||||
Training Data:
|
||||
- ImageNet-1k
|
||||
Paper:
|
||||
URL: https://arxiv.org/abs/1904.05068
|
||||
Title: Relational Knowledge Distillation
|
||||
README: configs/distill/rkd/README.md
|
||||
Code:
|
||||
URL: https://github.com/open-mmlab/mmrazor/blob/v0.3.0/mmrazor/models/losses/relation_kd.py
|
||||
Version: v0.3.0
|
||||
Converted From:
|
||||
Code: https://github.com/lenscloth/RKD
|
||||
Models:
|
||||
- Name: rkd_neck_resnet34_resnet18_8xb32_in1k
|
||||
In Collection: RKD
|
||||
Metadata:
|
||||
Location: neck
|
||||
Student: R-18
|
||||
Teacher: R-34
|
||||
Teacher Checkpoint: https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth
|
||||
Results:
|
||||
- Task: Image Classification
|
||||
Dataset: ImageNet-1k
|
||||
Metrics:
|
||||
Top 1 Accuracy: 70.23
|
||||
Top 1 Accuracy:(S): 69.90
|
||||
Top 1 Accuracy:(T): 73.62
|
||||
Config: configs/distill/rkd/rkd_neck_resnet34_resnet18_8xb32_in1k.py
|
||||
Weights: https://download.openmmlab.com/mmrazor/v0.3/distill/rkd/rkd_neck_resnet34_resnet18_8xb32_in1k_acc-70.23_20220401-f25700ac.pth
|
|
@ -0,0 +1,79 @@
|
|||
_base_ = [
|
||||
'../../_base_/datasets/mmcls/imagenet_bs32.py',
|
||||
'../../_base_/schedules/mmcls/imagenet_bs256.py',
|
||||
'../../_base_/mmcls_runtime.py'
|
||||
]
|
||||
|
||||
# model settings
|
||||
student = dict(
|
||||
type='mmcls.ImageClassifier',
|
||||
backbone=dict(
|
||||
type='ResNet',
|
||||
depth=18,
|
||||
num_stages=4,
|
||||
out_indices=(3, ),
|
||||
style='pytorch'),
|
||||
neck=dict(type='GlobalAveragePooling'),
|
||||
head=dict(
|
||||
type='LinearClsHead',
|
||||
num_classes=1000,
|
||||
in_channels=512,
|
||||
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
|
||||
topk=(1, 5),
|
||||
))
|
||||
|
||||
# teacher settings
|
||||
teacher_ckpt = 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth' # noqa: E501
|
||||
|
||||
teacher = dict(
|
||||
type='mmcls.ImageClassifier',
|
||||
init_cfg=dict(type='Pretrained', checkpoint=teacher_ckpt),
|
||||
backbone=dict(
|
||||
type='ResNet',
|
||||
depth=34,
|
||||
num_stages=4,
|
||||
out_indices=(3, ),
|
||||
style='pytorch'),
|
||||
neck=dict(type='GlobalAveragePooling'),
|
||||
head=dict(
|
||||
type='LinearClsHead',
|
||||
num_classes=1000,
|
||||
in_channels=512,
|
||||
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
|
||||
topk=(1, 5),
|
||||
))
|
||||
|
||||
# algorithm setting
|
||||
algorithm = dict(
|
||||
type='GeneralDistill',
|
||||
architecture=dict(
|
||||
type='MMClsArchitecture',
|
||||
model=student,
|
||||
),
|
||||
with_student_loss=True,
|
||||
with_teacher_loss=False,
|
||||
distiller=dict(
|
||||
type='SingleTeacherDistiller',
|
||||
teacher=teacher,
|
||||
teacher_trainable=False,
|
||||
teacher_norm_eval=True,
|
||||
components=[
|
||||
dict(
|
||||
student_module='neck.gap',
|
||||
teacher_module='neck.gap',
|
||||
losses=[
|
||||
dict(
|
||||
type='DistanceWiseRKD',
|
||||
name='distance_wise_loss',
|
||||
loss_weight=25.0,
|
||||
with_l2_norm=True),
|
||||
dict(
|
||||
type='AngleWiseRKD',
|
||||
name='angle_wise_loss',
|
||||
loss_weight=50.0,
|
||||
with_l2_norm=True),
|
||||
])
|
||||
]),
|
||||
)
|
||||
|
||||
find_unused_parameters = True
|
Binary file not shown.
After Width: | Height: | Size: 47 KiB |
|
@ -1,6 +1,10 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .cwd import ChannelWiseDivergence
|
||||
from .kl_divergence import KLDivergence
|
||||
from .relational_kd import AngleWiseRKD, DistanceWiseRKD
|
||||
from .weighted_soft_label_distillation import WSLD
|
||||
|
||||
__all__ = ['ChannelWiseDivergence', 'KLDivergence', 'WSLD']
|
||||
__all__ = [
|
||||
'ChannelWiseDivergence', 'KLDivergence', 'AngleWiseRKD', 'DistanceWiseRKD',
|
||||
'WSLD'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,149 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ..builder import LOSSES
|
||||
|
||||
|
||||
def euclidean_distance(pred, squared=False, eps=1e-12):
|
||||
"""Calculate the Euclidean distance between the two examples in the output
|
||||
representation space.
|
||||
|
||||
Args:
|
||||
pred (torch.Tensor): The prediction of the teacher or student with
|
||||
shape (N, C).
|
||||
squared (bool): Whether to calculate the squared Euclidean
|
||||
distance. Defaults to False.
|
||||
eps (float): The minimum Euclidean distance between the two
|
||||
examples. Defaults to 1e-12.
|
||||
"""
|
||||
pred_square = pred.pow(2).sum(dim=-1) # (N, )
|
||||
prod = torch.mm(pred, pred.t()) # (N, N)
|
||||
distance = (pred_square.unsqueeze(1) + pred_square.unsqueeze(0) -
|
||||
2 * prod).clamp(min=eps) # (N, N)
|
||||
|
||||
if not squared:
|
||||
distance = distance.sqrt()
|
||||
|
||||
distance = distance.clone()
|
||||
distance[range(len(prod)), range(len(prod))] = 0
|
||||
return distance
|
||||
|
||||
|
||||
def angle(pred):
|
||||
"""Calculate the angle-wise relational potential which measures the angle
|
||||
formed by the three examples in the output representation space.
|
||||
|
||||
Args:
|
||||
pred (torch.Tensor): The prediction of the teacher or student with
|
||||
shape (N, C).
|
||||
"""
|
||||
pred_vec = pred.unsqueeze(0) - pred.unsqueeze(1) # (N, N, C)
|
||||
norm_pred_vec = F.normalize(pred_vec, p=2, dim=2)
|
||||
angle = torch.bmm(norm_pred_vec,
|
||||
norm_pred_vec.transpose(1, 2)).view(-1) # (N*N*N, )
|
||||
return angle
|
||||
|
||||
|
||||
@LOSSES.register_module()
|
||||
class DistanceWiseRKD(nn.Module):
|
||||
"""PyTorch version of distance-wise loss of `Relational Knowledge
|
||||
Distillation.
|
||||
|
||||
<https://arxiv.org/abs/1904.05068>`_.
|
||||
|
||||
Args:
|
||||
loss_weight (float): Weight of distance-wise distillation loss.
|
||||
Defaults to 25.0.
|
||||
with_l2_norm (bool): Whether to normalize the model predictions before
|
||||
calculating the loss. Defaults to True.
|
||||
"""
|
||||
|
||||
def __init__(self, loss_weight=25.0, with_l2_norm=True):
|
||||
super(DistanceWiseRKD, self).__init__()
|
||||
|
||||
self.loss_weight = loss_weight
|
||||
self.with_l2_norm = with_l2_norm
|
||||
|
||||
def distance_loss(self, preds_S, preds_T):
|
||||
"""Calculate distance-wise distillation loss."""
|
||||
d_T = euclidean_distance(preds_T, squared=False)
|
||||
# mean_d_T is a normalization factor for distance
|
||||
mean_d_T = d_T[d_T > 0].mean()
|
||||
d_T = d_T / mean_d_T
|
||||
|
||||
d_S = euclidean_distance(preds_S, squared=False)
|
||||
mean_d_S = d_S[d_S > 0].mean()
|
||||
d_S = d_S / mean_d_S
|
||||
|
||||
return F.smooth_l1_loss(d_S, d_T)
|
||||
|
||||
def forward(self, preds_S, preds_T):
|
||||
"""Forward computation.
|
||||
|
||||
Args:
|
||||
preds_S (torch.Tensor): The student model prediction with
|
||||
shape (N, C, H, W) or shape (N, C).
|
||||
preds_T (torch.Tensor): The teacher model prediction with
|
||||
shape (N, C, H, W) or shape (N, C).
|
||||
Return:
|
||||
torch.Tensor: The calculated loss value.
|
||||
"""
|
||||
preds_S = preds_S.view(preds_S.shape[0], -1)
|
||||
preds_T = preds_T.view(preds_T.shape[0], -1)
|
||||
if self.with_l2_norm:
|
||||
preds_S = F.normalize(preds_S, p=2, dim=1)
|
||||
preds_T = F.normalize(preds_T, p=2, dim=1)
|
||||
|
||||
loss = self.distance_loss(preds_S, preds_T) * self.loss_weight
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
@LOSSES.register_module()
|
||||
class AngleWiseRKD(nn.Module):
|
||||
"""PyTorch version of angle-wise loss of `Relational Knowledge
|
||||
Distillation.
|
||||
|
||||
<https://arxiv.org/abs/1904.05068>`_.
|
||||
|
||||
Args:
|
||||
loss_weight (float): Weight of angle-wise distillation loss.
|
||||
Defaults to 50.0.
|
||||
with_l2_norm (bool): Whether to normalize the model predictions before
|
||||
calculating the loss. Defaults to True.
|
||||
"""
|
||||
|
||||
def __init__(self, loss_weight=50.0, with_l2_norm=True):
|
||||
super(AngleWiseRKD, self).__init__()
|
||||
|
||||
self.loss_weight = loss_weight
|
||||
self.with_l2_norm = with_l2_norm
|
||||
|
||||
def angle_loss(self, preds_S, preds_T):
|
||||
"""Calculate the angle-wise distillation loss."""
|
||||
angle_T = angle(preds_T)
|
||||
angle_S = angle(preds_S)
|
||||
return F.smooth_l1_loss(angle_S, angle_T)
|
||||
|
||||
def forward(self, preds_S, preds_T):
|
||||
"""Forward computation.
|
||||
|
||||
Args:
|
||||
preds_S (torch.Tensor): The student model prediction with
|
||||
shape (N, C, H, W) or shape (N, C).
|
||||
preds_T (torch.Tensor): The teacher model prediction with
|
||||
shape (N, C, H, W) or shape (N, C).
|
||||
Return:
|
||||
torch.Tensor: The calculated loss value.
|
||||
"""
|
||||
preds_S = preds_S.view(preds_S.shape[0], -1)
|
||||
preds_T = preds_T.view(preds_T.shape[0], -1)
|
||||
if self.with_l2_norm:
|
||||
preds_S = F.normalize(preds_S, p=2, dim=-1)
|
||||
preds_T = F.normalize(preds_T, p=2, dim=-1)
|
||||
|
||||
loss = self.angle_loss(preds_S, preds_T) * self.loss_weight
|
||||
|
||||
return loss
|
|
@ -1,6 +1,7 @@
|
|||
Import:
|
||||
- configs/distill/cwd/metafile.yml
|
||||
- configs/distill/wsld/metafile.yml
|
||||
- configs/distill/rkd/metafile.yml
|
||||
- configs/nas/darts/metafile.yml
|
||||
- configs/nas/detnas/metafile.yml
|
||||
- configs/nas/spos/metafile.yml
|
||||
|
|
|
@ -436,3 +436,113 @@ def test_cwd():
|
|||
# test algorithm train_step
|
||||
losses = algorithm.train_step(mm_inputs, None)
|
||||
assert losses['loss'].item() > 0
|
||||
|
||||
|
||||
def test_rkd():
|
||||
student = dict(
|
||||
type='mmcls.ImageClassifier',
|
||||
backbone=dict(
|
||||
type='ResNet',
|
||||
depth=18,
|
||||
num_stages=4,
|
||||
out_indices=(3, ),
|
||||
style='pytorch'),
|
||||
neck=dict(type='GlobalAveragePooling'),
|
||||
head=dict(
|
||||
type='LinearClsHead',
|
||||
num_classes=1000,
|
||||
in_channels=512,
|
||||
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
|
||||
topk=(1, 5),
|
||||
))
|
||||
|
||||
teacher = dict(
|
||||
type='mmcls.ImageClassifier',
|
||||
backbone=dict(
|
||||
type='ResNet',
|
||||
depth=34,
|
||||
num_stages=4,
|
||||
out_indices=(3, ),
|
||||
style='pytorch'),
|
||||
neck=dict(type='GlobalAveragePooling'),
|
||||
head=dict(
|
||||
type='LinearClsHead',
|
||||
num_classes=1000,
|
||||
in_channels=512,
|
||||
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
|
||||
topk=(1, 5),
|
||||
))
|
||||
|
||||
# test RelationalKD w/ l2 norm
|
||||
algorithm_cfg = ConfigDict(
|
||||
type='GeneralDistill',
|
||||
architecture=dict(
|
||||
type='MMClsArchitecture',
|
||||
model=student,
|
||||
),
|
||||
with_student_loss=True,
|
||||
with_teacher_loss=False,
|
||||
distiller=dict(
|
||||
type='SingleTeacherDistiller',
|
||||
teacher=teacher,
|
||||
teacher_trainable=False,
|
||||
teacher_norm_eval=True,
|
||||
components=[
|
||||
dict(
|
||||
student_module='neck.gap',
|
||||
teacher_module='neck.gap',
|
||||
losses=[
|
||||
dict(
|
||||
type='DistanceWiseRKD',
|
||||
name='distance_wise_loss',
|
||||
loss_weight=25.0,
|
||||
with_l2_norm=True),
|
||||
dict(
|
||||
type='AngleWiseRKD',
|
||||
name='angle_wise_loss',
|
||||
loss_weight=50.0,
|
||||
with_l2_norm=True),
|
||||
])
|
||||
]),
|
||||
)
|
||||
|
||||
imgs = torch.randn(16, 3, 32, 32)
|
||||
label = torch.randint(0, 10, (16, ))
|
||||
|
||||
algorithm = ALGORITHMS.build(algorithm_cfg)
|
||||
|
||||
optimizer = torch.optim.SGD(algorithm.parameters(), lr=0.01)
|
||||
outputs = algorithm.train_step({'img': imgs, 'gt_label': label}, optimizer)
|
||||
assert outputs['loss'].item() > 0
|
||||
assert outputs['num_samples'] == 16
|
||||
|
||||
# test forward
|
||||
losses = algorithm(imgs, return_loss=True, gt_label=label)
|
||||
assert losses['loss'].item() > 0
|
||||
|
||||
# test RelationalKD w/o l2 norm
|
||||
algorithm_cfg.distiller.components = [
|
||||
dict(
|
||||
student_module='neck.gap',
|
||||
teacher_module='neck.gap',
|
||||
losses=[
|
||||
dict(
|
||||
type='DistanceWiseRKD',
|
||||
name='distance_wise_loss',
|
||||
loss_weight=25.0,
|
||||
with_l2_norm=False),
|
||||
dict(
|
||||
type='AngleWiseRKD',
|
||||
name='angle_wise_loss',
|
||||
loss_weight=50.0,
|
||||
with_l2_norm=False),
|
||||
])
|
||||
]
|
||||
|
||||
optimizer = torch.optim.SGD(algorithm.parameters(), lr=0.01)
|
||||
outputs = algorithm.train_step({'img': imgs, 'gt_label': label}, optimizer)
|
||||
assert outputs['loss'].item() > 0
|
||||
assert outputs['num_samples'] == 16
|
||||
|
||||
losses = algorithm(imgs, return_loss=True, gt_label=label)
|
||||
assert losses['loss'].item() > 0
|
||||
|
|
Loading…
Reference in New Issue