mirror of
https://github.com/PaddlePaddle/PaddleClas.git
synced 2025-06-03 21:55:06 +08:00
fix: fix key error in distillation
This commit is contained in:
parent
26f329c20a
commit
a9a2e2d372
@ -78,7 +78,13 @@ def classification_eval(engine, epoch_id=0):
|
|||||||
labels = paddle.concat(label_list, 0)
|
labels = paddle.concat(label_list, 0)
|
||||||
|
|
||||||
if isinstance(out, dict):
|
if isinstance(out, dict):
|
||||||
|
if "logits" in out:
|
||||||
out = out["logits"]
|
out = out["logits"]
|
||||||
|
elif "Student" in out:
|
||||||
|
out = out["Student"]
|
||||||
|
else:
|
||||||
|
msg = "Error: Wrong key in out!"
|
||||||
|
raise Exception(msg)
|
||||||
if isinstance(out, list):
|
if isinstance(out, list):
|
||||||
pred = []
|
pred = []
|
||||||
for x in out:
|
for x in out:
|
||||||
|
@ -222,6 +222,7 @@ class DistillationTopkAcc(TopkAcc):
|
|||||||
self.feature_key = feature_key
|
self.feature_key = feature_key
|
||||||
|
|
||||||
def forward(self, x, label):
|
def forward(self, x, label):
|
||||||
|
if isinstance(x, dict):
|
||||||
x = x[self.model_key]
|
x = x[self.model_key]
|
||||||
if self.feature_key is not None:
|
if self.feature_key is not None:
|
||||||
x = x[self.feature_key]
|
x = x[self.feature_key]
|
||||||
|
Loading…
x
Reference in New Issue
Block a user