forward & backward in same no_sync context, slightly easier to read that splitting

wip-voidbag-accumulate-grad
Ross Wightman 2023-04-20 08:14:05 -07:00
parent 4cd7fb88b2
commit a83e9f2d3b
1 changed files with 10 additions and 14 deletions

View File

@ -905,19 +905,10 @@ def train_one_epoch(
loss /= accum_steps loss /= accum_steps
return loss return loss
if has_no_sync and not need_update: def _backward(_loss):
with model.no_sync():
loss = _forward()
else:
loss = _forward()
if not args.distributed:
losses_m.update(loss.item() * accum_steps, input.size(0))
def _backward():
if loss_scaler is not None: if loss_scaler is not None:
loss_scaler( loss_scaler(
loss, _loss,
optimizer, optimizer,
clip_grad=args.clip_grad, clip_grad=args.clip_grad,
clip_mode=args.clip_mode, clip_mode=args.clip_mode,
@ -926,7 +917,7 @@ def train_one_epoch(
need_update=need_update, need_update=need_update,
) )
else: else:
loss.backward(create_graph=second_order) _loss.backward(create_graph=second_order)
if need_update: if need_update:
if args.clip_grad is not None: if args.clip_grad is not None:
utils.dispatch_clip_grad( utils.dispatch_clip_grad(
@ -938,11 +929,16 @@ def train_one_epoch(
if has_no_sync and not need_update: if has_no_sync and not need_update:
with model.no_sync(): with model.no_sync():
_backward() loss = _forward()
_backward(loss)
else: else:
_backward() loss = _forward()
_backward(loss)
if not args.distributed:
losses_m.update(loss.item() * accum_steps, input.size(0))
update_sample_count += input.size(0) update_sample_count += input.size(0)
if not need_update: if not need_update:
data_start_time = time.time() data_start_time = time.time()
continue continue