fix #37: remove useless code (#38)

pull/2/head
Hongbin Sun 2021-04-06 11:40:48 +08:00 committed by GitHub
parent 03720f46c3
commit aa87b69f12
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 8 additions and 2087 deletions

View File

@ -1,6 +1,5 @@
from .dense_heads import * # noqa: F401,F403
from .detectors import * # noqa: F401,F403
from .losses import * # noqa: F401,F403
from .modules import * # noqa: F401,F403
from .necks import * # noqa: F401,F403
from .postprocess import * # noqa: F401,F403

View File

@ -1,10 +1,7 @@
from .db_head import DBHead
from .drrg_head import DRRGHead
from .head_mixin import HeadMixin
from .pan_head import PANHead
from .pse_head import PSEHead
from .textsnake_head import TextSnakeHead
__all__ = [
'PSEHead', 'PANHead', 'DBHead', 'HeadMixin', 'DRRGHead', 'TextSnakeHead'
]
__all__ = ['PSEHead', 'PANHead', 'DBHead', 'HeadMixin', 'TextSnakeHead']

View File

@ -1,193 +0,0 @@
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import normal_init
from mmdet.models.builder import HEADS, build_loss
from mmocr.models.textdet.modules import (GCN, LocalGraphs,
ProposalLocalGraphs,
merge_text_comps)
from .head_mixin import HeadMixin
@HEADS.register_module()
class DRRGHead(HeadMixin, nn.Module):
"""The class for DRRG head: Deep Relational Reasoning Graph Network for
Arbitrary Shape Text Detection.
[https://arxiv.org/abs/2003.07493]
Args:
k_at_hops (tuple(int)): The number of i-hop neighbors,
i = 1, 2, ..., h.
active_connection (int): The number of two hop neighbors deem as
linked to a pivot.
node_geo_feat_dim (int): The dimension of embedded geometric features
of a component.
pooling_scale (float): The spatial scale of RRoI-Aligning.
pooling_output_size (tuple(int)): The size of RRoI-Aligning output.
graph_filter_thr (float): The threshold to filter identical local
graphs.
comp_shrink_ratio (float): The shrink ratio of text components.
nms_thr (float): The locality-aware NMS threshold.
min_width (float): The minimum width of text components.
max_width (float): The maximum width of text components.
comp_ratio (float): The reciprocal of aspect ratio of text components.
text_region_thr (float): The threshold for text region probability map.
center_region_thr (float): The threshold of text center region
probability map.
center_region_area_thr (int): The threshold of filtering small-size
text center region.
link_thr (float): The threshold for connected components searching.
"""
def __init__(self,
in_channels,
k_at_hops=(8, 4),
active_connection=3,
node_geo_feat_dim=120,
pooling_scale=1.0,
pooling_output_size=(3, 4),
graph_filter_thr=0.75,
comp_shrink_ratio=0.95,
nms_thr=0.25,
min_width=8.0,
max_width=24.0,
comp_ratio=0.65,
text_region_thr=0.6,
center_region_thr=0.4,
center_region_area_thr=100,
link_thr=0.85,
loss=dict(type='DRRGLoss'),
train_cfg=None,
test_cfg=None):
super().__init__()
assert isinstance(in_channels, int)
assert isinstance(k_at_hops, tuple)
assert isinstance(active_connection, int)
assert isinstance(node_geo_feat_dim, int)
assert isinstance(pooling_scale, float)
assert isinstance(pooling_output_size, tuple)
assert isinstance(graph_filter_thr, float)
assert isinstance(comp_shrink_ratio, float)
assert isinstance(nms_thr, float)
assert isinstance(min_width, float)
assert isinstance(max_width, float)
assert isinstance(comp_ratio, float)
assert isinstance(center_region_area_thr, int)
assert isinstance(link_thr, float)
self.in_channels = in_channels
self.out_channels = 6
self.downsample_ratio = 1.0
self.k_at_hops = k_at_hops
self.active_connection = active_connection
self.node_geo_feat_dim = node_geo_feat_dim
self.pooling_scale = pooling_scale
self.pooling_output_size = pooling_output_size
self.graph_filter_thr = graph_filter_thr
self.comp_shrink_ratio = comp_shrink_ratio
self.nms_thr = nms_thr
self.min_width = min_width
self.max_width = max_width
self.comp_ratio = comp_ratio
self.text_region_thr = text_region_thr
self.center_region_thr = center_region_thr
self.center_region_area_thr = center_region_area_thr
self.link_thr = link_thr
self.loss_module = build_loss(loss)
self.train_cfg = train_cfg
self.test_cfg = test_cfg
self.out_conv = nn.Conv2d(
in_channels=self.in_channels,
out_channels=self.out_channels,
kernel_size=1,
stride=1,
padding=0)
self.init_weights()
self.graph_train = LocalGraphs(self.k_at_hops, self.active_connection,
self.node_geo_feat_dim,
self.pooling_scale,
self.pooling_output_size,
self.graph_filter_thr)
self.graph_test = ProposalLocalGraphs(
self.k_at_hops, self.active_connection, self.node_geo_feat_dim,
self.pooling_scale, self.pooling_output_size, self.nms_thr,
self.min_width, self.max_width, self.comp_shrink_ratio,
self.comp_ratio, self.text_region_thr, self.center_region_thr,
self.center_region_area_thr)
pool_w, pool_h = self.pooling_output_size
gcn_in_dim = (pool_w * pool_h) * (
self.in_channels + self.out_channels) + self.node_geo_feat_dim
self.gcn = GCN(gcn_in_dim, 32)
def init_weights(self):
normal_init(self.out_conv, mean=0, std=0.01)
def forward(self, inputs, text_comp_feats):
pred_maps = self.out_conv(inputs)
feat_maps = torch.cat([inputs, pred_maps], dim=1)
node_feats, adjacent_matrices, knn_inx, gt_labels = self.graph_train(
feat_maps, np.array(text_comp_feats))
gcn_pred = self.gcn(node_feats, adjacent_matrices, knn_inx)
return (pred_maps, (gcn_pred, gt_labels))
def single_test(self, feat_maps):
pred_maps = self.out_conv(feat_maps)
feat_maps = torch.cat([feat_maps[0], pred_maps], dim=1)
none_flag, graph_data = self.graph_test(pred_maps, feat_maps)
(node_feats, adjacent_matrix, pivot_inx, knn_inx, local_graph_nodes,
text_comps) = graph_data
if none_flag:
return None, None, None
adjacent_matrix, pivot_inx, knn_inx = map(
lambda x: x.to(feat_maps.device),
(adjacent_matrix, pivot_inx, knn_inx))
gcn_pred = self.gcn_model(node_feats, adjacent_matrix, knn_inx)
pred_labels = F.softmax(gcn_pred, dim=1)
edges = []
scores = []
local_graph_nodes = local_graph_nodes.long().squeeze().cpu().numpy()
graph_num = node_feats.size(0)
for graph_inx in range(graph_num):
pivot = pivot_inx[graph_inx].int().item()
nodes = local_graph_nodes[graph_inx]
for neighbor_inx, neighbor in enumerate(knn_inx[graph_inx]):
neighbor = neighbor.item()
edges.append([nodes[pivot], nodes[neighbor]])
scores.append(pred_labels[graph_inx * (knn_inx.shape[1]) +
neighbor_inx, 1].item())
edges = np.asarray(edges)
scores = np.asarray(scores)
return edges, scores, text_comps
def get_boundary(self, edges, scores, text_comps):
boundaries = []
if edges is not None:
boundaries = merge_text_comps(edges, scores, text_comps,
self.link_thr)
results = dict(boundary_result=boundaries)
return results

View File

@ -4,10 +4,9 @@ from .dbnet import DBNet # isort:skip
from .ocr_mask_rcnn import OCRMaskRCNN # isort:skip
from .panet import PANet # isort:skip
from .psenet import PSENet # isort:skip
from .drrg import DRRG # isort:skip
from .textsnake import TextSnake # isort:skip
__all__ = [
'TextDetectorMixin', 'SingleStageTextDetector', 'OCRMaskRCNN', 'DBNet',
'PANet', 'PSENet', 'DRRG', 'TextSnake'
'PANet', 'PSENet', 'TextSnake'
]

View File

@ -1,38 +0,0 @@
from mmdet.models.builder import DETECTORS
from . import SingleStageTextDetector, TextDetectorMixin
@DETECTORS.register_module()
class DRRG(TextDetectorMixin, SingleStageTextDetector):
"""The class for implementing DRRG text detector: Deep Relational Reasoning
Graph Network for Arbitrary Shape Text Detection.
[https://arxiv.org/abs/2003.07493]
"""
def forward_train(self, img, img_metas, **kwargs):
"""
Args:
img (Tensor): Input images of shape (N, C, H, W).
Typically these should be mean centered and std scaled.
img_metas (list[dict]): A list of image info dict where each dict
has: 'img_shape', 'scale_factor', 'flip', and may also contain
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
For details of the values of these keys, see
:class:`mmdet.datasets.pipelines.Collect`.
Returns:
dict[str, Tensor]: A dictionary of loss components.
"""
x = self.extract_feat(img)
gt_comp_attribs = kwargs.pop('gt_comp_attribs')
preds = self.bbox_head(x, gt_comp_attribs)
losses = self.bbox_head.loss(preds, **kwargs)
return losses
def simple_test(self, img, img_metas, rescale=False):
x = self.extract_feat(img)
outs = self.bbox_head.single_test(x, img)
boundaries = self.bbox_head.get_boundary(*outs, img_metas, rescale)
return [boundaries]

View File

@ -1,7 +1,6 @@
from .db_loss import DBLoss
from .drrg_loss import DRRGLoss
from .pan_loss import PANLoss
from .pse_loss import PSELoss
from .textsnake_loss import TextSnakeLoss
__all__ = ['PANLoss', 'PSELoss', 'DBLoss', 'DRRGLoss', 'TextSnakeLoss']
__all__ = ['PANLoss', 'PSELoss', 'DBLoss', 'TextSnakeLoss']

View File

@ -1,206 +0,0 @@
import torch
import torch.nn.functional as F
from torch import nn
from mmdet.core import BitmapMasks
from mmdet.models.builder import LOSSES
from mmocr.utils import check_argument
@LOSSES.register_module()
class DRRGLoss(nn.Module):
"""The class for implementing DRRG loss: Deep Relational Reasoning Graph
Network for Arbitrary Shape Text Detection.
[https://arxiv.org/abs/1908.05900] This is partially adapted from
https://github.com/GXYM/DRRG.
"""
def __init__(self, ohem_ratio=3.0):
"""Initialization.
Args:
ohem_ratio (float): The negative/positive ratio in OHEM.
"""
super().__init__()
self.ohem_ratio = ohem_ratio
def balance_bce_loss(self, pred, gt, mask):
assert pred.shape == gt.shape == mask.shape
positive = gt * mask
negative = (1 - gt) * mask
positive_count = int(positive.float().sum())
gt = gt.float()
if positive_count > 0:
loss = F.binary_cross_entropy(pred, gt, reduction='none')
positive_loss = torch.sum(loss * positive.float())
negative_loss = loss * negative.float()
negative_count = min(
int(negative.float().sum()),
int(positive_count * self.ohem_ratio))
else:
positive_loss = torch.tensor(0.0)
loss = F.binary_cross_entropy(pred, gt, reduction='none')
negative_loss = loss * negative.float()
negative_count = 100
negative_loss, _ = torch.topk(negative_loss.view(-1), negative_count)
balance_loss = (positive_loss + torch.sum(negative_loss)) / (
float(positive_count + negative_count) + 1e-5)
return balance_loss
def gcn_loss(self, gcn_data):
gcn_pred, gt_labels = gcn_data
gt_labels = gt_labels.view(-1).to(gcn_pred.device)
loss = F.cross_entropy(gcn_pred, gt_labels)
return loss
def bitmasks2tensor(self, bitmasks, target_sz):
"""Convert Bitmasks to tensor.
Args:
bitmasks (list[BitmapMasks]): The BitmapMasks list. Each item is
for one img.
target_sz (tuple(int, int)): The target tensor size HxW.
Returns
results (list[tensor]): The list of kernel tensors. Each
element is for one kernel level.
"""
assert check_argument.is_type_list(bitmasks, BitmapMasks)
assert isinstance(target_sz, tuple)
batch_size = len(bitmasks)
num_masks = len(bitmasks[0])
results = []
for level_inx in range(num_masks):
kernel = []
for batch_inx in range(batch_size):
mask = torch.from_numpy(bitmasks[batch_inx].masks[level_inx])
# hxw
mask_sz = mask.shape
# left, right, top, bottom
pad = [
0, target_sz[1] - mask_sz[1], 0, target_sz[0] - mask_sz[0]
]
mask = F.pad(mask, pad, mode='constant', value=0)
kernel.append(mask)
kernel = torch.stack(kernel)
results.append(kernel)
return results
def forward(self, preds, downsample_ratio, gt_text_mask,
gt_center_region_mask, gt_mask, gt_top_height_map,
gt_bot_height_map, gt_sin_map, gt_cos_map):
assert isinstance(preds, tuple)
assert isinstance(downsample_ratio, float)
assert abs(downsample_ratio - 1.0) < 1e-5
assert check_argument.is_type_list(gt_text_mask, BitmapMasks)
assert check_argument.is_type_list(gt_center_region_mask, BitmapMasks)
assert check_argument.is_type_list(gt_mask, BitmapMasks)
assert check_argument.is_type_list(gt_top_height_map, BitmapMasks)
assert check_argument.is_type_list(gt_bot_height_map, BitmapMasks)
assert check_argument.is_type_list(gt_sin_map, BitmapMasks)
assert check_argument.is_type_list(gt_cos_map, BitmapMasks)
pred_maps, gcn_data = preds
pred_text_region = pred_maps[:, 0, :, :]
pred_center_region = pred_maps[:, 1, :, :]
pred_sin_map = pred_maps[:, 2, :, :]
pred_cos_map = pred_maps[:, 3, :, :]
pred_top_height_map = pred_maps[:, 4, :, :]
pred_bot_height_map = pred_maps[:, 5, :, :]
feature_sz = pred_maps.size()
# bitmask 2 tensor
mapping = {
'gt_text_mask': gt_text_mask,
'gt_center_region_mask': gt_center_region_mask,
'gt_mask': gt_mask,
'gt_top_height_map': gt_top_height_map,
'gt_bot_height_map': gt_bot_height_map,
'gt_sin_map': gt_sin_map,
'gt_cos_map': gt_cos_map
}
gt = {}
for key, value in mapping.items():
gt[key] = value
gt[key] = self.bitmasks2tensor(gt[key], feature_sz[2:])
gt[key] = [item.to(pred_maps.device) for item in gt[key]]
scale = torch.sqrt(1.0 / (pred_sin_map**2 + pred_cos_map**2 + 1e-8))
pred_sin_map = pred_sin_map * scale
pred_cos_map = pred_cos_map * scale
loss_text = self.balance_bce_loss(
torch.sigmoid(pred_text_region), gt['gt_text_mask'][0],
gt['gt_mask'][0])
text_mask = (gt['gt_text_mask'][0] * gt['gt_mask'][0]).float()
negative_text_mask = ((1 - gt['gt_text_mask'][0]) *
gt['gt_mask'][0]).float()
gt_center_region_mask = gt['gt_center_region_mask'][0].float()
loss_center = F.binary_cross_entropy(
torch.sigmoid(pred_center_region),
gt_center_region_mask,
reduction='none')
if int(text_mask.sum()) > 0:
loss_center_positive = torch.sum(
loss_center * text_mask) / torch.sum(text_mask)
else:
loss_center_positive = torch.tensor(0.0)
loss_center_negative = torch.sum(
loss_center * negative_text_mask) / torch.sum(negative_text_mask)
loss_center = loss_center_positive + 0.5 * loss_center_negative
center_mask = (gt['gt_center_region_mask'][0] *
gt['gt_mask'][0]).float()
if int(center_mask.sum()) > 0:
ones = torch.ones_like(
gt['gt_top_height_map'][0], dtype=torch.float)
loss_top = F.smooth_l1_loss(
pred_top_height_map / (gt['gt_top_height_map'][0] + 1e-2),
ones,
reduction='none')
loss_bot = F.smooth_l1_loss(
pred_bot_height_map / (gt['gt_bot_height_map'][0] + 1e-2),
ones,
reduction='none')
gt_height = (
gt['gt_top_height_map'][0] + gt['gt_bot_height_map'][0])
loss_height = torch.sum(
(torch.log(gt_height + 1) *
(loss_top + loss_bot)) * center_mask) / torch.sum(center_mask)
loss_sin = torch.sum(
F.smooth_l1_loss(
pred_sin_map, gt['gt_sin_map'][0], reduction='none') *
center_mask) / torch.sum(center_mask)
loss_cos = torch.sum(
F.smooth_l1_loss(
pred_cos_map, gt['gt_cos_map'][0], reduction='none') *
center_mask) / torch.sum(center_mask)
else:
loss_height = torch.tensor(0.0)
loss_sin = torch.tensor(0.0)
loss_cos = torch.tensor(0.0)
loss_gcn = self.gcn_loss(gcn_data)
results = dict(
loss_text=loss_text,
loss_center=loss_center,
loss_height=loss_height,
loss_sin=loss_sin,
loss_cos=loss_cos,
loss_gcn=loss_gcn)
return results

View File

@ -1,6 +0,0 @@
from .gcn import GCN
from .local_graph import LocalGraphs
from .proposal_local_graph import ProposalLocalGraphs
from .utils import merge_text_comps
__all__ = ['LocalGraphs', 'ProposalLocalGraphs', 'GCN', 'merge_text_comps']

View File

@ -1,80 +0,0 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init
class MeanAggregator(nn.Module):
def forward(self, features, A):
x = torch.bmm(A, features)
return x
class GraphConv(nn.Module):
def __init__(self, in_dim, out_dim):
super(GraphConv, self).__init__()
self.in_dim = in_dim
self.out_dim = out_dim
self.weight = nn.Parameter(torch.FloatTensor(in_dim * 2, out_dim))
self.bias = nn.Parameter(torch.FloatTensor(out_dim))
init.xavier_uniform_(self.weight)
init.constant_(self.bias, 0)
self.agg = MeanAggregator()
def forward(self, features, A):
b, n, d = features.shape
assert d == self.in_dim
agg_feats = self.agg(features, A)
cat_feats = torch.cat([features, agg_feats], dim=2)
out = torch.einsum('bnd,df->bnf', (cat_feats, self.weight))
out = F.relu(out + self.bias)
return out
class GCN(nn.Module):
"""Predict linkage between instances. This was from repo
https://github.com/Zhongdao/gcn_clustering: Linkage Based Face Clustering
via Graph Convolution Network.
[https://arxiv.org/abs/1903.11306]
Args:
in_dim(int): The input dimension.
out_dim(int): The output dimension.
"""
def __init__(self, in_dim, out_dim):
super(GCN, self).__init__()
self.bn0 = nn.BatchNorm1d(in_dim, affine=False).float()
self.conv1 = GraphConv(in_dim, 512)
self.conv2 = GraphConv(512, 256)
self.conv3 = GraphConv(256, 128)
self.conv4 = GraphConv(128, 64)
self.classifier = nn.Sequential(
nn.Linear(64, out_dim), nn.PReLU(out_dim), nn.Linear(out_dim, 2))
def forward(self, x, A, one_hop_indexes, train=True):
B, N, D = x.shape
x = x.view(-1, D)
x = self.bn0(x)
x = x.view(B, N, D)
x = self.conv1(x, A)
x = self.conv2(x, A)
x = self.conv3(x, A)
x = self.conv4(x, A)
k1 = one_hop_indexes.size(-1)
dout = x.size(-1)
edge_feat = torch.zeros(B, k1, dout)
for b in range(B):
edge_feat[b, :, :] = x[b, one_hop_indexes[b]]
edge_feat = edge_feat.view(-1, dout).to(x.device)
pred = self.classifier(edge_feat)
# shape: (B*k1)x2
return pred

View File

@ -1,307 +0,0 @@
import numpy as np
import torch
from mmocr.models.utils import RROIAlign
from .utils import (embed_geo_feats, euclidean_distance_matrix,
normalize_adjacent_matrix)
class LocalGraphs:
"""Generate local graphs for GCN to predict which instance a text component
belongs to. This was partially adapted from https://github.com/GXYM/DRRG:
Deep Relational Reasoning Graph Network for Arbitrary Shape Text Detection.
[https://arxiv.org/abs/2003.07493]
Args:
k_at_hops (tuple(int)): The number of h-hop neighbors.
active_connection (int): The number of neighbors deem as linked to a
pivot.
node_geo_feat_dim (int): The dimension of embedded geometric features
of a component.
pooling_scale (float): The spatial scale of RRoI-Aligning.
pooling_output_size (tuple(int)): The size of RRoI-Aligning output.
local_graph_filter_thr (float): The threshold to filter out identical
local graphs.
"""
def __init__(self, k_at_hops, active_connection, node_geo_feat_dim,
pooling_scale, pooling_output_size, local_graph_filter_thr):
assert isinstance(k_at_hops, tuple)
assert isinstance(active_connection, int)
assert isinstance(node_geo_feat_dim, int)
assert isinstance(pooling_scale, float)
assert isinstance(pooling_output_size, tuple)
assert isinstance(local_graph_filter_thr, float)
self.k_at_hops = k_at_hops
self.local_graph_depth = len(self.k_at_hops)
self.active_connection = active_connection
self.node_geo_feat_dim = node_geo_feat_dim
self.pooling = RROIAlign(pooling_output_size, pooling_scale)
self.local_graph_filter_thr = local_graph_filter_thr
def generate_local_graphs(self, sorted_complete_graph, gt_belong_labels):
"""Generate local graphs for GCN to predict which instance a text
component belongs to.
Args:
sorted_complete_graph (ndarray): The complete graph where nodes are
sorted according to their Euclidean distance.
gt_belong_labels (ndarray): The ground truth labels define which
instance text components (nodes in graphs) belong to.
Returns:
local_graph_node_list (list): The list of local graph neighbors of
pivots.
knn_graph_neighbor_list (list): The list of k nearest neighbors of
pivots.
"""
assert sorted_complete_graph.ndim == 2
assert (sorted_complete_graph.shape[0] ==
sorted_complete_graph.shape[1] == gt_belong_labels.shape[0])
knn_graphs = sorted_complete_graph[:, :self.k_at_hops[0] + 1]
local_graph_node_list = list()
knn_graph_neighbor_list = list()
for pivot_inx, knn_graph in enumerate(knn_graphs):
h_hop_neighbor_list = list()
one_hop_neighbors = set(knn_graph[1:])
h_hop_neighbor_list.append(one_hop_neighbors)
for hop_inx in range(1, self.local_graph_depth):
h_hop_neighbor_list.append(set())
for last_hop_neighbor_inx in h_hop_neighbor_list[-2]:
h_hop_neighbor_list[-1].update(
set(sorted_complete_graph[last_hop_neighbor_inx]
[1:self.k_at_hops[hop_inx] + 1]))
hops_neighbor_set = set(
[node for hop in h_hop_neighbor_list for node in hop])
hops_neighbor_list = list(hops_neighbor_set)
hops_neighbor_list.insert(0, pivot_inx)
if pivot_inx < 1:
local_graph_node_list.append(hops_neighbor_list)
knn_graph_neighbor_list.append(one_hop_neighbors)
else:
add_flag = True
for graph_inx in range(len(knn_graph_neighbor_list)):
knn_graph_neighbors = knn_graph_neighbor_list[graph_inx]
local_graph_nodes = local_graph_node_list[graph_inx]
node_union_num = len(
list(
set(knn_graph_neighbors).union(
set(one_hop_neighbors))))
node_intersect_num = len(
list(
set(knn_graph_neighbors).intersection(
set(one_hop_neighbors))))
one_hop_iou = node_intersect_num / (node_union_num + 1e-8)
if (one_hop_iou > self.local_graph_filter_thr
and pivot_inx in knn_graph_neighbors
and gt_belong_labels[local_graph_nodes[0]]
== gt_belong_labels[pivot_inx]
and gt_belong_labels[local_graph_nodes[0]] != 0):
add_flag = False
break
if add_flag:
local_graph_node_list.append(hops_neighbor_list)
knn_graph_neighbor_list.append(one_hop_neighbors)
return local_graph_node_list, knn_graph_neighbor_list
def generate_gcn_input(self, node_feat_batch, belong_label_batch,
local_graph_node_batch, knn_graph_neighbor_batch,
sorted_complete_graph):
"""Generate graph convolution network input data.
Args:
node_feat_batch (List[Tensor]): The node feature batch.
belong_label_batch (List[ndarray]): The text component belong label
batch.
local_graph_node_batch (List[List[list]]): The local graph
neighbors batch.
knn_graph_neighbor_batch (List[List[set]]): The knn graph neighbor
batch.
sorted_complete_graph (List[ndarray]): The complete graph sorted
according to the Euclidean distance.
Returns:
node_feat_batch_tensor (Tensor): The node features of Graph
Convolutional Network (GCN).
adjacent_mat_batch_tensor (Tensor): The adjacent matrices.
knn_inx_batch_tensor (Tensor): The indices of k nearest neighbors.
gt_linkage_batch_tensor (Tensor): The surpervision signal of GCN
for linkage prediction.
"""
assert isinstance(node_feat_batch, list)
assert isinstance(belong_label_batch, list)
assert isinstance(local_graph_node_batch, list)
assert isinstance(knn_graph_neighbor_batch, list)
assert isinstance(sorted_complete_graph, list)
max_graph_node_num = max([
len(local_graph_nodes)
for local_graph_node_list in local_graph_node_batch
for local_graph_nodes in local_graph_node_list
])
node_feat_batch_list = list()
adjacent_matrix_batch_list = list()
knn_inx_batch_list = list()
gt_linkage_batch_list = list()
for batch_inx, sorted_neighbors in enumerate(sorted_complete_graph):
node_feats = node_feat_batch[batch_inx]
local_graph_list = local_graph_node_batch[batch_inx]
knn_graph_neighbor_list = knn_graph_neighbor_batch[batch_inx]
belong_labels = belong_label_batch[batch_inx]
for graph_inx in range(len(local_graph_list)):
local_graph_nodes = local_graph_list[graph_inx]
local_graph_node_num = len(local_graph_nodes)
pivot_inx = local_graph_nodes[0]
knn_graph_neighbors = knn_graph_neighbor_list[graph_inx]
node_to_graph_inx = {
j: i
for i, j in enumerate(local_graph_nodes)
}
knn_inx_in_local_graph = torch.tensor(
[node_to_graph_inx[i] for i in knn_graph_neighbors],
dtype=torch.long)
pivot_feats = node_feats[torch.tensor(
pivot_inx, dtype=torch.long)]
normalized_feats = node_feats[torch.tensor(
local_graph_nodes, dtype=torch.long)] - pivot_feats
adjacent_matrix = np.zeros(
(local_graph_node_num, local_graph_node_num))
pad_normalized_feats = torch.cat([
normalized_feats,
torch.zeros(max_graph_node_num - local_graph_node_num,
normalized_feats.shape[1]).to(
node_feats.device)
],
dim=0)
for node in local_graph_nodes:
neighbors = sorted_neighbors[node,
1:self.active_connection + 1]
for neighbor in neighbors:
if neighbor in local_graph_nodes:
adjacent_matrix[node_to_graph_inx[node],
node_to_graph_inx[neighbor]] = 1
adjacent_matrix[node_to_graph_inx[neighbor],
node_to_graph_inx[node]] = 1
adjacent_matrix = normalize_adjacent_matrix(
adjacent_matrix, type='DAD')
adjacent_matrix_tensor = torch.zeros(max_graph_node_num,
max_graph_node_num).to(
node_feats.device)
adjacent_matrix_tensor[:local_graph_node_num, :
local_graph_node_num] = adjacent_matrix
local_graph_labels = torch.from_numpy(
belong_labels[local_graph_nodes]).type(torch.long)
knn_labels = local_graph_labels[knn_inx_in_local_graph]
edge_labels = ((belong_labels[pivot_inx] == knn_labels)
& (belong_labels[pivot_inx] > 0)).long()
node_feat_batch_list.append(pad_normalized_feats)
adjacent_matrix_batch_list.append(adjacent_matrix_tensor)
knn_inx_batch_list.append(knn_inx_in_local_graph)
gt_linkage_batch_list.append(edge_labels)
node_feat_batch_tensor = torch.stack(node_feat_batch_list, 0)
adjacent_mat_batch_tensor = torch.stack(adjacent_matrix_batch_list, 0)
knn_inx_batch_tensor = torch.stack(knn_inx_batch_list, 0)
gt_linkage_batch_tensor = torch.stack(gt_linkage_batch_list, 0)
return (node_feat_batch_tensor, adjacent_mat_batch_tensor,
knn_inx_batch_tensor, gt_linkage_batch_tensor)
def __call__(self, feat_maps, comp_attribs):
"""Generate local graphs.
Args:
feat_maps (Tensor): The feature maps to propose node features in
graph.
comp_attribs (ndarray): The text components attributes.
Returns:
node_feats_batch (Tensor): The node features of Graph Convolutional
Network(GCN).
adjacent_matrices_batch (Tensor): The adjacent matrices.
knn_inx_batch (Tensor): The indices of k nearest neighbors.
gt_linkage_batch (Tensor): The surpervision signal of GCN for
linkage prediction.
"""
assert isinstance(feat_maps, torch.Tensor)
assert comp_attribs.shape[2] == 8
dist_sort_graph_batch_list = []
local_graph_node_batch_list = []
knn_graph_neighbor_batch_list = []
node_feature_batch_list = []
belong_label_batch_list = []
for batch_inx in range(comp_attribs.shape[0]):
comp_num = int(comp_attribs[batch_inx, 0, 0])
comp_geo_attribs = comp_attribs[batch_inx, :comp_num, 1:7]
node_belong_labels = comp_attribs[batch_inx, :comp_num,
7].astype(np.int32)
comp_centers = comp_geo_attribs[:, 0:2]
distance_matrix = euclidean_distance_matrix(
comp_centers, comp_centers)
graph_node_geo_feats = embed_geo_feats(comp_geo_attribs,
self.node_geo_feat_dim)
graph_node_geo_feats = torch.from_numpy(
graph_node_geo_feats).float().to(feat_maps.device)
batch_id = np.zeros(
(comp_geo_attribs.shape[0], 1), dtype=np.float32) * batch_inx
text_comps = np.hstack(
(batch_id, comp_geo_attribs.astype(np.float32)))
text_comps = torch.from_numpy(text_comps).float().to(
feat_maps.device)
comp_content_feats = self.pooling(
feat_maps[batch_inx].unsqueeze(0), text_comps)
comp_content_feats = comp_content_feats.view(
comp_content_feats.shape[0], -1).to(feat_maps.device)
node_feats = torch.cat((comp_content_feats, graph_node_geo_feats),
dim=-1)
dist_sort_complete_graph = np.argsort(distance_matrix, axis=1)
(local_graph_nodes,
knn_graph_neighbors) = self.generate_local_graphs(
dist_sort_complete_graph, node_belong_labels)
node_feature_batch_list.append(node_feats)
belong_label_batch_list.append(node_belong_labels)
local_graph_node_batch_list.append(local_graph_nodes)
knn_graph_neighbor_batch_list.append(knn_graph_neighbors)
dist_sort_graph_batch_list.append(dist_sort_complete_graph)
(node_feats_batch, adjacent_matrices_batch, knn_inx_batch,
gt_linkage_batch) = \
self.generate_gcn_input(node_feature_batch_list,
belong_label_batch_list,
local_graph_node_batch_list,
knn_graph_neighbor_batch_list,
dist_sort_graph_batch_list)
return (node_feats_batch, adjacent_matrices_batch, knn_inx_batch,
gt_linkage_batch)

View File

@ -1,418 +0,0 @@
import cv2
import numpy as np
import torch
# from mmocr.models.textdet.postprocess import la_nms
from mmocr.models.utils import RROIAlign
from .utils import (embed_geo_feats, euclidean_distance_matrix,
normalize_adjacent_matrix)
class ProposalLocalGraphs:
"""Propose text components and generate local graphs. This was partially
adapted from https://github.com/GXYM/DRRG: Deep Relational Reasoning Graph
Network for Arbitrary Shape Text Detection.
[https://arxiv.org/abs/2003.07493]
Args:
k_at_hops (tuple(int)): The number of i-hop neighbors,
i = 1, 2, ..., h.
active_connection (int): The number of two hop neighbors deem as linked
to a pivot.
node_geo_feat_dim (int): The dimension of embedded geometric features
of a component.
pooling_scale (float): The spatial scale of RRoI-Aligning.
pooling_output_size (tuple(int)): The size of RRoI-Aligning output.
nms_thr (float): The locality-aware NMS threshold.
min_width (float): The minimum width of text components.
max_width (float): The maximum width of text components.
comp_shrink_ratio (float): The shrink ratio of text components.
comp_ratio (float): The reciprocal of aspect ratio of text components.
text_region_thr (float): The threshold for text region probability map.
center_region_thr (float): The threshold for text center region
probability map.
center_region_area_thr (int): The threshold for filtering out
small-size text center region.
"""
def __init__(self, k_at_hops, active_connection, node_geo_feat_dim,
pooling_scale, pooling_output_size, nms_thr, min_width,
max_width, comp_shrink_ratio, comp_ratio, text_region_thr,
center_region_thr, center_region_area_thr):
assert isinstance(k_at_hops, tuple)
assert isinstance(active_connection, int)
assert isinstance(node_geo_feat_dim, int)
assert isinstance(pooling_scale, float)
assert isinstance(pooling_output_size, tuple)
assert isinstance(nms_thr, float)
assert isinstance(min_width, float)
assert isinstance(max_width, float)
assert isinstance(comp_shrink_ratio, float)
assert isinstance(comp_ratio, float)
assert isinstance(text_region_thr, float)
assert isinstance(center_region_thr, float)
assert isinstance(center_region_area_thr, int)
self.k_at_hops = k_at_hops
self.active_connection = active_connection
self.local_graph_depth = len(self.k_at_hops)
self.node_geo_feat_dim = node_geo_feat_dim
self.pooling = RROIAlign(pooling_output_size, pooling_scale)
self.nms_thr = nms_thr
self.min_width = min_width
self.max_width = max_width
self.comp_shrink_ratio = comp_shrink_ratio
self.comp_ratio = comp_ratio
self.text_region_thr = text_region_thr
self.center_region_thr = center_region_thr
self.center_region_area_thr = center_region_area_thr
def fill_hole(self, input_mask):
h, w = input_mask.shape
canvas = np.zeros((h + 2, w + 2), np.uint8)
canvas[1:h + 1, 1:w + 1] = input_mask.copy()
mask = np.zeros((h + 4, w + 4), np.uint8)
cv2.floodFill(canvas, mask, (0, 0), 1)
canvas = canvas[1:h + 1, 1:w + 1].astype(np.bool)
return (~canvas | input_mask.astype(np.uint8))
def propose_comps(self, top_radius_map, bot_radius_map, sin_map, cos_map,
score_map, min_width, max_width, comp_shrink_ratio,
comp_ratio):
"""Generate text components.
Args:
top_radius_map (ndarray): The predicted distance map from each
pixel in text center region to top sideline.
bot_radius_map (ndarray): The predicted distance map from each
pixel in text center region to bottom sideline.
sin_map (ndarray): The predicted sin(theta) map.
cos_map (ndarray): The predicted cos(theta) map.
score_map (ndarray): The score map for NMS.
min_width (float): The minimum width of text components.
max_width (float): The maximum width of text components.
comp_shrink_ratio (float): The shrink ratio of text components.
comp_ratio (float): The reciprocal of aspect ratio of text
components.
Returns:
text_comps (ndarray): The text components.
"""
comp_centers = np.argwhere(score_map > 0)
comp_centers = comp_centers[np.argsort(comp_centers[:, 0])]
y = comp_centers[:, 0]
x = comp_centers[:, 1]
top_radius = top_radius_map[y, x].reshape((-1, 1)) * comp_shrink_ratio
bot_radius = bot_radius_map[y, x].reshape((-1, 1)) * comp_shrink_ratio
sin = sin_map[y, x].reshape((-1, 1))
cos = cos_map[y, x].reshape((-1, 1))
top_mid_x_offset = top_radius * cos
top_mid_y_offset = top_radius * sin
bot_mid_x_offset = bot_radius * cos
bot_mid_y_offset = bot_radius * sin
top_mid_pnt = comp_centers + np.hstack(
[top_mid_y_offset, top_mid_x_offset])
bot_mid_pnt = comp_centers - np.hstack(
[bot_mid_y_offset, bot_mid_x_offset])
width = (top_radius + bot_radius) * comp_ratio
width = np.clip(width, min_width, max_width)
top_left = top_mid_pnt - np.hstack([width * cos, -width * sin
])[:, ::-1]
top_right = top_mid_pnt + np.hstack([width * cos, -width * sin
])[:, ::-1]
bot_right = bot_mid_pnt + np.hstack([width * cos, -width * sin
])[:, ::-1]
bot_left = bot_mid_pnt - np.hstack([width * cos, -width * sin
])[:, ::-1]
text_comps = np.hstack([top_left, top_right, bot_right, bot_left])
score = score_map[y, x].reshape((-1, 1))
text_comps = np.hstack([text_comps, score])
return text_comps
def propose_comps_and_attribs(self, text_region_map, center_region_map,
top_radius_map, bot_radius_map, sin_map,
cos_map):
"""Generate text components and attributes.
Args:
text_region_map (ndarray): The predicted text region probability
map.
center_region_map (ndarray): The predicted text center region
probability map.
top_radius_map (ndarray): The predicted distance map from each
pixel in text center region to top sideline.
bot_radius_map (ndarray): The predicted distance map from each
pixel in text center region to bottom sideline.
sin_map (ndarray): The predicted sin(theta) map.
cos_map (ndarray): The predicted cos(theta) map.
Returns:
comp_attribs (ndarray): The text components attributes.
text_comps (ndarray): The text components.
"""
assert (text_region_map.shape == center_region_map.shape ==
top_radius_map.shape == bot_radius_map == sin_map.shape ==
cos_map.shape)
text_mask = text_region_map > self.text_region_thr
center_region_mask = (center_region_map >
self.center_region_thr) * text_mask
scale = np.sqrt(1.0 / (sin_map**2 + cos_map**2))
sin_map, cos_map = sin_map * scale, cos_map * scale
center_region_mask = self.fill_hole(center_region_mask)
center_region_contours, _ = cv2.findContours(
center_region_mask.astype(np.uint8), cv2.RETR_TREE,
cv2.CHAIN_APPROX_SIMPLE)
mask = np.zeros_like(center_region_mask)
comp_list = []
for contour in center_region_contours:
current_center_mask = mask.copy()
cv2.drawContours(current_center_mask, [contour], -1, 1, -1)
if current_center_mask.sum() <= self.center_region_area_thr:
continue
score_map = text_region_map * current_center_mask
text_comp = self.propose_comps(top_radius_map, bot_radius_map,
sin_map, cos_map, score_map,
self.min_width, self.max_width,
self.comp_shrink_ratio,
self.comp_ratio)
# text_comp = la_nms(text_comp.astype('float32'), self.nms_thr)
text_comp_mask = mask.copy()
text_comps_bboxes = text_comp[:, :8].reshape(
(-1, 4, 2)).astype(np.int32)
cv2.drawContours(text_comp_mask, text_comps_bboxes, -1, 1, -1)
if (text_comp_mask * text_mask).sum() < text_comp_mask.sum() * 0.5:
continue
comp_list.append(text_comp)
if len(comp_list) <= 0:
return None, None
text_comps = np.vstack(comp_list)
centers = np.mean(
text_comps[:, :8].reshape((-1, 4, 2)), axis=1).astype(np.int32)
x = centers[:, 0]
y = centers[:, 1]
h = top_radius_map[y, x].reshape(
(-1, 1)) + bot_radius_map[y, x].reshape((-1, 1))
w = np.clip(h * self.comp_ratio, self.min_width, self.max_width)
sin = sin_map[y, x].reshape((-1, 1))
cos = cos_map[y, x].reshape((-1, 1))
x = x.reshape((-1, 1))
y = y.reshape((-1, 1))
comp_attribs = np.hstack([x, y, h, w, cos, sin])
return comp_attribs, text_comps
def generate_local_graphs(self, sorted_complete_graph, node_feats):
"""Generate local graphs and Graph Convolution Network input data.
Args:
sorted_complete_graph (ndarray): The complete graph where nodes are
sorted according to their Euclidean distance.
node_feats (tensor): The graph nodes features.
Returns:
node_feats_tensor (tensor): The graph nodes features.
adjacent_matrix_tensor (tensor): The adjacent matrix of graph.
pivot_inx_tensor (tensor): The pivot indices in local graph.
knn_inx_tensor (tensor): The k nearest neighbor nodes indexes in
local graph.
local_graph_node_tensor (tensor): The indices of nodes in local
graph.
"""
assert sorted_complete_graph.ndim == 2
assert (sorted_complete_graph.shape[0] ==
sorted_complete_graph.shape[1] == node_feats.shape[0])
knn_graphs = sorted_complete_graph[:, :self.k_at_hops[0] + 1]
local_graph_node_list = list()
knn_graph_neighbor_list = list()
for pivot_inx, knn_graph in enumerate(knn_graphs):
h_hop_neighbor_list = list()
one_hop_neighbors = set(knn_graph[1:])
h_hop_neighbor_list.append(one_hop_neighbors)
for hop_inx in range(1, self.local_graph_depth):
h_hop_neighbor_list.append(set())
for last_hop_neighbor_inx in h_hop_neighbor_list[-2]:
h_hop_neighbor_list[-1].update(
set(sorted_complete_graph[last_hop_neighbor_inx]
[1:self.k_at_hops[hop_inx] + 1]))
hops_neighbor_set = set(
[node for hop in h_hop_neighbor_list for node in hop])
hops_neighbor_list = list(hops_neighbor_set)
hops_neighbor_list.insert(0, pivot_inx)
local_graph_node_list.append(hops_neighbor_list)
knn_graph_neighbor_list.append(one_hop_neighbors)
max_graph_node_num = max([
len(local_graph_nodes)
for local_graph_nodes in local_graph_node_list
])
node_normalized_feats = list()
adjacent_matrix_list = list()
knn_inx = list()
pivot_graph_inx = list()
local_graph_tensor_list = list()
for graph_inx in range(len(local_graph_node_list)):
local_graph_nodes = local_graph_node_list[graph_inx]
local_graph_node_num = len(local_graph_nodes)
pivot_inx = local_graph_nodes[0]
knn_graph_neighbors = knn_graph_neighbor_list[graph_inx]
node_to_graph_inx = {j: i for i, j in enumerate(local_graph_nodes)}
pivot_node_inx = torch.tensor([
node_to_graph_inx[pivot_inx],
]).type(torch.long)
knn_inx_in_local_graph = torch.tensor(
[node_to_graph_inx[i] for i in knn_graph_neighbors],
dtype=torch.long)
pivot_feats = node_feats[torch.tensor(pivot_inx, dtype=torch.long)]
normalized_feats = node_feats[torch.tensor(
local_graph_nodes, dtype=torch.long)] - pivot_feats
adjacent_matrix = np.zeros(
(local_graph_node_num, local_graph_node_num))
pad_normalized_feats = torch.cat([
normalized_feats,
torch.zeros(max_graph_node_num - local_graph_node_num,
normalized_feats.shape[1]).to(node_feats.device)
],
dim=0)
for node in local_graph_nodes:
neighbors = sorted_complete_graph[node,
1:self.active_connection + 1]
for neighbor in neighbors:
if neighbor in local_graph_nodes:
adjacent_matrix[node_to_graph_inx[node],
node_to_graph_inx[neighbor]] = 1
adjacent_matrix[node_to_graph_inx[neighbor],
node_to_graph_inx[node]] = 1
adjacent_matrix = normalize_adjacent_matrix(
adjacent_matrix, type='DAD')
adjacent_matrix_tensor = torch.zeros(
max_graph_node_num, max_graph_node_num).to(node_feats.device)
adjacent_matrix_tensor[:local_graph_node_num, :
local_graph_node_num] = adjacent_matrix
local_graph_tensor = torch.tensor(local_graph_nodes)
local_graph_tensor = torch.cat([
local_graph_tensor,
torch.zeros(
max_graph_node_num - local_graph_node_num,
dtype=torch.long)
],
dim=0)
node_normalized_feats.append(pad_normalized_feats)
adjacent_matrix_list.append(adjacent_matrix_tensor)
pivot_graph_inx.append(pivot_node_inx)
knn_inx.append(knn_inx_in_local_graph)
local_graph_tensor_list.append(local_graph_tensor)
node_feats_tensor = torch.stack(node_normalized_feats, 0)
adjacent_matrix_tensor = torch.stack(adjacent_matrix_list, 0)
pivot_inx_tensor = torch.stack(pivot_graph_inx, 0)
knn_inx_tensor = torch.stack(knn_inx, 0)
local_graph_node_tensor = torch.stack(local_graph_tensor_list, 0)
return (node_feats_tensor, adjacent_matrix_tensor, pivot_inx_tensor,
knn_inx_tensor, local_graph_node_tensor)
def __call__(self, preds, feat_maps):
"""Generate local graphs and Graph Convolution Network input data.
Args:
preds (tensor): The predicted maps.
feat_maps (tensor): The feature maps to extract content features of
text components.
Returns:
node_feats_tensor (tensor): The graph nodes features.
adjacent_matrix_tensor (tensor): The adjacent matrix of graph.
pivot_inx_tensor (tensor): The pivot indices in local graph.
knn_inx_tensor (tensor): The k nearest neighbor nodes indices in
local graph.
local_graph_node_tensor (tensor): The indices of nodes in local
graph.
text_comps (ndarray): The predicted text components.
"""
pred_text_region = torch.sigmoid(preds[0, 0]).data.cpu().numpy()
pred_center_region = torch.sigmoid(preds[0, 1]).data.cpu().numpy()
pred_sin_map = preds[0, 2].data.cpu().numpy()
pred_cos_map = preds[0, 3].data.cpu().numpy()
pred_top_radius_map = preds[0, 4].data.cpu().numpy()
pred_bot_radius_map = preds[0, 5].data.cpu().numpy()
comp_attribs, text_comps = self.propose_comps_and_attribs(
pred_text_region, pred_center_region, pred_top_radius_map,
pred_bot_radius_map, pred_sin_map, pred_cos_map)
if comp_attribs is None:
none_flag = True
return none_flag, (0, 0, 0, 0, 0, 0)
comp_centers = comp_attribs[:, 0:2]
distance_matrix = euclidean_distance_matrix(comp_centers, comp_centers)
graph_node_geo_feats = embed_geo_feats(comp_attribs,
self.node_geo_feat_dim)
graph_node_geo_feats = torch.from_numpy(
graph_node_geo_feats).float().to(preds.device)
batch_id = np.zeros((comp_attribs.shape[0], 1), dtype=np.float32)
text_comps_bboxes = np.hstack(
(batch_id, comp_attribs.astype(np.float32, copy=False)))
text_comps_bboxes = torch.from_numpy(text_comps_bboxes).float().to(
preds.device)
comp_content_feats = self.pooling(feat_maps, text_comps_bboxes)
comp_content_feats = comp_content_feats.view(
comp_content_feats.shape[0], -1).to(preds.device)
node_feats = torch.cat((comp_content_feats, graph_node_geo_feats),
dim=-1)
dist_sort_complete_graph = np.argsort(distance_matrix, axis=1)
(node_feats_tensor, adjacent_matrix_tensor, pivot_inx_tensor,
knn_inx_tensor, local_graph_node_tensor) = self.generate_local_graphs(
dist_sort_complete_graph, node_feats)
none_flag = False
return none_flag, (node_feats_tensor, adjacent_matrix_tensor,
pivot_inx_tensor, knn_inx_tensor,
local_graph_node_tensor, text_comps)

View File

@ -1,354 +0,0 @@
import functools
import operator
from typing import List
import cv2
import numpy as np
import torch
from numpy.linalg import norm
def normalize_adjacent_matrix(A, type='AD'):
"""Normalize adjacent matrix for GCN.
This was from repo https://github.com/GXYM/DRRG.
"""
if type == 'DAD':
# d is Degree of nodes A=A+I
# L = D^-1/2 A D^-1/2
A = A + np.eye(A.shape[0]) # A=A+I
d = np.sum(A, axis=0)
d_inv = np.power(d, -0.5).flatten()
d_inv[np.isinf(d_inv)] = 0.0
d_inv = np.diag(d_inv)
G = A.dot(d_inv).transpose().dot(d_inv)
G = torch.from_numpy(G)
elif type == 'AD':
A = A + np.eye(A.shape[0]) # A=A+I
A = torch.from_numpy(A)
D = A.sum(1, keepdim=True)
G = A.div(D)
else:
A = A + np.eye(A.shape[0]) # A=A+I
A = torch.from_numpy(A)
D = A.sum(1, keepdim=True)
D = np.diag(D)
G = D - A
return G
def euclidean_distance_matrix(A, B):
"""Calculate the Euclidean distance matrix."""
M = A.shape[0]
N = B.shape[0]
assert A.shape[1] == B.shape[1]
A_dots = (A * A).sum(axis=1).reshape((M, 1)) * np.ones(shape=(1, N))
B_dots = (B * B).sum(axis=1) * np.ones(shape=(M, 1))
D_squared = A_dots + B_dots - 2 * A.dot(B.T)
zero_mask = np.less(D_squared, 0.0)
D_squared[zero_mask] = 0.0
return np.sqrt(D_squared)
def embed_geo_feats(geo_feats, out_dim):
"""Embed geometric features of text components. This was partially adapted
from https://github.com/GXYM/DRRG.
Args:
geo_feats (ndarray): The geometric features of text components.
out_dim (int): The output dimension.
Returns:
embedded_feats (ndarray): The embedded geometric features.
"""
assert isinstance(out_dim, int)
assert out_dim >= geo_feats.shape[1]
comp_num = geo_feats.shape[0]
feat_dim = geo_feats.shape[1]
feat_repeat_times = out_dim // feat_dim
residue_dim = out_dim % feat_dim
if residue_dim > 0:
embed_wave = np.array([
np.power(1000, 2.0 * (j // 2) / feat_repeat_times + 1)
for j in range(feat_repeat_times + 1)
]).reshape((feat_repeat_times + 1, 1, 1))
repeat_feats = np.repeat(
np.expand_dims(geo_feats, axis=0), feat_repeat_times, axis=0)
residue_feats = np.hstack([
geo_feats[:, 0:residue_dim],
np.zeros((comp_num, feat_dim - residue_dim))
])
repeat_feats = np.stack([repeat_feats, residue_feats], axis=0)
embedded_feats = repeat_feats / embed_wave
embedded_feats[:, 0::2] = np.sin(embedded_feats[:, 0::2])
embedded_feats[:, 1::2] = np.cos(embedded_feats[:, 1::2])
embedded_feats = np.transpose(embedded_feats, (1, 0, 2)).reshape(
(comp_num, -1))[:, 0:out_dim]
else:
embed_wave = np.array([
np.power(1000, 2.0 * (j // 2) / feat_repeat_times)
for j in range(feat_repeat_times)
]).reshape((feat_repeat_times, 1, 1))
repeat_feats = np.repeat(
np.expand_dims(geo_feats, axis=0), feat_repeat_times, axis=0)
embedded_feats = repeat_feats / embed_wave
embedded_feats[:, 0::2] = np.sin(embedded_feats[:, 0::2])
embedded_feats[:, 1::2] = np.cos(embedded_feats[:, 1::2])
embedded_feats = np.transpose(embedded_feats, (1, 0, 2)).reshape(
(comp_num, -1))
return embedded_feats
def min_connect_path(list_all: List[list]):
"""This is from https://github.com/GXYM/DRRG."""
list_nodo = list_all.copy()
res: List[List[int]] = []
ept = [0, 0]
def norm2(a, b):
return ((a[0] - b[0])**2 + (a[1] - b[1])**2)**0.5
dict00 = {}
dict11 = {}
ept[0] = list_nodo[0]
ept[1] = list_nodo[0]
list_nodo.remove(list_nodo[0])
while list_nodo:
for i in list_nodo:
length0 = norm2(i, ept[0])
dict00[length0] = [i, ept[0]]
length1 = norm2(ept[1], i)
dict11[length1] = [ept[1], i]
key0 = min(dict00.keys())
key1 = min(dict11.keys())
if key0 <= key1:
ss = dict00[key0][0]
ee = dict00[key0][1]
res.insert(0, [list_all.index(ss), list_all.index(ee)])
list_nodo.remove(ss)
ept[0] = ss
else:
ss = dict11[key1][0]
ee = dict11[key1][1]
res.append([list_all.index(ss), list_all.index(ee)])
list_nodo.remove(ee)
ept[1] = ee
dict00 = {}
dict11 = {}
path = functools.reduce(operator.concat, res)
path = sorted(set(path), key=path.index)
return res, path
def clusters2labels(clusters, node_num):
"""This is from https://github.com/GXYM/DRRG."""
labels = (-1) * np.ones((node_num, ))
for cluster_inx, cluster in enumerate(clusters):
for node in cluster:
labels[node.inx] = cluster_inx
assert np.sum(labels < 0) < 1
return labels
def remove_single(text_comps, pred):
"""Remove isolated single text components.
This is from https://github.com/GXYM/DRRG.
"""
single_flags = np.zeros_like(pred)
pred_labels = np.unique(pred)
for label in pred_labels:
current_label_flag = pred == label
if np.sum(current_label_flag) == 1:
single_flags[np.where(current_label_flag)[0][0]] = 1
remain_inx = [i for i in range(len(pred)) if not single_flags[i]]
remain_inx = np.asarray(remain_inx)
return text_comps[remain_inx, :], pred[remain_inx]
class Node:
def __init__(self, inx):
self.__inx = inx
self.__links = set()
@property
def inx(self):
return self.__inx
@property
def links(self):
return set(self.__links)
def add_link(self, other, score):
self.__links.add(other)
other.__links.add(self)
def connected_components(nodes, score_dict, thr):
"""Connected components searching.
This is from https://github.com/GXYM/DRRG.
"""
result = []
nodes = set(nodes)
while nodes:
node = nodes.pop()
group = {node}
queue = [node]
while queue:
node = queue.pop(0)
if thr is not None:
neighbors = {
linked_neighbor
for linked_neighbor in node.links if score_dict[tuple(
sorted([node.inx, linked_neighbor.inx]))] >= thr
}
else:
neighbors = node.links
neighbors.difference_update(group)
nodes.difference_update(neighbors)
group.update(neighbors)
queue.extend(neighbors)
result.append(group)
return result
def graph_propagation(edges,
scores,
link_thr,
bboxes=None,
dis_thr=50,
pool='avg'):
"""Propagate graph linkage score information.
This is from repo https://github.com/GXYM/DRRG.
"""
edges = np.sort(edges, axis=1)
score_dict = {}
if pool is None:
for i, edge in enumerate(edges):
score_dict[edge[0], edge[1]] = scores[i]
elif pool == 'avg':
for i, edge in enumerate(edges):
if bboxes is not None:
box1 = bboxes[edge[0]][:8].reshape(4, 2)
box2 = bboxes[edge[1]][:8].reshape(4, 2)
center1 = np.mean(box1, axis=0)
center2 = np.mean(box2, axis=0)
dst = norm(center1 - center2)
if dst > dis_thr:
scores[i] = 0
if (edge[0], edge[1]) in score_dict:
score_dict[edge[0], edge[1]] = 0.5 * (
score_dict[edge[0], edge[1]] + scores[i])
else:
score_dict[edge[0], edge[1]] = scores[i]
elif pool == 'max':
for i, edge in enumerate(edges):
if (edge[0], edge[1]) in score_dict:
score_dict[edge[0],
edge[1]] = max(score_dict[edge[0], edge[1]],
scores[i])
else:
score_dict[edge[0], edge[1]] = scores[i]
else:
raise ValueError('Pooling operation not supported')
nodes = np.sort(np.unique(edges.flatten()))
mapping = -1 * np.ones((nodes.max() + 1), dtype=np.int)
mapping[nodes] = np.arange(nodes.shape[0])
link_inx = mapping[edges]
vertex = [Node(node) for node in nodes]
for link, score in zip(link_inx, scores):
vertex[link[0]].add_link(vertex[link[1]], score)
clusters = connected_components(vertex, score_dict, link_thr)
return clusters
def in_contour(cont, point):
x, y = point
return cv2.pointPolygonTest(cont, (x, y), False) > 0
def select_edge(cont, box):
"""This is from repo https://github.com/GXYM/DRRG."""
cont = np.array(cont)
box = box.astype(np.int32)
c1 = np.array(0.5 * (box[0, :] + box[3, :]), dtype=np.int)
c2 = np.array(0.5 * (box[1, :] + box[2, :]), dtype=np.int)
if not in_contour(cont, c1):
return [box[0, :].tolist(), box[3, :].tolist()]
elif not in_contour(cont, c2):
return [box[1, :].tolist(), box[2, :].tolist()]
else:
return None
def comps2boundary(text_comps, final_pred):
"""Propose text components and generate local graphs.
This is from repo https://github.com/GXYM/DRRG.
"""
bbox_contours = list()
for inx in range(0, int(np.max(final_pred)) + 1):
current_instance = np.where(final_pred == inx)
boxes = text_comps[current_instance, :8].reshape(
(-1, 4, 2)).astype(np.int32)
boundary_point = None
if boxes.shape[0] > 1:
centers = np.mean(boxes, axis=1).astype(np.int32).tolist()
paths, routes_path = min_connect_path(centers)
boxes = boxes[routes_path]
top = np.mean(boxes[:, 0:2, :], axis=1).astype(np.int32).tolist()
bot = np.mean(boxes[:, 2:4, :], axis=1).astype(np.int32).tolist()
edge1 = select_edge(top + bot[::-1], boxes[0])
edge2 = select_edge(top + bot[::-1], boxes[-1])
if edge1 is not None:
top.insert(0, edge1[0])
bot.insert(0, edge1[1])
if edge2 is not None:
top.append(edge2[0])
bot.append(edge2[1])
boundary_point = np.array(top + bot[::-1])
elif boxes.shape[0] == 1:
top = boxes[0, 0:2, :].astype(np.int32).tolist()
bot = boxes[0, 2:4:-1, :].astype(np.int32).tolist()
boundary_point = np.array(top + bot)
if boundary_point is None:
continue
boundary_point = [p for p in boundary_point.flatten().tolist()]
bbox_contours.append(boundary_point)
return bbox_contours
def merge_text_comps(edges, scores, text_comps, link_thr):
"""Merge text components into text instance."""
clusters = graph_propagation(edges, scores, link_thr)
pred_labels = clusters2labels(clusters, text_comps.shape[0])
text_comps, final_pred = remove_single(text_comps, pred_labels)
boundaries = comps2boundary(text_comps, final_pred)
return boundaries

View File

@ -1,5 +1,5 @@
from .ce_loss import CELoss, SARLoss, TFLoss
from .ctc_loss import CTCLoss
from .seg_loss import CAFCNLoss, SegLoss
from .seg_loss import SegLoss
__all__ = ['CELoss', 'SARLoss', 'CTCLoss', 'TFLoss', 'SegLoss', 'CAFCNLoss']
__all__ = ['CELoss', 'SARLoss', 'CTCLoss', 'TFLoss', 'SegLoss']

View File

@ -3,7 +3,6 @@ import torch.nn as nn
import torch.nn.functional as F
from mmdet.models.builder import LOSSES
from mmocr.models.common.losses import DiceLoss
@LOSSES.register_module()
@ -68,109 +67,3 @@ class SegLoss(nn.Module):
losses['loss_seg'] = loss_seg
return losses
@LOSSES.register_module()
class CAFCNLoss(SegLoss):
"""Implementation of loss module in `CA-FCN.
<https://arxiv.org/pdf/1809.06508.pdf>`_
Args:
alpha (float): Weight ratio of attention loss.
attn_s2_downsample_ratio (float): Downsample ratio
of attention map from output stage 2.
attn_s3_downsample_ratio (float): Downsample ratio
of attention map from output stage 3.
seg_downsample_ratio (float): Downsample ratio of
segmentation map.
attn_with_dice_loss (bool): If True, use dice_loss for attention,
else BCELoss.
with_attn (bool): If True, include attention loss, else
segmentation loss only.
seg_with_loss_weight (bool): If True, set weight for
segmentation loss.
ignore_index (int): Specifies a target value that is ignored
and does not contribute to the input gradient.
"""
def __init__(self,
alpha=1.0,
attn_s2_downsample_ratio=0.25,
attn_s3_downsample_ratio=0.125,
seg_downsample_ratio=0.5,
attn_with_dice_loss=False,
with_attn=True,
seg_with_loss_weight=True,
ignore_index=255):
super().__init__(seg_downsample_ratio, seg_with_loss_weight,
ignore_index)
assert isinstance(alpha, (int, float))
assert isinstance(attn_s2_downsample_ratio, (int, float))
assert isinstance(attn_s3_downsample_ratio, (int, float))
assert 0 < attn_s2_downsample_ratio <= 1
assert 0 < attn_s3_downsample_ratio <= 1
self.alpha = alpha
self.attn_s2_downsample_ratio = attn_s2_downsample_ratio
self.attn_s3_downsample_ratio = attn_s3_downsample_ratio
self.with_attn = with_attn
self.attn_with_dice_loss = attn_with_dice_loss
# attention loss
if with_attn:
if attn_with_dice_loss:
self.criterion_attn = DiceLoss()
else:
self.criterion_attn = nn.BCELoss(reduction='none')
def attn_loss(self, out_neck, gt_kernels):
attn_map_s2 = out_neck[0] # bsz * 2 * H/4 * W/4
mask_s2 = torch.stack([
item[2].rescale(self.attn_s2_downsample_ratio).to_tensor(
torch.float, attn_map_s2.device) for item in gt_kernels
])
attn_target_s2 = torch.stack([
item[0].rescale(self.attn_s2_downsample_ratio).to_tensor(
torch.float, attn_map_s2.device) for item in gt_kernels
])
mask_s3 = torch.stack([
item[2].rescale(self.attn_s3_downsample_ratio).to_tensor(
torch.float, attn_map_s2.device) for item in gt_kernels
])
attn_target_s3 = torch.stack([
item[0].rescale(self.attn_s3_downsample_ratio).to_tensor(
torch.float, attn_map_s2.device) for item in gt_kernels
])
targets = [
attn_target_s2, attn_target_s3, attn_target_s3, attn_target_s3
]
masks = [mask_s2, mask_s3, mask_s3, mask_s3]
loss_attn = 0.
for i in range(len(out_neck) - 1):
pred = out_neck[i]
if self.attn_with_dice_loss:
loss_attn += self.criterion_attn(pred, targets[i], masks[i])
else:
loss_attn += torch.sum(
self.criterion_attn(pred, targets[i]) *
masks[i]) / torch.sum(masks[i])
return loss_attn
def forward(self, out_neck, out_head, gt_kernels):
losses = super().forward(out_neck, out_head, gt_kernels)
if self.with_attn:
loss_attn = self.attn_loss(out_neck, gt_kernels)
losses['loss_attn'] = loss_attn
return losses

View File

@ -1,5 +1,3 @@
from .cafcn_neck import CAFCNNeck
from .fpn_ocr import FPNOCR
from .fpn_seg import FPNSeg
__all__ = ['CAFCNNeck', 'FPNSeg', 'FPNOCR']
__all__ = ['FPNOCR']

View File

@ -1,223 +0,0 @@
import torch
import torch.nn.functional as F
from mmcv.cnn import ConvModule
from mmcv.ops import DeformConv2dPack
from torch import nn
from mmdet.models.builder import NECKS
class CharAttn(nn.Module):
"""Implementation of Character attention module in `CA-FCN.
<https://arxiv.org/pdf/1809.06508.pdf>`_
"""
def __init__(self, in_channels=128, out_channels=128, deformable=False):
super().__init__()
assert isinstance(in_channels, int)
assert isinstance(deformable, bool)
self.in_channels = in_channels
self.out_channels = out_channels
self.deformable = deformable
# attention layers
self.attn_layer = nn.Sequential(
ConvModule(
in_channels,
in_channels,
3,
stride=1,
padding=1,
norm_cfg=dict(type='BN')),
ConvModule(
in_channels,
1,
3,
stride=1,
padding=1,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='Sigmoid')))
conv_kwargs = dict(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=(3, 3),
stride=1,
padding=1)
if self.deformable:
self.conv = DeformConv2dPack(**conv_kwargs)
else:
self.conv = nn.Conv2d(**conv_kwargs)
self.bn = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
def forward(self, in_feat):
# Calculate attn map
attn_map = self.attn_layer(in_feat) # N * 1 * H * W
in_feat = self.relu(self.bn(self.conv(in_feat)))
out_feat_map = self._upsample_mul(in_feat, 1 + attn_map)
return out_feat_map, attn_map
def _upsample_add(self, x, y):
return F.interpolate(x, size=y.size()[2:]) + y
def _upsample_mul(self, x, y):
return F.interpolate(x, size=y.size()[2:]) * y
class FeatGenerator(nn.Module):
"""Generate attention-augmented stage feature from backbone stage
feature."""
def __init__(self,
in_channels=512,
out_channels=128,
deformable=True,
concat=False,
upsample=False,
with_attn=True):
super().__init__()
self.concat = concat
self.upsample = upsample
self.with_attn = with_attn
if with_attn:
self.char_attn = CharAttn(
in_channels=in_channels,
out_channels=out_channels,
deformable=deformable)
else:
self.char_attn = ConvModule(
in_channels,
out_channels,
3,
stride=1,
padding=1,
norm_cfg=dict(type='BN'))
if concat:
self.conv_to_concat = ConvModule(
out_channels,
out_channels,
3,
stride=1,
padding=1,
norm_cfg=dict(type='BN'))
kernel_size = (3, 1) if deformable else 3
padding = (1, 0) if deformable else 1
tmp_in_channels = out_channels * 2 if concat else out_channels
self.conv_after_concat = ConvModule(
tmp_in_channels,
out_channels,
kernel_size,
stride=1,
padding=padding,
norm_cfg=dict(type='BN'))
def forward(self, x, y=None, size=None):
if self.with_attn:
feat_map, attn_map = self.char_attn(x)
else:
feat_map = self.char_attn(x)
attn_map = feat_map
if self.concat:
y = self.conv_to_concat(y)
feat_map = torch.cat((y, feat_map), dim=1)
feat_map = self.conv_after_concat(feat_map)
if self.upsample:
feat_map = F.interpolate(feat_map, size)
return attn_map, feat_map
@NECKS.register_module()
class CAFCNNeck(nn.Module):
"""Implementation of neck module in `CA-FCN.
<https://arxiv.org/pdf/1809.06508.pdf>`_
Args:
in_channels (list[int]): Number of input channels for each scale.
out_channels (int): Number of output channels for each scale.
deformable (bool): If True, use deformable conv.
with_attn (bool): If True, add attention for each output feature map.
"""
def __init__(self,
in_channels=[128, 256, 512, 512],
out_channels=128,
deformable=True,
with_attn=True):
super().__init__()
self.deformable = deformable
self.with_attn = with_attn
# stage_in5_to_out5
self.s5 = FeatGenerator(
in_channels=in_channels[-1],
out_channels=out_channels,
deformable=deformable,
concat=False,
with_attn=with_attn)
# stage_in4_to_out4
self.s4 = FeatGenerator(
in_channels=in_channels[-2],
out_channels=out_channels,
deformable=deformable,
concat=True,
with_attn=with_attn)
# stage_in3_to_out3
self.s3 = FeatGenerator(
in_channels=in_channels[-3],
out_channels=out_channels,
deformable=False,
concat=True,
upsample=True,
with_attn=with_attn)
# stage_in2_to_out2
self.s2 = FeatGenerator(
in_channels=in_channels[-4],
out_channels=out_channels,
deformable=False,
concat=True,
upsample=True,
with_attn=with_attn)
def init_weights(self):
pass
def forward(self, feats):
in_stage1 = feats[0]
in_stage2, in_stage3 = feats[1], feats[2]
in_stage4, in_stage5 = feats[3], feats[4]
# out stage 5
out_s5_attn_map, out_s5 = self.s5(in_stage5)
# out stage 4
out_s4_attn_map, out_s4 = self.s4(in_stage4, out_s5)
# out stage 3
out_s3_attn_map, out_s3 = self.s3(in_stage3, out_s4,
in_stage2.size()[2:])
# out stage 2
out_s2_attn_map, out_s2 = self.s2(in_stage2, out_s3,
in_stage1.size()[2:])
return (out_s2_attn_map, out_s3_attn_map, out_s4_attn_map,
out_s5_attn_map, out_s2)

View File

@ -1,43 +0,0 @@
import torch.nn.functional as F
from mmcv.runner import auto_fp16
from mmdet.models.builder import NECKS
from mmdet.models.necks import FPN
@NECKS.register_module()
class FPNSeg(FPN):
"""Feature Pyramid Network for segmentation based text recognition.
Args:
in_channels (list[int]): Number of input channels for each scale.
out_channels (int): Number of output channels for each scale.
num_outs (int): Number of output scales.
upsample_param (dict | None): Config dict for interpolate layer.
Default: `dict(scale_factor=1.0, mode='nearest')`
last_stage_only (bool): If True, output last stage of FPN only.
"""
def __init__(self,
in_channels,
out_channels,
num_outs,
upsample_param=None,
last_stage_only=True):
super().__init__(in_channels, out_channels, num_outs)
self.upsample_param = upsample_param
self.last_stage_only = last_stage_only
@auto_fp16()
def forward(self, inputs):
outs = super().forward(inputs)
outs = list(outs)
if self.upsample_param is not None:
outs[0] = F.interpolate(outs[0], **self.upsample_param)
if self.last_stage_only:
return tuple(outs[0:1])
return tuple(outs[::-1])

View File

@ -1,5 +1,4 @@
from .base import BaseRecognizer
from .cafcn import CAFCNNet
from .crnn import CRNNNet
from .encode_decode_recognizer import EncodeDecodeRecognizer
from .nrtr import NRTR
@ -9,5 +8,5 @@ from .seg_recognizer import SegRecognizer
__all__ = [
'BaseRecognizer', 'EncodeDecodeRecognizer', 'CRNNNet', 'SARNet', 'NRTR',
'SegRecognizer', 'RobustScanner', 'CAFCNNet'
'SegRecognizer', 'RobustScanner'
]

View File

@ -1,7 +0,0 @@
from mmdet.models.builder import DETECTORS
from .seg_recognizer import SegRecognizer
@DETECTORS.register_module()
class CAFCNNet(SegRecognizer):
"""Implementation of `CAFCN <https://arxiv.org/pdf/1809.06508.pdf>`_"""

View File

@ -1,11 +1,8 @@
import numpy as np
import pytest
import torch
from mmdet.core import BitmapMasks
from mmocr.models.common.losses import DiceLoss
from mmocr.models.textrecog.losses import (CAFCNLoss, CELoss, CTCLoss, SARLoss,
TFLoss)
from mmocr.models.textrecog.losses import CELoss, CTCLoss, SARLoss, TFLoss
def test_ctc_loss():
@ -69,44 +66,6 @@ def test_tf_loss():
assert new_target.shape == torch.Size([1, 9])
def test_cafcn_loss():
with pytest.raises(AssertionError):
CAFCNLoss(alpha='1')
with pytest.raises(AssertionError):
CAFCNLoss(attn_s2_downsample_ratio='2')
with pytest.raises(AssertionError):
CAFCNLoss(attn_s3_downsample_ratio='1.5')
with pytest.raises(AssertionError):
CAFCNLoss(seg_downsample_ratio='1.5')
with pytest.raises(AssertionError):
CAFCNLoss(attn_s2_downsample_ratio=2)
with pytest.raises(AssertionError):
CAFCNLoss(attn_s3_downsample_ratio=1.5)
with pytest.raises(AssertionError):
CAFCNLoss(seg_downsample_ratio=1.5)
bsz = 1
H = W = 64
out_neck = (torch.ones(bsz, 1, H // 4, W // 4) * 0.5,
torch.ones(bsz, 1, H // 8, W // 8) * 0.5,
torch.ones(bsz, 1, H // 8, W // 8) * 0.5,
torch.ones(bsz, 1, H // 8, W // 8) * 0.5,
torch.ones(bsz, 1, H // 2, W // 2) * 0.5)
out_head = torch.rand(bsz, 37, H // 2, W // 2)
attn_tgt = np.zeros((H, W), dtype=np.float32)
segm_tgt = np.zeros((H, W), dtype=np.float32)
mask = np.ones((H, W), dtype=np.float32)
gt_kernels = BitmapMasks([attn_tgt, segm_tgt, mask], H, W)
cafcn_loss = CAFCNLoss()
losses = cafcn_loss(out_neck, out_head, [gt_kernels])
assert isinstance(losses, dict)
assert 'loss_seg' in losses
assert torch.allclose(losses['loss_seg'],
torch.tensor(losses['loss_seg'].item()).float())
def test_dice_loss():
with pytest.raises(AssertionError):
DiceLoss(eps='1')

View File

@ -1,47 +0,0 @@
import pytest
import torch
from mmocr.models.textrecog.necks.cafcn_neck import (CAFCNNeck, CharAttn,
FeatGenerator)
def test_char_attn():
with pytest.raises(AssertionError):
CharAttn(in_channels=5.0)
with pytest.raises(AssertionError):
CharAttn(deformable='deformabel')
in_feat = torch.rand(1, 128, 32, 32)
char_attn = CharAttn()
out_feat_map, attn_map = char_attn(in_feat)
assert attn_map.shape == torch.Size([1, 1, 32, 32])
assert out_feat_map.shape == torch.Size([1, 128, 32, 32])
def test_feat_generator():
in_feat = torch.rand(1, 128, 32, 32)
feat_generator = FeatGenerator(
in_channels=128, out_channels=128, deformable=False)
attn_map, feat_map = feat_generator(in_feat)
assert attn_map.shape == torch.Size([1, 1, 32, 32])
assert feat_map.shape == torch.Size([1, 128, 32, 32])
def test_cafcn_neck():
in_s1 = torch.rand(1, 64, 64, 64)
in_s2 = torch.rand(1, 128, 32, 32)
in_s3 = torch.rand(1, 256, 16, 16)
in_s4 = torch.rand(1, 512, 16, 16)
in_s5 = torch.rand(1, 512, 16, 16)
cafcn_neck = CAFCNNeck(deformable=False)
cafcn_neck.init_weights()
cafcn_neck.train()
out_neck = cafcn_neck((in_s1, in_s2, in_s3, in_s4, in_s5))
assert out_neck[0].shape == torch.Size([1, 1, 32, 32])
assert out_neck[1].shape == torch.Size([1, 1, 16, 16])
assert out_neck[2].shape == torch.Size([1, 1, 16, 16])
assert out_neck[3].shape == torch.Size([1, 1, 16, 16])
assert out_neck[4].shape == torch.Size([1, 128, 64, 64])