diff --git a/timm/utils/cuda.py b/timm/utils/cuda.py index 3af3d6e5..de0b881c 100644 --- a/timm/utils/cuda.py +++ b/timm/utils/cuda.py @@ -29,9 +29,9 @@ class ApexScaler: ): with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward(create_graph=create_graph) - if clip_grad is not None: - dispatch_clip_grad(amp.master_params(optimizer), clip_grad, mode=clip_mode) if need_update: + if clip_grad is not None: + dispatch_clip_grad(amp.master_params(optimizer), clip_grad, mode=clip_mode) optimizer.step() def state_dict(self): @@ -60,11 +60,11 @@ class NativeScaler: need_update=True, ): self._scaler.scale(loss).backward(create_graph=create_graph) - if clip_grad is not None: - assert parameters is not None - self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place - dispatch_clip_grad(parameters, clip_grad, mode=clip_mode) if need_update: + if clip_grad is not None: + assert parameters is not None + self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place + dispatch_clip_grad(parameters, clip_grad, mode=clip_mode) self._scaler.step(optimizer) self._scaler.update() diff --git a/train.py b/train.py index 6fd60292..cc73d0ae 100755 --- a/train.py +++ b/train.py @@ -927,13 +927,13 @@ def train_one_epoch( ) else: loss.backward(create_graph=second_order) - if args.clip_grad is not None: - utils.dispatch_clip_grad( - model_parameters(model, exclude_head='agc' in args.clip_mode), - value=args.clip_grad, - mode=args.clip_mode, - ) if need_update: + if args.clip_grad is not None: + utils.dispatch_clip_grad( + model_parameters(model, exclude_head='agc' in args.clip_mode), + value=args.clip_grad, + mode=args.clip_mode, + ) optimizer.step() if has_no_sync and not need_update: