pull/13487/merge
Parag Londhe 2025-04-18 01:02:47 +02:00 committed by GitHub
commit b22f293141
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 13 additions and 5 deletions

View File

@ -36,6 +36,12 @@ import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
import yaml import yaml
from torch.optim import lr_scheduler from torch.optim import lr_scheduler
try:
import torch.amp as amp
except ImportError:
import torch.cuda.amp as amp
from tqdm import tqdm from tqdm import tqdm
FILE = Path(__file__).resolve() FILE = Path(__file__).resolve()
@ -221,7 +227,7 @@ def train(hyp, opt, device, callbacks):
LOGGER.info(f"Transferred {len(csd)}/{len(model.state_dict())} items from {weights}") # report LOGGER.info(f"Transferred {len(csd)}/{len(model.state_dict())} items from {weights}") # report
else: else:
model = Model(cfg, ch=3, nc=nc, anchors=hyp.get("anchors")).to(device) # create model = Model(cfg, ch=3, nc=nc, anchors=hyp.get("anchors")).to(device) # create
amp = check_amp(model) # check AMP use_amp = check_amp(model) # check AMP
# Freeze # Freeze
freeze = [f"model.{x}." for x in (freeze if len(freeze) > 1 else range(freeze[0]))] # layers to freeze freeze = [f"model.{x}." for x in (freeze if len(freeze) > 1 else range(freeze[0]))] # layers to freeze
@ -238,7 +244,7 @@ def train(hyp, opt, device, callbacks):
# Batch size # Batch size
if RANK == -1 and batch_size == -1: # single-GPU only, estimate best batch size if RANK == -1 and batch_size == -1: # single-GPU only, estimate best batch size
batch_size = check_train_batch_size(model, imgsz, amp) batch_size = check_train_batch_size(model, imgsz, use_amp)
loggers.on_params_update({"batch_size": batch_size}) loggers.on_params_update({"batch_size": batch_size})
# Optimizer # Optimizer
@ -352,7 +358,8 @@ def train(hyp, opt, device, callbacks):
maps = np.zeros(nc) # mAP per class maps = np.zeros(nc) # mAP per class
results = (0, 0, 0, 0, 0, 0, 0) # P, R, mAP@.5, mAP@.5-.95, val_loss(box, obj, cls) results = (0, 0, 0, 0, 0, 0, 0) # P, R, mAP@.5, mAP@.5-.95, val_loss(box, obj, cls)
scheduler.last_epoch = start_epoch - 1 # do not move scheduler.last_epoch = start_epoch - 1 # do not move
scaler = torch.cuda.amp.GradScaler(enabled=amp) # scaler = torch.cuda.amp.GradScaler(enabled=amp)
scaler = amp.GradScaler(enabled=use_amp)
stopper, stop = EarlyStopping(patience=opt.patience), False stopper, stop = EarlyStopping(patience=opt.patience), False
compute_loss = ComputeLoss(model) # init loss class compute_loss = ComputeLoss(model) # init loss class
callbacks.run("on_train_start") callbacks.run("on_train_start")
@ -409,7 +416,8 @@ def train(hyp, opt, device, callbacks):
imgs = nn.functional.interpolate(imgs, size=ns, mode="bilinear", align_corners=False) imgs = nn.functional.interpolate(imgs, size=ns, mode="bilinear", align_corners=False)
# Forward # Forward
with torch.cuda.amp.autocast(amp): # with torch.cuda.amp.autocast(amp):
with amp.autocast(enabled=use_amp, device_type=device.type):
pred = model(imgs) # forward pred = model(imgs) # forward
loss, loss_items = compute_loss(pred, targets.to(device)) # loss scaled by batch_size loss, loss_items = compute_loss(pred, targets.to(device)) # loss scaled by batch_size
if RANK != -1: if RANK != -1:
@ -458,7 +466,7 @@ def train(hyp, opt, device, callbacks):
data_dict, data_dict,
batch_size=batch_size // WORLD_SIZE * 2, batch_size=batch_size // WORLD_SIZE * 2,
imgsz=imgsz, imgsz=imgsz,
half=amp, half=use_amp,
model=ema.ema, model=ema.ema,
single_cls=single_cls, single_cls=single_cls,
dataloader=val_loader, dataloader=val_loader,