Add so400m model size for test, few tweaks.

This commit is contained in:
Ross Wightman 2025-05-23 18:29:30 -07:00
parent 7bfe606d9f
commit d7d3538335
4 changed files with 21 additions and 6 deletions

View File

@ -8,7 +8,7 @@ from .dataset_info import DatasetInfo, CustomDatasetInfo
from .imagenet_info import ImageNetInfo, infer_imagenet_subset
from .loader import create_loader
from .mixup import Mixup, FastCollateMixup
from .naflex_dataset import VariableSeqMapWrapper
from .naflex_dataset import VariableSeqMapWrapper, calculate_naflex_batch_size
from .naflex_loader import create_naflex_loader
from .naflex_mixup import NaFlexMixup, pairwise_mixup_target, mix_batch_variable_size
from .naflex_transforms import (
@ -17,6 +17,8 @@ from .naflex_transforms import (
RandomCropToSequence,
RandomResizedCropToSequence,
ResizeKeepRatioToSequence,
Patchify,
patchify_image,
)
from .readers import create_reader
from .readers import get_img_extensions, is_img_extension, set_img_extensions, add_img_extensions, del_img_extensions

View File

@ -24,10 +24,10 @@ from torchvision import transforms
from PIL import Image
from .naflex_transforms import Patchify, patchify
from .naflex_transforms import Patchify, patchify_image
def calculate_batch_size(
def calculate_naflex_batch_size(
tokens_per_batch: int,
seq_len: int,
max_size: Optional[int] = None,
@ -240,7 +240,7 @@ class VariableSeqMapWrapper(IterableDataset):
seq_len = self.seq_lens[seq_idx]
# Calculate batch size
batch_size = calculate_batch_size(
batch_size = calculate_naflex_batch_size(
tokens_per_batch=self.max_tokens_per_batch,
seq_len=seq_len,
# max_size should be remaining_samples to avoid overshooting

View File

@ -738,7 +738,7 @@ class RandomResizedCropToSequence(torch.nn.Module):
return format_string
def patchify(
def patchify_image(
img: torch.Tensor,
patch_size: Tuple[int, int],
pad: bool = True,
@ -794,7 +794,7 @@ class Patchify(torch.nn.Module):
# Convert PIL Image to tensor [C, H, W]
img = transforms.functional.to_tensor(img)
patches, coord, valid = patchify(img, self.patch_size)
patches, coord, valid = patchify_image(img, self.patch_size)
return {
'patches': patches,

View File

@ -1027,6 +1027,7 @@ def _cfg(url: str = '', **kwargs) -> Dict[str, Any]:
default_cfgs = generate_default_cfgs({
'vit_naflex_base_patch16_gap': _cfg(),
'vit_naflex_base_patch16_map': _cfg(),
'vit_naflex_so400m_patch16_map': _cfg(),
# sbb model testijg
'vit_naflex_mediumd_patch16_reg4_gap.sbb2_r256_e200_in12k_ft_in1k': _cfg(
@ -1110,3 +1111,15 @@ def vit_naflex_base_patch16(pretrained: bool = False, **kwargs):
global_pool='token', class_token=True, pos_embed_grid_size=(14, 14))
model = _create_vision_transformer_flex('vit_naflex_base_patch16', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def vit_naflex_so400m_patch16_map(pretrained=False, **kwargs):
"""ViT-SO400M with NaFlex functionality for variable aspect ratios and resolutions.
"""
model_args = dict(
patch_size=16, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362, init_values=1e-5,
global_pool='map', class_token=False, reg_tokens=1, act_layer='gelu_tanh')
model = _create_vision_transformer_flex(
'vit_naflex_so400m_patch16_map', pretrained=pretrained, **dict(model_args, **kwargs))
return model