mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
More TFDS parser cleanup, support improved TFDS even_split impl (on tfds-nightly only currently).
This commit is contained in:
parent
ba65dfe2c6
commit
9ec3210c2d
@ -6,8 +6,6 @@ https://www.tensorflow.org/datasets/catalog/overview#image_classification
|
||||
|
||||
Hacked together by / Copyright 2020 Ross Wightman
|
||||
"""
|
||||
import os
|
||||
import io
|
||||
import math
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@ -17,6 +15,13 @@ try:
|
||||
import tensorflow as tf
|
||||
tf.config.set_visible_devices([], 'GPU') # Hands off my GPU! (or pip install tensorflow-cpu)
|
||||
import tensorflow_datasets as tfds
|
||||
try:
|
||||
tfds.even_splits('', 1, drop_remainder=False) # non-buggy even_splits has drop_remainder arg
|
||||
has_buggy_even_splits = False
|
||||
except TypeError:
|
||||
print("Warning: This version of tfds doesn't have the latest even_splits impl. "
|
||||
"Please update or use tfds-nightly for better fine-grained split behaviour.")
|
||||
has_buggy_even_splits = True
|
||||
except ImportError as e:
|
||||
print(e)
|
||||
print("Please install tensorflow_datasets package `pip install tensorflow-datasets`.")
|
||||
@ -25,7 +30,7 @@ from .parser import Parser
|
||||
|
||||
|
||||
MAX_TP_SIZE = 8 # maximum TF threadpool size, only doing jpeg decodes and queuing activities
|
||||
SHUFFLE_SIZE = 20480 # samples to shuffle in DS queue
|
||||
SHUFFLE_SIZE = 16384 # samples to shuffle in DS queue
|
||||
PREFETCH_SIZE = 2048 # samples to prefetch
|
||||
|
||||
|
||||
@ -58,8 +63,34 @@ class ParserTfds(Parser):
|
||||
|
||||
"""
|
||||
def __init__(
|
||||
self, root, name, split='train', is_training=False, batch_size=None,
|
||||
download=False, repeats=0, seed=42):
|
||||
self,
|
||||
root,
|
||||
name,
|
||||
split='train',
|
||||
is_training=False,
|
||||
batch_size=None,
|
||||
download=False,
|
||||
repeats=0,
|
||||
seed=42,
|
||||
prefetch_size=None,
|
||||
shuffle_size=None,
|
||||
max_threadpool_size=None
|
||||
):
|
||||
""" Tensorflow-datasets Wrapper
|
||||
|
||||
Args:
|
||||
root: root data dir (ie your TFDS_DATA_DIR. not dataset specific sub-dir)
|
||||
name: tfds dataset name (eg `imagenet2012`)
|
||||
split: tfds dataset split (can use all TFDS split strings eg `train[:10%]`)
|
||||
is_training: training mode, shuffle enabled, dataset len rounded by batch_size
|
||||
batch_size: batch_size to use to unsure total samples % batch_size == 0 in training across all dis nodes
|
||||
download: download and build TFDS dataset if set, otherwise must use tfds CLI
|
||||
repeats: iterate through (repeat) the dataset this many times per iteration (once if 0 or 1)
|
||||
seed: common seed for shard shuffle across all distributed/worker instances
|
||||
prefetch_size: override default tf.data prefetch buffer size
|
||||
shuffle_size: override default tf.data shuffle buffer size
|
||||
max_threadpool_size: override default threadpool size for tf.data
|
||||
"""
|
||||
super().__init__()
|
||||
self.root = root
|
||||
self.split = split
|
||||
@ -69,25 +100,33 @@ class ParserTfds(Parser):
|
||||
"Must specify batch_size in training mode for reasonable behaviour w/ TFDS wrapper"
|
||||
self.batch_size = batch_size
|
||||
self.repeats = repeats
|
||||
self.common_seed = seed # seed across all worker / dist nodes
|
||||
self.worker_seed = 0 # seed specific to each work instance
|
||||
self.subsplit = None
|
||||
self.common_seed = seed # a seed that's fixed across all worker / distributed instances
|
||||
self.prefetch_size = prefetch_size or PREFETCH_SIZE
|
||||
self.shuffle_size = shuffle_size or SHUFFLE_SIZE
|
||||
self.max_threadpool_size = max_threadpool_size or MAX_TP_SIZE
|
||||
|
||||
# TFDS builder and split information
|
||||
self.builder = tfds.builder(name, data_dir=root)
|
||||
# NOTE: the tfds command line app can be used download & prepare datasets if you don't enable download flag
|
||||
if download:
|
||||
self.builder.download_and_prepare()
|
||||
self.split_info = self.builder.info.splits[split]
|
||||
self.num_samples = self.split_info.num_examples
|
||||
self.ds = None # initialized lazily on each dataloader worker process
|
||||
|
||||
self.worker_info = None
|
||||
# Distributed world state
|
||||
self.dist_rank = 0
|
||||
self.dist_num_replicas = 1
|
||||
if dist.is_available() and dist.is_initialized() and dist.get_world_size() > 1:
|
||||
self.dist_rank = dist.get_rank()
|
||||
self.dist_num_replicas = dist.get_world_size()
|
||||
|
||||
# Attributes that are updated in _lazy_init, including the tf.data pipeline itself
|
||||
self.global_num_workers = 1
|
||||
self.worker_info = None
|
||||
self.worker_seed = 0 # seed unique to each work instance
|
||||
self.subsplit = None # set when data is distributed across workers using sub-splits
|
||||
self.ds = None # initialized lazily on each dataloader worker process
|
||||
|
||||
def _lazy_init(self):
|
||||
""" Lazily initialize the dataset.
|
||||
|
||||
@ -102,38 +141,44 @@ class ParserTfds(Parser):
|
||||
worker_info = torch.utils.data.get_worker_info()
|
||||
|
||||
# setup input context to split dataset across distributed processes
|
||||
global_num_workers = num_workers = 1
|
||||
global_worker_id = 1
|
||||
num_workers = 1
|
||||
global_worker_id = 0
|
||||
if worker_info is not None:
|
||||
self.worker_info = worker_info
|
||||
self.worker_seed = worker_info.seed
|
||||
num_workers = worker_info.num_workers
|
||||
global_num_workers = self.dist_num_replicas * num_workers
|
||||
worker_id = worker_info.id
|
||||
global_worker_id = self.dist_rank * num_workers + worker_id
|
||||
self.global_num_workers = self.dist_num_replicas * num_workers
|
||||
global_worker_id = self.dist_rank * num_workers + worker_info.id
|
||||
|
||||
# FIXME verify best sharding approach
|
||||
""" Data sharding
|
||||
InputContext will assign subset of underlying TFRecord files to each 'pipeline' if used.
|
||||
My understanding is that using split, the underling TFRecord files will shuffle (shuffle_files=True)
|
||||
between the splits each iteration, but that understanding could be wrong.
|
||||
Possible split options include:
|
||||
* InputContext for both distributed & worker processes (current)
|
||||
* InputContext for distributed and sub-splits for worker processes
|
||||
* sub-splits for both
|
||||
|
||||
I am currently using a mix of InputContext shard assignment and fine-grained sub-splits for distributing
|
||||
the data across workers. For training InputContext is used to assign shards to nodes unless num_shards
|
||||
in dataset < total number of workers. Otherwise sub-split API is used for datasets without enough shards or
|
||||
for validation where we can't drop samples and need to avoid minimize uneven splits to avoid padding.
|
||||
"""
|
||||
can_subsplit = '[' not in self.split # can't subsplit a subsplit
|
||||
should_subsplit = global_num_workers > 1 and (
|
||||
self.split_info.num_shards < global_num_workers or not self.is_training)
|
||||
if can_subsplit and should_subsplit:
|
||||
# manually split the dataset w/o sharding for more even samples / worker
|
||||
self.subsplit = even_split_indices(self.split, global_num_workers, self.num_samples)[global_worker_id]
|
||||
should_subsplit = self.global_num_workers > 1 and (
|
||||
self.split_info.num_shards < self.global_num_workers or not self.is_training)
|
||||
if should_subsplit:
|
||||
# split the dataset w/o using sharding for more even samples / worker, can result in less optimal
|
||||
# read patterns for distributed training (overlap across shards) so better to use InputContext there
|
||||
if has_buggy_even_splits:
|
||||
# my even_split workaround doesn't work on subsplits, upgrade tfds!
|
||||
if not isinstance(self.split_info, tfds.core.splits.SubSplitInfo):
|
||||
subsplits = even_split_indices(self.split, self.global_num_workers, self.num_samples)
|
||||
self.subsplit = subsplits[global_worker_id]
|
||||
else:
|
||||
subsplits = tfds.even_splits(self.split, self.global_num_workers)
|
||||
self.subsplit = subsplits[global_worker_id]
|
||||
|
||||
input_context = None
|
||||
if global_num_workers > 1 and self.subsplit is None:
|
||||
if self.global_num_workers > 1 and self.subsplit is None:
|
||||
# set input context to divide shards among distributed replicas
|
||||
input_context = tf.distribute.InputContext(
|
||||
num_input_pipelines=global_num_workers,
|
||||
num_input_pipelines=self.global_num_workers,
|
||||
input_pipeline_id=global_worker_id,
|
||||
num_replicas_in_sync=self.dist_num_replicas # FIXME does this arg have any impact?
|
||||
)
|
||||
@ -143,10 +188,10 @@ class ParserTfds(Parser):
|
||||
input_context=input_context)
|
||||
ds = self.builder.as_dataset(
|
||||
split=self.subsplit or self.split, shuffle_files=self.is_training, read_config=read_config)
|
||||
# avoid overloading threading w/ combo fo TF ds threads + PyTorch workers
|
||||
# avoid overloading threading w/ combo of TF ds threads + PyTorch workers
|
||||
options = tf.data.Options()
|
||||
thread_member = 'threading' if hasattr(options, 'threading') else 'experimental_threading'
|
||||
getattr(options, thread_member).private_threadpool_size = max(1, MAX_TP_SIZE // num_workers)
|
||||
getattr(options, thread_member).private_threadpool_size = max(1, self.max_threadpool_size // num_workers)
|
||||
getattr(options, thread_member).max_intra_op_parallelism = 1
|
||||
ds = ds.with_options(options)
|
||||
if self.is_training or self.repeats > 1:
|
||||
@ -154,22 +199,25 @@ class ParserTfds(Parser):
|
||||
# see warnings at https://pytorch.org/docs/stable/data.html#multi-process-data-loading
|
||||
ds = ds.repeat() # allow wrap around and break iteration manually
|
||||
if self.is_training:
|
||||
ds = ds.shuffle(min(self.num_samples, SHUFFLE_SIZE) // self._num_pipelines, seed=self.worker_seed)
|
||||
ds = ds.prefetch(min(self.num_samples // self._num_pipelines, PREFETCH_SIZE))
|
||||
ds = ds.shuffle(min(self.num_samples, self.shuffle_size) // self.global_num_workers, seed=self.worker_seed)
|
||||
ds = ds.prefetch(min(self.num_samples // self.global_num_workers, self.prefetch_size))
|
||||
self.ds = tfds.as_numpy(ds)
|
||||
|
||||
def __iter__(self):
|
||||
if self.ds is None:
|
||||
self._lazy_init()
|
||||
# compute a rounded up sample count that is used to:
|
||||
|
||||
# Compute a rounded up sample count that is used to:
|
||||
# 1. make batches even cross workers & replicas in distributed validation.
|
||||
# This adds extra samples and will slightly alter validation results.
|
||||
# 2. determine loop ending condition in training w/ repeat enabled so that only full batch_size
|
||||
# batches are produced (underlying tfds iter wraps around)
|
||||
target_sample_count = math.ceil(max(1, self.repeats) * self.num_samples / self._num_pipelines)
|
||||
target_sample_count = math.ceil(max(1, self.repeats) * self.num_samples / self.global_num_workers)
|
||||
if self.is_training:
|
||||
# round up to nearest batch_size per worker-replica
|
||||
target_sample_count = math.ceil(target_sample_count / self.batch_size) * self.batch_size
|
||||
|
||||
# Iterate until exhausted or sample count hits target when training (ds.repeat enabled)
|
||||
sample_count = 0
|
||||
for sample in self.ds:
|
||||
img = Image.fromarray(sample['image'], mode='RGB')
|
||||
@ -180,21 +228,17 @@ class ParserTfds(Parser):
|
||||
# this results in extra samples per epoch but seems more desirable than dropping
|
||||
# up to N*J batches per epoch (where N = num distributed processes, and J = num worker processes)
|
||||
break
|
||||
if not self.is_training and self.dist_num_replicas and 0 < sample_count < target_sample_count:
|
||||
|
||||
# Pad across distributed nodes (make counts equal by adding samples)
|
||||
if not self.is_training and self.dist_num_replicas > 1 and self.subsplit is not None and \
|
||||
0 < sample_count < target_sample_count:
|
||||
# Validation batch padding only done for distributed training where results are reduced across nodes.
|
||||
# For single process case, it won't matter if workers return different batch sizes.
|
||||
# FIXME if using input_context or % based subsplits, sample count can vary by more than +/- 1 and this
|
||||
# approach is not optimal
|
||||
yield img, sample['label'] # yield prev sample again
|
||||
sample_count += 1
|
||||
|
||||
@property
|
||||
def _num_workers(self):
|
||||
return 1 if self.worker_info is None else self.worker_info.num_workers
|
||||
|
||||
@property
|
||||
def _num_pipelines(self):
|
||||
return self._num_workers * self.dist_num_replicas
|
||||
# If using input_context or % based splits, sample count can vary significantly across workers and this
|
||||
# approach should not be used (hence disabled if self.subsplit isn't set).
|
||||
while sample_count < target_sample_count:
|
||||
yield img, sample['label'] # yield prev sample again
|
||||
sample_count += 1
|
||||
|
||||
def __len__(self):
|
||||
# this is just an estimate and does not factor in extra samples added to pad batches based on
|
||||
@ -202,7 +246,7 @@ class ParserTfds(Parser):
|
||||
return math.ceil(max(1, self.repeats) * self.num_samples / self.dist_num_replicas)
|
||||
|
||||
def _filename(self, index, basename=False, absolute=False):
|
||||
assert False, "Not supported" # no random access to samples
|
||||
assert False, "Not supported" # no random access to samples
|
||||
|
||||
def filenames(self, basename=False, absolute=False):
|
||||
""" Return all filenames in dataset, overrides base"""
|
||||
|
Loading…
x
Reference in New Issue
Block a user