forward & backward in same no_sync context, slightly easier to read that splitting

wip-voidbag-accumulate-grad
Ross Wightman 2023-04-20 08:14:05 -07:00
parent 4cd7fb88b2
commit a83e9f2d3b
1 changed files with 10 additions and 14 deletions

View File

@ -905,19 +905,10 @@ def train_one_epoch(
loss /= accum_steps
return loss
if has_no_sync and not need_update:
with model.no_sync():
loss = _forward()
else:
loss = _forward()
if not args.distributed:
losses_m.update(loss.item() * accum_steps, input.size(0))
def _backward():
def _backward(_loss):
if loss_scaler is not None:
loss_scaler(
loss,
_loss,
optimizer,
clip_grad=args.clip_grad,
clip_mode=args.clip_mode,
@ -926,7 +917,7 @@ def train_one_epoch(
need_update=need_update,
)
else:
loss.backward(create_graph=second_order)
_loss.backward(create_graph=second_order)
if need_update:
if args.clip_grad is not None:
utils.dispatch_clip_grad(
@ -938,11 +929,16 @@ def train_one_epoch(
if has_no_sync and not need_update:
with model.no_sync():
_backward()
loss = _forward()
_backward(loss)
else:
_backward()
loss = _forward()
_backward(loss)
if not args.distributed:
losses_m.update(loss.item() * accum_steps, input.size(0))
update_sample_count += input.size(0)
if not need_update:
data_start_time = time.time()
continue