Add albumentations (#45)
* Add Albu transform * pre-commit * Create optional.txt * Update requirements.txt * Update transforms.pypull/48/head
parent
8d3acce307
commit
99115fddbc
|
@ -18,8 +18,8 @@ class CIFAR10(BaseDataset):
|
|||
"""
|
||||
|
||||
base_folder = 'cifar-10-batches-py'
|
||||
url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
|
||||
filename = "cifar-10-python.tar.gz"
|
||||
url = 'https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'
|
||||
filename = 'cifar-10-python.tar.gz'
|
||||
tgz_md5 = 'c58f30108f718f92721af3b95e74349a'
|
||||
train_list = [
|
||||
['data_batch_1', 'c99cafc152244af753f735de768cd75f'],
|
||||
|
@ -110,8 +110,8 @@ class CIFAR100(CIFAR10):
|
|||
"""
|
||||
|
||||
base_folder = 'cifar-100-python'
|
||||
url = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz"
|
||||
filename = "cifar-100-python.tar.gz"
|
||||
url = 'https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz'
|
||||
filename = 'cifar-100-python.tar.gz'
|
||||
tgz_md5 = 'eb9058c3a382ffc7106e4002c42a8d85'
|
||||
train_list = [
|
||||
['train', '16019d7e3df5f24257cddd939b257f8d'],
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import inspect
|
||||
import math
|
||||
import random
|
||||
|
||||
|
@ -6,6 +7,13 @@ import numpy as np
|
|||
|
||||
from ..builder import PIPELINES
|
||||
|
||||
try:
|
||||
import albumentations
|
||||
from albumentations import Compose
|
||||
except ImportError:
|
||||
albumentations = None
|
||||
Compose = None
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
class RandomCrop(object):
|
||||
|
@ -155,8 +163,8 @@ class RandomResizedCrop(object):
|
|||
else:
|
||||
self.size = (size, size)
|
||||
if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
|
||||
raise ValueError("range should be of kind (min, max). "
|
||||
f"But received {scale}")
|
||||
raise ValueError('range should be of kind (min, max). '
|
||||
f'But received {scale}')
|
||||
if backend not in ['cv2', 'pillow']:
|
||||
raise ValueError(f'backend: {backend} is not supported for resize.'
|
||||
'Supported backends are "cv2", "pillow"')
|
||||
|
@ -363,8 +371,8 @@ class Resize(object):
|
|||
assert size[0] > 0 and (size[1] > 0 or size[1] == -1)
|
||||
if size[1] == -1:
|
||||
self.resize_w_short_side = True
|
||||
assert interpolation in ("nearest", "bilinear", "bicubic", "area",
|
||||
"lanczos")
|
||||
assert interpolation in ('nearest', 'bilinear', 'bicubic', 'area',
|
||||
'lanczos')
|
||||
if backend not in ['cv2', 'pillow']:
|
||||
raise ValueError(f'backend: {backend} is not supported for resize.'
|
||||
'Supported backends are "cv2", "pillow"')
|
||||
|
@ -486,3 +494,131 @@ class Normalize(object):
|
|||
repr_str += f'std={list(self.std)}, '
|
||||
repr_str += f'to_rgb={self.to_rgb})'
|
||||
return repr_str
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
class Albu(object):
|
||||
"""Albumentation augmentation.
|
||||
|
||||
Adds custom transformations from Albumentations library.
|
||||
Please, visit `https://albumentations.readthedocs.io`
|
||||
to get more information.
|
||||
An example of ``transforms`` is as followed:
|
||||
|
||||
.. code-block::
|
||||
[
|
||||
dict(
|
||||
type='ShiftScaleRotate',
|
||||
shift_limit=0.0625,
|
||||
scale_limit=0.0,
|
||||
rotate_limit=0,
|
||||
interpolation=1,
|
||||
p=0.5),
|
||||
dict(
|
||||
type='RandomBrightnessContrast',
|
||||
brightness_limit=[0.1, 0.3],
|
||||
contrast_limit=[0.1, 0.3],
|
||||
p=0.2),
|
||||
dict(type='ChannelShuffle', p=0.1),
|
||||
dict(
|
||||
type='OneOf',
|
||||
transforms=[
|
||||
dict(type='Blur', blur_limit=3, p=1.0),
|
||||
dict(type='MedianBlur', blur_limit=3, p=1.0)
|
||||
],
|
||||
p=0.1),
|
||||
]
|
||||
|
||||
Args:
|
||||
transforms (list[dict]): A list of albu transformations
|
||||
keymap (dict): Contains {'input key':'albumentation-style key'}
|
||||
"""
|
||||
|
||||
def __init__(self, transforms, keymap=None, update_pad_shape=False):
|
||||
if Compose is None:
|
||||
raise RuntimeError('albumentations is not installed')
|
||||
|
||||
self.transforms = transforms
|
||||
self.filter_lost_elements = False
|
||||
self.update_pad_shape = update_pad_shape
|
||||
|
||||
self.aug = Compose([self.albu_builder(t) for t in self.transforms])
|
||||
|
||||
if not keymap:
|
||||
self.keymap_to_albu = {
|
||||
'img': 'image',
|
||||
}
|
||||
else:
|
||||
self.keymap_to_albu = keymap
|
||||
self.keymap_back = {v: k for k, v in self.keymap_to_albu.items()}
|
||||
|
||||
def albu_builder(self, cfg):
|
||||
"""Import a module from albumentations.
|
||||
It inherits some of :func:`build_from_cfg` logic.
|
||||
Args:
|
||||
cfg (dict): Config dict. It should at least contain the key "type".
|
||||
Returns:
|
||||
obj: The constructed object.
|
||||
"""
|
||||
|
||||
assert isinstance(cfg, dict) and 'type' in cfg
|
||||
args = cfg.copy()
|
||||
|
||||
obj_type = args.pop('type')
|
||||
if mmcv.is_str(obj_type):
|
||||
if albumentations is None:
|
||||
raise RuntimeError('albumentations is not installed')
|
||||
obj_cls = getattr(albumentations, obj_type)
|
||||
elif inspect.isclass(obj_type):
|
||||
obj_cls = obj_type
|
||||
else:
|
||||
raise TypeError(
|
||||
f'type must be a str or valid type, but got {type(obj_type)}')
|
||||
|
||||
if 'transforms' in args:
|
||||
args['transforms'] = [
|
||||
self.albu_builder(transform)
|
||||
for transform in args['transforms']
|
||||
]
|
||||
|
||||
return obj_cls(**args)
|
||||
|
||||
@staticmethod
|
||||
def mapper(d, keymap):
|
||||
"""Dictionary mapper. Renames keys according to keymap provided.
|
||||
Args:
|
||||
d (dict): old dict
|
||||
keymap (dict): {'old_key':'new_key'}
|
||||
Returns:
|
||||
dict: new dict.
|
||||
"""
|
||||
|
||||
updated_dict = {}
|
||||
for k, v in zip(d.keys(), d.values()):
|
||||
new_k = keymap.get(k, k)
|
||||
updated_dict[new_k] = d[k]
|
||||
return updated_dict
|
||||
|
||||
def __call__(self, results):
|
||||
# dict to albumentations format
|
||||
results = self.mapper(results, self.keymap_to_albu)
|
||||
|
||||
results = self.aug(**results)
|
||||
|
||||
if 'gt_labels' in results:
|
||||
if isinstance(results['gt_labels'], list):
|
||||
results['gt_labels'] = np.array(results['gt_labels'])
|
||||
results['gt_labels'] = results['gt_labels'].astype(np.int64)
|
||||
|
||||
# back to the original format
|
||||
results = self.mapper(results, self.keymap_back)
|
||||
|
||||
# update final shape
|
||||
if self.update_pad_shape:
|
||||
results['pad_shape'] = results['img'].shape
|
||||
|
||||
return results
|
||||
|
||||
def __repr__(self):
|
||||
repr_str = self.__class__.__name__ + f'(transforms={self.transforms})'
|
||||
return repr_str
|
||||
|
|
|
@ -1,2 +1,3 @@
|
|||
-r requirements/runtime.txt
|
||||
-r requirements/optional.txt
|
||||
-r requirements/tests.txt
|
||||
|
|
|
@ -0,0 +1 @@
|
|||
albumentations>=0.3.2
|
|
@ -780,3 +780,27 @@ def test_randomflip():
|
|||
flipped_img = np.array(flipped_img)
|
||||
assert np.equal(results['img'], results['img2']).all()
|
||||
assert np.equal(results['img'], flipped_img).all()
|
||||
|
||||
|
||||
def test_albu_transform():
|
||||
results = dict(
|
||||
img_prefix=osp.join(osp.dirname(__file__), '../data'),
|
||||
img_info=dict(filename='color.jpg'))
|
||||
|
||||
# Define simple pipeline
|
||||
load = dict(type='LoadImageFromFile')
|
||||
load = build_from_cfg(load, PIPELINES)
|
||||
|
||||
albu_transform = dict(
|
||||
type='Albu', transforms=[dict(type='ChannelShuffle', p=1)])
|
||||
albu_transform = build_from_cfg(albu_transform, PIPELINES)
|
||||
|
||||
normalize = dict(type='Normalize', mean=[0] * 3, std=[0] * 3, to_rgb=True)
|
||||
normalize = build_from_cfg(normalize, PIPELINES)
|
||||
|
||||
# Execute transforms
|
||||
results = load(results)
|
||||
results = albu_transform(results)
|
||||
results = normalize(results)
|
||||
|
||||
assert results['img'].dtype == np.float32
|
||||
|
|
Loading…
Reference in New Issue