EasyCV/easycv/datasets/shared/multi_view.py
2022-04-02 20:01:06 +08:00

55 lines
1.9 KiB
Python

# Copyright (c) Alibaba, Inc. and its affiliates.
import random
from PIL import Image, ImageFilter, ImageOps
from torchvision import transforms
from easycv.datasets.builder import build_datasource
from easycv.datasets.registry import DATASETS, PIPELINES
from easycv.datasets.shared.base import BaseDataset
from easycv.datasets.shared.pipelines.transforms import Compose
from easycv.utils import build_from_cfg
@DATASETS.register_module
class MultiViewDataset(BaseDataset):
"""The dataset outputs multiple views of an image.
The number of views in the output dict depends on `num_views`. The
image can be processed by one pipeline or multiple piepelines.
Args:
num_views (list): The number of different views.
pipelines (list[list[dict]]): A list of pipelines.
"""
def __init__(self, data_source, num_views, pipelines):
self.data_source = build_datasource(data_source)
pipelines_list = []
for pipe in pipelines:
pipeline = Compose([build_from_cfg(p, PIPELINES) for p in pipe])
pipelines_list.append(pipeline)
self.transforms_list = []
assert isinstance(num_views, list)
for i in range(len(num_views)):
self.transforms_list.extend([pipelines_list[i]] * num_views[i])
def __getitem__(self, idx):
if hasattr(self.data_source,
'has_labels') and self.data_source.has_labels:
img, _ = self.data_source.get_sample(idx)
else:
img = self.data_source.get_sample(idx)
assert isinstance(img, Image.Image), \
f'The output from the data source must be an Image, got: {type(img)}. \
Please ensure that the list file does not contain labels.'
outputs = list(map(lambda trans: trans(img), self.transforms_list))
return dict(img=outputs)
def evaluate(self, results, evaluators, logger=None):
raise NotImplementedError