opt.img_weights bug fix (#885)

pull/898/head
Glenn Jocher 2020-08-31 10:33:07 -07:00
parent 987c226849
commit 69ff781ca5
1 changed files with 8 additions and 10 deletions

View File

@ -216,18 +216,15 @@ def train(hyp, opt, device, tb_writer=None):
model.train() model.train()
# Update image weights (optional) # Update image weights (optional)
if dataset.image_weights: if opt.img_weights:
# Generate indices # Generate indices
if rank in [-1, 0]: if rank in [-1, 0]:
w = model.class_weights.cpu().numpy() * (1 - maps) ** 2 # class weights cw = model.class_weights.cpu().numpy() * (1 - maps) ** 2 # class weights
image_weights = labels_to_image_weights(dataset.labels, nc=nc, class_weights=w) iw = labels_to_image_weights(dataset.labels, nc=nc, class_weights=cw) # image weights
dataset.indices = random.choices(range(dataset.n), weights=image_weights, dataset.indices = random.choices(range(dataset.n), weights=iw, k=dataset.n) # rand weighted idx
k=dataset.n) # rand weighted idx
# Broadcast if DDP # Broadcast if DDP
if rank != -1: if rank != -1:
indices = torch.zeros([dataset.n], dtype=torch.int) indices = (torch.tensor(dataset.indices) if rank == 0 else torch.zeros(dataset.n)).int()
if rank == 0:
indices[:] = torch.tensor(dataset.indices, dtype=torch.int)
dist.broadcast(indices, 0) dist.broadcast(indices, 0)
if rank != 0: if rank != 0:
dataset.indices = indices.cpu().numpy() dataset.indices = indices.cpu().numpy()
@ -388,7 +385,8 @@ if __name__ == '__main__':
parser.add_argument('--hyp', type=str, default='', help='hyperparameters path, i.e. data/hyp.scratch.yaml') parser.add_argument('--hyp', type=str, default='', help='hyperparameters path, i.e. data/hyp.scratch.yaml')
parser.add_argument('--epochs', type=int, default=300) parser.add_argument('--epochs', type=int, default=300)
parser.add_argument('--batch-size', type=int, default=16, help='total batch size for all GPUs') parser.add_argument('--batch-size', type=int, default=16, help='total batch size for all GPUs')
parser.add_argument('--img-size', nargs='+', type=int, default=[640, 640], help='train,test sizes') parser.add_argument('--img-size', nargs='+', type=int, default=[640, 640], help='[train, test] image sizes')
parser.add_argument('--img-weights', action='store_true', help='use weighted image selection for training')
parser.add_argument('--rect', action='store_true', help='rectangular training') parser.add_argument('--rect', action='store_true', help='rectangular training')
parser.add_argument('--resume', nargs='?', const=True, default=False, help='resume most recent training') parser.add_argument('--resume', nargs='?', const=True, default=False, help='resume most recent training')
parser.add_argument('--nosave', action='store_true', help='only save final checkpoint') parser.add_argument('--nosave', action='store_true', help='only save final checkpoint')
@ -471,7 +469,7 @@ if __name__ == '__main__':
'obj_pw': (1, 0.5, 2.0), # obj BCELoss positive_weight 'obj_pw': (1, 0.5, 2.0), # obj BCELoss positive_weight
'iou_t': (0, 0.1, 0.7), # IoU training threshold 'iou_t': (0, 0.1, 0.7), # IoU training threshold
'anchor_t': (1, 2.0, 8.0), # anchor-multiple threshold 'anchor_t': (1, 2.0, 8.0), # anchor-multiple threshold
# 'anchors': (1, 2.0, 10.0), # anchors per output grid (0 to ignore) # 'anchors': (1, 2.0, 10.0), # anchors per output grid (0 to ignore)
'fl_gamma': (0, 0.0, 2.0), # focal loss gamma (efficientDet default gamma=1.5) 'fl_gamma': (0, 0.0, 2.0), # focal loss gamma (efficientDet default gamma=1.5)
'hsv_h': (1, 0.0, 0.1), # image HSV-Hue augmentation (fraction) 'hsv_h': (1, 0.0, 0.1), # image HSV-Hue augmentation (fraction)
'hsv_s': (1, 0.0, 0.9), # image HSV-Saturation augmentation (fraction) 'hsv_s': (1, 0.0, 0.9), # image HSV-Saturation augmentation (fraction)