Add albumentations (#45)

* Add Albu transform

* pre-commit

* Create optional.txt

* Update requirements.txt

* Update transforms.py
pull/48/head
David de la Iglesia Castro 2020-09-22 11:35:39 +02:00 committed by GitHub
parent 8d3acce307
commit 99115fddbc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 192 additions and 30 deletions

View File

@ -18,8 +18,8 @@ class CIFAR10(BaseDataset):
"""
base_folder = 'cifar-10-batches-py'
url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
filename = "cifar-10-python.tar.gz"
url = 'https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'
filename = 'cifar-10-python.tar.gz'
tgz_md5 = 'c58f30108f718f92721af3b95e74349a'
train_list = [
['data_batch_1', 'c99cafc152244af753f735de768cd75f'],
@ -110,8 +110,8 @@ class CIFAR100(CIFAR10):
"""
base_folder = 'cifar-100-python'
url = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz"
filename = "cifar-100-python.tar.gz"
url = 'https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz'
filename = 'cifar-100-python.tar.gz'
tgz_md5 = 'eb9058c3a382ffc7106e4002c42a8d85'
train_list = [
['train', '16019d7e3df5f24257cddd939b257f8d'],

View File

@ -1,3 +1,4 @@
import inspect
import math
import random
@ -6,6 +7,13 @@ import numpy as np
from ..builder import PIPELINES
try:
import albumentations
from albumentations import Compose
except ImportError:
albumentations = None
Compose = None
@PIPELINES.register_module()
class RandomCrop(object):
@ -155,8 +163,8 @@ class RandomResizedCrop(object):
else:
self.size = (size, size)
if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
raise ValueError("range should be of kind (min, max). "
f"But received {scale}")
raise ValueError('range should be of kind (min, max). '
f'But received {scale}')
if backend not in ['cv2', 'pillow']:
raise ValueError(f'backend: {backend} is not supported for resize.'
'Supported backends are "cv2", "pillow"')
@ -363,8 +371,8 @@ class Resize(object):
assert size[0] > 0 and (size[1] > 0 or size[1] == -1)
if size[1] == -1:
self.resize_w_short_side = True
assert interpolation in ("nearest", "bilinear", "bicubic", "area",
"lanczos")
assert interpolation in ('nearest', 'bilinear', 'bicubic', 'area',
'lanczos')
if backend not in ['cv2', 'pillow']:
raise ValueError(f'backend: {backend} is not supported for resize.'
'Supported backends are "cv2", "pillow"')
@ -486,3 +494,131 @@ class Normalize(object):
repr_str += f'std={list(self.std)}, '
repr_str += f'to_rgb={self.to_rgb})'
return repr_str
@PIPELINES.register_module()
class Albu(object):
"""Albumentation augmentation.
Adds custom transformations from Albumentations library.
Please, visit `https://albumentations.readthedocs.io`
to get more information.
An example of ``transforms`` is as followed:
.. code-block::
[
dict(
type='ShiftScaleRotate',
shift_limit=0.0625,
scale_limit=0.0,
rotate_limit=0,
interpolation=1,
p=0.5),
dict(
type='RandomBrightnessContrast',
brightness_limit=[0.1, 0.3],
contrast_limit=[0.1, 0.3],
p=0.2),
dict(type='ChannelShuffle', p=0.1),
dict(
type='OneOf',
transforms=[
dict(type='Blur', blur_limit=3, p=1.0),
dict(type='MedianBlur', blur_limit=3, p=1.0)
],
p=0.1),
]
Args:
transforms (list[dict]): A list of albu transformations
keymap (dict): Contains {'input key':'albumentation-style key'}
"""
def __init__(self, transforms, keymap=None, update_pad_shape=False):
if Compose is None:
raise RuntimeError('albumentations is not installed')
self.transforms = transforms
self.filter_lost_elements = False
self.update_pad_shape = update_pad_shape
self.aug = Compose([self.albu_builder(t) for t in self.transforms])
if not keymap:
self.keymap_to_albu = {
'img': 'image',
}
else:
self.keymap_to_albu = keymap
self.keymap_back = {v: k for k, v in self.keymap_to_albu.items()}
def albu_builder(self, cfg):
"""Import a module from albumentations.
It inherits some of :func:`build_from_cfg` logic.
Args:
cfg (dict): Config dict. It should at least contain the key "type".
Returns:
obj: The constructed object.
"""
assert isinstance(cfg, dict) and 'type' in cfg
args = cfg.copy()
obj_type = args.pop('type')
if mmcv.is_str(obj_type):
if albumentations is None:
raise RuntimeError('albumentations is not installed')
obj_cls = getattr(albumentations, obj_type)
elif inspect.isclass(obj_type):
obj_cls = obj_type
else:
raise TypeError(
f'type must be a str or valid type, but got {type(obj_type)}')
if 'transforms' in args:
args['transforms'] = [
self.albu_builder(transform)
for transform in args['transforms']
]
return obj_cls(**args)
@staticmethod
def mapper(d, keymap):
"""Dictionary mapper. Renames keys according to keymap provided.
Args:
d (dict): old dict
keymap (dict): {'old_key':'new_key'}
Returns:
dict: new dict.
"""
updated_dict = {}
for k, v in zip(d.keys(), d.values()):
new_k = keymap.get(k, k)
updated_dict[new_k] = d[k]
return updated_dict
def __call__(self, results):
# dict to albumentations format
results = self.mapper(results, self.keymap_to_albu)
results = self.aug(**results)
if 'gt_labels' in results:
if isinstance(results['gt_labels'], list):
results['gt_labels'] = np.array(results['gt_labels'])
results['gt_labels'] = results['gt_labels'].astype(np.int64)
# back to the original format
results = self.mapper(results, self.keymap_back)
# update final shape
if self.update_pad_shape:
results['pad_shape'] = results['img'].shape
return results
def __repr__(self):
repr_str = self.__class__.__name__ + f'(transforms={self.transforms})'
return repr_str

View File

@ -1,2 +1,3 @@
-r requirements/runtime.txt
-r requirements/optional.txt
-r requirements/tests.txt

View File

@ -0,0 +1 @@
albumentations>=0.3.2

View File

@ -780,3 +780,27 @@ def test_randomflip():
flipped_img = np.array(flipped_img)
assert np.equal(results['img'], results['img2']).all()
assert np.equal(results['img'], flipped_img).all()
def test_albu_transform():
results = dict(
img_prefix=osp.join(osp.dirname(__file__), '../data'),
img_info=dict(filename='color.jpg'))
# Define simple pipeline
load = dict(type='LoadImageFromFile')
load = build_from_cfg(load, PIPELINES)
albu_transform = dict(
type='Albu', transforms=[dict(type='ChannelShuffle', p=1)])
albu_transform = build_from_cfg(albu_transform, PIPELINES)
normalize = dict(type='Normalize', mean=[0] * 3, std=[0] * 3, to_rgb=True)
normalize = build_from_cfg(normalize, PIPELINES)
# Execute transforms
results = load(results)
results = albu_transform(results)
results = normalize(results)
assert results['img'].dtype == np.float32