mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Move NaFlexCollate with dataset, remove stand alone collate_fn and remove redundancy
This commit is contained in:
parent
39eb56f875
commit
e2073e32d0
@ -59,63 +59,65 @@ def calculate_batch_size(
|
||||
return batch_size
|
||||
|
||||
|
||||
def _collate_batch(
|
||||
batch_samples: List[Tuple[Dict[str, torch.Tensor], Any]],
|
||||
target_seq_len: int
|
||||
) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]:
|
||||
"""Collates processed samples into a batch, padding/truncating to target_seq_len."""
|
||||
batch_patch_data = [item[0] for item in batch_samples]
|
||||
batch_labels = [item[1] for item in batch_samples]
|
||||
class NaFlexCollator:
|
||||
"""Custom collator for batching NaFlex-style variable-resolution images."""
|
||||
|
||||
if not batch_patch_data:
|
||||
return {}, torch.empty(0)
|
||||
def __init__(
|
||||
self,
|
||||
max_seq_len=None,
|
||||
):
|
||||
self.max_seq_len = max_seq_len or 576 # Default ViT-B/16 sequence length (577 = 24*24)
|
||||
|
||||
batch_size = len(batch_patch_data)
|
||||
patch_dim = batch_patch_data[0]['patches'].shape[1]
|
||||
def __call__(self, batch):
|
||||
"""
|
||||
Args:
|
||||
batch: List of tuples (patch_dict, target)
|
||||
|
||||
# Initialize tensors with target sequence length
|
||||
patches_batch = torch.zeros((batch_size, target_seq_len, patch_dim), dtype=torch.float32)
|
||||
patch_coord_batch = torch.zeros((batch_size, target_seq_len, 2), dtype=torch.int64)
|
||||
patch_valid_batch = torch.zeros((batch_size, target_seq_len), dtype=torch.bool) # Use bool
|
||||
Returns:
|
||||
A tuple of (input_dict, targets) where input_dict contains:
|
||||
- patches: Padded tensor of patches
|
||||
- patch_coord: Coordinates for each patch (y, x)
|
||||
- patch_valid: Valid indicators
|
||||
"""
|
||||
assert isinstance(batch[0], tuple)
|
||||
batch_size = len(batch)
|
||||
|
||||
for i, data in enumerate(batch_patch_data):
|
||||
num_patches = data['patches'].shape[0]
|
||||
# Take min(num_patches, target_seq_len) patches
|
||||
n_copy = min(num_patches, target_seq_len)
|
||||
# Extract targets
|
||||
# FIXME need to handle dense (float) targets or always done downstream of this?
|
||||
targets = torch.tensor([item[1] for item in batch], dtype=torch.int64)
|
||||
|
||||
patches_batch[i, :n_copy] = data['patches'][:n_copy]
|
||||
patch_coord_batch[i, :n_copy] = data['patch_coord'][:n_copy]
|
||||
patch_valid_batch[i, :n_copy] = data['patch_valid'][:n_copy] # Copy validity flags
|
||||
# Get patch dictionaries
|
||||
patch_dicts = [item[0] for item in batch]
|
||||
|
||||
# Create the final input dict
|
||||
input_dict = {
|
||||
'patches': patches_batch,
|
||||
'patch_coord': patch_coord_batch,
|
||||
'patch_valid': patch_valid_batch, # Boolean mask
|
||||
# Note: 'seq_length' might be ambiguous. The target length is target_seq_len.
|
||||
# The actual number of valid patches per sample varies.
|
||||
# 'patch_valid' mask is the most reliable source of truth.
|
||||
}
|
||||
|
||||
# Attempt to stack labels if they are tensors, otherwise return list
|
||||
try:
|
||||
if isinstance(batch_labels[0], torch.Tensor):
|
||||
labels_tensor = torch.stack(batch_labels)
|
||||
# If we have a maximum sequence length constraint, ensure we don't exceed it
|
||||
if self.max_seq_len is not None:
|
||||
max_patches = self.max_seq_len
|
||||
else:
|
||||
# Convert numerical types to tensor, keep others as list (or handle specific types)
|
||||
if isinstance(batch_labels[0], (int, float)):
|
||||
labels_tensor = torch.tensor(batch_labels)
|
||||
else:
|
||||
# Cannot convert non-numerical labels easily, return as list
|
||||
# Or handle specific conversion if needed
|
||||
# For FakeDataset, labels are ints, so this works
|
||||
labels_tensor = torch.tensor(batch_labels) # Assuming labels are numerical
|
||||
except Exception:
|
||||
# Fallback if stacking fails (e.g., different shapes, types)
|
||||
print("Warning: Could not stack labels into a tensor. Returning list of labels.")
|
||||
labels_tensor = batch_labels # Return as list
|
||||
# Find the maximum number of patches in this batch
|
||||
max_patches = max(item['patches'].shape[0] for item in patch_dicts)
|
||||
|
||||
return input_dict, labels_tensor
|
||||
# Get patch dimensionality
|
||||
patch_dim = patch_dicts[0]['patches'].shape[1]
|
||||
|
||||
# Prepare tensors for the batch
|
||||
patches = torch.zeros((batch_size, max_patches, patch_dim), dtype=torch.float32)
|
||||
patch_coord = torch.zeros((batch_size, max_patches, 2), dtype=torch.int64) # [B, N, 2] for (y, x)
|
||||
patch_valid = torch.zeros((batch_size, max_patches), dtype=torch.bool)
|
||||
|
||||
# Fill in the tensors
|
||||
for i, patch_dict in enumerate(patch_dicts):
|
||||
num_patches = min(patch_dict['patches'].shape[0], max_patches)
|
||||
|
||||
patches[i, :num_patches] = patch_dict['patches'][:num_patches]
|
||||
patch_coord[i, :num_patches] = patch_dict['patch_coord'][:num_patches]
|
||||
patch_valid[i, :num_patches] = patch_dict['patch_valid'][:num_patches]
|
||||
|
||||
return {
|
||||
'patches': patches,
|
||||
'patch_coord': patch_coord,
|
||||
'patch_valid': patch_valid,
|
||||
'seq_len': max_patches,
|
||||
}, targets
|
||||
|
||||
|
||||
class VariableSeqMapWrapper(IterableDataset):
|
||||
@ -161,15 +163,15 @@ class VariableSeqMapWrapper(IterableDataset):
|
||||
self.epoch = epoch
|
||||
self.batch_divisor = batch_divisor
|
||||
|
||||
# Pre-initialize transforms for each sequence length
|
||||
# Pre-initialize transforms and collate fns for each sequence length
|
||||
self.transforms: Dict[int, Optional[Callable]] = {}
|
||||
if transform_factory:
|
||||
for seq_len in self.seq_lens:
|
||||
self.collate_fns: Dict[int, Callable] = {}
|
||||
for seq_len in self.seq_lens:
|
||||
if transform_factory:
|
||||
self.transforms[seq_len] = transform_factory(max_seq_len=seq_len, patch_size=self.patch_size)
|
||||
else:
|
||||
for seq_len in self.seq_lens:
|
||||
self.transforms[seq_len] = None # No transform
|
||||
|
||||
else:
|
||||
self.transforms[seq_len] = None # No transform
|
||||
self.collate_fns[seq_len] = NaFlexCollator(seq_len)
|
||||
self.patchifier = Patchify(self.patch_size)
|
||||
|
||||
# --- Canonical Schedule Calculation (Done Once) ---
|
||||
@ -417,6 +419,6 @@ class VariableSeqMapWrapper(IterableDataset):
|
||||
|
||||
# Collate the processed samples into a batch
|
||||
if batch_samples: # Only yield if we successfully processed samples
|
||||
yield _collate_batch(batch_samples, seq_len)
|
||||
yield self.collate_fns[seq_len](batch_samples)
|
||||
|
||||
# If batch_samples is empty after processing 'indices', an empty batch is skipped.
|
||||
|
@ -7,74 +7,10 @@ import torch
|
||||
|
||||
from .constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from .loader import _worker_init
|
||||
from .naflex_dataset import VariableSeqMapWrapper
|
||||
from .naflex_dataset import VariableSeqMapWrapper, NaFlexCollator
|
||||
from .transforms_factory import create_transform
|
||||
|
||||
|
||||
class NaFlexCollator:
|
||||
"""Custom collator for batching NaFlex-style variable-resolution images."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
patch_size=16,
|
||||
max_seq_len=None,
|
||||
):
|
||||
self.patch_size = patch_size
|
||||
self.max_seq_len = max_seq_len or 576 # Default ViT-B/16 sequence length (577 = 24*24)
|
||||
|
||||
def __call__(self, batch):
|
||||
"""
|
||||
Args:
|
||||
batch: List of tuples (patch_dict, target)
|
||||
|
||||
Returns:
|
||||
A tuple of (input_dict, targets) where input_dict contains:
|
||||
- patches: Padded tensor of patches
|
||||
- patch_coord: Coordinates for each patch (y, x)
|
||||
- patch_valid: Valid indicators
|
||||
"""
|
||||
assert isinstance(batch[0], tuple)
|
||||
batch_size = len(batch)
|
||||
|
||||
# Resize to final size based on seq_len and patchify
|
||||
|
||||
# Extract targets
|
||||
targets = torch.tensor([item[1] for item in batch], dtype=torch.int64)
|
||||
|
||||
# Get patch dictionaries
|
||||
patch_dicts = [item[0] for item in batch]
|
||||
|
||||
# If we have a maximum sequence length constraint, ensure we don't exceed it
|
||||
if self.max_seq_len is not None:
|
||||
max_patches = self.max_seq_len
|
||||
else:
|
||||
# Find the maximum number of patches in this batch
|
||||
max_patches = max(item['patches'].shape[0] for item in patch_dicts)
|
||||
|
||||
# Get patch dimensionality
|
||||
patch_dim = patch_dicts[0]['patches'].shape[1]
|
||||
|
||||
# Prepare tensors for the batch
|
||||
patches = torch.zeros((batch_size, max_patches, patch_dim), dtype=torch.float32)
|
||||
patch_coord = torch.zeros((batch_size, max_patches, 2), dtype=torch.int64) # [B, N, 2] for (y, x)
|
||||
patch_valid = torch.zeros((batch_size, max_patches), dtype=torch.bool)
|
||||
|
||||
# Fill in the tensors
|
||||
for i, patch_dict in enumerate(patch_dicts):
|
||||
num_patches = min(patch_dict['patches'].shape[0], max_patches)
|
||||
|
||||
patches[i, :num_patches] = patch_dict['patches'][:num_patches]
|
||||
patch_coord[i, :num_patches] = patch_dict['patch_coord'][:num_patches]
|
||||
patch_valid[i, :num_patches] = patch_dict['patch_valid'][:num_patches]
|
||||
|
||||
return {
|
||||
'patches': patches,
|
||||
'patch_coord': patch_coord,
|
||||
'patch_valid': patch_valid,
|
||||
'seq_len': max_patches,
|
||||
}, targets
|
||||
|
||||
|
||||
class NaFlexPrefetchLoader:
|
||||
"""Data prefetcher for NaFlex format which normalizes patches."""
|
||||
|
||||
@ -261,9 +197,7 @@ def create_naflex_loader(
|
||||
# NOTE: Collation is handled by the dataset wrapper for training
|
||||
# Create the collator (handles fixed-size collation)
|
||||
# collate_fn = NaFlexCollator(
|
||||
# patch_size=patch_size,
|
||||
# max_seq_len=max(seq_lens) + 1, # +1 for class token
|
||||
# use_prefetcher=use_prefetcher
|
||||
# )
|
||||
|
||||
loader = torch.utils.data.DataLoader(
|
||||
@ -303,10 +237,7 @@ def create_naflex_loader(
|
||||
)
|
||||
|
||||
# Create the collator
|
||||
collate_fn = NaFlexCollator(
|
||||
patch_size=patch_size,
|
||||
max_seq_len=max_seq_len,
|
||||
)
|
||||
collate_fn = NaFlexCollator(max_seq_len=max_seq_len)
|
||||
|
||||
# Handle distributed training
|
||||
sampler = None
|
||||
|
Loading…
x
Reference in New Issue
Block a user