Update train.py forward simplification

pull/860/head
Glenn Jocher 2020-08-25 13:48:03 -07:00
parent 455f7b8f76
commit a21bd0687c
1 changed files with 3 additions and 9 deletions

View File

@ -265,18 +265,12 @@ def train(hyp, opt, device, tb_writer=None):
ns = [math.ceil(x * sf / gs) * gs for x in imgs.shape[2:]] # new shape (stretched to gs-multiple)
imgs = F.interpolate(imgs, size=ns, mode='bilinear', align_corners=False)
# Autocast
# Forward
with amp.autocast(enabled=cuda):
# Forward
pred = model(imgs)
# Loss
loss, loss_items = compute_loss(pred, targets.to(device), model) # scaled by batch_size
pred = model(imgs) # forward
loss, loss_items = compute_loss(pred, targets.to(device), model) # loss scaled by batch_size
if rank != -1:
loss *= opt.world_size # gradient averaged between devices in DDP mode
# if not torch.isfinite(loss):
# logger.info('WARNING: non-finite loss, ending training ', loss_items)
# return results
# Backward
scaler.scale(loss).backward()