fix: fix bug about calc loss in dist
parent
9d2bcbd703
commit
c46189bad0
ppcls/engine/evaluation
|
@ -66,68 +66,70 @@ def classification_eval(engine, epoch_id=0):
|
|||
},
|
||||
level=amp_level):
|
||||
out = engine.model(batch[0])
|
||||
# calc loss
|
||||
if engine.eval_loss_func is not None:
|
||||
loss_dict = engine.eval_loss_func(out, batch[1])
|
||||
for key in loss_dict:
|
||||
if key not in output_info:
|
||||
output_info[key] = AverageMeter(key, '7.5f')
|
||||
output_info[key].update(loss_dict[key].numpy()[0],
|
||||
batch_size)
|
||||
else:
|
||||
out = engine.model(batch[0])
|
||||
# calc loss
|
||||
if engine.eval_loss_func is not None:
|
||||
loss_dict = engine.eval_loss_func(out, batch[1])
|
||||
for key in loss_dict:
|
||||
if key not in output_info:
|
||||
output_info[key] = AverageMeter(key, '7.5f')
|
||||
output_info[key].update(loss_dict[key].numpy()[0],
|
||||
batch_size)
|
||||
|
||||
# just for DistributedBatchSampler issue: repeat sampling
|
||||
current_samples = batch_size * paddle.distributed.get_world_size()
|
||||
accum_samples += current_samples
|
||||
|
||||
# calc metric
|
||||
if engine.eval_metric_func is not None:
|
||||
if paddle.distributed.get_world_size() > 1:
|
||||
label_list = []
|
||||
paddle.distributed.all_gather(label_list, batch[1])
|
||||
labels = paddle.concat(label_list, 0)
|
||||
# gather Tensor when distributed
|
||||
if paddle.distributed.get_world_size() > 1:
|
||||
label_list = []
|
||||
paddle.distributed.all_gather(label_list, batch[1])
|
||||
labels = paddle.concat(label_list, 0)
|
||||
|
||||
if isinstance(out, dict):
|
||||
if "Student" in out:
|
||||
out = out["Student"]
|
||||
if isinstance(out, dict):
|
||||
out = out["logits"]
|
||||
elif "logits" in out:
|
||||
if isinstance(out, dict):
|
||||
if "Student" in out:
|
||||
out = out["Student"]
|
||||
if isinstance(out, dict):
|
||||
out = out["logits"]
|
||||
else:
|
||||
msg = "Error: Wrong key in out!"
|
||||
raise Exception(msg)
|
||||
if isinstance(out, list):
|
||||
pred = []
|
||||
for x in out:
|
||||
pred_list = []
|
||||
paddle.distributed.all_gather(pred_list, x)
|
||||
pred_x = paddle.concat(pred_list, 0)
|
||||
pred.append(pred_x)
|
||||
elif "logits" in out:
|
||||
out = out["logits"]
|
||||
else:
|
||||
msg = "Error: Wrong key in out!"
|
||||
raise Exception(msg)
|
||||
if isinstance(out, list):
|
||||
preds = []
|
||||
for x in out:
|
||||
pred_list = []
|
||||
paddle.distributed.all_gather(pred_list, out)
|
||||
pred = paddle.concat(pred_list, 0)
|
||||
|
||||
if accum_samples > total_samples and not engine.use_dali:
|
||||
pred = pred[:total_samples + current_samples -
|
||||
accum_samples]
|
||||
labels = labels[:total_samples + current_samples -
|
||||
accum_samples]
|
||||
current_samples = total_samples + current_samples - accum_samples
|
||||
metric_dict = engine.eval_metric_func(pred, labels)
|
||||
paddle.distributed.all_gather(pred_list, x)
|
||||
pred_x = paddle.concat(pred_list, 0)
|
||||
preds.append(pred_x)
|
||||
else:
|
||||
metric_dict = engine.eval_metric_func(out, batch[1])
|
||||
pred_list = []
|
||||
paddle.distributed.all_gather(pred_list, out)
|
||||
preds = paddle.concat(pred_list, 0)
|
||||
|
||||
if accum_samples > total_samples and not engine.use_dali:
|
||||
preds = preds[:total_samples + current_samples - accum_samples]
|
||||
labels = labels[:total_samples + current_samples -
|
||||
accum_samples]
|
||||
current_samples = total_samples + current_samples - accum_samples
|
||||
else:
|
||||
labels = batch[1]
|
||||
preds = out
|
||||
|
||||
# calc loss
|
||||
if engine.eval_loss_func is not None:
|
||||
if engine.amp and engine.config["AMP"].get("use_fp16_test", False):
|
||||
amp_level = engine.config['AMP'].get("level", "O1").upper()
|
||||
with paddle.amp.auto_cast(
|
||||
custom_black_list={
|
||||
"flatten_contiguous_range", "greater_than"
|
||||
},
|
||||
level=amp_level):
|
||||
loss_dict = engine.eval_loss_func(preds, labels)
|
||||
else:
|
||||
loss_dict = engine.eval_loss_func(preds, labels)
|
||||
|
||||
for key in loss_dict:
|
||||
if key not in output_info:
|
||||
output_info[key] = AverageMeter(key, '7.5f')
|
||||
output_info[key].update(loss_dict[key].numpy()[0], batch_size)
|
||||
# calc metric
|
||||
if engine.eval_metric_func is not None:
|
||||
metric_dict = engine.eval_metric_func(preds, labels)
|
||||
for key in metric_dict:
|
||||
if metric_key is None:
|
||||
metric_key = key
|
||||
|
|
Loading…
Reference in New Issue