diff --git a/mmocr/models/textdet/__init__.py b/mmocr/models/textdet/__init__.py index 3dcd957d..bf95e0f7 100644 --- a/mmocr/models/textdet/__init__.py +++ b/mmocr/models/textdet/__init__.py @@ -1,6 +1,5 @@ from .dense_heads import * # noqa: F401,F403 from .detectors import * # noqa: F401,F403 from .losses import * # noqa: F401,F403 -from .modules import * # noqa: F401,F403 from .necks import * # noqa: F401,F403 from .postprocess import * # noqa: F401,F403 diff --git a/mmocr/models/textdet/dense_heads/__init__.py b/mmocr/models/textdet/dense_heads/__init__.py index 975f2f59..8f227b80 100644 --- a/mmocr/models/textdet/dense_heads/__init__.py +++ b/mmocr/models/textdet/dense_heads/__init__.py @@ -1,10 +1,7 @@ from .db_head import DBHead -from .drrg_head import DRRGHead from .head_mixin import HeadMixin from .pan_head import PANHead from .pse_head import PSEHead from .textsnake_head import TextSnakeHead -__all__ = [ - 'PSEHead', 'PANHead', 'DBHead', 'HeadMixin', 'DRRGHead', 'TextSnakeHead' -] +__all__ = ['PSEHead', 'PANHead', 'DBHead', 'HeadMixin', 'TextSnakeHead'] diff --git a/mmocr/models/textdet/dense_heads/drrg_head.py b/mmocr/models/textdet/dense_heads/drrg_head.py deleted file mode 100644 index 8de153ae..00000000 --- a/mmocr/models/textdet/dense_heads/drrg_head.py +++ /dev/null @@ -1,193 +0,0 @@ -import numpy as np -import torch -import torch.nn as nn -import torch.nn.functional as F -from mmcv.cnn import normal_init - -from mmdet.models.builder import HEADS, build_loss -from mmocr.models.textdet.modules import (GCN, LocalGraphs, - ProposalLocalGraphs, - merge_text_comps) -from .head_mixin import HeadMixin - - -@HEADS.register_module() -class DRRGHead(HeadMixin, nn.Module): - """The class for DRRG head: Deep Relational Reasoning Graph Network for - Arbitrary Shape Text Detection. - - [https://arxiv.org/abs/2003.07493] - - Args: - k_at_hops (tuple(int)): The number of i-hop neighbors, - i = 1, 2, ..., h. - active_connection (int): The number of two hop neighbors deem as - linked to a pivot. - node_geo_feat_dim (int): The dimension of embedded geometric features - of a component. - pooling_scale (float): The spatial scale of RRoI-Aligning. - pooling_output_size (tuple(int)): The size of RRoI-Aligning output. - graph_filter_thr (float): The threshold to filter identical local - graphs. - comp_shrink_ratio (float): The shrink ratio of text components. - nms_thr (float): The locality-aware NMS threshold. - min_width (float): The minimum width of text components. - max_width (float): The maximum width of text components. - comp_ratio (float): The reciprocal of aspect ratio of text components. - text_region_thr (float): The threshold for text region probability map. - center_region_thr (float): The threshold of text center region - probability map. - center_region_area_thr (int): The threshold of filtering small-size - text center region. - link_thr (float): The threshold for connected components searching. - """ - - def __init__(self, - in_channels, - k_at_hops=(8, 4), - active_connection=3, - node_geo_feat_dim=120, - pooling_scale=1.0, - pooling_output_size=(3, 4), - graph_filter_thr=0.75, - comp_shrink_ratio=0.95, - nms_thr=0.25, - min_width=8.0, - max_width=24.0, - comp_ratio=0.65, - text_region_thr=0.6, - center_region_thr=0.4, - center_region_area_thr=100, - link_thr=0.85, - loss=dict(type='DRRGLoss'), - train_cfg=None, - test_cfg=None): - super().__init__() - - assert isinstance(in_channels, int) - assert isinstance(k_at_hops, tuple) - assert isinstance(active_connection, int) - assert isinstance(node_geo_feat_dim, int) - assert isinstance(pooling_scale, float) - assert isinstance(pooling_output_size, tuple) - assert isinstance(graph_filter_thr, float) - assert isinstance(comp_shrink_ratio, float) - assert isinstance(nms_thr, float) - assert isinstance(min_width, float) - assert isinstance(max_width, float) - assert isinstance(comp_ratio, float) - assert isinstance(center_region_area_thr, int) - assert isinstance(link_thr, float) - - self.in_channels = in_channels - self.out_channels = 6 - self.downsample_ratio = 1.0 - self.k_at_hops = k_at_hops - self.active_connection = active_connection - self.node_geo_feat_dim = node_geo_feat_dim - self.pooling_scale = pooling_scale - self.pooling_output_size = pooling_output_size - self.graph_filter_thr = graph_filter_thr - self.comp_shrink_ratio = comp_shrink_ratio - self.nms_thr = nms_thr - self.min_width = min_width - self.max_width = max_width - self.comp_ratio = comp_ratio - self.text_region_thr = text_region_thr - self.center_region_thr = center_region_thr - self.center_region_area_thr = center_region_area_thr - self.link_thr = link_thr - self.loss_module = build_loss(loss) - self.train_cfg = train_cfg - self.test_cfg = test_cfg - - self.out_conv = nn.Conv2d( - in_channels=self.in_channels, - out_channels=self.out_channels, - kernel_size=1, - stride=1, - padding=0) - self.init_weights() - - self.graph_train = LocalGraphs(self.k_at_hops, self.active_connection, - self.node_geo_feat_dim, - self.pooling_scale, - self.pooling_output_size, - self.graph_filter_thr) - - self.graph_test = ProposalLocalGraphs( - self.k_at_hops, self.active_connection, self.node_geo_feat_dim, - self.pooling_scale, self.pooling_output_size, self.nms_thr, - self.min_width, self.max_width, self.comp_shrink_ratio, - self.comp_ratio, self.text_region_thr, self.center_region_thr, - self.center_region_area_thr) - - pool_w, pool_h = self.pooling_output_size - gcn_in_dim = (pool_w * pool_h) * ( - self.in_channels + self.out_channels) + self.node_geo_feat_dim - self.gcn = GCN(gcn_in_dim, 32) - - def init_weights(self): - normal_init(self.out_conv, mean=0, std=0.01) - - def forward(self, inputs, text_comp_feats): - - pred_maps = self.out_conv(inputs) - - feat_maps = torch.cat([inputs, pred_maps], dim=1) - node_feats, adjacent_matrices, knn_inx, gt_labels = self.graph_train( - feat_maps, np.array(text_comp_feats)) - - gcn_pred = self.gcn(node_feats, adjacent_matrices, knn_inx) - - return (pred_maps, (gcn_pred, gt_labels)) - - def single_test(self, feat_maps): - - pred_maps = self.out_conv(feat_maps) - feat_maps = torch.cat([feat_maps[0], pred_maps], dim=1) - - none_flag, graph_data = self.graph_test(pred_maps, feat_maps) - - (node_feats, adjacent_matrix, pivot_inx, knn_inx, local_graph_nodes, - text_comps) = graph_data - - if none_flag: - return None, None, None - - adjacent_matrix, pivot_inx, knn_inx = map( - lambda x: x.to(feat_maps.device), - (adjacent_matrix, pivot_inx, knn_inx)) - gcn_pred = self.gcn_model(node_feats, adjacent_matrix, knn_inx) - - pred_labels = F.softmax(gcn_pred, dim=1) - - edges = [] - scores = [] - local_graph_nodes = local_graph_nodes.long().squeeze().cpu().numpy() - graph_num = node_feats.size(0) - - for graph_inx in range(graph_num): - pivot = pivot_inx[graph_inx].int().item() - nodes = local_graph_nodes[graph_inx] - for neighbor_inx, neighbor in enumerate(knn_inx[graph_inx]): - neighbor = neighbor.item() - edges.append([nodes[pivot], nodes[neighbor]]) - scores.append(pred_labels[graph_inx * (knn_inx.shape[1]) + - neighbor_inx, 1].item()) - - edges = np.asarray(edges) - scores = np.asarray(scores) - - return edges, scores, text_comps - - def get_boundary(self, edges, scores, text_comps): - - boundaries = [] - if edges is not None: - boundaries = merge_text_comps(edges, scores, text_comps, - self.link_thr) - - results = dict(boundary_result=boundaries) - - return results diff --git a/mmocr/models/textdet/detectors/__init__.py b/mmocr/models/textdet/detectors/__init__.py index 802ca09c..77e342c2 100644 --- a/mmocr/models/textdet/detectors/__init__.py +++ b/mmocr/models/textdet/detectors/__init__.py @@ -4,10 +4,9 @@ from .dbnet import DBNet # isort:skip from .ocr_mask_rcnn import OCRMaskRCNN # isort:skip from .panet import PANet # isort:skip from .psenet import PSENet # isort:skip -from .drrg import DRRG # isort:skip from .textsnake import TextSnake # isort:skip __all__ = [ 'TextDetectorMixin', 'SingleStageTextDetector', 'OCRMaskRCNN', 'DBNet', - 'PANet', 'PSENet', 'DRRG', 'TextSnake' + 'PANet', 'PSENet', 'TextSnake' ] diff --git a/mmocr/models/textdet/detectors/drrg.py b/mmocr/models/textdet/detectors/drrg.py deleted file mode 100644 index c96d40b0..00000000 --- a/mmocr/models/textdet/detectors/drrg.py +++ /dev/null @@ -1,38 +0,0 @@ -from mmdet.models.builder import DETECTORS -from . import SingleStageTextDetector, TextDetectorMixin - - -@DETECTORS.register_module() -class DRRG(TextDetectorMixin, SingleStageTextDetector): - """The class for implementing DRRG text detector: Deep Relational Reasoning - Graph Network for Arbitrary Shape Text Detection. - - [https://arxiv.org/abs/2003.07493] - """ - - def forward_train(self, img, img_metas, **kwargs): - """ - Args: - img (Tensor): Input images of shape (N, C, H, W). - Typically these should be mean centered and std scaled. - img_metas (list[dict]): A list of image info dict where each dict - has: 'img_shape', 'scale_factor', 'flip', and may also contain - 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. - For details of the values of these keys, see - :class:`mmdet.datasets.pipelines.Collect`. - Returns: - dict[str, Tensor]: A dictionary of loss components. - """ - x = self.extract_feat(img) - gt_comp_attribs = kwargs.pop('gt_comp_attribs') - preds = self.bbox_head(x, gt_comp_attribs) - losses = self.bbox_head.loss(preds, **kwargs) - return losses - - def simple_test(self, img, img_metas, rescale=False): - - x = self.extract_feat(img) - outs = self.bbox_head.single_test(x, img) - boundaries = self.bbox_head.get_boundary(*outs, img_metas, rescale) - - return [boundaries] diff --git a/mmocr/models/textdet/losses/__init__.py b/mmocr/models/textdet/losses/__init__.py index d671b281..eaa4d9cd 100644 --- a/mmocr/models/textdet/losses/__init__.py +++ b/mmocr/models/textdet/losses/__init__.py @@ -1,7 +1,6 @@ from .db_loss import DBLoss -from .drrg_loss import DRRGLoss from .pan_loss import PANLoss from .pse_loss import PSELoss from .textsnake_loss import TextSnakeLoss -__all__ = ['PANLoss', 'PSELoss', 'DBLoss', 'DRRGLoss', 'TextSnakeLoss'] +__all__ = ['PANLoss', 'PSELoss', 'DBLoss', 'TextSnakeLoss'] diff --git a/mmocr/models/textdet/losses/drrg_loss.py b/mmocr/models/textdet/losses/drrg_loss.py deleted file mode 100644 index 3caaf28f..00000000 --- a/mmocr/models/textdet/losses/drrg_loss.py +++ /dev/null @@ -1,206 +0,0 @@ -import torch -import torch.nn.functional as F -from torch import nn - -from mmdet.core import BitmapMasks -from mmdet.models.builder import LOSSES -from mmocr.utils import check_argument - - -@LOSSES.register_module() -class DRRGLoss(nn.Module): - """The class for implementing DRRG loss: Deep Relational Reasoning Graph - Network for Arbitrary Shape Text Detection. - - [https://arxiv.org/abs/1908.05900] This is partially adapted from - https://github.com/GXYM/DRRG. - """ - - def __init__(self, ohem_ratio=3.0): - """Initialization. - - Args: - ohem_ratio (float): The negative/positive ratio in OHEM. - """ - super().__init__() - self.ohem_ratio = ohem_ratio - - def balance_bce_loss(self, pred, gt, mask): - - assert pred.shape == gt.shape == mask.shape - positive = gt * mask - negative = (1 - gt) * mask - positive_count = int(positive.float().sum()) - gt = gt.float() - if positive_count > 0: - loss = F.binary_cross_entropy(pred, gt, reduction='none') - positive_loss = torch.sum(loss * positive.float()) - negative_loss = loss * negative.float() - negative_count = min( - int(negative.float().sum()), - int(positive_count * self.ohem_ratio)) - else: - positive_loss = torch.tensor(0.0) - loss = F.binary_cross_entropy(pred, gt, reduction='none') - negative_loss = loss * negative.float() - negative_count = 100 - negative_loss, _ = torch.topk(negative_loss.view(-1), negative_count) - - balance_loss = (positive_loss + torch.sum(negative_loss)) / ( - float(positive_count + negative_count) + 1e-5) - - return balance_loss - - def gcn_loss(self, gcn_data): - - gcn_pred, gt_labels = gcn_data - gt_labels = gt_labels.view(-1).to(gcn_pred.device) - loss = F.cross_entropy(gcn_pred, gt_labels) - - return loss - - def bitmasks2tensor(self, bitmasks, target_sz): - """Convert Bitmasks to tensor. - - Args: - bitmasks (list[BitmapMasks]): The BitmapMasks list. Each item is - for one img. - target_sz (tuple(int, int)): The target tensor size HxW. - - Returns - results (list[tensor]): The list of kernel tensors. Each - element is for one kernel level. - """ - assert check_argument.is_type_list(bitmasks, BitmapMasks) - assert isinstance(target_sz, tuple) - - batch_size = len(bitmasks) - num_masks = len(bitmasks[0]) - - results = [] - - for level_inx in range(num_masks): - kernel = [] - for batch_inx in range(batch_size): - mask = torch.from_numpy(bitmasks[batch_inx].masks[level_inx]) - # hxw - mask_sz = mask.shape - # left, right, top, bottom - pad = [ - 0, target_sz[1] - mask_sz[1], 0, target_sz[0] - mask_sz[0] - ] - mask = F.pad(mask, pad, mode='constant', value=0) - kernel.append(mask) - kernel = torch.stack(kernel) - results.append(kernel) - - return results - - def forward(self, preds, downsample_ratio, gt_text_mask, - gt_center_region_mask, gt_mask, gt_top_height_map, - gt_bot_height_map, gt_sin_map, gt_cos_map): - - assert isinstance(preds, tuple) - assert isinstance(downsample_ratio, float) - assert abs(downsample_ratio - 1.0) < 1e-5 - assert check_argument.is_type_list(gt_text_mask, BitmapMasks) - assert check_argument.is_type_list(gt_center_region_mask, BitmapMasks) - assert check_argument.is_type_list(gt_mask, BitmapMasks) - assert check_argument.is_type_list(gt_top_height_map, BitmapMasks) - assert check_argument.is_type_list(gt_bot_height_map, BitmapMasks) - assert check_argument.is_type_list(gt_sin_map, BitmapMasks) - assert check_argument.is_type_list(gt_cos_map, BitmapMasks) - - pred_maps, gcn_data = preds - pred_text_region = pred_maps[:, 0, :, :] - pred_center_region = pred_maps[:, 1, :, :] - pred_sin_map = pred_maps[:, 2, :, :] - pred_cos_map = pred_maps[:, 3, :, :] - pred_top_height_map = pred_maps[:, 4, :, :] - pred_bot_height_map = pred_maps[:, 5, :, :] - feature_sz = pred_maps.size() - - # bitmask 2 tensor - mapping = { - 'gt_text_mask': gt_text_mask, - 'gt_center_region_mask': gt_center_region_mask, - 'gt_mask': gt_mask, - 'gt_top_height_map': gt_top_height_map, - 'gt_bot_height_map': gt_bot_height_map, - 'gt_sin_map': gt_sin_map, - 'gt_cos_map': gt_cos_map - } - gt = {} - for key, value in mapping.items(): - gt[key] = value - gt[key] = self.bitmasks2tensor(gt[key], feature_sz[2:]) - gt[key] = [item.to(pred_maps.device) for item in gt[key]] - - scale = torch.sqrt(1.0 / (pred_sin_map**2 + pred_cos_map**2 + 1e-8)) - pred_sin_map = pred_sin_map * scale - pred_cos_map = pred_cos_map * scale - - loss_text = self.balance_bce_loss( - torch.sigmoid(pred_text_region), gt['gt_text_mask'][0], - gt['gt_mask'][0]) - - text_mask = (gt['gt_text_mask'][0] * gt['gt_mask'][0]).float() - negative_text_mask = ((1 - gt['gt_text_mask'][0]) * - gt['gt_mask'][0]).float() - gt_center_region_mask = gt['gt_center_region_mask'][0].float() - loss_center = F.binary_cross_entropy( - torch.sigmoid(pred_center_region), - gt_center_region_mask, - reduction='none') - if int(text_mask.sum()) > 0: - loss_center_positive = torch.sum( - loss_center * text_mask) / torch.sum(text_mask) - else: - loss_center_positive = torch.tensor(0.0) - loss_center_negative = torch.sum( - loss_center * negative_text_mask) / torch.sum(negative_text_mask) - loss_center = loss_center_positive + 0.5 * loss_center_negative - - center_mask = (gt['gt_center_region_mask'][0] * - gt['gt_mask'][0]).float() - if int(center_mask.sum()) > 0: - ones = torch.ones_like( - gt['gt_top_height_map'][0], dtype=torch.float) - loss_top = F.smooth_l1_loss( - pred_top_height_map / (gt['gt_top_height_map'][0] + 1e-2), - ones, - reduction='none') - loss_bot = F.smooth_l1_loss( - pred_bot_height_map / (gt['gt_bot_height_map'][0] + 1e-2), - ones, - reduction='none') - gt_height = ( - gt['gt_top_height_map'][0] + gt['gt_bot_height_map'][0]) - loss_height = torch.sum( - (torch.log(gt_height + 1) * - (loss_top + loss_bot)) * center_mask) / torch.sum(center_mask) - - loss_sin = torch.sum( - F.smooth_l1_loss( - pred_sin_map, gt['gt_sin_map'][0], reduction='none') * - center_mask) / torch.sum(center_mask) - loss_cos = torch.sum( - F.smooth_l1_loss( - pred_cos_map, gt['gt_cos_map'][0], reduction='none') * - center_mask) / torch.sum(center_mask) - else: - loss_height = torch.tensor(0.0) - loss_sin = torch.tensor(0.0) - loss_cos = torch.tensor(0.0) - - loss_gcn = self.gcn_loss(gcn_data) - - results = dict( - loss_text=loss_text, - loss_center=loss_center, - loss_height=loss_height, - loss_sin=loss_sin, - loss_cos=loss_cos, - loss_gcn=loss_gcn) - - return results diff --git a/mmocr/models/textdet/modules/__init__.py b/mmocr/models/textdet/modules/__init__.py deleted file mode 100644 index 9d62df75..00000000 --- a/mmocr/models/textdet/modules/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -from .gcn import GCN -from .local_graph import LocalGraphs -from .proposal_local_graph import ProposalLocalGraphs -from .utils import merge_text_comps - -__all__ = ['LocalGraphs', 'ProposalLocalGraphs', 'GCN', 'merge_text_comps'] diff --git a/mmocr/models/textdet/modules/gcn.py b/mmocr/models/textdet/modules/gcn.py deleted file mode 100644 index f2fde039..00000000 --- a/mmocr/models/textdet/modules/gcn.py +++ /dev/null @@ -1,80 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -from torch.nn import init - - -class MeanAggregator(nn.Module): - - def forward(self, features, A): - x = torch.bmm(A, features) - return x - - -class GraphConv(nn.Module): - - def __init__(self, in_dim, out_dim): - super(GraphConv, self).__init__() - self.in_dim = in_dim - self.out_dim = out_dim - self.weight = nn.Parameter(torch.FloatTensor(in_dim * 2, out_dim)) - self.bias = nn.Parameter(torch.FloatTensor(out_dim)) - init.xavier_uniform_(self.weight) - init.constant_(self.bias, 0) - self.agg = MeanAggregator() - - def forward(self, features, A): - b, n, d = features.shape - assert d == self.in_dim - agg_feats = self.agg(features, A) - cat_feats = torch.cat([features, agg_feats], dim=2) - out = torch.einsum('bnd,df->bnf', (cat_feats, self.weight)) - out = F.relu(out + self.bias) - return out - - -class GCN(nn.Module): - """Predict linkage between instances. This was from repo - https://github.com/Zhongdao/gcn_clustering: Linkage Based Face Clustering - via Graph Convolution Network. - - [https://arxiv.org/abs/1903.11306] - - Args: - in_dim(int): The input dimension. - out_dim(int): The output dimension. - """ - - def __init__(self, in_dim, out_dim): - super(GCN, self).__init__() - self.bn0 = nn.BatchNorm1d(in_dim, affine=False).float() - self.conv1 = GraphConv(in_dim, 512) - self.conv2 = GraphConv(512, 256) - self.conv3 = GraphConv(256, 128) - self.conv4 = GraphConv(128, 64) - - self.classifier = nn.Sequential( - nn.Linear(64, out_dim), nn.PReLU(out_dim), nn.Linear(out_dim, 2)) - - def forward(self, x, A, one_hop_indexes, train=True): - - B, N, D = x.shape - - x = x.view(-1, D) - x = self.bn0(x) - x = x.view(B, N, D) - - x = self.conv1(x, A) - x = self.conv2(x, A) - x = self.conv3(x, A) - x = self.conv4(x, A) - k1 = one_hop_indexes.size(-1) - dout = x.size(-1) - edge_feat = torch.zeros(B, k1, dout) - for b in range(B): - edge_feat[b, :, :] = x[b, one_hop_indexes[b]] - edge_feat = edge_feat.view(-1, dout).to(x.device) - pred = self.classifier(edge_feat) - - # shape: (B*k1)x2 - return pred diff --git a/mmocr/models/textdet/modules/local_graph.py b/mmocr/models/textdet/modules/local_graph.py deleted file mode 100644 index a1b0848c..00000000 --- a/mmocr/models/textdet/modules/local_graph.py +++ /dev/null @@ -1,307 +0,0 @@ -import numpy as np -import torch - -from mmocr.models.utils import RROIAlign -from .utils import (embed_geo_feats, euclidean_distance_matrix, - normalize_adjacent_matrix) - - -class LocalGraphs: - """Generate local graphs for GCN to predict which instance a text component - belongs to. This was partially adapted from https://github.com/GXYM/DRRG: - Deep Relational Reasoning Graph Network for Arbitrary Shape Text Detection. - - [https://arxiv.org/abs/2003.07493] - - Args: - k_at_hops (tuple(int)): The number of h-hop neighbors. - active_connection (int): The number of neighbors deem as linked to a - pivot. - node_geo_feat_dim (int): The dimension of embedded geometric features - of a component. - pooling_scale (float): The spatial scale of RRoI-Aligning. - pooling_output_size (tuple(int)): The size of RRoI-Aligning output. - local_graph_filter_thr (float): The threshold to filter out identical - local graphs. - """ - - def __init__(self, k_at_hops, active_connection, node_geo_feat_dim, - pooling_scale, pooling_output_size, local_graph_filter_thr): - - assert isinstance(k_at_hops, tuple) - assert isinstance(active_connection, int) - assert isinstance(node_geo_feat_dim, int) - assert isinstance(pooling_scale, float) - assert isinstance(pooling_output_size, tuple) - assert isinstance(local_graph_filter_thr, float) - - self.k_at_hops = k_at_hops - self.local_graph_depth = len(self.k_at_hops) - self.active_connection = active_connection - self.node_geo_feat_dim = node_geo_feat_dim - self.pooling = RROIAlign(pooling_output_size, pooling_scale) - self.local_graph_filter_thr = local_graph_filter_thr - - def generate_local_graphs(self, sorted_complete_graph, gt_belong_labels): - """Generate local graphs for GCN to predict which instance a text - component belongs to. - - Args: - sorted_complete_graph (ndarray): The complete graph where nodes are - sorted according to their Euclidean distance. - gt_belong_labels (ndarray): The ground truth labels define which - instance text components (nodes in graphs) belong to. - - Returns: - local_graph_node_list (list): The list of local graph neighbors of - pivots. - knn_graph_neighbor_list (list): The list of k nearest neighbors of - pivots. - """ - - assert sorted_complete_graph.ndim == 2 - assert (sorted_complete_graph.shape[0] == - sorted_complete_graph.shape[1] == gt_belong_labels.shape[0]) - - knn_graphs = sorted_complete_graph[:, :self.k_at_hops[0] + 1] - local_graph_node_list = list() - knn_graph_neighbor_list = list() - for pivot_inx, knn_graph in enumerate(knn_graphs): - - h_hop_neighbor_list = list() - one_hop_neighbors = set(knn_graph[1:]) - h_hop_neighbor_list.append(one_hop_neighbors) - - for hop_inx in range(1, self.local_graph_depth): - h_hop_neighbor_list.append(set()) - for last_hop_neighbor_inx in h_hop_neighbor_list[-2]: - h_hop_neighbor_list[-1].update( - set(sorted_complete_graph[last_hop_neighbor_inx] - [1:self.k_at_hops[hop_inx] + 1])) - - hops_neighbor_set = set( - [node for hop in h_hop_neighbor_list for node in hop]) - hops_neighbor_list = list(hops_neighbor_set) - hops_neighbor_list.insert(0, pivot_inx) - - if pivot_inx < 1: - local_graph_node_list.append(hops_neighbor_list) - knn_graph_neighbor_list.append(one_hop_neighbors) - else: - add_flag = True - for graph_inx in range(len(knn_graph_neighbor_list)): - knn_graph_neighbors = knn_graph_neighbor_list[graph_inx] - local_graph_nodes = local_graph_node_list[graph_inx] - - node_union_num = len( - list( - set(knn_graph_neighbors).union( - set(one_hop_neighbors)))) - node_intersect_num = len( - list( - set(knn_graph_neighbors).intersection( - set(one_hop_neighbors)))) - one_hop_iou = node_intersect_num / (node_union_num + 1e-8) - - if (one_hop_iou > self.local_graph_filter_thr - and pivot_inx in knn_graph_neighbors - and gt_belong_labels[local_graph_nodes[0]] - == gt_belong_labels[pivot_inx] - and gt_belong_labels[local_graph_nodes[0]] != 0): - add_flag = False - break - if add_flag: - local_graph_node_list.append(hops_neighbor_list) - knn_graph_neighbor_list.append(one_hop_neighbors) - - return local_graph_node_list, knn_graph_neighbor_list - - def generate_gcn_input(self, node_feat_batch, belong_label_batch, - local_graph_node_batch, knn_graph_neighbor_batch, - sorted_complete_graph): - """Generate graph convolution network input data. - - Args: - node_feat_batch (List[Tensor]): The node feature batch. - belong_label_batch (List[ndarray]): The text component belong label - batch. - local_graph_node_batch (List[List[list]]): The local graph - neighbors batch. - knn_graph_neighbor_batch (List[List[set]]): The knn graph neighbor - batch. - sorted_complete_graph (List[ndarray]): The complete graph sorted - according to the Euclidean distance. - - Returns: - node_feat_batch_tensor (Tensor): The node features of Graph - Convolutional Network (GCN). - adjacent_mat_batch_tensor (Tensor): The adjacent matrices. - knn_inx_batch_tensor (Tensor): The indices of k nearest neighbors. - gt_linkage_batch_tensor (Tensor): The surpervision signal of GCN - for linkage prediction. - """ - - assert isinstance(node_feat_batch, list) - assert isinstance(belong_label_batch, list) - assert isinstance(local_graph_node_batch, list) - assert isinstance(knn_graph_neighbor_batch, list) - assert isinstance(sorted_complete_graph, list) - - max_graph_node_num = max([ - len(local_graph_nodes) - for local_graph_node_list in local_graph_node_batch - for local_graph_nodes in local_graph_node_list - ]) - - node_feat_batch_list = list() - adjacent_matrix_batch_list = list() - knn_inx_batch_list = list() - gt_linkage_batch_list = list() - - for batch_inx, sorted_neighbors in enumerate(sorted_complete_graph): - node_feats = node_feat_batch[batch_inx] - local_graph_list = local_graph_node_batch[batch_inx] - knn_graph_neighbor_list = knn_graph_neighbor_batch[batch_inx] - belong_labels = belong_label_batch[batch_inx] - - for graph_inx in range(len(local_graph_list)): - local_graph_nodes = local_graph_list[graph_inx] - local_graph_node_num = len(local_graph_nodes) - pivot_inx = local_graph_nodes[0] - knn_graph_neighbors = knn_graph_neighbor_list[graph_inx] - node_to_graph_inx = { - j: i - for i, j in enumerate(local_graph_nodes) - } - - knn_inx_in_local_graph = torch.tensor( - [node_to_graph_inx[i] for i in knn_graph_neighbors], - dtype=torch.long) - pivot_feats = node_feats[torch.tensor( - pivot_inx, dtype=torch.long)] - normalized_feats = node_feats[torch.tensor( - local_graph_nodes, dtype=torch.long)] - pivot_feats - - adjacent_matrix = np.zeros( - (local_graph_node_num, local_graph_node_num)) - pad_normalized_feats = torch.cat([ - normalized_feats, - torch.zeros(max_graph_node_num - local_graph_node_num, - normalized_feats.shape[1]).to( - node_feats.device) - ], - dim=0) - - for node in local_graph_nodes: - neighbors = sorted_neighbors[node, - 1:self.active_connection + 1] - for neighbor in neighbors: - if neighbor in local_graph_nodes: - adjacent_matrix[node_to_graph_inx[node], - node_to_graph_inx[neighbor]] = 1 - adjacent_matrix[node_to_graph_inx[neighbor], - node_to_graph_inx[node]] = 1 - - adjacent_matrix = normalize_adjacent_matrix( - adjacent_matrix, type='DAD') - adjacent_matrix_tensor = torch.zeros(max_graph_node_num, - max_graph_node_num).to( - node_feats.device) - adjacent_matrix_tensor[:local_graph_node_num, : - local_graph_node_num] = adjacent_matrix - - local_graph_labels = torch.from_numpy( - belong_labels[local_graph_nodes]).type(torch.long) - knn_labels = local_graph_labels[knn_inx_in_local_graph] - edge_labels = ((belong_labels[pivot_inx] == knn_labels) - & (belong_labels[pivot_inx] > 0)).long() - - node_feat_batch_list.append(pad_normalized_feats) - adjacent_matrix_batch_list.append(adjacent_matrix_tensor) - knn_inx_batch_list.append(knn_inx_in_local_graph) - gt_linkage_batch_list.append(edge_labels) - - node_feat_batch_tensor = torch.stack(node_feat_batch_list, 0) - adjacent_mat_batch_tensor = torch.stack(adjacent_matrix_batch_list, 0) - knn_inx_batch_tensor = torch.stack(knn_inx_batch_list, 0) - gt_linkage_batch_tensor = torch.stack(gt_linkage_batch_list, 0) - - return (node_feat_batch_tensor, adjacent_mat_batch_tensor, - knn_inx_batch_tensor, gt_linkage_batch_tensor) - - def __call__(self, feat_maps, comp_attribs): - """Generate local graphs. - - Args: - feat_maps (Tensor): The feature maps to propose node features in - graph. - comp_attribs (ndarray): The text components attributes. - - Returns: - node_feats_batch (Tensor): The node features of Graph Convolutional - Network(GCN). - adjacent_matrices_batch (Tensor): The adjacent matrices. - knn_inx_batch (Tensor): The indices of k nearest neighbors. - gt_linkage_batch (Tensor): The surpervision signal of GCN for - linkage prediction. - """ - - assert isinstance(feat_maps, torch.Tensor) - assert comp_attribs.shape[2] == 8 - - dist_sort_graph_batch_list = [] - local_graph_node_batch_list = [] - knn_graph_neighbor_batch_list = [] - node_feature_batch_list = [] - belong_label_batch_list = [] - - for batch_inx in range(comp_attribs.shape[0]): - comp_num = int(comp_attribs[batch_inx, 0, 0]) - comp_geo_attribs = comp_attribs[batch_inx, :comp_num, 1:7] - node_belong_labels = comp_attribs[batch_inx, :comp_num, - 7].astype(np.int32) - - comp_centers = comp_geo_attribs[:, 0:2] - distance_matrix = euclidean_distance_matrix( - comp_centers, comp_centers) - - graph_node_geo_feats = embed_geo_feats(comp_geo_attribs, - self.node_geo_feat_dim) - graph_node_geo_feats = torch.from_numpy( - graph_node_geo_feats).float().to(feat_maps.device) - - batch_id = np.zeros( - (comp_geo_attribs.shape[0], 1), dtype=np.float32) * batch_inx - text_comps = np.hstack( - (batch_id, comp_geo_attribs.astype(np.float32))) - text_comps = torch.from_numpy(text_comps).float().to( - feat_maps.device) - - comp_content_feats = self.pooling( - feat_maps[batch_inx].unsqueeze(0), text_comps) - comp_content_feats = comp_content_feats.view( - comp_content_feats.shape[0], -1).to(feat_maps.device) - node_feats = torch.cat((comp_content_feats, graph_node_geo_feats), - dim=-1) - - dist_sort_complete_graph = np.argsort(distance_matrix, axis=1) - (local_graph_nodes, - knn_graph_neighbors) = self.generate_local_graphs( - dist_sort_complete_graph, node_belong_labels) - - node_feature_batch_list.append(node_feats) - belong_label_batch_list.append(node_belong_labels) - local_graph_node_batch_list.append(local_graph_nodes) - knn_graph_neighbor_batch_list.append(knn_graph_neighbors) - dist_sort_graph_batch_list.append(dist_sort_complete_graph) - - (node_feats_batch, adjacent_matrices_batch, knn_inx_batch, - gt_linkage_batch) = \ - self.generate_gcn_input(node_feature_batch_list, - belong_label_batch_list, - local_graph_node_batch_list, - knn_graph_neighbor_batch_list, - dist_sort_graph_batch_list) - - return (node_feats_batch, adjacent_matrices_batch, knn_inx_batch, - gt_linkage_batch) diff --git a/mmocr/models/textdet/modules/proposal_local_graph.py b/mmocr/models/textdet/modules/proposal_local_graph.py deleted file mode 100644 index 901740d6..00000000 --- a/mmocr/models/textdet/modules/proposal_local_graph.py +++ /dev/null @@ -1,418 +0,0 @@ -import cv2 -import numpy as np -import torch - -# from mmocr.models.textdet.postprocess import la_nms -from mmocr.models.utils import RROIAlign -from .utils import (embed_geo_feats, euclidean_distance_matrix, - normalize_adjacent_matrix) - - -class ProposalLocalGraphs: - """Propose text components and generate local graphs. This was partially - adapted from https://github.com/GXYM/DRRG: Deep Relational Reasoning Graph - Network for Arbitrary Shape Text Detection. - - [https://arxiv.org/abs/2003.07493] - - Args: - k_at_hops (tuple(int)): The number of i-hop neighbors, - i = 1, 2, ..., h. - active_connection (int): The number of two hop neighbors deem as linked - to a pivot. - node_geo_feat_dim (int): The dimension of embedded geometric features - of a component. - pooling_scale (float): The spatial scale of RRoI-Aligning. - pooling_output_size (tuple(int)): The size of RRoI-Aligning output. - nms_thr (float): The locality-aware NMS threshold. - min_width (float): The minimum width of text components. - max_width (float): The maximum width of text components. - comp_shrink_ratio (float): The shrink ratio of text components. - comp_ratio (float): The reciprocal of aspect ratio of text components. - text_region_thr (float): The threshold for text region probability map. - center_region_thr (float): The threshold for text center region - probability map. - center_region_area_thr (int): The threshold for filtering out - small-size text center region. - """ - - def __init__(self, k_at_hops, active_connection, node_geo_feat_dim, - pooling_scale, pooling_output_size, nms_thr, min_width, - max_width, comp_shrink_ratio, comp_ratio, text_region_thr, - center_region_thr, center_region_area_thr): - - assert isinstance(k_at_hops, tuple) - assert isinstance(active_connection, int) - assert isinstance(node_geo_feat_dim, int) - assert isinstance(pooling_scale, float) - assert isinstance(pooling_output_size, tuple) - assert isinstance(nms_thr, float) - assert isinstance(min_width, float) - assert isinstance(max_width, float) - assert isinstance(comp_shrink_ratio, float) - assert isinstance(comp_ratio, float) - assert isinstance(text_region_thr, float) - assert isinstance(center_region_thr, float) - assert isinstance(center_region_area_thr, int) - - self.k_at_hops = k_at_hops - self.active_connection = active_connection - self.local_graph_depth = len(self.k_at_hops) - self.node_geo_feat_dim = node_geo_feat_dim - self.pooling = RROIAlign(pooling_output_size, pooling_scale) - self.nms_thr = nms_thr - self.min_width = min_width - self.max_width = max_width - self.comp_shrink_ratio = comp_shrink_ratio - self.comp_ratio = comp_ratio - self.text_region_thr = text_region_thr - self.center_region_thr = center_region_thr - self.center_region_area_thr = center_region_area_thr - - def fill_hole(self, input_mask): - h, w = input_mask.shape - canvas = np.zeros((h + 2, w + 2), np.uint8) - canvas[1:h + 1, 1:w + 1] = input_mask.copy() - - mask = np.zeros((h + 4, w + 4), np.uint8) - - cv2.floodFill(canvas, mask, (0, 0), 1) - canvas = canvas[1:h + 1, 1:w + 1].astype(np.bool) - - return (~canvas | input_mask.astype(np.uint8)) - - def propose_comps(self, top_radius_map, bot_radius_map, sin_map, cos_map, - score_map, min_width, max_width, comp_shrink_ratio, - comp_ratio): - """Generate text components. - - Args: - top_radius_map (ndarray): The predicted distance map from each - pixel in text center region to top sideline. - bot_radius_map (ndarray): The predicted distance map from each - pixel in text center region to bottom sideline. - sin_map (ndarray): The predicted sin(theta) map. - cos_map (ndarray): The predicted cos(theta) map. - score_map (ndarray): The score map for NMS. - min_width (float): The minimum width of text components. - max_width (float): The maximum width of text components. - comp_shrink_ratio (float): The shrink ratio of text components. - comp_ratio (float): The reciprocal of aspect ratio of text - components. - - Returns: - text_comps (ndarray): The text components. - """ - - comp_centers = np.argwhere(score_map > 0) - comp_centers = comp_centers[np.argsort(comp_centers[:, 0])] - y = comp_centers[:, 0] - x = comp_centers[:, 1] - - top_radius = top_radius_map[y, x].reshape((-1, 1)) * comp_shrink_ratio - bot_radius = bot_radius_map[y, x].reshape((-1, 1)) * comp_shrink_ratio - sin = sin_map[y, x].reshape((-1, 1)) - cos = cos_map[y, x].reshape((-1, 1)) - - top_mid_x_offset = top_radius * cos - top_mid_y_offset = top_radius * sin - bot_mid_x_offset = bot_radius * cos - bot_mid_y_offset = bot_radius * sin - - top_mid_pnt = comp_centers + np.hstack( - [top_mid_y_offset, top_mid_x_offset]) - bot_mid_pnt = comp_centers - np.hstack( - [bot_mid_y_offset, bot_mid_x_offset]) - - width = (top_radius + bot_radius) * comp_ratio - width = np.clip(width, min_width, max_width) - - top_left = top_mid_pnt - np.hstack([width * cos, -width * sin - ])[:, ::-1] - top_right = top_mid_pnt + np.hstack([width * cos, -width * sin - ])[:, ::-1] - bot_right = bot_mid_pnt + np.hstack([width * cos, -width * sin - ])[:, ::-1] - bot_left = bot_mid_pnt - np.hstack([width * cos, -width * sin - ])[:, ::-1] - - text_comps = np.hstack([top_left, top_right, bot_right, bot_left]) - score = score_map[y, x].reshape((-1, 1)) - text_comps = np.hstack([text_comps, score]) - - return text_comps - - def propose_comps_and_attribs(self, text_region_map, center_region_map, - top_radius_map, bot_radius_map, sin_map, - cos_map): - """Generate text components and attributes. - - Args: - text_region_map (ndarray): The predicted text region probability - map. - center_region_map (ndarray): The predicted text center region - probability map. - top_radius_map (ndarray): The predicted distance map from each - pixel in text center region to top sideline. - bot_radius_map (ndarray): The predicted distance map from each - pixel in text center region to bottom sideline. - sin_map (ndarray): The predicted sin(theta) map. - cos_map (ndarray): The predicted cos(theta) map. - - Returns: - comp_attribs (ndarray): The text components attributes. - text_comps (ndarray): The text components. - """ - - assert (text_region_map.shape == center_region_map.shape == - top_radius_map.shape == bot_radius_map == sin_map.shape == - cos_map.shape) - text_mask = text_region_map > self.text_region_thr - center_region_mask = (center_region_map > - self.center_region_thr) * text_mask - - scale = np.sqrt(1.0 / (sin_map**2 + cos_map**2)) - sin_map, cos_map = sin_map * scale, cos_map * scale - - center_region_mask = self.fill_hole(center_region_mask) - center_region_contours, _ = cv2.findContours( - center_region_mask.astype(np.uint8), cv2.RETR_TREE, - cv2.CHAIN_APPROX_SIMPLE) - - mask = np.zeros_like(center_region_mask) - comp_list = [] - for contour in center_region_contours: - current_center_mask = mask.copy() - cv2.drawContours(current_center_mask, [contour], -1, 1, -1) - if current_center_mask.sum() <= self.center_region_area_thr: - continue - score_map = text_region_map * current_center_mask - - text_comp = self.propose_comps(top_radius_map, bot_radius_map, - sin_map, cos_map, score_map, - self.min_width, self.max_width, - self.comp_shrink_ratio, - self.comp_ratio) - - # text_comp = la_nms(text_comp.astype('float32'), self.nms_thr) - - text_comp_mask = mask.copy() - text_comps_bboxes = text_comp[:, :8].reshape( - (-1, 4, 2)).astype(np.int32) - - cv2.drawContours(text_comp_mask, text_comps_bboxes, -1, 1, -1) - if (text_comp_mask * text_mask).sum() < text_comp_mask.sum() * 0.5: - continue - - comp_list.append(text_comp) - - if len(comp_list) <= 0: - return None, None - - text_comps = np.vstack(comp_list) - - centers = np.mean( - text_comps[:, :8].reshape((-1, 4, 2)), axis=1).astype(np.int32) - - x = centers[:, 0] - y = centers[:, 1] - - h = top_radius_map[y, x].reshape( - (-1, 1)) + bot_radius_map[y, x].reshape((-1, 1)) - w = np.clip(h * self.comp_ratio, self.min_width, self.max_width) - sin = sin_map[y, x].reshape((-1, 1)) - cos = cos_map[y, x].reshape((-1, 1)) - x = x.reshape((-1, 1)) - y = y.reshape((-1, 1)) - comp_attribs = np.hstack([x, y, h, w, cos, sin]) - - return comp_attribs, text_comps - - def generate_local_graphs(self, sorted_complete_graph, node_feats): - """Generate local graphs and Graph Convolution Network input data. - - Args: - sorted_complete_graph (ndarray): The complete graph where nodes are - sorted according to their Euclidean distance. - node_feats (tensor): The graph nodes features. - - Returns: - node_feats_tensor (tensor): The graph nodes features. - adjacent_matrix_tensor (tensor): The adjacent matrix of graph. - pivot_inx_tensor (tensor): The pivot indices in local graph. - knn_inx_tensor (tensor): The k nearest neighbor nodes indexes in - local graph. - local_graph_node_tensor (tensor): The indices of nodes in local - graph. - """ - - assert sorted_complete_graph.ndim == 2 - assert (sorted_complete_graph.shape[0] == - sorted_complete_graph.shape[1] == node_feats.shape[0]) - - knn_graphs = sorted_complete_graph[:, :self.k_at_hops[0] + 1] - local_graph_node_list = list() - knn_graph_neighbor_list = list() - for pivot_inx, knn_graph in enumerate(knn_graphs): - - h_hop_neighbor_list = list() - one_hop_neighbors = set(knn_graph[1:]) - h_hop_neighbor_list.append(one_hop_neighbors) - - for hop_inx in range(1, self.local_graph_depth): - h_hop_neighbor_list.append(set()) - for last_hop_neighbor_inx in h_hop_neighbor_list[-2]: - h_hop_neighbor_list[-1].update( - set(sorted_complete_graph[last_hop_neighbor_inx] - [1:self.k_at_hops[hop_inx] + 1])) - - hops_neighbor_set = set( - [node for hop in h_hop_neighbor_list for node in hop]) - hops_neighbor_list = list(hops_neighbor_set) - hops_neighbor_list.insert(0, pivot_inx) - - local_graph_node_list.append(hops_neighbor_list) - knn_graph_neighbor_list.append(one_hop_neighbors) - - max_graph_node_num = max([ - len(local_graph_nodes) - for local_graph_nodes in local_graph_node_list - ]) - - node_normalized_feats = list() - adjacent_matrix_list = list() - knn_inx = list() - pivot_graph_inx = list() - local_graph_tensor_list = list() - - for graph_inx in range(len(local_graph_node_list)): - - local_graph_nodes = local_graph_node_list[graph_inx] - local_graph_node_num = len(local_graph_nodes) - pivot_inx = local_graph_nodes[0] - knn_graph_neighbors = knn_graph_neighbor_list[graph_inx] - node_to_graph_inx = {j: i for i, j in enumerate(local_graph_nodes)} - - pivot_node_inx = torch.tensor([ - node_to_graph_inx[pivot_inx], - ]).type(torch.long) - knn_inx_in_local_graph = torch.tensor( - [node_to_graph_inx[i] for i in knn_graph_neighbors], - dtype=torch.long) - pivot_feats = node_feats[torch.tensor(pivot_inx, dtype=torch.long)] - normalized_feats = node_feats[torch.tensor( - local_graph_nodes, dtype=torch.long)] - pivot_feats - - adjacent_matrix = np.zeros( - (local_graph_node_num, local_graph_node_num)) - pad_normalized_feats = torch.cat([ - normalized_feats, - torch.zeros(max_graph_node_num - local_graph_node_num, - normalized_feats.shape[1]).to(node_feats.device) - ], - dim=0) - - for node in local_graph_nodes: - neighbors = sorted_complete_graph[node, - 1:self.active_connection + 1] - for neighbor in neighbors: - if neighbor in local_graph_nodes: - adjacent_matrix[node_to_graph_inx[node], - node_to_graph_inx[neighbor]] = 1 - adjacent_matrix[node_to_graph_inx[neighbor], - node_to_graph_inx[node]] = 1 - - adjacent_matrix = normalize_adjacent_matrix( - adjacent_matrix, type='DAD') - adjacent_matrix_tensor = torch.zeros( - max_graph_node_num, max_graph_node_num).to(node_feats.device) - adjacent_matrix_tensor[:local_graph_node_num, : - local_graph_node_num] = adjacent_matrix - - local_graph_tensor = torch.tensor(local_graph_nodes) - local_graph_tensor = torch.cat([ - local_graph_tensor, - torch.zeros( - max_graph_node_num - local_graph_node_num, - dtype=torch.long) - ], - dim=0) - - node_normalized_feats.append(pad_normalized_feats) - adjacent_matrix_list.append(adjacent_matrix_tensor) - pivot_graph_inx.append(pivot_node_inx) - knn_inx.append(knn_inx_in_local_graph) - local_graph_tensor_list.append(local_graph_tensor) - - node_feats_tensor = torch.stack(node_normalized_feats, 0) - adjacent_matrix_tensor = torch.stack(adjacent_matrix_list, 0) - pivot_inx_tensor = torch.stack(pivot_graph_inx, 0) - knn_inx_tensor = torch.stack(knn_inx, 0) - local_graph_node_tensor = torch.stack(local_graph_tensor_list, 0) - - return (node_feats_tensor, adjacent_matrix_tensor, pivot_inx_tensor, - knn_inx_tensor, local_graph_node_tensor) - - def __call__(self, preds, feat_maps): - """Generate local graphs and Graph Convolution Network input data. - - Args: - preds (tensor): The predicted maps. - feat_maps (tensor): The feature maps to extract content features of - text components. - - Returns: - node_feats_tensor (tensor): The graph nodes features. - adjacent_matrix_tensor (tensor): The adjacent matrix of graph. - pivot_inx_tensor (tensor): The pivot indices in local graph. - knn_inx_tensor (tensor): The k nearest neighbor nodes indices in - local graph. - local_graph_node_tensor (tensor): The indices of nodes in local - graph. - text_comps (ndarray): The predicted text components. - """ - - pred_text_region = torch.sigmoid(preds[0, 0]).data.cpu().numpy() - pred_center_region = torch.sigmoid(preds[0, 1]).data.cpu().numpy() - pred_sin_map = preds[0, 2].data.cpu().numpy() - pred_cos_map = preds[0, 3].data.cpu().numpy() - pred_top_radius_map = preds[0, 4].data.cpu().numpy() - pred_bot_radius_map = preds[0, 5].data.cpu().numpy() - - comp_attribs, text_comps = self.propose_comps_and_attribs( - pred_text_region, pred_center_region, pred_top_radius_map, - pred_bot_radius_map, pred_sin_map, pred_cos_map) - - if comp_attribs is None: - none_flag = True - return none_flag, (0, 0, 0, 0, 0, 0) - - comp_centers = comp_attribs[:, 0:2] - distance_matrix = euclidean_distance_matrix(comp_centers, comp_centers) - - graph_node_geo_feats = embed_geo_feats(comp_attribs, - self.node_geo_feat_dim) - graph_node_geo_feats = torch.from_numpy( - graph_node_geo_feats).float().to(preds.device) - - batch_id = np.zeros((comp_attribs.shape[0], 1), dtype=np.float32) - text_comps_bboxes = np.hstack( - (batch_id, comp_attribs.astype(np.float32, copy=False))) - text_comps_bboxes = torch.from_numpy(text_comps_bboxes).float().to( - preds.device) - - comp_content_feats = self.pooling(feat_maps, text_comps_bboxes) - comp_content_feats = comp_content_feats.view( - comp_content_feats.shape[0], -1).to(preds.device) - node_feats = torch.cat((comp_content_feats, graph_node_geo_feats), - dim=-1) - - dist_sort_complete_graph = np.argsort(distance_matrix, axis=1) - (node_feats_tensor, adjacent_matrix_tensor, pivot_inx_tensor, - knn_inx_tensor, local_graph_node_tensor) = self.generate_local_graphs( - dist_sort_complete_graph, node_feats) - - none_flag = False - return none_flag, (node_feats_tensor, adjacent_matrix_tensor, - pivot_inx_tensor, knn_inx_tensor, - local_graph_node_tensor, text_comps) diff --git a/mmocr/models/textdet/modules/utils.py b/mmocr/models/textdet/modules/utils.py deleted file mode 100644 index c1d427ef..00000000 --- a/mmocr/models/textdet/modules/utils.py +++ /dev/null @@ -1,354 +0,0 @@ -import functools -import operator -from typing import List - -import cv2 -import numpy as np -import torch -from numpy.linalg import norm - - -def normalize_adjacent_matrix(A, type='AD'): - """Normalize adjacent matrix for GCN. - - This was from repo https://github.com/GXYM/DRRG. - """ - if type == 'DAD': - # d is Degree of nodes A=A+I - # L = D^-1/2 A D^-1/2 - A = A + np.eye(A.shape[0]) # A=A+I - d = np.sum(A, axis=0) - d_inv = np.power(d, -0.5).flatten() - d_inv[np.isinf(d_inv)] = 0.0 - d_inv = np.diag(d_inv) - G = A.dot(d_inv).transpose().dot(d_inv) - G = torch.from_numpy(G) - elif type == 'AD': - A = A + np.eye(A.shape[0]) # A=A+I - A = torch.from_numpy(A) - D = A.sum(1, keepdim=True) - G = A.div(D) - else: - A = A + np.eye(A.shape[0]) # A=A+I - A = torch.from_numpy(A) - D = A.sum(1, keepdim=True) - D = np.diag(D) - G = D - A - return G - - -def euclidean_distance_matrix(A, B): - """Calculate the Euclidean distance matrix.""" - - M = A.shape[0] - N = B.shape[0] - - assert A.shape[1] == B.shape[1] - - A_dots = (A * A).sum(axis=1).reshape((M, 1)) * np.ones(shape=(1, N)) - B_dots = (B * B).sum(axis=1) * np.ones(shape=(M, 1)) - D_squared = A_dots + B_dots - 2 * A.dot(B.T) - - zero_mask = np.less(D_squared, 0.0) - D_squared[zero_mask] = 0.0 - return np.sqrt(D_squared) - - -def embed_geo_feats(geo_feats, out_dim): - """Embed geometric features of text components. This was partially adapted - from https://github.com/GXYM/DRRG. - - Args: - geo_feats (ndarray): The geometric features of text components. - out_dim (int): The output dimension. - - Returns: - embedded_feats (ndarray): The embedded geometric features. - """ - assert isinstance(out_dim, int) - assert out_dim >= geo_feats.shape[1] - comp_num = geo_feats.shape[0] - feat_dim = geo_feats.shape[1] - feat_repeat_times = out_dim // feat_dim - residue_dim = out_dim % feat_dim - - if residue_dim > 0: - embed_wave = np.array([ - np.power(1000, 2.0 * (j // 2) / feat_repeat_times + 1) - for j in range(feat_repeat_times + 1) - ]).reshape((feat_repeat_times + 1, 1, 1)) - repeat_feats = np.repeat( - np.expand_dims(geo_feats, axis=0), feat_repeat_times, axis=0) - residue_feats = np.hstack([ - geo_feats[:, 0:residue_dim], - np.zeros((comp_num, feat_dim - residue_dim)) - ]) - repeat_feats = np.stack([repeat_feats, residue_feats], axis=0) - embedded_feats = repeat_feats / embed_wave - embedded_feats[:, 0::2] = np.sin(embedded_feats[:, 0::2]) - embedded_feats[:, 1::2] = np.cos(embedded_feats[:, 1::2]) - embedded_feats = np.transpose(embedded_feats, (1, 0, 2)).reshape( - (comp_num, -1))[:, 0:out_dim] - else: - embed_wave = np.array([ - np.power(1000, 2.0 * (j // 2) / feat_repeat_times) - for j in range(feat_repeat_times) - ]).reshape((feat_repeat_times, 1, 1)) - repeat_feats = np.repeat( - np.expand_dims(geo_feats, axis=0), feat_repeat_times, axis=0) - embedded_feats = repeat_feats / embed_wave - embedded_feats[:, 0::2] = np.sin(embedded_feats[:, 0::2]) - embedded_feats[:, 1::2] = np.cos(embedded_feats[:, 1::2]) - embedded_feats = np.transpose(embedded_feats, (1, 0, 2)).reshape( - (comp_num, -1)) - - return embedded_feats - - -def min_connect_path(list_all: List[list]): - """This is from https://github.com/GXYM/DRRG.""" - - list_nodo = list_all.copy() - res: List[List[int]] = [] - ept = [0, 0] - - def norm2(a, b): - return ((a[0] - b[0])**2 + (a[1] - b[1])**2)**0.5 - - dict00 = {} - dict11 = {} - ept[0] = list_nodo[0] - ept[1] = list_nodo[0] - list_nodo.remove(list_nodo[0]) - while list_nodo: - for i in list_nodo: - length0 = norm2(i, ept[0]) - dict00[length0] = [i, ept[0]] - length1 = norm2(ept[1], i) - dict11[length1] = [ept[1], i] - key0 = min(dict00.keys()) - key1 = min(dict11.keys()) - - if key0 <= key1: - ss = dict00[key0][0] - ee = dict00[key0][1] - res.insert(0, [list_all.index(ss), list_all.index(ee)]) - list_nodo.remove(ss) - ept[0] = ss - else: - ss = dict11[key1][0] - ee = dict11[key1][1] - res.append([list_all.index(ss), list_all.index(ee)]) - list_nodo.remove(ee) - ept[1] = ee - - dict00 = {} - dict11 = {} - - path = functools.reduce(operator.concat, res) - path = sorted(set(path), key=path.index) - - return res, path - - -def clusters2labels(clusters, node_num): - """This is from https://github.com/GXYM/DRRG.""" - labels = (-1) * np.ones((node_num, )) - for cluster_inx, cluster in enumerate(clusters): - for node in cluster: - labels[node.inx] = cluster_inx - assert np.sum(labels < 0) < 1 - return labels - - -def remove_single(text_comps, pred): - """Remove isolated single text components. - - This is from https://github.com/GXYM/DRRG. - """ - single_flags = np.zeros_like(pred) - pred_labels = np.unique(pred) - for label in pred_labels: - current_label_flag = pred == label - if np.sum(current_label_flag) == 1: - single_flags[np.where(current_label_flag)[0][0]] = 1 - remain_inx = [i for i in range(len(pred)) if not single_flags[i]] - remain_inx = np.asarray(remain_inx) - return text_comps[remain_inx, :], pred[remain_inx] - - -class Node: - - def __init__(self, inx): - self.__inx = inx - self.__links = set() - - @property - def inx(self): - return self.__inx - - @property - def links(self): - return set(self.__links) - - def add_link(self, other, score): - self.__links.add(other) - other.__links.add(self) - - -def connected_components(nodes, score_dict, thr): - """Connected components searching. - - This is from https://github.com/GXYM/DRRG. - """ - - result = [] - nodes = set(nodes) - while nodes: - node = nodes.pop() - group = {node} - queue = [node] - while queue: - node = queue.pop(0) - if thr is not None: - neighbors = { - linked_neighbor - for linked_neighbor in node.links if score_dict[tuple( - sorted([node.inx, linked_neighbor.inx]))] >= thr - } - else: - neighbors = node.links - neighbors.difference_update(group) - nodes.difference_update(neighbors) - group.update(neighbors) - queue.extend(neighbors) - result.append(group) - return result - - -def graph_propagation(edges, - scores, - link_thr, - bboxes=None, - dis_thr=50, - pool='avg'): - """Propagate graph linkage score information. - - This is from repo https://github.com/GXYM/DRRG. - """ - edges = np.sort(edges, axis=1) - - score_dict = {} - if pool is None: - for i, edge in enumerate(edges): - score_dict[edge[0], edge[1]] = scores[i] - elif pool == 'avg': - for i, edge in enumerate(edges): - if bboxes is not None: - box1 = bboxes[edge[0]][:8].reshape(4, 2) - box2 = bboxes[edge[1]][:8].reshape(4, 2) - center1 = np.mean(box1, axis=0) - center2 = np.mean(box2, axis=0) - dst = norm(center1 - center2) - if dst > dis_thr: - scores[i] = 0 - if (edge[0], edge[1]) in score_dict: - score_dict[edge[0], edge[1]] = 0.5 * ( - score_dict[edge[0], edge[1]] + scores[i]) - else: - score_dict[edge[0], edge[1]] = scores[i] - - elif pool == 'max': - for i, edge in enumerate(edges): - if (edge[0], edge[1]) in score_dict: - score_dict[edge[0], - edge[1]] = max(score_dict[edge[0], edge[1]], - scores[i]) - else: - score_dict[edge[0], edge[1]] = scores[i] - else: - raise ValueError('Pooling operation not supported') - - nodes = np.sort(np.unique(edges.flatten())) - mapping = -1 * np.ones((nodes.max() + 1), dtype=np.int) - mapping[nodes] = np.arange(nodes.shape[0]) - link_inx = mapping[edges] - vertex = [Node(node) for node in nodes] - for link, score in zip(link_inx, scores): - vertex[link[0]].add_link(vertex[link[1]], score) - - clusters = connected_components(vertex, score_dict, link_thr) - - return clusters - - -def in_contour(cont, point): - x, y = point - return cv2.pointPolygonTest(cont, (x, y), False) > 0 - - -def select_edge(cont, box): - """This is from repo https://github.com/GXYM/DRRG.""" - cont = np.array(cont) - box = box.astype(np.int32) - c1 = np.array(0.5 * (box[0, :] + box[3, :]), dtype=np.int) - c2 = np.array(0.5 * (box[1, :] + box[2, :]), dtype=np.int) - - if not in_contour(cont, c1): - return [box[0, :].tolist(), box[3, :].tolist()] - elif not in_contour(cont, c2): - return [box[1, :].tolist(), box[2, :].tolist()] - else: - return None - - -def comps2boundary(text_comps, final_pred): - """Propose text components and generate local graphs. - - This is from repo https://github.com/GXYM/DRRG. - """ - bbox_contours = list() - for inx in range(0, int(np.max(final_pred)) + 1): - current_instance = np.where(final_pred == inx) - boxes = text_comps[current_instance, :8].reshape( - (-1, 4, 2)).astype(np.int32) - - boundary_point = None - if boxes.shape[0] > 1: - centers = np.mean(boxes, axis=1).astype(np.int32).tolist() - paths, routes_path = min_connect_path(centers) - boxes = boxes[routes_path] - top = np.mean(boxes[:, 0:2, :], axis=1).astype(np.int32).tolist() - bot = np.mean(boxes[:, 2:4, :], axis=1).astype(np.int32).tolist() - edge1 = select_edge(top + bot[::-1], boxes[0]) - edge2 = select_edge(top + bot[::-1], boxes[-1]) - if edge1 is not None: - top.insert(0, edge1[0]) - bot.insert(0, edge1[1]) - if edge2 is not None: - top.append(edge2[0]) - bot.append(edge2[1]) - boundary_point = np.array(top + bot[::-1]) - - elif boxes.shape[0] == 1: - top = boxes[0, 0:2, :].astype(np.int32).tolist() - bot = boxes[0, 2:4:-1, :].astype(np.int32).tolist() - boundary_point = np.array(top + bot) - - if boundary_point is None: - continue - - boundary_point = [p for p in boundary_point.flatten().tolist()] - bbox_contours.append(boundary_point) - - return bbox_contours - - -def merge_text_comps(edges, scores, text_comps, link_thr): - """Merge text components into text instance.""" - clusters = graph_propagation(edges, scores, link_thr) - pred_labels = clusters2labels(clusters, text_comps.shape[0]) - text_comps, final_pred = remove_single(text_comps, pred_labels) - boundaries = comps2boundary(text_comps, final_pred) - - return boundaries diff --git a/mmocr/models/textrecog/losses/__init__.py b/mmocr/models/textrecog/losses/__init__.py index 4b5a24a2..226aa006 100755 --- a/mmocr/models/textrecog/losses/__init__.py +++ b/mmocr/models/textrecog/losses/__init__.py @@ -1,5 +1,5 @@ from .ce_loss import CELoss, SARLoss, TFLoss from .ctc_loss import CTCLoss -from .seg_loss import CAFCNLoss, SegLoss +from .seg_loss import SegLoss -__all__ = ['CELoss', 'SARLoss', 'CTCLoss', 'TFLoss', 'SegLoss', 'CAFCNLoss'] +__all__ = ['CELoss', 'SARLoss', 'CTCLoss', 'TFLoss', 'SegLoss'] diff --git a/mmocr/models/textrecog/losses/seg_loss.py b/mmocr/models/textrecog/losses/seg_loss.py index 0da9e61e..9a2911e8 100644 --- a/mmocr/models/textrecog/losses/seg_loss.py +++ b/mmocr/models/textrecog/losses/seg_loss.py @@ -3,7 +3,6 @@ import torch.nn as nn import torch.nn.functional as F from mmdet.models.builder import LOSSES -from mmocr.models.common.losses import DiceLoss @LOSSES.register_module() @@ -68,109 +67,3 @@ class SegLoss(nn.Module): losses['loss_seg'] = loss_seg return losses - - -@LOSSES.register_module() -class CAFCNLoss(SegLoss): - """Implementation of loss module in `CA-FCN. - - `_ - - Args: - alpha (float): Weight ratio of attention loss. - attn_s2_downsample_ratio (float): Downsample ratio - of attention map from output stage 2. - attn_s3_downsample_ratio (float): Downsample ratio - of attention map from output stage 3. - seg_downsample_ratio (float): Downsample ratio of - segmentation map. - attn_with_dice_loss (bool): If True, use dice_loss for attention, - else BCELoss. - with_attn (bool): If True, include attention loss, else - segmentation loss only. - seg_with_loss_weight (bool): If True, set weight for - segmentation loss. - ignore_index (int): Specifies a target value that is ignored - and does not contribute to the input gradient. - """ - - def __init__(self, - alpha=1.0, - attn_s2_downsample_ratio=0.25, - attn_s3_downsample_ratio=0.125, - seg_downsample_ratio=0.5, - attn_with_dice_loss=False, - with_attn=True, - seg_with_loss_weight=True, - ignore_index=255): - super().__init__(seg_downsample_ratio, seg_with_loss_weight, - ignore_index) - assert isinstance(alpha, (int, float)) - assert isinstance(attn_s2_downsample_ratio, (int, float)) - assert isinstance(attn_s3_downsample_ratio, (int, float)) - assert 0 < attn_s2_downsample_ratio <= 1 - assert 0 < attn_s3_downsample_ratio <= 1 - - self.alpha = alpha - self.attn_s2_downsample_ratio = attn_s2_downsample_ratio - self.attn_s3_downsample_ratio = attn_s3_downsample_ratio - self.with_attn = with_attn - self.attn_with_dice_loss = attn_with_dice_loss - - # attention loss - if with_attn: - if attn_with_dice_loss: - self.criterion_attn = DiceLoss() - else: - self.criterion_attn = nn.BCELoss(reduction='none') - - def attn_loss(self, out_neck, gt_kernels): - attn_map_s2 = out_neck[0] # bsz * 2 * H/4 * W/4 - - mask_s2 = torch.stack([ - item[2].rescale(self.attn_s2_downsample_ratio).to_tensor( - torch.float, attn_map_s2.device) for item in gt_kernels - ]) - - attn_target_s2 = torch.stack([ - item[0].rescale(self.attn_s2_downsample_ratio).to_tensor( - torch.float, attn_map_s2.device) for item in gt_kernels - ]) - - mask_s3 = torch.stack([ - item[2].rescale(self.attn_s3_downsample_ratio).to_tensor( - torch.float, attn_map_s2.device) for item in gt_kernels - ]) - - attn_target_s3 = torch.stack([ - item[0].rescale(self.attn_s3_downsample_ratio).to_tensor( - torch.float, attn_map_s2.device) for item in gt_kernels - ]) - - targets = [ - attn_target_s2, attn_target_s3, attn_target_s3, attn_target_s3 - ] - - masks = [mask_s2, mask_s3, mask_s3, mask_s3] - - loss_attn = 0. - for i in range(len(out_neck) - 1): - pred = out_neck[i] - if self.attn_with_dice_loss: - loss_attn += self.criterion_attn(pred, targets[i], masks[i]) - else: - loss_attn += torch.sum( - self.criterion_attn(pred, targets[i]) * - masks[i]) / torch.sum(masks[i]) - - return loss_attn - - def forward(self, out_neck, out_head, gt_kernels): - - losses = super().forward(out_neck, out_head, gt_kernels) - - if self.with_attn: - loss_attn = self.attn_loss(out_neck, gt_kernels) - losses['loss_attn'] = loss_attn - - return losses diff --git a/mmocr/models/textrecog/necks/__init__.py b/mmocr/models/textrecog/necks/__init__.py index c10a46a5..71ceadc1 100755 --- a/mmocr/models/textrecog/necks/__init__.py +++ b/mmocr/models/textrecog/necks/__init__.py @@ -1,5 +1,3 @@ -from .cafcn_neck import CAFCNNeck from .fpn_ocr import FPNOCR -from .fpn_seg import FPNSeg -__all__ = ['CAFCNNeck', 'FPNSeg', 'FPNOCR'] +__all__ = ['FPNOCR'] diff --git a/mmocr/models/textrecog/necks/cafcn_neck.py b/mmocr/models/textrecog/necks/cafcn_neck.py deleted file mode 100644 index b5fca8d2..00000000 --- a/mmocr/models/textrecog/necks/cafcn_neck.py +++ /dev/null @@ -1,223 +0,0 @@ -import torch -import torch.nn.functional as F -from mmcv.cnn import ConvModule -from mmcv.ops import DeformConv2dPack -from torch import nn - -from mmdet.models.builder import NECKS - - -class CharAttn(nn.Module): - """Implementation of Character attention module in `CA-FCN. - - `_ - """ - - def __init__(self, in_channels=128, out_channels=128, deformable=False): - super().__init__() - assert isinstance(in_channels, int) - assert isinstance(deformable, bool) - - self.in_channels = in_channels - self.out_channels = out_channels - self.deformable = deformable - - # attention layers - self.attn_layer = nn.Sequential( - ConvModule( - in_channels, - in_channels, - 3, - stride=1, - padding=1, - norm_cfg=dict(type='BN')), - ConvModule( - in_channels, - 1, - 3, - stride=1, - padding=1, - norm_cfg=dict(type='BN'), - act_cfg=dict(type='Sigmoid'))) - - conv_kwargs = dict( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=(3, 3), - stride=1, - padding=1) - if self.deformable: - self.conv = DeformConv2dPack(**conv_kwargs) - else: - self.conv = nn.Conv2d(**conv_kwargs) - self.bn = nn.BatchNorm2d(out_channels) - self.relu = nn.ReLU(inplace=True) - - def forward(self, in_feat): - # Calculate attn map - attn_map = self.attn_layer(in_feat) # N * 1 * H * W - - in_feat = self.relu(self.bn(self.conv(in_feat))) - - out_feat_map = self._upsample_mul(in_feat, 1 + attn_map) - - return out_feat_map, attn_map - - def _upsample_add(self, x, y): - return F.interpolate(x, size=y.size()[2:]) + y - - def _upsample_mul(self, x, y): - return F.interpolate(x, size=y.size()[2:]) * y - - -class FeatGenerator(nn.Module): - """Generate attention-augmented stage feature from backbone stage - feature.""" - - def __init__(self, - in_channels=512, - out_channels=128, - deformable=True, - concat=False, - upsample=False, - with_attn=True): - super().__init__() - - self.concat = concat - self.upsample = upsample - self.with_attn = with_attn - - if with_attn: - self.char_attn = CharAttn( - in_channels=in_channels, - out_channels=out_channels, - deformable=deformable) - else: - self.char_attn = ConvModule( - in_channels, - out_channels, - 3, - stride=1, - padding=1, - norm_cfg=dict(type='BN')) - - if concat: - self.conv_to_concat = ConvModule( - out_channels, - out_channels, - 3, - stride=1, - padding=1, - norm_cfg=dict(type='BN')) - - kernel_size = (3, 1) if deformable else 3 - padding = (1, 0) if deformable else 1 - tmp_in_channels = out_channels * 2 if concat else out_channels - - self.conv_after_concat = ConvModule( - tmp_in_channels, - out_channels, - kernel_size, - stride=1, - padding=padding, - norm_cfg=dict(type='BN')) - - def forward(self, x, y=None, size=None): - if self.with_attn: - feat_map, attn_map = self.char_attn(x) - else: - feat_map = self.char_attn(x) - attn_map = feat_map - - if self.concat: - y = self.conv_to_concat(y) - feat_map = torch.cat((y, feat_map), dim=1) - - feat_map = self.conv_after_concat(feat_map) - - if self.upsample: - feat_map = F.interpolate(feat_map, size) - - return attn_map, feat_map - - -@NECKS.register_module() -class CAFCNNeck(nn.Module): - """Implementation of neck module in `CA-FCN. - - `_ - - Args: - in_channels (list[int]): Number of input channels for each scale. - out_channels (int): Number of output channels for each scale. - deformable (bool): If True, use deformable conv. - with_attn (bool): If True, add attention for each output feature map. - """ - - def __init__(self, - in_channels=[128, 256, 512, 512], - out_channels=128, - deformable=True, - with_attn=True): - super().__init__() - - self.deformable = deformable - self.with_attn = with_attn - - # stage_in5_to_out5 - self.s5 = FeatGenerator( - in_channels=in_channels[-1], - out_channels=out_channels, - deformable=deformable, - concat=False, - with_attn=with_attn) - - # stage_in4_to_out4 - self.s4 = FeatGenerator( - in_channels=in_channels[-2], - out_channels=out_channels, - deformable=deformable, - concat=True, - with_attn=with_attn) - - # stage_in3_to_out3 - self.s3 = FeatGenerator( - in_channels=in_channels[-3], - out_channels=out_channels, - deformable=False, - concat=True, - upsample=True, - with_attn=with_attn) - - # stage_in2_to_out2 - self.s2 = FeatGenerator( - in_channels=in_channels[-4], - out_channels=out_channels, - deformable=False, - concat=True, - upsample=True, - with_attn=with_attn) - - def init_weights(self): - pass - - def forward(self, feats): - in_stage1 = feats[0] - in_stage2, in_stage3 = feats[1], feats[2] - in_stage4, in_stage5 = feats[3], feats[4] - # out stage 5 - out_s5_attn_map, out_s5 = self.s5(in_stage5) - - # out stage 4 - out_s4_attn_map, out_s4 = self.s4(in_stage4, out_s5) - - # out stage 3 - out_s3_attn_map, out_s3 = self.s3(in_stage3, out_s4, - in_stage2.size()[2:]) - - # out stage 2 - out_s2_attn_map, out_s2 = self.s2(in_stage2, out_s3, - in_stage1.size()[2:]) - - return (out_s2_attn_map, out_s3_attn_map, out_s4_attn_map, - out_s5_attn_map, out_s2) diff --git a/mmocr/models/textrecog/necks/fpn_seg.py b/mmocr/models/textrecog/necks/fpn_seg.py deleted file mode 100644 index 997951e4..00000000 --- a/mmocr/models/textrecog/necks/fpn_seg.py +++ /dev/null @@ -1,43 +0,0 @@ -import torch.nn.functional as F -from mmcv.runner import auto_fp16 - -from mmdet.models.builder import NECKS -from mmdet.models.necks import FPN - - -@NECKS.register_module() -class FPNSeg(FPN): - """Feature Pyramid Network for segmentation based text recognition. - - Args: - in_channels (list[int]): Number of input channels for each scale. - out_channels (int): Number of output channels for each scale. - num_outs (int): Number of output scales. - upsample_param (dict | None): Config dict for interpolate layer. - Default: `dict(scale_factor=1.0, mode='nearest')` - last_stage_only (bool): If True, output last stage of FPN only. - """ - - def __init__(self, - in_channels, - out_channels, - num_outs, - upsample_param=None, - last_stage_only=True): - super().__init__(in_channels, out_channels, num_outs) - self.upsample_param = upsample_param - self.last_stage_only = last_stage_only - - @auto_fp16() - def forward(self, inputs): - outs = super().forward(inputs) - - outs = list(outs) - - if self.upsample_param is not None: - outs[0] = F.interpolate(outs[0], **self.upsample_param) - - if self.last_stage_only: - return tuple(outs[0:1]) - - return tuple(outs[::-1]) diff --git a/mmocr/models/textrecog/recognizer/__init__.py b/mmocr/models/textrecog/recognizer/__init__.py index 379cef57..91af5666 100644 --- a/mmocr/models/textrecog/recognizer/__init__.py +++ b/mmocr/models/textrecog/recognizer/__init__.py @@ -1,5 +1,4 @@ from .base import BaseRecognizer -from .cafcn import CAFCNNet from .crnn import CRNNNet from .encode_decode_recognizer import EncodeDecodeRecognizer from .nrtr import NRTR @@ -9,5 +8,5 @@ from .seg_recognizer import SegRecognizer __all__ = [ 'BaseRecognizer', 'EncodeDecodeRecognizer', 'CRNNNet', 'SARNet', 'NRTR', - 'SegRecognizer', 'RobustScanner', 'CAFCNNet' + 'SegRecognizer', 'RobustScanner' ] diff --git a/mmocr/models/textrecog/recognizer/cafcn.py b/mmocr/models/textrecog/recognizer/cafcn.py deleted file mode 100644 index 6acade5e..00000000 --- a/mmocr/models/textrecog/recognizer/cafcn.py +++ /dev/null @@ -1,7 +0,0 @@ -from mmdet.models.builder import DETECTORS -from .seg_recognizer import SegRecognizer - - -@DETECTORS.register_module() -class CAFCNNet(SegRecognizer): - """Implementation of `CAFCN `_""" diff --git a/tests/test_models/test_ocr_loss.py b/tests/test_models/test_ocr_loss.py index fa118f5e..6dad9e2d 100644 --- a/tests/test_models/test_ocr_loss.py +++ b/tests/test_models/test_ocr_loss.py @@ -1,11 +1,8 @@ -import numpy as np import pytest import torch -from mmdet.core import BitmapMasks from mmocr.models.common.losses import DiceLoss -from mmocr.models.textrecog.losses import (CAFCNLoss, CELoss, CTCLoss, SARLoss, - TFLoss) +from mmocr.models.textrecog.losses import CELoss, CTCLoss, SARLoss, TFLoss def test_ctc_loss(): @@ -69,44 +66,6 @@ def test_tf_loss(): assert new_target.shape == torch.Size([1, 9]) -def test_cafcn_loss(): - with pytest.raises(AssertionError): - CAFCNLoss(alpha='1') - with pytest.raises(AssertionError): - CAFCNLoss(attn_s2_downsample_ratio='2') - with pytest.raises(AssertionError): - CAFCNLoss(attn_s3_downsample_ratio='1.5') - with pytest.raises(AssertionError): - CAFCNLoss(seg_downsample_ratio='1.5') - with pytest.raises(AssertionError): - CAFCNLoss(attn_s2_downsample_ratio=2) - with pytest.raises(AssertionError): - CAFCNLoss(attn_s3_downsample_ratio=1.5) - with pytest.raises(AssertionError): - CAFCNLoss(seg_downsample_ratio=1.5) - - bsz = 1 - H = W = 64 - out_neck = (torch.ones(bsz, 1, H // 4, W // 4) * 0.5, - torch.ones(bsz, 1, H // 8, W // 8) * 0.5, - torch.ones(bsz, 1, H // 8, W // 8) * 0.5, - torch.ones(bsz, 1, H // 8, W // 8) * 0.5, - torch.ones(bsz, 1, H // 2, W // 2) * 0.5) - out_head = torch.rand(bsz, 37, H // 2, W // 2) - - attn_tgt = np.zeros((H, W), dtype=np.float32) - segm_tgt = np.zeros((H, W), dtype=np.float32) - mask = np.ones((H, W), dtype=np.float32) - gt_kernels = BitmapMasks([attn_tgt, segm_tgt, mask], H, W) - - cafcn_loss = CAFCNLoss() - losses = cafcn_loss(out_neck, out_head, [gt_kernels]) - assert isinstance(losses, dict) - assert 'loss_seg' in losses - assert torch.allclose(losses['loss_seg'], - torch.tensor(losses['loss_seg'].item()).float()) - - def test_dice_loss(): with pytest.raises(AssertionError): DiceLoss(eps='1') diff --git a/tests/test_models/test_ocr_neck.py b/tests/test_models/test_ocr_neck.py deleted file mode 100644 index c37cc71a..00000000 --- a/tests/test_models/test_ocr_neck.py +++ /dev/null @@ -1,47 +0,0 @@ -import pytest -import torch - -from mmocr.models.textrecog.necks.cafcn_neck import (CAFCNNeck, CharAttn, - FeatGenerator) - - -def test_char_attn(): - with pytest.raises(AssertionError): - CharAttn(in_channels=5.0) - with pytest.raises(AssertionError): - CharAttn(deformable='deformabel') - - in_feat = torch.rand(1, 128, 32, 32) - char_attn = CharAttn() - out_feat_map, attn_map = char_attn(in_feat) - assert attn_map.shape == torch.Size([1, 1, 32, 32]) - assert out_feat_map.shape == torch.Size([1, 128, 32, 32]) - - -def test_feat_generator(): - in_feat = torch.rand(1, 128, 32, 32) - feat_generator = FeatGenerator( - in_channels=128, out_channels=128, deformable=False) - - attn_map, feat_map = feat_generator(in_feat) - assert attn_map.shape == torch.Size([1, 1, 32, 32]) - assert feat_map.shape == torch.Size([1, 128, 32, 32]) - - -def test_cafcn_neck(): - in_s1 = torch.rand(1, 64, 64, 64) - in_s2 = torch.rand(1, 128, 32, 32) - in_s3 = torch.rand(1, 256, 16, 16) - in_s4 = torch.rand(1, 512, 16, 16) - in_s5 = torch.rand(1, 512, 16, 16) - - cafcn_neck = CAFCNNeck(deformable=False) - cafcn_neck.init_weights() - cafcn_neck.train() - - out_neck = cafcn_neck((in_s1, in_s2, in_s3, in_s4, in_s5)) - assert out_neck[0].shape == torch.Size([1, 1, 32, 32]) - assert out_neck[1].shape == torch.Size([1, 1, 16, 16]) - assert out_neck[2].shape == torch.Size([1, 1, 16, 16]) - assert out_neck[3].shape == torch.Size([1, 1, 16, 16]) - assert out_neck[4].shape == torch.Size([1, 128, 64, 64])