[Feature] Use BCEWithLogitsLoss as many as possible for AMP (#1309)

* Use BCEWithLogitsLoss as many as possible for AMP

* fix

* Optimize DBNet

* fix docstr

* Use branch in dbhead, fix missing data_samples in textdethead

* fix
pull/1315/head
Tong Gao 2022-08-23 19:17:27 +08:00 committed by GitHub
parent 5c8c774aa9
commit 1860a3a3b6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 429 additions and 99 deletions

View File

@ -1,10 +1,12 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .bce_loss import MaskedBalancedBCELoss, MaskedBCELoss
from .bce_loss import (MaskedBalancedBCELoss, MaskedBalancedBCEWithLogitsLoss,
MaskedBCELoss, MaskedBCEWithLogitsLoss)
from .ce_loss import CrossEntropyLoss
from .dice_loss import MaskedDiceLoss, MaskedSquareDiceLoss
from .l1_loss import MaskedSmoothL1Loss, SmoothL1Loss
__all__ = [
'MaskedBalancedBCELoss', 'MaskedDiceLoss', 'MaskedSmoothL1Loss',
'MaskedSquareDiceLoss', 'MaskedBCELoss', 'SmoothL1Loss', 'CrossEntropyLoss'
'MaskedBalancedBCEWithLogitsLoss', 'MaskedDiceLoss', 'MaskedSmoothL1Loss',
'MaskedSquareDiceLoss', 'MaskedBCEWithLogitsLoss', 'SmoothL1Loss',
'CrossEntropyLoss', 'MaskedBalancedBCELoss', 'MaskedBCELoss'
]

View File

@ -8,8 +8,9 @@ from mmocr.registry import MODELS
@MODELS.register_module()
class MaskedBalancedBCELoss(nn.Module):
"""Masked Balanced BCE loss.
class MaskedBalancedBCEWithLogitsLoss(nn.Module):
"""This loss combines a Sigmoid layers and a masked balanced BCE loss in
one single class. It's AMP-eligible.
Args:
reduction (str, optional): The method to reduce the loss.
@ -37,7 +38,7 @@ class MaskedBalancedBCELoss(nn.Module):
self.negative_ratio = negative_ratio
self.reduction = reduction
self.fallback_negative_num = fallback_negative_num
self.binary_cross_entropy = nn.BCELoss(reduction=reduction)
self.loss = nn.BCEWithLogitsLoss(reduction=reduction)
def forward(self,
pred: torch.Tensor,
@ -74,8 +75,7 @@ class MaskedBalancedBCELoss(nn.Module):
int(negative.sum()), int(positive_count * self.negative_ratio))
assert gt.max() <= 1 and gt.min() >= 0
assert pred.max() <= 1 and pred.min() >= 0
loss = self.binary_cross_entropy(pred, gt)
loss = self.loss(pred, gt)
positive_loss = loss * positive
negative_loss = loss * negative
@ -88,11 +88,67 @@ class MaskedBalancedBCELoss(nn.Module):
@MODELS.register_module()
class MaskedBCELoss(nn.Module):
"""Masked BCE loss.
class MaskedBalancedBCELoss(MaskedBalancedBCEWithLogitsLoss):
"""Masked Balanced BCE loss.
Args:
eps (float, optional): Eps to avoid zero-division error. Defaults to
reduction (str, optional): The method to reduce the loss.
Options are 'none', 'mean' and 'sum'. Defaults to 'none'.
negative_ratio (float or int): Maximum ratio of negative
samples to positive ones. Defaults to 3.
fallback_negative_num (int): When the mask contains no
positive samples, the number of negative samples to be sampled.
Defaults to 0.
eps (float): Eps to avoid zero-division error. Defaults to
1e-6.
"""
def __init__(self,
reduction: str = 'none',
negative_ratio: Union[float, int] = 3,
fallback_negative_num: int = 0,
eps: float = 1e-6) -> None:
super().__init__()
assert reduction in ['none', 'mean', 'sum']
assert isinstance(negative_ratio, (float, int))
assert isinstance(fallback_negative_num, int)
assert isinstance(eps, float)
self.eps = eps
self.negative_ratio = negative_ratio
self.reduction = reduction
self.fallback_negative_num = fallback_negative_num
self.loss = nn.BCELoss(reduction=reduction)
def forward(self,
pred: torch.Tensor,
gt: torch.Tensor,
mask: Optional[torch.Tensor] = None) -> torch.Tensor:
"""Forward function.
Args:
pred (torch.Tensor): The prediction in any shape.
gt (torch.Tensor): The learning target of the prediction in the
same shape as pred.
mask (torch.Tensor, optional): Binary mask in the same shape of
pred, indicating positive regions to calculate the loss. Whole
region will be taken into account if not provided. Defaults to
None.
Returns:
torch.Tensor: The loss value.
"""
assert pred.max() <= 1 and pred.min() >= 0
return super().forward(pred, gt, mask)
@MODELS.register_module()
class MaskedBCEWithLogitsLoss(nn.Module):
"""This loss combines a Sigmoid layers and a masked BCE loss in one single
class. It's AMP-eligible.
Args:
eps (float): Eps to avoid zero-division error. Defaults to
1e-6.
"""
@ -100,7 +156,7 @@ class MaskedBCELoss(nn.Module):
super().__init__()
assert isinstance(eps, float)
self.eps = eps
self.binary_cross_entropy = nn.BCELoss(reduction='none')
self.loss = nn.BCEWithLogitsLoss(reduction='none')
def forward(self,
pred: torch.Tensor,
@ -127,7 +183,45 @@ class MaskedBCELoss(nn.Module):
assert mask.size() == gt.size()
assert gt.max() <= 1 and gt.min() >= 0
assert pred.max() <= 1 and pred.min() >= 0
loss = self.binary_cross_entropy(pred, gt)
loss = self.loss(pred, gt)
return (loss * mask).sum() / (mask.sum() + self.eps)
@MODELS.register_module()
class MaskedBCELoss(MaskedBCEWithLogitsLoss):
"""Masked BCE loss.
Args:
eps (float): Eps to avoid zero-division error. Defaults to
1e-6.
"""
def __init__(self, eps: float = 1e-6) -> None:
super().__init__()
assert isinstance(eps, float)
self.eps = eps
self.loss = nn.BCELoss(reduction='none')
def forward(self,
pred: torch.Tensor,
gt: torch.Tensor,
mask: Optional[torch.Tensor] = None) -> torch.Tensor:
"""Forward function.
Args:
pred (torch.Tensor): The prediction in any shape.
gt (torch.Tensor): The learning target of the prediction in the
same shape as pred.
mask (torch.Tensor, optional): Binary mask in the same shape of
pred, indicating positive regions to calculate the loss. Whole
region will be taken into account if not provided. Defaults to
None.
Returns:
torch.Tensor: The loss value.
"""
assert pred.max() <= 1 and pred.min() >= 0
return super().forward(pred, gt, mask)

View File

@ -6,9 +6,7 @@ from mmengine.model import BaseModule
from torch import Tensor
from mmocr.registry import MODELS
from mmocr.structures import TextDetDataSample
SampleList = List[TextDetDataSample]
from mmocr.utils.typing import DetSampleList
@MODELS.register_module()
@ -65,7 +63,8 @@ class BaseTextDetHead(BaseModule):
self.module_loss = MODELS.build(module_loss)
self.postprocessor = MODELS.build(postprocessor)
def loss(self, x: Tuple[Tensor], batch_data_samples: SampleList) -> dict:
def loss(self, x: Tuple[Tensor],
batch_data_samples: DetSampleList) -> dict:
"""Perform forward propagation and loss calculation of the detection
head on the features of the upstream network.
@ -79,12 +78,13 @@ class BaseTextDetHead(BaseModule):
Returns:
dict: A dictionary of loss components.
"""
outs = self(x)
outs = self(x, batch_data_samples)
losses = self.module_loss(outs, batch_data_samples)
return losses
def loss_and_predict(self, x: Tuple[Tensor], batch_data_samples: SampleList
) -> Tuple[dict, SampleList]:
def loss_and_predict(self, x: Tuple[Tensor],
batch_data_samples: DetSampleList
) -> Tuple[dict, DetSampleList]:
"""Perform forward propagation of the head, then calculate loss and
predictions from the features and data samples.
@ -101,14 +101,14 @@ class BaseTextDetHead(BaseModule):
- predictions (list[:obj:`InstanceData`]): Detection
results of each image after the post process.
"""
outs = self(x)
outs = self(x, batch_data_samples)
losses = self.module_loss(outs, batch_data_samples)
predictions = self.postprocessor(outs, batch_data_samples)
return losses, predictions
def predict(self, x: torch.Tensor,
batch_data_samples: SampleList) -> SampleList:
batch_data_samples: DetSampleList) -> DetSampleList:
"""Perform forward propagation of the detection head and predict
detection results on the features of the upstream network.

View File

@ -9,6 +9,7 @@ from torch import Tensor
from mmocr.models.textdet.heads import BaseTextDetHead
from mmocr.registry import MODELS
from mmocr.structures import TextDetDataSample
from mmocr.utils.typing import DetSampleList
@MODELS.register_module()
@ -53,8 +54,9 @@ class DBHead(BaseTextDetHead):
nn.BatchNorm2d(in_channels // 4), nn.ReLU(inplace=True),
nn.ConvTranspose2d(in_channels // 4, in_channels // 4, 2, 2),
nn.BatchNorm2d(in_channels // 4), nn.ReLU(inplace=True),
nn.ConvTranspose2d(in_channels // 4, 1, 2, 2), nn.Sigmoid())
nn.ConvTranspose2d(in_channels // 4, 1, 2, 2))
self.threshold = self._init_thr(in_channels)
self.sigmoid = nn.Sigmoid()
def _diff_binarize(self, prob_map: Tensor, thr_map: Tensor,
k: int) -> Tensor:
@ -70,26 +72,6 @@ class DBHead(BaseTextDetHead):
"""
return torch.reciprocal(1.0 + torch.exp(-k * (prob_map - thr_map)))
def forward(
self,
img: Tensor,
data_samples: Optional[List[TextDetDataSample]] = None
) -> Tuple[Tensor, Tensor, Tensor]:
"""
Args:
img (Tensor): Shape :math:`(N, C, H, W)`.
data_samples (list[TextDetDataSample], optional): A list of data
samples. Defaults to None.
Returns:
tuple(Tensor, Tensor, Tensor): A tuple of ``prob_map``, ``thr_map``
and ``binary_map``, each of shape :math:`(N, 4H, 4W)`.
"""
prob_map = self.binarize(img).squeeze(1)
thr_map = self.threshold(img).squeeze(1)
binary_map = self._diff_binarize(prob_map, thr_map, k=50).squeeze(1)
return (prob_map, thr_map, binary_map)
def _init_thr(self,
inner_channels: int,
bias: bool = False) -> nn.ModuleList:
@ -103,3 +85,100 @@ class DBHead(BaseTextDetHead):
nn.BatchNorm2d(inner_channels // 4), nn.ReLU(inplace=True),
nn.ConvTranspose2d(inner_channels // 4, 1, 2, 2), nn.Sigmoid())
return seq
def forward(self,
img: Tensor,
data_samples: Optional[List[TextDetDataSample]] = None,
mode: str = 'predict') -> Tuple[Tensor, Tensor, Tensor]:
"""
Args:
img (Tensor): Shape :math:`(N, C, H, W)`.
data_samples (list[TextDetDataSample], optional): A list of data
samples. Defaults to None.
mode (str): Forward mode. It affects the return values. Options are
"loss", "predict" and "both". Defaults to "predict".
- ``loss``: Run the full network and return the prob
logits, threshold map and binary map.
- ``predict``: Run the binarzation part and return the prob
map only.
- ``both``: Run the full network and return prob logits,
threshold map, binary map and prob map.
Returns:
Tensor or tuple(Tensor): Its type depends on ``mode``, read its
docstring for details. Each has the shape of
:math:`(N, 4H, 4W)`.
"""
prob_logits = self.binarize(img).squeeze(1)
prob_map = self.sigmoid(prob_logits)
if mode == 'predict':
return prob_map
thr_map = self.threshold(img).squeeze(1)
binary_map = self._diff_binarize(prob_map, thr_map, k=50).squeeze(1)
if mode == 'loss':
return prob_logits, thr_map, binary_map
return prob_logits, thr_map, binary_map, prob_map
def loss(self, x: Tuple[Tensor],
batch_data_samples: DetSampleList) -> Dict:
"""Perform forward propagation and loss calculation of the detection
head on the features of the upstream network.
Args:
x (tuple[Tensor]): Features from the upstream network, each is
a 4D-tensor.
batch_data_samples (List[:obj:`DetDataSample`]): The Data
Samples. It usually includes information such as
`gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.
Returns:
dict: A dictionary of loss components.
"""
outs = self(x, batch_data_samples, mode='loss')
losses = self.module_loss(outs, batch_data_samples)
return losses
def loss_and_predict(self, x: Tuple[Tensor],
batch_data_samples: DetSampleList
) -> Tuple[dict, DetSampleList]:
"""Perform forward propagation of the head, then calculate loss and
predictions from the features and data samples.
Args:
x (tuple[Tensor]): Features from FPN.
batch_data_samples (list[:obj:`DetDataSample`]): Each item contains
the meta information of each image and corresponding
annotations.
Returns:
tuple: the return value is a tuple contains:
- losses: (dict[str, Tensor]): A dictionary of loss components.
- predictions (list[:obj:`InstanceData`]): Detection
results of each image after the post process.
"""
outs = self(x, batch_data_samples, mode='both')
losses = self.module_loss(outs[:3], batch_data_samples)
predictions = self.postprocessor(outs[3], batch_data_samples)
return losses, predictions
def predict(self, x: torch.Tensor,
batch_data_samples: DetSampleList) -> DetSampleList:
"""Perform forward propagation of the detection head and predict
detection results on the features of the upstream network.
Args:
x (tuple[Tensor]): Multi-level features from the
upstream network, each is a 4D-tensor.
batch_data_samples (List[:obj:`DetDataSample`]): The Data
Samples. It usually includes information such as
`gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.
Returns:
SampleList: Detection results of each image
after the post process.
"""
outs = self(x, batch_data_samples, mode='predict')
predictions = self.postprocessor(outs, batch_data_samples)
return predictions

View File

@ -22,22 +22,26 @@ class DBModuleLoss(nn.Module, TextKernelMixin):
This is partially adapted from https://github.com/MhLiao/DB.
Args:
loss_prob (dict): The loss config for probability map.
loss_thr (dict): The loss config for threshold map.
loss_db (dict): The loss config for binary map.
loss_prob (dict): The loss config for probability map. Defaults to
dict(type='MaskedBalancedBCEWithLogitsLoss').
loss_thr (dict): The loss config for threshold map. Defaults to
dict(type='MaskedSmoothL1Loss', beta=0).
loss_db (dict): The loss config for binary map. Defaults to
dict(type='MaskedDiceLoss').
weight_prob (float): The weight of probability map loss.
Denoted as :math:`\alpha` in paper.
Denoted as :math:`\alpha` in paper. Defaults to 5.
weight_thr (float): The weight of threshold map loss.
Denoted as :math:`\beta` in paper.
shrink_ratio (float): The ratio of shrunk text region.
thr_min (float): The minimum threshold map value.
thr_max (float): The maximum threshold map value.
Denoted as :math:`\beta` in paper. Defaults to 10.
shrink_ratio (float): The ratio of shrunk text region. Defaults to 0.4.
thr_min (float): The minimum threshold map value. Defaults to 0.3.
thr_max (float): The maximum threshold map value. Defaults to 0.7.
min_sidelength (int or float): The minimum sidelength of the
minimum rotated rectangle around any text region.
minimum rotated rectangle around any text region. Defaults to 8.
"""
def __init__(self,
loss_prob: Dict = dict(type='MaskedBalancedBCELoss'),
loss_prob: Dict = dict(
type='MaskedBalancedBCEWithLogitsLoss'),
loss_thr: Dict = dict(type='MaskedSmoothL1Loss', beta=0),
loss_db: Dict = dict(type='MaskedDiceLoss'),
weight_prob: float = 5.,
@ -63,22 +67,22 @@ class DBModuleLoss(nn.Module, TextKernelMixin):
Args:
preds (tuple(tensor)): Raw predictions from model, containing
``prob_map``, ``thr_map`` and ``binary_map``. Each is a tensor
of shape :math:`(N, H, W)`.
``prob_logits``, ``thr_map`` and ``binary_map``.
Each is a tensor of shape :math:`(N, H, W)`.
data_samples (list[TextDetDataSample]): The data samples.
Returns:
results(dict): The dict for dbnet losses with loss_prob, \
loss_db and loss_thr.
"""
prob_map, thr_map, binary_map = preds
prob_logits, thr_map, binary_map = preds
gt_shrinks, gt_shrink_masks, gt_thrs, gt_thr_masks = self.get_targets(
data_samples)
gt_shrinks = gt_shrinks.to(prob_map.device)
gt_shrink_masks = gt_shrink_masks.to(prob_map.device)
gt_shrinks = gt_shrinks.to(prob_logits.device)
gt_shrink_masks = gt_shrink_masks.to(prob_logits.device)
gt_thrs = gt_thrs.to(thr_map.device)
gt_thr_masks = gt_thr_masks.to(thr_map.device)
loss_prob = self.loss_prob(prob_map, gt_shrinks, gt_shrink_masks)
loss_prob = self.loss_prob(prob_logits, gt_shrinks, gt_shrink_masks)
loss_thr = self.loss_thr(thr_map, gt_thrs, gt_thr_masks)
loss_db = self.loss_db(binary_map, gt_shrinks, gt_shrink_masks)

View File

@ -56,10 +56,10 @@ class DRRGModuleLoss(TextSnakeModuleLoss):
jitter_level (float): The jitter level of text component geometric
features. Defaults to 0.2.
loss_text (dict): The loss config used to calculate the text loss.
Defaults to ``dict(
type='Normal', override=dict(name='out_conv'), mean=0, std=0.01)``.
Defaults to ``dict(type='MaskedBalancedBCEWithLogitsLoss',
fallback_negative_num=100, eps=1e-5)``.
loss_center (dict): The loss config used to calculate the center loss.
Defaults to ``dict(type='MaskedBCELoss')``.
Defaults to ``dict(type='MaskedBCEWithLogitsLoss')``.
loss_top (dict): The loss config used to calculate the top loss, which
is a part of the height loss. Defaults to
``dict(type='SmoothL1Loss', reduction='none')``.
@ -92,8 +92,10 @@ class DRRGModuleLoss(TextSnakeModuleLoss):
max_rand_half_height: float = 24.0,
jitter_level: float = 0.2,
loss_text: Dict = dict(
type='MaskedBalancedBCELoss', fallback_negative_num=100, eps=1e-5),
loss_center: Dict = dict(type='MaskedBCELoss'),
type='MaskedBalancedBCEWithLogitsLoss',
fallback_negative_num=100,
eps=1e-5),
loss_center: Dict = dict(type='MaskedBCEWithLogitsLoss'),
loss_top: Dict = dict(type='SmoothL1Loss', reduction='none'),
loss_btm: Dict = dict(type='SmoothL1Loss', reduction='none'),
loss_sin: Dict = dict(type='MaskedSmoothL1Loss'),
@ -185,16 +187,16 @@ class DRRGModuleLoss(TextSnakeModuleLoss):
pred_sin_map = pred_sin_map * scale
pred_cos_map = pred_cos_map * scale
loss_text = self.loss_text(pred_text_region.sigmoid(),
gt['gt_text_masks'], gt['gt_masks'])
loss_text = self.loss_text(pred_text_region, gt['gt_text_masks'],
gt['gt_masks'])
text_mask = (gt['gt_text_masks'] * gt['gt_masks']).float()
negative_text_mask = ((1 - gt['gt_text_masks']) *
gt['gt_masks']).float()
loss_center_positive = self.loss_center(pred_center_region.sigmoid(),
loss_center_positive = self.loss_center(pred_center_region,
gt['gt_center_region_masks'],
text_mask)
loss_center_negative = self.loss_center(pred_center_region.sigmoid(),
loss_center_negative = self.loss_center(pred_center_region,
gt['gt_center_region_masks'],
negative_text_mask)
loss_center = loss_center_positive + 0.5 * loss_center_negative

View File

@ -47,8 +47,10 @@ class TextSnakeModuleLoss(nn.Module, TextKernelMixin):
resample_step: float = 4.0,
center_region_shrink_ratio: float = 0.3,
loss_text: Dict = dict(
type='MaskedBalancedBCELoss', fallback_negative_num=100, eps=1e-5),
loss_center: Dict = dict(type='MaskedBCELoss'),
type='MaskedBalancedBCEWithLogitsLoss',
fallback_negative_num=100,
eps=1e-5),
loss_center: Dict = dict(type='MaskedBCEWithLogitsLoss'),
loss_radius: Dict = dict(type='MaskedSmoothL1Loss'),
loss_sin: Dict = dict(type='MaskedSmoothL1Loss'),
loss_cos: Dict = dict(type='MaskedSmoothL1Loss')
@ -146,11 +148,11 @@ class TextSnakeModuleLoss(nn.Module, TextKernelMixin):
pred_sin_map = pred_sin_map * scale
pred_cos_map = pred_cos_map * scale
loss_text = self.loss_text(pred_text_region.sigmoid(),
gt['gt_text_masks'], gt['gt_masks'])
loss_text = self.loss_text(pred_text_region, gt['gt_text_masks'],
gt['gt_masks'])
text_mask = (gt['gt_text_masks'] * gt['gt_masks']).float()
loss_center = self.loss_center(pred_center_region.sigmoid(),
loss_center = self.loss_center(pred_center_region,
gt['gt_center_region_masks'], text_mask)
center_mask = (gt['gt_center_region_masks'] * gt['gt_masks']).float()

View File

@ -39,15 +39,15 @@ class BaseTextDetPostProcessor:
self.test_cfg = test_cfg
def __call__(self,
pred_results: dict,
pred_results: Union[Tensor, List[Tensor]],
data_samples: Sequence[TextDetDataSample],
training: bool = False) -> Sequence[TextDetDataSample]:
"""Postprocess pred_results according to metainfos in data_samples.
Args:
pred_results (dict): The prediction results stored in a dictionary.
Usually each item to be post-processed is expected to be a
batched tensor.
pred_results (Union[Tensor, List[Tensor]]): The prediction results
stored in a tensor or a list of tensor. Usually each item to
be post-processed is expected to be a batched tensor.
data_samples (list[TextDetDataSample]): Batch of data_samples,
each corresponding to a prediction result.
training (bool): Whether the model is in training mode. Defaults to
@ -65,13 +65,14 @@ class BaseTextDetPostProcessor:
return results
def _process_single(self, pred_result: dict,
def _process_single(self, pred_result: Union[Tensor, List[Tensor]],
data_sample: TextDetDataSample,
**kwargs) -> TextDetDataSample:
"""Process prediction results from one image.
Args:
pred_result (dict): Prediction results of an image.
pred_result (Union[Tensor, List[Tensor]]): Prediction results of an
image.
data_sample (TextDetDataSample): Datasample of an image.
"""
@ -103,13 +104,13 @@ class BaseTextDetPostProcessor:
results.pred_instances[key], scale_factor, mode='div')
return results
def get_text_instances(self, pred_results: dict,
def get_text_instances(self, pred_results: Union[Tensor, List[Tensor]],
data_sample: TextDetDataSample,
**kwargs) -> TextDetDataSample:
"""Get text instance predictions of one image.
Args:
pred_result (dict): Prediction results of an image.
pred_result (tuple(Tensor)): Prediction results of an image.
data_sample (TextDetDataSample): Datasample of an image.
**kwargs: Other parameters. Configurable via ``__init__.train_cfg``
and ``__init__.test_cfg``.

View File

@ -1,5 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Sequence, Tuple
from typing import Sequence
import cv2
import numpy as np
@ -59,14 +59,14 @@ class DBPostprocessor(BaseTextDetPostProcessor):
self.epsilon_ratio = epsilon_ratio
self.max_candidates = max_candidates
def get_text_instances(self, pred_results: Tuple[Tensor, Tensor, Tensor],
def get_text_instances(self, prob_map: Tensor,
data_sample: TextDetDataSample
) -> TextDetDataSample:
"""Get text instance predictions of one image.
Args:
pred_result (tuple(Tensor)): A tuple of 3 tensors where the first
tensor is ``prob_map`` of shape :math:`(N, H, W)`.
pred_result (Tensor): DBNet's output ``prob_map`` of shape
:math:`(H, W)`.
data_sample (TextDetDataSample): Datasample of an image.
Returns:
@ -80,7 +80,6 @@ class DBPostprocessor(BaseTextDetPostProcessor):
data_sample.pred_instances.polygons = []
data_sample.pred_instances.scores = []
prob_map = pred_results[0]
text_mask = prob_map > self.mask_thr
score_map = prob_map.data.cpu().numpy().astype(np.float32)

View File

@ -3,7 +3,9 @@ from unittest import TestCase
import torch
from mmocr.models.common.losses import MaskedBalancedBCELoss, MaskedBCELoss
from mmocr.models.common.losses import (MaskedBalancedBCELoss,
MaskedBalancedBCEWithLogitsLoss,
MaskedBCELoss, MaskedBCEWithLogitsLoss)
class TestMaskedBalancedBCELoss(TestCase):
@ -117,3 +119,117 @@ class TestMaskedBCELoss(TestCase):
zero_mask = torch.FloatTensor([0, 0, 0, 0])
self.assertAlmostEqual(
self.bce_loss(self.pred, self.gt, zero_mask).item(), 0)
class TestMaskedBalancedWithLogitsBCELoss(TestCase):
def setUp(self) -> None:
self.loss = MaskedBalancedBCEWithLogitsLoss(negative_ratio=2)
self.pred = torch.FloatTensor([1.5, 1.5, 1.5, 1.5])
self.gt = torch.FloatTensor([1, 1, 1, 0])
self.mask = torch.BoolTensor([True, False, False, True])
def test_init(self):
with self.assertRaises(AssertionError):
MaskedBalancedBCEWithLogitsLoss(reduction='any')
with self.assertRaises(AssertionError):
MaskedBalancedBCEWithLogitsLoss(negative_ratio='a')
with self.assertRaises(AssertionError):
MaskedBalancedBCEWithLogitsLoss(eps='a')
with self.assertRaises(AssertionError):
MaskedBalancedBCEWithLogitsLoss(fallback_negative_num='a')
def test_forward(self):
# Shape mismatch between pred and gt
with self.assertRaises(AssertionError):
invalid_gt = torch.FloatTensor([0, 0, 0])
self.loss(self.pred, invalid_gt)
# Shape mismatch between pred and mask
with self.assertRaises(AssertionError):
invalid_mask = torch.BoolTensor([True, False, False])
self.loss(self.pred, self.gt, invalid_mask)
# Invalid gt
with self.assertRaises(AssertionError):
invalid_gt = torch.FloatTensor([2, 3, 4, 5])
self.loss(self.pred, invalid_gt, self.mask)
logit = torch.FloatTensor([1.5])
self.assertAlmostEqual(
self.loss(self.pred, self.gt).item(),
((-torch.log(torch.sigmoid(logit)) * 3 -
torch.log(1 - torch.sigmoid(logit))) / 4).item(),
delta=0.0001)
self.assertAlmostEqual(
self.loss(self.pred, self.gt, self.mask).item(),
(-torch.log(torch.sigmoid(logit)) -
torch.log(1 - torch.sigmoid(logit))).item() / 2,
delta=0.0001)
# Test zero mask
zero_mask = torch.FloatTensor([0, 0, 0, 0])
self.assertAlmostEqual(
self.loss(self.pred, self.gt, zero_mask).item(), 0)
# Test 0 < fallback_negative_num < negative numbers
all_neg_gt = torch.zeros((4, ))
self.fallback_bce_loss = MaskedBalancedBCEWithLogitsLoss(
fallback_negative_num=1)
self.assertAlmostEqual(
self.fallback_bce_loss(self.pred, all_neg_gt, self.mask).item(),
-torch.log(1 - torch.sigmoid(logit)).item(),
delta=0.001)
# Test fallback_negative_num > negative numbers
self.fallback_bce_loss = MaskedBalancedBCEWithLogitsLoss(
fallback_negative_num=5)
self.assertAlmostEqual(
self.fallback_bce_loss(self.pred, all_neg_gt, self.mask).item(),
-torch.log(1 - torch.sigmoid(logit)).item(),
delta=0.001)
class TestMaskedBCEWithLogitsLoss(TestCase):
def setUp(self) -> None:
self.loss = MaskedBCEWithLogitsLoss()
self.pred = torch.FloatTensor([1.5, 1.5, 1.5, 1.5])
self.gt = torch.FloatTensor([1, 1, 1, 0])
self.mask = torch.BoolTensor([True, False, False, True])
def test_init(self):
with self.assertRaises(AssertionError):
MaskedBCEWithLogitsLoss(eps='a')
def test_forward(self):
# Shape mismatch between pred and gt
with self.assertRaises(AssertionError):
invalid_gt = torch.FloatTensor([0, 0, 0])
self.loss(self.pred, invalid_gt)
# Shape mismatch between pred and mask
with self.assertRaises(AssertionError):
invalid_mask = torch.BoolTensor([True, False, False])
self.loss(self.pred, self.gt, invalid_mask)
logit = torch.FloatTensor([1.5])
self.assertAlmostEqual(
self.loss(self.pred, self.gt).item(),
((-torch.log(torch.sigmoid(logit)) * 3 -
torch.log(1 - torch.sigmoid(logit))) / 4).item(),
delta=0.0001)
self.assertAlmostEqual(
self.loss(self.pred, self.gt, self.mask).item(),
(-torch.log(torch.sigmoid(logit)) -
torch.log(1 - torch.sigmoid(logit))).item() / 2,
delta=0.0001)
# Test zero mask
zero_mask = torch.FloatTensor([0, 0, 0, 0])
self.assertAlmostEqual(
self.loss(self.pred, self.gt, zero_mask).item(), 0)

View File

@ -4,10 +4,24 @@ from unittest import TestCase
import torch
from mmocr.models.textdet.heads import DBHead
from mmocr.registry import MODELS
class TestDBHead(TestCase):
# Use to replace module loss and postprocessors
@MODELS.register_module(name='DBDummy')
class DummyModule:
def __call__(self, x, data_samples):
return x
def setUp(self) -> None:
self.db_head = DBHead(
in_channels=10,
module_loss=dict(type='DBDummy'),
postprocessor=dict(type='DBDummy'))
def test_init(self):
with self.assertRaises(AssertionError):
DBHead(in_channels='test', with_bias=False)
@ -16,9 +30,28 @@ class TestDBHead(TestCase):
DBHead(in_channels=1, with_bias='Text')
def test_forward(self):
db_head = DBHead(in_channels=10)
data = torch.randn((2, 10, 40, 50))
results = db_head(data, None)
results = self.db_head(data, None)
self.assertEqual(results[0].shape, (2, 160, 200))
self.assertEqual(results[1].shape, (2, 160, 200))
self.assertEqual(results[2].shape, (2, 160, 200))
def test_loss(self):
data = torch.randn((2, 10, 40, 50))
results = self.db_head.loss(data, None)
for i in range(3):
self.assertEqual(results[i].shape, (2, 160, 200))
def test_predict(self):
data = torch.randn((2, 10, 40, 50))
results = self.db_head.predict(data, None)
self.assertEqual(results.shape, (2, 160, 200))
def test_loss_and_predict(self):
data = torch.randn((2, 10, 40, 50))
loss_results, pred_results = self.db_head.loss_and_predict(data, None)
for i in range(3):
self.assertEqual(loss_results[i].shape, (2, 160, 200))
self.assertEqual(pred_results.shape, (2, 160, 200))
self.assertTrue(
torch.allclose(pred_results, loss_results[0].sigmoid()))

View File

@ -23,7 +23,7 @@ class TestDBPostProcessor(unittest.TestCase):
def test_get_text_instances(self, text_repr_type):
postprocessor = DBPostprocessor(text_repr_type=text_repr_type)
pred_result = (torch.rand(4, 5), torch.rand(4, 5), torch.rand(4, 5))
pred_result = torch.rand(4, 5)
data_sample = TextDetDataSample(
metainfo=dict(scale_factor=(0.5, 1)),
gt_instances=InstanceData(polygons=[
@ -36,12 +36,11 @@ class TestDBPostProcessor(unittest.TestCase):
self.assertTrue(
isinstance(results.pred_instances['scores'], torch.FloatTensor))
preds = (torch.FloatTensor([[0.8, 0.8, 0.8, 0.8, 0],
[0.8, 0.8, 0.8, 0.8, 0],
[0.8, 0.8, 0.8, 0.8, 0],
[0.8, 0.8, 0.8, 0.8, 0],
[0.8, 0.8, 0.8, 0.8, 0]]),
torch.rand([1, 10]), torch.rand([1, 10]))
preds = torch.FloatTensor([[0.8, 0.8, 0.8, 0.8, 0],
[0.8, 0.8, 0.8, 0.8, 0],
[0.8, 0.8, 0.8, 0.8, 0],
[0.8, 0.8, 0.8, 0.8, 0],
[0.8, 0.8, 0.8, 0.8, 0]])
postprocessor = DBPostprocessor(
text_repr_type=text_repr_type, min_text_width=0)
results = postprocessor.get_text_instances(preds, data_sample)
@ -49,8 +48,7 @@ class TestDBPostProcessor(unittest.TestCase):
postprocessor = DBPostprocessor(
min_text_score=1, text_repr_type=text_repr_type)
pred_result = (torch.rand(4, 5) * 0.8, torch.rand(4, 5) * 0.8,
torch.rand(4, 5) * 0.8)
pred_result = torch.rand(4, 5) * 0.8
results = postprocessor.get_text_instances(pred_result, data_sample)
self.assertEqual(results.pred_instances.polygons, [])
self.assertTrue(