From a21bd0687ccccd2f7291b7ac41daca77d3cb8c12 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Tue, 25 Aug 2020 13:48:03 -0700 Subject: [PATCH] Update train.py forward simplification --- train.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/train.py b/train.py index 52ea92752..8ac98df19 100644 --- a/train.py +++ b/train.py @@ -265,18 +265,12 @@ def train(hyp, opt, device, tb_writer=None): ns = [math.ceil(x * sf / gs) * gs for x in imgs.shape[2:]] # new shape (stretched to gs-multiple) imgs = F.interpolate(imgs, size=ns, mode='bilinear', align_corners=False) - # Autocast + # Forward with amp.autocast(enabled=cuda): - # Forward - pred = model(imgs) - - # Loss - loss, loss_items = compute_loss(pred, targets.to(device), model) # scaled by batch_size + pred = model(imgs) # forward + loss, loss_items = compute_loss(pred, targets.to(device), model) # loss scaled by batch_size if rank != -1: loss *= opt.world_size # gradient averaged between devices in DDP mode - # if not torch.isfinite(loss): - # logger.info('WARNING: non-finite loss, ending training ', loss_items) - # return results # Backward scaler.scale(loss).backward()