Support self-distill with EMA updated model

pull/504/head
liaoxingyu 2021-05-31 17:17:24 +08:00
parent 256721cfde
commit 8ab3554958
2 changed files with 44 additions and 0 deletions

View File

@ -26,6 +26,13 @@ class Distiller(Baseline):
for i in range(len(cfg.KD.MODEL_CONFIG)):
cfg_t = get_cfg()
cfg_t.merge_from_file(cfg.KD.MODEL_CONFIG[i])
cfg_t.defrost()
cfg_t.MODEL.META_ARCHITECTURE = "Baseline"
# Change syncBN to BN due to no DDP wrapper
if cfg_t.MODEL.BACKBONE.NORM == "syncBN":
cfg_t.MODEL.BACKBONE.NORM = "BN"
if cfg_t.MODEL.HEADS.NORM == "syncBN":
cfg_t.MODEL.HEADS.NORM = "BN"
model_t = build_model(cfg_t)
@ -38,10 +45,43 @@ class Distiller(Baseline):
model_ts.append(model_t)
self.ema_enabled = cfg.KD.EMA.ENABLED
self.ema_momentum = cfg.KD.EMA.MOMENTUM
if self.ema_enabled:
cfg_self = cfg.clone()
cfg_self.defrost()
cfg_self.MODEL.META_ARCHITECTURE = "Baseline"
if cfg_self.MODEL.BACKBONE.NORM == "syncBN":
cfg_self.MODEL.BACKBONE.NORM = "BN"
if cfg_self.MODEL.HEADS.NORM == "syncBN":
cfg_self.MODEL.HEADS.NORM = "BN"
model_self = build_model(cfg_self)
# No gradients for self model
for param in model_self.parameters():
param.requires_grad_(False)
if cfg_self.MODEL.WEIGHTS is not '':
logger.info("Loading self distillation model weights ...")
Checkpointer(model_self).load(cfg_self.MODEL.WEIGHTS)
else:
# Make sure the initial state is same
for param_q, param_k in zip(self.parameters(), model_self.parameters()):
param_k.data.copy_(param_q.data)
model_ts.insert(0, model_self)
# Not register teacher model as `nn.Module`, this is
# make sure teacher model weights not saved
self.model_ts = model_ts
@torch.no_grad()
def _momentum_update_key_encoder(self, m=0.999):
"""
Momentum update of the key encoder
"""
for param_q, param_k in zip(self.parameters(), self.model_ts[0].parameters()):
param_k.data = param_k.data * m + param_q.data * (1. - m)
def forward(self, batched_inputs):
if self.training:
images = self.preprocess_image(batched_inputs)
@ -57,6 +97,8 @@ class Distiller(Baseline):
t_outputs = []
# teacher model forward
with torch.no_grad():
if self.ema_enabled:
self._momentum_update_key_encoder(self.ema_momentum) # update self distill model
for model_t in self.model_ts:
t_feat = model_t.backbone(images)
t_output = model_t.heads(t_feat, targets)

View File

@ -90,6 +90,8 @@ class DistillerOverhaul(Distiller):
t_outputs = []
# teacher model forward
with torch.no_grad():
if self.ema_enabled:
self._momentum_update_key_encoder(self.ema_momentum)
for model_t in self.model_ts:
t_feats, t_feat = model_t.backbone.extract_feature(images, preReLU=True)
t_output = model_t.heads(t_feat, targets)