diff --git a/ppcls/data/dataloader/dali.py b/ppcls/data/dataloader/dali.py index a15c23156..a340a946c 100644 --- a/ppcls/data/dataloader/dali.py +++ b/ppcls/data/dataloader/dali.py @@ -230,7 +230,7 @@ def dali_dataloader(config, mode, device, seed=None): lower = ratio[0] upper = ratio[1] - if 'PADDLE_TRAINER_ID' in env and 'PADDLE_TRAINERS_NUM' in env: + if 'PADDLE_TRAINER_ID' in env and 'PADDLE_TRAINERS_NUM' in env and 'FLAGS_selected_gpus' in env: shard_id = int(env['PADDLE_TRAINER_ID']) num_shards = int(env['PADDLE_TRAINERS_NUM']) device_id = int(env['FLAGS_selected_gpus']) @@ -282,7 +282,7 @@ def dali_dataloader(config, mode, device, seed=None): else: resize_shorter = transforms["ResizeImage"].get("resize_short", 256) crop = transforms["CropImage"]["size"] - if 'PADDLE_TRAINER_ID' in env and 'PADDLE_TRAINERS_NUM' in env and sampler_name == "DistributedBatchSampler": + if 'PADDLE_TRAINER_ID' in env and 'PADDLE_TRAINERS_NUM' in env and 'FLAGS_selected_gpus' in env and sampler_name == "DistributedBatchSampler": shard_id = int(env['PADDLE_TRAINER_ID']) num_shards = int(env['PADDLE_TRAINERS_NUM']) device_id = int(env['FLAGS_selected_gpus'])