mirror of https://github.com/open-mmlab/mmocr.git
add torchvisionwrapper
parent
0b5d2df310
commit
7b6778c5d8
|
@ -14,21 +14,18 @@ from .processing import (PyramidRescale, RandomRotate, Resize,
|
|||
from .test_time_aug import MultiRotateAugOCR
|
||||
from .textdet_targets import (DBNetTargets, FCENetTargets, PANetTargets,
|
||||
TextSnakeTargets)
|
||||
from .transform_wrappers import OneOfWrapper, RandomWrapper, TorchVisionWrapper
|
||||
from .transforms import (ColorJitter, RandomCropInstances,
|
||||
RandomCropPolyInstances, RandomScaling,
|
||||
ScaleAspectJitter, SquareResizePad)
|
||||
from .wrappers import ImgAug
|
||||
from .transforms import (RandomCropInstances, RandomCropPolyInstances,
|
||||
RandomScaling, ScaleAspectJitter, SquareResizePad)
|
||||
from .wrappers import ImgAug, TorchVisionWrapper
|
||||
|
||||
__all__ = [
|
||||
'LoadTextAnnotations', 'NormalizeOCR', 'OnlineCropOCR', 'ResizeOCR',
|
||||
'ToTensorOCR', 'DBNetTargets', 'PANetTargets', 'ColorJitter',
|
||||
'RandomCropInstances', 'RandomRotate', 'ScaleAspectJitter',
|
||||
'MultiRotateAugOCR', 'OCRSegTargets', 'FancyPCA',
|
||||
'RandomCropPolyInstances', 'RandomPaddingOCR', 'ImgAug', 'EastRandomCrop',
|
||||
'RandomRotateImageBox', 'OpencvToPil', 'PilToOpencv', 'SquareResizePad',
|
||||
'TextSnakeTargets', 'sort_vertex', 'LoadImageFromNdarray', 'sort_vertex8',
|
||||
'FCENetTargets', 'RandomScaling', 'TextDetRandomCropFlip', 'NerTransform',
|
||||
'ToTensorNER', 'ResizeNoImg', 'PyramidRescale', 'OneOfWrapper',
|
||||
'RandomWrapper', 'TorchVisionWrapper', 'LoadImageFromLMDB', 'Resize'
|
||||
'ToTensorOCR', 'DBNetTargets', 'PANetTargets', 'RandomCropInstances',
|
||||
'RandomRotate', 'ScaleAspectJitter', 'MultiRotateAugOCR', 'OCRSegTargets',
|
||||
'FancyPCA', 'RandomCropPolyInstances', 'RandomPaddingOCR', 'ImgAug',
|
||||
'EastRandomCrop', 'RandomRotateImageBox', 'OpencvToPil', 'PilToOpencv',
|
||||
'SquareResizePad', 'TextSnakeTargets', 'sort_vertex',
|
||||
'LoadImageFromNdarray', 'sort_vertex8', 'FCENetTargets', 'RandomScaling',
|
||||
'TextDetRandomCropFlip', 'NerTransform', 'ToTensorNER', 'ResizeNoImg',
|
||||
'PyramidRescale', 'TorchVisionWrapper', 'LoadImageFromLMDB', 'Resize'
|
||||
]
|
||||
|
|
|
@ -1,128 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import inspect
|
||||
import random
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
import torchvision.transforms as torchvision_transforms
|
||||
from mmdet.datasets.pipelines import Compose
|
||||
from PIL import Image
|
||||
|
||||
from mmocr.registry import TRANSFORMS
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class OneOfWrapper:
|
||||
"""Randomly select and apply one of the transforms, each with the equal
|
||||
chance.
|
||||
|
||||
Warning:
|
||||
Different from albumentations, this wrapper only runs the selected
|
||||
transform, but doesn't guarantee the transform can always be applied to
|
||||
the input if the transform comes with a probability to run.
|
||||
|
||||
Args:
|
||||
transforms (list[dict|callable]): Candidate transforms to be applied.
|
||||
"""
|
||||
|
||||
def __init__(self, transforms):
|
||||
assert isinstance(transforms, list) or isinstance(transforms, tuple)
|
||||
assert len(transforms) > 0, 'Need at least one transform.'
|
||||
self.transforms = []
|
||||
for t in transforms:
|
||||
if isinstance(t, dict):
|
||||
self.transforms.append(TRANSFORMS.build(t))
|
||||
elif callable(t):
|
||||
self.transforms.append(t)
|
||||
else:
|
||||
raise TypeError('transform must be callable or a dict')
|
||||
|
||||
def __call__(self, results):
|
||||
return random.choice(self.transforms)(results)
|
||||
|
||||
def __repr__(self):
|
||||
repr_str = self.__class__.__name__
|
||||
repr_str += f'(transforms={self.transforms})'
|
||||
return repr_str
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class RandomWrapper:
|
||||
"""Run a transform or a sequence of transforms with probability p.
|
||||
|
||||
Args:
|
||||
transforms (list[dict|callable]): Transform(s) to be applied.
|
||||
p (int|float): Probability of running transform(s).
|
||||
"""
|
||||
|
||||
def __init__(self, transforms, p):
|
||||
assert 0 <= p <= 1
|
||||
self.transforms = Compose(transforms)
|
||||
self.p = p
|
||||
|
||||
def __call__(self, results):
|
||||
return results if np.random.uniform() > self.p else self.transforms(
|
||||
results)
|
||||
|
||||
def __repr__(self):
|
||||
repr_str = self.__class__.__name__
|
||||
repr_str += f'(transforms={self.transforms}, '
|
||||
repr_str += f'p={self.p})'
|
||||
return repr_str
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class TorchVisionWrapper:
|
||||
"""A wrapper of torchvision trasnforms. It applies specific transform to
|
||||
``img`` and updates ``img_shape`` accordingly.
|
||||
|
||||
Warning:
|
||||
This transform only affects the image but not its associated
|
||||
annotations, such as word bounding boxes and polygon masks. Therefore,
|
||||
it may only be applicable to text recognition tasks.
|
||||
|
||||
Args:
|
||||
op (str): The name of any transform class in
|
||||
:func:`torchvision.transforms`.
|
||||
**kwargs: Arguments that will be passed to initializer of torchvision
|
||||
transform.
|
||||
|
||||
:Required Keys:
|
||||
- | ``img`` (ndarray): The input image.
|
||||
|
||||
:Affected Keys:
|
||||
:Modified:
|
||||
- | ``img`` (ndarray): The modified image.
|
||||
:Added:
|
||||
- | ``img_shape`` (tuple(int)): Size of the modified image.
|
||||
"""
|
||||
|
||||
def __init__(self, op, **kwargs):
|
||||
assert type(op) is str
|
||||
|
||||
if mmcv.is_str(op):
|
||||
obj_cls = getattr(torchvision_transforms, op)
|
||||
elif inspect.isclass(op):
|
||||
obj_cls = op
|
||||
else:
|
||||
raise TypeError(
|
||||
f'type must be a str or valid type, but got {type(type)}')
|
||||
self.transform = obj_cls(**kwargs)
|
||||
self.kwargs = kwargs
|
||||
|
||||
def __call__(self, results):
|
||||
assert 'img' in results
|
||||
# BGR -> RGB
|
||||
img = results['img'][..., ::-1]
|
||||
img = Image.fromarray(img)
|
||||
img = self.transform(img)
|
||||
img = np.asarray(img)
|
||||
img = img[..., ::-1]
|
||||
results['img'] = img
|
||||
results['img_shape'] = img.shape
|
||||
return results
|
||||
|
||||
def __repr__(self):
|
||||
repr_str = self.__class__.__name__
|
||||
repr_str += f'(transform={self.transform})'
|
||||
return repr_str
|
|
@ -3,10 +3,8 @@ import math
|
|||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
import torchvision.transforms as transforms
|
||||
from mmdet.core import BitmapMasks, PolygonMasks
|
||||
from mmdet.datasets.pipelines.transforms import Resize
|
||||
from PIL import Image
|
||||
|
||||
import mmocr.core.evaluation.utils as eval_utils
|
||||
from mmocr.registry import TRANSFORMS
|
||||
|
@ -174,29 +172,6 @@ class RandomCropInstances:
|
|||
return repr_str
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class ColorJitter:
|
||||
"""An interface for torch color jitter so that it can be invoked in
|
||||
mmdetection pipeline."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.transform = transforms.ColorJitter(**kwargs)
|
||||
|
||||
def __call__(self, results):
|
||||
# img is bgr
|
||||
img = results['img'][..., ::-1]
|
||||
img = Image.fromarray(img)
|
||||
img = self.transform(img)
|
||||
img = np.asarray(img)
|
||||
img = img[..., ::-1]
|
||||
results['img'] = img
|
||||
return results
|
||||
|
||||
def __repr__(self):
|
||||
repr_str = self.__class__.__name__
|
||||
return repr_str
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class ScaleAspectJitter(Resize):
|
||||
"""Resize image and segmentation mask encoded by coordinates.
|
||||
|
@ -286,41 +261,6 @@ class ScaleAspectJitter(Resize):
|
|||
results['scale_idx'] = None
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class AffineJitter:
|
||||
"""An interface for torchvision random affine so that it can be invoked in
|
||||
mmdet pipeline."""
|
||||
|
||||
def __init__(self,
|
||||
degrees=4,
|
||||
translate=(0.02, 0.04),
|
||||
scale=(0.9, 1.1),
|
||||
shear=None,
|
||||
resample=False,
|
||||
fillcolor=0):
|
||||
self.transform = transforms.RandomAffine(
|
||||
degrees=degrees,
|
||||
translate=translate,
|
||||
scale=scale,
|
||||
shear=shear,
|
||||
resample=resample,
|
||||
fillcolor=fillcolor)
|
||||
|
||||
def __call__(self, results):
|
||||
# img is bgr
|
||||
img = results['img'][..., ::-1]
|
||||
img = Image.fromarray(img)
|
||||
img = self.transform(img)
|
||||
img = np.asarray(img)
|
||||
img = img[..., ::-1]
|
||||
results['img'] = img
|
||||
return results
|
||||
|
||||
def __repr__(self):
|
||||
repr_str = self.__class__.__name__
|
||||
return repr_str
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class RandomCropPolyInstances:
|
||||
"""Randomly crop images and make sure to contain at least one intact
|
||||
|
|
|
@ -4,7 +4,9 @@ from typing import Any, Dict, List, Optional, Tuple, Union
|
|||
import imgaug
|
||||
import imgaug.augmenters as iaa
|
||||
import numpy as np
|
||||
import torchvision.transforms as torchvision_transforms
|
||||
from mmcv.transforms.base import BaseTransform
|
||||
from PIL import Image
|
||||
|
||||
from mmocr.registry import TRANSFORMS
|
||||
|
||||
|
@ -233,3 +235,66 @@ class ImgAug(BaseTransform):
|
|||
repr_str = self.__class__.__name__
|
||||
repr_str += f'(args = {self.args})'
|
||||
return repr_str
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class TorchVisionWrapper(BaseTransform):
|
||||
"""A wrapper around torchvision trasnforms. It applies specific transform
|
||||
to ``img`` and updates ``height`` and ``width`` accordingly.
|
||||
|
||||
Required Keys:
|
||||
|
||||
- img (ndarray): The input image.
|
||||
|
||||
Modified Keys:
|
||||
|
||||
- img (ndarray): The modified image.
|
||||
- img_shape (tuple(int, int)): The shape of the image in (height, width).
|
||||
|
||||
|
||||
Warning:
|
||||
This transform only affects the image but not its associated
|
||||
annotations, such as word bounding boxes and polygons. Therefore,
|
||||
it may only be applicable to text recognition tasks.
|
||||
|
||||
Args:
|
||||
op (str): The name of any transform class in
|
||||
:func:`torchvision.transforms`.
|
||||
**kwargs: Arguments that will be passed to initializer of torchvision
|
||||
transform.
|
||||
"""
|
||||
|
||||
def __init__(self, op: str, **kwargs) -> None:
|
||||
assert isinstance(op, str)
|
||||
obj_cls = getattr(torchvision_transforms, op)
|
||||
self.torchvision = obj_cls(**kwargs)
|
||||
self.op = op
|
||||
self.kwargs = kwargs
|
||||
|
||||
def transform(self, results):
|
||||
"""Transform the image.
|
||||
|
||||
Args:
|
||||
results (dict): Result dict from the data loader.
|
||||
|
||||
Returns:
|
||||
dict: Transformed results.
|
||||
"""
|
||||
assert 'img' in results
|
||||
# BGR -> RGB
|
||||
img = results['img'][..., ::-1]
|
||||
img = Image.fromarray(img)
|
||||
img = self.torchvision(img)
|
||||
img = np.asarray(img)
|
||||
img = img[..., ::-1]
|
||||
results['img'] = img
|
||||
results['img_shape'] = img.shape[:2]
|
||||
return results
|
||||
|
||||
def __repr__(self):
|
||||
repr_str = self.__class__.__name__
|
||||
repr_str += f'(op = {self.op}'
|
||||
for k, v in self.kwargs.items():
|
||||
repr_str += f', {k} = {v}'
|
||||
repr_str += ')'
|
||||
return repr_str
|
||||
|
|
|
@ -2,9 +2,7 @@
|
|||
import unittest.mock as mock
|
||||
|
||||
import numpy as np
|
||||
import torchvision.transforms as TF
|
||||
from mmdet.core import BitmapMasks, PolygonMasks
|
||||
from PIL import Image
|
||||
|
||||
import mmocr.datasets.pipelines.transforms as transforms
|
||||
|
||||
|
@ -131,38 +129,6 @@ def test_scale_aspect_jitter(mock_random):
|
|||
assert results['scale'] == (650, 2600)
|
||||
|
||||
|
||||
def test_color_jitter():
|
||||
img = np.ones((64, 256, 3), dtype=np.uint8)
|
||||
results = {'img': img}
|
||||
|
||||
pt_official_color_jitter = TF.ColorJitter()
|
||||
output1 = pt_official_color_jitter(img)
|
||||
|
||||
color_jitter = transforms.ColorJitter()
|
||||
output2 = color_jitter(results)
|
||||
|
||||
assert np.allclose(output1, output2['img'])
|
||||
|
||||
|
||||
def test_affine_jitter():
|
||||
img = np.ones((64, 256, 3), dtype=np.uint8)
|
||||
results = {'img': img}
|
||||
|
||||
pt_official_affine_jitter = TF.RandomAffine(degrees=0)
|
||||
output1 = pt_official_affine_jitter(Image.fromarray(img))
|
||||
|
||||
affine_jitter = transforms.AffineJitter(
|
||||
degrees=0,
|
||||
translate=None,
|
||||
scale=None,
|
||||
shear=None,
|
||||
resample=False,
|
||||
fillcolor=0)
|
||||
output2 = affine_jitter(results)
|
||||
|
||||
assert np.allclose(np.array(output1), output2['img'])
|
||||
|
||||
|
||||
def test_random_scale():
|
||||
h, w, c = 100, 100, 3
|
||||
img = np.ones((h, w, c), dtype=np.uint8)
|
||||
|
|
|
@ -6,7 +6,7 @@ from typing import Dict, List, Optional
|
|||
import numpy as np
|
||||
from shapely.geometry import Polygon
|
||||
|
||||
from mmocr.datasets.pipelines import ImgAug
|
||||
from mmocr.datasets.pipelines import ImgAug, TorchVisionWrapper
|
||||
|
||||
|
||||
class TestImgAug(unittest.TestCase):
|
||||
|
@ -140,3 +140,26 @@ class TestImgAug(unittest.TestCase):
|
|||
self.assertEqual(
|
||||
repr(transform),
|
||||
("ImgAug(args = [['Resize', [0.5, 3.0]], ['Fliplr', 0.5]])"))
|
||||
|
||||
|
||||
class TestTorchVisionWrapper(unittest.TestCase):
|
||||
|
||||
def test_transform(self):
|
||||
x = {'img': np.ones((128, 100, 3), dtype=np.uint8)}
|
||||
# object not found error
|
||||
with self.assertRaises(Exception):
|
||||
TorchVisionWrapper(op='NonExist')
|
||||
with self.assertRaises(TypeError):
|
||||
TorchVisionWrapper()
|
||||
f = TorchVisionWrapper('Grayscale')
|
||||
with self.assertRaises(AssertionError):
|
||||
f({})
|
||||
results = f(x)
|
||||
assert results['img'].shape == (128, 100)
|
||||
assert results['img_shape'] == (128, 100)
|
||||
|
||||
def test_repr(self):
|
||||
f = TorchVisionWrapper('Grayscale', num_output_channels=3)
|
||||
self.assertEqual(
|
||||
repr(f),
|
||||
'TorchVisionWrapper(op = Grayscale, num_output_channels = 3)')
|
||||
|
|
Loading…
Reference in New Issue