mirror of https://github.com/JDAI-CV/fast-reid.git
Support self-distill with EMA updated model
parent
256721cfde
commit
8ab3554958
|
@ -26,6 +26,13 @@ class Distiller(Baseline):
|
|||
for i in range(len(cfg.KD.MODEL_CONFIG)):
|
||||
cfg_t = get_cfg()
|
||||
cfg_t.merge_from_file(cfg.KD.MODEL_CONFIG[i])
|
||||
cfg_t.defrost()
|
||||
cfg_t.MODEL.META_ARCHITECTURE = "Baseline"
|
||||
# Change syncBN to BN due to no DDP wrapper
|
||||
if cfg_t.MODEL.BACKBONE.NORM == "syncBN":
|
||||
cfg_t.MODEL.BACKBONE.NORM = "BN"
|
||||
if cfg_t.MODEL.HEADS.NORM == "syncBN":
|
||||
cfg_t.MODEL.HEADS.NORM = "BN"
|
||||
|
||||
model_t = build_model(cfg_t)
|
||||
|
||||
|
@ -38,10 +45,43 @@ class Distiller(Baseline):
|
|||
|
||||
model_ts.append(model_t)
|
||||
|
||||
self.ema_enabled = cfg.KD.EMA.ENABLED
|
||||
self.ema_momentum = cfg.KD.EMA.MOMENTUM
|
||||
if self.ema_enabled:
|
||||
cfg_self = cfg.clone()
|
||||
cfg_self.defrost()
|
||||
cfg_self.MODEL.META_ARCHITECTURE = "Baseline"
|
||||
if cfg_self.MODEL.BACKBONE.NORM == "syncBN":
|
||||
cfg_self.MODEL.BACKBONE.NORM = "BN"
|
||||
if cfg_self.MODEL.HEADS.NORM == "syncBN":
|
||||
cfg_self.MODEL.HEADS.NORM = "BN"
|
||||
model_self = build_model(cfg_self)
|
||||
# No gradients for self model
|
||||
for param in model_self.parameters():
|
||||
param.requires_grad_(False)
|
||||
|
||||
if cfg_self.MODEL.WEIGHTS is not '':
|
||||
logger.info("Loading self distillation model weights ...")
|
||||
Checkpointer(model_self).load(cfg_self.MODEL.WEIGHTS)
|
||||
else:
|
||||
# Make sure the initial state is same
|
||||
for param_q, param_k in zip(self.parameters(), model_self.parameters()):
|
||||
param_k.data.copy_(param_q.data)
|
||||
|
||||
model_ts.insert(0, model_self)
|
||||
|
||||
# Not register teacher model as `nn.Module`, this is
|
||||
# make sure teacher model weights not saved
|
||||
self.model_ts = model_ts
|
||||
|
||||
@torch.no_grad()
|
||||
def _momentum_update_key_encoder(self, m=0.999):
|
||||
"""
|
||||
Momentum update of the key encoder
|
||||
"""
|
||||
for param_q, param_k in zip(self.parameters(), self.model_ts[0].parameters()):
|
||||
param_k.data = param_k.data * m + param_q.data * (1. - m)
|
||||
|
||||
def forward(self, batched_inputs):
|
||||
if self.training:
|
||||
images = self.preprocess_image(batched_inputs)
|
||||
|
@ -57,6 +97,8 @@ class Distiller(Baseline):
|
|||
t_outputs = []
|
||||
# teacher model forward
|
||||
with torch.no_grad():
|
||||
if self.ema_enabled:
|
||||
self._momentum_update_key_encoder(self.ema_momentum) # update self distill model
|
||||
for model_t in self.model_ts:
|
||||
t_feat = model_t.backbone(images)
|
||||
t_output = model_t.heads(t_feat, targets)
|
||||
|
|
|
@ -90,6 +90,8 @@ class DistillerOverhaul(Distiller):
|
|||
t_outputs = []
|
||||
# teacher model forward
|
||||
with torch.no_grad():
|
||||
if self.ema_enabled:
|
||||
self._momentum_update_key_encoder(self.ema_momentum)
|
||||
for model_t in self.model_ts:
|
||||
t_feats, t_feat = model_t.backbone.extract_feature(images, preReLU=True)
|
||||
t_output = model_t.heads(t_feat, targets)
|
||||
|
|
Loading…
Reference in New Issue