Rename 'feat' mode to 'tensor' mode

pull/913/head
mzr1996 2022-06-15 14:53:41 +08:00
parent 125b74d4ca
commit cecff79a79
2 changed files with 11 additions and 11 deletions

View File

@ -45,12 +45,12 @@ class BaseClassifier(BaseModel, metaclass=ABCMeta):
def forward(self,
batch_inputs: torch.Tensor,
data_samples: Optional[List[BaseDataElement]] = None,
mode: str = 'feat'):
mode: str = 'tensor'):
"""The unified entry for a forward process in both training and test.
The method should accept three modes: "feat", "predict" and "loss":
The method should accept three modes: "tensor", "predict" and "loss":
- "feat": Forward the whole network and return tensor or tuple of
- "tensor": Forward the whole network and return tensor or tuple of
tensor without any post-processing, same as a common nn.Module.
- "predict": Forward and return the predictions, which are fully
processed to a list of :obj:`BaseDataElement`.
@ -66,12 +66,12 @@ class BaseClassifier(BaseModel, metaclass=ABCMeta):
data_samples (List[BaseDataElement], optional): The annotation
data of every samples. It's required if ``mode="loss"``.
Defaults to None.
mode (str): Return what kind of value. Defaults to 'feat'.
mode (str): Return what kind of value. Defaults to 'tensor'.
Returns:
The return type depends on ``mode``.
- If ``mode="feat"``, return a tensor or a tuple of tensor.
- If ``mode="tensor"``, return a tensor or a tuple of tensor.
- If ``mode="predict"``, return a list of
:obj:`mmengine.BaseDataElement`.
- If ``mode="loss"``, return a dict of tensor.

View File

@ -72,12 +72,12 @@ class ImageClassifier(BaseClassifier):
def forward(self,
batch_inputs: torch.Tensor,
data_samples: Optional[List[ClsDataSample]] = None,
mode: str = 'feat'):
mode: str = 'tensor'):
"""The unified entry for a forward process in both training and test.
The method should accept three modes: "feat", "predict" and "loss":
The method should accept three modes: "tensor", "predict" and "loss":
- "feat": Forward the whole network and return tensor or tuple of
- "tensor": Forward the whole network and return tensor or tuple of
tensor without any post-processing, same as a common nn.Module.
- "predict": Forward and return the predictions, which are fully
processed to a list of :obj:`ClsDataSample`.
@ -93,17 +93,17 @@ class ImageClassifier(BaseClassifier):
data_samples (List[ClsDataSample], optional): The annotation
data of every samples. It's required if ``mode="loss"``.
Defaults to None.
mode (str): Return what kind of value. Defaults to 'feat'.
mode (str): Return what kind of value. Defaults to 'tensor'.
Returns:
The return type depends on ``mode``.
- If ``mode="feat"``, return a tuple of tensor.
- If ``mode="tensor"``, return a tensor or a tuple of tensor.
- If ``mode="predict"``, return a list of
:obj:`mmcls.core.ClsDataSample`.
- If ``mode="loss"``, return a dict of tensor.
"""
if mode == 'feat':
if mode == 'tensor':
feats = self.extract_feat(batch_inputs)
return self.head(feats) if self.with_head else feats
elif mode == 'loss':