[Refactor] Adapt to new dataflow (#1305)

* datasample->datasamples

* update rec data preprocessor

* rename datasamples

* update det preprocessor

* update metric

* update data_sample->data_samples in test

* update

* fix data preprocessor uts

* remove engine runner

* fix kie ut

* fix ut

* fix comments

* refactor evaluator ut

* apply comments

Co-authored-by: Tong Gao <gaotongxiao@gmail.com>

* remove useless

* apply comments

Co-authored-by: Tong Gao <gaotongxiao@gmail.com>

* apply comments

Co-authored-by: Tong Gao <gaotongxiao@gmail.com>

Co-authored-by: Tong Gao <gaotongxiao@gmail.com>
pull/1324/head
Xinyu Wang 2022-08-25 11:45:42 +08:00 committed by GitHub
parent 1b5764b155
commit 9bd5258513
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
27 changed files with 316 additions and 857 deletions

View File

@ -1,31 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
from argparse import ArgumentParser
from mmocr.apis import init_detector
from mmocr.apis.inference import text_model_inference
from mmocr.registry import DATASETS # NOQA
def main():
parser = ArgumentParser()
parser.add_argument('config', help='Config file.')
parser.add_argument('checkpoint', help='Checkpoint file.')
parser.add_argument(
'--device', default='cuda:0', help='Device used for inference.')
args = parser.parse_args()
# build the model from a config file and a checkpoint file
model = init_detector(args.config, args.checkpoint, device=args.device)
# test a single text
input_sentence = input('Please enter a sentence you want to test: ')
result = text_model_inference(model, input_sentence)
# show the results
for pred_entities in result:
for entity in pred_entities:
print(f'{entity[0]}: {input_sentence[entity[1]:entity[2] + 1]}')
if __name__ == '__main__':
main()

View File

@ -17,7 +17,7 @@ class PackTextDetInputs(BaseTransform):
The type of outputs is `dict`:
- inputs: image converted to tensor, whose shape is (C, H, W).
- data_sample: Two components of ``TextDetDataSample`` will be updated:
- data_samples: Two components of ``TextDetDataSample`` will be updated:
- gt_instances (InstanceData): Depending on annotations, a subset of the
following keys will be updated:
@ -82,7 +82,7 @@ class PackTextDetInputs(BaseTransform):
dict:
- 'inputs' (obj:`torch.Tensor`): Data for model forwarding.
- 'data_sample' (obj:`DetDataSample`): The annotation info of the
- 'data_samples' (obj:`DetDataSample`): The annotation info of the
sample.
"""
packed_results = dict()
@ -109,7 +109,7 @@ class PackTextDetInputs(BaseTransform):
for key in self.meta_keys:
img_meta[key] = results[key]
data_sample.set_metainfo(img_meta)
packed_results['data_sample'] = data_sample
packed_results['data_samples'] = data_sample
return packed_results
@ -126,7 +126,7 @@ class PackTextRecogInputs(BaseTransform):
The type of outputs is `dict`:
- inputs: Image as a tensor, whose shape is (C, H, W).
- data_sample: Two components of ``TextRecogDataSample`` will be updated:
- data_samples: Two components of ``TextRecogDataSample`` will be updated:
- gt_text (LabelData):
@ -166,7 +166,7 @@ class PackTextRecogInputs(BaseTransform):
dict:
- 'inputs' (obj:`torch.Tensor`): Data for model forwarding.
- 'data_sample' (obj:`TextRecogDataSample`): The annotation info
- 'data_samples' (obj:`TextRecogDataSample`): The annotation info
of the sample.
"""
packed_results = dict()
@ -195,7 +195,7 @@ class PackTextRecogInputs(BaseTransform):
img_meta[key] = results[key]
data_sample.set_metainfo(img_meta)
packed_results['data_sample'] = data_sample
packed_results['data_samples'] = data_sample
return packed_results
@ -212,7 +212,7 @@ class PackKIEInputs(BaseTransform):
The type of outputs is `dict`:
- inputs: image converted to tensor, whose shape is (C, H, W).
- data_sample: Two components of ``TextDetDataSample`` will be updated:
- data_samples: Two components of ``TextDetDataSample`` will be updated:
- gt_instances (InstanceData): Depending on annotations, a subset of the
following keys will be updated:
@ -266,7 +266,7 @@ class PackKIEInputs(BaseTransform):
dict:
- 'inputs' (obj:`torch.Tensor`): Data for model forwarding.
- 'data_sample' (obj:`DetDataSample`): The annotation info of the
- 'data_samples' (obj:`DetDataSample`): The annotation info of the
sample.
"""
packed_results = dict()
@ -295,7 +295,7 @@ class PackKIEInputs(BaseTransform):
for key in self.meta_keys:
img_meta[key] = results[key]
data_sample.set_metainfo(img_meta)
packed_results['data_sample'] = data_sample
packed_results['data_samples'] = data_sample
return packed_results

View File

@ -1,3 +1,2 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .hooks import * # NOQA
from .runner import * # NOQA

View File

@ -1,4 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .multi_loops import MultiTestLoop, MultiValLoop
__all__ = ['MultiValLoop', 'MultiTestLoop']

View File

@ -1,202 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from typing import Dict, List, Sequence, Union
import torch
from mmengine.evaluator import Evaluator
from mmengine.runner.amp import autocast
from mmengine.runner.base_loop import BaseLoop
from mmengine.utils import is_list_of
from torch.utils.data import DataLoader
from mmocr.registry import LOOPS
@LOOPS.register_module()
class MultiValLoop(BaseLoop):
"""Loop for validation multi-datasets.
Args:
runner (Runner): A reference of runner.
dataloader (list[Dataloader or dic]): A dataloader object or a dict to
build a dataloader.
evaluator (list[]): Used for computing metrics.
fp16 (bool): Whether to enable fp16 validation. Defaults to False.
"""
def __init__(self,
runner,
dataloader: Union[DataLoader, Dict],
evaluator: Union[Evaluator, Dict, List],
fp16: bool = False) -> None:
self._runner = runner
assert isinstance(dataloader, list)
self.dataloaders = list()
for loader in dataloader:
if isinstance(loader, dict):
self.dataloaders.append(
runner.build_dataloader(loader, seed=runner.seed))
else:
self.dataloaders.append(loader)
assert isinstance(evaluator, list)
self.evaluators = list()
for single_evalator in evaluator:
if isinstance(single_evalator, dict) or is_list_of(
single_evalator, dict):
self.evaluators.append(runner.build_evaluator(single_evalator))
else:
self.evaluators.append(single_evalator)
self.evaluators = [runner.build_evaluator(eval) for eval in evaluator]
assert len(self.evaluators) == len(self.dataloaders)
self.fp16 = fp16
def run(self):
"""Launch validation."""
self.runner.call_hook('before_val')
self.runner.model.eval()
multi_metric = dict()
self.runner.call_hook('before_val_epoch')
for evaluator, dataloader in zip(self.evaluators, self.dataloaders):
self.evaluator = evaluator
self.dataloader = dataloader
if hasattr(self.dataloader.dataset, 'metainfo'):
self.evaluator.dataset_meta = self.dataloader.dataset.metainfo
self.runner.visualizer.dataset_meta = \
self.dataloader.dataset.metainfo
else:
warnings.warn(
f'Dataset {self.dataloader.dataset.__class__.__name__} '
'has no metainfo. ``dataset_meta`` in evaluator, metric'
' and visualizer will be None.')
for idx, data_batch in enumerate(self.dataloader):
self.run_iter(idx, data_batch)
# compute metrics
metrics = self.evaluator.evaluate(len(self.dataloader.dataset))
if multi_metric and metrics.keys() & multi_metric.keys():
raise ValueError('Please set different prefix for different'
' datasets in `val_evaluator`')
else:
multi_metric.update(metrics)
self.runner.call_hook('after_val_epoch', metrics=multi_metric)
self.runner.call_hook('after_val')
@torch.no_grad()
def run_iter(self, idx: int, data_batch: Sequence[dict]):
"""Iterate one mini-batch.
Args:
idx (int): The index of the current batch in the loop.
data_batch (Sequence[dict]): Batch of data
from dataloader.
"""
self.runner.call_hook(
'before_val_iter', batch_idx=idx, data_batch=data_batch)
# outputs should be sequence of BaseDataElement
with autocast(enabled=self.fp16):
outputs = self.runner.model.val_step(data_batch)
self.evaluator.process(data_batch, outputs)
self.runner.call_hook(
'after_val_iter',
batch_idx=idx,
data_batch=data_batch,
outputs=outputs)
@LOOPS.register_module()
class MultiTestLoop(BaseLoop):
"""Loop for validation multi-datasets.
Args:
runner (Runner): A reference of runner.
dataloader (Dataloader or dict): A dataloader object or a dict to
build a dataloader.
evaluator (Evaluator or dict or list): Used for computing metrics.
fp16 (bool): Whether to enable fp16 validation. Defaults to False.
"""
def __init__(self,
runner,
dataloader: Union[DataLoader, Dict],
evaluator: Union[Evaluator, Dict, List],
fp16: bool = False) -> None:
self._runner = runner
assert isinstance(dataloader, list)
self.dataloaders = list()
for loader in dataloader:
if isinstance(loader, dict):
self.dataloaders.append(
runner.build_dataloader(loader, seed=runner.seed))
else:
self.dataloaders.append(loader)
assert isinstance(evaluator, list)
self.evaluators = list()
for single_evalator in evaluator:
if isinstance(single_evalator, dict) or is_list_of(
single_evalator, dict):
self.evaluators.append(runner.build_evaluator(single_evalator))
else:
self.evaluators.append(single_evalator)
self.evaluators = [runner.build_evaluator(eval) for eval in evaluator]
assert len(self.evaluators) == len(self.dataloaders)
self.fp16 = fp16
def run(self):
"""Launch test."""
self.runner.call_hook('before_test')
self.runner.model.eval()
multi_metric = dict()
self.runner.call_hook('before_test_epoch')
for evaluator, dataloader in zip(self.evaluators, self.dataloaders):
self.dataloader = dataloader
self.evaluator = evaluator
if hasattr(self.dataloader.dataset, 'metainfo'):
self.evaluator.dataset_meta = self.dataloader.dataset.metainfo
self.runner.visualizer.dataset_meta = \
self.dataloader.dataset.metainfo
else:
warnings.warn(
f'Dataset {self.dataloader.dataset.__class__.__name__} '
'has no metainfo. ``dataset_meta`` in evaluator, metric'
' and visualizer will be None.')
for idx, data_batch in enumerate(self.dataloader):
self.run_iter(idx, data_batch)
# compute metrics
metrics = self.evaluator.evaluate(len(self.dataloader.dataset))
if multi_metric and metrics.keys() & multi_metric.keys():
raise ValueError('Please set different prefix for different'
' datasets in `test_evaluator`')
else:
multi_metric.update(metrics)
self.runner.call_hook('after_test_epoch', metrics=multi_metric)
self.runner.call_hook('after_test')
@torch.no_grad()
def run_iter(self, idx: int, data_batch: Sequence[dict]):
"""Iterate one mini-batch.
Args:
idx (int): The index of the current batch in the loop.
data_batch (Sequence[dict]): Batch of data
from dataloader.
"""
self.runner.call_hook(
'before_test_iter', batch_idx=idx, data_batch=data_batch)
# outputs should be sequence of BaseDataElement
with autocast(enabled=self.fp16):
predictions = self.runner.model.test_step(data_batch)
self.evaluator.process(data_batch, predictions)
self.runner.call_hook(
'after_test_iter',
batch_idx=idx,
data_batch=data_batch,
outputs=predictions)

View File

@ -85,19 +85,18 @@ class F1Metric(BaseMetric):
self.key = key
def process(self, data_batch: Sequence[Dict],
predictions: Sequence[Dict]) -> None:
"""Process one batch of predictions. The processed results should be
data_samples: Sequence[Dict]) -> None:
"""Process one batch of data_samples. The processed results should be
stored in ``self.results``, which will be used to compute the metrics
when all batches have been processed.
Args:
data_batch (Sequence[Dict]): A batch of gts.
predictions (Sequence[Dict]): A batch of outputs from the model.
data_samples (Sequence[Dict]): A batch of outputs from the model.
"""
for data_samples in predictions:
pred_labels = data_samples.get('pred_instances').get(
self.key).cpu()
gt_labels = data_samples.get('gt_instances').get(self.key).cpu()
for data_sample in data_samples:
pred_labels = data_sample.get('pred_instances').get(self.key).cpu()
gt_labels = data_sample.get('gt_instances').get(self.key).cpu()
result = dict(
pred_labels=pred_labels.flatten(),

View File

@ -75,17 +75,17 @@ class HmeanIOUMetric(BaseMetric):
self.strategy = strategy
def process(self, data_batch: Sequence[Dict],
predictions: Sequence[Dict]) -> None:
data_samples: Sequence[Dict]) -> None:
"""Process one batch of data samples and predictions. The processed
results should be stored in ``self.results``, which will be used to
compute the metrics when all batches have been processed.
Args:
data_batch (Sequence[Dict]): A batch of data from dataloader.
predictions (Sequence[Dict]): A batch of outputs from
data_samples (Sequence[Dict]): A batch of outputs from
the model.
"""
for data_sample in predictions:
for data_sample in data_samples:
pred_instances = data_sample.get('pred_instances')
pred_polygons = pred_instances.get('polygons')

View File

@ -51,16 +51,16 @@ class WordMetric(BaseMetric):
self.mode = set(mode)
def process(self, data_batch: Sequence[Dict],
predictions: Sequence[Dict]) -> None:
"""Process one batch of predictions. The processed results should be
data_samples: Sequence[Dict]) -> None:
"""Process one batch of data_samples. The processed results should be
stored in ``self.results``, which will be used to compute the metrics
when all batches have been processed.
Args:
data_batch (Sequence[Dict]): A batch of gts.
predictions (Sequence[Dict]): A batch of outputs from the model.
data_samples (Sequence[Dict]): A batch of outputs from the model.
"""
for data_sample in predictions:
for data_sample in data_samples:
match_num = 0
match_ignore_case_num = 0
match_ignore_case_symbol_num = 0
@ -149,16 +149,16 @@ class CharMetric(BaseMetric):
self.valid_symbol = re.compile(valid_symbol)
def process(self, data_batch: Sequence[Dict],
predictions: Sequence[Dict]) -> None:
"""Process one batch of predictions. The processed results should be
data_samples: Sequence[Dict]) -> None:
"""Process one batch of data_samples. The processed results should be
stored in ``self.results``, which will be used to compute the metrics
when all batches have been processed.
Args:
data_batch (Sequence[Dict]): A batch of gts.
predictions (Sequence[Dict]): A batch of outputs from the model.
data_samples (Sequence[Dict]): A batch of outputs from the model.
"""
for data_sample in predictions:
for data_sample in data_samples:
pred_text = data_sample.get('pred_text').get('item')
gt_text = data_sample.get('gt_text').get('item')
gt_text_lower = gt_text.lower()
@ -249,16 +249,16 @@ class OneMinusNEDMetric(BaseMetric):
self.valid_symbol = re.compile(valid_symbol)
def process(self, data_batch: Sequence[Dict],
predictions: Sequence[Dict]) -> None:
"""Process one batch of predictions. The processed results should be
data_samples: Sequence[Dict]) -> None:
"""Process one batch of data_samples. The processed results should be
stored in ``self.results``, which will be used to compute the metrics
when all batches have been processed.
Args:
data_batch (Sequence[Dict]): A batch of gts.
predictions (Sequence[Dict]): A batch of outputs from the model.
data_samples (Sequence[Dict]): A batch of outputs from the model.
"""
for data_sample in predictions:
for data_sample in data_samples:
pred_text = data_sample.get('pred_text').get('item')
gt_text = data_sample.get('gt_text').get('item')
gt_text_lower = gt_text.lower()

View File

@ -84,8 +84,8 @@ class SDMGR(BaseModel):
return feats.view(feats.size(0), -1)
def forward(self,
batch_inputs: torch.Tensor,
batch_data_samples: Sequence[KIEDataSample] = None,
inputs: torch.Tensor,
data_samples: Sequence[KIEDataSample] = None,
mode: str = 'tensor',
**kwargs) -> torch.Tensor:
"""The unified entry for a forward process in both training and test.
@ -103,9 +103,9 @@ class SDMGR(BaseModel):
optimizer updating, which are done in the :meth:`train_step`.
Args:
batch_inputs (torch.Tensor): The input tensor with shape
inputs (torch.Tensor): The input tensor with shape
(N, C, ...) in general.
batch_data_samples (list[:obj:`DetDataSample`], optional): The
data_samples (list[:obj:`DetDataSample`], optional): The
annotation data of every samples. Defaults to None.
mode (str): Return what kind of value. Defaults to 'tensor'.
@ -117,43 +117,42 @@ class SDMGR(BaseModel):
- If ``mode="loss"``, return a dict of tensor.
"""
if mode == 'loss':
return self.loss(batch_inputs, batch_data_samples, **kwargs)
return self.loss(inputs, data_samples, **kwargs)
elif mode == 'predict':
return self.predict(batch_inputs, batch_data_samples, **kwargs)
return self.predict(inputs, data_samples, **kwargs)
elif mode == 'tensor':
return self._forward(batch_inputs, batch_data_samples, **kwargs)
return self._forward(inputs, data_samples, **kwargs)
else:
raise RuntimeError(f'Invalid mode "{mode}". '
'Only supports loss, predict and tensor mode')
def loss(self, batch_inputs: torch.Tensor,
batch_data_samples: Sequence[KIEDataSample], **kwargs) -> dict:
def loss(self, inputs: torch.Tensor, data_samples: Sequence[KIEDataSample],
**kwargs) -> dict:
"""Calculate losses from a batch of inputs and data samples.
Args:
batch_inputs (torch.Tensor): Input images of shape (N, C, H, W).
inputs (torch.Tensor): Input images of shape (N, C, H, W).
Typically these should be mean centered and std scaled.
batch_data_samples (list[KIEDataSample]): A list of N datasamples,
data_samples (list[KIEDataSample]): A list of N datasamples,
containing meta information and gold annotations for each of
the images.
Returns:
dict[str, Tensor]: A dictionary of loss components.
"""
x = self.extract_feat(batch_inputs, [
data_sample.gt_instances.bboxes
for data_sample in batch_data_samples
])
return self.kie_head.loss(x, batch_data_samples)
x = self.extract_feat(
inputs,
[data_sample.gt_instances.bboxes for data_sample in data_samples])
return self.kie_head.loss(x, data_samples)
def predict(self, batch_inputs: torch.Tensor,
batch_data_samples: Sequence[KIEDataSample],
def predict(self, inputs: torch.Tensor,
data_samples: Sequence[KIEDataSample],
**kwargs) -> List[KIEDataSample]:
"""Predict results from a batch of inputs and data samples with post-
processing.
Args:
batch_inputs (torch.Tensor): Input images of shape (N, C, H, W).
inputs (torch.Tensor): Input images of shape (N, C, H, W).
Typically these should be mean centered and std scaled.
batch_data_samples (list[KIEDataSample]): A list of N datasamples,
data_samples (list[KIEDataSample]): A list of N datasamples,
containing meta information and gold annotations for each of
the images.
@ -162,22 +161,21 @@ class SDMGR(BaseModel):
Results are stored in ``pred_instances.labels`` and
``pred_instances.edge_labels``.
"""
x = self.extract_feat(batch_inputs, [
data_sample.gt_instances.bboxes
for data_sample in batch_data_samples
])
return self.kie_head.predict(x, batch_data_samples)
x = self.extract_feat(
inputs,
[data_sample.gt_instances.bboxes for data_sample in data_samples])
return self.kie_head.predict(x, data_samples)
def _forward(self, batch_inputs: torch.Tensor,
batch_data_samples: Sequence[KIEDataSample],
def _forward(self, inputs: torch.Tensor,
data_samples: Sequence[KIEDataSample],
**kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
"""Get the raw tensor outputs from backbone and head without any post-
processing.
Args:
batch_inputs (torch.Tensor): Input images of shape (N, C, H, W).
inputs (torch.Tensor): Input images of shape (N, C, H, W).
Typically these should be mean centered and std scaled.
batch_data_samples (list[KIEDataSample]): A list of N datasamples,
data_samples (list[KIEDataSample]): A list of N datasamples,
containing meta information and gold annotations for each of
the images.
@ -187,8 +185,7 @@ class SDMGR(BaseModel):
- node_cls (torch.Tensor): Node classification output.
- edge_cls (torch.Tensor): Edge classification output.
"""
x = self.extract_feat(batch_inputs, [
data_sample.gt_instances.bboxes
for data_sample in batch_data_samples
])
return self.kie_head(x, batch_data_samples)
x = self.extract_feat(
inputs,
[data_sample.gt_instances.bboxes for data_sample in data_samples])
return self.kie_head(x, data_samples)

View File

@ -83,28 +83,26 @@ class SDMGRHead(BaseModule):
self.postprocessor = MODELS.build(postprocessor)
self.relation_norm = relation_norm
def loss(self, batch_inputs: Tensor,
batch_data_samples: List[KIEDataSample]) -> Dict:
def loss(self, inputs: Tensor, data_samples: List[KIEDataSample]) -> Dict:
"""Calculate losses from a batch of inputs and data samples.
Args:
batch_inputs (torch.Tensor): Shape :math:`(N, E)`.
batch_data_samples (List[KIEDataSample]): List of data samples.
inputs (torch.Tensor): Shape :math:`(N, E)`.
data_samples (List[KIEDataSample]): List of data samples.
Returns:
dict[str, tensor]: A dictionary of loss components.
"""
preds = self.forward(batch_inputs, batch_data_samples)
return self.module_loss(preds, batch_data_samples)
preds = self.forward(inputs, data_samples)
return self.module_loss(preds, data_samples)
def predict(self, batch_inputs: Tensor,
batch_data_samples: List[KIEDataSample]
) -> List[KIEDataSample]:
def predict(self, inputs: Tensor,
data_samples: List[KIEDataSample]) -> List[KIEDataSample]:
"""Predict results from a batch of inputs and data samples with post-
processing.
Args:
batch_inputs (torch.Tensor): Shape :math:`(N, E)`.
batch_data_samples (List[KIEDataSample]): List of data samples.
inputs (torch.Tensor): Shape :math:`(N, E)`.
data_samples (List[KIEDataSample]): List of data samples.
Returns:
List[KIEDataSample]: A list of datasamples of prediction results.
@ -121,16 +119,15 @@ class SDMGRHead(BaseModule):
- edge_scores (Tensor): A float tensor of shape (N, ), indicating
the confidence scores for edge predictions.
"""
preds = self.forward(batch_inputs, batch_data_samples)
return self.postprocessor(preds, batch_data_samples)
preds = self.forward(inputs, data_samples)
return self.postprocessor(preds, data_samples)
def forward(self, batch_inputs: Tensor,
batch_data_samples: List[KIEDataSample]
) -> Tuple[Tensor, Tensor]:
def forward(self, inputs: Tensor,
data_samples: List[KIEDataSample]) -> Tuple[Tensor, Tensor]:
"""
Args:
batch_inputs (torch.Tensor): Shape :math:`(N, E)`.
batch_data_samples (List[KIEDataSample]): List of data samples.
inputs (torch.Tensor): Shape :math:`(N, E)`.
data_samples (List[KIEDataSample]): List of data samples.
Returns:
tuple(Tensor, Tensor):
@ -143,8 +140,7 @@ class SDMGRHead(BaseModule):
device = self.node_embed.weight.device
node_nums, char_nums, all_nodes = self.convert_texts(
batch_data_samples)
node_nums, char_nums, all_nodes = self.convert_texts(data_samples)
embed_nodes = self.node_embed(all_nodes.to(device).long())
rnn_nodes, _ = self.rnn(embed_nodes)
@ -156,10 +152,10 @@ class SDMGRHead(BaseModule):
1, (all_nums[valid] - 1).unsqueeze(-1).unsqueeze(-1).expand(
-1, -1, rnn_nodes.size(-1))).squeeze(1)
if batch_inputs is not None:
nodes = self.fusion([batch_inputs, nodes])
if inputs is not None:
nodes = self.fusion([inputs, nodes])
relations = self.compute_relations(batch_data_samples)
relations = self.compute_relations(data_samples)
all_edges = torch.cat(
[relation.view(-1, relation.size(-1)) for relation in relations],
dim=0)

View File

@ -52,8 +52,7 @@ class SDMGRPostProcessor:
self.softmax = nn.Softmax(dim=-1)
def __call__(self, preds: Tuple[Tensor, Tensor],
batch_data_samples: List[KIEDataSample]
) -> List[KIEDataSample]:
data_samples: List[KIEDataSample]) -> List[KIEDataSample]:
"""Postprocess raw outputs from SDMGR heads and pack the results into a
list of KIEDataSample.
@ -83,7 +82,7 @@ class SDMGRPostProcessor:
all_edge_scores = self.softmax(edge_preds)
chunk_size = [
data_sample.gt_instances.bboxes.shape[0]
for data_sample in batch_data_samples
for data_sample in data_samples
]
node_scores, node_preds = torch.max(all_node_scores, dim=-1)
edge_scores, edge_preds = torch.max(all_edge_scores, dim=-1)
@ -97,17 +96,17 @@ class SDMGRPostProcessor:
edge_preds[i] = edge_preds[i].reshape((chunk, chunk))
edge_scores[i] = edge_scores[i].reshape((chunk, chunk))
for i in range(len(batch_data_samples)):
batch_data_samples[i].pred_instances = InstanceData()
batch_data_samples[i].pred_instances.labels = node_preds[i]
batch_data_samples[i].pred_instances.scores = node_scores[i]
for i in range(len(data_samples)):
data_samples[i].pred_instances = InstanceData()
data_samples[i].pred_instances.labels = node_preds[i]
data_samples[i].pred_instances.scores = node_scores[i]
if self.link_type != 'none':
edge_scores[i], edge_preds[i] = self.decode_edges(
node_preds[i], edge_scores[i], edge_preds[i])
batch_data_samples[i].pred_instances.edge_labels = edge_preds[i]
batch_data_samples[i].pred_instances.edge_scores = edge_scores[i]
data_samples[i].pred_instances.edge_labels = edge_preds[i]
data_samples[i].pred_instances.edge_scores = edge_scores[i]
return batch_data_samples
return data_samples
def decode_edges(self, node_labels: Tensor, edge_scores: Tensor,
edge_labels: Tensor) -> Tuple[Tensor, Tensor]:

View File

@ -1,8 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
from numbers import Number
from typing import List, Optional, Sequence, Tuple, Union
from typing import Dict, List, Optional, Sequence, Union
import torch
import torch.nn as nn
from mmengine.model import ImgDataPreprocessor
@ -59,7 +58,7 @@ class TextDetDataPreprocessor(ImgDataPreprocessor):
pad_value: Union[float, int] = 0,
bgr_to_rgb: bool = False,
rgb_to_bgr: bool = False,
batch_augments: Optional[List[dict]] = None):
batch_augments: Optional[List[Dict]] = None) -> None:
super().__init__(
mean=mean,
std=std,
@ -73,32 +72,28 @@ class TextDetDataPreprocessor(ImgDataPreprocessor):
else:
self.batch_augments = None
def forward(self,
data: Sequence[dict],
training: bool = False) -> Tuple[torch.Tensor, Optional[list]]:
def forward(self, data: Dict, training: bool = False) -> Dict:
"""Perform normalization、padding and bgr2rgb conversion based on
``BaseDataPreprocessor``.
Args:
data (Sequence[dict]): data sampled from dataloader.
data (dict): data sampled from dataloader.
training (bool): Whether to enable training time augmentation.
Returns:
Tuple[torch.Tensor, Optional[list]]: Data in the same format as the
model input.
dict: Data in the same format as the model input.
"""
batch_inputs, batch_data_samples = super().forward(
data=data, training=training)
data = super().forward(data=data, training=training)
inputs, data_samples = data['inputs'], data['data_samples']
if batch_data_samples is not None:
batch_input_shape = tuple(batch_inputs[0].size()[-2:])
for data_samples in batch_data_samples:
data_samples.set_metainfo(
if data_samples is not None:
batch_input_shape = tuple(inputs[0].size()[-2:])
for data_sample in data_samples:
data_sample.set_metainfo(
{'batch_input_shape': batch_input_shape})
if training and self.batch_augments is not None:
for batch_aug in self.batch_augments:
batch_inputs, batch_data_samples = batch_aug(
batch_inputs, batch_data_samples)
inputs, data_samples = batch_aug(inputs, data_samples)
return batch_inputs, batch_data_samples
return data

View File

@ -35,8 +35,8 @@ class MMDetWrapper(BaseModel):
self.text_repr_type = text_repr_type
def forward(self,
batch_inputs: torch.Tensor,
batch_data_samples: OptSampleList = None,
inputs: torch.Tensor,
data_samples: OptSampleList = None,
mode: str = 'tensor',
**kwargs) -> ForwardResults:
"""The unified entry for a forward process in both training and test.
@ -54,9 +54,9 @@ class MMDetWrapper(BaseModel):
optimizer updating, which are done in the :meth:`train_step`.
Args:
batch_inputs (torch.Tensor): The input tensor with shape
inputs (torch.Tensor): The input tensor with shape
(N, C, ...) in general.
batch_data_samples (list[:obj:`DetDataSample`], optional): The
data_samples (list[:obj:`DetDataSample`], optional): The
annotation data of every samples. Defaults to None.
mode (str): Return what kind of value. Defaults to 'tensor'.
@ -67,8 +67,8 @@ class MMDetWrapper(BaseModel):
- If ``mode="predict"``, return a list of :obj:`TextDetDataSample`.
- If ``mode="loss"``, return a dict of tensor.
"""
results = self.wrapped_model.forward(batch_inputs, batch_data_samples,
mode, **kwargs)
results = self.wrapped_model.forward(inputs, data_samples, mode,
**kwargs)
if mode == 'predict':
results = self.adapt_predictions(results)

View File

@ -44,47 +44,46 @@ class SingleStageTextDetector(BaseTextDetector):
self.neck = MODELS.build(neck)
self.det_head = MODELS.build(det_head)
def extract_feat(self, batch_inputs: torch.Tensor) -> torch.Tensor:
def extract_feat(self, inputs: torch.Tensor) -> torch.Tensor:
"""Extract features.
Args:
batch_inputs (Tensor): Image tensor with shape (N, C, H ,W).
inputs (Tensor): Image tensor with shape (N, C, H ,W).
Returns:
Tensor or tuple[Tensor]: Multi-level features that may have
different resolutions.
"""
batch_inputs = self.backbone(batch_inputs)
inputs = self.backbone(inputs)
if self.with_neck:
batch_inputs = self.neck(batch_inputs)
return batch_inputs
inputs = self.neck(inputs)
return inputs
def loss(self, batch_inputs: torch.Tensor,
batch_data_samples: Sequence[TextDetDataSample]) -> Dict:
def loss(self, inputs: torch.Tensor,
data_samples: Sequence[TextDetDataSample]) -> Dict:
"""Calculate losses from a batch of inputs and data samples.
Args:
batch_inputs (torch.Tensor): Input images of shape (N, C, H, W).
inputs (torch.Tensor): Input images of shape (N, C, H, W).
Typically these should be mean centered and std scaled.
batch_data_samples (list[TextDetDataSample]): A list of N
data_samples (list[TextDetDataSample]): A list of N
datasamples, containing meta information and gold annotations
for each of the images.
Returns:
dict[str, Tensor]: A dictionary of loss components.
"""
batch_inputs = self.extract_feat(batch_inputs)
return self.det_head.loss(batch_inputs, batch_data_samples)
inputs = self.extract_feat(inputs)
return self.det_head.loss(inputs, data_samples)
def predict(
self, batch_inputs: torch.Tensor,
batch_data_samples: Sequence[TextDetDataSample]
) -> Sequence[TextDetDataSample]:
def predict(self, inputs: torch.Tensor,
data_samples: Sequence[TextDetDataSample]
) -> Sequence[TextDetDataSample]:
"""Predict results from a batch of inputs and data samples with post-
processing.
Args:
batch_inputs (torch.Tensor): Images of shape (N, C, H, W).
batch_data_samples (list[TextDetDataSample]): A list of N
inputs (torch.Tensor): Images of shape (N, C, H, W).
data_samples (list[TextDetDataSample]): A list of N
datasamples, containing meta information and gold annotations
for each of the images.
@ -104,20 +103,19 @@ class SingleStageTextDetector(BaseTextDetector):
Each element represents the polygon of the
instance, in (xn, yn) order.
"""
x = self.extract_feat(batch_inputs)
return self.det_head.predict(x, batch_data_samples)
x = self.extract_feat(inputs)
return self.det_head.predict(x, data_samples)
def _forward(self,
batch_inputs: torch.Tensor,
batch_data_samples: Optional[
Sequence[TextDetDataSample]] = None,
inputs: torch.Tensor,
data_samples: Optional[Sequence[TextDetDataSample]] = None,
**kwargs) -> torch.Tensor:
"""Network forward process. Usually includes backbone, neck and head
forward without any post-processing.
Args:
batch_inputs (Tensor): Inputs with shape (N, C, H, W).
batch_data_samples (list[TextDetDataSample]): A list of N
inputs (Tensor): Inputs with shape (N, C, H, W).
data_samples (list[TextDetDataSample]): A list of N
datasamples, containing meta information and gold annotations
for each of the images.
@ -125,5 +123,5 @@ class SingleStageTextDetector(BaseTextDetector):
Tensor or tuple[Tensor]: A tuple of features from ``det_head``
forward.
"""
x = self.extract_feat(batch_inputs)
return self.det_head(x, batch_data_samples)
x = self.extract_feat(inputs)
return self.det_head(x, data_samples)

View File

@ -63,34 +63,32 @@ class BaseTextDetHead(BaseModule):
self.module_loss = MODELS.build(module_loss)
self.postprocessor = MODELS.build(postprocessor)
def loss(self, x: Tuple[Tensor],
batch_data_samples: DetSampleList) -> dict:
def loss(self, x: Tuple[Tensor], data_samples: DetSampleList) -> dict:
"""Perform forward propagation and loss calculation of the detection
head on the features of the upstream network.
Args:
x (tuple[Tensor]): Features from the upstream network, each is
a 4D-tensor.
batch_data_samples (List[:obj:`DetDataSample`]): The Data
data_samples (List[:obj:`DetDataSample`]): The Data
Samples. It usually includes information such as
`gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.
Returns:
dict: A dictionary of loss components.
"""
outs = self(x, batch_data_samples)
losses = self.module_loss(outs, batch_data_samples)
outs = self(x, data_samples)
losses = self.module_loss(outs, data_samples)
return losses
def loss_and_predict(self, x: Tuple[Tensor],
batch_data_samples: DetSampleList
def loss_and_predict(self, x: Tuple[Tensor], data_samples: DetSampleList
) -> Tuple[dict, DetSampleList]:
"""Perform forward propagation of the head, then calculate loss and
predictions from the features and data samples.
Args:
x (tuple[Tensor]): Features from FPN.
batch_data_samples (list[:obj:`DetDataSample`]): Each item contains
data_samples (list[:obj:`DetDataSample`]): Each item contains
the meta information of each image and corresponding
annotations.
@ -101,21 +99,21 @@ class BaseTextDetHead(BaseModule):
- predictions (list[:obj:`InstanceData`]): Detection
results of each image after the post process.
"""
outs = self(x, batch_data_samples)
losses = self.module_loss(outs, batch_data_samples)
outs = self(x, data_samples)
losses = self.module_loss(outs, data_samples)
predictions = self.postprocessor(outs, batch_data_samples)
predictions = self.postprocessor(outs, data_samples)
return losses, predictions
def predict(self, x: torch.Tensor,
batch_data_samples: DetSampleList) -> DetSampleList:
data_samples: DetSampleList) -> DetSampleList:
"""Perform forward propagation of the detection head and predict
detection results on the features of the upstream network.
Args:
x (tuple[Tensor]): Multi-level features from the
upstream network, each is a 4D-tensor.
batch_data_samples (List[:obj:`DetDataSample`]): The Data
data_samples (List[:obj:`DetDataSample`]): The Data
Samples. It usually includes information such as
`gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.
@ -123,7 +121,7 @@ class BaseTextDetHead(BaseModule):
SampleList: Detection results of each image
after the post process.
"""
outs = self(x, batch_data_samples)
outs = self(x, data_samples)
predictions = self.postprocessor(outs, batch_data_samples)
predictions = self.postprocessor(outs, data_samples)
return predictions

View File

@ -261,15 +261,13 @@ class DRRGHead(BaseTextDetHead):
self.in_channels + self.out_channels) + self.node_geo_feat_len
self.gcn = GCN(node_feat_len)
def loss(
self, batch_inputs: torch.Tensor,
batch_data_samples: List[TextDetDataSample]
) -> Tuple[Tensor, Tensor, Tensor]:
def loss(self, inputs: torch.Tensor, data_samples: List[TextDetDataSample]
) -> Tuple[Tensor, Tensor, Tensor]:
"""Loss function.
Args:
batch_inputs (Tensor): Shape of :math:`(N, C, H, W)`.
batch_data_samples (List[TextDetDataSample]): List of data samples.
inputs (Tensor): Shape of :math:`(N, C, H, W)`.
data_samples (List[TextDetDataSample]): List of data samples.
Returns:
tuple(pred_maps, gcn_pred, gt_labels):
@ -281,27 +279,26 @@ class DRRGHead(BaseTextDetHead):
- gt_labels (Tensor): Ground-truth label of shape
:math:`(m, n)` where :math:`m * n = N`.
"""
targets = self.module_loss.get_targets(batch_data_samples)
targets = self.module_loss.get_targets(data_samples)
gt_comp_attribs = targets[-1]
pred_maps = self.out_conv(batch_inputs)
feat_maps = torch.cat([batch_inputs, pred_maps], dim=1)
pred_maps = self.out_conv(inputs)
feat_maps = torch.cat([inputs, pred_maps], dim=1)
node_feats, adjacent_matrices, knn_inds, gt_labels = self.graph_train(
feat_maps, np.stack(gt_comp_attribs))
gcn_pred = self.gcn(node_feats, adjacent_matrices, knn_inds)
return self.module_loss((pred_maps, gcn_pred, gt_labels),
batch_data_samples)
return self.module_loss((pred_maps, gcn_pred, gt_labels), data_samples)
def forward(
self,
batch_inputs: Tensor,
inputs: Tensor,
data_samples: Optional[List[TextDetDataSample]] = None
) -> Tuple[Tensor, Tensor, Tensor]:
r"""Run DRRG head in prediction mode, and return the raw tensors only.
Args:
batch_inputs (Tensor): Shape of :math:`(1, C, H, W)`.
inputs (Tensor): Shape of :math:`(1, C, H, W)`.
data_samples (list[TextDetDataSample], optional): A list of data
samples. Defaults to None.
@ -317,10 +314,10 @@ class DRRGHead(BaseTextDetHead):
:math:`(M, 9)` where each row corresponds to one box and
its score: (x1, y1, x2, y2, x3, y3, x4, y4, score).
"""
pred_maps = self.out_conv(batch_inputs)
batch_inputs = torch.cat([batch_inputs, pred_maps], dim=1)
pred_maps = self.out_conv(inputs)
inputs = torch.cat([inputs, pred_maps], dim=1)
none_flag, graph_data = self.graph_test(pred_maps, batch_inputs)
none_flag, graph_data = self.graph_test(pred_maps, inputs)
(local_graphs_node_feat, adjacent_matrices, pivots_knn_inds,
pivot_local_graphs, text_comps) = graph_data

View File

@ -1,8 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
from numbers import Number
from typing import List, Optional, Sequence, Tuple, Union
from typing import Dict, List, Optional, Sequence, Union
import torch
import torch.nn as nn
from mmengine.model import ImgDataPreprocessor
@ -25,7 +24,7 @@ class TextRecogDataPreprocessor(ImgDataPreprocessor):
- Pad inputs to the maximum size of current batch with defined
``pad_value``. The padding size can be divisible by a defined
``pad_size_divisor``
- Stack inputs to batch_inputs.
- Stack inputs to inputs.
- Convert inputs from bgr to rgb if the shape of input is (3, H, W).
- Normalize image with defined std and mean.
- Do batch augmentations during training.
@ -52,7 +51,7 @@ class TextRecogDataPreprocessor(ImgDataPreprocessor):
pad_value: Union[float, int] = 0,
bgr_to_rgb: bool = False,
rgb_to_bgr: bool = False,
batch_augments: Optional[List[dict]] = None):
batch_augments: Optional[List[Dict]] = None) -> None:
super().__init__(
mean=mean,
std=std,
@ -66,37 +65,33 @@ class TextRecogDataPreprocessor(ImgDataPreprocessor):
else:
self.batch_augments = None
def forward(self,
data: Sequence[dict],
training: bool = False) -> Tuple[torch.Tensor, Optional[list]]:
def forward(self, data: Dict, training: bool = False) -> Dict:
"""Perform normalization、padding and bgr2rgb conversion based on
``BaseDataPreprocessor``.
Args:
data (Sequence[dict]): data sampled from dataloader.
data (dict): Data sampled from dataloader.
training (bool): Whether to enable training time augmentation.
Returns:
Tuple[torch.Tensor, Optional[list]]: Data in the same format as the
model input.
dict: Data in the same format as the model input.
"""
batch_inputs, batch_data_samples = super().forward(
data=data, training=training)
data = super().forward(data=data, training=training)
inputs, data_samples = data['inputs'], data['data_samples']
if batch_data_samples is not None:
batch_input_shape = tuple(batch_inputs[0].size()[-2:])
for data_samples in batch_data_samples:
if data_samples is not None:
batch_input_shape = tuple(inputs[0].size()[-2:])
for data_sample in data_samples:
valid_ratio = data_samples.valid_ratio * \
data_samples.img_shape[1] / batch_input_shape[1]
data_samples.set_metainfo(
valid_ratio = data_sample.valid_ratio * \
data_sample.img_shape[1] / batch_input_shape[1]
data_sample.set_metainfo(
dict(
valid_ratio=valid_ratio,
batch_input_shape=batch_input_shape))
if training and self.batch_augments is not None:
for batch_aug in self.batch_augments:
batch_inputs, batch_data_samples = batch_aug(
batch_inputs, batch_data_samples)
inputs, data_samples = batch_aug(inputs, data_samples)
return batch_inputs, batch_data_samples
return data

View File

@ -52,8 +52,8 @@ class BaseRecognizer(BaseModel, metaclass=ABCMeta):
pass
def forward(self,
batch_inputs: torch.Tensor,
batch_data_samples: OptRecSampleList = None,
inputs: torch.Tensor,
data_samples: OptRecSampleList = None,
mode: str = 'tensor',
**kwargs) -> RecForwardResults:
"""The unified entry for a forward process in both training and test.
@ -71,9 +71,9 @@ class BaseRecognizer(BaseModel, metaclass=ABCMeta):
optimizer updating, which are done in the :meth:`train_step`.
Args:
batch_inputs (torch.Tensor): The input tensor with shape
inputs (torch.Tensor): The input tensor with shape
(N, C, ...) in general.
batch_data_samples (list[:obj:`DetDataSample`], optional): The
data_samples (list[:obj:`DetDataSample`], optional): The
annotation data of every samples. Defaults to None.
mode (str): Return what kind of value. Defaults to 'tensor'.
@ -85,33 +85,32 @@ class BaseRecognizer(BaseModel, metaclass=ABCMeta):
- If ``mode="loss"``, return a dict of tensor.
"""
if mode == 'loss':
return self.loss(batch_inputs, batch_data_samples, **kwargs)
return self.loss(inputs, data_samples, **kwargs)
elif mode == 'predict':
return self.predict(batch_inputs, batch_data_samples, **kwargs)
return self.predict(inputs, data_samples, **kwargs)
elif mode == 'tensor':
return self._forward(batch_inputs, batch_data_samples, **kwargs)
return self._forward(inputs, data_samples, **kwargs)
else:
raise RuntimeError(f'Invalid mode "{mode}". '
'Only supports loss, predict and tensor mode')
@abstractmethod
def loss(self, batch_inputs: torch.Tensor,
batch_data_samples: RecSampleList,
def loss(self, inputs: torch.Tensor, data_samples: RecSampleList,
**kwargs) -> Union[dict, tuple]:
"""Calculate losses from a batch of inputs and data samples."""
pass
@abstractmethod
def predict(self, batch_inputs: torch.Tensor,
batch_data_samples: RecSampleList, **kwargs) -> RecSampleList:
def predict(self, inputs: torch.Tensor, data_samples: RecSampleList,
**kwargs) -> RecSampleList:
"""Predict results from a batch of inputs and data samples with post-
processing."""
pass
@abstractmethod
def _forward(self,
batch_inputs: torch.Tensor,
batch_data_samples: OptRecSampleList = None,
inputs: torch.Tensor,
data_samples: OptRecSampleList = None,
**kwargs):
"""Network forward process.

View File

@ -59,16 +59,16 @@ class EncoderDecoderRecognizer(BaseRecognizer):
assert decoder is not None
self.decoder = MODELS.build(decoder)
def extract_feat(self, batch_inputs: torch.Tensor) -> torch.Tensor:
def extract_feat(self, inputs: torch.Tensor) -> torch.Tensor:
"""Directly extract features from the backbone."""
if self.with_preprocessor:
batch_inputs = self.preprocessor(batch_inputs)
inputs = self.preprocessor(inputs)
if self.with_backbone:
batch_inputs = self.backbone(batch_inputs)
return batch_inputs
inputs = self.backbone(inputs)
return inputs
def loss(self, batch_inputs: torch.Tensor,
batch_data_samples: RecSampleList, **kwargs) -> Dict:
def loss(self, inputs: torch.Tensor, data_samples: RecSampleList,
**kwargs) -> Dict:
"""Calculate losses from a batch of inputs and data samples.
Args:
inputs (tensor): Input images of shape (N, C, H, W).
@ -80,14 +80,14 @@ class EncoderDecoderRecognizer(BaseRecognizer):
Returns:
dict[str, tensor]: A dictionary of loss components.
"""
feat = self.extract_feat(batch_inputs)
feat = self.extract_feat(inputs)
out_enc = None
if self.with_encoder:
out_enc = self.encoder(feat, batch_data_samples)
return self.decoder.loss(feat, out_enc, batch_data_samples)
out_enc = self.encoder(feat, data_samples)
return self.decoder.loss(feat, out_enc, data_samples)
def predict(self, batch_inputs: torch.Tensor,
batch_data_samples: RecSampleList, **kwargs) -> RecSampleList:
def predict(self, inputs: torch.Tensor, data_samples: RecSampleList,
**kwargs) -> RecSampleList:
"""Predict results from a batch of inputs and data samples with post-
processing.
@ -101,30 +101,30 @@ class EncoderDecoderRecognizer(BaseRecognizer):
list[TextRecogDataSample]: A list of N datasamples of prediction
results. Results are stored in ``pred_text``.
"""
feat = self.extract_feat(batch_inputs)
feat = self.extract_feat(inputs)
out_enc = None
if self.with_encoder:
out_enc = self.encoder(feat, batch_data_samples)
return self.decoder.predict(feat, out_enc, batch_data_samples)
out_enc = self.encoder(feat, data_samples)
return self.decoder.predict(feat, out_enc, data_samples)
def _forward(self,
batch_inputs: torch.Tensor,
batch_data_samples: OptRecSampleList = None,
inputs: torch.Tensor,
data_samples: OptRecSampleList = None,
**kwargs) -> RecForwardResults:
"""Network forward process. Usually includes backbone, encoder and
decoder forward without any post-processing.
Args:
batch_inputs (Tensor): Inputs with shape (N, C, H, W).
batch_data_samples (list[TextRecogDataSample]): A list of N
inputs (Tensor): Inputs with shape (N, C, H, W).
data_samples (list[TextRecogDataSample]): A list of N
datasamples, containing meta information and gold
annotations for each of the images.
Returns:
Tensor: A tuple of features from ``decoder`` forward.
"""
feat = self.extract_feat(batch_inputs)
feat = self.extract_feat(inputs)
out_enc = None
if self.with_encoder:
out_enc = self.encoder(feat, batch_data_samples)
return self.decoder(feat, out_enc, batch_data_samples)
out_enc = self.encoder(feat, data_samples)
return self.decoder(feat, out_enc, data_samples)

View File

@ -37,9 +37,9 @@ class TestPackTextDetInputs(TestCase):
results = transform(copy.deepcopy(datainfo))
self.assertIn('inputs', results)
self.assertTupleEqual(tuple(results['inputs'].shape), (1, 10, 10))
self.assertIn('data_sample', results)
self.assertIn('data_samples', results)
data_sample = results['data_sample']
data_sample = results['data_samples']
self.assertIn('bboxes', data_sample.gt_instances)
self.assertIsInstance(data_sample.gt_instances.bboxes, torch.Tensor)
self.assertEqual(data_sample.gt_instances.bboxes.dtype, torch.float32)
@ -56,9 +56,9 @@ class TestPackTextDetInputs(TestCase):
transform = PackTextDetInputs(meta_keys=('img_path', ))
results = transform(copy.deepcopy(datainfo))
self.assertIn('inputs', results)
self.assertIn('data_sample', results)
self.assertIn('data_samples', results)
data_sample = results['data_sample']
data_sample = results['data_samples']
self.assertIn('bboxes', data_sample.gt_instances)
self.assertIn('img_path', data_sample)
self.assertNotIn('flip', data_sample)
@ -66,14 +66,14 @@ class TestPackTextDetInputs(TestCase):
datainfo.pop('gt_texts')
transform = PackTextDetInputs()
results = transform(copy.deepcopy(datainfo))
data_sample = results['data_sample']
data_sample = results['data_samples']
self.assertNotIn('texts', data_sample.gt_instances)
datainfo = dict(img_shape=(10, 10))
transform = PackTextDetInputs(meta_keys=('img_shape', ))
results = transform(copy.deepcopy(datainfo))
self.assertNotIn('inputs', results)
data_sample = results['data_sample']
data_sample = results['data_samples']
self.assertNotIn('texts', data_sample.gt_instances)
def test_repr(self):
@ -108,8 +108,8 @@ class TestPackTextRecogInputs(TestCase):
results = transform(copy.deepcopy(datainfo))
self.assertIn('inputs', results)
self.assertTupleEqual(tuple(results['inputs'].shape), (1, 10, 10))
self.assertIn('data_sample', results)
data_sample = results['data_sample']
self.assertIn('data_samples', results)
data_sample = results['data_samples']
self.assertEqual(data_sample.gt_text.item, 'mmocr')
self.assertIn('img_path', data_sample)
self.assertIn('valid_ratio', data_sample)
@ -118,8 +118,8 @@ class TestPackTextRecogInputs(TestCase):
transform = PackTextRecogInputs(meta_keys=('img_path', ))
results = transform(copy.deepcopy(datainfo))
self.assertIn('inputs', results)
self.assertIn('data_sample', results)
data_sample = results['data_sample']
self.assertIn('data_samples', results)
data_sample = results['data_samples']
self.assertEqual(data_sample.gt_text.item, 'mmocr')
self.assertIn('img_path', data_sample)
self.assertNotIn('valid_ratio', data_sample)
@ -129,7 +129,7 @@ class TestPackTextRecogInputs(TestCase):
transform = PackTextRecogInputs(meta_keys=('img_shape', ))
results = transform(copy.deepcopy(datainfo))
self.assertNotIn('inputs', results)
data_sample = results['data_sample']
data_sample = results['data_samples']
self.assertNotIn('item', data_sample.gt_text)
def test_repr(self):
@ -165,8 +165,8 @@ class TestPackKIEInputs(TestCase):
results = self.transform(copy.deepcopy(datainfo))
self.assertIn('inputs', results)
self.assertTupleEqual(tuple(results['inputs'].shape), (1, 10, 10))
self.assertIn('data_sample', results)
data_sample = results['data_sample']
self.assertIn('data_samples', results)
data_sample = results['data_samples']
self.assertIsInstance(data_sample.gt_instances.bboxes, torch.Tensor)
self.assertEqual(data_sample.gt_instances.bboxes.dtype, torch.float32)
self.assertEqual(data_sample.gt_instances.labels.dtype, torch.int64)
@ -179,9 +179,9 @@ class TestPackKIEInputs(TestCase):
transform = PackKIEInputs(meta_keys=('img_path', ))
results = transform(copy.deepcopy(datainfo))
self.assertIn('inputs', results)
self.assertIn('data_sample', results)
self.assertIn('data_samples', results)
data_sample = results['data_sample']
data_sample = results['data_samples']
self.assertIn('bboxes', data_sample.gt_instances)
self.assertIn('img_path', data_sample)

View File

@ -200,9 +200,3 @@ class TestResize(unittest.TestCase):
resize = Resize(scale=(40, 30))
result = resize(dummy_result)
self.assertEqual(result['gt_bboxes'].dtype, np.float32)
if __name__ == '__main__':
t = TestRandomCrop()
t.test_sample_crop_box()
t.test_transform()

View File

@ -1,274 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import shutil
import tempfile
from unittest import TestCase
import torch
import torch.nn as nn
from mmengine.config import Config
from mmengine.evaluator import BaseMetric
from mmengine.hooks import Hook
from mmengine.model import BaseModel
from mmengine.registry import DATASETS, HOOKS, METRICS, MODELS
from mmengine.runner import Runner
from torch.utils.data import Dataset
from mmocr.engine.runner import MultiTestLoop, MultiValLoop
@MODELS.register_module()
class ToyModel(BaseModel):
def __init__(self):
super().__init__()
self.linear1 = nn.Linear(2, 2)
self.linear2 = nn.Linear(2, 1)
def forward(self, batch_inputs, labels, mode='tensor'):
labels = torch.stack(labels)
outputs = self.linear1(batch_inputs)
outputs = self.linear2(outputs)
if mode == 'tensor':
return outputs
elif mode == 'loss':
loss = (labels - outputs).sum()
outputs = dict(loss=loss)
return outputs
elif mode == 'predict':
return outputs
@DATASETS.register_module()
class ToyDataset(Dataset):
METAINFO = dict() # type: ignore
data = torch.randn(12, 2)
label = torch.ones(12)
@property
def metainfo(self):
return self.METAINFO
def __len__(self):
return self.data.size(0)
def __getitem__(self, index):
return dict(inputs=self.data[index], data_sample=self.label[index])
@METRICS.register_module()
class ToyMetric3(BaseMetric):
def __init__(self, collect_device='cpu', prefix=''):
super().__init__(collect_device=collect_device, prefix=prefix)
def process(self, data_samples, predictions):
result = {'acc': 1}
self.results.append(result)
def compute_metrics(self, results):
return dict(acc=1)
class TestRunner(TestCase):
def setUp(self):
self.temp_dir = tempfile.mkdtemp()
epoch_based_cfg = dict(
default_scope='mmocr',
model=dict(type='ToyModel'),
work_dir=self.temp_dir,
train_dataloader=dict(
dataset=dict(type='ToyDataset'),
sampler=dict(type='DefaultSampler', shuffle=True),
batch_size=3,
num_workers=0),
val_dataloader=dict(
dataset=dict(type='ToyDataset'),
sampler=dict(type='DefaultSampler', shuffle=False),
batch_size=3,
num_workers=0),
test_dataloader=dict(
dataset=dict(type='ToyDataset'),
sampler=dict(type='DefaultSampler', shuffle=False),
batch_size=3,
num_workers=0),
auto_scale_lr=dict(base_batch_size=16, enable=False),
optim_wrapper=dict(
type='OptimWrapper', optimizer=dict(type='SGD', lr=0.01)),
param_scheduler=dict(type='MultiStepLR', milestones=[1, 2]),
val_evaluator=dict(type='ToyMetric1'),
test_evaluator=dict(type='ToyMetric1'),
train_cfg=dict(
by_epoch=True, max_epochs=3, val_interval=1, val_begin=1),
val_cfg=dict(),
test_cfg=dict(),
custom_hooks=[],
default_hooks=dict(
runtime_info=dict(type='RuntimeInfoHook'),
timer=dict(type='IterTimerHook'),
logger=dict(type='LoggerHook'),
param_scheduler=dict(type='ParamSchedulerHook'),
checkpoint=dict(
type='CheckpointHook', interval=1, by_epoch=True),
sampler_seed=dict(type='DistSamplerSeedHook')),
launcher='none',
env_cfg=dict(dist_cfg=dict(backend='nccl')),
)
self.epoch_based_cfg = Config(epoch_based_cfg)
self.iter_based_cfg = copy.deepcopy(self.epoch_based_cfg)
self.iter_based_cfg.train_dataloader = dict(
dataset=dict(type='ToyDataset'),
sampler=dict(type='InfiniteSampler', shuffle=True),
batch_size=3,
num_workers=0)
self.iter_based_cfg.train_cfg = dict(by_epoch=False, max_iters=12)
self.iter_based_cfg.default_hooks = dict(
runtime_info=dict(type='RuntimeInfoHook'),
timer=dict(type='IterTimerHook'),
logger=dict(type='LoggerHook'),
param_scheduler=dict(type='ParamSchedulerHook'),
checkpoint=dict(type='CheckpointHook', interval=1, by_epoch=False),
sampler_seed=dict(type='DistSamplerSeedHook'))
def tearDown(self):
shutil.rmtree(self.temp_dir)
def test_multi_val_loop(self):
before_val_iter_results = []
after_val_iter_results = []
multi_metrics = dict()
@HOOKS.register_module()
class Fake_1(Hook):
"""test custom train loop."""
def before_val_iter(self, runner, batch_idx, data_batch=None):
before_val_iter_results.append('before')
def after_val_iter(self,
runner,
batch_idx,
data_batch=None,
outputs=None):
after_val_iter_results.append('after')
def after_val_epoch(self, runner, metrics=None) -> None:
multi_metrics.update(metrics)
self.iter_based_cfg.val_cfg = dict(type='MultiValLoop')
self.iter_based_cfg.val_dataloader = [
dict(
dataset=dict(type='ToyDataset'),
sampler=dict(type='DefaultSampler', shuffle=False),
batch_size=3,
num_workers=0),
dict(
dataset=dict(type='ToyDataset'),
sampler=dict(type='DefaultSampler', shuffle=False),
batch_size=3,
num_workers=0)
]
self.iter_based_cfg.val_evaluator = [
dict(type='ToyMetric3', prefix='tmp1'),
dict(type='ToyMetric3', prefix='tmp2')
]
self.iter_based_cfg.custom_hooks = [dict(type='Fake_1', priority=50)]
self.iter_based_cfg.experiment_name = 'test_multi_val_loop'
runner = Runner.from_cfg(self.iter_based_cfg)
runner.val()
self.assertIsInstance(runner.val_loop, MultiValLoop)
# test custom hook triggered as expected
self.assertEqual(len(before_val_iter_results), 8)
self.assertEqual(len(after_val_iter_results), 8)
for before, after in zip(before_val_iter_results,
after_val_iter_results):
self.assertEqual(before, 'before')
self.assertEqual(after, 'after')
self.assertDictEqual(multi_metrics, {'tmp1/acc': 1, 'tmp2/acc': 1})
# test_same prefix
self.iter_based_cfg.val_evaluator = [
dict(type='ToyMetric3', prefix='tmp1'),
dict(type='ToyMetric3', prefix='tmp1')
]
self.iter_based_cfg.experiment_name = 'test_multi_val_loop_same_prefix'
runner = Runner.from_cfg(self.iter_based_cfg)
with self.assertRaisesRegex(ValueError,
('Please set different'
' prefix for different datasets'
' in `val_evaluator`')):
runner.val()
def test_multi_test_loop(self):
before_test_iter_results = []
after_test_iter_results = []
multi_metrics = dict()
@HOOKS.register_module()
class Fake_2(Hook):
"""test custom train loop."""
def before_test_iter(self, runner, batch_idx, data_batch=None):
before_test_iter_results.append('before')
def after_test_iter(self,
runner,
batch_idx,
data_batch=None,
outputs=None):
after_test_iter_results.append('after')
def after_test_epoch(self, runner, metrics=None) -> None:
multi_metrics.update(metrics)
self.iter_based_cfg.test_cfg = dict(type='MultiTestLoop')
self.iter_based_cfg.test_dataloader = [
dict(
dataset=dict(type='ToyDataset'),
sampler=dict(type='DefaultSampler', shuffle=False),
batch_size=3,
num_workers=0),
dict(
dataset=dict(type='ToyDataset'),
sampler=dict(type='DefaultSampler', shuffle=False),
batch_size=3,
num_workers=0)
]
self.iter_based_cfg.test_evaluator = [
dict(type='ToyMetric3', prefix='tmp1'),
dict(type='ToyMetric3', prefix='tmp2')
]
self.iter_based_cfg.custom_hooks = [dict(type='Fake_2', priority=50)]
self.iter_based_cfg.experiment_name = 'multi_test_loop'
runner = Runner.from_cfg(self.iter_based_cfg)
runner.test()
self.assertIsInstance(runner.test_loop, MultiTestLoop)
# test custom hook triggered as expected
self.assertEqual(len(before_test_iter_results), 8)
self.assertEqual(len(after_test_iter_results), 8)
for before, after in zip(before_test_iter_results,
after_test_iter_results):
self.assertEqual(before, 'before')
self.assertEqual(after, 'after')
self.assertDictEqual(multi_metrics, {'tmp1/acc': 1, 'tmp2/acc': 1})
# test_same prefix
self.iter_based_cfg.test_evaluator = [
dict(type='ToyMetric3', prefix='tmp1'),
dict(type='ToyMetric3', prefix='tmp1')
]
self.iter_based_cfg.experiment_name = 'multi_test_loop_same_prefix'
runner = Runner.from_cfg(self.iter_based_cfg)
with self.assertRaisesRegex(ValueError,
('Please set different'
' prefix for different datasets'
' in `test_evaluator`')):
runner.test()

View File

@ -40,9 +40,9 @@ class ToyMetric(BaseMetric):
def process(self, data_batch, predictions):
results = [{
'pred': pred.get('pred'),
'label': data['data_sample'].get('label')
} for pred, data in zip(predictions, data_batch)]
'pred': prediction['pred'],
'label': prediction['label']
} for prediction in predictions]
self.results.extend(results)
def compute_metrics(self, results: List):
@ -67,12 +67,13 @@ def generate_test_results(size, batch_size, pred, label):
bs_residual = size % batch_size
for i in range(num_batch):
bs = bs_residual if i == num_batch - 1 else batch_size
data_batch = [
dict(
inputs=np.zeros((3, 10, 10)),
data_sample=BaseDataElement(label=label)) for _ in range(bs)
data_batch = {
'inputs': [np.zeros((3, 10, 10)) for _ in range(bs)],
'data_samples': [BaseDataElement(label=label) for _ in range(bs)]
}
predictions = [
BaseDataElement(pred=pred, label=label) for _ in range(bs)
]
predictions = [BaseDataElement(pred=pred) for _ in range(bs)]
yield (data_batch, predictions)
@ -92,7 +93,7 @@ class TestMultiDatasetsEvaluator(TestCase):
for data_samples, predictions in generate_test_results(
size, batch_size, pred=1, label=1):
evaluator.process(data_samples, predictions)
evaluator.process(predictions, data_samples)
metrics = evaluator.evaluate(size=size)
@ -109,7 +110,7 @@ class TestMultiDatasetsEvaluator(TestCase):
for data_samples, predictions in generate_test_results(
size, batch_size, pred=1, label=1):
evaluator.process(data_samples, predictions)
evaluator.process(predictions, data_samples)
with self.assertRaises(ValueError):
evaluator.evaluate(size=size)
@ -120,7 +121,7 @@ class TestMultiDatasetsEvaluator(TestCase):
for data_samples, predictions in generate_test_results(
size, batch_size, pred=1, label=1):
evaluator.process(data_samples, predictions)
evaluator.process(predictions, data_samples)
metrics = evaluator.evaluate(size=size)
self.assertIn('Fake/Toy/accuracy', metrics)
self.assertIn('Fake/accuracy', metrics)

View File

@ -46,8 +46,9 @@ class TestSDMGR(unittest.TestCase):
return model
def forward_wrapper(self, model, data, mode):
inputs, data_sample = model.data_preprocessor(data, False)
return model.forward(inputs, data_sample, mode)
out = model.data_preprocessor(data, False)
inputs, data_samples = out['inputs'], out['data_samples']
return model.forward(inputs, data_samples, mode)
def setUp(self):
@ -66,14 +67,11 @@ class TestSDMGR(unittest.TestCase):
edge_labels=torch.LongTensor([[0, 1], [1, 0]]),
texts=['text1', 'text2'],
relations=torch.rand((2, 2, 5)))
self.visual_data = [
dict(inputs=torch.rand((3, 10, 10)), data_sample=data_sample)
]
self.novisual_data = [
dict(
inputs=torch.Tensor([]).reshape((0, 0, 0)),
data_sample=data_sample)
]
self.visual_data = dict(
inputs=[torch.rand((3, 10, 10))], data_samples=[data_sample])
self.novisual_data = dict(
inputs=[torch.Tensor([]).reshape((0, 0, 0))],
data_samples=[data_sample])
def test_forward_loss(self):
result = self.forward_wrapper(

View File

@ -11,8 +11,8 @@ from mmocr.structures import TextDetDataSample
@MODELS.register_module()
class TDAugment(torch.nn.Module):
def forward(self, batch_inputs, batch_data_samples):
return batch_inputs, batch_data_samples
def forward(self, inputs, data_samples):
return inputs, data_samples
class TestTextDetDataPreprocessor(TestCase):
@ -47,58 +47,63 @@ class TestTextDetDataPreprocessor(TestCase):
def test_forward(self):
processor = TextDetDataPreprocessor(mean=[0, 0, 0], std=[1, 1, 1])
data = [{
'inputs':
torch.randint(0, 256, (3, 11, 10)),
'data_sample':
TextDetDataSample(
metainfo=dict(img_shape=(11, 10), valid_ratio=1.0))
}]
inputs, data_samples = processor(data)
print(inputs.dtype)
data = {
'inputs': [
torch.randint(0, 256, (3, 11, 10)),
],
'data_samples': [
TextDetDataSample(
metainfo=dict(img_shape=(11, 10), valid_ratio=1.0)),
]
}
out = processor(data)
inputs, data_samples = out['inputs'], out['data_samples']
self.assertEqual(inputs.shape, (1, 3, 11, 10))
self.assertEqual(len(data_samples), 1)
# test channel_conversion
processor = TextDetDataPreprocessor(
mean=[0., 0., 0.], std=[1., 1., 1.], bgr_to_rgb=True)
inputs, data_samples = processor(data)
out = processor(data)
inputs, data_samples = out['inputs'], out['data_samples']
self.assertEqual(inputs.shape, (1, 3, 11, 10))
self.assertEqual(len(data_samples), 1)
# test padding
data = [{
'inputs': torch.randint(0, 256, (3, 10, 11))
}, {
'inputs': torch.randint(0, 256, (3, 9, 14))
}]
data = {
'inputs': [
torch.randint(0, 256, (3, 10, 11)),
torch.randint(0, 256, (3, 9, 14))
]
}
processor = TextDetDataPreprocessor(
mean=[0., 0., 0.], std=[1., 1., 1.], bgr_to_rgb=True)
inputs, data_samples = processor(data)
out = processor(data)
inputs, data_samples = out['inputs'], out['data_samples']
self.assertEqual(inputs.shape, (2, 3, 10, 14))
self.assertIsNone(data_samples)
# test pad_size_divisor
data = [{
'inputs':
torch.randint(0, 256, (3, 10, 11)),
'data_sample':
TextDetDataSample(
metainfo=dict(img_shape=(10, 11), valid_ratio=1.0))
}, {
'inputs':
torch.randint(0, 256, (3, 9, 24)),
'data_sample':
TextDetDataSample(
metainfo=dict(img_shape=(9, 24), valid_ratio=1.0))
}]
data = {
'inputs': [
torch.randint(0, 256, (3, 10, 11)),
torch.randint(0, 256, (3, 9, 24))
],
'data_samples': [
TextDetDataSample(
metainfo=dict(img_shape=(10, 11), valid_ratio=1.0)),
TextDetDataSample(
metainfo=dict(img_shape=(9, 24), valid_ratio=1.0))
]
}
aug_cfg = [dict(type='TDAugment')]
processor = TextDetDataPreprocessor(
mean=[0., 0., 0.],
std=[1., 1., 1.],
pad_size_divisor=5,
batch_augments=aug_cfg)
inputs, data_samples = processor(data, training=True)
out = processor(data)
inputs, data_samples = out['inputs'], out['data_samples']
self.assertEqual(inputs.shape, (2, 3, 10, 25))
self.assertEqual(len(data_samples), 2)
for data_sample, expected_shape in zip(data_samples, [(10, 25),

View File

@ -262,9 +262,3 @@ class TestMMDetWrapper(unittest.TestCase):
self.assertEqual(len(results), 1)
self.assertIsInstance(results[0], TextDetDataSample)
self.assertTrue('polygons' in results[0].pred_instances.keys())
if __name__ == '__main__':
test = TestMMDetWrapper()
test.setUp()
test.test_mask_two_stage_wrapper()

View File

@ -11,8 +11,8 @@ from mmocr.structures import TextRecogDataSample
@MODELS.register_module()
class Augment(torch.nn.Module):
def forward(self, batch_inputs, batch_data_samples):
return batch_inputs, batch_data_samples
def forward(self, inputs, data_samples):
return inputs, data_samples
class TestTextRecogDataPreprocessor(TestCase):
@ -47,14 +47,17 @@ class TestTextRecogDataPreprocessor(TestCase):
def test_forward(self):
processor = TextRecogDataPreprocessor(mean=[0, 0, 0], std=[1, 1, 1])
data = [{
'inputs':
torch.randint(0, 256, (3, 11, 10)),
'data_sample':
TextRecogDataSample(
metainfo=dict(img_shape=(11, 10), valid_ratio=1.0))
}]
inputs, data_samples = processor(data)
data = {
'inputs': [
torch.randint(0, 256, (3, 11, 10)),
],
'data_samples': [
TextRecogDataSample(
metainfo=dict(img_shape=(11, 10), valid_ratio=1.0)),
]
}
out = processor(data)
inputs, data_samples = out['inputs'], out['data_samples']
print(inputs.dtype)
self.assertEqual(inputs.shape, (1, 3, 11, 10))
self.assertEqual(len(data_samples), 1)
@ -62,43 +65,46 @@ class TestTextRecogDataPreprocessor(TestCase):
# test channel_conversion
processor = TextRecogDataPreprocessor(
mean=[0., 0., 0.], std=[1., 1., 1.], bgr_to_rgb=True)
inputs, data_samples = processor(data)
out = processor(data)
inputs, data_samples = out['inputs'], out['data_samples']
self.assertEqual(inputs.shape, (1, 3, 11, 10))
self.assertEqual(len(data_samples), 1)
# test padding
data = [{
'inputs': torch.randint(0, 256, (3, 10, 11))
}, {
'inputs': torch.randint(0, 256, (3, 9, 14))
}]
data = {
'inputs': [
torch.randint(0, 256, (3, 10, 11)),
torch.randint(0, 256, (3, 9, 14))
]
}
processor = TextRecogDataPreprocessor(
mean=[0., 0., 0.], std=[1., 1., 1.], bgr_to_rgb=True)
inputs, data_samples = processor(data)
out = processor(data)
inputs, data_samples = out['inputs'], out['data_samples']
self.assertEqual(inputs.shape, (2, 3, 10, 14))
self.assertIsNone(data_samples)
# test pad_size_divisor
data = [{
'inputs':
torch.randint(0, 256, (3, 10, 11)),
'data_sample':
TextRecogDataSample(
metainfo=dict(img_shape=(10, 11), valid_ratio=1.0))
}, {
'inputs':
torch.randint(0, 256, (3, 9, 24)),
'data_sample':
TextRecogDataSample(
metainfo=dict(img_shape=(9, 24), valid_ratio=1.0))
}]
data = {
'inputs': [
torch.randint(0, 256, (3, 10, 11)),
torch.randint(0, 256, (3, 9, 24)),
],
'data_samples': [
TextRecogDataSample(
metainfo=dict(img_shape=(10, 11), valid_ratio=1.0)),
TextRecogDataSample(
metainfo=dict(img_shape=(9, 24), valid_ratio=1.0))
]
}
aug_cfg = [dict(type='Augment')]
processor = TextRecogDataPreprocessor(
mean=[0., 0., 0.],
std=[1., 1., 1.],
pad_size_divisor=5,
batch_augments=aug_cfg)
inputs, data_samples = processor(data, training=True)
out = processor(data)
inputs, data_samples = out['inputs'], out['data_samples']
self.assertEqual(inputs.shape, (2, 3, 10, 25))
self.assertEqual(len(data_samples), 2)
for data_sample, expected_shape, expected_ratio in zip(