mirror of https://github.com/JDAI-CV/fast-reid.git
change way of layer freezing
Remove `find_unused_parameters` in DDP and add a new step function in optimizer for freezing backbone. It will accelerate training speed in this way.pull/504/head
parent
dbf1604231
commit
2b65882447
|
@ -213,7 +213,6 @@ class DefaultTrainer(TrainerBase):
|
||||||
# for part of the parameters is not updated.
|
# for part of the parameters is not updated.
|
||||||
model = DistributedDataParallel(
|
model = DistributedDataParallel(
|
||||||
model, device_ids=[comm.get_local_rank()], broadcast_buffers=False,
|
model, device_ids=[comm.get_local_rank()], broadcast_buffers=False,
|
||||||
find_unused_parameters=True
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self._trainer = (AMPTrainer if cfg.SOLVER.AMP.ENABLED else SimpleTrainer)(
|
self._trainer = (AMPTrainer if cfg.SOLVER.AMP.ENABLED else SimpleTrainer)(
|
||||||
|
@ -305,9 +304,9 @@ class DefaultTrainer(TrainerBase):
|
||||||
|
|
||||||
ret.append(hooks.LayerFreeze(
|
ret.append(hooks.LayerFreeze(
|
||||||
self.model,
|
self.model,
|
||||||
|
self.optimizer,
|
||||||
cfg.MODEL.FREEZE_LAYERS,
|
cfg.MODEL.FREEZE_LAYERS,
|
||||||
cfg.SOLVER.FREEZE_ITERS,
|
cfg.SOLVER.FREEZE_ITERS,
|
||||||
cfg.SOLVER.FREEZE_FC_ITERS,
|
|
||||||
))
|
))
|
||||||
|
|
||||||
# Do PreciseBN before checkpointer, because it updates the model and need to
|
# Do PreciseBN before checkpointer, because it updates the model and need to
|
||||||
|
|
|
@ -449,19 +449,18 @@ class PreciseBN(HookBase):
|
||||||
|
|
||||||
|
|
||||||
class LayerFreeze(HookBase):
|
class LayerFreeze(HookBase):
|
||||||
def __init__(self, model, freeze_layers, freeze_iters, fc_freeze_iters):
|
def __init__(self, model, optimizer, freeze_layers, freeze_iters):
|
||||||
self._logger = logging.getLogger(__name__)
|
self._logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
if isinstance(model, DistributedDataParallel):
|
if isinstance(model, DistributedDataParallel):
|
||||||
model = model.module
|
model = model.module
|
||||||
self.model = model
|
self.model = model
|
||||||
|
self.optimizer = optimizer
|
||||||
|
|
||||||
self.freeze_layers = freeze_layers
|
self.freeze_layers = freeze_layers
|
||||||
self.freeze_iters = freeze_iters
|
self.freeze_iters = freeze_iters
|
||||||
self.fc_freeze_iters = fc_freeze_iters
|
|
||||||
|
|
||||||
self.is_frozen = False
|
self.is_frozen = False
|
||||||
self.fc_frozen = False
|
|
||||||
|
|
||||||
def before_step(self):
|
def before_step(self):
|
||||||
# Freeze specific layers
|
# Freeze specific layers
|
||||||
|
@ -472,18 +471,6 @@ class LayerFreeze(HookBase):
|
||||||
if self.trainer.iter >= self.freeze_iters and self.is_frozen:
|
if self.trainer.iter >= self.freeze_iters and self.is_frozen:
|
||||||
self.open_all_layer()
|
self.open_all_layer()
|
||||||
|
|
||||||
if self.trainer.max_iter - self.trainer.iter <= self.fc_freeze_iters \
|
|
||||||
and not self.fc_frozen:
|
|
||||||
self.freeze_classifier()
|
|
||||||
|
|
||||||
def freeze_classifier(self):
|
|
||||||
for p in self.model.heads.classifier.parameters():
|
|
||||||
p.requires_grad_(False)
|
|
||||||
|
|
||||||
self.fc_frozen = True
|
|
||||||
self._logger.info("Freeze classifier training for "
|
|
||||||
"last {} iterations".format(self.fc_freeze_iters))
|
|
||||||
|
|
||||||
def freeze_specific_layer(self):
|
def freeze_specific_layer(self):
|
||||||
for layer in self.freeze_layers:
|
for layer in self.freeze_layers:
|
||||||
if not hasattr(self.model, layer):
|
if not hasattr(self.model, layer):
|
||||||
|
@ -493,8 +480,24 @@ class LayerFreeze(HookBase):
|
||||||
if name in self.freeze_layers:
|
if name in self.freeze_layers:
|
||||||
# Change BN in freeze layers to eval mode
|
# Change BN in freeze layers to eval mode
|
||||||
module.eval()
|
module.eval()
|
||||||
for p in module.parameters():
|
|
||||||
p.requires_grad_(False)
|
def zero_freeze_grad():
|
||||||
|
for group in self.optimizer.param_groups:
|
||||||
|
if group["name"].split('.')[0] in self.freeze_layers:
|
||||||
|
for p in group["params"]:
|
||||||
|
if p.grad is not None:
|
||||||
|
p.grad = None
|
||||||
|
|
||||||
|
origin_step = self.optimizer.step
|
||||||
|
self.origin_step = origin_step
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def step(closure=None):
|
||||||
|
zero_freeze_grad()
|
||||||
|
loss = origin_step(closure)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
self.optimizer.step = step
|
||||||
|
|
||||||
self.is_frozen = True
|
self.is_frozen = True
|
||||||
freeze_layers = ", ".join(self.freeze_layers)
|
freeze_layers = ", ".join(self.freeze_layers)
|
||||||
|
@ -504,8 +507,8 @@ class LayerFreeze(HookBase):
|
||||||
for name, module in self.model.named_children():
|
for name, module in self.model.named_children():
|
||||||
if name in self.freeze_layers:
|
if name in self.freeze_layers:
|
||||||
module.train()
|
module.train()
|
||||||
for p in module.parameters():
|
|
||||||
p.requires_grad_(True)
|
self.optimizer.step = self.origin_step
|
||||||
|
|
||||||
self.is_frozen = False
|
self.is_frozen = False
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue