mirror of
https://github.com/PaddlePaddle/PaddleClas.git
synced 2025-06-03 21:55:06 +08:00
fix: fix key error in distillation
This commit is contained in:
parent
26f329c20a
commit
a9a2e2d372
@ -78,7 +78,13 @@ def classification_eval(engine, epoch_id=0):
|
||||
labels = paddle.concat(label_list, 0)
|
||||
|
||||
if isinstance(out, dict):
|
||||
if "logits" in out:
|
||||
out = out["logits"]
|
||||
elif "Student" in out:
|
||||
out = out["Student"]
|
||||
else:
|
||||
msg = "Error: Wrong key in out!"
|
||||
raise Exception(msg)
|
||||
if isinstance(out, list):
|
||||
pred = []
|
||||
for x in out:
|
||||
|
@ -222,6 +222,7 @@ class DistillationTopkAcc(TopkAcc):
|
||||
self.feature_key = feature_key
|
||||
|
||||
def forward(self, x, label):
|
||||
if isinstance(x, dict):
|
||||
x = x[self.model_key]
|
||||
if self.feature_key is not None:
|
||||
x = x[self.feature_key]
|
||||
|
Loading…
x
Reference in New Issue
Block a user