mirror of
https://github.com/PaddlePaddle/PaddleClas.git
synced 2025-06-03 21:55:06 +08:00
fix: fix the bug that DistributedBatchSampler may sample repeatedly
This commit is contained in:
parent
ccd15f5190
commit
978157e782
@ -89,9 +89,6 @@ def retrieval_eval(engine, epoch_id=0):
|
|||||||
|
|
||||||
|
|
||||||
def cal_feature(engine, name='gallery'):
|
def cal_feature(engine, name='gallery'):
|
||||||
all_feas = None
|
|
||||||
all_image_id = None
|
|
||||||
all_unique_id = None
|
|
||||||
has_unique_id = False
|
has_unique_id = False
|
||||||
|
|
||||||
if name == 'gallery':
|
if name == 'gallery':
|
||||||
@ -103,6 +100,9 @@ def cal_feature(engine, name='gallery'):
|
|||||||
else:
|
else:
|
||||||
raise RuntimeError("Only support gallery or query dataset")
|
raise RuntimeError("Only support gallery or query dataset")
|
||||||
|
|
||||||
|
batch_feas_list = []
|
||||||
|
img_id_list = []
|
||||||
|
unique_id_list = []
|
||||||
max_iter = len(dataloader) - 1 if platform.system() == "Windows" else len(
|
max_iter = len(dataloader) - 1 if platform.system() == "Windows" else len(
|
||||||
dataloader)
|
dataloader)
|
||||||
for idx, batch in enumerate(dataloader): # load is very time-consuming
|
for idx, batch in enumerate(dataloader): # load is very time-consuming
|
||||||
@ -140,32 +140,39 @@ def cal_feature(engine, name='gallery'):
|
|||||||
if engine.config["Global"].get("feature_binarize") == "sign":
|
if engine.config["Global"].get("feature_binarize") == "sign":
|
||||||
batch_feas = paddle.sign(batch_feas).astype("float32")
|
batch_feas = paddle.sign(batch_feas).astype("float32")
|
||||||
|
|
||||||
if all_feas is None:
|
if paddle.distributed.get_world_size() > 1:
|
||||||
all_feas = batch_feas
|
batch_feas_gather = []
|
||||||
|
img_id_gather = []
|
||||||
|
unique_id_gather = []
|
||||||
|
paddle.distributed.all_gather(batch_feas_gather, batch_feas)
|
||||||
|
paddle.distributed.all_gather(img_id_gather, batch[1])
|
||||||
|
batch_feas_list.append(paddle.concat(batch_feas_gather))
|
||||||
|
img_id_list.append(paddle.concat(img_id_gather))
|
||||||
if has_unique_id:
|
if has_unique_id:
|
||||||
all_unique_id = batch[2]
|
paddle.distributed.all_gather(unique_id_gather, batch[2])
|
||||||
all_image_id = batch[1]
|
unique_id_list.append(paddle.concat(unique_id_gather))
|
||||||
else:
|
else:
|
||||||
all_feas = paddle.concat([all_feas, batch_feas])
|
batch_feas_list.append(batch_feas)
|
||||||
all_image_id = paddle.concat([all_image_id, batch[1]])
|
img_id_list.append(batch[1])
|
||||||
if has_unique_id:
|
if has_unique_id:
|
||||||
all_unique_id = paddle.concat([all_unique_id, batch[2]])
|
unique_id_list.append(batch[2])
|
||||||
|
|
||||||
if engine.use_dali:
|
if engine.use_dali:
|
||||||
dataloader.reset()
|
dataloader.reset()
|
||||||
|
|
||||||
if paddle.distributed.get_world_size() > 1:
|
all_feas = paddle.concat(batch_feas_list)
|
||||||
feat_list = []
|
all_img_id = paddle.concat(img_id_list)
|
||||||
img_id_list = []
|
if has_unique_id:
|
||||||
unique_id_list = []
|
all_unique_id = paddle.concat(unique_id_list)
|
||||||
paddle.distributed.all_gather(feat_list, all_feas)
|
|
||||||
paddle.distributed.all_gather(img_id_list, all_image_id)
|
# just for DistributedBatchSampler issue: repeat sampling
|
||||||
all_feas = paddle.concat(feat_list, axis=0)
|
total_samples = len(
|
||||||
all_image_id = paddle.concat(img_id_list, axis=0)
|
dataloader.dataset) if not engine.use_dali else dataloader.size
|
||||||
if has_unique_id:
|
all_feas = all_feas[:total_samples]
|
||||||
paddle.distributed.all_gather(unique_id_list, all_unique_id)
|
all_img_id = all_img_id[:total_samples]
|
||||||
all_unique_id = paddle.concat(unique_id_list, axis=0)
|
if has_unique_id:
|
||||||
|
all_unique_id = all_unique_id[:total_samples]
|
||||||
|
|
||||||
logger.info("Build {} done, all feat shape: {}, begin to eval..".format(
|
logger.info("Build {} done, all feat shape: {}, begin to eval..".format(
|
||||||
name, all_feas.shape))
|
name, all_feas.shape))
|
||||||
return all_feas, all_image_id, all_unique_id
|
return all_feas, all_img_id, all_unique_id
|
||||||
|
Loading…
x
Reference in New Issue
Block a user