add bda
parent
8a5a98707c
commit
feca63686d
|
@ -63,8 +63,7 @@ Train:
|
|||
- DecodeImage:
|
||||
img_mode: BGR
|
||||
channel_first: false
|
||||
- RecAug:
|
||||
use_tia: False
|
||||
- BaseDataAugmentation:
|
||||
- RandAugment:
|
||||
- SSLRotateResize:
|
||||
image_shape: [3, 48, 320]
|
||||
|
|
|
@ -60,8 +60,7 @@ Train:
|
|||
img_mode: BGR
|
||||
channel_first: False
|
||||
- ClsLabelEncode: # Class handling label
|
||||
- RecAug:
|
||||
use_tia: False
|
||||
- BaseDataAugmentation:
|
||||
- RandAugment:
|
||||
- ClsResizeImg:
|
||||
image_shape: [3, 48, 192]
|
||||
|
|
|
@ -682,7 +682,7 @@ lr:
|
|||
|
||||
#### Q: 关于dygraph分支中,文本识别模型训练,要使用数据增强应该如何设置?
|
||||
|
||||
**A**:可以参考[配置文件](../../configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml)在Train['dataset']['transforms']添加RecAug字段,使数据增强生效。可以通过添加对aug_prob设置,表示每种数据增强采用的概率。aug_prob默认是0.4.由于tia数据增强特殊性,默认不采用,可以通过添加use_tia设置,使tia数据增强生效。详细设置可以参考[ISSUE 1744](https://github.com/PaddlePaddle/PaddleOCR/issues/1744)。
|
||||
**A**:可以参考[配置文件](../../configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml)在Train['dataset']['transforms']添加RecAug字段,使数据增强生效。可以通过添加对aug_prob设置,表示每种数据增强采用的概率。aug_prob默认是0.4。详细设置可以参考[ISSUE 1744](https://github.com/PaddlePaddle/PaddleOCR/issues/1744)。
|
||||
|
||||
#### Q: 训练过程中,训练程序意外退出/挂起,应该如何解决?
|
||||
|
||||
|
|
|
@ -22,7 +22,7 @@ from .make_shrink_map import MakeShrinkMap
|
|||
from .random_crop_data import EastRandomCropData, RandomCropImgMask
|
||||
from .make_pse_gt import MakePseGt
|
||||
|
||||
from .rec_img_aug import RecAug, RecConAug, RecResizeImg, ClsResizeImg, \
|
||||
from .rec_img_aug import BaseDataAugmentation, RecAug, RecConAug, RecResizeImg, ClsResizeImg, \
|
||||
SRNRecResizeImg, NRTRRecResizeImg, SARRecResizeImg, PRENResizeImg
|
||||
from .ssl_img_aug import SSLRotateResize
|
||||
from .randaugment import RandAugment
|
||||
|
|
|
@ -22,13 +22,74 @@ from .text_image_aug import tia_perspective, tia_stretch, tia_distort
|
|||
|
||||
|
||||
class RecAug(object):
|
||||
def __init__(self, use_tia=True, aug_prob=0.4, **kwargs):
|
||||
self.use_tia = use_tia
|
||||
self.aug_prob = aug_prob
|
||||
def __init__(self,
|
||||
tia_prob=True,
|
||||
crop_prob=0.4,
|
||||
reverse_prob=0.4,
|
||||
noise_prob=0.4,
|
||||
jitter_prob=0.4,
|
||||
blur_prob=0.4,
|
||||
hsv_aug_prob=0.4,
|
||||
**kwargs):
|
||||
self.tia_prob = tia_prob
|
||||
self.bda = BaseDataAugmentation(crop_prob, reverse_prob, noise_prob,
|
||||
jitter_prob, blur_prob, hsv_aug_prob)
|
||||
|
||||
def __call__(self, data):
|
||||
img = data['image']
|
||||
img = warp(img, 10, self.use_tia, self.aug_prob)
|
||||
h, w, _ = img.shape
|
||||
|
||||
# tia
|
||||
if random.random() <= self.tia_prob:
|
||||
if h >= 20 and w >= 20:
|
||||
img = tia_distort(img, random.randint(3, 6))
|
||||
img = tia_stretch(img, random.randint(3, 6))
|
||||
img = tia_perspective(img)
|
||||
|
||||
# bda
|
||||
data['image'] = img
|
||||
data = self.bda(data)
|
||||
return data
|
||||
|
||||
|
||||
class BaseDataAugmentation(object):
|
||||
def __init__(self,
|
||||
crop_prob=0.4,
|
||||
reverse_prob=0.4,
|
||||
noise_prob=0.4,
|
||||
jitter_prob=0.4,
|
||||
blur_prob=0.4,
|
||||
hsv_aug_prob=0.4,
|
||||
**kwargs):
|
||||
self.crop_prob = crop_prob
|
||||
self.reverse_prob = reverse_prob
|
||||
self.noise_prob = noise_prob
|
||||
self.jitter_prob = jitter_prob
|
||||
self.blur_prob = blur_prob
|
||||
self.hsv_aug_prob = hsv_aug_prob
|
||||
|
||||
def __call__(self, data):
|
||||
img = data['image']
|
||||
h, w, _ = img.shape
|
||||
|
||||
if random.random() <= self.crop_prob and h >= 20 and w >= 20:
|
||||
img = get_crop(img)
|
||||
|
||||
if random.random() <= self.blur_prob:
|
||||
img = blur(img)
|
||||
|
||||
if random.random() <= self.hsv_aug_prob:
|
||||
img = hsv_aug(img)
|
||||
|
||||
if random.random() <= self.jitter_prob:
|
||||
img = jitter(img)
|
||||
|
||||
if random.random() <= self.noise_prob:
|
||||
img = add_gasuss_noise(img)
|
||||
|
||||
if random.random() <= self.reverse_prob:
|
||||
img = 255 - img
|
||||
|
||||
data['image'] = img
|
||||
return data
|
||||
|
||||
|
@ -359,7 +420,7 @@ def flag():
|
|||
return 1 if random.random() > 0.5000001 else -1
|
||||
|
||||
|
||||
def cvtColor(img):
|
||||
def hsv_aug(img):
|
||||
"""
|
||||
cvtColor
|
||||
"""
|
||||
|
@ -427,50 +488,6 @@ def get_crop(image):
|
|||
return crop_img
|
||||
|
||||
|
||||
class Config:
|
||||
"""
|
||||
Config
|
||||
"""
|
||||
|
||||
def __init__(self, use_tia):
|
||||
self.anglex = random.random() * 30
|
||||
self.angley = random.random() * 15
|
||||
self.anglez = random.random() * 10
|
||||
self.fov = 42
|
||||
self.r = 0
|
||||
self.shearx = random.random() * 0.3
|
||||
self.sheary = random.random() * 0.05
|
||||
self.borderMode = cv2.BORDER_REPLICATE
|
||||
self.use_tia = use_tia
|
||||
|
||||
def make(self, w, h, ang):
|
||||
"""
|
||||
make
|
||||
"""
|
||||
self.anglex = random.random() * 5 * flag()
|
||||
self.angley = random.random() * 5 * flag()
|
||||
self.anglez = -1 * random.random() * int(ang) * flag()
|
||||
self.fov = 42
|
||||
self.r = 0
|
||||
self.shearx = 0
|
||||
self.sheary = 0
|
||||
self.borderMode = cv2.BORDER_REPLICATE
|
||||
self.w = w
|
||||
self.h = h
|
||||
|
||||
self.perspective = self.use_tia
|
||||
self.stretch = self.use_tia
|
||||
self.distort = self.use_tia
|
||||
|
||||
self.crop = True
|
||||
self.affine = False
|
||||
self.reverse = True
|
||||
self.noise = True
|
||||
self.jitter = True
|
||||
self.blur = True
|
||||
self.color = True
|
||||
|
||||
|
||||
def rad(x):
|
||||
"""
|
||||
rad
|
||||
|
@ -554,48 +571,3 @@ def get_warpAffine(config):
|
|||
rz = np.array([[np.cos(rad(anglez)), np.sin(rad(anglez)), 0],
|
||||
[-np.sin(rad(anglez)), np.cos(rad(anglez)), 0]], np.float32)
|
||||
return rz
|
||||
|
||||
|
||||
def warp(img, ang, use_tia=True, prob=0.4):
|
||||
"""
|
||||
warp
|
||||
"""
|
||||
h, w, _ = img.shape
|
||||
config = Config(use_tia=use_tia)
|
||||
config.make(w, h, ang)
|
||||
new_img = img
|
||||
|
||||
if config.distort:
|
||||
img_height, img_width = img.shape[0:2]
|
||||
if random.random() <= prob and img_height >= 20 and img_width >= 20:
|
||||
new_img = tia_distort(new_img, random.randint(3, 6))
|
||||
|
||||
if config.stretch:
|
||||
img_height, img_width = img.shape[0:2]
|
||||
if random.random() <= prob and img_height >= 20 and img_width >= 20:
|
||||
new_img = tia_stretch(new_img, random.randint(3, 6))
|
||||
|
||||
if config.perspective:
|
||||
if random.random() <= prob:
|
||||
new_img = tia_perspective(new_img)
|
||||
|
||||
if config.crop:
|
||||
img_height, img_width = img.shape[0:2]
|
||||
if random.random() <= prob and img_height >= 20 and img_width >= 20:
|
||||
new_img = get_crop(new_img)
|
||||
|
||||
if config.blur:
|
||||
if random.random() <= prob:
|
||||
new_img = blur(new_img)
|
||||
if config.color:
|
||||
if random.random() <= prob:
|
||||
new_img = cvtColor(new_img)
|
||||
if config.jitter:
|
||||
new_img = jitter(new_img)
|
||||
if config.noise:
|
||||
if random.random() <= prob:
|
||||
new_img = add_gasuss_noise(new_img)
|
||||
if config.reverse:
|
||||
if random.random() <= prob:
|
||||
new_img = 255 - new_img
|
||||
return new_img
|
||||
|
|
Loading…
Reference in New Issue