fix: fix key error in distillation

This commit is contained in:
gaotingquan 2021-12-15 07:53:09 +00:00 committed by Tingquan Gao
parent 26f329c20a
commit a9a2e2d372
2 changed files with 9 additions and 2 deletions

View File

@ -78,7 +78,13 @@ def classification_eval(engine, epoch_id=0):
labels = paddle.concat(label_list, 0)
if isinstance(out, dict):
if "logits" in out:
out = out["logits"]
elif "Student" in out:
out = out["Student"]
else:
msg = "Error: Wrong key in out!"
raise Exception(msg)
if isinstance(out, list):
pred = []
for x in out:

View File

@ -222,6 +222,7 @@ class DistillationTopkAcc(TopkAcc):
self.feature_key = feature_key
def forward(self, x, label):
if isinstance(x, dict):
x = x[self.model_key]
if self.feature_key is not None:
x = x[self.feature_key]