diff --git a/fastreid/engine/defaults.py b/fastreid/engine/defaults.py index 2fcf900..d5d7312 100644 --- a/fastreid/engine/defaults.py +++ b/fastreid/engine/defaults.py @@ -213,7 +213,6 @@ class DefaultTrainer(TrainerBase): # for part of the parameters is not updated. model = DistributedDataParallel( model, device_ids=[comm.get_local_rank()], broadcast_buffers=False, - find_unused_parameters=True ) self._trainer = (AMPTrainer if cfg.SOLVER.AMP.ENABLED else SimpleTrainer)( @@ -305,9 +304,9 @@ class DefaultTrainer(TrainerBase): ret.append(hooks.LayerFreeze( self.model, + self.optimizer, cfg.MODEL.FREEZE_LAYERS, cfg.SOLVER.FREEZE_ITERS, - cfg.SOLVER.FREEZE_FC_ITERS, )) # Do PreciseBN before checkpointer, because it updates the model and need to diff --git a/fastreid/engine/hooks.py b/fastreid/engine/hooks.py index c72a170..e66e3fa 100644 --- a/fastreid/engine/hooks.py +++ b/fastreid/engine/hooks.py @@ -449,19 +449,18 @@ class PreciseBN(HookBase): class LayerFreeze(HookBase): - def __init__(self, model, freeze_layers, freeze_iters, fc_freeze_iters): + def __init__(self, model, optimizer, freeze_layers, freeze_iters): self._logger = logging.getLogger(__name__) if isinstance(model, DistributedDataParallel): model = model.module self.model = model + self.optimizer = optimizer self.freeze_layers = freeze_layers self.freeze_iters = freeze_iters - self.fc_freeze_iters = fc_freeze_iters self.is_frozen = False - self.fc_frozen = False def before_step(self): # Freeze specific layers @@ -472,18 +471,6 @@ class LayerFreeze(HookBase): if self.trainer.iter >= self.freeze_iters and self.is_frozen: self.open_all_layer() - if self.trainer.max_iter - self.trainer.iter <= self.fc_freeze_iters \ - and not self.fc_frozen: - self.freeze_classifier() - - def freeze_classifier(self): - for p in self.model.heads.classifier.parameters(): - p.requires_grad_(False) - - self.fc_frozen = True - self._logger.info("Freeze classifier training for " - "last {} iterations".format(self.fc_freeze_iters)) - def freeze_specific_layer(self): for layer in self.freeze_layers: if not hasattr(self.model, layer): @@ -493,8 +480,24 @@ class LayerFreeze(HookBase): if name in self.freeze_layers: # Change BN in freeze layers to eval mode module.eval() - for p in module.parameters(): - p.requires_grad_(False) + + def zero_freeze_grad(): + for group in self.optimizer.param_groups: + if group["name"].split('.')[0] in self.freeze_layers: + for p in group["params"]: + if p.grad is not None: + p.grad = None + + origin_step = self.optimizer.step + self.origin_step = origin_step + + @torch.no_grad() + def step(closure=None): + zero_freeze_grad() + loss = origin_step(closure) + return loss + + self.optimizer.step = step self.is_frozen = True freeze_layers = ", ".join(self.freeze_layers) @@ -504,8 +507,8 @@ class LayerFreeze(HookBase): for name, module in self.model.named_children(): if name in self.freeze_layers: module.train() - for p in module.parameters(): - p.requires_grad_(True) + + self.optimizer.step = self.origin_step self.is_frozen = False