Refactor training loop to incorporate effective epoch length and gradient accumulation steps
parent
9bdb158362
commit
83f2d7ff62
|
@ -63,29 +63,29 @@ def build_optimizer(cfg, params_groups):
|
|||
|
||||
|
||||
def build_schedulers(cfg):
|
||||
OFFICIAL_EPOCH_LENGTH = cfg.train.OFFICIAL_EPOCH_LENGTH
|
||||
EFFECTIVE_EPOCH_LENGTH = (cfg.train.OFFICIAL_EPOCH_LENGTH + cfg.train.grad_accum_steps - 1) // cfg.train.grad_accum_steps
|
||||
lr = dict(
|
||||
base_value=cfg.optim["lr"],
|
||||
final_value=cfg.optim["min_lr"],
|
||||
total_iters=cfg.optim["epochs"] * OFFICIAL_EPOCH_LENGTH,
|
||||
warmup_iters=cfg.optim["warmup_epochs"] * OFFICIAL_EPOCH_LENGTH,
|
||||
total_iters=cfg.optim["epochs"] * EFFECTIVE_EPOCH_LENGTH,
|
||||
warmup_iters=cfg.optim["warmup_epochs"] * EFFECTIVE_EPOCH_LENGTH,
|
||||
start_warmup_value=0,
|
||||
)
|
||||
wd = dict(
|
||||
base_value=cfg.optim["weight_decay"],
|
||||
final_value=cfg.optim["weight_decay_end"],
|
||||
total_iters=cfg.optim["epochs"] * OFFICIAL_EPOCH_LENGTH,
|
||||
total_iters=cfg.optim["epochs"] * EFFECTIVE_EPOCH_LENGTH,
|
||||
)
|
||||
momentum = dict(
|
||||
base_value=cfg.teacher["momentum_teacher"],
|
||||
final_value=cfg.teacher["final_momentum_teacher"],
|
||||
total_iters=cfg.optim["epochs"] * OFFICIAL_EPOCH_LENGTH,
|
||||
total_iters=cfg.optim["epochs"] * EFFECTIVE_EPOCH_LENGTH,
|
||||
)
|
||||
teacher_temp = dict(
|
||||
base_value=cfg.teacher["teacher_temp"],
|
||||
final_value=cfg.teacher["teacher_temp"],
|
||||
total_iters=cfg.teacher["warmup_teacher_temp_epochs"] * OFFICIAL_EPOCH_LENGTH,
|
||||
warmup_iters=cfg.teacher["warmup_teacher_temp_epochs"] * OFFICIAL_EPOCH_LENGTH,
|
||||
total_iters=cfg.teacher["warmup_teacher_temp_epochs"] * EFFECTIVE_EPOCH_LENGTH,
|
||||
warmup_iters=cfg.teacher["warmup_teacher_temp_epochs"] * EFFECTIVE_EPOCH_LENGTH,
|
||||
start_warmup_value=cfg.teacher["warmup_teacher_temp"],
|
||||
)
|
||||
|
||||
|
@ -96,7 +96,7 @@ def build_schedulers(cfg):
|
|||
last_layer_lr_schedule = CosineScheduler(**lr)
|
||||
|
||||
last_layer_lr_schedule.schedule[
|
||||
: cfg.optim["freeze_last_layer_epochs"] * OFFICIAL_EPOCH_LENGTH
|
||||
: cfg.optim["freeze_last_layer_epochs"] * EFFECTIVE_EPOCH_LENGTH
|
||||
] = 0 # mimicking the original schedules
|
||||
|
||||
logger.info("Schedulers ready.")
|
||||
|
@ -154,11 +154,13 @@ def do_train(cfg, model, resume=False):
|
|||
|
||||
OFFICIAL_EPOCH_LENGTH = cfg.train.OFFICIAL_EPOCH_LENGTH
|
||||
max_iter = cfg.optim.epochs * OFFICIAL_EPOCH_LENGTH
|
||||
EFFECTIVE_EPOCH_LENGTH = (OFFICIAL_EPOCH_LENGTH + cfg.train.grad_accum_steps - 1) // cfg.train.grad_accum_steps
|
||||
max_effective_iter = cfg.optim.epochs * EFFECTIVE_EPOCH_LENGTH
|
||||
|
||||
periodic_checkpointer = PeriodicCheckpointer(
|
||||
checkpointer,
|
||||
period=3 * OFFICIAL_EPOCH_LENGTH,
|
||||
max_iter=max_iter,
|
||||
period=3 * EFFECTIVE_EPOCH_LENGTH,
|
||||
max_iter=max_effective_iter,
|
||||
max_to_keep=3,
|
||||
)
|
||||
|
||||
|
@ -203,9 +205,9 @@ def do_train(cfg, model, resume=False):
|
|||
batch_size=cfg.train.batch_size_per_gpu,
|
||||
num_workers=cfg.train.num_workers,
|
||||
shuffle=True,
|
||||
seed=start_iter, # TODO: Fix this -- cfg.train.seed
|
||||
seed=cfg.train.seed, # TODO: Fix this -- cfg.train.seed
|
||||
sampler_type=sampler_type,
|
||||
sampler_advance=0, # TODO(qas): fix this -- start_iter * cfg.train.batch_size_per_gpu,
|
||||
sampler_advance=start_iter * cfg.train.batch_size_per_gpu, # TODO(qas): fix this -- start_iter * cfg.train.batch_size_per_gpu,
|
||||
drop_last=True,
|
||||
collate_fn=collate_fn,
|
||||
)
|
||||
|
@ -214,11 +216,18 @@ def do_train(cfg, model, resume=False):
|
|||
|
||||
iteration = start_iter
|
||||
|
||||
accum_steps = cfg.train.grad_accum_steps
|
||||
last_accum_steps = OFFICIAL_EPOCH_LENGTH % accum_steps
|
||||
last_batch_idx = OFFICIAL_EPOCH_LENGTH - 1
|
||||
last_batch_idx_to_accum = OFFICIAL_EPOCH_LENGTH - last_accum_steps
|
||||
|
||||
logger.info("Starting training from iteration {}".format(start_iter))
|
||||
metrics_file = os.path.join(cfg.train.output_dir, "training_metrics.json")
|
||||
metric_logger = MetricLogger(delimiter=" ", output_file=metrics_file)
|
||||
header = "Training"
|
||||
|
||||
batch_idx = 0
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
for data in metric_logger.log_every(
|
||||
data_loader,
|
||||
10,
|
||||
|
@ -227,48 +236,53 @@ def do_train(cfg, model, resume=False):
|
|||
start_iter,
|
||||
):
|
||||
current_batch_size = data["collated_global_crops"].shape[0] / 2
|
||||
if iteration > max_iter:
|
||||
if iteration > max_effective_iter:
|
||||
return
|
||||
|
||||
# apply schedules
|
||||
# Determine if we need to update after this batch
|
||||
last_batch = batch_idx == last_batch_idx
|
||||
need_update = last_batch or (batch_idx + 1) % accum_steps == 0
|
||||
if batch_idx >= last_batch_idx_to_accum:
|
||||
accum_steps = last_accum_steps
|
||||
|
||||
lr = lr_schedule[iteration]
|
||||
wd = wd_schedule[iteration]
|
||||
mom = momentum_schedule[iteration]
|
||||
teacher_temp = teacher_temp_schedule[iteration]
|
||||
last_layer_lr = last_layer_lr_schedule[iteration]
|
||||
apply_optim_scheduler(optimizer, lr, wd, last_layer_lr)
|
||||
if need_update:
|
||||
# apply schedules
|
||||
lr = lr_schedule[iteration]
|
||||
wd = wd_schedule[iteration]
|
||||
mom = momentum_schedule[iteration]
|
||||
teacher_temp = teacher_temp_schedule[iteration]
|
||||
last_layer_lr = last_layer_lr_schedule[iteration]
|
||||
apply_optim_scheduler(optimizer, lr, wd, last_layer_lr)
|
||||
|
||||
# compute losses
|
||||
loss_dict = model.forward_backward(data, teacher_temp=teacher_temp, scale=1.0/accum_steps)
|
||||
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
loss_dict = model.forward_backward(data, teacher_temp=teacher_temp)
|
||||
# clip gradients and update weights only when needed
|
||||
if need_update:
|
||||
if fp16_scaler is not None:
|
||||
if cfg.optim.clip_grad:
|
||||
fp16_scaler.unscale_(optimizer)
|
||||
for v in model.student.values():
|
||||
v.clip_grad_norm_(cfg.optim.clip_grad)
|
||||
fp16_scaler.step(optimizer)
|
||||
fp16_scaler.update()
|
||||
else:
|
||||
if cfg.optim.clip_grad:
|
||||
for v in model.student.values():
|
||||
v.clip_grad_norm_(cfg.optim.clip_grad)
|
||||
optimizer.step()
|
||||
|
||||
# clip gradients
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
if fp16_scaler is not None:
|
||||
if cfg.optim.clip_grad:
|
||||
fp16_scaler.unscale_(optimizer)
|
||||
for v in model.student.values():
|
||||
v.clip_grad_norm_(cfg.optim.clip_grad)
|
||||
fp16_scaler.step(optimizer)
|
||||
fp16_scaler.update()
|
||||
else:
|
||||
if cfg.optim.clip_grad:
|
||||
for v in model.student.values():
|
||||
v.clip_grad_norm_(cfg.optim.clip_grad)
|
||||
optimizer.step()
|
||||
|
||||
# perform teacher EMA update
|
||||
|
||||
model.update_teacher(mom)
|
||||
# perform teacher EMA update
|
||||
model.update_teacher(mom)
|
||||
|
||||
# logging
|
||||
|
||||
if distributed.get_global_size() > 1:
|
||||
for v in loss_dict.values():
|
||||
torch.distributed.all_reduce(v)
|
||||
loss_dict_reduced = {k: v.item() / distributed.get_global_size() for k, v in loss_dict.items()}
|
||||
loss_dict_reduced = {k: v.item() * accum_steps / distributed.get_global_size() for k, v in loss_dict.items()}
|
||||
|
||||
if math.isnan(sum(loss_dict_reduced.values())):
|
||||
logger.info("NaN detected")
|
||||
|
@ -283,13 +297,19 @@ def do_train(cfg, model, resume=False):
|
|||
metric_logger.update(total_loss=losses_reduced, **loss_dict_reduced)
|
||||
|
||||
# checkpointing and testing
|
||||
|
||||
if cfg.evaluation.eval_period_iterations > 0 and (iteration + 1) % cfg.evaluation.eval_period_iterations == 0:
|
||||
if need_update and cfg.evaluation.eval_period_iterations > 0 and (iteration + 1) % cfg.evaluation.eval_period_iterations == 0:
|
||||
do_test(cfg, model, f"training_{iteration}")
|
||||
torch.cuda.synchronize()
|
||||
periodic_checkpointer.step(iteration)
|
||||
|
||||
iteration = iteration + 1
|
||||
if need_update:
|
||||
periodic_checkpointer.step(iteration)
|
||||
iteration = iteration + 1
|
||||
|
||||
batch_idx += 1
|
||||
if batch_idx >= OFFICIAL_EPOCH_LENGTH:
|
||||
batch_idx = 0
|
||||
accum_steps = cfg.train.grad_accum_steps
|
||||
|
||||
metric_logger.synchronize_between_processes()
|
||||
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
|
||||
|
||||
|
|
Loading…
Reference in New Issue