Update train.py

updated deprecated call

Signed-off-by: Shubham Phapale <94707673+ShubhamPhapale@users.noreply.github.com>
pull/13497/head
Shubham Phapale 2025-01-22 22:06:30 +05:30 committed by GitHub
parent de62f93c21
commit c29e1d62e3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 1 additions and 1 deletions

View File

@ -409,7 +409,7 @@ def train(hyp, opt, device, callbacks):
imgs = nn.functional.interpolate(imgs, size=ns, mode="bilinear", align_corners=False)
# Forward
with torch.cuda.amp.autocast(amp):
with torch.amp.autocast('cuda', amp):
pred = model(imgs) # forward
loss, loss_items = compute_loss(pred, targets.to(device)) # loss scaled by batch_size
if RANK != -1: