mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add worker_seeding arg to allow selecting old vs updated data loader worker seed for (old) experiment repeatability
This commit is contained in:
parent
6478bcd02c
commit
80075b0b8a
@ -3,8 +3,11 @@
|
|||||||
Prefetcher and Fast Collate inspired by NVIDIA APEX example at
|
Prefetcher and Fast Collate inspired by NVIDIA APEX example at
|
||||||
https://github.com/NVIDIA/apex/commit/d5e2bb4bdeedd27b1dfaf5bb2b24d6c000dee9be#diff-cf86c282ff7fba81fad27a559379d5bf
|
https://github.com/NVIDIA/apex/commit/d5e2bb4bdeedd27b1dfaf5bb2b24d6c000dee9be#diff-cf86c282ff7fba81fad27a559379d5bf
|
||||||
|
|
||||||
Hacked together by / Copyright 2020 Ross Wightman
|
Hacked together by / Copyright 2021 Ross Wightman
|
||||||
"""
|
"""
|
||||||
|
import random
|
||||||
|
from functools import partial
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
import torch.utils.data
|
import torch.utils.data
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -125,9 +128,19 @@ class PrefetchLoader:
|
|||||||
self.loader.collate_fn.mixup_enabled = x
|
self.loader.collate_fn.mixup_enabled = x
|
||||||
|
|
||||||
|
|
||||||
def _worker_init(worker_id):
|
def _worker_init(worker_id, worker_seeding='all'):
|
||||||
worker_info = torch.utils.data.get_worker_info()
|
worker_info = torch.utils.data.get_worker_info()
|
||||||
assert worker_info.id == worker_id
|
assert worker_info.id == worker_id
|
||||||
|
if isinstance(worker_seeding, Callable):
|
||||||
|
seed = worker_seeding(worker_info)
|
||||||
|
random.seed(seed)
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
np.random.seed(seed % (2 ** 32 - 1))
|
||||||
|
else:
|
||||||
|
assert worker_seeding in ('all', 'part')
|
||||||
|
# random / torch seed already called in dataloader iter class w/ worker_info.seed
|
||||||
|
# to reproduce some old results (same seed + hparam combo), partial seeding is required (skip numpy re-seed)
|
||||||
|
if worker_seeding == 'all':
|
||||||
np.random.seed(worker_info.seed % (2 ** 32 - 1))
|
np.random.seed(worker_info.seed % (2 ** 32 - 1))
|
||||||
|
|
||||||
|
|
||||||
@ -162,6 +175,7 @@ def create_loader(
|
|||||||
tf_preprocessing=False,
|
tf_preprocessing=False,
|
||||||
use_multi_epochs_loader=False,
|
use_multi_epochs_loader=False,
|
||||||
persistent_workers=True,
|
persistent_workers=True,
|
||||||
|
worker_seeding='all',
|
||||||
):
|
):
|
||||||
re_num_splits = 0
|
re_num_splits = 0
|
||||||
if re_split:
|
if re_split:
|
||||||
@ -219,8 +233,9 @@ def create_loader(
|
|||||||
collate_fn=collate_fn,
|
collate_fn=collate_fn,
|
||||||
pin_memory=pin_memory,
|
pin_memory=pin_memory,
|
||||||
drop_last=is_training,
|
drop_last=is_training,
|
||||||
worker_init_fn=_worker_init,
|
worker_init_fn=partial(_worker_init, worker_seeding=worker_seeding),
|
||||||
persistent_workers=persistent_workers)
|
persistent_workers=persistent_workers
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
loader = loader_class(dataset, **loader_args)
|
loader = loader_class(dataset, **loader_args)
|
||||||
except TypeError as e:
|
except TypeError as e:
|
||||||
|
5
train.py
5
train.py
@ -252,6 +252,8 @@ parser.add_argument('--model-ema-decay', type=float, default=0.9998,
|
|||||||
# Misc
|
# Misc
|
||||||
parser.add_argument('--seed', type=int, default=42, metavar='S',
|
parser.add_argument('--seed', type=int, default=42, metavar='S',
|
||||||
help='random seed (default: 42)')
|
help='random seed (default: 42)')
|
||||||
|
parser.add_argument('--worker-seeding', type=str, default='all',
|
||||||
|
help='worker seed mode (default: all)')
|
||||||
parser.add_argument('--log-interval', type=int, default=50, metavar='N',
|
parser.add_argument('--log-interval', type=int, default=50, metavar='N',
|
||||||
help='how many batches to wait before logging training status')
|
help='how many batches to wait before logging training status')
|
||||||
parser.add_argument('--recovery-interval', type=int, default=0, metavar='N',
|
parser.add_argument('--recovery-interval', type=int, default=0, metavar='N',
|
||||||
@ -535,7 +537,8 @@ def main():
|
|||||||
distributed=args.distributed,
|
distributed=args.distributed,
|
||||||
collate_fn=collate_fn,
|
collate_fn=collate_fn,
|
||||||
pin_memory=args.pin_mem,
|
pin_memory=args.pin_mem,
|
||||||
use_multi_epochs_loader=args.use_multi_epochs_loader
|
use_multi_epochs_loader=args.use_multi_epochs_loader,
|
||||||
|
worker_seeding=args.worker_seeding,
|
||||||
)
|
)
|
||||||
|
|
||||||
loader_eval = create_loader(
|
loader_eval = create_loader(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user