Merge c9468ec281
into fe1d4d9947
commit
b22f293141
18
train.py
18
train.py
|
@ -36,6 +36,12 @@ import torch.distributed as dist
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import yaml
|
import yaml
|
||||||
from torch.optim import lr_scheduler
|
from torch.optim import lr_scheduler
|
||||||
|
|
||||||
|
try:
|
||||||
|
import torch.amp as amp
|
||||||
|
except ImportError:
|
||||||
|
import torch.cuda.amp as amp
|
||||||
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
FILE = Path(__file__).resolve()
|
FILE = Path(__file__).resolve()
|
||||||
|
@ -221,7 +227,7 @@ def train(hyp, opt, device, callbacks):
|
||||||
LOGGER.info(f"Transferred {len(csd)}/{len(model.state_dict())} items from {weights}") # report
|
LOGGER.info(f"Transferred {len(csd)}/{len(model.state_dict())} items from {weights}") # report
|
||||||
else:
|
else:
|
||||||
model = Model(cfg, ch=3, nc=nc, anchors=hyp.get("anchors")).to(device) # create
|
model = Model(cfg, ch=3, nc=nc, anchors=hyp.get("anchors")).to(device) # create
|
||||||
amp = check_amp(model) # check AMP
|
use_amp = check_amp(model) # check AMP
|
||||||
|
|
||||||
# Freeze
|
# Freeze
|
||||||
freeze = [f"model.{x}." for x in (freeze if len(freeze) > 1 else range(freeze[0]))] # layers to freeze
|
freeze = [f"model.{x}." for x in (freeze if len(freeze) > 1 else range(freeze[0]))] # layers to freeze
|
||||||
|
@ -238,7 +244,7 @@ def train(hyp, opt, device, callbacks):
|
||||||
|
|
||||||
# Batch size
|
# Batch size
|
||||||
if RANK == -1 and batch_size == -1: # single-GPU only, estimate best batch size
|
if RANK == -1 and batch_size == -1: # single-GPU only, estimate best batch size
|
||||||
batch_size = check_train_batch_size(model, imgsz, amp)
|
batch_size = check_train_batch_size(model, imgsz, use_amp)
|
||||||
loggers.on_params_update({"batch_size": batch_size})
|
loggers.on_params_update({"batch_size": batch_size})
|
||||||
|
|
||||||
# Optimizer
|
# Optimizer
|
||||||
|
@ -352,7 +358,8 @@ def train(hyp, opt, device, callbacks):
|
||||||
maps = np.zeros(nc) # mAP per class
|
maps = np.zeros(nc) # mAP per class
|
||||||
results = (0, 0, 0, 0, 0, 0, 0) # P, R, mAP@.5, mAP@.5-.95, val_loss(box, obj, cls)
|
results = (0, 0, 0, 0, 0, 0, 0) # P, R, mAP@.5, mAP@.5-.95, val_loss(box, obj, cls)
|
||||||
scheduler.last_epoch = start_epoch - 1 # do not move
|
scheduler.last_epoch = start_epoch - 1 # do not move
|
||||||
scaler = torch.cuda.amp.GradScaler(enabled=amp)
|
# scaler = torch.cuda.amp.GradScaler(enabled=amp)
|
||||||
|
scaler = amp.GradScaler(enabled=use_amp)
|
||||||
stopper, stop = EarlyStopping(patience=opt.patience), False
|
stopper, stop = EarlyStopping(patience=opt.patience), False
|
||||||
compute_loss = ComputeLoss(model) # init loss class
|
compute_loss = ComputeLoss(model) # init loss class
|
||||||
callbacks.run("on_train_start")
|
callbacks.run("on_train_start")
|
||||||
|
@ -409,7 +416,8 @@ def train(hyp, opt, device, callbacks):
|
||||||
imgs = nn.functional.interpolate(imgs, size=ns, mode="bilinear", align_corners=False)
|
imgs = nn.functional.interpolate(imgs, size=ns, mode="bilinear", align_corners=False)
|
||||||
|
|
||||||
# Forward
|
# Forward
|
||||||
with torch.cuda.amp.autocast(amp):
|
# with torch.cuda.amp.autocast(amp):
|
||||||
|
with amp.autocast(enabled=use_amp, device_type=device.type):
|
||||||
pred = model(imgs) # forward
|
pred = model(imgs) # forward
|
||||||
loss, loss_items = compute_loss(pred, targets.to(device)) # loss scaled by batch_size
|
loss, loss_items = compute_loss(pred, targets.to(device)) # loss scaled by batch_size
|
||||||
if RANK != -1:
|
if RANK != -1:
|
||||||
|
@ -458,7 +466,7 @@ def train(hyp, opt, device, callbacks):
|
||||||
data_dict,
|
data_dict,
|
||||||
batch_size=batch_size // WORLD_SIZE * 2,
|
batch_size=batch_size // WORLD_SIZE * 2,
|
||||||
imgsz=imgsz,
|
imgsz=imgsz,
|
||||||
half=amp,
|
half=use_amp,
|
||||||
model=ema.ema,
|
model=ema.ema,
|
||||||
single_cls=single_cls,
|
single_cls=single_cls,
|
||||||
dataloader=val_loader,
|
dataloader=val_loader,
|
||||||
|
|
Loading…
Reference in New Issue