fix ClasEvalution too much memory cost

Compute total number of correct predictions on each batch avoiding keeping all predicted logits, which will cost too much memory when the number of classes is large

#503
pull/504/head
liaoxingyu 2021-06-07 15:48:47 +08:00
parent 8f8cbf9411
commit de81b3dbaa
1 changed files with 13 additions and 15 deletions

View File

@ -46,11 +46,14 @@ class ClasEvaluator(DatasetEvaluator):
self._predictions = [] self._predictions = []
def process(self, inputs, outputs): def process(self, inputs, outputs):
predictions = { pred_logits = outputs.to(self._cpu_device, torch.float32)
"logits": outputs.to(self._cpu_device, torch.float32), labels = inputs["targets"].to(self._cpu_device)
"labels": inputs["targets"].to(self._cpu_device),
} # measure accuracy
self._predictions.append(predictions) acc1, = accuracy(pred_logits, labels, topk=(1,))
num_correct_acc1 = acc1 * labels.size(0) / 100
self._predictions.append({"num_correct": num_correct_acc1, "num_samples": labels.size(0)})
def evaluate(self): def evaluate(self):
if comm.get_world_size() > 1: if comm.get_world_size() > 1:
@ -63,21 +66,16 @@ class ClasEvaluator(DatasetEvaluator):
else: else:
predictions = self._predictions predictions = self._predictions
pred_logits = [] total_correct_num = 0
labels = [] total_samples = 0
for prediction in predictions: for prediction in predictions:
pred_logits.append(prediction['logits']) total_correct_num += prediction["num_correct"]
labels.append(prediction['labels']) total_samples += prediction["num_samples"]
pred_logits = torch.cat(pred_logits, dim=0) acc1 = total_correct_num / total_samples * 100
labels = torch.cat(labels, dim=0)
# measure accuracy and record loss
acc1, = accuracy(pred_logits, labels, topk=(1,))
self._results = OrderedDict() self._results = OrderedDict()
self._results["Acc@1"] = acc1 self._results["Acc@1"] = acc1
self._results["metric"] = acc1 self._results["metric"] = acc1
return copy.deepcopy(self._results) return copy.deepcopy(self._results)