mirror of
https://github.com/PaddlePaddle/PaddleClas.git
synced 2025-06-03 21:55:06 +08:00
add dist of rec model (#1574)
* add distillation loss func and rec distillation
This commit is contained in:
parent
0aa85d4f7f
commit
aea712cc87
@ -77,14 +77,19 @@ class RecModel(TheseusLayer):
|
||||
self.head = None
|
||||
|
||||
def forward(self, x, label=None):
|
||||
out = dict()
|
||||
x = self.backbone(x)
|
||||
out["backbone"] = x
|
||||
if self.neck is not None:
|
||||
x = self.neck(x)
|
||||
out["features"] = x
|
||||
if self.head is not None:
|
||||
y = self.head(x, label)
|
||||
out["neck"] = x
|
||||
else:
|
||||
y = None
|
||||
return {"features": x, "logits": y}
|
||||
out["logits"] = y
|
||||
return out
|
||||
|
||||
|
||||
class DistillationModel(nn.Layer):
|
||||
|
@ -196,7 +196,10 @@ class MobileNetV3(TheseusLayer):
|
||||
bias_attr=False)
|
||||
|
||||
self.hardswish = nn.Hardswish()
|
||||
self.dropout = Dropout(p=dropout_prob, mode="downscale_in_infer")
|
||||
if dropout_prob is not None:
|
||||
self.dropout = Dropout(p=dropout_prob, mode="downscale_in_infer")
|
||||
else:
|
||||
self.dropout = None
|
||||
self.flatten = nn.Flatten(start_axis=1, stop_axis=-1)
|
||||
|
||||
self.fc = Linear(self.class_expand, class_num)
|
||||
@ -210,7 +213,8 @@ class MobileNetV3(TheseusLayer):
|
||||
x = self.avg_pool(x)
|
||||
x = self.last_conv(x)
|
||||
x = self.hardswish(x)
|
||||
x = self.dropout(x)
|
||||
if self.dropout is not None:
|
||||
x = self.dropout(x)
|
||||
x = self.flatten(x)
|
||||
x = self.fc(x)
|
||||
|
||||
|
@ -0,0 +1,194 @@
|
||||
# global configs
|
||||
# global configs
|
||||
Global:
|
||||
checkpoints: null
|
||||
pretrained_model: null
|
||||
output_dir: ./output/
|
||||
device: gpu
|
||||
save_interval: 1
|
||||
eval_during_train: true
|
||||
eval_interval: 1
|
||||
epochs: 100
|
||||
print_batch_step: 10
|
||||
use_visualdl: False
|
||||
# used for static mode and model export
|
||||
image_shape: [3, 224, 224]
|
||||
save_inference_dir: ./inference
|
||||
eval_mode: retrieval
|
||||
use_dali: False
|
||||
to_static: False
|
||||
|
||||
# model architecture
|
||||
Arch:
|
||||
name: "DistillationModel"
|
||||
infer_output_key: features
|
||||
infer_add_softmax: False
|
||||
is_rec: True
|
||||
infer_model_name: "Student"
|
||||
# if not null, its lengths should be same as models
|
||||
pretrained_list:
|
||||
# if not null, its lengths should be same as models
|
||||
freeze_params_list:
|
||||
- False
|
||||
- False
|
||||
models:
|
||||
- Teacher:
|
||||
name: RecModel
|
||||
infer_output_key: features
|
||||
infer_add_softmax: False
|
||||
Backbone:
|
||||
name: PPLCNet_x2_5
|
||||
pretrained: True
|
||||
use_ssld: True
|
||||
BackboneStopLayer:
|
||||
name: "flatten"
|
||||
Neck:
|
||||
name: FC
|
||||
embedding_size: 1280
|
||||
class_num: 512
|
||||
Head:
|
||||
name: ArcMargin
|
||||
embedding_size: 512
|
||||
class_num: 185341
|
||||
margin: 0.2
|
||||
scale: 30
|
||||
- Student:
|
||||
name: RecModel
|
||||
infer_output_key: features
|
||||
infer_add_softmax: False
|
||||
Backbone:
|
||||
name: PPLCNet_x2_5
|
||||
pretrained: True
|
||||
use_ssld: True
|
||||
BackboneStopLayer:
|
||||
name: "flatten"
|
||||
Neck:
|
||||
name: FC
|
||||
embedding_size: 1280
|
||||
class_num: 512
|
||||
Head:
|
||||
name: ArcMargin
|
||||
embedding_size: 512
|
||||
class_num: 185341
|
||||
margin: 0.2
|
||||
scale: 30
|
||||
|
||||
# loss function config for traing/eval process
|
||||
Loss:
|
||||
Train:
|
||||
- DistillationGTCELoss:
|
||||
weight: 1.0
|
||||
key: "logits"
|
||||
model_names: ["Student", "Teacher"]
|
||||
- DistillationDMLLoss:
|
||||
weight: 1.0
|
||||
key: "logits"
|
||||
model_name_pairs:
|
||||
- ["Student", "Teacher"]
|
||||
- DistillationDMLLoss:
|
||||
weight: 1.0
|
||||
key: "logits"
|
||||
model_name_pairs:
|
||||
- ["Student", "Teacher"]
|
||||
Eval:
|
||||
- DistillationGTCELoss:
|
||||
weight: 1.0
|
||||
model_names: ["Student"]
|
||||
|
||||
Optimizer:
|
||||
name: Momentum
|
||||
momentum: 0.9
|
||||
lr:
|
||||
name: Cosine
|
||||
learning_rate: 0.02
|
||||
warmup_epoch: 5
|
||||
regularizer:
|
||||
name: 'L2'
|
||||
coeff: 0.00001
|
||||
|
||||
|
||||
# data loader for train and eval
|
||||
DataLoader:
|
||||
Train:
|
||||
dataset:
|
||||
name: ImageNetDataset
|
||||
image_root: ./dataset/
|
||||
cls_label_path: ./dataset/train_reg_all_data.txt
|
||||
transform_ops:
|
||||
- DecodeImage:
|
||||
to_rgb: True
|
||||
channel_first: False
|
||||
- RandCropImage:
|
||||
size: 224
|
||||
- RandFlipImage:
|
||||
flip_code: 1
|
||||
- NormalizeImage:
|
||||
scale: 1.0/255.0
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: ''
|
||||
|
||||
sampler:
|
||||
name: DistributedBatchSampler
|
||||
batch_size: 64
|
||||
drop_last: False
|
||||
shuffle: True
|
||||
loader:
|
||||
num_workers: 4
|
||||
use_shared_memory: True
|
||||
|
||||
Eval:
|
||||
Query:
|
||||
dataset:
|
||||
name: VeriWild
|
||||
image_root: ./dataset/Aliproduct/
|
||||
cls_label_path: ./dataset/Aliproduct/val_list.txt
|
||||
transform_ops:
|
||||
- DecodeImage:
|
||||
to_rgb: True
|
||||
channel_first: False
|
||||
- ResizeImage:
|
||||
size: 224
|
||||
- NormalizeImage:
|
||||
scale: 0.00392157
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: ''
|
||||
sampler:
|
||||
name: DistributedBatchSampler
|
||||
batch_size: 64
|
||||
drop_last: False
|
||||
shuffle: False
|
||||
loader:
|
||||
num_workers: 4
|
||||
use_shared_memory: True
|
||||
|
||||
Gallery:
|
||||
dataset:
|
||||
name: VeriWild
|
||||
image_root: ./dataset/Aliproduct/
|
||||
cls_label_path: ./dataset/Aliproduct/val_list.txt
|
||||
transform_ops:
|
||||
- DecodeImage:
|
||||
to_rgb: True
|
||||
channel_first: False
|
||||
- ResizeImage:
|
||||
size: 224
|
||||
- NormalizeImage:
|
||||
scale: 0.00392157
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: ''
|
||||
sampler:
|
||||
name: DistributedBatchSampler
|
||||
batch_size: 64
|
||||
drop_last: False
|
||||
shuffle: False
|
||||
loader:
|
||||
num_workers: 4
|
||||
use_shared_memory: True
|
||||
|
||||
Metric:
|
||||
Eval:
|
||||
- Recallk:
|
||||
topk: [1, 5]
|
@ -0,0 +1,193 @@
|
||||
# global configs
|
||||
Global:
|
||||
checkpoints: null
|
||||
pretrained_model: null
|
||||
output_dir: ./output/
|
||||
device: gpu
|
||||
save_interval: 1
|
||||
eval_during_train: true
|
||||
eval_interval: 1
|
||||
epochs: 100
|
||||
print_batch_step: 10
|
||||
use_visualdl: False
|
||||
# used for static mode and model export
|
||||
image_shape: [3, 224, 224]
|
||||
save_inference_dir: ./inference
|
||||
eval_mode: retrieval
|
||||
use_dali: False
|
||||
to_static: False
|
||||
|
||||
# model architecture
|
||||
Arch:
|
||||
name: "DistillationModel"
|
||||
infer_output_key: features
|
||||
infer_add_softmax: False
|
||||
is_rec: True
|
||||
infer_model_name: "Student"
|
||||
# if not null, its lengths should be same as models
|
||||
pretrained_list:
|
||||
# if not null, its lengths should be same as models
|
||||
freeze_params_list:
|
||||
- False
|
||||
- False
|
||||
models:
|
||||
- Teacher:
|
||||
name: RecModel
|
||||
infer_output_key: features
|
||||
infer_add_softmax: False
|
||||
Backbone:
|
||||
name: PPLCNet_x2_5
|
||||
pretrained: True
|
||||
use_ssld: True
|
||||
BackboneStopLayer:
|
||||
name: "flatten"
|
||||
Neck:
|
||||
name: FC
|
||||
embedding_size: 1280
|
||||
class_num: 512
|
||||
Head:
|
||||
name: ArcMargin
|
||||
embedding_size: 512
|
||||
class_num: 185341
|
||||
margin: 0.2
|
||||
scale: 30
|
||||
- Student:
|
||||
name: RecModel
|
||||
infer_output_key: features
|
||||
infer_add_softmax: False
|
||||
Backbone:
|
||||
name: PPLCNet_x2_5
|
||||
pretrained: True
|
||||
use_ssld: True
|
||||
BackboneStopLayer:
|
||||
name: "flatten"
|
||||
Neck:
|
||||
name: FC
|
||||
embedding_size: 1280
|
||||
class_num: 512
|
||||
Head:
|
||||
name: ArcMargin
|
||||
embedding_size: 512
|
||||
class_num: 185341
|
||||
margin: 0.2
|
||||
scale: 30
|
||||
|
||||
# loss function config for traing/eval process
|
||||
Loss:
|
||||
Train:
|
||||
- DistillationGTCELoss:
|
||||
weight: 1.0
|
||||
key: "logits"
|
||||
model_names: ["Student", "Teacher"]
|
||||
- DistillationDMLLoss:
|
||||
weight: 1.0
|
||||
key: "logits"
|
||||
model_name_pairs:
|
||||
- ["Student", "Teacher"]
|
||||
- DistillationDistanceLoss:
|
||||
weight: 1.0
|
||||
key: "backbone"
|
||||
model_name_pairs:
|
||||
- ["Student", "Teacher"]
|
||||
Eval:
|
||||
- DistillationGTCELoss:
|
||||
weight: 1.0
|
||||
model_names: ["Student"]
|
||||
|
||||
Optimizer:
|
||||
name: Momentum
|
||||
momentum: 0.9
|
||||
lr:
|
||||
name: Cosine
|
||||
learning_rate: 0.02
|
||||
warmup_epoch: 5
|
||||
regularizer:
|
||||
name: 'L2'
|
||||
coeff: 0.00001
|
||||
|
||||
|
||||
# data loader for train and eval
|
||||
DataLoader:
|
||||
Train:
|
||||
dataset:
|
||||
name: ImageNetDataset
|
||||
image_root: ./dataset/
|
||||
cls_label_path: ./dataset/train_reg_all_data.txt
|
||||
transform_ops:
|
||||
- DecodeImage:
|
||||
to_rgb: True
|
||||
channel_first: False
|
||||
- RandCropImage:
|
||||
size: 224
|
||||
- RandFlipImage:
|
||||
flip_code: 1
|
||||
- NormalizeImage:
|
||||
scale: 1.0/255.0
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: ''
|
||||
|
||||
sampler:
|
||||
name: DistributedBatchSampler
|
||||
batch_size: 64
|
||||
drop_last: False
|
||||
shuffle: True
|
||||
loader:
|
||||
num_workers: 4
|
||||
use_shared_memory: True
|
||||
|
||||
Eval:
|
||||
Query:
|
||||
dataset:
|
||||
name: VeriWild
|
||||
image_root: ./dataset/Aliproduct/
|
||||
cls_label_path: ./dataset/Aliproduct/val_list.txt
|
||||
transform_ops:
|
||||
- DecodeImage:
|
||||
to_rgb: True
|
||||
channel_first: False
|
||||
- ResizeImage:
|
||||
size: 224
|
||||
- NormalizeImage:
|
||||
scale: 0.00392157
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: ''
|
||||
sampler:
|
||||
name: DistributedBatchSampler
|
||||
batch_size: 64
|
||||
drop_last: False
|
||||
shuffle: False
|
||||
loader:
|
||||
num_workers: 4
|
||||
use_shared_memory: True
|
||||
|
||||
Gallery:
|
||||
dataset:
|
||||
name: VeriWild
|
||||
image_root: ./dataset/Aliproduct/
|
||||
cls_label_path: ./dataset/Aliproduct/val_list.txt
|
||||
transform_ops:
|
||||
- DecodeImage:
|
||||
to_rgb: True
|
||||
channel_first: False
|
||||
- ResizeImage:
|
||||
size: 224
|
||||
- NormalizeImage:
|
||||
scale: 0.00392157
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: ''
|
||||
sampler:
|
||||
name: DistributedBatchSampler
|
||||
batch_size: 64
|
||||
drop_last: False
|
||||
shuffle: False
|
||||
loader:
|
||||
num_workers: 4
|
||||
use_shared_memory: True
|
||||
|
||||
Metric:
|
||||
Eval:
|
||||
- Recallk:
|
||||
topk: [1, 5]
|
@ -13,6 +13,7 @@ Global:
|
||||
# used for static mode and model export
|
||||
image_shape: [3, 224, 224]
|
||||
save_inference_dir: "./inference"
|
||||
use_dali: false
|
||||
|
||||
# model architecture
|
||||
Arch:
|
||||
@ -29,9 +30,11 @@ Arch:
|
||||
name: MobileNetV3_large_x1_0
|
||||
pretrained: True
|
||||
use_ssld: True
|
||||
dropout_prob: null
|
||||
- Student:
|
||||
name: MobileNetV3_small_x1_0
|
||||
pretrained: False
|
||||
dropout_prob: null
|
||||
|
||||
infer_model_name: "Student"
|
||||
|
||||
@ -76,7 +79,6 @@ DataLoader:
|
||||
size: 224
|
||||
- RandFlipImage:
|
||||
flip_code: 1
|
||||
- AutoAugment:
|
||||
- NormalizeImage:
|
||||
scale: 0.00392157
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
@ -85,7 +87,7 @@ DataLoader:
|
||||
|
||||
sampler:
|
||||
name: DistributedBatchSampler
|
||||
batch_size: 512
|
||||
batch_size: 256
|
||||
drop_last: False
|
||||
shuffle: True
|
||||
loader:
|
||||
@ -112,7 +114,7 @@ DataLoader:
|
||||
order: ''
|
||||
sampler:
|
||||
name: DistributedBatchSampler
|
||||
batch_size: 64
|
||||
batch_size: 128
|
||||
drop_last: False
|
||||
shuffle: False
|
||||
loader:
|
||||
|
@ -53,7 +53,8 @@ class Engine(object):
|
||||
self.config = config
|
||||
self.eval_mode = self.config["Global"].get("eval_mode",
|
||||
"classification")
|
||||
if "Head" in self.config["Arch"]:
|
||||
if "Head" in self.config["Arch"] or self.config["Arch"].get("is_rec",
|
||||
False):
|
||||
self.is_rec = True
|
||||
else:
|
||||
self.is_rec = False
|
||||
@ -357,7 +358,9 @@ class Engine(object):
|
||||
out = self.model(batch_tensor)
|
||||
if isinstance(out, list):
|
||||
out = out[0]
|
||||
if isinstance(out, dict):
|
||||
if isinstance(out, dict) and "logits" in out:
|
||||
out = out["logits"]
|
||||
if isinstance(out, dict) and "output" in out:
|
||||
out = out["output"]
|
||||
result = self.postprocess_func(out, image_file_list)
|
||||
print(result)
|
||||
|
@ -78,10 +78,10 @@ def classification_eval(engine, epoch_id=0):
|
||||
labels = paddle.concat(label_list, 0)
|
||||
|
||||
if isinstance(out, dict):
|
||||
if "logits" in out:
|
||||
out = out["logits"]
|
||||
elif "Student" in out:
|
||||
if "Student" in out:
|
||||
out = out["Student"]
|
||||
elif "logits" in out:
|
||||
out = out["logits"]
|
||||
else:
|
||||
msg = "Error: Wrong key in out!"
|
||||
raise Exception(msg)
|
||||
@ -106,6 +106,7 @@ def classification_eval(engine, epoch_id=0):
|
||||
metric_dict = engine.eval_metric_func(pred, labels)
|
||||
else:
|
||||
metric_dict = engine.eval_metric_func(out, batch[1])
|
||||
|
||||
for key in metric_dict:
|
||||
if metric_key is None:
|
||||
metric_key = key
|
||||
|
@ -123,6 +123,8 @@ def cal_feature(engine, name='gallery'):
|
||||
has_unique_id = True
|
||||
batch[2] = batch[2].reshape([-1, 1]).astype("int64")
|
||||
out = engine.model(batch[0], batch[1])
|
||||
if "Student" in out:
|
||||
out = out["Student"]
|
||||
batch_feas = out["features"]
|
||||
|
||||
# do norm
|
||||
|
@ -20,6 +20,8 @@ from .distanceloss import DistanceLoss
|
||||
from .distillationloss import DistillationCELoss
|
||||
from .distillationloss import DistillationGTCELoss
|
||||
from .distillationloss import DistillationDMLLoss
|
||||
from .distillationloss import DistillationDistanceLoss
|
||||
from .distillationloss import DistillationRKDLoss
|
||||
from .multilabelloss import MultiLabelLoss
|
||||
|
||||
from .deephashloss import DSHSDLoss, LCDSHLoss
|
||||
|
@ -18,6 +18,7 @@ import paddle.nn as nn
|
||||
from .celoss import CELoss
|
||||
from .dmlloss import DMLLoss
|
||||
from .distanceloss import DistanceLoss
|
||||
from .rkdloss import RKdAngle, RkdDistance
|
||||
|
||||
|
||||
class DistillationCELoss(CELoss):
|
||||
@ -68,7 +69,7 @@ class DistillationGTCELoss(CELoss):
|
||||
|
||||
def forward(self, predicts, batch):
|
||||
loss_dict = dict()
|
||||
for idx, name in enumerate(self.model_names):
|
||||
for _, name in enumerate(self.model_names):
|
||||
out = predicts[name]
|
||||
if self.key is not None:
|
||||
out = out[self.key]
|
||||
@ -84,7 +85,7 @@ class DistillationDMLLoss(DMLLoss):
|
||||
|
||||
def __init__(self,
|
||||
model_name_pairs=[],
|
||||
act=None,
|
||||
act="softmax",
|
||||
key=None,
|
||||
name="loss_dml"):
|
||||
super().__init__(act=act)
|
||||
@ -125,7 +126,7 @@ class DistillationDistanceLoss(DistanceLoss):
|
||||
assert isinstance(model_name_pairs, list)
|
||||
self.key = key
|
||||
self.model_name_pairs = model_name_pairs
|
||||
self.name = name + "_l2"
|
||||
self.name = name + mode
|
||||
|
||||
def forward(self, predicts, batch):
|
||||
loss_dict = dict()
|
||||
@ -139,3 +140,35 @@ class DistillationDistanceLoss(DistanceLoss):
|
||||
for key in loss:
|
||||
loss_dict["{}_{}_{}".format(self.name, key, idx)] = loss[key]
|
||||
return loss_dict
|
||||
|
||||
|
||||
class DistillationRKDLoss(nn.Layer):
|
||||
def __init__(self,
|
||||
target_size=None,
|
||||
model_name_pairs=(["Student", "Teacher"], ),
|
||||
student_keepkeys=[],
|
||||
teacher_keepkeys=[]):
|
||||
super().__init__()
|
||||
self.student_keepkeys = student_keepkeys
|
||||
self.teacher_keepkeys = teacher_keepkeys
|
||||
self.model_name_pairs = model_name_pairs
|
||||
assert len(self.student_keepkeys) == len(self.teacher_keepkeys)
|
||||
|
||||
self.rkd_angle_loss = RKdAngle(target_size=target_size)
|
||||
self.rkd_dist_loss = RkdDistance(target_size=target_size)
|
||||
|
||||
def __call__(self, predicts, batch):
|
||||
loss_dict = {}
|
||||
for m1, m2 in self.model_name_pairs:
|
||||
for idx, (
|
||||
student_name, teacher_name
|
||||
) in enumerate(zip(self.student_keepkeys, self.teacher_keepkeys)):
|
||||
student_out = predicts[m1][student_name]
|
||||
teacher_out = predicts[m2][teacher_name]
|
||||
|
||||
loss_dict[f"loss_angle_{idx}_{m1}_{m2}"] = self.rkd_angle_loss(
|
||||
student_out, teacher_out)
|
||||
loss_dict[f"loss_dist_{idx}_{m1}_{m2}"] = self.rkd_dist_loss(
|
||||
student_out, teacher_out)
|
||||
|
||||
return loss_dict
|
||||
|
@ -22,7 +22,7 @@ class DMLLoss(nn.Layer):
|
||||
DMLLoss
|
||||
"""
|
||||
|
||||
def __init__(self, act="softmax"):
|
||||
def __init__(self, act="softmax", eps=1e-12):
|
||||
super().__init__()
|
||||
if act is not None:
|
||||
assert act in ["softmax", "sigmoid"]
|
||||
@ -32,15 +32,19 @@ class DMLLoss(nn.Layer):
|
||||
self.act = nn.Sigmoid()
|
||||
else:
|
||||
self.act = None
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, out1, out2):
|
||||
def _kldiv(self, x, target):
|
||||
class_num = x.shape[-1]
|
||||
cost = target * paddle.log(
|
||||
(target + self.eps) / (x + self.eps)) * class_num
|
||||
return cost
|
||||
|
||||
def forward(self, x, target):
|
||||
if self.act is not None:
|
||||
out1 = self.act(out1)
|
||||
out2 = self.act(out2)
|
||||
|
||||
log_out1 = paddle.log(out1)
|
||||
log_out2 = paddle.log(out2)
|
||||
loss = (F.kl_div(
|
||||
log_out1, out2, reduction='batchmean') + F.kl_div(
|
||||
log_out2, out1, reduction='batchmean')) / 2.0
|
||||
x = F.softmax(x)
|
||||
target = F.softmax(target)
|
||||
loss = self._kldiv(x, target) + self._kldiv(target, x)
|
||||
loss = loss / 2
|
||||
loss = paddle.mean(loss)
|
||||
return {"DMLLoss": loss}
|
||||
|
97
ppcls/loss/rkdloss.py
Normal file
97
ppcls/loss/rkdloss.py
Normal file
@ -0,0 +1,97 @@
|
||||
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
import paddle.nn.functional as F
|
||||
|
||||
|
||||
def pdist(e, squared=False, eps=1e-12):
|
||||
e_square = e.pow(2).sum(axis=1)
|
||||
prod = paddle.mm(e, e.t())
|
||||
res = (e_square.unsqueeze(1) + e_square.unsqueeze(0) - 2 * prod).clip(
|
||||
min=eps)
|
||||
|
||||
if not squared:
|
||||
res = res.sqrt()
|
||||
return res
|
||||
|
||||
|
||||
class RKdAngle(nn.Layer):
|
||||
# reference: https://github.com/lenscloth/RKD/blob/master/metric/loss.py
|
||||
def __init__(self, target_size=None):
|
||||
super().__init__()
|
||||
if target_size is not None:
|
||||
self.avgpool = paddle.nn.AdaptiveAvgPool2D(target_size)
|
||||
else:
|
||||
self.avgpool = None
|
||||
|
||||
def forward(self, student, teacher):
|
||||
# GAP to reduce memory
|
||||
if self.avgpool is not None:
|
||||
# NxC1xH1xW1 -> NxC1x1x1
|
||||
student = self.avgpool(student)
|
||||
# NxC2xH2xW2 -> NxC2x1x1
|
||||
teacher = self.avgpool(teacher)
|
||||
|
||||
# reshape for feature map distillation
|
||||
bs = student.shape[0]
|
||||
student = student.reshape([bs, -1])
|
||||
teacher = teacher.reshape([bs, -1])
|
||||
|
||||
td = (teacher.unsqueeze(0) - teacher.unsqueeze(1))
|
||||
norm_td = F.normalize(td, p=2, axis=2)
|
||||
t_angle = paddle.bmm(norm_td, norm_td.transpose([0, 2, 1])).reshape(
|
||||
[-1, 1])
|
||||
|
||||
sd = (student.unsqueeze(0) - student.unsqueeze(1))
|
||||
norm_sd = F.normalize(sd, p=2, axis=2)
|
||||
s_angle = paddle.bmm(norm_sd, norm_sd.transpose([0, 2, 1])).reshape(
|
||||
[-1, 1])
|
||||
loss = F.smooth_l1_loss(s_angle, t_angle, reduction='mean')
|
||||
return loss
|
||||
|
||||
|
||||
class RkdDistance(nn.Layer):
|
||||
# reference: https://github.com/lenscloth/RKD/blob/master/metric/loss.py
|
||||
def __init__(self, eps=1e-12, target_size=1):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
if target_size is not None:
|
||||
self.avgpool = paddle.nn.AdaptiveAvgPool2D(target_size)
|
||||
else:
|
||||
self.avgpool = None
|
||||
|
||||
def forward(self, student, teacher):
|
||||
# GAP to reduce memory
|
||||
if self.avgpool is not None:
|
||||
# NxC1xH1xW1 -> NxC1x1x1
|
||||
student = self.avgpool(student)
|
||||
# NxC2xH2xW2 -> NxC2x1x1
|
||||
teacher = self.avgpool(teacher)
|
||||
|
||||
bs = student.shape[0]
|
||||
student = student.reshape([bs, -1])
|
||||
teacher = teacher.reshape([bs, -1])
|
||||
|
||||
t_d = pdist(teacher, squared=False)
|
||||
mean_td = t_d.mean()
|
||||
t_d = t_d / (mean_td + self.eps)
|
||||
|
||||
d = pdist(student, squared=False)
|
||||
mean_d = d.mean()
|
||||
d = d / (mean_d + self.eps)
|
||||
|
||||
loss = F.smooth_l1_loss(d, t_d, reduction="mean")
|
||||
return loss
|
Loading…
x
Reference in New Issue
Block a user