mirror of https://github.com/open-mmlab/mmocr.git
[Feature] Use BCEWithLogitsLoss as many as possible for AMP (#1309)
* Use BCEWithLogitsLoss as many as possible for AMP * fix * Optimize DBNet * fix docstr * Use branch in dbhead, fix missing data_samples in textdethead * fixpull/1315/head
parent
5c8c774aa9
commit
1860a3a3b6
|
@ -1,10 +1,12 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .bce_loss import MaskedBalancedBCELoss, MaskedBCELoss
|
||||
from .bce_loss import (MaskedBalancedBCELoss, MaskedBalancedBCEWithLogitsLoss,
|
||||
MaskedBCELoss, MaskedBCEWithLogitsLoss)
|
||||
from .ce_loss import CrossEntropyLoss
|
||||
from .dice_loss import MaskedDiceLoss, MaskedSquareDiceLoss
|
||||
from .l1_loss import MaskedSmoothL1Loss, SmoothL1Loss
|
||||
|
||||
__all__ = [
|
||||
'MaskedBalancedBCELoss', 'MaskedDiceLoss', 'MaskedSmoothL1Loss',
|
||||
'MaskedSquareDiceLoss', 'MaskedBCELoss', 'SmoothL1Loss', 'CrossEntropyLoss'
|
||||
'MaskedBalancedBCEWithLogitsLoss', 'MaskedDiceLoss', 'MaskedSmoothL1Loss',
|
||||
'MaskedSquareDiceLoss', 'MaskedBCEWithLogitsLoss', 'SmoothL1Loss',
|
||||
'CrossEntropyLoss', 'MaskedBalancedBCELoss', 'MaskedBCELoss'
|
||||
]
|
||||
|
|
|
@ -8,8 +8,9 @@ from mmocr.registry import MODELS
|
|||
|
||||
|
||||
@MODELS.register_module()
|
||||
class MaskedBalancedBCELoss(nn.Module):
|
||||
"""Masked Balanced BCE loss.
|
||||
class MaskedBalancedBCEWithLogitsLoss(nn.Module):
|
||||
"""This loss combines a Sigmoid layers and a masked balanced BCE loss in
|
||||
one single class. It's AMP-eligible.
|
||||
|
||||
Args:
|
||||
reduction (str, optional): The method to reduce the loss.
|
||||
|
@ -37,7 +38,7 @@ class MaskedBalancedBCELoss(nn.Module):
|
|||
self.negative_ratio = negative_ratio
|
||||
self.reduction = reduction
|
||||
self.fallback_negative_num = fallback_negative_num
|
||||
self.binary_cross_entropy = nn.BCELoss(reduction=reduction)
|
||||
self.loss = nn.BCEWithLogitsLoss(reduction=reduction)
|
||||
|
||||
def forward(self,
|
||||
pred: torch.Tensor,
|
||||
|
@ -74,8 +75,7 @@ class MaskedBalancedBCELoss(nn.Module):
|
|||
int(negative.sum()), int(positive_count * self.negative_ratio))
|
||||
|
||||
assert gt.max() <= 1 and gt.min() >= 0
|
||||
assert pred.max() <= 1 and pred.min() >= 0
|
||||
loss = self.binary_cross_entropy(pred, gt)
|
||||
loss = self.loss(pred, gt)
|
||||
positive_loss = loss * positive
|
||||
negative_loss = loss * negative
|
||||
|
||||
|
@ -88,11 +88,67 @@ class MaskedBalancedBCELoss(nn.Module):
|
|||
|
||||
|
||||
@MODELS.register_module()
|
||||
class MaskedBCELoss(nn.Module):
|
||||
"""Masked BCE loss.
|
||||
class MaskedBalancedBCELoss(MaskedBalancedBCEWithLogitsLoss):
|
||||
"""Masked Balanced BCE loss.
|
||||
|
||||
Args:
|
||||
eps (float, optional): Eps to avoid zero-division error. Defaults to
|
||||
reduction (str, optional): The method to reduce the loss.
|
||||
Options are 'none', 'mean' and 'sum'. Defaults to 'none'.
|
||||
negative_ratio (float or int): Maximum ratio of negative
|
||||
samples to positive ones. Defaults to 3.
|
||||
fallback_negative_num (int): When the mask contains no
|
||||
positive samples, the number of negative samples to be sampled.
|
||||
Defaults to 0.
|
||||
eps (float): Eps to avoid zero-division error. Defaults to
|
||||
1e-6.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
reduction: str = 'none',
|
||||
negative_ratio: Union[float, int] = 3,
|
||||
fallback_negative_num: int = 0,
|
||||
eps: float = 1e-6) -> None:
|
||||
super().__init__()
|
||||
assert reduction in ['none', 'mean', 'sum']
|
||||
assert isinstance(negative_ratio, (float, int))
|
||||
assert isinstance(fallback_negative_num, int)
|
||||
assert isinstance(eps, float)
|
||||
self.eps = eps
|
||||
self.negative_ratio = negative_ratio
|
||||
self.reduction = reduction
|
||||
self.fallback_negative_num = fallback_negative_num
|
||||
self.loss = nn.BCELoss(reduction=reduction)
|
||||
|
||||
def forward(self,
|
||||
pred: torch.Tensor,
|
||||
gt: torch.Tensor,
|
||||
mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
"""Forward function.
|
||||
|
||||
Args:
|
||||
pred (torch.Tensor): The prediction in any shape.
|
||||
gt (torch.Tensor): The learning target of the prediction in the
|
||||
same shape as pred.
|
||||
mask (torch.Tensor, optional): Binary mask in the same shape of
|
||||
pred, indicating positive regions to calculate the loss. Whole
|
||||
region will be taken into account if not provided. Defaults to
|
||||
None.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The loss value.
|
||||
"""
|
||||
|
||||
assert pred.max() <= 1 and pred.min() >= 0
|
||||
return super().forward(pred, gt, mask)
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class MaskedBCEWithLogitsLoss(nn.Module):
|
||||
"""This loss combines a Sigmoid layers and a masked BCE loss in one single
|
||||
class. It's AMP-eligible.
|
||||
|
||||
Args:
|
||||
eps (float): Eps to avoid zero-division error. Defaults to
|
||||
1e-6.
|
||||
"""
|
||||
|
||||
|
@ -100,7 +156,7 @@ class MaskedBCELoss(nn.Module):
|
|||
super().__init__()
|
||||
assert isinstance(eps, float)
|
||||
self.eps = eps
|
||||
self.binary_cross_entropy = nn.BCELoss(reduction='none')
|
||||
self.loss = nn.BCEWithLogitsLoss(reduction='none')
|
||||
|
||||
def forward(self,
|
||||
pred: torch.Tensor,
|
||||
|
@ -127,7 +183,45 @@ class MaskedBCELoss(nn.Module):
|
|||
assert mask.size() == gt.size()
|
||||
|
||||
assert gt.max() <= 1 and gt.min() >= 0
|
||||
assert pred.max() <= 1 and pred.min() >= 0
|
||||
loss = self.binary_cross_entropy(pred, gt)
|
||||
loss = self.loss(pred, gt)
|
||||
|
||||
return (loss * mask).sum() / (mask.sum() + self.eps)
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class MaskedBCELoss(MaskedBCEWithLogitsLoss):
|
||||
"""Masked BCE loss.
|
||||
|
||||
Args:
|
||||
eps (float): Eps to avoid zero-division error. Defaults to
|
||||
1e-6.
|
||||
"""
|
||||
|
||||
def __init__(self, eps: float = 1e-6) -> None:
|
||||
super().__init__()
|
||||
assert isinstance(eps, float)
|
||||
self.eps = eps
|
||||
self.loss = nn.BCELoss(reduction='none')
|
||||
|
||||
def forward(self,
|
||||
pred: torch.Tensor,
|
||||
gt: torch.Tensor,
|
||||
mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
"""Forward function.
|
||||
|
||||
Args:
|
||||
pred (torch.Tensor): The prediction in any shape.
|
||||
gt (torch.Tensor): The learning target of the prediction in the
|
||||
same shape as pred.
|
||||
mask (torch.Tensor, optional): Binary mask in the same shape of
|
||||
pred, indicating positive regions to calculate the loss. Whole
|
||||
region will be taken into account if not provided. Defaults to
|
||||
None.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The loss value.
|
||||
"""
|
||||
|
||||
assert pred.max() <= 1 and pred.min() >= 0
|
||||
|
||||
return super().forward(pred, gt, mask)
|
||||
|
|
|
@ -6,9 +6,7 @@ from mmengine.model import BaseModule
|
|||
from torch import Tensor
|
||||
|
||||
from mmocr.registry import MODELS
|
||||
from mmocr.structures import TextDetDataSample
|
||||
|
||||
SampleList = List[TextDetDataSample]
|
||||
from mmocr.utils.typing import DetSampleList
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
|
@ -65,7 +63,8 @@ class BaseTextDetHead(BaseModule):
|
|||
self.module_loss = MODELS.build(module_loss)
|
||||
self.postprocessor = MODELS.build(postprocessor)
|
||||
|
||||
def loss(self, x: Tuple[Tensor], batch_data_samples: SampleList) -> dict:
|
||||
def loss(self, x: Tuple[Tensor],
|
||||
batch_data_samples: DetSampleList) -> dict:
|
||||
"""Perform forward propagation and loss calculation of the detection
|
||||
head on the features of the upstream network.
|
||||
|
||||
|
@ -79,12 +78,13 @@ class BaseTextDetHead(BaseModule):
|
|||
Returns:
|
||||
dict: A dictionary of loss components.
|
||||
"""
|
||||
outs = self(x)
|
||||
outs = self(x, batch_data_samples)
|
||||
losses = self.module_loss(outs, batch_data_samples)
|
||||
return losses
|
||||
|
||||
def loss_and_predict(self, x: Tuple[Tensor], batch_data_samples: SampleList
|
||||
) -> Tuple[dict, SampleList]:
|
||||
def loss_and_predict(self, x: Tuple[Tensor],
|
||||
batch_data_samples: DetSampleList
|
||||
) -> Tuple[dict, DetSampleList]:
|
||||
"""Perform forward propagation of the head, then calculate loss and
|
||||
predictions from the features and data samples.
|
||||
|
||||
|
@ -101,14 +101,14 @@ class BaseTextDetHead(BaseModule):
|
|||
- predictions (list[:obj:`InstanceData`]): Detection
|
||||
results of each image after the post process.
|
||||
"""
|
||||
outs = self(x)
|
||||
outs = self(x, batch_data_samples)
|
||||
losses = self.module_loss(outs, batch_data_samples)
|
||||
|
||||
predictions = self.postprocessor(outs, batch_data_samples)
|
||||
return losses, predictions
|
||||
|
||||
def predict(self, x: torch.Tensor,
|
||||
batch_data_samples: SampleList) -> SampleList:
|
||||
batch_data_samples: DetSampleList) -> DetSampleList:
|
||||
"""Perform forward propagation of the detection head and predict
|
||||
detection results on the features of the upstream network.
|
||||
|
||||
|
|
|
@ -9,6 +9,7 @@ from torch import Tensor
|
|||
from mmocr.models.textdet.heads import BaseTextDetHead
|
||||
from mmocr.registry import MODELS
|
||||
from mmocr.structures import TextDetDataSample
|
||||
from mmocr.utils.typing import DetSampleList
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
|
@ -53,8 +54,9 @@ class DBHead(BaseTextDetHead):
|
|||
nn.BatchNorm2d(in_channels // 4), nn.ReLU(inplace=True),
|
||||
nn.ConvTranspose2d(in_channels // 4, in_channels // 4, 2, 2),
|
||||
nn.BatchNorm2d(in_channels // 4), nn.ReLU(inplace=True),
|
||||
nn.ConvTranspose2d(in_channels // 4, 1, 2, 2), nn.Sigmoid())
|
||||
nn.ConvTranspose2d(in_channels // 4, 1, 2, 2))
|
||||
self.threshold = self._init_thr(in_channels)
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
|
||||
def _diff_binarize(self, prob_map: Tensor, thr_map: Tensor,
|
||||
k: int) -> Tensor:
|
||||
|
@ -70,26 +72,6 @@ class DBHead(BaseTextDetHead):
|
|||
"""
|
||||
return torch.reciprocal(1.0 + torch.exp(-k * (prob_map - thr_map)))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
img: Tensor,
|
||||
data_samples: Optional[List[TextDetDataSample]] = None
|
||||
) -> Tuple[Tensor, Tensor, Tensor]:
|
||||
"""
|
||||
Args:
|
||||
img (Tensor): Shape :math:`(N, C, H, W)`.
|
||||
data_samples (list[TextDetDataSample], optional): A list of data
|
||||
samples. Defaults to None.
|
||||
|
||||
Returns:
|
||||
tuple(Tensor, Tensor, Tensor): A tuple of ``prob_map``, ``thr_map``
|
||||
and ``binary_map``, each of shape :math:`(N, 4H, 4W)`.
|
||||
"""
|
||||
prob_map = self.binarize(img).squeeze(1)
|
||||
thr_map = self.threshold(img).squeeze(1)
|
||||
binary_map = self._diff_binarize(prob_map, thr_map, k=50).squeeze(1)
|
||||
return (prob_map, thr_map, binary_map)
|
||||
|
||||
def _init_thr(self,
|
||||
inner_channels: int,
|
||||
bias: bool = False) -> nn.ModuleList:
|
||||
|
@ -103,3 +85,100 @@ class DBHead(BaseTextDetHead):
|
|||
nn.BatchNorm2d(inner_channels // 4), nn.ReLU(inplace=True),
|
||||
nn.ConvTranspose2d(inner_channels // 4, 1, 2, 2), nn.Sigmoid())
|
||||
return seq
|
||||
|
||||
def forward(self,
|
||||
img: Tensor,
|
||||
data_samples: Optional[List[TextDetDataSample]] = None,
|
||||
mode: str = 'predict') -> Tuple[Tensor, Tensor, Tensor]:
|
||||
"""
|
||||
Args:
|
||||
img (Tensor): Shape :math:`(N, C, H, W)`.
|
||||
data_samples (list[TextDetDataSample], optional): A list of data
|
||||
samples. Defaults to None.
|
||||
mode (str): Forward mode. It affects the return values. Options are
|
||||
"loss", "predict" and "both". Defaults to "predict".
|
||||
|
||||
- ``loss``: Run the full network and return the prob
|
||||
logits, threshold map and binary map.
|
||||
- ``predict``: Run the binarzation part and return the prob
|
||||
map only.
|
||||
- ``both``: Run the full network and return prob logits,
|
||||
threshold map, binary map and prob map.
|
||||
|
||||
Returns:
|
||||
Tensor or tuple(Tensor): Its type depends on ``mode``, read its
|
||||
docstring for details. Each has the shape of
|
||||
:math:`(N, 4H, 4W)`.
|
||||
"""
|
||||
prob_logits = self.binarize(img).squeeze(1)
|
||||
prob_map = self.sigmoid(prob_logits)
|
||||
if mode == 'predict':
|
||||
return prob_map
|
||||
thr_map = self.threshold(img).squeeze(1)
|
||||
binary_map = self._diff_binarize(prob_map, thr_map, k=50).squeeze(1)
|
||||
if mode == 'loss':
|
||||
return prob_logits, thr_map, binary_map
|
||||
return prob_logits, thr_map, binary_map, prob_map
|
||||
|
||||
def loss(self, x: Tuple[Tensor],
|
||||
batch_data_samples: DetSampleList) -> Dict:
|
||||
"""Perform forward propagation and loss calculation of the detection
|
||||
head on the features of the upstream network.
|
||||
|
||||
Args:
|
||||
x (tuple[Tensor]): Features from the upstream network, each is
|
||||
a 4D-tensor.
|
||||
batch_data_samples (List[:obj:`DetDataSample`]): The Data
|
||||
Samples. It usually includes information such as
|
||||
`gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary of loss components.
|
||||
"""
|
||||
outs = self(x, batch_data_samples, mode='loss')
|
||||
losses = self.module_loss(outs, batch_data_samples)
|
||||
return losses
|
||||
|
||||
def loss_and_predict(self, x: Tuple[Tensor],
|
||||
batch_data_samples: DetSampleList
|
||||
) -> Tuple[dict, DetSampleList]:
|
||||
"""Perform forward propagation of the head, then calculate loss and
|
||||
predictions from the features and data samples.
|
||||
|
||||
Args:
|
||||
x (tuple[Tensor]): Features from FPN.
|
||||
batch_data_samples (list[:obj:`DetDataSample`]): Each item contains
|
||||
the meta information of each image and corresponding
|
||||
annotations.
|
||||
|
||||
Returns:
|
||||
tuple: the return value is a tuple contains:
|
||||
|
||||
- losses: (dict[str, Tensor]): A dictionary of loss components.
|
||||
- predictions (list[:obj:`InstanceData`]): Detection
|
||||
results of each image after the post process.
|
||||
"""
|
||||
outs = self(x, batch_data_samples, mode='both')
|
||||
losses = self.module_loss(outs[:3], batch_data_samples)
|
||||
predictions = self.postprocessor(outs[3], batch_data_samples)
|
||||
return losses, predictions
|
||||
|
||||
def predict(self, x: torch.Tensor,
|
||||
batch_data_samples: DetSampleList) -> DetSampleList:
|
||||
"""Perform forward propagation of the detection head and predict
|
||||
detection results on the features of the upstream network.
|
||||
|
||||
Args:
|
||||
x (tuple[Tensor]): Multi-level features from the
|
||||
upstream network, each is a 4D-tensor.
|
||||
batch_data_samples (List[:obj:`DetDataSample`]): The Data
|
||||
Samples. It usually includes information such as
|
||||
`gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.
|
||||
|
||||
Returns:
|
||||
SampleList: Detection results of each image
|
||||
after the post process.
|
||||
"""
|
||||
outs = self(x, batch_data_samples, mode='predict')
|
||||
predictions = self.postprocessor(outs, batch_data_samples)
|
||||
return predictions
|
||||
|
|
|
@ -22,22 +22,26 @@ class DBModuleLoss(nn.Module, TextKernelMixin):
|
|||
This is partially adapted from https://github.com/MhLiao/DB.
|
||||
|
||||
Args:
|
||||
loss_prob (dict): The loss config for probability map.
|
||||
loss_thr (dict): The loss config for threshold map.
|
||||
loss_db (dict): The loss config for binary map.
|
||||
loss_prob (dict): The loss config for probability map. Defaults to
|
||||
dict(type='MaskedBalancedBCEWithLogitsLoss').
|
||||
loss_thr (dict): The loss config for threshold map. Defaults to
|
||||
dict(type='MaskedSmoothL1Loss', beta=0).
|
||||
loss_db (dict): The loss config for binary map. Defaults to
|
||||
dict(type='MaskedDiceLoss').
|
||||
weight_prob (float): The weight of probability map loss.
|
||||
Denoted as :math:`\alpha` in paper.
|
||||
Denoted as :math:`\alpha` in paper. Defaults to 5.
|
||||
weight_thr (float): The weight of threshold map loss.
|
||||
Denoted as :math:`\beta` in paper.
|
||||
shrink_ratio (float): The ratio of shrunk text region.
|
||||
thr_min (float): The minimum threshold map value.
|
||||
thr_max (float): The maximum threshold map value.
|
||||
Denoted as :math:`\beta` in paper. Defaults to 10.
|
||||
shrink_ratio (float): The ratio of shrunk text region. Defaults to 0.4.
|
||||
thr_min (float): The minimum threshold map value. Defaults to 0.3.
|
||||
thr_max (float): The maximum threshold map value. Defaults to 0.7.
|
||||
min_sidelength (int or float): The minimum sidelength of the
|
||||
minimum rotated rectangle around any text region.
|
||||
minimum rotated rectangle around any text region. Defaults to 8.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
loss_prob: Dict = dict(type='MaskedBalancedBCELoss'),
|
||||
loss_prob: Dict = dict(
|
||||
type='MaskedBalancedBCEWithLogitsLoss'),
|
||||
loss_thr: Dict = dict(type='MaskedSmoothL1Loss', beta=0),
|
||||
loss_db: Dict = dict(type='MaskedDiceLoss'),
|
||||
weight_prob: float = 5.,
|
||||
|
@ -63,22 +67,22 @@ class DBModuleLoss(nn.Module, TextKernelMixin):
|
|||
|
||||
Args:
|
||||
preds (tuple(tensor)): Raw predictions from model, containing
|
||||
``prob_map``, ``thr_map`` and ``binary_map``. Each is a tensor
|
||||
of shape :math:`(N, H, W)`.
|
||||
``prob_logits``, ``thr_map`` and ``binary_map``.
|
||||
Each is a tensor of shape :math:`(N, H, W)`.
|
||||
data_samples (list[TextDetDataSample]): The data samples.
|
||||
|
||||
Returns:
|
||||
results(dict): The dict for dbnet losses with loss_prob, \
|
||||
loss_db and loss_thr.
|
||||
"""
|
||||
prob_map, thr_map, binary_map = preds
|
||||
prob_logits, thr_map, binary_map = preds
|
||||
gt_shrinks, gt_shrink_masks, gt_thrs, gt_thr_masks = self.get_targets(
|
||||
data_samples)
|
||||
gt_shrinks = gt_shrinks.to(prob_map.device)
|
||||
gt_shrink_masks = gt_shrink_masks.to(prob_map.device)
|
||||
gt_shrinks = gt_shrinks.to(prob_logits.device)
|
||||
gt_shrink_masks = gt_shrink_masks.to(prob_logits.device)
|
||||
gt_thrs = gt_thrs.to(thr_map.device)
|
||||
gt_thr_masks = gt_thr_masks.to(thr_map.device)
|
||||
loss_prob = self.loss_prob(prob_map, gt_shrinks, gt_shrink_masks)
|
||||
loss_prob = self.loss_prob(prob_logits, gt_shrinks, gt_shrink_masks)
|
||||
|
||||
loss_thr = self.loss_thr(thr_map, gt_thrs, gt_thr_masks)
|
||||
loss_db = self.loss_db(binary_map, gt_shrinks, gt_shrink_masks)
|
||||
|
|
|
@ -56,10 +56,10 @@ class DRRGModuleLoss(TextSnakeModuleLoss):
|
|||
jitter_level (float): The jitter level of text component geometric
|
||||
features. Defaults to 0.2.
|
||||
loss_text (dict): The loss config used to calculate the text loss.
|
||||
Defaults to ``dict(
|
||||
type='Normal', override=dict(name='out_conv'), mean=0, std=0.01)``.
|
||||
Defaults to ``dict(type='MaskedBalancedBCEWithLogitsLoss',
|
||||
fallback_negative_num=100, eps=1e-5)``.
|
||||
loss_center (dict): The loss config used to calculate the center loss.
|
||||
Defaults to ``dict(type='MaskedBCELoss')``.
|
||||
Defaults to ``dict(type='MaskedBCEWithLogitsLoss')``.
|
||||
loss_top (dict): The loss config used to calculate the top loss, which
|
||||
is a part of the height loss. Defaults to
|
||||
``dict(type='SmoothL1Loss', reduction='none')``.
|
||||
|
@ -92,8 +92,10 @@ class DRRGModuleLoss(TextSnakeModuleLoss):
|
|||
max_rand_half_height: float = 24.0,
|
||||
jitter_level: float = 0.2,
|
||||
loss_text: Dict = dict(
|
||||
type='MaskedBalancedBCELoss', fallback_negative_num=100, eps=1e-5),
|
||||
loss_center: Dict = dict(type='MaskedBCELoss'),
|
||||
type='MaskedBalancedBCEWithLogitsLoss',
|
||||
fallback_negative_num=100,
|
||||
eps=1e-5),
|
||||
loss_center: Dict = dict(type='MaskedBCEWithLogitsLoss'),
|
||||
loss_top: Dict = dict(type='SmoothL1Loss', reduction='none'),
|
||||
loss_btm: Dict = dict(type='SmoothL1Loss', reduction='none'),
|
||||
loss_sin: Dict = dict(type='MaskedSmoothL1Loss'),
|
||||
|
@ -185,16 +187,16 @@ class DRRGModuleLoss(TextSnakeModuleLoss):
|
|||
pred_sin_map = pred_sin_map * scale
|
||||
pred_cos_map = pred_cos_map * scale
|
||||
|
||||
loss_text = self.loss_text(pred_text_region.sigmoid(),
|
||||
gt['gt_text_masks'], gt['gt_masks'])
|
||||
loss_text = self.loss_text(pred_text_region, gt['gt_text_masks'],
|
||||
gt['gt_masks'])
|
||||
|
||||
text_mask = (gt['gt_text_masks'] * gt['gt_masks']).float()
|
||||
negative_text_mask = ((1 - gt['gt_text_masks']) *
|
||||
gt['gt_masks']).float()
|
||||
loss_center_positive = self.loss_center(pred_center_region.sigmoid(),
|
||||
loss_center_positive = self.loss_center(pred_center_region,
|
||||
gt['gt_center_region_masks'],
|
||||
text_mask)
|
||||
loss_center_negative = self.loss_center(pred_center_region.sigmoid(),
|
||||
loss_center_negative = self.loss_center(pred_center_region,
|
||||
gt['gt_center_region_masks'],
|
||||
negative_text_mask)
|
||||
loss_center = loss_center_positive + 0.5 * loss_center_negative
|
||||
|
|
|
@ -47,8 +47,10 @@ class TextSnakeModuleLoss(nn.Module, TextKernelMixin):
|
|||
resample_step: float = 4.0,
|
||||
center_region_shrink_ratio: float = 0.3,
|
||||
loss_text: Dict = dict(
|
||||
type='MaskedBalancedBCELoss', fallback_negative_num=100, eps=1e-5),
|
||||
loss_center: Dict = dict(type='MaskedBCELoss'),
|
||||
type='MaskedBalancedBCEWithLogitsLoss',
|
||||
fallback_negative_num=100,
|
||||
eps=1e-5),
|
||||
loss_center: Dict = dict(type='MaskedBCEWithLogitsLoss'),
|
||||
loss_radius: Dict = dict(type='MaskedSmoothL1Loss'),
|
||||
loss_sin: Dict = dict(type='MaskedSmoothL1Loss'),
|
||||
loss_cos: Dict = dict(type='MaskedSmoothL1Loss')
|
||||
|
@ -146,11 +148,11 @@ class TextSnakeModuleLoss(nn.Module, TextKernelMixin):
|
|||
pred_sin_map = pred_sin_map * scale
|
||||
pred_cos_map = pred_cos_map * scale
|
||||
|
||||
loss_text = self.loss_text(pred_text_region.sigmoid(),
|
||||
gt['gt_text_masks'], gt['gt_masks'])
|
||||
loss_text = self.loss_text(pred_text_region, gt['gt_text_masks'],
|
||||
gt['gt_masks'])
|
||||
|
||||
text_mask = (gt['gt_text_masks'] * gt['gt_masks']).float()
|
||||
loss_center = self.loss_center(pred_center_region.sigmoid(),
|
||||
loss_center = self.loss_center(pred_center_region,
|
||||
gt['gt_center_region_masks'], text_mask)
|
||||
|
||||
center_mask = (gt['gt_center_region_masks'] * gt['gt_masks']).float()
|
||||
|
|
|
@ -39,15 +39,15 @@ class BaseTextDetPostProcessor:
|
|||
self.test_cfg = test_cfg
|
||||
|
||||
def __call__(self,
|
||||
pred_results: dict,
|
||||
pred_results: Union[Tensor, List[Tensor]],
|
||||
data_samples: Sequence[TextDetDataSample],
|
||||
training: bool = False) -> Sequence[TextDetDataSample]:
|
||||
"""Postprocess pred_results according to metainfos in data_samples.
|
||||
|
||||
Args:
|
||||
pred_results (dict): The prediction results stored in a dictionary.
|
||||
Usually each item to be post-processed is expected to be a
|
||||
batched tensor.
|
||||
pred_results (Union[Tensor, List[Tensor]]): The prediction results
|
||||
stored in a tensor or a list of tensor. Usually each item to
|
||||
be post-processed is expected to be a batched tensor.
|
||||
data_samples (list[TextDetDataSample]): Batch of data_samples,
|
||||
each corresponding to a prediction result.
|
||||
training (bool): Whether the model is in training mode. Defaults to
|
||||
|
@ -65,13 +65,14 @@ class BaseTextDetPostProcessor:
|
|||
|
||||
return results
|
||||
|
||||
def _process_single(self, pred_result: dict,
|
||||
def _process_single(self, pred_result: Union[Tensor, List[Tensor]],
|
||||
data_sample: TextDetDataSample,
|
||||
**kwargs) -> TextDetDataSample:
|
||||
"""Process prediction results from one image.
|
||||
|
||||
Args:
|
||||
pred_result (dict): Prediction results of an image.
|
||||
pred_result (Union[Tensor, List[Tensor]]): Prediction results of an
|
||||
image.
|
||||
data_sample (TextDetDataSample): Datasample of an image.
|
||||
"""
|
||||
|
||||
|
@ -103,13 +104,13 @@ class BaseTextDetPostProcessor:
|
|||
results.pred_instances[key], scale_factor, mode='div')
|
||||
return results
|
||||
|
||||
def get_text_instances(self, pred_results: dict,
|
||||
def get_text_instances(self, pred_results: Union[Tensor, List[Tensor]],
|
||||
data_sample: TextDetDataSample,
|
||||
**kwargs) -> TextDetDataSample:
|
||||
"""Get text instance predictions of one image.
|
||||
|
||||
Args:
|
||||
pred_result (dict): Prediction results of an image.
|
||||
pred_result (tuple(Tensor)): Prediction results of an image.
|
||||
data_sample (TextDetDataSample): Datasample of an image.
|
||||
**kwargs: Other parameters. Configurable via ``__init__.train_cfg``
|
||||
and ``__init__.test_cfg``.
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Sequence, Tuple
|
||||
from typing import Sequence
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
@ -59,14 +59,14 @@ class DBPostprocessor(BaseTextDetPostProcessor):
|
|||
self.epsilon_ratio = epsilon_ratio
|
||||
self.max_candidates = max_candidates
|
||||
|
||||
def get_text_instances(self, pred_results: Tuple[Tensor, Tensor, Tensor],
|
||||
def get_text_instances(self, prob_map: Tensor,
|
||||
data_sample: TextDetDataSample
|
||||
) -> TextDetDataSample:
|
||||
"""Get text instance predictions of one image.
|
||||
|
||||
Args:
|
||||
pred_result (tuple(Tensor)): A tuple of 3 tensors where the first
|
||||
tensor is ``prob_map`` of shape :math:`(N, H, W)`.
|
||||
pred_result (Tensor): DBNet's output ``prob_map`` of shape
|
||||
:math:`(H, W)`.
|
||||
data_sample (TextDetDataSample): Datasample of an image.
|
||||
|
||||
Returns:
|
||||
|
@ -80,7 +80,6 @@ class DBPostprocessor(BaseTextDetPostProcessor):
|
|||
data_sample.pred_instances.polygons = []
|
||||
data_sample.pred_instances.scores = []
|
||||
|
||||
prob_map = pred_results[0]
|
||||
text_mask = prob_map > self.mask_thr
|
||||
|
||||
score_map = prob_map.data.cpu().numpy().astype(np.float32)
|
||||
|
|
|
@ -3,7 +3,9 @@ from unittest import TestCase
|
|||
|
||||
import torch
|
||||
|
||||
from mmocr.models.common.losses import MaskedBalancedBCELoss, MaskedBCELoss
|
||||
from mmocr.models.common.losses import (MaskedBalancedBCELoss,
|
||||
MaskedBalancedBCEWithLogitsLoss,
|
||||
MaskedBCELoss, MaskedBCEWithLogitsLoss)
|
||||
|
||||
|
||||
class TestMaskedBalancedBCELoss(TestCase):
|
||||
|
@ -117,3 +119,117 @@ class TestMaskedBCELoss(TestCase):
|
|||
zero_mask = torch.FloatTensor([0, 0, 0, 0])
|
||||
self.assertAlmostEqual(
|
||||
self.bce_loss(self.pred, self.gt, zero_mask).item(), 0)
|
||||
|
||||
|
||||
class TestMaskedBalancedWithLogitsBCELoss(TestCase):
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.loss = MaskedBalancedBCEWithLogitsLoss(negative_ratio=2)
|
||||
self.pred = torch.FloatTensor([1.5, 1.5, 1.5, 1.5])
|
||||
self.gt = torch.FloatTensor([1, 1, 1, 0])
|
||||
self.mask = torch.BoolTensor([True, False, False, True])
|
||||
|
||||
def test_init(self):
|
||||
with self.assertRaises(AssertionError):
|
||||
MaskedBalancedBCEWithLogitsLoss(reduction='any')
|
||||
|
||||
with self.assertRaises(AssertionError):
|
||||
MaskedBalancedBCEWithLogitsLoss(negative_ratio='a')
|
||||
|
||||
with self.assertRaises(AssertionError):
|
||||
MaskedBalancedBCEWithLogitsLoss(eps='a')
|
||||
|
||||
with self.assertRaises(AssertionError):
|
||||
MaskedBalancedBCEWithLogitsLoss(fallback_negative_num='a')
|
||||
|
||||
def test_forward(self):
|
||||
|
||||
# Shape mismatch between pred and gt
|
||||
with self.assertRaises(AssertionError):
|
||||
invalid_gt = torch.FloatTensor([0, 0, 0])
|
||||
self.loss(self.pred, invalid_gt)
|
||||
|
||||
# Shape mismatch between pred and mask
|
||||
with self.assertRaises(AssertionError):
|
||||
invalid_mask = torch.BoolTensor([True, False, False])
|
||||
self.loss(self.pred, self.gt, invalid_mask)
|
||||
|
||||
# Invalid gt
|
||||
with self.assertRaises(AssertionError):
|
||||
invalid_gt = torch.FloatTensor([2, 3, 4, 5])
|
||||
self.loss(self.pred, invalid_gt, self.mask)
|
||||
|
||||
logit = torch.FloatTensor([1.5])
|
||||
self.assertAlmostEqual(
|
||||
self.loss(self.pred, self.gt).item(),
|
||||
((-torch.log(torch.sigmoid(logit)) * 3 -
|
||||
torch.log(1 - torch.sigmoid(logit))) / 4).item(),
|
||||
delta=0.0001)
|
||||
self.assertAlmostEqual(
|
||||
self.loss(self.pred, self.gt, self.mask).item(),
|
||||
(-torch.log(torch.sigmoid(logit)) -
|
||||
torch.log(1 - torch.sigmoid(logit))).item() / 2,
|
||||
delta=0.0001)
|
||||
|
||||
# Test zero mask
|
||||
zero_mask = torch.FloatTensor([0, 0, 0, 0])
|
||||
self.assertAlmostEqual(
|
||||
self.loss(self.pred, self.gt, zero_mask).item(), 0)
|
||||
|
||||
# Test 0 < fallback_negative_num < negative numbers
|
||||
all_neg_gt = torch.zeros((4, ))
|
||||
self.fallback_bce_loss = MaskedBalancedBCEWithLogitsLoss(
|
||||
fallback_negative_num=1)
|
||||
self.assertAlmostEqual(
|
||||
self.fallback_bce_loss(self.pred, all_neg_gt, self.mask).item(),
|
||||
-torch.log(1 - torch.sigmoid(logit)).item(),
|
||||
delta=0.001)
|
||||
# Test fallback_negative_num > negative numbers
|
||||
self.fallback_bce_loss = MaskedBalancedBCEWithLogitsLoss(
|
||||
fallback_negative_num=5)
|
||||
self.assertAlmostEqual(
|
||||
self.fallback_bce_loss(self.pred, all_neg_gt, self.mask).item(),
|
||||
-torch.log(1 - torch.sigmoid(logit)).item(),
|
||||
delta=0.001)
|
||||
|
||||
|
||||
class TestMaskedBCEWithLogitsLoss(TestCase):
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.loss = MaskedBCEWithLogitsLoss()
|
||||
self.pred = torch.FloatTensor([1.5, 1.5, 1.5, 1.5])
|
||||
self.gt = torch.FloatTensor([1, 1, 1, 0])
|
||||
self.mask = torch.BoolTensor([True, False, False, True])
|
||||
|
||||
def test_init(self):
|
||||
with self.assertRaises(AssertionError):
|
||||
MaskedBCEWithLogitsLoss(eps='a')
|
||||
|
||||
def test_forward(self):
|
||||
|
||||
# Shape mismatch between pred and gt
|
||||
with self.assertRaises(AssertionError):
|
||||
invalid_gt = torch.FloatTensor([0, 0, 0])
|
||||
self.loss(self.pred, invalid_gt)
|
||||
|
||||
# Shape mismatch between pred and mask
|
||||
with self.assertRaises(AssertionError):
|
||||
invalid_mask = torch.BoolTensor([True, False, False])
|
||||
self.loss(self.pred, self.gt, invalid_mask)
|
||||
|
||||
logit = torch.FloatTensor([1.5])
|
||||
self.assertAlmostEqual(
|
||||
self.loss(self.pred, self.gt).item(),
|
||||
((-torch.log(torch.sigmoid(logit)) * 3 -
|
||||
torch.log(1 - torch.sigmoid(logit))) / 4).item(),
|
||||
delta=0.0001)
|
||||
self.assertAlmostEqual(
|
||||
self.loss(self.pred, self.gt, self.mask).item(),
|
||||
(-torch.log(torch.sigmoid(logit)) -
|
||||
torch.log(1 - torch.sigmoid(logit))).item() / 2,
|
||||
delta=0.0001)
|
||||
|
||||
# Test zero mask
|
||||
zero_mask = torch.FloatTensor([0, 0, 0, 0])
|
||||
self.assertAlmostEqual(
|
||||
self.loss(self.pred, self.gt, zero_mask).item(), 0)
|
||||
|
|
|
@ -4,10 +4,24 @@ from unittest import TestCase
|
|||
import torch
|
||||
|
||||
from mmocr.models.textdet.heads import DBHead
|
||||
from mmocr.registry import MODELS
|
||||
|
||||
|
||||
class TestDBHead(TestCase):
|
||||
|
||||
# Use to replace module loss and postprocessors
|
||||
@MODELS.register_module(name='DBDummy')
|
||||
class DummyModule:
|
||||
|
||||
def __call__(self, x, data_samples):
|
||||
return x
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.db_head = DBHead(
|
||||
in_channels=10,
|
||||
module_loss=dict(type='DBDummy'),
|
||||
postprocessor=dict(type='DBDummy'))
|
||||
|
||||
def test_init(self):
|
||||
with self.assertRaises(AssertionError):
|
||||
DBHead(in_channels='test', with_bias=False)
|
||||
|
@ -16,9 +30,28 @@ class TestDBHead(TestCase):
|
|||
DBHead(in_channels=1, with_bias='Text')
|
||||
|
||||
def test_forward(self):
|
||||
db_head = DBHead(in_channels=10)
|
||||
data = torch.randn((2, 10, 40, 50))
|
||||
results = db_head(data, None)
|
||||
results = self.db_head(data, None)
|
||||
self.assertEqual(results[0].shape, (2, 160, 200))
|
||||
self.assertEqual(results[1].shape, (2, 160, 200))
|
||||
self.assertEqual(results[2].shape, (2, 160, 200))
|
||||
|
||||
def test_loss(self):
|
||||
data = torch.randn((2, 10, 40, 50))
|
||||
results = self.db_head.loss(data, None)
|
||||
for i in range(3):
|
||||
self.assertEqual(results[i].shape, (2, 160, 200))
|
||||
|
||||
def test_predict(self):
|
||||
data = torch.randn((2, 10, 40, 50))
|
||||
results = self.db_head.predict(data, None)
|
||||
self.assertEqual(results.shape, (2, 160, 200))
|
||||
|
||||
def test_loss_and_predict(self):
|
||||
data = torch.randn((2, 10, 40, 50))
|
||||
loss_results, pred_results = self.db_head.loss_and_predict(data, None)
|
||||
for i in range(3):
|
||||
self.assertEqual(loss_results[i].shape, (2, 160, 200))
|
||||
self.assertEqual(pred_results.shape, (2, 160, 200))
|
||||
self.assertTrue(
|
||||
torch.allclose(pred_results, loss_results[0].sigmoid()))
|
||||
|
|
|
@ -23,7 +23,7 @@ class TestDBPostProcessor(unittest.TestCase):
|
|||
def test_get_text_instances(self, text_repr_type):
|
||||
|
||||
postprocessor = DBPostprocessor(text_repr_type=text_repr_type)
|
||||
pred_result = (torch.rand(4, 5), torch.rand(4, 5), torch.rand(4, 5))
|
||||
pred_result = torch.rand(4, 5)
|
||||
data_sample = TextDetDataSample(
|
||||
metainfo=dict(scale_factor=(0.5, 1)),
|
||||
gt_instances=InstanceData(polygons=[
|
||||
|
@ -36,12 +36,11 @@ class TestDBPostProcessor(unittest.TestCase):
|
|||
self.assertTrue(
|
||||
isinstance(results.pred_instances['scores'], torch.FloatTensor))
|
||||
|
||||
preds = (torch.FloatTensor([[0.8, 0.8, 0.8, 0.8, 0],
|
||||
[0.8, 0.8, 0.8, 0.8, 0],
|
||||
[0.8, 0.8, 0.8, 0.8, 0],
|
||||
[0.8, 0.8, 0.8, 0.8, 0],
|
||||
[0.8, 0.8, 0.8, 0.8, 0]]),
|
||||
torch.rand([1, 10]), torch.rand([1, 10]))
|
||||
preds = torch.FloatTensor([[0.8, 0.8, 0.8, 0.8, 0],
|
||||
[0.8, 0.8, 0.8, 0.8, 0],
|
||||
[0.8, 0.8, 0.8, 0.8, 0],
|
||||
[0.8, 0.8, 0.8, 0.8, 0],
|
||||
[0.8, 0.8, 0.8, 0.8, 0]])
|
||||
postprocessor = DBPostprocessor(
|
||||
text_repr_type=text_repr_type, min_text_width=0)
|
||||
results = postprocessor.get_text_instances(preds, data_sample)
|
||||
|
@ -49,8 +48,7 @@ class TestDBPostProcessor(unittest.TestCase):
|
|||
|
||||
postprocessor = DBPostprocessor(
|
||||
min_text_score=1, text_repr_type=text_repr_type)
|
||||
pred_result = (torch.rand(4, 5) * 0.8, torch.rand(4, 5) * 0.8,
|
||||
torch.rand(4, 5) * 0.8)
|
||||
pred_result = torch.rand(4, 5) * 0.8
|
||||
results = postprocessor.get_text_instances(pred_result, data_sample)
|
||||
self.assertEqual(results.pred_instances.polygons, [])
|
||||
self.assertTrue(
|
||||
|
|
Loading…
Reference in New Issue