mirror of
https://github.com/open-mmlab/mmocr.git
synced 2025-06-03 21:54:47 +08:00
[Utils] Add typing
This commit is contained in:
parent
41a642bc7b
commit
68b0aaa2e9
@ -1,25 +1,12 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
from abc import ABCMeta, abstractmethod
|
from abc import ABCMeta, abstractmethod
|
||||||
from typing import Dict, List, Optional, Tuple, Union
|
from typing import Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from mmengine.config import ConfigDict
|
|
||||||
from mmengine.model.base_model import BaseModel
|
from mmengine.model.base_model import BaseModel
|
||||||
|
|
||||||
from mmocr.data import TextRecogDataSample
|
from mmocr.utils import (OptConfigType, OptMultiConfig, OptRecSampleList,
|
||||||
|
RecForwardResults, RecSampleList)
|
||||||
# TODO Move to type hint file
|
|
||||||
# Type hint of config data
|
|
||||||
ConfigType = Union[ConfigDict, dict]
|
|
||||||
OptConfigType = Optional[ConfigType]
|
|
||||||
# Type hint of one or more config data
|
|
||||||
MultiConfig = Union[ConfigType, List[ConfigType]]
|
|
||||||
OptMultiConfig = Optional[MultiConfig]
|
|
||||||
|
|
||||||
ForwardResults = Union[Dict[str, torch.Tensor], List[TextRecogDataSample],
|
|
||||||
Tuple[torch.Tensor], torch.Tensor]
|
|
||||||
SampleList = List[TextRecogDataSample]
|
|
||||||
OptSampleList = Optional[SampleList]
|
|
||||||
|
|
||||||
|
|
||||||
class BaseRecognizer(BaseModel, metaclass=ABCMeta):
|
class BaseRecognizer(BaseModel, metaclass=ABCMeta):
|
||||||
@ -34,7 +21,7 @@ class BaseRecognizer(BaseModel, metaclass=ABCMeta):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
data_preprocessor: Optional[Union[ConfigDict, dict]] = None,
|
data_preprocessor: OptConfigType = None,
|
||||||
init_cfg: OptMultiConfig = None):
|
init_cfg: OptMultiConfig = None):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
data_preprocessor=data_preprocessor, init_cfg=init_cfg)
|
data_preprocessor=data_preprocessor, init_cfg=init_cfg)
|
||||||
@ -66,9 +53,9 @@ class BaseRecognizer(BaseModel, metaclass=ABCMeta):
|
|||||||
|
|
||||||
def forward(self,
|
def forward(self,
|
||||||
batch_inputs: torch.Tensor,
|
batch_inputs: torch.Tensor,
|
||||||
batch_data_samples: OptSampleList = None,
|
batch_data_samples: OptRecSampleList = None,
|
||||||
mode: str = 'tensor',
|
mode: str = 'tensor',
|
||||||
**kwargs) -> ForwardResults:
|
**kwargs) -> RecForwardResults:
|
||||||
"""The unified entry for a forward process in both training and test.
|
"""The unified entry for a forward process in both training and test.
|
||||||
|
|
||||||
The method should accept three modes: "tensor", "predict" and "loss":
|
The method should accept three modes: "tensor", "predict" and "loss":
|
||||||
@ -108,14 +95,15 @@ class BaseRecognizer(BaseModel, metaclass=ABCMeta):
|
|||||||
'Only supports loss, predict and tensor mode')
|
'Only supports loss, predict and tensor mode')
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def loss(self, batch_inputs: torch.Tensor, batch_data_samples: SampleList,
|
def loss(self, batch_inputs: torch.Tensor,
|
||||||
|
batch_data_samples: RecSampleList,
|
||||||
**kwargs) -> Union[dict, tuple]:
|
**kwargs) -> Union[dict, tuple]:
|
||||||
"""Calculate losses from a batch of inputs and data samples."""
|
"""Calculate losses from a batch of inputs and data samples."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def predict(self, batch_inputs: torch.Tensor,
|
def predict(self, batch_inputs: torch.Tensor,
|
||||||
batch_data_samples: SampleList, **kwargs) -> SampleList:
|
batch_data_samples: RecSampleList, **kwargs) -> RecSampleList:
|
||||||
"""Predict results from a batch of inputs and data samples with post-
|
"""Predict results from a batch of inputs and data samples with post-
|
||||||
processing."""
|
processing."""
|
||||||
pass
|
pass
|
||||||
@ -123,7 +111,7 @@ class BaseRecognizer(BaseModel, metaclass=ABCMeta):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _forward(self,
|
def _forward(self,
|
||||||
batch_inputs: torch.Tensor,
|
batch_inputs: torch.Tensor,
|
||||||
batch_data_samples: OptSampleList = None,
|
batch_data_samples: OptRecSampleList = None,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
"""Network forward process.
|
"""Network forward process.
|
||||||
|
|
||||||
|
@ -24,19 +24,30 @@ from .polygon_utils import (boundary_iou, crop_polygon, is_poly_inside_rect,
|
|||||||
rescale_polygons, shapely2poly)
|
rescale_polygons, shapely2poly)
|
||||||
from .setup_env import register_all_modules
|
from .setup_env import register_all_modules
|
||||||
from .string_util import StringStrip
|
from .string_util import StringStrip
|
||||||
|
from .typing import (ColorType, ConfigType, DetSampleList, InitConfigType,
|
||||||
|
KIESampleList, MultiConfig, OptConfigType,
|
||||||
|
OptDetSampleList, OptInitConfigType, OptKIESampleList,
|
||||||
|
OptMultiConfig, OptRecSampleList, OptTensor,
|
||||||
|
RecForwardResults, RecSampleList)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'Registry', 'build_from_cfg', 'get_root_logger', 'collect_env',
|
'Registry', 'build_from_cfg', 'get_root_logger', 'collect_env',
|
||||||
'is_3dlist', 'is_type_list', 'is_none_or_type', 'equal_len', 'is_2dlist',
|
'is_3dlist', 'is_type_list', 'is_none_or_type', 'equal_len', 'is_2dlist',
|
||||||
'valid_boundary', 'list_to_file', 'list_from_file', 'is_on_same_line',
|
'valid_boundary', 'list_to_file', 'list_from_file', 'is_on_same_line',
|
||||||
'stitch_boxes_into_lines', 'StringStrip', 'revert_sync_batchnorm',
|
'stitch_boxes_into_lines', 'StringStrip', 'revert_sync_batchnorm',
|
||||||
'bezier2polygon', 'sort_points', 'dump_ocr_data', 'recog_anno_to_imginfo',
|
'bezier_to_polygon', 'sort_points', 'dump_ocr_data',
|
||||||
'rescale_polygons', 'rescale_polygon', 'rescale_bboxes', 'bbox2poly',
|
'recog_anno_to_imginfo', 'rescale_polygons', 'rescale_polygon',
|
||||||
'crop_polygon', 'is_poly_inside_rect', 'poly2bbox', 'poly_intersection',
|
'rescale_bboxes', 'bbox2poly', 'crop_polygon', 'is_poly_inside_rect',
|
||||||
'poly_iou', 'poly_make_valid', 'poly_union', 'poly2shapely',
|
'poly2bbox', 'poly_intersection', 'poly_iou', 'poly_make_valid',
|
||||||
'polys2shapely', 'register_all_modules', 'offset_polygon', 'sort_vertex8',
|
'poly_union', 'poly2shapely', 'polys2shapely', 'register_all_modules',
|
||||||
'sort_vertex', 'bbox_center_distance', 'bbox_diag_distance',
|
'offset_polygon', 'sort_vertex8', 'sort_vertex', 'bbox_center_distance',
|
||||||
'compute_hmean', 'boundary_iou', 'point_distance', 'points_center',
|
'bbox_diag', 'compute_hmean', 'filter_2dlist_result',
|
||||||
'fill_hole', 'LineJsonParser', 'LineStrParser', 'shapely2poly', 'crop_img',
|
'many2one_match_ic13', 'one2one_match_ic13', 'select_top_boundary',
|
||||||
'warp_img'
|
'boundary_iou', 'point_distance', 'points_center', 'fill_hole',
|
||||||
|
'LineJsonParser', 'LineStrParser', 'shapely2poly', 'crop_img', 'warp_img',
|
||||||
|
'ConfigType', 'DetSampleList', 'RecForwardResults', 'InitConfigType',
|
||||||
|
'OptConfigType', 'OptDetSampleList', 'OptInitConfigType', 'OptMultiConfig',
|
||||||
|
'OptRecSampleList', 'RecSampleList', 'MultiConfig', 'OptTensor',
|
||||||
|
'ColorType', 'OptKIESampleList', 'KIESampleList', 'bbox_diag_distance',
|
||||||
|
'bezier2polygon'
|
||||||
]
|
]
|
||||||
|
33
mmocr/utils/typing.py
Normal file
33
mmocr/utils/typing.py
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
"""Collecting some commonly used type hint in MMOCR."""
|
||||||
|
|
||||||
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from mmengine.config import ConfigDict
|
||||||
|
|
||||||
|
from mmocr.data import KIEDataSample, TextDetDataSample, TextRecogDataSample
|
||||||
|
|
||||||
|
# Config
|
||||||
|
ConfigType = Union[ConfigDict, Dict]
|
||||||
|
OptConfigType = Optional[ConfigType]
|
||||||
|
MultiConfig = Union[ConfigType, List[ConfigType]]
|
||||||
|
OptMultiConfig = Optional[MultiConfig]
|
||||||
|
InitConfigType = Union[Dict, List[Dict]]
|
||||||
|
OptInitConfigType = Optional[InitConfigType]
|
||||||
|
|
||||||
|
# Data
|
||||||
|
RecSampleList = List[TextRecogDataSample]
|
||||||
|
DetSampleList = List[TextDetDataSample]
|
||||||
|
KIESampleList = List[KIEDataSample]
|
||||||
|
OptRecSampleList = Optional[RecSampleList]
|
||||||
|
OptDetSampleList = Optional[DetSampleList]
|
||||||
|
OptKIESampleList = Optional[KIESampleList]
|
||||||
|
|
||||||
|
OptTensor = Optional[torch.Tensor]
|
||||||
|
|
||||||
|
RecForwardResults = Union[Dict[str, torch.Tensor], List[TextRecogDataSample],
|
||||||
|
Tuple[torch.Tensor], torch.Tensor]
|
||||||
|
|
||||||
|
# Visualization
|
||||||
|
ColorType = Union[str, Tuple, List[str], List[Tuple]]
|
Loading…
x
Reference in New Issue
Block a user