hyperparameter expansion to flips, perspective, mixup
parent
6f08e8bcce
commit
127cbeb3f5
23
train.py
23
train.py
|
@ -16,25 +16,29 @@ from utils.datasets import *
|
|||
from utils.utils import *
|
||||
|
||||
# Hyperparameters
|
||||
hyp = {'optimizer': 'SGD', # ['adam', 'SGD', None] if none, default is SGD
|
||||
hyp = {'optimizer': 'SGD', # ['Adam', 'SGD', ...] from torch.optim
|
||||
'lr0': 0.01, # initial learning rate (SGD=1E-2, Adam=1E-3)
|
||||
'momentum': 0.937, # SGD momentum/Adam beta1
|
||||
'weight_decay': 5e-4, # optimizer weight decay
|
||||
'giou': 0.05, # giou loss gain
|
||||
'giou': 0.05, # GIoU loss gain
|
||||
'cls': 0.5, # cls loss gain
|
||||
'cls_pw': 1.0, # cls BCELoss positive_weight
|
||||
'obj': 1.0, # obj loss gain (*=img_size/320 if img_size != 320)
|
||||
'obj': 1.0, # obj loss gain (scale with pixels)
|
||||
'obj_pw': 1.0, # obj BCELoss positive_weight
|
||||
'iou_t': 0.20, # iou training threshold
|
||||
'iou_t': 0.20, # IoU training threshold
|
||||
'anchor_t': 4.0, # anchor-multiple threshold
|
||||
'fl_gamma': 0.0, # focal loss gamma (efficientDet default is gamma=1.5)
|
||||
'fl_gamma': 0.0, # focal loss gamma (efficientDet default gamma=1.5)
|
||||
'hsv_h': 0.015, # image HSV-Hue augmentation (fraction)
|
||||
'hsv_s': 0.7, # image HSV-Saturation augmentation (fraction)
|
||||
'hsv_v': 0.4, # image HSV-Value augmentation (fraction)
|
||||
'degrees': 0.0, # image rotation (+/- deg)
|
||||
'translate': 0.5, # image translation (+/- fraction)
|
||||
'scale': 0.5, # image scale (+/- gain)
|
||||
'shear': 0.0} # image shear (+/- deg)
|
||||
'shear': 0.0, # image shear (+/- deg)
|
||||
'perspective': 0.0, # image perspective (+/- fraction), range 0-0.001
|
||||
'flipud': 0.0, # image flip up-down (probability)
|
||||
'fliplr': 0.5, # image flip left-right (probability)
|
||||
'mixup': 0.0} # image mixup (probability)
|
||||
|
||||
|
||||
def train(hyp, tb_writer, opt, device):
|
||||
|
@ -47,8 +51,7 @@ def train(hyp, tb_writer, opt, device):
|
|||
results_file = log_dir + os.sep + 'results.txt'
|
||||
epochs, batch_size, total_batch_size, weights, rank = \
|
||||
opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.local_rank
|
||||
# TODO: Init DDP logging. Only the first process is allowed to log.
|
||||
# Since I see lots of print here, the logging configuration is skipped here. We may see repeated outputs.
|
||||
# TODO: Use DDP logging. Only the first process is allowed to log.
|
||||
|
||||
# Save run settings
|
||||
with open(Path(log_dir) / 'hyp.yaml', 'w') as f:
|
||||
|
@ -99,7 +102,7 @@ def train(hyp, tb_writer, opt, device):
|
|||
else:
|
||||
pg0.append(v) # all else
|
||||
|
||||
if hyp['optimizer'] == 'adam': # https://pytorch.org/docs/stable/_modules/torch/optim/lr_scheduler.html#OneCycleLR
|
||||
if hyp['optimizer'] == 'Adam':
|
||||
optimizer = optim.Adam(pg0, lr=hyp['lr0'], betas=(hyp['momentum'], 0.999)) # adjust beta1 to momentum
|
||||
else:
|
||||
optimizer = optim.SGD(pg0, lr=hyp['lr0'], momentum=hyp['momentum'], nesterov=True)
|
||||
|
@ -110,9 +113,9 @@ def train(hyp, tb_writer, opt, device):
|
|||
del pg0, pg1, pg2
|
||||
|
||||
# Scheduler https://arxiv.org/pdf/1812.01187.pdf
|
||||
# https://pytorch.org/docs/stable/_modules/torch/optim/lr_scheduler.html#OneCycleLR
|
||||
lf = lambda x: (((1 + math.cos(x * math.pi / epochs)) / 2) ** 1.0) * 0.8 + 0.2 # cosine
|
||||
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)
|
||||
# https://discuss.pytorch.org/t/a-problem-occured-when-resuming-an-optimizer/28822
|
||||
# plot_lr_scheduler(optimizer, scheduler, epochs)
|
||||
|
||||
# Load Model
|
||||
|
|
|
@ -484,11 +484,11 @@ class LoadImagesAndLabels(Dataset): # for training/testing
|
|||
shapes = None
|
||||
|
||||
# MixUp https://arxiv.org/pdf/1710.09412.pdf
|
||||
# if random.random() < 0.5:
|
||||
# img2, labels2 = load_mosaic(self, random.randint(0, len(self.labels) - 1))
|
||||
# r = np.random.beta(0.3, 0.3) # mixup ratio, alpha=beta=0.3
|
||||
# img = (img * r + img2 * (1 - r)).astype(np.uint8)
|
||||
# labels = np.concatenate((labels, labels2), 0)
|
||||
if random.random() < hyp['mixup']:
|
||||
img2, labels2 = load_mosaic(self, random.randint(0, len(self.labels) - 1))
|
||||
r = np.random.beta(8.0, 8.0) # mixup ratio, alpha=beta=8.0
|
||||
img = (img * r + img2 * (1 - r)).astype(np.uint8)
|
||||
labels = np.concatenate((labels, labels2), 0)
|
||||
|
||||
else:
|
||||
# Load image
|
||||
|
@ -517,7 +517,8 @@ class LoadImagesAndLabels(Dataset): # for training/testing
|
|||
degrees=hyp['degrees'],
|
||||
translate=hyp['translate'],
|
||||
scale=hyp['scale'],
|
||||
shear=hyp['shear'])
|
||||
shear=hyp['shear'],
|
||||
perspective=hyp['perspective'])
|
||||
|
||||
# Augment colorspace
|
||||
augment_hsv(img, hgain=hyp['hsv_h'], sgain=hyp['hsv_s'], vgain=hyp['hsv_v'])
|
||||
|
@ -528,28 +529,23 @@ class LoadImagesAndLabels(Dataset): # for training/testing
|
|||
|
||||
nL = len(labels) # number of labels
|
||||
if nL:
|
||||
# convert xyxy to xywh
|
||||
labels[:, 1:5] = xyxy2xywh(labels[:, 1:5])
|
||||
|
||||
# Normalize coordinates 0 - 1
|
||||
labels[:, [2, 4]] /= img.shape[0] # height
|
||||
labels[:, [1, 3]] /= img.shape[1] # width
|
||||
labels[:, 1:5] = xyxy2xywh(labels[:, 1:5]) # convert xyxy to xywh
|
||||
labels[:, [2, 4]] /= img.shape[0] # normalized height 0-1
|
||||
labels[:, [1, 3]] /= img.shape[1] # normalized width 0-1
|
||||
|
||||
if self.augment:
|
||||
# random left-right flip
|
||||
lr_flip = True
|
||||
if lr_flip and random.random() < 0.5:
|
||||
img = np.fliplr(img)
|
||||
if nL:
|
||||
labels[:, 1] = 1 - labels[:, 1]
|
||||
|
||||
# random up-down flip
|
||||
ud_flip = False
|
||||
if ud_flip and random.random() < 0.5:
|
||||
# flip up-down
|
||||
if random.random() < hyp['flipud']:
|
||||
img = np.flipud(img)
|
||||
if nL:
|
||||
labels[:, 2] = 1 - labels[:, 2]
|
||||
|
||||
# flip left-right
|
||||
if random.random() < hyp['fliplr']:
|
||||
img = np.fliplr(img)
|
||||
if nL:
|
||||
labels[:, 1] = 1 - labels[:, 1]
|
||||
|
||||
labels_out = torch.zeros((nL, 6))
|
||||
if nL:
|
||||
labels_out[:, 1:] = torch.from_numpy(labels)
|
||||
|
@ -661,6 +657,7 @@ def load_mosaic(self, index):
|
|||
translate=self.hyp['translate'],
|
||||
scale=self.hyp['scale'],
|
||||
shear=self.hyp['shear'],
|
||||
perspective=self.hyp['perspective'],
|
||||
border=self.mosaic_border) # border to remove
|
||||
|
||||
return img4, labels4
|
||||
|
|
Loading…
Reference in New Issue