From c29e1d62e31ea8b49d050b0e162d2fde4fe06436 Mon Sep 17 00:00:00 2001 From: Shubham Phapale <94707673+ShubhamPhapale@users.noreply.github.com> Date: Wed, 22 Jan 2025 22:06:30 +0530 Subject: [PATCH] Update train.py updated deprecated call Signed-off-by: Shubham Phapale <94707673+ShubhamPhapale@users.noreply.github.com> --- train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train.py b/train.py index 1401ccb96..babbc4c9f 100644 --- a/train.py +++ b/train.py @@ -409,7 +409,7 @@ def train(hyp, opt, device, callbacks): imgs = nn.functional.interpolate(imgs, size=ns, mode="bilinear", align_corners=False) # Forward - with torch.cuda.amp.autocast(amp): + with torch.amp.autocast('cuda', amp): pred = model(imgs) # forward loss, loss_items = compute_loss(pred, targets.to(device)) # loss scaled by batch_size if RANK != -1: