mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Initial NaFlex ViT model and training support
This commit is contained in:
parent
e44f14d7d2
commit
0893f5d296
@ -8,6 +8,13 @@ from .dataset_info import DatasetInfo, CustomDatasetInfo
|
||||
from .imagenet_info import ImageNetInfo, infer_imagenet_subset
|
||||
from .loader import create_loader
|
||||
from .mixup import Mixup, FastCollateMixup
|
||||
from .naflex_transforms import (
|
||||
ResizeToSequence,
|
||||
CenterCropToSequence,
|
||||
RandomCropToSequence,
|
||||
RandomResizedCropToSequence,
|
||||
ResizeKeepRatioToSequence,
|
||||
)
|
||||
from .readers import create_reader
|
||||
from .readers import get_img_extensions, is_img_extension, set_img_extensions, add_img_extensions, del_img_extensions
|
||||
from .real_labels import RealLabelsImagenet
|
||||
|
422
timm/data/naflex_dataset.py
Normal file
422
timm/data/naflex_dataset.py
Normal file
@ -0,0 +1,422 @@
|
||||
"""
|
||||
Dynamic Sequence Length Datasets for Variable Resolution Image Processing
|
||||
|
||||
Implements two dataset wrappers:
|
||||
1. DynamicSeqMapDataset - Map-style dataset that returns batches with variable sequence lengths
|
||||
2. DynamicSeqIterDataset - Iterable dataset that yields batches with variable sequence lengths
|
||||
|
||||
Both support:
|
||||
- Pre-initialized transforms for efficiency
|
||||
- Distributed training
|
||||
- Multiple workers
|
||||
- Variable batch sizes based on sequence length
|
||||
"""
|
||||
|
||||
import math
|
||||
import random
|
||||
import warnings
|
||||
from functools import partial
|
||||
from typing import Any, Iterator, List, Tuple, Dict, Optional, Union, Callable
|
||||
|
||||
import torch
|
||||
from torch.utils.data import Dataset, IterableDataset, DataLoader
|
||||
from torchvision import transforms
|
||||
from PIL import Image
|
||||
|
||||
|
||||
from .naflex_transforms import Patchify, patchify
|
||||
|
||||
|
||||
def calculate_batch_size(
|
||||
tokens_per_batch: int,
|
||||
seq_len: int,
|
||||
max_size: Optional[int] = None,
|
||||
divisor: int = 1,
|
||||
rounding: str ='floor',
|
||||
):
|
||||
"""Calculate batch size based on sequence length with divisibility constraints."""
|
||||
# Calculate raw batch size based on sequence length
|
||||
raw_batch_size = tokens_per_batch / seq_len
|
||||
|
||||
# Apply divisibility with specified rounding method
|
||||
if divisor > 1:
|
||||
if rounding == 'floor':
|
||||
batch_size = math.floor(raw_batch_size / divisor) * divisor
|
||||
elif rounding == 'ceil':
|
||||
batch_size = math.ceil(raw_batch_size / divisor) * divisor
|
||||
else: # 'round' is the default
|
||||
batch_size = round(raw_batch_size / divisor) * divisor
|
||||
else:
|
||||
# If no divisor specified, just use integer division
|
||||
batch_size = int(raw_batch_size)
|
||||
|
||||
# Ensure batch size is valid
|
||||
batch_size = max(1, batch_size) # At least 1
|
||||
|
||||
if max_size is not None:
|
||||
batch_size = min(batch_size, max_size)
|
||||
|
||||
return batch_size
|
||||
|
||||
|
||||
def _collate_batch(
|
||||
batch_samples: List[Tuple[Dict[str, torch.Tensor], Any]],
|
||||
target_seq_len: int
|
||||
) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]:
|
||||
"""Collates processed samples into a batch, padding/truncating to target_seq_len."""
|
||||
batch_patch_data = [item[0] for item in batch_samples]
|
||||
batch_labels = [item[1] for item in batch_samples]
|
||||
|
||||
if not batch_patch_data:
|
||||
return {}, torch.empty(0)
|
||||
|
||||
batch_size = len(batch_patch_data)
|
||||
patch_dim = batch_patch_data[0]['patches'].shape[1]
|
||||
|
||||
# Initialize tensors with target sequence length
|
||||
patches_batch = torch.zeros((batch_size, target_seq_len, patch_dim), dtype=torch.float32)
|
||||
patch_coord_batch = torch.zeros((batch_size, target_seq_len, 2), dtype=torch.int64)
|
||||
patch_valid_batch = torch.zeros((batch_size, target_seq_len), dtype=torch.bool) # Use bool
|
||||
|
||||
for i, data in enumerate(batch_patch_data):
|
||||
num_patches = data['patches'].shape[0]
|
||||
# Take min(num_patches, target_seq_len) patches
|
||||
n_copy = min(num_patches, target_seq_len)
|
||||
|
||||
patches_batch[i, :n_copy] = data['patches'][:n_copy]
|
||||
patch_coord_batch[i, :n_copy] = data['patch_coord'][:n_copy]
|
||||
patch_valid_batch[i, :n_copy] = data['patch_valid'][:n_copy] # Copy validity flags
|
||||
|
||||
# Create the final input dict
|
||||
input_dict = {
|
||||
'patches': patches_batch,
|
||||
'patch_coord': patch_coord_batch,
|
||||
'patch_valid': patch_valid_batch, # Boolean mask
|
||||
# Note: 'seq_length' might be ambiguous. The target length is target_seq_len.
|
||||
# The actual number of valid patches per sample varies.
|
||||
# 'patch_valid' mask is the most reliable source of truth.
|
||||
}
|
||||
|
||||
# Attempt to stack labels if they are tensors, otherwise return list
|
||||
try:
|
||||
if isinstance(batch_labels[0], torch.Tensor):
|
||||
labels_tensor = torch.stack(batch_labels)
|
||||
else:
|
||||
# Convert numerical types to tensor, keep others as list (or handle specific types)
|
||||
if isinstance(batch_labels[0], (int, float)):
|
||||
labels_tensor = torch.tensor(batch_labels)
|
||||
else:
|
||||
# Cannot convert non-numerical labels easily, return as list
|
||||
# Or handle specific conversion if needed
|
||||
# For FakeDataset, labels are ints, so this works
|
||||
labels_tensor = torch.tensor(batch_labels) # Assuming labels are numerical
|
||||
except Exception:
|
||||
# Fallback if stacking fails (e.g., different shapes, types)
|
||||
print("Warning: Could not stack labels into a tensor. Returning list of labels.")
|
||||
labels_tensor = batch_labels # Return as list
|
||||
|
||||
return input_dict, labels_tensor
|
||||
|
||||
|
||||
class VariableSeqMapWrapper(IterableDataset):
|
||||
"""
|
||||
IterableDataset wrapper for a map-style base dataset.
|
||||
|
||||
Yields batches with variable sequence lengths. It calculates a canonical
|
||||
batch schedule (sequence length, batch size pairs) once based on the
|
||||
total dataset size (padded for distribution). Each epoch, it shuffles
|
||||
the *order* of this canonical schedule and the dataset indices.
|
||||
This ensures a consistent number of batches and samples per epoch
|
||||
across all ranks. Handles distributed training and multiple workers.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_dataset: Dataset,
|
||||
patch_size: Union[int, Tuple[int, int]] = 16,
|
||||
seq_lens: List[int] = (128, 256, 576, 784, 1024),
|
||||
max_tokens_per_batch: int = 4096 * 4, # Example: 16k tokens
|
||||
transform_factory: Optional[Callable] = None,
|
||||
seed: int = 42,
|
||||
shuffle: bool = True,
|
||||
distributed: bool = False,
|
||||
rank: int = 0,
|
||||
world_size: int = 1,
|
||||
epoch: int = 0,
|
||||
batch_divisor: int = 8, # Ensure batch size is divisible by this
|
||||
):
|
||||
super().__init__()
|
||||
if not hasattr(base_dataset, '__len__') or not hasattr(base_dataset, '__getitem__'):
|
||||
raise TypeError("base_dataset must be a map-style dataset (implement __len__ and __getitem__)")
|
||||
|
||||
self.base_dataset = base_dataset
|
||||
self.patch_size = patch_size if isinstance(patch_size, tuple) else (patch_size, patch_size)
|
||||
self.seq_lens = sorted(list(set(seq_lens))) # Ensure unique and sorted
|
||||
self.max_tokens_per_batch = max_tokens_per_batch
|
||||
self.seed = seed
|
||||
self.shuffle = shuffle
|
||||
self.distributed = distributed
|
||||
self.rank = rank if distributed else 0
|
||||
self.world_size = world_size if distributed else 1
|
||||
self.epoch = epoch
|
||||
self.batch_divisor = batch_divisor
|
||||
|
||||
# Pre-initialize transforms for each sequence length
|
||||
self.transforms: Dict[int, Optional[Callable]] = {}
|
||||
if transform_factory:
|
||||
for seq_len in self.seq_lens:
|
||||
self.transforms[seq_len] = transform_factory(max_seq_len=seq_len, patch_size=self.patch_size)
|
||||
else:
|
||||
for seq_len in self.seq_lens:
|
||||
self.transforms[seq_len] = None # No transform
|
||||
|
||||
self.patchifier = Patchify(self.patch_size)
|
||||
|
||||
# --- Canonical Schedule Calculation (Done Once) ---
|
||||
self._canonical_batch_schedule: List[Tuple[int, int]] = []
|
||||
self._num_batches_per_rank: int = 0
|
||||
self._padded_samples_per_rank: int = 0
|
||||
self._create_canonical_schedule() # Calculate schedule based on padded size
|
||||
|
||||
# --- Per-Epoch State ---
|
||||
# Stores (seq_len, list_of_indices) for the current epoch, specific to this rank
|
||||
self._epoch_batches: List[Tuple[int, List[int]]] = []
|
||||
self._prepare_epoch_batches(self.epoch) # setup for initial epoch
|
||||
|
||||
def _create_canonical_schedule(self):
|
||||
"""
|
||||
Calculates the canonical batch schedule (seq_len, batch_size pairs)
|
||||
based on the dataset size, padded for distributed training.
|
||||
This schedule is the *same* for all ranks and ensures consistent
|
||||
epoch length. It is calculated once during initialization.
|
||||
"""
|
||||
total_len = len(self.base_dataset)
|
||||
padded_total_len = total_len
|
||||
num_samples_per_rank = total_len
|
||||
|
||||
if self.distributed and self.world_size > 1:
|
||||
# Calculate padding needed for even distribution
|
||||
if total_len % self.world_size != 0:
|
||||
pad_size = self.world_size - (total_len % self.world_size)
|
||||
padded_total_len += pad_size
|
||||
print(f"Rank {self.rank}: Padding dataset with {pad_size} samples for distributed training (total size {padded_total_len}).")
|
||||
else:
|
||||
pad_size = 0
|
||||
|
||||
if padded_total_len % self.world_size != 0:
|
||||
# This should not happen with the padding logic, but safeguard
|
||||
raise RuntimeError(f"Internal Error: Padded total length {padded_total_len} not divisible by world size {self.world_size}")
|
||||
|
||||
num_samples_per_rank = padded_total_len // self.world_size
|
||||
elif self.distributed and self.world_size <= 1:
|
||||
# Distributed flag set but world_size is 1, treat as non-distributed
|
||||
pass # num_samples_per_rank remains total_len
|
||||
|
||||
self._padded_samples_per_rank = num_samples_per_rank
|
||||
|
||||
if num_samples_per_rank == 0:
|
||||
self._canonical_batch_schedule = []
|
||||
self._num_batches_per_rank = 0
|
||||
return
|
||||
|
||||
# Use a fixed seed for generating the canonical schedule structure
|
||||
g = torch.Generator()
|
||||
g.manual_seed(self.seed) # Use base seed, NOT epoch seed
|
||||
|
||||
current_schedule: List[Tuple[int, int]] = []
|
||||
remaining_samples = num_samples_per_rank
|
||||
total_scheduled_samples = 0
|
||||
|
||||
while remaining_samples > 0:
|
||||
# Sample sequence length deterministically based on base seed
|
||||
seq_idx = torch.randint(0, len(self.seq_lens), (1,), generator=g).item()
|
||||
seq_len = self.seq_lens[seq_idx]
|
||||
|
||||
# Calculate batch size
|
||||
batch_size = calculate_batch_size(
|
||||
tokens_per_batch=self.max_tokens_per_batch,
|
||||
seq_len=seq_len,
|
||||
# max_size should be remaining_samples to avoid overshooting
|
||||
max_size=remaining_samples,
|
||||
divisor=self.batch_divisor,
|
||||
rounding='floor',
|
||||
)
|
||||
# Ensure batch size is positive and doesn't exceed remaining samples
|
||||
batch_size = max(1, batch_size)
|
||||
batch_size = min(batch_size, remaining_samples)
|
||||
|
||||
if batch_size <= 0:
|
||||
warnings.warn(f"Calculated batch size <= 0 (seq_len={seq_len}, remaining={remaining_samples}). Stopping schedule generation early.")
|
||||
break # Avoid infinite loop if something goes wrong
|
||||
|
||||
current_schedule.append((seq_len, batch_size))
|
||||
remaining_samples -= batch_size
|
||||
total_scheduled_samples += batch_size
|
||||
|
||||
# Sanity check: Ensure the schedule covers all samples for the rank
|
||||
if total_scheduled_samples != num_samples_per_rank:
|
||||
warnings.warn(
|
||||
f"Rank {self.rank}: Canonical schedule accounts for {total_scheduled_samples} samples, "
|
||||
f"but expected {num_samples_per_rank} samples per rank. "
|
||||
f"This might happen if min_batch_size or batch_divisor constraints prevent utilizing all samples. "
|
||||
f"Check parameters. Remaining samples: {remaining_samples}"
|
||||
)
|
||||
# Adjust if needed? Could add a final small batch, but might violate constraints.
|
||||
# Current behavior: some samples might be dropped if schedule logic fails.
|
||||
|
||||
self._canonical_batch_schedule = current_schedule
|
||||
self._num_batches_per_rank = len(current_schedule)
|
||||
print(f"Rank {self.rank}: Created canonical schedule with {self._num_batches_per_rank} batches for {self._padded_samples_per_rank} samples/rank.")
|
||||
|
||||
|
||||
def _prepare_epoch_batches(self, epoch: int):
|
||||
"""
|
||||
Prepares the batches for the current epoch by:
|
||||
1. Shuffling the full dataset indices (using epoch seed).
|
||||
2. Applying padding if in distributed mode.
|
||||
3. Selecting indices for the current rank.
|
||||
4. Shuffling the *order* of the canonical batch schedule (using epoch seed).
|
||||
5. Assigning the rank's indices to the shuffled batches.
|
||||
"""
|
||||
g = torch.Generator()
|
||||
g.manual_seed(self.seed + epoch) # Epoch-specific seed for shuffling
|
||||
|
||||
# 1. Get shuffled global indices
|
||||
total_len = len(self.base_dataset)
|
||||
if self.shuffle:
|
||||
all_indices_shuffled = torch.randperm(total_len, generator=g).tolist()
|
||||
else:
|
||||
all_indices_shuffled = list(range(total_len))
|
||||
|
||||
# 2. Apply padding for distributed mode
|
||||
indices_for_ranks = all_indices_shuffled
|
||||
if self.distributed and self.world_size > 1:
|
||||
padded_total_len = self._padded_samples_per_rank * self.world_size
|
||||
if padded_total_len > total_len:
|
||||
pad_size = padded_total_len - total_len
|
||||
# Repeat initial elements from the *shuffled* list for padding
|
||||
indices_for_ranks = all_indices_shuffled + all_indices_shuffled[:pad_size]
|
||||
# Ensure length matches expectation
|
||||
if len(indices_for_ranks) != padded_total_len:
|
||||
raise RuntimeError(f"Internal Error: Padded index list length {len(indices_for_ranks)} does not match expected {padded_total_len}")
|
||||
|
||||
|
||||
# 3. Select indices for the current rank
|
||||
if self.distributed and self.world_size > 1:
|
||||
indices_this_rank = indices_for_ranks[self.rank::self.world_size]
|
||||
else: # Non-distributed or world_size=1
|
||||
indices_this_rank = indices_for_ranks
|
||||
|
||||
# Sanity check length
|
||||
if len(indices_this_rank) != self._padded_samples_per_rank:
|
||||
# This might happen if canonical schedule generation had warnings/issues
|
||||
warnings.warn(
|
||||
f"Rank {self.rank}: Number of indices for this rank ({len(indices_this_rank)}) "
|
||||
f"does not match expected padded samples per rank ({self._padded_samples_per_rank}). "
|
||||
f"Epoch generation might be inconsistent."
|
||||
)
|
||||
# Adjust expected samples? Or truncate/pad indices? Let's proceed but warn.
|
||||
# Using min() prevents IndexError later if indices are fewer than expected.
|
||||
effective_samples_this_rank = min(len(indices_this_rank), self._padded_samples_per_rank)
|
||||
indices_this_rank = indices_this_rank[:effective_samples_this_rank]
|
||||
|
||||
else:
|
||||
effective_samples_this_rank = self._padded_samples_per_rank
|
||||
|
||||
# 4. Shuffle the order of the canonical batch schedule for this epoch
|
||||
if self.shuffle:
|
||||
schedule_perm = torch.randperm(self._num_batches_per_rank, generator=g).tolist()
|
||||
shuffled_schedule = [self._canonical_batch_schedule[i] for i in schedule_perm]
|
||||
else:
|
||||
shuffled_schedule = list(self._canonical_batch_schedule) # Keep original order
|
||||
|
||||
# 5. Assign indices to the shuffled batches
|
||||
self._epoch_batches = []
|
||||
idx_pos = 0
|
||||
scheduled_samples_count = 0
|
||||
for seq_len, bs in shuffled_schedule:
|
||||
# Ensure we don't try to grab more indices than available for the rank
|
||||
actual_bs = min(bs, effective_samples_this_rank - idx_pos)
|
||||
if actual_bs <= 0:
|
||||
if scheduled_samples_count < effective_samples_this_rank:
|
||||
# This indicates mismatch between schedule total and actual samples
|
||||
warnings.warn(f"Rank {self.rank}: Ran out of samples ({idx_pos}/{effective_samples_this_rank}) before processing entire schedule. Check schedule generation.")
|
||||
break # Stop if no more indices or batch size is zero
|
||||
|
||||
batch_indices = indices_this_rank[idx_pos : idx_pos + actual_bs]
|
||||
self._epoch_batches.append((seq_len, batch_indices))
|
||||
idx_pos += actual_bs
|
||||
scheduled_samples_count += actual_bs
|
||||
|
||||
# Final check
|
||||
if scheduled_samples_count != effective_samples_this_rank:
|
||||
warnings.warn(
|
||||
f"Rank {self.rank}: Assigned {scheduled_samples_count} samples to batches, "
|
||||
f"but expected {effective_samples_this_rank} effective samples this epoch. "
|
||||
f"Indices remaining: {effective_samples_this_rank - scheduled_samples_count}."
|
||||
)
|
||||
|
||||
def set_epoch(self, epoch: int):
|
||||
"""Updates the epoch, regenerating the epoch-specific batches."""
|
||||
# Only regenerate if the epoch actually changes
|
||||
if epoch != self.epoch:
|
||||
self.epoch = epoch
|
||||
self._prepare_epoch_batches(epoch)
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""
|
||||
Returns the number of batches per **worker** for the current epoch.
|
||||
Calculated based on the fixed number of batches per rank divided by
|
||||
the number of workers.
|
||||
"""
|
||||
return self._num_batches_per_rank
|
||||
|
||||
def __iter__(self) -> Iterator[Tuple[Dict[str, torch.Tensor], torch.Tensor]]:
|
||||
"""
|
||||
Iterates through the pre-calculated batches for the current epoch,
|
||||
distributing them among workers.
|
||||
"""
|
||||
worker_info = torch.utils.data.get_worker_info()
|
||||
num_workers = worker_info.num_workers if worker_info else 1
|
||||
worker_id = worker_info.id if worker_info else 0
|
||||
|
||||
# Distribute pre-calculated batches among workers for this rank
|
||||
# Each worker processes a slice of the batches prepared in _prepare_epoch_batches
|
||||
batches_for_worker = self._epoch_batches[worker_id::num_workers]
|
||||
for seq_len, indices in batches_for_worker:
|
||||
if not indices: # Skip if a batch ended up with no indices (shouldn't happen often)
|
||||
continue
|
||||
|
||||
# Get the pre-initialized transform for this sequence length
|
||||
transform = self.transforms.get(seq_len)
|
||||
|
||||
batch_samples = []
|
||||
for idx in indices:
|
||||
try:
|
||||
# Get original image and label from map-style dataset
|
||||
img, label = self.base_dataset[idx]
|
||||
|
||||
# Apply transform if available
|
||||
# Handle cases where transform might return None or fail
|
||||
processed_img = transform(img) if transform else img
|
||||
if processed_img is None:
|
||||
warnings.warn(f"Transform returned None for index {idx}. Skipping sample.")
|
||||
continue
|
||||
|
||||
# Apply patching
|
||||
patch_data = self.patchifier(processed_img)
|
||||
batch_samples.append((patch_data, label))
|
||||
|
||||
except IndexError:
|
||||
warnings.warn(f"IndexError encountered for index {idx} (possibly due to padding/repeated indices). Skipping sample.")
|
||||
continue
|
||||
except Exception as e:
|
||||
# Log other potential errors during data loading/processing
|
||||
warnings.warn(f"Error processing sample index {idx}. Error: {e}. Skipping sample.")
|
||||
continue # Skip problematic sample
|
||||
|
||||
# Collate the processed samples into a batch
|
||||
if batch_samples: # Only yield if we successfully processed samples
|
||||
yield _collate_batch(batch_samples, seq_len)
|
||||
|
||||
# If batch_samples is empty after processing 'indices', an empty batch is skipped.
|
341
timm/data/naflex_loader.py
Normal file
341
timm/data/naflex_loader.py
Normal file
@ -0,0 +1,341 @@
|
||||
import math
|
||||
from contextlib import suppress
|
||||
from functools import partial
|
||||
from typing import Callable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from .constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from .loader import _worker_init
|
||||
from .naflex_dataset import VariableSeqMapWrapper
|
||||
from .transforms_factory import create_transform
|
||||
|
||||
|
||||
class NaFlexCollator:
|
||||
"""Custom collator for batching NaFlex-style variable-resolution images."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
patch_size=16,
|
||||
max_seq_len=None,
|
||||
):
|
||||
self.patch_size = patch_size
|
||||
self.max_seq_len = max_seq_len or 576 # Default ViT-B/16 sequence length (577 = 24*24)
|
||||
|
||||
def __call__(self, batch):
|
||||
"""
|
||||
Args:
|
||||
batch: List of tuples (patch_dict, target)
|
||||
|
||||
Returns:
|
||||
A tuple of (input_dict, targets) where input_dict contains:
|
||||
- patches: Padded tensor of patches
|
||||
- patch_coord: Coordinates for each patch (y, x)
|
||||
- patch_valid: Valid indicators
|
||||
"""
|
||||
assert isinstance(batch[0], tuple)
|
||||
batch_size = len(batch)
|
||||
|
||||
# FIXME
|
||||
# get seq len from sampler schedule
|
||||
|
||||
# resize to final size based on seq_len and patchify
|
||||
|
||||
# Extract targets
|
||||
targets = torch.tensor([item[1] for item in batch], dtype=torch.int64)
|
||||
|
||||
# Get patch dictionaries
|
||||
patch_dicts = [item[0] for item in batch]
|
||||
|
||||
# If we have a maximum sequence length constraint, ensure we don't exceed it
|
||||
if self.max_seq_len is not None:
|
||||
max_patches = self.max_seq_len
|
||||
else:
|
||||
# Find the maximum number of patches in this batch
|
||||
max_patches = max(item['patches'].shape[0] for item in patch_dicts)
|
||||
|
||||
# Get patch dimensionality
|
||||
patch_dim = patch_dicts[0]['patches'].shape[1]
|
||||
|
||||
# Prepare tensors for the batch
|
||||
patches = torch.zeros((batch_size, max_patches, patch_dim), dtype=torch.float32)
|
||||
patch_coord = torch.zeros((batch_size, max_patches, 2), dtype=torch.int64) # [B, N, 2] for (y, x)
|
||||
patch_valid = torch.zeros((batch_size, max_patches), dtype=torch.bool)
|
||||
|
||||
# Fill in the tensors
|
||||
for i, patch_dict in enumerate(patch_dicts):
|
||||
num_patches = min(patch_dict['patches'].shape[0], max_patches)
|
||||
|
||||
patches[i, :num_patches] = patch_dict['patches'][:num_patches]
|
||||
patch_coord[i, :num_patches] = patch_dict['patch_coord'][:num_patches]
|
||||
patch_valid[i, :num_patches] = patch_dict['patch_valid'][:num_patches]
|
||||
|
||||
return {
|
||||
'patches': patches,
|
||||
'patch_coord': patch_coord,
|
||||
'patch_valid': patch_valid,
|
||||
'seq_len': max_patches,
|
||||
}, targets
|
||||
|
||||
|
||||
class NaFlexPrefetchLoader:
|
||||
"""Data prefetcher for NaFlex format which normalizes patches."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
loader,
|
||||
mean=(0.485, 0.456, 0.406),
|
||||
std=(0.229, 0.224, 0.225),
|
||||
img_dtype=torch.float32,
|
||||
device=torch.device('cuda')
|
||||
):
|
||||
self.loader = loader
|
||||
self.device = device
|
||||
self.img_dtype = img_dtype or torch.float32
|
||||
|
||||
# Create mean/std tensors for normalization (will be applied to patches)
|
||||
self.mean = torch.tensor([x * 255 for x in mean], device=device, dtype=self.img_dtype).view(1, 1, 3)
|
||||
self.std = torch.tensor([x * 255 for x in std], device=device, dtype=self.img_dtype).view(1, 1, 3)
|
||||
|
||||
# Check for CUDA/NPU availability
|
||||
self.is_cuda = device.type == 'cuda' and torch.cuda.is_available()
|
||||
self.is_npu = device.type == 'npu' and torch.npu.is_available()
|
||||
|
||||
def __iter__(self):
|
||||
first = True
|
||||
if self.is_cuda:
|
||||
stream = torch.cuda.Stream()
|
||||
stream_context = partial(torch.cuda.stream, stream=stream)
|
||||
elif self.is_npu:
|
||||
stream = torch.npu.Stream()
|
||||
stream_context = partial(torch.npu.stream, stream=stream)
|
||||
else:
|
||||
stream = None
|
||||
stream_context = suppress
|
||||
|
||||
for next_input_dict, next_target in self.loader:
|
||||
with stream_context():
|
||||
# Move all tensors in input_dict to device
|
||||
for k, v in next_input_dict.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
dtype = self.img_dtype if k == 'patches' else None
|
||||
next_input_dict[k] = next_input_dict[k].to(
|
||||
device=self.device,
|
||||
non_blocking=True,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
next_target = next_target.to(device=self.device, non_blocking=True)
|
||||
|
||||
# Normalize patch values (assuming patches are in format [B, N, P*P*C])
|
||||
batch_size, num_patches, patch_pixels = next_input_dict['patches'].shape
|
||||
patches = next_input_dict['patches'].view(batch_size, -1, 3) # to [B*N, P*P, C] for normalization
|
||||
patches = patches.sub(self.mean).div(self.std)
|
||||
|
||||
# Reshape back
|
||||
next_input_dict['patches'] = patches.reshape(batch_size, num_patches, patch_pixels)
|
||||
|
||||
if not first:
|
||||
yield input_dict, target
|
||||
else:
|
||||
first = False
|
||||
|
||||
if stream is not None:
|
||||
if self.is_cuda:
|
||||
torch.cuda.current_stream().wait_stream(stream)
|
||||
elif self.is_npu:
|
||||
torch.npu.current_stream().wait_stream(stream)
|
||||
|
||||
input_dict = next_input_dict
|
||||
target = next_target
|
||||
|
||||
yield input_dict, target
|
||||
|
||||
def __len__(self):
|
||||
return len(self.loader)
|
||||
|
||||
@property
|
||||
def sampler(self):
|
||||
return self.loader.sampler
|
||||
|
||||
@property
|
||||
def dataset(self):
|
||||
return self.loader.dataset
|
||||
|
||||
|
||||
def create_naflex_loader(
|
||||
dataset,
|
||||
patch_size: Union[Tuple[int, int], int] = 16,
|
||||
train_seq_lens: List[int] = (128, 256, 576, 784, 1024), # Training sequence lengths
|
||||
max_seq_len: int = 576, # Fixed sequence length for validation
|
||||
batch_size: int = 32, # Used for max_seq_len and max(train_seq_lens)
|
||||
is_training: bool = False,
|
||||
|
||||
no_aug: bool = False,
|
||||
re_prob: float = 0.,
|
||||
re_mode: str = 'const',
|
||||
re_count: int = 1,
|
||||
re_split: bool = False,
|
||||
train_crop_mode: Optional[str] = None,
|
||||
scale: Optional[Tuple[float, float]] = None,
|
||||
ratio: Optional[Tuple[float, float]] = None,
|
||||
hflip: float = 0.5,
|
||||
vflip: float = 0.,
|
||||
color_jitter: float = 0.4,
|
||||
color_jitter_prob: Optional[float] = None,
|
||||
grayscale_prob: float = 0.,
|
||||
gaussian_blur_prob: float = 0.,
|
||||
auto_augment: Optional[str] = None,
|
||||
num_aug_repeats: int = 0,
|
||||
num_aug_splits: int = 0,
|
||||
interpolation: str = 'bilinear',
|
||||
mean: Tuple[float, ...] = IMAGENET_DEFAULT_MEAN,
|
||||
std: Tuple[float, ...] = IMAGENET_DEFAULT_STD,
|
||||
crop_pct: Optional[float] = None,
|
||||
crop_mode: Optional[str] = None,
|
||||
crop_border_pixels: Optional[int] = None,
|
||||
|
||||
num_workers: int = 4,
|
||||
distributed: bool = False,
|
||||
rank: int = 0,
|
||||
world_size: int = 1,
|
||||
seed: int = 42,
|
||||
epoch: int = 0,
|
||||
use_prefetcher: bool = True,
|
||||
pin_memory: bool = True,
|
||||
img_dtype: torch.dtype = torch.float32,
|
||||
device: Union[str, torch.device] = torch.device('cuda'),
|
||||
persistent_workers: bool = True,
|
||||
worker_seeding: str = 'all',
|
||||
):
|
||||
"""Create a data loader with dynamic sequence length sampling for training."""
|
||||
|
||||
if is_training:
|
||||
# For training, use the dynamic sequence length mechanism
|
||||
assert num_aug_repeats == 0, 'Augmentation repeats not currently supported in NaFlex loader'
|
||||
|
||||
transform_factory = partial(
|
||||
create_transform,
|
||||
is_training=True,
|
||||
no_aug=no_aug,
|
||||
train_crop_mode=train_crop_mode,
|
||||
scale=scale,
|
||||
ratio=ratio,
|
||||
hflip=hflip,
|
||||
vflip=vflip,
|
||||
color_jitter=color_jitter,
|
||||
color_jitter_prob=color_jitter_prob,
|
||||
grayscale_prob=grayscale_prob,
|
||||
gaussian_blur_prob=gaussian_blur_prob,
|
||||
auto_augment=auto_augment,
|
||||
interpolation=interpolation,
|
||||
mean=mean,
|
||||
std=std,
|
||||
crop_pct=crop_pct,
|
||||
crop_mode=crop_mode,
|
||||
crop_border_pixels=crop_border_pixels,
|
||||
re_prob=re_prob,
|
||||
re_mode=re_mode,
|
||||
re_count=re_count,
|
||||
use_prefetcher=use_prefetcher,
|
||||
naflex=True,
|
||||
)
|
||||
|
||||
max_train_seq_len = max(train_seq_lens)
|
||||
max_tokens_per_batch = batch_size * max_train_seq_len
|
||||
|
||||
if isinstance(dataset, torch.utils.data.IterableDataset):
|
||||
assert False, "IterableDataset Wrapper is a WIP"
|
||||
|
||||
naflex_dataset = VariableSeqMapWrapper(
|
||||
dataset,
|
||||
transform_factory=transform_factory,
|
||||
patch_size=patch_size,
|
||||
seq_lens=train_seq_lens,
|
||||
max_tokens_per_batch=max_tokens_per_batch,
|
||||
seed=seed,
|
||||
distributed=distributed,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
shuffle=True,
|
||||
epoch=epoch,
|
||||
)
|
||||
|
||||
# NOTE: Collation is handled by the dataset wrapper for training
|
||||
# Create the collator (handles fixed-size collation)
|
||||
# collate_fn = NaFlexCollator(
|
||||
# patch_size=patch_size,
|
||||
# max_seq_len=max(seq_lens) + 1, # +1 for class token
|
||||
# use_prefetcher=use_prefetcher
|
||||
# )
|
||||
|
||||
loader = torch.utils.data.DataLoader(
|
||||
naflex_dataset,
|
||||
batch_size=None,
|
||||
shuffle=False,
|
||||
num_workers=num_workers,
|
||||
sampler=None,
|
||||
#collate_fn=collate_fn,
|
||||
pin_memory=pin_memory,
|
||||
worker_init_fn=partial(_worker_init, worker_seeding=worker_seeding),
|
||||
persistent_workers=persistent_workers
|
||||
)
|
||||
|
||||
if use_prefetcher:
|
||||
loader = NaFlexPrefetchLoader(
|
||||
loader,
|
||||
mean=mean,
|
||||
std=std,
|
||||
img_dtype=img_dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
else:
|
||||
# For validation, use fixed sequence length (unchanged)
|
||||
dataset.transform = create_transform(
|
||||
is_training=False,
|
||||
interpolation=interpolation,
|
||||
mean=mean,
|
||||
std=std,
|
||||
# FIXME add crop args when sequence transforms support crop modes
|
||||
use_prefetcher=use_prefetcher,
|
||||
naflex=True,
|
||||
patch_size=patch_size,
|
||||
max_seq_len=max_seq_len,
|
||||
patchify=True,
|
||||
)
|
||||
|
||||
# Create the collator
|
||||
collate_fn = NaFlexCollator(
|
||||
patch_size=patch_size,
|
||||
max_seq_len=max_seq_len,
|
||||
)
|
||||
|
||||
# Handle distributed training
|
||||
sampler = None
|
||||
if distributed and not isinstance(dataset, torch.utils.data.IterableDataset):
|
||||
# For validation, use OrderedDistributedSampler
|
||||
from timm.data.distributed_sampler import OrderedDistributedSampler
|
||||
sampler = OrderedDistributedSampler(dataset)
|
||||
|
||||
loader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=False,
|
||||
num_workers=num_workers,
|
||||
sampler=sampler,
|
||||
collate_fn=collate_fn,
|
||||
pin_memory=pin_memory,
|
||||
drop_last=False,
|
||||
)
|
||||
|
||||
if use_prefetcher:
|
||||
loader = NaFlexPrefetchLoader(
|
||||
loader,
|
||||
mean=mean,
|
||||
std=std,
|
||||
img_dtype=img_dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
return loader
|
804
timm/data/naflex_transforms.py
Normal file
804
timm/data/naflex_transforms.py
Normal file
@ -0,0 +1,804 @@
|
||||
""" NaFlex (NaViT + FlexiViT) Transforms and Collation
|
||||
|
||||
Implements PyTorch versions of the transforms described in the NaViT and FlexiViT papers:
|
||||
- NaViT: https://arxiv.org/abs/2307.14995
|
||||
- FlexiViT: https://arxiv.org/abs/2212.08013
|
||||
|
||||
Enables variable resolution/aspect ratio image handling with efficient patching.
|
||||
"""
|
||||
|
||||
import math
|
||||
import random
|
||||
import warnings
|
||||
from typing import List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torchvision import transforms
|
||||
from torchvision.transforms import functional as F
|
||||
from torchvision.transforms.functional import InterpolationMode
|
||||
|
||||
from .transforms import str_to_interp_mode, crop_or_pad, center_crop_or_pad
|
||||
|
||||
|
||||
def get_image_size_for_seq(
|
||||
image_hw,
|
||||
patch_size=16,
|
||||
max_seq_len=1024,
|
||||
divisible_by_patch=True,
|
||||
max_ratio=None,
|
||||
eps = 1e-5,
|
||||
):
|
||||
"""
|
||||
Determine scaling ratio and image size so that when `image_hw` is scaled
|
||||
by 'ratio', the total number of resulting patches does not exceed
|
||||
'max_seq_len'.
|
||||
|
||||
- Patch size can be an integer (square patch) or a tuple (patch_h, patch_w).
|
||||
- Optionally cap the ratio at `max_ratio` to prevent upsampling beyond
|
||||
a certain multiple of the original size.
|
||||
|
||||
Args:
|
||||
image_hw (tuple or list of int): (height, width) of the original image.
|
||||
patch_size (int or tuple[int, int]): If int, patch is square. If tuple,
|
||||
patch is rectangular (patch_h, patch_w).
|
||||
max_seq_len (int): Maximum allowed sequence length for the resulting image.
|
||||
divisible_by_patch (bool): If True, the resulting image height and width
|
||||
must be multiples of patch_size.
|
||||
eps (float): Small number for binary search convergence.
|
||||
max_ratio (float or None): If provided, the scaling ratio found by the
|
||||
binary search will be clamped to min(found_ratio, max_ratio). Set
|
||||
max_ratio=1.0 to ensure no upsampling beyond original size.
|
||||
|
||||
Returns:
|
||||
ratio (float): Found scaling ratio (capped by `max_ratio` if provided).
|
||||
target_hw (tuple of int): Target (height, width) after scaling.
|
||||
"""
|
||||
|
||||
# Handle patch size input, extract patch_h, patch_w
|
||||
if isinstance(patch_size, int):
|
||||
patch_h, patch_w = patch_size, patch_size
|
||||
else:
|
||||
# Assume it's a tuple/list: (patch_h, patch_w)
|
||||
if len(patch_size) != 2:
|
||||
raise ValueError("patch_size tuple must have exactly two elements (patch_h, patch_w).")
|
||||
patch_h, patch_w = patch_size
|
||||
|
||||
# Safety checks
|
||||
if patch_h <= 0 or patch_w <= 0:
|
||||
raise ValueError("patch_size dimensions must be positive.")
|
||||
|
||||
def prepare_target_hw(ratio):
|
||||
"""Scale image_hw by ratio and optionally round dimensions to multiples of patch_h, patch_w."""
|
||||
scaled_h = image_hw[0] * ratio
|
||||
scaled_w = image_hw[1] * ratio
|
||||
|
||||
# If we need the result to be divisible by patch_size
|
||||
if divisible_by_patch:
|
||||
scaled_h = patch_h * math.ceil(scaled_h / patch_h)
|
||||
scaled_w = patch_w * math.ceil(scaled_w / patch_w)
|
||||
|
||||
# Ensure at least one patch in each dimension
|
||||
scaled_h = int(max(scaled_h, patch_h))
|
||||
scaled_w = int(max(scaled_w, patch_w))
|
||||
|
||||
return scaled_h, scaled_w
|
||||
|
||||
def is_feasible(ratio):
|
||||
"""Check if scaling by 'ratio' keeps patch count within max_seq_len."""
|
||||
t_h, t_w = prepare_target_hw(ratio)
|
||||
|
||||
# Each dimension is already a multiple of patch_h, patch_w if divisible_by_patch=True.
|
||||
# Use integer division to count patches.
|
||||
num_patches_h = t_h // patch_h
|
||||
num_patches_w = t_w // patch_w
|
||||
seq_len = num_patches_h * num_patches_w
|
||||
|
||||
return seq_len <= max_seq_len
|
||||
|
||||
# Binary search boundaries
|
||||
lb = eps / 10.0
|
||||
rb = 100.0
|
||||
|
||||
# Standard binary search loop
|
||||
while (rb - lb) >= eps:
|
||||
mid = (lb + rb) / 2.0
|
||||
if is_feasible(mid):
|
||||
lb = mid
|
||||
else:
|
||||
rb = mid
|
||||
|
||||
# The final ratio from the binary search
|
||||
ratio = lb
|
||||
|
||||
# If max_ratio is provided, clamp it to prevent upsampling beyond that threshold
|
||||
if max_ratio is not None:
|
||||
ratio = min(ratio, max_ratio)
|
||||
|
||||
# Final checks
|
||||
if ratio <= eps:
|
||||
raise ValueError("Binary search failed - image might be too large?")
|
||||
if ratio >= 100.0:
|
||||
raise ValueError("Binary search failed - image might be too small?")
|
||||
|
||||
# Prepare the final target dimensions with the possibly clamped ratio
|
||||
target_hw = prepare_target_hw(ratio)
|
||||
return ratio, target_hw
|
||||
|
||||
|
||||
_RANDOM_INTERPOLATION = (str_to_interp_mode('bilinear'), str_to_interp_mode('bicubic'))
|
||||
|
||||
|
||||
class ResizeToSequence(torch.nn.Module):
|
||||
"""Resize image to fit within a maximum sequence length constraint when patchified.
|
||||
|
||||
This maintains aspect ratio while ensuring the resulting image, when divided into patches,
|
||||
will not exceed the specified maximum sequence length.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
patch_size: int,
|
||||
max_seq_len: int = 1024,
|
||||
divisible_by_patch: bool = True,
|
||||
max_ratio: Optional[float] = None,
|
||||
interpolation='bicubic',
|
||||
):
|
||||
super().__init__()
|
||||
self.patch_size = patch_size
|
||||
self.max_seq_len = max_seq_len
|
||||
self.divisible_by_patch = divisible_by_patch
|
||||
self.max_ratio = max_ratio
|
||||
if isinstance(interpolation, str):
|
||||
if interpolation == 'random':
|
||||
self.interpolation = _RANDOM_INTERPOLATION
|
||||
else:
|
||||
self.interpolation = str_to_interp_mode(interpolation)
|
||||
else:
|
||||
self.interpolation = interpolation
|
||||
|
||||
|
||||
def forward(self, img):
|
||||
"""Resize image to maintain aspect ratio and fit sequence constraint."""
|
||||
_, h, w = transforms.functional.get_dimensions(img)
|
||||
|
||||
_, target_hw = get_image_size_for_seq(
|
||||
(h, w),
|
||||
self.patch_size,
|
||||
self.max_seq_len,
|
||||
divisible_by_patch=self.divisible_by_patch,
|
||||
max_ratio=self.max_ratio,
|
||||
)
|
||||
|
||||
if isinstance(self.interpolation, (tuple, list)):
|
||||
interpolation = random.choice(self.interpolation)
|
||||
else:
|
||||
interpolation = self.interpolation
|
||||
|
||||
resized_img = transforms.functional.resize(img, target_hw, interpolation=interpolation, antialias=True)
|
||||
|
||||
return resized_img
|
||||
|
||||
|
||||
class ResizeKeepRatioToSequence(torch.nn.Module):
|
||||
"""
|
||||
Resize and Keep Aspect Ratio, adapted to fit sequence length constraints.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
patch_size=16,
|
||||
max_sequence_len=1024,
|
||||
divisible_by_patch=True,
|
||||
longest=0.,
|
||||
interpolation='bilinear',
|
||||
random_scale_prob=0.,
|
||||
random_scale_range=(0.85, 1.05),
|
||||
random_scale_area=False,
|
||||
random_aspect_prob=0.,
|
||||
random_aspect_range=(0.9, 1.11),
|
||||
max_ratio=None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
patch_size: Size of patches (int or tuple of (patch_h, patch_w))
|
||||
max_sequence_len: Maximum allowed sequence length for the resulting image
|
||||
divisible_by_patch: If True, ensure dimensions are divisible by patch_size
|
||||
longest: Float between 0-1 where 0=shortest side, 1=longest side determines scale
|
||||
interpolation: Interpolation method for resizing
|
||||
random_scale_prob: Probability of applying random scaling
|
||||
random_scale_range: Range for random scaling factor (min, max)
|
||||
random_scale_area: If True, scale factors affect area (√ factor)
|
||||
random_aspect_prob: Probability of applying random aspect ratio jittering
|
||||
random_aspect_range: Range for random aspect ratio (min, max)
|
||||
max_ratio: Maximum allowed scaling ratio
|
||||
"""
|
||||
super().__init__()
|
||||
self.patch_size = patch_size
|
||||
self.max_sequence_len = max_sequence_len
|
||||
self.divisible_by_patch = divisible_by_patch
|
||||
self.longest = float(longest)
|
||||
|
||||
if interpolation == 'random':
|
||||
self.interpolation = _RANDOM_INTERPOLATION
|
||||
else:
|
||||
self.interpolation = str_to_interp_mode(interpolation)
|
||||
|
||||
self.random_scale_prob = random_scale_prob
|
||||
self.random_scale_range = random_scale_range
|
||||
self.random_scale_area = random_scale_area
|
||||
self.random_aspect_prob = random_aspect_prob
|
||||
self.random_aspect_range = random_aspect_range
|
||||
self.max_ratio = max_ratio
|
||||
|
||||
@staticmethod
|
||||
def get_params(
|
||||
img,
|
||||
patch_size,
|
||||
max_sequence_len,
|
||||
divisible_by_patch,
|
||||
longest,
|
||||
random_scale_prob=0.,
|
||||
random_scale_range=(1.0, 1.33),
|
||||
random_scale_area=False,
|
||||
random_aspect_prob=0.,
|
||||
random_aspect_range=(0.9, 1.11),
|
||||
max_ratio=None,
|
||||
):
|
||||
"""Get parameters for resizing."""
|
||||
# Get image dimensions
|
||||
img_h, img_w = F.get_dimensions(img)[1:]
|
||||
|
||||
# Step 1: Get the maximum allowed dimensions from sequence length constraint
|
||||
_, target_hw = get_image_size_for_seq(
|
||||
(img_h, img_w),
|
||||
patch_size,
|
||||
max_sequence_len,
|
||||
divisible_by_patch,
|
||||
max_ratio,
|
||||
)
|
||||
target_h, target_w = target_hw
|
||||
|
||||
# Calculate ratio based on sequence constraint
|
||||
ratio_h = target_h / img_h
|
||||
ratio_w = target_w / img_w
|
||||
# Apply longest blending
|
||||
ratio = max(ratio_h, ratio_w) * longest + min(ratio_h, ratio_w) * (1. - longest)
|
||||
|
||||
# Apply random scaling
|
||||
if random_scale_prob > 0 and random.random() < random_scale_prob:
|
||||
ratio_factor = random.uniform(random_scale_range[0], random_scale_range[1])
|
||||
if random_scale_area:
|
||||
# Make ratio factor equivalent to area change
|
||||
ratio_factor = 1. / math.sqrt(ratio_factor)
|
||||
ratio_factor = (ratio_factor, ratio_factor)
|
||||
else:
|
||||
ratio_factor = (1., 1.)
|
||||
|
||||
# Apply random aspect
|
||||
if random_aspect_prob > 0 and random.random() < random_aspect_prob:
|
||||
log_aspect = (math.log(random_aspect_range[0]), math.log(random_aspect_range[1]))
|
||||
aspect_factor = math.exp(random.uniform(*log_aspect))
|
||||
aspect_factor = math.sqrt(aspect_factor)
|
||||
# Apply aspect ratio jittering
|
||||
ratio_factor = (ratio_factor[0] / aspect_factor, ratio_factor[1] * aspect_factor)
|
||||
|
||||
# Calculate final dimensions
|
||||
size = [round(dim * ratio * f) for dim, f in zip((img_h, img_w), ratio_factor)]
|
||||
|
||||
# Ensure dimensions satisfy sequence constraint and are divisible by patch size
|
||||
if isinstance(patch_size, int):
|
||||
ph, pw = patch_size, patch_size
|
||||
else:
|
||||
ph, pw = patch_size
|
||||
|
||||
# Ensure dimensions are at least one patch
|
||||
size[0] = max(size[0], ph)
|
||||
size[1] = max(size[1], pw)
|
||||
|
||||
# Make divisible by patch size if needed
|
||||
if divisible_by_patch:
|
||||
size[0] = ph * math.ceil(size[0] / ph)
|
||||
size[1] = pw * math.ceil(size[1] / pw)
|
||||
|
||||
# Verify we haven't exceeded sequence length
|
||||
num_patches_h = size[0] // ph
|
||||
num_patches_w = size[1] // pw
|
||||
seq_len = num_patches_h * num_patches_w
|
||||
|
||||
if seq_len > max_sequence_len:
|
||||
# Scale back down to fit sequence constraint
|
||||
scale_back = math.sqrt(max_sequence_len / seq_len)
|
||||
size[0] = int(size[0] * scale_back)
|
||||
size[1] = int(size[1] * scale_back)
|
||||
|
||||
# Ensure divisible by patch size after scaling back
|
||||
if divisible_by_patch:
|
||||
size[0] = ph * math.ceil(size[0] / ph)
|
||||
size[1] = pw * math.ceil(size[1] / pw)
|
||||
|
||||
return size
|
||||
|
||||
def forward(self, img):
|
||||
"""
|
||||
Resize the image with aspect ratio preservation and sequence length constraints.
|
||||
"""
|
||||
size = self.get_params(
|
||||
img,
|
||||
self.patch_size,
|
||||
self.max_sequence_len,
|
||||
self.divisible_by_patch,
|
||||
self.longest,
|
||||
self.random_scale_prob,
|
||||
self.random_scale_range,
|
||||
self.random_scale_area,
|
||||
self.random_aspect_prob,
|
||||
self.random_aspect_range,
|
||||
self.max_ratio,
|
||||
)
|
||||
|
||||
if isinstance(self.interpolation, (tuple, list)):
|
||||
interpolation = random.choice(self.interpolation)
|
||||
else:
|
||||
interpolation = self.interpolation
|
||||
|
||||
return F.resize(img, size, interpolation)
|
||||
|
||||
def __repr__(self):
|
||||
interpolate_str = "random" if isinstance(self.interpolation, (tuple, list)) else str(self.interpolation)
|
||||
return (f"{self.__class__.__name__}(patch_size={self.patch_size}, "
|
||||
f"max_sequence_len={self.max_sequence_len}, "
|
||||
f"longest={self.longest:.3f}, "
|
||||
f"random_scale_prob={self.random_scale_prob:.3f}, "
|
||||
f"random_aspect_prob={self.random_aspect_prob:.3f})")
|
||||
|
||||
|
||||
class CenterCropToSequence(torch.nn.Module):
|
||||
"""Center crop the image such that the resulting patch sequence length meets constraints."""
|
||||
def __init__(
|
||||
self,
|
||||
patch_size: int,
|
||||
max_seq_len: int,
|
||||
divisible_by_patch: bool = True,
|
||||
fill: Union[int, Tuple[int, int, int]] = 0,
|
||||
padding_mode: str = 'constant'
|
||||
):
|
||||
super().__init__()
|
||||
self.patch_size = patch_size
|
||||
self.max_seq_len = max_seq_len
|
||||
self.divisible_by_patch = divisible_by_patch
|
||||
self.fill = fill
|
||||
self.padding_mode = padding_mode
|
||||
|
||||
|
||||
def forward(self, img):
|
||||
"""Center crop the image to maintain aspect ratio and fit sequence constraint."""
|
||||
_, h, w = transforms.functional.get_dimensions(img)
|
||||
_, target_hw = get_image_size_for_seq(
|
||||
(h, w),
|
||||
self.patch_size,
|
||||
self.max_seq_len,
|
||||
self.divisible_by_patch
|
||||
)
|
||||
|
||||
# Use center crop
|
||||
return center_crop_or_pad(img, target_hw, fill=self.fill, padding_mode=self.padding_mode)
|
||||
|
||||
|
||||
class RandomCropToSequence(torch.nn.Module):
|
||||
"""Randomly crop and/or pad the image to fit sequence length constraints.
|
||||
|
||||
This maintains aspect ratio while ensuring the resulting image, when divided into patches,
|
||||
will not exceed the specified maximum sequence length. Similar to CentralCropToSequence
|
||||
but with randomized positioning.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
patch_size: int,
|
||||
max_sequence_len: int,
|
||||
divisible_by_patch: bool = True,
|
||||
fill: Union[int, Tuple[int, int, int]] = 0,
|
||||
padding_mode: str = 'constant'
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
patch_size: Size of patches (int or tuple of (patch_h, patch_w))
|
||||
max_sequence_len: Maximum allowed sequence length for the resulting image
|
||||
divisible_by_patch: If True, resulting image dimensions will be multiples of patch_size
|
||||
fill: Fill value for padding
|
||||
padding_mode: Padding mode ('constant', 'edge', 'reflect', 'symmetric')
|
||||
"""
|
||||
super().__init__()
|
||||
self.patch_size = patch_size
|
||||
self.max_sequence_len = max_sequence_len
|
||||
self.divisible_by_patch = divisible_by_patch
|
||||
self.fill = fill
|
||||
self.padding_mode = padding_mode
|
||||
|
||||
@staticmethod
|
||||
def get_params(img, target_size):
|
||||
"""Get random position for crop/pad."""
|
||||
_, image_height, image_width = transforms.functional.get_dimensions(img)
|
||||
delta_height = image_height - target_size[0]
|
||||
delta_width = image_width - target_size[1]
|
||||
|
||||
# Handle both positive (crop) and negative (pad) deltas
|
||||
if delta_height == 0:
|
||||
top = 0
|
||||
else:
|
||||
top = int(math.copysign(random.randint(0, abs(delta_height)), delta_height))
|
||||
|
||||
if delta_width == 0:
|
||||
left = 0
|
||||
else:
|
||||
left = int(math.copysign(random.randint(0, abs(delta_width)), delta_width))
|
||||
|
||||
return top, left
|
||||
|
||||
def forward(self, img):
|
||||
"""Randomly crop or pad the image to maintain aspect ratio and fit sequence constraint."""
|
||||
# Get current dimensions
|
||||
_, img_h, img_w = transforms.functional.get_dimensions(img)
|
||||
|
||||
# Calculate target dimensions that satisfy sequence length
|
||||
# We use max_ratio=1.0 to prevent upscaling - we only want to crop or maintain current size
|
||||
_, target_hw = get_image_size_for_seq(
|
||||
(img_h, img_w),
|
||||
self.patch_size,
|
||||
self.max_sequence_len,
|
||||
self.divisible_by_patch,
|
||||
max_ratio=1.0 # Prevent upscaling
|
||||
)
|
||||
|
||||
# Get random position for crop/pad
|
||||
top, left = self.get_params(img, target_hw)
|
||||
|
||||
# Apply crop or pad
|
||||
return crop_or_pad(
|
||||
img,
|
||||
top=top,
|
||||
left=left,
|
||||
height=target_hw[0],
|
||||
width=target_hw[1],
|
||||
fill=self.fill,
|
||||
padding_mode=self.padding_mode,
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"{self.__class__.__name__}(patch_size={self.patch_size}, "
|
||||
f"max_sequence_len={self.max_sequence_len}, "
|
||||
f"divisible_by_patch={self.divisible_by_patch})")
|
||||
|
||||
|
||||
def _validate_range(value, name, length=2):
|
||||
# Validate type and length
|
||||
if not isinstance(value, Sequence) or len(value) != length:
|
||||
raise ValueError(f"{name} should be a sequence of length {length}.")
|
||||
|
||||
# Validate order
|
||||
if value[0] > value[1]:
|
||||
warnings.warn(f"{name.capitalize()} range reversed. Swapping.")
|
||||
return value[1], value[0]
|
||||
|
||||
return value
|
||||
|
||||
|
||||
class RandomResizedCropToSequence(torch.nn.Module):
|
||||
"""
|
||||
Randomly crop the input image to a subregion with varying area and aspect ratio
|
||||
(relative to the original), then resize that crop to a target size. The target size
|
||||
is determined such that patchifying the resized image (with `patch_size`)
|
||||
does not exceed `max_seq_len` patches, while maintaining the aspect ratio of the crop.
|
||||
|
||||
This combines aspects of torchvision's RandomResizedCrop with sequence length constraints.
|
||||
|
||||
Args:
|
||||
patch_size (int or tuple[int, int]):
|
||||
Patch dimensions (patch_h, patch_w) for sequence length calculation.
|
||||
max_seq_len (int):
|
||||
Maximum number of patches allowed in the final image.
|
||||
scale (tuple[float, float]):
|
||||
Range (min, max) of area fraction of the original image to crop.
|
||||
ratio (tuple[float, float]):
|
||||
Range (min, max) of aspect ratio *multipliers* for the crop, relative
|
||||
to the original image's aspect ratio. E.g., (0.75, 1.333) means the
|
||||
crop's aspect ratio will be sampled between 0.75*orig_ar and 1.333*orig_ar.
|
||||
Uses log-uniform sampling.
|
||||
interpolation (str or InterpolationMode):
|
||||
Interpolation mode for resizing. Can be 'bilinear', 'bicubic', 'nearest',
|
||||
or 'random' (chooses between bilinear and bicubic).
|
||||
Defaults to 'bicubic'.
|
||||
divisible_by_patch (bool):
|
||||
If True, the final image height and width will be multiples of the
|
||||
respective patch dimensions. Defaults to True.
|
||||
max_ratio (float, optional):
|
||||
An optional upper limit on the scaling ratio applied during resizing.
|
||||
Prevents excessive upsampling of the initial crop. `max_ratio=1.0`
|
||||
prevents any upsampling beyond the cropped size. Defaults to None (no limit).
|
||||
final_scale_range (tuple[float, float], optional):
|
||||
If provided, applies an *additional* random scaling factor to the
|
||||
final target size. The factor is sampled uniformly from this range,
|
||||
and multiplied by the size determined by `get_image_size_for_seq`.
|
||||
E.g., (0.8, 1.0) means the final size will be between 80% and 100%
|
||||
of the maximum feasible size. Defaults to None (use maximum feasible size).
|
||||
attempts (int):
|
||||
Number of attempts to sample a valid crop geometry before falling back
|
||||
to a center crop strategy. Defaults to 10.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
patch_size: Union[int, Tuple[int, int]] = 16,
|
||||
max_seq_len: int = 1024,
|
||||
scale: Tuple[float, float] = (0.08, 1.0),
|
||||
ratio: Tuple[float, float] = (.8, 1.25),
|
||||
interpolation: Union[str, InterpolationMode] = 'bicubic',
|
||||
divisible_by_patch: bool = True,
|
||||
max_ratio: Optional[float] = None,
|
||||
final_scale_range: Optional[Tuple[float, float]] = None,
|
||||
attempts: int = 10,
|
||||
):
|
||||
super().__init__()
|
||||
if isinstance(patch_size, int):
|
||||
self.patch_h, self.patch_w = patch_size, patch_size
|
||||
else:
|
||||
# Assume it's a tuple/list: (patch_h, patch_w)
|
||||
if len(patch_size) != 2:
|
||||
raise ValueError("patch_size tuple must have exactly two elements (patch_h, patch_w).")
|
||||
self.patch_h, self.patch_w = patch_size
|
||||
self.max_seq_len = max_seq_len
|
||||
self.scale = scale
|
||||
self.ratio = ratio
|
||||
self.divisible_by_patch = divisible_by_patch
|
||||
self.max_ratio = max_ratio
|
||||
self.final_scale_range = final_scale_range
|
||||
self.attempts = attempts
|
||||
if isinstance(interpolation, str):
|
||||
if interpolation == 'random':
|
||||
self.interpolation = _RANDOM_INTERPOLATION
|
||||
else:
|
||||
self.interpolation = str_to_interp_mode(interpolation)
|
||||
else:
|
||||
self.interpolation = interpolation
|
||||
|
||||
# Validate scale and ratio
|
||||
self.scale = _validate_range(self.scale, "scale")
|
||||
self.ratio = _validate_range(self.ratio, "ratio")
|
||||
|
||||
# Validate final_scale_range if provided
|
||||
if self.final_scale_range is not None:
|
||||
self.final_scale_range = _validate_range(self.final_scale_range, "final_scale_range")
|
||||
|
||||
# Additional validation for final_scale_range values
|
||||
if not (0.0 <= self.final_scale_range[0] <= self.final_scale_range[1] <= 1.0):
|
||||
warnings.warn("final_scale_range values should ideally be between 0.0 and 1.0.")
|
||||
|
||||
@staticmethod
|
||||
def get_params(
|
||||
img: Union[torch.Tensor, Image],
|
||||
scale: Tuple[float, float],
|
||||
ratio: Tuple[float, float],
|
||||
crop_attempts: int = 10,
|
||||
patch_h: int = 16,
|
||||
patch_w: int = 16,
|
||||
max_seq_len: int = 1024,
|
||||
divisible_by_patch: bool = True,
|
||||
max_ratio: Optional[float] = None,
|
||||
final_scale_range: Optional[Tuple[float, float]] = None,
|
||||
interpolation: Union[List[InterpolationMode], InterpolationMode] = _RANDOM_INTERPOLATION,
|
||||
) -> Tuple[Tuple[int, int, int, int], Tuple[int, int], InterpolationMode]:
|
||||
""" Get parameters for a random sized crop relative to image aspect ratio.
|
||||
"""
|
||||
_, height, width = F.get_dimensions(img)
|
||||
if height <= 0 or width <= 0:
|
||||
raise ValueError(f"Input image must have positive dimensions, got H={height}, W={width}")
|
||||
|
||||
area = height * width
|
||||
orig_aspect = width / height
|
||||
log_ratio = (math.log(ratio[0]), math.log(ratio[1]))
|
||||
|
||||
for _ in range(crop_attempts):
|
||||
target_area = area * random.uniform(scale[0], scale[1])
|
||||
aspect_ratio_factor = math.exp(random.uniform(log_ratio[0], log_ratio[1]))
|
||||
aspect_ratio = orig_aspect * aspect_ratio_factor
|
||||
|
||||
# Calculate target dimensions for the crop
|
||||
# target_area = crop_w * crop_h, aspect_ratio = crop_w / crop_h
|
||||
# => crop_h = sqrt(target_area / aspect_ratio)
|
||||
# => crop_w = sqrt(target_area * aspect_ratio)
|
||||
crop_h = int(round(math.sqrt(target_area / aspect_ratio)))
|
||||
crop_w = int(round(math.sqrt(target_area * aspect_ratio)))
|
||||
|
||||
if 0 < crop_w <= width and 0 < crop_h <= height:
|
||||
top = random.randint(0, height - crop_h)
|
||||
left = random.randint(0, width - crop_w)
|
||||
break
|
||||
else:
|
||||
# Fallback strategy, use center crop trying to respect ratio range
|
||||
min_aspect_ratio = orig_aspect * ratio[0]
|
||||
max_aspect_ratio = orig_aspect * ratio[1]
|
||||
|
||||
if orig_aspect < min_aspect_ratio:
|
||||
# Original is narrower than target min, clamp width
|
||||
crop_w = width
|
||||
crop_h = min(int(round(crop_w / min_aspect_ratio)), height)
|
||||
elif orig_aspect > max_aspect_ratio:
|
||||
# Original is wider than target max, clamp height
|
||||
crop_h = height
|
||||
crop_w = min(int(round(crop_h * max_aspect_ratio)), width)
|
||||
else:
|
||||
# Aspect ratio is within range, take the largest possible crop (full image)
|
||||
crop_w = width
|
||||
crop_h = height
|
||||
|
||||
# Ensure valid dimensions after fallback calculation
|
||||
crop_h = max(1, crop_h)
|
||||
crop_w = max(1, crop_w)
|
||||
|
||||
top = (height - crop_h) // 2
|
||||
left = (width - crop_w) // 2
|
||||
|
||||
# Determine max feasible size for scaling of the *cropped* region
|
||||
feasible_ratio, feasible_size = get_image_size_for_seq(
|
||||
(crop_h, crop_w),
|
||||
patch_size=(patch_h, patch_w), # Pass as tuple
|
||||
max_seq_len=max_seq_len,
|
||||
divisible_by_patch=divisible_by_patch,
|
||||
max_ratio=max_ratio,
|
||||
)
|
||||
|
||||
# Optionally apply final scale randomization
|
||||
final_size = feasible_size
|
||||
if final_scale_range is not None:
|
||||
min_sc, max_sc = final_scale_range
|
||||
scale_factor = random.uniform(min_sc, max_sc)
|
||||
scale_factor = min(max(scale_factor, 0.0), 1.0) # Clamp factor just in case
|
||||
|
||||
# Calculate raw scaled size
|
||||
# Note: feasible_ratio already accounts for max_ratio clamp if any
|
||||
raw_h = crop_h * feasible_ratio * scale_factor
|
||||
raw_w = crop_w * feasible_ratio * scale_factor
|
||||
|
||||
# Re-apply divisibility constraint if needed
|
||||
if divisible_by_patch:
|
||||
# Use ceil to avoid going under minimum patch size
|
||||
target_h = patch_h * math.ceil(raw_h / patch_h)
|
||||
target_w = patch_w * math.ceil(raw_w / patch_w)
|
||||
else:
|
||||
target_h = int(round(raw_h))
|
||||
target_w = int(round(raw_w))
|
||||
|
||||
# Ensure final size is at least one patch dimension
|
||||
target_h = max(target_h, patch_h)
|
||||
target_w = max(target_w, patch_w)
|
||||
final_size = (target_h, target_w)
|
||||
|
||||
# Final check: Ensure this randomized size still fits max_seq_len
|
||||
# (It should, as we scaled down, but rounding might theoretically push it over)
|
||||
num_patches_h = final_size[0] // patch_h
|
||||
num_patches_w = final_size[1] // patch_w
|
||||
if (num_patches_h * num_patches_w) > max_seq_len:
|
||||
# If it exceeds, revert to the original feasible_size (safest)
|
||||
final_size = feasible_size
|
||||
warnings.warn(f"Final scale randomization ({scale_factor:.2f}) resulted in size {final_size} exceeding max_seq_len={max_seq_len} after rounding. Reverting to feasible size {feasible_size}.")
|
||||
|
||||
# Select interpolation mode
|
||||
if isinstance(interpolation, (tuple, list)):
|
||||
interpolation = random.choice(interpolation)
|
||||
else:
|
||||
interpolation = interpolation
|
||||
|
||||
return (top, left, crop_h, crop_w), final_size, interpolation
|
||||
|
||||
def forward(self, img: Union[torch.Tensor, Image]) -> torch.Tensor:
|
||||
# Sample crop, resize, and interpolation parameters
|
||||
crop_params, final_size, interpolation = self.get_params(
|
||||
img,
|
||||
scale=self.scale,
|
||||
ratio=self.ratio,
|
||||
crop_attempts=self.attempts,
|
||||
patch_h=self.patch_h,
|
||||
patch_w=self.patch_w,
|
||||
divisible_by_patch=self.divisible_by_patch,
|
||||
max_seq_len=self.max_seq_len,
|
||||
final_scale_range=self.final_scale_range,
|
||||
interpolation=self.interpolation,
|
||||
)
|
||||
top, left, crop_h, crop_w = crop_params
|
||||
|
||||
output = F.resized_crop(
|
||||
img,
|
||||
top=top,
|
||||
left=left,
|
||||
height=crop_h,
|
||||
width=crop_w,
|
||||
size=final_size,
|
||||
interpolation=interpolation,
|
||||
antialias=True,
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
def __repr__(self) -> str:
|
||||
if isinstance(self.interpolation, (tuple, list)):
|
||||
interpolate_str = ', '.join(str(m).split('.')[-1] for m in self.interpolation)
|
||||
else:
|
||||
interpolate_str = str(self.interpolation)
|
||||
format_string = self.__class__.__name__ + '('
|
||||
format_string += f"patch_size=({self.patch_h}, {self.patch_w})"
|
||||
format_string += f", max_seq_len={self.max_seq_len}"
|
||||
format_string += f", scale={self.scale}"
|
||||
format_string += f", ratio={self.ratio}"
|
||||
format_string += f", interpolation=[{interpolate_str}]"
|
||||
format_string += f", divisible_by_patch={self.divisible_by_patch}"
|
||||
format_string += f", max_ratio={self.max_ratio}"
|
||||
format_string += f", final_scale_range={self.final_scale_range}"
|
||||
format_string += f", attempts={self.attempts}"
|
||||
format_string += ')'
|
||||
return format_string
|
||||
|
||||
|
||||
def patchify(
|
||||
img: torch.Tensor,
|
||||
patch_size: Tuple[int, int],
|
||||
pad: bool = True,
|
||||
include_info: bool = True,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
|
||||
c, h, w = img.shape
|
||||
ph, pw = patch_size
|
||||
|
||||
# Ensure the image is divisible by patch size
|
||||
if pad and (h % ph != 0 or w % pw != 0):
|
||||
new_h = math.ceil(h / ph) * ph
|
||||
new_w = math.ceil(w / pw) * pw
|
||||
padded_img = torch.zeros(c, new_h, new_w, dtype=img.dtype)
|
||||
padded_img[:, :h, :w] = img
|
||||
img = padded_img
|
||||
c, h, w = img.shape
|
||||
|
||||
# Calculate number of patches in each dimension
|
||||
nh, nw = h // ph, w // pw
|
||||
# Reshape image to patches [nh, nw, ph, pw, c]
|
||||
patches = img.view(c, nh, ph, nw, pw).permute(1, 3, 2, 4, 0).reshape(nh * nw, ph * pw * c)
|
||||
|
||||
if include_info:
|
||||
# Create coordinate indices
|
||||
y_idx, x_idx = torch.meshgrid(torch.arange(nh), torch.arange(nw), indexing='ij')
|
||||
# Stack into a single coords tensor [N, 2] with (y, x) order
|
||||
coord = torch.stack([y_idx.reshape(-1), x_idx.reshape(-1)], dim=1)
|
||||
# Create type indicators (all 1s for regular patches)
|
||||
valid = torch.ones(nh * nw, dtype=torch.bool)
|
||||
return patches, coord, valid
|
||||
|
||||
return patches
|
||||
|
||||
|
||||
class Patchify(torch.nn.Module):
|
||||
"""Transform an image into patches with corresponding coordinates and type indicators."""
|
||||
|
||||
def __init__(self, patch_size):
|
||||
super().__init__()
|
||||
self.patch_size = patch_size if isinstance(patch_size, tuple) else (patch_size, patch_size)
|
||||
|
||||
def forward(self, img):
|
||||
"""
|
||||
Args:
|
||||
img: A PIL Image or tensor of shape [C, H, W]
|
||||
|
||||
Returns:
|
||||
A dictionary containing:
|
||||
- patches: Tensor of shape [N, P*P*C] where N is the number of patches
|
||||
- patch_coord: Tensor of shape [N, 2] with (y, x) coordinates
|
||||
- patch_valid: Valid indicator (all 1s for non-padding patches)
|
||||
"""
|
||||
if isinstance(img, Image.Image):
|
||||
# Convert PIL Image to tensor [C, H, W]
|
||||
img = transforms.functional.to_tensor(img)
|
||||
|
||||
patches, coord, valid = patchify(img, self.patch_size)
|
||||
|
||||
return {
|
||||
'patches': patches,
|
||||
'patch_coord': coord,
|
||||
'patch_valid': valid,
|
||||
}
|
@ -12,7 +12,8 @@ from torchvision import transforms
|
||||
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, DEFAULT_CROP_PCT
|
||||
from timm.data.auto_augment import rand_augment_transform, augment_and_mix_transform, auto_augment_transform
|
||||
from timm.data.transforms import str_to_interp_mode, str_to_pil_interp, RandomResizedCropAndInterpolation, \
|
||||
ResizeKeepRatio, CenterCropOrPad, RandomCropOrPad, TrimBorder, ToNumpy, MaybeToTensor, MaybePILToTensor
|
||||
ResizeKeepRatio, CenterCropOrPad, RandomCropOrPad, TrimBorder, MaybeToTensor, MaybePILToTensor
|
||||
from timm.data.naflex_transforms import RandomResizedCropToSequence, ResizeToSequence, Patchify
|
||||
from timm.data.random_erasing import RandomErasing
|
||||
|
||||
|
||||
@ -46,7 +47,7 @@ def transforms_noaug_train(
|
||||
]
|
||||
if use_prefetcher:
|
||||
# prefetcher and collate will handle tensor conversion and norm
|
||||
tfl += [ToNumpy()]
|
||||
tfl += [MaybePILToTensor()]
|
||||
elif not normalize:
|
||||
# when normalize disabled, converted to tensor without scaling, keep original dtype
|
||||
tfl += [MaybePILToTensor()]
|
||||
@ -84,6 +85,10 @@ def transforms_imagenet_train(
|
||||
use_prefetcher: bool = False,
|
||||
normalize: bool = True,
|
||||
separate: bool = False,
|
||||
naflex: bool = False,
|
||||
patch_size: Union[int, Tuple[int, int]] = 16,
|
||||
max_seq_len: int = 576, # 24x24 for 16x16 patch
|
||||
patchify: bool = False,
|
||||
):
|
||||
""" ImageNet-oriented image transforms for training.
|
||||
|
||||
@ -111,6 +116,9 @@ def transforms_imagenet_train(
|
||||
use_prefetcher: Prefetcher enabled. Do not convert image to tensor or normalize.
|
||||
normalize: Normalize tensor output w/ provided mean/std (if prefetcher not used).
|
||||
separate: Output transforms in 3-stage tuple.
|
||||
naflex: Enable NaFlex mode, sequence constrained patch output
|
||||
patch_size: Patch size for NaFlex mode.
|
||||
max_seq_len: Max sequence length for NaFlex mode.
|
||||
|
||||
Returns:
|
||||
If separate==True, the transforms are returned as a tuple of 3 separate transforms
|
||||
@ -121,11 +129,20 @@ def transforms_imagenet_train(
|
||||
"""
|
||||
train_crop_mode = train_crop_mode or 'rrc'
|
||||
assert train_crop_mode in {'rrc', 'rkrc', 'rkrr'}
|
||||
|
||||
primary_tfl = []
|
||||
if naflex:
|
||||
primary_tfl += [RandomResizedCropToSequence(
|
||||
patch_size=patch_size,
|
||||
max_seq_len=max_seq_len,
|
||||
interpolation=interpolation
|
||||
)]
|
||||
else:
|
||||
if train_crop_mode in ('rkrc', 'rkrr'):
|
||||
# FIXME integration of RKR is a WIP
|
||||
scale = tuple(scale or (0.8, 1.00))
|
||||
ratio = tuple(ratio or (0.9, 1/.9))
|
||||
primary_tfl = [
|
||||
primary_tfl += [
|
||||
ResizeKeepRatio(
|
||||
img_size,
|
||||
interpolation=interpolation,
|
||||
@ -142,7 +159,7 @@ def transforms_imagenet_train(
|
||||
else:
|
||||
scale = tuple(scale or (0.08, 1.0)) # default imagenet scale range
|
||||
ratio = tuple(ratio or (3. / 4., 4. / 3.)) # default imagenet ratio range
|
||||
primary_tfl = [
|
||||
primary_tfl += [
|
||||
RandomResizedCropAndInterpolation(
|
||||
img_size,
|
||||
scale=scale,
|
||||
@ -150,6 +167,7 @@ def transforms_imagenet_train(
|
||||
interpolation=interpolation,
|
||||
)
|
||||
]
|
||||
|
||||
if hflip > 0.:
|
||||
primary_tfl += [transforms.RandomHorizontalFlip(p=hflip)]
|
||||
if vflip > 0.:
|
||||
@ -215,7 +233,7 @@ def transforms_imagenet_train(
|
||||
final_tfl = []
|
||||
if use_prefetcher:
|
||||
# prefetcher and collate will handle tensor conversion and norm
|
||||
final_tfl += [ToNumpy()]
|
||||
final_tfl += [MaybePILToTensor()]
|
||||
elif not normalize:
|
||||
# when normalize disable, converted to tensor without scaling, keeps original dtype
|
||||
final_tfl += [MaybePILToTensor()]
|
||||
@ -238,6 +256,9 @@ def transforms_imagenet_train(
|
||||
)
|
||||
]
|
||||
|
||||
if patchify:
|
||||
final_tfl += [Patchify(patch_size=patch_size)]
|
||||
|
||||
if separate:
|
||||
return transforms.Compose(primary_tfl), transforms.Compose(secondary_tfl), transforms.Compose(final_tfl)
|
||||
else:
|
||||
@ -254,6 +275,10 @@ def transforms_imagenet_eval(
|
||||
std: Tuple[float, ...] = IMAGENET_DEFAULT_STD,
|
||||
use_prefetcher: bool = False,
|
||||
normalize: bool = True,
|
||||
naflex: bool = False,
|
||||
patch_size: Union[int, Tuple[int, int]] = 16,
|
||||
max_seq_len: int = 576, # 24x24 for 16x16 patch
|
||||
patchify: bool = False,
|
||||
):
|
||||
""" ImageNet-oriented image transform for evaluation and inference.
|
||||
|
||||
@ -267,6 +292,10 @@ def transforms_imagenet_eval(
|
||||
std: Image normalization standard deviation.
|
||||
use_prefetcher: Prefetcher enabled. Do not convert image to tensor or normalize.
|
||||
normalize: Normalize tensor output w/ provided mean/std (if prefetcher not used).
|
||||
naflex: Enable NaFlex mode, sequence constrained patch output
|
||||
patch_size: Patch size for NaFlex mode.
|
||||
max_seq_len: Max sequence length for NaFlex mode.
|
||||
patchify: Patchify the output instead of relying on prefetcher
|
||||
|
||||
Returns:
|
||||
Composed transform pipeline
|
||||
@ -285,6 +314,13 @@ def transforms_imagenet_eval(
|
||||
if crop_border_pixels:
|
||||
tfl += [TrimBorder(crop_border_pixels)]
|
||||
|
||||
if naflex:
|
||||
tfl += [ResizeToSequence(
|
||||
patch_size=patch_size,
|
||||
max_seq_len=max_seq_len,
|
||||
interpolation=interpolation
|
||||
)]
|
||||
else:
|
||||
if crop_mode == 'squash':
|
||||
# squash mode scales each edge to 1/pct of target, then crops
|
||||
# aspect ratio is not preserved, no img lost if crop_pct == 1.0
|
||||
@ -315,7 +351,7 @@ def transforms_imagenet_eval(
|
||||
|
||||
if use_prefetcher:
|
||||
# prefetcher and collate will handle tensor conversion and norm
|
||||
tfl += [ToNumpy()]
|
||||
tfl += [MaybePILToTensor()]
|
||||
elif not normalize:
|
||||
# when normalize disabled, converted to tensor without scaling, keeps original dtype
|
||||
tfl += [MaybePILToTensor()]
|
||||
@ -328,6 +364,9 @@ def transforms_imagenet_eval(
|
||||
),
|
||||
]
|
||||
|
||||
if patchify:
|
||||
tfl += [Patchify(patch_size=patch_size)]
|
||||
|
||||
return transforms.Compose(tfl)
|
||||
|
||||
|
||||
@ -359,6 +398,10 @@ def create_transform(
|
||||
use_prefetcher: bool = False,
|
||||
normalize: bool = True,
|
||||
separate: bool = False,
|
||||
naflex: bool = False,
|
||||
patch_size: Union[int, Tuple[int, int]] = 16,
|
||||
max_seq_len: int = 576, # 24x24 for 16x16 patch
|
||||
patchify: bool = False
|
||||
):
|
||||
"""
|
||||
|
||||
@ -442,6 +485,10 @@ def create_transform(
|
||||
use_prefetcher=use_prefetcher,
|
||||
normalize=normalize,
|
||||
separate=separate,
|
||||
naflex=naflex,
|
||||
patch_size=patch_size,
|
||||
max_seq_len=max_seq_len,
|
||||
patchify=patchify,
|
||||
)
|
||||
else:
|
||||
assert not separate, "Separate transforms not supported for validation preprocessing"
|
||||
@ -455,6 +502,10 @@ def create_transform(
|
||||
crop_border_pixels=crop_border_pixels,
|
||||
use_prefetcher=use_prefetcher,
|
||||
normalize=normalize,
|
||||
naflex=naflex,
|
||||
patch_size=patch_size,
|
||||
max_seq_len=max_seq_len,
|
||||
patchify=patchify,
|
||||
)
|
||||
|
||||
return transform
|
||||
|
@ -1,6 +1,7 @@
|
||||
from .activations import *
|
||||
from .adaptive_avgmax_pool import \
|
||||
adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d
|
||||
from .attention import Attention
|
||||
from .attention2d import MultiQueryAttention2d, Attention2d, MultiQueryAttentionV2
|
||||
from .attention_pool import AttentionPoolLatent
|
||||
from .attention_pool2d import AttentionPool2d, RotAttentionPool2d, RotaryEmbedding
|
||||
|
66
timm/layers/attention.py
Normal file
66
timm/layers/attention.py
Normal file
@ -0,0 +1,66 @@
|
||||
from typing import Final, Type, Optional
|
||||
|
||||
import torch
|
||||
from torch import nn as nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from .config import use_fused_attn
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
fused_attn: Final[bool]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int = 8,
|
||||
qkv_bias: bool = False,
|
||||
qk_norm: bool = False,
|
||||
proj_bias: bool = True,
|
||||
attn_drop: float = 0.,
|
||||
proj_drop: float = 0.,
|
||||
norm_layer: Type[nn.Module] = nn.LayerNorm,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
self.scale = self.head_dim ** -0.5
|
||||
self.fused_attn = use_fused_attn()
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
||||
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(dim, dim, bias=proj_bias)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
B, N, C = x.shape
|
||||
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv.unbind(0)
|
||||
q, k = self.q_norm(q), self.k_norm(k)
|
||||
|
||||
if self.fused_attn:
|
||||
x = F.scaled_dot_product_attention(
|
||||
q, k, v,
|
||||
attn_mask=attn_mask,
|
||||
dropout_p=self.attn_drop.p if self.training else 0.,
|
||||
)
|
||||
else:
|
||||
q = q * self.scale
|
||||
attn = q @ k.transpose(-2, -1)
|
||||
if attn_mask is not None:
|
||||
attn = attn + attn_mask
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
x = attn @ v
|
||||
|
||||
x = x.transpose(1, 2).reshape(B, N, C)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
@ -75,7 +75,7 @@ class AttentionPoolLatent(nn.Module):
|
||||
trunc_normal_tf_(self.pos_embed, std=self.pos_embed.shape[1] ** -0.5)
|
||||
trunc_normal_tf_(self.latent, std=self.latent_dim ** -0.5)
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x, attn_mask: Optional[torch.Tensor] = None):
|
||||
B, N, C = x.shape
|
||||
|
||||
if self.pos_embed is not None:
|
||||
@ -91,10 +91,12 @@ class AttentionPoolLatent(nn.Module):
|
||||
q, k = self.q_norm(q), self.k_norm(k)
|
||||
|
||||
if self.fused_attn:
|
||||
x = F.scaled_dot_product_attention(q, k, v)
|
||||
x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
||||
else:
|
||||
q = q * self.scale
|
||||
attn = q @ k.transpose(-2, -1)
|
||||
if attn_mask is not None:
|
||||
attn = attn + attn_mask
|
||||
attn = attn.softmax(dim=-1)
|
||||
x = attn @ v
|
||||
x = x.transpose(1, 2).reshape(B, self.latent_len, C)
|
||||
|
@ -72,6 +72,7 @@ from .vgg import *
|
||||
from .visformer import *
|
||||
from .vision_transformer import *
|
||||
from .vision_transformer_hybrid import *
|
||||
from .vision_transformer_flex import *
|
||||
from .vision_transformer_relpos import *
|
||||
from .vision_transformer_sam import *
|
||||
from .vitamin import *
|
||||
|
@ -41,8 +41,8 @@ from torch.jit import Final
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD, \
|
||||
OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
|
||||
from timm.layers import PatchEmbed, Mlp, DropPath, AttentionPoolLatent, RmsNorm, PatchDropout, SwiGLUPacked, SwiGLU, \
|
||||
trunc_normal_, lecun_normal_, resample_patch_embed, resample_abs_pos_embed, use_fused_attn, \
|
||||
from timm.layers import Attention, PatchEmbed, Mlp, DropPath, AttentionPoolLatent, RmsNorm, PatchDropout, \
|
||||
SwiGLUPacked, SwiGLU, trunc_normal_, lecun_normal_, resample_patch_embed, resample_abs_pos_embed, use_fused_attn, \
|
||||
get_act_layer, get_norm_layer, LayerType
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._features import feature_take_indices
|
||||
@ -55,58 +55,6 @@ __all__ = ['VisionTransformer'] # model_registry will add each entrypoint fn to
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
fused_attn: Final[bool]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int = 8,
|
||||
qkv_bias: bool = False,
|
||||
qk_norm: bool = False,
|
||||
proj_bias: bool = True,
|
||||
attn_drop: float = 0.,
|
||||
proj_drop: float = 0.,
|
||||
norm_layer: Type[nn.Module] = nn.LayerNorm,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
self.scale = self.head_dim ** -0.5
|
||||
self.fused_attn = use_fused_attn()
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
||||
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(dim, dim, bias=proj_bias)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
B, N, C = x.shape
|
||||
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv.unbind(0)
|
||||
q, k = self.q_norm(q), self.k_norm(k)
|
||||
|
||||
if self.fused_attn:
|
||||
x = F.scaled_dot_product_attention(
|
||||
q, k, v,
|
||||
dropout_p=self.attn_drop.p if self.training else 0.,
|
||||
)
|
||||
else:
|
||||
q = q * self.scale
|
||||
attn = q @ k.transpose(-2, -1)
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
x = attn @ v
|
||||
|
||||
x = x.transpose(1, 2).reshape(B, N, C)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
||||
|
||||
class LayerScale(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@ -165,8 +113,8 @@ class Block(nn.Module):
|
||||
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
||||
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x))))
|
||||
def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x), attn_mask=attn_mask)))
|
||||
x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
|
||||
return x
|
||||
|
||||
@ -222,8 +170,8 @@ class ResPostBlock(nn.Module):
|
||||
nn.init.constant_(self.norm1.weight, self.init_values)
|
||||
nn.init.constant_(self.norm2.weight, self.init_values)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = x + self.drop_path1(self.norm1(self.attn(x)))
|
||||
def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
x = x + self.drop_path1(self.norm1(self.attn(x, attn_mask=attn_mask)))
|
||||
x = x + self.drop_path2(self.norm2(self.mlp(x)))
|
||||
return x
|
||||
|
||||
@ -282,7 +230,7 @@ class ParallelScalingBlock(nn.Module):
|
||||
self.ls = LayerScale(dim, init_values=init_values) if init_values is not None else nn.Identity()
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
B, N, C = x.shape
|
||||
|
||||
# Combined MLP fc1 & qkv projections
|
||||
@ -302,14 +250,18 @@ class ParallelScalingBlock(nn.Module):
|
||||
if self.fused_attn:
|
||||
x_attn = F.scaled_dot_product_attention(
|
||||
q, k, v,
|
||||
attn_mask=attn_mask,
|
||||
dropout_p=self.attn_drop.p if self.training else 0.,
|
||||
)
|
||||
else:
|
||||
q = q * self.scale
|
||||
attn = q @ k.transpose(-2, -1)
|
||||
if attn_mask is not None:
|
||||
attn = attn + attn_mask
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
x_attn = attn @ v
|
||||
|
||||
x_attn = x_attn.transpose(1, 2).reshape(B, N, C)
|
||||
x_attn = self.attn_out_proj(x_attn)
|
||||
|
||||
@ -379,23 +331,11 @@ class ParallelThingsBlock(nn.Module):
|
||||
('drop_path', DropPath(drop_path) if drop_path > 0. else nn.Identity())
|
||||
])))
|
||||
|
||||
def _forward_jit(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = x + torch.stack([attn(x) for attn in self.attns]).sum(dim=0)
|
||||
def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
x = x + torch.stack([attn(x, attn_mask=attn_mask) for attn in self.attns]).sum(dim=0)
|
||||
x = x + torch.stack([ffn(x) for ffn in self.ffns]).sum(dim=0)
|
||||
return x
|
||||
|
||||
@torch.jit.ignore
|
||||
def _forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = x + sum(attn(x) for attn in self.attns)
|
||||
x = x + sum(ffn(x) for ffn in self.ffns)
|
||||
return x
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||
return self._forward_jit(x)
|
||||
else:
|
||||
return self._forward(x)
|
||||
|
||||
|
||||
def global_pool_nlc(
|
||||
x: torch.Tensor,
|
||||
@ -728,7 +668,9 @@ class VisionTransformer(nn.Module):
|
||||
stop_early: bool = False,
|
||||
output_fmt: str = 'NCHW',
|
||||
intermediates_only: bool = False,
|
||||
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||
output_dict: bool = False,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]], Dict[str, Any]]:
|
||||
""" Forward features that returns intermediates.
|
||||
|
||||
Args:
|
||||
@ -739,8 +681,11 @@ class VisionTransformer(nn.Module):
|
||||
stop_early: Stop iterating over blocks when last desired intermediate hit
|
||||
output_fmt: Shape of intermediate feature outputs
|
||||
intermediates_only: Only return intermediate features
|
||||
output_dict: Return outputs as a dictionary with 'image_features' and 'image_intermediates' keys
|
||||
attn_mask: Optional attention mask for masked attention (e.g., for NaFlex)
|
||||
Returns:
|
||||
|
||||
A tuple with (final_features, intermediates), a list of intermediate features, or a dictionary containing
|
||||
'image_features' and 'image_intermediates' (and optionally 'image_intermediates_prefix')
|
||||
"""
|
||||
assert output_fmt in ('NCHW', 'NLC'), 'Output format must be one of NCHW or NLC.'
|
||||
reshape = output_fmt == 'NCHW'
|
||||
@ -759,7 +704,7 @@ class VisionTransformer(nn.Module):
|
||||
else:
|
||||
blocks = self.blocks[:max_index + 1]
|
||||
for i, blk in enumerate(blocks):
|
||||
x = blk(x)
|
||||
x = blk(x, attn_mask=attn_mask)
|
||||
if i in take_indices:
|
||||
# normalize intermediates with final norm layer if enabled
|
||||
intermediates.append(self.norm(x) if norm else x)
|
||||
@ -776,6 +721,23 @@ class VisionTransformer(nn.Module):
|
||||
# reshape to BCHW output format
|
||||
H, W = self.patch_embed.dynamic_feat_size((height, width))
|
||||
intermediates = [y.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for y in intermediates]
|
||||
|
||||
# For dictionary output, handle prefix tokens separately
|
||||
if output_dict:
|
||||
result_dict = {}
|
||||
# Intermediates are always included
|
||||
result_dict['image_intermediates'] = intermediates
|
||||
if prefix_tokens is not None and return_prefix_tokens:
|
||||
result_dict['image_intermediates_prefix'] = prefix_tokens
|
||||
|
||||
# Only include features if not intermediates_only
|
||||
if not intermediates_only:
|
||||
x_final = self.norm(x)
|
||||
result_dict['image_features'] = x_final
|
||||
|
||||
return result_dict
|
||||
|
||||
# For non-dictionary output, maintain the original behavior
|
||||
if not torch.jit.is_scripting() and return_prefix_tokens and prefix_tokens is not None:
|
||||
# return_prefix not support in torchscript due to poor type handling
|
||||
intermediates = list(zip(intermediates, prefix_tokens))
|
||||
@ -811,6 +773,7 @@ class VisionTransformer(nn.Module):
|
||||
reshape: bool = False,
|
||||
return_prefix_tokens: bool = False,
|
||||
norm: bool = False,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
) -> List[torch.Tensor]:
|
||||
""" Intermediate layer accessor inspired by DINO / DINOv2 interface.
|
||||
NOTE: This API is for backwards compat, favour using forward_intermediates() directly.
|
||||
@ -821,17 +784,24 @@ class VisionTransformer(nn.Module):
|
||||
norm=norm,
|
||||
output_fmt='NCHW' if reshape else 'NLC',
|
||||
intermediates_only=True,
|
||||
attn_mask=attn_mask,
|
||||
)
|
||||
|
||||
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
|
||||
def forward_features(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
x = self.patch_embed(x)
|
||||
x = self._pos_embed(x)
|
||||
x = self.patch_drop(x)
|
||||
x = self.norm_pre(x)
|
||||
if self.grad_checkpointing and not torch.jit.is_scripting():
|
||||
|
||||
if attn_mask is not None:
|
||||
# If mask provided, we need to apply blocks one by one
|
||||
for blk in self.blocks:
|
||||
x = blk(x, attn_mask=attn_mask)
|
||||
elif self.grad_checkpointing and not torch.jit.is_scripting():
|
||||
x = checkpoint_seq(self.blocks, x)
|
||||
else:
|
||||
x = self.blocks(x)
|
||||
|
||||
x = self.norm(x)
|
||||
return x
|
||||
|
||||
@ -849,8 +819,8 @@ class VisionTransformer(nn.Module):
|
||||
x = self.head_drop(x)
|
||||
return x if pre_logits else self.head(x)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.forward_features(x)
|
||||
def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
x = self.forward_features(x, attn_mask=attn_mask)
|
||||
x = self.forward_head(x)
|
||||
return x
|
||||
|
||||
|
1045
timm/models/vision_transformer_flex.py
Normal file
1045
timm/models/vision_transformer_flex.py
Normal file
File diff suppressed because it is too large
Load Diff
108
train.py
108
train.py
@ -396,6 +396,15 @@ group.add_argument('--wandb-tags', default=[], type=str, nargs='+',
|
||||
group.add_argument('--wandb-resume-id', default='', type=str, metavar='ID',
|
||||
help='If resuming a run, the id of the run in wandb')
|
||||
|
||||
# NaFlex scheduled loader arguments
|
||||
group.add_argument('--naflex-loader', action='store_true', default=False,
|
||||
help='Use NaFlex loader (Requires NaFlex compatible model)')
|
||||
group.add_argument('--naflex-train-seq-lens', type=int, nargs='+', default=[128, 256, 576, 784, 1024],
|
||||
help='Sequence lengths to use for NaFlex loader')
|
||||
group.add_argument('--naflex-max-seq-len', type=int, default=576,
|
||||
help='Fixed maximum sequence length for NaFlex loader (validation)')
|
||||
|
||||
|
||||
|
||||
def _parse_args():
|
||||
# Do we have a config file to parse?
|
||||
@ -669,6 +678,7 @@ def main():
|
||||
trust_remote_code=args.dataset_trust_remote_code,
|
||||
)
|
||||
|
||||
dataset_eval = None
|
||||
if args.val_split:
|
||||
dataset_eval = create_dataset(
|
||||
args.dataset,
|
||||
@ -690,6 +700,7 @@ def main():
|
||||
mixup_fn = None
|
||||
mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
|
||||
if mixup_active:
|
||||
assert not args.naflex_loader, "Mixup/Cutmix not currently supported for NaFlex loading."
|
||||
mixup_args = dict(
|
||||
mixup_alpha=args.mixup,
|
||||
cutmix_alpha=args.cutmix,
|
||||
@ -714,9 +725,19 @@ def main():
|
||||
train_interpolation = args.train_interpolation
|
||||
if args.no_aug or not train_interpolation:
|
||||
train_interpolation = data_config['interpolation']
|
||||
loader_train = create_loader(
|
||||
dataset_train,
|
||||
input_size=data_config['input_size'],
|
||||
|
||||
# Check if we should use the NaFlex scheduled loader
|
||||
common_loader_kwargs = dict(
|
||||
mean=data_config['mean'],
|
||||
std=data_config['std'],
|
||||
pin_memory=args.pin_mem,
|
||||
img_dtype=model_dtype or torch.float32,
|
||||
device=device,
|
||||
distributed=args.distributed,
|
||||
use_prefetcher=args.prefetcher,
|
||||
)
|
||||
|
||||
train_loader_kwargs = dict(
|
||||
batch_size=args.batch_size,
|
||||
is_training=True,
|
||||
no_aug=args.no_aug,
|
||||
@ -737,40 +758,68 @@ def main():
|
||||
num_aug_repeats=args.aug_repeats,
|
||||
num_aug_splits=num_aug_splits,
|
||||
interpolation=train_interpolation,
|
||||
mean=data_config['mean'],
|
||||
std=data_config['std'],
|
||||
num_workers=args.workers,
|
||||
distributed=args.distributed,
|
||||
collate_fn=collate_fn,
|
||||
pin_memory=args.pin_mem,
|
||||
img_dtype=model_dtype or torch.float32,
|
||||
device=device,
|
||||
use_prefetcher=args.prefetcher,
|
||||
use_multi_epochs_loader=args.use_multi_epochs_loader,
|
||||
worker_seeding=args.worker_seeding,
|
||||
)
|
||||
|
||||
if args.naflex_loader:
|
||||
from timm.data.naflex_loader import create_naflex_loader
|
||||
if utils.is_primary(args):
|
||||
_logger.info('Using NaFlex loader')
|
||||
|
||||
loader_train = create_naflex_loader(
|
||||
dataset=dataset_train,
|
||||
patch_size=16, # Could be derived from model config
|
||||
train_seq_lens=args.naflex_train_seq_lens,
|
||||
rank=args.rank,
|
||||
world_size=args.world_size,
|
||||
**common_loader_kwargs,
|
||||
**train_loader_kwargs,
|
||||
)
|
||||
else:
|
||||
# Use standard loader
|
||||
loader_train = create_loader(
|
||||
dataset_train,
|
||||
input_size=data_config['input_size'],
|
||||
collate_fn=collate_fn,
|
||||
use_multi_epochs_loader=args.use_multi_epochs_loader,
|
||||
**common_loader_kwargs,
|
||||
**train_loader_kwargs,
|
||||
)
|
||||
|
||||
loader_eval = None
|
||||
if args.val_split:
|
||||
assert dataset_eval is not None
|
||||
eval_workers = args.workers
|
||||
if args.distributed and ('tfds' in args.dataset or 'wds' in args.dataset):
|
||||
# FIXME reduces validation padding issues when using TFDS, WDS w/ workers and distributed training
|
||||
eval_workers = min(2, args.workers)
|
||||
loader_eval = create_loader(
|
||||
dataset_eval,
|
||||
input_size=data_config['input_size'],
|
||||
|
||||
eval_loader_kwargs = dict(
|
||||
batch_size=args.validation_batch_size or args.batch_size,
|
||||
is_training=False,
|
||||
interpolation=data_config['interpolation'],
|
||||
mean=data_config['mean'],
|
||||
std=data_config['std'],
|
||||
num_workers=eval_workers,
|
||||
distributed=args.distributed,
|
||||
crop_pct=data_config['crop_pct'],
|
||||
pin_memory=args.pin_mem,
|
||||
img_dtype=model_dtype or torch.float32,
|
||||
device=device,
|
||||
use_prefetcher=args.prefetcher,
|
||||
)
|
||||
|
||||
if args.naflex_loader:
|
||||
from timm.data.naflex_loader import create_naflex_loader
|
||||
# Use largest sequence length for validation
|
||||
loader_eval = create_naflex_loader(
|
||||
dataset=dataset_eval,
|
||||
patch_size=16, # Could be derived from model config
|
||||
max_seq_len=args.naflex_max_seq_len,
|
||||
**common_loader_kwargs,
|
||||
**eval_loader_kwargs
|
||||
)
|
||||
else:
|
||||
# Use standard loader
|
||||
loader_eval = create_loader(
|
||||
dataset_eval,
|
||||
input_size=data_config['input_size'],
|
||||
**common_loader_kwargs,
|
||||
**eval_loader_kwargs,
|
||||
)
|
||||
|
||||
# setup loss function
|
||||
@ -1083,8 +1132,12 @@ def train_one_epoch(
|
||||
loss = _forward()
|
||||
_backward(loss)
|
||||
|
||||
losses_m.update(loss.item() * accum_steps, input.size(0))
|
||||
update_sample_count += input.size(0)
|
||||
if isinstance(input, dict):
|
||||
batch_size = input['patches'].shape[0]
|
||||
else:
|
||||
batch_size = input.shape[0]
|
||||
losses_m.update(loss.item() * accum_steps, batch_size)
|
||||
update_sample_count += batch_size
|
||||
|
||||
if not need_update:
|
||||
data_start_time = time.time()
|
||||
@ -1210,9 +1263,10 @@ def validate(
|
||||
elif device.type == "npu":
|
||||
torch.npu.synchronize()
|
||||
|
||||
losses_m.update(reduced_loss.item(), input.size(0))
|
||||
top1_m.update(acc1.item(), output.size(0))
|
||||
top5_m.update(acc5.item(), output.size(0))
|
||||
batch_size = output.shape[0]
|
||||
losses_m.update(reduced_loss.item(), batch_size)
|
||||
top1_m.update(acc1.item(), batch_size)
|
||||
top5_m.update(acc5.item(), batch_size)
|
||||
|
||||
batch_time_m.update(time.time() - end)
|
||||
end = time.time()
|
||||
|
Loading…
x
Reference in New Issue
Block a user