EasyCV/easycv/datasets/shared/multi_view.py

58 lines
2.0 KiB
Python

# Copyright (c) Alibaba, Inc. and its affiliates.
import copy
from PIL import Image
from easycv.datasets.builder import build_datasource
from easycv.datasets.registry import DATASETS, PIPELINES
from easycv.datasets.shared.base import BaseDataset
from easycv.datasets.shared.pipelines.transforms import Compose
from easycv.framework.errors import NotImplementedError
from easycv.utils.registry import build_from_cfg
@DATASETS.register_module
class MultiViewDataset(BaseDataset):
"""The dataset outputs multiple views of an image.
The number of views in the output dict depends on `num_views`. The
image can be processed by one pipeline or multiple piepelines.
Args:
num_views (list): The number of different views.
pipelines (list[list[dict]]): A list of pipelines.
"""
def __init__(self, data_source, num_views, pipelines):
self.data_source = build_datasource(data_source)
pipelines_list = []
for pipe in pipelines:
pipeline = Compose([build_from_cfg(p, PIPELINES) for p in pipe])
pipelines_list.append(pipeline)
self.transforms_list = []
assert isinstance(num_views, list)
for i in range(len(num_views)):
self.transforms_list.extend([pipelines_list[i]] * num_views[i])
def __getitem__(self, idx):
results = self.data_source[idx]
img = results['img']
assert isinstance(img, Image.Image), \
f'The output from the data source must be an Image, got: {type(img)}. \
Please ensure that the list file does not contain labels.'
imgs_list = []
# only perform transforms to img
for trans in self.transforms_list:
tmp_input = {'img': copy.deepcopy(img)}
tmp_result = trans(tmp_input)
imgs_list.append(tmp_result['img'])
results['img'] = imgs_list
return results
def evaluate(self, results, evaluators, logger=None):
raise NotImplementedError