refractored the code

pull/13483/head
Bala-Vignesh-Reddy 2025-02-23 20:07:59 +05:30
parent 5cdad8922c
commit a63bfd38c1
6 changed files with 44 additions and 12 deletions

View File

@ -27,9 +27,16 @@ import torch.distributed as dist
import torch.hub as hub import torch.hub as hub
import torch.optim.lr_scheduler as lr_scheduler import torch.optim.lr_scheduler as lr_scheduler
import torchvision import torchvision
from torch.cuda import amp
from tqdm import tqdm from tqdm import tqdm
# version check
if torch.__version__.startswith("1.8"):
Autocast = torch.cuda.amp.autocast
GradScaler = torch.cuda.amp.GradScaler
else:
Autocast = torch.amp.autocast
GradScaler = torch.amp.GradScaler
FILE = Path(__file__).resolve() FILE = Path(__file__).resolve()
ROOT = FILE.parents[1] # YOLOv5 root directory ROOT = FILE.parents[1] # YOLOv5 root directory
if str(ROOT) not in sys.path: if str(ROOT) not in sys.path:
@ -198,7 +205,7 @@ def train(opt, device):
t0 = time.time() t0 = time.time()
criterion = smartCrossEntropyLoss(label_smoothing=opt.label_smoothing) # loss function criterion = smartCrossEntropyLoss(label_smoothing=opt.label_smoothing) # loss function
best_fitness = 0.0 best_fitness = 0.0
scaler = amp.GradScaler(enabled=cuda) scaler = GradScaler(enabled=cuda)
val = test_dir.stem # 'val' or 'test' val = test_dir.stem # 'val' or 'test'
LOGGER.info( LOGGER.info(
f"Image sizes {imgsz} train, {imgsz} test\n" f"Image sizes {imgsz} train, {imgsz} test\n"
@ -219,7 +226,7 @@ def train(opt, device):
images, labels = images.to(device, non_blocking=True), labels.to(device) images, labels = images.to(device, non_blocking=True), labels.to(device)
# Forward # Forward
with amp.autocast(enabled=cuda): # stability issues when enabled with Autocast(enabled=device.type != "cpu"): # stability issues when enabled
loss = criterion(model(images), labels) loss = criterion(model(images), labels)
# Backward # Backward

View File

@ -48,6 +48,11 @@ from utils.general import (
) )
from utils.torch_utils import select_device, smart_inference_mode from utils.torch_utils import select_device, smart_inference_mode
#version check
if torch.__version__.startswith("1.8"):
Autocast = torch.cuda.amp.autocast
else:
Autocast = torch.amp.autocast
@smart_inference_mode() @smart_inference_mode()
def run( def run(
@ -108,7 +113,7 @@ def run(
action = "validating" if dataloader.dataset.root.stem == "val" else "testing" action = "validating" if dataloader.dataset.root.stem == "val" else "testing"
desc = f"{pbar.desc[:-36]}{action:>36}" if pbar else f"{action}" desc = f"{pbar.desc[:-36]}{action:>36}" if pbar else f"{action}"
bar = tqdm(dataloader, desc, n, not training, bar_format=TQDM_BAR_FORMAT, position=0) bar = tqdm(dataloader, desc, n, not training, bar_format=TQDM_BAR_FORMAT, position=0)
with torch.cuda.amp.autocast(enabled=device.type != "cpu"): with Autocast(enabled=device.type != "cpu"):
for images, labels in bar: for images, labels in bar:
with dt[0]: with dt[0]:
images, labels = images.to(device, non_blocking=True), labels.to(device) images, labels = images.to(device, non_blocking=True), labels.to(device)

View File

@ -20,7 +20,6 @@ import requests
import torch import torch
import torch.nn as nn import torch.nn as nn
from PIL import Image from PIL import Image
from torch.cuda import amp
# Import 'ultralytics' package or install if missing # Import 'ultralytics' package or install if missing
try: try:
@ -56,6 +55,11 @@ from utils.general import (
) )
from utils.torch_utils import copy_attr, smart_inference_mode from utils.torch_utils import copy_attr, smart_inference_mode
# version check
if torch.__version__.startswith("1.8"):
Autocast = torch.cuda.amp.autocast
else:
Autocast = torch.amp.autocast
def autopad(k, p=None, d=1): def autopad(k, p=None, d=1):
""" """
@ -864,7 +868,7 @@ class AutoShape(nn.Module):
p = next(self.model.parameters()) if self.pt else torch.empty(1, device=self.model.device) # param p = next(self.model.parameters()) if self.pt else torch.empty(1, device=self.model.device) # param
autocast = self.amp and (p.device.type != "cpu") # Automatic Mixed Precision (AMP) inference autocast = self.amp and (p.device.type != "cpu") # Automatic Mixed Precision (AMP) inference
if isinstance(ims, torch.Tensor): # torch if isinstance(ims, torch.Tensor): # torch
with amp.autocast(autocast): with Autocast(enabled=autocast):
return self.model(ims.to(p.device).type_as(p), augment=augment) # inference return self.model(ims.to(p.device).type_as(p), augment=augment) # inference
# Pre-process # Pre-process
@ -891,7 +895,7 @@ class AutoShape(nn.Module):
x = np.ascontiguousarray(np.array(x).transpose((0, 3, 1, 2))) # stack and BHWC to BCHW x = np.ascontiguousarray(np.array(x).transpose((0, 3, 1, 2))) # stack and BHWC to BCHW
x = torch.from_numpy(x).to(p.device).type_as(p) / 255 # uint8 to fp16/32 x = torch.from_numpy(x).to(p.device).type_as(p) / 255 # uint8 to fp16/32
with amp.autocast(autocast): with Autocast(enabled=autocast):
# Inference # Inference
with dt[1]: with dt[1]:
y = self.model(x, augment=augment) # forward y = self.model(x, augment=augment) # forward

View File

@ -89,6 +89,14 @@ from utils.torch_utils import (
torch_distributed_zero_first, torch_distributed_zero_first,
) )
# version check
if torch.__version__.startswith("1.8"):
Autocast = torch.cuda.amp.autocast
GradScaler = torch.cuda.amp.GradScaler
else:
Autocast = torch.amp.autocast
GradScaler = torch.amp.GradScaler
LOCAL_RANK = int(os.getenv("LOCAL_RANK", -1)) # https://pytorch.org/docs/stable/elastic/run.html LOCAL_RANK = int(os.getenv("LOCAL_RANK", -1)) # https://pytorch.org/docs/stable/elastic/run.html
RANK = int(os.getenv("RANK", -1)) RANK = int(os.getenv("RANK", -1))
WORLD_SIZE = int(os.getenv("WORLD_SIZE", 1)) WORLD_SIZE = int(os.getenv("WORLD_SIZE", 1))
@ -320,7 +328,7 @@ def train(hyp, opt, device, callbacks):
maps = np.zeros(nc) # mAP per class maps = np.zeros(nc) # mAP per class
results = (0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) # P, R, mAP@.5, mAP@.5-.95, val_loss(box, obj, cls) results = (0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) # P, R, mAP@.5, mAP@.5-.95, val_loss(box, obj, cls)
scheduler.last_epoch = start_epoch - 1 # do not move scheduler.last_epoch = start_epoch - 1 # do not move
scaler = torch.cuda.amp.GradScaler(enabled=amp) scaler = GradScaler(enabled=amp)
stopper, stop = EarlyStopping(patience=opt.patience), False stopper, stop = EarlyStopping(patience=opt.patience), False
compute_loss = ComputeLoss(model, overlap=overlap) # init loss class compute_loss = ComputeLoss(model, overlap=overlap) # init loss class
# callbacks.run('on_train_start') # callbacks.run('on_train_start')
@ -380,7 +388,7 @@ def train(hyp, opt, device, callbacks):
imgs = nn.functional.interpolate(imgs, size=ns, mode="bilinear", align_corners=False) imgs = nn.functional.interpolate(imgs, size=ns, mode="bilinear", align_corners=False)
# Forward # Forward
with torch.cuda.amp.autocast(amp): with Autocast(enabled=amp):
pred = model(imgs) # forward pred = model(imgs) # forward
loss, loss_items = compute_loss(pred, targets.to(device), masks=masks.to(device).float()) loss, loss_items = compute_loss(pred, targets.to(device), masks=masks.to(device).float())
if RANK != -1: if RANK != -1:

View File

@ -94,6 +94,14 @@ from utils.torch_utils import (
torch_distributed_zero_first, torch_distributed_zero_first,
) )
# version check
if torch.__version__.startswith("1.8"):
Autocast = torch.cuda.amp.autocast
GradScaler = torch.cuda.amp.GradScaler
else:
Autocast = torch.amp.autocast
GradScaler = torch.amp.GradScaler
LOCAL_RANK = int(os.getenv("LOCAL_RANK", -1)) # https://pytorch.org/docs/stable/elastic/run.html LOCAL_RANK = int(os.getenv("LOCAL_RANK", -1)) # https://pytorch.org/docs/stable/elastic/run.html
RANK = int(os.getenv("RANK", -1)) RANK = int(os.getenv("RANK", -1))
WORLD_SIZE = int(os.getenv("WORLD_SIZE", 1)) WORLD_SIZE = int(os.getenv("WORLD_SIZE", 1))
@ -352,7 +360,7 @@ def train(hyp, opt, device, callbacks):
maps = np.zeros(nc) # mAP per class maps = np.zeros(nc) # mAP per class
results = (0, 0, 0, 0, 0, 0, 0) # P, R, mAP@.5, mAP@.5-.95, val_loss(box, obj, cls) results = (0, 0, 0, 0, 0, 0, 0) # P, R, mAP@.5, mAP@.5-.95, val_loss(box, obj, cls)
scheduler.last_epoch = start_epoch - 1 # do not move scheduler.last_epoch = start_epoch - 1 # do not move
scaler = torch.cuda.amp.GradScaler(enabled=amp) scaler = GradScaler(enabled=amp)
stopper, stop = EarlyStopping(patience=opt.patience), False stopper, stop = EarlyStopping(patience=opt.patience), False
compute_loss = ComputeLoss(model) # init loss class compute_loss = ComputeLoss(model) # init loss class
callbacks.run("on_train_start") callbacks.run("on_train_start")
@ -409,7 +417,7 @@ def train(hyp, opt, device, callbacks):
imgs = nn.functional.interpolate(imgs, size=ns, mode="bilinear", align_corners=False) imgs = nn.functional.interpolate(imgs, size=ns, mode="bilinear", align_corners=False)
# Forward # Forward
with torch.cuda.amp.autocast(amp): with Autocast(enabled=amp):
pred = model(imgs) # forward pred = model(imgs) # forward
loss, loss_items = compute_loss(pred, targets.to(device)) # loss scaled by batch_size loss, loss_items = compute_loss(pred, targets.to(device)) # loss scaled by batch_size
if RANK != -1: if RANK != -1:

View File

@ -12,7 +12,7 @@ from utils.torch_utils import profile
def check_train_batch_size(model, imgsz=640, amp=True): def check_train_batch_size(model, imgsz=640, amp=True):
"""Checks and computes optimal training batch size for YOLOv5 model, given image size and AMP setting.""" """Checks and computes optimal training batch size for YOLOv5 model, given image size and AMP setting."""
with torch.cuda.amp.autocast(amp): with torch.amp.autocast("cuda", enabled=amp):
return autobatch(deepcopy(model).train(), imgsz) # compute optimal batch size return autobatch(deepcopy(model).train(), imgsz) # compute optimal batch size