diff --git a/demo/test_ap_on_coco.py b/demo/test_ap_on_coco.py index 59ce6a2..f1e532f 100644 --- a/demo/test_ap_on_coco.py +++ b/demo/test_ap_on_coco.py @@ -26,7 +26,7 @@ def load_model(model_config_path: str, model_checkpoint_path: str, device: str = args.device = device model = build_model(args) checkpoint = torch.load(model_checkpoint_path, map_location="cpu") - model.load_state_dict(clean_state_dict(checkpoint["ema_model"]), strict=False) + model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False) model.eval() return model