Move NaFlexCollate with dataset, remove stand alone collate_fn and remove redundancy

This commit is contained in:
Ross Wightman 2025-04-29 10:44:46 -07:00
parent 39eb56f875
commit e2073e32d0
2 changed files with 61 additions and 128 deletions

View File

@ -59,63 +59,65 @@ def calculate_batch_size(
return batch_size
def _collate_batch(
batch_samples: List[Tuple[Dict[str, torch.Tensor], Any]],
target_seq_len: int
) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]:
"""Collates processed samples into a batch, padding/truncating to target_seq_len."""
batch_patch_data = [item[0] for item in batch_samples]
batch_labels = [item[1] for item in batch_samples]
class NaFlexCollator:
"""Custom collator for batching NaFlex-style variable-resolution images."""
if not batch_patch_data:
return {}, torch.empty(0)
def __init__(
self,
max_seq_len=None,
):
self.max_seq_len = max_seq_len or 576 # Default ViT-B/16 sequence length (577 = 24*24)
batch_size = len(batch_patch_data)
patch_dim = batch_patch_data[0]['patches'].shape[1]
def __call__(self, batch):
"""
Args:
batch: List of tuples (patch_dict, target)
# Initialize tensors with target sequence length
patches_batch = torch.zeros((batch_size, target_seq_len, patch_dim), dtype=torch.float32)
patch_coord_batch = torch.zeros((batch_size, target_seq_len, 2), dtype=torch.int64)
patch_valid_batch = torch.zeros((batch_size, target_seq_len), dtype=torch.bool) # Use bool
Returns:
A tuple of (input_dict, targets) where input_dict contains:
- patches: Padded tensor of patches
- patch_coord: Coordinates for each patch (y, x)
- patch_valid: Valid indicators
"""
assert isinstance(batch[0], tuple)
batch_size = len(batch)
for i, data in enumerate(batch_patch_data):
num_patches = data['patches'].shape[0]
# Take min(num_patches, target_seq_len) patches
n_copy = min(num_patches, target_seq_len)
# Extract targets
# FIXME need to handle dense (float) targets or always done downstream of this?
targets = torch.tensor([item[1] for item in batch], dtype=torch.int64)
patches_batch[i, :n_copy] = data['patches'][:n_copy]
patch_coord_batch[i, :n_copy] = data['patch_coord'][:n_copy]
patch_valid_batch[i, :n_copy] = data['patch_valid'][:n_copy] # Copy validity flags
# Get patch dictionaries
patch_dicts = [item[0] for item in batch]
# Create the final input dict
input_dict = {
'patches': patches_batch,
'patch_coord': patch_coord_batch,
'patch_valid': patch_valid_batch, # Boolean mask
# Note: 'seq_length' might be ambiguous. The target length is target_seq_len.
# The actual number of valid patches per sample varies.
# 'patch_valid' mask is the most reliable source of truth.
}
# Attempt to stack labels if they are tensors, otherwise return list
try:
if isinstance(batch_labels[0], torch.Tensor):
labels_tensor = torch.stack(batch_labels)
# If we have a maximum sequence length constraint, ensure we don't exceed it
if self.max_seq_len is not None:
max_patches = self.max_seq_len
else:
# Convert numerical types to tensor, keep others as list (or handle specific types)
if isinstance(batch_labels[0], (int, float)):
labels_tensor = torch.tensor(batch_labels)
else:
# Cannot convert non-numerical labels easily, return as list
# Or handle specific conversion if needed
# For FakeDataset, labels are ints, so this works
labels_tensor = torch.tensor(batch_labels) # Assuming labels are numerical
except Exception:
# Fallback if stacking fails (e.g., different shapes, types)
print("Warning: Could not stack labels into a tensor. Returning list of labels.")
labels_tensor = batch_labels # Return as list
# Find the maximum number of patches in this batch
max_patches = max(item['patches'].shape[0] for item in patch_dicts)
return input_dict, labels_tensor
# Get patch dimensionality
patch_dim = patch_dicts[0]['patches'].shape[1]
# Prepare tensors for the batch
patches = torch.zeros((batch_size, max_patches, patch_dim), dtype=torch.float32)
patch_coord = torch.zeros((batch_size, max_patches, 2), dtype=torch.int64) # [B, N, 2] for (y, x)
patch_valid = torch.zeros((batch_size, max_patches), dtype=torch.bool)
# Fill in the tensors
for i, patch_dict in enumerate(patch_dicts):
num_patches = min(patch_dict['patches'].shape[0], max_patches)
patches[i, :num_patches] = patch_dict['patches'][:num_patches]
patch_coord[i, :num_patches] = patch_dict['patch_coord'][:num_patches]
patch_valid[i, :num_patches] = patch_dict['patch_valid'][:num_patches]
return {
'patches': patches,
'patch_coord': patch_coord,
'patch_valid': patch_valid,
'seq_len': max_patches,
}, targets
class VariableSeqMapWrapper(IterableDataset):
@ -161,15 +163,15 @@ class VariableSeqMapWrapper(IterableDataset):
self.epoch = epoch
self.batch_divisor = batch_divisor
# Pre-initialize transforms for each sequence length
# Pre-initialize transforms and collate fns for each sequence length
self.transforms: Dict[int, Optional[Callable]] = {}
if transform_factory:
for seq_len in self.seq_lens:
self.collate_fns: Dict[int, Callable] = {}
for seq_len in self.seq_lens:
if transform_factory:
self.transforms[seq_len] = transform_factory(max_seq_len=seq_len, patch_size=self.patch_size)
else:
for seq_len in self.seq_lens:
self.transforms[seq_len] = None # No transform
else:
self.transforms[seq_len] = None # No transform
self.collate_fns[seq_len] = NaFlexCollator(seq_len)
self.patchifier = Patchify(self.patch_size)
# --- Canonical Schedule Calculation (Done Once) ---
@ -417,6 +419,6 @@ class VariableSeqMapWrapper(IterableDataset):
# Collate the processed samples into a batch
if batch_samples: # Only yield if we successfully processed samples
yield _collate_batch(batch_samples, seq_len)
yield self.collate_fns[seq_len](batch_samples)
# If batch_samples is empty after processing 'indices', an empty batch is skipped.

View File

@ -7,74 +7,10 @@ import torch
from .constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .loader import _worker_init
from .naflex_dataset import VariableSeqMapWrapper
from .naflex_dataset import VariableSeqMapWrapper, NaFlexCollator
from .transforms_factory import create_transform
class NaFlexCollator:
"""Custom collator for batching NaFlex-style variable-resolution images."""
def __init__(
self,
patch_size=16,
max_seq_len=None,
):
self.patch_size = patch_size
self.max_seq_len = max_seq_len or 576 # Default ViT-B/16 sequence length (577 = 24*24)
def __call__(self, batch):
"""
Args:
batch: List of tuples (patch_dict, target)
Returns:
A tuple of (input_dict, targets) where input_dict contains:
- patches: Padded tensor of patches
- patch_coord: Coordinates for each patch (y, x)
- patch_valid: Valid indicators
"""
assert isinstance(batch[0], tuple)
batch_size = len(batch)
# Resize to final size based on seq_len and patchify
# Extract targets
targets = torch.tensor([item[1] for item in batch], dtype=torch.int64)
# Get patch dictionaries
patch_dicts = [item[0] for item in batch]
# If we have a maximum sequence length constraint, ensure we don't exceed it
if self.max_seq_len is not None:
max_patches = self.max_seq_len
else:
# Find the maximum number of patches in this batch
max_patches = max(item['patches'].shape[0] for item in patch_dicts)
# Get patch dimensionality
patch_dim = patch_dicts[0]['patches'].shape[1]
# Prepare tensors for the batch
patches = torch.zeros((batch_size, max_patches, patch_dim), dtype=torch.float32)
patch_coord = torch.zeros((batch_size, max_patches, 2), dtype=torch.int64) # [B, N, 2] for (y, x)
patch_valid = torch.zeros((batch_size, max_patches), dtype=torch.bool)
# Fill in the tensors
for i, patch_dict in enumerate(patch_dicts):
num_patches = min(patch_dict['patches'].shape[0], max_patches)
patches[i, :num_patches] = patch_dict['patches'][:num_patches]
patch_coord[i, :num_patches] = patch_dict['patch_coord'][:num_patches]
patch_valid[i, :num_patches] = patch_dict['patch_valid'][:num_patches]
return {
'patches': patches,
'patch_coord': patch_coord,
'patch_valid': patch_valid,
'seq_len': max_patches,
}, targets
class NaFlexPrefetchLoader:
"""Data prefetcher for NaFlex format which normalizes patches."""
@ -261,9 +197,7 @@ def create_naflex_loader(
# NOTE: Collation is handled by the dataset wrapper for training
# Create the collator (handles fixed-size collation)
# collate_fn = NaFlexCollator(
# patch_size=patch_size,
# max_seq_len=max(seq_lens) + 1, # +1 for class token
# use_prefetcher=use_prefetcher
# )
loader = torch.utils.data.DataLoader(
@ -303,10 +237,7 @@ def create_naflex_loader(
)
# Create the collator
collate_fn = NaFlexCollator(
patch_size=patch_size,
max_seq_len=max_seq_len,
)
collate_fn = NaFlexCollator(max_seq_len=max_seq_len)
# Handle distributed training
sampler = None