""" Dynamic Sequence Length Datasets for Variable Resolution Image Processing Implements two dataset wrappers: 1. DynamicSeqMapDataset - Map-style dataset that returns batches with variable sequence lengths 2. DynamicSeqIterDataset - Iterable dataset that yields batches with variable sequence lengths Both support: - Pre-initialized transforms for efficiency - Distributed training - Multiple workers - Variable batch sizes based on sequence length """ import math import random import warnings from functools import partial from typing import Any, Iterator, List, Tuple, Dict, Optional, Union, Callable import torch from torch.utils.data import Dataset, IterableDataset, DataLoader from torchvision import transforms from PIL import Image from .naflex_transforms import Patchify, patchify def calculate_batch_size( tokens_per_batch: int, seq_len: int, max_size: Optional[int] = None, divisor: int = 1, rounding: str ='floor', ): """Calculate batch size based on sequence length with divisibility constraints.""" # Calculate raw batch size based on sequence length raw_batch_size = tokens_per_batch / seq_len # Apply divisibility with specified rounding method if divisor > 1: if rounding == 'floor': batch_size = math.floor(raw_batch_size / divisor) * divisor elif rounding == 'ceil': batch_size = math.ceil(raw_batch_size / divisor) * divisor else: # 'round' is the default batch_size = round(raw_batch_size / divisor) * divisor else: # If no divisor specified, just use integer division batch_size = int(raw_batch_size) # Ensure batch size is valid batch_size = max(1, batch_size) # At least 1 if max_size is not None: batch_size = min(batch_size, max_size) return batch_size class NaFlexCollator: """Custom collator for batching NaFlex-style variable-resolution images.""" def __init__( self, max_seq_len=None, ): self.max_seq_len = max_seq_len or 576 # Default ViT-B/16 sequence length (577 = 24*24) def __call__(self, batch): """ Args: batch: List of tuples (patch_dict, target) Returns: A tuple of (input_dict, targets) where input_dict contains: - patches: Padded tensor of patches - patch_coord: Coordinates for each patch (y, x) - patch_valid: Valid indicators """ assert isinstance(batch[0], tuple) batch_size = len(batch) # Extract targets targets = [item[1] for item in batch] if isinstance(targets[0], torch.Tensor): targets = torch.stack(targets) else: targets = torch.tensor(targets, dtype=torch.int64) # Get patch dictionaries patch_dicts = [item[0] for item in batch] # If we have a maximum sequence length constraint, ensure we don't exceed it if self.max_seq_len is not None: max_patches = self.max_seq_len else: # Find the maximum number of patches in this batch max_patches = max(item['patches'].shape[0] for item in patch_dicts) # Get patch dimensionality patch_dim = patch_dicts[0]['patches'].shape[1] # Prepare tensors for the batch patches = torch.zeros((batch_size, max_patches, patch_dim), dtype=torch.float32) patch_coord = torch.zeros((batch_size, max_patches, 2), dtype=torch.int64) # [B, N, 2] for (y, x) patch_valid = torch.zeros((batch_size, max_patches), dtype=torch.bool) # Fill in the tensors for i, patch_dict in enumerate(patch_dicts): num_patches = min(patch_dict['patches'].shape[0], max_patches) patches[i, :num_patches] = patch_dict['patches'][:num_patches] patch_coord[i, :num_patches] = patch_dict['patch_coord'][:num_patches] patch_valid[i, :num_patches] = patch_dict['patch_valid'][:num_patches] return { 'patches': patches, 'patch_coord': patch_coord, 'patch_valid': patch_valid, 'seq_len': max_patches, }, targets class VariableSeqMapWrapper(IterableDataset): """ IterableDataset wrapper for a map-style base dataset. Yields batches with variable sequence lengths. It calculates a canonical batch schedule (sequence length, batch size pairs) once based on the total dataset size (padded for distribution). Each epoch, it shuffles the *order* of this canonical schedule and the dataset indices. This ensures a consistent number of batches and samples per epoch across all ranks. Handles distributed training and multiple workers. """ def __init__( self, base_dataset: Dataset, patch_size: Union[int, Tuple[int, int]] = 16, seq_lens: List[int] = (128, 256, 576, 784, 1024), max_tokens_per_batch: int = 4096 * 4, # Example: 16k tokens transform_factory: Optional[Callable] = None, mixup_fn: Optional[Callable] = None, seed: int = 42, shuffle: bool = True, distributed: bool = False, rank: int = 0, world_size: int = 1, epoch: int = 0, batch_divisor: int = 8, # Ensure batch size is divisible by this ): super().__init__() if not hasattr(base_dataset, '__len__') or not hasattr(base_dataset, '__getitem__'): raise TypeError("base_dataset must be a map-style dataset (implement __len__ and __getitem__)") self.base_dataset = base_dataset self.patch_size = patch_size if isinstance(patch_size, tuple) else (patch_size, patch_size) self.seq_lens = sorted(list(set(seq_lens))) # Ensure unique and sorted self.max_tokens_per_batch = max_tokens_per_batch self.seed = seed self.shuffle = shuffle self.distributed = distributed self.rank = rank if distributed else 0 self.world_size = world_size if distributed else 1 self.epoch = epoch self.batch_divisor = batch_divisor # Pre-initialize transforms and collate fns for each sequence length self.transforms: Dict[int, Optional[Callable]] = {} self.collate_fns: Dict[int, Callable] = {} for seq_len in self.seq_lens: if transform_factory: self.transforms[seq_len] = transform_factory(max_seq_len=seq_len, patch_size=self.patch_size) else: self.transforms[seq_len] = None # No transform self.collate_fns[seq_len] = NaFlexCollator(seq_len) self.mixup_fn = mixup_fn self.patchifier = Patchify(self.patch_size) # --- Canonical Schedule Calculation (Done Once) --- self._canonical_batch_schedule: List[Tuple[int, int]] = [] self._num_batches_per_rank: int = 0 self._padded_samples_per_rank: int = 0 self._create_canonical_schedule() # Calculate schedule based on padded size # --- Per-Epoch State --- # Stores (seq_len, list_of_indices) for the current epoch, specific to this rank self._epoch_batches: List[Tuple[int, List[int]]] = [] self._prepare_epoch_batches(self.epoch) # setup for initial epoch def _create_canonical_schedule(self): """ Calculates the canonical batch schedule (seq_len, batch_size pairs) based on the dataset size, padded for distributed training. This schedule is the *same* for all ranks and ensures consistent epoch length. It is calculated once during initialization. """ total_len = len(self.base_dataset) padded_total_len = total_len num_samples_per_rank = total_len if self.distributed and self.world_size > 1: # Calculate padding needed for even distribution if total_len % self.world_size != 0: pad_size = self.world_size - (total_len % self.world_size) padded_total_len += pad_size print(f"Rank {self.rank}: Padding dataset with {pad_size} samples for distributed training (total size {padded_total_len}).") else: pad_size = 0 if padded_total_len % self.world_size != 0: # This should not happen with the padding logic, but safeguard raise RuntimeError(f"Internal Error: Padded total length {padded_total_len} not divisible by world size {self.world_size}") num_samples_per_rank = padded_total_len // self.world_size elif self.distributed and self.world_size <= 1: # Distributed flag set but world_size is 1, treat as non-distributed pass # num_samples_per_rank remains total_len self._padded_samples_per_rank = num_samples_per_rank if num_samples_per_rank == 0: self._canonical_batch_schedule = [] self._num_batches_per_rank = 0 return # Use a fixed seed for generating the canonical schedule structure g = torch.Generator() g.manual_seed(self.seed) # Use base seed, NOT epoch seed current_schedule: List[Tuple[int, int]] = [] remaining_samples = num_samples_per_rank total_scheduled_samples = 0 while remaining_samples > 0: # Sample sequence length deterministically based on base seed seq_idx = torch.randint(0, len(self.seq_lens), (1,), generator=g).item() seq_len = self.seq_lens[seq_idx] # Calculate batch size batch_size = calculate_batch_size( tokens_per_batch=self.max_tokens_per_batch, seq_len=seq_len, # max_size should be remaining_samples to avoid overshooting max_size=remaining_samples, divisor=self.batch_divisor, rounding='floor', ) # Ensure batch size is positive and doesn't exceed remaining samples batch_size = max(1, batch_size) batch_size = min(batch_size, remaining_samples) if batch_size <= 0: warnings.warn(f"Calculated batch size <= 0 (seq_len={seq_len}, remaining={remaining_samples}). Stopping schedule generation early.") break # Avoid infinite loop if something goes wrong current_schedule.append((seq_len, batch_size)) remaining_samples -= batch_size total_scheduled_samples += batch_size # Sanity check: Ensure the schedule covers all samples for the rank if total_scheduled_samples != num_samples_per_rank: warnings.warn( f"Rank {self.rank}: Canonical schedule accounts for {total_scheduled_samples} samples, " f"but expected {num_samples_per_rank} samples per rank. " f"This might happen if min_batch_size or batch_divisor constraints prevent utilizing all samples. " f"Check parameters. Remaining samples: {remaining_samples}" ) # Adjust if needed? Could add a final small batch, but might violate constraints. # Current behavior: some samples might be dropped if schedule logic fails. self._canonical_batch_schedule = current_schedule self._num_batches_per_rank = len(current_schedule) print(f"Rank {self.rank}: Created canonical schedule with {self._num_batches_per_rank} batches for {self._padded_samples_per_rank} samples/rank.") def _prepare_epoch_batches(self, epoch: int): """ Prepares the batches for the current epoch by: 1. Shuffling the full dataset indices (using epoch seed). 2. Applying padding if in distributed mode. 3. Selecting indices for the current rank. 4. Shuffling the *order* of the canonical batch schedule (using epoch seed). 5. Assigning the rank's indices to the shuffled batches. """ g = torch.Generator() g.manual_seed(self.seed + epoch) # Epoch-specific seed for shuffling # 1. Get shuffled global indices total_len = len(self.base_dataset) if self.shuffle: all_indices_shuffled = torch.randperm(total_len, generator=g).tolist() else: all_indices_shuffled = list(range(total_len)) # 2. Apply padding for distributed mode indices_for_ranks = all_indices_shuffled if self.distributed and self.world_size > 1: padded_total_len = self._padded_samples_per_rank * self.world_size if padded_total_len > total_len: pad_size = padded_total_len - total_len # Repeat initial elements from the *shuffled* list for padding indices_for_ranks = all_indices_shuffled + all_indices_shuffled[:pad_size] # Ensure length matches expectation if len(indices_for_ranks) != padded_total_len: raise RuntimeError(f"Internal Error: Padded index list length {len(indices_for_ranks)} does not match expected {padded_total_len}") # 3. Select indices for the current rank if self.distributed and self.world_size > 1: indices_this_rank = indices_for_ranks[self.rank::self.world_size] else: # Non-distributed or world_size=1 indices_this_rank = indices_for_ranks # Sanity check length if len(indices_this_rank) != self._padded_samples_per_rank: # This might happen if canonical schedule generation had warnings/issues warnings.warn( f"Rank {self.rank}: Number of indices for this rank ({len(indices_this_rank)}) " f"does not match expected padded samples per rank ({self._padded_samples_per_rank}). " f"Epoch generation might be inconsistent." ) # Adjust expected samples? Or truncate/pad indices? Let's proceed but warn. # Using min() prevents IndexError later if indices are fewer than expected. effective_samples_this_rank = min(len(indices_this_rank), self._padded_samples_per_rank) indices_this_rank = indices_this_rank[:effective_samples_this_rank] else: effective_samples_this_rank = self._padded_samples_per_rank # 4. Shuffle the order of the canonical batch schedule for this epoch if self.shuffle: schedule_perm = torch.randperm(self._num_batches_per_rank, generator=g).tolist() shuffled_schedule = [self._canonical_batch_schedule[i] for i in schedule_perm] else: shuffled_schedule = list(self._canonical_batch_schedule) # Keep original order # 5. Assign indices to the shuffled batches self._epoch_batches = [] idx_pos = 0 scheduled_samples_count = 0 for seq_len, bs in shuffled_schedule: # Ensure we don't try to grab more indices than available for the rank actual_bs = min(bs, effective_samples_this_rank - idx_pos) if actual_bs <= 0: if scheduled_samples_count < effective_samples_this_rank: # This indicates mismatch between schedule total and actual samples warnings.warn(f"Rank {self.rank}: Ran out of samples ({idx_pos}/{effective_samples_this_rank}) before processing entire schedule. Check schedule generation.") break # Stop if no more indices or batch size is zero batch_indices = indices_this_rank[idx_pos : idx_pos + actual_bs] self._epoch_batches.append((seq_len, batch_indices)) idx_pos += actual_bs scheduled_samples_count += actual_bs # Final check if scheduled_samples_count != effective_samples_this_rank: warnings.warn( f"Rank {self.rank}: Assigned {scheduled_samples_count} samples to batches, " f"but expected {effective_samples_this_rank} effective samples this epoch. " f"Indices remaining: {effective_samples_this_rank - scheduled_samples_count}." ) def set_epoch(self, epoch: int): """Updates the epoch, regenerating the epoch-specific batches.""" # Only regenerate if the epoch actually changes if epoch != self.epoch: self.epoch = epoch self._prepare_epoch_batches(epoch) def __len__(self) -> int: """ Returns the number of batches per **worker** for the current epoch. Calculated based on the fixed number of batches per rank divided by the number of workers. """ return self._num_batches_per_rank def __iter__(self) -> Iterator[Tuple[Dict[str, torch.Tensor], torch.Tensor]]: """ Iterates through the pre-calculated batches for the current epoch, distributing them among workers. """ worker_info = torch.utils.data.get_worker_info() num_workers = worker_info.num_workers if worker_info else 1 worker_id = worker_info.id if worker_info else 0 # Distribute pre-calculated batches among workers for this rank # Each worker processes a slice of the batches prepared in _prepare_epoch_batches batches_for_worker = self._epoch_batches[worker_id::num_workers] for seq_len, indices in batches_for_worker: if not indices: # Skip if a batch ended up with no indices (shouldn't happen often) continue # Get the pre-initialized transform for this sequence length transform = self.transforms.get(seq_len) batch_samples = [] batch_imgs = [] batch_targets = [] for idx in indices: try: # Get original image and label from map-style dataset img, label = self.base_dataset[idx] # Apply transform if available # Handle cases where transform might return None or fail processed_img = transform(img) if transform else img if processed_img is None: warnings.warn(f"Transform returned None for index {idx}. Skipping sample.") continue batch_imgs.append(processed_img) batch_targets.append(label) except IndexError: warnings.warn(f"IndexError encountered for index {idx} (possibly due to padding/repeated indices). Skipping sample.") continue except Exception as e: # Log other potential errors during data loading/processing warnings.warn(f"Error processing sample index {idx}. Error: {e}. Skipping sample.") continue # Skip problematic sample if self.mixup_fn is not None: batch_imgs, batch_targets = self.mixup_fn(batch_imgs, batch_targets) batch_imgs = [self.patchifier(img) for img in batch_imgs] batch_samples = list(zip(batch_imgs, batch_targets)) if batch_samples: # Only yield if we successfully processed samples # Collate the processed samples into a batch yield self.collate_fns[seq_len](batch_samples) # If batch_samples is empty after processing 'indices', an empty batch is skipped.