Impl `freezebb` in optimizer's step()

Make impl. of `freezebb` consistent with impl. of grad clip, and both are implemented through step() in optimizer
pull/504/head
liaoxingyu 2021-05-31 17:15:26 +08:00
parent 07b8251ccb
commit 256721cfde
2 changed files with 9 additions and 44 deletions

View File

@ -152,10 +152,8 @@ class DefaultPredictor:
inputs = {"images": image.to(self.model.device)} inputs = {"images": image.to(self.model.device)}
with torch.no_grad(): # https://github.com/sphinx-doc/sphinx/issues/4258 with torch.no_grad(): # https://github.com/sphinx-doc/sphinx/issues/4258
predictions = self.model(inputs) predictions = self.model(inputs)
# Normalize feature to compute cosine distance return predictions
features = F.normalize(predictions)
features = features.cpu().data
return features
class DefaultTrainer(TrainerBase): class DefaultTrainer(TrainerBase):
@ -281,17 +279,6 @@ class DefaultTrainer(TrainerBase):
hooks.LRScheduler(self.optimizer, self.scheduler), hooks.LRScheduler(self.optimizer, self.scheduler),
] ]
# if cfg.SOLVER.SWA.ENABLED:
# ret.append(
# hooks.SWA(
# cfg.SOLVER.MAX_ITER,
# cfg.SOLVER.SWA.PERIOD,
# cfg.SOLVER.SWA.LR_FACTOR,
# cfg.SOLVER.SWA.ETA_MIN_LR,
# cfg.SOLVER.SWA.LR_SCHED,
# )
# )
if cfg.TEST.PRECISE_BN.ENABLED and hooks.get_bn_modules(self.model): if cfg.TEST.PRECISE_BN.ENABLED and hooks.get_bn_modules(self.model):
logger.info("Prepare precise BN dataset") logger.info("Prepare precise BN dataset")
ret.append(hooks.PreciseBN( ret.append(hooks.PreciseBN(
@ -302,9 +289,9 @@ class DefaultTrainer(TrainerBase):
cfg.TEST.PRECISE_BN.NUM_ITER, cfg.TEST.PRECISE_BN.NUM_ITER,
)) ))
if len(cfg.MODEL.FREEZE_LAYERS) > 0 and cfg.SOLVER.FREEZE_ITERS > 0:
ret.append(hooks.LayerFreeze( ret.append(hooks.LayerFreeze(
self.model, self.model,
self.optimizer,
cfg.MODEL.FREEZE_LAYERS, cfg.MODEL.FREEZE_LAYERS,
cfg.SOLVER.FREEZE_ITERS, cfg.SOLVER.FREEZE_ITERS,
)) ))

View File

@ -449,13 +449,11 @@ class PreciseBN(HookBase):
class LayerFreeze(HookBase): class LayerFreeze(HookBase):
def __init__(self, model, optimizer, freeze_layers, freeze_iters): def __init__(self, model, freeze_layers, freeze_iters):
self._logger = logging.getLogger(__name__) self._logger = logging.getLogger(__name__)
if isinstance(model, DistributedDataParallel): if isinstance(model, DistributedDataParallel):
model = model.module model = model.module
self.model = model self.model = model
self.optimizer = optimizer
self.freeze_layers = freeze_layers self.freeze_layers = freeze_layers
self.freeze_iters = freeze_iters self.freeze_iters = freeze_iters
@ -481,24 +479,6 @@ class LayerFreeze(HookBase):
# Change BN in freeze layers to eval mode # Change BN in freeze layers to eval mode
module.eval() module.eval()
def zero_freeze_grad():
for group in self.optimizer.param_groups:
if group["name"].split('.')[0] in self.freeze_layers:
for p in group["params"]:
if p.grad is not None:
p.grad = None
origin_step = self.optimizer.step
self.origin_step = origin_step
@torch.no_grad()
def step(closure=None):
zero_freeze_grad()
loss = origin_step(closure)
return loss
self.optimizer.step = step
self.is_frozen = True self.is_frozen = True
freeze_layers = ", ".join(self.freeze_layers) freeze_layers = ", ".join(self.freeze_layers)
self._logger.info(f'Freeze layer group "{freeze_layers}" training for {self.freeze_iters:d} iterations') self._logger.info(f'Freeze layer group "{freeze_layers}" training for {self.freeze_iters:d} iterations')
@ -508,8 +488,6 @@ class LayerFreeze(HookBase):
if name in self.freeze_layers: if name in self.freeze_layers:
module.train() module.train()
self.optimizer.step = self.origin_step
self.is_frozen = False self.is_frozen = False
freeze_layers = ", ".join(self.freeze_layers) freeze_layers = ", ".join(self.freeze_layers)