mirror of
https://github.com/alibaba/EasyCV.git
synced 2025-06-03 14:49:00 +08:00
80 lines
2.6 KiB
Python
80 lines
2.6 KiB
Python
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
import inspect
|
|
import time
|
|
|
|
import numpy as np
|
|
# not useful in this CR, future support albumentations
|
|
# import mkl
|
|
# mkl.get_max_threads()
|
|
from albumentations import (CLAHE, Blur, Flip, GaussNoise, GridDistortion,
|
|
HorizontalFlip, HueSaturationValue, IAAEmboss,
|
|
IAAPerspective, IAAPiecewiseAffine, IAASharpen,
|
|
MedianBlur, MotionBlur, OneOf, OpticalDistortion,
|
|
RandomBrightness, RandomContrast, RandomRotate90,
|
|
ShiftScaleRotate, Transpose)
|
|
from PIL import Image
|
|
from torchvision import transforms as _transforms
|
|
|
|
from easycv.datasets.registry import PIPELINES
|
|
|
|
albumentation_list = [
|
|
HorizontalFlip, IAAPerspective, ShiftScaleRotate, CLAHE, RandomRotate90,
|
|
Transpose, ShiftScaleRotate, Blur, OpticalDistortion, GridDistortion,
|
|
HueSaturationValue, GaussNoise, MotionBlur, MedianBlur, IAAPiecewiseAffine,
|
|
IAASharpen, IAAEmboss, RandomContrast, RandomBrightness, Flip, OneOf
|
|
]
|
|
|
|
PIPELINES.register_module(ShiftScaleRotate)
|
|
PIPELINES.register_module(GaussNoise)
|
|
PIPELINES.register_module(MotionBlur)
|
|
|
|
# register all existing transforms in torchvision
|
|
for m in inspect.getmembers(_transforms, inspect.isclass):
|
|
# use self-implement Compose
|
|
if m[0] == 'Compose':
|
|
continue
|
|
PIPELINES.register_module(m[1])
|
|
|
|
|
|
@PIPELINES.register_module
|
|
class Compose(object):
|
|
"""Composes several transforms together.
|
|
|
|
Args:
|
|
transforms (list of ``Transform`` objects): list of transforms to compose.
|
|
|
|
Example:
|
|
>>> transforms.Compose([
|
|
>>> transforms.CenterCrop(10),
|
|
>>> transforms.ToTensor(),
|
|
>>> ])
|
|
"""
|
|
|
|
def __init__(self, transforms, profiling=False):
|
|
self.transforms = transforms
|
|
self.profiling = profiling
|
|
|
|
def __call__(self, img):
|
|
for t in self.transforms:
|
|
if self.profiling:
|
|
start = time.time()
|
|
|
|
if isinstance(t, tuple(albumentation_list)):
|
|
img_np = np.array(img)
|
|
augmented = t(image=img_np)
|
|
img = Image.fromarray(augmented['image'])
|
|
else:
|
|
img = t(img)
|
|
|
|
if self.profiling:
|
|
print(f'{t} time {time.time()-start}')
|
|
return img
|
|
|
|
def __repr__(self):
|
|
format_string = self.__class__.__name__ + '('
|
|
for t in self.transforms:
|
|
format_string += '\n'
|
|
format_string += ' {0}'.format(t)
|
|
format_string += '\n)'
|
|
return format_string
|