diff --git a/timm/data/naflex_transforms.py b/timm/data/naflex_transforms.py index cd16553b..81653e78 100644 --- a/timm/data/naflex_transforms.py +++ b/timm/data/naflex_transforms.py @@ -760,7 +760,6 @@ def patchify( nh, nw = h // ph, w // pw # Reshape image to patches [nh, nw, ph, pw, c] patches = img.view(c, nh, ph, nw, pw).permute(1, 3, 2, 4, 0).reshape(nh * nw, ph * pw * c) - if include_info: # Create coordinate indices y_idx, x_idx = torch.meshgrid(torch.arange(nh), torch.arange(nw), indexing='ij') diff --git a/timm/data/transforms_factory.py b/timm/data/transforms_factory.py index 3d5b5a0d..ed427456 100644 --- a/timm/data/transforms_factory.py +++ b/timm/data/transforms_factory.py @@ -318,7 +318,7 @@ def transforms_imagenet_eval( tfl += [ResizeToSequence( patch_size=patch_size, max_seq_len=max_seq_len, - interpolation=interpolation + interpolation=interpolation, )] else: if crop_mode == 'squash': diff --git a/timm/models/vision_transformer_flex.py b/timm/models/vision_transformer_flex.py index 93c2c892..c8c52586 100644 --- a/timm/models/vision_transformer_flex.py +++ b/timm/models/vision_transformer_flex.py @@ -52,6 +52,7 @@ def batch_patchify( nh, nw = H // ph, W // pw patches = x.view(B, C, nh, ph, nw, pw).permute(0, 2, 4, 3, 5, 1).reshape(B, nh * nw, ph * pw * C) + # FIXME confirm we want 'channels last' in the patch channel layout, egg ph, ph, C instead of C, ph, hw return patches, (nh, nw) @@ -297,7 +298,7 @@ class FlexEmbeds(nn.Module): size_to_indices[k].append(bi) # Handle each batch element separately with its own grid size - pos_embed_nchw = self.pos_embed.permute(0, 3, 1, 2) # B,C,H,W + pos_embed_nchw = self.pos_embed.permute(0, 3, 1, 2).float() # B,C,H,W for k, batch_indices in size_to_indices.items(): h, w = k #h, w = k >> 16, k & 0xFFFF # FIXME can get jit compat with this @@ -312,9 +313,14 @@ class FlexEmbeds(nn.Module): align_corners=False, antialias=True, ).flatten(2).transpose(1, 2) + pos_embed_flat = pos_embed_flat.to(dtype=x.dtype) seq_len = min(x.shape[1], pos_embed_flat.shape[1]) - x[batch_indices, :seq_len].add_(pos_embed_flat[:, :seq_len]) + x[:, :seq_len].index_add_( + 0, + torch.as_tensor(batch_indices, device=x.device), + pos_embed_flat[:, :seq_len].expand(len(batch_indices), -1, -1) + ) def _apply_learned_pos_embed( self, @@ -328,12 +334,13 @@ class FlexEmbeds(nn.Module): else: # Resize if needed - directly using F.interpolate pos_embed_flat = F.interpolate( - self.pos_embed.permute(0, 3, 1, 2), # B,C,H,W + self.pos_embed.permute(0, 3, 1, 2).float(), # B,C,H,W size=grid_size, mode=self.pos_embed_interp_mode, align_corners=False, antialias=True, ).flatten(2).transpose(1, 2) + pos_embed_flat = pos_embed_flat.to(dtype=x.dtype) x.add_(pos_embed_flat) @@ -806,21 +813,20 @@ class VisionTransformerFlex(nn.Module): # Apply the mask to extract only valid tokens x = x[:, self.num_prefix_tokens:] # prefix tokens not included in pooling + patch_valid_float = patch_valid.to(x.dtype) if pool_type == 'avg': - # Compute masked average pooling - # Sum valid tokens and divide by count of valid tokens - masked_sums = (x * patch_valid.unsqueeze(-1).float()).sum(dim=1) - valid_counts = patch_valid.float().sum(dim=1, keepdim=True).clamp(min=1) + # Compute masked average pooling, sum valid tokens and divide by count of valid tokens + masked_sums = (x * patch_valid_float.unsqueeze(-1)).sum(dim=1) + valid_counts = patch_valid_float.sum(dim=1, keepdim=True).clamp(min=1) pooled = masked_sums / valid_counts return pooled elif pool_type == 'avgmax': # For avgmax, compute masked average and masked max - # For max, we set masked positions to large negative value - masked_sums = (x * patch_valid.unsqueeze(-1).float()).sum(dim=1) - valid_counts = patch_valid.float().sum(dim=1, keepdim=True).clamp(min=1) + masked_sums = (x * patch_valid_float.unsqueeze(-1)).sum(dim=1) + valid_counts = patch_valid_float.sum(dim=1, keepdim=True).clamp(min=1) masked_avg = masked_sums / valid_counts - # For max pooling with mask + # For max pooling we set masked positions to large negative value masked_x = x.clone() masked_x[~patch_valid] = torch.finfo(masked_x.dtype).min masked_max = masked_x.max(dim=1)[0] @@ -915,6 +921,82 @@ def get_init_weights_vit(mode: str = 'jax', head_bias: float = 0.0) -> Callable: return init_weights_vit_timm +def checkpoint_filter_fn(state_dict, model): + """Handle state dict conversion from original ViT to the new version with combined embedding.""" + from .vision_transformer import checkpoint_filter_fn as orig_filter_fn + + # Handle CombinedEmbed module pattern + out_dict = {} + for k, v in state_dict.items(): + # Convert tokens and embeddings to combined_embed structure + if k == 'pos_embed': + # Handle position embedding format conversion - from (1, N, C) to (1, H, W, C) + if hasattr(model.embeds, 'pos_embed') and v.ndim == 3: + num_cls_token = 0 + num_reg_token = 0 + if 'reg_token' in state_dict: + num_reg_token = state_dict['reg_token'].shape[1] + if 'cls_token' in state_dict: + num_cls_token = state_dict['cls_token'].shape[1] + num_prefix_tokens = num_cls_token + num_reg_token + + # Original format is (1, N, C), need to reshape to (1, H, W, C) + num_patches = v.shape[1] + num_patches_no_prefix = num_patches - num_prefix_tokens + grid_size_no_prefix = math.sqrt(num_patches_no_prefix) + grid_size = math.sqrt(num_patches) + if (grid_size_no_prefix != grid_size and ( + grid_size_no_prefix.is_integer() and not grid_size.is_integer())): + # make a decision, did the pos_embed of the original include the prefix tokens? + num_patches = num_patches_no_prefix + cls_token_emb = v[:, 0:num_cls_token] + if cls_token_emb.numel(): + state_dict['cls_token'] += cls_token_emb + reg_token_emb = v[:, num_cls_token:num_reg_token] + if reg_token_emb.numel(): + state_dict['reg_token'] += reg_token_emb + v = v[:, num_prefix_tokens:] + grid_size = grid_size_no_prefix + grid_size = int(grid_size) + + # Check if it's a perfect square for a standard grid + if grid_size * grid_size == num_patches: + # Reshape from (1, N, C) to (1, H, W, C) + v = v.reshape(1, grid_size, grid_size, v.shape[2]) + else: + # Not a square grid, we need to get the actual dimensions + if hasattr(model.embeds.patch_embed, 'grid_size'): + h, w = model.embeds.patch_embed.grid_size + if h * w == num_patches: + # We have the right dimensions + v = v.reshape(1, h, w, v.shape[2]) + else: + # Dimensions don't match, use interpolation + _logger.warning( + f"Position embedding size mismatch: checkpoint={num_patches}, model={(h * w)}. " + f"Using default initialization and will resize in forward pass." + ) + # Keep v as is, the forward pass will handle resizing + + out_dict['embeds.pos_embed'] = v + elif k == 'cls_token': + out_dict['embeds.cls_token'] = v + elif k == 'reg_token': + out_dict['embeds.reg_token'] = v + # Convert patch_embed.X to embeds.patch_embed.X + elif k.startswith('patch_embed.'): + suffix = k[12:] + if suffix == 'proj.weight': + # FIXME confirm patchify memory layout across use cases + v = v.permute(0, 2, 3, 1).flatten(1) + new_key = 'embeds.' + suffix + out_dict[new_key] = v + else: + out_dict[k] = v + + return out_dict + + def _cfg(url: str = '', **kwargs) -> Dict[str, Any]: return { 'url': url, @@ -936,6 +1018,26 @@ default_cfgs = generate_default_cfgs({ 'vit_naflex_base_patch16': _cfg(), 'vit_naflex_base_patch16_gap': _cfg(), 'vit_naflex_base_patch16_map': _cfg(), + + # sbb model testijg + 'vit_naflex_mediumd_patch16_reg4_gap.sbb2_r256_e200_in12k_ft_in1k': _cfg( + hf_hub_id='timm/vit_mediumd_patch16_reg4_gap_256.sbb2_e200_in12k_ft_in1k', + input_size=(3, 256, 256), crop_pct=0.95), + 'vit_naflex_so150m2_patch16_reg1_gap.sbb_r256_e200_in12k_ft_in1k': _cfg( + hf_hub_id='timm/vit_so150m2_patch16_reg1_gap_256.sbb_e200_in12k_ft_in1k', + input_size=(3, 256, 256), crop_pct=1.0), + 'vit_naflex_so150m2_patch16_reg1_gap.sbb_r384_e200_in12k_ft_in1k': _cfg( + hf_hub_id='timm/vit_so150m2_patch16_reg1_gap_384.sbb_e200_in12k_ft_in1k', + input_size=(3, 384, 384), crop_pct=1.0), + 'vit_naflex_so150m2_patch16_reg1_gap.sbb_r448_e200_in12k_ft_in1k': _cfg( + hf_hub_id='timm/vit_so150m2_patch16_reg1_gap_448.sbb_e200_in12k_ft_in1k', + input_size=(3, 448, 448), crop_pct=1.0, crop_mode='squash'), + + # traditional vit testing + 'vit_naflex_base_patch16.augreg2_r224_in21k_ft_in1k': _cfg( + hf_hub_id='timm/vit_base_patch16_224.augreg2_in21k_ft_in1k'), + 'vit_naflex_base_patch8.augreg2_r224_in21k_ft_in1k': _cfg( + hf_hub_id='timm/vit_base_patch16_224.augreg2_in21k_ft_in1k'), }) @@ -948,10 +1050,22 @@ def _create_vision_transformer_flex(variant, pretrained=False, **kwargs): return model +@register_model +def vit_naflex_mediumd_patch16_reg4_gap(pretrained=False, **kwargs): + """ViT-New with NaFlex functionality for variable aspect ratios and resolutions. + """ + model_args = dict( + patch_size=16, embed_dim=512, depth=20, num_heads=8, init_values=1e-5, + global_pool='avg', class_token=False, reg_tokens=4, fc_norm=True, **kwargs) + model = _create_vision_transformer_flex( + 'vit_naflex_mediumd_patch16_reg4_gap', pretrained=pretrained, **model_args) + return model + + @register_model def vit_naflex_base_patch16(pretrained=False, **kwargs): """ViT-New with NaFlex functionality for variable aspect ratios and resolutions. - + This model supports: 1. Variable aspect ratios and resolutions via patch coordinates 2. Position embedding interpolation for arbitrary grid sizes @@ -987,54 +1101,28 @@ def vit_naflex_base_patch16_map(pretrained=False, **kwargs): return model -def checkpoint_filter_fn(state_dict, model): - """Handle state dict conversion from original ViT to the new version with combined embedding.""" - from .vision_transformer import checkpoint_filter_fn as orig_filter_fn +@register_model +def vit_naflex_so150m2_patch16_reg1_gap(pretrained=False, **kwargs): + """ViT-New with NaFlex functionality for variable aspect ratios and resolutions. - # FIXME conversion of existing vit checkpoints has not been finished or tested + This model supports: + 1. Variable aspect ratios and resolutions via patch coordinates + 2. Position embedding interpolation for arbitrary grid sizes + 3. Explicit patch coordinates and valid token masking + """ + model_args = dict( + patch_size=16, embed_dim=832, depth=21, num_heads=13, mlp_ratio=34/13, init_values=1e-5, + qkv_bias=False, class_token=False, reg_tokens=1, global_pool='avg', fc_norm=True, **kwargs) + model = _create_vision_transformer_flex( + 'vit_naflex_so150m2_patch16_reg1_gap', pretrained=pretrained, **model_args) + return model - # Handle CombinedEmbed module pattern - out_dict = {} - for k, v in state_dict.items(): - # Convert tokens and embeddings to combined_embed structure - if k == 'pos_embed': - # Handle position embedding format conversion - from (1, N, C) to (1, H, W, C) - if hasattr(model.embeds, 'pos_embed') and v.ndim == 3: - # Original format is (1, N, C) - need to reshape to (1, H, W, C) - num_patches = v.shape[1] - grid_size = int(math.sqrt(num_patches)) - - # Check if it's a perfect square for a standard grid - if grid_size * grid_size == num_patches: - # Reshape from (1, N, C) to (1, H, W, C) - v = v.reshape(1, grid_size, grid_size, v.shape[2]) - else: - # Not a square grid, we need to get the actual dimensions - if hasattr(model.embeds.patch_embed, 'grid_size'): - h, w = model.embeds.patch_embed.grid_size - if h * w == num_patches: - # We have the right dimensions - v = v.reshape(1, h, w, v.shape[2]) - else: - # Dimensions don't match, use interpolation - _logger.warning( - f"Position embedding size mismatch: checkpoint={num_patches}, model={(h * w)}. " - f"Using default initialization and will resize in forward pass." - ) - # Keep v as is, the forward pass will handle resizing - - out_dict['embeds.pos_embed'] = v - - elif k == 'cls_token': - out_dict['embeds.cls_token'] = v - elif k == 'reg_token': - out_dict['embeds.reg_token'] = v - # Convert patch_embed.X to embeds.patch_embed.X - elif k.startswith('patch_embed.'): - new_key = 'embeds.' + k[12:] - out_dict[new_key] = v - else: - out_dict[k] = v - - # Call the original filter function to handle other patterns - return orig_filter_fn(out_dict, model) \ No newline at end of file + +@register_model +def vit_naflex_base_patch16(pretrained: bool = False, **kwargs): + """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer. + """ + model_args = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, global_pool='token', class_token=True, pos_embed_grid_size=(14, 14)) + model = _create_vision_transformer_flex('vit_naflex_base_patch16', pretrained=pretrained, **dict(model_args, **kwargs)) + return model diff --git a/validate.py b/validate.py index f757855e..59e78a91 100755 --- a/validate.py +++ b/validate.py @@ -158,6 +158,12 @@ parser.add_argument('--valid-labels', default='', type=str, metavar='FILENAME', parser.add_argument('--retry', default=False, action='store_true', help='Enable batch size decay & retry for single model validation') +# NaFlex loader arguments +parser.add_argument('--naflex-loader', action='store_true', default=False, + help='Use NaFlex loader (Requires NaFlex compatible model)') +parser.add_argument('--naflex-max-seq-len', type=int, default=576, + help='Fixed maximum sequence length for NaFlex loader (validation)') + def validate(args): # might as well try to validate something @@ -293,23 +299,43 @@ def validate(args): real_labels = None crop_pct = 1.0 if test_time_pool else data_config['crop_pct'] - loader = create_loader( - dataset, - input_size=data_config['input_size'], - batch_size=args.batch_size, - use_prefetcher=args.prefetcher, - interpolation=data_config['interpolation'], - mean=data_config['mean'], - std=data_config['std'], - num_workers=args.workers, - crop_pct=crop_pct, - crop_mode=data_config['crop_mode'], - crop_border_pixels=args.crop_border_pixels, - pin_memory=args.pin_mem, - device=device, - img_dtype=model_dtype or torch.float32, - tf_preprocessing=args.tf_preprocessing, - ) + if args.naflex_loader: + from timm.data import create_naflex_loader + loader = create_naflex_loader( + dataset, + batch_size=args.batch_size, + use_prefetcher=args.prefetcher, + interpolation=data_config['interpolation'], + mean=data_config['mean'], + std=data_config['std'], + num_workers=args.workers, + crop_pct=crop_pct, + crop_mode=data_config['crop_mode'], + crop_border_pixels=args.crop_border_pixels, + pin_memory=args.pin_mem, + device=device, + img_dtype=model_dtype or torch.float32, + patch_size=16, # Could be derived from model config + max_seq_len=args.naflex_max_seq_len, + ) + else: + loader = create_loader( + dataset, + input_size=data_config['input_size'], + batch_size=args.batch_size, + use_prefetcher=args.prefetcher, + interpolation=data_config['interpolation'], + mean=data_config['mean'], + std=data_config['std'], + num_workers=args.workers, + crop_pct=crop_pct, + crop_mode=data_config['crop_mode'], + crop_border_pixels=args.crop_border_pixels, + pin_memory=args.pin_mem, + device=device, + img_dtype=model_dtype or torch.float32, + tf_preprocessing=args.tf_preprocessing, + ) batch_time = AverageMeter() losses = AverageMeter() @@ -345,10 +371,11 @@ def validate(args): real_labels.add_result(output) # measure accuracy and record loss + batch_size = output.shape[0] acc1, acc5 = accuracy(output.detach(), target, topk=(1, 5)) - losses.update(loss.item(), input.size(0)) - top1.update(acc1.item(), input.size(0)) - top5.update(acc5.item(), input.size(0)) + losses.update(loss.item(), batch_size) + top1.update(acc1.item(), batch_size) + top5.update(acc5.item(), batch_size) # measure elapsed time batch_time.update(time.time() - end) @@ -364,7 +391,7 @@ def validate(args): batch_idx, len(loader), batch_time=batch_time, - rate_avg=input.size(0) / batch_time.avg, + rate_avg=batch_size / batch_time.avg, loss=losses, top1=top1, top5=top5