2022-04-02 20:01:06 +08:00
|
|
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
|
|
from abc import ABCMeta, abstractmethod
|
|
|
|
|
|
|
|
from torch.utils.data import Dataset
|
|
|
|
|
2022-04-22 15:22:43 +08:00
|
|
|
from easycv.utils.registry import build_from_cfg
|
2022-04-02 20:01:06 +08:00
|
|
|
from ..builder import build_datasource
|
|
|
|
from ..registry import PIPELINES
|
|
|
|
from .pipelines.transforms import Compose
|
|
|
|
|
|
|
|
|
|
|
|
class BaseDataset(Dataset, metaclass=ABCMeta):
|
|
|
|
"""Base Dataset
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self, data_source, pipeline, profiling=False):
|
|
|
|
self.data_source = build_datasource(data_source)
|
|
|
|
pipeline = [build_from_cfg(p, PIPELINES) for p in pipeline]
|
|
|
|
self.pipeline = Compose(pipeline, profiling=profiling)
|
|
|
|
|
|
|
|
def __len__(self):
|
2022-08-03 17:30:35 +08:00
|
|
|
return len(self.data_source)
|
2022-04-02 20:01:06 +08:00
|
|
|
|
|
|
|
@abstractmethod
|
|
|
|
def __getitem__(self, idx):
|
|
|
|
pass
|
|
|
|
|
|
|
|
@abstractmethod
|
|
|
|
def evaluate(self, results, evaluators, logger=None, **kwargs):
|
|
|
|
pass
|
2022-04-21 20:48:58 +08:00
|
|
|
|
|
|
|
def visualize(self, results, **kwargs):
|
|
|
|
"""Visulaize the model output results on validation data.
|
|
|
|
Returns: A dictionary
|
|
|
|
If add image visualization, return dict containing
|
|
|
|
images: List of visulaized images.
|
|
|
|
img_metas: List of length number of test images,
|
|
|
|
dict of image meta info, containing filename, img_shape,
|
|
|
|
origin_img_shape, scale_factor and so on.
|
|
|
|
"""
|
|
|
|
return {}
|