diff --git a/mmocr/models/common/losses/__init__.py b/mmocr/models/common/losses/__init__.py index 1d6956b0..336d2ed8 100644 --- a/mmocr/models/common/losses/__init__.py +++ b/mmocr/models/common/losses/__init__.py @@ -1,10 +1,12 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .bce_loss import MaskedBalancedBCELoss, MaskedBCELoss +from .bce_loss import (MaskedBalancedBCELoss, MaskedBalancedBCEWithLogitsLoss, + MaskedBCELoss, MaskedBCEWithLogitsLoss) from .ce_loss import CrossEntropyLoss from .dice_loss import MaskedDiceLoss, MaskedSquareDiceLoss from .l1_loss import MaskedSmoothL1Loss, SmoothL1Loss __all__ = [ - 'MaskedBalancedBCELoss', 'MaskedDiceLoss', 'MaskedSmoothL1Loss', - 'MaskedSquareDiceLoss', 'MaskedBCELoss', 'SmoothL1Loss', 'CrossEntropyLoss' + 'MaskedBalancedBCEWithLogitsLoss', 'MaskedDiceLoss', 'MaskedSmoothL1Loss', + 'MaskedSquareDiceLoss', 'MaskedBCEWithLogitsLoss', 'SmoothL1Loss', + 'CrossEntropyLoss', 'MaskedBalancedBCELoss', 'MaskedBCELoss' ] diff --git a/mmocr/models/common/losses/bce_loss.py b/mmocr/models/common/losses/bce_loss.py index 2b8c1b5f..df4ce140 100644 --- a/mmocr/models/common/losses/bce_loss.py +++ b/mmocr/models/common/losses/bce_loss.py @@ -8,8 +8,9 @@ from mmocr.registry import MODELS @MODELS.register_module() -class MaskedBalancedBCELoss(nn.Module): - """Masked Balanced BCE loss. +class MaskedBalancedBCEWithLogitsLoss(nn.Module): + """This loss combines a Sigmoid layers and a masked balanced BCE loss in + one single class. It's AMP-eligible. Args: reduction (str, optional): The method to reduce the loss. @@ -37,7 +38,7 @@ class MaskedBalancedBCELoss(nn.Module): self.negative_ratio = negative_ratio self.reduction = reduction self.fallback_negative_num = fallback_negative_num - self.binary_cross_entropy = nn.BCELoss(reduction=reduction) + self.loss = nn.BCEWithLogitsLoss(reduction=reduction) def forward(self, pred: torch.Tensor, @@ -74,8 +75,7 @@ class MaskedBalancedBCELoss(nn.Module): int(negative.sum()), int(positive_count * self.negative_ratio)) assert gt.max() <= 1 and gt.min() >= 0 - assert pred.max() <= 1 and pred.min() >= 0 - loss = self.binary_cross_entropy(pred, gt) + loss = self.loss(pred, gt) positive_loss = loss * positive negative_loss = loss * negative @@ -88,11 +88,67 @@ class MaskedBalancedBCELoss(nn.Module): @MODELS.register_module() -class MaskedBCELoss(nn.Module): - """Masked BCE loss. +class MaskedBalancedBCELoss(MaskedBalancedBCEWithLogitsLoss): + """Masked Balanced BCE loss. Args: - eps (float, optional): Eps to avoid zero-division error. Defaults to + reduction (str, optional): The method to reduce the loss. + Options are 'none', 'mean' and 'sum'. Defaults to 'none'. + negative_ratio (float or int): Maximum ratio of negative + samples to positive ones. Defaults to 3. + fallback_negative_num (int): When the mask contains no + positive samples, the number of negative samples to be sampled. + Defaults to 0. + eps (float): Eps to avoid zero-division error. Defaults to + 1e-6. + """ + + def __init__(self, + reduction: str = 'none', + negative_ratio: Union[float, int] = 3, + fallback_negative_num: int = 0, + eps: float = 1e-6) -> None: + super().__init__() + assert reduction in ['none', 'mean', 'sum'] + assert isinstance(negative_ratio, (float, int)) + assert isinstance(fallback_negative_num, int) + assert isinstance(eps, float) + self.eps = eps + self.negative_ratio = negative_ratio + self.reduction = reduction + self.fallback_negative_num = fallback_negative_num + self.loss = nn.BCELoss(reduction=reduction) + + def forward(self, + pred: torch.Tensor, + gt: torch.Tensor, + mask: Optional[torch.Tensor] = None) -> torch.Tensor: + """Forward function. + + Args: + pred (torch.Tensor): The prediction in any shape. + gt (torch.Tensor): The learning target of the prediction in the + same shape as pred. + mask (torch.Tensor, optional): Binary mask in the same shape of + pred, indicating positive regions to calculate the loss. Whole + region will be taken into account if not provided. Defaults to + None. + + Returns: + torch.Tensor: The loss value. + """ + + assert pred.max() <= 1 and pred.min() >= 0 + return super().forward(pred, gt, mask) + + +@MODELS.register_module() +class MaskedBCEWithLogitsLoss(nn.Module): + """This loss combines a Sigmoid layers and a masked BCE loss in one single + class. It's AMP-eligible. + + Args: + eps (float): Eps to avoid zero-division error. Defaults to 1e-6. """ @@ -100,7 +156,7 @@ class MaskedBCELoss(nn.Module): super().__init__() assert isinstance(eps, float) self.eps = eps - self.binary_cross_entropy = nn.BCELoss(reduction='none') + self.loss = nn.BCEWithLogitsLoss(reduction='none') def forward(self, pred: torch.Tensor, @@ -127,7 +183,45 @@ class MaskedBCELoss(nn.Module): assert mask.size() == gt.size() assert gt.max() <= 1 and gt.min() >= 0 - assert pred.max() <= 1 and pred.min() >= 0 - loss = self.binary_cross_entropy(pred, gt) + loss = self.loss(pred, gt) return (loss * mask).sum() / (mask.sum() + self.eps) + + +@MODELS.register_module() +class MaskedBCELoss(MaskedBCEWithLogitsLoss): + """Masked BCE loss. + + Args: + eps (float): Eps to avoid zero-division error. Defaults to + 1e-6. + """ + + def __init__(self, eps: float = 1e-6) -> None: + super().__init__() + assert isinstance(eps, float) + self.eps = eps + self.loss = nn.BCELoss(reduction='none') + + def forward(self, + pred: torch.Tensor, + gt: torch.Tensor, + mask: Optional[torch.Tensor] = None) -> torch.Tensor: + """Forward function. + + Args: + pred (torch.Tensor): The prediction in any shape. + gt (torch.Tensor): The learning target of the prediction in the + same shape as pred. + mask (torch.Tensor, optional): Binary mask in the same shape of + pred, indicating positive regions to calculate the loss. Whole + region will be taken into account if not provided. Defaults to + None. + + Returns: + torch.Tensor: The loss value. + """ + + assert pred.max() <= 1 and pred.min() >= 0 + + return super().forward(pred, gt, mask) diff --git a/mmocr/models/textdet/heads/base_textdet_head.py b/mmocr/models/textdet/heads/base_textdet_head.py index b4b94692..ffb6e846 100644 --- a/mmocr/models/textdet/heads/base_textdet_head.py +++ b/mmocr/models/textdet/heads/base_textdet_head.py @@ -6,9 +6,7 @@ from mmengine.model import BaseModule from torch import Tensor from mmocr.registry import MODELS -from mmocr.structures import TextDetDataSample - -SampleList = List[TextDetDataSample] +from mmocr.utils.typing import DetSampleList @MODELS.register_module() @@ -65,7 +63,8 @@ class BaseTextDetHead(BaseModule): self.module_loss = MODELS.build(module_loss) self.postprocessor = MODELS.build(postprocessor) - def loss(self, x: Tuple[Tensor], batch_data_samples: SampleList) -> dict: + def loss(self, x: Tuple[Tensor], + batch_data_samples: DetSampleList) -> dict: """Perform forward propagation and loss calculation of the detection head on the features of the upstream network. @@ -79,12 +78,13 @@ class BaseTextDetHead(BaseModule): Returns: dict: A dictionary of loss components. """ - outs = self(x) + outs = self(x, batch_data_samples) losses = self.module_loss(outs, batch_data_samples) return losses - def loss_and_predict(self, x: Tuple[Tensor], batch_data_samples: SampleList - ) -> Tuple[dict, SampleList]: + def loss_and_predict(self, x: Tuple[Tensor], + batch_data_samples: DetSampleList + ) -> Tuple[dict, DetSampleList]: """Perform forward propagation of the head, then calculate loss and predictions from the features and data samples. @@ -101,14 +101,14 @@ class BaseTextDetHead(BaseModule): - predictions (list[:obj:`InstanceData`]): Detection results of each image after the post process. """ - outs = self(x) + outs = self(x, batch_data_samples) losses = self.module_loss(outs, batch_data_samples) predictions = self.postprocessor(outs, batch_data_samples) return losses, predictions def predict(self, x: torch.Tensor, - batch_data_samples: SampleList) -> SampleList: + batch_data_samples: DetSampleList) -> DetSampleList: """Perform forward propagation of the detection head and predict detection results on the features of the upstream network. diff --git a/mmocr/models/textdet/heads/db_head.py b/mmocr/models/textdet/heads/db_head.py index 5014f9e0..c67c55a8 100644 --- a/mmocr/models/textdet/heads/db_head.py +++ b/mmocr/models/textdet/heads/db_head.py @@ -9,6 +9,7 @@ from torch import Tensor from mmocr.models.textdet.heads import BaseTextDetHead from mmocr.registry import MODELS from mmocr.structures import TextDetDataSample +from mmocr.utils.typing import DetSampleList @MODELS.register_module() @@ -53,8 +54,9 @@ class DBHead(BaseTextDetHead): nn.BatchNorm2d(in_channels // 4), nn.ReLU(inplace=True), nn.ConvTranspose2d(in_channels // 4, in_channels // 4, 2, 2), nn.BatchNorm2d(in_channels // 4), nn.ReLU(inplace=True), - nn.ConvTranspose2d(in_channels // 4, 1, 2, 2), nn.Sigmoid()) + nn.ConvTranspose2d(in_channels // 4, 1, 2, 2)) self.threshold = self._init_thr(in_channels) + self.sigmoid = nn.Sigmoid() def _diff_binarize(self, prob_map: Tensor, thr_map: Tensor, k: int) -> Tensor: @@ -70,26 +72,6 @@ class DBHead(BaseTextDetHead): """ return torch.reciprocal(1.0 + torch.exp(-k * (prob_map - thr_map))) - def forward( - self, - img: Tensor, - data_samples: Optional[List[TextDetDataSample]] = None - ) -> Tuple[Tensor, Tensor, Tensor]: - """ - Args: - img (Tensor): Shape :math:`(N, C, H, W)`. - data_samples (list[TextDetDataSample], optional): A list of data - samples. Defaults to None. - - Returns: - tuple(Tensor, Tensor, Tensor): A tuple of ``prob_map``, ``thr_map`` - and ``binary_map``, each of shape :math:`(N, 4H, 4W)`. - """ - prob_map = self.binarize(img).squeeze(1) - thr_map = self.threshold(img).squeeze(1) - binary_map = self._diff_binarize(prob_map, thr_map, k=50).squeeze(1) - return (prob_map, thr_map, binary_map) - def _init_thr(self, inner_channels: int, bias: bool = False) -> nn.ModuleList: @@ -103,3 +85,100 @@ class DBHead(BaseTextDetHead): nn.BatchNorm2d(inner_channels // 4), nn.ReLU(inplace=True), nn.ConvTranspose2d(inner_channels // 4, 1, 2, 2), nn.Sigmoid()) return seq + + def forward(self, + img: Tensor, + data_samples: Optional[List[TextDetDataSample]] = None, + mode: str = 'predict') -> Tuple[Tensor, Tensor, Tensor]: + """ + Args: + img (Tensor): Shape :math:`(N, C, H, W)`. + data_samples (list[TextDetDataSample], optional): A list of data + samples. Defaults to None. + mode (str): Forward mode. It affects the return values. Options are + "loss", "predict" and "both". Defaults to "predict". + + - ``loss``: Run the full network and return the prob + logits, threshold map and binary map. + - ``predict``: Run the binarzation part and return the prob + map only. + - ``both``: Run the full network and return prob logits, + threshold map, binary map and prob map. + + Returns: + Tensor or tuple(Tensor): Its type depends on ``mode``, read its + docstring for details. Each has the shape of + :math:`(N, 4H, 4W)`. + """ + prob_logits = self.binarize(img).squeeze(1) + prob_map = self.sigmoid(prob_logits) + if mode == 'predict': + return prob_map + thr_map = self.threshold(img).squeeze(1) + binary_map = self._diff_binarize(prob_map, thr_map, k=50).squeeze(1) + if mode == 'loss': + return prob_logits, thr_map, binary_map + return prob_logits, thr_map, binary_map, prob_map + + def loss(self, x: Tuple[Tensor], + batch_data_samples: DetSampleList) -> Dict: + """Perform forward propagation and loss calculation of the detection + head on the features of the upstream network. + + Args: + x (tuple[Tensor]): Features from the upstream network, each is + a 4D-tensor. + batch_data_samples (List[:obj:`DetDataSample`]): The Data + Samples. It usually includes information such as + `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. + + Returns: + dict: A dictionary of loss components. + """ + outs = self(x, batch_data_samples, mode='loss') + losses = self.module_loss(outs, batch_data_samples) + return losses + + def loss_and_predict(self, x: Tuple[Tensor], + batch_data_samples: DetSampleList + ) -> Tuple[dict, DetSampleList]: + """Perform forward propagation of the head, then calculate loss and + predictions from the features and data samples. + + Args: + x (tuple[Tensor]): Features from FPN. + batch_data_samples (list[:obj:`DetDataSample`]): Each item contains + the meta information of each image and corresponding + annotations. + + Returns: + tuple: the return value is a tuple contains: + + - losses: (dict[str, Tensor]): A dictionary of loss components. + - predictions (list[:obj:`InstanceData`]): Detection + results of each image after the post process. + """ + outs = self(x, batch_data_samples, mode='both') + losses = self.module_loss(outs[:3], batch_data_samples) + predictions = self.postprocessor(outs[3], batch_data_samples) + return losses, predictions + + def predict(self, x: torch.Tensor, + batch_data_samples: DetSampleList) -> DetSampleList: + """Perform forward propagation of the detection head and predict + detection results on the features of the upstream network. + + Args: + x (tuple[Tensor]): Multi-level features from the + upstream network, each is a 4D-tensor. + batch_data_samples (List[:obj:`DetDataSample`]): The Data + Samples. It usually includes information such as + `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. + + Returns: + SampleList: Detection results of each image + after the post process. + """ + outs = self(x, batch_data_samples, mode='predict') + predictions = self.postprocessor(outs, batch_data_samples) + return predictions diff --git a/mmocr/models/textdet/module_losses/db_module_loss.py b/mmocr/models/textdet/module_losses/db_module_loss.py index ee99e6c3..7d850b74 100644 --- a/mmocr/models/textdet/module_losses/db_module_loss.py +++ b/mmocr/models/textdet/module_losses/db_module_loss.py @@ -22,22 +22,26 @@ class DBModuleLoss(nn.Module, TextKernelMixin): This is partially adapted from https://github.com/MhLiao/DB. Args: - loss_prob (dict): The loss config for probability map. - loss_thr (dict): The loss config for threshold map. - loss_db (dict): The loss config for binary map. + loss_prob (dict): The loss config for probability map. Defaults to + dict(type='MaskedBalancedBCEWithLogitsLoss'). + loss_thr (dict): The loss config for threshold map. Defaults to + dict(type='MaskedSmoothL1Loss', beta=0). + loss_db (dict): The loss config for binary map. Defaults to + dict(type='MaskedDiceLoss'). weight_prob (float): The weight of probability map loss. - Denoted as :math:`\alpha` in paper. + Denoted as :math:`\alpha` in paper. Defaults to 5. weight_thr (float): The weight of threshold map loss. - Denoted as :math:`\beta` in paper. - shrink_ratio (float): The ratio of shrunk text region. - thr_min (float): The minimum threshold map value. - thr_max (float): The maximum threshold map value. + Denoted as :math:`\beta` in paper. Defaults to 10. + shrink_ratio (float): The ratio of shrunk text region. Defaults to 0.4. + thr_min (float): The minimum threshold map value. Defaults to 0.3. + thr_max (float): The maximum threshold map value. Defaults to 0.7. min_sidelength (int or float): The minimum sidelength of the - minimum rotated rectangle around any text region. + minimum rotated rectangle around any text region. Defaults to 8. """ def __init__(self, - loss_prob: Dict = dict(type='MaskedBalancedBCELoss'), + loss_prob: Dict = dict( + type='MaskedBalancedBCEWithLogitsLoss'), loss_thr: Dict = dict(type='MaskedSmoothL1Loss', beta=0), loss_db: Dict = dict(type='MaskedDiceLoss'), weight_prob: float = 5., @@ -63,22 +67,22 @@ class DBModuleLoss(nn.Module, TextKernelMixin): Args: preds (tuple(tensor)): Raw predictions from model, containing - ``prob_map``, ``thr_map`` and ``binary_map``. Each is a tensor - of shape :math:`(N, H, W)`. + ``prob_logits``, ``thr_map`` and ``binary_map``. + Each is a tensor of shape :math:`(N, H, W)`. data_samples (list[TextDetDataSample]): The data samples. Returns: results(dict): The dict for dbnet losses with loss_prob, \ loss_db and loss_thr. """ - prob_map, thr_map, binary_map = preds + prob_logits, thr_map, binary_map = preds gt_shrinks, gt_shrink_masks, gt_thrs, gt_thr_masks = self.get_targets( data_samples) - gt_shrinks = gt_shrinks.to(prob_map.device) - gt_shrink_masks = gt_shrink_masks.to(prob_map.device) + gt_shrinks = gt_shrinks.to(prob_logits.device) + gt_shrink_masks = gt_shrink_masks.to(prob_logits.device) gt_thrs = gt_thrs.to(thr_map.device) gt_thr_masks = gt_thr_masks.to(thr_map.device) - loss_prob = self.loss_prob(prob_map, gt_shrinks, gt_shrink_masks) + loss_prob = self.loss_prob(prob_logits, gt_shrinks, gt_shrink_masks) loss_thr = self.loss_thr(thr_map, gt_thrs, gt_thr_masks) loss_db = self.loss_db(binary_map, gt_shrinks, gt_shrink_masks) diff --git a/mmocr/models/textdet/module_losses/drrg_module_loss.py b/mmocr/models/textdet/module_losses/drrg_module_loss.py index c36461b3..51923ef0 100644 --- a/mmocr/models/textdet/module_losses/drrg_module_loss.py +++ b/mmocr/models/textdet/module_losses/drrg_module_loss.py @@ -56,10 +56,10 @@ class DRRGModuleLoss(TextSnakeModuleLoss): jitter_level (float): The jitter level of text component geometric features. Defaults to 0.2. loss_text (dict): The loss config used to calculate the text loss. - Defaults to ``dict( - type='Normal', override=dict(name='out_conv'), mean=0, std=0.01)``. + Defaults to ``dict(type='MaskedBalancedBCEWithLogitsLoss', + fallback_negative_num=100, eps=1e-5)``. loss_center (dict): The loss config used to calculate the center loss. - Defaults to ``dict(type='MaskedBCELoss')``. + Defaults to ``dict(type='MaskedBCEWithLogitsLoss')``. loss_top (dict): The loss config used to calculate the top loss, which is a part of the height loss. Defaults to ``dict(type='SmoothL1Loss', reduction='none')``. @@ -92,8 +92,10 @@ class DRRGModuleLoss(TextSnakeModuleLoss): max_rand_half_height: float = 24.0, jitter_level: float = 0.2, loss_text: Dict = dict( - type='MaskedBalancedBCELoss', fallback_negative_num=100, eps=1e-5), - loss_center: Dict = dict(type='MaskedBCELoss'), + type='MaskedBalancedBCEWithLogitsLoss', + fallback_negative_num=100, + eps=1e-5), + loss_center: Dict = dict(type='MaskedBCEWithLogitsLoss'), loss_top: Dict = dict(type='SmoothL1Loss', reduction='none'), loss_btm: Dict = dict(type='SmoothL1Loss', reduction='none'), loss_sin: Dict = dict(type='MaskedSmoothL1Loss'), @@ -185,16 +187,16 @@ class DRRGModuleLoss(TextSnakeModuleLoss): pred_sin_map = pred_sin_map * scale pred_cos_map = pred_cos_map * scale - loss_text = self.loss_text(pred_text_region.sigmoid(), - gt['gt_text_masks'], gt['gt_masks']) + loss_text = self.loss_text(pred_text_region, gt['gt_text_masks'], + gt['gt_masks']) text_mask = (gt['gt_text_masks'] * gt['gt_masks']).float() negative_text_mask = ((1 - gt['gt_text_masks']) * gt['gt_masks']).float() - loss_center_positive = self.loss_center(pred_center_region.sigmoid(), + loss_center_positive = self.loss_center(pred_center_region, gt['gt_center_region_masks'], text_mask) - loss_center_negative = self.loss_center(pred_center_region.sigmoid(), + loss_center_negative = self.loss_center(pred_center_region, gt['gt_center_region_masks'], negative_text_mask) loss_center = loss_center_positive + 0.5 * loss_center_negative diff --git a/mmocr/models/textdet/module_losses/textsnake_module_loss.py b/mmocr/models/textdet/module_losses/textsnake_module_loss.py index 75f12ba6..00faff50 100644 --- a/mmocr/models/textdet/module_losses/textsnake_module_loss.py +++ b/mmocr/models/textdet/module_losses/textsnake_module_loss.py @@ -47,8 +47,10 @@ class TextSnakeModuleLoss(nn.Module, TextKernelMixin): resample_step: float = 4.0, center_region_shrink_ratio: float = 0.3, loss_text: Dict = dict( - type='MaskedBalancedBCELoss', fallback_negative_num=100, eps=1e-5), - loss_center: Dict = dict(type='MaskedBCELoss'), + type='MaskedBalancedBCEWithLogitsLoss', + fallback_negative_num=100, + eps=1e-5), + loss_center: Dict = dict(type='MaskedBCEWithLogitsLoss'), loss_radius: Dict = dict(type='MaskedSmoothL1Loss'), loss_sin: Dict = dict(type='MaskedSmoothL1Loss'), loss_cos: Dict = dict(type='MaskedSmoothL1Loss') @@ -146,11 +148,11 @@ class TextSnakeModuleLoss(nn.Module, TextKernelMixin): pred_sin_map = pred_sin_map * scale pred_cos_map = pred_cos_map * scale - loss_text = self.loss_text(pred_text_region.sigmoid(), - gt['gt_text_masks'], gt['gt_masks']) + loss_text = self.loss_text(pred_text_region, gt['gt_text_masks'], + gt['gt_masks']) text_mask = (gt['gt_text_masks'] * gt['gt_masks']).float() - loss_center = self.loss_center(pred_center_region.sigmoid(), + loss_center = self.loss_center(pred_center_region, gt['gt_center_region_masks'], text_mask) center_mask = (gt['gt_center_region_masks'] * gt['gt_masks']).float() diff --git a/mmocr/models/textdet/postprocessors/base_postprocessor.py b/mmocr/models/textdet/postprocessors/base_postprocessor.py index f5f73465..706b1526 100644 --- a/mmocr/models/textdet/postprocessors/base_postprocessor.py +++ b/mmocr/models/textdet/postprocessors/base_postprocessor.py @@ -39,15 +39,15 @@ class BaseTextDetPostProcessor: self.test_cfg = test_cfg def __call__(self, - pred_results: dict, + pred_results: Union[Tensor, List[Tensor]], data_samples: Sequence[TextDetDataSample], training: bool = False) -> Sequence[TextDetDataSample]: """Postprocess pred_results according to metainfos in data_samples. Args: - pred_results (dict): The prediction results stored in a dictionary. - Usually each item to be post-processed is expected to be a - batched tensor. + pred_results (Union[Tensor, List[Tensor]]): The prediction results + stored in a tensor or a list of tensor. Usually each item to + be post-processed is expected to be a batched tensor. data_samples (list[TextDetDataSample]): Batch of data_samples, each corresponding to a prediction result. training (bool): Whether the model is in training mode. Defaults to @@ -65,13 +65,14 @@ class BaseTextDetPostProcessor: return results - def _process_single(self, pred_result: dict, + def _process_single(self, pred_result: Union[Tensor, List[Tensor]], data_sample: TextDetDataSample, **kwargs) -> TextDetDataSample: """Process prediction results from one image. Args: - pred_result (dict): Prediction results of an image. + pred_result (Union[Tensor, List[Tensor]]): Prediction results of an + image. data_sample (TextDetDataSample): Datasample of an image. """ @@ -103,13 +104,13 @@ class BaseTextDetPostProcessor: results.pred_instances[key], scale_factor, mode='div') return results - def get_text_instances(self, pred_results: dict, + def get_text_instances(self, pred_results: Union[Tensor, List[Tensor]], data_sample: TextDetDataSample, **kwargs) -> TextDetDataSample: """Get text instance predictions of one image. Args: - pred_result (dict): Prediction results of an image. + pred_result (tuple(Tensor)): Prediction results of an image. data_sample (TextDetDataSample): Datasample of an image. **kwargs: Other parameters. Configurable via ``__init__.train_cfg`` and ``__init__.test_cfg``. diff --git a/mmocr/models/textdet/postprocessors/db_postprocessor.py b/mmocr/models/textdet/postprocessors/db_postprocessor.py index ab245201..9cb8d3ae 100644 --- a/mmocr/models/textdet/postprocessors/db_postprocessor.py +++ b/mmocr/models/textdet/postprocessors/db_postprocessor.py @@ -1,5 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Sequence, Tuple +from typing import Sequence import cv2 import numpy as np @@ -59,14 +59,14 @@ class DBPostprocessor(BaseTextDetPostProcessor): self.epsilon_ratio = epsilon_ratio self.max_candidates = max_candidates - def get_text_instances(self, pred_results: Tuple[Tensor, Tensor, Tensor], + def get_text_instances(self, prob_map: Tensor, data_sample: TextDetDataSample ) -> TextDetDataSample: """Get text instance predictions of one image. Args: - pred_result (tuple(Tensor)): A tuple of 3 tensors where the first - tensor is ``prob_map`` of shape :math:`(N, H, W)`. + pred_result (Tensor): DBNet's output ``prob_map`` of shape + :math:`(H, W)`. data_sample (TextDetDataSample): Datasample of an image. Returns: @@ -80,7 +80,6 @@ class DBPostprocessor(BaseTextDetPostProcessor): data_sample.pred_instances.polygons = [] data_sample.pred_instances.scores = [] - prob_map = pred_results[0] text_mask = prob_map > self.mask_thr score_map = prob_map.data.cpu().numpy().astype(np.float32) diff --git a/tests/test_models/test_common/test_losses/test_bce_loss.py b/tests/test_models/test_common/test_losses/test_bce_loss.py index 5f10be07..0de420d2 100644 --- a/tests/test_models/test_common/test_losses/test_bce_loss.py +++ b/tests/test_models/test_common/test_losses/test_bce_loss.py @@ -3,7 +3,9 @@ from unittest import TestCase import torch -from mmocr.models.common.losses import MaskedBalancedBCELoss, MaskedBCELoss +from mmocr.models.common.losses import (MaskedBalancedBCELoss, + MaskedBalancedBCEWithLogitsLoss, + MaskedBCELoss, MaskedBCEWithLogitsLoss) class TestMaskedBalancedBCELoss(TestCase): @@ -117,3 +119,117 @@ class TestMaskedBCELoss(TestCase): zero_mask = torch.FloatTensor([0, 0, 0, 0]) self.assertAlmostEqual( self.bce_loss(self.pred, self.gt, zero_mask).item(), 0) + + +class TestMaskedBalancedWithLogitsBCELoss(TestCase): + + def setUp(self) -> None: + self.loss = MaskedBalancedBCEWithLogitsLoss(negative_ratio=2) + self.pred = torch.FloatTensor([1.5, 1.5, 1.5, 1.5]) + self.gt = torch.FloatTensor([1, 1, 1, 0]) + self.mask = torch.BoolTensor([True, False, False, True]) + + def test_init(self): + with self.assertRaises(AssertionError): + MaskedBalancedBCEWithLogitsLoss(reduction='any') + + with self.assertRaises(AssertionError): + MaskedBalancedBCEWithLogitsLoss(negative_ratio='a') + + with self.assertRaises(AssertionError): + MaskedBalancedBCEWithLogitsLoss(eps='a') + + with self.assertRaises(AssertionError): + MaskedBalancedBCEWithLogitsLoss(fallback_negative_num='a') + + def test_forward(self): + + # Shape mismatch between pred and gt + with self.assertRaises(AssertionError): + invalid_gt = torch.FloatTensor([0, 0, 0]) + self.loss(self.pred, invalid_gt) + + # Shape mismatch between pred and mask + with self.assertRaises(AssertionError): + invalid_mask = torch.BoolTensor([True, False, False]) + self.loss(self.pred, self.gt, invalid_mask) + + # Invalid gt + with self.assertRaises(AssertionError): + invalid_gt = torch.FloatTensor([2, 3, 4, 5]) + self.loss(self.pred, invalid_gt, self.mask) + + logit = torch.FloatTensor([1.5]) + self.assertAlmostEqual( + self.loss(self.pred, self.gt).item(), + ((-torch.log(torch.sigmoid(logit)) * 3 - + torch.log(1 - torch.sigmoid(logit))) / 4).item(), + delta=0.0001) + self.assertAlmostEqual( + self.loss(self.pred, self.gt, self.mask).item(), + (-torch.log(torch.sigmoid(logit)) - + torch.log(1 - torch.sigmoid(logit))).item() / 2, + delta=0.0001) + + # Test zero mask + zero_mask = torch.FloatTensor([0, 0, 0, 0]) + self.assertAlmostEqual( + self.loss(self.pred, self.gt, zero_mask).item(), 0) + + # Test 0 < fallback_negative_num < negative numbers + all_neg_gt = torch.zeros((4, )) + self.fallback_bce_loss = MaskedBalancedBCEWithLogitsLoss( + fallback_negative_num=1) + self.assertAlmostEqual( + self.fallback_bce_loss(self.pred, all_neg_gt, self.mask).item(), + -torch.log(1 - torch.sigmoid(logit)).item(), + delta=0.001) + # Test fallback_negative_num > negative numbers + self.fallback_bce_loss = MaskedBalancedBCEWithLogitsLoss( + fallback_negative_num=5) + self.assertAlmostEqual( + self.fallback_bce_loss(self.pred, all_neg_gt, self.mask).item(), + -torch.log(1 - torch.sigmoid(logit)).item(), + delta=0.001) + + +class TestMaskedBCEWithLogitsLoss(TestCase): + + def setUp(self) -> None: + self.loss = MaskedBCEWithLogitsLoss() + self.pred = torch.FloatTensor([1.5, 1.5, 1.5, 1.5]) + self.gt = torch.FloatTensor([1, 1, 1, 0]) + self.mask = torch.BoolTensor([True, False, False, True]) + + def test_init(self): + with self.assertRaises(AssertionError): + MaskedBCEWithLogitsLoss(eps='a') + + def test_forward(self): + + # Shape mismatch between pred and gt + with self.assertRaises(AssertionError): + invalid_gt = torch.FloatTensor([0, 0, 0]) + self.loss(self.pred, invalid_gt) + + # Shape mismatch between pred and mask + with self.assertRaises(AssertionError): + invalid_mask = torch.BoolTensor([True, False, False]) + self.loss(self.pred, self.gt, invalid_mask) + + logit = torch.FloatTensor([1.5]) + self.assertAlmostEqual( + self.loss(self.pred, self.gt).item(), + ((-torch.log(torch.sigmoid(logit)) * 3 - + torch.log(1 - torch.sigmoid(logit))) / 4).item(), + delta=0.0001) + self.assertAlmostEqual( + self.loss(self.pred, self.gt, self.mask).item(), + (-torch.log(torch.sigmoid(logit)) - + torch.log(1 - torch.sigmoid(logit))).item() / 2, + delta=0.0001) + + # Test zero mask + zero_mask = torch.FloatTensor([0, 0, 0, 0]) + self.assertAlmostEqual( + self.loss(self.pred, self.gt, zero_mask).item(), 0) diff --git a/tests/test_models/test_textdet/test_heads/test_db_head.py b/tests/test_models/test_textdet/test_heads/test_db_head.py index 60d1dd2e..b0c8d424 100644 --- a/tests/test_models/test_textdet/test_heads/test_db_head.py +++ b/tests/test_models/test_textdet/test_heads/test_db_head.py @@ -4,10 +4,24 @@ from unittest import TestCase import torch from mmocr.models.textdet.heads import DBHead +from mmocr.registry import MODELS class TestDBHead(TestCase): + # Use to replace module loss and postprocessors + @MODELS.register_module(name='DBDummy') + class DummyModule: + + def __call__(self, x, data_samples): + return x + + def setUp(self) -> None: + self.db_head = DBHead( + in_channels=10, + module_loss=dict(type='DBDummy'), + postprocessor=dict(type='DBDummy')) + def test_init(self): with self.assertRaises(AssertionError): DBHead(in_channels='test', with_bias=False) @@ -16,9 +30,28 @@ class TestDBHead(TestCase): DBHead(in_channels=1, with_bias='Text') def test_forward(self): - db_head = DBHead(in_channels=10) data = torch.randn((2, 10, 40, 50)) - results = db_head(data, None) + results = self.db_head(data, None) self.assertEqual(results[0].shape, (2, 160, 200)) self.assertEqual(results[1].shape, (2, 160, 200)) self.assertEqual(results[2].shape, (2, 160, 200)) + + def test_loss(self): + data = torch.randn((2, 10, 40, 50)) + results = self.db_head.loss(data, None) + for i in range(3): + self.assertEqual(results[i].shape, (2, 160, 200)) + + def test_predict(self): + data = torch.randn((2, 10, 40, 50)) + results = self.db_head.predict(data, None) + self.assertEqual(results.shape, (2, 160, 200)) + + def test_loss_and_predict(self): + data = torch.randn((2, 10, 40, 50)) + loss_results, pred_results = self.db_head.loss_and_predict(data, None) + for i in range(3): + self.assertEqual(loss_results[i].shape, (2, 160, 200)) + self.assertEqual(pred_results.shape, (2, 160, 200)) + self.assertTrue( + torch.allclose(pred_results, loss_results[0].sigmoid())) diff --git a/tests/test_models/test_textdet/test_postprocessors/test_db_postprocessor.py b/tests/test_models/test_textdet/test_postprocessors/test_db_postprocessor.py index e832d6cd..abd0da0f 100644 --- a/tests/test_models/test_textdet/test_postprocessors/test_db_postprocessor.py +++ b/tests/test_models/test_textdet/test_postprocessors/test_db_postprocessor.py @@ -23,7 +23,7 @@ class TestDBPostProcessor(unittest.TestCase): def test_get_text_instances(self, text_repr_type): postprocessor = DBPostprocessor(text_repr_type=text_repr_type) - pred_result = (torch.rand(4, 5), torch.rand(4, 5), torch.rand(4, 5)) + pred_result = torch.rand(4, 5) data_sample = TextDetDataSample( metainfo=dict(scale_factor=(0.5, 1)), gt_instances=InstanceData(polygons=[ @@ -36,12 +36,11 @@ class TestDBPostProcessor(unittest.TestCase): self.assertTrue( isinstance(results.pred_instances['scores'], torch.FloatTensor)) - preds = (torch.FloatTensor([[0.8, 0.8, 0.8, 0.8, 0], - [0.8, 0.8, 0.8, 0.8, 0], - [0.8, 0.8, 0.8, 0.8, 0], - [0.8, 0.8, 0.8, 0.8, 0], - [0.8, 0.8, 0.8, 0.8, 0]]), - torch.rand([1, 10]), torch.rand([1, 10])) + preds = torch.FloatTensor([[0.8, 0.8, 0.8, 0.8, 0], + [0.8, 0.8, 0.8, 0.8, 0], + [0.8, 0.8, 0.8, 0.8, 0], + [0.8, 0.8, 0.8, 0.8, 0], + [0.8, 0.8, 0.8, 0.8, 0]]) postprocessor = DBPostprocessor( text_repr_type=text_repr_type, min_text_width=0) results = postprocessor.get_text_instances(preds, data_sample) @@ -49,8 +48,7 @@ class TestDBPostProcessor(unittest.TestCase): postprocessor = DBPostprocessor( min_text_score=1, text_repr_type=text_repr_type) - pred_result = (torch.rand(4, 5) * 0.8, torch.rand(4, 5) * 0.8, - torch.rand(4, 5) * 0.8) + pred_result = torch.rand(4, 5) * 0.8 results = postprocessor.get_text_instances(pred_result, data_sample) self.assertEqual(results.pred_instances.polygons, []) self.assertTrue(