Add --cache
parent
c655119653
commit
c54631ceb3
|
@ -77,6 +77,7 @@ def train():
|
|||
imgsz=imgsz,
|
||||
batch_size=bs // WORLD_SIZE,
|
||||
augment=True,
|
||||
cache=opt.cache,
|
||||
rank=LOCAL_RANK,
|
||||
workers=nw)
|
||||
|
||||
|
@ -85,6 +86,7 @@ def train():
|
|||
imgsz=imgsz,
|
||||
batch_size=bs // WORLD_SIZE * 2,
|
||||
augment=False,
|
||||
cache=opt.cache,
|
||||
rank=-1,
|
||||
workers=nw)
|
||||
|
||||
|
@ -330,7 +332,7 @@ if __name__ == '__main__':
|
|||
parser.add_argument('--nosave', action='store_true', help='only save final checkpoint')
|
||||
parser.add_argument('--optimizer', type=str, choices=['SGD', 'Adam', 'AdamW'], default='Adam', help='optimizer')
|
||||
parser.add_argument('--evolve', action='store_true', help='evolve hyperparameters')
|
||||
parser.add_argument('--cache-images', action='store_true', help='cache images for faster training')
|
||||
parser.add_argument('--cache', action='store_true', help='--cache images to disk for faster training')
|
||||
parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
|
||||
parser.add_argument('--workers', type=int, default=8, help='max dataloader workers (per RANK in DDP mode)')
|
||||
parser.add_argument('--project', default='runs/train', help='save to project/name')
|
||||
|
|
|
@ -1109,18 +1109,26 @@ class ClassificationDataset(torchvision.datasets.ImageFolder):
|
|||
album_transform: Albumentations transforms, used if installed
|
||||
"""
|
||||
|
||||
def __init__(self, root, torch_transforms, album_transforms=None):
|
||||
def __init__(self, root, torch_transforms, album_transforms=None, cache=False):
|
||||
super().__init__(root=root)
|
||||
self.torch_transforms = torch_transforms
|
||||
self.album_transforms = album_transforms
|
||||
self.cache = cache
|
||||
|
||||
def __getitem__(self, idx):
|
||||
path, target = self.samples[idx]
|
||||
def __getitem__(self, i):
|
||||
f, j = self.samples[i] # filename, index
|
||||
if self.album_transforms:
|
||||
sample = self.album_transforms(image=cv2.imread(path)[..., ::-1])["image"]
|
||||
if self.cache:
|
||||
fn = Path(f).with_suffix('.npy') # filename numpy
|
||||
if not fn.exists(): # load npy
|
||||
np.save(fn.as_posix(), cv2.imread(f))
|
||||
im = np.load(fn)
|
||||
else: # read image
|
||||
im = cv2.imread(f) # BGR
|
||||
sample = self.album_transforms(image=im[..., ::-1])["image"]
|
||||
else:
|
||||
sample = self.torch_transforms(self.loader(path))
|
||||
return sample, target
|
||||
sample = self.torch_transforms(self.loader(f))
|
||||
return sample, j
|
||||
|
||||
|
||||
def create_classification_dataloader(
|
||||
|
@ -1128,7 +1136,7 @@ def create_classification_dataloader(
|
|||
imgsz=224,
|
||||
batch_size=16,
|
||||
augment=True,
|
||||
cache=False, # TODO
|
||||
cache=False,
|
||||
rank=-1,
|
||||
workers=8,
|
||||
shuffle=True):
|
||||
|
@ -1136,7 +1144,8 @@ def create_classification_dataloader(
|
|||
with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP
|
||||
dataset = ClassificationDataset(root=path,
|
||||
torch_transforms=classify_transforms(),
|
||||
album_transforms=classify_albumentations(augment, imgsz))
|
||||
album_transforms=classify_albumentations(augment, imgsz),
|
||||
cache=cache)
|
||||
batch_size = min(batch_size, len(dataset))
|
||||
nd = torch.cuda.device_count()
|
||||
nw = min([os.cpu_count() // max(nd, 1), batch_size if batch_size > 1 else 0, workers])
|
||||
|
|
Loading…
Reference in New Issue