feat: freeze FC

Summary: update freeze FC in the last stages of training
pull/380/head
liaoxingyu 2020-12-28 14:46:28 +08:00
parent fe2e46d40e
commit 2c17847980
2 changed files with 23 additions and 8 deletions

View File

@ -321,12 +321,12 @@ class DefaultTrainer(TrainerBase):
cfg.TEST.PRECISE_BN.NUM_ITER,
))
if cfg.MODEL.FREEZE_LAYERS != [''] and cfg.SOLVER.FREEZE_ITERS > 0:
ret.append(hooks.LayerFreeze(
self.model,
cfg.MODEL.FREEZE_LAYERS,
cfg.SOLVER.FREEZE_ITERS,
))
ret.append(hooks.LayerFreeze(
self.model,
cfg.MODEL.FREEZE_LAYERS,
cfg.SOLVER.FREEZE_ITERS,
cfg.SOLVER.FREEZE_FC_ITERS,
))
# Do PreciseBN before checkpointer, because it updates the model and need to
# be saved by checkpointer.
# This is not always the best: if checkpointing has a different frequency,

View File

@ -447,7 +447,7 @@ class PreciseBN(HookBase):
class LayerFreeze(HookBase):
def __init__(self, model, freeze_layers, freeze_iters):
def __init__(self, model, freeze_layers, freeze_iters, fc_freeze_iters):
self._logger = logging.getLogger(__name__)
if isinstance(model, DistributedDataParallel):
@ -456,9 +456,12 @@ class LayerFreeze(HookBase):
self.freeze_layers = freeze_layers
self.freeze_iters = freeze_iters
self.fc_freeze_iters = fc_freeze_iters
self.is_frozen = False
self.fc_frozen = False
def before_step(self):
# Freeze specific layers
if self.trainer.iter < self.freeze_iters and not self.is_frozen:
@ -468,6 +471,18 @@ class LayerFreeze(HookBase):
if self.trainer.iter >= self.freeze_iters and self.is_frozen:
self.open_all_layer()
if self.trainer.max_iter - self.trainer.iter <= self.fc_freeze_iters \
and not self.fc_frozen:
self.freeze_classifier()
def freeze_classifier(self):
for p in self.model.heads.classifier.parameters():
p.requires_grad_(False)
self.fc_frozen = True
self._logger.info("Freeze classifier training for "
"last {} iterations".format(self.fc_freeze_iters))
def freeze_specific_layer(self):
for layer in self.freeze_layers:
if not hasattr(self.model, layer):
@ -493,7 +508,7 @@ class LayerFreeze(HookBase):
self.is_frozen = False
freeze_layers = ",".join(self.freeze_layers)
freeze_layers = ", ".join(self.freeze_layers)
self._logger.info(f'Open layer group "{freeze_layers}" training')