Merge remote-tracking branch 'origin/master'
commit
3e04d20c7d
|
@ -63,15 +63,51 @@ def create_dataloader(path, imgsz, batch_size, stride, opt, hyp=None, augment=Fa
|
|||
batch_size = min(batch_size, len(dataset))
|
||||
nw = min([os.cpu_count() // world_size, batch_size if batch_size > 1 else 0, workers]) # number of workers
|
||||
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset) if rank != -1 else None
|
||||
dataloader = torch.utils.data.DataLoader(dataset,
|
||||
batch_size=batch_size,
|
||||
num_workers=nw,
|
||||
sampler=train_sampler,
|
||||
pin_memory=True,
|
||||
collate_fn=LoadImagesAndLabels.collate_fn)
|
||||
dataloader = InfiniteDataLoader (dataset,
|
||||
batch_size=batch_size,
|
||||
num_workers=nw,
|
||||
sampler=train_sampler,
|
||||
pin_memory=True,
|
||||
collate_fn=LoadImagesAndLabels.collate_fn)
|
||||
return dataloader, dataset
|
||||
|
||||
|
||||
class InfiniteDataLoader(torch.utils.data.dataloader.DataLoader):
|
||||
'''
|
||||
Dataloader that reuses workers.
|
||||
|
||||
Uses same syntax as vanilla DataLoader.
|
||||
'''
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
object.__setattr__(self, 'batch_sampler', _RepeatSampler(self.batch_sampler))
|
||||
self.iterator = super().__iter__()
|
||||
|
||||
def __len__(self):
|
||||
return len(self.batch_sampler.sampler)
|
||||
|
||||
def __iter__(self):
|
||||
for i in range(len(self)):
|
||||
yield next(self.iterator)
|
||||
|
||||
|
||||
class _RepeatSampler(object):
|
||||
'''
|
||||
Sampler that repeats forever.
|
||||
|
||||
Args:
|
||||
sampler (Sampler)
|
||||
'''
|
||||
|
||||
def __init__(self, sampler):
|
||||
self.sampler = sampler
|
||||
|
||||
def __iter__(self):
|
||||
while True:
|
||||
yield from iter(self.sampler)
|
||||
|
||||
|
||||
class LoadImages: # for inference
|
||||
def __init__(self, path, img_size=640):
|
||||
p = str(Path(path)) # os-agnostic
|
||||
|
|
Loading…
Reference in New Issue