mmselfsup/openselfsup/datasets/rotation_pred.py
2020-09-02 18:49:39 +08:00

46 lines
1.3 KiB
Python

import torch
from PIL import Image
from .registry import DATASETS
from .base import BaseDataset
def rotate(img):
"""Rotate input image with 0, 90, 180, and 270 degrees.
Args:
img (Tensor): input image of shape (C, H, W).
Returns:
list[Tensor]: A list of four rotated images.
"""
return [
img,
torch.flip(img.transpose(1, 2), [1]),
torch.flip(img, [1, 2]),
torch.flip(img, [1]).transpose(1, 2)
]
@DATASETS.register_module
class RotationPredDataset(BaseDataset):
"""Dataset for rotation prediction.
"""
def __init__(self, data_source, pipeline):
super(RotationPredDataset, self).__init__(data_source, pipeline)
def __getitem__(self, idx):
img = self.data_source.get_sample(idx)
assert isinstance(img, Image.Image), \
'The output from the data source must be an Image, got: {}. \
Please ensure that the list file does not contain labels.'.format(
type(img))
img = self.pipeline(img)
img = torch.stack(rotate(img), dim=0)
rotation_labels = torch.LongTensor([0, 1, 2, 3])
return dict(img=img, rot_label=rotation_labels)
def evaluate(self, scores, keyword, logger=None):
raise NotImplemented