refactor: to be pythonic
parent
e823f17853
commit
19840cb046
|
@ -41,6 +41,7 @@ from ppcls.data.preprocess.ops.operators import RandomCropImage
|
|||
from ppcls.data.preprocess.ops.operators import RandomRotation
|
||||
from ppcls.data.preprocess.ops.operators import Padv2
|
||||
from ppcls.data.preprocess.ops.operators import RandomRot90
|
||||
from .ops.operators import format_data
|
||||
|
||||
from ppcls.data.preprocess.batch_ops.batch_operators import MixupOperator, CutmixOperator, OpSampler, FmixOperator
|
||||
from ppcls.data.preprocess.batch_ops.batch_operators import MixupCutmixHybrid
|
||||
|
@ -102,8 +103,8 @@ class TimmAutoAugment(RawTimmAutoAugment):
|
|||
super().__init__(*args, **kwargs)
|
||||
self.prob = prob
|
||||
|
||||
def __call__(self, ori_data):
|
||||
img = ori_data["img"] if isinstance(ori_data, dict) else ori_data
|
||||
@format_data
|
||||
def __call__(self, img):
|
||||
if not isinstance(img, Image.Image):
|
||||
img = np.ascontiguousarray(img)
|
||||
img = Image.fromarray(img)
|
||||
|
@ -111,9 +112,5 @@ class TimmAutoAugment(RawTimmAutoAugment):
|
|||
img = super().__call__(img)
|
||||
if isinstance(img, Image.Image):
|
||||
img = np.asarray(img)
|
||||
processed_data = {
|
||||
**
|
||||
ori_data,
|
||||
"img": img
|
||||
} if isinstance(ori_data, dict) else img
|
||||
return processed_data
|
||||
|
||||
return img
|
||||
|
|
|
@ -34,6 +34,23 @@ from .functional import augmentations
|
|||
from ppcls.utils import logger
|
||||
|
||||
|
||||
def format_data(func):
|
||||
def warpper(self, data):
|
||||
if isinstance(data, dict):
|
||||
img = data["img"]
|
||||
result = func(self, img)
|
||||
if not isinstance(result, dict):
|
||||
result = {"img": result}
|
||||
return { ** data, ** result}
|
||||
else:
|
||||
result = func(self, data)
|
||||
if isinstance(result, dict):
|
||||
result = result["img"]
|
||||
return result
|
||||
|
||||
return warpper
|
||||
|
||||
|
||||
class UnifiedResize(object):
|
||||
def __init__(self, interpolation=None, backend="cv2", return_numpy=True):
|
||||
_cv2_interp_from_str = {
|
||||
|
@ -161,8 +178,8 @@ class DecodeImage(object):
|
|||
f"\"to_rgb\" and \"channel_first\" are only enabled when to_np is True. \"to_np\" is now {to_np}."
|
||||
)
|
||||
|
||||
def __call__(self, ori_data):
|
||||
img = ori_data["img"] if isinstance(ori_data, dict) else ori_data
|
||||
@format_data
|
||||
def __call__(self, img):
|
||||
if isinstance(img, Image.Image):
|
||||
assert self.backend == "pil", "invalid input 'img' in DecodeImage"
|
||||
elif isinstance(img, np.ndarray):
|
||||
|
@ -189,12 +206,7 @@ class DecodeImage(object):
|
|||
|
||||
if self.channel_first:
|
||||
img = img.transpose((2, 0, 1))
|
||||
processed_data = {
|
||||
**
|
||||
ori_data,
|
||||
"img": img
|
||||
} if isinstance(ori_data, dict) else img
|
||||
return processed_data
|
||||
return img
|
||||
|
||||
|
||||
class ResizeImage(object):
|
||||
|
@ -421,8 +433,8 @@ class RandCropImage(object):
|
|||
self._resize_func = UnifiedResize(
|
||||
interpolation=interpolation, backend=backend)
|
||||
|
||||
def __call__(self, ori_data):
|
||||
img = ori_data["img"] if isinstance(ori_data, dict) else ori_data
|
||||
@format_data
|
||||
def __call__(self, img):
|
||||
size = self.size
|
||||
scale = self.scale
|
||||
ratio = self.ratio
|
||||
|
@ -447,12 +459,7 @@ class RandCropImage(object):
|
|||
j = random.randint(0, img_h - h)
|
||||
|
||||
img = self._resize_func(img[j:j + h, i:i + w, :], size)
|
||||
processed_data = {
|
||||
**
|
||||
ori_data,
|
||||
"img": img
|
||||
} if isinstance(ori_data, dict) else img
|
||||
return processed_data
|
||||
return img
|
||||
|
||||
|
||||
class RandCropImageV2(object):
|
||||
|
@ -557,8 +564,8 @@ class NormalizeImage(object):
|
|||
self.mean = np.array(mean).reshape(shape).astype('float32')
|
||||
self.std = np.array(std).reshape(shape).astype('float32')
|
||||
|
||||
def __call__(self, ori_data):
|
||||
img = ori_data["img"] if isinstance(ori_data, dict) else ori_data
|
||||
@format_data
|
||||
def __call__(self, img):
|
||||
from PIL import Image
|
||||
if isinstance(img, Image.Image):
|
||||
img = np.array(img)
|
||||
|
@ -580,12 +587,7 @@ class NormalizeImage(object):
|
|||
(img, pad_zeros), axis=2))
|
||||
|
||||
img = img.astype(self.output_dtype)
|
||||
processed_data = {
|
||||
**
|
||||
ori_data,
|
||||
"img": img
|
||||
} if isinstance(ori_data, dict) else img
|
||||
return processed_data
|
||||
return img
|
||||
|
||||
|
||||
class ToCHWImage(object):
|
||||
|
@ -772,15 +774,9 @@ class RandomRot90(object):
|
|||
def __init__(self):
|
||||
pass
|
||||
|
||||
def __call__(self, ori_data):
|
||||
img = ori_data["img"] if isinstance(ori_data, dict) else ori_data
|
||||
@format_data
|
||||
def __call__(self, img):
|
||||
orientation = random.choice([0, 1, 2, 3])
|
||||
if orientation:
|
||||
img = np.rot90(img, orientation)
|
||||
processed_data = {
|
||||
**
|
||||
ori_data,
|
||||
"img": img,
|
||||
"random_rot90_orientation": orientation
|
||||
} if isinstance(ori_data, dict) else img
|
||||
return processed_data
|
||||
return {"img": img, "random_rot90_orientation": orientation}
|
||||
|
|
|
@ -22,6 +22,8 @@ import random
|
|||
|
||||
import numpy as np
|
||||
|
||||
from .operators import format_data
|
||||
|
||||
|
||||
class Pixels(object):
|
||||
def __init__(self, mode="const", mean=[0., 0., 0.]):
|
||||
|
@ -70,11 +72,10 @@ class RandomErasing(object):
|
|||
self.attempt = attempt
|
||||
self.get_pixels = Pixels(mode, mean)
|
||||
|
||||
def __call__(self, ori_data):
|
||||
@format_data
|
||||
def __call__(self, img):
|
||||
if random.random() > self.EPSILON:
|
||||
return ori_data
|
||||
|
||||
img = ori_data["img"] if isinstance(ori_data, dict) else ori_data
|
||||
return img
|
||||
|
||||
for _ in range(self.attempt):
|
||||
if isinstance(img, np.ndarray):
|
||||
|
@ -107,16 +108,6 @@ class RandomErasing(object):
|
|||
img[0, x1:x1 + h, y1:y1 + w] = pixels[0]
|
||||
else:
|
||||
img[x1:x1 + h, y1:y1 + w, 0] = pixels[:, :, 0]
|
||||
processed_data = {
|
||||
**
|
||||
ori_data,
|
||||
"img": img
|
||||
} if isinstance(ori_data, dict) else img
|
||||
return processed_data
|
||||
return img
|
||||
|
||||
processed_data = {
|
||||
**
|
||||
ori_data,
|
||||
"img": img
|
||||
} if isinstance(ori_data, dict) else img
|
||||
return processed_data
|
||||
return img
|
||||
|
|
Loading…
Reference in New Issue