mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Move NaFlexCollate with dataset, remove stand alone collate_fn and remove redundancy
This commit is contained in:
parent
39eb56f875
commit
e2073e32d0
@ -59,63 +59,65 @@ def calculate_batch_size(
|
|||||||
return batch_size
|
return batch_size
|
||||||
|
|
||||||
|
|
||||||
def _collate_batch(
|
class NaFlexCollator:
|
||||||
batch_samples: List[Tuple[Dict[str, torch.Tensor], Any]],
|
"""Custom collator for batching NaFlex-style variable-resolution images."""
|
||||||
target_seq_len: int
|
|
||||||
) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]:
|
|
||||||
"""Collates processed samples into a batch, padding/truncating to target_seq_len."""
|
|
||||||
batch_patch_data = [item[0] for item in batch_samples]
|
|
||||||
batch_labels = [item[1] for item in batch_samples]
|
|
||||||
|
|
||||||
if not batch_patch_data:
|
def __init__(
|
||||||
return {}, torch.empty(0)
|
self,
|
||||||
|
max_seq_len=None,
|
||||||
|
):
|
||||||
|
self.max_seq_len = max_seq_len or 576 # Default ViT-B/16 sequence length (577 = 24*24)
|
||||||
|
|
||||||
batch_size = len(batch_patch_data)
|
def __call__(self, batch):
|
||||||
patch_dim = batch_patch_data[0]['patches'].shape[1]
|
"""
|
||||||
|
Args:
|
||||||
|
batch: List of tuples (patch_dict, target)
|
||||||
|
|
||||||
# Initialize tensors with target sequence length
|
Returns:
|
||||||
patches_batch = torch.zeros((batch_size, target_seq_len, patch_dim), dtype=torch.float32)
|
A tuple of (input_dict, targets) where input_dict contains:
|
||||||
patch_coord_batch = torch.zeros((batch_size, target_seq_len, 2), dtype=torch.int64)
|
- patches: Padded tensor of patches
|
||||||
patch_valid_batch = torch.zeros((batch_size, target_seq_len), dtype=torch.bool) # Use bool
|
- patch_coord: Coordinates for each patch (y, x)
|
||||||
|
- patch_valid: Valid indicators
|
||||||
|
"""
|
||||||
|
assert isinstance(batch[0], tuple)
|
||||||
|
batch_size = len(batch)
|
||||||
|
|
||||||
for i, data in enumerate(batch_patch_data):
|
# Extract targets
|
||||||
num_patches = data['patches'].shape[0]
|
# FIXME need to handle dense (float) targets or always done downstream of this?
|
||||||
# Take min(num_patches, target_seq_len) patches
|
targets = torch.tensor([item[1] for item in batch], dtype=torch.int64)
|
||||||
n_copy = min(num_patches, target_seq_len)
|
|
||||||
|
|
||||||
patches_batch[i, :n_copy] = data['patches'][:n_copy]
|
# Get patch dictionaries
|
||||||
patch_coord_batch[i, :n_copy] = data['patch_coord'][:n_copy]
|
patch_dicts = [item[0] for item in batch]
|
||||||
patch_valid_batch[i, :n_copy] = data['patch_valid'][:n_copy] # Copy validity flags
|
|
||||||
|
|
||||||
# Create the final input dict
|
# If we have a maximum sequence length constraint, ensure we don't exceed it
|
||||||
input_dict = {
|
if self.max_seq_len is not None:
|
||||||
'patches': patches_batch,
|
max_patches = self.max_seq_len
|
||||||
'patch_coord': patch_coord_batch,
|
|
||||||
'patch_valid': patch_valid_batch, # Boolean mask
|
|
||||||
# Note: 'seq_length' might be ambiguous. The target length is target_seq_len.
|
|
||||||
# The actual number of valid patches per sample varies.
|
|
||||||
# 'patch_valid' mask is the most reliable source of truth.
|
|
||||||
}
|
|
||||||
|
|
||||||
# Attempt to stack labels if they are tensors, otherwise return list
|
|
||||||
try:
|
|
||||||
if isinstance(batch_labels[0], torch.Tensor):
|
|
||||||
labels_tensor = torch.stack(batch_labels)
|
|
||||||
else:
|
else:
|
||||||
# Convert numerical types to tensor, keep others as list (or handle specific types)
|
# Find the maximum number of patches in this batch
|
||||||
if isinstance(batch_labels[0], (int, float)):
|
max_patches = max(item['patches'].shape[0] for item in patch_dicts)
|
||||||
labels_tensor = torch.tensor(batch_labels)
|
|
||||||
else:
|
|
||||||
# Cannot convert non-numerical labels easily, return as list
|
|
||||||
# Or handle specific conversion if needed
|
|
||||||
# For FakeDataset, labels are ints, so this works
|
|
||||||
labels_tensor = torch.tensor(batch_labels) # Assuming labels are numerical
|
|
||||||
except Exception:
|
|
||||||
# Fallback if stacking fails (e.g., different shapes, types)
|
|
||||||
print("Warning: Could not stack labels into a tensor. Returning list of labels.")
|
|
||||||
labels_tensor = batch_labels # Return as list
|
|
||||||
|
|
||||||
return input_dict, labels_tensor
|
# Get patch dimensionality
|
||||||
|
patch_dim = patch_dicts[0]['patches'].shape[1]
|
||||||
|
|
||||||
|
# Prepare tensors for the batch
|
||||||
|
patches = torch.zeros((batch_size, max_patches, patch_dim), dtype=torch.float32)
|
||||||
|
patch_coord = torch.zeros((batch_size, max_patches, 2), dtype=torch.int64) # [B, N, 2] for (y, x)
|
||||||
|
patch_valid = torch.zeros((batch_size, max_patches), dtype=torch.bool)
|
||||||
|
|
||||||
|
# Fill in the tensors
|
||||||
|
for i, patch_dict in enumerate(patch_dicts):
|
||||||
|
num_patches = min(patch_dict['patches'].shape[0], max_patches)
|
||||||
|
|
||||||
|
patches[i, :num_patches] = patch_dict['patches'][:num_patches]
|
||||||
|
patch_coord[i, :num_patches] = patch_dict['patch_coord'][:num_patches]
|
||||||
|
patch_valid[i, :num_patches] = patch_dict['patch_valid'][:num_patches]
|
||||||
|
|
||||||
|
return {
|
||||||
|
'patches': patches,
|
||||||
|
'patch_coord': patch_coord,
|
||||||
|
'patch_valid': patch_valid,
|
||||||
|
'seq_len': max_patches,
|
||||||
|
}, targets
|
||||||
|
|
||||||
|
|
||||||
class VariableSeqMapWrapper(IterableDataset):
|
class VariableSeqMapWrapper(IterableDataset):
|
||||||
@ -161,15 +163,15 @@ class VariableSeqMapWrapper(IterableDataset):
|
|||||||
self.epoch = epoch
|
self.epoch = epoch
|
||||||
self.batch_divisor = batch_divisor
|
self.batch_divisor = batch_divisor
|
||||||
|
|
||||||
# Pre-initialize transforms for each sequence length
|
# Pre-initialize transforms and collate fns for each sequence length
|
||||||
self.transforms: Dict[int, Optional[Callable]] = {}
|
self.transforms: Dict[int, Optional[Callable]] = {}
|
||||||
if transform_factory:
|
self.collate_fns: Dict[int, Callable] = {}
|
||||||
for seq_len in self.seq_lens:
|
for seq_len in self.seq_lens:
|
||||||
|
if transform_factory:
|
||||||
self.transforms[seq_len] = transform_factory(max_seq_len=seq_len, patch_size=self.patch_size)
|
self.transforms[seq_len] = transform_factory(max_seq_len=seq_len, patch_size=self.patch_size)
|
||||||
else:
|
else:
|
||||||
for seq_len in self.seq_lens:
|
self.transforms[seq_len] = None # No transform
|
||||||
self.transforms[seq_len] = None # No transform
|
self.collate_fns[seq_len] = NaFlexCollator(seq_len)
|
||||||
|
|
||||||
self.patchifier = Patchify(self.patch_size)
|
self.patchifier = Patchify(self.patch_size)
|
||||||
|
|
||||||
# --- Canonical Schedule Calculation (Done Once) ---
|
# --- Canonical Schedule Calculation (Done Once) ---
|
||||||
@ -417,6 +419,6 @@ class VariableSeqMapWrapper(IterableDataset):
|
|||||||
|
|
||||||
# Collate the processed samples into a batch
|
# Collate the processed samples into a batch
|
||||||
if batch_samples: # Only yield if we successfully processed samples
|
if batch_samples: # Only yield if we successfully processed samples
|
||||||
yield _collate_batch(batch_samples, seq_len)
|
yield self.collate_fns[seq_len](batch_samples)
|
||||||
|
|
||||||
# If batch_samples is empty after processing 'indices', an empty batch is skipped.
|
# If batch_samples is empty after processing 'indices', an empty batch is skipped.
|
||||||
|
@ -7,74 +7,10 @@ import torch
|
|||||||
|
|
||||||
from .constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
from .constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||||
from .loader import _worker_init
|
from .loader import _worker_init
|
||||||
from .naflex_dataset import VariableSeqMapWrapper
|
from .naflex_dataset import VariableSeqMapWrapper, NaFlexCollator
|
||||||
from .transforms_factory import create_transform
|
from .transforms_factory import create_transform
|
||||||
|
|
||||||
|
|
||||||
class NaFlexCollator:
|
|
||||||
"""Custom collator for batching NaFlex-style variable-resolution images."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
patch_size=16,
|
|
||||||
max_seq_len=None,
|
|
||||||
):
|
|
||||||
self.patch_size = patch_size
|
|
||||||
self.max_seq_len = max_seq_len or 576 # Default ViT-B/16 sequence length (577 = 24*24)
|
|
||||||
|
|
||||||
def __call__(self, batch):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
batch: List of tuples (patch_dict, target)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A tuple of (input_dict, targets) where input_dict contains:
|
|
||||||
- patches: Padded tensor of patches
|
|
||||||
- patch_coord: Coordinates for each patch (y, x)
|
|
||||||
- patch_valid: Valid indicators
|
|
||||||
"""
|
|
||||||
assert isinstance(batch[0], tuple)
|
|
||||||
batch_size = len(batch)
|
|
||||||
|
|
||||||
# Resize to final size based on seq_len and patchify
|
|
||||||
|
|
||||||
# Extract targets
|
|
||||||
targets = torch.tensor([item[1] for item in batch], dtype=torch.int64)
|
|
||||||
|
|
||||||
# Get patch dictionaries
|
|
||||||
patch_dicts = [item[0] for item in batch]
|
|
||||||
|
|
||||||
# If we have a maximum sequence length constraint, ensure we don't exceed it
|
|
||||||
if self.max_seq_len is not None:
|
|
||||||
max_patches = self.max_seq_len
|
|
||||||
else:
|
|
||||||
# Find the maximum number of patches in this batch
|
|
||||||
max_patches = max(item['patches'].shape[0] for item in patch_dicts)
|
|
||||||
|
|
||||||
# Get patch dimensionality
|
|
||||||
patch_dim = patch_dicts[0]['patches'].shape[1]
|
|
||||||
|
|
||||||
# Prepare tensors for the batch
|
|
||||||
patches = torch.zeros((batch_size, max_patches, patch_dim), dtype=torch.float32)
|
|
||||||
patch_coord = torch.zeros((batch_size, max_patches, 2), dtype=torch.int64) # [B, N, 2] for (y, x)
|
|
||||||
patch_valid = torch.zeros((batch_size, max_patches), dtype=torch.bool)
|
|
||||||
|
|
||||||
# Fill in the tensors
|
|
||||||
for i, patch_dict in enumerate(patch_dicts):
|
|
||||||
num_patches = min(patch_dict['patches'].shape[0], max_patches)
|
|
||||||
|
|
||||||
patches[i, :num_patches] = patch_dict['patches'][:num_patches]
|
|
||||||
patch_coord[i, :num_patches] = patch_dict['patch_coord'][:num_patches]
|
|
||||||
patch_valid[i, :num_patches] = patch_dict['patch_valid'][:num_patches]
|
|
||||||
|
|
||||||
return {
|
|
||||||
'patches': patches,
|
|
||||||
'patch_coord': patch_coord,
|
|
||||||
'patch_valid': patch_valid,
|
|
||||||
'seq_len': max_patches,
|
|
||||||
}, targets
|
|
||||||
|
|
||||||
|
|
||||||
class NaFlexPrefetchLoader:
|
class NaFlexPrefetchLoader:
|
||||||
"""Data prefetcher for NaFlex format which normalizes patches."""
|
"""Data prefetcher for NaFlex format which normalizes patches."""
|
||||||
|
|
||||||
@ -261,9 +197,7 @@ def create_naflex_loader(
|
|||||||
# NOTE: Collation is handled by the dataset wrapper for training
|
# NOTE: Collation is handled by the dataset wrapper for training
|
||||||
# Create the collator (handles fixed-size collation)
|
# Create the collator (handles fixed-size collation)
|
||||||
# collate_fn = NaFlexCollator(
|
# collate_fn = NaFlexCollator(
|
||||||
# patch_size=patch_size,
|
|
||||||
# max_seq_len=max(seq_lens) + 1, # +1 for class token
|
# max_seq_len=max(seq_lens) + 1, # +1 for class token
|
||||||
# use_prefetcher=use_prefetcher
|
|
||||||
# )
|
# )
|
||||||
|
|
||||||
loader = torch.utils.data.DataLoader(
|
loader = torch.utils.data.DataLoader(
|
||||||
@ -303,10 +237,7 @@ def create_naflex_loader(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Create the collator
|
# Create the collator
|
||||||
collate_fn = NaFlexCollator(
|
collate_fn = NaFlexCollator(max_seq_len=max_seq_len)
|
||||||
patch_size=patch_size,
|
|
||||||
max_seq_len=max_seq_len,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Handle distributed training
|
# Handle distributed training
|
||||||
sampler = None
|
sampler = None
|
||||||
|
Loading…
x
Reference in New Issue
Block a user