Add --cache

pull/8478/head
Glenn Jocher 2022-07-12 15:24:01 +02:00
parent c655119653
commit c54631ceb3
2 changed files with 20 additions and 9 deletions

View File

@ -77,6 +77,7 @@ def train():
imgsz=imgsz,
batch_size=bs // WORLD_SIZE,
augment=True,
cache=opt.cache,
rank=LOCAL_RANK,
workers=nw)
@ -85,6 +86,7 @@ def train():
imgsz=imgsz,
batch_size=bs // WORLD_SIZE * 2,
augment=False,
cache=opt.cache,
rank=-1,
workers=nw)
@ -330,7 +332,7 @@ if __name__ == '__main__':
parser.add_argument('--nosave', action='store_true', help='only save final checkpoint')
parser.add_argument('--optimizer', type=str, choices=['SGD', 'Adam', 'AdamW'], default='Adam', help='optimizer')
parser.add_argument('--evolve', action='store_true', help='evolve hyperparameters')
parser.add_argument('--cache-images', action='store_true', help='cache images for faster training')
parser.add_argument('--cache', action='store_true', help='--cache images to disk for faster training')
parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
parser.add_argument('--workers', type=int, default=8, help='max dataloader workers (per RANK in DDP mode)')
parser.add_argument('--project', default='runs/train', help='save to project/name')

View File

@ -1109,18 +1109,26 @@ class ClassificationDataset(torchvision.datasets.ImageFolder):
album_transform: Albumentations transforms, used if installed
"""
def __init__(self, root, torch_transforms, album_transforms=None):
def __init__(self, root, torch_transforms, album_transforms=None, cache=False):
super().__init__(root=root)
self.torch_transforms = torch_transforms
self.album_transforms = album_transforms
self.cache = cache
def __getitem__(self, idx):
path, target = self.samples[idx]
def __getitem__(self, i):
f, j = self.samples[i] # filename, index
if self.album_transforms:
sample = self.album_transforms(image=cv2.imread(path)[..., ::-1])["image"]
if self.cache:
fn = Path(f).with_suffix('.npy') # filename numpy
if not fn.exists(): # load npy
np.save(fn.as_posix(), cv2.imread(f))
im = np.load(fn)
else: # read image
im = cv2.imread(f) # BGR
sample = self.album_transforms(image=im[..., ::-1])["image"]
else:
sample = self.torch_transforms(self.loader(path))
return sample, target
sample = self.torch_transforms(self.loader(f))
return sample, j
def create_classification_dataloader(
@ -1128,7 +1136,7 @@ def create_classification_dataloader(
imgsz=224,
batch_size=16,
augment=True,
cache=False, # TODO
cache=False,
rank=-1,
workers=8,
shuffle=True):
@ -1136,7 +1144,8 @@ def create_classification_dataloader(
with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP
dataset = ClassificationDataset(root=path,
torch_transforms=classify_transforms(),
album_transforms=classify_albumentations(augment, imgsz))
album_transforms=classify_albumentations(augment, imgsz),
cache=cache)
batch_size = min(batch_size, len(dataset))
nd = torch.cuda.device_count()
nw = min([os.cpu_count() // max(nd, 1), batch_size if batch_size > 1 else 0, workers])