[Feature]Support Relational Knowledge Distillation (#127)
* add rkd * add rkd pytest * add rkd configs * fix readme * fix rkd * split rkd loss to distance-wise and angle-wise losses * rename rkd losses * add rkd metaflie * add rkd related links * rename rkd metafile and add to model index * delete cifar100 Co-authored-by: caoweihan <caoweihan@sensetime.com> Co-authored-by: pppppM <gjf_mail@126.com>pull/135/head
parent
f9920a403e
commit
de4dd13cc6
|
@ -0,0 +1,42 @@
|
||||||
|
# RKD
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
> [Relational Knowledge Distillation](https://arxiv.org/abs/1904.05068)
|
||||||
|
|
||||||
|
<!-- [ALGORITHM] -->
|
||||||
|
## Abstract
|
||||||
|
Knowledge distillation aims at transferring knowledge acquired
|
||||||
|
in one model (a teacher) to another model (a student) that is
|
||||||
|
typically smaller. Previous approaches can be expressed as
|
||||||
|
a form of training the student to mimic output activations of
|
||||||
|
individual data examples represented by the teacher. We introduce
|
||||||
|
a novel approach, dubbed relational knowledge distillation (RKD),
|
||||||
|
that transfers mutual relations of data examples instead.
|
||||||
|
For concrete realizations of RKD, we propose distance-wise and
|
||||||
|
angle-wise distillation losses that penalize structural differences
|
||||||
|
in relations. Experiments conducted on different tasks show that the
|
||||||
|
proposed method improves educated student models with a significant margin.
|
||||||
|
In particular for metric learning, it allows students to outperform their
|
||||||
|
teachers' performance, achieving the state of the arts on standard benchmark datasets.
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
## Results and models
|
||||||
|
### Classification
|
||||||
|
|Location|Dataset|Teacher|Student|Acc|Acc(T)|Acc(S)|Config | Download |
|
||||||
|
:--------:|:---------:|:---------:|:---------:|:---------:|:---------:|:---------:|:------:|:---------|
|
||||||
|
| neck |ImageNet|[resnet34](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet34_8xb32_in1k.py)|[resnet18](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet18_8xb32_in1k.py)| 70.23 | 73.62 | 69.90 |[config](./rkd_neck_resnet34_resnet18_8xb32_in1k.py)|[teacher](https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_b16x8_cifar10_20210528-a8aa36a6.pth) |[model](https://download.openmmlab.com/mmrazor/v0.3/distill/rkd/rkd_neck_resnet34_resnet18_8xb32_in1k_acc-70.23_20220401-f25700ac.pth) | [log](https://download.openmmlab.com/mmrazor/v0.3/distill/rkd/rkd_neck_resnet34_resnet18_8xb32_in1k_20220312_130419.log.json)|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
## Citation
|
||||||
|
```latex
|
||||||
|
@inproceedings{park2019relational,
|
||||||
|
title={Relational knowledge distillation},
|
||||||
|
author={Park, Wonpyo and Kim, Dongju and Lu, Yan and Cho, Minsu},
|
||||||
|
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
|
||||||
|
pages={3967--3976},
|
||||||
|
year={2019}
|
||||||
|
}
|
||||||
|
```
|
|
@ -0,0 +1,31 @@
|
||||||
|
Collections:
|
||||||
|
- Name: RKD
|
||||||
|
Metadata:
|
||||||
|
Training Data:
|
||||||
|
- ImageNet-1k
|
||||||
|
Paper:
|
||||||
|
URL: https://arxiv.org/abs/1904.05068
|
||||||
|
Title: Relational Knowledge Distillation
|
||||||
|
README: configs/distill/rkd/README.md
|
||||||
|
Code:
|
||||||
|
URL: https://github.com/open-mmlab/mmrazor/blob/v0.3.0/mmrazor/models/losses/relation_kd.py
|
||||||
|
Version: v0.3.0
|
||||||
|
Converted From:
|
||||||
|
Code: https://github.com/lenscloth/RKD
|
||||||
|
Models:
|
||||||
|
- Name: rkd_neck_resnet34_resnet18_8xb32_in1k
|
||||||
|
In Collection: RKD
|
||||||
|
Metadata:
|
||||||
|
Location: neck
|
||||||
|
Student: R-18
|
||||||
|
Teacher: R-34
|
||||||
|
Teacher Checkpoint: https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth
|
||||||
|
Results:
|
||||||
|
- Task: Image Classification
|
||||||
|
Dataset: ImageNet-1k
|
||||||
|
Metrics:
|
||||||
|
Top 1 Accuracy: 70.23
|
||||||
|
Top 1 Accuracy:(S): 69.90
|
||||||
|
Top 1 Accuracy:(T): 73.62
|
||||||
|
Config: configs/distill/rkd/rkd_neck_resnet34_resnet18_8xb32_in1k.py
|
||||||
|
Weights: https://download.openmmlab.com/mmrazor/v0.3/distill/rkd/rkd_neck_resnet34_resnet18_8xb32_in1k_acc-70.23_20220401-f25700ac.pth
|
|
@ -0,0 +1,79 @@
|
||||||
|
_base_ = [
|
||||||
|
'../../_base_/datasets/mmcls/imagenet_bs32.py',
|
||||||
|
'../../_base_/schedules/mmcls/imagenet_bs256.py',
|
||||||
|
'../../_base_/mmcls_runtime.py'
|
||||||
|
]
|
||||||
|
|
||||||
|
# model settings
|
||||||
|
student = dict(
|
||||||
|
type='mmcls.ImageClassifier',
|
||||||
|
backbone=dict(
|
||||||
|
type='ResNet',
|
||||||
|
depth=18,
|
||||||
|
num_stages=4,
|
||||||
|
out_indices=(3, ),
|
||||||
|
style='pytorch'),
|
||||||
|
neck=dict(type='GlobalAveragePooling'),
|
||||||
|
head=dict(
|
||||||
|
type='LinearClsHead',
|
||||||
|
num_classes=1000,
|
||||||
|
in_channels=512,
|
||||||
|
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
|
||||||
|
topk=(1, 5),
|
||||||
|
))
|
||||||
|
|
||||||
|
# teacher settings
|
||||||
|
teacher_ckpt = 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth' # noqa: E501
|
||||||
|
|
||||||
|
teacher = dict(
|
||||||
|
type='mmcls.ImageClassifier',
|
||||||
|
init_cfg=dict(type='Pretrained', checkpoint=teacher_ckpt),
|
||||||
|
backbone=dict(
|
||||||
|
type='ResNet',
|
||||||
|
depth=34,
|
||||||
|
num_stages=4,
|
||||||
|
out_indices=(3, ),
|
||||||
|
style='pytorch'),
|
||||||
|
neck=dict(type='GlobalAveragePooling'),
|
||||||
|
head=dict(
|
||||||
|
type='LinearClsHead',
|
||||||
|
num_classes=1000,
|
||||||
|
in_channels=512,
|
||||||
|
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
|
||||||
|
topk=(1, 5),
|
||||||
|
))
|
||||||
|
|
||||||
|
# algorithm setting
|
||||||
|
algorithm = dict(
|
||||||
|
type='GeneralDistill',
|
||||||
|
architecture=dict(
|
||||||
|
type='MMClsArchitecture',
|
||||||
|
model=student,
|
||||||
|
),
|
||||||
|
with_student_loss=True,
|
||||||
|
with_teacher_loss=False,
|
||||||
|
distiller=dict(
|
||||||
|
type='SingleTeacherDistiller',
|
||||||
|
teacher=teacher,
|
||||||
|
teacher_trainable=False,
|
||||||
|
teacher_norm_eval=True,
|
||||||
|
components=[
|
||||||
|
dict(
|
||||||
|
student_module='neck.gap',
|
||||||
|
teacher_module='neck.gap',
|
||||||
|
losses=[
|
||||||
|
dict(
|
||||||
|
type='DistanceWiseRKD',
|
||||||
|
name='distance_wise_loss',
|
||||||
|
loss_weight=25.0,
|
||||||
|
with_l2_norm=True),
|
||||||
|
dict(
|
||||||
|
type='AngleWiseRKD',
|
||||||
|
name='angle_wise_loss',
|
||||||
|
loss_weight=50.0,
|
||||||
|
with_l2_norm=True),
|
||||||
|
])
|
||||||
|
]),
|
||||||
|
)
|
||||||
|
|
||||||
|
find_unused_parameters = True
|
Binary file not shown.
After Width: | Height: | Size: 47 KiB |
|
@ -1,6 +1,10 @@
|
||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
from .cwd import ChannelWiseDivergence
|
from .cwd import ChannelWiseDivergence
|
||||||
from .kl_divergence import KLDivergence
|
from .kl_divergence import KLDivergence
|
||||||
|
from .relational_kd import AngleWiseRKD, DistanceWiseRKD
|
||||||
from .weighted_soft_label_distillation import WSLD
|
from .weighted_soft_label_distillation import WSLD
|
||||||
|
|
||||||
__all__ = ['ChannelWiseDivergence', 'KLDivergence', 'WSLD']
|
__all__ = [
|
||||||
|
'ChannelWiseDivergence', 'KLDivergence', 'AngleWiseRKD', 'DistanceWiseRKD',
|
||||||
|
'WSLD'
|
||||||
|
]
|
||||||
|
|
|
@ -0,0 +1,149 @@
|
||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from ..builder import LOSSES
|
||||||
|
|
||||||
|
|
||||||
|
def euclidean_distance(pred, squared=False, eps=1e-12):
|
||||||
|
"""Calculate the Euclidean distance between the two examples in the output
|
||||||
|
representation space.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pred (torch.Tensor): The prediction of the teacher or student with
|
||||||
|
shape (N, C).
|
||||||
|
squared (bool): Whether to calculate the squared Euclidean
|
||||||
|
distance. Defaults to False.
|
||||||
|
eps (float): The minimum Euclidean distance between the two
|
||||||
|
examples. Defaults to 1e-12.
|
||||||
|
"""
|
||||||
|
pred_square = pred.pow(2).sum(dim=-1) # (N, )
|
||||||
|
prod = torch.mm(pred, pred.t()) # (N, N)
|
||||||
|
distance = (pred_square.unsqueeze(1) + pred_square.unsqueeze(0) -
|
||||||
|
2 * prod).clamp(min=eps) # (N, N)
|
||||||
|
|
||||||
|
if not squared:
|
||||||
|
distance = distance.sqrt()
|
||||||
|
|
||||||
|
distance = distance.clone()
|
||||||
|
distance[range(len(prod)), range(len(prod))] = 0
|
||||||
|
return distance
|
||||||
|
|
||||||
|
|
||||||
|
def angle(pred):
|
||||||
|
"""Calculate the angle-wise relational potential which measures the angle
|
||||||
|
formed by the three examples in the output representation space.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pred (torch.Tensor): The prediction of the teacher or student with
|
||||||
|
shape (N, C).
|
||||||
|
"""
|
||||||
|
pred_vec = pred.unsqueeze(0) - pred.unsqueeze(1) # (N, N, C)
|
||||||
|
norm_pred_vec = F.normalize(pred_vec, p=2, dim=2)
|
||||||
|
angle = torch.bmm(norm_pred_vec,
|
||||||
|
norm_pred_vec.transpose(1, 2)).view(-1) # (N*N*N, )
|
||||||
|
return angle
|
||||||
|
|
||||||
|
|
||||||
|
@LOSSES.register_module()
|
||||||
|
class DistanceWiseRKD(nn.Module):
|
||||||
|
"""PyTorch version of distance-wise loss of `Relational Knowledge
|
||||||
|
Distillation.
|
||||||
|
|
||||||
|
<https://arxiv.org/abs/1904.05068>`_.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
loss_weight (float): Weight of distance-wise distillation loss.
|
||||||
|
Defaults to 25.0.
|
||||||
|
with_l2_norm (bool): Whether to normalize the model predictions before
|
||||||
|
calculating the loss. Defaults to True.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, loss_weight=25.0, with_l2_norm=True):
|
||||||
|
super(DistanceWiseRKD, self).__init__()
|
||||||
|
|
||||||
|
self.loss_weight = loss_weight
|
||||||
|
self.with_l2_norm = with_l2_norm
|
||||||
|
|
||||||
|
def distance_loss(self, preds_S, preds_T):
|
||||||
|
"""Calculate distance-wise distillation loss."""
|
||||||
|
d_T = euclidean_distance(preds_T, squared=False)
|
||||||
|
# mean_d_T is a normalization factor for distance
|
||||||
|
mean_d_T = d_T[d_T > 0].mean()
|
||||||
|
d_T = d_T / mean_d_T
|
||||||
|
|
||||||
|
d_S = euclidean_distance(preds_S, squared=False)
|
||||||
|
mean_d_S = d_S[d_S > 0].mean()
|
||||||
|
d_S = d_S / mean_d_S
|
||||||
|
|
||||||
|
return F.smooth_l1_loss(d_S, d_T)
|
||||||
|
|
||||||
|
def forward(self, preds_S, preds_T):
|
||||||
|
"""Forward computation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
preds_S (torch.Tensor): The student model prediction with
|
||||||
|
shape (N, C, H, W) or shape (N, C).
|
||||||
|
preds_T (torch.Tensor): The teacher model prediction with
|
||||||
|
shape (N, C, H, W) or shape (N, C).
|
||||||
|
Return:
|
||||||
|
torch.Tensor: The calculated loss value.
|
||||||
|
"""
|
||||||
|
preds_S = preds_S.view(preds_S.shape[0], -1)
|
||||||
|
preds_T = preds_T.view(preds_T.shape[0], -1)
|
||||||
|
if self.with_l2_norm:
|
||||||
|
preds_S = F.normalize(preds_S, p=2, dim=1)
|
||||||
|
preds_T = F.normalize(preds_T, p=2, dim=1)
|
||||||
|
|
||||||
|
loss = self.distance_loss(preds_S, preds_T) * self.loss_weight
|
||||||
|
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
@LOSSES.register_module()
|
||||||
|
class AngleWiseRKD(nn.Module):
|
||||||
|
"""PyTorch version of angle-wise loss of `Relational Knowledge
|
||||||
|
Distillation.
|
||||||
|
|
||||||
|
<https://arxiv.org/abs/1904.05068>`_.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
loss_weight (float): Weight of angle-wise distillation loss.
|
||||||
|
Defaults to 50.0.
|
||||||
|
with_l2_norm (bool): Whether to normalize the model predictions before
|
||||||
|
calculating the loss. Defaults to True.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, loss_weight=50.0, with_l2_norm=True):
|
||||||
|
super(AngleWiseRKD, self).__init__()
|
||||||
|
|
||||||
|
self.loss_weight = loss_weight
|
||||||
|
self.with_l2_norm = with_l2_norm
|
||||||
|
|
||||||
|
def angle_loss(self, preds_S, preds_T):
|
||||||
|
"""Calculate the angle-wise distillation loss."""
|
||||||
|
angle_T = angle(preds_T)
|
||||||
|
angle_S = angle(preds_S)
|
||||||
|
return F.smooth_l1_loss(angle_S, angle_T)
|
||||||
|
|
||||||
|
def forward(self, preds_S, preds_T):
|
||||||
|
"""Forward computation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
preds_S (torch.Tensor): The student model prediction with
|
||||||
|
shape (N, C, H, W) or shape (N, C).
|
||||||
|
preds_T (torch.Tensor): The teacher model prediction with
|
||||||
|
shape (N, C, H, W) or shape (N, C).
|
||||||
|
Return:
|
||||||
|
torch.Tensor: The calculated loss value.
|
||||||
|
"""
|
||||||
|
preds_S = preds_S.view(preds_S.shape[0], -1)
|
||||||
|
preds_T = preds_T.view(preds_T.shape[0], -1)
|
||||||
|
if self.with_l2_norm:
|
||||||
|
preds_S = F.normalize(preds_S, p=2, dim=-1)
|
||||||
|
preds_T = F.normalize(preds_T, p=2, dim=-1)
|
||||||
|
|
||||||
|
loss = self.angle_loss(preds_S, preds_T) * self.loss_weight
|
||||||
|
|
||||||
|
return loss
|
|
@ -1,6 +1,7 @@
|
||||||
Import:
|
Import:
|
||||||
- configs/distill/cwd/metafile.yml
|
- configs/distill/cwd/metafile.yml
|
||||||
- configs/distill/wsld/metafile.yml
|
- configs/distill/wsld/metafile.yml
|
||||||
|
- configs/distill/rkd/metafile.yml
|
||||||
- configs/nas/darts/metafile.yml
|
- configs/nas/darts/metafile.yml
|
||||||
- configs/nas/detnas/metafile.yml
|
- configs/nas/detnas/metafile.yml
|
||||||
- configs/nas/spos/metafile.yml
|
- configs/nas/spos/metafile.yml
|
||||||
|
|
|
@ -436,3 +436,113 @@ def test_cwd():
|
||||||
# test algorithm train_step
|
# test algorithm train_step
|
||||||
losses = algorithm.train_step(mm_inputs, None)
|
losses = algorithm.train_step(mm_inputs, None)
|
||||||
assert losses['loss'].item() > 0
|
assert losses['loss'].item() > 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_rkd():
|
||||||
|
student = dict(
|
||||||
|
type='mmcls.ImageClassifier',
|
||||||
|
backbone=dict(
|
||||||
|
type='ResNet',
|
||||||
|
depth=18,
|
||||||
|
num_stages=4,
|
||||||
|
out_indices=(3, ),
|
||||||
|
style='pytorch'),
|
||||||
|
neck=dict(type='GlobalAveragePooling'),
|
||||||
|
head=dict(
|
||||||
|
type='LinearClsHead',
|
||||||
|
num_classes=1000,
|
||||||
|
in_channels=512,
|
||||||
|
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
|
||||||
|
topk=(1, 5),
|
||||||
|
))
|
||||||
|
|
||||||
|
teacher = dict(
|
||||||
|
type='mmcls.ImageClassifier',
|
||||||
|
backbone=dict(
|
||||||
|
type='ResNet',
|
||||||
|
depth=34,
|
||||||
|
num_stages=4,
|
||||||
|
out_indices=(3, ),
|
||||||
|
style='pytorch'),
|
||||||
|
neck=dict(type='GlobalAveragePooling'),
|
||||||
|
head=dict(
|
||||||
|
type='LinearClsHead',
|
||||||
|
num_classes=1000,
|
||||||
|
in_channels=512,
|
||||||
|
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
|
||||||
|
topk=(1, 5),
|
||||||
|
))
|
||||||
|
|
||||||
|
# test RelationalKD w/ l2 norm
|
||||||
|
algorithm_cfg = ConfigDict(
|
||||||
|
type='GeneralDistill',
|
||||||
|
architecture=dict(
|
||||||
|
type='MMClsArchitecture',
|
||||||
|
model=student,
|
||||||
|
),
|
||||||
|
with_student_loss=True,
|
||||||
|
with_teacher_loss=False,
|
||||||
|
distiller=dict(
|
||||||
|
type='SingleTeacherDistiller',
|
||||||
|
teacher=teacher,
|
||||||
|
teacher_trainable=False,
|
||||||
|
teacher_norm_eval=True,
|
||||||
|
components=[
|
||||||
|
dict(
|
||||||
|
student_module='neck.gap',
|
||||||
|
teacher_module='neck.gap',
|
||||||
|
losses=[
|
||||||
|
dict(
|
||||||
|
type='DistanceWiseRKD',
|
||||||
|
name='distance_wise_loss',
|
||||||
|
loss_weight=25.0,
|
||||||
|
with_l2_norm=True),
|
||||||
|
dict(
|
||||||
|
type='AngleWiseRKD',
|
||||||
|
name='angle_wise_loss',
|
||||||
|
loss_weight=50.0,
|
||||||
|
with_l2_norm=True),
|
||||||
|
])
|
||||||
|
]),
|
||||||
|
)
|
||||||
|
|
||||||
|
imgs = torch.randn(16, 3, 32, 32)
|
||||||
|
label = torch.randint(0, 10, (16, ))
|
||||||
|
|
||||||
|
algorithm = ALGORITHMS.build(algorithm_cfg)
|
||||||
|
|
||||||
|
optimizer = torch.optim.SGD(algorithm.parameters(), lr=0.01)
|
||||||
|
outputs = algorithm.train_step({'img': imgs, 'gt_label': label}, optimizer)
|
||||||
|
assert outputs['loss'].item() > 0
|
||||||
|
assert outputs['num_samples'] == 16
|
||||||
|
|
||||||
|
# test forward
|
||||||
|
losses = algorithm(imgs, return_loss=True, gt_label=label)
|
||||||
|
assert losses['loss'].item() > 0
|
||||||
|
|
||||||
|
# test RelationalKD w/o l2 norm
|
||||||
|
algorithm_cfg.distiller.components = [
|
||||||
|
dict(
|
||||||
|
student_module='neck.gap',
|
||||||
|
teacher_module='neck.gap',
|
||||||
|
losses=[
|
||||||
|
dict(
|
||||||
|
type='DistanceWiseRKD',
|
||||||
|
name='distance_wise_loss',
|
||||||
|
loss_weight=25.0,
|
||||||
|
with_l2_norm=False),
|
||||||
|
dict(
|
||||||
|
type='AngleWiseRKD',
|
||||||
|
name='angle_wise_loss',
|
||||||
|
loss_weight=50.0,
|
||||||
|
with_l2_norm=False),
|
||||||
|
])
|
||||||
|
]
|
||||||
|
|
||||||
|
optimizer = torch.optim.SGD(algorithm.parameters(), lr=0.01)
|
||||||
|
outputs = algorithm.train_step({'img': imgs, 'gt_label': label}, optimizer)
|
||||||
|
assert outputs['loss'].item() > 0
|
||||||
|
assert outputs['num_samples'] == 16
|
||||||
|
|
||||||
|
losses = algorithm(imgs, return_loss=True, gt_label=label)
|
||||||
|
assert losses['loss'].item() > 0
|
||||||
|
|
Loading…
Reference in New Issue