diff --git a/fastreid/engine/defaults.py b/fastreid/engine/defaults.py index df45b1c..95f9d0d 100644 --- a/fastreid/engine/defaults.py +++ b/fastreid/engine/defaults.py @@ -321,12 +321,12 @@ class DefaultTrainer(TrainerBase): cfg.TEST.PRECISE_BN.NUM_ITER, )) - if cfg.MODEL.FREEZE_LAYERS != [''] and cfg.SOLVER.FREEZE_ITERS > 0: - ret.append(hooks.LayerFreeze( - self.model, - cfg.MODEL.FREEZE_LAYERS, - cfg.SOLVER.FREEZE_ITERS, - )) + ret.append(hooks.LayerFreeze( + self.model, + cfg.MODEL.FREEZE_LAYERS, + cfg.SOLVER.FREEZE_ITERS, + cfg.SOLVER.FREEZE_FC_ITERS, + )) # Do PreciseBN before checkpointer, because it updates the model and need to # be saved by checkpointer. # This is not always the best: if checkpointing has a different frequency, diff --git a/fastreid/engine/hooks.py b/fastreid/engine/hooks.py index 74c42a7..8fc81e0 100644 --- a/fastreid/engine/hooks.py +++ b/fastreid/engine/hooks.py @@ -447,7 +447,7 @@ class PreciseBN(HookBase): class LayerFreeze(HookBase): - def __init__(self, model, freeze_layers, freeze_iters): + def __init__(self, model, freeze_layers, freeze_iters, fc_freeze_iters): self._logger = logging.getLogger(__name__) if isinstance(model, DistributedDataParallel): @@ -456,9 +456,12 @@ class LayerFreeze(HookBase): self.freeze_layers = freeze_layers self.freeze_iters = freeze_iters + self.fc_freeze_iters = fc_freeze_iters self.is_frozen = False + self.fc_frozen = False + def before_step(self): # Freeze specific layers if self.trainer.iter < self.freeze_iters and not self.is_frozen: @@ -468,6 +471,18 @@ class LayerFreeze(HookBase): if self.trainer.iter >= self.freeze_iters and self.is_frozen: self.open_all_layer() + if self.trainer.max_iter - self.trainer.iter <= self.fc_freeze_iters \ + and not self.fc_frozen: + self.freeze_classifier() + + def freeze_classifier(self): + for p in self.model.heads.classifier.parameters(): + p.requires_grad_(False) + + self.fc_frozen = True + self._logger.info("Freeze classifier training for " + "last {} iterations".format(self.fc_freeze_iters)) + def freeze_specific_layer(self): for layer in self.freeze_layers: if not hasattr(self.model, layer): @@ -493,7 +508,7 @@ class LayerFreeze(HookBase): self.is_frozen = False - freeze_layers = ",".join(self.freeze_layers) + freeze_layers = ", ".join(self.freeze_layers) self._logger.info(f'Open layer group "{freeze_layers}" training')