forward & backward in same no_sync context, slightly easier to read that splitting
parent
4cd7fb88b2
commit
a83e9f2d3b
24
train.py
24
train.py
|
@ -905,19 +905,10 @@ def train_one_epoch(
|
|||
loss /= accum_steps
|
||||
return loss
|
||||
|
||||
if has_no_sync and not need_update:
|
||||
with model.no_sync():
|
||||
loss = _forward()
|
||||
else:
|
||||
loss = _forward()
|
||||
|
||||
if not args.distributed:
|
||||
losses_m.update(loss.item() * accum_steps, input.size(0))
|
||||
|
||||
def _backward():
|
||||
def _backward(_loss):
|
||||
if loss_scaler is not None:
|
||||
loss_scaler(
|
||||
loss,
|
||||
_loss,
|
||||
optimizer,
|
||||
clip_grad=args.clip_grad,
|
||||
clip_mode=args.clip_mode,
|
||||
|
@ -926,7 +917,7 @@ def train_one_epoch(
|
|||
need_update=need_update,
|
||||
)
|
||||
else:
|
||||
loss.backward(create_graph=second_order)
|
||||
_loss.backward(create_graph=second_order)
|
||||
if need_update:
|
||||
if args.clip_grad is not None:
|
||||
utils.dispatch_clip_grad(
|
||||
|
@ -938,11 +929,16 @@ def train_one_epoch(
|
|||
|
||||
if has_no_sync and not need_update:
|
||||
with model.no_sync():
|
||||
_backward()
|
||||
loss = _forward()
|
||||
_backward(loss)
|
||||
else:
|
||||
_backward()
|
||||
loss = _forward()
|
||||
_backward(loss)
|
||||
|
||||
if not args.distributed:
|
||||
losses_m.update(loss.item() * accum_steps, input.size(0))
|
||||
update_sample_count += input.size(0)
|
||||
|
||||
if not need_update:
|
||||
data_start_time = time.time()
|
||||
continue
|
||||
|
|
Loading…
Reference in New Issue