Update train.py forward simplification
parent
455f7b8f76
commit
a21bd0687c
12
train.py
12
train.py
|
@ -265,18 +265,12 @@ def train(hyp, opt, device, tb_writer=None):
|
|||
ns = [math.ceil(x * sf / gs) * gs for x in imgs.shape[2:]] # new shape (stretched to gs-multiple)
|
||||
imgs = F.interpolate(imgs, size=ns, mode='bilinear', align_corners=False)
|
||||
|
||||
# Autocast
|
||||
# Forward
|
||||
with amp.autocast(enabled=cuda):
|
||||
# Forward
|
||||
pred = model(imgs)
|
||||
|
||||
# Loss
|
||||
loss, loss_items = compute_loss(pred, targets.to(device), model) # scaled by batch_size
|
||||
pred = model(imgs) # forward
|
||||
loss, loss_items = compute_loss(pred, targets.to(device), model) # loss scaled by batch_size
|
||||
if rank != -1:
|
||||
loss *= opt.world_size # gradient averaged between devices in DDP mode
|
||||
# if not torch.isfinite(loss):
|
||||
# logger.info('WARNING: non-finite loss, ending training ', loss_items)
|
||||
# return results
|
||||
|
||||
# Backward
|
||||
scaler.scale(loss).backward()
|
||||
|
|
Loading…
Reference in New Issue