EasyCV/easycv/datasets/shared/__init__.py

25 lines
903 B
Python

# Copyright (c) Alibaba, Inc. and its affiliates.
import torch
from ..shared.multi_view import MultiViewDataset
from . import data_sources # pylint: disable=unused-import
from . import pipelines # pylint: disable=unused-import
from .base import BaseDataset
from .dataset_wrappers import ConcatDataset, RepeatDataset
from .odps_reader import OdpsReader
from .raw import RawDataset
__all__ = [
'ConcatDataset', 'RepeatDataset', 'OdpsReader', 'RawDataset',
'BaseDataset', 'MultiViewDataset'
]
# TODO: merge `DaliImageNetTFRecordDataSet` and `DaliTFRecordMultiViewDataset`
# avoid to import dali on cpu env which will result error
if torch.cuda.is_available():
from .dali_tfrecord_imagenet import DaliImageNetTFRecordDataSet
from .dali_tfrecord_multi_view import DaliTFRecordMultiViewDataset
__all__.extend(
['DaliImageNetTFRecordDataSet', 'DaliTFRecordMultiViewDataset'])