Merge pull request from RainFrost1/dali_bug

修复dali环境变量bug
pull/1918/head^2
Walter 2022-05-11 17:20:13 +08:00 committed by GitHub
commit 353e050c21
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 2 additions and 2 deletions
ppcls/data/dataloader

View File

@ -230,7 +230,7 @@ def dali_dataloader(config, mode, device, seed=None):
lower = ratio[0]
upper = ratio[1]
if 'PADDLE_TRAINER_ID' in env and 'PADDLE_TRAINERS_NUM' in env:
if 'PADDLE_TRAINER_ID' in env and 'PADDLE_TRAINERS_NUM' in env and 'FLAGS_selected_gpus' in env:
shard_id = int(env['PADDLE_TRAINER_ID'])
num_shards = int(env['PADDLE_TRAINERS_NUM'])
device_id = int(env['FLAGS_selected_gpus'])
@ -282,7 +282,7 @@ def dali_dataloader(config, mode, device, seed=None):
else:
resize_shorter = transforms["ResizeImage"].get("resize_short", 256)
crop = transforms["CropImage"]["size"]
if 'PADDLE_TRAINER_ID' in env and 'PADDLE_TRAINERS_NUM' in env and sampler_name == "DistributedBatchSampler":
if 'PADDLE_TRAINER_ID' in env and 'PADDLE_TRAINERS_NUM' in env and 'FLAGS_selected_gpus' in env and sampler_name == "DistributedBatchSampler":
shard_id = int(env['PADDLE_TRAINER_ID'])
num_shards = int(env['PADDLE_TRAINERS_NUM'])
device_id = int(env['FLAGS_selected_gpus'])