refactor: to be pythonic

pull/2357/head
gaotingquan 2022-09-28 03:51:34 +00:00 committed by Tingquan Gao
parent e823f17853
commit 19840cb046
3 changed files with 41 additions and 57 deletions

View File

@ -41,6 +41,7 @@ from ppcls.data.preprocess.ops.operators import RandomCropImage
from ppcls.data.preprocess.ops.operators import RandomRotation from ppcls.data.preprocess.ops.operators import RandomRotation
from ppcls.data.preprocess.ops.operators import Padv2 from ppcls.data.preprocess.ops.operators import Padv2
from ppcls.data.preprocess.ops.operators import RandomRot90 from ppcls.data.preprocess.ops.operators import RandomRot90
from .ops.operators import format_data
from ppcls.data.preprocess.batch_ops.batch_operators import MixupOperator, CutmixOperator, OpSampler, FmixOperator from ppcls.data.preprocess.batch_ops.batch_operators import MixupOperator, CutmixOperator, OpSampler, FmixOperator
from ppcls.data.preprocess.batch_ops.batch_operators import MixupCutmixHybrid from ppcls.data.preprocess.batch_ops.batch_operators import MixupCutmixHybrid
@ -102,8 +103,8 @@ class TimmAutoAugment(RawTimmAutoAugment):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.prob = prob self.prob = prob
def __call__(self, ori_data): @format_data
img = ori_data["img"] if isinstance(ori_data, dict) else ori_data def __call__(self, img):
if not isinstance(img, Image.Image): if not isinstance(img, Image.Image):
img = np.ascontiguousarray(img) img = np.ascontiguousarray(img)
img = Image.fromarray(img) img = Image.fromarray(img)
@ -111,9 +112,5 @@ class TimmAutoAugment(RawTimmAutoAugment):
img = super().__call__(img) img = super().__call__(img)
if isinstance(img, Image.Image): if isinstance(img, Image.Image):
img = np.asarray(img) img = np.asarray(img)
processed_data = {
** return img
ori_data,
"img": img
} if isinstance(ori_data, dict) else img
return processed_data

View File

@ -34,6 +34,23 @@ from .functional import augmentations
from ppcls.utils import logger from ppcls.utils import logger
def format_data(func):
def warpper(self, data):
if isinstance(data, dict):
img = data["img"]
result = func(self, img)
if not isinstance(result, dict):
result = {"img": result}
return { ** data, ** result}
else:
result = func(self, data)
if isinstance(result, dict):
result = result["img"]
return result
return warpper
class UnifiedResize(object): class UnifiedResize(object):
def __init__(self, interpolation=None, backend="cv2", return_numpy=True): def __init__(self, interpolation=None, backend="cv2", return_numpy=True):
_cv2_interp_from_str = { _cv2_interp_from_str = {
@ -161,8 +178,8 @@ class DecodeImage(object):
f"\"to_rgb\" and \"channel_first\" are only enabled when to_np is True. \"to_np\" is now {to_np}." f"\"to_rgb\" and \"channel_first\" are only enabled when to_np is True. \"to_np\" is now {to_np}."
) )
def __call__(self, ori_data): @format_data
img = ori_data["img"] if isinstance(ori_data, dict) else ori_data def __call__(self, img):
if isinstance(img, Image.Image): if isinstance(img, Image.Image):
assert self.backend == "pil", "invalid input 'img' in DecodeImage" assert self.backend == "pil", "invalid input 'img' in DecodeImage"
elif isinstance(img, np.ndarray): elif isinstance(img, np.ndarray):
@ -189,12 +206,7 @@ class DecodeImage(object):
if self.channel_first: if self.channel_first:
img = img.transpose((2, 0, 1)) img = img.transpose((2, 0, 1))
processed_data = { return img
**
ori_data,
"img": img
} if isinstance(ori_data, dict) else img
return processed_data
class ResizeImage(object): class ResizeImage(object):
@ -421,8 +433,8 @@ class RandCropImage(object):
self._resize_func = UnifiedResize( self._resize_func = UnifiedResize(
interpolation=interpolation, backend=backend) interpolation=interpolation, backend=backend)
def __call__(self, ori_data): @format_data
img = ori_data["img"] if isinstance(ori_data, dict) else ori_data def __call__(self, img):
size = self.size size = self.size
scale = self.scale scale = self.scale
ratio = self.ratio ratio = self.ratio
@ -447,12 +459,7 @@ class RandCropImage(object):
j = random.randint(0, img_h - h) j = random.randint(0, img_h - h)
img = self._resize_func(img[j:j + h, i:i + w, :], size) img = self._resize_func(img[j:j + h, i:i + w, :], size)
processed_data = { return img
**
ori_data,
"img": img
} if isinstance(ori_data, dict) else img
return processed_data
class RandCropImageV2(object): class RandCropImageV2(object):
@ -557,8 +564,8 @@ class NormalizeImage(object):
self.mean = np.array(mean).reshape(shape).astype('float32') self.mean = np.array(mean).reshape(shape).astype('float32')
self.std = np.array(std).reshape(shape).astype('float32') self.std = np.array(std).reshape(shape).astype('float32')
def __call__(self, ori_data): @format_data
img = ori_data["img"] if isinstance(ori_data, dict) else ori_data def __call__(self, img):
from PIL import Image from PIL import Image
if isinstance(img, Image.Image): if isinstance(img, Image.Image):
img = np.array(img) img = np.array(img)
@ -580,12 +587,7 @@ class NormalizeImage(object):
(img, pad_zeros), axis=2)) (img, pad_zeros), axis=2))
img = img.astype(self.output_dtype) img = img.astype(self.output_dtype)
processed_data = { return img
**
ori_data,
"img": img
} if isinstance(ori_data, dict) else img
return processed_data
class ToCHWImage(object): class ToCHWImage(object):
@ -772,15 +774,9 @@ class RandomRot90(object):
def __init__(self): def __init__(self):
pass pass
def __call__(self, ori_data): @format_data
img = ori_data["img"] if isinstance(ori_data, dict) else ori_data def __call__(self, img):
orientation = random.choice([0, 1, 2, 3]) orientation = random.choice([0, 1, 2, 3])
if orientation: if orientation:
img = np.rot90(img, orientation) img = np.rot90(img, orientation)
processed_data = { return {"img": img, "random_rot90_orientation": orientation}
**
ori_data,
"img": img,
"random_rot90_orientation": orientation
} if isinstance(ori_data, dict) else img
return processed_data

View File

@ -22,6 +22,8 @@ import random
import numpy as np import numpy as np
from .operators import format_data
class Pixels(object): class Pixels(object):
def __init__(self, mode="const", mean=[0., 0., 0.]): def __init__(self, mode="const", mean=[0., 0., 0.]):
@ -70,11 +72,10 @@ class RandomErasing(object):
self.attempt = attempt self.attempt = attempt
self.get_pixels = Pixels(mode, mean) self.get_pixels = Pixels(mode, mean)
def __call__(self, ori_data): @format_data
def __call__(self, img):
if random.random() > self.EPSILON: if random.random() > self.EPSILON:
return ori_data return img
img = ori_data["img"] if isinstance(ori_data, dict) else ori_data
for _ in range(self.attempt): for _ in range(self.attempt):
if isinstance(img, np.ndarray): if isinstance(img, np.ndarray):
@ -107,16 +108,6 @@ class RandomErasing(object):
img[0, x1:x1 + h, y1:y1 + w] = pixels[0] img[0, x1:x1 + h, y1:y1 + w] = pixels[0]
else: else:
img[x1:x1 + h, y1:y1 + w, 0] = pixels[:, :, 0] img[x1:x1 + h, y1:y1 + w, 0] = pixels[:, :, 0]
processed_data = { return img
**
ori_data,
"img": img
} if isinstance(ori_data, dict) else img
return processed_data
processed_data = { return img
**
ori_data,
"img": img
} if isinstance(ori_data, dict) else img
return processed_data