From ad73fb10ff056a937b16faae575b070d7d7e0aaf Mon Sep 17 00:00:00 2001 From: Tong Gao Date: Thu, 25 Aug 2022 14:49:45 +0800 Subject: [PATCH] [Enhancemnet] Add BaseTextDetModuleLoss (#1323) * [Enhancemnet] Add BaseTextDetModuleLoss * textkernelmixin->SegBasedModuleLoss * Update configs/textdet/dbnet/dbnet_resnet18_fpnc_1200e_icdar2015.py Co-authored-by: liukuikun <24622904+Harold-lkk@users.noreply.github.com> Co-authored-by: liukuikun <24622904+Harold-lkk@users.noreply.github.com> --- .../models/textdet/module_losses/__init__.py | 4 +- mmocr/models/textdet/module_losses/base.py | 51 +++++++++++++++++++ .../textdet/module_losses/db_module_loss.py | 6 +-- .../textdet/module_losses/pan_module_loss.py | 4 +- ...rnel_mixin.py => seg_based_module_loss.py} | 6 ++- .../module_losses/textsnake_module_loss.py | 6 +-- 6 files changed, 65 insertions(+), 12 deletions(-) create mode 100644 mmocr/models/textdet/module_losses/base.py rename mmocr/models/textdet/module_losses/{text_kernel_mixin.py => seg_based_module_loss.py} (94%) diff --git a/mmocr/models/textdet/module_losses/__init__.py b/mmocr/models/textdet/module_losses/__init__.py index 1662f9cc..111c4799 100644 --- a/mmocr/models/textdet/module_losses/__init__.py +++ b/mmocr/models/textdet/module_losses/__init__.py @@ -4,10 +4,10 @@ from .drrg_module_loss import DRRGModuleLoss from .fce_module_loss import FCEModuleLoss from .pan_module_loss import PANModuleLoss from .pse_module_loss import PSEModuleLoss -from .text_kernel_mixin import TextKernelMixin +from .seg_based_module_loss import SegBasedModuleLoss from .textsnake_module_loss import TextSnakeModuleLoss __all__ = [ 'PANModuleLoss', 'PSEModuleLoss', 'DBModuleLoss', 'TextSnakeModuleLoss', - 'FCEModuleLoss', 'DRRGModuleLoss', 'TextKernelMixin' + 'FCEModuleLoss', 'DRRGModuleLoss', 'SegBasedModuleLoss' ] diff --git a/mmocr/models/textdet/module_losses/base.py b/mmocr/models/textdet/module_losses/base.py new file mode 100644 index 00000000..a884b69c --- /dev/null +++ b/mmocr/models/textdet/module_losses/base.py @@ -0,0 +1,51 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import ABCMeta, abstractmethod +from typing import Dict, Sequence, Tuple, Union + +import torch +from torch import nn + +from mmocr.registry import MODELS +from mmocr.utils.typing import DetSampleList + +INPUT_TYPES = Union[torch.Tensor, Sequence[torch.Tensor], Dict] + + +@MODELS.register_module() +class BaseTextDetModuleLoss(nn.Module, metaclass=ABCMeta): + r"""Base class for text detection module loss. + """ + + def __init__(self) -> None: + super().__init__() + + @abstractmethod + def forward(self, + inputs: INPUT_TYPES, + data_samples: DetSampleList = None) -> Dict: + """Calculates losses from a batch of inputs and data samples. Returns a + dict of losses. + + Args: + inputs (Tensor or list[Tensor] or dict): The raw tensor outputs + from the model. + data_samples (list(TextDetDataSample)): Datasamples containing + ground truth data. + + Returns: + dict: A dict of losses. + """ + pass + + @abstractmethod + def get_targets(self, data_samples: DetSampleList) -> Tuple: + """Generates loss targets from data samples. Returns a tuple of target + tensors. + + Args: + data_samples (list(TextDetDataSample)): Ground truth data samples. + + Returns: + tuple: A tuple of target tensors. + """ + pass diff --git a/mmocr/models/textdet/module_losses/db_module_loss.py b/mmocr/models/textdet/module_losses/db_module_loss.py index 7d850b74..f611e768 100644 --- a/mmocr/models/textdet/module_losses/db_module_loss.py +++ b/mmocr/models/textdet/module_losses/db_module_loss.py @@ -6,17 +6,17 @@ import numpy as np import torch from mmdet.models.utils import multi_apply from shapely.geometry import Polygon -from torch import Tensor, nn +from torch import Tensor from mmocr.registry import MODELS from mmocr.structures import TextDetDataSample from mmocr.utils import offset_polygon from mmocr.utils.typing import ArrayLike -from .text_kernel_mixin import TextKernelMixin +from .seg_based_module_loss import SegBasedModuleLoss @MODELS.register_module() -class DBModuleLoss(nn.Module, TextKernelMixin): +class DBModuleLoss(SegBasedModuleLoss): r"""The class for implementing DBNet loss. This is partially adapted from https://github.com/MhLiao/DB. diff --git a/mmocr/models/textdet/module_losses/pan_module_loss.py b/mmocr/models/textdet/module_losses/pan_module_loss.py index ebb50345..6a5a6685 100644 --- a/mmocr/models/textdet/module_losses/pan_module_loss.py +++ b/mmocr/models/textdet/module_losses/pan_module_loss.py @@ -10,11 +10,11 @@ from torch import nn from mmocr.registry import MODELS from mmocr.structures import TextDetDataSample -from .text_kernel_mixin import TextKernelMixin +from .seg_based_module_loss import SegBasedModuleLoss @MODELS.register_module() -class PANModuleLoss(nn.Module, TextKernelMixin): +class PANModuleLoss(SegBasedModuleLoss): """The class for implementing PANet loss. This was partially adapted from https://github.com/whai362/pan_pp.pytorch and https://github.com/WenmuZhou/PAN.pytorch. diff --git a/mmocr/models/textdet/module_losses/text_kernel_mixin.py b/mmocr/models/textdet/module_losses/seg_based_module_loss.py similarity index 94% rename from mmocr/models/textdet/module_losses/text_kernel_mixin.py rename to mmocr/models/textdet/module_losses/seg_based_module_loss.py index f3c95ab1..2f216692 100644 --- a/mmocr/models/textdet/module_losses/text_kernel_mixin.py +++ b/mmocr/models/textdet/module_losses/seg_based_module_loss.py @@ -9,10 +9,12 @@ from mmengine.logging import MMLogger from shapely.geometry import Polygon from mmocr.utils.polygon_utils import offset_polygon +from .base import BaseTextDetModuleLoss -class TextKernelMixin: - """Mixin class for text detection models that use text instance kernels.""" +class SegBasedModuleLoss(BaseTextDetModuleLoss): + """Base class for the module loss of segmentation-based text detection + algorithms with some handy utilities.""" def _generate_kernels( self, diff --git a/mmocr/models/textdet/module_losses/textsnake_module_loss.py b/mmocr/models/textdet/module_losses/textsnake_module_loss.py index 00faff50..651a7475 100644 --- a/mmocr/models/textdet/module_losses/textsnake_module_loss.py +++ b/mmocr/models/textdet/module_losses/textsnake_module_loss.py @@ -8,15 +8,15 @@ from mmcv.image import impad, imrescale from mmdet.models.utils import multi_apply from numpy import ndarray from numpy.linalg import norm -from torch import Tensor, nn +from torch import Tensor from mmocr.registry import MODELS from mmocr.structures import TextDetDataSample -from .text_kernel_mixin import TextKernelMixin +from .seg_based_module_loss import SegBasedModuleLoss @MODELS.register_module() -class TextSnakeModuleLoss(nn.Module, TextKernelMixin): +class TextSnakeModuleLoss(SegBasedModuleLoss): """The class for implementing TextSnake loss. This is partially adapted from https://github.com/princewang1994/TextSnake.pytorch.