Add naflex loader support to validate.py, fix bug in naflex pos embed add, classic vit weight loading for naflex model

This commit is contained in:
Ross Wightman 2025-04-25 16:00:54 -07:00
parent c527c37969
commit 3dc90ed7a7
4 changed files with 198 additions and 84 deletions

View File

@ -760,7 +760,6 @@ def patchify(
nh, nw = h // ph, w // pw nh, nw = h // ph, w // pw
# Reshape image to patches [nh, nw, ph, pw, c] # Reshape image to patches [nh, nw, ph, pw, c]
patches = img.view(c, nh, ph, nw, pw).permute(1, 3, 2, 4, 0).reshape(nh * nw, ph * pw * c) patches = img.view(c, nh, ph, nw, pw).permute(1, 3, 2, 4, 0).reshape(nh * nw, ph * pw * c)
if include_info: if include_info:
# Create coordinate indices # Create coordinate indices
y_idx, x_idx = torch.meshgrid(torch.arange(nh), torch.arange(nw), indexing='ij') y_idx, x_idx = torch.meshgrid(torch.arange(nh), torch.arange(nw), indexing='ij')

View File

@ -318,7 +318,7 @@ def transforms_imagenet_eval(
tfl += [ResizeToSequence( tfl += [ResizeToSequence(
patch_size=patch_size, patch_size=patch_size,
max_seq_len=max_seq_len, max_seq_len=max_seq_len,
interpolation=interpolation interpolation=interpolation,
)] )]
else: else:
if crop_mode == 'squash': if crop_mode == 'squash':

View File

@ -52,6 +52,7 @@ def batch_patchify(
nh, nw = H // ph, W // pw nh, nw = H // ph, W // pw
patches = x.view(B, C, nh, ph, nw, pw).permute(0, 2, 4, 3, 5, 1).reshape(B, nh * nw, ph * pw * C) patches = x.view(B, C, nh, ph, nw, pw).permute(0, 2, 4, 3, 5, 1).reshape(B, nh * nw, ph * pw * C)
# FIXME confirm we want 'channels last' in the patch channel layout, egg ph, ph, C instead of C, ph, hw
return patches, (nh, nw) return patches, (nh, nw)
@ -297,7 +298,7 @@ class FlexEmbeds(nn.Module):
size_to_indices[k].append(bi) size_to_indices[k].append(bi)
# Handle each batch element separately with its own grid size # Handle each batch element separately with its own grid size
pos_embed_nchw = self.pos_embed.permute(0, 3, 1, 2) # B,C,H,W pos_embed_nchw = self.pos_embed.permute(0, 3, 1, 2).float() # B,C,H,W
for k, batch_indices in size_to_indices.items(): for k, batch_indices in size_to_indices.items():
h, w = k h, w = k
#h, w = k >> 16, k & 0xFFFF # FIXME can get jit compat with this #h, w = k >> 16, k & 0xFFFF # FIXME can get jit compat with this
@ -312,9 +313,14 @@ class FlexEmbeds(nn.Module):
align_corners=False, align_corners=False,
antialias=True, antialias=True,
).flatten(2).transpose(1, 2) ).flatten(2).transpose(1, 2)
pos_embed_flat = pos_embed_flat.to(dtype=x.dtype)
seq_len = min(x.shape[1], pos_embed_flat.shape[1]) seq_len = min(x.shape[1], pos_embed_flat.shape[1])
x[batch_indices, :seq_len].add_(pos_embed_flat[:, :seq_len]) x[:, :seq_len].index_add_(
0,
torch.as_tensor(batch_indices, device=x.device),
pos_embed_flat[:, :seq_len].expand(len(batch_indices), -1, -1)
)
def _apply_learned_pos_embed( def _apply_learned_pos_embed(
self, self,
@ -328,12 +334,13 @@ class FlexEmbeds(nn.Module):
else: else:
# Resize if needed - directly using F.interpolate # Resize if needed - directly using F.interpolate
pos_embed_flat = F.interpolate( pos_embed_flat = F.interpolate(
self.pos_embed.permute(0, 3, 1, 2), # B,C,H,W self.pos_embed.permute(0, 3, 1, 2).float(), # B,C,H,W
size=grid_size, size=grid_size,
mode=self.pos_embed_interp_mode, mode=self.pos_embed_interp_mode,
align_corners=False, align_corners=False,
antialias=True, antialias=True,
).flatten(2).transpose(1, 2) ).flatten(2).transpose(1, 2)
pos_embed_flat = pos_embed_flat.to(dtype=x.dtype)
x.add_(pos_embed_flat) x.add_(pos_embed_flat)
@ -806,21 +813,20 @@ class VisionTransformerFlex(nn.Module):
# Apply the mask to extract only valid tokens # Apply the mask to extract only valid tokens
x = x[:, self.num_prefix_tokens:] # prefix tokens not included in pooling x = x[:, self.num_prefix_tokens:] # prefix tokens not included in pooling
patch_valid_float = patch_valid.to(x.dtype)
if pool_type == 'avg': if pool_type == 'avg':
# Compute masked average pooling # Compute masked average pooling, sum valid tokens and divide by count of valid tokens
# Sum valid tokens and divide by count of valid tokens masked_sums = (x * patch_valid_float.unsqueeze(-1)).sum(dim=1)
masked_sums = (x * patch_valid.unsqueeze(-1).float()).sum(dim=1) valid_counts = patch_valid_float.sum(dim=1, keepdim=True).clamp(min=1)
valid_counts = patch_valid.float().sum(dim=1, keepdim=True).clamp(min=1)
pooled = masked_sums / valid_counts pooled = masked_sums / valid_counts
return pooled return pooled
elif pool_type == 'avgmax': elif pool_type == 'avgmax':
# For avgmax, compute masked average and masked max # For avgmax, compute masked average and masked max
# For max, we set masked positions to large negative value masked_sums = (x * patch_valid_float.unsqueeze(-1)).sum(dim=1)
masked_sums = (x * patch_valid.unsqueeze(-1).float()).sum(dim=1) valid_counts = patch_valid_float.sum(dim=1, keepdim=True).clamp(min=1)
valid_counts = patch_valid.float().sum(dim=1, keepdim=True).clamp(min=1)
masked_avg = masked_sums / valid_counts masked_avg = masked_sums / valid_counts
# For max pooling with mask # For max pooling we set masked positions to large negative value
masked_x = x.clone() masked_x = x.clone()
masked_x[~patch_valid] = torch.finfo(masked_x.dtype).min masked_x[~patch_valid] = torch.finfo(masked_x.dtype).min
masked_max = masked_x.max(dim=1)[0] masked_max = masked_x.max(dim=1)[0]
@ -915,6 +921,82 @@ def get_init_weights_vit(mode: str = 'jax', head_bias: float = 0.0) -> Callable:
return init_weights_vit_timm return init_weights_vit_timm
def checkpoint_filter_fn(state_dict, model):
"""Handle state dict conversion from original ViT to the new version with combined embedding."""
from .vision_transformer import checkpoint_filter_fn as orig_filter_fn
# Handle CombinedEmbed module pattern
out_dict = {}
for k, v in state_dict.items():
# Convert tokens and embeddings to combined_embed structure
if k == 'pos_embed':
# Handle position embedding format conversion - from (1, N, C) to (1, H, W, C)
if hasattr(model.embeds, 'pos_embed') and v.ndim == 3:
num_cls_token = 0
num_reg_token = 0
if 'reg_token' in state_dict:
num_reg_token = state_dict['reg_token'].shape[1]
if 'cls_token' in state_dict:
num_cls_token = state_dict['cls_token'].shape[1]
num_prefix_tokens = num_cls_token + num_reg_token
# Original format is (1, N, C), need to reshape to (1, H, W, C)
num_patches = v.shape[1]
num_patches_no_prefix = num_patches - num_prefix_tokens
grid_size_no_prefix = math.sqrt(num_patches_no_prefix)
grid_size = math.sqrt(num_patches)
if (grid_size_no_prefix != grid_size and (
grid_size_no_prefix.is_integer() and not grid_size.is_integer())):
# make a decision, did the pos_embed of the original include the prefix tokens?
num_patches = num_patches_no_prefix
cls_token_emb = v[:, 0:num_cls_token]
if cls_token_emb.numel():
state_dict['cls_token'] += cls_token_emb
reg_token_emb = v[:, num_cls_token:num_reg_token]
if reg_token_emb.numel():
state_dict['reg_token'] += reg_token_emb
v = v[:, num_prefix_tokens:]
grid_size = grid_size_no_prefix
grid_size = int(grid_size)
# Check if it's a perfect square for a standard grid
if grid_size * grid_size == num_patches:
# Reshape from (1, N, C) to (1, H, W, C)
v = v.reshape(1, grid_size, grid_size, v.shape[2])
else:
# Not a square grid, we need to get the actual dimensions
if hasattr(model.embeds.patch_embed, 'grid_size'):
h, w = model.embeds.patch_embed.grid_size
if h * w == num_patches:
# We have the right dimensions
v = v.reshape(1, h, w, v.shape[2])
else:
# Dimensions don't match, use interpolation
_logger.warning(
f"Position embedding size mismatch: checkpoint={num_patches}, model={(h * w)}. "
f"Using default initialization and will resize in forward pass."
)
# Keep v as is, the forward pass will handle resizing
out_dict['embeds.pos_embed'] = v
elif k == 'cls_token':
out_dict['embeds.cls_token'] = v
elif k == 'reg_token':
out_dict['embeds.reg_token'] = v
# Convert patch_embed.X to embeds.patch_embed.X
elif k.startswith('patch_embed.'):
suffix = k[12:]
if suffix == 'proj.weight':
# FIXME confirm patchify memory layout across use cases
v = v.permute(0, 2, 3, 1).flatten(1)
new_key = 'embeds.' + suffix
out_dict[new_key] = v
else:
out_dict[k] = v
return out_dict
def _cfg(url: str = '', **kwargs) -> Dict[str, Any]: def _cfg(url: str = '', **kwargs) -> Dict[str, Any]:
return { return {
'url': url, 'url': url,
@ -936,6 +1018,26 @@ default_cfgs = generate_default_cfgs({
'vit_naflex_base_patch16': _cfg(), 'vit_naflex_base_patch16': _cfg(),
'vit_naflex_base_patch16_gap': _cfg(), 'vit_naflex_base_patch16_gap': _cfg(),
'vit_naflex_base_patch16_map': _cfg(), 'vit_naflex_base_patch16_map': _cfg(),
# sbb model testijg
'vit_naflex_mediumd_patch16_reg4_gap.sbb2_r256_e200_in12k_ft_in1k': _cfg(
hf_hub_id='timm/vit_mediumd_patch16_reg4_gap_256.sbb2_e200_in12k_ft_in1k',
input_size=(3, 256, 256), crop_pct=0.95),
'vit_naflex_so150m2_patch16_reg1_gap.sbb_r256_e200_in12k_ft_in1k': _cfg(
hf_hub_id='timm/vit_so150m2_patch16_reg1_gap_256.sbb_e200_in12k_ft_in1k',
input_size=(3, 256, 256), crop_pct=1.0),
'vit_naflex_so150m2_patch16_reg1_gap.sbb_r384_e200_in12k_ft_in1k': _cfg(
hf_hub_id='timm/vit_so150m2_patch16_reg1_gap_384.sbb_e200_in12k_ft_in1k',
input_size=(3, 384, 384), crop_pct=1.0),
'vit_naflex_so150m2_patch16_reg1_gap.sbb_r448_e200_in12k_ft_in1k': _cfg(
hf_hub_id='timm/vit_so150m2_patch16_reg1_gap_448.sbb_e200_in12k_ft_in1k',
input_size=(3, 448, 448), crop_pct=1.0, crop_mode='squash'),
# traditional vit testing
'vit_naflex_base_patch16.augreg2_r224_in21k_ft_in1k': _cfg(
hf_hub_id='timm/vit_base_patch16_224.augreg2_in21k_ft_in1k'),
'vit_naflex_base_patch8.augreg2_r224_in21k_ft_in1k': _cfg(
hf_hub_id='timm/vit_base_patch16_224.augreg2_in21k_ft_in1k'),
}) })
@ -948,10 +1050,22 @@ def _create_vision_transformer_flex(variant, pretrained=False, **kwargs):
return model return model
@register_model
def vit_naflex_mediumd_patch16_reg4_gap(pretrained=False, **kwargs):
"""ViT-New with NaFlex functionality for variable aspect ratios and resolutions.
"""
model_args = dict(
patch_size=16, embed_dim=512, depth=20, num_heads=8, init_values=1e-5,
global_pool='avg', class_token=False, reg_tokens=4, fc_norm=True, **kwargs)
model = _create_vision_transformer_flex(
'vit_naflex_mediumd_patch16_reg4_gap', pretrained=pretrained, **model_args)
return model
@register_model @register_model
def vit_naflex_base_patch16(pretrained=False, **kwargs): def vit_naflex_base_patch16(pretrained=False, **kwargs):
"""ViT-New with NaFlex functionality for variable aspect ratios and resolutions. """ViT-New with NaFlex functionality for variable aspect ratios and resolutions.
This model supports: This model supports:
1. Variable aspect ratios and resolutions via patch coordinates 1. Variable aspect ratios and resolutions via patch coordinates
2. Position embedding interpolation for arbitrary grid sizes 2. Position embedding interpolation for arbitrary grid sizes
@ -987,54 +1101,28 @@ def vit_naflex_base_patch16_map(pretrained=False, **kwargs):
return model return model
def checkpoint_filter_fn(state_dict, model): @register_model
"""Handle state dict conversion from original ViT to the new version with combined embedding.""" def vit_naflex_so150m2_patch16_reg1_gap(pretrained=False, **kwargs):
from .vision_transformer import checkpoint_filter_fn as orig_filter_fn """ViT-New with NaFlex functionality for variable aspect ratios and resolutions.
# FIXME conversion of existing vit checkpoints has not been finished or tested This model supports:
1. Variable aspect ratios and resolutions via patch coordinates
2. Position embedding interpolation for arbitrary grid sizes
3. Explicit patch coordinates and valid token masking
"""
model_args = dict(
patch_size=16, embed_dim=832, depth=21, num_heads=13, mlp_ratio=34/13, init_values=1e-5,
qkv_bias=False, class_token=False, reg_tokens=1, global_pool='avg', fc_norm=True, **kwargs)
model = _create_vision_transformer_flex(
'vit_naflex_so150m2_patch16_reg1_gap', pretrained=pretrained, **model_args)
return model
# Handle CombinedEmbed module pattern
out_dict = {} @register_model
for k, v in state_dict.items(): def vit_naflex_base_patch16(pretrained: bool = False, **kwargs):
# Convert tokens and embeddings to combined_embed structure """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
if k == 'pos_embed': ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
# Handle position embedding format conversion - from (1, N, C) to (1, H, W, C) """
if hasattr(model.embeds, 'pos_embed') and v.ndim == 3: model_args = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, global_pool='token', class_token=True, pos_embed_grid_size=(14, 14))
# Original format is (1, N, C) - need to reshape to (1, H, W, C) model = _create_vision_transformer_flex('vit_naflex_base_patch16', pretrained=pretrained, **dict(model_args, **kwargs))
num_patches = v.shape[1] return model
grid_size = int(math.sqrt(num_patches))
# Check if it's a perfect square for a standard grid
if grid_size * grid_size == num_patches:
# Reshape from (1, N, C) to (1, H, W, C)
v = v.reshape(1, grid_size, grid_size, v.shape[2])
else:
# Not a square grid, we need to get the actual dimensions
if hasattr(model.embeds.patch_embed, 'grid_size'):
h, w = model.embeds.patch_embed.grid_size
if h * w == num_patches:
# We have the right dimensions
v = v.reshape(1, h, w, v.shape[2])
else:
# Dimensions don't match, use interpolation
_logger.warning(
f"Position embedding size mismatch: checkpoint={num_patches}, model={(h * w)}. "
f"Using default initialization and will resize in forward pass."
)
# Keep v as is, the forward pass will handle resizing
out_dict['embeds.pos_embed'] = v
elif k == 'cls_token':
out_dict['embeds.cls_token'] = v
elif k == 'reg_token':
out_dict['embeds.reg_token'] = v
# Convert patch_embed.X to embeds.patch_embed.X
elif k.startswith('patch_embed.'):
new_key = 'embeds.' + k[12:]
out_dict[new_key] = v
else:
out_dict[k] = v
# Call the original filter function to handle other patterns
return orig_filter_fn(out_dict, model)

View File

@ -158,6 +158,12 @@ parser.add_argument('--valid-labels', default='', type=str, metavar='FILENAME',
parser.add_argument('--retry', default=False, action='store_true', parser.add_argument('--retry', default=False, action='store_true',
help='Enable batch size decay & retry for single model validation') help='Enable batch size decay & retry for single model validation')
# NaFlex loader arguments
parser.add_argument('--naflex-loader', action='store_true', default=False,
help='Use NaFlex loader (Requires NaFlex compatible model)')
parser.add_argument('--naflex-max-seq-len', type=int, default=576,
help='Fixed maximum sequence length for NaFlex loader (validation)')
def validate(args): def validate(args):
# might as well try to validate something # might as well try to validate something
@ -293,23 +299,43 @@ def validate(args):
real_labels = None real_labels = None
crop_pct = 1.0 if test_time_pool else data_config['crop_pct'] crop_pct = 1.0 if test_time_pool else data_config['crop_pct']
loader = create_loader( if args.naflex_loader:
dataset, from timm.data import create_naflex_loader
input_size=data_config['input_size'], loader = create_naflex_loader(
batch_size=args.batch_size, dataset,
use_prefetcher=args.prefetcher, batch_size=args.batch_size,
interpolation=data_config['interpolation'], use_prefetcher=args.prefetcher,
mean=data_config['mean'], interpolation=data_config['interpolation'],
std=data_config['std'], mean=data_config['mean'],
num_workers=args.workers, std=data_config['std'],
crop_pct=crop_pct, num_workers=args.workers,
crop_mode=data_config['crop_mode'], crop_pct=crop_pct,
crop_border_pixels=args.crop_border_pixels, crop_mode=data_config['crop_mode'],
pin_memory=args.pin_mem, crop_border_pixels=args.crop_border_pixels,
device=device, pin_memory=args.pin_mem,
img_dtype=model_dtype or torch.float32, device=device,
tf_preprocessing=args.tf_preprocessing, img_dtype=model_dtype or torch.float32,
) patch_size=16, # Could be derived from model config
max_seq_len=args.naflex_max_seq_len,
)
else:
loader = create_loader(
dataset,
input_size=data_config['input_size'],
batch_size=args.batch_size,
use_prefetcher=args.prefetcher,
interpolation=data_config['interpolation'],
mean=data_config['mean'],
std=data_config['std'],
num_workers=args.workers,
crop_pct=crop_pct,
crop_mode=data_config['crop_mode'],
crop_border_pixels=args.crop_border_pixels,
pin_memory=args.pin_mem,
device=device,
img_dtype=model_dtype or torch.float32,
tf_preprocessing=args.tf_preprocessing,
)
batch_time = AverageMeter() batch_time = AverageMeter()
losses = AverageMeter() losses = AverageMeter()
@ -345,10 +371,11 @@ def validate(args):
real_labels.add_result(output) real_labels.add_result(output)
# measure accuracy and record loss # measure accuracy and record loss
batch_size = output.shape[0]
acc1, acc5 = accuracy(output.detach(), target, topk=(1, 5)) acc1, acc5 = accuracy(output.detach(), target, topk=(1, 5))
losses.update(loss.item(), input.size(0)) losses.update(loss.item(), batch_size)
top1.update(acc1.item(), input.size(0)) top1.update(acc1.item(), batch_size)
top5.update(acc5.item(), input.size(0)) top5.update(acc5.item(), batch_size)
# measure elapsed time # measure elapsed time
batch_time.update(time.time() - end) batch_time.update(time.time() - end)
@ -364,7 +391,7 @@ def validate(args):
batch_idx, batch_idx,
len(loader), len(loader),
batch_time=batch_time, batch_time=batch_time,
rate_avg=input.size(0) / batch_time.avg, rate_avg=batch_size / batch_time.avg,
loss=losses, loss=losses,
top1=top1, top1=top1,
top5=top5 top5=top5