mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Merge pull request #140 from yoniaflalo/PR_MultiEpochsDataLoader
added MultiEpochsDataLoader
This commit is contained in:
commit
3b72ebff51
@ -140,6 +140,7 @@ def create_loader(
|
|||||||
pin_memory=False,
|
pin_memory=False,
|
||||||
fp16=False,
|
fp16=False,
|
||||||
tf_preprocessing=False,
|
tf_preprocessing=False,
|
||||||
|
use_multi_epochs_loader=False
|
||||||
):
|
):
|
||||||
re_num_splits = 0
|
re_num_splits = 0
|
||||||
if re_split:
|
if re_split:
|
||||||
@ -175,7 +176,12 @@ def create_loader(
|
|||||||
if collate_fn is None:
|
if collate_fn is None:
|
||||||
collate_fn = fast_collate if use_prefetcher else torch.utils.data.dataloader.default_collate
|
collate_fn = fast_collate if use_prefetcher else torch.utils.data.dataloader.default_collate
|
||||||
|
|
||||||
loader = torch.utils.data.DataLoader(
|
loader_class = torch.utils.data.DataLoader
|
||||||
|
|
||||||
|
if use_multi_epochs_loader:
|
||||||
|
loader_class = MultiEpochsDataLoader
|
||||||
|
|
||||||
|
loader = loader_class(
|
||||||
dataset,
|
dataset,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
shuffle=sampler is None and is_training,
|
shuffle=sampler is None and is_training,
|
||||||
@ -198,3 +204,35 @@ def create_loader(
|
|||||||
)
|
)
|
||||||
|
|
||||||
return loader
|
return loader
|
||||||
|
|
||||||
|
|
||||||
|
class MultiEpochsDataLoader(torch.utils.data.DataLoader):
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self._DataLoader__initialized = False
|
||||||
|
self.batch_sampler = _RepeatSampler(self.batch_sampler)
|
||||||
|
self._DataLoader__initialized = True
|
||||||
|
self.iterator = super().__iter__()
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.batch_sampler.sampler)
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
for i in range(len(self)):
|
||||||
|
yield next(self.iterator)
|
||||||
|
|
||||||
|
|
||||||
|
class _RepeatSampler(object):
|
||||||
|
""" Sampler that repeats forever.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sampler (Sampler)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, sampler):
|
||||||
|
self.sampler = sampler
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
while True:
|
||||||
|
yield from iter(self.sampler)
|
||||||
|
3
train.py
3
train.py
@ -198,6 +198,8 @@ parser.add_argument('--eval-metric', default='top1', type=str, metavar='EVAL_MET
|
|||||||
parser.add_argument('--tta', type=int, default=0, metavar='N',
|
parser.add_argument('--tta', type=int, default=0, metavar='N',
|
||||||
help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)')
|
help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)')
|
||||||
parser.add_argument("--local_rank", default=0, type=int)
|
parser.add_argument("--local_rank", default=0, type=int)
|
||||||
|
parser.add_argument('--use-multi-epochs-loader', action='store_true', default=False,
|
||||||
|
help='use the multi-epochs-loader to save time at the beginning of every epoch')
|
||||||
|
|
||||||
|
|
||||||
def _parse_args():
|
def _parse_args():
|
||||||
@ -391,6 +393,7 @@ def main():
|
|||||||
distributed=args.distributed,
|
distributed=args.distributed,
|
||||||
collate_fn=collate_fn,
|
collate_fn=collate_fn,
|
||||||
pin_memory=args.pin_mem,
|
pin_memory=args.pin_mem,
|
||||||
|
use_multi_epochs_loader=args.use_multi_epochs_loader
|
||||||
)
|
)
|
||||||
|
|
||||||
eval_dir = os.path.join(args.data, 'val')
|
eval_dir = os.path.join(args.data, 'val')
|
||||||
|
Loading…
x
Reference in New Issue
Block a user