mirror of https://github.com/open-mmlab/mmocr.git
Add BaseTextDetHead
parent
9c3d741712
commit
26da038d49
.dev_scripts
mmocr/models/textdet/dense_heads
|
@ -5,3 +5,9 @@
|
|||
# .*/utils.py
|
||||
|
||||
.*/__init__.py
|
||||
|
||||
# will be deleted
|
||||
mmocr/models/textdet/dense_heads/head_mixin.py
|
||||
|
||||
# Will be covered by det head tests
|
||||
mmocr/models/textdet/dense_heads/base_textdet_head.py
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .base_textdet_head import BaseTextDetHead
|
||||
from .db_head import DBHead
|
||||
from .drrg_head import DRRGHead
|
||||
from .fce_head import FCEHead
|
||||
|
@ -9,5 +10,5 @@ from .textsnake_head import TextSnakeHead
|
|||
|
||||
__all__ = [
|
||||
'PSEHead', 'PANHead', 'DBHead', 'FCEHead', 'TextSnakeHead', 'DRRGHead',
|
||||
'HeadMixin'
|
||||
'HeadMixin', 'BaseTextDetHead'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,29 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from mmcv.runner import BaseModule
|
||||
|
||||
from mmocr.registry import MODELS
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class BaseTextDetHead(BaseModule):
|
||||
"""Base head for text detection, build the loss and postprocessor.
|
||||
|
||||
Args:
|
||||
loss (dict): Config to build loss.
|
||||
postprocessor (dict): Config to build postprocessor.
|
||||
init_cfg (dict or list[dict], optional): Initialization configs.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
loss: Dict,
|
||||
postprocessor: Dict,
|
||||
init_cfg: Optional[Union[Dict, List[Dict]]] = None) -> None:
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
assert isinstance(loss, dict)
|
||||
assert isinstance(postprocessor, dict)
|
||||
|
||||
self.loss = MODELS.build(loss)
|
||||
self.postprocessor = MODELS.build(postprocessor)
|
|
@ -5,6 +5,7 @@ from mmocr.registry import MODELS
|
|||
from mmocr.utils import check_argument
|
||||
|
||||
|
||||
# TODO: del this
|
||||
@MODELS.register_module()
|
||||
class HeadMixin:
|
||||
"""Base head class for text detection, including loss calcalation and
|
||||
|
|
Loading…
Reference in New Issue