mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Rename dataset/parsers -> dataset/readers, create_parser to create_reader, etc
This commit is contained in:
parent
f67a7ee8bd
commit
e9dccc918c
@ -112,7 +112,7 @@ More models, more fixes
|
|||||||
* `cs3`, `darknet`, and `vit_*relpos` weights above all trained on TPU thanks to TRC program! Rest trained on overheating GPUs.
|
* `cs3`, `darknet`, and `vit_*relpos` weights above all trained on TPU thanks to TRC program! Rest trained on overheating GPUs.
|
||||||
* Hugging Face Hub support fixes verified, demo notebook TBA
|
* Hugging Face Hub support fixes verified, demo notebook TBA
|
||||||
* Pretrained weights / configs can be loaded externally (ie from local disk) w/ support for head adaptation.
|
* Pretrained weights / configs can be loaded externally (ie from local disk) w/ support for head adaptation.
|
||||||
* Add support to change image extensions scanned by `timm` datasets/parsers. See (https://github.com/rwightman/pytorch-image-models/pull/1274#issuecomment-1178303103)
|
* Add support to change image extensions scanned by `timm` datasets/readers. See (https://github.com/rwightman/pytorch-image-models/pull/1274#issuecomment-1178303103)
|
||||||
* Default ConvNeXt LayerNorm impl to use `F.layer_norm(x.permute(0, 2, 3, 1), ...).permute(0, 3, 1, 2)` via `LayerNorm2d` in all cases.
|
* Default ConvNeXt LayerNorm impl to use `F.layer_norm(x.permute(0, 2, 3, 1), ...).permute(0, 3, 1, 2)` via `LayerNorm2d` in all cases.
|
||||||
* a bit slower than previous custom impl on some hardware (ie Ampere w/ CL), but overall fewer regressions across wider HW / PyTorch version ranges.
|
* a bit slower than previous custom impl on some hardware (ie Ampere w/ CL), but overall fewer regressions across wider HW / PyTorch version ranges.
|
||||||
* previous impl exists as `LayerNormExp2d` in `models/layers/norm.py`
|
* previous impl exists as `LayerNormExp2d` in `models/layers/norm.py`
|
||||||
|
@ -6,8 +6,8 @@ from .dataset import ImageDataset, IterableImageDataset, AugMixDataset
|
|||||||
from .dataset_factory import create_dataset
|
from .dataset_factory import create_dataset
|
||||||
from .loader import create_loader
|
from .loader import create_loader
|
||||||
from .mixup import Mixup, FastCollateMixup
|
from .mixup import Mixup, FastCollateMixup
|
||||||
from .parsers import create_parser,\
|
from .readers import create_reader
|
||||||
get_img_extensions, is_img_extension, set_img_extensions, add_img_extensions, del_img_extensions
|
from .readers import get_img_extensions, is_img_extension, set_img_extensions, add_img_extensions, del_img_extensions
|
||||||
from .real_labels import RealLabelsImagenet
|
from .real_labels import RealLabelsImagenet
|
||||||
from .transforms import *
|
from .transforms import *
|
||||||
from .transforms_factory import create_transform
|
from .transforms_factory import create_transform
|
||||||
|
@ -10,7 +10,7 @@ import torch
|
|||||||
import torch.utils.data as data
|
import torch.utils.data as data
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from .parsers import create_parser
|
from .readers import create_reader
|
||||||
|
|
||||||
_logger = logging.getLogger(__name__)
|
_logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -23,7 +23,7 @@ class ImageDataset(data.Dataset):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
root,
|
root,
|
||||||
parser=None,
|
reader=None,
|
||||||
split='train',
|
split='train',
|
||||||
class_map=None,
|
class_map=None,
|
||||||
load_bytes=False,
|
load_bytes=False,
|
||||||
@ -31,14 +31,14 @@ class ImageDataset(data.Dataset):
|
|||||||
transform=None,
|
transform=None,
|
||||||
target_transform=None,
|
target_transform=None,
|
||||||
):
|
):
|
||||||
if parser is None or isinstance(parser, str):
|
if reader is None or isinstance(reader, str):
|
||||||
parser = create_parser(
|
reader = create_reader(
|
||||||
parser or '',
|
reader or '',
|
||||||
root=root,
|
root=root,
|
||||||
split=split,
|
split=split,
|
||||||
class_map=class_map
|
class_map=class_map
|
||||||
)
|
)
|
||||||
self.parser = parser
|
self.reader = reader
|
||||||
self.load_bytes = load_bytes
|
self.load_bytes = load_bytes
|
||||||
self.img_mode = img_mode
|
self.img_mode = img_mode
|
||||||
self.transform = transform
|
self.transform = transform
|
||||||
@ -46,15 +46,15 @@ class ImageDataset(data.Dataset):
|
|||||||
self._consecutive_errors = 0
|
self._consecutive_errors = 0
|
||||||
|
|
||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
img, target = self.parser[index]
|
img, target = self.reader[index]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
img = img.read() if self.load_bytes else Image.open(img)
|
img = img.read() if self.load_bytes else Image.open(img)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
_logger.warning(f'Skipped sample (index {index}, file {self.parser.filename(index)}). {str(e)}')
|
_logger.warning(f'Skipped sample (index {index}, file {self.reader.filename(index)}). {str(e)}')
|
||||||
self._consecutive_errors += 1
|
self._consecutive_errors += 1
|
||||||
if self._consecutive_errors < _ERROR_RETRY:
|
if self._consecutive_errors < _ERROR_RETRY:
|
||||||
return self.__getitem__((index + 1) % len(self.parser))
|
return self.__getitem__((index + 1) % len(self.reader))
|
||||||
else:
|
else:
|
||||||
raise e
|
raise e
|
||||||
self._consecutive_errors = 0
|
self._consecutive_errors = 0
|
||||||
@ -72,13 +72,13 @@ class ImageDataset(data.Dataset):
|
|||||||
return img, target
|
return img, target
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.parser)
|
return len(self.reader)
|
||||||
|
|
||||||
def filename(self, index, basename=False, absolute=False):
|
def filename(self, index, basename=False, absolute=False):
|
||||||
return self.parser.filename(index, basename, absolute)
|
return self.reader.filename(index, basename, absolute)
|
||||||
|
|
||||||
def filenames(self, basename=False, absolute=False):
|
def filenames(self, basename=False, absolute=False):
|
||||||
return self.parser.filenames(basename, absolute)
|
return self.reader.filenames(basename, absolute)
|
||||||
|
|
||||||
|
|
||||||
class IterableImageDataset(data.IterableDataset):
|
class IterableImageDataset(data.IterableDataset):
|
||||||
@ -86,7 +86,7 @@ class IterableImageDataset(data.IterableDataset):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
root,
|
root,
|
||||||
parser=None,
|
reader=None,
|
||||||
split='train',
|
split='train',
|
||||||
is_training=False,
|
is_training=False,
|
||||||
batch_size=None,
|
batch_size=None,
|
||||||
@ -96,10 +96,10 @@ class IterableImageDataset(data.IterableDataset):
|
|||||||
transform=None,
|
transform=None,
|
||||||
target_transform=None,
|
target_transform=None,
|
||||||
):
|
):
|
||||||
assert parser is not None
|
assert reader is not None
|
||||||
if isinstance(parser, str):
|
if isinstance(reader, str):
|
||||||
self.parser = create_parser(
|
self.reader = create_reader(
|
||||||
parser,
|
reader,
|
||||||
root=root,
|
root=root,
|
||||||
split=split,
|
split=split,
|
||||||
is_training=is_training,
|
is_training=is_training,
|
||||||
@ -109,13 +109,13 @@ class IterableImageDataset(data.IterableDataset):
|
|||||||
download=download,
|
download=download,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.parser = parser
|
self.reader = reader
|
||||||
self.transform = transform
|
self.transform = transform
|
||||||
self.target_transform = target_transform
|
self.target_transform = target_transform
|
||||||
self._consecutive_errors = 0
|
self._consecutive_errors = 0
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
for img, target in self.parser:
|
for img, target in self.reader:
|
||||||
if self.transform is not None:
|
if self.transform is not None:
|
||||||
img = self.transform(img)
|
img = self.transform(img)
|
||||||
if self.target_transform is not None:
|
if self.target_transform is not None:
|
||||||
@ -123,29 +123,29 @@ class IterableImageDataset(data.IterableDataset):
|
|||||||
yield img, target
|
yield img, target
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
if hasattr(self.parser, '__len__'):
|
if hasattr(self.reader, '__len__'):
|
||||||
return len(self.parser)
|
return len(self.reader)
|
||||||
else:
|
else:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
def set_epoch(self, count):
|
def set_epoch(self, count):
|
||||||
# TFDS and WDS need external epoch count for deterministic cross process shuffle
|
# TFDS and WDS need external epoch count for deterministic cross process shuffle
|
||||||
if hasattr(self.parser, 'set_epoch'):
|
if hasattr(self.reader, 'set_epoch'):
|
||||||
self.parser.set_epoch(count)
|
self.reader.set_epoch(count)
|
||||||
|
|
||||||
def set_loader_cfg(
|
def set_loader_cfg(
|
||||||
self,
|
self,
|
||||||
num_workers: Optional[int] = None,
|
num_workers: Optional[int] = None,
|
||||||
):
|
):
|
||||||
# TFDS and WDS readers need # workers for correct # samples estimate before loader processes created
|
# TFDS and WDS readers need # workers for correct # samples estimate before loader processes created
|
||||||
if hasattr(self.parser, 'set_loader_cfg'):
|
if hasattr(self.reader, 'set_loader_cfg'):
|
||||||
self.parser.set_loader_cfg(num_workers=num_workers)
|
self.reader.set_loader_cfg(num_workers=num_workers)
|
||||||
|
|
||||||
def filename(self, index, basename=False, absolute=False):
|
def filename(self, index, basename=False, absolute=False):
|
||||||
assert False, 'Filename lookup by index not supported, use filenames().'
|
assert False, 'Filename lookup by index not supported, use filenames().'
|
||||||
|
|
||||||
def filenames(self, basename=False, absolute=False):
|
def filenames(self, basename=False, absolute=False):
|
||||||
return self.parser.filenames(basename, absolute)
|
return self.reader.filenames(basename, absolute)
|
||||||
|
|
||||||
|
|
||||||
class AugMixDataset(torch.utils.data.Dataset):
|
class AugMixDataset(torch.utils.data.Dataset):
|
||||||
|
@ -137,11 +137,11 @@ def create_dataset(
|
|||||||
elif name.startswith('hfds/'):
|
elif name.startswith('hfds/'):
|
||||||
# NOTE right now, HF datasets default arrow format is a random-access Dataset,
|
# NOTE right now, HF datasets default arrow format is a random-access Dataset,
|
||||||
# There will be a IterableDataset variant too, TBD
|
# There will be a IterableDataset variant too, TBD
|
||||||
ds = ImageDataset(root, parser=name, split=split, **kwargs)
|
ds = ImageDataset(root, reader=name, split=split, **kwargs)
|
||||||
elif name.startswith('tfds/'):
|
elif name.startswith('tfds/'):
|
||||||
ds = IterableImageDataset(
|
ds = IterableImageDataset(
|
||||||
root,
|
root,
|
||||||
parser=name,
|
reader=name,
|
||||||
split=split,
|
split=split,
|
||||||
is_training=is_training,
|
is_training=is_training,
|
||||||
download=download,
|
download=download,
|
||||||
@ -153,7 +153,7 @@ def create_dataset(
|
|||||||
elif name.startswith('wds/'):
|
elif name.startswith('wds/'):
|
||||||
ds = IterableImageDataset(
|
ds = IterableImageDataset(
|
||||||
root,
|
root,
|
||||||
parser=name,
|
reader=name,
|
||||||
split=split,
|
split=split,
|
||||||
is_training=is_training,
|
is_training=is_training,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
@ -166,5 +166,5 @@ def create_dataset(
|
|||||||
if search_split and os.path.isdir(root):
|
if search_split and os.path.isdir(root):
|
||||||
# look for split specific sub-folder in root
|
# look for split specific sub-folder in root
|
||||||
root = _search_split(root, split)
|
root = _search_split(root, split)
|
||||||
ds = ImageDataset(root, parser=name, class_map=class_map, load_bytes=load_bytes, **kwargs)
|
ds = ImageDataset(root, reader=name, class_map=class_map, load_bytes=load_bytes, **kwargs)
|
||||||
return ds
|
return ds
|
||||||
|
@ -1,2 +0,0 @@
|
|||||||
from .parser_factory import create_parser
|
|
||||||
from .img_extensions import *
|
|
@ -1,35 +0,0 @@
|
|||||||
import os
|
|
||||||
|
|
||||||
from .parser_image_folder import ParserImageFolder
|
|
||||||
from .parser_image_in_tar import ParserImageInTar
|
|
||||||
|
|
||||||
|
|
||||||
def create_parser(name, root, split='train', **kwargs):
|
|
||||||
name = name.lower()
|
|
||||||
name = name.split('/', 2)
|
|
||||||
prefix = ''
|
|
||||||
if len(name) > 1:
|
|
||||||
prefix = name[0]
|
|
||||||
name = name[-1]
|
|
||||||
|
|
||||||
# FIXME improve the selection right now just tfds prefix or fallback path, will need options to
|
|
||||||
# explicitly select other options shortly
|
|
||||||
if prefix == 'hfds':
|
|
||||||
from .parser_hfds import ParserHfds # defer tensorflow import
|
|
||||||
parser = ParserHfds(root, name, split=split, **kwargs)
|
|
||||||
elif prefix == 'tfds':
|
|
||||||
from .parser_tfds import ParserTfds # defer tensorflow import
|
|
||||||
parser = ParserTfds(root, name, split=split, **kwargs)
|
|
||||||
elif prefix == 'wds':
|
|
||||||
from .parser_wds import ParserWds
|
|
||||||
kwargs.pop('download', False)
|
|
||||||
parser = ParserWds(root, name, split=split, **kwargs)
|
|
||||||
else:
|
|
||||||
assert os.path.exists(root)
|
|
||||||
# default fallback path (backwards compat), use image tar if root is a .tar file, otherwise image folder
|
|
||||||
# FIXME support split here, in parser?
|
|
||||||
if os.path.isfile(root) and os.path.splitext(root)[1] == '.tar':
|
|
||||||
parser = ParserImageInTar(root, **kwargs)
|
|
||||||
else:
|
|
||||||
parser = ParserImageFolder(root, **kwargs)
|
|
||||||
return parser
|
|
2
timm/data/readers/__init__.py
Normal file
2
timm/data/readers/__init__.py
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
from .reader_factory import create_reader
|
||||||
|
from .img_extensions import *
|
@ -1,7 +1,7 @@
|
|||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
|
|
||||||
|
|
||||||
class Parser:
|
class Reader:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
pass
|
pass
|
||||||
|
|
35
timm/data/readers/reader_factory.py
Normal file
35
timm/data/readers/reader_factory.py
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
from .reader_image_folder import ReaderImageFolder
|
||||||
|
from .reader_image_in_tar import ReaderImageInTar
|
||||||
|
|
||||||
|
|
||||||
|
def create_reader(name, root, split='train', **kwargs):
|
||||||
|
name = name.lower()
|
||||||
|
name = name.split('/', 2)
|
||||||
|
prefix = ''
|
||||||
|
if len(name) > 1:
|
||||||
|
prefix = name[0]
|
||||||
|
name = name[-1]
|
||||||
|
|
||||||
|
# FIXME improve the selection right now just tfds prefix or fallback path, will need options to
|
||||||
|
# explicitly select other options shortly
|
||||||
|
if prefix == 'hfds':
|
||||||
|
from .reader_hfds import ReaderHfds # defer tensorflow import
|
||||||
|
reader = ReaderHfds(root, name, split=split, **kwargs)
|
||||||
|
elif prefix == 'tfds':
|
||||||
|
from .reader_tfds import ReaderTfds # defer tensorflow import
|
||||||
|
reader = ReaderTfds(root, name, split=split, **kwargs)
|
||||||
|
elif prefix == 'wds':
|
||||||
|
from .reader_wds import ReaderWds
|
||||||
|
kwargs.pop('download', False)
|
||||||
|
reader = ReaderWds(root, name, split=split, **kwargs)
|
||||||
|
else:
|
||||||
|
assert os.path.exists(root)
|
||||||
|
# default fallback path (backwards compat), use image tar if root is a .tar file, otherwise image folder
|
||||||
|
# FIXME support split here or in reader?
|
||||||
|
if os.path.isfile(root) and os.path.splitext(root)[1] == '.tar':
|
||||||
|
reader = ReaderImageInTar(root, **kwargs)
|
||||||
|
else:
|
||||||
|
reader = ReaderImageFolder(root, **kwargs)
|
||||||
|
return reader
|
@ -1,4 +1,5 @@
|
|||||||
""" Dataset parser interface that wraps Hugging Face datasets
|
""" Dataset reader that wraps Hugging Face datasets
|
||||||
|
|
||||||
Hacked together by / Copyright 2022 Ross Wightman
|
Hacked together by / Copyright 2022 Ross Wightman
|
||||||
"""
|
"""
|
||||||
import io
|
import io
|
||||||
@ -12,7 +13,7 @@ try:
|
|||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
print("Please install Hugging Face datasets package `pip install datasets`.")
|
print("Please install Hugging Face datasets package `pip install datasets`.")
|
||||||
exit(1)
|
exit(1)
|
||||||
from .parser import Parser
|
from .reader import Reader
|
||||||
|
|
||||||
|
|
||||||
def get_class_labels(info):
|
def get_class_labels(info):
|
||||||
@ -23,7 +24,7 @@ def get_class_labels(info):
|
|||||||
return class_to_idx
|
return class_to_idx
|
||||||
|
|
||||||
|
|
||||||
class ParserHfds(Parser):
|
class ReaderHfds(Reader):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
@ -1,6 +1,6 @@
|
|||||||
""" A dataset parser that reads images from folders
|
""" A dataset reader that extracts images from folders
|
||||||
|
|
||||||
Folders are scannerd recursively to find image files. Labels are based
|
Folders are scanned recursively to find image files. Labels are based
|
||||||
on the folder hierarchy, just leaf folders by default.
|
on the folder hierarchy, just leaf folders by default.
|
||||||
|
|
||||||
Hacked together by / Copyright 2020 Ross Wightman
|
Hacked together by / Copyright 2020 Ross Wightman
|
||||||
@ -12,7 +12,7 @@ from timm.utils.misc import natural_key
|
|||||||
|
|
||||||
from .class_map import load_class_map
|
from .class_map import load_class_map
|
||||||
from .img_extensions import get_img_extensions
|
from .img_extensions import get_img_extensions
|
||||||
from .parser import Parser
|
from .reader import Reader
|
||||||
|
|
||||||
|
|
||||||
def find_images_and_targets(
|
def find_images_and_targets(
|
||||||
@ -56,7 +56,7 @@ def find_images_and_targets(
|
|||||||
return images_and_targets, class_to_idx
|
return images_and_targets, class_to_idx
|
||||||
|
|
||||||
|
|
||||||
class ParserImageFolder(Parser):
|
class ReaderImageFolder(Reader):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
@ -1,6 +1,6 @@
|
|||||||
""" A dataset parser that reads tarfile based datasets
|
""" A dataset reader that reads tarfile based datasets
|
||||||
|
|
||||||
This parser can read and extract image samples from:
|
This reader can extract image samples from:
|
||||||
* a single tar of image files
|
* a single tar of image files
|
||||||
* a folder of multiple tarfiles containing imagefiles
|
* a folder of multiple tarfiles containing imagefiles
|
||||||
* a tar of tars containing image files
|
* a tar of tars containing image files
|
||||||
@ -22,7 +22,7 @@ from timm.utils.misc import natural_key
|
|||||||
|
|
||||||
from .class_map import load_class_map
|
from .class_map import load_class_map
|
||||||
from .img_extensions import get_img_extensions
|
from .img_extensions import get_img_extensions
|
||||||
from .parser import Parser
|
from .reader import Reader
|
||||||
|
|
||||||
_logger = logging.getLogger(__name__)
|
_logger = logging.getLogger(__name__)
|
||||||
CACHE_FILENAME_SUFFIX = '_tarinfos.pickle'
|
CACHE_FILENAME_SUFFIX = '_tarinfos.pickle'
|
||||||
@ -169,8 +169,8 @@ def extract_tarinfos(
|
|||||||
return samples, targets, class_name_to_idx, tarfiles
|
return samples, targets, class_name_to_idx, tarfiles
|
||||||
|
|
||||||
|
|
||||||
class ParserImageInTar(Parser):
|
class ReaderImageInTar(Reader):
|
||||||
""" Multi-tarfile dataset parser where there is one .tar file per class
|
""" Multi-tarfile dataset reader where there is one .tar file per class
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, root, class_map='', cache_tarfiles=True, cache_tarinfo=None):
|
def __init__(self, root, class_map='', cache_tarfiles=True, cache_tarinfo=None):
|
@ -1,6 +1,6 @@
|
|||||||
""" A dataset parser that reads single tarfile based datasets
|
""" A dataset reader that reads single tarfile based datasets
|
||||||
|
|
||||||
This parser can read datasets consisting if a single tarfile containing images.
|
This reader can read datasets consisting if a single tarfile containing images.
|
||||||
I am planning to deprecated it in favour of ParerImageInTar.
|
I am planning to deprecated it in favour of ParerImageInTar.
|
||||||
|
|
||||||
Hacked together by / Copyright 2020 Ross Wightman
|
Hacked together by / Copyright 2020 Ross Wightman
|
||||||
@ -12,7 +12,7 @@ from timm.utils.misc import natural_key
|
|||||||
|
|
||||||
from .class_map import load_class_map
|
from .class_map import load_class_map
|
||||||
from .img_extensions import get_img_extensions
|
from .img_extensions import get_img_extensions
|
||||||
from .parser import Parser
|
from .reader import Reader
|
||||||
|
|
||||||
|
|
||||||
def extract_tarinfo(tarfile, class_to_idx=None, sort=True):
|
def extract_tarinfo(tarfile, class_to_idx=None, sort=True):
|
||||||
@ -38,9 +38,9 @@ def extract_tarinfo(tarfile, class_to_idx=None, sort=True):
|
|||||||
return tarinfo_and_targets, class_to_idx
|
return tarinfo_and_targets, class_to_idx
|
||||||
|
|
||||||
|
|
||||||
class ParserImageTar(Parser):
|
class ReaderImageTar(Reader):
|
||||||
""" Single tarfile dataset where classes are mapped to folders within tar
|
""" Single tarfile dataset where classes are mapped to folders within tar
|
||||||
NOTE: This class is being deprecated in favour of the more capable ParserImageInTar that can
|
NOTE: This class is being deprecated in favour of the more capable ReaderImageInTar that can
|
||||||
operate on folders of tars or tars in tars.
|
operate on folders of tars or tars in tars.
|
||||||
"""
|
"""
|
||||||
def __init__(self, root, class_map=''):
|
def __init__(self, root, class_map=''):
|
@ -1,4 +1,4 @@
|
|||||||
""" Dataset parser interface that wraps TFDS datasets
|
""" Dataset reader that wraps TFDS datasets
|
||||||
|
|
||||||
Wraps many (most?) TFDS image-classification datasets
|
Wraps many (most?) TFDS image-classification datasets
|
||||||
from https://github.com/tensorflow/datasets
|
from https://github.com/tensorflow/datasets
|
||||||
@ -34,7 +34,7 @@ except ImportError as e:
|
|||||||
print("Please install tensorflow_datasets package `pip install tensorflow-datasets`.")
|
print("Please install tensorflow_datasets package `pip install tensorflow-datasets`.")
|
||||||
exit(1)
|
exit(1)
|
||||||
|
|
||||||
from .parser import Parser
|
from .reader import Reader
|
||||||
from .shared_count import SharedCount
|
from .shared_count import SharedCount
|
||||||
|
|
||||||
|
|
||||||
@ -56,7 +56,7 @@ def get_class_labels(info):
|
|||||||
return class_to_idx
|
return class_to_idx
|
||||||
|
|
||||||
|
|
||||||
class ParserTfds(Parser):
|
class ReaderTfds(Reader):
|
||||||
""" Wrap Tensorflow Datasets for use in PyTorch
|
""" Wrap Tensorflow Datasets for use in PyTorch
|
||||||
|
|
||||||
There several things to be aware of:
|
There several things to be aware of:
|
@ -1,4 +1,4 @@
|
|||||||
""" Dataset parser interface for webdataset
|
""" Dataset reader for webdataset
|
||||||
|
|
||||||
Hacked together by / Copyright 2022 Ross Wightman
|
Hacked together by / Copyright 2022 Ross Wightman
|
||||||
"""
|
"""
|
||||||
@ -29,7 +29,7 @@ except ImportError:
|
|||||||
wds = None
|
wds = None
|
||||||
expand_urls = None
|
expand_urls = None
|
||||||
|
|
||||||
from .parser import Parser
|
from .reader import Reader
|
||||||
from .shared_count import SharedCount
|
from .shared_count import SharedCount
|
||||||
|
|
||||||
_logger = logging.getLogger(__name__)
|
_logger = logging.getLogger(__name__)
|
||||||
@ -280,7 +280,7 @@ class ResampledShards2(IterableDataset):
|
|||||||
yield dict(url=self.urls[index])
|
yield dict(url=self.urls[index])
|
||||||
|
|
||||||
|
|
||||||
class ParserWds(Parser):
|
class ReaderWds(Reader):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
root,
|
root,
|
Loading…
x
Reference in New Issue
Block a user