diff --git a/mmocr/datasets/pipelines/textdet_targets/drrg_targets.py b/mmocr/datasets/pipelines/textdet_targets/drrg_targets.py new file mode 100644 index 00000000..41cb5ec1 --- /dev/null +++ b/mmocr/datasets/pipelines/textdet_targets/drrg_targets.py @@ -0,0 +1,566 @@ +import cv2 +import numpy as np +from numpy.linalg import norm + +import mmocr.utils.check_argument as check_argument +from mmdet.core import BitmapMasks +from mmdet.datasets.builder import PIPELINES +# from mmocr.models.textdet import la_nms +from .textsnake_targets import TextSnakeTargets + + +@PIPELINES.register_module() +class DRRGTargets(TextSnakeTargets): + """Generate the ground truth targets of DRRG: Deep Relational Reasoning + Graph Network for Arbitrary Shape Text Detection. + + [https://arxiv.org/abs/2003.07493]. This was partially adapted from + https://github.com/GXYM/DRRG. + + Args: + orientation_thr (float): The threshold for distinguishing between + head edge and tail edge among the horizontal and vertical edges + of a quadrangle. + resample_step (float): The step size for resampling the text center + line (TCL). Better not exceed half of the minimum width + of the text component. + min_comp_num (int): The minimum number of text components, which + should be k_hop1 + 1 on graph. + max_comp_num (int): The maximum number of text components. + min_width (float): The minimum width of text components. + max_width (float): The maximum width of text components. + center_region_shrink_ratio (float): The shrink ratio of text center + region. + comp_shrink_ratio (float): The shrink ratio of text components. + text_comp_ratio (float): The reciprocal of aspect ratio of text + components. + min_rand_half_height(float): The minimum half-height of random text + components. + max_rand_half_height (float): The maximum half-height of random + text components. + jitter_level (float): The jitter level of text components geometric + features. + """ + + def __init__(self, + orientation_thr=2.0, + resample_step=8.0, + min_comp_num=9, + max_comp_num=600, + min_width=8.0, + max_width=24.0, + center_region_shrink_ratio=0.3, + comp_shrink_ratio=1.0, + text_comp_ratio=0.65, + text_comp_nms_thr=0.25, + min_rand_half_height=6.0, + max_rand_half_height=24.0, + jitter_level=0.2): + + super().__init__() + self.orientation_thr = orientation_thr + self.resample_step = resample_step + self.max_comp_num = max_comp_num + self.min_comp_num = min_comp_num + self.min_width = min_width + self.max_width = max_width + self.center_region_shrink_ratio = center_region_shrink_ratio + self.comp_shrink_ratio = comp_shrink_ratio + self.text_comp_ratio = text_comp_ratio + self.text_comp_nms_thr = text_comp_nms_thr + self.min_rand_half_height = min_rand_half_height + self.max_rand_half_height = max_rand_half_height + self.jitter_level = jitter_level + + def dist_point2line(self, pnt, line): + + assert isinstance(line, tuple) + pnt1, pnt2 = line + d = abs(np.cross(pnt2 - pnt1, pnt - pnt1)) / (norm(pnt2 - pnt1) + 1e-8) + return d + + def draw_center_region_maps(self, top_line, bot_line, center_line, + center_region_mask, top_height_map, + bot_height_map, sin_map, cos_map, + region_shrink_ratio): + """Draw attributes on text center region. + + Args: + top_line (ndarray): The points composing top curved sideline of + text polygon. + bot_line (ndarray): The points composing bottom curved sideline + of text polygon. + center_line (ndarray): The points composing the center line of text + instance. + center_region_mask (ndarray): The text center region mask. + top_height_map (ndarray): The map on which the distance from point + to top sideline will be drawn for each pixel in text center + region. + bot_height_map (ndarray): The map on which the distance from point + to bottom sideline will be drawn for each pixel in text center + region. + sin_map (ndarray): The map of vector_sin(top_point -bot_point) + that will be drawn on text center region. + cos_map (ndarray): The map of vector_cos(top_point -bot_point) + will be drawn on text center region. + region_shrink_ratio (float): The shrink ratio of text center. + """ + + assert top_line.shape == bot_line.shape == center_line.shape + assert (center_region_mask.shape == top_height_map.shape == + bot_height_map.shape == sin_map.shape == cos_map.shape) + assert isinstance(region_shrink_ratio, float) + + h, w = center_region_mask.shape + for i in range(0, len(center_line) - 1): + + top_mid_point = (top_line[i] + top_line[i + 1]) / 2 + bot_mid_point = (bot_line[i] + bot_line[i + 1]) / 2 + + sin_theta = self.vector_sin(top_mid_point - bot_mid_point) + cos_theta = self.vector_cos(top_mid_point - bot_mid_point) + + pnt_tl = center_line[i] + (top_line[i] - + center_line[i]) * region_shrink_ratio + pnt_tr = center_line[i + 1] + ( + top_line[i + 1] - center_line[i + 1]) * region_shrink_ratio + pnt_br = center_line[i + 1] + ( + bot_line[i + 1] - center_line[i + 1]) * region_shrink_ratio + pnt_bl = center_line[i] + (bot_line[i] - + center_line[i]) * region_shrink_ratio + current_center_box = np.vstack([pnt_tl, pnt_tr, pnt_br, + pnt_bl]).astype(np.int32) + + cv2.fillPoly(center_region_mask, [current_center_box], color=1) + cv2.fillPoly(sin_map, [current_center_box], color=sin_theta) + cv2.fillPoly(cos_map, [current_center_box], color=cos_theta) + + # x,y order + current_center_box[:, 0] = np.clip(current_center_box[:, 0], 0, + w - 1) + current_center_box[:, 1] = np.clip(current_center_box[:, 1], 0, + h - 1) + min_coord = np.min(current_center_box, axis=0).astype(np.int32) + max_coord = np.max(current_center_box, axis=0).astype(np.int32) + current_center_box = current_center_box - min_coord + sz = (max_coord - min_coord + 1) + + center_box_mask = np.zeros((sz[1], sz[0]), dtype=np.uint8) + cv2.fillPoly(center_box_mask, [current_center_box], color=1) + + inx = np.argwhere(center_box_mask > 0) + inx = inx + (min_coord[1], min_coord[0]) # y, x order + inx_xy = np.fliplr(inx) + top_height_map[(inx[:, 0], inx[:, 1])] = self.dist_point2line( + inx_xy, (top_line[i], top_line[i + 1])) + bot_height_map[(inx[:, 0], inx[:, 1])] = self.dist_point2line( + inx_xy, (bot_line[i], bot_line[i + 1])) + + def generate_center_mask_attrib_maps(self, img_size, text_polys): + """Generate text center region mask and geometry attribute maps. + + Args: + img_size (tuple): The image size of (height, width). + text_polys (list[list[ndarray]]): The list of text polygons. + + Returns: + center_lines (list): The list of text center lines. + center_region_mask (ndarray): The text center region mask. + top_height_map (ndarray): The distance map from each pixel in text + center region to top sideline. + bot_height_map (ndarray): The distance map from each pixel in text + center region to bottom sideline. + sin_map (ndarray): The sin(theta) map where theta is the angle + between vector (top point - bottom point) and vector (1, 0). + cos_map (ndarray): The cos(theta) map where theta is the angle + between vector (top point - bottom point) and vector (1, 0). + """ + + assert isinstance(img_size, tuple) + assert check_argument.is_2dlist(text_polys) + + h, w = img_size + + center_lines = [] + center_region_mask = np.zeros((h, w), np.uint8) + top_height_map = np.zeros((h, w), dtype=np.float32) + bot_height_map = np.zeros((h, w), dtype=np.float32) + sin_map = np.zeros((h, w), dtype=np.float32) + cos_map = np.zeros((h, w), dtype=np.float32) + + for poly in text_polys: + assert len(poly) == 1 + polygon_points = poly[0].reshape(-1, 2) + _, _, top_line, bot_line = self.reorder_poly_edge(polygon_points) + resampled_top_line, resampled_bot_line = self.resample_sidelines( + top_line, bot_line, self.resample_step) + resampled_bot_line = resampled_bot_line[::-1] + center_line = (resampled_top_line + resampled_bot_line) / 2 + + line_head_shrink_len = np.clip( + (norm(top_line[0] - bot_line[0]) * self.text_comp_ratio), + self.min_width, self.max_width) / 2 + line_tail_shrink_len = np.clip( + (norm(top_line[-1] - bot_line[-1]) * self.text_comp_ratio), + self.min_width, self.max_width) / 2 + head_shrink_num = int(line_head_shrink_len // self.resample_step) + tail_shrink_num = int(line_tail_shrink_len // self.resample_step) + if len(center_line) > head_shrink_num + tail_shrink_num + 2: + center_line = center_line[head_shrink_num:len(center_line) - + tail_shrink_num] + resampled_top_line = resampled_top_line[ + head_shrink_num:len(resampled_top_line) - tail_shrink_num] + resampled_bot_line = resampled_bot_line[ + head_shrink_num:len(resampled_bot_line) - tail_shrink_num] + center_lines.append(center_line.astype(np.int32)) + + self.draw_center_region_maps(resampled_top_line, + resampled_bot_line, center_line, + center_region_mask, top_height_map, + bot_height_map, sin_map, cos_map, + self.center_region_shrink_ratio) + + return (center_lines, center_region_mask, top_height_map, + bot_height_map, sin_map, cos_map) + + def generate_comp_attribs_from_maps(self, center_lines, center_region_mask, + top_height_map, bot_height_map, + sin_map, cos_map, comp_shrink_ratio): + """Generate attributes of text components in accordance with text + center lines and geometry attribute maps. + + Args: + center_lines (list[ndarray]): The list of text center lines. + center_region_mask (ndarray): The text center region mask. + top_height_map (ndarray): The distance map from each pixel in text + center region to top sideline. + bot_height_map (ndarray): The distance map from each pixel in text + center region to bottom sideline. + sin_map (ndarray): The sin(theta) map where theta is the angle + between vector (top point - bottom point) and vector (1, 0). + cos_map (ndarray): The cos(theta) map where theta is the angle + between vector (top point - bottom point) and vector (1, 0). + comp_shrink_ratio (float): The text component shrink ratio. + + Returns: + comp_attribs (ndarray): All text components attributes(x, y, h, w, + cos, sin, comp_labels). The comp_labels of two text components + from the same text instance are equal. + """ + + assert isinstance(center_lines, list) + assert (center_region_mask.shape == top_height_map.shape == + bot_height_map.shape == sin_map.shape == cos_map.shape) + assert isinstance(comp_shrink_ratio, float) + + center_lines_mask = np.zeros_like(center_region_mask) + cv2.polylines(center_lines_mask, center_lines, 0, (1, ), 1) + center_lines_mask = center_lines_mask * center_region_mask + comp_centers = np.argwhere(center_lines_mask > 0) + + y = comp_centers[:, 0] + x = comp_centers[:, 1] + + top_height = top_height_map[y, x].reshape((-1, 1)) * comp_shrink_ratio + bot_height = bot_height_map[y, x].reshape((-1, 1)) * comp_shrink_ratio + sin = sin_map[y, x].reshape((-1, 1)) + cos = cos_map[y, x].reshape((-1, 1)) + + top_mid_x_offset = top_height * cos + top_mid_y_offset = top_height * sin + bot_mid_x_offset = bot_height * cos + bot_mid_y_offset = bot_height * sin + + top_mid_pnt = comp_centers + np.hstack( + [top_mid_y_offset, top_mid_x_offset]) + bot_mid_pnt = comp_centers - np.hstack( + [bot_mid_y_offset, bot_mid_x_offset]) + + width = (top_height + bot_height) * self.text_comp_ratio + width = np.clip(width, self.min_width, self.max_width) + + top_left = (top_mid_pnt - + np.hstack([width * cos, -width * sin]))[:, ::-1] + top_right = (top_mid_pnt + + np.hstack([width * cos, -width * sin]))[:, ::-1] + bot_right = (bot_mid_pnt + + np.hstack([width * cos, -width * sin]))[:, ::-1] + bot_left = (bot_mid_pnt - + np.hstack([width * cos, -width * sin]))[:, ::-1] + text_comps = np.hstack([top_left, top_right, bot_right, bot_left]) + + score = center_lines_mask[y, x].reshape((-1, 1)) + text_comps = np.hstack([text_comps, score]).astype(np.float32) + # text_comps = la_nms(text_comps, self.text_comp_nms_thr) + + if text_comps.shape[0] < 1: + return None + + img_h, img_w = center_region_mask.shape + text_comps[:, 0:8:2] = np.clip(text_comps[:, 0:8:2], 0, img_w - 1) + text_comps[:, 1:8:2] = np.clip(text_comps[:, 1:8:2], 0, img_h - 1) + + comp_centers = np.mean( + text_comps[:, 0:8].reshape((-1, 4, 2)), axis=1).astype(np.int32) + x = comp_centers[:, 0] + y = comp_centers[:, 1] + + height = (top_height_map[y, x] + bot_height_map[y, x]).reshape((-1, 1)) + width = np.clip(height * self.text_comp_ratio, self.min_width, + self.max_width) + + cos = cos_map[y, x].reshape((-1, 1)) + sin = sin_map[y, x].reshape((-1, 1)) + + _, comp_label_mask = cv2.connectedComponents( + center_region_mask.astype(np.uint8), connectivity=8) + comp_labels = comp_label_mask[y, x].reshape((-1, 1)) + + x = x.reshape((-1, 1)) + y = y.reshape((-1, 1)) + comp_attribs = np.hstack([x, y, height, width, cos, sin, comp_labels]) + + return comp_attribs + + def generate_rand_comp_attribs(self, comp_num, center_sample_mask): + """Generate random text components and their attributes to ensure the + the number text components in a text image is larger than k_hop1, which + is the number of one hop neighbors in KNN graph. + + Args: + comp_num (int): The number of random text components. + center_sample_mask (ndarray): The text component centers sampling + region mask. + + Returns: + rand_comp_attribs (ndarray): The random text components + attributes(x, y, h, w, cos, sin, belong_instance_label=0). + """ + + assert isinstance(comp_num, int) + assert comp_num > 0 + assert center_sample_mask.ndim == 2 + + h, w = center_sample_mask.shape + + max_rand_half_height = self.max_rand_half_height + min_rand_half_height = self.min_rand_half_height + max_rand_height = max_rand_half_height * 2 + max_rand_width = np.clip(max_rand_height * self.text_comp_ratio, + self.min_width, self.max_width) + margin = int( + np.sqrt((max_rand_height / 2)**2 + (max_rand_width / 2)**2)) + 1 + + if 2 * margin + 1 > min(h, w): + + assert min(h, w) > (np.sqrt(2) * (self.min_width + 1)) + max_rand_half_height = max(min(h, w) / 4, self.min_width / 2 + 1) + min_rand_half_height = max(max_rand_half_height / 4, + self.min_width / 2) + + max_rand_height = max_rand_half_height * 2 + max_rand_width = np.clip(max_rand_height * self.text_comp_ratio, + self.min_width, self.max_width) + margin = int( + np.sqrt((max_rand_height / 2)**2 + + (max_rand_width / 2)**2)) + 1 + + inner_center_sample_mask = np.zeros_like(center_sample_mask) + inner_center_sample_mask[margin:h-margin, margin:w-margin] = \ + center_sample_mask[margin:h - margin, margin:w - margin] + kernel_size = int(min(max(max_rand_half_height, 7), 17)) + inner_center_sample_mask = cv2.erode( + inner_center_sample_mask, + np.ones((kernel_size, kernel_size), np.uint8)) + + center_candidates = np.argwhere(inner_center_sample_mask > 0) + center_candidate_num = len(center_candidates) + sample_inx = np.random.choice(center_candidate_num, comp_num) + rand_centers = center_candidates[sample_inx] + + rand_top_height = np.random.randint( + min_rand_half_height, + max_rand_half_height, + size=(len(rand_centers), 1)) + rand_bot_height = np.random.randint( + min_rand_half_height, + max_rand_half_height, + size=(len(rand_centers), 1)) + + rand_cos = 2 * np.random.random(size=(len(rand_centers), 1)) - 1 + rand_sin = 2 * np.random.random(size=(len(rand_centers), 1)) - 1 + scale = np.sqrt(1.0 / (rand_cos**2 + rand_sin**2 + 1e-8)) + rand_cos = rand_cos * scale + rand_sin = rand_sin * scale + + height = (rand_top_height + rand_bot_height) + width = np.clip(height * self.text_comp_ratio, self.min_width, + self.max_width) + + rand_comp_attribs = np.hstack([ + rand_centers[:, ::-1], height, width, rand_cos, rand_sin, + np.zeros_like(rand_sin) + ]) + + return rand_comp_attribs + + def jitter_comp_attribs(self, comp_attribs, jitter_level): + """Jitter text components attributes. + + Args: + comp_attribs (ndarray): The text components attributes. + jitter_level (float): The jitter level of text components + attributes. + + Returns: + jittered_comp_attribs (ndarray): The jittered text components + attributes(x, y, h, w, cos, sin, belong_instance_label). + """ + + assert comp_attribs.shape[1] == 7 + assert comp_attribs.shape[0] > 0 + assert isinstance(jitter_level, float) + + x = comp_attribs[:, 0].reshape((-1, 1)) + y = comp_attribs[:, 1].reshape((-1, 1)) + h = comp_attribs[:, 2].reshape((-1, 1)) + w = comp_attribs[:, 3].reshape((-1, 1)) + cos = comp_attribs[:, 4].reshape((-1, 1)) + sin = comp_attribs[:, 5].reshape((-1, 1)) + belong_label = comp_attribs[:, 6].reshape((-1, 1)) + + # max jitter offset of (x, y) should be + # ((h * abs(cos) + w * abs(sin)) / 2, + # (h * abs(sin) + w * abs(cos)) / 2) + x += (np.random.random(size=(len(comp_attribs), 1)) - + 0.5) * (h * np.abs(cos) + w * np.abs(sin)) * jitter_level + y += (np.random.random(size=(len(comp_attribs), 1)) - + 0.5) * (h * np.abs(sin) + w * np.abs(cos)) * jitter_level + + # max jitter offset of (h, w) should be (h, w) + h += (np.random.random(size=(len(comp_attribs), 1)) - + 0.5) * h * jitter_level + w += (np.random.random(size=(len(comp_attribs), 1)) - + 0.5) * w * jitter_level + + # max jitter offset of (cos, sin) should be (1, 1) + cos += (np.random.random(size=(len(comp_attribs), 1)) - + 0.5) * 2 * jitter_level + sin += (np.random.random(size=(len(comp_attribs), 1)) - + 0.5) * 2 * jitter_level + + scale = np.sqrt(1.0 / (cos**2 + sin**2 + 1e-8)) + cos = cos * scale + sin = sin * scale + + jittered_comp_attribs = np.hstack([x, y, h, w, cos, sin, belong_label]) + + return jittered_comp_attribs + + def generate_comp_attribs(self, center_lines, text_mask, + center_region_mask, top_height_map, + bot_height_map, sin_map, cos_map): + """Generate text components attributes. + + Args: + center_lines (list[ndarray]): The text center lines list. + text_mask (ndarray): The text region mask. + center_region_mask (ndarray): The text center region mask. + top_height_map (ndarray): The distance map from each pixel in text + center region to top sideline. + bot_height_map (ndarray): The distance map from each pixel in text + center region to bottom sideline. + sin_map (ndarray): The sin(theta) map where theta is the angle + between vector (top point - bottom point) and vector (1, 0). + cos_map (ndarray): The cos(theta) map where theta is the angle + between vector (top point - bottom point) and vector (1, 0). + + Returns: + pad_comp_attribs (ndarray): The padded text components attributes + with a fixed size. + """ + + assert isinstance(center_lines, list) + assert (text_mask.shape == center_region_mask.shape == + top_height_map.shape == bot_height_map.shape == sin_map.shape + == cos_map.shape) + + comp_attribs = self.generate_comp_attribs_from_maps( + center_lines, center_region_mask, top_height_map, bot_height_map, + sin_map, cos_map, self.comp_shrink_ratio) + + if comp_attribs is not None: + comp_attribs = self.jitter_comp_attribs(comp_attribs, + self.jitter_level) + + if comp_attribs.shape[0] < self.min_comp_num: + rand_sample_num = self.min_comp_num - comp_attribs.shape[0] + rand_comp_attribs = self.generate_rand_comp_attribs( + rand_sample_num, 1 - text_mask) + comp_attribs = np.vstack([comp_attribs, rand_comp_attribs]) + else: + comp_attribs = self.generate_rand_comp_attribs( + self.min_comp_num, 1 - text_mask) + + comp_num = ( + np.ones((comp_attribs.shape[0], 1), dtype=np.float32) * + comp_attribs.shape[0]) + comp_attribs = np.hstack([comp_num, comp_attribs]) + + if comp_attribs.shape[0] > self.max_comp_num: + comp_attribs = comp_attribs[:self.max_comp_num, :] + comp_attribs[:, 0] = self.max_comp_num + + pad_comp_attribs = np.zeros((self.max_comp_num, comp_attribs.shape[1])) + pad_comp_attribs[:comp_attribs.shape[0], :] = comp_attribs + + return pad_comp_attribs + + def generate_targets(self, results): + """Generate the gt targets for DRRG. + + Args: + results (dict): The input result dictionary. + + Returns: + results (dict): The output result dictionary. + """ + + assert isinstance(results, dict) + + polygon_masks = results['gt_masks'].masks + polygon_masks_ignore = results['gt_masks_ignore'].masks + + h, w, _ = results['img_shape'] + + gt_text_mask = self.generate_text_region_mask((h, w), polygon_masks) + gt_mask = self.generate_effective_mask((h, w), polygon_masks_ignore) + (center_lines, gt_center_region_mask, gt_top_height_map, + gt_bot_height_map, gt_sin_map, + gt_cos_map) = self.generate_center_mask_attrib_maps((h, w), + polygon_masks) + + gt_comp_attribs = self.generate_comp_attribs(center_lines, + gt_text_mask, + gt_center_region_mask, + gt_top_height_map, + gt_bot_height_map, + gt_sin_map, gt_cos_map) + + results['mask_fields'].clear() # rm gt_masks encoded by polygons + mapping = { + 'gt_text_mask': gt_text_mask, + 'gt_center_region_mask': gt_center_region_mask, + 'gt_mask': gt_mask, + 'gt_top_height_map': gt_top_height_map, + 'gt_bot_height_map': gt_bot_height_map, + 'gt_sin_map': gt_sin_map, + 'gt_cos_map': gt_cos_map + } + for key, value in mapping.items(): + value = value if isinstance(value, list) else [value] + results[key] = BitmapMasks(value, h, w) + results['mask_fields'].append(key) + + results['gt_comp_attribs'] = gt_comp_attribs + return results diff --git a/mmocr/datasets/pipelines/textdet_targets/textsnake_targets.py b/mmocr/datasets/pipelines/textdet_targets/textsnake_targets.py new file mode 100644 index 00000000..72042cfd --- /dev/null +++ b/mmocr/datasets/pipelines/textdet_targets/textsnake_targets.py @@ -0,0 +1,453 @@ +import cv2 +import numpy as np +from numpy.linalg import norm + +import mmocr.utils.check_argument as check_argument +from mmdet.core import BitmapMasks +from mmdet.datasets.builder import PIPELINES +from . import BaseTextDetTargets + + +@PIPELINES.register_module() +class TextSnakeTargets(BaseTextDetTargets): + """Generate the ground truth targets of TextSnake: TextSnake: A Flexible + Representation for Detecting Text of Arbitrary Shapes. + + [https://arxiv.org/abs/1807.01544]. This was partially adapted from + https://github.com/princewang1994/TextSnake.pytorch. + + Args: + orientation_thr (float): The threshold for distinguishing between + head edge and tail edge among the horizontal and vertical edges + of a quadrangle. + """ + + def __init__(self, + orientation_thr=2.0, + resample_step=4.0, + center_region_shrink_ratio=0.3): + + super().__init__() + self.orientation_thr = orientation_thr + self.resample_step = resample_step + self.center_region_shrink_ratio = center_region_shrink_ratio + + def vector_angle(self, vec1, vec2): + if vec1.ndim > 1: + unit_vec1 = vec1 / (norm(vec1, axis=-1) + 1e-8).reshape((-1, 1)) + else: + unit_vec1 = vec1 / (norm(vec1, axis=-1) + 1e-8) + if vec2.ndim > 1: + unit_vec2 = vec2 / (norm(vec2, axis=-1) + 1e-8).reshape((-1, 1)) + else: + unit_vec2 = vec2 / (norm(vec2, axis=-1) + 1e-8) + return np.arccos( + np.clip(np.sum(unit_vec1 * unit_vec2, axis=-1), -1.0, 1.0)) + + def vector_slope(self, vec): + assert len(vec) == 2 + return abs(vec[1] / (vec[0] + 1e-8)) + + def vector_sin(self, vec): + assert len(vec) == 2 + return vec[1] / (norm(vec) + 1e-8) + + def vector_cos(self, vec): + assert len(vec) == 2 + return vec[0] / (norm(vec) + 1e-8) + + def find_head_tail(self, points, orientation_thr): + """Find the head edge and tail edge of a text polygon. + + Args: + points (ndarray): The points composing a text polygon. + orientation_thr (float): The threshold for distinguishing between + head edge and tail edge among the horizontal and vertical edges + of a quadrangle. + + Returns: + head_inds (list): The indexes of two points composing head edge. + tail_inds (list): The indexes of two points composing tail edge. + """ + + assert points.ndim == 2 + assert points.shape[0] >= 4 + assert points.shape[1] == 2 + assert isinstance(orientation_thr, float) + + if len(points) > 4: + pad_points = np.vstack([points, points[0]]) + edge_vec = pad_points[1:] - pad_points[:-1] + + theta_sum = [] + + for i, edge_vec1 in enumerate(edge_vec): + adjacent_ind = [x % len(edge_vec) for x in [i - 1, i + 1]] + adjacent_edge_vec = edge_vec[adjacent_ind] + temp_theta_sum = np.sum( + self.vector_angle(edge_vec1, adjacent_edge_vec)) + theta_sum.append(temp_theta_sum) + theta_sum = np.array(theta_sum) + head_start, tail_start = np.argsort(theta_sum)[::-1][0:2] + + if abs(head_start - tail_start) < 2 \ + or abs(head_start - tail_start) > 12: + tail_start = (head_start + len(points) // 2) % len(points) + head_end = (head_start + 1) % len(points) + tail_end = (tail_start + 1) % len(points) + + if head_end > tail_end: + head_start, tail_start = tail_start, head_start + head_end, tail_end = tail_end, head_end + head_inds = [head_start, head_end] + tail_inds = [tail_start, tail_end] + else: + if self.vector_slope(points[1] - points[0]) + self.vector_slope( + points[3] - points[2]) < self.vector_slope( + points[2] - points[1]) + self.vector_slope(points[0] - + points[3]): + horizontal_edge_inds = [[0, 1], [2, 3]] + vertical_edge_inds = [[3, 0], [1, 2]] + else: + horizontal_edge_inds = [[3, 0], [1, 2]] + vertical_edge_inds = [[0, 1], [2, 3]] + + vertical_len_sum = norm(points[vertical_edge_inds[0][0]] - + points[vertical_edge_inds[0][1]]) + norm( + points[vertical_edge_inds[1][0]] - + points[vertical_edge_inds[1][1]]) + horizontal_len_sum = norm( + points[horizontal_edge_inds[0][0]] - + points[horizontal_edge_inds[0][1]]) + norm( + points[horizontal_edge_inds[1][0]] - + points[horizontal_edge_inds[1][1]]) + + if vertical_len_sum > horizontal_len_sum * orientation_thr: + head_inds = horizontal_edge_inds[0] + tail_inds = horizontal_edge_inds[1] + else: + head_inds = vertical_edge_inds[0] + tail_inds = vertical_edge_inds[1] + + return head_inds, tail_inds + + def reorder_poly_edge(self, points): + """Get the respective points composing head edge, tail edge, top + sideline and bottom sideline. + + Args: + points (ndarray): The points composing a text polygon. + + Returns: + head_edge (ndarray): The two points composing the head edge of text + polygon. + tail_edge (ndarray): The two points composing the tail edge of text + polygon. + top_sideline (ndarray): The points composing top curved sideline of + text polygon. + bot_sideline (ndarray): The points composing bottom curved sideline + of text polygon. + """ + + assert points.ndim == 2 + assert points.shape[0] >= 4 + assert points.shape[1] == 2 + + head_inds, tail_inds = self.find_head_tail(points, + self.orientation_thr) + head_edge, tail_edge = points[head_inds], points[tail_inds] + + pad_points = np.vstack([points, points]) + if tail_inds[1] < 1: + tail_inds[1] = len(points) + sideline1 = pad_points[head_inds[1]:tail_inds[1]] + sideline2 = pad_points[tail_inds[1]:(head_inds[1] + len(points))] + sideline_mean_shift = np.mean( + sideline1, axis=0) - np.mean( + sideline2, axis=0) + + if sideline_mean_shift[1] > 0: + top_sideline, bot_sideline = sideline2, sideline1 + else: + top_sideline, bot_sideline = sideline1, sideline2 + + return head_edge, tail_edge, top_sideline, bot_sideline + + def resample_line(self, line, n): + """Resample n points on a line. + + Args: + line (ndarray): The points composing a line. + n (int): The resampled points number. + + Returns: + resampled_line (ndarray): The points composing the resampled line. + """ + + assert line.ndim == 2 + assert line.shape[0] >= 2 + assert line.shape[1] == 2 + assert isinstance(n, int) + + length_list = [ + norm(line[i + 1] - line[i]) for i in range(len(line) - 1) + ] + total_length = sum(length_list) + length_cumsum = np.cumsum([0.0] + length_list) + delta_length = total_length / (float(n) + 1e-8) + + current_edge_ind = 0 + resampled_line = [line[0]] + + for i in range(1, n): + current_line_len = i * delta_length + + while current_line_len >= length_cumsum[current_edge_ind + 1]: + current_edge_ind += 1 + current_edge_end_shift = current_line_len - length_cumsum[ + current_edge_ind] + end_shift_ratio = current_edge_end_shift / length_list[ + current_edge_ind] + current_point = line[current_edge_ind] + ( + line[current_edge_ind + 1] - + line[current_edge_ind]) * end_shift_ratio + resampled_line.append(current_point) + + resampled_line.append(line[-1]) + resampled_line = np.array(resampled_line) + + return resampled_line + + def resample_sidelines(self, sideline1, sideline2, resample_step): + """Resample two sidelines to be of the same points number according to + step size. + + Args: + sideline1 (ndarray): The points composing a sideline of a text + polygon. + sideline2 (ndarray): The points composing another sideline of a + text polygon. + resample_step (float): The resampled step size. + + Returns: + resampled_line1 (ndarray): The resampled line 1. + resampled_line2 (ndarray): The resampled line 2. + """ + + assert sideline1.ndim == sideline1.ndim == 2 + assert sideline1.shape[1] == sideline1.shape[1] == 2 + assert sideline1.shape[0] >= 2 + assert sideline2.shape[0] >= 2 + assert isinstance(resample_step, float) + + length1 = sum([ + norm(sideline1[i + 1] - sideline1[i]) + for i in range(len(sideline1) - 1) + ]) + length2 = sum([ + norm(sideline2[i + 1] - sideline2[i]) + for i in range(len(sideline2) - 1) + ]) + + total_length = (length1 + length2) / 2 + resample_point_num = int(float(total_length) / resample_step) + + resampled_line1 = self.resample_line(sideline1, resample_point_num) + resampled_line2 = self.resample_line(sideline2, resample_point_num) + + return resampled_line1, resampled_line2 + + def draw_center_region_maps(self, top_line, bot_line, center_line, + center_region_mask, radius_map, sin_map, + cos_map, region_shrink_ratio): + """Draw attributes on text center region. + + Args: + top_line (ndarray): The points composing top curved sideline of + text polygon. + bot_line (ndarray): The points composing bottom curved sideline + of text polygon. + center_line (ndarray): The points composing the center line of text + instance. + center_region_mask (ndarray): The text center region mask. + radius_map (ndarray): The map where the distance from point to + sidelines will be drawn on for each pixel in text center + region. + sin_map (ndarray): The map where vector_sin(theta) will be drawn + on text center regions. Theta is the angle between tangent + line and vector (1, 0). + cos_map (ndarray): The map where vector_cos(theta) will be drawn on + text center regions. Theta is the angle between tangent line + and vector (1, 0). + region_shrink_ratio (float): The shrink ratio of text center. + """ + + assert top_line.shape == bot_line.shape == center_line.shape + assert (center_region_mask.shape == radius_map.shape == sin_map.shape + == cos_map.shape) + assert isinstance(region_shrink_ratio, float) + for i in range(0, len(center_line) - 1): + + top_mid_point = (top_line[i] + top_line[i + 1]) / 2 + bot_mid_point = (bot_line[i] + bot_line[i + 1]) / 2 + radius = norm(top_mid_point - bot_mid_point) / 2 + + text_direction = center_line[i + 1] - center_line[i] + sin_theta = self.vector_sin(text_direction) + cos_theta = self.vector_cos(text_direction) + + pnt_tl = center_line[i] + (top_line[i] - + center_line[i]) * region_shrink_ratio + pnt_tr = center_line[i + 1] + ( + top_line[i + 1] - center_line[i + 1]) * region_shrink_ratio + pnt_br = center_line[i + 1] + ( + bot_line[i + 1] - center_line[i + 1]) * region_shrink_ratio + pnt_bl = center_line[i] + (bot_line[i] - + center_line[i]) * region_shrink_ratio + current_center_box = np.vstack([pnt_tl, pnt_tr, pnt_br, + pnt_bl]).astype(np.int32) + + cv2.fillPoly(center_region_mask, [current_center_box], color=1) + cv2.fillPoly(sin_map, [current_center_box], color=sin_theta) + cv2.fillPoly(cos_map, [current_center_box], color=cos_theta) + cv2.fillPoly(radius_map, [current_center_box], color=radius) + + def generate_center_mask_attrib_maps(self, img_size, text_polys): + """Generate text center region mask and geometric attribute maps. + + Args: + img_size (tuple): The image size of (height, width). + text_polys (list[list[ndarray]]): The list of text polygons. + + Returns: + center_region_mask (ndarray): The text center region mask. + radius_map (ndarray): The distance map from each pixel in text + center region to top sideline. + sin_map (ndarray): The sin(theta) map where theta is the angle + between vector (top point - bottom point) and vector (1, 0). + cos_map (ndarray): The cos(theta) map where theta is the angle + between vector (top point - bottom point) and vector (1, 0). + """ + + assert isinstance(img_size, tuple) + assert check_argument.is_2dlist(text_polys) + + h, w = img_size + + center_region_mask = np.zeros((h, w), np.uint8) + radius_map = np.zeros((h, w), dtype=np.float32) + sin_map = np.zeros((h, w), dtype=np.float32) + cos_map = np.zeros((h, w), dtype=np.float32) + + for poly in text_polys: + assert len(poly) == 1 + text_instance = [[poly[0][i], poly[0][i + 1]] + for i in range(0, len(poly[0]), 2)] + polygon_points = np.array( + text_instance, dtype=np.int).reshape(-1, 2) + + _, _, top_line, bot_line = self.reorder_poly_edge(polygon_points) + resampled_top_line, resampled_bot_line = self.resample_sidelines( + top_line, bot_line, self.resample_step) + resampled_bot_line = resampled_bot_line[::-1] + center_line = (resampled_top_line + resampled_bot_line) / 2 + + if self.vector_slope(center_line[-1] - center_line[0]) > 0.9: + if (center_line[-1] - center_line[0])[1] < 0: + center_line = center_line[::-1] + resampled_top_line = resampled_top_line[::-1] + resampled_bot_line = resampled_bot_line[::-1] + else: + if (center_line[-1] - center_line[0])[0] < 0: + center_line = center_line[::-1] + resampled_top_line = resampled_top_line[::-1] + resampled_bot_line = resampled_bot_line[::-1] + + line_head_shrink_len = norm(resampled_top_line[0] - + resampled_bot_line[0]) / 4.0 + line_tail_shrink_len = norm(resampled_top_line[-1] - + resampled_bot_line[-1]) / 4.0 + head_shrink_num = int(line_head_shrink_len // self.resample_step) + tail_shrink_num = int(line_tail_shrink_len // self.resample_step) + + if len(center_line) > head_shrink_num + tail_shrink_num + 2: + center_line = center_line[head_shrink_num:len(center_line) - + tail_shrink_num] + resampled_top_line = resampled_top_line[ + head_shrink_num:len(resampled_top_line) - tail_shrink_num] + resampled_bot_line = resampled_bot_line[ + head_shrink_num:len(resampled_bot_line) - tail_shrink_num] + + self.draw_center_region_maps(resampled_top_line, + resampled_bot_line, center_line, + center_region_mask, radius_map, + sin_map, cos_map, + self.center_region_shrink_ratio) + + return center_region_mask, radius_map, sin_map, cos_map + + def generate_text_region_mask(self, img_size, text_polys): + """Generate text center region mask and geometry attribute maps. + + Args: + img_size (tuple): The image size (height, width). + text_polys (list[list[ndarray]]): The list of text polygons. + + Returns: + text_region_mask (ndarray): The text region mask. + """ + + assert isinstance(img_size, tuple) + assert check_argument.is_2dlist(text_polys) + + h, w = img_size + text_region_mask = np.zeros((h, w), dtype=np.uint8) + + for poly in text_polys: + assert len(poly) == 1 + text_instance = [[poly[0][i], poly[0][i + 1]] + for i in range(0, len(poly[0]), 2)] + polygon = np.array(text_instance, dtype=np.int).reshape((1, -1, 2)) + cv2.fillPoly(text_region_mask, polygon, 1) + + return text_region_mask + + def generate_targets(self, results): + """Generate the gt targets for TextSnake. + + Args: + results (dict): The input result dictionary. + + Returns: + results (dict): The output result dictionary. + """ + + assert isinstance(results, dict) + + polygon_masks = results['gt_masks'].masks + polygon_masks_ignore = results['gt_masks_ignore'].masks + + h, w, _ = results['img_shape'] + + gt_text_mask = self.generate_text_region_mask((h, w), polygon_masks) + gt_mask = self.generate_effective_mask((h, w), polygon_masks_ignore) + + (gt_center_region_mask, gt_radius_map, gt_sin_map, + gt_cos_map) = self.generate_center_mask_attrib_maps((h, w), + polygon_masks) + + results['mask_fields'].clear() # rm gt_masks encoded by polygons + mapping = { + 'gt_text_mask': gt_text_mask, + 'gt_center_region_mask': gt_center_region_mask, + 'gt_mask': gt_mask, + 'gt_radius_map': gt_radius_map, + 'gt_sin_map': gt_sin_map, + 'gt_cos_map': gt_cos_map + } + for key, value in mapping.items(): + value = value if isinstance(value, list) else [value] + results[key] = BitmapMasks(value, h, w) + results['mask_fields'].append(key) + + return results diff --git a/mmocr/models/textdet/dense_heads/drrg_head.py b/mmocr/models/textdet/dense_heads/drrg_head.py new file mode 100644 index 00000000..8de153ae --- /dev/null +++ b/mmocr/models/textdet/dense_heads/drrg_head.py @@ -0,0 +1,193 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import normal_init + +from mmdet.models.builder import HEADS, build_loss +from mmocr.models.textdet.modules import (GCN, LocalGraphs, + ProposalLocalGraphs, + merge_text_comps) +from .head_mixin import HeadMixin + + +@HEADS.register_module() +class DRRGHead(HeadMixin, nn.Module): + """The class for DRRG head: Deep Relational Reasoning Graph Network for + Arbitrary Shape Text Detection. + + [https://arxiv.org/abs/2003.07493] + + Args: + k_at_hops (tuple(int)): The number of i-hop neighbors, + i = 1, 2, ..., h. + active_connection (int): The number of two hop neighbors deem as + linked to a pivot. + node_geo_feat_dim (int): The dimension of embedded geometric features + of a component. + pooling_scale (float): The spatial scale of RRoI-Aligning. + pooling_output_size (tuple(int)): The size of RRoI-Aligning output. + graph_filter_thr (float): The threshold to filter identical local + graphs. + comp_shrink_ratio (float): The shrink ratio of text components. + nms_thr (float): The locality-aware NMS threshold. + min_width (float): The minimum width of text components. + max_width (float): The maximum width of text components. + comp_ratio (float): The reciprocal of aspect ratio of text components. + text_region_thr (float): The threshold for text region probability map. + center_region_thr (float): The threshold of text center region + probability map. + center_region_area_thr (int): The threshold of filtering small-size + text center region. + link_thr (float): The threshold for connected components searching. + """ + + def __init__(self, + in_channels, + k_at_hops=(8, 4), + active_connection=3, + node_geo_feat_dim=120, + pooling_scale=1.0, + pooling_output_size=(3, 4), + graph_filter_thr=0.75, + comp_shrink_ratio=0.95, + nms_thr=0.25, + min_width=8.0, + max_width=24.0, + comp_ratio=0.65, + text_region_thr=0.6, + center_region_thr=0.4, + center_region_area_thr=100, + link_thr=0.85, + loss=dict(type='DRRGLoss'), + train_cfg=None, + test_cfg=None): + super().__init__() + + assert isinstance(in_channels, int) + assert isinstance(k_at_hops, tuple) + assert isinstance(active_connection, int) + assert isinstance(node_geo_feat_dim, int) + assert isinstance(pooling_scale, float) + assert isinstance(pooling_output_size, tuple) + assert isinstance(graph_filter_thr, float) + assert isinstance(comp_shrink_ratio, float) + assert isinstance(nms_thr, float) + assert isinstance(min_width, float) + assert isinstance(max_width, float) + assert isinstance(comp_ratio, float) + assert isinstance(center_region_area_thr, int) + assert isinstance(link_thr, float) + + self.in_channels = in_channels + self.out_channels = 6 + self.downsample_ratio = 1.0 + self.k_at_hops = k_at_hops + self.active_connection = active_connection + self.node_geo_feat_dim = node_geo_feat_dim + self.pooling_scale = pooling_scale + self.pooling_output_size = pooling_output_size + self.graph_filter_thr = graph_filter_thr + self.comp_shrink_ratio = comp_shrink_ratio + self.nms_thr = nms_thr + self.min_width = min_width + self.max_width = max_width + self.comp_ratio = comp_ratio + self.text_region_thr = text_region_thr + self.center_region_thr = center_region_thr + self.center_region_area_thr = center_region_area_thr + self.link_thr = link_thr + self.loss_module = build_loss(loss) + self.train_cfg = train_cfg + self.test_cfg = test_cfg + + self.out_conv = nn.Conv2d( + in_channels=self.in_channels, + out_channels=self.out_channels, + kernel_size=1, + stride=1, + padding=0) + self.init_weights() + + self.graph_train = LocalGraphs(self.k_at_hops, self.active_connection, + self.node_geo_feat_dim, + self.pooling_scale, + self.pooling_output_size, + self.graph_filter_thr) + + self.graph_test = ProposalLocalGraphs( + self.k_at_hops, self.active_connection, self.node_geo_feat_dim, + self.pooling_scale, self.pooling_output_size, self.nms_thr, + self.min_width, self.max_width, self.comp_shrink_ratio, + self.comp_ratio, self.text_region_thr, self.center_region_thr, + self.center_region_area_thr) + + pool_w, pool_h = self.pooling_output_size + gcn_in_dim = (pool_w * pool_h) * ( + self.in_channels + self.out_channels) + self.node_geo_feat_dim + self.gcn = GCN(gcn_in_dim, 32) + + def init_weights(self): + normal_init(self.out_conv, mean=0, std=0.01) + + def forward(self, inputs, text_comp_feats): + + pred_maps = self.out_conv(inputs) + + feat_maps = torch.cat([inputs, pred_maps], dim=1) + node_feats, adjacent_matrices, knn_inx, gt_labels = self.graph_train( + feat_maps, np.array(text_comp_feats)) + + gcn_pred = self.gcn(node_feats, adjacent_matrices, knn_inx) + + return (pred_maps, (gcn_pred, gt_labels)) + + def single_test(self, feat_maps): + + pred_maps = self.out_conv(feat_maps) + feat_maps = torch.cat([feat_maps[0], pred_maps], dim=1) + + none_flag, graph_data = self.graph_test(pred_maps, feat_maps) + + (node_feats, adjacent_matrix, pivot_inx, knn_inx, local_graph_nodes, + text_comps) = graph_data + + if none_flag: + return None, None, None + + adjacent_matrix, pivot_inx, knn_inx = map( + lambda x: x.to(feat_maps.device), + (adjacent_matrix, pivot_inx, knn_inx)) + gcn_pred = self.gcn_model(node_feats, adjacent_matrix, knn_inx) + + pred_labels = F.softmax(gcn_pred, dim=1) + + edges = [] + scores = [] + local_graph_nodes = local_graph_nodes.long().squeeze().cpu().numpy() + graph_num = node_feats.size(0) + + for graph_inx in range(graph_num): + pivot = pivot_inx[graph_inx].int().item() + nodes = local_graph_nodes[graph_inx] + for neighbor_inx, neighbor in enumerate(knn_inx[graph_inx]): + neighbor = neighbor.item() + edges.append([nodes[pivot], nodes[neighbor]]) + scores.append(pred_labels[graph_inx * (knn_inx.shape[1]) + + neighbor_inx, 1].item()) + + edges = np.asarray(edges) + scores = np.asarray(scores) + + return edges, scores, text_comps + + def get_boundary(self, edges, scores, text_comps): + + boundaries = [] + if edges is not None: + boundaries = merge_text_comps(edges, scores, text_comps, + self.link_thr) + + results = dict(boundary_result=boundaries) + + return results diff --git a/mmocr/models/textdet/dense_heads/textsnake_head.py b/mmocr/models/textdet/dense_heads/textsnake_head.py new file mode 100644 index 00000000..1645bba7 --- /dev/null +++ b/mmocr/models/textdet/dense_heads/textsnake_head.py @@ -0,0 +1,48 @@ +import torch.nn as nn +from mmcv.cnn import normal_init + +from mmdet.models.builder import HEADS, build_loss +from . import HeadMixin + + +@HEADS.register_module() +class TextSnakeHead(HeadMixin, nn.Module): + """The class for TextSnake head: TextSnake: A Flexible Representation for + Detecting Text of Arbitrary Shapes. + + [https://arxiv.org/abs/1807.01544] + """ + + def __init__(self, + in_channels, + decoding_type='textsnake', + text_repr_type='poly', + loss=dict(type='TextSnakeLoss'), + train_cfg=None, + test_cfg=None): + super().__init__() + + assert isinstance(in_channels, int) + self.in_channels = in_channels + self.out_channels = 5 + self.downsample_ratio = 1.0 + self.decoding_type = decoding_type + self.text_repr_type = text_repr_type + self.loss_module = build_loss(loss) + self.train_cfg = train_cfg + self.test_cfg = test_cfg + + self.out_conv = nn.Conv2d( + in_channels=self.in_channels, + out_channels=self.out_channels, + kernel_size=1, + stride=1, + padding=0) + self.init_weights() + + def init_weights(self): + normal_init(self.out_conv, mean=0, std=0.01) + + def forward(self, inputs): + outputs = self.out_conv(inputs) + return outputs diff --git a/mmocr/models/textdet/detectors/drrg.py b/mmocr/models/textdet/detectors/drrg.py new file mode 100644 index 00000000..c96d40b0 --- /dev/null +++ b/mmocr/models/textdet/detectors/drrg.py @@ -0,0 +1,38 @@ +from mmdet.models.builder import DETECTORS +from . import SingleStageTextDetector, TextDetectorMixin + + +@DETECTORS.register_module() +class DRRG(TextDetectorMixin, SingleStageTextDetector): + """The class for implementing DRRG text detector: Deep Relational Reasoning + Graph Network for Arbitrary Shape Text Detection. + + [https://arxiv.org/abs/2003.07493] + """ + + def forward_train(self, img, img_metas, **kwargs): + """ + Args: + img (Tensor): Input images of shape (N, C, H, W). + Typically these should be mean centered and std scaled. + img_metas (list[dict]): A list of image info dict where each dict + has: 'img_shape', 'scale_factor', 'flip', and may also contain + 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. + For details of the values of these keys, see + :class:`mmdet.datasets.pipelines.Collect`. + Returns: + dict[str, Tensor]: A dictionary of loss components. + """ + x = self.extract_feat(img) + gt_comp_attribs = kwargs.pop('gt_comp_attribs') + preds = self.bbox_head(x, gt_comp_attribs) + losses = self.bbox_head.loss(preds, **kwargs) + return losses + + def simple_test(self, img, img_metas, rescale=False): + + x = self.extract_feat(img) + outs = self.bbox_head.single_test(x, img) + boundaries = self.bbox_head.get_boundary(*outs, img_metas, rescale) + + return [boundaries] diff --git a/mmocr/models/textdet/detectors/textsnake.py b/mmocr/models/textdet/detectors/textsnake.py new file mode 100644 index 00000000..25f65abf --- /dev/null +++ b/mmocr/models/textdet/detectors/textsnake.py @@ -0,0 +1,23 @@ +from mmdet.models.builder import DETECTORS +from . import SingleStageTextDetector, TextDetectorMixin + + +@DETECTORS.register_module() +class TextSnake(TextDetectorMixin, SingleStageTextDetector): + """The class for implementing TextSnake text detector: TextSnake: A + Flexible Representation for Detecting Text of Arbitrary Shapes. + + [https://arxiv.org/abs/1807.01544] + """ + + def __init__(self, + backbone, + neck, + bbox_head, + train_cfg=None, + test_cfg=None, + pretrained=None, + show_score=False): + SingleStageTextDetector.__init__(self, backbone, neck, bbox_head, + train_cfg, test_cfg, pretrained) + TextDetectorMixin.__init__(self, show_score) diff --git a/mmocr/models/textdet/losses/drrg_loss.py b/mmocr/models/textdet/losses/drrg_loss.py new file mode 100644 index 00000000..3caaf28f --- /dev/null +++ b/mmocr/models/textdet/losses/drrg_loss.py @@ -0,0 +1,206 @@ +import torch +import torch.nn.functional as F +from torch import nn + +from mmdet.core import BitmapMasks +from mmdet.models.builder import LOSSES +from mmocr.utils import check_argument + + +@LOSSES.register_module() +class DRRGLoss(nn.Module): + """The class for implementing DRRG loss: Deep Relational Reasoning Graph + Network for Arbitrary Shape Text Detection. + + [https://arxiv.org/abs/1908.05900] This is partially adapted from + https://github.com/GXYM/DRRG. + """ + + def __init__(self, ohem_ratio=3.0): + """Initialization. + + Args: + ohem_ratio (float): The negative/positive ratio in OHEM. + """ + super().__init__() + self.ohem_ratio = ohem_ratio + + def balance_bce_loss(self, pred, gt, mask): + + assert pred.shape == gt.shape == mask.shape + positive = gt * mask + negative = (1 - gt) * mask + positive_count = int(positive.float().sum()) + gt = gt.float() + if positive_count > 0: + loss = F.binary_cross_entropy(pred, gt, reduction='none') + positive_loss = torch.sum(loss * positive.float()) + negative_loss = loss * negative.float() + negative_count = min( + int(negative.float().sum()), + int(positive_count * self.ohem_ratio)) + else: + positive_loss = torch.tensor(0.0) + loss = F.binary_cross_entropy(pred, gt, reduction='none') + negative_loss = loss * negative.float() + negative_count = 100 + negative_loss, _ = torch.topk(negative_loss.view(-1), negative_count) + + balance_loss = (positive_loss + torch.sum(negative_loss)) / ( + float(positive_count + negative_count) + 1e-5) + + return balance_loss + + def gcn_loss(self, gcn_data): + + gcn_pred, gt_labels = gcn_data + gt_labels = gt_labels.view(-1).to(gcn_pred.device) + loss = F.cross_entropy(gcn_pred, gt_labels) + + return loss + + def bitmasks2tensor(self, bitmasks, target_sz): + """Convert Bitmasks to tensor. + + Args: + bitmasks (list[BitmapMasks]): The BitmapMasks list. Each item is + for one img. + target_sz (tuple(int, int)): The target tensor size HxW. + + Returns + results (list[tensor]): The list of kernel tensors. Each + element is for one kernel level. + """ + assert check_argument.is_type_list(bitmasks, BitmapMasks) + assert isinstance(target_sz, tuple) + + batch_size = len(bitmasks) + num_masks = len(bitmasks[0]) + + results = [] + + for level_inx in range(num_masks): + kernel = [] + for batch_inx in range(batch_size): + mask = torch.from_numpy(bitmasks[batch_inx].masks[level_inx]) + # hxw + mask_sz = mask.shape + # left, right, top, bottom + pad = [ + 0, target_sz[1] - mask_sz[1], 0, target_sz[0] - mask_sz[0] + ] + mask = F.pad(mask, pad, mode='constant', value=0) + kernel.append(mask) + kernel = torch.stack(kernel) + results.append(kernel) + + return results + + def forward(self, preds, downsample_ratio, gt_text_mask, + gt_center_region_mask, gt_mask, gt_top_height_map, + gt_bot_height_map, gt_sin_map, gt_cos_map): + + assert isinstance(preds, tuple) + assert isinstance(downsample_ratio, float) + assert abs(downsample_ratio - 1.0) < 1e-5 + assert check_argument.is_type_list(gt_text_mask, BitmapMasks) + assert check_argument.is_type_list(gt_center_region_mask, BitmapMasks) + assert check_argument.is_type_list(gt_mask, BitmapMasks) + assert check_argument.is_type_list(gt_top_height_map, BitmapMasks) + assert check_argument.is_type_list(gt_bot_height_map, BitmapMasks) + assert check_argument.is_type_list(gt_sin_map, BitmapMasks) + assert check_argument.is_type_list(gt_cos_map, BitmapMasks) + + pred_maps, gcn_data = preds + pred_text_region = pred_maps[:, 0, :, :] + pred_center_region = pred_maps[:, 1, :, :] + pred_sin_map = pred_maps[:, 2, :, :] + pred_cos_map = pred_maps[:, 3, :, :] + pred_top_height_map = pred_maps[:, 4, :, :] + pred_bot_height_map = pred_maps[:, 5, :, :] + feature_sz = pred_maps.size() + + # bitmask 2 tensor + mapping = { + 'gt_text_mask': gt_text_mask, + 'gt_center_region_mask': gt_center_region_mask, + 'gt_mask': gt_mask, + 'gt_top_height_map': gt_top_height_map, + 'gt_bot_height_map': gt_bot_height_map, + 'gt_sin_map': gt_sin_map, + 'gt_cos_map': gt_cos_map + } + gt = {} + for key, value in mapping.items(): + gt[key] = value + gt[key] = self.bitmasks2tensor(gt[key], feature_sz[2:]) + gt[key] = [item.to(pred_maps.device) for item in gt[key]] + + scale = torch.sqrt(1.0 / (pred_sin_map**2 + pred_cos_map**2 + 1e-8)) + pred_sin_map = pred_sin_map * scale + pred_cos_map = pred_cos_map * scale + + loss_text = self.balance_bce_loss( + torch.sigmoid(pred_text_region), gt['gt_text_mask'][0], + gt['gt_mask'][0]) + + text_mask = (gt['gt_text_mask'][0] * gt['gt_mask'][0]).float() + negative_text_mask = ((1 - gt['gt_text_mask'][0]) * + gt['gt_mask'][0]).float() + gt_center_region_mask = gt['gt_center_region_mask'][0].float() + loss_center = F.binary_cross_entropy( + torch.sigmoid(pred_center_region), + gt_center_region_mask, + reduction='none') + if int(text_mask.sum()) > 0: + loss_center_positive = torch.sum( + loss_center * text_mask) / torch.sum(text_mask) + else: + loss_center_positive = torch.tensor(0.0) + loss_center_negative = torch.sum( + loss_center * negative_text_mask) / torch.sum(negative_text_mask) + loss_center = loss_center_positive + 0.5 * loss_center_negative + + center_mask = (gt['gt_center_region_mask'][0] * + gt['gt_mask'][0]).float() + if int(center_mask.sum()) > 0: + ones = torch.ones_like( + gt['gt_top_height_map'][0], dtype=torch.float) + loss_top = F.smooth_l1_loss( + pred_top_height_map / (gt['gt_top_height_map'][0] + 1e-2), + ones, + reduction='none') + loss_bot = F.smooth_l1_loss( + pred_bot_height_map / (gt['gt_bot_height_map'][0] + 1e-2), + ones, + reduction='none') + gt_height = ( + gt['gt_top_height_map'][0] + gt['gt_bot_height_map'][0]) + loss_height = torch.sum( + (torch.log(gt_height + 1) * + (loss_top + loss_bot)) * center_mask) / torch.sum(center_mask) + + loss_sin = torch.sum( + F.smooth_l1_loss( + pred_sin_map, gt['gt_sin_map'][0], reduction='none') * + center_mask) / torch.sum(center_mask) + loss_cos = torch.sum( + F.smooth_l1_loss( + pred_cos_map, gt['gt_cos_map'][0], reduction='none') * + center_mask) / torch.sum(center_mask) + else: + loss_height = torch.tensor(0.0) + loss_sin = torch.tensor(0.0) + loss_cos = torch.tensor(0.0) + + loss_gcn = self.gcn_loss(gcn_data) + + results = dict( + loss_text=loss_text, + loss_center=loss_center, + loss_height=loss_height, + loss_sin=loss_sin, + loss_cos=loss_cos, + loss_gcn=loss_gcn) + + return results diff --git a/mmocr/models/textdet/losses/textsnake_loss.py b/mmocr/models/textdet/losses/textsnake_loss.py new file mode 100644 index 00000000..b12de444 --- /dev/null +++ b/mmocr/models/textdet/losses/textsnake_loss.py @@ -0,0 +1,181 @@ +import torch +import torch.nn.functional as F +from torch import nn + +from mmdet.core import BitmapMasks +from mmdet.models.builder import LOSSES +from mmocr.utils import check_argument + + +@LOSSES.register_module() +class TextSnakeLoss(nn.Module): + """The class for implementing TextSnake loss: + TextSnake: A Flexible Representation for Detecting Text of Arbitrary Shapes + [https://arxiv.org/abs/1807.01544]. + This is partially adapted from + https://github.com/princewang1994/TextSnake.pytorch. + """ + + def __init__(self, ohem_ratio=3.0): + """Initialization. + + Args: + ohem_ratio (float): The negative/positive ratio in ohem. + """ + super().__init__() + self.ohem_ratio = ohem_ratio + + def balanced_bce_loss(self, pred, gt, mask): + + assert pred.shape == gt.shape == mask.shape + positive = gt * mask + negative = (1 - gt) * mask + positive_count = int(positive.float().sum()) + gt = gt.float() + if positive_count > 0: + loss = F.binary_cross_entropy(pred, gt, reduction='none') + positive_loss = torch.sum(loss * positive.float()) + negative_loss = loss * negative.float() + negative_count = min( + int(negative.float().sum()), + int(positive_count * self.ohem_ratio)) + else: + positive_loss = torch.tensor(0.0, device=pred.device) + loss = F.binary_cross_entropy(pred, gt, reduction='none') + negative_loss = loss * negative.float() + negative_count = 100 + negative_loss, _ = torch.topk(negative_loss.view(-1), negative_count) + + balance_loss = (positive_loss + torch.sum(negative_loss)) / ( + float(positive_count + negative_count) + 1e-5) + + return balance_loss + + def bitmasks2tensor(self, bitmasks, target_sz): + """Convert Bitmasks to tensor. + + Args: + bitmasks (list[BitmapMasks]): The BitmapMasks list. Each item is + for one img. + target_sz (tuple(int, int)): The target tensor size HxW. + + Returns + results (list[tensor]): The list of kernel tensors. Each + element is for one kernel level. + """ + assert check_argument.is_type_list(bitmasks, BitmapMasks) + assert isinstance(target_sz, tuple) + + batch_size = len(bitmasks) + num_masks = len(bitmasks[0]) + + results = [] + + for level_inx in range(num_masks): + kernel = [] + for batch_inx in range(batch_size): + mask = torch.from_numpy(bitmasks[batch_inx].masks[level_inx]) + # hxw + mask_sz = mask.shape + # left, right, top, bottom + pad = [ + 0, target_sz[1] - mask_sz[1], 0, target_sz[0] - mask_sz[0] + ] + mask = F.pad(mask, pad, mode='constant', value=0) + kernel.append(mask) + kernel = torch.stack(kernel) + results.append(kernel) + + return results + + def forward(self, pred_maps, downsample_ratio, gt_text_mask, + gt_center_region_mask, gt_mask, gt_radius_map, gt_sin_map, + gt_cos_map): + + assert isinstance(downsample_ratio, float) + assert check_argument.is_type_list(gt_text_mask, BitmapMasks) + assert check_argument.is_type_list(gt_center_region_mask, BitmapMasks) + assert check_argument.is_type_list(gt_mask, BitmapMasks) + assert check_argument.is_type_list(gt_radius_map, BitmapMasks) + assert check_argument.is_type_list(gt_sin_map, BitmapMasks) + assert check_argument.is_type_list(gt_cos_map, BitmapMasks) + + pred_text_region = pred_maps[:, 0, :, :] + pred_center_region = pred_maps[:, 1, :, :] + pred_sin_map = pred_maps[:, 2, :, :] + pred_cos_map = pred_maps[:, 3, :, :] + pred_radius_map = pred_maps[:, 4, :, :] + feature_sz = pred_maps.size() + device = pred_maps.device + + # bitmask 2 tensor + mapping = { + 'gt_text_mask': gt_text_mask, + 'gt_center_region_mask': gt_center_region_mask, + 'gt_mask': gt_mask, + 'gt_radius_map': gt_radius_map, + 'gt_sin_map': gt_sin_map, + 'gt_cos_map': gt_cos_map + } + gt = {} + for key, value in mapping.items(): + gt[key] = value + if abs(downsample_ratio - 1.0) < 1e-2: + gt[key] = self.bitmasks2tensor(gt[key], feature_sz[2:]) + else: + gt[key] = [item.rescale(downsample_ratio) for item in gt[key]] + gt[key] = self.bitmasks2tensor(gt[key], feature_sz[2:]) + if key == 'gt_radius_map': + gt[key] = [item * downsample_ratio for item in gt[key]] + gt[key] = [item.to(device) for item in gt[key]] + + scale = torch.sqrt(1.0 / (pred_sin_map**2 + pred_cos_map**2 + 1e-8)) + pred_sin_map = pred_sin_map * scale + pred_cos_map = pred_cos_map * scale + + loss_text = self.balanced_bce_loss( + torch.sigmoid(pred_text_region), gt['gt_text_mask'][0], + gt['gt_mask'][0]) + + text_mask = (gt['gt_text_mask'][0] * gt['gt_mask'][0]).float() + loss_center_map = F.binary_cross_entropy( + torch.sigmoid(pred_center_region), + gt['gt_center_region_mask'][0].float(), + reduction='none') + if int(text_mask.sum()) > 0: + loss_center = torch.sum( + loss_center_map * text_mask) / torch.sum(text_mask) + else: + loss_center = torch.tensor(0.0, device=device) + + center_mask = (gt['gt_center_region_mask'][0] * + gt['gt_mask'][0]).float() + if int(center_mask.sum()) > 0: + map_sz = pred_radius_map.size() + ones = torch.ones(map_sz, dtype=torch.float, device=device) + loss_radius = torch.sum( + F.smooth_l1_loss( + pred_radius_map / (gt['gt_radius_map'][0] + 1e-2), + ones, + reduction='none') * center_mask) / torch.sum(center_mask) + loss_sin = torch.sum( + F.smooth_l1_loss( + pred_sin_map, gt['gt_sin_map'][0], reduction='none') * + center_mask) / torch.sum(center_mask) + loss_cos = torch.sum( + F.smooth_l1_loss( + pred_cos_map, gt['gt_cos_map'][0], reduction='none') * + center_mask) / torch.sum(center_mask) + else: + loss_radius = torch.tensor(0.0, device=device) + loss_sin = torch.tensor(0.0, device=device) + loss_cos = torch.tensor(0.0, device=device) + + results = dict( + loss_text=loss_text, + loss_center=loss_center, + loss_radius=loss_radius, + loss_sin=loss_sin, + loss_cos=loss_cos) + + return results diff --git a/mmocr/models/textdet/modules/__init__.py b/mmocr/models/textdet/modules/__init__.py new file mode 100644 index 00000000..9d62df75 --- /dev/null +++ b/mmocr/models/textdet/modules/__init__.py @@ -0,0 +1,6 @@ +from .gcn import GCN +from .local_graph import LocalGraphs +from .proposal_local_graph import ProposalLocalGraphs +from .utils import merge_text_comps + +__all__ = ['LocalGraphs', 'ProposalLocalGraphs', 'GCN', 'merge_text_comps'] diff --git a/mmocr/models/textdet/modules/gcn.py b/mmocr/models/textdet/modules/gcn.py new file mode 100644 index 00000000..f2fde039 --- /dev/null +++ b/mmocr/models/textdet/modules/gcn.py @@ -0,0 +1,80 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import init + + +class MeanAggregator(nn.Module): + + def forward(self, features, A): + x = torch.bmm(A, features) + return x + + +class GraphConv(nn.Module): + + def __init__(self, in_dim, out_dim): + super(GraphConv, self).__init__() + self.in_dim = in_dim + self.out_dim = out_dim + self.weight = nn.Parameter(torch.FloatTensor(in_dim * 2, out_dim)) + self.bias = nn.Parameter(torch.FloatTensor(out_dim)) + init.xavier_uniform_(self.weight) + init.constant_(self.bias, 0) + self.agg = MeanAggregator() + + def forward(self, features, A): + b, n, d = features.shape + assert d == self.in_dim + agg_feats = self.agg(features, A) + cat_feats = torch.cat([features, agg_feats], dim=2) + out = torch.einsum('bnd,df->bnf', (cat_feats, self.weight)) + out = F.relu(out + self.bias) + return out + + +class GCN(nn.Module): + """Predict linkage between instances. This was from repo + https://github.com/Zhongdao/gcn_clustering: Linkage Based Face Clustering + via Graph Convolution Network. + + [https://arxiv.org/abs/1903.11306] + + Args: + in_dim(int): The input dimension. + out_dim(int): The output dimension. + """ + + def __init__(self, in_dim, out_dim): + super(GCN, self).__init__() + self.bn0 = nn.BatchNorm1d(in_dim, affine=False).float() + self.conv1 = GraphConv(in_dim, 512) + self.conv2 = GraphConv(512, 256) + self.conv3 = GraphConv(256, 128) + self.conv4 = GraphConv(128, 64) + + self.classifier = nn.Sequential( + nn.Linear(64, out_dim), nn.PReLU(out_dim), nn.Linear(out_dim, 2)) + + def forward(self, x, A, one_hop_indexes, train=True): + + B, N, D = x.shape + + x = x.view(-1, D) + x = self.bn0(x) + x = x.view(B, N, D) + + x = self.conv1(x, A) + x = self.conv2(x, A) + x = self.conv3(x, A) + x = self.conv4(x, A) + k1 = one_hop_indexes.size(-1) + dout = x.size(-1) + edge_feat = torch.zeros(B, k1, dout) + for b in range(B): + edge_feat[b, :, :] = x[b, one_hop_indexes[b]] + edge_feat = edge_feat.view(-1, dout).to(x.device) + pred = self.classifier(edge_feat) + + # shape: (B*k1)x2 + return pred diff --git a/mmocr/models/textdet/modules/local_graph.py b/mmocr/models/textdet/modules/local_graph.py new file mode 100644 index 00000000..a1b0848c --- /dev/null +++ b/mmocr/models/textdet/modules/local_graph.py @@ -0,0 +1,307 @@ +import numpy as np +import torch + +from mmocr.models.utils import RROIAlign +from .utils import (embed_geo_feats, euclidean_distance_matrix, + normalize_adjacent_matrix) + + +class LocalGraphs: + """Generate local graphs for GCN to predict which instance a text component + belongs to. This was partially adapted from https://github.com/GXYM/DRRG: + Deep Relational Reasoning Graph Network for Arbitrary Shape Text Detection. + + [https://arxiv.org/abs/2003.07493] + + Args: + k_at_hops (tuple(int)): The number of h-hop neighbors. + active_connection (int): The number of neighbors deem as linked to a + pivot. + node_geo_feat_dim (int): The dimension of embedded geometric features + of a component. + pooling_scale (float): The spatial scale of RRoI-Aligning. + pooling_output_size (tuple(int)): The size of RRoI-Aligning output. + local_graph_filter_thr (float): The threshold to filter out identical + local graphs. + """ + + def __init__(self, k_at_hops, active_connection, node_geo_feat_dim, + pooling_scale, pooling_output_size, local_graph_filter_thr): + + assert isinstance(k_at_hops, tuple) + assert isinstance(active_connection, int) + assert isinstance(node_geo_feat_dim, int) + assert isinstance(pooling_scale, float) + assert isinstance(pooling_output_size, tuple) + assert isinstance(local_graph_filter_thr, float) + + self.k_at_hops = k_at_hops + self.local_graph_depth = len(self.k_at_hops) + self.active_connection = active_connection + self.node_geo_feat_dim = node_geo_feat_dim + self.pooling = RROIAlign(pooling_output_size, pooling_scale) + self.local_graph_filter_thr = local_graph_filter_thr + + def generate_local_graphs(self, sorted_complete_graph, gt_belong_labels): + """Generate local graphs for GCN to predict which instance a text + component belongs to. + + Args: + sorted_complete_graph (ndarray): The complete graph where nodes are + sorted according to their Euclidean distance. + gt_belong_labels (ndarray): The ground truth labels define which + instance text components (nodes in graphs) belong to. + + Returns: + local_graph_node_list (list): The list of local graph neighbors of + pivots. + knn_graph_neighbor_list (list): The list of k nearest neighbors of + pivots. + """ + + assert sorted_complete_graph.ndim == 2 + assert (sorted_complete_graph.shape[0] == + sorted_complete_graph.shape[1] == gt_belong_labels.shape[0]) + + knn_graphs = sorted_complete_graph[:, :self.k_at_hops[0] + 1] + local_graph_node_list = list() + knn_graph_neighbor_list = list() + for pivot_inx, knn_graph in enumerate(knn_graphs): + + h_hop_neighbor_list = list() + one_hop_neighbors = set(knn_graph[1:]) + h_hop_neighbor_list.append(one_hop_neighbors) + + for hop_inx in range(1, self.local_graph_depth): + h_hop_neighbor_list.append(set()) + for last_hop_neighbor_inx in h_hop_neighbor_list[-2]: + h_hop_neighbor_list[-1].update( + set(sorted_complete_graph[last_hop_neighbor_inx] + [1:self.k_at_hops[hop_inx] + 1])) + + hops_neighbor_set = set( + [node for hop in h_hop_neighbor_list for node in hop]) + hops_neighbor_list = list(hops_neighbor_set) + hops_neighbor_list.insert(0, pivot_inx) + + if pivot_inx < 1: + local_graph_node_list.append(hops_neighbor_list) + knn_graph_neighbor_list.append(one_hop_neighbors) + else: + add_flag = True + for graph_inx in range(len(knn_graph_neighbor_list)): + knn_graph_neighbors = knn_graph_neighbor_list[graph_inx] + local_graph_nodes = local_graph_node_list[graph_inx] + + node_union_num = len( + list( + set(knn_graph_neighbors).union( + set(one_hop_neighbors)))) + node_intersect_num = len( + list( + set(knn_graph_neighbors).intersection( + set(one_hop_neighbors)))) + one_hop_iou = node_intersect_num / (node_union_num + 1e-8) + + if (one_hop_iou > self.local_graph_filter_thr + and pivot_inx in knn_graph_neighbors + and gt_belong_labels[local_graph_nodes[0]] + == gt_belong_labels[pivot_inx] + and gt_belong_labels[local_graph_nodes[0]] != 0): + add_flag = False + break + if add_flag: + local_graph_node_list.append(hops_neighbor_list) + knn_graph_neighbor_list.append(one_hop_neighbors) + + return local_graph_node_list, knn_graph_neighbor_list + + def generate_gcn_input(self, node_feat_batch, belong_label_batch, + local_graph_node_batch, knn_graph_neighbor_batch, + sorted_complete_graph): + """Generate graph convolution network input data. + + Args: + node_feat_batch (List[Tensor]): The node feature batch. + belong_label_batch (List[ndarray]): The text component belong label + batch. + local_graph_node_batch (List[List[list]]): The local graph + neighbors batch. + knn_graph_neighbor_batch (List[List[set]]): The knn graph neighbor + batch. + sorted_complete_graph (List[ndarray]): The complete graph sorted + according to the Euclidean distance. + + Returns: + node_feat_batch_tensor (Tensor): The node features of Graph + Convolutional Network (GCN). + adjacent_mat_batch_tensor (Tensor): The adjacent matrices. + knn_inx_batch_tensor (Tensor): The indices of k nearest neighbors. + gt_linkage_batch_tensor (Tensor): The surpervision signal of GCN + for linkage prediction. + """ + + assert isinstance(node_feat_batch, list) + assert isinstance(belong_label_batch, list) + assert isinstance(local_graph_node_batch, list) + assert isinstance(knn_graph_neighbor_batch, list) + assert isinstance(sorted_complete_graph, list) + + max_graph_node_num = max([ + len(local_graph_nodes) + for local_graph_node_list in local_graph_node_batch + for local_graph_nodes in local_graph_node_list + ]) + + node_feat_batch_list = list() + adjacent_matrix_batch_list = list() + knn_inx_batch_list = list() + gt_linkage_batch_list = list() + + for batch_inx, sorted_neighbors in enumerate(sorted_complete_graph): + node_feats = node_feat_batch[batch_inx] + local_graph_list = local_graph_node_batch[batch_inx] + knn_graph_neighbor_list = knn_graph_neighbor_batch[batch_inx] + belong_labels = belong_label_batch[batch_inx] + + for graph_inx in range(len(local_graph_list)): + local_graph_nodes = local_graph_list[graph_inx] + local_graph_node_num = len(local_graph_nodes) + pivot_inx = local_graph_nodes[0] + knn_graph_neighbors = knn_graph_neighbor_list[graph_inx] + node_to_graph_inx = { + j: i + for i, j in enumerate(local_graph_nodes) + } + + knn_inx_in_local_graph = torch.tensor( + [node_to_graph_inx[i] for i in knn_graph_neighbors], + dtype=torch.long) + pivot_feats = node_feats[torch.tensor( + pivot_inx, dtype=torch.long)] + normalized_feats = node_feats[torch.tensor( + local_graph_nodes, dtype=torch.long)] - pivot_feats + + adjacent_matrix = np.zeros( + (local_graph_node_num, local_graph_node_num)) + pad_normalized_feats = torch.cat([ + normalized_feats, + torch.zeros(max_graph_node_num - local_graph_node_num, + normalized_feats.shape[1]).to( + node_feats.device) + ], + dim=0) + + for node in local_graph_nodes: + neighbors = sorted_neighbors[node, + 1:self.active_connection + 1] + for neighbor in neighbors: + if neighbor in local_graph_nodes: + adjacent_matrix[node_to_graph_inx[node], + node_to_graph_inx[neighbor]] = 1 + adjacent_matrix[node_to_graph_inx[neighbor], + node_to_graph_inx[node]] = 1 + + adjacent_matrix = normalize_adjacent_matrix( + adjacent_matrix, type='DAD') + adjacent_matrix_tensor = torch.zeros(max_graph_node_num, + max_graph_node_num).to( + node_feats.device) + adjacent_matrix_tensor[:local_graph_node_num, : + local_graph_node_num] = adjacent_matrix + + local_graph_labels = torch.from_numpy( + belong_labels[local_graph_nodes]).type(torch.long) + knn_labels = local_graph_labels[knn_inx_in_local_graph] + edge_labels = ((belong_labels[pivot_inx] == knn_labels) + & (belong_labels[pivot_inx] > 0)).long() + + node_feat_batch_list.append(pad_normalized_feats) + adjacent_matrix_batch_list.append(adjacent_matrix_tensor) + knn_inx_batch_list.append(knn_inx_in_local_graph) + gt_linkage_batch_list.append(edge_labels) + + node_feat_batch_tensor = torch.stack(node_feat_batch_list, 0) + adjacent_mat_batch_tensor = torch.stack(adjacent_matrix_batch_list, 0) + knn_inx_batch_tensor = torch.stack(knn_inx_batch_list, 0) + gt_linkage_batch_tensor = torch.stack(gt_linkage_batch_list, 0) + + return (node_feat_batch_tensor, adjacent_mat_batch_tensor, + knn_inx_batch_tensor, gt_linkage_batch_tensor) + + def __call__(self, feat_maps, comp_attribs): + """Generate local graphs. + + Args: + feat_maps (Tensor): The feature maps to propose node features in + graph. + comp_attribs (ndarray): The text components attributes. + + Returns: + node_feats_batch (Tensor): The node features of Graph Convolutional + Network(GCN). + adjacent_matrices_batch (Tensor): The adjacent matrices. + knn_inx_batch (Tensor): The indices of k nearest neighbors. + gt_linkage_batch (Tensor): The surpervision signal of GCN for + linkage prediction. + """ + + assert isinstance(feat_maps, torch.Tensor) + assert comp_attribs.shape[2] == 8 + + dist_sort_graph_batch_list = [] + local_graph_node_batch_list = [] + knn_graph_neighbor_batch_list = [] + node_feature_batch_list = [] + belong_label_batch_list = [] + + for batch_inx in range(comp_attribs.shape[0]): + comp_num = int(comp_attribs[batch_inx, 0, 0]) + comp_geo_attribs = comp_attribs[batch_inx, :comp_num, 1:7] + node_belong_labels = comp_attribs[batch_inx, :comp_num, + 7].astype(np.int32) + + comp_centers = comp_geo_attribs[:, 0:2] + distance_matrix = euclidean_distance_matrix( + comp_centers, comp_centers) + + graph_node_geo_feats = embed_geo_feats(comp_geo_attribs, + self.node_geo_feat_dim) + graph_node_geo_feats = torch.from_numpy( + graph_node_geo_feats).float().to(feat_maps.device) + + batch_id = np.zeros( + (comp_geo_attribs.shape[0], 1), dtype=np.float32) * batch_inx + text_comps = np.hstack( + (batch_id, comp_geo_attribs.astype(np.float32))) + text_comps = torch.from_numpy(text_comps).float().to( + feat_maps.device) + + comp_content_feats = self.pooling( + feat_maps[batch_inx].unsqueeze(0), text_comps) + comp_content_feats = comp_content_feats.view( + comp_content_feats.shape[0], -1).to(feat_maps.device) + node_feats = torch.cat((comp_content_feats, graph_node_geo_feats), + dim=-1) + + dist_sort_complete_graph = np.argsort(distance_matrix, axis=1) + (local_graph_nodes, + knn_graph_neighbors) = self.generate_local_graphs( + dist_sort_complete_graph, node_belong_labels) + + node_feature_batch_list.append(node_feats) + belong_label_batch_list.append(node_belong_labels) + local_graph_node_batch_list.append(local_graph_nodes) + knn_graph_neighbor_batch_list.append(knn_graph_neighbors) + dist_sort_graph_batch_list.append(dist_sort_complete_graph) + + (node_feats_batch, adjacent_matrices_batch, knn_inx_batch, + gt_linkage_batch) = \ + self.generate_gcn_input(node_feature_batch_list, + belong_label_batch_list, + local_graph_node_batch_list, + knn_graph_neighbor_batch_list, + dist_sort_graph_batch_list) + + return (node_feats_batch, adjacent_matrices_batch, knn_inx_batch, + gt_linkage_batch) diff --git a/mmocr/models/textdet/modules/proposal_local_graph.py b/mmocr/models/textdet/modules/proposal_local_graph.py new file mode 100644 index 00000000..901740d6 --- /dev/null +++ b/mmocr/models/textdet/modules/proposal_local_graph.py @@ -0,0 +1,418 @@ +import cv2 +import numpy as np +import torch + +# from mmocr.models.textdet.postprocess import la_nms +from mmocr.models.utils import RROIAlign +from .utils import (embed_geo_feats, euclidean_distance_matrix, + normalize_adjacent_matrix) + + +class ProposalLocalGraphs: + """Propose text components and generate local graphs. This was partially + adapted from https://github.com/GXYM/DRRG: Deep Relational Reasoning Graph + Network for Arbitrary Shape Text Detection. + + [https://arxiv.org/abs/2003.07493] + + Args: + k_at_hops (tuple(int)): The number of i-hop neighbors, + i = 1, 2, ..., h. + active_connection (int): The number of two hop neighbors deem as linked + to a pivot. + node_geo_feat_dim (int): The dimension of embedded geometric features + of a component. + pooling_scale (float): The spatial scale of RRoI-Aligning. + pooling_output_size (tuple(int)): The size of RRoI-Aligning output. + nms_thr (float): The locality-aware NMS threshold. + min_width (float): The minimum width of text components. + max_width (float): The maximum width of text components. + comp_shrink_ratio (float): The shrink ratio of text components. + comp_ratio (float): The reciprocal of aspect ratio of text components. + text_region_thr (float): The threshold for text region probability map. + center_region_thr (float): The threshold for text center region + probability map. + center_region_area_thr (int): The threshold for filtering out + small-size text center region. + """ + + def __init__(self, k_at_hops, active_connection, node_geo_feat_dim, + pooling_scale, pooling_output_size, nms_thr, min_width, + max_width, comp_shrink_ratio, comp_ratio, text_region_thr, + center_region_thr, center_region_area_thr): + + assert isinstance(k_at_hops, tuple) + assert isinstance(active_connection, int) + assert isinstance(node_geo_feat_dim, int) + assert isinstance(pooling_scale, float) + assert isinstance(pooling_output_size, tuple) + assert isinstance(nms_thr, float) + assert isinstance(min_width, float) + assert isinstance(max_width, float) + assert isinstance(comp_shrink_ratio, float) + assert isinstance(comp_ratio, float) + assert isinstance(text_region_thr, float) + assert isinstance(center_region_thr, float) + assert isinstance(center_region_area_thr, int) + + self.k_at_hops = k_at_hops + self.active_connection = active_connection + self.local_graph_depth = len(self.k_at_hops) + self.node_geo_feat_dim = node_geo_feat_dim + self.pooling = RROIAlign(pooling_output_size, pooling_scale) + self.nms_thr = nms_thr + self.min_width = min_width + self.max_width = max_width + self.comp_shrink_ratio = comp_shrink_ratio + self.comp_ratio = comp_ratio + self.text_region_thr = text_region_thr + self.center_region_thr = center_region_thr + self.center_region_area_thr = center_region_area_thr + + def fill_hole(self, input_mask): + h, w = input_mask.shape + canvas = np.zeros((h + 2, w + 2), np.uint8) + canvas[1:h + 1, 1:w + 1] = input_mask.copy() + + mask = np.zeros((h + 4, w + 4), np.uint8) + + cv2.floodFill(canvas, mask, (0, 0), 1) + canvas = canvas[1:h + 1, 1:w + 1].astype(np.bool) + + return (~canvas | input_mask.astype(np.uint8)) + + def propose_comps(self, top_radius_map, bot_radius_map, sin_map, cos_map, + score_map, min_width, max_width, comp_shrink_ratio, + comp_ratio): + """Generate text components. + + Args: + top_radius_map (ndarray): The predicted distance map from each + pixel in text center region to top sideline. + bot_radius_map (ndarray): The predicted distance map from each + pixel in text center region to bottom sideline. + sin_map (ndarray): The predicted sin(theta) map. + cos_map (ndarray): The predicted cos(theta) map. + score_map (ndarray): The score map for NMS. + min_width (float): The minimum width of text components. + max_width (float): The maximum width of text components. + comp_shrink_ratio (float): The shrink ratio of text components. + comp_ratio (float): The reciprocal of aspect ratio of text + components. + + Returns: + text_comps (ndarray): The text components. + """ + + comp_centers = np.argwhere(score_map > 0) + comp_centers = comp_centers[np.argsort(comp_centers[:, 0])] + y = comp_centers[:, 0] + x = comp_centers[:, 1] + + top_radius = top_radius_map[y, x].reshape((-1, 1)) * comp_shrink_ratio + bot_radius = bot_radius_map[y, x].reshape((-1, 1)) * comp_shrink_ratio + sin = sin_map[y, x].reshape((-1, 1)) + cos = cos_map[y, x].reshape((-1, 1)) + + top_mid_x_offset = top_radius * cos + top_mid_y_offset = top_radius * sin + bot_mid_x_offset = bot_radius * cos + bot_mid_y_offset = bot_radius * sin + + top_mid_pnt = comp_centers + np.hstack( + [top_mid_y_offset, top_mid_x_offset]) + bot_mid_pnt = comp_centers - np.hstack( + [bot_mid_y_offset, bot_mid_x_offset]) + + width = (top_radius + bot_radius) * comp_ratio + width = np.clip(width, min_width, max_width) + + top_left = top_mid_pnt - np.hstack([width * cos, -width * sin + ])[:, ::-1] + top_right = top_mid_pnt + np.hstack([width * cos, -width * sin + ])[:, ::-1] + bot_right = bot_mid_pnt + np.hstack([width * cos, -width * sin + ])[:, ::-1] + bot_left = bot_mid_pnt - np.hstack([width * cos, -width * sin + ])[:, ::-1] + + text_comps = np.hstack([top_left, top_right, bot_right, bot_left]) + score = score_map[y, x].reshape((-1, 1)) + text_comps = np.hstack([text_comps, score]) + + return text_comps + + def propose_comps_and_attribs(self, text_region_map, center_region_map, + top_radius_map, bot_radius_map, sin_map, + cos_map): + """Generate text components and attributes. + + Args: + text_region_map (ndarray): The predicted text region probability + map. + center_region_map (ndarray): The predicted text center region + probability map. + top_radius_map (ndarray): The predicted distance map from each + pixel in text center region to top sideline. + bot_radius_map (ndarray): The predicted distance map from each + pixel in text center region to bottom sideline. + sin_map (ndarray): The predicted sin(theta) map. + cos_map (ndarray): The predicted cos(theta) map. + + Returns: + comp_attribs (ndarray): The text components attributes. + text_comps (ndarray): The text components. + """ + + assert (text_region_map.shape == center_region_map.shape == + top_radius_map.shape == bot_radius_map == sin_map.shape == + cos_map.shape) + text_mask = text_region_map > self.text_region_thr + center_region_mask = (center_region_map > + self.center_region_thr) * text_mask + + scale = np.sqrt(1.0 / (sin_map**2 + cos_map**2)) + sin_map, cos_map = sin_map * scale, cos_map * scale + + center_region_mask = self.fill_hole(center_region_mask) + center_region_contours, _ = cv2.findContours( + center_region_mask.astype(np.uint8), cv2.RETR_TREE, + cv2.CHAIN_APPROX_SIMPLE) + + mask = np.zeros_like(center_region_mask) + comp_list = [] + for contour in center_region_contours: + current_center_mask = mask.copy() + cv2.drawContours(current_center_mask, [contour], -1, 1, -1) + if current_center_mask.sum() <= self.center_region_area_thr: + continue + score_map = text_region_map * current_center_mask + + text_comp = self.propose_comps(top_radius_map, bot_radius_map, + sin_map, cos_map, score_map, + self.min_width, self.max_width, + self.comp_shrink_ratio, + self.comp_ratio) + + # text_comp = la_nms(text_comp.astype('float32'), self.nms_thr) + + text_comp_mask = mask.copy() + text_comps_bboxes = text_comp[:, :8].reshape( + (-1, 4, 2)).astype(np.int32) + + cv2.drawContours(text_comp_mask, text_comps_bboxes, -1, 1, -1) + if (text_comp_mask * text_mask).sum() < text_comp_mask.sum() * 0.5: + continue + + comp_list.append(text_comp) + + if len(comp_list) <= 0: + return None, None + + text_comps = np.vstack(comp_list) + + centers = np.mean( + text_comps[:, :8].reshape((-1, 4, 2)), axis=1).astype(np.int32) + + x = centers[:, 0] + y = centers[:, 1] + + h = top_radius_map[y, x].reshape( + (-1, 1)) + bot_radius_map[y, x].reshape((-1, 1)) + w = np.clip(h * self.comp_ratio, self.min_width, self.max_width) + sin = sin_map[y, x].reshape((-1, 1)) + cos = cos_map[y, x].reshape((-1, 1)) + x = x.reshape((-1, 1)) + y = y.reshape((-1, 1)) + comp_attribs = np.hstack([x, y, h, w, cos, sin]) + + return comp_attribs, text_comps + + def generate_local_graphs(self, sorted_complete_graph, node_feats): + """Generate local graphs and Graph Convolution Network input data. + + Args: + sorted_complete_graph (ndarray): The complete graph where nodes are + sorted according to their Euclidean distance. + node_feats (tensor): The graph nodes features. + + Returns: + node_feats_tensor (tensor): The graph nodes features. + adjacent_matrix_tensor (tensor): The adjacent matrix of graph. + pivot_inx_tensor (tensor): The pivot indices in local graph. + knn_inx_tensor (tensor): The k nearest neighbor nodes indexes in + local graph. + local_graph_node_tensor (tensor): The indices of nodes in local + graph. + """ + + assert sorted_complete_graph.ndim == 2 + assert (sorted_complete_graph.shape[0] == + sorted_complete_graph.shape[1] == node_feats.shape[0]) + + knn_graphs = sorted_complete_graph[:, :self.k_at_hops[0] + 1] + local_graph_node_list = list() + knn_graph_neighbor_list = list() + for pivot_inx, knn_graph in enumerate(knn_graphs): + + h_hop_neighbor_list = list() + one_hop_neighbors = set(knn_graph[1:]) + h_hop_neighbor_list.append(one_hop_neighbors) + + for hop_inx in range(1, self.local_graph_depth): + h_hop_neighbor_list.append(set()) + for last_hop_neighbor_inx in h_hop_neighbor_list[-2]: + h_hop_neighbor_list[-1].update( + set(sorted_complete_graph[last_hop_neighbor_inx] + [1:self.k_at_hops[hop_inx] + 1])) + + hops_neighbor_set = set( + [node for hop in h_hop_neighbor_list for node in hop]) + hops_neighbor_list = list(hops_neighbor_set) + hops_neighbor_list.insert(0, pivot_inx) + + local_graph_node_list.append(hops_neighbor_list) + knn_graph_neighbor_list.append(one_hop_neighbors) + + max_graph_node_num = max([ + len(local_graph_nodes) + for local_graph_nodes in local_graph_node_list + ]) + + node_normalized_feats = list() + adjacent_matrix_list = list() + knn_inx = list() + pivot_graph_inx = list() + local_graph_tensor_list = list() + + for graph_inx in range(len(local_graph_node_list)): + + local_graph_nodes = local_graph_node_list[graph_inx] + local_graph_node_num = len(local_graph_nodes) + pivot_inx = local_graph_nodes[0] + knn_graph_neighbors = knn_graph_neighbor_list[graph_inx] + node_to_graph_inx = {j: i for i, j in enumerate(local_graph_nodes)} + + pivot_node_inx = torch.tensor([ + node_to_graph_inx[pivot_inx], + ]).type(torch.long) + knn_inx_in_local_graph = torch.tensor( + [node_to_graph_inx[i] for i in knn_graph_neighbors], + dtype=torch.long) + pivot_feats = node_feats[torch.tensor(pivot_inx, dtype=torch.long)] + normalized_feats = node_feats[torch.tensor( + local_graph_nodes, dtype=torch.long)] - pivot_feats + + adjacent_matrix = np.zeros( + (local_graph_node_num, local_graph_node_num)) + pad_normalized_feats = torch.cat([ + normalized_feats, + torch.zeros(max_graph_node_num - local_graph_node_num, + normalized_feats.shape[1]).to(node_feats.device) + ], + dim=0) + + for node in local_graph_nodes: + neighbors = sorted_complete_graph[node, + 1:self.active_connection + 1] + for neighbor in neighbors: + if neighbor in local_graph_nodes: + adjacent_matrix[node_to_graph_inx[node], + node_to_graph_inx[neighbor]] = 1 + adjacent_matrix[node_to_graph_inx[neighbor], + node_to_graph_inx[node]] = 1 + + adjacent_matrix = normalize_adjacent_matrix( + adjacent_matrix, type='DAD') + adjacent_matrix_tensor = torch.zeros( + max_graph_node_num, max_graph_node_num).to(node_feats.device) + adjacent_matrix_tensor[:local_graph_node_num, : + local_graph_node_num] = adjacent_matrix + + local_graph_tensor = torch.tensor(local_graph_nodes) + local_graph_tensor = torch.cat([ + local_graph_tensor, + torch.zeros( + max_graph_node_num - local_graph_node_num, + dtype=torch.long) + ], + dim=0) + + node_normalized_feats.append(pad_normalized_feats) + adjacent_matrix_list.append(adjacent_matrix_tensor) + pivot_graph_inx.append(pivot_node_inx) + knn_inx.append(knn_inx_in_local_graph) + local_graph_tensor_list.append(local_graph_tensor) + + node_feats_tensor = torch.stack(node_normalized_feats, 0) + adjacent_matrix_tensor = torch.stack(adjacent_matrix_list, 0) + pivot_inx_tensor = torch.stack(pivot_graph_inx, 0) + knn_inx_tensor = torch.stack(knn_inx, 0) + local_graph_node_tensor = torch.stack(local_graph_tensor_list, 0) + + return (node_feats_tensor, adjacent_matrix_tensor, pivot_inx_tensor, + knn_inx_tensor, local_graph_node_tensor) + + def __call__(self, preds, feat_maps): + """Generate local graphs and Graph Convolution Network input data. + + Args: + preds (tensor): The predicted maps. + feat_maps (tensor): The feature maps to extract content features of + text components. + + Returns: + node_feats_tensor (tensor): The graph nodes features. + adjacent_matrix_tensor (tensor): The adjacent matrix of graph. + pivot_inx_tensor (tensor): The pivot indices in local graph. + knn_inx_tensor (tensor): The k nearest neighbor nodes indices in + local graph. + local_graph_node_tensor (tensor): The indices of nodes in local + graph. + text_comps (ndarray): The predicted text components. + """ + + pred_text_region = torch.sigmoid(preds[0, 0]).data.cpu().numpy() + pred_center_region = torch.sigmoid(preds[0, 1]).data.cpu().numpy() + pred_sin_map = preds[0, 2].data.cpu().numpy() + pred_cos_map = preds[0, 3].data.cpu().numpy() + pred_top_radius_map = preds[0, 4].data.cpu().numpy() + pred_bot_radius_map = preds[0, 5].data.cpu().numpy() + + comp_attribs, text_comps = self.propose_comps_and_attribs( + pred_text_region, pred_center_region, pred_top_radius_map, + pred_bot_radius_map, pred_sin_map, pred_cos_map) + + if comp_attribs is None: + none_flag = True + return none_flag, (0, 0, 0, 0, 0, 0) + + comp_centers = comp_attribs[:, 0:2] + distance_matrix = euclidean_distance_matrix(comp_centers, comp_centers) + + graph_node_geo_feats = embed_geo_feats(comp_attribs, + self.node_geo_feat_dim) + graph_node_geo_feats = torch.from_numpy( + graph_node_geo_feats).float().to(preds.device) + + batch_id = np.zeros((comp_attribs.shape[0], 1), dtype=np.float32) + text_comps_bboxes = np.hstack( + (batch_id, comp_attribs.astype(np.float32, copy=False))) + text_comps_bboxes = torch.from_numpy(text_comps_bboxes).float().to( + preds.device) + + comp_content_feats = self.pooling(feat_maps, text_comps_bboxes) + comp_content_feats = comp_content_feats.view( + comp_content_feats.shape[0], -1).to(preds.device) + node_feats = torch.cat((comp_content_feats, graph_node_geo_feats), + dim=-1) + + dist_sort_complete_graph = np.argsort(distance_matrix, axis=1) + (node_feats_tensor, adjacent_matrix_tensor, pivot_inx_tensor, + knn_inx_tensor, local_graph_node_tensor) = self.generate_local_graphs( + dist_sort_complete_graph, node_feats) + + none_flag = False + return none_flag, (node_feats_tensor, adjacent_matrix_tensor, + pivot_inx_tensor, knn_inx_tensor, + local_graph_node_tensor, text_comps) diff --git a/mmocr/models/textdet/modules/utils.py b/mmocr/models/textdet/modules/utils.py new file mode 100644 index 00000000..c1d427ef --- /dev/null +++ b/mmocr/models/textdet/modules/utils.py @@ -0,0 +1,354 @@ +import functools +import operator +from typing import List + +import cv2 +import numpy as np +import torch +from numpy.linalg import norm + + +def normalize_adjacent_matrix(A, type='AD'): + """Normalize adjacent matrix for GCN. + + This was from repo https://github.com/GXYM/DRRG. + """ + if type == 'DAD': + # d is Degree of nodes A=A+I + # L = D^-1/2 A D^-1/2 + A = A + np.eye(A.shape[0]) # A=A+I + d = np.sum(A, axis=0) + d_inv = np.power(d, -0.5).flatten() + d_inv[np.isinf(d_inv)] = 0.0 + d_inv = np.diag(d_inv) + G = A.dot(d_inv).transpose().dot(d_inv) + G = torch.from_numpy(G) + elif type == 'AD': + A = A + np.eye(A.shape[0]) # A=A+I + A = torch.from_numpy(A) + D = A.sum(1, keepdim=True) + G = A.div(D) + else: + A = A + np.eye(A.shape[0]) # A=A+I + A = torch.from_numpy(A) + D = A.sum(1, keepdim=True) + D = np.diag(D) + G = D - A + return G + + +def euclidean_distance_matrix(A, B): + """Calculate the Euclidean distance matrix.""" + + M = A.shape[0] + N = B.shape[0] + + assert A.shape[1] == B.shape[1] + + A_dots = (A * A).sum(axis=1).reshape((M, 1)) * np.ones(shape=(1, N)) + B_dots = (B * B).sum(axis=1) * np.ones(shape=(M, 1)) + D_squared = A_dots + B_dots - 2 * A.dot(B.T) + + zero_mask = np.less(D_squared, 0.0) + D_squared[zero_mask] = 0.0 + return np.sqrt(D_squared) + + +def embed_geo_feats(geo_feats, out_dim): + """Embed geometric features of text components. This was partially adapted + from https://github.com/GXYM/DRRG. + + Args: + geo_feats (ndarray): The geometric features of text components. + out_dim (int): The output dimension. + + Returns: + embedded_feats (ndarray): The embedded geometric features. + """ + assert isinstance(out_dim, int) + assert out_dim >= geo_feats.shape[1] + comp_num = geo_feats.shape[0] + feat_dim = geo_feats.shape[1] + feat_repeat_times = out_dim // feat_dim + residue_dim = out_dim % feat_dim + + if residue_dim > 0: + embed_wave = np.array([ + np.power(1000, 2.0 * (j // 2) / feat_repeat_times + 1) + for j in range(feat_repeat_times + 1) + ]).reshape((feat_repeat_times + 1, 1, 1)) + repeat_feats = np.repeat( + np.expand_dims(geo_feats, axis=0), feat_repeat_times, axis=0) + residue_feats = np.hstack([ + geo_feats[:, 0:residue_dim], + np.zeros((comp_num, feat_dim - residue_dim)) + ]) + repeat_feats = np.stack([repeat_feats, residue_feats], axis=0) + embedded_feats = repeat_feats / embed_wave + embedded_feats[:, 0::2] = np.sin(embedded_feats[:, 0::2]) + embedded_feats[:, 1::2] = np.cos(embedded_feats[:, 1::2]) + embedded_feats = np.transpose(embedded_feats, (1, 0, 2)).reshape( + (comp_num, -1))[:, 0:out_dim] + else: + embed_wave = np.array([ + np.power(1000, 2.0 * (j // 2) / feat_repeat_times) + for j in range(feat_repeat_times) + ]).reshape((feat_repeat_times, 1, 1)) + repeat_feats = np.repeat( + np.expand_dims(geo_feats, axis=0), feat_repeat_times, axis=0) + embedded_feats = repeat_feats / embed_wave + embedded_feats[:, 0::2] = np.sin(embedded_feats[:, 0::2]) + embedded_feats[:, 1::2] = np.cos(embedded_feats[:, 1::2]) + embedded_feats = np.transpose(embedded_feats, (1, 0, 2)).reshape( + (comp_num, -1)) + + return embedded_feats + + +def min_connect_path(list_all: List[list]): + """This is from https://github.com/GXYM/DRRG.""" + + list_nodo = list_all.copy() + res: List[List[int]] = [] + ept = [0, 0] + + def norm2(a, b): + return ((a[0] - b[0])**2 + (a[1] - b[1])**2)**0.5 + + dict00 = {} + dict11 = {} + ept[0] = list_nodo[0] + ept[1] = list_nodo[0] + list_nodo.remove(list_nodo[0]) + while list_nodo: + for i in list_nodo: + length0 = norm2(i, ept[0]) + dict00[length0] = [i, ept[0]] + length1 = norm2(ept[1], i) + dict11[length1] = [ept[1], i] + key0 = min(dict00.keys()) + key1 = min(dict11.keys()) + + if key0 <= key1: + ss = dict00[key0][0] + ee = dict00[key0][1] + res.insert(0, [list_all.index(ss), list_all.index(ee)]) + list_nodo.remove(ss) + ept[0] = ss + else: + ss = dict11[key1][0] + ee = dict11[key1][1] + res.append([list_all.index(ss), list_all.index(ee)]) + list_nodo.remove(ee) + ept[1] = ee + + dict00 = {} + dict11 = {} + + path = functools.reduce(operator.concat, res) + path = sorted(set(path), key=path.index) + + return res, path + + +def clusters2labels(clusters, node_num): + """This is from https://github.com/GXYM/DRRG.""" + labels = (-1) * np.ones((node_num, )) + for cluster_inx, cluster in enumerate(clusters): + for node in cluster: + labels[node.inx] = cluster_inx + assert np.sum(labels < 0) < 1 + return labels + + +def remove_single(text_comps, pred): + """Remove isolated single text components. + + This is from https://github.com/GXYM/DRRG. + """ + single_flags = np.zeros_like(pred) + pred_labels = np.unique(pred) + for label in pred_labels: + current_label_flag = pred == label + if np.sum(current_label_flag) == 1: + single_flags[np.where(current_label_flag)[0][0]] = 1 + remain_inx = [i for i in range(len(pred)) if not single_flags[i]] + remain_inx = np.asarray(remain_inx) + return text_comps[remain_inx, :], pred[remain_inx] + + +class Node: + + def __init__(self, inx): + self.__inx = inx + self.__links = set() + + @property + def inx(self): + return self.__inx + + @property + def links(self): + return set(self.__links) + + def add_link(self, other, score): + self.__links.add(other) + other.__links.add(self) + + +def connected_components(nodes, score_dict, thr): + """Connected components searching. + + This is from https://github.com/GXYM/DRRG. + """ + + result = [] + nodes = set(nodes) + while nodes: + node = nodes.pop() + group = {node} + queue = [node] + while queue: + node = queue.pop(0) + if thr is not None: + neighbors = { + linked_neighbor + for linked_neighbor in node.links if score_dict[tuple( + sorted([node.inx, linked_neighbor.inx]))] >= thr + } + else: + neighbors = node.links + neighbors.difference_update(group) + nodes.difference_update(neighbors) + group.update(neighbors) + queue.extend(neighbors) + result.append(group) + return result + + +def graph_propagation(edges, + scores, + link_thr, + bboxes=None, + dis_thr=50, + pool='avg'): + """Propagate graph linkage score information. + + This is from repo https://github.com/GXYM/DRRG. + """ + edges = np.sort(edges, axis=1) + + score_dict = {} + if pool is None: + for i, edge in enumerate(edges): + score_dict[edge[0], edge[1]] = scores[i] + elif pool == 'avg': + for i, edge in enumerate(edges): + if bboxes is not None: + box1 = bboxes[edge[0]][:8].reshape(4, 2) + box2 = bboxes[edge[1]][:8].reshape(4, 2) + center1 = np.mean(box1, axis=0) + center2 = np.mean(box2, axis=0) + dst = norm(center1 - center2) + if dst > dis_thr: + scores[i] = 0 + if (edge[0], edge[1]) in score_dict: + score_dict[edge[0], edge[1]] = 0.5 * ( + score_dict[edge[0], edge[1]] + scores[i]) + else: + score_dict[edge[0], edge[1]] = scores[i] + + elif pool == 'max': + for i, edge in enumerate(edges): + if (edge[0], edge[1]) in score_dict: + score_dict[edge[0], + edge[1]] = max(score_dict[edge[0], edge[1]], + scores[i]) + else: + score_dict[edge[0], edge[1]] = scores[i] + else: + raise ValueError('Pooling operation not supported') + + nodes = np.sort(np.unique(edges.flatten())) + mapping = -1 * np.ones((nodes.max() + 1), dtype=np.int) + mapping[nodes] = np.arange(nodes.shape[0]) + link_inx = mapping[edges] + vertex = [Node(node) for node in nodes] + for link, score in zip(link_inx, scores): + vertex[link[0]].add_link(vertex[link[1]], score) + + clusters = connected_components(vertex, score_dict, link_thr) + + return clusters + + +def in_contour(cont, point): + x, y = point + return cv2.pointPolygonTest(cont, (x, y), False) > 0 + + +def select_edge(cont, box): + """This is from repo https://github.com/GXYM/DRRG.""" + cont = np.array(cont) + box = box.astype(np.int32) + c1 = np.array(0.5 * (box[0, :] + box[3, :]), dtype=np.int) + c2 = np.array(0.5 * (box[1, :] + box[2, :]), dtype=np.int) + + if not in_contour(cont, c1): + return [box[0, :].tolist(), box[3, :].tolist()] + elif not in_contour(cont, c2): + return [box[1, :].tolist(), box[2, :].tolist()] + else: + return None + + +def comps2boundary(text_comps, final_pred): + """Propose text components and generate local graphs. + + This is from repo https://github.com/GXYM/DRRG. + """ + bbox_contours = list() + for inx in range(0, int(np.max(final_pred)) + 1): + current_instance = np.where(final_pred == inx) + boxes = text_comps[current_instance, :8].reshape( + (-1, 4, 2)).astype(np.int32) + + boundary_point = None + if boxes.shape[0] > 1: + centers = np.mean(boxes, axis=1).astype(np.int32).tolist() + paths, routes_path = min_connect_path(centers) + boxes = boxes[routes_path] + top = np.mean(boxes[:, 0:2, :], axis=1).astype(np.int32).tolist() + bot = np.mean(boxes[:, 2:4, :], axis=1).astype(np.int32).tolist() + edge1 = select_edge(top + bot[::-1], boxes[0]) + edge2 = select_edge(top + bot[::-1], boxes[-1]) + if edge1 is not None: + top.insert(0, edge1[0]) + bot.insert(0, edge1[1]) + if edge2 is not None: + top.append(edge2[0]) + bot.append(edge2[1]) + boundary_point = np.array(top + bot[::-1]) + + elif boxes.shape[0] == 1: + top = boxes[0, 0:2, :].astype(np.int32).tolist() + bot = boxes[0, 2:4:-1, :].astype(np.int32).tolist() + boundary_point = np.array(top + bot) + + if boundary_point is None: + continue + + boundary_point = [p for p in boundary_point.flatten().tolist()] + bbox_contours.append(boundary_point) + + return bbox_contours + + +def merge_text_comps(edges, scores, text_comps, link_thr): + """Merge text components into text instance.""" + clusters = graph_propagation(edges, scores, link_thr) + pred_labels = clusters2labels(clusters, text_comps.shape[0]) + text_comps, final_pred = remove_single(text_comps, pred_labels) + boundaries = comps2boundary(text_comps, final_pred) + + return boundaries diff --git a/mmocr/models/textdet/necks/fpn_unet.py b/mmocr/models/textdet/necks/fpn_unet.py new file mode 100644 index 00000000..85c09085 --- /dev/null +++ b/mmocr/models/textdet/necks/fpn_unet.py @@ -0,0 +1,88 @@ +import torch +import torch.nn.functional as F +from mmcv.cnn import xavier_init +from torch import nn + +from mmdet.models.builder import NECKS + + +class UpBlock(nn.Module): + """Upsample block for DRRG and TextSnake.""" + + def __init__(self, in_channels, out_channels): + super().__init__() + + assert isinstance(in_channels, int) + assert isinstance(out_channels, int) + + self.conv1x1 = nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.conv3x3 = nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1) + self.deconv = nn.ConvTranspose2d( + out_channels, out_channels, kernel_size=4, stride=2, padding=1) + + def forward(self, x): + x = F.relu(self.conv1x1(x)) + x = F.relu(self.conv3x3(x)) + x = self.deconv(x) + return x + + +@NECKS.register_module() +class FPN_UNET(nn.Module): + """The class for implementing DRRG and TextSnake U-Net-like FPN. + + DRRG: Deep Relational Reasoning Graph Network for Arbitrary Shape + Text Detection [https://arxiv.org/abs/2003.07493]. + TextSnake: A Flexible Representation for Detecting Text of Arbitrary Shapes + [https://arxiv.org/abs/1807.01544]. + """ + + def __init__(self, in_channels, out_channels): + super().__init__() + + assert len(in_channels) == 4 + assert isinstance(out_channels, int) + + blocks_out_channels = [out_channels] + [ + min(out_channels * 2**i, 256) for i in range(4) + ] + blocks_in_channels = [blocks_out_channels[1]] + [ + in_channels[i] + blocks_out_channels[i + 2] for i in range(3) + ] + [in_channels[3]] + + self.up4 = nn.ConvTranspose2d( + blocks_in_channels[4], + blocks_out_channels[4], + kernel_size=4, + stride=2, + padding=1) + self.up_block3 = UpBlock(blocks_in_channels[3], blocks_out_channels[3]) + self.up_block2 = UpBlock(blocks_in_channels[2], blocks_out_channels[2]) + self.up_block1 = UpBlock(blocks_in_channels[1], blocks_out_channels[1]) + self.up_block0 = UpBlock(blocks_in_channels[0], blocks_out_channels[0]) + self.init_weights() + + def init_weights(self): + for m in self.modules(): + if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)): + xavier_init(m, distribution='uniform') + + def forward(self, x): + c2, c3, c4, c5 = x + + x = F.relu(self.up4(c5)) + + x = torch.cat([x, c4], dim=1) + x = F.relu(self.up_block3(x)) + + x = torch.cat([x, c3], dim=1) + x = F.relu(self.up_block2(x)) + + x = torch.cat([x, c2], dim=1) + x = F.relu(self.up_block1(x)) + + x = self.up_block0(x) + # the output should be of the same height and width as backbone input + return x