Update train.py
updated deprecated call Signed-off-by: Shubham Phapale <94707673+ShubhamPhapale@users.noreply.github.com>pull/13497/head
parent
de62f93c21
commit
c29e1d62e3
2
train.py
2
train.py
|
@ -409,7 +409,7 @@ def train(hyp, opt, device, callbacks):
|
||||||
imgs = nn.functional.interpolate(imgs, size=ns, mode="bilinear", align_corners=False)
|
imgs = nn.functional.interpolate(imgs, size=ns, mode="bilinear", align_corners=False)
|
||||||
|
|
||||||
# Forward
|
# Forward
|
||||||
with torch.cuda.amp.autocast(amp):
|
with torch.amp.autocast('cuda', amp):
|
||||||
pred = model(imgs) # forward
|
pred = model(imgs) # forward
|
||||||
loss, loss_items = compute_loss(pred, targets.to(device)) # loss scaled by batch_size
|
loss, loss_items = compute_loss(pred, targets.to(device)) # loss scaled by batch_size
|
||||||
if RANK != -1:
|
if RANK != -1:
|
||||||
|
|
Loading…
Reference in New Issue