mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Fix] Fix base tta model (#593)
Co-authored-by: ubuntu <ubuntu@localhost.localdomain>
This commit is contained in:
parent
46add351fe
commit
2df5bc137b
@ -20,7 +20,7 @@ MergedDataSamples = List[BaseDataElement]
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class BaseTTAModel:
|
||||
class BaseTTAModel(nn.Module):
|
||||
"""Base model for inference with test-time augmentation.
|
||||
|
||||
``BaseTTAModel`` is a wrapper for inference given multi-batch data.
|
||||
@ -74,6 +74,7 @@ class BaseTTAModel:
|
||||
"""
|
||||
|
||||
def __init__(self, module: Union[dict, nn.Module]):
|
||||
super().__init__()
|
||||
if isinstance(module, nn.Module):
|
||||
self.module = module
|
||||
elif isinstance(module, dict):
|
||||
|
Loading…
x
Reference in New Issue
Block a user