Add albumentations (#45)
* Add Albu transform * pre-commit * Create optional.txt * Update requirements.txt * Update transforms.pypull/48/head
parent
8d3acce307
commit
99115fddbc
|
@ -18,8 +18,8 @@ class CIFAR10(BaseDataset):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
base_folder = 'cifar-10-batches-py'
|
base_folder = 'cifar-10-batches-py'
|
||||||
url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
|
url = 'https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'
|
||||||
filename = "cifar-10-python.tar.gz"
|
filename = 'cifar-10-python.tar.gz'
|
||||||
tgz_md5 = 'c58f30108f718f92721af3b95e74349a'
|
tgz_md5 = 'c58f30108f718f92721af3b95e74349a'
|
||||||
train_list = [
|
train_list = [
|
||||||
['data_batch_1', 'c99cafc152244af753f735de768cd75f'],
|
['data_batch_1', 'c99cafc152244af753f735de768cd75f'],
|
||||||
|
@ -110,8 +110,8 @@ class CIFAR100(CIFAR10):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
base_folder = 'cifar-100-python'
|
base_folder = 'cifar-100-python'
|
||||||
url = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz"
|
url = 'https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz'
|
||||||
filename = "cifar-100-python.tar.gz"
|
filename = 'cifar-100-python.tar.gz'
|
||||||
tgz_md5 = 'eb9058c3a382ffc7106e4002c42a8d85'
|
tgz_md5 = 'eb9058c3a382ffc7106e4002c42a8d85'
|
||||||
train_list = [
|
train_list = [
|
||||||
['train', '16019d7e3df5f24257cddd939b257f8d'],
|
['train', '16019d7e3df5f24257cddd939b257f8d'],
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
import inspect
|
||||||
import math
|
import math
|
||||||
import random
|
import random
|
||||||
|
|
||||||
|
@ -6,6 +7,13 @@ import numpy as np
|
||||||
|
|
||||||
from ..builder import PIPELINES
|
from ..builder import PIPELINES
|
||||||
|
|
||||||
|
try:
|
||||||
|
import albumentations
|
||||||
|
from albumentations import Compose
|
||||||
|
except ImportError:
|
||||||
|
albumentations = None
|
||||||
|
Compose = None
|
||||||
|
|
||||||
|
|
||||||
@PIPELINES.register_module()
|
@PIPELINES.register_module()
|
||||||
class RandomCrop(object):
|
class RandomCrop(object):
|
||||||
|
@ -155,8 +163,8 @@ class RandomResizedCrop(object):
|
||||||
else:
|
else:
|
||||||
self.size = (size, size)
|
self.size = (size, size)
|
||||||
if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
|
if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
|
||||||
raise ValueError("range should be of kind (min, max). "
|
raise ValueError('range should be of kind (min, max). '
|
||||||
f"But received {scale}")
|
f'But received {scale}')
|
||||||
if backend not in ['cv2', 'pillow']:
|
if backend not in ['cv2', 'pillow']:
|
||||||
raise ValueError(f'backend: {backend} is not supported for resize.'
|
raise ValueError(f'backend: {backend} is not supported for resize.'
|
||||||
'Supported backends are "cv2", "pillow"')
|
'Supported backends are "cv2", "pillow"')
|
||||||
|
@ -363,8 +371,8 @@ class Resize(object):
|
||||||
assert size[0] > 0 and (size[1] > 0 or size[1] == -1)
|
assert size[0] > 0 and (size[1] > 0 or size[1] == -1)
|
||||||
if size[1] == -1:
|
if size[1] == -1:
|
||||||
self.resize_w_short_side = True
|
self.resize_w_short_side = True
|
||||||
assert interpolation in ("nearest", "bilinear", "bicubic", "area",
|
assert interpolation in ('nearest', 'bilinear', 'bicubic', 'area',
|
||||||
"lanczos")
|
'lanczos')
|
||||||
if backend not in ['cv2', 'pillow']:
|
if backend not in ['cv2', 'pillow']:
|
||||||
raise ValueError(f'backend: {backend} is not supported for resize.'
|
raise ValueError(f'backend: {backend} is not supported for resize.'
|
||||||
'Supported backends are "cv2", "pillow"')
|
'Supported backends are "cv2", "pillow"')
|
||||||
|
@ -486,3 +494,131 @@ class Normalize(object):
|
||||||
repr_str += f'std={list(self.std)}, '
|
repr_str += f'std={list(self.std)}, '
|
||||||
repr_str += f'to_rgb={self.to_rgb})'
|
repr_str += f'to_rgb={self.to_rgb})'
|
||||||
return repr_str
|
return repr_str
|
||||||
|
|
||||||
|
|
||||||
|
@PIPELINES.register_module()
|
||||||
|
class Albu(object):
|
||||||
|
"""Albumentation augmentation.
|
||||||
|
|
||||||
|
Adds custom transformations from Albumentations library.
|
||||||
|
Please, visit `https://albumentations.readthedocs.io`
|
||||||
|
to get more information.
|
||||||
|
An example of ``transforms`` is as followed:
|
||||||
|
|
||||||
|
.. code-block::
|
||||||
|
[
|
||||||
|
dict(
|
||||||
|
type='ShiftScaleRotate',
|
||||||
|
shift_limit=0.0625,
|
||||||
|
scale_limit=0.0,
|
||||||
|
rotate_limit=0,
|
||||||
|
interpolation=1,
|
||||||
|
p=0.5),
|
||||||
|
dict(
|
||||||
|
type='RandomBrightnessContrast',
|
||||||
|
brightness_limit=[0.1, 0.3],
|
||||||
|
contrast_limit=[0.1, 0.3],
|
||||||
|
p=0.2),
|
||||||
|
dict(type='ChannelShuffle', p=0.1),
|
||||||
|
dict(
|
||||||
|
type='OneOf',
|
||||||
|
transforms=[
|
||||||
|
dict(type='Blur', blur_limit=3, p=1.0),
|
||||||
|
dict(type='MedianBlur', blur_limit=3, p=1.0)
|
||||||
|
],
|
||||||
|
p=0.1),
|
||||||
|
]
|
||||||
|
|
||||||
|
Args:
|
||||||
|
transforms (list[dict]): A list of albu transformations
|
||||||
|
keymap (dict): Contains {'input key':'albumentation-style key'}
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, transforms, keymap=None, update_pad_shape=False):
|
||||||
|
if Compose is None:
|
||||||
|
raise RuntimeError('albumentations is not installed')
|
||||||
|
|
||||||
|
self.transforms = transforms
|
||||||
|
self.filter_lost_elements = False
|
||||||
|
self.update_pad_shape = update_pad_shape
|
||||||
|
|
||||||
|
self.aug = Compose([self.albu_builder(t) for t in self.transforms])
|
||||||
|
|
||||||
|
if not keymap:
|
||||||
|
self.keymap_to_albu = {
|
||||||
|
'img': 'image',
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
self.keymap_to_albu = keymap
|
||||||
|
self.keymap_back = {v: k for k, v in self.keymap_to_albu.items()}
|
||||||
|
|
||||||
|
def albu_builder(self, cfg):
|
||||||
|
"""Import a module from albumentations.
|
||||||
|
It inherits some of :func:`build_from_cfg` logic.
|
||||||
|
Args:
|
||||||
|
cfg (dict): Config dict. It should at least contain the key "type".
|
||||||
|
Returns:
|
||||||
|
obj: The constructed object.
|
||||||
|
"""
|
||||||
|
|
||||||
|
assert isinstance(cfg, dict) and 'type' in cfg
|
||||||
|
args = cfg.copy()
|
||||||
|
|
||||||
|
obj_type = args.pop('type')
|
||||||
|
if mmcv.is_str(obj_type):
|
||||||
|
if albumentations is None:
|
||||||
|
raise RuntimeError('albumentations is not installed')
|
||||||
|
obj_cls = getattr(albumentations, obj_type)
|
||||||
|
elif inspect.isclass(obj_type):
|
||||||
|
obj_cls = obj_type
|
||||||
|
else:
|
||||||
|
raise TypeError(
|
||||||
|
f'type must be a str or valid type, but got {type(obj_type)}')
|
||||||
|
|
||||||
|
if 'transforms' in args:
|
||||||
|
args['transforms'] = [
|
||||||
|
self.albu_builder(transform)
|
||||||
|
for transform in args['transforms']
|
||||||
|
]
|
||||||
|
|
||||||
|
return obj_cls(**args)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def mapper(d, keymap):
|
||||||
|
"""Dictionary mapper. Renames keys according to keymap provided.
|
||||||
|
Args:
|
||||||
|
d (dict): old dict
|
||||||
|
keymap (dict): {'old_key':'new_key'}
|
||||||
|
Returns:
|
||||||
|
dict: new dict.
|
||||||
|
"""
|
||||||
|
|
||||||
|
updated_dict = {}
|
||||||
|
for k, v in zip(d.keys(), d.values()):
|
||||||
|
new_k = keymap.get(k, k)
|
||||||
|
updated_dict[new_k] = d[k]
|
||||||
|
return updated_dict
|
||||||
|
|
||||||
|
def __call__(self, results):
|
||||||
|
# dict to albumentations format
|
||||||
|
results = self.mapper(results, self.keymap_to_albu)
|
||||||
|
|
||||||
|
results = self.aug(**results)
|
||||||
|
|
||||||
|
if 'gt_labels' in results:
|
||||||
|
if isinstance(results['gt_labels'], list):
|
||||||
|
results['gt_labels'] = np.array(results['gt_labels'])
|
||||||
|
results['gt_labels'] = results['gt_labels'].astype(np.int64)
|
||||||
|
|
||||||
|
# back to the original format
|
||||||
|
results = self.mapper(results, self.keymap_back)
|
||||||
|
|
||||||
|
# update final shape
|
||||||
|
if self.update_pad_shape:
|
||||||
|
results['pad_shape'] = results['img'].shape
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
repr_str = self.__class__.__name__ + f'(transforms={self.transforms})'
|
||||||
|
return repr_str
|
||||||
|
|
|
@ -1,2 +1,3 @@
|
||||||
-r requirements/runtime.txt
|
-r requirements/runtime.txt
|
||||||
|
-r requirements/optional.txt
|
||||||
-r requirements/tests.txt
|
-r requirements/tests.txt
|
||||||
|
|
|
@ -0,0 +1 @@
|
||||||
|
albumentations>=0.3.2
|
|
@ -780,3 +780,27 @@ def test_randomflip():
|
||||||
flipped_img = np.array(flipped_img)
|
flipped_img = np.array(flipped_img)
|
||||||
assert np.equal(results['img'], results['img2']).all()
|
assert np.equal(results['img'], results['img2']).all()
|
||||||
assert np.equal(results['img'], flipped_img).all()
|
assert np.equal(results['img'], flipped_img).all()
|
||||||
|
|
||||||
|
|
||||||
|
def test_albu_transform():
|
||||||
|
results = dict(
|
||||||
|
img_prefix=osp.join(osp.dirname(__file__), '../data'),
|
||||||
|
img_info=dict(filename='color.jpg'))
|
||||||
|
|
||||||
|
# Define simple pipeline
|
||||||
|
load = dict(type='LoadImageFromFile')
|
||||||
|
load = build_from_cfg(load, PIPELINES)
|
||||||
|
|
||||||
|
albu_transform = dict(
|
||||||
|
type='Albu', transforms=[dict(type='ChannelShuffle', p=1)])
|
||||||
|
albu_transform = build_from_cfg(albu_transform, PIPELINES)
|
||||||
|
|
||||||
|
normalize = dict(type='Normalize', mean=[0] * 3, std=[0] * 3, to_rgb=True)
|
||||||
|
normalize = build_from_cfg(normalize, PIPELINES)
|
||||||
|
|
||||||
|
# Execute transforms
|
||||||
|
results = load(results)
|
||||||
|
results = albu_transform(results)
|
||||||
|
results = normalize(results)
|
||||||
|
|
||||||
|
assert results['img'].dtype == np.float32
|
||||||
|
|
Loading…
Reference in New Issue