In dist training, update loss running avg every step, only sync on log updates / final.
parent
ae0737f5d0
commit
36b5d1adaa
18
train.py
18
train.py
|
@ -1042,8 +1042,7 @@ def train_one_epoch(
|
|||
loss = _forward()
|
||||
_backward(loss)
|
||||
|
||||
if not args.distributed:
|
||||
losses_m.update(loss.item() * accum_steps, input.size(0))
|
||||
losses_m.update(loss.item() * accum_steps, input.size(0))
|
||||
update_sample_count += input.size(0)
|
||||
|
||||
if not need_update:
|
||||
|
@ -1068,16 +1067,18 @@ def train_one_epoch(
|
|||
lrl = [param_group['lr'] for param_group in optimizer.param_groups]
|
||||
lr = sum(lrl) / len(lrl)
|
||||
|
||||
loss_avg, loss_now = losses_m.avg, losses_m.val
|
||||
if args.distributed:
|
||||
reduced_loss = utils.reduce_tensor(loss.data, args.world_size)
|
||||
losses_m.update(reduced_loss.item() * accum_steps, input.size(0))
|
||||
# synchronize current step and avg loss, each process keeps its own running avg
|
||||
loss_avg = utils.reduce_tensor(loss.new([loss_avg]), args.world_size).item()
|
||||
loss_now = utils.reduce_tensor(loss.new([loss_now]), args.world_size).item()
|
||||
update_sample_count *= args.world_size
|
||||
|
||||
if utils.is_primary(args):
|
||||
_logger.info(
|
||||
f'Train: {epoch} [{update_idx:>4d}/{updates_per_epoch} '
|
||||
f'({100. * (update_idx + 1) / updates_per_epoch:>3.0f}%)] '
|
||||
f'Loss: {losses_m.val:#.3g} ({losses_m.avg:#.3g}) '
|
||||
f'Loss: {loss_now:#.3g} ({loss_avg:#.3g}) '
|
||||
f'Time: {update_time_m.val:.3f}s, {update_sample_count / update_time_m.val:>7.2f}/s '
|
||||
f'({update_time_m.avg:.3f}s, {update_sample_count / update_time_m.avg:>7.2f}/s) '
|
||||
f'LR: {lr:.3e} '
|
||||
|
@ -1106,7 +1107,12 @@ def train_one_epoch(
|
|||
if hasattr(optimizer, 'sync_lookahead'):
|
||||
optimizer.sync_lookahead()
|
||||
|
||||
return OrderedDict([('loss', losses_m.avg)])
|
||||
loss_avg = losses_m.avg
|
||||
if args.distributed:
|
||||
# synchronize avg loss, each process keeps its own running avg
|
||||
loss_avg = torch.tensor([loss_avg], device=device, dtype=torch.float32)
|
||||
loss_avg = utils.reduce_tensor(loss_avg, args.world_size).item()
|
||||
return OrderedDict([('loss', loss_avg)])
|
||||
|
||||
|
||||
def validate(
|
||||
|
|
Loading…
Reference in New Issue