Add albumentations (#45)
* Add Albu transform * pre-commit * Create optional.txt * Update requirements.txt * Update transforms.pypull/48/head
parent
8d3acce307
commit
99115fddbc
|
@ -64,7 +64,7 @@ Optional arguments:
|
|||
|
||||
Examples:
|
||||
|
||||
Assume that you have already downloaded the checkpoints to the directory `checkpoints/`.
|
||||
Assume that you have already downloaded the checkpoints to the directory `checkpoints/`.
|
||||
Test ResNet-50 on ImageNet validation and evaluate the top-1 and top-5.
|
||||
|
||||
```shell
|
||||
|
|
|
@ -49,7 +49,7 @@ img_norm_cfg = dict(
|
|||
train_pipeline = [
|
||||
dict(type='RandomCrop', size=32, padding=4),
|
||||
dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
|
||||
dict(type='Resize', size=224)
|
||||
dict(type='Resize', size=224)
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='ImageToTensor', keys=['img']),
|
||||
dict(type='ToTensor', keys=['gt_label']),
|
||||
|
|
|
@ -27,9 +27,9 @@ from .resnet import ResNet
|
|||
class ResNet_CIFAR(ResNet):
|
||||
|
||||
"""ResNet backbone for CIFAR.
|
||||
|
||||
|
||||
short description of the backbone
|
||||
|
||||
|
||||
Args:
|
||||
depth(int): Network depth, from {18, 34, 50, 101, 152}.
|
||||
...
|
||||
|
@ -45,7 +45,7 @@ class ResNet_CIFAR(ResNet):
|
|||
|
||||
def init_weights(self, pretrained=None):
|
||||
pass # override ResNet init_weights if necessary
|
||||
|
||||
|
||||
def train(self, mode=True):
|
||||
pass # override ResNet train if necessary
|
||||
```
|
||||
|
@ -77,7 +77,7 @@ To add a new neck, we mainly implement the `forward` function, which applies som
|
|||
|
||||
```python
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
from ..builder import NECKS
|
||||
|
||||
@NECKS.register_module()
|
||||
|
@ -117,11 +117,11 @@ To implement a new head, basically we need to implement `forward_train`, which t
|
|||
```python
|
||||
from ..builder import HEADS
|
||||
from .cls_head import ClsHead
|
||||
|
||||
|
||||
|
||||
|
||||
@HEADS.register_module()
|
||||
class LinearClsHead(ClsHead):
|
||||
|
||||
|
||||
def __init__(self,
|
||||
num_classes,
|
||||
in_channels,
|
||||
|
@ -130,24 +130,24 @@ To implement a new head, basically we need to implement `forward_train`, which t
|
|||
super(LinearClsHead, self).__init__(loss=loss, topk=topk)
|
||||
self.in_channels = in_channels
|
||||
self.num_classes = num_classes
|
||||
|
||||
|
||||
if self.num_classes <= 0:
|
||||
raise ValueError(
|
||||
f'num_classes={num_classes} must be a positive integer')
|
||||
|
||||
|
||||
self._init_layers()
|
||||
|
||||
|
||||
def _init_layers(self):
|
||||
self.fc = nn.Linear(self.in_channels, self.num_classes)
|
||||
|
||||
|
||||
def init_weights(self):
|
||||
normal_init(self.fc, mean=0, std=0.01, bias=0)
|
||||
|
||||
|
||||
def forward_train(self, x, gt_label):
|
||||
cls_score = self.fc(x)
|
||||
losses = self.loss(cls_score, gt_label)
|
||||
return losses
|
||||
|
||||
|
||||
```
|
||||
|
||||
|
||||
|
@ -178,37 +178,37 @@ Together with the added GlobalAveragePooling neck, an entire config for a model
|
|||
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
|
||||
topk=(1, 5),
|
||||
))
|
||||
|
||||
|
||||
```
|
||||
|
||||
### Add new loss
|
||||
|
||||
To add a new loss function, we mainly implement the `forward` function in the loss module.
|
||||
In addition, it is helpful to leverage the decorator `weighted_loss` to weight the loss for each element.
|
||||
Assuming that we want to mimic a probablistic distribution generated from anther classification model, we implement a L1Loss to fulfil the purpose as below.
|
||||
Assuming that we want to mimic a probablistic distribution generated from anther classification model, we implement a L1Loss to fulfil the purpose as below.
|
||||
|
||||
1. Create a new file in `mmcls/models/losses/l1_loss.py`.
|
||||
```python
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
from ..builder import LOSSES
|
||||
from .utils import weighted_loss
|
||||
|
||||
|
||||
@weighted_loss
|
||||
def l1_loss(pred, target):
|
||||
assert pred.size() == target.size() and target.numel() > 0
|
||||
loss = torch.abs(pred - target)
|
||||
return loss
|
||||
|
||||
|
||||
@LOSSES.register_module()
|
||||
class L1Loss(nn.Module):
|
||||
|
||||
|
||||
def __init__(self, reduction='mean', loss_weight=1.0):
|
||||
super(L1Loss, self).__init__()
|
||||
self.reduction = reduction
|
||||
self.loss_weight = loss_weight
|
||||
|
||||
|
||||
def forward(self,
|
||||
pred,
|
||||
target,
|
||||
|
|
|
@ -18,8 +18,8 @@ class CIFAR10(BaseDataset):
|
|||
"""
|
||||
|
||||
base_folder = 'cifar-10-batches-py'
|
||||
url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
|
||||
filename = "cifar-10-python.tar.gz"
|
||||
url = 'https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'
|
||||
filename = 'cifar-10-python.tar.gz'
|
||||
tgz_md5 = 'c58f30108f718f92721af3b95e74349a'
|
||||
train_list = [
|
||||
['data_batch_1', 'c99cafc152244af753f735de768cd75f'],
|
||||
|
@ -110,8 +110,8 @@ class CIFAR100(CIFAR10):
|
|||
"""
|
||||
|
||||
base_folder = 'cifar-100-python'
|
||||
url = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz"
|
||||
filename = "cifar-100-python.tar.gz"
|
||||
url = 'https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz'
|
||||
filename = 'cifar-100-python.tar.gz'
|
||||
tgz_md5 = 'eb9058c3a382ffc7106e4002c42a8d85'
|
||||
train_list = [
|
||||
['train', '16019d7e3df5f24257cddd939b257f8d'],
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import inspect
|
||||
import math
|
||||
import random
|
||||
|
||||
|
@ -6,6 +7,13 @@ import numpy as np
|
|||
|
||||
from ..builder import PIPELINES
|
||||
|
||||
try:
|
||||
import albumentations
|
||||
from albumentations import Compose
|
||||
except ImportError:
|
||||
albumentations = None
|
||||
Compose = None
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
class RandomCrop(object):
|
||||
|
@ -155,8 +163,8 @@ class RandomResizedCrop(object):
|
|||
else:
|
||||
self.size = (size, size)
|
||||
if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
|
||||
raise ValueError("range should be of kind (min, max). "
|
||||
f"But received {scale}")
|
||||
raise ValueError('range should be of kind (min, max). '
|
||||
f'But received {scale}')
|
||||
if backend not in ['cv2', 'pillow']:
|
||||
raise ValueError(f'backend: {backend} is not supported for resize.'
|
||||
'Supported backends are "cv2", "pillow"')
|
||||
|
@ -363,8 +371,8 @@ class Resize(object):
|
|||
assert size[0] > 0 and (size[1] > 0 or size[1] == -1)
|
||||
if size[1] == -1:
|
||||
self.resize_w_short_side = True
|
||||
assert interpolation in ("nearest", "bilinear", "bicubic", "area",
|
||||
"lanczos")
|
||||
assert interpolation in ('nearest', 'bilinear', 'bicubic', 'area',
|
||||
'lanczos')
|
||||
if backend not in ['cv2', 'pillow']:
|
||||
raise ValueError(f'backend: {backend} is not supported for resize.'
|
||||
'Supported backends are "cv2", "pillow"')
|
||||
|
@ -486,3 +494,131 @@ class Normalize(object):
|
|||
repr_str += f'std={list(self.std)}, '
|
||||
repr_str += f'to_rgb={self.to_rgb})'
|
||||
return repr_str
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
class Albu(object):
|
||||
"""Albumentation augmentation.
|
||||
|
||||
Adds custom transformations from Albumentations library.
|
||||
Please, visit `https://albumentations.readthedocs.io`
|
||||
to get more information.
|
||||
An example of ``transforms`` is as followed:
|
||||
|
||||
.. code-block::
|
||||
[
|
||||
dict(
|
||||
type='ShiftScaleRotate',
|
||||
shift_limit=0.0625,
|
||||
scale_limit=0.0,
|
||||
rotate_limit=0,
|
||||
interpolation=1,
|
||||
p=0.5),
|
||||
dict(
|
||||
type='RandomBrightnessContrast',
|
||||
brightness_limit=[0.1, 0.3],
|
||||
contrast_limit=[0.1, 0.3],
|
||||
p=0.2),
|
||||
dict(type='ChannelShuffle', p=0.1),
|
||||
dict(
|
||||
type='OneOf',
|
||||
transforms=[
|
||||
dict(type='Blur', blur_limit=3, p=1.0),
|
||||
dict(type='MedianBlur', blur_limit=3, p=1.0)
|
||||
],
|
||||
p=0.1),
|
||||
]
|
||||
|
||||
Args:
|
||||
transforms (list[dict]): A list of albu transformations
|
||||
keymap (dict): Contains {'input key':'albumentation-style key'}
|
||||
"""
|
||||
|
||||
def __init__(self, transforms, keymap=None, update_pad_shape=False):
|
||||
if Compose is None:
|
||||
raise RuntimeError('albumentations is not installed')
|
||||
|
||||
self.transforms = transforms
|
||||
self.filter_lost_elements = False
|
||||
self.update_pad_shape = update_pad_shape
|
||||
|
||||
self.aug = Compose([self.albu_builder(t) for t in self.transforms])
|
||||
|
||||
if not keymap:
|
||||
self.keymap_to_albu = {
|
||||
'img': 'image',
|
||||
}
|
||||
else:
|
||||
self.keymap_to_albu = keymap
|
||||
self.keymap_back = {v: k for k, v in self.keymap_to_albu.items()}
|
||||
|
||||
def albu_builder(self, cfg):
|
||||
"""Import a module from albumentations.
|
||||
It inherits some of :func:`build_from_cfg` logic.
|
||||
Args:
|
||||
cfg (dict): Config dict. It should at least contain the key "type".
|
||||
Returns:
|
||||
obj: The constructed object.
|
||||
"""
|
||||
|
||||
assert isinstance(cfg, dict) and 'type' in cfg
|
||||
args = cfg.copy()
|
||||
|
||||
obj_type = args.pop('type')
|
||||
if mmcv.is_str(obj_type):
|
||||
if albumentations is None:
|
||||
raise RuntimeError('albumentations is not installed')
|
||||
obj_cls = getattr(albumentations, obj_type)
|
||||
elif inspect.isclass(obj_type):
|
||||
obj_cls = obj_type
|
||||
else:
|
||||
raise TypeError(
|
||||
f'type must be a str or valid type, but got {type(obj_type)}')
|
||||
|
||||
if 'transforms' in args:
|
||||
args['transforms'] = [
|
||||
self.albu_builder(transform)
|
||||
for transform in args['transforms']
|
||||
]
|
||||
|
||||
return obj_cls(**args)
|
||||
|
||||
@staticmethod
|
||||
def mapper(d, keymap):
|
||||
"""Dictionary mapper. Renames keys according to keymap provided.
|
||||
Args:
|
||||
d (dict): old dict
|
||||
keymap (dict): {'old_key':'new_key'}
|
||||
Returns:
|
||||
dict: new dict.
|
||||
"""
|
||||
|
||||
updated_dict = {}
|
||||
for k, v in zip(d.keys(), d.values()):
|
||||
new_k = keymap.get(k, k)
|
||||
updated_dict[new_k] = d[k]
|
||||
return updated_dict
|
||||
|
||||
def __call__(self, results):
|
||||
# dict to albumentations format
|
||||
results = self.mapper(results, self.keymap_to_albu)
|
||||
|
||||
results = self.aug(**results)
|
||||
|
||||
if 'gt_labels' in results:
|
||||
if isinstance(results['gt_labels'], list):
|
||||
results['gt_labels'] = np.array(results['gt_labels'])
|
||||
results['gt_labels'] = results['gt_labels'].astype(np.int64)
|
||||
|
||||
# back to the original format
|
||||
results = self.mapper(results, self.keymap_back)
|
||||
|
||||
# update final shape
|
||||
if self.update_pad_shape:
|
||||
results['pad_shape'] = results['img'].shape
|
||||
|
||||
return results
|
||||
|
||||
def __repr__(self):
|
||||
repr_str = self.__class__.__name__ + f'(transforms={self.transforms})'
|
||||
return repr_str
|
||||
|
|
|
@ -1,2 +1,3 @@
|
|||
-r requirements/runtime.txt
|
||||
-r requirements/optional.txt
|
||||
-r requirements/tests.txt
|
||||
|
|
|
@ -0,0 +1 @@
|
|||
albumentations>=0.3.2
|
|
@ -780,3 +780,27 @@ def test_randomflip():
|
|||
flipped_img = np.array(flipped_img)
|
||||
assert np.equal(results['img'], results['img2']).all()
|
||||
assert np.equal(results['img'], flipped_img).all()
|
||||
|
||||
|
||||
def test_albu_transform():
|
||||
results = dict(
|
||||
img_prefix=osp.join(osp.dirname(__file__), '../data'),
|
||||
img_info=dict(filename='color.jpg'))
|
||||
|
||||
# Define simple pipeline
|
||||
load = dict(type='LoadImageFromFile')
|
||||
load = build_from_cfg(load, PIPELINES)
|
||||
|
||||
albu_transform = dict(
|
||||
type='Albu', transforms=[dict(type='ChannelShuffle', p=1)])
|
||||
albu_transform = build_from_cfg(albu_transform, PIPELINES)
|
||||
|
||||
normalize = dict(type='Normalize', mean=[0] * 3, std=[0] * 3, to_rgb=True)
|
||||
normalize = build_from_cfg(normalize, PIPELINES)
|
||||
|
||||
# Execute transforms
|
||||
results = load(results)
|
||||
results = albu_transform(results)
|
||||
results = normalize(results)
|
||||
|
||||
assert results['img'].dtype == np.float32
|
||||
|
|
Loading…
Reference in New Issue