From feca63686d0dfdb0d46dc9c140e26643201bbbef Mon Sep 17 00:00:00 2001 From: WenmuZhou <572459439@qq.com> Date: Mon, 30 May 2022 16:44:50 +0800 Subject: [PATCH 1/2] add bda --- .../cls/ch_PP-OCRv3/ch_PP-OCRv3_rotnet.yml | 3 +- configs/cls/cls_mv3.yml | 3 +- doc/doc_ch/FAQ.md | 2 +- ppocr/data/imaug/__init__.py | 2 +- ppocr/data/imaug/rec_img_aug.py | 160 ++++++++---------- 5 files changed, 70 insertions(+), 100 deletions(-) diff --git a/configs/cls/ch_PP-OCRv3/ch_PP-OCRv3_rotnet.yml b/configs/cls/ch_PP-OCRv3/ch_PP-OCRv3_rotnet.yml index 1ffeba0799..f7e327d1e0 100644 --- a/configs/cls/ch_PP-OCRv3/ch_PP-OCRv3_rotnet.yml +++ b/configs/cls/ch_PP-OCRv3/ch_PP-OCRv3_rotnet.yml @@ -63,8 +63,7 @@ Train: - DecodeImage: img_mode: BGR channel_first: false - - RecAug: - use_tia: False + - BaseDataAugmentation: - RandAugment: - SSLRotateResize: image_shape: [3, 48, 320] diff --git a/configs/cls/cls_mv3.yml b/configs/cls/cls_mv3.yml index 5e643dc383..0c46ff5602 100644 --- a/configs/cls/cls_mv3.yml +++ b/configs/cls/cls_mv3.yml @@ -60,8 +60,7 @@ Train: img_mode: BGR channel_first: False - ClsLabelEncode: # Class handling label - - RecAug: - use_tia: False + - BaseDataAugmentation: - RandAugment: - ClsResizeImg: image_shape: [3, 48, 192] diff --git a/doc/doc_ch/FAQ.md b/doc/doc_ch/FAQ.md index 24f8a3e92b..2dad829284 100644 --- a/doc/doc_ch/FAQ.md +++ b/doc/doc_ch/FAQ.md @@ -682,7 +682,7 @@ lr: #### Q: 关于dygraph分支中,文本识别模型训练,要使用数据增强应该如何设置? -**A**:可以参考[配置文件](../../configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml)在Train['dataset']['transforms']添加RecAug字段,使数据增强生效。可以通过添加对aug_prob设置,表示每种数据增强采用的概率。aug_prob默认是0.4.由于tia数据增强特殊性,默认不采用,可以通过添加use_tia设置,使tia数据增强生效。详细设置可以参考[ISSUE 1744](https://github.com/PaddlePaddle/PaddleOCR/issues/1744)。 +**A**:可以参考[配置文件](../../configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml)在Train['dataset']['transforms']添加RecAug字段,使数据增强生效。可以通过添加对aug_prob设置,表示每种数据增强采用的概率。aug_prob默认是0.4。详细设置可以参考[ISSUE 1744](https://github.com/PaddlePaddle/PaddleOCR/issues/1744)。 #### Q: 训练过程中,训练程序意外退出/挂起,应该如何解决? diff --git a/ppocr/data/imaug/__init__.py b/ppocr/data/imaug/__init__.py index 548832fb0d..f0fd578f61 100644 --- a/ppocr/data/imaug/__init__.py +++ b/ppocr/data/imaug/__init__.py @@ -22,7 +22,7 @@ from .make_shrink_map import MakeShrinkMap from .random_crop_data import EastRandomCropData, RandomCropImgMask from .make_pse_gt import MakePseGt -from .rec_img_aug import RecAug, RecConAug, RecResizeImg, ClsResizeImg, \ +from .rec_img_aug import BaseDataAugmentation, RecAug, RecConAug, RecResizeImg, ClsResizeImg, \ SRNRecResizeImg, NRTRRecResizeImg, SARRecResizeImg, PRENResizeImg from .ssl_img_aug import SSLRotateResize from .randaugment import RandAugment diff --git a/ppocr/data/imaug/rec_img_aug.py b/ppocr/data/imaug/rec_img_aug.py index 7483dffe5b..67b92d71bf 100644 --- a/ppocr/data/imaug/rec_img_aug.py +++ b/ppocr/data/imaug/rec_img_aug.py @@ -22,13 +22,74 @@ from .text_image_aug import tia_perspective, tia_stretch, tia_distort class RecAug(object): - def __init__(self, use_tia=True, aug_prob=0.4, **kwargs): - self.use_tia = use_tia - self.aug_prob = aug_prob + def __init__(self, + tia_prob=True, + crop_prob=0.4, + reverse_prob=0.4, + noise_prob=0.4, + jitter_prob=0.4, + blur_prob=0.4, + hsv_aug_prob=0.4, + **kwargs): + self.tia_prob = tia_prob + self.bda = BaseDataAugmentation(crop_prob, reverse_prob, noise_prob, + jitter_prob, blur_prob, hsv_aug_prob) def __call__(self, data): img = data['image'] - img = warp(img, 10, self.use_tia, self.aug_prob) + h, w, _ = img.shape + + # tia + if random.random() <= self.tia_prob: + if h >= 20 and w >= 20: + img = tia_distort(img, random.randint(3, 6)) + img = tia_stretch(img, random.randint(3, 6)) + img = tia_perspective(img) + + # bda + data['image'] = img + data = self.bda(data) + return data + + +class BaseDataAugmentation(object): + def __init__(self, + crop_prob=0.4, + reverse_prob=0.4, + noise_prob=0.4, + jitter_prob=0.4, + blur_prob=0.4, + hsv_aug_prob=0.4, + **kwargs): + self.crop_prob = crop_prob + self.reverse_prob = reverse_prob + self.noise_prob = noise_prob + self.jitter_prob = jitter_prob + self.blur_prob = blur_prob + self.hsv_aug_prob = hsv_aug_prob + + def __call__(self, data): + img = data['image'] + h, w, _ = img.shape + + if random.random() <= self.crop_prob and h >= 20 and w >= 20: + img = get_crop(img) + + if random.random() <= self.blur_prob: + img = blur(img) + + if random.random() <= self.hsv_aug_prob: + img = hsv_aug(img) + + if random.random() <= self.jitter_prob: + img = jitter(img) + + if random.random() <= self.noise_prob: + img = add_gasuss_noise(img) + + if random.random() <= self.reverse_prob: + img = 255 - img + data['image'] = img return data @@ -359,7 +420,7 @@ def flag(): return 1 if random.random() > 0.5000001 else -1 -def cvtColor(img): +def hsv_aug(img): """ cvtColor """ @@ -427,50 +488,6 @@ def get_crop(image): return crop_img -class Config: - """ - Config - """ - - def __init__(self, use_tia): - self.anglex = random.random() * 30 - self.angley = random.random() * 15 - self.anglez = random.random() * 10 - self.fov = 42 - self.r = 0 - self.shearx = random.random() * 0.3 - self.sheary = random.random() * 0.05 - self.borderMode = cv2.BORDER_REPLICATE - self.use_tia = use_tia - - def make(self, w, h, ang): - """ - make - """ - self.anglex = random.random() * 5 * flag() - self.angley = random.random() * 5 * flag() - self.anglez = -1 * random.random() * int(ang) * flag() - self.fov = 42 - self.r = 0 - self.shearx = 0 - self.sheary = 0 - self.borderMode = cv2.BORDER_REPLICATE - self.w = w - self.h = h - - self.perspective = self.use_tia - self.stretch = self.use_tia - self.distort = self.use_tia - - self.crop = True - self.affine = False - self.reverse = True - self.noise = True - self.jitter = True - self.blur = True - self.color = True - - def rad(x): """ rad @@ -554,48 +571,3 @@ def get_warpAffine(config): rz = np.array([[np.cos(rad(anglez)), np.sin(rad(anglez)), 0], [-np.sin(rad(anglez)), np.cos(rad(anglez)), 0]], np.float32) return rz - - -def warp(img, ang, use_tia=True, prob=0.4): - """ - warp - """ - h, w, _ = img.shape - config = Config(use_tia=use_tia) - config.make(w, h, ang) - new_img = img - - if config.distort: - img_height, img_width = img.shape[0:2] - if random.random() <= prob and img_height >= 20 and img_width >= 20: - new_img = tia_distort(new_img, random.randint(3, 6)) - - if config.stretch: - img_height, img_width = img.shape[0:2] - if random.random() <= prob and img_height >= 20 and img_width >= 20: - new_img = tia_stretch(new_img, random.randint(3, 6)) - - if config.perspective: - if random.random() <= prob: - new_img = tia_perspective(new_img) - - if config.crop: - img_height, img_width = img.shape[0:2] - if random.random() <= prob and img_height >= 20 and img_width >= 20: - new_img = get_crop(new_img) - - if config.blur: - if random.random() <= prob: - new_img = blur(new_img) - if config.color: - if random.random() <= prob: - new_img = cvtColor(new_img) - if config.jitter: - new_img = jitter(new_img) - if config.noise: - if random.random() <= prob: - new_img = add_gasuss_noise(new_img) - if config.reverse: - if random.random() <= prob: - new_img = 255 - new_img - return new_img From 32bcea9c1c72c7324371fcb59c9e56ceff78d789 Mon Sep 17 00:00:00 2001 From: WenmuZhou <572459439@qq.com> Date: Mon, 30 May 2022 16:47:12 +0800 Subject: [PATCH 2/2] fix params error --- ppocr/data/imaug/rec_img_aug.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ppocr/data/imaug/rec_img_aug.py b/ppocr/data/imaug/rec_img_aug.py index 67b92d71bf..32de2b3fc3 100644 --- a/ppocr/data/imaug/rec_img_aug.py +++ b/ppocr/data/imaug/rec_img_aug.py @@ -23,7 +23,7 @@ from .text_image_aug import tia_perspective, tia_stretch, tia_distort class RecAug(object): def __init__(self, - tia_prob=True, + tia_prob=0.4, crop_prob=0.4, reverse_prob=0.4, noise_prob=0.4,