From 9bd5258513a9afcf8708799529c315d9c4f2bc2d Mon Sep 17 00:00:00 2001 From: Xinyu Wang <45810070+xinke-wang@users.noreply.github.com> Date: Thu, 25 Aug 2022 11:45:42 +0800 Subject: [PATCH] [Refactor] Adapt to new dataflow (#1305) * datasample->datasamples * update rec data preprocessor * rename datasamples * update det preprocessor * update metric * update data_sample->data_samples in test * update * fix data preprocessor uts * remove engine runner * fix kie ut * fix ut * fix comments * refactor evaluator ut * apply comments Co-authored-by: Tong Gao * remove useless * apply comments Co-authored-by: Tong Gao * apply comments Co-authored-by: Tong Gao Co-authored-by: Tong Gao --- demo/ner_demo.py | 31 -- mmocr/datasets/transforms/formatting.py | 18 +- mmocr/engine/__init__.py | 1 - mmocr/engine/runner/__init__.py | 4 - mmocr/engine/runner/multi_loops.py | 202 ------------- mmocr/evaluation/metrics/f_metric.py | 13 +- mmocr/evaluation/metrics/hmean_iou_metric.py | 6 +- mmocr/evaluation/metrics/recog_metric.py | 24 +- mmocr/models/kie/extractors/sdmgr.py | 65 ++--- mmocr/models/kie/heads/sdmgr_head.py | 42 ++- .../kie/postprocessors/sdmgr_postprocessor.py | 19 +- .../data_preprocessors/data_preprocessor.py | 31 +- .../models/textdet/detectors/mmdet_wrapper.py | 12 +- .../detectors/single_stage_text_detector.py | 50 ++-- mmocr/models/textdet/heads/base.py | 28 +- mmocr/models/textdet/heads/drrg_head.py | 29 +- .../data_preprocessors/data_preprocessor.py | 37 +-- mmocr/models/textrecog/recognizers/base.py | 25 +- .../recognizers/encoder_decoder_recognizer.py | 42 +-- .../test_transforms/test_formatting.py | 30 +- .../test_transforms/test_ocr_transforms.py | 6 - .../test_runner/test_multi_loop.py | 274 ------------------ .../test_multi_datasets_evaluator.py | 23 +- .../test_kie/test_extractors/test_sdmgr.py | 18 +- .../test_textdet_data_preprocessor.py | 69 +++-- .../test_wrappers/test_mmdet_wrapper.py | 6 - .../test_data_preprocessor.py | 68 +++-- 27 files changed, 316 insertions(+), 857 deletions(-) delete mode 100755 demo/ner_demo.py delete mode 100644 mmocr/engine/runner/__init__.py delete mode 100644 mmocr/engine/runner/multi_loops.py delete mode 100644 tests/test_engine/test_runner/test_multi_loop.py diff --git a/demo/ner_demo.py b/demo/ner_demo.py deleted file mode 100755 index 003f2e89..00000000 --- a/demo/ner_demo.py +++ /dev/null @@ -1,31 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from argparse import ArgumentParser - -from mmocr.apis import init_detector -from mmocr.apis.inference import text_model_inference -from mmocr.registry import DATASETS # NOQA - - -def main(): - parser = ArgumentParser() - parser.add_argument('config', help='Config file.') - parser.add_argument('checkpoint', help='Checkpoint file.') - parser.add_argument( - '--device', default='cuda:0', help='Device used for inference.') - args = parser.parse_args() - - # build the model from a config file and a checkpoint file - model = init_detector(args.config, args.checkpoint, device=args.device) - - # test a single text - input_sentence = input('Please enter a sentence you want to test: ') - result = text_model_inference(model, input_sentence) - - # show the results - for pred_entities in result: - for entity in pred_entities: - print(f'{entity[0]}: {input_sentence[entity[1]:entity[2] + 1]}') - - -if __name__ == '__main__': - main() diff --git a/mmocr/datasets/transforms/formatting.py b/mmocr/datasets/transforms/formatting.py index 54218fc4..64dd21ba 100644 --- a/mmocr/datasets/transforms/formatting.py +++ b/mmocr/datasets/transforms/formatting.py @@ -17,7 +17,7 @@ class PackTextDetInputs(BaseTransform): The type of outputs is `dict`: - inputs: image converted to tensor, whose shape is (C, H, W). - - data_sample: Two components of ``TextDetDataSample`` will be updated: + - data_samples: Two components of ``TextDetDataSample`` will be updated: - gt_instances (InstanceData): Depending on annotations, a subset of the following keys will be updated: @@ -82,7 +82,7 @@ class PackTextDetInputs(BaseTransform): dict: - 'inputs' (obj:`torch.Tensor`): Data for model forwarding. - - 'data_sample' (obj:`DetDataSample`): The annotation info of the + - 'data_samples' (obj:`DetDataSample`): The annotation info of the sample. """ packed_results = dict() @@ -109,7 +109,7 @@ class PackTextDetInputs(BaseTransform): for key in self.meta_keys: img_meta[key] = results[key] data_sample.set_metainfo(img_meta) - packed_results['data_sample'] = data_sample + packed_results['data_samples'] = data_sample return packed_results @@ -126,7 +126,7 @@ class PackTextRecogInputs(BaseTransform): The type of outputs is `dict`: - inputs: Image as a tensor, whose shape is (C, H, W). - - data_sample: Two components of ``TextRecogDataSample`` will be updated: + - data_samples: Two components of ``TextRecogDataSample`` will be updated: - gt_text (LabelData): @@ -166,7 +166,7 @@ class PackTextRecogInputs(BaseTransform): dict: - 'inputs' (obj:`torch.Tensor`): Data for model forwarding. - - 'data_sample' (obj:`TextRecogDataSample`): The annotation info + - 'data_samples' (obj:`TextRecogDataSample`): The annotation info of the sample. """ packed_results = dict() @@ -195,7 +195,7 @@ class PackTextRecogInputs(BaseTransform): img_meta[key] = results[key] data_sample.set_metainfo(img_meta) - packed_results['data_sample'] = data_sample + packed_results['data_samples'] = data_sample return packed_results @@ -212,7 +212,7 @@ class PackKIEInputs(BaseTransform): The type of outputs is `dict`: - inputs: image converted to tensor, whose shape is (C, H, W). - - data_sample: Two components of ``TextDetDataSample`` will be updated: + - data_samples: Two components of ``TextDetDataSample`` will be updated: - gt_instances (InstanceData): Depending on annotations, a subset of the following keys will be updated: @@ -266,7 +266,7 @@ class PackKIEInputs(BaseTransform): dict: - 'inputs' (obj:`torch.Tensor`): Data for model forwarding. - - 'data_sample' (obj:`DetDataSample`): The annotation info of the + - 'data_samples' (obj:`DetDataSample`): The annotation info of the sample. """ packed_results = dict() @@ -295,7 +295,7 @@ class PackKIEInputs(BaseTransform): for key in self.meta_keys: img_meta[key] = results[key] data_sample.set_metainfo(img_meta) - packed_results['data_sample'] = data_sample + packed_results['data_samples'] = data_sample return packed_results diff --git a/mmocr/engine/__init__.py b/mmocr/engine/__init__.py index c2db2ce5..1944bc1e 100644 --- a/mmocr/engine/__init__.py +++ b/mmocr/engine/__init__.py @@ -1,3 +1,2 @@ # Copyright (c) OpenMMLab. All rights reserved. from .hooks import * # NOQA -from .runner import * # NOQA diff --git a/mmocr/engine/runner/__init__.py b/mmocr/engine/runner/__init__.py deleted file mode 100644 index c5d67a0a..00000000 --- a/mmocr/engine/runner/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from .multi_loops import MultiTestLoop, MultiValLoop - -__all__ = ['MultiValLoop', 'MultiTestLoop'] diff --git a/mmocr/engine/runner/multi_loops.py b/mmocr/engine/runner/multi_loops.py deleted file mode 100644 index e769f6f9..00000000 --- a/mmocr/engine/runner/multi_loops.py +++ /dev/null @@ -1,202 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import warnings -from typing import Dict, List, Sequence, Union - -import torch -from mmengine.evaluator import Evaluator -from mmengine.runner.amp import autocast -from mmengine.runner.base_loop import BaseLoop -from mmengine.utils import is_list_of -from torch.utils.data import DataLoader - -from mmocr.registry import LOOPS - - -@LOOPS.register_module() -class MultiValLoop(BaseLoop): - """Loop for validation multi-datasets. - - Args: - runner (Runner): A reference of runner. - dataloader (list[Dataloader or dic]): A dataloader object or a dict to - build a dataloader. - evaluator (list[]): Used for computing metrics. - fp16 (bool): Whether to enable fp16 validation. Defaults to False. - """ - - def __init__(self, - runner, - dataloader: Union[DataLoader, Dict], - evaluator: Union[Evaluator, Dict, List], - fp16: bool = False) -> None: - self._runner = runner - assert isinstance(dataloader, list) - self.dataloaders = list() - for loader in dataloader: - if isinstance(loader, dict): - self.dataloaders.append( - runner.build_dataloader(loader, seed=runner.seed)) - else: - self.dataloaders.append(loader) - - assert isinstance(evaluator, list) - self.evaluators = list() - for single_evalator in evaluator: - if isinstance(single_evalator, dict) or is_list_of( - single_evalator, dict): - self.evaluators.append(runner.build_evaluator(single_evalator)) - else: - self.evaluators.append(single_evalator) - self.evaluators = [runner.build_evaluator(eval) for eval in evaluator] - - assert len(self.evaluators) == len(self.dataloaders) - - self.fp16 = fp16 - - def run(self): - """Launch validation.""" - self.runner.call_hook('before_val') - - self.runner.model.eval() - multi_metric = dict() - self.runner.call_hook('before_val_epoch') - for evaluator, dataloader in zip(self.evaluators, self.dataloaders): - self.evaluator = evaluator - self.dataloader = dataloader - if hasattr(self.dataloader.dataset, 'metainfo'): - self.evaluator.dataset_meta = self.dataloader.dataset.metainfo - self.runner.visualizer.dataset_meta = \ - self.dataloader.dataset.metainfo - else: - warnings.warn( - f'Dataset {self.dataloader.dataset.__class__.__name__} ' - 'has no metainfo. ``dataset_meta`` in evaluator, metric' - ' and visualizer will be None.') - for idx, data_batch in enumerate(self.dataloader): - self.run_iter(idx, data_batch) - # compute metrics - metrics = self.evaluator.evaluate(len(self.dataloader.dataset)) - if multi_metric and metrics.keys() & multi_metric.keys(): - raise ValueError('Please set different prefix for different' - ' datasets in `val_evaluator`') - else: - - multi_metric.update(metrics) - self.runner.call_hook('after_val_epoch', metrics=multi_metric) - self.runner.call_hook('after_val') - - @torch.no_grad() - def run_iter(self, idx: int, data_batch: Sequence[dict]): - """Iterate one mini-batch. - - Args: - idx (int): The index of the current batch in the loop. - data_batch (Sequence[dict]): Batch of data - from dataloader. - """ - self.runner.call_hook( - 'before_val_iter', batch_idx=idx, data_batch=data_batch) - # outputs should be sequence of BaseDataElement - with autocast(enabled=self.fp16): - outputs = self.runner.model.val_step(data_batch) - self.evaluator.process(data_batch, outputs) - self.runner.call_hook( - 'after_val_iter', - batch_idx=idx, - data_batch=data_batch, - outputs=outputs) - - -@LOOPS.register_module() -class MultiTestLoop(BaseLoop): - """Loop for validation multi-datasets. - - Args: - runner (Runner): A reference of runner. - dataloader (Dataloader or dict): A dataloader object or a dict to - build a dataloader. - evaluator (Evaluator or dict or list): Used for computing metrics. - fp16 (bool): Whether to enable fp16 validation. Defaults to False. - """ - - def __init__(self, - runner, - dataloader: Union[DataLoader, Dict], - evaluator: Union[Evaluator, Dict, List], - fp16: bool = False) -> None: - self._runner = runner - assert isinstance(dataloader, list) - self.dataloaders = list() - for loader in dataloader: - if isinstance(loader, dict): - self.dataloaders.append( - runner.build_dataloader(loader, seed=runner.seed)) - else: - self.dataloaders.append(loader) - - assert isinstance(evaluator, list) - self.evaluators = list() - for single_evalator in evaluator: - if isinstance(single_evalator, dict) or is_list_of( - single_evalator, dict): - self.evaluators.append(runner.build_evaluator(single_evalator)) - else: - self.evaluators.append(single_evalator) - self.evaluators = [runner.build_evaluator(eval) for eval in evaluator] - - assert len(self.evaluators) == len(self.dataloaders) - - self.fp16 = fp16 - - def run(self): - """Launch test.""" - self.runner.call_hook('before_test') - - self.runner.model.eval() - multi_metric = dict() - self.runner.call_hook('before_test_epoch') - for evaluator, dataloader in zip(self.evaluators, self.dataloaders): - self.dataloader = dataloader - self.evaluator = evaluator - if hasattr(self.dataloader.dataset, 'metainfo'): - self.evaluator.dataset_meta = self.dataloader.dataset.metainfo - self.runner.visualizer.dataset_meta = \ - self.dataloader.dataset.metainfo - else: - warnings.warn( - f'Dataset {self.dataloader.dataset.__class__.__name__} ' - 'has no metainfo. ``dataset_meta`` in evaluator, metric' - ' and visualizer will be None.') - for idx, data_batch in enumerate(self.dataloader): - self.run_iter(idx, data_batch) - # compute metrics - metrics = self.evaluator.evaluate(len(self.dataloader.dataset)) - if multi_metric and metrics.keys() & multi_metric.keys(): - raise ValueError('Please set different prefix for different' - ' datasets in `test_evaluator`') - else: - - multi_metric.update(metrics) - self.runner.call_hook('after_test_epoch', metrics=multi_metric) - self.runner.call_hook('after_test') - - @torch.no_grad() - def run_iter(self, idx: int, data_batch: Sequence[dict]): - """Iterate one mini-batch. - - Args: - idx (int): The index of the current batch in the loop. - data_batch (Sequence[dict]): Batch of data - from dataloader. - """ - self.runner.call_hook( - 'before_test_iter', batch_idx=idx, data_batch=data_batch) - # outputs should be sequence of BaseDataElement - with autocast(enabled=self.fp16): - predictions = self.runner.model.test_step(data_batch) - self.evaluator.process(data_batch, predictions) - self.runner.call_hook( - 'after_test_iter', - batch_idx=idx, - data_batch=data_batch, - outputs=predictions) diff --git a/mmocr/evaluation/metrics/f_metric.py b/mmocr/evaluation/metrics/f_metric.py index 4c4bba1a..e021ed6b 100644 --- a/mmocr/evaluation/metrics/f_metric.py +++ b/mmocr/evaluation/metrics/f_metric.py @@ -85,19 +85,18 @@ class F1Metric(BaseMetric): self.key = key def process(self, data_batch: Sequence[Dict], - predictions: Sequence[Dict]) -> None: - """Process one batch of predictions. The processed results should be + data_samples: Sequence[Dict]) -> None: + """Process one batch of data_samples. The processed results should be stored in ``self.results``, which will be used to compute the metrics when all batches have been processed. Args: data_batch (Sequence[Dict]): A batch of gts. - predictions (Sequence[Dict]): A batch of outputs from the model. + data_samples (Sequence[Dict]): A batch of outputs from the model. """ - for data_samples in predictions: - pred_labels = data_samples.get('pred_instances').get( - self.key).cpu() - gt_labels = data_samples.get('gt_instances').get(self.key).cpu() + for data_sample in data_samples: + pred_labels = data_sample.get('pred_instances').get(self.key).cpu() + gt_labels = data_sample.get('gt_instances').get(self.key).cpu() result = dict( pred_labels=pred_labels.flatten(), diff --git a/mmocr/evaluation/metrics/hmean_iou_metric.py b/mmocr/evaluation/metrics/hmean_iou_metric.py index 2816c598..c5d40971 100644 --- a/mmocr/evaluation/metrics/hmean_iou_metric.py +++ b/mmocr/evaluation/metrics/hmean_iou_metric.py @@ -75,17 +75,17 @@ class HmeanIOUMetric(BaseMetric): self.strategy = strategy def process(self, data_batch: Sequence[Dict], - predictions: Sequence[Dict]) -> None: + data_samples: Sequence[Dict]) -> None: """Process one batch of data samples and predictions. The processed results should be stored in ``self.results``, which will be used to compute the metrics when all batches have been processed. Args: data_batch (Sequence[Dict]): A batch of data from dataloader. - predictions (Sequence[Dict]): A batch of outputs from + data_samples (Sequence[Dict]): A batch of outputs from the model. """ - for data_sample in predictions: + for data_sample in data_samples: pred_instances = data_sample.get('pred_instances') pred_polygons = pred_instances.get('polygons') diff --git a/mmocr/evaluation/metrics/recog_metric.py b/mmocr/evaluation/metrics/recog_metric.py index 5ed0e0fc..a0469512 100644 --- a/mmocr/evaluation/metrics/recog_metric.py +++ b/mmocr/evaluation/metrics/recog_metric.py @@ -51,16 +51,16 @@ class WordMetric(BaseMetric): self.mode = set(mode) def process(self, data_batch: Sequence[Dict], - predictions: Sequence[Dict]) -> None: - """Process one batch of predictions. The processed results should be + data_samples: Sequence[Dict]) -> None: + """Process one batch of data_samples. The processed results should be stored in ``self.results``, which will be used to compute the metrics when all batches have been processed. Args: data_batch (Sequence[Dict]): A batch of gts. - predictions (Sequence[Dict]): A batch of outputs from the model. + data_samples (Sequence[Dict]): A batch of outputs from the model. """ - for data_sample in predictions: + for data_sample in data_samples: match_num = 0 match_ignore_case_num = 0 match_ignore_case_symbol_num = 0 @@ -149,16 +149,16 @@ class CharMetric(BaseMetric): self.valid_symbol = re.compile(valid_symbol) def process(self, data_batch: Sequence[Dict], - predictions: Sequence[Dict]) -> None: - """Process one batch of predictions. The processed results should be + data_samples: Sequence[Dict]) -> None: + """Process one batch of data_samples. The processed results should be stored in ``self.results``, which will be used to compute the metrics when all batches have been processed. Args: data_batch (Sequence[Dict]): A batch of gts. - predictions (Sequence[Dict]): A batch of outputs from the model. + data_samples (Sequence[Dict]): A batch of outputs from the model. """ - for data_sample in predictions: + for data_sample in data_samples: pred_text = data_sample.get('pred_text').get('item') gt_text = data_sample.get('gt_text').get('item') gt_text_lower = gt_text.lower() @@ -249,16 +249,16 @@ class OneMinusNEDMetric(BaseMetric): self.valid_symbol = re.compile(valid_symbol) def process(self, data_batch: Sequence[Dict], - predictions: Sequence[Dict]) -> None: - """Process one batch of predictions. The processed results should be + data_samples: Sequence[Dict]) -> None: + """Process one batch of data_samples. The processed results should be stored in ``self.results``, which will be used to compute the metrics when all batches have been processed. Args: data_batch (Sequence[Dict]): A batch of gts. - predictions (Sequence[Dict]): A batch of outputs from the model. + data_samples (Sequence[Dict]): A batch of outputs from the model. """ - for data_sample in predictions: + for data_sample in data_samples: pred_text = data_sample.get('pred_text').get('item') gt_text = data_sample.get('gt_text').get('item') gt_text_lower = gt_text.lower() diff --git a/mmocr/models/kie/extractors/sdmgr.py b/mmocr/models/kie/extractors/sdmgr.py index add7c481..670dcdf5 100644 --- a/mmocr/models/kie/extractors/sdmgr.py +++ b/mmocr/models/kie/extractors/sdmgr.py @@ -84,8 +84,8 @@ class SDMGR(BaseModel): return feats.view(feats.size(0), -1) def forward(self, - batch_inputs: torch.Tensor, - batch_data_samples: Sequence[KIEDataSample] = None, + inputs: torch.Tensor, + data_samples: Sequence[KIEDataSample] = None, mode: str = 'tensor', **kwargs) -> torch.Tensor: """The unified entry for a forward process in both training and test. @@ -103,9 +103,9 @@ class SDMGR(BaseModel): optimizer updating, which are done in the :meth:`train_step`. Args: - batch_inputs (torch.Tensor): The input tensor with shape + inputs (torch.Tensor): The input tensor with shape (N, C, ...) in general. - batch_data_samples (list[:obj:`DetDataSample`], optional): The + data_samples (list[:obj:`DetDataSample`], optional): The annotation data of every samples. Defaults to None. mode (str): Return what kind of value. Defaults to 'tensor'. @@ -117,43 +117,42 @@ class SDMGR(BaseModel): - If ``mode="loss"``, return a dict of tensor. """ if mode == 'loss': - return self.loss(batch_inputs, batch_data_samples, **kwargs) + return self.loss(inputs, data_samples, **kwargs) elif mode == 'predict': - return self.predict(batch_inputs, batch_data_samples, **kwargs) + return self.predict(inputs, data_samples, **kwargs) elif mode == 'tensor': - return self._forward(batch_inputs, batch_data_samples, **kwargs) + return self._forward(inputs, data_samples, **kwargs) else: raise RuntimeError(f'Invalid mode "{mode}". ' 'Only supports loss, predict and tensor mode') - def loss(self, batch_inputs: torch.Tensor, - batch_data_samples: Sequence[KIEDataSample], **kwargs) -> dict: + def loss(self, inputs: torch.Tensor, data_samples: Sequence[KIEDataSample], + **kwargs) -> dict: """Calculate losses from a batch of inputs and data samples. Args: - batch_inputs (torch.Tensor): Input images of shape (N, C, H, W). + inputs (torch.Tensor): Input images of shape (N, C, H, W). Typically these should be mean centered and std scaled. - batch_data_samples (list[KIEDataSample]): A list of N datasamples, + data_samples (list[KIEDataSample]): A list of N datasamples, containing meta information and gold annotations for each of the images. Returns: dict[str, Tensor]: A dictionary of loss components. """ - x = self.extract_feat(batch_inputs, [ - data_sample.gt_instances.bboxes - for data_sample in batch_data_samples - ]) - return self.kie_head.loss(x, batch_data_samples) + x = self.extract_feat( + inputs, + [data_sample.gt_instances.bboxes for data_sample in data_samples]) + return self.kie_head.loss(x, data_samples) - def predict(self, batch_inputs: torch.Tensor, - batch_data_samples: Sequence[KIEDataSample], + def predict(self, inputs: torch.Tensor, + data_samples: Sequence[KIEDataSample], **kwargs) -> List[KIEDataSample]: """Predict results from a batch of inputs and data samples with post- processing. Args: - batch_inputs (torch.Tensor): Input images of shape (N, C, H, W). + inputs (torch.Tensor): Input images of shape (N, C, H, W). Typically these should be mean centered and std scaled. - batch_data_samples (list[KIEDataSample]): A list of N datasamples, + data_samples (list[KIEDataSample]): A list of N datasamples, containing meta information and gold annotations for each of the images. @@ -162,22 +161,21 @@ class SDMGR(BaseModel): Results are stored in ``pred_instances.labels`` and ``pred_instances.edge_labels``. """ - x = self.extract_feat(batch_inputs, [ - data_sample.gt_instances.bboxes - for data_sample in batch_data_samples - ]) - return self.kie_head.predict(x, batch_data_samples) + x = self.extract_feat( + inputs, + [data_sample.gt_instances.bboxes for data_sample in data_samples]) + return self.kie_head.predict(x, data_samples) - def _forward(self, batch_inputs: torch.Tensor, - batch_data_samples: Sequence[KIEDataSample], + def _forward(self, inputs: torch.Tensor, + data_samples: Sequence[KIEDataSample], **kwargs) -> Tuple[torch.Tensor, torch.Tensor]: """Get the raw tensor outputs from backbone and head without any post- processing. Args: - batch_inputs (torch.Tensor): Input images of shape (N, C, H, W). + inputs (torch.Tensor): Input images of shape (N, C, H, W). Typically these should be mean centered and std scaled. - batch_data_samples (list[KIEDataSample]): A list of N datasamples, + data_samples (list[KIEDataSample]): A list of N datasamples, containing meta information and gold annotations for each of the images. @@ -187,8 +185,7 @@ class SDMGR(BaseModel): - node_cls (torch.Tensor): Node classification output. - edge_cls (torch.Tensor): Edge classification output. """ - x = self.extract_feat(batch_inputs, [ - data_sample.gt_instances.bboxes - for data_sample in batch_data_samples - ]) - return self.kie_head(x, batch_data_samples) + x = self.extract_feat( + inputs, + [data_sample.gt_instances.bboxes for data_sample in data_samples]) + return self.kie_head(x, data_samples) diff --git a/mmocr/models/kie/heads/sdmgr_head.py b/mmocr/models/kie/heads/sdmgr_head.py index c98e508b..311e8709 100644 --- a/mmocr/models/kie/heads/sdmgr_head.py +++ b/mmocr/models/kie/heads/sdmgr_head.py @@ -83,28 +83,26 @@ class SDMGRHead(BaseModule): self.postprocessor = MODELS.build(postprocessor) self.relation_norm = relation_norm - def loss(self, batch_inputs: Tensor, - batch_data_samples: List[KIEDataSample]) -> Dict: + def loss(self, inputs: Tensor, data_samples: List[KIEDataSample]) -> Dict: """Calculate losses from a batch of inputs and data samples. Args: - batch_inputs (torch.Tensor): Shape :math:`(N, E)`. - batch_data_samples (List[KIEDataSample]): List of data samples. + inputs (torch.Tensor): Shape :math:`(N, E)`. + data_samples (List[KIEDataSample]): List of data samples. Returns: dict[str, tensor]: A dictionary of loss components. """ - preds = self.forward(batch_inputs, batch_data_samples) - return self.module_loss(preds, batch_data_samples) + preds = self.forward(inputs, data_samples) + return self.module_loss(preds, data_samples) - def predict(self, batch_inputs: Tensor, - batch_data_samples: List[KIEDataSample] - ) -> List[KIEDataSample]: + def predict(self, inputs: Tensor, + data_samples: List[KIEDataSample]) -> List[KIEDataSample]: """Predict results from a batch of inputs and data samples with post- processing. Args: - batch_inputs (torch.Tensor): Shape :math:`(N, E)`. - batch_data_samples (List[KIEDataSample]): List of data samples. + inputs (torch.Tensor): Shape :math:`(N, E)`. + data_samples (List[KIEDataSample]): List of data samples. Returns: List[KIEDataSample]: A list of datasamples of prediction results. @@ -121,16 +119,15 @@ class SDMGRHead(BaseModule): - edge_scores (Tensor): A float tensor of shape (N, ), indicating the confidence scores for edge predictions. """ - preds = self.forward(batch_inputs, batch_data_samples) - return self.postprocessor(preds, batch_data_samples) + preds = self.forward(inputs, data_samples) + return self.postprocessor(preds, data_samples) - def forward(self, batch_inputs: Tensor, - batch_data_samples: List[KIEDataSample] - ) -> Tuple[Tensor, Tensor]: + def forward(self, inputs: Tensor, + data_samples: List[KIEDataSample]) -> Tuple[Tensor, Tensor]: """ Args: - batch_inputs (torch.Tensor): Shape :math:`(N, E)`. - batch_data_samples (List[KIEDataSample]): List of data samples. + inputs (torch.Tensor): Shape :math:`(N, E)`. + data_samples (List[KIEDataSample]): List of data samples. Returns: tuple(Tensor, Tensor): @@ -143,8 +140,7 @@ class SDMGRHead(BaseModule): device = self.node_embed.weight.device - node_nums, char_nums, all_nodes = self.convert_texts( - batch_data_samples) + node_nums, char_nums, all_nodes = self.convert_texts(data_samples) embed_nodes = self.node_embed(all_nodes.to(device).long()) rnn_nodes, _ = self.rnn(embed_nodes) @@ -156,10 +152,10 @@ class SDMGRHead(BaseModule): 1, (all_nums[valid] - 1).unsqueeze(-1).unsqueeze(-1).expand( -1, -1, rnn_nodes.size(-1))).squeeze(1) - if batch_inputs is not None: - nodes = self.fusion([batch_inputs, nodes]) + if inputs is not None: + nodes = self.fusion([inputs, nodes]) - relations = self.compute_relations(batch_data_samples) + relations = self.compute_relations(data_samples) all_edges = torch.cat( [relation.view(-1, relation.size(-1)) for relation in relations], dim=0) diff --git a/mmocr/models/kie/postprocessors/sdmgr_postprocessor.py b/mmocr/models/kie/postprocessors/sdmgr_postprocessor.py index f09ca9d4..c6303618 100644 --- a/mmocr/models/kie/postprocessors/sdmgr_postprocessor.py +++ b/mmocr/models/kie/postprocessors/sdmgr_postprocessor.py @@ -52,8 +52,7 @@ class SDMGRPostProcessor: self.softmax = nn.Softmax(dim=-1) def __call__(self, preds: Tuple[Tensor, Tensor], - batch_data_samples: List[KIEDataSample] - ) -> List[KIEDataSample]: + data_samples: List[KIEDataSample]) -> List[KIEDataSample]: """Postprocess raw outputs from SDMGR heads and pack the results into a list of KIEDataSample. @@ -83,7 +82,7 @@ class SDMGRPostProcessor: all_edge_scores = self.softmax(edge_preds) chunk_size = [ data_sample.gt_instances.bboxes.shape[0] - for data_sample in batch_data_samples + for data_sample in data_samples ] node_scores, node_preds = torch.max(all_node_scores, dim=-1) edge_scores, edge_preds = torch.max(all_edge_scores, dim=-1) @@ -97,17 +96,17 @@ class SDMGRPostProcessor: edge_preds[i] = edge_preds[i].reshape((chunk, chunk)) edge_scores[i] = edge_scores[i].reshape((chunk, chunk)) - for i in range(len(batch_data_samples)): - batch_data_samples[i].pred_instances = InstanceData() - batch_data_samples[i].pred_instances.labels = node_preds[i] - batch_data_samples[i].pred_instances.scores = node_scores[i] + for i in range(len(data_samples)): + data_samples[i].pred_instances = InstanceData() + data_samples[i].pred_instances.labels = node_preds[i] + data_samples[i].pred_instances.scores = node_scores[i] if self.link_type != 'none': edge_scores[i], edge_preds[i] = self.decode_edges( node_preds[i], edge_scores[i], edge_preds[i]) - batch_data_samples[i].pred_instances.edge_labels = edge_preds[i] - batch_data_samples[i].pred_instances.edge_scores = edge_scores[i] + data_samples[i].pred_instances.edge_labels = edge_preds[i] + data_samples[i].pred_instances.edge_scores = edge_scores[i] - return batch_data_samples + return data_samples def decode_edges(self, node_labels: Tensor, edge_scores: Tensor, edge_labels: Tensor) -> Tuple[Tensor, Tensor]: diff --git a/mmocr/models/textdet/data_preprocessors/data_preprocessor.py b/mmocr/models/textdet/data_preprocessors/data_preprocessor.py index c6f87eed..990f0b14 100644 --- a/mmocr/models/textdet/data_preprocessors/data_preprocessor.py +++ b/mmocr/models/textdet/data_preprocessors/data_preprocessor.py @@ -1,8 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. from numbers import Number -from typing import List, Optional, Sequence, Tuple, Union +from typing import Dict, List, Optional, Sequence, Union -import torch import torch.nn as nn from mmengine.model import ImgDataPreprocessor @@ -59,7 +58,7 @@ class TextDetDataPreprocessor(ImgDataPreprocessor): pad_value: Union[float, int] = 0, bgr_to_rgb: bool = False, rgb_to_bgr: bool = False, - batch_augments: Optional[List[dict]] = None): + batch_augments: Optional[List[Dict]] = None) -> None: super().__init__( mean=mean, std=std, @@ -73,32 +72,28 @@ class TextDetDataPreprocessor(ImgDataPreprocessor): else: self.batch_augments = None - def forward(self, - data: Sequence[dict], - training: bool = False) -> Tuple[torch.Tensor, Optional[list]]: + def forward(self, data: Dict, training: bool = False) -> Dict: """Perform normalization、padding and bgr2rgb conversion based on ``BaseDataPreprocessor``. Args: - data (Sequence[dict]): data sampled from dataloader. + data (dict): data sampled from dataloader. training (bool): Whether to enable training time augmentation. Returns: - Tuple[torch.Tensor, Optional[list]]: Data in the same format as the - model input. + dict: Data in the same format as the model input. """ - batch_inputs, batch_data_samples = super().forward( - data=data, training=training) + data = super().forward(data=data, training=training) + inputs, data_samples = data['inputs'], data['data_samples'] - if batch_data_samples is not None: - batch_input_shape = tuple(batch_inputs[0].size()[-2:]) - for data_samples in batch_data_samples: - data_samples.set_metainfo( + if data_samples is not None: + batch_input_shape = tuple(inputs[0].size()[-2:]) + for data_sample in data_samples: + data_sample.set_metainfo( {'batch_input_shape': batch_input_shape}) if training and self.batch_augments is not None: for batch_aug in self.batch_augments: - batch_inputs, batch_data_samples = batch_aug( - batch_inputs, batch_data_samples) + inputs, data_samples = batch_aug(inputs, data_samples) - return batch_inputs, batch_data_samples + return data diff --git a/mmocr/models/textdet/detectors/mmdet_wrapper.py b/mmocr/models/textdet/detectors/mmdet_wrapper.py index 1859fdda..368e524e 100644 --- a/mmocr/models/textdet/detectors/mmdet_wrapper.py +++ b/mmocr/models/textdet/detectors/mmdet_wrapper.py @@ -35,8 +35,8 @@ class MMDetWrapper(BaseModel): self.text_repr_type = text_repr_type def forward(self, - batch_inputs: torch.Tensor, - batch_data_samples: OptSampleList = None, + inputs: torch.Tensor, + data_samples: OptSampleList = None, mode: str = 'tensor', **kwargs) -> ForwardResults: """The unified entry for a forward process in both training and test. @@ -54,9 +54,9 @@ class MMDetWrapper(BaseModel): optimizer updating, which are done in the :meth:`train_step`. Args: - batch_inputs (torch.Tensor): The input tensor with shape + inputs (torch.Tensor): The input tensor with shape (N, C, ...) in general. - batch_data_samples (list[:obj:`DetDataSample`], optional): The + data_samples (list[:obj:`DetDataSample`], optional): The annotation data of every samples. Defaults to None. mode (str): Return what kind of value. Defaults to 'tensor'. @@ -67,8 +67,8 @@ class MMDetWrapper(BaseModel): - If ``mode="predict"``, return a list of :obj:`TextDetDataSample`. - If ``mode="loss"``, return a dict of tensor. """ - results = self.wrapped_model.forward(batch_inputs, batch_data_samples, - mode, **kwargs) + results = self.wrapped_model.forward(inputs, data_samples, mode, + **kwargs) if mode == 'predict': results = self.adapt_predictions(results) diff --git a/mmocr/models/textdet/detectors/single_stage_text_detector.py b/mmocr/models/textdet/detectors/single_stage_text_detector.py index b30209d2..5617e26a 100644 --- a/mmocr/models/textdet/detectors/single_stage_text_detector.py +++ b/mmocr/models/textdet/detectors/single_stage_text_detector.py @@ -44,47 +44,46 @@ class SingleStageTextDetector(BaseTextDetector): self.neck = MODELS.build(neck) self.det_head = MODELS.build(det_head) - def extract_feat(self, batch_inputs: torch.Tensor) -> torch.Tensor: + def extract_feat(self, inputs: torch.Tensor) -> torch.Tensor: """Extract features. Args: - batch_inputs (Tensor): Image tensor with shape (N, C, H ,W). + inputs (Tensor): Image tensor with shape (N, C, H ,W). Returns: Tensor or tuple[Tensor]: Multi-level features that may have different resolutions. """ - batch_inputs = self.backbone(batch_inputs) + inputs = self.backbone(inputs) if self.with_neck: - batch_inputs = self.neck(batch_inputs) - return batch_inputs + inputs = self.neck(inputs) + return inputs - def loss(self, batch_inputs: torch.Tensor, - batch_data_samples: Sequence[TextDetDataSample]) -> Dict: + def loss(self, inputs: torch.Tensor, + data_samples: Sequence[TextDetDataSample]) -> Dict: """Calculate losses from a batch of inputs and data samples. Args: - batch_inputs (torch.Tensor): Input images of shape (N, C, H, W). + inputs (torch.Tensor): Input images of shape (N, C, H, W). Typically these should be mean centered and std scaled. - batch_data_samples (list[TextDetDataSample]): A list of N + data_samples (list[TextDetDataSample]): A list of N datasamples, containing meta information and gold annotations for each of the images. Returns: dict[str, Tensor]: A dictionary of loss components. """ - batch_inputs = self.extract_feat(batch_inputs) - return self.det_head.loss(batch_inputs, batch_data_samples) + inputs = self.extract_feat(inputs) + return self.det_head.loss(inputs, data_samples) - def predict( - self, batch_inputs: torch.Tensor, - batch_data_samples: Sequence[TextDetDataSample] - ) -> Sequence[TextDetDataSample]: + def predict(self, inputs: torch.Tensor, + data_samples: Sequence[TextDetDataSample] + ) -> Sequence[TextDetDataSample]: """Predict results from a batch of inputs and data samples with post- processing. Args: - batch_inputs (torch.Tensor): Images of shape (N, C, H, W). - batch_data_samples (list[TextDetDataSample]): A list of N + inputs (torch.Tensor): Images of shape (N, C, H, W). + data_samples (list[TextDetDataSample]): A list of N datasamples, containing meta information and gold annotations for each of the images. @@ -104,20 +103,19 @@ class SingleStageTextDetector(BaseTextDetector): Each element represents the polygon of the instance, in (xn, yn) order. """ - x = self.extract_feat(batch_inputs) - return self.det_head.predict(x, batch_data_samples) + x = self.extract_feat(inputs) + return self.det_head.predict(x, data_samples) def _forward(self, - batch_inputs: torch.Tensor, - batch_data_samples: Optional[ - Sequence[TextDetDataSample]] = None, + inputs: torch.Tensor, + data_samples: Optional[Sequence[TextDetDataSample]] = None, **kwargs) -> torch.Tensor: """Network forward process. Usually includes backbone, neck and head forward without any post-processing. Args: - batch_inputs (Tensor): Inputs with shape (N, C, H, W). - batch_data_samples (list[TextDetDataSample]): A list of N + inputs (Tensor): Inputs with shape (N, C, H, W). + data_samples (list[TextDetDataSample]): A list of N datasamples, containing meta information and gold annotations for each of the images. @@ -125,5 +123,5 @@ class SingleStageTextDetector(BaseTextDetector): Tensor or tuple[Tensor]: A tuple of features from ``det_head`` forward. """ - x = self.extract_feat(batch_inputs) - return self.det_head(x, batch_data_samples) + x = self.extract_feat(inputs) + return self.det_head(x, data_samples) diff --git a/mmocr/models/textdet/heads/base.py b/mmocr/models/textdet/heads/base.py index ffb6e846..0e06b597 100644 --- a/mmocr/models/textdet/heads/base.py +++ b/mmocr/models/textdet/heads/base.py @@ -63,34 +63,32 @@ class BaseTextDetHead(BaseModule): self.module_loss = MODELS.build(module_loss) self.postprocessor = MODELS.build(postprocessor) - def loss(self, x: Tuple[Tensor], - batch_data_samples: DetSampleList) -> dict: + def loss(self, x: Tuple[Tensor], data_samples: DetSampleList) -> dict: """Perform forward propagation and loss calculation of the detection head on the features of the upstream network. Args: x (tuple[Tensor]): Features from the upstream network, each is a 4D-tensor. - batch_data_samples (List[:obj:`DetDataSample`]): The Data + data_samples (List[:obj:`DetDataSample`]): The Data Samples. It usually includes information such as `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. Returns: dict: A dictionary of loss components. """ - outs = self(x, batch_data_samples) - losses = self.module_loss(outs, batch_data_samples) + outs = self(x, data_samples) + losses = self.module_loss(outs, data_samples) return losses - def loss_and_predict(self, x: Tuple[Tensor], - batch_data_samples: DetSampleList + def loss_and_predict(self, x: Tuple[Tensor], data_samples: DetSampleList ) -> Tuple[dict, DetSampleList]: """Perform forward propagation of the head, then calculate loss and predictions from the features and data samples. Args: x (tuple[Tensor]): Features from FPN. - batch_data_samples (list[:obj:`DetDataSample`]): Each item contains + data_samples (list[:obj:`DetDataSample`]): Each item contains the meta information of each image and corresponding annotations. @@ -101,21 +99,21 @@ class BaseTextDetHead(BaseModule): - predictions (list[:obj:`InstanceData`]): Detection results of each image after the post process. """ - outs = self(x, batch_data_samples) - losses = self.module_loss(outs, batch_data_samples) + outs = self(x, data_samples) + losses = self.module_loss(outs, data_samples) - predictions = self.postprocessor(outs, batch_data_samples) + predictions = self.postprocessor(outs, data_samples) return losses, predictions def predict(self, x: torch.Tensor, - batch_data_samples: DetSampleList) -> DetSampleList: + data_samples: DetSampleList) -> DetSampleList: """Perform forward propagation of the detection head and predict detection results on the features of the upstream network. Args: x (tuple[Tensor]): Multi-level features from the upstream network, each is a 4D-tensor. - batch_data_samples (List[:obj:`DetDataSample`]): The Data + data_samples (List[:obj:`DetDataSample`]): The Data Samples. It usually includes information such as `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. @@ -123,7 +121,7 @@ class BaseTextDetHead(BaseModule): SampleList: Detection results of each image after the post process. """ - outs = self(x, batch_data_samples) + outs = self(x, data_samples) - predictions = self.postprocessor(outs, batch_data_samples) + predictions = self.postprocessor(outs, data_samples) return predictions diff --git a/mmocr/models/textdet/heads/drrg_head.py b/mmocr/models/textdet/heads/drrg_head.py index 37a6d85e..353dc2ec 100644 --- a/mmocr/models/textdet/heads/drrg_head.py +++ b/mmocr/models/textdet/heads/drrg_head.py @@ -261,15 +261,13 @@ class DRRGHead(BaseTextDetHead): self.in_channels + self.out_channels) + self.node_geo_feat_len self.gcn = GCN(node_feat_len) - def loss( - self, batch_inputs: torch.Tensor, - batch_data_samples: List[TextDetDataSample] - ) -> Tuple[Tensor, Tensor, Tensor]: + def loss(self, inputs: torch.Tensor, data_samples: List[TextDetDataSample] + ) -> Tuple[Tensor, Tensor, Tensor]: """Loss function. Args: - batch_inputs (Tensor): Shape of :math:`(N, C, H, W)`. - batch_data_samples (List[TextDetDataSample]): List of data samples. + inputs (Tensor): Shape of :math:`(N, C, H, W)`. + data_samples (List[TextDetDataSample]): List of data samples. Returns: tuple(pred_maps, gcn_pred, gt_labels): @@ -281,27 +279,26 @@ class DRRGHead(BaseTextDetHead): - gt_labels (Tensor): Ground-truth label of shape :math:`(m, n)` where :math:`m * n = N`. """ - targets = self.module_loss.get_targets(batch_data_samples) + targets = self.module_loss.get_targets(data_samples) gt_comp_attribs = targets[-1] - pred_maps = self.out_conv(batch_inputs) - feat_maps = torch.cat([batch_inputs, pred_maps], dim=1) + pred_maps = self.out_conv(inputs) + feat_maps = torch.cat([inputs, pred_maps], dim=1) node_feats, adjacent_matrices, knn_inds, gt_labels = self.graph_train( feat_maps, np.stack(gt_comp_attribs)) gcn_pred = self.gcn(node_feats, adjacent_matrices, knn_inds) - return self.module_loss((pred_maps, gcn_pred, gt_labels), - batch_data_samples) + return self.module_loss((pred_maps, gcn_pred, gt_labels), data_samples) def forward( self, - batch_inputs: Tensor, + inputs: Tensor, data_samples: Optional[List[TextDetDataSample]] = None ) -> Tuple[Tensor, Tensor, Tensor]: r"""Run DRRG head in prediction mode, and return the raw tensors only. Args: - batch_inputs (Tensor): Shape of :math:`(1, C, H, W)`. + inputs (Tensor): Shape of :math:`(1, C, H, W)`. data_samples (list[TextDetDataSample], optional): A list of data samples. Defaults to None. @@ -317,10 +314,10 @@ class DRRGHead(BaseTextDetHead): :math:`(M, 9)` where each row corresponds to one box and its score: (x1, y1, x2, y2, x3, y3, x4, y4, score). """ - pred_maps = self.out_conv(batch_inputs) - batch_inputs = torch.cat([batch_inputs, pred_maps], dim=1) + pred_maps = self.out_conv(inputs) + inputs = torch.cat([inputs, pred_maps], dim=1) - none_flag, graph_data = self.graph_test(pred_maps, batch_inputs) + none_flag, graph_data = self.graph_test(pred_maps, inputs) (local_graphs_node_feat, adjacent_matrices, pivots_knn_inds, pivot_local_graphs, text_comps) = graph_data diff --git a/mmocr/models/textrecog/data_preprocessors/data_preprocessor.py b/mmocr/models/textrecog/data_preprocessors/data_preprocessor.py index 16696d02..99ae1719 100644 --- a/mmocr/models/textrecog/data_preprocessors/data_preprocessor.py +++ b/mmocr/models/textrecog/data_preprocessors/data_preprocessor.py @@ -1,8 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. from numbers import Number -from typing import List, Optional, Sequence, Tuple, Union +from typing import Dict, List, Optional, Sequence, Union -import torch import torch.nn as nn from mmengine.model import ImgDataPreprocessor @@ -25,7 +24,7 @@ class TextRecogDataPreprocessor(ImgDataPreprocessor): - Pad inputs to the maximum size of current batch with defined ``pad_value``. The padding size can be divisible by a defined ``pad_size_divisor`` - - Stack inputs to batch_inputs. + - Stack inputs to inputs. - Convert inputs from bgr to rgb if the shape of input is (3, H, W). - Normalize image with defined std and mean. - Do batch augmentations during training. @@ -52,7 +51,7 @@ class TextRecogDataPreprocessor(ImgDataPreprocessor): pad_value: Union[float, int] = 0, bgr_to_rgb: bool = False, rgb_to_bgr: bool = False, - batch_augments: Optional[List[dict]] = None): + batch_augments: Optional[List[Dict]] = None) -> None: super().__init__( mean=mean, std=std, @@ -66,37 +65,33 @@ class TextRecogDataPreprocessor(ImgDataPreprocessor): else: self.batch_augments = None - def forward(self, - data: Sequence[dict], - training: bool = False) -> Tuple[torch.Tensor, Optional[list]]: + def forward(self, data: Dict, training: bool = False) -> Dict: """Perform normalization、padding and bgr2rgb conversion based on ``BaseDataPreprocessor``. Args: - data (Sequence[dict]): data sampled from dataloader. + data (dict): Data sampled from dataloader. training (bool): Whether to enable training time augmentation. Returns: - Tuple[torch.Tensor, Optional[list]]: Data in the same format as the - model input. + dict: Data in the same format as the model input. """ - batch_inputs, batch_data_samples = super().forward( - data=data, training=training) + data = super().forward(data=data, training=training) + inputs, data_samples = data['inputs'], data['data_samples'] - if batch_data_samples is not None: - batch_input_shape = tuple(batch_inputs[0].size()[-2:]) - for data_samples in batch_data_samples: + if data_samples is not None: + batch_input_shape = tuple(inputs[0].size()[-2:]) + for data_sample in data_samples: - valid_ratio = data_samples.valid_ratio * \ - data_samples.img_shape[1] / batch_input_shape[1] - data_samples.set_metainfo( + valid_ratio = data_sample.valid_ratio * \ + data_sample.img_shape[1] / batch_input_shape[1] + data_sample.set_metainfo( dict( valid_ratio=valid_ratio, batch_input_shape=batch_input_shape)) if training and self.batch_augments is not None: for batch_aug in self.batch_augments: - batch_inputs, batch_data_samples = batch_aug( - batch_inputs, batch_data_samples) + inputs, data_samples = batch_aug(inputs, data_samples) - return batch_inputs, batch_data_samples + return data diff --git a/mmocr/models/textrecog/recognizers/base.py b/mmocr/models/textrecog/recognizers/base.py index 9d062b0e..41f5aa1d 100644 --- a/mmocr/models/textrecog/recognizers/base.py +++ b/mmocr/models/textrecog/recognizers/base.py @@ -52,8 +52,8 @@ class BaseRecognizer(BaseModel, metaclass=ABCMeta): pass def forward(self, - batch_inputs: torch.Tensor, - batch_data_samples: OptRecSampleList = None, + inputs: torch.Tensor, + data_samples: OptRecSampleList = None, mode: str = 'tensor', **kwargs) -> RecForwardResults: """The unified entry for a forward process in both training and test. @@ -71,9 +71,9 @@ class BaseRecognizer(BaseModel, metaclass=ABCMeta): optimizer updating, which are done in the :meth:`train_step`. Args: - batch_inputs (torch.Tensor): The input tensor with shape + inputs (torch.Tensor): The input tensor with shape (N, C, ...) in general. - batch_data_samples (list[:obj:`DetDataSample`], optional): The + data_samples (list[:obj:`DetDataSample`], optional): The annotation data of every samples. Defaults to None. mode (str): Return what kind of value. Defaults to 'tensor'. @@ -85,33 +85,32 @@ class BaseRecognizer(BaseModel, metaclass=ABCMeta): - If ``mode="loss"``, return a dict of tensor. """ if mode == 'loss': - return self.loss(batch_inputs, batch_data_samples, **kwargs) + return self.loss(inputs, data_samples, **kwargs) elif mode == 'predict': - return self.predict(batch_inputs, batch_data_samples, **kwargs) + return self.predict(inputs, data_samples, **kwargs) elif mode == 'tensor': - return self._forward(batch_inputs, batch_data_samples, **kwargs) + return self._forward(inputs, data_samples, **kwargs) else: raise RuntimeError(f'Invalid mode "{mode}". ' 'Only supports loss, predict and tensor mode') @abstractmethod - def loss(self, batch_inputs: torch.Tensor, - batch_data_samples: RecSampleList, + def loss(self, inputs: torch.Tensor, data_samples: RecSampleList, **kwargs) -> Union[dict, tuple]: """Calculate losses from a batch of inputs and data samples.""" pass @abstractmethod - def predict(self, batch_inputs: torch.Tensor, - batch_data_samples: RecSampleList, **kwargs) -> RecSampleList: + def predict(self, inputs: torch.Tensor, data_samples: RecSampleList, + **kwargs) -> RecSampleList: """Predict results from a batch of inputs and data samples with post- processing.""" pass @abstractmethod def _forward(self, - batch_inputs: torch.Tensor, - batch_data_samples: OptRecSampleList = None, + inputs: torch.Tensor, + data_samples: OptRecSampleList = None, **kwargs): """Network forward process. diff --git a/mmocr/models/textrecog/recognizers/encoder_decoder_recognizer.py b/mmocr/models/textrecog/recognizers/encoder_decoder_recognizer.py index c6aa295b..9be410b0 100644 --- a/mmocr/models/textrecog/recognizers/encoder_decoder_recognizer.py +++ b/mmocr/models/textrecog/recognizers/encoder_decoder_recognizer.py @@ -59,16 +59,16 @@ class EncoderDecoderRecognizer(BaseRecognizer): assert decoder is not None self.decoder = MODELS.build(decoder) - def extract_feat(self, batch_inputs: torch.Tensor) -> torch.Tensor: + def extract_feat(self, inputs: torch.Tensor) -> torch.Tensor: """Directly extract features from the backbone.""" if self.with_preprocessor: - batch_inputs = self.preprocessor(batch_inputs) + inputs = self.preprocessor(inputs) if self.with_backbone: - batch_inputs = self.backbone(batch_inputs) - return batch_inputs + inputs = self.backbone(inputs) + return inputs - def loss(self, batch_inputs: torch.Tensor, - batch_data_samples: RecSampleList, **kwargs) -> Dict: + def loss(self, inputs: torch.Tensor, data_samples: RecSampleList, + **kwargs) -> Dict: """Calculate losses from a batch of inputs and data samples. Args: inputs (tensor): Input images of shape (N, C, H, W). @@ -80,14 +80,14 @@ class EncoderDecoderRecognizer(BaseRecognizer): Returns: dict[str, tensor]: A dictionary of loss components. """ - feat = self.extract_feat(batch_inputs) + feat = self.extract_feat(inputs) out_enc = None if self.with_encoder: - out_enc = self.encoder(feat, batch_data_samples) - return self.decoder.loss(feat, out_enc, batch_data_samples) + out_enc = self.encoder(feat, data_samples) + return self.decoder.loss(feat, out_enc, data_samples) - def predict(self, batch_inputs: torch.Tensor, - batch_data_samples: RecSampleList, **kwargs) -> RecSampleList: + def predict(self, inputs: torch.Tensor, data_samples: RecSampleList, + **kwargs) -> RecSampleList: """Predict results from a batch of inputs and data samples with post- processing. @@ -101,30 +101,30 @@ class EncoderDecoderRecognizer(BaseRecognizer): list[TextRecogDataSample]: A list of N datasamples of prediction results. Results are stored in ``pred_text``. """ - feat = self.extract_feat(batch_inputs) + feat = self.extract_feat(inputs) out_enc = None if self.with_encoder: - out_enc = self.encoder(feat, batch_data_samples) - return self.decoder.predict(feat, out_enc, batch_data_samples) + out_enc = self.encoder(feat, data_samples) + return self.decoder.predict(feat, out_enc, data_samples) def _forward(self, - batch_inputs: torch.Tensor, - batch_data_samples: OptRecSampleList = None, + inputs: torch.Tensor, + data_samples: OptRecSampleList = None, **kwargs) -> RecForwardResults: """Network forward process. Usually includes backbone, encoder and decoder forward without any post-processing. Args: - batch_inputs (Tensor): Inputs with shape (N, C, H, W). - batch_data_samples (list[TextRecogDataSample]): A list of N + inputs (Tensor): Inputs with shape (N, C, H, W). + data_samples (list[TextRecogDataSample]): A list of N datasamples, containing meta information and gold annotations for each of the images. Returns: Tensor: A tuple of features from ``decoder`` forward. """ - feat = self.extract_feat(batch_inputs) + feat = self.extract_feat(inputs) out_enc = None if self.with_encoder: - out_enc = self.encoder(feat, batch_data_samples) - return self.decoder(feat, out_enc, batch_data_samples) + out_enc = self.encoder(feat, data_samples) + return self.decoder(feat, out_enc, data_samples) diff --git a/tests/test_datasets/test_transforms/test_formatting.py b/tests/test_datasets/test_transforms/test_formatting.py index 053c570f..af3b84e2 100644 --- a/tests/test_datasets/test_transforms/test_formatting.py +++ b/tests/test_datasets/test_transforms/test_formatting.py @@ -37,9 +37,9 @@ class TestPackTextDetInputs(TestCase): results = transform(copy.deepcopy(datainfo)) self.assertIn('inputs', results) self.assertTupleEqual(tuple(results['inputs'].shape), (1, 10, 10)) - self.assertIn('data_sample', results) + self.assertIn('data_samples', results) - data_sample = results['data_sample'] + data_sample = results['data_samples'] self.assertIn('bboxes', data_sample.gt_instances) self.assertIsInstance(data_sample.gt_instances.bboxes, torch.Tensor) self.assertEqual(data_sample.gt_instances.bboxes.dtype, torch.float32) @@ -56,9 +56,9 @@ class TestPackTextDetInputs(TestCase): transform = PackTextDetInputs(meta_keys=('img_path', )) results = transform(copy.deepcopy(datainfo)) self.assertIn('inputs', results) - self.assertIn('data_sample', results) + self.assertIn('data_samples', results) - data_sample = results['data_sample'] + data_sample = results['data_samples'] self.assertIn('bboxes', data_sample.gt_instances) self.assertIn('img_path', data_sample) self.assertNotIn('flip', data_sample) @@ -66,14 +66,14 @@ class TestPackTextDetInputs(TestCase): datainfo.pop('gt_texts') transform = PackTextDetInputs() results = transform(copy.deepcopy(datainfo)) - data_sample = results['data_sample'] + data_sample = results['data_samples'] self.assertNotIn('texts', data_sample.gt_instances) datainfo = dict(img_shape=(10, 10)) transform = PackTextDetInputs(meta_keys=('img_shape', )) results = transform(copy.deepcopy(datainfo)) self.assertNotIn('inputs', results) - data_sample = results['data_sample'] + data_sample = results['data_samples'] self.assertNotIn('texts', data_sample.gt_instances) def test_repr(self): @@ -108,8 +108,8 @@ class TestPackTextRecogInputs(TestCase): results = transform(copy.deepcopy(datainfo)) self.assertIn('inputs', results) self.assertTupleEqual(tuple(results['inputs'].shape), (1, 10, 10)) - self.assertIn('data_sample', results) - data_sample = results['data_sample'] + self.assertIn('data_samples', results) + data_sample = results['data_samples'] self.assertEqual(data_sample.gt_text.item, 'mmocr') self.assertIn('img_path', data_sample) self.assertIn('valid_ratio', data_sample) @@ -118,8 +118,8 @@ class TestPackTextRecogInputs(TestCase): transform = PackTextRecogInputs(meta_keys=('img_path', )) results = transform(copy.deepcopy(datainfo)) self.assertIn('inputs', results) - self.assertIn('data_sample', results) - data_sample = results['data_sample'] + self.assertIn('data_samples', results) + data_sample = results['data_samples'] self.assertEqual(data_sample.gt_text.item, 'mmocr') self.assertIn('img_path', data_sample) self.assertNotIn('valid_ratio', data_sample) @@ -129,7 +129,7 @@ class TestPackTextRecogInputs(TestCase): transform = PackTextRecogInputs(meta_keys=('img_shape', )) results = transform(copy.deepcopy(datainfo)) self.assertNotIn('inputs', results) - data_sample = results['data_sample'] + data_sample = results['data_samples'] self.assertNotIn('item', data_sample.gt_text) def test_repr(self): @@ -165,8 +165,8 @@ class TestPackKIEInputs(TestCase): results = self.transform(copy.deepcopy(datainfo)) self.assertIn('inputs', results) self.assertTupleEqual(tuple(results['inputs'].shape), (1, 10, 10)) - self.assertIn('data_sample', results) - data_sample = results['data_sample'] + self.assertIn('data_samples', results) + data_sample = results['data_samples'] self.assertIsInstance(data_sample.gt_instances.bboxes, torch.Tensor) self.assertEqual(data_sample.gt_instances.bboxes.dtype, torch.float32) self.assertEqual(data_sample.gt_instances.labels.dtype, torch.int64) @@ -179,9 +179,9 @@ class TestPackKIEInputs(TestCase): transform = PackKIEInputs(meta_keys=('img_path', )) results = transform(copy.deepcopy(datainfo)) self.assertIn('inputs', results) - self.assertIn('data_sample', results) + self.assertIn('data_samples', results) - data_sample = results['data_sample'] + data_sample = results['data_samples'] self.assertIn('bboxes', data_sample.gt_instances) self.assertIn('img_path', data_sample) diff --git a/tests/test_datasets/test_transforms/test_ocr_transforms.py b/tests/test_datasets/test_transforms/test_ocr_transforms.py index e4d06641..4f32f4f3 100644 --- a/tests/test_datasets/test_transforms/test_ocr_transforms.py +++ b/tests/test_datasets/test_transforms/test_ocr_transforms.py @@ -200,9 +200,3 @@ class TestResize(unittest.TestCase): resize = Resize(scale=(40, 30)) result = resize(dummy_result) self.assertEqual(result['gt_bboxes'].dtype, np.float32) - - -if __name__ == '__main__': - t = TestRandomCrop() - t.test_sample_crop_box() - t.test_transform() diff --git a/tests/test_engine/test_runner/test_multi_loop.py b/tests/test_engine/test_runner/test_multi_loop.py deleted file mode 100644 index 0efe89ba..00000000 --- a/tests/test_engine/test_runner/test_multi_loop.py +++ /dev/null @@ -1,274 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import copy -import shutil -import tempfile -from unittest import TestCase - -import torch -import torch.nn as nn -from mmengine.config import Config -from mmengine.evaluator import BaseMetric -from mmengine.hooks import Hook -from mmengine.model import BaseModel -from mmengine.registry import DATASETS, HOOKS, METRICS, MODELS -from mmengine.runner import Runner -from torch.utils.data import Dataset - -from mmocr.engine.runner import MultiTestLoop, MultiValLoop - - -@MODELS.register_module() -class ToyModel(BaseModel): - - def __init__(self): - super().__init__() - self.linear1 = nn.Linear(2, 2) - self.linear2 = nn.Linear(2, 1) - - def forward(self, batch_inputs, labels, mode='tensor'): - labels = torch.stack(labels) - outputs = self.linear1(batch_inputs) - outputs = self.linear2(outputs) - - if mode == 'tensor': - return outputs - elif mode == 'loss': - loss = (labels - outputs).sum() - outputs = dict(loss=loss) - return outputs - elif mode == 'predict': - return outputs - - -@DATASETS.register_module() -class ToyDataset(Dataset): - METAINFO = dict() # type: ignore - data = torch.randn(12, 2) - label = torch.ones(12) - - @property - def metainfo(self): - return self.METAINFO - - def __len__(self): - return self.data.size(0) - - def __getitem__(self, index): - return dict(inputs=self.data[index], data_sample=self.label[index]) - - -@METRICS.register_module() -class ToyMetric3(BaseMetric): - - def __init__(self, collect_device='cpu', prefix=''): - super().__init__(collect_device=collect_device, prefix=prefix) - - def process(self, data_samples, predictions): - result = {'acc': 1} - self.results.append(result) - - def compute_metrics(self, results): - return dict(acc=1) - - -class TestRunner(TestCase): - - def setUp(self): - self.temp_dir = tempfile.mkdtemp() - epoch_based_cfg = dict( - default_scope='mmocr', - model=dict(type='ToyModel'), - work_dir=self.temp_dir, - train_dataloader=dict( - dataset=dict(type='ToyDataset'), - sampler=dict(type='DefaultSampler', shuffle=True), - batch_size=3, - num_workers=0), - val_dataloader=dict( - dataset=dict(type='ToyDataset'), - sampler=dict(type='DefaultSampler', shuffle=False), - batch_size=3, - num_workers=0), - test_dataloader=dict( - dataset=dict(type='ToyDataset'), - sampler=dict(type='DefaultSampler', shuffle=False), - batch_size=3, - num_workers=0), - auto_scale_lr=dict(base_batch_size=16, enable=False), - optim_wrapper=dict( - type='OptimWrapper', optimizer=dict(type='SGD', lr=0.01)), - param_scheduler=dict(type='MultiStepLR', milestones=[1, 2]), - val_evaluator=dict(type='ToyMetric1'), - test_evaluator=dict(type='ToyMetric1'), - train_cfg=dict( - by_epoch=True, max_epochs=3, val_interval=1, val_begin=1), - val_cfg=dict(), - test_cfg=dict(), - custom_hooks=[], - default_hooks=dict( - runtime_info=dict(type='RuntimeInfoHook'), - timer=dict(type='IterTimerHook'), - logger=dict(type='LoggerHook'), - param_scheduler=dict(type='ParamSchedulerHook'), - checkpoint=dict( - type='CheckpointHook', interval=1, by_epoch=True), - sampler_seed=dict(type='DistSamplerSeedHook')), - launcher='none', - env_cfg=dict(dist_cfg=dict(backend='nccl')), - ) - self.epoch_based_cfg = Config(epoch_based_cfg) - self.iter_based_cfg = copy.deepcopy(self.epoch_based_cfg) - self.iter_based_cfg.train_dataloader = dict( - dataset=dict(type='ToyDataset'), - sampler=dict(type='InfiniteSampler', shuffle=True), - batch_size=3, - num_workers=0) - self.iter_based_cfg.train_cfg = dict(by_epoch=False, max_iters=12) - self.iter_based_cfg.default_hooks = dict( - runtime_info=dict(type='RuntimeInfoHook'), - timer=dict(type='IterTimerHook'), - logger=dict(type='LoggerHook'), - param_scheduler=dict(type='ParamSchedulerHook'), - checkpoint=dict(type='CheckpointHook', interval=1, by_epoch=False), - sampler_seed=dict(type='DistSamplerSeedHook')) - - def tearDown(self): - shutil.rmtree(self.temp_dir) - - def test_multi_val_loop(self): - - before_val_iter_results = [] - after_val_iter_results = [] - multi_metrics = dict() - - @HOOKS.register_module() - class Fake_1(Hook): - """test custom train loop.""" - - def before_val_iter(self, runner, batch_idx, data_batch=None): - before_val_iter_results.append('before') - - def after_val_iter(self, - runner, - batch_idx, - data_batch=None, - outputs=None): - after_val_iter_results.append('after') - - def after_val_epoch(self, runner, metrics=None) -> None: - multi_metrics.update(metrics) - - self.iter_based_cfg.val_cfg = dict(type='MultiValLoop') - self.iter_based_cfg.val_dataloader = [ - dict( - dataset=dict(type='ToyDataset'), - sampler=dict(type='DefaultSampler', shuffle=False), - batch_size=3, - num_workers=0), - dict( - dataset=dict(type='ToyDataset'), - sampler=dict(type='DefaultSampler', shuffle=False), - batch_size=3, - num_workers=0) - ] - self.iter_based_cfg.val_evaluator = [ - dict(type='ToyMetric3', prefix='tmp1'), - dict(type='ToyMetric3', prefix='tmp2') - ] - self.iter_based_cfg.custom_hooks = [dict(type='Fake_1', priority=50)] - self.iter_based_cfg.experiment_name = 'test_multi_val_loop' - runner = Runner.from_cfg(self.iter_based_cfg) - runner.val() - - self.assertIsInstance(runner.val_loop, MultiValLoop) - - # test custom hook triggered as expected - self.assertEqual(len(before_val_iter_results), 8) - self.assertEqual(len(after_val_iter_results), 8) - for before, after in zip(before_val_iter_results, - after_val_iter_results): - self.assertEqual(before, 'before') - self.assertEqual(after, 'after') - self.assertDictEqual(multi_metrics, {'tmp1/acc': 1, 'tmp2/acc': 1}) - - # test_same prefix - self.iter_based_cfg.val_evaluator = [ - dict(type='ToyMetric3', prefix='tmp1'), - dict(type='ToyMetric3', prefix='tmp1') - ] - self.iter_based_cfg.experiment_name = 'test_multi_val_loop_same_prefix' - runner = Runner.from_cfg(self.iter_based_cfg) - with self.assertRaisesRegex(ValueError, - ('Please set different' - ' prefix for different datasets' - ' in `val_evaluator`')): - runner.val() - - def test_multi_test_loop(self): - - before_test_iter_results = [] - after_test_iter_results = [] - multi_metrics = dict() - - @HOOKS.register_module() - class Fake_2(Hook): - """test custom train loop.""" - - def before_test_iter(self, runner, batch_idx, data_batch=None): - before_test_iter_results.append('before') - - def after_test_iter(self, - runner, - batch_idx, - data_batch=None, - outputs=None): - after_test_iter_results.append('after') - - def after_test_epoch(self, runner, metrics=None) -> None: - multi_metrics.update(metrics) - - self.iter_based_cfg.test_cfg = dict(type='MultiTestLoop') - self.iter_based_cfg.test_dataloader = [ - dict( - dataset=dict(type='ToyDataset'), - sampler=dict(type='DefaultSampler', shuffle=False), - batch_size=3, - num_workers=0), - dict( - dataset=dict(type='ToyDataset'), - sampler=dict(type='DefaultSampler', shuffle=False), - batch_size=3, - num_workers=0) - ] - self.iter_based_cfg.test_evaluator = [ - dict(type='ToyMetric3', prefix='tmp1'), - dict(type='ToyMetric3', prefix='tmp2') - ] - self.iter_based_cfg.custom_hooks = [dict(type='Fake_2', priority=50)] - self.iter_based_cfg.experiment_name = 'multi_test_loop' - runner = Runner.from_cfg(self.iter_based_cfg) - runner.test() - - self.assertIsInstance(runner.test_loop, MultiTestLoop) - - # test custom hook triggered as expected - self.assertEqual(len(before_test_iter_results), 8) - self.assertEqual(len(after_test_iter_results), 8) - for before, after in zip(before_test_iter_results, - after_test_iter_results): - self.assertEqual(before, 'before') - self.assertEqual(after, 'after') - self.assertDictEqual(multi_metrics, {'tmp1/acc': 1, 'tmp2/acc': 1}) - - # test_same prefix - self.iter_based_cfg.test_evaluator = [ - dict(type='ToyMetric3', prefix='tmp1'), - dict(type='ToyMetric3', prefix='tmp1') - ] - self.iter_based_cfg.experiment_name = 'multi_test_loop_same_prefix' - runner = Runner.from_cfg(self.iter_based_cfg) - with self.assertRaisesRegex(ValueError, - ('Please set different' - ' prefix for different datasets' - ' in `test_evaluator`')): - runner.test() diff --git a/tests/test_evaluation/test_evaluator/test_multi_datasets_evaluator.py b/tests/test_evaluation/test_evaluator/test_multi_datasets_evaluator.py index 44f7c353..e4f620c2 100644 --- a/tests/test_evaluation/test_evaluator/test_multi_datasets_evaluator.py +++ b/tests/test_evaluation/test_evaluator/test_multi_datasets_evaluator.py @@ -40,9 +40,9 @@ class ToyMetric(BaseMetric): def process(self, data_batch, predictions): results = [{ - 'pred': pred.get('pred'), - 'label': data['data_sample'].get('label') - } for pred, data in zip(predictions, data_batch)] + 'pred': prediction['pred'], + 'label': prediction['label'] + } for prediction in predictions] self.results.extend(results) def compute_metrics(self, results: List): @@ -67,12 +67,13 @@ def generate_test_results(size, batch_size, pred, label): bs_residual = size % batch_size for i in range(num_batch): bs = bs_residual if i == num_batch - 1 else batch_size - data_batch = [ - dict( - inputs=np.zeros((3, 10, 10)), - data_sample=BaseDataElement(label=label)) for _ in range(bs) + data_batch = { + 'inputs': [np.zeros((3, 10, 10)) for _ in range(bs)], + 'data_samples': [BaseDataElement(label=label) for _ in range(bs)] + } + predictions = [ + BaseDataElement(pred=pred, label=label) for _ in range(bs) ] - predictions = [BaseDataElement(pred=pred) for _ in range(bs)] yield (data_batch, predictions) @@ -92,7 +93,7 @@ class TestMultiDatasetsEvaluator(TestCase): for data_samples, predictions in generate_test_results( size, batch_size, pred=1, label=1): - evaluator.process(data_samples, predictions) + evaluator.process(predictions, data_samples) metrics = evaluator.evaluate(size=size) @@ -109,7 +110,7 @@ class TestMultiDatasetsEvaluator(TestCase): for data_samples, predictions in generate_test_results( size, batch_size, pred=1, label=1): - evaluator.process(data_samples, predictions) + evaluator.process(predictions, data_samples) with self.assertRaises(ValueError): evaluator.evaluate(size=size) @@ -120,7 +121,7 @@ class TestMultiDatasetsEvaluator(TestCase): for data_samples, predictions in generate_test_results( size, batch_size, pred=1, label=1): - evaluator.process(data_samples, predictions) + evaluator.process(predictions, data_samples) metrics = evaluator.evaluate(size=size) self.assertIn('Fake/Toy/accuracy', metrics) self.assertIn('Fake/accuracy', metrics) diff --git a/tests/test_models/test_kie/test_extractors/test_sdmgr.py b/tests/test_models/test_kie/test_extractors/test_sdmgr.py index 07ae76a2..5a89a6f8 100644 --- a/tests/test_models/test_kie/test_extractors/test_sdmgr.py +++ b/tests/test_models/test_kie/test_extractors/test_sdmgr.py @@ -46,8 +46,9 @@ class TestSDMGR(unittest.TestCase): return model def forward_wrapper(self, model, data, mode): - inputs, data_sample = model.data_preprocessor(data, False) - return model.forward(inputs, data_sample, mode) + out = model.data_preprocessor(data, False) + inputs, data_samples = out['inputs'], out['data_samples'] + return model.forward(inputs, data_samples, mode) def setUp(self): @@ -66,14 +67,11 @@ class TestSDMGR(unittest.TestCase): edge_labels=torch.LongTensor([[0, 1], [1, 0]]), texts=['text1', 'text2'], relations=torch.rand((2, 2, 5))) - self.visual_data = [ - dict(inputs=torch.rand((3, 10, 10)), data_sample=data_sample) - ] - self.novisual_data = [ - dict( - inputs=torch.Tensor([]).reshape((0, 0, 0)), - data_sample=data_sample) - ] + self.visual_data = dict( + inputs=[torch.rand((3, 10, 10))], data_samples=[data_sample]) + self.novisual_data = dict( + inputs=[torch.Tensor([]).reshape((0, 0, 0))], + data_samples=[data_sample]) def test_forward_loss(self): result = self.forward_wrapper( diff --git a/tests/test_models/test_textdet/test_data_preprocessors/test_textdet_data_preprocessor.py b/tests/test_models/test_textdet/test_data_preprocessors/test_textdet_data_preprocessor.py index d5ccc3ee..8f0642ce 100644 --- a/tests/test_models/test_textdet/test_data_preprocessors/test_textdet_data_preprocessor.py +++ b/tests/test_models/test_textdet/test_data_preprocessors/test_textdet_data_preprocessor.py @@ -11,8 +11,8 @@ from mmocr.structures import TextDetDataSample @MODELS.register_module() class TDAugment(torch.nn.Module): - def forward(self, batch_inputs, batch_data_samples): - return batch_inputs, batch_data_samples + def forward(self, inputs, data_samples): + return inputs, data_samples class TestTextDetDataPreprocessor(TestCase): @@ -47,58 +47,63 @@ class TestTextDetDataPreprocessor(TestCase): def test_forward(self): processor = TextDetDataPreprocessor(mean=[0, 0, 0], std=[1, 1, 1]) - data = [{ - 'inputs': - torch.randint(0, 256, (3, 11, 10)), - 'data_sample': - TextDetDataSample( - metainfo=dict(img_shape=(11, 10), valid_ratio=1.0)) - }] - inputs, data_samples = processor(data) - print(inputs.dtype) + data = { + 'inputs': [ + torch.randint(0, 256, (3, 11, 10)), + ], + 'data_samples': [ + TextDetDataSample( + metainfo=dict(img_shape=(11, 10), valid_ratio=1.0)), + ] + } + out = processor(data) + inputs, data_samples = out['inputs'], out['data_samples'] self.assertEqual(inputs.shape, (1, 3, 11, 10)) self.assertEqual(len(data_samples), 1) # test channel_conversion processor = TextDetDataPreprocessor( mean=[0., 0., 0.], std=[1., 1., 1.], bgr_to_rgb=True) - inputs, data_samples = processor(data) + out = processor(data) + inputs, data_samples = out['inputs'], out['data_samples'] self.assertEqual(inputs.shape, (1, 3, 11, 10)) self.assertEqual(len(data_samples), 1) # test padding - data = [{ - 'inputs': torch.randint(0, 256, (3, 10, 11)) - }, { - 'inputs': torch.randint(0, 256, (3, 9, 14)) - }] + data = { + 'inputs': [ + torch.randint(0, 256, (3, 10, 11)), + torch.randint(0, 256, (3, 9, 14)) + ] + } processor = TextDetDataPreprocessor( mean=[0., 0., 0.], std=[1., 1., 1.], bgr_to_rgb=True) - inputs, data_samples = processor(data) + out = processor(data) + inputs, data_samples = out['inputs'], out['data_samples'] self.assertEqual(inputs.shape, (2, 3, 10, 14)) self.assertIsNone(data_samples) # test pad_size_divisor - data = [{ - 'inputs': - torch.randint(0, 256, (3, 10, 11)), - 'data_sample': - TextDetDataSample( - metainfo=dict(img_shape=(10, 11), valid_ratio=1.0)) - }, { - 'inputs': - torch.randint(0, 256, (3, 9, 24)), - 'data_sample': - TextDetDataSample( - metainfo=dict(img_shape=(9, 24), valid_ratio=1.0)) - }] + data = { + 'inputs': [ + torch.randint(0, 256, (3, 10, 11)), + torch.randint(0, 256, (3, 9, 24)) + ], + 'data_samples': [ + TextDetDataSample( + metainfo=dict(img_shape=(10, 11), valid_ratio=1.0)), + TextDetDataSample( + metainfo=dict(img_shape=(9, 24), valid_ratio=1.0)) + ] + } aug_cfg = [dict(type='TDAugment')] processor = TextDetDataPreprocessor( mean=[0., 0., 0.], std=[1., 1., 1.], pad_size_divisor=5, batch_augments=aug_cfg) - inputs, data_samples = processor(data, training=True) + out = processor(data) + inputs, data_samples = out['inputs'], out['data_samples'] self.assertEqual(inputs.shape, (2, 3, 10, 25)) self.assertEqual(len(data_samples), 2) for data_sample, expected_shape in zip(data_samples, [(10, 25), diff --git a/tests/test_models/test_textdet/test_wrappers/test_mmdet_wrapper.py b/tests/test_models/test_textdet/test_wrappers/test_mmdet_wrapper.py index 57747ff7..de489228 100644 --- a/tests/test_models/test_textdet/test_wrappers/test_mmdet_wrapper.py +++ b/tests/test_models/test_textdet/test_wrappers/test_mmdet_wrapper.py @@ -262,9 +262,3 @@ class TestMMDetWrapper(unittest.TestCase): self.assertEqual(len(results), 1) self.assertIsInstance(results[0], TextDetDataSample) self.assertTrue('polygons' in results[0].pred_instances.keys()) - - -if __name__ == '__main__': - test = TestMMDetWrapper() - test.setUp() - test.test_mask_two_stage_wrapper() diff --git a/tests/test_models/test_textrecog/test_data_preprocessors/test_data_preprocessor.py b/tests/test_models/test_textrecog/test_data_preprocessors/test_data_preprocessor.py index f680a2cb..2216aba0 100644 --- a/tests/test_models/test_textrecog/test_data_preprocessors/test_data_preprocessor.py +++ b/tests/test_models/test_textrecog/test_data_preprocessors/test_data_preprocessor.py @@ -11,8 +11,8 @@ from mmocr.structures import TextRecogDataSample @MODELS.register_module() class Augment(torch.nn.Module): - def forward(self, batch_inputs, batch_data_samples): - return batch_inputs, batch_data_samples + def forward(self, inputs, data_samples): + return inputs, data_samples class TestTextRecogDataPreprocessor(TestCase): @@ -47,14 +47,17 @@ class TestTextRecogDataPreprocessor(TestCase): def test_forward(self): processor = TextRecogDataPreprocessor(mean=[0, 0, 0], std=[1, 1, 1]) - data = [{ - 'inputs': - torch.randint(0, 256, (3, 11, 10)), - 'data_sample': - TextRecogDataSample( - metainfo=dict(img_shape=(11, 10), valid_ratio=1.0)) - }] - inputs, data_samples = processor(data) + data = { + 'inputs': [ + torch.randint(0, 256, (3, 11, 10)), + ], + 'data_samples': [ + TextRecogDataSample( + metainfo=dict(img_shape=(11, 10), valid_ratio=1.0)), + ] + } + out = processor(data) + inputs, data_samples = out['inputs'], out['data_samples'] print(inputs.dtype) self.assertEqual(inputs.shape, (1, 3, 11, 10)) self.assertEqual(len(data_samples), 1) @@ -62,43 +65,46 @@ class TestTextRecogDataPreprocessor(TestCase): # test channel_conversion processor = TextRecogDataPreprocessor( mean=[0., 0., 0.], std=[1., 1., 1.], bgr_to_rgb=True) - inputs, data_samples = processor(data) + out = processor(data) + inputs, data_samples = out['inputs'], out['data_samples'] self.assertEqual(inputs.shape, (1, 3, 11, 10)) self.assertEqual(len(data_samples), 1) # test padding - data = [{ - 'inputs': torch.randint(0, 256, (3, 10, 11)) - }, { - 'inputs': torch.randint(0, 256, (3, 9, 14)) - }] + data = { + 'inputs': [ + torch.randint(0, 256, (3, 10, 11)), + torch.randint(0, 256, (3, 9, 14)) + ] + } processor = TextRecogDataPreprocessor( mean=[0., 0., 0.], std=[1., 1., 1.], bgr_to_rgb=True) - inputs, data_samples = processor(data) + out = processor(data) + inputs, data_samples = out['inputs'], out['data_samples'] self.assertEqual(inputs.shape, (2, 3, 10, 14)) self.assertIsNone(data_samples) # test pad_size_divisor - data = [{ - 'inputs': - torch.randint(0, 256, (3, 10, 11)), - 'data_sample': - TextRecogDataSample( - metainfo=dict(img_shape=(10, 11), valid_ratio=1.0)) - }, { - 'inputs': - torch.randint(0, 256, (3, 9, 24)), - 'data_sample': - TextRecogDataSample( - metainfo=dict(img_shape=(9, 24), valid_ratio=1.0)) - }] + data = { + 'inputs': [ + torch.randint(0, 256, (3, 10, 11)), + torch.randint(0, 256, (3, 9, 24)), + ], + 'data_samples': [ + TextRecogDataSample( + metainfo=dict(img_shape=(10, 11), valid_ratio=1.0)), + TextRecogDataSample( + metainfo=dict(img_shape=(9, 24), valid_ratio=1.0)) + ] + } aug_cfg = [dict(type='Augment')] processor = TextRecogDataPreprocessor( mean=[0., 0., 0.], std=[1., 1., 1.], pad_size_divisor=5, batch_augments=aug_cfg) - inputs, data_samples = processor(data, training=True) + out = processor(data) + inputs, data_samples = out['inputs'], out['data_samples'] self.assertEqual(inputs.shape, (2, 3, 10, 25)) self.assertEqual(len(data_samples), 2) for data_sample, expected_shape, expected_ratio in zip(