mirror of https://github.com/JDAI-CV/fast-reid.git
Impl `freezebb` in optimizer's step()
Make impl. of `freezebb` consistent with impl. of grad clip, and both are implemented through step() in optimizerpull/504/head
parent
07b8251ccb
commit
256721cfde
|
@ -152,10 +152,8 @@ class DefaultPredictor:
|
|||
inputs = {"images": image.to(self.model.device)}
|
||||
with torch.no_grad(): # https://github.com/sphinx-doc/sphinx/issues/4258
|
||||
predictions = self.model(inputs)
|
||||
# Normalize feature to compute cosine distance
|
||||
features = F.normalize(predictions)
|
||||
features = features.cpu().data
|
||||
return features
|
||||
return predictions
|
||||
|
||||
|
||||
|
||||
class DefaultTrainer(TrainerBase):
|
||||
|
@ -281,17 +279,6 @@ class DefaultTrainer(TrainerBase):
|
|||
hooks.LRScheduler(self.optimizer, self.scheduler),
|
||||
]
|
||||
|
||||
# if cfg.SOLVER.SWA.ENABLED:
|
||||
# ret.append(
|
||||
# hooks.SWA(
|
||||
# cfg.SOLVER.MAX_ITER,
|
||||
# cfg.SOLVER.SWA.PERIOD,
|
||||
# cfg.SOLVER.SWA.LR_FACTOR,
|
||||
# cfg.SOLVER.SWA.ETA_MIN_LR,
|
||||
# cfg.SOLVER.SWA.LR_SCHED,
|
||||
# )
|
||||
# )
|
||||
|
||||
if cfg.TEST.PRECISE_BN.ENABLED and hooks.get_bn_modules(self.model):
|
||||
logger.info("Prepare precise BN dataset")
|
||||
ret.append(hooks.PreciseBN(
|
||||
|
@ -302,12 +289,12 @@ class DefaultTrainer(TrainerBase):
|
|||
cfg.TEST.PRECISE_BN.NUM_ITER,
|
||||
))
|
||||
|
||||
ret.append(hooks.LayerFreeze(
|
||||
self.model,
|
||||
self.optimizer,
|
||||
cfg.MODEL.FREEZE_LAYERS,
|
||||
cfg.SOLVER.FREEZE_ITERS,
|
||||
))
|
||||
if len(cfg.MODEL.FREEZE_LAYERS) > 0 and cfg.SOLVER.FREEZE_ITERS > 0:
|
||||
ret.append(hooks.LayerFreeze(
|
||||
self.model,
|
||||
cfg.MODEL.FREEZE_LAYERS,
|
||||
cfg.SOLVER.FREEZE_ITERS,
|
||||
))
|
||||
|
||||
# Do PreciseBN before checkpointer, because it updates the model and need to
|
||||
# be saved by checkpointer.
|
||||
|
|
|
@ -449,13 +449,11 @@ class PreciseBN(HookBase):
|
|||
|
||||
|
||||
class LayerFreeze(HookBase):
|
||||
def __init__(self, model, optimizer, freeze_layers, freeze_iters):
|
||||
def __init__(self, model, freeze_layers, freeze_iters):
|
||||
self._logger = logging.getLogger(__name__)
|
||||
|
||||
if isinstance(model, DistributedDataParallel):
|
||||
model = model.module
|
||||
self.model = model
|
||||
self.optimizer = optimizer
|
||||
|
||||
self.freeze_layers = freeze_layers
|
||||
self.freeze_iters = freeze_iters
|
||||
|
@ -481,24 +479,6 @@ class LayerFreeze(HookBase):
|
|||
# Change BN in freeze layers to eval mode
|
||||
module.eval()
|
||||
|
||||
def zero_freeze_grad():
|
||||
for group in self.optimizer.param_groups:
|
||||
if group["name"].split('.')[0] in self.freeze_layers:
|
||||
for p in group["params"]:
|
||||
if p.grad is not None:
|
||||
p.grad = None
|
||||
|
||||
origin_step = self.optimizer.step
|
||||
self.origin_step = origin_step
|
||||
|
||||
@torch.no_grad()
|
||||
def step(closure=None):
|
||||
zero_freeze_grad()
|
||||
loss = origin_step(closure)
|
||||
return loss
|
||||
|
||||
self.optimizer.step = step
|
||||
|
||||
self.is_frozen = True
|
||||
freeze_layers = ", ".join(self.freeze_layers)
|
||||
self._logger.info(f'Freeze layer group "{freeze_layers}" training for {self.freeze_iters:d} iterations')
|
||||
|
@ -508,8 +488,6 @@ class LayerFreeze(HookBase):
|
|||
if name in self.freeze_layers:
|
||||
module.train()
|
||||
|
||||
self.optimizer.step = self.origin_step
|
||||
|
||||
self.is_frozen = False
|
||||
|
||||
freeze_layers = ", ".join(self.freeze_layers)
|
||||
|
|
Loading…
Reference in New Issue