In dist training, update loss running avg every step, only sync on log updates / final.

friday_weights
Ross Wightman 2024-11-20 11:45:53 -08:00 committed by Ross Wightman
parent ae0737f5d0
commit 36b5d1adaa
1 changed files with 12 additions and 6 deletions

View File

@ -1042,8 +1042,7 @@ def train_one_epoch(
loss = _forward()
_backward(loss)
if not args.distributed:
losses_m.update(loss.item() * accum_steps, input.size(0))
losses_m.update(loss.item() * accum_steps, input.size(0))
update_sample_count += input.size(0)
if not need_update:
@ -1068,16 +1067,18 @@ def train_one_epoch(
lrl = [param_group['lr'] for param_group in optimizer.param_groups]
lr = sum(lrl) / len(lrl)
loss_avg, loss_now = losses_m.avg, losses_m.val
if args.distributed:
reduced_loss = utils.reduce_tensor(loss.data, args.world_size)
losses_m.update(reduced_loss.item() * accum_steps, input.size(0))
# synchronize current step and avg loss, each process keeps its own running avg
loss_avg = utils.reduce_tensor(loss.new([loss_avg]), args.world_size).item()
loss_now = utils.reduce_tensor(loss.new([loss_now]), args.world_size).item()
update_sample_count *= args.world_size
if utils.is_primary(args):
_logger.info(
f'Train: {epoch} [{update_idx:>4d}/{updates_per_epoch} '
f'({100. * (update_idx + 1) / updates_per_epoch:>3.0f}%)] '
f'Loss: {losses_m.val:#.3g} ({losses_m.avg:#.3g}) '
f'Loss: {loss_now:#.3g} ({loss_avg:#.3g}) '
f'Time: {update_time_m.val:.3f}s, {update_sample_count / update_time_m.val:>7.2f}/s '
f'({update_time_m.avg:.3f}s, {update_sample_count / update_time_m.avg:>7.2f}/s) '
f'LR: {lr:.3e} '
@ -1106,7 +1107,12 @@ def train_one_epoch(
if hasattr(optimizer, 'sync_lookahead'):
optimizer.sync_lookahead()
return OrderedDict([('loss', losses_m.avg)])
loss_avg = losses_m.avg
if args.distributed:
# synchronize avg loss, each process keeps its own running avg
loss_avg = torch.tensor([loss_avg], device=device, dtype=torch.float32)
loss_avg = utils.reduce_tensor(loss_avg, args.world_size).item()
return OrderedDict([('loss', loss_avg)])
def validate(