[Feature] Register torchvision transforms into mmcls (#1265)
* [Enhance] Add stochastic depth decay rule in resnet. (#1363) * add stochastic depth decay rule to drop path rate * add default value * update * pass ut * update * pass ut * remove np * rebase * update ToPIL and ToNumpy * rebase * rebase * rebase * rebase * add readme * fix review suggestions * rebase * fix conflicts * fix conflicts * fix lint * remove comments * remove useless code * update docstring * update doc API * update doc --------- Co-authored-by: Yixiao Fang <36138628+fangyixiao18@users.noreply.github.com>pull/1488/head
parent
0826df8963
commit
99e48116aa
|
@ -61,8 +61,8 @@ Loading and Formatting
|
|||
LoadImageFromFile
|
||||
PackInputs
|
||||
PackMultiTaskInputs
|
||||
ToNumpy
|
||||
ToPIL
|
||||
PILToNumpy
|
||||
NumpyToPIL
|
||||
Transpose
|
||||
Collect
|
||||
|
||||
|
@ -147,6 +147,88 @@ Transform Wrapper
|
|||
|
||||
.. module:: mmpretrain.models.utils.data_preprocessor
|
||||
|
||||
|
||||
TorchVision Transforms
|
||||
^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
We also provide all the transforms in TorchVision. You can use them the like following examples:
|
||||
|
||||
**1. Use some TorchVision Augs Surrounded by NumpyToPIL and PILToNumpy (Recommendation)**
|
||||
|
||||
Add TorchVision Augs surrounded by ``dict(type='NumpyToPIL', to_rgb=True),`` and ``dict(type='PILToNumpy', to_bgr=True),``
|
||||
|
||||
.. code:: python
|
||||
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='NumpyToPIL', to_rgb=True), # from BGR in cv2 to RGB in PIL
|
||||
dict(type='torchvision/RandomResizedCrop',size=176),
|
||||
dict(type='PILToNumpy', to_bgr=True), # from RGB in PIL to BGR in cv2
|
||||
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
|
||||
dict(type='PackInputs'),
|
||||
]
|
||||
|
||||
data_preprocessor = dict(
|
||||
num_classes=1000,
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
to_rgb=True, # from BGR in cv2 to RGB in PIL
|
||||
)
|
||||
|
||||
|
||||
**2. Use TorchVision Augs and ToTensor&Normalize**
|
||||
|
||||
Make sure the 'img' has been converted to PIL format from BGR-Numpy format before being processed by TorchVision Augs.
|
||||
|
||||
.. code:: python
|
||||
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='NumpyToPIL', to_rgb=True), # from BGR in cv2 to RGB in PIL
|
||||
dict(
|
||||
type='torchvision/RandomResizedCrop',
|
||||
size=176,
|
||||
interpolation='bilinear'), # accept str format interpolation mode
|
||||
dict(type='torchvision/RandomHorizontalFlip', p=0.5),
|
||||
dict(
|
||||
type='torchvision/TrivialAugmentWide',
|
||||
interpolation='bilinear'),
|
||||
dict(type='torchvision/PILToTensor'),
|
||||
dict(type='torchvision/ConvertImageDtype', dtype=torch.float),
|
||||
dict(
|
||||
type='torchvision/Normalize',
|
||||
mean=(0.485, 0.456, 0.406),
|
||||
std=(0.229, 0.224, 0.225),
|
||||
),
|
||||
dict(type='torchvision/RandomErasing', p=0.1),
|
||||
dict(type='PackInputs'),
|
||||
]
|
||||
|
||||
data_preprocessor = dict(num_classes=1000, mean=None, std=None, to_rgb=False) # Normalize in dataset pipeline
|
||||
|
||||
|
||||
**3. Use TorchVision Augs Except ToTensor&Normalize**
|
||||
|
||||
.. code:: python
|
||||
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='NumpyToPIL', to_rgb=True), # from BGR in cv2 to RGB in PIL
|
||||
dict(type='torchvision/RandomResizedCrop', size=176, interpolation='bilinear'),
|
||||
dict(type='torchvision/RandomHorizontalFlip', p=0.5),
|
||||
dict(type='torchvision/TrivialAugmentWide', interpolation='bilinear'),
|
||||
dict(type='PackInputs'),
|
||||
]
|
||||
|
||||
# here the Normalize params is for the RGB format
|
||||
data_preprocessor = dict(
|
||||
num_classes=1000,
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
to_rgb=False,
|
||||
)
|
||||
|
||||
|
||||
Data Preprocessors
|
||||
------------------
|
||||
|
||||
|
|
|
@ -8,8 +8,8 @@ from .auto_augment import (AutoAugment, AutoContrast, BaseAugTransform,
|
|||
Equalize, GaussianBlur, Invert, Posterize,
|
||||
RandAugment, Rotate, Sharpness, Shear, Solarize,
|
||||
SolarizeAdd, Translate)
|
||||
from .formatting import (Collect, PackInputs, PackMultiTaskInputs, ToNumpy,
|
||||
ToPIL, Transpose)
|
||||
from .formatting import (Collect, NumpyToPIL, PackInputs, PackMultiTaskInputs,
|
||||
PILToNumpy, Transpose)
|
||||
from .processing import (Albumentations, BEiTMaskGenerator, ColorJitter,
|
||||
EfficientNetCenterCrop, EfficientNetRandomCrop,
|
||||
Lighting, RandomCrop, RandomErasing,
|
||||
|
@ -21,7 +21,7 @@ for t in (CenterCrop, LoadImageFromFile, Normalize, RandomFlip,
|
|||
TRANSFORMS.register_module(module=t)
|
||||
|
||||
__all__ = [
|
||||
'ToPIL', 'ToNumpy', 'Transpose', 'Collect', 'RandomCrop',
|
||||
'NumpyToPIL', 'PILToNumpy', 'Transpose', 'Collect', 'RandomCrop',
|
||||
'RandomResizedCrop', 'Shear', 'Translate', 'Rotate', 'Invert',
|
||||
'ColorTransform', 'Solarize', 'Posterize', 'AutoContrast', 'Equalize',
|
||||
'Contrast', 'Brightness', 'Sharpness', 'AutoAugment', 'SolarizeAdd',
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
from collections import defaultdict
|
||||
from collections.abc import Sequence
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision.transforms.functional as F
|
||||
|
@ -256,55 +257,70 @@ class Transpose(BaseTransform):
|
|||
f'(keys={self.keys}, order={self.order})'
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class ToPIL(BaseTransform):
|
||||
@TRANSFORMS.register_module(('NumpyToPIL', 'ToPIL'))
|
||||
class NumpyToPIL(BaseTransform):
|
||||
"""Convert the image from OpenCV format to :obj:`PIL.Image.Image`.
|
||||
|
||||
**Required Keys:**
|
||||
|
||||
- img
|
||||
- ``img``
|
||||
|
||||
**Modified Keys:**
|
||||
|
||||
- img
|
||||
- ``img``
|
||||
|
||||
Args:
|
||||
to_rgb (bool): Whether to convert img to rgb. Defaults to True.
|
||||
"""
|
||||
|
||||
def transform(self, results):
|
||||
def __init__(self, to_rgb: bool = False) -> None:
|
||||
self.to_rgb = to_rgb
|
||||
|
||||
def transform(self, results: dict) -> dict:
|
||||
"""Method to convert images to :obj:`PIL.Image.Image`."""
|
||||
results['img'] = Image.fromarray(results['img'])
|
||||
img = results['img']
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) if self.to_rgb else img
|
||||
|
||||
results['img'] = Image.fromarray(img)
|
||||
return results
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return self.__class__.__name__ + f'(to_rgb={self.to_rgb})'
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class ToNumpy(BaseTransform):
|
||||
"""Convert object to :obj:`numpy.ndarray`.
|
||||
|
||||
@TRANSFORMS.register_module(('PILToNumpy', 'ToNumpy'))
|
||||
class PILToNumpy(BaseTransform):
|
||||
"""Convert img to :obj:`numpy.ndarray`.
|
||||
|
||||
**Required Keys:**
|
||||
|
||||
- ``*keys**``
|
||||
- ``img``
|
||||
|
||||
**Modified Keys:**
|
||||
|
||||
- ``*keys**``
|
||||
- ``img``
|
||||
|
||||
Args:
|
||||
to_bgr (bool): Whether to convert img to rgb. Defaults to True.
|
||||
dtype (str, optional): The dtype of the converted numpy array.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self, keys, dtype=None):
|
||||
self.keys = keys
|
||||
def __init__(self, to_bgr: bool = False, dtype=None) -> None:
|
||||
self.to_bgr = to_bgr
|
||||
self.dtype = dtype
|
||||
|
||||
def transform(self, results):
|
||||
"""Method to convert object to :obj:`numpy.ndarray`."""
|
||||
for key in self.keys:
|
||||
results[key] = np.array(results[key], dtype=self.dtype)
|
||||
def transform(self, results: dict) -> dict:
|
||||
"""Method to convert img to :obj:`numpy.ndarray`."""
|
||||
img = np.array(results['img'], dtype=self.dtype)
|
||||
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) if self.to_bgr else img
|
||||
|
||||
results['img'] = img
|
||||
return results
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return self.__class__.__name__ + \
|
||||
f'(keys={self.keys}, dtype={self.dtype})'
|
||||
f'(to_bgr={self.to_bgr}, dtype={self.dtype})'
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
|
|
|
@ -2,14 +2,19 @@
|
|||
import inspect
|
||||
import math
|
||||
import numbers
|
||||
import re
|
||||
import traceback
|
||||
from enum import EnumMeta
|
||||
from numbers import Number
|
||||
from typing import Dict, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import mmcv
|
||||
import mmengine
|
||||
import numpy as np
|
||||
import torchvision
|
||||
from mmcv.transforms import BaseTransform
|
||||
from mmcv.transforms.utils import cache_randomness
|
||||
from torchvision.transforms.transforms import InterpolationMode
|
||||
|
||||
from mmpretrain.registry import TRANSFORMS
|
||||
|
||||
|
@ -19,6 +24,92 @@ except ImportError:
|
|||
albumentations = None
|
||||
|
||||
|
||||
def _str_to_torch_dtype(t: str):
|
||||
"""mapping str format dtype to torch.dtype."""
|
||||
import torch # noqa: F401,F403
|
||||
return eval(f'torch.{t}')
|
||||
|
||||
|
||||
def _interpolation_modes_from_str(t: str):
|
||||
"""mapping str format to Interpolation."""
|
||||
t = t.lower()
|
||||
inverse_modes_mapping = {
|
||||
'nearest': InterpolationMode.NEAREST,
|
||||
'bilinear': InterpolationMode.BILINEAR,
|
||||
'bicubic': InterpolationMode.BICUBIC,
|
||||
'box': InterpolationMode.BOX,
|
||||
'hammimg': InterpolationMode.HAMMING,
|
||||
'lanczos': InterpolationMode.LANCZOS,
|
||||
}
|
||||
return inverse_modes_mapping[t]
|
||||
|
||||
|
||||
def _warpper_vision_transform_cls(vision_transform_cls, new_name):
|
||||
"""build a transform warpper class for specific torchvison.transform to
|
||||
handle the different input type between torchvison.transforms with
|
||||
mmcls.datasets.transforms."""
|
||||
|
||||
def new_init(self, *args, **kwargs):
|
||||
if 'interpolation' in kwargs and isinstance(kwargs['interpolation'],
|
||||
str):
|
||||
kwargs['interpolation'] = _interpolation_modes_from_str(
|
||||
kwargs['interpolation'])
|
||||
if 'dtype' in kwargs and isinstance(kwargs['dtype'], str):
|
||||
kwargs['dtype'] = _str_to_torch_dtype(kwargs['dtype'])
|
||||
|
||||
try:
|
||||
self.t = vision_transform_cls(*args, **kwargs)
|
||||
except TypeError as e:
|
||||
traceback.print_exc()
|
||||
raise TypeError(
|
||||
f'Error when init the {vision_transform_cls}, please '
|
||||
f'check the argmemnts of {args} and {kwargs}. \n{e}')
|
||||
|
||||
def new_call(self, input):
|
||||
try:
|
||||
input['img'] = self.t(input['img'])
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
raise Exception('Error when processing of transform(`torhcvison/'
|
||||
f'{vision_transform_cls.__name__}`). \n{e}')
|
||||
return input
|
||||
|
||||
def new_str(self):
|
||||
return str(self.t)
|
||||
|
||||
new_transforms_cls = type(
|
||||
new_name, (),
|
||||
dict(__init__=new_init, __call__=new_call, __str__=new_str))
|
||||
return new_transforms_cls
|
||||
|
||||
|
||||
def register_vision_transforms() -> List[str]:
|
||||
"""Register transforms in ``torchvision.transforms`` to the ``TRANSFORMS``
|
||||
registry.
|
||||
|
||||
Returns:
|
||||
List[str]: A list of registered transforms' name.
|
||||
"""
|
||||
vision_transforms = []
|
||||
for module_name in dir(torchvision.transforms):
|
||||
if not re.match('[A-Z]', module_name):
|
||||
# must startswith a capital letter
|
||||
continue
|
||||
_transform = getattr(torchvision.transforms, module_name)
|
||||
if inspect.isclass(_transform) and callable(
|
||||
_transform) and not isinstance(_transform, (EnumMeta)):
|
||||
new_cls = _warpper_vision_transform_cls(
|
||||
_transform, f'TorchVison{module_name}')
|
||||
TRANSFORMS.register_module(
|
||||
module=new_cls, name=f'torchvision/{module_name}')
|
||||
vision_transforms.append(f'torchvision/{module_name}')
|
||||
return vision_transforms
|
||||
|
||||
|
||||
# register all the transforms in torchvision by using a transform wrapper
|
||||
VISION_TRANSFORMS = register_vision_transforms()
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class RandomCrop(BaseTransform):
|
||||
"""Crop the given Image at a random location.
|
||||
|
|
|
@ -143,6 +143,8 @@ class Res2Layer(Sequential):
|
|||
Default: dict(type='BN')
|
||||
scales (int): Scales used in Res2Net. Default: 4
|
||||
base_width (int): Basic width of each scale. Default: 26
|
||||
drop_path_rate (float or np.ndarray): stochastic depth rate.
|
||||
Default: 0.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
|
@ -156,9 +158,16 @@ class Res2Layer(Sequential):
|
|||
norm_cfg=dict(type='BN'),
|
||||
scales=4,
|
||||
base_width=26,
|
||||
drop_path_rate=0.0,
|
||||
**kwargs):
|
||||
self.block = block
|
||||
|
||||
if isinstance(drop_path_rate, float):
|
||||
drop_path_rate = [drop_path_rate] * num_blocks
|
||||
|
||||
assert len(drop_path_rate
|
||||
) == num_blocks, 'Please check the length of drop_path_rate'
|
||||
|
||||
downsample = None
|
||||
if stride != 1 or in_channels != out_channels:
|
||||
if avg_down:
|
||||
|
@ -201,9 +210,10 @@ class Res2Layer(Sequential):
|
|||
scales=scales,
|
||||
base_width=base_width,
|
||||
stage_type='stage',
|
||||
drop_path_rate=drop_path_rate[0],
|
||||
**kwargs))
|
||||
in_channels = out_channels
|
||||
for _ in range(1, num_blocks):
|
||||
for i in range(1, num_blocks):
|
||||
layers.append(
|
||||
block(
|
||||
in_channels=in_channels,
|
||||
|
@ -213,6 +223,7 @@ class Res2Layer(Sequential):
|
|||
norm_cfg=norm_cfg,
|
||||
scales=scales,
|
||||
base_width=base_width,
|
||||
drop_path_rate=drop_path_rate[i],
|
||||
**kwargs))
|
||||
super(Res2Layer, self).__init__(*layers)
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.checkpoint as cp
|
||||
from mmcv.cnn import (ConvModule, build_activation_layer, build_conv_layer,
|
||||
|
@ -334,6 +334,8 @@ class ResLayer(nn.Sequential):
|
|||
layer. Default: None
|
||||
norm_cfg (dict): dictionary to construct and config norm layer.
|
||||
Default: dict(type='BN')
|
||||
drop_path_rate (float or list): stochastic depth rate.
|
||||
Default: 0.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
|
@ -346,10 +348,17 @@ class ResLayer(nn.Sequential):
|
|||
avg_down=False,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
drop_path_rate=0.0,
|
||||
**kwargs):
|
||||
self.block = block
|
||||
self.expansion = get_expansion(block, expansion)
|
||||
|
||||
if isinstance(drop_path_rate, float):
|
||||
drop_path_rate = [drop_path_rate] * num_blocks
|
||||
|
||||
assert len(drop_path_rate
|
||||
) == num_blocks, 'Please check the length of drop_path_rate'
|
||||
|
||||
downsample = None
|
||||
if stride != 1 or in_channels != out_channels:
|
||||
downsample = []
|
||||
|
@ -384,6 +393,7 @@ class ResLayer(nn.Sequential):
|
|||
downsample=downsample,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
drop_path_rate=drop_path_rate[0],
|
||||
**kwargs))
|
||||
in_channels = out_channels
|
||||
for i in range(1, num_blocks):
|
||||
|
@ -395,6 +405,7 @@ class ResLayer(nn.Sequential):
|
|||
stride=1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
drop_path_rate=drop_path_rate[i],
|
||||
**kwargs))
|
||||
super(ResLayer, self).__init__(*layers)
|
||||
|
||||
|
@ -518,6 +529,13 @@ class ResNet(BaseBackbone):
|
|||
self.res_layers = []
|
||||
_in_channels = stem_channels
|
||||
_out_channels = base_channels * self.expansion
|
||||
|
||||
# stochastic depth decay rule
|
||||
total_depth = sum(stage_blocks)
|
||||
dpr = [
|
||||
x.item() for x in torch.linspace(0, drop_path_rate, total_depth)
|
||||
]
|
||||
|
||||
for i, num_blocks in enumerate(self.stage_blocks):
|
||||
stride = strides[i]
|
||||
dilation = dilations[i]
|
||||
|
@ -534,9 +552,10 @@ class ResNet(BaseBackbone):
|
|||
with_cp=with_cp,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
drop_path_rate=drop_path_rate)
|
||||
drop_path_rate=dpr[:num_blocks])
|
||||
_in_channels = _out_channels
|
||||
_out_channels *= 2
|
||||
dpr = dpr[num_blocks:]
|
||||
layer_name = f'layer{i + 1}'
|
||||
self.add_module(layer_name, res_layer)
|
||||
self.res_layers.append(layer_name)
|
||||
|
|
|
@ -104,28 +104,47 @@ class TestToPIL(unittest.TestCase):
|
|||
results = transform(copy.deepcopy(data))
|
||||
self.assertIsInstance(results['img'], Image.Image)
|
||||
|
||||
cfg = dict(type='ToPIL', to_rgb=True)
|
||||
transform = TRANSFORMS.build(cfg)
|
||||
|
||||
data = {'img': np.random.randint(0, 256, (224, 224, 3), dtype='uint8')}
|
||||
|
||||
results = transform(copy.deepcopy(data))
|
||||
self.assertIsInstance(results['img'], Image.Image)
|
||||
np.equal(np.array(results['img']), data['img'][:, :, ::-1])
|
||||
|
||||
def test_repr(self):
|
||||
cfg = dict(type='ToPIL', to_rgb=True)
|
||||
transform = TRANSFORMS.build(cfg)
|
||||
self.assertEqual(repr(transform), 'NumpyToPIL(to_rgb=True)')
|
||||
|
||||
|
||||
class TestToNumpy(unittest.TestCase):
|
||||
|
||||
def test_transform(self):
|
||||
img_path = osp.join(osp.dirname(__file__), '../../data/color.jpg')
|
||||
data = {
|
||||
'tensor': torch.tensor([1, 2, 3]),
|
||||
'Image': Image.open(img_path),
|
||||
'img': Image.open(img_path),
|
||||
}
|
||||
|
||||
cfg = dict(type='ToNumpy', keys=['tensor', 'Image'], dtype='uint8')
|
||||
cfg = dict(type='ToNumpy')
|
||||
transform = TRANSFORMS.build(cfg)
|
||||
results = transform(copy.deepcopy(data))
|
||||
self.assertIsInstance(results['tensor'], np.ndarray)
|
||||
self.assertEqual(results['tensor'].dtype, 'uint8')
|
||||
self.assertIsInstance(results['Image'], np.ndarray)
|
||||
self.assertEqual(results['Image'].dtype, 'uint8')
|
||||
self.assertIsInstance(results['img'], np.ndarray)
|
||||
self.assertEqual(results['img'].dtype, 'uint8')
|
||||
|
||||
cfg = dict(type='ToNumpy', to_bgr=True)
|
||||
transform = TRANSFORMS.build(cfg)
|
||||
results = transform(copy.deepcopy(data))
|
||||
self.assertIsInstance(results['img'], np.ndarray)
|
||||
self.assertEqual(results['img'].dtype, 'uint8')
|
||||
np.equal(results['img'], np.array(data['img'])[:, :, ::-1])
|
||||
|
||||
def test_repr(self):
|
||||
cfg = dict(type='ToNumpy', keys=['img'], dtype='uint8')
|
||||
cfg = dict(type='ToNumpy', to_bgr=True)
|
||||
transform = TRANSFORMS.build(cfg)
|
||||
self.assertEqual(repr(transform), "ToNumpy(keys=['img'], dtype=uint8)")
|
||||
self.assertEqual(
|
||||
repr(transform), 'PILToNumpy(to_bgr=True, dtype=None)')
|
||||
|
||||
|
||||
class TestCollect(unittest.TestCase):
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
import math
|
||||
import os.path as osp
|
||||
import random
|
||||
from unittest import TestCase
|
||||
from unittest.mock import ANY, call, patch
|
||||
|
@ -8,7 +9,14 @@ from unittest.mock import ANY, call, patch
|
|||
import mmengine
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
import torchvision
|
||||
from mmcv.transforms import Compose
|
||||
from mmengine.utils import digit_version
|
||||
from PIL import Image
|
||||
from torchvision import transforms
|
||||
|
||||
from mmpretrain.datasets.transforms.processing import VISION_TRANSFORMS
|
||||
from mmpretrain.registry import TRANSFORMS
|
||||
|
||||
try:
|
||||
|
@ -864,3 +872,88 @@ class TestBEiTMaskGenerator(TestCase):
|
|||
repr(transform), 'BEiTMaskGenerator(height=14, width=14, '
|
||||
'num_patches=196, num_masking_patches=75, min_num_patches=16, '
|
||||
f'max_num_patches=75, log_aspect_ratio={log_aspect_ratio})')
|
||||
|
||||
|
||||
class TestVisionTransformWrapper(TestCase):
|
||||
|
||||
def test_register(self):
|
||||
for t in VISION_TRANSFORMS:
|
||||
self.assertIn('torchvision/', t)
|
||||
self.assertIn(t, TRANSFORMS)
|
||||
|
||||
def test_transform(self):
|
||||
img_path = osp.join(osp.dirname(__file__), '../../data/color.jpg')
|
||||
data = {'img': Image.open(img_path)}
|
||||
|
||||
# test normal transform
|
||||
vision_trans = transforms.RandomResizedCrop(224)
|
||||
vision_transformed_img = vision_trans(data['img'])
|
||||
mmcls_trans = TRANSFORMS.build(
|
||||
dict(type='torchvision/RandomResizedCrop', size=224))
|
||||
mmcls_transformed_img = mmcls_trans(data)['img']
|
||||
np.equal(
|
||||
np.array(vision_transformed_img), np.array(mmcls_transformed_img))
|
||||
|
||||
# test convert type dtype
|
||||
data = {'img': torch.randn(3, 224, 224)}
|
||||
vision_trans = transforms.ConvertImageDtype(torch.float)
|
||||
vision_transformed_img = vision_trans(data['img'])
|
||||
mmcls_trans = TRANSFORMS.build(
|
||||
dict(type='torchvision/ConvertImageDtype', dtype='float'))
|
||||
mmcls_transformed_img = mmcls_trans(data)['img']
|
||||
np.equal(
|
||||
np.array(vision_transformed_img), np.array(mmcls_transformed_img))
|
||||
|
||||
# test transform with interpolation
|
||||
data = {'img': Image.open(img_path)}
|
||||
if digit_version(torchvision.__version__) > digit_version('0.8.0'):
|
||||
from torchvision.transforms import InterpolationMode
|
||||
interpolation_t = InterpolationMode.NEAREST
|
||||
else:
|
||||
interpolation_t = Image.NEAREST
|
||||
vision_trans = transforms.Resize(224, interpolation_t)
|
||||
vision_transformed_img = vision_trans(data['img'])
|
||||
mmcls_trans = TRANSFORMS.build(
|
||||
dict(type='torchvision/Resize', size=224, interpolation='nearest'))
|
||||
mmcls_transformed_img = mmcls_trans(data)['img']
|
||||
np.equal(
|
||||
np.array(vision_transformed_img), np.array(mmcls_transformed_img))
|
||||
|
||||
# test compose transforms
|
||||
data = {'img': Image.open(img_path)}
|
||||
vision_trans = transforms.Compose([
|
||||
transforms.Resize(176),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.PILToTensor(),
|
||||
transforms.ConvertImageDtype(torch.float),
|
||||
transforms.Normalize(
|
||||
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||
])
|
||||
vision_transformed_img = vision_trans(data['img'])
|
||||
|
||||
pipeline_cfg = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='NumpyToPIL', to_rgb=True),
|
||||
dict(type='torchvision/Resize', size=176),
|
||||
dict(type='torchvision/RandomHorizontalFlip'),
|
||||
dict(type='torchvision/PILToTensor'),
|
||||
dict(type='torchvision/ConvertImageDtype', dtype='float'),
|
||||
dict(
|
||||
type='torchvision/Normalize',
|
||||
mean=(0.485, 0.456, 0.406),
|
||||
std=(0.229, 0.224, 0.225),
|
||||
)
|
||||
]
|
||||
pipeline = [TRANSFORMS.build(t) for t in pipeline_cfg]
|
||||
mmcls_trans = Compose(transforms=pipeline)
|
||||
mmcls_data = {'img_path': img_path}
|
||||
mmcls_transformed_img = mmcls_trans(mmcls_data)['img']
|
||||
np.equal(
|
||||
np.array(vision_transformed_img), np.array(mmcls_transformed_img))
|
||||
|
||||
def test_repr(self):
|
||||
vision_trans = transforms.RandomResizedCrop(224)
|
||||
mmcls_trans = TRANSFORMS.build(
|
||||
dict(type='torchvision/RandomResizedCrop', size=224))
|
||||
|
||||
self.assertEqual(str(vision_trans), str(mmcls_trans))
|
||||
|
|
Loading…
Reference in New Issue