--image_weights bug fix ()

pull/1526/head
Glenn Jocher 2020-11-26 11:49:01 +01:00 committed by GitHub
parent e9a0ae6f19
commit 9728e2b8ae
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 10 additions and 8 deletions

View File

@ -181,8 +181,9 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
# Trainloader
dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, opt,
hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect,
rank=rank, world_size=opt.world_size, workers=opt.workers)
hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect, rank=rank,
world_size=opt.world_size, workers=opt.workers,
image_weights=opt.image_weights)
mlc = np.concatenate(dataset.labels, 0)[:, 0].max() # max label class
nb = len(dataloader) # number of batches
assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Possible class labels are 0-%g' % (mlc, nc, opt.data, nc - 1)

View File

@ -55,7 +55,7 @@ def exif_size(img):
def create_dataloader(path, imgsz, batch_size, stride, opt, hyp=None, augment=False, cache=False, pad=0.0, rect=False,
rank=-1, world_size=1, workers=8):
rank=-1, world_size=1, workers=8, image_weights=False):
# Make sure only the first process in DDP process the dataset first, and the following others can use the cache
with torch_distributed_zero_first(rank):
dataset = LoadImagesAndLabels(path, imgsz, batch_size,
@ -66,7 +66,8 @@ def create_dataloader(path, imgsz, batch_size, stride, opt, hyp=None, augment=Fa
single_cls=opt.single_cls,
stride=int(stride),
pad=pad,
rank=rank)
rank=rank,
image_weights=image_weights)
batch_size = min(batch_size, len(dataset))
nw = min([os.cpu_count() // world_size, batch_size if batch_size > 1 else 0, workers]) # number of workers
@ -392,6 +393,7 @@ class LoadImagesAndLabels(Dataset): # for training/testing
nb = bi[-1] + 1 # number of batches
self.batch = bi # batch index of image
self.n = n
self.indices = range(n)
# Rectangular Training
if self.rect:
@ -485,8 +487,7 @@ class LoadImagesAndLabels(Dataset): # for training/testing
# return self
def __getitem__(self, index):
if self.image_weights:
index = self.indices[index]
index = self.indices[index] # linear, shuffled, or image_weights
hyp = self.hyp
mosaic = self.mosaic and random.random() < hyp['mosaic']
@ -497,7 +498,7 @@ class LoadImagesAndLabels(Dataset): # for training/testing
# MixUp https://arxiv.org/pdf/1710.09412.pdf
if random.random() < hyp['mixup']:
img2, labels2 = load_mosaic(self, random.randint(0, len(self.labels) - 1))
img2, labels2 = load_mosaic(self, random.randint(0, self.n - 1))
r = np.random.beta(8.0, 8.0) # mixup ratio, alpha=beta=8.0
img = (img * r + img2 * (1 - r)).astype(np.uint8)
labels = np.concatenate((labels, labels2), 0)
@ -619,7 +620,7 @@ def load_mosaic(self, index):
labels4 = []
s = self.img_size
yc, xc = [int(random.uniform(-x, 2 * s + x)) for x in self.mosaic_border] # mosaic center x, y
indices = [index] + [random.randint(0, len(self.labels) - 1) for _ in range(3)] # 3 additional image indices
indices = [index] + [self.indices[random.randint(0, self.n - 1)] for _ in range(3)] # 3 additional image indices
for i, index in enumerate(indices):
# Load image
img, _, (h, w) = load_image(self, index)