Merge pull request #1833 from TingquanGao/dev/fix_dist_loss
fix calc metric error and calc loss error in distributed.pull/1874/head
commit
4e6c36e269
|
@ -73,68 +73,71 @@ def classification_eval(engine, epoch_id=0):
|
|||
},
|
||||
level=amp_level):
|
||||
out = engine.model(batch[0])
|
||||
# calc loss
|
||||
if engine.eval_loss_func is not None:
|
||||
loss_dict = engine.eval_loss_func(out, batch[1])
|
||||
for key in loss_dict:
|
||||
if key not in output_info:
|
||||
output_info[key] = AverageMeter(key, '7.5f')
|
||||
output_info[key].update(loss_dict[key].numpy()[0],
|
||||
batch_size)
|
||||
else:
|
||||
out = engine.model(batch[0])
|
||||
# calc loss
|
||||
if engine.eval_loss_func is not None:
|
||||
loss_dict = engine.eval_loss_func(out, batch[1])
|
||||
for key in loss_dict:
|
||||
if key not in output_info:
|
||||
output_info[key] = AverageMeter(key, '7.5f')
|
||||
output_info[key].update(loss_dict[key].numpy()[0],
|
||||
batch_size)
|
||||
|
||||
# just for DistributedBatchSampler issue: repeat sampling
|
||||
current_samples = batch_size * paddle.distributed.get_world_size()
|
||||
accum_samples += current_samples
|
||||
|
||||
# calc metric
|
||||
if engine.eval_metric_func is not None:
|
||||
if paddle.distributed.get_world_size() > 1:
|
||||
label_list = []
|
||||
paddle.distributed.all_gather(label_list, batch[1])
|
||||
labels = paddle.concat(label_list, 0)
|
||||
# gather Tensor when distributed
|
||||
if paddle.distributed.get_world_size() > 1:
|
||||
label_list = []
|
||||
paddle.distributed.all_gather(label_list, batch[1])
|
||||
labels = paddle.concat(label_list, 0)
|
||||
|
||||
if isinstance(out, dict):
|
||||
if "Student" in out:
|
||||
out = out["Student"]
|
||||
if isinstance(out, dict):
|
||||
out = out["logits"]
|
||||
elif "logits" in out:
|
||||
if isinstance(out, dict):
|
||||
if "Student" in out:
|
||||
out = out["Student"]
|
||||
if isinstance(out, dict):
|
||||
out = out["logits"]
|
||||
else:
|
||||
msg = "Error: Wrong key in out!"
|
||||
raise Exception(msg)
|
||||
if isinstance(out, list):
|
||||
pred = []
|
||||
for x in out:
|
||||
pred_list = []
|
||||
paddle.distributed.all_gather(pred_list, x)
|
||||
pred_x = paddle.concat(pred_list, 0)
|
||||
pred.append(pred_x)
|
||||
elif "logits" in out:
|
||||
out = out["logits"]
|
||||
else:
|
||||
msg = "Error: Wrong key in out!"
|
||||
raise Exception(msg)
|
||||
if isinstance(out, list):
|
||||
preds = []
|
||||
for x in out:
|
||||
pred_list = []
|
||||
paddle.distributed.all_gather(pred_list, out)
|
||||
pred = paddle.concat(pred_list, 0)
|
||||
|
||||
if accum_samples > total_samples and not engine.use_dali:
|
||||
pred = pred[:total_samples + current_samples -
|
||||
accum_samples]
|
||||
labels = labels[:total_samples + current_samples -
|
||||
accum_samples]
|
||||
current_samples = total_samples + current_samples - accum_samples
|
||||
metric_dict = engine.eval_metric_func(pred, labels)
|
||||
paddle.distributed.all_gather(pred_list, x)
|
||||
pred_x = paddle.concat(pred_list, 0)
|
||||
preds.append(pred_x)
|
||||
else:
|
||||
metric_dict = engine.eval_metric_func(out, batch[1])
|
||||
pred_list = []
|
||||
paddle.distributed.all_gather(pred_list, out)
|
||||
preds = paddle.concat(pred_list, 0)
|
||||
|
||||
if accum_samples > total_samples and not engine.use_dali:
|
||||
preds = preds[:total_samples + current_samples - accum_samples]
|
||||
labels = labels[:total_samples + current_samples -
|
||||
accum_samples]
|
||||
current_samples = total_samples + current_samples - accum_samples
|
||||
else:
|
||||
labels = batch[1]
|
||||
preds = out
|
||||
|
||||
# calc loss
|
||||
if engine.eval_loss_func is not None:
|
||||
if engine.amp and engine.config["AMP"].get("use_fp16_test", False):
|
||||
amp_level = engine.config['AMP'].get("level", "O1").upper()
|
||||
with paddle.amp.auto_cast(
|
||||
custom_black_list={
|
||||
"flatten_contiguous_range", "greater_than"
|
||||
},
|
||||
level=amp_level):
|
||||
loss_dict = engine.eval_loss_func(preds, labels)
|
||||
else:
|
||||
loss_dict = engine.eval_loss_func(preds, labels)
|
||||
|
||||
for key in loss_dict:
|
||||
if key not in output_info:
|
||||
output_info[key] = AverageMeter(key, '7.5f')
|
||||
output_info[key].update(loss_dict[key].numpy()[0],
|
||||
current_samples)
|
||||
# calc metric
|
||||
if engine.eval_metric_func is not None:
|
||||
metric_dict = engine.eval_metric_func(preds, labels)
|
||||
for key in metric_dict:
|
||||
if metric_key is None:
|
||||
metric_key = key
|
||||
|
|
|
@ -89,9 +89,6 @@ def retrieval_eval(engine, epoch_id=0):
|
|||
|
||||
|
||||
def cal_feature(engine, name='gallery'):
|
||||
all_feas = None
|
||||
all_image_id = None
|
||||
all_unique_id = None
|
||||
has_unique_id = False
|
||||
|
||||
if name == 'gallery':
|
||||
|
@ -103,6 +100,9 @@ def cal_feature(engine, name='gallery'):
|
|||
else:
|
||||
raise RuntimeError("Only support gallery or query dataset")
|
||||
|
||||
batch_feas_list = []
|
||||
img_id_list = []
|
||||
unique_id_list = []
|
||||
max_iter = len(dataloader) - 1 if platform.system() == "Windows" else len(
|
||||
dataloader)
|
||||
for idx, batch in enumerate(dataloader): # load is very time-consuming
|
||||
|
@ -140,32 +140,39 @@ def cal_feature(engine, name='gallery'):
|
|||
if engine.config["Global"].get("feature_binarize") == "sign":
|
||||
batch_feas = paddle.sign(batch_feas).astype("float32")
|
||||
|
||||
if all_feas is None:
|
||||
all_feas = batch_feas
|
||||
if paddle.distributed.get_world_size() > 1:
|
||||
batch_feas_gather = []
|
||||
img_id_gather = []
|
||||
unique_id_gather = []
|
||||
paddle.distributed.all_gather(batch_feas_gather, batch_feas)
|
||||
paddle.distributed.all_gather(img_id_gather, batch[1])
|
||||
batch_feas_list.append(paddle.concat(batch_feas_gather))
|
||||
img_id_list.append(paddle.concat(img_id_gather))
|
||||
if has_unique_id:
|
||||
all_unique_id = batch[2]
|
||||
all_image_id = batch[1]
|
||||
paddle.distributed.all_gather(unique_id_gather, batch[2])
|
||||
unique_id_list.append(paddle.concat(unique_id_gather))
|
||||
else:
|
||||
all_feas = paddle.concat([all_feas, batch_feas])
|
||||
all_image_id = paddle.concat([all_image_id, batch[1]])
|
||||
batch_feas_list.append(batch_feas)
|
||||
img_id_list.append(batch[1])
|
||||
if has_unique_id:
|
||||
all_unique_id = paddle.concat([all_unique_id, batch[2]])
|
||||
unique_id_list.append(batch[2])
|
||||
|
||||
if engine.use_dali:
|
||||
dataloader.reset()
|
||||
|
||||
if paddle.distributed.get_world_size() > 1:
|
||||
feat_list = []
|
||||
img_id_list = []
|
||||
unique_id_list = []
|
||||
paddle.distributed.all_gather(feat_list, all_feas)
|
||||
paddle.distributed.all_gather(img_id_list, all_image_id)
|
||||
all_feas = paddle.concat(feat_list, axis=0)
|
||||
all_image_id = paddle.concat(img_id_list, axis=0)
|
||||
if has_unique_id:
|
||||
paddle.distributed.all_gather(unique_id_list, all_unique_id)
|
||||
all_unique_id = paddle.concat(unique_id_list, axis=0)
|
||||
all_feas = paddle.concat(batch_feas_list)
|
||||
all_img_id = paddle.concat(img_id_list)
|
||||
if has_unique_id:
|
||||
all_unique_id = paddle.concat(unique_id_list)
|
||||
|
||||
# just for DistributedBatchSampler issue: repeat sampling
|
||||
total_samples = len(
|
||||
dataloader.dataset) if not engine.use_dali else dataloader.size
|
||||
all_feas = all_feas[:total_samples]
|
||||
all_img_id = all_img_id[:total_samples]
|
||||
if has_unique_id:
|
||||
all_unique_id = all_unique_id[:total_samples]
|
||||
|
||||
logger.info("Build {} done, all feat shape: {}, begin to eval..".format(
|
||||
name, all_feas.shape))
|
||||
return all_feas, all_image_id, all_unique_id
|
||||
return all_feas, all_img_id, all_unique_id
|
||||
|
|
Loading…
Reference in New Issue