fix dali env bug
parent
a944603da0
commit
11b0860a1b
|
@ -230,7 +230,7 @@ def dali_dataloader(config, mode, device, seed=None):
|
||||||
lower = ratio[0]
|
lower = ratio[0]
|
||||||
upper = ratio[1]
|
upper = ratio[1]
|
||||||
|
|
||||||
if 'PADDLE_TRAINER_ID' in env and 'PADDLE_TRAINERS_NUM' in env:
|
if 'PADDLE_TRAINER_ID' in env and 'PADDLE_TRAINERS_NUM' in env and 'FLAGS_selected_gpus' in env:
|
||||||
shard_id = int(env['PADDLE_TRAINER_ID'])
|
shard_id = int(env['PADDLE_TRAINER_ID'])
|
||||||
num_shards = int(env['PADDLE_TRAINERS_NUM'])
|
num_shards = int(env['PADDLE_TRAINERS_NUM'])
|
||||||
device_id = int(env['FLAGS_selected_gpus'])
|
device_id = int(env['FLAGS_selected_gpus'])
|
||||||
|
@ -282,7 +282,7 @@ def dali_dataloader(config, mode, device, seed=None):
|
||||||
else:
|
else:
|
||||||
resize_shorter = transforms["ResizeImage"].get("resize_short", 256)
|
resize_shorter = transforms["ResizeImage"].get("resize_short", 256)
|
||||||
crop = transforms["CropImage"]["size"]
|
crop = transforms["CropImage"]["size"]
|
||||||
if 'PADDLE_TRAINER_ID' in env and 'PADDLE_TRAINERS_NUM' in env and sampler_name == "DistributedBatchSampler":
|
if 'PADDLE_TRAINER_ID' in env and 'PADDLE_TRAINERS_NUM' in env and 'FLAGS_selected_gpus' in env and sampler_name == "DistributedBatchSampler":
|
||||||
shard_id = int(env['PADDLE_TRAINER_ID'])
|
shard_id = int(env['PADDLE_TRAINER_ID'])
|
||||||
num_shards = int(env['PADDLE_TRAINERS_NUM'])
|
num_shards = int(env['PADDLE_TRAINERS_NUM'])
|
||||||
device_id = int(env['FLAGS_selected_gpus'])
|
device_id = int(env['FLAGS_selected_gpus'])
|
||||||
|
|
Loading…
Reference in New Issue