mirror of https://github.com/open-mmlab/mmocr.git
[Refactor] Adapt to new dataflow (#1305)
* datasample->datasamples * update rec data preprocessor * rename datasamples * update det preprocessor * update metric * update data_sample->data_samples in test * update * fix data preprocessor uts * remove engine runner * fix kie ut * fix ut * fix comments * refactor evaluator ut * apply comments Co-authored-by: Tong Gao <gaotongxiao@gmail.com> * remove useless * apply comments Co-authored-by: Tong Gao <gaotongxiao@gmail.com> * apply comments Co-authored-by: Tong Gao <gaotongxiao@gmail.com> Co-authored-by: Tong Gao <gaotongxiao@gmail.com>pull/1324/head
parent
1b5764b155
commit
9bd5258513
|
@ -1,31 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from argparse import ArgumentParser
|
||||
|
||||
from mmocr.apis import init_detector
|
||||
from mmocr.apis.inference import text_model_inference
|
||||
from mmocr.registry import DATASETS # NOQA
|
||||
|
||||
|
||||
def main():
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument('config', help='Config file.')
|
||||
parser.add_argument('checkpoint', help='Checkpoint file.')
|
||||
parser.add_argument(
|
||||
'--device', default='cuda:0', help='Device used for inference.')
|
||||
args = parser.parse_args()
|
||||
|
||||
# build the model from a config file and a checkpoint file
|
||||
model = init_detector(args.config, args.checkpoint, device=args.device)
|
||||
|
||||
# test a single text
|
||||
input_sentence = input('Please enter a sentence you want to test: ')
|
||||
result = text_model_inference(model, input_sentence)
|
||||
|
||||
# show the results
|
||||
for pred_entities in result:
|
||||
for entity in pred_entities:
|
||||
print(f'{entity[0]}: {input_sentence[entity[1]:entity[2] + 1]}')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -17,7 +17,7 @@ class PackTextDetInputs(BaseTransform):
|
|||
The type of outputs is `dict`:
|
||||
|
||||
- inputs: image converted to tensor, whose shape is (C, H, W).
|
||||
- data_sample: Two components of ``TextDetDataSample`` will be updated:
|
||||
- data_samples: Two components of ``TextDetDataSample`` will be updated:
|
||||
|
||||
- gt_instances (InstanceData): Depending on annotations, a subset of the
|
||||
following keys will be updated:
|
||||
|
@ -82,7 +82,7 @@ class PackTextDetInputs(BaseTransform):
|
|||
dict:
|
||||
|
||||
- 'inputs' (obj:`torch.Tensor`): Data for model forwarding.
|
||||
- 'data_sample' (obj:`DetDataSample`): The annotation info of the
|
||||
- 'data_samples' (obj:`DetDataSample`): The annotation info of the
|
||||
sample.
|
||||
"""
|
||||
packed_results = dict()
|
||||
|
@ -109,7 +109,7 @@ class PackTextDetInputs(BaseTransform):
|
|||
for key in self.meta_keys:
|
||||
img_meta[key] = results[key]
|
||||
data_sample.set_metainfo(img_meta)
|
||||
packed_results['data_sample'] = data_sample
|
||||
packed_results['data_samples'] = data_sample
|
||||
|
||||
return packed_results
|
||||
|
||||
|
@ -126,7 +126,7 @@ class PackTextRecogInputs(BaseTransform):
|
|||
The type of outputs is `dict`:
|
||||
|
||||
- inputs: Image as a tensor, whose shape is (C, H, W).
|
||||
- data_sample: Two components of ``TextRecogDataSample`` will be updated:
|
||||
- data_samples: Two components of ``TextRecogDataSample`` will be updated:
|
||||
|
||||
- gt_text (LabelData):
|
||||
|
||||
|
@ -166,7 +166,7 @@ class PackTextRecogInputs(BaseTransform):
|
|||
dict:
|
||||
|
||||
- 'inputs' (obj:`torch.Tensor`): Data for model forwarding.
|
||||
- 'data_sample' (obj:`TextRecogDataSample`): The annotation info
|
||||
- 'data_samples' (obj:`TextRecogDataSample`): The annotation info
|
||||
of the sample.
|
||||
"""
|
||||
packed_results = dict()
|
||||
|
@ -195,7 +195,7 @@ class PackTextRecogInputs(BaseTransform):
|
|||
img_meta[key] = results[key]
|
||||
data_sample.set_metainfo(img_meta)
|
||||
|
||||
packed_results['data_sample'] = data_sample
|
||||
packed_results['data_samples'] = data_sample
|
||||
|
||||
return packed_results
|
||||
|
||||
|
@ -212,7 +212,7 @@ class PackKIEInputs(BaseTransform):
|
|||
The type of outputs is `dict`:
|
||||
|
||||
- inputs: image converted to tensor, whose shape is (C, H, W).
|
||||
- data_sample: Two components of ``TextDetDataSample`` will be updated:
|
||||
- data_samples: Two components of ``TextDetDataSample`` will be updated:
|
||||
|
||||
- gt_instances (InstanceData): Depending on annotations, a subset of the
|
||||
following keys will be updated:
|
||||
|
@ -266,7 +266,7 @@ class PackKIEInputs(BaseTransform):
|
|||
dict:
|
||||
|
||||
- 'inputs' (obj:`torch.Tensor`): Data for model forwarding.
|
||||
- 'data_sample' (obj:`DetDataSample`): The annotation info of the
|
||||
- 'data_samples' (obj:`DetDataSample`): The annotation info of the
|
||||
sample.
|
||||
"""
|
||||
packed_results = dict()
|
||||
|
@ -295,7 +295,7 @@ class PackKIEInputs(BaseTransform):
|
|||
for key in self.meta_keys:
|
||||
img_meta[key] = results[key]
|
||||
data_sample.set_metainfo(img_meta)
|
||||
packed_results['data_sample'] = data_sample
|
||||
packed_results['data_samples'] = data_sample
|
||||
|
||||
return packed_results
|
||||
|
||||
|
|
|
@ -1,3 +1,2 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .hooks import * # NOQA
|
||||
from .runner import * # NOQA
|
||||
|
|
|
@ -1,4 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .multi_loops import MultiTestLoop, MultiValLoop
|
||||
|
||||
__all__ = ['MultiValLoop', 'MultiTestLoop']
|
|
@ -1,202 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
from typing import Dict, List, Sequence, Union
|
||||
|
||||
import torch
|
||||
from mmengine.evaluator import Evaluator
|
||||
from mmengine.runner.amp import autocast
|
||||
from mmengine.runner.base_loop import BaseLoop
|
||||
from mmengine.utils import is_list_of
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from mmocr.registry import LOOPS
|
||||
|
||||
|
||||
@LOOPS.register_module()
|
||||
class MultiValLoop(BaseLoop):
|
||||
"""Loop for validation multi-datasets.
|
||||
|
||||
Args:
|
||||
runner (Runner): A reference of runner.
|
||||
dataloader (list[Dataloader or dic]): A dataloader object or a dict to
|
||||
build a dataloader.
|
||||
evaluator (list[]): Used for computing metrics.
|
||||
fp16 (bool): Whether to enable fp16 validation. Defaults to False.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
runner,
|
||||
dataloader: Union[DataLoader, Dict],
|
||||
evaluator: Union[Evaluator, Dict, List],
|
||||
fp16: bool = False) -> None:
|
||||
self._runner = runner
|
||||
assert isinstance(dataloader, list)
|
||||
self.dataloaders = list()
|
||||
for loader in dataloader:
|
||||
if isinstance(loader, dict):
|
||||
self.dataloaders.append(
|
||||
runner.build_dataloader(loader, seed=runner.seed))
|
||||
else:
|
||||
self.dataloaders.append(loader)
|
||||
|
||||
assert isinstance(evaluator, list)
|
||||
self.evaluators = list()
|
||||
for single_evalator in evaluator:
|
||||
if isinstance(single_evalator, dict) or is_list_of(
|
||||
single_evalator, dict):
|
||||
self.evaluators.append(runner.build_evaluator(single_evalator))
|
||||
else:
|
||||
self.evaluators.append(single_evalator)
|
||||
self.evaluators = [runner.build_evaluator(eval) for eval in evaluator]
|
||||
|
||||
assert len(self.evaluators) == len(self.dataloaders)
|
||||
|
||||
self.fp16 = fp16
|
||||
|
||||
def run(self):
|
||||
"""Launch validation."""
|
||||
self.runner.call_hook('before_val')
|
||||
|
||||
self.runner.model.eval()
|
||||
multi_metric = dict()
|
||||
self.runner.call_hook('before_val_epoch')
|
||||
for evaluator, dataloader in zip(self.evaluators, self.dataloaders):
|
||||
self.evaluator = evaluator
|
||||
self.dataloader = dataloader
|
||||
if hasattr(self.dataloader.dataset, 'metainfo'):
|
||||
self.evaluator.dataset_meta = self.dataloader.dataset.metainfo
|
||||
self.runner.visualizer.dataset_meta = \
|
||||
self.dataloader.dataset.metainfo
|
||||
else:
|
||||
warnings.warn(
|
||||
f'Dataset {self.dataloader.dataset.__class__.__name__} '
|
||||
'has no metainfo. ``dataset_meta`` in evaluator, metric'
|
||||
' and visualizer will be None.')
|
||||
for idx, data_batch in enumerate(self.dataloader):
|
||||
self.run_iter(idx, data_batch)
|
||||
# compute metrics
|
||||
metrics = self.evaluator.evaluate(len(self.dataloader.dataset))
|
||||
if multi_metric and metrics.keys() & multi_metric.keys():
|
||||
raise ValueError('Please set different prefix for different'
|
||||
' datasets in `val_evaluator`')
|
||||
else:
|
||||
|
||||
multi_metric.update(metrics)
|
||||
self.runner.call_hook('after_val_epoch', metrics=multi_metric)
|
||||
self.runner.call_hook('after_val')
|
||||
|
||||
@torch.no_grad()
|
||||
def run_iter(self, idx: int, data_batch: Sequence[dict]):
|
||||
"""Iterate one mini-batch.
|
||||
|
||||
Args:
|
||||
idx (int): The index of the current batch in the loop.
|
||||
data_batch (Sequence[dict]): Batch of data
|
||||
from dataloader.
|
||||
"""
|
||||
self.runner.call_hook(
|
||||
'before_val_iter', batch_idx=idx, data_batch=data_batch)
|
||||
# outputs should be sequence of BaseDataElement
|
||||
with autocast(enabled=self.fp16):
|
||||
outputs = self.runner.model.val_step(data_batch)
|
||||
self.evaluator.process(data_batch, outputs)
|
||||
self.runner.call_hook(
|
||||
'after_val_iter',
|
||||
batch_idx=idx,
|
||||
data_batch=data_batch,
|
||||
outputs=outputs)
|
||||
|
||||
|
||||
@LOOPS.register_module()
|
||||
class MultiTestLoop(BaseLoop):
|
||||
"""Loop for validation multi-datasets.
|
||||
|
||||
Args:
|
||||
runner (Runner): A reference of runner.
|
||||
dataloader (Dataloader or dict): A dataloader object or a dict to
|
||||
build a dataloader.
|
||||
evaluator (Evaluator or dict or list): Used for computing metrics.
|
||||
fp16 (bool): Whether to enable fp16 validation. Defaults to False.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
runner,
|
||||
dataloader: Union[DataLoader, Dict],
|
||||
evaluator: Union[Evaluator, Dict, List],
|
||||
fp16: bool = False) -> None:
|
||||
self._runner = runner
|
||||
assert isinstance(dataloader, list)
|
||||
self.dataloaders = list()
|
||||
for loader in dataloader:
|
||||
if isinstance(loader, dict):
|
||||
self.dataloaders.append(
|
||||
runner.build_dataloader(loader, seed=runner.seed))
|
||||
else:
|
||||
self.dataloaders.append(loader)
|
||||
|
||||
assert isinstance(evaluator, list)
|
||||
self.evaluators = list()
|
||||
for single_evalator in evaluator:
|
||||
if isinstance(single_evalator, dict) or is_list_of(
|
||||
single_evalator, dict):
|
||||
self.evaluators.append(runner.build_evaluator(single_evalator))
|
||||
else:
|
||||
self.evaluators.append(single_evalator)
|
||||
self.evaluators = [runner.build_evaluator(eval) for eval in evaluator]
|
||||
|
||||
assert len(self.evaluators) == len(self.dataloaders)
|
||||
|
||||
self.fp16 = fp16
|
||||
|
||||
def run(self):
|
||||
"""Launch test."""
|
||||
self.runner.call_hook('before_test')
|
||||
|
||||
self.runner.model.eval()
|
||||
multi_metric = dict()
|
||||
self.runner.call_hook('before_test_epoch')
|
||||
for evaluator, dataloader in zip(self.evaluators, self.dataloaders):
|
||||
self.dataloader = dataloader
|
||||
self.evaluator = evaluator
|
||||
if hasattr(self.dataloader.dataset, 'metainfo'):
|
||||
self.evaluator.dataset_meta = self.dataloader.dataset.metainfo
|
||||
self.runner.visualizer.dataset_meta = \
|
||||
self.dataloader.dataset.metainfo
|
||||
else:
|
||||
warnings.warn(
|
||||
f'Dataset {self.dataloader.dataset.__class__.__name__} '
|
||||
'has no metainfo. ``dataset_meta`` in evaluator, metric'
|
||||
' and visualizer will be None.')
|
||||
for idx, data_batch in enumerate(self.dataloader):
|
||||
self.run_iter(idx, data_batch)
|
||||
# compute metrics
|
||||
metrics = self.evaluator.evaluate(len(self.dataloader.dataset))
|
||||
if multi_metric and metrics.keys() & multi_metric.keys():
|
||||
raise ValueError('Please set different prefix for different'
|
||||
' datasets in `test_evaluator`')
|
||||
else:
|
||||
|
||||
multi_metric.update(metrics)
|
||||
self.runner.call_hook('after_test_epoch', metrics=multi_metric)
|
||||
self.runner.call_hook('after_test')
|
||||
|
||||
@torch.no_grad()
|
||||
def run_iter(self, idx: int, data_batch: Sequence[dict]):
|
||||
"""Iterate one mini-batch.
|
||||
|
||||
Args:
|
||||
idx (int): The index of the current batch in the loop.
|
||||
data_batch (Sequence[dict]): Batch of data
|
||||
from dataloader.
|
||||
"""
|
||||
self.runner.call_hook(
|
||||
'before_test_iter', batch_idx=idx, data_batch=data_batch)
|
||||
# outputs should be sequence of BaseDataElement
|
||||
with autocast(enabled=self.fp16):
|
||||
predictions = self.runner.model.test_step(data_batch)
|
||||
self.evaluator.process(data_batch, predictions)
|
||||
self.runner.call_hook(
|
||||
'after_test_iter',
|
||||
batch_idx=idx,
|
||||
data_batch=data_batch,
|
||||
outputs=predictions)
|
|
@ -85,19 +85,18 @@ class F1Metric(BaseMetric):
|
|||
self.key = key
|
||||
|
||||
def process(self, data_batch: Sequence[Dict],
|
||||
predictions: Sequence[Dict]) -> None:
|
||||
"""Process one batch of predictions. The processed results should be
|
||||
data_samples: Sequence[Dict]) -> None:
|
||||
"""Process one batch of data_samples. The processed results should be
|
||||
stored in ``self.results``, which will be used to compute the metrics
|
||||
when all batches have been processed.
|
||||
|
||||
Args:
|
||||
data_batch (Sequence[Dict]): A batch of gts.
|
||||
predictions (Sequence[Dict]): A batch of outputs from the model.
|
||||
data_samples (Sequence[Dict]): A batch of outputs from the model.
|
||||
"""
|
||||
for data_samples in predictions:
|
||||
pred_labels = data_samples.get('pred_instances').get(
|
||||
self.key).cpu()
|
||||
gt_labels = data_samples.get('gt_instances').get(self.key).cpu()
|
||||
for data_sample in data_samples:
|
||||
pred_labels = data_sample.get('pred_instances').get(self.key).cpu()
|
||||
gt_labels = data_sample.get('gt_instances').get(self.key).cpu()
|
||||
|
||||
result = dict(
|
||||
pred_labels=pred_labels.flatten(),
|
||||
|
|
|
@ -75,17 +75,17 @@ class HmeanIOUMetric(BaseMetric):
|
|||
self.strategy = strategy
|
||||
|
||||
def process(self, data_batch: Sequence[Dict],
|
||||
predictions: Sequence[Dict]) -> None:
|
||||
data_samples: Sequence[Dict]) -> None:
|
||||
"""Process one batch of data samples and predictions. The processed
|
||||
results should be stored in ``self.results``, which will be used to
|
||||
compute the metrics when all batches have been processed.
|
||||
|
||||
Args:
|
||||
data_batch (Sequence[Dict]): A batch of data from dataloader.
|
||||
predictions (Sequence[Dict]): A batch of outputs from
|
||||
data_samples (Sequence[Dict]): A batch of outputs from
|
||||
the model.
|
||||
"""
|
||||
for data_sample in predictions:
|
||||
for data_sample in data_samples:
|
||||
|
||||
pred_instances = data_sample.get('pred_instances')
|
||||
pred_polygons = pred_instances.get('polygons')
|
||||
|
|
|
@ -51,16 +51,16 @@ class WordMetric(BaseMetric):
|
|||
self.mode = set(mode)
|
||||
|
||||
def process(self, data_batch: Sequence[Dict],
|
||||
predictions: Sequence[Dict]) -> None:
|
||||
"""Process one batch of predictions. The processed results should be
|
||||
data_samples: Sequence[Dict]) -> None:
|
||||
"""Process one batch of data_samples. The processed results should be
|
||||
stored in ``self.results``, which will be used to compute the metrics
|
||||
when all batches have been processed.
|
||||
|
||||
Args:
|
||||
data_batch (Sequence[Dict]): A batch of gts.
|
||||
predictions (Sequence[Dict]): A batch of outputs from the model.
|
||||
data_samples (Sequence[Dict]): A batch of outputs from the model.
|
||||
"""
|
||||
for data_sample in predictions:
|
||||
for data_sample in data_samples:
|
||||
match_num = 0
|
||||
match_ignore_case_num = 0
|
||||
match_ignore_case_symbol_num = 0
|
||||
|
@ -149,16 +149,16 @@ class CharMetric(BaseMetric):
|
|||
self.valid_symbol = re.compile(valid_symbol)
|
||||
|
||||
def process(self, data_batch: Sequence[Dict],
|
||||
predictions: Sequence[Dict]) -> None:
|
||||
"""Process one batch of predictions. The processed results should be
|
||||
data_samples: Sequence[Dict]) -> None:
|
||||
"""Process one batch of data_samples. The processed results should be
|
||||
stored in ``self.results``, which will be used to compute the metrics
|
||||
when all batches have been processed.
|
||||
|
||||
Args:
|
||||
data_batch (Sequence[Dict]): A batch of gts.
|
||||
predictions (Sequence[Dict]): A batch of outputs from the model.
|
||||
data_samples (Sequence[Dict]): A batch of outputs from the model.
|
||||
"""
|
||||
for data_sample in predictions:
|
||||
for data_sample in data_samples:
|
||||
pred_text = data_sample.get('pred_text').get('item')
|
||||
gt_text = data_sample.get('gt_text').get('item')
|
||||
gt_text_lower = gt_text.lower()
|
||||
|
@ -249,16 +249,16 @@ class OneMinusNEDMetric(BaseMetric):
|
|||
self.valid_symbol = re.compile(valid_symbol)
|
||||
|
||||
def process(self, data_batch: Sequence[Dict],
|
||||
predictions: Sequence[Dict]) -> None:
|
||||
"""Process one batch of predictions. The processed results should be
|
||||
data_samples: Sequence[Dict]) -> None:
|
||||
"""Process one batch of data_samples. The processed results should be
|
||||
stored in ``self.results``, which will be used to compute the metrics
|
||||
when all batches have been processed.
|
||||
|
||||
Args:
|
||||
data_batch (Sequence[Dict]): A batch of gts.
|
||||
predictions (Sequence[Dict]): A batch of outputs from the model.
|
||||
data_samples (Sequence[Dict]): A batch of outputs from the model.
|
||||
"""
|
||||
for data_sample in predictions:
|
||||
for data_sample in data_samples:
|
||||
pred_text = data_sample.get('pred_text').get('item')
|
||||
gt_text = data_sample.get('gt_text').get('item')
|
||||
gt_text_lower = gt_text.lower()
|
||||
|
|
|
@ -84,8 +84,8 @@ class SDMGR(BaseModel):
|
|||
return feats.view(feats.size(0), -1)
|
||||
|
||||
def forward(self,
|
||||
batch_inputs: torch.Tensor,
|
||||
batch_data_samples: Sequence[KIEDataSample] = None,
|
||||
inputs: torch.Tensor,
|
||||
data_samples: Sequence[KIEDataSample] = None,
|
||||
mode: str = 'tensor',
|
||||
**kwargs) -> torch.Tensor:
|
||||
"""The unified entry for a forward process in both training and test.
|
||||
|
@ -103,9 +103,9 @@ class SDMGR(BaseModel):
|
|||
optimizer updating, which are done in the :meth:`train_step`.
|
||||
|
||||
Args:
|
||||
batch_inputs (torch.Tensor): The input tensor with shape
|
||||
inputs (torch.Tensor): The input tensor with shape
|
||||
(N, C, ...) in general.
|
||||
batch_data_samples (list[:obj:`DetDataSample`], optional): The
|
||||
data_samples (list[:obj:`DetDataSample`], optional): The
|
||||
annotation data of every samples. Defaults to None.
|
||||
mode (str): Return what kind of value. Defaults to 'tensor'.
|
||||
|
||||
|
@ -117,43 +117,42 @@ class SDMGR(BaseModel):
|
|||
- If ``mode="loss"``, return a dict of tensor.
|
||||
"""
|
||||
if mode == 'loss':
|
||||
return self.loss(batch_inputs, batch_data_samples, **kwargs)
|
||||
return self.loss(inputs, data_samples, **kwargs)
|
||||
elif mode == 'predict':
|
||||
return self.predict(batch_inputs, batch_data_samples, **kwargs)
|
||||
return self.predict(inputs, data_samples, **kwargs)
|
||||
elif mode == 'tensor':
|
||||
return self._forward(batch_inputs, batch_data_samples, **kwargs)
|
||||
return self._forward(inputs, data_samples, **kwargs)
|
||||
else:
|
||||
raise RuntimeError(f'Invalid mode "{mode}". '
|
||||
'Only supports loss, predict and tensor mode')
|
||||
|
||||
def loss(self, batch_inputs: torch.Tensor,
|
||||
batch_data_samples: Sequence[KIEDataSample], **kwargs) -> dict:
|
||||
def loss(self, inputs: torch.Tensor, data_samples: Sequence[KIEDataSample],
|
||||
**kwargs) -> dict:
|
||||
"""Calculate losses from a batch of inputs and data samples.
|
||||
|
||||
Args:
|
||||
batch_inputs (torch.Tensor): Input images of shape (N, C, H, W).
|
||||
inputs (torch.Tensor): Input images of shape (N, C, H, W).
|
||||
Typically these should be mean centered and std scaled.
|
||||
batch_data_samples (list[KIEDataSample]): A list of N datasamples,
|
||||
data_samples (list[KIEDataSample]): A list of N datasamples,
|
||||
containing meta information and gold annotations for each of
|
||||
the images.
|
||||
Returns:
|
||||
dict[str, Tensor]: A dictionary of loss components.
|
||||
"""
|
||||
x = self.extract_feat(batch_inputs, [
|
||||
data_sample.gt_instances.bboxes
|
||||
for data_sample in batch_data_samples
|
||||
])
|
||||
return self.kie_head.loss(x, batch_data_samples)
|
||||
x = self.extract_feat(
|
||||
inputs,
|
||||
[data_sample.gt_instances.bboxes for data_sample in data_samples])
|
||||
return self.kie_head.loss(x, data_samples)
|
||||
|
||||
def predict(self, batch_inputs: torch.Tensor,
|
||||
batch_data_samples: Sequence[KIEDataSample],
|
||||
def predict(self, inputs: torch.Tensor,
|
||||
data_samples: Sequence[KIEDataSample],
|
||||
**kwargs) -> List[KIEDataSample]:
|
||||
"""Predict results from a batch of inputs and data samples with post-
|
||||
processing.
|
||||
Args:
|
||||
batch_inputs (torch.Tensor): Input images of shape (N, C, H, W).
|
||||
inputs (torch.Tensor): Input images of shape (N, C, H, W).
|
||||
Typically these should be mean centered and std scaled.
|
||||
batch_data_samples (list[KIEDataSample]): A list of N datasamples,
|
||||
data_samples (list[KIEDataSample]): A list of N datasamples,
|
||||
containing meta information and gold annotations for each of
|
||||
the images.
|
||||
|
||||
|
@ -162,22 +161,21 @@ class SDMGR(BaseModel):
|
|||
Results are stored in ``pred_instances.labels`` and
|
||||
``pred_instances.edge_labels``.
|
||||
"""
|
||||
x = self.extract_feat(batch_inputs, [
|
||||
data_sample.gt_instances.bboxes
|
||||
for data_sample in batch_data_samples
|
||||
])
|
||||
return self.kie_head.predict(x, batch_data_samples)
|
||||
x = self.extract_feat(
|
||||
inputs,
|
||||
[data_sample.gt_instances.bboxes for data_sample in data_samples])
|
||||
return self.kie_head.predict(x, data_samples)
|
||||
|
||||
def _forward(self, batch_inputs: torch.Tensor,
|
||||
batch_data_samples: Sequence[KIEDataSample],
|
||||
def _forward(self, inputs: torch.Tensor,
|
||||
data_samples: Sequence[KIEDataSample],
|
||||
**kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Get the raw tensor outputs from backbone and head without any post-
|
||||
processing.
|
||||
|
||||
Args:
|
||||
batch_inputs (torch.Tensor): Input images of shape (N, C, H, W).
|
||||
inputs (torch.Tensor): Input images of shape (N, C, H, W).
|
||||
Typically these should be mean centered and std scaled.
|
||||
batch_data_samples (list[KIEDataSample]): A list of N datasamples,
|
||||
data_samples (list[KIEDataSample]): A list of N datasamples,
|
||||
containing meta information and gold annotations for each of
|
||||
the images.
|
||||
|
||||
|
@ -187,8 +185,7 @@ class SDMGR(BaseModel):
|
|||
- node_cls (torch.Tensor): Node classification output.
|
||||
- edge_cls (torch.Tensor): Edge classification output.
|
||||
"""
|
||||
x = self.extract_feat(batch_inputs, [
|
||||
data_sample.gt_instances.bboxes
|
||||
for data_sample in batch_data_samples
|
||||
])
|
||||
return self.kie_head(x, batch_data_samples)
|
||||
x = self.extract_feat(
|
||||
inputs,
|
||||
[data_sample.gt_instances.bboxes for data_sample in data_samples])
|
||||
return self.kie_head(x, data_samples)
|
||||
|
|
|
@ -83,28 +83,26 @@ class SDMGRHead(BaseModule):
|
|||
self.postprocessor = MODELS.build(postprocessor)
|
||||
self.relation_norm = relation_norm
|
||||
|
||||
def loss(self, batch_inputs: Tensor,
|
||||
batch_data_samples: List[KIEDataSample]) -> Dict:
|
||||
def loss(self, inputs: Tensor, data_samples: List[KIEDataSample]) -> Dict:
|
||||
"""Calculate losses from a batch of inputs and data samples.
|
||||
Args:
|
||||
batch_inputs (torch.Tensor): Shape :math:`(N, E)`.
|
||||
batch_data_samples (List[KIEDataSample]): List of data samples.
|
||||
inputs (torch.Tensor): Shape :math:`(N, E)`.
|
||||
data_samples (List[KIEDataSample]): List of data samples.
|
||||
|
||||
Returns:
|
||||
dict[str, tensor]: A dictionary of loss components.
|
||||
"""
|
||||
preds = self.forward(batch_inputs, batch_data_samples)
|
||||
return self.module_loss(preds, batch_data_samples)
|
||||
preds = self.forward(inputs, data_samples)
|
||||
return self.module_loss(preds, data_samples)
|
||||
|
||||
def predict(self, batch_inputs: Tensor,
|
||||
batch_data_samples: List[KIEDataSample]
|
||||
) -> List[KIEDataSample]:
|
||||
def predict(self, inputs: Tensor,
|
||||
data_samples: List[KIEDataSample]) -> List[KIEDataSample]:
|
||||
"""Predict results from a batch of inputs and data samples with post-
|
||||
processing.
|
||||
|
||||
Args:
|
||||
batch_inputs (torch.Tensor): Shape :math:`(N, E)`.
|
||||
batch_data_samples (List[KIEDataSample]): List of data samples.
|
||||
inputs (torch.Tensor): Shape :math:`(N, E)`.
|
||||
data_samples (List[KIEDataSample]): List of data samples.
|
||||
|
||||
Returns:
|
||||
List[KIEDataSample]: A list of datasamples of prediction results.
|
||||
|
@ -121,16 +119,15 @@ class SDMGRHead(BaseModule):
|
|||
- edge_scores (Tensor): A float tensor of shape (N, ), indicating
|
||||
the confidence scores for edge predictions.
|
||||
"""
|
||||
preds = self.forward(batch_inputs, batch_data_samples)
|
||||
return self.postprocessor(preds, batch_data_samples)
|
||||
preds = self.forward(inputs, data_samples)
|
||||
return self.postprocessor(preds, data_samples)
|
||||
|
||||
def forward(self, batch_inputs: Tensor,
|
||||
batch_data_samples: List[KIEDataSample]
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
def forward(self, inputs: Tensor,
|
||||
data_samples: List[KIEDataSample]) -> Tuple[Tensor, Tensor]:
|
||||
"""
|
||||
Args:
|
||||
batch_inputs (torch.Tensor): Shape :math:`(N, E)`.
|
||||
batch_data_samples (List[KIEDataSample]): List of data samples.
|
||||
inputs (torch.Tensor): Shape :math:`(N, E)`.
|
||||
data_samples (List[KIEDataSample]): List of data samples.
|
||||
|
||||
Returns:
|
||||
tuple(Tensor, Tensor):
|
||||
|
@ -143,8 +140,7 @@ class SDMGRHead(BaseModule):
|
|||
|
||||
device = self.node_embed.weight.device
|
||||
|
||||
node_nums, char_nums, all_nodes = self.convert_texts(
|
||||
batch_data_samples)
|
||||
node_nums, char_nums, all_nodes = self.convert_texts(data_samples)
|
||||
|
||||
embed_nodes = self.node_embed(all_nodes.to(device).long())
|
||||
rnn_nodes, _ = self.rnn(embed_nodes)
|
||||
|
@ -156,10 +152,10 @@ class SDMGRHead(BaseModule):
|
|||
1, (all_nums[valid] - 1).unsqueeze(-1).unsqueeze(-1).expand(
|
||||
-1, -1, rnn_nodes.size(-1))).squeeze(1)
|
||||
|
||||
if batch_inputs is not None:
|
||||
nodes = self.fusion([batch_inputs, nodes])
|
||||
if inputs is not None:
|
||||
nodes = self.fusion([inputs, nodes])
|
||||
|
||||
relations = self.compute_relations(batch_data_samples)
|
||||
relations = self.compute_relations(data_samples)
|
||||
all_edges = torch.cat(
|
||||
[relation.view(-1, relation.size(-1)) for relation in relations],
|
||||
dim=0)
|
||||
|
|
|
@ -52,8 +52,7 @@ class SDMGRPostProcessor:
|
|||
self.softmax = nn.Softmax(dim=-1)
|
||||
|
||||
def __call__(self, preds: Tuple[Tensor, Tensor],
|
||||
batch_data_samples: List[KIEDataSample]
|
||||
) -> List[KIEDataSample]:
|
||||
data_samples: List[KIEDataSample]) -> List[KIEDataSample]:
|
||||
"""Postprocess raw outputs from SDMGR heads and pack the results into a
|
||||
list of KIEDataSample.
|
||||
|
||||
|
@ -83,7 +82,7 @@ class SDMGRPostProcessor:
|
|||
all_edge_scores = self.softmax(edge_preds)
|
||||
chunk_size = [
|
||||
data_sample.gt_instances.bboxes.shape[0]
|
||||
for data_sample in batch_data_samples
|
||||
for data_sample in data_samples
|
||||
]
|
||||
node_scores, node_preds = torch.max(all_node_scores, dim=-1)
|
||||
edge_scores, edge_preds = torch.max(all_edge_scores, dim=-1)
|
||||
|
@ -97,17 +96,17 @@ class SDMGRPostProcessor:
|
|||
edge_preds[i] = edge_preds[i].reshape((chunk, chunk))
|
||||
edge_scores[i] = edge_scores[i].reshape((chunk, chunk))
|
||||
|
||||
for i in range(len(batch_data_samples)):
|
||||
batch_data_samples[i].pred_instances = InstanceData()
|
||||
batch_data_samples[i].pred_instances.labels = node_preds[i]
|
||||
batch_data_samples[i].pred_instances.scores = node_scores[i]
|
||||
for i in range(len(data_samples)):
|
||||
data_samples[i].pred_instances = InstanceData()
|
||||
data_samples[i].pred_instances.labels = node_preds[i]
|
||||
data_samples[i].pred_instances.scores = node_scores[i]
|
||||
if self.link_type != 'none':
|
||||
edge_scores[i], edge_preds[i] = self.decode_edges(
|
||||
node_preds[i], edge_scores[i], edge_preds[i])
|
||||
batch_data_samples[i].pred_instances.edge_labels = edge_preds[i]
|
||||
batch_data_samples[i].pred_instances.edge_scores = edge_scores[i]
|
||||
data_samples[i].pred_instances.edge_labels = edge_preds[i]
|
||||
data_samples[i].pred_instances.edge_scores = edge_scores[i]
|
||||
|
||||
return batch_data_samples
|
||||
return data_samples
|
||||
|
||||
def decode_edges(self, node_labels: Tensor, edge_scores: Tensor,
|
||||
edge_labels: Tensor) -> Tuple[Tensor, Tensor]:
|
||||
|
|
|
@ -1,8 +1,7 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from numbers import Number
|
||||
from typing import List, Optional, Sequence, Tuple, Union
|
||||
from typing import Dict, List, Optional, Sequence, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmengine.model import ImgDataPreprocessor
|
||||
|
||||
|
@ -59,7 +58,7 @@ class TextDetDataPreprocessor(ImgDataPreprocessor):
|
|||
pad_value: Union[float, int] = 0,
|
||||
bgr_to_rgb: bool = False,
|
||||
rgb_to_bgr: bool = False,
|
||||
batch_augments: Optional[List[dict]] = None):
|
||||
batch_augments: Optional[List[Dict]] = None) -> None:
|
||||
super().__init__(
|
||||
mean=mean,
|
||||
std=std,
|
||||
|
@ -73,32 +72,28 @@ class TextDetDataPreprocessor(ImgDataPreprocessor):
|
|||
else:
|
||||
self.batch_augments = None
|
||||
|
||||
def forward(self,
|
||||
data: Sequence[dict],
|
||||
training: bool = False) -> Tuple[torch.Tensor, Optional[list]]:
|
||||
def forward(self, data: Dict, training: bool = False) -> Dict:
|
||||
"""Perform normalization、padding and bgr2rgb conversion based on
|
||||
``BaseDataPreprocessor``.
|
||||
|
||||
Args:
|
||||
data (Sequence[dict]): data sampled from dataloader.
|
||||
data (dict): data sampled from dataloader.
|
||||
training (bool): Whether to enable training time augmentation.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, Optional[list]]: Data in the same format as the
|
||||
model input.
|
||||
dict: Data in the same format as the model input.
|
||||
"""
|
||||
batch_inputs, batch_data_samples = super().forward(
|
||||
data=data, training=training)
|
||||
data = super().forward(data=data, training=training)
|
||||
inputs, data_samples = data['inputs'], data['data_samples']
|
||||
|
||||
if batch_data_samples is not None:
|
||||
batch_input_shape = tuple(batch_inputs[0].size()[-2:])
|
||||
for data_samples in batch_data_samples:
|
||||
data_samples.set_metainfo(
|
||||
if data_samples is not None:
|
||||
batch_input_shape = tuple(inputs[0].size()[-2:])
|
||||
for data_sample in data_samples:
|
||||
data_sample.set_metainfo(
|
||||
{'batch_input_shape': batch_input_shape})
|
||||
|
||||
if training and self.batch_augments is not None:
|
||||
for batch_aug in self.batch_augments:
|
||||
batch_inputs, batch_data_samples = batch_aug(
|
||||
batch_inputs, batch_data_samples)
|
||||
inputs, data_samples = batch_aug(inputs, data_samples)
|
||||
|
||||
return batch_inputs, batch_data_samples
|
||||
return data
|
||||
|
|
|
@ -35,8 +35,8 @@ class MMDetWrapper(BaseModel):
|
|||
self.text_repr_type = text_repr_type
|
||||
|
||||
def forward(self,
|
||||
batch_inputs: torch.Tensor,
|
||||
batch_data_samples: OptSampleList = None,
|
||||
inputs: torch.Tensor,
|
||||
data_samples: OptSampleList = None,
|
||||
mode: str = 'tensor',
|
||||
**kwargs) -> ForwardResults:
|
||||
"""The unified entry for a forward process in both training and test.
|
||||
|
@ -54,9 +54,9 @@ class MMDetWrapper(BaseModel):
|
|||
optimizer updating, which are done in the :meth:`train_step`.
|
||||
|
||||
Args:
|
||||
batch_inputs (torch.Tensor): The input tensor with shape
|
||||
inputs (torch.Tensor): The input tensor with shape
|
||||
(N, C, ...) in general.
|
||||
batch_data_samples (list[:obj:`DetDataSample`], optional): The
|
||||
data_samples (list[:obj:`DetDataSample`], optional): The
|
||||
annotation data of every samples. Defaults to None.
|
||||
mode (str): Return what kind of value. Defaults to 'tensor'.
|
||||
|
||||
|
@ -67,8 +67,8 @@ class MMDetWrapper(BaseModel):
|
|||
- If ``mode="predict"``, return a list of :obj:`TextDetDataSample`.
|
||||
- If ``mode="loss"``, return a dict of tensor.
|
||||
"""
|
||||
results = self.wrapped_model.forward(batch_inputs, batch_data_samples,
|
||||
mode, **kwargs)
|
||||
results = self.wrapped_model.forward(inputs, data_samples, mode,
|
||||
**kwargs)
|
||||
if mode == 'predict':
|
||||
results = self.adapt_predictions(results)
|
||||
|
||||
|
|
|
@ -44,47 +44,46 @@ class SingleStageTextDetector(BaseTextDetector):
|
|||
self.neck = MODELS.build(neck)
|
||||
self.det_head = MODELS.build(det_head)
|
||||
|
||||
def extract_feat(self, batch_inputs: torch.Tensor) -> torch.Tensor:
|
||||
def extract_feat(self, inputs: torch.Tensor) -> torch.Tensor:
|
||||
"""Extract features.
|
||||
|
||||
Args:
|
||||
batch_inputs (Tensor): Image tensor with shape (N, C, H ,W).
|
||||
inputs (Tensor): Image tensor with shape (N, C, H ,W).
|
||||
|
||||
Returns:
|
||||
Tensor or tuple[Tensor]: Multi-level features that may have
|
||||
different resolutions.
|
||||
"""
|
||||
batch_inputs = self.backbone(batch_inputs)
|
||||
inputs = self.backbone(inputs)
|
||||
if self.with_neck:
|
||||
batch_inputs = self.neck(batch_inputs)
|
||||
return batch_inputs
|
||||
inputs = self.neck(inputs)
|
||||
return inputs
|
||||
|
||||
def loss(self, batch_inputs: torch.Tensor,
|
||||
batch_data_samples: Sequence[TextDetDataSample]) -> Dict:
|
||||
def loss(self, inputs: torch.Tensor,
|
||||
data_samples: Sequence[TextDetDataSample]) -> Dict:
|
||||
"""Calculate losses from a batch of inputs and data samples.
|
||||
|
||||
Args:
|
||||
batch_inputs (torch.Tensor): Input images of shape (N, C, H, W).
|
||||
inputs (torch.Tensor): Input images of shape (N, C, H, W).
|
||||
Typically these should be mean centered and std scaled.
|
||||
batch_data_samples (list[TextDetDataSample]): A list of N
|
||||
data_samples (list[TextDetDataSample]): A list of N
|
||||
datasamples, containing meta information and gold annotations
|
||||
for each of the images.
|
||||
Returns:
|
||||
dict[str, Tensor]: A dictionary of loss components.
|
||||
"""
|
||||
batch_inputs = self.extract_feat(batch_inputs)
|
||||
return self.det_head.loss(batch_inputs, batch_data_samples)
|
||||
inputs = self.extract_feat(inputs)
|
||||
return self.det_head.loss(inputs, data_samples)
|
||||
|
||||
def predict(
|
||||
self, batch_inputs: torch.Tensor,
|
||||
batch_data_samples: Sequence[TextDetDataSample]
|
||||
) -> Sequence[TextDetDataSample]:
|
||||
def predict(self, inputs: torch.Tensor,
|
||||
data_samples: Sequence[TextDetDataSample]
|
||||
) -> Sequence[TextDetDataSample]:
|
||||
"""Predict results from a batch of inputs and data samples with post-
|
||||
processing.
|
||||
|
||||
Args:
|
||||
batch_inputs (torch.Tensor): Images of shape (N, C, H, W).
|
||||
batch_data_samples (list[TextDetDataSample]): A list of N
|
||||
inputs (torch.Tensor): Images of shape (N, C, H, W).
|
||||
data_samples (list[TextDetDataSample]): A list of N
|
||||
datasamples, containing meta information and gold annotations
|
||||
for each of the images.
|
||||
|
||||
|
@ -104,20 +103,19 @@ class SingleStageTextDetector(BaseTextDetector):
|
|||
Each element represents the polygon of the
|
||||
instance, in (xn, yn) order.
|
||||
"""
|
||||
x = self.extract_feat(batch_inputs)
|
||||
return self.det_head.predict(x, batch_data_samples)
|
||||
x = self.extract_feat(inputs)
|
||||
return self.det_head.predict(x, data_samples)
|
||||
|
||||
def _forward(self,
|
||||
batch_inputs: torch.Tensor,
|
||||
batch_data_samples: Optional[
|
||||
Sequence[TextDetDataSample]] = None,
|
||||
inputs: torch.Tensor,
|
||||
data_samples: Optional[Sequence[TextDetDataSample]] = None,
|
||||
**kwargs) -> torch.Tensor:
|
||||
"""Network forward process. Usually includes backbone, neck and head
|
||||
forward without any post-processing.
|
||||
|
||||
Args:
|
||||
batch_inputs (Tensor): Inputs with shape (N, C, H, W).
|
||||
batch_data_samples (list[TextDetDataSample]): A list of N
|
||||
inputs (Tensor): Inputs with shape (N, C, H, W).
|
||||
data_samples (list[TextDetDataSample]): A list of N
|
||||
datasamples, containing meta information and gold annotations
|
||||
for each of the images.
|
||||
|
||||
|
@ -125,5 +123,5 @@ class SingleStageTextDetector(BaseTextDetector):
|
|||
Tensor or tuple[Tensor]: A tuple of features from ``det_head``
|
||||
forward.
|
||||
"""
|
||||
x = self.extract_feat(batch_inputs)
|
||||
return self.det_head(x, batch_data_samples)
|
||||
x = self.extract_feat(inputs)
|
||||
return self.det_head(x, data_samples)
|
||||
|
|
|
@ -63,34 +63,32 @@ class BaseTextDetHead(BaseModule):
|
|||
self.module_loss = MODELS.build(module_loss)
|
||||
self.postprocessor = MODELS.build(postprocessor)
|
||||
|
||||
def loss(self, x: Tuple[Tensor],
|
||||
batch_data_samples: DetSampleList) -> dict:
|
||||
def loss(self, x: Tuple[Tensor], data_samples: DetSampleList) -> dict:
|
||||
"""Perform forward propagation and loss calculation of the detection
|
||||
head on the features of the upstream network.
|
||||
|
||||
Args:
|
||||
x (tuple[Tensor]): Features from the upstream network, each is
|
||||
a 4D-tensor.
|
||||
batch_data_samples (List[:obj:`DetDataSample`]): The Data
|
||||
data_samples (List[:obj:`DetDataSample`]): The Data
|
||||
Samples. It usually includes information such as
|
||||
`gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary of loss components.
|
||||
"""
|
||||
outs = self(x, batch_data_samples)
|
||||
losses = self.module_loss(outs, batch_data_samples)
|
||||
outs = self(x, data_samples)
|
||||
losses = self.module_loss(outs, data_samples)
|
||||
return losses
|
||||
|
||||
def loss_and_predict(self, x: Tuple[Tensor],
|
||||
batch_data_samples: DetSampleList
|
||||
def loss_and_predict(self, x: Tuple[Tensor], data_samples: DetSampleList
|
||||
) -> Tuple[dict, DetSampleList]:
|
||||
"""Perform forward propagation of the head, then calculate loss and
|
||||
predictions from the features and data samples.
|
||||
|
||||
Args:
|
||||
x (tuple[Tensor]): Features from FPN.
|
||||
batch_data_samples (list[:obj:`DetDataSample`]): Each item contains
|
||||
data_samples (list[:obj:`DetDataSample`]): Each item contains
|
||||
the meta information of each image and corresponding
|
||||
annotations.
|
||||
|
||||
|
@ -101,21 +99,21 @@ class BaseTextDetHead(BaseModule):
|
|||
- predictions (list[:obj:`InstanceData`]): Detection
|
||||
results of each image after the post process.
|
||||
"""
|
||||
outs = self(x, batch_data_samples)
|
||||
losses = self.module_loss(outs, batch_data_samples)
|
||||
outs = self(x, data_samples)
|
||||
losses = self.module_loss(outs, data_samples)
|
||||
|
||||
predictions = self.postprocessor(outs, batch_data_samples)
|
||||
predictions = self.postprocessor(outs, data_samples)
|
||||
return losses, predictions
|
||||
|
||||
def predict(self, x: torch.Tensor,
|
||||
batch_data_samples: DetSampleList) -> DetSampleList:
|
||||
data_samples: DetSampleList) -> DetSampleList:
|
||||
"""Perform forward propagation of the detection head and predict
|
||||
detection results on the features of the upstream network.
|
||||
|
||||
Args:
|
||||
x (tuple[Tensor]): Multi-level features from the
|
||||
upstream network, each is a 4D-tensor.
|
||||
batch_data_samples (List[:obj:`DetDataSample`]): The Data
|
||||
data_samples (List[:obj:`DetDataSample`]): The Data
|
||||
Samples. It usually includes information such as
|
||||
`gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.
|
||||
|
||||
|
@ -123,7 +121,7 @@ class BaseTextDetHead(BaseModule):
|
|||
SampleList: Detection results of each image
|
||||
after the post process.
|
||||
"""
|
||||
outs = self(x, batch_data_samples)
|
||||
outs = self(x, data_samples)
|
||||
|
||||
predictions = self.postprocessor(outs, batch_data_samples)
|
||||
predictions = self.postprocessor(outs, data_samples)
|
||||
return predictions
|
||||
|
|
|
@ -261,15 +261,13 @@ class DRRGHead(BaseTextDetHead):
|
|||
self.in_channels + self.out_channels) + self.node_geo_feat_len
|
||||
self.gcn = GCN(node_feat_len)
|
||||
|
||||
def loss(
|
||||
self, batch_inputs: torch.Tensor,
|
||||
batch_data_samples: List[TextDetDataSample]
|
||||
) -> Tuple[Tensor, Tensor, Tensor]:
|
||||
def loss(self, inputs: torch.Tensor, data_samples: List[TextDetDataSample]
|
||||
) -> Tuple[Tensor, Tensor, Tensor]:
|
||||
"""Loss function.
|
||||
|
||||
Args:
|
||||
batch_inputs (Tensor): Shape of :math:`(N, C, H, W)`.
|
||||
batch_data_samples (List[TextDetDataSample]): List of data samples.
|
||||
inputs (Tensor): Shape of :math:`(N, C, H, W)`.
|
||||
data_samples (List[TextDetDataSample]): List of data samples.
|
||||
|
||||
Returns:
|
||||
tuple(pred_maps, gcn_pred, gt_labels):
|
||||
|
@ -281,27 +279,26 @@ class DRRGHead(BaseTextDetHead):
|
|||
- gt_labels (Tensor): Ground-truth label of shape
|
||||
:math:`(m, n)` where :math:`m * n = N`.
|
||||
"""
|
||||
targets = self.module_loss.get_targets(batch_data_samples)
|
||||
targets = self.module_loss.get_targets(data_samples)
|
||||
gt_comp_attribs = targets[-1]
|
||||
|
||||
pred_maps = self.out_conv(batch_inputs)
|
||||
feat_maps = torch.cat([batch_inputs, pred_maps], dim=1)
|
||||
pred_maps = self.out_conv(inputs)
|
||||
feat_maps = torch.cat([inputs, pred_maps], dim=1)
|
||||
node_feats, adjacent_matrices, knn_inds, gt_labels = self.graph_train(
|
||||
feat_maps, np.stack(gt_comp_attribs))
|
||||
|
||||
gcn_pred = self.gcn(node_feats, adjacent_matrices, knn_inds)
|
||||
|
||||
return self.module_loss((pred_maps, gcn_pred, gt_labels),
|
||||
batch_data_samples)
|
||||
return self.module_loss((pred_maps, gcn_pred, gt_labels), data_samples)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
batch_inputs: Tensor,
|
||||
inputs: Tensor,
|
||||
data_samples: Optional[List[TextDetDataSample]] = None
|
||||
) -> Tuple[Tensor, Tensor, Tensor]:
|
||||
r"""Run DRRG head in prediction mode, and return the raw tensors only.
|
||||
Args:
|
||||
batch_inputs (Tensor): Shape of :math:`(1, C, H, W)`.
|
||||
inputs (Tensor): Shape of :math:`(1, C, H, W)`.
|
||||
data_samples (list[TextDetDataSample], optional): A list of data
|
||||
samples. Defaults to None.
|
||||
|
||||
|
@ -317,10 +314,10 @@ class DRRGHead(BaseTextDetHead):
|
|||
:math:`(M, 9)` where each row corresponds to one box and
|
||||
its score: (x1, y1, x2, y2, x3, y3, x4, y4, score).
|
||||
"""
|
||||
pred_maps = self.out_conv(batch_inputs)
|
||||
batch_inputs = torch.cat([batch_inputs, pred_maps], dim=1)
|
||||
pred_maps = self.out_conv(inputs)
|
||||
inputs = torch.cat([inputs, pred_maps], dim=1)
|
||||
|
||||
none_flag, graph_data = self.graph_test(pred_maps, batch_inputs)
|
||||
none_flag, graph_data = self.graph_test(pred_maps, inputs)
|
||||
|
||||
(local_graphs_node_feat, adjacent_matrices, pivots_knn_inds,
|
||||
pivot_local_graphs, text_comps) = graph_data
|
||||
|
|
|
@ -1,8 +1,7 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from numbers import Number
|
||||
from typing import List, Optional, Sequence, Tuple, Union
|
||||
from typing import Dict, List, Optional, Sequence, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmengine.model import ImgDataPreprocessor
|
||||
|
||||
|
@ -25,7 +24,7 @@ class TextRecogDataPreprocessor(ImgDataPreprocessor):
|
|||
- Pad inputs to the maximum size of current batch with defined
|
||||
``pad_value``. The padding size can be divisible by a defined
|
||||
``pad_size_divisor``
|
||||
- Stack inputs to batch_inputs.
|
||||
- Stack inputs to inputs.
|
||||
- Convert inputs from bgr to rgb if the shape of input is (3, H, W).
|
||||
- Normalize image with defined std and mean.
|
||||
- Do batch augmentations during training.
|
||||
|
@ -52,7 +51,7 @@ class TextRecogDataPreprocessor(ImgDataPreprocessor):
|
|||
pad_value: Union[float, int] = 0,
|
||||
bgr_to_rgb: bool = False,
|
||||
rgb_to_bgr: bool = False,
|
||||
batch_augments: Optional[List[dict]] = None):
|
||||
batch_augments: Optional[List[Dict]] = None) -> None:
|
||||
super().__init__(
|
||||
mean=mean,
|
||||
std=std,
|
||||
|
@ -66,37 +65,33 @@ class TextRecogDataPreprocessor(ImgDataPreprocessor):
|
|||
else:
|
||||
self.batch_augments = None
|
||||
|
||||
def forward(self,
|
||||
data: Sequence[dict],
|
||||
training: bool = False) -> Tuple[torch.Tensor, Optional[list]]:
|
||||
def forward(self, data: Dict, training: bool = False) -> Dict:
|
||||
"""Perform normalization、padding and bgr2rgb conversion based on
|
||||
``BaseDataPreprocessor``.
|
||||
|
||||
Args:
|
||||
data (Sequence[dict]): data sampled from dataloader.
|
||||
data (dict): Data sampled from dataloader.
|
||||
training (bool): Whether to enable training time augmentation.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, Optional[list]]: Data in the same format as the
|
||||
model input.
|
||||
dict: Data in the same format as the model input.
|
||||
"""
|
||||
batch_inputs, batch_data_samples = super().forward(
|
||||
data=data, training=training)
|
||||
data = super().forward(data=data, training=training)
|
||||
inputs, data_samples = data['inputs'], data['data_samples']
|
||||
|
||||
if batch_data_samples is not None:
|
||||
batch_input_shape = tuple(batch_inputs[0].size()[-2:])
|
||||
for data_samples in batch_data_samples:
|
||||
if data_samples is not None:
|
||||
batch_input_shape = tuple(inputs[0].size()[-2:])
|
||||
for data_sample in data_samples:
|
||||
|
||||
valid_ratio = data_samples.valid_ratio * \
|
||||
data_samples.img_shape[1] / batch_input_shape[1]
|
||||
data_samples.set_metainfo(
|
||||
valid_ratio = data_sample.valid_ratio * \
|
||||
data_sample.img_shape[1] / batch_input_shape[1]
|
||||
data_sample.set_metainfo(
|
||||
dict(
|
||||
valid_ratio=valid_ratio,
|
||||
batch_input_shape=batch_input_shape))
|
||||
|
||||
if training and self.batch_augments is not None:
|
||||
for batch_aug in self.batch_augments:
|
||||
batch_inputs, batch_data_samples = batch_aug(
|
||||
batch_inputs, batch_data_samples)
|
||||
inputs, data_samples = batch_aug(inputs, data_samples)
|
||||
|
||||
return batch_inputs, batch_data_samples
|
||||
return data
|
||||
|
|
|
@ -52,8 +52,8 @@ class BaseRecognizer(BaseModel, metaclass=ABCMeta):
|
|||
pass
|
||||
|
||||
def forward(self,
|
||||
batch_inputs: torch.Tensor,
|
||||
batch_data_samples: OptRecSampleList = None,
|
||||
inputs: torch.Tensor,
|
||||
data_samples: OptRecSampleList = None,
|
||||
mode: str = 'tensor',
|
||||
**kwargs) -> RecForwardResults:
|
||||
"""The unified entry for a forward process in both training and test.
|
||||
|
@ -71,9 +71,9 @@ class BaseRecognizer(BaseModel, metaclass=ABCMeta):
|
|||
optimizer updating, which are done in the :meth:`train_step`.
|
||||
|
||||
Args:
|
||||
batch_inputs (torch.Tensor): The input tensor with shape
|
||||
inputs (torch.Tensor): The input tensor with shape
|
||||
(N, C, ...) in general.
|
||||
batch_data_samples (list[:obj:`DetDataSample`], optional): The
|
||||
data_samples (list[:obj:`DetDataSample`], optional): The
|
||||
annotation data of every samples. Defaults to None.
|
||||
mode (str): Return what kind of value. Defaults to 'tensor'.
|
||||
|
||||
|
@ -85,33 +85,32 @@ class BaseRecognizer(BaseModel, metaclass=ABCMeta):
|
|||
- If ``mode="loss"``, return a dict of tensor.
|
||||
"""
|
||||
if mode == 'loss':
|
||||
return self.loss(batch_inputs, batch_data_samples, **kwargs)
|
||||
return self.loss(inputs, data_samples, **kwargs)
|
||||
elif mode == 'predict':
|
||||
return self.predict(batch_inputs, batch_data_samples, **kwargs)
|
||||
return self.predict(inputs, data_samples, **kwargs)
|
||||
elif mode == 'tensor':
|
||||
return self._forward(batch_inputs, batch_data_samples, **kwargs)
|
||||
return self._forward(inputs, data_samples, **kwargs)
|
||||
else:
|
||||
raise RuntimeError(f'Invalid mode "{mode}". '
|
||||
'Only supports loss, predict and tensor mode')
|
||||
|
||||
@abstractmethod
|
||||
def loss(self, batch_inputs: torch.Tensor,
|
||||
batch_data_samples: RecSampleList,
|
||||
def loss(self, inputs: torch.Tensor, data_samples: RecSampleList,
|
||||
**kwargs) -> Union[dict, tuple]:
|
||||
"""Calculate losses from a batch of inputs and data samples."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def predict(self, batch_inputs: torch.Tensor,
|
||||
batch_data_samples: RecSampleList, **kwargs) -> RecSampleList:
|
||||
def predict(self, inputs: torch.Tensor, data_samples: RecSampleList,
|
||||
**kwargs) -> RecSampleList:
|
||||
"""Predict results from a batch of inputs and data samples with post-
|
||||
processing."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _forward(self,
|
||||
batch_inputs: torch.Tensor,
|
||||
batch_data_samples: OptRecSampleList = None,
|
||||
inputs: torch.Tensor,
|
||||
data_samples: OptRecSampleList = None,
|
||||
**kwargs):
|
||||
"""Network forward process.
|
||||
|
||||
|
|
|
@ -59,16 +59,16 @@ class EncoderDecoderRecognizer(BaseRecognizer):
|
|||
assert decoder is not None
|
||||
self.decoder = MODELS.build(decoder)
|
||||
|
||||
def extract_feat(self, batch_inputs: torch.Tensor) -> torch.Tensor:
|
||||
def extract_feat(self, inputs: torch.Tensor) -> torch.Tensor:
|
||||
"""Directly extract features from the backbone."""
|
||||
if self.with_preprocessor:
|
||||
batch_inputs = self.preprocessor(batch_inputs)
|
||||
inputs = self.preprocessor(inputs)
|
||||
if self.with_backbone:
|
||||
batch_inputs = self.backbone(batch_inputs)
|
||||
return batch_inputs
|
||||
inputs = self.backbone(inputs)
|
||||
return inputs
|
||||
|
||||
def loss(self, batch_inputs: torch.Tensor,
|
||||
batch_data_samples: RecSampleList, **kwargs) -> Dict:
|
||||
def loss(self, inputs: torch.Tensor, data_samples: RecSampleList,
|
||||
**kwargs) -> Dict:
|
||||
"""Calculate losses from a batch of inputs and data samples.
|
||||
Args:
|
||||
inputs (tensor): Input images of shape (N, C, H, W).
|
||||
|
@ -80,14 +80,14 @@ class EncoderDecoderRecognizer(BaseRecognizer):
|
|||
Returns:
|
||||
dict[str, tensor]: A dictionary of loss components.
|
||||
"""
|
||||
feat = self.extract_feat(batch_inputs)
|
||||
feat = self.extract_feat(inputs)
|
||||
out_enc = None
|
||||
if self.with_encoder:
|
||||
out_enc = self.encoder(feat, batch_data_samples)
|
||||
return self.decoder.loss(feat, out_enc, batch_data_samples)
|
||||
out_enc = self.encoder(feat, data_samples)
|
||||
return self.decoder.loss(feat, out_enc, data_samples)
|
||||
|
||||
def predict(self, batch_inputs: torch.Tensor,
|
||||
batch_data_samples: RecSampleList, **kwargs) -> RecSampleList:
|
||||
def predict(self, inputs: torch.Tensor, data_samples: RecSampleList,
|
||||
**kwargs) -> RecSampleList:
|
||||
"""Predict results from a batch of inputs and data samples with post-
|
||||
processing.
|
||||
|
||||
|
@ -101,30 +101,30 @@ class EncoderDecoderRecognizer(BaseRecognizer):
|
|||
list[TextRecogDataSample]: A list of N datasamples of prediction
|
||||
results. Results are stored in ``pred_text``.
|
||||
"""
|
||||
feat = self.extract_feat(batch_inputs)
|
||||
feat = self.extract_feat(inputs)
|
||||
out_enc = None
|
||||
if self.with_encoder:
|
||||
out_enc = self.encoder(feat, batch_data_samples)
|
||||
return self.decoder.predict(feat, out_enc, batch_data_samples)
|
||||
out_enc = self.encoder(feat, data_samples)
|
||||
return self.decoder.predict(feat, out_enc, data_samples)
|
||||
|
||||
def _forward(self,
|
||||
batch_inputs: torch.Tensor,
|
||||
batch_data_samples: OptRecSampleList = None,
|
||||
inputs: torch.Tensor,
|
||||
data_samples: OptRecSampleList = None,
|
||||
**kwargs) -> RecForwardResults:
|
||||
"""Network forward process. Usually includes backbone, encoder and
|
||||
decoder forward without any post-processing.
|
||||
|
||||
Args:
|
||||
batch_inputs (Tensor): Inputs with shape (N, C, H, W).
|
||||
batch_data_samples (list[TextRecogDataSample]): A list of N
|
||||
inputs (Tensor): Inputs with shape (N, C, H, W).
|
||||
data_samples (list[TextRecogDataSample]): A list of N
|
||||
datasamples, containing meta information and gold
|
||||
annotations for each of the images.
|
||||
|
||||
Returns:
|
||||
Tensor: A tuple of features from ``decoder`` forward.
|
||||
"""
|
||||
feat = self.extract_feat(batch_inputs)
|
||||
feat = self.extract_feat(inputs)
|
||||
out_enc = None
|
||||
if self.with_encoder:
|
||||
out_enc = self.encoder(feat, batch_data_samples)
|
||||
return self.decoder(feat, out_enc, batch_data_samples)
|
||||
out_enc = self.encoder(feat, data_samples)
|
||||
return self.decoder(feat, out_enc, data_samples)
|
||||
|
|
|
@ -37,9 +37,9 @@ class TestPackTextDetInputs(TestCase):
|
|||
results = transform(copy.deepcopy(datainfo))
|
||||
self.assertIn('inputs', results)
|
||||
self.assertTupleEqual(tuple(results['inputs'].shape), (1, 10, 10))
|
||||
self.assertIn('data_sample', results)
|
||||
self.assertIn('data_samples', results)
|
||||
|
||||
data_sample = results['data_sample']
|
||||
data_sample = results['data_samples']
|
||||
self.assertIn('bboxes', data_sample.gt_instances)
|
||||
self.assertIsInstance(data_sample.gt_instances.bboxes, torch.Tensor)
|
||||
self.assertEqual(data_sample.gt_instances.bboxes.dtype, torch.float32)
|
||||
|
@ -56,9 +56,9 @@ class TestPackTextDetInputs(TestCase):
|
|||
transform = PackTextDetInputs(meta_keys=('img_path', ))
|
||||
results = transform(copy.deepcopy(datainfo))
|
||||
self.assertIn('inputs', results)
|
||||
self.assertIn('data_sample', results)
|
||||
self.assertIn('data_samples', results)
|
||||
|
||||
data_sample = results['data_sample']
|
||||
data_sample = results['data_samples']
|
||||
self.assertIn('bboxes', data_sample.gt_instances)
|
||||
self.assertIn('img_path', data_sample)
|
||||
self.assertNotIn('flip', data_sample)
|
||||
|
@ -66,14 +66,14 @@ class TestPackTextDetInputs(TestCase):
|
|||
datainfo.pop('gt_texts')
|
||||
transform = PackTextDetInputs()
|
||||
results = transform(copy.deepcopy(datainfo))
|
||||
data_sample = results['data_sample']
|
||||
data_sample = results['data_samples']
|
||||
self.assertNotIn('texts', data_sample.gt_instances)
|
||||
|
||||
datainfo = dict(img_shape=(10, 10))
|
||||
transform = PackTextDetInputs(meta_keys=('img_shape', ))
|
||||
results = transform(copy.deepcopy(datainfo))
|
||||
self.assertNotIn('inputs', results)
|
||||
data_sample = results['data_sample']
|
||||
data_sample = results['data_samples']
|
||||
self.assertNotIn('texts', data_sample.gt_instances)
|
||||
|
||||
def test_repr(self):
|
||||
|
@ -108,8 +108,8 @@ class TestPackTextRecogInputs(TestCase):
|
|||
results = transform(copy.deepcopy(datainfo))
|
||||
self.assertIn('inputs', results)
|
||||
self.assertTupleEqual(tuple(results['inputs'].shape), (1, 10, 10))
|
||||
self.assertIn('data_sample', results)
|
||||
data_sample = results['data_sample']
|
||||
self.assertIn('data_samples', results)
|
||||
data_sample = results['data_samples']
|
||||
self.assertEqual(data_sample.gt_text.item, 'mmocr')
|
||||
self.assertIn('img_path', data_sample)
|
||||
self.assertIn('valid_ratio', data_sample)
|
||||
|
@ -118,8 +118,8 @@ class TestPackTextRecogInputs(TestCase):
|
|||
transform = PackTextRecogInputs(meta_keys=('img_path', ))
|
||||
results = transform(copy.deepcopy(datainfo))
|
||||
self.assertIn('inputs', results)
|
||||
self.assertIn('data_sample', results)
|
||||
data_sample = results['data_sample']
|
||||
self.assertIn('data_samples', results)
|
||||
data_sample = results['data_samples']
|
||||
self.assertEqual(data_sample.gt_text.item, 'mmocr')
|
||||
self.assertIn('img_path', data_sample)
|
||||
self.assertNotIn('valid_ratio', data_sample)
|
||||
|
@ -129,7 +129,7 @@ class TestPackTextRecogInputs(TestCase):
|
|||
transform = PackTextRecogInputs(meta_keys=('img_shape', ))
|
||||
results = transform(copy.deepcopy(datainfo))
|
||||
self.assertNotIn('inputs', results)
|
||||
data_sample = results['data_sample']
|
||||
data_sample = results['data_samples']
|
||||
self.assertNotIn('item', data_sample.gt_text)
|
||||
|
||||
def test_repr(self):
|
||||
|
@ -165,8 +165,8 @@ class TestPackKIEInputs(TestCase):
|
|||
results = self.transform(copy.deepcopy(datainfo))
|
||||
self.assertIn('inputs', results)
|
||||
self.assertTupleEqual(tuple(results['inputs'].shape), (1, 10, 10))
|
||||
self.assertIn('data_sample', results)
|
||||
data_sample = results['data_sample']
|
||||
self.assertIn('data_samples', results)
|
||||
data_sample = results['data_samples']
|
||||
self.assertIsInstance(data_sample.gt_instances.bboxes, torch.Tensor)
|
||||
self.assertEqual(data_sample.gt_instances.bboxes.dtype, torch.float32)
|
||||
self.assertEqual(data_sample.gt_instances.labels.dtype, torch.int64)
|
||||
|
@ -179,9 +179,9 @@ class TestPackKIEInputs(TestCase):
|
|||
transform = PackKIEInputs(meta_keys=('img_path', ))
|
||||
results = transform(copy.deepcopy(datainfo))
|
||||
self.assertIn('inputs', results)
|
||||
self.assertIn('data_sample', results)
|
||||
self.assertIn('data_samples', results)
|
||||
|
||||
data_sample = results['data_sample']
|
||||
data_sample = results['data_samples']
|
||||
self.assertIn('bboxes', data_sample.gt_instances)
|
||||
self.assertIn('img_path', data_sample)
|
||||
|
||||
|
|
|
@ -200,9 +200,3 @@ class TestResize(unittest.TestCase):
|
|||
resize = Resize(scale=(40, 30))
|
||||
result = resize(dummy_result)
|
||||
self.assertEqual(result['gt_bboxes'].dtype, np.float32)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
t = TestRandomCrop()
|
||||
t.test_sample_crop_box()
|
||||
t.test_transform()
|
||||
|
|
|
@ -1,274 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
import shutil
|
||||
import tempfile
|
||||
from unittest import TestCase
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmengine.config import Config
|
||||
from mmengine.evaluator import BaseMetric
|
||||
from mmengine.hooks import Hook
|
||||
from mmengine.model import BaseModel
|
||||
from mmengine.registry import DATASETS, HOOKS, METRICS, MODELS
|
||||
from mmengine.runner import Runner
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from mmocr.engine.runner import MultiTestLoop, MultiValLoop
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class ToyModel(BaseModel):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear1 = nn.Linear(2, 2)
|
||||
self.linear2 = nn.Linear(2, 1)
|
||||
|
||||
def forward(self, batch_inputs, labels, mode='tensor'):
|
||||
labels = torch.stack(labels)
|
||||
outputs = self.linear1(batch_inputs)
|
||||
outputs = self.linear2(outputs)
|
||||
|
||||
if mode == 'tensor':
|
||||
return outputs
|
||||
elif mode == 'loss':
|
||||
loss = (labels - outputs).sum()
|
||||
outputs = dict(loss=loss)
|
||||
return outputs
|
||||
elif mode == 'predict':
|
||||
return outputs
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
class ToyDataset(Dataset):
|
||||
METAINFO = dict() # type: ignore
|
||||
data = torch.randn(12, 2)
|
||||
label = torch.ones(12)
|
||||
|
||||
@property
|
||||
def metainfo(self):
|
||||
return self.METAINFO
|
||||
|
||||
def __len__(self):
|
||||
return self.data.size(0)
|
||||
|
||||
def __getitem__(self, index):
|
||||
return dict(inputs=self.data[index], data_sample=self.label[index])
|
||||
|
||||
|
||||
@METRICS.register_module()
|
||||
class ToyMetric3(BaseMetric):
|
||||
|
||||
def __init__(self, collect_device='cpu', prefix=''):
|
||||
super().__init__(collect_device=collect_device, prefix=prefix)
|
||||
|
||||
def process(self, data_samples, predictions):
|
||||
result = {'acc': 1}
|
||||
self.results.append(result)
|
||||
|
||||
def compute_metrics(self, results):
|
||||
return dict(acc=1)
|
||||
|
||||
|
||||
class TestRunner(TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.temp_dir = tempfile.mkdtemp()
|
||||
epoch_based_cfg = dict(
|
||||
default_scope='mmocr',
|
||||
model=dict(type='ToyModel'),
|
||||
work_dir=self.temp_dir,
|
||||
train_dataloader=dict(
|
||||
dataset=dict(type='ToyDataset'),
|
||||
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||
batch_size=3,
|
||||
num_workers=0),
|
||||
val_dataloader=dict(
|
||||
dataset=dict(type='ToyDataset'),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
batch_size=3,
|
||||
num_workers=0),
|
||||
test_dataloader=dict(
|
||||
dataset=dict(type='ToyDataset'),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
batch_size=3,
|
||||
num_workers=0),
|
||||
auto_scale_lr=dict(base_batch_size=16, enable=False),
|
||||
optim_wrapper=dict(
|
||||
type='OptimWrapper', optimizer=dict(type='SGD', lr=0.01)),
|
||||
param_scheduler=dict(type='MultiStepLR', milestones=[1, 2]),
|
||||
val_evaluator=dict(type='ToyMetric1'),
|
||||
test_evaluator=dict(type='ToyMetric1'),
|
||||
train_cfg=dict(
|
||||
by_epoch=True, max_epochs=3, val_interval=1, val_begin=1),
|
||||
val_cfg=dict(),
|
||||
test_cfg=dict(),
|
||||
custom_hooks=[],
|
||||
default_hooks=dict(
|
||||
runtime_info=dict(type='RuntimeInfoHook'),
|
||||
timer=dict(type='IterTimerHook'),
|
||||
logger=dict(type='LoggerHook'),
|
||||
param_scheduler=dict(type='ParamSchedulerHook'),
|
||||
checkpoint=dict(
|
||||
type='CheckpointHook', interval=1, by_epoch=True),
|
||||
sampler_seed=dict(type='DistSamplerSeedHook')),
|
||||
launcher='none',
|
||||
env_cfg=dict(dist_cfg=dict(backend='nccl')),
|
||||
)
|
||||
self.epoch_based_cfg = Config(epoch_based_cfg)
|
||||
self.iter_based_cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||
self.iter_based_cfg.train_dataloader = dict(
|
||||
dataset=dict(type='ToyDataset'),
|
||||
sampler=dict(type='InfiniteSampler', shuffle=True),
|
||||
batch_size=3,
|
||||
num_workers=0)
|
||||
self.iter_based_cfg.train_cfg = dict(by_epoch=False, max_iters=12)
|
||||
self.iter_based_cfg.default_hooks = dict(
|
||||
runtime_info=dict(type='RuntimeInfoHook'),
|
||||
timer=dict(type='IterTimerHook'),
|
||||
logger=dict(type='LoggerHook'),
|
||||
param_scheduler=dict(type='ParamSchedulerHook'),
|
||||
checkpoint=dict(type='CheckpointHook', interval=1, by_epoch=False),
|
||||
sampler_seed=dict(type='DistSamplerSeedHook'))
|
||||
|
||||
def tearDown(self):
|
||||
shutil.rmtree(self.temp_dir)
|
||||
|
||||
def test_multi_val_loop(self):
|
||||
|
||||
before_val_iter_results = []
|
||||
after_val_iter_results = []
|
||||
multi_metrics = dict()
|
||||
|
||||
@HOOKS.register_module()
|
||||
class Fake_1(Hook):
|
||||
"""test custom train loop."""
|
||||
|
||||
def before_val_iter(self, runner, batch_idx, data_batch=None):
|
||||
before_val_iter_results.append('before')
|
||||
|
||||
def after_val_iter(self,
|
||||
runner,
|
||||
batch_idx,
|
||||
data_batch=None,
|
||||
outputs=None):
|
||||
after_val_iter_results.append('after')
|
||||
|
||||
def after_val_epoch(self, runner, metrics=None) -> None:
|
||||
multi_metrics.update(metrics)
|
||||
|
||||
self.iter_based_cfg.val_cfg = dict(type='MultiValLoop')
|
||||
self.iter_based_cfg.val_dataloader = [
|
||||
dict(
|
||||
dataset=dict(type='ToyDataset'),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
batch_size=3,
|
||||
num_workers=0),
|
||||
dict(
|
||||
dataset=dict(type='ToyDataset'),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
batch_size=3,
|
||||
num_workers=0)
|
||||
]
|
||||
self.iter_based_cfg.val_evaluator = [
|
||||
dict(type='ToyMetric3', prefix='tmp1'),
|
||||
dict(type='ToyMetric3', prefix='tmp2')
|
||||
]
|
||||
self.iter_based_cfg.custom_hooks = [dict(type='Fake_1', priority=50)]
|
||||
self.iter_based_cfg.experiment_name = 'test_multi_val_loop'
|
||||
runner = Runner.from_cfg(self.iter_based_cfg)
|
||||
runner.val()
|
||||
|
||||
self.assertIsInstance(runner.val_loop, MultiValLoop)
|
||||
|
||||
# test custom hook triggered as expected
|
||||
self.assertEqual(len(before_val_iter_results), 8)
|
||||
self.assertEqual(len(after_val_iter_results), 8)
|
||||
for before, after in zip(before_val_iter_results,
|
||||
after_val_iter_results):
|
||||
self.assertEqual(before, 'before')
|
||||
self.assertEqual(after, 'after')
|
||||
self.assertDictEqual(multi_metrics, {'tmp1/acc': 1, 'tmp2/acc': 1})
|
||||
|
||||
# test_same prefix
|
||||
self.iter_based_cfg.val_evaluator = [
|
||||
dict(type='ToyMetric3', prefix='tmp1'),
|
||||
dict(type='ToyMetric3', prefix='tmp1')
|
||||
]
|
||||
self.iter_based_cfg.experiment_name = 'test_multi_val_loop_same_prefix'
|
||||
runner = Runner.from_cfg(self.iter_based_cfg)
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
('Please set different'
|
||||
' prefix for different datasets'
|
||||
' in `val_evaluator`')):
|
||||
runner.val()
|
||||
|
||||
def test_multi_test_loop(self):
|
||||
|
||||
before_test_iter_results = []
|
||||
after_test_iter_results = []
|
||||
multi_metrics = dict()
|
||||
|
||||
@HOOKS.register_module()
|
||||
class Fake_2(Hook):
|
||||
"""test custom train loop."""
|
||||
|
||||
def before_test_iter(self, runner, batch_idx, data_batch=None):
|
||||
before_test_iter_results.append('before')
|
||||
|
||||
def after_test_iter(self,
|
||||
runner,
|
||||
batch_idx,
|
||||
data_batch=None,
|
||||
outputs=None):
|
||||
after_test_iter_results.append('after')
|
||||
|
||||
def after_test_epoch(self, runner, metrics=None) -> None:
|
||||
multi_metrics.update(metrics)
|
||||
|
||||
self.iter_based_cfg.test_cfg = dict(type='MultiTestLoop')
|
||||
self.iter_based_cfg.test_dataloader = [
|
||||
dict(
|
||||
dataset=dict(type='ToyDataset'),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
batch_size=3,
|
||||
num_workers=0),
|
||||
dict(
|
||||
dataset=dict(type='ToyDataset'),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
batch_size=3,
|
||||
num_workers=0)
|
||||
]
|
||||
self.iter_based_cfg.test_evaluator = [
|
||||
dict(type='ToyMetric3', prefix='tmp1'),
|
||||
dict(type='ToyMetric3', prefix='tmp2')
|
||||
]
|
||||
self.iter_based_cfg.custom_hooks = [dict(type='Fake_2', priority=50)]
|
||||
self.iter_based_cfg.experiment_name = 'multi_test_loop'
|
||||
runner = Runner.from_cfg(self.iter_based_cfg)
|
||||
runner.test()
|
||||
|
||||
self.assertIsInstance(runner.test_loop, MultiTestLoop)
|
||||
|
||||
# test custom hook triggered as expected
|
||||
self.assertEqual(len(before_test_iter_results), 8)
|
||||
self.assertEqual(len(after_test_iter_results), 8)
|
||||
for before, after in zip(before_test_iter_results,
|
||||
after_test_iter_results):
|
||||
self.assertEqual(before, 'before')
|
||||
self.assertEqual(after, 'after')
|
||||
self.assertDictEqual(multi_metrics, {'tmp1/acc': 1, 'tmp2/acc': 1})
|
||||
|
||||
# test_same prefix
|
||||
self.iter_based_cfg.test_evaluator = [
|
||||
dict(type='ToyMetric3', prefix='tmp1'),
|
||||
dict(type='ToyMetric3', prefix='tmp1')
|
||||
]
|
||||
self.iter_based_cfg.experiment_name = 'multi_test_loop_same_prefix'
|
||||
runner = Runner.from_cfg(self.iter_based_cfg)
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
('Please set different'
|
||||
' prefix for different datasets'
|
||||
' in `test_evaluator`')):
|
||||
runner.test()
|
|
@ -40,9 +40,9 @@ class ToyMetric(BaseMetric):
|
|||
|
||||
def process(self, data_batch, predictions):
|
||||
results = [{
|
||||
'pred': pred.get('pred'),
|
||||
'label': data['data_sample'].get('label')
|
||||
} for pred, data in zip(predictions, data_batch)]
|
||||
'pred': prediction['pred'],
|
||||
'label': prediction['label']
|
||||
} for prediction in predictions]
|
||||
self.results.extend(results)
|
||||
|
||||
def compute_metrics(self, results: List):
|
||||
|
@ -67,12 +67,13 @@ def generate_test_results(size, batch_size, pred, label):
|
|||
bs_residual = size % batch_size
|
||||
for i in range(num_batch):
|
||||
bs = bs_residual if i == num_batch - 1 else batch_size
|
||||
data_batch = [
|
||||
dict(
|
||||
inputs=np.zeros((3, 10, 10)),
|
||||
data_sample=BaseDataElement(label=label)) for _ in range(bs)
|
||||
data_batch = {
|
||||
'inputs': [np.zeros((3, 10, 10)) for _ in range(bs)],
|
||||
'data_samples': [BaseDataElement(label=label) for _ in range(bs)]
|
||||
}
|
||||
predictions = [
|
||||
BaseDataElement(pred=pred, label=label) for _ in range(bs)
|
||||
]
|
||||
predictions = [BaseDataElement(pred=pred) for _ in range(bs)]
|
||||
yield (data_batch, predictions)
|
||||
|
||||
|
||||
|
@ -92,7 +93,7 @@ class TestMultiDatasetsEvaluator(TestCase):
|
|||
|
||||
for data_samples, predictions in generate_test_results(
|
||||
size, batch_size, pred=1, label=1):
|
||||
evaluator.process(data_samples, predictions)
|
||||
evaluator.process(predictions, data_samples)
|
||||
|
||||
metrics = evaluator.evaluate(size=size)
|
||||
|
||||
|
@ -109,7 +110,7 @@ class TestMultiDatasetsEvaluator(TestCase):
|
|||
|
||||
for data_samples, predictions in generate_test_results(
|
||||
size, batch_size, pred=1, label=1):
|
||||
evaluator.process(data_samples, predictions)
|
||||
evaluator.process(predictions, data_samples)
|
||||
with self.assertRaises(ValueError):
|
||||
evaluator.evaluate(size=size)
|
||||
|
||||
|
@ -120,7 +121,7 @@ class TestMultiDatasetsEvaluator(TestCase):
|
|||
|
||||
for data_samples, predictions in generate_test_results(
|
||||
size, batch_size, pred=1, label=1):
|
||||
evaluator.process(data_samples, predictions)
|
||||
evaluator.process(predictions, data_samples)
|
||||
metrics = evaluator.evaluate(size=size)
|
||||
self.assertIn('Fake/Toy/accuracy', metrics)
|
||||
self.assertIn('Fake/accuracy', metrics)
|
||||
|
|
|
@ -46,8 +46,9 @@ class TestSDMGR(unittest.TestCase):
|
|||
return model
|
||||
|
||||
def forward_wrapper(self, model, data, mode):
|
||||
inputs, data_sample = model.data_preprocessor(data, False)
|
||||
return model.forward(inputs, data_sample, mode)
|
||||
out = model.data_preprocessor(data, False)
|
||||
inputs, data_samples = out['inputs'], out['data_samples']
|
||||
return model.forward(inputs, data_samples, mode)
|
||||
|
||||
def setUp(self):
|
||||
|
||||
|
@ -66,14 +67,11 @@ class TestSDMGR(unittest.TestCase):
|
|||
edge_labels=torch.LongTensor([[0, 1], [1, 0]]),
|
||||
texts=['text1', 'text2'],
|
||||
relations=torch.rand((2, 2, 5)))
|
||||
self.visual_data = [
|
||||
dict(inputs=torch.rand((3, 10, 10)), data_sample=data_sample)
|
||||
]
|
||||
self.novisual_data = [
|
||||
dict(
|
||||
inputs=torch.Tensor([]).reshape((0, 0, 0)),
|
||||
data_sample=data_sample)
|
||||
]
|
||||
self.visual_data = dict(
|
||||
inputs=[torch.rand((3, 10, 10))], data_samples=[data_sample])
|
||||
self.novisual_data = dict(
|
||||
inputs=[torch.Tensor([]).reshape((0, 0, 0))],
|
||||
data_samples=[data_sample])
|
||||
|
||||
def test_forward_loss(self):
|
||||
result = self.forward_wrapper(
|
||||
|
|
|
@ -11,8 +11,8 @@ from mmocr.structures import TextDetDataSample
|
|||
@MODELS.register_module()
|
||||
class TDAugment(torch.nn.Module):
|
||||
|
||||
def forward(self, batch_inputs, batch_data_samples):
|
||||
return batch_inputs, batch_data_samples
|
||||
def forward(self, inputs, data_samples):
|
||||
return inputs, data_samples
|
||||
|
||||
|
||||
class TestTextDetDataPreprocessor(TestCase):
|
||||
|
@ -47,58 +47,63 @@ class TestTextDetDataPreprocessor(TestCase):
|
|||
def test_forward(self):
|
||||
processor = TextDetDataPreprocessor(mean=[0, 0, 0], std=[1, 1, 1])
|
||||
|
||||
data = [{
|
||||
'inputs':
|
||||
torch.randint(0, 256, (3, 11, 10)),
|
||||
'data_sample':
|
||||
TextDetDataSample(
|
||||
metainfo=dict(img_shape=(11, 10), valid_ratio=1.0))
|
||||
}]
|
||||
inputs, data_samples = processor(data)
|
||||
print(inputs.dtype)
|
||||
data = {
|
||||
'inputs': [
|
||||
torch.randint(0, 256, (3, 11, 10)),
|
||||
],
|
||||
'data_samples': [
|
||||
TextDetDataSample(
|
||||
metainfo=dict(img_shape=(11, 10), valid_ratio=1.0)),
|
||||
]
|
||||
}
|
||||
out = processor(data)
|
||||
inputs, data_samples = out['inputs'], out['data_samples']
|
||||
self.assertEqual(inputs.shape, (1, 3, 11, 10))
|
||||
self.assertEqual(len(data_samples), 1)
|
||||
|
||||
# test channel_conversion
|
||||
processor = TextDetDataPreprocessor(
|
||||
mean=[0., 0., 0.], std=[1., 1., 1.], bgr_to_rgb=True)
|
||||
inputs, data_samples = processor(data)
|
||||
out = processor(data)
|
||||
inputs, data_samples = out['inputs'], out['data_samples']
|
||||
self.assertEqual(inputs.shape, (1, 3, 11, 10))
|
||||
self.assertEqual(len(data_samples), 1)
|
||||
|
||||
# test padding
|
||||
data = [{
|
||||
'inputs': torch.randint(0, 256, (3, 10, 11))
|
||||
}, {
|
||||
'inputs': torch.randint(0, 256, (3, 9, 14))
|
||||
}]
|
||||
data = {
|
||||
'inputs': [
|
||||
torch.randint(0, 256, (3, 10, 11)),
|
||||
torch.randint(0, 256, (3, 9, 14))
|
||||
]
|
||||
}
|
||||
processor = TextDetDataPreprocessor(
|
||||
mean=[0., 0., 0.], std=[1., 1., 1.], bgr_to_rgb=True)
|
||||
inputs, data_samples = processor(data)
|
||||
out = processor(data)
|
||||
inputs, data_samples = out['inputs'], out['data_samples']
|
||||
self.assertEqual(inputs.shape, (2, 3, 10, 14))
|
||||
self.assertIsNone(data_samples)
|
||||
|
||||
# test pad_size_divisor
|
||||
data = [{
|
||||
'inputs':
|
||||
torch.randint(0, 256, (3, 10, 11)),
|
||||
'data_sample':
|
||||
TextDetDataSample(
|
||||
metainfo=dict(img_shape=(10, 11), valid_ratio=1.0))
|
||||
}, {
|
||||
'inputs':
|
||||
torch.randint(0, 256, (3, 9, 24)),
|
||||
'data_sample':
|
||||
TextDetDataSample(
|
||||
metainfo=dict(img_shape=(9, 24), valid_ratio=1.0))
|
||||
}]
|
||||
data = {
|
||||
'inputs': [
|
||||
torch.randint(0, 256, (3, 10, 11)),
|
||||
torch.randint(0, 256, (3, 9, 24))
|
||||
],
|
||||
'data_samples': [
|
||||
TextDetDataSample(
|
||||
metainfo=dict(img_shape=(10, 11), valid_ratio=1.0)),
|
||||
TextDetDataSample(
|
||||
metainfo=dict(img_shape=(9, 24), valid_ratio=1.0))
|
||||
]
|
||||
}
|
||||
aug_cfg = [dict(type='TDAugment')]
|
||||
processor = TextDetDataPreprocessor(
|
||||
mean=[0., 0., 0.],
|
||||
std=[1., 1., 1.],
|
||||
pad_size_divisor=5,
|
||||
batch_augments=aug_cfg)
|
||||
inputs, data_samples = processor(data, training=True)
|
||||
out = processor(data)
|
||||
inputs, data_samples = out['inputs'], out['data_samples']
|
||||
self.assertEqual(inputs.shape, (2, 3, 10, 25))
|
||||
self.assertEqual(len(data_samples), 2)
|
||||
for data_sample, expected_shape in zip(data_samples, [(10, 25),
|
||||
|
|
|
@ -262,9 +262,3 @@ class TestMMDetWrapper(unittest.TestCase):
|
|||
self.assertEqual(len(results), 1)
|
||||
self.assertIsInstance(results[0], TextDetDataSample)
|
||||
self.assertTrue('polygons' in results[0].pred_instances.keys())
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test = TestMMDetWrapper()
|
||||
test.setUp()
|
||||
test.test_mask_two_stage_wrapper()
|
||||
|
|
|
@ -11,8 +11,8 @@ from mmocr.structures import TextRecogDataSample
|
|||
@MODELS.register_module()
|
||||
class Augment(torch.nn.Module):
|
||||
|
||||
def forward(self, batch_inputs, batch_data_samples):
|
||||
return batch_inputs, batch_data_samples
|
||||
def forward(self, inputs, data_samples):
|
||||
return inputs, data_samples
|
||||
|
||||
|
||||
class TestTextRecogDataPreprocessor(TestCase):
|
||||
|
@ -47,14 +47,17 @@ class TestTextRecogDataPreprocessor(TestCase):
|
|||
def test_forward(self):
|
||||
processor = TextRecogDataPreprocessor(mean=[0, 0, 0], std=[1, 1, 1])
|
||||
|
||||
data = [{
|
||||
'inputs':
|
||||
torch.randint(0, 256, (3, 11, 10)),
|
||||
'data_sample':
|
||||
TextRecogDataSample(
|
||||
metainfo=dict(img_shape=(11, 10), valid_ratio=1.0))
|
||||
}]
|
||||
inputs, data_samples = processor(data)
|
||||
data = {
|
||||
'inputs': [
|
||||
torch.randint(0, 256, (3, 11, 10)),
|
||||
],
|
||||
'data_samples': [
|
||||
TextRecogDataSample(
|
||||
metainfo=dict(img_shape=(11, 10), valid_ratio=1.0)),
|
||||
]
|
||||
}
|
||||
out = processor(data)
|
||||
inputs, data_samples = out['inputs'], out['data_samples']
|
||||
print(inputs.dtype)
|
||||
self.assertEqual(inputs.shape, (1, 3, 11, 10))
|
||||
self.assertEqual(len(data_samples), 1)
|
||||
|
@ -62,43 +65,46 @@ class TestTextRecogDataPreprocessor(TestCase):
|
|||
# test channel_conversion
|
||||
processor = TextRecogDataPreprocessor(
|
||||
mean=[0., 0., 0.], std=[1., 1., 1.], bgr_to_rgb=True)
|
||||
inputs, data_samples = processor(data)
|
||||
out = processor(data)
|
||||
inputs, data_samples = out['inputs'], out['data_samples']
|
||||
self.assertEqual(inputs.shape, (1, 3, 11, 10))
|
||||
self.assertEqual(len(data_samples), 1)
|
||||
|
||||
# test padding
|
||||
data = [{
|
||||
'inputs': torch.randint(0, 256, (3, 10, 11))
|
||||
}, {
|
||||
'inputs': torch.randint(0, 256, (3, 9, 14))
|
||||
}]
|
||||
data = {
|
||||
'inputs': [
|
||||
torch.randint(0, 256, (3, 10, 11)),
|
||||
torch.randint(0, 256, (3, 9, 14))
|
||||
]
|
||||
}
|
||||
processor = TextRecogDataPreprocessor(
|
||||
mean=[0., 0., 0.], std=[1., 1., 1.], bgr_to_rgb=True)
|
||||
inputs, data_samples = processor(data)
|
||||
out = processor(data)
|
||||
inputs, data_samples = out['inputs'], out['data_samples']
|
||||
self.assertEqual(inputs.shape, (2, 3, 10, 14))
|
||||
self.assertIsNone(data_samples)
|
||||
|
||||
# test pad_size_divisor
|
||||
data = [{
|
||||
'inputs':
|
||||
torch.randint(0, 256, (3, 10, 11)),
|
||||
'data_sample':
|
||||
TextRecogDataSample(
|
||||
metainfo=dict(img_shape=(10, 11), valid_ratio=1.0))
|
||||
}, {
|
||||
'inputs':
|
||||
torch.randint(0, 256, (3, 9, 24)),
|
||||
'data_sample':
|
||||
TextRecogDataSample(
|
||||
metainfo=dict(img_shape=(9, 24), valid_ratio=1.0))
|
||||
}]
|
||||
data = {
|
||||
'inputs': [
|
||||
torch.randint(0, 256, (3, 10, 11)),
|
||||
torch.randint(0, 256, (3, 9, 24)),
|
||||
],
|
||||
'data_samples': [
|
||||
TextRecogDataSample(
|
||||
metainfo=dict(img_shape=(10, 11), valid_ratio=1.0)),
|
||||
TextRecogDataSample(
|
||||
metainfo=dict(img_shape=(9, 24), valid_ratio=1.0))
|
||||
]
|
||||
}
|
||||
aug_cfg = [dict(type='Augment')]
|
||||
processor = TextRecogDataPreprocessor(
|
||||
mean=[0., 0., 0.],
|
||||
std=[1., 1., 1.],
|
||||
pad_size_divisor=5,
|
||||
batch_augments=aug_cfg)
|
||||
inputs, data_samples = processor(data, training=True)
|
||||
out = processor(data)
|
||||
inputs, data_samples = out['inputs'], out['data_samples']
|
||||
self.assertEqual(inputs.shape, (2, 3, 10, 25))
|
||||
self.assertEqual(len(data_samples), 2)
|
||||
for data_sample, expected_shape, expected_ratio in zip(
|
||||
|
|
Loading…
Reference in New Issue