mirror of https://github.com/JDAI-CV/fast-reid.git
parent
7e652fea2a
commit
ced654431b
|
@ -22,6 +22,8 @@ class AttrEvaluator(DatasetEvaluator):
|
|||
self.thres = thres
|
||||
self._output_dir = output_dir
|
||||
|
||||
self._cpu_device = torch.device("cpu")
|
||||
|
||||
self.pred_logits = []
|
||||
self.gt_labels = []
|
||||
|
||||
|
@ -30,8 +32,8 @@ class AttrEvaluator(DatasetEvaluator):
|
|||
self.gt_labels = []
|
||||
|
||||
def process(self, inputs, outputs):
|
||||
self.gt_labels.extend(inputs["targets"])
|
||||
self.pred_logits.extend(outputs.cpu())
|
||||
self.gt_labels.extend(inputs["targets"].to(self._cpu_device))
|
||||
self.pred_logits.extend(outputs.to(self._cpu_device, torch.float32))
|
||||
|
||||
@staticmethod
|
||||
def get_attr_metrics(gt_labels, pred_logits, thres):
|
||||
|
|
Loading…
Reference in New Issue