[Fix] Fix BaseTextDetector (#1324)

* fix base

* update docstring
pull/1327/head
Xinyu Wang 2022-08-25 14:04:25 +08:00 committed by GitHub
parent b32412a9e9
commit a45716d20e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 15 additions and 16 deletions

View File

@ -36,8 +36,8 @@ class BaseTextDetector(BaseModel, metaclass=ABCMeta):
return hasattr(self, 'neck') and self.neck is not None
def forward(self,
batch_inputs: torch.Tensor,
batch_data_samples: OptDetSampleList = None,
inputs: torch.Tensor,
data_samples: OptDetSampleList = None,
mode: str = 'tensor') -> ForwardResults:
"""The unified entry for a forward process in both training and test.
@ -54,10 +54,11 @@ class BaseTextDetector(BaseModel, metaclass=ABCMeta):
optimizer updating, which are done in the :meth:`train_step`.
Args:
batch_inputs (torch.Tensor): The input tensor with shape
inputs (torch.Tensor): The input tensor with shape
(N, C, ...) in general.
batch_data_samples (list[:obj:`TextDetDataSample`], optional): The
annotation data of every samples. Defaults to None.
data_samples (list[:obj:`TextDetDataSample`], optional): A batch of
data samples that contain annotations and predictions.
Defaults to None.
mode (str): Return what kind of value. Defaults to 'tensor'.
Returns:
@ -68,32 +69,30 @@ class BaseTextDetector(BaseModel, metaclass=ABCMeta):
- If ``mode="loss"``, return a dict of tensor.
"""
if mode == 'loss':
return self.loss(batch_inputs, batch_data_samples)
return self.loss(inputs, data_samples)
elif mode == 'predict':
return self.predict(batch_inputs, batch_data_samples)
return self.predict(inputs, data_samples)
elif mode == 'tensor':
return self._forward(batch_inputs, batch_data_samples)
return self._forward(inputs, data_samples)
else:
raise RuntimeError(f'Invalid mode "{mode}". '
'Only supports loss, predict and tensor mode')
@abstractmethod
def loss(self, batch_inputs: Tensor,
batch_data_samples: DetSampleList) -> Union[dict, tuple]:
def loss(self, inputs: Tensor,
data_samples: DetSampleList) -> Union[dict, tuple]:
"""Calculate losses from a batch of inputs and data samples."""
pass
@abstractmethod
def predict(self, batch_inputs: Tensor,
batch_data_samples: DetSampleList) -> DetSampleList:
def predict(self, inputs: Tensor,
data_samples: DetSampleList) -> DetSampleList:
"""Predict results from a batch of inputs and data samples with post-
processing."""
pass
@abstractmethod
def _forward(self,
batch_inputs: Tensor,
batch_data_samples: OptDetSampleList = None):
def _forward(self, inputs: Tensor, data_samples: OptDetSampleList = None):
"""Network forward process.
Usually includes backbone, neck and head forward without any post-
@ -102,6 +101,6 @@ class BaseTextDetector(BaseModel, metaclass=ABCMeta):
pass
@abstractmethod
def extract_feat(self, batch_inputs: Tensor):
def extract_feat(self, inputs: Tensor):
"""Extract features from images."""
pass