mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Significant transforms, dataset, dataloading enhancements.
This commit is contained in:
parent
b5a4fa9c3b
commit
be0944edae
@ -27,7 +27,7 @@ class ImageDataset(data.Dataset):
|
||||
split='train',
|
||||
class_map=None,
|
||||
load_bytes=False,
|
||||
img_mode='RGB',
|
||||
input_img_mode='RGB',
|
||||
transform=None,
|
||||
target_transform=None,
|
||||
):
|
||||
@ -40,7 +40,7 @@ class ImageDataset(data.Dataset):
|
||||
)
|
||||
self.reader = reader
|
||||
self.load_bytes = load_bytes
|
||||
self.img_mode = img_mode
|
||||
self.input_img_mode = input_img_mode
|
||||
self.transform = transform
|
||||
self.target_transform = target_transform
|
||||
self._consecutive_errors = 0
|
||||
@ -59,8 +59,8 @@ class ImageDataset(data.Dataset):
|
||||
raise e
|
||||
self._consecutive_errors = 0
|
||||
|
||||
if self.img_mode and not self.load_bytes:
|
||||
img = img.convert(self.img_mode)
|
||||
if self.input_img_mode and not self.load_bytes:
|
||||
img = img.convert(self.input_img_mode)
|
||||
if self.transform is not None:
|
||||
img = self.transform(img)
|
||||
|
||||
@ -90,12 +90,17 @@ class IterableImageDataset(data.IterableDataset):
|
||||
split='train',
|
||||
class_map=None,
|
||||
is_training=False,
|
||||
batch_size=None,
|
||||
batch_size=1,
|
||||
num_samples=None,
|
||||
seed=42,
|
||||
repeats=0,
|
||||
download=False,
|
||||
input_img_mode='RGB',
|
||||
input_key=None,
|
||||
target_key=None,
|
||||
transform=None,
|
||||
target_transform=None,
|
||||
max_steps=None,
|
||||
):
|
||||
assert reader is not None
|
||||
if isinstance(reader, str):
|
||||
@ -106,9 +111,14 @@ class IterableImageDataset(data.IterableDataset):
|
||||
class_map=class_map,
|
||||
is_training=is_training,
|
||||
batch_size=batch_size,
|
||||
num_samples=num_samples,
|
||||
seed=seed,
|
||||
repeats=repeats,
|
||||
download=download,
|
||||
input_img_mode=input_img_mode,
|
||||
input_key=input_key,
|
||||
target_key=target_key,
|
||||
max_steps=max_steps,
|
||||
)
|
||||
else:
|
||||
self.reader = reader
|
||||
|
@ -3,6 +3,7 @@
|
||||
Hacked together by / Copyright 2021, Ross Wightman
|
||||
"""
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from torchvision.datasets import CIFAR100, CIFAR10, MNIST, KMNIST, FashionMNIST, ImageFolder
|
||||
try:
|
||||
@ -60,22 +61,24 @@ def _search_split(root, split):
|
||||
|
||||
|
||||
def create_dataset(
|
||||
name,
|
||||
root,
|
||||
split='validation',
|
||||
search_split=True,
|
||||
class_map=None,
|
||||
load_bytes=False,
|
||||
is_training=False,
|
||||
download=False,
|
||||
batch_size=None,
|
||||
seed=42,
|
||||
repeats=0,
|
||||
**kwargs
|
||||
name: str,
|
||||
root: Optional[str] = None,
|
||||
split: str = 'validation',
|
||||
search_split: bool = True,
|
||||
class_map: dict = None,
|
||||
load_bytes: bool = False,
|
||||
is_training: bool = False,
|
||||
download: bool = False,
|
||||
batch_size: int = 1,
|
||||
num_samples: Optional[int] = None,
|
||||
seed: int = 42,
|
||||
repeats: int = 0,
|
||||
input_img_mode: str = 'RGB',
|
||||
**kwargs,
|
||||
):
|
||||
""" Dataset factory method
|
||||
|
||||
In parenthesis after each arg are the type of dataset supported for each arg, one of:
|
||||
In parentheses after each arg are the type of dataset supported for each arg, one of:
|
||||
* folder - default, timm folder (or tar) based ImageDataset
|
||||
* torch - torchvision based datasets
|
||||
* HFDS - Hugging Face Datasets
|
||||
@ -97,11 +100,13 @@ def create_dataset(
|
||||
batch_size: batch size hint for (TFDS, WDS)
|
||||
seed: seed for iterable datasets (TFDS, WDS)
|
||||
repeats: dataset repeats per iteration i.e. epoch (TFDS, WDS)
|
||||
input_img_mode: Input image color conversion mode e.g. 'RGB', 'L' (folder, TFDS, WDS, HFDS)
|
||||
**kwargs: other args to pass to dataset
|
||||
|
||||
Returns:
|
||||
Dataset object
|
||||
"""
|
||||
kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
||||
name = name.lower()
|
||||
if name.startswith('torch/'):
|
||||
name = name.split('/', 2)[-1]
|
||||
@ -151,7 +156,29 @@ def create_dataset(
|
||||
elif name.startswith('hfds/'):
|
||||
# NOTE right now, HF datasets default arrow format is a random-access Dataset,
|
||||
# There will be a IterableDataset variant too, TBD
|
||||
ds = ImageDataset(root, reader=name, split=split, class_map=class_map, **kwargs)
|
||||
ds = ImageDataset(
|
||||
root,
|
||||
reader=name,
|
||||
split=split,
|
||||
class_map=class_map,
|
||||
input_img_mode=input_img_mode,
|
||||
**kwargs,
|
||||
)
|
||||
elif name.startswith('hfids/'):
|
||||
ds = IterableImageDataset(
|
||||
root,
|
||||
reader=name,
|
||||
split=split,
|
||||
class_map=class_map,
|
||||
is_training=is_training,
|
||||
download=download,
|
||||
batch_size=batch_size,
|
||||
num_samples=num_samples,
|
||||
repeats=repeats,
|
||||
seed=seed,
|
||||
input_img_mode=input_img_mode,
|
||||
**kwargs
|
||||
)
|
||||
elif name.startswith('tfds/'):
|
||||
ds = IterableImageDataset(
|
||||
root,
|
||||
@ -161,8 +188,10 @@ def create_dataset(
|
||||
is_training=is_training,
|
||||
download=download,
|
||||
batch_size=batch_size,
|
||||
num_samples=num_samples,
|
||||
repeats=repeats,
|
||||
seed=seed,
|
||||
input_img_mode=input_img_mode,
|
||||
**kwargs
|
||||
)
|
||||
elif name.startswith('wds/'):
|
||||
@ -173,8 +202,10 @@ def create_dataset(
|
||||
class_map=class_map,
|
||||
is_training=is_training,
|
||||
batch_size=batch_size,
|
||||
num_samples=num_samples,
|
||||
repeats=repeats,
|
||||
seed=seed,
|
||||
input_img_mode=input_img_mode,
|
||||
**kwargs
|
||||
)
|
||||
else:
|
||||
@ -182,5 +213,12 @@ def create_dataset(
|
||||
if search_split and os.path.isdir(root):
|
||||
# look for split specific sub-folder in root
|
||||
root = _search_split(root, split)
|
||||
ds = ImageDataset(root, reader=name, class_map=class_map, load_bytes=load_bytes, **kwargs)
|
||||
ds = ImageDataset(
|
||||
root,
|
||||
reader=name,
|
||||
class_map=class_map,
|
||||
load_bytes=load_bytes,
|
||||
input_img_mode=input_img_mode,
|
||||
**kwargs,
|
||||
)
|
||||
return ds
|
||||
|
@ -10,14 +10,14 @@ import random
|
||||
from contextlib import suppress
|
||||
from functools import partial
|
||||
from itertools import repeat
|
||||
from typing import Callable
|
||||
from typing import Callable, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.utils.data
|
||||
import numpy as np
|
||||
|
||||
from .constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from .dataset import IterableImageDataset
|
||||
from .dataset import IterableImageDataset, ImageDataset
|
||||
from .distributed_sampler import OrderedDistributedSampler, RepeatAugSampler
|
||||
from .random_erasing import RandomErasing
|
||||
from .mixup import FastCollateMixup
|
||||
@ -187,41 +187,91 @@ def _worker_init(worker_id, worker_seeding='all'):
|
||||
|
||||
|
||||
def create_loader(
|
||||
dataset,
|
||||
input_size,
|
||||
batch_size,
|
||||
is_training=False,
|
||||
use_prefetcher=True,
|
||||
no_aug=False,
|
||||
re_prob=0.,
|
||||
re_mode='const',
|
||||
re_count=1,
|
||||
re_split=False,
|
||||
scale=None,
|
||||
ratio=None,
|
||||
hflip=0.5,
|
||||
vflip=0.,
|
||||
color_jitter=0.4,
|
||||
auto_augment=None,
|
||||
num_aug_repeats=0,
|
||||
num_aug_splits=0,
|
||||
interpolation='bilinear',
|
||||
mean=IMAGENET_DEFAULT_MEAN,
|
||||
std=IMAGENET_DEFAULT_STD,
|
||||
num_workers=1,
|
||||
distributed=False,
|
||||
crop_pct=None,
|
||||
crop_mode=None,
|
||||
collate_fn=None,
|
||||
pin_memory=False,
|
||||
fp16=False, # deprecated, use img_dtype
|
||||
img_dtype=torch.float32,
|
||||
device=torch.device('cuda'),
|
||||
tf_preprocessing=False,
|
||||
use_multi_epochs_loader=False,
|
||||
persistent_workers=True,
|
||||
worker_seeding='all',
|
||||
dataset: Union[ImageDataset, IterableImageDataset],
|
||||
input_size: Union[int, Tuple[int, int], Tuple[int, int, int]],
|
||||
batch_size: int,
|
||||
is_training: bool = False,
|
||||
no_aug: bool = False,
|
||||
re_prob: float = 0.,
|
||||
re_mode: str = 'const',
|
||||
re_count: int = 1,
|
||||
re_split: bool = False,
|
||||
scale: Optional[Tuple[float, float]] = None,
|
||||
ratio: Optional[Tuple[float, float]] = None,
|
||||
hflip: float = 0.5,
|
||||
vflip: float = 0.,
|
||||
color_jitter: float = 0.4,
|
||||
color_jitter_prob: Optional[float] = None,
|
||||
grayscale_prob: float = 0.,
|
||||
gaussian_blur_prob: float = 0.,
|
||||
auto_augment: Optional[str] = None,
|
||||
num_aug_repeats: int = 0,
|
||||
num_aug_splits: int = 0,
|
||||
interpolation: str = 'bilinear',
|
||||
mean: Tuple[float, ...] = IMAGENET_DEFAULT_MEAN,
|
||||
std: Tuple[float, ...] = IMAGENET_DEFAULT_STD,
|
||||
num_workers: int = 1,
|
||||
distributed: bool = False,
|
||||
crop_pct: Optional[float] = None,
|
||||
crop_mode: Optional[str] = None,
|
||||
crop_border_pixels: Optional[int] = None,
|
||||
collate_fn: Optional[Callable] = None,
|
||||
pin_memory: bool = False,
|
||||
fp16: bool = False, # deprecated, use img_dtype
|
||||
img_dtype: torch.dtype = torch.float32,
|
||||
device: torch.device = torch.device('cuda'),
|
||||
use_prefetcher: bool = True,
|
||||
use_multi_epochs_loader: bool = False,
|
||||
persistent_workers: bool = True,
|
||||
worker_seeding: str = 'all',
|
||||
tf_preprocessing: bool = False,
|
||||
):
|
||||
"""
|
||||
|
||||
Args:
|
||||
dataset: The image dataset to load.
|
||||
input_size: Target input size (channels, height, width) tuple or size scalar.
|
||||
batch_size: Number of samples in a batch.
|
||||
is_training: Return training (random) transforms.
|
||||
no_aug: Disable augmentation for training (useful for debug).
|
||||
re_prob: Random erasing probability.
|
||||
re_mode: Random erasing fill mode.
|
||||
re_count: Number of random erasing regions.
|
||||
re_split: Control split of random erasing across batch size.
|
||||
scale: Random resize scale range (crop area, < 1.0 => zoom in).
|
||||
ratio: Random aspect ratio range (crop ratio for RRC, ratio adjustment factor for RKR).
|
||||
hflip: Horizontal flip probability.
|
||||
vflip: Vertical flip probability.
|
||||
color_jitter: Random color jitter component factors (brightness, contrast, saturation, hue).
|
||||
Scalar is applied as (scalar,) * 3 (no hue).
|
||||
color_jitter_prob: Apply color jitter with this probability if not None (for SimlCLR-like aug
|
||||
grayscale_prob: Probability of converting image to grayscale (for SimCLR-like aug).
|
||||
gaussian_blur_prob: Probability of applying gaussian blur (for SimCLR-like aug).
|
||||
auto_augment: Auto augment configuration string (see auto_augment.py).
|
||||
num_aug_repeats: Enable special sampler to repeat same augmentation across distributed GPUs.
|
||||
num_aug_splits: Enable mode where augmentations can be split across the batch.
|
||||
interpolation: Image interpolation mode.
|
||||
mean: Image normalization mean.
|
||||
std: Image normalization standard deviation.
|
||||
num_workers: Num worker processes per DataLoader.
|
||||
distributed: Enable dataloading for distributed training.
|
||||
crop_pct: Inference crop percentage (output size / resize size).
|
||||
crop_mode: Inference crop mode. One of ['squash', 'border', 'center']. Defaults to 'center' when None.
|
||||
crop_border_pixels: Inference crop border of specified # pixels around edge of original image.
|
||||
collate_fn: Override default collate_fn.
|
||||
pin_memory: Pin memory for device transfer.
|
||||
fp16: Deprecated argument for half-precision input dtype. Use img_dtype.
|
||||
img_dtype: Data type for input image.
|
||||
device: Device to transfer inputs and targets to.
|
||||
use_prefetcher: Use efficient pre-fetcher to load samples onto device.
|
||||
use_multi_epochs_loader:
|
||||
persistent_workers: Enable persistent worker processes.
|
||||
worker_seeding: Control worker random seeding at init.
|
||||
tf_preprocessing: Use TF 1.0 inference preprocessing for testing model ports.
|
||||
|
||||
Returns:
|
||||
DataLoader
|
||||
"""
|
||||
re_num_splits = 0
|
||||
if re_split:
|
||||
# apply RE to second half of batch if no aug split otherwise line up with aug split
|
||||
@ -229,24 +279,28 @@ def create_loader(
|
||||
dataset.transform = create_transform(
|
||||
input_size,
|
||||
is_training=is_training,
|
||||
use_prefetcher=use_prefetcher,
|
||||
no_aug=no_aug,
|
||||
scale=scale,
|
||||
ratio=ratio,
|
||||
hflip=hflip,
|
||||
vflip=vflip,
|
||||
color_jitter=color_jitter,
|
||||
color_jitter_prob=color_jitter_prob,
|
||||
grayscale_prob=grayscale_prob,
|
||||
gaussian_blur_prob=gaussian_blur_prob,
|
||||
auto_augment=auto_augment,
|
||||
interpolation=interpolation,
|
||||
mean=mean,
|
||||
std=std,
|
||||
crop_pct=crop_pct,
|
||||
crop_mode=crop_mode,
|
||||
tf_preprocessing=tf_preprocessing,
|
||||
crop_border_pixels=crop_border_pixels,
|
||||
re_prob=re_prob,
|
||||
re_mode=re_mode,
|
||||
re_count=re_count,
|
||||
re_num_splits=re_num_splits,
|
||||
tf_preprocessing=tf_preprocessing,
|
||||
use_prefetcher=use_prefetcher,
|
||||
separate=num_aug_splits > 0,
|
||||
)
|
||||
|
||||
|
@ -1,10 +1,17 @@
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from .reader_image_folder import ReaderImageFolder
|
||||
from .reader_image_in_tar import ReaderImageInTar
|
||||
|
||||
|
||||
def create_reader(name, root, split='train', **kwargs):
|
||||
def create_reader(
|
||||
name: str,
|
||||
root: Optional[str] = None,
|
||||
split: str = 'train',
|
||||
**kwargs,
|
||||
):
|
||||
kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
||||
name = name.lower()
|
||||
name = name.split('/', 1)
|
||||
prefix = ''
|
||||
@ -15,15 +22,18 @@ def create_reader(name, root, split='train', **kwargs):
|
||||
# FIXME improve the selection right now just tfds prefix or fallback path, will need options to
|
||||
# explicitly select other options shortly
|
||||
if prefix == 'hfds':
|
||||
from .reader_hfds import ReaderHfds # defer tensorflow import
|
||||
reader = ReaderHfds(root, name, split=split, **kwargs)
|
||||
from .reader_hfds import ReaderHfds # defer Hf datasets import
|
||||
reader = ReaderHfds(name=name, root=root, split=split, **kwargs)
|
||||
elif prefix == 'hfids':
|
||||
from .reader_hfids import ReaderHfids # defer HF datasets import
|
||||
reader = ReaderHfids(name=name, root=root, split=split, **kwargs)
|
||||
elif prefix == 'tfds':
|
||||
from .reader_tfds import ReaderTfds # defer tensorflow import
|
||||
reader = ReaderTfds(root, name, split=split, **kwargs)
|
||||
reader = ReaderTfds(name=name, root=root, split=split, **kwargs)
|
||||
elif prefix == 'wds':
|
||||
from .reader_wds import ReaderWds
|
||||
kwargs.pop('download', False)
|
||||
reader = ReaderWds(root, name, split=split, **kwargs)
|
||||
reader = ReaderWds(root=root, name=name, split=split, **kwargs)
|
||||
else:
|
||||
assert os.path.exists(root)
|
||||
# default fallback path (backwards compat), use image tar if root is a .tar file, otherwise image folder
|
||||
|
@ -4,6 +4,8 @@ Hacked together by / Copyright 2022 Ross Wightman
|
||||
"""
|
||||
import io
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from PIL import Image
|
||||
@ -12,7 +14,7 @@ try:
|
||||
import datasets
|
||||
except ImportError as e:
|
||||
print("Please install Hugging Face datasets package `pip install datasets`.")
|
||||
exit(1)
|
||||
raise e
|
||||
from .class_map import load_class_map
|
||||
from .reader import Reader
|
||||
|
||||
@ -29,12 +31,13 @@ class ReaderHfds(Reader):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root,
|
||||
name,
|
||||
split='train',
|
||||
class_map=None,
|
||||
label_key='label',
|
||||
download=False,
|
||||
name: str,
|
||||
root: Optional[str] = None,
|
||||
split: str = 'train',
|
||||
class_map: dict = None,
|
||||
image_key: str = 'image',
|
||||
target_key: str = 'label',
|
||||
download: bool = False,
|
||||
):
|
||||
"""
|
||||
"""
|
||||
@ -47,9 +50,10 @@ class ReaderHfds(Reader):
|
||||
cache_dir=self.root, # timm doesn't expect hidden cache dir for datasets, specify a path
|
||||
)
|
||||
# leave decode for caller, plus we want easy access to original path names...
|
||||
self.dataset = self.dataset.cast_column('image', datasets.Image(decode=False))
|
||||
self.dataset = self.dataset.cast_column(image_key, datasets.Image(decode=False))
|
||||
|
||||
self.label_key = label_key
|
||||
self.image_key = image_key
|
||||
self.label_key = target_key
|
||||
self.remap_class = False
|
||||
if class_map:
|
||||
self.class_to_idx = load_class_map(class_map)
|
||||
@ -61,7 +65,7 @@ class ReaderHfds(Reader):
|
||||
|
||||
def __getitem__(self, index):
|
||||
item = self.dataset[index]
|
||||
image = item['image']
|
||||
image = item[self.image_key]
|
||||
if 'bytes' in image and image['bytes']:
|
||||
image = io.BytesIO(image['bytes'])
|
||||
else:
|
||||
@ -77,4 +81,4 @@ class ReaderHfds(Reader):
|
||||
|
||||
def _filename(self, index, basename=False, absolute=False):
|
||||
item = self.dataset[index]
|
||||
return item['image']['path']
|
||||
return item[self.image_key]['path']
|
||||
|
213
timm/data/readers/reader_hfids.py
Normal file
213
timm/data/readers/reader_hfids.py
Normal file
@ -0,0 +1,213 @@
|
||||
""" Dataset reader for HF IterableDataset
|
||||
"""
|
||||
import math
|
||||
import os
|
||||
from itertools import repeat, chain
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from PIL import Image
|
||||
|
||||
try:
|
||||
import datasets
|
||||
from datasets.distributed import split_dataset_by_node
|
||||
from datasets.splits import SplitInfo
|
||||
except ImportError as e:
|
||||
print("Please install Hugging Face datasets package `pip install datasets`.")
|
||||
raise e
|
||||
|
||||
|
||||
from .class_map import load_class_map
|
||||
from .reader import Reader
|
||||
from .shared_count import SharedCount
|
||||
|
||||
|
||||
SHUFFLE_SIZE = int(os.environ.get('HFIDS_SHUFFLE_SIZE', 4096))
|
||||
|
||||
|
||||
class ReaderHfids(Reader):
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
root: Optional[str] = None,
|
||||
split: str = 'train',
|
||||
is_training: bool = False,
|
||||
batch_size: int = 1,
|
||||
download: bool = False,
|
||||
repeats: int = 0,
|
||||
seed: int = 42,
|
||||
class_map: Optional[dict] = None,
|
||||
input_key: str = 'image',
|
||||
input_img_mode: str = 'RGB',
|
||||
target_key: str = 'label',
|
||||
target_img_mode: str = '',
|
||||
shuffle_size: Optional[int] = None,
|
||||
num_samples: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.root = root
|
||||
self.split = split
|
||||
self.is_training = is_training
|
||||
self.batch_size = batch_size
|
||||
self.download = download
|
||||
self.repeats = repeats
|
||||
self.common_seed = seed # a seed that's fixed across all worker / distributed instances
|
||||
self.shuffle_size = shuffle_size or SHUFFLE_SIZE
|
||||
|
||||
self.input_key = input_key
|
||||
self.input_img_mode = input_img_mode
|
||||
self.target_key = target_key
|
||||
self.target_img_mode = target_img_mode
|
||||
|
||||
self.builder = datasets.load_dataset_builder(name, cache_dir=root)
|
||||
if download:
|
||||
self.builder.download_and_prepare()
|
||||
|
||||
split_info: Optional[SplitInfo] = None
|
||||
if self.builder.info.splits and split in self.builder.info.splits:
|
||||
if isinstance(self.builder.info.splits[split], SplitInfo):
|
||||
split_info: Optional[SplitInfo] = self.builder.info.splits[split]
|
||||
|
||||
if num_samples:
|
||||
self.num_samples = num_samples
|
||||
elif split_info and split_info.num_examples:
|
||||
self.num_samples = split_info.num_examples
|
||||
else:
|
||||
raise ValueError(
|
||||
"Dataset length is unknown, please pass `num_samples` explicitely. "
|
||||
"The number of steps needs to be known in advance for the learning rate scheduler."
|
||||
)
|
||||
|
||||
self.remap_class = False
|
||||
if class_map:
|
||||
self.class_to_idx = load_class_map(class_map)
|
||||
self.remap_class = True
|
||||
else:
|
||||
self.class_to_idx = {}
|
||||
|
||||
# Distributed world state
|
||||
self.dist_rank = 0
|
||||
self.dist_num_replicas = 1
|
||||
if dist.is_available() and dist.is_initialized() and dist.get_world_size() > 1:
|
||||
self.dist_rank = dist.get_rank()
|
||||
self.dist_num_replicas = dist.get_world_size()
|
||||
|
||||
# Attributes that are updated in _lazy_init
|
||||
self.worker_info = None
|
||||
self.worker_id = 0
|
||||
self.num_workers = 1
|
||||
self.global_worker_id = 0
|
||||
self.global_num_workers = 1
|
||||
|
||||
# Initialized lazily on each dataloader worker process
|
||||
self.ds: Optional[datasets.IterableDataset] = None
|
||||
self.epoch = SharedCount()
|
||||
|
||||
def set_epoch(self, count):
|
||||
# to update the shuffling effective_seed = seed + epoch
|
||||
self.epoch.value = count
|
||||
|
||||
def set_loader_cfg(
|
||||
self,
|
||||
num_workers: Optional[int] = None,
|
||||
):
|
||||
if self.ds is not None:
|
||||
return
|
||||
if num_workers is not None:
|
||||
self.num_workers = num_workers
|
||||
self.global_num_workers = self.dist_num_replicas * self.num_workers
|
||||
|
||||
def _lazy_init(self):
|
||||
""" Lazily initialize worker (in worker processes)
|
||||
"""
|
||||
if self.worker_info is None:
|
||||
worker_info = torch.utils.data.get_worker_info()
|
||||
if worker_info is not None:
|
||||
self.worker_info = worker_info
|
||||
self.worker_id = worker_info.id
|
||||
self.num_workers = worker_info.num_workers
|
||||
self.global_num_workers = self.dist_num_replicas * self.num_workers
|
||||
self.global_worker_id = self.dist_rank * self.num_workers + self.worker_id
|
||||
|
||||
if self.download:
|
||||
dataset = self.builder.as_dataset(split=self.split)
|
||||
# to distribute evenly to workers
|
||||
ds = dataset.to_iterable_dataset(num_shards=self.global_num_workers)
|
||||
else:
|
||||
# in this case the number of shard is determined by the number of remote files
|
||||
ds = self.builder.as_streaming_dataset(split=self.split)
|
||||
|
||||
if self.is_training:
|
||||
# will shuffle the list of shards and use a shuffle buffer
|
||||
ds = ds.shuffle(seed=self.common_seed, buffer_size=self.shuffle_size)
|
||||
|
||||
# Distributed:
|
||||
# The dataset has a number of shards that is a factor of `dist_num_replicas` (i.e. if `ds.n_shards % dist_num_replicas == 0`),
|
||||
# so the shards are evenly assigned across the nodes.
|
||||
# If it's not the case for dataset streaming, each node keeps 1 example out of `dist_num_replicas`, skipping the other examples.
|
||||
|
||||
# Workers:
|
||||
# In a node, datasets.IterableDataset assigns the shards assigned to the node as evenly as possible to workers.
|
||||
self.ds = split_dataset_by_node(ds, rank=self.dist_rank, world_size=self.dist_num_replicas)
|
||||
|
||||
def _num_samples_per_worker(self):
|
||||
num_worker_samples = \
|
||||
max(1, self.repeats) * self.num_samples / max(self.global_num_workers, self.dist_num_replicas)
|
||||
if self.is_training or self.dist_num_replicas > 1:
|
||||
num_worker_samples = math.ceil(num_worker_samples)
|
||||
if self.is_training and self.batch_size is not None:
|
||||
num_worker_samples = math.ceil(num_worker_samples / self.batch_size) * self.batch_size
|
||||
return int(num_worker_samples)
|
||||
|
||||
def __iter__(self):
|
||||
if self.ds is None:
|
||||
self._lazy_init()
|
||||
self.ds.set_epoch(self.epoch.value)
|
||||
|
||||
target_sample_count = self._num_samples_per_worker()
|
||||
sample_count = 0
|
||||
ds_iter = iter(self.ds)
|
||||
if self.is_training:
|
||||
ds_iter = chain.from_iterable(repeat(ds_iter))
|
||||
for sample in ds_iter:
|
||||
input_data: Image.Image = sample[self.input_key]
|
||||
if self.input_img_mode and input_data.mode != self.input_img_mode:
|
||||
input_data = input_data.convert(self.input_img_mode)
|
||||
target_data = sample[self.target_key]
|
||||
if self.target_img_mode:
|
||||
assert isinstance(target_data, Image.Image), "target_img_mode is specified but target is not an image"
|
||||
if target_data.mode != self.target_img_mode:
|
||||
target_data = target_data.convert(self.target_img_mode)
|
||||
elif self.remap_class:
|
||||
target_data = self.class_to_idx[target_data]
|
||||
yield input_data, target_data
|
||||
sample_count += 1
|
||||
if self.is_training and sample_count >= target_sample_count:
|
||||
break
|
||||
|
||||
def __len__(self):
|
||||
num_samples = self._num_samples_per_worker() * self.num_workers
|
||||
return num_samples
|
||||
|
||||
def _filename(self, index, basename=False, absolute=False):
|
||||
assert False, "Not supported" # no random access to examples
|
||||
|
||||
def filenames(self, basename=False, absolute=False):
|
||||
""" Return all filenames in dataset, overrides base"""
|
||||
if self.ds is None:
|
||||
self._lazy_init()
|
||||
names = []
|
||||
for sample in self.ds:
|
||||
if 'file_name' in sample:
|
||||
name = sample['file_name']
|
||||
elif 'filename' in sample:
|
||||
name = sample['filename']
|
||||
elif 'id' in sample:
|
||||
name = sample['id']
|
||||
elif 'image_id' in sample:
|
||||
name = sample['image_id']
|
||||
else:
|
||||
assert False, "No supported name field present"
|
||||
names.append(name)
|
||||
return names
|
@ -61,14 +61,23 @@ class ReaderImageFolder(Reader):
|
||||
def __init__(
|
||||
self,
|
||||
root,
|
||||
class_map=''):
|
||||
class_map='',
|
||||
input_key=None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.root = root
|
||||
class_to_idx = None
|
||||
if class_map:
|
||||
class_to_idx = load_class_map(class_map, root)
|
||||
self.samples, self.class_to_idx = find_images_and_targets(root, class_to_idx=class_to_idx)
|
||||
find_types = None
|
||||
if input_key:
|
||||
find_types = input_key.split(';')
|
||||
self.samples, self.class_to_idx = find_images_and_targets(
|
||||
root,
|
||||
class_to_idx=class_to_idx,
|
||||
types=find_types,
|
||||
)
|
||||
if len(self.samples) == 0:
|
||||
raise RuntimeError(
|
||||
f'Found 0 images in subfolders of {root}. '
|
||||
|
@ -8,6 +8,7 @@ Hacked together by / Copyright 2020 Ross Wightman
|
||||
"""
|
||||
import math
|
||||
import os
|
||||
import sys
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
@ -32,7 +33,7 @@ try:
|
||||
except ImportError as e:
|
||||
print(e)
|
||||
print("Please install tensorflow_datasets package `pip install tensorflow-datasets`.")
|
||||
exit(1)
|
||||
raise e
|
||||
|
||||
from .class_map import load_class_map
|
||||
from .reader import Reader
|
||||
@ -45,10 +46,10 @@ PREFETCH_SIZE = int(os.environ.get('TFDS_PREFETCH_SIZE', 2048)) # samples to pr
|
||||
|
||||
|
||||
@tfds.decode.make_decoder()
|
||||
def decode_example(serialized_image, feature, dct_method='INTEGER_ACCURATE'):
|
||||
def decode_example(serialized_image, feature, dct_method='INTEGER_ACCURATE', channels=3):
|
||||
return tf.image.decode_jpeg(
|
||||
serialized_image,
|
||||
channels=3,
|
||||
channels=channels,
|
||||
dct_method=dct_method,
|
||||
)
|
||||
|
||||
@ -92,18 +93,18 @@ class ReaderTfds(Reader):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root,
|
||||
name,
|
||||
root=None,
|
||||
split='train',
|
||||
class_map=None,
|
||||
is_training=False,
|
||||
batch_size=None,
|
||||
batch_size=1,
|
||||
download=False,
|
||||
repeats=0,
|
||||
seed=42,
|
||||
input_name='image',
|
||||
input_key='image',
|
||||
input_img_mode='RGB',
|
||||
target_name='label',
|
||||
target_key='label',
|
||||
target_img_mode='',
|
||||
prefetch_size=None,
|
||||
shuffle_size=None,
|
||||
@ -120,9 +121,9 @@ class ReaderTfds(Reader):
|
||||
download: download and build TFDS dataset if set, otherwise must use tfds CLI
|
||||
repeats: iterate through (repeat) the dataset this many times per iteration (once if 0 or 1)
|
||||
seed: common seed for shard shuffle across all distributed/worker instances
|
||||
input_name: name of Feature to return as data (input)
|
||||
input_key: name of Feature to return as data (input)
|
||||
input_img_mode: image mode if input is an image (currently PIL mode string)
|
||||
target_name: name of Feature to return as target (label)
|
||||
target_key: name of Feature to return as target (label)
|
||||
target_img_mode: image mode if target is an image (currently PIL mode string)
|
||||
prefetch_size: override default tf.data prefetch buffer size
|
||||
shuffle_size: override default tf.data shuffle buffer size
|
||||
@ -132,9 +133,6 @@ class ReaderTfds(Reader):
|
||||
self.root = root
|
||||
self.split = split
|
||||
self.is_training = is_training
|
||||
if self.is_training:
|
||||
assert batch_size is not None, \
|
||||
"Must specify batch_size in training mode for reasonable behaviour w/ TFDS wrapper"
|
||||
self.batch_size = batch_size
|
||||
self.repeats = repeats
|
||||
self.common_seed = seed # a seed that's fixed across all worker / distributed instances
|
||||
@ -145,10 +143,10 @@ class ReaderTfds(Reader):
|
||||
self.max_threadpool_size = max_threadpool_size or MAX_TP_SIZE
|
||||
|
||||
# TFDS builder and split information
|
||||
self.input_name = input_name # FIXME support tuples / lists of inputs and targets and full range of Feature
|
||||
self.input_key = input_key # FIXME support tuples / lists of inputs and targets and full range of Feature
|
||||
self.input_img_mode = input_img_mode
|
||||
self.target_name = target_name
|
||||
self.target_img_mode = target_img_mode
|
||||
self.target_key = target_key
|
||||
self.target_img_mode = target_img_mode # for dense pixel targets
|
||||
self.builder = tfds.builder(name, data_dir=root)
|
||||
# NOTE: the tfds command line app can be used download & prepare datasets if you don't enable download flag
|
||||
if download:
|
||||
@ -158,7 +156,7 @@ class ReaderTfds(Reader):
|
||||
self.class_to_idx = load_class_map(class_map)
|
||||
self.remap_class = True
|
||||
else:
|
||||
self.class_to_idx = get_class_labels(self.builder.info) if self.target_name == 'label' else {}
|
||||
self.class_to_idx = get_class_labels(self.builder.info) if self.target_key == 'label' else {}
|
||||
self.split_info = self.builder.info.splits[split]
|
||||
self.num_samples = self.split_info.num_examples
|
||||
|
||||
@ -258,7 +256,7 @@ class ReaderTfds(Reader):
|
||||
ds = self.builder.as_dataset(
|
||||
split=self.subsplit or self.split,
|
||||
shuffle_files=self.is_training,
|
||||
decoders=dict(image=decode_example()),
|
||||
decoders=dict(image=decode_example(channels=1 if self.input_img_mode == 'L' else 3)),
|
||||
read_config=read_config,
|
||||
)
|
||||
# avoid overloading threading w/ combo of TF ds threads + PyTorch workers
|
||||
@ -282,7 +280,7 @@ class ReaderTfds(Reader):
|
||||
max(1, self.repeats) * self.num_samples / max(self.global_num_workers, self.dist_num_replicas)
|
||||
if self.is_training or self.dist_num_replicas > 1:
|
||||
num_worker_samples = math.ceil(num_worker_samples)
|
||||
if self.is_training and self.batch_size is not None:
|
||||
if self.is_training:
|
||||
num_worker_samples = math.ceil(num_worker_samples / self.batch_size) * self.batch_size
|
||||
return int(num_worker_samples)
|
||||
|
||||
@ -300,11 +298,14 @@ class ReaderTfds(Reader):
|
||||
# Iterate until exhausted or sample count hits target when training (ds.repeat enabled)
|
||||
sample_count = 0
|
||||
for sample in self.ds:
|
||||
input_data = sample[self.input_name]
|
||||
input_data = sample[self.input_key]
|
||||
if self.input_img_mode:
|
||||
if self.input_img_mode == 'L' and input_data.ndim == 3:
|
||||
input_data = input_data[:, :, 0]
|
||||
input_data = Image.fromarray(input_data, mode=self.input_img_mode)
|
||||
target_data = sample[self.target_name]
|
||||
target_data = sample[self.target_key]
|
||||
if self.target_img_mode:
|
||||
# dense pixel target
|
||||
target_data = Image.fromarray(target_data, mode=self.target_img_mode)
|
||||
elif self.remap_class:
|
||||
target_data = self.class_to_idx[target_data]
|
||||
|
@ -22,7 +22,7 @@ from torch.utils.data import Dataset, IterableDataset, get_worker_info
|
||||
|
||||
try:
|
||||
import webdataset as wds
|
||||
from webdataset.filters import _shuffle
|
||||
from webdataset.filters import _shuffle, getfirst
|
||||
from webdataset.shardlists import expand_urls
|
||||
from webdataset.tariterators import base_plus_ext, url_opener, tar_file_expander, valid_sample
|
||||
except ImportError:
|
||||
@ -35,27 +35,30 @@ from .shared_count import SharedCount
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
SHUFFLE_SIZE = int(os.environ.get('WDS_SHUFFLE_SIZE', 8192))
|
||||
SAMPLE_SHUFFLE_SIZE = int(os.environ.get('WDS_SHUFFLE_SIZE', 8192))
|
||||
SAMPLE_INITIAL_SIZE = int(os.environ.get('WDS_INITIAL_SIZE', 2048))
|
||||
|
||||
|
||||
def _load_info(root, basename='info'):
|
||||
info_json = os.path.join(root, basename + '.json')
|
||||
info_yaml = os.path.join(root, basename + '.yaml')
|
||||
def _load_info(root, names=('_info.json', 'info.json')):
|
||||
if isinstance(names, str):
|
||||
names = (names,)
|
||||
tried = []
|
||||
err_str = ''
|
||||
for n in names:
|
||||
full_path = os.path.join(root, n)
|
||||
try:
|
||||
with wds.gopen(info_json) as f:
|
||||
tried.append(full_path)
|
||||
with wds.gopen(full_path) as f:
|
||||
if n.endswith('.json'):
|
||||
info_dict = json.load(f)
|
||||
else:
|
||||
info_dict = yaml.safe_load(f)
|
||||
return info_dict
|
||||
except Exception as e:
|
||||
err_str = str(e)
|
||||
try:
|
||||
with wds.gopen(info_yaml) as f:
|
||||
info_dict = yaml.safe_load(f)
|
||||
return info_dict
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
_logger.warning(
|
||||
f'Dataset info file not found at {info_json} or {info_yaml}. Error: {err_str}. '
|
||||
f'Dataset info file not found at {tried}. Error: {err_str}. '
|
||||
'Falling back to provided split and size arg.')
|
||||
return {}
|
||||
|
||||
@ -121,15 +124,18 @@ def _parse_split_info(split: str, info: Dict):
|
||||
|
||||
|
||||
def log_and_continue(exn):
|
||||
"""Call in an exception handler to ignore any exception, isssue a warning, and continue."""
|
||||
"""Call in an exception handler to ignore exceptions, isssue a warning, and continue."""
|
||||
_logger.warning(f'Handling webdataset error ({repr(exn)}). Ignoring.')
|
||||
# NOTE: try force an exit on errors that are clearly code / config and not transient
|
||||
if isinstance(exn, TypeError):
|
||||
raise exn
|
||||
return True
|
||||
|
||||
|
||||
def _decode(
|
||||
sample,
|
||||
image_key='jpg',
|
||||
image_format='RGB',
|
||||
image_mode='RGB',
|
||||
target_key='cls',
|
||||
alt_label=''
|
||||
):
|
||||
@ -150,47 +156,18 @@ def _decode(
|
||||
class_label = int(sample[target_key])
|
||||
|
||||
# decode image
|
||||
with io.BytesIO(sample[image_key]) as b:
|
||||
img = getfirst(sample, image_key)
|
||||
with io.BytesIO(img) as b:
|
||||
img = Image.open(b)
|
||||
img.load()
|
||||
if image_format:
|
||||
img = img.convert(image_format)
|
||||
if image_mode:
|
||||
img = img.convert(image_mode)
|
||||
|
||||
# json passed through in undecoded state
|
||||
decoded = dict(jpg=img, cls=class_label, json=sample.get('json', None))
|
||||
return decoded
|
||||
|
||||
|
||||
def _decode_samples(
|
||||
data,
|
||||
image_key='jpg',
|
||||
image_format='RGB',
|
||||
target_key='cls',
|
||||
alt_label='',
|
||||
handler=log_and_continue):
|
||||
"""Decode samples with skip."""
|
||||
for sample in data:
|
||||
try:
|
||||
result = _decode(
|
||||
sample,
|
||||
image_key=image_key,
|
||||
image_format=image_format,
|
||||
target_key=target_key,
|
||||
alt_label=alt_label
|
||||
)
|
||||
except Exception as exn:
|
||||
if handler(exn):
|
||||
continue
|
||||
else:
|
||||
break
|
||||
|
||||
# null results are skipped
|
||||
if result is not None:
|
||||
if isinstance(sample, dict) and isinstance(result, dict):
|
||||
result["__key__"] = sample.get("__key__")
|
||||
yield result
|
||||
|
||||
|
||||
def pytorch_worker_seed():
|
||||
"""get dataloader worker seed from pytorch"""
|
||||
worker_info = get_worker_info()
|
||||
@ -203,6 +180,7 @@ def pytorch_worker_seed():
|
||||
|
||||
if wds is not None:
|
||||
# conditional to avoid mandatory wds import (via inheritance of wds.PipelineStage)
|
||||
|
||||
class detshuffle2(wds.PipelineStage):
|
||||
def __init__(
|
||||
self,
|
||||
@ -284,20 +262,22 @@ class ResampledShards2(IterableDataset):
|
||||
class ReaderWds(Reader):
|
||||
def __init__(
|
||||
self,
|
||||
root,
|
||||
name,
|
||||
split,
|
||||
is_training=False,
|
||||
batch_size=None,
|
||||
repeats=0,
|
||||
seed=42,
|
||||
class_map=None,
|
||||
input_name='jpg',
|
||||
input_image='RGB',
|
||||
target_name='cls',
|
||||
target_image='',
|
||||
prefetch_size=None,
|
||||
shuffle_size=None,
|
||||
root: str,
|
||||
name: Optional[str] = None,
|
||||
split: str = 'train',
|
||||
is_training: bool = False,
|
||||
num_samples: Optional[int] = None,
|
||||
batch_size: int = 1,
|
||||
repeats: int = 0,
|
||||
seed: int = 42,
|
||||
class_map: Optional[dict] = None,
|
||||
input_key: str = 'jpg;png;webp',
|
||||
input_img_mode: str = 'RGB',
|
||||
target_key: str = 'cls',
|
||||
target_img_mode: str = '',
|
||||
filename_key: str = 'filename',
|
||||
sample_shuffle_size: Optional[int] = None,
|
||||
smaple_initial_size: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
if wds is None:
|
||||
@ -309,19 +289,23 @@ class ReaderWds(Reader):
|
||||
self.repeats = repeats
|
||||
self.common_seed = seed # a seed that's fixed across all worker / distributed instances
|
||||
self.shard_shuffle_size = 500
|
||||
self.sample_shuffle_size = shuffle_size or SHUFFLE_SIZE
|
||||
self.sample_shuffle_size = sample_shuffle_size or SAMPLE_SHUFFLE_SIZE
|
||||
self.sample_initial_size = smaple_initial_size or SAMPLE_INITIAL_SIZE
|
||||
|
||||
self.image_key = input_name
|
||||
self.image_format = input_image
|
||||
self.target_key = target_name
|
||||
self.filename_key = 'filename'
|
||||
self.input_key = input_key
|
||||
self.input_img_mode = input_img_mode
|
||||
self.target_key = target_key
|
||||
self.filename_key = filename_key
|
||||
self.key_ext = '.JPEG' # extension to add to key for original filenames (DS specific, default ImageNet)
|
||||
|
||||
self.info = _load_info(self.root)
|
||||
self.split_info = _parse_split_info(split, self.info)
|
||||
if num_samples is not None:
|
||||
self.num_samples = num_samples
|
||||
else:
|
||||
self.num_samples = self.split_info.num_samples
|
||||
if not self.num_samples:
|
||||
raise RuntimeError(f'Invalid split definition, no samples found.')
|
||||
raise RuntimeError(f'Invalid split definition, num_samples not specified.')
|
||||
self.remap_class = False
|
||||
if class_map:
|
||||
self.class_to_idx = load_class_map(class_map)
|
||||
@ -346,7 +330,7 @@ class ReaderWds(Reader):
|
||||
self.init_count = 0
|
||||
self.epoch_count = SharedCount()
|
||||
|
||||
# DataPipeline is lazy init, majority of WDS DataPipeline could be init here, BUT, shuffle seed
|
||||
# DataPipeline is lazy init, the majority of WDS DataPipeline could be init here, BUT, shuffle seed
|
||||
# is not handled in manner where it can be deterministic for each worker AND initialized up front
|
||||
self.ds = None
|
||||
|
||||
@ -382,13 +366,19 @@ class ReaderWds(Reader):
|
||||
# at this point we have an iterator over all the shards
|
||||
if self.is_training:
|
||||
pipeline.extend([
|
||||
detshuffle2(self.shard_shuffle_size, seed=self.common_seed, epoch=self.epoch_count),
|
||||
detshuffle2(
|
||||
self.shard_shuffle_size,
|
||||
seed=self.common_seed,
|
||||
epoch=self.epoch_count,
|
||||
),
|
||||
self._split_by_node_and_worker,
|
||||
# at this point, we have an iterator over the shards assigned to each worker
|
||||
wds.tarfile_to_samples(handler=log_and_continue),
|
||||
wds.shuffle(
|
||||
self.sample_shuffle_size,
|
||||
rng=random.Random(self.worker_seed)), # this is why we lazy-init whole DataPipeline
|
||||
bufsize=self.sample_shuffle_size,
|
||||
initial=self.sample_initial_size,
|
||||
rng=random.Random(self.worker_seed) # this is why we lazy-init whole DataPipeline
|
||||
),
|
||||
])
|
||||
else:
|
||||
pipeline.extend([
|
||||
@ -397,12 +387,16 @@ class ReaderWds(Reader):
|
||||
wds.tarfile_to_samples(handler=log_and_continue),
|
||||
])
|
||||
pipeline.extend([
|
||||
wds.map(
|
||||
partial(
|
||||
_decode_samples,
|
||||
image_key=self.image_key,
|
||||
image_format=self.image_format,
|
||||
alt_label=self.split_info.alt_label
|
||||
)
|
||||
_decode,
|
||||
image_key=self.input_key,
|
||||
image_mode=self.input_img_mode,
|
||||
alt_label=self.split_info.alt_label,
|
||||
),
|
||||
handler=log_and_continue,
|
||||
),
|
||||
wds.rename(image=self.input_key, target=self.target_key)
|
||||
])
|
||||
self.ds = wds.DataPipeline(*pipeline)
|
||||
|
||||
@ -418,7 +412,7 @@ class ReaderWds(Reader):
|
||||
num_worker_samples = self.num_samples / max(self.global_num_workers, self.dist_num_replicas)
|
||||
if self.is_training or self.dist_num_replicas > 1:
|
||||
num_worker_samples = math.ceil(num_worker_samples)
|
||||
if self.is_training and self.batch_size is not None:
|
||||
if self.is_training:
|
||||
num_worker_samples = math.ceil(num_worker_samples / self.batch_size) * self.batch_size
|
||||
return int(num_worker_samples)
|
||||
|
||||
@ -439,10 +433,10 @@ class ReaderWds(Reader):
|
||||
i = 0
|
||||
# _logger.info(f'start {i}, {self.worker_id}') # FIXME temporary debug
|
||||
for sample in ds:
|
||||
target = sample[self.target_key]
|
||||
target = sample['target']
|
||||
if self.remap_class:
|
||||
target = self.class_to_idx[target]
|
||||
yield sample[self.image_key], target
|
||||
yield sample['image'], target
|
||||
i += 1
|
||||
# _logger.info(f'end {i}, {self.worker_id}') # FIXME temporary debug
|
||||
|
||||
|
@ -2,7 +2,7 @@ import math
|
||||
import numbers
|
||||
import random
|
||||
import warnings
|
||||
from typing import List, Sequence
|
||||
from typing import List, Sequence, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torchvision.transforms.functional as F
|
||||
@ -14,6 +14,12 @@ except ImportError:
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
__all__ = [
|
||||
"ToNumpy", "ToTensor", "str_to_interp_mode", "str_to_pil_interp", "interp_mode_to_str",
|
||||
"RandomResizedCropAndInterpolation", "CenterCropOrPad", "center_crop_or_pad", "crop_or_pad",
|
||||
"RandomCropOrPad", "RandomPad", "ResizeKeepRatio", "TrimBorder"
|
||||
]
|
||||
|
||||
|
||||
class ToNumpy:
|
||||
|
||||
@ -99,7 +105,7 @@ def interp_mode_to_str(mode):
|
||||
_RANDOM_INTERPOLATION = (str_to_interp_mode('bilinear'), str_to_interp_mode('bicubic'))
|
||||
|
||||
|
||||
def _setup_size(size, error_msg):
|
||||
def _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size."):
|
||||
if isinstance(size, numbers.Number):
|
||||
return int(size), int(size)
|
||||
|
||||
@ -127,8 +133,13 @@ class RandomResizedCropAndInterpolation:
|
||||
interpolation: Default: PIL.Image.BILINEAR
|
||||
"""
|
||||
|
||||
def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.),
|
||||
interpolation='bilinear'):
|
||||
def __init__(
|
||||
self,
|
||||
size,
|
||||
scale=(0.08, 1.0),
|
||||
ratio=(3. / 4., 4. / 3.),
|
||||
interpolation='bilinear',
|
||||
):
|
||||
if isinstance(size, (list, tuple)):
|
||||
self.size = tuple(size)
|
||||
else:
|
||||
@ -156,35 +167,35 @@ class RandomResizedCropAndInterpolation:
|
||||
tuple: params (i, j, h, w) to be passed to ``crop`` for a random
|
||||
sized crop.
|
||||
"""
|
||||
area = img.size[0] * img.size[1]
|
||||
img_w, img_h = F.get_image_size(img)
|
||||
area = img_w * img_h
|
||||
|
||||
for attempt in range(10):
|
||||
target_area = random.uniform(*scale) * area
|
||||
log_ratio = (math.log(ratio[0]), math.log(ratio[1]))
|
||||
aspect_ratio = math.exp(random.uniform(*log_ratio))
|
||||
|
||||
w = int(round(math.sqrt(target_area * aspect_ratio)))
|
||||
h = int(round(math.sqrt(target_area / aspect_ratio)))
|
||||
|
||||
if w <= img.size[0] and h <= img.size[1]:
|
||||
i = random.randint(0, img.size[1] - h)
|
||||
j = random.randint(0, img.size[0] - w)
|
||||
return i, j, h, w
|
||||
target_w = int(round(math.sqrt(target_area * aspect_ratio)))
|
||||
target_h = int(round(math.sqrt(target_area / aspect_ratio)))
|
||||
if target_w <= img_w and target_h <= img_h:
|
||||
i = random.randint(0, img_h - target_h)
|
||||
j = random.randint(0, img_w - target_w)
|
||||
return i, j, target_h, target_w
|
||||
|
||||
# Fallback to central crop
|
||||
in_ratio = img.size[0] / img.size[1]
|
||||
in_ratio = img_w / img_h
|
||||
if in_ratio < min(ratio):
|
||||
w = img.size[0]
|
||||
h = int(round(w / min(ratio)))
|
||||
target_w = img_w
|
||||
target_h = int(round(target_w / min(ratio)))
|
||||
elif in_ratio > max(ratio):
|
||||
h = img.size[1]
|
||||
w = int(round(h * max(ratio)))
|
||||
target_h = img_h
|
||||
target_w = int(round(target_h * max(ratio)))
|
||||
else: # whole image
|
||||
w = img.size[0]
|
||||
h = img.size[1]
|
||||
i = (img.size[1] - h) // 2
|
||||
j = (img.size[0] - w) // 2
|
||||
return i, j, h, w
|
||||
target_w = img_w
|
||||
target_h = img_h
|
||||
i = (img_h - target_h) // 2
|
||||
j = (img_w - target_w) // 2
|
||||
return i, j, target_h, target_w
|
||||
|
||||
def __call__(self, img):
|
||||
"""
|
||||
@ -213,8 +224,14 @@ class RandomResizedCropAndInterpolation:
|
||||
return format_string
|
||||
|
||||
|
||||
def center_crop_or_pad(img: torch.Tensor, output_size: List[int], fill=0) -> torch.Tensor:
|
||||
def center_crop_or_pad(
|
||||
img: torch.Tensor,
|
||||
output_size: Union[int, List[int]],
|
||||
fill: Union[int, Tuple[int, int, int]] = 0,
|
||||
padding_mode: str = 'constant',
|
||||
) -> torch.Tensor:
|
||||
"""Center crops and/or pads the given image.
|
||||
|
||||
If the image is torch Tensor, it is expected
|
||||
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
|
||||
If image size is smaller than output size along any edge, image is padded with 0 and then center cropped.
|
||||
@ -228,13 +245,9 @@ def center_crop_or_pad(img: torch.Tensor, output_size: List[int], fill=0) -> tor
|
||||
Returns:
|
||||
PIL Image or Tensor: Cropped image.
|
||||
"""
|
||||
if isinstance(output_size, numbers.Number):
|
||||
output_size = (int(output_size), int(output_size))
|
||||
elif isinstance(output_size, (tuple, list)) and len(output_size) == 1:
|
||||
output_size = (output_size[0], output_size[0])
|
||||
|
||||
_, image_height, image_width = F.get_dimensions(img)
|
||||
output_size = _setup_size(output_size)
|
||||
crop_height, crop_width = output_size
|
||||
_, image_height, image_width = F.get_dimensions(img)
|
||||
|
||||
if crop_width > image_width or crop_height > image_height:
|
||||
padding_ltrb = [
|
||||
@ -243,7 +256,7 @@ def center_crop_or_pad(img: torch.Tensor, output_size: List[int], fill=0) -> tor
|
||||
(crop_width - image_width + 1) // 2 if crop_width > image_width else 0,
|
||||
(crop_height - image_height + 1) // 2 if crop_height > image_height else 0,
|
||||
]
|
||||
img = F.pad(img, padding_ltrb, fill=fill)
|
||||
img = F.pad(img, padding_ltrb, fill=fill, padding_mode=padding_mode)
|
||||
_, image_height, image_width = F.get_dimensions(img)
|
||||
if crop_width == image_width and crop_height == image_height:
|
||||
return img
|
||||
@ -265,10 +278,16 @@ class CenterCropOrPad(torch.nn.Module):
|
||||
made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
|
||||
"""
|
||||
|
||||
def __init__(self, size, fill=0):
|
||||
def __init__(
|
||||
self,
|
||||
size: Union[int, List[int]],
|
||||
fill: Union[int, Tuple[int, int, int]] = 0,
|
||||
padding_mode: str = 'constant',
|
||||
):
|
||||
super().__init__()
|
||||
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
|
||||
self.size = _setup_size(size)
|
||||
self.fill = fill
|
||||
self.padding_mode = padding_mode
|
||||
|
||||
def forward(self, img):
|
||||
"""
|
||||
@ -278,14 +297,111 @@ class CenterCropOrPad(torch.nn.Module):
|
||||
Returns:
|
||||
PIL Image or Tensor: Cropped image.
|
||||
"""
|
||||
return center_crop_or_pad(img, self.size, fill=self.fill)
|
||||
return center_crop_or_pad(img, self.size, fill=self.fill, padding_mode=self.padding_mode)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}(size={self.size})"
|
||||
|
||||
|
||||
def crop_or_pad(
|
||||
img: torch.Tensor,
|
||||
top: int,
|
||||
left: int,
|
||||
height: int,
|
||||
width: int,
|
||||
fill: Union[int, Tuple[int, int, int]] = 0,
|
||||
padding_mode: str = 'constant',
|
||||
) -> torch.Tensor:
|
||||
""" Crops and/or pads image to meet target size, with control over fill and padding_mode.
|
||||
"""
|
||||
_, image_height, image_width = F.get_dimensions(img)
|
||||
right = left + width
|
||||
bottom = top + height
|
||||
if left < 0 or top < 0 or right > image_width or bottom > image_height:
|
||||
padding_ltrb = [
|
||||
max(-left + min(0, right), 0),
|
||||
max(-top + min(0, bottom), 0),
|
||||
max(right - max(image_width, left), 0),
|
||||
max(bottom - max(image_height, top), 0),
|
||||
]
|
||||
img = F.pad(img, padding_ltrb, fill=fill, padding_mode=padding_mode)
|
||||
|
||||
top = max(top, 0)
|
||||
left = max(left, 0)
|
||||
return F.crop(img, top, left, height, width)
|
||||
|
||||
|
||||
class RandomCropOrPad(torch.nn.Module):
|
||||
""" Crop and/or pad image with random placement within the crop or pad margin.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
size: Union[int, List[int]],
|
||||
fill: Union[int, Tuple[int, int, int]] = 0,
|
||||
padding_mode: str = 'constant',
|
||||
):
|
||||
super().__init__()
|
||||
self.size = _setup_size(size)
|
||||
self.fill = fill
|
||||
self.padding_mode = padding_mode
|
||||
|
||||
@staticmethod
|
||||
def get_params(img, size):
|
||||
_, image_height, image_width = F.get_dimensions(img)
|
||||
delta_height = image_height - size[0]
|
||||
delta_width = image_width - size[1]
|
||||
top = int(math.copysign(random.randint(0, abs(delta_height)), delta_height))
|
||||
left = int(math.copysign(random.randint(0, abs(delta_width)), delta_width))
|
||||
return top, left
|
||||
|
||||
def forward(self, img):
|
||||
"""
|
||||
Args:
|
||||
img (PIL Image or Tensor): Image to be cropped.
|
||||
|
||||
Returns:
|
||||
PIL Image or Tensor: Cropped image.
|
||||
"""
|
||||
top, left = self.get_params(img, self.size)
|
||||
return crop_or_pad(
|
||||
img,
|
||||
top=top,
|
||||
left=left,
|
||||
height=self.size[0],
|
||||
width=self.size[1],
|
||||
fill=self.fill,
|
||||
padding_mode=self.padding_mode,
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}(size={self.size})"
|
||||
|
||||
|
||||
class RandomPad:
|
||||
def __init__(self, input_size, fill=0):
|
||||
self.input_size = input_size
|
||||
self.fill = fill
|
||||
|
||||
@staticmethod
|
||||
def get_params(img, input_size):
|
||||
width, height = F.get_image_size(img)
|
||||
delta_width = max(input_size[1] - width, 0)
|
||||
delta_height = max(input_size[0] - height, 0)
|
||||
pad_left = random.randint(0, delta_width)
|
||||
pad_top = random.randint(0, delta_height)
|
||||
pad_right = delta_width - pad_left
|
||||
pad_bottom = delta_height - pad_top
|
||||
return pad_left, pad_top, pad_right, pad_bottom
|
||||
|
||||
def __call__(self, img):
|
||||
padding = self.get_params(img, self.input_size)
|
||||
img = F.pad(img, padding, self.fill)
|
||||
return img
|
||||
|
||||
|
||||
class ResizeKeepRatio:
|
||||
""" Resize and Keep Ratio
|
||||
""" Resize and Keep Aspect Ratio
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -293,33 +409,77 @@ class ResizeKeepRatio:
|
||||
size,
|
||||
longest=0.,
|
||||
interpolation='bilinear',
|
||||
fill=0,
|
||||
random_scale_prob=0.,
|
||||
random_scale_range=(0.85, 1.05),
|
||||
random_scale_area=False,
|
||||
random_aspect_prob=0.,
|
||||
random_aspect_range=(0.9, 1.11),
|
||||
):
|
||||
"""
|
||||
|
||||
Args:
|
||||
size:
|
||||
longest:
|
||||
interpolation:
|
||||
random_scale_prob:
|
||||
random_scale_range:
|
||||
random_scale_area:
|
||||
random_aspect_prob:
|
||||
random_aspect_range:
|
||||
"""
|
||||
if isinstance(size, (list, tuple)):
|
||||
self.size = tuple(size)
|
||||
else:
|
||||
self.size = (size, size)
|
||||
if interpolation == 'random':
|
||||
self.interpolation = _RANDOM_INTERPOLATION
|
||||
else:
|
||||
self.interpolation = str_to_interp_mode(interpolation)
|
||||
self.longest = float(longest)
|
||||
self.fill = fill
|
||||
self.random_scale_prob = random_scale_prob
|
||||
self.random_scale_range = random_scale_range
|
||||
self.random_scale_area = random_scale_area
|
||||
self.random_aspect_prob = random_aspect_prob
|
||||
self.random_aspect_range = random_aspect_range
|
||||
|
||||
@staticmethod
|
||||
def get_params(img, target_size, longest):
|
||||
def get_params(
|
||||
img,
|
||||
target_size,
|
||||
longest,
|
||||
random_scale_prob=0.,
|
||||
random_scale_range=(1.0, 1.33),
|
||||
random_scale_area=False,
|
||||
random_aspect_prob=0.,
|
||||
random_aspect_range=(0.9, 1.11)
|
||||
):
|
||||
"""Get parameters
|
||||
|
||||
Args:
|
||||
img (PIL Image): Image to be cropped.
|
||||
target_size (Tuple[int, int]): Size of output
|
||||
Returns:
|
||||
tuple: params (h, w) and (l, r, t, b) to be passed to ``resize`` and ``pad`` respectively
|
||||
"""
|
||||
source_size = img.size[::-1] # h, w
|
||||
h, w = source_size
|
||||
img_h, img_w = img_size = F.get_dimensions(img)[1:]
|
||||
target_h, target_w = target_size
|
||||
ratio_h = h / target_h
|
||||
ratio_w = w / target_w
|
||||
ratio_h = img_h / target_h
|
||||
ratio_w = img_w / target_w
|
||||
ratio = max(ratio_h, ratio_w) * longest + min(ratio_h, ratio_w) * (1. - longest)
|
||||
size = [round(x / ratio) for x in source_size]
|
||||
|
||||
if random_scale_prob > 0 and random.random() < random_scale_prob:
|
||||
ratio_factor = random.uniform(random_scale_range[0], random_scale_range[1])
|
||||
if random_scale_area:
|
||||
# make ratio factor equivalent to RRC area crop where < 1.0 = area zoom,
|
||||
# otherwise like affine scale where < 1.0 = linear zoom out
|
||||
ratio_factor = 1. / math.sqrt(ratio_factor)
|
||||
ratio_factor = (ratio_factor, ratio_factor)
|
||||
else:
|
||||
ratio_factor = (1., 1.)
|
||||
|
||||
if random_aspect_prob > 0 and random.random() < random_aspect_prob:
|
||||
log_aspect = (math.log(random_aspect_range[0]), math.log(random_aspect_range[1]))
|
||||
aspect_factor = math.exp(random.uniform(*log_aspect))
|
||||
aspect_factor = math.sqrt(aspect_factor)
|
||||
# currently applying random aspect adjustment equally to both dims,
|
||||
# could change to keep output sizes above their target where possible
|
||||
ratio_factor = (ratio_factor[0] / aspect_factor, ratio_factor[1] * aspect_factor)
|
||||
|
||||
size = [round(x * f / ratio) for x, f in zip(img_size, ratio_factor)]
|
||||
return size
|
||||
|
||||
def __call__(self, img):
|
||||
@ -330,13 +490,49 @@ class ResizeKeepRatio:
|
||||
Returns:
|
||||
PIL Image: Resized, padded to at least target size, possibly cropped to exactly target size
|
||||
"""
|
||||
size = self.get_params(img, self.size, self.longest)
|
||||
img = F.resize(img, size, self.interpolation)
|
||||
size = self.get_params(
|
||||
img, self.size, self.longest,
|
||||
self.random_scale_prob, self.random_scale_range, self.random_scale_area,
|
||||
self.random_aspect_prob, self.random_aspect_range
|
||||
)
|
||||
if isinstance(self.interpolation, (tuple, list)):
|
||||
interpolation = random.choice(self.interpolation)
|
||||
else:
|
||||
interpolation = self.interpolation
|
||||
img = F.resize(img, size, interpolation)
|
||||
return img
|
||||
|
||||
def __repr__(self):
|
||||
if isinstance(self.interpolation, (tuple, list)):
|
||||
interpolate_str = ' '.join([interp_mode_to_str(x) for x in self.interpolation])
|
||||
else:
|
||||
interpolate_str = interp_mode_to_str(self.interpolation)
|
||||
format_string = self.__class__.__name__ + '(size={0}'.format(self.size)
|
||||
format_string += f', interpolation={interpolate_str})'
|
||||
format_string += f', longest={self.longest:.3f})'
|
||||
format_string += f', interpolation={interpolate_str}'
|
||||
format_string += f', longest={self.longest:.3f}'
|
||||
format_string += f', random_scale_prob={self.random_scale_prob:.3f}'
|
||||
format_string += f', random_scale_range=(' \
|
||||
f'{self.random_scale_range[0]:.3f}, {self.random_aspect_range[1]:.3f})'
|
||||
format_string += f', random_aspect_prob={self.random_aspect_prob:.3f}'
|
||||
format_string += f', random_aspect_range=(' \
|
||||
f'{self.random_aspect_range[0]:.3f}, {self.random_aspect_range[1]:.3f}))'
|
||||
return format_string
|
||||
|
||||
|
||||
class TrimBorder(torch.nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
border_size: int,
|
||||
):
|
||||
super().__init__()
|
||||
self.border_size = border_size
|
||||
|
||||
def forward(self, img):
|
||||
w, h = F.get_image_size(img)
|
||||
top = left = self.border_size
|
||||
top = min(top, h)
|
||||
left = min(left, h)
|
||||
height = max(0, h - 2 * self.border_size)
|
||||
width = max(0, w - 2 * self.border_size)
|
||||
return F.crop(img, top, left, height, width)
|
@ -4,6 +4,7 @@ Factory methods for building image transforms for use with TIMM (PyTorch Image M
|
||||
Hacked together by / Copyright 2019, Ross Wightman
|
||||
"""
|
||||
import math
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torchvision import transforms
|
||||
@ -11,17 +12,29 @@ from torchvision import transforms
|
||||
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, DEFAULT_CROP_PCT
|
||||
from timm.data.auto_augment import rand_augment_transform, augment_and_mix_transform, auto_augment_transform
|
||||
from timm.data.transforms import str_to_interp_mode, str_to_pil_interp, RandomResizedCropAndInterpolation,\
|
||||
ResizeKeepRatio, CenterCropOrPad, ToNumpy
|
||||
ResizeKeepRatio, CenterCropOrPad, RandomCropOrPad, TrimBorder, ToNumpy
|
||||
from timm.data.random_erasing import RandomErasing
|
||||
|
||||
|
||||
def transforms_noaug_train(
|
||||
img_size=224,
|
||||
interpolation='bilinear',
|
||||
use_prefetcher=False,
|
||||
mean=IMAGENET_DEFAULT_MEAN,
|
||||
std=IMAGENET_DEFAULT_STD,
|
||||
img_size: Union[int, Tuple[int, int]] = 224,
|
||||
interpolation: str = 'bilinear',
|
||||
use_prefetcher: bool = False,
|
||||
mean: Tuple[float, ...] = IMAGENET_DEFAULT_MEAN,
|
||||
std: Tuple[float, ...] = IMAGENET_DEFAULT_STD,
|
||||
):
|
||||
""" No-augmentation image transforms for training.
|
||||
|
||||
Args:
|
||||
img_size: Target image size.
|
||||
interpolation: Image interpolation mode.
|
||||
mean: Image normalization mean.
|
||||
std: Image normalization standard deviation.
|
||||
use_prefetcher: Prefetcher enabled. Do not convert image to tensor or normalize.
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
if interpolation == 'random':
|
||||
# random interpolation not supported with no-aug
|
||||
interpolation = 'bilinear'
|
||||
@ -37,41 +50,97 @@ def transforms_noaug_train(
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(
|
||||
mean=torch.tensor(mean),
|
||||
std=torch.tensor(std))
|
||||
std=torch.tensor(std)
|
||||
)
|
||||
]
|
||||
return transforms.Compose(tfl)
|
||||
|
||||
|
||||
def transforms_imagenet_train(
|
||||
img_size=224,
|
||||
scale=None,
|
||||
ratio=None,
|
||||
hflip=0.5,
|
||||
vflip=0.,
|
||||
color_jitter=0.4,
|
||||
auto_augment=None,
|
||||
interpolation='random',
|
||||
use_prefetcher=False,
|
||||
mean=IMAGENET_DEFAULT_MEAN,
|
||||
std=IMAGENET_DEFAULT_STD,
|
||||
re_prob=0.,
|
||||
re_mode='const',
|
||||
re_count=1,
|
||||
re_num_splits=0,
|
||||
separate=False,
|
||||
force_color_jitter=False,
|
||||
img_size: Union[int, Tuple[int, int]] = 224,
|
||||
scale: Optional[Tuple[float, float]] = None,
|
||||
ratio: Optional[Tuple[float, float]] = None,
|
||||
train_crop_mode: Optional[str] = None,
|
||||
hflip: float = 0.5,
|
||||
vflip: float = 0.,
|
||||
color_jitter: Union[float, Tuple[float, ...]] = 0.4,
|
||||
color_jitter_prob: Optional[float] = None,
|
||||
force_color_jitter: bool = False,
|
||||
grayscale_prob: float = 0.,
|
||||
gaussian_blur_prob: float = 0.,
|
||||
auto_augment: Optional[str] = None,
|
||||
interpolation: str = 'random',
|
||||
mean: Tuple[float, ...] = IMAGENET_DEFAULT_MEAN,
|
||||
std: Tuple[float, ...] = IMAGENET_DEFAULT_STD,
|
||||
re_prob: float = 0.,
|
||||
re_mode: str = 'const',
|
||||
re_count: int = 1,
|
||||
re_num_splits: int = 0,
|
||||
use_prefetcher: bool = False,
|
||||
separate: bool = False,
|
||||
):
|
||||
"""
|
||||
""" ImageNet-oriented image transforms for training.
|
||||
|
||||
Args:
|
||||
img_size: Target image size.
|
||||
scale: Random resize scale range (crop area, < 1.0 => zoom in).
|
||||
ratio: Random aspect ratio range (crop ratio for RRC, ratio adjustment factor for RKR).
|
||||
hflip: Horizontal flip probability.
|
||||
vflip: Vertical flip probability.
|
||||
color_jitter: Random color jitter component factors (brightness, contrast, saturation, hue).
|
||||
Scalar is applied as (scalar,) * 3 (no hue).
|
||||
color_jitter_prob: Apply color jitter with this probability if not None (for SimlCLR-like aug).
|
||||
force_color_jitter: Force color jitter where it is normally disabled (ie with RandAugment on).
|
||||
grayscale_prob: Probability of converting image to grayscale (for SimCLR-like aug).
|
||||
gaussian_blur_prob: Probability of applying gaussian blur (for SimCLR-like aug).
|
||||
auto_augment: Auto augment configuration string (see auto_augment.py).
|
||||
interpolation: Image interpolation mode.
|
||||
mean: Image normalization mean.
|
||||
std: Image normalization standard deviation.
|
||||
re_prob: Random erasing probability.
|
||||
re_mode: Random erasing fill mode.
|
||||
re_count: Number of random erasing regions.
|
||||
re_num_splits: Control split of random erasing across batch size.
|
||||
use_prefetcher: Prefetcher enabled. Do not convert image to tensor or normalize.
|
||||
separate: Output transforms in 3-stage tuple.
|
||||
|
||||
Returns:
|
||||
If separate==True, the transforms are returned as a tuple of 3 separate transforms
|
||||
for use in a mixing dataset that passes
|
||||
* all data through the first (primary) transform, called the 'clean' data
|
||||
* a portion of the data through the secondary transform
|
||||
* normalizes and converts the branches above with the third, final transform
|
||||
"""
|
||||
train_crop_mode = train_crop_mode or 'rrc'
|
||||
if train_crop_mode in ('rkrc', 'rkrr'):
|
||||
# FIXME integration of RKR is a WIP
|
||||
scale = tuple(scale or (0.8, 1.00))
|
||||
ratio = tuple(ratio or (0.9, 1/.9))
|
||||
primary_tfl = [
|
||||
ResizeKeepRatio(
|
||||
img_size,
|
||||
interpolation=interpolation,
|
||||
random_scale_prob=0.5,
|
||||
random_scale_range=scale,
|
||||
random_scale_area=True, # scale compatible with RRC
|
||||
random_aspect_prob=0.5,
|
||||
random_aspect_range=ratio,
|
||||
),
|
||||
CenterCropOrPad(img_size, padding_mode='reflect')
|
||||
if train_crop_mode == 'rkrc' else
|
||||
RandomCropOrPad(img_size, padding_mode='reflect')
|
||||
]
|
||||
else:
|
||||
scale = tuple(scale or (0.08, 1.0)) # default imagenet scale range
|
||||
ratio = tuple(ratio or (3. / 4., 4. / 3.)) # default imagenet ratio range
|
||||
primary_tfl = [
|
||||
RandomResizedCropAndInterpolation(img_size, scale=scale, ratio=ratio, interpolation=interpolation)]
|
||||
RandomResizedCropAndInterpolation(
|
||||
img_size,
|
||||
scale=scale,
|
||||
ratio=ratio,
|
||||
interpolation=interpolation,
|
||||
)
|
||||
]
|
||||
if hflip > 0.:
|
||||
primary_tfl += [transforms.RandomHorizontalFlip(p=hflip)]
|
||||
if vflip > 0.:
|
||||
@ -111,8 +180,29 @@ def transforms_imagenet_train(
|
||||
else:
|
||||
# if it's a scalar, duplicate for brightness, contrast, and saturation, no hue
|
||||
color_jitter = (float(color_jitter),) * 3
|
||||
if color_jitter_prob is not None:
|
||||
secondary_tfl += [
|
||||
transforms.RandomApply([
|
||||
transforms.ColorJitter(*color_jitter),
|
||||
],
|
||||
p=color_jitter_prob
|
||||
)
|
||||
]
|
||||
else:
|
||||
secondary_tfl += [transforms.ColorJitter(*color_jitter)]
|
||||
|
||||
if grayscale_prob:
|
||||
secondary_tfl += [transforms.RandomGrayscale(p=grayscale_prob)]
|
||||
|
||||
if gaussian_blur_prob:
|
||||
secondary_tfl += [
|
||||
transforms.RandomApply([
|
||||
transforms.GaussianBlur(kernel_size=23), # hardcoded for now
|
||||
],
|
||||
p=gaussian_blur_prob,
|
||||
)
|
||||
]
|
||||
|
||||
final_tfl = []
|
||||
if use_prefetcher:
|
||||
# prefetcher and collate will handle tensor conversion and norm
|
||||
@ -122,11 +212,19 @@ def transforms_imagenet_train(
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(
|
||||
mean=torch.tensor(mean),
|
||||
std=torch.tensor(std))
|
||||
std=torch.tensor(std)
|
||||
),
|
||||
]
|
||||
if re_prob > 0.:
|
||||
final_tfl.append(
|
||||
RandomErasing(re_prob, mode=re_mode, max_count=re_count, num_splits=re_num_splits, device='cpu'))
|
||||
final_tfl += [
|
||||
RandomErasing(
|
||||
re_prob,
|
||||
mode=re_mode,
|
||||
max_count=re_count,
|
||||
num_splits=re_num_splits,
|
||||
device='cpu',
|
||||
)
|
||||
]
|
||||
|
||||
if separate:
|
||||
return transforms.Compose(primary_tfl), transforms.Compose(secondary_tfl), transforms.Compose(final_tfl)
|
||||
@ -135,14 +233,30 @@ def transforms_imagenet_train(
|
||||
|
||||
|
||||
def transforms_imagenet_eval(
|
||||
img_size=224,
|
||||
crop_pct=None,
|
||||
crop_mode=None,
|
||||
interpolation='bilinear',
|
||||
use_prefetcher=False,
|
||||
mean=IMAGENET_DEFAULT_MEAN,
|
||||
std=IMAGENET_DEFAULT_STD
|
||||
img_size: Union[int, Tuple[int, int]] = 224,
|
||||
crop_pct: Optional[float] = None,
|
||||
crop_mode: Optional[str] = None,
|
||||
crop_border_pixels: Optional[int] = None,
|
||||
interpolation: str = 'bilinear',
|
||||
mean: Tuple[float, ...] = IMAGENET_DEFAULT_MEAN,
|
||||
std: Tuple[float, ...] = IMAGENET_DEFAULT_STD,
|
||||
use_prefetcher: bool = False,
|
||||
):
|
||||
""" ImageNet-oriented image transform for evaluation and inference.
|
||||
|
||||
Args:
|
||||
img_size: Target image size.
|
||||
crop_pct: Crop percentage. Defaults to 0.875 when None.
|
||||
crop_mode: Crop mode. One of ['squash', 'border', 'center']. Defaults to 'center' when None.
|
||||
crop_border_pixels: Trim a border of specified # pixels around edge of original image.
|
||||
interpolation: Image interpolation mode.
|
||||
mean: Image normalization mean.
|
||||
std: Image normalization standard deviation.
|
||||
use_prefetcher: Prefetcher enabled. Do not convert image to tensor or normalize.
|
||||
|
||||
Returns:
|
||||
Composed transform pipeline
|
||||
"""
|
||||
crop_pct = crop_pct or DEFAULT_CROP_PCT
|
||||
|
||||
if isinstance(img_size, (tuple, list)):
|
||||
@ -152,10 +266,15 @@ def transforms_imagenet_eval(
|
||||
scale_size = math.floor(img_size / crop_pct)
|
||||
scale_size = (scale_size, scale_size)
|
||||
|
||||
tfl = []
|
||||
|
||||
if crop_border_pixels:
|
||||
tfl += [TrimBorder(crop_border_pixels)]
|
||||
|
||||
if crop_mode == 'squash':
|
||||
# squash mode scales each edge to 1/pct of target, then crops
|
||||
# aspect ratio is not preserved, no img lost if crop_pct == 1.0
|
||||
tfl = [
|
||||
tfl += [
|
||||
transforms.Resize(scale_size, interpolation=str_to_interp_mode(interpolation)),
|
||||
transforms.CenterCrop(img_size),
|
||||
]
|
||||
@ -163,7 +282,7 @@ def transforms_imagenet_eval(
|
||||
# scale the longest edge of image to 1/pct of target edge, add borders to pad, then crop
|
||||
# no image lost if crop_pct == 1.0
|
||||
fill = [round(255 * v) for v in mean]
|
||||
tfl = [
|
||||
tfl += [
|
||||
ResizeKeepRatio(scale_size, interpolation=interpolation, longest=1.0),
|
||||
CenterCropOrPad(img_size, fill=fill),
|
||||
]
|
||||
@ -172,12 +291,12 @@ def transforms_imagenet_eval(
|
||||
# aspect ratio is preserved, crops center within image, no borders are added, image is lost
|
||||
if scale_size[0] == scale_size[1]:
|
||||
# simple case, use torchvision built-in Resize w/ shortest edge mode (scalar size arg)
|
||||
tfl = [
|
||||
tfl += [
|
||||
transforms.Resize(scale_size[0], interpolation=str_to_interp_mode(interpolation))
|
||||
]
|
||||
else:
|
||||
# resize shortest edge to matching target dim for non-square target
|
||||
tfl = [ResizeKeepRatio(scale_size)]
|
||||
# resize the shortest edge to matching target dim for non-square target
|
||||
tfl += [ResizeKeepRatio(scale_size)]
|
||||
tfl += [transforms.CenterCrop(img_size)]
|
||||
|
||||
if use_prefetcher:
|
||||
@ -196,28 +315,65 @@ def transforms_imagenet_eval(
|
||||
|
||||
|
||||
def create_transform(
|
||||
input_size,
|
||||
is_training=False,
|
||||
use_prefetcher=False,
|
||||
no_aug=False,
|
||||
scale=None,
|
||||
ratio=None,
|
||||
hflip=0.5,
|
||||
vflip=0.,
|
||||
color_jitter=0.4,
|
||||
auto_augment=None,
|
||||
interpolation='bilinear',
|
||||
mean=IMAGENET_DEFAULT_MEAN,
|
||||
std=IMAGENET_DEFAULT_STD,
|
||||
re_prob=0.,
|
||||
re_mode='const',
|
||||
re_count=1,
|
||||
re_num_splits=0,
|
||||
crop_pct=None,
|
||||
crop_mode=None,
|
||||
tf_preprocessing=False,
|
||||
separate=False):
|
||||
input_size: Union[int, Tuple[int, int], Tuple[int, int, int]] = 224,
|
||||
is_training: bool = False,
|
||||
no_aug: bool = False,
|
||||
scale: Optional[Tuple[float, float]] = None,
|
||||
ratio: Optional[Tuple[float, float]] = None,
|
||||
hflip: float = 0.5,
|
||||
vflip: float = 0.,
|
||||
color_jitter: Union[float, Tuple[float, ...]] = 0.4,
|
||||
color_jitter_prob: Optional[float] = None,
|
||||
grayscale_prob: float = 0.,
|
||||
gaussian_blur_prob: float = 0.,
|
||||
auto_augment: Optional[str] = None,
|
||||
interpolation: str = 'bilinear',
|
||||
mean: Tuple[float, ...] = IMAGENET_DEFAULT_MEAN,
|
||||
std: Tuple[float, ...] = IMAGENET_DEFAULT_STD,
|
||||
re_prob: float = 0.,
|
||||
re_mode: str = 'const',
|
||||
re_count: int = 1,
|
||||
re_num_splits: int = 0,
|
||||
crop_pct: Optional[float] = None,
|
||||
crop_mode: Optional[str] = None,
|
||||
crop_border_pixels: Optional[int] = None,
|
||||
tf_preprocessing: bool = False,
|
||||
use_prefetcher: bool = False,
|
||||
separate: bool = False,
|
||||
):
|
||||
"""
|
||||
|
||||
Args:
|
||||
input_size: Target input size (channels, height, width) tuple or size scalar.
|
||||
is_training: Return training (random) transforms.
|
||||
no_aug: Disable augmentation for training (useful for debug).
|
||||
scale: Random resize scale range (crop area, < 1.0 => zoom in).
|
||||
ratio: Random aspect ratio range (crop ratio for RRC, ratio adjustment factor for RKR).
|
||||
hflip: Horizontal flip probability.
|
||||
vflip: Vertical flip probability.
|
||||
color_jitter: Random color jitter component factors (brightness, contrast, saturation, hue).
|
||||
Scalar is applied as (scalar,) * 3 (no hue).
|
||||
color_jitter_prob: Apply color jitter with this probability if not None (for SimlCLR-like aug).
|
||||
grayscale_prob: Probability of converting image to grayscale (for SimCLR-like aug).
|
||||
gaussian_blur_prob: Probability of applying gaussian blur (for SimCLR-like aug).
|
||||
auto_augment: Auto augment configuration string (see auto_augment.py).
|
||||
interpolation: Image interpolation mode.
|
||||
mean: Image normalization mean.
|
||||
std: Image normalization standard deviation.
|
||||
re_prob: Random erasing probability.
|
||||
re_mode: Random erasing fill mode.
|
||||
re_count: Number of random erasing regions.
|
||||
re_num_splits: Control split of random erasing across batch size.
|
||||
crop_pct: Inference crop percentage (output size / resize size).
|
||||
crop_mode: Inference crop mode. One of ['squash', 'border', 'center']. Defaults to 'center' when None.
|
||||
crop_border_pixels: Inference crop border of specified # pixels around edge of original image.
|
||||
tf_preprocessing: Use TF 1.0 inference preprocessing for testing model ports
|
||||
use_prefetcher: Pre-fetcher enabled. Do not convert image to tensor or normalize.
|
||||
separate: Output transforms in 3-stage tuple.
|
||||
|
||||
Returns:
|
||||
Composed transforms or tuple thereof
|
||||
"""
|
||||
if isinstance(input_size, (tuple, list)):
|
||||
img_size = input_size[-2:]
|
||||
else:
|
||||
@ -227,7 +383,10 @@ def create_transform(
|
||||
assert not separate, "Separate transforms not supported for TF preprocessing"
|
||||
from timm.data.tf_preprocessing import TfPreprocessTransform
|
||||
transform = TfPreprocessTransform(
|
||||
is_training=is_training, size=img_size, interpolation=interpolation)
|
||||
is_training=is_training,
|
||||
size=img_size,
|
||||
interpolation=interpolation,
|
||||
)
|
||||
else:
|
||||
if is_training and no_aug:
|
||||
assert not separate, "Cannot perform split augmentation with no_aug"
|
||||
@ -246,6 +405,9 @@ def create_transform(
|
||||
hflip=hflip,
|
||||
vflip=vflip,
|
||||
color_jitter=color_jitter,
|
||||
color_jitter_prob=color_jitter_prob,
|
||||
grayscale_prob=grayscale_prob,
|
||||
gaussian_blur_prob=gaussian_blur_prob,
|
||||
auto_augment=auto_augment,
|
||||
interpolation=interpolation,
|
||||
use_prefetcher=use_prefetcher,
|
||||
@ -267,6 +429,7 @@ def create_transform(
|
||||
std=std,
|
||||
crop_pct=crop_pct,
|
||||
crop_mode=crop_mode,
|
||||
crop_border_pixels=crop_border_pixels,
|
||||
)
|
||||
|
||||
return transform
|
||||
|
35
train.py
35
train.py
@ -93,10 +93,20 @@ group.add_argument('--train-split', metavar='NAME', default='train',
|
||||
help='dataset train split (default: train)')
|
||||
group.add_argument('--val-split', metavar='NAME', default='validation',
|
||||
help='dataset validation split (default: validation)')
|
||||
parser.add_argument('--train-num-samples', default=None, type=int,
|
||||
metavar='N', help='Manually specify num samples in train split, for IterableDatasets.')
|
||||
parser.add_argument('--val-num-samples', default=None, type=int,
|
||||
metavar='N', help='Manually specify num samples in validation split, for IterableDatasets.')
|
||||
group.add_argument('--dataset-download', action='store_true', default=False,
|
||||
help='Allow download of dataset for torch/ and tfds/ datasets that support it.')
|
||||
group.add_argument('--class-map', default='', type=str, metavar='FILENAME',
|
||||
help='path to class to idx mapping file (default: "")')
|
||||
group.add_argument('--input-img-mode', default=None, type=str,
|
||||
help='Dataset image conversion mode for input images.')
|
||||
group.add_argument('--input-key', default=None, type=str,
|
||||
help='Dataset key for input images.')
|
||||
group.add_argument('--target-key', default=None, type=str,
|
||||
help='Dataset key for target labels.')
|
||||
|
||||
# Model parameters
|
||||
group = parser.add_argument_group('Model parameters')
|
||||
@ -245,6 +255,12 @@ group.add_argument('--vflip', type=float, default=0.,
|
||||
help='Vertical flip training aug probability')
|
||||
group.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT',
|
||||
help='Color jitter factor (default: 0.4)')
|
||||
group.add_argument('--color-jitter-prob', type=float, default=None, metavar='PCT',
|
||||
help='Probability of applying any color jitter.')
|
||||
group.add_argument('--grayscale-prob', type=float, default=None, metavar='PCT',
|
||||
help='Probability of applying random grayscale conversion.')
|
||||
group.add_argument('--gaussian-blur-prob', type=float, default=None, metavar='PCT',
|
||||
help='Probability of applying gaussian blur.')
|
||||
group.add_argument('--aa', type=str, default=None, metavar='NAME',
|
||||
help='Use AutoAugment policy. "v0" or "original". (default: None)'),
|
||||
group.add_argument('--aug-repeats', type=float, default=0,
|
||||
@ -594,6 +610,10 @@ def main():
|
||||
# create the train and eval datasets
|
||||
if args.data and not args.data_dir:
|
||||
args.data_dir = args.data
|
||||
if args.input_img_mode is None:
|
||||
input_img_mode = 'RGB' if data_config['input_size'][0] == 3 else 'L'
|
||||
else:
|
||||
input_img_mode = args.input_img_mode
|
||||
dataset_train = create_dataset(
|
||||
args.dataset,
|
||||
root=args.data_dir,
|
||||
@ -604,6 +624,10 @@ def main():
|
||||
batch_size=args.batch_size,
|
||||
seed=args.seed,
|
||||
repeats=args.epoch_repeats,
|
||||
input_img_mode=input_img_mode,
|
||||
input_key=args.input_key,
|
||||
target_key=args.target_key,
|
||||
num_samples=args.train_num_samples,
|
||||
)
|
||||
|
||||
dataset_eval = create_dataset(
|
||||
@ -614,6 +638,10 @@ def main():
|
||||
class_map=args.class_map,
|
||||
download=args.dataset_download,
|
||||
batch_size=args.batch_size,
|
||||
input_img_mode=input_img_mode,
|
||||
input_key=args.input_key,
|
||||
target_key=args.target_key,
|
||||
num_samples=args.val_num_samples,
|
||||
)
|
||||
|
||||
# setup mixup / cutmix
|
||||
@ -650,7 +678,6 @@ def main():
|
||||
input_size=data_config['input_size'],
|
||||
batch_size=args.batch_size,
|
||||
is_training=True,
|
||||
use_prefetcher=args.prefetcher,
|
||||
no_aug=args.no_aug,
|
||||
re_prob=args.reprob,
|
||||
re_mode=args.remode,
|
||||
@ -661,6 +688,9 @@ def main():
|
||||
hflip=args.hflip,
|
||||
vflip=args.vflip,
|
||||
color_jitter=args.color_jitter,
|
||||
color_jitter_prob=args.color_jitter_prob,
|
||||
grayscale_prob=args.grayscale_prob,
|
||||
gaussian_blur_prob=args.gaussian_blur_prob,
|
||||
auto_augment=args.aa,
|
||||
num_aug_repeats=args.aug_repeats,
|
||||
num_aug_splits=num_aug_splits,
|
||||
@ -672,6 +702,7 @@ def main():
|
||||
collate_fn=collate_fn,
|
||||
pin_memory=args.pin_mem,
|
||||
device=device,
|
||||
use_prefetcher=args.prefetcher,
|
||||
use_multi_epochs_loader=args.use_multi_epochs_loader,
|
||||
worker_seeding=args.worker_seeding,
|
||||
)
|
||||
@ -685,7 +716,6 @@ def main():
|
||||
input_size=data_config['input_size'],
|
||||
batch_size=args.validation_batch_size or args.batch_size,
|
||||
is_training=False,
|
||||
use_prefetcher=args.prefetcher,
|
||||
interpolation=data_config['interpolation'],
|
||||
mean=data_config['mean'],
|
||||
std=data_config['std'],
|
||||
@ -694,6 +724,7 @@ def main():
|
||||
crop_pct=data_config['crop_pct'],
|
||||
pin_memory=args.pin_mem,
|
||||
device=device,
|
||||
use_prefetcher=args.prefetcher,
|
||||
)
|
||||
|
||||
# setup loss function
|
||||
|
28
validate.py
28
validate.py
@ -61,10 +61,23 @@ parser.add_argument('--dataset', metavar='NAME', default='',
|
||||
help='dataset type + name ("<type>/<name>") (default: ImageFolder or ImageTar if empty)')
|
||||
parser.add_argument('--split', metavar='NAME', default='validation',
|
||||
help='dataset split (default: validation)')
|
||||
parser.add_argument('--num-samples', default=None, type=int,
|
||||
metavar='N', help='Manually specify num samples in dataset split, for IterableDatasets.')
|
||||
parser.add_argument('--dataset-download', action='store_true', default=False,
|
||||
help='Allow download of dataset for torch/ and tfds/ datasets that support it.')
|
||||
parser.add_argument('--class-map', default='', type=str, metavar='FILENAME',
|
||||
help='path to class to idx mapping file (default: "")')
|
||||
parser.add_argument('--input-key', default=None, type=str,
|
||||
help='Dataset key for input images.')
|
||||
parser.add_argument('--input-img-mode', default=None, type=str,
|
||||
help='Dataset image conversion mode for input images.')
|
||||
parser.add_argument('--target-key', default=None, type=str,
|
||||
help='Dataset key for target labels.')
|
||||
|
||||
parser.add_argument('--model', '-m', metavar='NAME', default='dpn92',
|
||||
help='model architecture (default: dpn92)')
|
||||
parser.add_argument('--pretrained', dest='pretrained', action='store_true',
|
||||
help='use pre-trained model')
|
||||
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
|
||||
help='number of data loading workers (default: 4)')
|
||||
parser.add_argument('-b', '--batch-size', default=256, type=int,
|
||||
@ -81,6 +94,8 @@ parser.add_argument('--crop-pct', default=None, type=float,
|
||||
metavar='N', help='Input image center crop pct')
|
||||
parser.add_argument('--crop-mode', default=None, type=str,
|
||||
metavar='N', help='Input image crop mode (squash, border, center). Model default if None.')
|
||||
parser.add_argument('--crop-border-pixels', type=int, default=None,
|
||||
help='Crop pixels from image border.')
|
||||
parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
|
||||
help='Override mean pixel value of dataset')
|
||||
parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',
|
||||
@ -89,16 +104,12 @@ parser.add_argument('--interpolation', default='', type=str, metavar='NAME',
|
||||
help='Image resize interpolation type (overrides model)')
|
||||
parser.add_argument('--num-classes', type=int, default=None,
|
||||
help='Number classes in dataset')
|
||||
parser.add_argument('--class-map', default='', type=str, metavar='FILENAME',
|
||||
help='path to class to idx mapping file (default: "")')
|
||||
parser.add_argument('--gp', default=None, type=str, metavar='POOL',
|
||||
help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.')
|
||||
parser.add_argument('--log-freq', default=10, type=int,
|
||||
metavar='N', help='batch logging frequency (default: 10)')
|
||||
parser.add_argument('--checkpoint', default='', type=str, metavar='PATH',
|
||||
help='path to latest checkpoint (default: none)')
|
||||
parser.add_argument('--pretrained', dest='pretrained', action='store_true',
|
||||
help='use pre-trained model')
|
||||
parser.add_argument('--num-gpu', type=int, default=1,
|
||||
help='Number of GPUS to use')
|
||||
parser.add_argument('--test-pool', dest='test_pool', action='store_true',
|
||||
@ -249,6 +260,10 @@ def validate(args):
|
||||
criterion = nn.CrossEntropyLoss().to(device)
|
||||
|
||||
root_dir = args.data or args.data_dir
|
||||
if args.input_img_mode is None:
|
||||
input_img_mode = 'RGB' if data_config['input_size'][0] == 3 else 'L'
|
||||
else:
|
||||
input_img_mode = args.input_img_mode
|
||||
dataset = create_dataset(
|
||||
root=root_dir,
|
||||
name=args.dataset,
|
||||
@ -256,6 +271,10 @@ def validate(args):
|
||||
download=args.dataset_download,
|
||||
load_bytes=args.tf_preprocessing,
|
||||
class_map=args.class_map,
|
||||
num_samples=args.num_samples,
|
||||
input_key=args.input_key,
|
||||
input_img_mode=input_img_mode,
|
||||
target_key=args.target_key,
|
||||
)
|
||||
|
||||
if args.valid_labels:
|
||||
@ -281,6 +300,7 @@ def validate(args):
|
||||
num_workers=args.workers,
|
||||
crop_pct=crop_pct,
|
||||
crop_mode=data_config['crop_mode'],
|
||||
crop_border_pixels=args.crop_border_pixels,
|
||||
pin_memory=args.pin_mem,
|
||||
device=device,
|
||||
tf_preprocessing=args.tf_preprocessing,
|
||||
|
Loading…
x
Reference in New Issue
Block a user