diff --git a/ppcls/engine/train/classification.py b/ppcls/engine/train/classification.py index 6f84133b3..9074a6b49 100644 --- a/ppcls/engine/train/classification.py +++ b/ppcls/engine/train/classification.py @@ -258,7 +258,7 @@ class ClassTrainer(object): return None def _build_ema_model(self): - if "EMA" in self.config: + if "EMA" in self.config and self.mode == "train": model_ema = ExponentialMovingAverage( self.model, self.config['EMA'].get("decay", 0.9999)) self.best_metric["metric_ema"] = 0