[FCENet] Add FCENet loss

pull/1178/head
jiangqing.vendor 2022-06-21 04:55:25 +00:00 committed by gaotongxiao
parent 17606c25fc
commit fded755af2
6 changed files with 490 additions and 84 deletions

View File

@ -1,9 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .bce_loss import MaskedBalancedBCELoss
from .dice_loss import MaskedDiceLoss, MaskedSquareDiceLoss
from .l1_loss import MaskedSmoothL1Loss
from .l1_loss import MaskedSmoothL1Loss, SmoothL1Loss
__all__ = [
'MaskedBalancedBCELoss', 'MaskedDiceLoss', 'MaskedSmoothL1Loss',
'MaskedSquareDiceLoss'
'MaskedSquareDiceLoss', 'SmoothL1Loss'
]

View File

@ -130,4 +130,4 @@ class MaskedBCELoss(nn.Module):
assert pred.max() <= 1 and pred.min() >= 0
loss = self.binary_cross_entropy(pred, gt)
return (loss * mask) / (mask.sum() + self.eps)
return (loss * mask).sum() / (mask.sum() + self.eps)

View File

@ -7,6 +7,11 @@ import torch.nn as nn
from mmocr.registry import MODELS
@MODELS.register_module()
class SmoothL1Loss(nn.SmoothL1Loss):
"""Smooth L1 loss."""
@MODELS.register_module()
class MaskedSmoothL1Loss(nn.Module):
"""Masked Smooth L1 loss.

View File

@ -1,15 +1,21 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Optional, Sequence, Tuple, Union
import cv2
import numpy as np
import torch
import torch.nn.functional as F
from mmdet.core import multi_apply
from torch import nn
from numpy.fft import fft
from numpy.linalg import norm
from numpy.typing import ArrayLike
from mmocr.core import TextDetDataSample
from mmocr.registry import MODELS
from .textsnake_loss import TextSnakeLoss
@MODELS.register_module()
class FCELoss(nn.Module):
class FCELoss(TextSnakeLoss):
"""The class for implementing FCENet loss.
FCENet(CVPR2021): `Fourier Contour Embedding for Arbitrary-shaped Text
@ -19,43 +25,75 @@ class FCELoss(nn.Module):
fourier_degree (int) : The maximum Fourier transform degree k.
num_sample (int) : The sampling points number of regression
loss. If it is too small, fcenet tends to be overfitting.
ohem_ratio (float): the negative/positive ratio in OHEM.
negative_ratio (float or int): Maximum ratio of negative
samples to positive ones in OHEM. Defaults to 3.
resample_step (float): The step size for resampling the text center
line (TCL). It's better not to exceed half of the minimum width.
center_region_shrink_ratio (float): The shrink ratio of text center
region.
level_size_divisors (tuple(int)): The downsample ratio on each level.
level_proportion_range (tuple(tuple(int))): The range of text sizes
assigned to each level.
loss_tr (dict) : The loss config used to calculate the text region
loss. Defaults to dict(type='MaskedBalancedBCELoss').
loss_tcl (dict) : The loss config used to calculate the text center
line loss. Defaults to dict(type='MaskedBCELoss').
loss_reg_x (dict) : The loss config used to calculate the regression
loss on x axis. Defaults to dict(type='MaskedSmoothL1Loss').
loss_reg_y (dict) : The loss config used to calculate the regression
loss on y axis. Defaults to dict(type='MaskedSmoothL1Loss').
"""
def __init__(self, fourier_degree, num_sample, ohem_ratio=3.):
def __init__(
self,
fourier_degree: int,
num_sample: int,
negative_ratio: Union[float, int] = 3.,
resample_step: float = 4.0,
center_region_shrink_ratio: float = 0.3,
level_size_divisors: Tuple[int] = (8, 16, 32),
level_proportion_range: Tuple[Tuple[int]] = ((0, 0.4), (0.3, 0.7),
(0.6, 1.0)),
loss_tr: Dict = dict(type='MaskedBalancedBCELoss'),
loss_tcl: Dict = dict(type='MaskedBCELoss'),
loss_reg_x: Dict = dict(type='SmoothL1Loss', reduction='none'),
loss_reg_y: Dict = dict(type='SmoothL1Loss', reduction='none'),
) -> None:
super().__init__()
self.fourier_degree = fourier_degree
self.num_sample = num_sample
self.ohem_ratio = ohem_ratio
self.resample_step = resample_step
self.center_region_shrink_ratio = center_region_shrink_ratio
self.level_size_divisors = level_size_divisors
self.level_proportion_range = level_proportion_range
def forward(self, preds, _, p3_maps, p4_maps, p5_maps):
loss_tr.update(negative_ratio=negative_ratio)
self.loss_tr = MODELS.build(loss_tr)
self.loss_tcl = MODELS.build(loss_tcl)
self.loss_reg_x = MODELS.build(loss_reg_x)
self.loss_reg_y = MODELS.build(loss_reg_y)
def forward(self, preds: Sequence[Dict],
data_samples: Sequence[TextDetDataSample]) -> Dict:
"""Compute FCENet loss.
Args:
preds (list[list[Tensor]]): The outer list indicates images
in a batch, and the inner list indicates the classification
prediction map (with shape :math:`(N, C, H, W)`) and
regression map (with shape :math:`(N, C, H, W)`).
p3_maps (list[ndarray]): List of leval 3 ground truth target map
with shape :math:`(C, H, W)`.
p4_maps (list[ndarray]): List of leval 4 ground truth target map
with shape :math:`(C, H, W)`.
p5_maps (list[ndarray]): List of leval 5 ground truth target map
with shape :math:`(C, H, W)`.
preds (list[dict]): A list of dict with keys of ``cls_res``,
``reg_res`` corresponds to the classification result and
regression result computed from the input tensor with the
same index. They have the shapes of :math:`(N, C_{cls,i}, H_i,
W_i)` and :math: `(N, C_{out,i}, H_i, W_i)`.
data_samples (list[TextDetDataSample]): The data samples.
Returns:
dict: A loss dict with ``loss_text``, ``loss_center``,
``loss_reg_x`` and ``loss_reg_y``.
dict: The dict for fcenet losses with loss_text, loss_center,
loss_reg_x and loss_reg_y.
"""
assert isinstance(preds, list)
assert p3_maps[0].shape[0] == 4 * self.fourier_degree + 5,\
'fourier degree not equal in FCEhead and FCEtarget'
device = preds[0][0].device
# to tensor
gts = [p3_maps, p4_maps, p5_maps]
for idx, maps in enumerate(gts):
gts[idx] = torch.from_numpy(np.stack(maps)).float().to(device)
assert isinstance(preds, list) and len(preds) == 3
p3_maps, p4_maps, p5_maps = self.get_targets(data_samples)
device = preds[0]['cls_res'].device
# to device
gts = [p3_maps.to(device), p4_maps.to(device), p5_maps.to(device)]
losses = multi_apply(self.forward_single, preds, gts)
@ -83,9 +121,24 @@ class FCELoss(nn.Module):
return results
def forward_single(self, pred, gt):
cls_pred = pred[0].permute(0, 2, 3, 1).contiguous()
reg_pred = pred[1].permute(0, 2, 3, 1).contiguous()
def forward_single(self, pred: torch.Tensor,
gt: torch.Tensor) -> Sequence[torch.Tensor]:
"""Compute loss for one feature level.
Args:
pred (dict): A dict with keys ``cls_res`` and ``reg_res``
corresponds to the classification result and regression result
from one feature level.
gt (Tensor): Ground truth for one feature level. Cls and reg
targets are concatenated along the channel dimension.
Returns:
list[Tensor]: A list of losses for each feature level.
"""
assert isinstance(pred, dict) and isinstance(gt, torch.Tensor)
cls_pred = pred['cls_res'].permute(0, 2, 3, 1).contiguous()
reg_pred = pred['reg_res'].permute(0, 2, 3, 1).contiguous()
gt = gt.permute(0, 2, 3, 1).contiguous()
k = 2 * self.fourier_degree + 1
@ -100,70 +153,371 @@ class FCELoss(nn.Module):
x_map = gt[:, :, :, 3:3 + k].view(-1, k)
y_map = gt[:, :, :, 3 + k:].view(-1, k)
tr_train_mask = train_mask * tr_mask
device = x_map.device
# tr loss
loss_tr = self.ohem(tr_pred, tr_mask.long(), train_mask.long())
tr_train_mask = (train_mask * tr_mask).float()
# text region loss
loss_tr = self.loss_tr(tr_pred.softmax(-1)[:, 1], tr_mask, train_mask)
# tcl loss
loss_tcl = torch.tensor(0.).float().to(device)
# text center line loss
tr_neg_mask = 1 - tr_train_mask
if tr_train_mask.sum().item() > 0:
loss_tcl_pos = F.cross_entropy(
tcl_pred[tr_train_mask.bool()],
tcl_mask[tr_train_mask.bool()].long())
loss_tcl_neg = F.cross_entropy(tcl_pred[tr_neg_mask.bool()],
tcl_mask[tr_neg_mask.bool()].long())
loss_tcl = loss_tcl_pos + 0.5 * loss_tcl_neg
loss_tcl_positive = self.loss_center(
tcl_pred.softmax(-1)[:, 1], tcl_mask, tr_train_mask)
loss_tcl_negative = self.loss_center(
tcl_pred.softmax(-1)[:, 1], tcl_mask, tr_neg_mask)
loss_tcl = loss_tcl_positive + 0.5 * loss_tcl_negative
# regression loss
loss_reg_x = torch.tensor(0.).float().to(device)
loss_reg_y = torch.tensor(0.).float().to(device)
loss_reg_x = torch.tensor(0.).float().to(x_pred.device)
loss_reg_y = torch.tensor(0.).float().to(x_pred.device)
if tr_train_mask.sum().item() > 0:
weight = (tr_mask[tr_train_mask.bool()].float() +
tcl_mask[tr_train_mask.bool()].float()) / 2
weight = weight.contiguous().view(-1, 1)
ft_x, ft_y = self.fourier2poly(x_map, y_map)
ft_x_pre, ft_y_pre = self.fourier2poly(x_pred, y_pred)
ft_x, ft_y = self._fourier2poly(x_map, y_map)
ft_x_pre, ft_y_pre = self._fourier2poly(x_pred, y_pred)
loss_reg_x = torch.mean(weight * F.smooth_l1_loss(
ft_x_pre[tr_train_mask.bool()],
ft_x[tr_train_mask.bool()],
reduction='none'))
loss_reg_y = torch.mean(weight * F.smooth_l1_loss(
ft_y_pre[tr_train_mask.bool()],
ft_y[tr_train_mask.bool()],
reduction='none'))
loss_reg_x = torch.mean(weight * self.loss_reg_x(
ft_x_pre[tr_train_mask.bool()], ft_x[tr_train_mask.bool()]))
loss_reg_y = torch.mean(weight * self.loss_reg_x(
ft_y_pre[tr_train_mask.bool()], ft_y[tr_train_mask.bool()]))
return loss_tr, loss_tcl, loss_reg_x, loss_reg_y
def ohem(self, predict, target, train_mask):
device = train_mask.device
pos = (target * train_mask).bool()
neg = ((1 - target) * train_mask).bool()
def get_targets(self, data_samples: List[TextDetDataSample]) -> Tuple:
"""Generate loss targets for fcenet from data samples.
n_pos = pos.float().sum()
Args:
data_samples (list(TextDetDataSample)): Ground truth data samples.
if n_pos.item() > 0:
loss_pos = F.cross_entropy(
predict[pos], target[pos], reduction='sum')
loss_neg = F.cross_entropy(
predict[neg], target[neg], reduction='none')
n_neg = min(
int(neg.float().sum().item()),
int(self.ohem_ratio * n_pos.float()))
Returns:
tuple[Tensor]: A tuple of three tensors from three different
feature level as FCENet targets.
"""
p3_maps, p4_maps, p5_maps = multi_apply(self._get_target_single,
data_samples)
p3_maps = torch.cat(p3_maps, 0)
p4_maps = torch.cat(p4_maps, 0)
p5_maps = torch.cat(p5_maps, 0)
return p3_maps, p4_maps, p5_maps
def _get_target_single(self, data_sample: TextDetDataSample) -> Tuple:
"""Generate loss target for fcenet from a data sample.
Args:
data_sample (TextDetDataSample): The data sample.
Returns:
tuple[Tensor]: A tuple of three tensors from three different
feature level as the targets of one prediction.
"""
img_size = data_sample.img_shape[:2]
text_polys = data_sample.gt_instances.polygons
ignore_flags = data_sample.gt_instances.ignored
p3_map, p4_map, p5_map = self._generate_level_targets(
img_size, text_polys, ignore_flags)
# to tesnor
p3_map = torch.from_numpy(p3_map).unsqueeze(0).float()
p4_map = torch.from_numpy(p4_map).unsqueeze(0).float()
p5_map = torch.from_numpy(p5_map).unsqueeze(0).float()
return p3_map, p4_map, p5_map
def _generate_level_targets(self,
img_size: Tuple[int, int],
text_polys: List[ArrayLike],
ignore_flags: Optional[torch.BoolTensor] = None
) -> Tuple[torch.Tensor]:
"""Generate targets for one feature level.
Args:
img_size (tuple(int, int)): The image size of (height, width).
text_polys (List[ndarray]): 2D array of text polygons.
ignore_flags (torch.BoolTensor, optional): Indicate whether the
corresponding text polygon is ignored. Defaults to None.
Returns:
tuple[Tensor]: A tuple of three tensors from one feature level
as the targets.
"""
h, w = img_size
lv_size_divs = self.level_size_divisors
lv_proportion_range = self.level_proportion_range
lv_size_divs = self.level_size_divisors
lv_proportion_range = self.level_proportion_range
lv_text_polys = [[] for i in range(len(lv_size_divs))]
lv_ignore_polys = [[] for i in range(len(lv_size_divs))]
level_maps = []
for poly_ind, poly in enumerate(text_polys):
poly = np.array(poly, dtype=np.int).reshape((1, -1, 2))
_, _, box_w, box_h = cv2.boundingRect(poly)
proportion = max(box_h, box_w) / (h + 1e-8)
for ind, proportion_range in enumerate(lv_proportion_range):
if proportion_range[0] < proportion < proportion_range[1]:
if ignore_flags is not None and ignore_flags[poly_ind]:
lv_ignore_polys[ind].append(poly[0] /
lv_size_divs[ind])
else:
lv_text_polys[ind].append(poly[0] / lv_size_divs[ind])
for ind, size_divisor in enumerate(lv_size_divs):
current_level_maps = []
level_img_size = (h // size_divisor, w // size_divisor)
text_region = self._generate_text_region_mask(
level_img_size, lv_text_polys[ind])[None]
current_level_maps.append(text_region)
center_region = self._generate_center_region_mask(
level_img_size, lv_text_polys[ind])[None]
current_level_maps.append(center_region)
effective_mask = self._generate_effective_mask(
level_img_size, lv_ignore_polys[ind])[None]
current_level_maps.append(effective_mask)
fourier_real_map, fourier_image_maps = self._generate_fourier_maps(
level_img_size, lv_text_polys[ind])
current_level_maps.append(fourier_real_map)
current_level_maps.append(fourier_image_maps)
level_maps.append(np.concatenate(current_level_maps))
return level_maps
def _generate_center_region_mask(self, img_size: Tuple[int, int],
text_polys: ArrayLike) -> np.ndarray:
"""Generate text center region mask.
Args:
img_size (tuple): The image size of (height, width).
text_polys (list[ndarray]): The list of text polygons.
Returns:
ndarray: The text center region mask.
"""
assert isinstance(img_size, tuple)
h, w = img_size
center_region_mask = np.zeros((h, w), np.uint8)
center_region_boxes = []
for poly in text_polys:
polygon_points = poly.reshape(-1, 2)
_, _, top_line, bot_line = self._reorder_poly_edge(polygon_points)
resampled_top_line, resampled_bot_line = self._resample_sidelines(
top_line, bot_line, self.resample_step)
resampled_bot_line = resampled_bot_line[::-1]
center_line = (resampled_top_line + resampled_bot_line) / 2
line_head_shrink_len = norm(resampled_top_line[0] -
resampled_bot_line[0]) / 4.0
line_tail_shrink_len = norm(resampled_top_line[-1] -
resampled_bot_line[-1]) / 4.0
head_shrink_num = int(line_head_shrink_len // self.resample_step)
tail_shrink_num = int(line_tail_shrink_len // self.resample_step)
if len(center_line) > head_shrink_num + tail_shrink_num + 2:
center_line = center_line[head_shrink_num:len(center_line) -
tail_shrink_num]
resampled_top_line = resampled_top_line[
head_shrink_num:len(resampled_top_line) - tail_shrink_num]
resampled_bot_line = resampled_bot_line[
head_shrink_num:len(resampled_bot_line) - tail_shrink_num]
for i in range(0, len(center_line) - 1):
tl = center_line[i] + (resampled_top_line[i] - center_line[i]
) * self.center_region_shrink_ratio
tr = center_line[i + 1] + (
resampled_top_line[i + 1] -
center_line[i + 1]) * self.center_region_shrink_ratio
br = center_line[i + 1] + (
resampled_bot_line[i + 1] -
center_line[i + 1]) * self.center_region_shrink_ratio
bl = center_line[i] + (resampled_bot_line[i] - center_line[i]
) * self.center_region_shrink_ratio
current_center_box = np.vstack([tl, tr, br,
bl]).astype(np.int32)
center_region_boxes.append(current_center_box)
cv2.fillPoly(center_region_mask, center_region_boxes, 1)
return center_region_mask
def _generate_fourier_maps(self, img_size: Tuple[int, int],
text_polys: ArrayLike
) -> Tuple[np.ndarray, np.ndarray]:
"""Generate Fourier coefficient maps.
Args:
img_size (tuple): The image size of (height, width).
text_polys (list[ndarray]): The list of text polygons.
Returns:
tuple(ndarray, ndarray):
- fourier_real_map (ndarray): The Fourier coefficient real part
maps.
- fourier_image_map (ndarray): The Fourier coefficient image part
maps.
"""
assert isinstance(img_size, tuple)
h, w = img_size
k = self.fourier_degree
real_map = np.zeros((k * 2 + 1, h, w), dtype=np.float32)
imag_map = np.zeros((k * 2 + 1, h, w), dtype=np.float32)
for poly in text_polys:
mask = np.zeros((h, w), dtype=np.uint8)
polygon = np.array(poly).reshape((1, -1, 2))
cv2.fillPoly(mask, polygon.astype(np.int32), 1)
fourier_coeff = self._cal_fourier_signature(polygon[0], k)
for i in range(-k, k + 1):
if i != 0:
real_map[i + k, :, :] = mask * fourier_coeff[i + k, 0] + (
1 - mask) * real_map[i + k, :, :]
imag_map[i + k, :, :] = mask * fourier_coeff[i + k, 1] + (
1 - mask) * imag_map[i + k, :, :]
else:
yx = np.argwhere(mask > 0.5)
k_ind = np.ones((len(yx)), dtype=np.int64) * k
y, x = yx[:, 0], yx[:, 1]
real_map[k_ind, y, x] = fourier_coeff[k, 0] - x
imag_map[k_ind, y, x] = fourier_coeff[k, 1] - y
return real_map, imag_map
def _cal_fourier_signature(self, polygon: ArrayLike,
fourier_degree: int) -> np.ndarray:
"""Calculate Fourier signature from input polygon.
Args:
polygon (list[ndarray]): The input polygon.
fourier_degree (int): The maximum Fourier degree K.
Returns:
ndarray: An array shaped (2k+1, 2) containing
real part and image part of 2k+1 Fourier coefficients.
"""
resampled_polygon = self._resample_polygon(polygon)
resampled_polygon = self._normalize_polygon(resampled_polygon)
fourier_coeff = self._poly2fourier(resampled_polygon, fourier_degree)
fourier_coeff = self._clockwise(fourier_coeff, fourier_degree)
real_part = np.real(fourier_coeff).reshape((-1, 1))
image_part = np.imag(fourier_coeff).reshape((-1, 1))
fourier_signature = np.hstack([real_part, image_part])
return fourier_signature
def _resample_polygon(self,
polygon: ArrayLike,
n: int = 400) -> np.ndarray:
"""Resample one polygon with n points on its boundary.
Args:
polygon (list[ndarray]): The input polygon.
n (int): The number of resampled points. Defaults to 400.
Returns:
ndarray: The resampled polygon.
"""
length = []
for i in range(len(polygon)):
p1 = polygon[i]
if i == len(polygon) - 1:
p2 = polygon[0]
else:
p2 = polygon[i + 1]
length.append(((p1[0] - p2[0])**2 + (p1[1] - p2[1])**2)**0.5)
total_length = sum(length)
n_on_each_line = (np.array(length) / (total_length + 1e-8)) * n
n_on_each_line = n_on_each_line.astype(np.int32)
new_polygon = []
for i in range(len(polygon)):
num = n_on_each_line[i]
p1 = polygon[i]
if i == len(polygon) - 1:
p2 = polygon[0]
else:
p2 = polygon[i + 1]
if num == 0:
continue
dxdy = (p2 - p1) / num
for j in range(num):
point = p1 + dxdy * j
new_polygon.append(point)
return np.array(new_polygon)
def _normalize_polygon(self, polygon: ArrayLike) -> np.ndarray:
"""Normalize one polygon so that its start point is at right most.
Args:
polygon (list[ndarray]): The origin polygon.
Returns:
ndarray: The polygon with start point at right.
"""
temp_polygon = polygon - polygon.mean(axis=0)
x = np.abs(temp_polygon[:, 0])
y = temp_polygon[:, 1]
index_x = np.argsort(x)
index_y = np.argmin(y[index_x[:8]])
index = index_x[index_y]
new_polygon = np.concatenate([polygon[index:], polygon[:index]])
return new_polygon
def _clockwise(self, fourier_coeff: np.ndarray,
fourier_degree: int) -> np.ndarray:
"""Make sure the polygon reconstructed from Fourier coefficients c in
the clockwise direction.
Args:
fourier_coeff (ndarray[complex]): The Fourier coefficients.
fourier_degree: The maximum Fourier degree K.
Returns:
lost[float]: The polygon in clockwise point order.
"""
if np.abs(fourier_coeff[fourier_degree + 1]) > np.abs(
fourier_coeff[fourier_degree - 1]):
return fourier_coeff
elif np.abs(fourier_coeff[fourier_degree + 1]) < np.abs(
fourier_coeff[fourier_degree - 1]):
return fourier_coeff[::-1]
else:
loss_pos = torch.tensor(0.).to(device)
loss_neg = F.cross_entropy(
predict[neg], target[neg], reduction='none')
n_neg = 100
if len(loss_neg) > n_neg:
loss_neg, _ = torch.topk(loss_neg, n_neg)
if np.abs(fourier_coeff[fourier_degree + 2]) > np.abs(
fourier_coeff[fourier_degree - 2]):
return fourier_coeff
else:
return fourier_coeff[::-1]
return (loss_pos + loss_neg.sum()) / (n_pos + n_neg).float()
def _poly2fourier(self, polygon: ArrayLike,
fourier_degree: int) -> np.ndarray:
"""Perform Fourier transformation to generate Fourier coefficients ck
from polygon.
def fourier2poly(self, real_maps, imag_maps):
Args:
polygon (list[ndarray]): An input polygon.
fourier_degree (int): The maximum Fourier degree K.
Returns:
ndarray: Fourier coefficients.
"""
points = polygon[:, 0] + polygon[:, 1] * 1j
c_fft = fft(points) / len(points)
c = np.hstack((c_fft[-fourier_degree:], c_fft[:fourier_degree + 1]))
return c
def _fourier2poly(self, real_maps: torch.Tensor,
imag_maps: torch.Tensor) -> Sequence[torch.Tensor]:
"""Transform Fourier coefficient maps to polygon maps.
Args:
@ -173,9 +527,11 @@ class FCELoss(nn.Module):
Fourier coefficients, whose shape is (-1, 2k+1)
Returns
x_maps (tensor): A map composed of the x value of the polygon
tuple(tensor, tensor):
- x_maps (tensor): A map composed of the x value of the polygon
represented by n sample points (xn, yn), whose shape is (-1, n)
y_maps (tensor): A map composed of the y value of the polygon
- y_maps (tensor): A map composed of the y value of the polygon
represented by n sample points (xn, yn), whose shape is (-1, n)
"""

View File

@ -29,8 +29,8 @@ class TextKernelMixin:
text_polys (Sequence[np.ndarray]): 2D array of text polygons.
shrink_ratio (float or int): The shrink ratio of kernel.
max_shrink_dist (float or int): The maximum shrinking distance.
ignore_flags (torch.BoolTensor, options): Indicate whether the
corresponding text polygon is ignored.
ignore_flags (torch.BoolTensor, optional): Indicate whether the
corresponding text polygon is ignored. Defaults to None.
Returns:
tuple(ndarray, ndarray): The text instance kernels of shape

View File

@ -0,0 +1,45 @@
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
import numpy as np
import torch
from mmengine import InstanceData
from mmocr.core import TextDetDataSample
from mmocr.models.textdet.losses import FCELoss
class TestFCELoss(TestCase):
def setUp(self) -> None:
self.fce_loss = FCELoss(fourier_degree=5, num_sample=400)
self.data_samples = [
TextDetDataSample(
metainfo=dict(img_shape=(320, 320)),
gt_instances=InstanceData(
polygons=np.array([
[0, 0, 10, 0, 10, 10, 0, 10],
[20, 0, 30, 0, 30, 10, 20, 10],
[0, 0, 15, 0, 15, 10, 0, 10],
],
dtype=np.float32),
ignored=torch.BoolTensor([False, False, True])))
]
self.preds = [
dict(
cls_res=torch.rand(1, 4, 40, 40),
reg_res=torch.rand(1, 22, 40, 40)),
dict(
cls_res=torch.rand(1, 4, 20, 20),
reg_res=torch.rand(1, 22, 20, 20)),
dict(
cls_res=torch.rand(1, 4, 10, 10),
reg_res=torch.rand(1, 22, 10, 10))
]
def test_forward(self):
losses = self.fce_loss(self.preds, self.data_samples)
assert 'loss_text' in losses
assert 'loss_center' in losses
assert 'loss_reg_x' in losses
assert 'loss_reg_y' in losses