refractored the code
parent
5cdad8922c
commit
a63bfd38c1
|
@ -27,9 +27,16 @@ import torch.distributed as dist
|
||||||
import torch.hub as hub
|
import torch.hub as hub
|
||||||
import torch.optim.lr_scheduler as lr_scheduler
|
import torch.optim.lr_scheduler as lr_scheduler
|
||||||
import torchvision
|
import torchvision
|
||||||
from torch.cuda import amp
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
# version check
|
||||||
|
if torch.__version__.startswith("1.8"):
|
||||||
|
Autocast = torch.cuda.amp.autocast
|
||||||
|
GradScaler = torch.cuda.amp.GradScaler
|
||||||
|
else:
|
||||||
|
Autocast = torch.amp.autocast
|
||||||
|
GradScaler = torch.amp.GradScaler
|
||||||
|
|
||||||
FILE = Path(__file__).resolve()
|
FILE = Path(__file__).resolve()
|
||||||
ROOT = FILE.parents[1] # YOLOv5 root directory
|
ROOT = FILE.parents[1] # YOLOv5 root directory
|
||||||
if str(ROOT) not in sys.path:
|
if str(ROOT) not in sys.path:
|
||||||
|
@ -198,7 +205,7 @@ def train(opt, device):
|
||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
criterion = smartCrossEntropyLoss(label_smoothing=opt.label_smoothing) # loss function
|
criterion = smartCrossEntropyLoss(label_smoothing=opt.label_smoothing) # loss function
|
||||||
best_fitness = 0.0
|
best_fitness = 0.0
|
||||||
scaler = amp.GradScaler(enabled=cuda)
|
scaler = GradScaler(enabled=cuda)
|
||||||
val = test_dir.stem # 'val' or 'test'
|
val = test_dir.stem # 'val' or 'test'
|
||||||
LOGGER.info(
|
LOGGER.info(
|
||||||
f"Image sizes {imgsz} train, {imgsz} test\n"
|
f"Image sizes {imgsz} train, {imgsz} test\n"
|
||||||
|
@ -219,7 +226,7 @@ def train(opt, device):
|
||||||
images, labels = images.to(device, non_blocking=True), labels.to(device)
|
images, labels = images.to(device, non_blocking=True), labels.to(device)
|
||||||
|
|
||||||
# Forward
|
# Forward
|
||||||
with amp.autocast(enabled=cuda): # stability issues when enabled
|
with Autocast(enabled=device.type != "cpu"): # stability issues when enabled
|
||||||
loss = criterion(model(images), labels)
|
loss = criterion(model(images), labels)
|
||||||
|
|
||||||
# Backward
|
# Backward
|
||||||
|
|
|
@ -48,6 +48,11 @@ from utils.general import (
|
||||||
)
|
)
|
||||||
from utils.torch_utils import select_device, smart_inference_mode
|
from utils.torch_utils import select_device, smart_inference_mode
|
||||||
|
|
||||||
|
#version check
|
||||||
|
if torch.__version__.startswith("1.8"):
|
||||||
|
Autocast = torch.cuda.amp.autocast
|
||||||
|
else:
|
||||||
|
Autocast = torch.amp.autocast
|
||||||
|
|
||||||
@smart_inference_mode()
|
@smart_inference_mode()
|
||||||
def run(
|
def run(
|
||||||
|
@ -108,7 +113,7 @@ def run(
|
||||||
action = "validating" if dataloader.dataset.root.stem == "val" else "testing"
|
action = "validating" if dataloader.dataset.root.stem == "val" else "testing"
|
||||||
desc = f"{pbar.desc[:-36]}{action:>36}" if pbar else f"{action}"
|
desc = f"{pbar.desc[:-36]}{action:>36}" if pbar else f"{action}"
|
||||||
bar = tqdm(dataloader, desc, n, not training, bar_format=TQDM_BAR_FORMAT, position=0)
|
bar = tqdm(dataloader, desc, n, not training, bar_format=TQDM_BAR_FORMAT, position=0)
|
||||||
with torch.cuda.amp.autocast(enabled=device.type != "cpu"):
|
with Autocast(enabled=device.type != "cpu"):
|
||||||
for images, labels in bar:
|
for images, labels in bar:
|
||||||
with dt[0]:
|
with dt[0]:
|
||||||
images, labels = images.to(device, non_blocking=True), labels.to(device)
|
images, labels = images.to(device, non_blocking=True), labels.to(device)
|
||||||
|
|
|
@ -20,7 +20,6 @@ import requests
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from torch.cuda import amp
|
|
||||||
|
|
||||||
# Import 'ultralytics' package or install if missing
|
# Import 'ultralytics' package or install if missing
|
||||||
try:
|
try:
|
||||||
|
@ -56,6 +55,11 @@ from utils.general import (
|
||||||
)
|
)
|
||||||
from utils.torch_utils import copy_attr, smart_inference_mode
|
from utils.torch_utils import copy_attr, smart_inference_mode
|
||||||
|
|
||||||
|
# version check
|
||||||
|
if torch.__version__.startswith("1.8"):
|
||||||
|
Autocast = torch.cuda.amp.autocast
|
||||||
|
else:
|
||||||
|
Autocast = torch.amp.autocast
|
||||||
|
|
||||||
def autopad(k, p=None, d=1):
|
def autopad(k, p=None, d=1):
|
||||||
"""
|
"""
|
||||||
|
@ -864,7 +868,7 @@ class AutoShape(nn.Module):
|
||||||
p = next(self.model.parameters()) if self.pt else torch.empty(1, device=self.model.device) # param
|
p = next(self.model.parameters()) if self.pt else torch.empty(1, device=self.model.device) # param
|
||||||
autocast = self.amp and (p.device.type != "cpu") # Automatic Mixed Precision (AMP) inference
|
autocast = self.amp and (p.device.type != "cpu") # Automatic Mixed Precision (AMP) inference
|
||||||
if isinstance(ims, torch.Tensor): # torch
|
if isinstance(ims, torch.Tensor): # torch
|
||||||
with amp.autocast(autocast):
|
with Autocast(enabled=autocast):
|
||||||
return self.model(ims.to(p.device).type_as(p), augment=augment) # inference
|
return self.model(ims.to(p.device).type_as(p), augment=augment) # inference
|
||||||
|
|
||||||
# Pre-process
|
# Pre-process
|
||||||
|
@ -891,7 +895,7 @@ class AutoShape(nn.Module):
|
||||||
x = np.ascontiguousarray(np.array(x).transpose((0, 3, 1, 2))) # stack and BHWC to BCHW
|
x = np.ascontiguousarray(np.array(x).transpose((0, 3, 1, 2))) # stack and BHWC to BCHW
|
||||||
x = torch.from_numpy(x).to(p.device).type_as(p) / 255 # uint8 to fp16/32
|
x = torch.from_numpy(x).to(p.device).type_as(p) / 255 # uint8 to fp16/32
|
||||||
|
|
||||||
with amp.autocast(autocast):
|
with Autocast(enabled=autocast):
|
||||||
# Inference
|
# Inference
|
||||||
with dt[1]:
|
with dt[1]:
|
||||||
y = self.model(x, augment=augment) # forward
|
y = self.model(x, augment=augment) # forward
|
||||||
|
|
|
@ -89,6 +89,14 @@ from utils.torch_utils import (
|
||||||
torch_distributed_zero_first,
|
torch_distributed_zero_first,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# version check
|
||||||
|
if torch.__version__.startswith("1.8"):
|
||||||
|
Autocast = torch.cuda.amp.autocast
|
||||||
|
GradScaler = torch.cuda.amp.GradScaler
|
||||||
|
else:
|
||||||
|
Autocast = torch.amp.autocast
|
||||||
|
GradScaler = torch.amp.GradScaler
|
||||||
|
|
||||||
LOCAL_RANK = int(os.getenv("LOCAL_RANK", -1)) # https://pytorch.org/docs/stable/elastic/run.html
|
LOCAL_RANK = int(os.getenv("LOCAL_RANK", -1)) # https://pytorch.org/docs/stable/elastic/run.html
|
||||||
RANK = int(os.getenv("RANK", -1))
|
RANK = int(os.getenv("RANK", -1))
|
||||||
WORLD_SIZE = int(os.getenv("WORLD_SIZE", 1))
|
WORLD_SIZE = int(os.getenv("WORLD_SIZE", 1))
|
||||||
|
@ -320,7 +328,7 @@ def train(hyp, opt, device, callbacks):
|
||||||
maps = np.zeros(nc) # mAP per class
|
maps = np.zeros(nc) # mAP per class
|
||||||
results = (0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) # P, R, mAP@.5, mAP@.5-.95, val_loss(box, obj, cls)
|
results = (0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) # P, R, mAP@.5, mAP@.5-.95, val_loss(box, obj, cls)
|
||||||
scheduler.last_epoch = start_epoch - 1 # do not move
|
scheduler.last_epoch = start_epoch - 1 # do not move
|
||||||
scaler = torch.cuda.amp.GradScaler(enabled=amp)
|
scaler = GradScaler(enabled=amp)
|
||||||
stopper, stop = EarlyStopping(patience=opt.patience), False
|
stopper, stop = EarlyStopping(patience=opt.patience), False
|
||||||
compute_loss = ComputeLoss(model, overlap=overlap) # init loss class
|
compute_loss = ComputeLoss(model, overlap=overlap) # init loss class
|
||||||
# callbacks.run('on_train_start')
|
# callbacks.run('on_train_start')
|
||||||
|
@ -380,7 +388,7 @@ def train(hyp, opt, device, callbacks):
|
||||||
imgs = nn.functional.interpolate(imgs, size=ns, mode="bilinear", align_corners=False)
|
imgs = nn.functional.interpolate(imgs, size=ns, mode="bilinear", align_corners=False)
|
||||||
|
|
||||||
# Forward
|
# Forward
|
||||||
with torch.cuda.amp.autocast(amp):
|
with Autocast(enabled=amp):
|
||||||
pred = model(imgs) # forward
|
pred = model(imgs) # forward
|
||||||
loss, loss_items = compute_loss(pred, targets.to(device), masks=masks.to(device).float())
|
loss, loss_items = compute_loss(pred, targets.to(device), masks=masks.to(device).float())
|
||||||
if RANK != -1:
|
if RANK != -1:
|
||||||
|
|
12
train.py
12
train.py
|
@ -94,6 +94,14 @@ from utils.torch_utils import (
|
||||||
torch_distributed_zero_first,
|
torch_distributed_zero_first,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# version check
|
||||||
|
if torch.__version__.startswith("1.8"):
|
||||||
|
Autocast = torch.cuda.amp.autocast
|
||||||
|
GradScaler = torch.cuda.amp.GradScaler
|
||||||
|
else:
|
||||||
|
Autocast = torch.amp.autocast
|
||||||
|
GradScaler = torch.amp.GradScaler
|
||||||
|
|
||||||
LOCAL_RANK = int(os.getenv("LOCAL_RANK", -1)) # https://pytorch.org/docs/stable/elastic/run.html
|
LOCAL_RANK = int(os.getenv("LOCAL_RANK", -1)) # https://pytorch.org/docs/stable/elastic/run.html
|
||||||
RANK = int(os.getenv("RANK", -1))
|
RANK = int(os.getenv("RANK", -1))
|
||||||
WORLD_SIZE = int(os.getenv("WORLD_SIZE", 1))
|
WORLD_SIZE = int(os.getenv("WORLD_SIZE", 1))
|
||||||
|
@ -352,7 +360,7 @@ def train(hyp, opt, device, callbacks):
|
||||||
maps = np.zeros(nc) # mAP per class
|
maps = np.zeros(nc) # mAP per class
|
||||||
results = (0, 0, 0, 0, 0, 0, 0) # P, R, mAP@.5, mAP@.5-.95, val_loss(box, obj, cls)
|
results = (0, 0, 0, 0, 0, 0, 0) # P, R, mAP@.5, mAP@.5-.95, val_loss(box, obj, cls)
|
||||||
scheduler.last_epoch = start_epoch - 1 # do not move
|
scheduler.last_epoch = start_epoch - 1 # do not move
|
||||||
scaler = torch.cuda.amp.GradScaler(enabled=amp)
|
scaler = GradScaler(enabled=amp)
|
||||||
stopper, stop = EarlyStopping(patience=opt.patience), False
|
stopper, stop = EarlyStopping(patience=opt.patience), False
|
||||||
compute_loss = ComputeLoss(model) # init loss class
|
compute_loss = ComputeLoss(model) # init loss class
|
||||||
callbacks.run("on_train_start")
|
callbacks.run("on_train_start")
|
||||||
|
@ -409,7 +417,7 @@ def train(hyp, opt, device, callbacks):
|
||||||
imgs = nn.functional.interpolate(imgs, size=ns, mode="bilinear", align_corners=False)
|
imgs = nn.functional.interpolate(imgs, size=ns, mode="bilinear", align_corners=False)
|
||||||
|
|
||||||
# Forward
|
# Forward
|
||||||
with torch.cuda.amp.autocast(amp):
|
with Autocast(enabled=amp):
|
||||||
pred = model(imgs) # forward
|
pred = model(imgs) # forward
|
||||||
loss, loss_items = compute_loss(pred, targets.to(device)) # loss scaled by batch_size
|
loss, loss_items = compute_loss(pred, targets.to(device)) # loss scaled by batch_size
|
||||||
if RANK != -1:
|
if RANK != -1:
|
||||||
|
|
|
@ -12,7 +12,7 @@ from utils.torch_utils import profile
|
||||||
|
|
||||||
def check_train_batch_size(model, imgsz=640, amp=True):
|
def check_train_batch_size(model, imgsz=640, amp=True):
|
||||||
"""Checks and computes optimal training batch size for YOLOv5 model, given image size and AMP setting."""
|
"""Checks and computes optimal training batch size for YOLOv5 model, given image size and AMP setting."""
|
||||||
with torch.cuda.amp.autocast(amp):
|
with torch.amp.autocast("cuda", enabled=amp):
|
||||||
return autobatch(deepcopy(model).train(), imgsz) # compute optimal batch size
|
return autobatch(deepcopy(model).train(), imgsz) # compute optimal batch size
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue