[WIP] Feature/ddp fixed (#401)
* Squashed commit of the following: commit d738487089e41c22b3b1cd73aa7c1c40320a6ebf Author: NanoCode012 <kevinvong@rocketmail.com> Date: Tue Jul 14 17:33:38 2020 +0700 Adding world_size Reduce calls to torch.distributed. For use in create_dataloader. commit e742dd9619d29306c7541821238d3d7cddcdc508 Author: yizhi.chen <chenyzsjtu@outlook.com> Date: Tue Jul 14 15:38:48 2020 +0800 Make SyncBN a choice commit e90d4004387e6103fecad745f8cbc2edc918e906 Merge: 5bf8beb cd90360 Author: yzchen <Chenyzsjtu@gmail.com> Date: Tue Jul 14 15:32:10 2020 +0800 Merge pull request #6 from NanoCode012/patch-5 Update train.py commit cd9036017e7f8bd519a8b62adab0f47ea67f4962 Author: NanoCode012 <kevinvong@rocketmail.com> Date: Tue Jul 14 13:39:29 2020 +0700 Update train.py Remove redundant `opt.` prefix. commit 5bf8bebe8873afb18b762fe1f409aca116fac073 Merge: c9558a9pull/454/heada1c8406
Author: yizhi.chen <chenyzsjtu@outlook.com> Date: Tue Jul 14 14:09:51 2020 +0800 Merge branch 'master' of https://github.com/ultralytics/yolov5 into feature/DDP_fixed commit c9558a9b51547febb03d9c1ca42e2ef0fc15bb31 Author: yizhi.chen <chenyzsjtu@outlook.com> Date: Tue Jul 14 13:51:34 2020 +0800 Add device allocation for loss compute commit 4f08c692fb5e943a89e0ee354ef6c80a50eeb28d Author: yizhi.chen <chenyzsjtu@outlook.com> Date: Thu Jul 9 11:16:27 2020 +0800 Revert drop_last commit 1dabe33a5a223b758cc761fc8741c6224205a34b Merge: a1ce9b1 4b8450b Author: yizhi.chen <chenyzsjtu@outlook.com> Date: Thu Jul 9 11:15:49 2020 +0800 Merge branch 'feature/DDP_fixed' of https://github.com/MagicFrogSJTU/yolov5 into feature/DDP_fixed commit a1ce9b1e96b71d7fcb9d3e8143013eb8cebe5e27 Author: yizhi.chen <chenyzsjtu@outlook.com> Date: Thu Jul 9 11:15:21 2020 +0800 fix lr warning commit 4b8450b46db76e5e58cd95df965d4736077cfb0e Merge: b9a50ae 02c63ef Author: yzchen <Chenyzsjtu@gmail.com> Date: Wed Jul 8 21:24:24 2020 +0800 Merge pull request #4 from NanoCode012/patch-4 Add drop_last for multi gpu commit 02c63ef81cf98b28b10344fe2cce08a03b143941 Author: NanoCode012 <kevinvong@rocketmail.com> Date: Wed Jul 8 10:08:30 2020 +0700 Add drop_last for multi gpu commit b9a50aed48ab1536f94d49269977e2accd67748f Merge: ec2dc6c121d90b
Author: yizhi.chen <chenyzsjtu@outlook.com> Date: Tue Jul 7 19:48:04 2020 +0800 Merge branch 'master' of https://github.com/ultralytics/yolov5 into feature/DDP_fixed commit ec2dc6cc56de43ddff939e14c450672d0fbf9b3d Merge: d0326e3 82a6182 Author: yizhi.chen <chenyzsjtu@outlook.com> Date: Tue Jul 7 19:34:31 2020 +0800 Merge branch 'feature/DDP_fixed' of https://github.com/MagicFrogSJTU/yolov5 into feature/DDP_fixed commit d0326e398dfeeeac611ccc64198d4fe91b7aa969 Author: yizhi.chen <chenyzsjtu@outlook.com> Date: Tue Jul 7 19:31:24 2020 +0800 Add SyncBN commit 82a6182b3ad0689a4432b631b438004e5acb3b74 Merge: 96fa40a 050b2a5 Author: yzchen <Chenyzsjtu@gmail.com> Date: Tue Jul 7 19:21:01 2020 +0800 Merge pull request #1 from NanoCode012/patch-2 Convert BatchNorm to SyncBatchNorm commit 050b2a5a79a89c9405854d439a1f70f892139b1c Author: NanoCode012 <kevinvong@rocketmail.com> Date: Tue Jul 7 12:38:14 2020 +0700 Add cleanup for process_group commit 2aa330139f3cc1237aeb3132245ed7e5d6da1683 Author: NanoCode012 <kevinvong@rocketmail.com> Date: Tue Jul 7 12:07:40 2020 +0700 Remove apex.parallel. Use torch.nn.parallel For future compatibility commit 77c8e27e603bea9a69e7647587ca8d509dc1990d Author: NanoCode012 <kevinvong@rocketmail.com> Date: Tue Jul 7 01:54:39 2020 +0700 Convert BatchNorm to SyncBatchNorm commit 96fa40a3a925e4ffd815fe329e1b5181ec92adc8 Author: yizhi.chen <chenyzsjtu@outlook.com> Date: Mon Jul 6 21:53:56 2020 +0800 Fix the datset inconsistency problem commit 16e7c269d062c8d16c4d4ff70cc80fd87935dc95 Author: yizhi.chen <chenyzsjtu@outlook.com> Date: Mon Jul 6 11:34:03 2020 +0800 Add loss multiplication to preserver the single-process performance commit e83805563065ffd2e38f85abe008fc662cc17909 Merge: 625bb493bdea3f
Author: yizhi.chen <chenyzsjtu@outlook.com> Date: Fri Jul 3 20:56:30 2020 +0800 Merge branch 'master' of https://github.com/ultralytics/yolov5 into feature/DDP_fixed commit 625bb49f4e52d781143fea0af36d14e5be8b040c Author: yizhi.chen <chenyzsjtu@outlook.com> Date: Thu Jul 2 22:45:15 2020 +0800 DDP established * Squashed commit of the following: commit 94147314e559a6bdd13cb9de62490d385c27596f Merge: 65157e237acbdc
Author: yizhi.chen <chenyzsjtu@outlook.com> Date: Thu Jul 16 14:00:17 2020 +0800 Merge branch 'master' of https://github.com/ultralytics/yolov4 into feature/DDP_fixed commit37acbdc0b6
Author: Glenn Jocher <glenn.jocher@ultralytics.com> Date: Wed Jul 15 20:03:41 2020 -0700 update test.py --save-txt commitb8c2da4a0d
Author: Glenn Jocher <glenn.jocher@ultralytics.com> Date: Wed Jul 15 20:00:48 2020 -0700 update test.py --save-txt commit 65157e2fc97d371bc576e18b424e130eb3026917 Author: yizhi.chen <chenyzsjtu@outlook.com> Date: Wed Jul 15 16:44:13 2020 +0800 Revert the README.md removal commit 1c802bfa503623661d8617ca3f259835d27c5345 Merge: cd55b44 0f3b8bb Author: yizhi.chen <chenyzsjtu@outlook.com> Date: Wed Jul 15 16:43:38 2020 +0800 Merge branch 'feature/DDP_fixed' of https://github.com/MagicFrogSJTU/yolov5 into feature/DDP_fixed commit cd55b445c4dcd8003ff4b0b46b64adf7c16e5ce7 Author: yizhi.chen <chenyzsjtu@outlook.com> Date: Wed Jul 15 16:42:33 2020 +0800 fix the DDP performance deterioration bug. commit 0f3b8bb1fae5885474ba861bbbd1924fb622ee93 Author: Glenn Jocher <glenn.jocher@ultralytics.com> Date: Wed Jul 15 00:28:53 2020 -0700 Delete README.md commit f5921ba1e35475f24b062456a890238cb7a3cf94 Merge: 85ab2f3 bd3fdbb Author: yizhi.chen <chenyzsjtu@outlook.com> Date: Wed Jul 15 11:20:17 2020 +0800 Merge branch 'feature/DDP_fixed' of https://github.com/MagicFrogSJTU/yolov5 into feature/DDP_fixed commit bd3fdbbf1b08ef87931eef49fa8340621caa7e87 Author: Glenn Jocher <glenn.jocher@ultralytics.com> Date: Tue Jul 14 18:38:20 2020 -0700 Update README.md commit c1a97a7767ccb2aa9afc7a5e72fd159e7c62ec02 Merge: 2bf86b8f796708
Author: Glenn Jocher <glenn.jocher@ultralytics.com> Date: Tue Jul 14 18:36:53 2020 -0700 Merge branch 'master' into feature/DDP_fixed commit 2bf86b892fa2fd712f6530903a0d9b8533d7447a Author: NanoCode012 <kevinvong@rocketmail.com> Date: Tue Jul 14 22:18:15 2020 +0700 Fixed world_size not found when called from test commit 85ab2f38cdda28b61ad15a3a5a14c3aafb620dc8 Merge: 5a19011 c8357ad Author: yizhi.chen <chenyzsjtu@outlook.com> Date: Tue Jul 14 22:19:58 2020 +0800 Merge branch 'feature/DDP_fixed' of https://github.com/MagicFrogSJTU/yolov5 into feature/DDP_fixed commit 5a19011949398d06e744d8d5521ab4e6dfa06ab7 Author: yizhi.chen <chenyzsjtu@outlook.com> Date: Tue Jul 14 22:19:15 2020 +0800 Add assertion for <=2 gpus DDP commit c8357ad5b15a0e6aeef4d7fe67ca9637f7322a4d Merge: e742dd9 787582f Author: yzchen <Chenyzsjtu@gmail.com> Date: Tue Jul 14 22:10:02 2020 +0800 Merge pull request #8 from MagicFrogSJTU/NanoCode012-patch-1 Modify number of dataloaders' workers commit 787582f97251834f955ef05a77072b8c673a8397 Author: NanoCode012 <kevinvong@rocketmail.com> Date: Tue Jul 14 20:38:58 2020 +0700 Fixed issue with single gpu not having world_size commit 63648925288d63a21174a4dd28f92dbfebfeb75a Author: NanoCode012 <kevinvong@rocketmail.com> Date: Tue Jul 14 19:16:15 2020 +0700 Add assert message for clarification Clarify why assertion was thrown to users commit 69364d6050e048d0d8834e0f30ce84da3f6a13f3 Author: NanoCode012 <kevinvong@rocketmail.com> Date: Tue Jul 14 17:36:48 2020 +0700 Changed number of workers check commit d738487089e41c22b3b1cd73aa7c1c40320a6ebf Author: NanoCode012 <kevinvong@rocketmail.com> Date: Tue Jul 14 17:33:38 2020 +0700 Adding world_size Reduce calls to torch.distributed. For use in create_dataloader. commit e742dd9619d29306c7541821238d3d7cddcdc508 Author: yizhi.chen <chenyzsjtu@outlook.com> Date: Tue Jul 14 15:38:48 2020 +0800 Make SyncBN a choice commit e90d4004387e6103fecad745f8cbc2edc918e906 Merge: 5bf8beb cd90360 Author: yzchen <Chenyzsjtu@gmail.com> Date: Tue Jul 14 15:32:10 2020 +0800 Merge pull request #6 from NanoCode012/patch-5 Update train.py commit cd9036017e7f8bd519a8b62adab0f47ea67f4962 Author: NanoCode012 <kevinvong@rocketmail.com> Date: Tue Jul 14 13:39:29 2020 +0700 Update train.py Remove redundant `opt.` prefix. commit 5bf8bebe8873afb18b762fe1f409aca116fac073 Merge: c9558a9a1c8406
Author: yizhi.chen <chenyzsjtu@outlook.com> Date: Tue Jul 14 14:09:51 2020 +0800 Merge branch 'master' of https://github.com/ultralytics/yolov5 into feature/DDP_fixed commit c9558a9b51547febb03d9c1ca42e2ef0fc15bb31 Author: yizhi.chen <chenyzsjtu@outlook.com> Date: Tue Jul 14 13:51:34 2020 +0800 Add device allocation for loss compute commit 4f08c692fb5e943a89e0ee354ef6c80a50eeb28d Author: yizhi.chen <chenyzsjtu@outlook.com> Date: Thu Jul 9 11:16:27 2020 +0800 Revert drop_last commit 1dabe33a5a223b758cc761fc8741c6224205a34b Merge: a1ce9b1 4b8450b Author: yizhi.chen <chenyzsjtu@outlook.com> Date: Thu Jul 9 11:15:49 2020 +0800 Merge branch 'feature/DDP_fixed' of https://github.com/MagicFrogSJTU/yolov5 into feature/DDP_fixed commit a1ce9b1e96b71d7fcb9d3e8143013eb8cebe5e27 Author: yizhi.chen <chenyzsjtu@outlook.com> Date: Thu Jul 9 11:15:21 2020 +0800 fix lr warning commit 4b8450b46db76e5e58cd95df965d4736077cfb0e Merge: b9a50ae 02c63ef Author: yzchen <Chenyzsjtu@gmail.com> Date: Wed Jul 8 21:24:24 2020 +0800 Merge pull request #4 from NanoCode012/patch-4 Add drop_last for multi gpu commit 02c63ef81cf98b28b10344fe2cce08a03b143941 Author: NanoCode012 <kevinvong@rocketmail.com> Date: Wed Jul 8 10:08:30 2020 +0700 Add drop_last for multi gpu commit b9a50aed48ab1536f94d49269977e2accd67748f Merge: ec2dc6c121d90b
Author: yizhi.chen <chenyzsjtu@outlook.com> Date: Tue Jul 7 19:48:04 2020 +0800 Merge branch 'master' of https://github.com/ultralytics/yolov5 into feature/DDP_fixed commit ec2dc6cc56de43ddff939e14c450672d0fbf9b3d Merge: d0326e3 82a6182 Author: yizhi.chen <chenyzsjtu@outlook.com> Date: Tue Jul 7 19:34:31 2020 +0800 Merge branch 'feature/DDP_fixed' of https://github.com/MagicFrogSJTU/yolov5 into feature/DDP_fixed commit d0326e398dfeeeac611ccc64198d4fe91b7aa969 Author: yizhi.chen <chenyzsjtu@outlook.com> Date: Tue Jul 7 19:31:24 2020 +0800 Add SyncBN commit 82a6182b3ad0689a4432b631b438004e5acb3b74 Merge: 96fa40a 050b2a5 Author: yzchen <Chenyzsjtu@gmail.com> Date: Tue Jul 7 19:21:01 2020 +0800 Merge pull request #1 from NanoCode012/patch-2 Convert BatchNorm to SyncBatchNorm commit 050b2a5a79a89c9405854d439a1f70f892139b1c Author: NanoCode012 <kevinvong@rocketmail.com> Date: Tue Jul 7 12:38:14 2020 +0700 Add cleanup for process_group commit 2aa330139f3cc1237aeb3132245ed7e5d6da1683 Author: NanoCode012 <kevinvong@rocketmail.com> Date: Tue Jul 7 12:07:40 2020 +0700 Remove apex.parallel. Use torch.nn.parallel For future compatibility commit 77c8e27e603bea9a69e7647587ca8d509dc1990d Author: NanoCode012 <kevinvong@rocketmail.com> Date: Tue Jul 7 01:54:39 2020 +0700 Convert BatchNorm to SyncBatchNorm commit 96fa40a3a925e4ffd815fe329e1b5181ec92adc8 Author: yizhi.chen <chenyzsjtu@outlook.com> Date: Mon Jul 6 21:53:56 2020 +0800 Fix the datset inconsistency problem commit 16e7c269d062c8d16c4d4ff70cc80fd87935dc95 Author: yizhi.chen <chenyzsjtu@outlook.com> Date: Mon Jul 6 11:34:03 2020 +0800 Add loss multiplication to preserver the single-process performance commit e83805563065ffd2e38f85abe008fc662cc17909 Merge: 625bb493bdea3f
Author: yizhi.chen <chenyzsjtu@outlook.com> Date: Fri Jul 3 20:56:30 2020 +0800 Merge branch 'master' of https://github.com/ultralytics/yolov5 into feature/DDP_fixed commit 625bb49f4e52d781143fea0af36d14e5be8b040c Author: yizhi.chen <chenyzsjtu@outlook.com> Date: Thu Jul 2 22:45:15 2020 +0800 DDP established * Fixed destroy_process_group in DP mode * Update torch_utils.py * Update utils.py Revert build_targets() to current master. * Update datasets.py * Fixed world_size attribute not found Co-authored-by: NanoCode012 <kevinvong@rocketmail.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
parent
b6fe2e4595
commit
4102fcc9a7
330
train.py
330
train.py
|
@ -1,11 +1,13 @@
|
|||
import argparse
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
import torch.optim.lr_scheduler as lr_scheduler
|
||||
import torch.utils.data
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
import test # import test.py to get mAP after each epoch
|
||||
from models.yolo import Model
|
||||
|
@ -42,7 +44,7 @@ hyp = {'optimizer': 'SGD', # ['adam', 'SGD', None] if none, default is SGD
|
|||
'shear': 0.0} # image shear (+/- deg)
|
||||
|
||||
|
||||
def train(hyp):
|
||||
def train(hyp, tb_writer, opt, device):
|
||||
print(f'Hyperparameters {hyp}')
|
||||
log_dir = tb_writer.log_dir if tb_writer else 'runs/evolution' # run directory
|
||||
wdir = str(Path(log_dir) / 'weights') + os.sep # weights directory
|
||||
|
@ -59,11 +61,16 @@ def train(hyp):
|
|||
yaml.dump(vars(opt), f, sort_keys=False)
|
||||
|
||||
epochs = opt.epochs # 300
|
||||
batch_size = opt.batch_size # 64
|
||||
batch_size = opt.batch_size # batch size per process.
|
||||
total_batch_size = opt.total_batch_size
|
||||
weights = opt.weights # initial training weights
|
||||
local_rank = opt.local_rank
|
||||
|
||||
# TODO: Init DDP logging. Only the first process is allowed to log.
|
||||
# Since I see lots of print here, the logging configuration is skipped here. We may see repeated outputs.
|
||||
|
||||
# Configure
|
||||
init_seeds(1)
|
||||
init_seeds(2+local_rank)
|
||||
with open(opt.data) as f:
|
||||
data_dict = yaml.load(f, Loader=yaml.FullLoader) # model dict
|
||||
train_path = data_dict['train']
|
||||
|
@ -72,8 +79,9 @@ def train(hyp):
|
|||
assert len(names) == nc, '%g names found for nc=%g dataset in %s' % (len(names), nc, opt.data) # check
|
||||
|
||||
# Remove previous results
|
||||
for f in glob.glob('*_batch*.jpg') + glob.glob(results_file):
|
||||
os.remove(f)
|
||||
if local_rank in [-1, 0]:
|
||||
for f in glob.glob('*_batch*.jpg') + glob.glob(results_file):
|
||||
os.remove(f)
|
||||
|
||||
# Create model
|
||||
model = Model(opt.cfg, nc=nc).to(device)
|
||||
|
@ -84,8 +92,15 @@ def train(hyp):
|
|||
|
||||
# Optimizer
|
||||
nbs = 64 # nominal batch size
|
||||
accumulate = max(round(nbs / batch_size), 1) # accumulate loss before optimizing
|
||||
hyp['weight_decay'] *= batch_size * accumulate / nbs # scale weight_decay
|
||||
# the default DDP implementation is slow for accumulation according to: https://pytorch.org/docs/stable/notes/ddp.html
|
||||
# all-reduce operation is carried out during loss.backward().
|
||||
# Thus, there would be redundant all-reduce communications in a accumulation procedure,
|
||||
# which means, the result is still right but the training speed gets slower.
|
||||
# TODO: If acceleration is needed, there is an implementation of allreduce_post_accumulation
|
||||
# in https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/LanguageModeling/BERT/run_pretraining.py
|
||||
accumulate = max(round(nbs / total_batch_size), 1) # accumulate loss before optimizing
|
||||
hyp['weight_decay'] *= total_batch_size * accumulate / nbs # scale weight_decay
|
||||
|
||||
pg0, pg1, pg2 = [], [], [] # optimizer parameter groups
|
||||
for k, v in model.named_parameters():
|
||||
if v.requires_grad:
|
||||
|
@ -106,13 +121,10 @@ def train(hyp):
|
|||
print('Optimizer groups: %g .bias, %g conv.weight, %g other' % (len(pg2), len(pg1), len(pg0)))
|
||||
del pg0, pg1, pg2
|
||||
|
||||
# Scheduler https://arxiv.org/pdf/1812.01187.pdf
|
||||
lf = lambda x: (((1 + math.cos(x * math.pi / epochs)) / 2) ** 1.0) * 0.9 + 0.1 # cosine
|
||||
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)
|
||||
# plot_lr_scheduler(optimizer, scheduler, epochs, save_dir=log_dir)
|
||||
|
||||
# Load Model
|
||||
google_utils.attempt_download(weights)
|
||||
# Avoid multiple downloads.
|
||||
with torch_distributed_zero_first(local_rank):
|
||||
google_utils.attempt_download(weights)
|
||||
start_epoch, best_fitness = 0, 0.0
|
||||
if weights.endswith('.pt'): # pytorch format
|
||||
ckpt = torch.load(weights, map_location=device) # load checkpoint
|
||||
|
@ -124,7 +136,7 @@ def train(hyp):
|
|||
except KeyError as e:
|
||||
s = "%s is not compatible with %s. This may be due to model differences or %s may be out of date. " \
|
||||
"Please delete or update %s and try again, or use --weights '' to train from scratch." \
|
||||
% (opt.weights, opt.cfg, opt.weights, opt.weights)
|
||||
% (weights, opt.cfg, weights, weights)
|
||||
raise KeyError(s) from e
|
||||
|
||||
# load optimizer
|
||||
|
@ -141,7 +153,7 @@ def train(hyp):
|
|||
start_epoch = ckpt['epoch'] + 1
|
||||
if epochs < start_epoch:
|
||||
print('%s has been trained for %g epochs. Fine-tuning for %g additional epochs.' %
|
||||
(opt.weights, ckpt['epoch'], epochs))
|
||||
(weights, ckpt['epoch'], epochs))
|
||||
epochs += ckpt['epoch'] # finetune additional epochs
|
||||
|
||||
del ckpt
|
||||
|
@ -150,25 +162,41 @@ def train(hyp):
|
|||
if mixed_precision:
|
||||
model, optimizer = amp.initialize(model, optimizer, opt_level='O1', verbosity=0)
|
||||
|
||||
# Distributed training
|
||||
if device.type != 'cpu' and torch.cuda.device_count() > 1 and dist.is_available():
|
||||
dist.init_process_group(backend='nccl', # distributed backend
|
||||
init_method='tcp://127.0.0.1:9999', # init method
|
||||
world_size=1, # number of nodes
|
||||
rank=0) # node rank
|
||||
# model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device) # requires world_size > 1
|
||||
model = torch.nn.parallel.DistributedDataParallel(model)
|
||||
# Scheduler https://arxiv.org/pdf/1812.01187.pdf
|
||||
lf = lambda x: (((1 + math.cos(x * math.pi / epochs)) / 2) ** 1.0) * 0.9 + 0.1 # cosine
|
||||
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)
|
||||
# https://discuss.pytorch.org/t/a-problem-occured-when-resuming-an-optimizer/28822
|
||||
# plot_lr_scheduler(optimizer, scheduler, epochs)
|
||||
|
||||
# DP mode
|
||||
if device.type != 'cpu' and local_rank == -1 and torch.cuda.device_count() > 1:
|
||||
model = torch.nn.DataParallel(model)
|
||||
|
||||
# Exponential moving average
|
||||
# From https://github.com/rwightman/pytorch-image-models/blob/master/train.py:
|
||||
# "Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper"
|
||||
# chenyzsjtu: ema should be placed before after SyncBN. As SyncBN introduces new modules.
|
||||
if opt.sync_bn and device.type != 'cpu' and local_rank != -1:
|
||||
print("SyncBN activated!")
|
||||
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device)
|
||||
ema = torch_utils.ModelEMA(model) if local_rank in [-1, 0] else None
|
||||
|
||||
# DDP mode
|
||||
if device.type != 'cpu' and local_rank != -1:
|
||||
model = DDP(model, device_ids=[local_rank], output_device=local_rank)
|
||||
|
||||
# Trainloader
|
||||
dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, opt,
|
||||
hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect)
|
||||
dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, opt, hyp=hyp, augment=True,
|
||||
cache=opt.cache_images, rect=opt.rect, local_rank=local_rank, world_size=opt.world_size)
|
||||
mlc = np.concatenate(dataset.labels, 0)[:, 0].max() # max label class
|
||||
nb = len(dataloader) # number of batches
|
||||
assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Possible class labels are 0-%g' % (mlc, nc, opt.data, nc - 1)
|
||||
|
||||
# Testloader
|
||||
testloader = create_dataloader(test_path, imgsz_test, batch_size, gs, opt,
|
||||
hyp=hyp, augment=False, cache=opt.cache_images, rect=True)[0]
|
||||
if local_rank in [-1, 0]:
|
||||
# local_rank is set to -1. Because only the first process is expected to do evaluation.
|
||||
testloader = create_dataloader(test_path, imgsz_test, total_batch_size, gs, opt, hyp=hyp, augment=False,
|
||||
cache=opt.cache_images, rect=True, local_rank=-1, world_size=opt.world_size)[0]
|
||||
|
||||
# Model parameters
|
||||
hyp['cls'] *= nc / 80. # scale coco-tuned hyp['cls'] to current dataset
|
||||
|
@ -179,48 +207,63 @@ def train(hyp):
|
|||
model.names = names
|
||||
|
||||
# Class frequency
|
||||
labels = np.concatenate(dataset.labels, 0)
|
||||
c = torch.tensor(labels[:, 0]) # classes
|
||||
# cf = torch.bincount(c.long(), minlength=nc) + 1.
|
||||
# model._initialize_biases(cf.to(device))
|
||||
plot_labels(labels, save_dir=log_dir)
|
||||
if tb_writer:
|
||||
# tb_writer.add_hparams(hyp, {}) # causes duplicate https://github.com/ultralytics/yolov5/pull/384
|
||||
tb_writer.add_histogram('classes', c, 0)
|
||||
|
||||
# Check anchors
|
||||
if not opt.noautoanchor:
|
||||
check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz)
|
||||
|
||||
# Exponential moving average
|
||||
ema = torch_utils.ModelEMA(model)
|
||||
# Only one check and log is needed.
|
||||
if local_rank in [-1, 0]:
|
||||
labels = np.concatenate(dataset.labels, 0)
|
||||
c = torch.tensor(labels[:, 0]) # classes
|
||||
# cf = torch.bincount(c.long(), minlength=nc) + 1.
|
||||
# model._initialize_biases(cf.to(device))
|
||||
plot_labels(labels, save_dir=log_dir)
|
||||
if tb_writer:
|
||||
# tb_writer.add_hparams(hyp, {}) # causes duplicate https://github.com/ultralytics/yolov5/pull/384
|
||||
tb_writer.add_histogram('classes', c, 0)
|
||||
|
||||
# Check anchors
|
||||
if not opt.noautoanchor:
|
||||
check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz)
|
||||
# Start training
|
||||
t0 = time.time()
|
||||
nw = max(3 * nb, 1e3) # number of warmup iterations, max(3 epochs, 1k iterations)
|
||||
maps = np.zeros(nc) # mAP per class
|
||||
results = (0, 0, 0, 0, 0, 0, 0) # 'P', 'R', 'mAP', 'F1', 'val GIoU', 'val Objectness', 'val Classification'
|
||||
scheduler.last_epoch = start_epoch - 1 # do not move
|
||||
print('Image sizes %g train, %g test' % (imgsz, imgsz_test))
|
||||
print('Using %g dataloader workers' % dataloader.num_workers)
|
||||
print('Starting training for %g epochs...' % epochs)
|
||||
if local_rank in [0, -1]:
|
||||
print('Image sizes %g train, %g test' % (imgsz, imgsz_test))
|
||||
print('Using %g dataloader workers' % dataloader.num_workers)
|
||||
print('Starting training for %g epochs...' % epochs)
|
||||
# torch.autograd.set_detect_anomaly(True)
|
||||
for epoch in range(start_epoch, epochs): # epoch ------------------------------------------------------------------
|
||||
model.train()
|
||||
|
||||
# Update image weights (optional)
|
||||
# When in DDP mode, the generated indices will be broadcasted to synchronize dataset.
|
||||
if dataset.image_weights:
|
||||
w = model.class_weights.cpu().numpy() * (1 - maps) ** 2 # class weights
|
||||
image_weights = labels_to_image_weights(dataset.labels, nc=nc, class_weights=w)
|
||||
dataset.indices = random.choices(range(dataset.n), weights=image_weights, k=dataset.n) # rand weighted idx
|
||||
# Generate indices.
|
||||
if local_rank in [-1, 0]:
|
||||
w = model.class_weights.cpu().numpy() * (1 - maps) ** 2 # class weights
|
||||
image_weights = labels_to_image_weights(dataset.labels, nc=nc, class_weights=w)
|
||||
dataset.indices = random.choices(range(dataset.n), weights=image_weights, k=dataset.n) # rand weighted idx
|
||||
# Broadcast.
|
||||
if local_rank != -1:
|
||||
indices = torch.zeros([dataset.n], dtype=torch.int)
|
||||
if local_rank == 0:
|
||||
indices[:] = torch.from_tensor(dataset.indices, dtype=torch.int)
|
||||
dist.broadcast(indices, 0)
|
||||
if local_rank != 0:
|
||||
dataset.indices = indices.cpu().numpy()
|
||||
|
||||
# Update mosaic border
|
||||
# b = int(random.uniform(0.25 * imgsz, 0.75 * imgsz + gs) // gs * gs)
|
||||
# dataset.mosaic_border = [b - imgsz, -b] # height, width borders
|
||||
|
||||
mloss = torch.zeros(4, device=device) # mean losses
|
||||
print(('\n' + '%10s' * 8) % ('Epoch', 'gpu_mem', 'GIoU', 'obj', 'cls', 'total', 'targets', 'img_size'))
|
||||
pbar = tqdm(enumerate(dataloader), total=nb) # progress bar
|
||||
if local_rank != -1:
|
||||
dataloader.sampler.set_epoch(epoch)
|
||||
pbar = enumerate(dataloader)
|
||||
if local_rank in [-1, 0]:
|
||||
print(('\n' + '%10s' * 8) % ('Epoch', 'gpu_mem', 'GIoU', 'obj', 'cls', 'total', 'targets', 'img_size'))
|
||||
pbar = tqdm(pbar, total=nb) # progress bar
|
||||
optimizer.zero_grad()
|
||||
for i, (imgs, targets, paths, _) in pbar: # batch -------------------------------------------------------------
|
||||
ni = i + nb * epoch # number integrated batches (since train start)
|
||||
imgs = imgs.to(device, non_blocking=True).float() / 255.0 # uint8 to float32, 0 - 255 to 0.0 - 1.0
|
||||
|
@ -229,7 +272,7 @@ def train(hyp):
|
|||
if ni <= nw:
|
||||
xi = [0, nw] # x interp
|
||||
# model.gr = np.interp(ni, xi, [0.0, 1.0]) # giou loss ratio (obj_loss = 1.0 or giou)
|
||||
accumulate = max(1, np.interp(ni, xi, [1, nbs / batch_size]).round())
|
||||
accumulate = max(1, np.interp(ni, xi, [1, nbs / total_batch_size]).round())
|
||||
for j, x in enumerate(optimizer.param_groups):
|
||||
# bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
|
||||
x['lr'] = np.interp(ni, xi, [0.1 if j == 2 else 0.0, x['initial_lr'] * lf(epoch)])
|
||||
|
@ -249,6 +292,9 @@ def train(hyp):
|
|||
|
||||
# Loss
|
||||
loss, loss_items = compute_loss(pred, targets.to(device), model)
|
||||
# loss is scaled with batch size in func compute_loss. But in DDP mode, gradient is averaged between devices.
|
||||
if local_rank != -1:
|
||||
loss *= opt.world_size
|
||||
if not torch.isfinite(loss):
|
||||
print('WARNING: non-finite loss, ending training ', loss_items)
|
||||
return results
|
||||
|
@ -264,106 +310,110 @@ def train(hyp):
|
|||
if ni % accumulate == 0:
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
ema.update(model)
|
||||
if ema is not None:
|
||||
ema.update(model)
|
||||
|
||||
# Print
|
||||
mloss = (mloss * i + loss_items) / (i + 1) # update mean losses
|
||||
mem = '%.3gG' % (torch.cuda.memory_cached() / 1E9 if torch.cuda.is_available() else 0) # (GB)
|
||||
s = ('%10s' * 2 + '%10.4g' * 6) % (
|
||||
'%g/%g' % (epoch, epochs - 1), mem, *mloss, targets.shape[0], imgs.shape[-1])
|
||||
pbar.set_description(s)
|
||||
if local_rank in [-1, 0]:
|
||||
mloss = (mloss * i + loss_items) / (i + 1) # update mean losses
|
||||
mem = '%.3gG' % (torch.cuda.memory_cached() / 1E9 if torch.cuda.is_available() else 0) # (GB)
|
||||
s = ('%10s' * 2 + '%10.4g' * 6) % (
|
||||
'%g/%g' % (epoch, epochs - 1), mem, *mloss, targets.shape[0], imgs.shape[-1])
|
||||
pbar.set_description(s)
|
||||
|
||||
# Plot
|
||||
if ni < 3:
|
||||
f = str(Path(log_dir) / ('train_batch%g.jpg' % ni)) # filename
|
||||
result = plot_images(images=imgs, targets=targets, paths=paths, fname=f)
|
||||
if tb_writer and result is not None:
|
||||
tb_writer.add_image(f, result, dataformats='HWC', global_step=epoch)
|
||||
# tb_writer.add_graph(model, imgs) # add model to tensorboard
|
||||
# Plot
|
||||
if ni < 3:
|
||||
f = str(Path(log_dir) / ('train_batch%g.jpg' % ni)) # filename
|
||||
result = plot_images(images=imgs, targets=targets, paths=paths, fname=f)
|
||||
if tb_writer and result is not None:
|
||||
tb_writer.add_image(f, result, dataformats='HWC', global_step=epoch)
|
||||
# tb_writer.add_graph(model, imgs) # add model to tensorboard
|
||||
|
||||
# end batch ------------------------------------------------------------------------------------------------
|
||||
|
||||
# Scheduler
|
||||
scheduler.step()
|
||||
|
||||
# mAP
|
||||
ema.update_attr(model, include=['md', 'nc', 'hyp', 'gr', 'names', 'stride'])
|
||||
final_epoch = epoch + 1 == epochs
|
||||
if not opt.notest or final_epoch: # Calculate mAP
|
||||
results, maps, times = test.test(opt.data,
|
||||
batch_size=batch_size,
|
||||
imgsz=imgsz_test,
|
||||
save_json=final_epoch and opt.data.endswith(os.sep + 'coco.yaml'),
|
||||
model=ema.ema,
|
||||
single_cls=opt.single_cls,
|
||||
dataloader=testloader,
|
||||
save_dir=log_dir)
|
||||
# Only the first process in DDP mode is allowed to log or save checkpoints.
|
||||
if local_rank in [-1, 0]:
|
||||
# mAP
|
||||
if ema is not None:
|
||||
ema.update_attr(model, include=['md', 'nc', 'hyp', 'gr', 'names', 'stride'])
|
||||
final_epoch = epoch + 1 == epochs
|
||||
if not opt.notest or final_epoch: # Calculate mAP
|
||||
results, maps, times = test.test(opt.data,
|
||||
batch_size=total_batch_size,
|
||||
imgsz=imgsz_test,
|
||||
save_json=final_epoch and opt.data.endswith(os.sep + 'coco.yaml'),
|
||||
model=ema.ema.module if hasattr(ema.ema, 'module') else ema.ema,
|
||||
single_cls=opt.single_cls,
|
||||
dataloader=testloader,
|
||||
save_dir=log_dir)
|
||||
# Explicitly keep the shape.
|
||||
# Write
|
||||
with open(results_file, 'a') as f:
|
||||
f.write(s + '%10.4g' * 7 % results + '\n') # P, R, mAP, F1, test_losses=(GIoU, obj, cls)
|
||||
if len(opt.name) and opt.bucket:
|
||||
os.system('gsutil cp results.txt gs://%s/results/results%s.txt' % (opt.bucket, opt.name))
|
||||
|
||||
# Write
|
||||
with open(results_file, 'a') as f:
|
||||
f.write(s + '%10.4g' * 7 % results + '\n') # P, R, mAP, F1, test_losses=(GIoU, obj, cls)
|
||||
if len(opt.name) and opt.bucket:
|
||||
os.system('gsutil cp %s gs://%s/results/results%s.txt' % (results_file, opt.bucket, opt.name))
|
||||
# Tensorboard
|
||||
if tb_writer:
|
||||
tags = ['train/giou_loss', 'train/obj_loss', 'train/cls_loss',
|
||||
'metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/F1',
|
||||
'val/giou_loss', 'val/obj_loss', 'val/cls_loss']
|
||||
for x, tag in zip(list(mloss[:-1]) + list(results), tags):
|
||||
tb_writer.add_scalar(tag, x, epoch)
|
||||
|
||||
# Tensorboard
|
||||
if tb_writer:
|
||||
tags = ['train/giou_loss', 'train/obj_loss', 'train/cls_loss',
|
||||
'metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/mAP_0.5:0.95',
|
||||
'val/giou_loss', 'val/obj_loss', 'val/cls_loss']
|
||||
for x, tag in zip(list(mloss[:-1]) + list(results), tags):
|
||||
tb_writer.add_scalar(tag, x, epoch)
|
||||
# Update best mAP
|
||||
fi = fitness(np.array(results).reshape(1, -1)) # fitness_i = weighted combination of [P, R, mAP, F1]
|
||||
if fi > best_fitness:
|
||||
best_fitness = fi
|
||||
|
||||
# Update best mAP
|
||||
fi = fitness(np.array(results).reshape(1, -1)) # fitness_i = weighted combination of [P, R, mAP, F1]
|
||||
if fi > best_fitness:
|
||||
best_fitness = fi
|
||||
|
||||
# Save model
|
||||
save = (not opt.nosave) or (final_epoch and not opt.evolve)
|
||||
if save:
|
||||
with open(results_file, 'r') as f: # create checkpoint
|
||||
ckpt = {'epoch': epoch,
|
||||
'best_fitness': best_fitness,
|
||||
'training_results': f.read(),
|
||||
'model': ema.ema,
|
||||
'optimizer': None if final_epoch else optimizer.state_dict()}
|
||||
|
||||
# Save last, best and delete
|
||||
torch.save(ckpt, last)
|
||||
if (best_fitness == fi) and not final_epoch:
|
||||
torch.save(ckpt, best)
|
||||
del ckpt
|
||||
# Save model
|
||||
save = (not opt.nosave) or (final_epoch and not opt.evolve)
|
||||
if save:
|
||||
with open(results_file, 'r') as f: # create checkpoint
|
||||
ckpt = {'epoch': epoch,
|
||||
'best_fitness': best_fitness,
|
||||
'training_results': f.read(),
|
||||
'model': ema.ema.module if hasattr(ema, 'module') else ema.ema,
|
||||
'optimizer': None if final_epoch else optimizer.state_dict()}
|
||||
|
||||
# Save last, best and delete
|
||||
torch.save(ckpt, last)
|
||||
if (best_fitness == fi) and not final_epoch:
|
||||
torch.save(ckpt, best)
|
||||
del ckpt
|
||||
# end epoch ----------------------------------------------------------------------------------------------------
|
||||
# end training
|
||||
|
||||
# Strip optimizers
|
||||
n = ('_' if len(opt.name) and not opt.name.isnumeric() else '') + opt.name
|
||||
fresults, flast, fbest = 'results%s.txt' % n, wdir + 'last%s.pt' % n, wdir + 'best%s.pt' % n
|
||||
for f1, f2 in zip([wdir + 'last.pt', wdir + 'best.pt', 'results.txt'], [flast, fbest, fresults]):
|
||||
if os.path.exists(f1):
|
||||
os.rename(f1, f2) # rename
|
||||
ispt = f2.endswith('.pt') # is *.pt
|
||||
strip_optimizer(f2) if ispt else None # strip optimizer
|
||||
os.system('gsutil cp %s gs://%s/weights' % (f2, opt.bucket)) if opt.bucket and ispt else None # upload
|
||||
if local_rank in [-1, 0]:
|
||||
# Strip optimizers
|
||||
n = ('_' if len(opt.name) and not opt.name.isnumeric() else '') + opt.name
|
||||
fresults, flast, fbest = 'results%s.txt' % n, wdir + 'last%s.pt' % n, wdir + 'best%s.pt' % n
|
||||
for f1, f2 in zip([wdir + 'last.pt', wdir + 'best.pt', 'results.txt'], [flast, fbest, fresults]):
|
||||
if os.path.exists(f1):
|
||||
os.rename(f1, f2) # rename
|
||||
ispt = f2.endswith('.pt') # is *.pt
|
||||
strip_optimizer(f2) if ispt else None # strip optimizer
|
||||
os.system('gsutil cp %s gs://%s/weights' % (f2, opt.bucket)) if opt.bucket and ispt else None # upload
|
||||
# Finish
|
||||
if not opt.evolve:
|
||||
plot_results() # save as results.png
|
||||
print('%g epochs completed in %.3f hours.\n' % (epoch - start_epoch + 1, (time.time() - t0) / 3600))
|
||||
|
||||
# Finish
|
||||
if not opt.evolve:
|
||||
plot_results(save_dir=log_dir) # save as results.png
|
||||
print('%g epochs completed in %.3f hours.\n' % (epoch - start_epoch + 1, (time.time() - t0) / 3600))
|
||||
dist.destroy_process_group() if device.type != 'cpu' and torch.cuda.device_count() > 1 else None
|
||||
dist.destroy_process_group() if local_rank not in [-1,0] else None
|
||||
torch.cuda.empty_cache()
|
||||
return results
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
check_git_status()
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--cfg', type=str, default='models/yolov5s.yaml', help='model.yaml path')
|
||||
parser.add_argument('--data', type=str, default='data/coco128.yaml', help='data.yaml path')
|
||||
parser.add_argument('--hyp', type=str, default='', help='hyp.yaml path (optional)')
|
||||
parser.add_argument('--epochs', type=int, default=300)
|
||||
parser.add_argument('--batch-size', type=int, default=16)
|
||||
parser.add_argument('--batch-size', type=int, default=16, help="Total batch size for all gpus.")
|
||||
parser.add_argument('--img-size', nargs='+', type=int, default=[640, 640], help='train,test sizes')
|
||||
parser.add_argument('--rect', action='store_true', help='rectangular training')
|
||||
parser.add_argument('--resume', nargs='?', const='get_last', default=False,
|
||||
|
@ -379,32 +429,54 @@ if __name__ == '__main__':
|
|||
parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
|
||||
parser.add_argument('--multi-scale', action='store_true', help='vary img-size +/- 50%%')
|
||||
parser.add_argument('--single-cls', action='store_true', help='train as single-class dataset')
|
||||
parser.add_argument("--sync-bn", action="store_true", help="Use sync-bn, only avaible in DDP mode.")
|
||||
# Parameter For DDP.
|
||||
parser.add_argument('--local_rank', type=int, default=-1, help="Extra parameter for DDP implementation. Don't use it manually.")
|
||||
opt = parser.parse_args()
|
||||
|
||||
last = get_latest_run() if opt.resume == 'get_last' else opt.resume # resume from most recent run
|
||||
if last and not opt.weights:
|
||||
print(f'Resuming training from {last}')
|
||||
opt.weights = last if opt.resume and not opt.weights else opt.weights
|
||||
if opt.local_rank in [-1, 0]:
|
||||
check_git_status()
|
||||
opt.cfg = check_file(opt.cfg) # check file
|
||||
opt.data = check_file(opt.data) # check file
|
||||
if opt.hyp: # update hyps
|
||||
opt.hyp = check_file(opt.hyp) # check file
|
||||
with open(opt.hyp) as f:
|
||||
hyp.update(yaml.load(f, Loader=yaml.FullLoader)) # update hyps
|
||||
print(opt)
|
||||
opt.img_size.extend([opt.img_size[-1]] * (2 - len(opt.img_size))) # extend to 2 sizes (train, test)
|
||||
device = torch_utils.select_device(opt.device, apex=mixed_precision, batch_size=opt.batch_size)
|
||||
opt.total_batch_size = opt.batch_size
|
||||
opt.world_size = 1
|
||||
if device.type == 'cpu':
|
||||
mixed_precision = False
|
||||
elif opt.local_rank != -1:
|
||||
# DDP mode
|
||||
assert torch.cuda.device_count() > opt.local_rank
|
||||
torch.cuda.set_device(opt.local_rank)
|
||||
device = torch.device("cuda", opt.local_rank)
|
||||
dist.init_process_group(backend='nccl', init_method='env://') # distributed backend
|
||||
|
||||
opt.world_size = dist.get_world_size()
|
||||
assert opt.batch_size % opt.world_size == 0, "Batch size is not a multiple of the number of devices given!"
|
||||
opt.batch_size = opt.total_batch_size // opt.world_size
|
||||
print(opt)
|
||||
|
||||
# Train
|
||||
if not opt.evolve:
|
||||
tb_writer = SummaryWriter(log_dir=increment_dir('runs/exp', opt.name))
|
||||
print('Start Tensorboard with "tensorboard --logdir=runs", view at http://localhost:6006/')
|
||||
train(hyp)
|
||||
if opt.local_rank in [-1, 0]:
|
||||
print('Start Tensorboard with "tensorboard --logdir=runs", view at http://localhost:6006/')
|
||||
tb_writer = SummaryWriter(log_dir=increment_dir('runs/exp', opt.name))
|
||||
else:
|
||||
tb_writer = None
|
||||
train(hyp, tb_writer, opt, device)
|
||||
|
||||
# Evolve hyperparameters (optional)
|
||||
else:
|
||||
assert opt.local_rank == -1, "DDP mode currently not implemented for Evolve!"
|
||||
|
||||
tb_writer = None
|
||||
opt.notest, opt.nosave = True, True # only test/save final epoch
|
||||
if opt.bucket:
|
||||
|
@ -443,7 +515,7 @@ if __name__ == '__main__':
|
|||
hyp[k] = np.clip(hyp[k], v[0], v[1])
|
||||
|
||||
# Train mutation
|
||||
results = train(hyp.copy())
|
||||
results = train(hyp.copy(), tb_writer, opt, device)
|
||||
|
||||
# Write mutation results
|
||||
print_mutation(hyp, results, opt.bucket)
|
||||
|
|
|
@ -14,7 +14,7 @@ from PIL import Image, ExifTags
|
|||
from torch.utils.data import Dataset
|
||||
from tqdm import tqdm
|
||||
|
||||
from utils.utils import xyxy2xywh, xywh2xyxy
|
||||
from utils.utils import xyxy2xywh, xywh2xyxy, torch_distributed_zero_first
|
||||
|
||||
help_url = 'https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data'
|
||||
img_formats = ['.bmp', '.jpg', '.jpeg', '.png', '.tif', '.dng']
|
||||
|
@ -46,21 +46,25 @@ def exif_size(img):
|
|||
return s
|
||||
|
||||
|
||||
def create_dataloader(path, imgsz, batch_size, stride, opt, hyp=None, augment=False, cache=False, pad=0.0, rect=False):
|
||||
dataset = LoadImagesAndLabels(path, imgsz, batch_size,
|
||||
augment=augment, # augment images
|
||||
hyp=hyp, # augmentation hyperparameters
|
||||
rect=rect, # rectangular training
|
||||
cache_images=cache,
|
||||
single_cls=opt.single_cls,
|
||||
stride=int(stride),
|
||||
pad=pad)
|
||||
def create_dataloader(path, imgsz, batch_size, stride, opt, hyp=None, augment=False, cache=False, pad=0.0, rect=False, local_rank=-1, world_size=1):
|
||||
# Make sure only the first process in DDP process the dataset first, and the following others can use the cache.
|
||||
with torch_distributed_zero_first(local_rank):
|
||||
dataset = LoadImagesAndLabels(path, imgsz, batch_size,
|
||||
augment=augment, # augment images
|
||||
hyp=hyp, # augmentation hyperparameters
|
||||
rect=rect, # rectangular training
|
||||
cache_images=cache,
|
||||
single_cls=opt.single_cls,
|
||||
stride=int(stride),
|
||||
pad=pad)
|
||||
|
||||
batch_size = min(batch_size, len(dataset))
|
||||
nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workers
|
||||
nw = min([os.cpu_count() // world_size, batch_size if batch_size > 1 else 0, 8]) # number of workers
|
||||
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset) if local_rank != -1 else None
|
||||
dataloader = torch.utils.data.DataLoader(dataset,
|
||||
batch_size=batch_size,
|
||||
num_workers=nw,
|
||||
sampler=train_sampler,
|
||||
pin_memory=True,
|
||||
collate_fn=LoadImagesAndLabels.collate_fn)
|
||||
return dataloader, dataset
|
||||
|
@ -301,7 +305,7 @@ class LoadImagesAndLabels(Dataset): # for training/testing
|
|||
f += glob.iglob(p + os.sep + '*.*')
|
||||
else:
|
||||
raise Exception('%s does not exist' % p)
|
||||
self.img_files = [x.replace('/', os.sep) for x in f if os.path.splitext(x)[-1].lower() in img_formats]
|
||||
self.img_files = sorted([x.replace('/', os.sep) for x in f if os.path.splitext(x)[-1].lower() in img_formats])
|
||||
except Exception as e:
|
||||
raise Exception('Error loading data from %s: %s\nSee %s' % (path, e, help_url))
|
||||
|
||||
|
|
|
@ -8,6 +8,7 @@ import time
|
|||
from copy import copy
|
||||
from pathlib import Path
|
||||
from sys import platform
|
||||
from contextlib import contextmanager
|
||||
|
||||
import cv2
|
||||
import matplotlib
|
||||
|
@ -31,6 +32,18 @@ matplotlib.rc('font', **{'size': 11})
|
|||
cv2.setNumThreads(0)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def torch_distributed_zero_first(local_rank: int):
|
||||
"""
|
||||
Decorator to make all processes in distributed training wait for each local_master to do something.
|
||||
"""
|
||||
if local_rank not in [-1, 0]:
|
||||
torch.distributed.barrier()
|
||||
yield
|
||||
if local_rank == 0:
|
||||
torch.distributed.barrier()
|
||||
|
||||
|
||||
def init_seeds(seed=0):
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
|
@ -424,15 +437,16 @@ class BCEBlurWithLogitsLoss(nn.Module):
|
|||
|
||||
|
||||
def compute_loss(p, targets, model): # predictions, targets, model
|
||||
device = targets.device
|
||||
ft = torch.cuda.FloatTensor if p[0].is_cuda else torch.Tensor
|
||||
lcls, lbox, lobj = ft([0]), ft([0]), ft([0])
|
||||
lcls, lbox, lobj = ft([0]).to(device), ft([0]).to(device), ft([0]).to(device)
|
||||
tcls, tbox, indices, anchors = build_targets(p, targets, model) # targets
|
||||
h = model.hyp # hyperparameters
|
||||
red = 'mean' # Loss reduction (sum or mean)
|
||||
|
||||
# Define criteria
|
||||
BCEcls = nn.BCEWithLogitsLoss(pos_weight=ft([h['cls_pw']]), reduction=red)
|
||||
BCEobj = nn.BCEWithLogitsLoss(pos_weight=ft([h['obj_pw']]), reduction=red)
|
||||
BCEcls = nn.BCEWithLogitsLoss(pos_weight=ft([h['cls_pw']]), reduction=red).to(device)
|
||||
BCEobj = nn.BCEWithLogitsLoss(pos_weight=ft([h['obj_pw']]), reduction=red).to(device)
|
||||
|
||||
# class label smoothing https://arxiv.org/pdf/1902.04103.pdf eqn 3
|
||||
cp, cn = smooth_BCE(eps=0.0)
|
||||
|
@ -448,7 +462,7 @@ def compute_loss(p, targets, model): # predictions, targets, model
|
|||
balance = [1.0, 1.0, 1.0]
|
||||
for i, pi in enumerate(p): # layer index, layer predictions
|
||||
b, a, gj, gi = indices[i] # image, anchor, gridy, gridx
|
||||
tobj = torch.zeros_like(pi[..., 0]) # target obj
|
||||
tobj = torch.zeros_like(pi[..., 0]).to(device) # target obj
|
||||
|
||||
nb = b.shape[0] # number of targets
|
||||
if nb:
|
||||
|
@ -458,7 +472,7 @@ def compute_loss(p, targets, model): # predictions, targets, model
|
|||
# GIoU
|
||||
pxy = ps[:, :2].sigmoid() * 2. - 0.5
|
||||
pwh = (ps[:, 2:4].sigmoid() * 2) ** 2 * anchors[i]
|
||||
pbox = torch.cat((pxy, pwh), 1) # predicted box
|
||||
pbox = torch.cat((pxy, pwh), 1).to(device) # predicted box
|
||||
giou = bbox_iou(pbox.t(), tbox[i], x1y1x2y2=False, GIoU=True) # giou(prediction, target)
|
||||
lbox += (1.0 - giou).sum() if red == 'sum' else (1.0 - giou).mean() # giou loss
|
||||
|
||||
|
@ -467,7 +481,7 @@ def compute_loss(p, targets, model): # predictions, targets, model
|
|||
|
||||
# Class
|
||||
if model.nc > 1: # cls loss (only if multiple classes)
|
||||
t = torch.full_like(ps[:, 5:], cn) # targets
|
||||
t = torch.full_like(ps[:, 5:], cn).to(device) # targets
|
||||
t[range(nb), tcls[i]] = cp
|
||||
lcls += BCEcls(ps[:, 5:], t) # BCE
|
||||
|
||||
|
|
Loading…
Reference in New Issue