Add generator and worker seed (#8602)
* Add generator and worker seed * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update dataloaders.py * Update dataloaders.py * Update dataloaders.py Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>pull/8690/head
parent
b92430a83b
commit
1c5e92aba1
utils
|
@ -91,6 +91,13 @@ def exif_transpose(image):
|
|||
return image
|
||||
|
||||
|
||||
def seed_worker(worker_id):
|
||||
# Set dataloader worker seed https://pytorch.org/docs/stable/notes/randomness.html#dataloader
|
||||
worker_seed = torch.initial_seed() % 2 ** 32
|
||||
np.random.seed(worker_seed)
|
||||
random.seed(worker_seed)
|
||||
|
||||
|
||||
def create_dataloader(path,
|
||||
imgsz,
|
||||
batch_size,
|
||||
|
@ -130,13 +137,17 @@ def create_dataloader(path,
|
|||
nw = min([os.cpu_count() // max(nd, 1), batch_size if batch_size > 1 else 0, workers]) # number of workers
|
||||
sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)
|
||||
loader = DataLoader if image_weights else InfiniteDataLoader # only DataLoader allows for attribute updates
|
||||
generator = torch.Generator()
|
||||
generator.manual_seed(0)
|
||||
return loader(dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=shuffle and sampler is None,
|
||||
num_workers=nw,
|
||||
sampler=sampler,
|
||||
pin_memory=True,
|
||||
collate_fn=LoadImagesAndLabels.collate_fn4 if quad else LoadImagesAndLabels.collate_fn), dataset
|
||||
collate_fn=LoadImagesAndLabels.collate_fn4 if quad else LoadImagesAndLabels.collate_fn,
|
||||
worker_init_fn=seed_worker,
|
||||
generator=generator), dataset
|
||||
|
||||
|
||||
class InfiniteDataLoader(dataloader.DataLoader):
|
||||
|
|
Loading…
Reference in New Issue