mmselfsup/openselfsup/datasets/relative_loc.py
2020-09-02 18:49:39 +08:00

66 lines
2.3 KiB
Python

from openselfsup.utils import build_from_cfg
import torch
from PIL import Image
from torchvision.transforms import Compose, RandomCrop
import torchvision.transforms.functional as TF
from .registry import DATASETS, PIPELINES
from .base import BaseDataset
def image_to_patches(img):
"""Crop split_per_side x split_per_side patches from input image.
Args:
img (PIL Image): input image.
Returns:
list[PIL Image]: A list of cropped patches.
"""
split_per_side = 3 # split of patches per image side
patch_jitter = 21 # jitter of each patch from each grid
h, w = img.size
h_grid = h // split_per_side
w_grid = w // split_per_side
h_patch = h_grid - patch_jitter
w_patch = w_grid - patch_jitter
assert h_patch > 0 and w_patch > 0
patches = []
for i in range(split_per_side):
for j in range(split_per_side):
p = TF.crop(img, i * h_grid, j * w_grid, h_grid, w_grid)
p = RandomCrop((h_patch, w_patch))(p)
patches.append(p)
return patches
@DATASETS.register_module
class RelativeLocDataset(BaseDataset):
"""Dataset for relative patch location.
"""
def __init__(self, data_source, pipeline, format_pipeline):
super(RelativeLocDataset, self).__init__(data_source, pipeline)
format_pipeline = [build_from_cfg(p, PIPELINES) for p in format_pipeline]
self.format_pipeline = Compose(format_pipeline)
def __getitem__(self, idx):
img = self.data_source.get_sample(idx)
assert isinstance(img, Image.Image), \
'The output from the data source must be an Image, got: {}. \
Please ensure that the list file does not contain labels.'.format(
type(img))
img = self.pipeline(img)
patches = image_to_patches(img)
patches = [self.format_pipeline(p) for p in patches]
perms = []
# create a list of patch pairs
[perms.append(torch.cat((patches[i], patches[4]), dim=0)) for i in range(9) if i != 4]
# create corresponding labels for patch pairs
patch_labels = torch.LongTensor([0, 1, 2, 3, 4, 5, 6, 7])
return dict(img=torch.stack(perms), patch_label=patch_labels) # 8(2C)HW, 8
def evaluate(self, scores, keyword, logger=None):
raise NotImplemented