parent
46add351fe
commit
2df5bc137b
|
@ -20,7 +20,7 @@ MergedDataSamples = List[BaseDataElement]
|
|||
|
||||
|
||||
@MODELS.register_module()
|
||||
class BaseTTAModel:
|
||||
class BaseTTAModel(nn.Module):
|
||||
"""Base model for inference with test-time augmentation.
|
||||
|
||||
``BaseTTAModel`` is a wrapper for inference given multi-batch data.
|
||||
|
@ -74,6 +74,7 @@ class BaseTTAModel:
|
|||
"""
|
||||
|
||||
def __init__(self, module: Union[dict, nn.Module]):
|
||||
super().__init__()
|
||||
if isinstance(module, nn.Module):
|
||||
self.module = module
|
||||
elif isinstance(module, dict):
|
||||
|
|
Loading…
Reference in New Issue