[Utils] Add typing

This commit is contained in:
wangxinyu 2022-07-13 12:04:06 +00:00 committed by gaotongxiao
parent 41a642bc7b
commit 68b0aaa2e9
3 changed files with 63 additions and 31 deletions

View File

@ -1,25 +1,12 @@
# Copyright (c) OpenMMLab. All rights reserved.
from abc import ABCMeta, abstractmethod
from typing import Dict, List, Optional, Tuple, Union
from typing import Union
import torch
from mmengine.config import ConfigDict
from mmengine.model.base_model import BaseModel
from mmocr.data import TextRecogDataSample
# TODO Move to type hint file
# Type hint of config data
ConfigType = Union[ConfigDict, dict]
OptConfigType = Optional[ConfigType]
# Type hint of one or more config data
MultiConfig = Union[ConfigType, List[ConfigType]]
OptMultiConfig = Optional[MultiConfig]
ForwardResults = Union[Dict[str, torch.Tensor], List[TextRecogDataSample],
Tuple[torch.Tensor], torch.Tensor]
SampleList = List[TextRecogDataSample]
OptSampleList = Optional[SampleList]
from mmocr.utils import (OptConfigType, OptMultiConfig, OptRecSampleList,
RecForwardResults, RecSampleList)
class BaseRecognizer(BaseModel, metaclass=ABCMeta):
@ -34,7 +21,7 @@ class BaseRecognizer(BaseModel, metaclass=ABCMeta):
"""
def __init__(self,
data_preprocessor: Optional[Union[ConfigDict, dict]] = None,
data_preprocessor: OptConfigType = None,
init_cfg: OptMultiConfig = None):
super().__init__(
data_preprocessor=data_preprocessor, init_cfg=init_cfg)
@ -66,9 +53,9 @@ class BaseRecognizer(BaseModel, metaclass=ABCMeta):
def forward(self,
batch_inputs: torch.Tensor,
batch_data_samples: OptSampleList = None,
batch_data_samples: OptRecSampleList = None,
mode: str = 'tensor',
**kwargs) -> ForwardResults:
**kwargs) -> RecForwardResults:
"""The unified entry for a forward process in both training and test.
The method should accept three modes: "tensor", "predict" and "loss":
@ -108,14 +95,15 @@ class BaseRecognizer(BaseModel, metaclass=ABCMeta):
'Only supports loss, predict and tensor mode')
@abstractmethod
def loss(self, batch_inputs: torch.Tensor, batch_data_samples: SampleList,
def loss(self, batch_inputs: torch.Tensor,
batch_data_samples: RecSampleList,
**kwargs) -> Union[dict, tuple]:
"""Calculate losses from a batch of inputs and data samples."""
pass
@abstractmethod
def predict(self, batch_inputs: torch.Tensor,
batch_data_samples: SampleList, **kwargs) -> SampleList:
batch_data_samples: RecSampleList, **kwargs) -> RecSampleList:
"""Predict results from a batch of inputs and data samples with post-
processing."""
pass
@ -123,7 +111,7 @@ class BaseRecognizer(BaseModel, metaclass=ABCMeta):
@abstractmethod
def _forward(self,
batch_inputs: torch.Tensor,
batch_data_samples: OptSampleList = None,
batch_data_samples: OptRecSampleList = None,
**kwargs):
"""Network forward process.

View File

@ -24,19 +24,30 @@ from .polygon_utils import (boundary_iou, crop_polygon, is_poly_inside_rect,
rescale_polygons, shapely2poly)
from .setup_env import register_all_modules
from .string_util import StringStrip
from .typing import (ColorType, ConfigType, DetSampleList, InitConfigType,
KIESampleList, MultiConfig, OptConfigType,
OptDetSampleList, OptInitConfigType, OptKIESampleList,
OptMultiConfig, OptRecSampleList, OptTensor,
RecForwardResults, RecSampleList)
__all__ = [
'Registry', 'build_from_cfg', 'get_root_logger', 'collect_env',
'is_3dlist', 'is_type_list', 'is_none_or_type', 'equal_len', 'is_2dlist',
'valid_boundary', 'list_to_file', 'list_from_file', 'is_on_same_line',
'stitch_boxes_into_lines', 'StringStrip', 'revert_sync_batchnorm',
'bezier2polygon', 'sort_points', 'dump_ocr_data', 'recog_anno_to_imginfo',
'rescale_polygons', 'rescale_polygon', 'rescale_bboxes', 'bbox2poly',
'crop_polygon', 'is_poly_inside_rect', 'poly2bbox', 'poly_intersection',
'poly_iou', 'poly_make_valid', 'poly_union', 'poly2shapely',
'polys2shapely', 'register_all_modules', 'offset_polygon', 'sort_vertex8',
'sort_vertex', 'bbox_center_distance', 'bbox_diag_distance',
'compute_hmean', 'boundary_iou', 'point_distance', 'points_center',
'fill_hole', 'LineJsonParser', 'LineStrParser', 'shapely2poly', 'crop_img',
'warp_img'
'bezier_to_polygon', 'sort_points', 'dump_ocr_data',
'recog_anno_to_imginfo', 'rescale_polygons', 'rescale_polygon',
'rescale_bboxes', 'bbox2poly', 'crop_polygon', 'is_poly_inside_rect',
'poly2bbox', 'poly_intersection', 'poly_iou', 'poly_make_valid',
'poly_union', 'poly2shapely', 'polys2shapely', 'register_all_modules',
'offset_polygon', 'sort_vertex8', 'sort_vertex', 'bbox_center_distance',
'bbox_diag', 'compute_hmean', 'filter_2dlist_result',
'many2one_match_ic13', 'one2one_match_ic13', 'select_top_boundary',
'boundary_iou', 'point_distance', 'points_center', 'fill_hole',
'LineJsonParser', 'LineStrParser', 'shapely2poly', 'crop_img', 'warp_img',
'ConfigType', 'DetSampleList', 'RecForwardResults', 'InitConfigType',
'OptConfigType', 'OptDetSampleList', 'OptInitConfigType', 'OptMultiConfig',
'OptRecSampleList', 'RecSampleList', 'MultiConfig', 'OptTensor',
'ColorType', 'OptKIESampleList', 'KIESampleList', 'bbox_diag_distance',
'bezier2polygon'
]

33
mmocr/utils/typing.py Normal file
View File

@ -0,0 +1,33 @@
# Copyright (c) OpenMMLab. All rights reserved.
"""Collecting some commonly used type hint in MMOCR."""
from typing import Dict, List, Optional, Tuple, Union
import torch
from mmengine.config import ConfigDict
from mmocr.data import KIEDataSample, TextDetDataSample, TextRecogDataSample
# Config
ConfigType = Union[ConfigDict, Dict]
OptConfigType = Optional[ConfigType]
MultiConfig = Union[ConfigType, List[ConfigType]]
OptMultiConfig = Optional[MultiConfig]
InitConfigType = Union[Dict, List[Dict]]
OptInitConfigType = Optional[InitConfigType]
# Data
RecSampleList = List[TextRecogDataSample]
DetSampleList = List[TextDetDataSample]
KIESampleList = List[KIEDataSample]
OptRecSampleList = Optional[RecSampleList]
OptDetSampleList = Optional[DetSampleList]
OptKIESampleList = Optional[KIESampleList]
OptTensor = Optional[torch.Tensor]
RecForwardResults = Union[Dict[str, torch.Tensor], List[TextRecogDataSample],
Tuple[torch.Tensor], torch.Tensor]
# Visualization
ColorType = Union[str, Tuple, List[str], List[Tuple]]