mirror of
https://github.com/PaddlePaddle/PaddleClas.git
synced 2025-06-03 21:55:06 +08:00
add dist of rec model (#1574)
* add distillation loss func and rec distillation
This commit is contained in:
parent
0aa85d4f7f
commit
aea712cc87
@ -77,14 +77,19 @@ class RecModel(TheseusLayer):
|
|||||||
self.head = None
|
self.head = None
|
||||||
|
|
||||||
def forward(self, x, label=None):
|
def forward(self, x, label=None):
|
||||||
|
out = dict()
|
||||||
x = self.backbone(x)
|
x = self.backbone(x)
|
||||||
|
out["backbone"] = x
|
||||||
if self.neck is not None:
|
if self.neck is not None:
|
||||||
x = self.neck(x)
|
x = self.neck(x)
|
||||||
|
out["features"] = x
|
||||||
if self.head is not None:
|
if self.head is not None:
|
||||||
y = self.head(x, label)
|
y = self.head(x, label)
|
||||||
|
out["neck"] = x
|
||||||
else:
|
else:
|
||||||
y = None
|
y = None
|
||||||
return {"features": x, "logits": y}
|
out["logits"] = y
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
class DistillationModel(nn.Layer):
|
class DistillationModel(nn.Layer):
|
||||||
|
@ -196,7 +196,10 @@ class MobileNetV3(TheseusLayer):
|
|||||||
bias_attr=False)
|
bias_attr=False)
|
||||||
|
|
||||||
self.hardswish = nn.Hardswish()
|
self.hardswish = nn.Hardswish()
|
||||||
|
if dropout_prob is not None:
|
||||||
self.dropout = Dropout(p=dropout_prob, mode="downscale_in_infer")
|
self.dropout = Dropout(p=dropout_prob, mode="downscale_in_infer")
|
||||||
|
else:
|
||||||
|
self.dropout = None
|
||||||
self.flatten = nn.Flatten(start_axis=1, stop_axis=-1)
|
self.flatten = nn.Flatten(start_axis=1, stop_axis=-1)
|
||||||
|
|
||||||
self.fc = Linear(self.class_expand, class_num)
|
self.fc = Linear(self.class_expand, class_num)
|
||||||
@ -210,6 +213,7 @@ class MobileNetV3(TheseusLayer):
|
|||||||
x = self.avg_pool(x)
|
x = self.avg_pool(x)
|
||||||
x = self.last_conv(x)
|
x = self.last_conv(x)
|
||||||
x = self.hardswish(x)
|
x = self.hardswish(x)
|
||||||
|
if self.dropout is not None:
|
||||||
x = self.dropout(x)
|
x = self.dropout(x)
|
||||||
x = self.flatten(x)
|
x = self.flatten(x)
|
||||||
x = self.fc(x)
|
x = self.fc(x)
|
||||||
|
@ -0,0 +1,194 @@
|
|||||||
|
# global configs
|
||||||
|
# global configs
|
||||||
|
Global:
|
||||||
|
checkpoints: null
|
||||||
|
pretrained_model: null
|
||||||
|
output_dir: ./output/
|
||||||
|
device: gpu
|
||||||
|
save_interval: 1
|
||||||
|
eval_during_train: true
|
||||||
|
eval_interval: 1
|
||||||
|
epochs: 100
|
||||||
|
print_batch_step: 10
|
||||||
|
use_visualdl: False
|
||||||
|
# used for static mode and model export
|
||||||
|
image_shape: [3, 224, 224]
|
||||||
|
save_inference_dir: ./inference
|
||||||
|
eval_mode: retrieval
|
||||||
|
use_dali: False
|
||||||
|
to_static: False
|
||||||
|
|
||||||
|
# model architecture
|
||||||
|
Arch:
|
||||||
|
name: "DistillationModel"
|
||||||
|
infer_output_key: features
|
||||||
|
infer_add_softmax: False
|
||||||
|
is_rec: True
|
||||||
|
infer_model_name: "Student"
|
||||||
|
# if not null, its lengths should be same as models
|
||||||
|
pretrained_list:
|
||||||
|
# if not null, its lengths should be same as models
|
||||||
|
freeze_params_list:
|
||||||
|
- False
|
||||||
|
- False
|
||||||
|
models:
|
||||||
|
- Teacher:
|
||||||
|
name: RecModel
|
||||||
|
infer_output_key: features
|
||||||
|
infer_add_softmax: False
|
||||||
|
Backbone:
|
||||||
|
name: PPLCNet_x2_5
|
||||||
|
pretrained: True
|
||||||
|
use_ssld: True
|
||||||
|
BackboneStopLayer:
|
||||||
|
name: "flatten"
|
||||||
|
Neck:
|
||||||
|
name: FC
|
||||||
|
embedding_size: 1280
|
||||||
|
class_num: 512
|
||||||
|
Head:
|
||||||
|
name: ArcMargin
|
||||||
|
embedding_size: 512
|
||||||
|
class_num: 185341
|
||||||
|
margin: 0.2
|
||||||
|
scale: 30
|
||||||
|
- Student:
|
||||||
|
name: RecModel
|
||||||
|
infer_output_key: features
|
||||||
|
infer_add_softmax: False
|
||||||
|
Backbone:
|
||||||
|
name: PPLCNet_x2_5
|
||||||
|
pretrained: True
|
||||||
|
use_ssld: True
|
||||||
|
BackboneStopLayer:
|
||||||
|
name: "flatten"
|
||||||
|
Neck:
|
||||||
|
name: FC
|
||||||
|
embedding_size: 1280
|
||||||
|
class_num: 512
|
||||||
|
Head:
|
||||||
|
name: ArcMargin
|
||||||
|
embedding_size: 512
|
||||||
|
class_num: 185341
|
||||||
|
margin: 0.2
|
||||||
|
scale: 30
|
||||||
|
|
||||||
|
# loss function config for traing/eval process
|
||||||
|
Loss:
|
||||||
|
Train:
|
||||||
|
- DistillationGTCELoss:
|
||||||
|
weight: 1.0
|
||||||
|
key: "logits"
|
||||||
|
model_names: ["Student", "Teacher"]
|
||||||
|
- DistillationDMLLoss:
|
||||||
|
weight: 1.0
|
||||||
|
key: "logits"
|
||||||
|
model_name_pairs:
|
||||||
|
- ["Student", "Teacher"]
|
||||||
|
- DistillationDMLLoss:
|
||||||
|
weight: 1.0
|
||||||
|
key: "logits"
|
||||||
|
model_name_pairs:
|
||||||
|
- ["Student", "Teacher"]
|
||||||
|
Eval:
|
||||||
|
- DistillationGTCELoss:
|
||||||
|
weight: 1.0
|
||||||
|
model_names: ["Student"]
|
||||||
|
|
||||||
|
Optimizer:
|
||||||
|
name: Momentum
|
||||||
|
momentum: 0.9
|
||||||
|
lr:
|
||||||
|
name: Cosine
|
||||||
|
learning_rate: 0.02
|
||||||
|
warmup_epoch: 5
|
||||||
|
regularizer:
|
||||||
|
name: 'L2'
|
||||||
|
coeff: 0.00001
|
||||||
|
|
||||||
|
|
||||||
|
# data loader for train and eval
|
||||||
|
DataLoader:
|
||||||
|
Train:
|
||||||
|
dataset:
|
||||||
|
name: ImageNetDataset
|
||||||
|
image_root: ./dataset/
|
||||||
|
cls_label_path: ./dataset/train_reg_all_data.txt
|
||||||
|
transform_ops:
|
||||||
|
- DecodeImage:
|
||||||
|
to_rgb: True
|
||||||
|
channel_first: False
|
||||||
|
- RandCropImage:
|
||||||
|
size: 224
|
||||||
|
- RandFlipImage:
|
||||||
|
flip_code: 1
|
||||||
|
- NormalizeImage:
|
||||||
|
scale: 1.0/255.0
|
||||||
|
mean: [0.485, 0.456, 0.406]
|
||||||
|
std: [0.229, 0.224, 0.225]
|
||||||
|
order: ''
|
||||||
|
|
||||||
|
sampler:
|
||||||
|
name: DistributedBatchSampler
|
||||||
|
batch_size: 64
|
||||||
|
drop_last: False
|
||||||
|
shuffle: True
|
||||||
|
loader:
|
||||||
|
num_workers: 4
|
||||||
|
use_shared_memory: True
|
||||||
|
|
||||||
|
Eval:
|
||||||
|
Query:
|
||||||
|
dataset:
|
||||||
|
name: VeriWild
|
||||||
|
image_root: ./dataset/Aliproduct/
|
||||||
|
cls_label_path: ./dataset/Aliproduct/val_list.txt
|
||||||
|
transform_ops:
|
||||||
|
- DecodeImage:
|
||||||
|
to_rgb: True
|
||||||
|
channel_first: False
|
||||||
|
- ResizeImage:
|
||||||
|
size: 224
|
||||||
|
- NormalizeImage:
|
||||||
|
scale: 0.00392157
|
||||||
|
mean: [0.485, 0.456, 0.406]
|
||||||
|
std: [0.229, 0.224, 0.225]
|
||||||
|
order: ''
|
||||||
|
sampler:
|
||||||
|
name: DistributedBatchSampler
|
||||||
|
batch_size: 64
|
||||||
|
drop_last: False
|
||||||
|
shuffle: False
|
||||||
|
loader:
|
||||||
|
num_workers: 4
|
||||||
|
use_shared_memory: True
|
||||||
|
|
||||||
|
Gallery:
|
||||||
|
dataset:
|
||||||
|
name: VeriWild
|
||||||
|
image_root: ./dataset/Aliproduct/
|
||||||
|
cls_label_path: ./dataset/Aliproduct/val_list.txt
|
||||||
|
transform_ops:
|
||||||
|
- DecodeImage:
|
||||||
|
to_rgb: True
|
||||||
|
channel_first: False
|
||||||
|
- ResizeImage:
|
||||||
|
size: 224
|
||||||
|
- NormalizeImage:
|
||||||
|
scale: 0.00392157
|
||||||
|
mean: [0.485, 0.456, 0.406]
|
||||||
|
std: [0.229, 0.224, 0.225]
|
||||||
|
order: ''
|
||||||
|
sampler:
|
||||||
|
name: DistributedBatchSampler
|
||||||
|
batch_size: 64
|
||||||
|
drop_last: False
|
||||||
|
shuffle: False
|
||||||
|
loader:
|
||||||
|
num_workers: 4
|
||||||
|
use_shared_memory: True
|
||||||
|
|
||||||
|
Metric:
|
||||||
|
Eval:
|
||||||
|
- Recallk:
|
||||||
|
topk: [1, 5]
|
@ -0,0 +1,193 @@
|
|||||||
|
# global configs
|
||||||
|
Global:
|
||||||
|
checkpoints: null
|
||||||
|
pretrained_model: null
|
||||||
|
output_dir: ./output/
|
||||||
|
device: gpu
|
||||||
|
save_interval: 1
|
||||||
|
eval_during_train: true
|
||||||
|
eval_interval: 1
|
||||||
|
epochs: 100
|
||||||
|
print_batch_step: 10
|
||||||
|
use_visualdl: False
|
||||||
|
# used for static mode and model export
|
||||||
|
image_shape: [3, 224, 224]
|
||||||
|
save_inference_dir: ./inference
|
||||||
|
eval_mode: retrieval
|
||||||
|
use_dali: False
|
||||||
|
to_static: False
|
||||||
|
|
||||||
|
# model architecture
|
||||||
|
Arch:
|
||||||
|
name: "DistillationModel"
|
||||||
|
infer_output_key: features
|
||||||
|
infer_add_softmax: False
|
||||||
|
is_rec: True
|
||||||
|
infer_model_name: "Student"
|
||||||
|
# if not null, its lengths should be same as models
|
||||||
|
pretrained_list:
|
||||||
|
# if not null, its lengths should be same as models
|
||||||
|
freeze_params_list:
|
||||||
|
- False
|
||||||
|
- False
|
||||||
|
models:
|
||||||
|
- Teacher:
|
||||||
|
name: RecModel
|
||||||
|
infer_output_key: features
|
||||||
|
infer_add_softmax: False
|
||||||
|
Backbone:
|
||||||
|
name: PPLCNet_x2_5
|
||||||
|
pretrained: True
|
||||||
|
use_ssld: True
|
||||||
|
BackboneStopLayer:
|
||||||
|
name: "flatten"
|
||||||
|
Neck:
|
||||||
|
name: FC
|
||||||
|
embedding_size: 1280
|
||||||
|
class_num: 512
|
||||||
|
Head:
|
||||||
|
name: ArcMargin
|
||||||
|
embedding_size: 512
|
||||||
|
class_num: 185341
|
||||||
|
margin: 0.2
|
||||||
|
scale: 30
|
||||||
|
- Student:
|
||||||
|
name: RecModel
|
||||||
|
infer_output_key: features
|
||||||
|
infer_add_softmax: False
|
||||||
|
Backbone:
|
||||||
|
name: PPLCNet_x2_5
|
||||||
|
pretrained: True
|
||||||
|
use_ssld: True
|
||||||
|
BackboneStopLayer:
|
||||||
|
name: "flatten"
|
||||||
|
Neck:
|
||||||
|
name: FC
|
||||||
|
embedding_size: 1280
|
||||||
|
class_num: 512
|
||||||
|
Head:
|
||||||
|
name: ArcMargin
|
||||||
|
embedding_size: 512
|
||||||
|
class_num: 185341
|
||||||
|
margin: 0.2
|
||||||
|
scale: 30
|
||||||
|
|
||||||
|
# loss function config for traing/eval process
|
||||||
|
Loss:
|
||||||
|
Train:
|
||||||
|
- DistillationGTCELoss:
|
||||||
|
weight: 1.0
|
||||||
|
key: "logits"
|
||||||
|
model_names: ["Student", "Teacher"]
|
||||||
|
- DistillationDMLLoss:
|
||||||
|
weight: 1.0
|
||||||
|
key: "logits"
|
||||||
|
model_name_pairs:
|
||||||
|
- ["Student", "Teacher"]
|
||||||
|
- DistillationDistanceLoss:
|
||||||
|
weight: 1.0
|
||||||
|
key: "backbone"
|
||||||
|
model_name_pairs:
|
||||||
|
- ["Student", "Teacher"]
|
||||||
|
Eval:
|
||||||
|
- DistillationGTCELoss:
|
||||||
|
weight: 1.0
|
||||||
|
model_names: ["Student"]
|
||||||
|
|
||||||
|
Optimizer:
|
||||||
|
name: Momentum
|
||||||
|
momentum: 0.9
|
||||||
|
lr:
|
||||||
|
name: Cosine
|
||||||
|
learning_rate: 0.02
|
||||||
|
warmup_epoch: 5
|
||||||
|
regularizer:
|
||||||
|
name: 'L2'
|
||||||
|
coeff: 0.00001
|
||||||
|
|
||||||
|
|
||||||
|
# data loader for train and eval
|
||||||
|
DataLoader:
|
||||||
|
Train:
|
||||||
|
dataset:
|
||||||
|
name: ImageNetDataset
|
||||||
|
image_root: ./dataset/
|
||||||
|
cls_label_path: ./dataset/train_reg_all_data.txt
|
||||||
|
transform_ops:
|
||||||
|
- DecodeImage:
|
||||||
|
to_rgb: True
|
||||||
|
channel_first: False
|
||||||
|
- RandCropImage:
|
||||||
|
size: 224
|
||||||
|
- RandFlipImage:
|
||||||
|
flip_code: 1
|
||||||
|
- NormalizeImage:
|
||||||
|
scale: 1.0/255.0
|
||||||
|
mean: [0.485, 0.456, 0.406]
|
||||||
|
std: [0.229, 0.224, 0.225]
|
||||||
|
order: ''
|
||||||
|
|
||||||
|
sampler:
|
||||||
|
name: DistributedBatchSampler
|
||||||
|
batch_size: 64
|
||||||
|
drop_last: False
|
||||||
|
shuffle: True
|
||||||
|
loader:
|
||||||
|
num_workers: 4
|
||||||
|
use_shared_memory: True
|
||||||
|
|
||||||
|
Eval:
|
||||||
|
Query:
|
||||||
|
dataset:
|
||||||
|
name: VeriWild
|
||||||
|
image_root: ./dataset/Aliproduct/
|
||||||
|
cls_label_path: ./dataset/Aliproduct/val_list.txt
|
||||||
|
transform_ops:
|
||||||
|
- DecodeImage:
|
||||||
|
to_rgb: True
|
||||||
|
channel_first: False
|
||||||
|
- ResizeImage:
|
||||||
|
size: 224
|
||||||
|
- NormalizeImage:
|
||||||
|
scale: 0.00392157
|
||||||
|
mean: [0.485, 0.456, 0.406]
|
||||||
|
std: [0.229, 0.224, 0.225]
|
||||||
|
order: ''
|
||||||
|
sampler:
|
||||||
|
name: DistributedBatchSampler
|
||||||
|
batch_size: 64
|
||||||
|
drop_last: False
|
||||||
|
shuffle: False
|
||||||
|
loader:
|
||||||
|
num_workers: 4
|
||||||
|
use_shared_memory: True
|
||||||
|
|
||||||
|
Gallery:
|
||||||
|
dataset:
|
||||||
|
name: VeriWild
|
||||||
|
image_root: ./dataset/Aliproduct/
|
||||||
|
cls_label_path: ./dataset/Aliproduct/val_list.txt
|
||||||
|
transform_ops:
|
||||||
|
- DecodeImage:
|
||||||
|
to_rgb: True
|
||||||
|
channel_first: False
|
||||||
|
- ResizeImage:
|
||||||
|
size: 224
|
||||||
|
- NormalizeImage:
|
||||||
|
scale: 0.00392157
|
||||||
|
mean: [0.485, 0.456, 0.406]
|
||||||
|
std: [0.229, 0.224, 0.225]
|
||||||
|
order: ''
|
||||||
|
sampler:
|
||||||
|
name: DistributedBatchSampler
|
||||||
|
batch_size: 64
|
||||||
|
drop_last: False
|
||||||
|
shuffle: False
|
||||||
|
loader:
|
||||||
|
num_workers: 4
|
||||||
|
use_shared_memory: True
|
||||||
|
|
||||||
|
Metric:
|
||||||
|
Eval:
|
||||||
|
- Recallk:
|
||||||
|
topk: [1, 5]
|
@ -13,6 +13,7 @@ Global:
|
|||||||
# used for static mode and model export
|
# used for static mode and model export
|
||||||
image_shape: [3, 224, 224]
|
image_shape: [3, 224, 224]
|
||||||
save_inference_dir: "./inference"
|
save_inference_dir: "./inference"
|
||||||
|
use_dali: false
|
||||||
|
|
||||||
# model architecture
|
# model architecture
|
||||||
Arch:
|
Arch:
|
||||||
@ -29,9 +30,11 @@ Arch:
|
|||||||
name: MobileNetV3_large_x1_0
|
name: MobileNetV3_large_x1_0
|
||||||
pretrained: True
|
pretrained: True
|
||||||
use_ssld: True
|
use_ssld: True
|
||||||
|
dropout_prob: null
|
||||||
- Student:
|
- Student:
|
||||||
name: MobileNetV3_small_x1_0
|
name: MobileNetV3_small_x1_0
|
||||||
pretrained: False
|
pretrained: False
|
||||||
|
dropout_prob: null
|
||||||
|
|
||||||
infer_model_name: "Student"
|
infer_model_name: "Student"
|
||||||
|
|
||||||
@ -76,7 +79,6 @@ DataLoader:
|
|||||||
size: 224
|
size: 224
|
||||||
- RandFlipImage:
|
- RandFlipImage:
|
||||||
flip_code: 1
|
flip_code: 1
|
||||||
- AutoAugment:
|
|
||||||
- NormalizeImage:
|
- NormalizeImage:
|
||||||
scale: 0.00392157
|
scale: 0.00392157
|
||||||
mean: [0.485, 0.456, 0.406]
|
mean: [0.485, 0.456, 0.406]
|
||||||
@ -85,7 +87,7 @@ DataLoader:
|
|||||||
|
|
||||||
sampler:
|
sampler:
|
||||||
name: DistributedBatchSampler
|
name: DistributedBatchSampler
|
||||||
batch_size: 512
|
batch_size: 256
|
||||||
drop_last: False
|
drop_last: False
|
||||||
shuffle: True
|
shuffle: True
|
||||||
loader:
|
loader:
|
||||||
@ -112,7 +114,7 @@ DataLoader:
|
|||||||
order: ''
|
order: ''
|
||||||
sampler:
|
sampler:
|
||||||
name: DistributedBatchSampler
|
name: DistributedBatchSampler
|
||||||
batch_size: 64
|
batch_size: 128
|
||||||
drop_last: False
|
drop_last: False
|
||||||
shuffle: False
|
shuffle: False
|
||||||
loader:
|
loader:
|
||||||
|
@ -53,7 +53,8 @@ class Engine(object):
|
|||||||
self.config = config
|
self.config = config
|
||||||
self.eval_mode = self.config["Global"].get("eval_mode",
|
self.eval_mode = self.config["Global"].get("eval_mode",
|
||||||
"classification")
|
"classification")
|
||||||
if "Head" in self.config["Arch"]:
|
if "Head" in self.config["Arch"] or self.config["Arch"].get("is_rec",
|
||||||
|
False):
|
||||||
self.is_rec = True
|
self.is_rec = True
|
||||||
else:
|
else:
|
||||||
self.is_rec = False
|
self.is_rec = False
|
||||||
@ -357,7 +358,9 @@ class Engine(object):
|
|||||||
out = self.model(batch_tensor)
|
out = self.model(batch_tensor)
|
||||||
if isinstance(out, list):
|
if isinstance(out, list):
|
||||||
out = out[0]
|
out = out[0]
|
||||||
if isinstance(out, dict):
|
if isinstance(out, dict) and "logits" in out:
|
||||||
|
out = out["logits"]
|
||||||
|
if isinstance(out, dict) and "output" in out:
|
||||||
out = out["output"]
|
out = out["output"]
|
||||||
result = self.postprocess_func(out, image_file_list)
|
result = self.postprocess_func(out, image_file_list)
|
||||||
print(result)
|
print(result)
|
||||||
|
@ -78,10 +78,10 @@ def classification_eval(engine, epoch_id=0):
|
|||||||
labels = paddle.concat(label_list, 0)
|
labels = paddle.concat(label_list, 0)
|
||||||
|
|
||||||
if isinstance(out, dict):
|
if isinstance(out, dict):
|
||||||
if "logits" in out:
|
if "Student" in out:
|
||||||
out = out["logits"]
|
|
||||||
elif "Student" in out:
|
|
||||||
out = out["Student"]
|
out = out["Student"]
|
||||||
|
elif "logits" in out:
|
||||||
|
out = out["logits"]
|
||||||
else:
|
else:
|
||||||
msg = "Error: Wrong key in out!"
|
msg = "Error: Wrong key in out!"
|
||||||
raise Exception(msg)
|
raise Exception(msg)
|
||||||
@ -106,6 +106,7 @@ def classification_eval(engine, epoch_id=0):
|
|||||||
metric_dict = engine.eval_metric_func(pred, labels)
|
metric_dict = engine.eval_metric_func(pred, labels)
|
||||||
else:
|
else:
|
||||||
metric_dict = engine.eval_metric_func(out, batch[1])
|
metric_dict = engine.eval_metric_func(out, batch[1])
|
||||||
|
|
||||||
for key in metric_dict:
|
for key in metric_dict:
|
||||||
if metric_key is None:
|
if metric_key is None:
|
||||||
metric_key = key
|
metric_key = key
|
||||||
|
@ -123,6 +123,8 @@ def cal_feature(engine, name='gallery'):
|
|||||||
has_unique_id = True
|
has_unique_id = True
|
||||||
batch[2] = batch[2].reshape([-1, 1]).astype("int64")
|
batch[2] = batch[2].reshape([-1, 1]).astype("int64")
|
||||||
out = engine.model(batch[0], batch[1])
|
out = engine.model(batch[0], batch[1])
|
||||||
|
if "Student" in out:
|
||||||
|
out = out["Student"]
|
||||||
batch_feas = out["features"]
|
batch_feas = out["features"]
|
||||||
|
|
||||||
# do norm
|
# do norm
|
||||||
|
@ -20,6 +20,8 @@ from .distanceloss import DistanceLoss
|
|||||||
from .distillationloss import DistillationCELoss
|
from .distillationloss import DistillationCELoss
|
||||||
from .distillationloss import DistillationGTCELoss
|
from .distillationloss import DistillationGTCELoss
|
||||||
from .distillationloss import DistillationDMLLoss
|
from .distillationloss import DistillationDMLLoss
|
||||||
|
from .distillationloss import DistillationDistanceLoss
|
||||||
|
from .distillationloss import DistillationRKDLoss
|
||||||
from .multilabelloss import MultiLabelLoss
|
from .multilabelloss import MultiLabelLoss
|
||||||
|
|
||||||
from .deephashloss import DSHSDLoss, LCDSHLoss
|
from .deephashloss import DSHSDLoss, LCDSHLoss
|
||||||
|
@ -18,6 +18,7 @@ import paddle.nn as nn
|
|||||||
from .celoss import CELoss
|
from .celoss import CELoss
|
||||||
from .dmlloss import DMLLoss
|
from .dmlloss import DMLLoss
|
||||||
from .distanceloss import DistanceLoss
|
from .distanceloss import DistanceLoss
|
||||||
|
from .rkdloss import RKdAngle, RkdDistance
|
||||||
|
|
||||||
|
|
||||||
class DistillationCELoss(CELoss):
|
class DistillationCELoss(CELoss):
|
||||||
@ -68,7 +69,7 @@ class DistillationGTCELoss(CELoss):
|
|||||||
|
|
||||||
def forward(self, predicts, batch):
|
def forward(self, predicts, batch):
|
||||||
loss_dict = dict()
|
loss_dict = dict()
|
||||||
for idx, name in enumerate(self.model_names):
|
for _, name in enumerate(self.model_names):
|
||||||
out = predicts[name]
|
out = predicts[name]
|
||||||
if self.key is not None:
|
if self.key is not None:
|
||||||
out = out[self.key]
|
out = out[self.key]
|
||||||
@ -84,7 +85,7 @@ class DistillationDMLLoss(DMLLoss):
|
|||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
model_name_pairs=[],
|
model_name_pairs=[],
|
||||||
act=None,
|
act="softmax",
|
||||||
key=None,
|
key=None,
|
||||||
name="loss_dml"):
|
name="loss_dml"):
|
||||||
super().__init__(act=act)
|
super().__init__(act=act)
|
||||||
@ -125,7 +126,7 @@ class DistillationDistanceLoss(DistanceLoss):
|
|||||||
assert isinstance(model_name_pairs, list)
|
assert isinstance(model_name_pairs, list)
|
||||||
self.key = key
|
self.key = key
|
||||||
self.model_name_pairs = model_name_pairs
|
self.model_name_pairs = model_name_pairs
|
||||||
self.name = name + "_l2"
|
self.name = name + mode
|
||||||
|
|
||||||
def forward(self, predicts, batch):
|
def forward(self, predicts, batch):
|
||||||
loss_dict = dict()
|
loss_dict = dict()
|
||||||
@ -139,3 +140,35 @@ class DistillationDistanceLoss(DistanceLoss):
|
|||||||
for key in loss:
|
for key in loss:
|
||||||
loss_dict["{}_{}_{}".format(self.name, key, idx)] = loss[key]
|
loss_dict["{}_{}_{}".format(self.name, key, idx)] = loss[key]
|
||||||
return loss_dict
|
return loss_dict
|
||||||
|
|
||||||
|
|
||||||
|
class DistillationRKDLoss(nn.Layer):
|
||||||
|
def __init__(self,
|
||||||
|
target_size=None,
|
||||||
|
model_name_pairs=(["Student", "Teacher"], ),
|
||||||
|
student_keepkeys=[],
|
||||||
|
teacher_keepkeys=[]):
|
||||||
|
super().__init__()
|
||||||
|
self.student_keepkeys = student_keepkeys
|
||||||
|
self.teacher_keepkeys = teacher_keepkeys
|
||||||
|
self.model_name_pairs = model_name_pairs
|
||||||
|
assert len(self.student_keepkeys) == len(self.teacher_keepkeys)
|
||||||
|
|
||||||
|
self.rkd_angle_loss = RKdAngle(target_size=target_size)
|
||||||
|
self.rkd_dist_loss = RkdDistance(target_size=target_size)
|
||||||
|
|
||||||
|
def __call__(self, predicts, batch):
|
||||||
|
loss_dict = {}
|
||||||
|
for m1, m2 in self.model_name_pairs:
|
||||||
|
for idx, (
|
||||||
|
student_name, teacher_name
|
||||||
|
) in enumerate(zip(self.student_keepkeys, self.teacher_keepkeys)):
|
||||||
|
student_out = predicts[m1][student_name]
|
||||||
|
teacher_out = predicts[m2][teacher_name]
|
||||||
|
|
||||||
|
loss_dict[f"loss_angle_{idx}_{m1}_{m2}"] = self.rkd_angle_loss(
|
||||||
|
student_out, teacher_out)
|
||||||
|
loss_dict[f"loss_dist_{idx}_{m1}_{m2}"] = self.rkd_dist_loss(
|
||||||
|
student_out, teacher_out)
|
||||||
|
|
||||||
|
return loss_dict
|
||||||
|
@ -22,7 +22,7 @@ class DMLLoss(nn.Layer):
|
|||||||
DMLLoss
|
DMLLoss
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, act="softmax"):
|
def __init__(self, act="softmax", eps=1e-12):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if act is not None:
|
if act is not None:
|
||||||
assert act in ["softmax", "sigmoid"]
|
assert act in ["softmax", "sigmoid"]
|
||||||
@ -32,15 +32,19 @@ class DMLLoss(nn.Layer):
|
|||||||
self.act = nn.Sigmoid()
|
self.act = nn.Sigmoid()
|
||||||
else:
|
else:
|
||||||
self.act = None
|
self.act = None
|
||||||
|
self.eps = eps
|
||||||
|
|
||||||
def forward(self, out1, out2):
|
def _kldiv(self, x, target):
|
||||||
|
class_num = x.shape[-1]
|
||||||
|
cost = target * paddle.log(
|
||||||
|
(target + self.eps) / (x + self.eps)) * class_num
|
||||||
|
return cost
|
||||||
|
|
||||||
|
def forward(self, x, target):
|
||||||
if self.act is not None:
|
if self.act is not None:
|
||||||
out1 = self.act(out1)
|
x = F.softmax(x)
|
||||||
out2 = self.act(out2)
|
target = F.softmax(target)
|
||||||
|
loss = self._kldiv(x, target) + self._kldiv(target, x)
|
||||||
log_out1 = paddle.log(out1)
|
loss = loss / 2
|
||||||
log_out2 = paddle.log(out2)
|
loss = paddle.mean(loss)
|
||||||
loss = (F.kl_div(
|
|
||||||
log_out1, out2, reduction='batchmean') + F.kl_div(
|
|
||||||
log_out2, out1, reduction='batchmean')) / 2.0
|
|
||||||
return {"DMLLoss": loss}
|
return {"DMLLoss": loss}
|
||||||
|
97
ppcls/loss/rkdloss.py
Normal file
97
ppcls/loss/rkdloss.py
Normal file
@ -0,0 +1,97 @@
|
|||||||
|
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
import paddle.nn as nn
|
||||||
|
import paddle.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
def pdist(e, squared=False, eps=1e-12):
|
||||||
|
e_square = e.pow(2).sum(axis=1)
|
||||||
|
prod = paddle.mm(e, e.t())
|
||||||
|
res = (e_square.unsqueeze(1) + e_square.unsqueeze(0) - 2 * prod).clip(
|
||||||
|
min=eps)
|
||||||
|
|
||||||
|
if not squared:
|
||||||
|
res = res.sqrt()
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
class RKdAngle(nn.Layer):
|
||||||
|
# reference: https://github.com/lenscloth/RKD/blob/master/metric/loss.py
|
||||||
|
def __init__(self, target_size=None):
|
||||||
|
super().__init__()
|
||||||
|
if target_size is not None:
|
||||||
|
self.avgpool = paddle.nn.AdaptiveAvgPool2D(target_size)
|
||||||
|
else:
|
||||||
|
self.avgpool = None
|
||||||
|
|
||||||
|
def forward(self, student, teacher):
|
||||||
|
# GAP to reduce memory
|
||||||
|
if self.avgpool is not None:
|
||||||
|
# NxC1xH1xW1 -> NxC1x1x1
|
||||||
|
student = self.avgpool(student)
|
||||||
|
# NxC2xH2xW2 -> NxC2x1x1
|
||||||
|
teacher = self.avgpool(teacher)
|
||||||
|
|
||||||
|
# reshape for feature map distillation
|
||||||
|
bs = student.shape[0]
|
||||||
|
student = student.reshape([bs, -1])
|
||||||
|
teacher = teacher.reshape([bs, -1])
|
||||||
|
|
||||||
|
td = (teacher.unsqueeze(0) - teacher.unsqueeze(1))
|
||||||
|
norm_td = F.normalize(td, p=2, axis=2)
|
||||||
|
t_angle = paddle.bmm(norm_td, norm_td.transpose([0, 2, 1])).reshape(
|
||||||
|
[-1, 1])
|
||||||
|
|
||||||
|
sd = (student.unsqueeze(0) - student.unsqueeze(1))
|
||||||
|
norm_sd = F.normalize(sd, p=2, axis=2)
|
||||||
|
s_angle = paddle.bmm(norm_sd, norm_sd.transpose([0, 2, 1])).reshape(
|
||||||
|
[-1, 1])
|
||||||
|
loss = F.smooth_l1_loss(s_angle, t_angle, reduction='mean')
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
class RkdDistance(nn.Layer):
|
||||||
|
# reference: https://github.com/lenscloth/RKD/blob/master/metric/loss.py
|
||||||
|
def __init__(self, eps=1e-12, target_size=1):
|
||||||
|
super().__init__()
|
||||||
|
self.eps = eps
|
||||||
|
if target_size is not None:
|
||||||
|
self.avgpool = paddle.nn.AdaptiveAvgPool2D(target_size)
|
||||||
|
else:
|
||||||
|
self.avgpool = None
|
||||||
|
|
||||||
|
def forward(self, student, teacher):
|
||||||
|
# GAP to reduce memory
|
||||||
|
if self.avgpool is not None:
|
||||||
|
# NxC1xH1xW1 -> NxC1x1x1
|
||||||
|
student = self.avgpool(student)
|
||||||
|
# NxC2xH2xW2 -> NxC2x1x1
|
||||||
|
teacher = self.avgpool(teacher)
|
||||||
|
|
||||||
|
bs = student.shape[0]
|
||||||
|
student = student.reshape([bs, -1])
|
||||||
|
teacher = teacher.reshape([bs, -1])
|
||||||
|
|
||||||
|
t_d = pdist(teacher, squared=False)
|
||||||
|
mean_td = t_d.mean()
|
||||||
|
t_d = t_d / (mean_td + self.eps)
|
||||||
|
|
||||||
|
d = pdist(student, squared=False)
|
||||||
|
mean_d = d.mean()
|
||||||
|
d = d / (mean_d + self.eps)
|
||||||
|
|
||||||
|
loss = F.smooth_l1_loss(d, t_d, reduction="mean")
|
||||||
|
return loss
|
Loading…
x
Reference in New Issue
Block a user