Auto-format by https://ultralytics.com
parent
b65447a098
commit
94edd0785f
|
@ -48,12 +48,13 @@ from utils.general import (
|
|||
)
|
||||
from utils.torch_utils import select_device, smart_inference_mode
|
||||
|
||||
#version check
|
||||
# version check
|
||||
if torch.__version__.startswith("1.8"):
|
||||
Autocast = torch.cuda.amp.autocast
|
||||
Autocast = torch.cuda.amp.autocast
|
||||
else:
|
||||
Autocast = torch.amp.autocast
|
||||
|
||||
|
||||
@smart_inference_mode()
|
||||
def run(
|
||||
data=ROOT / "../datasets/mnist", # dataset dir
|
||||
|
|
|
@ -55,12 +55,13 @@ from utils.general import (
|
|||
)
|
||||
from utils.torch_utils import copy_attr, smart_inference_mode
|
||||
|
||||
# version check
|
||||
# version check
|
||||
if torch.__version__.startswith("1.8"):
|
||||
Autocast = torch.cuda.amp.autocast
|
||||
else:
|
||||
Autocast = torch.amp.autocast
|
||||
|
||||
|
||||
def autopad(k, p=None, d=1):
|
||||
"""
|
||||
Pads kernel to 'same' output shape, adjusting for optional dilation; returns padding size.
|
||||
|
|
|
@ -387,11 +387,10 @@ def train(hyp, opt, device, callbacks):
|
|||
ns = [math.ceil(x * sf / gs) * gs for x in imgs.shape[2:]] # new shape (stretched to gs-multiple)
|
||||
imgs = nn.functional.interpolate(imgs, size=ns, mode="bilinear", align_corners=False)
|
||||
|
||||
amp_autocast = None
|
||||
if torch.__version__.startswith("1.8"):
|
||||
amp_autocast = torch.cuda.amp.autocast(enabled=amp)
|
||||
torch.cuda.amp.autocast(enabled=amp)
|
||||
else:
|
||||
amp_autocast = torch.amp.autocast("cuda", enabled=amp)
|
||||
torch.amp.autocast("cuda", enabled=amp)
|
||||
# Forward
|
||||
with Autocast(enabled=amp):
|
||||
pred = model(imgs) # forward
|
||||
|
|
Loading…
Reference in New Issue