2025-04-07 21:27:10 -07:00
"""
Dynamic Sequence Length Datasets for Variable Resolution Image Processing
Implements two dataset wrappers :
1. DynamicSeqMapDataset - Map - style dataset that returns batches with variable sequence lengths
2. DynamicSeqIterDataset - Iterable dataset that yields batches with variable sequence lengths
Both support :
- Pre - initialized transforms for efficiency
- Distributed training
- Multiple workers
- Variable batch sizes based on sequence length
"""
import math
import random
import warnings
from functools import partial
from typing import Any , Iterator , List , Tuple , Dict , Optional , Union , Callable
import torch
from torch . utils . data import Dataset , IterableDataset , DataLoader
from torchvision import transforms
from PIL import Image
from . naflex_transforms import Patchify , patchify
def calculate_batch_size (
tokens_per_batch : int ,
seq_len : int ,
max_size : Optional [ int ] = None ,
divisor : int = 1 ,
rounding : str = ' floor ' ,
) :
""" Calculate batch size based on sequence length with divisibility constraints. """
# Calculate raw batch size based on sequence length
raw_batch_size = tokens_per_batch / seq_len
# Apply divisibility with specified rounding method
if divisor > 1 :
if rounding == ' floor ' :
batch_size = math . floor ( raw_batch_size / divisor ) * divisor
elif rounding == ' ceil ' :
batch_size = math . ceil ( raw_batch_size / divisor ) * divisor
else : # 'round' is the default
batch_size = round ( raw_batch_size / divisor ) * divisor
else :
# If no divisor specified, just use integer division
batch_size = int ( raw_batch_size )
# Ensure batch size is valid
batch_size = max ( 1 , batch_size ) # At least 1
if max_size is not None :
batch_size = min ( batch_size , max_size )
return batch_size
2025-04-29 10:44:46 -07:00
class NaFlexCollator :
""" Custom collator for batching NaFlex-style variable-resolution images. """
def __init__ (
self ,
max_seq_len = None ,
) :
self . max_seq_len = max_seq_len or 576 # Default ViT-B/16 sequence length (577 = 24*24)
def __call__ ( self , batch ) :
"""
Args :
batch : List of tuples ( patch_dict , target )
Returns :
A tuple of ( input_dict , targets ) where input_dict contains :
- patches : Padded tensor of patches
- patch_coord : Coordinates for each patch ( y , x )
- patch_valid : Valid indicators
"""
assert isinstance ( batch [ 0 ] , tuple )
batch_size = len ( batch )
# Extract targets
2025-05-20 14:38:03 -07:00
targets = [ item [ 1 ] for item in batch ]
if isinstance ( targets [ 0 ] , torch . Tensor ) :
targets = torch . stack ( targets )
else :
targets = torch . tensor ( targets , dtype = torch . int64 )
2025-04-29 10:44:46 -07:00
# Get patch dictionaries
patch_dicts = [ item [ 0 ] for item in batch ]
# If we have a maximum sequence length constraint, ensure we don't exceed it
if self . max_seq_len is not None :
max_patches = self . max_seq_len
2025-04-07 21:27:10 -07:00
else :
2025-04-29 10:44:46 -07:00
# Find the maximum number of patches in this batch
max_patches = max ( item [ ' patches ' ] . shape [ 0 ] for item in patch_dicts )
# Get patch dimensionality
patch_dim = patch_dicts [ 0 ] [ ' patches ' ] . shape [ 1 ]
# Prepare tensors for the batch
patches = torch . zeros ( ( batch_size , max_patches , patch_dim ) , dtype = torch . float32 )
patch_coord = torch . zeros ( ( batch_size , max_patches , 2 ) , dtype = torch . int64 ) # [B, N, 2] for (y, x)
patch_valid = torch . zeros ( ( batch_size , max_patches ) , dtype = torch . bool )
# Fill in the tensors
for i , patch_dict in enumerate ( patch_dicts ) :
num_patches = min ( patch_dict [ ' patches ' ] . shape [ 0 ] , max_patches )
2025-04-07 21:27:10 -07:00
2025-04-29 10:44:46 -07:00
patches [ i , : num_patches ] = patch_dict [ ' patches ' ] [ : num_patches ]
patch_coord [ i , : num_patches ] = patch_dict [ ' patch_coord ' ] [ : num_patches ]
patch_valid [ i , : num_patches ] = patch_dict [ ' patch_valid ' ] [ : num_patches ]
return {
' patches ' : patches ,
' patch_coord ' : patch_coord ,
' patch_valid ' : patch_valid ,
' seq_len ' : max_patches ,
} , targets
2025-04-07 21:27:10 -07:00
class VariableSeqMapWrapper ( IterableDataset ) :
"""
IterableDataset wrapper for a map - style base dataset .
Yields batches with variable sequence lengths . It calculates a canonical
batch schedule ( sequence length , batch size pairs ) once based on the
total dataset size ( padded for distribution ) . Each epoch , it shuffles
the * order * of this canonical schedule and the dataset indices .
This ensures a consistent number of batches and samples per epoch
across all ranks . Handles distributed training and multiple workers .
"""
def __init__ (
self ,
base_dataset : Dataset ,
patch_size : Union [ int , Tuple [ int , int ] ] = 16 ,
seq_lens : List [ int ] = ( 128 , 256 , 576 , 784 , 1024 ) ,
max_tokens_per_batch : int = 4096 * 4 , # Example: 16k tokens
transform_factory : Optional [ Callable ] = None ,
2025-05-20 14:38:03 -07:00
mixup_fn : Optional [ Callable ] = None ,
2025-04-07 21:27:10 -07:00
seed : int = 42 ,
shuffle : bool = True ,
distributed : bool = False ,
rank : int = 0 ,
world_size : int = 1 ,
epoch : int = 0 ,
batch_divisor : int = 8 , # Ensure batch size is divisible by this
) :
super ( ) . __init__ ( )
if not hasattr ( base_dataset , ' __len__ ' ) or not hasattr ( base_dataset , ' __getitem__ ' ) :
raise TypeError ( " base_dataset must be a map-style dataset (implement __len__ and __getitem__) " )
self . base_dataset = base_dataset
self . patch_size = patch_size if isinstance ( patch_size , tuple ) else ( patch_size , patch_size )
self . seq_lens = sorted ( list ( set ( seq_lens ) ) ) # Ensure unique and sorted
self . max_tokens_per_batch = max_tokens_per_batch
self . seed = seed
self . shuffle = shuffle
self . distributed = distributed
self . rank = rank if distributed else 0
self . world_size = world_size if distributed else 1
self . epoch = epoch
self . batch_divisor = batch_divisor
2025-04-29 10:44:46 -07:00
# Pre-initialize transforms and collate fns for each sequence length
2025-04-07 21:27:10 -07:00
self . transforms : Dict [ int , Optional [ Callable ] ] = { }
2025-04-29 10:44:46 -07:00
self . collate_fns : Dict [ int , Callable ] = { }
for seq_len in self . seq_lens :
if transform_factory :
2025-04-07 21:27:10 -07:00
self . transforms [ seq_len ] = transform_factory ( max_seq_len = seq_len , patch_size = self . patch_size )
2025-04-29 10:44:46 -07:00
else :
self . transforms [ seq_len ] = None # No transform
self . collate_fns [ seq_len ] = NaFlexCollator ( seq_len )
2025-05-20 14:38:03 -07:00
self . mixup_fn = mixup_fn
2025-04-07 21:27:10 -07:00
self . patchifier = Patchify ( self . patch_size )
# --- Canonical Schedule Calculation (Done Once) ---
self . _canonical_batch_schedule : List [ Tuple [ int , int ] ] = [ ]
self . _num_batches_per_rank : int = 0
self . _padded_samples_per_rank : int = 0
self . _create_canonical_schedule ( ) # Calculate schedule based on padded size
# --- Per-Epoch State ---
# Stores (seq_len, list_of_indices) for the current epoch, specific to this rank
self . _epoch_batches : List [ Tuple [ int , List [ int ] ] ] = [ ]
self . _prepare_epoch_batches ( self . epoch ) # setup for initial epoch
def _create_canonical_schedule ( self ) :
"""
Calculates the canonical batch schedule ( seq_len , batch_size pairs )
based on the dataset size , padded for distributed training .
This schedule is the * same * for all ranks and ensures consistent
epoch length . It is calculated once during initialization .
"""
total_len = len ( self . base_dataset )
padded_total_len = total_len
num_samples_per_rank = total_len
if self . distributed and self . world_size > 1 :
# Calculate padding needed for even distribution
if total_len % self . world_size != 0 :
pad_size = self . world_size - ( total_len % self . world_size )
padded_total_len + = pad_size
print ( f " Rank { self . rank } : Padding dataset with { pad_size } samples for distributed training (total size { padded_total_len } ). " )
else :
pad_size = 0
if padded_total_len % self . world_size != 0 :
# This should not happen with the padding logic, but safeguard
raise RuntimeError ( f " Internal Error: Padded total length { padded_total_len } not divisible by world size { self . world_size } " )
num_samples_per_rank = padded_total_len / / self . world_size
elif self . distributed and self . world_size < = 1 :
# Distributed flag set but world_size is 1, treat as non-distributed
pass # num_samples_per_rank remains total_len
self . _padded_samples_per_rank = num_samples_per_rank
if num_samples_per_rank == 0 :
self . _canonical_batch_schedule = [ ]
self . _num_batches_per_rank = 0
return
# Use a fixed seed for generating the canonical schedule structure
g = torch . Generator ( )
g . manual_seed ( self . seed ) # Use base seed, NOT epoch seed
current_schedule : List [ Tuple [ int , int ] ] = [ ]
remaining_samples = num_samples_per_rank
total_scheduled_samples = 0
while remaining_samples > 0 :
# Sample sequence length deterministically based on base seed
seq_idx = torch . randint ( 0 , len ( self . seq_lens ) , ( 1 , ) , generator = g ) . item ( )
seq_len = self . seq_lens [ seq_idx ]
# Calculate batch size
batch_size = calculate_batch_size (
tokens_per_batch = self . max_tokens_per_batch ,
seq_len = seq_len ,
# max_size should be remaining_samples to avoid overshooting
max_size = remaining_samples ,
divisor = self . batch_divisor ,
rounding = ' floor ' ,
)
# Ensure batch size is positive and doesn't exceed remaining samples
batch_size = max ( 1 , batch_size )
batch_size = min ( batch_size , remaining_samples )
if batch_size < = 0 :
warnings . warn ( f " Calculated batch size <= 0 (seq_len= { seq_len } , remaining= { remaining_samples } ). Stopping schedule generation early. " )
break # Avoid infinite loop if something goes wrong
current_schedule . append ( ( seq_len , batch_size ) )
remaining_samples - = batch_size
total_scheduled_samples + = batch_size
# Sanity check: Ensure the schedule covers all samples for the rank
if total_scheduled_samples != num_samples_per_rank :
warnings . warn (
f " Rank { self . rank } : Canonical schedule accounts for { total_scheduled_samples } samples, "
f " but expected { num_samples_per_rank } samples per rank. "
f " This might happen if min_batch_size or batch_divisor constraints prevent utilizing all samples. "
f " Check parameters. Remaining samples: { remaining_samples } "
)
# Adjust if needed? Could add a final small batch, but might violate constraints.
# Current behavior: some samples might be dropped if schedule logic fails.
self . _canonical_batch_schedule = current_schedule
self . _num_batches_per_rank = len ( current_schedule )
print ( f " Rank { self . rank } : Created canonical schedule with { self . _num_batches_per_rank } batches for { self . _padded_samples_per_rank } samples/rank. " )
def _prepare_epoch_batches ( self , epoch : int ) :
"""
Prepares the batches for the current epoch by :
1. Shuffling the full dataset indices ( using epoch seed ) .
2. Applying padding if in distributed mode .
3. Selecting indices for the current rank .
4. Shuffling the * order * of the canonical batch schedule ( using epoch seed ) .
5. Assigning the rank ' s indices to the shuffled batches.
"""
g = torch . Generator ( )
g . manual_seed ( self . seed + epoch ) # Epoch-specific seed for shuffling
# 1. Get shuffled global indices
total_len = len ( self . base_dataset )
if self . shuffle :
all_indices_shuffled = torch . randperm ( total_len , generator = g ) . tolist ( )
else :
all_indices_shuffled = list ( range ( total_len ) )
# 2. Apply padding for distributed mode
indices_for_ranks = all_indices_shuffled
if self . distributed and self . world_size > 1 :
padded_total_len = self . _padded_samples_per_rank * self . world_size
if padded_total_len > total_len :
pad_size = padded_total_len - total_len
# Repeat initial elements from the *shuffled* list for padding
indices_for_ranks = all_indices_shuffled + all_indices_shuffled [ : pad_size ]
# Ensure length matches expectation
if len ( indices_for_ranks ) != padded_total_len :
raise RuntimeError ( f " Internal Error: Padded index list length { len ( indices_for_ranks ) } does not match expected { padded_total_len } " )
# 3. Select indices for the current rank
if self . distributed and self . world_size > 1 :
indices_this_rank = indices_for_ranks [ self . rank : : self . world_size ]
else : # Non-distributed or world_size=1
indices_this_rank = indices_for_ranks
# Sanity check length
if len ( indices_this_rank ) != self . _padded_samples_per_rank :
# This might happen if canonical schedule generation had warnings/issues
warnings . warn (
f " Rank { self . rank } : Number of indices for this rank ( { len ( indices_this_rank ) } ) "
f " does not match expected padded samples per rank ( { self . _padded_samples_per_rank } ). "
f " Epoch generation might be inconsistent. "
)
# Adjust expected samples? Or truncate/pad indices? Let's proceed but warn.
# Using min() prevents IndexError later if indices are fewer than expected.
effective_samples_this_rank = min ( len ( indices_this_rank ) , self . _padded_samples_per_rank )
indices_this_rank = indices_this_rank [ : effective_samples_this_rank ]
else :
effective_samples_this_rank = self . _padded_samples_per_rank
# 4. Shuffle the order of the canonical batch schedule for this epoch
if self . shuffle :
schedule_perm = torch . randperm ( self . _num_batches_per_rank , generator = g ) . tolist ( )
shuffled_schedule = [ self . _canonical_batch_schedule [ i ] for i in schedule_perm ]
else :
shuffled_schedule = list ( self . _canonical_batch_schedule ) # Keep original order
# 5. Assign indices to the shuffled batches
self . _epoch_batches = [ ]
idx_pos = 0
scheduled_samples_count = 0
for seq_len , bs in shuffled_schedule :
# Ensure we don't try to grab more indices than available for the rank
actual_bs = min ( bs , effective_samples_this_rank - idx_pos )
if actual_bs < = 0 :
if scheduled_samples_count < effective_samples_this_rank :
# This indicates mismatch between schedule total and actual samples
warnings . warn ( f " Rank { self . rank } : Ran out of samples ( { idx_pos } / { effective_samples_this_rank } ) before processing entire schedule. Check schedule generation. " )
break # Stop if no more indices or batch size is zero
batch_indices = indices_this_rank [ idx_pos : idx_pos + actual_bs ]
self . _epoch_batches . append ( ( seq_len , batch_indices ) )
idx_pos + = actual_bs
scheduled_samples_count + = actual_bs
# Final check
if scheduled_samples_count != effective_samples_this_rank :
warnings . warn (
f " Rank { self . rank } : Assigned { scheduled_samples_count } samples to batches, "
f " but expected { effective_samples_this_rank } effective samples this epoch. "
f " Indices remaining: { effective_samples_this_rank - scheduled_samples_count } . "
)
def set_epoch ( self , epoch : int ) :
""" Updates the epoch, regenerating the epoch-specific batches. """
# Only regenerate if the epoch actually changes
if epoch != self . epoch :
self . epoch = epoch
self . _prepare_epoch_batches ( epoch )
def __len__ ( self ) - > int :
"""
Returns the number of batches per * * worker * * for the current epoch .
Calculated based on the fixed number of batches per rank divided by
the number of workers .
"""
return self . _num_batches_per_rank
def __iter__ ( self ) - > Iterator [ Tuple [ Dict [ str , torch . Tensor ] , torch . Tensor ] ] :
"""
Iterates through the pre - calculated batches for the current epoch ,
distributing them among workers .
"""
worker_info = torch . utils . data . get_worker_info ( )
num_workers = worker_info . num_workers if worker_info else 1
worker_id = worker_info . id if worker_info else 0
# Distribute pre-calculated batches among workers for this rank
# Each worker processes a slice of the batches prepared in _prepare_epoch_batches
batches_for_worker = self . _epoch_batches [ worker_id : : num_workers ]
for seq_len , indices in batches_for_worker :
if not indices : # Skip if a batch ended up with no indices (shouldn't happen often)
continue
# Get the pre-initialized transform for this sequence length
transform = self . transforms . get ( seq_len )
batch_samples = [ ]
2025-05-20 14:38:03 -07:00
batch_imgs = [ ]
batch_targets = [ ]
2025-04-07 21:27:10 -07:00
for idx in indices :
try :
# Get original image and label from map-style dataset
img , label = self . base_dataset [ idx ]
# Apply transform if available
# Handle cases where transform might return None or fail
processed_img = transform ( img ) if transform else img
if processed_img is None :
warnings . warn ( f " Transform returned None for index { idx } . Skipping sample. " )
continue
2025-05-20 14:38:03 -07:00
batch_imgs . append ( processed_img )
batch_targets . append ( label )
2025-04-07 21:27:10 -07:00
except IndexError :
warnings . warn ( f " IndexError encountered for index { idx } (possibly due to padding/repeated indices). Skipping sample. " )
continue
except Exception as e :
# Log other potential errors during data loading/processing
warnings . warn ( f " Error processing sample index { idx } . Error: { e } . Skipping sample. " )
continue # Skip problematic sample
2025-05-20 14:38:03 -07:00
if self . mixup_fn is not None :
batch_imgs , batch_targets = self . mixup_fn ( batch_imgs , batch_targets )
batch_imgs = [ self . patchifier ( img ) for img in batch_imgs ]
batch_samples = list ( zip ( batch_imgs , batch_targets ) )
2025-04-07 21:27:10 -07:00
if batch_samples : # Only yield if we successfully processed samples
2025-05-20 14:38:03 -07:00
# Collate the processed samples into a batch
2025-04-29 10:44:46 -07:00
yield self . collate_fns [ seq_len ] ( batch_samples )
2025-04-07 21:27:10 -07:00
# If batch_samples is empty after processing 'indices', an empty batch is skipped.