mirror of
https://github.com/open-mmlab/mmocr.git
synced 2025-06-03 21:54:47 +08:00
[Utils] Add typing
This commit is contained in:
parent
41a642bc7b
commit
68b0aaa2e9
@ -1,25 +1,12 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
from mmengine.config import ConfigDict
|
||||
from mmengine.model.base_model import BaseModel
|
||||
|
||||
from mmocr.data import TextRecogDataSample
|
||||
|
||||
# TODO Move to type hint file
|
||||
# Type hint of config data
|
||||
ConfigType = Union[ConfigDict, dict]
|
||||
OptConfigType = Optional[ConfigType]
|
||||
# Type hint of one or more config data
|
||||
MultiConfig = Union[ConfigType, List[ConfigType]]
|
||||
OptMultiConfig = Optional[MultiConfig]
|
||||
|
||||
ForwardResults = Union[Dict[str, torch.Tensor], List[TextRecogDataSample],
|
||||
Tuple[torch.Tensor], torch.Tensor]
|
||||
SampleList = List[TextRecogDataSample]
|
||||
OptSampleList = Optional[SampleList]
|
||||
from mmocr.utils import (OptConfigType, OptMultiConfig, OptRecSampleList,
|
||||
RecForwardResults, RecSampleList)
|
||||
|
||||
|
||||
class BaseRecognizer(BaseModel, metaclass=ABCMeta):
|
||||
@ -34,7 +21,7 @@ class BaseRecognizer(BaseModel, metaclass=ABCMeta):
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
data_preprocessor: Optional[Union[ConfigDict, dict]] = None,
|
||||
data_preprocessor: OptConfigType = None,
|
||||
init_cfg: OptMultiConfig = None):
|
||||
super().__init__(
|
||||
data_preprocessor=data_preprocessor, init_cfg=init_cfg)
|
||||
@ -66,9 +53,9 @@ class BaseRecognizer(BaseModel, metaclass=ABCMeta):
|
||||
|
||||
def forward(self,
|
||||
batch_inputs: torch.Tensor,
|
||||
batch_data_samples: OptSampleList = None,
|
||||
batch_data_samples: OptRecSampleList = None,
|
||||
mode: str = 'tensor',
|
||||
**kwargs) -> ForwardResults:
|
||||
**kwargs) -> RecForwardResults:
|
||||
"""The unified entry for a forward process in both training and test.
|
||||
|
||||
The method should accept three modes: "tensor", "predict" and "loss":
|
||||
@ -108,14 +95,15 @@ class BaseRecognizer(BaseModel, metaclass=ABCMeta):
|
||||
'Only supports loss, predict and tensor mode')
|
||||
|
||||
@abstractmethod
|
||||
def loss(self, batch_inputs: torch.Tensor, batch_data_samples: SampleList,
|
||||
def loss(self, batch_inputs: torch.Tensor,
|
||||
batch_data_samples: RecSampleList,
|
||||
**kwargs) -> Union[dict, tuple]:
|
||||
"""Calculate losses from a batch of inputs and data samples."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def predict(self, batch_inputs: torch.Tensor,
|
||||
batch_data_samples: SampleList, **kwargs) -> SampleList:
|
||||
batch_data_samples: RecSampleList, **kwargs) -> RecSampleList:
|
||||
"""Predict results from a batch of inputs and data samples with post-
|
||||
processing."""
|
||||
pass
|
||||
@ -123,7 +111,7 @@ class BaseRecognizer(BaseModel, metaclass=ABCMeta):
|
||||
@abstractmethod
|
||||
def _forward(self,
|
||||
batch_inputs: torch.Tensor,
|
||||
batch_data_samples: OptSampleList = None,
|
||||
batch_data_samples: OptRecSampleList = None,
|
||||
**kwargs):
|
||||
"""Network forward process.
|
||||
|
||||
|
@ -24,19 +24,30 @@ from .polygon_utils import (boundary_iou, crop_polygon, is_poly_inside_rect,
|
||||
rescale_polygons, shapely2poly)
|
||||
from .setup_env import register_all_modules
|
||||
from .string_util import StringStrip
|
||||
from .typing import (ColorType, ConfigType, DetSampleList, InitConfigType,
|
||||
KIESampleList, MultiConfig, OptConfigType,
|
||||
OptDetSampleList, OptInitConfigType, OptKIESampleList,
|
||||
OptMultiConfig, OptRecSampleList, OptTensor,
|
||||
RecForwardResults, RecSampleList)
|
||||
|
||||
__all__ = [
|
||||
'Registry', 'build_from_cfg', 'get_root_logger', 'collect_env',
|
||||
'is_3dlist', 'is_type_list', 'is_none_or_type', 'equal_len', 'is_2dlist',
|
||||
'valid_boundary', 'list_to_file', 'list_from_file', 'is_on_same_line',
|
||||
'stitch_boxes_into_lines', 'StringStrip', 'revert_sync_batchnorm',
|
||||
'bezier2polygon', 'sort_points', 'dump_ocr_data', 'recog_anno_to_imginfo',
|
||||
'rescale_polygons', 'rescale_polygon', 'rescale_bboxes', 'bbox2poly',
|
||||
'crop_polygon', 'is_poly_inside_rect', 'poly2bbox', 'poly_intersection',
|
||||
'poly_iou', 'poly_make_valid', 'poly_union', 'poly2shapely',
|
||||
'polys2shapely', 'register_all_modules', 'offset_polygon', 'sort_vertex8',
|
||||
'sort_vertex', 'bbox_center_distance', 'bbox_diag_distance',
|
||||
'compute_hmean', 'boundary_iou', 'point_distance', 'points_center',
|
||||
'fill_hole', 'LineJsonParser', 'LineStrParser', 'shapely2poly', 'crop_img',
|
||||
'warp_img'
|
||||
'bezier_to_polygon', 'sort_points', 'dump_ocr_data',
|
||||
'recog_anno_to_imginfo', 'rescale_polygons', 'rescale_polygon',
|
||||
'rescale_bboxes', 'bbox2poly', 'crop_polygon', 'is_poly_inside_rect',
|
||||
'poly2bbox', 'poly_intersection', 'poly_iou', 'poly_make_valid',
|
||||
'poly_union', 'poly2shapely', 'polys2shapely', 'register_all_modules',
|
||||
'offset_polygon', 'sort_vertex8', 'sort_vertex', 'bbox_center_distance',
|
||||
'bbox_diag', 'compute_hmean', 'filter_2dlist_result',
|
||||
'many2one_match_ic13', 'one2one_match_ic13', 'select_top_boundary',
|
||||
'boundary_iou', 'point_distance', 'points_center', 'fill_hole',
|
||||
'LineJsonParser', 'LineStrParser', 'shapely2poly', 'crop_img', 'warp_img',
|
||||
'ConfigType', 'DetSampleList', 'RecForwardResults', 'InitConfigType',
|
||||
'OptConfigType', 'OptDetSampleList', 'OptInitConfigType', 'OptMultiConfig',
|
||||
'OptRecSampleList', 'RecSampleList', 'MultiConfig', 'OptTensor',
|
||||
'ColorType', 'OptKIESampleList', 'KIESampleList', 'bbox_diag_distance',
|
||||
'bezier2polygon'
|
||||
]
|
||||
|
33
mmocr/utils/typing.py
Normal file
33
mmocr/utils/typing.py
Normal file
@ -0,0 +1,33 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
"""Collecting some commonly used type hint in MMOCR."""
|
||||
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from mmengine.config import ConfigDict
|
||||
|
||||
from mmocr.data import KIEDataSample, TextDetDataSample, TextRecogDataSample
|
||||
|
||||
# Config
|
||||
ConfigType = Union[ConfigDict, Dict]
|
||||
OptConfigType = Optional[ConfigType]
|
||||
MultiConfig = Union[ConfigType, List[ConfigType]]
|
||||
OptMultiConfig = Optional[MultiConfig]
|
||||
InitConfigType = Union[Dict, List[Dict]]
|
||||
OptInitConfigType = Optional[InitConfigType]
|
||||
|
||||
# Data
|
||||
RecSampleList = List[TextRecogDataSample]
|
||||
DetSampleList = List[TextDetDataSample]
|
||||
KIESampleList = List[KIEDataSample]
|
||||
OptRecSampleList = Optional[RecSampleList]
|
||||
OptDetSampleList = Optional[DetSampleList]
|
||||
OptKIESampleList = Optional[KIESampleList]
|
||||
|
||||
OptTensor = Optional[torch.Tensor]
|
||||
|
||||
RecForwardResults = Union[Dict[str, torch.Tensor], List[TextRecogDataSample],
|
||||
Tuple[torch.Tensor], torch.Tensor]
|
||||
|
||||
# Visualization
|
||||
ColorType = Union[str, Tuple, List[str], List[Tuple]]
|
Loading…
x
Reference in New Issue
Block a user