Auto-format by https://ultralytics.com
parent
ed9b85f0d2
commit
6fe685f9aa
16
train.py
16
train.py
|
@ -95,12 +95,12 @@ from utils.torch_utils import (
|
|||
)
|
||||
|
||||
# version check
|
||||
#if torch.__version__.startswith("1.8"):
|
||||
# Autocast = torch.cuda.amp.autocast(enabled=amp)
|
||||
# GradScaler = torch.cuda.amp.GradScaler
|
||||
#else:
|
||||
# Autocast = torch.amp.autocast("cuda", enabled=amp)
|
||||
# GradScaler = torch.amp.GradScaler
|
||||
# if torch.__version__.startswith("1.8"):
|
||||
# Autocast = torch.cuda.amp.autocast(enabled=amp)
|
||||
# GradScaler = torch.cuda.amp.GradScaler
|
||||
# else:
|
||||
# Autocast = torch.amp.autocast("cuda", enabled=amp)
|
||||
# GradScaler = torch.amp.GradScaler
|
||||
|
||||
LOCAL_RANK = int(os.getenv("LOCAL_RANK", -1)) # https://pytorch.org/docs/stable/elastic/run.html
|
||||
RANK = int(os.getenv("RANK", -1))
|
||||
|
@ -121,7 +121,7 @@ def train(hyp, opt, device, callbacks):
|
|||
|
||||
Returns:
|
||||
None
|
||||
#
|
||||
#
|
||||
Models and datasets download automatically from the latest YOLOv5 release.
|
||||
|
||||
Example:
|
||||
|
@ -422,7 +422,7 @@ def train(hyp, opt, device, callbacks):
|
|||
imgs = nn.functional.interpolate(imgs, size=ns, mode="bilinear", align_corners=False)
|
||||
|
||||
# Forward
|
||||
#with Autocast:
|
||||
# with Autocast:
|
||||
amp_autocast = None
|
||||
if torch.__version__.startswith("1.8"):
|
||||
amp_autocast = torch.cuda.amp.autocast(enabled=amp)
|
||||
|
|
Loading…
Reference in New Issue