diff --git a/ppcls/engine/trainer.py b/ppcls/engine/trainer.py index 81fe515dd..4bf62ba93 100644 --- a/ppcls/engine/trainer.py +++ b/ppcls/engine/trainer.py @@ -246,11 +246,11 @@ class Trainer(object): elif self.eval_mode == "retrieval": if self.gallery_dataloader is None: self.gallery_dataloader = build_dataloader( - self.config["DataLoader"], "Gallery", self.device) + self.config["DataLoader"]["Eval"], "Gallery", self.device) if self.query_dataloader is None: self.query_dataloader = build_dataloader( - self.config["DataLoader"], "Query", self.device) + self.config["DataLoader"]["Eval"], "Query", self.device) # build metric info if self.eval_metric_func is None: metric_config = self.config.get("Metric", None)