diff --git a/README.md b/README.md index 493d6c60..3c0b7bbe 100644 --- a/README.md +++ b/README.md @@ -26,6 +26,12 @@ * The Hugging Face Hub (https://huggingface.co/timm) is now the primary source for `timm` weights. Model cards include link to papers, original source, license. * Previous 0.6.x can be cloned from [0.6.x](https://github.com/rwightman/pytorch-image-models/tree/0.6.x) branch or installed via pip with version. +### May 14, 2024 +* Support loading PaliGemma jax weights into SigLIP ViT models with average pooling. +* Add Hiera models from Meta (https://github.com/facebookresearch/hiera). +* Add `normalize=` flag for transorms, return non-normalized torch.Tensor with original dytpe (for `chug`) +* Version 1.0.3 release + ### May 11, 2024 * `Searching for Better ViT Baselines (For the GPU Poor)` weights and vit variants released. Exploring model shapes between Tiny and Base. @@ -42,6 +48,7 @@ | [vit_medium_patch16_reg4_gap_256.sbb_in1k](https://huggingface.co/timm/vit_medium_patch16_reg4_gap_256.sbb_in1k) | 83.47 | 96.622 | 38.88 | 256 | | [vit_medium_patch16_reg1_gap_256.sbb_in1k](https://huggingface.co/timm/vit_medium_patch16_reg1_gap_256.sbb_in1k) | 83.462 | 96.548 | 38.88 | 256 | | [vit_little_patch16_reg4_gap_256.sbb_in1k](https://huggingface.co/timm/vit_little_patch16_reg4_gap_256.sbb_in1k) | 82.514 | 96.262 | 22.52 | 256 | +| [vit_wee_patch16_reg1_gap_256.sbb_in1k](https://huggingface.co/timm/vit_wee_patch16_reg1_gap_256.sbb_in1k) | 80.256 | 95.360 | 13.42 | 256 | | [vit_pwee_patch16_reg1_gap_256.sbb_in1k](https://huggingface.co/timm/vit_pwee_patch16_reg1_gap_256.sbb_in1k) | 80.072 | 95.136 | 15.25 | 256 | | [vit_mediumd_patch16_reg4_gap_256.sbb_in12k](https://huggingface.co/timm/vit_mediumd_patch16_reg4_gap_256.sbb_in12k) | N/A | N/A | 64.11 | 256 | | [vit_betwixt_patch16_reg4_gap_256.sbb_in12k](https://huggingface.co/timm/vit_betwixt_patch16_reg4_gap_256.sbb_in12k) | N/A | N/A | 60.4 | 256 | diff --git a/hfdocs/source/feature_extraction.mdx b/hfdocs/source/feature_extraction.mdx index d7343398..f443d4a9 100644 --- a/hfdocs/source/feature_extraction.mdx +++ b/hfdocs/source/feature_extraction.mdx @@ -192,9 +192,9 @@ There are two additional creation arguments impacting the output features. #### Output index selection -The `out_indices` argument is supported by all models, but not all models have the same index to feature stride mapping. Look at the code or check feature_info to compare. The out indices generally correspond to the `C(i+1)th` feature level (a `2^(i+1)` reduction). For most convnet models, index 0 is the stride 2 features, and index 4 is stride 32. For many ViT or ViT-Conv hybrids there may be many to all features maps of the same shape, or a combination of hierarchical and non-hieararchical feature maps. It is best to look at the `feature_info` attribute to see the number of features, their corresponding channel count and reduction level. +The `out_indices` argument is supported by all models, but not all models have the same index to feature stride mapping. Look at the code or check feature_info to compare. The out indices generally correspond to the `C(i+1)th` feature level (a `2^(i+1)` reduction). For most convnet models, index 0 is the stride 2 features, and index 4 is stride 32. For many ViT or ViT-Conv hybrids there may be many to all features maps of the same shape, or a combination of hierarchical and non-hierarchical feature maps. It is best to look at the `feature_info` attribute to see the number of features, their corresponding channel count and reduction level. -`out_indices` supports negative indexing, this makes it easy to get the last, penunltimate, etc feature map. `out_indices=(-2,)` would return the penultimate feature map for any model. +`out_indices` supports negative indexing, this makes it easy to get the last, penultimate, etc feature map. `out_indices=(-2,)` would return the penultimate feature map for any model. #### Output stride (feature map dilation) @@ -228,7 +228,7 @@ Accompanying the `forward_intermediates` function is a `prune_intermediate_layer An `indices` argument is used for both `forward_intermediates()` and `prune_intermediate_layers()` to select the features to return or layers to remove. As with the `out_indices` for `features_only` API, `indices` is model specific and selects which intermediates are returned. -In non-hierarchical block based models such as ViT the indices correspond to the blocks, in models with hierarchical stages they usually correspond to the output of the stem + each hierarhical stage. Both positive (from the start), and negative (relative to the end) indexing works, and `None` is used to return all intermediates. +In non-hierarchical block based models such as ViT the indices correspond to the blocks, in models with hierarchical stages they usually correspond to the output of the stem + each hierarchical stage. Both positive (from the start), and negative (relative to the end) indexing works, and `None` is used to return all intermediates. The `prune_intermediate_layers()` call returns an indices variable, as negative indices must be converted to absolute (positive) indices when the model is trimmed. diff --git a/hfdocs/source/installation.mdx b/hfdocs/source/installation.mdx index 3ff210f3..b5093440 100644 --- a/hfdocs/source/installation.mdx +++ b/hfdocs/source/installation.mdx @@ -28,7 +28,7 @@ You should install `timm` in a [virtual environment](https://docs.python.org/3/l # Deactivate the virtual environment source .env/bin/deactivate ``` -` + Once you've created your virtual environment, you can install `timm` in it. ## Using pip diff --git a/tests/test_models.py b/tests/test_models.py index 9ff64c3b..652ea355 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -52,7 +52,7 @@ FEAT_INTER_FILTERS = [ 'vision_transformer', 'vision_transformer_sam', 'vision_transformer_hybrid', 'vision_transformer_relpos', 'beit', 'mvitv2', 'eva', 'cait', 'xcit', 'volo', 'twins', 'deit', 'swin_transformer', 'swin_transformer_v2', 'swin_transformer_v2_cr', 'maxxvit', 'efficientnet', 'mobilenetv3', 'levit', 'efficientformer', 'resnet', - 'regnet', 'byobnet', 'byoanet', 'mlp_mixer' + 'regnet', 'byobnet', 'byoanet', 'mlp_mixer', 'hiera', ] # transformer / hybrid models don't support full set of spatial / feature APIs and/or have spatial output. @@ -60,7 +60,7 @@ NON_STD_FILTERS = [ 'vit_*', 'tnt_*', 'pit_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*', 'convit_*', 'levit*', 'visformer*', 'deit*', 'xcit_*', 'crossvit_*', 'beit*', 'poolformer_*', 'volo_*', 'sequencer2d_*', 'mvitv2*', 'gcvit*', 'efficientformer*', - 'eva_*', 'flexivit*', 'eva02*', 'samvit_*', 'efficientvit_m*', 'tiny_vit_*' + 'eva_*', 'flexivit*', 'eva02*', 'samvit_*', 'efficientvit_m*', 'tiny_vit_*', 'hiera_*' ] NUM_NON_STD = len(NON_STD_FILTERS) @@ -77,7 +77,7 @@ else: EXCLUDE_FILTERS = ['*enormous*'] NON_STD_EXCLUDE_FILTERS = ['*gigantic*', '*enormous*'] -EXCLUDE_JIT_FILTERS = [] +EXCLUDE_JIT_FILTERS = ['hiera_*'] TARGET_FWD_SIZE = MAX_FWD_SIZE = 384 TARGET_BWD_SIZE = 128 @@ -486,7 +486,7 @@ def _create_fx_model(model, train=False): return fx_model -EXCLUDE_FX_FILTERS = ['vit_gi*'] +EXCLUDE_FX_FILTERS = ['vit_gi*', 'hiera*'] # not enough memory to run fx on more models than other tests if 'GITHUB_ACTIONS' in os.environ: EXCLUDE_FX_FILTERS += [ diff --git a/timm/data/transforms_factory.py b/timm/data/transforms_factory.py index 504d1199..59754891 100644 --- a/timm/data/transforms_factory.py +++ b/timm/data/transforms_factory.py @@ -19,9 +19,10 @@ from timm.data.random_erasing import RandomErasing def transforms_noaug_train( img_size: Union[int, Tuple[int, int]] = 224, interpolation: str = 'bilinear', - use_prefetcher: bool = False, mean: Tuple[float, ...] = IMAGENET_DEFAULT_MEAN, std: Tuple[float, ...] = IMAGENET_DEFAULT_STD, + use_prefetcher: bool = False, + normalize: bool = True, ): """ No-augmentation image transforms for training. @@ -31,6 +32,7 @@ def transforms_noaug_train( mean: Image normalization mean. std: Image normalization standard deviation. use_prefetcher: Prefetcher enabled. Do not convert image to tensor or normalize. + normalize: Normalization tensor output w/ provided mean/std (if prefetcher not used). Returns: @@ -45,6 +47,9 @@ def transforms_noaug_train( if use_prefetcher: # prefetcher and collate will handle tensor conversion and norm tfl += [ToNumpy()] + elif not normalize: + # when normalize disabled, converted to tensor without scaling, keep original dtype + tfl += [transforms.PILToTensor()] else: tfl += [ transforms.ToTensor(), @@ -77,6 +82,7 @@ def transforms_imagenet_train( re_count: int = 1, re_num_splits: int = 0, use_prefetcher: bool = False, + normalize: bool = True, separate: bool = False, ): """ ImageNet-oriented image transforms for training. @@ -103,6 +109,7 @@ def transforms_imagenet_train( re_count: Number of random erasing regions. re_num_splits: Control split of random erasing across batch size. use_prefetcher: Prefetcher enabled. Do not convert image to tensor or normalize. + normalize: Normalize tensor output w/ provided mean/std (if prefetcher not used). separate: Output transforms in 3-stage tuple. Returns: @@ -209,12 +216,15 @@ def transforms_imagenet_train( if use_prefetcher: # prefetcher and collate will handle tensor conversion and norm final_tfl += [ToNumpy()] + elif not normalize: + # when normalize disable, converted to tensor without scaling, keeps original dtype + final_tfl += [transforms.PILToTensor()] else: final_tfl += [ transforms.ToTensor(), transforms.Normalize( mean=torch.tensor(mean), - std=torch.tensor(std) + std=torch.tensor(std), ), ] if re_prob > 0.: @@ -243,6 +253,7 @@ def transforms_imagenet_eval( mean: Tuple[float, ...] = IMAGENET_DEFAULT_MEAN, std: Tuple[float, ...] = IMAGENET_DEFAULT_STD, use_prefetcher: bool = False, + normalize: bool = True, ): """ ImageNet-oriented image transform for evaluation and inference. @@ -255,6 +266,7 @@ def transforms_imagenet_eval( mean: Image normalization mean. std: Image normalization standard deviation. use_prefetcher: Prefetcher enabled. Do not convert image to tensor or normalize. + normalize: Normalize tensor output w/ provided mean/std (if prefetcher not used). Returns: Composed transform pipeline @@ -304,13 +316,16 @@ def transforms_imagenet_eval( if use_prefetcher: # prefetcher and collate will handle tensor conversion and norm tfl += [ToNumpy()] + elif not normalize: + # when normalize disabled, converted to tensor without scaling, keeps original dtype + tfl += [transforms.PILToTensor()] else: tfl += [ transforms.ToTensor(), transforms.Normalize( mean=torch.tensor(mean), std=torch.tensor(std), - ) + ), ] return transforms.Compose(tfl) @@ -342,6 +357,7 @@ def create_transform( crop_border_pixels: Optional[int] = None, tf_preprocessing: bool = False, use_prefetcher: bool = False, + normalize: bool = True, separate: bool = False, ): """ @@ -373,6 +389,7 @@ def create_transform( crop_border_pixels: Inference crop border of specified # pixels around edge of original image. tf_preprocessing: Use TF 1.0 inference preprocessing for testing model ports use_prefetcher: Pre-fetcher enabled. Do not convert image to tensor or normalize. + normalize: Normalization tensor output w/ provided mean/std (if prefetcher not used). separate: Output transforms in 3-stage tuple. Returns: @@ -397,9 +414,10 @@ def create_transform( transform = transforms_noaug_train( img_size, interpolation=interpolation, - use_prefetcher=use_prefetcher, mean=mean, std=std, + use_prefetcher=use_prefetcher, + normalize=normalize, ) elif is_training: transform = transforms_imagenet_train( @@ -415,13 +433,14 @@ def create_transform( gaussian_blur_prob=gaussian_blur_prob, auto_augment=auto_augment, interpolation=interpolation, - use_prefetcher=use_prefetcher, mean=mean, std=std, re_prob=re_prob, re_mode=re_mode, re_count=re_count, re_num_splits=re_num_splits, + use_prefetcher=use_prefetcher, + normalize=normalize, separate=separate, ) else: @@ -429,12 +448,13 @@ def create_transform( transform = transforms_imagenet_eval( img_size, interpolation=interpolation, - use_prefetcher=use_prefetcher, mean=mean, std=std, crop_pct=crop_pct, crop_mode=crop_mode, crop_border_pixels=crop_border_pixels, + use_prefetcher=use_prefetcher, + normalize=normalize, ) return transform diff --git a/timm/layers/classifier.py b/timm/layers/classifier.py index 71e45c87..27ee5e70 100644 --- a/timm/layers/classifier.py +++ b/timm/layers/classifier.py @@ -108,7 +108,7 @@ class ClassifierHead(nn.Module): self.fc = fc self.flatten = nn.Flatten(1) if use_conv and pool_type else nn.Identity() - def reset(self, num_classes, pool_type=None): + def reset(self, num_classes: int, pool_type: Optional[str] = None): if pool_type is not None and pool_type != self.global_pool.pool_type: self.global_pool, self.fc = create_classifier( self.in_features, @@ -180,7 +180,7 @@ class NormMlpClassifierHead(nn.Module): self.drop = nn.Dropout(drop_rate) self.fc = linear_layer(self.num_features, num_classes) if num_classes > 0 else nn.Identity() - def reset(self, num_classes, pool_type=None): + def reset(self, num_classes: int, pool_type: Optional[str] = None): if pool_type is not None: self.global_pool = SelectAdaptivePool2d(pool_type=pool_type) self.flatten = nn.Flatten(1) if pool_type else nn.Identity() diff --git a/timm/layers/create_norm.py b/timm/layers/create_norm.py index 3c4d287a..fbf58985 100644 --- a/timm/layers/create_norm.py +++ b/timm/layers/create_norm.py @@ -47,7 +47,7 @@ def get_norm_layer(norm_layer): if isinstance(norm_layer, str): if not norm_layer: return None - layer_name = norm_layer.replace('_', '') + layer_name = norm_layer.replace('_', '').lower() norm_layer = _NORM_MAP[layer_name] else: norm_layer = norm_layer diff --git a/timm/layers/patch_dropout.py b/timm/layers/patch_dropout.py index 32dd1519..4428fe04 100644 --- a/timm/layers/patch_dropout.py +++ b/timm/layers/patch_dropout.py @@ -6,7 +6,7 @@ import torch.nn as nn class PatchDropout(nn.Module): """ - https://arxiv.org/abs/2212.00794 + https://arxiv.org/abs/2212.00794 and https://arxiv.org/pdf/2208.07220 """ return_indices: torch.jit.Final[bool] diff --git a/timm/models/__init__.py b/timm/models/__init__.py index e558c1a6..ed4df651 100644 --- a/timm/models/__init__.py +++ b/timm/models/__init__.py @@ -26,6 +26,7 @@ from .gcvit import * from .ghostnet import * from .hardcorenas import * from .hgnet import * +from .hiera import * from .hrnet import * from .inception_next import * from .inception_resnet_v2 import * diff --git a/timm/models/_builder.py b/timm/models/_builder.py index f248fbd3..7741cf94 100644 --- a/timm/models/_builder.py +++ b/timm/models/_builder.py @@ -10,7 +10,8 @@ from torch.hub import load_state_dict_from_url from timm.models._features import FeatureListNet, FeatureDictNet, FeatureHookNet, FeatureGetterNet from timm.models._features_fx import FeatureGraphNet from timm.models._helpers import load_state_dict -from timm.models._hub import has_hf_hub, download_cached_file, check_cached_file, load_state_dict_from_hf +from timm.models._hub import has_hf_hub, download_cached_file, check_cached_file, load_state_dict_from_hf,\ + load_custom_from_hf from timm.models._manipulate import adapt_input_conv from timm.models._pretrained import PretrainedCfg from timm.models._prune import adapt_model_from_file @@ -185,7 +186,12 @@ def load_pretrained( elif load_from == 'hf-hub': _logger.info(f'Loading pretrained weights from Hugging Face hub ({pretrained_loc})') if isinstance(pretrained_loc, (list, tuple)): - state_dict = load_state_dict_from_hf(*pretrained_loc) + custom_load = pretrained_cfg.get('custom_load', False) + if isinstance(custom_load, str) and custom_load == 'hf': + load_custom_from_hf(*pretrained_loc, model) + return + else: + state_dict = load_state_dict_from_hf(*pretrained_loc) else: state_dict = load_state_dict_from_hf(pretrained_loc) else: diff --git a/timm/models/_features_fx.py b/timm/models/_features_fx.py index b775871c..3a276046 100644 --- a/timm/models/_features_fx.py +++ b/timm/models/_features_fx.py @@ -118,6 +118,7 @@ class FeatureGraphNet(nn.Module): out_indices: Tuple[int, ...], out_map: Optional[Dict] = None, output_fmt: str = 'NCHW', + return_dict: bool = False, ): super().__init__() assert has_fx_feature_extraction, 'Please update to PyTorch 1.10+, torchvision 0.11+ for FX feature extraction' @@ -127,9 +128,13 @@ class FeatureGraphNet(nn.Module): self.output_fmt = Format(output_fmt) return_nodes = _get_return_layers(self.feature_info, out_map) self.graph_module = create_feature_extractor(model, return_nodes) + self.return_dict = return_dict def forward(self, x): - return list(self.graph_module(x).values()) + out = self.graph_module(x) + if self.return_dict: + return out + return list(out.values()) class GraphExtractNet(nn.Module): @@ -144,19 +149,23 @@ class GraphExtractNet(nn.Module): model: model to extract features from return_nodes: node names to return features from (dict or list) squeeze_out: if only one output, and output in list format, flatten to single tensor + return_dict: return as dictionary from extractor with node names as keys, ignores squeeze_out arg """ def __init__( self, model: nn.Module, return_nodes: Union[Dict[str, str], List[str]], squeeze_out: bool = True, + return_dict: bool = False, ): super().__init__() self.squeeze_out = squeeze_out self.graph_module = create_feature_extractor(model, return_nodes) + self.return_dict = return_dict def forward(self, x) -> Union[List[torch.Tensor], torch.Tensor]: - out = list(self.graph_module(x).values()) - if self.squeeze_out and len(out) == 1: - return out[0] - return out + out = self.graph_module(x) + if self.return_dict: + return out + out = list(out.values()) + return out[0] if self.squeeze_out and len(out) == 1 else out diff --git a/timm/models/_hub.py b/timm/models/_hub.py index 55ab04bf..a36321bf 100644 --- a/timm/models/_hub.py +++ b/timm/models/_hub.py @@ -190,6 +190,13 @@ def load_state_dict_from_hf(model_id: str, filename: str = HF_WEIGHTS_NAME): return torch.load(cached_file, map_location='cpu') +def load_custom_from_hf(model_id: str, filename: str, model: torch.nn.Module): + assert has_hf_hub(True) + hf_model_id, hf_revision = hf_split(model_id) + cached_file = hf_hub_download(hf_model_id, filename=filename, revision=hf_revision) + return model.load_pretrained(cached_file) + + def save_config_for_hf( model, config_path: str, diff --git a/timm/models/beit.py b/timm/models/beit.py index 43048285..63b6db54 100644 --- a/timm/models/beit.py +++ b/timm/models/beit.py @@ -395,7 +395,7 @@ class Beit(nn.Module): def get_classifier(self): return self.head - def reset_classifier(self, num_classes, global_pool=None): + def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): self.num_classes = num_classes if global_pool is not None: self.global_pool = global_pool diff --git a/timm/models/cait.py b/timm/models/cait.py index bf649076..2d4c7365 100644 --- a/timm/models/cait.py +++ b/timm/models/cait.py @@ -331,7 +331,7 @@ class Cait(nn.Module): def get_classifier(self): return self.head - def reset_classifier(self, num_classes, global_pool=None): + def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): self.num_classes = num_classes if global_pool is not None: assert global_pool in ('', 'token', 'avg') diff --git a/timm/models/coat.py b/timm/models/coat.py index 68358b3d..3e7b9c7a 100644 --- a/timm/models/coat.py +++ b/timm/models/coat.py @@ -7,8 +7,7 @@ Official CoaT code at: https://github.com/mlpc-ucsd/CoaT Modified from timm/models/vision_transformer.py """ -from functools import partial -from typing import Tuple, List, Union +from typing import List, Optional, Union, Tuple import torch import torch.nn as nn @@ -560,7 +559,7 @@ class CoaT(nn.Module): def get_classifier(self): return self.head - def reset_classifier(self, num_classes, global_pool=None): + def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): self.num_classes = num_classes if global_pool is not None: assert global_pool in ('token', 'avg') diff --git a/timm/models/convit.py b/timm/models/convit.py index 6cfcae27..fb42baa0 100644 --- a/timm/models/convit.py +++ b/timm/models/convit.py @@ -21,8 +21,7 @@ Modifications and additions for timm hacked together by / Copyright 2021, Ross W '''These modules are adapted from those of timm, see https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py ''' - -from functools import partial +from typing import Optional import torch import torch.nn as nn @@ -349,7 +348,7 @@ class ConVit(nn.Module): def get_classifier(self): return self.head - def reset_classifier(self, num_classes, global_pool=None): + def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): self.num_classes = num_classes if global_pool is not None: assert global_pool in ('', 'token', 'avg') diff --git a/timm/models/convmixer.py b/timm/models/convmixer.py index 854f84a0..3e43dd66 100644 --- a/timm/models/convmixer.py +++ b/timm/models/convmixer.py @@ -1,6 +1,8 @@ """ ConvMixer """ +from typing import Optional + import torch import torch.nn as nn @@ -75,7 +77,7 @@ class ConvMixer(nn.Module): def get_classifier(self): return self.head - def reset_classifier(self, num_classes, global_pool=None): + def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): self.num_classes = num_classes if global_pool is not None: self.pooling = SelectAdaptivePool2d(pool_type=global_pool, flatten=True) diff --git a/timm/models/convnext.py b/timm/models/convnext.py index f10f6c7b..012a73c9 100644 --- a/timm/models/convnext.py +++ b/timm/models/convnext.py @@ -37,7 +37,6 @@ Modifications and additions for timm hacked together by / Copyright 2022, Ross W # LICENSE file in the root directory of this source tree (Attribution-NonCommercial 4.0 International (CC BY-NC 4.0)) # No code was used directly from ConvNeXt-V2, however the weights are CC BY-NC 4.0 so beware if using commercially. -from collections import OrderedDict from functools import partial from typing import Callable, List, Optional, Tuple, Union diff --git a/timm/models/crossvit.py b/timm/models/crossvit.py index 5c90aec9..27d75808 100644 --- a/timm/models/crossvit.py +++ b/timm/models/crossvit.py @@ -25,8 +25,7 @@ Modifed from Timm. https://github.com/rwightman/pytorch-image-models/blob/master """ from functools import partial -from typing import List -from typing import Tuple +from typing import List, Optional, Tuple import torch import torch.hub @@ -419,7 +418,7 @@ class CrossVit(nn.Module): def get_classifier(self): return self.head - def reset_classifier(self, num_classes, global_pool=None): + def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): self.num_classes = num_classes if global_pool is not None: assert global_pool in ('token', 'avg') diff --git a/timm/models/davit.py b/timm/models/davit.py index 442ca620..dceda60e 100644 --- a/timm/models/davit.py +++ b/timm/models/davit.py @@ -12,7 +12,7 @@ DaViT model defs and weights adapted from https://github.com/dingmyu/davit, orig # All rights reserved. # This source code is licensed under the MIT license from functools import partial -from typing import Tuple +from typing import Optional, Tuple import torch import torch.nn as nn @@ -568,7 +568,7 @@ class DaVit(nn.Module): def get_classifier(self): return self.head.fc - def reset_classifier(self, num_classes, global_pool=None): + def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): self.head.reset(num_classes, global_pool) def forward_features(self, x): diff --git a/timm/models/deit.py b/timm/models/deit.py index 9400549d..96770beb 100644 --- a/timm/models/deit.py +++ b/timm/models/deit.py @@ -11,7 +11,7 @@ Modifications copyright 2021, Ross Wightman # Copyright (c) 2015-present, Facebook, Inc. # All rights reserved. from functools import partial -from typing import Sequence, Union +from typing import Optional import torch from torch import nn as nn @@ -20,7 +20,6 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import resample_abs_pos_embed from timm.models.vision_transformer import VisionTransformer, trunc_normal_, checkpoint_filter_fn from ._builder import build_model_with_cfg -from ._manipulate import checkpoint_seq from ._registry import generate_default_cfgs, register_model, register_model_deprecations __all__ = ['VisionTransformerDistilled'] # model_registry will add each entrypoint fn to this @@ -64,7 +63,7 @@ class VisionTransformerDistilled(VisionTransformer): def get_classifier(self): return self.head, self.head_dist - def reset_classifier(self, num_classes, global_pool=None): + def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): self.num_classes = num_classes self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity() diff --git a/timm/models/edgenext.py b/timm/models/edgenext.py index 661669d5..515bc225 100644 --- a/timm/models/edgenext.py +++ b/timm/models/edgenext.py @@ -8,7 +8,6 @@ Original code and weights from https://github.com/mmaaz60/EdgeNeXt Modifications and additions for timm by / Copyright 2022, Ross Wightman """ import math -from collections import OrderedDict from functools import partial from typing import Tuple @@ -17,7 +16,7 @@ import torch.nn.functional as F from torch import nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.layers import trunc_normal_tf_, DropPath, LayerNorm2d, Mlp, SelectAdaptivePool2d, create_conv2d, \ +from timm.layers import trunc_normal_tf_, DropPath, LayerNorm2d, Mlp, create_conv2d, \ use_fused_attn, NormMlpClassifierHead, ClassifierHead from ._builder import build_model_with_cfg from ._features_fx import register_notrace_module diff --git a/timm/models/efficientformer.py b/timm/models/efficientformer.py index 798f6435..c28538bc 100644 --- a/timm/models/efficientformer.py +++ b/timm/models/efficientformer.py @@ -449,7 +449,7 @@ class EfficientFormer(nn.Module): def get_classifier(self): return self.head, self.head_dist - def reset_classifier(self, num_classes, global_pool=None): + def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): self.num_classes = num_classes if global_pool is not None: self.global_pool = global_pool diff --git a/timm/models/efficientformer_v2.py b/timm/models/efficientformer_v2.py index 8f76bed3..ba3b7c5f 100644 --- a/timm/models/efficientformer_v2.py +++ b/timm/models/efficientformer_v2.py @@ -16,7 +16,7 @@ Modifications and timm support by / Copyright 2023, Ross Wightman """ import math from functools import partial -from typing import Dict +from typing import Dict, Optional import torch import torch.nn as nn @@ -612,7 +612,7 @@ class EfficientFormerV2(nn.Module): def get_classifier(self): return self.head, self.head_dist - def reset_classifier(self, num_classes, global_pool=None): + def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): self.num_classes = num_classes if global_pool is not None: self.global_pool = global_pool diff --git a/timm/models/efficientvit_mit.py b/timm/models/efficientvit_mit.py index 1960d3d2..c971fe61 100644 --- a/timm/models/efficientvit_mit.py +++ b/timm/models/efficientvit_mit.py @@ -13,7 +13,6 @@ from functools import partial import torch import torch.nn as nn import torch.nn.functional as F -from torch.nn.modules.batchnorm import _BatchNorm from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import SelectAdaptivePool2d, create_conv2d, GELUTanh @@ -740,7 +739,7 @@ class EfficientVit(nn.Module): def get_classifier(self): return self.head.classifier[-1] - def reset_classifier(self, num_classes, global_pool=None): + def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): self.num_classes = num_classes if global_pool is not None: self.global_pool = global_pool @@ -858,7 +857,7 @@ class EfficientVitLarge(nn.Module): def get_classifier(self): return self.head.classifier[-1] - def reset_classifier(self, num_classes, global_pool=None): + def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): self.num_classes = num_classes if global_pool is not None: self.global_pool = global_pool diff --git a/timm/models/efficientvit_msra.py b/timm/models/efficientvit_msra.py index 1b7f52a0..deaf1fba 100644 --- a/timm/models/efficientvit_msra.py +++ b/timm/models/efficientvit_msra.py @@ -9,7 +9,7 @@ Adapted from official impl at https://github.com/microsoft/Cream/tree/main/Effic __all__ = ['EfficientVitMsra'] import itertools from collections import OrderedDict -from typing import Dict +from typing import Dict, Optional import torch import torch.nn as nn @@ -464,7 +464,7 @@ class EfficientVitMsra(nn.Module): def get_classifier(self): return self.head.linear - def reset_classifier(self, num_classes, global_pool=None): + def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): self.num_classes = num_classes if global_pool is not None: if global_pool == 'avg': diff --git a/timm/models/eva.py b/timm/models/eva.py index d7763fe1..d424ab3d 100644 --- a/timm/models/eva.py +++ b/timm/models/eva.py @@ -539,7 +539,7 @@ class Eva(nn.Module): def get_classifier(self): return self.head - def reset_classifier(self, num_classes, global_pool=None): + def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): self.num_classes = num_classes if global_pool is not None: self.global_pool = global_pool diff --git a/timm/models/fastvit.py b/timm/models/fastvit.py index 67961880..74b6cc28 100644 --- a/timm/models/fastvit.py +++ b/timm/models/fastvit.py @@ -396,7 +396,7 @@ class ReparamLargeKernelConv(nn.Module): @staticmethod def _fuse_bn( - conv: torch.Tensor, bn: nn.BatchNorm2d + conv: nn.Conv2d, bn: nn.BatchNorm2d ) -> Tuple[torch.Tensor, torch.Tensor]: """Method to fuse batchnorm layer with conv layer. @@ -1232,7 +1232,7 @@ class FastVit(nn.Module): def get_classifier(self): return self.head.fc - def reset_classifier(self, num_classes, global_pool=None): + def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): self.num_classes = num_classes self.head.reset(num_classes, global_pool) diff --git a/timm/models/focalnet.py b/timm/models/focalnet.py index 07410da4..1624340b 100644 --- a/timm/models/focalnet.py +++ b/timm/models/focalnet.py @@ -454,7 +454,7 @@ class FocalNet(nn.Module): def get_classifier(self): return self.head.fc - def reset_classifier(self, num_classes, global_pool=None): + def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): self.head.reset(num_classes, pool_type=global_pool) def forward_features(self, x): diff --git a/timm/models/gcvit.py b/timm/models/gcvit.py index 653bc370..16f93bf4 100644 --- a/timm/models/gcvit.py +++ b/timm/models/gcvit.py @@ -489,7 +489,7 @@ class GlobalContextVit(nn.Module): def get_classifier(self): return self.head.fc - def reset_classifier(self, num_classes, global_pool=None): + def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): self.num_classes = num_classes if global_pool is None: global_pool = self.head.global_pool.pool_type diff --git a/timm/models/hiera.py b/timm/models/hiera.py new file mode 100644 index 00000000..f229daf4 --- /dev/null +++ b/timm/models/hiera.py @@ -0,0 +1,936 @@ +""" An PyTorch implementation of Hiera + +Adapted for timm from originals at https://github.com/facebookresearch/hiera +""" + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# +# Hiera: A Hierarchical Vision Transformer without the Bells-and-Whistles +# +# Chaitanya Ryali, Yuan-Ting Hu, Daniel Bolya, Chen Wei, Haoqi Fan, +# Po-Yao Huang, Vaibhav Aggarwal, Arkabandhu Chowdhury, Omid Poursaeed, +# Judy Hoffman, Jitendra Malik, Yanghao Li, Christoph Feichtenhofer. +# +# Paper: https://arxiv.org/abs/2306.00989/ +# +# References: +# slowfast: https://github.com/facebookresearch/SlowFast +# timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm +# -------------------------------------------------------- +import math +from functools import partial +from typing import Callable, Dict, List, Optional, Tuple, Type, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.checkpoint import checkpoint + + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from timm.layers import DropPath, Mlp, use_fused_attn, _assert, get_norm_layer + + +from ._registry import generate_default_cfgs, register_model +from ._builder import build_model_with_cfg +from ._features import feature_take_indices +from ._features_fx import register_notrace_function + + +def conv_nd(n: int) -> Type[nn.Module]: + """ + Returns a conv with nd (e.g., Conv2d for n=2). Work up to n=3. + If you wanted a 4d Hiera, you could probably just implement this for n=4. (no promises) + """ + return [nn.Identity, nn.Conv1d, nn.Conv2d, nn.Conv3d][n] + + +@register_notrace_function +def get_resized_mask(target_size: List[int], mask: torch.Tensor) -> torch.Tensor: + # target_size: [(T), (H), W] + # (spatial) mask: [B, C, (t), (h), w] + if mask is None: + return mask + + _assert(len(mask.shape[2:]) == len(target_size), "mask spatial shape and target_size must match.") + if mask.shape[2:] != target_size: + return F.interpolate(mask.float(), size=target_size) + return mask + + +def undo_windowing( + x: torch.Tensor, + shape: List[int], + mu_shape: List[int], +) -> torch.Tensor: + """ + Restore spatial organization by undoing windowed organization of mask units. + + Args: + x: organized by mask units windows, e.g. in 2d [B, #MUy*#MUx, MUy, MUx, C] + shape: current spatial shape, if it were not organized into mask unit + windows, e.g. in 2d [B, #MUy*MUy, #MUx*MUx, C]. + mu_shape: current mask unit shape, e.g. in 2d [MUy, MUx] + Returns: + x: e.g. in 2d, [B, #MUy*MUy, #MUx*MUx, C] + """ + D = len(shape) + B, C = x.shape[0], x.shape[-1] + # [B, #MUy*#MUx, MUy, MUx, C] -> [B, #MUy, #MUx, MUy, MUx, C] + num_MUs = [s // mu for s, mu in zip(shape, mu_shape)] + x = x.view(B, *num_MUs, *mu_shape, C) + + # [B, #MUy, #MUx, MUy, MUx, C] -> [B, #MUy*MUy, #MUx*MUx, C] + permute = ( + [0] + + sum([list(p) for p in zip(range(1, 1 + D), range(1 + D, 1 + 2 * D))], []) + + [len(x.shape) - 1] + ) + x = x.permute(permute).reshape(B, *shape, C) + + return x + + +class Unroll(nn.Module): + """ + Reorders the tokens such that patches are contiguous in memory. + E.g., given [B, (H, W), C] and stride of (Sy, Sx), this will re-order the tokens as + [B, (Sy, Sx, H // Sy, W // Sx), C] + + This allows operations like Max2d to be computed as x.view(B, Sx*Sy, -1, C).max(dim=1). + Not only is this faster, but it also makes it easy to support inputs of arbitrary + dimensions in addition to patch-wise sparsity. + + Performing this operation multiple times in sequence puts entire windows as contiguous + in memory. For instance, if you applied the stride (2, 2) 3 times, entire windows of + size 8x8 would be contiguous in memory, allowing operations like mask unit attention + computed easily and efficiently, while also allowing max to be applied sequentially. + + Note: This means that intermediate values of the model are not in HxW order, so they + need to be re-rolled if you want to use the intermediate values as a HxW feature map. + The last block of the network is fine though, since by then the strides are all consumed. + """ + + def __init__( + self, + input_size: Tuple[int, ...], + patch_stride: Tuple[int, ...], + unroll_schedule: List[Tuple[int, ...]], + ): + super().__init__() + self.size = [i // s for i, s in zip(input_size, patch_stride)] + self.schedule = unroll_schedule + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Input: Flattened patch embeddings [B, N, C] + Output: Patch embeddings [B, N, C] permuted such that [B, 4, N//4, C].max(1) etc. performs MaxPoolNd + """ + B, _, C = x.shape + cur_size = self.size + x = x.view(*([B] + cur_size + [C])) + + for strides in self.schedule: + # Move patches with the given strides to the batch dimension + + # Create a view of the tensor with the patch stride as separate dims + # For example in 2d: [B, H // Sy, Sy, W // Sx, Sx, C] + cur_size = [i // s for i, s in zip(cur_size, strides)] + new_shape = [B] + sum([[i, s] for i, s in zip(cur_size, strides)], []) + [C] + x = x.view(new_shape) + + # Move the patch stride into the batch dimension + # For example in 2d: [B, Sy, Sx, H // Sy, W // Sx, C] + L = len(new_shape) + permute = [0] + list(range(2, L - 1, 2)) + list(range(1, L - 1, 2)) + [L - 1] + x = x.permute(permute) + + # Now finally flatten the relevant dims into the batch dimension + x = x.flatten(0, len(strides)) + B *= math.prod(strides) + + x = x.reshape(-1, math.prod(self.size), C) + return x + + +class Reroll(nn.Module): + """ + Undos the "unroll" operation so that you can use intermediate features. + """ + + def __init__( + self, + input_size: Tuple[int, ...], + patch_stride: Tuple[int, ...], + unroll_schedule: List[Tuple[int, ...]], + stage_ends: List[int], + q_pool: int, + ): + super().__init__() + self.size = [i // s for i, s in zip(input_size, patch_stride)] + + # The first stage has to reverse everything + # The next stage has to reverse all but the first unroll, etc. + self.schedule = {} + size = self.size + for i in range(stage_ends[-1] + 1): + self.schedule[i] = unroll_schedule, size + # schedule unchanged if no pooling at a stage end + if i in stage_ends[:q_pool]: + if len(unroll_schedule) > 0: + size = [n // s for n, s in zip(size, unroll_schedule[0])] + unroll_schedule = unroll_schedule[1:] + + def forward( + self, + x: torch.Tensor, + block_idx: int, + mask: torch.Tensor = None + ) -> torch.Tensor: + """ + Roll the given tensor back up to spatial order assuming it's from the given block. + + If no mask is provided: + - Returns [B, H, W, C] for 2d, [B, T, H, W, C] for 3d, etc. + If a mask is provided: + - Returns [B, #MUs, MUy, MUx, C] for 2d, etc. + """ + schedule, size = self.schedule[block_idx] + B, N, C = x.shape + + D = len(size) + cur_mu_shape = [1] * D + + for strides in schedule: + # Extract the current patch from N + x = x.view(B, *strides, N // math.prod(strides), *cur_mu_shape, C) + + # Move that patch into the current MU + # Example in 2d: [B, Sy, Sx, N//(Sy*Sx), MUy, MUx, C] -> [B, N//(Sy*Sx), Sy, MUy, Sx, MUx, C] + L = len(x.shape) + permute = ( + [0, 1 + D] + + sum([list(p) for p in zip(range(1, 1 + D), range(1 + D + 1, L - 1))], []) + + [L - 1] + ) + x = x.permute(permute) + + # Reshape to [B, N//(Sy*Sx), *MU, C] + for i in range(D): + cur_mu_shape[i] *= strides[i] + x = x.reshape(B, -1, *cur_mu_shape, C) + N = x.shape[1] + + # Current shape (e.g., 2d: [B, #MUy*#MUx, MUy, MUx, C]) + x = x.view(B, N, *cur_mu_shape, C) + + # If masked, return [B, #MUs, MUy, MUx, C] + if mask is not None: + return x + + # If not masked, we can return [B, H, W, C] + x = undo_windowing(x, size, cur_mu_shape) + + return x + + +class MaskUnitAttention(nn.Module): + """ + Computes either Mask Unit or Global Attention. Also is able to perform q pooling. + + Note: this assumes the tokens have already been flattened and unrolled into mask units. + See `Unroll` for more details. + """ + fused_attn: torch.jit.Final[bool] + + def __init__( + self, + dim: int, + dim_out: int, + heads: int, + q_stride: int = 1, + window_size: int = 0, + use_mask_unit_attn: bool = False, + ): + """ + Args: + - dim, dim_out: The input and output feature dimensions. + - heads: The number of attention heads. + - q_stride: If greater than 1, pool q with this stride. The stride should be flattened (e.g., 2x2 = 4). + - window_size: The current (flattened) size of a mask unit *after* pooling (if any). + - use_mask_unit_attn: Use Mask Unit or Global Attention. + """ + super().__init__() + + self.dim = dim + self.dim_out = dim_out + self.heads = heads + self.q_stride = q_stride + self.head_dim = dim_out // heads + self.scale = self.head_dim ** -0.5 + self.fused_attn = use_fused_attn() + + self.qkv = nn.Linear(dim, 3 * dim_out) + self.proj = nn.Linear(dim_out, dim_out) + + self.window_size = window_size + self.use_mask_unit_attn = use_mask_unit_attn + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ Input should be of shape [batch, tokens, channels]. """ + B, N, _ = x.shape + num_windows = (N // (self.q_stride * self.window_size)) if self.use_mask_unit_attn else 1 + + qkv = self.qkv(x).reshape(B, -1, num_windows, 3, self.heads, self.head_dim).permute(3, 0, 4, 2, 1, 5) + q, k, v = qkv.unbind(0) + + if self.q_stride > 1: + # Refer to Unroll to see how this performs a maxpool-Nd + q = q.view(B, self.heads, num_windows, self.q_stride, -1, self.head_dim).amax(dim=3) + + if self.fused_attn: + # Note: the original paper did *not* use SDPA, it's a free boost! + x = F.scaled_dot_product_attention(q, k, v) + else: + attn = (q * self.scale) @ k.transpose(-1, -2) + attn = attn.softmax(dim=-1) + x = attn @ v + + x = x.transpose(1, 3).reshape(B, -1, self.dim_out) + x = self.proj(x) + return x + + +class HieraBlock(nn.Module): + def __init__( + self, + dim: int, + dim_out: int, + heads: int, + mlp_ratio: float = 4.0, + drop_path: float = 0.0, + norm_layer: nn.Module = nn.LayerNorm, + act_layer: nn.Module = nn.GELU, + q_stride: int = 1, + window_size: int = 0, + use_expand_proj: bool = True, + use_mask_unit_attn: bool = False, + ): + super().__init__() + + self.dim = dim + self.dim_out = dim_out + + self.norm1 = norm_layer(dim) + if dim != dim_out: + self.do_expand = True + if use_expand_proj: + self.proj = nn.Linear(dim, dim_out) + else: + assert dim_out == dim * 2 + self.proj = None + else: + self.do_expand = False + self.proj = None + self.attn = MaskUnitAttention( + dim, + dim_out, + heads, + q_stride, + window_size, + use_mask_unit_attn + ) + self.drop_path1 = DropPath(drop_path) if drop_path > 0 else nn.Identity() + + self.norm2 = norm_layer(dim_out) + self.mlp = Mlp(dim_out, int(dim_out * mlp_ratio), act_layer=act_layer) + self.drop_path2 = DropPath(drop_path) if drop_path > 0 else nn.Identity() + + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # Attention + Q Pooling + x_norm = self.norm1(x) + if self.do_expand: + if self.proj is not None: + x = self.proj(x_norm) + x = x.view(x.shape[0], self.attn.q_stride, -1, x.shape[-1]).amax(dim=1) # max-pool + else: + x = torch.cat([ + x.view(x.shape[0], self.attn.q_stride, -1, x.shape[-1]).amax(dim=1), # max-pool + x.view(x.shape[0], self.attn.q_stride, -1, x.shape[-1]).mean(dim=1), # avg-pool + ], + dim=-1, + ) + x = x + self.drop_path1(self.attn(x_norm)) + + # MLP + x = x + self.drop_path2(self.mlp(self.norm2(x))) + return x + + +class NormClassifierHead(nn.Module): + def __init__( + self, + in_features: int, + num_classes: int, + pool_type: str = 'avg', + drop_rate: float = 0.0, + norm_layer: Union[str, Callable] = 'layernorm', + ): + super().__init__() + norm_layer = get_norm_layer(norm_layer) + assert pool_type in ('avg', '') + self.in_features = self.num_features = in_features + self.pool_type = pool_type + self.norm = norm_layer(in_features) + self.drop = nn.Dropout(drop_rate) if drop_rate else nn.Identity() + self.fc = nn.Linear(in_features, num_classes) if num_classes > 0 else nn.Identity() + + def reset(self, num_classes: int, pool_type: Optional[str] = None, other: bool = False): + if pool_type is not None: + assert pool_type in ('avg', '') + self.pool_type = pool_type + if other: + # reset other non-fc layers + self.norm = nn.Identity() + self.fc = nn.Linear(self.in_features, num_classes) if num_classes > 0 else nn.Identity() + + def forward(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor: + if self.pool_type == 'avg': + x = x.mean(dim=1) + x = self.norm(x) + x = self.drop(x) + if pre_logits: + return x + x = self.fc(x) + return x + + +class PatchEmbed(nn.Module): + """Patch embed that supports any number of spatial dimensions (1d, 2d, 3d).""" + + def __init__( + self, + dim_in: int, + dim_out: int, + kernel: Tuple[int, ...], + stride: Tuple[int, ...], + padding: Tuple[int, ...], + reshape: bool = True, + ): + super().__init__() + + # Support any number of spatial dimensions + self.spatial_dims = len(kernel) + self.reshape = reshape + self.proj = conv_nd(self.spatial_dims)( + dim_in, + dim_out, + kernel_size=kernel, + stride=stride, + padding=padding, + ) + + def forward( + self, + x: torch.Tensor, + mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if mask is not None: + mask = get_resized_mask(target_size=x.shape[2:], mask=mask) + x = self.proj(x * mask.to(torch.bool)) + else: + x = self.proj(x) + if self.reshape: + x = x.reshape(x.shape[0], x.shape[1], -1).transpose(2, 1) + return x + + +class Hiera(nn.Module): + + def __init__( + self, + img_size: Tuple[int, ...] = (224, 224), + in_chans: int = 3, + embed_dim: int = 96, # initial embed dim + num_heads: int = 1, # initial number of heads + num_classes: int = 1000, + global_pool: str = 'avg', + stages: Tuple[int, ...] = (2, 3, 16, 3), + q_pool: int = 3, # number of q_pool stages + q_stride: Tuple[int, ...] = (2, 2), + mask_unit_size: Tuple[int, ...] = (8, 8), # must divide q_stride ** (#stages-1) + # mask_unit_attn: which stages use mask unit attention? + mask_unit_attn: Tuple[bool, ...] = (True, True, False, False), + dim_mul: float = 2.0, + head_mul: float = 2.0, + patch_kernel: Tuple[int, ...] = (7, 7), + patch_stride: Tuple[int, ...] = (4, 4), + patch_padding: Tuple[int, ...] = (3, 3), + mlp_ratio: float = 4.0, + drop_path_rate: float = 0.0, + norm_layer: Union[str, nn.Module] = "LayerNorm", + drop_rate: float = 0.0, + head_init_scale: float = 0.001, + sep_pos_embed: bool = False, + ): + super().__init__() + self.num_classes = num_classes + self.grad_checkpointing = False + norm_layer = get_norm_layer(norm_layer) + + self.patch_stride = patch_stride + self.tokens_spatial_shape = [i // s for i, s in zip(img_size, patch_stride)] + num_tokens = math.prod(self.tokens_spatial_shape) + flat_mu_size = math.prod(mask_unit_size) + flat_q_stride = math.prod(q_stride) + assert q_pool < len(stages) + self.q_pool, self.q_stride = q_pool, q_stride + self.mu_size, self.mask_unit_size = flat_mu_size, mask_unit_size + self.mask_spatial_shape = [i // s for i, s in zip(self.tokens_spatial_shape, self.mask_unit_size)] + self.stage_ends = [sum(stages[:i]) - 1 for i in range(1, len(stages) + 1)] + + self.patch_embed = PatchEmbed( + in_chans, + embed_dim, + patch_kernel, + patch_stride, + patch_padding, + #reshape=False, # leave spatial / temporal dims in output + ) + + if sep_pos_embed: + self.pos_embed = None + self.pos_embed_spatial = nn.Parameter( + torch.zeros(1, self.tokens_spatial_shape[1] * self.tokens_spatial_shape[2], embed_dim) + ) + self.pos_embed_temporal = nn.Parameter( + torch.zeros(1, self.tokens_spatial_shape[0], embed_dim) + ) + else: + self.pos_embed = nn.Parameter(torch.zeros(1, num_tokens, embed_dim)) + self.pos_embed_spatial = None + self.pos_embed_temporal = None + + # Setup roll and reroll modules + self.unroll = Unroll( + img_size, + patch_stride, + [q_stride] * len(self.stage_ends[:-1]) + ) + self.reroll = Reroll( + img_size, + patch_stride, + [q_stride] * len(self.stage_ends[:-1]), + self.stage_ends, + q_pool, + ) + # q_pool locations + q_pool_blocks = [x + 1 for x in self.stage_ends[:q_pool]] + + # Transformer blocks + cur_stage = 0 + depth = sum(stages) + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + self.blocks = nn.ModuleList() + self.feature_info = [] + for i in range(depth): + dim_out = embed_dim + # Mask unit or global attention. + # Lag by 1 block, so that global attention, + # applied post pooling on lower resolution + use_mask_unit_attn = mask_unit_attn[cur_stage] + + if i - 1 in self.stage_ends: + dim_out = int(embed_dim * dim_mul) + num_heads = int(num_heads * head_mul) + cur_stage += 1 + if i in q_pool_blocks: + flat_mu_size //= flat_q_stride + + block = HieraBlock( + dim=embed_dim, + dim_out=dim_out, + heads=num_heads, + mlp_ratio=mlp_ratio, + drop_path=dpr[i], + norm_layer=norm_layer, + q_stride=(flat_q_stride if i in q_pool_blocks else 1), + window_size=flat_mu_size, + use_mask_unit_attn=use_mask_unit_attn, + ) + embed_dim = dim_out + if i in self.stage_ends: + self.feature_info += [ + dict(num_chs=dim_out, reduction=2**(cur_stage+2), module=f'blocks.{self.stage_ends[cur_stage]}')] + self.blocks.append(block) + + self.num_features = embed_dim + self.head = NormClassifierHead( + embed_dim, + num_classes, + pool_type=global_pool, + drop_rate=drop_rate, + norm_layer=norm_layer, + ) + + # Initialize everything + if sep_pos_embed: + nn.init.trunc_normal_(self.pos_embed_spatial, std=0.02) + nn.init.trunc_normal_(self.pos_embed_temporal, std=0.02) + else: + nn.init.trunc_normal_(self.pos_embed, std=0.02) + self.apply(partial(self._init_weights)) + if isinstance(self.head.fc, nn.Linear): + self.head.fc.weight.data.mul_(head_init_scale) + self.head.fc.bias.data.mul_(head_init_scale) + + def _init_weights(self, m, init_bias=0.02): + if isinstance(m, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d)): + nn.init.trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, init_bias) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, init_bias) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + if self.pos_embed is not None: + return ["pos_embed"] + else: + return ["pos_embed_spatial", "pos_embed_temporal"] + + @torch.jit.ignore + def group_matcher(self, coarse: bool = False) -> Dict: + return dict( + stem=r'^pos_embed|pos_embed_spatial|pos_embed_temporal|patch_embed', # stem and embed + blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))] + ) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable: bool = True) -> None: + self.grad_checkpointing = enable + + @torch.jit.ignore + def get_classifier(self): + return self.head.fc + + def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None, other: bool = False): + self.num_classes = num_classes + self.head.reset(num_classes, global_pool, other=other) + + def get_random_mask(self, x: torch.Tensor, mask_ratio: float) -> torch.Tensor: + """ + Generates a random mask, mask_ratio fraction are dropped. + 1 is *keep*, 0 is *remove*. Useful for MAE, FLIP, etc. + """ + B = x.shape[0] + # Tokens selected for masking at mask unit level + num_windows = math.prod(self.mask_spatial_shape) # num_mask_units + len_keep = int(num_windows * (1 - mask_ratio)) + noise = torch.rand(B, num_windows, device=x.device) + + # Sort noise for each sample + ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove + ids_restore = torch.argsort(ids_shuffle, dim=1) + + # Generate the binary mask: 1 is *keep*, 0 is *remove* + # Note this is opposite to original MAE + mask = torch.zeros([B, num_windows], device=x.device) + mask[:, :len_keep] = 1 + # Unshuffle to get the binary mask + mask = torch.gather(mask, dim=1, index=ids_restore) + + return mask.bool() + + def _pos_embed(self, x) -> torch.Tensor: + if self.pos_embed is not None: + pos_embed = self.pos_embed + else: + pos_embed = ( + self.pos_embed_spatial.repeat(1, self.tokens_spatial_shape[0], 1) + + + torch.repeat_interleave( + self.pos_embed_temporal, + self.tokens_spatial_shape[1] * self.tokens_spatial_shape[2], + dim=1, + ) + ) + x = x + pos_embed + return x + + def forward_intermediates( + self, + x: torch.Tensor, + mask: Optional[torch.Tensor] = None, + indices: Optional[Union[int, List[int], Tuple[int]]] = None, + norm: bool = False, + stop_early: bool = True, + output_fmt: str = 'NCHW', + intermediates_only: bool = False, + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: + """ Forward features that returns intermediates. + + Args: + x: Input image tensor + indices: Take last n blocks if int, all if None, select matching indices if sequence + norm: Apply norm layer to all intermediates + stop_early: Stop iterating over blocks when last desired intermediate hit + output_fmt: Shape of intermediate feature outputs + intermediates_only: Only return intermediate features + Returns: + + """ + assert not norm, 'normalization of features not supported' + assert output_fmt in ('NCHW',), 'Output format must be one of NCHW.' + take_indices, max_index = feature_take_indices(len(self.stage_ends), indices) + take_indices = [self.stage_ends[i] for i in take_indices] + max_index = self.stage_ends[max_index] + + if mask is not None: + patch_mask = mask.view(x.shape[0], 1, *self.mask_spatial_shape) # B, C, *mask_spatial_shape + else: + patch_mask = None + x = self.patch_embed(x, mask=patch_mask) + x = self._pos_embed(x) + x = self.unroll(x) + + # Discard masked tokens + if mask is not None: + x = x[mask[..., None].tile(1, self.mu_size, x.shape[2])].view(x.shape[0], -1, x.shape[-1]) + + intermediates = [] + if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript + blocks = self.blocks + else: + blocks = self.blocks[:max_index + 1] + for i, blk in enumerate(blocks): + x = blk(x) + if i in take_indices: + intermediates.append(self.reroll(x, i, mask=mask).permute(0, 3, 1, 2)) + + if intermediates_only: + return intermediates + + return x, intermediates + + def prune_intermediate_layers( + self, + indices: Union[int, List[int], Tuple[int]] = 1, + prune_norm: bool = False, + prune_head: bool = True, + ): + """ Prune layers not required for specified intermediates. + """ + take_indices, max_index = feature_take_indices(len(self.stage_ends), indices) + max_index = self.stage_ends[max_index] + self.blocks = self.blocks[:max_index + 1] # truncate blocks + if prune_head: + self.head.reset(0, other=True) + return take_indices + + + def forward_features( + self, + x: torch.Tensor, + mask: Optional[torch.Tensor] = None, + return_intermediates: bool = False, + ) -> torch.Tensor: + """ + mask should be a boolean tensor of shape [B, #MUt*#MUy*#MUx] where #MU are the number of mask units in that dim. + Note: 1 in mask is *keep*, 0 is *remove*; mask.sum(dim=-1) should be the same across the batch. + """ + if mask is not None: + patch_mask = mask.view(x.shape[0], 1, *self.mask_spatial_shape) # B, C, *mask_spatial_shape + else: + patch_mask = None + x = self.patch_embed(x, mask=patch_mask) + x = self._pos_embed(x) + x = self.unroll(x) + + # Discard masked tokens + if mask is not None: + x = x[mask[..., None].tile(1, self.mu_size, x.shape[2])].view(x.shape[0], -1, x.shape[-1]) + + intermediates = [] + for i, blk in enumerate(self.blocks): + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint(blk, x) + else: + x = blk(x) + if return_intermediates and i in self.stage_ends: + intermediates.append(self.reroll(x, i, mask=mask)) + + # x may not always be in spatial order here. + # e.g. if q_pool = 2, mask_unit_size = (8, 8), and + # q_stride = (2, 2), not all unrolls were consumed, + # intermediates[-1] is x in spatial order + if return_intermediates: + return x, intermediates + + return x + + def forward_head(self, x, pre_logits: bool = False) -> torch.Tensor: + x = self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x) + return x + + def forward( + self, + x: torch.Tensor, + mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + x = self.forward_features(x, mask=mask) + if mask is None: + x = self.forward_head(x) + return x + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, + 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'patch_embed.proj', 'classifier': 'head.fc', + **kwargs + } + +default_cfgs = generate_default_cfgs({ + "hiera_tiny_224.mae_in1k_ft_in1k": _cfg( + hf_hub_id='timm/', + license='cc-by-nc-4.0', + ), + "hiera_tiny_224.mae": _cfg( + hf_hub_id='timm/', + license='cc-by-nc-4.0', + num_classes=0, + ), + + "hiera_small_224.mae_in1k_ft_in1k": _cfg( + hf_hub_id='timm/', + license='cc-by-nc-4.0', + ), + "hiera_small_224.mae": _cfg( + hf_hub_id='timm/', + license='cc-by-nc-4.0', + num_classes=0, + ), + + "hiera_base_224.mae_in1k_ft_in1k": _cfg( + hf_hub_id='timm/', + license='cc-by-nc-4.0', + ), + "hiera_base_224.mae": _cfg( + hf_hub_id='timm/', + license='cc-by-nc-4.0', + num_classes=0, + ), + + "hiera_base_plus_224.mae_in1k_ft_in1k": _cfg( + hf_hub_id='timm/', + license='cc-by-nc-4.0', + ), + "hiera_base_plus_224.mae": _cfg( + hf_hub_id='timm/', + license='cc-by-nc-4.0', + num_classes=0, + ), + + "hiera_large_224.mae_in1k_ft_in1k": _cfg( + hf_hub_id='timm/', + license='cc-by-nc-4.0', + ), + "hiera_large_224.mae": _cfg( + hf_hub_id='timm/', + license='cc-by-nc-4.0', + num_classes=0, + ), + + "hiera_huge_224.mae_in1k_ft_in1k": _cfg( + hf_hub_id='timm/', + license='cc-by-nc-4.0', + ), + "hiera_huge_224.mae": _cfg( + hf_hub_id='timm/', + license='cc-by-nc-4.0', + num_classes=0, + ), +}) + +def checkpoint_filter_fn(state_dict, model=None): + state_dict = state_dict.get('model_state', state_dict) + output = {} + for k, v in state_dict.items(): + if k == 'pos_embed' and v.shape[1] != model.pos_embed.shape[1]: + # # To resize pos embedding when using model at different size from pretrained weights + # from timm.layers import resample_abs_pos_embed + # v = resample_abs_pos_embed( + # v, + # new_size=(64, 64), + # num_prefix_tokens=0, + # verbose=True, + # ) + #v = F.interpolate(v.transpose(1, 2), (model.pos_embed.shape[1],)).transpose(1, 2) + pass + if 'head.projection.' in k: + k = k.replace('head.projection.', 'head.fc.') + if k.startswith('encoder_norm.'): + k = k.replace('encoder_norm.', 'head.norm.') + elif k.startswith('norm.'): + k = k.replace('norm.', 'head.norm.') + output[k] = v + return output + + +def _create_hiera(variant: str, pretrained: bool = False, **kwargs) -> Hiera: + out_indices = kwargs.pop('out_indices', 4) + + return build_model_with_cfg( + Hiera, + variant, + pretrained, + pretrained_filter_fn=checkpoint_filter_fn, + feature_cfg=dict(out_indices=out_indices, feature_cls='getter'), + **kwargs, + ) + +@register_model +def hiera_tiny_224(pretrained = False, **kwargs): + model_args = dict(embed_dim=96, num_heads=1, stages=(1, 2, 7, 2)) + return _create_hiera('hiera_tiny_224', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def hiera_small_224(pretrained = False, **kwargs): + model_args = dict(embed_dim=96, num_heads=1, stages=(1, 2, 11, 2)) + return _create_hiera('hiera_small_224', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def hiera_base_224(pretrained = False, **kwargs): + model_args = dict(embed_dim=96, num_heads=1, stages=(2, 3, 16, 3)) + return _create_hiera('hiera_base_224', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def hiera_base_plus_224(pretrained = False, **kwargs): + model_args = dict(embed_dim=112, num_heads=2, stages=(2, 3, 16, 3)) + return _create_hiera('hiera_base_plus_224', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def hiera_large_224(pretrained = False, **kwargs): + model_args = dict(embed_dim=144, num_heads=2, stages=(2, 6, 36, 4)) + return _create_hiera('hiera_large_224', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def hiera_huge_224(pretrained = False, **kwargs): + model_args = dict(embed_dim=256, num_heads=4, stages=(2, 6, 36, 4)) + return _create_hiera('hiera_huge_224', pretrained=pretrained, **dict(model_args, **kwargs)) diff --git a/timm/models/levit.py b/timm/models/levit.py index 037cae6e..ccac445c 100644 --- a/timm/models/levit.py +++ b/timm/models/levit.py @@ -628,7 +628,7 @@ class Levit(nn.Module): def get_classifier(self): return self.head - def reset_classifier(self, num_classes, global_pool=None, distillation=None): + def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None, distillation=None): self.num_classes = num_classes if global_pool is not None: self.global_pool = global_pool @@ -730,7 +730,7 @@ class LevitDistilled(Levit): def get_classifier(self): return self.head, self.head_dist - def reset_classifier(self, num_classes, global_pool=None, distillation=None): + def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None, distillation=None): self.num_classes = num_classes if global_pool is not None: self.global_pool = global_pool diff --git a/timm/models/maxxvit.py b/timm/models/maxxvit.py index 0be7b9b3..3dc08a55 100644 --- a/timm/models/maxxvit.py +++ b/timm/models/maxxvit.py @@ -1248,7 +1248,7 @@ class MaxxVit(nn.Module): def get_classifier(self): return self.head.fc - def reset_classifier(self, num_classes, global_pool=None): + def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): self.num_classes = num_classes self.head.reset(num_classes, global_pool) diff --git a/timm/models/mlp_mixer.py b/timm/models/mlp_mixer.py index b775b736..a1bf02be 100644 --- a/timm/models/mlp_mixer.py +++ b/timm/models/mlp_mixer.py @@ -255,7 +255,7 @@ class MlpMixer(nn.Module): def get_classifier(self): return self.head - def reset_classifier(self, num_classes, global_pool=None): + def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): self.num_classes = num_classes if global_pool is not None: assert global_pool in ('', 'avg') diff --git a/timm/models/mobilenetv3.py b/timm/models/mobilenetv3.py index 0846e191..ad6d8a85 100644 --- a/timm/models/mobilenetv3.py +++ b/timm/models/mobilenetv3.py @@ -622,8 +622,6 @@ def _gen_lcnet(variant: str, channel_multiplier: float = 1.0, pretrained: bool = return model - - def _gen_mobilenet_v4(variant: str, channel_multiplier: float = 1.0, pretrained: bool = False, **kwargs) -> MobileNetV3: """Creates a MobileNet-V4 model. diff --git a/timm/models/mvitv2.py b/timm/models/mvitv2.py index 5ad013e4..d6ba311c 100644 --- a/timm/models/mvitv2.py +++ b/timm/models/mvitv2.py @@ -825,7 +825,7 @@ class MultiScaleVit(nn.Module): def get_classifier(self): return self.head.fc - def reset_classifier(self, num_classes, global_pool=None): + def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): self.num_classes = num_classes if global_pool is not None: self.global_pool = global_pool diff --git a/timm/models/nextvit.py b/timm/models/nextvit.py index 7ef56a38..80857ed1 100644 --- a/timm/models/nextvit.py +++ b/timm/models/nextvit.py @@ -6,6 +6,7 @@ Next-ViT model defs and weights adapted from https://github.com/bytedance/Next-V """ # Copyright (c) ByteDance Inc. All rights reserved. from functools import partial +from typing import Optional import torch import torch.nn.functional as F @@ -553,7 +554,7 @@ class NextViT(nn.Module): def get_classifier(self): return self.head.fc - def reset_classifier(self, num_classes, global_pool=None): + def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): self.head.reset(num_classes, pool_type=global_pool) def forward_features(self, x): diff --git a/timm/models/pit.py b/timm/models/pit.py index 993606d5..ce41b9fc 100644 --- a/timm/models/pit.py +++ b/timm/models/pit.py @@ -14,13 +14,13 @@ Modifications for timm by / Copyright 2020 Ross Wightman import math import re from functools import partial -from typing import Sequence, Tuple +from typing import Optional, Sequence, Tuple import torch from torch import nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.layers import trunc_normal_, to_2tuple, LayerNorm +from timm.layers import trunc_normal_, to_2tuple from ._builder import build_model_with_cfg from ._registry import register_model, generate_default_cfgs from .vision_transformer import Block @@ -246,7 +246,7 @@ class PoolingVisionTransformer(nn.Module): else: return self.head - def reset_classifier(self, num_classes, global_pool=None): + def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): self.num_classes = num_classes self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() if self.head_dist is not None: diff --git a/timm/models/pvt_v2.py b/timm/models/pvt_v2.py index 16302002..1d9c6842 100644 --- a/timm/models/pvt_v2.py +++ b/timm/models/pvt_v2.py @@ -16,7 +16,7 @@ Modifications and timm support by / Copyright 2022, Ross Wightman """ import math -from typing import Tuple, List, Callable, Union +from typing import Callable, List, Optional, Union import torch import torch.nn as nn @@ -379,7 +379,7 @@ class PyramidVisionTransformerV2(nn.Module): def get_classifier(self): return self.head - def reset_classifier(self, num_classes, global_pool=None): + def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): self.num_classes = num_classes if global_pool is not None: assert global_pool in ('avg', '') diff --git a/timm/models/repvit.py b/timm/models/repvit.py index b4b2f46d..00ad78c8 100644 --- a/timm/models/repvit.py +++ b/timm/models/repvit.py @@ -16,15 +16,16 @@ Adapted from official impl at https://github.com/jameslahm/RepViT """ __all__ = ['RepVit'] - -import torch.nn as nn -from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from ._registry import register_model, generate_default_cfgs -from ._builder import build_model_with_cfg -from timm.layers import SqueezeExcite, trunc_normal_, to_ntuple, to_2tuple -from ._manipulate import checkpoint_seq +from typing import Optional import torch +import torch.nn as nn + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from timm.layers import SqueezeExcite, trunc_normal_, to_ntuple, to_2tuple +from ._builder import build_model_with_cfg +from ._manipulate import checkpoint_seq +from ._registry import register_model, generate_default_cfgs class ConvNorm(nn.Sequential): @@ -322,7 +323,7 @@ class RepVit(nn.Module): def get_classifier(self): return self.head - def reset_classifier(self, num_classes, global_pool=None, distillation=False): + def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None, distillation=False): self.num_classes = num_classes if global_pool is not None: self.global_pool = global_pool diff --git a/timm/models/sequencer.py b/timm/models/sequencer.py index cb0f15f3..7e12453b 100644 --- a/timm/models/sequencer.py +++ b/timm/models/sequencer.py @@ -9,7 +9,7 @@ Paper: `Sequencer: Deep LSTM for Image Classification` - https://arxiv.org/pdf/2 import math from functools import partial from itertools import accumulate -from typing import Tuple +from typing import Optional, Tuple import torch import torch.nn as nn @@ -419,7 +419,7 @@ class Sequencer2d(nn.Module): def get_classifier(self): return self.head - def reset_classifier(self, num_classes, global_pool=None): + def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): self.num_classes = num_classes self.head.reset(num_classes, pool_type=global_pool) diff --git a/timm/models/swin_transformer.py b/timm/models/swin_transformer.py index 6614e4ad..d5369282 100644 --- a/timm/models/swin_transformer.py +++ b/timm/models/swin_transformer.py @@ -604,7 +604,7 @@ class SwinTransformer(nn.Module): def get_classifier(self): return self.head.fc - def reset_classifier(self, num_classes, global_pool=None): + def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): self.num_classes = num_classes self.head.reset(num_classes, pool_type=global_pool) diff --git a/timm/models/swin_transformer_v2.py b/timm/models/swin_transformer_v2.py index 6bf2d767..49d449c9 100644 --- a/timm/models/swin_transformer_v2.py +++ b/timm/models/swin_transformer_v2.py @@ -605,7 +605,7 @@ class SwinTransformerV2(nn.Module): def get_classifier(self): return self.head.fc - def reset_classifier(self, num_classes, global_pool=None): + def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): self.num_classes = num_classes self.head.reset(num_classes, global_pool) diff --git a/timm/models/tiny_vit.py b/timm/models/tiny_vit.py index b4b29648..85eee7e0 100644 --- a/timm/models/tiny_vit.py +++ b/timm/models/tiny_vit.py @@ -8,10 +8,9 @@ Adapted from official impl at https://github.com/microsoft/Cream/tree/main/TinyV __all__ = ['TinyVit'] -import math import itertools from functools import partial -from typing import Dict +from typing import Dict, Optional import torch import torch.nn as nn @@ -533,7 +532,7 @@ class TinyVit(nn.Module): def get_classifier(self): return self.head.fc - def reset_classifier(self, num_classes, global_pool=None): + def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): self.num_classes = num_classes self.head.reset(num_classes, pool_type=global_pool) diff --git a/timm/models/tnt.py b/timm/models/tnt.py index c3590187..00ab2ba7 100644 --- a/timm/models/tnt.py +++ b/timm/models/tnt.py @@ -7,6 +7,7 @@ The official mindspore code is released and available at https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/TNT """ import math +from typing import Optional import torch import torch.nn as nn @@ -298,7 +299,7 @@ class TNT(nn.Module): def get_classifier(self): return self.head - def reset_classifier(self, num_classes, global_pool=None): + def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): self.num_classes = num_classes if global_pool is not None: assert global_pool in ('', 'token', 'avg') diff --git a/timm/models/tresnet.py b/timm/models/tresnet.py index 33d375f7..c00a28e2 100644 --- a/timm/models/tresnet.py +++ b/timm/models/tresnet.py @@ -7,6 +7,7 @@ Original model: https://github.com/mrT23/TResNet """ from collections import OrderedDict from functools import partial +from typing import Optional import torch import torch.nn as nn @@ -233,7 +234,7 @@ class TResNet(nn.Module): def get_classifier(self): return self.head.fc - def reset_classifier(self, num_classes, global_pool=None): + def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): self.head.reset(num_classes, pool_type=global_pool) def forward_features(self, x): diff --git a/timm/models/twins.py b/timm/models/twins.py index 8e898f9f..c7b9e157 100644 --- a/timm/models/twins.py +++ b/timm/models/twins.py @@ -382,7 +382,7 @@ class Twins(nn.Module): def get_classifier(self): return self.head - def reset_classifier(self, num_classes, global_pool=None): + def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): self.num_classes = num_classes if global_pool is not None: assert global_pool in ('', 'avg') diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 5fbabb59..9c95ae4b 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -845,7 +845,9 @@ def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = """ import numpy as np - def _n2p(w, t=True): + def _n2p(w, t=True, idx=None): + if idx is not None: + w = w[idx] if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1: w = w.flatten() if t: @@ -955,21 +957,28 @@ def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = mha_sub, b_sub, ln1_sub = (0, 0, 1) if big_vision else (1, 3, 2) for i, block in enumerate(model.blocks.children()): - block_prefix = f'{prefix}Transformer/encoderblock_{i}/' + if f'{prefix}Transformer/encoderblock/LayerNorm_0/scale' in w: + block_prefix = f'{prefix}Transformer/encoderblock/' + idx = i + else: + block_prefix = f'{prefix}Transformer/encoderblock_{i}/' + idx = None mha_prefix = block_prefix + f'MultiHeadDotProductAttention_{mha_sub}/' - block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'])) - block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'])) + block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'], idx=idx)) + block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'], idx=idx)) block.attn.qkv.weight.copy_(torch.cat([ - _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')])) + _n2p(w[f'{mha_prefix}{n}/kernel'], t=False, idx=idx).flatten(1).T for n in ('query', 'key', 'value')])) block.attn.qkv.bias.copy_(torch.cat([ - _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')])) - block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1)) - block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'])) - block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/scale'])) - block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/bias'])) + _n2p(w[f'{mha_prefix}{n}/bias'], t=False, idx=idx).reshape(-1) for n in ('query', 'key', 'value')])) + block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel'], idx=idx).flatten(1)) + block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'], idx=idx)) + block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/scale'], idx=idx)) + block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/bias'], idx=idx)) for r in range(2): - getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/kernel'])) - getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/bias'])) + getattr(block.mlp, f'fc{r + 1}').weight.copy_( + _n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/kernel'], idx=idx)) + getattr(block.mlp, f'fc{r + 1}').bias.copy_( + _n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/bias'], idx=idx)) def _convert_openai_clip( @@ -1769,6 +1778,73 @@ default_cfgs = { input_size=(3, 384, 384), num_classes=0), + 'vit_base_patch16_siglip_gap_224.webli': _cfg( + hf_hub_id='timm/ViT-B-16-SigLIP', + hf_hub_filename='open_clip_pytorch_model.bin', + num_classes=0), + 'vit_base_patch16_siglip_gap_256.webli': _cfg( + hf_hub_id='timm/ViT-B-16-SigLIP-256', + hf_hub_filename='open_clip_pytorch_model.bin', + input_size=(3, 256, 256), + num_classes=0), + 'vit_base_patch16_siglip_gap_384.webli': _cfg( + hf_hub_id='timm/ViT-B-16-SigLIP-384', + hf_hub_filename='open_clip_pytorch_model.bin', + input_size=(3, 384, 384), + num_classes=0), + 'vit_base_patch16_siglip_gap_512.webli': _cfg( + hf_hub_id='timm/ViT-B-16-SigLIP-512', + hf_hub_filename='open_clip_pytorch_model.bin', + input_size=(3, 512, 512), + num_classes=0), + 'vit_large_patch16_siglip_gap_256.webli': _cfg( + hf_hub_id='timm/ViT-L-16-SigLIP-256', + hf_hub_filename='open_clip_pytorch_model.bin', + input_size=(3, 256, 256), + num_classes=0), + 'vit_large_patch16_siglip_gap_384.webli': _cfg( + hf_hub_id='timm/ViT-L-16-SigLIP-384', + hf_hub_filename='open_clip_pytorch_model.bin', + input_size=(3, 384, 384), + num_classes=0), + 'vit_so400m_patch14_siglip_gap_224.webli': _cfg( + hf_hub_id='timm/ViT-SO400M-14-SigLIP', + hf_hub_filename='open_clip_pytorch_model.bin', + num_classes=0), + 'vit_so400m_patch14_siglip_gap_224.pali_mix': _cfg( + hf_hub_id='google/paligemma-3b-mix-224-jax', + hf_hub_filename='paligemma-3b-mix-224.npz', + custom_load='hf', + num_classes=0), + 'vit_so400m_patch14_siglip_gap_224.pali_pt': _cfg( + hf_hub_id='google/paligemma-3b-pt-224-jax', + hf_hub_filename='paligemma-3b-pt-224.npz', + custom_load='hf', + num_classes=0), + 'vit_so400m_patch14_siglip_gap_384.webli': _cfg( + hf_hub_id='timm/ViT-SO400M-14-SigLIP-384', + hf_hub_filename='open_clip_pytorch_model.bin', + input_size=(3, 384, 384), crop_pct=1.0, + num_classes=0), + 'vit_so400m_patch14_siglip_gap_448.pali_mix': _cfg( + hf_hub_id='google/paligemma-3b-mix-448-jax', + hf_hub_filename='paligemma-3b-mix-448.npz', + custom_load='hf', + input_size=(3, 448, 448), crop_pct=1.0, + num_classes=0), + 'vit_so400m_patch14_siglip_gap_448.pali_pt': _cfg( + hf_hub_id='google/paligemma-3b-pt-448-jax', + hf_hub_filename='paligemma-3b-pt-448.npz', + custom_load='hf', + input_size=(3, 448, 448), crop_pct=1.0, + num_classes=0), + 'vit_so400m_patch14_siglip_gap_896.pali_pt': _cfg( + hf_hub_id='google/paligemma-3b-pt-896-jax', + hf_hub_filename='paligemma-3b-pt-896.npz', + custom_load='hf', + input_size=(3, 896, 896), crop_pct=1.0, + num_classes=0), + 'vit_xsmall_patch16_clip_224.tinyclip_yfcc15m': _cfg( hf_hub_id='timm/', hf_hub_filename='open_clip_pytorch_model.bin', @@ -1791,6 +1867,7 @@ default_cfgs = { mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512), 'vit_wee_patch16_reg1_gap_256.sbb_in1k': _cfg( + hf_hub_id='timm/', input_size=(3, 256, 256), crop_pct=0.95), 'vit_pwee_patch16_reg1_gap_256.sbb_in1k': _cfg( hf_hub_id='timm/', @@ -1801,9 +1878,16 @@ default_cfgs = { 'vit_medium_patch16_reg1_gap_256.sbb_in1k': _cfg( hf_hub_id='timm/', input_size=(3, 256, 256), crop_pct=0.95), + 'vit_medium_patch16_reg4_gap_256.sbb_in12k_ft_in1k': _cfg( + hf_hub_id='timm/', + input_size=(3, 256, 256), crop_pct=0.95), 'vit_medium_patch16_reg4_gap_256.sbb_in1k': _cfg( hf_hub_id='timm/', input_size=(3, 256, 256), crop_pct=0.95), + 'vit_medium_patch16_reg4_gap_256.sbb_in12k': _cfg( + hf_hub_id='timm/', + num_classes=11821, + input_size=(3, 256, 256), crop_pct=0.95), 'vit_mediumd_patch16_reg4_gap_256.sbb_in12k_ft_in1k': _cfg( hf_hub_id='timm/', input_size=(3, 256, 256), crop_pct=0.95), @@ -2374,7 +2458,6 @@ def vit_base_patch16_clip_quickgelu_224(pretrained: bool = False, **kwargs) -> V def vit_large_patch14_clip_quickgelu_224(pretrained: bool = False, **kwargs) -> VisionTransformer: """ ViT-Large model (ViT-L/14) CLIP image tower w/ QuickGELU act """ - from timm.layers import get_act_layer model_args = dict( patch_size=14, embed_dim=1024, depth=24, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm, act_layer='quick_gelu') @@ -2756,15 +2839,118 @@ def vit_so400m_patch14_siglip_384(pretrained: bool = False, **kwargs) -> VisionT return model -# @register_model -# def vit_medium_patch16_reg4_256(pretrained: bool = False, **kwargs) -> VisionTransformer: -# model_args = dict( -# patch_size=16, embed_dim=512, depth=12, num_heads=8, class_token=True, -# no_embed_class=True, reg_tokens=4, -# ) -# model = _create_vision_transformer( -# 'vit_medium_patch16_reg4_256', pretrained=pretrained, **dict(model_args, **kwargs)) -# return model +@register_model +def vit_base_patch16_siglip_gap_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ A SigLIP variant of ViT with global average pooling (GAP) instead of attention pooling (MAP).""" + model_args = dict( + patch_size=16, embed_dim=768, depth=12, num_heads=12, class_token=False, global_pool='avg', fc_norm=False, + ) + model = _create_vision_transformer( + 'vit_base_patch16_siglip_gap_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_base_patch16_siglip_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ A SigLIP variant of ViT with global average pooling (GAP) instead of attention pooling (MAP).""" + model_args = dict( + patch_size=16, embed_dim=768, depth=12, num_heads=12, class_token=False, global_pool='avg', fc_norm=False, + ) + model = _create_vision_transformer( + 'vit_base_patch16_siglip_gap_256', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_base_patch16_siglip_gap_384(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ A SigLIP variant of ViT with global average pooling (GAP) instead of attention pooling (MAP).""" + model_args = dict( + patch_size=16, embed_dim=768, depth=12, num_heads=12, class_token=False, global_pool='avg', fc_norm=False, + ) + model = _create_vision_transformer( + 'vit_base_patch16_siglip_gap_384', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_base_patch16_siglip_gap_512(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ A SigLIP variant of ViT with global average pooling (GAP) instead of attention pooling (MAP).""" + model_args = dict( + patch_size=16, embed_dim=768, depth=12, num_heads=12, class_token=False, global_pool='avg', fc_norm=False, + ) + model = _create_vision_transformer( + 'vit_base_patch16_siglip_gap_512', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_large_patch16_siglip_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ A SigLIP variant of ViT with global average pooling (GAP) instead of attention pooling (MAP).""" + model_args = dict( + patch_size=16, embed_dim=1024, depth=24, num_heads=16, class_token=False, global_pool='avg', fc_norm=False, + ) + model = _create_vision_transformer( + 'vit_large_patch16_siglip_gap_256', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_large_patch16_siglip_gap_384(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ A SigLIP variant of ViT with global average pooling (GAP) instead of attention pooling (MAP).""" + model_args = dict( + patch_size=16, embed_dim=1024, depth=24, num_heads=16, class_token=False, global_pool='avg', fc_norm=False, + ) + model = _create_vision_transformer( + 'vit_large_patch16_siglip_gap_384', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_so400m_patch14_siglip_gap_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ A SigLIP variant of ViT with global average pooling (GAP) instead of attention pooling (MAP).""" + model_args = dict( + patch_size=14, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362, + class_token=False, global_pool='avg', fc_norm=False, + ) + model = _create_vision_transformer( + 'vit_so400m_patch14_siglip_gap_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_so400m_patch14_siglip_gap_384(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ A SigLIP variant of ViT with global average pooling (GAP) instead of attention pooling (MAP).""" + model_args = dict( + patch_size=14, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362, + class_token=False, global_pool='avg', fc_norm=False, + ) + model = _create_vision_transformer( + 'vit_so400m_patch14_siglip_gap_384', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_so400m_patch14_siglip_gap_448(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ A SigLIP variant of ViT with global average pooling (GAP) instead of attention pooling (MAP).""" + model_args = dict( + patch_size=14, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362, + class_token=False, global_pool='avg', fc_norm=False, + ) + model = _create_vision_transformer( + 'vit_so400m_patch14_siglip_gap_448', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_so400m_patch14_siglip_gap_896(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ A SigLIP variant of ViT with global average pooling (GAP) instead of attention pooling (MAP).""" + model_args = dict( + patch_size=14, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362, + class_token=False, global_pool='avg', fc_norm=False, + ) + model = _create_vision_transformer( + 'vit_so400m_patch14_siglip_gap_896', pretrained=pretrained, **dict(model_args, **kwargs)) + return model @register_model diff --git a/timm/models/volo.py b/timm/models/volo.py index a9ff905c..cefabd0e 100644 --- a/timm/models/volo.py +++ b/timm/models/volo.py @@ -622,7 +622,7 @@ class VOLO(nn.Module): def get_classifier(self): return self.head - def reset_classifier(self, num_classes, global_pool=None): + def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): self.num_classes = num_classes if global_pool is not None: self.global_pool = global_pool diff --git a/timm/utils/attention_extract.py b/timm/utils/attention_extract.py index 90021018..e813d42a 100644 --- a/timm/utils/attention_extract.py +++ b/timm/utils/attention_extract.py @@ -1,4 +1,5 @@ import fnmatch +import re from collections import OrderedDict from typing import Union, Optional, List @@ -17,6 +18,7 @@ class AttentionExtract(torch.nn.Module): mode: str = 'eval', method: str = 'fx', hook_type: str = 'forward', + use_regex: bool = False, ): """ Extract attention maps (or other activations) from a model by name. @@ -26,6 +28,7 @@ class AttentionExtract(torch.nn.Module): mode: 'train' or 'eval' model mode. method: 'fx' or 'hook' extraction method. hook_type: 'forward' or 'forward_pre' hooks used. + use_regex: Use regex instead of fnmatch """ super().__init__() assert mode in ('train', 'eval') @@ -40,14 +43,16 @@ class AttentionExtract(torch.nn.Module): from timm.models._features_fx import get_graph_node_names, GraphExtractNet node_names = get_graph_node_names(model)[0 if mode == 'train' else 1] - matched = [] names = names or self.default_node_names - for n in names: - matched.extend(fnmatch.filter(node_names, n)) + if use_regex: + regexes = [re.compile(r) for r in names] + matched = [g for g in node_names if any([r.match(g) for r in regexes])] + else: + matched = [g for g in node_names if any([fnmatch.fnmatch(g, n) for n in names])] if not matched: raise RuntimeError(f'No node names found matching {names}.') - self.model = GraphExtractNet(model, matched) + self.model = GraphExtractNet(model, matched, return_dict=True) self.hooks = None else: # names are module names @@ -55,10 +60,12 @@ class AttentionExtract(torch.nn.Module): from timm.models._features import FeatureHooks module_names = [n for n, m in model.named_modules()] - matched = [] names = names or self.default_module_names - for n in names: - matched.extend(fnmatch.filter(module_names, n)) + if use_regex: + regexes = [re.compile(r) for r in names] + matched = [m for m in module_names if any([r.match(m) for r in regexes])] + else: + matched = [m for m in module_names if any([fnmatch.fnmatch(m, n) for n in names])] if not matched: raise RuntimeError(f'No module names found matching {names}.') @@ -75,5 +82,4 @@ class AttentionExtract(torch.nn.Module): output = self.hooks.get_output(device=x.device) else: output = self.model(x) - output = OrderedDict(zip(self.names, output)) return output diff --git a/timm/utils/distributed.py b/timm/utils/distributed.py index 286e8ba4..18f526bb 100644 --- a/timm/utils/distributed.py +++ b/timm/utils/distributed.py @@ -108,9 +108,16 @@ def init_distributed_device_so( world_size = 1 global_rank = 0 local_rank = 0 + device_type, *device_idx = device.split(':', maxsplit=1) + if dist_backend is None: - # FIXME sane defaults for other device backends? - dist_backend = 'nccl' if 'cuda' in device else 'gloo' + # FIXME: verify that ROCm transform nccl to rccl + dist_backends = { + "xpu": "ccl", + "hpu": "hccl", + "cuda": "nccl", + } + dist_backend = dist_backends.get(device_type, 'gloo') dist_url = dist_url or 'env://' # TBD, support horovod? @@ -150,18 +157,15 @@ def init_distributed_device_so( global_rank = torch.distributed.get_rank() distributed = True - if 'cuda' in device: + if device_type == 'cuda': assert torch.cuda.is_available(), f'CUDA is not available but {device} was specified.' if distributed and device != 'cpu': - device, *device_idx = device.split(':', maxsplit=1) - # Ignore manually specified device index in distributed mode and # override with resolved local rank, fewer headaches in most setups. if device_idx: _logger.warning(f'device index {device_idx[0]} removed from specified ({device}).') - - device = f'{device}:{local_rank}' + device = f'{device_type}:{local_rank}' if device.startswith('cuda:'): torch.cuda.set_device(device) diff --git a/timm/version.py b/timm/version.py index c6092d3e..fb88fd21 100644 --- a/timm/version.py +++ b/timm/version.py @@ -1 +1 @@ -__version__ = '1.0.1.dev0' +__version__ = '1.0.4.dev0'