[Fix] Fix base tta model (#593)

Co-authored-by: ubuntu <ubuntu@localhost.localdomain>
pull/596/head
Mashiro 2022-10-11 09:49:50 +08:00 committed by GitHub
parent 46add351fe
commit 2df5bc137b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 2 additions and 1 deletions

View File

@ -20,7 +20,7 @@ MergedDataSamples = List[BaseDataElement]
@MODELS.register_module()
class BaseTTAModel:
class BaseTTAModel(nn.Module):
"""Base model for inference with test-time augmentation.
``BaseTTAModel`` is a wrapper for inference given multi-batch data.
@ -74,6 +74,7 @@ class BaseTTAModel:
"""
def __init__(self, module: Union[dict, nn.Module]):
super().__init__()
if isinstance(module, nn.Module):
self.module = module
elif isinstance(module, dict):