[Fix]: Support numpy array in random crop resize with two pic

pull/352/head
liuyuan1.vendor 2022-05-13 10:57:27 +00:00 committed by fangyixiao18
parent 4dc2ff1b79
commit efb9255d7e
2 changed files with 23 additions and 21 deletions

View File

@ -4,6 +4,7 @@ import random
import warnings
from typing import Dict, List, Optional, Sequence, Tuple, Union
import cv2
import numpy as np
import torch
import torchvision.transforms.functional as F
@ -265,7 +266,8 @@ class RandomResizedCropAndInterpolationWithTwoPic(BaseTransform):
tuple: params (i, j, h, w) to be passed to ``crop`` for a random
sized crop.
"""
area = img.size[0] * img.size[1]
img_h, img_w = img.shape[:2]
area = img_h * img_w
for _ in range(10):
target_area = random.uniform(*scale) * area
@ -275,24 +277,24 @@ class RandomResizedCropAndInterpolationWithTwoPic(BaseTransform):
w = int(round(math.sqrt(target_area * aspect_ratio)))
h = int(round(math.sqrt(target_area / aspect_ratio)))
if w <= img.size[0] and h <= img.size[1]:
i = random.randint(0, img.size[1] - h)
j = random.randint(0, img.size[0] - w)
if w <= img_w and h <= img_h:
i = random.randint(0, img_h - h)
j = random.randint(0, img_w - w)
return i, j, h, w
# Fallback to central crop
in_ratio = img.size[0] / img.size[1]
in_ratio = img_w / img_h
if in_ratio < min(ratio):
w = img.size[0]
w = img_w
h = int(round(w / min(ratio)))
elif in_ratio > max(ratio):
h = img.size[1]
h = img_h
w = int(round(h * max(ratio)))
else: # whole image
w = img.size[0]
h = img.size[1]
i = (img.size[1] - h) // 2
j = (img.size[0] - w) // 2
w = img_w
h = img_h
i = (img_h - h) // 2
j = (img_w - w) // 2
return i, j, h, w
def transform(self, results: Dict) -> Dict:
@ -303,14 +305,17 @@ class RandomResizedCropAndInterpolationWithTwoPic(BaseTransform):
else:
interpolation = self.interpolation
if self.second_size is None:
img = F.resized_crop(img, i, j, h, w, self.size, interpolation)
img = img[i:i + h, j:j + w]
img = cv2.resize(img, self.size, interpolation=interpolation)
results.update({'img': img})
else:
img = F.resized_crop(img, i, j, h, w, self.size, interpolation)
img_target = F.resized_crop(img, i, j, h, w, self.second_size,
self.second_interpolation)
results.update({'img': img, 'target_img': img_target})
img = img[i:i + h, j:j + w]
img_sample = cv2.resize(
img, self.size, interpolation=interpolation)
img_target = cv2.resize(
img, self.second_size, interpolation=self.second_interpolation)
results.update({'img': [img_sample, img_target]})
return results
def __repr__(self) -> str:

View File

@ -2,7 +2,6 @@
import numpy as np
import pytest
import torch
from PIL import Image
from mmselfsup.datasets.pipelines import (
BEiTMaskGenerator, Lighting, RandomGaussianBlur, RandomPatchWithLabels,
@ -55,14 +54,12 @@ def test_random_resize_crop_with_two_pic():
scale=(0.08, 1.0))
module = RandomResizedCropAndInterpolationWithTwoPic(**transform)
fake_input = torch.rand((224, 224, 3)).numpy().astype(np.uint8)
fake_input = Image.fromarray(fake_input)
results = {'img': fake_input}
results = module(results)
# test transform
assert list(results['img'].size) == [224, 224]
assert list(results['target_img'].size) == [112, 112]
assert list(results['img'][0].shape) == [224, 224, 3]
assert list(results['img'][1].shape) == [112, 112, 3]
# test repr
assert isinstance(str(module), str)