diff --git a/train.py b/train.py index cc73d0ae..ec344a64 100755 --- a/train.py +++ b/train.py @@ -905,19 +905,10 @@ def train_one_epoch( loss /= accum_steps return loss - if has_no_sync and not need_update: - with model.no_sync(): - loss = _forward() - else: - loss = _forward() - - if not args.distributed: - losses_m.update(loss.item() * accum_steps, input.size(0)) - - def _backward(): + def _backward(_loss): if loss_scaler is not None: loss_scaler( - loss, + _loss, optimizer, clip_grad=args.clip_grad, clip_mode=args.clip_mode, @@ -926,7 +917,7 @@ def train_one_epoch( need_update=need_update, ) else: - loss.backward(create_graph=second_order) + _loss.backward(create_graph=second_order) if need_update: if args.clip_grad is not None: utils.dispatch_clip_grad( @@ -938,11 +929,16 @@ def train_one_epoch( if has_no_sync and not need_update: with model.no_sync(): - _backward() + loss = _forward() + _backward(loss) else: - _backward() + loss = _forward() + _backward(loss) + if not args.distributed: + losses_m.update(loss.item() * accum_steps, input.size(0)) update_sample_count += input.size(0) + if not need_update: data_start_time = time.time() continue