mirror of https://github.com/JDAI-CV/fast-reid.git
fix ClasEvalution too much memory cost
Compute total number of correct predictions on each batch avoiding keeping all predicted logits, which will cost too much memory when the number of classes is large #503pull/504/head
parent
8f8cbf9411
commit
de81b3dbaa
|
@ -46,11 +46,14 @@ class ClasEvaluator(DatasetEvaluator):
|
||||||
self._predictions = []
|
self._predictions = []
|
||||||
|
|
||||||
def process(self, inputs, outputs):
|
def process(self, inputs, outputs):
|
||||||
predictions = {
|
pred_logits = outputs.to(self._cpu_device, torch.float32)
|
||||||
"logits": outputs.to(self._cpu_device, torch.float32),
|
labels = inputs["targets"].to(self._cpu_device)
|
||||||
"labels": inputs["targets"].to(self._cpu_device),
|
|
||||||
}
|
# measure accuracy
|
||||||
self._predictions.append(predictions)
|
acc1, = accuracy(pred_logits, labels, topk=(1,))
|
||||||
|
num_correct_acc1 = acc1 * labels.size(0) / 100
|
||||||
|
|
||||||
|
self._predictions.append({"num_correct": num_correct_acc1, "num_samples": labels.size(0)})
|
||||||
|
|
||||||
def evaluate(self):
|
def evaluate(self):
|
||||||
if comm.get_world_size() > 1:
|
if comm.get_world_size() > 1:
|
||||||
|
@ -63,21 +66,16 @@ class ClasEvaluator(DatasetEvaluator):
|
||||||
else:
|
else:
|
||||||
predictions = self._predictions
|
predictions = self._predictions
|
||||||
|
|
||||||
pred_logits = []
|
total_correct_num = 0
|
||||||
labels = []
|
total_samples = 0
|
||||||
for prediction in predictions:
|
for prediction in predictions:
|
||||||
pred_logits.append(prediction['logits'])
|
total_correct_num += prediction["num_correct"]
|
||||||
labels.append(prediction['labels'])
|
total_samples += prediction["num_samples"]
|
||||||
|
|
||||||
pred_logits = torch.cat(pred_logits, dim=0)
|
acc1 = total_correct_num / total_samples * 100
|
||||||
labels = torch.cat(labels, dim=0)
|
|
||||||
|
|
||||||
# measure accuracy and record loss
|
|
||||||
acc1, = accuracy(pred_logits, labels, topk=(1,))
|
|
||||||
|
|
||||||
self._results = OrderedDict()
|
self._results = OrderedDict()
|
||||||
self._results["Acc@1"] = acc1
|
self._results["Acc@1"] = acc1
|
||||||
|
|
||||||
self._results["metric"] = acc1
|
self._results["metric"] = acc1
|
||||||
|
|
||||||
return copy.deepcopy(self._results)
|
return copy.deepcopy(self._results)
|
||||||
|
|
Loading…
Reference in New Issue