mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Rename dataset/parsers -> dataset/readers, create_parser to create_reader, etc
This commit is contained in:
parent
f67a7ee8bd
commit
e9dccc918c
@ -112,7 +112,7 @@ More models, more fixes
|
||||
* `cs3`, `darknet`, and `vit_*relpos` weights above all trained on TPU thanks to TRC program! Rest trained on overheating GPUs.
|
||||
* Hugging Face Hub support fixes verified, demo notebook TBA
|
||||
* Pretrained weights / configs can be loaded externally (ie from local disk) w/ support for head adaptation.
|
||||
* Add support to change image extensions scanned by `timm` datasets/parsers. See (https://github.com/rwightman/pytorch-image-models/pull/1274#issuecomment-1178303103)
|
||||
* Add support to change image extensions scanned by `timm` datasets/readers. See (https://github.com/rwightman/pytorch-image-models/pull/1274#issuecomment-1178303103)
|
||||
* Default ConvNeXt LayerNorm impl to use `F.layer_norm(x.permute(0, 2, 3, 1), ...).permute(0, 3, 1, 2)` via `LayerNorm2d` in all cases.
|
||||
* a bit slower than previous custom impl on some hardware (ie Ampere w/ CL), but overall fewer regressions across wider HW / PyTorch version ranges.
|
||||
* previous impl exists as `LayerNormExp2d` in `models/layers/norm.py`
|
||||
|
@ -6,8 +6,8 @@ from .dataset import ImageDataset, IterableImageDataset, AugMixDataset
|
||||
from .dataset_factory import create_dataset
|
||||
from .loader import create_loader
|
||||
from .mixup import Mixup, FastCollateMixup
|
||||
from .parsers import create_parser,\
|
||||
get_img_extensions, is_img_extension, set_img_extensions, add_img_extensions, del_img_extensions
|
||||
from .readers import create_reader
|
||||
from .readers import get_img_extensions, is_img_extension, set_img_extensions, add_img_extensions, del_img_extensions
|
||||
from .real_labels import RealLabelsImagenet
|
||||
from .transforms import *
|
||||
from .transforms_factory import create_transform
|
||||
|
@ -10,7 +10,7 @@ import torch
|
||||
import torch.utils.data as data
|
||||
from PIL import Image
|
||||
|
||||
from .parsers import create_parser
|
||||
from .readers import create_reader
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
@ -23,7 +23,7 @@ class ImageDataset(data.Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
root,
|
||||
parser=None,
|
||||
reader=None,
|
||||
split='train',
|
||||
class_map=None,
|
||||
load_bytes=False,
|
||||
@ -31,14 +31,14 @@ class ImageDataset(data.Dataset):
|
||||
transform=None,
|
||||
target_transform=None,
|
||||
):
|
||||
if parser is None or isinstance(parser, str):
|
||||
parser = create_parser(
|
||||
parser or '',
|
||||
if reader is None or isinstance(reader, str):
|
||||
reader = create_reader(
|
||||
reader or '',
|
||||
root=root,
|
||||
split=split,
|
||||
class_map=class_map
|
||||
)
|
||||
self.parser = parser
|
||||
self.reader = reader
|
||||
self.load_bytes = load_bytes
|
||||
self.img_mode = img_mode
|
||||
self.transform = transform
|
||||
@ -46,15 +46,15 @@ class ImageDataset(data.Dataset):
|
||||
self._consecutive_errors = 0
|
||||
|
||||
def __getitem__(self, index):
|
||||
img, target = self.parser[index]
|
||||
img, target = self.reader[index]
|
||||
|
||||
try:
|
||||
img = img.read() if self.load_bytes else Image.open(img)
|
||||
except Exception as e:
|
||||
_logger.warning(f'Skipped sample (index {index}, file {self.parser.filename(index)}). {str(e)}')
|
||||
_logger.warning(f'Skipped sample (index {index}, file {self.reader.filename(index)}). {str(e)}')
|
||||
self._consecutive_errors += 1
|
||||
if self._consecutive_errors < _ERROR_RETRY:
|
||||
return self.__getitem__((index + 1) % len(self.parser))
|
||||
return self.__getitem__((index + 1) % len(self.reader))
|
||||
else:
|
||||
raise e
|
||||
self._consecutive_errors = 0
|
||||
@ -72,13 +72,13 @@ class ImageDataset(data.Dataset):
|
||||
return img, target
|
||||
|
||||
def __len__(self):
|
||||
return len(self.parser)
|
||||
return len(self.reader)
|
||||
|
||||
def filename(self, index, basename=False, absolute=False):
|
||||
return self.parser.filename(index, basename, absolute)
|
||||
return self.reader.filename(index, basename, absolute)
|
||||
|
||||
def filenames(self, basename=False, absolute=False):
|
||||
return self.parser.filenames(basename, absolute)
|
||||
return self.reader.filenames(basename, absolute)
|
||||
|
||||
|
||||
class IterableImageDataset(data.IterableDataset):
|
||||
@ -86,7 +86,7 @@ class IterableImageDataset(data.IterableDataset):
|
||||
def __init__(
|
||||
self,
|
||||
root,
|
||||
parser=None,
|
||||
reader=None,
|
||||
split='train',
|
||||
is_training=False,
|
||||
batch_size=None,
|
||||
@ -96,10 +96,10 @@ class IterableImageDataset(data.IterableDataset):
|
||||
transform=None,
|
||||
target_transform=None,
|
||||
):
|
||||
assert parser is not None
|
||||
if isinstance(parser, str):
|
||||
self.parser = create_parser(
|
||||
parser,
|
||||
assert reader is not None
|
||||
if isinstance(reader, str):
|
||||
self.reader = create_reader(
|
||||
reader,
|
||||
root=root,
|
||||
split=split,
|
||||
is_training=is_training,
|
||||
@ -109,13 +109,13 @@ class IterableImageDataset(data.IterableDataset):
|
||||
download=download,
|
||||
)
|
||||
else:
|
||||
self.parser = parser
|
||||
self.reader = reader
|
||||
self.transform = transform
|
||||
self.target_transform = target_transform
|
||||
self._consecutive_errors = 0
|
||||
|
||||
def __iter__(self):
|
||||
for img, target in self.parser:
|
||||
for img, target in self.reader:
|
||||
if self.transform is not None:
|
||||
img = self.transform(img)
|
||||
if self.target_transform is not None:
|
||||
@ -123,29 +123,29 @@ class IterableImageDataset(data.IterableDataset):
|
||||
yield img, target
|
||||
|
||||
def __len__(self):
|
||||
if hasattr(self.parser, '__len__'):
|
||||
return len(self.parser)
|
||||
if hasattr(self.reader, '__len__'):
|
||||
return len(self.reader)
|
||||
else:
|
||||
return 0
|
||||
|
||||
def set_epoch(self, count):
|
||||
# TFDS and WDS need external epoch count for deterministic cross process shuffle
|
||||
if hasattr(self.parser, 'set_epoch'):
|
||||
self.parser.set_epoch(count)
|
||||
if hasattr(self.reader, 'set_epoch'):
|
||||
self.reader.set_epoch(count)
|
||||
|
||||
def set_loader_cfg(
|
||||
self,
|
||||
num_workers: Optional[int] = None,
|
||||
):
|
||||
# TFDS and WDS readers need # workers for correct # samples estimate before loader processes created
|
||||
if hasattr(self.parser, 'set_loader_cfg'):
|
||||
self.parser.set_loader_cfg(num_workers=num_workers)
|
||||
if hasattr(self.reader, 'set_loader_cfg'):
|
||||
self.reader.set_loader_cfg(num_workers=num_workers)
|
||||
|
||||
def filename(self, index, basename=False, absolute=False):
|
||||
assert False, 'Filename lookup by index not supported, use filenames().'
|
||||
|
||||
def filenames(self, basename=False, absolute=False):
|
||||
return self.parser.filenames(basename, absolute)
|
||||
return self.reader.filenames(basename, absolute)
|
||||
|
||||
|
||||
class AugMixDataset(torch.utils.data.Dataset):
|
||||
|
@ -137,11 +137,11 @@ def create_dataset(
|
||||
elif name.startswith('hfds/'):
|
||||
# NOTE right now, HF datasets default arrow format is a random-access Dataset,
|
||||
# There will be a IterableDataset variant too, TBD
|
||||
ds = ImageDataset(root, parser=name, split=split, **kwargs)
|
||||
ds = ImageDataset(root, reader=name, split=split, **kwargs)
|
||||
elif name.startswith('tfds/'):
|
||||
ds = IterableImageDataset(
|
||||
root,
|
||||
parser=name,
|
||||
reader=name,
|
||||
split=split,
|
||||
is_training=is_training,
|
||||
download=download,
|
||||
@ -153,7 +153,7 @@ def create_dataset(
|
||||
elif name.startswith('wds/'):
|
||||
ds = IterableImageDataset(
|
||||
root,
|
||||
parser=name,
|
||||
reader=name,
|
||||
split=split,
|
||||
is_training=is_training,
|
||||
batch_size=batch_size,
|
||||
@ -166,5 +166,5 @@ def create_dataset(
|
||||
if search_split and os.path.isdir(root):
|
||||
# look for split specific sub-folder in root
|
||||
root = _search_split(root, split)
|
||||
ds = ImageDataset(root, parser=name, class_map=class_map, load_bytes=load_bytes, **kwargs)
|
||||
ds = ImageDataset(root, reader=name, class_map=class_map, load_bytes=load_bytes, **kwargs)
|
||||
return ds
|
||||
|
@ -1,2 +0,0 @@
|
||||
from .parser_factory import create_parser
|
||||
from .img_extensions import *
|
@ -1,35 +0,0 @@
|
||||
import os
|
||||
|
||||
from .parser_image_folder import ParserImageFolder
|
||||
from .parser_image_in_tar import ParserImageInTar
|
||||
|
||||
|
||||
def create_parser(name, root, split='train', **kwargs):
|
||||
name = name.lower()
|
||||
name = name.split('/', 2)
|
||||
prefix = ''
|
||||
if len(name) > 1:
|
||||
prefix = name[0]
|
||||
name = name[-1]
|
||||
|
||||
# FIXME improve the selection right now just tfds prefix or fallback path, will need options to
|
||||
# explicitly select other options shortly
|
||||
if prefix == 'hfds':
|
||||
from .parser_hfds import ParserHfds # defer tensorflow import
|
||||
parser = ParserHfds(root, name, split=split, **kwargs)
|
||||
elif prefix == 'tfds':
|
||||
from .parser_tfds import ParserTfds # defer tensorflow import
|
||||
parser = ParserTfds(root, name, split=split, **kwargs)
|
||||
elif prefix == 'wds':
|
||||
from .parser_wds import ParserWds
|
||||
kwargs.pop('download', False)
|
||||
parser = ParserWds(root, name, split=split, **kwargs)
|
||||
else:
|
||||
assert os.path.exists(root)
|
||||
# default fallback path (backwards compat), use image tar if root is a .tar file, otherwise image folder
|
||||
# FIXME support split here, in parser?
|
||||
if os.path.isfile(root) and os.path.splitext(root)[1] == '.tar':
|
||||
parser = ParserImageInTar(root, **kwargs)
|
||||
else:
|
||||
parser = ParserImageFolder(root, **kwargs)
|
||||
return parser
|
2
timm/data/readers/__init__.py
Normal file
2
timm/data/readers/__init__.py
Normal file
@ -0,0 +1,2 @@
|
||||
from .reader_factory import create_reader
|
||||
from .img_extensions import *
|
@ -1,7 +1,7 @@
|
||||
from abc import abstractmethod
|
||||
|
||||
|
||||
class Parser:
|
||||
class Reader:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
35
timm/data/readers/reader_factory.py
Normal file
35
timm/data/readers/reader_factory.py
Normal file
@ -0,0 +1,35 @@
|
||||
import os
|
||||
|
||||
from .reader_image_folder import ReaderImageFolder
|
||||
from .reader_image_in_tar import ReaderImageInTar
|
||||
|
||||
|
||||
def create_reader(name, root, split='train', **kwargs):
|
||||
name = name.lower()
|
||||
name = name.split('/', 2)
|
||||
prefix = ''
|
||||
if len(name) > 1:
|
||||
prefix = name[0]
|
||||
name = name[-1]
|
||||
|
||||
# FIXME improve the selection right now just tfds prefix or fallback path, will need options to
|
||||
# explicitly select other options shortly
|
||||
if prefix == 'hfds':
|
||||
from .reader_hfds import ReaderHfds # defer tensorflow import
|
||||
reader = ReaderHfds(root, name, split=split, **kwargs)
|
||||
elif prefix == 'tfds':
|
||||
from .reader_tfds import ReaderTfds # defer tensorflow import
|
||||
reader = ReaderTfds(root, name, split=split, **kwargs)
|
||||
elif prefix == 'wds':
|
||||
from .reader_wds import ReaderWds
|
||||
kwargs.pop('download', False)
|
||||
reader = ReaderWds(root, name, split=split, **kwargs)
|
||||
else:
|
||||
assert os.path.exists(root)
|
||||
# default fallback path (backwards compat), use image tar if root is a .tar file, otherwise image folder
|
||||
# FIXME support split here or in reader?
|
||||
if os.path.isfile(root) and os.path.splitext(root)[1] == '.tar':
|
||||
reader = ReaderImageInTar(root, **kwargs)
|
||||
else:
|
||||
reader = ReaderImageFolder(root, **kwargs)
|
||||
return reader
|
@ -1,4 +1,5 @@
|
||||
""" Dataset parser interface that wraps Hugging Face datasets
|
||||
""" Dataset reader that wraps Hugging Face datasets
|
||||
|
||||
Hacked together by / Copyright 2022 Ross Wightman
|
||||
"""
|
||||
import io
|
||||
@ -12,7 +13,7 @@ try:
|
||||
except ImportError as e:
|
||||
print("Please install Hugging Face datasets package `pip install datasets`.")
|
||||
exit(1)
|
||||
from .parser import Parser
|
||||
from .reader import Reader
|
||||
|
||||
|
||||
def get_class_labels(info):
|
||||
@ -23,7 +24,7 @@ def get_class_labels(info):
|
||||
return class_to_idx
|
||||
|
||||
|
||||
class ParserHfds(Parser):
|
||||
class ReaderHfds(Reader):
|
||||
|
||||
def __init__(
|
||||
self,
|
@ -1,6 +1,6 @@
|
||||
""" A dataset parser that reads images from folders
|
||||
""" A dataset reader that extracts images from folders
|
||||
|
||||
Folders are scannerd recursively to find image files. Labels are based
|
||||
Folders are scanned recursively to find image files. Labels are based
|
||||
on the folder hierarchy, just leaf folders by default.
|
||||
|
||||
Hacked together by / Copyright 2020 Ross Wightman
|
||||
@ -12,7 +12,7 @@ from timm.utils.misc import natural_key
|
||||
|
||||
from .class_map import load_class_map
|
||||
from .img_extensions import get_img_extensions
|
||||
from .parser import Parser
|
||||
from .reader import Reader
|
||||
|
||||
|
||||
def find_images_and_targets(
|
||||
@ -56,7 +56,7 @@ def find_images_and_targets(
|
||||
return images_and_targets, class_to_idx
|
||||
|
||||
|
||||
class ParserImageFolder(Parser):
|
||||
class ReaderImageFolder(Reader):
|
||||
|
||||
def __init__(
|
||||
self,
|
@ -1,6 +1,6 @@
|
||||
""" A dataset parser that reads tarfile based datasets
|
||||
""" A dataset reader that reads tarfile based datasets
|
||||
|
||||
This parser can read and extract image samples from:
|
||||
This reader can extract image samples from:
|
||||
* a single tar of image files
|
||||
* a folder of multiple tarfiles containing imagefiles
|
||||
* a tar of tars containing image files
|
||||
@ -22,7 +22,7 @@ from timm.utils.misc import natural_key
|
||||
|
||||
from .class_map import load_class_map
|
||||
from .img_extensions import get_img_extensions
|
||||
from .parser import Parser
|
||||
from .reader import Reader
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
CACHE_FILENAME_SUFFIX = '_tarinfos.pickle'
|
||||
@ -169,8 +169,8 @@ def extract_tarinfos(
|
||||
return samples, targets, class_name_to_idx, tarfiles
|
||||
|
||||
|
||||
class ParserImageInTar(Parser):
|
||||
""" Multi-tarfile dataset parser where there is one .tar file per class
|
||||
class ReaderImageInTar(Reader):
|
||||
""" Multi-tarfile dataset reader where there is one .tar file per class
|
||||
"""
|
||||
|
||||
def __init__(self, root, class_map='', cache_tarfiles=True, cache_tarinfo=None):
|
@ -1,6 +1,6 @@
|
||||
""" A dataset parser that reads single tarfile based datasets
|
||||
""" A dataset reader that reads single tarfile based datasets
|
||||
|
||||
This parser can read datasets consisting if a single tarfile containing images.
|
||||
This reader can read datasets consisting if a single tarfile containing images.
|
||||
I am planning to deprecated it in favour of ParerImageInTar.
|
||||
|
||||
Hacked together by / Copyright 2020 Ross Wightman
|
||||
@ -12,7 +12,7 @@ from timm.utils.misc import natural_key
|
||||
|
||||
from .class_map import load_class_map
|
||||
from .img_extensions import get_img_extensions
|
||||
from .parser import Parser
|
||||
from .reader import Reader
|
||||
|
||||
|
||||
def extract_tarinfo(tarfile, class_to_idx=None, sort=True):
|
||||
@ -38,9 +38,9 @@ def extract_tarinfo(tarfile, class_to_idx=None, sort=True):
|
||||
return tarinfo_and_targets, class_to_idx
|
||||
|
||||
|
||||
class ParserImageTar(Parser):
|
||||
class ReaderImageTar(Reader):
|
||||
""" Single tarfile dataset where classes are mapped to folders within tar
|
||||
NOTE: This class is being deprecated in favour of the more capable ParserImageInTar that can
|
||||
NOTE: This class is being deprecated in favour of the more capable ReaderImageInTar that can
|
||||
operate on folders of tars or tars in tars.
|
||||
"""
|
||||
def __init__(self, root, class_map=''):
|
@ -1,4 +1,4 @@
|
||||
""" Dataset parser interface that wraps TFDS datasets
|
||||
""" Dataset reader that wraps TFDS datasets
|
||||
|
||||
Wraps many (most?) TFDS image-classification datasets
|
||||
from https://github.com/tensorflow/datasets
|
||||
@ -34,7 +34,7 @@ except ImportError as e:
|
||||
print("Please install tensorflow_datasets package `pip install tensorflow-datasets`.")
|
||||
exit(1)
|
||||
|
||||
from .parser import Parser
|
||||
from .reader import Reader
|
||||
from .shared_count import SharedCount
|
||||
|
||||
|
||||
@ -56,7 +56,7 @@ def get_class_labels(info):
|
||||
return class_to_idx
|
||||
|
||||
|
||||
class ParserTfds(Parser):
|
||||
class ReaderTfds(Reader):
|
||||
""" Wrap Tensorflow Datasets for use in PyTorch
|
||||
|
||||
There several things to be aware of:
|
@ -1,4 +1,4 @@
|
||||
""" Dataset parser interface for webdataset
|
||||
""" Dataset reader for webdataset
|
||||
|
||||
Hacked together by / Copyright 2022 Ross Wightman
|
||||
"""
|
||||
@ -29,7 +29,7 @@ except ImportError:
|
||||
wds = None
|
||||
expand_urls = None
|
||||
|
||||
from .parser import Parser
|
||||
from .reader import Reader
|
||||
from .shared_count import SharedCount
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
@ -280,7 +280,7 @@ class ResampledShards2(IterableDataset):
|
||||
yield dict(url=self.urls[index])
|
||||
|
||||
|
||||
class ParserWds(Parser):
|
||||
class ReaderWds(Reader):
|
||||
def __init__(
|
||||
self,
|
||||
root,
|
Loading…
x
Reference in New Issue
Block a user