mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Merge remote-tracking branch 'origin/main' into efficientnet_x
This commit is contained in:
commit
cee79dada0
@ -26,6 +26,12 @@
|
||||
* The Hugging Face Hub (https://huggingface.co/timm) is now the primary source for `timm` weights. Model cards include link to papers, original source, license.
|
||||
* Previous 0.6.x can be cloned from [0.6.x](https://github.com/rwightman/pytorch-image-models/tree/0.6.x) branch or installed via pip with version.
|
||||
|
||||
### May 14, 2024
|
||||
* Support loading PaliGemma jax weights into SigLIP ViT models with average pooling.
|
||||
* Add Hiera models from Meta (https://github.com/facebookresearch/hiera).
|
||||
* Add `normalize=` flag for transorms, return non-normalized torch.Tensor with original dytpe (for `chug`)
|
||||
* Version 1.0.3 release
|
||||
|
||||
### May 11, 2024
|
||||
* `Searching for Better ViT Baselines (For the GPU Poor)` weights and vit variants released. Exploring model shapes between Tiny and Base.
|
||||
|
||||
@ -42,6 +48,7 @@
|
||||
| [vit_medium_patch16_reg4_gap_256.sbb_in1k](https://huggingface.co/timm/vit_medium_patch16_reg4_gap_256.sbb_in1k) | 83.47 | 96.622 | 38.88 | 256 |
|
||||
| [vit_medium_patch16_reg1_gap_256.sbb_in1k](https://huggingface.co/timm/vit_medium_patch16_reg1_gap_256.sbb_in1k) | 83.462 | 96.548 | 38.88 | 256 |
|
||||
| [vit_little_patch16_reg4_gap_256.sbb_in1k](https://huggingface.co/timm/vit_little_patch16_reg4_gap_256.sbb_in1k) | 82.514 | 96.262 | 22.52 | 256 |
|
||||
| [vit_wee_patch16_reg1_gap_256.sbb_in1k](https://huggingface.co/timm/vit_wee_patch16_reg1_gap_256.sbb_in1k) | 80.256 | 95.360 | 13.42 | 256 |
|
||||
| [vit_pwee_patch16_reg1_gap_256.sbb_in1k](https://huggingface.co/timm/vit_pwee_patch16_reg1_gap_256.sbb_in1k) | 80.072 | 95.136 | 15.25 | 256 |
|
||||
| [vit_mediumd_patch16_reg4_gap_256.sbb_in12k](https://huggingface.co/timm/vit_mediumd_patch16_reg4_gap_256.sbb_in12k) | N/A | N/A | 64.11 | 256 |
|
||||
| [vit_betwixt_patch16_reg4_gap_256.sbb_in12k](https://huggingface.co/timm/vit_betwixt_patch16_reg4_gap_256.sbb_in12k) | N/A | N/A | 60.4 | 256 |
|
||||
|
@ -192,9 +192,9 @@ There are two additional creation arguments impacting the output features.
|
||||
|
||||
#### Output index selection
|
||||
|
||||
The `out_indices` argument is supported by all models, but not all models have the same index to feature stride mapping. Look at the code or check feature_info to compare. The out indices generally correspond to the `C(i+1)th` feature level (a `2^(i+1)` reduction). For most convnet models, index 0 is the stride 2 features, and index 4 is stride 32. For many ViT or ViT-Conv hybrids there may be many to all features maps of the same shape, or a combination of hierarchical and non-hieararchical feature maps. It is best to look at the `feature_info` attribute to see the number of features, their corresponding channel count and reduction level.
|
||||
The `out_indices` argument is supported by all models, but not all models have the same index to feature stride mapping. Look at the code or check feature_info to compare. The out indices generally correspond to the `C(i+1)th` feature level (a `2^(i+1)` reduction). For most convnet models, index 0 is the stride 2 features, and index 4 is stride 32. For many ViT or ViT-Conv hybrids there may be many to all features maps of the same shape, or a combination of hierarchical and non-hierarchical feature maps. It is best to look at the `feature_info` attribute to see the number of features, their corresponding channel count and reduction level.
|
||||
|
||||
`out_indices` supports negative indexing, this makes it easy to get the last, penunltimate, etc feature map. `out_indices=(-2,)` would return the penultimate feature map for any model.
|
||||
`out_indices` supports negative indexing, this makes it easy to get the last, penultimate, etc feature map. `out_indices=(-2,)` would return the penultimate feature map for any model.
|
||||
|
||||
#### Output stride (feature map dilation)
|
||||
|
||||
@ -228,7 +228,7 @@ Accompanying the `forward_intermediates` function is a `prune_intermediate_layer
|
||||
|
||||
An `indices` argument is used for both `forward_intermediates()` and `prune_intermediate_layers()` to select the features to return or layers to remove. As with the `out_indices` for `features_only` API, `indices` is model specific and selects which intermediates are returned.
|
||||
|
||||
In non-hierarchical block based models such as ViT the indices correspond to the blocks, in models with hierarchical stages they usually correspond to the output of the stem + each hierarhical stage. Both positive (from the start), and negative (relative to the end) indexing works, and `None` is used to return all intermediates.
|
||||
In non-hierarchical block based models such as ViT the indices correspond to the blocks, in models with hierarchical stages they usually correspond to the output of the stem + each hierarchical stage. Both positive (from the start), and negative (relative to the end) indexing works, and `None` is used to return all intermediates.
|
||||
|
||||
The `prune_intermediate_layers()` call returns an indices variable, as negative indices must be converted to absolute (positive) indices when the model is trimmed.
|
||||
|
||||
|
@ -28,7 +28,7 @@ You should install `timm` in a [virtual environment](https://docs.python.org/3/l
|
||||
# Deactivate the virtual environment
|
||||
source .env/bin/deactivate
|
||||
```
|
||||
`
|
||||
|
||||
Once you've created your virtual environment, you can install `timm` in it.
|
||||
|
||||
## Using pip
|
||||
|
@ -52,7 +52,7 @@ FEAT_INTER_FILTERS = [
|
||||
'vision_transformer', 'vision_transformer_sam', 'vision_transformer_hybrid', 'vision_transformer_relpos',
|
||||
'beit', 'mvitv2', 'eva', 'cait', 'xcit', 'volo', 'twins', 'deit', 'swin_transformer', 'swin_transformer_v2',
|
||||
'swin_transformer_v2_cr', 'maxxvit', 'efficientnet', 'mobilenetv3', 'levit', 'efficientformer', 'resnet',
|
||||
'regnet', 'byobnet', 'byoanet', 'mlp_mixer'
|
||||
'regnet', 'byobnet', 'byoanet', 'mlp_mixer', 'hiera',
|
||||
]
|
||||
|
||||
# transformer / hybrid models don't support full set of spatial / feature APIs and/or have spatial output.
|
||||
@ -60,7 +60,7 @@ NON_STD_FILTERS = [
|
||||
'vit_*', 'tnt_*', 'pit_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*',
|
||||
'convit_*', 'levit*', 'visformer*', 'deit*', 'xcit_*', 'crossvit_*', 'beit*',
|
||||
'poolformer_*', 'volo_*', 'sequencer2d_*', 'mvitv2*', 'gcvit*', 'efficientformer*',
|
||||
'eva_*', 'flexivit*', 'eva02*', 'samvit_*', 'efficientvit_m*', 'tiny_vit_*'
|
||||
'eva_*', 'flexivit*', 'eva02*', 'samvit_*', 'efficientvit_m*', 'tiny_vit_*', 'hiera_*'
|
||||
]
|
||||
NUM_NON_STD = len(NON_STD_FILTERS)
|
||||
|
||||
@ -77,7 +77,7 @@ else:
|
||||
EXCLUDE_FILTERS = ['*enormous*']
|
||||
NON_STD_EXCLUDE_FILTERS = ['*gigantic*', '*enormous*']
|
||||
|
||||
EXCLUDE_JIT_FILTERS = []
|
||||
EXCLUDE_JIT_FILTERS = ['hiera_*']
|
||||
|
||||
TARGET_FWD_SIZE = MAX_FWD_SIZE = 384
|
||||
TARGET_BWD_SIZE = 128
|
||||
@ -486,7 +486,7 @@ def _create_fx_model(model, train=False):
|
||||
return fx_model
|
||||
|
||||
|
||||
EXCLUDE_FX_FILTERS = ['vit_gi*']
|
||||
EXCLUDE_FX_FILTERS = ['vit_gi*', 'hiera*']
|
||||
# not enough memory to run fx on more models than other tests
|
||||
if 'GITHUB_ACTIONS' in os.environ:
|
||||
EXCLUDE_FX_FILTERS += [
|
||||
|
@ -19,9 +19,10 @@ from timm.data.random_erasing import RandomErasing
|
||||
def transforms_noaug_train(
|
||||
img_size: Union[int, Tuple[int, int]] = 224,
|
||||
interpolation: str = 'bilinear',
|
||||
use_prefetcher: bool = False,
|
||||
mean: Tuple[float, ...] = IMAGENET_DEFAULT_MEAN,
|
||||
std: Tuple[float, ...] = IMAGENET_DEFAULT_STD,
|
||||
use_prefetcher: bool = False,
|
||||
normalize: bool = True,
|
||||
):
|
||||
""" No-augmentation image transforms for training.
|
||||
|
||||
@ -31,6 +32,7 @@ def transforms_noaug_train(
|
||||
mean: Image normalization mean.
|
||||
std: Image normalization standard deviation.
|
||||
use_prefetcher: Prefetcher enabled. Do not convert image to tensor or normalize.
|
||||
normalize: Normalization tensor output w/ provided mean/std (if prefetcher not used).
|
||||
|
||||
Returns:
|
||||
|
||||
@ -45,6 +47,9 @@ def transforms_noaug_train(
|
||||
if use_prefetcher:
|
||||
# prefetcher and collate will handle tensor conversion and norm
|
||||
tfl += [ToNumpy()]
|
||||
elif not normalize:
|
||||
# when normalize disabled, converted to tensor without scaling, keep original dtype
|
||||
tfl += [transforms.PILToTensor()]
|
||||
else:
|
||||
tfl += [
|
||||
transforms.ToTensor(),
|
||||
@ -77,6 +82,7 @@ def transforms_imagenet_train(
|
||||
re_count: int = 1,
|
||||
re_num_splits: int = 0,
|
||||
use_prefetcher: bool = False,
|
||||
normalize: bool = True,
|
||||
separate: bool = False,
|
||||
):
|
||||
""" ImageNet-oriented image transforms for training.
|
||||
@ -103,6 +109,7 @@ def transforms_imagenet_train(
|
||||
re_count: Number of random erasing regions.
|
||||
re_num_splits: Control split of random erasing across batch size.
|
||||
use_prefetcher: Prefetcher enabled. Do not convert image to tensor or normalize.
|
||||
normalize: Normalize tensor output w/ provided mean/std (if prefetcher not used).
|
||||
separate: Output transforms in 3-stage tuple.
|
||||
|
||||
Returns:
|
||||
@ -209,12 +216,15 @@ def transforms_imagenet_train(
|
||||
if use_prefetcher:
|
||||
# prefetcher and collate will handle tensor conversion and norm
|
||||
final_tfl += [ToNumpy()]
|
||||
elif not normalize:
|
||||
# when normalize disable, converted to tensor without scaling, keeps original dtype
|
||||
final_tfl += [transforms.PILToTensor()]
|
||||
else:
|
||||
final_tfl += [
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(
|
||||
mean=torch.tensor(mean),
|
||||
std=torch.tensor(std)
|
||||
std=torch.tensor(std),
|
||||
),
|
||||
]
|
||||
if re_prob > 0.:
|
||||
@ -243,6 +253,7 @@ def transforms_imagenet_eval(
|
||||
mean: Tuple[float, ...] = IMAGENET_DEFAULT_MEAN,
|
||||
std: Tuple[float, ...] = IMAGENET_DEFAULT_STD,
|
||||
use_prefetcher: bool = False,
|
||||
normalize: bool = True,
|
||||
):
|
||||
""" ImageNet-oriented image transform for evaluation and inference.
|
||||
|
||||
@ -255,6 +266,7 @@ def transforms_imagenet_eval(
|
||||
mean: Image normalization mean.
|
||||
std: Image normalization standard deviation.
|
||||
use_prefetcher: Prefetcher enabled. Do not convert image to tensor or normalize.
|
||||
normalize: Normalize tensor output w/ provided mean/std (if prefetcher not used).
|
||||
|
||||
Returns:
|
||||
Composed transform pipeline
|
||||
@ -304,13 +316,16 @@ def transforms_imagenet_eval(
|
||||
if use_prefetcher:
|
||||
# prefetcher and collate will handle tensor conversion and norm
|
||||
tfl += [ToNumpy()]
|
||||
elif not normalize:
|
||||
# when normalize disabled, converted to tensor without scaling, keeps original dtype
|
||||
tfl += [transforms.PILToTensor()]
|
||||
else:
|
||||
tfl += [
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(
|
||||
mean=torch.tensor(mean),
|
||||
std=torch.tensor(std),
|
||||
)
|
||||
),
|
||||
]
|
||||
|
||||
return transforms.Compose(tfl)
|
||||
@ -342,6 +357,7 @@ def create_transform(
|
||||
crop_border_pixels: Optional[int] = None,
|
||||
tf_preprocessing: bool = False,
|
||||
use_prefetcher: bool = False,
|
||||
normalize: bool = True,
|
||||
separate: bool = False,
|
||||
):
|
||||
"""
|
||||
@ -373,6 +389,7 @@ def create_transform(
|
||||
crop_border_pixels: Inference crop border of specified # pixels around edge of original image.
|
||||
tf_preprocessing: Use TF 1.0 inference preprocessing for testing model ports
|
||||
use_prefetcher: Pre-fetcher enabled. Do not convert image to tensor or normalize.
|
||||
normalize: Normalization tensor output w/ provided mean/std (if prefetcher not used).
|
||||
separate: Output transforms in 3-stage tuple.
|
||||
|
||||
Returns:
|
||||
@ -397,9 +414,10 @@ def create_transform(
|
||||
transform = transforms_noaug_train(
|
||||
img_size,
|
||||
interpolation=interpolation,
|
||||
use_prefetcher=use_prefetcher,
|
||||
mean=mean,
|
||||
std=std,
|
||||
use_prefetcher=use_prefetcher,
|
||||
normalize=normalize,
|
||||
)
|
||||
elif is_training:
|
||||
transform = transforms_imagenet_train(
|
||||
@ -415,13 +433,14 @@ def create_transform(
|
||||
gaussian_blur_prob=gaussian_blur_prob,
|
||||
auto_augment=auto_augment,
|
||||
interpolation=interpolation,
|
||||
use_prefetcher=use_prefetcher,
|
||||
mean=mean,
|
||||
std=std,
|
||||
re_prob=re_prob,
|
||||
re_mode=re_mode,
|
||||
re_count=re_count,
|
||||
re_num_splits=re_num_splits,
|
||||
use_prefetcher=use_prefetcher,
|
||||
normalize=normalize,
|
||||
separate=separate,
|
||||
)
|
||||
else:
|
||||
@ -429,12 +448,13 @@ def create_transform(
|
||||
transform = transforms_imagenet_eval(
|
||||
img_size,
|
||||
interpolation=interpolation,
|
||||
use_prefetcher=use_prefetcher,
|
||||
mean=mean,
|
||||
std=std,
|
||||
crop_pct=crop_pct,
|
||||
crop_mode=crop_mode,
|
||||
crop_border_pixels=crop_border_pixels,
|
||||
use_prefetcher=use_prefetcher,
|
||||
normalize=normalize,
|
||||
)
|
||||
|
||||
return transform
|
||||
|
@ -108,7 +108,7 @@ class ClassifierHead(nn.Module):
|
||||
self.fc = fc
|
||||
self.flatten = nn.Flatten(1) if use_conv and pool_type else nn.Identity()
|
||||
|
||||
def reset(self, num_classes, pool_type=None):
|
||||
def reset(self, num_classes: int, pool_type: Optional[str] = None):
|
||||
if pool_type is not None and pool_type != self.global_pool.pool_type:
|
||||
self.global_pool, self.fc = create_classifier(
|
||||
self.in_features,
|
||||
@ -180,7 +180,7 @@ class NormMlpClassifierHead(nn.Module):
|
||||
self.drop = nn.Dropout(drop_rate)
|
||||
self.fc = linear_layer(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
def reset(self, num_classes, pool_type=None):
|
||||
def reset(self, num_classes: int, pool_type: Optional[str] = None):
|
||||
if pool_type is not None:
|
||||
self.global_pool = SelectAdaptivePool2d(pool_type=pool_type)
|
||||
self.flatten = nn.Flatten(1) if pool_type else nn.Identity()
|
||||
|
@ -47,7 +47,7 @@ def get_norm_layer(norm_layer):
|
||||
if isinstance(norm_layer, str):
|
||||
if not norm_layer:
|
||||
return None
|
||||
layer_name = norm_layer.replace('_', '')
|
||||
layer_name = norm_layer.replace('_', '').lower()
|
||||
norm_layer = _NORM_MAP[layer_name]
|
||||
else:
|
||||
norm_layer = norm_layer
|
||||
|
@ -6,7 +6,7 @@ import torch.nn as nn
|
||||
|
||||
class PatchDropout(nn.Module):
|
||||
"""
|
||||
https://arxiv.org/abs/2212.00794
|
||||
https://arxiv.org/abs/2212.00794 and https://arxiv.org/pdf/2208.07220
|
||||
"""
|
||||
return_indices: torch.jit.Final[bool]
|
||||
|
||||
|
@ -26,6 +26,7 @@ from .gcvit import *
|
||||
from .ghostnet import *
|
||||
from .hardcorenas import *
|
||||
from .hgnet import *
|
||||
from .hiera import *
|
||||
from .hrnet import *
|
||||
from .inception_next import *
|
||||
from .inception_resnet_v2 import *
|
||||
|
@ -10,7 +10,8 @@ from torch.hub import load_state_dict_from_url
|
||||
from timm.models._features import FeatureListNet, FeatureDictNet, FeatureHookNet, FeatureGetterNet
|
||||
from timm.models._features_fx import FeatureGraphNet
|
||||
from timm.models._helpers import load_state_dict
|
||||
from timm.models._hub import has_hf_hub, download_cached_file, check_cached_file, load_state_dict_from_hf
|
||||
from timm.models._hub import has_hf_hub, download_cached_file, check_cached_file, load_state_dict_from_hf,\
|
||||
load_custom_from_hf
|
||||
from timm.models._manipulate import adapt_input_conv
|
||||
from timm.models._pretrained import PretrainedCfg
|
||||
from timm.models._prune import adapt_model_from_file
|
||||
@ -185,7 +186,12 @@ def load_pretrained(
|
||||
elif load_from == 'hf-hub':
|
||||
_logger.info(f'Loading pretrained weights from Hugging Face hub ({pretrained_loc})')
|
||||
if isinstance(pretrained_loc, (list, tuple)):
|
||||
state_dict = load_state_dict_from_hf(*pretrained_loc)
|
||||
custom_load = pretrained_cfg.get('custom_load', False)
|
||||
if isinstance(custom_load, str) and custom_load == 'hf':
|
||||
load_custom_from_hf(*pretrained_loc, model)
|
||||
return
|
||||
else:
|
||||
state_dict = load_state_dict_from_hf(*pretrained_loc)
|
||||
else:
|
||||
state_dict = load_state_dict_from_hf(pretrained_loc)
|
||||
else:
|
||||
|
@ -118,6 +118,7 @@ class FeatureGraphNet(nn.Module):
|
||||
out_indices: Tuple[int, ...],
|
||||
out_map: Optional[Dict] = None,
|
||||
output_fmt: str = 'NCHW',
|
||||
return_dict: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
assert has_fx_feature_extraction, 'Please update to PyTorch 1.10+, torchvision 0.11+ for FX feature extraction'
|
||||
@ -127,9 +128,13 @@ class FeatureGraphNet(nn.Module):
|
||||
self.output_fmt = Format(output_fmt)
|
||||
return_nodes = _get_return_layers(self.feature_info, out_map)
|
||||
self.graph_module = create_feature_extractor(model, return_nodes)
|
||||
self.return_dict = return_dict
|
||||
|
||||
def forward(self, x):
|
||||
return list(self.graph_module(x).values())
|
||||
out = self.graph_module(x)
|
||||
if self.return_dict:
|
||||
return out
|
||||
return list(out.values())
|
||||
|
||||
|
||||
class GraphExtractNet(nn.Module):
|
||||
@ -144,19 +149,23 @@ class GraphExtractNet(nn.Module):
|
||||
model: model to extract features from
|
||||
return_nodes: node names to return features from (dict or list)
|
||||
squeeze_out: if only one output, and output in list format, flatten to single tensor
|
||||
return_dict: return as dictionary from extractor with node names as keys, ignores squeeze_out arg
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
model: nn.Module,
|
||||
return_nodes: Union[Dict[str, str], List[str]],
|
||||
squeeze_out: bool = True,
|
||||
return_dict: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.squeeze_out = squeeze_out
|
||||
self.graph_module = create_feature_extractor(model, return_nodes)
|
||||
self.return_dict = return_dict
|
||||
|
||||
def forward(self, x) -> Union[List[torch.Tensor], torch.Tensor]:
|
||||
out = list(self.graph_module(x).values())
|
||||
if self.squeeze_out and len(out) == 1:
|
||||
return out[0]
|
||||
return out
|
||||
out = self.graph_module(x)
|
||||
if self.return_dict:
|
||||
return out
|
||||
out = list(out.values())
|
||||
return out[0] if self.squeeze_out and len(out) == 1 else out
|
||||
|
@ -190,6 +190,13 @@ def load_state_dict_from_hf(model_id: str, filename: str = HF_WEIGHTS_NAME):
|
||||
return torch.load(cached_file, map_location='cpu')
|
||||
|
||||
|
||||
def load_custom_from_hf(model_id: str, filename: str, model: torch.nn.Module):
|
||||
assert has_hf_hub(True)
|
||||
hf_model_id, hf_revision = hf_split(model_id)
|
||||
cached_file = hf_hub_download(hf_model_id, filename=filename, revision=hf_revision)
|
||||
return model.load_pretrained(cached_file)
|
||||
|
||||
|
||||
def save_config_for_hf(
|
||||
model,
|
||||
config_path: str,
|
||||
|
@ -395,7 +395,7 @@ class Beit(nn.Module):
|
||||
def get_classifier(self):
|
||||
return self.head
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool=None):
|
||||
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
|
||||
self.num_classes = num_classes
|
||||
if global_pool is not None:
|
||||
self.global_pool = global_pool
|
||||
|
@ -331,7 +331,7 @@ class Cait(nn.Module):
|
||||
def get_classifier(self):
|
||||
return self.head
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool=None):
|
||||
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
|
||||
self.num_classes = num_classes
|
||||
if global_pool is not None:
|
||||
assert global_pool in ('', 'token', 'avg')
|
||||
|
@ -7,8 +7,7 @@ Official CoaT code at: https://github.com/mlpc-ucsd/CoaT
|
||||
|
||||
Modified from timm/models/vision_transformer.py
|
||||
"""
|
||||
from functools import partial
|
||||
from typing import Tuple, List, Union
|
||||
from typing import List, Optional, Union, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -560,7 +559,7 @@ class CoaT(nn.Module):
|
||||
def get_classifier(self):
|
||||
return self.head
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool=None):
|
||||
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
|
||||
self.num_classes = num_classes
|
||||
if global_pool is not None:
|
||||
assert global_pool in ('token', 'avg')
|
||||
|
@ -21,8 +21,7 @@ Modifications and additions for timm hacked together by / Copyright 2021, Ross W
|
||||
'''These modules are adapted from those of timm, see
|
||||
https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
||||
'''
|
||||
|
||||
from functools import partial
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -349,7 +348,7 @@ class ConVit(nn.Module):
|
||||
def get_classifier(self):
|
||||
return self.head
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool=None):
|
||||
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
|
||||
self.num_classes = num_classes
|
||||
if global_pool is not None:
|
||||
assert global_pool in ('', 'token', 'avg')
|
||||
|
@ -1,6 +1,8 @@
|
||||
""" ConvMixer
|
||||
|
||||
"""
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
@ -75,7 +77,7 @@ class ConvMixer(nn.Module):
|
||||
def get_classifier(self):
|
||||
return self.head
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool=None):
|
||||
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
|
||||
self.num_classes = num_classes
|
||||
if global_pool is not None:
|
||||
self.pooling = SelectAdaptivePool2d(pool_type=global_pool, flatten=True)
|
||||
|
@ -37,7 +37,6 @@ Modifications and additions for timm hacked together by / Copyright 2022, Ross W
|
||||
# LICENSE file in the root directory of this source tree (Attribution-NonCommercial 4.0 International (CC BY-NC 4.0))
|
||||
# No code was used directly from ConvNeXt-V2, however the weights are CC BY-NC 4.0 so beware if using commercially.
|
||||
|
||||
from collections import OrderedDict
|
||||
from functools import partial
|
||||
from typing import Callable, List, Optional, Tuple, Union
|
||||
|
||||
|
@ -25,8 +25,7 @@ Modifed from Timm. https://github.com/rwightman/pytorch-image-models/blob/master
|
||||
|
||||
"""
|
||||
from functools import partial
|
||||
from typing import List
|
||||
from typing import Tuple
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.hub
|
||||
@ -419,7 +418,7 @@ class CrossVit(nn.Module):
|
||||
def get_classifier(self):
|
||||
return self.head
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool=None):
|
||||
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
|
||||
self.num_classes = num_classes
|
||||
if global_pool is not None:
|
||||
assert global_pool in ('token', 'avg')
|
||||
|
@ -12,7 +12,7 @@ DaViT model defs and weights adapted from https://github.com/dingmyu/davit, orig
|
||||
# All rights reserved.
|
||||
# This source code is licensed under the MIT license
|
||||
from functools import partial
|
||||
from typing import Tuple
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -568,7 +568,7 @@ class DaVit(nn.Module):
|
||||
def get_classifier(self):
|
||||
return self.head.fc
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool=None):
|
||||
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
|
||||
self.head.reset(num_classes, global_pool)
|
||||
|
||||
def forward_features(self, x):
|
||||
|
@ -11,7 +11,7 @@ Modifications copyright 2021, Ross Wightman
|
||||
# Copyright (c) 2015-present, Facebook, Inc.
|
||||
# All rights reserved.
|
||||
from functools import partial
|
||||
from typing import Sequence, Union
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch import nn as nn
|
||||
@ -20,7 +20,6 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import resample_abs_pos_embed
|
||||
from timm.models.vision_transformer import VisionTransformer, trunc_normal_, checkpoint_filter_fn
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._manipulate import checkpoint_seq
|
||||
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
|
||||
|
||||
__all__ = ['VisionTransformerDistilled'] # model_registry will add each entrypoint fn to this
|
||||
@ -64,7 +63,7 @@ class VisionTransformerDistilled(VisionTransformer):
|
||||
def get_classifier(self):
|
||||
return self.head, self.head_dist
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool=None):
|
||||
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
|
||||
self.num_classes = num_classes
|
||||
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
||||
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
@ -8,7 +8,6 @@ Original code and weights from https://github.com/mmaaz60/EdgeNeXt
|
||||
Modifications and additions for timm by / Copyright 2022, Ross Wightman
|
||||
"""
|
||||
import math
|
||||
from collections import OrderedDict
|
||||
from functools import partial
|
||||
from typing import Tuple
|
||||
|
||||
@ -17,7 +16,7 @@ import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import trunc_normal_tf_, DropPath, LayerNorm2d, Mlp, SelectAdaptivePool2d, create_conv2d, \
|
||||
from timm.layers import trunc_normal_tf_, DropPath, LayerNorm2d, Mlp, create_conv2d, \
|
||||
use_fused_attn, NormMlpClassifierHead, ClassifierHead
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._features_fx import register_notrace_module
|
||||
|
@ -449,7 +449,7 @@ class EfficientFormer(nn.Module):
|
||||
def get_classifier(self):
|
||||
return self.head, self.head_dist
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool=None):
|
||||
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
|
||||
self.num_classes = num_classes
|
||||
if global_pool is not None:
|
||||
self.global_pool = global_pool
|
||||
|
@ -16,7 +16,7 @@ Modifications and timm support by / Copyright 2023, Ross Wightman
|
||||
"""
|
||||
import math
|
||||
from functools import partial
|
||||
from typing import Dict
|
||||
from typing import Dict, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -612,7 +612,7 @@ class EfficientFormerV2(nn.Module):
|
||||
def get_classifier(self):
|
||||
return self.head, self.head_dist
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool=None):
|
||||
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
|
||||
self.num_classes = num_classes
|
||||
if global_pool is not None:
|
||||
self.global_pool = global_pool
|
||||
|
@ -13,7 +13,6 @@ from functools import partial
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import SelectAdaptivePool2d, create_conv2d, GELUTanh
|
||||
@ -740,7 +739,7 @@ class EfficientVit(nn.Module):
|
||||
def get_classifier(self):
|
||||
return self.head.classifier[-1]
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool=None):
|
||||
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
|
||||
self.num_classes = num_classes
|
||||
if global_pool is not None:
|
||||
self.global_pool = global_pool
|
||||
@ -858,7 +857,7 @@ class EfficientVitLarge(nn.Module):
|
||||
def get_classifier(self):
|
||||
return self.head.classifier[-1]
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool=None):
|
||||
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
|
||||
self.num_classes = num_classes
|
||||
if global_pool is not None:
|
||||
self.global_pool = global_pool
|
||||
|
@ -9,7 +9,7 @@ Adapted from official impl at https://github.com/microsoft/Cream/tree/main/Effic
|
||||
__all__ = ['EfficientVitMsra']
|
||||
import itertools
|
||||
from collections import OrderedDict
|
||||
from typing import Dict
|
||||
from typing import Dict, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -464,7 +464,7 @@ class EfficientVitMsra(nn.Module):
|
||||
def get_classifier(self):
|
||||
return self.head.linear
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool=None):
|
||||
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
|
||||
self.num_classes = num_classes
|
||||
if global_pool is not None:
|
||||
if global_pool == 'avg':
|
||||
|
@ -539,7 +539,7 @@ class Eva(nn.Module):
|
||||
def get_classifier(self):
|
||||
return self.head
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool=None):
|
||||
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
|
||||
self.num_classes = num_classes
|
||||
if global_pool is not None:
|
||||
self.global_pool = global_pool
|
||||
|
@ -396,7 +396,7 @@ class ReparamLargeKernelConv(nn.Module):
|
||||
|
||||
@staticmethod
|
||||
def _fuse_bn(
|
||||
conv: torch.Tensor, bn: nn.BatchNorm2d
|
||||
conv: nn.Conv2d, bn: nn.BatchNorm2d
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Method to fuse batchnorm layer with conv layer.
|
||||
|
||||
@ -1232,7 +1232,7 @@ class FastVit(nn.Module):
|
||||
def get_classifier(self):
|
||||
return self.head.fc
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool=None):
|
||||
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
|
||||
self.num_classes = num_classes
|
||||
self.head.reset(num_classes, global_pool)
|
||||
|
||||
|
@ -454,7 +454,7 @@ class FocalNet(nn.Module):
|
||||
def get_classifier(self):
|
||||
return self.head.fc
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool=None):
|
||||
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
|
||||
self.head.reset(num_classes, pool_type=global_pool)
|
||||
|
||||
def forward_features(self, x):
|
||||
|
@ -489,7 +489,7 @@ class GlobalContextVit(nn.Module):
|
||||
def get_classifier(self):
|
||||
return self.head.fc
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool=None):
|
||||
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
|
||||
self.num_classes = num_classes
|
||||
if global_pool is None:
|
||||
global_pool = self.head.global_pool.pool_type
|
||||
|
936
timm/models/hiera.py
Normal file
936
timm/models/hiera.py
Normal file
@ -0,0 +1,936 @@
|
||||
""" An PyTorch implementation of Hiera
|
||||
|
||||
Adapted for timm from originals at https://github.com/facebookresearch/hiera
|
||||
"""
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
# --------------------------------------------------------
|
||||
#
|
||||
# Hiera: A Hierarchical Vision Transformer without the Bells-and-Whistles
|
||||
#
|
||||
# Chaitanya Ryali, Yuan-Ting Hu, Daniel Bolya, Chen Wei, Haoqi Fan,
|
||||
# Po-Yao Huang, Vaibhav Aggarwal, Arkabandhu Chowdhury, Omid Poursaeed,
|
||||
# Judy Hoffman, Jitendra Malik, Yanghao Li, Christoph Feichtenhofer.
|
||||
#
|
||||
# Paper: https://arxiv.org/abs/2306.00989/
|
||||
#
|
||||
# References:
|
||||
# slowfast: https://github.com/facebookresearch/SlowFast
|
||||
# timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
|
||||
# --------------------------------------------------------
|
||||
import math
|
||||
from functools import partial
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Type, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import DropPath, Mlp, use_fused_attn, _assert, get_norm_layer
|
||||
|
||||
|
||||
from ._registry import generate_default_cfgs, register_model
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._features import feature_take_indices
|
||||
from ._features_fx import register_notrace_function
|
||||
|
||||
|
||||
def conv_nd(n: int) -> Type[nn.Module]:
|
||||
"""
|
||||
Returns a conv with nd (e.g., Conv2d for n=2). Work up to n=3.
|
||||
If you wanted a 4d Hiera, you could probably just implement this for n=4. (no promises)
|
||||
"""
|
||||
return [nn.Identity, nn.Conv1d, nn.Conv2d, nn.Conv3d][n]
|
||||
|
||||
|
||||
@register_notrace_function
|
||||
def get_resized_mask(target_size: List[int], mask: torch.Tensor) -> torch.Tensor:
|
||||
# target_size: [(T), (H), W]
|
||||
# (spatial) mask: [B, C, (t), (h), w]
|
||||
if mask is None:
|
||||
return mask
|
||||
|
||||
_assert(len(mask.shape[2:]) == len(target_size), "mask spatial shape and target_size must match.")
|
||||
if mask.shape[2:] != target_size:
|
||||
return F.interpolate(mask.float(), size=target_size)
|
||||
return mask
|
||||
|
||||
|
||||
def undo_windowing(
|
||||
x: torch.Tensor,
|
||||
shape: List[int],
|
||||
mu_shape: List[int],
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Restore spatial organization by undoing windowed organization of mask units.
|
||||
|
||||
Args:
|
||||
x: organized by mask units windows, e.g. in 2d [B, #MUy*#MUx, MUy, MUx, C]
|
||||
shape: current spatial shape, if it were not organized into mask unit
|
||||
windows, e.g. in 2d [B, #MUy*MUy, #MUx*MUx, C].
|
||||
mu_shape: current mask unit shape, e.g. in 2d [MUy, MUx]
|
||||
Returns:
|
||||
x: e.g. in 2d, [B, #MUy*MUy, #MUx*MUx, C]
|
||||
"""
|
||||
D = len(shape)
|
||||
B, C = x.shape[0], x.shape[-1]
|
||||
# [B, #MUy*#MUx, MUy, MUx, C] -> [B, #MUy, #MUx, MUy, MUx, C]
|
||||
num_MUs = [s // mu for s, mu in zip(shape, mu_shape)]
|
||||
x = x.view(B, *num_MUs, *mu_shape, C)
|
||||
|
||||
# [B, #MUy, #MUx, MUy, MUx, C] -> [B, #MUy*MUy, #MUx*MUx, C]
|
||||
permute = (
|
||||
[0]
|
||||
+ sum([list(p) for p in zip(range(1, 1 + D), range(1 + D, 1 + 2 * D))], [])
|
||||
+ [len(x.shape) - 1]
|
||||
)
|
||||
x = x.permute(permute).reshape(B, *shape, C)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class Unroll(nn.Module):
|
||||
"""
|
||||
Reorders the tokens such that patches are contiguous in memory.
|
||||
E.g., given [B, (H, W), C] and stride of (Sy, Sx), this will re-order the tokens as
|
||||
[B, (Sy, Sx, H // Sy, W // Sx), C]
|
||||
|
||||
This allows operations like Max2d to be computed as x.view(B, Sx*Sy, -1, C).max(dim=1).
|
||||
Not only is this faster, but it also makes it easy to support inputs of arbitrary
|
||||
dimensions in addition to patch-wise sparsity.
|
||||
|
||||
Performing this operation multiple times in sequence puts entire windows as contiguous
|
||||
in memory. For instance, if you applied the stride (2, 2) 3 times, entire windows of
|
||||
size 8x8 would be contiguous in memory, allowing operations like mask unit attention
|
||||
computed easily and efficiently, while also allowing max to be applied sequentially.
|
||||
|
||||
Note: This means that intermediate values of the model are not in HxW order, so they
|
||||
need to be re-rolled if you want to use the intermediate values as a HxW feature map.
|
||||
The last block of the network is fine though, since by then the strides are all consumed.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: Tuple[int, ...],
|
||||
patch_stride: Tuple[int, ...],
|
||||
unroll_schedule: List[Tuple[int, ...]],
|
||||
):
|
||||
super().__init__()
|
||||
self.size = [i // s for i, s in zip(input_size, patch_stride)]
|
||||
self.schedule = unroll_schedule
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Input: Flattened patch embeddings [B, N, C]
|
||||
Output: Patch embeddings [B, N, C] permuted such that [B, 4, N//4, C].max(1) etc. performs MaxPoolNd
|
||||
"""
|
||||
B, _, C = x.shape
|
||||
cur_size = self.size
|
||||
x = x.view(*([B] + cur_size + [C]))
|
||||
|
||||
for strides in self.schedule:
|
||||
# Move patches with the given strides to the batch dimension
|
||||
|
||||
# Create a view of the tensor with the patch stride as separate dims
|
||||
# For example in 2d: [B, H // Sy, Sy, W // Sx, Sx, C]
|
||||
cur_size = [i // s for i, s in zip(cur_size, strides)]
|
||||
new_shape = [B] + sum([[i, s] for i, s in zip(cur_size, strides)], []) + [C]
|
||||
x = x.view(new_shape)
|
||||
|
||||
# Move the patch stride into the batch dimension
|
||||
# For example in 2d: [B, Sy, Sx, H // Sy, W // Sx, C]
|
||||
L = len(new_shape)
|
||||
permute = [0] + list(range(2, L - 1, 2)) + list(range(1, L - 1, 2)) + [L - 1]
|
||||
x = x.permute(permute)
|
||||
|
||||
# Now finally flatten the relevant dims into the batch dimension
|
||||
x = x.flatten(0, len(strides))
|
||||
B *= math.prod(strides)
|
||||
|
||||
x = x.reshape(-1, math.prod(self.size), C)
|
||||
return x
|
||||
|
||||
|
||||
class Reroll(nn.Module):
|
||||
"""
|
||||
Undos the "unroll" operation so that you can use intermediate features.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: Tuple[int, ...],
|
||||
patch_stride: Tuple[int, ...],
|
||||
unroll_schedule: List[Tuple[int, ...]],
|
||||
stage_ends: List[int],
|
||||
q_pool: int,
|
||||
):
|
||||
super().__init__()
|
||||
self.size = [i // s for i, s in zip(input_size, patch_stride)]
|
||||
|
||||
# The first stage has to reverse everything
|
||||
# The next stage has to reverse all but the first unroll, etc.
|
||||
self.schedule = {}
|
||||
size = self.size
|
||||
for i in range(stage_ends[-1] + 1):
|
||||
self.schedule[i] = unroll_schedule, size
|
||||
# schedule unchanged if no pooling at a stage end
|
||||
if i in stage_ends[:q_pool]:
|
||||
if len(unroll_schedule) > 0:
|
||||
size = [n // s for n, s in zip(size, unroll_schedule[0])]
|
||||
unroll_schedule = unroll_schedule[1:]
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
block_idx: int,
|
||||
mask: torch.Tensor = None
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Roll the given tensor back up to spatial order assuming it's from the given block.
|
||||
|
||||
If no mask is provided:
|
||||
- Returns [B, H, W, C] for 2d, [B, T, H, W, C] for 3d, etc.
|
||||
If a mask is provided:
|
||||
- Returns [B, #MUs, MUy, MUx, C] for 2d, etc.
|
||||
"""
|
||||
schedule, size = self.schedule[block_idx]
|
||||
B, N, C = x.shape
|
||||
|
||||
D = len(size)
|
||||
cur_mu_shape = [1] * D
|
||||
|
||||
for strides in schedule:
|
||||
# Extract the current patch from N
|
||||
x = x.view(B, *strides, N // math.prod(strides), *cur_mu_shape, C)
|
||||
|
||||
# Move that patch into the current MU
|
||||
# Example in 2d: [B, Sy, Sx, N//(Sy*Sx), MUy, MUx, C] -> [B, N//(Sy*Sx), Sy, MUy, Sx, MUx, C]
|
||||
L = len(x.shape)
|
||||
permute = (
|
||||
[0, 1 + D]
|
||||
+ sum([list(p) for p in zip(range(1, 1 + D), range(1 + D + 1, L - 1))], [])
|
||||
+ [L - 1]
|
||||
)
|
||||
x = x.permute(permute)
|
||||
|
||||
# Reshape to [B, N//(Sy*Sx), *MU, C]
|
||||
for i in range(D):
|
||||
cur_mu_shape[i] *= strides[i]
|
||||
x = x.reshape(B, -1, *cur_mu_shape, C)
|
||||
N = x.shape[1]
|
||||
|
||||
# Current shape (e.g., 2d: [B, #MUy*#MUx, MUy, MUx, C])
|
||||
x = x.view(B, N, *cur_mu_shape, C)
|
||||
|
||||
# If masked, return [B, #MUs, MUy, MUx, C]
|
||||
if mask is not None:
|
||||
return x
|
||||
|
||||
# If not masked, we can return [B, H, W, C]
|
||||
x = undo_windowing(x, size, cur_mu_shape)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class MaskUnitAttention(nn.Module):
|
||||
"""
|
||||
Computes either Mask Unit or Global Attention. Also is able to perform q pooling.
|
||||
|
||||
Note: this assumes the tokens have already been flattened and unrolled into mask units.
|
||||
See `Unroll` for more details.
|
||||
"""
|
||||
fused_attn: torch.jit.Final[bool]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
dim_out: int,
|
||||
heads: int,
|
||||
q_stride: int = 1,
|
||||
window_size: int = 0,
|
||||
use_mask_unit_attn: bool = False,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
- dim, dim_out: The input and output feature dimensions.
|
||||
- heads: The number of attention heads.
|
||||
- q_stride: If greater than 1, pool q with this stride. The stride should be flattened (e.g., 2x2 = 4).
|
||||
- window_size: The current (flattened) size of a mask unit *after* pooling (if any).
|
||||
- use_mask_unit_attn: Use Mask Unit or Global Attention.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.dim = dim
|
||||
self.dim_out = dim_out
|
||||
self.heads = heads
|
||||
self.q_stride = q_stride
|
||||
self.head_dim = dim_out // heads
|
||||
self.scale = self.head_dim ** -0.5
|
||||
self.fused_attn = use_fused_attn()
|
||||
|
||||
self.qkv = nn.Linear(dim, 3 * dim_out)
|
||||
self.proj = nn.Linear(dim_out, dim_out)
|
||||
|
||||
self.window_size = window_size
|
||||
self.use_mask_unit_attn = use_mask_unit_attn
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
""" Input should be of shape [batch, tokens, channels]. """
|
||||
B, N, _ = x.shape
|
||||
num_windows = (N // (self.q_stride * self.window_size)) if self.use_mask_unit_attn else 1
|
||||
|
||||
qkv = self.qkv(x).reshape(B, -1, num_windows, 3, self.heads, self.head_dim).permute(3, 0, 4, 2, 1, 5)
|
||||
q, k, v = qkv.unbind(0)
|
||||
|
||||
if self.q_stride > 1:
|
||||
# Refer to Unroll to see how this performs a maxpool-Nd
|
||||
q = q.view(B, self.heads, num_windows, self.q_stride, -1, self.head_dim).amax(dim=3)
|
||||
|
||||
if self.fused_attn:
|
||||
# Note: the original paper did *not* use SDPA, it's a free boost!
|
||||
x = F.scaled_dot_product_attention(q, k, v)
|
||||
else:
|
||||
attn = (q * self.scale) @ k.transpose(-1, -2)
|
||||
attn = attn.softmax(dim=-1)
|
||||
x = attn @ v
|
||||
|
||||
x = x.transpose(1, 3).reshape(B, -1, self.dim_out)
|
||||
x = self.proj(x)
|
||||
return x
|
||||
|
||||
|
||||
class HieraBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
dim_out: int,
|
||||
heads: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
drop_path: float = 0.0,
|
||||
norm_layer: nn.Module = nn.LayerNorm,
|
||||
act_layer: nn.Module = nn.GELU,
|
||||
q_stride: int = 1,
|
||||
window_size: int = 0,
|
||||
use_expand_proj: bool = True,
|
||||
use_mask_unit_attn: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.dim = dim
|
||||
self.dim_out = dim_out
|
||||
|
||||
self.norm1 = norm_layer(dim)
|
||||
if dim != dim_out:
|
||||
self.do_expand = True
|
||||
if use_expand_proj:
|
||||
self.proj = nn.Linear(dim, dim_out)
|
||||
else:
|
||||
assert dim_out == dim * 2
|
||||
self.proj = None
|
||||
else:
|
||||
self.do_expand = False
|
||||
self.proj = None
|
||||
self.attn = MaskUnitAttention(
|
||||
dim,
|
||||
dim_out,
|
||||
heads,
|
||||
q_stride,
|
||||
window_size,
|
||||
use_mask_unit_attn
|
||||
)
|
||||
self.drop_path1 = DropPath(drop_path) if drop_path > 0 else nn.Identity()
|
||||
|
||||
self.norm2 = norm_layer(dim_out)
|
||||
self.mlp = Mlp(dim_out, int(dim_out * mlp_ratio), act_layer=act_layer)
|
||||
self.drop_path2 = DropPath(drop_path) if drop_path > 0 else nn.Identity()
|
||||
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# Attention + Q Pooling
|
||||
x_norm = self.norm1(x)
|
||||
if self.do_expand:
|
||||
if self.proj is not None:
|
||||
x = self.proj(x_norm)
|
||||
x = x.view(x.shape[0], self.attn.q_stride, -1, x.shape[-1]).amax(dim=1) # max-pool
|
||||
else:
|
||||
x = torch.cat([
|
||||
x.view(x.shape[0], self.attn.q_stride, -1, x.shape[-1]).amax(dim=1), # max-pool
|
||||
x.view(x.shape[0], self.attn.q_stride, -1, x.shape[-1]).mean(dim=1), # avg-pool
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
x = x + self.drop_path1(self.attn(x_norm))
|
||||
|
||||
# MLP
|
||||
x = x + self.drop_path2(self.mlp(self.norm2(x)))
|
||||
return x
|
||||
|
||||
|
||||
class NormClassifierHead(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
num_classes: int,
|
||||
pool_type: str = 'avg',
|
||||
drop_rate: float = 0.0,
|
||||
norm_layer: Union[str, Callable] = 'layernorm',
|
||||
):
|
||||
super().__init__()
|
||||
norm_layer = get_norm_layer(norm_layer)
|
||||
assert pool_type in ('avg', '')
|
||||
self.in_features = self.num_features = in_features
|
||||
self.pool_type = pool_type
|
||||
self.norm = norm_layer(in_features)
|
||||
self.drop = nn.Dropout(drop_rate) if drop_rate else nn.Identity()
|
||||
self.fc = nn.Linear(in_features, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
def reset(self, num_classes: int, pool_type: Optional[str] = None, other: bool = False):
|
||||
if pool_type is not None:
|
||||
assert pool_type in ('avg', '')
|
||||
self.pool_type = pool_type
|
||||
if other:
|
||||
# reset other non-fc layers
|
||||
self.norm = nn.Identity()
|
||||
self.fc = nn.Linear(self.in_features, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
def forward(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor:
|
||||
if self.pool_type == 'avg':
|
||||
x = x.mean(dim=1)
|
||||
x = self.norm(x)
|
||||
x = self.drop(x)
|
||||
if pre_logits:
|
||||
return x
|
||||
x = self.fc(x)
|
||||
return x
|
||||
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
"""Patch embed that supports any number of spatial dimensions (1d, 2d, 3d)."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim_in: int,
|
||||
dim_out: int,
|
||||
kernel: Tuple[int, ...],
|
||||
stride: Tuple[int, ...],
|
||||
padding: Tuple[int, ...],
|
||||
reshape: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Support any number of spatial dimensions
|
||||
self.spatial_dims = len(kernel)
|
||||
self.reshape = reshape
|
||||
self.proj = conv_nd(self.spatial_dims)(
|
||||
dim_in,
|
||||
dim_out,
|
||||
kernel_size=kernel,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if mask is not None:
|
||||
mask = get_resized_mask(target_size=x.shape[2:], mask=mask)
|
||||
x = self.proj(x * mask.to(torch.bool))
|
||||
else:
|
||||
x = self.proj(x)
|
||||
if self.reshape:
|
||||
x = x.reshape(x.shape[0], x.shape[1], -1).transpose(2, 1)
|
||||
return x
|
||||
|
||||
|
||||
class Hiera(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
img_size: Tuple[int, ...] = (224, 224),
|
||||
in_chans: int = 3,
|
||||
embed_dim: int = 96, # initial embed dim
|
||||
num_heads: int = 1, # initial number of heads
|
||||
num_classes: int = 1000,
|
||||
global_pool: str = 'avg',
|
||||
stages: Tuple[int, ...] = (2, 3, 16, 3),
|
||||
q_pool: int = 3, # number of q_pool stages
|
||||
q_stride: Tuple[int, ...] = (2, 2),
|
||||
mask_unit_size: Tuple[int, ...] = (8, 8), # must divide q_stride ** (#stages-1)
|
||||
# mask_unit_attn: which stages use mask unit attention?
|
||||
mask_unit_attn: Tuple[bool, ...] = (True, True, False, False),
|
||||
dim_mul: float = 2.0,
|
||||
head_mul: float = 2.0,
|
||||
patch_kernel: Tuple[int, ...] = (7, 7),
|
||||
patch_stride: Tuple[int, ...] = (4, 4),
|
||||
patch_padding: Tuple[int, ...] = (3, 3),
|
||||
mlp_ratio: float = 4.0,
|
||||
drop_path_rate: float = 0.0,
|
||||
norm_layer: Union[str, nn.Module] = "LayerNorm",
|
||||
drop_rate: float = 0.0,
|
||||
head_init_scale: float = 0.001,
|
||||
sep_pos_embed: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_classes = num_classes
|
||||
self.grad_checkpointing = False
|
||||
norm_layer = get_norm_layer(norm_layer)
|
||||
|
||||
self.patch_stride = patch_stride
|
||||
self.tokens_spatial_shape = [i // s for i, s in zip(img_size, patch_stride)]
|
||||
num_tokens = math.prod(self.tokens_spatial_shape)
|
||||
flat_mu_size = math.prod(mask_unit_size)
|
||||
flat_q_stride = math.prod(q_stride)
|
||||
assert q_pool < len(stages)
|
||||
self.q_pool, self.q_stride = q_pool, q_stride
|
||||
self.mu_size, self.mask_unit_size = flat_mu_size, mask_unit_size
|
||||
self.mask_spatial_shape = [i // s for i, s in zip(self.tokens_spatial_shape, self.mask_unit_size)]
|
||||
self.stage_ends = [sum(stages[:i]) - 1 for i in range(1, len(stages) + 1)]
|
||||
|
||||
self.patch_embed = PatchEmbed(
|
||||
in_chans,
|
||||
embed_dim,
|
||||
patch_kernel,
|
||||
patch_stride,
|
||||
patch_padding,
|
||||
#reshape=False, # leave spatial / temporal dims in output
|
||||
)
|
||||
|
||||
if sep_pos_embed:
|
||||
self.pos_embed = None
|
||||
self.pos_embed_spatial = nn.Parameter(
|
||||
torch.zeros(1, self.tokens_spatial_shape[1] * self.tokens_spatial_shape[2], embed_dim)
|
||||
)
|
||||
self.pos_embed_temporal = nn.Parameter(
|
||||
torch.zeros(1, self.tokens_spatial_shape[0], embed_dim)
|
||||
)
|
||||
else:
|
||||
self.pos_embed = nn.Parameter(torch.zeros(1, num_tokens, embed_dim))
|
||||
self.pos_embed_spatial = None
|
||||
self.pos_embed_temporal = None
|
||||
|
||||
# Setup roll and reroll modules
|
||||
self.unroll = Unroll(
|
||||
img_size,
|
||||
patch_stride,
|
||||
[q_stride] * len(self.stage_ends[:-1])
|
||||
)
|
||||
self.reroll = Reroll(
|
||||
img_size,
|
||||
patch_stride,
|
||||
[q_stride] * len(self.stage_ends[:-1]),
|
||||
self.stage_ends,
|
||||
q_pool,
|
||||
)
|
||||
# q_pool locations
|
||||
q_pool_blocks = [x + 1 for x in self.stage_ends[:q_pool]]
|
||||
|
||||
# Transformer blocks
|
||||
cur_stage = 0
|
||||
depth = sum(stages)
|
||||
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
||||
self.blocks = nn.ModuleList()
|
||||
self.feature_info = []
|
||||
for i in range(depth):
|
||||
dim_out = embed_dim
|
||||
# Mask unit or global attention.
|
||||
# Lag by 1 block, so that global attention,
|
||||
# applied post pooling on lower resolution
|
||||
use_mask_unit_attn = mask_unit_attn[cur_stage]
|
||||
|
||||
if i - 1 in self.stage_ends:
|
||||
dim_out = int(embed_dim * dim_mul)
|
||||
num_heads = int(num_heads * head_mul)
|
||||
cur_stage += 1
|
||||
if i in q_pool_blocks:
|
||||
flat_mu_size //= flat_q_stride
|
||||
|
||||
block = HieraBlock(
|
||||
dim=embed_dim,
|
||||
dim_out=dim_out,
|
||||
heads=num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
drop_path=dpr[i],
|
||||
norm_layer=norm_layer,
|
||||
q_stride=(flat_q_stride if i in q_pool_blocks else 1),
|
||||
window_size=flat_mu_size,
|
||||
use_mask_unit_attn=use_mask_unit_attn,
|
||||
)
|
||||
embed_dim = dim_out
|
||||
if i in self.stage_ends:
|
||||
self.feature_info += [
|
||||
dict(num_chs=dim_out, reduction=2**(cur_stage+2), module=f'blocks.{self.stage_ends[cur_stage]}')]
|
||||
self.blocks.append(block)
|
||||
|
||||
self.num_features = embed_dim
|
||||
self.head = NormClassifierHead(
|
||||
embed_dim,
|
||||
num_classes,
|
||||
pool_type=global_pool,
|
||||
drop_rate=drop_rate,
|
||||
norm_layer=norm_layer,
|
||||
)
|
||||
|
||||
# Initialize everything
|
||||
if sep_pos_embed:
|
||||
nn.init.trunc_normal_(self.pos_embed_spatial, std=0.02)
|
||||
nn.init.trunc_normal_(self.pos_embed_temporal, std=0.02)
|
||||
else:
|
||||
nn.init.trunc_normal_(self.pos_embed, std=0.02)
|
||||
self.apply(partial(self._init_weights))
|
||||
if isinstance(self.head.fc, nn.Linear):
|
||||
self.head.fc.weight.data.mul_(head_init_scale)
|
||||
self.head.fc.bias.data.mul_(head_init_scale)
|
||||
|
||||
def _init_weights(self, m, init_bias=0.02):
|
||||
if isinstance(m, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d)):
|
||||
nn.init.trunc_normal_(m.weight, std=0.02)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
nn.init.constant_(m.bias, init_bias)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
nn.init.constant_(m.bias, init_bias)
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
|
||||
@torch.jit.ignore
|
||||
def no_weight_decay(self):
|
||||
if self.pos_embed is not None:
|
||||
return ["pos_embed"]
|
||||
else:
|
||||
return ["pos_embed_spatial", "pos_embed_temporal"]
|
||||
|
||||
@torch.jit.ignore
|
||||
def group_matcher(self, coarse: bool = False) -> Dict:
|
||||
return dict(
|
||||
stem=r'^pos_embed|pos_embed_spatial|pos_embed_temporal|patch_embed', # stem and embed
|
||||
blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))]
|
||||
)
|
||||
|
||||
@torch.jit.ignore
|
||||
def set_grad_checkpointing(self, enable: bool = True) -> None:
|
||||
self.grad_checkpointing = enable
|
||||
|
||||
@torch.jit.ignore
|
||||
def get_classifier(self):
|
||||
return self.head.fc
|
||||
|
||||
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None, other: bool = False):
|
||||
self.num_classes = num_classes
|
||||
self.head.reset(num_classes, global_pool, other=other)
|
||||
|
||||
def get_random_mask(self, x: torch.Tensor, mask_ratio: float) -> torch.Tensor:
|
||||
"""
|
||||
Generates a random mask, mask_ratio fraction are dropped.
|
||||
1 is *keep*, 0 is *remove*. Useful for MAE, FLIP, etc.
|
||||
"""
|
||||
B = x.shape[0]
|
||||
# Tokens selected for masking at mask unit level
|
||||
num_windows = math.prod(self.mask_spatial_shape) # num_mask_units
|
||||
len_keep = int(num_windows * (1 - mask_ratio))
|
||||
noise = torch.rand(B, num_windows, device=x.device)
|
||||
|
||||
# Sort noise for each sample
|
||||
ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove
|
||||
ids_restore = torch.argsort(ids_shuffle, dim=1)
|
||||
|
||||
# Generate the binary mask: 1 is *keep*, 0 is *remove*
|
||||
# Note this is opposite to original MAE
|
||||
mask = torch.zeros([B, num_windows], device=x.device)
|
||||
mask[:, :len_keep] = 1
|
||||
# Unshuffle to get the binary mask
|
||||
mask = torch.gather(mask, dim=1, index=ids_restore)
|
||||
|
||||
return mask.bool()
|
||||
|
||||
def _pos_embed(self, x) -> torch.Tensor:
|
||||
if self.pos_embed is not None:
|
||||
pos_embed = self.pos_embed
|
||||
else:
|
||||
pos_embed = (
|
||||
self.pos_embed_spatial.repeat(1, self.tokens_spatial_shape[0], 1)
|
||||
+
|
||||
torch.repeat_interleave(
|
||||
self.pos_embed_temporal,
|
||||
self.tokens_spatial_shape[1] * self.tokens_spatial_shape[2],
|
||||
dim=1,
|
||||
)
|
||||
)
|
||||
x = x + pos_embed
|
||||
return x
|
||||
|
||||
def forward_intermediates(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
|
||||
norm: bool = False,
|
||||
stop_early: bool = True,
|
||||
output_fmt: str = 'NCHW',
|
||||
intermediates_only: bool = False,
|
||||
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||
""" Forward features that returns intermediates.
|
||||
|
||||
Args:
|
||||
x: Input image tensor
|
||||
indices: Take last n blocks if int, all if None, select matching indices if sequence
|
||||
norm: Apply norm layer to all intermediates
|
||||
stop_early: Stop iterating over blocks when last desired intermediate hit
|
||||
output_fmt: Shape of intermediate feature outputs
|
||||
intermediates_only: Only return intermediate features
|
||||
Returns:
|
||||
|
||||
"""
|
||||
assert not norm, 'normalization of features not supported'
|
||||
assert output_fmt in ('NCHW',), 'Output format must be one of NCHW.'
|
||||
take_indices, max_index = feature_take_indices(len(self.stage_ends), indices)
|
||||
take_indices = [self.stage_ends[i] for i in take_indices]
|
||||
max_index = self.stage_ends[max_index]
|
||||
|
||||
if mask is not None:
|
||||
patch_mask = mask.view(x.shape[0], 1, *self.mask_spatial_shape) # B, C, *mask_spatial_shape
|
||||
else:
|
||||
patch_mask = None
|
||||
x = self.patch_embed(x, mask=patch_mask)
|
||||
x = self._pos_embed(x)
|
||||
x = self.unroll(x)
|
||||
|
||||
# Discard masked tokens
|
||||
if mask is not None:
|
||||
x = x[mask[..., None].tile(1, self.mu_size, x.shape[2])].view(x.shape[0], -1, x.shape[-1])
|
||||
|
||||
intermediates = []
|
||||
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
|
||||
blocks = self.blocks
|
||||
else:
|
||||
blocks = self.blocks[:max_index + 1]
|
||||
for i, blk in enumerate(blocks):
|
||||
x = blk(x)
|
||||
if i in take_indices:
|
||||
intermediates.append(self.reroll(x, i, mask=mask).permute(0, 3, 1, 2))
|
||||
|
||||
if intermediates_only:
|
||||
return intermediates
|
||||
|
||||
return x, intermediates
|
||||
|
||||
def prune_intermediate_layers(
|
||||
self,
|
||||
indices: Union[int, List[int], Tuple[int]] = 1,
|
||||
prune_norm: bool = False,
|
||||
prune_head: bool = True,
|
||||
):
|
||||
""" Prune layers not required for specified intermediates.
|
||||
"""
|
||||
take_indices, max_index = feature_take_indices(len(self.stage_ends), indices)
|
||||
max_index = self.stage_ends[max_index]
|
||||
self.blocks = self.blocks[:max_index + 1] # truncate blocks
|
||||
if prune_head:
|
||||
self.head.reset(0, other=True)
|
||||
return take_indices
|
||||
|
||||
|
||||
def forward_features(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
return_intermediates: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
mask should be a boolean tensor of shape [B, #MUt*#MUy*#MUx] where #MU are the number of mask units in that dim.
|
||||
Note: 1 in mask is *keep*, 0 is *remove*; mask.sum(dim=-1) should be the same across the batch.
|
||||
"""
|
||||
if mask is not None:
|
||||
patch_mask = mask.view(x.shape[0], 1, *self.mask_spatial_shape) # B, C, *mask_spatial_shape
|
||||
else:
|
||||
patch_mask = None
|
||||
x = self.patch_embed(x, mask=patch_mask)
|
||||
x = self._pos_embed(x)
|
||||
x = self.unroll(x)
|
||||
|
||||
# Discard masked tokens
|
||||
if mask is not None:
|
||||
x = x[mask[..., None].tile(1, self.mu_size, x.shape[2])].view(x.shape[0], -1, x.shape[-1])
|
||||
|
||||
intermediates = []
|
||||
for i, blk in enumerate(self.blocks):
|
||||
if self.grad_checkpointing and not torch.jit.is_scripting():
|
||||
x = checkpoint(blk, x)
|
||||
else:
|
||||
x = blk(x)
|
||||
if return_intermediates and i in self.stage_ends:
|
||||
intermediates.append(self.reroll(x, i, mask=mask))
|
||||
|
||||
# x may not always be in spatial order here.
|
||||
# e.g. if q_pool = 2, mask_unit_size = (8, 8), and
|
||||
# q_stride = (2, 2), not all unrolls were consumed,
|
||||
# intermediates[-1] is x in spatial order
|
||||
if return_intermediates:
|
||||
return x, intermediates
|
||||
|
||||
return x
|
||||
|
||||
def forward_head(self, x, pre_logits: bool = False) -> torch.Tensor:
|
||||
x = self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x)
|
||||
return x
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
x = self.forward_features(x, mask=mask)
|
||||
if mask is None:
|
||||
x = self.forward_head(x)
|
||||
return x
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
return {
|
||||
'url': url,
|
||||
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
|
||||
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
|
||||
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
||||
'first_conv': 'patch_embed.proj', 'classifier': 'head.fc',
|
||||
**kwargs
|
||||
}
|
||||
|
||||
default_cfgs = generate_default_cfgs({
|
||||
"hiera_tiny_224.mae_in1k_ft_in1k": _cfg(
|
||||
hf_hub_id='timm/',
|
||||
license='cc-by-nc-4.0',
|
||||
),
|
||||
"hiera_tiny_224.mae": _cfg(
|
||||
hf_hub_id='timm/',
|
||||
license='cc-by-nc-4.0',
|
||||
num_classes=0,
|
||||
),
|
||||
|
||||
"hiera_small_224.mae_in1k_ft_in1k": _cfg(
|
||||
hf_hub_id='timm/',
|
||||
license='cc-by-nc-4.0',
|
||||
),
|
||||
"hiera_small_224.mae": _cfg(
|
||||
hf_hub_id='timm/',
|
||||
license='cc-by-nc-4.0',
|
||||
num_classes=0,
|
||||
),
|
||||
|
||||
"hiera_base_224.mae_in1k_ft_in1k": _cfg(
|
||||
hf_hub_id='timm/',
|
||||
license='cc-by-nc-4.0',
|
||||
),
|
||||
"hiera_base_224.mae": _cfg(
|
||||
hf_hub_id='timm/',
|
||||
license='cc-by-nc-4.0',
|
||||
num_classes=0,
|
||||
),
|
||||
|
||||
"hiera_base_plus_224.mae_in1k_ft_in1k": _cfg(
|
||||
hf_hub_id='timm/',
|
||||
license='cc-by-nc-4.0',
|
||||
),
|
||||
"hiera_base_plus_224.mae": _cfg(
|
||||
hf_hub_id='timm/',
|
||||
license='cc-by-nc-4.0',
|
||||
num_classes=0,
|
||||
),
|
||||
|
||||
"hiera_large_224.mae_in1k_ft_in1k": _cfg(
|
||||
hf_hub_id='timm/',
|
||||
license='cc-by-nc-4.0',
|
||||
),
|
||||
"hiera_large_224.mae": _cfg(
|
||||
hf_hub_id='timm/',
|
||||
license='cc-by-nc-4.0',
|
||||
num_classes=0,
|
||||
),
|
||||
|
||||
"hiera_huge_224.mae_in1k_ft_in1k": _cfg(
|
||||
hf_hub_id='timm/',
|
||||
license='cc-by-nc-4.0',
|
||||
),
|
||||
"hiera_huge_224.mae": _cfg(
|
||||
hf_hub_id='timm/',
|
||||
license='cc-by-nc-4.0',
|
||||
num_classes=0,
|
||||
),
|
||||
})
|
||||
|
||||
def checkpoint_filter_fn(state_dict, model=None):
|
||||
state_dict = state_dict.get('model_state', state_dict)
|
||||
output = {}
|
||||
for k, v in state_dict.items():
|
||||
if k == 'pos_embed' and v.shape[1] != model.pos_embed.shape[1]:
|
||||
# # To resize pos embedding when using model at different size from pretrained weights
|
||||
# from timm.layers import resample_abs_pos_embed
|
||||
# v = resample_abs_pos_embed(
|
||||
# v,
|
||||
# new_size=(64, 64),
|
||||
# num_prefix_tokens=0,
|
||||
# verbose=True,
|
||||
# )
|
||||
#v = F.interpolate(v.transpose(1, 2), (model.pos_embed.shape[1],)).transpose(1, 2)
|
||||
pass
|
||||
if 'head.projection.' in k:
|
||||
k = k.replace('head.projection.', 'head.fc.')
|
||||
if k.startswith('encoder_norm.'):
|
||||
k = k.replace('encoder_norm.', 'head.norm.')
|
||||
elif k.startswith('norm.'):
|
||||
k = k.replace('norm.', 'head.norm.')
|
||||
output[k] = v
|
||||
return output
|
||||
|
||||
|
||||
def _create_hiera(variant: str, pretrained: bool = False, **kwargs) -> Hiera:
|
||||
out_indices = kwargs.pop('out_indices', 4)
|
||||
|
||||
return build_model_with_cfg(
|
||||
Hiera,
|
||||
variant,
|
||||
pretrained,
|
||||
pretrained_filter_fn=checkpoint_filter_fn,
|
||||
feature_cfg=dict(out_indices=out_indices, feature_cls='getter'),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@register_model
|
||||
def hiera_tiny_224(pretrained = False, **kwargs):
|
||||
model_args = dict(embed_dim=96, num_heads=1, stages=(1, 2, 7, 2))
|
||||
return _create_hiera('hiera_tiny_224', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
@register_model
|
||||
def hiera_small_224(pretrained = False, **kwargs):
|
||||
model_args = dict(embed_dim=96, num_heads=1, stages=(1, 2, 11, 2))
|
||||
return _create_hiera('hiera_small_224', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
@register_model
|
||||
def hiera_base_224(pretrained = False, **kwargs):
|
||||
model_args = dict(embed_dim=96, num_heads=1, stages=(2, 3, 16, 3))
|
||||
return _create_hiera('hiera_base_224', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
@register_model
|
||||
def hiera_base_plus_224(pretrained = False, **kwargs):
|
||||
model_args = dict(embed_dim=112, num_heads=2, stages=(2, 3, 16, 3))
|
||||
return _create_hiera('hiera_base_plus_224', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
@register_model
|
||||
def hiera_large_224(pretrained = False, **kwargs):
|
||||
model_args = dict(embed_dim=144, num_heads=2, stages=(2, 6, 36, 4))
|
||||
return _create_hiera('hiera_large_224', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
@register_model
|
||||
def hiera_huge_224(pretrained = False, **kwargs):
|
||||
model_args = dict(embed_dim=256, num_heads=4, stages=(2, 6, 36, 4))
|
||||
return _create_hiera('hiera_huge_224', pretrained=pretrained, **dict(model_args, **kwargs))
|
@ -628,7 +628,7 @@ class Levit(nn.Module):
|
||||
def get_classifier(self):
|
||||
return self.head
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool=None, distillation=None):
|
||||
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None, distillation=None):
|
||||
self.num_classes = num_classes
|
||||
if global_pool is not None:
|
||||
self.global_pool = global_pool
|
||||
@ -730,7 +730,7 @@ class LevitDistilled(Levit):
|
||||
def get_classifier(self):
|
||||
return self.head, self.head_dist
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool=None, distillation=None):
|
||||
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None, distillation=None):
|
||||
self.num_classes = num_classes
|
||||
if global_pool is not None:
|
||||
self.global_pool = global_pool
|
||||
|
@ -1248,7 +1248,7 @@ class MaxxVit(nn.Module):
|
||||
def get_classifier(self):
|
||||
return self.head.fc
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool=None):
|
||||
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
|
||||
self.num_classes = num_classes
|
||||
self.head.reset(num_classes, global_pool)
|
||||
|
||||
|
@ -255,7 +255,7 @@ class MlpMixer(nn.Module):
|
||||
def get_classifier(self):
|
||||
return self.head
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool=None):
|
||||
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
|
||||
self.num_classes = num_classes
|
||||
if global_pool is not None:
|
||||
assert global_pool in ('', 'avg')
|
||||
|
@ -622,8 +622,6 @@ def _gen_lcnet(variant: str, channel_multiplier: float = 1.0, pretrained: bool =
|
||||
return model
|
||||
|
||||
|
||||
|
||||
|
||||
def _gen_mobilenet_v4(variant: str, channel_multiplier: float = 1.0, pretrained: bool = False, **kwargs) -> MobileNetV3:
|
||||
"""Creates a MobileNet-V4 model.
|
||||
|
||||
|
@ -825,7 +825,7 @@ class MultiScaleVit(nn.Module):
|
||||
def get_classifier(self):
|
||||
return self.head.fc
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool=None):
|
||||
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
|
||||
self.num_classes = num_classes
|
||||
if global_pool is not None:
|
||||
self.global_pool = global_pool
|
||||
|
@ -6,6 +6,7 @@ Next-ViT model defs and weights adapted from https://github.com/bytedance/Next-V
|
||||
"""
|
||||
# Copyright (c) ByteDance Inc. All rights reserved.
|
||||
from functools import partial
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@ -553,7 +554,7 @@ class NextViT(nn.Module):
|
||||
def get_classifier(self):
|
||||
return self.head.fc
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool=None):
|
||||
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
|
||||
self.head.reset(num_classes, pool_type=global_pool)
|
||||
|
||||
def forward_features(self, x):
|
||||
|
@ -14,13 +14,13 @@ Modifications for timm by / Copyright 2020 Ross Wightman
|
||||
import math
|
||||
import re
|
||||
from functools import partial
|
||||
from typing import Sequence, Tuple
|
||||
from typing import Optional, Sequence, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import trunc_normal_, to_2tuple, LayerNorm
|
||||
from timm.layers import trunc_normal_, to_2tuple
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._registry import register_model, generate_default_cfgs
|
||||
from .vision_transformer import Block
|
||||
@ -246,7 +246,7 @@ class PoolingVisionTransformer(nn.Module):
|
||||
else:
|
||||
return self.head
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool=None):
|
||||
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
|
||||
self.num_classes = num_classes
|
||||
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
||||
if self.head_dist is not None:
|
||||
|
@ -16,7 +16,7 @@ Modifications and timm support by / Copyright 2022, Ross Wightman
|
||||
"""
|
||||
|
||||
import math
|
||||
from typing import Tuple, List, Callable, Union
|
||||
from typing import Callable, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -379,7 +379,7 @@ class PyramidVisionTransformerV2(nn.Module):
|
||||
def get_classifier(self):
|
||||
return self.head
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool=None):
|
||||
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
|
||||
self.num_classes = num_classes
|
||||
if global_pool is not None:
|
||||
assert global_pool in ('avg', '')
|
||||
|
@ -16,15 +16,16 @@ Adapted from official impl at https://github.com/jameslahm/RepViT
|
||||
"""
|
||||
|
||||
__all__ = ['RepVit']
|
||||
|
||||
import torch.nn as nn
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from ._registry import register_model, generate_default_cfgs
|
||||
from ._builder import build_model_with_cfg
|
||||
from timm.layers import SqueezeExcite, trunc_normal_, to_ntuple, to_2tuple
|
||||
from ._manipulate import checkpoint_seq
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import SqueezeExcite, trunc_normal_, to_ntuple, to_2tuple
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._manipulate import checkpoint_seq
|
||||
from ._registry import register_model, generate_default_cfgs
|
||||
|
||||
|
||||
class ConvNorm(nn.Sequential):
|
||||
@ -322,7 +323,7 @@ class RepVit(nn.Module):
|
||||
def get_classifier(self):
|
||||
return self.head
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool=None, distillation=False):
|
||||
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None, distillation=False):
|
||||
self.num_classes = num_classes
|
||||
if global_pool is not None:
|
||||
self.global_pool = global_pool
|
||||
|
@ -9,7 +9,7 @@ Paper: `Sequencer: Deep LSTM for Image Classification` - https://arxiv.org/pdf/2
|
||||
import math
|
||||
from functools import partial
|
||||
from itertools import accumulate
|
||||
from typing import Tuple
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -419,7 +419,7 @@ class Sequencer2d(nn.Module):
|
||||
def get_classifier(self):
|
||||
return self.head
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool=None):
|
||||
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
|
||||
self.num_classes = num_classes
|
||||
self.head.reset(num_classes, pool_type=global_pool)
|
||||
|
||||
|
@ -604,7 +604,7 @@ class SwinTransformer(nn.Module):
|
||||
def get_classifier(self):
|
||||
return self.head.fc
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool=None):
|
||||
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
|
||||
self.num_classes = num_classes
|
||||
self.head.reset(num_classes, pool_type=global_pool)
|
||||
|
||||
|
@ -605,7 +605,7 @@ class SwinTransformerV2(nn.Module):
|
||||
def get_classifier(self):
|
||||
return self.head.fc
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool=None):
|
||||
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
|
||||
self.num_classes = num_classes
|
||||
self.head.reset(num_classes, global_pool)
|
||||
|
||||
|
@ -8,10 +8,9 @@ Adapted from official impl at https://github.com/microsoft/Cream/tree/main/TinyV
|
||||
|
||||
__all__ = ['TinyVit']
|
||||
|
||||
import math
|
||||
import itertools
|
||||
from functools import partial
|
||||
from typing import Dict
|
||||
from typing import Dict, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -533,7 +532,7 @@ class TinyVit(nn.Module):
|
||||
def get_classifier(self):
|
||||
return self.head.fc
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool=None):
|
||||
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
|
||||
self.num_classes = num_classes
|
||||
self.head.reset(num_classes, pool_type=global_pool)
|
||||
|
||||
|
@ -7,6 +7,7 @@ The official mindspore code is released and available at
|
||||
https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/TNT
|
||||
"""
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -298,7 +299,7 @@ class TNT(nn.Module):
|
||||
def get_classifier(self):
|
||||
return self.head
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool=None):
|
||||
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
|
||||
self.num_classes = num_classes
|
||||
if global_pool is not None:
|
||||
assert global_pool in ('', 'token', 'avg')
|
||||
|
@ -7,6 +7,7 @@ Original model: https://github.com/mrT23/TResNet
|
||||
"""
|
||||
from collections import OrderedDict
|
||||
from functools import partial
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -233,7 +234,7 @@ class TResNet(nn.Module):
|
||||
def get_classifier(self):
|
||||
return self.head.fc
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool=None):
|
||||
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
|
||||
self.head.reset(num_classes, pool_type=global_pool)
|
||||
|
||||
def forward_features(self, x):
|
||||
|
@ -382,7 +382,7 @@ class Twins(nn.Module):
|
||||
def get_classifier(self):
|
||||
return self.head
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool=None):
|
||||
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
|
||||
self.num_classes = num_classes
|
||||
if global_pool is not None:
|
||||
assert global_pool in ('', 'avg')
|
||||
|
@ -845,7 +845,9 @@ def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str =
|
||||
"""
|
||||
import numpy as np
|
||||
|
||||
def _n2p(w, t=True):
|
||||
def _n2p(w, t=True, idx=None):
|
||||
if idx is not None:
|
||||
w = w[idx]
|
||||
if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1:
|
||||
w = w.flatten()
|
||||
if t:
|
||||
@ -955,21 +957,28 @@ def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str =
|
||||
|
||||
mha_sub, b_sub, ln1_sub = (0, 0, 1) if big_vision else (1, 3, 2)
|
||||
for i, block in enumerate(model.blocks.children()):
|
||||
block_prefix = f'{prefix}Transformer/encoderblock_{i}/'
|
||||
if f'{prefix}Transformer/encoderblock/LayerNorm_0/scale' in w:
|
||||
block_prefix = f'{prefix}Transformer/encoderblock/'
|
||||
idx = i
|
||||
else:
|
||||
block_prefix = f'{prefix}Transformer/encoderblock_{i}/'
|
||||
idx = None
|
||||
mha_prefix = block_prefix + f'MultiHeadDotProductAttention_{mha_sub}/'
|
||||
block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
|
||||
block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
|
||||
block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'], idx=idx))
|
||||
block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'], idx=idx))
|
||||
block.attn.qkv.weight.copy_(torch.cat([
|
||||
_n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')]))
|
||||
_n2p(w[f'{mha_prefix}{n}/kernel'], t=False, idx=idx).flatten(1).T for n in ('query', 'key', 'value')]))
|
||||
block.attn.qkv.bias.copy_(torch.cat([
|
||||
_n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')]))
|
||||
block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
|
||||
block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
|
||||
block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/scale']))
|
||||
block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/bias']))
|
||||
_n2p(w[f'{mha_prefix}{n}/bias'], t=False, idx=idx).reshape(-1) for n in ('query', 'key', 'value')]))
|
||||
block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel'], idx=idx).flatten(1))
|
||||
block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'], idx=idx))
|
||||
block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/scale'], idx=idx))
|
||||
block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/bias'], idx=idx))
|
||||
for r in range(2):
|
||||
getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/kernel']))
|
||||
getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/bias']))
|
||||
getattr(block.mlp, f'fc{r + 1}').weight.copy_(
|
||||
_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/kernel'], idx=idx))
|
||||
getattr(block.mlp, f'fc{r + 1}').bias.copy_(
|
||||
_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/bias'], idx=idx))
|
||||
|
||||
|
||||
def _convert_openai_clip(
|
||||
@ -1769,6 +1778,73 @@ default_cfgs = {
|
||||
input_size=(3, 384, 384),
|
||||
num_classes=0),
|
||||
|
||||
'vit_base_patch16_siglip_gap_224.webli': _cfg(
|
||||
hf_hub_id='timm/ViT-B-16-SigLIP',
|
||||
hf_hub_filename='open_clip_pytorch_model.bin',
|
||||
num_classes=0),
|
||||
'vit_base_patch16_siglip_gap_256.webli': _cfg(
|
||||
hf_hub_id='timm/ViT-B-16-SigLIP-256',
|
||||
hf_hub_filename='open_clip_pytorch_model.bin',
|
||||
input_size=(3, 256, 256),
|
||||
num_classes=0),
|
||||
'vit_base_patch16_siglip_gap_384.webli': _cfg(
|
||||
hf_hub_id='timm/ViT-B-16-SigLIP-384',
|
||||
hf_hub_filename='open_clip_pytorch_model.bin',
|
||||
input_size=(3, 384, 384),
|
||||
num_classes=0),
|
||||
'vit_base_patch16_siglip_gap_512.webli': _cfg(
|
||||
hf_hub_id='timm/ViT-B-16-SigLIP-512',
|
||||
hf_hub_filename='open_clip_pytorch_model.bin',
|
||||
input_size=(3, 512, 512),
|
||||
num_classes=0),
|
||||
'vit_large_patch16_siglip_gap_256.webli': _cfg(
|
||||
hf_hub_id='timm/ViT-L-16-SigLIP-256',
|
||||
hf_hub_filename='open_clip_pytorch_model.bin',
|
||||
input_size=(3, 256, 256),
|
||||
num_classes=0),
|
||||
'vit_large_patch16_siglip_gap_384.webli': _cfg(
|
||||
hf_hub_id='timm/ViT-L-16-SigLIP-384',
|
||||
hf_hub_filename='open_clip_pytorch_model.bin',
|
||||
input_size=(3, 384, 384),
|
||||
num_classes=0),
|
||||
'vit_so400m_patch14_siglip_gap_224.webli': _cfg(
|
||||
hf_hub_id='timm/ViT-SO400M-14-SigLIP',
|
||||
hf_hub_filename='open_clip_pytorch_model.bin',
|
||||
num_classes=0),
|
||||
'vit_so400m_patch14_siglip_gap_224.pali_mix': _cfg(
|
||||
hf_hub_id='google/paligemma-3b-mix-224-jax',
|
||||
hf_hub_filename='paligemma-3b-mix-224.npz',
|
||||
custom_load='hf',
|
||||
num_classes=0),
|
||||
'vit_so400m_patch14_siglip_gap_224.pali_pt': _cfg(
|
||||
hf_hub_id='google/paligemma-3b-pt-224-jax',
|
||||
hf_hub_filename='paligemma-3b-pt-224.npz',
|
||||
custom_load='hf',
|
||||
num_classes=0),
|
||||
'vit_so400m_patch14_siglip_gap_384.webli': _cfg(
|
||||
hf_hub_id='timm/ViT-SO400M-14-SigLIP-384',
|
||||
hf_hub_filename='open_clip_pytorch_model.bin',
|
||||
input_size=(3, 384, 384), crop_pct=1.0,
|
||||
num_classes=0),
|
||||
'vit_so400m_patch14_siglip_gap_448.pali_mix': _cfg(
|
||||
hf_hub_id='google/paligemma-3b-mix-448-jax',
|
||||
hf_hub_filename='paligemma-3b-mix-448.npz',
|
||||
custom_load='hf',
|
||||
input_size=(3, 448, 448), crop_pct=1.0,
|
||||
num_classes=0),
|
||||
'vit_so400m_patch14_siglip_gap_448.pali_pt': _cfg(
|
||||
hf_hub_id='google/paligemma-3b-pt-448-jax',
|
||||
hf_hub_filename='paligemma-3b-pt-448.npz',
|
||||
custom_load='hf',
|
||||
input_size=(3, 448, 448), crop_pct=1.0,
|
||||
num_classes=0),
|
||||
'vit_so400m_patch14_siglip_gap_896.pali_pt': _cfg(
|
||||
hf_hub_id='google/paligemma-3b-pt-896-jax',
|
||||
hf_hub_filename='paligemma-3b-pt-896.npz',
|
||||
custom_load='hf',
|
||||
input_size=(3, 896, 896), crop_pct=1.0,
|
||||
num_classes=0),
|
||||
|
||||
'vit_xsmall_patch16_clip_224.tinyclip_yfcc15m': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
hf_hub_filename='open_clip_pytorch_model.bin',
|
||||
@ -1791,6 +1867,7 @@ default_cfgs = {
|
||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512),
|
||||
|
||||
'vit_wee_patch16_reg1_gap_256.sbb_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 256, 256), crop_pct=0.95),
|
||||
'vit_pwee_patch16_reg1_gap_256.sbb_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
@ -1801,9 +1878,16 @@ default_cfgs = {
|
||||
'vit_medium_patch16_reg1_gap_256.sbb_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 256, 256), crop_pct=0.95),
|
||||
'vit_medium_patch16_reg4_gap_256.sbb_in12k_ft_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 256, 256), crop_pct=0.95),
|
||||
'vit_medium_patch16_reg4_gap_256.sbb_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 256, 256), crop_pct=0.95),
|
||||
'vit_medium_patch16_reg4_gap_256.sbb_in12k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
num_classes=11821,
|
||||
input_size=(3, 256, 256), crop_pct=0.95),
|
||||
'vit_mediumd_patch16_reg4_gap_256.sbb_in12k_ft_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 256, 256), crop_pct=0.95),
|
||||
@ -2374,7 +2458,6 @@ def vit_base_patch16_clip_quickgelu_224(pretrained: bool = False, **kwargs) -> V
|
||||
def vit_large_patch14_clip_quickgelu_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
|
||||
""" ViT-Large model (ViT-L/14) CLIP image tower w/ QuickGELU act
|
||||
"""
|
||||
from timm.layers import get_act_layer
|
||||
model_args = dict(
|
||||
patch_size=14, embed_dim=1024, depth=24, num_heads=16, pre_norm=True,
|
||||
norm_layer=nn.LayerNorm, act_layer='quick_gelu')
|
||||
@ -2756,15 +2839,118 @@ def vit_so400m_patch14_siglip_384(pretrained: bool = False, **kwargs) -> VisionT
|
||||
return model
|
||||
|
||||
|
||||
# @register_model
|
||||
# def vit_medium_patch16_reg4_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
|
||||
# model_args = dict(
|
||||
# patch_size=16, embed_dim=512, depth=12, num_heads=8, class_token=True,
|
||||
# no_embed_class=True, reg_tokens=4,
|
||||
# )
|
||||
# model = _create_vision_transformer(
|
||||
# 'vit_medium_patch16_reg4_256', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
# return model
|
||||
@register_model
|
||||
def vit_base_patch16_siglip_gap_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
|
||||
""" A SigLIP variant of ViT with global average pooling (GAP) instead of attention pooling (MAP)."""
|
||||
model_args = dict(
|
||||
patch_size=16, embed_dim=768, depth=12, num_heads=12, class_token=False, global_pool='avg', fc_norm=False,
|
||||
)
|
||||
model = _create_vision_transformer(
|
||||
'vit_base_patch16_siglip_gap_224', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_base_patch16_siglip_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
|
||||
""" A SigLIP variant of ViT with global average pooling (GAP) instead of attention pooling (MAP)."""
|
||||
model_args = dict(
|
||||
patch_size=16, embed_dim=768, depth=12, num_heads=12, class_token=False, global_pool='avg', fc_norm=False,
|
||||
)
|
||||
model = _create_vision_transformer(
|
||||
'vit_base_patch16_siglip_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_base_patch16_siglip_gap_384(pretrained: bool = False, **kwargs) -> VisionTransformer:
|
||||
""" A SigLIP variant of ViT with global average pooling (GAP) instead of attention pooling (MAP)."""
|
||||
model_args = dict(
|
||||
patch_size=16, embed_dim=768, depth=12, num_heads=12, class_token=False, global_pool='avg', fc_norm=False,
|
||||
)
|
||||
model = _create_vision_transformer(
|
||||
'vit_base_patch16_siglip_gap_384', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_base_patch16_siglip_gap_512(pretrained: bool = False, **kwargs) -> VisionTransformer:
|
||||
""" A SigLIP variant of ViT with global average pooling (GAP) instead of attention pooling (MAP)."""
|
||||
model_args = dict(
|
||||
patch_size=16, embed_dim=768, depth=12, num_heads=12, class_token=False, global_pool='avg', fc_norm=False,
|
||||
)
|
||||
model = _create_vision_transformer(
|
||||
'vit_base_patch16_siglip_gap_512', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_large_patch16_siglip_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
|
||||
""" A SigLIP variant of ViT with global average pooling (GAP) instead of attention pooling (MAP)."""
|
||||
model_args = dict(
|
||||
patch_size=16, embed_dim=1024, depth=24, num_heads=16, class_token=False, global_pool='avg', fc_norm=False,
|
||||
)
|
||||
model = _create_vision_transformer(
|
||||
'vit_large_patch16_siglip_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_large_patch16_siglip_gap_384(pretrained: bool = False, **kwargs) -> VisionTransformer:
|
||||
""" A SigLIP variant of ViT with global average pooling (GAP) instead of attention pooling (MAP)."""
|
||||
model_args = dict(
|
||||
patch_size=16, embed_dim=1024, depth=24, num_heads=16, class_token=False, global_pool='avg', fc_norm=False,
|
||||
)
|
||||
model = _create_vision_transformer(
|
||||
'vit_large_patch16_siglip_gap_384', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_so400m_patch14_siglip_gap_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
|
||||
""" A SigLIP variant of ViT with global average pooling (GAP) instead of attention pooling (MAP)."""
|
||||
model_args = dict(
|
||||
patch_size=14, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362,
|
||||
class_token=False, global_pool='avg', fc_norm=False,
|
||||
)
|
||||
model = _create_vision_transformer(
|
||||
'vit_so400m_patch14_siglip_gap_224', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_so400m_patch14_siglip_gap_384(pretrained: bool = False, **kwargs) -> VisionTransformer:
|
||||
""" A SigLIP variant of ViT with global average pooling (GAP) instead of attention pooling (MAP)."""
|
||||
model_args = dict(
|
||||
patch_size=14, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362,
|
||||
class_token=False, global_pool='avg', fc_norm=False,
|
||||
)
|
||||
model = _create_vision_transformer(
|
||||
'vit_so400m_patch14_siglip_gap_384', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_so400m_patch14_siglip_gap_448(pretrained: bool = False, **kwargs) -> VisionTransformer:
|
||||
""" A SigLIP variant of ViT with global average pooling (GAP) instead of attention pooling (MAP)."""
|
||||
model_args = dict(
|
||||
patch_size=14, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362,
|
||||
class_token=False, global_pool='avg', fc_norm=False,
|
||||
)
|
||||
model = _create_vision_transformer(
|
||||
'vit_so400m_patch14_siglip_gap_448', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_so400m_patch14_siglip_gap_896(pretrained: bool = False, **kwargs) -> VisionTransformer:
|
||||
""" A SigLIP variant of ViT with global average pooling (GAP) instead of attention pooling (MAP)."""
|
||||
model_args = dict(
|
||||
patch_size=14, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362,
|
||||
class_token=False, global_pool='avg', fc_norm=False,
|
||||
)
|
||||
model = _create_vision_transformer(
|
||||
'vit_so400m_patch14_siglip_gap_896', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
|
@ -622,7 +622,7 @@ class VOLO(nn.Module):
|
||||
def get_classifier(self):
|
||||
return self.head
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool=None):
|
||||
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
|
||||
self.num_classes = num_classes
|
||||
if global_pool is not None:
|
||||
self.global_pool = global_pool
|
||||
|
@ -1,4 +1,5 @@
|
||||
import fnmatch
|
||||
import re
|
||||
from collections import OrderedDict
|
||||
from typing import Union, Optional, List
|
||||
|
||||
@ -17,6 +18,7 @@ class AttentionExtract(torch.nn.Module):
|
||||
mode: str = 'eval',
|
||||
method: str = 'fx',
|
||||
hook_type: str = 'forward',
|
||||
use_regex: bool = False,
|
||||
):
|
||||
""" Extract attention maps (or other activations) from a model by name.
|
||||
|
||||
@ -26,6 +28,7 @@ class AttentionExtract(torch.nn.Module):
|
||||
mode: 'train' or 'eval' model mode.
|
||||
method: 'fx' or 'hook' extraction method.
|
||||
hook_type: 'forward' or 'forward_pre' hooks used.
|
||||
use_regex: Use regex instead of fnmatch
|
||||
"""
|
||||
super().__init__()
|
||||
assert mode in ('train', 'eval')
|
||||
@ -40,14 +43,16 @@ class AttentionExtract(torch.nn.Module):
|
||||
from timm.models._features_fx import get_graph_node_names, GraphExtractNet
|
||||
|
||||
node_names = get_graph_node_names(model)[0 if mode == 'train' else 1]
|
||||
matched = []
|
||||
names = names or self.default_node_names
|
||||
for n in names:
|
||||
matched.extend(fnmatch.filter(node_names, n))
|
||||
if use_regex:
|
||||
regexes = [re.compile(r) for r in names]
|
||||
matched = [g for g in node_names if any([r.match(g) for r in regexes])]
|
||||
else:
|
||||
matched = [g for g in node_names if any([fnmatch.fnmatch(g, n) for n in names])]
|
||||
if not matched:
|
||||
raise RuntimeError(f'No node names found matching {names}.')
|
||||
|
||||
self.model = GraphExtractNet(model, matched)
|
||||
self.model = GraphExtractNet(model, matched, return_dict=True)
|
||||
self.hooks = None
|
||||
else:
|
||||
# names are module names
|
||||
@ -55,10 +60,12 @@ class AttentionExtract(torch.nn.Module):
|
||||
from timm.models._features import FeatureHooks
|
||||
|
||||
module_names = [n for n, m in model.named_modules()]
|
||||
matched = []
|
||||
names = names or self.default_module_names
|
||||
for n in names:
|
||||
matched.extend(fnmatch.filter(module_names, n))
|
||||
if use_regex:
|
||||
regexes = [re.compile(r) for r in names]
|
||||
matched = [m for m in module_names if any([r.match(m) for r in regexes])]
|
||||
else:
|
||||
matched = [m for m in module_names if any([fnmatch.fnmatch(m, n) for n in names])]
|
||||
if not matched:
|
||||
raise RuntimeError(f'No module names found matching {names}.')
|
||||
|
||||
@ -75,5 +82,4 @@ class AttentionExtract(torch.nn.Module):
|
||||
output = self.hooks.get_output(device=x.device)
|
||||
else:
|
||||
output = self.model(x)
|
||||
output = OrderedDict(zip(self.names, output))
|
||||
return output
|
||||
|
@ -108,9 +108,16 @@ def init_distributed_device_so(
|
||||
world_size = 1
|
||||
global_rank = 0
|
||||
local_rank = 0
|
||||
device_type, *device_idx = device.split(':', maxsplit=1)
|
||||
|
||||
if dist_backend is None:
|
||||
# FIXME sane defaults for other device backends?
|
||||
dist_backend = 'nccl' if 'cuda' in device else 'gloo'
|
||||
# FIXME: verify that ROCm transform nccl to rccl
|
||||
dist_backends = {
|
||||
"xpu": "ccl",
|
||||
"hpu": "hccl",
|
||||
"cuda": "nccl",
|
||||
}
|
||||
dist_backend = dist_backends.get(device_type, 'gloo')
|
||||
dist_url = dist_url or 'env://'
|
||||
|
||||
# TBD, support horovod?
|
||||
@ -150,18 +157,15 @@ def init_distributed_device_so(
|
||||
global_rank = torch.distributed.get_rank()
|
||||
distributed = True
|
||||
|
||||
if 'cuda' in device:
|
||||
if device_type == 'cuda':
|
||||
assert torch.cuda.is_available(), f'CUDA is not available but {device} was specified.'
|
||||
|
||||
if distributed and device != 'cpu':
|
||||
device, *device_idx = device.split(':', maxsplit=1)
|
||||
|
||||
# Ignore manually specified device index in distributed mode and
|
||||
# override with resolved local rank, fewer headaches in most setups.
|
||||
if device_idx:
|
||||
_logger.warning(f'device index {device_idx[0]} removed from specified ({device}).')
|
||||
|
||||
device = f'{device}:{local_rank}'
|
||||
device = f'{device_type}:{local_rank}'
|
||||
|
||||
if device.startswith('cuda:'):
|
||||
torch.cuda.set_device(device)
|
||||
|
@ -1 +1 @@
|
||||
__version__ = '1.0.1.dev0'
|
||||
__version__ = '1.0.4.dev0'
|
||||
|
Loading…
x
Reference in New Issue
Block a user