mirror of
https://github.com/ultralytics/yolov5.git
synced 2025-06-03 14:49:29 +08:00
Update datasets.py
This commit is contained in:
parent
901243c780
commit
d3f9bf2bb7
@ -62,26 +62,25 @@ def create_dataloader(path, imgsz, batch_size, stride, opt, hyp=None, augment=Fa
|
|||||||
|
|
||||||
batch_size = min(batch_size, len(dataset))
|
batch_size = min(batch_size, len(dataset))
|
||||||
nw = min([os.cpu_count() // world_size, batch_size if batch_size > 1 else 0, workers]) # number of workers
|
nw = min([os.cpu_count() // world_size, batch_size if batch_size > 1 else 0, workers]) # number of workers
|
||||||
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset) if rank != -1 else None
|
sampler = torch.utils.data.distributed.DistributedSampler(dataset) if rank != -1 else None
|
||||||
dataloader = InfiniteDataLoader(dataset,
|
dataloader = InfiniteDataLoader(dataset,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
num_workers=nw,
|
num_workers=nw,
|
||||||
sampler=train_sampler,
|
sampler=sampler,
|
||||||
pin_memory=True,
|
pin_memory=True,
|
||||||
collate_fn=LoadImagesAndLabels.collate_fn)
|
collate_fn=LoadImagesAndLabels.collate_fn)
|
||||||
return dataloader, dataset
|
return dataloader, dataset
|
||||||
|
|
||||||
|
|
||||||
class InfiniteDataLoader(torch.utils.data.dataloader.DataLoader):
|
class InfiniteDataLoader(torch.utils.data.dataloader.DataLoader):
|
||||||
'''
|
""" Dataloader that reuses workers.
|
||||||
Dataloader that reuses workers.
|
|
||||||
|
|
||||||
Uses same syntax as vanilla DataLoader.
|
Uses same syntax as vanilla DataLoader.
|
||||||
'''
|
"""
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
object.__setattr__(self, 'batch_sampler', _RepeatSampler(self.batch_sampler))
|
object.__setattr__(self, 'batch_sampler', self._RepeatSampler(self.batch_sampler))
|
||||||
self.iterator = super().__iter__()
|
self.iterator = super().__iter__()
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
@ -91,14 +90,12 @@ class InfiniteDataLoader(torch.utils.data.dataloader.DataLoader):
|
|||||||
for i in range(len(self)):
|
for i in range(len(self)):
|
||||||
yield next(self.iterator)
|
yield next(self.iterator)
|
||||||
|
|
||||||
|
|
||||||
class _RepeatSampler(object):
|
class _RepeatSampler(object):
|
||||||
'''
|
""" Sampler that repeats forever.
|
||||||
Sampler that repeats forever.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
sampler (Sampler)
|
sampler (Sampler)
|
||||||
'''
|
"""
|
||||||
|
|
||||||
def __init__(self, sampler):
|
def __init__(self, sampler):
|
||||||
self.sampler = sampler
|
self.sampler = sampler
|
||||||
@ -684,14 +681,10 @@ def load_mosaic(self, index):
|
|||||||
# Concat/clip labels
|
# Concat/clip labels
|
||||||
if len(labels4):
|
if len(labels4):
|
||||||
labels4 = np.concatenate(labels4, 0)
|
labels4 = np.concatenate(labels4, 0)
|
||||||
# np.clip(labels4[:, 1:] - s / 2, 0, s, out=labels4[:, 1:]) # use with center crop
|
np.clip(labels4[:, 1:], 0, 2 * s, out=labels4[:, 1:]) # use with random_perspective
|
||||||
np.clip(labels4[:, 1:], 0, 2 * s, out=labels4[:, 1:]) # use with random_affine
|
# img4, labels4 = replicate(img4, labels4) # replicate
|
||||||
|
|
||||||
# Replicate
|
|
||||||
# img4, labels4 = replicate(img4, labels4)
|
|
||||||
|
|
||||||
# Augment
|
# Augment
|
||||||
# img4 = img4[s // 2: int(s * 1.5), s // 2:int(s * 1.5)] # center crop (WARNING, requires box pruning)
|
|
||||||
img4, labels4 = random_perspective(img4, labels4,
|
img4, labels4 = random_perspective(img4, labels4,
|
||||||
degrees=self.hyp['degrees'],
|
degrees=self.hyp['degrees'],
|
||||||
translate=self.hyp['translate'],
|
translate=self.hyp['translate'],
|
||||||
|
Loading…
x
Reference in New Issue
Block a user