mirror of
https://github.com/ultralytics/yolov5.git
synced 2025-06-03 14:49:29 +08:00
Remove opt
from create_dataloader()
` (#3552)
This commit is contained in:
parent
0cfc5b2c18
commit
958ab92dc1
2
test.py
2
test.py
@ -88,7 +88,7 @@ def test(data,
|
|||||||
if device.type != 'cpu':
|
if device.type != 'cpu':
|
||||||
model(torch.zeros(1, 3, imgsz, imgsz).to(device).type_as(next(model.parameters()))) # run once
|
model(torch.zeros(1, 3, imgsz, imgsz).to(device).type_as(next(model.parameters()))) # run once
|
||||||
task = opt.task if opt.task in ('train', 'val', 'test') else 'val' # path to train/val/test images
|
task = opt.task if opt.task in ('train', 'val', 'test') else 'val' # path to train/val/test images
|
||||||
dataloader = create_dataloader(data[task], imgsz, batch_size, gs, opt, pad=0.5, rect=True,
|
dataloader = create_dataloader(data[task], imgsz, batch_size, gs, single_cls, pad=0.5, rect=True,
|
||||||
prefix=colorstr(f'{task}: '))[0]
|
prefix=colorstr(f'{task}: '))[0]
|
||||||
|
|
||||||
seen = 0
|
seen = 0
|
||||||
|
17
train.py
17
train.py
@ -41,8 +41,9 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
def train(hyp, opt, device, tb_writer=None):
|
def train(hyp, opt, device, tb_writer=None):
|
||||||
logger.info(colorstr('hyperparameters: ') + ', '.join(f'{k}={v}' for k, v in hyp.items()))
|
logger.info(colorstr('hyperparameters: ') + ', '.join(f'{k}={v}' for k, v in hyp.items()))
|
||||||
save_dir, epochs, batch_size, total_batch_size, weights, rank = \
|
save_dir, epochs, batch_size, total_batch_size, weights, rank, single_cls = \
|
||||||
Path(opt.save_dir), opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank
|
Path(opt.save_dir), opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank, \
|
||||||
|
opt.single_cls
|
||||||
|
|
||||||
# Directories
|
# Directories
|
||||||
wdir = save_dir / 'weights'
|
wdir = save_dir / 'weights'
|
||||||
@ -75,8 +76,8 @@ def train(hyp, opt, device, tb_writer=None):
|
|||||||
if wandb_logger.wandb:
|
if wandb_logger.wandb:
|
||||||
weights, epochs, hyp = opt.weights, opt.epochs, opt.hyp # WandbLogger might update weights, epochs if resuming
|
weights, epochs, hyp = opt.weights, opt.epochs, opt.hyp # WandbLogger might update weights, epochs if resuming
|
||||||
|
|
||||||
nc = 1 if opt.single_cls else int(data_dict['nc']) # number of classes
|
nc = 1 if single_cls else int(data_dict['nc']) # number of classes
|
||||||
names = ['item'] if opt.single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names
|
names = ['item'] if single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names
|
||||||
assert len(names) == nc, '%g names found for nc=%g dataset in %s' % (len(names), nc, opt.data) # check
|
assert len(names) == nc, '%g names found for nc=%g dataset in %s' % (len(names), nc, opt.data) # check
|
||||||
is_coco = opt.data.endswith('coco.yaml') and nc == 80 # COCO dataset
|
is_coco = opt.data.endswith('coco.yaml') and nc == 80 # COCO dataset
|
||||||
|
|
||||||
@ -187,7 +188,7 @@ def train(hyp, opt, device, tb_writer=None):
|
|||||||
logger.info('Using SyncBatchNorm()')
|
logger.info('Using SyncBatchNorm()')
|
||||||
|
|
||||||
# Trainloader
|
# Trainloader
|
||||||
dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, opt,
|
dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, single_cls,
|
||||||
hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect, rank=rank,
|
hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect, rank=rank,
|
||||||
world_size=opt.world_size, workers=opt.workers,
|
world_size=opt.world_size, workers=opt.workers,
|
||||||
image_weights=opt.image_weights, quad=opt.quad, prefix=colorstr('train: '))
|
image_weights=opt.image_weights, quad=opt.quad, prefix=colorstr('train: '))
|
||||||
@ -197,7 +198,7 @@ def train(hyp, opt, device, tb_writer=None):
|
|||||||
|
|
||||||
# Process 0
|
# Process 0
|
||||||
if rank in [-1, 0]:
|
if rank in [-1, 0]:
|
||||||
testloader = create_dataloader(test_path, imgsz_test, batch_size * 2, gs, opt, # testloader
|
testloader = create_dataloader(test_path, imgsz_test, batch_size * 2, gs, single_cls,
|
||||||
hyp=hyp, cache=opt.cache_images and not opt.notest, rect=True, rank=-1,
|
hyp=hyp, cache=opt.cache_images and not opt.notest, rect=True, rank=-1,
|
||||||
world_size=opt.world_size, workers=opt.workers,
|
world_size=opt.world_size, workers=opt.workers,
|
||||||
pad=0.5, prefix=colorstr('val: '))[0]
|
pad=0.5, prefix=colorstr('val: '))[0]
|
||||||
@ -357,7 +358,7 @@ def train(hyp, opt, device, tb_writer=None):
|
|||||||
batch_size=batch_size * 2,
|
batch_size=batch_size * 2,
|
||||||
imgsz=imgsz_test,
|
imgsz=imgsz_test,
|
||||||
model=ema.ema,
|
model=ema.ema,
|
||||||
single_cls=opt.single_cls,
|
single_cls=single_cls,
|
||||||
dataloader=testloader,
|
dataloader=testloader,
|
||||||
save_dir=save_dir,
|
save_dir=save_dir,
|
||||||
save_json=is_coco and final_epoch,
|
save_json=is_coco and final_epoch,
|
||||||
@ -429,7 +430,7 @@ def train(hyp, opt, device, tb_writer=None):
|
|||||||
conf_thres=0.001,
|
conf_thres=0.001,
|
||||||
iou_thres=0.7,
|
iou_thres=0.7,
|
||||||
model=attempt_load(m, device).half(),
|
model=attempt_load(m, device).half(),
|
||||||
single_cls=opt.single_cls,
|
single_cls=single_cls,
|
||||||
dataloader=testloader,
|
dataloader=testloader,
|
||||||
save_dir=save_dir,
|
save_dir=save_dir,
|
||||||
save_json=True,
|
save_json=True,
|
||||||
|
@ -62,8 +62,8 @@ def exif_size(img):
|
|||||||
return s
|
return s
|
||||||
|
|
||||||
|
|
||||||
def create_dataloader(path, imgsz, batch_size, stride, opt, hyp=None, augment=False, cache=False, pad=0.0, rect=False,
|
def create_dataloader(path, imgsz, batch_size, stride, single_cls=False, hyp=None, augment=False, cache=False, pad=0.0,
|
||||||
rank=-1, world_size=1, workers=8, image_weights=False, quad=False, prefix=''):
|
rect=False, rank=-1, world_size=1, workers=8, image_weights=False, quad=False, prefix=''):
|
||||||
# Make sure only the first process in DDP process the dataset first, and the following others can use the cache
|
# Make sure only the first process in DDP process the dataset first, and the following others can use the cache
|
||||||
with torch_distributed_zero_first(rank):
|
with torch_distributed_zero_first(rank):
|
||||||
dataset = LoadImagesAndLabels(path, imgsz, batch_size,
|
dataset = LoadImagesAndLabels(path, imgsz, batch_size,
|
||||||
@ -71,7 +71,7 @@ def create_dataloader(path, imgsz, batch_size, stride, opt, hyp=None, augment=Fa
|
|||||||
hyp=hyp, # augmentation hyperparameters
|
hyp=hyp, # augmentation hyperparameters
|
||||||
rect=rect, # rectangular training
|
rect=rect, # rectangular training
|
||||||
cache_images=cache,
|
cache_images=cache,
|
||||||
single_cls=opt.single_cls,
|
single_cls=single_cls,
|
||||||
stride=int(stride),
|
stride=int(stride),
|
||||||
pad=pad,
|
pad=pad,
|
||||||
image_weights=image_weights,
|
image_weights=image_weights,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user