fox dist err (#1621)

* fox dist err

* fix init

* fix init
pull/1600/head^2
littletomatodonkey 2022-01-06 09:46:52 +08:00 committed by GitHub
parent aea712cc87
commit e0a6e5bf38
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 6 additions and 15 deletions

View File

@ -28,7 +28,6 @@ from ppcls.utils import logger
from ppcls.utils.save_load import load_dygraph_pretrain
from ppcls.arch.slim import prune_model, quantize_model
__all__ = ["build_model", "RecModel", "DistillationModel"]
@ -82,13 +81,11 @@ class RecModel(TheseusLayer):
out["backbone"] = x
if self.neck is not None:
x = self.neck(x)
out["neck"] = x
out["features"] = x
if self.head is not None:
y = self.head(x, label)
out["neck"] = x
else:
y = None
out["logits"] = y
out["logits"] = y
return out

View File

@ -1,5 +1,4 @@
# global configs
# global configs
Global:
checkpoints: null
pretrained_model: null
@ -85,11 +84,6 @@ Loss:
key: "logits"
model_name_pairs:
- ["Student", "Teacher"]
- DistillationDMLLoss:
weight: 1.0
key: "logits"
model_name_pairs:
- ["Student", "Teacher"]
Eval:
- DistillationGTCELoss:
weight: 1.0

View File

@ -57,7 +57,7 @@ Optimizer:
momentum: 0.9
lr:
name: Cosine
learning_rate: 1.3
learning_rate: 0.65
warmup_epoch: 5
regularizer:
name: 'L2'

View File

@ -69,7 +69,7 @@ class DistillationGTCELoss(CELoss):
def forward(self, predicts, batch):
loss_dict = dict()
for _, name in enumerate(self.model_names):
for name in self.model_names:
out = predicts[name]
if self.key is not None:
out = out[self.key]

View File

@ -42,8 +42,8 @@ class DMLLoss(nn.Layer):
def forward(self, x, target):
if self.act is not None:
x = F.softmax(x)
target = F.softmax(target)
x = self.act(x)
target = self.act(target)
loss = self._kldiv(x, target) + self._kldiv(target, x)
loss = loss / 2
loss = paddle.mean(loss)