From 99115fddbc9e326856ebab743281ee908d8f6641 Mon Sep 17 00:00:00 2001 From: David de la Iglesia Castro Date: Tue, 22 Sep 2020 11:35:39 +0200 Subject: [PATCH] Add albumentations (#45) * Add Albu transform * pre-commit * Create optional.txt * Update requirements.txt * Update transforms.py --- docs/getting_started.md | 2 +- docs/tutorials/finetune.md | 2 +- docs/tutorials/new_modules.md | 40 +++---- mmcls/datasets/cifar.py | 8 +- mmcls/datasets/pipelines/transforms.py | 144 ++++++++++++++++++++++++- requirements.txt | 1 + requirements/optional.txt | 1 + tests/test_pipelines/test_transform.py | 24 +++++ 8 files changed, 192 insertions(+), 30 deletions(-) create mode 100644 requirements/optional.txt diff --git a/docs/getting_started.md b/docs/getting_started.md index 2ee36f66..75134e38 100644 --- a/docs/getting_started.md +++ b/docs/getting_started.md @@ -64,7 +64,7 @@ Optional arguments: Examples: -Assume that you have already downloaded the checkpoints to the directory `checkpoints/`. +Assume that you have already downloaded the checkpoints to the directory `checkpoints/`. Test ResNet-50 on ImageNet validation and evaluate the top-1 and top-5. ```shell diff --git a/docs/tutorials/finetune.md b/docs/tutorials/finetune.md index e1890c24..0895ec94 100644 --- a/docs/tutorials/finetune.md +++ b/docs/tutorials/finetune.md @@ -49,7 +49,7 @@ img_norm_cfg = dict( train_pipeline = [ dict(type='RandomCrop', size=32, padding=4), dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'), - dict(type='Resize', size=224) + dict(type='Resize', size=224) dict(type='Normalize', **img_norm_cfg), dict(type='ImageToTensor', keys=['img']), dict(type='ToTensor', keys=['gt_label']), diff --git a/docs/tutorials/new_modules.md b/docs/tutorials/new_modules.md index b111b304..49974e32 100644 --- a/docs/tutorials/new_modules.md +++ b/docs/tutorials/new_modules.md @@ -27,9 +27,9 @@ from .resnet import ResNet class ResNet_CIFAR(ResNet): """ResNet backbone for CIFAR. - + short description of the backbone - + Args: depth(int): Network depth, from {18, 34, 50, 101, 152}. ... @@ -45,7 +45,7 @@ class ResNet_CIFAR(ResNet): def init_weights(self, pretrained=None): pass # override ResNet init_weights if necessary - + def train(self, mode=True): pass # override ResNet train if necessary ``` @@ -77,7 +77,7 @@ To add a new neck, we mainly implement the `forward` function, which applies som ```python import torch.nn as nn - + from ..builder import NECKS @NECKS.register_module() @@ -117,11 +117,11 @@ To implement a new head, basically we need to implement `forward_train`, which t ```python from ..builder import HEADS from .cls_head import ClsHead - - + + @HEADS.register_module() class LinearClsHead(ClsHead): - + def __init__(self, num_classes, in_channels, @@ -130,24 +130,24 @@ To implement a new head, basically we need to implement `forward_train`, which t super(LinearClsHead, self).__init__(loss=loss, topk=topk) self.in_channels = in_channels self.num_classes = num_classes - + if self.num_classes <= 0: raise ValueError( f'num_classes={num_classes} must be a positive integer') - + self._init_layers() - + def _init_layers(self): self.fc = nn.Linear(self.in_channels, self.num_classes) - + def init_weights(self): normal_init(self.fc, mean=0, std=0.01, bias=0) - + def forward_train(self, x, gt_label): cls_score = self.fc(x) losses = self.loss(cls_score, gt_label) return losses - + ``` @@ -178,37 +178,37 @@ Together with the added GlobalAveragePooling neck, an entire config for a model loss=dict(type='CrossEntropyLoss', loss_weight=1.0), topk=(1, 5), )) - + ``` ### Add new loss To add a new loss function, we mainly implement the `forward` function in the loss module. In addition, it is helpful to leverage the decorator `weighted_loss` to weight the loss for each element. -Assuming that we want to mimic a probablistic distribution generated from anther classification model, we implement a L1Loss to fulfil the purpose as below. +Assuming that we want to mimic a probablistic distribution generated from anther classification model, we implement a L1Loss to fulfil the purpose as below. 1. Create a new file in `mmcls/models/losses/l1_loss.py`. ```python import torch import torch.nn as nn - + from ..builder import LOSSES from .utils import weighted_loss - + @weighted_loss def l1_loss(pred, target): assert pred.size() == target.size() and target.numel() > 0 loss = torch.abs(pred - target) return loss - + @LOSSES.register_module() class L1Loss(nn.Module): - + def __init__(self, reduction='mean', loss_weight=1.0): super(L1Loss, self).__init__() self.reduction = reduction self.loss_weight = loss_weight - + def forward(self, pred, target, diff --git a/mmcls/datasets/cifar.py b/mmcls/datasets/cifar.py index a5dea0a0..ba7e7fc9 100644 --- a/mmcls/datasets/cifar.py +++ b/mmcls/datasets/cifar.py @@ -18,8 +18,8 @@ class CIFAR10(BaseDataset): """ base_folder = 'cifar-10-batches-py' - url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz" - filename = "cifar-10-python.tar.gz" + url = 'https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz' + filename = 'cifar-10-python.tar.gz' tgz_md5 = 'c58f30108f718f92721af3b95e74349a' train_list = [ ['data_batch_1', 'c99cafc152244af753f735de768cd75f'], @@ -110,8 +110,8 @@ class CIFAR100(CIFAR10): """ base_folder = 'cifar-100-python' - url = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz" - filename = "cifar-100-python.tar.gz" + url = 'https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz' + filename = 'cifar-100-python.tar.gz' tgz_md5 = 'eb9058c3a382ffc7106e4002c42a8d85' train_list = [ ['train', '16019d7e3df5f24257cddd939b257f8d'], diff --git a/mmcls/datasets/pipelines/transforms.py b/mmcls/datasets/pipelines/transforms.py index 7242bcd2..2bd96bbc 100644 --- a/mmcls/datasets/pipelines/transforms.py +++ b/mmcls/datasets/pipelines/transforms.py @@ -1,3 +1,4 @@ +import inspect import math import random @@ -6,6 +7,13 @@ import numpy as np from ..builder import PIPELINES +try: + import albumentations + from albumentations import Compose +except ImportError: + albumentations = None + Compose = None + @PIPELINES.register_module() class RandomCrop(object): @@ -155,8 +163,8 @@ class RandomResizedCrop(object): else: self.size = (size, size) if (scale[0] > scale[1]) or (ratio[0] > ratio[1]): - raise ValueError("range should be of kind (min, max). " - f"But received {scale}") + raise ValueError('range should be of kind (min, max). ' + f'But received {scale}') if backend not in ['cv2', 'pillow']: raise ValueError(f'backend: {backend} is not supported for resize.' 'Supported backends are "cv2", "pillow"') @@ -363,8 +371,8 @@ class Resize(object): assert size[0] > 0 and (size[1] > 0 or size[1] == -1) if size[1] == -1: self.resize_w_short_side = True - assert interpolation in ("nearest", "bilinear", "bicubic", "area", - "lanczos") + assert interpolation in ('nearest', 'bilinear', 'bicubic', 'area', + 'lanczos') if backend not in ['cv2', 'pillow']: raise ValueError(f'backend: {backend} is not supported for resize.' 'Supported backends are "cv2", "pillow"') @@ -486,3 +494,131 @@ class Normalize(object): repr_str += f'std={list(self.std)}, ' repr_str += f'to_rgb={self.to_rgb})' return repr_str + + +@PIPELINES.register_module() +class Albu(object): + """Albumentation augmentation. + + Adds custom transformations from Albumentations library. + Please, visit `https://albumentations.readthedocs.io` + to get more information. + An example of ``transforms`` is as followed: + + .. code-block:: + [ + dict( + type='ShiftScaleRotate', + shift_limit=0.0625, + scale_limit=0.0, + rotate_limit=0, + interpolation=1, + p=0.5), + dict( + type='RandomBrightnessContrast', + brightness_limit=[0.1, 0.3], + contrast_limit=[0.1, 0.3], + p=0.2), + dict(type='ChannelShuffle', p=0.1), + dict( + type='OneOf', + transforms=[ + dict(type='Blur', blur_limit=3, p=1.0), + dict(type='MedianBlur', blur_limit=3, p=1.0) + ], + p=0.1), + ] + + Args: + transforms (list[dict]): A list of albu transformations + keymap (dict): Contains {'input key':'albumentation-style key'} + """ + + def __init__(self, transforms, keymap=None, update_pad_shape=False): + if Compose is None: + raise RuntimeError('albumentations is not installed') + + self.transforms = transforms + self.filter_lost_elements = False + self.update_pad_shape = update_pad_shape + + self.aug = Compose([self.albu_builder(t) for t in self.transforms]) + + if not keymap: + self.keymap_to_albu = { + 'img': 'image', + } + else: + self.keymap_to_albu = keymap + self.keymap_back = {v: k for k, v in self.keymap_to_albu.items()} + + def albu_builder(self, cfg): + """Import a module from albumentations. + It inherits some of :func:`build_from_cfg` logic. + Args: + cfg (dict): Config dict. It should at least contain the key "type". + Returns: + obj: The constructed object. + """ + + assert isinstance(cfg, dict) and 'type' in cfg + args = cfg.copy() + + obj_type = args.pop('type') + if mmcv.is_str(obj_type): + if albumentations is None: + raise RuntimeError('albumentations is not installed') + obj_cls = getattr(albumentations, obj_type) + elif inspect.isclass(obj_type): + obj_cls = obj_type + else: + raise TypeError( + f'type must be a str or valid type, but got {type(obj_type)}') + + if 'transforms' in args: + args['transforms'] = [ + self.albu_builder(transform) + for transform in args['transforms'] + ] + + return obj_cls(**args) + + @staticmethod + def mapper(d, keymap): + """Dictionary mapper. Renames keys according to keymap provided. + Args: + d (dict): old dict + keymap (dict): {'old_key':'new_key'} + Returns: + dict: new dict. + """ + + updated_dict = {} + for k, v in zip(d.keys(), d.values()): + new_k = keymap.get(k, k) + updated_dict[new_k] = d[k] + return updated_dict + + def __call__(self, results): + # dict to albumentations format + results = self.mapper(results, self.keymap_to_albu) + + results = self.aug(**results) + + if 'gt_labels' in results: + if isinstance(results['gt_labels'], list): + results['gt_labels'] = np.array(results['gt_labels']) + results['gt_labels'] = results['gt_labels'].astype(np.int64) + + # back to the original format + results = self.mapper(results, self.keymap_back) + + # update final shape + if self.update_pad_shape: + results['pad_shape'] = results['img'].shape + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + f'(transforms={self.transforms})' + return repr_str diff --git a/requirements.txt b/requirements.txt index ec4ca05e..307c21c5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,3 @@ -r requirements/runtime.txt +-r requirements/optional.txt -r requirements/tests.txt diff --git a/requirements/optional.txt b/requirements/optional.txt new file mode 100644 index 00000000..e20fb3f6 --- /dev/null +++ b/requirements/optional.txt @@ -0,0 +1 @@ +albumentations>=0.3.2 diff --git a/tests/test_pipelines/test_transform.py b/tests/test_pipelines/test_transform.py index 2884d9ea..77baa30f 100644 --- a/tests/test_pipelines/test_transform.py +++ b/tests/test_pipelines/test_transform.py @@ -780,3 +780,27 @@ def test_randomflip(): flipped_img = np.array(flipped_img) assert np.equal(results['img'], results['img2']).all() assert np.equal(results['img'], flipped_img).all() + + +def test_albu_transform(): + results = dict( + img_prefix=osp.join(osp.dirname(__file__), '../data'), + img_info=dict(filename='color.jpg')) + + # Define simple pipeline + load = dict(type='LoadImageFromFile') + load = build_from_cfg(load, PIPELINES) + + albu_transform = dict( + type='Albu', transforms=[dict(type='ChannelShuffle', p=1)]) + albu_transform = build_from_cfg(albu_transform, PIPELINES) + + normalize = dict(type='Normalize', mean=[0] * 3, std=[0] * 3, to_rgb=True) + normalize = build_from_cfg(normalize, PIPELINES) + + # Execute transforms + results = load(results) + results = albu_transform(results) + results = normalize(results) + + assert results['img'].dtype == np.float32