Remove `opt` from `create_dataloader()`` (#3552)
parent
0cfc5b2c18
commit
958ab92dc1
2
test.py
2
test.py
|
@ -88,7 +88,7 @@ def test(data,
|
|||
if device.type != 'cpu':
|
||||
model(torch.zeros(1, 3, imgsz, imgsz).to(device).type_as(next(model.parameters()))) # run once
|
||||
task = opt.task if opt.task in ('train', 'val', 'test') else 'val' # path to train/val/test images
|
||||
dataloader = create_dataloader(data[task], imgsz, batch_size, gs, opt, pad=0.5, rect=True,
|
||||
dataloader = create_dataloader(data[task], imgsz, batch_size, gs, single_cls, pad=0.5, rect=True,
|
||||
prefix=colorstr(f'{task}: '))[0]
|
||||
|
||||
seen = 0
|
||||
|
|
17
train.py
17
train.py
|
@ -41,8 +41,9 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
def train(hyp, opt, device, tb_writer=None):
|
||||
logger.info(colorstr('hyperparameters: ') + ', '.join(f'{k}={v}' for k, v in hyp.items()))
|
||||
save_dir, epochs, batch_size, total_batch_size, weights, rank = \
|
||||
Path(opt.save_dir), opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank
|
||||
save_dir, epochs, batch_size, total_batch_size, weights, rank, single_cls = \
|
||||
Path(opt.save_dir), opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank, \
|
||||
opt.single_cls
|
||||
|
||||
# Directories
|
||||
wdir = save_dir / 'weights'
|
||||
|
@ -75,8 +76,8 @@ def train(hyp, opt, device, tb_writer=None):
|
|||
if wandb_logger.wandb:
|
||||
weights, epochs, hyp = opt.weights, opt.epochs, opt.hyp # WandbLogger might update weights, epochs if resuming
|
||||
|
||||
nc = 1 if opt.single_cls else int(data_dict['nc']) # number of classes
|
||||
names = ['item'] if opt.single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names
|
||||
nc = 1 if single_cls else int(data_dict['nc']) # number of classes
|
||||
names = ['item'] if single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names
|
||||
assert len(names) == nc, '%g names found for nc=%g dataset in %s' % (len(names), nc, opt.data) # check
|
||||
is_coco = opt.data.endswith('coco.yaml') and nc == 80 # COCO dataset
|
||||
|
||||
|
@ -187,7 +188,7 @@ def train(hyp, opt, device, tb_writer=None):
|
|||
logger.info('Using SyncBatchNorm()')
|
||||
|
||||
# Trainloader
|
||||
dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, opt,
|
||||
dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, single_cls,
|
||||
hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect, rank=rank,
|
||||
world_size=opt.world_size, workers=opt.workers,
|
||||
image_weights=opt.image_weights, quad=opt.quad, prefix=colorstr('train: '))
|
||||
|
@ -197,7 +198,7 @@ def train(hyp, opt, device, tb_writer=None):
|
|||
|
||||
# Process 0
|
||||
if rank in [-1, 0]:
|
||||
testloader = create_dataloader(test_path, imgsz_test, batch_size * 2, gs, opt, # testloader
|
||||
testloader = create_dataloader(test_path, imgsz_test, batch_size * 2, gs, single_cls,
|
||||
hyp=hyp, cache=opt.cache_images and not opt.notest, rect=True, rank=-1,
|
||||
world_size=opt.world_size, workers=opt.workers,
|
||||
pad=0.5, prefix=colorstr('val: '))[0]
|
||||
|
@ -357,7 +358,7 @@ def train(hyp, opt, device, tb_writer=None):
|
|||
batch_size=batch_size * 2,
|
||||
imgsz=imgsz_test,
|
||||
model=ema.ema,
|
||||
single_cls=opt.single_cls,
|
||||
single_cls=single_cls,
|
||||
dataloader=testloader,
|
||||
save_dir=save_dir,
|
||||
save_json=is_coco and final_epoch,
|
||||
|
@ -429,7 +430,7 @@ def train(hyp, opt, device, tb_writer=None):
|
|||
conf_thres=0.001,
|
||||
iou_thres=0.7,
|
||||
model=attempt_load(m, device).half(),
|
||||
single_cls=opt.single_cls,
|
||||
single_cls=single_cls,
|
||||
dataloader=testloader,
|
||||
save_dir=save_dir,
|
||||
save_json=True,
|
||||
|
|
|
@ -62,8 +62,8 @@ def exif_size(img):
|
|||
return s
|
||||
|
||||
|
||||
def create_dataloader(path, imgsz, batch_size, stride, opt, hyp=None, augment=False, cache=False, pad=0.0, rect=False,
|
||||
rank=-1, world_size=1, workers=8, image_weights=False, quad=False, prefix=''):
|
||||
def create_dataloader(path, imgsz, batch_size, stride, single_cls=False, hyp=None, augment=False, cache=False, pad=0.0,
|
||||
rect=False, rank=-1, world_size=1, workers=8, image_weights=False, quad=False, prefix=''):
|
||||
# Make sure only the first process in DDP process the dataset first, and the following others can use the cache
|
||||
with torch_distributed_zero_first(rank):
|
||||
dataset = LoadImagesAndLabels(path, imgsz, batch_size,
|
||||
|
@ -71,7 +71,7 @@ def create_dataloader(path, imgsz, batch_size, stride, opt, hyp=None, augment=Fa
|
|||
hyp=hyp, # augmentation hyperparameters
|
||||
rect=rect, # rectangular training
|
||||
cache_images=cache,
|
||||
single_cls=opt.single_cls,
|
||||
single_cls=single_cls,
|
||||
stride=int(stride),
|
||||
pad=pad,
|
||||
image_weights=image_weights,
|
||||
|
|
Loading…
Reference in New Issue