Initial NaFlex ViT model and training support

This commit is contained in:
Ross Wightman 2025-04-07 21:27:10 -07:00
parent e44f14d7d2
commit 0893f5d296
12 changed files with 2928 additions and 164 deletions

View File

@ -8,6 +8,13 @@ from .dataset_info import DatasetInfo, CustomDatasetInfo
from .imagenet_info import ImageNetInfo, infer_imagenet_subset from .imagenet_info import ImageNetInfo, infer_imagenet_subset
from .loader import create_loader from .loader import create_loader
from .mixup import Mixup, FastCollateMixup from .mixup import Mixup, FastCollateMixup
from .naflex_transforms import (
ResizeToSequence,
CenterCropToSequence,
RandomCropToSequence,
RandomResizedCropToSequence,
ResizeKeepRatioToSequence,
)
from .readers import create_reader from .readers import create_reader
from .readers import get_img_extensions, is_img_extension, set_img_extensions, add_img_extensions, del_img_extensions from .readers import get_img_extensions, is_img_extension, set_img_extensions, add_img_extensions, del_img_extensions
from .real_labels import RealLabelsImagenet from .real_labels import RealLabelsImagenet

422
timm/data/naflex_dataset.py Normal file
View File

@ -0,0 +1,422 @@
"""
Dynamic Sequence Length Datasets for Variable Resolution Image Processing
Implements two dataset wrappers:
1. DynamicSeqMapDataset - Map-style dataset that returns batches with variable sequence lengths
2. DynamicSeqIterDataset - Iterable dataset that yields batches with variable sequence lengths
Both support:
- Pre-initialized transforms for efficiency
- Distributed training
- Multiple workers
- Variable batch sizes based on sequence length
"""
import math
import random
import warnings
from functools import partial
from typing import Any, Iterator, List, Tuple, Dict, Optional, Union, Callable
import torch
from torch.utils.data import Dataset, IterableDataset, DataLoader
from torchvision import transforms
from PIL import Image
from .naflex_transforms import Patchify, patchify
def calculate_batch_size(
tokens_per_batch: int,
seq_len: int,
max_size: Optional[int] = None,
divisor: int = 1,
rounding: str ='floor',
):
"""Calculate batch size based on sequence length with divisibility constraints."""
# Calculate raw batch size based on sequence length
raw_batch_size = tokens_per_batch / seq_len
# Apply divisibility with specified rounding method
if divisor > 1:
if rounding == 'floor':
batch_size = math.floor(raw_batch_size / divisor) * divisor
elif rounding == 'ceil':
batch_size = math.ceil(raw_batch_size / divisor) * divisor
else: # 'round' is the default
batch_size = round(raw_batch_size / divisor) * divisor
else:
# If no divisor specified, just use integer division
batch_size = int(raw_batch_size)
# Ensure batch size is valid
batch_size = max(1, batch_size) # At least 1
if max_size is not None:
batch_size = min(batch_size, max_size)
return batch_size
def _collate_batch(
batch_samples: List[Tuple[Dict[str, torch.Tensor], Any]],
target_seq_len: int
) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]:
"""Collates processed samples into a batch, padding/truncating to target_seq_len."""
batch_patch_data = [item[0] for item in batch_samples]
batch_labels = [item[1] for item in batch_samples]
if not batch_patch_data:
return {}, torch.empty(0)
batch_size = len(batch_patch_data)
patch_dim = batch_patch_data[0]['patches'].shape[1]
# Initialize tensors with target sequence length
patches_batch = torch.zeros((batch_size, target_seq_len, patch_dim), dtype=torch.float32)
patch_coord_batch = torch.zeros((batch_size, target_seq_len, 2), dtype=torch.int64)
patch_valid_batch = torch.zeros((batch_size, target_seq_len), dtype=torch.bool) # Use bool
for i, data in enumerate(batch_patch_data):
num_patches = data['patches'].shape[0]
# Take min(num_patches, target_seq_len) patches
n_copy = min(num_patches, target_seq_len)
patches_batch[i, :n_copy] = data['patches'][:n_copy]
patch_coord_batch[i, :n_copy] = data['patch_coord'][:n_copy]
patch_valid_batch[i, :n_copy] = data['patch_valid'][:n_copy] # Copy validity flags
# Create the final input dict
input_dict = {
'patches': patches_batch,
'patch_coord': patch_coord_batch,
'patch_valid': patch_valid_batch, # Boolean mask
# Note: 'seq_length' might be ambiguous. The target length is target_seq_len.
# The actual number of valid patches per sample varies.
# 'patch_valid' mask is the most reliable source of truth.
}
# Attempt to stack labels if they are tensors, otherwise return list
try:
if isinstance(batch_labels[0], torch.Tensor):
labels_tensor = torch.stack(batch_labels)
else:
# Convert numerical types to tensor, keep others as list (or handle specific types)
if isinstance(batch_labels[0], (int, float)):
labels_tensor = torch.tensor(batch_labels)
else:
# Cannot convert non-numerical labels easily, return as list
# Or handle specific conversion if needed
# For FakeDataset, labels are ints, so this works
labels_tensor = torch.tensor(batch_labels) # Assuming labels are numerical
except Exception:
# Fallback if stacking fails (e.g., different shapes, types)
print("Warning: Could not stack labels into a tensor. Returning list of labels.")
labels_tensor = batch_labels # Return as list
return input_dict, labels_tensor
class VariableSeqMapWrapper(IterableDataset):
"""
IterableDataset wrapper for a map-style base dataset.
Yields batches with variable sequence lengths. It calculates a canonical
batch schedule (sequence length, batch size pairs) once based on the
total dataset size (padded for distribution). Each epoch, it shuffles
the *order* of this canonical schedule and the dataset indices.
This ensures a consistent number of batches and samples per epoch
across all ranks. Handles distributed training and multiple workers.
"""
def __init__(
self,
base_dataset: Dataset,
patch_size: Union[int, Tuple[int, int]] = 16,
seq_lens: List[int] = (128, 256, 576, 784, 1024),
max_tokens_per_batch: int = 4096 * 4, # Example: 16k tokens
transform_factory: Optional[Callable] = None,
seed: int = 42,
shuffle: bool = True,
distributed: bool = False,
rank: int = 0,
world_size: int = 1,
epoch: int = 0,
batch_divisor: int = 8, # Ensure batch size is divisible by this
):
super().__init__()
if not hasattr(base_dataset, '__len__') or not hasattr(base_dataset, '__getitem__'):
raise TypeError("base_dataset must be a map-style dataset (implement __len__ and __getitem__)")
self.base_dataset = base_dataset
self.patch_size = patch_size if isinstance(patch_size, tuple) else (patch_size, patch_size)
self.seq_lens = sorted(list(set(seq_lens))) # Ensure unique and sorted
self.max_tokens_per_batch = max_tokens_per_batch
self.seed = seed
self.shuffle = shuffle
self.distributed = distributed
self.rank = rank if distributed else 0
self.world_size = world_size if distributed else 1
self.epoch = epoch
self.batch_divisor = batch_divisor
# Pre-initialize transforms for each sequence length
self.transforms: Dict[int, Optional[Callable]] = {}
if transform_factory:
for seq_len in self.seq_lens:
self.transforms[seq_len] = transform_factory(max_seq_len=seq_len, patch_size=self.patch_size)
else:
for seq_len in self.seq_lens:
self.transforms[seq_len] = None # No transform
self.patchifier = Patchify(self.patch_size)
# --- Canonical Schedule Calculation (Done Once) ---
self._canonical_batch_schedule: List[Tuple[int, int]] = []
self._num_batches_per_rank: int = 0
self._padded_samples_per_rank: int = 0
self._create_canonical_schedule() # Calculate schedule based on padded size
# --- Per-Epoch State ---
# Stores (seq_len, list_of_indices) for the current epoch, specific to this rank
self._epoch_batches: List[Tuple[int, List[int]]] = []
self._prepare_epoch_batches(self.epoch) # setup for initial epoch
def _create_canonical_schedule(self):
"""
Calculates the canonical batch schedule (seq_len, batch_size pairs)
based on the dataset size, padded for distributed training.
This schedule is the *same* for all ranks and ensures consistent
epoch length. It is calculated once during initialization.
"""
total_len = len(self.base_dataset)
padded_total_len = total_len
num_samples_per_rank = total_len
if self.distributed and self.world_size > 1:
# Calculate padding needed for even distribution
if total_len % self.world_size != 0:
pad_size = self.world_size - (total_len % self.world_size)
padded_total_len += pad_size
print(f"Rank {self.rank}: Padding dataset with {pad_size} samples for distributed training (total size {padded_total_len}).")
else:
pad_size = 0
if padded_total_len % self.world_size != 0:
# This should not happen with the padding logic, but safeguard
raise RuntimeError(f"Internal Error: Padded total length {padded_total_len} not divisible by world size {self.world_size}")
num_samples_per_rank = padded_total_len // self.world_size
elif self.distributed and self.world_size <= 1:
# Distributed flag set but world_size is 1, treat as non-distributed
pass # num_samples_per_rank remains total_len
self._padded_samples_per_rank = num_samples_per_rank
if num_samples_per_rank == 0:
self._canonical_batch_schedule = []
self._num_batches_per_rank = 0
return
# Use a fixed seed for generating the canonical schedule structure
g = torch.Generator()
g.manual_seed(self.seed) # Use base seed, NOT epoch seed
current_schedule: List[Tuple[int, int]] = []
remaining_samples = num_samples_per_rank
total_scheduled_samples = 0
while remaining_samples > 0:
# Sample sequence length deterministically based on base seed
seq_idx = torch.randint(0, len(self.seq_lens), (1,), generator=g).item()
seq_len = self.seq_lens[seq_idx]
# Calculate batch size
batch_size = calculate_batch_size(
tokens_per_batch=self.max_tokens_per_batch,
seq_len=seq_len,
# max_size should be remaining_samples to avoid overshooting
max_size=remaining_samples,
divisor=self.batch_divisor,
rounding='floor',
)
# Ensure batch size is positive and doesn't exceed remaining samples
batch_size = max(1, batch_size)
batch_size = min(batch_size, remaining_samples)
if batch_size <= 0:
warnings.warn(f"Calculated batch size <= 0 (seq_len={seq_len}, remaining={remaining_samples}). Stopping schedule generation early.")
break # Avoid infinite loop if something goes wrong
current_schedule.append((seq_len, batch_size))
remaining_samples -= batch_size
total_scheduled_samples += batch_size
# Sanity check: Ensure the schedule covers all samples for the rank
if total_scheduled_samples != num_samples_per_rank:
warnings.warn(
f"Rank {self.rank}: Canonical schedule accounts for {total_scheduled_samples} samples, "
f"but expected {num_samples_per_rank} samples per rank. "
f"This might happen if min_batch_size or batch_divisor constraints prevent utilizing all samples. "
f"Check parameters. Remaining samples: {remaining_samples}"
)
# Adjust if needed? Could add a final small batch, but might violate constraints.
# Current behavior: some samples might be dropped if schedule logic fails.
self._canonical_batch_schedule = current_schedule
self._num_batches_per_rank = len(current_schedule)
print(f"Rank {self.rank}: Created canonical schedule with {self._num_batches_per_rank} batches for {self._padded_samples_per_rank} samples/rank.")
def _prepare_epoch_batches(self, epoch: int):
"""
Prepares the batches for the current epoch by:
1. Shuffling the full dataset indices (using epoch seed).
2. Applying padding if in distributed mode.
3. Selecting indices for the current rank.
4. Shuffling the *order* of the canonical batch schedule (using epoch seed).
5. Assigning the rank's indices to the shuffled batches.
"""
g = torch.Generator()
g.manual_seed(self.seed + epoch) # Epoch-specific seed for shuffling
# 1. Get shuffled global indices
total_len = len(self.base_dataset)
if self.shuffle:
all_indices_shuffled = torch.randperm(total_len, generator=g).tolist()
else:
all_indices_shuffled = list(range(total_len))
# 2. Apply padding for distributed mode
indices_for_ranks = all_indices_shuffled
if self.distributed and self.world_size > 1:
padded_total_len = self._padded_samples_per_rank * self.world_size
if padded_total_len > total_len:
pad_size = padded_total_len - total_len
# Repeat initial elements from the *shuffled* list for padding
indices_for_ranks = all_indices_shuffled + all_indices_shuffled[:pad_size]
# Ensure length matches expectation
if len(indices_for_ranks) != padded_total_len:
raise RuntimeError(f"Internal Error: Padded index list length {len(indices_for_ranks)} does not match expected {padded_total_len}")
# 3. Select indices for the current rank
if self.distributed and self.world_size > 1:
indices_this_rank = indices_for_ranks[self.rank::self.world_size]
else: # Non-distributed or world_size=1
indices_this_rank = indices_for_ranks
# Sanity check length
if len(indices_this_rank) != self._padded_samples_per_rank:
# This might happen if canonical schedule generation had warnings/issues
warnings.warn(
f"Rank {self.rank}: Number of indices for this rank ({len(indices_this_rank)}) "
f"does not match expected padded samples per rank ({self._padded_samples_per_rank}). "
f"Epoch generation might be inconsistent."
)
# Adjust expected samples? Or truncate/pad indices? Let's proceed but warn.
# Using min() prevents IndexError later if indices are fewer than expected.
effective_samples_this_rank = min(len(indices_this_rank), self._padded_samples_per_rank)
indices_this_rank = indices_this_rank[:effective_samples_this_rank]
else:
effective_samples_this_rank = self._padded_samples_per_rank
# 4. Shuffle the order of the canonical batch schedule for this epoch
if self.shuffle:
schedule_perm = torch.randperm(self._num_batches_per_rank, generator=g).tolist()
shuffled_schedule = [self._canonical_batch_schedule[i] for i in schedule_perm]
else:
shuffled_schedule = list(self._canonical_batch_schedule) # Keep original order
# 5. Assign indices to the shuffled batches
self._epoch_batches = []
idx_pos = 0
scheduled_samples_count = 0
for seq_len, bs in shuffled_schedule:
# Ensure we don't try to grab more indices than available for the rank
actual_bs = min(bs, effective_samples_this_rank - idx_pos)
if actual_bs <= 0:
if scheduled_samples_count < effective_samples_this_rank:
# This indicates mismatch between schedule total and actual samples
warnings.warn(f"Rank {self.rank}: Ran out of samples ({idx_pos}/{effective_samples_this_rank}) before processing entire schedule. Check schedule generation.")
break # Stop if no more indices or batch size is zero
batch_indices = indices_this_rank[idx_pos : idx_pos + actual_bs]
self._epoch_batches.append((seq_len, batch_indices))
idx_pos += actual_bs
scheduled_samples_count += actual_bs
# Final check
if scheduled_samples_count != effective_samples_this_rank:
warnings.warn(
f"Rank {self.rank}: Assigned {scheduled_samples_count} samples to batches, "
f"but expected {effective_samples_this_rank} effective samples this epoch. "
f"Indices remaining: {effective_samples_this_rank - scheduled_samples_count}."
)
def set_epoch(self, epoch: int):
"""Updates the epoch, regenerating the epoch-specific batches."""
# Only regenerate if the epoch actually changes
if epoch != self.epoch:
self.epoch = epoch
self._prepare_epoch_batches(epoch)
def __len__(self) -> int:
"""
Returns the number of batches per **worker** for the current epoch.
Calculated based on the fixed number of batches per rank divided by
the number of workers.
"""
return self._num_batches_per_rank
def __iter__(self) -> Iterator[Tuple[Dict[str, torch.Tensor], torch.Tensor]]:
"""
Iterates through the pre-calculated batches for the current epoch,
distributing them among workers.
"""
worker_info = torch.utils.data.get_worker_info()
num_workers = worker_info.num_workers if worker_info else 1
worker_id = worker_info.id if worker_info else 0
# Distribute pre-calculated batches among workers for this rank
# Each worker processes a slice of the batches prepared in _prepare_epoch_batches
batches_for_worker = self._epoch_batches[worker_id::num_workers]
for seq_len, indices in batches_for_worker:
if not indices: # Skip if a batch ended up with no indices (shouldn't happen often)
continue
# Get the pre-initialized transform for this sequence length
transform = self.transforms.get(seq_len)
batch_samples = []
for idx in indices:
try:
# Get original image and label from map-style dataset
img, label = self.base_dataset[idx]
# Apply transform if available
# Handle cases where transform might return None or fail
processed_img = transform(img) if transform else img
if processed_img is None:
warnings.warn(f"Transform returned None for index {idx}. Skipping sample.")
continue
# Apply patching
patch_data = self.patchifier(processed_img)
batch_samples.append((patch_data, label))
except IndexError:
warnings.warn(f"IndexError encountered for index {idx} (possibly due to padding/repeated indices). Skipping sample.")
continue
except Exception as e:
# Log other potential errors during data loading/processing
warnings.warn(f"Error processing sample index {idx}. Error: {e}. Skipping sample.")
continue # Skip problematic sample
# Collate the processed samples into a batch
if batch_samples: # Only yield if we successfully processed samples
yield _collate_batch(batch_samples, seq_len)
# If batch_samples is empty after processing 'indices', an empty batch is skipped.

341
timm/data/naflex_loader.py Normal file
View File

@ -0,0 +1,341 @@
import math
from contextlib import suppress
from functools import partial
from typing import Callable, List, Optional, Tuple, Union
import torch
from .constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .loader import _worker_init
from .naflex_dataset import VariableSeqMapWrapper
from .transforms_factory import create_transform
class NaFlexCollator:
"""Custom collator for batching NaFlex-style variable-resolution images."""
def __init__(
self,
patch_size=16,
max_seq_len=None,
):
self.patch_size = patch_size
self.max_seq_len = max_seq_len or 576 # Default ViT-B/16 sequence length (577 = 24*24)
def __call__(self, batch):
"""
Args:
batch: List of tuples (patch_dict, target)
Returns:
A tuple of (input_dict, targets) where input_dict contains:
- patches: Padded tensor of patches
- patch_coord: Coordinates for each patch (y, x)
- patch_valid: Valid indicators
"""
assert isinstance(batch[0], tuple)
batch_size = len(batch)
# FIXME
# get seq len from sampler schedule
# resize to final size based on seq_len and patchify
# Extract targets
targets = torch.tensor([item[1] for item in batch], dtype=torch.int64)
# Get patch dictionaries
patch_dicts = [item[0] for item in batch]
# If we have a maximum sequence length constraint, ensure we don't exceed it
if self.max_seq_len is not None:
max_patches = self.max_seq_len
else:
# Find the maximum number of patches in this batch
max_patches = max(item['patches'].shape[0] for item in patch_dicts)
# Get patch dimensionality
patch_dim = patch_dicts[0]['patches'].shape[1]
# Prepare tensors for the batch
patches = torch.zeros((batch_size, max_patches, patch_dim), dtype=torch.float32)
patch_coord = torch.zeros((batch_size, max_patches, 2), dtype=torch.int64) # [B, N, 2] for (y, x)
patch_valid = torch.zeros((batch_size, max_patches), dtype=torch.bool)
# Fill in the tensors
for i, patch_dict in enumerate(patch_dicts):
num_patches = min(patch_dict['patches'].shape[0], max_patches)
patches[i, :num_patches] = patch_dict['patches'][:num_patches]
patch_coord[i, :num_patches] = patch_dict['patch_coord'][:num_patches]
patch_valid[i, :num_patches] = patch_dict['patch_valid'][:num_patches]
return {
'patches': patches,
'patch_coord': patch_coord,
'patch_valid': patch_valid,
'seq_len': max_patches,
}, targets
class NaFlexPrefetchLoader:
"""Data prefetcher for NaFlex format which normalizes patches."""
def __init__(
self,
loader,
mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225),
img_dtype=torch.float32,
device=torch.device('cuda')
):
self.loader = loader
self.device = device
self.img_dtype = img_dtype or torch.float32
# Create mean/std tensors for normalization (will be applied to patches)
self.mean = torch.tensor([x * 255 for x in mean], device=device, dtype=self.img_dtype).view(1, 1, 3)
self.std = torch.tensor([x * 255 for x in std], device=device, dtype=self.img_dtype).view(1, 1, 3)
# Check for CUDA/NPU availability
self.is_cuda = device.type == 'cuda' and torch.cuda.is_available()
self.is_npu = device.type == 'npu' and torch.npu.is_available()
def __iter__(self):
first = True
if self.is_cuda:
stream = torch.cuda.Stream()
stream_context = partial(torch.cuda.stream, stream=stream)
elif self.is_npu:
stream = torch.npu.Stream()
stream_context = partial(torch.npu.stream, stream=stream)
else:
stream = None
stream_context = suppress
for next_input_dict, next_target in self.loader:
with stream_context():
# Move all tensors in input_dict to device
for k, v in next_input_dict.items():
if isinstance(v, torch.Tensor):
dtype = self.img_dtype if k == 'patches' else None
next_input_dict[k] = next_input_dict[k].to(
device=self.device,
non_blocking=True,
dtype=dtype,
)
next_target = next_target.to(device=self.device, non_blocking=True)
# Normalize patch values (assuming patches are in format [B, N, P*P*C])
batch_size, num_patches, patch_pixels = next_input_dict['patches'].shape
patches = next_input_dict['patches'].view(batch_size, -1, 3) # to [B*N, P*P, C] for normalization
patches = patches.sub(self.mean).div(self.std)
# Reshape back
next_input_dict['patches'] = patches.reshape(batch_size, num_patches, patch_pixels)
if not first:
yield input_dict, target
else:
first = False
if stream is not None:
if self.is_cuda:
torch.cuda.current_stream().wait_stream(stream)
elif self.is_npu:
torch.npu.current_stream().wait_stream(stream)
input_dict = next_input_dict
target = next_target
yield input_dict, target
def __len__(self):
return len(self.loader)
@property
def sampler(self):
return self.loader.sampler
@property
def dataset(self):
return self.loader.dataset
def create_naflex_loader(
dataset,
patch_size: Union[Tuple[int, int], int] = 16,
train_seq_lens: List[int] = (128, 256, 576, 784, 1024), # Training sequence lengths
max_seq_len: int = 576, # Fixed sequence length for validation
batch_size: int = 32, # Used for max_seq_len and max(train_seq_lens)
is_training: bool = False,
no_aug: bool = False,
re_prob: float = 0.,
re_mode: str = 'const',
re_count: int = 1,
re_split: bool = False,
train_crop_mode: Optional[str] = None,
scale: Optional[Tuple[float, float]] = None,
ratio: Optional[Tuple[float, float]] = None,
hflip: float = 0.5,
vflip: float = 0.,
color_jitter: float = 0.4,
color_jitter_prob: Optional[float] = None,
grayscale_prob: float = 0.,
gaussian_blur_prob: float = 0.,
auto_augment: Optional[str] = None,
num_aug_repeats: int = 0,
num_aug_splits: int = 0,
interpolation: str = 'bilinear',
mean: Tuple[float, ...] = IMAGENET_DEFAULT_MEAN,
std: Tuple[float, ...] = IMAGENET_DEFAULT_STD,
crop_pct: Optional[float] = None,
crop_mode: Optional[str] = None,
crop_border_pixels: Optional[int] = None,
num_workers: int = 4,
distributed: bool = False,
rank: int = 0,
world_size: int = 1,
seed: int = 42,
epoch: int = 0,
use_prefetcher: bool = True,
pin_memory: bool = True,
img_dtype: torch.dtype = torch.float32,
device: Union[str, torch.device] = torch.device('cuda'),
persistent_workers: bool = True,
worker_seeding: str = 'all',
):
"""Create a data loader with dynamic sequence length sampling for training."""
if is_training:
# For training, use the dynamic sequence length mechanism
assert num_aug_repeats == 0, 'Augmentation repeats not currently supported in NaFlex loader'
transform_factory = partial(
create_transform,
is_training=True,
no_aug=no_aug,
train_crop_mode=train_crop_mode,
scale=scale,
ratio=ratio,
hflip=hflip,
vflip=vflip,
color_jitter=color_jitter,
color_jitter_prob=color_jitter_prob,
grayscale_prob=grayscale_prob,
gaussian_blur_prob=gaussian_blur_prob,
auto_augment=auto_augment,
interpolation=interpolation,
mean=mean,
std=std,
crop_pct=crop_pct,
crop_mode=crop_mode,
crop_border_pixels=crop_border_pixels,
re_prob=re_prob,
re_mode=re_mode,
re_count=re_count,
use_prefetcher=use_prefetcher,
naflex=True,
)
max_train_seq_len = max(train_seq_lens)
max_tokens_per_batch = batch_size * max_train_seq_len
if isinstance(dataset, torch.utils.data.IterableDataset):
assert False, "IterableDataset Wrapper is a WIP"
naflex_dataset = VariableSeqMapWrapper(
dataset,
transform_factory=transform_factory,
patch_size=patch_size,
seq_lens=train_seq_lens,
max_tokens_per_batch=max_tokens_per_batch,
seed=seed,
distributed=distributed,
rank=rank,
world_size=world_size,
shuffle=True,
epoch=epoch,
)
# NOTE: Collation is handled by the dataset wrapper for training
# Create the collator (handles fixed-size collation)
# collate_fn = NaFlexCollator(
# patch_size=patch_size,
# max_seq_len=max(seq_lens) + 1, # +1 for class token
# use_prefetcher=use_prefetcher
# )
loader = torch.utils.data.DataLoader(
naflex_dataset,
batch_size=None,
shuffle=False,
num_workers=num_workers,
sampler=None,
#collate_fn=collate_fn,
pin_memory=pin_memory,
worker_init_fn=partial(_worker_init, worker_seeding=worker_seeding),
persistent_workers=persistent_workers
)
if use_prefetcher:
loader = NaFlexPrefetchLoader(
loader,
mean=mean,
std=std,
img_dtype=img_dtype,
device=device,
)
else:
# For validation, use fixed sequence length (unchanged)
dataset.transform = create_transform(
is_training=False,
interpolation=interpolation,
mean=mean,
std=std,
# FIXME add crop args when sequence transforms support crop modes
use_prefetcher=use_prefetcher,
naflex=True,
patch_size=patch_size,
max_seq_len=max_seq_len,
patchify=True,
)
# Create the collator
collate_fn = NaFlexCollator(
patch_size=patch_size,
max_seq_len=max_seq_len,
)
# Handle distributed training
sampler = None
if distributed and not isinstance(dataset, torch.utils.data.IterableDataset):
# For validation, use OrderedDistributedSampler
from timm.data.distributed_sampler import OrderedDistributedSampler
sampler = OrderedDistributedSampler(dataset)
loader = torch.utils.data.DataLoader(
dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
sampler=sampler,
collate_fn=collate_fn,
pin_memory=pin_memory,
drop_last=False,
)
if use_prefetcher:
loader = NaFlexPrefetchLoader(
loader,
mean=mean,
std=std,
img_dtype=img_dtype,
device=device,
)
return loader

View File

@ -0,0 +1,804 @@
""" NaFlex (NaViT + FlexiViT) Transforms and Collation
Implements PyTorch versions of the transforms described in the NaViT and FlexiViT papers:
- NaViT: https://arxiv.org/abs/2307.14995
- FlexiViT: https://arxiv.org/abs/2212.08013
Enables variable resolution/aspect ratio image handling with efficient patching.
"""
import math
import random
import warnings
from typing import List, Optional, Sequence, Tuple, Union
import torch
from PIL import Image
from torchvision import transforms
from torchvision.transforms import functional as F
from torchvision.transforms.functional import InterpolationMode
from .transforms import str_to_interp_mode, crop_or_pad, center_crop_or_pad
def get_image_size_for_seq(
image_hw,
patch_size=16,
max_seq_len=1024,
divisible_by_patch=True,
max_ratio=None,
eps = 1e-5,
):
"""
Determine scaling ratio and image size so that when `image_hw` is scaled
by 'ratio', the total number of resulting patches does not exceed
'max_seq_len'.
- Patch size can be an integer (square patch) or a tuple (patch_h, patch_w).
- Optionally cap the ratio at `max_ratio` to prevent upsampling beyond
a certain multiple of the original size.
Args:
image_hw (tuple or list of int): (height, width) of the original image.
patch_size (int or tuple[int, int]): If int, patch is square. If tuple,
patch is rectangular (patch_h, patch_w).
max_seq_len (int): Maximum allowed sequence length for the resulting image.
divisible_by_patch (bool): If True, the resulting image height and width
must be multiples of patch_size.
eps (float): Small number for binary search convergence.
max_ratio (float or None): If provided, the scaling ratio found by the
binary search will be clamped to min(found_ratio, max_ratio). Set
max_ratio=1.0 to ensure no upsampling beyond original size.
Returns:
ratio (float): Found scaling ratio (capped by `max_ratio` if provided).
target_hw (tuple of int): Target (height, width) after scaling.
"""
# Handle patch size input, extract patch_h, patch_w
if isinstance(patch_size, int):
patch_h, patch_w = patch_size, patch_size
else:
# Assume it's a tuple/list: (patch_h, patch_w)
if len(patch_size) != 2:
raise ValueError("patch_size tuple must have exactly two elements (patch_h, patch_w).")
patch_h, patch_w = patch_size
# Safety checks
if patch_h <= 0 or patch_w <= 0:
raise ValueError("patch_size dimensions must be positive.")
def prepare_target_hw(ratio):
"""Scale image_hw by ratio and optionally round dimensions to multiples of patch_h, patch_w."""
scaled_h = image_hw[0] * ratio
scaled_w = image_hw[1] * ratio
# If we need the result to be divisible by patch_size
if divisible_by_patch:
scaled_h = patch_h * math.ceil(scaled_h / patch_h)
scaled_w = patch_w * math.ceil(scaled_w / patch_w)
# Ensure at least one patch in each dimension
scaled_h = int(max(scaled_h, patch_h))
scaled_w = int(max(scaled_w, patch_w))
return scaled_h, scaled_w
def is_feasible(ratio):
"""Check if scaling by 'ratio' keeps patch count within max_seq_len."""
t_h, t_w = prepare_target_hw(ratio)
# Each dimension is already a multiple of patch_h, patch_w if divisible_by_patch=True.
# Use integer division to count patches.
num_patches_h = t_h // patch_h
num_patches_w = t_w // patch_w
seq_len = num_patches_h * num_patches_w
return seq_len <= max_seq_len
# Binary search boundaries
lb = eps / 10.0
rb = 100.0
# Standard binary search loop
while (rb - lb) >= eps:
mid = (lb + rb) / 2.0
if is_feasible(mid):
lb = mid
else:
rb = mid
# The final ratio from the binary search
ratio = lb
# If max_ratio is provided, clamp it to prevent upsampling beyond that threshold
if max_ratio is not None:
ratio = min(ratio, max_ratio)
# Final checks
if ratio <= eps:
raise ValueError("Binary search failed - image might be too large?")
if ratio >= 100.0:
raise ValueError("Binary search failed - image might be too small?")
# Prepare the final target dimensions with the possibly clamped ratio
target_hw = prepare_target_hw(ratio)
return ratio, target_hw
_RANDOM_INTERPOLATION = (str_to_interp_mode('bilinear'), str_to_interp_mode('bicubic'))
class ResizeToSequence(torch.nn.Module):
"""Resize image to fit within a maximum sequence length constraint when patchified.
This maintains aspect ratio while ensuring the resulting image, when divided into patches,
will not exceed the specified maximum sequence length.
"""
def __init__(
self,
patch_size: int,
max_seq_len: int = 1024,
divisible_by_patch: bool = True,
max_ratio: Optional[float] = None,
interpolation='bicubic',
):
super().__init__()
self.patch_size = patch_size
self.max_seq_len = max_seq_len
self.divisible_by_patch = divisible_by_patch
self.max_ratio = max_ratio
if isinstance(interpolation, str):
if interpolation == 'random':
self.interpolation = _RANDOM_INTERPOLATION
else:
self.interpolation = str_to_interp_mode(interpolation)
else:
self.interpolation = interpolation
def forward(self, img):
"""Resize image to maintain aspect ratio and fit sequence constraint."""
_, h, w = transforms.functional.get_dimensions(img)
_, target_hw = get_image_size_for_seq(
(h, w),
self.patch_size,
self.max_seq_len,
divisible_by_patch=self.divisible_by_patch,
max_ratio=self.max_ratio,
)
if isinstance(self.interpolation, (tuple, list)):
interpolation = random.choice(self.interpolation)
else:
interpolation = self.interpolation
resized_img = transforms.functional.resize(img, target_hw, interpolation=interpolation, antialias=True)
return resized_img
class ResizeKeepRatioToSequence(torch.nn.Module):
"""
Resize and Keep Aspect Ratio, adapted to fit sequence length constraints.
"""
def __init__(
self,
patch_size=16,
max_sequence_len=1024,
divisible_by_patch=True,
longest=0.,
interpolation='bilinear',
random_scale_prob=0.,
random_scale_range=(0.85, 1.05),
random_scale_area=False,
random_aspect_prob=0.,
random_aspect_range=(0.9, 1.11),
max_ratio=None,
):
"""
Args:
patch_size: Size of patches (int or tuple of (patch_h, patch_w))
max_sequence_len: Maximum allowed sequence length for the resulting image
divisible_by_patch: If True, ensure dimensions are divisible by patch_size
longest: Float between 0-1 where 0=shortest side, 1=longest side determines scale
interpolation: Interpolation method for resizing
random_scale_prob: Probability of applying random scaling
random_scale_range: Range for random scaling factor (min, max)
random_scale_area: If True, scale factors affect area ( factor)
random_aspect_prob: Probability of applying random aspect ratio jittering
random_aspect_range: Range for random aspect ratio (min, max)
max_ratio: Maximum allowed scaling ratio
"""
super().__init__()
self.patch_size = patch_size
self.max_sequence_len = max_sequence_len
self.divisible_by_patch = divisible_by_patch
self.longest = float(longest)
if interpolation == 'random':
self.interpolation = _RANDOM_INTERPOLATION
else:
self.interpolation = str_to_interp_mode(interpolation)
self.random_scale_prob = random_scale_prob
self.random_scale_range = random_scale_range
self.random_scale_area = random_scale_area
self.random_aspect_prob = random_aspect_prob
self.random_aspect_range = random_aspect_range
self.max_ratio = max_ratio
@staticmethod
def get_params(
img,
patch_size,
max_sequence_len,
divisible_by_patch,
longest,
random_scale_prob=0.,
random_scale_range=(1.0, 1.33),
random_scale_area=False,
random_aspect_prob=0.,
random_aspect_range=(0.9, 1.11),
max_ratio=None,
):
"""Get parameters for resizing."""
# Get image dimensions
img_h, img_w = F.get_dimensions(img)[1:]
# Step 1: Get the maximum allowed dimensions from sequence length constraint
_, target_hw = get_image_size_for_seq(
(img_h, img_w),
patch_size,
max_sequence_len,
divisible_by_patch,
max_ratio,
)
target_h, target_w = target_hw
# Calculate ratio based on sequence constraint
ratio_h = target_h / img_h
ratio_w = target_w / img_w
# Apply longest blending
ratio = max(ratio_h, ratio_w) * longest + min(ratio_h, ratio_w) * (1. - longest)
# Apply random scaling
if random_scale_prob > 0 and random.random() < random_scale_prob:
ratio_factor = random.uniform(random_scale_range[0], random_scale_range[1])
if random_scale_area:
# Make ratio factor equivalent to area change
ratio_factor = 1. / math.sqrt(ratio_factor)
ratio_factor = (ratio_factor, ratio_factor)
else:
ratio_factor = (1., 1.)
# Apply random aspect
if random_aspect_prob > 0 and random.random() < random_aspect_prob:
log_aspect = (math.log(random_aspect_range[0]), math.log(random_aspect_range[1]))
aspect_factor = math.exp(random.uniform(*log_aspect))
aspect_factor = math.sqrt(aspect_factor)
# Apply aspect ratio jittering
ratio_factor = (ratio_factor[0] / aspect_factor, ratio_factor[1] * aspect_factor)
# Calculate final dimensions
size = [round(dim * ratio * f) for dim, f in zip((img_h, img_w), ratio_factor)]
# Ensure dimensions satisfy sequence constraint and are divisible by patch size
if isinstance(patch_size, int):
ph, pw = patch_size, patch_size
else:
ph, pw = patch_size
# Ensure dimensions are at least one patch
size[0] = max(size[0], ph)
size[1] = max(size[1], pw)
# Make divisible by patch size if needed
if divisible_by_patch:
size[0] = ph * math.ceil(size[0] / ph)
size[1] = pw * math.ceil(size[1] / pw)
# Verify we haven't exceeded sequence length
num_patches_h = size[0] // ph
num_patches_w = size[1] // pw
seq_len = num_patches_h * num_patches_w
if seq_len > max_sequence_len:
# Scale back down to fit sequence constraint
scale_back = math.sqrt(max_sequence_len / seq_len)
size[0] = int(size[0] * scale_back)
size[1] = int(size[1] * scale_back)
# Ensure divisible by patch size after scaling back
if divisible_by_patch:
size[0] = ph * math.ceil(size[0] / ph)
size[1] = pw * math.ceil(size[1] / pw)
return size
def forward(self, img):
"""
Resize the image with aspect ratio preservation and sequence length constraints.
"""
size = self.get_params(
img,
self.patch_size,
self.max_sequence_len,
self.divisible_by_patch,
self.longest,
self.random_scale_prob,
self.random_scale_range,
self.random_scale_area,
self.random_aspect_prob,
self.random_aspect_range,
self.max_ratio,
)
if isinstance(self.interpolation, (tuple, list)):
interpolation = random.choice(self.interpolation)
else:
interpolation = self.interpolation
return F.resize(img, size, interpolation)
def __repr__(self):
interpolate_str = "random" if isinstance(self.interpolation, (tuple, list)) else str(self.interpolation)
return (f"{self.__class__.__name__}(patch_size={self.patch_size}, "
f"max_sequence_len={self.max_sequence_len}, "
f"longest={self.longest:.3f}, "
f"random_scale_prob={self.random_scale_prob:.3f}, "
f"random_aspect_prob={self.random_aspect_prob:.3f})")
class CenterCropToSequence(torch.nn.Module):
"""Center crop the image such that the resulting patch sequence length meets constraints."""
def __init__(
self,
patch_size: int,
max_seq_len: int,
divisible_by_patch: bool = True,
fill: Union[int, Tuple[int, int, int]] = 0,
padding_mode: str = 'constant'
):
super().__init__()
self.patch_size = patch_size
self.max_seq_len = max_seq_len
self.divisible_by_patch = divisible_by_patch
self.fill = fill
self.padding_mode = padding_mode
def forward(self, img):
"""Center crop the image to maintain aspect ratio and fit sequence constraint."""
_, h, w = transforms.functional.get_dimensions(img)
_, target_hw = get_image_size_for_seq(
(h, w),
self.patch_size,
self.max_seq_len,
self.divisible_by_patch
)
# Use center crop
return center_crop_or_pad(img, target_hw, fill=self.fill, padding_mode=self.padding_mode)
class RandomCropToSequence(torch.nn.Module):
"""Randomly crop and/or pad the image to fit sequence length constraints.
This maintains aspect ratio while ensuring the resulting image, when divided into patches,
will not exceed the specified maximum sequence length. Similar to CentralCropToSequence
but with randomized positioning.
"""
def __init__(
self,
patch_size: int,
max_sequence_len: int,
divisible_by_patch: bool = True,
fill: Union[int, Tuple[int, int, int]] = 0,
padding_mode: str = 'constant'
):
"""
Args:
patch_size: Size of patches (int or tuple of (patch_h, patch_w))
max_sequence_len: Maximum allowed sequence length for the resulting image
divisible_by_patch: If True, resulting image dimensions will be multiples of patch_size
fill: Fill value for padding
padding_mode: Padding mode ('constant', 'edge', 'reflect', 'symmetric')
"""
super().__init__()
self.patch_size = patch_size
self.max_sequence_len = max_sequence_len
self.divisible_by_patch = divisible_by_patch
self.fill = fill
self.padding_mode = padding_mode
@staticmethod
def get_params(img, target_size):
"""Get random position for crop/pad."""
_, image_height, image_width = transforms.functional.get_dimensions(img)
delta_height = image_height - target_size[0]
delta_width = image_width - target_size[1]
# Handle both positive (crop) and negative (pad) deltas
if delta_height == 0:
top = 0
else:
top = int(math.copysign(random.randint(0, abs(delta_height)), delta_height))
if delta_width == 0:
left = 0
else:
left = int(math.copysign(random.randint(0, abs(delta_width)), delta_width))
return top, left
def forward(self, img):
"""Randomly crop or pad the image to maintain aspect ratio and fit sequence constraint."""
# Get current dimensions
_, img_h, img_w = transforms.functional.get_dimensions(img)
# Calculate target dimensions that satisfy sequence length
# We use max_ratio=1.0 to prevent upscaling - we only want to crop or maintain current size
_, target_hw = get_image_size_for_seq(
(img_h, img_w),
self.patch_size,
self.max_sequence_len,
self.divisible_by_patch,
max_ratio=1.0 # Prevent upscaling
)
# Get random position for crop/pad
top, left = self.get_params(img, target_hw)
# Apply crop or pad
return crop_or_pad(
img,
top=top,
left=left,
height=target_hw[0],
width=target_hw[1],
fill=self.fill,
padding_mode=self.padding_mode,
)
def __repr__(self) -> str:
return (f"{self.__class__.__name__}(patch_size={self.patch_size}, "
f"max_sequence_len={self.max_sequence_len}, "
f"divisible_by_patch={self.divisible_by_patch})")
def _validate_range(value, name, length=2):
# Validate type and length
if not isinstance(value, Sequence) or len(value) != length:
raise ValueError(f"{name} should be a sequence of length {length}.")
# Validate order
if value[0] > value[1]:
warnings.warn(f"{name.capitalize()} range reversed. Swapping.")
return value[1], value[0]
return value
class RandomResizedCropToSequence(torch.nn.Module):
"""
Randomly crop the input image to a subregion with varying area and aspect ratio
(relative to the original), then resize that crop to a target size. The target size
is determined such that patchifying the resized image (with `patch_size`)
does not exceed `max_seq_len` patches, while maintaining the aspect ratio of the crop.
This combines aspects of torchvision's RandomResizedCrop with sequence length constraints.
Args:
patch_size (int or tuple[int, int]):
Patch dimensions (patch_h, patch_w) for sequence length calculation.
max_seq_len (int):
Maximum number of patches allowed in the final image.
scale (tuple[float, float]):
Range (min, max) of area fraction of the original image to crop.
ratio (tuple[float, float]):
Range (min, max) of aspect ratio *multipliers* for the crop, relative
to the original image's aspect ratio. E.g., (0.75, 1.333) means the
crop's aspect ratio will be sampled between 0.75*orig_ar and 1.333*orig_ar.
Uses log-uniform sampling.
interpolation (str or InterpolationMode):
Interpolation mode for resizing. Can be 'bilinear', 'bicubic', 'nearest',
or 'random' (chooses between bilinear and bicubic).
Defaults to 'bicubic'.
divisible_by_patch (bool):
If True, the final image height and width will be multiples of the
respective patch dimensions. Defaults to True.
max_ratio (float, optional):
An optional upper limit on the scaling ratio applied during resizing.
Prevents excessive upsampling of the initial crop. `max_ratio=1.0`
prevents any upsampling beyond the cropped size. Defaults to None (no limit).
final_scale_range (tuple[float, float], optional):
If provided, applies an *additional* random scaling factor to the
final target size. The factor is sampled uniformly from this range,
and multiplied by the size determined by `get_image_size_for_seq`.
E.g., (0.8, 1.0) means the final size will be between 80% and 100%
of the maximum feasible size. Defaults to None (use maximum feasible size).
attempts (int):
Number of attempts to sample a valid crop geometry before falling back
to a center crop strategy. Defaults to 10.
"""
def __init__(
self,
patch_size: Union[int, Tuple[int, int]] = 16,
max_seq_len: int = 1024,
scale: Tuple[float, float] = (0.08, 1.0),
ratio: Tuple[float, float] = (.8, 1.25),
interpolation: Union[str, InterpolationMode] = 'bicubic',
divisible_by_patch: bool = True,
max_ratio: Optional[float] = None,
final_scale_range: Optional[Tuple[float, float]] = None,
attempts: int = 10,
):
super().__init__()
if isinstance(patch_size, int):
self.patch_h, self.patch_w = patch_size, patch_size
else:
# Assume it's a tuple/list: (patch_h, patch_w)
if len(patch_size) != 2:
raise ValueError("patch_size tuple must have exactly two elements (patch_h, patch_w).")
self.patch_h, self.patch_w = patch_size
self.max_seq_len = max_seq_len
self.scale = scale
self.ratio = ratio
self.divisible_by_patch = divisible_by_patch
self.max_ratio = max_ratio
self.final_scale_range = final_scale_range
self.attempts = attempts
if isinstance(interpolation, str):
if interpolation == 'random':
self.interpolation = _RANDOM_INTERPOLATION
else:
self.interpolation = str_to_interp_mode(interpolation)
else:
self.interpolation = interpolation
# Validate scale and ratio
self.scale = _validate_range(self.scale, "scale")
self.ratio = _validate_range(self.ratio, "ratio")
# Validate final_scale_range if provided
if self.final_scale_range is not None:
self.final_scale_range = _validate_range(self.final_scale_range, "final_scale_range")
# Additional validation for final_scale_range values
if not (0.0 <= self.final_scale_range[0] <= self.final_scale_range[1] <= 1.0):
warnings.warn("final_scale_range values should ideally be between 0.0 and 1.0.")
@staticmethod
def get_params(
img: Union[torch.Tensor, Image],
scale: Tuple[float, float],
ratio: Tuple[float, float],
crop_attempts: int = 10,
patch_h: int = 16,
patch_w: int = 16,
max_seq_len: int = 1024,
divisible_by_patch: bool = True,
max_ratio: Optional[float] = None,
final_scale_range: Optional[Tuple[float, float]] = None,
interpolation: Union[List[InterpolationMode], InterpolationMode] = _RANDOM_INTERPOLATION,
) -> Tuple[Tuple[int, int, int, int], Tuple[int, int], InterpolationMode]:
""" Get parameters for a random sized crop relative to image aspect ratio.
"""
_, height, width = F.get_dimensions(img)
if height <= 0 or width <= 0:
raise ValueError(f"Input image must have positive dimensions, got H={height}, W={width}")
area = height * width
orig_aspect = width / height
log_ratio = (math.log(ratio[0]), math.log(ratio[1]))
for _ in range(crop_attempts):
target_area = area * random.uniform(scale[0], scale[1])
aspect_ratio_factor = math.exp(random.uniform(log_ratio[0], log_ratio[1]))
aspect_ratio = orig_aspect * aspect_ratio_factor
# Calculate target dimensions for the crop
# target_area = crop_w * crop_h, aspect_ratio = crop_w / crop_h
# => crop_h = sqrt(target_area / aspect_ratio)
# => crop_w = sqrt(target_area * aspect_ratio)
crop_h = int(round(math.sqrt(target_area / aspect_ratio)))
crop_w = int(round(math.sqrt(target_area * aspect_ratio)))
if 0 < crop_w <= width and 0 < crop_h <= height:
top = random.randint(0, height - crop_h)
left = random.randint(0, width - crop_w)
break
else:
# Fallback strategy, use center crop trying to respect ratio range
min_aspect_ratio = orig_aspect * ratio[0]
max_aspect_ratio = orig_aspect * ratio[1]
if orig_aspect < min_aspect_ratio:
# Original is narrower than target min, clamp width
crop_w = width
crop_h = min(int(round(crop_w / min_aspect_ratio)), height)
elif orig_aspect > max_aspect_ratio:
# Original is wider than target max, clamp height
crop_h = height
crop_w = min(int(round(crop_h * max_aspect_ratio)), width)
else:
# Aspect ratio is within range, take the largest possible crop (full image)
crop_w = width
crop_h = height
# Ensure valid dimensions after fallback calculation
crop_h = max(1, crop_h)
crop_w = max(1, crop_w)
top = (height - crop_h) // 2
left = (width - crop_w) // 2
# Determine max feasible size for scaling of the *cropped* region
feasible_ratio, feasible_size = get_image_size_for_seq(
(crop_h, crop_w),
patch_size=(patch_h, patch_w), # Pass as tuple
max_seq_len=max_seq_len,
divisible_by_patch=divisible_by_patch,
max_ratio=max_ratio,
)
# Optionally apply final scale randomization
final_size = feasible_size
if final_scale_range is not None:
min_sc, max_sc = final_scale_range
scale_factor = random.uniform(min_sc, max_sc)
scale_factor = min(max(scale_factor, 0.0), 1.0) # Clamp factor just in case
# Calculate raw scaled size
# Note: feasible_ratio already accounts for max_ratio clamp if any
raw_h = crop_h * feasible_ratio * scale_factor
raw_w = crop_w * feasible_ratio * scale_factor
# Re-apply divisibility constraint if needed
if divisible_by_patch:
# Use ceil to avoid going under minimum patch size
target_h = patch_h * math.ceil(raw_h / patch_h)
target_w = patch_w * math.ceil(raw_w / patch_w)
else:
target_h = int(round(raw_h))
target_w = int(round(raw_w))
# Ensure final size is at least one patch dimension
target_h = max(target_h, patch_h)
target_w = max(target_w, patch_w)
final_size = (target_h, target_w)
# Final check: Ensure this randomized size still fits max_seq_len
# (It should, as we scaled down, but rounding might theoretically push it over)
num_patches_h = final_size[0] // patch_h
num_patches_w = final_size[1] // patch_w
if (num_patches_h * num_patches_w) > max_seq_len:
# If it exceeds, revert to the original feasible_size (safest)
final_size = feasible_size
warnings.warn(f"Final scale randomization ({scale_factor:.2f}) resulted in size {final_size} exceeding max_seq_len={max_seq_len} after rounding. Reverting to feasible size {feasible_size}.")
# Select interpolation mode
if isinstance(interpolation, (tuple, list)):
interpolation = random.choice(interpolation)
else:
interpolation = interpolation
return (top, left, crop_h, crop_w), final_size, interpolation
def forward(self, img: Union[torch.Tensor, Image]) -> torch.Tensor:
# Sample crop, resize, and interpolation parameters
crop_params, final_size, interpolation = self.get_params(
img,
scale=self.scale,
ratio=self.ratio,
crop_attempts=self.attempts,
patch_h=self.patch_h,
patch_w=self.patch_w,
divisible_by_patch=self.divisible_by_patch,
max_seq_len=self.max_seq_len,
final_scale_range=self.final_scale_range,
interpolation=self.interpolation,
)
top, left, crop_h, crop_w = crop_params
output = F.resized_crop(
img,
top=top,
left=left,
height=crop_h,
width=crop_w,
size=final_size,
interpolation=interpolation,
antialias=True,
)
return output
def __repr__(self) -> str:
if isinstance(self.interpolation, (tuple, list)):
interpolate_str = ', '.join(str(m).split('.')[-1] for m in self.interpolation)
else:
interpolate_str = str(self.interpolation)
format_string = self.__class__.__name__ + '('
format_string += f"patch_size=({self.patch_h}, {self.patch_w})"
format_string += f", max_seq_len={self.max_seq_len}"
format_string += f", scale={self.scale}"
format_string += f", ratio={self.ratio}"
format_string += f", interpolation=[{interpolate_str}]"
format_string += f", divisible_by_patch={self.divisible_by_patch}"
format_string += f", max_ratio={self.max_ratio}"
format_string += f", final_scale_range={self.final_scale_range}"
format_string += f", attempts={self.attempts}"
format_string += ')'
return format_string
def patchify(
img: torch.Tensor,
patch_size: Tuple[int, int],
pad: bool = True,
include_info: bool = True,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
c, h, w = img.shape
ph, pw = patch_size
# Ensure the image is divisible by patch size
if pad and (h % ph != 0 or w % pw != 0):
new_h = math.ceil(h / ph) * ph
new_w = math.ceil(w / pw) * pw
padded_img = torch.zeros(c, new_h, new_w, dtype=img.dtype)
padded_img[:, :h, :w] = img
img = padded_img
c, h, w = img.shape
# Calculate number of patches in each dimension
nh, nw = h // ph, w // pw
# Reshape image to patches [nh, nw, ph, pw, c]
patches = img.view(c, nh, ph, nw, pw).permute(1, 3, 2, 4, 0).reshape(nh * nw, ph * pw * c)
if include_info:
# Create coordinate indices
y_idx, x_idx = torch.meshgrid(torch.arange(nh), torch.arange(nw), indexing='ij')
# Stack into a single coords tensor [N, 2] with (y, x) order
coord = torch.stack([y_idx.reshape(-1), x_idx.reshape(-1)], dim=1)
# Create type indicators (all 1s for regular patches)
valid = torch.ones(nh * nw, dtype=torch.bool)
return patches, coord, valid
return patches
class Patchify(torch.nn.Module):
"""Transform an image into patches with corresponding coordinates and type indicators."""
def __init__(self, patch_size):
super().__init__()
self.patch_size = patch_size if isinstance(patch_size, tuple) else (patch_size, patch_size)
def forward(self, img):
"""
Args:
img: A PIL Image or tensor of shape [C, H, W]
Returns:
A dictionary containing:
- patches: Tensor of shape [N, P*P*C] where N is the number of patches
- patch_coord: Tensor of shape [N, 2] with (y, x) coordinates
- patch_valid: Valid indicator (all 1s for non-padding patches)
"""
if isinstance(img, Image.Image):
# Convert PIL Image to tensor [C, H, W]
img = transforms.functional.to_tensor(img)
patches, coord, valid = patchify(img, self.patch_size)
return {
'patches': patches,
'patch_coord': coord,
'patch_valid': valid,
}

View File

@ -12,7 +12,8 @@ from torchvision import transforms
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, DEFAULT_CROP_PCT from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, DEFAULT_CROP_PCT
from timm.data.auto_augment import rand_augment_transform, augment_and_mix_transform, auto_augment_transform from timm.data.auto_augment import rand_augment_transform, augment_and_mix_transform, auto_augment_transform
from timm.data.transforms import str_to_interp_mode, str_to_pil_interp, RandomResizedCropAndInterpolation, \ from timm.data.transforms import str_to_interp_mode, str_to_pil_interp, RandomResizedCropAndInterpolation, \
ResizeKeepRatio, CenterCropOrPad, RandomCropOrPad, TrimBorder, ToNumpy, MaybeToTensor, MaybePILToTensor ResizeKeepRatio, CenterCropOrPad, RandomCropOrPad, TrimBorder, MaybeToTensor, MaybePILToTensor
from timm.data.naflex_transforms import RandomResizedCropToSequence, ResizeToSequence, Patchify
from timm.data.random_erasing import RandomErasing from timm.data.random_erasing import RandomErasing
@ -46,7 +47,7 @@ def transforms_noaug_train(
] ]
if use_prefetcher: if use_prefetcher:
# prefetcher and collate will handle tensor conversion and norm # prefetcher and collate will handle tensor conversion and norm
tfl += [ToNumpy()] tfl += [MaybePILToTensor()]
elif not normalize: elif not normalize:
# when normalize disabled, converted to tensor without scaling, keep original dtype # when normalize disabled, converted to tensor without scaling, keep original dtype
tfl += [MaybePILToTensor()] tfl += [MaybePILToTensor()]
@ -84,6 +85,10 @@ def transforms_imagenet_train(
use_prefetcher: bool = False, use_prefetcher: bool = False,
normalize: bool = True, normalize: bool = True,
separate: bool = False, separate: bool = False,
naflex: bool = False,
patch_size: Union[int, Tuple[int, int]] = 16,
max_seq_len: int = 576, # 24x24 for 16x16 patch
patchify: bool = False,
): ):
""" ImageNet-oriented image transforms for training. """ ImageNet-oriented image transforms for training.
@ -111,6 +116,9 @@ def transforms_imagenet_train(
use_prefetcher: Prefetcher enabled. Do not convert image to tensor or normalize. use_prefetcher: Prefetcher enabled. Do not convert image to tensor or normalize.
normalize: Normalize tensor output w/ provided mean/std (if prefetcher not used). normalize: Normalize tensor output w/ provided mean/std (if prefetcher not used).
separate: Output transforms in 3-stage tuple. separate: Output transforms in 3-stage tuple.
naflex: Enable NaFlex mode, sequence constrained patch output
patch_size: Patch size for NaFlex mode.
max_seq_len: Max sequence length for NaFlex mode.
Returns: Returns:
If separate==True, the transforms are returned as a tuple of 3 separate transforms If separate==True, the transforms are returned as a tuple of 3 separate transforms
@ -121,35 +129,45 @@ def transforms_imagenet_train(
""" """
train_crop_mode = train_crop_mode or 'rrc' train_crop_mode = train_crop_mode or 'rrc'
assert train_crop_mode in {'rrc', 'rkrc', 'rkrr'} assert train_crop_mode in {'rrc', 'rkrc', 'rkrr'}
if train_crop_mode in ('rkrc', 'rkrr'):
# FIXME integration of RKR is a WIP primary_tfl = []
scale = tuple(scale or (0.8, 1.00)) if naflex:
ratio = tuple(ratio or (0.9, 1/.9)) primary_tfl += [RandomResizedCropToSequence(
primary_tfl = [ patch_size=patch_size,
ResizeKeepRatio( max_seq_len=max_seq_len,
img_size, interpolation=interpolation
interpolation=interpolation, )]
random_scale_prob=0.5,
random_scale_range=scale,
random_scale_area=True, # scale compatible with RRC
random_aspect_prob=0.5,
random_aspect_range=ratio,
),
CenterCropOrPad(img_size, padding_mode='reflect')
if train_crop_mode == 'rkrc' else
RandomCropOrPad(img_size, padding_mode='reflect')
]
else: else:
scale = tuple(scale or (0.08, 1.0)) # default imagenet scale range if train_crop_mode in ('rkrc', 'rkrr'):
ratio = tuple(ratio or (3. / 4., 4. / 3.)) # default imagenet ratio range # FIXME integration of RKR is a WIP
primary_tfl = [ scale = tuple(scale or (0.8, 1.00))
RandomResizedCropAndInterpolation( ratio = tuple(ratio or (0.9, 1/.9))
img_size, primary_tfl += [
scale=scale, ResizeKeepRatio(
ratio=ratio, img_size,
interpolation=interpolation, interpolation=interpolation,
) random_scale_prob=0.5,
] random_scale_range=scale,
random_scale_area=True, # scale compatible with RRC
random_aspect_prob=0.5,
random_aspect_range=ratio,
),
CenterCropOrPad(img_size, padding_mode='reflect')
if train_crop_mode == 'rkrc' else
RandomCropOrPad(img_size, padding_mode='reflect')
]
else:
scale = tuple(scale or (0.08, 1.0)) # default imagenet scale range
ratio = tuple(ratio or (3. / 4., 4. / 3.)) # default imagenet ratio range
primary_tfl += [
RandomResizedCropAndInterpolation(
img_size,
scale=scale,
ratio=ratio,
interpolation=interpolation,
)
]
if hflip > 0.: if hflip > 0.:
primary_tfl += [transforms.RandomHorizontalFlip(p=hflip)] primary_tfl += [transforms.RandomHorizontalFlip(p=hflip)]
if vflip > 0.: if vflip > 0.:
@ -215,7 +233,7 @@ def transforms_imagenet_train(
final_tfl = [] final_tfl = []
if use_prefetcher: if use_prefetcher:
# prefetcher and collate will handle tensor conversion and norm # prefetcher and collate will handle tensor conversion and norm
final_tfl += [ToNumpy()] final_tfl += [MaybePILToTensor()]
elif not normalize: elif not normalize:
# when normalize disable, converted to tensor without scaling, keeps original dtype # when normalize disable, converted to tensor without scaling, keeps original dtype
final_tfl += [MaybePILToTensor()] final_tfl += [MaybePILToTensor()]
@ -238,6 +256,9 @@ def transforms_imagenet_train(
) )
] ]
if patchify:
final_tfl += [Patchify(patch_size=patch_size)]
if separate: if separate:
return transforms.Compose(primary_tfl), transforms.Compose(secondary_tfl), transforms.Compose(final_tfl) return transforms.Compose(primary_tfl), transforms.Compose(secondary_tfl), transforms.Compose(final_tfl)
else: else:
@ -254,6 +275,10 @@ def transforms_imagenet_eval(
std: Tuple[float, ...] = IMAGENET_DEFAULT_STD, std: Tuple[float, ...] = IMAGENET_DEFAULT_STD,
use_prefetcher: bool = False, use_prefetcher: bool = False,
normalize: bool = True, normalize: bool = True,
naflex: bool = False,
patch_size: Union[int, Tuple[int, int]] = 16,
max_seq_len: int = 576, # 24x24 for 16x16 patch
patchify: bool = False,
): ):
""" ImageNet-oriented image transform for evaluation and inference. """ ImageNet-oriented image transform for evaluation and inference.
@ -267,6 +292,10 @@ def transforms_imagenet_eval(
std: Image normalization standard deviation. std: Image normalization standard deviation.
use_prefetcher: Prefetcher enabled. Do not convert image to tensor or normalize. use_prefetcher: Prefetcher enabled. Do not convert image to tensor or normalize.
normalize: Normalize tensor output w/ provided mean/std (if prefetcher not used). normalize: Normalize tensor output w/ provided mean/std (if prefetcher not used).
naflex: Enable NaFlex mode, sequence constrained patch output
patch_size: Patch size for NaFlex mode.
max_seq_len: Max sequence length for NaFlex mode.
patchify: Patchify the output instead of relying on prefetcher
Returns: Returns:
Composed transform pipeline Composed transform pipeline
@ -285,37 +314,44 @@ def transforms_imagenet_eval(
if crop_border_pixels: if crop_border_pixels:
tfl += [TrimBorder(crop_border_pixels)] tfl += [TrimBorder(crop_border_pixels)]
if crop_mode == 'squash': if naflex:
# squash mode scales each edge to 1/pct of target, then crops tfl += [ResizeToSequence(
# aspect ratio is not preserved, no img lost if crop_pct == 1.0 patch_size=patch_size,
tfl += [ max_seq_len=max_seq_len,
transforms.Resize(scale_size, interpolation=str_to_interp_mode(interpolation)), interpolation=interpolation
transforms.CenterCrop(img_size), )]
]
elif crop_mode == 'border':
# scale the longest edge of image to 1/pct of target edge, add borders to pad, then crop
# no image lost if crop_pct == 1.0
fill = [round(255 * v) for v in mean]
tfl += [
ResizeKeepRatio(scale_size, interpolation=interpolation, longest=1.0),
CenterCropOrPad(img_size, fill=fill),
]
else: else:
# default crop model is center if crop_mode == 'squash':
# aspect ratio is preserved, crops center within image, no borders are added, image is lost # squash mode scales each edge to 1/pct of target, then crops
if scale_size[0] == scale_size[1]: # aspect ratio is not preserved, no img lost if crop_pct == 1.0
# simple case, use torchvision built-in Resize w/ shortest edge mode (scalar size arg)
tfl += [ tfl += [
transforms.Resize(scale_size[0], interpolation=str_to_interp_mode(interpolation)) transforms.Resize(scale_size, interpolation=str_to_interp_mode(interpolation)),
transforms.CenterCrop(img_size),
]
elif crop_mode == 'border':
# scale the longest edge of image to 1/pct of target edge, add borders to pad, then crop
# no image lost if crop_pct == 1.0
fill = [round(255 * v) for v in mean]
tfl += [
ResizeKeepRatio(scale_size, interpolation=interpolation, longest=1.0),
CenterCropOrPad(img_size, fill=fill),
] ]
else: else:
# resize the shortest edge to matching target dim for non-square target # default crop model is center
tfl += [ResizeKeepRatio(scale_size)] # aspect ratio is preserved, crops center within image, no borders are added, image is lost
tfl += [transforms.CenterCrop(img_size)] if scale_size[0] == scale_size[1]:
# simple case, use torchvision built-in Resize w/ shortest edge mode (scalar size arg)
tfl += [
transforms.Resize(scale_size[0], interpolation=str_to_interp_mode(interpolation))
]
else:
# resize the shortest edge to matching target dim for non-square target
tfl += [ResizeKeepRatio(scale_size)]
tfl += [transforms.CenterCrop(img_size)]
if use_prefetcher: if use_prefetcher:
# prefetcher and collate will handle tensor conversion and norm # prefetcher and collate will handle tensor conversion and norm
tfl += [ToNumpy()] tfl += [MaybePILToTensor()]
elif not normalize: elif not normalize:
# when normalize disabled, converted to tensor without scaling, keeps original dtype # when normalize disabled, converted to tensor without scaling, keeps original dtype
tfl += [MaybePILToTensor()] tfl += [MaybePILToTensor()]
@ -328,6 +364,9 @@ def transforms_imagenet_eval(
), ),
] ]
if patchify:
tfl += [Patchify(patch_size=patch_size)]
return transforms.Compose(tfl) return transforms.Compose(tfl)
@ -359,6 +398,10 @@ def create_transform(
use_prefetcher: bool = False, use_prefetcher: bool = False,
normalize: bool = True, normalize: bool = True,
separate: bool = False, separate: bool = False,
naflex: bool = False,
patch_size: Union[int, Tuple[int, int]] = 16,
max_seq_len: int = 576, # 24x24 for 16x16 patch
patchify: bool = False
): ):
""" """
@ -442,6 +485,10 @@ def create_transform(
use_prefetcher=use_prefetcher, use_prefetcher=use_prefetcher,
normalize=normalize, normalize=normalize,
separate=separate, separate=separate,
naflex=naflex,
patch_size=patch_size,
max_seq_len=max_seq_len,
patchify=patchify,
) )
else: else:
assert not separate, "Separate transforms not supported for validation preprocessing" assert not separate, "Separate transforms not supported for validation preprocessing"
@ -455,6 +502,10 @@ def create_transform(
crop_border_pixels=crop_border_pixels, crop_border_pixels=crop_border_pixels,
use_prefetcher=use_prefetcher, use_prefetcher=use_prefetcher,
normalize=normalize, normalize=normalize,
naflex=naflex,
patch_size=patch_size,
max_seq_len=max_seq_len,
patchify=patchify,
) )
return transform return transform

View File

@ -1,6 +1,7 @@
from .activations import * from .activations import *
from .adaptive_avgmax_pool import \ from .adaptive_avgmax_pool import \
adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d
from .attention import Attention
from .attention2d import MultiQueryAttention2d, Attention2d, MultiQueryAttentionV2 from .attention2d import MultiQueryAttention2d, Attention2d, MultiQueryAttentionV2
from .attention_pool import AttentionPoolLatent from .attention_pool import AttentionPoolLatent
from .attention_pool2d import AttentionPool2d, RotAttentionPool2d, RotaryEmbedding from .attention_pool2d import AttentionPool2d, RotAttentionPool2d, RotaryEmbedding

66
timm/layers/attention.py Normal file
View File

@ -0,0 +1,66 @@
from typing import Final, Type, Optional
import torch
from torch import nn as nn
from torch.nn import functional as F
from .config import use_fused_attn
class Attention(nn.Module):
fused_attn: Final[bool]
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
qk_norm: bool = False,
proj_bias: bool = True,
attn_drop: float = 0.,
proj_drop: float = 0.,
norm_layer: Type[nn.Module] = nn.LayerNorm,
) -> None:
super().__init__()
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5
self.fused_attn = use_fused_attn()
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim, bias=proj_bias)
self.proj_drop = nn.Dropout(proj_drop)
def forward(
self,
x: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)
q, k = self.q_norm(q), self.k_norm(k)
if self.fused_attn:
x = F.scaled_dot_product_attention(
q, k, v,
attn_mask=attn_mask,
dropout_p=self.attn_drop.p if self.training else 0.,
)
else:
q = q * self.scale
attn = q @ k.transpose(-2, -1)
if attn_mask is not None:
attn = attn + attn_mask
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = attn @ v
x = x.transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x

View File

@ -75,7 +75,7 @@ class AttentionPoolLatent(nn.Module):
trunc_normal_tf_(self.pos_embed, std=self.pos_embed.shape[1] ** -0.5) trunc_normal_tf_(self.pos_embed, std=self.pos_embed.shape[1] ** -0.5)
trunc_normal_tf_(self.latent, std=self.latent_dim ** -0.5) trunc_normal_tf_(self.latent, std=self.latent_dim ** -0.5)
def forward(self, x): def forward(self, x, attn_mask: Optional[torch.Tensor] = None):
B, N, C = x.shape B, N, C = x.shape
if self.pos_embed is not None: if self.pos_embed is not None:
@ -91,10 +91,12 @@ class AttentionPoolLatent(nn.Module):
q, k = self.q_norm(q), self.k_norm(k) q, k = self.q_norm(q), self.k_norm(k)
if self.fused_attn: if self.fused_attn:
x = F.scaled_dot_product_attention(q, k, v) x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
else: else:
q = q * self.scale q = q * self.scale
attn = q @ k.transpose(-2, -1) attn = q @ k.transpose(-2, -1)
if attn_mask is not None:
attn = attn + attn_mask
attn = attn.softmax(dim=-1) attn = attn.softmax(dim=-1)
x = attn @ v x = attn @ v
x = x.transpose(1, 2).reshape(B, self.latent_len, C) x = x.transpose(1, 2).reshape(B, self.latent_len, C)

View File

@ -72,6 +72,7 @@ from .vgg import *
from .visformer import * from .visformer import *
from .vision_transformer import * from .vision_transformer import *
from .vision_transformer_hybrid import * from .vision_transformer_hybrid import *
from .vision_transformer_flex import *
from .vision_transformer_relpos import * from .vision_transformer_relpos import *
from .vision_transformer_sam import * from .vision_transformer_sam import *
from .vitamin import * from .vitamin import *

View File

@ -41,8 +41,8 @@ from torch.jit import Final
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD, \ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD, \
OPENAI_CLIP_MEAN, OPENAI_CLIP_STD OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
from timm.layers import PatchEmbed, Mlp, DropPath, AttentionPoolLatent, RmsNorm, PatchDropout, SwiGLUPacked, SwiGLU, \ from timm.layers import Attention, PatchEmbed, Mlp, DropPath, AttentionPoolLatent, RmsNorm, PatchDropout, \
trunc_normal_, lecun_normal_, resample_patch_embed, resample_abs_pos_embed, use_fused_attn, \ SwiGLUPacked, SwiGLU, trunc_normal_, lecun_normal_, resample_patch_embed, resample_abs_pos_embed, use_fused_attn, \
get_act_layer, get_norm_layer, LayerType get_act_layer, get_norm_layer, LayerType
from ._builder import build_model_with_cfg from ._builder import build_model_with_cfg
from ._features import feature_take_indices from ._features import feature_take_indices
@ -55,58 +55,6 @@ __all__ = ['VisionTransformer'] # model_registry will add each entrypoint fn to
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
class Attention(nn.Module):
fused_attn: Final[bool]
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
qk_norm: bool = False,
proj_bias: bool = True,
attn_drop: float = 0.,
proj_drop: float = 0.,
norm_layer: Type[nn.Module] = nn.LayerNorm,
) -> None:
super().__init__()
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5
self.fused_attn = use_fused_attn()
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim, bias=proj_bias)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)
q, k = self.q_norm(q), self.k_norm(k)
if self.fused_attn:
x = F.scaled_dot_product_attention(
q, k, v,
dropout_p=self.attn_drop.p if self.training else 0.,
)
else:
q = q * self.scale
attn = q @ k.transpose(-2, -1)
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = attn @ v
x = x.transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class LayerScale(nn.Module): class LayerScale(nn.Module):
def __init__( def __init__(
self, self,
@ -165,8 +113,8 @@ class Block(nn.Module):
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x)))) x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x), attn_mask=attn_mask)))
x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
return x return x
@ -222,8 +170,8 @@ class ResPostBlock(nn.Module):
nn.init.constant_(self.norm1.weight, self.init_values) nn.init.constant_(self.norm1.weight, self.init_values)
nn.init.constant_(self.norm2.weight, self.init_values) nn.init.constant_(self.norm2.weight, self.init_values)
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
x = x + self.drop_path1(self.norm1(self.attn(x))) x = x + self.drop_path1(self.norm1(self.attn(x, attn_mask=attn_mask)))
x = x + self.drop_path2(self.norm2(self.mlp(x))) x = x + self.drop_path2(self.norm2(self.mlp(x)))
return x return x
@ -282,7 +230,7 @@ class ParallelScalingBlock(nn.Module):
self.ls = LayerScale(dim, init_values=init_values) if init_values is not None else nn.Identity() self.ls = LayerScale(dim, init_values=init_values) if init_values is not None else nn.Identity()
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
B, N, C = x.shape B, N, C = x.shape
# Combined MLP fc1 & qkv projections # Combined MLP fc1 & qkv projections
@ -302,14 +250,18 @@ class ParallelScalingBlock(nn.Module):
if self.fused_attn: if self.fused_attn:
x_attn = F.scaled_dot_product_attention( x_attn = F.scaled_dot_product_attention(
q, k, v, q, k, v,
attn_mask=attn_mask,
dropout_p=self.attn_drop.p if self.training else 0., dropout_p=self.attn_drop.p if self.training else 0.,
) )
else: else:
q = q * self.scale q = q * self.scale
attn = q @ k.transpose(-2, -1) attn = q @ k.transpose(-2, -1)
if attn_mask is not None:
attn = attn + attn_mask
attn = attn.softmax(dim=-1) attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn) attn = self.attn_drop(attn)
x_attn = attn @ v x_attn = attn @ v
x_attn = x_attn.transpose(1, 2).reshape(B, N, C) x_attn = x_attn.transpose(1, 2).reshape(B, N, C)
x_attn = self.attn_out_proj(x_attn) x_attn = self.attn_out_proj(x_attn)
@ -379,23 +331,11 @@ class ParallelThingsBlock(nn.Module):
('drop_path', DropPath(drop_path) if drop_path > 0. else nn.Identity()) ('drop_path', DropPath(drop_path) if drop_path > 0. else nn.Identity())
]))) ])))
def _forward_jit(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
x = x + torch.stack([attn(x) for attn in self.attns]).sum(dim=0) x = x + torch.stack([attn(x, attn_mask=attn_mask) for attn in self.attns]).sum(dim=0)
x = x + torch.stack([ffn(x) for ffn in self.ffns]).sum(dim=0) x = x + torch.stack([ffn(x) for ffn in self.ffns]).sum(dim=0)
return x return x
@torch.jit.ignore
def _forward(self, x: torch.Tensor) -> torch.Tensor:
x = x + sum(attn(x) for attn in self.attns)
x = x + sum(ffn(x) for ffn in self.ffns)
return x
def forward(self, x: torch.Tensor) -> torch.Tensor:
if torch.jit.is_scripting() or torch.jit.is_tracing():
return self._forward_jit(x)
else:
return self._forward(x)
def global_pool_nlc( def global_pool_nlc(
x: torch.Tensor, x: torch.Tensor,
@ -728,7 +668,9 @@ class VisionTransformer(nn.Module):
stop_early: bool = False, stop_early: bool = False,
output_fmt: str = 'NCHW', output_fmt: str = 'NCHW',
intermediates_only: bool = False, intermediates_only: bool = False,
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: output_dict: bool = False,
attn_mask: Optional[torch.Tensor] = None,
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]], Dict[str, Any]]:
""" Forward features that returns intermediates. """ Forward features that returns intermediates.
Args: Args:
@ -739,8 +681,11 @@ class VisionTransformer(nn.Module):
stop_early: Stop iterating over blocks when last desired intermediate hit stop_early: Stop iterating over blocks when last desired intermediate hit
output_fmt: Shape of intermediate feature outputs output_fmt: Shape of intermediate feature outputs
intermediates_only: Only return intermediate features intermediates_only: Only return intermediate features
output_dict: Return outputs as a dictionary with 'image_features' and 'image_intermediates' keys
attn_mask: Optional attention mask for masked attention (e.g., for NaFlex)
Returns: Returns:
A tuple with (final_features, intermediates), a list of intermediate features, or a dictionary containing
'image_features' and 'image_intermediates' (and optionally 'image_intermediates_prefix')
""" """
assert output_fmt in ('NCHW', 'NLC'), 'Output format must be one of NCHW or NLC.' assert output_fmt in ('NCHW', 'NLC'), 'Output format must be one of NCHW or NLC.'
reshape = output_fmt == 'NCHW' reshape = output_fmt == 'NCHW'
@ -759,7 +704,7 @@ class VisionTransformer(nn.Module):
else: else:
blocks = self.blocks[:max_index + 1] blocks = self.blocks[:max_index + 1]
for i, blk in enumerate(blocks): for i, blk in enumerate(blocks):
x = blk(x) x = blk(x, attn_mask=attn_mask)
if i in take_indices: if i in take_indices:
# normalize intermediates with final norm layer if enabled # normalize intermediates with final norm layer if enabled
intermediates.append(self.norm(x) if norm else x) intermediates.append(self.norm(x) if norm else x)
@ -776,6 +721,23 @@ class VisionTransformer(nn.Module):
# reshape to BCHW output format # reshape to BCHW output format
H, W = self.patch_embed.dynamic_feat_size((height, width)) H, W = self.patch_embed.dynamic_feat_size((height, width))
intermediates = [y.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for y in intermediates] intermediates = [y.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for y in intermediates]
# For dictionary output, handle prefix tokens separately
if output_dict:
result_dict = {}
# Intermediates are always included
result_dict['image_intermediates'] = intermediates
if prefix_tokens is not None and return_prefix_tokens:
result_dict['image_intermediates_prefix'] = prefix_tokens
# Only include features if not intermediates_only
if not intermediates_only:
x_final = self.norm(x)
result_dict['image_features'] = x_final
return result_dict
# For non-dictionary output, maintain the original behavior
if not torch.jit.is_scripting() and return_prefix_tokens and prefix_tokens is not None: if not torch.jit.is_scripting() and return_prefix_tokens and prefix_tokens is not None:
# return_prefix not support in torchscript due to poor type handling # return_prefix not support in torchscript due to poor type handling
intermediates = list(zip(intermediates, prefix_tokens)) intermediates = list(zip(intermediates, prefix_tokens))
@ -811,6 +773,7 @@ class VisionTransformer(nn.Module):
reshape: bool = False, reshape: bool = False,
return_prefix_tokens: bool = False, return_prefix_tokens: bool = False,
norm: bool = False, norm: bool = False,
attn_mask: Optional[torch.Tensor] = None,
) -> List[torch.Tensor]: ) -> List[torch.Tensor]:
""" Intermediate layer accessor inspired by DINO / DINOv2 interface. """ Intermediate layer accessor inspired by DINO / DINOv2 interface.
NOTE: This API is for backwards compat, favour using forward_intermediates() directly. NOTE: This API is for backwards compat, favour using forward_intermediates() directly.
@ -821,17 +784,24 @@ class VisionTransformer(nn.Module):
norm=norm, norm=norm,
output_fmt='NCHW' if reshape else 'NLC', output_fmt='NCHW' if reshape else 'NLC',
intermediates_only=True, intermediates_only=True,
attn_mask=attn_mask,
) )
def forward_features(self, x: torch.Tensor) -> torch.Tensor: def forward_features(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
x = self.patch_embed(x) x = self.patch_embed(x)
x = self._pos_embed(x) x = self._pos_embed(x)
x = self.patch_drop(x) x = self.patch_drop(x)
x = self.norm_pre(x) x = self.norm_pre(x)
if self.grad_checkpointing and not torch.jit.is_scripting():
if attn_mask is not None:
# If mask provided, we need to apply blocks one by one
for blk in self.blocks:
x = blk(x, attn_mask=attn_mask)
elif self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint_seq(self.blocks, x) x = checkpoint_seq(self.blocks, x)
else: else:
x = self.blocks(x) x = self.blocks(x)
x = self.norm(x) x = self.norm(x)
return x return x
@ -849,8 +819,8 @@ class VisionTransformer(nn.Module):
x = self.head_drop(x) x = self.head_drop(x)
return x if pre_logits else self.head(x) return x if pre_logits else self.head(x)
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
x = self.forward_features(x) x = self.forward_features(x, attn_mask=attn_mask)
x = self.forward_head(x) x = self.forward_head(x)
return x return x

File diff suppressed because it is too large Load Diff

108
train.py
View File

@ -396,6 +396,15 @@ group.add_argument('--wandb-tags', default=[], type=str, nargs='+',
group.add_argument('--wandb-resume-id', default='', type=str, metavar='ID', group.add_argument('--wandb-resume-id', default='', type=str, metavar='ID',
help='If resuming a run, the id of the run in wandb') help='If resuming a run, the id of the run in wandb')
# NaFlex scheduled loader arguments
group.add_argument('--naflex-loader', action='store_true', default=False,
help='Use NaFlex loader (Requires NaFlex compatible model)')
group.add_argument('--naflex-train-seq-lens', type=int, nargs='+', default=[128, 256, 576, 784, 1024],
help='Sequence lengths to use for NaFlex loader')
group.add_argument('--naflex-max-seq-len', type=int, default=576,
help='Fixed maximum sequence length for NaFlex loader (validation)')
def _parse_args(): def _parse_args():
# Do we have a config file to parse? # Do we have a config file to parse?
@ -669,6 +678,7 @@ def main():
trust_remote_code=args.dataset_trust_remote_code, trust_remote_code=args.dataset_trust_remote_code,
) )
dataset_eval = None
if args.val_split: if args.val_split:
dataset_eval = create_dataset( dataset_eval = create_dataset(
args.dataset, args.dataset,
@ -690,6 +700,7 @@ def main():
mixup_fn = None mixup_fn = None
mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
if mixup_active: if mixup_active:
assert not args.naflex_loader, "Mixup/Cutmix not currently supported for NaFlex loading."
mixup_args = dict( mixup_args = dict(
mixup_alpha=args.mixup, mixup_alpha=args.mixup,
cutmix_alpha=args.cutmix, cutmix_alpha=args.cutmix,
@ -714,9 +725,19 @@ def main():
train_interpolation = args.train_interpolation train_interpolation = args.train_interpolation
if args.no_aug or not train_interpolation: if args.no_aug or not train_interpolation:
train_interpolation = data_config['interpolation'] train_interpolation = data_config['interpolation']
loader_train = create_loader(
dataset_train, # Check if we should use the NaFlex scheduled loader
input_size=data_config['input_size'], common_loader_kwargs = dict(
mean=data_config['mean'],
std=data_config['std'],
pin_memory=args.pin_mem,
img_dtype=model_dtype or torch.float32,
device=device,
distributed=args.distributed,
use_prefetcher=args.prefetcher,
)
train_loader_kwargs = dict(
batch_size=args.batch_size, batch_size=args.batch_size,
is_training=True, is_training=True,
no_aug=args.no_aug, no_aug=args.no_aug,
@ -737,42 +758,70 @@ def main():
num_aug_repeats=args.aug_repeats, num_aug_repeats=args.aug_repeats,
num_aug_splits=num_aug_splits, num_aug_splits=num_aug_splits,
interpolation=train_interpolation, interpolation=train_interpolation,
mean=data_config['mean'],
std=data_config['std'],
num_workers=args.workers, num_workers=args.workers,
distributed=args.distributed,
collate_fn=collate_fn,
pin_memory=args.pin_mem,
img_dtype=model_dtype or torch.float32,
device=device,
use_prefetcher=args.prefetcher,
use_multi_epochs_loader=args.use_multi_epochs_loader,
worker_seeding=args.worker_seeding, worker_seeding=args.worker_seeding,
) )
if args.naflex_loader:
from timm.data.naflex_loader import create_naflex_loader
if utils.is_primary(args):
_logger.info('Using NaFlex loader')
loader_train = create_naflex_loader(
dataset=dataset_train,
patch_size=16, # Could be derived from model config
train_seq_lens=args.naflex_train_seq_lens,
rank=args.rank,
world_size=args.world_size,
**common_loader_kwargs,
**train_loader_kwargs,
)
else:
# Use standard loader
loader_train = create_loader(
dataset_train,
input_size=data_config['input_size'],
collate_fn=collate_fn,
use_multi_epochs_loader=args.use_multi_epochs_loader,
**common_loader_kwargs,
**train_loader_kwargs,
)
loader_eval = None loader_eval = None
if args.val_split: if args.val_split:
assert dataset_eval is not None
eval_workers = args.workers eval_workers = args.workers
if args.distributed and ('tfds' in args.dataset or 'wds' in args.dataset): if args.distributed and ('tfds' in args.dataset or 'wds' in args.dataset):
# FIXME reduces validation padding issues when using TFDS, WDS w/ workers and distributed training # FIXME reduces validation padding issues when using TFDS, WDS w/ workers and distributed training
eval_workers = min(2, args.workers) eval_workers = min(2, args.workers)
loader_eval = create_loader(
dataset_eval, eval_loader_kwargs = dict(
input_size=data_config['input_size'],
batch_size=args.validation_batch_size or args.batch_size, batch_size=args.validation_batch_size or args.batch_size,
is_training=False, is_training=False,
interpolation=data_config['interpolation'], interpolation=data_config['interpolation'],
mean=data_config['mean'],
std=data_config['std'],
num_workers=eval_workers, num_workers=eval_workers,
distributed=args.distributed,
crop_pct=data_config['crop_pct'], crop_pct=data_config['crop_pct'],
pin_memory=args.pin_mem,
img_dtype=model_dtype or torch.float32,
device=device,
use_prefetcher=args.prefetcher,
) )
if args.naflex_loader:
from timm.data.naflex_loader import create_naflex_loader
# Use largest sequence length for validation
loader_eval = create_naflex_loader(
dataset=dataset_eval,
patch_size=16, # Could be derived from model config
max_seq_len=args.naflex_max_seq_len,
**common_loader_kwargs,
**eval_loader_kwargs
)
else:
# Use standard loader
loader_eval = create_loader(
dataset_eval,
input_size=data_config['input_size'],
**common_loader_kwargs,
**eval_loader_kwargs,
)
# setup loss function # setup loss function
if args.jsd_loss: if args.jsd_loss:
assert num_aug_splits > 1 # JSD only valid with aug splits set assert num_aug_splits > 1 # JSD only valid with aug splits set
@ -1083,8 +1132,12 @@ def train_one_epoch(
loss = _forward() loss = _forward()
_backward(loss) _backward(loss)
losses_m.update(loss.item() * accum_steps, input.size(0)) if isinstance(input, dict):
update_sample_count += input.size(0) batch_size = input['patches'].shape[0]
else:
batch_size = input.shape[0]
losses_m.update(loss.item() * accum_steps, batch_size)
update_sample_count += batch_size
if not need_update: if not need_update:
data_start_time = time.time() data_start_time = time.time()
@ -1210,9 +1263,10 @@ def validate(
elif device.type == "npu": elif device.type == "npu":
torch.npu.synchronize() torch.npu.synchronize()
losses_m.update(reduced_loss.item(), input.size(0)) batch_size = output.shape[0]
top1_m.update(acc1.item(), output.size(0)) losses_m.update(reduced_loss.item(), batch_size)
top5_m.update(acc5.item(), output.size(0)) top1_m.update(acc1.item(), batch_size)
top5_m.update(acc5.item(), batch_size)
batch_time_m.update(time.time() - end) batch_time_m.update(time.time() - end)
end = time.time() end = time.time()