From d3f9bf2bb7932bcd696d3e062031c20ec0adda6c Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Tue, 1 Sep 2020 17:02:47 -0700 Subject: [PATCH] Update datasets.py --- utils/datasets.py | 43 ++++++++++++++++++------------------------- 1 file changed, 18 insertions(+), 25 deletions(-) diff --git a/utils/datasets.py b/utils/datasets.py index edb6b10fa..d3282bf14 100755 --- a/utils/datasets.py +++ b/utils/datasets.py @@ -62,26 +62,25 @@ def create_dataloader(path, imgsz, batch_size, stride, opt, hyp=None, augment=Fa batch_size = min(batch_size, len(dataset)) nw = min([os.cpu_count() // world_size, batch_size if batch_size > 1 else 0, workers]) # number of workers - train_sampler = torch.utils.data.distributed.DistributedSampler(dataset) if rank != -1 else None - dataloader = InfiniteDataLoader (dataset, + sampler = torch.utils.data.distributed.DistributedSampler(dataset) if rank != -1 else None + dataloader = InfiniteDataLoader(dataset, batch_size=batch_size, num_workers=nw, - sampler=train_sampler, + sampler=sampler, pin_memory=True, collate_fn=LoadImagesAndLabels.collate_fn) return dataloader, dataset class InfiniteDataLoader(torch.utils.data.dataloader.DataLoader): - ''' - Dataloader that reuses workers. + """ Dataloader that reuses workers. Uses same syntax as vanilla DataLoader. - ''' + """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - object.__setattr__(self, 'batch_sampler', _RepeatSampler(self.batch_sampler)) + object.__setattr__(self, 'batch_sampler', self._RepeatSampler(self.batch_sampler)) self.iterator = super().__iter__() def __len__(self): @@ -91,22 +90,20 @@ class InfiniteDataLoader(torch.utils.data.dataloader.DataLoader): for i in range(len(self)): yield next(self.iterator) + class _RepeatSampler(object): + """ Sampler that repeats forever. -class _RepeatSampler(object): - ''' - Sampler that repeats forever. + Args: + sampler (Sampler) + """ - Args: - sampler (Sampler) - ''' + def __init__(self, sampler): + self.sampler = sampler - def __init__(self, sampler): - self.sampler = sampler + def __iter__(self): + while True: + yield from iter(self.sampler) - def __iter__(self): - while True: - yield from iter(self.sampler) - class LoadImages: # for inference def __init__(self, path, img_size=640): @@ -684,14 +681,10 @@ def load_mosaic(self, index): # Concat/clip labels if len(labels4): labels4 = np.concatenate(labels4, 0) - # np.clip(labels4[:, 1:] - s / 2, 0, s, out=labels4[:, 1:]) # use with center crop - np.clip(labels4[:, 1:], 0, 2 * s, out=labels4[:, 1:]) # use with random_affine - - # Replicate - # img4, labels4 = replicate(img4, labels4) + np.clip(labels4[:, 1:], 0, 2 * s, out=labels4[:, 1:]) # use with random_perspective + # img4, labels4 = replicate(img4, labels4) # replicate # Augment - # img4 = img4[s // 2: int(s * 1.5), s // 2:int(s * 1.5)] # center crop (WARNING, requires box pruning) img4, labels4 = random_perspective(img4, labels4, degrees=self.hyp['degrees'], translate=self.hyp['translate'],