mirror of https://github.com/open-mmlab/mmocr.git
[Enhancemnet] Add BaseTextDetModuleLoss (#1323)
* [Enhancemnet] Add BaseTextDetModuleLoss * textkernelmixin->SegBasedModuleLoss * Update configs/textdet/dbnet/dbnet_resnet18_fpnc_1200e_icdar2015.py Co-authored-by: liukuikun <24622904+Harold-lkk@users.noreply.github.com> Co-authored-by: liukuikun <24622904+Harold-lkk@users.noreply.github.com>pull/1327/head
parent
a45716d20e
commit
ad73fb10ff
|
@ -4,10 +4,10 @@ from .drrg_module_loss import DRRGModuleLoss
|
|||
from .fce_module_loss import FCEModuleLoss
|
||||
from .pan_module_loss import PANModuleLoss
|
||||
from .pse_module_loss import PSEModuleLoss
|
||||
from .text_kernel_mixin import TextKernelMixin
|
||||
from .seg_based_module_loss import SegBasedModuleLoss
|
||||
from .textsnake_module_loss import TextSnakeModuleLoss
|
||||
|
||||
__all__ = [
|
||||
'PANModuleLoss', 'PSEModuleLoss', 'DBModuleLoss', 'TextSnakeModuleLoss',
|
||||
'FCEModuleLoss', 'DRRGModuleLoss', 'TextKernelMixin'
|
||||
'FCEModuleLoss', 'DRRGModuleLoss', 'SegBasedModuleLoss'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,51 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from typing import Dict, Sequence, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from mmocr.registry import MODELS
|
||||
from mmocr.utils.typing import DetSampleList
|
||||
|
||||
INPUT_TYPES = Union[torch.Tensor, Sequence[torch.Tensor], Dict]
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class BaseTextDetModuleLoss(nn.Module, metaclass=ABCMeta):
|
||||
r"""Base class for text detection module loss.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
@abstractmethod
|
||||
def forward(self,
|
||||
inputs: INPUT_TYPES,
|
||||
data_samples: DetSampleList = None) -> Dict:
|
||||
"""Calculates losses from a batch of inputs and data samples. Returns a
|
||||
dict of losses.
|
||||
|
||||
Args:
|
||||
inputs (Tensor or list[Tensor] or dict): The raw tensor outputs
|
||||
from the model.
|
||||
data_samples (list(TextDetDataSample)): Datasamples containing
|
||||
ground truth data.
|
||||
|
||||
Returns:
|
||||
dict: A dict of losses.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_targets(self, data_samples: DetSampleList) -> Tuple:
|
||||
"""Generates loss targets from data samples. Returns a tuple of target
|
||||
tensors.
|
||||
|
||||
Args:
|
||||
data_samples (list(TextDetDataSample)): Ground truth data samples.
|
||||
|
||||
Returns:
|
||||
tuple: A tuple of target tensors.
|
||||
"""
|
||||
pass
|
|
@ -6,17 +6,17 @@ import numpy as np
|
|||
import torch
|
||||
from mmdet.models.utils import multi_apply
|
||||
from shapely.geometry import Polygon
|
||||
from torch import Tensor, nn
|
||||
from torch import Tensor
|
||||
|
||||
from mmocr.registry import MODELS
|
||||
from mmocr.structures import TextDetDataSample
|
||||
from mmocr.utils import offset_polygon
|
||||
from mmocr.utils.typing import ArrayLike
|
||||
from .text_kernel_mixin import TextKernelMixin
|
||||
from .seg_based_module_loss import SegBasedModuleLoss
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class DBModuleLoss(nn.Module, TextKernelMixin):
|
||||
class DBModuleLoss(SegBasedModuleLoss):
|
||||
r"""The class for implementing DBNet loss.
|
||||
|
||||
This is partially adapted from https://github.com/MhLiao/DB.
|
||||
|
|
|
@ -10,11 +10,11 @@ from torch import nn
|
|||
|
||||
from mmocr.registry import MODELS
|
||||
from mmocr.structures import TextDetDataSample
|
||||
from .text_kernel_mixin import TextKernelMixin
|
||||
from .seg_based_module_loss import SegBasedModuleLoss
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class PANModuleLoss(nn.Module, TextKernelMixin):
|
||||
class PANModuleLoss(SegBasedModuleLoss):
|
||||
"""The class for implementing PANet loss. This was partially adapted from
|
||||
https://github.com/whai362/pan_pp.pytorch and
|
||||
https://github.com/WenmuZhou/PAN.pytorch.
|
||||
|
|
|
@ -9,10 +9,12 @@ from mmengine.logging import MMLogger
|
|||
from shapely.geometry import Polygon
|
||||
|
||||
from mmocr.utils.polygon_utils import offset_polygon
|
||||
from .base import BaseTextDetModuleLoss
|
||||
|
||||
|
||||
class TextKernelMixin:
|
||||
"""Mixin class for text detection models that use text instance kernels."""
|
||||
class SegBasedModuleLoss(BaseTextDetModuleLoss):
|
||||
"""Base class for the module loss of segmentation-based text detection
|
||||
algorithms with some handy utilities."""
|
||||
|
||||
def _generate_kernels(
|
||||
self,
|
|
@ -8,15 +8,15 @@ from mmcv.image import impad, imrescale
|
|||
from mmdet.models.utils import multi_apply
|
||||
from numpy import ndarray
|
||||
from numpy.linalg import norm
|
||||
from torch import Tensor, nn
|
||||
from torch import Tensor
|
||||
|
||||
from mmocr.registry import MODELS
|
||||
from mmocr.structures import TextDetDataSample
|
||||
from .text_kernel_mixin import TextKernelMixin
|
||||
from .seg_based_module_loss import SegBasedModuleLoss
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class TextSnakeModuleLoss(nn.Module, TextKernelMixin):
|
||||
class TextSnakeModuleLoss(SegBasedModuleLoss):
|
||||
"""The class for implementing TextSnake loss. This is partially adapted
|
||||
from https://github.com/princewang1994/TextSnake.pytorch.
|
||||
|
||||
|
|
Loading…
Reference in New Issue