modify fcenet
parent
4cea42d51b
commit
bf7e085ea2
|
@ -3,17 +3,17 @@ Global:
|
|||
epoch_num: 1500
|
||||
log_smooth_window: 20
|
||||
print_batch_step: 20
|
||||
save_model_dir: ./output/fce_r50_ctw/
|
||||
save_model_dir: ./output/det_r50_dcn_fce_ctw/
|
||||
save_epoch_step: 100
|
||||
# evaluation is run every 835 iterations
|
||||
eval_batch_step: [0, 835]
|
||||
cal_metric_during_train: False
|
||||
pretrained_model: ../pretrain_models/ResNet50_vd_ssld_pretrained
|
||||
checkpoints: #output/fce_r50_ctw/latest
|
||||
pretrained_model: ./pretrain_models/ResNet50_vd_ssld_pretrained
|
||||
checkpoints: #output/det_r50_dcn_fce_ctw/latest
|
||||
save_inference_dir:
|
||||
use_visualdl: False
|
||||
infer_img: doc/imgs_en/img_10.jpg
|
||||
save_res_path: ./output/fce_r50_ctw/predicts_ctw.txt
|
||||
save_res_path: ./output/det_fce/predicts_fce.txt
|
||||
|
||||
|
||||
Architecture:
|
||||
|
@ -65,9 +65,9 @@ Metric:
|
|||
Train:
|
||||
dataset:
|
||||
name: SimpleDataSet
|
||||
data_dir: /data/Dataset/OCR_det/ctw1500/imgs/
|
||||
data_dir: ./train_data/ctw1500/imgs/
|
||||
label_file_list:
|
||||
- /data/Dataset/OCR_det/ctw1500/imgs/training.txt
|
||||
- ./train_data/ctw1500/imgs/training.txt
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: BGR
|
||||
|
@ -113,9 +113,9 @@ Train:
|
|||
Eval:
|
||||
dataset:
|
||||
name: SimpleDataSet
|
||||
data_dir: /data/Dataset/OCR_det/ctw1500/imgs/
|
||||
data_dir: ./train_data/ctw1500/imgs/
|
||||
label_file_list:
|
||||
- /data/Dataset/OCR_det/ctw1500/imgs/test.txt
|
||||
- ./train_data/ctw1500/imgs/test.txt
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: BGR
|
|
@ -1,63 +1,26 @@
|
|||
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This code is refer from:
|
||||
https://github.com/open-mmlab/mmocr/blob/main/mmocr/datasets/pipelines/transforms.py
|
||||
"""
|
||||
import numpy as np
|
||||
from PIL import Image, ImageDraw
|
||||
import paddle.vision.transforms as paddle_trans
|
||||
import cv2
|
||||
import Polygon as plg
|
||||
import math
|
||||
|
||||
|
||||
def imresize(img,
|
||||
size,
|
||||
return_scale=False,
|
||||
interpolation='bilinear',
|
||||
out=None,
|
||||
backend=None):
|
||||
"""Resize image to a given size.
|
||||
|
||||
Args:
|
||||
img (ndarray): The input image.
|
||||
size (tuple[int]): Target size (w, h).
|
||||
return_scale (bool): Whether to return `w_scale` and `h_scale`.
|
||||
interpolation (str): Interpolation method, accepted values are
|
||||
"nearest", "bilinear", "bicubic", "area", "lanczos" for 'cv2'
|
||||
backend, "nearest", "bilinear" for 'pillow' backend.
|
||||
out (ndarray): The output destination.
|
||||
backend (str | None): The image resize backend type. Options are `cv2`,
|
||||
`pillow`, `None`. If backend is None, the global imread_backend
|
||||
specified by ``mmcv.use_backend()`` will be used. Default: None.
|
||||
|
||||
Returns:
|
||||
tuple | ndarray: (`resized_img`, `w_scale`, `h_scale`) or
|
||||
`resized_img`.
|
||||
"""
|
||||
cv2_interp_codes = {
|
||||
'nearest': cv2.INTER_NEAREST,
|
||||
'bilinear': cv2.INTER_LINEAR,
|
||||
'bicubic': cv2.INTER_CUBIC,
|
||||
'area': cv2.INTER_AREA,
|
||||
'lanczos': cv2.INTER_LANCZOS4
|
||||
}
|
||||
h, w = img.shape[:2]
|
||||
if backend is None:
|
||||
backend = 'cv2'
|
||||
if backend not in ['cv2', 'pillow']:
|
||||
raise ValueError(f'backend: {backend} is not supported for resize.'
|
||||
f"Supported backends are 'cv2', 'pillow'")
|
||||
|
||||
if backend == 'pillow':
|
||||
assert img.dtype == np.uint8, 'Pillow backend only support uint8 type'
|
||||
pil_image = Image.fromarray(img)
|
||||
pil_image = pil_image.resize(size, pillow_interp_codes[interpolation])
|
||||
resized_img = np.array(pil_image)
|
||||
else:
|
||||
resized_img = cv2.resize(
|
||||
img, size, dst=out, interpolation=cv2_interp_codes[interpolation])
|
||||
if not return_scale:
|
||||
return resized_img
|
||||
else:
|
||||
w_scale = size[0] / w
|
||||
h_scale = size[1] / h
|
||||
return resized_img, w_scale, h_scale
|
||||
from ppocr.utils.poly_nms import poly_intersection
|
||||
|
||||
|
||||
class RandomScaling:
|
||||
|
@ -83,45 +46,16 @@ class RandomScaling:
|
|||
scales = self.size * 1.0 / max(h, w) * aspect_ratio
|
||||
scales = np.array([scales, scales])
|
||||
out_size = (int(h * scales[1]), int(w * scales[0]))
|
||||
image = imresize(image, out_size[::-1])
|
||||
image = cv2.resize(image, out_size[::-1])
|
||||
|
||||
data['image'] = image
|
||||
text_polys[:, :, 0::2] = text_polys[:, :, 0::2] * scales[1]
|
||||
text_polys[:, :, 1::2] = text_polys[:, :, 1::2] * scales[0]
|
||||
data['polys'] = text_polys
|
||||
|
||||
# import os
|
||||
# base_name = os.path.split(data['img_path'])[-1]
|
||||
# img = image[..., ::-1]
|
||||
# img = Image.fromarray(img)
|
||||
# draw = ImageDraw.Draw(img)
|
||||
# for box in text_polys:
|
||||
# draw.polygon(box, outline=(0, 255, 255,), )
|
||||
# import time
|
||||
# img.save('tmp/{}.jpg'.format(base_name[:-4]))
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def poly_intersection(poly_det, poly_gt):
|
||||
"""Calculate the intersection area between two polygon.
|
||||
|
||||
Args:
|
||||
poly_det (Polygon): A polygon predicted by detector.
|
||||
poly_gt (Polygon): A gt polygon.
|
||||
|
||||
Returns:
|
||||
intersection_area (float): The intersection area between two polygons.
|
||||
"""
|
||||
assert isinstance(poly_det, plg.Polygon)
|
||||
assert isinstance(poly_gt, plg.Polygon)
|
||||
|
||||
poly_inter = poly_det & poly_gt
|
||||
if len(poly_inter) == 0:
|
||||
return 0, poly_inter
|
||||
return poly_inter.area(), poly_inter
|
||||
|
||||
|
||||
class RandomCropFlip:
|
||||
def __init__(self,
|
||||
pad_ratio=0.1,
|
||||
|
@ -352,12 +286,7 @@ class RandomCropPolyInstances:
|
|||
max_y_start = max(np.min(selected_mask[:, 1]) - 2, 0)
|
||||
min_y_end = min(np.max(selected_mask[:, 1]) + 3, h - 1)
|
||||
|
||||
# for key in results.get('mask_fields', []):
|
||||
# if len(results[key].masks) == 0:
|
||||
# continue
|
||||
# masks = results[key].masks
|
||||
for mask in key_masks:
|
||||
# assert len(mask) == 1
|
||||
mask = mask.reshape((-1, 2)).astype(np.int32)
|
||||
clip_x = np.clip(mask[:, 0], 0, w - 1)
|
||||
clip_y = np.clip(mask[:, 1], 0, h - 1)
|
||||
|
@ -501,7 +430,8 @@ class RandomRotatePolyInstances:
|
|||
(h_ind, w_ind) = (np.random.randint(0, h * 7 // 8),
|
||||
np.random.randint(0, w * 7 // 8))
|
||||
img_cut = img[h_ind:(h_ind + h // 9), w_ind:(w_ind + w // 9)]
|
||||
img_cut = imresize(img_cut, (canvas_size[1], canvas_size[0]))
|
||||
img_cut = cv2.resize(img_cut, (canvas_size[1], canvas_size[0]))
|
||||
|
||||
mask = cv2.warpAffine(
|
||||
mask,
|
||||
rotation_matrix, (canvas_size[1], canvas_size[0]),
|
||||
|
@ -574,7 +504,7 @@ class SquareResizePad:
|
|||
t_w = self.target_size if h <= w else int(w * self.target_size / h)
|
||||
else:
|
||||
t_h = t_w = self.target_size
|
||||
img = imresize(img, (t_w, t_h))
|
||||
img = cv2.resize(img, (t_w, t_h))
|
||||
return img, (t_h, t_w)
|
||||
|
||||
def square_pad(self, img):
|
||||
|
@ -589,7 +519,7 @@ class SquareResizePad:
|
|||
(h_ind, w_ind) = (np.random.randint(0, h * 7 // 8),
|
||||
np.random.randint(0, w * 7 // 8))
|
||||
img_cut = img[h_ind:(h_ind + h // 9), w_ind:(w_ind + w // 9)]
|
||||
expand_img = imresize(img_cut, (pad_size, pad_size))
|
||||
expand_img = cv2.resize(img_cut, (pad_size, pad_size))
|
||||
if h > w:
|
||||
y0, x0 = 0, (h - w) // 2
|
||||
else:
|
||||
|
@ -617,13 +547,14 @@ class SquareResizePad:
|
|||
else:
|
||||
image, out_size = self.resize_img(image, keep_ratio=False)
|
||||
offset = (0, 0)
|
||||
# image, out_size = self.resize_img(image, keep_ratio=True)
|
||||
# image, offset = self.square_pad(image)
|
||||
results['image'] = image
|
||||
polygons[:, :, 0::2] = polygons[:, :, 0::2] * out_size[1] / w + offset[
|
||||
0]
|
||||
polygons[:, :, 1::2] = polygons[:, :, 1::2] * out_size[0] / h + offset[
|
||||
1]
|
||||
try:
|
||||
polygons[:, :, 0::2] = polygons[:, :, 0::2] * out_size[
|
||||
1] / w + offset[0]
|
||||
polygons[:, :, 1::2] = polygons[:, :, 1::2] * out_size[
|
||||
0] / h + offset[1]
|
||||
except:
|
||||
pass
|
||||
results['polys'] = polygons
|
||||
|
||||
return results
|
||||
|
|
|
@ -1,3 +1,21 @@
|
|||
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This code is refer from:
|
||||
https://github.com/open-mmlab/mmocr/blob/main/mmocr/datasets/pipelines/textdet_targets/fcenet_targets.py
|
||||
"""
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from numpy.fft import fft
|
||||
|
@ -470,7 +488,6 @@ class FCENetTargets:
|
|||
"""
|
||||
|
||||
assert isinstance(img_size, tuple)
|
||||
# assert check_argument.is_2dlist(text_polys)
|
||||
|
||||
h, w = img_size
|
||||
k = self.fourier_degree
|
||||
|
@ -478,9 +495,6 @@ class FCENetTargets:
|
|||
imag_map = np.zeros((k * 2 + 1, h, w), dtype=np.float32)
|
||||
|
||||
for poly in text_polys:
|
||||
# assert len(poly) == 1
|
||||
# text_instance = [[poly[i], poly[i + 1]]
|
||||
# for i in range(0, len(poly), 2)]
|
||||
mask = np.zeros((h, w), dtype=np.uint8)
|
||||
polygon = np.array(poly).reshape((1, -1, 2))
|
||||
cv2.fillPoly(mask, polygon.astype(np.int32), 1)
|
||||
|
@ -512,15 +526,11 @@ class FCENetTargets:
|
|||
"""
|
||||
|
||||
assert isinstance(img_size, tuple)
|
||||
# assert check_argument.is_2dlist(text_polys)
|
||||
|
||||
h, w = img_size
|
||||
text_region_mask = np.zeros((h, w), dtype=np.uint8)
|
||||
|
||||
for poly in text_polys:
|
||||
# assert len(poly) == 1
|
||||
# text_instance = [[poly[i], poly[i + 1]]
|
||||
# for i in range(0, len(poly), 2)]
|
||||
polygon = np.array(poly, dtype=np.int32).reshape((1, -1, 2))
|
||||
cv2.fillPoly(text_region_mask, polygon, 1)
|
||||
|
||||
|
@ -539,8 +549,6 @@ class FCENetTargets:
|
|||
mask (ndarray): The effective mask of (height, width).
|
||||
"""
|
||||
|
||||
# assert check_argument.is_2dlist(polygons_ignore)
|
||||
|
||||
mask = np.ones(mask_size, dtype=np.uint8)
|
||||
|
||||
for poly in polygons_ignore:
|
||||
|
@ -566,9 +574,6 @@ class FCENetTargets:
|
|||
lv_ignore_polys = [[] for i in range(len(lv_size_divs))]
|
||||
level_maps = []
|
||||
for poly in text_polys:
|
||||
# assert len(poly) == 1
|
||||
# text_instance = [[poly[i], poly[i + 1]]
|
||||
# for i in range(0, len(poly), 2)]
|
||||
polygon = np.array(poly, dtype=np.int).reshape((1, -1, 2))
|
||||
_, _, box_w, box_h = cv2.boundingRect(polygon)
|
||||
proportion = max(box_h, box_w) / (h + 1e-8)
|
||||
|
@ -578,9 +583,6 @@ class FCENetTargets:
|
|||
lv_text_polys[ind].append(poly / lv_size_divs[ind])
|
||||
|
||||
for ignore_poly in ignore_polys:
|
||||
# assert len(ignore_poly) == 1
|
||||
# text_instance = [[ignore_poly[i], ignore_poly[i + 1]]
|
||||
# for i in range(0, len(ignore_poly), 2)]
|
||||
polygon = np.array(ignore_poly, dtype=np.int).reshape((1, -1, 2))
|
||||
_, _, box_w, box_h = cv2.boundingRect(polygon)
|
||||
proportion = max(box_h, box_w) / (h + 1e-8)
|
||||
|
@ -630,18 +632,6 @@ class FCENetTargets:
|
|||
ignore_tags = results['ignore_tags']
|
||||
h, w, _ = image.shape
|
||||
|
||||
# import time
|
||||
# from PIL import Image, ImageDraw
|
||||
# cur_time = time.time()
|
||||
# image = results['image']
|
||||
# text_polys = results['polys']
|
||||
# img = image[..., ::-1]
|
||||
# img = Image.fromarray(img)
|
||||
# draw = ImageDraw.Draw(img)
|
||||
# for box in text_polys:
|
||||
# draw.polygon(box, outline=(0, 255, 255,), )
|
||||
# img.save('tmp/{}_resize_pad.jpg'.format(cur_time))
|
||||
|
||||
polygon_masks = []
|
||||
polygon_masks_ignore = []
|
||||
for tag, polygon in zip(ignore_tags, polygons):
|
||||
|
@ -653,8 +643,6 @@ class FCENetTargets:
|
|||
level_maps = self.generate_level_targets((h, w), polygon_masks,
|
||||
polygon_masks_ignore)
|
||||
|
||||
# results['mask_fields'].clear() # rm gt_masks encoded by polygons
|
||||
# import remote_pdb as pdb;pdb.set_trace()
|
||||
mapping = {
|
||||
'p3_maps': level_maps[0],
|
||||
'p4_maps': level_maps[1],
|
||||
|
|
|
@ -23,6 +23,7 @@ import sys
|
|||
import six
|
||||
import cv2
|
||||
import numpy as np
|
||||
import math
|
||||
|
||||
|
||||
class DecodeImage(object):
|
||||
|
@ -165,6 +166,27 @@ class KeepKeys(object):
|
|||
return data_list
|
||||
|
||||
|
||||
class Pad(object):
|
||||
def __init__(self, size_div=32, **kwargs):
|
||||
self.size_div = size_div
|
||||
|
||||
def __call__(self, data):
|
||||
|
||||
img = data['image']
|
||||
resize_h2 = max(int(math.ceil(img.shape[0] / 32) * 32), 32)
|
||||
resize_w2 = max(int(math.ceil(img.shape[1] / 32) * 32), 32)
|
||||
img = cv2.copyMakeBorder(
|
||||
img,
|
||||
0,
|
||||
resize_h2 - img.shape[0],
|
||||
0,
|
||||
resize_w2 - img.shape[1],
|
||||
cv2.BORDER_CONSTANT,
|
||||
value=0)
|
||||
data['image'] = img
|
||||
return data
|
||||
|
||||
|
||||
class Resize(object):
|
||||
def __init__(self, size=(640, 640), **kwargs):
|
||||
self.size = size
|
||||
|
|
|
@ -1,3 +1,21 @@
|
|||
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This code is refer from:
|
||||
https://github.com/open-mmlab/mmocr/blob/main/mmocr/models/textdet/losses/fce_loss.py
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
from paddle import nn
|
||||
import paddle
|
||||
|
@ -39,7 +57,6 @@ class FCELoss(nn.Layer):
|
|||
assert p3_maps[0].shape[0] == 4 * self.fourier_degree + 5,\
|
||||
'fourier degree not equal in FCEhead and FCEtarget'
|
||||
|
||||
# device = preds[0][0].device
|
||||
# to tensor
|
||||
gts = [p3_maps, p4_maps, p5_maps]
|
||||
for idx, maps in enumerate(gts):
|
||||
|
@ -94,7 +111,6 @@ class FCELoss(nn.Layer):
|
|||
[tr_train_mask.unsqueeze(1), tr_train_mask.unsqueeze(1)], axis=1)
|
||||
# tr loss
|
||||
loss_tr = self.ohem(tr_pred, tr_mask, train_mask)
|
||||
# import pdb; pdb.set_trace()
|
||||
# tcl loss
|
||||
loss_tcl = paddle.to_tensor(0.).astype('float32')
|
||||
tr_neg_mask = tr_train_mask.logical_not()
|
||||
|
@ -138,7 +154,6 @@ class FCELoss(nn.Layer):
|
|||
return loss_tr, loss_tcl, loss_reg_x, loss_reg_y
|
||||
|
||||
def ohem(self, predict, target, train_mask):
|
||||
# device = train_mask.device
|
||||
|
||||
pos = (target * train_mask).astype('bool')
|
||||
neg = ((1 - target) * train_mask).astype('bool')
|
||||
|
|
|
@ -1,3 +1,21 @@
|
|||
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This code is refer from:
|
||||
https://github.com/open-mmlab/mmocr/blob/main/mmocr/models/textdet/dense_heads/fce_head.py
|
||||
"""
|
||||
|
||||
from paddle import nn
|
||||
from paddle import ParamAttr
|
||||
import paddle.nn.functional as F
|
||||
|
@ -7,22 +25,6 @@ from functools import partial
|
|||
|
||||
|
||||
def multi_apply(func, *args, **kwargs):
|
||||
"""Apply function to a list of arguments.
|
||||
|
||||
Note:
|
||||
This function applies the ``func`` to multiple inputs and
|
||||
map the multiple outputs of the ``func`` into different
|
||||
list. Each list contains the same type of outputs corresponding
|
||||
to different inputs.
|
||||
|
||||
Args:
|
||||
func (Function): A function that will be applied to a list of
|
||||
arguments
|
||||
|
||||
Returns:
|
||||
tuple(list): A tuple containing multiple list, each list contains \
|
||||
a kind of returned results by the function
|
||||
"""
|
||||
pfunc = partial(func, **kwargs) if kwargs else func
|
||||
map_results = map(pfunc, *args)
|
||||
return tuple(map(list, zip(*map_results)))
|
||||
|
|
|
@ -1,3 +1,21 @@
|
|||
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This code is refer from:
|
||||
https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.3/ppdet/modeling/necks/fpn.py
|
||||
"""
|
||||
|
||||
import paddle.nn as nn
|
||||
import paddle.nn.functional as F
|
||||
from paddle import ParamAttr
|
||||
|
|
|
@ -1,143 +1,26 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This code is refer from:
|
||||
https://github.com/open-mmlab/mmocr/blob/v0.3.0/mmocr/models/textdet/postprocess/wrapper.py
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import cv2
|
||||
import paddle
|
||||
import numpy as np
|
||||
from numpy.fft import ifft
|
||||
import Polygon as plg
|
||||
|
||||
|
||||
def points2polygon(points):
|
||||
"""Convert k points to 1 polygon.
|
||||
|
||||
Args:
|
||||
points (ndarray or list): A ndarray or a list of shape (2k)
|
||||
that indicates k points.
|
||||
|
||||
Returns:
|
||||
polygon (Polygon): A polygon object.
|
||||
"""
|
||||
if isinstance(points, list):
|
||||
points = np.array(points)
|
||||
|
||||
assert isinstance(points, np.ndarray)
|
||||
assert (points.size % 2 == 0) and (points.size >= 8)
|
||||
|
||||
point_mat = points.reshape([-1, 2])
|
||||
return plg.Polygon(point_mat)
|
||||
|
||||
|
||||
def poly_intersection(poly_det, poly_gt):
|
||||
"""Calculate the intersection area between two polygon.
|
||||
|
||||
Args:
|
||||
poly_det (Polygon): A polygon predicted by detector.
|
||||
poly_gt (Polygon): A gt polygon.
|
||||
|
||||
Returns:
|
||||
intersection_area (float): The intersection area between two polygons.
|
||||
"""
|
||||
assert isinstance(poly_det, plg.Polygon)
|
||||
assert isinstance(poly_gt, plg.Polygon)
|
||||
|
||||
poly_inter = poly_det & poly_gt
|
||||
if len(poly_inter) == 0:
|
||||
return 0, poly_inter
|
||||
return poly_inter.area(), poly_inter
|
||||
|
||||
|
||||
def poly_union(poly_det, poly_gt):
|
||||
"""Calculate the union area between two polygon.
|
||||
|
||||
Args:
|
||||
poly_det (Polygon): A polygon predicted by detector.
|
||||
poly_gt (Polygon): A gt polygon.
|
||||
|
||||
Returns:
|
||||
union_area (float): The union area between two polygons.
|
||||
"""
|
||||
assert isinstance(poly_det, plg.Polygon)
|
||||
assert isinstance(poly_gt, plg.Polygon)
|
||||
|
||||
area_det = poly_det.area()
|
||||
area_gt = poly_gt.area()
|
||||
area_inters, _ = poly_intersection(poly_det, poly_gt)
|
||||
return area_det + area_gt - area_inters
|
||||
|
||||
|
||||
def valid_boundary(x, with_score=True):
|
||||
num = len(x)
|
||||
if num < 8:
|
||||
return False
|
||||
if num % 2 == 0 and (not with_score):
|
||||
return True
|
||||
if num % 2 == 1 and with_score:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def boundary_iou(src, target):
|
||||
"""Calculate the IOU between two boundaries.
|
||||
|
||||
Args:
|
||||
src (list): Source boundary.
|
||||
target (list): Target boundary.
|
||||
|
||||
Returns:
|
||||
iou (float): The iou between two boundaries.
|
||||
"""
|
||||
assert valid_boundary(src, False)
|
||||
assert valid_boundary(target, False)
|
||||
src_poly = points2polygon(src)
|
||||
target_poly = points2polygon(target)
|
||||
|
||||
return poly_iou(src_poly, target_poly)
|
||||
|
||||
|
||||
def poly_iou(poly_det, poly_gt):
|
||||
"""Calculate the IOU between two polygons.
|
||||
|
||||
Args:
|
||||
poly_det (Polygon): A polygon predicted by detector.
|
||||
poly_gt (Polygon): A gt polygon.
|
||||
|
||||
Returns:
|
||||
iou (float): The IOU between two polygons.
|
||||
"""
|
||||
assert isinstance(poly_det, plg.Polygon)
|
||||
assert isinstance(poly_gt, plg.Polygon)
|
||||
area_inters, _ = poly_intersection(poly_det, poly_gt)
|
||||
area_union = poly_union(poly_det, poly_gt)
|
||||
if area_union == 0:
|
||||
return 0.0
|
||||
return area_inters / area_union
|
||||
|
||||
|
||||
def poly_nms(polygons, threshold):
|
||||
assert isinstance(polygons, list)
|
||||
|
||||
polygons = np.array(sorted(polygons, key=lambda x: x[-1]))
|
||||
|
||||
keep_poly = []
|
||||
index = [i for i in range(polygons.shape[0])]
|
||||
|
||||
while len(index) > 0:
|
||||
keep_poly.append(polygons[index[-1]].tolist())
|
||||
A = polygons[index[-1]][:-1]
|
||||
index = np.delete(index, -1)
|
||||
|
||||
iou_list = np.zeros((len(index), ))
|
||||
for i in range(len(index)):
|
||||
B = polygons[index[i]][:-1]
|
||||
|
||||
iou_list[i] = boundary_iou(A, B)
|
||||
remove_index = np.where(iou_list > threshold)
|
||||
index = np.delete(index, remove_index)
|
||||
|
||||
return keep_poly
|
||||
from ppocr.utils.poly_nms import poly_nms, valid_boundary
|
||||
|
||||
|
||||
def fill_hole(input_mask):
|
||||
|
@ -177,96 +60,6 @@ def fourier2poly(fourier_coeff, num_reconstr_points=50):
|
|||
return polygon.astype('int32').reshape((len(fourier_coeff), -1))
|
||||
|
||||
|
||||
def fcenet_decode(preds,
|
||||
fourier_degree,
|
||||
num_reconstr_points,
|
||||
scale,
|
||||
alpha=1.0,
|
||||
beta=2.0,
|
||||
text_repr_type='poly',
|
||||
score_thr=0.3,
|
||||
nms_thr=0.1):
|
||||
"""Decoding predictions of FCENet to instances.
|
||||
|
||||
Args:
|
||||
preds (list(Tensor)): The head output tensors.
|
||||
fourier_degree (int): The maximum Fourier transform degree k.
|
||||
num_reconstr_points (int): The points number of the polygon
|
||||
reconstructed from predicted Fourier coefficients.
|
||||
scale (int): The down-sample scale of the prediction.
|
||||
alpha (float) : The parameter to calculate final scores. Score_{final}
|
||||
= (Score_{text region} ^ alpha)
|
||||
* (Score_{text center region}^ beta)
|
||||
beta (float) : The parameter to calculate final score.
|
||||
text_repr_type (str): Boundary encoding type 'poly' or 'quad'.
|
||||
score_thr (float) : The threshold used to filter out the final
|
||||
candidates.
|
||||
nms_thr (float) : The threshold of nms.
|
||||
|
||||
Returns:
|
||||
boundaries (list[list[float]]): The instance boundary and confidence
|
||||
list.
|
||||
"""
|
||||
assert isinstance(preds, list)
|
||||
assert len(preds) == 2
|
||||
assert text_repr_type in ['poly', 'quad']
|
||||
|
||||
# import pdb;pdb.set_trace()
|
||||
cls_pred = preds[0][0]
|
||||
# tr_pred = F.softmax(cls_pred[0:2], axis=0).cpu().numpy()
|
||||
# tcl_pred = F.softmax(cls_pred[2:], axis=0).cpu().numpy()
|
||||
|
||||
tr_pred = cls_pred[0:2]
|
||||
tcl_pred = cls_pred[2:]
|
||||
|
||||
reg_pred = preds[1][0].transpose([1, 2, 0]) #.cpu().numpy()
|
||||
x_pred = reg_pred[:, :, :2 * fourier_degree + 1]
|
||||
y_pred = reg_pred[:, :, 2 * fourier_degree + 1:]
|
||||
|
||||
score_pred = (tr_pred[1]**alpha) * (tcl_pred[1]**beta)
|
||||
tr_pred_mask = (score_pred) > score_thr
|
||||
tr_mask = fill_hole(tr_pred_mask)
|
||||
|
||||
tr_contours, _ = cv2.findContours(
|
||||
tr_mask.astype(np.uint8), cv2.RETR_TREE,
|
||||
cv2.CHAIN_APPROX_SIMPLE) # opencv4
|
||||
|
||||
mask = np.zeros_like(tr_mask)
|
||||
boundaries = []
|
||||
for cont in tr_contours:
|
||||
deal_map = mask.copy().astype(np.int8)
|
||||
cv2.drawContours(deal_map, [cont], -1, 1, -1)
|
||||
|
||||
score_map = score_pred * deal_map
|
||||
score_mask = score_map > 0
|
||||
xy_text = np.argwhere(score_mask)
|
||||
dxy = xy_text[:, 1] + xy_text[:, 0] * 1j
|
||||
|
||||
x, y = x_pred[score_mask], y_pred[score_mask]
|
||||
c = x + y * 1j
|
||||
c[:, fourier_degree] = c[:, fourier_degree] + dxy
|
||||
c *= scale
|
||||
|
||||
polygons = fourier2poly(c, num_reconstr_points)
|
||||
score = score_map[score_mask].reshape(-1, 1)
|
||||
polygons = poly_nms(np.hstack((polygons, score)).tolist(), nms_thr)
|
||||
|
||||
boundaries = boundaries + polygons
|
||||
|
||||
boundaries = poly_nms(boundaries, nms_thr)
|
||||
|
||||
if text_repr_type == 'quad':
|
||||
new_boundaries = []
|
||||
for boundary in boundaries:
|
||||
poly = np.array(boundary[:-1]).reshape(-1, 2).astype(np.float32)
|
||||
score = boundary[-1]
|
||||
points = cv2.boxPoints(cv2.minAreaRect(poly))
|
||||
points = np.int0(points)
|
||||
new_boundaries.append(points.reshape(-1).tolist() + [score])
|
||||
|
||||
return boundaries
|
||||
|
||||
|
||||
class FCEPostProcess(object):
|
||||
"""
|
||||
The post process for FCENet.
|
||||
|
@ -316,10 +109,6 @@ class FCEPostProcess(object):
|
|||
Returns:
|
||||
boundaries (list[list[float]]): The scaled boundaries.
|
||||
"""
|
||||
# assert check_argument.is_2dlist(boundaries)
|
||||
# assert isinstance(scale_factor, np.ndarray)
|
||||
# assert scale_factor.shape[0] == 4
|
||||
|
||||
boxes = []
|
||||
scores = []
|
||||
for b in boundaries:
|
||||
|
@ -335,7 +124,6 @@ class FCEPostProcess(object):
|
|||
|
||||
def get_boundary(self, score_maps, shape_list):
|
||||
assert len(score_maps) == len(self.scales)
|
||||
# import pdb;pdb.set_trace()
|
||||
boundaries = []
|
||||
for idx, score_map in enumerate(score_maps):
|
||||
scale = self.scales[idx]
|
||||
|
@ -344,8 +132,6 @@ class FCEPostProcess(object):
|
|||
|
||||
# nms
|
||||
boundaries = poly_nms(boundaries, self.nms_thr)
|
||||
# if rescale:
|
||||
# import pdb;pdb.set_trace()
|
||||
boundaries, scores = self.resize_boundary(
|
||||
boundaries, (1 / shape_list[0, 2:]).tolist()[::-1])
|
||||
|
||||
|
@ -356,7 +142,7 @@ class FCEPostProcess(object):
|
|||
assert len(score_map) == 2
|
||||
assert score_map[1].shape[1] == 4 * self.fourier_degree + 2
|
||||
|
||||
return fcenet_decode(
|
||||
return self.fcenet_decode(
|
||||
preds=score_map,
|
||||
fourier_degree=self.fourier_degree,
|
||||
num_reconstr_points=self.num_reconstr_points,
|
||||
|
@ -366,3 +152,89 @@ class FCEPostProcess(object):
|
|||
text_repr_type=self.text_repr_type,
|
||||
score_thr=self.score_thr,
|
||||
nms_thr=self.nms_thr)
|
||||
|
||||
def fcenet_decode(self,
|
||||
preds,
|
||||
fourier_degree,
|
||||
num_reconstr_points,
|
||||
scale,
|
||||
alpha=1.0,
|
||||
beta=2.0,
|
||||
text_repr_type='poly',
|
||||
score_thr=0.3,
|
||||
nms_thr=0.1):
|
||||
"""Decoding predictions of FCENet to instances.
|
||||
|
||||
Args:
|
||||
preds (list(Tensor)): The head output tensors.
|
||||
fourier_degree (int): The maximum Fourier transform degree k.
|
||||
num_reconstr_points (int): The points number of the polygon
|
||||
reconstructed from predicted Fourier coefficients.
|
||||
scale (int): The down-sample scale of the prediction.
|
||||
alpha (float) : The parameter to calculate final scores. Score_{final}
|
||||
= (Score_{text region} ^ alpha)
|
||||
* (Score_{text center region}^ beta)
|
||||
beta (float) : The parameter to calculate final score.
|
||||
text_repr_type (str): Boundary encoding type 'poly' or 'quad'.
|
||||
score_thr (float) : The threshold used to filter out the final
|
||||
candidates.
|
||||
nms_thr (float) : The threshold of nms.
|
||||
|
||||
Returns:
|
||||
boundaries (list[list[float]]): The instance boundary and confidence
|
||||
list.
|
||||
"""
|
||||
assert isinstance(preds, list)
|
||||
assert len(preds) == 2
|
||||
assert text_repr_type in ['poly', 'quad']
|
||||
|
||||
cls_pred = preds[0][0]
|
||||
tr_pred = cls_pred[0:2]
|
||||
tcl_pred = cls_pred[2:]
|
||||
|
||||
reg_pred = preds[1][0].transpose([1, 2, 0])
|
||||
x_pred = reg_pred[:, :, :2 * fourier_degree + 1]
|
||||
y_pred = reg_pred[:, :, 2 * fourier_degree + 1:]
|
||||
|
||||
score_pred = (tr_pred[1]**alpha) * (tcl_pred[1]**beta)
|
||||
tr_pred_mask = (score_pred) > score_thr
|
||||
tr_mask = fill_hole(tr_pred_mask)
|
||||
|
||||
tr_contours, _ = cv2.findContours(
|
||||
tr_mask.astype(np.uint8), cv2.RETR_TREE,
|
||||
cv2.CHAIN_APPROX_SIMPLE) # opencv4
|
||||
|
||||
mask = np.zeros_like(tr_mask)
|
||||
boundaries = []
|
||||
for cont in tr_contours:
|
||||
deal_map = mask.copy().astype(np.int8)
|
||||
cv2.drawContours(deal_map, [cont], -1, 1, -1)
|
||||
|
||||
score_map = score_pred * deal_map
|
||||
score_mask = score_map > 0
|
||||
xy_text = np.argwhere(score_mask)
|
||||
dxy = xy_text[:, 1] + xy_text[:, 0] * 1j
|
||||
|
||||
x, y = x_pred[score_mask], y_pred[score_mask]
|
||||
c = x + y * 1j
|
||||
c[:, fourier_degree] = c[:, fourier_degree] + dxy
|
||||
c *= scale
|
||||
|
||||
polygons = fourier2poly(c, num_reconstr_points)
|
||||
score = score_map[score_mask].reshape(-1, 1)
|
||||
polygons = poly_nms(np.hstack((polygons, score)).tolist(), nms_thr)
|
||||
|
||||
boundaries = boundaries + polygons
|
||||
|
||||
boundaries = poly_nms(boundaries, nms_thr)
|
||||
|
||||
if text_repr_type == 'quad':
|
||||
new_boundaries = []
|
||||
for boundary in boundaries:
|
||||
poly = np.array(boundary[:-1]).reshape(-1, 2).astype(np.float32)
|
||||
score = boundary[-1]
|
||||
points = cv2.boxPoints(cv2.minAreaRect(poly))
|
||||
points = np.int0(points)
|
||||
new_boundaries.append(points.reshape(-1).tolist() + [score])
|
||||
|
||||
return boundaries
|
||||
|
|
|
@ -0,0 +1,145 @@
|
|||
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
import Polygon as plg
|
||||
|
||||
|
||||
def points2polygon(points):
|
||||
"""Convert k points to 1 polygon.
|
||||
|
||||
Args:
|
||||
points (ndarray or list): A ndarray or a list of shape (2k)
|
||||
that indicates k points.
|
||||
|
||||
Returns:
|
||||
polygon (Polygon): A polygon object.
|
||||
"""
|
||||
if isinstance(points, list):
|
||||
points = np.array(points)
|
||||
|
||||
assert isinstance(points, np.ndarray)
|
||||
assert (points.size % 2 == 0) and (points.size >= 8)
|
||||
|
||||
point_mat = points.reshape([-1, 2])
|
||||
return plg.Polygon(point_mat)
|
||||
|
||||
|
||||
def poly_intersection(poly_det, poly_gt):
|
||||
"""Calculate the intersection area between two polygon.
|
||||
|
||||
Args:
|
||||
poly_det (Polygon): A polygon predicted by detector.
|
||||
poly_gt (Polygon): A gt polygon.
|
||||
|
||||
Returns:
|
||||
intersection_area (float): The intersection area between two polygons.
|
||||
"""
|
||||
assert isinstance(poly_det, plg.Polygon)
|
||||
assert isinstance(poly_gt, plg.Polygon)
|
||||
|
||||
poly_inter = poly_det & poly_gt
|
||||
if len(poly_inter) == 0:
|
||||
return 0, poly_inter
|
||||
return poly_inter.area(), poly_inter
|
||||
|
||||
|
||||
def poly_union(poly_det, poly_gt):
|
||||
"""Calculate the union area between two polygon.
|
||||
|
||||
Args:
|
||||
poly_det (Polygon): A polygon predicted by detector.
|
||||
poly_gt (Polygon): A gt polygon.
|
||||
|
||||
Returns:
|
||||
union_area (float): The union area between two polygons.
|
||||
"""
|
||||
assert isinstance(poly_det, plg.Polygon)
|
||||
assert isinstance(poly_gt, plg.Polygon)
|
||||
|
||||
area_det = poly_det.area()
|
||||
area_gt = poly_gt.area()
|
||||
area_inters, _ = poly_intersection(poly_det, poly_gt)
|
||||
return area_det + area_gt - area_inters
|
||||
|
||||
|
||||
def valid_boundary(x, with_score=True):
|
||||
num = len(x)
|
||||
if num < 8:
|
||||
return False
|
||||
if num % 2 == 0 and (not with_score):
|
||||
return True
|
||||
if num % 2 == 1 and with_score:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def boundary_iou(src, target):
|
||||
"""Calculate the IOU between two boundaries.
|
||||
|
||||
Args:
|
||||
src (list): Source boundary.
|
||||
target (list): Target boundary.
|
||||
|
||||
Returns:
|
||||
iou (float): The iou between two boundaries.
|
||||
"""
|
||||
assert valid_boundary(src, False)
|
||||
assert valid_boundary(target, False)
|
||||
src_poly = points2polygon(src)
|
||||
target_poly = points2polygon(target)
|
||||
|
||||
return poly_iou(src_poly, target_poly)
|
||||
|
||||
|
||||
def poly_iou(poly_det, poly_gt):
|
||||
"""Calculate the IOU between two polygons.
|
||||
|
||||
Args:
|
||||
poly_det (Polygon): A polygon predicted by detector.
|
||||
poly_gt (Polygon): A gt polygon.
|
||||
|
||||
Returns:
|
||||
iou (float): The IOU between two polygons.
|
||||
"""
|
||||
assert isinstance(poly_det, plg.Polygon)
|
||||
assert isinstance(poly_gt, plg.Polygon)
|
||||
area_inters, _ = poly_intersection(poly_det, poly_gt)
|
||||
area_union = poly_union(poly_det, poly_gt)
|
||||
if area_union == 0:
|
||||
return 0.0
|
||||
return area_inters / area_union
|
||||
|
||||
|
||||
def poly_nms(polygons, threshold):
|
||||
assert isinstance(polygons, list)
|
||||
|
||||
polygons = np.array(sorted(polygons, key=lambda x: x[-1]))
|
||||
|
||||
keep_poly = []
|
||||
index = [i for i in range(polygons.shape[0])]
|
||||
|
||||
while len(index) > 0:
|
||||
keep_poly.append(polygons[index[-1]].tolist())
|
||||
A = polygons[index[-1]][:-1]
|
||||
index = np.delete(index, -1)
|
||||
iou_list = np.zeros((len(index), ))
|
||||
for i in range(len(index)):
|
||||
B = polygons[index[i]][:-1]
|
||||
iou_list[i] = boundary_iou(A, B)
|
||||
remove_index = np.where(iou_list > threshold)
|
||||
index = np.delete(index, remove_index)
|
||||
|
||||
return keep_poly
|
|
@ -503,7 +503,7 @@ def preprocess(is_train=False):
|
|||
assert alg in [
|
||||
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
|
||||
'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE',
|
||||
'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM'
|
||||
'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'FCE'
|
||||
]
|
||||
|
||||
device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu'
|
||||
|
|
3
train.sh
3
train.sh
|
@ -1,3 +1,2 @@
|
|||
# recommended paddle.__version__ == 2.0.0
|
||||
# python3 -m paddle.distributed.launch --log_dir=./debug/ --gpus '0,1,2,3,4,5,6,7' tools/train.py -c configs/rec/rec_mv3_none_bilstm_ctc.yml
|
||||
python -m paddle.distributed.launch --gpus '7' tools/train.py -c configs/det/det_r50_fce_ctw.yml
|
||||
python3 -m paddle.distributed.launch --log_dir=./debug/ --gpus '0,1,2,3,4,5,6,7' tools/train.py -c configs/rec/rec_mv3_none_bilstm_ctc.yml
|
||||
|
|
Loading…
Reference in New Issue