import math from contextlib import suppress from functools import partial from typing import Callable, List, Optional, Tuple, Union import torch from .constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .loader import _worker_init from .naflex_dataset import VariableSeqMapWrapper, NaFlexCollator from .transforms_factory import create_transform class NaFlexPrefetchLoader: """Data prefetcher for NaFlex format which normalizes patches.""" def __init__( self, loader, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), img_dtype=torch.float32, device=torch.device('cuda') ): self.loader = loader self.device = device self.img_dtype = img_dtype or torch.float32 # Create mean/std tensors for normalization (will be applied to patches) self.mean = torch.tensor([x * 255 for x in mean], device=device, dtype=self.img_dtype).view(1, 1, 3) self.std = torch.tensor([x * 255 for x in std], device=device, dtype=self.img_dtype).view(1, 1, 3) # Check for CUDA/NPU availability self.is_cuda = device.type == 'cuda' and torch.cuda.is_available() self.is_npu = device.type == 'npu' and torch.npu.is_available() def __iter__(self): first = True if self.is_cuda: stream = torch.cuda.Stream() stream_context = partial(torch.cuda.stream, stream=stream) elif self.is_npu: stream = torch.npu.Stream() stream_context = partial(torch.npu.stream, stream=stream) else: stream = None stream_context = suppress for next_input_dict, next_target in self.loader: with stream_context(): # Move all tensors in input_dict to device for k, v in next_input_dict.items(): if isinstance(v, torch.Tensor): dtype = self.img_dtype if k == 'patches' else None next_input_dict[k] = next_input_dict[k].to( device=self.device, non_blocking=True, dtype=dtype, ) next_target = next_target.to(device=self.device, non_blocking=True) # Normalize patch values (assuming patches are in format [B, N, P*P*C]) batch_size, num_patches, patch_pixels = next_input_dict['patches'].shape patches = next_input_dict['patches'].view(batch_size, -1, 3) # to [B*N, P*P, C] for normalization patches = patches.sub(self.mean).div(self.std) # Reshape back next_input_dict['patches'] = patches.reshape(batch_size, num_patches, patch_pixels) if not first: yield input_dict, target else: first = False if stream is not None: if self.is_cuda: torch.cuda.current_stream().wait_stream(stream) elif self.is_npu: torch.npu.current_stream().wait_stream(stream) input_dict = next_input_dict target = next_target yield input_dict, target def __len__(self): return len(self.loader) @property def sampler(self): return self.loader.sampler @property def dataset(self): return self.loader.dataset def create_naflex_loader( dataset, patch_size: Union[Tuple[int, int], int] = 16, train_seq_lens: List[int] = (128, 256, 576, 784, 1024), # Training sequence lengths max_seq_len: int = 576, # Fixed sequence length for validation batch_size: int = 32, # Used for max_seq_len and max(train_seq_lens) is_training: bool = False, no_aug: bool = False, re_prob: float = 0., re_mode: str = 'const', re_count: int = 1, re_split: bool = False, train_crop_mode: Optional[str] = None, scale: Optional[Tuple[float, float]] = None, ratio: Optional[Tuple[float, float]] = None, hflip: float = 0.5, vflip: float = 0., color_jitter: float = 0.4, color_jitter_prob: Optional[float] = None, grayscale_prob: float = 0., gaussian_blur_prob: float = 0., auto_augment: Optional[str] = None, num_aug_repeats: int = 0, num_aug_splits: int = 0, interpolation: str = 'bilinear', mean: Tuple[float, ...] = IMAGENET_DEFAULT_MEAN, std: Tuple[float, ...] = IMAGENET_DEFAULT_STD, crop_pct: Optional[float] = None, crop_mode: Optional[str] = None, crop_border_pixels: Optional[int] = None, num_workers: int = 4, distributed: bool = False, rank: int = 0, world_size: int = 1, seed: int = 42, epoch: int = 0, use_prefetcher: bool = True, pin_memory: bool = True, img_dtype: torch.dtype = torch.float32, device: Union[str, torch.device] = torch.device('cuda'), persistent_workers: bool = True, worker_seeding: str = 'all', ): """Create a data loader with dynamic sequence length sampling for training.""" if is_training: # For training, use the dynamic sequence length mechanism assert num_aug_repeats == 0, 'Augmentation repeats not currently supported in NaFlex loader' transform_factory = partial( create_transform, is_training=True, no_aug=no_aug, train_crop_mode=train_crop_mode, scale=scale, ratio=ratio, hflip=hflip, vflip=vflip, color_jitter=color_jitter, color_jitter_prob=color_jitter_prob, grayscale_prob=grayscale_prob, gaussian_blur_prob=gaussian_blur_prob, auto_augment=auto_augment, interpolation=interpolation, mean=mean, std=std, crop_pct=crop_pct, crop_mode=crop_mode, crop_border_pixels=crop_border_pixels, re_prob=re_prob, re_mode=re_mode, re_count=re_count, use_prefetcher=use_prefetcher, naflex=True, ) max_train_seq_len = max(train_seq_lens) max_tokens_per_batch = batch_size * max_train_seq_len if isinstance(dataset, torch.utils.data.IterableDataset): assert False, "IterableDataset Wrapper is a WIP" naflex_dataset = VariableSeqMapWrapper( dataset, transform_factory=transform_factory, patch_size=patch_size, seq_lens=train_seq_lens, max_tokens_per_batch=max_tokens_per_batch, seed=seed, distributed=distributed, rank=rank, world_size=world_size, shuffle=True, epoch=epoch, ) # NOTE: Collation is handled by the dataset wrapper for training # Create the collator (handles fixed-size collation) # collate_fn = NaFlexCollator( # max_seq_len=max(seq_lens) + 1, # +1 for class token # ) loader = torch.utils.data.DataLoader( naflex_dataset, batch_size=None, shuffle=False, num_workers=num_workers, sampler=None, #collate_fn=collate_fn, pin_memory=pin_memory, worker_init_fn=partial(_worker_init, worker_seeding=worker_seeding), persistent_workers=persistent_workers ) if use_prefetcher: loader = NaFlexPrefetchLoader( loader, mean=mean, std=std, img_dtype=img_dtype, device=device, ) else: # For validation, use fixed sequence length (unchanged) dataset.transform = create_transform( is_training=False, interpolation=interpolation, mean=mean, std=std, # FIXME add crop args when sequence transforms support crop modes use_prefetcher=use_prefetcher, naflex=True, patch_size=patch_size, max_seq_len=max_seq_len, patchify=True, ) # Create the collator collate_fn = NaFlexCollator(max_seq_len=max_seq_len) # Handle distributed training sampler = None if distributed and not isinstance(dataset, torch.utils.data.IterableDataset): # For validation, use OrderedDistributedSampler from timm.data.distributed_sampler import OrderedDistributedSampler sampler = OrderedDistributedSampler(dataset) loader = torch.utils.data.DataLoader( dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, sampler=sampler, collate_fn=collate_fn, pin_memory=pin_memory, drop_last=False, ) if use_prefetcher: loader = NaFlexPrefetchLoader( loader, mean=mean, std=std, img_dtype=img_dtype, device=device, ) return loader