mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add so400m model size for test, few tweaks.
This commit is contained in:
parent
7bfe606d9f
commit
d7d3538335
@ -8,7 +8,7 @@ from .dataset_info import DatasetInfo, CustomDatasetInfo
|
||||
from .imagenet_info import ImageNetInfo, infer_imagenet_subset
|
||||
from .loader import create_loader
|
||||
from .mixup import Mixup, FastCollateMixup
|
||||
from .naflex_dataset import VariableSeqMapWrapper
|
||||
from .naflex_dataset import VariableSeqMapWrapper, calculate_naflex_batch_size
|
||||
from .naflex_loader import create_naflex_loader
|
||||
from .naflex_mixup import NaFlexMixup, pairwise_mixup_target, mix_batch_variable_size
|
||||
from .naflex_transforms import (
|
||||
@ -17,6 +17,8 @@ from .naflex_transforms import (
|
||||
RandomCropToSequence,
|
||||
RandomResizedCropToSequence,
|
||||
ResizeKeepRatioToSequence,
|
||||
Patchify,
|
||||
patchify_image,
|
||||
)
|
||||
from .readers import create_reader
|
||||
from .readers import get_img_extensions, is_img_extension, set_img_extensions, add_img_extensions, del_img_extensions
|
||||
|
@ -24,10 +24,10 @@ from torchvision import transforms
|
||||
from PIL import Image
|
||||
|
||||
|
||||
from .naflex_transforms import Patchify, patchify
|
||||
from .naflex_transforms import Patchify, patchify_image
|
||||
|
||||
|
||||
def calculate_batch_size(
|
||||
def calculate_naflex_batch_size(
|
||||
tokens_per_batch: int,
|
||||
seq_len: int,
|
||||
max_size: Optional[int] = None,
|
||||
@ -240,7 +240,7 @@ class VariableSeqMapWrapper(IterableDataset):
|
||||
seq_len = self.seq_lens[seq_idx]
|
||||
|
||||
# Calculate batch size
|
||||
batch_size = calculate_batch_size(
|
||||
batch_size = calculate_naflex_batch_size(
|
||||
tokens_per_batch=self.max_tokens_per_batch,
|
||||
seq_len=seq_len,
|
||||
# max_size should be remaining_samples to avoid overshooting
|
||||
|
@ -738,7 +738,7 @@ class RandomResizedCropToSequence(torch.nn.Module):
|
||||
return format_string
|
||||
|
||||
|
||||
def patchify(
|
||||
def patchify_image(
|
||||
img: torch.Tensor,
|
||||
patch_size: Tuple[int, int],
|
||||
pad: bool = True,
|
||||
@ -794,7 +794,7 @@ class Patchify(torch.nn.Module):
|
||||
# Convert PIL Image to tensor [C, H, W]
|
||||
img = transforms.functional.to_tensor(img)
|
||||
|
||||
patches, coord, valid = patchify(img, self.patch_size)
|
||||
patches, coord, valid = patchify_image(img, self.patch_size)
|
||||
|
||||
return {
|
||||
'patches': patches,
|
||||
|
@ -1027,6 +1027,7 @@ def _cfg(url: str = '', **kwargs) -> Dict[str, Any]:
|
||||
default_cfgs = generate_default_cfgs({
|
||||
'vit_naflex_base_patch16_gap': _cfg(),
|
||||
'vit_naflex_base_patch16_map': _cfg(),
|
||||
'vit_naflex_so400m_patch16_map': _cfg(),
|
||||
|
||||
# sbb model testijg
|
||||
'vit_naflex_mediumd_patch16_reg4_gap.sbb2_r256_e200_in12k_ft_in1k': _cfg(
|
||||
@ -1110,3 +1111,15 @@ def vit_naflex_base_patch16(pretrained: bool = False, **kwargs):
|
||||
global_pool='token', class_token=True, pos_embed_grid_size=(14, 14))
|
||||
model = _create_vision_transformer_flex('vit_naflex_base_patch16', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_naflex_so400m_patch16_map(pretrained=False, **kwargs):
|
||||
"""ViT-SO400M with NaFlex functionality for variable aspect ratios and resolutions.
|
||||
"""
|
||||
model_args = dict(
|
||||
patch_size=16, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362, init_values=1e-5,
|
||||
global_pool='map', class_token=False, reg_tokens=1, act_layer='gelu_tanh')
|
||||
model = _create_vision_transformer_flex(
|
||||
'vit_naflex_so400m_patch16_map', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
Loading…
x
Reference in New Issue
Block a user