fix clas distributed eval bug
parent
14c43838a6
commit
80209e0218
|
@ -34,6 +34,10 @@ def classification_eval(engine, epoch_id=0):
|
|||
|
||||
metric_key = None
|
||||
tic = time.time()
|
||||
accum_samples = 0
|
||||
total_samples = len(
|
||||
engine.eval_dataloader.
|
||||
dataset) if not engine.use_dali else engine.eval_dataloader.size
|
||||
max_iter = len(engine.eval_dataloader) - 1 if platform.system(
|
||||
) == "Windows" else len(engine.eval_dataloader)
|
||||
for iter_id, batch in enumerate(engine.eval_dataloader):
|
||||
|
@ -61,15 +65,37 @@ def classification_eval(engine, epoch_id=0):
|
|||
if key not in output_info:
|
||||
output_info[key] = AverageMeter(key, '7.5f')
|
||||
output_info[key].update(loss_dict[key].numpy()[0], batch_size)
|
||||
|
||||
# just for DistributedBatchSampler issue: repeat sampling
|
||||
current_samples = batch_size * paddle.distributed.get_world_size()
|
||||
accum_samples += current_samples
|
||||
|
||||
# calc metric
|
||||
if engine.eval_metric_func is not None:
|
||||
metric_dict = engine.eval_metric_func(out, batch[1])
|
||||
if paddle.distributed.get_world_size() > 1:
|
||||
pred_list = []
|
||||
label_list = []
|
||||
if isinstance(out, dict):
|
||||
out = out["logits"]
|
||||
paddle.distributed.all_gather(pred_list, out)
|
||||
paddle.distributed.all_gather(label_list, batch[1])
|
||||
pred = paddle.concat(pred_list, 0)
|
||||
labels = paddle.concat(label_list, 0)
|
||||
if accum_samples > total_samples:
|
||||
pred = pred[:total_samples + current_samples -
|
||||
accum_samples]
|
||||
labels = labels[:total_samples + current_samples -
|
||||
accum_samples]
|
||||
current_samples = total_samples + current_samples - accum_samples
|
||||
metric_dict = engine.eval_metric_func(pred, labels)
|
||||
|
||||
for key in metric_dict:
|
||||
paddle.distributed.all_reduce(
|
||||
metric_dict[key], op=paddle.distributed.ReduceOp.SUM)
|
||||
metric_dict[key] = metric_dict[
|
||||
key] / paddle.distributed.get_world_size()
|
||||
else:
|
||||
metric_dict = engine.eval_metric_func(out, batch[1])
|
||||
for key in metric_dict:
|
||||
if metric_key is None:
|
||||
metric_key = key
|
||||
|
|
Loading…
Reference in New Issue