add torchvisionwrapper

pull/1178/head
gaotongxiao 2022-05-19 19:31:24 +08:00
parent 0b5d2df310
commit 7b6778c5d8
6 changed files with 100 additions and 237 deletions

View File

@ -14,21 +14,18 @@ from .processing import (PyramidRescale, RandomRotate, Resize,
from .test_time_aug import MultiRotateAugOCR
from .textdet_targets import (DBNetTargets, FCENetTargets, PANetTargets,
TextSnakeTargets)
from .transform_wrappers import OneOfWrapper, RandomWrapper, TorchVisionWrapper
from .transforms import (ColorJitter, RandomCropInstances,
RandomCropPolyInstances, RandomScaling,
ScaleAspectJitter, SquareResizePad)
from .wrappers import ImgAug
from .transforms import (RandomCropInstances, RandomCropPolyInstances,
RandomScaling, ScaleAspectJitter, SquareResizePad)
from .wrappers import ImgAug, TorchVisionWrapper
__all__ = [
'LoadTextAnnotations', 'NormalizeOCR', 'OnlineCropOCR', 'ResizeOCR',
'ToTensorOCR', 'DBNetTargets', 'PANetTargets', 'ColorJitter',
'RandomCropInstances', 'RandomRotate', 'ScaleAspectJitter',
'MultiRotateAugOCR', 'OCRSegTargets', 'FancyPCA',
'RandomCropPolyInstances', 'RandomPaddingOCR', 'ImgAug', 'EastRandomCrop',
'RandomRotateImageBox', 'OpencvToPil', 'PilToOpencv', 'SquareResizePad',
'TextSnakeTargets', 'sort_vertex', 'LoadImageFromNdarray', 'sort_vertex8',
'FCENetTargets', 'RandomScaling', 'TextDetRandomCropFlip', 'NerTransform',
'ToTensorNER', 'ResizeNoImg', 'PyramidRescale', 'OneOfWrapper',
'RandomWrapper', 'TorchVisionWrapper', 'LoadImageFromLMDB', 'Resize'
'ToTensorOCR', 'DBNetTargets', 'PANetTargets', 'RandomCropInstances',
'RandomRotate', 'ScaleAspectJitter', 'MultiRotateAugOCR', 'OCRSegTargets',
'FancyPCA', 'RandomCropPolyInstances', 'RandomPaddingOCR', 'ImgAug',
'EastRandomCrop', 'RandomRotateImageBox', 'OpencvToPil', 'PilToOpencv',
'SquareResizePad', 'TextSnakeTargets', 'sort_vertex',
'LoadImageFromNdarray', 'sort_vertex8', 'FCENetTargets', 'RandomScaling',
'TextDetRandomCropFlip', 'NerTransform', 'ToTensorNER', 'ResizeNoImg',
'PyramidRescale', 'TorchVisionWrapper', 'LoadImageFromLMDB', 'Resize'
]

View File

@ -1,128 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
import inspect
import random
import mmcv
import numpy as np
import torchvision.transforms as torchvision_transforms
from mmdet.datasets.pipelines import Compose
from PIL import Image
from mmocr.registry import TRANSFORMS
@TRANSFORMS.register_module()
class OneOfWrapper:
"""Randomly select and apply one of the transforms, each with the equal
chance.
Warning:
Different from albumentations, this wrapper only runs the selected
transform, but doesn't guarantee the transform can always be applied to
the input if the transform comes with a probability to run.
Args:
transforms (list[dict|callable]): Candidate transforms to be applied.
"""
def __init__(self, transforms):
assert isinstance(transforms, list) or isinstance(transforms, tuple)
assert len(transforms) > 0, 'Need at least one transform.'
self.transforms = []
for t in transforms:
if isinstance(t, dict):
self.transforms.append(TRANSFORMS.build(t))
elif callable(t):
self.transforms.append(t)
else:
raise TypeError('transform must be callable or a dict')
def __call__(self, results):
return random.choice(self.transforms)(results)
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(transforms={self.transforms})'
return repr_str
@TRANSFORMS.register_module()
class RandomWrapper:
"""Run a transform or a sequence of transforms with probability p.
Args:
transforms (list[dict|callable]): Transform(s) to be applied.
p (int|float): Probability of running transform(s).
"""
def __init__(self, transforms, p):
assert 0 <= p <= 1
self.transforms = Compose(transforms)
self.p = p
def __call__(self, results):
return results if np.random.uniform() > self.p else self.transforms(
results)
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(transforms={self.transforms}, '
repr_str += f'p={self.p})'
return repr_str
@TRANSFORMS.register_module()
class TorchVisionWrapper:
"""A wrapper of torchvision trasnforms. It applies specific transform to
``img`` and updates ``img_shape`` accordingly.
Warning:
This transform only affects the image but not its associated
annotations, such as word bounding boxes and polygon masks. Therefore,
it may only be applicable to text recognition tasks.
Args:
op (str): The name of any transform class in
:func:`torchvision.transforms`.
**kwargs: Arguments that will be passed to initializer of torchvision
transform.
:Required Keys:
- | ``img`` (ndarray): The input image.
:Affected Keys:
:Modified:
- | ``img`` (ndarray): The modified image.
:Added:
- | ``img_shape`` (tuple(int)): Size of the modified image.
"""
def __init__(self, op, **kwargs):
assert type(op) is str
if mmcv.is_str(op):
obj_cls = getattr(torchvision_transforms, op)
elif inspect.isclass(op):
obj_cls = op
else:
raise TypeError(
f'type must be a str or valid type, but got {type(type)}')
self.transform = obj_cls(**kwargs)
self.kwargs = kwargs
def __call__(self, results):
assert 'img' in results
# BGR -> RGB
img = results['img'][..., ::-1]
img = Image.fromarray(img)
img = self.transform(img)
img = np.asarray(img)
img = img[..., ::-1]
results['img'] = img
results['img_shape'] = img.shape
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(transform={self.transform})'
return repr_str

View File

@ -3,10 +3,8 @@ import math
import mmcv
import numpy as np
import torchvision.transforms as transforms
from mmdet.core import BitmapMasks, PolygonMasks
from mmdet.datasets.pipelines.transforms import Resize
from PIL import Image
import mmocr.core.evaluation.utils as eval_utils
from mmocr.registry import TRANSFORMS
@ -174,29 +172,6 @@ class RandomCropInstances:
return repr_str
@TRANSFORMS.register_module()
class ColorJitter:
"""An interface for torch color jitter so that it can be invoked in
mmdetection pipeline."""
def __init__(self, **kwargs):
self.transform = transforms.ColorJitter(**kwargs)
def __call__(self, results):
# img is bgr
img = results['img'][..., ::-1]
img = Image.fromarray(img)
img = self.transform(img)
img = np.asarray(img)
img = img[..., ::-1]
results['img'] = img
return results
def __repr__(self):
repr_str = self.__class__.__name__
return repr_str
@TRANSFORMS.register_module()
class ScaleAspectJitter(Resize):
"""Resize image and segmentation mask encoded by coordinates.
@ -286,41 +261,6 @@ class ScaleAspectJitter(Resize):
results['scale_idx'] = None
@TRANSFORMS.register_module()
class AffineJitter:
"""An interface for torchvision random affine so that it can be invoked in
mmdet pipeline."""
def __init__(self,
degrees=4,
translate=(0.02, 0.04),
scale=(0.9, 1.1),
shear=None,
resample=False,
fillcolor=0):
self.transform = transforms.RandomAffine(
degrees=degrees,
translate=translate,
scale=scale,
shear=shear,
resample=resample,
fillcolor=fillcolor)
def __call__(self, results):
# img is bgr
img = results['img'][..., ::-1]
img = Image.fromarray(img)
img = self.transform(img)
img = np.asarray(img)
img = img[..., ::-1]
results['img'] = img
return results
def __repr__(self):
repr_str = self.__class__.__name__
return repr_str
@TRANSFORMS.register_module()
class RandomCropPolyInstances:
"""Randomly crop images and make sure to contain at least one intact

View File

@ -4,7 +4,9 @@ from typing import Any, Dict, List, Optional, Tuple, Union
import imgaug
import imgaug.augmenters as iaa
import numpy as np
import torchvision.transforms as torchvision_transforms
from mmcv.transforms.base import BaseTransform
from PIL import Image
from mmocr.registry import TRANSFORMS
@ -233,3 +235,66 @@ class ImgAug(BaseTransform):
repr_str = self.__class__.__name__
repr_str += f'(args = {self.args})'
return repr_str
@TRANSFORMS.register_module()
class TorchVisionWrapper(BaseTransform):
"""A wrapper around torchvision trasnforms. It applies specific transform
to ``img`` and updates ``height`` and ``width`` accordingly.
Required Keys:
- img (ndarray): The input image.
Modified Keys:
- img (ndarray): The modified image.
- img_shape (tuple(int, int)): The shape of the image in (height, width).
Warning:
This transform only affects the image but not its associated
annotations, such as word bounding boxes and polygons. Therefore,
it may only be applicable to text recognition tasks.
Args:
op (str): The name of any transform class in
:func:`torchvision.transforms`.
**kwargs: Arguments that will be passed to initializer of torchvision
transform.
"""
def __init__(self, op: str, **kwargs) -> None:
assert isinstance(op, str)
obj_cls = getattr(torchvision_transforms, op)
self.torchvision = obj_cls(**kwargs)
self.op = op
self.kwargs = kwargs
def transform(self, results):
"""Transform the image.
Args:
results (dict): Result dict from the data loader.
Returns:
dict: Transformed results.
"""
assert 'img' in results
# BGR -> RGB
img = results['img'][..., ::-1]
img = Image.fromarray(img)
img = self.torchvision(img)
img = np.asarray(img)
img = img[..., ::-1]
results['img'] = img
results['img_shape'] = img.shape[:2]
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(op = {self.op}'
for k, v in self.kwargs.items():
repr_str += f', {k} = {v}'
repr_str += ')'
return repr_str

View File

@ -2,9 +2,7 @@
import unittest.mock as mock
import numpy as np
import torchvision.transforms as TF
from mmdet.core import BitmapMasks, PolygonMasks
from PIL import Image
import mmocr.datasets.pipelines.transforms as transforms
@ -131,38 +129,6 @@ def test_scale_aspect_jitter(mock_random):
assert results['scale'] == (650, 2600)
def test_color_jitter():
img = np.ones((64, 256, 3), dtype=np.uint8)
results = {'img': img}
pt_official_color_jitter = TF.ColorJitter()
output1 = pt_official_color_jitter(img)
color_jitter = transforms.ColorJitter()
output2 = color_jitter(results)
assert np.allclose(output1, output2['img'])
def test_affine_jitter():
img = np.ones((64, 256, 3), dtype=np.uint8)
results = {'img': img}
pt_official_affine_jitter = TF.RandomAffine(degrees=0)
output1 = pt_official_affine_jitter(Image.fromarray(img))
affine_jitter = transforms.AffineJitter(
degrees=0,
translate=None,
scale=None,
shear=None,
resample=False,
fillcolor=0)
output2 = affine_jitter(results)
assert np.allclose(np.array(output1), output2['img'])
def test_random_scale():
h, w, c = 100, 100, 3
img = np.ones((h, w, c), dtype=np.uint8)

View File

@ -6,7 +6,7 @@ from typing import Dict, List, Optional
import numpy as np
from shapely.geometry import Polygon
from mmocr.datasets.pipelines import ImgAug
from mmocr.datasets.pipelines import ImgAug, TorchVisionWrapper
class TestImgAug(unittest.TestCase):
@ -140,3 +140,26 @@ class TestImgAug(unittest.TestCase):
self.assertEqual(
repr(transform),
("ImgAug(args = [['Resize', [0.5, 3.0]], ['Fliplr', 0.5]])"))
class TestTorchVisionWrapper(unittest.TestCase):
def test_transform(self):
x = {'img': np.ones((128, 100, 3), dtype=np.uint8)}
# object not found error
with self.assertRaises(Exception):
TorchVisionWrapper(op='NonExist')
with self.assertRaises(TypeError):
TorchVisionWrapper()
f = TorchVisionWrapper('Grayscale')
with self.assertRaises(AssertionError):
f({})
results = f(x)
assert results['img'].shape == (128, 100)
assert results['img_shape'] == (128, 100)
def test_repr(self):
f = TorchVisionWrapper('Grayscale', num_output_channels=3)
self.assertEqual(
repr(f),
'TorchVisionWrapper(op = Grayscale, num_output_channels = 3)')