[Fix]: Support numpy array in random crop resize with two pic
parent
4dc2ff1b79
commit
efb9255d7e
|
@ -4,6 +4,7 @@ import random
|
|||
import warnings
|
||||
from typing import Dict, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision.transforms.functional as F
|
||||
|
@ -265,7 +266,8 @@ class RandomResizedCropAndInterpolationWithTwoPic(BaseTransform):
|
|||
tuple: params (i, j, h, w) to be passed to ``crop`` for a random
|
||||
sized crop.
|
||||
"""
|
||||
area = img.size[0] * img.size[1]
|
||||
img_h, img_w = img.shape[:2]
|
||||
area = img_h * img_w
|
||||
|
||||
for _ in range(10):
|
||||
target_area = random.uniform(*scale) * area
|
||||
|
@ -275,24 +277,24 @@ class RandomResizedCropAndInterpolationWithTwoPic(BaseTransform):
|
|||
w = int(round(math.sqrt(target_area * aspect_ratio)))
|
||||
h = int(round(math.sqrt(target_area / aspect_ratio)))
|
||||
|
||||
if w <= img.size[0] and h <= img.size[1]:
|
||||
i = random.randint(0, img.size[1] - h)
|
||||
j = random.randint(0, img.size[0] - w)
|
||||
if w <= img_w and h <= img_h:
|
||||
i = random.randint(0, img_h - h)
|
||||
j = random.randint(0, img_w - w)
|
||||
return i, j, h, w
|
||||
|
||||
# Fallback to central crop
|
||||
in_ratio = img.size[0] / img.size[1]
|
||||
in_ratio = img_w / img_h
|
||||
if in_ratio < min(ratio):
|
||||
w = img.size[0]
|
||||
w = img_w
|
||||
h = int(round(w / min(ratio)))
|
||||
elif in_ratio > max(ratio):
|
||||
h = img.size[1]
|
||||
h = img_h
|
||||
w = int(round(h * max(ratio)))
|
||||
else: # whole image
|
||||
w = img.size[0]
|
||||
h = img.size[1]
|
||||
i = (img.size[1] - h) // 2
|
||||
j = (img.size[0] - w) // 2
|
||||
w = img_w
|
||||
h = img_h
|
||||
i = (img_h - h) // 2
|
||||
j = (img_w - w) // 2
|
||||
return i, j, h, w
|
||||
|
||||
def transform(self, results: Dict) -> Dict:
|
||||
|
@ -303,14 +305,17 @@ class RandomResizedCropAndInterpolationWithTwoPic(BaseTransform):
|
|||
else:
|
||||
interpolation = self.interpolation
|
||||
if self.second_size is None:
|
||||
img = F.resized_crop(img, i, j, h, w, self.size, interpolation)
|
||||
img = img[i:i + h, j:j + w]
|
||||
img = cv2.resize(img, self.size, interpolation=interpolation)
|
||||
results.update({'img': img})
|
||||
|
||||
else:
|
||||
img = F.resized_crop(img, i, j, h, w, self.size, interpolation)
|
||||
img_target = F.resized_crop(img, i, j, h, w, self.second_size,
|
||||
self.second_interpolation)
|
||||
results.update({'img': img, 'target_img': img_target})
|
||||
img = img[i:i + h, j:j + w]
|
||||
img_sample = cv2.resize(
|
||||
img, self.size, interpolation=interpolation)
|
||||
img_target = cv2.resize(
|
||||
img, self.second_size, interpolation=self.second_interpolation)
|
||||
results.update({'img': [img_sample, img_target]})
|
||||
return results
|
||||
|
||||
def __repr__(self) -> str:
|
||||
|
|
|
@ -2,7 +2,6 @@
|
|||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from mmselfsup.datasets.pipelines import (
|
||||
BEiTMaskGenerator, Lighting, RandomGaussianBlur, RandomPatchWithLabels,
|
||||
|
@ -55,14 +54,12 @@ def test_random_resize_crop_with_two_pic():
|
|||
scale=(0.08, 1.0))
|
||||
module = RandomResizedCropAndInterpolationWithTwoPic(**transform)
|
||||
fake_input = torch.rand((224, 224, 3)).numpy().astype(np.uint8)
|
||||
fake_input = Image.fromarray(fake_input)
|
||||
|
||||
results = {'img': fake_input}
|
||||
results = module(results)
|
||||
|
||||
# test transform
|
||||
assert list(results['img'].size) == [224, 224]
|
||||
assert list(results['target_img'].size) == [112, 112]
|
||||
assert list(results['img'][0].shape) == [224, 224, 3]
|
||||
assert list(results['img'][1].shape) == [112, 112, 3]
|
||||
|
||||
# test repr
|
||||
assert isinstance(str(module), str)
|
||||
|
|
Loading…
Reference in New Issue