fix bug in attribute evaluation

# 535
pull/566/head
Sherlock 2021-08-12 14:17:44 +08:00
parent 7e652fea2a
commit ced654431b
1 changed files with 4 additions and 2 deletions

View File

@ -22,6 +22,8 @@ class AttrEvaluator(DatasetEvaluator):
self.thres = thres
self._output_dir = output_dir
self._cpu_device = torch.device("cpu")
self.pred_logits = []
self.gt_labels = []
@ -30,8 +32,8 @@ class AttrEvaluator(DatasetEvaluator):
self.gt_labels = []
def process(self, inputs, outputs):
self.gt_labels.extend(inputs["targets"])
self.pred_logits.extend(outputs.cpu())
self.gt_labels.extend(inputs["targets"].to(self._cpu_device))
self.pred_logits.extend(outputs.to(self._cpu_device, torch.float32))
@staticmethod
def get_attr_metrics(gt_labels, pred_logits, thres):