fix: fix bug about calc loss in dist
parent
9d2bcbd703
commit
c46189bad0
|
@ -66,68 +66,70 @@ def classification_eval(engine, epoch_id=0):
|
||||||
},
|
},
|
||||||
level=amp_level):
|
level=amp_level):
|
||||||
out = engine.model(batch[0])
|
out = engine.model(batch[0])
|
||||||
# calc loss
|
|
||||||
if engine.eval_loss_func is not None:
|
|
||||||
loss_dict = engine.eval_loss_func(out, batch[1])
|
|
||||||
for key in loss_dict:
|
|
||||||
if key not in output_info:
|
|
||||||
output_info[key] = AverageMeter(key, '7.5f')
|
|
||||||
output_info[key].update(loss_dict[key].numpy()[0],
|
|
||||||
batch_size)
|
|
||||||
else:
|
else:
|
||||||
out = engine.model(batch[0])
|
out = engine.model(batch[0])
|
||||||
# calc loss
|
|
||||||
if engine.eval_loss_func is not None:
|
|
||||||
loss_dict = engine.eval_loss_func(out, batch[1])
|
|
||||||
for key in loss_dict:
|
|
||||||
if key not in output_info:
|
|
||||||
output_info[key] = AverageMeter(key, '7.5f')
|
|
||||||
output_info[key].update(loss_dict[key].numpy()[0],
|
|
||||||
batch_size)
|
|
||||||
|
|
||||||
# just for DistributedBatchSampler issue: repeat sampling
|
# just for DistributedBatchSampler issue: repeat sampling
|
||||||
current_samples = batch_size * paddle.distributed.get_world_size()
|
current_samples = batch_size * paddle.distributed.get_world_size()
|
||||||
accum_samples += current_samples
|
accum_samples += current_samples
|
||||||
|
|
||||||
# calc metric
|
# gather Tensor when distributed
|
||||||
if engine.eval_metric_func is not None:
|
if paddle.distributed.get_world_size() > 1:
|
||||||
if paddle.distributed.get_world_size() > 1:
|
label_list = []
|
||||||
label_list = []
|
paddle.distributed.all_gather(label_list, batch[1])
|
||||||
paddle.distributed.all_gather(label_list, batch[1])
|
labels = paddle.concat(label_list, 0)
|
||||||
labels = paddle.concat(label_list, 0)
|
|
||||||
|
|
||||||
if isinstance(out, dict):
|
if isinstance(out, dict):
|
||||||
if "Student" in out:
|
if "Student" in out:
|
||||||
out = out["Student"]
|
out = out["Student"]
|
||||||
if isinstance(out, dict):
|
if isinstance(out, dict):
|
||||||
out = out["logits"]
|
|
||||||
elif "logits" in out:
|
|
||||||
out = out["logits"]
|
out = out["logits"]
|
||||||
else:
|
elif "logits" in out:
|
||||||
msg = "Error: Wrong key in out!"
|
out = out["logits"]
|
||||||
raise Exception(msg)
|
|
||||||
if isinstance(out, list):
|
|
||||||
pred = []
|
|
||||||
for x in out:
|
|
||||||
pred_list = []
|
|
||||||
paddle.distributed.all_gather(pred_list, x)
|
|
||||||
pred_x = paddle.concat(pred_list, 0)
|
|
||||||
pred.append(pred_x)
|
|
||||||
else:
|
else:
|
||||||
|
msg = "Error: Wrong key in out!"
|
||||||
|
raise Exception(msg)
|
||||||
|
if isinstance(out, list):
|
||||||
|
preds = []
|
||||||
|
for x in out:
|
||||||
pred_list = []
|
pred_list = []
|
||||||
paddle.distributed.all_gather(pred_list, out)
|
paddle.distributed.all_gather(pred_list, x)
|
||||||
pred = paddle.concat(pred_list, 0)
|
pred_x = paddle.concat(pred_list, 0)
|
||||||
|
preds.append(pred_x)
|
||||||
if accum_samples > total_samples and not engine.use_dali:
|
|
||||||
pred = pred[:total_samples + current_samples -
|
|
||||||
accum_samples]
|
|
||||||
labels = labels[:total_samples + current_samples -
|
|
||||||
accum_samples]
|
|
||||||
current_samples = total_samples + current_samples - accum_samples
|
|
||||||
metric_dict = engine.eval_metric_func(pred, labels)
|
|
||||||
else:
|
else:
|
||||||
metric_dict = engine.eval_metric_func(out, batch[1])
|
pred_list = []
|
||||||
|
paddle.distributed.all_gather(pred_list, out)
|
||||||
|
preds = paddle.concat(pred_list, 0)
|
||||||
|
|
||||||
|
if accum_samples > total_samples and not engine.use_dali:
|
||||||
|
preds = preds[:total_samples + current_samples - accum_samples]
|
||||||
|
labels = labels[:total_samples + current_samples -
|
||||||
|
accum_samples]
|
||||||
|
current_samples = total_samples + current_samples - accum_samples
|
||||||
|
else:
|
||||||
|
labels = batch[1]
|
||||||
|
preds = out
|
||||||
|
|
||||||
|
# calc loss
|
||||||
|
if engine.eval_loss_func is not None:
|
||||||
|
if engine.amp and engine.config["AMP"].get("use_fp16_test", False):
|
||||||
|
amp_level = engine.config['AMP'].get("level", "O1").upper()
|
||||||
|
with paddle.amp.auto_cast(
|
||||||
|
custom_black_list={
|
||||||
|
"flatten_contiguous_range", "greater_than"
|
||||||
|
},
|
||||||
|
level=amp_level):
|
||||||
|
loss_dict = engine.eval_loss_func(preds, labels)
|
||||||
|
else:
|
||||||
|
loss_dict = engine.eval_loss_func(preds, labels)
|
||||||
|
|
||||||
|
for key in loss_dict:
|
||||||
|
if key not in output_info:
|
||||||
|
output_info[key] = AverageMeter(key, '7.5f')
|
||||||
|
output_info[key].update(loss_dict[key].numpy()[0], batch_size)
|
||||||
|
# calc metric
|
||||||
|
if engine.eval_metric_func is not None:
|
||||||
|
metric_dict = engine.eval_metric_func(preds, labels)
|
||||||
for key in metric_dict:
|
for key in metric_dict:
|
||||||
if metric_key is None:
|
if metric_key is None:
|
||||||
metric_key = key
|
metric_key = key
|
||||||
|
|
Loading…
Reference in New Issue