fix ClasEvalution too much memory cost

Compute total number of correct predictions on each batch avoiding keeping all predicted logits, which will cost too much memory when the number of classes is large

#503
pull/504/head
liaoxingyu 2021-06-07 15:48:47 +08:00
parent 8f8cbf9411
commit de81b3dbaa
1 changed files with 13 additions and 15 deletions

View File

@ -46,11 +46,14 @@ class ClasEvaluator(DatasetEvaluator):
self._predictions = []
def process(self, inputs, outputs):
predictions = {
"logits": outputs.to(self._cpu_device, torch.float32),
"labels": inputs["targets"].to(self._cpu_device),
}
self._predictions.append(predictions)
pred_logits = outputs.to(self._cpu_device, torch.float32)
labels = inputs["targets"].to(self._cpu_device)
# measure accuracy
acc1, = accuracy(pred_logits, labels, topk=(1,))
num_correct_acc1 = acc1 * labels.size(0) / 100
self._predictions.append({"num_correct": num_correct_acc1, "num_samples": labels.size(0)})
def evaluate(self):
if comm.get_world_size() > 1:
@ -63,21 +66,16 @@ class ClasEvaluator(DatasetEvaluator):
else:
predictions = self._predictions
pred_logits = []
labels = []
total_correct_num = 0
total_samples = 0
for prediction in predictions:
pred_logits.append(prediction['logits'])
labels.append(prediction['labels'])
total_correct_num += prediction["num_correct"]
total_samples += prediction["num_samples"]
pred_logits = torch.cat(pred_logits, dim=0)
labels = torch.cat(labels, dim=0)
# measure accuracy and record loss
acc1, = accuracy(pred_logits, labels, topk=(1,))
acc1 = total_correct_num / total_samples * 100
self._results = OrderedDict()
self._results["Acc@1"] = acc1
self._results["metric"] = acc1
return copy.deepcopy(self._results)