pull/13483/head
UltralyticsAssistant 2025-02-23 14:49:51 +00:00
parent b65447a098
commit 94edd0785f
3 changed files with 7 additions and 6 deletions

View File

@ -48,12 +48,13 @@ from utils.general import (
)
from utils.torch_utils import select_device, smart_inference_mode
#version check
# version check
if torch.__version__.startswith("1.8"):
Autocast = torch.cuda.amp.autocast
Autocast = torch.cuda.amp.autocast
else:
Autocast = torch.amp.autocast
@smart_inference_mode()
def run(
data=ROOT / "../datasets/mnist", # dataset dir

View File

@ -55,12 +55,13 @@ from utils.general import (
)
from utils.torch_utils import copy_attr, smart_inference_mode
# version check
# version check
if torch.__version__.startswith("1.8"):
Autocast = torch.cuda.amp.autocast
else:
Autocast = torch.amp.autocast
def autopad(k, p=None, d=1):
"""
Pads kernel to 'same' output shape, adjusting for optional dilation; returns padding size.

View File

@ -387,11 +387,10 @@ def train(hyp, opt, device, callbacks):
ns = [math.ceil(x * sf / gs) * gs for x in imgs.shape[2:]] # new shape (stretched to gs-multiple)
imgs = nn.functional.interpolate(imgs, size=ns, mode="bilinear", align_corners=False)
amp_autocast = None
if torch.__version__.startswith("1.8"):
amp_autocast = torch.cuda.amp.autocast(enabled=amp)
torch.cuda.amp.autocast(enabled=amp)
else:
amp_autocast = torch.amp.autocast("cuda", enabled=amp)
torch.amp.autocast("cuda", enabled=amp)
# Forward
with Autocast(enabled=amp):
pred = model(imgs) # forward