mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add worker_init_fn to loader for numpy seed per worker
This commit is contained in:
parent
515121cca1
commit
f8a63a3b71
@ -125,6 +125,12 @@ class PrefetchLoader:
|
|||||||
self.loader.collate_fn.mixup_enabled = x
|
self.loader.collate_fn.mixup_enabled = x
|
||||||
|
|
||||||
|
|
||||||
|
def _worker_init(worker_id):
|
||||||
|
worker_info = torch.utils.data.get_worker_info()
|
||||||
|
assert worker_info.id == worker_id
|
||||||
|
np.random.seed(worker_info.seed % (2**32-1))
|
||||||
|
|
||||||
|
|
||||||
def create_loader(
|
def create_loader(
|
||||||
dataset,
|
dataset,
|
||||||
input_size,
|
input_size,
|
||||||
@ -202,7 +208,6 @@ def create_loader(
|
|||||||
collate_fn = fast_collate if use_prefetcher else torch.utils.data.dataloader.default_collate
|
collate_fn = fast_collate if use_prefetcher else torch.utils.data.dataloader.default_collate
|
||||||
|
|
||||||
loader_class = torch.utils.data.DataLoader
|
loader_class = torch.utils.data.DataLoader
|
||||||
|
|
||||||
if use_multi_epochs_loader:
|
if use_multi_epochs_loader:
|
||||||
loader_class = MultiEpochsDataLoader
|
loader_class = MultiEpochsDataLoader
|
||||||
|
|
||||||
@ -214,6 +219,7 @@ def create_loader(
|
|||||||
collate_fn=collate_fn,
|
collate_fn=collate_fn,
|
||||||
pin_memory=pin_memory,
|
pin_memory=pin_memory,
|
||||||
drop_last=is_training,
|
drop_last=is_training,
|
||||||
|
worker_init_fn=_worker_init,
|
||||||
persistent_workers=persistent_workers)
|
persistent_workers=persistent_workers)
|
||||||
try:
|
try:
|
||||||
loader = loader_class(dataset, **loader_args)
|
loader = loader_class(dataset, **loader_args)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user