From c54631ceb30a50d8a31b00bd7645e1b77ccc1eff Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Tue, 12 Jul 2022 15:24:01 +0200 Subject: [PATCH] Add --cache --- classifier.py | 4 +++- utils/dataloaders.py | 25 +++++++++++++++++-------- 2 files changed, 20 insertions(+), 9 deletions(-) diff --git a/classifier.py b/classifier.py index 0edc4a59f..7036b197b 100644 --- a/classifier.py +++ b/classifier.py @@ -77,6 +77,7 @@ def train(): imgsz=imgsz, batch_size=bs // WORLD_SIZE, augment=True, + cache=opt.cache, rank=LOCAL_RANK, workers=nw) @@ -85,6 +86,7 @@ def train(): imgsz=imgsz, batch_size=bs // WORLD_SIZE * 2, augment=False, + cache=opt.cache, rank=-1, workers=nw) @@ -330,7 +332,7 @@ if __name__ == '__main__': parser.add_argument('--nosave', action='store_true', help='only save final checkpoint') parser.add_argument('--optimizer', type=str, choices=['SGD', 'Adam', 'AdamW'], default='Adam', help='optimizer') parser.add_argument('--evolve', action='store_true', help='evolve hyperparameters') - parser.add_argument('--cache-images', action='store_true', help='cache images for faster training') + parser.add_argument('--cache', action='store_true', help='--cache images to disk for faster training') parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu') parser.add_argument('--workers', type=int, default=8, help='max dataloader workers (per RANK in DDP mode)') parser.add_argument('--project', default='runs/train', help='save to project/name') diff --git a/utils/dataloaders.py b/utils/dataloaders.py index 9b3becaf7..f31367e31 100755 --- a/utils/dataloaders.py +++ b/utils/dataloaders.py @@ -1109,18 +1109,26 @@ class ClassificationDataset(torchvision.datasets.ImageFolder): album_transform: Albumentations transforms, used if installed """ - def __init__(self, root, torch_transforms, album_transforms=None): + def __init__(self, root, torch_transforms, album_transforms=None, cache=False): super().__init__(root=root) self.torch_transforms = torch_transforms self.album_transforms = album_transforms + self.cache = cache - def __getitem__(self, idx): - path, target = self.samples[idx] + def __getitem__(self, i): + f, j = self.samples[i] # filename, index if self.album_transforms: - sample = self.album_transforms(image=cv2.imread(path)[..., ::-1])["image"] + if self.cache: + fn = Path(f).with_suffix('.npy') # filename numpy + if not fn.exists(): # load npy + np.save(fn.as_posix(), cv2.imread(f)) + im = np.load(fn) + else: # read image + im = cv2.imread(f) # BGR + sample = self.album_transforms(image=im[..., ::-1])["image"] else: - sample = self.torch_transforms(self.loader(path)) - return sample, target + sample = self.torch_transforms(self.loader(f)) + return sample, j def create_classification_dataloader( @@ -1128,7 +1136,7 @@ def create_classification_dataloader( imgsz=224, batch_size=16, augment=True, - cache=False, # TODO + cache=False, rank=-1, workers=8, shuffle=True): @@ -1136,7 +1144,8 @@ def create_classification_dataloader( with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP dataset = ClassificationDataset(root=path, torch_transforms=classify_transforms(), - album_transforms=classify_albumentations(augment, imgsz)) + album_transforms=classify_albumentations(augment, imgsz), + cache=cache) batch_size = min(batch_size, len(dataset)) nd = torch.cuda.device_count() nw = min([os.cpu_count() // max(nd, 1), batch_size if batch_size > 1 else 0, workers])