[Enhancemnet] Add BaseTextDetModuleLoss (#1323)

* [Enhancemnet] Add BaseTextDetModuleLoss

* textkernelmixin->SegBasedModuleLoss

* Update configs/textdet/dbnet/dbnet_resnet18_fpnc_1200e_icdar2015.py

Co-authored-by: liukuikun <24622904+Harold-lkk@users.noreply.github.com>

Co-authored-by: liukuikun <24622904+Harold-lkk@users.noreply.github.com>
pull/1327/head
Tong Gao 2022-08-25 14:49:45 +08:00 committed by GitHub
parent a45716d20e
commit ad73fb10ff
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 65 additions and 12 deletions

View File

@ -4,10 +4,10 @@ from .drrg_module_loss import DRRGModuleLoss
from .fce_module_loss import FCEModuleLoss
from .pan_module_loss import PANModuleLoss
from .pse_module_loss import PSEModuleLoss
from .text_kernel_mixin import TextKernelMixin
from .seg_based_module_loss import SegBasedModuleLoss
from .textsnake_module_loss import TextSnakeModuleLoss
__all__ = [
'PANModuleLoss', 'PSEModuleLoss', 'DBModuleLoss', 'TextSnakeModuleLoss',
'FCEModuleLoss', 'DRRGModuleLoss', 'TextKernelMixin'
'FCEModuleLoss', 'DRRGModuleLoss', 'SegBasedModuleLoss'
]

View File

@ -0,0 +1,51 @@
# Copyright (c) OpenMMLab. All rights reserved.
from abc import ABCMeta, abstractmethod
from typing import Dict, Sequence, Tuple, Union
import torch
from torch import nn
from mmocr.registry import MODELS
from mmocr.utils.typing import DetSampleList
INPUT_TYPES = Union[torch.Tensor, Sequence[torch.Tensor], Dict]
@MODELS.register_module()
class BaseTextDetModuleLoss(nn.Module, metaclass=ABCMeta):
r"""Base class for text detection module loss.
"""
def __init__(self) -> None:
super().__init__()
@abstractmethod
def forward(self,
inputs: INPUT_TYPES,
data_samples: DetSampleList = None) -> Dict:
"""Calculates losses from a batch of inputs and data samples. Returns a
dict of losses.
Args:
inputs (Tensor or list[Tensor] or dict): The raw tensor outputs
from the model.
data_samples (list(TextDetDataSample)): Datasamples containing
ground truth data.
Returns:
dict: A dict of losses.
"""
pass
@abstractmethod
def get_targets(self, data_samples: DetSampleList) -> Tuple:
"""Generates loss targets from data samples. Returns a tuple of target
tensors.
Args:
data_samples (list(TextDetDataSample)): Ground truth data samples.
Returns:
tuple: A tuple of target tensors.
"""
pass

View File

@ -6,17 +6,17 @@ import numpy as np
import torch
from mmdet.models.utils import multi_apply
from shapely.geometry import Polygon
from torch import Tensor, nn
from torch import Tensor
from mmocr.registry import MODELS
from mmocr.structures import TextDetDataSample
from mmocr.utils import offset_polygon
from mmocr.utils.typing import ArrayLike
from .text_kernel_mixin import TextKernelMixin
from .seg_based_module_loss import SegBasedModuleLoss
@MODELS.register_module()
class DBModuleLoss(nn.Module, TextKernelMixin):
class DBModuleLoss(SegBasedModuleLoss):
r"""The class for implementing DBNet loss.
This is partially adapted from https://github.com/MhLiao/DB.

View File

@ -10,11 +10,11 @@ from torch import nn
from mmocr.registry import MODELS
from mmocr.structures import TextDetDataSample
from .text_kernel_mixin import TextKernelMixin
from .seg_based_module_loss import SegBasedModuleLoss
@MODELS.register_module()
class PANModuleLoss(nn.Module, TextKernelMixin):
class PANModuleLoss(SegBasedModuleLoss):
"""The class for implementing PANet loss. This was partially adapted from
https://github.com/whai362/pan_pp.pytorch and
https://github.com/WenmuZhou/PAN.pytorch.

View File

@ -9,10 +9,12 @@ from mmengine.logging import MMLogger
from shapely.geometry import Polygon
from mmocr.utils.polygon_utils import offset_polygon
from .base import BaseTextDetModuleLoss
class TextKernelMixin:
"""Mixin class for text detection models that use text instance kernels."""
class SegBasedModuleLoss(BaseTextDetModuleLoss):
"""Base class for the module loss of segmentation-based text detection
algorithms with some handy utilities."""
def _generate_kernels(
self,

View File

@ -8,15 +8,15 @@ from mmcv.image import impad, imrescale
from mmdet.models.utils import multi_apply
from numpy import ndarray
from numpy.linalg import norm
from torch import Tensor, nn
from torch import Tensor
from mmocr.registry import MODELS
from mmocr.structures import TextDetDataSample
from .text_kernel_mixin import TextKernelMixin
from .seg_based_module_loss import SegBasedModuleLoss
@MODELS.register_module()
class TextSnakeModuleLoss(nn.Module, TextKernelMixin):
class TextSnakeModuleLoss(SegBasedModuleLoss):
"""The class for implementing TextSnake loss. This is partially adapted
from https://github.com/princewang1994/TextSnake.pytorch.