mirror of https://github.com/open-mmlab/mmocr.git
add torchvisionwrapper
parent
0b5d2df310
commit
7b6778c5d8
|
@ -14,21 +14,18 @@ from .processing import (PyramidRescale, RandomRotate, Resize,
|
||||||
from .test_time_aug import MultiRotateAugOCR
|
from .test_time_aug import MultiRotateAugOCR
|
||||||
from .textdet_targets import (DBNetTargets, FCENetTargets, PANetTargets,
|
from .textdet_targets import (DBNetTargets, FCENetTargets, PANetTargets,
|
||||||
TextSnakeTargets)
|
TextSnakeTargets)
|
||||||
from .transform_wrappers import OneOfWrapper, RandomWrapper, TorchVisionWrapper
|
from .transforms import (RandomCropInstances, RandomCropPolyInstances,
|
||||||
from .transforms import (ColorJitter, RandomCropInstances,
|
RandomScaling, ScaleAspectJitter, SquareResizePad)
|
||||||
RandomCropPolyInstances, RandomScaling,
|
from .wrappers import ImgAug, TorchVisionWrapper
|
||||||
ScaleAspectJitter, SquareResizePad)
|
|
||||||
from .wrappers import ImgAug
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'LoadTextAnnotations', 'NormalizeOCR', 'OnlineCropOCR', 'ResizeOCR',
|
'LoadTextAnnotations', 'NormalizeOCR', 'OnlineCropOCR', 'ResizeOCR',
|
||||||
'ToTensorOCR', 'DBNetTargets', 'PANetTargets', 'ColorJitter',
|
'ToTensorOCR', 'DBNetTargets', 'PANetTargets', 'RandomCropInstances',
|
||||||
'RandomCropInstances', 'RandomRotate', 'ScaleAspectJitter',
|
'RandomRotate', 'ScaleAspectJitter', 'MultiRotateAugOCR', 'OCRSegTargets',
|
||||||
'MultiRotateAugOCR', 'OCRSegTargets', 'FancyPCA',
|
'FancyPCA', 'RandomCropPolyInstances', 'RandomPaddingOCR', 'ImgAug',
|
||||||
'RandomCropPolyInstances', 'RandomPaddingOCR', 'ImgAug', 'EastRandomCrop',
|
'EastRandomCrop', 'RandomRotateImageBox', 'OpencvToPil', 'PilToOpencv',
|
||||||
'RandomRotateImageBox', 'OpencvToPil', 'PilToOpencv', 'SquareResizePad',
|
'SquareResizePad', 'TextSnakeTargets', 'sort_vertex',
|
||||||
'TextSnakeTargets', 'sort_vertex', 'LoadImageFromNdarray', 'sort_vertex8',
|
'LoadImageFromNdarray', 'sort_vertex8', 'FCENetTargets', 'RandomScaling',
|
||||||
'FCENetTargets', 'RandomScaling', 'TextDetRandomCropFlip', 'NerTransform',
|
'TextDetRandomCropFlip', 'NerTransform', 'ToTensorNER', 'ResizeNoImg',
|
||||||
'ToTensorNER', 'ResizeNoImg', 'PyramidRescale', 'OneOfWrapper',
|
'PyramidRescale', 'TorchVisionWrapper', 'LoadImageFromLMDB', 'Resize'
|
||||||
'RandomWrapper', 'TorchVisionWrapper', 'LoadImageFromLMDB', 'Resize'
|
|
||||||
]
|
]
|
||||||
|
|
|
@ -1,128 +0,0 @@
|
||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
|
||||||
import inspect
|
|
||||||
import random
|
|
||||||
|
|
||||||
import mmcv
|
|
||||||
import numpy as np
|
|
||||||
import torchvision.transforms as torchvision_transforms
|
|
||||||
from mmdet.datasets.pipelines import Compose
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
from mmocr.registry import TRANSFORMS
|
|
||||||
|
|
||||||
|
|
||||||
@TRANSFORMS.register_module()
|
|
||||||
class OneOfWrapper:
|
|
||||||
"""Randomly select and apply one of the transforms, each with the equal
|
|
||||||
chance.
|
|
||||||
|
|
||||||
Warning:
|
|
||||||
Different from albumentations, this wrapper only runs the selected
|
|
||||||
transform, but doesn't guarantee the transform can always be applied to
|
|
||||||
the input if the transform comes with a probability to run.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
transforms (list[dict|callable]): Candidate transforms to be applied.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, transforms):
|
|
||||||
assert isinstance(transforms, list) or isinstance(transforms, tuple)
|
|
||||||
assert len(transforms) > 0, 'Need at least one transform.'
|
|
||||||
self.transforms = []
|
|
||||||
for t in transforms:
|
|
||||||
if isinstance(t, dict):
|
|
||||||
self.transforms.append(TRANSFORMS.build(t))
|
|
||||||
elif callable(t):
|
|
||||||
self.transforms.append(t)
|
|
||||||
else:
|
|
||||||
raise TypeError('transform must be callable or a dict')
|
|
||||||
|
|
||||||
def __call__(self, results):
|
|
||||||
return random.choice(self.transforms)(results)
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
repr_str = self.__class__.__name__
|
|
||||||
repr_str += f'(transforms={self.transforms})'
|
|
||||||
return repr_str
|
|
||||||
|
|
||||||
|
|
||||||
@TRANSFORMS.register_module()
|
|
||||||
class RandomWrapper:
|
|
||||||
"""Run a transform or a sequence of transforms with probability p.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
transforms (list[dict|callable]): Transform(s) to be applied.
|
|
||||||
p (int|float): Probability of running transform(s).
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, transforms, p):
|
|
||||||
assert 0 <= p <= 1
|
|
||||||
self.transforms = Compose(transforms)
|
|
||||||
self.p = p
|
|
||||||
|
|
||||||
def __call__(self, results):
|
|
||||||
return results if np.random.uniform() > self.p else self.transforms(
|
|
||||||
results)
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
repr_str = self.__class__.__name__
|
|
||||||
repr_str += f'(transforms={self.transforms}, '
|
|
||||||
repr_str += f'p={self.p})'
|
|
||||||
return repr_str
|
|
||||||
|
|
||||||
|
|
||||||
@TRANSFORMS.register_module()
|
|
||||||
class TorchVisionWrapper:
|
|
||||||
"""A wrapper of torchvision trasnforms. It applies specific transform to
|
|
||||||
``img`` and updates ``img_shape`` accordingly.
|
|
||||||
|
|
||||||
Warning:
|
|
||||||
This transform only affects the image but not its associated
|
|
||||||
annotations, such as word bounding boxes and polygon masks. Therefore,
|
|
||||||
it may only be applicable to text recognition tasks.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
op (str): The name of any transform class in
|
|
||||||
:func:`torchvision.transforms`.
|
|
||||||
**kwargs: Arguments that will be passed to initializer of torchvision
|
|
||||||
transform.
|
|
||||||
|
|
||||||
:Required Keys:
|
|
||||||
- | ``img`` (ndarray): The input image.
|
|
||||||
|
|
||||||
:Affected Keys:
|
|
||||||
:Modified:
|
|
||||||
- | ``img`` (ndarray): The modified image.
|
|
||||||
:Added:
|
|
||||||
- | ``img_shape`` (tuple(int)): Size of the modified image.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, op, **kwargs):
|
|
||||||
assert type(op) is str
|
|
||||||
|
|
||||||
if mmcv.is_str(op):
|
|
||||||
obj_cls = getattr(torchvision_transforms, op)
|
|
||||||
elif inspect.isclass(op):
|
|
||||||
obj_cls = op
|
|
||||||
else:
|
|
||||||
raise TypeError(
|
|
||||||
f'type must be a str or valid type, but got {type(type)}')
|
|
||||||
self.transform = obj_cls(**kwargs)
|
|
||||||
self.kwargs = kwargs
|
|
||||||
|
|
||||||
def __call__(self, results):
|
|
||||||
assert 'img' in results
|
|
||||||
# BGR -> RGB
|
|
||||||
img = results['img'][..., ::-1]
|
|
||||||
img = Image.fromarray(img)
|
|
||||||
img = self.transform(img)
|
|
||||||
img = np.asarray(img)
|
|
||||||
img = img[..., ::-1]
|
|
||||||
results['img'] = img
|
|
||||||
results['img_shape'] = img.shape
|
|
||||||
return results
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
repr_str = self.__class__.__name__
|
|
||||||
repr_str += f'(transform={self.transform})'
|
|
||||||
return repr_str
|
|
|
@ -3,10 +3,8 @@ import math
|
||||||
|
|
||||||
import mmcv
|
import mmcv
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torchvision.transforms as transforms
|
|
||||||
from mmdet.core import BitmapMasks, PolygonMasks
|
from mmdet.core import BitmapMasks, PolygonMasks
|
||||||
from mmdet.datasets.pipelines.transforms import Resize
|
from mmdet.datasets.pipelines.transforms import Resize
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
import mmocr.core.evaluation.utils as eval_utils
|
import mmocr.core.evaluation.utils as eval_utils
|
||||||
from mmocr.registry import TRANSFORMS
|
from mmocr.registry import TRANSFORMS
|
||||||
|
@ -174,29 +172,6 @@ class RandomCropInstances:
|
||||||
return repr_str
|
return repr_str
|
||||||
|
|
||||||
|
|
||||||
@TRANSFORMS.register_module()
|
|
||||||
class ColorJitter:
|
|
||||||
"""An interface for torch color jitter so that it can be invoked in
|
|
||||||
mmdetection pipeline."""
|
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
|
||||||
self.transform = transforms.ColorJitter(**kwargs)
|
|
||||||
|
|
||||||
def __call__(self, results):
|
|
||||||
# img is bgr
|
|
||||||
img = results['img'][..., ::-1]
|
|
||||||
img = Image.fromarray(img)
|
|
||||||
img = self.transform(img)
|
|
||||||
img = np.asarray(img)
|
|
||||||
img = img[..., ::-1]
|
|
||||||
results['img'] = img
|
|
||||||
return results
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
repr_str = self.__class__.__name__
|
|
||||||
return repr_str
|
|
||||||
|
|
||||||
|
|
||||||
@TRANSFORMS.register_module()
|
@TRANSFORMS.register_module()
|
||||||
class ScaleAspectJitter(Resize):
|
class ScaleAspectJitter(Resize):
|
||||||
"""Resize image and segmentation mask encoded by coordinates.
|
"""Resize image and segmentation mask encoded by coordinates.
|
||||||
|
@ -286,41 +261,6 @@ class ScaleAspectJitter(Resize):
|
||||||
results['scale_idx'] = None
|
results['scale_idx'] = None
|
||||||
|
|
||||||
|
|
||||||
@TRANSFORMS.register_module()
|
|
||||||
class AffineJitter:
|
|
||||||
"""An interface for torchvision random affine so that it can be invoked in
|
|
||||||
mmdet pipeline."""
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
degrees=4,
|
|
||||||
translate=(0.02, 0.04),
|
|
||||||
scale=(0.9, 1.1),
|
|
||||||
shear=None,
|
|
||||||
resample=False,
|
|
||||||
fillcolor=0):
|
|
||||||
self.transform = transforms.RandomAffine(
|
|
||||||
degrees=degrees,
|
|
||||||
translate=translate,
|
|
||||||
scale=scale,
|
|
||||||
shear=shear,
|
|
||||||
resample=resample,
|
|
||||||
fillcolor=fillcolor)
|
|
||||||
|
|
||||||
def __call__(self, results):
|
|
||||||
# img is bgr
|
|
||||||
img = results['img'][..., ::-1]
|
|
||||||
img = Image.fromarray(img)
|
|
||||||
img = self.transform(img)
|
|
||||||
img = np.asarray(img)
|
|
||||||
img = img[..., ::-1]
|
|
||||||
results['img'] = img
|
|
||||||
return results
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
repr_str = self.__class__.__name__
|
|
||||||
return repr_str
|
|
||||||
|
|
||||||
|
|
||||||
@TRANSFORMS.register_module()
|
@TRANSFORMS.register_module()
|
||||||
class RandomCropPolyInstances:
|
class RandomCropPolyInstances:
|
||||||
"""Randomly crop images and make sure to contain at least one intact
|
"""Randomly crop images and make sure to contain at least one intact
|
||||||
|
|
|
@ -4,7 +4,9 @@ from typing import Any, Dict, List, Optional, Tuple, Union
|
||||||
import imgaug
|
import imgaug
|
||||||
import imgaug.augmenters as iaa
|
import imgaug.augmenters as iaa
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import torchvision.transforms as torchvision_transforms
|
||||||
from mmcv.transforms.base import BaseTransform
|
from mmcv.transforms.base import BaseTransform
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
from mmocr.registry import TRANSFORMS
|
from mmocr.registry import TRANSFORMS
|
||||||
|
|
||||||
|
@ -233,3 +235,66 @@ class ImgAug(BaseTransform):
|
||||||
repr_str = self.__class__.__name__
|
repr_str = self.__class__.__name__
|
||||||
repr_str += f'(args = {self.args})'
|
repr_str += f'(args = {self.args})'
|
||||||
return repr_str
|
return repr_str
|
||||||
|
|
||||||
|
|
||||||
|
@TRANSFORMS.register_module()
|
||||||
|
class TorchVisionWrapper(BaseTransform):
|
||||||
|
"""A wrapper around torchvision trasnforms. It applies specific transform
|
||||||
|
to ``img`` and updates ``height`` and ``width`` accordingly.
|
||||||
|
|
||||||
|
Required Keys:
|
||||||
|
|
||||||
|
- img (ndarray): The input image.
|
||||||
|
|
||||||
|
Modified Keys:
|
||||||
|
|
||||||
|
- img (ndarray): The modified image.
|
||||||
|
- img_shape (tuple(int, int)): The shape of the image in (height, width).
|
||||||
|
|
||||||
|
|
||||||
|
Warning:
|
||||||
|
This transform only affects the image but not its associated
|
||||||
|
annotations, such as word bounding boxes and polygons. Therefore,
|
||||||
|
it may only be applicable to text recognition tasks.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
op (str): The name of any transform class in
|
||||||
|
:func:`torchvision.transforms`.
|
||||||
|
**kwargs: Arguments that will be passed to initializer of torchvision
|
||||||
|
transform.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, op: str, **kwargs) -> None:
|
||||||
|
assert isinstance(op, str)
|
||||||
|
obj_cls = getattr(torchvision_transforms, op)
|
||||||
|
self.torchvision = obj_cls(**kwargs)
|
||||||
|
self.op = op
|
||||||
|
self.kwargs = kwargs
|
||||||
|
|
||||||
|
def transform(self, results):
|
||||||
|
"""Transform the image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
results (dict): Result dict from the data loader.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Transformed results.
|
||||||
|
"""
|
||||||
|
assert 'img' in results
|
||||||
|
# BGR -> RGB
|
||||||
|
img = results['img'][..., ::-1]
|
||||||
|
img = Image.fromarray(img)
|
||||||
|
img = self.torchvision(img)
|
||||||
|
img = np.asarray(img)
|
||||||
|
img = img[..., ::-1]
|
||||||
|
results['img'] = img
|
||||||
|
results['img_shape'] = img.shape[:2]
|
||||||
|
return results
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
repr_str = self.__class__.__name__
|
||||||
|
repr_str += f'(op = {self.op}'
|
||||||
|
for k, v in self.kwargs.items():
|
||||||
|
repr_str += f', {k} = {v}'
|
||||||
|
repr_str += ')'
|
||||||
|
return repr_str
|
||||||
|
|
|
@ -2,9 +2,7 @@
|
||||||
import unittest.mock as mock
|
import unittest.mock as mock
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torchvision.transforms as TF
|
|
||||||
from mmdet.core import BitmapMasks, PolygonMasks
|
from mmdet.core import BitmapMasks, PolygonMasks
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
import mmocr.datasets.pipelines.transforms as transforms
|
import mmocr.datasets.pipelines.transforms as transforms
|
||||||
|
|
||||||
|
@ -131,38 +129,6 @@ def test_scale_aspect_jitter(mock_random):
|
||||||
assert results['scale'] == (650, 2600)
|
assert results['scale'] == (650, 2600)
|
||||||
|
|
||||||
|
|
||||||
def test_color_jitter():
|
|
||||||
img = np.ones((64, 256, 3), dtype=np.uint8)
|
|
||||||
results = {'img': img}
|
|
||||||
|
|
||||||
pt_official_color_jitter = TF.ColorJitter()
|
|
||||||
output1 = pt_official_color_jitter(img)
|
|
||||||
|
|
||||||
color_jitter = transforms.ColorJitter()
|
|
||||||
output2 = color_jitter(results)
|
|
||||||
|
|
||||||
assert np.allclose(output1, output2['img'])
|
|
||||||
|
|
||||||
|
|
||||||
def test_affine_jitter():
|
|
||||||
img = np.ones((64, 256, 3), dtype=np.uint8)
|
|
||||||
results = {'img': img}
|
|
||||||
|
|
||||||
pt_official_affine_jitter = TF.RandomAffine(degrees=0)
|
|
||||||
output1 = pt_official_affine_jitter(Image.fromarray(img))
|
|
||||||
|
|
||||||
affine_jitter = transforms.AffineJitter(
|
|
||||||
degrees=0,
|
|
||||||
translate=None,
|
|
||||||
scale=None,
|
|
||||||
shear=None,
|
|
||||||
resample=False,
|
|
||||||
fillcolor=0)
|
|
||||||
output2 = affine_jitter(results)
|
|
||||||
|
|
||||||
assert np.allclose(np.array(output1), output2['img'])
|
|
||||||
|
|
||||||
|
|
||||||
def test_random_scale():
|
def test_random_scale():
|
||||||
h, w, c = 100, 100, 3
|
h, w, c = 100, 100, 3
|
||||||
img = np.ones((h, w, c), dtype=np.uint8)
|
img = np.ones((h, w, c), dtype=np.uint8)
|
||||||
|
|
|
@ -6,7 +6,7 @@ from typing import Dict, List, Optional
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from shapely.geometry import Polygon
|
from shapely.geometry import Polygon
|
||||||
|
|
||||||
from mmocr.datasets.pipelines import ImgAug
|
from mmocr.datasets.pipelines import ImgAug, TorchVisionWrapper
|
||||||
|
|
||||||
|
|
||||||
class TestImgAug(unittest.TestCase):
|
class TestImgAug(unittest.TestCase):
|
||||||
|
@ -140,3 +140,26 @@ class TestImgAug(unittest.TestCase):
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
repr(transform),
|
repr(transform),
|
||||||
("ImgAug(args = [['Resize', [0.5, 3.0]], ['Fliplr', 0.5]])"))
|
("ImgAug(args = [['Resize', [0.5, 3.0]], ['Fliplr', 0.5]])"))
|
||||||
|
|
||||||
|
|
||||||
|
class TestTorchVisionWrapper(unittest.TestCase):
|
||||||
|
|
||||||
|
def test_transform(self):
|
||||||
|
x = {'img': np.ones((128, 100, 3), dtype=np.uint8)}
|
||||||
|
# object not found error
|
||||||
|
with self.assertRaises(Exception):
|
||||||
|
TorchVisionWrapper(op='NonExist')
|
||||||
|
with self.assertRaises(TypeError):
|
||||||
|
TorchVisionWrapper()
|
||||||
|
f = TorchVisionWrapper('Grayscale')
|
||||||
|
with self.assertRaises(AssertionError):
|
||||||
|
f({})
|
||||||
|
results = f(x)
|
||||||
|
assert results['img'].shape == (128, 100)
|
||||||
|
assert results['img_shape'] == (128, 100)
|
||||||
|
|
||||||
|
def test_repr(self):
|
||||||
|
f = TorchVisionWrapper('Grayscale', num_output_channels=3)
|
||||||
|
self.assertEqual(
|
||||||
|
repr(f),
|
||||||
|
'TorchVisionWrapper(op = Grayscale, num_output_channels = 3)')
|
||||||
|
|
Loading…
Reference in New Issue