diff --git a/tools/train.py b/tools/train.py index aec796c71..d1d014639 100644 --- a/tools/train.py +++ b/tools/train.py @@ -22,9 +22,13 @@ sys.path.append(os.path.abspath(os.path.join(__dir__, '../'))) from ppcls.utils import config from ppcls.engine.trainer import Trainer +from ppcls.engine.trainer_reid import TrainerReID if __name__ == "__main__": args = config.parse_args() config = config.get_config(args.config, overrides=args.override, show=True) - trainer = Trainer(config, mode="train") + if "Trainer" in config: + trainer = eval(config["Trainer"]["name"])(config, mode="train") + else: + trainer = Trainer(config, mode="train") trainer.train()