diff --git a/mmocr/models/textdet/detectors/base.py b/mmocr/models/textdet/detectors/base.py index d88713cb..dd0adab2 100644 --- a/mmocr/models/textdet/detectors/base.py +++ b/mmocr/models/textdet/detectors/base.py @@ -36,8 +36,8 @@ class BaseTextDetector(BaseModel, metaclass=ABCMeta): return hasattr(self, 'neck') and self.neck is not None def forward(self, - batch_inputs: torch.Tensor, - batch_data_samples: OptDetSampleList = None, + inputs: torch.Tensor, + data_samples: OptDetSampleList = None, mode: str = 'tensor') -> ForwardResults: """The unified entry for a forward process in both training and test. @@ -54,10 +54,11 @@ class BaseTextDetector(BaseModel, metaclass=ABCMeta): optimizer updating, which are done in the :meth:`train_step`. Args: - batch_inputs (torch.Tensor): The input tensor with shape + inputs (torch.Tensor): The input tensor with shape (N, C, ...) in general. - batch_data_samples (list[:obj:`TextDetDataSample`], optional): The - annotation data of every samples. Defaults to None. + data_samples (list[:obj:`TextDetDataSample`], optional): A batch of + data samples that contain annotations and predictions. + Defaults to None. mode (str): Return what kind of value. Defaults to 'tensor'. Returns: @@ -68,32 +69,30 @@ class BaseTextDetector(BaseModel, metaclass=ABCMeta): - If ``mode="loss"``, return a dict of tensor. """ if mode == 'loss': - return self.loss(batch_inputs, batch_data_samples) + return self.loss(inputs, data_samples) elif mode == 'predict': - return self.predict(batch_inputs, batch_data_samples) + return self.predict(inputs, data_samples) elif mode == 'tensor': - return self._forward(batch_inputs, batch_data_samples) + return self._forward(inputs, data_samples) else: raise RuntimeError(f'Invalid mode "{mode}". ' 'Only supports loss, predict and tensor mode') @abstractmethod - def loss(self, batch_inputs: Tensor, - batch_data_samples: DetSampleList) -> Union[dict, tuple]: + def loss(self, inputs: Tensor, + data_samples: DetSampleList) -> Union[dict, tuple]: """Calculate losses from a batch of inputs and data samples.""" pass @abstractmethod - def predict(self, batch_inputs: Tensor, - batch_data_samples: DetSampleList) -> DetSampleList: + def predict(self, inputs: Tensor, + data_samples: DetSampleList) -> DetSampleList: """Predict results from a batch of inputs and data samples with post- processing.""" pass @abstractmethod - def _forward(self, - batch_inputs: Tensor, - batch_data_samples: OptDetSampleList = None): + def _forward(self, inputs: Tensor, data_samples: OptDetSampleList = None): """Network forward process. Usually includes backbone, neck and head forward without any post- @@ -102,6 +101,6 @@ class BaseTextDetector(BaseModel, metaclass=ABCMeta): pass @abstractmethod - def extract_feat(self, batch_inputs: Tensor): + def extract_feat(self, inputs: Tensor): """Extract features from images.""" pass