Add naflex loader support to validate.py, fix bug in naflex pos embed add, classic vit weight loading for naflex model

This commit is contained in:
Ross Wightman 2025-04-25 16:00:54 -07:00
parent c527c37969
commit 3dc90ed7a7
4 changed files with 198 additions and 84 deletions

View File

@ -760,7 +760,6 @@ def patchify(
nh, nw = h // ph, w // pw
# Reshape image to patches [nh, nw, ph, pw, c]
patches = img.view(c, nh, ph, nw, pw).permute(1, 3, 2, 4, 0).reshape(nh * nw, ph * pw * c)
if include_info:
# Create coordinate indices
y_idx, x_idx = torch.meshgrid(torch.arange(nh), torch.arange(nw), indexing='ij')

View File

@ -318,7 +318,7 @@ def transforms_imagenet_eval(
tfl += [ResizeToSequence(
patch_size=patch_size,
max_seq_len=max_seq_len,
interpolation=interpolation
interpolation=interpolation,
)]
else:
if crop_mode == 'squash':

View File

@ -52,6 +52,7 @@ def batch_patchify(
nh, nw = H // ph, W // pw
patches = x.view(B, C, nh, ph, nw, pw).permute(0, 2, 4, 3, 5, 1).reshape(B, nh * nw, ph * pw * C)
# FIXME confirm we want 'channels last' in the patch channel layout, egg ph, ph, C instead of C, ph, hw
return patches, (nh, nw)
@ -297,7 +298,7 @@ class FlexEmbeds(nn.Module):
size_to_indices[k].append(bi)
# Handle each batch element separately with its own grid size
pos_embed_nchw = self.pos_embed.permute(0, 3, 1, 2) # B,C,H,W
pos_embed_nchw = self.pos_embed.permute(0, 3, 1, 2).float() # B,C,H,W
for k, batch_indices in size_to_indices.items():
h, w = k
#h, w = k >> 16, k & 0xFFFF # FIXME can get jit compat with this
@ -312,9 +313,14 @@ class FlexEmbeds(nn.Module):
align_corners=False,
antialias=True,
).flatten(2).transpose(1, 2)
pos_embed_flat = pos_embed_flat.to(dtype=x.dtype)
seq_len = min(x.shape[1], pos_embed_flat.shape[1])
x[batch_indices, :seq_len].add_(pos_embed_flat[:, :seq_len])
x[:, :seq_len].index_add_(
0,
torch.as_tensor(batch_indices, device=x.device),
pos_embed_flat[:, :seq_len].expand(len(batch_indices), -1, -1)
)
def _apply_learned_pos_embed(
self,
@ -328,12 +334,13 @@ class FlexEmbeds(nn.Module):
else:
# Resize if needed - directly using F.interpolate
pos_embed_flat = F.interpolate(
self.pos_embed.permute(0, 3, 1, 2), # B,C,H,W
self.pos_embed.permute(0, 3, 1, 2).float(), # B,C,H,W
size=grid_size,
mode=self.pos_embed_interp_mode,
align_corners=False,
antialias=True,
).flatten(2).transpose(1, 2)
pos_embed_flat = pos_embed_flat.to(dtype=x.dtype)
x.add_(pos_embed_flat)
@ -806,21 +813,20 @@ class VisionTransformerFlex(nn.Module):
# Apply the mask to extract only valid tokens
x = x[:, self.num_prefix_tokens:] # prefix tokens not included in pooling
patch_valid_float = patch_valid.to(x.dtype)
if pool_type == 'avg':
# Compute masked average pooling
# Sum valid tokens and divide by count of valid tokens
masked_sums = (x * patch_valid.unsqueeze(-1).float()).sum(dim=1)
valid_counts = patch_valid.float().sum(dim=1, keepdim=True).clamp(min=1)
# Compute masked average pooling, sum valid tokens and divide by count of valid tokens
masked_sums = (x * patch_valid_float.unsqueeze(-1)).sum(dim=1)
valid_counts = patch_valid_float.sum(dim=1, keepdim=True).clamp(min=1)
pooled = masked_sums / valid_counts
return pooled
elif pool_type == 'avgmax':
# For avgmax, compute masked average and masked max
# For max, we set masked positions to large negative value
masked_sums = (x * patch_valid.unsqueeze(-1).float()).sum(dim=1)
valid_counts = patch_valid.float().sum(dim=1, keepdim=True).clamp(min=1)
masked_sums = (x * patch_valid_float.unsqueeze(-1)).sum(dim=1)
valid_counts = patch_valid_float.sum(dim=1, keepdim=True).clamp(min=1)
masked_avg = masked_sums / valid_counts
# For max pooling with mask
# For max pooling we set masked positions to large negative value
masked_x = x.clone()
masked_x[~patch_valid] = torch.finfo(masked_x.dtype).min
masked_max = masked_x.max(dim=1)[0]
@ -915,6 +921,82 @@ def get_init_weights_vit(mode: str = 'jax', head_bias: float = 0.0) -> Callable:
return init_weights_vit_timm
def checkpoint_filter_fn(state_dict, model):
"""Handle state dict conversion from original ViT to the new version with combined embedding."""
from .vision_transformer import checkpoint_filter_fn as orig_filter_fn
# Handle CombinedEmbed module pattern
out_dict = {}
for k, v in state_dict.items():
# Convert tokens and embeddings to combined_embed structure
if k == 'pos_embed':
# Handle position embedding format conversion - from (1, N, C) to (1, H, W, C)
if hasattr(model.embeds, 'pos_embed') and v.ndim == 3:
num_cls_token = 0
num_reg_token = 0
if 'reg_token' in state_dict:
num_reg_token = state_dict['reg_token'].shape[1]
if 'cls_token' in state_dict:
num_cls_token = state_dict['cls_token'].shape[1]
num_prefix_tokens = num_cls_token + num_reg_token
# Original format is (1, N, C), need to reshape to (1, H, W, C)
num_patches = v.shape[1]
num_patches_no_prefix = num_patches - num_prefix_tokens
grid_size_no_prefix = math.sqrt(num_patches_no_prefix)
grid_size = math.sqrt(num_patches)
if (grid_size_no_prefix != grid_size and (
grid_size_no_prefix.is_integer() and not grid_size.is_integer())):
# make a decision, did the pos_embed of the original include the prefix tokens?
num_patches = num_patches_no_prefix
cls_token_emb = v[:, 0:num_cls_token]
if cls_token_emb.numel():
state_dict['cls_token'] += cls_token_emb
reg_token_emb = v[:, num_cls_token:num_reg_token]
if reg_token_emb.numel():
state_dict['reg_token'] += reg_token_emb
v = v[:, num_prefix_tokens:]
grid_size = grid_size_no_prefix
grid_size = int(grid_size)
# Check if it's a perfect square for a standard grid
if grid_size * grid_size == num_patches:
# Reshape from (1, N, C) to (1, H, W, C)
v = v.reshape(1, grid_size, grid_size, v.shape[2])
else:
# Not a square grid, we need to get the actual dimensions
if hasattr(model.embeds.patch_embed, 'grid_size'):
h, w = model.embeds.patch_embed.grid_size
if h * w == num_patches:
# We have the right dimensions
v = v.reshape(1, h, w, v.shape[2])
else:
# Dimensions don't match, use interpolation
_logger.warning(
f"Position embedding size mismatch: checkpoint={num_patches}, model={(h * w)}. "
f"Using default initialization and will resize in forward pass."
)
# Keep v as is, the forward pass will handle resizing
out_dict['embeds.pos_embed'] = v
elif k == 'cls_token':
out_dict['embeds.cls_token'] = v
elif k == 'reg_token':
out_dict['embeds.reg_token'] = v
# Convert patch_embed.X to embeds.patch_embed.X
elif k.startswith('patch_embed.'):
suffix = k[12:]
if suffix == 'proj.weight':
# FIXME confirm patchify memory layout across use cases
v = v.permute(0, 2, 3, 1).flatten(1)
new_key = 'embeds.' + suffix
out_dict[new_key] = v
else:
out_dict[k] = v
return out_dict
def _cfg(url: str = '', **kwargs) -> Dict[str, Any]:
return {
'url': url,
@ -936,6 +1018,26 @@ default_cfgs = generate_default_cfgs({
'vit_naflex_base_patch16': _cfg(),
'vit_naflex_base_patch16_gap': _cfg(),
'vit_naflex_base_patch16_map': _cfg(),
# sbb model testijg
'vit_naflex_mediumd_patch16_reg4_gap.sbb2_r256_e200_in12k_ft_in1k': _cfg(
hf_hub_id='timm/vit_mediumd_patch16_reg4_gap_256.sbb2_e200_in12k_ft_in1k',
input_size=(3, 256, 256), crop_pct=0.95),
'vit_naflex_so150m2_patch16_reg1_gap.sbb_r256_e200_in12k_ft_in1k': _cfg(
hf_hub_id='timm/vit_so150m2_patch16_reg1_gap_256.sbb_e200_in12k_ft_in1k',
input_size=(3, 256, 256), crop_pct=1.0),
'vit_naflex_so150m2_patch16_reg1_gap.sbb_r384_e200_in12k_ft_in1k': _cfg(
hf_hub_id='timm/vit_so150m2_patch16_reg1_gap_384.sbb_e200_in12k_ft_in1k',
input_size=(3, 384, 384), crop_pct=1.0),
'vit_naflex_so150m2_patch16_reg1_gap.sbb_r448_e200_in12k_ft_in1k': _cfg(
hf_hub_id='timm/vit_so150m2_patch16_reg1_gap_448.sbb_e200_in12k_ft_in1k',
input_size=(3, 448, 448), crop_pct=1.0, crop_mode='squash'),
# traditional vit testing
'vit_naflex_base_patch16.augreg2_r224_in21k_ft_in1k': _cfg(
hf_hub_id='timm/vit_base_patch16_224.augreg2_in21k_ft_in1k'),
'vit_naflex_base_patch8.augreg2_r224_in21k_ft_in1k': _cfg(
hf_hub_id='timm/vit_base_patch16_224.augreg2_in21k_ft_in1k'),
})
@ -948,6 +1050,18 @@ def _create_vision_transformer_flex(variant, pretrained=False, **kwargs):
return model
@register_model
def vit_naflex_mediumd_patch16_reg4_gap(pretrained=False, **kwargs):
"""ViT-New with NaFlex functionality for variable aspect ratios and resolutions.
"""
model_args = dict(
patch_size=16, embed_dim=512, depth=20, num_heads=8, init_values=1e-5,
global_pool='avg', class_token=False, reg_tokens=4, fc_norm=True, **kwargs)
model = _create_vision_transformer_flex(
'vit_naflex_mediumd_patch16_reg4_gap', pretrained=pretrained, **model_args)
return model
@register_model
def vit_naflex_base_patch16(pretrained=False, **kwargs):
"""ViT-New with NaFlex functionality for variable aspect ratios and resolutions.
@ -987,54 +1101,28 @@ def vit_naflex_base_patch16_map(pretrained=False, **kwargs):
return model
def checkpoint_filter_fn(state_dict, model):
"""Handle state dict conversion from original ViT to the new version with combined embedding."""
from .vision_transformer import checkpoint_filter_fn as orig_filter_fn
@register_model
def vit_naflex_so150m2_patch16_reg1_gap(pretrained=False, **kwargs):
"""ViT-New with NaFlex functionality for variable aspect ratios and resolutions.
# FIXME conversion of existing vit checkpoints has not been finished or tested
This model supports:
1. Variable aspect ratios and resolutions via patch coordinates
2. Position embedding interpolation for arbitrary grid sizes
3. Explicit patch coordinates and valid token masking
"""
model_args = dict(
patch_size=16, embed_dim=832, depth=21, num_heads=13, mlp_ratio=34/13, init_values=1e-5,
qkv_bias=False, class_token=False, reg_tokens=1, global_pool='avg', fc_norm=True, **kwargs)
model = _create_vision_transformer_flex(
'vit_naflex_so150m2_patch16_reg1_gap', pretrained=pretrained, **model_args)
return model
# Handle CombinedEmbed module pattern
out_dict = {}
for k, v in state_dict.items():
# Convert tokens and embeddings to combined_embed structure
if k == 'pos_embed':
# Handle position embedding format conversion - from (1, N, C) to (1, H, W, C)
if hasattr(model.embeds, 'pos_embed') and v.ndim == 3:
# Original format is (1, N, C) - need to reshape to (1, H, W, C)
num_patches = v.shape[1]
grid_size = int(math.sqrt(num_patches))
# Check if it's a perfect square for a standard grid
if grid_size * grid_size == num_patches:
# Reshape from (1, N, C) to (1, H, W, C)
v = v.reshape(1, grid_size, grid_size, v.shape[2])
else:
# Not a square grid, we need to get the actual dimensions
if hasattr(model.embeds.patch_embed, 'grid_size'):
h, w = model.embeds.patch_embed.grid_size
if h * w == num_patches:
# We have the right dimensions
v = v.reshape(1, h, w, v.shape[2])
else:
# Dimensions don't match, use interpolation
_logger.warning(
f"Position embedding size mismatch: checkpoint={num_patches}, model={(h * w)}. "
f"Using default initialization and will resize in forward pass."
)
# Keep v as is, the forward pass will handle resizing
out_dict['embeds.pos_embed'] = v
elif k == 'cls_token':
out_dict['embeds.cls_token'] = v
elif k == 'reg_token':
out_dict['embeds.reg_token'] = v
# Convert patch_embed.X to embeds.patch_embed.X
elif k.startswith('patch_embed.'):
new_key = 'embeds.' + k[12:]
out_dict[new_key] = v
else:
out_dict[k] = v
# Call the original filter function to handle other patterns
return orig_filter_fn(out_dict, model)
@register_model
def vit_naflex_base_patch16(pretrained: bool = False, **kwargs):
""" ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
"""
model_args = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, global_pool='token', class_token=True, pos_embed_grid_size=(14, 14))
model = _create_vision_transformer_flex('vit_naflex_base_patch16', pretrained=pretrained, **dict(model_args, **kwargs))
return model

View File

@ -158,6 +158,12 @@ parser.add_argument('--valid-labels', default='', type=str, metavar='FILENAME',
parser.add_argument('--retry', default=False, action='store_true',
help='Enable batch size decay & retry for single model validation')
# NaFlex loader arguments
parser.add_argument('--naflex-loader', action='store_true', default=False,
help='Use NaFlex loader (Requires NaFlex compatible model)')
parser.add_argument('--naflex-max-seq-len', type=int, default=576,
help='Fixed maximum sequence length for NaFlex loader (validation)')
def validate(args):
# might as well try to validate something
@ -293,6 +299,26 @@ def validate(args):
real_labels = None
crop_pct = 1.0 if test_time_pool else data_config['crop_pct']
if args.naflex_loader:
from timm.data import create_naflex_loader
loader = create_naflex_loader(
dataset,
batch_size=args.batch_size,
use_prefetcher=args.prefetcher,
interpolation=data_config['interpolation'],
mean=data_config['mean'],
std=data_config['std'],
num_workers=args.workers,
crop_pct=crop_pct,
crop_mode=data_config['crop_mode'],
crop_border_pixels=args.crop_border_pixels,
pin_memory=args.pin_mem,
device=device,
img_dtype=model_dtype or torch.float32,
patch_size=16, # Could be derived from model config
max_seq_len=args.naflex_max_seq_len,
)
else:
loader = create_loader(
dataset,
input_size=data_config['input_size'],
@ -345,10 +371,11 @@ def validate(args):
real_labels.add_result(output)
# measure accuracy and record loss
batch_size = output.shape[0]
acc1, acc5 = accuracy(output.detach(), target, topk=(1, 5))
losses.update(loss.item(), input.size(0))
top1.update(acc1.item(), input.size(0))
top5.update(acc5.item(), input.size(0))
losses.update(loss.item(), batch_size)
top1.update(acc1.item(), batch_size)
top5.update(acc5.item(), batch_size)
# measure elapsed time
batch_time.update(time.time() - end)
@ -364,7 +391,7 @@ def validate(args):
batch_idx,
len(loader),
batch_time=batch_time,
rate_avg=input.size(0) / batch_time.avg,
rate_avg=batch_size / batch_time.avg,
loss=losses,
top1=top1,
top5=top5