mirror of
https://github.com/ultralytics/yolov5.git
synced 2025-06-03 14:49:29 +08:00
refactored
This commit is contained in:
commit
c4ec77fa98
@ -28,10 +28,8 @@ import torch.distributed as dist
|
|||||||
import torch.hub as hub
|
import torch.hub as hub
|
||||||
import torch.optim.lr_scheduler as lr_scheduler
|
import torch.optim.lr_scheduler as lr_scheduler
|
||||||
import torchvision
|
import torchvision
|
||||||
from torch.cuda import amp
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
FILE = Path(__file__).resolve()
|
FILE = Path(__file__).resolve()
|
||||||
ROOT = FILE.parents[1] # YOLOv5 root directory
|
ROOT = FILE.parents[1] # YOLOv5 root directory
|
||||||
if str(ROOT) not in sys.path:
|
if str(ROOT) not in sys.path:
|
||||||
|
@ -90,7 +90,6 @@ from utils.torch_utils import (
|
|||||||
torch_distributed_zero_first,
|
torch_distributed_zero_first,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
LOCAL_RANK = int(os.getenv("LOCAL_RANK", -1)) # https://pytorch.org/docs/stable/elastic/run.html
|
LOCAL_RANK = int(os.getenv("LOCAL_RANK", -1)) # https://pytorch.org/docs/stable/elastic/run.html
|
||||||
RANK = int(os.getenv("RANK", -1))
|
RANK = int(os.getenv("RANK", -1))
|
||||||
WORLD_SIZE = int(os.getenv("WORLD_SIZE", 1))
|
WORLD_SIZE = int(os.getenv("WORLD_SIZE", 1))
|
||||||
@ -388,7 +387,7 @@ def train(hyp, opt, device, callbacks):
|
|||||||
else:
|
else:
|
||||||
amp_autocast = torch.amp.autocast("cuda", enabled=amp)
|
amp_autocast = torch.amp.autocast("cuda", enabled=amp)
|
||||||
# Forward
|
# Forward
|
||||||
with amp_autocast:
|
with amp_autocast:
|
||||||
pred = model(imgs) # forward
|
pred = model(imgs) # forward
|
||||||
loss, loss_items = compute_loss(pred, targets.to(device), masks=masks.to(device).float())
|
loss, loss_items = compute_loss(pred, targets.to(device), masks=masks.to(device).float())
|
||||||
if RANK != -1:
|
if RANK != -1:
|
||||||
|
1
train.py
1
train.py
@ -95,7 +95,6 @@ from utils.torch_utils import (
|
|||||||
torch_distributed_zero_first,
|
torch_distributed_zero_first,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
LOCAL_RANK = int(os.getenv("LOCAL_RANK", -1)) # https://pytorch.org/docs/stable/elastic/run.html
|
LOCAL_RANK = int(os.getenv("LOCAL_RANK", -1)) # https://pytorch.org/docs/stable/elastic/run.html
|
||||||
RANK = int(os.getenv("RANK", -1))
|
RANK = int(os.getenv("RANK", -1))
|
||||||
WORLD_SIZE = int(os.getenv("WORLD_SIZE", 1))
|
WORLD_SIZE = int(os.getenv("WORLD_SIZE", 1))
|
||||||
|
Loading…
x
Reference in New Issue
Block a user