mirror of
https://github.com/PaddlePaddle/PaddleOCR.git
synced 2025-06-03 21:53:39 +08:00
Fix copy_paste no texts augment.
This commit is contained in:
parent
473a811667
commit
de91f9a010
@ -35,10 +35,12 @@ class CopyPaste(object):
|
|||||||
point_num = data['polys'].shape[1]
|
point_num = data['polys'].shape[1]
|
||||||
src_img = data['image']
|
src_img = data['image']
|
||||||
src_polys = data['polys'].tolist()
|
src_polys = data['polys'].tolist()
|
||||||
|
src_texts = data['texts']
|
||||||
src_ignores = data['ignore_tags'].tolist()
|
src_ignores = data['ignore_tags'].tolist()
|
||||||
ext_data = data['ext_data'][0]
|
ext_data = data['ext_data'][0]
|
||||||
ext_image = ext_data['image']
|
ext_image = ext_data['image']
|
||||||
ext_polys = ext_data['polys']
|
ext_polys = ext_data['polys']
|
||||||
|
ext_texts = ext_data['texts']
|
||||||
ext_ignores = ext_data['ignore_tags']
|
ext_ignores = ext_data['ignore_tags']
|
||||||
|
|
||||||
indexs = [i for i in range(len(ext_ignores)) if not ext_ignores[i]]
|
indexs = [i for i in range(len(ext_ignores)) if not ext_ignores[i]]
|
||||||
@ -53,7 +55,7 @@ class CopyPaste(object):
|
|||||||
src_img = cv2.cvtColor(src_img, cv2.COLOR_BGR2RGB)
|
src_img = cv2.cvtColor(src_img, cv2.COLOR_BGR2RGB)
|
||||||
ext_image = cv2.cvtColor(ext_image, cv2.COLOR_BGR2RGB)
|
ext_image = cv2.cvtColor(ext_image, cv2.COLOR_BGR2RGB)
|
||||||
src_img = Image.fromarray(src_img).convert('RGBA')
|
src_img = Image.fromarray(src_img).convert('RGBA')
|
||||||
for poly, tag in zip(select_polys, select_ignores):
|
for idx, poly, tag in zip(select_idxs, select_polys, select_ignores):
|
||||||
box_img = get_rotate_crop_image(ext_image, poly)
|
box_img = get_rotate_crop_image(ext_image, poly)
|
||||||
|
|
||||||
src_img, box = self.paste_img(src_img, box_img, src_polys)
|
src_img, box = self.paste_img(src_img, box_img, src_polys)
|
||||||
@ -62,6 +64,7 @@ class CopyPaste(object):
|
|||||||
for _ in range(len(box), point_num):
|
for _ in range(len(box), point_num):
|
||||||
box.append(box[-1])
|
box.append(box[-1])
|
||||||
src_polys.append(box)
|
src_polys.append(box)
|
||||||
|
src_texts.append(ext_texts[idx])
|
||||||
src_ignores.append(tag)
|
src_ignores.append(tag)
|
||||||
src_img = cv2.cvtColor(np.array(src_img), cv2.COLOR_RGB2BGR)
|
src_img = cv2.cvtColor(np.array(src_img), cv2.COLOR_RGB2BGR)
|
||||||
h, w = src_img.shape[:2]
|
h, w = src_img.shape[:2]
|
||||||
@ -70,6 +73,7 @@ class CopyPaste(object):
|
|||||||
src_polys[:, :, 1] = np.clip(src_polys[:, :, 1], 0, h)
|
src_polys[:, :, 1] = np.clip(src_polys[:, :, 1], 0, h)
|
||||||
data['image'] = src_img
|
data['image'] = src_img
|
||||||
data['polys'] = src_polys
|
data['polys'] = src_polys
|
||||||
|
data['texts'] = src_texts
|
||||||
data['ignore_tags'] = np.array(src_ignores)
|
data['ignore_tags'] = np.array(src_ignores)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user