mirror of https://github.com/open-mmlab/mmocr.git
parent
b32412a9e9
commit
a45716d20e
|
@ -36,8 +36,8 @@ class BaseTextDetector(BaseModel, metaclass=ABCMeta):
|
|||
return hasattr(self, 'neck') and self.neck is not None
|
||||
|
||||
def forward(self,
|
||||
batch_inputs: torch.Tensor,
|
||||
batch_data_samples: OptDetSampleList = None,
|
||||
inputs: torch.Tensor,
|
||||
data_samples: OptDetSampleList = None,
|
||||
mode: str = 'tensor') -> ForwardResults:
|
||||
"""The unified entry for a forward process in both training and test.
|
||||
|
||||
|
@ -54,10 +54,11 @@ class BaseTextDetector(BaseModel, metaclass=ABCMeta):
|
|||
optimizer updating, which are done in the :meth:`train_step`.
|
||||
|
||||
Args:
|
||||
batch_inputs (torch.Tensor): The input tensor with shape
|
||||
inputs (torch.Tensor): The input tensor with shape
|
||||
(N, C, ...) in general.
|
||||
batch_data_samples (list[:obj:`TextDetDataSample`], optional): The
|
||||
annotation data of every samples. Defaults to None.
|
||||
data_samples (list[:obj:`TextDetDataSample`], optional): A batch of
|
||||
data samples that contain annotations and predictions.
|
||||
Defaults to None.
|
||||
mode (str): Return what kind of value. Defaults to 'tensor'.
|
||||
|
||||
Returns:
|
||||
|
@ -68,32 +69,30 @@ class BaseTextDetector(BaseModel, metaclass=ABCMeta):
|
|||
- If ``mode="loss"``, return a dict of tensor.
|
||||
"""
|
||||
if mode == 'loss':
|
||||
return self.loss(batch_inputs, batch_data_samples)
|
||||
return self.loss(inputs, data_samples)
|
||||
elif mode == 'predict':
|
||||
return self.predict(batch_inputs, batch_data_samples)
|
||||
return self.predict(inputs, data_samples)
|
||||
elif mode == 'tensor':
|
||||
return self._forward(batch_inputs, batch_data_samples)
|
||||
return self._forward(inputs, data_samples)
|
||||
else:
|
||||
raise RuntimeError(f'Invalid mode "{mode}". '
|
||||
'Only supports loss, predict and tensor mode')
|
||||
|
||||
@abstractmethod
|
||||
def loss(self, batch_inputs: Tensor,
|
||||
batch_data_samples: DetSampleList) -> Union[dict, tuple]:
|
||||
def loss(self, inputs: Tensor,
|
||||
data_samples: DetSampleList) -> Union[dict, tuple]:
|
||||
"""Calculate losses from a batch of inputs and data samples."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def predict(self, batch_inputs: Tensor,
|
||||
batch_data_samples: DetSampleList) -> DetSampleList:
|
||||
def predict(self, inputs: Tensor,
|
||||
data_samples: DetSampleList) -> DetSampleList:
|
||||
"""Predict results from a batch of inputs and data samples with post-
|
||||
processing."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _forward(self,
|
||||
batch_inputs: Tensor,
|
||||
batch_data_samples: OptDetSampleList = None):
|
||||
def _forward(self, inputs: Tensor, data_samples: OptDetSampleList = None):
|
||||
"""Network forward process.
|
||||
|
||||
Usually includes backbone, neck and head forward without any post-
|
||||
|
@ -102,6 +101,6 @@ class BaseTextDetector(BaseModel, metaclass=ABCMeta):
|
|||
pass
|
||||
|
||||
@abstractmethod
|
||||
def extract_feat(self, batch_inputs: Tensor):
|
||||
def extract_feat(self, inputs: Tensor):
|
||||
"""Extract features from images."""
|
||||
pass
|
||||
|
|
Loading…
Reference in New Issue