From 8ab35549585caf62ea5b770bc969f29135ca69ed Mon Sep 17 00:00:00 2001 From: liaoxingyu Date: Mon, 31 May 2021 17:17:24 +0800 Subject: [PATCH] Support self-distill with EMA updated model --- fastreid/modeling/meta_arch/distiller.py | 42 ++++++++++++++++++++ projects/FastDistill/fastdistill/overhaul.py | 2 + 2 files changed, 44 insertions(+) diff --git a/fastreid/modeling/meta_arch/distiller.py b/fastreid/modeling/meta_arch/distiller.py index 1bbb7dc..fc68f3c 100644 --- a/fastreid/modeling/meta_arch/distiller.py +++ b/fastreid/modeling/meta_arch/distiller.py @@ -26,6 +26,13 @@ class Distiller(Baseline): for i in range(len(cfg.KD.MODEL_CONFIG)): cfg_t = get_cfg() cfg_t.merge_from_file(cfg.KD.MODEL_CONFIG[i]) + cfg_t.defrost() + cfg_t.MODEL.META_ARCHITECTURE = "Baseline" + # Change syncBN to BN due to no DDP wrapper + if cfg_t.MODEL.BACKBONE.NORM == "syncBN": + cfg_t.MODEL.BACKBONE.NORM = "BN" + if cfg_t.MODEL.HEADS.NORM == "syncBN": + cfg_t.MODEL.HEADS.NORM = "BN" model_t = build_model(cfg_t) @@ -38,10 +45,43 @@ class Distiller(Baseline): model_ts.append(model_t) + self.ema_enabled = cfg.KD.EMA.ENABLED + self.ema_momentum = cfg.KD.EMA.MOMENTUM + if self.ema_enabled: + cfg_self = cfg.clone() + cfg_self.defrost() + cfg_self.MODEL.META_ARCHITECTURE = "Baseline" + if cfg_self.MODEL.BACKBONE.NORM == "syncBN": + cfg_self.MODEL.BACKBONE.NORM = "BN" + if cfg_self.MODEL.HEADS.NORM == "syncBN": + cfg_self.MODEL.HEADS.NORM = "BN" + model_self = build_model(cfg_self) + # No gradients for self model + for param in model_self.parameters(): + param.requires_grad_(False) + + if cfg_self.MODEL.WEIGHTS is not '': + logger.info("Loading self distillation model weights ...") + Checkpointer(model_self).load(cfg_self.MODEL.WEIGHTS) + else: + # Make sure the initial state is same + for param_q, param_k in zip(self.parameters(), model_self.parameters()): + param_k.data.copy_(param_q.data) + + model_ts.insert(0, model_self) + # Not register teacher model as `nn.Module`, this is # make sure teacher model weights not saved self.model_ts = model_ts + @torch.no_grad() + def _momentum_update_key_encoder(self, m=0.999): + """ + Momentum update of the key encoder + """ + for param_q, param_k in zip(self.parameters(), self.model_ts[0].parameters()): + param_k.data = param_k.data * m + param_q.data * (1. - m) + def forward(self, batched_inputs): if self.training: images = self.preprocess_image(batched_inputs) @@ -57,6 +97,8 @@ class Distiller(Baseline): t_outputs = [] # teacher model forward with torch.no_grad(): + if self.ema_enabled: + self._momentum_update_key_encoder(self.ema_momentum) # update self distill model for model_t in self.model_ts: t_feat = model_t.backbone(images) t_output = model_t.heads(t_feat, targets) diff --git a/projects/FastDistill/fastdistill/overhaul.py b/projects/FastDistill/fastdistill/overhaul.py index d311102..5c4bbb6 100644 --- a/projects/FastDistill/fastdistill/overhaul.py +++ b/projects/FastDistill/fastdistill/overhaul.py @@ -90,6 +90,8 @@ class DistillerOverhaul(Distiller): t_outputs = [] # teacher model forward with torch.no_grad(): + if self.ema_enabled: + self._momentum_update_key_encoder(self.ema_momentum) for model_t in self.model_ts: t_feats, t_feat = model_t.backbone.extract_feature(images, preReLU=True) t_output = model_t.heads(t_feat, targets)