# Copyright (c) Alibaba, Inc. and its affiliates.
from easycv.utils.torchacc_util import is_torchacc_enabled

if is_torchacc_enabled():
    import torchacc.torch_xla.distributed.parallel_loader as pl

    class TorchaccLoaderWrapper(pl.MpDeviceLoader):

        def __init__(self, loader, device=None, **kwargs) -> None:
            if device is None:
                import torchacc.torch_xla.core.xla_model as xm
                device = xm.xla_device()

            super(TorchaccLoaderWrapper, self).__init__(
                loader=loader, device=device, **kwargs)

        @property
        def sampler(self):
            return self._loader.sampler
else:
    TorchaccLoaderWrapper = None