mirror of https://github.com/alibaba/EasyCV.git
25 lines
903 B
Python
25 lines
903 B
Python
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
import torch
|
|
|
|
from ..shared.multi_view import MultiViewDataset
|
|
from . import data_sources # pylint: disable=unused-import
|
|
from . import pipelines # pylint: disable=unused-import
|
|
from .base import BaseDataset
|
|
from .dataset_wrappers import ConcatDataset, RepeatDataset
|
|
from .odps_reader import OdpsReader
|
|
from .raw import RawDataset
|
|
|
|
__all__ = [
|
|
'ConcatDataset', 'RepeatDataset', 'OdpsReader', 'RawDataset',
|
|
'BaseDataset', 'MultiViewDataset'
|
|
]
|
|
|
|
# TODO: merge `DaliImageNetTFRecordDataSet` and `DaliTFRecordMultiViewDataset`
|
|
# avoid to import dali on cpu env which will result error
|
|
if torch.cuda.is_available():
|
|
from .dali_tfrecord_imagenet import DaliImageNetTFRecordDataSet
|
|
from .dali_tfrecord_multi_view import DaliTFRecordMultiViewDataset
|
|
|
|
__all__.extend(
|
|
['DaliImageNetTFRecordDataSet', 'DaliTFRecordMultiViewDataset'])
|