mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Fix tf.data options setting for newer TF versions
This commit is contained in:
parent
94d4b53352
commit
d53e91218e
@ -25,8 +25,8 @@ from .parser import Parser
|
|||||||
|
|
||||||
|
|
||||||
MAX_TP_SIZE = 8 # maximum TF threadpool size, only doing jpeg decodes and queuing activities
|
MAX_TP_SIZE = 8 # maximum TF threadpool size, only doing jpeg decodes and queuing activities
|
||||||
SHUFFLE_SIZE = 16834 # samples to shuffle in DS queue
|
SHUFFLE_SIZE = 20480 # samples to shuffle in DS queue
|
||||||
PREFETCH_SIZE = 4096 # samples to prefetch
|
PREFETCH_SIZE = 2048 # samples to prefetch
|
||||||
|
|
||||||
|
|
||||||
def even_split_indices(split, n, num_samples):
|
def even_split_indices(split, n, num_samples):
|
||||||
@ -144,14 +144,16 @@ class ParserTfds(Parser):
|
|||||||
ds = self.builder.as_dataset(
|
ds = self.builder.as_dataset(
|
||||||
split=self.subsplit or self.split, shuffle_files=self.shuffle, read_config=read_config)
|
split=self.subsplit or self.split, shuffle_files=self.shuffle, read_config=read_config)
|
||||||
# avoid overloading threading w/ combo fo TF ds threads + PyTorch workers
|
# avoid overloading threading w/ combo fo TF ds threads + PyTorch workers
|
||||||
ds.options().experimental_threading.private_threadpool_size = max(1, MAX_TP_SIZE // num_workers)
|
options = tf.data.Options()
|
||||||
ds.options().experimental_threading.max_intra_op_parallelism = 1
|
options.experimental_threading.private_threadpool_size = max(1, MAX_TP_SIZE // num_workers)
|
||||||
|
options.experimental_threading.max_intra_op_parallelism = 1
|
||||||
|
ds = ds.with_options(options)
|
||||||
if self.is_training or self.repeats > 1:
|
if self.is_training or self.repeats > 1:
|
||||||
# to prevent excessive drop_last batch behaviour w/ IterableDatasets
|
# to prevent excessive drop_last batch behaviour w/ IterableDatasets
|
||||||
# see warnings at https://pytorch.org/docs/stable/data.html#multi-process-data-loading
|
# see warnings at https://pytorch.org/docs/stable/data.html#multi-process-data-loading
|
||||||
ds = ds.repeat() # allow wrap around and break iteration manually
|
ds = ds.repeat() # allow wrap around and break iteration manually
|
||||||
if self.shuffle:
|
if self.shuffle:
|
||||||
ds = ds.shuffle(min(self.num_samples // self._num_pipelines, SHUFFLE_SIZE), seed=0)
|
ds = ds.shuffle(min(self.num_samples, SHUFFLE_SIZE) // self._num_pipelines, seed=0)
|
||||||
ds = ds.prefetch(min(self.num_samples // self._num_pipelines, PREFETCH_SIZE))
|
ds = ds.prefetch(min(self.num_samples // self._num_pipelines, PREFETCH_SIZE))
|
||||||
self.ds = tfds.as_numpy(ds)
|
self.ds = tfds.as_numpy(ds)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user