--image_weights bug fix (#1524)
parent
e9a0ae6f19
commit
9728e2b8ae
5
train.py
5
train.py
|
@ -181,8 +181,9 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
|
|||
|
||||
# Trainloader
|
||||
dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, opt,
|
||||
hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect,
|
||||
rank=rank, world_size=opt.world_size, workers=opt.workers)
|
||||
hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect, rank=rank,
|
||||
world_size=opt.world_size, workers=opt.workers,
|
||||
image_weights=opt.image_weights)
|
||||
mlc = np.concatenate(dataset.labels, 0)[:, 0].max() # max label class
|
||||
nb = len(dataloader) # number of batches
|
||||
assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Possible class labels are 0-%g' % (mlc, nc, opt.data, nc - 1)
|
||||
|
|
|
@ -55,7 +55,7 @@ def exif_size(img):
|
|||
|
||||
|
||||
def create_dataloader(path, imgsz, batch_size, stride, opt, hyp=None, augment=False, cache=False, pad=0.0, rect=False,
|
||||
rank=-1, world_size=1, workers=8):
|
||||
rank=-1, world_size=1, workers=8, image_weights=False):
|
||||
# Make sure only the first process in DDP process the dataset first, and the following others can use the cache
|
||||
with torch_distributed_zero_first(rank):
|
||||
dataset = LoadImagesAndLabels(path, imgsz, batch_size,
|
||||
|
@ -66,7 +66,8 @@ def create_dataloader(path, imgsz, batch_size, stride, opt, hyp=None, augment=Fa
|
|||
single_cls=opt.single_cls,
|
||||
stride=int(stride),
|
||||
pad=pad,
|
||||
rank=rank)
|
||||
rank=rank,
|
||||
image_weights=image_weights)
|
||||
|
||||
batch_size = min(batch_size, len(dataset))
|
||||
nw = min([os.cpu_count() // world_size, batch_size if batch_size > 1 else 0, workers]) # number of workers
|
||||
|
@ -392,6 +393,7 @@ class LoadImagesAndLabels(Dataset): # for training/testing
|
|||
nb = bi[-1] + 1 # number of batches
|
||||
self.batch = bi # batch index of image
|
||||
self.n = n
|
||||
self.indices = range(n)
|
||||
|
||||
# Rectangular Training
|
||||
if self.rect:
|
||||
|
@ -485,8 +487,7 @@ class LoadImagesAndLabels(Dataset): # for training/testing
|
|||
# return self
|
||||
|
||||
def __getitem__(self, index):
|
||||
if self.image_weights:
|
||||
index = self.indices[index]
|
||||
index = self.indices[index] # linear, shuffled, or image_weights
|
||||
|
||||
hyp = self.hyp
|
||||
mosaic = self.mosaic and random.random() < hyp['mosaic']
|
||||
|
@ -497,7 +498,7 @@ class LoadImagesAndLabels(Dataset): # for training/testing
|
|||
|
||||
# MixUp https://arxiv.org/pdf/1710.09412.pdf
|
||||
if random.random() < hyp['mixup']:
|
||||
img2, labels2 = load_mosaic(self, random.randint(0, len(self.labels) - 1))
|
||||
img2, labels2 = load_mosaic(self, random.randint(0, self.n - 1))
|
||||
r = np.random.beta(8.0, 8.0) # mixup ratio, alpha=beta=8.0
|
||||
img = (img * r + img2 * (1 - r)).astype(np.uint8)
|
||||
labels = np.concatenate((labels, labels2), 0)
|
||||
|
@ -619,7 +620,7 @@ def load_mosaic(self, index):
|
|||
labels4 = []
|
||||
s = self.img_size
|
||||
yc, xc = [int(random.uniform(-x, 2 * s + x)) for x in self.mosaic_border] # mosaic center x, y
|
||||
indices = [index] + [random.randint(0, len(self.labels) - 1) for _ in range(3)] # 3 additional image indices
|
||||
indices = [index] + [self.indices[random.randint(0, self.n - 1)] for _ in range(3)] # 3 additional image indices
|
||||
for i, index in enumerate(indices):
|
||||
# Load image
|
||||
img, _, (h, w) = load_image(self, index)
|
||||
|
|
Loading…
Reference in New Issue