diff --git a/projects/FastAttr/fastattr/attr_evaluation.py b/projects/FastAttr/fastattr/attr_evaluation.py index 1725818..20dad10 100644 --- a/projects/FastAttr/fastattr/attr_evaluation.py +++ b/projects/FastAttr/fastattr/attr_evaluation.py @@ -22,6 +22,8 @@ class AttrEvaluator(DatasetEvaluator): self.thres = thres self._output_dir = output_dir + self._cpu_device = torch.device("cpu") + self.pred_logits = [] self.gt_labels = [] @@ -30,8 +32,8 @@ class AttrEvaluator(DatasetEvaluator): self.gt_labels = [] def process(self, inputs, outputs): - self.gt_labels.extend(inputs["targets"]) - self.pred_logits.extend(outputs.cpu()) + self.gt_labels.extend(inputs["targets"].to(self._cpu_device)) + self.pred_logits.extend(outputs.to(self._cpu_device, torch.float32)) @staticmethod def get_attr_metrics(gt_labels, pred_logits, thres):