diff --git a/mmcls/datasets/pipelines/transforms.py b/mmcls/datasets/pipelines/transforms.py index 08a9e410e..750b3ee0e 100644 --- a/mmcls/datasets/pipelines/transforms.py +++ b/mmcls/datasets/pipelines/transforms.py @@ -146,7 +146,8 @@ class RandomResizedCrop(object): size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), - interpolation='bilinear'): + interpolation='bilinear', + backend='cv2'): if isinstance(size, (tuple, list)): self.size = size else: @@ -154,10 +155,14 @@ class RandomResizedCrop(object): if (scale[0] > scale[1]) or (ratio[0] > ratio[1]): raise ValueError("range should be of kind (min, max). " f"But received {scale}") + if backend not in ['cv2', 'pillow']: + raise ValueError(f'backend: {backend} is not supported for resize.' + 'Supported backends are "cv2", "pillow"') self.interpolation = interpolation self.scale = scale self.ratio = ratio + self.backend = backend @staticmethod def get_params(img, scale, ratio): @@ -225,7 +230,10 @@ class RandomResizedCrop(object): xmin + target_height - 1 ])) results[key] = mmcv.imresize( - img, tuple(self.size[::-1]), interpolation=self.interpolation) + img, + tuple(self.size[::-1]), + interpolation=self.interpolation, + backend=self.backend) return results def __repr__(self): @@ -333,7 +341,7 @@ class Resize(object): More details can be found in `mmcv.image.geometric`. """ - def __init__(self, size, interpolation='bilinear'): + def __init__(self, size, interpolation='bilinear', backend='cv2'): assert isinstance(size, int) or (isinstance(size, tuple) and len(size) == 2) if isinstance(size, int): @@ -341,11 +349,15 @@ class Resize(object): assert size[0] > 0 and size[1] > 0 assert interpolation in ("nearest", "bilinear", "bicubic", "area", "lanczos") + if backend not in ['cv2', 'pillow']: + raise ValueError(f'backend: {backend} is not supported for resize.' + 'Supported backends are "cv2", "pillow"') self.height = size[0] self.width = size[1] self.size = size self.interpolation = interpolation + self.backend = backend def _resize_img(self, results): for key in results.get('img_fields', ['img']): @@ -353,7 +365,8 @@ class Resize(object): results[key], size=(self.width, self.height), interpolation=self.interpolation, - return_scale=False) + return_scale=False, + backend=self.backend) results[key] = img results['img_shape'] = img.shape