Significant transforms, dataset, dataloading enhancements.

This commit is contained in:
Ross Wightman 2024-01-04 17:06:07 -08:00 committed by Ross Wightman
parent b5a4fa9c3b
commit be0944edae
13 changed files with 1058 additions and 315 deletions

View File

@ -27,7 +27,7 @@ class ImageDataset(data.Dataset):
split='train',
class_map=None,
load_bytes=False,
img_mode='RGB',
input_img_mode='RGB',
transform=None,
target_transform=None,
):
@ -40,7 +40,7 @@ class ImageDataset(data.Dataset):
)
self.reader = reader
self.load_bytes = load_bytes
self.img_mode = img_mode
self.input_img_mode = input_img_mode
self.transform = transform
self.target_transform = target_transform
self._consecutive_errors = 0
@ -59,8 +59,8 @@ class ImageDataset(data.Dataset):
raise e
self._consecutive_errors = 0
if self.img_mode and not self.load_bytes:
img = img.convert(self.img_mode)
if self.input_img_mode and not self.load_bytes:
img = img.convert(self.input_img_mode)
if self.transform is not None:
img = self.transform(img)
@ -90,12 +90,17 @@ class IterableImageDataset(data.IterableDataset):
split='train',
class_map=None,
is_training=False,
batch_size=None,
batch_size=1,
num_samples=None,
seed=42,
repeats=0,
download=False,
input_img_mode='RGB',
input_key=None,
target_key=None,
transform=None,
target_transform=None,
max_steps=None,
):
assert reader is not None
if isinstance(reader, str):
@ -106,9 +111,14 @@ class IterableImageDataset(data.IterableDataset):
class_map=class_map,
is_training=is_training,
batch_size=batch_size,
num_samples=num_samples,
seed=seed,
repeats=repeats,
download=download,
input_img_mode=input_img_mode,
input_key=input_key,
target_key=target_key,
max_steps=max_steps,
)
else:
self.reader = reader

View File

@ -3,6 +3,7 @@
Hacked together by / Copyright 2021, Ross Wightman
"""
import os
from typing import Optional
from torchvision.datasets import CIFAR100, CIFAR10, MNIST, KMNIST, FashionMNIST, ImageFolder
try:
@ -60,22 +61,24 @@ def _search_split(root, split):
def create_dataset(
name,
root,
split='validation',
search_split=True,
class_map=None,
load_bytes=False,
is_training=False,
download=False,
batch_size=None,
seed=42,
repeats=0,
**kwargs
name: str,
root: Optional[str] = None,
split: str = 'validation',
search_split: bool = True,
class_map: dict = None,
load_bytes: bool = False,
is_training: bool = False,
download: bool = False,
batch_size: int = 1,
num_samples: Optional[int] = None,
seed: int = 42,
repeats: int = 0,
input_img_mode: str = 'RGB',
**kwargs,
):
""" Dataset factory method
In parenthesis after each arg are the type of dataset supported for each arg, one of:
In parentheses after each arg are the type of dataset supported for each arg, one of:
* folder - default, timm folder (or tar) based ImageDataset
* torch - torchvision based datasets
* HFDS - Hugging Face Datasets
@ -97,11 +100,13 @@ def create_dataset(
batch_size: batch size hint for (TFDS, WDS)
seed: seed for iterable datasets (TFDS, WDS)
repeats: dataset repeats per iteration i.e. epoch (TFDS, WDS)
input_img_mode: Input image color conversion mode e.g. 'RGB', 'L' (folder, TFDS, WDS, HFDS)
**kwargs: other args to pass to dataset
Returns:
Dataset object
"""
kwargs = {k: v for k, v in kwargs.items() if v is not None}
name = name.lower()
if name.startswith('torch/'):
name = name.split('/', 2)[-1]
@ -151,7 +156,29 @@ def create_dataset(
elif name.startswith('hfds/'):
# NOTE right now, HF datasets default arrow format is a random-access Dataset,
# There will be a IterableDataset variant too, TBD
ds = ImageDataset(root, reader=name, split=split, class_map=class_map, **kwargs)
ds = ImageDataset(
root,
reader=name,
split=split,
class_map=class_map,
input_img_mode=input_img_mode,
**kwargs,
)
elif name.startswith('hfids/'):
ds = IterableImageDataset(
root,
reader=name,
split=split,
class_map=class_map,
is_training=is_training,
download=download,
batch_size=batch_size,
num_samples=num_samples,
repeats=repeats,
seed=seed,
input_img_mode=input_img_mode,
**kwargs
)
elif name.startswith('tfds/'):
ds = IterableImageDataset(
root,
@ -161,8 +188,10 @@ def create_dataset(
is_training=is_training,
download=download,
batch_size=batch_size,
num_samples=num_samples,
repeats=repeats,
seed=seed,
input_img_mode=input_img_mode,
**kwargs
)
elif name.startswith('wds/'):
@ -173,8 +202,10 @@ def create_dataset(
class_map=class_map,
is_training=is_training,
batch_size=batch_size,
num_samples=num_samples,
repeats=repeats,
seed=seed,
input_img_mode=input_img_mode,
**kwargs
)
else:
@ -182,5 +213,12 @@ def create_dataset(
if search_split and os.path.isdir(root):
# look for split specific sub-folder in root
root = _search_split(root, split)
ds = ImageDataset(root, reader=name, class_map=class_map, load_bytes=load_bytes, **kwargs)
ds = ImageDataset(
root,
reader=name,
class_map=class_map,
load_bytes=load_bytes,
input_img_mode=input_img_mode,
**kwargs,
)
return ds

View File

@ -10,14 +10,14 @@ import random
from contextlib import suppress
from functools import partial
from itertools import repeat
from typing import Callable
from typing import Callable, Optional, Tuple, Union
import torch
import torch.utils.data
import numpy as np
from .constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .dataset import IterableImageDataset
from .dataset import IterableImageDataset, ImageDataset
from .distributed_sampler import OrderedDistributedSampler, RepeatAugSampler
from .random_erasing import RandomErasing
from .mixup import FastCollateMixup
@ -187,41 +187,91 @@ def _worker_init(worker_id, worker_seeding='all'):
def create_loader(
dataset,
input_size,
batch_size,
is_training=False,
use_prefetcher=True,
no_aug=False,
re_prob=0.,
re_mode='const',
re_count=1,
re_split=False,
scale=None,
ratio=None,
hflip=0.5,
vflip=0.,
color_jitter=0.4,
auto_augment=None,
num_aug_repeats=0,
num_aug_splits=0,
interpolation='bilinear',
mean=IMAGENET_DEFAULT_MEAN,
std=IMAGENET_DEFAULT_STD,
num_workers=1,
distributed=False,
crop_pct=None,
crop_mode=None,
collate_fn=None,
pin_memory=False,
fp16=False, # deprecated, use img_dtype
img_dtype=torch.float32,
device=torch.device('cuda'),
tf_preprocessing=False,
use_multi_epochs_loader=False,
persistent_workers=True,
worker_seeding='all',
dataset: Union[ImageDataset, IterableImageDataset],
input_size: Union[int, Tuple[int, int], Tuple[int, int, int]],
batch_size: int,
is_training: bool = False,
no_aug: bool = False,
re_prob: float = 0.,
re_mode: str = 'const',
re_count: int = 1,
re_split: bool = False,
scale: Optional[Tuple[float, float]] = None,
ratio: Optional[Tuple[float, float]] = None,
hflip: float = 0.5,
vflip: float = 0.,
color_jitter: float = 0.4,
color_jitter_prob: Optional[float] = None,
grayscale_prob: float = 0.,
gaussian_blur_prob: float = 0.,
auto_augment: Optional[str] = None,
num_aug_repeats: int = 0,
num_aug_splits: int = 0,
interpolation: str = 'bilinear',
mean: Tuple[float, ...] = IMAGENET_DEFAULT_MEAN,
std: Tuple[float, ...] = IMAGENET_DEFAULT_STD,
num_workers: int = 1,
distributed: bool = False,
crop_pct: Optional[float] = None,
crop_mode: Optional[str] = None,
crop_border_pixels: Optional[int] = None,
collate_fn: Optional[Callable] = None,
pin_memory: bool = False,
fp16: bool = False, # deprecated, use img_dtype
img_dtype: torch.dtype = torch.float32,
device: torch.device = torch.device('cuda'),
use_prefetcher: bool = True,
use_multi_epochs_loader: bool = False,
persistent_workers: bool = True,
worker_seeding: str = 'all',
tf_preprocessing: bool = False,
):
"""
Args:
dataset: The image dataset to load.
input_size: Target input size (channels, height, width) tuple or size scalar.
batch_size: Number of samples in a batch.
is_training: Return training (random) transforms.
no_aug: Disable augmentation for training (useful for debug).
re_prob: Random erasing probability.
re_mode: Random erasing fill mode.
re_count: Number of random erasing regions.
re_split: Control split of random erasing across batch size.
scale: Random resize scale range (crop area, < 1.0 => zoom in).
ratio: Random aspect ratio range (crop ratio for RRC, ratio adjustment factor for RKR).
hflip: Horizontal flip probability.
vflip: Vertical flip probability.
color_jitter: Random color jitter component factors (brightness, contrast, saturation, hue).
Scalar is applied as (scalar,) * 3 (no hue).
color_jitter_prob: Apply color jitter with this probability if not None (for SimlCLR-like aug
grayscale_prob: Probability of converting image to grayscale (for SimCLR-like aug).
gaussian_blur_prob: Probability of applying gaussian blur (for SimCLR-like aug).
auto_augment: Auto augment configuration string (see auto_augment.py).
num_aug_repeats: Enable special sampler to repeat same augmentation across distributed GPUs.
num_aug_splits: Enable mode where augmentations can be split across the batch.
interpolation: Image interpolation mode.
mean: Image normalization mean.
std: Image normalization standard deviation.
num_workers: Num worker processes per DataLoader.
distributed: Enable dataloading for distributed training.
crop_pct: Inference crop percentage (output size / resize size).
crop_mode: Inference crop mode. One of ['squash', 'border', 'center']. Defaults to 'center' when None.
crop_border_pixels: Inference crop border of specified # pixels around edge of original image.
collate_fn: Override default collate_fn.
pin_memory: Pin memory for device transfer.
fp16: Deprecated argument for half-precision input dtype. Use img_dtype.
img_dtype: Data type for input image.
device: Device to transfer inputs and targets to.
use_prefetcher: Use efficient pre-fetcher to load samples onto device.
use_multi_epochs_loader:
persistent_workers: Enable persistent worker processes.
worker_seeding: Control worker random seeding at init.
tf_preprocessing: Use TF 1.0 inference preprocessing for testing model ports.
Returns:
DataLoader
"""
re_num_splits = 0
if re_split:
# apply RE to second half of batch if no aug split otherwise line up with aug split
@ -229,24 +279,28 @@ def create_loader(
dataset.transform = create_transform(
input_size,
is_training=is_training,
use_prefetcher=use_prefetcher,
no_aug=no_aug,
scale=scale,
ratio=ratio,
hflip=hflip,
vflip=vflip,
color_jitter=color_jitter,
color_jitter_prob=color_jitter_prob,
grayscale_prob=grayscale_prob,
gaussian_blur_prob=gaussian_blur_prob,
auto_augment=auto_augment,
interpolation=interpolation,
mean=mean,
std=std,
crop_pct=crop_pct,
crop_mode=crop_mode,
tf_preprocessing=tf_preprocessing,
crop_border_pixels=crop_border_pixels,
re_prob=re_prob,
re_mode=re_mode,
re_count=re_count,
re_num_splits=re_num_splits,
tf_preprocessing=tf_preprocessing,
use_prefetcher=use_prefetcher,
separate=num_aug_splits > 0,
)

View File

@ -1,10 +1,17 @@
import os
from typing import Optional
from .reader_image_folder import ReaderImageFolder
from .reader_image_in_tar import ReaderImageInTar
def create_reader(name, root, split='train', **kwargs):
def create_reader(
name: str,
root: Optional[str] = None,
split: str = 'train',
**kwargs,
):
kwargs = {k: v for k, v in kwargs.items() if v is not None}
name = name.lower()
name = name.split('/', 1)
prefix = ''
@ -15,15 +22,18 @@ def create_reader(name, root, split='train', **kwargs):
# FIXME improve the selection right now just tfds prefix or fallback path, will need options to
# explicitly select other options shortly
if prefix == 'hfds':
from .reader_hfds import ReaderHfds # defer tensorflow import
reader = ReaderHfds(root, name, split=split, **kwargs)
from .reader_hfds import ReaderHfds # defer Hf datasets import
reader = ReaderHfds(name=name, root=root, split=split, **kwargs)
elif prefix == 'hfids':
from .reader_hfids import ReaderHfids # defer HF datasets import
reader = ReaderHfids(name=name, root=root, split=split, **kwargs)
elif prefix == 'tfds':
from .reader_tfds import ReaderTfds # defer tensorflow import
reader = ReaderTfds(root, name, split=split, **kwargs)
reader = ReaderTfds(name=name, root=root, split=split, **kwargs)
elif prefix == 'wds':
from .reader_wds import ReaderWds
kwargs.pop('download', False)
reader = ReaderWds(root, name, split=split, **kwargs)
reader = ReaderWds(root=root, name=name, split=split, **kwargs)
else:
assert os.path.exists(root)
# default fallback path (backwards compat), use image tar if root is a .tar file, otherwise image folder

View File

@ -4,6 +4,8 @@ Hacked together by / Copyright 2022 Ross Wightman
"""
import io
import math
from typing import Optional
import torch
import torch.distributed as dist
from PIL import Image
@ -12,7 +14,7 @@ try:
import datasets
except ImportError as e:
print("Please install Hugging Face datasets package `pip install datasets`.")
exit(1)
raise e
from .class_map import load_class_map
from .reader import Reader
@ -29,12 +31,13 @@ class ReaderHfds(Reader):
def __init__(
self,
root,
name,
split='train',
class_map=None,
label_key='label',
download=False,
name: str,
root: Optional[str] = None,
split: str = 'train',
class_map: dict = None,
image_key: str = 'image',
target_key: str = 'label',
download: bool = False,
):
"""
"""
@ -47,9 +50,10 @@ class ReaderHfds(Reader):
cache_dir=self.root, # timm doesn't expect hidden cache dir for datasets, specify a path
)
# leave decode for caller, plus we want easy access to original path names...
self.dataset = self.dataset.cast_column('image', datasets.Image(decode=False))
self.dataset = self.dataset.cast_column(image_key, datasets.Image(decode=False))
self.label_key = label_key
self.image_key = image_key
self.label_key = target_key
self.remap_class = False
if class_map:
self.class_to_idx = load_class_map(class_map)
@ -61,7 +65,7 @@ class ReaderHfds(Reader):
def __getitem__(self, index):
item = self.dataset[index]
image = item['image']
image = item[self.image_key]
if 'bytes' in image and image['bytes']:
image = io.BytesIO(image['bytes'])
else:
@ -77,4 +81,4 @@ class ReaderHfds(Reader):
def _filename(self, index, basename=False, absolute=False):
item = self.dataset[index]
return item['image']['path']
return item[self.image_key]['path']

View File

@ -0,0 +1,213 @@
""" Dataset reader for HF IterableDataset
"""
import math
import os
from itertools import repeat, chain
from typing import Optional
import torch
import torch.distributed as dist
from PIL import Image
try:
import datasets
from datasets.distributed import split_dataset_by_node
from datasets.splits import SplitInfo
except ImportError as e:
print("Please install Hugging Face datasets package `pip install datasets`.")
raise e
from .class_map import load_class_map
from .reader import Reader
from .shared_count import SharedCount
SHUFFLE_SIZE = int(os.environ.get('HFIDS_SHUFFLE_SIZE', 4096))
class ReaderHfids(Reader):
def __init__(
self,
name: str,
root: Optional[str] = None,
split: str = 'train',
is_training: bool = False,
batch_size: int = 1,
download: bool = False,
repeats: int = 0,
seed: int = 42,
class_map: Optional[dict] = None,
input_key: str = 'image',
input_img_mode: str = 'RGB',
target_key: str = 'label',
target_img_mode: str = '',
shuffle_size: Optional[int] = None,
num_samples: Optional[int] = None,
):
super().__init__()
self.root = root
self.split = split
self.is_training = is_training
self.batch_size = batch_size
self.download = download
self.repeats = repeats
self.common_seed = seed # a seed that's fixed across all worker / distributed instances
self.shuffle_size = shuffle_size or SHUFFLE_SIZE
self.input_key = input_key
self.input_img_mode = input_img_mode
self.target_key = target_key
self.target_img_mode = target_img_mode
self.builder = datasets.load_dataset_builder(name, cache_dir=root)
if download:
self.builder.download_and_prepare()
split_info: Optional[SplitInfo] = None
if self.builder.info.splits and split in self.builder.info.splits:
if isinstance(self.builder.info.splits[split], SplitInfo):
split_info: Optional[SplitInfo] = self.builder.info.splits[split]
if num_samples:
self.num_samples = num_samples
elif split_info and split_info.num_examples:
self.num_samples = split_info.num_examples
else:
raise ValueError(
"Dataset length is unknown, please pass `num_samples` explicitely. "
"The number of steps needs to be known in advance for the learning rate scheduler."
)
self.remap_class = False
if class_map:
self.class_to_idx = load_class_map(class_map)
self.remap_class = True
else:
self.class_to_idx = {}
# Distributed world state
self.dist_rank = 0
self.dist_num_replicas = 1
if dist.is_available() and dist.is_initialized() and dist.get_world_size() > 1:
self.dist_rank = dist.get_rank()
self.dist_num_replicas = dist.get_world_size()
# Attributes that are updated in _lazy_init
self.worker_info = None
self.worker_id = 0
self.num_workers = 1
self.global_worker_id = 0
self.global_num_workers = 1
# Initialized lazily on each dataloader worker process
self.ds: Optional[datasets.IterableDataset] = None
self.epoch = SharedCount()
def set_epoch(self, count):
# to update the shuffling effective_seed = seed + epoch
self.epoch.value = count
def set_loader_cfg(
self,
num_workers: Optional[int] = None,
):
if self.ds is not None:
return
if num_workers is not None:
self.num_workers = num_workers
self.global_num_workers = self.dist_num_replicas * self.num_workers
def _lazy_init(self):
""" Lazily initialize worker (in worker processes)
"""
if self.worker_info is None:
worker_info = torch.utils.data.get_worker_info()
if worker_info is not None:
self.worker_info = worker_info
self.worker_id = worker_info.id
self.num_workers = worker_info.num_workers
self.global_num_workers = self.dist_num_replicas * self.num_workers
self.global_worker_id = self.dist_rank * self.num_workers + self.worker_id
if self.download:
dataset = self.builder.as_dataset(split=self.split)
# to distribute evenly to workers
ds = dataset.to_iterable_dataset(num_shards=self.global_num_workers)
else:
# in this case the number of shard is determined by the number of remote files
ds = self.builder.as_streaming_dataset(split=self.split)
if self.is_training:
# will shuffle the list of shards and use a shuffle buffer
ds = ds.shuffle(seed=self.common_seed, buffer_size=self.shuffle_size)
# Distributed:
# The dataset has a number of shards that is a factor of `dist_num_replicas` (i.e. if `ds.n_shards % dist_num_replicas == 0`),
# so the shards are evenly assigned across the nodes.
# If it's not the case for dataset streaming, each node keeps 1 example out of `dist_num_replicas`, skipping the other examples.
# Workers:
# In a node, datasets.IterableDataset assigns the shards assigned to the node as evenly as possible to workers.
self.ds = split_dataset_by_node(ds, rank=self.dist_rank, world_size=self.dist_num_replicas)
def _num_samples_per_worker(self):
num_worker_samples = \
max(1, self.repeats) * self.num_samples / max(self.global_num_workers, self.dist_num_replicas)
if self.is_training or self.dist_num_replicas > 1:
num_worker_samples = math.ceil(num_worker_samples)
if self.is_training and self.batch_size is not None:
num_worker_samples = math.ceil(num_worker_samples / self.batch_size) * self.batch_size
return int(num_worker_samples)
def __iter__(self):
if self.ds is None:
self._lazy_init()
self.ds.set_epoch(self.epoch.value)
target_sample_count = self._num_samples_per_worker()
sample_count = 0
ds_iter = iter(self.ds)
if self.is_training:
ds_iter = chain.from_iterable(repeat(ds_iter))
for sample in ds_iter:
input_data: Image.Image = sample[self.input_key]
if self.input_img_mode and input_data.mode != self.input_img_mode:
input_data = input_data.convert(self.input_img_mode)
target_data = sample[self.target_key]
if self.target_img_mode:
assert isinstance(target_data, Image.Image), "target_img_mode is specified but target is not an image"
if target_data.mode != self.target_img_mode:
target_data = target_data.convert(self.target_img_mode)
elif self.remap_class:
target_data = self.class_to_idx[target_data]
yield input_data, target_data
sample_count += 1
if self.is_training and sample_count >= target_sample_count:
break
def __len__(self):
num_samples = self._num_samples_per_worker() * self.num_workers
return num_samples
def _filename(self, index, basename=False, absolute=False):
assert False, "Not supported" # no random access to examples
def filenames(self, basename=False, absolute=False):
""" Return all filenames in dataset, overrides base"""
if self.ds is None:
self._lazy_init()
names = []
for sample in self.ds:
if 'file_name' in sample:
name = sample['file_name']
elif 'filename' in sample:
name = sample['filename']
elif 'id' in sample:
name = sample['id']
elif 'image_id' in sample:
name = sample['image_id']
else:
assert False, "No supported name field present"
names.append(name)
return names

View File

@ -61,14 +61,23 @@ class ReaderImageFolder(Reader):
def __init__(
self,
root,
class_map=''):
class_map='',
input_key=None,
):
super().__init__()
self.root = root
class_to_idx = None
if class_map:
class_to_idx = load_class_map(class_map, root)
self.samples, self.class_to_idx = find_images_and_targets(root, class_to_idx=class_to_idx)
find_types = None
if input_key:
find_types = input_key.split(';')
self.samples, self.class_to_idx = find_images_and_targets(
root,
class_to_idx=class_to_idx,
types=find_types,
)
if len(self.samples) == 0:
raise RuntimeError(
f'Found 0 images in subfolders of {root}. '

View File

@ -8,6 +8,7 @@ Hacked together by / Copyright 2020 Ross Wightman
"""
import math
import os
import sys
from typing import Optional
import torch
@ -32,7 +33,7 @@ try:
except ImportError as e:
print(e)
print("Please install tensorflow_datasets package `pip install tensorflow-datasets`.")
exit(1)
raise e
from .class_map import load_class_map
from .reader import Reader
@ -45,10 +46,10 @@ PREFETCH_SIZE = int(os.environ.get('TFDS_PREFETCH_SIZE', 2048)) # samples to pr
@tfds.decode.make_decoder()
def decode_example(serialized_image, feature, dct_method='INTEGER_ACCURATE'):
def decode_example(serialized_image, feature, dct_method='INTEGER_ACCURATE', channels=3):
return tf.image.decode_jpeg(
serialized_image,
channels=3,
channels=channels,
dct_method=dct_method,
)
@ -92,18 +93,18 @@ class ReaderTfds(Reader):
def __init__(
self,
root,
name,
root=None,
split='train',
class_map=None,
is_training=False,
batch_size=None,
batch_size=1,
download=False,
repeats=0,
seed=42,
input_name='image',
input_key='image',
input_img_mode='RGB',
target_name='label',
target_key='label',
target_img_mode='',
prefetch_size=None,
shuffle_size=None,
@ -120,9 +121,9 @@ class ReaderTfds(Reader):
download: download and build TFDS dataset if set, otherwise must use tfds CLI
repeats: iterate through (repeat) the dataset this many times per iteration (once if 0 or 1)
seed: common seed for shard shuffle across all distributed/worker instances
input_name: name of Feature to return as data (input)
input_key: name of Feature to return as data (input)
input_img_mode: image mode if input is an image (currently PIL mode string)
target_name: name of Feature to return as target (label)
target_key: name of Feature to return as target (label)
target_img_mode: image mode if target is an image (currently PIL mode string)
prefetch_size: override default tf.data prefetch buffer size
shuffle_size: override default tf.data shuffle buffer size
@ -132,9 +133,6 @@ class ReaderTfds(Reader):
self.root = root
self.split = split
self.is_training = is_training
if self.is_training:
assert batch_size is not None, \
"Must specify batch_size in training mode for reasonable behaviour w/ TFDS wrapper"
self.batch_size = batch_size
self.repeats = repeats
self.common_seed = seed # a seed that's fixed across all worker / distributed instances
@ -145,10 +143,10 @@ class ReaderTfds(Reader):
self.max_threadpool_size = max_threadpool_size or MAX_TP_SIZE
# TFDS builder and split information
self.input_name = input_name # FIXME support tuples / lists of inputs and targets and full range of Feature
self.input_key = input_key # FIXME support tuples / lists of inputs and targets and full range of Feature
self.input_img_mode = input_img_mode
self.target_name = target_name
self.target_img_mode = target_img_mode
self.target_key = target_key
self.target_img_mode = target_img_mode # for dense pixel targets
self.builder = tfds.builder(name, data_dir=root)
# NOTE: the tfds command line app can be used download & prepare datasets if you don't enable download flag
if download:
@ -158,7 +156,7 @@ class ReaderTfds(Reader):
self.class_to_idx = load_class_map(class_map)
self.remap_class = True
else:
self.class_to_idx = get_class_labels(self.builder.info) if self.target_name == 'label' else {}
self.class_to_idx = get_class_labels(self.builder.info) if self.target_key == 'label' else {}
self.split_info = self.builder.info.splits[split]
self.num_samples = self.split_info.num_examples
@ -258,7 +256,7 @@ class ReaderTfds(Reader):
ds = self.builder.as_dataset(
split=self.subsplit or self.split,
shuffle_files=self.is_training,
decoders=dict(image=decode_example()),
decoders=dict(image=decode_example(channels=1 if self.input_img_mode == 'L' else 3)),
read_config=read_config,
)
# avoid overloading threading w/ combo of TF ds threads + PyTorch workers
@ -282,7 +280,7 @@ class ReaderTfds(Reader):
max(1, self.repeats) * self.num_samples / max(self.global_num_workers, self.dist_num_replicas)
if self.is_training or self.dist_num_replicas > 1:
num_worker_samples = math.ceil(num_worker_samples)
if self.is_training and self.batch_size is not None:
if self.is_training:
num_worker_samples = math.ceil(num_worker_samples / self.batch_size) * self.batch_size
return int(num_worker_samples)
@ -300,11 +298,14 @@ class ReaderTfds(Reader):
# Iterate until exhausted or sample count hits target when training (ds.repeat enabled)
sample_count = 0
for sample in self.ds:
input_data = sample[self.input_name]
input_data = sample[self.input_key]
if self.input_img_mode:
if self.input_img_mode == 'L' and input_data.ndim == 3:
input_data = input_data[:, :, 0]
input_data = Image.fromarray(input_data, mode=self.input_img_mode)
target_data = sample[self.target_name]
target_data = sample[self.target_key]
if self.target_img_mode:
# dense pixel target
target_data = Image.fromarray(target_data, mode=self.target_img_mode)
elif self.remap_class:
target_data = self.class_to_idx[target_data]

View File

@ -22,7 +22,7 @@ from torch.utils.data import Dataset, IterableDataset, get_worker_info
try:
import webdataset as wds
from webdataset.filters import _shuffle
from webdataset.filters import _shuffle, getfirst
from webdataset.shardlists import expand_urls
from webdataset.tariterators import base_plus_ext, url_opener, tar_file_expander, valid_sample
except ImportError:
@ -35,27 +35,30 @@ from .shared_count import SharedCount
_logger = logging.getLogger(__name__)
SHUFFLE_SIZE = int(os.environ.get('WDS_SHUFFLE_SIZE', 8192))
SAMPLE_SHUFFLE_SIZE = int(os.environ.get('WDS_SHUFFLE_SIZE', 8192))
SAMPLE_INITIAL_SIZE = int(os.environ.get('WDS_INITIAL_SIZE', 2048))
def _load_info(root, basename='info'):
info_json = os.path.join(root, basename + '.json')
info_yaml = os.path.join(root, basename + '.yaml')
def _load_info(root, names=('_info.json', 'info.json')):
if isinstance(names, str):
names = (names,)
tried = []
err_str = ''
for n in names:
full_path = os.path.join(root, n)
try:
with wds.gopen(info_json) as f:
tried.append(full_path)
with wds.gopen(full_path) as f:
if n.endswith('.json'):
info_dict = json.load(f)
else:
info_dict = yaml.safe_load(f)
return info_dict
except Exception as e:
err_str = str(e)
try:
with wds.gopen(info_yaml) as f:
info_dict = yaml.safe_load(f)
return info_dict
except Exception:
pass
_logger.warning(
f'Dataset info file not found at {info_json} or {info_yaml}. Error: {err_str}. '
f'Dataset info file not found at {tried}. Error: {err_str}. '
'Falling back to provided split and size arg.')
return {}
@ -121,15 +124,18 @@ def _parse_split_info(split: str, info: Dict):
def log_and_continue(exn):
"""Call in an exception handler to ignore any exception, isssue a warning, and continue."""
"""Call in an exception handler to ignore exceptions, isssue a warning, and continue."""
_logger.warning(f'Handling webdataset error ({repr(exn)}). Ignoring.')
# NOTE: try force an exit on errors that are clearly code / config and not transient
if isinstance(exn, TypeError):
raise exn
return True
def _decode(
sample,
image_key='jpg',
image_format='RGB',
image_mode='RGB',
target_key='cls',
alt_label=''
):
@ -150,47 +156,18 @@ def _decode(
class_label = int(sample[target_key])
# decode image
with io.BytesIO(sample[image_key]) as b:
img = getfirst(sample, image_key)
with io.BytesIO(img) as b:
img = Image.open(b)
img.load()
if image_format:
img = img.convert(image_format)
if image_mode:
img = img.convert(image_mode)
# json passed through in undecoded state
decoded = dict(jpg=img, cls=class_label, json=sample.get('json', None))
return decoded
def _decode_samples(
data,
image_key='jpg',
image_format='RGB',
target_key='cls',
alt_label='',
handler=log_and_continue):
"""Decode samples with skip."""
for sample in data:
try:
result = _decode(
sample,
image_key=image_key,
image_format=image_format,
target_key=target_key,
alt_label=alt_label
)
except Exception as exn:
if handler(exn):
continue
else:
break
# null results are skipped
if result is not None:
if isinstance(sample, dict) and isinstance(result, dict):
result["__key__"] = sample.get("__key__")
yield result
def pytorch_worker_seed():
"""get dataloader worker seed from pytorch"""
worker_info = get_worker_info()
@ -203,6 +180,7 @@ def pytorch_worker_seed():
if wds is not None:
# conditional to avoid mandatory wds import (via inheritance of wds.PipelineStage)
class detshuffle2(wds.PipelineStage):
def __init__(
self,
@ -284,20 +262,22 @@ class ResampledShards2(IterableDataset):
class ReaderWds(Reader):
def __init__(
self,
root,
name,
split,
is_training=False,
batch_size=None,
repeats=0,
seed=42,
class_map=None,
input_name='jpg',
input_image='RGB',
target_name='cls',
target_image='',
prefetch_size=None,
shuffle_size=None,
root: str,
name: Optional[str] = None,
split: str = 'train',
is_training: bool = False,
num_samples: Optional[int] = None,
batch_size: int = 1,
repeats: int = 0,
seed: int = 42,
class_map: Optional[dict] = None,
input_key: str = 'jpg;png;webp',
input_img_mode: str = 'RGB',
target_key: str = 'cls',
target_img_mode: str = '',
filename_key: str = 'filename',
sample_shuffle_size: Optional[int] = None,
smaple_initial_size: Optional[int] = None,
):
super().__init__()
if wds is None:
@ -309,19 +289,23 @@ class ReaderWds(Reader):
self.repeats = repeats
self.common_seed = seed # a seed that's fixed across all worker / distributed instances
self.shard_shuffle_size = 500
self.sample_shuffle_size = shuffle_size or SHUFFLE_SIZE
self.sample_shuffle_size = sample_shuffle_size or SAMPLE_SHUFFLE_SIZE
self.sample_initial_size = smaple_initial_size or SAMPLE_INITIAL_SIZE
self.image_key = input_name
self.image_format = input_image
self.target_key = target_name
self.filename_key = 'filename'
self.input_key = input_key
self.input_img_mode = input_img_mode
self.target_key = target_key
self.filename_key = filename_key
self.key_ext = '.JPEG' # extension to add to key for original filenames (DS specific, default ImageNet)
self.info = _load_info(self.root)
self.split_info = _parse_split_info(split, self.info)
if num_samples is not None:
self.num_samples = num_samples
else:
self.num_samples = self.split_info.num_samples
if not self.num_samples:
raise RuntimeError(f'Invalid split definition, no samples found.')
raise RuntimeError(f'Invalid split definition, num_samples not specified.')
self.remap_class = False
if class_map:
self.class_to_idx = load_class_map(class_map)
@ -346,7 +330,7 @@ class ReaderWds(Reader):
self.init_count = 0
self.epoch_count = SharedCount()
# DataPipeline is lazy init, majority of WDS DataPipeline could be init here, BUT, shuffle seed
# DataPipeline is lazy init, the majority of WDS DataPipeline could be init here, BUT, shuffle seed
# is not handled in manner where it can be deterministic for each worker AND initialized up front
self.ds = None
@ -382,13 +366,19 @@ class ReaderWds(Reader):
# at this point we have an iterator over all the shards
if self.is_training:
pipeline.extend([
detshuffle2(self.shard_shuffle_size, seed=self.common_seed, epoch=self.epoch_count),
detshuffle2(
self.shard_shuffle_size,
seed=self.common_seed,
epoch=self.epoch_count,
),
self._split_by_node_and_worker,
# at this point, we have an iterator over the shards assigned to each worker
wds.tarfile_to_samples(handler=log_and_continue),
wds.shuffle(
self.sample_shuffle_size,
rng=random.Random(self.worker_seed)), # this is why we lazy-init whole DataPipeline
bufsize=self.sample_shuffle_size,
initial=self.sample_initial_size,
rng=random.Random(self.worker_seed) # this is why we lazy-init whole DataPipeline
),
])
else:
pipeline.extend([
@ -397,12 +387,16 @@ class ReaderWds(Reader):
wds.tarfile_to_samples(handler=log_and_continue),
])
pipeline.extend([
wds.map(
partial(
_decode_samples,
image_key=self.image_key,
image_format=self.image_format,
alt_label=self.split_info.alt_label
)
_decode,
image_key=self.input_key,
image_mode=self.input_img_mode,
alt_label=self.split_info.alt_label,
),
handler=log_and_continue,
),
wds.rename(image=self.input_key, target=self.target_key)
])
self.ds = wds.DataPipeline(*pipeline)
@ -418,7 +412,7 @@ class ReaderWds(Reader):
num_worker_samples = self.num_samples / max(self.global_num_workers, self.dist_num_replicas)
if self.is_training or self.dist_num_replicas > 1:
num_worker_samples = math.ceil(num_worker_samples)
if self.is_training and self.batch_size is not None:
if self.is_training:
num_worker_samples = math.ceil(num_worker_samples / self.batch_size) * self.batch_size
return int(num_worker_samples)
@ -439,10 +433,10 @@ class ReaderWds(Reader):
i = 0
# _logger.info(f'start {i}, {self.worker_id}') # FIXME temporary debug
for sample in ds:
target = sample[self.target_key]
target = sample['target']
if self.remap_class:
target = self.class_to_idx[target]
yield sample[self.image_key], target
yield sample['image'], target
i += 1
# _logger.info(f'end {i}, {self.worker_id}') # FIXME temporary debug

View File

@ -2,7 +2,7 @@ import math
import numbers
import random
import warnings
from typing import List, Sequence
from typing import List, Sequence, Tuple, Union
import torch
import torchvision.transforms.functional as F
@ -14,6 +14,12 @@ except ImportError:
from PIL import Image
import numpy as np
__all__ = [
"ToNumpy", "ToTensor", "str_to_interp_mode", "str_to_pil_interp", "interp_mode_to_str",
"RandomResizedCropAndInterpolation", "CenterCropOrPad", "center_crop_or_pad", "crop_or_pad",
"RandomCropOrPad", "RandomPad", "ResizeKeepRatio", "TrimBorder"
]
class ToNumpy:
@ -99,7 +105,7 @@ def interp_mode_to_str(mode):
_RANDOM_INTERPOLATION = (str_to_interp_mode('bilinear'), str_to_interp_mode('bicubic'))
def _setup_size(size, error_msg):
def _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size."):
if isinstance(size, numbers.Number):
return int(size), int(size)
@ -127,8 +133,13 @@ class RandomResizedCropAndInterpolation:
interpolation: Default: PIL.Image.BILINEAR
"""
def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.),
interpolation='bilinear'):
def __init__(
self,
size,
scale=(0.08, 1.0),
ratio=(3. / 4., 4. / 3.),
interpolation='bilinear',
):
if isinstance(size, (list, tuple)):
self.size = tuple(size)
else:
@ -156,35 +167,35 @@ class RandomResizedCropAndInterpolation:
tuple: params (i, j, h, w) to be passed to ``crop`` for a random
sized crop.
"""
area = img.size[0] * img.size[1]
img_w, img_h = F.get_image_size(img)
area = img_w * img_h
for attempt in range(10):
target_area = random.uniform(*scale) * area
log_ratio = (math.log(ratio[0]), math.log(ratio[1]))
aspect_ratio = math.exp(random.uniform(*log_ratio))
w = int(round(math.sqrt(target_area * aspect_ratio)))
h = int(round(math.sqrt(target_area / aspect_ratio)))
if w <= img.size[0] and h <= img.size[1]:
i = random.randint(0, img.size[1] - h)
j = random.randint(0, img.size[0] - w)
return i, j, h, w
target_w = int(round(math.sqrt(target_area * aspect_ratio)))
target_h = int(round(math.sqrt(target_area / aspect_ratio)))
if target_w <= img_w and target_h <= img_h:
i = random.randint(0, img_h - target_h)
j = random.randint(0, img_w - target_w)
return i, j, target_h, target_w
# Fallback to central crop
in_ratio = img.size[0] / img.size[1]
in_ratio = img_w / img_h
if in_ratio < min(ratio):
w = img.size[0]
h = int(round(w / min(ratio)))
target_w = img_w
target_h = int(round(target_w / min(ratio)))
elif in_ratio > max(ratio):
h = img.size[1]
w = int(round(h * max(ratio)))
target_h = img_h
target_w = int(round(target_h * max(ratio)))
else: # whole image
w = img.size[0]
h = img.size[1]
i = (img.size[1] - h) // 2
j = (img.size[0] - w) // 2
return i, j, h, w
target_w = img_w
target_h = img_h
i = (img_h - target_h) // 2
j = (img_w - target_w) // 2
return i, j, target_h, target_w
def __call__(self, img):
"""
@ -213,8 +224,14 @@ class RandomResizedCropAndInterpolation:
return format_string
def center_crop_or_pad(img: torch.Tensor, output_size: List[int], fill=0) -> torch.Tensor:
def center_crop_or_pad(
img: torch.Tensor,
output_size: Union[int, List[int]],
fill: Union[int, Tuple[int, int, int]] = 0,
padding_mode: str = 'constant',
) -> torch.Tensor:
"""Center crops and/or pads the given image.
If the image is torch Tensor, it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
If image size is smaller than output size along any edge, image is padded with 0 and then center cropped.
@ -228,13 +245,9 @@ def center_crop_or_pad(img: torch.Tensor, output_size: List[int], fill=0) -> tor
Returns:
PIL Image or Tensor: Cropped image.
"""
if isinstance(output_size, numbers.Number):
output_size = (int(output_size), int(output_size))
elif isinstance(output_size, (tuple, list)) and len(output_size) == 1:
output_size = (output_size[0], output_size[0])
_, image_height, image_width = F.get_dimensions(img)
output_size = _setup_size(output_size)
crop_height, crop_width = output_size
_, image_height, image_width = F.get_dimensions(img)
if crop_width > image_width or crop_height > image_height:
padding_ltrb = [
@ -243,7 +256,7 @@ def center_crop_or_pad(img: torch.Tensor, output_size: List[int], fill=0) -> tor
(crop_width - image_width + 1) // 2 if crop_width > image_width else 0,
(crop_height - image_height + 1) // 2 if crop_height > image_height else 0,
]
img = F.pad(img, padding_ltrb, fill=fill)
img = F.pad(img, padding_ltrb, fill=fill, padding_mode=padding_mode)
_, image_height, image_width = F.get_dimensions(img)
if crop_width == image_width and crop_height == image_height:
return img
@ -265,10 +278,16 @@ class CenterCropOrPad(torch.nn.Module):
made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
"""
def __init__(self, size, fill=0):
def __init__(
self,
size: Union[int, List[int]],
fill: Union[int, Tuple[int, int, int]] = 0,
padding_mode: str = 'constant',
):
super().__init__()
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
self.size = _setup_size(size)
self.fill = fill
self.padding_mode = padding_mode
def forward(self, img):
"""
@ -278,14 +297,111 @@ class CenterCropOrPad(torch.nn.Module):
Returns:
PIL Image or Tensor: Cropped image.
"""
return center_crop_or_pad(img, self.size, fill=self.fill)
return center_crop_or_pad(img, self.size, fill=self.fill, padding_mode=self.padding_mode)
def __repr__(self) -> str:
return f"{self.__class__.__name__}(size={self.size})"
def crop_or_pad(
img: torch.Tensor,
top: int,
left: int,
height: int,
width: int,
fill: Union[int, Tuple[int, int, int]] = 0,
padding_mode: str = 'constant',
) -> torch.Tensor:
""" Crops and/or pads image to meet target size, with control over fill and padding_mode.
"""
_, image_height, image_width = F.get_dimensions(img)
right = left + width
bottom = top + height
if left < 0 or top < 0 or right > image_width or bottom > image_height:
padding_ltrb = [
max(-left + min(0, right), 0),
max(-top + min(0, bottom), 0),
max(right - max(image_width, left), 0),
max(bottom - max(image_height, top), 0),
]
img = F.pad(img, padding_ltrb, fill=fill, padding_mode=padding_mode)
top = max(top, 0)
left = max(left, 0)
return F.crop(img, top, left, height, width)
class RandomCropOrPad(torch.nn.Module):
""" Crop and/or pad image with random placement within the crop or pad margin.
"""
def __init__(
self,
size: Union[int, List[int]],
fill: Union[int, Tuple[int, int, int]] = 0,
padding_mode: str = 'constant',
):
super().__init__()
self.size = _setup_size(size)
self.fill = fill
self.padding_mode = padding_mode
@staticmethod
def get_params(img, size):
_, image_height, image_width = F.get_dimensions(img)
delta_height = image_height - size[0]
delta_width = image_width - size[1]
top = int(math.copysign(random.randint(0, abs(delta_height)), delta_height))
left = int(math.copysign(random.randint(0, abs(delta_width)), delta_width))
return top, left
def forward(self, img):
"""
Args:
img (PIL Image or Tensor): Image to be cropped.
Returns:
PIL Image or Tensor: Cropped image.
"""
top, left = self.get_params(img, self.size)
return crop_or_pad(
img,
top=top,
left=left,
height=self.size[0],
width=self.size[1],
fill=self.fill,
padding_mode=self.padding_mode,
)
def __repr__(self) -> str:
return f"{self.__class__.__name__}(size={self.size})"
class RandomPad:
def __init__(self, input_size, fill=0):
self.input_size = input_size
self.fill = fill
@staticmethod
def get_params(img, input_size):
width, height = F.get_image_size(img)
delta_width = max(input_size[1] - width, 0)
delta_height = max(input_size[0] - height, 0)
pad_left = random.randint(0, delta_width)
pad_top = random.randint(0, delta_height)
pad_right = delta_width - pad_left
pad_bottom = delta_height - pad_top
return pad_left, pad_top, pad_right, pad_bottom
def __call__(self, img):
padding = self.get_params(img, self.input_size)
img = F.pad(img, padding, self.fill)
return img
class ResizeKeepRatio:
""" Resize and Keep Ratio
""" Resize and Keep Aspect Ratio
"""
def __init__(
@ -293,33 +409,77 @@ class ResizeKeepRatio:
size,
longest=0.,
interpolation='bilinear',
fill=0,
random_scale_prob=0.,
random_scale_range=(0.85, 1.05),
random_scale_area=False,
random_aspect_prob=0.,
random_aspect_range=(0.9, 1.11),
):
"""
Args:
size:
longest:
interpolation:
random_scale_prob:
random_scale_range:
random_scale_area:
random_aspect_prob:
random_aspect_range:
"""
if isinstance(size, (list, tuple)):
self.size = tuple(size)
else:
self.size = (size, size)
if interpolation == 'random':
self.interpolation = _RANDOM_INTERPOLATION
else:
self.interpolation = str_to_interp_mode(interpolation)
self.longest = float(longest)
self.fill = fill
self.random_scale_prob = random_scale_prob
self.random_scale_range = random_scale_range
self.random_scale_area = random_scale_area
self.random_aspect_prob = random_aspect_prob
self.random_aspect_range = random_aspect_range
@staticmethod
def get_params(img, target_size, longest):
def get_params(
img,
target_size,
longest,
random_scale_prob=0.,
random_scale_range=(1.0, 1.33),
random_scale_area=False,
random_aspect_prob=0.,
random_aspect_range=(0.9, 1.11)
):
"""Get parameters
Args:
img (PIL Image): Image to be cropped.
target_size (Tuple[int, int]): Size of output
Returns:
tuple: params (h, w) and (l, r, t, b) to be passed to ``resize`` and ``pad`` respectively
"""
source_size = img.size[::-1] # h, w
h, w = source_size
img_h, img_w = img_size = F.get_dimensions(img)[1:]
target_h, target_w = target_size
ratio_h = h / target_h
ratio_w = w / target_w
ratio_h = img_h / target_h
ratio_w = img_w / target_w
ratio = max(ratio_h, ratio_w) * longest + min(ratio_h, ratio_w) * (1. - longest)
size = [round(x / ratio) for x in source_size]
if random_scale_prob > 0 and random.random() < random_scale_prob:
ratio_factor = random.uniform(random_scale_range[0], random_scale_range[1])
if random_scale_area:
# make ratio factor equivalent to RRC area crop where < 1.0 = area zoom,
# otherwise like affine scale where < 1.0 = linear zoom out
ratio_factor = 1. / math.sqrt(ratio_factor)
ratio_factor = (ratio_factor, ratio_factor)
else:
ratio_factor = (1., 1.)
if random_aspect_prob > 0 and random.random() < random_aspect_prob:
log_aspect = (math.log(random_aspect_range[0]), math.log(random_aspect_range[1]))
aspect_factor = math.exp(random.uniform(*log_aspect))
aspect_factor = math.sqrt(aspect_factor)
# currently applying random aspect adjustment equally to both dims,
# could change to keep output sizes above their target where possible
ratio_factor = (ratio_factor[0] / aspect_factor, ratio_factor[1] * aspect_factor)
size = [round(x * f / ratio) for x, f in zip(img_size, ratio_factor)]
return size
def __call__(self, img):
@ -330,13 +490,49 @@ class ResizeKeepRatio:
Returns:
PIL Image: Resized, padded to at least target size, possibly cropped to exactly target size
"""
size = self.get_params(img, self.size, self.longest)
img = F.resize(img, size, self.interpolation)
size = self.get_params(
img, self.size, self.longest,
self.random_scale_prob, self.random_scale_range, self.random_scale_area,
self.random_aspect_prob, self.random_aspect_range
)
if isinstance(self.interpolation, (tuple, list)):
interpolation = random.choice(self.interpolation)
else:
interpolation = self.interpolation
img = F.resize(img, size, interpolation)
return img
def __repr__(self):
if isinstance(self.interpolation, (tuple, list)):
interpolate_str = ' '.join([interp_mode_to_str(x) for x in self.interpolation])
else:
interpolate_str = interp_mode_to_str(self.interpolation)
format_string = self.__class__.__name__ + '(size={0}'.format(self.size)
format_string += f', interpolation={interpolate_str})'
format_string += f', longest={self.longest:.3f})'
format_string += f', interpolation={interpolate_str}'
format_string += f', longest={self.longest:.3f}'
format_string += f', random_scale_prob={self.random_scale_prob:.3f}'
format_string += f', random_scale_range=(' \
f'{self.random_scale_range[0]:.3f}, {self.random_aspect_range[1]:.3f})'
format_string += f', random_aspect_prob={self.random_aspect_prob:.3f}'
format_string += f', random_aspect_range=(' \
f'{self.random_aspect_range[0]:.3f}, {self.random_aspect_range[1]:.3f}))'
return format_string
class TrimBorder(torch.nn.Module):
def __init__(
self,
border_size: int,
):
super().__init__()
self.border_size = border_size
def forward(self, img):
w, h = F.get_image_size(img)
top = left = self.border_size
top = min(top, h)
left = min(left, h)
height = max(0, h - 2 * self.border_size)
width = max(0, w - 2 * self.border_size)
return F.crop(img, top, left, height, width)

View File

@ -4,6 +4,7 @@ Factory methods for building image transforms for use with TIMM (PyTorch Image M
Hacked together by / Copyright 2019, Ross Wightman
"""
import math
from typing import Optional, Tuple, Union
import torch
from torchvision import transforms
@ -11,17 +12,29 @@ from torchvision import transforms
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, DEFAULT_CROP_PCT
from timm.data.auto_augment import rand_augment_transform, augment_and_mix_transform, auto_augment_transform
from timm.data.transforms import str_to_interp_mode, str_to_pil_interp, RandomResizedCropAndInterpolation,\
ResizeKeepRatio, CenterCropOrPad, ToNumpy
ResizeKeepRatio, CenterCropOrPad, RandomCropOrPad, TrimBorder, ToNumpy
from timm.data.random_erasing import RandomErasing
def transforms_noaug_train(
img_size=224,
interpolation='bilinear',
use_prefetcher=False,
mean=IMAGENET_DEFAULT_MEAN,
std=IMAGENET_DEFAULT_STD,
img_size: Union[int, Tuple[int, int]] = 224,
interpolation: str = 'bilinear',
use_prefetcher: bool = False,
mean: Tuple[float, ...] = IMAGENET_DEFAULT_MEAN,
std: Tuple[float, ...] = IMAGENET_DEFAULT_STD,
):
""" No-augmentation image transforms for training.
Args:
img_size: Target image size.
interpolation: Image interpolation mode.
mean: Image normalization mean.
std: Image normalization standard deviation.
use_prefetcher: Prefetcher enabled. Do not convert image to tensor or normalize.
Returns:
"""
if interpolation == 'random':
# random interpolation not supported with no-aug
interpolation = 'bilinear'
@ -37,41 +50,97 @@ def transforms_noaug_train(
transforms.ToTensor(),
transforms.Normalize(
mean=torch.tensor(mean),
std=torch.tensor(std))
std=torch.tensor(std)
)
]
return transforms.Compose(tfl)
def transforms_imagenet_train(
img_size=224,
scale=None,
ratio=None,
hflip=0.5,
vflip=0.,
color_jitter=0.4,
auto_augment=None,
interpolation='random',
use_prefetcher=False,
mean=IMAGENET_DEFAULT_MEAN,
std=IMAGENET_DEFAULT_STD,
re_prob=0.,
re_mode='const',
re_count=1,
re_num_splits=0,
separate=False,
force_color_jitter=False,
img_size: Union[int, Tuple[int, int]] = 224,
scale: Optional[Tuple[float, float]] = None,
ratio: Optional[Tuple[float, float]] = None,
train_crop_mode: Optional[str] = None,
hflip: float = 0.5,
vflip: float = 0.,
color_jitter: Union[float, Tuple[float, ...]] = 0.4,
color_jitter_prob: Optional[float] = None,
force_color_jitter: bool = False,
grayscale_prob: float = 0.,
gaussian_blur_prob: float = 0.,
auto_augment: Optional[str] = None,
interpolation: str = 'random',
mean: Tuple[float, ...] = IMAGENET_DEFAULT_MEAN,
std: Tuple[float, ...] = IMAGENET_DEFAULT_STD,
re_prob: float = 0.,
re_mode: str = 'const',
re_count: int = 1,
re_num_splits: int = 0,
use_prefetcher: bool = False,
separate: bool = False,
):
"""
""" ImageNet-oriented image transforms for training.
Args:
img_size: Target image size.
scale: Random resize scale range (crop area, < 1.0 => zoom in).
ratio: Random aspect ratio range (crop ratio for RRC, ratio adjustment factor for RKR).
hflip: Horizontal flip probability.
vflip: Vertical flip probability.
color_jitter: Random color jitter component factors (brightness, contrast, saturation, hue).
Scalar is applied as (scalar,) * 3 (no hue).
color_jitter_prob: Apply color jitter with this probability if not None (for SimlCLR-like aug).
force_color_jitter: Force color jitter where it is normally disabled (ie with RandAugment on).
grayscale_prob: Probability of converting image to grayscale (for SimCLR-like aug).
gaussian_blur_prob: Probability of applying gaussian blur (for SimCLR-like aug).
auto_augment: Auto augment configuration string (see auto_augment.py).
interpolation: Image interpolation mode.
mean: Image normalization mean.
std: Image normalization standard deviation.
re_prob: Random erasing probability.
re_mode: Random erasing fill mode.
re_count: Number of random erasing regions.
re_num_splits: Control split of random erasing across batch size.
use_prefetcher: Prefetcher enabled. Do not convert image to tensor or normalize.
separate: Output transforms in 3-stage tuple.
Returns:
If separate==True, the transforms are returned as a tuple of 3 separate transforms
for use in a mixing dataset that passes
* all data through the first (primary) transform, called the 'clean' data
* a portion of the data through the secondary transform
* normalizes and converts the branches above with the third, final transform
"""
train_crop_mode = train_crop_mode or 'rrc'
if train_crop_mode in ('rkrc', 'rkrr'):
# FIXME integration of RKR is a WIP
scale = tuple(scale or (0.8, 1.00))
ratio = tuple(ratio or (0.9, 1/.9))
primary_tfl = [
ResizeKeepRatio(
img_size,
interpolation=interpolation,
random_scale_prob=0.5,
random_scale_range=scale,
random_scale_area=True, # scale compatible with RRC
random_aspect_prob=0.5,
random_aspect_range=ratio,
),
CenterCropOrPad(img_size, padding_mode='reflect')
if train_crop_mode == 'rkrc' else
RandomCropOrPad(img_size, padding_mode='reflect')
]
else:
scale = tuple(scale or (0.08, 1.0)) # default imagenet scale range
ratio = tuple(ratio or (3. / 4., 4. / 3.)) # default imagenet ratio range
primary_tfl = [
RandomResizedCropAndInterpolation(img_size, scale=scale, ratio=ratio, interpolation=interpolation)]
RandomResizedCropAndInterpolation(
img_size,
scale=scale,
ratio=ratio,
interpolation=interpolation,
)
]
if hflip > 0.:
primary_tfl += [transforms.RandomHorizontalFlip(p=hflip)]
if vflip > 0.:
@ -111,8 +180,29 @@ def transforms_imagenet_train(
else:
# if it's a scalar, duplicate for brightness, contrast, and saturation, no hue
color_jitter = (float(color_jitter),) * 3
if color_jitter_prob is not None:
secondary_tfl += [
transforms.RandomApply([
transforms.ColorJitter(*color_jitter),
],
p=color_jitter_prob
)
]
else:
secondary_tfl += [transforms.ColorJitter(*color_jitter)]
if grayscale_prob:
secondary_tfl += [transforms.RandomGrayscale(p=grayscale_prob)]
if gaussian_blur_prob:
secondary_tfl += [
transforms.RandomApply([
transforms.GaussianBlur(kernel_size=23), # hardcoded for now
],
p=gaussian_blur_prob,
)
]
final_tfl = []
if use_prefetcher:
# prefetcher and collate will handle tensor conversion and norm
@ -122,11 +212,19 @@ def transforms_imagenet_train(
transforms.ToTensor(),
transforms.Normalize(
mean=torch.tensor(mean),
std=torch.tensor(std))
std=torch.tensor(std)
),
]
if re_prob > 0.:
final_tfl.append(
RandomErasing(re_prob, mode=re_mode, max_count=re_count, num_splits=re_num_splits, device='cpu'))
final_tfl += [
RandomErasing(
re_prob,
mode=re_mode,
max_count=re_count,
num_splits=re_num_splits,
device='cpu',
)
]
if separate:
return transforms.Compose(primary_tfl), transforms.Compose(secondary_tfl), transforms.Compose(final_tfl)
@ -135,14 +233,30 @@ def transforms_imagenet_train(
def transforms_imagenet_eval(
img_size=224,
crop_pct=None,
crop_mode=None,
interpolation='bilinear',
use_prefetcher=False,
mean=IMAGENET_DEFAULT_MEAN,
std=IMAGENET_DEFAULT_STD
img_size: Union[int, Tuple[int, int]] = 224,
crop_pct: Optional[float] = None,
crop_mode: Optional[str] = None,
crop_border_pixels: Optional[int] = None,
interpolation: str = 'bilinear',
mean: Tuple[float, ...] = IMAGENET_DEFAULT_MEAN,
std: Tuple[float, ...] = IMAGENET_DEFAULT_STD,
use_prefetcher: bool = False,
):
""" ImageNet-oriented image transform for evaluation and inference.
Args:
img_size: Target image size.
crop_pct: Crop percentage. Defaults to 0.875 when None.
crop_mode: Crop mode. One of ['squash', 'border', 'center']. Defaults to 'center' when None.
crop_border_pixels: Trim a border of specified # pixels around edge of original image.
interpolation: Image interpolation mode.
mean: Image normalization mean.
std: Image normalization standard deviation.
use_prefetcher: Prefetcher enabled. Do not convert image to tensor or normalize.
Returns:
Composed transform pipeline
"""
crop_pct = crop_pct or DEFAULT_CROP_PCT
if isinstance(img_size, (tuple, list)):
@ -152,10 +266,15 @@ def transforms_imagenet_eval(
scale_size = math.floor(img_size / crop_pct)
scale_size = (scale_size, scale_size)
tfl = []
if crop_border_pixels:
tfl += [TrimBorder(crop_border_pixels)]
if crop_mode == 'squash':
# squash mode scales each edge to 1/pct of target, then crops
# aspect ratio is not preserved, no img lost if crop_pct == 1.0
tfl = [
tfl += [
transforms.Resize(scale_size, interpolation=str_to_interp_mode(interpolation)),
transforms.CenterCrop(img_size),
]
@ -163,7 +282,7 @@ def transforms_imagenet_eval(
# scale the longest edge of image to 1/pct of target edge, add borders to pad, then crop
# no image lost if crop_pct == 1.0
fill = [round(255 * v) for v in mean]
tfl = [
tfl += [
ResizeKeepRatio(scale_size, interpolation=interpolation, longest=1.0),
CenterCropOrPad(img_size, fill=fill),
]
@ -172,12 +291,12 @@ def transforms_imagenet_eval(
# aspect ratio is preserved, crops center within image, no borders are added, image is lost
if scale_size[0] == scale_size[1]:
# simple case, use torchvision built-in Resize w/ shortest edge mode (scalar size arg)
tfl = [
tfl += [
transforms.Resize(scale_size[0], interpolation=str_to_interp_mode(interpolation))
]
else:
# resize shortest edge to matching target dim for non-square target
tfl = [ResizeKeepRatio(scale_size)]
# resize the shortest edge to matching target dim for non-square target
tfl += [ResizeKeepRatio(scale_size)]
tfl += [transforms.CenterCrop(img_size)]
if use_prefetcher:
@ -196,28 +315,65 @@ def transforms_imagenet_eval(
def create_transform(
input_size,
is_training=False,
use_prefetcher=False,
no_aug=False,
scale=None,
ratio=None,
hflip=0.5,
vflip=0.,
color_jitter=0.4,
auto_augment=None,
interpolation='bilinear',
mean=IMAGENET_DEFAULT_MEAN,
std=IMAGENET_DEFAULT_STD,
re_prob=0.,
re_mode='const',
re_count=1,
re_num_splits=0,
crop_pct=None,
crop_mode=None,
tf_preprocessing=False,
separate=False):
input_size: Union[int, Tuple[int, int], Tuple[int, int, int]] = 224,
is_training: bool = False,
no_aug: bool = False,
scale: Optional[Tuple[float, float]] = None,
ratio: Optional[Tuple[float, float]] = None,
hflip: float = 0.5,
vflip: float = 0.,
color_jitter: Union[float, Tuple[float, ...]] = 0.4,
color_jitter_prob: Optional[float] = None,
grayscale_prob: float = 0.,
gaussian_blur_prob: float = 0.,
auto_augment: Optional[str] = None,
interpolation: str = 'bilinear',
mean: Tuple[float, ...] = IMAGENET_DEFAULT_MEAN,
std: Tuple[float, ...] = IMAGENET_DEFAULT_STD,
re_prob: float = 0.,
re_mode: str = 'const',
re_count: int = 1,
re_num_splits: int = 0,
crop_pct: Optional[float] = None,
crop_mode: Optional[str] = None,
crop_border_pixels: Optional[int] = None,
tf_preprocessing: bool = False,
use_prefetcher: bool = False,
separate: bool = False,
):
"""
Args:
input_size: Target input size (channels, height, width) tuple or size scalar.
is_training: Return training (random) transforms.
no_aug: Disable augmentation for training (useful for debug).
scale: Random resize scale range (crop area, < 1.0 => zoom in).
ratio: Random aspect ratio range (crop ratio for RRC, ratio adjustment factor for RKR).
hflip: Horizontal flip probability.
vflip: Vertical flip probability.
color_jitter: Random color jitter component factors (brightness, contrast, saturation, hue).
Scalar is applied as (scalar,) * 3 (no hue).
color_jitter_prob: Apply color jitter with this probability if not None (for SimlCLR-like aug).
grayscale_prob: Probability of converting image to grayscale (for SimCLR-like aug).
gaussian_blur_prob: Probability of applying gaussian blur (for SimCLR-like aug).
auto_augment: Auto augment configuration string (see auto_augment.py).
interpolation: Image interpolation mode.
mean: Image normalization mean.
std: Image normalization standard deviation.
re_prob: Random erasing probability.
re_mode: Random erasing fill mode.
re_count: Number of random erasing regions.
re_num_splits: Control split of random erasing across batch size.
crop_pct: Inference crop percentage (output size / resize size).
crop_mode: Inference crop mode. One of ['squash', 'border', 'center']. Defaults to 'center' when None.
crop_border_pixels: Inference crop border of specified # pixels around edge of original image.
tf_preprocessing: Use TF 1.0 inference preprocessing for testing model ports
use_prefetcher: Pre-fetcher enabled. Do not convert image to tensor or normalize.
separate: Output transforms in 3-stage tuple.
Returns:
Composed transforms or tuple thereof
"""
if isinstance(input_size, (tuple, list)):
img_size = input_size[-2:]
else:
@ -227,7 +383,10 @@ def create_transform(
assert not separate, "Separate transforms not supported for TF preprocessing"
from timm.data.tf_preprocessing import TfPreprocessTransform
transform = TfPreprocessTransform(
is_training=is_training, size=img_size, interpolation=interpolation)
is_training=is_training,
size=img_size,
interpolation=interpolation,
)
else:
if is_training and no_aug:
assert not separate, "Cannot perform split augmentation with no_aug"
@ -246,6 +405,9 @@ def create_transform(
hflip=hflip,
vflip=vflip,
color_jitter=color_jitter,
color_jitter_prob=color_jitter_prob,
grayscale_prob=grayscale_prob,
gaussian_blur_prob=gaussian_blur_prob,
auto_augment=auto_augment,
interpolation=interpolation,
use_prefetcher=use_prefetcher,
@ -267,6 +429,7 @@ def create_transform(
std=std,
crop_pct=crop_pct,
crop_mode=crop_mode,
crop_border_pixels=crop_border_pixels,
)
return transform

View File

@ -93,10 +93,20 @@ group.add_argument('--train-split', metavar='NAME', default='train',
help='dataset train split (default: train)')
group.add_argument('--val-split', metavar='NAME', default='validation',
help='dataset validation split (default: validation)')
parser.add_argument('--train-num-samples', default=None, type=int,
metavar='N', help='Manually specify num samples in train split, for IterableDatasets.')
parser.add_argument('--val-num-samples', default=None, type=int,
metavar='N', help='Manually specify num samples in validation split, for IterableDatasets.')
group.add_argument('--dataset-download', action='store_true', default=False,
help='Allow download of dataset for torch/ and tfds/ datasets that support it.')
group.add_argument('--class-map', default='', type=str, metavar='FILENAME',
help='path to class to idx mapping file (default: "")')
group.add_argument('--input-img-mode', default=None, type=str,
help='Dataset image conversion mode for input images.')
group.add_argument('--input-key', default=None, type=str,
help='Dataset key for input images.')
group.add_argument('--target-key', default=None, type=str,
help='Dataset key for target labels.')
# Model parameters
group = parser.add_argument_group('Model parameters')
@ -245,6 +255,12 @@ group.add_argument('--vflip', type=float, default=0.,
help='Vertical flip training aug probability')
group.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT',
help='Color jitter factor (default: 0.4)')
group.add_argument('--color-jitter-prob', type=float, default=None, metavar='PCT',
help='Probability of applying any color jitter.')
group.add_argument('--grayscale-prob', type=float, default=None, metavar='PCT',
help='Probability of applying random grayscale conversion.')
group.add_argument('--gaussian-blur-prob', type=float, default=None, metavar='PCT',
help='Probability of applying gaussian blur.')
group.add_argument('--aa', type=str, default=None, metavar='NAME',
help='Use AutoAugment policy. "v0" or "original". (default: None)'),
group.add_argument('--aug-repeats', type=float, default=0,
@ -594,6 +610,10 @@ def main():
# create the train and eval datasets
if args.data and not args.data_dir:
args.data_dir = args.data
if args.input_img_mode is None:
input_img_mode = 'RGB' if data_config['input_size'][0] == 3 else 'L'
else:
input_img_mode = args.input_img_mode
dataset_train = create_dataset(
args.dataset,
root=args.data_dir,
@ -604,6 +624,10 @@ def main():
batch_size=args.batch_size,
seed=args.seed,
repeats=args.epoch_repeats,
input_img_mode=input_img_mode,
input_key=args.input_key,
target_key=args.target_key,
num_samples=args.train_num_samples,
)
dataset_eval = create_dataset(
@ -614,6 +638,10 @@ def main():
class_map=args.class_map,
download=args.dataset_download,
batch_size=args.batch_size,
input_img_mode=input_img_mode,
input_key=args.input_key,
target_key=args.target_key,
num_samples=args.val_num_samples,
)
# setup mixup / cutmix
@ -650,7 +678,6 @@ def main():
input_size=data_config['input_size'],
batch_size=args.batch_size,
is_training=True,
use_prefetcher=args.prefetcher,
no_aug=args.no_aug,
re_prob=args.reprob,
re_mode=args.remode,
@ -661,6 +688,9 @@ def main():
hflip=args.hflip,
vflip=args.vflip,
color_jitter=args.color_jitter,
color_jitter_prob=args.color_jitter_prob,
grayscale_prob=args.grayscale_prob,
gaussian_blur_prob=args.gaussian_blur_prob,
auto_augment=args.aa,
num_aug_repeats=args.aug_repeats,
num_aug_splits=num_aug_splits,
@ -672,6 +702,7 @@ def main():
collate_fn=collate_fn,
pin_memory=args.pin_mem,
device=device,
use_prefetcher=args.prefetcher,
use_multi_epochs_loader=args.use_multi_epochs_loader,
worker_seeding=args.worker_seeding,
)
@ -685,7 +716,6 @@ def main():
input_size=data_config['input_size'],
batch_size=args.validation_batch_size or args.batch_size,
is_training=False,
use_prefetcher=args.prefetcher,
interpolation=data_config['interpolation'],
mean=data_config['mean'],
std=data_config['std'],
@ -694,6 +724,7 @@ def main():
crop_pct=data_config['crop_pct'],
pin_memory=args.pin_mem,
device=device,
use_prefetcher=args.prefetcher,
)
# setup loss function

View File

@ -61,10 +61,23 @@ parser.add_argument('--dataset', metavar='NAME', default='',
help='dataset type + name ("<type>/<name>") (default: ImageFolder or ImageTar if empty)')
parser.add_argument('--split', metavar='NAME', default='validation',
help='dataset split (default: validation)')
parser.add_argument('--num-samples', default=None, type=int,
metavar='N', help='Manually specify num samples in dataset split, for IterableDatasets.')
parser.add_argument('--dataset-download', action='store_true', default=False,
help='Allow download of dataset for torch/ and tfds/ datasets that support it.')
parser.add_argument('--class-map', default='', type=str, metavar='FILENAME',
help='path to class to idx mapping file (default: "")')
parser.add_argument('--input-key', default=None, type=str,
help='Dataset key for input images.')
parser.add_argument('--input-img-mode', default=None, type=str,
help='Dataset image conversion mode for input images.')
parser.add_argument('--target-key', default=None, type=str,
help='Dataset key for target labels.')
parser.add_argument('--model', '-m', metavar='NAME', default='dpn92',
help='model architecture (default: dpn92)')
parser.add_argument('--pretrained', dest='pretrained', action='store_true',
help='use pre-trained model')
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
help='number of data loading workers (default: 4)')
parser.add_argument('-b', '--batch-size', default=256, type=int,
@ -81,6 +94,8 @@ parser.add_argument('--crop-pct', default=None, type=float,
metavar='N', help='Input image center crop pct')
parser.add_argument('--crop-mode', default=None, type=str,
metavar='N', help='Input image crop mode (squash, border, center). Model default if None.')
parser.add_argument('--crop-border-pixels', type=int, default=None,
help='Crop pixels from image border.')
parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
help='Override mean pixel value of dataset')
parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',
@ -89,16 +104,12 @@ parser.add_argument('--interpolation', default='', type=str, metavar='NAME',
help='Image resize interpolation type (overrides model)')
parser.add_argument('--num-classes', type=int, default=None,
help='Number classes in dataset')
parser.add_argument('--class-map', default='', type=str, metavar='FILENAME',
help='path to class to idx mapping file (default: "")')
parser.add_argument('--gp', default=None, type=str, metavar='POOL',
help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.')
parser.add_argument('--log-freq', default=10, type=int,
metavar='N', help='batch logging frequency (default: 10)')
parser.add_argument('--checkpoint', default='', type=str, metavar='PATH',
help='path to latest checkpoint (default: none)')
parser.add_argument('--pretrained', dest='pretrained', action='store_true',
help='use pre-trained model')
parser.add_argument('--num-gpu', type=int, default=1,
help='Number of GPUS to use')
parser.add_argument('--test-pool', dest='test_pool', action='store_true',
@ -249,6 +260,10 @@ def validate(args):
criterion = nn.CrossEntropyLoss().to(device)
root_dir = args.data or args.data_dir
if args.input_img_mode is None:
input_img_mode = 'RGB' if data_config['input_size'][0] == 3 else 'L'
else:
input_img_mode = args.input_img_mode
dataset = create_dataset(
root=root_dir,
name=args.dataset,
@ -256,6 +271,10 @@ def validate(args):
download=args.dataset_download,
load_bytes=args.tf_preprocessing,
class_map=args.class_map,
num_samples=args.num_samples,
input_key=args.input_key,
input_img_mode=input_img_mode,
target_key=args.target_key,
)
if args.valid_labels:
@ -281,6 +300,7 @@ def validate(args):
num_workers=args.workers,
crop_pct=crop_pct,
crop_mode=data_config['crop_mode'],
crop_border_pixels=args.crop_border_pixels,
pin_memory=args.pin_mem,
device=device,
tf_preprocessing=args.tf_preprocessing,