workers to max os.cpu_count()

This commit is contained in:
Glenn Jocher 2021-03-09 15:09:03 -08:00
parent 05b8c6a501
commit bbcc482416

View File

@ -37,7 +37,8 @@ def imshow(img):
def train():
save_dir, data, bs, epochs, nw = Path(opt.save_dir), opt.data, opt.batch_size, opt.epochs, opt.workers
save_dir, data, bs, epochs, nw = Path(opt.save_dir), opt.data, opt.batch_size, opt.epochs, \
min(os.cpu_count(), opt.workers)
# Directories
wdir = save_dir / 'weights'
@ -62,7 +63,7 @@ def train():
testform = T.Compose(trainform.transforms[-2:])
# Dataloaders
trainset = torchvision.datasets.ImageFolder(root=f'../{data}/train', transform=trainform)
trainset = torchvision.datasets.ImageFolder(root=f'../{data}/train', transformd=trainform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=bs, shuffle=True, num_workers=nw)
testset = torchvision.datasets.ImageFolder(root=f'../{data}/test', transform=testform)
testloader = torch.utils.data.DataLoader(testset, batch_size=bs, shuffle=False, num_workers=nw)