fix engine.py
parent
fa52acd614
commit
c07758b331
|
@ -144,22 +144,21 @@ class Engine(object):
|
||||||
self.config["Global"]["eval_during_train"]):
|
self.config["Global"]["eval_during_train"]):
|
||||||
if self.eval_mode in ["classification", "adaface"]:
|
if self.eval_mode in ["classification", "adaface"]:
|
||||||
self.eval_dataloader = build_dataloader(
|
self.eval_dataloader = build_dataloader(
|
||||||
self.config["DataLoader"], "Eval", self.device,
|
self.config["DataLoader"], "Eval", self.device, False)
|
||||||
self.use_dali)
|
|
||||||
elif self.eval_mode == "retrieval":
|
elif self.eval_mode == "retrieval":
|
||||||
self.gallery_query_dataloader = None
|
self.gallery_query_dataloader = None
|
||||||
if len(self.config["DataLoader"]["Eval"].keys()) == 1:
|
if len(self.config["DataLoader"]["Eval"].keys()) == 1:
|
||||||
key = list(self.config["DataLoader"]["Eval"].keys())[0]
|
key = list(self.config["DataLoader"]["Eval"].keys())[0]
|
||||||
self.gallery_query_dataloader = build_dataloader(
|
self.gallery_query_dataloader = build_dataloader(
|
||||||
self.config["DataLoader"]["Eval"], key, self.device,
|
self.config["DataLoader"]["Eval"], key, self.device,
|
||||||
self.use_dali)
|
False)
|
||||||
else:
|
else:
|
||||||
self.gallery_dataloader = build_dataloader(
|
self.gallery_dataloader = build_dataloader(
|
||||||
self.config["DataLoader"]["Eval"], "Gallery",
|
self.config["DataLoader"]["Eval"], "Gallery",
|
||||||
self.device, self.use_dali)
|
self.device, False)
|
||||||
self.query_dataloader = build_dataloader(
|
self.query_dataloader = build_dataloader(
|
||||||
self.config["DataLoader"]["Eval"], "Query",
|
self.config["DataLoader"]["Eval"], "Query",
|
||||||
self.device, self.use_dali)
|
self.device, False)
|
||||||
|
|
||||||
# build loss
|
# build loss
|
||||||
if self.mode == "train":
|
if self.mode == "train":
|
||||||
|
@ -339,11 +338,11 @@ class Engine(object):
|
||||||
)
|
)
|
||||||
self.config["Global"]["seed"] = seed = 42
|
self.config["Global"]["seed"] = seed = 42
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Set random seed to ({seed} + $PADDLE_TRAINER_ID) for different trainer"
|
f"Set random seed to ({int(seed)} + $PADDLE_TRAINER_ID) for different trainer"
|
||||||
)
|
)
|
||||||
paddle.seed(seed + dist.get_rank())
|
paddle.seed(int(seed) + dist.get_rank())
|
||||||
np.random.seed(seed + dist.get_rank())
|
np.random.seed(int(seed) + dist.get_rank())
|
||||||
random.seed(seed + dist.get_rank())
|
random.seed(int(seed) + dist.get_rank())
|
||||||
|
|
||||||
# build postprocess for infer
|
# build postprocess for infer
|
||||||
if self.mode == 'infer':
|
if self.mode == 'infer':
|
||||||
|
|
Loading…
Reference in New Issue