fix engine.py
parent
fa52acd614
commit
c07758b331
|
@ -144,22 +144,21 @@ class Engine(object):
|
|||
self.config["Global"]["eval_during_train"]):
|
||||
if self.eval_mode in ["classification", "adaface"]:
|
||||
self.eval_dataloader = build_dataloader(
|
||||
self.config["DataLoader"], "Eval", self.device,
|
||||
self.use_dali)
|
||||
self.config["DataLoader"], "Eval", self.device, False)
|
||||
elif self.eval_mode == "retrieval":
|
||||
self.gallery_query_dataloader = None
|
||||
if len(self.config["DataLoader"]["Eval"].keys()) == 1:
|
||||
key = list(self.config["DataLoader"]["Eval"].keys())[0]
|
||||
self.gallery_query_dataloader = build_dataloader(
|
||||
self.config["DataLoader"]["Eval"], key, self.device,
|
||||
self.use_dali)
|
||||
False)
|
||||
else:
|
||||
self.gallery_dataloader = build_dataloader(
|
||||
self.config["DataLoader"]["Eval"], "Gallery",
|
||||
self.device, self.use_dali)
|
||||
self.device, False)
|
||||
self.query_dataloader = build_dataloader(
|
||||
self.config["DataLoader"]["Eval"], "Query",
|
||||
self.device, self.use_dali)
|
||||
self.device, False)
|
||||
|
||||
# build loss
|
||||
if self.mode == "train":
|
||||
|
@ -339,11 +338,11 @@ class Engine(object):
|
|||
)
|
||||
self.config["Global"]["seed"] = seed = 42
|
||||
logger.info(
|
||||
f"Set random seed to ({seed} + $PADDLE_TRAINER_ID) for different trainer"
|
||||
f"Set random seed to ({int(seed)} + $PADDLE_TRAINER_ID) for different trainer"
|
||||
)
|
||||
paddle.seed(seed + dist.get_rank())
|
||||
np.random.seed(seed + dist.get_rank())
|
||||
random.seed(seed + dist.get_rank())
|
||||
paddle.seed(int(seed) + dist.get_rank())
|
||||
np.random.seed(int(seed) + dist.get_rank())
|
||||
random.seed(int(seed) + dist.get_rank())
|
||||
|
||||
# build postprocess for infer
|
||||
if self.mode == 'infer':
|
||||
|
|
Loading…
Reference in New Issue