diff --git a/train.py b/train.py index 0b703e9ba..535993acf 100644 --- a/train.py +++ b/train.py @@ -352,15 +352,7 @@ def train(hyp, opt, device, callbacks): maps = np.zeros(nc) # mAP per class results = (0, 0, 0, 0, 0, 0, 0) # P, R, mAP@.5, mAP@.5-.95, val_loss(box, obj, cls) scheduler.last_epoch = start_epoch - 1 # do not move -<<<<<<< HEAD - scaler = torch.amp.GradScaler(enabled=device.type != "cpu") -======= - - # checking if autocast is available - device_amp = torch.is_autocast_available(device_type=device.type) - - scaler = torch.amp.GradScaler(enabled=(device_amp and device.type != "cpu")) ->>>>>>> 5d03fd8cdd44ce49148653ba4ea874d9cd41a832 + scaler = torch.amp.GradScaler("cuda", enabled=amp) #updated stopper, stop = EarlyStopping(patience=opt.patience), False compute_loss = ComputeLoss(model) # init loss class callbacks.run("on_train_start") @@ -417,7 +409,7 @@ def train(hyp, opt, device, callbacks): imgs = nn.functional.interpolate(imgs, size=ns, mode="bilinear", align_corners=False) # Forward - with torch.amp.autocast(device_type=device.type, enabled=True): + with torch.amp.autocast("cuda", enabled=amp): pred = model(imgs) # forward loss, loss_items = compute_loss(pred, targets.to(device)) # loss scaled by batch_size if RANK != -1: