Refactor training loop to incorporate effective epoch length and gradient accumulation steps

pull/509/head
ChuaHanChong 2025-03-25 06:06:52 +00:00
parent 9bdb158362
commit 83f2d7ff62
1 changed files with 64 additions and 44 deletions

View File

@ -63,29 +63,29 @@ def build_optimizer(cfg, params_groups):
def build_schedulers(cfg):
OFFICIAL_EPOCH_LENGTH = cfg.train.OFFICIAL_EPOCH_LENGTH
EFFECTIVE_EPOCH_LENGTH = (cfg.train.OFFICIAL_EPOCH_LENGTH + cfg.train.grad_accum_steps - 1) // cfg.train.grad_accum_steps
lr = dict(
base_value=cfg.optim["lr"],
final_value=cfg.optim["min_lr"],
total_iters=cfg.optim["epochs"] * OFFICIAL_EPOCH_LENGTH,
warmup_iters=cfg.optim["warmup_epochs"] * OFFICIAL_EPOCH_LENGTH,
total_iters=cfg.optim["epochs"] * EFFECTIVE_EPOCH_LENGTH,
warmup_iters=cfg.optim["warmup_epochs"] * EFFECTIVE_EPOCH_LENGTH,
start_warmup_value=0,
)
wd = dict(
base_value=cfg.optim["weight_decay"],
final_value=cfg.optim["weight_decay_end"],
total_iters=cfg.optim["epochs"] * OFFICIAL_EPOCH_LENGTH,
total_iters=cfg.optim["epochs"] * EFFECTIVE_EPOCH_LENGTH,
)
momentum = dict(
base_value=cfg.teacher["momentum_teacher"],
final_value=cfg.teacher["final_momentum_teacher"],
total_iters=cfg.optim["epochs"] * OFFICIAL_EPOCH_LENGTH,
total_iters=cfg.optim["epochs"] * EFFECTIVE_EPOCH_LENGTH,
)
teacher_temp = dict(
base_value=cfg.teacher["teacher_temp"],
final_value=cfg.teacher["teacher_temp"],
total_iters=cfg.teacher["warmup_teacher_temp_epochs"] * OFFICIAL_EPOCH_LENGTH,
warmup_iters=cfg.teacher["warmup_teacher_temp_epochs"] * OFFICIAL_EPOCH_LENGTH,
total_iters=cfg.teacher["warmup_teacher_temp_epochs"] * EFFECTIVE_EPOCH_LENGTH,
warmup_iters=cfg.teacher["warmup_teacher_temp_epochs"] * EFFECTIVE_EPOCH_LENGTH,
start_warmup_value=cfg.teacher["warmup_teacher_temp"],
)
@ -96,7 +96,7 @@ def build_schedulers(cfg):
last_layer_lr_schedule = CosineScheduler(**lr)
last_layer_lr_schedule.schedule[
: cfg.optim["freeze_last_layer_epochs"] * OFFICIAL_EPOCH_LENGTH
: cfg.optim["freeze_last_layer_epochs"] * EFFECTIVE_EPOCH_LENGTH
] = 0 # mimicking the original schedules
logger.info("Schedulers ready.")
@ -154,11 +154,13 @@ def do_train(cfg, model, resume=False):
OFFICIAL_EPOCH_LENGTH = cfg.train.OFFICIAL_EPOCH_LENGTH
max_iter = cfg.optim.epochs * OFFICIAL_EPOCH_LENGTH
EFFECTIVE_EPOCH_LENGTH = (OFFICIAL_EPOCH_LENGTH + cfg.train.grad_accum_steps - 1) // cfg.train.grad_accum_steps
max_effective_iter = cfg.optim.epochs * EFFECTIVE_EPOCH_LENGTH
periodic_checkpointer = PeriodicCheckpointer(
checkpointer,
period=3 * OFFICIAL_EPOCH_LENGTH,
max_iter=max_iter,
period=3 * EFFECTIVE_EPOCH_LENGTH,
max_iter=max_effective_iter,
max_to_keep=3,
)
@ -203,9 +205,9 @@ def do_train(cfg, model, resume=False):
batch_size=cfg.train.batch_size_per_gpu,
num_workers=cfg.train.num_workers,
shuffle=True,
seed=start_iter, # TODO: Fix this -- cfg.train.seed
seed=cfg.train.seed, # TODO: Fix this -- cfg.train.seed
sampler_type=sampler_type,
sampler_advance=0, # TODO(qas): fix this -- start_iter * cfg.train.batch_size_per_gpu,
sampler_advance=start_iter * cfg.train.batch_size_per_gpu, # TODO(qas): fix this -- start_iter * cfg.train.batch_size_per_gpu,
drop_last=True,
collate_fn=collate_fn,
)
@ -214,11 +216,18 @@ def do_train(cfg, model, resume=False):
iteration = start_iter
accum_steps = cfg.train.grad_accum_steps
last_accum_steps = OFFICIAL_EPOCH_LENGTH % accum_steps
last_batch_idx = OFFICIAL_EPOCH_LENGTH - 1
last_batch_idx_to_accum = OFFICIAL_EPOCH_LENGTH - last_accum_steps
logger.info("Starting training from iteration {}".format(start_iter))
metrics_file = os.path.join(cfg.train.output_dir, "training_metrics.json")
metric_logger = MetricLogger(delimiter=" ", output_file=metrics_file)
header = "Training"
batch_idx = 0
optimizer.zero_grad(set_to_none=True)
for data in metric_logger.log_every(
data_loader,
10,
@ -227,48 +236,53 @@ def do_train(cfg, model, resume=False):
start_iter,
):
current_batch_size = data["collated_global_crops"].shape[0] / 2
if iteration > max_iter:
if iteration > max_effective_iter:
return
# apply schedules
# Determine if we need to update after this batch
last_batch = batch_idx == last_batch_idx
need_update = last_batch or (batch_idx + 1) % accum_steps == 0
if batch_idx >= last_batch_idx_to_accum:
accum_steps = last_accum_steps
lr = lr_schedule[iteration]
wd = wd_schedule[iteration]
mom = momentum_schedule[iteration]
teacher_temp = teacher_temp_schedule[iteration]
last_layer_lr = last_layer_lr_schedule[iteration]
apply_optim_scheduler(optimizer, lr, wd, last_layer_lr)
if need_update:
# apply schedules
lr = lr_schedule[iteration]
wd = wd_schedule[iteration]
mom = momentum_schedule[iteration]
teacher_temp = teacher_temp_schedule[iteration]
last_layer_lr = last_layer_lr_schedule[iteration]
apply_optim_scheduler(optimizer, lr, wd, last_layer_lr)
# compute losses
loss_dict = model.forward_backward(data, teacher_temp=teacher_temp, scale=1.0/accum_steps)
optimizer.zero_grad(set_to_none=True)
loss_dict = model.forward_backward(data, teacher_temp=teacher_temp)
# clip gradients and update weights only when needed
if need_update:
if fp16_scaler is not None:
if cfg.optim.clip_grad:
fp16_scaler.unscale_(optimizer)
for v in model.student.values():
v.clip_grad_norm_(cfg.optim.clip_grad)
fp16_scaler.step(optimizer)
fp16_scaler.update()
else:
if cfg.optim.clip_grad:
for v in model.student.values():
v.clip_grad_norm_(cfg.optim.clip_grad)
optimizer.step()
# clip gradients
optimizer.zero_grad(set_to_none=True)
if fp16_scaler is not None:
if cfg.optim.clip_grad:
fp16_scaler.unscale_(optimizer)
for v in model.student.values():
v.clip_grad_norm_(cfg.optim.clip_grad)
fp16_scaler.step(optimizer)
fp16_scaler.update()
else:
if cfg.optim.clip_grad:
for v in model.student.values():
v.clip_grad_norm_(cfg.optim.clip_grad)
optimizer.step()
# perform teacher EMA update
model.update_teacher(mom)
# perform teacher EMA update
model.update_teacher(mom)
# logging
if distributed.get_global_size() > 1:
for v in loss_dict.values():
torch.distributed.all_reduce(v)
loss_dict_reduced = {k: v.item() / distributed.get_global_size() for k, v in loss_dict.items()}
loss_dict_reduced = {k: v.item() * accum_steps / distributed.get_global_size() for k, v in loss_dict.items()}
if math.isnan(sum(loss_dict_reduced.values())):
logger.info("NaN detected")
@ -283,13 +297,19 @@ def do_train(cfg, model, resume=False):
metric_logger.update(total_loss=losses_reduced, **loss_dict_reduced)
# checkpointing and testing
if cfg.evaluation.eval_period_iterations > 0 and (iteration + 1) % cfg.evaluation.eval_period_iterations == 0:
if need_update and cfg.evaluation.eval_period_iterations > 0 and (iteration + 1) % cfg.evaluation.eval_period_iterations == 0:
do_test(cfg, model, f"training_{iteration}")
torch.cuda.synchronize()
periodic_checkpointer.step(iteration)
iteration = iteration + 1
if need_update:
periodic_checkpointer.step(iteration)
iteration = iteration + 1
batch_idx += 1
if batch_idx >= OFFICIAL_EPOCH_LENGTH:
batch_idx = 0
accum_steps = cfg.train.grad_accum_steps
metric_logger.synchronize_between_processes()
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}