mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Dataset work
* support some torchvision datasets * improvements to TFDS wrapper for subsplit handling (fix #942), shuffle seed * add class-map support to train (fix #957)
This commit is contained in:
parent
ddc29da974
commit
ba65dfe2c6
@ -23,15 +23,17 @@ class ImageDataset(data.Dataset):
|
||||
self,
|
||||
root,
|
||||
parser=None,
|
||||
class_map='',
|
||||
class_map=None,
|
||||
load_bytes=False,
|
||||
transform=None,
|
||||
target_transform=None,
|
||||
):
|
||||
if parser is None or isinstance(parser, str):
|
||||
parser = create_parser(parser or '', root=root, class_map=class_map)
|
||||
self.parser = parser
|
||||
self.load_bytes = load_bytes
|
||||
self.transform = transform
|
||||
self.target_transform = target_transform
|
||||
self._consecutive_errors = 0
|
||||
|
||||
def __getitem__(self, index):
|
||||
@ -49,7 +51,9 @@ class ImageDataset(data.Dataset):
|
||||
if self.transform is not None:
|
||||
img = self.transform(img)
|
||||
if target is None:
|
||||
target = torch.tensor(-1, dtype=torch.long)
|
||||
target = -1
|
||||
elif self.target_transform is not None:
|
||||
target = self.target_transform(target)
|
||||
return img, target
|
||||
|
||||
def __len__(self):
|
||||
@ -71,26 +75,28 @@ class IterableImageDataset(data.IterableDataset):
|
||||
split='train',
|
||||
is_training=False,
|
||||
batch_size=None,
|
||||
class_map='',
|
||||
load_bytes=False,
|
||||
repeats=0,
|
||||
download=False,
|
||||
transform=None,
|
||||
target_transform=None,
|
||||
):
|
||||
assert parser is not None
|
||||
if isinstance(parser, str):
|
||||
self.parser = create_parser(
|
||||
parser, root=root, split=split, is_training=is_training, batch_size=batch_size, repeats=repeats)
|
||||
parser, root=root, split=split, is_training=is_training,
|
||||
batch_size=batch_size, repeats=repeats, download=download)
|
||||
else:
|
||||
self.parser = parser
|
||||
self.transform = transform
|
||||
self.target_transform = target_transform
|
||||
self._consecutive_errors = 0
|
||||
|
||||
def __iter__(self):
|
||||
for img, target in self.parser:
|
||||
if self.transform is not None:
|
||||
img = self.transform(img)
|
||||
if target is None:
|
||||
target = torch.tensor(-1, dtype=torch.long)
|
||||
if self.target_transform is not None:
|
||||
target = self.target_transform(target)
|
||||
yield img, target
|
||||
|
||||
def __len__(self):
|
||||
|
@ -1,7 +1,26 @@
|
||||
import os
|
||||
|
||||
from torchvision.datasets import CIFAR100, CIFAR10, MNIST, QMNIST, KMNIST, FashionMNIST,\
|
||||
Places365, ImageNet, ImageFolder
|
||||
try:
|
||||
from torchvision.datasets import INaturalist
|
||||
has_inaturalist = True
|
||||
except ImportError:
|
||||
has_inaturalist = False
|
||||
|
||||
from .dataset import IterableImageDataset, ImageDataset
|
||||
|
||||
_TORCH_BASIC_DS = dict(
|
||||
cifar10=CIFAR10,
|
||||
cifar100=CIFAR100,
|
||||
mnist=MNIST,
|
||||
qmist=QMNIST,
|
||||
kmnist=KMNIST,
|
||||
fashion_mnist=FashionMNIST,
|
||||
)
|
||||
_TRAIN_SYNONYM = {'train', 'training'}
|
||||
_EVAL_SYNONYM = {'val', 'valid', 'validation', 'eval', 'evaluation'}
|
||||
|
||||
|
||||
def _search_split(root, split):
|
||||
# look for sub-folder with name of split in root and use that if it exists
|
||||
@ -9,22 +28,107 @@ def _search_split(root, split):
|
||||
try_root = os.path.join(root, split_name)
|
||||
if os.path.exists(try_root):
|
||||
return try_root
|
||||
if split_name == 'validation':
|
||||
try_root = os.path.join(root, 'val')
|
||||
if os.path.exists(try_root):
|
||||
return try_root
|
||||
|
||||
def _try(syn):
|
||||
for s in syn:
|
||||
try_root = os.path.join(root, s)
|
||||
if os.path.exists(try_root):
|
||||
return try_root
|
||||
return root
|
||||
if split_name in _TRAIN_SYNONYM:
|
||||
root = _try(_TRAIN_SYNONYM)
|
||||
elif split_name in _EVAL_SYNONYM:
|
||||
root = _try(_EVAL_SYNONYM)
|
||||
return root
|
||||
|
||||
|
||||
def create_dataset(name, root, split='validation', search_split=True, is_training=False, batch_size=None, **kwargs):
|
||||
def create_dataset(
|
||||
name,
|
||||
root,
|
||||
split='validation',
|
||||
search_split=True,
|
||||
class_map=None,
|
||||
load_bytes=False,
|
||||
is_training=False,
|
||||
download=False,
|
||||
batch_size=None,
|
||||
repeats=0,
|
||||
**kwargs
|
||||
):
|
||||
""" Dataset factory method
|
||||
|
||||
In parenthesis after each arg are the type of dataset supported for each arg, one of:
|
||||
* folder - default, timm folder (or tar) based ImageDataset
|
||||
* torch - torchvision based datasets
|
||||
* TFDS - Tensorflow-datasets wrapper in IterabeDataset interface via IterableImageDataset
|
||||
* all - any of the above
|
||||
|
||||
Args:
|
||||
name: dataset name, empty is okay for folder based datasets
|
||||
root: root folder of dataset (all)
|
||||
split: dataset split (all)
|
||||
search_split: search for split specific child fold from root so one can specify
|
||||
`imagenet/` instead of `/imagenet/val`, etc on cmd line / config. (folder, torch/folder)
|
||||
class_map: specify class -> index mapping via text file or dict (folder)
|
||||
load_bytes: load data, return images as undecoded bytes (folder)
|
||||
download: download dataset if not present and supported (TFDS, torch)
|
||||
is_training: create dataset in train mode, this is different from the split.
|
||||
For Iterable / TDFS it enables shuffle, ignored for other datasets. (TFDS)
|
||||
batch_size: batch size hint for (TFDS)
|
||||
repeats: dataset repeats per iteration i.e. epoch (TFDS)
|
||||
**kwargs: other args to pass to dataset
|
||||
|
||||
Returns:
|
||||
Dataset object
|
||||
"""
|
||||
name = name.lower()
|
||||
if name.startswith('tfds'):
|
||||
if name.startswith('torch/'):
|
||||
name = name.split('/', 2)[-1]
|
||||
torch_kwargs = dict(root=root, download=download, **kwargs)
|
||||
if name in _TORCH_BASIC_DS:
|
||||
ds_class = _TORCH_BASIC_DS[name]
|
||||
use_train = split in _TRAIN_SYNONYM
|
||||
ds = ds_class(train=use_train, **torch_kwargs)
|
||||
elif name == 'inaturalist' or name == 'inat':
|
||||
assert has_inaturalist, 'Please update to PyTorch 1.10, torchvision 0.11+ for Inaturalist'
|
||||
target_type = 'full'
|
||||
split_split = split.split('/')
|
||||
if len(split_split) > 1:
|
||||
target_type = split_split[0].split('_')
|
||||
if len(target_type) == 1:
|
||||
target_type = target_type[0]
|
||||
split = split_split[-1]
|
||||
if split in _TRAIN_SYNONYM:
|
||||
split = '2021_train'
|
||||
elif split in _EVAL_SYNONYM:
|
||||
split = '2021_valid'
|
||||
ds = INaturalist(version=split, target_type=target_type, **torch_kwargs)
|
||||
elif name == 'places365':
|
||||
if split in _TRAIN_SYNONYM:
|
||||
split = 'train-standard'
|
||||
elif split in _EVAL_SYNONYM:
|
||||
split = 'val'
|
||||
ds = Places365(split=split, **torch_kwargs)
|
||||
elif name == 'imagenet':
|
||||
if split in _EVAL_SYNONYM:
|
||||
split = 'val'
|
||||
ds = ImageNet(split=split, **torch_kwargs)
|
||||
elif name == 'image_folder' or name == 'folder':
|
||||
# in case torchvision ImageFolder is preferred over timm ImageDataset for some reason
|
||||
if search_split and os.path.isdir(root):
|
||||
# look for split specific sub-folder in root
|
||||
root = _search_split(root, split)
|
||||
ds = ImageFolder(root, **kwargs)
|
||||
else:
|
||||
assert False, f"Unknown torchvision dataset {name}"
|
||||
elif name.startswith('tfds/'):
|
||||
ds = IterableImageDataset(
|
||||
root, parser=name, split=split, is_training=is_training, batch_size=batch_size, **kwargs)
|
||||
root, parser=name, split=split, is_training=is_training,
|
||||
download=download, batch_size=batch_size, repeats=repeats, **kwargs)
|
||||
else:
|
||||
# FIXME support more advance split cfg for ImageFolder/Tar datasets in the future
|
||||
kwargs.pop('repeats', 0) # FIXME currently only Iterable dataset support the repeat multiplier
|
||||
if search_split and os.path.isdir(root):
|
||||
# look for split specific sub-folder in root
|
||||
root = _search_split(root, split)
|
||||
ds = ImageDataset(root, parser=name, **kwargs)
|
||||
ds = ImageDataset(root, parser=name, class_map=class_map, load_bytes=load_bytes, **kwargs)
|
||||
return ds
|
||||
|
@ -1,16 +1,19 @@
|
||||
import os
|
||||
|
||||
|
||||
def load_class_map(filename, root=''):
|
||||
class_map_path = filename
|
||||
def load_class_map(map_or_filename, root=''):
|
||||
if isinstance(map_or_filename, dict):
|
||||
assert dict, 'class_map dict must be non-empty'
|
||||
return map_or_filename
|
||||
class_map_path = map_or_filename
|
||||
if not os.path.exists(class_map_path):
|
||||
class_map_path = os.path.join(root, filename)
|
||||
assert os.path.exists(class_map_path), 'Cannot locate specified class map file (%s)' % filename
|
||||
class_map_ext = os.path.splitext(filename)[-1].lower()
|
||||
class_map_path = os.path.join(root, class_map_path)
|
||||
assert os.path.exists(class_map_path), 'Cannot locate specified class map file (%s)' % map_or_filename
|
||||
class_map_ext = os.path.splitext(map_or_filename)[-1].lower()
|
||||
if class_map_ext == '.txt':
|
||||
with open(class_map_path) as f:
|
||||
class_to_idx = {v.strip(): k for k, v in enumerate(f)}
|
||||
else:
|
||||
assert False, 'Unsupported class map extension'
|
||||
assert False, f'Unsupported class map file extension ({class_map_ext}).'
|
||||
return class_to_idx
|
||||
|
||||
|
@ -17,7 +17,7 @@ def create_parser(name, root, split='train', **kwargs):
|
||||
# explicitly select other options shortly
|
||||
if prefix == 'tfds':
|
||||
from .parser_tfds import ParserTfds # defer tensorflow import
|
||||
parser = ParserTfds(root, name, split=split, shuffle=kwargs.pop('shuffle', False), **kwargs)
|
||||
parser = ParserTfds(root, name, split=split, **kwargs)
|
||||
else:
|
||||
assert os.path.exists(root)
|
||||
# default fallback path (backwards compat), use image tar if root is a .tar file, otherwise image folder
|
||||
|
@ -57,23 +57,28 @@ class ParserTfds(Parser):
|
||||
components.
|
||||
|
||||
"""
|
||||
def __init__(self, root, name, split='train', shuffle=False, is_training=False, batch_size=None, repeats=0):
|
||||
def __init__(
|
||||
self, root, name, split='train', is_training=False, batch_size=None,
|
||||
download=False, repeats=0, seed=42):
|
||||
super().__init__()
|
||||
self.root = root
|
||||
self.split = split
|
||||
self.shuffle = shuffle
|
||||
self.is_training = is_training
|
||||
if self.is_training:
|
||||
assert batch_size is not None,\
|
||||
"Must specify batch_size in training mode for reasonable behaviour w/ TFDS wrapper"
|
||||
self.batch_size = batch_size
|
||||
self.repeats = repeats
|
||||
self.common_seed = seed # seed across all worker / dist nodes
|
||||
self.worker_seed = 0 # seed specific to each work instance
|
||||
self.subsplit = None
|
||||
|
||||
self.builder = tfds.builder(name, data_dir=root)
|
||||
# NOTE: please use tfds command line app to download & prepare datasets, I don't want to call
|
||||
# download_and_prepare() by default here as it's caused issues generating unwanted paths.
|
||||
self.num_samples = self.builder.info.splits[split].num_examples
|
||||
# NOTE: the tfds command line app can be used download & prepare datasets if you don't enable download flag
|
||||
if download:
|
||||
self.builder.download_and_prepare()
|
||||
self.split_info = self.builder.info.splits[split]
|
||||
self.num_samples = self.split_info.num_examples
|
||||
self.ds = None # initialized lazily on each dataloader worker process
|
||||
|
||||
self.worker_info = None
|
||||
@ -97,17 +102,18 @@ class ParserTfds(Parser):
|
||||
worker_info = torch.utils.data.get_worker_info()
|
||||
|
||||
# setup input context to split dataset across distributed processes
|
||||
split = self.split
|
||||
num_workers = 1
|
||||
global_num_workers = num_workers = 1
|
||||
global_worker_id = 1
|
||||
if worker_info is not None:
|
||||
self.worker_info = worker_info
|
||||
self.worker_seed = worker_info.seed
|
||||
num_workers = worker_info.num_workers
|
||||
global_num_workers = self.dist_num_replicas * num_workers
|
||||
worker_id = worker_info.id
|
||||
global_worker_id = self.dist_rank * num_workers + worker_id
|
||||
|
||||
# FIXME I need to spend more time figuring out the best way to distribute/split data across
|
||||
# combo of distributed replicas + dataloader worker processes
|
||||
"""
|
||||
# FIXME verify best sharding approach
|
||||
""" Data sharding
|
||||
InputContext will assign subset of underlying TFRecord files to each 'pipeline' if used.
|
||||
My understanding is that using split, the underling TFRecord files will shuffle (shuffle_files=True)
|
||||
between the splits each iteration, but that understanding could be wrong.
|
||||
@ -116,44 +122,39 @@ class ParserTfds(Parser):
|
||||
* InputContext for distributed and sub-splits for worker processes
|
||||
* sub-splits for both
|
||||
"""
|
||||
# split_size = self.num_samples // num_workers
|
||||
# start = worker_id * split_size
|
||||
# if worker_id == num_workers - 1:
|
||||
# split = split + '[{}:]'.format(start)
|
||||
# else:
|
||||
# split = split + '[{}:{}]'.format(start, start + split_size)
|
||||
if not self.is_training and '[' not in self.split:
|
||||
# If not training, and split doesn't define a subsplit, manually split the dataset
|
||||
# for more even samples / worker
|
||||
self.subsplit = even_split_indices(self.split, global_num_workers, self.num_samples)[
|
||||
self.dist_rank * num_workers + worker_id]
|
||||
can_subsplit = '[' not in self.split # can't subsplit a subsplit
|
||||
should_subsplit = global_num_workers > 1 and (
|
||||
self.split_info.num_shards < global_num_workers or not self.is_training)
|
||||
if can_subsplit and should_subsplit:
|
||||
# manually split the dataset w/o sharding for more even samples / worker
|
||||
self.subsplit = even_split_indices(self.split, global_num_workers, self.num_samples)[global_worker_id]
|
||||
|
||||
if self.subsplit is None:
|
||||
input_context = None
|
||||
if global_num_workers > 1 and self.subsplit is None:
|
||||
# set input context to divide shards among distributed replicas
|
||||
input_context = tf.distribute.InputContext(
|
||||
num_input_pipelines=self.dist_num_replicas * num_workers,
|
||||
input_pipeline_id=self.dist_rank * num_workers + worker_id,
|
||||
num_input_pipelines=global_num_workers,
|
||||
input_pipeline_id=global_worker_id,
|
||||
num_replicas_in_sync=self.dist_num_replicas # FIXME does this arg have any impact?
|
||||
)
|
||||
else:
|
||||
input_context = None
|
||||
|
||||
read_config = tfds.ReadConfig(
|
||||
shuffle_seed=42,
|
||||
shuffle_seed=self.common_seed,
|
||||
shuffle_reshuffle_each_iteration=True,
|
||||
input_context=input_context)
|
||||
ds = self.builder.as_dataset(
|
||||
split=self.subsplit or self.split, shuffle_files=self.shuffle, read_config=read_config)
|
||||
split=self.subsplit or self.split, shuffle_files=self.is_training, read_config=read_config)
|
||||
# avoid overloading threading w/ combo fo TF ds threads + PyTorch workers
|
||||
options = tf.data.Options()
|
||||
options.experimental_threading.private_threadpool_size = max(1, MAX_TP_SIZE // num_workers)
|
||||
options.experimental_threading.max_intra_op_parallelism = 1
|
||||
thread_member = 'threading' if hasattr(options, 'threading') else 'experimental_threading'
|
||||
getattr(options, thread_member).private_threadpool_size = max(1, MAX_TP_SIZE // num_workers)
|
||||
getattr(options, thread_member).max_intra_op_parallelism = 1
|
||||
ds = ds.with_options(options)
|
||||
if self.is_training or self.repeats > 1:
|
||||
# to prevent excessive drop_last batch behaviour w/ IterableDatasets
|
||||
# see warnings at https://pytorch.org/docs/stable/data.html#multi-process-data-loading
|
||||
ds = ds.repeat() # allow wrap around and break iteration manually
|
||||
if self.shuffle:
|
||||
ds = ds.shuffle(min(self.num_samples, SHUFFLE_SIZE) // self._num_pipelines, seed=0)
|
||||
if self.is_training:
|
||||
ds = ds.shuffle(min(self.num_samples, SHUFFLE_SIZE) // self._num_pipelines, seed=self.worker_seed)
|
||||
ds = ds.prefetch(min(self.num_samples // self._num_pipelines, PREFETCH_SIZE))
|
||||
self.ds = tfds.as_numpy(ds)
|
||||
|
||||
|
21
train.py
21
train.py
@ -70,7 +70,7 @@ parser.add_argument('-c', '--config', default='', type=str, metavar='FILE',
|
||||
|
||||
parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
|
||||
|
||||
# Dataset / Model parameters
|
||||
# Dataset parameters
|
||||
parser.add_argument('data_dir', metavar='DIR',
|
||||
help='path to dataset')
|
||||
parser.add_argument('--dataset', '-d', metavar='NAME', default='',
|
||||
@ -79,6 +79,12 @@ parser.add_argument('--train-split', metavar='NAME', default='train',
|
||||
help='dataset train split (default: train)')
|
||||
parser.add_argument('--val-split', metavar='NAME', default='validation',
|
||||
help='dataset validation split (default: validation)')
|
||||
parser.add_argument('--dataset-download', action='store_true', default=False,
|
||||
help='Allow download of dataset for torch/ and tfds/ datasets that support it.')
|
||||
parser.add_argument('--class-map', default='', type=str, metavar='FILENAME',
|
||||
help='path to class to idx mapping file (default: "")')
|
||||
|
||||
# Model parameters
|
||||
parser.add_argument('--model', default='resnet50', type=str, metavar='MODEL',
|
||||
help='Name of model to train (default: "resnet50"')
|
||||
parser.add_argument('--pretrained', action='store_true', default=False,
|
||||
@ -484,11 +490,16 @@ def main():
|
||||
|
||||
# create the train and eval datasets
|
||||
dataset_train = create_dataset(
|
||||
args.dataset,
|
||||
root=args.data_dir, split=args.train_split, is_training=True,
|
||||
batch_size=args.batch_size, repeats=args.epoch_repeats)
|
||||
args.dataset, root=args.data_dir, split=args.train_split, is_training=True,
|
||||
class_map=args.class_map,
|
||||
download=args.dataset_download,
|
||||
batch_size=args.batch_size,
|
||||
repeats=args.epoch_repeats)
|
||||
dataset_eval = create_dataset(
|
||||
args.dataset, root=args.data_dir, split=args.val_split, is_training=False, batch_size=args.batch_size)
|
||||
args.dataset, root=args.data_dir, split=args.val_split, is_training=False,
|
||||
class_map=args.class_map,
|
||||
download=args.dataset_download,
|
||||
batch_size=args.batch_size)
|
||||
|
||||
# setup mixup / cutmix
|
||||
collate_fn = None
|
||||
|
@ -48,6 +48,8 @@ parser.add_argument('--dataset', '-d', metavar='NAME', default='',
|
||||
help='dataset type (default: ImageFolder/ImageTar if empty)')
|
||||
parser.add_argument('--split', metavar='NAME', default='validation',
|
||||
help='dataset split (default: validation)')
|
||||
parser.add_argument('--dataset-download', action='store_true', default=False,
|
||||
help='Allow download of dataset for torch/ and tfds/ datasets that support it.')
|
||||
parser.add_argument('--model', '-m', metavar='NAME', default='dpn92',
|
||||
help='model architecture (default: dpn92)')
|
||||
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
|
||||
@ -175,7 +177,7 @@ def validate(args):
|
||||
|
||||
dataset = create_dataset(
|
||||
root=args.data, name=args.dataset, split=args.split,
|
||||
load_bytes=args.tf_preprocessing, class_map=args.class_map)
|
||||
download=args.dataset_download, load_bytes=args.tf_preprocessing, class_map=args.class_map)
|
||||
|
||||
if args.valid_labels:
|
||||
with open(args.valid_labels, 'r') as f:
|
||||
|
Loading…
x
Reference in New Issue
Block a user