Merge remote-tracking branch 'origin/main' into efficientnet_x

This commit is contained in:
Ross Wightman 2024-05-23 11:01:39 -07:00
commit cee79dada0
52 changed files with 1306 additions and 128 deletions

View File

@ -26,6 +26,12 @@
* The Hugging Face Hub (https://huggingface.co/timm) is now the primary source for `timm` weights. Model cards include link to papers, original source, license.
* Previous 0.6.x can be cloned from [0.6.x](https://github.com/rwightman/pytorch-image-models/tree/0.6.x) branch or installed via pip with version.
### May 14, 2024
* Support loading PaliGemma jax weights into SigLIP ViT models with average pooling.
* Add Hiera models from Meta (https://github.com/facebookresearch/hiera).
* Add `normalize=` flag for transorms, return non-normalized torch.Tensor with original dytpe (for `chug`)
* Version 1.0.3 release
### May 11, 2024
* `Searching for Better ViT Baselines (For the GPU Poor)` weights and vit variants released. Exploring model shapes between Tiny and Base.
@ -42,6 +48,7 @@
| [vit_medium_patch16_reg4_gap_256.sbb_in1k](https://huggingface.co/timm/vit_medium_patch16_reg4_gap_256.sbb_in1k) | 83.47 | 96.622 | 38.88 | 256 |
| [vit_medium_patch16_reg1_gap_256.sbb_in1k](https://huggingface.co/timm/vit_medium_patch16_reg1_gap_256.sbb_in1k) | 83.462 | 96.548 | 38.88 | 256 |
| [vit_little_patch16_reg4_gap_256.sbb_in1k](https://huggingface.co/timm/vit_little_patch16_reg4_gap_256.sbb_in1k) | 82.514 | 96.262 | 22.52 | 256 |
| [vit_wee_patch16_reg1_gap_256.sbb_in1k](https://huggingface.co/timm/vit_wee_patch16_reg1_gap_256.sbb_in1k) | 80.256 | 95.360 | 13.42 | 256 |
| [vit_pwee_patch16_reg1_gap_256.sbb_in1k](https://huggingface.co/timm/vit_pwee_patch16_reg1_gap_256.sbb_in1k) | 80.072 | 95.136 | 15.25 | 256 |
| [vit_mediumd_patch16_reg4_gap_256.sbb_in12k](https://huggingface.co/timm/vit_mediumd_patch16_reg4_gap_256.sbb_in12k) | N/A | N/A | 64.11 | 256 |
| [vit_betwixt_patch16_reg4_gap_256.sbb_in12k](https://huggingface.co/timm/vit_betwixt_patch16_reg4_gap_256.sbb_in12k) | N/A | N/A | 60.4 | 256 |

View File

@ -192,9 +192,9 @@ There are two additional creation arguments impacting the output features.
#### Output index selection
The `out_indices` argument is supported by all models, but not all models have the same index to feature stride mapping. Look at the code or check feature_info to compare. The out indices generally correspond to the `C(i+1)th` feature level (a `2^(i+1)` reduction). For most convnet models, index 0 is the stride 2 features, and index 4 is stride 32. For many ViT or ViT-Conv hybrids there may be many to all features maps of the same shape, or a combination of hierarchical and non-hieararchical feature maps. It is best to look at the `feature_info` attribute to see the number of features, their corresponding channel count and reduction level.
The `out_indices` argument is supported by all models, but not all models have the same index to feature stride mapping. Look at the code or check feature_info to compare. The out indices generally correspond to the `C(i+1)th` feature level (a `2^(i+1)` reduction). For most convnet models, index 0 is the stride 2 features, and index 4 is stride 32. For many ViT or ViT-Conv hybrids there may be many to all features maps of the same shape, or a combination of hierarchical and non-hierarchical feature maps. It is best to look at the `feature_info` attribute to see the number of features, their corresponding channel count and reduction level.
`out_indices` supports negative indexing, this makes it easy to get the last, penunltimate, etc feature map. `out_indices=(-2,)` would return the penultimate feature map for any model.
`out_indices` supports negative indexing, this makes it easy to get the last, penultimate, etc feature map. `out_indices=(-2,)` would return the penultimate feature map for any model.
#### Output stride (feature map dilation)
@ -228,7 +228,7 @@ Accompanying the `forward_intermediates` function is a `prune_intermediate_layer
An `indices` argument is used for both `forward_intermediates()` and `prune_intermediate_layers()` to select the features to return or layers to remove. As with the `out_indices` for `features_only` API, `indices` is model specific and selects which intermediates are returned.
In non-hierarchical block based models such as ViT the indices correspond to the blocks, in models with hierarchical stages they usually correspond to the output of the stem + each hierarhical stage. Both positive (from the start), and negative (relative to the end) indexing works, and `None` is used to return all intermediates.
In non-hierarchical block based models such as ViT the indices correspond to the blocks, in models with hierarchical stages they usually correspond to the output of the stem + each hierarchical stage. Both positive (from the start), and negative (relative to the end) indexing works, and `None` is used to return all intermediates.
The `prune_intermediate_layers()` call returns an indices variable, as negative indices must be converted to absolute (positive) indices when the model is trimmed.

View File

@ -28,7 +28,7 @@ You should install `timm` in a [virtual environment](https://docs.python.org/3/l
# Deactivate the virtual environment
source .env/bin/deactivate
```
`
Once you've created your virtual environment, you can install `timm` in it.
## Using pip

View File

@ -52,7 +52,7 @@ FEAT_INTER_FILTERS = [
'vision_transformer', 'vision_transformer_sam', 'vision_transformer_hybrid', 'vision_transformer_relpos',
'beit', 'mvitv2', 'eva', 'cait', 'xcit', 'volo', 'twins', 'deit', 'swin_transformer', 'swin_transformer_v2',
'swin_transformer_v2_cr', 'maxxvit', 'efficientnet', 'mobilenetv3', 'levit', 'efficientformer', 'resnet',
'regnet', 'byobnet', 'byoanet', 'mlp_mixer'
'regnet', 'byobnet', 'byoanet', 'mlp_mixer', 'hiera',
]
# transformer / hybrid models don't support full set of spatial / feature APIs and/or have spatial output.
@ -60,7 +60,7 @@ NON_STD_FILTERS = [
'vit_*', 'tnt_*', 'pit_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*',
'convit_*', 'levit*', 'visformer*', 'deit*', 'xcit_*', 'crossvit_*', 'beit*',
'poolformer_*', 'volo_*', 'sequencer2d_*', 'mvitv2*', 'gcvit*', 'efficientformer*',
'eva_*', 'flexivit*', 'eva02*', 'samvit_*', 'efficientvit_m*', 'tiny_vit_*'
'eva_*', 'flexivit*', 'eva02*', 'samvit_*', 'efficientvit_m*', 'tiny_vit_*', 'hiera_*'
]
NUM_NON_STD = len(NON_STD_FILTERS)
@ -77,7 +77,7 @@ else:
EXCLUDE_FILTERS = ['*enormous*']
NON_STD_EXCLUDE_FILTERS = ['*gigantic*', '*enormous*']
EXCLUDE_JIT_FILTERS = []
EXCLUDE_JIT_FILTERS = ['hiera_*']
TARGET_FWD_SIZE = MAX_FWD_SIZE = 384
TARGET_BWD_SIZE = 128
@ -486,7 +486,7 @@ def _create_fx_model(model, train=False):
return fx_model
EXCLUDE_FX_FILTERS = ['vit_gi*']
EXCLUDE_FX_FILTERS = ['vit_gi*', 'hiera*']
# not enough memory to run fx on more models than other tests
if 'GITHUB_ACTIONS' in os.environ:
EXCLUDE_FX_FILTERS += [

View File

@ -19,9 +19,10 @@ from timm.data.random_erasing import RandomErasing
def transforms_noaug_train(
img_size: Union[int, Tuple[int, int]] = 224,
interpolation: str = 'bilinear',
use_prefetcher: bool = False,
mean: Tuple[float, ...] = IMAGENET_DEFAULT_MEAN,
std: Tuple[float, ...] = IMAGENET_DEFAULT_STD,
use_prefetcher: bool = False,
normalize: bool = True,
):
""" No-augmentation image transforms for training.
@ -31,6 +32,7 @@ def transforms_noaug_train(
mean: Image normalization mean.
std: Image normalization standard deviation.
use_prefetcher: Prefetcher enabled. Do not convert image to tensor or normalize.
normalize: Normalization tensor output w/ provided mean/std (if prefetcher not used).
Returns:
@ -45,6 +47,9 @@ def transforms_noaug_train(
if use_prefetcher:
# prefetcher and collate will handle tensor conversion and norm
tfl += [ToNumpy()]
elif not normalize:
# when normalize disabled, converted to tensor without scaling, keep original dtype
tfl += [transforms.PILToTensor()]
else:
tfl += [
transforms.ToTensor(),
@ -77,6 +82,7 @@ def transforms_imagenet_train(
re_count: int = 1,
re_num_splits: int = 0,
use_prefetcher: bool = False,
normalize: bool = True,
separate: bool = False,
):
""" ImageNet-oriented image transforms for training.
@ -103,6 +109,7 @@ def transforms_imagenet_train(
re_count: Number of random erasing regions.
re_num_splits: Control split of random erasing across batch size.
use_prefetcher: Prefetcher enabled. Do not convert image to tensor or normalize.
normalize: Normalize tensor output w/ provided mean/std (if prefetcher not used).
separate: Output transforms in 3-stage tuple.
Returns:
@ -209,12 +216,15 @@ def transforms_imagenet_train(
if use_prefetcher:
# prefetcher and collate will handle tensor conversion and norm
final_tfl += [ToNumpy()]
elif not normalize:
# when normalize disable, converted to tensor without scaling, keeps original dtype
final_tfl += [transforms.PILToTensor()]
else:
final_tfl += [
transforms.ToTensor(),
transforms.Normalize(
mean=torch.tensor(mean),
std=torch.tensor(std)
std=torch.tensor(std),
),
]
if re_prob > 0.:
@ -243,6 +253,7 @@ def transforms_imagenet_eval(
mean: Tuple[float, ...] = IMAGENET_DEFAULT_MEAN,
std: Tuple[float, ...] = IMAGENET_DEFAULT_STD,
use_prefetcher: bool = False,
normalize: bool = True,
):
""" ImageNet-oriented image transform for evaluation and inference.
@ -255,6 +266,7 @@ def transforms_imagenet_eval(
mean: Image normalization mean.
std: Image normalization standard deviation.
use_prefetcher: Prefetcher enabled. Do not convert image to tensor or normalize.
normalize: Normalize tensor output w/ provided mean/std (if prefetcher not used).
Returns:
Composed transform pipeline
@ -304,13 +316,16 @@ def transforms_imagenet_eval(
if use_prefetcher:
# prefetcher and collate will handle tensor conversion and norm
tfl += [ToNumpy()]
elif not normalize:
# when normalize disabled, converted to tensor without scaling, keeps original dtype
tfl += [transforms.PILToTensor()]
else:
tfl += [
transforms.ToTensor(),
transforms.Normalize(
mean=torch.tensor(mean),
std=torch.tensor(std),
)
),
]
return transforms.Compose(tfl)
@ -342,6 +357,7 @@ def create_transform(
crop_border_pixels: Optional[int] = None,
tf_preprocessing: bool = False,
use_prefetcher: bool = False,
normalize: bool = True,
separate: bool = False,
):
"""
@ -373,6 +389,7 @@ def create_transform(
crop_border_pixels: Inference crop border of specified # pixels around edge of original image.
tf_preprocessing: Use TF 1.0 inference preprocessing for testing model ports
use_prefetcher: Pre-fetcher enabled. Do not convert image to tensor or normalize.
normalize: Normalization tensor output w/ provided mean/std (if prefetcher not used).
separate: Output transforms in 3-stage tuple.
Returns:
@ -397,9 +414,10 @@ def create_transform(
transform = transforms_noaug_train(
img_size,
interpolation=interpolation,
use_prefetcher=use_prefetcher,
mean=mean,
std=std,
use_prefetcher=use_prefetcher,
normalize=normalize,
)
elif is_training:
transform = transforms_imagenet_train(
@ -415,13 +433,14 @@ def create_transform(
gaussian_blur_prob=gaussian_blur_prob,
auto_augment=auto_augment,
interpolation=interpolation,
use_prefetcher=use_prefetcher,
mean=mean,
std=std,
re_prob=re_prob,
re_mode=re_mode,
re_count=re_count,
re_num_splits=re_num_splits,
use_prefetcher=use_prefetcher,
normalize=normalize,
separate=separate,
)
else:
@ -429,12 +448,13 @@ def create_transform(
transform = transforms_imagenet_eval(
img_size,
interpolation=interpolation,
use_prefetcher=use_prefetcher,
mean=mean,
std=std,
crop_pct=crop_pct,
crop_mode=crop_mode,
crop_border_pixels=crop_border_pixels,
use_prefetcher=use_prefetcher,
normalize=normalize,
)
return transform

View File

@ -108,7 +108,7 @@ class ClassifierHead(nn.Module):
self.fc = fc
self.flatten = nn.Flatten(1) if use_conv and pool_type else nn.Identity()
def reset(self, num_classes, pool_type=None):
def reset(self, num_classes: int, pool_type: Optional[str] = None):
if pool_type is not None and pool_type != self.global_pool.pool_type:
self.global_pool, self.fc = create_classifier(
self.in_features,
@ -180,7 +180,7 @@ class NormMlpClassifierHead(nn.Module):
self.drop = nn.Dropout(drop_rate)
self.fc = linear_layer(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
def reset(self, num_classes, pool_type=None):
def reset(self, num_classes: int, pool_type: Optional[str] = None):
if pool_type is not None:
self.global_pool = SelectAdaptivePool2d(pool_type=pool_type)
self.flatten = nn.Flatten(1) if pool_type else nn.Identity()

View File

@ -47,7 +47,7 @@ def get_norm_layer(norm_layer):
if isinstance(norm_layer, str):
if not norm_layer:
return None
layer_name = norm_layer.replace('_', '')
layer_name = norm_layer.replace('_', '').lower()
norm_layer = _NORM_MAP[layer_name]
else:
norm_layer = norm_layer

View File

@ -6,7 +6,7 @@ import torch.nn as nn
class PatchDropout(nn.Module):
"""
https://arxiv.org/abs/2212.00794
https://arxiv.org/abs/2212.00794 and https://arxiv.org/pdf/2208.07220
"""
return_indices: torch.jit.Final[bool]

View File

@ -26,6 +26,7 @@ from .gcvit import *
from .ghostnet import *
from .hardcorenas import *
from .hgnet import *
from .hiera import *
from .hrnet import *
from .inception_next import *
from .inception_resnet_v2 import *

View File

@ -10,7 +10,8 @@ from torch.hub import load_state_dict_from_url
from timm.models._features import FeatureListNet, FeatureDictNet, FeatureHookNet, FeatureGetterNet
from timm.models._features_fx import FeatureGraphNet
from timm.models._helpers import load_state_dict
from timm.models._hub import has_hf_hub, download_cached_file, check_cached_file, load_state_dict_from_hf
from timm.models._hub import has_hf_hub, download_cached_file, check_cached_file, load_state_dict_from_hf,\
load_custom_from_hf
from timm.models._manipulate import adapt_input_conv
from timm.models._pretrained import PretrainedCfg
from timm.models._prune import adapt_model_from_file
@ -185,7 +186,12 @@ def load_pretrained(
elif load_from == 'hf-hub':
_logger.info(f'Loading pretrained weights from Hugging Face hub ({pretrained_loc})')
if isinstance(pretrained_loc, (list, tuple)):
state_dict = load_state_dict_from_hf(*pretrained_loc)
custom_load = pretrained_cfg.get('custom_load', False)
if isinstance(custom_load, str) and custom_load == 'hf':
load_custom_from_hf(*pretrained_loc, model)
return
else:
state_dict = load_state_dict_from_hf(*pretrained_loc)
else:
state_dict = load_state_dict_from_hf(pretrained_loc)
else:

View File

@ -118,6 +118,7 @@ class FeatureGraphNet(nn.Module):
out_indices: Tuple[int, ...],
out_map: Optional[Dict] = None,
output_fmt: str = 'NCHW',
return_dict: bool = False,
):
super().__init__()
assert has_fx_feature_extraction, 'Please update to PyTorch 1.10+, torchvision 0.11+ for FX feature extraction'
@ -127,9 +128,13 @@ class FeatureGraphNet(nn.Module):
self.output_fmt = Format(output_fmt)
return_nodes = _get_return_layers(self.feature_info, out_map)
self.graph_module = create_feature_extractor(model, return_nodes)
self.return_dict = return_dict
def forward(self, x):
return list(self.graph_module(x).values())
out = self.graph_module(x)
if self.return_dict:
return out
return list(out.values())
class GraphExtractNet(nn.Module):
@ -144,19 +149,23 @@ class GraphExtractNet(nn.Module):
model: model to extract features from
return_nodes: node names to return features from (dict or list)
squeeze_out: if only one output, and output in list format, flatten to single tensor
return_dict: return as dictionary from extractor with node names as keys, ignores squeeze_out arg
"""
def __init__(
self,
model: nn.Module,
return_nodes: Union[Dict[str, str], List[str]],
squeeze_out: bool = True,
return_dict: bool = False,
):
super().__init__()
self.squeeze_out = squeeze_out
self.graph_module = create_feature_extractor(model, return_nodes)
self.return_dict = return_dict
def forward(self, x) -> Union[List[torch.Tensor], torch.Tensor]:
out = list(self.graph_module(x).values())
if self.squeeze_out and len(out) == 1:
return out[0]
return out
out = self.graph_module(x)
if self.return_dict:
return out
out = list(out.values())
return out[0] if self.squeeze_out and len(out) == 1 else out

View File

@ -190,6 +190,13 @@ def load_state_dict_from_hf(model_id: str, filename: str = HF_WEIGHTS_NAME):
return torch.load(cached_file, map_location='cpu')
def load_custom_from_hf(model_id: str, filename: str, model: torch.nn.Module):
assert has_hf_hub(True)
hf_model_id, hf_revision = hf_split(model_id)
cached_file = hf_hub_download(hf_model_id, filename=filename, revision=hf_revision)
return model.load_pretrained(cached_file)
def save_config_for_hf(
model,
config_path: str,

View File

@ -395,7 +395,7 @@ class Beit(nn.Module):
def get_classifier(self):
return self.head
def reset_classifier(self, num_classes, global_pool=None):
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
self.num_classes = num_classes
if global_pool is not None:
self.global_pool = global_pool

View File

@ -331,7 +331,7 @@ class Cait(nn.Module):
def get_classifier(self):
return self.head
def reset_classifier(self, num_classes, global_pool=None):
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
self.num_classes = num_classes
if global_pool is not None:
assert global_pool in ('', 'token', 'avg')

View File

@ -7,8 +7,7 @@ Official CoaT code at: https://github.com/mlpc-ucsd/CoaT
Modified from timm/models/vision_transformer.py
"""
from functools import partial
from typing import Tuple, List, Union
from typing import List, Optional, Union, Tuple
import torch
import torch.nn as nn
@ -560,7 +559,7 @@ class CoaT(nn.Module):
def get_classifier(self):
return self.head
def reset_classifier(self, num_classes, global_pool=None):
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
self.num_classes = num_classes
if global_pool is not None:
assert global_pool in ('token', 'avg')

View File

@ -21,8 +21,7 @@ Modifications and additions for timm hacked together by / Copyright 2021, Ross W
'''These modules are adapted from those of timm, see
https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
'''
from functools import partial
from typing import Optional
import torch
import torch.nn as nn
@ -349,7 +348,7 @@ class ConVit(nn.Module):
def get_classifier(self):
return self.head
def reset_classifier(self, num_classes, global_pool=None):
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
self.num_classes = num_classes
if global_pool is not None:
assert global_pool in ('', 'token', 'avg')

View File

@ -1,6 +1,8 @@
""" ConvMixer
"""
from typing import Optional
import torch
import torch.nn as nn
@ -75,7 +77,7 @@ class ConvMixer(nn.Module):
def get_classifier(self):
return self.head
def reset_classifier(self, num_classes, global_pool=None):
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
self.num_classes = num_classes
if global_pool is not None:
self.pooling = SelectAdaptivePool2d(pool_type=global_pool, flatten=True)

View File

@ -37,7 +37,6 @@ Modifications and additions for timm hacked together by / Copyright 2022, Ross W
# LICENSE file in the root directory of this source tree (Attribution-NonCommercial 4.0 International (CC BY-NC 4.0))
# No code was used directly from ConvNeXt-V2, however the weights are CC BY-NC 4.0 so beware if using commercially.
from collections import OrderedDict
from functools import partial
from typing import Callable, List, Optional, Tuple, Union

View File

@ -25,8 +25,7 @@ Modifed from Timm. https://github.com/rwightman/pytorch-image-models/blob/master
"""
from functools import partial
from typing import List
from typing import Tuple
from typing import List, Optional, Tuple
import torch
import torch.hub
@ -419,7 +418,7 @@ class CrossVit(nn.Module):
def get_classifier(self):
return self.head
def reset_classifier(self, num_classes, global_pool=None):
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
self.num_classes = num_classes
if global_pool is not None:
assert global_pool in ('token', 'avg')

View File

@ -12,7 +12,7 @@ DaViT model defs and weights adapted from https://github.com/dingmyu/davit, orig
# All rights reserved.
# This source code is licensed under the MIT license
from functools import partial
from typing import Tuple
from typing import Optional, Tuple
import torch
import torch.nn as nn
@ -568,7 +568,7 @@ class DaVit(nn.Module):
def get_classifier(self):
return self.head.fc
def reset_classifier(self, num_classes, global_pool=None):
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
self.head.reset(num_classes, global_pool)
def forward_features(self, x):

View File

@ -11,7 +11,7 @@ Modifications copyright 2021, Ross Wightman
# Copyright (c) 2015-present, Facebook, Inc.
# All rights reserved.
from functools import partial
from typing import Sequence, Union
from typing import Optional
import torch
from torch import nn as nn
@ -20,7 +20,6 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import resample_abs_pos_embed
from timm.models.vision_transformer import VisionTransformer, trunc_normal_, checkpoint_filter_fn
from ._builder import build_model_with_cfg
from ._manipulate import checkpoint_seq
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
__all__ = ['VisionTransformerDistilled'] # model_registry will add each entrypoint fn to this
@ -64,7 +63,7 @@ class VisionTransformerDistilled(VisionTransformer):
def get_classifier(self):
return self.head, self.head_dist
def reset_classifier(self, num_classes, global_pool=None):
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
self.num_classes = num_classes
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()

View File

@ -8,7 +8,6 @@ Original code and weights from https://github.com/mmaaz60/EdgeNeXt
Modifications and additions for timm by / Copyright 2022, Ross Wightman
"""
import math
from collections import OrderedDict
from functools import partial
from typing import Tuple
@ -17,7 +16,7 @@ import torch.nn.functional as F
from torch import nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import trunc_normal_tf_, DropPath, LayerNorm2d, Mlp, SelectAdaptivePool2d, create_conv2d, \
from timm.layers import trunc_normal_tf_, DropPath, LayerNorm2d, Mlp, create_conv2d, \
use_fused_attn, NormMlpClassifierHead, ClassifierHead
from ._builder import build_model_with_cfg
from ._features_fx import register_notrace_module

View File

@ -449,7 +449,7 @@ class EfficientFormer(nn.Module):
def get_classifier(self):
return self.head, self.head_dist
def reset_classifier(self, num_classes, global_pool=None):
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
self.num_classes = num_classes
if global_pool is not None:
self.global_pool = global_pool

View File

@ -16,7 +16,7 @@ Modifications and timm support by / Copyright 2023, Ross Wightman
"""
import math
from functools import partial
from typing import Dict
from typing import Dict, Optional
import torch
import torch.nn as nn
@ -612,7 +612,7 @@ class EfficientFormerV2(nn.Module):
def get_classifier(self):
return self.head, self.head_dist
def reset_classifier(self, num_classes, global_pool=None):
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
self.num_classes = num_classes
if global_pool is not None:
self.global_pool = global_pool

View File

@ -13,7 +13,6 @@ from functools import partial
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules.batchnorm import _BatchNorm
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import SelectAdaptivePool2d, create_conv2d, GELUTanh
@ -740,7 +739,7 @@ class EfficientVit(nn.Module):
def get_classifier(self):
return self.head.classifier[-1]
def reset_classifier(self, num_classes, global_pool=None):
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
self.num_classes = num_classes
if global_pool is not None:
self.global_pool = global_pool
@ -858,7 +857,7 @@ class EfficientVitLarge(nn.Module):
def get_classifier(self):
return self.head.classifier[-1]
def reset_classifier(self, num_classes, global_pool=None):
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
self.num_classes = num_classes
if global_pool is not None:
self.global_pool = global_pool

View File

@ -9,7 +9,7 @@ Adapted from official impl at https://github.com/microsoft/Cream/tree/main/Effic
__all__ = ['EfficientVitMsra']
import itertools
from collections import OrderedDict
from typing import Dict
from typing import Dict, Optional
import torch
import torch.nn as nn
@ -464,7 +464,7 @@ class EfficientVitMsra(nn.Module):
def get_classifier(self):
return self.head.linear
def reset_classifier(self, num_classes, global_pool=None):
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
self.num_classes = num_classes
if global_pool is not None:
if global_pool == 'avg':

View File

@ -539,7 +539,7 @@ class Eva(nn.Module):
def get_classifier(self):
return self.head
def reset_classifier(self, num_classes, global_pool=None):
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
self.num_classes = num_classes
if global_pool is not None:
self.global_pool = global_pool

View File

@ -396,7 +396,7 @@ class ReparamLargeKernelConv(nn.Module):
@staticmethod
def _fuse_bn(
conv: torch.Tensor, bn: nn.BatchNorm2d
conv: nn.Conv2d, bn: nn.BatchNorm2d
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Method to fuse batchnorm layer with conv layer.
@ -1232,7 +1232,7 @@ class FastVit(nn.Module):
def get_classifier(self):
return self.head.fc
def reset_classifier(self, num_classes, global_pool=None):
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
self.num_classes = num_classes
self.head.reset(num_classes, global_pool)

View File

@ -454,7 +454,7 @@ class FocalNet(nn.Module):
def get_classifier(self):
return self.head.fc
def reset_classifier(self, num_classes, global_pool=None):
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
self.head.reset(num_classes, pool_type=global_pool)
def forward_features(self, x):

View File

@ -489,7 +489,7 @@ class GlobalContextVit(nn.Module):
def get_classifier(self):
return self.head.fc
def reset_classifier(self, num_classes, global_pool=None):
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
self.num_classes = num_classes
if global_pool is None:
global_pool = self.head.global_pool.pool_type

936
timm/models/hiera.py Normal file
View File

@ -0,0 +1,936 @@
""" An PyTorch implementation of Hiera
Adapted for timm from originals at https://github.com/facebookresearch/hiera
"""
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
#
# Hiera: A Hierarchical Vision Transformer without the Bells-and-Whistles
#
# Chaitanya Ryali, Yuan-Ting Hu, Daniel Bolya, Chen Wei, Haoqi Fan,
# Po-Yao Huang, Vaibhav Aggarwal, Arkabandhu Chowdhury, Omid Poursaeed,
# Judy Hoffman, Jitendra Malik, Yanghao Li, Christoph Feichtenhofer.
#
# Paper: https://arxiv.org/abs/2306.00989/
#
# References:
# slowfast: https://github.com/facebookresearch/SlowFast
# timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
# --------------------------------------------------------
import math
from functools import partial
from typing import Callable, Dict, List, Optional, Tuple, Type, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import DropPath, Mlp, use_fused_attn, _assert, get_norm_layer
from ._registry import generate_default_cfgs, register_model
from ._builder import build_model_with_cfg
from ._features import feature_take_indices
from ._features_fx import register_notrace_function
def conv_nd(n: int) -> Type[nn.Module]:
"""
Returns a conv with nd (e.g., Conv2d for n=2). Work up to n=3.
If you wanted a 4d Hiera, you could probably just implement this for n=4. (no promises)
"""
return [nn.Identity, nn.Conv1d, nn.Conv2d, nn.Conv3d][n]
@register_notrace_function
def get_resized_mask(target_size: List[int], mask: torch.Tensor) -> torch.Tensor:
# target_size: [(T), (H), W]
# (spatial) mask: [B, C, (t), (h), w]
if mask is None:
return mask
_assert(len(mask.shape[2:]) == len(target_size), "mask spatial shape and target_size must match.")
if mask.shape[2:] != target_size:
return F.interpolate(mask.float(), size=target_size)
return mask
def undo_windowing(
x: torch.Tensor,
shape: List[int],
mu_shape: List[int],
) -> torch.Tensor:
"""
Restore spatial organization by undoing windowed organization of mask units.
Args:
x: organized by mask units windows, e.g. in 2d [B, #MUy*#MUx, MUy, MUx, C]
shape: current spatial shape, if it were not organized into mask unit
windows, e.g. in 2d [B, #MUy*MUy, #MUx*MUx, C].
mu_shape: current mask unit shape, e.g. in 2d [MUy, MUx]
Returns:
x: e.g. in 2d, [B, #MUy*MUy, #MUx*MUx, C]
"""
D = len(shape)
B, C = x.shape[0], x.shape[-1]
# [B, #MUy*#MUx, MUy, MUx, C] -> [B, #MUy, #MUx, MUy, MUx, C]
num_MUs = [s // mu for s, mu in zip(shape, mu_shape)]
x = x.view(B, *num_MUs, *mu_shape, C)
# [B, #MUy, #MUx, MUy, MUx, C] -> [B, #MUy*MUy, #MUx*MUx, C]
permute = (
[0]
+ sum([list(p) for p in zip(range(1, 1 + D), range(1 + D, 1 + 2 * D))], [])
+ [len(x.shape) - 1]
)
x = x.permute(permute).reshape(B, *shape, C)
return x
class Unroll(nn.Module):
"""
Reorders the tokens such that patches are contiguous in memory.
E.g., given [B, (H, W), C] and stride of (Sy, Sx), this will re-order the tokens as
[B, (Sy, Sx, H // Sy, W // Sx), C]
This allows operations like Max2d to be computed as x.view(B, Sx*Sy, -1, C).max(dim=1).
Not only is this faster, but it also makes it easy to support inputs of arbitrary
dimensions in addition to patch-wise sparsity.
Performing this operation multiple times in sequence puts entire windows as contiguous
in memory. For instance, if you applied the stride (2, 2) 3 times, entire windows of
size 8x8 would be contiguous in memory, allowing operations like mask unit attention
computed easily and efficiently, while also allowing max to be applied sequentially.
Note: This means that intermediate values of the model are not in HxW order, so they
need to be re-rolled if you want to use the intermediate values as a HxW feature map.
The last block of the network is fine though, since by then the strides are all consumed.
"""
def __init__(
self,
input_size: Tuple[int, ...],
patch_stride: Tuple[int, ...],
unroll_schedule: List[Tuple[int, ...]],
):
super().__init__()
self.size = [i // s for i, s in zip(input_size, patch_stride)]
self.schedule = unroll_schedule
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Input: Flattened patch embeddings [B, N, C]
Output: Patch embeddings [B, N, C] permuted such that [B, 4, N//4, C].max(1) etc. performs MaxPoolNd
"""
B, _, C = x.shape
cur_size = self.size
x = x.view(*([B] + cur_size + [C]))
for strides in self.schedule:
# Move patches with the given strides to the batch dimension
# Create a view of the tensor with the patch stride as separate dims
# For example in 2d: [B, H // Sy, Sy, W // Sx, Sx, C]
cur_size = [i // s for i, s in zip(cur_size, strides)]
new_shape = [B] + sum([[i, s] for i, s in zip(cur_size, strides)], []) + [C]
x = x.view(new_shape)
# Move the patch stride into the batch dimension
# For example in 2d: [B, Sy, Sx, H // Sy, W // Sx, C]
L = len(new_shape)
permute = [0] + list(range(2, L - 1, 2)) + list(range(1, L - 1, 2)) + [L - 1]
x = x.permute(permute)
# Now finally flatten the relevant dims into the batch dimension
x = x.flatten(0, len(strides))
B *= math.prod(strides)
x = x.reshape(-1, math.prod(self.size), C)
return x
class Reroll(nn.Module):
"""
Undos the "unroll" operation so that you can use intermediate features.
"""
def __init__(
self,
input_size: Tuple[int, ...],
patch_stride: Tuple[int, ...],
unroll_schedule: List[Tuple[int, ...]],
stage_ends: List[int],
q_pool: int,
):
super().__init__()
self.size = [i // s for i, s in zip(input_size, patch_stride)]
# The first stage has to reverse everything
# The next stage has to reverse all but the first unroll, etc.
self.schedule = {}
size = self.size
for i in range(stage_ends[-1] + 1):
self.schedule[i] = unroll_schedule, size
# schedule unchanged if no pooling at a stage end
if i in stage_ends[:q_pool]:
if len(unroll_schedule) > 0:
size = [n // s for n, s in zip(size, unroll_schedule[0])]
unroll_schedule = unroll_schedule[1:]
def forward(
self,
x: torch.Tensor,
block_idx: int,
mask: torch.Tensor = None
) -> torch.Tensor:
"""
Roll the given tensor back up to spatial order assuming it's from the given block.
If no mask is provided:
- Returns [B, H, W, C] for 2d, [B, T, H, W, C] for 3d, etc.
If a mask is provided:
- Returns [B, #MUs, MUy, MUx, C] for 2d, etc.
"""
schedule, size = self.schedule[block_idx]
B, N, C = x.shape
D = len(size)
cur_mu_shape = [1] * D
for strides in schedule:
# Extract the current patch from N
x = x.view(B, *strides, N // math.prod(strides), *cur_mu_shape, C)
# Move that patch into the current MU
# Example in 2d: [B, Sy, Sx, N//(Sy*Sx), MUy, MUx, C] -> [B, N//(Sy*Sx), Sy, MUy, Sx, MUx, C]
L = len(x.shape)
permute = (
[0, 1 + D]
+ sum([list(p) for p in zip(range(1, 1 + D), range(1 + D + 1, L - 1))], [])
+ [L - 1]
)
x = x.permute(permute)
# Reshape to [B, N//(Sy*Sx), *MU, C]
for i in range(D):
cur_mu_shape[i] *= strides[i]
x = x.reshape(B, -1, *cur_mu_shape, C)
N = x.shape[1]
# Current shape (e.g., 2d: [B, #MUy*#MUx, MUy, MUx, C])
x = x.view(B, N, *cur_mu_shape, C)
# If masked, return [B, #MUs, MUy, MUx, C]
if mask is not None:
return x
# If not masked, we can return [B, H, W, C]
x = undo_windowing(x, size, cur_mu_shape)
return x
class MaskUnitAttention(nn.Module):
"""
Computes either Mask Unit or Global Attention. Also is able to perform q pooling.
Note: this assumes the tokens have already been flattened and unrolled into mask units.
See `Unroll` for more details.
"""
fused_attn: torch.jit.Final[bool]
def __init__(
self,
dim: int,
dim_out: int,
heads: int,
q_stride: int = 1,
window_size: int = 0,
use_mask_unit_attn: bool = False,
):
"""
Args:
- dim, dim_out: The input and output feature dimensions.
- heads: The number of attention heads.
- q_stride: If greater than 1, pool q with this stride. The stride should be flattened (e.g., 2x2 = 4).
- window_size: The current (flattened) size of a mask unit *after* pooling (if any).
- use_mask_unit_attn: Use Mask Unit or Global Attention.
"""
super().__init__()
self.dim = dim
self.dim_out = dim_out
self.heads = heads
self.q_stride = q_stride
self.head_dim = dim_out // heads
self.scale = self.head_dim ** -0.5
self.fused_attn = use_fused_attn()
self.qkv = nn.Linear(dim, 3 * dim_out)
self.proj = nn.Linear(dim_out, dim_out)
self.window_size = window_size
self.use_mask_unit_attn = use_mask_unit_attn
def forward(self, x: torch.Tensor) -> torch.Tensor:
""" Input should be of shape [batch, tokens, channels]. """
B, N, _ = x.shape
num_windows = (N // (self.q_stride * self.window_size)) if self.use_mask_unit_attn else 1
qkv = self.qkv(x).reshape(B, -1, num_windows, 3, self.heads, self.head_dim).permute(3, 0, 4, 2, 1, 5)
q, k, v = qkv.unbind(0)
if self.q_stride > 1:
# Refer to Unroll to see how this performs a maxpool-Nd
q = q.view(B, self.heads, num_windows, self.q_stride, -1, self.head_dim).amax(dim=3)
if self.fused_attn:
# Note: the original paper did *not* use SDPA, it's a free boost!
x = F.scaled_dot_product_attention(q, k, v)
else:
attn = (q * self.scale) @ k.transpose(-1, -2)
attn = attn.softmax(dim=-1)
x = attn @ v
x = x.transpose(1, 3).reshape(B, -1, self.dim_out)
x = self.proj(x)
return x
class HieraBlock(nn.Module):
def __init__(
self,
dim: int,
dim_out: int,
heads: int,
mlp_ratio: float = 4.0,
drop_path: float = 0.0,
norm_layer: nn.Module = nn.LayerNorm,
act_layer: nn.Module = nn.GELU,
q_stride: int = 1,
window_size: int = 0,
use_expand_proj: bool = True,
use_mask_unit_attn: bool = False,
):
super().__init__()
self.dim = dim
self.dim_out = dim_out
self.norm1 = norm_layer(dim)
if dim != dim_out:
self.do_expand = True
if use_expand_proj:
self.proj = nn.Linear(dim, dim_out)
else:
assert dim_out == dim * 2
self.proj = None
else:
self.do_expand = False
self.proj = None
self.attn = MaskUnitAttention(
dim,
dim_out,
heads,
q_stride,
window_size,
use_mask_unit_attn
)
self.drop_path1 = DropPath(drop_path) if drop_path > 0 else nn.Identity()
self.norm2 = norm_layer(dim_out)
self.mlp = Mlp(dim_out, int(dim_out * mlp_ratio), act_layer=act_layer)
self.drop_path2 = DropPath(drop_path) if drop_path > 0 else nn.Identity()
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Attention + Q Pooling
x_norm = self.norm1(x)
if self.do_expand:
if self.proj is not None:
x = self.proj(x_norm)
x = x.view(x.shape[0], self.attn.q_stride, -1, x.shape[-1]).amax(dim=1) # max-pool
else:
x = torch.cat([
x.view(x.shape[0], self.attn.q_stride, -1, x.shape[-1]).amax(dim=1), # max-pool
x.view(x.shape[0], self.attn.q_stride, -1, x.shape[-1]).mean(dim=1), # avg-pool
],
dim=-1,
)
x = x + self.drop_path1(self.attn(x_norm))
# MLP
x = x + self.drop_path2(self.mlp(self.norm2(x)))
return x
class NormClassifierHead(nn.Module):
def __init__(
self,
in_features: int,
num_classes: int,
pool_type: str = 'avg',
drop_rate: float = 0.0,
norm_layer: Union[str, Callable] = 'layernorm',
):
super().__init__()
norm_layer = get_norm_layer(norm_layer)
assert pool_type in ('avg', '')
self.in_features = self.num_features = in_features
self.pool_type = pool_type
self.norm = norm_layer(in_features)
self.drop = nn.Dropout(drop_rate) if drop_rate else nn.Identity()
self.fc = nn.Linear(in_features, num_classes) if num_classes > 0 else nn.Identity()
def reset(self, num_classes: int, pool_type: Optional[str] = None, other: bool = False):
if pool_type is not None:
assert pool_type in ('avg', '')
self.pool_type = pool_type
if other:
# reset other non-fc layers
self.norm = nn.Identity()
self.fc = nn.Linear(self.in_features, num_classes) if num_classes > 0 else nn.Identity()
def forward(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor:
if self.pool_type == 'avg':
x = x.mean(dim=1)
x = self.norm(x)
x = self.drop(x)
if pre_logits:
return x
x = self.fc(x)
return x
class PatchEmbed(nn.Module):
"""Patch embed that supports any number of spatial dimensions (1d, 2d, 3d)."""
def __init__(
self,
dim_in: int,
dim_out: int,
kernel: Tuple[int, ...],
stride: Tuple[int, ...],
padding: Tuple[int, ...],
reshape: bool = True,
):
super().__init__()
# Support any number of spatial dimensions
self.spatial_dims = len(kernel)
self.reshape = reshape
self.proj = conv_nd(self.spatial_dims)(
dim_in,
dim_out,
kernel_size=kernel,
stride=stride,
padding=padding,
)
def forward(
self,
x: torch.Tensor,
mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if mask is not None:
mask = get_resized_mask(target_size=x.shape[2:], mask=mask)
x = self.proj(x * mask.to(torch.bool))
else:
x = self.proj(x)
if self.reshape:
x = x.reshape(x.shape[0], x.shape[1], -1).transpose(2, 1)
return x
class Hiera(nn.Module):
def __init__(
self,
img_size: Tuple[int, ...] = (224, 224),
in_chans: int = 3,
embed_dim: int = 96, # initial embed dim
num_heads: int = 1, # initial number of heads
num_classes: int = 1000,
global_pool: str = 'avg',
stages: Tuple[int, ...] = (2, 3, 16, 3),
q_pool: int = 3, # number of q_pool stages
q_stride: Tuple[int, ...] = (2, 2),
mask_unit_size: Tuple[int, ...] = (8, 8), # must divide q_stride ** (#stages-1)
# mask_unit_attn: which stages use mask unit attention?
mask_unit_attn: Tuple[bool, ...] = (True, True, False, False),
dim_mul: float = 2.0,
head_mul: float = 2.0,
patch_kernel: Tuple[int, ...] = (7, 7),
patch_stride: Tuple[int, ...] = (4, 4),
patch_padding: Tuple[int, ...] = (3, 3),
mlp_ratio: float = 4.0,
drop_path_rate: float = 0.0,
norm_layer: Union[str, nn.Module] = "LayerNorm",
drop_rate: float = 0.0,
head_init_scale: float = 0.001,
sep_pos_embed: bool = False,
):
super().__init__()
self.num_classes = num_classes
self.grad_checkpointing = False
norm_layer = get_norm_layer(norm_layer)
self.patch_stride = patch_stride
self.tokens_spatial_shape = [i // s for i, s in zip(img_size, patch_stride)]
num_tokens = math.prod(self.tokens_spatial_shape)
flat_mu_size = math.prod(mask_unit_size)
flat_q_stride = math.prod(q_stride)
assert q_pool < len(stages)
self.q_pool, self.q_stride = q_pool, q_stride
self.mu_size, self.mask_unit_size = flat_mu_size, mask_unit_size
self.mask_spatial_shape = [i // s for i, s in zip(self.tokens_spatial_shape, self.mask_unit_size)]
self.stage_ends = [sum(stages[:i]) - 1 for i in range(1, len(stages) + 1)]
self.patch_embed = PatchEmbed(
in_chans,
embed_dim,
patch_kernel,
patch_stride,
patch_padding,
#reshape=False, # leave spatial / temporal dims in output
)
if sep_pos_embed:
self.pos_embed = None
self.pos_embed_spatial = nn.Parameter(
torch.zeros(1, self.tokens_spatial_shape[1] * self.tokens_spatial_shape[2], embed_dim)
)
self.pos_embed_temporal = nn.Parameter(
torch.zeros(1, self.tokens_spatial_shape[0], embed_dim)
)
else:
self.pos_embed = nn.Parameter(torch.zeros(1, num_tokens, embed_dim))
self.pos_embed_spatial = None
self.pos_embed_temporal = None
# Setup roll and reroll modules
self.unroll = Unroll(
img_size,
patch_stride,
[q_stride] * len(self.stage_ends[:-1])
)
self.reroll = Reroll(
img_size,
patch_stride,
[q_stride] * len(self.stage_ends[:-1]),
self.stage_ends,
q_pool,
)
# q_pool locations
q_pool_blocks = [x + 1 for x in self.stage_ends[:q_pool]]
# Transformer blocks
cur_stage = 0
depth = sum(stages)
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
self.blocks = nn.ModuleList()
self.feature_info = []
for i in range(depth):
dim_out = embed_dim
# Mask unit or global attention.
# Lag by 1 block, so that global attention,
# applied post pooling on lower resolution
use_mask_unit_attn = mask_unit_attn[cur_stage]
if i - 1 in self.stage_ends:
dim_out = int(embed_dim * dim_mul)
num_heads = int(num_heads * head_mul)
cur_stage += 1
if i in q_pool_blocks:
flat_mu_size //= flat_q_stride
block = HieraBlock(
dim=embed_dim,
dim_out=dim_out,
heads=num_heads,
mlp_ratio=mlp_ratio,
drop_path=dpr[i],
norm_layer=norm_layer,
q_stride=(flat_q_stride if i in q_pool_blocks else 1),
window_size=flat_mu_size,
use_mask_unit_attn=use_mask_unit_attn,
)
embed_dim = dim_out
if i in self.stage_ends:
self.feature_info += [
dict(num_chs=dim_out, reduction=2**(cur_stage+2), module=f'blocks.{self.stage_ends[cur_stage]}')]
self.blocks.append(block)
self.num_features = embed_dim
self.head = NormClassifierHead(
embed_dim,
num_classes,
pool_type=global_pool,
drop_rate=drop_rate,
norm_layer=norm_layer,
)
# Initialize everything
if sep_pos_embed:
nn.init.trunc_normal_(self.pos_embed_spatial, std=0.02)
nn.init.trunc_normal_(self.pos_embed_temporal, std=0.02)
else:
nn.init.trunc_normal_(self.pos_embed, std=0.02)
self.apply(partial(self._init_weights))
if isinstance(self.head.fc, nn.Linear):
self.head.fc.weight.data.mul_(head_init_scale)
self.head.fc.bias.data.mul_(head_init_scale)
def _init_weights(self, m, init_bias=0.02):
if isinstance(m, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d)):
nn.init.trunc_normal_(m.weight, std=0.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, init_bias)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, init_bias)
nn.init.constant_(m.weight, 1.0)
@torch.jit.ignore
def no_weight_decay(self):
if self.pos_embed is not None:
return ["pos_embed"]
else:
return ["pos_embed_spatial", "pos_embed_temporal"]
@torch.jit.ignore
def group_matcher(self, coarse: bool = False) -> Dict:
return dict(
stem=r'^pos_embed|pos_embed_spatial|pos_embed_temporal|patch_embed', # stem and embed
blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))]
)
@torch.jit.ignore
def set_grad_checkpointing(self, enable: bool = True) -> None:
self.grad_checkpointing = enable
@torch.jit.ignore
def get_classifier(self):
return self.head.fc
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None, other: bool = False):
self.num_classes = num_classes
self.head.reset(num_classes, global_pool, other=other)
def get_random_mask(self, x: torch.Tensor, mask_ratio: float) -> torch.Tensor:
"""
Generates a random mask, mask_ratio fraction are dropped.
1 is *keep*, 0 is *remove*. Useful for MAE, FLIP, etc.
"""
B = x.shape[0]
# Tokens selected for masking at mask unit level
num_windows = math.prod(self.mask_spatial_shape) # num_mask_units
len_keep = int(num_windows * (1 - mask_ratio))
noise = torch.rand(B, num_windows, device=x.device)
# Sort noise for each sample
ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove
ids_restore = torch.argsort(ids_shuffle, dim=1)
# Generate the binary mask: 1 is *keep*, 0 is *remove*
# Note this is opposite to original MAE
mask = torch.zeros([B, num_windows], device=x.device)
mask[:, :len_keep] = 1
# Unshuffle to get the binary mask
mask = torch.gather(mask, dim=1, index=ids_restore)
return mask.bool()
def _pos_embed(self, x) -> torch.Tensor:
if self.pos_embed is not None:
pos_embed = self.pos_embed
else:
pos_embed = (
self.pos_embed_spatial.repeat(1, self.tokens_spatial_shape[0], 1)
+
torch.repeat_interleave(
self.pos_embed_temporal,
self.tokens_spatial_shape[1] * self.tokens_spatial_shape[2],
dim=1,
)
)
x = x + pos_embed
return x
def forward_intermediates(
self,
x: torch.Tensor,
mask: Optional[torch.Tensor] = None,
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
norm: bool = False,
stop_early: bool = True,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
""" Forward features that returns intermediates.
Args:
x: Input image tensor
indices: Take last n blocks if int, all if None, select matching indices if sequence
norm: Apply norm layer to all intermediates
stop_early: Stop iterating over blocks when last desired intermediate hit
output_fmt: Shape of intermediate feature outputs
intermediates_only: Only return intermediate features
Returns:
"""
assert not norm, 'normalization of features not supported'
assert output_fmt in ('NCHW',), 'Output format must be one of NCHW.'
take_indices, max_index = feature_take_indices(len(self.stage_ends), indices)
take_indices = [self.stage_ends[i] for i in take_indices]
max_index = self.stage_ends[max_index]
if mask is not None:
patch_mask = mask.view(x.shape[0], 1, *self.mask_spatial_shape) # B, C, *mask_spatial_shape
else:
patch_mask = None
x = self.patch_embed(x, mask=patch_mask)
x = self._pos_embed(x)
x = self.unroll(x)
# Discard masked tokens
if mask is not None:
x = x[mask[..., None].tile(1, self.mu_size, x.shape[2])].view(x.shape[0], -1, x.shape[-1])
intermediates = []
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
blocks = self.blocks
else:
blocks = self.blocks[:max_index + 1]
for i, blk in enumerate(blocks):
x = blk(x)
if i in take_indices:
intermediates.append(self.reroll(x, i, mask=mask).permute(0, 3, 1, 2))
if intermediates_only:
return intermediates
return x, intermediates
def prune_intermediate_layers(
self,
indices: Union[int, List[int], Tuple[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):
""" Prune layers not required for specified intermediates.
"""
take_indices, max_index = feature_take_indices(len(self.stage_ends), indices)
max_index = self.stage_ends[max_index]
self.blocks = self.blocks[:max_index + 1] # truncate blocks
if prune_head:
self.head.reset(0, other=True)
return take_indices
def forward_features(
self,
x: torch.Tensor,
mask: Optional[torch.Tensor] = None,
return_intermediates: bool = False,
) -> torch.Tensor:
"""
mask should be a boolean tensor of shape [B, #MUt*#MUy*#MUx] where #MU are the number of mask units in that dim.
Note: 1 in mask is *keep*, 0 is *remove*; mask.sum(dim=-1) should be the same across the batch.
"""
if mask is not None:
patch_mask = mask.view(x.shape[0], 1, *self.mask_spatial_shape) # B, C, *mask_spatial_shape
else:
patch_mask = None
x = self.patch_embed(x, mask=patch_mask)
x = self._pos_embed(x)
x = self.unroll(x)
# Discard masked tokens
if mask is not None:
x = x[mask[..., None].tile(1, self.mu_size, x.shape[2])].view(x.shape[0], -1, x.shape[-1])
intermediates = []
for i, blk in enumerate(self.blocks):
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint(blk, x)
else:
x = blk(x)
if return_intermediates and i in self.stage_ends:
intermediates.append(self.reroll(x, i, mask=mask))
# x may not always be in spatial order here.
# e.g. if q_pool = 2, mask_unit_size = (8, 8), and
# q_stride = (2, 2), not all unrolls were consumed,
# intermediates[-1] is x in spatial order
if return_intermediates:
return x, intermediates
return x
def forward_head(self, x, pre_logits: bool = False) -> torch.Tensor:
x = self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x)
return x
def forward(
self,
x: torch.Tensor,
mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
x = self.forward_features(x, mask=mask)
if mask is None:
x = self.forward_head(x)
return x
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'patch_embed.proj', 'classifier': 'head.fc',
**kwargs
}
default_cfgs = generate_default_cfgs({
"hiera_tiny_224.mae_in1k_ft_in1k": _cfg(
hf_hub_id='timm/',
license='cc-by-nc-4.0',
),
"hiera_tiny_224.mae": _cfg(
hf_hub_id='timm/',
license='cc-by-nc-4.0',
num_classes=0,
),
"hiera_small_224.mae_in1k_ft_in1k": _cfg(
hf_hub_id='timm/',
license='cc-by-nc-4.0',
),
"hiera_small_224.mae": _cfg(
hf_hub_id='timm/',
license='cc-by-nc-4.0',
num_classes=0,
),
"hiera_base_224.mae_in1k_ft_in1k": _cfg(
hf_hub_id='timm/',
license='cc-by-nc-4.0',
),
"hiera_base_224.mae": _cfg(
hf_hub_id='timm/',
license='cc-by-nc-4.0',
num_classes=0,
),
"hiera_base_plus_224.mae_in1k_ft_in1k": _cfg(
hf_hub_id='timm/',
license='cc-by-nc-4.0',
),
"hiera_base_plus_224.mae": _cfg(
hf_hub_id='timm/',
license='cc-by-nc-4.0',
num_classes=0,
),
"hiera_large_224.mae_in1k_ft_in1k": _cfg(
hf_hub_id='timm/',
license='cc-by-nc-4.0',
),
"hiera_large_224.mae": _cfg(
hf_hub_id='timm/',
license='cc-by-nc-4.0',
num_classes=0,
),
"hiera_huge_224.mae_in1k_ft_in1k": _cfg(
hf_hub_id='timm/',
license='cc-by-nc-4.0',
),
"hiera_huge_224.mae": _cfg(
hf_hub_id='timm/',
license='cc-by-nc-4.0',
num_classes=0,
),
})
def checkpoint_filter_fn(state_dict, model=None):
state_dict = state_dict.get('model_state', state_dict)
output = {}
for k, v in state_dict.items():
if k == 'pos_embed' and v.shape[1] != model.pos_embed.shape[1]:
# # To resize pos embedding when using model at different size from pretrained weights
# from timm.layers import resample_abs_pos_embed
# v = resample_abs_pos_embed(
# v,
# new_size=(64, 64),
# num_prefix_tokens=0,
# verbose=True,
# )
#v = F.interpolate(v.transpose(1, 2), (model.pos_embed.shape[1],)).transpose(1, 2)
pass
if 'head.projection.' in k:
k = k.replace('head.projection.', 'head.fc.')
if k.startswith('encoder_norm.'):
k = k.replace('encoder_norm.', 'head.norm.')
elif k.startswith('norm.'):
k = k.replace('norm.', 'head.norm.')
output[k] = v
return output
def _create_hiera(variant: str, pretrained: bool = False, **kwargs) -> Hiera:
out_indices = kwargs.pop('out_indices', 4)
return build_model_with_cfg(
Hiera,
variant,
pretrained,
pretrained_filter_fn=checkpoint_filter_fn,
feature_cfg=dict(out_indices=out_indices, feature_cls='getter'),
**kwargs,
)
@register_model
def hiera_tiny_224(pretrained = False, **kwargs):
model_args = dict(embed_dim=96, num_heads=1, stages=(1, 2, 7, 2))
return _create_hiera('hiera_tiny_224', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model
def hiera_small_224(pretrained = False, **kwargs):
model_args = dict(embed_dim=96, num_heads=1, stages=(1, 2, 11, 2))
return _create_hiera('hiera_small_224', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model
def hiera_base_224(pretrained = False, **kwargs):
model_args = dict(embed_dim=96, num_heads=1, stages=(2, 3, 16, 3))
return _create_hiera('hiera_base_224', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model
def hiera_base_plus_224(pretrained = False, **kwargs):
model_args = dict(embed_dim=112, num_heads=2, stages=(2, 3, 16, 3))
return _create_hiera('hiera_base_plus_224', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model
def hiera_large_224(pretrained = False, **kwargs):
model_args = dict(embed_dim=144, num_heads=2, stages=(2, 6, 36, 4))
return _create_hiera('hiera_large_224', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model
def hiera_huge_224(pretrained = False, **kwargs):
model_args = dict(embed_dim=256, num_heads=4, stages=(2, 6, 36, 4))
return _create_hiera('hiera_huge_224', pretrained=pretrained, **dict(model_args, **kwargs))

View File

@ -628,7 +628,7 @@ class Levit(nn.Module):
def get_classifier(self):
return self.head
def reset_classifier(self, num_classes, global_pool=None, distillation=None):
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None, distillation=None):
self.num_classes = num_classes
if global_pool is not None:
self.global_pool = global_pool
@ -730,7 +730,7 @@ class LevitDistilled(Levit):
def get_classifier(self):
return self.head, self.head_dist
def reset_classifier(self, num_classes, global_pool=None, distillation=None):
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None, distillation=None):
self.num_classes = num_classes
if global_pool is not None:
self.global_pool = global_pool

View File

@ -1248,7 +1248,7 @@ class MaxxVit(nn.Module):
def get_classifier(self):
return self.head.fc
def reset_classifier(self, num_classes, global_pool=None):
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
self.num_classes = num_classes
self.head.reset(num_classes, global_pool)

View File

@ -255,7 +255,7 @@ class MlpMixer(nn.Module):
def get_classifier(self):
return self.head
def reset_classifier(self, num_classes, global_pool=None):
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
self.num_classes = num_classes
if global_pool is not None:
assert global_pool in ('', 'avg')

View File

@ -622,8 +622,6 @@ def _gen_lcnet(variant: str, channel_multiplier: float = 1.0, pretrained: bool =
return model
def _gen_mobilenet_v4(variant: str, channel_multiplier: float = 1.0, pretrained: bool = False, **kwargs) -> MobileNetV3:
"""Creates a MobileNet-V4 model.

View File

@ -825,7 +825,7 @@ class MultiScaleVit(nn.Module):
def get_classifier(self):
return self.head.fc
def reset_classifier(self, num_classes, global_pool=None):
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
self.num_classes = num_classes
if global_pool is not None:
self.global_pool = global_pool

View File

@ -6,6 +6,7 @@ Next-ViT model defs and weights adapted from https://github.com/bytedance/Next-V
"""
# Copyright (c) ByteDance Inc. All rights reserved.
from functools import partial
from typing import Optional
import torch
import torch.nn.functional as F
@ -553,7 +554,7 @@ class NextViT(nn.Module):
def get_classifier(self):
return self.head.fc
def reset_classifier(self, num_classes, global_pool=None):
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
self.head.reset(num_classes, pool_type=global_pool)
def forward_features(self, x):

View File

@ -14,13 +14,13 @@ Modifications for timm by / Copyright 2020 Ross Wightman
import math
import re
from functools import partial
from typing import Sequence, Tuple
from typing import Optional, Sequence, Tuple
import torch
from torch import nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import trunc_normal_, to_2tuple, LayerNorm
from timm.layers import trunc_normal_, to_2tuple
from ._builder import build_model_with_cfg
from ._registry import register_model, generate_default_cfgs
from .vision_transformer import Block
@ -246,7 +246,7 @@ class PoolingVisionTransformer(nn.Module):
else:
return self.head
def reset_classifier(self, num_classes, global_pool=None):
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
self.num_classes = num_classes
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
if self.head_dist is not None:

View File

@ -16,7 +16,7 @@ Modifications and timm support by / Copyright 2022, Ross Wightman
"""
import math
from typing import Tuple, List, Callable, Union
from typing import Callable, List, Optional, Union
import torch
import torch.nn as nn
@ -379,7 +379,7 @@ class PyramidVisionTransformerV2(nn.Module):
def get_classifier(self):
return self.head
def reset_classifier(self, num_classes, global_pool=None):
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
self.num_classes = num_classes
if global_pool is not None:
assert global_pool in ('avg', '')

View File

@ -16,15 +16,16 @@ Adapted from official impl at https://github.com/jameslahm/RepViT
"""
__all__ = ['RepVit']
import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from ._registry import register_model, generate_default_cfgs
from ._builder import build_model_with_cfg
from timm.layers import SqueezeExcite, trunc_normal_, to_ntuple, to_2tuple
from ._manipulate import checkpoint_seq
from typing import Optional
import torch
import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import SqueezeExcite, trunc_normal_, to_ntuple, to_2tuple
from ._builder import build_model_with_cfg
from ._manipulate import checkpoint_seq
from ._registry import register_model, generate_default_cfgs
class ConvNorm(nn.Sequential):
@ -322,7 +323,7 @@ class RepVit(nn.Module):
def get_classifier(self):
return self.head
def reset_classifier(self, num_classes, global_pool=None, distillation=False):
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None, distillation=False):
self.num_classes = num_classes
if global_pool is not None:
self.global_pool = global_pool

View File

@ -9,7 +9,7 @@ Paper: `Sequencer: Deep LSTM for Image Classification` - https://arxiv.org/pdf/2
import math
from functools import partial
from itertools import accumulate
from typing import Tuple
from typing import Optional, Tuple
import torch
import torch.nn as nn
@ -419,7 +419,7 @@ class Sequencer2d(nn.Module):
def get_classifier(self):
return self.head
def reset_classifier(self, num_classes, global_pool=None):
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
self.num_classes = num_classes
self.head.reset(num_classes, pool_type=global_pool)

View File

@ -604,7 +604,7 @@ class SwinTransformer(nn.Module):
def get_classifier(self):
return self.head.fc
def reset_classifier(self, num_classes, global_pool=None):
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
self.num_classes = num_classes
self.head.reset(num_classes, pool_type=global_pool)

View File

@ -605,7 +605,7 @@ class SwinTransformerV2(nn.Module):
def get_classifier(self):
return self.head.fc
def reset_classifier(self, num_classes, global_pool=None):
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
self.num_classes = num_classes
self.head.reset(num_classes, global_pool)

View File

@ -8,10 +8,9 @@ Adapted from official impl at https://github.com/microsoft/Cream/tree/main/TinyV
__all__ = ['TinyVit']
import math
import itertools
from functools import partial
from typing import Dict
from typing import Dict, Optional
import torch
import torch.nn as nn
@ -533,7 +532,7 @@ class TinyVit(nn.Module):
def get_classifier(self):
return self.head.fc
def reset_classifier(self, num_classes, global_pool=None):
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
self.num_classes = num_classes
self.head.reset(num_classes, pool_type=global_pool)

View File

@ -7,6 +7,7 @@ The official mindspore code is released and available at
https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/TNT
"""
import math
from typing import Optional
import torch
import torch.nn as nn
@ -298,7 +299,7 @@ class TNT(nn.Module):
def get_classifier(self):
return self.head
def reset_classifier(self, num_classes, global_pool=None):
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
self.num_classes = num_classes
if global_pool is not None:
assert global_pool in ('', 'token', 'avg')

View File

@ -7,6 +7,7 @@ Original model: https://github.com/mrT23/TResNet
"""
from collections import OrderedDict
from functools import partial
from typing import Optional
import torch
import torch.nn as nn
@ -233,7 +234,7 @@ class TResNet(nn.Module):
def get_classifier(self):
return self.head.fc
def reset_classifier(self, num_classes, global_pool=None):
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
self.head.reset(num_classes, pool_type=global_pool)
def forward_features(self, x):

View File

@ -382,7 +382,7 @@ class Twins(nn.Module):
def get_classifier(self):
return self.head
def reset_classifier(self, num_classes, global_pool=None):
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
self.num_classes = num_classes
if global_pool is not None:
assert global_pool in ('', 'avg')

View File

@ -845,7 +845,9 @@ def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str =
"""
import numpy as np
def _n2p(w, t=True):
def _n2p(w, t=True, idx=None):
if idx is not None:
w = w[idx]
if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1:
w = w.flatten()
if t:
@ -955,21 +957,28 @@ def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str =
mha_sub, b_sub, ln1_sub = (0, 0, 1) if big_vision else (1, 3, 2)
for i, block in enumerate(model.blocks.children()):
block_prefix = f'{prefix}Transformer/encoderblock_{i}/'
if f'{prefix}Transformer/encoderblock/LayerNorm_0/scale' in w:
block_prefix = f'{prefix}Transformer/encoderblock/'
idx = i
else:
block_prefix = f'{prefix}Transformer/encoderblock_{i}/'
idx = None
mha_prefix = block_prefix + f'MultiHeadDotProductAttention_{mha_sub}/'
block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'], idx=idx))
block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'], idx=idx))
block.attn.qkv.weight.copy_(torch.cat([
_n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')]))
_n2p(w[f'{mha_prefix}{n}/kernel'], t=False, idx=idx).flatten(1).T for n in ('query', 'key', 'value')]))
block.attn.qkv.bias.copy_(torch.cat([
_n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')]))
block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/scale']))
block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/bias']))
_n2p(w[f'{mha_prefix}{n}/bias'], t=False, idx=idx).reshape(-1) for n in ('query', 'key', 'value')]))
block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel'], idx=idx).flatten(1))
block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'], idx=idx))
block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/scale'], idx=idx))
block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/bias'], idx=idx))
for r in range(2):
getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/kernel']))
getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/bias']))
getattr(block.mlp, f'fc{r + 1}').weight.copy_(
_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/kernel'], idx=idx))
getattr(block.mlp, f'fc{r + 1}').bias.copy_(
_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/bias'], idx=idx))
def _convert_openai_clip(
@ -1769,6 +1778,73 @@ default_cfgs = {
input_size=(3, 384, 384),
num_classes=0),
'vit_base_patch16_siglip_gap_224.webli': _cfg(
hf_hub_id='timm/ViT-B-16-SigLIP',
hf_hub_filename='open_clip_pytorch_model.bin',
num_classes=0),
'vit_base_patch16_siglip_gap_256.webli': _cfg(
hf_hub_id='timm/ViT-B-16-SigLIP-256',
hf_hub_filename='open_clip_pytorch_model.bin',
input_size=(3, 256, 256),
num_classes=0),
'vit_base_patch16_siglip_gap_384.webli': _cfg(
hf_hub_id='timm/ViT-B-16-SigLIP-384',
hf_hub_filename='open_clip_pytorch_model.bin',
input_size=(3, 384, 384),
num_classes=0),
'vit_base_patch16_siglip_gap_512.webli': _cfg(
hf_hub_id='timm/ViT-B-16-SigLIP-512',
hf_hub_filename='open_clip_pytorch_model.bin',
input_size=(3, 512, 512),
num_classes=0),
'vit_large_patch16_siglip_gap_256.webli': _cfg(
hf_hub_id='timm/ViT-L-16-SigLIP-256',
hf_hub_filename='open_clip_pytorch_model.bin',
input_size=(3, 256, 256),
num_classes=0),
'vit_large_patch16_siglip_gap_384.webli': _cfg(
hf_hub_id='timm/ViT-L-16-SigLIP-384',
hf_hub_filename='open_clip_pytorch_model.bin',
input_size=(3, 384, 384),
num_classes=0),
'vit_so400m_patch14_siglip_gap_224.webli': _cfg(
hf_hub_id='timm/ViT-SO400M-14-SigLIP',
hf_hub_filename='open_clip_pytorch_model.bin',
num_classes=0),
'vit_so400m_patch14_siglip_gap_224.pali_mix': _cfg(
hf_hub_id='google/paligemma-3b-mix-224-jax',
hf_hub_filename='paligemma-3b-mix-224.npz',
custom_load='hf',
num_classes=0),
'vit_so400m_patch14_siglip_gap_224.pali_pt': _cfg(
hf_hub_id='google/paligemma-3b-pt-224-jax',
hf_hub_filename='paligemma-3b-pt-224.npz',
custom_load='hf',
num_classes=0),
'vit_so400m_patch14_siglip_gap_384.webli': _cfg(
hf_hub_id='timm/ViT-SO400M-14-SigLIP-384',
hf_hub_filename='open_clip_pytorch_model.bin',
input_size=(3, 384, 384), crop_pct=1.0,
num_classes=0),
'vit_so400m_patch14_siglip_gap_448.pali_mix': _cfg(
hf_hub_id='google/paligemma-3b-mix-448-jax',
hf_hub_filename='paligemma-3b-mix-448.npz',
custom_load='hf',
input_size=(3, 448, 448), crop_pct=1.0,
num_classes=0),
'vit_so400m_patch14_siglip_gap_448.pali_pt': _cfg(
hf_hub_id='google/paligemma-3b-pt-448-jax',
hf_hub_filename='paligemma-3b-pt-448.npz',
custom_load='hf',
input_size=(3, 448, 448), crop_pct=1.0,
num_classes=0),
'vit_so400m_patch14_siglip_gap_896.pali_pt': _cfg(
hf_hub_id='google/paligemma-3b-pt-896-jax',
hf_hub_filename='paligemma-3b-pt-896.npz',
custom_load='hf',
input_size=(3, 896, 896), crop_pct=1.0,
num_classes=0),
'vit_xsmall_patch16_clip_224.tinyclip_yfcc15m': _cfg(
hf_hub_id='timm/',
hf_hub_filename='open_clip_pytorch_model.bin',
@ -1791,6 +1867,7 @@ default_cfgs = {
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512),
'vit_wee_patch16_reg1_gap_256.sbb_in1k': _cfg(
hf_hub_id='timm/',
input_size=(3, 256, 256), crop_pct=0.95),
'vit_pwee_patch16_reg1_gap_256.sbb_in1k': _cfg(
hf_hub_id='timm/',
@ -1801,9 +1878,16 @@ default_cfgs = {
'vit_medium_patch16_reg1_gap_256.sbb_in1k': _cfg(
hf_hub_id='timm/',
input_size=(3, 256, 256), crop_pct=0.95),
'vit_medium_patch16_reg4_gap_256.sbb_in12k_ft_in1k': _cfg(
hf_hub_id='timm/',
input_size=(3, 256, 256), crop_pct=0.95),
'vit_medium_patch16_reg4_gap_256.sbb_in1k': _cfg(
hf_hub_id='timm/',
input_size=(3, 256, 256), crop_pct=0.95),
'vit_medium_patch16_reg4_gap_256.sbb_in12k': _cfg(
hf_hub_id='timm/',
num_classes=11821,
input_size=(3, 256, 256), crop_pct=0.95),
'vit_mediumd_patch16_reg4_gap_256.sbb_in12k_ft_in1k': _cfg(
hf_hub_id='timm/',
input_size=(3, 256, 256), crop_pct=0.95),
@ -2374,7 +2458,6 @@ def vit_base_patch16_clip_quickgelu_224(pretrained: bool = False, **kwargs) -> V
def vit_large_patch14_clip_quickgelu_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
""" ViT-Large model (ViT-L/14) CLIP image tower w/ QuickGELU act
"""
from timm.layers import get_act_layer
model_args = dict(
patch_size=14, embed_dim=1024, depth=24, num_heads=16, pre_norm=True,
norm_layer=nn.LayerNorm, act_layer='quick_gelu')
@ -2756,15 +2839,118 @@ def vit_so400m_patch14_siglip_384(pretrained: bool = False, **kwargs) -> VisionT
return model
# @register_model
# def vit_medium_patch16_reg4_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
# model_args = dict(
# patch_size=16, embed_dim=512, depth=12, num_heads=8, class_token=True,
# no_embed_class=True, reg_tokens=4,
# )
# model = _create_vision_transformer(
# 'vit_medium_patch16_reg4_256', pretrained=pretrained, **dict(model_args, **kwargs))
# return model
@register_model
def vit_base_patch16_siglip_gap_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
""" A SigLIP variant of ViT with global average pooling (GAP) instead of attention pooling (MAP)."""
model_args = dict(
patch_size=16, embed_dim=768, depth=12, num_heads=12, class_token=False, global_pool='avg', fc_norm=False,
)
model = _create_vision_transformer(
'vit_base_patch16_siglip_gap_224', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def vit_base_patch16_siglip_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
""" A SigLIP variant of ViT with global average pooling (GAP) instead of attention pooling (MAP)."""
model_args = dict(
patch_size=16, embed_dim=768, depth=12, num_heads=12, class_token=False, global_pool='avg', fc_norm=False,
)
model = _create_vision_transformer(
'vit_base_patch16_siglip_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def vit_base_patch16_siglip_gap_384(pretrained: bool = False, **kwargs) -> VisionTransformer:
""" A SigLIP variant of ViT with global average pooling (GAP) instead of attention pooling (MAP)."""
model_args = dict(
patch_size=16, embed_dim=768, depth=12, num_heads=12, class_token=False, global_pool='avg', fc_norm=False,
)
model = _create_vision_transformer(
'vit_base_patch16_siglip_gap_384', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def vit_base_patch16_siglip_gap_512(pretrained: bool = False, **kwargs) -> VisionTransformer:
""" A SigLIP variant of ViT with global average pooling (GAP) instead of attention pooling (MAP)."""
model_args = dict(
patch_size=16, embed_dim=768, depth=12, num_heads=12, class_token=False, global_pool='avg', fc_norm=False,
)
model = _create_vision_transformer(
'vit_base_patch16_siglip_gap_512', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def vit_large_patch16_siglip_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
""" A SigLIP variant of ViT with global average pooling (GAP) instead of attention pooling (MAP)."""
model_args = dict(
patch_size=16, embed_dim=1024, depth=24, num_heads=16, class_token=False, global_pool='avg', fc_norm=False,
)
model = _create_vision_transformer(
'vit_large_patch16_siglip_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def vit_large_patch16_siglip_gap_384(pretrained: bool = False, **kwargs) -> VisionTransformer:
""" A SigLIP variant of ViT with global average pooling (GAP) instead of attention pooling (MAP)."""
model_args = dict(
patch_size=16, embed_dim=1024, depth=24, num_heads=16, class_token=False, global_pool='avg', fc_norm=False,
)
model = _create_vision_transformer(
'vit_large_patch16_siglip_gap_384', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def vit_so400m_patch14_siglip_gap_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
""" A SigLIP variant of ViT with global average pooling (GAP) instead of attention pooling (MAP)."""
model_args = dict(
patch_size=14, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362,
class_token=False, global_pool='avg', fc_norm=False,
)
model = _create_vision_transformer(
'vit_so400m_patch14_siglip_gap_224', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def vit_so400m_patch14_siglip_gap_384(pretrained: bool = False, **kwargs) -> VisionTransformer:
""" A SigLIP variant of ViT with global average pooling (GAP) instead of attention pooling (MAP)."""
model_args = dict(
patch_size=14, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362,
class_token=False, global_pool='avg', fc_norm=False,
)
model = _create_vision_transformer(
'vit_so400m_patch14_siglip_gap_384', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def vit_so400m_patch14_siglip_gap_448(pretrained: bool = False, **kwargs) -> VisionTransformer:
""" A SigLIP variant of ViT with global average pooling (GAP) instead of attention pooling (MAP)."""
model_args = dict(
patch_size=14, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362,
class_token=False, global_pool='avg', fc_norm=False,
)
model = _create_vision_transformer(
'vit_so400m_patch14_siglip_gap_448', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def vit_so400m_patch14_siglip_gap_896(pretrained: bool = False, **kwargs) -> VisionTransformer:
""" A SigLIP variant of ViT with global average pooling (GAP) instead of attention pooling (MAP)."""
model_args = dict(
patch_size=14, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362,
class_token=False, global_pool='avg', fc_norm=False,
)
model = _create_vision_transformer(
'vit_so400m_patch14_siglip_gap_896', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model

View File

@ -622,7 +622,7 @@ class VOLO(nn.Module):
def get_classifier(self):
return self.head
def reset_classifier(self, num_classes, global_pool=None):
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
self.num_classes = num_classes
if global_pool is not None:
self.global_pool = global_pool

View File

@ -1,4 +1,5 @@
import fnmatch
import re
from collections import OrderedDict
from typing import Union, Optional, List
@ -17,6 +18,7 @@ class AttentionExtract(torch.nn.Module):
mode: str = 'eval',
method: str = 'fx',
hook_type: str = 'forward',
use_regex: bool = False,
):
""" Extract attention maps (or other activations) from a model by name.
@ -26,6 +28,7 @@ class AttentionExtract(torch.nn.Module):
mode: 'train' or 'eval' model mode.
method: 'fx' or 'hook' extraction method.
hook_type: 'forward' or 'forward_pre' hooks used.
use_regex: Use regex instead of fnmatch
"""
super().__init__()
assert mode in ('train', 'eval')
@ -40,14 +43,16 @@ class AttentionExtract(torch.nn.Module):
from timm.models._features_fx import get_graph_node_names, GraphExtractNet
node_names = get_graph_node_names(model)[0 if mode == 'train' else 1]
matched = []
names = names or self.default_node_names
for n in names:
matched.extend(fnmatch.filter(node_names, n))
if use_regex:
regexes = [re.compile(r) for r in names]
matched = [g for g in node_names if any([r.match(g) for r in regexes])]
else:
matched = [g for g in node_names if any([fnmatch.fnmatch(g, n) for n in names])]
if not matched:
raise RuntimeError(f'No node names found matching {names}.')
self.model = GraphExtractNet(model, matched)
self.model = GraphExtractNet(model, matched, return_dict=True)
self.hooks = None
else:
# names are module names
@ -55,10 +60,12 @@ class AttentionExtract(torch.nn.Module):
from timm.models._features import FeatureHooks
module_names = [n for n, m in model.named_modules()]
matched = []
names = names or self.default_module_names
for n in names:
matched.extend(fnmatch.filter(module_names, n))
if use_regex:
regexes = [re.compile(r) for r in names]
matched = [m for m in module_names if any([r.match(m) for r in regexes])]
else:
matched = [m for m in module_names if any([fnmatch.fnmatch(m, n) for n in names])]
if not matched:
raise RuntimeError(f'No module names found matching {names}.')
@ -75,5 +82,4 @@ class AttentionExtract(torch.nn.Module):
output = self.hooks.get_output(device=x.device)
else:
output = self.model(x)
output = OrderedDict(zip(self.names, output))
return output

View File

@ -108,9 +108,16 @@ def init_distributed_device_so(
world_size = 1
global_rank = 0
local_rank = 0
device_type, *device_idx = device.split(':', maxsplit=1)
if dist_backend is None:
# FIXME sane defaults for other device backends?
dist_backend = 'nccl' if 'cuda' in device else 'gloo'
# FIXME: verify that ROCm transform nccl to rccl
dist_backends = {
"xpu": "ccl",
"hpu": "hccl",
"cuda": "nccl",
}
dist_backend = dist_backends.get(device_type, 'gloo')
dist_url = dist_url or 'env://'
# TBD, support horovod?
@ -150,18 +157,15 @@ def init_distributed_device_so(
global_rank = torch.distributed.get_rank()
distributed = True
if 'cuda' in device:
if device_type == 'cuda':
assert torch.cuda.is_available(), f'CUDA is not available but {device} was specified.'
if distributed and device != 'cpu':
device, *device_idx = device.split(':', maxsplit=1)
# Ignore manually specified device index in distributed mode and
# override with resolved local rank, fewer headaches in most setups.
if device_idx:
_logger.warning(f'device index {device_idx[0]} removed from specified ({device}).')
device = f'{device}:{local_rank}'
device = f'{device_type}:{local_rank}'
if device.startswith('cuda:'):
torch.cuda.set_device(device)

View File

@ -1 +1 @@
__version__ = '1.0.1.dev0'
__version__ = '1.0.4.dev0'