refactored
commit
c4ec77fa98
|
@ -28,10 +28,8 @@ import torch.distributed as dist
|
|||
import torch.hub as hub
|
||||
import torch.optim.lr_scheduler as lr_scheduler
|
||||
import torchvision
|
||||
from torch.cuda import amp
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
FILE = Path(__file__).resolve()
|
||||
ROOT = FILE.parents[1] # YOLOv5 root directory
|
||||
if str(ROOT) not in sys.path:
|
||||
|
|
|
@ -90,7 +90,6 @@ from utils.torch_utils import (
|
|||
torch_distributed_zero_first,
|
||||
)
|
||||
|
||||
|
||||
LOCAL_RANK = int(os.getenv("LOCAL_RANK", -1)) # https://pytorch.org/docs/stable/elastic/run.html
|
||||
RANK = int(os.getenv("RANK", -1))
|
||||
WORLD_SIZE = int(os.getenv("WORLD_SIZE", 1))
|
||||
|
@ -388,7 +387,7 @@ def train(hyp, opt, device, callbacks):
|
|||
else:
|
||||
amp_autocast = torch.amp.autocast("cuda", enabled=amp)
|
||||
# Forward
|
||||
with amp_autocast:
|
||||
with amp_autocast:
|
||||
pred = model(imgs) # forward
|
||||
loss, loss_items = compute_loss(pred, targets.to(device), masks=masks.to(device).float())
|
||||
if RANK != -1:
|
||||
|
|
Loading…
Reference in New Issue