mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Initial NaFlex ViT model and training support
This commit is contained in:
parent
e44f14d7d2
commit
0893f5d296
@ -8,6 +8,13 @@ from .dataset_info import DatasetInfo, CustomDatasetInfo
|
|||||||
from .imagenet_info import ImageNetInfo, infer_imagenet_subset
|
from .imagenet_info import ImageNetInfo, infer_imagenet_subset
|
||||||
from .loader import create_loader
|
from .loader import create_loader
|
||||||
from .mixup import Mixup, FastCollateMixup
|
from .mixup import Mixup, FastCollateMixup
|
||||||
|
from .naflex_transforms import (
|
||||||
|
ResizeToSequence,
|
||||||
|
CenterCropToSequence,
|
||||||
|
RandomCropToSequence,
|
||||||
|
RandomResizedCropToSequence,
|
||||||
|
ResizeKeepRatioToSequence,
|
||||||
|
)
|
||||||
from .readers import create_reader
|
from .readers import create_reader
|
||||||
from .readers import get_img_extensions, is_img_extension, set_img_extensions, add_img_extensions, del_img_extensions
|
from .readers import get_img_extensions, is_img_extension, set_img_extensions, add_img_extensions, del_img_extensions
|
||||||
from .real_labels import RealLabelsImagenet
|
from .real_labels import RealLabelsImagenet
|
||||||
|
422
timm/data/naflex_dataset.py
Normal file
422
timm/data/naflex_dataset.py
Normal file
@ -0,0 +1,422 @@
|
|||||||
|
"""
|
||||||
|
Dynamic Sequence Length Datasets for Variable Resolution Image Processing
|
||||||
|
|
||||||
|
Implements two dataset wrappers:
|
||||||
|
1. DynamicSeqMapDataset - Map-style dataset that returns batches with variable sequence lengths
|
||||||
|
2. DynamicSeqIterDataset - Iterable dataset that yields batches with variable sequence lengths
|
||||||
|
|
||||||
|
Both support:
|
||||||
|
- Pre-initialized transforms for efficiency
|
||||||
|
- Distributed training
|
||||||
|
- Multiple workers
|
||||||
|
- Variable batch sizes based on sequence length
|
||||||
|
"""
|
||||||
|
|
||||||
|
import math
|
||||||
|
import random
|
||||||
|
import warnings
|
||||||
|
from functools import partial
|
||||||
|
from typing import Any, Iterator, List, Tuple, Dict, Optional, Union, Callable
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import Dataset, IterableDataset, DataLoader
|
||||||
|
from torchvision import transforms
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
|
from .naflex_transforms import Patchify, patchify
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_batch_size(
|
||||||
|
tokens_per_batch: int,
|
||||||
|
seq_len: int,
|
||||||
|
max_size: Optional[int] = None,
|
||||||
|
divisor: int = 1,
|
||||||
|
rounding: str ='floor',
|
||||||
|
):
|
||||||
|
"""Calculate batch size based on sequence length with divisibility constraints."""
|
||||||
|
# Calculate raw batch size based on sequence length
|
||||||
|
raw_batch_size = tokens_per_batch / seq_len
|
||||||
|
|
||||||
|
# Apply divisibility with specified rounding method
|
||||||
|
if divisor > 1:
|
||||||
|
if rounding == 'floor':
|
||||||
|
batch_size = math.floor(raw_batch_size / divisor) * divisor
|
||||||
|
elif rounding == 'ceil':
|
||||||
|
batch_size = math.ceil(raw_batch_size / divisor) * divisor
|
||||||
|
else: # 'round' is the default
|
||||||
|
batch_size = round(raw_batch_size / divisor) * divisor
|
||||||
|
else:
|
||||||
|
# If no divisor specified, just use integer division
|
||||||
|
batch_size = int(raw_batch_size)
|
||||||
|
|
||||||
|
# Ensure batch size is valid
|
||||||
|
batch_size = max(1, batch_size) # At least 1
|
||||||
|
|
||||||
|
if max_size is not None:
|
||||||
|
batch_size = min(batch_size, max_size)
|
||||||
|
|
||||||
|
return batch_size
|
||||||
|
|
||||||
|
|
||||||
|
def _collate_batch(
|
||||||
|
batch_samples: List[Tuple[Dict[str, torch.Tensor], Any]],
|
||||||
|
target_seq_len: int
|
||||||
|
) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]:
|
||||||
|
"""Collates processed samples into a batch, padding/truncating to target_seq_len."""
|
||||||
|
batch_patch_data = [item[0] for item in batch_samples]
|
||||||
|
batch_labels = [item[1] for item in batch_samples]
|
||||||
|
|
||||||
|
if not batch_patch_data:
|
||||||
|
return {}, torch.empty(0)
|
||||||
|
|
||||||
|
batch_size = len(batch_patch_data)
|
||||||
|
patch_dim = batch_patch_data[0]['patches'].shape[1]
|
||||||
|
|
||||||
|
# Initialize tensors with target sequence length
|
||||||
|
patches_batch = torch.zeros((batch_size, target_seq_len, patch_dim), dtype=torch.float32)
|
||||||
|
patch_coord_batch = torch.zeros((batch_size, target_seq_len, 2), dtype=torch.int64)
|
||||||
|
patch_valid_batch = torch.zeros((batch_size, target_seq_len), dtype=torch.bool) # Use bool
|
||||||
|
|
||||||
|
for i, data in enumerate(batch_patch_data):
|
||||||
|
num_patches = data['patches'].shape[0]
|
||||||
|
# Take min(num_patches, target_seq_len) patches
|
||||||
|
n_copy = min(num_patches, target_seq_len)
|
||||||
|
|
||||||
|
patches_batch[i, :n_copy] = data['patches'][:n_copy]
|
||||||
|
patch_coord_batch[i, :n_copy] = data['patch_coord'][:n_copy]
|
||||||
|
patch_valid_batch[i, :n_copy] = data['patch_valid'][:n_copy] # Copy validity flags
|
||||||
|
|
||||||
|
# Create the final input dict
|
||||||
|
input_dict = {
|
||||||
|
'patches': patches_batch,
|
||||||
|
'patch_coord': patch_coord_batch,
|
||||||
|
'patch_valid': patch_valid_batch, # Boolean mask
|
||||||
|
# Note: 'seq_length' might be ambiguous. The target length is target_seq_len.
|
||||||
|
# The actual number of valid patches per sample varies.
|
||||||
|
# 'patch_valid' mask is the most reliable source of truth.
|
||||||
|
}
|
||||||
|
|
||||||
|
# Attempt to stack labels if they are tensors, otherwise return list
|
||||||
|
try:
|
||||||
|
if isinstance(batch_labels[0], torch.Tensor):
|
||||||
|
labels_tensor = torch.stack(batch_labels)
|
||||||
|
else:
|
||||||
|
# Convert numerical types to tensor, keep others as list (or handle specific types)
|
||||||
|
if isinstance(batch_labels[0], (int, float)):
|
||||||
|
labels_tensor = torch.tensor(batch_labels)
|
||||||
|
else:
|
||||||
|
# Cannot convert non-numerical labels easily, return as list
|
||||||
|
# Or handle specific conversion if needed
|
||||||
|
# For FakeDataset, labels are ints, so this works
|
||||||
|
labels_tensor = torch.tensor(batch_labels) # Assuming labels are numerical
|
||||||
|
except Exception:
|
||||||
|
# Fallback if stacking fails (e.g., different shapes, types)
|
||||||
|
print("Warning: Could not stack labels into a tensor. Returning list of labels.")
|
||||||
|
labels_tensor = batch_labels # Return as list
|
||||||
|
|
||||||
|
return input_dict, labels_tensor
|
||||||
|
|
||||||
|
|
||||||
|
class VariableSeqMapWrapper(IterableDataset):
|
||||||
|
"""
|
||||||
|
IterableDataset wrapper for a map-style base dataset.
|
||||||
|
|
||||||
|
Yields batches with variable sequence lengths. It calculates a canonical
|
||||||
|
batch schedule (sequence length, batch size pairs) once based on the
|
||||||
|
total dataset size (padded for distribution). Each epoch, it shuffles
|
||||||
|
the *order* of this canonical schedule and the dataset indices.
|
||||||
|
This ensures a consistent number of batches and samples per epoch
|
||||||
|
across all ranks. Handles distributed training and multiple workers.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
base_dataset: Dataset,
|
||||||
|
patch_size: Union[int, Tuple[int, int]] = 16,
|
||||||
|
seq_lens: List[int] = (128, 256, 576, 784, 1024),
|
||||||
|
max_tokens_per_batch: int = 4096 * 4, # Example: 16k tokens
|
||||||
|
transform_factory: Optional[Callable] = None,
|
||||||
|
seed: int = 42,
|
||||||
|
shuffle: bool = True,
|
||||||
|
distributed: bool = False,
|
||||||
|
rank: int = 0,
|
||||||
|
world_size: int = 1,
|
||||||
|
epoch: int = 0,
|
||||||
|
batch_divisor: int = 8, # Ensure batch size is divisible by this
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
if not hasattr(base_dataset, '__len__') or not hasattr(base_dataset, '__getitem__'):
|
||||||
|
raise TypeError("base_dataset must be a map-style dataset (implement __len__ and __getitem__)")
|
||||||
|
|
||||||
|
self.base_dataset = base_dataset
|
||||||
|
self.patch_size = patch_size if isinstance(patch_size, tuple) else (patch_size, patch_size)
|
||||||
|
self.seq_lens = sorted(list(set(seq_lens))) # Ensure unique and sorted
|
||||||
|
self.max_tokens_per_batch = max_tokens_per_batch
|
||||||
|
self.seed = seed
|
||||||
|
self.shuffle = shuffle
|
||||||
|
self.distributed = distributed
|
||||||
|
self.rank = rank if distributed else 0
|
||||||
|
self.world_size = world_size if distributed else 1
|
||||||
|
self.epoch = epoch
|
||||||
|
self.batch_divisor = batch_divisor
|
||||||
|
|
||||||
|
# Pre-initialize transforms for each sequence length
|
||||||
|
self.transforms: Dict[int, Optional[Callable]] = {}
|
||||||
|
if transform_factory:
|
||||||
|
for seq_len in self.seq_lens:
|
||||||
|
self.transforms[seq_len] = transform_factory(max_seq_len=seq_len, patch_size=self.patch_size)
|
||||||
|
else:
|
||||||
|
for seq_len in self.seq_lens:
|
||||||
|
self.transforms[seq_len] = None # No transform
|
||||||
|
|
||||||
|
self.patchifier = Patchify(self.patch_size)
|
||||||
|
|
||||||
|
# --- Canonical Schedule Calculation (Done Once) ---
|
||||||
|
self._canonical_batch_schedule: List[Tuple[int, int]] = []
|
||||||
|
self._num_batches_per_rank: int = 0
|
||||||
|
self._padded_samples_per_rank: int = 0
|
||||||
|
self._create_canonical_schedule() # Calculate schedule based on padded size
|
||||||
|
|
||||||
|
# --- Per-Epoch State ---
|
||||||
|
# Stores (seq_len, list_of_indices) for the current epoch, specific to this rank
|
||||||
|
self._epoch_batches: List[Tuple[int, List[int]]] = []
|
||||||
|
self._prepare_epoch_batches(self.epoch) # setup for initial epoch
|
||||||
|
|
||||||
|
def _create_canonical_schedule(self):
|
||||||
|
"""
|
||||||
|
Calculates the canonical batch schedule (seq_len, batch_size pairs)
|
||||||
|
based on the dataset size, padded for distributed training.
|
||||||
|
This schedule is the *same* for all ranks and ensures consistent
|
||||||
|
epoch length. It is calculated once during initialization.
|
||||||
|
"""
|
||||||
|
total_len = len(self.base_dataset)
|
||||||
|
padded_total_len = total_len
|
||||||
|
num_samples_per_rank = total_len
|
||||||
|
|
||||||
|
if self.distributed and self.world_size > 1:
|
||||||
|
# Calculate padding needed for even distribution
|
||||||
|
if total_len % self.world_size != 0:
|
||||||
|
pad_size = self.world_size - (total_len % self.world_size)
|
||||||
|
padded_total_len += pad_size
|
||||||
|
print(f"Rank {self.rank}: Padding dataset with {pad_size} samples for distributed training (total size {padded_total_len}).")
|
||||||
|
else:
|
||||||
|
pad_size = 0
|
||||||
|
|
||||||
|
if padded_total_len % self.world_size != 0:
|
||||||
|
# This should not happen with the padding logic, but safeguard
|
||||||
|
raise RuntimeError(f"Internal Error: Padded total length {padded_total_len} not divisible by world size {self.world_size}")
|
||||||
|
|
||||||
|
num_samples_per_rank = padded_total_len // self.world_size
|
||||||
|
elif self.distributed and self.world_size <= 1:
|
||||||
|
# Distributed flag set but world_size is 1, treat as non-distributed
|
||||||
|
pass # num_samples_per_rank remains total_len
|
||||||
|
|
||||||
|
self._padded_samples_per_rank = num_samples_per_rank
|
||||||
|
|
||||||
|
if num_samples_per_rank == 0:
|
||||||
|
self._canonical_batch_schedule = []
|
||||||
|
self._num_batches_per_rank = 0
|
||||||
|
return
|
||||||
|
|
||||||
|
# Use a fixed seed for generating the canonical schedule structure
|
||||||
|
g = torch.Generator()
|
||||||
|
g.manual_seed(self.seed) # Use base seed, NOT epoch seed
|
||||||
|
|
||||||
|
current_schedule: List[Tuple[int, int]] = []
|
||||||
|
remaining_samples = num_samples_per_rank
|
||||||
|
total_scheduled_samples = 0
|
||||||
|
|
||||||
|
while remaining_samples > 0:
|
||||||
|
# Sample sequence length deterministically based on base seed
|
||||||
|
seq_idx = torch.randint(0, len(self.seq_lens), (1,), generator=g).item()
|
||||||
|
seq_len = self.seq_lens[seq_idx]
|
||||||
|
|
||||||
|
# Calculate batch size
|
||||||
|
batch_size = calculate_batch_size(
|
||||||
|
tokens_per_batch=self.max_tokens_per_batch,
|
||||||
|
seq_len=seq_len,
|
||||||
|
# max_size should be remaining_samples to avoid overshooting
|
||||||
|
max_size=remaining_samples,
|
||||||
|
divisor=self.batch_divisor,
|
||||||
|
rounding='floor',
|
||||||
|
)
|
||||||
|
# Ensure batch size is positive and doesn't exceed remaining samples
|
||||||
|
batch_size = max(1, batch_size)
|
||||||
|
batch_size = min(batch_size, remaining_samples)
|
||||||
|
|
||||||
|
if batch_size <= 0:
|
||||||
|
warnings.warn(f"Calculated batch size <= 0 (seq_len={seq_len}, remaining={remaining_samples}). Stopping schedule generation early.")
|
||||||
|
break # Avoid infinite loop if something goes wrong
|
||||||
|
|
||||||
|
current_schedule.append((seq_len, batch_size))
|
||||||
|
remaining_samples -= batch_size
|
||||||
|
total_scheduled_samples += batch_size
|
||||||
|
|
||||||
|
# Sanity check: Ensure the schedule covers all samples for the rank
|
||||||
|
if total_scheduled_samples != num_samples_per_rank:
|
||||||
|
warnings.warn(
|
||||||
|
f"Rank {self.rank}: Canonical schedule accounts for {total_scheduled_samples} samples, "
|
||||||
|
f"but expected {num_samples_per_rank} samples per rank. "
|
||||||
|
f"This might happen if min_batch_size or batch_divisor constraints prevent utilizing all samples. "
|
||||||
|
f"Check parameters. Remaining samples: {remaining_samples}"
|
||||||
|
)
|
||||||
|
# Adjust if needed? Could add a final small batch, but might violate constraints.
|
||||||
|
# Current behavior: some samples might be dropped if schedule logic fails.
|
||||||
|
|
||||||
|
self._canonical_batch_schedule = current_schedule
|
||||||
|
self._num_batches_per_rank = len(current_schedule)
|
||||||
|
print(f"Rank {self.rank}: Created canonical schedule with {self._num_batches_per_rank} batches for {self._padded_samples_per_rank} samples/rank.")
|
||||||
|
|
||||||
|
|
||||||
|
def _prepare_epoch_batches(self, epoch: int):
|
||||||
|
"""
|
||||||
|
Prepares the batches for the current epoch by:
|
||||||
|
1. Shuffling the full dataset indices (using epoch seed).
|
||||||
|
2. Applying padding if in distributed mode.
|
||||||
|
3. Selecting indices for the current rank.
|
||||||
|
4. Shuffling the *order* of the canonical batch schedule (using epoch seed).
|
||||||
|
5. Assigning the rank's indices to the shuffled batches.
|
||||||
|
"""
|
||||||
|
g = torch.Generator()
|
||||||
|
g.manual_seed(self.seed + epoch) # Epoch-specific seed for shuffling
|
||||||
|
|
||||||
|
# 1. Get shuffled global indices
|
||||||
|
total_len = len(self.base_dataset)
|
||||||
|
if self.shuffle:
|
||||||
|
all_indices_shuffled = torch.randperm(total_len, generator=g).tolist()
|
||||||
|
else:
|
||||||
|
all_indices_shuffled = list(range(total_len))
|
||||||
|
|
||||||
|
# 2. Apply padding for distributed mode
|
||||||
|
indices_for_ranks = all_indices_shuffled
|
||||||
|
if self.distributed and self.world_size > 1:
|
||||||
|
padded_total_len = self._padded_samples_per_rank * self.world_size
|
||||||
|
if padded_total_len > total_len:
|
||||||
|
pad_size = padded_total_len - total_len
|
||||||
|
# Repeat initial elements from the *shuffled* list for padding
|
||||||
|
indices_for_ranks = all_indices_shuffled + all_indices_shuffled[:pad_size]
|
||||||
|
# Ensure length matches expectation
|
||||||
|
if len(indices_for_ranks) != padded_total_len:
|
||||||
|
raise RuntimeError(f"Internal Error: Padded index list length {len(indices_for_ranks)} does not match expected {padded_total_len}")
|
||||||
|
|
||||||
|
|
||||||
|
# 3. Select indices for the current rank
|
||||||
|
if self.distributed and self.world_size > 1:
|
||||||
|
indices_this_rank = indices_for_ranks[self.rank::self.world_size]
|
||||||
|
else: # Non-distributed or world_size=1
|
||||||
|
indices_this_rank = indices_for_ranks
|
||||||
|
|
||||||
|
# Sanity check length
|
||||||
|
if len(indices_this_rank) != self._padded_samples_per_rank:
|
||||||
|
# This might happen if canonical schedule generation had warnings/issues
|
||||||
|
warnings.warn(
|
||||||
|
f"Rank {self.rank}: Number of indices for this rank ({len(indices_this_rank)}) "
|
||||||
|
f"does not match expected padded samples per rank ({self._padded_samples_per_rank}). "
|
||||||
|
f"Epoch generation might be inconsistent."
|
||||||
|
)
|
||||||
|
# Adjust expected samples? Or truncate/pad indices? Let's proceed but warn.
|
||||||
|
# Using min() prevents IndexError later if indices are fewer than expected.
|
||||||
|
effective_samples_this_rank = min(len(indices_this_rank), self._padded_samples_per_rank)
|
||||||
|
indices_this_rank = indices_this_rank[:effective_samples_this_rank]
|
||||||
|
|
||||||
|
else:
|
||||||
|
effective_samples_this_rank = self._padded_samples_per_rank
|
||||||
|
|
||||||
|
# 4. Shuffle the order of the canonical batch schedule for this epoch
|
||||||
|
if self.shuffle:
|
||||||
|
schedule_perm = torch.randperm(self._num_batches_per_rank, generator=g).tolist()
|
||||||
|
shuffled_schedule = [self._canonical_batch_schedule[i] for i in schedule_perm]
|
||||||
|
else:
|
||||||
|
shuffled_schedule = list(self._canonical_batch_schedule) # Keep original order
|
||||||
|
|
||||||
|
# 5. Assign indices to the shuffled batches
|
||||||
|
self._epoch_batches = []
|
||||||
|
idx_pos = 0
|
||||||
|
scheduled_samples_count = 0
|
||||||
|
for seq_len, bs in shuffled_schedule:
|
||||||
|
# Ensure we don't try to grab more indices than available for the rank
|
||||||
|
actual_bs = min(bs, effective_samples_this_rank - idx_pos)
|
||||||
|
if actual_bs <= 0:
|
||||||
|
if scheduled_samples_count < effective_samples_this_rank:
|
||||||
|
# This indicates mismatch between schedule total and actual samples
|
||||||
|
warnings.warn(f"Rank {self.rank}: Ran out of samples ({idx_pos}/{effective_samples_this_rank}) before processing entire schedule. Check schedule generation.")
|
||||||
|
break # Stop if no more indices or batch size is zero
|
||||||
|
|
||||||
|
batch_indices = indices_this_rank[idx_pos : idx_pos + actual_bs]
|
||||||
|
self._epoch_batches.append((seq_len, batch_indices))
|
||||||
|
idx_pos += actual_bs
|
||||||
|
scheduled_samples_count += actual_bs
|
||||||
|
|
||||||
|
# Final check
|
||||||
|
if scheduled_samples_count != effective_samples_this_rank:
|
||||||
|
warnings.warn(
|
||||||
|
f"Rank {self.rank}: Assigned {scheduled_samples_count} samples to batches, "
|
||||||
|
f"but expected {effective_samples_this_rank} effective samples this epoch. "
|
||||||
|
f"Indices remaining: {effective_samples_this_rank - scheduled_samples_count}."
|
||||||
|
)
|
||||||
|
|
||||||
|
def set_epoch(self, epoch: int):
|
||||||
|
"""Updates the epoch, regenerating the epoch-specific batches."""
|
||||||
|
# Only regenerate if the epoch actually changes
|
||||||
|
if epoch != self.epoch:
|
||||||
|
self.epoch = epoch
|
||||||
|
self._prepare_epoch_batches(epoch)
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
"""
|
||||||
|
Returns the number of batches per **worker** for the current epoch.
|
||||||
|
Calculated based on the fixed number of batches per rank divided by
|
||||||
|
the number of workers.
|
||||||
|
"""
|
||||||
|
return self._num_batches_per_rank
|
||||||
|
|
||||||
|
def __iter__(self) -> Iterator[Tuple[Dict[str, torch.Tensor], torch.Tensor]]:
|
||||||
|
"""
|
||||||
|
Iterates through the pre-calculated batches for the current epoch,
|
||||||
|
distributing them among workers.
|
||||||
|
"""
|
||||||
|
worker_info = torch.utils.data.get_worker_info()
|
||||||
|
num_workers = worker_info.num_workers if worker_info else 1
|
||||||
|
worker_id = worker_info.id if worker_info else 0
|
||||||
|
|
||||||
|
# Distribute pre-calculated batches among workers for this rank
|
||||||
|
# Each worker processes a slice of the batches prepared in _prepare_epoch_batches
|
||||||
|
batches_for_worker = self._epoch_batches[worker_id::num_workers]
|
||||||
|
for seq_len, indices in batches_for_worker:
|
||||||
|
if not indices: # Skip if a batch ended up with no indices (shouldn't happen often)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Get the pre-initialized transform for this sequence length
|
||||||
|
transform = self.transforms.get(seq_len)
|
||||||
|
|
||||||
|
batch_samples = []
|
||||||
|
for idx in indices:
|
||||||
|
try:
|
||||||
|
# Get original image and label from map-style dataset
|
||||||
|
img, label = self.base_dataset[idx]
|
||||||
|
|
||||||
|
# Apply transform if available
|
||||||
|
# Handle cases where transform might return None or fail
|
||||||
|
processed_img = transform(img) if transform else img
|
||||||
|
if processed_img is None:
|
||||||
|
warnings.warn(f"Transform returned None for index {idx}. Skipping sample.")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Apply patching
|
||||||
|
patch_data = self.patchifier(processed_img)
|
||||||
|
batch_samples.append((patch_data, label))
|
||||||
|
|
||||||
|
except IndexError:
|
||||||
|
warnings.warn(f"IndexError encountered for index {idx} (possibly due to padding/repeated indices). Skipping sample.")
|
||||||
|
continue
|
||||||
|
except Exception as e:
|
||||||
|
# Log other potential errors during data loading/processing
|
||||||
|
warnings.warn(f"Error processing sample index {idx}. Error: {e}. Skipping sample.")
|
||||||
|
continue # Skip problematic sample
|
||||||
|
|
||||||
|
# Collate the processed samples into a batch
|
||||||
|
if batch_samples: # Only yield if we successfully processed samples
|
||||||
|
yield _collate_batch(batch_samples, seq_len)
|
||||||
|
|
||||||
|
# If batch_samples is empty after processing 'indices', an empty batch is skipped.
|
341
timm/data/naflex_loader.py
Normal file
341
timm/data/naflex_loader.py
Normal file
@ -0,0 +1,341 @@
|
|||||||
|
import math
|
||||||
|
from contextlib import suppress
|
||||||
|
from functools import partial
|
||||||
|
from typing import Callable, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from .constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||||
|
from .loader import _worker_init
|
||||||
|
from .naflex_dataset import VariableSeqMapWrapper
|
||||||
|
from .transforms_factory import create_transform
|
||||||
|
|
||||||
|
|
||||||
|
class NaFlexCollator:
|
||||||
|
"""Custom collator for batching NaFlex-style variable-resolution images."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
patch_size=16,
|
||||||
|
max_seq_len=None,
|
||||||
|
):
|
||||||
|
self.patch_size = patch_size
|
||||||
|
self.max_seq_len = max_seq_len or 576 # Default ViT-B/16 sequence length (577 = 24*24)
|
||||||
|
|
||||||
|
def __call__(self, batch):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
batch: List of tuples (patch_dict, target)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple of (input_dict, targets) where input_dict contains:
|
||||||
|
- patches: Padded tensor of patches
|
||||||
|
- patch_coord: Coordinates for each patch (y, x)
|
||||||
|
- patch_valid: Valid indicators
|
||||||
|
"""
|
||||||
|
assert isinstance(batch[0], tuple)
|
||||||
|
batch_size = len(batch)
|
||||||
|
|
||||||
|
# FIXME
|
||||||
|
# get seq len from sampler schedule
|
||||||
|
|
||||||
|
# resize to final size based on seq_len and patchify
|
||||||
|
|
||||||
|
# Extract targets
|
||||||
|
targets = torch.tensor([item[1] for item in batch], dtype=torch.int64)
|
||||||
|
|
||||||
|
# Get patch dictionaries
|
||||||
|
patch_dicts = [item[0] for item in batch]
|
||||||
|
|
||||||
|
# If we have a maximum sequence length constraint, ensure we don't exceed it
|
||||||
|
if self.max_seq_len is not None:
|
||||||
|
max_patches = self.max_seq_len
|
||||||
|
else:
|
||||||
|
# Find the maximum number of patches in this batch
|
||||||
|
max_patches = max(item['patches'].shape[0] for item in patch_dicts)
|
||||||
|
|
||||||
|
# Get patch dimensionality
|
||||||
|
patch_dim = patch_dicts[0]['patches'].shape[1]
|
||||||
|
|
||||||
|
# Prepare tensors for the batch
|
||||||
|
patches = torch.zeros((batch_size, max_patches, patch_dim), dtype=torch.float32)
|
||||||
|
patch_coord = torch.zeros((batch_size, max_patches, 2), dtype=torch.int64) # [B, N, 2] for (y, x)
|
||||||
|
patch_valid = torch.zeros((batch_size, max_patches), dtype=torch.bool)
|
||||||
|
|
||||||
|
# Fill in the tensors
|
||||||
|
for i, patch_dict in enumerate(patch_dicts):
|
||||||
|
num_patches = min(patch_dict['patches'].shape[0], max_patches)
|
||||||
|
|
||||||
|
patches[i, :num_patches] = patch_dict['patches'][:num_patches]
|
||||||
|
patch_coord[i, :num_patches] = patch_dict['patch_coord'][:num_patches]
|
||||||
|
patch_valid[i, :num_patches] = patch_dict['patch_valid'][:num_patches]
|
||||||
|
|
||||||
|
return {
|
||||||
|
'patches': patches,
|
||||||
|
'patch_coord': patch_coord,
|
||||||
|
'patch_valid': patch_valid,
|
||||||
|
'seq_len': max_patches,
|
||||||
|
}, targets
|
||||||
|
|
||||||
|
|
||||||
|
class NaFlexPrefetchLoader:
|
||||||
|
"""Data prefetcher for NaFlex format which normalizes patches."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
loader,
|
||||||
|
mean=(0.485, 0.456, 0.406),
|
||||||
|
std=(0.229, 0.224, 0.225),
|
||||||
|
img_dtype=torch.float32,
|
||||||
|
device=torch.device('cuda')
|
||||||
|
):
|
||||||
|
self.loader = loader
|
||||||
|
self.device = device
|
||||||
|
self.img_dtype = img_dtype or torch.float32
|
||||||
|
|
||||||
|
# Create mean/std tensors for normalization (will be applied to patches)
|
||||||
|
self.mean = torch.tensor([x * 255 for x in mean], device=device, dtype=self.img_dtype).view(1, 1, 3)
|
||||||
|
self.std = torch.tensor([x * 255 for x in std], device=device, dtype=self.img_dtype).view(1, 1, 3)
|
||||||
|
|
||||||
|
# Check for CUDA/NPU availability
|
||||||
|
self.is_cuda = device.type == 'cuda' and torch.cuda.is_available()
|
||||||
|
self.is_npu = device.type == 'npu' and torch.npu.is_available()
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
first = True
|
||||||
|
if self.is_cuda:
|
||||||
|
stream = torch.cuda.Stream()
|
||||||
|
stream_context = partial(torch.cuda.stream, stream=stream)
|
||||||
|
elif self.is_npu:
|
||||||
|
stream = torch.npu.Stream()
|
||||||
|
stream_context = partial(torch.npu.stream, stream=stream)
|
||||||
|
else:
|
||||||
|
stream = None
|
||||||
|
stream_context = suppress
|
||||||
|
|
||||||
|
for next_input_dict, next_target in self.loader:
|
||||||
|
with stream_context():
|
||||||
|
# Move all tensors in input_dict to device
|
||||||
|
for k, v in next_input_dict.items():
|
||||||
|
if isinstance(v, torch.Tensor):
|
||||||
|
dtype = self.img_dtype if k == 'patches' else None
|
||||||
|
next_input_dict[k] = next_input_dict[k].to(
|
||||||
|
device=self.device,
|
||||||
|
non_blocking=True,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
next_target = next_target.to(device=self.device, non_blocking=True)
|
||||||
|
|
||||||
|
# Normalize patch values (assuming patches are in format [B, N, P*P*C])
|
||||||
|
batch_size, num_patches, patch_pixels = next_input_dict['patches'].shape
|
||||||
|
patches = next_input_dict['patches'].view(batch_size, -1, 3) # to [B*N, P*P, C] for normalization
|
||||||
|
patches = patches.sub(self.mean).div(self.std)
|
||||||
|
|
||||||
|
# Reshape back
|
||||||
|
next_input_dict['patches'] = patches.reshape(batch_size, num_patches, patch_pixels)
|
||||||
|
|
||||||
|
if not first:
|
||||||
|
yield input_dict, target
|
||||||
|
else:
|
||||||
|
first = False
|
||||||
|
|
||||||
|
if stream is not None:
|
||||||
|
if self.is_cuda:
|
||||||
|
torch.cuda.current_stream().wait_stream(stream)
|
||||||
|
elif self.is_npu:
|
||||||
|
torch.npu.current_stream().wait_stream(stream)
|
||||||
|
|
||||||
|
input_dict = next_input_dict
|
||||||
|
target = next_target
|
||||||
|
|
||||||
|
yield input_dict, target
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.loader)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def sampler(self):
|
||||||
|
return self.loader.sampler
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dataset(self):
|
||||||
|
return self.loader.dataset
|
||||||
|
|
||||||
|
|
||||||
|
def create_naflex_loader(
|
||||||
|
dataset,
|
||||||
|
patch_size: Union[Tuple[int, int], int] = 16,
|
||||||
|
train_seq_lens: List[int] = (128, 256, 576, 784, 1024), # Training sequence lengths
|
||||||
|
max_seq_len: int = 576, # Fixed sequence length for validation
|
||||||
|
batch_size: int = 32, # Used for max_seq_len and max(train_seq_lens)
|
||||||
|
is_training: bool = False,
|
||||||
|
|
||||||
|
no_aug: bool = False,
|
||||||
|
re_prob: float = 0.,
|
||||||
|
re_mode: str = 'const',
|
||||||
|
re_count: int = 1,
|
||||||
|
re_split: bool = False,
|
||||||
|
train_crop_mode: Optional[str] = None,
|
||||||
|
scale: Optional[Tuple[float, float]] = None,
|
||||||
|
ratio: Optional[Tuple[float, float]] = None,
|
||||||
|
hflip: float = 0.5,
|
||||||
|
vflip: float = 0.,
|
||||||
|
color_jitter: float = 0.4,
|
||||||
|
color_jitter_prob: Optional[float] = None,
|
||||||
|
grayscale_prob: float = 0.,
|
||||||
|
gaussian_blur_prob: float = 0.,
|
||||||
|
auto_augment: Optional[str] = None,
|
||||||
|
num_aug_repeats: int = 0,
|
||||||
|
num_aug_splits: int = 0,
|
||||||
|
interpolation: str = 'bilinear',
|
||||||
|
mean: Tuple[float, ...] = IMAGENET_DEFAULT_MEAN,
|
||||||
|
std: Tuple[float, ...] = IMAGENET_DEFAULT_STD,
|
||||||
|
crop_pct: Optional[float] = None,
|
||||||
|
crop_mode: Optional[str] = None,
|
||||||
|
crop_border_pixels: Optional[int] = None,
|
||||||
|
|
||||||
|
num_workers: int = 4,
|
||||||
|
distributed: bool = False,
|
||||||
|
rank: int = 0,
|
||||||
|
world_size: int = 1,
|
||||||
|
seed: int = 42,
|
||||||
|
epoch: int = 0,
|
||||||
|
use_prefetcher: bool = True,
|
||||||
|
pin_memory: bool = True,
|
||||||
|
img_dtype: torch.dtype = torch.float32,
|
||||||
|
device: Union[str, torch.device] = torch.device('cuda'),
|
||||||
|
persistent_workers: bool = True,
|
||||||
|
worker_seeding: str = 'all',
|
||||||
|
):
|
||||||
|
"""Create a data loader with dynamic sequence length sampling for training."""
|
||||||
|
|
||||||
|
if is_training:
|
||||||
|
# For training, use the dynamic sequence length mechanism
|
||||||
|
assert num_aug_repeats == 0, 'Augmentation repeats not currently supported in NaFlex loader'
|
||||||
|
|
||||||
|
transform_factory = partial(
|
||||||
|
create_transform,
|
||||||
|
is_training=True,
|
||||||
|
no_aug=no_aug,
|
||||||
|
train_crop_mode=train_crop_mode,
|
||||||
|
scale=scale,
|
||||||
|
ratio=ratio,
|
||||||
|
hflip=hflip,
|
||||||
|
vflip=vflip,
|
||||||
|
color_jitter=color_jitter,
|
||||||
|
color_jitter_prob=color_jitter_prob,
|
||||||
|
grayscale_prob=grayscale_prob,
|
||||||
|
gaussian_blur_prob=gaussian_blur_prob,
|
||||||
|
auto_augment=auto_augment,
|
||||||
|
interpolation=interpolation,
|
||||||
|
mean=mean,
|
||||||
|
std=std,
|
||||||
|
crop_pct=crop_pct,
|
||||||
|
crop_mode=crop_mode,
|
||||||
|
crop_border_pixels=crop_border_pixels,
|
||||||
|
re_prob=re_prob,
|
||||||
|
re_mode=re_mode,
|
||||||
|
re_count=re_count,
|
||||||
|
use_prefetcher=use_prefetcher,
|
||||||
|
naflex=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
max_train_seq_len = max(train_seq_lens)
|
||||||
|
max_tokens_per_batch = batch_size * max_train_seq_len
|
||||||
|
|
||||||
|
if isinstance(dataset, torch.utils.data.IterableDataset):
|
||||||
|
assert False, "IterableDataset Wrapper is a WIP"
|
||||||
|
|
||||||
|
naflex_dataset = VariableSeqMapWrapper(
|
||||||
|
dataset,
|
||||||
|
transform_factory=transform_factory,
|
||||||
|
patch_size=patch_size,
|
||||||
|
seq_lens=train_seq_lens,
|
||||||
|
max_tokens_per_batch=max_tokens_per_batch,
|
||||||
|
seed=seed,
|
||||||
|
distributed=distributed,
|
||||||
|
rank=rank,
|
||||||
|
world_size=world_size,
|
||||||
|
shuffle=True,
|
||||||
|
epoch=epoch,
|
||||||
|
)
|
||||||
|
|
||||||
|
# NOTE: Collation is handled by the dataset wrapper for training
|
||||||
|
# Create the collator (handles fixed-size collation)
|
||||||
|
# collate_fn = NaFlexCollator(
|
||||||
|
# patch_size=patch_size,
|
||||||
|
# max_seq_len=max(seq_lens) + 1, # +1 for class token
|
||||||
|
# use_prefetcher=use_prefetcher
|
||||||
|
# )
|
||||||
|
|
||||||
|
loader = torch.utils.data.DataLoader(
|
||||||
|
naflex_dataset,
|
||||||
|
batch_size=None,
|
||||||
|
shuffle=False,
|
||||||
|
num_workers=num_workers,
|
||||||
|
sampler=None,
|
||||||
|
#collate_fn=collate_fn,
|
||||||
|
pin_memory=pin_memory,
|
||||||
|
worker_init_fn=partial(_worker_init, worker_seeding=worker_seeding),
|
||||||
|
persistent_workers=persistent_workers
|
||||||
|
)
|
||||||
|
|
||||||
|
if use_prefetcher:
|
||||||
|
loader = NaFlexPrefetchLoader(
|
||||||
|
loader,
|
||||||
|
mean=mean,
|
||||||
|
std=std,
|
||||||
|
img_dtype=img_dtype,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# For validation, use fixed sequence length (unchanged)
|
||||||
|
dataset.transform = create_transform(
|
||||||
|
is_training=False,
|
||||||
|
interpolation=interpolation,
|
||||||
|
mean=mean,
|
||||||
|
std=std,
|
||||||
|
# FIXME add crop args when sequence transforms support crop modes
|
||||||
|
use_prefetcher=use_prefetcher,
|
||||||
|
naflex=True,
|
||||||
|
patch_size=patch_size,
|
||||||
|
max_seq_len=max_seq_len,
|
||||||
|
patchify=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create the collator
|
||||||
|
collate_fn = NaFlexCollator(
|
||||||
|
patch_size=patch_size,
|
||||||
|
max_seq_len=max_seq_len,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Handle distributed training
|
||||||
|
sampler = None
|
||||||
|
if distributed and not isinstance(dataset, torch.utils.data.IterableDataset):
|
||||||
|
# For validation, use OrderedDistributedSampler
|
||||||
|
from timm.data.distributed_sampler import OrderedDistributedSampler
|
||||||
|
sampler = OrderedDistributedSampler(dataset)
|
||||||
|
|
||||||
|
loader = torch.utils.data.DataLoader(
|
||||||
|
dataset,
|
||||||
|
batch_size=batch_size,
|
||||||
|
shuffle=False,
|
||||||
|
num_workers=num_workers,
|
||||||
|
sampler=sampler,
|
||||||
|
collate_fn=collate_fn,
|
||||||
|
pin_memory=pin_memory,
|
||||||
|
drop_last=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
if use_prefetcher:
|
||||||
|
loader = NaFlexPrefetchLoader(
|
||||||
|
loader,
|
||||||
|
mean=mean,
|
||||||
|
std=std,
|
||||||
|
img_dtype=img_dtype,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
return loader
|
804
timm/data/naflex_transforms.py
Normal file
804
timm/data/naflex_transforms.py
Normal file
@ -0,0 +1,804 @@
|
|||||||
|
""" NaFlex (NaViT + FlexiViT) Transforms and Collation
|
||||||
|
|
||||||
|
Implements PyTorch versions of the transforms described in the NaViT and FlexiViT papers:
|
||||||
|
- NaViT: https://arxiv.org/abs/2307.14995
|
||||||
|
- FlexiViT: https://arxiv.org/abs/2212.08013
|
||||||
|
|
||||||
|
Enables variable resolution/aspect ratio image handling with efficient patching.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import math
|
||||||
|
import random
|
||||||
|
import warnings
|
||||||
|
from typing import List, Optional, Sequence, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from torchvision import transforms
|
||||||
|
from torchvision.transforms import functional as F
|
||||||
|
from torchvision.transforms.functional import InterpolationMode
|
||||||
|
|
||||||
|
from .transforms import str_to_interp_mode, crop_or_pad, center_crop_or_pad
|
||||||
|
|
||||||
|
|
||||||
|
def get_image_size_for_seq(
|
||||||
|
image_hw,
|
||||||
|
patch_size=16,
|
||||||
|
max_seq_len=1024,
|
||||||
|
divisible_by_patch=True,
|
||||||
|
max_ratio=None,
|
||||||
|
eps = 1e-5,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Determine scaling ratio and image size so that when `image_hw` is scaled
|
||||||
|
by 'ratio', the total number of resulting patches does not exceed
|
||||||
|
'max_seq_len'.
|
||||||
|
|
||||||
|
- Patch size can be an integer (square patch) or a tuple (patch_h, patch_w).
|
||||||
|
- Optionally cap the ratio at `max_ratio` to prevent upsampling beyond
|
||||||
|
a certain multiple of the original size.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_hw (tuple or list of int): (height, width) of the original image.
|
||||||
|
patch_size (int or tuple[int, int]): If int, patch is square. If tuple,
|
||||||
|
patch is rectangular (patch_h, patch_w).
|
||||||
|
max_seq_len (int): Maximum allowed sequence length for the resulting image.
|
||||||
|
divisible_by_patch (bool): If True, the resulting image height and width
|
||||||
|
must be multiples of patch_size.
|
||||||
|
eps (float): Small number for binary search convergence.
|
||||||
|
max_ratio (float or None): If provided, the scaling ratio found by the
|
||||||
|
binary search will be clamped to min(found_ratio, max_ratio). Set
|
||||||
|
max_ratio=1.0 to ensure no upsampling beyond original size.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ratio (float): Found scaling ratio (capped by `max_ratio` if provided).
|
||||||
|
target_hw (tuple of int): Target (height, width) after scaling.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Handle patch size input, extract patch_h, patch_w
|
||||||
|
if isinstance(patch_size, int):
|
||||||
|
patch_h, patch_w = patch_size, patch_size
|
||||||
|
else:
|
||||||
|
# Assume it's a tuple/list: (patch_h, patch_w)
|
||||||
|
if len(patch_size) != 2:
|
||||||
|
raise ValueError("patch_size tuple must have exactly two elements (patch_h, patch_w).")
|
||||||
|
patch_h, patch_w = patch_size
|
||||||
|
|
||||||
|
# Safety checks
|
||||||
|
if patch_h <= 0 or patch_w <= 0:
|
||||||
|
raise ValueError("patch_size dimensions must be positive.")
|
||||||
|
|
||||||
|
def prepare_target_hw(ratio):
|
||||||
|
"""Scale image_hw by ratio and optionally round dimensions to multiples of patch_h, patch_w."""
|
||||||
|
scaled_h = image_hw[0] * ratio
|
||||||
|
scaled_w = image_hw[1] * ratio
|
||||||
|
|
||||||
|
# If we need the result to be divisible by patch_size
|
||||||
|
if divisible_by_patch:
|
||||||
|
scaled_h = patch_h * math.ceil(scaled_h / patch_h)
|
||||||
|
scaled_w = patch_w * math.ceil(scaled_w / patch_w)
|
||||||
|
|
||||||
|
# Ensure at least one patch in each dimension
|
||||||
|
scaled_h = int(max(scaled_h, patch_h))
|
||||||
|
scaled_w = int(max(scaled_w, patch_w))
|
||||||
|
|
||||||
|
return scaled_h, scaled_w
|
||||||
|
|
||||||
|
def is_feasible(ratio):
|
||||||
|
"""Check if scaling by 'ratio' keeps patch count within max_seq_len."""
|
||||||
|
t_h, t_w = prepare_target_hw(ratio)
|
||||||
|
|
||||||
|
# Each dimension is already a multiple of patch_h, patch_w if divisible_by_patch=True.
|
||||||
|
# Use integer division to count patches.
|
||||||
|
num_patches_h = t_h // patch_h
|
||||||
|
num_patches_w = t_w // patch_w
|
||||||
|
seq_len = num_patches_h * num_patches_w
|
||||||
|
|
||||||
|
return seq_len <= max_seq_len
|
||||||
|
|
||||||
|
# Binary search boundaries
|
||||||
|
lb = eps / 10.0
|
||||||
|
rb = 100.0
|
||||||
|
|
||||||
|
# Standard binary search loop
|
||||||
|
while (rb - lb) >= eps:
|
||||||
|
mid = (lb + rb) / 2.0
|
||||||
|
if is_feasible(mid):
|
||||||
|
lb = mid
|
||||||
|
else:
|
||||||
|
rb = mid
|
||||||
|
|
||||||
|
# The final ratio from the binary search
|
||||||
|
ratio = lb
|
||||||
|
|
||||||
|
# If max_ratio is provided, clamp it to prevent upsampling beyond that threshold
|
||||||
|
if max_ratio is not None:
|
||||||
|
ratio = min(ratio, max_ratio)
|
||||||
|
|
||||||
|
# Final checks
|
||||||
|
if ratio <= eps:
|
||||||
|
raise ValueError("Binary search failed - image might be too large?")
|
||||||
|
if ratio >= 100.0:
|
||||||
|
raise ValueError("Binary search failed - image might be too small?")
|
||||||
|
|
||||||
|
# Prepare the final target dimensions with the possibly clamped ratio
|
||||||
|
target_hw = prepare_target_hw(ratio)
|
||||||
|
return ratio, target_hw
|
||||||
|
|
||||||
|
|
||||||
|
_RANDOM_INTERPOLATION = (str_to_interp_mode('bilinear'), str_to_interp_mode('bicubic'))
|
||||||
|
|
||||||
|
|
||||||
|
class ResizeToSequence(torch.nn.Module):
|
||||||
|
"""Resize image to fit within a maximum sequence length constraint when patchified.
|
||||||
|
|
||||||
|
This maintains aspect ratio while ensuring the resulting image, when divided into patches,
|
||||||
|
will not exceed the specified maximum sequence length.
|
||||||
|
"""
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
patch_size: int,
|
||||||
|
max_seq_len: int = 1024,
|
||||||
|
divisible_by_patch: bool = True,
|
||||||
|
max_ratio: Optional[float] = None,
|
||||||
|
interpolation='bicubic',
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.patch_size = patch_size
|
||||||
|
self.max_seq_len = max_seq_len
|
||||||
|
self.divisible_by_patch = divisible_by_patch
|
||||||
|
self.max_ratio = max_ratio
|
||||||
|
if isinstance(interpolation, str):
|
||||||
|
if interpolation == 'random':
|
||||||
|
self.interpolation = _RANDOM_INTERPOLATION
|
||||||
|
else:
|
||||||
|
self.interpolation = str_to_interp_mode(interpolation)
|
||||||
|
else:
|
||||||
|
self.interpolation = interpolation
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, img):
|
||||||
|
"""Resize image to maintain aspect ratio and fit sequence constraint."""
|
||||||
|
_, h, w = transforms.functional.get_dimensions(img)
|
||||||
|
|
||||||
|
_, target_hw = get_image_size_for_seq(
|
||||||
|
(h, w),
|
||||||
|
self.patch_size,
|
||||||
|
self.max_seq_len,
|
||||||
|
divisible_by_patch=self.divisible_by_patch,
|
||||||
|
max_ratio=self.max_ratio,
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(self.interpolation, (tuple, list)):
|
||||||
|
interpolation = random.choice(self.interpolation)
|
||||||
|
else:
|
||||||
|
interpolation = self.interpolation
|
||||||
|
|
||||||
|
resized_img = transforms.functional.resize(img, target_hw, interpolation=interpolation, antialias=True)
|
||||||
|
|
||||||
|
return resized_img
|
||||||
|
|
||||||
|
|
||||||
|
class ResizeKeepRatioToSequence(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
Resize and Keep Aspect Ratio, adapted to fit sequence length constraints.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
patch_size=16,
|
||||||
|
max_sequence_len=1024,
|
||||||
|
divisible_by_patch=True,
|
||||||
|
longest=0.,
|
||||||
|
interpolation='bilinear',
|
||||||
|
random_scale_prob=0.,
|
||||||
|
random_scale_range=(0.85, 1.05),
|
||||||
|
random_scale_area=False,
|
||||||
|
random_aspect_prob=0.,
|
||||||
|
random_aspect_range=(0.9, 1.11),
|
||||||
|
max_ratio=None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
patch_size: Size of patches (int or tuple of (patch_h, patch_w))
|
||||||
|
max_sequence_len: Maximum allowed sequence length for the resulting image
|
||||||
|
divisible_by_patch: If True, ensure dimensions are divisible by patch_size
|
||||||
|
longest: Float between 0-1 where 0=shortest side, 1=longest side determines scale
|
||||||
|
interpolation: Interpolation method for resizing
|
||||||
|
random_scale_prob: Probability of applying random scaling
|
||||||
|
random_scale_range: Range for random scaling factor (min, max)
|
||||||
|
random_scale_area: If True, scale factors affect area (√ factor)
|
||||||
|
random_aspect_prob: Probability of applying random aspect ratio jittering
|
||||||
|
random_aspect_range: Range for random aspect ratio (min, max)
|
||||||
|
max_ratio: Maximum allowed scaling ratio
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.patch_size = patch_size
|
||||||
|
self.max_sequence_len = max_sequence_len
|
||||||
|
self.divisible_by_patch = divisible_by_patch
|
||||||
|
self.longest = float(longest)
|
||||||
|
|
||||||
|
if interpolation == 'random':
|
||||||
|
self.interpolation = _RANDOM_INTERPOLATION
|
||||||
|
else:
|
||||||
|
self.interpolation = str_to_interp_mode(interpolation)
|
||||||
|
|
||||||
|
self.random_scale_prob = random_scale_prob
|
||||||
|
self.random_scale_range = random_scale_range
|
||||||
|
self.random_scale_area = random_scale_area
|
||||||
|
self.random_aspect_prob = random_aspect_prob
|
||||||
|
self.random_aspect_range = random_aspect_range
|
||||||
|
self.max_ratio = max_ratio
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_params(
|
||||||
|
img,
|
||||||
|
patch_size,
|
||||||
|
max_sequence_len,
|
||||||
|
divisible_by_patch,
|
||||||
|
longest,
|
||||||
|
random_scale_prob=0.,
|
||||||
|
random_scale_range=(1.0, 1.33),
|
||||||
|
random_scale_area=False,
|
||||||
|
random_aspect_prob=0.,
|
||||||
|
random_aspect_range=(0.9, 1.11),
|
||||||
|
max_ratio=None,
|
||||||
|
):
|
||||||
|
"""Get parameters for resizing."""
|
||||||
|
# Get image dimensions
|
||||||
|
img_h, img_w = F.get_dimensions(img)[1:]
|
||||||
|
|
||||||
|
# Step 1: Get the maximum allowed dimensions from sequence length constraint
|
||||||
|
_, target_hw = get_image_size_for_seq(
|
||||||
|
(img_h, img_w),
|
||||||
|
patch_size,
|
||||||
|
max_sequence_len,
|
||||||
|
divisible_by_patch,
|
||||||
|
max_ratio,
|
||||||
|
)
|
||||||
|
target_h, target_w = target_hw
|
||||||
|
|
||||||
|
# Calculate ratio based on sequence constraint
|
||||||
|
ratio_h = target_h / img_h
|
||||||
|
ratio_w = target_w / img_w
|
||||||
|
# Apply longest blending
|
||||||
|
ratio = max(ratio_h, ratio_w) * longest + min(ratio_h, ratio_w) * (1. - longest)
|
||||||
|
|
||||||
|
# Apply random scaling
|
||||||
|
if random_scale_prob > 0 and random.random() < random_scale_prob:
|
||||||
|
ratio_factor = random.uniform(random_scale_range[0], random_scale_range[1])
|
||||||
|
if random_scale_area:
|
||||||
|
# Make ratio factor equivalent to area change
|
||||||
|
ratio_factor = 1. / math.sqrt(ratio_factor)
|
||||||
|
ratio_factor = (ratio_factor, ratio_factor)
|
||||||
|
else:
|
||||||
|
ratio_factor = (1., 1.)
|
||||||
|
|
||||||
|
# Apply random aspect
|
||||||
|
if random_aspect_prob > 0 and random.random() < random_aspect_prob:
|
||||||
|
log_aspect = (math.log(random_aspect_range[0]), math.log(random_aspect_range[1]))
|
||||||
|
aspect_factor = math.exp(random.uniform(*log_aspect))
|
||||||
|
aspect_factor = math.sqrt(aspect_factor)
|
||||||
|
# Apply aspect ratio jittering
|
||||||
|
ratio_factor = (ratio_factor[0] / aspect_factor, ratio_factor[1] * aspect_factor)
|
||||||
|
|
||||||
|
# Calculate final dimensions
|
||||||
|
size = [round(dim * ratio * f) for dim, f in zip((img_h, img_w), ratio_factor)]
|
||||||
|
|
||||||
|
# Ensure dimensions satisfy sequence constraint and are divisible by patch size
|
||||||
|
if isinstance(patch_size, int):
|
||||||
|
ph, pw = patch_size, patch_size
|
||||||
|
else:
|
||||||
|
ph, pw = patch_size
|
||||||
|
|
||||||
|
# Ensure dimensions are at least one patch
|
||||||
|
size[0] = max(size[0], ph)
|
||||||
|
size[1] = max(size[1], pw)
|
||||||
|
|
||||||
|
# Make divisible by patch size if needed
|
||||||
|
if divisible_by_patch:
|
||||||
|
size[0] = ph * math.ceil(size[0] / ph)
|
||||||
|
size[1] = pw * math.ceil(size[1] / pw)
|
||||||
|
|
||||||
|
# Verify we haven't exceeded sequence length
|
||||||
|
num_patches_h = size[0] // ph
|
||||||
|
num_patches_w = size[1] // pw
|
||||||
|
seq_len = num_patches_h * num_patches_w
|
||||||
|
|
||||||
|
if seq_len > max_sequence_len:
|
||||||
|
# Scale back down to fit sequence constraint
|
||||||
|
scale_back = math.sqrt(max_sequence_len / seq_len)
|
||||||
|
size[0] = int(size[0] * scale_back)
|
||||||
|
size[1] = int(size[1] * scale_back)
|
||||||
|
|
||||||
|
# Ensure divisible by patch size after scaling back
|
||||||
|
if divisible_by_patch:
|
||||||
|
size[0] = ph * math.ceil(size[0] / ph)
|
||||||
|
size[1] = pw * math.ceil(size[1] / pw)
|
||||||
|
|
||||||
|
return size
|
||||||
|
|
||||||
|
def forward(self, img):
|
||||||
|
"""
|
||||||
|
Resize the image with aspect ratio preservation and sequence length constraints.
|
||||||
|
"""
|
||||||
|
size = self.get_params(
|
||||||
|
img,
|
||||||
|
self.patch_size,
|
||||||
|
self.max_sequence_len,
|
||||||
|
self.divisible_by_patch,
|
||||||
|
self.longest,
|
||||||
|
self.random_scale_prob,
|
||||||
|
self.random_scale_range,
|
||||||
|
self.random_scale_area,
|
||||||
|
self.random_aspect_prob,
|
||||||
|
self.random_aspect_range,
|
||||||
|
self.max_ratio,
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(self.interpolation, (tuple, list)):
|
||||||
|
interpolation = random.choice(self.interpolation)
|
||||||
|
else:
|
||||||
|
interpolation = self.interpolation
|
||||||
|
|
||||||
|
return F.resize(img, size, interpolation)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
interpolate_str = "random" if isinstance(self.interpolation, (tuple, list)) else str(self.interpolation)
|
||||||
|
return (f"{self.__class__.__name__}(patch_size={self.patch_size}, "
|
||||||
|
f"max_sequence_len={self.max_sequence_len}, "
|
||||||
|
f"longest={self.longest:.3f}, "
|
||||||
|
f"random_scale_prob={self.random_scale_prob:.3f}, "
|
||||||
|
f"random_aspect_prob={self.random_aspect_prob:.3f})")
|
||||||
|
|
||||||
|
|
||||||
|
class CenterCropToSequence(torch.nn.Module):
|
||||||
|
"""Center crop the image such that the resulting patch sequence length meets constraints."""
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
patch_size: int,
|
||||||
|
max_seq_len: int,
|
||||||
|
divisible_by_patch: bool = True,
|
||||||
|
fill: Union[int, Tuple[int, int, int]] = 0,
|
||||||
|
padding_mode: str = 'constant'
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.patch_size = patch_size
|
||||||
|
self.max_seq_len = max_seq_len
|
||||||
|
self.divisible_by_patch = divisible_by_patch
|
||||||
|
self.fill = fill
|
||||||
|
self.padding_mode = padding_mode
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, img):
|
||||||
|
"""Center crop the image to maintain aspect ratio and fit sequence constraint."""
|
||||||
|
_, h, w = transforms.functional.get_dimensions(img)
|
||||||
|
_, target_hw = get_image_size_for_seq(
|
||||||
|
(h, w),
|
||||||
|
self.patch_size,
|
||||||
|
self.max_seq_len,
|
||||||
|
self.divisible_by_patch
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use center crop
|
||||||
|
return center_crop_or_pad(img, target_hw, fill=self.fill, padding_mode=self.padding_mode)
|
||||||
|
|
||||||
|
|
||||||
|
class RandomCropToSequence(torch.nn.Module):
|
||||||
|
"""Randomly crop and/or pad the image to fit sequence length constraints.
|
||||||
|
|
||||||
|
This maintains aspect ratio while ensuring the resulting image, when divided into patches,
|
||||||
|
will not exceed the specified maximum sequence length. Similar to CentralCropToSequence
|
||||||
|
but with randomized positioning.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
patch_size: int,
|
||||||
|
max_sequence_len: int,
|
||||||
|
divisible_by_patch: bool = True,
|
||||||
|
fill: Union[int, Tuple[int, int, int]] = 0,
|
||||||
|
padding_mode: str = 'constant'
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
patch_size: Size of patches (int or tuple of (patch_h, patch_w))
|
||||||
|
max_sequence_len: Maximum allowed sequence length for the resulting image
|
||||||
|
divisible_by_patch: If True, resulting image dimensions will be multiples of patch_size
|
||||||
|
fill: Fill value for padding
|
||||||
|
padding_mode: Padding mode ('constant', 'edge', 'reflect', 'symmetric')
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.patch_size = patch_size
|
||||||
|
self.max_sequence_len = max_sequence_len
|
||||||
|
self.divisible_by_patch = divisible_by_patch
|
||||||
|
self.fill = fill
|
||||||
|
self.padding_mode = padding_mode
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_params(img, target_size):
|
||||||
|
"""Get random position for crop/pad."""
|
||||||
|
_, image_height, image_width = transforms.functional.get_dimensions(img)
|
||||||
|
delta_height = image_height - target_size[0]
|
||||||
|
delta_width = image_width - target_size[1]
|
||||||
|
|
||||||
|
# Handle both positive (crop) and negative (pad) deltas
|
||||||
|
if delta_height == 0:
|
||||||
|
top = 0
|
||||||
|
else:
|
||||||
|
top = int(math.copysign(random.randint(0, abs(delta_height)), delta_height))
|
||||||
|
|
||||||
|
if delta_width == 0:
|
||||||
|
left = 0
|
||||||
|
else:
|
||||||
|
left = int(math.copysign(random.randint(0, abs(delta_width)), delta_width))
|
||||||
|
|
||||||
|
return top, left
|
||||||
|
|
||||||
|
def forward(self, img):
|
||||||
|
"""Randomly crop or pad the image to maintain aspect ratio and fit sequence constraint."""
|
||||||
|
# Get current dimensions
|
||||||
|
_, img_h, img_w = transforms.functional.get_dimensions(img)
|
||||||
|
|
||||||
|
# Calculate target dimensions that satisfy sequence length
|
||||||
|
# We use max_ratio=1.0 to prevent upscaling - we only want to crop or maintain current size
|
||||||
|
_, target_hw = get_image_size_for_seq(
|
||||||
|
(img_h, img_w),
|
||||||
|
self.patch_size,
|
||||||
|
self.max_sequence_len,
|
||||||
|
self.divisible_by_patch,
|
||||||
|
max_ratio=1.0 # Prevent upscaling
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get random position for crop/pad
|
||||||
|
top, left = self.get_params(img, target_hw)
|
||||||
|
|
||||||
|
# Apply crop or pad
|
||||||
|
return crop_or_pad(
|
||||||
|
img,
|
||||||
|
top=top,
|
||||||
|
left=left,
|
||||||
|
height=target_hw[0],
|
||||||
|
width=target_hw[1],
|
||||||
|
fill=self.fill,
|
||||||
|
padding_mode=self.padding_mode,
|
||||||
|
)
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return (f"{self.__class__.__name__}(patch_size={self.patch_size}, "
|
||||||
|
f"max_sequence_len={self.max_sequence_len}, "
|
||||||
|
f"divisible_by_patch={self.divisible_by_patch})")
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_range(value, name, length=2):
|
||||||
|
# Validate type and length
|
||||||
|
if not isinstance(value, Sequence) or len(value) != length:
|
||||||
|
raise ValueError(f"{name} should be a sequence of length {length}.")
|
||||||
|
|
||||||
|
# Validate order
|
||||||
|
if value[0] > value[1]:
|
||||||
|
warnings.warn(f"{name.capitalize()} range reversed. Swapping.")
|
||||||
|
return value[1], value[0]
|
||||||
|
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
class RandomResizedCropToSequence(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
Randomly crop the input image to a subregion with varying area and aspect ratio
|
||||||
|
(relative to the original), then resize that crop to a target size. The target size
|
||||||
|
is determined such that patchifying the resized image (with `patch_size`)
|
||||||
|
does not exceed `max_seq_len` patches, while maintaining the aspect ratio of the crop.
|
||||||
|
|
||||||
|
This combines aspects of torchvision's RandomResizedCrop with sequence length constraints.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
patch_size (int or tuple[int, int]):
|
||||||
|
Patch dimensions (patch_h, patch_w) for sequence length calculation.
|
||||||
|
max_seq_len (int):
|
||||||
|
Maximum number of patches allowed in the final image.
|
||||||
|
scale (tuple[float, float]):
|
||||||
|
Range (min, max) of area fraction of the original image to crop.
|
||||||
|
ratio (tuple[float, float]):
|
||||||
|
Range (min, max) of aspect ratio *multipliers* for the crop, relative
|
||||||
|
to the original image's aspect ratio. E.g., (0.75, 1.333) means the
|
||||||
|
crop's aspect ratio will be sampled between 0.75*orig_ar and 1.333*orig_ar.
|
||||||
|
Uses log-uniform sampling.
|
||||||
|
interpolation (str or InterpolationMode):
|
||||||
|
Interpolation mode for resizing. Can be 'bilinear', 'bicubic', 'nearest',
|
||||||
|
or 'random' (chooses between bilinear and bicubic).
|
||||||
|
Defaults to 'bicubic'.
|
||||||
|
divisible_by_patch (bool):
|
||||||
|
If True, the final image height and width will be multiples of the
|
||||||
|
respective patch dimensions. Defaults to True.
|
||||||
|
max_ratio (float, optional):
|
||||||
|
An optional upper limit on the scaling ratio applied during resizing.
|
||||||
|
Prevents excessive upsampling of the initial crop. `max_ratio=1.0`
|
||||||
|
prevents any upsampling beyond the cropped size. Defaults to None (no limit).
|
||||||
|
final_scale_range (tuple[float, float], optional):
|
||||||
|
If provided, applies an *additional* random scaling factor to the
|
||||||
|
final target size. The factor is sampled uniformly from this range,
|
||||||
|
and multiplied by the size determined by `get_image_size_for_seq`.
|
||||||
|
E.g., (0.8, 1.0) means the final size will be between 80% and 100%
|
||||||
|
of the maximum feasible size. Defaults to None (use maximum feasible size).
|
||||||
|
attempts (int):
|
||||||
|
Number of attempts to sample a valid crop geometry before falling back
|
||||||
|
to a center crop strategy. Defaults to 10.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
patch_size: Union[int, Tuple[int, int]] = 16,
|
||||||
|
max_seq_len: int = 1024,
|
||||||
|
scale: Tuple[float, float] = (0.08, 1.0),
|
||||||
|
ratio: Tuple[float, float] = (.8, 1.25),
|
||||||
|
interpolation: Union[str, InterpolationMode] = 'bicubic',
|
||||||
|
divisible_by_patch: bool = True,
|
||||||
|
max_ratio: Optional[float] = None,
|
||||||
|
final_scale_range: Optional[Tuple[float, float]] = None,
|
||||||
|
attempts: int = 10,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
if isinstance(patch_size, int):
|
||||||
|
self.patch_h, self.patch_w = patch_size, patch_size
|
||||||
|
else:
|
||||||
|
# Assume it's a tuple/list: (patch_h, patch_w)
|
||||||
|
if len(patch_size) != 2:
|
||||||
|
raise ValueError("patch_size tuple must have exactly two elements (patch_h, patch_w).")
|
||||||
|
self.patch_h, self.patch_w = patch_size
|
||||||
|
self.max_seq_len = max_seq_len
|
||||||
|
self.scale = scale
|
||||||
|
self.ratio = ratio
|
||||||
|
self.divisible_by_patch = divisible_by_patch
|
||||||
|
self.max_ratio = max_ratio
|
||||||
|
self.final_scale_range = final_scale_range
|
||||||
|
self.attempts = attempts
|
||||||
|
if isinstance(interpolation, str):
|
||||||
|
if interpolation == 'random':
|
||||||
|
self.interpolation = _RANDOM_INTERPOLATION
|
||||||
|
else:
|
||||||
|
self.interpolation = str_to_interp_mode(interpolation)
|
||||||
|
else:
|
||||||
|
self.interpolation = interpolation
|
||||||
|
|
||||||
|
# Validate scale and ratio
|
||||||
|
self.scale = _validate_range(self.scale, "scale")
|
||||||
|
self.ratio = _validate_range(self.ratio, "ratio")
|
||||||
|
|
||||||
|
# Validate final_scale_range if provided
|
||||||
|
if self.final_scale_range is not None:
|
||||||
|
self.final_scale_range = _validate_range(self.final_scale_range, "final_scale_range")
|
||||||
|
|
||||||
|
# Additional validation for final_scale_range values
|
||||||
|
if not (0.0 <= self.final_scale_range[0] <= self.final_scale_range[1] <= 1.0):
|
||||||
|
warnings.warn("final_scale_range values should ideally be between 0.0 and 1.0.")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_params(
|
||||||
|
img: Union[torch.Tensor, Image],
|
||||||
|
scale: Tuple[float, float],
|
||||||
|
ratio: Tuple[float, float],
|
||||||
|
crop_attempts: int = 10,
|
||||||
|
patch_h: int = 16,
|
||||||
|
patch_w: int = 16,
|
||||||
|
max_seq_len: int = 1024,
|
||||||
|
divisible_by_patch: bool = True,
|
||||||
|
max_ratio: Optional[float] = None,
|
||||||
|
final_scale_range: Optional[Tuple[float, float]] = None,
|
||||||
|
interpolation: Union[List[InterpolationMode], InterpolationMode] = _RANDOM_INTERPOLATION,
|
||||||
|
) -> Tuple[Tuple[int, int, int, int], Tuple[int, int], InterpolationMode]:
|
||||||
|
""" Get parameters for a random sized crop relative to image aspect ratio.
|
||||||
|
"""
|
||||||
|
_, height, width = F.get_dimensions(img)
|
||||||
|
if height <= 0 or width <= 0:
|
||||||
|
raise ValueError(f"Input image must have positive dimensions, got H={height}, W={width}")
|
||||||
|
|
||||||
|
area = height * width
|
||||||
|
orig_aspect = width / height
|
||||||
|
log_ratio = (math.log(ratio[0]), math.log(ratio[1]))
|
||||||
|
|
||||||
|
for _ in range(crop_attempts):
|
||||||
|
target_area = area * random.uniform(scale[0], scale[1])
|
||||||
|
aspect_ratio_factor = math.exp(random.uniform(log_ratio[0], log_ratio[1]))
|
||||||
|
aspect_ratio = orig_aspect * aspect_ratio_factor
|
||||||
|
|
||||||
|
# Calculate target dimensions for the crop
|
||||||
|
# target_area = crop_w * crop_h, aspect_ratio = crop_w / crop_h
|
||||||
|
# => crop_h = sqrt(target_area / aspect_ratio)
|
||||||
|
# => crop_w = sqrt(target_area * aspect_ratio)
|
||||||
|
crop_h = int(round(math.sqrt(target_area / aspect_ratio)))
|
||||||
|
crop_w = int(round(math.sqrt(target_area * aspect_ratio)))
|
||||||
|
|
||||||
|
if 0 < crop_w <= width and 0 < crop_h <= height:
|
||||||
|
top = random.randint(0, height - crop_h)
|
||||||
|
left = random.randint(0, width - crop_w)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
# Fallback strategy, use center crop trying to respect ratio range
|
||||||
|
min_aspect_ratio = orig_aspect * ratio[0]
|
||||||
|
max_aspect_ratio = orig_aspect * ratio[1]
|
||||||
|
|
||||||
|
if orig_aspect < min_aspect_ratio:
|
||||||
|
# Original is narrower than target min, clamp width
|
||||||
|
crop_w = width
|
||||||
|
crop_h = min(int(round(crop_w / min_aspect_ratio)), height)
|
||||||
|
elif orig_aspect > max_aspect_ratio:
|
||||||
|
# Original is wider than target max, clamp height
|
||||||
|
crop_h = height
|
||||||
|
crop_w = min(int(round(crop_h * max_aspect_ratio)), width)
|
||||||
|
else:
|
||||||
|
# Aspect ratio is within range, take the largest possible crop (full image)
|
||||||
|
crop_w = width
|
||||||
|
crop_h = height
|
||||||
|
|
||||||
|
# Ensure valid dimensions after fallback calculation
|
||||||
|
crop_h = max(1, crop_h)
|
||||||
|
crop_w = max(1, crop_w)
|
||||||
|
|
||||||
|
top = (height - crop_h) // 2
|
||||||
|
left = (width - crop_w) // 2
|
||||||
|
|
||||||
|
# Determine max feasible size for scaling of the *cropped* region
|
||||||
|
feasible_ratio, feasible_size = get_image_size_for_seq(
|
||||||
|
(crop_h, crop_w),
|
||||||
|
patch_size=(patch_h, patch_w), # Pass as tuple
|
||||||
|
max_seq_len=max_seq_len,
|
||||||
|
divisible_by_patch=divisible_by_patch,
|
||||||
|
max_ratio=max_ratio,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Optionally apply final scale randomization
|
||||||
|
final_size = feasible_size
|
||||||
|
if final_scale_range is not None:
|
||||||
|
min_sc, max_sc = final_scale_range
|
||||||
|
scale_factor = random.uniform(min_sc, max_sc)
|
||||||
|
scale_factor = min(max(scale_factor, 0.0), 1.0) # Clamp factor just in case
|
||||||
|
|
||||||
|
# Calculate raw scaled size
|
||||||
|
# Note: feasible_ratio already accounts for max_ratio clamp if any
|
||||||
|
raw_h = crop_h * feasible_ratio * scale_factor
|
||||||
|
raw_w = crop_w * feasible_ratio * scale_factor
|
||||||
|
|
||||||
|
# Re-apply divisibility constraint if needed
|
||||||
|
if divisible_by_patch:
|
||||||
|
# Use ceil to avoid going under minimum patch size
|
||||||
|
target_h = patch_h * math.ceil(raw_h / patch_h)
|
||||||
|
target_w = patch_w * math.ceil(raw_w / patch_w)
|
||||||
|
else:
|
||||||
|
target_h = int(round(raw_h))
|
||||||
|
target_w = int(round(raw_w))
|
||||||
|
|
||||||
|
# Ensure final size is at least one patch dimension
|
||||||
|
target_h = max(target_h, patch_h)
|
||||||
|
target_w = max(target_w, patch_w)
|
||||||
|
final_size = (target_h, target_w)
|
||||||
|
|
||||||
|
# Final check: Ensure this randomized size still fits max_seq_len
|
||||||
|
# (It should, as we scaled down, but rounding might theoretically push it over)
|
||||||
|
num_patches_h = final_size[0] // patch_h
|
||||||
|
num_patches_w = final_size[1] // patch_w
|
||||||
|
if (num_patches_h * num_patches_w) > max_seq_len:
|
||||||
|
# If it exceeds, revert to the original feasible_size (safest)
|
||||||
|
final_size = feasible_size
|
||||||
|
warnings.warn(f"Final scale randomization ({scale_factor:.2f}) resulted in size {final_size} exceeding max_seq_len={max_seq_len} after rounding. Reverting to feasible size {feasible_size}.")
|
||||||
|
|
||||||
|
# Select interpolation mode
|
||||||
|
if isinstance(interpolation, (tuple, list)):
|
||||||
|
interpolation = random.choice(interpolation)
|
||||||
|
else:
|
||||||
|
interpolation = interpolation
|
||||||
|
|
||||||
|
return (top, left, crop_h, crop_w), final_size, interpolation
|
||||||
|
|
||||||
|
def forward(self, img: Union[torch.Tensor, Image]) -> torch.Tensor:
|
||||||
|
# Sample crop, resize, and interpolation parameters
|
||||||
|
crop_params, final_size, interpolation = self.get_params(
|
||||||
|
img,
|
||||||
|
scale=self.scale,
|
||||||
|
ratio=self.ratio,
|
||||||
|
crop_attempts=self.attempts,
|
||||||
|
patch_h=self.patch_h,
|
||||||
|
patch_w=self.patch_w,
|
||||||
|
divisible_by_patch=self.divisible_by_patch,
|
||||||
|
max_seq_len=self.max_seq_len,
|
||||||
|
final_scale_range=self.final_scale_range,
|
||||||
|
interpolation=self.interpolation,
|
||||||
|
)
|
||||||
|
top, left, crop_h, crop_w = crop_params
|
||||||
|
|
||||||
|
output = F.resized_crop(
|
||||||
|
img,
|
||||||
|
top=top,
|
||||||
|
left=left,
|
||||||
|
height=crop_h,
|
||||||
|
width=crop_w,
|
||||||
|
size=final_size,
|
||||||
|
interpolation=interpolation,
|
||||||
|
antialias=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
if isinstance(self.interpolation, (tuple, list)):
|
||||||
|
interpolate_str = ', '.join(str(m).split('.')[-1] for m in self.interpolation)
|
||||||
|
else:
|
||||||
|
interpolate_str = str(self.interpolation)
|
||||||
|
format_string = self.__class__.__name__ + '('
|
||||||
|
format_string += f"patch_size=({self.patch_h}, {self.patch_w})"
|
||||||
|
format_string += f", max_seq_len={self.max_seq_len}"
|
||||||
|
format_string += f", scale={self.scale}"
|
||||||
|
format_string += f", ratio={self.ratio}"
|
||||||
|
format_string += f", interpolation=[{interpolate_str}]"
|
||||||
|
format_string += f", divisible_by_patch={self.divisible_by_patch}"
|
||||||
|
format_string += f", max_ratio={self.max_ratio}"
|
||||||
|
format_string += f", final_scale_range={self.final_scale_range}"
|
||||||
|
format_string += f", attempts={self.attempts}"
|
||||||
|
format_string += ')'
|
||||||
|
return format_string
|
||||||
|
|
||||||
|
|
||||||
|
def patchify(
|
||||||
|
img: torch.Tensor,
|
||||||
|
patch_size: Tuple[int, int],
|
||||||
|
pad: bool = True,
|
||||||
|
include_info: bool = True,
|
||||||
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
|
||||||
|
c, h, w = img.shape
|
||||||
|
ph, pw = patch_size
|
||||||
|
|
||||||
|
# Ensure the image is divisible by patch size
|
||||||
|
if pad and (h % ph != 0 or w % pw != 0):
|
||||||
|
new_h = math.ceil(h / ph) * ph
|
||||||
|
new_w = math.ceil(w / pw) * pw
|
||||||
|
padded_img = torch.zeros(c, new_h, new_w, dtype=img.dtype)
|
||||||
|
padded_img[:, :h, :w] = img
|
||||||
|
img = padded_img
|
||||||
|
c, h, w = img.shape
|
||||||
|
|
||||||
|
# Calculate number of patches in each dimension
|
||||||
|
nh, nw = h // ph, w // pw
|
||||||
|
# Reshape image to patches [nh, nw, ph, pw, c]
|
||||||
|
patches = img.view(c, nh, ph, nw, pw).permute(1, 3, 2, 4, 0).reshape(nh * nw, ph * pw * c)
|
||||||
|
|
||||||
|
if include_info:
|
||||||
|
# Create coordinate indices
|
||||||
|
y_idx, x_idx = torch.meshgrid(torch.arange(nh), torch.arange(nw), indexing='ij')
|
||||||
|
# Stack into a single coords tensor [N, 2] with (y, x) order
|
||||||
|
coord = torch.stack([y_idx.reshape(-1), x_idx.reshape(-1)], dim=1)
|
||||||
|
# Create type indicators (all 1s for regular patches)
|
||||||
|
valid = torch.ones(nh * nw, dtype=torch.bool)
|
||||||
|
return patches, coord, valid
|
||||||
|
|
||||||
|
return patches
|
||||||
|
|
||||||
|
|
||||||
|
class Patchify(torch.nn.Module):
|
||||||
|
"""Transform an image into patches with corresponding coordinates and type indicators."""
|
||||||
|
|
||||||
|
def __init__(self, patch_size):
|
||||||
|
super().__init__()
|
||||||
|
self.patch_size = patch_size if isinstance(patch_size, tuple) else (patch_size, patch_size)
|
||||||
|
|
||||||
|
def forward(self, img):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
img: A PIL Image or tensor of shape [C, H, W]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dictionary containing:
|
||||||
|
- patches: Tensor of shape [N, P*P*C] where N is the number of patches
|
||||||
|
- patch_coord: Tensor of shape [N, 2] with (y, x) coordinates
|
||||||
|
- patch_valid: Valid indicator (all 1s for non-padding patches)
|
||||||
|
"""
|
||||||
|
if isinstance(img, Image.Image):
|
||||||
|
# Convert PIL Image to tensor [C, H, W]
|
||||||
|
img = transforms.functional.to_tensor(img)
|
||||||
|
|
||||||
|
patches, coord, valid = patchify(img, self.patch_size)
|
||||||
|
|
||||||
|
return {
|
||||||
|
'patches': patches,
|
||||||
|
'patch_coord': coord,
|
||||||
|
'patch_valid': valid,
|
||||||
|
}
|
@ -12,7 +12,8 @@ from torchvision import transforms
|
|||||||
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, DEFAULT_CROP_PCT
|
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, DEFAULT_CROP_PCT
|
||||||
from timm.data.auto_augment import rand_augment_transform, augment_and_mix_transform, auto_augment_transform
|
from timm.data.auto_augment import rand_augment_transform, augment_and_mix_transform, auto_augment_transform
|
||||||
from timm.data.transforms import str_to_interp_mode, str_to_pil_interp, RandomResizedCropAndInterpolation, \
|
from timm.data.transforms import str_to_interp_mode, str_to_pil_interp, RandomResizedCropAndInterpolation, \
|
||||||
ResizeKeepRatio, CenterCropOrPad, RandomCropOrPad, TrimBorder, ToNumpy, MaybeToTensor, MaybePILToTensor
|
ResizeKeepRatio, CenterCropOrPad, RandomCropOrPad, TrimBorder, MaybeToTensor, MaybePILToTensor
|
||||||
|
from timm.data.naflex_transforms import RandomResizedCropToSequence, ResizeToSequence, Patchify
|
||||||
from timm.data.random_erasing import RandomErasing
|
from timm.data.random_erasing import RandomErasing
|
||||||
|
|
||||||
|
|
||||||
@ -46,7 +47,7 @@ def transforms_noaug_train(
|
|||||||
]
|
]
|
||||||
if use_prefetcher:
|
if use_prefetcher:
|
||||||
# prefetcher and collate will handle tensor conversion and norm
|
# prefetcher and collate will handle tensor conversion and norm
|
||||||
tfl += [ToNumpy()]
|
tfl += [MaybePILToTensor()]
|
||||||
elif not normalize:
|
elif not normalize:
|
||||||
# when normalize disabled, converted to tensor without scaling, keep original dtype
|
# when normalize disabled, converted to tensor without scaling, keep original dtype
|
||||||
tfl += [MaybePILToTensor()]
|
tfl += [MaybePILToTensor()]
|
||||||
@ -84,6 +85,10 @@ def transforms_imagenet_train(
|
|||||||
use_prefetcher: bool = False,
|
use_prefetcher: bool = False,
|
||||||
normalize: bool = True,
|
normalize: bool = True,
|
||||||
separate: bool = False,
|
separate: bool = False,
|
||||||
|
naflex: bool = False,
|
||||||
|
patch_size: Union[int, Tuple[int, int]] = 16,
|
||||||
|
max_seq_len: int = 576, # 24x24 for 16x16 patch
|
||||||
|
patchify: bool = False,
|
||||||
):
|
):
|
||||||
""" ImageNet-oriented image transforms for training.
|
""" ImageNet-oriented image transforms for training.
|
||||||
|
|
||||||
@ -111,6 +116,9 @@ def transforms_imagenet_train(
|
|||||||
use_prefetcher: Prefetcher enabled. Do not convert image to tensor or normalize.
|
use_prefetcher: Prefetcher enabled. Do not convert image to tensor or normalize.
|
||||||
normalize: Normalize tensor output w/ provided mean/std (if prefetcher not used).
|
normalize: Normalize tensor output w/ provided mean/std (if prefetcher not used).
|
||||||
separate: Output transforms in 3-stage tuple.
|
separate: Output transforms in 3-stage tuple.
|
||||||
|
naflex: Enable NaFlex mode, sequence constrained patch output
|
||||||
|
patch_size: Patch size for NaFlex mode.
|
||||||
|
max_seq_len: Max sequence length for NaFlex mode.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
If separate==True, the transforms are returned as a tuple of 3 separate transforms
|
If separate==True, the transforms are returned as a tuple of 3 separate transforms
|
||||||
@ -121,35 +129,45 @@ def transforms_imagenet_train(
|
|||||||
"""
|
"""
|
||||||
train_crop_mode = train_crop_mode or 'rrc'
|
train_crop_mode = train_crop_mode or 'rrc'
|
||||||
assert train_crop_mode in {'rrc', 'rkrc', 'rkrr'}
|
assert train_crop_mode in {'rrc', 'rkrc', 'rkrr'}
|
||||||
if train_crop_mode in ('rkrc', 'rkrr'):
|
|
||||||
# FIXME integration of RKR is a WIP
|
primary_tfl = []
|
||||||
scale = tuple(scale or (0.8, 1.00))
|
if naflex:
|
||||||
ratio = tuple(ratio or (0.9, 1/.9))
|
primary_tfl += [RandomResizedCropToSequence(
|
||||||
primary_tfl = [
|
patch_size=patch_size,
|
||||||
ResizeKeepRatio(
|
max_seq_len=max_seq_len,
|
||||||
img_size,
|
interpolation=interpolation
|
||||||
interpolation=interpolation,
|
)]
|
||||||
random_scale_prob=0.5,
|
|
||||||
random_scale_range=scale,
|
|
||||||
random_scale_area=True, # scale compatible with RRC
|
|
||||||
random_aspect_prob=0.5,
|
|
||||||
random_aspect_range=ratio,
|
|
||||||
),
|
|
||||||
CenterCropOrPad(img_size, padding_mode='reflect')
|
|
||||||
if train_crop_mode == 'rkrc' else
|
|
||||||
RandomCropOrPad(img_size, padding_mode='reflect')
|
|
||||||
]
|
|
||||||
else:
|
else:
|
||||||
scale = tuple(scale or (0.08, 1.0)) # default imagenet scale range
|
if train_crop_mode in ('rkrc', 'rkrr'):
|
||||||
ratio = tuple(ratio or (3. / 4., 4. / 3.)) # default imagenet ratio range
|
# FIXME integration of RKR is a WIP
|
||||||
primary_tfl = [
|
scale = tuple(scale or (0.8, 1.00))
|
||||||
RandomResizedCropAndInterpolation(
|
ratio = tuple(ratio or (0.9, 1/.9))
|
||||||
img_size,
|
primary_tfl += [
|
||||||
scale=scale,
|
ResizeKeepRatio(
|
||||||
ratio=ratio,
|
img_size,
|
||||||
interpolation=interpolation,
|
interpolation=interpolation,
|
||||||
)
|
random_scale_prob=0.5,
|
||||||
]
|
random_scale_range=scale,
|
||||||
|
random_scale_area=True, # scale compatible with RRC
|
||||||
|
random_aspect_prob=0.5,
|
||||||
|
random_aspect_range=ratio,
|
||||||
|
),
|
||||||
|
CenterCropOrPad(img_size, padding_mode='reflect')
|
||||||
|
if train_crop_mode == 'rkrc' else
|
||||||
|
RandomCropOrPad(img_size, padding_mode='reflect')
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
scale = tuple(scale or (0.08, 1.0)) # default imagenet scale range
|
||||||
|
ratio = tuple(ratio or (3. / 4., 4. / 3.)) # default imagenet ratio range
|
||||||
|
primary_tfl += [
|
||||||
|
RandomResizedCropAndInterpolation(
|
||||||
|
img_size,
|
||||||
|
scale=scale,
|
||||||
|
ratio=ratio,
|
||||||
|
interpolation=interpolation,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
if hflip > 0.:
|
if hflip > 0.:
|
||||||
primary_tfl += [transforms.RandomHorizontalFlip(p=hflip)]
|
primary_tfl += [transforms.RandomHorizontalFlip(p=hflip)]
|
||||||
if vflip > 0.:
|
if vflip > 0.:
|
||||||
@ -215,7 +233,7 @@ def transforms_imagenet_train(
|
|||||||
final_tfl = []
|
final_tfl = []
|
||||||
if use_prefetcher:
|
if use_prefetcher:
|
||||||
# prefetcher and collate will handle tensor conversion and norm
|
# prefetcher and collate will handle tensor conversion and norm
|
||||||
final_tfl += [ToNumpy()]
|
final_tfl += [MaybePILToTensor()]
|
||||||
elif not normalize:
|
elif not normalize:
|
||||||
# when normalize disable, converted to tensor without scaling, keeps original dtype
|
# when normalize disable, converted to tensor without scaling, keeps original dtype
|
||||||
final_tfl += [MaybePILToTensor()]
|
final_tfl += [MaybePILToTensor()]
|
||||||
@ -238,6 +256,9 @@ def transforms_imagenet_train(
|
|||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
if patchify:
|
||||||
|
final_tfl += [Patchify(patch_size=patch_size)]
|
||||||
|
|
||||||
if separate:
|
if separate:
|
||||||
return transforms.Compose(primary_tfl), transforms.Compose(secondary_tfl), transforms.Compose(final_tfl)
|
return transforms.Compose(primary_tfl), transforms.Compose(secondary_tfl), transforms.Compose(final_tfl)
|
||||||
else:
|
else:
|
||||||
@ -254,6 +275,10 @@ def transforms_imagenet_eval(
|
|||||||
std: Tuple[float, ...] = IMAGENET_DEFAULT_STD,
|
std: Tuple[float, ...] = IMAGENET_DEFAULT_STD,
|
||||||
use_prefetcher: bool = False,
|
use_prefetcher: bool = False,
|
||||||
normalize: bool = True,
|
normalize: bool = True,
|
||||||
|
naflex: bool = False,
|
||||||
|
patch_size: Union[int, Tuple[int, int]] = 16,
|
||||||
|
max_seq_len: int = 576, # 24x24 for 16x16 patch
|
||||||
|
patchify: bool = False,
|
||||||
):
|
):
|
||||||
""" ImageNet-oriented image transform for evaluation and inference.
|
""" ImageNet-oriented image transform for evaluation and inference.
|
||||||
|
|
||||||
@ -267,6 +292,10 @@ def transforms_imagenet_eval(
|
|||||||
std: Image normalization standard deviation.
|
std: Image normalization standard deviation.
|
||||||
use_prefetcher: Prefetcher enabled. Do not convert image to tensor or normalize.
|
use_prefetcher: Prefetcher enabled. Do not convert image to tensor or normalize.
|
||||||
normalize: Normalize tensor output w/ provided mean/std (if prefetcher not used).
|
normalize: Normalize tensor output w/ provided mean/std (if prefetcher not used).
|
||||||
|
naflex: Enable NaFlex mode, sequence constrained patch output
|
||||||
|
patch_size: Patch size for NaFlex mode.
|
||||||
|
max_seq_len: Max sequence length for NaFlex mode.
|
||||||
|
patchify: Patchify the output instead of relying on prefetcher
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Composed transform pipeline
|
Composed transform pipeline
|
||||||
@ -285,37 +314,44 @@ def transforms_imagenet_eval(
|
|||||||
if crop_border_pixels:
|
if crop_border_pixels:
|
||||||
tfl += [TrimBorder(crop_border_pixels)]
|
tfl += [TrimBorder(crop_border_pixels)]
|
||||||
|
|
||||||
if crop_mode == 'squash':
|
if naflex:
|
||||||
# squash mode scales each edge to 1/pct of target, then crops
|
tfl += [ResizeToSequence(
|
||||||
# aspect ratio is not preserved, no img lost if crop_pct == 1.0
|
patch_size=patch_size,
|
||||||
tfl += [
|
max_seq_len=max_seq_len,
|
||||||
transforms.Resize(scale_size, interpolation=str_to_interp_mode(interpolation)),
|
interpolation=interpolation
|
||||||
transforms.CenterCrop(img_size),
|
)]
|
||||||
]
|
|
||||||
elif crop_mode == 'border':
|
|
||||||
# scale the longest edge of image to 1/pct of target edge, add borders to pad, then crop
|
|
||||||
# no image lost if crop_pct == 1.0
|
|
||||||
fill = [round(255 * v) for v in mean]
|
|
||||||
tfl += [
|
|
||||||
ResizeKeepRatio(scale_size, interpolation=interpolation, longest=1.0),
|
|
||||||
CenterCropOrPad(img_size, fill=fill),
|
|
||||||
]
|
|
||||||
else:
|
else:
|
||||||
# default crop model is center
|
if crop_mode == 'squash':
|
||||||
# aspect ratio is preserved, crops center within image, no borders are added, image is lost
|
# squash mode scales each edge to 1/pct of target, then crops
|
||||||
if scale_size[0] == scale_size[1]:
|
# aspect ratio is not preserved, no img lost if crop_pct == 1.0
|
||||||
# simple case, use torchvision built-in Resize w/ shortest edge mode (scalar size arg)
|
|
||||||
tfl += [
|
tfl += [
|
||||||
transforms.Resize(scale_size[0], interpolation=str_to_interp_mode(interpolation))
|
transforms.Resize(scale_size, interpolation=str_to_interp_mode(interpolation)),
|
||||||
|
transforms.CenterCrop(img_size),
|
||||||
|
]
|
||||||
|
elif crop_mode == 'border':
|
||||||
|
# scale the longest edge of image to 1/pct of target edge, add borders to pad, then crop
|
||||||
|
# no image lost if crop_pct == 1.0
|
||||||
|
fill = [round(255 * v) for v in mean]
|
||||||
|
tfl += [
|
||||||
|
ResizeKeepRatio(scale_size, interpolation=interpolation, longest=1.0),
|
||||||
|
CenterCropOrPad(img_size, fill=fill),
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
# resize the shortest edge to matching target dim for non-square target
|
# default crop model is center
|
||||||
tfl += [ResizeKeepRatio(scale_size)]
|
# aspect ratio is preserved, crops center within image, no borders are added, image is lost
|
||||||
tfl += [transforms.CenterCrop(img_size)]
|
if scale_size[0] == scale_size[1]:
|
||||||
|
# simple case, use torchvision built-in Resize w/ shortest edge mode (scalar size arg)
|
||||||
|
tfl += [
|
||||||
|
transforms.Resize(scale_size[0], interpolation=str_to_interp_mode(interpolation))
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
# resize the shortest edge to matching target dim for non-square target
|
||||||
|
tfl += [ResizeKeepRatio(scale_size)]
|
||||||
|
tfl += [transforms.CenterCrop(img_size)]
|
||||||
|
|
||||||
if use_prefetcher:
|
if use_prefetcher:
|
||||||
# prefetcher and collate will handle tensor conversion and norm
|
# prefetcher and collate will handle tensor conversion and norm
|
||||||
tfl += [ToNumpy()]
|
tfl += [MaybePILToTensor()]
|
||||||
elif not normalize:
|
elif not normalize:
|
||||||
# when normalize disabled, converted to tensor without scaling, keeps original dtype
|
# when normalize disabled, converted to tensor without scaling, keeps original dtype
|
||||||
tfl += [MaybePILToTensor()]
|
tfl += [MaybePILToTensor()]
|
||||||
@ -328,6 +364,9 @@ def transforms_imagenet_eval(
|
|||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
if patchify:
|
||||||
|
tfl += [Patchify(patch_size=patch_size)]
|
||||||
|
|
||||||
return transforms.Compose(tfl)
|
return transforms.Compose(tfl)
|
||||||
|
|
||||||
|
|
||||||
@ -359,6 +398,10 @@ def create_transform(
|
|||||||
use_prefetcher: bool = False,
|
use_prefetcher: bool = False,
|
||||||
normalize: bool = True,
|
normalize: bool = True,
|
||||||
separate: bool = False,
|
separate: bool = False,
|
||||||
|
naflex: bool = False,
|
||||||
|
patch_size: Union[int, Tuple[int, int]] = 16,
|
||||||
|
max_seq_len: int = 576, # 24x24 for 16x16 patch
|
||||||
|
patchify: bool = False
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -442,6 +485,10 @@ def create_transform(
|
|||||||
use_prefetcher=use_prefetcher,
|
use_prefetcher=use_prefetcher,
|
||||||
normalize=normalize,
|
normalize=normalize,
|
||||||
separate=separate,
|
separate=separate,
|
||||||
|
naflex=naflex,
|
||||||
|
patch_size=patch_size,
|
||||||
|
max_seq_len=max_seq_len,
|
||||||
|
patchify=patchify,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
assert not separate, "Separate transforms not supported for validation preprocessing"
|
assert not separate, "Separate transforms not supported for validation preprocessing"
|
||||||
@ -455,6 +502,10 @@ def create_transform(
|
|||||||
crop_border_pixels=crop_border_pixels,
|
crop_border_pixels=crop_border_pixels,
|
||||||
use_prefetcher=use_prefetcher,
|
use_prefetcher=use_prefetcher,
|
||||||
normalize=normalize,
|
normalize=normalize,
|
||||||
|
naflex=naflex,
|
||||||
|
patch_size=patch_size,
|
||||||
|
max_seq_len=max_seq_len,
|
||||||
|
patchify=patchify,
|
||||||
)
|
)
|
||||||
|
|
||||||
return transform
|
return transform
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
from .activations import *
|
from .activations import *
|
||||||
from .adaptive_avgmax_pool import \
|
from .adaptive_avgmax_pool import \
|
||||||
adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d
|
adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d
|
||||||
|
from .attention import Attention
|
||||||
from .attention2d import MultiQueryAttention2d, Attention2d, MultiQueryAttentionV2
|
from .attention2d import MultiQueryAttention2d, Attention2d, MultiQueryAttentionV2
|
||||||
from .attention_pool import AttentionPoolLatent
|
from .attention_pool import AttentionPoolLatent
|
||||||
from .attention_pool2d import AttentionPool2d, RotAttentionPool2d, RotaryEmbedding
|
from .attention_pool2d import AttentionPool2d, RotAttentionPool2d, RotaryEmbedding
|
||||||
|
66
timm/layers/attention.py
Normal file
66
timm/layers/attention.py
Normal file
@ -0,0 +1,66 @@
|
|||||||
|
from typing import Final, Type, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn as nn
|
||||||
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
from .config import use_fused_attn
|
||||||
|
|
||||||
|
|
||||||
|
class Attention(nn.Module):
|
||||||
|
fused_attn: Final[bool]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int,
|
||||||
|
num_heads: int = 8,
|
||||||
|
qkv_bias: bool = False,
|
||||||
|
qk_norm: bool = False,
|
||||||
|
proj_bias: bool = True,
|
||||||
|
attn_drop: float = 0.,
|
||||||
|
proj_drop: float = 0.,
|
||||||
|
norm_layer: Type[nn.Module] = nn.LayerNorm,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.head_dim = dim // num_heads
|
||||||
|
self.scale = self.head_dim ** -0.5
|
||||||
|
self.fused_attn = use_fused_attn()
|
||||||
|
|
||||||
|
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||||
|
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
||||||
|
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
||||||
|
self.attn_drop = nn.Dropout(attn_drop)
|
||||||
|
self.proj = nn.Linear(dim, dim, bias=proj_bias)
|
||||||
|
self.proj_drop = nn.Dropout(proj_drop)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
attn_mask: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
B, N, C = x.shape
|
||||||
|
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
||||||
|
q, k, v = qkv.unbind(0)
|
||||||
|
q, k = self.q_norm(q), self.k_norm(k)
|
||||||
|
|
||||||
|
if self.fused_attn:
|
||||||
|
x = F.scaled_dot_product_attention(
|
||||||
|
q, k, v,
|
||||||
|
attn_mask=attn_mask,
|
||||||
|
dropout_p=self.attn_drop.p if self.training else 0.,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
q = q * self.scale
|
||||||
|
attn = q @ k.transpose(-2, -1)
|
||||||
|
if attn_mask is not None:
|
||||||
|
attn = attn + attn_mask
|
||||||
|
attn = attn.softmax(dim=-1)
|
||||||
|
attn = self.attn_drop(attn)
|
||||||
|
x = attn @ v
|
||||||
|
|
||||||
|
x = x.transpose(1, 2).reshape(B, N, C)
|
||||||
|
x = self.proj(x)
|
||||||
|
x = self.proj_drop(x)
|
||||||
|
return x
|
@ -75,7 +75,7 @@ class AttentionPoolLatent(nn.Module):
|
|||||||
trunc_normal_tf_(self.pos_embed, std=self.pos_embed.shape[1] ** -0.5)
|
trunc_normal_tf_(self.pos_embed, std=self.pos_embed.shape[1] ** -0.5)
|
||||||
trunc_normal_tf_(self.latent, std=self.latent_dim ** -0.5)
|
trunc_normal_tf_(self.latent, std=self.latent_dim ** -0.5)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x, attn_mask: Optional[torch.Tensor] = None):
|
||||||
B, N, C = x.shape
|
B, N, C = x.shape
|
||||||
|
|
||||||
if self.pos_embed is not None:
|
if self.pos_embed is not None:
|
||||||
@ -91,10 +91,12 @@ class AttentionPoolLatent(nn.Module):
|
|||||||
q, k = self.q_norm(q), self.k_norm(k)
|
q, k = self.q_norm(q), self.k_norm(k)
|
||||||
|
|
||||||
if self.fused_attn:
|
if self.fused_attn:
|
||||||
x = F.scaled_dot_product_attention(q, k, v)
|
x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
||||||
else:
|
else:
|
||||||
q = q * self.scale
|
q = q * self.scale
|
||||||
attn = q @ k.transpose(-2, -1)
|
attn = q @ k.transpose(-2, -1)
|
||||||
|
if attn_mask is not None:
|
||||||
|
attn = attn + attn_mask
|
||||||
attn = attn.softmax(dim=-1)
|
attn = attn.softmax(dim=-1)
|
||||||
x = attn @ v
|
x = attn @ v
|
||||||
x = x.transpose(1, 2).reshape(B, self.latent_len, C)
|
x = x.transpose(1, 2).reshape(B, self.latent_len, C)
|
||||||
|
@ -72,6 +72,7 @@ from .vgg import *
|
|||||||
from .visformer import *
|
from .visformer import *
|
||||||
from .vision_transformer import *
|
from .vision_transformer import *
|
||||||
from .vision_transformer_hybrid import *
|
from .vision_transformer_hybrid import *
|
||||||
|
from .vision_transformer_flex import *
|
||||||
from .vision_transformer_relpos import *
|
from .vision_transformer_relpos import *
|
||||||
from .vision_transformer_sam import *
|
from .vision_transformer_sam import *
|
||||||
from .vitamin import *
|
from .vitamin import *
|
||||||
|
@ -41,8 +41,8 @@ from torch.jit import Final
|
|||||||
|
|
||||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD, \
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD, \
|
||||||
OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
|
OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
|
||||||
from timm.layers import PatchEmbed, Mlp, DropPath, AttentionPoolLatent, RmsNorm, PatchDropout, SwiGLUPacked, SwiGLU, \
|
from timm.layers import Attention, PatchEmbed, Mlp, DropPath, AttentionPoolLatent, RmsNorm, PatchDropout, \
|
||||||
trunc_normal_, lecun_normal_, resample_patch_embed, resample_abs_pos_embed, use_fused_attn, \
|
SwiGLUPacked, SwiGLU, trunc_normal_, lecun_normal_, resample_patch_embed, resample_abs_pos_embed, use_fused_attn, \
|
||||||
get_act_layer, get_norm_layer, LayerType
|
get_act_layer, get_norm_layer, LayerType
|
||||||
from ._builder import build_model_with_cfg
|
from ._builder import build_model_with_cfg
|
||||||
from ._features import feature_take_indices
|
from ._features import feature_take_indices
|
||||||
@ -55,58 +55,6 @@ __all__ = ['VisionTransformer'] # model_registry will add each entrypoint fn to
|
|||||||
_logger = logging.getLogger(__name__)
|
_logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class Attention(nn.Module):
|
|
||||||
fused_attn: Final[bool]
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
dim: int,
|
|
||||||
num_heads: int = 8,
|
|
||||||
qkv_bias: bool = False,
|
|
||||||
qk_norm: bool = False,
|
|
||||||
proj_bias: bool = True,
|
|
||||||
attn_drop: float = 0.,
|
|
||||||
proj_drop: float = 0.,
|
|
||||||
norm_layer: Type[nn.Module] = nn.LayerNorm,
|
|
||||||
) -> None:
|
|
||||||
super().__init__()
|
|
||||||
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
|
|
||||||
self.num_heads = num_heads
|
|
||||||
self.head_dim = dim // num_heads
|
|
||||||
self.scale = self.head_dim ** -0.5
|
|
||||||
self.fused_attn = use_fused_attn()
|
|
||||||
|
|
||||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
|
||||||
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
|
||||||
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
|
||||||
self.attn_drop = nn.Dropout(attn_drop)
|
|
||||||
self.proj = nn.Linear(dim, dim, bias=proj_bias)
|
|
||||||
self.proj_drop = nn.Dropout(proj_drop)
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
||||||
B, N, C = x.shape
|
|
||||||
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
|
||||||
q, k, v = qkv.unbind(0)
|
|
||||||
q, k = self.q_norm(q), self.k_norm(k)
|
|
||||||
|
|
||||||
if self.fused_attn:
|
|
||||||
x = F.scaled_dot_product_attention(
|
|
||||||
q, k, v,
|
|
||||||
dropout_p=self.attn_drop.p if self.training else 0.,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
q = q * self.scale
|
|
||||||
attn = q @ k.transpose(-2, -1)
|
|
||||||
attn = attn.softmax(dim=-1)
|
|
||||||
attn = self.attn_drop(attn)
|
|
||||||
x = attn @ v
|
|
||||||
|
|
||||||
x = x.transpose(1, 2).reshape(B, N, C)
|
|
||||||
x = self.proj(x)
|
|
||||||
x = self.proj_drop(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class LayerScale(nn.Module):
|
class LayerScale(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -165,8 +113,8 @@ class Block(nn.Module):
|
|||||||
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
||||||
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||||
x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x))))
|
x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x), attn_mask=attn_mask)))
|
||||||
x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
|
x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@ -222,8 +170,8 @@ class ResPostBlock(nn.Module):
|
|||||||
nn.init.constant_(self.norm1.weight, self.init_values)
|
nn.init.constant_(self.norm1.weight, self.init_values)
|
||||||
nn.init.constant_(self.norm2.weight, self.init_values)
|
nn.init.constant_(self.norm2.weight, self.init_values)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||||
x = x + self.drop_path1(self.norm1(self.attn(x)))
|
x = x + self.drop_path1(self.norm1(self.attn(x, attn_mask=attn_mask)))
|
||||||
x = x + self.drop_path2(self.norm2(self.mlp(x)))
|
x = x + self.drop_path2(self.norm2(self.mlp(x)))
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@ -282,7 +230,7 @@ class ParallelScalingBlock(nn.Module):
|
|||||||
self.ls = LayerScale(dim, init_values=init_values) if init_values is not None else nn.Identity()
|
self.ls = LayerScale(dim, init_values=init_values) if init_values is not None else nn.Identity()
|
||||||
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||||
B, N, C = x.shape
|
B, N, C = x.shape
|
||||||
|
|
||||||
# Combined MLP fc1 & qkv projections
|
# Combined MLP fc1 & qkv projections
|
||||||
@ -302,14 +250,18 @@ class ParallelScalingBlock(nn.Module):
|
|||||||
if self.fused_attn:
|
if self.fused_attn:
|
||||||
x_attn = F.scaled_dot_product_attention(
|
x_attn = F.scaled_dot_product_attention(
|
||||||
q, k, v,
|
q, k, v,
|
||||||
|
attn_mask=attn_mask,
|
||||||
dropout_p=self.attn_drop.p if self.training else 0.,
|
dropout_p=self.attn_drop.p if self.training else 0.,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
q = q * self.scale
|
q = q * self.scale
|
||||||
attn = q @ k.transpose(-2, -1)
|
attn = q @ k.transpose(-2, -1)
|
||||||
|
if attn_mask is not None:
|
||||||
|
attn = attn + attn_mask
|
||||||
attn = attn.softmax(dim=-1)
|
attn = attn.softmax(dim=-1)
|
||||||
attn = self.attn_drop(attn)
|
attn = self.attn_drop(attn)
|
||||||
x_attn = attn @ v
|
x_attn = attn @ v
|
||||||
|
|
||||||
x_attn = x_attn.transpose(1, 2).reshape(B, N, C)
|
x_attn = x_attn.transpose(1, 2).reshape(B, N, C)
|
||||||
x_attn = self.attn_out_proj(x_attn)
|
x_attn = self.attn_out_proj(x_attn)
|
||||||
|
|
||||||
@ -379,23 +331,11 @@ class ParallelThingsBlock(nn.Module):
|
|||||||
('drop_path', DropPath(drop_path) if drop_path > 0. else nn.Identity())
|
('drop_path', DropPath(drop_path) if drop_path > 0. else nn.Identity())
|
||||||
])))
|
])))
|
||||||
|
|
||||||
def _forward_jit(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||||
x = x + torch.stack([attn(x) for attn in self.attns]).sum(dim=0)
|
x = x + torch.stack([attn(x, attn_mask=attn_mask) for attn in self.attns]).sum(dim=0)
|
||||||
x = x + torch.stack([ffn(x) for ffn in self.ffns]).sum(dim=0)
|
x = x + torch.stack([ffn(x) for ffn in self.ffns]).sum(dim=0)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@torch.jit.ignore
|
|
||||||
def _forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
||||||
x = x + sum(attn(x) for attn in self.attns)
|
|
||||||
x = x + sum(ffn(x) for ffn in self.ffns)
|
|
||||||
return x
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
||||||
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
|
||||||
return self._forward_jit(x)
|
|
||||||
else:
|
|
||||||
return self._forward(x)
|
|
||||||
|
|
||||||
|
|
||||||
def global_pool_nlc(
|
def global_pool_nlc(
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
@ -728,7 +668,9 @@ class VisionTransformer(nn.Module):
|
|||||||
stop_early: bool = False,
|
stop_early: bool = False,
|
||||||
output_fmt: str = 'NCHW',
|
output_fmt: str = 'NCHW',
|
||||||
intermediates_only: bool = False,
|
intermediates_only: bool = False,
|
||||||
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
|
output_dict: bool = False,
|
||||||
|
attn_mask: Optional[torch.Tensor] = None,
|
||||||
|
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]], Dict[str, Any]]:
|
||||||
""" Forward features that returns intermediates.
|
""" Forward features that returns intermediates.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -739,8 +681,11 @@ class VisionTransformer(nn.Module):
|
|||||||
stop_early: Stop iterating over blocks when last desired intermediate hit
|
stop_early: Stop iterating over blocks when last desired intermediate hit
|
||||||
output_fmt: Shape of intermediate feature outputs
|
output_fmt: Shape of intermediate feature outputs
|
||||||
intermediates_only: Only return intermediate features
|
intermediates_only: Only return intermediate features
|
||||||
|
output_dict: Return outputs as a dictionary with 'image_features' and 'image_intermediates' keys
|
||||||
|
attn_mask: Optional attention mask for masked attention (e.g., for NaFlex)
|
||||||
Returns:
|
Returns:
|
||||||
|
A tuple with (final_features, intermediates), a list of intermediate features, or a dictionary containing
|
||||||
|
'image_features' and 'image_intermediates' (and optionally 'image_intermediates_prefix')
|
||||||
"""
|
"""
|
||||||
assert output_fmt in ('NCHW', 'NLC'), 'Output format must be one of NCHW or NLC.'
|
assert output_fmt in ('NCHW', 'NLC'), 'Output format must be one of NCHW or NLC.'
|
||||||
reshape = output_fmt == 'NCHW'
|
reshape = output_fmt == 'NCHW'
|
||||||
@ -759,7 +704,7 @@ class VisionTransformer(nn.Module):
|
|||||||
else:
|
else:
|
||||||
blocks = self.blocks[:max_index + 1]
|
blocks = self.blocks[:max_index + 1]
|
||||||
for i, blk in enumerate(blocks):
|
for i, blk in enumerate(blocks):
|
||||||
x = blk(x)
|
x = blk(x, attn_mask=attn_mask)
|
||||||
if i in take_indices:
|
if i in take_indices:
|
||||||
# normalize intermediates with final norm layer if enabled
|
# normalize intermediates with final norm layer if enabled
|
||||||
intermediates.append(self.norm(x) if norm else x)
|
intermediates.append(self.norm(x) if norm else x)
|
||||||
@ -776,6 +721,23 @@ class VisionTransformer(nn.Module):
|
|||||||
# reshape to BCHW output format
|
# reshape to BCHW output format
|
||||||
H, W = self.patch_embed.dynamic_feat_size((height, width))
|
H, W = self.patch_embed.dynamic_feat_size((height, width))
|
||||||
intermediates = [y.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for y in intermediates]
|
intermediates = [y.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for y in intermediates]
|
||||||
|
|
||||||
|
# For dictionary output, handle prefix tokens separately
|
||||||
|
if output_dict:
|
||||||
|
result_dict = {}
|
||||||
|
# Intermediates are always included
|
||||||
|
result_dict['image_intermediates'] = intermediates
|
||||||
|
if prefix_tokens is not None and return_prefix_tokens:
|
||||||
|
result_dict['image_intermediates_prefix'] = prefix_tokens
|
||||||
|
|
||||||
|
# Only include features if not intermediates_only
|
||||||
|
if not intermediates_only:
|
||||||
|
x_final = self.norm(x)
|
||||||
|
result_dict['image_features'] = x_final
|
||||||
|
|
||||||
|
return result_dict
|
||||||
|
|
||||||
|
# For non-dictionary output, maintain the original behavior
|
||||||
if not torch.jit.is_scripting() and return_prefix_tokens and prefix_tokens is not None:
|
if not torch.jit.is_scripting() and return_prefix_tokens and prefix_tokens is not None:
|
||||||
# return_prefix not support in torchscript due to poor type handling
|
# return_prefix not support in torchscript due to poor type handling
|
||||||
intermediates = list(zip(intermediates, prefix_tokens))
|
intermediates = list(zip(intermediates, prefix_tokens))
|
||||||
@ -811,6 +773,7 @@ class VisionTransformer(nn.Module):
|
|||||||
reshape: bool = False,
|
reshape: bool = False,
|
||||||
return_prefix_tokens: bool = False,
|
return_prefix_tokens: bool = False,
|
||||||
norm: bool = False,
|
norm: bool = False,
|
||||||
|
attn_mask: Optional[torch.Tensor] = None,
|
||||||
) -> List[torch.Tensor]:
|
) -> List[torch.Tensor]:
|
||||||
""" Intermediate layer accessor inspired by DINO / DINOv2 interface.
|
""" Intermediate layer accessor inspired by DINO / DINOv2 interface.
|
||||||
NOTE: This API is for backwards compat, favour using forward_intermediates() directly.
|
NOTE: This API is for backwards compat, favour using forward_intermediates() directly.
|
||||||
@ -821,17 +784,24 @@ class VisionTransformer(nn.Module):
|
|||||||
norm=norm,
|
norm=norm,
|
||||||
output_fmt='NCHW' if reshape else 'NLC',
|
output_fmt='NCHW' if reshape else 'NLC',
|
||||||
intermediates_only=True,
|
intermediates_only=True,
|
||||||
|
attn_mask=attn_mask,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
|
def forward_features(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||||
x = self.patch_embed(x)
|
x = self.patch_embed(x)
|
||||||
x = self._pos_embed(x)
|
x = self._pos_embed(x)
|
||||||
x = self.patch_drop(x)
|
x = self.patch_drop(x)
|
||||||
x = self.norm_pre(x)
|
x = self.norm_pre(x)
|
||||||
if self.grad_checkpointing and not torch.jit.is_scripting():
|
|
||||||
|
if attn_mask is not None:
|
||||||
|
# If mask provided, we need to apply blocks one by one
|
||||||
|
for blk in self.blocks:
|
||||||
|
x = blk(x, attn_mask=attn_mask)
|
||||||
|
elif self.grad_checkpointing and not torch.jit.is_scripting():
|
||||||
x = checkpoint_seq(self.blocks, x)
|
x = checkpoint_seq(self.blocks, x)
|
||||||
else:
|
else:
|
||||||
x = self.blocks(x)
|
x = self.blocks(x)
|
||||||
|
|
||||||
x = self.norm(x)
|
x = self.norm(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@ -849,8 +819,8 @@ class VisionTransformer(nn.Module):
|
|||||||
x = self.head_drop(x)
|
x = self.head_drop(x)
|
||||||
return x if pre_logits else self.head(x)
|
return x if pre_logits else self.head(x)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||||
x = self.forward_features(x)
|
x = self.forward_features(x, attn_mask=attn_mask)
|
||||||
x = self.forward_head(x)
|
x = self.forward_head(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
1045
timm/models/vision_transformer_flex.py
Normal file
1045
timm/models/vision_transformer_flex.py
Normal file
File diff suppressed because it is too large
Load Diff
108
train.py
108
train.py
@ -396,6 +396,15 @@ group.add_argument('--wandb-tags', default=[], type=str, nargs='+',
|
|||||||
group.add_argument('--wandb-resume-id', default='', type=str, metavar='ID',
|
group.add_argument('--wandb-resume-id', default='', type=str, metavar='ID',
|
||||||
help='If resuming a run, the id of the run in wandb')
|
help='If resuming a run, the id of the run in wandb')
|
||||||
|
|
||||||
|
# NaFlex scheduled loader arguments
|
||||||
|
group.add_argument('--naflex-loader', action='store_true', default=False,
|
||||||
|
help='Use NaFlex loader (Requires NaFlex compatible model)')
|
||||||
|
group.add_argument('--naflex-train-seq-lens', type=int, nargs='+', default=[128, 256, 576, 784, 1024],
|
||||||
|
help='Sequence lengths to use for NaFlex loader')
|
||||||
|
group.add_argument('--naflex-max-seq-len', type=int, default=576,
|
||||||
|
help='Fixed maximum sequence length for NaFlex loader (validation)')
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def _parse_args():
|
def _parse_args():
|
||||||
# Do we have a config file to parse?
|
# Do we have a config file to parse?
|
||||||
@ -669,6 +678,7 @@ def main():
|
|||||||
trust_remote_code=args.dataset_trust_remote_code,
|
trust_remote_code=args.dataset_trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
dataset_eval = None
|
||||||
if args.val_split:
|
if args.val_split:
|
||||||
dataset_eval = create_dataset(
|
dataset_eval = create_dataset(
|
||||||
args.dataset,
|
args.dataset,
|
||||||
@ -690,6 +700,7 @@ def main():
|
|||||||
mixup_fn = None
|
mixup_fn = None
|
||||||
mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
|
mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
|
||||||
if mixup_active:
|
if mixup_active:
|
||||||
|
assert not args.naflex_loader, "Mixup/Cutmix not currently supported for NaFlex loading."
|
||||||
mixup_args = dict(
|
mixup_args = dict(
|
||||||
mixup_alpha=args.mixup,
|
mixup_alpha=args.mixup,
|
||||||
cutmix_alpha=args.cutmix,
|
cutmix_alpha=args.cutmix,
|
||||||
@ -714,9 +725,19 @@ def main():
|
|||||||
train_interpolation = args.train_interpolation
|
train_interpolation = args.train_interpolation
|
||||||
if args.no_aug or not train_interpolation:
|
if args.no_aug or not train_interpolation:
|
||||||
train_interpolation = data_config['interpolation']
|
train_interpolation = data_config['interpolation']
|
||||||
loader_train = create_loader(
|
|
||||||
dataset_train,
|
# Check if we should use the NaFlex scheduled loader
|
||||||
input_size=data_config['input_size'],
|
common_loader_kwargs = dict(
|
||||||
|
mean=data_config['mean'],
|
||||||
|
std=data_config['std'],
|
||||||
|
pin_memory=args.pin_mem,
|
||||||
|
img_dtype=model_dtype or torch.float32,
|
||||||
|
device=device,
|
||||||
|
distributed=args.distributed,
|
||||||
|
use_prefetcher=args.prefetcher,
|
||||||
|
)
|
||||||
|
|
||||||
|
train_loader_kwargs = dict(
|
||||||
batch_size=args.batch_size,
|
batch_size=args.batch_size,
|
||||||
is_training=True,
|
is_training=True,
|
||||||
no_aug=args.no_aug,
|
no_aug=args.no_aug,
|
||||||
@ -737,42 +758,70 @@ def main():
|
|||||||
num_aug_repeats=args.aug_repeats,
|
num_aug_repeats=args.aug_repeats,
|
||||||
num_aug_splits=num_aug_splits,
|
num_aug_splits=num_aug_splits,
|
||||||
interpolation=train_interpolation,
|
interpolation=train_interpolation,
|
||||||
mean=data_config['mean'],
|
|
||||||
std=data_config['std'],
|
|
||||||
num_workers=args.workers,
|
num_workers=args.workers,
|
||||||
distributed=args.distributed,
|
|
||||||
collate_fn=collate_fn,
|
|
||||||
pin_memory=args.pin_mem,
|
|
||||||
img_dtype=model_dtype or torch.float32,
|
|
||||||
device=device,
|
|
||||||
use_prefetcher=args.prefetcher,
|
|
||||||
use_multi_epochs_loader=args.use_multi_epochs_loader,
|
|
||||||
worker_seeding=args.worker_seeding,
|
worker_seeding=args.worker_seeding,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if args.naflex_loader:
|
||||||
|
from timm.data.naflex_loader import create_naflex_loader
|
||||||
|
if utils.is_primary(args):
|
||||||
|
_logger.info('Using NaFlex loader')
|
||||||
|
|
||||||
|
loader_train = create_naflex_loader(
|
||||||
|
dataset=dataset_train,
|
||||||
|
patch_size=16, # Could be derived from model config
|
||||||
|
train_seq_lens=args.naflex_train_seq_lens,
|
||||||
|
rank=args.rank,
|
||||||
|
world_size=args.world_size,
|
||||||
|
**common_loader_kwargs,
|
||||||
|
**train_loader_kwargs,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Use standard loader
|
||||||
|
loader_train = create_loader(
|
||||||
|
dataset_train,
|
||||||
|
input_size=data_config['input_size'],
|
||||||
|
collate_fn=collate_fn,
|
||||||
|
use_multi_epochs_loader=args.use_multi_epochs_loader,
|
||||||
|
**common_loader_kwargs,
|
||||||
|
**train_loader_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
loader_eval = None
|
loader_eval = None
|
||||||
if args.val_split:
|
if args.val_split:
|
||||||
|
assert dataset_eval is not None
|
||||||
eval_workers = args.workers
|
eval_workers = args.workers
|
||||||
if args.distributed and ('tfds' in args.dataset or 'wds' in args.dataset):
|
if args.distributed and ('tfds' in args.dataset or 'wds' in args.dataset):
|
||||||
# FIXME reduces validation padding issues when using TFDS, WDS w/ workers and distributed training
|
# FIXME reduces validation padding issues when using TFDS, WDS w/ workers and distributed training
|
||||||
eval_workers = min(2, args.workers)
|
eval_workers = min(2, args.workers)
|
||||||
loader_eval = create_loader(
|
|
||||||
dataset_eval,
|
eval_loader_kwargs = dict(
|
||||||
input_size=data_config['input_size'],
|
|
||||||
batch_size=args.validation_batch_size or args.batch_size,
|
batch_size=args.validation_batch_size or args.batch_size,
|
||||||
is_training=False,
|
is_training=False,
|
||||||
interpolation=data_config['interpolation'],
|
interpolation=data_config['interpolation'],
|
||||||
mean=data_config['mean'],
|
|
||||||
std=data_config['std'],
|
|
||||||
num_workers=eval_workers,
|
num_workers=eval_workers,
|
||||||
distributed=args.distributed,
|
|
||||||
crop_pct=data_config['crop_pct'],
|
crop_pct=data_config['crop_pct'],
|
||||||
pin_memory=args.pin_mem,
|
|
||||||
img_dtype=model_dtype or torch.float32,
|
|
||||||
device=device,
|
|
||||||
use_prefetcher=args.prefetcher,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if args.naflex_loader:
|
||||||
|
from timm.data.naflex_loader import create_naflex_loader
|
||||||
|
# Use largest sequence length for validation
|
||||||
|
loader_eval = create_naflex_loader(
|
||||||
|
dataset=dataset_eval,
|
||||||
|
patch_size=16, # Could be derived from model config
|
||||||
|
max_seq_len=args.naflex_max_seq_len,
|
||||||
|
**common_loader_kwargs,
|
||||||
|
**eval_loader_kwargs
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Use standard loader
|
||||||
|
loader_eval = create_loader(
|
||||||
|
dataset_eval,
|
||||||
|
input_size=data_config['input_size'],
|
||||||
|
**common_loader_kwargs,
|
||||||
|
**eval_loader_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
# setup loss function
|
# setup loss function
|
||||||
if args.jsd_loss:
|
if args.jsd_loss:
|
||||||
assert num_aug_splits > 1 # JSD only valid with aug splits set
|
assert num_aug_splits > 1 # JSD only valid with aug splits set
|
||||||
@ -1083,8 +1132,12 @@ def train_one_epoch(
|
|||||||
loss = _forward()
|
loss = _forward()
|
||||||
_backward(loss)
|
_backward(loss)
|
||||||
|
|
||||||
losses_m.update(loss.item() * accum_steps, input.size(0))
|
if isinstance(input, dict):
|
||||||
update_sample_count += input.size(0)
|
batch_size = input['patches'].shape[0]
|
||||||
|
else:
|
||||||
|
batch_size = input.shape[0]
|
||||||
|
losses_m.update(loss.item() * accum_steps, batch_size)
|
||||||
|
update_sample_count += batch_size
|
||||||
|
|
||||||
if not need_update:
|
if not need_update:
|
||||||
data_start_time = time.time()
|
data_start_time = time.time()
|
||||||
@ -1210,9 +1263,10 @@ def validate(
|
|||||||
elif device.type == "npu":
|
elif device.type == "npu":
|
||||||
torch.npu.synchronize()
|
torch.npu.synchronize()
|
||||||
|
|
||||||
losses_m.update(reduced_loss.item(), input.size(0))
|
batch_size = output.shape[0]
|
||||||
top1_m.update(acc1.item(), output.size(0))
|
losses_m.update(reduced_loss.item(), batch_size)
|
||||||
top5_m.update(acc5.item(), output.size(0))
|
top1_m.update(acc1.item(), batch_size)
|
||||||
|
top5_m.update(acc5.item(), batch_size)
|
||||||
|
|
||||||
batch_time_m.update(time.time() - end)
|
batch_time_m.update(time.time() - end)
|
||||||
end = time.time()
|
end = time.time()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user