Data improvements. Improve train support for in_chans != 3. Add wds dataset support from bits_and_tpu branch w/ fixes and tweaks. TFDS tweaks.
parent
87939e6fab
commit
b8c8550841
|
@ -89,6 +89,7 @@ class IterableImageDataset(data.IterableDataset):
|
||||||
split='train',
|
split='train',
|
||||||
is_training=False,
|
is_training=False,
|
||||||
batch_size=None,
|
batch_size=None,
|
||||||
|
seed=42,
|
||||||
repeats=0,
|
repeats=0,
|
||||||
download=False,
|
download=False,
|
||||||
transform=None,
|
transform=None,
|
||||||
|
@ -102,6 +103,7 @@ class IterableImageDataset(data.IterableDataset):
|
||||||
split=split,
|
split=split,
|
||||||
is_training=is_training,
|
is_training=is_training,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
|
seed=seed,
|
||||||
repeats=repeats,
|
repeats=repeats,
|
||||||
download=download,
|
download=download,
|
||||||
)
|
)
|
||||||
|
@ -125,6 +127,11 @@ class IterableImageDataset(data.IterableDataset):
|
||||||
else:
|
else:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
def set_epoch(self, count):
|
||||||
|
# TFDS and WDS need external epoch count for deterministic cross process shuffle
|
||||||
|
if hasattr(self.parser, 'set_epoch'):
|
||||||
|
self.parser.set_epoch(count)
|
||||||
|
|
||||||
def filename(self, index, basename=False, absolute=False):
|
def filename(self, index, basename=False, absolute=False):
|
||||||
assert False, 'Filename lookup by index not supported, use filenames().'
|
assert False, 'Filename lookup by index not supported, use filenames().'
|
||||||
|
|
||||||
|
|
|
@ -60,6 +60,7 @@ def create_dataset(
|
||||||
is_training=False,
|
is_training=False,
|
||||||
download=False,
|
download=False,
|
||||||
batch_size=None,
|
batch_size=None,
|
||||||
|
seed=42,
|
||||||
repeats=0,
|
repeats=0,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
|
@ -68,7 +69,9 @@ def create_dataset(
|
||||||
In parenthesis after each arg are the type of dataset supported for each arg, one of:
|
In parenthesis after each arg are the type of dataset supported for each arg, one of:
|
||||||
* folder - default, timm folder (or tar) based ImageDataset
|
* folder - default, timm folder (or tar) based ImageDataset
|
||||||
* torch - torchvision based datasets
|
* torch - torchvision based datasets
|
||||||
|
* HFDS - Hugging Face Datasets
|
||||||
* TFDS - Tensorflow-datasets wrapper in IterabeDataset interface via IterableImageDataset
|
* TFDS - Tensorflow-datasets wrapper in IterabeDataset interface via IterableImageDataset
|
||||||
|
* WDS - Webdataset
|
||||||
* all - any of the above
|
* all - any of the above
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -79,11 +82,12 @@ def create_dataset(
|
||||||
`imagenet/` instead of `/imagenet/val`, etc on cmd line / config. (folder, torch/folder)
|
`imagenet/` instead of `/imagenet/val`, etc on cmd line / config. (folder, torch/folder)
|
||||||
class_map: specify class -> index mapping via text file or dict (folder)
|
class_map: specify class -> index mapping via text file or dict (folder)
|
||||||
load_bytes: load data, return images as undecoded bytes (folder)
|
load_bytes: load data, return images as undecoded bytes (folder)
|
||||||
download: download dataset if not present and supported (TFDS, torch)
|
download: download dataset if not present and supported (HFDS, TFDS, torch)
|
||||||
is_training: create dataset in train mode, this is different from the split.
|
is_training: create dataset in train mode, this is different from the split.
|
||||||
For Iterable / TDFS it enables shuffle, ignored for other datasets. (TFDS)
|
For Iterable / TDFS it enables shuffle, ignored for other datasets. (TFDS, WDS)
|
||||||
batch_size: batch size hint for (TFDS)
|
batch_size: batch size hint for (TFDS, WDS)
|
||||||
repeats: dataset repeats per iteration i.e. epoch (TFDS)
|
seed: seed for iterable datasets (TFDS, WDS)
|
||||||
|
repeats: dataset repeats per iteration i.e. epoch (TFDS, WDS)
|
||||||
**kwargs: other args to pass to dataset
|
**kwargs: other args to pass to dataset
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
@ -130,14 +134,33 @@ def create_dataset(
|
||||||
ds = ImageFolder(root, **kwargs)
|
ds = ImageFolder(root, **kwargs)
|
||||||
else:
|
else:
|
||||||
assert False, f"Unknown torchvision dataset {name}"
|
assert False, f"Unknown torchvision dataset {name}"
|
||||||
elif name.startswith('tfds/'):
|
|
||||||
ds = IterableImageDataset(
|
|
||||||
root, parser=name, split=split, is_training=is_training,
|
|
||||||
download=download, batch_size=batch_size, repeats=repeats, **kwargs)
|
|
||||||
elif name.startswith('hfds/'):
|
elif name.startswith('hfds/'):
|
||||||
# NOTE right now, HF datasets default arrow format is a random-access Dataset,
|
# NOTE right now, HF datasets default arrow format is a random-access Dataset,
|
||||||
# There will be a IterableDataset variant too, TBD
|
# There will be a IterableDataset variant too, TBD
|
||||||
ds = ImageDataset(root, parser=name, split=split, **kwargs)
|
ds = ImageDataset(root, parser=name, split=split, **kwargs)
|
||||||
|
elif name.startswith('tfds/'):
|
||||||
|
ds = IterableImageDataset(
|
||||||
|
root,
|
||||||
|
parser=name,
|
||||||
|
split=split,
|
||||||
|
is_training=is_training,
|
||||||
|
download=download,
|
||||||
|
batch_size=batch_size,
|
||||||
|
repeats=repeats,
|
||||||
|
seed=seed,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
elif name.startswith('wds/'):
|
||||||
|
ds = IterableImageDataset(
|
||||||
|
root,
|
||||||
|
parser=name,
|
||||||
|
split=split,
|
||||||
|
is_training=is_training,
|
||||||
|
batch_size=batch_size,
|
||||||
|
repeats=repeats,
|
||||||
|
seed=seed,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# FIXME support more advance split cfg for ImageFolder/Tar datasets in the future
|
# FIXME support more advance split cfg for ImageFolder/Tar datasets in the future
|
||||||
if search_split and os.path.isdir(root):
|
if search_split and os.path.isdir(root):
|
||||||
|
|
|
@ -5,6 +5,7 @@ https://github.com/NVIDIA/apex/commit/d5e2bb4bdeedd27b1dfaf5bb2b24d6c000dee9be#d
|
||||||
|
|
||||||
Hacked together by / Copyright 2019, Ross Wightman
|
Hacked together by / Copyright 2019, Ross Wightman
|
||||||
"""
|
"""
|
||||||
|
import logging
|
||||||
import random
|
import random
|
||||||
from contextlib import suppress
|
from contextlib import suppress
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
@ -22,6 +23,9 @@ from .random_erasing import RandomErasing
|
||||||
from .mixup import FastCollateMixup
|
from .mixup import FastCollateMixup
|
||||||
|
|
||||||
|
|
||||||
|
_logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def fast_collate(batch):
|
def fast_collate(batch):
|
||||||
""" A fast collation function optimized for uint8 images (np array or torch) and int64 targets (labels)"""
|
""" A fast collation function optimized for uint8 images (np array or torch) and int64 targets (labels)"""
|
||||||
assert isinstance(batch[0], tuple)
|
assert isinstance(batch[0], tuple)
|
||||||
|
@ -57,11 +61,13 @@ def fast_collate(batch):
|
||||||
assert False
|
assert False
|
||||||
|
|
||||||
|
|
||||||
def expand_to_chs(x, n):
|
def adapt_to_chs(x, n):
|
||||||
if not isinstance(x, (tuple, list)):
|
if not isinstance(x, (tuple, list)):
|
||||||
x = tuple(repeat(x, n))
|
x = tuple(repeat(x, n))
|
||||||
elif len(x) == 1:
|
elif len(x) != n:
|
||||||
x = x * n
|
x_mean = np.mean(x).item()
|
||||||
|
x = (x_mean,) * n
|
||||||
|
_logger.warning(f'Pretrained mean/std different shape than model, using avg value {x}.')
|
||||||
else:
|
else:
|
||||||
assert len(x) == n, 'normalization stats must match image channels'
|
assert len(x) == n, 'normalization stats must match image channels'
|
||||||
return x
|
return x
|
||||||
|
@ -83,8 +89,8 @@ class PrefetchLoader:
|
||||||
re_count=1,
|
re_count=1,
|
||||||
re_num_splits=0):
|
re_num_splits=0):
|
||||||
|
|
||||||
mean = expand_to_chs(mean, channels)
|
mean = adapt_to_chs(mean, channels)
|
||||||
std = expand_to_chs(std, channels)
|
std = adapt_to_chs(std, channels)
|
||||||
normalization_shape = (1, channels, 1, 1)
|
normalization_shape = (1, channels, 1, 1)
|
||||||
|
|
||||||
self.loader = loader
|
self.loader = loader
|
||||||
|
|
|
@ -14,12 +14,16 @@ def create_parser(name, root, split='train', **kwargs):
|
||||||
|
|
||||||
# FIXME improve the selection right now just tfds prefix or fallback path, will need options to
|
# FIXME improve the selection right now just tfds prefix or fallback path, will need options to
|
||||||
# explicitly select other options shortly
|
# explicitly select other options shortly
|
||||||
if prefix == 'tfds':
|
if prefix == 'hfds':
|
||||||
from .parser_tfds import ParserTfds # defer tensorflow import
|
|
||||||
parser = ParserTfds(root, name, split=split, **kwargs)
|
|
||||||
elif prefix == 'hfds':
|
|
||||||
from .parser_hfds import ParserHfds # defer tensorflow import
|
from .parser_hfds import ParserHfds # defer tensorflow import
|
||||||
parser = ParserHfds(root, name, split=split, **kwargs)
|
parser = ParserHfds(root, name, split=split, **kwargs)
|
||||||
|
elif prefix == 'tfds':
|
||||||
|
from .parser_tfds import ParserTfds # defer tensorflow import
|
||||||
|
parser = ParserTfds(root, name, split=split, **kwargs)
|
||||||
|
elif prefix == 'wds':
|
||||||
|
from .parser_wds import ParserWds
|
||||||
|
kwargs.pop('download', False)
|
||||||
|
parser = ParserWds(root, name, split=split, **kwargs)
|
||||||
else:
|
else:
|
||||||
assert os.path.exists(root)
|
assert os.path.exists(root)
|
||||||
# default fallback path (backwards compat), use image tar if root is a .tar file, otherwise image folder
|
# default fallback path (backwards compat), use image tar if root is a .tar file, otherwise image folder
|
||||||
|
|
|
@ -7,6 +7,8 @@ https://www.tensorflow.org/datasets/catalog/overview#image_classification
|
||||||
Hacked together by / Copyright 2020 Ross Wightman
|
Hacked together by / Copyright 2020 Ross Wightman
|
||||||
"""
|
"""
|
||||||
import math
|
import math
|
||||||
|
import os
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
@ -30,12 +32,14 @@ except ImportError as e:
|
||||||
print(e)
|
print(e)
|
||||||
print("Please install tensorflow_datasets package `pip install tensorflow-datasets`.")
|
print("Please install tensorflow_datasets package `pip install tensorflow-datasets`.")
|
||||||
exit(1)
|
exit(1)
|
||||||
|
|
||||||
from .parser import Parser
|
from .parser import Parser
|
||||||
|
from .shared_count import SharedCount
|
||||||
|
|
||||||
|
|
||||||
MAX_TP_SIZE = 8 # maximum TF threadpool size, only doing jpeg decodes and queuing activities
|
MAX_TP_SIZE = os.environ.get('TFDS_TP_SIZE', 8) # maximum TF threadpool size, for jpeg decodes and queuing activities
|
||||||
SHUFFLE_SIZE = 8192 # examples to shuffle in DS queue
|
SHUFFLE_SIZE = os.environ.get('TFDS_SHUFFLE_SIZE', 8192) # examples to shuffle in DS queue
|
||||||
PREFETCH_SIZE = 2048 # examples to prefetch
|
PREFETCH_SIZE = os.environ.get('TFDS_PREFETCH_SIZE', 2048) # examples to prefetch
|
||||||
|
|
||||||
|
|
||||||
def even_split_indices(split, n, num_examples):
|
def even_split_indices(split, n, num_examples):
|
||||||
|
@ -154,6 +158,14 @@ class ParserTfds(Parser):
|
||||||
self.worker_seed = 0 # seed unique to each work instance
|
self.worker_seed = 0 # seed unique to each work instance
|
||||||
self.subsplit = None # set when data is distributed across workers using sub-splits
|
self.subsplit = None # set when data is distributed across workers using sub-splits
|
||||||
self.ds = None # initialized lazily on each dataloader worker process
|
self.ds = None # initialized lazily on each dataloader worker process
|
||||||
|
self.init_count = 0 # number of ds TF data pipeline initializations
|
||||||
|
self.epoch_count = SharedCount()
|
||||||
|
# FIXME need to determine if reinit_each_iter is necessary. I'm don't completely trust behaviour
|
||||||
|
# of `shuffle_reshuffle_each_iteration` when there are multiple workers / nodes across epochs
|
||||||
|
self.reinit_each_iter = self.is_training
|
||||||
|
|
||||||
|
def set_epoch(self, count):
|
||||||
|
self.epoch_count.value = count
|
||||||
|
|
||||||
def _lazy_init(self):
|
def _lazy_init(self):
|
||||||
""" Lazily initialize the dataset.
|
""" Lazily initialize the dataset.
|
||||||
|
@ -211,11 +223,15 @@ class ParserTfds(Parser):
|
||||||
num_replicas_in_sync=self.dist_num_replicas # FIXME does this arg have any impact?
|
num_replicas_in_sync=self.dist_num_replicas # FIXME does this arg have any impact?
|
||||||
)
|
)
|
||||||
read_config = tfds.ReadConfig(
|
read_config = tfds.ReadConfig(
|
||||||
shuffle_seed=self.common_seed,
|
shuffle_seed=self.common_seed + self.epoch_count.value,
|
||||||
shuffle_reshuffle_each_iteration=True,
|
shuffle_reshuffle_each_iteration=True,
|
||||||
input_context=input_context)
|
input_context=input_context,
|
||||||
|
)
|
||||||
ds = self.builder.as_dataset(
|
ds = self.builder.as_dataset(
|
||||||
split=self.subsplit or self.split, shuffle_files=self.is_training, read_config=read_config)
|
split=self.subsplit or self.split,
|
||||||
|
shuffle_files=self.is_training,
|
||||||
|
read_config=read_config,
|
||||||
|
)
|
||||||
# avoid overloading threading w/ combo of TF ds threads + PyTorch workers
|
# avoid overloading threading w/ combo of TF ds threads + PyTorch workers
|
||||||
options = tf.data.Options()
|
options = tf.data.Options()
|
||||||
thread_member = 'threading' if hasattr(options, 'threading') else 'experimental_threading'
|
thread_member = 'threading' if hasattr(options, 'threading') else 'experimental_threading'
|
||||||
|
@ -230,9 +246,10 @@ class ParserTfds(Parser):
|
||||||
ds = ds.shuffle(min(self.num_examples, self.shuffle_size) // self.global_num_workers, seed=self.worker_seed)
|
ds = ds.shuffle(min(self.num_examples, self.shuffle_size) // self.global_num_workers, seed=self.worker_seed)
|
||||||
ds = ds.prefetch(min(self.num_examples // self.global_num_workers, self.prefetch_size))
|
ds = ds.prefetch(min(self.num_examples // self.global_num_workers, self.prefetch_size))
|
||||||
self.ds = tfds.as_numpy(ds)
|
self.ds = tfds.as_numpy(ds)
|
||||||
|
self.init_count += 1
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
if self.ds is None:
|
if self.ds is None or self.reinit_each_iter:
|
||||||
self._lazy_init()
|
self._lazy_init()
|
||||||
|
|
||||||
# Compute a rounded up sample count that is used to:
|
# Compute a rounded up sample count that is used to:
|
||||||
|
|
|
@ -0,0 +1,448 @@
|
||||||
|
""" Dataset parser interface for webdataset
|
||||||
|
|
||||||
|
Hacked together by / Copyright 2022 Ross Wightman
|
||||||
|
"""
|
||||||
|
import io
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import sys
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from functools import partial
|
||||||
|
from itertools import islice
|
||||||
|
from typing import Dict, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
import yaml
|
||||||
|
from PIL import Image
|
||||||
|
from torch.utils.data import Dataset, IterableDataset, get_worker_info
|
||||||
|
|
||||||
|
try:
|
||||||
|
import webdataset as wds
|
||||||
|
from webdataset.filters import _shuffle
|
||||||
|
from webdataset.shardlists import expand_urls
|
||||||
|
from webdataset.tariterators import base_plus_ext, url_opener, tar_file_expander, valid_sample
|
||||||
|
except ImportError:
|
||||||
|
wds = None
|
||||||
|
expand_urls = None
|
||||||
|
|
||||||
|
from .parser import Parser
|
||||||
|
from .shared_count import SharedCount
|
||||||
|
|
||||||
|
_logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
SHUFFLE_SIZE = os.environ.get('WDS_SHUFFLE_SIZE', 8192)
|
||||||
|
|
||||||
|
|
||||||
|
def _load_info(root, basename='info'):
|
||||||
|
info_json = os.path.join(root, basename + '.json')
|
||||||
|
info_yaml = os.path.join(root, basename + '.yaml')
|
||||||
|
err_str = ''
|
||||||
|
try:
|
||||||
|
with wds.gopen.gopen(info_json) as f:
|
||||||
|
info_dict = json.load(f)
|
||||||
|
return info_dict
|
||||||
|
except Exception as e:
|
||||||
|
err_str = str(e)
|
||||||
|
try:
|
||||||
|
with wds.gopen.gopen(info_yaml) as f:
|
||||||
|
info_dict = yaml.safe_load(f)
|
||||||
|
return info_dict
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
_logger.warning(
|
||||||
|
f'Dataset info file not found at {info_json} or {info_yaml}. Error: {err_str}. '
|
||||||
|
'Falling back to provided split and size arg.')
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SplitInfo:
|
||||||
|
num_samples: int
|
||||||
|
filenames: Tuple[str]
|
||||||
|
shard_lengths: Tuple[int] = ()
|
||||||
|
alt_label: str = ''
|
||||||
|
name: str = ''
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_split_info(split: str, info: Dict):
|
||||||
|
def _info_convert(dict_info):
|
||||||
|
return SplitInfo(
|
||||||
|
num_samples=dict_info['num_samples'],
|
||||||
|
filenames=tuple(dict_info['filenames']),
|
||||||
|
shard_lengths=tuple(dict_info['shard_lengths']),
|
||||||
|
alt_label=dict_info.get('alt_label', ''),
|
||||||
|
name=dict_info['name'],
|
||||||
|
)
|
||||||
|
|
||||||
|
if 'tar' in split or '..' in split:
|
||||||
|
# split in WDS string braceexpand format, sample count can be included with a | separator
|
||||||
|
# ex: `dataset-split-{0000..9999}.tar|100000` for 9999 shards, covering 100,000 samples
|
||||||
|
split = split.split('|')
|
||||||
|
num_samples = 0
|
||||||
|
split_name = ''
|
||||||
|
if len(split) > 1:
|
||||||
|
num_samples = int(split[1])
|
||||||
|
split = split[0]
|
||||||
|
if '::' not in split:
|
||||||
|
split_parts = split.split('-', 3)
|
||||||
|
split_idx = len(split_parts) - 1
|
||||||
|
if split_idx and 'splits' in info and split_parts[split_idx] in info['splits']:
|
||||||
|
split_name = split_parts[split_idx]
|
||||||
|
|
||||||
|
split_filenames = expand_urls(split)
|
||||||
|
if split_name:
|
||||||
|
split_info = info['splits'][split_name]
|
||||||
|
if not num_samples:
|
||||||
|
_fc = {f: c for f, c in zip(split_info['filenames'], split_info['shard_lengths'])}
|
||||||
|
num_samples = sum(_fc[f] for f in split_filenames)
|
||||||
|
split_info['filenames'] = tuple(_fc.keys())
|
||||||
|
split_info['shard_lengths'] = tuple(_fc.values())
|
||||||
|
split_info['num_samples'] = num_samples
|
||||||
|
split_info = _info_convert(split_info)
|
||||||
|
else:
|
||||||
|
split_info = SplitInfo(
|
||||||
|
name=split_name,
|
||||||
|
num_samples=num_samples,
|
||||||
|
filenames=split_filenames,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if split not in info['splits']:
|
||||||
|
raise RuntimeError(f"split {split} not found in info ({info['splits'].keys()})")
|
||||||
|
split = split
|
||||||
|
split_info = info['splits'][split]
|
||||||
|
split_info = _info_convert(split_info)
|
||||||
|
|
||||||
|
return split_info
|
||||||
|
|
||||||
|
|
||||||
|
def log_and_continue(exn):
|
||||||
|
"""Call in an exception handler to ignore any exception, isssue a warning, and continue."""
|
||||||
|
_logger.warning(f'Handling webdataset error ({repr(exn)}). Ignoring.')
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def _decode(
|
||||||
|
sample,
|
||||||
|
image_key='jpg',
|
||||||
|
image_format='RGB',
|
||||||
|
target_key='cls',
|
||||||
|
alt_label=''
|
||||||
|
):
|
||||||
|
""" Custom sample decode
|
||||||
|
* decode and convert PIL Image
|
||||||
|
* cls byte string label to int
|
||||||
|
* pass through JSON byte string (if it exists) without parse
|
||||||
|
"""
|
||||||
|
# decode class label, skip if alternate label not valid
|
||||||
|
if alt_label:
|
||||||
|
# alternative labels are encoded in json metadata
|
||||||
|
meta = json.loads(sample['json'])
|
||||||
|
class_label = int(meta[alt_label])
|
||||||
|
if class_label < 0:
|
||||||
|
# skipped labels currently encoded as -1, may change to a null/None value
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
class_label = int(sample[target_key])
|
||||||
|
|
||||||
|
# decode image
|
||||||
|
with io.BytesIO(sample[image_key]) as b:
|
||||||
|
img = Image.open(b)
|
||||||
|
img.load()
|
||||||
|
if image_format:
|
||||||
|
img = img.convert(image_format)
|
||||||
|
|
||||||
|
# json passed through in undecoded state
|
||||||
|
decoded = dict(jpg=img, cls=class_label, json=sample.get('json', None))
|
||||||
|
return decoded
|
||||||
|
|
||||||
|
|
||||||
|
def _decode_samples(
|
||||||
|
data,
|
||||||
|
image_key='jpg',
|
||||||
|
image_format='RGB',
|
||||||
|
target_key='cls',
|
||||||
|
alt_label='',
|
||||||
|
handler=log_and_continue):
|
||||||
|
"""Decode samples with skip."""
|
||||||
|
for sample in data:
|
||||||
|
try:
|
||||||
|
result = _decode(
|
||||||
|
sample,
|
||||||
|
image_key=image_key,
|
||||||
|
image_format=image_format,
|
||||||
|
target_key=target_key,
|
||||||
|
alt_label=alt_label
|
||||||
|
)
|
||||||
|
except Exception as exn:
|
||||||
|
if handler(exn):
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
|
# null results are skipped
|
||||||
|
if result is not None:
|
||||||
|
if isinstance(sample, dict) and isinstance(result, dict):
|
||||||
|
result["__key__"] = sample.get("__key__")
|
||||||
|
yield result
|
||||||
|
|
||||||
|
|
||||||
|
def pytorch_worker_seed():
|
||||||
|
"""get dataloader worker seed from pytorch"""
|
||||||
|
worker_info = get_worker_info()
|
||||||
|
if worker_info is not None:
|
||||||
|
# favour the seed already created for pytorch dataloader workers if it exists
|
||||||
|
return worker_info.seed
|
||||||
|
# fallback to wds rank based seed
|
||||||
|
return wds.utils.pytorch_worker_seed()
|
||||||
|
|
||||||
|
|
||||||
|
if wds is not None:
|
||||||
|
# conditional to avoid mandatory wds import (via inheritance of wds.PipelineStage)
|
||||||
|
class detshuffle2(wds.PipelineStage):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
bufsize=1000,
|
||||||
|
initial=100,
|
||||||
|
seed=0,
|
||||||
|
epoch=-1,
|
||||||
|
):
|
||||||
|
self.bufsize = bufsize
|
||||||
|
self.initial = initial
|
||||||
|
self.seed = seed
|
||||||
|
self.epoch = epoch
|
||||||
|
|
||||||
|
def run(self, src):
|
||||||
|
if isinstance(self.epoch, SharedCount):
|
||||||
|
epoch = self.epoch.value
|
||||||
|
else:
|
||||||
|
# NOTE: this is epoch tracking is problematic in a multiprocess (dataloader workers or train)
|
||||||
|
# situation as different workers may wrap at different times (or not at all).
|
||||||
|
self.epoch += 1
|
||||||
|
epoch = self.epoch
|
||||||
|
|
||||||
|
if self.seed < 0:
|
||||||
|
seed = pytorch_worker_seed() + epoch
|
||||||
|
else:
|
||||||
|
seed = self.seed + epoch
|
||||||
|
_logger.info('shuffle', self.seed, epoch, seed) # FIXME temporary
|
||||||
|
rng = random.Random(seed)
|
||||||
|
return _shuffle(src, self.bufsize, self.initial, rng)
|
||||||
|
|
||||||
|
else:
|
||||||
|
detshuffle2 = None
|
||||||
|
|
||||||
|
|
||||||
|
class ResampledShards2(IterableDataset):
|
||||||
|
"""An iterable dataset yielding a list of urls."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
urls,
|
||||||
|
nshards=sys.maxsize,
|
||||||
|
worker_seed=None,
|
||||||
|
deterministic=True,
|
||||||
|
epoch=-1,
|
||||||
|
):
|
||||||
|
"""Sample shards from the shard list with replacement.
|
||||||
|
|
||||||
|
:param urls: a list of URLs as a Python list or brace notation string
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
urls = wds.shardlists.expand_urls(urls)
|
||||||
|
self.urls = urls
|
||||||
|
assert isinstance(self.urls[0], str)
|
||||||
|
self.nshards = nshards
|
||||||
|
self.rng = random.Random()
|
||||||
|
self.worker_seed = pytorch_worker_seed if worker_seed is None else worker_seed
|
||||||
|
self.deterministic = deterministic
|
||||||
|
self.epoch = epoch
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
"""Return an iterator over the shards."""
|
||||||
|
if isinstance(self.epoch, SharedCount):
|
||||||
|
epoch = self.epoch.value
|
||||||
|
else:
|
||||||
|
# NOTE: this is epoch tracking is problematic in a multiprocess (dataloader workers or train)
|
||||||
|
# situation as different workers may wrap at different times (or not at all).
|
||||||
|
self.epoch += 1
|
||||||
|
epoch = self.epoch
|
||||||
|
|
||||||
|
if self.deterministic:
|
||||||
|
# reset seed w/ epoch if deterministic, worker seed should be deterministic due to arg.seed
|
||||||
|
self.rng = random.Random(self.worker_seed() + epoch)
|
||||||
|
|
||||||
|
for _ in range(self.nshards):
|
||||||
|
index = self.rng.randint(0, len(self.urls) - 1)
|
||||||
|
yield dict(url=self.urls[index])
|
||||||
|
|
||||||
|
|
||||||
|
class ParserWds(Parser):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
root,
|
||||||
|
name,
|
||||||
|
split,
|
||||||
|
is_training=False,
|
||||||
|
batch_size=None,
|
||||||
|
repeats=0,
|
||||||
|
seed=42,
|
||||||
|
input_name='jpg',
|
||||||
|
input_image='RGB',
|
||||||
|
target_name='cls',
|
||||||
|
target_image='',
|
||||||
|
prefetch_size=None,
|
||||||
|
shuffle_size=None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
if wds is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
'Please install webdataset 0.2.x package `pip install git+https://github.com/webdataset/webdataset`.')
|
||||||
|
self.root = root
|
||||||
|
self.is_training = is_training
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.repeats = repeats
|
||||||
|
self.common_seed = seed # a seed that's fixed across all worker / distributed instances
|
||||||
|
self.shard_shuffle_size = 500
|
||||||
|
self.sample_shuffle_size = shuffle_size or SHUFFLE_SIZE
|
||||||
|
|
||||||
|
self.image_key = input_name
|
||||||
|
self.image_format = input_image
|
||||||
|
self.target_key = target_name
|
||||||
|
self.filename_key = 'filename'
|
||||||
|
self.key_ext = '.JPEG' # extension to add to key for original filenames (DS specific, default ImageNet)
|
||||||
|
|
||||||
|
self.info = _load_info(self.root)
|
||||||
|
self.split_info = _parse_split_info(split, self.info)
|
||||||
|
self.num_samples = self.split_info.num_samples
|
||||||
|
if not self.num_samples:
|
||||||
|
raise RuntimeError(f'Invalid split definition, no samples found.')
|
||||||
|
|
||||||
|
# Distributed world state
|
||||||
|
self.dist_rank = 0
|
||||||
|
self.dist_num_replicas = 1
|
||||||
|
if dist.is_available() and dist.is_initialized() and dist.get_world_size() > 1:
|
||||||
|
self.dist_rank = dist.get_rank()
|
||||||
|
self.dist_num_replicas = dist.get_world_size()
|
||||||
|
|
||||||
|
# Attributes that are updated in _lazy_init
|
||||||
|
self.worker_info = None
|
||||||
|
self.worker_id = 0
|
||||||
|
self.worker_seed = seed # seed unique to each worker instance
|
||||||
|
self.num_workers = 1
|
||||||
|
self.global_worker_id = 0
|
||||||
|
self.global_num_workers = 1
|
||||||
|
self.init_count = 0
|
||||||
|
self.epoch_count = SharedCount()
|
||||||
|
|
||||||
|
# DataPipeline is lazy init, majority of WDS DataPipeline could be init here, BUT, shuffle seed
|
||||||
|
# is not handled in manner where it can be deterministic for each worker AND initialized up front
|
||||||
|
self.ds = None
|
||||||
|
|
||||||
|
def set_epoch(self, count):
|
||||||
|
self.epoch_count.value = count
|
||||||
|
|
||||||
|
def _lazy_init(self):
|
||||||
|
""" Lazily initialize worker (in worker processes)
|
||||||
|
"""
|
||||||
|
if self.worker_info is None:
|
||||||
|
worker_info = torch.utils.data.get_worker_info()
|
||||||
|
if worker_info is not None:
|
||||||
|
self.worker_info = worker_info
|
||||||
|
self.worker_id = worker_info.id
|
||||||
|
self.worker_seed = worker_info.seed
|
||||||
|
self.num_workers = worker_info.num_workers
|
||||||
|
self.global_num_workers = self.dist_num_replicas * self.num_workers
|
||||||
|
self.global_worker_id = self.dist_rank * self.num_workers + self.worker_id
|
||||||
|
|
||||||
|
# init data pipeline
|
||||||
|
abs_shard_filenames = [os.path.join(self.root, f) for f in self.split_info.filenames]
|
||||||
|
pipeline = [wds.SimpleShardList(abs_shard_filenames)]
|
||||||
|
# at this point we have an iterator over all the shards
|
||||||
|
if self.is_training:
|
||||||
|
pipeline.extend([
|
||||||
|
detshuffle2(self.shard_shuffle_size, seed=self.common_seed, epoch=self.epoch_count),
|
||||||
|
self._split_by_node_and_worker,
|
||||||
|
# at this point, we have an iterator over the shards assigned to each worker
|
||||||
|
wds.tarfile_to_samples(handler=log_and_continue),
|
||||||
|
wds.shuffle(
|
||||||
|
self.sample_shuffle_size,
|
||||||
|
rng=random.Random(self.worker_seed)), # this is why we lazy-init whole DataPipeline
|
||||||
|
])
|
||||||
|
else:
|
||||||
|
pipeline.extend([
|
||||||
|
self._split_by_node_and_worker,
|
||||||
|
# at this point, we have an iterator over the shards assigned to each worker
|
||||||
|
wds.tarfile_to_samples(handler=log_and_continue),
|
||||||
|
])
|
||||||
|
pipeline.extend([
|
||||||
|
partial(
|
||||||
|
_decode_samples,
|
||||||
|
image_key=self.image_key,
|
||||||
|
image_format=self.image_format,
|
||||||
|
alt_label=self.split_info.alt_label
|
||||||
|
)
|
||||||
|
])
|
||||||
|
self.ds = wds.DataPipeline(*pipeline)
|
||||||
|
|
||||||
|
def _split_by_node_and_worker(self, src):
|
||||||
|
if self.global_num_workers > 1:
|
||||||
|
for s in islice(src, self.global_worker_id, None, self.global_num_workers):
|
||||||
|
yield s
|
||||||
|
else:
|
||||||
|
for s in src:
|
||||||
|
yield s
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
if self.ds is None:
|
||||||
|
self._lazy_init()
|
||||||
|
|
||||||
|
if self.is_training:
|
||||||
|
num_worker_samples = math.floor(self.num_samples / self.global_num_workers)
|
||||||
|
if self.batch_size is not None:
|
||||||
|
num_worker_samples = (num_worker_samples // self.batch_size) * self.batch_size
|
||||||
|
ds = self.ds.with_epoch(num_worker_samples)
|
||||||
|
else:
|
||||||
|
if self.dist_num_replicas > 1:
|
||||||
|
# doing distributed validation w/ WDS is messy, hard to meet constraints that
|
||||||
|
# same # of batches needed across all replicas w/ seeing each sample once.
|
||||||
|
# with_epoch() is simple but could miss a shard's worth of samples in some workers,
|
||||||
|
# and duplicate in others. Best to keep num DL workers low and a divisor of #val shards.
|
||||||
|
num_worker_samples = math.ceil(self.num_samples / self.global_num_workers)
|
||||||
|
ds = self.ds.with_epoch(num_worker_samples)
|
||||||
|
else:
|
||||||
|
ds = self.ds
|
||||||
|
|
||||||
|
i = 0
|
||||||
|
_logger.info('start', i, self.worker_id) # FIXME temporary debug
|
||||||
|
for sample in ds:
|
||||||
|
yield sample[self.image_key], sample[self.target_key]
|
||||||
|
i += 1
|
||||||
|
_logger.info('end', i, self.worker_id) # FIXME temporary debug
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return math.ceil(max(1, self.repeats) * self.num_samples / self.dist_num_replicas)
|
||||||
|
|
||||||
|
def _filename(self, index, basename=False, absolute=False):
|
||||||
|
assert False, "Not supported" # no random access to examples
|
||||||
|
|
||||||
|
def filenames(self, basename=False, absolute=False):
|
||||||
|
""" Return all filenames in dataset, overrides base"""
|
||||||
|
if self.ds is None:
|
||||||
|
self._lazy_init()
|
||||||
|
|
||||||
|
names = []
|
||||||
|
for sample in self.ds:
|
||||||
|
if self.filename_key in sample:
|
||||||
|
name = sample[self.filename_key]
|
||||||
|
elif '__key__' in sample:
|
||||||
|
name = sample['__key__'] + self.key_ext
|
||||||
|
else:
|
||||||
|
assert False, "No supported name field present"
|
||||||
|
names.append(name)
|
||||||
|
if len(names) >= self.num_samples:
|
||||||
|
break # safety for ds.repeat() case
|
||||||
|
return names
|
|
@ -0,0 +1,14 @@
|
||||||
|
from multiprocessing import Value
|
||||||
|
|
||||||
|
|
||||||
|
class SharedCount:
|
||||||
|
def __init__(self, epoch: int = 0):
|
||||||
|
self.shared_epoch = Value('i', epoch)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def value(self):
|
||||||
|
return self.shared_epoch.value
|
||||||
|
|
||||||
|
@value.setter
|
||||||
|
def value(self, epoch):
|
||||||
|
self.shared_epoch.value = epoch
|
26
train.py
26
train.py
|
@ -111,7 +111,9 @@ group.add_argument('--num-classes', type=int, default=None, metavar='N',
|
||||||
group.add_argument('--gp', default=None, type=str, metavar='POOL',
|
group.add_argument('--gp', default=None, type=str, metavar='POOL',
|
||||||
help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.')
|
help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.')
|
||||||
group.add_argument('--img-size', type=int, default=None, metavar='N',
|
group.add_argument('--img-size', type=int, default=None, metavar='N',
|
||||||
help='Image patch size (default: None => model default)')
|
help='Image size (default: None => model default)')
|
||||||
|
group.add_argument('--in-chans', type=int, default=None, metavar='N',
|
||||||
|
help='Image input channels (default: None => 3)')
|
||||||
group.add_argument('--input-size', default=None, nargs=3, type=int,
|
group.add_argument('--input-size', default=None, nargs=3, type=int,
|
||||||
metavar='N N N', help='Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty')
|
metavar='N N N', help='Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty')
|
||||||
group.add_argument('--crop-pct', default=None, type=float,
|
group.add_argument('--crop-pct', default=None, type=float,
|
||||||
|
@ -394,9 +396,16 @@ def main():
|
||||||
if args.fast_norm:
|
if args.fast_norm:
|
||||||
set_fast_norm()
|
set_fast_norm()
|
||||||
|
|
||||||
|
in_chans = 3
|
||||||
|
if args.in_chans is not None:
|
||||||
|
in_chans = args.in_chanes
|
||||||
|
elif args.input_size is not None:
|
||||||
|
in_chans = args.input_size[0]
|
||||||
|
|
||||||
model = create_model(
|
model = create_model(
|
||||||
args.model,
|
args.model,
|
||||||
pretrained=args.pretrained,
|
pretrained=args.pretrained,
|
||||||
|
in_chans=in_chans,
|
||||||
num_classes=args.num_classes,
|
num_classes=args.num_classes,
|
||||||
drop_rate=args.drop,
|
drop_rate=args.drop,
|
||||||
drop_path_rate=args.drop_path,
|
drop_path_rate=args.drop_path,
|
||||||
|
@ -537,7 +546,8 @@ def main():
|
||||||
class_map=args.class_map,
|
class_map=args.class_map,
|
||||||
download=args.dataset_download,
|
download=args.dataset_download,
|
||||||
batch_size=args.batch_size,
|
batch_size=args.batch_size,
|
||||||
repeats=args.epoch_repeats
|
seed=args.seed,
|
||||||
|
repeats=args.epoch_repeats,
|
||||||
)
|
)
|
||||||
|
|
||||||
dataset_eval = create_dataset(
|
dataset_eval = create_dataset(
|
||||||
|
@ -547,7 +557,7 @@ def main():
|
||||||
is_training=False,
|
is_training=False,
|
||||||
class_map=args.class_map,
|
class_map=args.class_map,
|
||||||
download=args.dataset_download,
|
download=args.dataset_download,
|
||||||
batch_size=args.batch_size
|
batch_size=args.batch_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
# setup mixup / cutmix
|
# setup mixup / cutmix
|
||||||
|
@ -610,6 +620,10 @@ def main():
|
||||||
worker_seeding=args.worker_seeding,
|
worker_seeding=args.worker_seeding,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
eval_workers = args.workers
|
||||||
|
if args.distributed and ('tfds' in args.dataset or 'wds' in args.dataset):
|
||||||
|
# FIXME reduces validation padding issues when using TFDS, WDS w/ workers and distributed training
|
||||||
|
eval_workers = min(2, args.workers)
|
||||||
loader_eval = create_loader(
|
loader_eval = create_loader(
|
||||||
dataset_eval,
|
dataset_eval,
|
||||||
input_size=data_config['input_size'],
|
input_size=data_config['input_size'],
|
||||||
|
@ -619,7 +633,7 @@ def main():
|
||||||
interpolation=data_config['interpolation'],
|
interpolation=data_config['interpolation'],
|
||||||
mean=data_config['mean'],
|
mean=data_config['mean'],
|
||||||
std=data_config['std'],
|
std=data_config['std'],
|
||||||
num_workers=args.workers,
|
num_workers=eval_workers,
|
||||||
distributed=args.distributed,
|
distributed=args.distributed,
|
||||||
crop_pct=data_config['crop_pct'],
|
crop_pct=data_config['crop_pct'],
|
||||||
pin_memory=args.pin_mem,
|
pin_memory=args.pin_mem,
|
||||||
|
@ -679,7 +693,9 @@ def main():
|
||||||
|
|
||||||
try:
|
try:
|
||||||
for epoch in range(start_epoch, num_epochs):
|
for epoch in range(start_epoch, num_epochs):
|
||||||
if args.distributed and hasattr(loader_train.sampler, 'set_epoch'):
|
if hasattr(dataset_train, 'set_epoch'):
|
||||||
|
dataset_train.set_epoch(epoch)
|
||||||
|
elif args.distributed and hasattr(loader_train.sampler, 'set_epoch'):
|
||||||
loader_train.sampler.set_epoch(epoch)
|
loader_train.sampler.set_epoch(epoch)
|
||||||
|
|
||||||
train_metrics = train_one_epoch(
|
train_metrics = train_one_epoch(
|
||||||
|
|
Loading…
Reference in New Issue