mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Significant transforms, dataset, dataloading enhancements.
This commit is contained in:
parent
b5a4fa9c3b
commit
be0944edae
@ -27,7 +27,7 @@ class ImageDataset(data.Dataset):
|
|||||||
split='train',
|
split='train',
|
||||||
class_map=None,
|
class_map=None,
|
||||||
load_bytes=False,
|
load_bytes=False,
|
||||||
img_mode='RGB',
|
input_img_mode='RGB',
|
||||||
transform=None,
|
transform=None,
|
||||||
target_transform=None,
|
target_transform=None,
|
||||||
):
|
):
|
||||||
@ -40,7 +40,7 @@ class ImageDataset(data.Dataset):
|
|||||||
)
|
)
|
||||||
self.reader = reader
|
self.reader = reader
|
||||||
self.load_bytes = load_bytes
|
self.load_bytes = load_bytes
|
||||||
self.img_mode = img_mode
|
self.input_img_mode = input_img_mode
|
||||||
self.transform = transform
|
self.transform = transform
|
||||||
self.target_transform = target_transform
|
self.target_transform = target_transform
|
||||||
self._consecutive_errors = 0
|
self._consecutive_errors = 0
|
||||||
@ -59,8 +59,8 @@ class ImageDataset(data.Dataset):
|
|||||||
raise e
|
raise e
|
||||||
self._consecutive_errors = 0
|
self._consecutive_errors = 0
|
||||||
|
|
||||||
if self.img_mode and not self.load_bytes:
|
if self.input_img_mode and not self.load_bytes:
|
||||||
img = img.convert(self.img_mode)
|
img = img.convert(self.input_img_mode)
|
||||||
if self.transform is not None:
|
if self.transform is not None:
|
||||||
img = self.transform(img)
|
img = self.transform(img)
|
||||||
|
|
||||||
@ -90,12 +90,17 @@ class IterableImageDataset(data.IterableDataset):
|
|||||||
split='train',
|
split='train',
|
||||||
class_map=None,
|
class_map=None,
|
||||||
is_training=False,
|
is_training=False,
|
||||||
batch_size=None,
|
batch_size=1,
|
||||||
|
num_samples=None,
|
||||||
seed=42,
|
seed=42,
|
||||||
repeats=0,
|
repeats=0,
|
||||||
download=False,
|
download=False,
|
||||||
|
input_img_mode='RGB',
|
||||||
|
input_key=None,
|
||||||
|
target_key=None,
|
||||||
transform=None,
|
transform=None,
|
||||||
target_transform=None,
|
target_transform=None,
|
||||||
|
max_steps=None,
|
||||||
):
|
):
|
||||||
assert reader is not None
|
assert reader is not None
|
||||||
if isinstance(reader, str):
|
if isinstance(reader, str):
|
||||||
@ -106,9 +111,14 @@ class IterableImageDataset(data.IterableDataset):
|
|||||||
class_map=class_map,
|
class_map=class_map,
|
||||||
is_training=is_training,
|
is_training=is_training,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
|
num_samples=num_samples,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
repeats=repeats,
|
repeats=repeats,
|
||||||
download=download,
|
download=download,
|
||||||
|
input_img_mode=input_img_mode,
|
||||||
|
input_key=input_key,
|
||||||
|
target_key=target_key,
|
||||||
|
max_steps=max_steps,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.reader = reader
|
self.reader = reader
|
||||||
|
@ -3,6 +3,7 @@
|
|||||||
Hacked together by / Copyright 2021, Ross Wightman
|
Hacked together by / Copyright 2021, Ross Wightman
|
||||||
"""
|
"""
|
||||||
import os
|
import os
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from torchvision.datasets import CIFAR100, CIFAR10, MNIST, KMNIST, FashionMNIST, ImageFolder
|
from torchvision.datasets import CIFAR100, CIFAR10, MNIST, KMNIST, FashionMNIST, ImageFolder
|
||||||
try:
|
try:
|
||||||
@ -60,22 +61,24 @@ def _search_split(root, split):
|
|||||||
|
|
||||||
|
|
||||||
def create_dataset(
|
def create_dataset(
|
||||||
name,
|
name: str,
|
||||||
root,
|
root: Optional[str] = None,
|
||||||
split='validation',
|
split: str = 'validation',
|
||||||
search_split=True,
|
search_split: bool = True,
|
||||||
class_map=None,
|
class_map: dict = None,
|
||||||
load_bytes=False,
|
load_bytes: bool = False,
|
||||||
is_training=False,
|
is_training: bool = False,
|
||||||
download=False,
|
download: bool = False,
|
||||||
batch_size=None,
|
batch_size: int = 1,
|
||||||
seed=42,
|
num_samples: Optional[int] = None,
|
||||||
repeats=0,
|
seed: int = 42,
|
||||||
**kwargs
|
repeats: int = 0,
|
||||||
|
input_img_mode: str = 'RGB',
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
""" Dataset factory method
|
""" Dataset factory method
|
||||||
|
|
||||||
In parenthesis after each arg are the type of dataset supported for each arg, one of:
|
In parentheses after each arg are the type of dataset supported for each arg, one of:
|
||||||
* folder - default, timm folder (or tar) based ImageDataset
|
* folder - default, timm folder (or tar) based ImageDataset
|
||||||
* torch - torchvision based datasets
|
* torch - torchvision based datasets
|
||||||
* HFDS - Hugging Face Datasets
|
* HFDS - Hugging Face Datasets
|
||||||
@ -97,11 +100,13 @@ def create_dataset(
|
|||||||
batch_size: batch size hint for (TFDS, WDS)
|
batch_size: batch size hint for (TFDS, WDS)
|
||||||
seed: seed for iterable datasets (TFDS, WDS)
|
seed: seed for iterable datasets (TFDS, WDS)
|
||||||
repeats: dataset repeats per iteration i.e. epoch (TFDS, WDS)
|
repeats: dataset repeats per iteration i.e. epoch (TFDS, WDS)
|
||||||
|
input_img_mode: Input image color conversion mode e.g. 'RGB', 'L' (folder, TFDS, WDS, HFDS)
|
||||||
**kwargs: other args to pass to dataset
|
**kwargs: other args to pass to dataset
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dataset object
|
Dataset object
|
||||||
"""
|
"""
|
||||||
|
kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
||||||
name = name.lower()
|
name = name.lower()
|
||||||
if name.startswith('torch/'):
|
if name.startswith('torch/'):
|
||||||
name = name.split('/', 2)[-1]
|
name = name.split('/', 2)[-1]
|
||||||
@ -151,7 +156,29 @@ def create_dataset(
|
|||||||
elif name.startswith('hfds/'):
|
elif name.startswith('hfds/'):
|
||||||
# NOTE right now, HF datasets default arrow format is a random-access Dataset,
|
# NOTE right now, HF datasets default arrow format is a random-access Dataset,
|
||||||
# There will be a IterableDataset variant too, TBD
|
# There will be a IterableDataset variant too, TBD
|
||||||
ds = ImageDataset(root, reader=name, split=split, class_map=class_map, **kwargs)
|
ds = ImageDataset(
|
||||||
|
root,
|
||||||
|
reader=name,
|
||||||
|
split=split,
|
||||||
|
class_map=class_map,
|
||||||
|
input_img_mode=input_img_mode,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
elif name.startswith('hfids/'):
|
||||||
|
ds = IterableImageDataset(
|
||||||
|
root,
|
||||||
|
reader=name,
|
||||||
|
split=split,
|
||||||
|
class_map=class_map,
|
||||||
|
is_training=is_training,
|
||||||
|
download=download,
|
||||||
|
batch_size=batch_size,
|
||||||
|
num_samples=num_samples,
|
||||||
|
repeats=repeats,
|
||||||
|
seed=seed,
|
||||||
|
input_img_mode=input_img_mode,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
elif name.startswith('tfds/'):
|
elif name.startswith('tfds/'):
|
||||||
ds = IterableImageDataset(
|
ds = IterableImageDataset(
|
||||||
root,
|
root,
|
||||||
@ -161,8 +188,10 @@ def create_dataset(
|
|||||||
is_training=is_training,
|
is_training=is_training,
|
||||||
download=download,
|
download=download,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
|
num_samples=num_samples,
|
||||||
repeats=repeats,
|
repeats=repeats,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
|
input_img_mode=input_img_mode,
|
||||||
**kwargs
|
**kwargs
|
||||||
)
|
)
|
||||||
elif name.startswith('wds/'):
|
elif name.startswith('wds/'):
|
||||||
@ -173,8 +202,10 @@ def create_dataset(
|
|||||||
class_map=class_map,
|
class_map=class_map,
|
||||||
is_training=is_training,
|
is_training=is_training,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
|
num_samples=num_samples,
|
||||||
repeats=repeats,
|
repeats=repeats,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
|
input_img_mode=input_img_mode,
|
||||||
**kwargs
|
**kwargs
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@ -182,5 +213,12 @@ def create_dataset(
|
|||||||
if search_split and os.path.isdir(root):
|
if search_split and os.path.isdir(root):
|
||||||
# look for split specific sub-folder in root
|
# look for split specific sub-folder in root
|
||||||
root = _search_split(root, split)
|
root = _search_split(root, split)
|
||||||
ds = ImageDataset(root, reader=name, class_map=class_map, load_bytes=load_bytes, **kwargs)
|
ds = ImageDataset(
|
||||||
|
root,
|
||||||
|
reader=name,
|
||||||
|
class_map=class_map,
|
||||||
|
load_bytes=load_bytes,
|
||||||
|
input_img_mode=input_img_mode,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
return ds
|
return ds
|
||||||
|
@ -10,14 +10,14 @@ import random
|
|||||||
from contextlib import suppress
|
from contextlib import suppress
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from itertools import repeat
|
from itertools import repeat
|
||||||
from typing import Callable
|
from typing import Callable, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.utils.data
|
import torch.utils.data
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from .constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
from .constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||||
from .dataset import IterableImageDataset
|
from .dataset import IterableImageDataset, ImageDataset
|
||||||
from .distributed_sampler import OrderedDistributedSampler, RepeatAugSampler
|
from .distributed_sampler import OrderedDistributedSampler, RepeatAugSampler
|
||||||
from .random_erasing import RandomErasing
|
from .random_erasing import RandomErasing
|
||||||
from .mixup import FastCollateMixup
|
from .mixup import FastCollateMixup
|
||||||
@ -187,41 +187,91 @@ def _worker_init(worker_id, worker_seeding='all'):
|
|||||||
|
|
||||||
|
|
||||||
def create_loader(
|
def create_loader(
|
||||||
dataset,
|
dataset: Union[ImageDataset, IterableImageDataset],
|
||||||
input_size,
|
input_size: Union[int, Tuple[int, int], Tuple[int, int, int]],
|
||||||
batch_size,
|
batch_size: int,
|
||||||
is_training=False,
|
is_training: bool = False,
|
||||||
use_prefetcher=True,
|
no_aug: bool = False,
|
||||||
no_aug=False,
|
re_prob: float = 0.,
|
||||||
re_prob=0.,
|
re_mode: str = 'const',
|
||||||
re_mode='const',
|
re_count: int = 1,
|
||||||
re_count=1,
|
re_split: bool = False,
|
||||||
re_split=False,
|
scale: Optional[Tuple[float, float]] = None,
|
||||||
scale=None,
|
ratio: Optional[Tuple[float, float]] = None,
|
||||||
ratio=None,
|
hflip: float = 0.5,
|
||||||
hflip=0.5,
|
vflip: float = 0.,
|
||||||
vflip=0.,
|
color_jitter: float = 0.4,
|
||||||
color_jitter=0.4,
|
color_jitter_prob: Optional[float] = None,
|
||||||
auto_augment=None,
|
grayscale_prob: float = 0.,
|
||||||
num_aug_repeats=0,
|
gaussian_blur_prob: float = 0.,
|
||||||
num_aug_splits=0,
|
auto_augment: Optional[str] = None,
|
||||||
interpolation='bilinear',
|
num_aug_repeats: int = 0,
|
||||||
mean=IMAGENET_DEFAULT_MEAN,
|
num_aug_splits: int = 0,
|
||||||
std=IMAGENET_DEFAULT_STD,
|
interpolation: str = 'bilinear',
|
||||||
num_workers=1,
|
mean: Tuple[float, ...] = IMAGENET_DEFAULT_MEAN,
|
||||||
distributed=False,
|
std: Tuple[float, ...] = IMAGENET_DEFAULT_STD,
|
||||||
crop_pct=None,
|
num_workers: int = 1,
|
||||||
crop_mode=None,
|
distributed: bool = False,
|
||||||
collate_fn=None,
|
crop_pct: Optional[float] = None,
|
||||||
pin_memory=False,
|
crop_mode: Optional[str] = None,
|
||||||
fp16=False, # deprecated, use img_dtype
|
crop_border_pixels: Optional[int] = None,
|
||||||
img_dtype=torch.float32,
|
collate_fn: Optional[Callable] = None,
|
||||||
device=torch.device('cuda'),
|
pin_memory: bool = False,
|
||||||
tf_preprocessing=False,
|
fp16: bool = False, # deprecated, use img_dtype
|
||||||
use_multi_epochs_loader=False,
|
img_dtype: torch.dtype = torch.float32,
|
||||||
persistent_workers=True,
|
device: torch.device = torch.device('cuda'),
|
||||||
worker_seeding='all',
|
use_prefetcher: bool = True,
|
||||||
|
use_multi_epochs_loader: bool = False,
|
||||||
|
persistent_workers: bool = True,
|
||||||
|
worker_seeding: str = 'all',
|
||||||
|
tf_preprocessing: bool = False,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset: The image dataset to load.
|
||||||
|
input_size: Target input size (channels, height, width) tuple or size scalar.
|
||||||
|
batch_size: Number of samples in a batch.
|
||||||
|
is_training: Return training (random) transforms.
|
||||||
|
no_aug: Disable augmentation for training (useful for debug).
|
||||||
|
re_prob: Random erasing probability.
|
||||||
|
re_mode: Random erasing fill mode.
|
||||||
|
re_count: Number of random erasing regions.
|
||||||
|
re_split: Control split of random erasing across batch size.
|
||||||
|
scale: Random resize scale range (crop area, < 1.0 => zoom in).
|
||||||
|
ratio: Random aspect ratio range (crop ratio for RRC, ratio adjustment factor for RKR).
|
||||||
|
hflip: Horizontal flip probability.
|
||||||
|
vflip: Vertical flip probability.
|
||||||
|
color_jitter: Random color jitter component factors (brightness, contrast, saturation, hue).
|
||||||
|
Scalar is applied as (scalar,) * 3 (no hue).
|
||||||
|
color_jitter_prob: Apply color jitter with this probability if not None (for SimlCLR-like aug
|
||||||
|
grayscale_prob: Probability of converting image to grayscale (for SimCLR-like aug).
|
||||||
|
gaussian_blur_prob: Probability of applying gaussian blur (for SimCLR-like aug).
|
||||||
|
auto_augment: Auto augment configuration string (see auto_augment.py).
|
||||||
|
num_aug_repeats: Enable special sampler to repeat same augmentation across distributed GPUs.
|
||||||
|
num_aug_splits: Enable mode where augmentations can be split across the batch.
|
||||||
|
interpolation: Image interpolation mode.
|
||||||
|
mean: Image normalization mean.
|
||||||
|
std: Image normalization standard deviation.
|
||||||
|
num_workers: Num worker processes per DataLoader.
|
||||||
|
distributed: Enable dataloading for distributed training.
|
||||||
|
crop_pct: Inference crop percentage (output size / resize size).
|
||||||
|
crop_mode: Inference crop mode. One of ['squash', 'border', 'center']. Defaults to 'center' when None.
|
||||||
|
crop_border_pixels: Inference crop border of specified # pixels around edge of original image.
|
||||||
|
collate_fn: Override default collate_fn.
|
||||||
|
pin_memory: Pin memory for device transfer.
|
||||||
|
fp16: Deprecated argument for half-precision input dtype. Use img_dtype.
|
||||||
|
img_dtype: Data type for input image.
|
||||||
|
device: Device to transfer inputs and targets to.
|
||||||
|
use_prefetcher: Use efficient pre-fetcher to load samples onto device.
|
||||||
|
use_multi_epochs_loader:
|
||||||
|
persistent_workers: Enable persistent worker processes.
|
||||||
|
worker_seeding: Control worker random seeding at init.
|
||||||
|
tf_preprocessing: Use TF 1.0 inference preprocessing for testing model ports.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DataLoader
|
||||||
|
"""
|
||||||
re_num_splits = 0
|
re_num_splits = 0
|
||||||
if re_split:
|
if re_split:
|
||||||
# apply RE to second half of batch if no aug split otherwise line up with aug split
|
# apply RE to second half of batch if no aug split otherwise line up with aug split
|
||||||
@ -229,24 +279,28 @@ def create_loader(
|
|||||||
dataset.transform = create_transform(
|
dataset.transform = create_transform(
|
||||||
input_size,
|
input_size,
|
||||||
is_training=is_training,
|
is_training=is_training,
|
||||||
use_prefetcher=use_prefetcher,
|
|
||||||
no_aug=no_aug,
|
no_aug=no_aug,
|
||||||
scale=scale,
|
scale=scale,
|
||||||
ratio=ratio,
|
ratio=ratio,
|
||||||
hflip=hflip,
|
hflip=hflip,
|
||||||
vflip=vflip,
|
vflip=vflip,
|
||||||
color_jitter=color_jitter,
|
color_jitter=color_jitter,
|
||||||
|
color_jitter_prob=color_jitter_prob,
|
||||||
|
grayscale_prob=grayscale_prob,
|
||||||
|
gaussian_blur_prob=gaussian_blur_prob,
|
||||||
auto_augment=auto_augment,
|
auto_augment=auto_augment,
|
||||||
interpolation=interpolation,
|
interpolation=interpolation,
|
||||||
mean=mean,
|
mean=mean,
|
||||||
std=std,
|
std=std,
|
||||||
crop_pct=crop_pct,
|
crop_pct=crop_pct,
|
||||||
crop_mode=crop_mode,
|
crop_mode=crop_mode,
|
||||||
tf_preprocessing=tf_preprocessing,
|
crop_border_pixels=crop_border_pixels,
|
||||||
re_prob=re_prob,
|
re_prob=re_prob,
|
||||||
re_mode=re_mode,
|
re_mode=re_mode,
|
||||||
re_count=re_count,
|
re_count=re_count,
|
||||||
re_num_splits=re_num_splits,
|
re_num_splits=re_num_splits,
|
||||||
|
tf_preprocessing=tf_preprocessing,
|
||||||
|
use_prefetcher=use_prefetcher,
|
||||||
separate=num_aug_splits > 0,
|
separate=num_aug_splits > 0,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1,10 +1,17 @@
|
|||||||
import os
|
import os
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from .reader_image_folder import ReaderImageFolder
|
from .reader_image_folder import ReaderImageFolder
|
||||||
from .reader_image_in_tar import ReaderImageInTar
|
from .reader_image_in_tar import ReaderImageInTar
|
||||||
|
|
||||||
|
|
||||||
def create_reader(name, root, split='train', **kwargs):
|
def create_reader(
|
||||||
|
name: str,
|
||||||
|
root: Optional[str] = None,
|
||||||
|
split: str = 'train',
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
||||||
name = name.lower()
|
name = name.lower()
|
||||||
name = name.split('/', 1)
|
name = name.split('/', 1)
|
||||||
prefix = ''
|
prefix = ''
|
||||||
@ -15,15 +22,18 @@ def create_reader(name, root, split='train', **kwargs):
|
|||||||
# FIXME improve the selection right now just tfds prefix or fallback path, will need options to
|
# FIXME improve the selection right now just tfds prefix or fallback path, will need options to
|
||||||
# explicitly select other options shortly
|
# explicitly select other options shortly
|
||||||
if prefix == 'hfds':
|
if prefix == 'hfds':
|
||||||
from .reader_hfds import ReaderHfds # defer tensorflow import
|
from .reader_hfds import ReaderHfds # defer Hf datasets import
|
||||||
reader = ReaderHfds(root, name, split=split, **kwargs)
|
reader = ReaderHfds(name=name, root=root, split=split, **kwargs)
|
||||||
|
elif prefix == 'hfids':
|
||||||
|
from .reader_hfids import ReaderHfids # defer HF datasets import
|
||||||
|
reader = ReaderHfids(name=name, root=root, split=split, **kwargs)
|
||||||
elif prefix == 'tfds':
|
elif prefix == 'tfds':
|
||||||
from .reader_tfds import ReaderTfds # defer tensorflow import
|
from .reader_tfds import ReaderTfds # defer tensorflow import
|
||||||
reader = ReaderTfds(root, name, split=split, **kwargs)
|
reader = ReaderTfds(name=name, root=root, split=split, **kwargs)
|
||||||
elif prefix == 'wds':
|
elif prefix == 'wds':
|
||||||
from .reader_wds import ReaderWds
|
from .reader_wds import ReaderWds
|
||||||
kwargs.pop('download', False)
|
kwargs.pop('download', False)
|
||||||
reader = ReaderWds(root, name, split=split, **kwargs)
|
reader = ReaderWds(root=root, name=name, split=split, **kwargs)
|
||||||
else:
|
else:
|
||||||
assert os.path.exists(root)
|
assert os.path.exists(root)
|
||||||
# default fallback path (backwards compat), use image tar if root is a .tar file, otherwise image folder
|
# default fallback path (backwards compat), use image tar if root is a .tar file, otherwise image folder
|
||||||
|
@ -4,6 +4,8 @@ Hacked together by / Copyright 2022 Ross Wightman
|
|||||||
"""
|
"""
|
||||||
import io
|
import io
|
||||||
import math
|
import math
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
@ -12,7 +14,7 @@ try:
|
|||||||
import datasets
|
import datasets
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
print("Please install Hugging Face datasets package `pip install datasets`.")
|
print("Please install Hugging Face datasets package `pip install datasets`.")
|
||||||
exit(1)
|
raise e
|
||||||
from .class_map import load_class_map
|
from .class_map import load_class_map
|
||||||
from .reader import Reader
|
from .reader import Reader
|
||||||
|
|
||||||
@ -29,12 +31,13 @@ class ReaderHfds(Reader):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
root,
|
name: str,
|
||||||
name,
|
root: Optional[str] = None,
|
||||||
split='train',
|
split: str = 'train',
|
||||||
class_map=None,
|
class_map: dict = None,
|
||||||
label_key='label',
|
image_key: str = 'image',
|
||||||
download=False,
|
target_key: str = 'label',
|
||||||
|
download: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
"""
|
"""
|
||||||
@ -47,9 +50,10 @@ class ReaderHfds(Reader):
|
|||||||
cache_dir=self.root, # timm doesn't expect hidden cache dir for datasets, specify a path
|
cache_dir=self.root, # timm doesn't expect hidden cache dir for datasets, specify a path
|
||||||
)
|
)
|
||||||
# leave decode for caller, plus we want easy access to original path names...
|
# leave decode for caller, plus we want easy access to original path names...
|
||||||
self.dataset = self.dataset.cast_column('image', datasets.Image(decode=False))
|
self.dataset = self.dataset.cast_column(image_key, datasets.Image(decode=False))
|
||||||
|
|
||||||
self.label_key = label_key
|
self.image_key = image_key
|
||||||
|
self.label_key = target_key
|
||||||
self.remap_class = False
|
self.remap_class = False
|
||||||
if class_map:
|
if class_map:
|
||||||
self.class_to_idx = load_class_map(class_map)
|
self.class_to_idx = load_class_map(class_map)
|
||||||
@ -61,7 +65,7 @@ class ReaderHfds(Reader):
|
|||||||
|
|
||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
item = self.dataset[index]
|
item = self.dataset[index]
|
||||||
image = item['image']
|
image = item[self.image_key]
|
||||||
if 'bytes' in image and image['bytes']:
|
if 'bytes' in image and image['bytes']:
|
||||||
image = io.BytesIO(image['bytes'])
|
image = io.BytesIO(image['bytes'])
|
||||||
else:
|
else:
|
||||||
@ -77,4 +81,4 @@ class ReaderHfds(Reader):
|
|||||||
|
|
||||||
def _filename(self, index, basename=False, absolute=False):
|
def _filename(self, index, basename=False, absolute=False):
|
||||||
item = self.dataset[index]
|
item = self.dataset[index]
|
||||||
return item['image']['path']
|
return item[self.image_key]['path']
|
||||||
|
213
timm/data/readers/reader_hfids.py
Normal file
213
timm/data/readers/reader_hfids.py
Normal file
@ -0,0 +1,213 @@
|
|||||||
|
""" Dataset reader for HF IterableDataset
|
||||||
|
"""
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
from itertools import repeat, chain
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
try:
|
||||||
|
import datasets
|
||||||
|
from datasets.distributed import split_dataset_by_node
|
||||||
|
from datasets.splits import SplitInfo
|
||||||
|
except ImportError as e:
|
||||||
|
print("Please install Hugging Face datasets package `pip install datasets`.")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
|
||||||
|
from .class_map import load_class_map
|
||||||
|
from .reader import Reader
|
||||||
|
from .shared_count import SharedCount
|
||||||
|
|
||||||
|
|
||||||
|
SHUFFLE_SIZE = int(os.environ.get('HFIDS_SHUFFLE_SIZE', 4096))
|
||||||
|
|
||||||
|
|
||||||
|
class ReaderHfids(Reader):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
root: Optional[str] = None,
|
||||||
|
split: str = 'train',
|
||||||
|
is_training: bool = False,
|
||||||
|
batch_size: int = 1,
|
||||||
|
download: bool = False,
|
||||||
|
repeats: int = 0,
|
||||||
|
seed: int = 42,
|
||||||
|
class_map: Optional[dict] = None,
|
||||||
|
input_key: str = 'image',
|
||||||
|
input_img_mode: str = 'RGB',
|
||||||
|
target_key: str = 'label',
|
||||||
|
target_img_mode: str = '',
|
||||||
|
shuffle_size: Optional[int] = None,
|
||||||
|
num_samples: Optional[int] = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.root = root
|
||||||
|
self.split = split
|
||||||
|
self.is_training = is_training
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.download = download
|
||||||
|
self.repeats = repeats
|
||||||
|
self.common_seed = seed # a seed that's fixed across all worker / distributed instances
|
||||||
|
self.shuffle_size = shuffle_size or SHUFFLE_SIZE
|
||||||
|
|
||||||
|
self.input_key = input_key
|
||||||
|
self.input_img_mode = input_img_mode
|
||||||
|
self.target_key = target_key
|
||||||
|
self.target_img_mode = target_img_mode
|
||||||
|
|
||||||
|
self.builder = datasets.load_dataset_builder(name, cache_dir=root)
|
||||||
|
if download:
|
||||||
|
self.builder.download_and_prepare()
|
||||||
|
|
||||||
|
split_info: Optional[SplitInfo] = None
|
||||||
|
if self.builder.info.splits and split in self.builder.info.splits:
|
||||||
|
if isinstance(self.builder.info.splits[split], SplitInfo):
|
||||||
|
split_info: Optional[SplitInfo] = self.builder.info.splits[split]
|
||||||
|
|
||||||
|
if num_samples:
|
||||||
|
self.num_samples = num_samples
|
||||||
|
elif split_info and split_info.num_examples:
|
||||||
|
self.num_samples = split_info.num_examples
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"Dataset length is unknown, please pass `num_samples` explicitely. "
|
||||||
|
"The number of steps needs to be known in advance for the learning rate scheduler."
|
||||||
|
)
|
||||||
|
|
||||||
|
self.remap_class = False
|
||||||
|
if class_map:
|
||||||
|
self.class_to_idx = load_class_map(class_map)
|
||||||
|
self.remap_class = True
|
||||||
|
else:
|
||||||
|
self.class_to_idx = {}
|
||||||
|
|
||||||
|
# Distributed world state
|
||||||
|
self.dist_rank = 0
|
||||||
|
self.dist_num_replicas = 1
|
||||||
|
if dist.is_available() and dist.is_initialized() and dist.get_world_size() > 1:
|
||||||
|
self.dist_rank = dist.get_rank()
|
||||||
|
self.dist_num_replicas = dist.get_world_size()
|
||||||
|
|
||||||
|
# Attributes that are updated in _lazy_init
|
||||||
|
self.worker_info = None
|
||||||
|
self.worker_id = 0
|
||||||
|
self.num_workers = 1
|
||||||
|
self.global_worker_id = 0
|
||||||
|
self.global_num_workers = 1
|
||||||
|
|
||||||
|
# Initialized lazily on each dataloader worker process
|
||||||
|
self.ds: Optional[datasets.IterableDataset] = None
|
||||||
|
self.epoch = SharedCount()
|
||||||
|
|
||||||
|
def set_epoch(self, count):
|
||||||
|
# to update the shuffling effective_seed = seed + epoch
|
||||||
|
self.epoch.value = count
|
||||||
|
|
||||||
|
def set_loader_cfg(
|
||||||
|
self,
|
||||||
|
num_workers: Optional[int] = None,
|
||||||
|
):
|
||||||
|
if self.ds is not None:
|
||||||
|
return
|
||||||
|
if num_workers is not None:
|
||||||
|
self.num_workers = num_workers
|
||||||
|
self.global_num_workers = self.dist_num_replicas * self.num_workers
|
||||||
|
|
||||||
|
def _lazy_init(self):
|
||||||
|
""" Lazily initialize worker (in worker processes)
|
||||||
|
"""
|
||||||
|
if self.worker_info is None:
|
||||||
|
worker_info = torch.utils.data.get_worker_info()
|
||||||
|
if worker_info is not None:
|
||||||
|
self.worker_info = worker_info
|
||||||
|
self.worker_id = worker_info.id
|
||||||
|
self.num_workers = worker_info.num_workers
|
||||||
|
self.global_num_workers = self.dist_num_replicas * self.num_workers
|
||||||
|
self.global_worker_id = self.dist_rank * self.num_workers + self.worker_id
|
||||||
|
|
||||||
|
if self.download:
|
||||||
|
dataset = self.builder.as_dataset(split=self.split)
|
||||||
|
# to distribute evenly to workers
|
||||||
|
ds = dataset.to_iterable_dataset(num_shards=self.global_num_workers)
|
||||||
|
else:
|
||||||
|
# in this case the number of shard is determined by the number of remote files
|
||||||
|
ds = self.builder.as_streaming_dataset(split=self.split)
|
||||||
|
|
||||||
|
if self.is_training:
|
||||||
|
# will shuffle the list of shards and use a shuffle buffer
|
||||||
|
ds = ds.shuffle(seed=self.common_seed, buffer_size=self.shuffle_size)
|
||||||
|
|
||||||
|
# Distributed:
|
||||||
|
# The dataset has a number of shards that is a factor of `dist_num_replicas` (i.e. if `ds.n_shards % dist_num_replicas == 0`),
|
||||||
|
# so the shards are evenly assigned across the nodes.
|
||||||
|
# If it's not the case for dataset streaming, each node keeps 1 example out of `dist_num_replicas`, skipping the other examples.
|
||||||
|
|
||||||
|
# Workers:
|
||||||
|
# In a node, datasets.IterableDataset assigns the shards assigned to the node as evenly as possible to workers.
|
||||||
|
self.ds = split_dataset_by_node(ds, rank=self.dist_rank, world_size=self.dist_num_replicas)
|
||||||
|
|
||||||
|
def _num_samples_per_worker(self):
|
||||||
|
num_worker_samples = \
|
||||||
|
max(1, self.repeats) * self.num_samples / max(self.global_num_workers, self.dist_num_replicas)
|
||||||
|
if self.is_training or self.dist_num_replicas > 1:
|
||||||
|
num_worker_samples = math.ceil(num_worker_samples)
|
||||||
|
if self.is_training and self.batch_size is not None:
|
||||||
|
num_worker_samples = math.ceil(num_worker_samples / self.batch_size) * self.batch_size
|
||||||
|
return int(num_worker_samples)
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
if self.ds is None:
|
||||||
|
self._lazy_init()
|
||||||
|
self.ds.set_epoch(self.epoch.value)
|
||||||
|
|
||||||
|
target_sample_count = self._num_samples_per_worker()
|
||||||
|
sample_count = 0
|
||||||
|
ds_iter = iter(self.ds)
|
||||||
|
if self.is_training:
|
||||||
|
ds_iter = chain.from_iterable(repeat(ds_iter))
|
||||||
|
for sample in ds_iter:
|
||||||
|
input_data: Image.Image = sample[self.input_key]
|
||||||
|
if self.input_img_mode and input_data.mode != self.input_img_mode:
|
||||||
|
input_data = input_data.convert(self.input_img_mode)
|
||||||
|
target_data = sample[self.target_key]
|
||||||
|
if self.target_img_mode:
|
||||||
|
assert isinstance(target_data, Image.Image), "target_img_mode is specified but target is not an image"
|
||||||
|
if target_data.mode != self.target_img_mode:
|
||||||
|
target_data = target_data.convert(self.target_img_mode)
|
||||||
|
elif self.remap_class:
|
||||||
|
target_data = self.class_to_idx[target_data]
|
||||||
|
yield input_data, target_data
|
||||||
|
sample_count += 1
|
||||||
|
if self.is_training and sample_count >= target_sample_count:
|
||||||
|
break
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
num_samples = self._num_samples_per_worker() * self.num_workers
|
||||||
|
return num_samples
|
||||||
|
|
||||||
|
def _filename(self, index, basename=False, absolute=False):
|
||||||
|
assert False, "Not supported" # no random access to examples
|
||||||
|
|
||||||
|
def filenames(self, basename=False, absolute=False):
|
||||||
|
""" Return all filenames in dataset, overrides base"""
|
||||||
|
if self.ds is None:
|
||||||
|
self._lazy_init()
|
||||||
|
names = []
|
||||||
|
for sample in self.ds:
|
||||||
|
if 'file_name' in sample:
|
||||||
|
name = sample['file_name']
|
||||||
|
elif 'filename' in sample:
|
||||||
|
name = sample['filename']
|
||||||
|
elif 'id' in sample:
|
||||||
|
name = sample['id']
|
||||||
|
elif 'image_id' in sample:
|
||||||
|
name = sample['image_id']
|
||||||
|
else:
|
||||||
|
assert False, "No supported name field present"
|
||||||
|
names.append(name)
|
||||||
|
return names
|
@ -61,14 +61,23 @@ class ReaderImageFolder(Reader):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
root,
|
root,
|
||||||
class_map=''):
|
class_map='',
|
||||||
|
input_key=None,
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.root = root
|
self.root = root
|
||||||
class_to_idx = None
|
class_to_idx = None
|
||||||
if class_map:
|
if class_map:
|
||||||
class_to_idx = load_class_map(class_map, root)
|
class_to_idx = load_class_map(class_map, root)
|
||||||
self.samples, self.class_to_idx = find_images_and_targets(root, class_to_idx=class_to_idx)
|
find_types = None
|
||||||
|
if input_key:
|
||||||
|
find_types = input_key.split(';')
|
||||||
|
self.samples, self.class_to_idx = find_images_and_targets(
|
||||||
|
root,
|
||||||
|
class_to_idx=class_to_idx,
|
||||||
|
types=find_types,
|
||||||
|
)
|
||||||
if len(self.samples) == 0:
|
if len(self.samples) == 0:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f'Found 0 images in subfolders of {root}. '
|
f'Found 0 images in subfolders of {root}. '
|
||||||
|
@ -8,6 +8,7 @@ Hacked together by / Copyright 2020 Ross Wightman
|
|||||||
"""
|
"""
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
|
import sys
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -32,7 +33,7 @@ try:
|
|||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
print(e)
|
print(e)
|
||||||
print("Please install tensorflow_datasets package `pip install tensorflow-datasets`.")
|
print("Please install tensorflow_datasets package `pip install tensorflow-datasets`.")
|
||||||
exit(1)
|
raise e
|
||||||
|
|
||||||
from .class_map import load_class_map
|
from .class_map import load_class_map
|
||||||
from .reader import Reader
|
from .reader import Reader
|
||||||
@ -45,10 +46,10 @@ PREFETCH_SIZE = int(os.environ.get('TFDS_PREFETCH_SIZE', 2048)) # samples to pr
|
|||||||
|
|
||||||
|
|
||||||
@tfds.decode.make_decoder()
|
@tfds.decode.make_decoder()
|
||||||
def decode_example(serialized_image, feature, dct_method='INTEGER_ACCURATE'):
|
def decode_example(serialized_image, feature, dct_method='INTEGER_ACCURATE', channels=3):
|
||||||
return tf.image.decode_jpeg(
|
return tf.image.decode_jpeg(
|
||||||
serialized_image,
|
serialized_image,
|
||||||
channels=3,
|
channels=channels,
|
||||||
dct_method=dct_method,
|
dct_method=dct_method,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -92,18 +93,18 @@ class ReaderTfds(Reader):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
root,
|
|
||||||
name,
|
name,
|
||||||
|
root=None,
|
||||||
split='train',
|
split='train',
|
||||||
class_map=None,
|
class_map=None,
|
||||||
is_training=False,
|
is_training=False,
|
||||||
batch_size=None,
|
batch_size=1,
|
||||||
download=False,
|
download=False,
|
||||||
repeats=0,
|
repeats=0,
|
||||||
seed=42,
|
seed=42,
|
||||||
input_name='image',
|
input_key='image',
|
||||||
input_img_mode='RGB',
|
input_img_mode='RGB',
|
||||||
target_name='label',
|
target_key='label',
|
||||||
target_img_mode='',
|
target_img_mode='',
|
||||||
prefetch_size=None,
|
prefetch_size=None,
|
||||||
shuffle_size=None,
|
shuffle_size=None,
|
||||||
@ -120,9 +121,9 @@ class ReaderTfds(Reader):
|
|||||||
download: download and build TFDS dataset if set, otherwise must use tfds CLI
|
download: download and build TFDS dataset if set, otherwise must use tfds CLI
|
||||||
repeats: iterate through (repeat) the dataset this many times per iteration (once if 0 or 1)
|
repeats: iterate through (repeat) the dataset this many times per iteration (once if 0 or 1)
|
||||||
seed: common seed for shard shuffle across all distributed/worker instances
|
seed: common seed for shard shuffle across all distributed/worker instances
|
||||||
input_name: name of Feature to return as data (input)
|
input_key: name of Feature to return as data (input)
|
||||||
input_img_mode: image mode if input is an image (currently PIL mode string)
|
input_img_mode: image mode if input is an image (currently PIL mode string)
|
||||||
target_name: name of Feature to return as target (label)
|
target_key: name of Feature to return as target (label)
|
||||||
target_img_mode: image mode if target is an image (currently PIL mode string)
|
target_img_mode: image mode if target is an image (currently PIL mode string)
|
||||||
prefetch_size: override default tf.data prefetch buffer size
|
prefetch_size: override default tf.data prefetch buffer size
|
||||||
shuffle_size: override default tf.data shuffle buffer size
|
shuffle_size: override default tf.data shuffle buffer size
|
||||||
@ -132,9 +133,6 @@ class ReaderTfds(Reader):
|
|||||||
self.root = root
|
self.root = root
|
||||||
self.split = split
|
self.split = split
|
||||||
self.is_training = is_training
|
self.is_training = is_training
|
||||||
if self.is_training:
|
|
||||||
assert batch_size is not None, \
|
|
||||||
"Must specify batch_size in training mode for reasonable behaviour w/ TFDS wrapper"
|
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.repeats = repeats
|
self.repeats = repeats
|
||||||
self.common_seed = seed # a seed that's fixed across all worker / distributed instances
|
self.common_seed = seed # a seed that's fixed across all worker / distributed instances
|
||||||
@ -145,10 +143,10 @@ class ReaderTfds(Reader):
|
|||||||
self.max_threadpool_size = max_threadpool_size or MAX_TP_SIZE
|
self.max_threadpool_size = max_threadpool_size or MAX_TP_SIZE
|
||||||
|
|
||||||
# TFDS builder and split information
|
# TFDS builder and split information
|
||||||
self.input_name = input_name # FIXME support tuples / lists of inputs and targets and full range of Feature
|
self.input_key = input_key # FIXME support tuples / lists of inputs and targets and full range of Feature
|
||||||
self.input_img_mode = input_img_mode
|
self.input_img_mode = input_img_mode
|
||||||
self.target_name = target_name
|
self.target_key = target_key
|
||||||
self.target_img_mode = target_img_mode
|
self.target_img_mode = target_img_mode # for dense pixel targets
|
||||||
self.builder = tfds.builder(name, data_dir=root)
|
self.builder = tfds.builder(name, data_dir=root)
|
||||||
# NOTE: the tfds command line app can be used download & prepare datasets if you don't enable download flag
|
# NOTE: the tfds command line app can be used download & prepare datasets if you don't enable download flag
|
||||||
if download:
|
if download:
|
||||||
@ -158,7 +156,7 @@ class ReaderTfds(Reader):
|
|||||||
self.class_to_idx = load_class_map(class_map)
|
self.class_to_idx = load_class_map(class_map)
|
||||||
self.remap_class = True
|
self.remap_class = True
|
||||||
else:
|
else:
|
||||||
self.class_to_idx = get_class_labels(self.builder.info) if self.target_name == 'label' else {}
|
self.class_to_idx = get_class_labels(self.builder.info) if self.target_key == 'label' else {}
|
||||||
self.split_info = self.builder.info.splits[split]
|
self.split_info = self.builder.info.splits[split]
|
||||||
self.num_samples = self.split_info.num_examples
|
self.num_samples = self.split_info.num_examples
|
||||||
|
|
||||||
@ -258,7 +256,7 @@ class ReaderTfds(Reader):
|
|||||||
ds = self.builder.as_dataset(
|
ds = self.builder.as_dataset(
|
||||||
split=self.subsplit or self.split,
|
split=self.subsplit or self.split,
|
||||||
shuffle_files=self.is_training,
|
shuffle_files=self.is_training,
|
||||||
decoders=dict(image=decode_example()),
|
decoders=dict(image=decode_example(channels=1 if self.input_img_mode == 'L' else 3)),
|
||||||
read_config=read_config,
|
read_config=read_config,
|
||||||
)
|
)
|
||||||
# avoid overloading threading w/ combo of TF ds threads + PyTorch workers
|
# avoid overloading threading w/ combo of TF ds threads + PyTorch workers
|
||||||
@ -282,7 +280,7 @@ class ReaderTfds(Reader):
|
|||||||
max(1, self.repeats) * self.num_samples / max(self.global_num_workers, self.dist_num_replicas)
|
max(1, self.repeats) * self.num_samples / max(self.global_num_workers, self.dist_num_replicas)
|
||||||
if self.is_training or self.dist_num_replicas > 1:
|
if self.is_training or self.dist_num_replicas > 1:
|
||||||
num_worker_samples = math.ceil(num_worker_samples)
|
num_worker_samples = math.ceil(num_worker_samples)
|
||||||
if self.is_training and self.batch_size is not None:
|
if self.is_training:
|
||||||
num_worker_samples = math.ceil(num_worker_samples / self.batch_size) * self.batch_size
|
num_worker_samples = math.ceil(num_worker_samples / self.batch_size) * self.batch_size
|
||||||
return int(num_worker_samples)
|
return int(num_worker_samples)
|
||||||
|
|
||||||
@ -300,11 +298,14 @@ class ReaderTfds(Reader):
|
|||||||
# Iterate until exhausted or sample count hits target when training (ds.repeat enabled)
|
# Iterate until exhausted or sample count hits target when training (ds.repeat enabled)
|
||||||
sample_count = 0
|
sample_count = 0
|
||||||
for sample in self.ds:
|
for sample in self.ds:
|
||||||
input_data = sample[self.input_name]
|
input_data = sample[self.input_key]
|
||||||
if self.input_img_mode:
|
if self.input_img_mode:
|
||||||
|
if self.input_img_mode == 'L' and input_data.ndim == 3:
|
||||||
|
input_data = input_data[:, :, 0]
|
||||||
input_data = Image.fromarray(input_data, mode=self.input_img_mode)
|
input_data = Image.fromarray(input_data, mode=self.input_img_mode)
|
||||||
target_data = sample[self.target_name]
|
target_data = sample[self.target_key]
|
||||||
if self.target_img_mode:
|
if self.target_img_mode:
|
||||||
|
# dense pixel target
|
||||||
target_data = Image.fromarray(target_data, mode=self.target_img_mode)
|
target_data = Image.fromarray(target_data, mode=self.target_img_mode)
|
||||||
elif self.remap_class:
|
elif self.remap_class:
|
||||||
target_data = self.class_to_idx[target_data]
|
target_data = self.class_to_idx[target_data]
|
||||||
|
@ -22,7 +22,7 @@ from torch.utils.data import Dataset, IterableDataset, get_worker_info
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
import webdataset as wds
|
import webdataset as wds
|
||||||
from webdataset.filters import _shuffle
|
from webdataset.filters import _shuffle, getfirst
|
||||||
from webdataset.shardlists import expand_urls
|
from webdataset.shardlists import expand_urls
|
||||||
from webdataset.tariterators import base_plus_ext, url_opener, tar_file_expander, valid_sample
|
from webdataset.tariterators import base_plus_ext, url_opener, tar_file_expander, valid_sample
|
||||||
except ImportError:
|
except ImportError:
|
||||||
@ -35,27 +35,30 @@ from .shared_count import SharedCount
|
|||||||
|
|
||||||
_logger = logging.getLogger(__name__)
|
_logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
SHUFFLE_SIZE = int(os.environ.get('WDS_SHUFFLE_SIZE', 8192))
|
SAMPLE_SHUFFLE_SIZE = int(os.environ.get('WDS_SHUFFLE_SIZE', 8192))
|
||||||
|
SAMPLE_INITIAL_SIZE = int(os.environ.get('WDS_INITIAL_SIZE', 2048))
|
||||||
|
|
||||||
|
|
||||||
def _load_info(root, basename='info'):
|
def _load_info(root, names=('_info.json', 'info.json')):
|
||||||
info_json = os.path.join(root, basename + '.json')
|
if isinstance(names, str):
|
||||||
info_yaml = os.path.join(root, basename + '.yaml')
|
names = (names,)
|
||||||
|
tried = []
|
||||||
err_str = ''
|
err_str = ''
|
||||||
try:
|
for n in names:
|
||||||
with wds.gopen(info_json) as f:
|
full_path = os.path.join(root, n)
|
||||||
info_dict = json.load(f)
|
try:
|
||||||
return info_dict
|
tried.append(full_path)
|
||||||
except Exception as e:
|
with wds.gopen(full_path) as f:
|
||||||
err_str = str(e)
|
if n.endswith('.json'):
|
||||||
try:
|
info_dict = json.load(f)
|
||||||
with wds.gopen(info_yaml) as f:
|
else:
|
||||||
info_dict = yaml.safe_load(f)
|
info_dict = yaml.safe_load(f)
|
||||||
return info_dict
|
return info_dict
|
||||||
except Exception:
|
except Exception as e:
|
||||||
pass
|
err_str = str(e)
|
||||||
|
|
||||||
_logger.warning(
|
_logger.warning(
|
||||||
f'Dataset info file not found at {info_json} or {info_yaml}. Error: {err_str}. '
|
f'Dataset info file not found at {tried}. Error: {err_str}. '
|
||||||
'Falling back to provided split and size arg.')
|
'Falling back to provided split and size arg.')
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
@ -121,15 +124,18 @@ def _parse_split_info(split: str, info: Dict):
|
|||||||
|
|
||||||
|
|
||||||
def log_and_continue(exn):
|
def log_and_continue(exn):
|
||||||
"""Call in an exception handler to ignore any exception, isssue a warning, and continue."""
|
"""Call in an exception handler to ignore exceptions, isssue a warning, and continue."""
|
||||||
_logger.warning(f'Handling webdataset error ({repr(exn)}). Ignoring.')
|
_logger.warning(f'Handling webdataset error ({repr(exn)}). Ignoring.')
|
||||||
|
# NOTE: try force an exit on errors that are clearly code / config and not transient
|
||||||
|
if isinstance(exn, TypeError):
|
||||||
|
raise exn
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def _decode(
|
def _decode(
|
||||||
sample,
|
sample,
|
||||||
image_key='jpg',
|
image_key='jpg',
|
||||||
image_format='RGB',
|
image_mode='RGB',
|
||||||
target_key='cls',
|
target_key='cls',
|
||||||
alt_label=''
|
alt_label=''
|
||||||
):
|
):
|
||||||
@ -150,47 +156,18 @@ def _decode(
|
|||||||
class_label = int(sample[target_key])
|
class_label = int(sample[target_key])
|
||||||
|
|
||||||
# decode image
|
# decode image
|
||||||
with io.BytesIO(sample[image_key]) as b:
|
img = getfirst(sample, image_key)
|
||||||
|
with io.BytesIO(img) as b:
|
||||||
img = Image.open(b)
|
img = Image.open(b)
|
||||||
img.load()
|
img.load()
|
||||||
if image_format:
|
if image_mode:
|
||||||
img = img.convert(image_format)
|
img = img.convert(image_mode)
|
||||||
|
|
||||||
# json passed through in undecoded state
|
# json passed through in undecoded state
|
||||||
decoded = dict(jpg=img, cls=class_label, json=sample.get('json', None))
|
decoded = dict(jpg=img, cls=class_label, json=sample.get('json', None))
|
||||||
return decoded
|
return decoded
|
||||||
|
|
||||||
|
|
||||||
def _decode_samples(
|
|
||||||
data,
|
|
||||||
image_key='jpg',
|
|
||||||
image_format='RGB',
|
|
||||||
target_key='cls',
|
|
||||||
alt_label='',
|
|
||||||
handler=log_and_continue):
|
|
||||||
"""Decode samples with skip."""
|
|
||||||
for sample in data:
|
|
||||||
try:
|
|
||||||
result = _decode(
|
|
||||||
sample,
|
|
||||||
image_key=image_key,
|
|
||||||
image_format=image_format,
|
|
||||||
target_key=target_key,
|
|
||||||
alt_label=alt_label
|
|
||||||
)
|
|
||||||
except Exception as exn:
|
|
||||||
if handler(exn):
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
break
|
|
||||||
|
|
||||||
# null results are skipped
|
|
||||||
if result is not None:
|
|
||||||
if isinstance(sample, dict) and isinstance(result, dict):
|
|
||||||
result["__key__"] = sample.get("__key__")
|
|
||||||
yield result
|
|
||||||
|
|
||||||
|
|
||||||
def pytorch_worker_seed():
|
def pytorch_worker_seed():
|
||||||
"""get dataloader worker seed from pytorch"""
|
"""get dataloader worker seed from pytorch"""
|
||||||
worker_info = get_worker_info()
|
worker_info = get_worker_info()
|
||||||
@ -203,6 +180,7 @@ def pytorch_worker_seed():
|
|||||||
|
|
||||||
if wds is not None:
|
if wds is not None:
|
||||||
# conditional to avoid mandatory wds import (via inheritance of wds.PipelineStage)
|
# conditional to avoid mandatory wds import (via inheritance of wds.PipelineStage)
|
||||||
|
|
||||||
class detshuffle2(wds.PipelineStage):
|
class detshuffle2(wds.PipelineStage):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -284,20 +262,22 @@ class ResampledShards2(IterableDataset):
|
|||||||
class ReaderWds(Reader):
|
class ReaderWds(Reader):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
root,
|
root: str,
|
||||||
name,
|
name: Optional[str] = None,
|
||||||
split,
|
split: str = 'train',
|
||||||
is_training=False,
|
is_training: bool = False,
|
||||||
batch_size=None,
|
num_samples: Optional[int] = None,
|
||||||
repeats=0,
|
batch_size: int = 1,
|
||||||
seed=42,
|
repeats: int = 0,
|
||||||
class_map=None,
|
seed: int = 42,
|
||||||
input_name='jpg',
|
class_map: Optional[dict] = None,
|
||||||
input_image='RGB',
|
input_key: str = 'jpg;png;webp',
|
||||||
target_name='cls',
|
input_img_mode: str = 'RGB',
|
||||||
target_image='',
|
target_key: str = 'cls',
|
||||||
prefetch_size=None,
|
target_img_mode: str = '',
|
||||||
shuffle_size=None,
|
filename_key: str = 'filename',
|
||||||
|
sample_shuffle_size: Optional[int] = None,
|
||||||
|
smaple_initial_size: Optional[int] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if wds is None:
|
if wds is None:
|
||||||
@ -309,19 +289,23 @@ class ReaderWds(Reader):
|
|||||||
self.repeats = repeats
|
self.repeats = repeats
|
||||||
self.common_seed = seed # a seed that's fixed across all worker / distributed instances
|
self.common_seed = seed # a seed that's fixed across all worker / distributed instances
|
||||||
self.shard_shuffle_size = 500
|
self.shard_shuffle_size = 500
|
||||||
self.sample_shuffle_size = shuffle_size or SHUFFLE_SIZE
|
self.sample_shuffle_size = sample_shuffle_size or SAMPLE_SHUFFLE_SIZE
|
||||||
|
self.sample_initial_size = smaple_initial_size or SAMPLE_INITIAL_SIZE
|
||||||
|
|
||||||
self.image_key = input_name
|
self.input_key = input_key
|
||||||
self.image_format = input_image
|
self.input_img_mode = input_img_mode
|
||||||
self.target_key = target_name
|
self.target_key = target_key
|
||||||
self.filename_key = 'filename'
|
self.filename_key = filename_key
|
||||||
self.key_ext = '.JPEG' # extension to add to key for original filenames (DS specific, default ImageNet)
|
self.key_ext = '.JPEG' # extension to add to key for original filenames (DS specific, default ImageNet)
|
||||||
|
|
||||||
self.info = _load_info(self.root)
|
self.info = _load_info(self.root)
|
||||||
self.split_info = _parse_split_info(split, self.info)
|
self.split_info = _parse_split_info(split, self.info)
|
||||||
self.num_samples = self.split_info.num_samples
|
if num_samples is not None:
|
||||||
|
self.num_samples = num_samples
|
||||||
|
else:
|
||||||
|
self.num_samples = self.split_info.num_samples
|
||||||
if not self.num_samples:
|
if not self.num_samples:
|
||||||
raise RuntimeError(f'Invalid split definition, no samples found.')
|
raise RuntimeError(f'Invalid split definition, num_samples not specified.')
|
||||||
self.remap_class = False
|
self.remap_class = False
|
||||||
if class_map:
|
if class_map:
|
||||||
self.class_to_idx = load_class_map(class_map)
|
self.class_to_idx = load_class_map(class_map)
|
||||||
@ -346,7 +330,7 @@ class ReaderWds(Reader):
|
|||||||
self.init_count = 0
|
self.init_count = 0
|
||||||
self.epoch_count = SharedCount()
|
self.epoch_count = SharedCount()
|
||||||
|
|
||||||
# DataPipeline is lazy init, majority of WDS DataPipeline could be init here, BUT, shuffle seed
|
# DataPipeline is lazy init, the majority of WDS DataPipeline could be init here, BUT, shuffle seed
|
||||||
# is not handled in manner where it can be deterministic for each worker AND initialized up front
|
# is not handled in manner where it can be deterministic for each worker AND initialized up front
|
||||||
self.ds = None
|
self.ds = None
|
||||||
|
|
||||||
@ -382,13 +366,19 @@ class ReaderWds(Reader):
|
|||||||
# at this point we have an iterator over all the shards
|
# at this point we have an iterator over all the shards
|
||||||
if self.is_training:
|
if self.is_training:
|
||||||
pipeline.extend([
|
pipeline.extend([
|
||||||
detshuffle2(self.shard_shuffle_size, seed=self.common_seed, epoch=self.epoch_count),
|
detshuffle2(
|
||||||
|
self.shard_shuffle_size,
|
||||||
|
seed=self.common_seed,
|
||||||
|
epoch=self.epoch_count,
|
||||||
|
),
|
||||||
self._split_by_node_and_worker,
|
self._split_by_node_and_worker,
|
||||||
# at this point, we have an iterator over the shards assigned to each worker
|
# at this point, we have an iterator over the shards assigned to each worker
|
||||||
wds.tarfile_to_samples(handler=log_and_continue),
|
wds.tarfile_to_samples(handler=log_and_continue),
|
||||||
wds.shuffle(
|
wds.shuffle(
|
||||||
self.sample_shuffle_size,
|
bufsize=self.sample_shuffle_size,
|
||||||
rng=random.Random(self.worker_seed)), # this is why we lazy-init whole DataPipeline
|
initial=self.sample_initial_size,
|
||||||
|
rng=random.Random(self.worker_seed) # this is why we lazy-init whole DataPipeline
|
||||||
|
),
|
||||||
])
|
])
|
||||||
else:
|
else:
|
||||||
pipeline.extend([
|
pipeline.extend([
|
||||||
@ -397,12 +387,16 @@ class ReaderWds(Reader):
|
|||||||
wds.tarfile_to_samples(handler=log_and_continue),
|
wds.tarfile_to_samples(handler=log_and_continue),
|
||||||
])
|
])
|
||||||
pipeline.extend([
|
pipeline.extend([
|
||||||
partial(
|
wds.map(
|
||||||
_decode_samples,
|
partial(
|
||||||
image_key=self.image_key,
|
_decode,
|
||||||
image_format=self.image_format,
|
image_key=self.input_key,
|
||||||
alt_label=self.split_info.alt_label
|
image_mode=self.input_img_mode,
|
||||||
)
|
alt_label=self.split_info.alt_label,
|
||||||
|
),
|
||||||
|
handler=log_and_continue,
|
||||||
|
),
|
||||||
|
wds.rename(image=self.input_key, target=self.target_key)
|
||||||
])
|
])
|
||||||
self.ds = wds.DataPipeline(*pipeline)
|
self.ds = wds.DataPipeline(*pipeline)
|
||||||
|
|
||||||
@ -418,7 +412,7 @@ class ReaderWds(Reader):
|
|||||||
num_worker_samples = self.num_samples / max(self.global_num_workers, self.dist_num_replicas)
|
num_worker_samples = self.num_samples / max(self.global_num_workers, self.dist_num_replicas)
|
||||||
if self.is_training or self.dist_num_replicas > 1:
|
if self.is_training or self.dist_num_replicas > 1:
|
||||||
num_worker_samples = math.ceil(num_worker_samples)
|
num_worker_samples = math.ceil(num_worker_samples)
|
||||||
if self.is_training and self.batch_size is not None:
|
if self.is_training:
|
||||||
num_worker_samples = math.ceil(num_worker_samples / self.batch_size) * self.batch_size
|
num_worker_samples = math.ceil(num_worker_samples / self.batch_size) * self.batch_size
|
||||||
return int(num_worker_samples)
|
return int(num_worker_samples)
|
||||||
|
|
||||||
@ -439,10 +433,10 @@ class ReaderWds(Reader):
|
|||||||
i = 0
|
i = 0
|
||||||
# _logger.info(f'start {i}, {self.worker_id}') # FIXME temporary debug
|
# _logger.info(f'start {i}, {self.worker_id}') # FIXME temporary debug
|
||||||
for sample in ds:
|
for sample in ds:
|
||||||
target = sample[self.target_key]
|
target = sample['target']
|
||||||
if self.remap_class:
|
if self.remap_class:
|
||||||
target = self.class_to_idx[target]
|
target = self.class_to_idx[target]
|
||||||
yield sample[self.image_key], target
|
yield sample['image'], target
|
||||||
i += 1
|
i += 1
|
||||||
# _logger.info(f'end {i}, {self.worker_id}') # FIXME temporary debug
|
# _logger.info(f'end {i}, {self.worker_id}') # FIXME temporary debug
|
||||||
|
|
||||||
|
@ -2,7 +2,7 @@ import math
|
|||||||
import numbers
|
import numbers
|
||||||
import random
|
import random
|
||||||
import warnings
|
import warnings
|
||||||
from typing import List, Sequence
|
from typing import List, Sequence, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torchvision.transforms.functional as F
|
import torchvision.transforms.functional as F
|
||||||
@ -14,6 +14,12 @@ except ImportError:
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"ToNumpy", "ToTensor", "str_to_interp_mode", "str_to_pil_interp", "interp_mode_to_str",
|
||||||
|
"RandomResizedCropAndInterpolation", "CenterCropOrPad", "center_crop_or_pad", "crop_or_pad",
|
||||||
|
"RandomCropOrPad", "RandomPad", "ResizeKeepRatio", "TrimBorder"
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class ToNumpy:
|
class ToNumpy:
|
||||||
|
|
||||||
@ -99,7 +105,7 @@ def interp_mode_to_str(mode):
|
|||||||
_RANDOM_INTERPOLATION = (str_to_interp_mode('bilinear'), str_to_interp_mode('bicubic'))
|
_RANDOM_INTERPOLATION = (str_to_interp_mode('bilinear'), str_to_interp_mode('bicubic'))
|
||||||
|
|
||||||
|
|
||||||
def _setup_size(size, error_msg):
|
def _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size."):
|
||||||
if isinstance(size, numbers.Number):
|
if isinstance(size, numbers.Number):
|
||||||
return int(size), int(size)
|
return int(size), int(size)
|
||||||
|
|
||||||
@ -127,8 +133,13 @@ class RandomResizedCropAndInterpolation:
|
|||||||
interpolation: Default: PIL.Image.BILINEAR
|
interpolation: Default: PIL.Image.BILINEAR
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.),
|
def __init__(
|
||||||
interpolation='bilinear'):
|
self,
|
||||||
|
size,
|
||||||
|
scale=(0.08, 1.0),
|
||||||
|
ratio=(3. / 4., 4. / 3.),
|
||||||
|
interpolation='bilinear',
|
||||||
|
):
|
||||||
if isinstance(size, (list, tuple)):
|
if isinstance(size, (list, tuple)):
|
||||||
self.size = tuple(size)
|
self.size = tuple(size)
|
||||||
else:
|
else:
|
||||||
@ -156,35 +167,35 @@ class RandomResizedCropAndInterpolation:
|
|||||||
tuple: params (i, j, h, w) to be passed to ``crop`` for a random
|
tuple: params (i, j, h, w) to be passed to ``crop`` for a random
|
||||||
sized crop.
|
sized crop.
|
||||||
"""
|
"""
|
||||||
area = img.size[0] * img.size[1]
|
img_w, img_h = F.get_image_size(img)
|
||||||
|
area = img_w * img_h
|
||||||
|
|
||||||
for attempt in range(10):
|
for attempt in range(10):
|
||||||
target_area = random.uniform(*scale) * area
|
target_area = random.uniform(*scale) * area
|
||||||
log_ratio = (math.log(ratio[0]), math.log(ratio[1]))
|
log_ratio = (math.log(ratio[0]), math.log(ratio[1]))
|
||||||
aspect_ratio = math.exp(random.uniform(*log_ratio))
|
aspect_ratio = math.exp(random.uniform(*log_ratio))
|
||||||
|
|
||||||
w = int(round(math.sqrt(target_area * aspect_ratio)))
|
target_w = int(round(math.sqrt(target_area * aspect_ratio)))
|
||||||
h = int(round(math.sqrt(target_area / aspect_ratio)))
|
target_h = int(round(math.sqrt(target_area / aspect_ratio)))
|
||||||
|
if target_w <= img_w and target_h <= img_h:
|
||||||
if w <= img.size[0] and h <= img.size[1]:
|
i = random.randint(0, img_h - target_h)
|
||||||
i = random.randint(0, img.size[1] - h)
|
j = random.randint(0, img_w - target_w)
|
||||||
j = random.randint(0, img.size[0] - w)
|
return i, j, target_h, target_w
|
||||||
return i, j, h, w
|
|
||||||
|
|
||||||
# Fallback to central crop
|
# Fallback to central crop
|
||||||
in_ratio = img.size[0] / img.size[1]
|
in_ratio = img_w / img_h
|
||||||
if in_ratio < min(ratio):
|
if in_ratio < min(ratio):
|
||||||
w = img.size[0]
|
target_w = img_w
|
||||||
h = int(round(w / min(ratio)))
|
target_h = int(round(target_w / min(ratio)))
|
||||||
elif in_ratio > max(ratio):
|
elif in_ratio > max(ratio):
|
||||||
h = img.size[1]
|
target_h = img_h
|
||||||
w = int(round(h * max(ratio)))
|
target_w = int(round(target_h * max(ratio)))
|
||||||
else: # whole image
|
else: # whole image
|
||||||
w = img.size[0]
|
target_w = img_w
|
||||||
h = img.size[1]
|
target_h = img_h
|
||||||
i = (img.size[1] - h) // 2
|
i = (img_h - target_h) // 2
|
||||||
j = (img.size[0] - w) // 2
|
j = (img_w - target_w) // 2
|
||||||
return i, j, h, w
|
return i, j, target_h, target_w
|
||||||
|
|
||||||
def __call__(self, img):
|
def __call__(self, img):
|
||||||
"""
|
"""
|
||||||
@ -213,8 +224,14 @@ class RandomResizedCropAndInterpolation:
|
|||||||
return format_string
|
return format_string
|
||||||
|
|
||||||
|
|
||||||
def center_crop_or_pad(img: torch.Tensor, output_size: List[int], fill=0) -> torch.Tensor:
|
def center_crop_or_pad(
|
||||||
|
img: torch.Tensor,
|
||||||
|
output_size: Union[int, List[int]],
|
||||||
|
fill: Union[int, Tuple[int, int, int]] = 0,
|
||||||
|
padding_mode: str = 'constant',
|
||||||
|
) -> torch.Tensor:
|
||||||
"""Center crops and/or pads the given image.
|
"""Center crops and/or pads the given image.
|
||||||
|
|
||||||
If the image is torch Tensor, it is expected
|
If the image is torch Tensor, it is expected
|
||||||
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
|
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
|
||||||
If image size is smaller than output size along any edge, image is padded with 0 and then center cropped.
|
If image size is smaller than output size along any edge, image is padded with 0 and then center cropped.
|
||||||
@ -228,13 +245,9 @@ def center_crop_or_pad(img: torch.Tensor, output_size: List[int], fill=0) -> tor
|
|||||||
Returns:
|
Returns:
|
||||||
PIL Image or Tensor: Cropped image.
|
PIL Image or Tensor: Cropped image.
|
||||||
"""
|
"""
|
||||||
if isinstance(output_size, numbers.Number):
|
output_size = _setup_size(output_size)
|
||||||
output_size = (int(output_size), int(output_size))
|
|
||||||
elif isinstance(output_size, (tuple, list)) and len(output_size) == 1:
|
|
||||||
output_size = (output_size[0], output_size[0])
|
|
||||||
|
|
||||||
_, image_height, image_width = F.get_dimensions(img)
|
|
||||||
crop_height, crop_width = output_size
|
crop_height, crop_width = output_size
|
||||||
|
_, image_height, image_width = F.get_dimensions(img)
|
||||||
|
|
||||||
if crop_width > image_width or crop_height > image_height:
|
if crop_width > image_width or crop_height > image_height:
|
||||||
padding_ltrb = [
|
padding_ltrb = [
|
||||||
@ -243,7 +256,7 @@ def center_crop_or_pad(img: torch.Tensor, output_size: List[int], fill=0) -> tor
|
|||||||
(crop_width - image_width + 1) // 2 if crop_width > image_width else 0,
|
(crop_width - image_width + 1) // 2 if crop_width > image_width else 0,
|
||||||
(crop_height - image_height + 1) // 2 if crop_height > image_height else 0,
|
(crop_height - image_height + 1) // 2 if crop_height > image_height else 0,
|
||||||
]
|
]
|
||||||
img = F.pad(img, padding_ltrb, fill=fill)
|
img = F.pad(img, padding_ltrb, fill=fill, padding_mode=padding_mode)
|
||||||
_, image_height, image_width = F.get_dimensions(img)
|
_, image_height, image_width = F.get_dimensions(img)
|
||||||
if crop_width == image_width and crop_height == image_height:
|
if crop_width == image_width and crop_height == image_height:
|
||||||
return img
|
return img
|
||||||
@ -265,10 +278,16 @@ class CenterCropOrPad(torch.nn.Module):
|
|||||||
made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
|
made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, size, fill=0):
|
def __init__(
|
||||||
|
self,
|
||||||
|
size: Union[int, List[int]],
|
||||||
|
fill: Union[int, Tuple[int, int, int]] = 0,
|
||||||
|
padding_mode: str = 'constant',
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
|
self.size = _setup_size(size)
|
||||||
self.fill = fill
|
self.fill = fill
|
||||||
|
self.padding_mode = padding_mode
|
||||||
|
|
||||||
def forward(self, img):
|
def forward(self, img):
|
||||||
"""
|
"""
|
||||||
@ -278,14 +297,111 @@ class CenterCropOrPad(torch.nn.Module):
|
|||||||
Returns:
|
Returns:
|
||||||
PIL Image or Tensor: Cropped image.
|
PIL Image or Tensor: Cropped image.
|
||||||
"""
|
"""
|
||||||
return center_crop_or_pad(img, self.size, fill=self.fill)
|
return center_crop_or_pad(img, self.size, fill=self.fill, padding_mode=self.padding_mode)
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return f"{self.__class__.__name__}(size={self.size})"
|
return f"{self.__class__.__name__}(size={self.size})"
|
||||||
|
|
||||||
|
|
||||||
|
def crop_or_pad(
|
||||||
|
img: torch.Tensor,
|
||||||
|
top: int,
|
||||||
|
left: int,
|
||||||
|
height: int,
|
||||||
|
width: int,
|
||||||
|
fill: Union[int, Tuple[int, int, int]] = 0,
|
||||||
|
padding_mode: str = 'constant',
|
||||||
|
) -> torch.Tensor:
|
||||||
|
""" Crops and/or pads image to meet target size, with control over fill and padding_mode.
|
||||||
|
"""
|
||||||
|
_, image_height, image_width = F.get_dimensions(img)
|
||||||
|
right = left + width
|
||||||
|
bottom = top + height
|
||||||
|
if left < 0 or top < 0 or right > image_width or bottom > image_height:
|
||||||
|
padding_ltrb = [
|
||||||
|
max(-left + min(0, right), 0),
|
||||||
|
max(-top + min(0, bottom), 0),
|
||||||
|
max(right - max(image_width, left), 0),
|
||||||
|
max(bottom - max(image_height, top), 0),
|
||||||
|
]
|
||||||
|
img = F.pad(img, padding_ltrb, fill=fill, padding_mode=padding_mode)
|
||||||
|
|
||||||
|
top = max(top, 0)
|
||||||
|
left = max(left, 0)
|
||||||
|
return F.crop(img, top, left, height, width)
|
||||||
|
|
||||||
|
|
||||||
|
class RandomCropOrPad(torch.nn.Module):
|
||||||
|
""" Crop and/or pad image with random placement within the crop or pad margin.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
size: Union[int, List[int]],
|
||||||
|
fill: Union[int, Tuple[int, int, int]] = 0,
|
||||||
|
padding_mode: str = 'constant',
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.size = _setup_size(size)
|
||||||
|
self.fill = fill
|
||||||
|
self.padding_mode = padding_mode
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_params(img, size):
|
||||||
|
_, image_height, image_width = F.get_dimensions(img)
|
||||||
|
delta_height = image_height - size[0]
|
||||||
|
delta_width = image_width - size[1]
|
||||||
|
top = int(math.copysign(random.randint(0, abs(delta_height)), delta_height))
|
||||||
|
left = int(math.copysign(random.randint(0, abs(delta_width)), delta_width))
|
||||||
|
return top, left
|
||||||
|
|
||||||
|
def forward(self, img):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
img (PIL Image or Tensor): Image to be cropped.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
PIL Image or Tensor: Cropped image.
|
||||||
|
"""
|
||||||
|
top, left = self.get_params(img, self.size)
|
||||||
|
return crop_or_pad(
|
||||||
|
img,
|
||||||
|
top=top,
|
||||||
|
left=left,
|
||||||
|
height=self.size[0],
|
||||||
|
width=self.size[1],
|
||||||
|
fill=self.fill,
|
||||||
|
padding_mode=self.padding_mode,
|
||||||
|
)
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"{self.__class__.__name__}(size={self.size})"
|
||||||
|
|
||||||
|
|
||||||
|
class RandomPad:
|
||||||
|
def __init__(self, input_size, fill=0):
|
||||||
|
self.input_size = input_size
|
||||||
|
self.fill = fill
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_params(img, input_size):
|
||||||
|
width, height = F.get_image_size(img)
|
||||||
|
delta_width = max(input_size[1] - width, 0)
|
||||||
|
delta_height = max(input_size[0] - height, 0)
|
||||||
|
pad_left = random.randint(0, delta_width)
|
||||||
|
pad_top = random.randint(0, delta_height)
|
||||||
|
pad_right = delta_width - pad_left
|
||||||
|
pad_bottom = delta_height - pad_top
|
||||||
|
return pad_left, pad_top, pad_right, pad_bottom
|
||||||
|
|
||||||
|
def __call__(self, img):
|
||||||
|
padding = self.get_params(img, self.input_size)
|
||||||
|
img = F.pad(img, padding, self.fill)
|
||||||
|
return img
|
||||||
|
|
||||||
|
|
||||||
class ResizeKeepRatio:
|
class ResizeKeepRatio:
|
||||||
""" Resize and Keep Ratio
|
""" Resize and Keep Aspect Ratio
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -293,33 +409,77 @@ class ResizeKeepRatio:
|
|||||||
size,
|
size,
|
||||||
longest=0.,
|
longest=0.,
|
||||||
interpolation='bilinear',
|
interpolation='bilinear',
|
||||||
fill=0,
|
random_scale_prob=0.,
|
||||||
|
random_scale_range=(0.85, 1.05),
|
||||||
|
random_scale_area=False,
|
||||||
|
random_aspect_prob=0.,
|
||||||
|
random_aspect_range=(0.9, 1.11),
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
|
||||||
|
Args:
|
||||||
|
size:
|
||||||
|
longest:
|
||||||
|
interpolation:
|
||||||
|
random_scale_prob:
|
||||||
|
random_scale_range:
|
||||||
|
random_scale_area:
|
||||||
|
random_aspect_prob:
|
||||||
|
random_aspect_range:
|
||||||
|
"""
|
||||||
if isinstance(size, (list, tuple)):
|
if isinstance(size, (list, tuple)):
|
||||||
self.size = tuple(size)
|
self.size = tuple(size)
|
||||||
else:
|
else:
|
||||||
self.size = (size, size)
|
self.size = (size, size)
|
||||||
self.interpolation = str_to_interp_mode(interpolation)
|
if interpolation == 'random':
|
||||||
|
self.interpolation = _RANDOM_INTERPOLATION
|
||||||
|
else:
|
||||||
|
self.interpolation = str_to_interp_mode(interpolation)
|
||||||
self.longest = float(longest)
|
self.longest = float(longest)
|
||||||
self.fill = fill
|
self.random_scale_prob = random_scale_prob
|
||||||
|
self.random_scale_range = random_scale_range
|
||||||
|
self.random_scale_area = random_scale_area
|
||||||
|
self.random_aspect_prob = random_aspect_prob
|
||||||
|
self.random_aspect_range = random_aspect_range
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_params(img, target_size, longest):
|
def get_params(
|
||||||
|
img,
|
||||||
|
target_size,
|
||||||
|
longest,
|
||||||
|
random_scale_prob=0.,
|
||||||
|
random_scale_range=(1.0, 1.33),
|
||||||
|
random_scale_area=False,
|
||||||
|
random_aspect_prob=0.,
|
||||||
|
random_aspect_range=(0.9, 1.11)
|
||||||
|
):
|
||||||
"""Get parameters
|
"""Get parameters
|
||||||
|
|
||||||
Args:
|
|
||||||
img (PIL Image): Image to be cropped.
|
|
||||||
target_size (Tuple[int, int]): Size of output
|
|
||||||
Returns:
|
|
||||||
tuple: params (h, w) and (l, r, t, b) to be passed to ``resize`` and ``pad`` respectively
|
|
||||||
"""
|
"""
|
||||||
source_size = img.size[::-1] # h, w
|
img_h, img_w = img_size = F.get_dimensions(img)[1:]
|
||||||
h, w = source_size
|
|
||||||
target_h, target_w = target_size
|
target_h, target_w = target_size
|
||||||
ratio_h = h / target_h
|
ratio_h = img_h / target_h
|
||||||
ratio_w = w / target_w
|
ratio_w = img_w / target_w
|
||||||
ratio = max(ratio_h, ratio_w) * longest + min(ratio_h, ratio_w) * (1. - longest)
|
ratio = max(ratio_h, ratio_w) * longest + min(ratio_h, ratio_w) * (1. - longest)
|
||||||
size = [round(x / ratio) for x in source_size]
|
|
||||||
|
if random_scale_prob > 0 and random.random() < random_scale_prob:
|
||||||
|
ratio_factor = random.uniform(random_scale_range[0], random_scale_range[1])
|
||||||
|
if random_scale_area:
|
||||||
|
# make ratio factor equivalent to RRC area crop where < 1.0 = area zoom,
|
||||||
|
# otherwise like affine scale where < 1.0 = linear zoom out
|
||||||
|
ratio_factor = 1. / math.sqrt(ratio_factor)
|
||||||
|
ratio_factor = (ratio_factor, ratio_factor)
|
||||||
|
else:
|
||||||
|
ratio_factor = (1., 1.)
|
||||||
|
|
||||||
|
if random_aspect_prob > 0 and random.random() < random_aspect_prob:
|
||||||
|
log_aspect = (math.log(random_aspect_range[0]), math.log(random_aspect_range[1]))
|
||||||
|
aspect_factor = math.exp(random.uniform(*log_aspect))
|
||||||
|
aspect_factor = math.sqrt(aspect_factor)
|
||||||
|
# currently applying random aspect adjustment equally to both dims,
|
||||||
|
# could change to keep output sizes above their target where possible
|
||||||
|
ratio_factor = (ratio_factor[0] / aspect_factor, ratio_factor[1] * aspect_factor)
|
||||||
|
|
||||||
|
size = [round(x * f / ratio) for x, f in zip(img_size, ratio_factor)]
|
||||||
return size
|
return size
|
||||||
|
|
||||||
def __call__(self, img):
|
def __call__(self, img):
|
||||||
@ -330,13 +490,49 @@ class ResizeKeepRatio:
|
|||||||
Returns:
|
Returns:
|
||||||
PIL Image: Resized, padded to at least target size, possibly cropped to exactly target size
|
PIL Image: Resized, padded to at least target size, possibly cropped to exactly target size
|
||||||
"""
|
"""
|
||||||
size = self.get_params(img, self.size, self.longest)
|
size = self.get_params(
|
||||||
img = F.resize(img, size, self.interpolation)
|
img, self.size, self.longest,
|
||||||
|
self.random_scale_prob, self.random_scale_range, self.random_scale_area,
|
||||||
|
self.random_aspect_prob, self.random_aspect_range
|
||||||
|
)
|
||||||
|
if isinstance(self.interpolation, (tuple, list)):
|
||||||
|
interpolation = random.choice(self.interpolation)
|
||||||
|
else:
|
||||||
|
interpolation = self.interpolation
|
||||||
|
img = F.resize(img, size, interpolation)
|
||||||
return img
|
return img
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
interpolate_str = interp_mode_to_str(self.interpolation)
|
if isinstance(self.interpolation, (tuple, list)):
|
||||||
|
interpolate_str = ' '.join([interp_mode_to_str(x) for x in self.interpolation])
|
||||||
|
else:
|
||||||
|
interpolate_str = interp_mode_to_str(self.interpolation)
|
||||||
format_string = self.__class__.__name__ + '(size={0}'.format(self.size)
|
format_string = self.__class__.__name__ + '(size={0}'.format(self.size)
|
||||||
format_string += f', interpolation={interpolate_str})'
|
format_string += f', interpolation={interpolate_str}'
|
||||||
format_string += f', longest={self.longest:.3f})'
|
format_string += f', longest={self.longest:.3f}'
|
||||||
|
format_string += f', random_scale_prob={self.random_scale_prob:.3f}'
|
||||||
|
format_string += f', random_scale_range=(' \
|
||||||
|
f'{self.random_scale_range[0]:.3f}, {self.random_aspect_range[1]:.3f})'
|
||||||
|
format_string += f', random_aspect_prob={self.random_aspect_prob:.3f}'
|
||||||
|
format_string += f', random_aspect_range=(' \
|
||||||
|
f'{self.random_aspect_range[0]:.3f}, {self.random_aspect_range[1]:.3f}))'
|
||||||
return format_string
|
return format_string
|
||||||
|
|
||||||
|
|
||||||
|
class TrimBorder(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
border_size: int,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.border_size = border_size
|
||||||
|
|
||||||
|
def forward(self, img):
|
||||||
|
w, h = F.get_image_size(img)
|
||||||
|
top = left = self.border_size
|
||||||
|
top = min(top, h)
|
||||||
|
left = min(left, h)
|
||||||
|
height = max(0, h - 2 * self.border_size)
|
||||||
|
width = max(0, w - 2 * self.border_size)
|
||||||
|
return F.crop(img, top, left, height, width)
|
@ -4,6 +4,7 @@ Factory methods for building image transforms for use with TIMM (PyTorch Image M
|
|||||||
Hacked together by / Copyright 2019, Ross Wightman
|
Hacked together by / Copyright 2019, Ross Wightman
|
||||||
"""
|
"""
|
||||||
import math
|
import math
|
||||||
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
@ -11,17 +12,29 @@ from torchvision import transforms
|
|||||||
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, DEFAULT_CROP_PCT
|
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, DEFAULT_CROP_PCT
|
||||||
from timm.data.auto_augment import rand_augment_transform, augment_and_mix_transform, auto_augment_transform
|
from timm.data.auto_augment import rand_augment_transform, augment_and_mix_transform, auto_augment_transform
|
||||||
from timm.data.transforms import str_to_interp_mode, str_to_pil_interp, RandomResizedCropAndInterpolation,\
|
from timm.data.transforms import str_to_interp_mode, str_to_pil_interp, RandomResizedCropAndInterpolation,\
|
||||||
ResizeKeepRatio, CenterCropOrPad, ToNumpy
|
ResizeKeepRatio, CenterCropOrPad, RandomCropOrPad, TrimBorder, ToNumpy
|
||||||
from timm.data.random_erasing import RandomErasing
|
from timm.data.random_erasing import RandomErasing
|
||||||
|
|
||||||
|
|
||||||
def transforms_noaug_train(
|
def transforms_noaug_train(
|
||||||
img_size=224,
|
img_size: Union[int, Tuple[int, int]] = 224,
|
||||||
interpolation='bilinear',
|
interpolation: str = 'bilinear',
|
||||||
use_prefetcher=False,
|
use_prefetcher: bool = False,
|
||||||
mean=IMAGENET_DEFAULT_MEAN,
|
mean: Tuple[float, ...] = IMAGENET_DEFAULT_MEAN,
|
||||||
std=IMAGENET_DEFAULT_STD,
|
std: Tuple[float, ...] = IMAGENET_DEFAULT_STD,
|
||||||
):
|
):
|
||||||
|
""" No-augmentation image transforms for training.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img_size: Target image size.
|
||||||
|
interpolation: Image interpolation mode.
|
||||||
|
mean: Image normalization mean.
|
||||||
|
std: Image normalization standard deviation.
|
||||||
|
use_prefetcher: Prefetcher enabled. Do not convert image to tensor or normalize.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
"""
|
||||||
if interpolation == 'random':
|
if interpolation == 'random':
|
||||||
# random interpolation not supported with no-aug
|
# random interpolation not supported with no-aug
|
||||||
interpolation = 'bilinear'
|
interpolation = 'bilinear'
|
||||||
@ -37,41 +50,97 @@ def transforms_noaug_train(
|
|||||||
transforms.ToTensor(),
|
transforms.ToTensor(),
|
||||||
transforms.Normalize(
|
transforms.Normalize(
|
||||||
mean=torch.tensor(mean),
|
mean=torch.tensor(mean),
|
||||||
std=torch.tensor(std))
|
std=torch.tensor(std)
|
||||||
|
)
|
||||||
]
|
]
|
||||||
return transforms.Compose(tfl)
|
return transforms.Compose(tfl)
|
||||||
|
|
||||||
|
|
||||||
def transforms_imagenet_train(
|
def transforms_imagenet_train(
|
||||||
img_size=224,
|
img_size: Union[int, Tuple[int, int]] = 224,
|
||||||
scale=None,
|
scale: Optional[Tuple[float, float]] = None,
|
||||||
ratio=None,
|
ratio: Optional[Tuple[float, float]] = None,
|
||||||
hflip=0.5,
|
train_crop_mode: Optional[str] = None,
|
||||||
vflip=0.,
|
hflip: float = 0.5,
|
||||||
color_jitter=0.4,
|
vflip: float = 0.,
|
||||||
auto_augment=None,
|
color_jitter: Union[float, Tuple[float, ...]] = 0.4,
|
||||||
interpolation='random',
|
color_jitter_prob: Optional[float] = None,
|
||||||
use_prefetcher=False,
|
force_color_jitter: bool = False,
|
||||||
mean=IMAGENET_DEFAULT_MEAN,
|
grayscale_prob: float = 0.,
|
||||||
std=IMAGENET_DEFAULT_STD,
|
gaussian_blur_prob: float = 0.,
|
||||||
re_prob=0.,
|
auto_augment: Optional[str] = None,
|
||||||
re_mode='const',
|
interpolation: str = 'random',
|
||||||
re_count=1,
|
mean: Tuple[float, ...] = IMAGENET_DEFAULT_MEAN,
|
||||||
re_num_splits=0,
|
std: Tuple[float, ...] = IMAGENET_DEFAULT_STD,
|
||||||
separate=False,
|
re_prob: float = 0.,
|
||||||
force_color_jitter=False,
|
re_mode: str = 'const',
|
||||||
|
re_count: int = 1,
|
||||||
|
re_num_splits: int = 0,
|
||||||
|
use_prefetcher: bool = False,
|
||||||
|
separate: bool = False,
|
||||||
):
|
):
|
||||||
|
""" ImageNet-oriented image transforms for training.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img_size: Target image size.
|
||||||
|
scale: Random resize scale range (crop area, < 1.0 => zoom in).
|
||||||
|
ratio: Random aspect ratio range (crop ratio for RRC, ratio adjustment factor for RKR).
|
||||||
|
hflip: Horizontal flip probability.
|
||||||
|
vflip: Vertical flip probability.
|
||||||
|
color_jitter: Random color jitter component factors (brightness, contrast, saturation, hue).
|
||||||
|
Scalar is applied as (scalar,) * 3 (no hue).
|
||||||
|
color_jitter_prob: Apply color jitter with this probability if not None (for SimlCLR-like aug).
|
||||||
|
force_color_jitter: Force color jitter where it is normally disabled (ie with RandAugment on).
|
||||||
|
grayscale_prob: Probability of converting image to grayscale (for SimCLR-like aug).
|
||||||
|
gaussian_blur_prob: Probability of applying gaussian blur (for SimCLR-like aug).
|
||||||
|
auto_augment: Auto augment configuration string (see auto_augment.py).
|
||||||
|
interpolation: Image interpolation mode.
|
||||||
|
mean: Image normalization mean.
|
||||||
|
std: Image normalization standard deviation.
|
||||||
|
re_prob: Random erasing probability.
|
||||||
|
re_mode: Random erasing fill mode.
|
||||||
|
re_count: Number of random erasing regions.
|
||||||
|
re_num_splits: Control split of random erasing across batch size.
|
||||||
|
use_prefetcher: Prefetcher enabled. Do not convert image to tensor or normalize.
|
||||||
|
separate: Output transforms in 3-stage tuple.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
If separate==True, the transforms are returned as a tuple of 3 separate transforms
|
||||||
|
for use in a mixing dataset that passes
|
||||||
|
* all data through the first (primary) transform, called the 'clean' data
|
||||||
|
* a portion of the data through the secondary transform
|
||||||
|
* normalizes and converts the branches above with the third, final transform
|
||||||
"""
|
"""
|
||||||
If separate==True, the transforms are returned as a tuple of 3 separate transforms
|
train_crop_mode = train_crop_mode or 'rrc'
|
||||||
for use in a mixing dataset that passes
|
if train_crop_mode in ('rkrc', 'rkrr'):
|
||||||
* all data through the first (primary) transform, called the 'clean' data
|
# FIXME integration of RKR is a WIP
|
||||||
* a portion of the data through the secondary transform
|
scale = tuple(scale or (0.8, 1.00))
|
||||||
* normalizes and converts the branches above with the third, final transform
|
ratio = tuple(ratio or (0.9, 1/.9))
|
||||||
"""
|
primary_tfl = [
|
||||||
scale = tuple(scale or (0.08, 1.0)) # default imagenet scale range
|
ResizeKeepRatio(
|
||||||
ratio = tuple(ratio or (3./4., 4./3.)) # default imagenet ratio range
|
img_size,
|
||||||
primary_tfl = [
|
interpolation=interpolation,
|
||||||
RandomResizedCropAndInterpolation(img_size, scale=scale, ratio=ratio, interpolation=interpolation)]
|
random_scale_prob=0.5,
|
||||||
|
random_scale_range=scale,
|
||||||
|
random_scale_area=True, # scale compatible with RRC
|
||||||
|
random_aspect_prob=0.5,
|
||||||
|
random_aspect_range=ratio,
|
||||||
|
),
|
||||||
|
CenterCropOrPad(img_size, padding_mode='reflect')
|
||||||
|
if train_crop_mode == 'rkrc' else
|
||||||
|
RandomCropOrPad(img_size, padding_mode='reflect')
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
scale = tuple(scale or (0.08, 1.0)) # default imagenet scale range
|
||||||
|
ratio = tuple(ratio or (3. / 4., 4. / 3.)) # default imagenet ratio range
|
||||||
|
primary_tfl = [
|
||||||
|
RandomResizedCropAndInterpolation(
|
||||||
|
img_size,
|
||||||
|
scale=scale,
|
||||||
|
ratio=ratio,
|
||||||
|
interpolation=interpolation,
|
||||||
|
)
|
||||||
|
]
|
||||||
if hflip > 0.:
|
if hflip > 0.:
|
||||||
primary_tfl += [transforms.RandomHorizontalFlip(p=hflip)]
|
primary_tfl += [transforms.RandomHorizontalFlip(p=hflip)]
|
||||||
if vflip > 0.:
|
if vflip > 0.:
|
||||||
@ -111,7 +180,28 @@ def transforms_imagenet_train(
|
|||||||
else:
|
else:
|
||||||
# if it's a scalar, duplicate for brightness, contrast, and saturation, no hue
|
# if it's a scalar, duplicate for brightness, contrast, and saturation, no hue
|
||||||
color_jitter = (float(color_jitter),) * 3
|
color_jitter = (float(color_jitter),) * 3
|
||||||
secondary_tfl += [transforms.ColorJitter(*color_jitter)]
|
if color_jitter_prob is not None:
|
||||||
|
secondary_tfl += [
|
||||||
|
transforms.RandomApply([
|
||||||
|
transforms.ColorJitter(*color_jitter),
|
||||||
|
],
|
||||||
|
p=color_jitter_prob
|
||||||
|
)
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
secondary_tfl += [transforms.ColorJitter(*color_jitter)]
|
||||||
|
|
||||||
|
if grayscale_prob:
|
||||||
|
secondary_tfl += [transforms.RandomGrayscale(p=grayscale_prob)]
|
||||||
|
|
||||||
|
if gaussian_blur_prob:
|
||||||
|
secondary_tfl += [
|
||||||
|
transforms.RandomApply([
|
||||||
|
transforms.GaussianBlur(kernel_size=23), # hardcoded for now
|
||||||
|
],
|
||||||
|
p=gaussian_blur_prob,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
final_tfl = []
|
final_tfl = []
|
||||||
if use_prefetcher:
|
if use_prefetcher:
|
||||||
@ -122,11 +212,19 @@ def transforms_imagenet_train(
|
|||||||
transforms.ToTensor(),
|
transforms.ToTensor(),
|
||||||
transforms.Normalize(
|
transforms.Normalize(
|
||||||
mean=torch.tensor(mean),
|
mean=torch.tensor(mean),
|
||||||
std=torch.tensor(std))
|
std=torch.tensor(std)
|
||||||
|
),
|
||||||
]
|
]
|
||||||
if re_prob > 0.:
|
if re_prob > 0.:
|
||||||
final_tfl.append(
|
final_tfl += [
|
||||||
RandomErasing(re_prob, mode=re_mode, max_count=re_count, num_splits=re_num_splits, device='cpu'))
|
RandomErasing(
|
||||||
|
re_prob,
|
||||||
|
mode=re_mode,
|
||||||
|
max_count=re_count,
|
||||||
|
num_splits=re_num_splits,
|
||||||
|
device='cpu',
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
if separate:
|
if separate:
|
||||||
return transforms.Compose(primary_tfl), transforms.Compose(secondary_tfl), transforms.Compose(final_tfl)
|
return transforms.Compose(primary_tfl), transforms.Compose(secondary_tfl), transforms.Compose(final_tfl)
|
||||||
@ -135,14 +233,30 @@ def transforms_imagenet_train(
|
|||||||
|
|
||||||
|
|
||||||
def transforms_imagenet_eval(
|
def transforms_imagenet_eval(
|
||||||
img_size=224,
|
img_size: Union[int, Tuple[int, int]] = 224,
|
||||||
crop_pct=None,
|
crop_pct: Optional[float] = None,
|
||||||
crop_mode=None,
|
crop_mode: Optional[str] = None,
|
||||||
interpolation='bilinear',
|
crop_border_pixels: Optional[int] = None,
|
||||||
use_prefetcher=False,
|
interpolation: str = 'bilinear',
|
||||||
mean=IMAGENET_DEFAULT_MEAN,
|
mean: Tuple[float, ...] = IMAGENET_DEFAULT_MEAN,
|
||||||
std=IMAGENET_DEFAULT_STD
|
std: Tuple[float, ...] = IMAGENET_DEFAULT_STD,
|
||||||
|
use_prefetcher: bool = False,
|
||||||
):
|
):
|
||||||
|
""" ImageNet-oriented image transform for evaluation and inference.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img_size: Target image size.
|
||||||
|
crop_pct: Crop percentage. Defaults to 0.875 when None.
|
||||||
|
crop_mode: Crop mode. One of ['squash', 'border', 'center']. Defaults to 'center' when None.
|
||||||
|
crop_border_pixels: Trim a border of specified # pixels around edge of original image.
|
||||||
|
interpolation: Image interpolation mode.
|
||||||
|
mean: Image normalization mean.
|
||||||
|
std: Image normalization standard deviation.
|
||||||
|
use_prefetcher: Prefetcher enabled. Do not convert image to tensor or normalize.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Composed transform pipeline
|
||||||
|
"""
|
||||||
crop_pct = crop_pct or DEFAULT_CROP_PCT
|
crop_pct = crop_pct or DEFAULT_CROP_PCT
|
||||||
|
|
||||||
if isinstance(img_size, (tuple, list)):
|
if isinstance(img_size, (tuple, list)):
|
||||||
@ -152,10 +266,15 @@ def transforms_imagenet_eval(
|
|||||||
scale_size = math.floor(img_size / crop_pct)
|
scale_size = math.floor(img_size / crop_pct)
|
||||||
scale_size = (scale_size, scale_size)
|
scale_size = (scale_size, scale_size)
|
||||||
|
|
||||||
|
tfl = []
|
||||||
|
|
||||||
|
if crop_border_pixels:
|
||||||
|
tfl += [TrimBorder(crop_border_pixels)]
|
||||||
|
|
||||||
if crop_mode == 'squash':
|
if crop_mode == 'squash':
|
||||||
# squash mode scales each edge to 1/pct of target, then crops
|
# squash mode scales each edge to 1/pct of target, then crops
|
||||||
# aspect ratio is not preserved, no img lost if crop_pct == 1.0
|
# aspect ratio is not preserved, no img lost if crop_pct == 1.0
|
||||||
tfl = [
|
tfl += [
|
||||||
transforms.Resize(scale_size, interpolation=str_to_interp_mode(interpolation)),
|
transforms.Resize(scale_size, interpolation=str_to_interp_mode(interpolation)),
|
||||||
transforms.CenterCrop(img_size),
|
transforms.CenterCrop(img_size),
|
||||||
]
|
]
|
||||||
@ -163,7 +282,7 @@ def transforms_imagenet_eval(
|
|||||||
# scale the longest edge of image to 1/pct of target edge, add borders to pad, then crop
|
# scale the longest edge of image to 1/pct of target edge, add borders to pad, then crop
|
||||||
# no image lost if crop_pct == 1.0
|
# no image lost if crop_pct == 1.0
|
||||||
fill = [round(255 * v) for v in mean]
|
fill = [round(255 * v) for v in mean]
|
||||||
tfl = [
|
tfl += [
|
||||||
ResizeKeepRatio(scale_size, interpolation=interpolation, longest=1.0),
|
ResizeKeepRatio(scale_size, interpolation=interpolation, longest=1.0),
|
||||||
CenterCropOrPad(img_size, fill=fill),
|
CenterCropOrPad(img_size, fill=fill),
|
||||||
]
|
]
|
||||||
@ -172,12 +291,12 @@ def transforms_imagenet_eval(
|
|||||||
# aspect ratio is preserved, crops center within image, no borders are added, image is lost
|
# aspect ratio is preserved, crops center within image, no borders are added, image is lost
|
||||||
if scale_size[0] == scale_size[1]:
|
if scale_size[0] == scale_size[1]:
|
||||||
# simple case, use torchvision built-in Resize w/ shortest edge mode (scalar size arg)
|
# simple case, use torchvision built-in Resize w/ shortest edge mode (scalar size arg)
|
||||||
tfl = [
|
tfl += [
|
||||||
transforms.Resize(scale_size[0], interpolation=str_to_interp_mode(interpolation))
|
transforms.Resize(scale_size[0], interpolation=str_to_interp_mode(interpolation))
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
# resize shortest edge to matching target dim for non-square target
|
# resize the shortest edge to matching target dim for non-square target
|
||||||
tfl = [ResizeKeepRatio(scale_size)]
|
tfl += [ResizeKeepRatio(scale_size)]
|
||||||
tfl += [transforms.CenterCrop(img_size)]
|
tfl += [transforms.CenterCrop(img_size)]
|
||||||
|
|
||||||
if use_prefetcher:
|
if use_prefetcher:
|
||||||
@ -196,28 +315,65 @@ def transforms_imagenet_eval(
|
|||||||
|
|
||||||
|
|
||||||
def create_transform(
|
def create_transform(
|
||||||
input_size,
|
input_size: Union[int, Tuple[int, int], Tuple[int, int, int]] = 224,
|
||||||
is_training=False,
|
is_training: bool = False,
|
||||||
use_prefetcher=False,
|
no_aug: bool = False,
|
||||||
no_aug=False,
|
scale: Optional[Tuple[float, float]] = None,
|
||||||
scale=None,
|
ratio: Optional[Tuple[float, float]] = None,
|
||||||
ratio=None,
|
hflip: float = 0.5,
|
||||||
hflip=0.5,
|
vflip: float = 0.,
|
||||||
vflip=0.,
|
color_jitter: Union[float, Tuple[float, ...]] = 0.4,
|
||||||
color_jitter=0.4,
|
color_jitter_prob: Optional[float] = None,
|
||||||
auto_augment=None,
|
grayscale_prob: float = 0.,
|
||||||
interpolation='bilinear',
|
gaussian_blur_prob: float = 0.,
|
||||||
mean=IMAGENET_DEFAULT_MEAN,
|
auto_augment: Optional[str] = None,
|
||||||
std=IMAGENET_DEFAULT_STD,
|
interpolation: str = 'bilinear',
|
||||||
re_prob=0.,
|
mean: Tuple[float, ...] = IMAGENET_DEFAULT_MEAN,
|
||||||
re_mode='const',
|
std: Tuple[float, ...] = IMAGENET_DEFAULT_STD,
|
||||||
re_count=1,
|
re_prob: float = 0.,
|
||||||
re_num_splits=0,
|
re_mode: str = 'const',
|
||||||
crop_pct=None,
|
re_count: int = 1,
|
||||||
crop_mode=None,
|
re_num_splits: int = 0,
|
||||||
tf_preprocessing=False,
|
crop_pct: Optional[float] = None,
|
||||||
separate=False):
|
crop_mode: Optional[str] = None,
|
||||||
|
crop_border_pixels: Optional[int] = None,
|
||||||
|
tf_preprocessing: bool = False,
|
||||||
|
use_prefetcher: bool = False,
|
||||||
|
separate: bool = False,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_size: Target input size (channels, height, width) tuple or size scalar.
|
||||||
|
is_training: Return training (random) transforms.
|
||||||
|
no_aug: Disable augmentation for training (useful for debug).
|
||||||
|
scale: Random resize scale range (crop area, < 1.0 => zoom in).
|
||||||
|
ratio: Random aspect ratio range (crop ratio for RRC, ratio adjustment factor for RKR).
|
||||||
|
hflip: Horizontal flip probability.
|
||||||
|
vflip: Vertical flip probability.
|
||||||
|
color_jitter: Random color jitter component factors (brightness, contrast, saturation, hue).
|
||||||
|
Scalar is applied as (scalar,) * 3 (no hue).
|
||||||
|
color_jitter_prob: Apply color jitter with this probability if not None (for SimlCLR-like aug).
|
||||||
|
grayscale_prob: Probability of converting image to grayscale (for SimCLR-like aug).
|
||||||
|
gaussian_blur_prob: Probability of applying gaussian blur (for SimCLR-like aug).
|
||||||
|
auto_augment: Auto augment configuration string (see auto_augment.py).
|
||||||
|
interpolation: Image interpolation mode.
|
||||||
|
mean: Image normalization mean.
|
||||||
|
std: Image normalization standard deviation.
|
||||||
|
re_prob: Random erasing probability.
|
||||||
|
re_mode: Random erasing fill mode.
|
||||||
|
re_count: Number of random erasing regions.
|
||||||
|
re_num_splits: Control split of random erasing across batch size.
|
||||||
|
crop_pct: Inference crop percentage (output size / resize size).
|
||||||
|
crop_mode: Inference crop mode. One of ['squash', 'border', 'center']. Defaults to 'center' when None.
|
||||||
|
crop_border_pixels: Inference crop border of specified # pixels around edge of original image.
|
||||||
|
tf_preprocessing: Use TF 1.0 inference preprocessing for testing model ports
|
||||||
|
use_prefetcher: Pre-fetcher enabled. Do not convert image to tensor or normalize.
|
||||||
|
separate: Output transforms in 3-stage tuple.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Composed transforms or tuple thereof
|
||||||
|
"""
|
||||||
if isinstance(input_size, (tuple, list)):
|
if isinstance(input_size, (tuple, list)):
|
||||||
img_size = input_size[-2:]
|
img_size = input_size[-2:]
|
||||||
else:
|
else:
|
||||||
@ -227,7 +383,10 @@ def create_transform(
|
|||||||
assert not separate, "Separate transforms not supported for TF preprocessing"
|
assert not separate, "Separate transforms not supported for TF preprocessing"
|
||||||
from timm.data.tf_preprocessing import TfPreprocessTransform
|
from timm.data.tf_preprocessing import TfPreprocessTransform
|
||||||
transform = TfPreprocessTransform(
|
transform = TfPreprocessTransform(
|
||||||
is_training=is_training, size=img_size, interpolation=interpolation)
|
is_training=is_training,
|
||||||
|
size=img_size,
|
||||||
|
interpolation=interpolation,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
if is_training and no_aug:
|
if is_training and no_aug:
|
||||||
assert not separate, "Cannot perform split augmentation with no_aug"
|
assert not separate, "Cannot perform split augmentation with no_aug"
|
||||||
@ -246,6 +405,9 @@ def create_transform(
|
|||||||
hflip=hflip,
|
hflip=hflip,
|
||||||
vflip=vflip,
|
vflip=vflip,
|
||||||
color_jitter=color_jitter,
|
color_jitter=color_jitter,
|
||||||
|
color_jitter_prob=color_jitter_prob,
|
||||||
|
grayscale_prob=grayscale_prob,
|
||||||
|
gaussian_blur_prob=gaussian_blur_prob,
|
||||||
auto_augment=auto_augment,
|
auto_augment=auto_augment,
|
||||||
interpolation=interpolation,
|
interpolation=interpolation,
|
||||||
use_prefetcher=use_prefetcher,
|
use_prefetcher=use_prefetcher,
|
||||||
@ -267,6 +429,7 @@ def create_transform(
|
|||||||
std=std,
|
std=std,
|
||||||
crop_pct=crop_pct,
|
crop_pct=crop_pct,
|
||||||
crop_mode=crop_mode,
|
crop_mode=crop_mode,
|
||||||
|
crop_border_pixels=crop_border_pixels,
|
||||||
)
|
)
|
||||||
|
|
||||||
return transform
|
return transform
|
||||||
|
35
train.py
35
train.py
@ -93,10 +93,20 @@ group.add_argument('--train-split', metavar='NAME', default='train',
|
|||||||
help='dataset train split (default: train)')
|
help='dataset train split (default: train)')
|
||||||
group.add_argument('--val-split', metavar='NAME', default='validation',
|
group.add_argument('--val-split', metavar='NAME', default='validation',
|
||||||
help='dataset validation split (default: validation)')
|
help='dataset validation split (default: validation)')
|
||||||
|
parser.add_argument('--train-num-samples', default=None, type=int,
|
||||||
|
metavar='N', help='Manually specify num samples in train split, for IterableDatasets.')
|
||||||
|
parser.add_argument('--val-num-samples', default=None, type=int,
|
||||||
|
metavar='N', help='Manually specify num samples in validation split, for IterableDatasets.')
|
||||||
group.add_argument('--dataset-download', action='store_true', default=False,
|
group.add_argument('--dataset-download', action='store_true', default=False,
|
||||||
help='Allow download of dataset for torch/ and tfds/ datasets that support it.')
|
help='Allow download of dataset for torch/ and tfds/ datasets that support it.')
|
||||||
group.add_argument('--class-map', default='', type=str, metavar='FILENAME',
|
group.add_argument('--class-map', default='', type=str, metavar='FILENAME',
|
||||||
help='path to class to idx mapping file (default: "")')
|
help='path to class to idx mapping file (default: "")')
|
||||||
|
group.add_argument('--input-img-mode', default=None, type=str,
|
||||||
|
help='Dataset image conversion mode for input images.')
|
||||||
|
group.add_argument('--input-key', default=None, type=str,
|
||||||
|
help='Dataset key for input images.')
|
||||||
|
group.add_argument('--target-key', default=None, type=str,
|
||||||
|
help='Dataset key for target labels.')
|
||||||
|
|
||||||
# Model parameters
|
# Model parameters
|
||||||
group = parser.add_argument_group('Model parameters')
|
group = parser.add_argument_group('Model parameters')
|
||||||
@ -245,6 +255,12 @@ group.add_argument('--vflip', type=float, default=0.,
|
|||||||
help='Vertical flip training aug probability')
|
help='Vertical flip training aug probability')
|
||||||
group.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT',
|
group.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT',
|
||||||
help='Color jitter factor (default: 0.4)')
|
help='Color jitter factor (default: 0.4)')
|
||||||
|
group.add_argument('--color-jitter-prob', type=float, default=None, metavar='PCT',
|
||||||
|
help='Probability of applying any color jitter.')
|
||||||
|
group.add_argument('--grayscale-prob', type=float, default=None, metavar='PCT',
|
||||||
|
help='Probability of applying random grayscale conversion.')
|
||||||
|
group.add_argument('--gaussian-blur-prob', type=float, default=None, metavar='PCT',
|
||||||
|
help='Probability of applying gaussian blur.')
|
||||||
group.add_argument('--aa', type=str, default=None, metavar='NAME',
|
group.add_argument('--aa', type=str, default=None, metavar='NAME',
|
||||||
help='Use AutoAugment policy. "v0" or "original". (default: None)'),
|
help='Use AutoAugment policy. "v0" or "original". (default: None)'),
|
||||||
group.add_argument('--aug-repeats', type=float, default=0,
|
group.add_argument('--aug-repeats', type=float, default=0,
|
||||||
@ -594,6 +610,10 @@ def main():
|
|||||||
# create the train and eval datasets
|
# create the train and eval datasets
|
||||||
if args.data and not args.data_dir:
|
if args.data and not args.data_dir:
|
||||||
args.data_dir = args.data
|
args.data_dir = args.data
|
||||||
|
if args.input_img_mode is None:
|
||||||
|
input_img_mode = 'RGB' if data_config['input_size'][0] == 3 else 'L'
|
||||||
|
else:
|
||||||
|
input_img_mode = args.input_img_mode
|
||||||
dataset_train = create_dataset(
|
dataset_train = create_dataset(
|
||||||
args.dataset,
|
args.dataset,
|
||||||
root=args.data_dir,
|
root=args.data_dir,
|
||||||
@ -604,6 +624,10 @@ def main():
|
|||||||
batch_size=args.batch_size,
|
batch_size=args.batch_size,
|
||||||
seed=args.seed,
|
seed=args.seed,
|
||||||
repeats=args.epoch_repeats,
|
repeats=args.epoch_repeats,
|
||||||
|
input_img_mode=input_img_mode,
|
||||||
|
input_key=args.input_key,
|
||||||
|
target_key=args.target_key,
|
||||||
|
num_samples=args.train_num_samples,
|
||||||
)
|
)
|
||||||
|
|
||||||
dataset_eval = create_dataset(
|
dataset_eval = create_dataset(
|
||||||
@ -614,6 +638,10 @@ def main():
|
|||||||
class_map=args.class_map,
|
class_map=args.class_map,
|
||||||
download=args.dataset_download,
|
download=args.dataset_download,
|
||||||
batch_size=args.batch_size,
|
batch_size=args.batch_size,
|
||||||
|
input_img_mode=input_img_mode,
|
||||||
|
input_key=args.input_key,
|
||||||
|
target_key=args.target_key,
|
||||||
|
num_samples=args.val_num_samples,
|
||||||
)
|
)
|
||||||
|
|
||||||
# setup mixup / cutmix
|
# setup mixup / cutmix
|
||||||
@ -650,7 +678,6 @@ def main():
|
|||||||
input_size=data_config['input_size'],
|
input_size=data_config['input_size'],
|
||||||
batch_size=args.batch_size,
|
batch_size=args.batch_size,
|
||||||
is_training=True,
|
is_training=True,
|
||||||
use_prefetcher=args.prefetcher,
|
|
||||||
no_aug=args.no_aug,
|
no_aug=args.no_aug,
|
||||||
re_prob=args.reprob,
|
re_prob=args.reprob,
|
||||||
re_mode=args.remode,
|
re_mode=args.remode,
|
||||||
@ -661,6 +688,9 @@ def main():
|
|||||||
hflip=args.hflip,
|
hflip=args.hflip,
|
||||||
vflip=args.vflip,
|
vflip=args.vflip,
|
||||||
color_jitter=args.color_jitter,
|
color_jitter=args.color_jitter,
|
||||||
|
color_jitter_prob=args.color_jitter_prob,
|
||||||
|
grayscale_prob=args.grayscale_prob,
|
||||||
|
gaussian_blur_prob=args.gaussian_blur_prob,
|
||||||
auto_augment=args.aa,
|
auto_augment=args.aa,
|
||||||
num_aug_repeats=args.aug_repeats,
|
num_aug_repeats=args.aug_repeats,
|
||||||
num_aug_splits=num_aug_splits,
|
num_aug_splits=num_aug_splits,
|
||||||
@ -672,6 +702,7 @@ def main():
|
|||||||
collate_fn=collate_fn,
|
collate_fn=collate_fn,
|
||||||
pin_memory=args.pin_mem,
|
pin_memory=args.pin_mem,
|
||||||
device=device,
|
device=device,
|
||||||
|
use_prefetcher=args.prefetcher,
|
||||||
use_multi_epochs_loader=args.use_multi_epochs_loader,
|
use_multi_epochs_loader=args.use_multi_epochs_loader,
|
||||||
worker_seeding=args.worker_seeding,
|
worker_seeding=args.worker_seeding,
|
||||||
)
|
)
|
||||||
@ -685,7 +716,6 @@ def main():
|
|||||||
input_size=data_config['input_size'],
|
input_size=data_config['input_size'],
|
||||||
batch_size=args.validation_batch_size or args.batch_size,
|
batch_size=args.validation_batch_size or args.batch_size,
|
||||||
is_training=False,
|
is_training=False,
|
||||||
use_prefetcher=args.prefetcher,
|
|
||||||
interpolation=data_config['interpolation'],
|
interpolation=data_config['interpolation'],
|
||||||
mean=data_config['mean'],
|
mean=data_config['mean'],
|
||||||
std=data_config['std'],
|
std=data_config['std'],
|
||||||
@ -694,6 +724,7 @@ def main():
|
|||||||
crop_pct=data_config['crop_pct'],
|
crop_pct=data_config['crop_pct'],
|
||||||
pin_memory=args.pin_mem,
|
pin_memory=args.pin_mem,
|
||||||
device=device,
|
device=device,
|
||||||
|
use_prefetcher=args.prefetcher,
|
||||||
)
|
)
|
||||||
|
|
||||||
# setup loss function
|
# setup loss function
|
||||||
|
28
validate.py
28
validate.py
@ -61,10 +61,23 @@ parser.add_argument('--dataset', metavar='NAME', default='',
|
|||||||
help='dataset type + name ("<type>/<name>") (default: ImageFolder or ImageTar if empty)')
|
help='dataset type + name ("<type>/<name>") (default: ImageFolder or ImageTar if empty)')
|
||||||
parser.add_argument('--split', metavar='NAME', default='validation',
|
parser.add_argument('--split', metavar='NAME', default='validation',
|
||||||
help='dataset split (default: validation)')
|
help='dataset split (default: validation)')
|
||||||
|
parser.add_argument('--num-samples', default=None, type=int,
|
||||||
|
metavar='N', help='Manually specify num samples in dataset split, for IterableDatasets.')
|
||||||
parser.add_argument('--dataset-download', action='store_true', default=False,
|
parser.add_argument('--dataset-download', action='store_true', default=False,
|
||||||
help='Allow download of dataset for torch/ and tfds/ datasets that support it.')
|
help='Allow download of dataset for torch/ and tfds/ datasets that support it.')
|
||||||
|
parser.add_argument('--class-map', default='', type=str, metavar='FILENAME',
|
||||||
|
help='path to class to idx mapping file (default: "")')
|
||||||
|
parser.add_argument('--input-key', default=None, type=str,
|
||||||
|
help='Dataset key for input images.')
|
||||||
|
parser.add_argument('--input-img-mode', default=None, type=str,
|
||||||
|
help='Dataset image conversion mode for input images.')
|
||||||
|
parser.add_argument('--target-key', default=None, type=str,
|
||||||
|
help='Dataset key for target labels.')
|
||||||
|
|
||||||
parser.add_argument('--model', '-m', metavar='NAME', default='dpn92',
|
parser.add_argument('--model', '-m', metavar='NAME', default='dpn92',
|
||||||
help='model architecture (default: dpn92)')
|
help='model architecture (default: dpn92)')
|
||||||
|
parser.add_argument('--pretrained', dest='pretrained', action='store_true',
|
||||||
|
help='use pre-trained model')
|
||||||
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
|
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
|
||||||
help='number of data loading workers (default: 4)')
|
help='number of data loading workers (default: 4)')
|
||||||
parser.add_argument('-b', '--batch-size', default=256, type=int,
|
parser.add_argument('-b', '--batch-size', default=256, type=int,
|
||||||
@ -81,6 +94,8 @@ parser.add_argument('--crop-pct', default=None, type=float,
|
|||||||
metavar='N', help='Input image center crop pct')
|
metavar='N', help='Input image center crop pct')
|
||||||
parser.add_argument('--crop-mode', default=None, type=str,
|
parser.add_argument('--crop-mode', default=None, type=str,
|
||||||
metavar='N', help='Input image crop mode (squash, border, center). Model default if None.')
|
metavar='N', help='Input image crop mode (squash, border, center). Model default if None.')
|
||||||
|
parser.add_argument('--crop-border-pixels', type=int, default=None,
|
||||||
|
help='Crop pixels from image border.')
|
||||||
parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
|
parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
|
||||||
help='Override mean pixel value of dataset')
|
help='Override mean pixel value of dataset')
|
||||||
parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',
|
parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',
|
||||||
@ -89,16 +104,12 @@ parser.add_argument('--interpolation', default='', type=str, metavar='NAME',
|
|||||||
help='Image resize interpolation type (overrides model)')
|
help='Image resize interpolation type (overrides model)')
|
||||||
parser.add_argument('--num-classes', type=int, default=None,
|
parser.add_argument('--num-classes', type=int, default=None,
|
||||||
help='Number classes in dataset')
|
help='Number classes in dataset')
|
||||||
parser.add_argument('--class-map', default='', type=str, metavar='FILENAME',
|
|
||||||
help='path to class to idx mapping file (default: "")')
|
|
||||||
parser.add_argument('--gp', default=None, type=str, metavar='POOL',
|
parser.add_argument('--gp', default=None, type=str, metavar='POOL',
|
||||||
help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.')
|
help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.')
|
||||||
parser.add_argument('--log-freq', default=10, type=int,
|
parser.add_argument('--log-freq', default=10, type=int,
|
||||||
metavar='N', help='batch logging frequency (default: 10)')
|
metavar='N', help='batch logging frequency (default: 10)')
|
||||||
parser.add_argument('--checkpoint', default='', type=str, metavar='PATH',
|
parser.add_argument('--checkpoint', default='', type=str, metavar='PATH',
|
||||||
help='path to latest checkpoint (default: none)')
|
help='path to latest checkpoint (default: none)')
|
||||||
parser.add_argument('--pretrained', dest='pretrained', action='store_true',
|
|
||||||
help='use pre-trained model')
|
|
||||||
parser.add_argument('--num-gpu', type=int, default=1,
|
parser.add_argument('--num-gpu', type=int, default=1,
|
||||||
help='Number of GPUS to use')
|
help='Number of GPUS to use')
|
||||||
parser.add_argument('--test-pool', dest='test_pool', action='store_true',
|
parser.add_argument('--test-pool', dest='test_pool', action='store_true',
|
||||||
@ -249,6 +260,10 @@ def validate(args):
|
|||||||
criterion = nn.CrossEntropyLoss().to(device)
|
criterion = nn.CrossEntropyLoss().to(device)
|
||||||
|
|
||||||
root_dir = args.data or args.data_dir
|
root_dir = args.data or args.data_dir
|
||||||
|
if args.input_img_mode is None:
|
||||||
|
input_img_mode = 'RGB' if data_config['input_size'][0] == 3 else 'L'
|
||||||
|
else:
|
||||||
|
input_img_mode = args.input_img_mode
|
||||||
dataset = create_dataset(
|
dataset = create_dataset(
|
||||||
root=root_dir,
|
root=root_dir,
|
||||||
name=args.dataset,
|
name=args.dataset,
|
||||||
@ -256,6 +271,10 @@ def validate(args):
|
|||||||
download=args.dataset_download,
|
download=args.dataset_download,
|
||||||
load_bytes=args.tf_preprocessing,
|
load_bytes=args.tf_preprocessing,
|
||||||
class_map=args.class_map,
|
class_map=args.class_map,
|
||||||
|
num_samples=args.num_samples,
|
||||||
|
input_key=args.input_key,
|
||||||
|
input_img_mode=input_img_mode,
|
||||||
|
target_key=args.target_key,
|
||||||
)
|
)
|
||||||
|
|
||||||
if args.valid_labels:
|
if args.valid_labels:
|
||||||
@ -281,6 +300,7 @@ def validate(args):
|
|||||||
num_workers=args.workers,
|
num_workers=args.workers,
|
||||||
crop_pct=crop_pct,
|
crop_pct=crop_pct,
|
||||||
crop_mode=data_config['crop_mode'],
|
crop_mode=data_config['crop_mode'],
|
||||||
|
crop_border_pixels=args.crop_border_pixels,
|
||||||
pin_memory=args.pin_mem,
|
pin_memory=args.pin_mem,
|
||||||
device=device,
|
device=device,
|
||||||
tf_preprocessing=args.tf_preprocessing,
|
tf_preprocessing=args.tf_preprocessing,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user