fix: fix bug about calc loss in dist

pull/1833/head
gaotingquan 2022-04-12 06:56:44 +00:00
parent 9d2bcbd703
commit c46189bad0
No known key found for this signature in database
GPG Key ID: F3EF7F42536A30B7
1 changed files with 51 additions and 49 deletions
ppcls/engine/evaluation

View File

@ -66,68 +66,70 @@ def classification_eval(engine, epoch_id=0):
},
level=amp_level):
out = engine.model(batch[0])
# calc loss
if engine.eval_loss_func is not None:
loss_dict = engine.eval_loss_func(out, batch[1])
for key in loss_dict:
if key not in output_info:
output_info[key] = AverageMeter(key, '7.5f')
output_info[key].update(loss_dict[key].numpy()[0],
batch_size)
else:
out = engine.model(batch[0])
# calc loss
if engine.eval_loss_func is not None:
loss_dict = engine.eval_loss_func(out, batch[1])
for key in loss_dict:
if key not in output_info:
output_info[key] = AverageMeter(key, '7.5f')
output_info[key].update(loss_dict[key].numpy()[0],
batch_size)
# just for DistributedBatchSampler issue: repeat sampling
current_samples = batch_size * paddle.distributed.get_world_size()
accum_samples += current_samples
# calc metric
if engine.eval_metric_func is not None:
if paddle.distributed.get_world_size() > 1:
label_list = []
paddle.distributed.all_gather(label_list, batch[1])
labels = paddle.concat(label_list, 0)
# gather Tensor when distributed
if paddle.distributed.get_world_size() > 1:
label_list = []
paddle.distributed.all_gather(label_list, batch[1])
labels = paddle.concat(label_list, 0)
if isinstance(out, dict):
if "Student" in out:
out = out["Student"]
if isinstance(out, dict):
out = out["logits"]
elif "logits" in out:
if isinstance(out, dict):
if "Student" in out:
out = out["Student"]
if isinstance(out, dict):
out = out["logits"]
else:
msg = "Error: Wrong key in out!"
raise Exception(msg)
if isinstance(out, list):
pred = []
for x in out:
pred_list = []
paddle.distributed.all_gather(pred_list, x)
pred_x = paddle.concat(pred_list, 0)
pred.append(pred_x)
elif "logits" in out:
out = out["logits"]
else:
msg = "Error: Wrong key in out!"
raise Exception(msg)
if isinstance(out, list):
preds = []
for x in out:
pred_list = []
paddle.distributed.all_gather(pred_list, out)
pred = paddle.concat(pred_list, 0)
if accum_samples > total_samples and not engine.use_dali:
pred = pred[:total_samples + current_samples -
accum_samples]
labels = labels[:total_samples + current_samples -
accum_samples]
current_samples = total_samples + current_samples - accum_samples
metric_dict = engine.eval_metric_func(pred, labels)
paddle.distributed.all_gather(pred_list, x)
pred_x = paddle.concat(pred_list, 0)
preds.append(pred_x)
else:
metric_dict = engine.eval_metric_func(out, batch[1])
pred_list = []
paddle.distributed.all_gather(pred_list, out)
preds = paddle.concat(pred_list, 0)
if accum_samples > total_samples and not engine.use_dali:
preds = preds[:total_samples + current_samples - accum_samples]
labels = labels[:total_samples + current_samples -
accum_samples]
current_samples = total_samples + current_samples - accum_samples
else:
labels = batch[1]
preds = out
# calc loss
if engine.eval_loss_func is not None:
if engine.amp and engine.config["AMP"].get("use_fp16_test", False):
amp_level = engine.config['AMP'].get("level", "O1").upper()
with paddle.amp.auto_cast(
custom_black_list={
"flatten_contiguous_range", "greater_than"
},
level=amp_level):
loss_dict = engine.eval_loss_func(preds, labels)
else:
loss_dict = engine.eval_loss_func(preds, labels)
for key in loss_dict:
if key not in output_info:
output_info[key] = AverageMeter(key, '7.5f')
output_info[key].update(loss_dict[key].numpy()[0], batch_size)
# calc metric
if engine.eval_metric_func is not None:
metric_dict = engine.eval_metric_func(preds, labels)
for key in metric_dict:
if metric_key is None:
metric_key = key