mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
* Add parser/dataset factory methods for more flexible dataset & parser creation * Add dataset parser that wraps TFDS image classification datasets * Tweak num_classes handling bug for 21k models * Add initial deit models so they can be benchmarked in next csv results runs
263 lines
8.5 KiB
Python
263 lines
8.5 KiB
Python
""" Loader Factory, Fast Collate, CUDA Prefetcher
|
|
|
|
Prefetcher and Fast Collate inspired by NVIDIA APEX example at
|
|
https://github.com/NVIDIA/apex/commit/d5e2bb4bdeedd27b1dfaf5bb2b24d6c000dee9be#diff-cf86c282ff7fba81fad27a559379d5bf
|
|
|
|
Hacked together by / Copyright 2020 Ross Wightman
|
|
"""
|
|
|
|
import torch.utils.data
|
|
import numpy as np
|
|
|
|
from .transforms_factory import create_transform
|
|
from .constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
|
from .distributed_sampler import OrderedDistributedSampler
|
|
from .random_erasing import RandomErasing
|
|
from .mixup import FastCollateMixup
|
|
|
|
|
|
def fast_collate(batch):
|
|
""" A fast collation function optimized for uint8 images (np array or torch) and int64 targets (labels)"""
|
|
assert isinstance(batch[0], tuple)
|
|
batch_size = len(batch)
|
|
if isinstance(batch[0][0], tuple):
|
|
# This branch 'deinterleaves' and flattens tuples of input tensors into one tensor ordered by position
|
|
# such that all tuple of position n will end up in a torch.split(tensor, batch_size) in nth position
|
|
inner_tuple_size = len(batch[0][0])
|
|
flattened_batch_size = batch_size * inner_tuple_size
|
|
targets = torch.zeros(flattened_batch_size, dtype=torch.int64)
|
|
tensor = torch.zeros((flattened_batch_size, *batch[0][0][0].shape), dtype=torch.uint8)
|
|
for i in range(batch_size):
|
|
assert len(batch[i][0]) == inner_tuple_size # all input tensor tuples must be same length
|
|
for j in range(inner_tuple_size):
|
|
targets[i + j * batch_size] = batch[i][1]
|
|
tensor[i + j * batch_size] += torch.from_numpy(batch[i][0][j])
|
|
return tensor, targets
|
|
elif isinstance(batch[0][0], np.ndarray):
|
|
targets = torch.tensor([b[1] for b in batch], dtype=torch.int64)
|
|
assert len(targets) == batch_size
|
|
tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8)
|
|
for i in range(batch_size):
|
|
tensor[i] += torch.from_numpy(batch[i][0])
|
|
return tensor, targets
|
|
elif isinstance(batch[0][0], torch.Tensor):
|
|
targets = torch.tensor([b[1] for b in batch], dtype=torch.int64)
|
|
assert len(targets) == batch_size
|
|
tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8)
|
|
for i in range(batch_size):
|
|
tensor[i].copy_(batch[i][0])
|
|
return tensor, targets
|
|
else:
|
|
assert False
|
|
|
|
|
|
class PrefetchLoader:
|
|
|
|
def __init__(self,
|
|
loader,
|
|
mean=IMAGENET_DEFAULT_MEAN,
|
|
std=IMAGENET_DEFAULT_STD,
|
|
fp16=False,
|
|
re_prob=0.,
|
|
re_mode='const',
|
|
re_count=1,
|
|
re_num_splits=0):
|
|
self.loader = loader
|
|
self.mean = torch.tensor([x * 255 for x in mean]).cuda().view(1, 3, 1, 1)
|
|
self.std = torch.tensor([x * 255 for x in std]).cuda().view(1, 3, 1, 1)
|
|
self.fp16 = fp16
|
|
if fp16:
|
|
self.mean = self.mean.half()
|
|
self.std = self.std.half()
|
|
if re_prob > 0.:
|
|
self.random_erasing = RandomErasing(
|
|
probability=re_prob, mode=re_mode, max_count=re_count, num_splits=re_num_splits)
|
|
else:
|
|
self.random_erasing = None
|
|
|
|
def __iter__(self):
|
|
stream = torch.cuda.Stream()
|
|
first = True
|
|
|
|
for next_input, next_target in self.loader:
|
|
with torch.cuda.stream(stream):
|
|
next_input = next_input.cuda(non_blocking=True)
|
|
next_target = next_target.cuda(non_blocking=True)
|
|
if self.fp16:
|
|
next_input = next_input.half().sub_(self.mean).div_(self.std)
|
|
else:
|
|
next_input = next_input.float().sub_(self.mean).div_(self.std)
|
|
if self.random_erasing is not None:
|
|
next_input = self.random_erasing(next_input)
|
|
|
|
if not first:
|
|
yield input, target
|
|
else:
|
|
first = False
|
|
|
|
torch.cuda.current_stream().wait_stream(stream)
|
|
input = next_input
|
|
target = next_target
|
|
|
|
yield input, target
|
|
|
|
def __len__(self):
|
|
return len(self.loader)
|
|
|
|
@property
|
|
def sampler(self):
|
|
return self.loader.sampler
|
|
|
|
@property
|
|
def dataset(self):
|
|
return self.loader.dataset
|
|
|
|
@property
|
|
def mixup_enabled(self):
|
|
if isinstance(self.loader.collate_fn, FastCollateMixup):
|
|
return self.loader.collate_fn.mixup_enabled
|
|
else:
|
|
return False
|
|
|
|
@mixup_enabled.setter
|
|
def mixup_enabled(self, x):
|
|
if isinstance(self.loader.collate_fn, FastCollateMixup):
|
|
self.loader.collate_fn.mixup_enabled = x
|
|
|
|
|
|
def create_loader(
|
|
dataset,
|
|
input_size,
|
|
batch_size,
|
|
is_training=False,
|
|
use_prefetcher=True,
|
|
no_aug=False,
|
|
re_prob=0.,
|
|
re_mode='const',
|
|
re_count=1,
|
|
re_split=False,
|
|
scale=None,
|
|
ratio=None,
|
|
hflip=0.5,
|
|
vflip=0.,
|
|
color_jitter=0.4,
|
|
auto_augment=None,
|
|
num_aug_splits=0,
|
|
interpolation='bilinear',
|
|
mean=IMAGENET_DEFAULT_MEAN,
|
|
std=IMAGENET_DEFAULT_STD,
|
|
num_workers=1,
|
|
distributed=False,
|
|
crop_pct=None,
|
|
collate_fn=None,
|
|
pin_memory=False,
|
|
fp16=False,
|
|
tf_preprocessing=False,
|
|
use_multi_epochs_loader=False,
|
|
persistent_workers=True,
|
|
):
|
|
re_num_splits = 0
|
|
if re_split:
|
|
# apply RE to second half of batch if no aug split otherwise line up with aug split
|
|
re_num_splits = num_aug_splits or 2
|
|
dataset.transform = create_transform(
|
|
input_size,
|
|
is_training=is_training,
|
|
use_prefetcher=use_prefetcher,
|
|
no_aug=no_aug,
|
|
scale=scale,
|
|
ratio=ratio,
|
|
hflip=hflip,
|
|
vflip=vflip,
|
|
color_jitter=color_jitter,
|
|
auto_augment=auto_augment,
|
|
interpolation=interpolation,
|
|
mean=mean,
|
|
std=std,
|
|
crop_pct=crop_pct,
|
|
tf_preprocessing=tf_preprocessing,
|
|
re_prob=re_prob,
|
|
re_mode=re_mode,
|
|
re_count=re_count,
|
|
re_num_splits=re_num_splits,
|
|
separate=num_aug_splits > 0,
|
|
)
|
|
|
|
sampler = None
|
|
if distributed and not isinstance(dataset, torch.utils.data.IterableDataset):
|
|
if is_training:
|
|
sampler = torch.utils.data.distributed.DistributedSampler(dataset)
|
|
else:
|
|
# This will add extra duplicate entries to result in equal num
|
|
# of samples per-process, will slightly alter validation results
|
|
sampler = OrderedDistributedSampler(dataset)
|
|
|
|
if collate_fn is None:
|
|
collate_fn = fast_collate if use_prefetcher else torch.utils.data.dataloader.default_collate
|
|
|
|
loader_class = torch.utils.data.DataLoader
|
|
|
|
if use_multi_epochs_loader:
|
|
loader_class = MultiEpochsDataLoader
|
|
|
|
loader_args = dict(
|
|
batch_size=batch_size,
|
|
shuffle=not isinstance(dataset, torch.utils.data.IterableDataset) and sampler is None and is_training,
|
|
num_workers=num_workers,
|
|
sampler=sampler,
|
|
collate_fn=collate_fn,
|
|
pin_memory=pin_memory,
|
|
drop_last=is_training,
|
|
persistent_workers=persistent_workers)
|
|
try:
|
|
loader = loader_class(dataset, **loader_args)
|
|
except TypeError as e:
|
|
loader_args.pop('persistent_workers') # only in Pytorch 1.7+
|
|
loader = loader_class(dataset, **loader_args)
|
|
if use_prefetcher:
|
|
prefetch_re_prob = re_prob if is_training and not no_aug else 0.
|
|
loader = PrefetchLoader(
|
|
loader,
|
|
mean=mean,
|
|
std=std,
|
|
fp16=fp16,
|
|
re_prob=prefetch_re_prob,
|
|
re_mode=re_mode,
|
|
re_count=re_count,
|
|
re_num_splits=re_num_splits
|
|
)
|
|
|
|
return loader
|
|
|
|
|
|
class MultiEpochsDataLoader(torch.utils.data.DataLoader):
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self._DataLoader__initialized = False
|
|
self.batch_sampler = _RepeatSampler(self.batch_sampler)
|
|
self._DataLoader__initialized = True
|
|
self.iterator = super().__iter__()
|
|
|
|
def __len__(self):
|
|
return len(self.batch_sampler.sampler)
|
|
|
|
def __iter__(self):
|
|
for i in range(len(self)):
|
|
yield next(self.iterator)
|
|
|
|
|
|
class _RepeatSampler(object):
|
|
""" Sampler that repeats forever.
|
|
|
|
Args:
|
|
sampler (Sampler)
|
|
"""
|
|
|
|
def __init__(self, sampler):
|
|
self.sampler = sampler
|
|
|
|
def __iter__(self):
|
|
while True:
|
|
yield from iter(self.sampler)
|