add fcenet
parent
5876f3f475
commit
9f62b610de
configs/det
ppocr
metrics
modeling
postprocess
|
@ -0,0 +1,141 @@
|
|||
Global:
|
||||
use_gpu: true
|
||||
epoch_num: 1500
|
||||
log_smooth_window: 20
|
||||
print_batch_step: 20
|
||||
save_model_dir: ./output/fce_r50_ctw/
|
||||
save_epoch_step: 100
|
||||
# evaluation is run every 835 iterations
|
||||
eval_batch_step: [0, 835]
|
||||
cal_metric_during_train: False
|
||||
pretrained_model: ../pretrain_models/ResNet50_vd_ssld_pretrained
|
||||
checkpoints: #output/fce_r50_ctw/latest
|
||||
save_inference_dir:
|
||||
use_visualdl: False
|
||||
infer_img: doc/imgs_en/img_10.jpg
|
||||
save_res_path: ./output/fce_r50_ctw/predicts_ctw.txt
|
||||
|
||||
|
||||
Architecture:
|
||||
model_type: det
|
||||
algorithm: FCE
|
||||
Transform:
|
||||
Backbone:
|
||||
name: ResNet
|
||||
layers: 50
|
||||
dcn_stage: [False, True, True, True]
|
||||
out_indices: [1,2,3]
|
||||
Neck:
|
||||
name: FCEFPN
|
||||
in_channels: [512, 1024, 2048]
|
||||
out_channels: 256
|
||||
has_extra_convs: False
|
||||
extra_stage: 0
|
||||
Head:
|
||||
name: FCEHead
|
||||
in_channels: 256
|
||||
scales: [8, 16, 32]
|
||||
fourier_degree: 5
|
||||
Loss:
|
||||
name: FCELoss
|
||||
fourier_degree: 5
|
||||
num_sample: 50
|
||||
|
||||
Optimizer:
|
||||
name: Adam
|
||||
beta1: 0.9
|
||||
beta2: 0.999
|
||||
lr:
|
||||
learning_rate: 0.0001
|
||||
regularizer:
|
||||
name: 'L2'
|
||||
factor: 0
|
||||
|
||||
PostProcess:
|
||||
name: FCEPostProcess
|
||||
scales: [8, 16, 32]
|
||||
alpha: 1.0
|
||||
beta: 1.0
|
||||
fourier_degree: 5
|
||||
|
||||
Metric:
|
||||
name: DetFCEMetric
|
||||
main_indicator: hmean
|
||||
|
||||
Train:
|
||||
dataset:
|
||||
name: SimpleDataSet
|
||||
data_dir: /data/Dataset/OCR_det/ctw1500/imgs/
|
||||
label_file_list:
|
||||
- /data/Dataset/OCR_det/ctw1500/imgs/training.txt
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: BGR
|
||||
channel_first: False
|
||||
ignore_orientation: True
|
||||
- DetLabelEncode: # Class handling label
|
||||
- ColorJitter:
|
||||
brightness: 0.142
|
||||
saturation: 0.5
|
||||
contrast: 0.5
|
||||
- RandomScaling:
|
||||
- RandomCropFlip:
|
||||
crop_ratio: 0.5
|
||||
- RandomCropPolyInstances:
|
||||
crop_ratio: 0.8
|
||||
min_side_ratio: 0.3
|
||||
- RandomRotatePolyInstances:
|
||||
rotate_ratio: 0.5
|
||||
max_angle: 30
|
||||
pad_with_fixed_color: False
|
||||
- SquareResizePad:
|
||||
target_size: 800
|
||||
pad_ratio: 0.6
|
||||
- IaaAugment:
|
||||
augmenter_args:
|
||||
- { 'type': Fliplr, 'args': { 'p': 0.5 } }
|
||||
- FCENetTargets:
|
||||
fourier_degree: 5
|
||||
- NormalizeImage:
|
||||
scale: 1./255.
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: 'hwc'
|
||||
- ToCHWImage:
|
||||
- KeepKeys:
|
||||
keep_keys: ['image', 'p3_maps', 'p4_maps', 'p5_maps'] # dataloader will return list in this order
|
||||
loader:
|
||||
shuffle: True
|
||||
drop_last: False
|
||||
batch_size_per_card: 6
|
||||
num_workers: 8
|
||||
|
||||
Eval:
|
||||
dataset:
|
||||
name: SimpleDataSet
|
||||
data_dir: /data/Dataset/OCR_det/ctw1500/imgs/
|
||||
label_file_list:
|
||||
- /data/Dataset/OCR_det/ctw1500/imgs/test.txt
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: BGR
|
||||
channel_first: False
|
||||
ignore_orientation: True
|
||||
- DetLabelEncode: # Class handling label
|
||||
- DetResizeForTest:
|
||||
# resize_long: 1280
|
||||
rescale_img: [1080, 736]
|
||||
- NormalizeImage:
|
||||
scale: 1./255.
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: 'hwc'
|
||||
- Pad:
|
||||
- ToCHWImage:
|
||||
- KeepKeys:
|
||||
keep_keys: ['image', 'shape', 'polys', 'ignore_tags']
|
||||
loader:
|
||||
shuffle: False
|
||||
drop_last: False
|
||||
batch_size_per_card: 1 # must be 1
|
||||
num_workers: 2
|
|
@ -36,6 +36,9 @@ from .gen_table_mask import *
|
|||
|
||||
from .vqa import *
|
||||
|
||||
from .fce_aug import *
|
||||
from .fce_targets import FCENetTargets
|
||||
|
||||
|
||||
def transform(data, ops=None):
|
||||
""" transform """
|
||||
|
|
|
@ -0,0 +1,633 @@
|
|||
import numpy as np
|
||||
from PIL import Image, ImageDraw
|
||||
import paddle.vision.transforms as paddle_trans
|
||||
import cv2
|
||||
import Polygon as plg
|
||||
import math
|
||||
|
||||
|
||||
def imresize(img,
|
||||
size,
|
||||
return_scale=False,
|
||||
interpolation='bilinear',
|
||||
out=None,
|
||||
backend=None):
|
||||
"""Resize image to a given size.
|
||||
|
||||
Args:
|
||||
img (ndarray): The input image.
|
||||
size (tuple[int]): Target size (w, h).
|
||||
return_scale (bool): Whether to return `w_scale` and `h_scale`.
|
||||
interpolation (str): Interpolation method, accepted values are
|
||||
"nearest", "bilinear", "bicubic", "area", "lanczos" for 'cv2'
|
||||
backend, "nearest", "bilinear" for 'pillow' backend.
|
||||
out (ndarray): The output destination.
|
||||
backend (str | None): The image resize backend type. Options are `cv2`,
|
||||
`pillow`, `None`. If backend is None, the global imread_backend
|
||||
specified by ``mmcv.use_backend()`` will be used. Default: None.
|
||||
|
||||
Returns:
|
||||
tuple | ndarray: (`resized_img`, `w_scale`, `h_scale`) or
|
||||
`resized_img`.
|
||||
"""
|
||||
cv2_interp_codes = {
|
||||
'nearest': cv2.INTER_NEAREST,
|
||||
'bilinear': cv2.INTER_LINEAR,
|
||||
'bicubic': cv2.INTER_CUBIC,
|
||||
'area': cv2.INTER_AREA,
|
||||
'lanczos': cv2.INTER_LANCZOS4
|
||||
}
|
||||
h, w = img.shape[:2]
|
||||
if backend is None:
|
||||
backend = 'cv2'
|
||||
if backend not in ['cv2', 'pillow']:
|
||||
raise ValueError(f'backend: {backend} is not supported for resize.'
|
||||
f"Supported backends are 'cv2', 'pillow'")
|
||||
|
||||
if backend == 'pillow':
|
||||
assert img.dtype == np.uint8, 'Pillow backend only support uint8 type'
|
||||
pil_image = Image.fromarray(img)
|
||||
pil_image = pil_image.resize(size, pillow_interp_codes[interpolation])
|
||||
resized_img = np.array(pil_image)
|
||||
else:
|
||||
resized_img = cv2.resize(
|
||||
img, size, dst=out, interpolation=cv2_interp_codes[interpolation])
|
||||
if not return_scale:
|
||||
return resized_img
|
||||
else:
|
||||
w_scale = size[0] / w
|
||||
h_scale = size[1] / h
|
||||
return resized_img, w_scale, h_scale
|
||||
|
||||
|
||||
class RandomScaling:
|
||||
def __init__(self, size=800, scale=(3. / 4, 5. / 2), **kwargs):
|
||||
"""Random scale the image while keeping aspect.
|
||||
|
||||
Args:
|
||||
size (int) : Base size before scaling.
|
||||
scale (tuple(float)) : The range of scaling.
|
||||
"""
|
||||
assert isinstance(size, int)
|
||||
assert isinstance(scale, float) or isinstance(scale, tuple)
|
||||
self.size = size
|
||||
self.scale = scale if isinstance(scale, tuple) \
|
||||
else (1 - scale, 1 + scale)
|
||||
|
||||
def __call__(self, data):
|
||||
image = data['image']
|
||||
text_polys = data['polys']
|
||||
h, w, _ = image.shape
|
||||
|
||||
aspect_ratio = np.random.uniform(min(self.scale), max(self.scale))
|
||||
scales = self.size * 1.0 / max(h, w) * aspect_ratio
|
||||
scales = np.array([scales, scales])
|
||||
out_size = (int(h * scales[1]), int(w * scales[0]))
|
||||
image = imresize(image, out_size[::-1])
|
||||
|
||||
data['image'] = image
|
||||
text_polys[:, :, 0::2] = text_polys[:, :, 0::2] * scales[1]
|
||||
text_polys[:, :, 1::2] = text_polys[:, :, 1::2] * scales[0]
|
||||
data['polys'] = text_polys
|
||||
|
||||
# import os
|
||||
# base_name = os.path.split(data['img_path'])[-1]
|
||||
# img = image[..., ::-1]
|
||||
# img = Image.fromarray(img)
|
||||
# draw = ImageDraw.Draw(img)
|
||||
# for box in text_polys:
|
||||
# draw.polygon(box, outline=(0, 255, 255,), )
|
||||
# import time
|
||||
# img.save('tmp/{}.jpg'.format(base_name[:-4]))
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def poly_intersection(poly_det, poly_gt):
|
||||
"""Calculate the intersection area between two polygon.
|
||||
|
||||
Args:
|
||||
poly_det (Polygon): A polygon predicted by detector.
|
||||
poly_gt (Polygon): A gt polygon.
|
||||
|
||||
Returns:
|
||||
intersection_area (float): The intersection area between two polygons.
|
||||
"""
|
||||
assert isinstance(poly_det, plg.Polygon)
|
||||
assert isinstance(poly_gt, plg.Polygon)
|
||||
|
||||
poly_inter = poly_det & poly_gt
|
||||
if len(poly_inter) == 0:
|
||||
return 0, poly_inter
|
||||
return poly_inter.area(), poly_inter
|
||||
|
||||
|
||||
class RandomCropFlip:
|
||||
def __init__(self,
|
||||
pad_ratio=0.1,
|
||||
crop_ratio=0.5,
|
||||
iter_num=1,
|
||||
min_area_ratio=0.2,
|
||||
**kwargs):
|
||||
"""Random crop and flip a patch of the image.
|
||||
|
||||
Args:
|
||||
crop_ratio (float): The ratio of cropping.
|
||||
iter_num (int): Number of operations.
|
||||
min_area_ratio (float): Minimal area ratio between cropped patch
|
||||
and original image.
|
||||
"""
|
||||
assert isinstance(crop_ratio, float)
|
||||
assert isinstance(iter_num, int)
|
||||
assert isinstance(min_area_ratio, float)
|
||||
|
||||
self.pad_ratio = pad_ratio
|
||||
self.epsilon = 1e-2
|
||||
self.crop_ratio = crop_ratio
|
||||
self.iter_num = iter_num
|
||||
self.min_area_ratio = min_area_ratio
|
||||
|
||||
def __call__(self, results):
|
||||
for i in range(self.iter_num):
|
||||
results = self.random_crop_flip(results)
|
||||
|
||||
return results
|
||||
|
||||
def random_crop_flip(self, results):
|
||||
image = results['image']
|
||||
polygons = results['polys']
|
||||
ignore_tags = results['ignore_tags']
|
||||
if len(polygons) == 0:
|
||||
return results
|
||||
|
||||
if np.random.random() >= self.crop_ratio:
|
||||
return results
|
||||
|
||||
h, w, _ = image.shape
|
||||
area = h * w
|
||||
pad_h = int(h * self.pad_ratio)
|
||||
pad_w = int(w * self.pad_ratio)
|
||||
h_axis, w_axis = self.generate_crop_target(image, polygons, pad_h,
|
||||
pad_w)
|
||||
if len(h_axis) == 0 or len(w_axis) == 0:
|
||||
return results
|
||||
|
||||
attempt = 0
|
||||
while attempt < 50:
|
||||
attempt += 1
|
||||
polys_keep = []
|
||||
polys_new = []
|
||||
ignore_tags_keep = []
|
||||
ignore_tags_new = []
|
||||
xx = np.random.choice(w_axis, size=2)
|
||||
xmin = np.min(xx) - pad_w
|
||||
xmax = np.max(xx) - pad_w
|
||||
xmin = np.clip(xmin, 0, w - 1)
|
||||
xmax = np.clip(xmax, 0, w - 1)
|
||||
yy = np.random.choice(h_axis, size=2)
|
||||
ymin = np.min(yy) - pad_h
|
||||
ymax = np.max(yy) - pad_h
|
||||
ymin = np.clip(ymin, 0, h - 1)
|
||||
ymax = np.clip(ymax, 0, h - 1)
|
||||
if (xmax - xmin) * (ymax - ymin) < area * self.min_area_ratio:
|
||||
# area too small
|
||||
continue
|
||||
|
||||
pts = np.stack([[xmin, xmax, xmax, xmin],
|
||||
[ymin, ymin, ymax, ymax]]).T.astype(np.int32)
|
||||
pp = plg.Polygon(pts)
|
||||
fail_flag = False
|
||||
for polygon, ignore_tag in zip(polygons, ignore_tags):
|
||||
ppi = plg.Polygon(polygon.reshape(-1, 2))
|
||||
ppiou, _ = poly_intersection(ppi, pp)
|
||||
if np.abs(ppiou - float(ppi.area())) > self.epsilon and \
|
||||
np.abs(ppiou) > self.epsilon:
|
||||
fail_flag = True
|
||||
break
|
||||
elif np.abs(ppiou - float(ppi.area())) < self.epsilon:
|
||||
polys_new.append(polygon)
|
||||
ignore_tags_new.append(ignore_tag)
|
||||
else:
|
||||
polys_keep.append(polygon)
|
||||
ignore_tags_keep.append(ignore_tag)
|
||||
|
||||
if fail_flag:
|
||||
continue
|
||||
else:
|
||||
break
|
||||
|
||||
cropped = image[ymin:ymax, xmin:xmax, :]
|
||||
select_type = np.random.randint(3)
|
||||
if select_type == 0:
|
||||
img = np.ascontiguousarray(cropped[:, ::-1])
|
||||
elif select_type == 1:
|
||||
img = np.ascontiguousarray(cropped[::-1, :])
|
||||
else:
|
||||
img = np.ascontiguousarray(cropped[::-1, ::-1])
|
||||
image[ymin:ymax, xmin:xmax, :] = img
|
||||
results['img'] = image
|
||||
|
||||
if len(polys_new) != 0:
|
||||
height, width, _ = cropped.shape
|
||||
if select_type == 0:
|
||||
for idx, polygon in enumerate(polys_new):
|
||||
poly = polygon.reshape(-1, 2)
|
||||
poly[:, 0] = width - poly[:, 0] + 2 * xmin
|
||||
polys_new[idx] = poly
|
||||
elif select_type == 1:
|
||||
for idx, polygon in enumerate(polys_new):
|
||||
poly = polygon.reshape(-1, 2)
|
||||
poly[:, 1] = height - poly[:, 1] + 2 * ymin
|
||||
polys_new[idx] = poly
|
||||
else:
|
||||
for idx, polygon in enumerate(polys_new):
|
||||
poly = polygon.reshape(-1, 2)
|
||||
poly[:, 0] = width - poly[:, 0] + 2 * xmin
|
||||
poly[:, 1] = height - poly[:, 1] + 2 * ymin
|
||||
polys_new[idx] = poly
|
||||
polygons = polys_keep + polys_new
|
||||
ignore_tags = ignore_tags_keep + ignore_tags_new
|
||||
results['polys'] = np.array(polygons)
|
||||
results['ignore_tags'] = ignore_tags
|
||||
|
||||
return results
|
||||
|
||||
def generate_crop_target(self, image, all_polys, pad_h, pad_w):
|
||||
"""Generate crop target and make sure not to crop the polygon
|
||||
instances.
|
||||
|
||||
Args:
|
||||
image (ndarray): The image waited to be crop.
|
||||
all_polys (list[list[ndarray]]): All polygons including ground
|
||||
truth polygons and ground truth ignored polygons.
|
||||
pad_h (int): Padding length of height.
|
||||
pad_w (int): Padding length of width.
|
||||
Returns:
|
||||
h_axis (ndarray): Vertical cropping range.
|
||||
w_axis (ndarray): Horizontal cropping range.
|
||||
"""
|
||||
h, w, _ = image.shape
|
||||
h_array = np.zeros((h + pad_h * 2), dtype=np.int32)
|
||||
w_array = np.zeros((w + pad_w * 2), dtype=np.int32)
|
||||
|
||||
text_polys = []
|
||||
for polygon in all_polys:
|
||||
rect = cv2.minAreaRect(polygon.astype(np.int32).reshape(-1, 2))
|
||||
box = cv2.boxPoints(rect)
|
||||
box = np.int0(box)
|
||||
text_polys.append([box[0], box[1], box[2], box[3]])
|
||||
|
||||
polys = np.array(text_polys, dtype=np.int32)
|
||||
for poly in polys:
|
||||
poly = np.round(poly, decimals=0).astype(np.int32)
|
||||
minx = np.min(poly[:, 0])
|
||||
maxx = np.max(poly[:, 0])
|
||||
w_array[minx + pad_w:maxx + pad_w] = 1
|
||||
miny = np.min(poly[:, 1])
|
||||
maxy = np.max(poly[:, 1])
|
||||
h_array[miny + pad_h:maxy + pad_h] = 1
|
||||
|
||||
h_axis = np.where(h_array == 0)[0]
|
||||
w_axis = np.where(w_array == 0)[0]
|
||||
return h_axis, w_axis
|
||||
|
||||
|
||||
class RandomCropPolyInstances:
|
||||
"""Randomly crop images and make sure to contain at least one intact
|
||||
instance."""
|
||||
|
||||
def __init__(self, crop_ratio=5.0 / 8.0, min_side_ratio=0.4, **kwargs):
|
||||
super().__init__()
|
||||
self.crop_ratio = crop_ratio
|
||||
self.min_side_ratio = min_side_ratio
|
||||
|
||||
def sample_valid_start_end(self, valid_array, min_len, max_start, min_end):
|
||||
|
||||
assert isinstance(min_len, int)
|
||||
assert len(valid_array) > min_len
|
||||
|
||||
start_array = valid_array.copy()
|
||||
max_start = min(len(start_array) - min_len, max_start)
|
||||
start_array[max_start:] = 0
|
||||
start_array[0] = 1
|
||||
diff_array = np.hstack([0, start_array]) - np.hstack([start_array, 0])
|
||||
region_starts = np.where(diff_array < 0)[0]
|
||||
region_ends = np.where(diff_array > 0)[0]
|
||||
region_ind = np.random.randint(0, len(region_starts))
|
||||
start = np.random.randint(region_starts[region_ind],
|
||||
region_ends[region_ind])
|
||||
|
||||
end_array = valid_array.copy()
|
||||
min_end = max(start + min_len, min_end)
|
||||
end_array[:min_end] = 0
|
||||
end_array[-1] = 1
|
||||
diff_array = np.hstack([0, end_array]) - np.hstack([end_array, 0])
|
||||
region_starts = np.where(diff_array < 0)[0]
|
||||
region_ends = np.where(diff_array > 0)[0]
|
||||
region_ind = np.random.randint(0, len(region_starts))
|
||||
end = np.random.randint(region_starts[region_ind],
|
||||
region_ends[region_ind])
|
||||
return start, end
|
||||
|
||||
def sample_crop_box(self, img_size, results):
|
||||
"""Generate crop box and make sure not to crop the polygon instances.
|
||||
|
||||
Args:
|
||||
img_size (tuple(int)): The image size (h, w).
|
||||
results (dict): The results dict.
|
||||
"""
|
||||
|
||||
assert isinstance(img_size, tuple)
|
||||
h, w = img_size[:2]
|
||||
|
||||
key_masks = results['polys']
|
||||
|
||||
x_valid_array = np.ones(w, dtype=np.int32)
|
||||
y_valid_array = np.ones(h, dtype=np.int32)
|
||||
|
||||
selected_mask = key_masks[np.random.randint(0, len(key_masks))]
|
||||
selected_mask = selected_mask.reshape((-1, 2)).astype(np.int32)
|
||||
max_x_start = max(np.min(selected_mask[:, 0]) - 2, 0)
|
||||
min_x_end = min(np.max(selected_mask[:, 0]) + 3, w - 1)
|
||||
max_y_start = max(np.min(selected_mask[:, 1]) - 2, 0)
|
||||
min_y_end = min(np.max(selected_mask[:, 1]) + 3, h - 1)
|
||||
|
||||
# for key in results.get('mask_fields', []):
|
||||
# if len(results[key].masks) == 0:
|
||||
# continue
|
||||
# masks = results[key].masks
|
||||
for mask in key_masks:
|
||||
# assert len(mask) == 1
|
||||
mask = mask.reshape((-1, 2)).astype(np.int32)
|
||||
clip_x = np.clip(mask[:, 0], 0, w - 1)
|
||||
clip_y = np.clip(mask[:, 1], 0, h - 1)
|
||||
min_x, max_x = np.min(clip_x), np.max(clip_x)
|
||||
min_y, max_y = np.min(clip_y), np.max(clip_y)
|
||||
|
||||
x_valid_array[min_x - 2:max_x + 3] = 0
|
||||
y_valid_array[min_y - 2:max_y + 3] = 0
|
||||
|
||||
min_w = int(w * self.min_side_ratio)
|
||||
min_h = int(h * self.min_side_ratio)
|
||||
|
||||
x1, x2 = self.sample_valid_start_end(x_valid_array, min_w, max_x_start,
|
||||
min_x_end)
|
||||
y1, y2 = self.sample_valid_start_end(y_valid_array, min_h, max_y_start,
|
||||
min_y_end)
|
||||
|
||||
return np.array([x1, y1, x2, y2])
|
||||
|
||||
def crop_img(self, img, bbox):
|
||||
assert img.ndim == 3
|
||||
h, w, _ = img.shape
|
||||
assert 0 <= bbox[1] < bbox[3] <= h
|
||||
assert 0 <= bbox[0] < bbox[2] <= w
|
||||
return img[bbox[1]:bbox[3], bbox[0]:bbox[2]]
|
||||
|
||||
def __call__(self, results):
|
||||
image = results['image']
|
||||
polygons = results['polys']
|
||||
ignore_tags = results['ignore_tags']
|
||||
if len(polygons) < 1:
|
||||
return results
|
||||
|
||||
if np.random.random_sample() < self.crop_ratio:
|
||||
|
||||
crop_box = self.sample_crop_box(image.shape, results)
|
||||
img = self.crop_img(image, crop_box)
|
||||
results['image'] = img
|
||||
# crop and filter masks
|
||||
x1, y1, x2, y2 = crop_box
|
||||
w = max(x2 - x1, 1)
|
||||
h = max(y2 - y1, 1)
|
||||
polygons[:, :, 0::2] = polygons[:, :, 0::2] - x1
|
||||
polygons[:, :, 1::2] = polygons[:, :, 1::2] - y1
|
||||
|
||||
valid_masks_list = []
|
||||
valid_tags_list = []
|
||||
for ind, polygon in enumerate(polygons):
|
||||
if (polygon[:, ::2] > -4).all() and (
|
||||
polygon[:, ::2] < w + 4).all() and (
|
||||
polygon[:, 1::2] > -4).all() and (
|
||||
polygon[:, 1::2] < h + 4).all():
|
||||
polygon[:, ::2] = np.clip(polygon[:, ::2], 0, w)
|
||||
polygon[:, 1::2] = np.clip(polygon[:, 1::2], 0, h)
|
||||
valid_masks_list.append(polygon)
|
||||
valid_tags_list.append(ignore_tags[ind])
|
||||
|
||||
results['polys'] = np.array(valid_masks_list)
|
||||
results['ignore_tags'] = valid_tags_list
|
||||
|
||||
return results
|
||||
|
||||
def __repr__(self):
|
||||
repr_str = self.__class__.__name__
|
||||
return repr_str
|
||||
|
||||
|
||||
class RandomRotatePolyInstances:
|
||||
def __init__(self,
|
||||
rotate_ratio=0.5,
|
||||
max_angle=10,
|
||||
pad_with_fixed_color=False,
|
||||
pad_value=(0, 0, 0),
|
||||
**kwargs):
|
||||
"""Randomly rotate images and polygon masks.
|
||||
|
||||
Args:
|
||||
rotate_ratio (float): The ratio of samples to operate rotation.
|
||||
max_angle (int): The maximum rotation angle.
|
||||
pad_with_fixed_color (bool): The flag for whether to pad rotated
|
||||
image with fixed value. If set to False, the rotated image will
|
||||
be padded onto cropped image.
|
||||
pad_value (tuple(int)): The color value for padding rotated image.
|
||||
"""
|
||||
self.rotate_ratio = rotate_ratio
|
||||
self.max_angle = max_angle
|
||||
self.pad_with_fixed_color = pad_with_fixed_color
|
||||
self.pad_value = pad_value
|
||||
|
||||
def rotate(self, center, points, theta, center_shift=(0, 0)):
|
||||
# rotate points.
|
||||
(center_x, center_y) = center
|
||||
center_y = -center_y
|
||||
x, y = points[:, ::2], points[:, 1::2]
|
||||
y = -y
|
||||
|
||||
theta = theta / 180 * math.pi
|
||||
cos = math.cos(theta)
|
||||
sin = math.sin(theta)
|
||||
|
||||
x = (x - center_x)
|
||||
y = (y - center_y)
|
||||
|
||||
_x = center_x + x * cos - y * sin + center_shift[0]
|
||||
_y = -(center_y + x * sin + y * cos) + center_shift[1]
|
||||
|
||||
points[:, ::2], points[:, 1::2] = _x, _y
|
||||
return points
|
||||
|
||||
def cal_canvas_size(self, ori_size, degree):
|
||||
assert isinstance(ori_size, tuple)
|
||||
angle = degree * math.pi / 180.0
|
||||
h, w = ori_size[:2]
|
||||
|
||||
cos = math.cos(angle)
|
||||
sin = math.sin(angle)
|
||||
canvas_h = int(w * math.fabs(sin) + h * math.fabs(cos))
|
||||
canvas_w = int(w * math.fabs(cos) + h * math.fabs(sin))
|
||||
|
||||
canvas_size = (canvas_h, canvas_w)
|
||||
return canvas_size
|
||||
|
||||
def sample_angle(self, max_angle):
|
||||
angle = np.random.random_sample() * 2 * max_angle - max_angle
|
||||
return angle
|
||||
|
||||
def rotate_img(self, img, angle, canvas_size):
|
||||
h, w = img.shape[:2]
|
||||
rotation_matrix = cv2.getRotationMatrix2D((w / 2, h / 2), angle, 1)
|
||||
rotation_matrix[0, 2] += int((canvas_size[1] - w) / 2)
|
||||
rotation_matrix[1, 2] += int((canvas_size[0] - h) / 2)
|
||||
|
||||
if self.pad_with_fixed_color:
|
||||
target_img = cv2.warpAffine(
|
||||
img,
|
||||
rotation_matrix, (canvas_size[1], canvas_size[0]),
|
||||
flags=cv2.INTER_NEAREST,
|
||||
borderValue=self.pad_value)
|
||||
else:
|
||||
mask = np.zeros_like(img)
|
||||
(h_ind, w_ind) = (np.random.randint(0, h * 7 // 8),
|
||||
np.random.randint(0, w * 7 // 8))
|
||||
img_cut = img[h_ind:(h_ind + h // 9), w_ind:(w_ind + w // 9)]
|
||||
img_cut = imresize(img_cut, (canvas_size[1], canvas_size[0]))
|
||||
mask = cv2.warpAffine(
|
||||
mask,
|
||||
rotation_matrix, (canvas_size[1], canvas_size[0]),
|
||||
borderValue=[1, 1, 1])
|
||||
target_img = cv2.warpAffine(
|
||||
img,
|
||||
rotation_matrix, (canvas_size[1], canvas_size[0]),
|
||||
borderValue=[0, 0, 0])
|
||||
target_img = target_img + img_cut * mask
|
||||
|
||||
return target_img
|
||||
|
||||
def __call__(self, results):
|
||||
if np.random.random_sample() < self.rotate_ratio:
|
||||
image = results['image']
|
||||
polygons = results['polys']
|
||||
h, w = image.shape[:2]
|
||||
|
||||
angle = self.sample_angle(self.max_angle)
|
||||
canvas_size = self.cal_canvas_size((h, w), angle)
|
||||
center_shift = (int((canvas_size[1] - w) / 2), int(
|
||||
(canvas_size[0] - h) / 2))
|
||||
image = self.rotate_img(image, angle, canvas_size)
|
||||
results['image'] = image
|
||||
# rotate polygons
|
||||
rotated_masks = []
|
||||
for mask in polygons:
|
||||
rotated_mask = self.rotate((w / 2, h / 2), mask, angle,
|
||||
center_shift)
|
||||
rotated_masks.append(rotated_mask)
|
||||
results['polys'] = np.array(rotated_masks)
|
||||
|
||||
return results
|
||||
|
||||
def __repr__(self):
|
||||
repr_str = self.__class__.__name__
|
||||
return repr_str
|
||||
|
||||
|
||||
class SquareResizePad:
|
||||
def __init__(self,
|
||||
target_size,
|
||||
pad_ratio=0.6,
|
||||
pad_with_fixed_color=False,
|
||||
pad_value=(0, 0, 0),
|
||||
**kwargs):
|
||||
"""Resize or pad images to be square shape.
|
||||
|
||||
Args:
|
||||
target_size (int): The target size of square shaped image.
|
||||
pad_with_fixed_color (bool): The flag for whether to pad rotated
|
||||
image with fixed value. If set to False, the rescales image will
|
||||
be padded onto cropped image.
|
||||
pad_value (tuple(int)): The color value for padding rotated image.
|
||||
"""
|
||||
assert isinstance(target_size, int)
|
||||
assert isinstance(pad_ratio, float)
|
||||
assert isinstance(pad_with_fixed_color, bool)
|
||||
assert isinstance(pad_value, tuple)
|
||||
|
||||
self.target_size = target_size
|
||||
self.pad_ratio = pad_ratio
|
||||
self.pad_with_fixed_color = pad_with_fixed_color
|
||||
self.pad_value = pad_value
|
||||
|
||||
def resize_img(self, img, keep_ratio=True):
|
||||
h, w, _ = img.shape
|
||||
if keep_ratio:
|
||||
t_h = self.target_size if h >= w else int(h * self.target_size / w)
|
||||
t_w = self.target_size if h <= w else int(w * self.target_size / h)
|
||||
else:
|
||||
t_h = t_w = self.target_size
|
||||
img = imresize(img, (t_w, t_h))
|
||||
return img, (t_h, t_w)
|
||||
|
||||
def square_pad(self, img):
|
||||
h, w = img.shape[:2]
|
||||
if h == w:
|
||||
return img, (0, 0)
|
||||
pad_size = max(h, w)
|
||||
if self.pad_with_fixed_color:
|
||||
expand_img = np.ones((pad_size, pad_size, 3), dtype=np.uint8)
|
||||
expand_img[:] = self.pad_value
|
||||
else:
|
||||
(h_ind, w_ind) = (np.random.randint(0, h * 7 // 8),
|
||||
np.random.randint(0, w * 7 // 8))
|
||||
img_cut = img[h_ind:(h_ind + h // 9), w_ind:(w_ind + w // 9)]
|
||||
expand_img = imresize(img_cut, (pad_size, pad_size))
|
||||
if h > w:
|
||||
y0, x0 = 0, (h - w) // 2
|
||||
else:
|
||||
y0, x0 = (w - h) // 2, 0
|
||||
expand_img[y0:y0 + h, x0:x0 + w] = img
|
||||
offset = (x0, y0)
|
||||
|
||||
return expand_img, offset
|
||||
|
||||
def square_pad_mask(self, points, offset):
|
||||
x0, y0 = offset
|
||||
pad_points = points.copy()
|
||||
pad_points[::2] = pad_points[::2] + x0
|
||||
pad_points[1::2] = pad_points[1::2] + y0
|
||||
return pad_points
|
||||
|
||||
def __call__(self, results):
|
||||
image = results['image']
|
||||
polygons = results['polys']
|
||||
h, w = image.shape[:2]
|
||||
|
||||
if np.random.random_sample() < self.pad_ratio:
|
||||
image, out_size = self.resize_img(image, keep_ratio=True)
|
||||
image, offset = self.square_pad(image)
|
||||
else:
|
||||
image, out_size = self.resize_img(image, keep_ratio=False)
|
||||
offset = (0, 0)
|
||||
# image, out_size = self.resize_img(image, keep_ratio=True)
|
||||
# image, offset = self.square_pad(image)
|
||||
results['image'] = image
|
||||
polygons[:, :, 0::2] = polygons[:, :, 0::2] * out_size[1] / w + offset[
|
||||
0]
|
||||
polygons[:, :, 1::2] = polygons[:, :, 1::2] * out_size[0] / h + offset[
|
||||
1]
|
||||
results['polys'] = polygons
|
||||
|
||||
return results
|
||||
|
||||
def __repr__(self):
|
||||
repr_str = self.__class__.__name__
|
||||
return repr_str
|
|
@ -0,0 +1,670 @@
|
|||
import cv2
|
||||
import numpy as np
|
||||
from numpy.fft import fft
|
||||
from numpy.linalg import norm
|
||||
import sys
|
||||
|
||||
|
||||
class FCENetTargets:
|
||||
"""Generate the ground truth targets of FCENet: Fourier Contour Embedding
|
||||
for Arbitrary-Shaped Text Detection.
|
||||
|
||||
[https://arxiv.org/abs/2104.10442]
|
||||
|
||||
Args:
|
||||
fourier_degree (int): The maximum Fourier transform degree k.
|
||||
resample_step (float): The step size for resampling the text center
|
||||
line (TCL). It's better not to exceed half of the minimum width.
|
||||
center_region_shrink_ratio (float): The shrink ratio of text center
|
||||
region.
|
||||
level_size_divisors (tuple(int)): The downsample ratio on each level.
|
||||
level_proportion_range (tuple(tuple(int))): The range of text sizes
|
||||
assigned to each level.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
fourier_degree=5,
|
||||
resample_step=4.0,
|
||||
center_region_shrink_ratio=0.3,
|
||||
level_size_divisors=(8, 16, 32),
|
||||
level_proportion_range=((0, 0.25), (0.2, 0.65), (0.55, 1.0)),
|
||||
orientation_thr=2.0,
|
||||
**kwargs):
|
||||
|
||||
super().__init__()
|
||||
assert isinstance(level_size_divisors, tuple)
|
||||
assert isinstance(level_proportion_range, tuple)
|
||||
assert len(level_size_divisors) == len(level_proportion_range)
|
||||
self.fourier_degree = fourier_degree
|
||||
self.resample_step = resample_step
|
||||
self.center_region_shrink_ratio = center_region_shrink_ratio
|
||||
self.level_size_divisors = level_size_divisors
|
||||
self.level_proportion_range = level_proportion_range
|
||||
|
||||
self.orientation_thr = orientation_thr
|
||||
|
||||
def vector_angle(self, vec1, vec2):
|
||||
if vec1.ndim > 1:
|
||||
unit_vec1 = vec1 / (norm(vec1, axis=-1) + 1e-8).reshape((-1, 1))
|
||||
else:
|
||||
unit_vec1 = vec1 / (norm(vec1, axis=-1) + 1e-8)
|
||||
if vec2.ndim > 1:
|
||||
unit_vec2 = vec2 / (norm(vec2, axis=-1) + 1e-8).reshape((-1, 1))
|
||||
else:
|
||||
unit_vec2 = vec2 / (norm(vec2, axis=-1) + 1e-8)
|
||||
return np.arccos(
|
||||
np.clip(
|
||||
np.sum(unit_vec1 * unit_vec2, axis=-1), -1.0, 1.0))
|
||||
|
||||
def resample_line(self, line, n):
|
||||
"""Resample n points on a line.
|
||||
|
||||
Args:
|
||||
line (ndarray): The points composing a line.
|
||||
n (int): The resampled points number.
|
||||
|
||||
Returns:
|
||||
resampled_line (ndarray): The points composing the resampled line.
|
||||
"""
|
||||
|
||||
assert line.ndim == 2
|
||||
assert line.shape[0] >= 2
|
||||
assert line.shape[1] == 2
|
||||
assert isinstance(n, int)
|
||||
assert n > 0
|
||||
|
||||
length_list = [
|
||||
norm(line[i + 1] - line[i]) for i in range(len(line) - 1)
|
||||
]
|
||||
total_length = sum(length_list)
|
||||
length_cumsum = np.cumsum([0.0] + length_list)
|
||||
delta_length = total_length / (float(n) + 1e-8)
|
||||
|
||||
current_edge_ind = 0
|
||||
resampled_line = [line[0]]
|
||||
|
||||
for i in range(1, n):
|
||||
current_line_len = i * delta_length
|
||||
|
||||
while current_line_len >= length_cumsum[current_edge_ind + 1]:
|
||||
current_edge_ind += 1
|
||||
current_edge_end_shift = current_line_len - length_cumsum[
|
||||
current_edge_ind]
|
||||
end_shift_ratio = current_edge_end_shift / length_list[
|
||||
current_edge_ind]
|
||||
current_point = line[current_edge_ind] + (line[current_edge_ind + 1]
|
||||
- line[current_edge_ind]
|
||||
) * end_shift_ratio
|
||||
resampled_line.append(current_point)
|
||||
|
||||
resampled_line.append(line[-1])
|
||||
resampled_line = np.array(resampled_line)
|
||||
|
||||
return resampled_line
|
||||
|
||||
def reorder_poly_edge(self, points):
|
||||
"""Get the respective points composing head edge, tail edge, top
|
||||
sideline and bottom sideline.
|
||||
|
||||
Args:
|
||||
points (ndarray): The points composing a text polygon.
|
||||
|
||||
Returns:
|
||||
head_edge (ndarray): The two points composing the head edge of text
|
||||
polygon.
|
||||
tail_edge (ndarray): The two points composing the tail edge of text
|
||||
polygon.
|
||||
top_sideline (ndarray): The points composing top curved sideline of
|
||||
text polygon.
|
||||
bot_sideline (ndarray): The points composing bottom curved sideline
|
||||
of text polygon.
|
||||
"""
|
||||
|
||||
assert points.ndim == 2
|
||||
assert points.shape[0] >= 4
|
||||
assert points.shape[1] == 2
|
||||
|
||||
head_inds, tail_inds = self.find_head_tail(points, self.orientation_thr)
|
||||
head_edge, tail_edge = points[head_inds], points[tail_inds]
|
||||
|
||||
pad_points = np.vstack([points, points])
|
||||
if tail_inds[1] < 1:
|
||||
tail_inds[1] = len(points)
|
||||
sideline1 = pad_points[head_inds[1]:tail_inds[1]]
|
||||
sideline2 = pad_points[tail_inds[1]:(head_inds[1] + len(points))]
|
||||
sideline_mean_shift = np.mean(
|
||||
sideline1, axis=0) - np.mean(
|
||||
sideline2, axis=0)
|
||||
|
||||
if sideline_mean_shift[1] > 0:
|
||||
top_sideline, bot_sideline = sideline2, sideline1
|
||||
else:
|
||||
top_sideline, bot_sideline = sideline1, sideline2
|
||||
|
||||
return head_edge, tail_edge, top_sideline, bot_sideline
|
||||
|
||||
def find_head_tail(self, points, orientation_thr):
|
||||
"""Find the head edge and tail edge of a text polygon.
|
||||
|
||||
Args:
|
||||
points (ndarray): The points composing a text polygon.
|
||||
orientation_thr (float): The threshold for distinguishing between
|
||||
head edge and tail edge among the horizontal and vertical edges
|
||||
of a quadrangle.
|
||||
|
||||
Returns:
|
||||
head_inds (list): The indexes of two points composing head edge.
|
||||
tail_inds (list): The indexes of two points composing tail edge.
|
||||
"""
|
||||
|
||||
assert points.ndim == 2
|
||||
assert points.shape[0] >= 4
|
||||
assert points.shape[1] == 2
|
||||
assert isinstance(orientation_thr, float)
|
||||
|
||||
if len(points) > 4:
|
||||
pad_points = np.vstack([points, points[0]])
|
||||
edge_vec = pad_points[1:] - pad_points[:-1]
|
||||
|
||||
theta_sum = []
|
||||
adjacent_vec_theta = []
|
||||
for i, edge_vec1 in enumerate(edge_vec):
|
||||
adjacent_ind = [x % len(edge_vec) for x in [i - 1, i + 1]]
|
||||
adjacent_edge_vec = edge_vec[adjacent_ind]
|
||||
temp_theta_sum = np.sum(
|
||||
self.vector_angle(edge_vec1, adjacent_edge_vec))
|
||||
temp_adjacent_theta = self.vector_angle(adjacent_edge_vec[0],
|
||||
adjacent_edge_vec[1])
|
||||
theta_sum.append(temp_theta_sum)
|
||||
adjacent_vec_theta.append(temp_adjacent_theta)
|
||||
theta_sum_score = np.array(theta_sum) / np.pi
|
||||
adjacent_theta_score = np.array(adjacent_vec_theta) / np.pi
|
||||
poly_center = np.mean(points, axis=0)
|
||||
edge_dist = np.maximum(
|
||||
norm(
|
||||
pad_points[1:] - poly_center, axis=-1),
|
||||
norm(
|
||||
pad_points[:-1] - poly_center, axis=-1))
|
||||
dist_score = edge_dist / np.max(edge_dist)
|
||||
position_score = np.zeros(len(edge_vec))
|
||||
score = 0.5 * theta_sum_score + 0.15 * adjacent_theta_score
|
||||
score += 0.35 * dist_score
|
||||
if len(points) % 2 == 0:
|
||||
position_score[(len(score) // 2 - 1)] += 1
|
||||
position_score[-1] += 1
|
||||
score += 0.1 * position_score
|
||||
pad_score = np.concatenate([score, score])
|
||||
score_matrix = np.zeros((len(score), len(score) - 3))
|
||||
x = np.arange(len(score) - 3) / float(len(score) - 4)
|
||||
gaussian = 1. / (np.sqrt(2. * np.pi) * 0.5) * np.exp(-np.power(
|
||||
(x - 0.5) / 0.5, 2.) / 2)
|
||||
gaussian = gaussian / np.max(gaussian)
|
||||
for i in range(len(score)):
|
||||
score_matrix[i, :] = score[i] + pad_score[(i + 2):(i + len(
|
||||
score) - 1)] * gaussian * 0.3
|
||||
|
||||
head_start, tail_increment = np.unravel_index(score_matrix.argmax(),
|
||||
score_matrix.shape)
|
||||
tail_start = (head_start + tail_increment + 2) % len(points)
|
||||
head_end = (head_start + 1) % len(points)
|
||||
tail_end = (tail_start + 1) % len(points)
|
||||
|
||||
if head_end > tail_end:
|
||||
head_start, tail_start = tail_start, head_start
|
||||
head_end, tail_end = tail_end, head_end
|
||||
head_inds = [head_start, head_end]
|
||||
tail_inds = [tail_start, tail_end]
|
||||
else:
|
||||
if self.vector_slope(points[1] - points[0]) + self.vector_slope(
|
||||
points[3] - points[2]) < self.vector_slope(points[
|
||||
2] - points[1]) + self.vector_slope(points[0] - points[
|
||||
3]):
|
||||
horizontal_edge_inds = [[0, 1], [2, 3]]
|
||||
vertical_edge_inds = [[3, 0], [1, 2]]
|
||||
else:
|
||||
horizontal_edge_inds = [[3, 0], [1, 2]]
|
||||
vertical_edge_inds = [[0, 1], [2, 3]]
|
||||
|
||||
vertical_len_sum = norm(points[vertical_edge_inds[0][0]] - points[
|
||||
vertical_edge_inds[0][1]]) + norm(points[vertical_edge_inds[1][
|
||||
0]] - points[vertical_edge_inds[1][1]])
|
||||
horizontal_len_sum = norm(points[horizontal_edge_inds[0][
|
||||
0]] - points[horizontal_edge_inds[0][1]]) + norm(points[
|
||||
horizontal_edge_inds[1][0]] - points[horizontal_edge_inds[1]
|
||||
[1]])
|
||||
|
||||
if vertical_len_sum > horizontal_len_sum * orientation_thr:
|
||||
head_inds = horizontal_edge_inds[0]
|
||||
tail_inds = horizontal_edge_inds[1]
|
||||
else:
|
||||
head_inds = vertical_edge_inds[0]
|
||||
tail_inds = vertical_edge_inds[1]
|
||||
|
||||
return head_inds, tail_inds
|
||||
|
||||
def resample_sidelines(self, sideline1, sideline2, resample_step):
|
||||
"""Resample two sidelines to be of the same points number according to
|
||||
step size.
|
||||
|
||||
Args:
|
||||
sideline1 (ndarray): The points composing a sideline of a text
|
||||
polygon.
|
||||
sideline2 (ndarray): The points composing another sideline of a
|
||||
text polygon.
|
||||
resample_step (float): The resampled step size.
|
||||
|
||||
Returns:
|
||||
resampled_line1 (ndarray): The resampled line 1.
|
||||
resampled_line2 (ndarray): The resampled line 2.
|
||||
"""
|
||||
|
||||
assert sideline1.ndim == sideline2.ndim == 2
|
||||
assert sideline1.shape[1] == sideline2.shape[1] == 2
|
||||
assert sideline1.shape[0] >= 2
|
||||
assert sideline2.shape[0] >= 2
|
||||
assert isinstance(resample_step, float)
|
||||
|
||||
length1 = sum([
|
||||
norm(sideline1[i + 1] - sideline1[i])
|
||||
for i in range(len(sideline1) - 1)
|
||||
])
|
||||
length2 = sum([
|
||||
norm(sideline2[i + 1] - sideline2[i])
|
||||
for i in range(len(sideline2) - 1)
|
||||
])
|
||||
|
||||
total_length = (length1 + length2) / 2
|
||||
resample_point_num = max(int(float(total_length) / resample_step), 1)
|
||||
|
||||
resampled_line1 = self.resample_line(sideline1, resample_point_num)
|
||||
resampled_line2 = self.resample_line(sideline2, resample_point_num)
|
||||
|
||||
return resampled_line1, resampled_line2
|
||||
|
||||
def generate_center_region_mask(self, img_size, text_polys):
|
||||
"""Generate text center region mask.
|
||||
|
||||
Args:
|
||||
img_size (tuple): The image size of (height, width).
|
||||
text_polys (list[list[ndarray]]): The list of text polygons.
|
||||
|
||||
Returns:
|
||||
center_region_mask (ndarray): The text center region mask.
|
||||
"""
|
||||
|
||||
assert isinstance(img_size, tuple)
|
||||
# assert check_argument.is_2dlist(text_polys)
|
||||
|
||||
h, w = img_size
|
||||
|
||||
center_region_mask = np.zeros((h, w), np.uint8)
|
||||
|
||||
center_region_boxes = []
|
||||
for poly in text_polys:
|
||||
# assert len(poly) == 1
|
||||
polygon_points = poly.reshape(-1, 2)
|
||||
_, _, top_line, bot_line = self.reorder_poly_edge(polygon_points)
|
||||
resampled_top_line, resampled_bot_line = self.resample_sidelines(
|
||||
top_line, bot_line, self.resample_step)
|
||||
resampled_bot_line = resampled_bot_line[::-1]
|
||||
center_line = (resampled_top_line + resampled_bot_line) / 2
|
||||
|
||||
line_head_shrink_len = norm(resampled_top_line[0] -
|
||||
resampled_bot_line[0]) / 4.0
|
||||
line_tail_shrink_len = norm(resampled_top_line[-1] -
|
||||
resampled_bot_line[-1]) / 4.0
|
||||
head_shrink_num = int(line_head_shrink_len // self.resample_step)
|
||||
tail_shrink_num = int(line_tail_shrink_len // self.resample_step)
|
||||
if len(center_line) > head_shrink_num + tail_shrink_num + 2:
|
||||
center_line = center_line[head_shrink_num:len(center_line) -
|
||||
tail_shrink_num]
|
||||
resampled_top_line = resampled_top_line[head_shrink_num:len(
|
||||
resampled_top_line) - tail_shrink_num]
|
||||
resampled_bot_line = resampled_bot_line[head_shrink_num:len(
|
||||
resampled_bot_line) - tail_shrink_num]
|
||||
|
||||
for i in range(0, len(center_line) - 1):
|
||||
tl = center_line[i] + (resampled_top_line[i] - center_line[i]
|
||||
) * self.center_region_shrink_ratio
|
||||
tr = center_line[i + 1] + (resampled_top_line[i + 1] -
|
||||
center_line[i + 1]
|
||||
) * self.center_region_shrink_ratio
|
||||
br = center_line[i + 1] + (resampled_bot_line[i + 1] -
|
||||
center_line[i + 1]
|
||||
) * self.center_region_shrink_ratio
|
||||
bl = center_line[i] + (resampled_bot_line[i] - center_line[i]
|
||||
) * self.center_region_shrink_ratio
|
||||
current_center_box = np.vstack([tl, tr, br,
|
||||
bl]).astype(np.int32)
|
||||
center_region_boxes.append(current_center_box)
|
||||
|
||||
cv2.fillPoly(center_region_mask, center_region_boxes, 1)
|
||||
return center_region_mask
|
||||
|
||||
def resample_polygon(self, polygon, n=400):
|
||||
"""Resample one polygon with n points on its boundary.
|
||||
|
||||
Args:
|
||||
polygon (list[float]): The input polygon.
|
||||
n (int): The number of resampled points.
|
||||
Returns:
|
||||
resampled_polygon (list[float]): The resampled polygon.
|
||||
"""
|
||||
length = []
|
||||
|
||||
for i in range(len(polygon)):
|
||||
p1 = polygon[i]
|
||||
if i == len(polygon) - 1:
|
||||
p2 = polygon[0]
|
||||
else:
|
||||
p2 = polygon[i + 1]
|
||||
length.append(((p1[0] - p2[0])**2 + (p1[1] - p2[1])**2)**0.5)
|
||||
|
||||
total_length = sum(length)
|
||||
n_on_each_line = (np.array(length) / (total_length + 1e-8)) * n
|
||||
n_on_each_line = n_on_each_line.astype(np.int32)
|
||||
new_polygon = []
|
||||
|
||||
for i in range(len(polygon)):
|
||||
num = n_on_each_line[i]
|
||||
p1 = polygon[i]
|
||||
if i == len(polygon) - 1:
|
||||
p2 = polygon[0]
|
||||
else:
|
||||
p2 = polygon[i + 1]
|
||||
|
||||
if num == 0:
|
||||
continue
|
||||
|
||||
dxdy = (p2 - p1) / num
|
||||
for j in range(num):
|
||||
point = p1 + dxdy * j
|
||||
new_polygon.append(point)
|
||||
|
||||
return np.array(new_polygon)
|
||||
|
||||
def normalize_polygon(self, polygon):
|
||||
"""Normalize one polygon so that its start point is at right most.
|
||||
|
||||
Args:
|
||||
polygon (list[float]): The origin polygon.
|
||||
Returns:
|
||||
new_polygon (lost[float]): The polygon with start point at right.
|
||||
"""
|
||||
temp_polygon = polygon - polygon.mean(axis=0)
|
||||
x = np.abs(temp_polygon[:, 0])
|
||||
y = temp_polygon[:, 1]
|
||||
index_x = np.argsort(x)
|
||||
index_y = np.argmin(y[index_x[:8]])
|
||||
index = index_x[index_y]
|
||||
new_polygon = np.concatenate([polygon[index:], polygon[:index]])
|
||||
return new_polygon
|
||||
|
||||
def poly2fourier(self, polygon, fourier_degree):
|
||||
"""Perform Fourier transformation to generate Fourier coefficients ck
|
||||
from polygon.
|
||||
|
||||
Args:
|
||||
polygon (ndarray): An input polygon.
|
||||
fourier_degree (int): The maximum Fourier degree K.
|
||||
Returns:
|
||||
c (ndarray(complex)): Fourier coefficients.
|
||||
"""
|
||||
points = polygon[:, 0] + polygon[:, 1] * 1j
|
||||
c_fft = fft(points) / len(points)
|
||||
c = np.hstack((c_fft[-fourier_degree:], c_fft[:fourier_degree + 1]))
|
||||
return c
|
||||
|
||||
def clockwise(self, c, fourier_degree):
|
||||
"""Make sure the polygon reconstructed from Fourier coefficients c in
|
||||
the clockwise direction.
|
||||
|
||||
Args:
|
||||
polygon (list[float]): The origin polygon.
|
||||
Returns:
|
||||
new_polygon (lost[float]): The polygon in clockwise point order.
|
||||
"""
|
||||
if np.abs(c[fourier_degree + 1]) > np.abs(c[fourier_degree - 1]):
|
||||
return c
|
||||
elif np.abs(c[fourier_degree + 1]) < np.abs(c[fourier_degree - 1]):
|
||||
return c[::-1]
|
||||
else:
|
||||
if np.abs(c[fourier_degree + 2]) > np.abs(c[fourier_degree - 2]):
|
||||
return c
|
||||
else:
|
||||
return c[::-1]
|
||||
|
||||
def cal_fourier_signature(self, polygon, fourier_degree):
|
||||
"""Calculate Fourier signature from input polygon.
|
||||
|
||||
Args:
|
||||
polygon (ndarray): The input polygon.
|
||||
fourier_degree (int): The maximum Fourier degree K.
|
||||
Returns:
|
||||
fourier_signature (ndarray): An array shaped (2k+1, 2) containing
|
||||
real part and image part of 2k+1 Fourier coefficients.
|
||||
"""
|
||||
resampled_polygon = self.resample_polygon(polygon)
|
||||
resampled_polygon = self.normalize_polygon(resampled_polygon)
|
||||
|
||||
fourier_coeff = self.poly2fourier(resampled_polygon, fourier_degree)
|
||||
fourier_coeff = self.clockwise(fourier_coeff, fourier_degree)
|
||||
|
||||
real_part = np.real(fourier_coeff).reshape((-1, 1))
|
||||
image_part = np.imag(fourier_coeff).reshape((-1, 1))
|
||||
fourier_signature = np.hstack([real_part, image_part])
|
||||
|
||||
return fourier_signature
|
||||
|
||||
def generate_fourier_maps(self, img_size, text_polys):
|
||||
"""Generate Fourier coefficient maps.
|
||||
|
||||
Args:
|
||||
img_size (tuple): The image size of (height, width).
|
||||
text_polys (list[list[ndarray]]): The list of text polygons.
|
||||
|
||||
Returns:
|
||||
fourier_real_map (ndarray): The Fourier coefficient real part maps.
|
||||
fourier_image_map (ndarray): The Fourier coefficient image part
|
||||
maps.
|
||||
"""
|
||||
|
||||
assert isinstance(img_size, tuple)
|
||||
# assert check_argument.is_2dlist(text_polys)
|
||||
|
||||
h, w = img_size
|
||||
k = self.fourier_degree
|
||||
real_map = np.zeros((k * 2 + 1, h, w), dtype=np.float32)
|
||||
imag_map = np.zeros((k * 2 + 1, h, w), dtype=np.float32)
|
||||
|
||||
for poly in text_polys:
|
||||
# assert len(poly) == 1
|
||||
# text_instance = [[poly[i], poly[i + 1]]
|
||||
# for i in range(0, len(poly), 2)]
|
||||
mask = np.zeros((h, w), dtype=np.uint8)
|
||||
polygon = np.array(poly).reshape((1, -1, 2))
|
||||
cv2.fillPoly(mask, polygon.astype(np.int32), 1)
|
||||
fourier_coeff = self.cal_fourier_signature(polygon[0], k)
|
||||
for i in range(-k, k + 1):
|
||||
if i != 0:
|
||||
real_map[i + k, :, :] = mask * fourier_coeff[i + k, 0] + (
|
||||
1 - mask) * real_map[i + k, :, :]
|
||||
imag_map[i + k, :, :] = mask * fourier_coeff[i + k, 1] + (
|
||||
1 - mask) * imag_map[i + k, :, :]
|
||||
else:
|
||||
yx = np.argwhere(mask > 0.5)
|
||||
k_ind = np.ones((len(yx)), dtype=np.int64) * k
|
||||
y, x = yx[:, 0], yx[:, 1]
|
||||
real_map[k_ind, y, x] = fourier_coeff[k, 0] - x
|
||||
imag_map[k_ind, y, x] = fourier_coeff[k, 1] - y
|
||||
|
||||
return real_map, imag_map
|
||||
|
||||
def generate_text_region_mask(self, img_size, text_polys):
|
||||
"""Generate text center region mask and geometry attribute maps.
|
||||
|
||||
Args:
|
||||
img_size (tuple): The image size (height, width).
|
||||
text_polys (list[list[ndarray]]): The list of text polygons.
|
||||
|
||||
Returns:
|
||||
text_region_mask (ndarray): The text region mask.
|
||||
"""
|
||||
|
||||
assert isinstance(img_size, tuple)
|
||||
# assert check_argument.is_2dlist(text_polys)
|
||||
|
||||
h, w = img_size
|
||||
text_region_mask = np.zeros((h, w), dtype=np.uint8)
|
||||
|
||||
for poly in text_polys:
|
||||
# assert len(poly) == 1
|
||||
# text_instance = [[poly[i], poly[i + 1]]
|
||||
# for i in range(0, len(poly), 2)]
|
||||
polygon = np.array(poly, dtype=np.int32).reshape((1, -1, 2))
|
||||
cv2.fillPoly(text_region_mask, polygon, 1)
|
||||
|
||||
return text_region_mask
|
||||
|
||||
def generate_effective_mask(self, mask_size: tuple, polygons_ignore):
|
||||
"""Generate effective mask by setting the ineffective regions to 0 and
|
||||
effective regions to 1.
|
||||
|
||||
Args:
|
||||
mask_size (tuple): The mask size.
|
||||
polygons_ignore (list[[ndarray]]: The list of ignored text
|
||||
polygons.
|
||||
|
||||
Returns:
|
||||
mask (ndarray): The effective mask of (height, width).
|
||||
"""
|
||||
|
||||
# assert check_argument.is_2dlist(polygons_ignore)
|
||||
|
||||
mask = np.ones(mask_size, dtype=np.uint8)
|
||||
|
||||
for poly in polygons_ignore:
|
||||
instance = poly.reshape(-1, 2).astype(np.int32).reshape(1, -1, 2)
|
||||
cv2.fillPoly(mask, instance, 0)
|
||||
|
||||
return mask
|
||||
|
||||
def generate_level_targets(self, img_size, text_polys, ignore_polys):
|
||||
"""Generate ground truth target on each level.
|
||||
|
||||
Args:
|
||||
img_size (list[int]): Shape of input image.
|
||||
text_polys (list[list[ndarray]]): A list of ground truth polygons.
|
||||
ignore_polys (list[list[ndarray]]): A list of ignored polygons.
|
||||
Returns:
|
||||
level_maps (list(ndarray)): A list of ground target on each level.
|
||||
"""
|
||||
h, w = img_size
|
||||
lv_size_divs = self.level_size_divisors
|
||||
lv_proportion_range = self.level_proportion_range
|
||||
lv_text_polys = [[] for i in range(len(lv_size_divs))]
|
||||
lv_ignore_polys = [[] for i in range(len(lv_size_divs))]
|
||||
level_maps = []
|
||||
for poly in text_polys:
|
||||
# assert len(poly) == 1
|
||||
# text_instance = [[poly[i], poly[i + 1]]
|
||||
# for i in range(0, len(poly), 2)]
|
||||
polygon = np.array(poly, dtype=np.int).reshape((1, -1, 2))
|
||||
_, _, box_w, box_h = cv2.boundingRect(polygon)
|
||||
proportion = max(box_h, box_w) / (h + 1e-8)
|
||||
|
||||
for ind, proportion_range in enumerate(lv_proportion_range):
|
||||
if proportion_range[0] < proportion < proportion_range[1]:
|
||||
lv_text_polys[ind].append(poly / lv_size_divs[ind])
|
||||
|
||||
for ignore_poly in ignore_polys:
|
||||
# assert len(ignore_poly) == 1
|
||||
# text_instance = [[ignore_poly[i], ignore_poly[i + 1]]
|
||||
# for i in range(0, len(ignore_poly), 2)]
|
||||
polygon = np.array(ignore_poly, dtype=np.int).reshape((1, -1, 2))
|
||||
_, _, box_w, box_h = cv2.boundingRect(polygon)
|
||||
proportion = max(box_h, box_w) / (h + 1e-8)
|
||||
|
||||
for ind, proportion_range in enumerate(lv_proportion_range):
|
||||
if proportion_range[0] < proportion < proportion_range[1]:
|
||||
lv_ignore_polys[ind].append(ignore_poly / lv_size_divs[ind])
|
||||
|
||||
for ind, size_divisor in enumerate(lv_size_divs):
|
||||
current_level_maps = []
|
||||
level_img_size = (h // size_divisor, w // size_divisor)
|
||||
|
||||
text_region = self.generate_text_region_mask(
|
||||
level_img_size, lv_text_polys[ind])[None]
|
||||
current_level_maps.append(text_region)
|
||||
|
||||
center_region = self.generate_center_region_mask(
|
||||
level_img_size, lv_text_polys[ind])[None]
|
||||
current_level_maps.append(center_region)
|
||||
|
||||
effective_mask = self.generate_effective_mask(
|
||||
level_img_size, lv_ignore_polys[ind])[None]
|
||||
current_level_maps.append(effective_mask)
|
||||
|
||||
fourier_real_map, fourier_image_maps = self.generate_fourier_maps(
|
||||
level_img_size, lv_text_polys[ind])
|
||||
current_level_maps.append(fourier_real_map)
|
||||
current_level_maps.append(fourier_image_maps)
|
||||
|
||||
level_maps.append(np.concatenate(current_level_maps))
|
||||
|
||||
return level_maps
|
||||
|
||||
def generate_targets(self, results):
|
||||
"""Generate the ground truth targets for FCENet.
|
||||
|
||||
Args:
|
||||
results (dict): The input result dictionary.
|
||||
|
||||
Returns:
|
||||
results (dict): The output result dictionary.
|
||||
"""
|
||||
|
||||
assert isinstance(results, dict)
|
||||
image = results['image']
|
||||
polygons = results['polys']
|
||||
ignore_tags = results['ignore_tags']
|
||||
h, w, _ = image.shape
|
||||
|
||||
# import time
|
||||
# from PIL import Image, ImageDraw
|
||||
# cur_time = time.time()
|
||||
# image = results['image']
|
||||
# text_polys = results['polys']
|
||||
# img = image[..., ::-1]
|
||||
# img = Image.fromarray(img)
|
||||
# draw = ImageDraw.Draw(img)
|
||||
# for box in text_polys:
|
||||
# draw.polygon(box, outline=(0, 255, 255,), )
|
||||
# img.save('tmp/{}_resize_pad.jpg'.format(cur_time))
|
||||
|
||||
polygon_masks = []
|
||||
polygon_masks_ignore = []
|
||||
for tag, polygon in zip(ignore_tags, polygons):
|
||||
if tag is True:
|
||||
polygon_masks_ignore.append(polygon)
|
||||
else:
|
||||
polygon_masks.append(polygon)
|
||||
|
||||
level_maps = self.generate_level_targets((h, w), polygon_masks,
|
||||
polygon_masks_ignore)
|
||||
|
||||
# results['mask_fields'].clear() # rm gt_masks encoded by polygons
|
||||
# import remote_pdb as pdb;pdb.set_trace()
|
||||
mapping = {
|
||||
'p3_maps': level_maps[0],
|
||||
'p4_maps': level_maps[1],
|
||||
'p5_maps': level_maps[2]
|
||||
}
|
||||
for key, value in mapping.items():
|
||||
results[key] = value
|
||||
|
||||
return results
|
||||
|
||||
def __call__(self, results):
|
||||
results = self.generate_targets(results)
|
||||
return results
|
|
@ -60,9 +60,14 @@ class DecodeImage(object):
|
|||
class NRTRDecodeImage(object):
|
||||
""" decode image """
|
||||
|
||||
def __init__(self, img_mode='RGB', channel_first=False, **kwargs):
|
||||
def __init__(self,
|
||||
img_mode='RGB',
|
||||
channel_first=False,
|
||||
ignore_orientation=False,
|
||||
**kwargs):
|
||||
self.img_mode = img_mode
|
||||
self.channel_first = channel_first
|
||||
self.ignore_orientation = ignore_orientation
|
||||
|
||||
def __call__(self, data):
|
||||
img = data['image']
|
||||
|
@ -74,7 +79,11 @@ class NRTRDecodeImage(object):
|
|||
img) > 0, "invalid input 'img' in DecodeImage"
|
||||
img = np.frombuffer(img, dtype='uint8')
|
||||
|
||||
img = cv2.imdecode(img, 1)
|
||||
if self.ignore_orientation:
|
||||
img = cv2.imdecode(img, cv2.IMREAD_IGNORE_ORIENTATION |
|
||||
cv2.IMREAD_COLOR)
|
||||
else:
|
||||
img = cv2.imdecode(img, 1)
|
||||
|
||||
if img is None:
|
||||
return None
|
||||
|
|
|
@ -24,6 +24,7 @@ from .det_db_loss import DBLoss
|
|||
from .det_east_loss import EASTLoss
|
||||
from .det_sast_loss import SASTLoss
|
||||
from .det_pse_loss import PSELoss
|
||||
from .det_fce_loss import FCELoss
|
||||
|
||||
# rec loss
|
||||
from .rec_ctc_loss import CTCLoss
|
||||
|
@ -55,9 +56,9 @@ from .vqa_token_layoutlm_loss import VQASerTokenLayoutLMLoss
|
|||
|
||||
def build_loss(config):
|
||||
support_dict = [
|
||||
'DBLoss', 'PSELoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss',
|
||||
'AttentionLoss', 'SRNLoss', 'PGLoss', 'CombinedLoss', 'NRTRLoss',
|
||||
'TableAttentionLoss', 'SARLoss', 'AsterLoss', 'SDMGRLoss',
|
||||
'DBLoss', 'PSELoss', 'EASTLoss', 'SASTLoss', 'FCELoss', 'CTCLoss',
|
||||
'ClsLoss', 'AttentionLoss', 'SRNLoss', 'PGLoss', 'CombinedLoss',
|
||||
'NRTRLoss', 'TableAttentionLoss', 'SARLoss', 'AsterLoss', 'SDMGRLoss',
|
||||
'VQASerTokenLayoutLMLoss', 'LossFromOutput'
|
||||
]
|
||||
config = copy.deepcopy(config)
|
||||
|
|
|
@ -0,0 +1,212 @@
|
|||
import numpy as np
|
||||
from paddle import nn
|
||||
import paddle
|
||||
import paddle.nn.functional as F
|
||||
from functools import partial
|
||||
|
||||
|
||||
def multi_apply(func, *args, **kwargs):
|
||||
pfunc = partial(func, **kwargs) if kwargs else func
|
||||
map_results = map(pfunc, *args)
|
||||
return tuple(map(list, zip(*map_results)))
|
||||
|
||||
|
||||
class FCELoss(nn.Layer):
|
||||
"""The class for implementing FCENet loss
|
||||
FCENet(CVPR2021): Fourier Contour Embedding for Arbitrary-shaped
|
||||
Text Detection
|
||||
|
||||
[https://arxiv.org/abs/2104.10442]
|
||||
|
||||
Args:
|
||||
fourier_degree (int) : The maximum Fourier transform degree k.
|
||||
num_sample (int) : The sampling points number of regression
|
||||
loss. If it is too small, fcenet tends to be overfitting.
|
||||
ohem_ratio (float): the negative/positive ratio in OHEM.
|
||||
"""
|
||||
|
||||
def __init__(self, fourier_degree, num_sample, ohem_ratio=3.):
|
||||
super().__init__()
|
||||
self.fourier_degree = fourier_degree
|
||||
self.num_sample = num_sample
|
||||
self.ohem_ratio = ohem_ratio
|
||||
|
||||
def forward(self, preds, labels):
|
||||
assert isinstance(preds, dict)
|
||||
preds = preds['levels']
|
||||
|
||||
p3_maps, p4_maps, p5_maps = labels[1:]
|
||||
assert p3_maps[0].shape[0] == 4 * self.fourier_degree + 5,\
|
||||
'fourier degree not equal in FCEhead and FCEtarget'
|
||||
|
||||
# device = preds[0][0].device
|
||||
# to tensor
|
||||
gts = [p3_maps, p4_maps, p5_maps]
|
||||
for idx, maps in enumerate(gts):
|
||||
gts[idx] = paddle.to_tensor(np.stack(maps))
|
||||
|
||||
losses = multi_apply(self.forward_single, preds, gts)
|
||||
|
||||
loss_tr = paddle.to_tensor(0.).astype('float32')
|
||||
loss_tcl = paddle.to_tensor(0.).astype('float32')
|
||||
loss_reg_x = paddle.to_tensor(0.).astype('float32')
|
||||
loss_reg_y = paddle.to_tensor(0.).astype('float32')
|
||||
loss_all = paddle.to_tensor(0.).astype('float32')
|
||||
|
||||
for idx, loss in enumerate(losses):
|
||||
loss_all += sum(loss)
|
||||
if idx == 0:
|
||||
loss_tr += sum(loss)
|
||||
elif idx == 1:
|
||||
loss_tcl += sum(loss)
|
||||
elif idx == 2:
|
||||
loss_reg_x += sum(loss)
|
||||
else:
|
||||
loss_reg_y += sum(loss)
|
||||
|
||||
results = dict(
|
||||
loss=loss_all,
|
||||
loss_text=loss_tr,
|
||||
loss_center=loss_tcl,
|
||||
loss_reg_x=loss_reg_x,
|
||||
loss_reg_y=loss_reg_y, )
|
||||
return results
|
||||
|
||||
def forward_single(self, pred, gt):
|
||||
cls_pred = paddle.transpose(pred[0], (0, 2, 3, 1))
|
||||
reg_pred = paddle.transpose(pred[1], (0, 2, 3, 1))
|
||||
gt = paddle.transpose(gt, (0, 2, 3, 1))
|
||||
|
||||
k = 2 * self.fourier_degree + 1
|
||||
tr_pred = paddle.reshape(cls_pred[:, :, :, :2], (-1, 2))
|
||||
tcl_pred = paddle.reshape(cls_pred[:, :, :, 2:], (-1, 2))
|
||||
x_pred = paddle.reshape(reg_pred[:, :, :, 0:k], (-1, k))
|
||||
y_pred = paddle.reshape(reg_pred[:, :, :, k:2 * k], (-1, k))
|
||||
|
||||
tr_mask = gt[:, :, :, :1].reshape([-1])
|
||||
tcl_mask = gt[:, :, :, 1:2].reshape([-1])
|
||||
train_mask = gt[:, :, :, 2:3].reshape([-1])
|
||||
x_map = paddle.reshape(gt[:, :, :, 3:3 + k], (-1, k))
|
||||
y_map = paddle.reshape(gt[:, :, :, 3 + k:], (-1, k))
|
||||
|
||||
tr_train_mask = (train_mask * tr_mask).astype('bool')
|
||||
tr_train_mask2 = paddle.concat(
|
||||
[tr_train_mask.unsqueeze(1), tr_train_mask.unsqueeze(1)], axis=1)
|
||||
# tr loss
|
||||
loss_tr = self.ohem(tr_pred, tr_mask, train_mask)
|
||||
# import pdb; pdb.set_trace()
|
||||
# tcl loss
|
||||
loss_tcl = paddle.to_tensor(0.).astype('float32')
|
||||
tr_neg_mask = tr_train_mask.logical_not()
|
||||
tr_neg_mask2 = paddle.concat(
|
||||
[tr_neg_mask.unsqueeze(1), tr_neg_mask.unsqueeze(1)], axis=1)
|
||||
if tr_train_mask.sum().item() > 0:
|
||||
loss_tcl_pos = F.cross_entropy(
|
||||
tcl_pred.masked_select(tr_train_mask2).reshape([-1, 2]),
|
||||
tcl_mask.masked_select(tr_train_mask).astype('int64'))
|
||||
loss_tcl_neg = F.cross_entropy(
|
||||
tcl_pred.masked_select(tr_neg_mask2).reshape([-1, 2]),
|
||||
tcl_mask.masked_select(tr_neg_mask).astype('int64'))
|
||||
loss_tcl = loss_tcl_pos + 0.5 * loss_tcl_neg
|
||||
|
||||
# regression loss
|
||||
loss_reg_x = paddle.to_tensor(0.).astype('float32')
|
||||
loss_reg_y = paddle.to_tensor(0.).astype('float32')
|
||||
if tr_train_mask.sum().item() > 0:
|
||||
weight = (tr_mask.masked_select(tr_train_mask.astype('bool'))
|
||||
.astype('float32') + tcl_mask.masked_select(
|
||||
tr_train_mask.astype('bool')).astype('float32')) / 2
|
||||
weight = weight.reshape([-1, 1])
|
||||
|
||||
ft_x, ft_y = self.fourier2poly(x_map, y_map)
|
||||
ft_x_pre, ft_y_pre = self.fourier2poly(x_pred, y_pred)
|
||||
|
||||
dim = ft_x.shape[1]
|
||||
|
||||
tr_train_mask3 = paddle.concat(
|
||||
[tr_train_mask.unsqueeze(1) for i in range(dim)], axis=1)
|
||||
|
||||
loss_reg_x = paddle.mean(weight * F.smooth_l1_loss(
|
||||
ft_x_pre.masked_select(tr_train_mask3).reshape([-1, dim]),
|
||||
ft_x.masked_select(tr_train_mask3).reshape([-1, dim]),
|
||||
reduction='none'))
|
||||
loss_reg_y = paddle.mean(weight * F.smooth_l1_loss(
|
||||
ft_y_pre.masked_select(tr_train_mask3).reshape([-1, dim]),
|
||||
ft_y.masked_select(tr_train_mask3).reshape([-1, dim]),
|
||||
reduction='none'))
|
||||
|
||||
return loss_tr, loss_tcl, loss_reg_x, loss_reg_y
|
||||
|
||||
def ohem(self, predict, target, train_mask):
|
||||
# device = train_mask.device
|
||||
|
||||
pos = (target * train_mask).astype('bool')
|
||||
neg = ((1 - target) * train_mask).astype('bool')
|
||||
|
||||
pos2 = paddle.concat([pos.unsqueeze(1), pos.unsqueeze(1)], axis=1)
|
||||
neg2 = paddle.concat([neg.unsqueeze(1), neg.unsqueeze(1)], axis=1)
|
||||
|
||||
n_pos = pos.astype('float32').sum()
|
||||
|
||||
if n_pos.item() > 0:
|
||||
loss_pos = F.cross_entropy(
|
||||
predict.masked_select(pos2).reshape([-1, 2]),
|
||||
target.masked_select(pos).astype('int64'),
|
||||
reduction='sum')
|
||||
loss_neg = F.cross_entropy(
|
||||
predict.masked_select(neg2).reshape([-1, 2]),
|
||||
target.masked_select(neg).astype('int64'),
|
||||
reduction='none')
|
||||
n_neg = min(
|
||||
int(neg.astype('float32').sum().item()),
|
||||
int(self.ohem_ratio * n_pos.astype('float32')))
|
||||
else:
|
||||
loss_pos = paddle.to_tensor(0.)
|
||||
loss_neg = F.cross_entropy(
|
||||
predict.masked_select(neg2).reshape([-1, 2]),
|
||||
target.masked_select(neg).astype('int64'),
|
||||
reduction='none')
|
||||
n_neg = 100
|
||||
if len(loss_neg) > n_neg:
|
||||
loss_neg, _ = paddle.topk(loss_neg, n_neg)
|
||||
|
||||
return (loss_pos + loss_neg.sum()) / (n_pos + n_neg).astype('float32')
|
||||
|
||||
def fourier2poly(self, real_maps, imag_maps):
|
||||
"""Transform Fourier coefficient maps to polygon maps.
|
||||
|
||||
Args:
|
||||
real_maps (tensor): A map composed of the real parts of the
|
||||
Fourier coefficients, whose shape is (-1, 2k+1)
|
||||
imag_maps (tensor):A map composed of the imag parts of the
|
||||
Fourier coefficients, whose shape is (-1, 2k+1)
|
||||
|
||||
Returns
|
||||
x_maps (tensor): A map composed of the x value of the polygon
|
||||
represented by n sample points (xn, yn), whose shape is (-1, n)
|
||||
y_maps (tensor): A map composed of the y value of the polygon
|
||||
represented by n sample points (xn, yn), whose shape is (-1, n)
|
||||
"""
|
||||
|
||||
k_vect = paddle.arange(
|
||||
-self.fourier_degree, self.fourier_degree + 1,
|
||||
dtype='float32').reshape([-1, 1])
|
||||
i_vect = paddle.arange(
|
||||
0, self.num_sample, dtype='float32').reshape([1, -1])
|
||||
|
||||
transform_matrix = 2 * np.pi / self.num_sample * paddle.matmul(k_vect,
|
||||
i_vect)
|
||||
|
||||
x1 = paddle.einsum('ak, kn-> an', real_maps,
|
||||
paddle.cos(transform_matrix))
|
||||
x2 = paddle.einsum('ak, kn-> an', imag_maps,
|
||||
paddle.sin(transform_matrix))
|
||||
y1 = paddle.einsum('ak, kn-> an', real_maps,
|
||||
paddle.sin(transform_matrix))
|
||||
y2 = paddle.einsum('ak, kn-> an', imag_maps,
|
||||
paddle.cos(transform_matrix))
|
||||
|
||||
x_maps = x1 - x2
|
||||
y_maps = y1 + y2
|
||||
|
||||
return x_maps, y_maps
|
|
@ -21,7 +21,7 @@ import copy
|
|||
|
||||
__all__ = ["build_metric"]
|
||||
|
||||
from .det_metric import DetMetric
|
||||
from .det_metric import DetMetric, DetFCEMetric
|
||||
from .rec_metric import RecMetric
|
||||
from .cls_metric import ClsMetric
|
||||
from .e2e_metric import E2EMetric
|
||||
|
@ -34,7 +34,7 @@ from .vqa_token_re_metric import VQAReTokenMetric
|
|||
|
||||
def build_metric(config):
|
||||
support_dict = [
|
||||
"DetMetric", "RecMetric", "ClsMetric", "E2EMetric",
|
||||
"DetMetric", "DetFCEMetric", "RecMetric", "ClsMetric", "E2EMetric",
|
||||
"DistillationMetric", "TableMetric", 'KIEMetric', 'VQASerTokenMetric',
|
||||
'VQAReTokenMetric'
|
||||
]
|
||||
|
|
|
@ -16,7 +16,7 @@ from __future__ import absolute_import
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
__all__ = ['DetMetric']
|
||||
__all__ = ['DetMetric', 'DetFCEMetric']
|
||||
|
||||
from .eval_det_iou import DetectionIoUEvaluator
|
||||
|
||||
|
@ -55,7 +55,6 @@ class DetMetric(object):
|
|||
result = self.evaluator.evaluate_image(gt_info_list, det_info_list)
|
||||
self.results.append(result)
|
||||
|
||||
|
||||
def get_metric(self):
|
||||
"""
|
||||
return metrics {
|
||||
|
@ -71,3 +70,85 @@ class DetMetric(object):
|
|||
|
||||
def reset(self):
|
||||
self.results = [] # clear results
|
||||
|
||||
|
||||
class DetFCEMetric(object):
|
||||
def __init__(self, main_indicator='hmean', **kwargs):
|
||||
self.evaluator = DetectionIoUEvaluator()
|
||||
self.main_indicator = main_indicator
|
||||
self.reset()
|
||||
|
||||
def __call__(self, preds, batch, **kwargs):
|
||||
'''
|
||||
batch: a list produced by dataloaders.
|
||||
image: np.ndarray of shape (N, C, H, W).
|
||||
ratio_list: np.ndarray of shape(N,2)
|
||||
polygons: np.ndarray of shape (N, K, 4, 2), the polygons of objective regions.
|
||||
ignore_tags: np.ndarray of shape (N, K), indicates whether a region is ignorable or not.
|
||||
preds: a list of dict produced by post process
|
||||
points: np.ndarray of shape (N, K, 4, 2), the polygons of objective regions.
|
||||
'''
|
||||
gt_polyons_batch = batch[2]
|
||||
ignore_tags_batch = batch[3]
|
||||
|
||||
for pred, gt_polyons, ignore_tags in zip(preds, gt_polyons_batch,
|
||||
ignore_tags_batch):
|
||||
# prepare gt
|
||||
gt_info_list = [{
|
||||
'points': gt_polyon,
|
||||
'text': '',
|
||||
'ignore': ignore_tag
|
||||
} for gt_polyon, ignore_tag in zip(gt_polyons, ignore_tags)]
|
||||
# prepare det
|
||||
det_info_list = [{
|
||||
'points': det_polyon,
|
||||
'text': '',
|
||||
'score': score
|
||||
} for det_polyon, score in zip(pred['points'], pred['scores'])]
|
||||
|
||||
for score_thr in self.results.keys():
|
||||
det_info_list_thr = [
|
||||
det_info for det_info in det_info_list
|
||||
if det_info['score'] >= score_thr
|
||||
]
|
||||
result = self.evaluator.evaluate_image(gt_info_list,
|
||||
det_info_list_thr)
|
||||
self.results[score_thr].append(result)
|
||||
|
||||
def get_metric(self):
|
||||
"""
|
||||
return metrics {'heman':0,
|
||||
'thr 0.3':'precision: 0 recall: 0 hmean: 0',
|
||||
'thr 0.4':'precision: 0 recall: 0 hmean: 0',
|
||||
'thr 0.5':'precision: 0 recall: 0 hmean: 0',
|
||||
'thr 0.6':'precision: 0 recall: 0 hmean: 0',
|
||||
'thr 0.7':'precision: 0 recall: 0 hmean: 0',
|
||||
'thr 0.8':'precision: 0 recall: 0 hmean: 0',
|
||||
'thr 0.9':'precision: 0 recall: 0 hmean: 0',
|
||||
}
|
||||
"""
|
||||
metircs = {}
|
||||
hmean = 0
|
||||
for score_thr in self.results.keys():
|
||||
metirc = self.evaluator.combine_results(self.results[score_thr])
|
||||
# for key, value in metirc.items():
|
||||
# metircs['{}_{}'.format(key, score_thr)] = value
|
||||
metirc_str = 'precision:{:.5f} recall:{:.5f} hmean:{:.5f}'.format(
|
||||
metirc['precision'], metirc['recall'], metirc['hmean'])
|
||||
metircs['\n thr {}'.format(score_thr)] = metirc_str
|
||||
hmean = max(hmean, metirc['hmean'])
|
||||
metircs['hmean'] = hmean
|
||||
|
||||
self.reset()
|
||||
return metircs
|
||||
|
||||
def reset(self):
|
||||
self.results = {
|
||||
0.3: [],
|
||||
0.4: [],
|
||||
0.5: [],
|
||||
0.6: [],
|
||||
0.7: [],
|
||||
0.8: [],
|
||||
0.9: []
|
||||
} # clear results
|
||||
|
|
|
@ -21,9 +21,82 @@ from paddle import ParamAttr
|
|||
import paddle.nn as nn
|
||||
import paddle.nn.functional as F
|
||||
|
||||
from paddle.vision.ops import DeformConv2D
|
||||
from paddle.regularizer import L2Decay
|
||||
from paddle.nn.initializer import Normal, Constant, XavierUniform
|
||||
|
||||
__all__ = ["ResNet"]
|
||||
|
||||
|
||||
class DeformableConvV2(nn.Layer):
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
padding=0,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
weight_attr=None,
|
||||
bias_attr=None,
|
||||
lr_scale=1,
|
||||
regularizer=None,
|
||||
skip_quant=False,
|
||||
dcn_bias_regularizer=L2Decay(0.),
|
||||
dcn_bias_lr_scale=2.):
|
||||
super(DeformableConvV2, self).__init__()
|
||||
self.offset_channel = 2 * kernel_size**2 * groups
|
||||
self.mask_channel = kernel_size**2 * groups
|
||||
|
||||
if bias_attr:
|
||||
# in FCOS-DCN head, specifically need learning_rate and regularizer
|
||||
dcn_bias_attr = ParamAttr(
|
||||
initializer=Constant(value=0),
|
||||
regularizer=dcn_bias_regularizer,
|
||||
learning_rate=dcn_bias_lr_scale)
|
||||
else:
|
||||
# in ResNet backbone, do not need bias
|
||||
dcn_bias_attr = False
|
||||
self.conv_dcn = DeformConv2D(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=stride,
|
||||
padding=(kernel_size - 1) // 2 * dilation,
|
||||
dilation=dilation,
|
||||
deformable_groups=groups,
|
||||
weight_attr=weight_attr,
|
||||
bias_attr=dcn_bias_attr)
|
||||
|
||||
if lr_scale == 1 and regularizer is None:
|
||||
offset_bias_attr = ParamAttr(initializer=Constant(0.))
|
||||
else:
|
||||
offset_bias_attr = ParamAttr(
|
||||
initializer=Constant(0.),
|
||||
learning_rate=lr_scale,
|
||||
regularizer=regularizer)
|
||||
self.conv_offset = nn.Conv2D(
|
||||
in_channels,
|
||||
groups * 3 * kernel_size**2,
|
||||
kernel_size,
|
||||
stride=stride,
|
||||
padding=(kernel_size - 1) // 2,
|
||||
weight_attr=ParamAttr(initializer=Constant(0.0)),
|
||||
bias_attr=offset_bias_attr)
|
||||
if skip_quant:
|
||||
self.conv_offset.skip_quant = True
|
||||
|
||||
def forward(self, x):
|
||||
offset_mask = self.conv_offset(x)
|
||||
offset, mask = paddle.split(
|
||||
offset_mask,
|
||||
num_or_sections=[self.offset_channel, self.mask_channel],
|
||||
axis=1)
|
||||
mask = F.sigmoid(mask)
|
||||
y = self.conv_dcn(x, offset, mask=mask)
|
||||
return y
|
||||
|
||||
|
||||
class ConvBNLayer(nn.Layer):
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
|
@ -32,20 +105,31 @@ class ConvBNLayer(nn.Layer):
|
|||
stride=1,
|
||||
groups=1,
|
||||
is_vd_mode=False,
|
||||
act=None):
|
||||
act=None,
|
||||
is_dcn=False):
|
||||
super(ConvBNLayer, self).__init__()
|
||||
|
||||
self.is_vd_mode = is_vd_mode
|
||||
self._pool2d_avg = nn.AvgPool2D(
|
||||
kernel_size=2, stride=2, padding=0, ceil_mode=True)
|
||||
self._conv = nn.Conv2D(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=(kernel_size - 1) // 2,
|
||||
groups=groups,
|
||||
bias_attr=False)
|
||||
if not is_dcn:
|
||||
self._conv = nn.Conv2D(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=(kernel_size - 1) // 2,
|
||||
groups=groups,
|
||||
bias_attr=False)
|
||||
else:
|
||||
self._conv = DeformableConvV2(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=(kernel_size - 1) // 2,
|
||||
groups=2, #groups,
|
||||
bias_attr=False)
|
||||
self._batch_norm = nn.BatchNorm(out_channels, act=act)
|
||||
|
||||
def forward(self, inputs):
|
||||
|
@ -57,12 +141,14 @@ class ConvBNLayer(nn.Layer):
|
|||
|
||||
|
||||
class BottleneckBlock(nn.Layer):
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
stride,
|
||||
shortcut=True,
|
||||
if_first=False):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
stride,
|
||||
shortcut=True,
|
||||
if_first=False,
|
||||
is_dcn=False, ):
|
||||
super(BottleneckBlock, self).__init__()
|
||||
|
||||
self.conv0 = ConvBNLayer(
|
||||
|
@ -75,7 +161,8 @@ class BottleneckBlock(nn.Layer):
|
|||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
stride=stride,
|
||||
act='relu')
|
||||
act='relu',
|
||||
is_dcn=is_dcn)
|
||||
self.conv2 = ConvBNLayer(
|
||||
in_channels=out_channels,
|
||||
out_channels=out_channels * 4,
|
||||
|
@ -152,7 +239,12 @@ class BasicBlock(nn.Layer):
|
|||
|
||||
|
||||
class ResNet(nn.Layer):
|
||||
def __init__(self, in_channels=3, layers=50, **kwargs):
|
||||
def __init__(self,
|
||||
in_channels=3,
|
||||
layers=50,
|
||||
dcn_stage=None,
|
||||
out_indices=None,
|
||||
**kwargs):
|
||||
super(ResNet, self).__init__()
|
||||
|
||||
self.layers = layers
|
||||
|
@ -175,6 +267,13 @@ class ResNet(nn.Layer):
|
|||
1024] if layers >= 50 else [64, 64, 128, 256]
|
||||
num_filters = [64, 128, 256, 512]
|
||||
|
||||
self.dcn_stage = dcn_stage if dcn_stage is not None else [
|
||||
False, False, False, False
|
||||
]
|
||||
self.out_indices = out_indices if out_indices is not None else [
|
||||
0, 1, 2, 3
|
||||
]
|
||||
|
||||
self.conv1_1 = ConvBNLayer(
|
||||
in_channels=in_channels,
|
||||
out_channels=32,
|
||||
|
@ -201,6 +300,7 @@ class ResNet(nn.Layer):
|
|||
for block in range(len(depth)):
|
||||
block_list = []
|
||||
shortcut = False
|
||||
is_dcn = self.dcn_stage[block]
|
||||
for i in range(depth[block]):
|
||||
bottleneck_block = self.add_sublayer(
|
||||
'bb_%d_%d' % (block, i),
|
||||
|
@ -210,15 +310,18 @@ class ResNet(nn.Layer):
|
|||
out_channels=num_filters[block],
|
||||
stride=2 if i == 0 and block != 0 else 1,
|
||||
shortcut=shortcut,
|
||||
if_first=block == i == 0))
|
||||
if_first=block == i == 0,
|
||||
is_dcn=is_dcn))
|
||||
shortcut = True
|
||||
block_list.append(bottleneck_block)
|
||||
self.out_channels.append(num_filters[block] * 4)
|
||||
if block in self.out_indices:
|
||||
self.out_channels.append(num_filters[block] * 4)
|
||||
self.stages.append(nn.Sequential(*block_list))
|
||||
else:
|
||||
for block in range(len(depth)):
|
||||
block_list = []
|
||||
shortcut = False
|
||||
# is_dcn = self.dcn_stage[block]
|
||||
for i in range(depth[block]):
|
||||
basic_block = self.add_sublayer(
|
||||
'bb_%d_%d' % (block, i),
|
||||
|
@ -231,7 +334,8 @@ class ResNet(nn.Layer):
|
|||
if_first=block == i == 0))
|
||||
shortcut = True
|
||||
block_list.append(basic_block)
|
||||
self.out_channels.append(num_filters[block])
|
||||
if block in self.out_indices:
|
||||
self.out_channels.append(num_filters[block])
|
||||
self.stages.append(nn.Sequential(*block_list))
|
||||
|
||||
def forward(self, inputs):
|
||||
|
@ -240,7 +344,8 @@ class ResNet(nn.Layer):
|
|||
y = self.conv1_3(y)
|
||||
y = self.pool2d_max(y)
|
||||
out = []
|
||||
for block in self.stages:
|
||||
for i, block in enumerate(self.stages):
|
||||
y = block(y)
|
||||
out.append(y)
|
||||
if i in self.out_indices:
|
||||
out.append(y)
|
||||
return out
|
||||
|
|
|
@ -21,6 +21,7 @@ def build_head(config):
|
|||
from .det_east_head import EASTHead
|
||||
from .det_sast_head import SASTHead
|
||||
from .det_pse_head import PSEHead
|
||||
from .det_fce_head import FCEHead
|
||||
from .e2e_pg_head import PGHead
|
||||
|
||||
# rec head
|
||||
|
@ -40,8 +41,8 @@ def build_head(config):
|
|||
from .table_att_head import TableAttentionHead
|
||||
|
||||
support_dict = [
|
||||
'DBHead', 'PSEHead', 'EASTHead', 'SASTHead', 'CTCHead', 'ClsHead',
|
||||
'AttentionHead', 'SRNHead', 'PGHead', 'Transformer',
|
||||
'DBHead', 'PSEHead', 'FCEHead', 'EASTHead', 'SASTHead', 'CTCHead',
|
||||
'ClsHead', 'AttentionHead', 'SRNHead', 'PGHead', 'Transformer',
|
||||
'TableAttentionHead', 'SARHead', 'AsterHead', 'SDMGRHead'
|
||||
]
|
||||
|
||||
|
|
|
@ -0,0 +1,100 @@
|
|||
from paddle import nn
|
||||
from paddle import ParamAttr
|
||||
import paddle.nn.functional as F
|
||||
from paddle.nn.initializer import Normal
|
||||
import paddle
|
||||
from functools import partial
|
||||
|
||||
|
||||
def multi_apply(func, *args, **kwargs):
|
||||
"""Apply function to a list of arguments.
|
||||
|
||||
Note:
|
||||
This function applies the ``func`` to multiple inputs and
|
||||
map the multiple outputs of the ``func`` into different
|
||||
list. Each list contains the same type of outputs corresponding
|
||||
to different inputs.
|
||||
|
||||
Args:
|
||||
func (Function): A function that will be applied to a list of
|
||||
arguments
|
||||
|
||||
Returns:
|
||||
tuple(list): A tuple containing multiple list, each list contains \
|
||||
a kind of returned results by the function
|
||||
"""
|
||||
pfunc = partial(func, **kwargs) if kwargs else func
|
||||
map_results = map(pfunc, *args)
|
||||
return tuple(map(list, zip(*map_results)))
|
||||
|
||||
|
||||
class FCEHead(nn.Layer):
|
||||
"""The class for implementing FCENet head.
|
||||
FCENet(CVPR2021): Fourier Contour Embedding for Arbitrary-shaped Text
|
||||
Detection.
|
||||
|
||||
[https://arxiv.org/abs/2104.10442]
|
||||
|
||||
Args:
|
||||
in_channels (int): The number of input channels.
|
||||
scales (list[int]) : The scale of each layer.
|
||||
fourier_degree (int) : The maximum Fourier transform degree k.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, scales, fourier_degree=5):
|
||||
super().__init__()
|
||||
assert isinstance(in_channels, int)
|
||||
|
||||
self.downsample_ratio = 1.0
|
||||
self.in_channels = in_channels
|
||||
self.scales = scales
|
||||
self.fourier_degree = fourier_degree
|
||||
self.out_channels_cls = 4
|
||||
self.out_channels_reg = (2 * self.fourier_degree + 1) * 2
|
||||
|
||||
self.out_conv_cls = nn.Conv2D(
|
||||
in_channels=self.in_channels,
|
||||
out_channels=self.out_channels_cls,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
groups=1,
|
||||
weight_attr=ParamAttr(
|
||||
name='cls_weights',
|
||||
initializer=Normal(
|
||||
mean=paddle.to_tensor(0.), std=paddle.to_tensor(0.01))),
|
||||
bias_attr=True)
|
||||
self.out_conv_reg = nn.Conv2D(
|
||||
in_channels=self.in_channels,
|
||||
out_channels=self.out_channels_reg,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
groups=1,
|
||||
weight_attr=ParamAttr(
|
||||
name='reg_weights',
|
||||
initializer=Normal(
|
||||
mean=paddle.to_tensor(0.), std=paddle.to_tensor(0.01))),
|
||||
bias_attr=True)
|
||||
|
||||
def forward(self, feats, targets=None):
|
||||
cls_res, reg_res = multi_apply(self.forward_single, feats)
|
||||
level_num = len(cls_res)
|
||||
# import pdb;pdb.set_trace()
|
||||
outs = {}
|
||||
|
||||
if not self.training:
|
||||
for i in range(level_num):
|
||||
tr_pred = F.softmax(cls_res[i][:, 0:2, :, :], axis=1)
|
||||
tcl_pred = F.softmax(cls_res[i][:, 2:, :, :], axis=1)
|
||||
outs['level_{}'.format(i)] = paddle.concat(
|
||||
[tr_pred, tcl_pred, reg_res[i]], axis=1)
|
||||
else:
|
||||
preds = [[cls_res[i], reg_res[i]] for i in range(level_num)]
|
||||
outs['levels'] = preds
|
||||
return outs
|
||||
|
||||
def forward_single(self, x):
|
||||
cls_predict = self.out_conv_cls(x)
|
||||
reg_predict = self.out_conv_reg(x)
|
||||
return cls_predict, reg_predict
|
|
@ -23,7 +23,11 @@ def build_neck(config):
|
|||
from .pg_fpn import PGFPN
|
||||
from .table_fpn import TableFPN
|
||||
from .fpn import FPN
|
||||
support_dict = ['FPN','DBFPN', 'EASTFPN', 'SASTFPN', 'SequenceEncoder', 'PGFPN', 'TableFPN']
|
||||
from .fce_fpn import FCEFPN
|
||||
support_dict = [
|
||||
'FPN', 'FCEFPN', 'DBFPN', 'EASTFPN', 'SASTFPN', 'SequenceEncoder',
|
||||
'PGFPN', 'TableFPN'
|
||||
]
|
||||
|
||||
module_name = config.pop('name')
|
||||
assert module_name in support_dict, Exception('neck only support {}'.format(
|
||||
|
|
|
@ -0,0 +1,262 @@
|
|||
import paddle.nn as nn
|
||||
import paddle.nn.functional as F
|
||||
from paddle import ParamAttr
|
||||
from paddle.nn.initializer import XavierUniform
|
||||
from paddle.nn.initializer import Normal
|
||||
from paddle.regularizer import L2Decay
|
||||
|
||||
__all__ = ['FCEFPN']
|
||||
|
||||
|
||||
class ConvNormLayer(nn.Layer):
|
||||
def __init__(self,
|
||||
ch_in,
|
||||
ch_out,
|
||||
filter_size,
|
||||
stride,
|
||||
groups=1,
|
||||
norm_type='bn',
|
||||
norm_decay=0.,
|
||||
norm_groups=32,
|
||||
lr_scale=1.,
|
||||
freeze_norm=False,
|
||||
initializer=Normal(
|
||||
mean=0., std=0.01)):
|
||||
super(ConvNormLayer, self).__init__()
|
||||
assert norm_type in ['bn', 'sync_bn', 'gn']
|
||||
|
||||
bias_attr = False
|
||||
|
||||
self.conv = nn.Conv2D(
|
||||
in_channels=ch_in,
|
||||
out_channels=ch_out,
|
||||
kernel_size=filter_size,
|
||||
stride=stride,
|
||||
padding=(filter_size - 1) // 2,
|
||||
groups=groups,
|
||||
weight_attr=ParamAttr(
|
||||
initializer=initializer, learning_rate=1.),
|
||||
bias_attr=bias_attr)
|
||||
|
||||
norm_lr = 0. if freeze_norm else 1.
|
||||
param_attr = ParamAttr(
|
||||
learning_rate=norm_lr,
|
||||
regularizer=L2Decay(norm_decay) if norm_decay is not None else None)
|
||||
bias_attr = ParamAttr(
|
||||
learning_rate=norm_lr,
|
||||
regularizer=L2Decay(norm_decay) if norm_decay is not None else None)
|
||||
if norm_type == 'bn':
|
||||
self.norm = nn.BatchNorm2D(
|
||||
ch_out, weight_attr=param_attr, bias_attr=bias_attr)
|
||||
elif norm_type == 'sync_bn':
|
||||
self.norm = nn.SyncBatchNorm(
|
||||
ch_out, weight_attr=param_attr, bias_attr=bias_attr)
|
||||
elif norm_type == 'gn':
|
||||
self.norm = nn.GroupNorm(
|
||||
num_groups=norm_groups,
|
||||
num_channels=ch_out,
|
||||
weight_attr=param_attr,
|
||||
bias_attr=bias_attr)
|
||||
|
||||
def forward(self, inputs):
|
||||
out = self.conv(inputs)
|
||||
out = self.norm(out)
|
||||
return out
|
||||
|
||||
|
||||
class FCEFPN(nn.Layer):
|
||||
"""
|
||||
Feature Pyramid Network, see https://arxiv.org/abs/1612.03144
|
||||
Args:
|
||||
in_channels (list[int]): input channels of each level which can be
|
||||
derived from the output shape of backbone by from_config
|
||||
out_channels (list[int]): output channel of each level
|
||||
spatial_scales (list[float]): the spatial scales between input feature
|
||||
maps and original input image which can be derived from the output
|
||||
shape of backbone by from_config
|
||||
has_extra_convs (bool): whether to add extra conv to the last level.
|
||||
default False
|
||||
extra_stage (int): the number of extra stages added to the last level.
|
||||
default 1
|
||||
use_c5 (bool): Whether to use c5 as the input of extra stage,
|
||||
otherwise p5 is used. default True
|
||||
norm_type (string|None): The normalization type in FPN module. If
|
||||
norm_type is None, norm will not be used after conv and if
|
||||
norm_type is string, bn, gn, sync_bn are available. default None
|
||||
norm_decay (float): weight decay for normalization layer weights.
|
||||
default 0.
|
||||
freeze_norm (bool): whether to freeze normalization layer.
|
||||
default False
|
||||
relu_before_extra_convs (bool): whether to add relu before extra convs.
|
||||
default False
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
spatial_scales=[0.25, 0.125, 0.0625, 0.03125],
|
||||
has_extra_convs=False,
|
||||
extra_stage=1,
|
||||
use_c5=True,
|
||||
norm_type=None,
|
||||
norm_decay=0.,
|
||||
freeze_norm=False,
|
||||
relu_before_extra_convs=True):
|
||||
super(FCEFPN, self).__init__()
|
||||
self.out_channels = out_channels
|
||||
for s in range(extra_stage):
|
||||
spatial_scales = spatial_scales + [spatial_scales[-1] / 2.]
|
||||
self.spatial_scales = spatial_scales
|
||||
self.has_extra_convs = has_extra_convs
|
||||
self.extra_stage = extra_stage
|
||||
self.use_c5 = use_c5
|
||||
self.relu_before_extra_convs = relu_before_extra_convs
|
||||
self.norm_type = norm_type
|
||||
self.norm_decay = norm_decay
|
||||
self.freeze_norm = freeze_norm
|
||||
|
||||
self.lateral_convs = []
|
||||
self.fpn_convs = []
|
||||
fan = out_channels * 3 * 3
|
||||
|
||||
# stage index 0,1,2,3 stands for res2,res3,res4,res5 on ResNet Backbone
|
||||
# 0 <= st_stage < ed_stage <= 3
|
||||
st_stage = 4 - len(in_channels)
|
||||
ed_stage = st_stage + len(in_channels) - 1
|
||||
for i in range(st_stage, ed_stage + 1):
|
||||
if i == 3:
|
||||
lateral_name = 'fpn_inner_res5_sum'
|
||||
else:
|
||||
lateral_name = 'fpn_inner_res{}_sum_lateral'.format(i + 2)
|
||||
in_c = in_channels[i - st_stage]
|
||||
if self.norm_type is not None:
|
||||
lateral = self.add_sublayer(
|
||||
lateral_name,
|
||||
ConvNormLayer(
|
||||
ch_in=in_c,
|
||||
ch_out=out_channels,
|
||||
filter_size=1,
|
||||
stride=1,
|
||||
norm_type=self.norm_type,
|
||||
norm_decay=self.norm_decay,
|
||||
freeze_norm=self.freeze_norm,
|
||||
initializer=XavierUniform(fan_out=in_c)))
|
||||
else:
|
||||
lateral = self.add_sublayer(
|
||||
lateral_name,
|
||||
nn.Conv2D(
|
||||
in_channels=in_c,
|
||||
out_channels=out_channels,
|
||||
kernel_size=1,
|
||||
weight_attr=ParamAttr(
|
||||
initializer=XavierUniform(fan_out=in_c))))
|
||||
self.lateral_convs.append(lateral)
|
||||
|
||||
for i in range(st_stage, ed_stage + 1):
|
||||
fpn_name = 'fpn_res{}_sum'.format(i + 2)
|
||||
if self.norm_type is not None:
|
||||
fpn_conv = self.add_sublayer(
|
||||
fpn_name,
|
||||
ConvNormLayer(
|
||||
ch_in=out_channels,
|
||||
ch_out=out_channels,
|
||||
filter_size=3,
|
||||
stride=1,
|
||||
norm_type=self.norm_type,
|
||||
norm_decay=self.norm_decay,
|
||||
freeze_norm=self.freeze_norm,
|
||||
initializer=XavierUniform(fan_out=fan)))
|
||||
else:
|
||||
fpn_conv = self.add_sublayer(
|
||||
fpn_name,
|
||||
nn.Conv2D(
|
||||
in_channels=out_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
weight_attr=ParamAttr(
|
||||
initializer=XavierUniform(fan_out=fan))))
|
||||
self.fpn_convs.append(fpn_conv)
|
||||
|
||||
# add extra conv levels for RetinaNet(use_c5)/FCOS(use_p5)
|
||||
if self.has_extra_convs:
|
||||
for i in range(self.extra_stage):
|
||||
lvl = ed_stage + 1 + i
|
||||
if i == 0 and self.use_c5:
|
||||
in_c = in_channels[-1]
|
||||
else:
|
||||
in_c = out_channels
|
||||
extra_fpn_name = 'fpn_{}'.format(lvl + 2)
|
||||
if self.norm_type is not None:
|
||||
extra_fpn_conv = self.add_sublayer(
|
||||
extra_fpn_name,
|
||||
ConvNormLayer(
|
||||
ch_in=in_c,
|
||||
ch_out=out_channels,
|
||||
filter_size=3,
|
||||
stride=2,
|
||||
norm_type=self.norm_type,
|
||||
norm_decay=self.norm_decay,
|
||||
freeze_norm=self.freeze_norm,
|
||||
initializer=XavierUniform(fan_out=fan)))
|
||||
else:
|
||||
extra_fpn_conv = self.add_sublayer(
|
||||
extra_fpn_name,
|
||||
nn.Conv2D(
|
||||
in_channels=in_c,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
weight_attr=ParamAttr(
|
||||
initializer=XavierUniform(fan_out=fan))))
|
||||
self.fpn_convs.append(extra_fpn_conv)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, cfg, input_shape):
|
||||
return {
|
||||
'in_channels': [i.channels for i in input_shape],
|
||||
'spatial_scales': [1.0 / i.stride for i in input_shape],
|
||||
}
|
||||
|
||||
def forward(self, body_feats):
|
||||
laterals = []
|
||||
num_levels = len(body_feats)
|
||||
|
||||
for i in range(num_levels):
|
||||
laterals.append(self.lateral_convs[i](body_feats[i]))
|
||||
|
||||
for i in range(1, num_levels):
|
||||
lvl = num_levels - i
|
||||
upsample = F.interpolate(
|
||||
laterals[lvl],
|
||||
scale_factor=2.,
|
||||
mode='nearest', )
|
||||
laterals[lvl - 1] += upsample
|
||||
|
||||
fpn_output = []
|
||||
for lvl in range(num_levels):
|
||||
fpn_output.append(self.fpn_convs[lvl](laterals[lvl]))
|
||||
|
||||
if self.extra_stage > 0:
|
||||
# use max pool to get more levels on top of outputs (Faster R-CNN, Mask R-CNN)
|
||||
if not self.has_extra_convs:
|
||||
assert self.extra_stage == 1, 'extra_stage should be 1 if FPN has not extra convs'
|
||||
fpn_output.append(F.max_pool2d(fpn_output[-1], 1, stride=2))
|
||||
# add extra conv levels for RetinaNet(use_c5)/FCOS(use_p5)
|
||||
else:
|
||||
if self.use_c5:
|
||||
extra_source = body_feats[-1]
|
||||
else:
|
||||
extra_source = fpn_output[-1]
|
||||
fpn_output.append(self.fpn_convs[num_levels](extra_source))
|
||||
|
||||
for i in range(1, self.extra_stage):
|
||||
if self.relu_before_extra_convs:
|
||||
fpn_output.append(self.fpn_convs[num_levels + i](F.relu(
|
||||
fpn_output[-1])))
|
||||
else:
|
||||
fpn_output.append(self.fpn_convs[num_levels + i](
|
||||
fpn_output[-1]))
|
||||
return fpn_output
|
|
@ -24,6 +24,7 @@ __all__ = ['build_post_process']
|
|||
from .db_postprocess import DBPostProcess, DistillationDBPostProcess
|
||||
from .east_postprocess import EASTPostProcess
|
||||
from .sast_postprocess import SASTPostProcess
|
||||
from .fce_postprocess import FCEPostProcess
|
||||
from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, DistillationCTCLabelDecode, \
|
||||
TableLabelDecode, NRTRLabelDecode, SARLabelDecode, SEEDLabelDecode
|
||||
from .cls_postprocess import ClsPostProcess
|
||||
|
@ -34,9 +35,9 @@ from .vqa_token_re_layoutlm_postprocess import VQAReTokenLayoutLMPostProcess
|
|||
|
||||
def build_post_process(config, global_config=None):
|
||||
support_dict = [
|
||||
'DBPostProcess', 'EASTPostProcess', 'SASTPostProcess', 'CTCLabelDecode',
|
||||
'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode', 'PGPostProcess',
|
||||
'DistillationCTCLabelDecode', 'TableLabelDecode',
|
||||
'DBPostProcess', 'EASTPostProcess', 'SASTPostProcess', 'FCEPostProcess',
|
||||
'CTCLabelDecode', 'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode',
|
||||
'PGPostProcess', 'DistillationCTCLabelDecode', 'TableLabelDecode',
|
||||
'DistillationDBPostProcess', 'NRTRLabelDecode', 'SARLabelDecode',
|
||||
'SEEDLabelDecode', 'VQASerTokenLayoutLMPostProcess',
|
||||
'VQAReTokenLayoutLMPostProcess'
|
||||
|
|
|
@ -0,0 +1,368 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import cv2
|
||||
import paddle
|
||||
from numpy.fft import ifft
|
||||
import Polygon as plg
|
||||
|
||||
|
||||
def points2polygon(points):
|
||||
"""Convert k points to 1 polygon.
|
||||
|
||||
Args:
|
||||
points (ndarray or list): A ndarray or a list of shape (2k)
|
||||
that indicates k points.
|
||||
|
||||
Returns:
|
||||
polygon (Polygon): A polygon object.
|
||||
"""
|
||||
if isinstance(points, list):
|
||||
points = np.array(points)
|
||||
|
||||
assert isinstance(points, np.ndarray)
|
||||
assert (points.size % 2 == 0) and (points.size >= 8)
|
||||
|
||||
point_mat = points.reshape([-1, 2])
|
||||
return plg.Polygon(point_mat)
|
||||
|
||||
|
||||
def poly_intersection(poly_det, poly_gt):
|
||||
"""Calculate the intersection area between two polygon.
|
||||
|
||||
Args:
|
||||
poly_det (Polygon): A polygon predicted by detector.
|
||||
poly_gt (Polygon): A gt polygon.
|
||||
|
||||
Returns:
|
||||
intersection_area (float): The intersection area between two polygons.
|
||||
"""
|
||||
assert isinstance(poly_det, plg.Polygon)
|
||||
assert isinstance(poly_gt, plg.Polygon)
|
||||
|
||||
poly_inter = poly_det & poly_gt
|
||||
if len(poly_inter) == 0:
|
||||
return 0, poly_inter
|
||||
return poly_inter.area(), poly_inter
|
||||
|
||||
|
||||
def poly_union(poly_det, poly_gt):
|
||||
"""Calculate the union area between two polygon.
|
||||
|
||||
Args:
|
||||
poly_det (Polygon): A polygon predicted by detector.
|
||||
poly_gt (Polygon): A gt polygon.
|
||||
|
||||
Returns:
|
||||
union_area (float): The union area between two polygons.
|
||||
"""
|
||||
assert isinstance(poly_det, plg.Polygon)
|
||||
assert isinstance(poly_gt, plg.Polygon)
|
||||
|
||||
area_det = poly_det.area()
|
||||
area_gt = poly_gt.area()
|
||||
area_inters, _ = poly_intersection(poly_det, poly_gt)
|
||||
return area_det + area_gt - area_inters
|
||||
|
||||
|
||||
def valid_boundary(x, with_score=True):
|
||||
num = len(x)
|
||||
if num < 8:
|
||||
return False
|
||||
if num % 2 == 0 and (not with_score):
|
||||
return True
|
||||
if num % 2 == 1 and with_score:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def boundary_iou(src, target):
|
||||
"""Calculate the IOU between two boundaries.
|
||||
|
||||
Args:
|
||||
src (list): Source boundary.
|
||||
target (list): Target boundary.
|
||||
|
||||
Returns:
|
||||
iou (float): The iou between two boundaries.
|
||||
"""
|
||||
assert valid_boundary(src, False)
|
||||
assert valid_boundary(target, False)
|
||||
src_poly = points2polygon(src)
|
||||
target_poly = points2polygon(target)
|
||||
|
||||
return poly_iou(src_poly, target_poly)
|
||||
|
||||
|
||||
def poly_iou(poly_det, poly_gt):
|
||||
"""Calculate the IOU between two polygons.
|
||||
|
||||
Args:
|
||||
poly_det (Polygon): A polygon predicted by detector.
|
||||
poly_gt (Polygon): A gt polygon.
|
||||
|
||||
Returns:
|
||||
iou (float): The IOU between two polygons.
|
||||
"""
|
||||
assert isinstance(poly_det, plg.Polygon)
|
||||
assert isinstance(poly_gt, plg.Polygon)
|
||||
area_inters, _ = poly_intersection(poly_det, poly_gt)
|
||||
area_union = poly_union(poly_det, poly_gt)
|
||||
if area_union == 0:
|
||||
return 0.0
|
||||
return area_inters / area_union
|
||||
|
||||
|
||||
def poly_nms(polygons, threshold):
|
||||
assert isinstance(polygons, list)
|
||||
|
||||
polygons = np.array(sorted(polygons, key=lambda x: x[-1]))
|
||||
|
||||
keep_poly = []
|
||||
index = [i for i in range(polygons.shape[0])]
|
||||
|
||||
while len(index) > 0:
|
||||
keep_poly.append(polygons[index[-1]].tolist())
|
||||
A = polygons[index[-1]][:-1]
|
||||
index = np.delete(index, -1)
|
||||
|
||||
iou_list = np.zeros((len(index), ))
|
||||
for i in range(len(index)):
|
||||
B = polygons[index[i]][:-1]
|
||||
|
||||
iou_list[i] = boundary_iou(A, B)
|
||||
remove_index = np.where(iou_list > threshold)
|
||||
index = np.delete(index, remove_index)
|
||||
|
||||
return keep_poly
|
||||
|
||||
|
||||
def fill_hole(input_mask):
|
||||
h, w = input_mask.shape
|
||||
canvas = np.zeros((h + 2, w + 2), np.uint8)
|
||||
canvas[1:h + 1, 1:w + 1] = input_mask.copy()
|
||||
|
||||
mask = np.zeros((h + 4, w + 4), np.uint8)
|
||||
|
||||
cv2.floodFill(canvas, mask, (0, 0), 1)
|
||||
canvas = canvas[1:h + 1, 1:w + 1].astype(np.bool)
|
||||
|
||||
return ~canvas | input_mask
|
||||
|
||||
|
||||
def fourier2poly(fourier_coeff, num_reconstr_points=50):
|
||||
""" Inverse Fourier transform
|
||||
Args:
|
||||
fourier_coeff (ndarray): Fourier coefficients shaped (n, 2k+1),
|
||||
with n and k being candidates number and Fourier degree
|
||||
respectively.
|
||||
num_reconstr_points (int): Number of reconstructed polygon points.
|
||||
Returns:
|
||||
Polygons (ndarray): The reconstructed polygons shaped (n, n')
|
||||
"""
|
||||
|
||||
a = np.zeros((len(fourier_coeff), num_reconstr_points), dtype='complex')
|
||||
k = (len(fourier_coeff[0]) - 1) // 2
|
||||
|
||||
a[:, 0:k + 1] = fourier_coeff[:, k:]
|
||||
a[:, -k:] = fourier_coeff[:, :k]
|
||||
|
||||
poly_complex = ifft(a) * num_reconstr_points
|
||||
polygon = np.zeros((len(fourier_coeff), num_reconstr_points, 2))
|
||||
polygon[:, :, 0] = poly_complex.real
|
||||
polygon[:, :, 1] = poly_complex.imag
|
||||
return polygon.astype('int32').reshape((len(fourier_coeff), -1))
|
||||
|
||||
|
||||
def fcenet_decode(preds,
|
||||
fourier_degree,
|
||||
num_reconstr_points,
|
||||
scale,
|
||||
alpha=1.0,
|
||||
beta=2.0,
|
||||
text_repr_type='poly',
|
||||
score_thr=0.3,
|
||||
nms_thr=0.1):
|
||||
"""Decoding predictions of FCENet to instances.
|
||||
|
||||
Args:
|
||||
preds (list(Tensor)): The head output tensors.
|
||||
fourier_degree (int): The maximum Fourier transform degree k.
|
||||
num_reconstr_points (int): The points number of the polygon
|
||||
reconstructed from predicted Fourier coefficients.
|
||||
scale (int): The down-sample scale of the prediction.
|
||||
alpha (float) : The parameter to calculate final scores. Score_{final}
|
||||
= (Score_{text region} ^ alpha)
|
||||
* (Score_{text center region}^ beta)
|
||||
beta (float) : The parameter to calculate final score.
|
||||
text_repr_type (str): Boundary encoding type 'poly' or 'quad'.
|
||||
score_thr (float) : The threshold used to filter out the final
|
||||
candidates.
|
||||
nms_thr (float) : The threshold of nms.
|
||||
|
||||
Returns:
|
||||
boundaries (list[list[float]]): The instance boundary and confidence
|
||||
list.
|
||||
"""
|
||||
assert isinstance(preds, list)
|
||||
assert len(preds) == 2
|
||||
assert text_repr_type in ['poly', 'quad']
|
||||
|
||||
# import pdb;pdb.set_trace()
|
||||
cls_pred = preds[0][0]
|
||||
# tr_pred = F.softmax(cls_pred[0:2], axis=0).cpu().numpy()
|
||||
# tcl_pred = F.softmax(cls_pred[2:], axis=0).cpu().numpy()
|
||||
|
||||
tr_pred = cls_pred[0:2]
|
||||
tcl_pred = cls_pred[2:]
|
||||
|
||||
reg_pred = preds[1][0].transpose([1, 2, 0]) #.cpu().numpy()
|
||||
x_pred = reg_pred[:, :, :2 * fourier_degree + 1]
|
||||
y_pred = reg_pred[:, :, 2 * fourier_degree + 1:]
|
||||
|
||||
score_pred = (tr_pred[1]**alpha) * (tcl_pred[1]**beta)
|
||||
tr_pred_mask = (score_pred) > score_thr
|
||||
tr_mask = fill_hole(tr_pred_mask)
|
||||
|
||||
tr_contours, _ = cv2.findContours(
|
||||
tr_mask.astype(np.uint8), cv2.RETR_TREE,
|
||||
cv2.CHAIN_APPROX_SIMPLE) # opencv4
|
||||
|
||||
mask = np.zeros_like(tr_mask)
|
||||
boundaries = []
|
||||
for cont in tr_contours:
|
||||
deal_map = mask.copy().astype(np.int8)
|
||||
cv2.drawContours(deal_map, [cont], -1, 1, -1)
|
||||
|
||||
score_map = score_pred * deal_map
|
||||
score_mask = score_map > 0
|
||||
xy_text = np.argwhere(score_mask)
|
||||
dxy = xy_text[:, 1] + xy_text[:, 0] * 1j
|
||||
|
||||
x, y = x_pred[score_mask], y_pred[score_mask]
|
||||
c = x + y * 1j
|
||||
c[:, fourier_degree] = c[:, fourier_degree] + dxy
|
||||
c *= scale
|
||||
|
||||
polygons = fourier2poly(c, num_reconstr_points)
|
||||
score = score_map[score_mask].reshape(-1, 1)
|
||||
polygons = poly_nms(np.hstack((polygons, score)).tolist(), nms_thr)
|
||||
|
||||
boundaries = boundaries + polygons
|
||||
|
||||
boundaries = poly_nms(boundaries, nms_thr)
|
||||
|
||||
if text_repr_type == 'quad':
|
||||
new_boundaries = []
|
||||
for boundary in boundaries:
|
||||
poly = np.array(boundary[:-1]).reshape(-1, 2).astype(np.float32)
|
||||
score = boundary[-1]
|
||||
points = cv2.boxPoints(cv2.minAreaRect(poly))
|
||||
points = np.int0(points)
|
||||
new_boundaries.append(points.reshape(-1).tolist() + [score])
|
||||
|
||||
return boundaries
|
||||
|
||||
|
||||
class FCEPostProcess(object):
|
||||
"""
|
||||
The post process for FCENet.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
scales,
|
||||
fourier_degree=5,
|
||||
num_reconstr_points=50,
|
||||
decoding_type='fcenet',
|
||||
score_thr=0.3,
|
||||
nms_thr=0.1,
|
||||
alpha=1.0,
|
||||
beta=1.0,
|
||||
text_repr_type='poly',
|
||||
**kwargs):
|
||||
|
||||
self.scales = scales
|
||||
self.fourier_degree = fourier_degree
|
||||
self.num_reconstr_points = num_reconstr_points
|
||||
self.decoding_type = decoding_type
|
||||
self.score_thr = score_thr
|
||||
self.nms_thr = nms_thr
|
||||
self.alpha = alpha
|
||||
self.beta = beta
|
||||
self.text_repr_type = text_repr_type
|
||||
|
||||
def __call__(self, preds, shape_list):
|
||||
score_maps = []
|
||||
for key, value in preds.items():
|
||||
if isinstance(value, paddle.Tensor):
|
||||
value = value.numpy()
|
||||
cls_res = value[:, :4, :, :]
|
||||
reg_res = value[:, 4:, :, :]
|
||||
score_maps.append([cls_res, reg_res])
|
||||
|
||||
return self.get_boundary(score_maps, shape_list)
|
||||
|
||||
def resize_boundary(self, boundaries, scale_factor):
|
||||
"""Rescale boundaries via scale_factor.
|
||||
|
||||
Args:
|
||||
boundaries (list[list[float]]): The boundary list. Each boundary
|
||||
with size 2k+1 with k>=4.
|
||||
scale_factor(ndarray): The scale factor of size (4,).
|
||||
|
||||
Returns:
|
||||
boundaries (list[list[float]]): The scaled boundaries.
|
||||
"""
|
||||
# assert check_argument.is_2dlist(boundaries)
|
||||
# assert isinstance(scale_factor, np.ndarray)
|
||||
# assert scale_factor.shape[0] == 4
|
||||
|
||||
boxes = []
|
||||
scores = []
|
||||
for b in boundaries:
|
||||
sz = len(b)
|
||||
valid_boundary(b, True)
|
||||
scores.append(b[-1])
|
||||
b = (np.array(b[:sz - 1]) *
|
||||
(np.tile(scale_factor[:2], int(
|
||||
(sz - 1) / 2)).reshape(1, sz - 1))).flatten().tolist()
|
||||
boxes.append(np.array(b).reshape([-1, 2]))
|
||||
|
||||
return np.array(boxes, dtype=np.float32), scores
|
||||
|
||||
def get_boundary(self, score_maps, shape_list):
|
||||
assert len(score_maps) == len(self.scales)
|
||||
# import pdb;pdb.set_trace()
|
||||
boundaries = []
|
||||
for idx, score_map in enumerate(score_maps):
|
||||
scale = self.scales[idx]
|
||||
boundaries = boundaries + self._get_boundary_single(score_map,
|
||||
scale)
|
||||
|
||||
# nms
|
||||
boundaries = poly_nms(boundaries, self.nms_thr)
|
||||
# if rescale:
|
||||
# import pdb;pdb.set_trace()
|
||||
boundaries, scores = self.resize_boundary(
|
||||
boundaries, (1 / shape_list[0, 2:]).tolist()[::-1])
|
||||
|
||||
boxes_batch = [dict(points=boundaries, scores=scores)]
|
||||
return boxes_batch
|
||||
|
||||
def _get_boundary_single(self, score_map, scale):
|
||||
assert len(score_map) == 2
|
||||
assert score_map[1].shape[1] == 4 * self.fourier_degree + 2
|
||||
|
||||
return fcenet_decode(
|
||||
preds=score_map,
|
||||
fourier_degree=self.fourier_degree,
|
||||
num_reconstr_points=self.num_reconstr_points,
|
||||
scale=scale,
|
||||
alpha=self.alpha,
|
||||
beta=self.beta,
|
||||
text_repr_type=self.text_repr_type,
|
||||
score_thr=self.score_thr,
|
||||
nms_thr=self.nms_thr)
|
3
train.sh
3
train.sh
|
@ -1,2 +1,3 @@
|
|||
# recommended paddle.__version__ == 2.0.0
|
||||
python3 -m paddle.distributed.launch --log_dir=./debug/ --gpus '0,1,2,3,4,5,6,7' tools/train.py -c configs/rec/rec_mv3_none_bilstm_ctc.yml
|
||||
# python3 -m paddle.distributed.launch --log_dir=./debug/ --gpus '0,1,2,3,4,5,6,7' tools/train.py -c configs/rec/rec_mv3_none_bilstm_ctc.yml
|
||||
python -m paddle.distributed.launch --gpus '7' tools/train.py -c configs/det/det_r50_fce_ctw.yml
|
||||
|
|
Loading…
Reference in New Issue