From e2073e32d0782465606f6b6be5f516df881cdf61 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 29 Apr 2025 10:44:46 -0700 Subject: [PATCH] Move NaFlexCollate with dataset, remove stand alone collate_fn and remove redundancy --- timm/data/naflex_dataset.py | 116 ++++++++++++++++++------------------ timm/data/naflex_loader.py | 73 +---------------------- 2 files changed, 61 insertions(+), 128 deletions(-) diff --git a/timm/data/naflex_dataset.py b/timm/data/naflex_dataset.py index 201d0325..858a182f 100644 --- a/timm/data/naflex_dataset.py +++ b/timm/data/naflex_dataset.py @@ -59,63 +59,65 @@ def calculate_batch_size( return batch_size -def _collate_batch( - batch_samples: List[Tuple[Dict[str, torch.Tensor], Any]], - target_seq_len: int - ) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]: - """Collates processed samples into a batch, padding/truncating to target_seq_len.""" - batch_patch_data = [item[0] for item in batch_samples] - batch_labels = [item[1] for item in batch_samples] +class NaFlexCollator: + """Custom collator for batching NaFlex-style variable-resolution images.""" - if not batch_patch_data: - return {}, torch.empty(0) + def __init__( + self, + max_seq_len=None, + ): + self.max_seq_len = max_seq_len or 576 # Default ViT-B/16 sequence length (577 = 24*24) - batch_size = len(batch_patch_data) - patch_dim = batch_patch_data[0]['patches'].shape[1] + def __call__(self, batch): + """ + Args: + batch: List of tuples (patch_dict, target) - # Initialize tensors with target sequence length - patches_batch = torch.zeros((batch_size, target_seq_len, patch_dim), dtype=torch.float32) - patch_coord_batch = torch.zeros((batch_size, target_seq_len, 2), dtype=torch.int64) - patch_valid_batch = torch.zeros((batch_size, target_seq_len), dtype=torch.bool) # Use bool + Returns: + A tuple of (input_dict, targets) where input_dict contains: + - patches: Padded tensor of patches + - patch_coord: Coordinates for each patch (y, x) + - patch_valid: Valid indicators + """ + assert isinstance(batch[0], tuple) + batch_size = len(batch) - for i, data in enumerate(batch_patch_data): - num_patches = data['patches'].shape[0] - # Take min(num_patches, target_seq_len) patches - n_copy = min(num_patches, target_seq_len) + # Extract targets + # FIXME need to handle dense (float) targets or always done downstream of this? + targets = torch.tensor([item[1] for item in batch], dtype=torch.int64) - patches_batch[i, :n_copy] = data['patches'][:n_copy] - patch_coord_batch[i, :n_copy] = data['patch_coord'][:n_copy] - patch_valid_batch[i, :n_copy] = data['patch_valid'][:n_copy] # Copy validity flags + # Get patch dictionaries + patch_dicts = [item[0] for item in batch] - # Create the final input dict - input_dict = { - 'patches': patches_batch, - 'patch_coord': patch_coord_batch, - 'patch_valid': patch_valid_batch, # Boolean mask - # Note: 'seq_length' might be ambiguous. The target length is target_seq_len. - # The actual number of valid patches per sample varies. - # 'patch_valid' mask is the most reliable source of truth. - } - - # Attempt to stack labels if they are tensors, otherwise return list - try: - if isinstance(batch_labels[0], torch.Tensor): - labels_tensor = torch.stack(batch_labels) + # If we have a maximum sequence length constraint, ensure we don't exceed it + if self.max_seq_len is not None: + max_patches = self.max_seq_len else: - # Convert numerical types to tensor, keep others as list (or handle specific types) - if isinstance(batch_labels[0], (int, float)): - labels_tensor = torch.tensor(batch_labels) - else: - # Cannot convert non-numerical labels easily, return as list - # Or handle specific conversion if needed - # For FakeDataset, labels are ints, so this works - labels_tensor = torch.tensor(batch_labels) # Assuming labels are numerical - except Exception: - # Fallback if stacking fails (e.g., different shapes, types) - print("Warning: Could not stack labels into a tensor. Returning list of labels.") - labels_tensor = batch_labels # Return as list + # Find the maximum number of patches in this batch + max_patches = max(item['patches'].shape[0] for item in patch_dicts) - return input_dict, labels_tensor + # Get patch dimensionality + patch_dim = patch_dicts[0]['patches'].shape[1] + + # Prepare tensors for the batch + patches = torch.zeros((batch_size, max_patches, patch_dim), dtype=torch.float32) + patch_coord = torch.zeros((batch_size, max_patches, 2), dtype=torch.int64) # [B, N, 2] for (y, x) + patch_valid = torch.zeros((batch_size, max_patches), dtype=torch.bool) + + # Fill in the tensors + for i, patch_dict in enumerate(patch_dicts): + num_patches = min(patch_dict['patches'].shape[0], max_patches) + + patches[i, :num_patches] = patch_dict['patches'][:num_patches] + patch_coord[i, :num_patches] = patch_dict['patch_coord'][:num_patches] + patch_valid[i, :num_patches] = patch_dict['patch_valid'][:num_patches] + + return { + 'patches': patches, + 'patch_coord': patch_coord, + 'patch_valid': patch_valid, + 'seq_len': max_patches, + }, targets class VariableSeqMapWrapper(IterableDataset): @@ -161,15 +163,15 @@ class VariableSeqMapWrapper(IterableDataset): self.epoch = epoch self.batch_divisor = batch_divisor - # Pre-initialize transforms for each sequence length + # Pre-initialize transforms and collate fns for each sequence length self.transforms: Dict[int, Optional[Callable]] = {} - if transform_factory: - for seq_len in self.seq_lens: + self.collate_fns: Dict[int, Callable] = {} + for seq_len in self.seq_lens: + if transform_factory: self.transforms[seq_len] = transform_factory(max_seq_len=seq_len, patch_size=self.patch_size) - else: - for seq_len in self.seq_lens: - self.transforms[seq_len] = None # No transform - + else: + self.transforms[seq_len] = None # No transform + self.collate_fns[seq_len] = NaFlexCollator(seq_len) self.patchifier = Patchify(self.patch_size) # --- Canonical Schedule Calculation (Done Once) --- @@ -417,6 +419,6 @@ class VariableSeqMapWrapper(IterableDataset): # Collate the processed samples into a batch if batch_samples: # Only yield if we successfully processed samples - yield _collate_batch(batch_samples, seq_len) + yield self.collate_fns[seq_len](batch_samples) # If batch_samples is empty after processing 'indices', an empty batch is skipped. diff --git a/timm/data/naflex_loader.py b/timm/data/naflex_loader.py index 917dfa0b..bb96d07d 100644 --- a/timm/data/naflex_loader.py +++ b/timm/data/naflex_loader.py @@ -7,74 +7,10 @@ import torch from .constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .loader import _worker_init -from .naflex_dataset import VariableSeqMapWrapper +from .naflex_dataset import VariableSeqMapWrapper, NaFlexCollator from .transforms_factory import create_transform -class NaFlexCollator: - """Custom collator for batching NaFlex-style variable-resolution images.""" - - def __init__( - self, - patch_size=16, - max_seq_len=None, - ): - self.patch_size = patch_size - self.max_seq_len = max_seq_len or 576 # Default ViT-B/16 sequence length (577 = 24*24) - - def __call__(self, batch): - """ - Args: - batch: List of tuples (patch_dict, target) - - Returns: - A tuple of (input_dict, targets) where input_dict contains: - - patches: Padded tensor of patches - - patch_coord: Coordinates for each patch (y, x) - - patch_valid: Valid indicators - """ - assert isinstance(batch[0], tuple) - batch_size = len(batch) - - # Resize to final size based on seq_len and patchify - - # Extract targets - targets = torch.tensor([item[1] for item in batch], dtype=torch.int64) - - # Get patch dictionaries - patch_dicts = [item[0] for item in batch] - - # If we have a maximum sequence length constraint, ensure we don't exceed it - if self.max_seq_len is not None: - max_patches = self.max_seq_len - else: - # Find the maximum number of patches in this batch - max_patches = max(item['patches'].shape[0] for item in patch_dicts) - - # Get patch dimensionality - patch_dim = patch_dicts[0]['patches'].shape[1] - - # Prepare tensors for the batch - patches = torch.zeros((batch_size, max_patches, patch_dim), dtype=torch.float32) - patch_coord = torch.zeros((batch_size, max_patches, 2), dtype=torch.int64) # [B, N, 2] for (y, x) - patch_valid = torch.zeros((batch_size, max_patches), dtype=torch.bool) - - # Fill in the tensors - for i, patch_dict in enumerate(patch_dicts): - num_patches = min(patch_dict['patches'].shape[0], max_patches) - - patches[i, :num_patches] = patch_dict['patches'][:num_patches] - patch_coord[i, :num_patches] = patch_dict['patch_coord'][:num_patches] - patch_valid[i, :num_patches] = patch_dict['patch_valid'][:num_patches] - - return { - 'patches': patches, - 'patch_coord': patch_coord, - 'patch_valid': patch_valid, - 'seq_len': max_patches, - }, targets - - class NaFlexPrefetchLoader: """Data prefetcher for NaFlex format which normalizes patches.""" @@ -261,9 +197,7 @@ def create_naflex_loader( # NOTE: Collation is handled by the dataset wrapper for training # Create the collator (handles fixed-size collation) # collate_fn = NaFlexCollator( - # patch_size=patch_size, # max_seq_len=max(seq_lens) + 1, # +1 for class token - # use_prefetcher=use_prefetcher # ) loader = torch.utils.data.DataLoader( @@ -303,10 +237,7 @@ def create_naflex_loader( ) # Create the collator - collate_fn = NaFlexCollator( - patch_size=patch_size, - max_seq_len=max_seq_len, - ) + collate_fn = NaFlexCollator(max_seq_len=max_seq_len) # Handle distributed training sampler = None