Update train.py forward simplification
parent
455f7b8f76
commit
a21bd0687c
12
train.py
12
train.py
|
@ -265,18 +265,12 @@ def train(hyp, opt, device, tb_writer=None):
|
||||||
ns = [math.ceil(x * sf / gs) * gs for x in imgs.shape[2:]] # new shape (stretched to gs-multiple)
|
ns = [math.ceil(x * sf / gs) * gs for x in imgs.shape[2:]] # new shape (stretched to gs-multiple)
|
||||||
imgs = F.interpolate(imgs, size=ns, mode='bilinear', align_corners=False)
|
imgs = F.interpolate(imgs, size=ns, mode='bilinear', align_corners=False)
|
||||||
|
|
||||||
# Autocast
|
|
||||||
with amp.autocast(enabled=cuda):
|
|
||||||
# Forward
|
# Forward
|
||||||
pred = model(imgs)
|
with amp.autocast(enabled=cuda):
|
||||||
|
pred = model(imgs) # forward
|
||||||
# Loss
|
loss, loss_items = compute_loss(pred, targets.to(device), model) # loss scaled by batch_size
|
||||||
loss, loss_items = compute_loss(pred, targets.to(device), model) # scaled by batch_size
|
|
||||||
if rank != -1:
|
if rank != -1:
|
||||||
loss *= opt.world_size # gradient averaged between devices in DDP mode
|
loss *= opt.world_size # gradient averaged between devices in DDP mode
|
||||||
# if not torch.isfinite(loss):
|
|
||||||
# logger.info('WARNING: non-finite loss, ending training ', loss_items)
|
|
||||||
# return results
|
|
||||||
|
|
||||||
# Backward
|
# Backward
|
||||||
scaler.scale(loss).backward()
|
scaler.scale(loss).backward()
|
||||||
|
|
Loading…
Reference in New Issue