mirror of https://github.com/open-mmlab/mmocr.git
parent
03720f46c3
commit
aa87b69f12
|
@ -1,6 +1,5 @@
|
|||
from .dense_heads import * # noqa: F401,F403
|
||||
from .detectors import * # noqa: F401,F403
|
||||
from .losses import * # noqa: F401,F403
|
||||
from .modules import * # noqa: F401,F403
|
||||
from .necks import * # noqa: F401,F403
|
||||
from .postprocess import * # noqa: F401,F403
|
||||
|
|
|
@ -1,10 +1,7 @@
|
|||
from .db_head import DBHead
|
||||
from .drrg_head import DRRGHead
|
||||
from .head_mixin import HeadMixin
|
||||
from .pan_head import PANHead
|
||||
from .pse_head import PSEHead
|
||||
from .textsnake_head import TextSnakeHead
|
||||
|
||||
__all__ = [
|
||||
'PSEHead', 'PANHead', 'DBHead', 'HeadMixin', 'DRRGHead', 'TextSnakeHead'
|
||||
]
|
||||
__all__ = ['PSEHead', 'PANHead', 'DBHead', 'HeadMixin', 'TextSnakeHead']
|
||||
|
|
|
@ -1,193 +0,0 @@
|
|||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmcv.cnn import normal_init
|
||||
|
||||
from mmdet.models.builder import HEADS, build_loss
|
||||
from mmocr.models.textdet.modules import (GCN, LocalGraphs,
|
||||
ProposalLocalGraphs,
|
||||
merge_text_comps)
|
||||
from .head_mixin import HeadMixin
|
||||
|
||||
|
||||
@HEADS.register_module()
|
||||
class DRRGHead(HeadMixin, nn.Module):
|
||||
"""The class for DRRG head: Deep Relational Reasoning Graph Network for
|
||||
Arbitrary Shape Text Detection.
|
||||
|
||||
[https://arxiv.org/abs/2003.07493]
|
||||
|
||||
Args:
|
||||
k_at_hops (tuple(int)): The number of i-hop neighbors,
|
||||
i = 1, 2, ..., h.
|
||||
active_connection (int): The number of two hop neighbors deem as
|
||||
linked to a pivot.
|
||||
node_geo_feat_dim (int): The dimension of embedded geometric features
|
||||
of a component.
|
||||
pooling_scale (float): The spatial scale of RRoI-Aligning.
|
||||
pooling_output_size (tuple(int)): The size of RRoI-Aligning output.
|
||||
graph_filter_thr (float): The threshold to filter identical local
|
||||
graphs.
|
||||
comp_shrink_ratio (float): The shrink ratio of text components.
|
||||
nms_thr (float): The locality-aware NMS threshold.
|
||||
min_width (float): The minimum width of text components.
|
||||
max_width (float): The maximum width of text components.
|
||||
comp_ratio (float): The reciprocal of aspect ratio of text components.
|
||||
text_region_thr (float): The threshold for text region probability map.
|
||||
center_region_thr (float): The threshold of text center region
|
||||
probability map.
|
||||
center_region_area_thr (int): The threshold of filtering small-size
|
||||
text center region.
|
||||
link_thr (float): The threshold for connected components searching.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
k_at_hops=(8, 4),
|
||||
active_connection=3,
|
||||
node_geo_feat_dim=120,
|
||||
pooling_scale=1.0,
|
||||
pooling_output_size=(3, 4),
|
||||
graph_filter_thr=0.75,
|
||||
comp_shrink_ratio=0.95,
|
||||
nms_thr=0.25,
|
||||
min_width=8.0,
|
||||
max_width=24.0,
|
||||
comp_ratio=0.65,
|
||||
text_region_thr=0.6,
|
||||
center_region_thr=0.4,
|
||||
center_region_area_thr=100,
|
||||
link_thr=0.85,
|
||||
loss=dict(type='DRRGLoss'),
|
||||
train_cfg=None,
|
||||
test_cfg=None):
|
||||
super().__init__()
|
||||
|
||||
assert isinstance(in_channels, int)
|
||||
assert isinstance(k_at_hops, tuple)
|
||||
assert isinstance(active_connection, int)
|
||||
assert isinstance(node_geo_feat_dim, int)
|
||||
assert isinstance(pooling_scale, float)
|
||||
assert isinstance(pooling_output_size, tuple)
|
||||
assert isinstance(graph_filter_thr, float)
|
||||
assert isinstance(comp_shrink_ratio, float)
|
||||
assert isinstance(nms_thr, float)
|
||||
assert isinstance(min_width, float)
|
||||
assert isinstance(max_width, float)
|
||||
assert isinstance(comp_ratio, float)
|
||||
assert isinstance(center_region_area_thr, int)
|
||||
assert isinstance(link_thr, float)
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = 6
|
||||
self.downsample_ratio = 1.0
|
||||
self.k_at_hops = k_at_hops
|
||||
self.active_connection = active_connection
|
||||
self.node_geo_feat_dim = node_geo_feat_dim
|
||||
self.pooling_scale = pooling_scale
|
||||
self.pooling_output_size = pooling_output_size
|
||||
self.graph_filter_thr = graph_filter_thr
|
||||
self.comp_shrink_ratio = comp_shrink_ratio
|
||||
self.nms_thr = nms_thr
|
||||
self.min_width = min_width
|
||||
self.max_width = max_width
|
||||
self.comp_ratio = comp_ratio
|
||||
self.text_region_thr = text_region_thr
|
||||
self.center_region_thr = center_region_thr
|
||||
self.center_region_area_thr = center_region_area_thr
|
||||
self.link_thr = link_thr
|
||||
self.loss_module = build_loss(loss)
|
||||
self.train_cfg = train_cfg
|
||||
self.test_cfg = test_cfg
|
||||
|
||||
self.out_conv = nn.Conv2d(
|
||||
in_channels=self.in_channels,
|
||||
out_channels=self.out_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.init_weights()
|
||||
|
||||
self.graph_train = LocalGraphs(self.k_at_hops, self.active_connection,
|
||||
self.node_geo_feat_dim,
|
||||
self.pooling_scale,
|
||||
self.pooling_output_size,
|
||||
self.graph_filter_thr)
|
||||
|
||||
self.graph_test = ProposalLocalGraphs(
|
||||
self.k_at_hops, self.active_connection, self.node_geo_feat_dim,
|
||||
self.pooling_scale, self.pooling_output_size, self.nms_thr,
|
||||
self.min_width, self.max_width, self.comp_shrink_ratio,
|
||||
self.comp_ratio, self.text_region_thr, self.center_region_thr,
|
||||
self.center_region_area_thr)
|
||||
|
||||
pool_w, pool_h = self.pooling_output_size
|
||||
gcn_in_dim = (pool_w * pool_h) * (
|
||||
self.in_channels + self.out_channels) + self.node_geo_feat_dim
|
||||
self.gcn = GCN(gcn_in_dim, 32)
|
||||
|
||||
def init_weights(self):
|
||||
normal_init(self.out_conv, mean=0, std=0.01)
|
||||
|
||||
def forward(self, inputs, text_comp_feats):
|
||||
|
||||
pred_maps = self.out_conv(inputs)
|
||||
|
||||
feat_maps = torch.cat([inputs, pred_maps], dim=1)
|
||||
node_feats, adjacent_matrices, knn_inx, gt_labels = self.graph_train(
|
||||
feat_maps, np.array(text_comp_feats))
|
||||
|
||||
gcn_pred = self.gcn(node_feats, adjacent_matrices, knn_inx)
|
||||
|
||||
return (pred_maps, (gcn_pred, gt_labels))
|
||||
|
||||
def single_test(self, feat_maps):
|
||||
|
||||
pred_maps = self.out_conv(feat_maps)
|
||||
feat_maps = torch.cat([feat_maps[0], pred_maps], dim=1)
|
||||
|
||||
none_flag, graph_data = self.graph_test(pred_maps, feat_maps)
|
||||
|
||||
(node_feats, adjacent_matrix, pivot_inx, knn_inx, local_graph_nodes,
|
||||
text_comps) = graph_data
|
||||
|
||||
if none_flag:
|
||||
return None, None, None
|
||||
|
||||
adjacent_matrix, pivot_inx, knn_inx = map(
|
||||
lambda x: x.to(feat_maps.device),
|
||||
(adjacent_matrix, pivot_inx, knn_inx))
|
||||
gcn_pred = self.gcn_model(node_feats, adjacent_matrix, knn_inx)
|
||||
|
||||
pred_labels = F.softmax(gcn_pred, dim=1)
|
||||
|
||||
edges = []
|
||||
scores = []
|
||||
local_graph_nodes = local_graph_nodes.long().squeeze().cpu().numpy()
|
||||
graph_num = node_feats.size(0)
|
||||
|
||||
for graph_inx in range(graph_num):
|
||||
pivot = pivot_inx[graph_inx].int().item()
|
||||
nodes = local_graph_nodes[graph_inx]
|
||||
for neighbor_inx, neighbor in enumerate(knn_inx[graph_inx]):
|
||||
neighbor = neighbor.item()
|
||||
edges.append([nodes[pivot], nodes[neighbor]])
|
||||
scores.append(pred_labels[graph_inx * (knn_inx.shape[1]) +
|
||||
neighbor_inx, 1].item())
|
||||
|
||||
edges = np.asarray(edges)
|
||||
scores = np.asarray(scores)
|
||||
|
||||
return edges, scores, text_comps
|
||||
|
||||
def get_boundary(self, edges, scores, text_comps):
|
||||
|
||||
boundaries = []
|
||||
if edges is not None:
|
||||
boundaries = merge_text_comps(edges, scores, text_comps,
|
||||
self.link_thr)
|
||||
|
||||
results = dict(boundary_result=boundaries)
|
||||
|
||||
return results
|
|
@ -4,10 +4,9 @@ from .dbnet import DBNet # isort:skip
|
|||
from .ocr_mask_rcnn import OCRMaskRCNN # isort:skip
|
||||
from .panet import PANet # isort:skip
|
||||
from .psenet import PSENet # isort:skip
|
||||
from .drrg import DRRG # isort:skip
|
||||
from .textsnake import TextSnake # isort:skip
|
||||
|
||||
__all__ = [
|
||||
'TextDetectorMixin', 'SingleStageTextDetector', 'OCRMaskRCNN', 'DBNet',
|
||||
'PANet', 'PSENet', 'DRRG', 'TextSnake'
|
||||
'PANet', 'PSENet', 'TextSnake'
|
||||
]
|
||||
|
|
|
@ -1,38 +0,0 @@
|
|||
from mmdet.models.builder import DETECTORS
|
||||
from . import SingleStageTextDetector, TextDetectorMixin
|
||||
|
||||
|
||||
@DETECTORS.register_module()
|
||||
class DRRG(TextDetectorMixin, SingleStageTextDetector):
|
||||
"""The class for implementing DRRG text detector: Deep Relational Reasoning
|
||||
Graph Network for Arbitrary Shape Text Detection.
|
||||
|
||||
[https://arxiv.org/abs/2003.07493]
|
||||
"""
|
||||
|
||||
def forward_train(self, img, img_metas, **kwargs):
|
||||
"""
|
||||
Args:
|
||||
img (Tensor): Input images of shape (N, C, H, W).
|
||||
Typically these should be mean centered and std scaled.
|
||||
img_metas (list[dict]): A list of image info dict where each dict
|
||||
has: 'img_shape', 'scale_factor', 'flip', and may also contain
|
||||
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
|
||||
For details of the values of these keys, see
|
||||
:class:`mmdet.datasets.pipelines.Collect`.
|
||||
Returns:
|
||||
dict[str, Tensor]: A dictionary of loss components.
|
||||
"""
|
||||
x = self.extract_feat(img)
|
||||
gt_comp_attribs = kwargs.pop('gt_comp_attribs')
|
||||
preds = self.bbox_head(x, gt_comp_attribs)
|
||||
losses = self.bbox_head.loss(preds, **kwargs)
|
||||
return losses
|
||||
|
||||
def simple_test(self, img, img_metas, rescale=False):
|
||||
|
||||
x = self.extract_feat(img)
|
||||
outs = self.bbox_head.single_test(x, img)
|
||||
boundaries = self.bbox_head.get_boundary(*outs, img_metas, rescale)
|
||||
|
||||
return [boundaries]
|
|
@ -1,7 +1,6 @@
|
|||
from .db_loss import DBLoss
|
||||
from .drrg_loss import DRRGLoss
|
||||
from .pan_loss import PANLoss
|
||||
from .pse_loss import PSELoss
|
||||
from .textsnake_loss import TextSnakeLoss
|
||||
|
||||
__all__ = ['PANLoss', 'PSELoss', 'DBLoss', 'DRRGLoss', 'TextSnakeLoss']
|
||||
__all__ = ['PANLoss', 'PSELoss', 'DBLoss', 'TextSnakeLoss']
|
||||
|
|
|
@ -1,206 +0,0 @@
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from mmdet.core import BitmapMasks
|
||||
from mmdet.models.builder import LOSSES
|
||||
from mmocr.utils import check_argument
|
||||
|
||||
|
||||
@LOSSES.register_module()
|
||||
class DRRGLoss(nn.Module):
|
||||
"""The class for implementing DRRG loss: Deep Relational Reasoning Graph
|
||||
Network for Arbitrary Shape Text Detection.
|
||||
|
||||
[https://arxiv.org/abs/1908.05900] This is partially adapted from
|
||||
https://github.com/GXYM/DRRG.
|
||||
"""
|
||||
|
||||
def __init__(self, ohem_ratio=3.0):
|
||||
"""Initialization.
|
||||
|
||||
Args:
|
||||
ohem_ratio (float): The negative/positive ratio in OHEM.
|
||||
"""
|
||||
super().__init__()
|
||||
self.ohem_ratio = ohem_ratio
|
||||
|
||||
def balance_bce_loss(self, pred, gt, mask):
|
||||
|
||||
assert pred.shape == gt.shape == mask.shape
|
||||
positive = gt * mask
|
||||
negative = (1 - gt) * mask
|
||||
positive_count = int(positive.float().sum())
|
||||
gt = gt.float()
|
||||
if positive_count > 0:
|
||||
loss = F.binary_cross_entropy(pred, gt, reduction='none')
|
||||
positive_loss = torch.sum(loss * positive.float())
|
||||
negative_loss = loss * negative.float()
|
||||
negative_count = min(
|
||||
int(negative.float().sum()),
|
||||
int(positive_count * self.ohem_ratio))
|
||||
else:
|
||||
positive_loss = torch.tensor(0.0)
|
||||
loss = F.binary_cross_entropy(pred, gt, reduction='none')
|
||||
negative_loss = loss * negative.float()
|
||||
negative_count = 100
|
||||
negative_loss, _ = torch.topk(negative_loss.view(-1), negative_count)
|
||||
|
||||
balance_loss = (positive_loss + torch.sum(negative_loss)) / (
|
||||
float(positive_count + negative_count) + 1e-5)
|
||||
|
||||
return balance_loss
|
||||
|
||||
def gcn_loss(self, gcn_data):
|
||||
|
||||
gcn_pred, gt_labels = gcn_data
|
||||
gt_labels = gt_labels.view(-1).to(gcn_pred.device)
|
||||
loss = F.cross_entropy(gcn_pred, gt_labels)
|
||||
|
||||
return loss
|
||||
|
||||
def bitmasks2tensor(self, bitmasks, target_sz):
|
||||
"""Convert Bitmasks to tensor.
|
||||
|
||||
Args:
|
||||
bitmasks (list[BitmapMasks]): The BitmapMasks list. Each item is
|
||||
for one img.
|
||||
target_sz (tuple(int, int)): The target tensor size HxW.
|
||||
|
||||
Returns
|
||||
results (list[tensor]): The list of kernel tensors. Each
|
||||
element is for one kernel level.
|
||||
"""
|
||||
assert check_argument.is_type_list(bitmasks, BitmapMasks)
|
||||
assert isinstance(target_sz, tuple)
|
||||
|
||||
batch_size = len(bitmasks)
|
||||
num_masks = len(bitmasks[0])
|
||||
|
||||
results = []
|
||||
|
||||
for level_inx in range(num_masks):
|
||||
kernel = []
|
||||
for batch_inx in range(batch_size):
|
||||
mask = torch.from_numpy(bitmasks[batch_inx].masks[level_inx])
|
||||
# hxw
|
||||
mask_sz = mask.shape
|
||||
# left, right, top, bottom
|
||||
pad = [
|
||||
0, target_sz[1] - mask_sz[1], 0, target_sz[0] - mask_sz[0]
|
||||
]
|
||||
mask = F.pad(mask, pad, mode='constant', value=0)
|
||||
kernel.append(mask)
|
||||
kernel = torch.stack(kernel)
|
||||
results.append(kernel)
|
||||
|
||||
return results
|
||||
|
||||
def forward(self, preds, downsample_ratio, gt_text_mask,
|
||||
gt_center_region_mask, gt_mask, gt_top_height_map,
|
||||
gt_bot_height_map, gt_sin_map, gt_cos_map):
|
||||
|
||||
assert isinstance(preds, tuple)
|
||||
assert isinstance(downsample_ratio, float)
|
||||
assert abs(downsample_ratio - 1.0) < 1e-5
|
||||
assert check_argument.is_type_list(gt_text_mask, BitmapMasks)
|
||||
assert check_argument.is_type_list(gt_center_region_mask, BitmapMasks)
|
||||
assert check_argument.is_type_list(gt_mask, BitmapMasks)
|
||||
assert check_argument.is_type_list(gt_top_height_map, BitmapMasks)
|
||||
assert check_argument.is_type_list(gt_bot_height_map, BitmapMasks)
|
||||
assert check_argument.is_type_list(gt_sin_map, BitmapMasks)
|
||||
assert check_argument.is_type_list(gt_cos_map, BitmapMasks)
|
||||
|
||||
pred_maps, gcn_data = preds
|
||||
pred_text_region = pred_maps[:, 0, :, :]
|
||||
pred_center_region = pred_maps[:, 1, :, :]
|
||||
pred_sin_map = pred_maps[:, 2, :, :]
|
||||
pred_cos_map = pred_maps[:, 3, :, :]
|
||||
pred_top_height_map = pred_maps[:, 4, :, :]
|
||||
pred_bot_height_map = pred_maps[:, 5, :, :]
|
||||
feature_sz = pred_maps.size()
|
||||
|
||||
# bitmask 2 tensor
|
||||
mapping = {
|
||||
'gt_text_mask': gt_text_mask,
|
||||
'gt_center_region_mask': gt_center_region_mask,
|
||||
'gt_mask': gt_mask,
|
||||
'gt_top_height_map': gt_top_height_map,
|
||||
'gt_bot_height_map': gt_bot_height_map,
|
||||
'gt_sin_map': gt_sin_map,
|
||||
'gt_cos_map': gt_cos_map
|
||||
}
|
||||
gt = {}
|
||||
for key, value in mapping.items():
|
||||
gt[key] = value
|
||||
gt[key] = self.bitmasks2tensor(gt[key], feature_sz[2:])
|
||||
gt[key] = [item.to(pred_maps.device) for item in gt[key]]
|
||||
|
||||
scale = torch.sqrt(1.0 / (pred_sin_map**2 + pred_cos_map**2 + 1e-8))
|
||||
pred_sin_map = pred_sin_map * scale
|
||||
pred_cos_map = pred_cos_map * scale
|
||||
|
||||
loss_text = self.balance_bce_loss(
|
||||
torch.sigmoid(pred_text_region), gt['gt_text_mask'][0],
|
||||
gt['gt_mask'][0])
|
||||
|
||||
text_mask = (gt['gt_text_mask'][0] * gt['gt_mask'][0]).float()
|
||||
negative_text_mask = ((1 - gt['gt_text_mask'][0]) *
|
||||
gt['gt_mask'][0]).float()
|
||||
gt_center_region_mask = gt['gt_center_region_mask'][0].float()
|
||||
loss_center = F.binary_cross_entropy(
|
||||
torch.sigmoid(pred_center_region),
|
||||
gt_center_region_mask,
|
||||
reduction='none')
|
||||
if int(text_mask.sum()) > 0:
|
||||
loss_center_positive = torch.sum(
|
||||
loss_center * text_mask) / torch.sum(text_mask)
|
||||
else:
|
||||
loss_center_positive = torch.tensor(0.0)
|
||||
loss_center_negative = torch.sum(
|
||||
loss_center * negative_text_mask) / torch.sum(negative_text_mask)
|
||||
loss_center = loss_center_positive + 0.5 * loss_center_negative
|
||||
|
||||
center_mask = (gt['gt_center_region_mask'][0] *
|
||||
gt['gt_mask'][0]).float()
|
||||
if int(center_mask.sum()) > 0:
|
||||
ones = torch.ones_like(
|
||||
gt['gt_top_height_map'][0], dtype=torch.float)
|
||||
loss_top = F.smooth_l1_loss(
|
||||
pred_top_height_map / (gt['gt_top_height_map'][0] + 1e-2),
|
||||
ones,
|
||||
reduction='none')
|
||||
loss_bot = F.smooth_l1_loss(
|
||||
pred_bot_height_map / (gt['gt_bot_height_map'][0] + 1e-2),
|
||||
ones,
|
||||
reduction='none')
|
||||
gt_height = (
|
||||
gt['gt_top_height_map'][0] + gt['gt_bot_height_map'][0])
|
||||
loss_height = torch.sum(
|
||||
(torch.log(gt_height + 1) *
|
||||
(loss_top + loss_bot)) * center_mask) / torch.sum(center_mask)
|
||||
|
||||
loss_sin = torch.sum(
|
||||
F.smooth_l1_loss(
|
||||
pred_sin_map, gt['gt_sin_map'][0], reduction='none') *
|
||||
center_mask) / torch.sum(center_mask)
|
||||
loss_cos = torch.sum(
|
||||
F.smooth_l1_loss(
|
||||
pred_cos_map, gt['gt_cos_map'][0], reduction='none') *
|
||||
center_mask) / torch.sum(center_mask)
|
||||
else:
|
||||
loss_height = torch.tensor(0.0)
|
||||
loss_sin = torch.tensor(0.0)
|
||||
loss_cos = torch.tensor(0.0)
|
||||
|
||||
loss_gcn = self.gcn_loss(gcn_data)
|
||||
|
||||
results = dict(
|
||||
loss_text=loss_text,
|
||||
loss_center=loss_center,
|
||||
loss_height=loss_height,
|
||||
loss_sin=loss_sin,
|
||||
loss_cos=loss_cos,
|
||||
loss_gcn=loss_gcn)
|
||||
|
||||
return results
|
|
@ -1,6 +0,0 @@
|
|||
from .gcn import GCN
|
||||
from .local_graph import LocalGraphs
|
||||
from .proposal_local_graph import ProposalLocalGraphs
|
||||
from .utils import merge_text_comps
|
||||
|
||||
__all__ = ['LocalGraphs', 'ProposalLocalGraphs', 'GCN', 'merge_text_comps']
|
|
@ -1,80 +0,0 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import init
|
||||
|
||||
|
||||
class MeanAggregator(nn.Module):
|
||||
|
||||
def forward(self, features, A):
|
||||
x = torch.bmm(A, features)
|
||||
return x
|
||||
|
||||
|
||||
class GraphConv(nn.Module):
|
||||
|
||||
def __init__(self, in_dim, out_dim):
|
||||
super(GraphConv, self).__init__()
|
||||
self.in_dim = in_dim
|
||||
self.out_dim = out_dim
|
||||
self.weight = nn.Parameter(torch.FloatTensor(in_dim * 2, out_dim))
|
||||
self.bias = nn.Parameter(torch.FloatTensor(out_dim))
|
||||
init.xavier_uniform_(self.weight)
|
||||
init.constant_(self.bias, 0)
|
||||
self.agg = MeanAggregator()
|
||||
|
||||
def forward(self, features, A):
|
||||
b, n, d = features.shape
|
||||
assert d == self.in_dim
|
||||
agg_feats = self.agg(features, A)
|
||||
cat_feats = torch.cat([features, agg_feats], dim=2)
|
||||
out = torch.einsum('bnd,df->bnf', (cat_feats, self.weight))
|
||||
out = F.relu(out + self.bias)
|
||||
return out
|
||||
|
||||
|
||||
class GCN(nn.Module):
|
||||
"""Predict linkage between instances. This was from repo
|
||||
https://github.com/Zhongdao/gcn_clustering: Linkage Based Face Clustering
|
||||
via Graph Convolution Network.
|
||||
|
||||
[https://arxiv.org/abs/1903.11306]
|
||||
|
||||
Args:
|
||||
in_dim(int): The input dimension.
|
||||
out_dim(int): The output dimension.
|
||||
"""
|
||||
|
||||
def __init__(self, in_dim, out_dim):
|
||||
super(GCN, self).__init__()
|
||||
self.bn0 = nn.BatchNorm1d(in_dim, affine=False).float()
|
||||
self.conv1 = GraphConv(in_dim, 512)
|
||||
self.conv2 = GraphConv(512, 256)
|
||||
self.conv3 = GraphConv(256, 128)
|
||||
self.conv4 = GraphConv(128, 64)
|
||||
|
||||
self.classifier = nn.Sequential(
|
||||
nn.Linear(64, out_dim), nn.PReLU(out_dim), nn.Linear(out_dim, 2))
|
||||
|
||||
def forward(self, x, A, one_hop_indexes, train=True):
|
||||
|
||||
B, N, D = x.shape
|
||||
|
||||
x = x.view(-1, D)
|
||||
x = self.bn0(x)
|
||||
x = x.view(B, N, D)
|
||||
|
||||
x = self.conv1(x, A)
|
||||
x = self.conv2(x, A)
|
||||
x = self.conv3(x, A)
|
||||
x = self.conv4(x, A)
|
||||
k1 = one_hop_indexes.size(-1)
|
||||
dout = x.size(-1)
|
||||
edge_feat = torch.zeros(B, k1, dout)
|
||||
for b in range(B):
|
||||
edge_feat[b, :, :] = x[b, one_hop_indexes[b]]
|
||||
edge_feat = edge_feat.view(-1, dout).to(x.device)
|
||||
pred = self.classifier(edge_feat)
|
||||
|
||||
# shape: (B*k1)x2
|
||||
return pred
|
|
@ -1,307 +0,0 @@
|
|||
import numpy as np
|
||||
import torch
|
||||
|
||||
from mmocr.models.utils import RROIAlign
|
||||
from .utils import (embed_geo_feats, euclidean_distance_matrix,
|
||||
normalize_adjacent_matrix)
|
||||
|
||||
|
||||
class LocalGraphs:
|
||||
"""Generate local graphs for GCN to predict which instance a text component
|
||||
belongs to. This was partially adapted from https://github.com/GXYM/DRRG:
|
||||
Deep Relational Reasoning Graph Network for Arbitrary Shape Text Detection.
|
||||
|
||||
[https://arxiv.org/abs/2003.07493]
|
||||
|
||||
Args:
|
||||
k_at_hops (tuple(int)): The number of h-hop neighbors.
|
||||
active_connection (int): The number of neighbors deem as linked to a
|
||||
pivot.
|
||||
node_geo_feat_dim (int): The dimension of embedded geometric features
|
||||
of a component.
|
||||
pooling_scale (float): The spatial scale of RRoI-Aligning.
|
||||
pooling_output_size (tuple(int)): The size of RRoI-Aligning output.
|
||||
local_graph_filter_thr (float): The threshold to filter out identical
|
||||
local graphs.
|
||||
"""
|
||||
|
||||
def __init__(self, k_at_hops, active_connection, node_geo_feat_dim,
|
||||
pooling_scale, pooling_output_size, local_graph_filter_thr):
|
||||
|
||||
assert isinstance(k_at_hops, tuple)
|
||||
assert isinstance(active_connection, int)
|
||||
assert isinstance(node_geo_feat_dim, int)
|
||||
assert isinstance(pooling_scale, float)
|
||||
assert isinstance(pooling_output_size, tuple)
|
||||
assert isinstance(local_graph_filter_thr, float)
|
||||
|
||||
self.k_at_hops = k_at_hops
|
||||
self.local_graph_depth = len(self.k_at_hops)
|
||||
self.active_connection = active_connection
|
||||
self.node_geo_feat_dim = node_geo_feat_dim
|
||||
self.pooling = RROIAlign(pooling_output_size, pooling_scale)
|
||||
self.local_graph_filter_thr = local_graph_filter_thr
|
||||
|
||||
def generate_local_graphs(self, sorted_complete_graph, gt_belong_labels):
|
||||
"""Generate local graphs for GCN to predict which instance a text
|
||||
component belongs to.
|
||||
|
||||
Args:
|
||||
sorted_complete_graph (ndarray): The complete graph where nodes are
|
||||
sorted according to their Euclidean distance.
|
||||
gt_belong_labels (ndarray): The ground truth labels define which
|
||||
instance text components (nodes in graphs) belong to.
|
||||
|
||||
Returns:
|
||||
local_graph_node_list (list): The list of local graph neighbors of
|
||||
pivots.
|
||||
knn_graph_neighbor_list (list): The list of k nearest neighbors of
|
||||
pivots.
|
||||
"""
|
||||
|
||||
assert sorted_complete_graph.ndim == 2
|
||||
assert (sorted_complete_graph.shape[0] ==
|
||||
sorted_complete_graph.shape[1] == gt_belong_labels.shape[0])
|
||||
|
||||
knn_graphs = sorted_complete_graph[:, :self.k_at_hops[0] + 1]
|
||||
local_graph_node_list = list()
|
||||
knn_graph_neighbor_list = list()
|
||||
for pivot_inx, knn_graph in enumerate(knn_graphs):
|
||||
|
||||
h_hop_neighbor_list = list()
|
||||
one_hop_neighbors = set(knn_graph[1:])
|
||||
h_hop_neighbor_list.append(one_hop_neighbors)
|
||||
|
||||
for hop_inx in range(1, self.local_graph_depth):
|
||||
h_hop_neighbor_list.append(set())
|
||||
for last_hop_neighbor_inx in h_hop_neighbor_list[-2]:
|
||||
h_hop_neighbor_list[-1].update(
|
||||
set(sorted_complete_graph[last_hop_neighbor_inx]
|
||||
[1:self.k_at_hops[hop_inx] + 1]))
|
||||
|
||||
hops_neighbor_set = set(
|
||||
[node for hop in h_hop_neighbor_list for node in hop])
|
||||
hops_neighbor_list = list(hops_neighbor_set)
|
||||
hops_neighbor_list.insert(0, pivot_inx)
|
||||
|
||||
if pivot_inx < 1:
|
||||
local_graph_node_list.append(hops_neighbor_list)
|
||||
knn_graph_neighbor_list.append(one_hop_neighbors)
|
||||
else:
|
||||
add_flag = True
|
||||
for graph_inx in range(len(knn_graph_neighbor_list)):
|
||||
knn_graph_neighbors = knn_graph_neighbor_list[graph_inx]
|
||||
local_graph_nodes = local_graph_node_list[graph_inx]
|
||||
|
||||
node_union_num = len(
|
||||
list(
|
||||
set(knn_graph_neighbors).union(
|
||||
set(one_hop_neighbors))))
|
||||
node_intersect_num = len(
|
||||
list(
|
||||
set(knn_graph_neighbors).intersection(
|
||||
set(one_hop_neighbors))))
|
||||
one_hop_iou = node_intersect_num / (node_union_num + 1e-8)
|
||||
|
||||
if (one_hop_iou > self.local_graph_filter_thr
|
||||
and pivot_inx in knn_graph_neighbors
|
||||
and gt_belong_labels[local_graph_nodes[0]]
|
||||
== gt_belong_labels[pivot_inx]
|
||||
and gt_belong_labels[local_graph_nodes[0]] != 0):
|
||||
add_flag = False
|
||||
break
|
||||
if add_flag:
|
||||
local_graph_node_list.append(hops_neighbor_list)
|
||||
knn_graph_neighbor_list.append(one_hop_neighbors)
|
||||
|
||||
return local_graph_node_list, knn_graph_neighbor_list
|
||||
|
||||
def generate_gcn_input(self, node_feat_batch, belong_label_batch,
|
||||
local_graph_node_batch, knn_graph_neighbor_batch,
|
||||
sorted_complete_graph):
|
||||
"""Generate graph convolution network input data.
|
||||
|
||||
Args:
|
||||
node_feat_batch (List[Tensor]): The node feature batch.
|
||||
belong_label_batch (List[ndarray]): The text component belong label
|
||||
batch.
|
||||
local_graph_node_batch (List[List[list]]): The local graph
|
||||
neighbors batch.
|
||||
knn_graph_neighbor_batch (List[List[set]]): The knn graph neighbor
|
||||
batch.
|
||||
sorted_complete_graph (List[ndarray]): The complete graph sorted
|
||||
according to the Euclidean distance.
|
||||
|
||||
Returns:
|
||||
node_feat_batch_tensor (Tensor): The node features of Graph
|
||||
Convolutional Network (GCN).
|
||||
adjacent_mat_batch_tensor (Tensor): The adjacent matrices.
|
||||
knn_inx_batch_tensor (Tensor): The indices of k nearest neighbors.
|
||||
gt_linkage_batch_tensor (Tensor): The surpervision signal of GCN
|
||||
for linkage prediction.
|
||||
"""
|
||||
|
||||
assert isinstance(node_feat_batch, list)
|
||||
assert isinstance(belong_label_batch, list)
|
||||
assert isinstance(local_graph_node_batch, list)
|
||||
assert isinstance(knn_graph_neighbor_batch, list)
|
||||
assert isinstance(sorted_complete_graph, list)
|
||||
|
||||
max_graph_node_num = max([
|
||||
len(local_graph_nodes)
|
||||
for local_graph_node_list in local_graph_node_batch
|
||||
for local_graph_nodes in local_graph_node_list
|
||||
])
|
||||
|
||||
node_feat_batch_list = list()
|
||||
adjacent_matrix_batch_list = list()
|
||||
knn_inx_batch_list = list()
|
||||
gt_linkage_batch_list = list()
|
||||
|
||||
for batch_inx, sorted_neighbors in enumerate(sorted_complete_graph):
|
||||
node_feats = node_feat_batch[batch_inx]
|
||||
local_graph_list = local_graph_node_batch[batch_inx]
|
||||
knn_graph_neighbor_list = knn_graph_neighbor_batch[batch_inx]
|
||||
belong_labels = belong_label_batch[batch_inx]
|
||||
|
||||
for graph_inx in range(len(local_graph_list)):
|
||||
local_graph_nodes = local_graph_list[graph_inx]
|
||||
local_graph_node_num = len(local_graph_nodes)
|
||||
pivot_inx = local_graph_nodes[0]
|
||||
knn_graph_neighbors = knn_graph_neighbor_list[graph_inx]
|
||||
node_to_graph_inx = {
|
||||
j: i
|
||||
for i, j in enumerate(local_graph_nodes)
|
||||
}
|
||||
|
||||
knn_inx_in_local_graph = torch.tensor(
|
||||
[node_to_graph_inx[i] for i in knn_graph_neighbors],
|
||||
dtype=torch.long)
|
||||
pivot_feats = node_feats[torch.tensor(
|
||||
pivot_inx, dtype=torch.long)]
|
||||
normalized_feats = node_feats[torch.tensor(
|
||||
local_graph_nodes, dtype=torch.long)] - pivot_feats
|
||||
|
||||
adjacent_matrix = np.zeros(
|
||||
(local_graph_node_num, local_graph_node_num))
|
||||
pad_normalized_feats = torch.cat([
|
||||
normalized_feats,
|
||||
torch.zeros(max_graph_node_num - local_graph_node_num,
|
||||
normalized_feats.shape[1]).to(
|
||||
node_feats.device)
|
||||
],
|
||||
dim=0)
|
||||
|
||||
for node in local_graph_nodes:
|
||||
neighbors = sorted_neighbors[node,
|
||||
1:self.active_connection + 1]
|
||||
for neighbor in neighbors:
|
||||
if neighbor in local_graph_nodes:
|
||||
adjacent_matrix[node_to_graph_inx[node],
|
||||
node_to_graph_inx[neighbor]] = 1
|
||||
adjacent_matrix[node_to_graph_inx[neighbor],
|
||||
node_to_graph_inx[node]] = 1
|
||||
|
||||
adjacent_matrix = normalize_adjacent_matrix(
|
||||
adjacent_matrix, type='DAD')
|
||||
adjacent_matrix_tensor = torch.zeros(max_graph_node_num,
|
||||
max_graph_node_num).to(
|
||||
node_feats.device)
|
||||
adjacent_matrix_tensor[:local_graph_node_num, :
|
||||
local_graph_node_num] = adjacent_matrix
|
||||
|
||||
local_graph_labels = torch.from_numpy(
|
||||
belong_labels[local_graph_nodes]).type(torch.long)
|
||||
knn_labels = local_graph_labels[knn_inx_in_local_graph]
|
||||
edge_labels = ((belong_labels[pivot_inx] == knn_labels)
|
||||
& (belong_labels[pivot_inx] > 0)).long()
|
||||
|
||||
node_feat_batch_list.append(pad_normalized_feats)
|
||||
adjacent_matrix_batch_list.append(adjacent_matrix_tensor)
|
||||
knn_inx_batch_list.append(knn_inx_in_local_graph)
|
||||
gt_linkage_batch_list.append(edge_labels)
|
||||
|
||||
node_feat_batch_tensor = torch.stack(node_feat_batch_list, 0)
|
||||
adjacent_mat_batch_tensor = torch.stack(adjacent_matrix_batch_list, 0)
|
||||
knn_inx_batch_tensor = torch.stack(knn_inx_batch_list, 0)
|
||||
gt_linkage_batch_tensor = torch.stack(gt_linkage_batch_list, 0)
|
||||
|
||||
return (node_feat_batch_tensor, adjacent_mat_batch_tensor,
|
||||
knn_inx_batch_tensor, gt_linkage_batch_tensor)
|
||||
|
||||
def __call__(self, feat_maps, comp_attribs):
|
||||
"""Generate local graphs.
|
||||
|
||||
Args:
|
||||
feat_maps (Tensor): The feature maps to propose node features in
|
||||
graph.
|
||||
comp_attribs (ndarray): The text components attributes.
|
||||
|
||||
Returns:
|
||||
node_feats_batch (Tensor): The node features of Graph Convolutional
|
||||
Network(GCN).
|
||||
adjacent_matrices_batch (Tensor): The adjacent matrices.
|
||||
knn_inx_batch (Tensor): The indices of k nearest neighbors.
|
||||
gt_linkage_batch (Tensor): The surpervision signal of GCN for
|
||||
linkage prediction.
|
||||
"""
|
||||
|
||||
assert isinstance(feat_maps, torch.Tensor)
|
||||
assert comp_attribs.shape[2] == 8
|
||||
|
||||
dist_sort_graph_batch_list = []
|
||||
local_graph_node_batch_list = []
|
||||
knn_graph_neighbor_batch_list = []
|
||||
node_feature_batch_list = []
|
||||
belong_label_batch_list = []
|
||||
|
||||
for batch_inx in range(comp_attribs.shape[0]):
|
||||
comp_num = int(comp_attribs[batch_inx, 0, 0])
|
||||
comp_geo_attribs = comp_attribs[batch_inx, :comp_num, 1:7]
|
||||
node_belong_labels = comp_attribs[batch_inx, :comp_num,
|
||||
7].astype(np.int32)
|
||||
|
||||
comp_centers = comp_geo_attribs[:, 0:2]
|
||||
distance_matrix = euclidean_distance_matrix(
|
||||
comp_centers, comp_centers)
|
||||
|
||||
graph_node_geo_feats = embed_geo_feats(comp_geo_attribs,
|
||||
self.node_geo_feat_dim)
|
||||
graph_node_geo_feats = torch.from_numpy(
|
||||
graph_node_geo_feats).float().to(feat_maps.device)
|
||||
|
||||
batch_id = np.zeros(
|
||||
(comp_geo_attribs.shape[0], 1), dtype=np.float32) * batch_inx
|
||||
text_comps = np.hstack(
|
||||
(batch_id, comp_geo_attribs.astype(np.float32)))
|
||||
text_comps = torch.from_numpy(text_comps).float().to(
|
||||
feat_maps.device)
|
||||
|
||||
comp_content_feats = self.pooling(
|
||||
feat_maps[batch_inx].unsqueeze(0), text_comps)
|
||||
comp_content_feats = comp_content_feats.view(
|
||||
comp_content_feats.shape[0], -1).to(feat_maps.device)
|
||||
node_feats = torch.cat((comp_content_feats, graph_node_geo_feats),
|
||||
dim=-1)
|
||||
|
||||
dist_sort_complete_graph = np.argsort(distance_matrix, axis=1)
|
||||
(local_graph_nodes,
|
||||
knn_graph_neighbors) = self.generate_local_graphs(
|
||||
dist_sort_complete_graph, node_belong_labels)
|
||||
|
||||
node_feature_batch_list.append(node_feats)
|
||||
belong_label_batch_list.append(node_belong_labels)
|
||||
local_graph_node_batch_list.append(local_graph_nodes)
|
||||
knn_graph_neighbor_batch_list.append(knn_graph_neighbors)
|
||||
dist_sort_graph_batch_list.append(dist_sort_complete_graph)
|
||||
|
||||
(node_feats_batch, adjacent_matrices_batch, knn_inx_batch,
|
||||
gt_linkage_batch) = \
|
||||
self.generate_gcn_input(node_feature_batch_list,
|
||||
belong_label_batch_list,
|
||||
local_graph_node_batch_list,
|
||||
knn_graph_neighbor_batch_list,
|
||||
dist_sort_graph_batch_list)
|
||||
|
||||
return (node_feats_batch, adjacent_matrices_batch, knn_inx_batch,
|
||||
gt_linkage_batch)
|
|
@ -1,418 +0,0 @@
|
|||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
# from mmocr.models.textdet.postprocess import la_nms
|
||||
from mmocr.models.utils import RROIAlign
|
||||
from .utils import (embed_geo_feats, euclidean_distance_matrix,
|
||||
normalize_adjacent_matrix)
|
||||
|
||||
|
||||
class ProposalLocalGraphs:
|
||||
"""Propose text components and generate local graphs. This was partially
|
||||
adapted from https://github.com/GXYM/DRRG: Deep Relational Reasoning Graph
|
||||
Network for Arbitrary Shape Text Detection.
|
||||
|
||||
[https://arxiv.org/abs/2003.07493]
|
||||
|
||||
Args:
|
||||
k_at_hops (tuple(int)): The number of i-hop neighbors,
|
||||
i = 1, 2, ..., h.
|
||||
active_connection (int): The number of two hop neighbors deem as linked
|
||||
to a pivot.
|
||||
node_geo_feat_dim (int): The dimension of embedded geometric features
|
||||
of a component.
|
||||
pooling_scale (float): The spatial scale of RRoI-Aligning.
|
||||
pooling_output_size (tuple(int)): The size of RRoI-Aligning output.
|
||||
nms_thr (float): The locality-aware NMS threshold.
|
||||
min_width (float): The minimum width of text components.
|
||||
max_width (float): The maximum width of text components.
|
||||
comp_shrink_ratio (float): The shrink ratio of text components.
|
||||
comp_ratio (float): The reciprocal of aspect ratio of text components.
|
||||
text_region_thr (float): The threshold for text region probability map.
|
||||
center_region_thr (float): The threshold for text center region
|
||||
probability map.
|
||||
center_region_area_thr (int): The threshold for filtering out
|
||||
small-size text center region.
|
||||
"""
|
||||
|
||||
def __init__(self, k_at_hops, active_connection, node_geo_feat_dim,
|
||||
pooling_scale, pooling_output_size, nms_thr, min_width,
|
||||
max_width, comp_shrink_ratio, comp_ratio, text_region_thr,
|
||||
center_region_thr, center_region_area_thr):
|
||||
|
||||
assert isinstance(k_at_hops, tuple)
|
||||
assert isinstance(active_connection, int)
|
||||
assert isinstance(node_geo_feat_dim, int)
|
||||
assert isinstance(pooling_scale, float)
|
||||
assert isinstance(pooling_output_size, tuple)
|
||||
assert isinstance(nms_thr, float)
|
||||
assert isinstance(min_width, float)
|
||||
assert isinstance(max_width, float)
|
||||
assert isinstance(comp_shrink_ratio, float)
|
||||
assert isinstance(comp_ratio, float)
|
||||
assert isinstance(text_region_thr, float)
|
||||
assert isinstance(center_region_thr, float)
|
||||
assert isinstance(center_region_area_thr, int)
|
||||
|
||||
self.k_at_hops = k_at_hops
|
||||
self.active_connection = active_connection
|
||||
self.local_graph_depth = len(self.k_at_hops)
|
||||
self.node_geo_feat_dim = node_geo_feat_dim
|
||||
self.pooling = RROIAlign(pooling_output_size, pooling_scale)
|
||||
self.nms_thr = nms_thr
|
||||
self.min_width = min_width
|
||||
self.max_width = max_width
|
||||
self.comp_shrink_ratio = comp_shrink_ratio
|
||||
self.comp_ratio = comp_ratio
|
||||
self.text_region_thr = text_region_thr
|
||||
self.center_region_thr = center_region_thr
|
||||
self.center_region_area_thr = center_region_area_thr
|
||||
|
||||
def fill_hole(self, input_mask):
|
||||
h, w = input_mask.shape
|
||||
canvas = np.zeros((h + 2, w + 2), np.uint8)
|
||||
canvas[1:h + 1, 1:w + 1] = input_mask.copy()
|
||||
|
||||
mask = np.zeros((h + 4, w + 4), np.uint8)
|
||||
|
||||
cv2.floodFill(canvas, mask, (0, 0), 1)
|
||||
canvas = canvas[1:h + 1, 1:w + 1].astype(np.bool)
|
||||
|
||||
return (~canvas | input_mask.astype(np.uint8))
|
||||
|
||||
def propose_comps(self, top_radius_map, bot_radius_map, sin_map, cos_map,
|
||||
score_map, min_width, max_width, comp_shrink_ratio,
|
||||
comp_ratio):
|
||||
"""Generate text components.
|
||||
|
||||
Args:
|
||||
top_radius_map (ndarray): The predicted distance map from each
|
||||
pixel in text center region to top sideline.
|
||||
bot_radius_map (ndarray): The predicted distance map from each
|
||||
pixel in text center region to bottom sideline.
|
||||
sin_map (ndarray): The predicted sin(theta) map.
|
||||
cos_map (ndarray): The predicted cos(theta) map.
|
||||
score_map (ndarray): The score map for NMS.
|
||||
min_width (float): The minimum width of text components.
|
||||
max_width (float): The maximum width of text components.
|
||||
comp_shrink_ratio (float): The shrink ratio of text components.
|
||||
comp_ratio (float): The reciprocal of aspect ratio of text
|
||||
components.
|
||||
|
||||
Returns:
|
||||
text_comps (ndarray): The text components.
|
||||
"""
|
||||
|
||||
comp_centers = np.argwhere(score_map > 0)
|
||||
comp_centers = comp_centers[np.argsort(comp_centers[:, 0])]
|
||||
y = comp_centers[:, 0]
|
||||
x = comp_centers[:, 1]
|
||||
|
||||
top_radius = top_radius_map[y, x].reshape((-1, 1)) * comp_shrink_ratio
|
||||
bot_radius = bot_radius_map[y, x].reshape((-1, 1)) * comp_shrink_ratio
|
||||
sin = sin_map[y, x].reshape((-1, 1))
|
||||
cos = cos_map[y, x].reshape((-1, 1))
|
||||
|
||||
top_mid_x_offset = top_radius * cos
|
||||
top_mid_y_offset = top_radius * sin
|
||||
bot_mid_x_offset = bot_radius * cos
|
||||
bot_mid_y_offset = bot_radius * sin
|
||||
|
||||
top_mid_pnt = comp_centers + np.hstack(
|
||||
[top_mid_y_offset, top_mid_x_offset])
|
||||
bot_mid_pnt = comp_centers - np.hstack(
|
||||
[bot_mid_y_offset, bot_mid_x_offset])
|
||||
|
||||
width = (top_radius + bot_radius) * comp_ratio
|
||||
width = np.clip(width, min_width, max_width)
|
||||
|
||||
top_left = top_mid_pnt - np.hstack([width * cos, -width * sin
|
||||
])[:, ::-1]
|
||||
top_right = top_mid_pnt + np.hstack([width * cos, -width * sin
|
||||
])[:, ::-1]
|
||||
bot_right = bot_mid_pnt + np.hstack([width * cos, -width * sin
|
||||
])[:, ::-1]
|
||||
bot_left = bot_mid_pnt - np.hstack([width * cos, -width * sin
|
||||
])[:, ::-1]
|
||||
|
||||
text_comps = np.hstack([top_left, top_right, bot_right, bot_left])
|
||||
score = score_map[y, x].reshape((-1, 1))
|
||||
text_comps = np.hstack([text_comps, score])
|
||||
|
||||
return text_comps
|
||||
|
||||
def propose_comps_and_attribs(self, text_region_map, center_region_map,
|
||||
top_radius_map, bot_radius_map, sin_map,
|
||||
cos_map):
|
||||
"""Generate text components and attributes.
|
||||
|
||||
Args:
|
||||
text_region_map (ndarray): The predicted text region probability
|
||||
map.
|
||||
center_region_map (ndarray): The predicted text center region
|
||||
probability map.
|
||||
top_radius_map (ndarray): The predicted distance map from each
|
||||
pixel in text center region to top sideline.
|
||||
bot_radius_map (ndarray): The predicted distance map from each
|
||||
pixel in text center region to bottom sideline.
|
||||
sin_map (ndarray): The predicted sin(theta) map.
|
||||
cos_map (ndarray): The predicted cos(theta) map.
|
||||
|
||||
Returns:
|
||||
comp_attribs (ndarray): The text components attributes.
|
||||
text_comps (ndarray): The text components.
|
||||
"""
|
||||
|
||||
assert (text_region_map.shape == center_region_map.shape ==
|
||||
top_radius_map.shape == bot_radius_map == sin_map.shape ==
|
||||
cos_map.shape)
|
||||
text_mask = text_region_map > self.text_region_thr
|
||||
center_region_mask = (center_region_map >
|
||||
self.center_region_thr) * text_mask
|
||||
|
||||
scale = np.sqrt(1.0 / (sin_map**2 + cos_map**2))
|
||||
sin_map, cos_map = sin_map * scale, cos_map * scale
|
||||
|
||||
center_region_mask = self.fill_hole(center_region_mask)
|
||||
center_region_contours, _ = cv2.findContours(
|
||||
center_region_mask.astype(np.uint8), cv2.RETR_TREE,
|
||||
cv2.CHAIN_APPROX_SIMPLE)
|
||||
|
||||
mask = np.zeros_like(center_region_mask)
|
||||
comp_list = []
|
||||
for contour in center_region_contours:
|
||||
current_center_mask = mask.copy()
|
||||
cv2.drawContours(current_center_mask, [contour], -1, 1, -1)
|
||||
if current_center_mask.sum() <= self.center_region_area_thr:
|
||||
continue
|
||||
score_map = text_region_map * current_center_mask
|
||||
|
||||
text_comp = self.propose_comps(top_radius_map, bot_radius_map,
|
||||
sin_map, cos_map, score_map,
|
||||
self.min_width, self.max_width,
|
||||
self.comp_shrink_ratio,
|
||||
self.comp_ratio)
|
||||
|
||||
# text_comp = la_nms(text_comp.astype('float32'), self.nms_thr)
|
||||
|
||||
text_comp_mask = mask.copy()
|
||||
text_comps_bboxes = text_comp[:, :8].reshape(
|
||||
(-1, 4, 2)).astype(np.int32)
|
||||
|
||||
cv2.drawContours(text_comp_mask, text_comps_bboxes, -1, 1, -1)
|
||||
if (text_comp_mask * text_mask).sum() < text_comp_mask.sum() * 0.5:
|
||||
continue
|
||||
|
||||
comp_list.append(text_comp)
|
||||
|
||||
if len(comp_list) <= 0:
|
||||
return None, None
|
||||
|
||||
text_comps = np.vstack(comp_list)
|
||||
|
||||
centers = np.mean(
|
||||
text_comps[:, :8].reshape((-1, 4, 2)), axis=1).astype(np.int32)
|
||||
|
||||
x = centers[:, 0]
|
||||
y = centers[:, 1]
|
||||
|
||||
h = top_radius_map[y, x].reshape(
|
||||
(-1, 1)) + bot_radius_map[y, x].reshape((-1, 1))
|
||||
w = np.clip(h * self.comp_ratio, self.min_width, self.max_width)
|
||||
sin = sin_map[y, x].reshape((-1, 1))
|
||||
cos = cos_map[y, x].reshape((-1, 1))
|
||||
x = x.reshape((-1, 1))
|
||||
y = y.reshape((-1, 1))
|
||||
comp_attribs = np.hstack([x, y, h, w, cos, sin])
|
||||
|
||||
return comp_attribs, text_comps
|
||||
|
||||
def generate_local_graphs(self, sorted_complete_graph, node_feats):
|
||||
"""Generate local graphs and Graph Convolution Network input data.
|
||||
|
||||
Args:
|
||||
sorted_complete_graph (ndarray): The complete graph where nodes are
|
||||
sorted according to their Euclidean distance.
|
||||
node_feats (tensor): The graph nodes features.
|
||||
|
||||
Returns:
|
||||
node_feats_tensor (tensor): The graph nodes features.
|
||||
adjacent_matrix_tensor (tensor): The adjacent matrix of graph.
|
||||
pivot_inx_tensor (tensor): The pivot indices in local graph.
|
||||
knn_inx_tensor (tensor): The k nearest neighbor nodes indexes in
|
||||
local graph.
|
||||
local_graph_node_tensor (tensor): The indices of nodes in local
|
||||
graph.
|
||||
"""
|
||||
|
||||
assert sorted_complete_graph.ndim == 2
|
||||
assert (sorted_complete_graph.shape[0] ==
|
||||
sorted_complete_graph.shape[1] == node_feats.shape[0])
|
||||
|
||||
knn_graphs = sorted_complete_graph[:, :self.k_at_hops[0] + 1]
|
||||
local_graph_node_list = list()
|
||||
knn_graph_neighbor_list = list()
|
||||
for pivot_inx, knn_graph in enumerate(knn_graphs):
|
||||
|
||||
h_hop_neighbor_list = list()
|
||||
one_hop_neighbors = set(knn_graph[1:])
|
||||
h_hop_neighbor_list.append(one_hop_neighbors)
|
||||
|
||||
for hop_inx in range(1, self.local_graph_depth):
|
||||
h_hop_neighbor_list.append(set())
|
||||
for last_hop_neighbor_inx in h_hop_neighbor_list[-2]:
|
||||
h_hop_neighbor_list[-1].update(
|
||||
set(sorted_complete_graph[last_hop_neighbor_inx]
|
||||
[1:self.k_at_hops[hop_inx] + 1]))
|
||||
|
||||
hops_neighbor_set = set(
|
||||
[node for hop in h_hop_neighbor_list for node in hop])
|
||||
hops_neighbor_list = list(hops_neighbor_set)
|
||||
hops_neighbor_list.insert(0, pivot_inx)
|
||||
|
||||
local_graph_node_list.append(hops_neighbor_list)
|
||||
knn_graph_neighbor_list.append(one_hop_neighbors)
|
||||
|
||||
max_graph_node_num = max([
|
||||
len(local_graph_nodes)
|
||||
for local_graph_nodes in local_graph_node_list
|
||||
])
|
||||
|
||||
node_normalized_feats = list()
|
||||
adjacent_matrix_list = list()
|
||||
knn_inx = list()
|
||||
pivot_graph_inx = list()
|
||||
local_graph_tensor_list = list()
|
||||
|
||||
for graph_inx in range(len(local_graph_node_list)):
|
||||
|
||||
local_graph_nodes = local_graph_node_list[graph_inx]
|
||||
local_graph_node_num = len(local_graph_nodes)
|
||||
pivot_inx = local_graph_nodes[0]
|
||||
knn_graph_neighbors = knn_graph_neighbor_list[graph_inx]
|
||||
node_to_graph_inx = {j: i for i, j in enumerate(local_graph_nodes)}
|
||||
|
||||
pivot_node_inx = torch.tensor([
|
||||
node_to_graph_inx[pivot_inx],
|
||||
]).type(torch.long)
|
||||
knn_inx_in_local_graph = torch.tensor(
|
||||
[node_to_graph_inx[i] for i in knn_graph_neighbors],
|
||||
dtype=torch.long)
|
||||
pivot_feats = node_feats[torch.tensor(pivot_inx, dtype=torch.long)]
|
||||
normalized_feats = node_feats[torch.tensor(
|
||||
local_graph_nodes, dtype=torch.long)] - pivot_feats
|
||||
|
||||
adjacent_matrix = np.zeros(
|
||||
(local_graph_node_num, local_graph_node_num))
|
||||
pad_normalized_feats = torch.cat([
|
||||
normalized_feats,
|
||||
torch.zeros(max_graph_node_num - local_graph_node_num,
|
||||
normalized_feats.shape[1]).to(node_feats.device)
|
||||
],
|
||||
dim=0)
|
||||
|
||||
for node in local_graph_nodes:
|
||||
neighbors = sorted_complete_graph[node,
|
||||
1:self.active_connection + 1]
|
||||
for neighbor in neighbors:
|
||||
if neighbor in local_graph_nodes:
|
||||
adjacent_matrix[node_to_graph_inx[node],
|
||||
node_to_graph_inx[neighbor]] = 1
|
||||
adjacent_matrix[node_to_graph_inx[neighbor],
|
||||
node_to_graph_inx[node]] = 1
|
||||
|
||||
adjacent_matrix = normalize_adjacent_matrix(
|
||||
adjacent_matrix, type='DAD')
|
||||
adjacent_matrix_tensor = torch.zeros(
|
||||
max_graph_node_num, max_graph_node_num).to(node_feats.device)
|
||||
adjacent_matrix_tensor[:local_graph_node_num, :
|
||||
local_graph_node_num] = adjacent_matrix
|
||||
|
||||
local_graph_tensor = torch.tensor(local_graph_nodes)
|
||||
local_graph_tensor = torch.cat([
|
||||
local_graph_tensor,
|
||||
torch.zeros(
|
||||
max_graph_node_num - local_graph_node_num,
|
||||
dtype=torch.long)
|
||||
],
|
||||
dim=0)
|
||||
|
||||
node_normalized_feats.append(pad_normalized_feats)
|
||||
adjacent_matrix_list.append(adjacent_matrix_tensor)
|
||||
pivot_graph_inx.append(pivot_node_inx)
|
||||
knn_inx.append(knn_inx_in_local_graph)
|
||||
local_graph_tensor_list.append(local_graph_tensor)
|
||||
|
||||
node_feats_tensor = torch.stack(node_normalized_feats, 0)
|
||||
adjacent_matrix_tensor = torch.stack(adjacent_matrix_list, 0)
|
||||
pivot_inx_tensor = torch.stack(pivot_graph_inx, 0)
|
||||
knn_inx_tensor = torch.stack(knn_inx, 0)
|
||||
local_graph_node_tensor = torch.stack(local_graph_tensor_list, 0)
|
||||
|
||||
return (node_feats_tensor, adjacent_matrix_tensor, pivot_inx_tensor,
|
||||
knn_inx_tensor, local_graph_node_tensor)
|
||||
|
||||
def __call__(self, preds, feat_maps):
|
||||
"""Generate local graphs and Graph Convolution Network input data.
|
||||
|
||||
Args:
|
||||
preds (tensor): The predicted maps.
|
||||
feat_maps (tensor): The feature maps to extract content features of
|
||||
text components.
|
||||
|
||||
Returns:
|
||||
node_feats_tensor (tensor): The graph nodes features.
|
||||
adjacent_matrix_tensor (tensor): The adjacent matrix of graph.
|
||||
pivot_inx_tensor (tensor): The pivot indices in local graph.
|
||||
knn_inx_tensor (tensor): The k nearest neighbor nodes indices in
|
||||
local graph.
|
||||
local_graph_node_tensor (tensor): The indices of nodes in local
|
||||
graph.
|
||||
text_comps (ndarray): The predicted text components.
|
||||
"""
|
||||
|
||||
pred_text_region = torch.sigmoid(preds[0, 0]).data.cpu().numpy()
|
||||
pred_center_region = torch.sigmoid(preds[0, 1]).data.cpu().numpy()
|
||||
pred_sin_map = preds[0, 2].data.cpu().numpy()
|
||||
pred_cos_map = preds[0, 3].data.cpu().numpy()
|
||||
pred_top_radius_map = preds[0, 4].data.cpu().numpy()
|
||||
pred_bot_radius_map = preds[0, 5].data.cpu().numpy()
|
||||
|
||||
comp_attribs, text_comps = self.propose_comps_and_attribs(
|
||||
pred_text_region, pred_center_region, pred_top_radius_map,
|
||||
pred_bot_radius_map, pred_sin_map, pred_cos_map)
|
||||
|
||||
if comp_attribs is None:
|
||||
none_flag = True
|
||||
return none_flag, (0, 0, 0, 0, 0, 0)
|
||||
|
||||
comp_centers = comp_attribs[:, 0:2]
|
||||
distance_matrix = euclidean_distance_matrix(comp_centers, comp_centers)
|
||||
|
||||
graph_node_geo_feats = embed_geo_feats(comp_attribs,
|
||||
self.node_geo_feat_dim)
|
||||
graph_node_geo_feats = torch.from_numpy(
|
||||
graph_node_geo_feats).float().to(preds.device)
|
||||
|
||||
batch_id = np.zeros((comp_attribs.shape[0], 1), dtype=np.float32)
|
||||
text_comps_bboxes = np.hstack(
|
||||
(batch_id, comp_attribs.astype(np.float32, copy=False)))
|
||||
text_comps_bboxes = torch.from_numpy(text_comps_bboxes).float().to(
|
||||
preds.device)
|
||||
|
||||
comp_content_feats = self.pooling(feat_maps, text_comps_bboxes)
|
||||
comp_content_feats = comp_content_feats.view(
|
||||
comp_content_feats.shape[0], -1).to(preds.device)
|
||||
node_feats = torch.cat((comp_content_feats, graph_node_geo_feats),
|
||||
dim=-1)
|
||||
|
||||
dist_sort_complete_graph = np.argsort(distance_matrix, axis=1)
|
||||
(node_feats_tensor, adjacent_matrix_tensor, pivot_inx_tensor,
|
||||
knn_inx_tensor, local_graph_node_tensor) = self.generate_local_graphs(
|
||||
dist_sort_complete_graph, node_feats)
|
||||
|
||||
none_flag = False
|
||||
return none_flag, (node_feats_tensor, adjacent_matrix_tensor,
|
||||
pivot_inx_tensor, knn_inx_tensor,
|
||||
local_graph_node_tensor, text_comps)
|
|
@ -1,354 +0,0 @@
|
|||
import functools
|
||||
import operator
|
||||
from typing import List
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from numpy.linalg import norm
|
||||
|
||||
|
||||
def normalize_adjacent_matrix(A, type='AD'):
|
||||
"""Normalize adjacent matrix for GCN.
|
||||
|
||||
This was from repo https://github.com/GXYM/DRRG.
|
||||
"""
|
||||
if type == 'DAD':
|
||||
# d is Degree of nodes A=A+I
|
||||
# L = D^-1/2 A D^-1/2
|
||||
A = A + np.eye(A.shape[0]) # A=A+I
|
||||
d = np.sum(A, axis=0)
|
||||
d_inv = np.power(d, -0.5).flatten()
|
||||
d_inv[np.isinf(d_inv)] = 0.0
|
||||
d_inv = np.diag(d_inv)
|
||||
G = A.dot(d_inv).transpose().dot(d_inv)
|
||||
G = torch.from_numpy(G)
|
||||
elif type == 'AD':
|
||||
A = A + np.eye(A.shape[0]) # A=A+I
|
||||
A = torch.from_numpy(A)
|
||||
D = A.sum(1, keepdim=True)
|
||||
G = A.div(D)
|
||||
else:
|
||||
A = A + np.eye(A.shape[0]) # A=A+I
|
||||
A = torch.from_numpy(A)
|
||||
D = A.sum(1, keepdim=True)
|
||||
D = np.diag(D)
|
||||
G = D - A
|
||||
return G
|
||||
|
||||
|
||||
def euclidean_distance_matrix(A, B):
|
||||
"""Calculate the Euclidean distance matrix."""
|
||||
|
||||
M = A.shape[0]
|
||||
N = B.shape[0]
|
||||
|
||||
assert A.shape[1] == B.shape[1]
|
||||
|
||||
A_dots = (A * A).sum(axis=1).reshape((M, 1)) * np.ones(shape=(1, N))
|
||||
B_dots = (B * B).sum(axis=1) * np.ones(shape=(M, 1))
|
||||
D_squared = A_dots + B_dots - 2 * A.dot(B.T)
|
||||
|
||||
zero_mask = np.less(D_squared, 0.0)
|
||||
D_squared[zero_mask] = 0.0
|
||||
return np.sqrt(D_squared)
|
||||
|
||||
|
||||
def embed_geo_feats(geo_feats, out_dim):
|
||||
"""Embed geometric features of text components. This was partially adapted
|
||||
from https://github.com/GXYM/DRRG.
|
||||
|
||||
Args:
|
||||
geo_feats (ndarray): The geometric features of text components.
|
||||
out_dim (int): The output dimension.
|
||||
|
||||
Returns:
|
||||
embedded_feats (ndarray): The embedded geometric features.
|
||||
"""
|
||||
assert isinstance(out_dim, int)
|
||||
assert out_dim >= geo_feats.shape[1]
|
||||
comp_num = geo_feats.shape[0]
|
||||
feat_dim = geo_feats.shape[1]
|
||||
feat_repeat_times = out_dim // feat_dim
|
||||
residue_dim = out_dim % feat_dim
|
||||
|
||||
if residue_dim > 0:
|
||||
embed_wave = np.array([
|
||||
np.power(1000, 2.0 * (j // 2) / feat_repeat_times + 1)
|
||||
for j in range(feat_repeat_times + 1)
|
||||
]).reshape((feat_repeat_times + 1, 1, 1))
|
||||
repeat_feats = np.repeat(
|
||||
np.expand_dims(geo_feats, axis=0), feat_repeat_times, axis=0)
|
||||
residue_feats = np.hstack([
|
||||
geo_feats[:, 0:residue_dim],
|
||||
np.zeros((comp_num, feat_dim - residue_dim))
|
||||
])
|
||||
repeat_feats = np.stack([repeat_feats, residue_feats], axis=0)
|
||||
embedded_feats = repeat_feats / embed_wave
|
||||
embedded_feats[:, 0::2] = np.sin(embedded_feats[:, 0::2])
|
||||
embedded_feats[:, 1::2] = np.cos(embedded_feats[:, 1::2])
|
||||
embedded_feats = np.transpose(embedded_feats, (1, 0, 2)).reshape(
|
||||
(comp_num, -1))[:, 0:out_dim]
|
||||
else:
|
||||
embed_wave = np.array([
|
||||
np.power(1000, 2.0 * (j // 2) / feat_repeat_times)
|
||||
for j in range(feat_repeat_times)
|
||||
]).reshape((feat_repeat_times, 1, 1))
|
||||
repeat_feats = np.repeat(
|
||||
np.expand_dims(geo_feats, axis=0), feat_repeat_times, axis=0)
|
||||
embedded_feats = repeat_feats / embed_wave
|
||||
embedded_feats[:, 0::2] = np.sin(embedded_feats[:, 0::2])
|
||||
embedded_feats[:, 1::2] = np.cos(embedded_feats[:, 1::2])
|
||||
embedded_feats = np.transpose(embedded_feats, (1, 0, 2)).reshape(
|
||||
(comp_num, -1))
|
||||
|
||||
return embedded_feats
|
||||
|
||||
|
||||
def min_connect_path(list_all: List[list]):
|
||||
"""This is from https://github.com/GXYM/DRRG."""
|
||||
|
||||
list_nodo = list_all.copy()
|
||||
res: List[List[int]] = []
|
||||
ept = [0, 0]
|
||||
|
||||
def norm2(a, b):
|
||||
return ((a[0] - b[0])**2 + (a[1] - b[1])**2)**0.5
|
||||
|
||||
dict00 = {}
|
||||
dict11 = {}
|
||||
ept[0] = list_nodo[0]
|
||||
ept[1] = list_nodo[0]
|
||||
list_nodo.remove(list_nodo[0])
|
||||
while list_nodo:
|
||||
for i in list_nodo:
|
||||
length0 = norm2(i, ept[0])
|
||||
dict00[length0] = [i, ept[0]]
|
||||
length1 = norm2(ept[1], i)
|
||||
dict11[length1] = [ept[1], i]
|
||||
key0 = min(dict00.keys())
|
||||
key1 = min(dict11.keys())
|
||||
|
||||
if key0 <= key1:
|
||||
ss = dict00[key0][0]
|
||||
ee = dict00[key0][1]
|
||||
res.insert(0, [list_all.index(ss), list_all.index(ee)])
|
||||
list_nodo.remove(ss)
|
||||
ept[0] = ss
|
||||
else:
|
||||
ss = dict11[key1][0]
|
||||
ee = dict11[key1][1]
|
||||
res.append([list_all.index(ss), list_all.index(ee)])
|
||||
list_nodo.remove(ee)
|
||||
ept[1] = ee
|
||||
|
||||
dict00 = {}
|
||||
dict11 = {}
|
||||
|
||||
path = functools.reduce(operator.concat, res)
|
||||
path = sorted(set(path), key=path.index)
|
||||
|
||||
return res, path
|
||||
|
||||
|
||||
def clusters2labels(clusters, node_num):
|
||||
"""This is from https://github.com/GXYM/DRRG."""
|
||||
labels = (-1) * np.ones((node_num, ))
|
||||
for cluster_inx, cluster in enumerate(clusters):
|
||||
for node in cluster:
|
||||
labels[node.inx] = cluster_inx
|
||||
assert np.sum(labels < 0) < 1
|
||||
return labels
|
||||
|
||||
|
||||
def remove_single(text_comps, pred):
|
||||
"""Remove isolated single text components.
|
||||
|
||||
This is from https://github.com/GXYM/DRRG.
|
||||
"""
|
||||
single_flags = np.zeros_like(pred)
|
||||
pred_labels = np.unique(pred)
|
||||
for label in pred_labels:
|
||||
current_label_flag = pred == label
|
||||
if np.sum(current_label_flag) == 1:
|
||||
single_flags[np.where(current_label_flag)[0][0]] = 1
|
||||
remain_inx = [i for i in range(len(pred)) if not single_flags[i]]
|
||||
remain_inx = np.asarray(remain_inx)
|
||||
return text_comps[remain_inx, :], pred[remain_inx]
|
||||
|
||||
|
||||
class Node:
|
||||
|
||||
def __init__(self, inx):
|
||||
self.__inx = inx
|
||||
self.__links = set()
|
||||
|
||||
@property
|
||||
def inx(self):
|
||||
return self.__inx
|
||||
|
||||
@property
|
||||
def links(self):
|
||||
return set(self.__links)
|
||||
|
||||
def add_link(self, other, score):
|
||||
self.__links.add(other)
|
||||
other.__links.add(self)
|
||||
|
||||
|
||||
def connected_components(nodes, score_dict, thr):
|
||||
"""Connected components searching.
|
||||
|
||||
This is from https://github.com/GXYM/DRRG.
|
||||
"""
|
||||
|
||||
result = []
|
||||
nodes = set(nodes)
|
||||
while nodes:
|
||||
node = nodes.pop()
|
||||
group = {node}
|
||||
queue = [node]
|
||||
while queue:
|
||||
node = queue.pop(0)
|
||||
if thr is not None:
|
||||
neighbors = {
|
||||
linked_neighbor
|
||||
for linked_neighbor in node.links if score_dict[tuple(
|
||||
sorted([node.inx, linked_neighbor.inx]))] >= thr
|
||||
}
|
||||
else:
|
||||
neighbors = node.links
|
||||
neighbors.difference_update(group)
|
||||
nodes.difference_update(neighbors)
|
||||
group.update(neighbors)
|
||||
queue.extend(neighbors)
|
||||
result.append(group)
|
||||
return result
|
||||
|
||||
|
||||
def graph_propagation(edges,
|
||||
scores,
|
||||
link_thr,
|
||||
bboxes=None,
|
||||
dis_thr=50,
|
||||
pool='avg'):
|
||||
"""Propagate graph linkage score information.
|
||||
|
||||
This is from repo https://github.com/GXYM/DRRG.
|
||||
"""
|
||||
edges = np.sort(edges, axis=1)
|
||||
|
||||
score_dict = {}
|
||||
if pool is None:
|
||||
for i, edge in enumerate(edges):
|
||||
score_dict[edge[0], edge[1]] = scores[i]
|
||||
elif pool == 'avg':
|
||||
for i, edge in enumerate(edges):
|
||||
if bboxes is not None:
|
||||
box1 = bboxes[edge[0]][:8].reshape(4, 2)
|
||||
box2 = bboxes[edge[1]][:8].reshape(4, 2)
|
||||
center1 = np.mean(box1, axis=0)
|
||||
center2 = np.mean(box2, axis=0)
|
||||
dst = norm(center1 - center2)
|
||||
if dst > dis_thr:
|
||||
scores[i] = 0
|
||||
if (edge[0], edge[1]) in score_dict:
|
||||
score_dict[edge[0], edge[1]] = 0.5 * (
|
||||
score_dict[edge[0], edge[1]] + scores[i])
|
||||
else:
|
||||
score_dict[edge[0], edge[1]] = scores[i]
|
||||
|
||||
elif pool == 'max':
|
||||
for i, edge in enumerate(edges):
|
||||
if (edge[0], edge[1]) in score_dict:
|
||||
score_dict[edge[0],
|
||||
edge[1]] = max(score_dict[edge[0], edge[1]],
|
||||
scores[i])
|
||||
else:
|
||||
score_dict[edge[0], edge[1]] = scores[i]
|
||||
else:
|
||||
raise ValueError('Pooling operation not supported')
|
||||
|
||||
nodes = np.sort(np.unique(edges.flatten()))
|
||||
mapping = -1 * np.ones((nodes.max() + 1), dtype=np.int)
|
||||
mapping[nodes] = np.arange(nodes.shape[0])
|
||||
link_inx = mapping[edges]
|
||||
vertex = [Node(node) for node in nodes]
|
||||
for link, score in zip(link_inx, scores):
|
||||
vertex[link[0]].add_link(vertex[link[1]], score)
|
||||
|
||||
clusters = connected_components(vertex, score_dict, link_thr)
|
||||
|
||||
return clusters
|
||||
|
||||
|
||||
def in_contour(cont, point):
|
||||
x, y = point
|
||||
return cv2.pointPolygonTest(cont, (x, y), False) > 0
|
||||
|
||||
|
||||
def select_edge(cont, box):
|
||||
"""This is from repo https://github.com/GXYM/DRRG."""
|
||||
cont = np.array(cont)
|
||||
box = box.astype(np.int32)
|
||||
c1 = np.array(0.5 * (box[0, :] + box[3, :]), dtype=np.int)
|
||||
c2 = np.array(0.5 * (box[1, :] + box[2, :]), dtype=np.int)
|
||||
|
||||
if not in_contour(cont, c1):
|
||||
return [box[0, :].tolist(), box[3, :].tolist()]
|
||||
elif not in_contour(cont, c2):
|
||||
return [box[1, :].tolist(), box[2, :].tolist()]
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def comps2boundary(text_comps, final_pred):
|
||||
"""Propose text components and generate local graphs.
|
||||
|
||||
This is from repo https://github.com/GXYM/DRRG.
|
||||
"""
|
||||
bbox_contours = list()
|
||||
for inx in range(0, int(np.max(final_pred)) + 1):
|
||||
current_instance = np.where(final_pred == inx)
|
||||
boxes = text_comps[current_instance, :8].reshape(
|
||||
(-1, 4, 2)).astype(np.int32)
|
||||
|
||||
boundary_point = None
|
||||
if boxes.shape[0] > 1:
|
||||
centers = np.mean(boxes, axis=1).astype(np.int32).tolist()
|
||||
paths, routes_path = min_connect_path(centers)
|
||||
boxes = boxes[routes_path]
|
||||
top = np.mean(boxes[:, 0:2, :], axis=1).astype(np.int32).tolist()
|
||||
bot = np.mean(boxes[:, 2:4, :], axis=1).astype(np.int32).tolist()
|
||||
edge1 = select_edge(top + bot[::-1], boxes[0])
|
||||
edge2 = select_edge(top + bot[::-1], boxes[-1])
|
||||
if edge1 is not None:
|
||||
top.insert(0, edge1[0])
|
||||
bot.insert(0, edge1[1])
|
||||
if edge2 is not None:
|
||||
top.append(edge2[0])
|
||||
bot.append(edge2[1])
|
||||
boundary_point = np.array(top + bot[::-1])
|
||||
|
||||
elif boxes.shape[0] == 1:
|
||||
top = boxes[0, 0:2, :].astype(np.int32).tolist()
|
||||
bot = boxes[0, 2:4:-1, :].astype(np.int32).tolist()
|
||||
boundary_point = np.array(top + bot)
|
||||
|
||||
if boundary_point is None:
|
||||
continue
|
||||
|
||||
boundary_point = [p for p in boundary_point.flatten().tolist()]
|
||||
bbox_contours.append(boundary_point)
|
||||
|
||||
return bbox_contours
|
||||
|
||||
|
||||
def merge_text_comps(edges, scores, text_comps, link_thr):
|
||||
"""Merge text components into text instance."""
|
||||
clusters = graph_propagation(edges, scores, link_thr)
|
||||
pred_labels = clusters2labels(clusters, text_comps.shape[0])
|
||||
text_comps, final_pred = remove_single(text_comps, pred_labels)
|
||||
boundaries = comps2boundary(text_comps, final_pred)
|
||||
|
||||
return boundaries
|
|
@ -1,5 +1,5 @@
|
|||
from .ce_loss import CELoss, SARLoss, TFLoss
|
||||
from .ctc_loss import CTCLoss
|
||||
from .seg_loss import CAFCNLoss, SegLoss
|
||||
from .seg_loss import SegLoss
|
||||
|
||||
__all__ = ['CELoss', 'SARLoss', 'CTCLoss', 'TFLoss', 'SegLoss', 'CAFCNLoss']
|
||||
__all__ = ['CELoss', 'SARLoss', 'CTCLoss', 'TFLoss', 'SegLoss']
|
||||
|
|
|
@ -3,7 +3,6 @@ import torch.nn as nn
|
|||
import torch.nn.functional as F
|
||||
|
||||
from mmdet.models.builder import LOSSES
|
||||
from mmocr.models.common.losses import DiceLoss
|
||||
|
||||
|
||||
@LOSSES.register_module()
|
||||
|
@ -68,109 +67,3 @@ class SegLoss(nn.Module):
|
|||
losses['loss_seg'] = loss_seg
|
||||
|
||||
return losses
|
||||
|
||||
|
||||
@LOSSES.register_module()
|
||||
class CAFCNLoss(SegLoss):
|
||||
"""Implementation of loss module in `CA-FCN.
|
||||
|
||||
<https://arxiv.org/pdf/1809.06508.pdf>`_
|
||||
|
||||
Args:
|
||||
alpha (float): Weight ratio of attention loss.
|
||||
attn_s2_downsample_ratio (float): Downsample ratio
|
||||
of attention map from output stage 2.
|
||||
attn_s3_downsample_ratio (float): Downsample ratio
|
||||
of attention map from output stage 3.
|
||||
seg_downsample_ratio (float): Downsample ratio of
|
||||
segmentation map.
|
||||
attn_with_dice_loss (bool): If True, use dice_loss for attention,
|
||||
else BCELoss.
|
||||
with_attn (bool): If True, include attention loss, else
|
||||
segmentation loss only.
|
||||
seg_with_loss_weight (bool): If True, set weight for
|
||||
segmentation loss.
|
||||
ignore_index (int): Specifies a target value that is ignored
|
||||
and does not contribute to the input gradient.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
alpha=1.0,
|
||||
attn_s2_downsample_ratio=0.25,
|
||||
attn_s3_downsample_ratio=0.125,
|
||||
seg_downsample_ratio=0.5,
|
||||
attn_with_dice_loss=False,
|
||||
with_attn=True,
|
||||
seg_with_loss_weight=True,
|
||||
ignore_index=255):
|
||||
super().__init__(seg_downsample_ratio, seg_with_loss_weight,
|
||||
ignore_index)
|
||||
assert isinstance(alpha, (int, float))
|
||||
assert isinstance(attn_s2_downsample_ratio, (int, float))
|
||||
assert isinstance(attn_s3_downsample_ratio, (int, float))
|
||||
assert 0 < attn_s2_downsample_ratio <= 1
|
||||
assert 0 < attn_s3_downsample_ratio <= 1
|
||||
|
||||
self.alpha = alpha
|
||||
self.attn_s2_downsample_ratio = attn_s2_downsample_ratio
|
||||
self.attn_s3_downsample_ratio = attn_s3_downsample_ratio
|
||||
self.with_attn = with_attn
|
||||
self.attn_with_dice_loss = attn_with_dice_loss
|
||||
|
||||
# attention loss
|
||||
if with_attn:
|
||||
if attn_with_dice_loss:
|
||||
self.criterion_attn = DiceLoss()
|
||||
else:
|
||||
self.criterion_attn = nn.BCELoss(reduction='none')
|
||||
|
||||
def attn_loss(self, out_neck, gt_kernels):
|
||||
attn_map_s2 = out_neck[0] # bsz * 2 * H/4 * W/4
|
||||
|
||||
mask_s2 = torch.stack([
|
||||
item[2].rescale(self.attn_s2_downsample_ratio).to_tensor(
|
||||
torch.float, attn_map_s2.device) for item in gt_kernels
|
||||
])
|
||||
|
||||
attn_target_s2 = torch.stack([
|
||||
item[0].rescale(self.attn_s2_downsample_ratio).to_tensor(
|
||||
torch.float, attn_map_s2.device) for item in gt_kernels
|
||||
])
|
||||
|
||||
mask_s3 = torch.stack([
|
||||
item[2].rescale(self.attn_s3_downsample_ratio).to_tensor(
|
||||
torch.float, attn_map_s2.device) for item in gt_kernels
|
||||
])
|
||||
|
||||
attn_target_s3 = torch.stack([
|
||||
item[0].rescale(self.attn_s3_downsample_ratio).to_tensor(
|
||||
torch.float, attn_map_s2.device) for item in gt_kernels
|
||||
])
|
||||
|
||||
targets = [
|
||||
attn_target_s2, attn_target_s3, attn_target_s3, attn_target_s3
|
||||
]
|
||||
|
||||
masks = [mask_s2, mask_s3, mask_s3, mask_s3]
|
||||
|
||||
loss_attn = 0.
|
||||
for i in range(len(out_neck) - 1):
|
||||
pred = out_neck[i]
|
||||
if self.attn_with_dice_loss:
|
||||
loss_attn += self.criterion_attn(pred, targets[i], masks[i])
|
||||
else:
|
||||
loss_attn += torch.sum(
|
||||
self.criterion_attn(pred, targets[i]) *
|
||||
masks[i]) / torch.sum(masks[i])
|
||||
|
||||
return loss_attn
|
||||
|
||||
def forward(self, out_neck, out_head, gt_kernels):
|
||||
|
||||
losses = super().forward(out_neck, out_head, gt_kernels)
|
||||
|
||||
if self.with_attn:
|
||||
loss_attn = self.attn_loss(out_neck, gt_kernels)
|
||||
losses['loss_attn'] = loss_attn
|
||||
|
||||
return losses
|
||||
|
|
|
@ -1,5 +1,3 @@
|
|||
from .cafcn_neck import CAFCNNeck
|
||||
from .fpn_ocr import FPNOCR
|
||||
from .fpn_seg import FPNSeg
|
||||
|
||||
__all__ = ['CAFCNNeck', 'FPNSeg', 'FPNOCR']
|
||||
__all__ = ['FPNOCR']
|
||||
|
|
|
@ -1,223 +0,0 @@
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
from mmcv.cnn import ConvModule
|
||||
from mmcv.ops import DeformConv2dPack
|
||||
from torch import nn
|
||||
|
||||
from mmdet.models.builder import NECKS
|
||||
|
||||
|
||||
class CharAttn(nn.Module):
|
||||
"""Implementation of Character attention module in `CA-FCN.
|
||||
|
||||
<https://arxiv.org/pdf/1809.06508.pdf>`_
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels=128, out_channels=128, deformable=False):
|
||||
super().__init__()
|
||||
assert isinstance(in_channels, int)
|
||||
assert isinstance(deformable, bool)
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.deformable = deformable
|
||||
|
||||
# attention layers
|
||||
self.attn_layer = nn.Sequential(
|
||||
ConvModule(
|
||||
in_channels,
|
||||
in_channels,
|
||||
3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
norm_cfg=dict(type='BN')),
|
||||
ConvModule(
|
||||
in_channels,
|
||||
1,
|
||||
3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='Sigmoid')))
|
||||
|
||||
conv_kwargs = dict(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=(3, 3),
|
||||
stride=1,
|
||||
padding=1)
|
||||
if self.deformable:
|
||||
self.conv = DeformConv2dPack(**conv_kwargs)
|
||||
else:
|
||||
self.conv = nn.Conv2d(**conv_kwargs)
|
||||
self.bn = nn.BatchNorm2d(out_channels)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
def forward(self, in_feat):
|
||||
# Calculate attn map
|
||||
attn_map = self.attn_layer(in_feat) # N * 1 * H * W
|
||||
|
||||
in_feat = self.relu(self.bn(self.conv(in_feat)))
|
||||
|
||||
out_feat_map = self._upsample_mul(in_feat, 1 + attn_map)
|
||||
|
||||
return out_feat_map, attn_map
|
||||
|
||||
def _upsample_add(self, x, y):
|
||||
return F.interpolate(x, size=y.size()[2:]) + y
|
||||
|
||||
def _upsample_mul(self, x, y):
|
||||
return F.interpolate(x, size=y.size()[2:]) * y
|
||||
|
||||
|
||||
class FeatGenerator(nn.Module):
|
||||
"""Generate attention-augmented stage feature from backbone stage
|
||||
feature."""
|
||||
|
||||
def __init__(self,
|
||||
in_channels=512,
|
||||
out_channels=128,
|
||||
deformable=True,
|
||||
concat=False,
|
||||
upsample=False,
|
||||
with_attn=True):
|
||||
super().__init__()
|
||||
|
||||
self.concat = concat
|
||||
self.upsample = upsample
|
||||
self.with_attn = with_attn
|
||||
|
||||
if with_attn:
|
||||
self.char_attn = CharAttn(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
deformable=deformable)
|
||||
else:
|
||||
self.char_attn = ConvModule(
|
||||
in_channels,
|
||||
out_channels,
|
||||
3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
norm_cfg=dict(type='BN'))
|
||||
|
||||
if concat:
|
||||
self.conv_to_concat = ConvModule(
|
||||
out_channels,
|
||||
out_channels,
|
||||
3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
norm_cfg=dict(type='BN'))
|
||||
|
||||
kernel_size = (3, 1) if deformable else 3
|
||||
padding = (1, 0) if deformable else 1
|
||||
tmp_in_channels = out_channels * 2 if concat else out_channels
|
||||
|
||||
self.conv_after_concat = ConvModule(
|
||||
tmp_in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
padding=padding,
|
||||
norm_cfg=dict(type='BN'))
|
||||
|
||||
def forward(self, x, y=None, size=None):
|
||||
if self.with_attn:
|
||||
feat_map, attn_map = self.char_attn(x)
|
||||
else:
|
||||
feat_map = self.char_attn(x)
|
||||
attn_map = feat_map
|
||||
|
||||
if self.concat:
|
||||
y = self.conv_to_concat(y)
|
||||
feat_map = torch.cat((y, feat_map), dim=1)
|
||||
|
||||
feat_map = self.conv_after_concat(feat_map)
|
||||
|
||||
if self.upsample:
|
||||
feat_map = F.interpolate(feat_map, size)
|
||||
|
||||
return attn_map, feat_map
|
||||
|
||||
|
||||
@NECKS.register_module()
|
||||
class CAFCNNeck(nn.Module):
|
||||
"""Implementation of neck module in `CA-FCN.
|
||||
|
||||
<https://arxiv.org/pdf/1809.06508.pdf>`_
|
||||
|
||||
Args:
|
||||
in_channels (list[int]): Number of input channels for each scale.
|
||||
out_channels (int): Number of output channels for each scale.
|
||||
deformable (bool): If True, use deformable conv.
|
||||
with_attn (bool): If True, add attention for each output feature map.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels=[128, 256, 512, 512],
|
||||
out_channels=128,
|
||||
deformable=True,
|
||||
with_attn=True):
|
||||
super().__init__()
|
||||
|
||||
self.deformable = deformable
|
||||
self.with_attn = with_attn
|
||||
|
||||
# stage_in5_to_out5
|
||||
self.s5 = FeatGenerator(
|
||||
in_channels=in_channels[-1],
|
||||
out_channels=out_channels,
|
||||
deformable=deformable,
|
||||
concat=False,
|
||||
with_attn=with_attn)
|
||||
|
||||
# stage_in4_to_out4
|
||||
self.s4 = FeatGenerator(
|
||||
in_channels=in_channels[-2],
|
||||
out_channels=out_channels,
|
||||
deformable=deformable,
|
||||
concat=True,
|
||||
with_attn=with_attn)
|
||||
|
||||
# stage_in3_to_out3
|
||||
self.s3 = FeatGenerator(
|
||||
in_channels=in_channels[-3],
|
||||
out_channels=out_channels,
|
||||
deformable=False,
|
||||
concat=True,
|
||||
upsample=True,
|
||||
with_attn=with_attn)
|
||||
|
||||
# stage_in2_to_out2
|
||||
self.s2 = FeatGenerator(
|
||||
in_channels=in_channels[-4],
|
||||
out_channels=out_channels,
|
||||
deformable=False,
|
||||
concat=True,
|
||||
upsample=True,
|
||||
with_attn=with_attn)
|
||||
|
||||
def init_weights(self):
|
||||
pass
|
||||
|
||||
def forward(self, feats):
|
||||
in_stage1 = feats[0]
|
||||
in_stage2, in_stage3 = feats[1], feats[2]
|
||||
in_stage4, in_stage5 = feats[3], feats[4]
|
||||
# out stage 5
|
||||
out_s5_attn_map, out_s5 = self.s5(in_stage5)
|
||||
|
||||
# out stage 4
|
||||
out_s4_attn_map, out_s4 = self.s4(in_stage4, out_s5)
|
||||
|
||||
# out stage 3
|
||||
out_s3_attn_map, out_s3 = self.s3(in_stage3, out_s4,
|
||||
in_stage2.size()[2:])
|
||||
|
||||
# out stage 2
|
||||
out_s2_attn_map, out_s2 = self.s2(in_stage2, out_s3,
|
||||
in_stage1.size()[2:])
|
||||
|
||||
return (out_s2_attn_map, out_s3_attn_map, out_s4_attn_map,
|
||||
out_s5_attn_map, out_s2)
|
|
@ -1,43 +0,0 @@
|
|||
import torch.nn.functional as F
|
||||
from mmcv.runner import auto_fp16
|
||||
|
||||
from mmdet.models.builder import NECKS
|
||||
from mmdet.models.necks import FPN
|
||||
|
||||
|
||||
@NECKS.register_module()
|
||||
class FPNSeg(FPN):
|
||||
"""Feature Pyramid Network for segmentation based text recognition.
|
||||
|
||||
Args:
|
||||
in_channels (list[int]): Number of input channels for each scale.
|
||||
out_channels (int): Number of output channels for each scale.
|
||||
num_outs (int): Number of output scales.
|
||||
upsample_param (dict | None): Config dict for interpolate layer.
|
||||
Default: `dict(scale_factor=1.0, mode='nearest')`
|
||||
last_stage_only (bool): If True, output last stage of FPN only.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
num_outs,
|
||||
upsample_param=None,
|
||||
last_stage_only=True):
|
||||
super().__init__(in_channels, out_channels, num_outs)
|
||||
self.upsample_param = upsample_param
|
||||
self.last_stage_only = last_stage_only
|
||||
|
||||
@auto_fp16()
|
||||
def forward(self, inputs):
|
||||
outs = super().forward(inputs)
|
||||
|
||||
outs = list(outs)
|
||||
|
||||
if self.upsample_param is not None:
|
||||
outs[0] = F.interpolate(outs[0], **self.upsample_param)
|
||||
|
||||
if self.last_stage_only:
|
||||
return tuple(outs[0:1])
|
||||
|
||||
return tuple(outs[::-1])
|
|
@ -1,5 +1,4 @@
|
|||
from .base import BaseRecognizer
|
||||
from .cafcn import CAFCNNet
|
||||
from .crnn import CRNNNet
|
||||
from .encode_decode_recognizer import EncodeDecodeRecognizer
|
||||
from .nrtr import NRTR
|
||||
|
@ -9,5 +8,5 @@ from .seg_recognizer import SegRecognizer
|
|||
|
||||
__all__ = [
|
||||
'BaseRecognizer', 'EncodeDecodeRecognizer', 'CRNNNet', 'SARNet', 'NRTR',
|
||||
'SegRecognizer', 'RobustScanner', 'CAFCNNet'
|
||||
'SegRecognizer', 'RobustScanner'
|
||||
]
|
||||
|
|
|
@ -1,7 +0,0 @@
|
|||
from mmdet.models.builder import DETECTORS
|
||||
from .seg_recognizer import SegRecognizer
|
||||
|
||||
|
||||
@DETECTORS.register_module()
|
||||
class CAFCNNet(SegRecognizer):
|
||||
"""Implementation of `CAFCN <https://arxiv.org/pdf/1809.06508.pdf>`_"""
|
|
@ -1,11 +1,8 @@
|
|||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmdet.core import BitmapMasks
|
||||
from mmocr.models.common.losses import DiceLoss
|
||||
from mmocr.models.textrecog.losses import (CAFCNLoss, CELoss, CTCLoss, SARLoss,
|
||||
TFLoss)
|
||||
from mmocr.models.textrecog.losses import CELoss, CTCLoss, SARLoss, TFLoss
|
||||
|
||||
|
||||
def test_ctc_loss():
|
||||
|
@ -69,44 +66,6 @@ def test_tf_loss():
|
|||
assert new_target.shape == torch.Size([1, 9])
|
||||
|
||||
|
||||
def test_cafcn_loss():
|
||||
with pytest.raises(AssertionError):
|
||||
CAFCNLoss(alpha='1')
|
||||
with pytest.raises(AssertionError):
|
||||
CAFCNLoss(attn_s2_downsample_ratio='2')
|
||||
with pytest.raises(AssertionError):
|
||||
CAFCNLoss(attn_s3_downsample_ratio='1.5')
|
||||
with pytest.raises(AssertionError):
|
||||
CAFCNLoss(seg_downsample_ratio='1.5')
|
||||
with pytest.raises(AssertionError):
|
||||
CAFCNLoss(attn_s2_downsample_ratio=2)
|
||||
with pytest.raises(AssertionError):
|
||||
CAFCNLoss(attn_s3_downsample_ratio=1.5)
|
||||
with pytest.raises(AssertionError):
|
||||
CAFCNLoss(seg_downsample_ratio=1.5)
|
||||
|
||||
bsz = 1
|
||||
H = W = 64
|
||||
out_neck = (torch.ones(bsz, 1, H // 4, W // 4) * 0.5,
|
||||
torch.ones(bsz, 1, H // 8, W // 8) * 0.5,
|
||||
torch.ones(bsz, 1, H // 8, W // 8) * 0.5,
|
||||
torch.ones(bsz, 1, H // 8, W // 8) * 0.5,
|
||||
torch.ones(bsz, 1, H // 2, W // 2) * 0.5)
|
||||
out_head = torch.rand(bsz, 37, H // 2, W // 2)
|
||||
|
||||
attn_tgt = np.zeros((H, W), dtype=np.float32)
|
||||
segm_tgt = np.zeros((H, W), dtype=np.float32)
|
||||
mask = np.ones((H, W), dtype=np.float32)
|
||||
gt_kernels = BitmapMasks([attn_tgt, segm_tgt, mask], H, W)
|
||||
|
||||
cafcn_loss = CAFCNLoss()
|
||||
losses = cafcn_loss(out_neck, out_head, [gt_kernels])
|
||||
assert isinstance(losses, dict)
|
||||
assert 'loss_seg' in losses
|
||||
assert torch.allclose(losses['loss_seg'],
|
||||
torch.tensor(losses['loss_seg'].item()).float())
|
||||
|
||||
|
||||
def test_dice_loss():
|
||||
with pytest.raises(AssertionError):
|
||||
DiceLoss(eps='1')
|
||||
|
|
|
@ -1,47 +0,0 @@
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
from mmocr.models.textrecog.necks.cafcn_neck import (CAFCNNeck, CharAttn,
|
||||
FeatGenerator)
|
||||
|
||||
|
||||
def test_char_attn():
|
||||
with pytest.raises(AssertionError):
|
||||
CharAttn(in_channels=5.0)
|
||||
with pytest.raises(AssertionError):
|
||||
CharAttn(deformable='deformabel')
|
||||
|
||||
in_feat = torch.rand(1, 128, 32, 32)
|
||||
char_attn = CharAttn()
|
||||
out_feat_map, attn_map = char_attn(in_feat)
|
||||
assert attn_map.shape == torch.Size([1, 1, 32, 32])
|
||||
assert out_feat_map.shape == torch.Size([1, 128, 32, 32])
|
||||
|
||||
|
||||
def test_feat_generator():
|
||||
in_feat = torch.rand(1, 128, 32, 32)
|
||||
feat_generator = FeatGenerator(
|
||||
in_channels=128, out_channels=128, deformable=False)
|
||||
|
||||
attn_map, feat_map = feat_generator(in_feat)
|
||||
assert attn_map.shape == torch.Size([1, 1, 32, 32])
|
||||
assert feat_map.shape == torch.Size([1, 128, 32, 32])
|
||||
|
||||
|
||||
def test_cafcn_neck():
|
||||
in_s1 = torch.rand(1, 64, 64, 64)
|
||||
in_s2 = torch.rand(1, 128, 32, 32)
|
||||
in_s3 = torch.rand(1, 256, 16, 16)
|
||||
in_s4 = torch.rand(1, 512, 16, 16)
|
||||
in_s5 = torch.rand(1, 512, 16, 16)
|
||||
|
||||
cafcn_neck = CAFCNNeck(deformable=False)
|
||||
cafcn_neck.init_weights()
|
||||
cafcn_neck.train()
|
||||
|
||||
out_neck = cafcn_neck((in_s1, in_s2, in_s3, in_s4, in_s5))
|
||||
assert out_neck[0].shape == torch.Size([1, 1, 32, 32])
|
||||
assert out_neck[1].shape == torch.Size([1, 1, 16, 16])
|
||||
assert out_neck[2].shape == torch.Size([1, 1, 16, 16])
|
||||
assert out_neck[3].shape == torch.Size([1, 1, 16, 16])
|
||||
assert out_neck[4].shape == torch.Size([1, 128, 64, 64])
|
Loading…
Reference in New Issue