refactor: to be pythonic

pull/2357/head
gaotingquan 2022-09-28 03:51:34 +00:00 committed by Tingquan Gao
parent e823f17853
commit 19840cb046
3 changed files with 41 additions and 57 deletions

View File

@ -41,6 +41,7 @@ from ppcls.data.preprocess.ops.operators import RandomCropImage
from ppcls.data.preprocess.ops.operators import RandomRotation
from ppcls.data.preprocess.ops.operators import Padv2
from ppcls.data.preprocess.ops.operators import RandomRot90
from .ops.operators import format_data
from ppcls.data.preprocess.batch_ops.batch_operators import MixupOperator, CutmixOperator, OpSampler, FmixOperator
from ppcls.data.preprocess.batch_ops.batch_operators import MixupCutmixHybrid
@ -102,8 +103,8 @@ class TimmAutoAugment(RawTimmAutoAugment):
super().__init__(*args, **kwargs)
self.prob = prob
def __call__(self, ori_data):
img = ori_data["img"] if isinstance(ori_data, dict) else ori_data
@format_data
def __call__(self, img):
if not isinstance(img, Image.Image):
img = np.ascontiguousarray(img)
img = Image.fromarray(img)
@ -111,9 +112,5 @@ class TimmAutoAugment(RawTimmAutoAugment):
img = super().__call__(img)
if isinstance(img, Image.Image):
img = np.asarray(img)
processed_data = {
**
ori_data,
"img": img
} if isinstance(ori_data, dict) else img
return processed_data
return img

View File

@ -34,6 +34,23 @@ from .functional import augmentations
from ppcls.utils import logger
def format_data(func):
def warpper(self, data):
if isinstance(data, dict):
img = data["img"]
result = func(self, img)
if not isinstance(result, dict):
result = {"img": result}
return { ** data, ** result}
else:
result = func(self, data)
if isinstance(result, dict):
result = result["img"]
return result
return warpper
class UnifiedResize(object):
def __init__(self, interpolation=None, backend="cv2", return_numpy=True):
_cv2_interp_from_str = {
@ -161,8 +178,8 @@ class DecodeImage(object):
f"\"to_rgb\" and \"channel_first\" are only enabled when to_np is True. \"to_np\" is now {to_np}."
)
def __call__(self, ori_data):
img = ori_data["img"] if isinstance(ori_data, dict) else ori_data
@format_data
def __call__(self, img):
if isinstance(img, Image.Image):
assert self.backend == "pil", "invalid input 'img' in DecodeImage"
elif isinstance(img, np.ndarray):
@ -189,12 +206,7 @@ class DecodeImage(object):
if self.channel_first:
img = img.transpose((2, 0, 1))
processed_data = {
**
ori_data,
"img": img
} if isinstance(ori_data, dict) else img
return processed_data
return img
class ResizeImage(object):
@ -421,8 +433,8 @@ class RandCropImage(object):
self._resize_func = UnifiedResize(
interpolation=interpolation, backend=backend)
def __call__(self, ori_data):
img = ori_data["img"] if isinstance(ori_data, dict) else ori_data
@format_data
def __call__(self, img):
size = self.size
scale = self.scale
ratio = self.ratio
@ -447,12 +459,7 @@ class RandCropImage(object):
j = random.randint(0, img_h - h)
img = self._resize_func(img[j:j + h, i:i + w, :], size)
processed_data = {
**
ori_data,
"img": img
} if isinstance(ori_data, dict) else img
return processed_data
return img
class RandCropImageV2(object):
@ -557,8 +564,8 @@ class NormalizeImage(object):
self.mean = np.array(mean).reshape(shape).astype('float32')
self.std = np.array(std).reshape(shape).astype('float32')
def __call__(self, ori_data):
img = ori_data["img"] if isinstance(ori_data, dict) else ori_data
@format_data
def __call__(self, img):
from PIL import Image
if isinstance(img, Image.Image):
img = np.array(img)
@ -580,12 +587,7 @@ class NormalizeImage(object):
(img, pad_zeros), axis=2))
img = img.astype(self.output_dtype)
processed_data = {
**
ori_data,
"img": img
} if isinstance(ori_data, dict) else img
return processed_data
return img
class ToCHWImage(object):
@ -772,15 +774,9 @@ class RandomRot90(object):
def __init__(self):
pass
def __call__(self, ori_data):
img = ori_data["img"] if isinstance(ori_data, dict) else ori_data
@format_data
def __call__(self, img):
orientation = random.choice([0, 1, 2, 3])
if orientation:
img = np.rot90(img, orientation)
processed_data = {
**
ori_data,
"img": img,
"random_rot90_orientation": orientation
} if isinstance(ori_data, dict) else img
return processed_data
return {"img": img, "random_rot90_orientation": orientation}

View File

@ -22,6 +22,8 @@ import random
import numpy as np
from .operators import format_data
class Pixels(object):
def __init__(self, mode="const", mean=[0., 0., 0.]):
@ -70,11 +72,10 @@ class RandomErasing(object):
self.attempt = attempt
self.get_pixels = Pixels(mode, mean)
def __call__(self, ori_data):
@format_data
def __call__(self, img):
if random.random() > self.EPSILON:
return ori_data
img = ori_data["img"] if isinstance(ori_data, dict) else ori_data
return img
for _ in range(self.attempt):
if isinstance(img, np.ndarray):
@ -107,16 +108,6 @@ class RandomErasing(object):
img[0, x1:x1 + h, y1:y1 + w] = pixels[0]
else:
img[x1:x1 + h, y1:y1 + w, 0] = pixels[:, :, 0]
processed_data = {
**
ori_data,
"img": img
} if isinstance(ori_data, dict) else img
return processed_data
return img
processed_data = {
**
ori_data,
"img": img
} if isinstance(ori_data, dict) else img
return processed_data
return img