[Feature] Register torchvision transforms into mmcls (#1265)

* [Enhance] Add stochastic depth decay rule in resnet. (#1363)

* add stochastic depth decay rule to drop path rate

* add default value

* update

* pass ut

* update

* pass ut

* remove np

* rebase

* update ToPIL and ToNumpy

* rebase

* rebase

* rebase

* rebase

* add readme

* fix review suggestions

* rebase

* fix conflicts

* fix conflicts

* fix lint

* remove comments

* remove useless code

* update docstring

* update doc API

* update doc

---------

Co-authored-by: Yixiao Fang <36138628+fangyixiao18@users.noreply.github.com>
pull/1488/head
Ezra-Yu 2023-04-13 18:05:57 +08:00 committed by GitHub
parent 0826df8963
commit 99e48116aa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 367 additions and 36 deletions

View File

@ -61,8 +61,8 @@ Loading and Formatting
LoadImageFromFile
PackInputs
PackMultiTaskInputs
ToNumpy
ToPIL
PILToNumpy
NumpyToPIL
Transpose
Collect
@ -147,6 +147,88 @@ Transform Wrapper
.. module:: mmpretrain.models.utils.data_preprocessor
TorchVision Transforms
^^^^^^^^^^^^^^^^^^^^^^
We also provide all the transforms in TorchVision. You can use them the like following examples:
**1. Use some TorchVision Augs Surrounded by NumpyToPIL and PILToNumpy (Recommendation)**
Add TorchVision Augs surrounded by ``dict(type='NumpyToPIL', to_rgb=True),`` and ``dict(type='PILToNumpy', to_bgr=True),``
.. code:: python
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='NumpyToPIL', to_rgb=True), # from BGR in cv2 to RGB in PIL
dict(type='torchvision/RandomResizedCrop',size=176),
dict(type='PILToNumpy', to_bgr=True), # from RGB in PIL to BGR in cv2
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
dict(type='PackInputs'),
]
data_preprocessor = dict(
num_classes=1000,
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
to_rgb=True, # from BGR in cv2 to RGB in PIL
)
**2. Use TorchVision Augs and ToTensor&Normalize**
Make sure the 'img' has been converted to PIL format from BGR-Numpy format before being processed by TorchVision Augs.
.. code:: python
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='NumpyToPIL', to_rgb=True), # from BGR in cv2 to RGB in PIL
dict(
type='torchvision/RandomResizedCrop',
size=176,
interpolation='bilinear'), # accept str format interpolation mode
dict(type='torchvision/RandomHorizontalFlip', p=0.5),
dict(
type='torchvision/TrivialAugmentWide',
interpolation='bilinear'),
dict(type='torchvision/PILToTensor'),
dict(type='torchvision/ConvertImageDtype', dtype=torch.float),
dict(
type='torchvision/Normalize',
mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225),
),
dict(type='torchvision/RandomErasing', p=0.1),
dict(type='PackInputs'),
]
data_preprocessor = dict(num_classes=1000, mean=None, std=None, to_rgb=False) # Normalize in dataset pipeline
**3. Use TorchVision Augs Except ToTensor&Normalize**
.. code:: python
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='NumpyToPIL', to_rgb=True), # from BGR in cv2 to RGB in PIL
dict(type='torchvision/RandomResizedCrop', size=176, interpolation='bilinear'),
dict(type='torchvision/RandomHorizontalFlip', p=0.5),
dict(type='torchvision/TrivialAugmentWide', interpolation='bilinear'),
dict(type='PackInputs'),
]
# here the Normalize params is for the RGB format
data_preprocessor = dict(
num_classes=1000,
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
to_rgb=False,
)
Data Preprocessors
------------------

View File

@ -8,8 +8,8 @@ from .auto_augment import (AutoAugment, AutoContrast, BaseAugTransform,
Equalize, GaussianBlur, Invert, Posterize,
RandAugment, Rotate, Sharpness, Shear, Solarize,
SolarizeAdd, Translate)
from .formatting import (Collect, PackInputs, PackMultiTaskInputs, ToNumpy,
ToPIL, Transpose)
from .formatting import (Collect, NumpyToPIL, PackInputs, PackMultiTaskInputs,
PILToNumpy, Transpose)
from .processing import (Albumentations, BEiTMaskGenerator, ColorJitter,
EfficientNetCenterCrop, EfficientNetRandomCrop,
Lighting, RandomCrop, RandomErasing,
@ -21,7 +21,7 @@ for t in (CenterCrop, LoadImageFromFile, Normalize, RandomFlip,
TRANSFORMS.register_module(module=t)
__all__ = [
'ToPIL', 'ToNumpy', 'Transpose', 'Collect', 'RandomCrop',
'NumpyToPIL', 'PILToNumpy', 'Transpose', 'Collect', 'RandomCrop',
'RandomResizedCrop', 'Shear', 'Translate', 'Rotate', 'Invert',
'ColorTransform', 'Solarize', 'Posterize', 'AutoContrast', 'Equalize',
'Contrast', 'Brightness', 'Sharpness', 'AutoAugment', 'SolarizeAdd',

View File

@ -2,6 +2,7 @@
from collections import defaultdict
from collections.abc import Sequence
import cv2
import numpy as np
import torch
import torchvision.transforms.functional as F
@ -256,55 +257,70 @@ class Transpose(BaseTransform):
f'(keys={self.keys}, order={self.order})'
@TRANSFORMS.register_module()
class ToPIL(BaseTransform):
@TRANSFORMS.register_module(('NumpyToPIL', 'ToPIL'))
class NumpyToPIL(BaseTransform):
"""Convert the image from OpenCV format to :obj:`PIL.Image.Image`.
**Required Keys:**
- img
- ``img``
**Modified Keys:**
- img
- ``img``
Args:
to_rgb (bool): Whether to convert img to rgb. Defaults to True.
"""
def transform(self, results):
def __init__(self, to_rgb: bool = False) -> None:
self.to_rgb = to_rgb
def transform(self, results: dict) -> dict:
"""Method to convert images to :obj:`PIL.Image.Image`."""
results['img'] = Image.fromarray(results['img'])
img = results['img']
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) if self.to_rgb else img
results['img'] = Image.fromarray(img)
return results
def __repr__(self) -> str:
return self.__class__.__name__ + f'(to_rgb={self.to_rgb})'
@TRANSFORMS.register_module()
class ToNumpy(BaseTransform):
"""Convert object to :obj:`numpy.ndarray`.
@TRANSFORMS.register_module(('PILToNumpy', 'ToNumpy'))
class PILToNumpy(BaseTransform):
"""Convert img to :obj:`numpy.ndarray`.
**Required Keys:**
- ``*keys**``
- ``img``
**Modified Keys:**
- ``*keys**``
- ``img``
Args:
to_bgr (bool): Whether to convert img to rgb. Defaults to True.
dtype (str, optional): The dtype of the converted numpy array.
Defaults to None.
"""
def __init__(self, keys, dtype=None):
self.keys = keys
def __init__(self, to_bgr: bool = False, dtype=None) -> None:
self.to_bgr = to_bgr
self.dtype = dtype
def transform(self, results):
"""Method to convert object to :obj:`numpy.ndarray`."""
for key in self.keys:
results[key] = np.array(results[key], dtype=self.dtype)
def transform(self, results: dict) -> dict:
"""Method to convert img to :obj:`numpy.ndarray`."""
img = np.array(results['img'], dtype=self.dtype)
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) if self.to_bgr else img
results['img'] = img
return results
def __repr__(self):
def __repr__(self) -> str:
return self.__class__.__name__ + \
f'(keys={self.keys}, dtype={self.dtype})'
f'(to_bgr={self.to_bgr}, dtype={self.dtype})'
@TRANSFORMS.register_module()

View File

@ -2,14 +2,19 @@
import inspect
import math
import numbers
import re
import traceback
from enum import EnumMeta
from numbers import Number
from typing import Dict, List, Optional, Sequence, Tuple, Union
import mmcv
import mmengine
import numpy as np
import torchvision
from mmcv.transforms import BaseTransform
from mmcv.transforms.utils import cache_randomness
from torchvision.transforms.transforms import InterpolationMode
from mmpretrain.registry import TRANSFORMS
@ -19,6 +24,92 @@ except ImportError:
albumentations = None
def _str_to_torch_dtype(t: str):
"""mapping str format dtype to torch.dtype."""
import torch # noqa: F401,F403
return eval(f'torch.{t}')
def _interpolation_modes_from_str(t: str):
"""mapping str format to Interpolation."""
t = t.lower()
inverse_modes_mapping = {
'nearest': InterpolationMode.NEAREST,
'bilinear': InterpolationMode.BILINEAR,
'bicubic': InterpolationMode.BICUBIC,
'box': InterpolationMode.BOX,
'hammimg': InterpolationMode.HAMMING,
'lanczos': InterpolationMode.LANCZOS,
}
return inverse_modes_mapping[t]
def _warpper_vision_transform_cls(vision_transform_cls, new_name):
"""build a transform warpper class for specific torchvison.transform to
handle the different input type between torchvison.transforms with
mmcls.datasets.transforms."""
def new_init(self, *args, **kwargs):
if 'interpolation' in kwargs and isinstance(kwargs['interpolation'],
str):
kwargs['interpolation'] = _interpolation_modes_from_str(
kwargs['interpolation'])
if 'dtype' in kwargs and isinstance(kwargs['dtype'], str):
kwargs['dtype'] = _str_to_torch_dtype(kwargs['dtype'])
try:
self.t = vision_transform_cls(*args, **kwargs)
except TypeError as e:
traceback.print_exc()
raise TypeError(
f'Error when init the {vision_transform_cls}, please '
f'check the argmemnts of {args} and {kwargs}. \n{e}')
def new_call(self, input):
try:
input['img'] = self.t(input['img'])
except Exception as e:
traceback.print_exc()
raise Exception('Error when processing of transform(`torhcvison/'
f'{vision_transform_cls.__name__}`). \n{e}')
return input
def new_str(self):
return str(self.t)
new_transforms_cls = type(
new_name, (),
dict(__init__=new_init, __call__=new_call, __str__=new_str))
return new_transforms_cls
def register_vision_transforms() -> List[str]:
"""Register transforms in ``torchvision.transforms`` to the ``TRANSFORMS``
registry.
Returns:
List[str]: A list of registered transforms' name.
"""
vision_transforms = []
for module_name in dir(torchvision.transforms):
if not re.match('[A-Z]', module_name):
# must startswith a capital letter
continue
_transform = getattr(torchvision.transforms, module_name)
if inspect.isclass(_transform) and callable(
_transform) and not isinstance(_transform, (EnumMeta)):
new_cls = _warpper_vision_transform_cls(
_transform, f'TorchVison{module_name}')
TRANSFORMS.register_module(
module=new_cls, name=f'torchvision/{module_name}')
vision_transforms.append(f'torchvision/{module_name}')
return vision_transforms
# register all the transforms in torchvision by using a transform wrapper
VISION_TRANSFORMS = register_vision_transforms()
@TRANSFORMS.register_module()
class RandomCrop(BaseTransform):
"""Crop the given Image at a random location.

View File

@ -143,6 +143,8 @@ class Res2Layer(Sequential):
Default: dict(type='BN')
scales (int): Scales used in Res2Net. Default: 4
base_width (int): Basic width of each scale. Default: 26
drop_path_rate (float or np.ndarray): stochastic depth rate.
Default: 0.
"""
def __init__(self,
@ -156,9 +158,16 @@ class Res2Layer(Sequential):
norm_cfg=dict(type='BN'),
scales=4,
base_width=26,
drop_path_rate=0.0,
**kwargs):
self.block = block
if isinstance(drop_path_rate, float):
drop_path_rate = [drop_path_rate] * num_blocks
assert len(drop_path_rate
) == num_blocks, 'Please check the length of drop_path_rate'
downsample = None
if stride != 1 or in_channels != out_channels:
if avg_down:
@ -201,9 +210,10 @@ class Res2Layer(Sequential):
scales=scales,
base_width=base_width,
stage_type='stage',
drop_path_rate=drop_path_rate[0],
**kwargs))
in_channels = out_channels
for _ in range(1, num_blocks):
for i in range(1, num_blocks):
layers.append(
block(
in_channels=in_channels,
@ -213,6 +223,7 @@ class Res2Layer(Sequential):
norm_cfg=norm_cfg,
scales=scales,
base_width=base_width,
drop_path_rate=drop_path_rate[i],
**kwargs))
super(Res2Layer, self).__init__(*layers)

View File

@ -1,5 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
import torch.utils.checkpoint as cp
from mmcv.cnn import (ConvModule, build_activation_layer, build_conv_layer,
@ -334,6 +334,8 @@ class ResLayer(nn.Sequential):
layer. Default: None
norm_cfg (dict): dictionary to construct and config norm layer.
Default: dict(type='BN')
drop_path_rate (float or list): stochastic depth rate.
Default: 0.
"""
def __init__(self,
@ -346,10 +348,17 @@ class ResLayer(nn.Sequential):
avg_down=False,
conv_cfg=None,
norm_cfg=dict(type='BN'),
drop_path_rate=0.0,
**kwargs):
self.block = block
self.expansion = get_expansion(block, expansion)
if isinstance(drop_path_rate, float):
drop_path_rate = [drop_path_rate] * num_blocks
assert len(drop_path_rate
) == num_blocks, 'Please check the length of drop_path_rate'
downsample = None
if stride != 1 or in_channels != out_channels:
downsample = []
@ -384,6 +393,7 @@ class ResLayer(nn.Sequential):
downsample=downsample,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
drop_path_rate=drop_path_rate[0],
**kwargs))
in_channels = out_channels
for i in range(1, num_blocks):
@ -395,6 +405,7 @@ class ResLayer(nn.Sequential):
stride=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
drop_path_rate=drop_path_rate[i],
**kwargs))
super(ResLayer, self).__init__(*layers)
@ -518,6 +529,13 @@ class ResNet(BaseBackbone):
self.res_layers = []
_in_channels = stem_channels
_out_channels = base_channels * self.expansion
# stochastic depth decay rule
total_depth = sum(stage_blocks)
dpr = [
x.item() for x in torch.linspace(0, drop_path_rate, total_depth)
]
for i, num_blocks in enumerate(self.stage_blocks):
stride = strides[i]
dilation = dilations[i]
@ -534,9 +552,10 @@ class ResNet(BaseBackbone):
with_cp=with_cp,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
drop_path_rate=drop_path_rate)
drop_path_rate=dpr[:num_blocks])
_in_channels = _out_channels
_out_channels *= 2
dpr = dpr[num_blocks:]
layer_name = f'layer{i + 1}'
self.add_module(layer_name, res_layer)
self.res_layers.append(layer_name)

View File

@ -104,28 +104,47 @@ class TestToPIL(unittest.TestCase):
results = transform(copy.deepcopy(data))
self.assertIsInstance(results['img'], Image.Image)
cfg = dict(type='ToPIL', to_rgb=True)
transform = TRANSFORMS.build(cfg)
data = {'img': np.random.randint(0, 256, (224, 224, 3), dtype='uint8')}
results = transform(copy.deepcopy(data))
self.assertIsInstance(results['img'], Image.Image)
np.equal(np.array(results['img']), data['img'][:, :, ::-1])
def test_repr(self):
cfg = dict(type='ToPIL', to_rgb=True)
transform = TRANSFORMS.build(cfg)
self.assertEqual(repr(transform), 'NumpyToPIL(to_rgb=True)')
class TestToNumpy(unittest.TestCase):
def test_transform(self):
img_path = osp.join(osp.dirname(__file__), '../../data/color.jpg')
data = {
'tensor': torch.tensor([1, 2, 3]),
'Image': Image.open(img_path),
'img': Image.open(img_path),
}
cfg = dict(type='ToNumpy', keys=['tensor', 'Image'], dtype='uint8')
cfg = dict(type='ToNumpy')
transform = TRANSFORMS.build(cfg)
results = transform(copy.deepcopy(data))
self.assertIsInstance(results['tensor'], np.ndarray)
self.assertEqual(results['tensor'].dtype, 'uint8')
self.assertIsInstance(results['Image'], np.ndarray)
self.assertEqual(results['Image'].dtype, 'uint8')
self.assertIsInstance(results['img'], np.ndarray)
self.assertEqual(results['img'].dtype, 'uint8')
cfg = dict(type='ToNumpy', to_bgr=True)
transform = TRANSFORMS.build(cfg)
results = transform(copy.deepcopy(data))
self.assertIsInstance(results['img'], np.ndarray)
self.assertEqual(results['img'].dtype, 'uint8')
np.equal(results['img'], np.array(data['img'])[:, :, ::-1])
def test_repr(self):
cfg = dict(type='ToNumpy', keys=['img'], dtype='uint8')
cfg = dict(type='ToNumpy', to_bgr=True)
transform = TRANSFORMS.build(cfg)
self.assertEqual(repr(transform), "ToNumpy(keys=['img'], dtype=uint8)")
self.assertEqual(
repr(transform), 'PILToNumpy(to_bgr=True, dtype=None)')
class TestCollect(unittest.TestCase):

View File

@ -1,6 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import math
import os.path as osp
import random
from unittest import TestCase
from unittest.mock import ANY, call, patch
@ -8,7 +9,14 @@ from unittest.mock import ANY, call, patch
import mmengine
import numpy as np
import pytest
import torch
import torchvision
from mmcv.transforms import Compose
from mmengine.utils import digit_version
from PIL import Image
from torchvision import transforms
from mmpretrain.datasets.transforms.processing import VISION_TRANSFORMS
from mmpretrain.registry import TRANSFORMS
try:
@ -864,3 +872,88 @@ class TestBEiTMaskGenerator(TestCase):
repr(transform), 'BEiTMaskGenerator(height=14, width=14, '
'num_patches=196, num_masking_patches=75, min_num_patches=16, '
f'max_num_patches=75, log_aspect_ratio={log_aspect_ratio})')
class TestVisionTransformWrapper(TestCase):
def test_register(self):
for t in VISION_TRANSFORMS:
self.assertIn('torchvision/', t)
self.assertIn(t, TRANSFORMS)
def test_transform(self):
img_path = osp.join(osp.dirname(__file__), '../../data/color.jpg')
data = {'img': Image.open(img_path)}
# test normal transform
vision_trans = transforms.RandomResizedCrop(224)
vision_transformed_img = vision_trans(data['img'])
mmcls_trans = TRANSFORMS.build(
dict(type='torchvision/RandomResizedCrop', size=224))
mmcls_transformed_img = mmcls_trans(data)['img']
np.equal(
np.array(vision_transformed_img), np.array(mmcls_transformed_img))
# test convert type dtype
data = {'img': torch.randn(3, 224, 224)}
vision_trans = transforms.ConvertImageDtype(torch.float)
vision_transformed_img = vision_trans(data['img'])
mmcls_trans = TRANSFORMS.build(
dict(type='torchvision/ConvertImageDtype', dtype='float'))
mmcls_transformed_img = mmcls_trans(data)['img']
np.equal(
np.array(vision_transformed_img), np.array(mmcls_transformed_img))
# test transform with interpolation
data = {'img': Image.open(img_path)}
if digit_version(torchvision.__version__) > digit_version('0.8.0'):
from torchvision.transforms import InterpolationMode
interpolation_t = InterpolationMode.NEAREST
else:
interpolation_t = Image.NEAREST
vision_trans = transforms.Resize(224, interpolation_t)
vision_transformed_img = vision_trans(data['img'])
mmcls_trans = TRANSFORMS.build(
dict(type='torchvision/Resize', size=224, interpolation='nearest'))
mmcls_transformed_img = mmcls_trans(data)['img']
np.equal(
np.array(vision_transformed_img), np.array(mmcls_transformed_img))
# test compose transforms
data = {'img': Image.open(img_path)}
vision_trans = transforms.Compose([
transforms.Resize(176),
transforms.RandomHorizontalFlip(),
transforms.PILToTensor(),
transforms.ConvertImageDtype(torch.float),
transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
vision_transformed_img = vision_trans(data['img'])
pipeline_cfg = [
dict(type='LoadImageFromFile'),
dict(type='NumpyToPIL', to_rgb=True),
dict(type='torchvision/Resize', size=176),
dict(type='torchvision/RandomHorizontalFlip'),
dict(type='torchvision/PILToTensor'),
dict(type='torchvision/ConvertImageDtype', dtype='float'),
dict(
type='torchvision/Normalize',
mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225),
)
]
pipeline = [TRANSFORMS.build(t) for t in pipeline_cfg]
mmcls_trans = Compose(transforms=pipeline)
mmcls_data = {'img_path': img_path}
mmcls_transformed_img = mmcls_trans(mmcls_data)['img']
np.equal(
np.array(vision_transformed_img), np.array(mmcls_transformed_img))
def test_repr(self):
vision_trans = transforms.RandomResizedCrop(224)
mmcls_trans = TRANSFORMS.build(
dict(type='torchvision/RandomResizedCrop', size=224))
self.assertEqual(str(vision_trans), str(mmcls_trans))