diff --git a/tools/train.py b/tools/train.py index 65a0c8f90..ab7752fd1 100644 --- a/tools/train.py +++ b/tools/train.py @@ -95,7 +95,7 @@ def main(args): # 1. train with train dataset program.run(train_dataloader, exe, compiled_train_prog, train_fetchs, epoch_id, 'train') - if int(os.environ.get("PADDLE_TRAINERS_ID", 0)) == 0: + if int(os.getenv("PADDLE_TRAINER_ID", 0)) == 0: # 2. validate with validate dataset if config.validate and epoch_id % config.valid_interval == 0: top1_acc = program.run(valid_dataloader, exe,