diff --git a/mmengine/model/wrappers/test_time_aug.py b/mmengine/model/wrappers/test_time_aug.py index 677bc3ca..d99919df 100644 --- a/mmengine/model/wrappers/test_time_aug.py +++ b/mmengine/model/wrappers/test_time_aug.py @@ -20,7 +20,7 @@ MergedDataSamples = List[BaseDataElement] @MODELS.register_module() -class BaseTTAModel: +class BaseTTAModel(nn.Module): """Base model for inference with test-time augmentation. ``BaseTTAModel`` is a wrapper for inference given multi-batch data. @@ -74,6 +74,7 @@ class BaseTTAModel: """ def __init__(self, module: Union[dict, nn.Module]): + super().__init__() if isinstance(module, nn.Module): self.module = module elif isinstance(module, dict):