mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Merge pull request #1593 from rwightman/multi-weight_effnet_convnext
Update efficientnet.py and convnext.py to multi-weight, add new 12k pretrained weights
This commit is contained in:
commit
4e24f75289
10
.gitignore
vendored
10
.gitignore
vendored
@ -106,6 +106,16 @@ output/
|
|||||||
*.tar
|
*.tar
|
||||||
*.pth
|
*.pth
|
||||||
*.pt
|
*.pt
|
||||||
|
*.torch
|
||||||
*.gz
|
*.gz
|
||||||
Untitled.ipynb
|
Untitled.ipynb
|
||||||
Testing notebook.ipynb
|
Testing notebook.ipynb
|
||||||
|
|
||||||
|
# Root dir exclusions
|
||||||
|
/*.csv
|
||||||
|
/*.yaml
|
||||||
|
/*.json
|
||||||
|
/*.jpg
|
||||||
|
/*.png
|
||||||
|
/*.zip
|
||||||
|
/*.tar.*
|
@ -27,7 +27,7 @@ NON_STD_FILTERS = [
|
|||||||
'vit_*', 'tnt_*', 'pit_*', 'swin_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*',
|
'vit_*', 'tnt_*', 'pit_*', 'swin_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*',
|
||||||
'convit_*', 'levit*', 'visformer*', 'deit*', 'jx_nest_*', 'nest_*', 'xcit_*', 'crossvit_*', 'beit*',
|
'convit_*', 'levit*', 'visformer*', 'deit*', 'jx_nest_*', 'nest_*', 'xcit_*', 'crossvit_*', 'beit*',
|
||||||
'poolformer_*', 'volo_*', 'sequencer2d_*', 'swinv2_*', 'pvt_v2*', 'mvitv2*', 'gcvit*', 'efficientformer*',
|
'poolformer_*', 'volo_*', 'sequencer2d_*', 'swinv2_*', 'pvt_v2*', 'mvitv2*', 'gcvit*', 'efficientformer*',
|
||||||
'coatnet*', 'coatnext*', 'maxvit*', 'maxxvit*', 'eva_*'
|
'coatnet*', 'coatnext*', 'maxvit*', 'maxxvit*', 'eva_*', 'flexivit*'
|
||||||
]
|
]
|
||||||
NUM_NON_STD = len(NON_STD_FILTERS)
|
NUM_NON_STD = len(NON_STD_FILTERS)
|
||||||
|
|
||||||
|
@ -151,7 +151,7 @@ def create_dataset(
|
|||||||
elif name.startswith('hfds/'):
|
elif name.startswith('hfds/'):
|
||||||
# NOTE right now, HF datasets default arrow format is a random-access Dataset,
|
# NOTE right now, HF datasets default arrow format is a random-access Dataset,
|
||||||
# There will be a IterableDataset variant too, TBD
|
# There will be a IterableDataset variant too, TBD
|
||||||
ds = ImageDataset(root, reader=name, split=split, **kwargs)
|
ds = ImageDataset(root, reader=name, split=split, class_map=class_map, **kwargs)
|
||||||
elif name.startswith('tfds/'):
|
elif name.startswith('tfds/'):
|
||||||
ds = IterableImageDataset(
|
ds = IterableImageDataset(
|
||||||
root,
|
root,
|
||||||
|
@ -6,7 +6,7 @@ from .reader_image_in_tar import ReaderImageInTar
|
|||||||
|
|
||||||
def create_reader(name, root, split='train', **kwargs):
|
def create_reader(name, root, split='train', **kwargs):
|
||||||
name = name.lower()
|
name = name.lower()
|
||||||
name = name.split('/', 2)
|
name = name.split('/', 1)
|
||||||
prefix = ''
|
prefix = ''
|
||||||
if len(name) > 1:
|
if len(name) > 1:
|
||||||
prefix = name[0]
|
prefix = name[0]
|
||||||
|
@ -13,13 +13,14 @@ try:
|
|||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
print("Please install Hugging Face datasets package `pip install datasets`.")
|
print("Please install Hugging Face datasets package `pip install datasets`.")
|
||||||
exit(1)
|
exit(1)
|
||||||
|
from .class_map import load_class_map
|
||||||
from .reader import Reader
|
from .reader import Reader
|
||||||
|
|
||||||
|
|
||||||
def get_class_labels(info):
|
def get_class_labels(info, label_key='label'):
|
||||||
if 'label' not in info.features:
|
if 'label' not in info.features:
|
||||||
return {}
|
return {}
|
||||||
class_label = info.features['label']
|
class_label = info.features[label_key]
|
||||||
class_to_idx = {n: class_label.str2int(n) for n in class_label.names}
|
class_to_idx = {n: class_label.str2int(n) for n in class_label.names}
|
||||||
return class_to_idx
|
return class_to_idx
|
||||||
|
|
||||||
@ -32,6 +33,7 @@ class ReaderHfds(Reader):
|
|||||||
name,
|
name,
|
||||||
split='train',
|
split='train',
|
||||||
class_map=None,
|
class_map=None,
|
||||||
|
label_key='label',
|
||||||
download=False,
|
download=False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@ -43,12 +45,17 @@ class ReaderHfds(Reader):
|
|||||||
name, # 'name' maps to path arg in hf datasets
|
name, # 'name' maps to path arg in hf datasets
|
||||||
split=split,
|
split=split,
|
||||||
cache_dir=self.root, # timm doesn't expect hidden cache dir for datasets, specify a path
|
cache_dir=self.root, # timm doesn't expect hidden cache dir for datasets, specify a path
|
||||||
#use_auth_token=True,
|
|
||||||
)
|
)
|
||||||
# leave decode for caller, plus we want easy access to original path names...
|
# leave decode for caller, plus we want easy access to original path names...
|
||||||
self.dataset = self.dataset.cast_column('image', datasets.Image(decode=False))
|
self.dataset = self.dataset.cast_column('image', datasets.Image(decode=False))
|
||||||
|
|
||||||
self.class_to_idx = get_class_labels(self.dataset.info)
|
self.label_key = label_key
|
||||||
|
self.remap_class = False
|
||||||
|
if class_map:
|
||||||
|
self.class_to_idx = load_class_map(class_map)
|
||||||
|
self.remap_class = True
|
||||||
|
else:
|
||||||
|
self.class_to_idx = get_class_labels(self.dataset.info, self.label_key)
|
||||||
self.split_info = self.dataset.info.splits[split]
|
self.split_info = self.dataset.info.splits[split]
|
||||||
self.num_samples = self.split_info.num_examples
|
self.num_samples = self.split_info.num_examples
|
||||||
|
|
||||||
@ -60,7 +67,10 @@ class ReaderHfds(Reader):
|
|||||||
else:
|
else:
|
||||||
assert 'path' in image and image['path']
|
assert 'path' in image and image['path']
|
||||||
image = open(image['path'], 'rb')
|
image = open(image['path'], 'rb')
|
||||||
return image, item['label']
|
label = item[self.label_key]
|
||||||
|
if self.remap_class:
|
||||||
|
label = self.class_to_idx[label]
|
||||||
|
return image, label
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.dataset)
|
return len(self.dataset)
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
from .activations import *
|
from .activations import *
|
||||||
from .adaptive_avgmax_pool import \
|
from .adaptive_avgmax_pool import \
|
||||||
adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d
|
adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d
|
||||||
|
from .attention_pool2d import AttentionPool2d, RotAttentionPool2d, RotaryEmbedding
|
||||||
from .blur_pool import BlurPool2d
|
from .blur_pool import BlurPool2d
|
||||||
from .classifier import ClassifierHead, create_classifier
|
from .classifier import ClassifierHead, create_classifier
|
||||||
from .cond_conv2d import CondConv2d, get_condconv_initializer
|
from .cond_conv2d import CondConv2d, get_condconv_initializer
|
||||||
@ -30,8 +31,12 @@ from .non_local_attn import NonLocalAttn, BatNonLocalAttn
|
|||||||
from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d
|
from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d
|
||||||
from .norm_act import BatchNormAct2d, GroupNormAct, convert_sync_batchnorm
|
from .norm_act import BatchNormAct2d, GroupNormAct, convert_sync_batchnorm
|
||||||
from .padding import get_padding, get_same_padding, pad_same
|
from .padding import get_padding, get_same_padding, pad_same
|
||||||
from .patch_embed import PatchEmbed
|
from .patch_embed import PatchEmbed, resample_patch_embed
|
||||||
from .pool2d_same import AvgPool2dSame, create_pool2d
|
from .pool2d_same import AvgPool2dSame, create_pool2d
|
||||||
|
from .pos_embed import resample_abs_pos_embed
|
||||||
|
from .pos_embed_rel import RelPosMlp, RelPosBias, RelPosBiasTf, gen_relative_position_index, gen_relative_log_coords
|
||||||
|
from .pos_embed_sincos import build_sincos2d_pos_embed, build_fourier_pos_embed, build_rotary_pos_embed, \
|
||||||
|
FourierEmbed, RotaryEmbedding
|
||||||
from .squeeze_excite import SEModule, SqueezeExcite, EffectiveSEModule, EffectiveSqueezeExcite
|
from .squeeze_excite import SEModule, SqueezeExcite, EffectiveSEModule, EffectiveSqueezeExcite
|
||||||
from .selective_kernel import SelectiveKernel
|
from .selective_kernel import SelectiveKernel
|
||||||
from .separable_conv import SeparableConv2d, SeparableConvNormAct
|
from .separable_conv import SeparableConv2d, SeparableConvNormAct
|
||||||
|
@ -13,7 +13,7 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from .helpers import to_2tuple
|
from .helpers import to_2tuple
|
||||||
from .pos_embed import apply_rot_embed, RotaryEmbedding
|
from .pos_embed_sincos import apply_rot_embed, RotaryEmbedding
|
||||||
from .weight_init import trunc_normal_
|
from .weight_init import trunc_normal_
|
||||||
|
|
||||||
|
|
||||||
|
@ -10,7 +10,7 @@ import collections.abc
|
|||||||
def _ntuple(n):
|
def _ntuple(n):
|
||||||
def parse(x):
|
def parse(x):
|
||||||
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
|
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
|
||||||
return x
|
return tuple(x)
|
||||||
return tuple(repeat(x, n))
|
return tuple(repeat(x, n))
|
||||||
return parse
|
return parse
|
||||||
|
|
||||||
|
@ -2,15 +2,24 @@
|
|||||||
|
|
||||||
A convolution based approach to patchifying a 2D image w/ embedding projection.
|
A convolution based approach to patchifying a 2D image w/ embedding projection.
|
||||||
|
|
||||||
Based on the impl in https://github.com/google-research/vision_transformer
|
Based on code in:
|
||||||
|
* https://github.com/google-research/vision_transformer
|
||||||
|
* https://github.com/google-research/big_vision/tree/main/big_vision
|
||||||
|
|
||||||
Hacked together by / Copyright 2020 Ross Wightman
|
Hacked together by / Copyright 2020 Ross Wightman
|
||||||
"""
|
"""
|
||||||
|
import logging
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import torch
|
||||||
from torch import nn as nn
|
from torch import nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from .helpers import to_2tuple
|
from .helpers import to_2tuple
|
||||||
from .trace_utils import _assert
|
from .trace_utils import _assert
|
||||||
|
|
||||||
|
_logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class PatchEmbed(nn.Module):
|
class PatchEmbed(nn.Module):
|
||||||
""" 2D Image to Patch Embedding
|
""" 2D Image to Patch Embedding
|
||||||
@ -46,3 +55,122 @@ class PatchEmbed(nn.Module):
|
|||||||
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
|
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
|
||||||
x = self.norm(x)
|
x = self.norm(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def resample_patch_embed(
|
||||||
|
patch_embed,
|
||||||
|
new_size: List[int],
|
||||||
|
interpolation: str = 'bicubic',
|
||||||
|
antialias: bool = True,
|
||||||
|
verbose: bool = False,
|
||||||
|
):
|
||||||
|
"""Resample the weights of the patch embedding kernel to target resolution.
|
||||||
|
We resample the patch embedding kernel by approximately inverting the effect
|
||||||
|
of patch resizing.
|
||||||
|
|
||||||
|
Code based on:
|
||||||
|
https://github.com/google-research/big_vision/blob/b00544b81f8694488d5f36295aeb7972f3755ffe/big_vision/models/proj/flexi/vit.py
|
||||||
|
|
||||||
|
With this resizing, we can for example load a B/8 filter into a B/16 model
|
||||||
|
and, on 2x larger input image, the result will match.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
patch_embed: original parameter to be resized.
|
||||||
|
new_size (tuple(int, int): target shape (height, width)-only.
|
||||||
|
interpolation (str): interpolation for resize
|
||||||
|
antialias (bool): use anti-aliasing filter in resize
|
||||||
|
verbose (bool): log operation
|
||||||
|
Returns:
|
||||||
|
Resized patch embedding kernel.
|
||||||
|
"""
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
assert len(patch_embed.shape) == 4, "Four dimensions expected"
|
||||||
|
assert len(new_size) == 2, "New shape should only be hw"
|
||||||
|
old_size = patch_embed.shape[-2:]
|
||||||
|
if tuple(old_size) == tuple(new_size):
|
||||||
|
return patch_embed
|
||||||
|
|
||||||
|
if verbose:
|
||||||
|
_logger.info(f"Resize patch embedding {patch_embed.shape} to {new_size}, w/ {interpolation} interpolation.")
|
||||||
|
|
||||||
|
def resize(x_np, _new_size):
|
||||||
|
x_tf = torch.Tensor(x_np)[None, None, ...]
|
||||||
|
x_upsampled = F.interpolate(
|
||||||
|
x_tf, size=_new_size, mode=interpolation, antialias=antialias)[0, 0, ...].numpy()
|
||||||
|
return x_upsampled
|
||||||
|
|
||||||
|
def get_resize_mat(_old_size, _new_size):
|
||||||
|
mat = []
|
||||||
|
for i in range(np.prod(_old_size)):
|
||||||
|
basis_vec = np.zeros(_old_size)
|
||||||
|
basis_vec[np.unravel_index(i, _old_size)] = 1.
|
||||||
|
mat.append(resize(basis_vec, _new_size).reshape(-1))
|
||||||
|
return np.stack(mat).T
|
||||||
|
|
||||||
|
resize_mat = get_resize_mat(old_size, new_size)
|
||||||
|
resize_mat_pinv = torch.Tensor(np.linalg.pinv(resize_mat.T))
|
||||||
|
|
||||||
|
def resample_kernel(kernel):
|
||||||
|
resampled_kernel = resize_mat_pinv @ kernel.reshape(-1)
|
||||||
|
return resampled_kernel.reshape(new_size)
|
||||||
|
|
||||||
|
v_resample_kernel = torch.vmap(torch.vmap(resample_kernel, 0, 0), 1, 1)
|
||||||
|
return v_resample_kernel(patch_embed)
|
||||||
|
|
||||||
|
|
||||||
|
# def divs(n, m=None):
|
||||||
|
# m = m or n // 2
|
||||||
|
# if m == 1:
|
||||||
|
# return [1]
|
||||||
|
# if n % m == 0:
|
||||||
|
# return [m] + divs(n, m - 1)
|
||||||
|
# return divs(n, m - 1)
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# class FlexiPatchEmbed(nn.Module):
|
||||||
|
# """ 2D Image to Patch Embedding w/ Flexible Patch sizes (FlexiViT)
|
||||||
|
# FIXME WIP
|
||||||
|
# """
|
||||||
|
# def __init__(
|
||||||
|
# self,
|
||||||
|
# img_size=240,
|
||||||
|
# patch_size=16,
|
||||||
|
# in_chans=3,
|
||||||
|
# embed_dim=768,
|
||||||
|
# base_img_size=240,
|
||||||
|
# base_patch_size=32,
|
||||||
|
# norm_layer=None,
|
||||||
|
# flatten=True,
|
||||||
|
# bias=True,
|
||||||
|
# ):
|
||||||
|
# super().__init__()
|
||||||
|
# self.img_size = to_2tuple(img_size)
|
||||||
|
# self.patch_size = to_2tuple(patch_size)
|
||||||
|
# self.num_patches = 0
|
||||||
|
#
|
||||||
|
# # full range for 240 = (5, 6, 8, 10, 12, 14, 15, 16, 20, 24, 30, 40, 48)
|
||||||
|
# self.seqhw = (6, 8, 10, 12, 14, 15, 16, 20, 24, 30)
|
||||||
|
#
|
||||||
|
# self.base_img_size = to_2tuple(base_img_size)
|
||||||
|
# self.base_patch_size = to_2tuple(base_patch_size)
|
||||||
|
# self.base_grid_size = tuple([i // p for i, p in zip(self.base_img_size, self.base_patch_size)])
|
||||||
|
# self.base_num_patches = self.base_grid_size[0] * self.base_grid_size[1]
|
||||||
|
#
|
||||||
|
# self.flatten = flatten
|
||||||
|
# self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=self.patch_size, stride=self.patch_size, bias=bias)
|
||||||
|
# self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
||||||
|
#
|
||||||
|
# def forward(self, x):
|
||||||
|
# B, C, H, W = x.shape
|
||||||
|
#
|
||||||
|
# if self.patch_size == self.base_patch_size:
|
||||||
|
# weight = self.proj.weight
|
||||||
|
# else:
|
||||||
|
# weight = resample_patch_embed(self.proj.weight, self.patch_size)
|
||||||
|
# patch_size = self.patch_size
|
||||||
|
# x = F.conv2d(x, weight, bias=self.proj.bias, stride=patch_size)
|
||||||
|
# if self.flatten:
|
||||||
|
# x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
|
||||||
|
# x = self.norm(x)
|
||||||
|
# return x
|
||||||
|
@ -1,207 +1,52 @@
|
|||||||
|
""" Position Embedding Utilities
|
||||||
|
|
||||||
|
Hacked together by / Copyright 2022 Ross Wightman
|
||||||
|
"""
|
||||||
|
import logging
|
||||||
import math
|
import math
|
||||||
from typing import List, Tuple, Optional, Union
|
from typing import List, Tuple, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn as nn
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from .helpers import to_2tuple
|
||||||
|
|
||||||
|
_logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def pixel_freq_bands(
|
def resample_abs_pos_embed(
|
||||||
num_bands: int,
|
posemb,
|
||||||
max_freq: float = 224.,
|
new_size: List[int],
|
||||||
linear_bands: bool = True,
|
old_size: Optional[List[int]] = None,
|
||||||
dtype: torch.dtype = torch.float32,
|
num_prefix_tokens: int = 1,
|
||||||
device: Optional[torch.device] = None,
|
interpolation: str = 'bicubic',
|
||||||
|
antialias: bool = True,
|
||||||
|
verbose: bool = False,
|
||||||
):
|
):
|
||||||
if linear_bands:
|
# sort out sizes, assume square if old size not provided
|
||||||
bands = torch.linspace(1.0, max_freq / 2, num_bands, dtype=dtype, device=device)
|
new_size = to_2tuple(new_size)
|
||||||
|
new_ntok = new_size[0] * new_size[1]
|
||||||
|
if not old_size:
|
||||||
|
old_size = int(math.sqrt(posemb.shape[1] - num_prefix_tokens))
|
||||||
|
old_size = to_2tuple(old_size)
|
||||||
|
if new_size == old_size: # might not both be same container type
|
||||||
|
return posemb
|
||||||
|
|
||||||
|
if num_prefix_tokens:
|
||||||
|
posemb_prefix, posemb = posemb[:, :num_prefix_tokens], posemb[:, num_prefix_tokens:]
|
||||||
else:
|
else:
|
||||||
bands = 2 ** torch.linspace(0, math.log(max_freq, 2) - 1, num_bands, dtype=dtype, device=device)
|
posemb_prefix, posemb = None, posemb
|
||||||
return bands * torch.pi
|
|
||||||
|
|
||||||
|
# do the interpolation
|
||||||
|
posemb = posemb.reshape(1, old_size[0], old_size[1], -1).permute(0, 3, 1, 2)
|
||||||
|
posemb = F.interpolate(posemb, size=new_size, mode=interpolation, antialias=antialias)
|
||||||
|
posemb = posemb.permute(0, 2, 3, 1).reshape(1, new_ntok, -1)
|
||||||
|
|
||||||
def inv_freq_bands(
|
if verbose:
|
||||||
num_bands: int,
|
_logger.info(f'Resized position embedding: {old_size} to {new_size}.')
|
||||||
temperature: float = 100000.,
|
|
||||||
step: int = 2,
|
|
||||||
dtype: torch.dtype = torch.float32,
|
|
||||||
device: Optional[torch.device] = None,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
inv_freq = 1. / (temperature ** (torch.arange(0, num_bands, step, dtype=dtype, device=device) / num_bands))
|
|
||||||
return inv_freq
|
|
||||||
|
|
||||||
|
# add back extra (class, etc) prefix tokens
|
||||||
def build_sincos2d_pos_embed(
|
if posemb_prefix is not None:
|
||||||
feat_shape: List[int],
|
print(posemb_prefix.shape, posemb.shape)
|
||||||
dim: int = 64,
|
posemb = torch.cat([posemb_prefix, posemb], dim=1)
|
||||||
temperature: float = 10000.,
|
return posemb
|
||||||
reverse_coord: bool = False,
|
|
||||||
interleave_sin_cos: bool = False,
|
|
||||||
dtype: torch.dtype = torch.float32,
|
|
||||||
device: Optional[torch.device] = None
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
|
|
||||||
Args:
|
|
||||||
feat_shape:
|
|
||||||
dim:
|
|
||||||
temperature:
|
|
||||||
reverse_coord: stack grid order W, H instead of H, W
|
|
||||||
interleave_sin_cos: sin, cos, sin, cos stack instead of sin, sin, cos, cos
|
|
||||||
dtype:
|
|
||||||
device:
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
|
|
||||||
"""
|
|
||||||
assert dim % 4 == 0, 'Embed dimension must be divisible by 4 for sin-cos 2D position embedding'
|
|
||||||
pos_dim = dim // 4
|
|
||||||
bands = inv_freq_bands(pos_dim, temperature=temperature, step=1, dtype=dtype, device=device)
|
|
||||||
|
|
||||||
if reverse_coord:
|
|
||||||
feat_shape = feat_shape[::-1] # stack W, H instead of H, W
|
|
||||||
grid = torch.stack(
|
|
||||||
torch.meshgrid([torch.arange(s, device=device, dtype=dtype) for s in feat_shape])).flatten(1).transpose(0, 1)
|
|
||||||
pos2 = grid.unsqueeze(-1) * bands.unsqueeze(0)
|
|
||||||
# FIXME add support for unflattened spatial dim?
|
|
||||||
|
|
||||||
stack_dim = 2 if interleave_sin_cos else 1 # stack sin, cos, sin, cos instead of sin sin cos cos
|
|
||||||
pos_emb = torch.stack([torch.sin(pos2), torch.cos(pos2)], dim=stack_dim).flatten(1)
|
|
||||||
return pos_emb
|
|
||||||
|
|
||||||
|
|
||||||
def build_fourier_pos_embed(
|
|
||||||
feat_shape: List[int],
|
|
||||||
bands: Optional[torch.Tensor] = None,
|
|
||||||
num_bands: int = 64,
|
|
||||||
max_res: int = 224,
|
|
||||||
linear_bands: bool = False,
|
|
||||||
include_grid: bool = False,
|
|
||||||
concat_out: bool = True,
|
|
||||||
in_pixels: bool = True,
|
|
||||||
dtype: torch.dtype = torch.float32,
|
|
||||||
device: Optional[torch.device] = None,
|
|
||||||
) -> List[torch.Tensor]:
|
|
||||||
if bands is None:
|
|
||||||
if in_pixels:
|
|
||||||
bands = pixel_freq_bands(num_bands, float(max_res), linear_bands=linear_bands, dtype=dtype, device=device)
|
|
||||||
else:
|
|
||||||
bands = inv_freq_bands(num_bands, step=1, dtype=dtype, device=device)
|
|
||||||
else:
|
|
||||||
if device is None:
|
|
||||||
device = bands.device
|
|
||||||
if dtype is None:
|
|
||||||
dtype = bands.dtype
|
|
||||||
|
|
||||||
if in_pixels:
|
|
||||||
grid = torch.stack(torch.meshgrid(
|
|
||||||
[torch.linspace(-1., 1., steps=s, device=device, dtype=dtype) for s in feat_shape]), dim=-1)
|
|
||||||
else:
|
|
||||||
grid = torch.stack(torch.meshgrid(
|
|
||||||
[torch.arange(s, device=device, dtype=dtype) for s in feat_shape]), dim=-1)
|
|
||||||
grid = grid.unsqueeze(-1)
|
|
||||||
pos = grid * bands
|
|
||||||
|
|
||||||
pos_sin, pos_cos = pos.sin(), pos.cos()
|
|
||||||
out = (grid, pos_sin, pos_cos) if include_grid else (pos_sin, pos_cos)
|
|
||||||
# FIXME torchscript doesn't like multiple return types, probably need to always cat?
|
|
||||||
if concat_out:
|
|
||||||
out = torch.cat(out, dim=-1)
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
class FourierEmbed(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, max_res: int = 224, num_bands: int = 64, concat_grid=True, keep_spatial=False):
|
|
||||||
super().__init__()
|
|
||||||
self.max_res = max_res
|
|
||||||
self.num_bands = num_bands
|
|
||||||
self.concat_grid = concat_grid
|
|
||||||
self.keep_spatial = keep_spatial
|
|
||||||
self.register_buffer('bands', pixel_freq_bands(max_res, num_bands), persistent=False)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
B, C = x.shape[:2]
|
|
||||||
feat_shape = x.shape[2:]
|
|
||||||
emb = build_fourier_pos_embed(
|
|
||||||
feat_shape,
|
|
||||||
self.bands,
|
|
||||||
include_grid=self.concat_grid,
|
|
||||||
dtype=x.dtype,
|
|
||||||
device=x.device)
|
|
||||||
emb = emb.transpose(-1, -2).flatten(len(feat_shape))
|
|
||||||
batch_expand = (B,) + (-1,) * (x.ndim - 1)
|
|
||||||
|
|
||||||
# FIXME support nD
|
|
||||||
if self.keep_spatial:
|
|
||||||
x = torch.cat([x, emb.unsqueeze(0).expand(batch_expand).permute(0, 3, 1, 2)], dim=1)
|
|
||||||
else:
|
|
||||||
x = torch.cat([x.permute(0, 2, 3, 1), emb.unsqueeze(0).expand(batch_expand)], dim=-1)
|
|
||||||
x = x.reshape(B, feat_shape.numel(), -1)
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
def rot(x):
|
|
||||||
return torch.stack([-x[..., 1::2], x[..., ::2]], -1).reshape(x.shape)
|
|
||||||
|
|
||||||
|
|
||||||
def apply_rot_embed(x: torch.Tensor, sin_emb, cos_emb):
|
|
||||||
return x * cos_emb + rot(x) * sin_emb
|
|
||||||
|
|
||||||
|
|
||||||
def apply_rot_embed_list(x: List[torch.Tensor], sin_emb, cos_emb):
|
|
||||||
if isinstance(x, torch.Tensor):
|
|
||||||
x = [x]
|
|
||||||
return [t * cos_emb + rot(t) * sin_emb for t in x]
|
|
||||||
|
|
||||||
|
|
||||||
def apply_rot_embed_split(x: torch.Tensor, emb):
|
|
||||||
split = emb.shape[-1] // 2
|
|
||||||
return x * emb[:, :split] + rot(x) * emb[:, split:]
|
|
||||||
|
|
||||||
|
|
||||||
def build_rotary_pos_embed(
|
|
||||||
feat_shape: List[int],
|
|
||||||
bands: Optional[torch.Tensor] = None,
|
|
||||||
dim: int = 64,
|
|
||||||
max_freq: float = 224,
|
|
||||||
linear_bands: bool = False,
|
|
||||||
dtype: torch.dtype = torch.float32,
|
|
||||||
device: Optional[torch.device] = None,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
NOTE: shape arg should include spatial dim only
|
|
||||||
"""
|
|
||||||
feat_shape = torch.Size(feat_shape)
|
|
||||||
|
|
||||||
sin_emb, cos_emb = build_fourier_pos_embed(
|
|
||||||
feat_shape, bands=bands, num_bands=dim // 4, max_res=max_freq, linear_bands=linear_bands,
|
|
||||||
concat_out=False, device=device, dtype=dtype)
|
|
||||||
N = feat_shape.numel()
|
|
||||||
sin_emb = sin_emb.reshape(N, -1).repeat_interleave(2, -1)
|
|
||||||
cos_emb = cos_emb.reshape(N, -1).repeat_interleave(2, -1)
|
|
||||||
return sin_emb, cos_emb
|
|
||||||
|
|
||||||
|
|
||||||
class RotaryEmbedding(nn.Module):
|
|
||||||
""" Rotary position embedding
|
|
||||||
|
|
||||||
NOTE: This is my initial attempt at impl rotary embedding for spatial use, it has not
|
|
||||||
been well tested, and will likely change. It will be moved to its own file.
|
|
||||||
|
|
||||||
The following impl/resources were referenced for this impl:
|
|
||||||
* https://github.com/lucidrains/vit-pytorch/blob/6f3a5fcf0bca1c5ec33a35ef48d97213709df4ba/vit_pytorch/rvt.py
|
|
||||||
* https://blog.eleuther.ai/rotary-embeddings/
|
|
||||||
"""
|
|
||||||
def __init__(self, dim, max_res=224, linear_bands: bool = False):
|
|
||||||
super().__init__()
|
|
||||||
self.dim = dim
|
|
||||||
self.register_buffer('bands', pixel_freq_bands(dim // 4, max_res, linear_bands=linear_bands), persistent=False)
|
|
||||||
|
|
||||||
def get_embed(self, shape: List[int]):
|
|
||||||
return build_rotary_pos_embed(shape, self.bands)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
# assuming channel-first tensor where spatial dim are >= 2
|
|
||||||
sin_emb, cos_emb = self.get_embed(x.shape[2:])
|
|
||||||
return apply_rot_embed(x, sin_emb, cos_emb)
|
|
||||||
|
283
timm/layers/pos_embed_rel.py
Normal file
283
timm/layers/pos_embed_rel.py
Normal file
@ -0,0 +1,283 @@
|
|||||||
|
""" Relative position embedding modules and functions
|
||||||
|
|
||||||
|
Hacked together by / Copyright 2022 Ross Wightman
|
||||||
|
"""
|
||||||
|
import math
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from .mlp import Mlp
|
||||||
|
from .weight_init import trunc_normal_
|
||||||
|
|
||||||
|
|
||||||
|
def gen_relative_position_index(
|
||||||
|
q_size: Tuple[int, int],
|
||||||
|
k_size: Tuple[int, int] = None,
|
||||||
|
class_token: bool = False) -> torch.Tensor:
|
||||||
|
# Adapted with significant modifications from Swin / BeiT codebases
|
||||||
|
# get pair-wise relative position index for each token inside the window
|
||||||
|
q_coords = torch.stack(torch.meshgrid([torch.arange(q_size[0]), torch.arange(q_size[1])])).flatten(1) # 2, Wh, Ww
|
||||||
|
if k_size is None:
|
||||||
|
k_coords = q_coords
|
||||||
|
k_size = q_size
|
||||||
|
else:
|
||||||
|
# different q vs k sizes is a WIP
|
||||||
|
k_coords = torch.stack(torch.meshgrid([torch.arange(k_size[0]), torch.arange(k_size[1])])).flatten(1)
|
||||||
|
relative_coords = q_coords[:, :, None] - k_coords[:, None, :] # 2, Wh*Ww, Wh*Ww
|
||||||
|
relative_coords = relative_coords.permute(1, 2, 0) # Wh*Ww, Wh*Ww, 2
|
||||||
|
_, relative_position_index = torch.unique(relative_coords.view(-1, 2), return_inverse=True, dim=0)
|
||||||
|
|
||||||
|
if class_token:
|
||||||
|
# handle cls to token & token 2 cls & cls to cls as per beit for rel pos bias
|
||||||
|
# NOTE not intended or tested with MLP log-coords
|
||||||
|
max_size = (max(q_size[0], k_size[0]), max(q_size[1], k_size[1]))
|
||||||
|
num_relative_distance = (2 * max_size[0] - 1) * (2 * max_size[1] - 1) + 3
|
||||||
|
relative_position_index = F.pad(relative_position_index, [1, 0, 1, 0])
|
||||||
|
relative_position_index[0, 0:] = num_relative_distance - 3
|
||||||
|
relative_position_index[0:, 0] = num_relative_distance - 2
|
||||||
|
relative_position_index[0, 0] = num_relative_distance - 1
|
||||||
|
|
||||||
|
return relative_position_index.contiguous()
|
||||||
|
|
||||||
|
|
||||||
|
class RelPosBias(nn.Module):
|
||||||
|
""" Relative Position Bias
|
||||||
|
Adapted from Swin-V1 relative position bias impl, modularized.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, window_size, num_heads, prefix_tokens=0):
|
||||||
|
super().__init__()
|
||||||
|
assert prefix_tokens <= 1
|
||||||
|
self.window_size = window_size
|
||||||
|
self.window_area = window_size[0] * window_size[1]
|
||||||
|
self.bias_shape = (self.window_area + prefix_tokens,) * 2 + (num_heads,)
|
||||||
|
|
||||||
|
num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 * prefix_tokens
|
||||||
|
self.relative_position_bias_table = nn.Parameter(torch.zeros(num_relative_distance, num_heads))
|
||||||
|
self.register_buffer(
|
||||||
|
"relative_position_index",
|
||||||
|
gen_relative_position_index(self.window_size, class_token=prefix_tokens > 0),
|
||||||
|
persistent=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.init_weights()
|
||||||
|
|
||||||
|
def init_weights(self):
|
||||||
|
trunc_normal_(self.relative_position_bias_table, std=.02)
|
||||||
|
|
||||||
|
def get_bias(self) -> torch.Tensor:
|
||||||
|
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)]
|
||||||
|
# win_h * win_w, win_h * win_w, num_heads
|
||||||
|
relative_position_bias = relative_position_bias.view(self.bias_shape).permute(2, 0, 1)
|
||||||
|
return relative_position_bias.unsqueeze(0).contiguous()
|
||||||
|
|
||||||
|
def forward(self, attn, shared_rel_pos: Optional[torch.Tensor] = None):
|
||||||
|
return attn + self.get_bias()
|
||||||
|
|
||||||
|
|
||||||
|
def gen_relative_log_coords(
|
||||||
|
win_size: Tuple[int, int],
|
||||||
|
pretrained_win_size: Tuple[int, int] = (0, 0),
|
||||||
|
mode='swin',
|
||||||
|
):
|
||||||
|
assert mode in ('swin', 'cr', 'rw')
|
||||||
|
# as per official swin-v2 impl, supporting timm specific 'cr' and 'rw' log coords as well
|
||||||
|
relative_coords_h = torch.arange(-(win_size[0] - 1), win_size[0], dtype=torch.float32)
|
||||||
|
relative_coords_w = torch.arange(-(win_size[1] - 1), win_size[1], dtype=torch.float32)
|
||||||
|
relative_coords_table = torch.stack(torch.meshgrid([relative_coords_h, relative_coords_w]))
|
||||||
|
relative_coords_table = relative_coords_table.permute(1, 2, 0).contiguous() # 2*Wh-1, 2*Ww-1, 2
|
||||||
|
if mode == 'swin':
|
||||||
|
if pretrained_win_size[0] > 0:
|
||||||
|
relative_coords_table[:, :, 0] /= (pretrained_win_size[0] - 1)
|
||||||
|
relative_coords_table[:, :, 1] /= (pretrained_win_size[1] - 1)
|
||||||
|
else:
|
||||||
|
relative_coords_table[:, :, 0] /= (win_size[0] - 1)
|
||||||
|
relative_coords_table[:, :, 1] /= (win_size[1] - 1)
|
||||||
|
relative_coords_table *= 8 # normalize to -8, 8
|
||||||
|
relative_coords_table = torch.sign(relative_coords_table) * torch.log2(
|
||||||
|
1.0 + relative_coords_table.abs()) / math.log2(8)
|
||||||
|
else:
|
||||||
|
if mode == 'rw':
|
||||||
|
# cr w/ window size normalization -> [-1,1] log coords
|
||||||
|
relative_coords_table[:, :, 0] /= (win_size[0] - 1)
|
||||||
|
relative_coords_table[:, :, 1] /= (win_size[1] - 1)
|
||||||
|
relative_coords_table *= 8 # scale to -8, 8
|
||||||
|
relative_coords_table = torch.sign(relative_coords_table) * torch.log2(
|
||||||
|
1.0 + relative_coords_table.abs())
|
||||||
|
relative_coords_table /= math.log2(9) # -> [-1, 1]
|
||||||
|
else:
|
||||||
|
# mode == 'cr'
|
||||||
|
relative_coords_table = torch.sign(relative_coords_table) * torch.log(
|
||||||
|
1.0 + relative_coords_table.abs())
|
||||||
|
|
||||||
|
return relative_coords_table
|
||||||
|
|
||||||
|
|
||||||
|
class RelPosMlp(nn.Module):
|
||||||
|
""" Log-Coordinate Relative Position MLP
|
||||||
|
Based on ideas presented in Swin-V2 paper (https://arxiv.org/abs/2111.09883)
|
||||||
|
|
||||||
|
This impl covers the 'swin' implementation as well as two timm specific modes ('cr', and 'rw')
|
||||||
|
"""
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
window_size,
|
||||||
|
num_heads=8,
|
||||||
|
hidden_dim=128,
|
||||||
|
prefix_tokens=0,
|
||||||
|
mode='cr',
|
||||||
|
pretrained_window_size=(0, 0)
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.window_size = window_size
|
||||||
|
self.window_area = self.window_size[0] * self.window_size[1]
|
||||||
|
self.prefix_tokens = prefix_tokens
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.bias_shape = (self.window_area,) * 2 + (num_heads,)
|
||||||
|
if mode == 'swin':
|
||||||
|
self.bias_act = nn.Sigmoid()
|
||||||
|
self.bias_gain = 16
|
||||||
|
mlp_bias = (True, False)
|
||||||
|
elif mode == 'rw':
|
||||||
|
self.bias_act = nn.Tanh()
|
||||||
|
self.bias_gain = 4
|
||||||
|
mlp_bias = True
|
||||||
|
else:
|
||||||
|
self.bias_act = nn.Identity()
|
||||||
|
self.bias_gain = None
|
||||||
|
mlp_bias = True
|
||||||
|
|
||||||
|
self.mlp = Mlp(
|
||||||
|
2, # x, y
|
||||||
|
hidden_features=hidden_dim,
|
||||||
|
out_features=num_heads,
|
||||||
|
act_layer=nn.ReLU,
|
||||||
|
bias=mlp_bias,
|
||||||
|
drop=(0.125, 0.)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.register_buffer(
|
||||||
|
"relative_position_index",
|
||||||
|
gen_relative_position_index(window_size),
|
||||||
|
persistent=False)
|
||||||
|
|
||||||
|
# get relative_coords_table
|
||||||
|
self.register_buffer(
|
||||||
|
"rel_coords_log",
|
||||||
|
gen_relative_log_coords(window_size, pretrained_window_size, mode=mode),
|
||||||
|
persistent=False)
|
||||||
|
|
||||||
|
def get_bias(self) -> torch.Tensor:
|
||||||
|
relative_position_bias = self.mlp(self.rel_coords_log)
|
||||||
|
if self.relative_position_index is not None:
|
||||||
|
relative_position_bias = relative_position_bias.view(-1, self.num_heads)[
|
||||||
|
self.relative_position_index.view(-1)] # Wh*Ww,Wh*Ww,nH
|
||||||
|
relative_position_bias = relative_position_bias.view(self.bias_shape)
|
||||||
|
relative_position_bias = relative_position_bias.permute(2, 0, 1)
|
||||||
|
relative_position_bias = self.bias_act(relative_position_bias)
|
||||||
|
if self.bias_gain is not None:
|
||||||
|
relative_position_bias = self.bias_gain * relative_position_bias
|
||||||
|
if self.prefix_tokens:
|
||||||
|
relative_position_bias = F.pad(relative_position_bias, [self.prefix_tokens, 0, self.prefix_tokens, 0])
|
||||||
|
return relative_position_bias.unsqueeze(0).contiguous()
|
||||||
|
|
||||||
|
def forward(self, attn, shared_rel_pos: Optional[torch.Tensor] = None):
|
||||||
|
return attn + self.get_bias()
|
||||||
|
|
||||||
|
|
||||||
|
def generate_lookup_tensor(
|
||||||
|
length: int,
|
||||||
|
max_relative_position: Optional[int] = None,
|
||||||
|
):
|
||||||
|
"""Generate a one_hot lookup tensor to reindex embeddings along one dimension.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
length: the length to reindex to.
|
||||||
|
max_relative_position: the maximum relative position to consider.
|
||||||
|
Relative position embeddings for distances above this threshold
|
||||||
|
are zeroed out.
|
||||||
|
Returns:
|
||||||
|
a lookup Tensor of size [length, length, vocab_size] that satisfies
|
||||||
|
ret[n,m,v] = 1{m - n + max_relative_position = v}.
|
||||||
|
"""
|
||||||
|
if max_relative_position is None:
|
||||||
|
max_relative_position = length - 1
|
||||||
|
# Return the cached lookup tensor, otherwise compute it and cache it.
|
||||||
|
vocab_size = 2 * max_relative_position + 1
|
||||||
|
ret = torch.zeros(length, length, vocab_size)
|
||||||
|
for i in range(length):
|
||||||
|
for x in range(length):
|
||||||
|
v = x - i + max_relative_position
|
||||||
|
if abs(x - i) > max_relative_position:
|
||||||
|
continue
|
||||||
|
ret[i, x, v] = 1
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
def reindex_2d_einsum_lookup(
|
||||||
|
relative_position_tensor,
|
||||||
|
height: int,
|
||||||
|
width: int,
|
||||||
|
height_lookup: torch.Tensor,
|
||||||
|
width_lookup: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Reindex 2d relative position bias with 2 independent einsum lookups.
|
||||||
|
|
||||||
|
Adapted from:
|
||||||
|
https://github.com/google-research/maxvit/blob/2e06a7f1f70c76e64cd3dabe5cd1b8c1a23c9fb7/maxvit/models/attention_utils.py
|
||||||
|
|
||||||
|
Args:
|
||||||
|
relative_position_tensor: tensor of shape
|
||||||
|
[..., vocab_height, vocab_width, ...].
|
||||||
|
height: height to reindex to.
|
||||||
|
width: width to reindex to.
|
||||||
|
height_lookup: one-hot height lookup
|
||||||
|
width_lookup: one-hot width lookup
|
||||||
|
Returns:
|
||||||
|
reindexed_tensor: a Tensor of shape
|
||||||
|
[..., height * width, height * width, ...]
|
||||||
|
"""
|
||||||
|
reindexed_tensor = torch.einsum('nhw,ixh->nixw', relative_position_tensor, height_lookup)
|
||||||
|
reindexed_tensor = torch.einsum('nixw,jyw->nijxy', reindexed_tensor, width_lookup)
|
||||||
|
area = height * width
|
||||||
|
return reindexed_tensor.reshape(relative_position_tensor.shape[0], area, area)
|
||||||
|
|
||||||
|
|
||||||
|
class RelPosBiasTf(nn.Module):
|
||||||
|
""" Relative Position Bias Impl (Compatible with Tensorflow MaxViT models)
|
||||||
|
Adapted from:
|
||||||
|
https://github.com/google-research/maxvit/blob/2e06a7f1f70c76e64cd3dabe5cd1b8c1a23c9fb7/maxvit/models/attention_utils.py
|
||||||
|
"""
|
||||||
|
def __init__(self, window_size, num_heads, prefix_tokens=0):
|
||||||
|
super().__init__()
|
||||||
|
assert prefix_tokens <= 1
|
||||||
|
self.window_size = window_size
|
||||||
|
self.window_area = window_size[0] * window_size[1]
|
||||||
|
self.num_heads = num_heads
|
||||||
|
|
||||||
|
vocab_height = 2 * window_size[0] - 1
|
||||||
|
vocab_width = 2 * window_size[1] - 1
|
||||||
|
self.bias_shape = (self.num_heads, vocab_height, vocab_width)
|
||||||
|
self.relative_position_bias_table = nn.Parameter(torch.zeros(self.bias_shape))
|
||||||
|
self.register_buffer('height_lookup', generate_lookup_tensor(window_size[0]), persistent=False)
|
||||||
|
self.register_buffer('width_lookup', generate_lookup_tensor(window_size[1]), persistent=False)
|
||||||
|
self.init_weights()
|
||||||
|
|
||||||
|
def init_weights(self):
|
||||||
|
nn.init.normal_(self.relative_position_bias_table, std=.02)
|
||||||
|
|
||||||
|
def get_bias(self) -> torch.Tensor:
|
||||||
|
# FIXME change to not use one-hot/einsum?
|
||||||
|
return reindex_2d_einsum_lookup(
|
||||||
|
self.relative_position_bias_table,
|
||||||
|
self.window_size[0],
|
||||||
|
self.window_size[1],
|
||||||
|
self.height_lookup,
|
||||||
|
self.width_lookup
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, attn, shared_rel_pos: Optional[torch.Tensor] = None):
|
||||||
|
return attn + self.get_bias()
|
219
timm/layers/pos_embed_sincos.py
Normal file
219
timm/layers/pos_embed_sincos.py
Normal file
@ -0,0 +1,219 @@
|
|||||||
|
""" Sin-cos, fourier, rotary position embedding modules and functions
|
||||||
|
|
||||||
|
Hacked together by / Copyright 2022 Ross Wightman
|
||||||
|
"""
|
||||||
|
import math
|
||||||
|
from typing import List, Tuple, Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
def pixel_freq_bands(
|
||||||
|
num_bands: int,
|
||||||
|
max_freq: float = 224.,
|
||||||
|
linear_bands: bool = True,
|
||||||
|
dtype: torch.dtype = torch.float32,
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
):
|
||||||
|
if linear_bands:
|
||||||
|
bands = torch.linspace(1.0, max_freq / 2, num_bands, dtype=dtype, device=device)
|
||||||
|
else:
|
||||||
|
bands = 2 ** torch.linspace(0, math.log(max_freq, 2) - 1, num_bands, dtype=dtype, device=device)
|
||||||
|
return bands * torch.pi
|
||||||
|
|
||||||
|
|
||||||
|
def inv_freq_bands(
|
||||||
|
num_bands: int,
|
||||||
|
temperature: float = 100000.,
|
||||||
|
step: int = 2,
|
||||||
|
dtype: torch.dtype = torch.float32,
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
inv_freq = 1. / (temperature ** (torch.arange(0, num_bands, step, dtype=dtype, device=device) / num_bands))
|
||||||
|
return inv_freq
|
||||||
|
|
||||||
|
|
||||||
|
def build_sincos2d_pos_embed(
|
||||||
|
feat_shape: List[int],
|
||||||
|
dim: int = 64,
|
||||||
|
temperature: float = 10000.,
|
||||||
|
reverse_coord: bool = False,
|
||||||
|
interleave_sin_cos: bool = False,
|
||||||
|
dtype: torch.dtype = torch.float32,
|
||||||
|
device: Optional[torch.device] = None
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
|
||||||
|
Args:
|
||||||
|
feat_shape:
|
||||||
|
dim:
|
||||||
|
temperature:
|
||||||
|
reverse_coord: stack grid order W, H instead of H, W
|
||||||
|
interleave_sin_cos: sin, cos, sin, cos stack instead of sin, sin, cos, cos
|
||||||
|
dtype:
|
||||||
|
device:
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
"""
|
||||||
|
assert dim % 4 == 0, 'Embed dimension must be divisible by 4 for sin-cos 2D position embedding'
|
||||||
|
pos_dim = dim // 4
|
||||||
|
bands = inv_freq_bands(pos_dim, temperature=temperature, step=1, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
if reverse_coord:
|
||||||
|
feat_shape = feat_shape[::-1] # stack W, H instead of H, W
|
||||||
|
grid = torch.stack(
|
||||||
|
torch.meshgrid([torch.arange(s, device=device, dtype=dtype) for s in feat_shape])).flatten(1).transpose(0, 1)
|
||||||
|
pos2 = grid.unsqueeze(-1) * bands.unsqueeze(0)
|
||||||
|
# FIXME add support for unflattened spatial dim?
|
||||||
|
|
||||||
|
stack_dim = 2 if interleave_sin_cos else 1 # stack sin, cos, sin, cos instead of sin sin cos cos
|
||||||
|
pos_emb = torch.stack([torch.sin(pos2), torch.cos(pos2)], dim=stack_dim).flatten(1)
|
||||||
|
return pos_emb
|
||||||
|
|
||||||
|
|
||||||
|
def build_fourier_pos_embed(
|
||||||
|
feat_shape: List[int],
|
||||||
|
bands: Optional[torch.Tensor] = None,
|
||||||
|
num_bands: int = 64,
|
||||||
|
max_res: int = 224,
|
||||||
|
linear_bands: bool = False,
|
||||||
|
include_grid: bool = False,
|
||||||
|
concat_out: bool = True,
|
||||||
|
in_pixels: bool = True,
|
||||||
|
dtype: torch.dtype = torch.float32,
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
) -> List[torch.Tensor]:
|
||||||
|
if bands is None:
|
||||||
|
if in_pixels:
|
||||||
|
bands = pixel_freq_bands(num_bands, float(max_res), linear_bands=linear_bands, dtype=dtype, device=device)
|
||||||
|
else:
|
||||||
|
bands = inv_freq_bands(num_bands, step=1, dtype=dtype, device=device)
|
||||||
|
else:
|
||||||
|
if device is None:
|
||||||
|
device = bands.device
|
||||||
|
if dtype is None:
|
||||||
|
dtype = bands.dtype
|
||||||
|
|
||||||
|
if in_pixels:
|
||||||
|
grid = torch.stack(torch.meshgrid(
|
||||||
|
[torch.linspace(-1., 1., steps=s, device=device, dtype=dtype) for s in feat_shape]), dim=-1)
|
||||||
|
else:
|
||||||
|
grid = torch.stack(torch.meshgrid(
|
||||||
|
[torch.arange(s, device=device, dtype=dtype) for s in feat_shape]), dim=-1)
|
||||||
|
grid = grid.unsqueeze(-1)
|
||||||
|
pos = grid * bands
|
||||||
|
|
||||||
|
pos_sin, pos_cos = pos.sin(), pos.cos()
|
||||||
|
out = (grid, pos_sin, pos_cos) if include_grid else (pos_sin, pos_cos)
|
||||||
|
# FIXME torchscript doesn't like multiple return types, probably need to always cat?
|
||||||
|
if concat_out:
|
||||||
|
out = torch.cat(out, dim=-1)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class FourierEmbed(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, max_res: int = 224, num_bands: int = 64, concat_grid=True, keep_spatial=False):
|
||||||
|
super().__init__()
|
||||||
|
self.max_res = max_res
|
||||||
|
self.num_bands = num_bands
|
||||||
|
self.concat_grid = concat_grid
|
||||||
|
self.keep_spatial = keep_spatial
|
||||||
|
self.register_buffer('bands', pixel_freq_bands(max_res, num_bands), persistent=False)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
B, C = x.shape[:2]
|
||||||
|
feat_shape = x.shape[2:]
|
||||||
|
emb = build_fourier_pos_embed(
|
||||||
|
feat_shape,
|
||||||
|
self.bands,
|
||||||
|
include_grid=self.concat_grid,
|
||||||
|
dtype=x.dtype,
|
||||||
|
device=x.device)
|
||||||
|
emb = emb.transpose(-1, -2).flatten(len(feat_shape))
|
||||||
|
batch_expand = (B,) + (-1,) * (x.ndim - 1)
|
||||||
|
|
||||||
|
# FIXME support nD
|
||||||
|
if self.keep_spatial:
|
||||||
|
x = torch.cat([x, emb.unsqueeze(0).expand(batch_expand).permute(0, 3, 1, 2)], dim=1)
|
||||||
|
else:
|
||||||
|
x = torch.cat([x.permute(0, 2, 3, 1), emb.unsqueeze(0).expand(batch_expand)], dim=-1)
|
||||||
|
x = x.reshape(B, feat_shape.numel(), -1)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def rot(x):
|
||||||
|
return torch.stack([-x[..., 1::2], x[..., ::2]], -1).reshape(x.shape)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_rot_embed(x: torch.Tensor, sin_emb, cos_emb):
|
||||||
|
return x * cos_emb + rot(x) * sin_emb
|
||||||
|
|
||||||
|
|
||||||
|
def apply_rot_embed_list(x: List[torch.Tensor], sin_emb, cos_emb):
|
||||||
|
if isinstance(x, torch.Tensor):
|
||||||
|
x = [x]
|
||||||
|
return [t * cos_emb + rot(t) * sin_emb for t in x]
|
||||||
|
|
||||||
|
|
||||||
|
def apply_rot_embed_split(x: torch.Tensor, emb):
|
||||||
|
split = emb.shape[-1] // 2
|
||||||
|
return x * emb[:, :split] + rot(x) * emb[:, split:]
|
||||||
|
|
||||||
|
|
||||||
|
def build_rotary_pos_embed(
|
||||||
|
feat_shape: List[int],
|
||||||
|
bands: Optional[torch.Tensor] = None,
|
||||||
|
dim: int = 64,
|
||||||
|
max_freq: float = 224,
|
||||||
|
linear_bands: bool = False,
|
||||||
|
dtype: torch.dtype = torch.float32,
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
NOTE: shape arg should include spatial dim only
|
||||||
|
"""
|
||||||
|
feat_shape = torch.Size(feat_shape)
|
||||||
|
|
||||||
|
sin_emb, cos_emb = build_fourier_pos_embed(
|
||||||
|
feat_shape,
|
||||||
|
bands=bands,
|
||||||
|
num_bands=dim // 4,
|
||||||
|
max_res=max_freq,
|
||||||
|
linear_bands=linear_bands,
|
||||||
|
concat_out=False,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
N = feat_shape.numel()
|
||||||
|
sin_emb = sin_emb.reshape(N, -1).repeat_interleave(2, -1)
|
||||||
|
cos_emb = cos_emb.reshape(N, -1).repeat_interleave(2, -1)
|
||||||
|
return sin_emb, cos_emb
|
||||||
|
|
||||||
|
|
||||||
|
class RotaryEmbedding(nn.Module):
|
||||||
|
""" Rotary position embedding
|
||||||
|
|
||||||
|
NOTE: This is my initial attempt at impl rotary embedding for spatial use, it has not
|
||||||
|
been well tested, and will likely change. It will be moved to its own file.
|
||||||
|
|
||||||
|
The following impl/resources were referenced for this impl:
|
||||||
|
* https://github.com/lucidrains/vit-pytorch/blob/6f3a5fcf0bca1c5ec33a35ef48d97213709df4ba/vit_pytorch/rvt.py
|
||||||
|
* https://blog.eleuther.ai/rotary-embeddings/
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, dim, max_res=224, linear_bands: bool = False):
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.register_buffer('bands', pixel_freq_bands(dim // 4, max_res, linear_bands=linear_bands), persistent=False)
|
||||||
|
|
||||||
|
def get_embed(self, shape: List[int]):
|
||||||
|
return build_rotary_pos_embed(shape, self.bands)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# assuming channel-first tensor where spatial dim are >= 2
|
||||||
|
sin_emb, cos_emb = self.get_embed(x.shape[2:])
|
||||||
|
return apply_rot_embed(x, sin_emb, cos_emb)
|
@ -1,5 +1,6 @@
|
|||||||
import dataclasses
|
import dataclasses
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from typing import Optional, Dict, Callable, Any, Tuple
|
from typing import Optional, Dict, Callable, Any, Tuple
|
||||||
|
|
||||||
@ -9,7 +10,7 @@ from torch.hub import load_state_dict_from_url
|
|||||||
from timm.models._features import FeatureListNet, FeatureHookNet
|
from timm.models._features import FeatureListNet, FeatureHookNet
|
||||||
from timm.models._features_fx import FeatureGraphNet
|
from timm.models._features_fx import FeatureGraphNet
|
||||||
from timm.models._helpers import load_state_dict
|
from timm.models._helpers import load_state_dict
|
||||||
from timm.models._hub import has_hf_hub, download_cached_file, load_state_dict_from_hf
|
from timm.models._hub import has_hf_hub, download_cached_file, check_cached_file, load_state_dict_from_hf
|
||||||
from timm.models._manipulate import adapt_input_conv
|
from timm.models._manipulate import adapt_input_conv
|
||||||
from timm.models._pretrained import PretrainedCfg
|
from timm.models._pretrained import PretrainedCfg
|
||||||
from timm.models._prune import adapt_model_from_file
|
from timm.models._prune import adapt_model_from_file
|
||||||
@ -32,6 +33,7 @@ def _resolve_pretrained_source(pretrained_cfg):
|
|||||||
pretrained_url = pretrained_cfg.get('url', None)
|
pretrained_url = pretrained_cfg.get('url', None)
|
||||||
pretrained_file = pretrained_cfg.get('file', None)
|
pretrained_file = pretrained_cfg.get('file', None)
|
||||||
hf_hub_id = pretrained_cfg.get('hf_hub_id', None)
|
hf_hub_id = pretrained_cfg.get('hf_hub_id', None)
|
||||||
|
|
||||||
# resolve where to load pretrained weights from
|
# resolve where to load pretrained weights from
|
||||||
load_from = ''
|
load_from = ''
|
||||||
pretrained_loc = ''
|
pretrained_loc = ''
|
||||||
@ -43,15 +45,20 @@ def _resolve_pretrained_source(pretrained_cfg):
|
|||||||
else:
|
else:
|
||||||
# default source == timm or unspecified
|
# default source == timm or unspecified
|
||||||
if pretrained_file:
|
if pretrained_file:
|
||||||
|
# file load override is the highest priority if set
|
||||||
load_from = 'file'
|
load_from = 'file'
|
||||||
pretrained_loc = pretrained_file
|
pretrained_loc = pretrained_file
|
||||||
elif pretrained_url:
|
else:
|
||||||
load_from = 'url'
|
# next, HF hub is prioritized unless a valid cached version of weights exists already
|
||||||
pretrained_loc = pretrained_url
|
cached_url_valid = check_cached_file(pretrained_url) if pretrained_url else False
|
||||||
elif hf_hub_id and has_hf_hub(necessary=True):
|
if hf_hub_id and has_hf_hub(necessary=True) and not cached_url_valid:
|
||||||
# hf-hub available as alternate weight source in default_cfg
|
# hf-hub available as alternate weight source in default_cfg
|
||||||
load_from = 'hf-hub'
|
load_from = 'hf-hub'
|
||||||
pretrained_loc = hf_hub_id
|
pretrained_loc = hf_hub_id
|
||||||
|
elif pretrained_url:
|
||||||
|
load_from = 'url'
|
||||||
|
pretrained_loc = pretrained_url
|
||||||
|
|
||||||
if load_from == 'hf-hub' and pretrained_cfg.get('hf_hub_filename', None):
|
if load_from == 'hf-hub' and pretrained_cfg.get('hf_hub_filename', None):
|
||||||
# if a filename override is set, return tuple for location w/ (hub_id, filename)
|
# if a filename override is set, return tuple for location w/ (hub_id, filename)
|
||||||
pretrained_loc = pretrained_loc, pretrained_cfg['hf_hub_filename']
|
pretrained_loc = pretrained_loc, pretrained_cfg['hf_hub_filename']
|
||||||
@ -105,7 +112,7 @@ def load_custom_pretrained(
|
|||||||
pretrained_loc = download_cached_file(
|
pretrained_loc = download_cached_file(
|
||||||
pretrained_loc,
|
pretrained_loc,
|
||||||
check_hash=_CHECK_HASH,
|
check_hash=_CHECK_HASH,
|
||||||
progress=_DOWNLOAD_PROGRESS
|
progress=_DOWNLOAD_PROGRESS,
|
||||||
)
|
)
|
||||||
|
|
||||||
if load_fn is not None:
|
if load_fn is not None:
|
||||||
@ -146,6 +153,15 @@ def load_pretrained(
|
|||||||
state_dict = load_state_dict(pretrained_loc)
|
state_dict = load_state_dict(pretrained_loc)
|
||||||
elif load_from == 'url':
|
elif load_from == 'url':
|
||||||
_logger.info(f'Loading pretrained weights from url ({pretrained_loc})')
|
_logger.info(f'Loading pretrained weights from url ({pretrained_loc})')
|
||||||
|
if pretrained_cfg.get('custom_load', False):
|
||||||
|
pretrained_loc = download_cached_file(
|
||||||
|
pretrained_loc,
|
||||||
|
progress=_DOWNLOAD_PROGRESS,
|
||||||
|
check_hash=_CHECK_HASH,
|
||||||
|
)
|
||||||
|
model.load_pretrained(pretrained_loc)
|
||||||
|
return
|
||||||
|
else:
|
||||||
state_dict = load_state_dict_from_url(
|
state_dict = load_state_dict_from_url(
|
||||||
pretrained_loc,
|
pretrained_loc,
|
||||||
map_location='cpu',
|
map_location='cpu',
|
||||||
@ -364,12 +380,6 @@ def build_model_with_cfg(
|
|||||||
# For classification models, check class attr, then kwargs, then default to 1k, otherwise 0 for feats
|
# For classification models, check class attr, then kwargs, then default to 1k, otherwise 0 for feats
|
||||||
num_classes_pretrained = 0 if features else getattr(model, 'num_classes', kwargs.get('num_classes', 1000))
|
num_classes_pretrained = 0 if features else getattr(model, 'num_classes', kwargs.get('num_classes', 1000))
|
||||||
if pretrained:
|
if pretrained:
|
||||||
if pretrained_cfg.get('custom_load', False):
|
|
||||||
load_custom_pretrained(
|
|
||||||
model,
|
|
||||||
pretrained_cfg=pretrained_cfg,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
load_pretrained(
|
load_pretrained(
|
||||||
model,
|
model,
|
||||||
pretrained_cfg=pretrained_cfg,
|
pretrained_cfg=pretrained_cfg,
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import hashlib
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
@ -67,6 +68,26 @@ def download_cached_file(url, check_hash=True, progress=False):
|
|||||||
return cached_file
|
return cached_file
|
||||||
|
|
||||||
|
|
||||||
|
def check_cached_file(url, check_hash=True):
|
||||||
|
if isinstance(url, (list, tuple)):
|
||||||
|
url, filename = url
|
||||||
|
else:
|
||||||
|
parts = urlparse(url)
|
||||||
|
filename = os.path.basename(parts.path)
|
||||||
|
cached_file = os.path.join(get_cache_dir(), filename)
|
||||||
|
if os.path.exists(cached_file):
|
||||||
|
if check_hash:
|
||||||
|
r = HASH_REGEX.search(filename) # r is Optional[Match[str]]
|
||||||
|
hash_prefix = r.group(1) if r else None
|
||||||
|
if hash_prefix:
|
||||||
|
with open(cached_file, 'rb') as f:
|
||||||
|
hd = hashlib.sha256(f.read()).hexdigest()
|
||||||
|
if hd[:len(hash_prefix)] != hash_prefix:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def has_hf_hub(necessary=False):
|
def has_hf_hub(necessary=False):
|
||||||
if not _has_hf_hub and necessary:
|
if not _has_hf_hub and necessary:
|
||||||
# if no HF Hub module installed, and it is necessary to continue, raise error
|
# if no HF Hub module installed, and it is necessary to continue, raise error
|
||||||
@ -90,14 +111,14 @@ def load_cfg_from_json(json_file: Union[str, os.PathLike]):
|
|||||||
return json.loads(text)
|
return json.loads(text)
|
||||||
|
|
||||||
|
|
||||||
def _download_from_hf(model_id: str, filename: str):
|
def download_from_hf(model_id: str, filename: str):
|
||||||
hf_model_id, hf_revision = hf_split(model_id)
|
hf_model_id, hf_revision = hf_split(model_id)
|
||||||
return hf_hub_download(hf_model_id, filename, revision=hf_revision)
|
return hf_hub_download(hf_model_id, filename, revision=hf_revision)
|
||||||
|
|
||||||
|
|
||||||
def load_model_config_from_hf(model_id: str):
|
def load_model_config_from_hf(model_id: str):
|
||||||
assert has_hf_hub(True)
|
assert has_hf_hub(True)
|
||||||
cached_file = _download_from_hf(model_id, 'config.json')
|
cached_file = download_from_hf(model_id, 'config.json')
|
||||||
|
|
||||||
hf_config = load_cfg_from_json(cached_file)
|
hf_config = load_cfg_from_json(cached_file)
|
||||||
if 'pretrained_cfg' not in hf_config:
|
if 'pretrained_cfg' not in hf_config:
|
||||||
@ -124,34 +145,28 @@ def load_model_config_from_hf(model_id: str):
|
|||||||
|
|
||||||
def load_state_dict_from_hf(model_id: str, filename: str = 'pytorch_model.bin'):
|
def load_state_dict_from_hf(model_id: str, filename: str = 'pytorch_model.bin'):
|
||||||
assert has_hf_hub(True)
|
assert has_hf_hub(True)
|
||||||
cached_file = _download_from_hf(model_id, filename)
|
cached_file = download_from_hf(model_id, filename)
|
||||||
state_dict = torch.load(cached_file, map_location='cpu')
|
state_dict = torch.load(cached_file, map_location='cpu')
|
||||||
return state_dict
|
return state_dict
|
||||||
|
|
||||||
|
|
||||||
def save_for_hf(model, save_directory, model_config=None):
|
def save_config_for_hf(model, config_path, model_config=None):
|
||||||
assert has_hf_hub(True)
|
|
||||||
model_config = model_config or {}
|
model_config = model_config or {}
|
||||||
save_directory = Path(save_directory)
|
|
||||||
save_directory.mkdir(exist_ok=True, parents=True)
|
|
||||||
|
|
||||||
weights_path = save_directory / 'pytorch_model.bin'
|
|
||||||
torch.save(model.state_dict(), weights_path)
|
|
||||||
|
|
||||||
config_path = save_directory / 'config.json'
|
|
||||||
hf_config = {}
|
hf_config = {}
|
||||||
pretrained_cfg = filter_pretrained_cfg(model.pretrained_cfg, remove_source=True, remove_null=True)
|
pretrained_cfg = filter_pretrained_cfg(model.pretrained_cfg, remove_source=True, remove_null=True)
|
||||||
# set some values at root config level
|
# set some values at root config level
|
||||||
hf_config['architecture'] = pretrained_cfg.pop('architecture')
|
hf_config['architecture'] = pretrained_cfg.pop('architecture')
|
||||||
hf_config['num_classes'] = model_config.get('num_classes', model.num_classes)
|
hf_config['num_classes'] = model_config.get('num_classes', model.num_classes)
|
||||||
hf_config['num_features'] = model_config.get('num_features', model.num_features)
|
hf_config['num_features'] = model_config.get('num_features', model.num_features)
|
||||||
hf_config['global_pool'] = model_config.get('global_pool', getattr(model, 'global_pool', None))
|
global_pool_type = model_config.get('global_pool', getattr(model, 'global_pool', None))
|
||||||
|
if isinstance(global_pool_type, str) and global_pool_type:
|
||||||
|
hf_config['global_pool'] = global_pool_type
|
||||||
|
|
||||||
if 'label' in model_config:
|
if 'labels' in model_config:
|
||||||
_logger.warning(
|
_logger.warning(
|
||||||
"'label' as a config field for timm models is deprecated. Please use 'label_name' and 'display_name'. "
|
"'labels' as a config field for timm models is deprecated. Please use 'label_name' and 'display_name'. "
|
||||||
"Using provided 'label' field as 'label_name'.")
|
"Using provided 'label' field as 'label_name'.")
|
||||||
model_config['label_name'] = model_config.pop('label')
|
model_config['label_name'] = model_config.pop('labels')
|
||||||
|
|
||||||
label_name = model_config.pop('label_name', None)
|
label_name = model_config.pop('label_name', None)
|
||||||
if label_name:
|
if label_name:
|
||||||
@ -173,6 +188,18 @@ def save_for_hf(model, save_directory, model_config=None):
|
|||||||
json.dump(hf_config, f, indent=2)
|
json.dump(hf_config, f, indent=2)
|
||||||
|
|
||||||
|
|
||||||
|
def save_for_hf(model, save_directory, model_config=None):
|
||||||
|
assert has_hf_hub(True)
|
||||||
|
save_directory = Path(save_directory)
|
||||||
|
save_directory.mkdir(exist_ok=True, parents=True)
|
||||||
|
|
||||||
|
weights_path = save_directory / 'pytorch_model.bin'
|
||||||
|
torch.save(model.state_dict(), weights_path)
|
||||||
|
|
||||||
|
config_path = save_directory / 'config.json'
|
||||||
|
save_config_for_hf(model, config_path, model_config=model_config)
|
||||||
|
|
||||||
|
|
||||||
def push_to_hf_hub(
|
def push_to_hf_hub(
|
||||||
model,
|
model,
|
||||||
repo_id: str,
|
repo_id: str,
|
||||||
|
@ -19,6 +19,7 @@ class PretrainedCfg:
|
|||||||
|
|
||||||
source: Optional[str] = None # source of cfg / weight location used (url, file, hf-hub)
|
source: Optional[str] = None # source of cfg / weight location used (url, file, hf-hub)
|
||||||
architecture: Optional[str] = None # architecture variant can be set when not implicit
|
architecture: Optional[str] = None # architecture variant can be set when not implicit
|
||||||
|
tag: Optional[str] = None # pretrained tag of source
|
||||||
custom_load: bool = False # use custom model specific model.load_pretrained() (ie for npz files)
|
custom_load: bool = False # use custom model specific model.load_pretrained() (ie for npz files)
|
||||||
|
|
||||||
# input / data config
|
# input / data config
|
||||||
@ -44,9 +45,11 @@ class PretrainedCfg:
|
|||||||
classifier: Optional[str] = None
|
classifier: Optional[str] = None
|
||||||
|
|
||||||
license: Optional[str] = None
|
license: Optional[str] = None
|
||||||
source_url: Optional[str] = None
|
description: Optional[str] = None
|
||||||
paper: Optional[str] = None
|
origin_url: Optional[str] = None
|
||||||
notes: Optional[str] = None
|
paper_name: Optional[str] = None
|
||||||
|
paper_ids: Optional[Union[str, Tuple[str]]] = None
|
||||||
|
notes: Optional[Tuple[str]] = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def has_weights(self):
|
def has_weights(self):
|
||||||
@ -62,11 +65,11 @@ class PretrainedCfg:
|
|||||||
|
|
||||||
def filter_pretrained_cfg(cfg, remove_source=False, remove_null=True):
|
def filter_pretrained_cfg(cfg, remove_source=False, remove_null=True):
|
||||||
filtered_cfg = {}
|
filtered_cfg = {}
|
||||||
keep_none = {'pool_size', 'first_conv', 'classifier'} # always keep these keys, even if none
|
keep_null = {'pool_size', 'first_conv', 'classifier'} # always keep these keys, even if none
|
||||||
for k, v in cfg.items():
|
for k, v in cfg.items():
|
||||||
if remove_source and k in {'url', 'file', 'hf_hub_id', 'hf_hub_id', 'hf_hub_filename', 'source'}:
|
if remove_source and k in {'url', 'file', 'hf_hub_id', 'hf_hub_id', 'hf_hub_filename', 'source'}:
|
||||||
continue
|
continue
|
||||||
if remove_null and v is None and k not in keep_none:
|
if remove_null and v is None and k not in keep_null:
|
||||||
continue
|
continue
|
||||||
filtered_cfg[k] = v
|
filtered_cfg[k] = v
|
||||||
return filtered_cfg
|
return filtered_cfg
|
||||||
|
@ -7,6 +7,7 @@ import re
|
|||||||
import sys
|
import sys
|
||||||
from collections import defaultdict, deque
|
from collections import defaultdict, deque
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
from dataclasses import replace
|
||||||
from typing import List, Optional, Union, Tuple
|
from typing import List, Optional, Union, Tuple
|
||||||
|
|
||||||
from ._pretrained import PretrainedCfg, DefaultCfg, split_model_name_tag
|
from ._pretrained import PretrainedCfg, DefaultCfg, split_model_name_tag
|
||||||
@ -20,7 +21,7 @@ _model_to_module = {} # mapping of model names to module names
|
|||||||
_model_entrypoints = {} # mapping of model names to architecture entrypoint fns
|
_model_entrypoints = {} # mapping of model names to architecture entrypoint fns
|
||||||
_model_has_pretrained = set() # set of model names that have pretrained weight url present
|
_model_has_pretrained = set() # set of model names that have pretrained weight url present
|
||||||
_model_default_cfgs = dict() # central repo for model arch -> default cfg objects
|
_model_default_cfgs = dict() # central repo for model arch -> default cfg objects
|
||||||
_model_pretrained_cfgs = dict() # central repo for model arch + tag -> pretrained cfgs
|
_model_pretrained_cfgs = dict() # central repo for model arch.tag -> pretrained cfgs
|
||||||
_model_with_tags = defaultdict(list) # shortcut to map each model arch to all model + tag names
|
_model_with_tags = defaultdict(list) # shortcut to map each model arch to all model + tag names
|
||||||
|
|
||||||
|
|
||||||
@ -48,24 +49,31 @@ def register_model(fn):
|
|||||||
if hasattr(mod, 'default_cfgs') and model_name in mod.default_cfgs:
|
if hasattr(mod, 'default_cfgs') and model_name in mod.default_cfgs:
|
||||||
# this will catch all models that have entrypoint matching cfg key, but miss any aliasing
|
# this will catch all models that have entrypoint matching cfg key, but miss any aliasing
|
||||||
# entrypoints or non-matching combos
|
# entrypoints or non-matching combos
|
||||||
cfg = mod.default_cfgs[model_name]
|
default_cfg = mod.default_cfgs[model_name]
|
||||||
if not isinstance(cfg, DefaultCfg):
|
if not isinstance(default_cfg, DefaultCfg):
|
||||||
# new style default cfg dataclass w/ multiple entries per model-arch
|
# new style default cfg dataclass w/ multiple entries per model-arch
|
||||||
assert isinstance(cfg, dict)
|
assert isinstance(default_cfg, dict)
|
||||||
# old style cfg dict per model-arch
|
# old style cfg dict per model-arch
|
||||||
cfg = PretrainedCfg(**cfg)
|
pretrained_cfg = PretrainedCfg(**default_cfg)
|
||||||
cfg = DefaultCfg(tags=deque(['']), cfgs={'': cfg})
|
default_cfg = DefaultCfg(tags=deque(['']), cfgs={'': pretrained_cfg})
|
||||||
|
|
||||||
for tag_idx, tag in enumerate(cfg.tags):
|
for tag_idx, tag in enumerate(default_cfg.tags):
|
||||||
is_default = tag_idx == 0
|
is_default = tag_idx == 0
|
||||||
pretrained_cfg = cfg.cfgs[tag]
|
pretrained_cfg = default_cfg.cfgs[tag]
|
||||||
|
model_name_tag = '.'.join([model_name, tag]) if tag else model_name
|
||||||
|
replace_items = dict(architecture=model_name, tag=tag if tag else None)
|
||||||
|
if pretrained_cfg.hf_hub_id and pretrained_cfg.hf_hub_id == 'timm/':
|
||||||
|
# auto-complete hub name w/ architecture.tag
|
||||||
|
replace_items['hf_hub_id'] = pretrained_cfg.hf_hub_id + model_name_tag
|
||||||
|
pretrained_cfg = replace(pretrained_cfg, **replace_items)
|
||||||
|
|
||||||
if is_default:
|
if is_default:
|
||||||
_model_pretrained_cfgs[model_name] = pretrained_cfg
|
_model_pretrained_cfgs[model_name] = pretrained_cfg
|
||||||
if pretrained_cfg.has_weights:
|
if pretrained_cfg.has_weights:
|
||||||
# add tagless entry if it's default and has weights
|
# add tagless entry if it's default and has weights
|
||||||
_model_has_pretrained.add(model_name)
|
_model_has_pretrained.add(model_name)
|
||||||
|
|
||||||
if tag:
|
if tag:
|
||||||
model_name_tag = '.'.join([model_name, tag])
|
|
||||||
_model_pretrained_cfgs[model_name_tag] = pretrained_cfg
|
_model_pretrained_cfgs[model_name_tag] = pretrained_cfg
|
||||||
if pretrained_cfg.has_weights:
|
if pretrained_cfg.has_weights:
|
||||||
# add model w/ tag if tag is valid
|
# add model w/ tag if tag is valid
|
||||||
@ -74,7 +82,7 @@ def register_model(fn):
|
|||||||
else:
|
else:
|
||||||
_model_with_tags[model_name].append(model_name) # has empty tag (to slowly remove these instances)
|
_model_with_tags[model_name].append(model_name) # has empty tag (to slowly remove these instances)
|
||||||
|
|
||||||
_model_default_cfgs[model_name] = cfg
|
_model_default_cfgs[model_name] = default_cfg
|
||||||
|
|
||||||
return fn
|
return fn
|
||||||
|
|
||||||
@ -198,15 +206,21 @@ def is_model_pretrained(model_name):
|
|||||||
return model_name in _model_has_pretrained
|
return model_name in _model_has_pretrained
|
||||||
|
|
||||||
|
|
||||||
def get_pretrained_cfg(model_name):
|
def get_pretrained_cfg(model_name, allow_unregistered=True):
|
||||||
if model_name in _model_pretrained_cfgs:
|
if model_name in _model_pretrained_cfgs:
|
||||||
return deepcopy(_model_pretrained_cfgs[model_name])
|
return deepcopy(_model_pretrained_cfgs[model_name])
|
||||||
raise RuntimeError(f'No pretrained config exists for model {model_name}.')
|
arch_name, tag = split_model_name_tag(model_name)
|
||||||
|
if arch_name in _model_default_cfgs:
|
||||||
|
# if model arch exists, but the tag is wrong, error out
|
||||||
|
raise RuntimeError(f'Invalid pretrained tag ({tag}) for {arch_name}.')
|
||||||
|
if allow_unregistered:
|
||||||
|
# if model arch doesn't exist, it has no pretrained_cfg registered, allow a default to be created
|
||||||
|
return None
|
||||||
|
raise RuntimeError(f'Model architecture ({arch_name}) has no pretrained cfg registered.')
|
||||||
|
|
||||||
|
|
||||||
def get_pretrained_cfg_value(model_name, cfg_key):
|
def get_pretrained_cfg_value(model_name, cfg_key):
|
||||||
""" Get a specific model default_cfg value by key. None if key doesn't exist.
|
""" Get a specific model default_cfg value by key. None if key doesn't exist.
|
||||||
"""
|
"""
|
||||||
if model_name in _model_pretrained_cfgs:
|
cfg = get_pretrained_cfg(model_name, allow_unregistered=False)
|
||||||
return getattr(_model_pretrained_cfgs[model_name], cfg_key, None)
|
return getattr(cfg, cfg_key, None)
|
||||||
raise RuntimeError(f'No pretrained config exist for model {model_name}.')
|
|
||||||
|
@ -355,64 +355,76 @@ def _cfg(url='', **kwargs):
|
|||||||
|
|
||||||
default_cfgs = generate_default_cfgs({
|
default_cfgs = generate_default_cfgs({
|
||||||
'beit_base_patch16_224.in22k_ft_in22k_in1k': _cfg(
|
'beit_base_patch16_224.in22k_ft_in22k_in1k': _cfg(
|
||||||
url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_base_patch16_224_pt22k_ft22kto1k.pth'),
|
url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_base_patch16_224_pt22k_ft22kto1k.pth',
|
||||||
|
hf_hub_id='timm/'),
|
||||||
'beit_base_patch16_384.in22k_ft_in22k_in1k': _cfg(
|
'beit_base_patch16_384.in22k_ft_in22k_in1k': _cfg(
|
||||||
url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_base_patch16_384_pt22k_ft22kto1k.pth',
|
url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_base_patch16_384_pt22k_ft22kto1k.pth',
|
||||||
|
hf_hub_id='timm/',
|
||||||
input_size=(3, 384, 384), crop_pct=1.0,
|
input_size=(3, 384, 384), crop_pct=1.0,
|
||||||
),
|
),
|
||||||
'beit_base_patch16_224.in22k_ft_in22k': _cfg(
|
'beit_base_patch16_224.in22k_ft_in22k': _cfg(
|
||||||
url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_base_patch16_224_pt22k_ft22k.pth',
|
url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_base_patch16_224_pt22k_ft22k.pth',
|
||||||
|
hf_hub_id='timm/',
|
||||||
num_classes=21841,
|
num_classes=21841,
|
||||||
),
|
),
|
||||||
'beit_large_patch16_224.in22k_ft_in22k_in1k': _cfg(
|
'beit_large_patch16_224.in22k_ft_in22k_in1k': _cfg(
|
||||||
url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_224_pt22k_ft22kto1k.pth'),
|
url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_224_pt22k_ft22kto1k.pth',
|
||||||
|
hf_hub_id='timm/'),
|
||||||
'beit_large_patch16_384.in22k_ft_in22k_in1k': _cfg(
|
'beit_large_patch16_384.in22k_ft_in22k_in1k': _cfg(
|
||||||
url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_384_pt22k_ft22kto1k.pth',
|
url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_384_pt22k_ft22kto1k.pth',
|
||||||
|
hf_hub_id='timm/',
|
||||||
input_size=(3, 384, 384), crop_pct=1.0,
|
input_size=(3, 384, 384), crop_pct=1.0,
|
||||||
),
|
),
|
||||||
'beit_large_patch16_512.in22k_ft_in22k_in1k': _cfg(
|
'beit_large_patch16_512.in22k_ft_in22k_in1k': _cfg(
|
||||||
url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_512_pt22k_ft22kto1k.pth',
|
url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_512_pt22k_ft22kto1k.pth',
|
||||||
|
hf_hub_id='timm/',
|
||||||
input_size=(3, 512, 512), crop_pct=1.0,
|
input_size=(3, 512, 512), crop_pct=1.0,
|
||||||
),
|
),
|
||||||
'beit_large_patch16_224.in22k_ft_in22k': _cfg(
|
'beit_large_patch16_224.in22k_ft_in22k': _cfg(
|
||||||
url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_224_pt22k_ft22k.pth',
|
url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_224_pt22k_ft22k.pth',
|
||||||
|
hf_hub_id='timm/',
|
||||||
num_classes=21841,
|
num_classes=21841,
|
||||||
),
|
),
|
||||||
|
|
||||||
'beitv2_base_patch16_224.in1k_ft_in22k_in1k': _cfg(
|
'beitv2_base_patch16_224.in1k_ft_in22k_in1k': _cfg(
|
||||||
url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_base_patch16_224_pt1k_ft21kto1k.pth',
|
url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_base_patch16_224_pt1k_ft21kto1k.pth',
|
||||||
|
hf_hub_id='timm/',
|
||||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD
|
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD
|
||||||
),
|
),
|
||||||
'beitv2_base_patch16_224.in1k_ft_in22k': _cfg(
|
'beitv2_base_patch16_224.in1k_ft_in22k': _cfg(
|
||||||
url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_base_patch16_224_pt1k_ft21k.pth',
|
url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_base_patch16_224_pt1k_ft21k.pth',
|
||||||
num_classes=21841,
|
hf_hub_id='timm/',
|
||||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD
|
num_classes=21841, mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD
|
||||||
),
|
),
|
||||||
'beitv2_large_patch16_224.in1k_ft_in22k_in1k': _cfg(
|
'beitv2_large_patch16_224.in1k_ft_in22k_in1k': _cfg(
|
||||||
url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_large_patch16_224_pt1k_ft21kto1k.pth',
|
url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_large_patch16_224_pt1k_ft21kto1k.pth',
|
||||||
crop_pct=0.95,
|
hf_hub_id='timm/',
|
||||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD
|
crop_pct=0.95, mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD
|
||||||
),
|
),
|
||||||
'beitv2_large_patch16_224.in1k_ft_in22k': _cfg(
|
'beitv2_large_patch16_224.in1k_ft_in22k': _cfg(
|
||||||
url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_large_patch16_224_pt1k_ft21k.pth',
|
url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_large_patch16_224_pt1k_ft21k.pth',
|
||||||
num_classes=21841,
|
hf_hub_id='timm/',
|
||||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD
|
num_classes=21841, mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD
|
||||||
),
|
),
|
||||||
|
|
||||||
'eva_giant_patch14_224.clip_ft_in1k': _cfg(
|
'eva_giant_patch14_224.clip_ft_in1k': _cfg(
|
||||||
hf_hub_id='BAAI/EVA', hf_hub_filename='eva_clip_vis_enc_sz224_ftcls_89p1.pt',
|
# hf_hub_id='BAAI/EVA', hf_hub_filename='eva_clip_vis_enc_sz224_ftcls_89p1.pt',
|
||||||
|
hf_hub_id='timm/',
|
||||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0,
|
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0,
|
||||||
),
|
),
|
||||||
'eva_giant_patch14_336.clip_ft_in1k': _cfg(
|
'eva_giant_patch14_336.clip_ft_in1k': _cfg(
|
||||||
hf_hub_id='BAAI/EVA', hf_hub_filename='eva_clip_vis_enc_sz336_ftcls_89p4.pt',
|
# hf_hub_id='BAAI/EVA', hf_hub_filename='eva_clip_vis_enc_sz336_ftcls_89p4.pt',
|
||||||
|
hf_hub_id='timm/',
|
||||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
|
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
|
||||||
input_size=(3, 336, 336), crop_pct=1.0, crop_mode='squash'),
|
input_size=(3, 336, 336), crop_pct=1.0, crop_mode='squash'),
|
||||||
'eva_giant_patch14_336.m30m_ft_in22k_in1k': _cfg(
|
'eva_giant_patch14_336.m30m_ft_in22k_in1k': _cfg(
|
||||||
hf_hub_id='BAAI/EVA', hf_hub_filename='eva_21k_1k_336px_psz14_ema_89p6.pt',
|
# hf_hub_id='BAAI/EVA', hf_hub_filename='eva_21k_1k_336px_psz14_ema_89p6.pt',
|
||||||
|
hf_hub_id='timm/',
|
||||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD,
|
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD,
|
||||||
input_size=(3, 336, 336), crop_pct=1.0, crop_mode='squash'),
|
input_size=(3, 336, 336), crop_pct=1.0, crop_mode='squash'),
|
||||||
'eva_giant_patch14_560.m30m_ft_in22k_in1k': _cfg(
|
'eva_giant_patch14_560.m30m_ft_in22k_in1k': _cfg(
|
||||||
hf_hub_id='BAAI/EVA', hf_hub_filename='eva_21k_1k_560px_psz14_ema_89p7.pt',
|
# hf_hub_id='BAAI/EVA', hf_hub_filename='eva_21k_1k_560px_psz14_ema_89p7.pt',
|
||||||
|
hf_hub_id='timm/',
|
||||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD,
|
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD,
|
||||||
input_size=(3, 560, 560), crop_pct=1.0, crop_mode='squash'),
|
input_size=(3, 560, 560), crop_pct=1.0, crop_mode='squash'),
|
||||||
})
|
})
|
||||||
|
@ -361,7 +361,6 @@ def _create_convnext(variant, pretrained=False, **kwargs):
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def _cfg(url='', **kwargs):
|
def _cfg(url='', **kwargs):
|
||||||
return {
|
return {
|
||||||
'url': url,
|
'url': url,
|
||||||
@ -375,90 +374,131 @@ def _cfg(url='', **kwargs):
|
|||||||
|
|
||||||
default_cfgs = generate_default_cfgs({
|
default_cfgs = generate_default_cfgs({
|
||||||
# timm specific variants
|
# timm specific variants
|
||||||
'convnext_atto.timm_in1k': _cfg(
|
'convnext_atto.d2_in1k': _cfg(
|
||||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_atto_d2-01bb0f51.pth',
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_atto_d2-01bb0f51.pth',
|
||||||
|
hf_hub_id='timm/',
|
||||||
test_input_size=(3, 288, 288), test_crop_pct=0.95),
|
test_input_size=(3, 288, 288), test_crop_pct=0.95),
|
||||||
'convnext_atto_ols.timm_in1k': _cfg(
|
'convnext_atto_ols.a2_in1k': _cfg(
|
||||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_atto_ols_a2-78d1c8f3.pth',
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_atto_ols_a2-78d1c8f3.pth',
|
||||||
|
hf_hub_id='timm/',
|
||||||
test_input_size=(3, 288, 288), test_crop_pct=0.95),
|
test_input_size=(3, 288, 288), test_crop_pct=0.95),
|
||||||
'convnext_femto.timm_in1k': _cfg(
|
'convnext_femto.d1_in1k': _cfg(
|
||||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_femto_d1-d71d5b4c.pth',
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_femto_d1-d71d5b4c.pth',
|
||||||
|
hf_hub_id='timm/',
|
||||||
test_input_size=(3, 288, 288), test_crop_pct=0.95),
|
test_input_size=(3, 288, 288), test_crop_pct=0.95),
|
||||||
'convnext_femto_ols.timm_in1k': _cfg(
|
'convnext_femto_ols.d1_in1k': _cfg(
|
||||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_femto_ols_d1-246bf2ed.pth',
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_femto_ols_d1-246bf2ed.pth',
|
||||||
|
hf_hub_id='timm/',
|
||||||
test_input_size=(3, 288, 288), test_crop_pct=0.95),
|
test_input_size=(3, 288, 288), test_crop_pct=0.95),
|
||||||
'convnext_pico.timm_in1k': _cfg(
|
'convnext_pico.d1_in1k': _cfg(
|
||||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_pico_d1-10ad7f0d.pth',
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_pico_d1-10ad7f0d.pth',
|
||||||
|
hf_hub_id='timm/',
|
||||||
test_input_size=(3, 288, 288), test_crop_pct=0.95),
|
test_input_size=(3, 288, 288), test_crop_pct=0.95),
|
||||||
'convnext_pico_ols.timm_in1k': _cfg(
|
'convnext_pico_ols.d1_in1k': _cfg(
|
||||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_pico_ols_d1-611f0ca7.pth',
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_pico_ols_d1-611f0ca7.pth',
|
||||||
|
hf_hub_id='timm/',
|
||||||
crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
||||||
'convnext_nano.timm_in1k': _cfg(
|
'convnext_nano.in12k_ft_in1k': _cfg(
|
||||||
|
hf_hub_id='timm/',
|
||||||
|
crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
||||||
|
'convnext_nano.d1h_in1k': _cfg(
|
||||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_nano_d1h-7eb4bdea.pth',
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_nano_d1h-7eb4bdea.pth',
|
||||||
|
hf_hub_id='timm/',
|
||||||
crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
||||||
'convnext_nano_ols.timm_in1k': _cfg(
|
'convnext_nano_ols.d1h_in1k': _cfg(
|
||||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_nano_ols_d1h-ae424a9a.pth',
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_nano_ols_d1h-ae424a9a.pth',
|
||||||
|
hf_hub_id='timm/',
|
||||||
crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
||||||
'convnext_tiny_hnf.timm_in1k': _cfg(
|
'convnext_tiny_hnf.a2h_in1k': _cfg(
|
||||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_tiny_hnf_a2h-ab7e9df2.pth',
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_tiny_hnf_a2h-ab7e9df2.pth',
|
||||||
|
hf_hub_id='timm/',
|
||||||
crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
||||||
|
|
||||||
|
'convnext_nano.in12k': _cfg(
|
||||||
|
hf_hub_id='timm/',
|
||||||
|
crop_pct=0.95, num_classes=11821),
|
||||||
|
|
||||||
'convnext_tiny.fb_in1k': _cfg(
|
'convnext_tiny.fb_in1k': _cfg(
|
||||||
url="https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth",
|
url="https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth",
|
||||||
|
hf_hub_id='timm/',
|
||||||
test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
||||||
'convnext_small.fb_in1k': _cfg(
|
'convnext_small.fb_in1k': _cfg(
|
||||||
url="https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth",
|
url="https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth",
|
||||||
|
hf_hub_id='timm/',
|
||||||
test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
||||||
'convnext_base.fb_in1k': _cfg(
|
'convnext_base.fb_in1k': _cfg(
|
||||||
url="https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth",
|
url="https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth",
|
||||||
|
hf_hub_id='timm/',
|
||||||
test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
||||||
'convnext_large.fb_in1k': _cfg(
|
'convnext_large.fb_in1k': _cfg(
|
||||||
url="https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth",
|
url="https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth",
|
||||||
|
hf_hub_id='timm/',
|
||||||
test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
||||||
'convnext_xlarge.untrained': _cfg(),
|
'convnext_xlarge.untrained': _cfg(),
|
||||||
|
'convnext_xxlarge.untrained': _cfg(),
|
||||||
|
|
||||||
'convnext_tiny.fb_in22k_ft_in1k': _cfg(
|
'convnext_tiny.fb_in22k_ft_in1k': _cfg(
|
||||||
url='https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_1k_224.pth',
|
url='https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_1k_224.pth',
|
||||||
|
hf_hub_id='timm/',
|
||||||
test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
||||||
'convnext_small.fb_in22k_ft_in1k': _cfg(
|
'convnext_small.fb_in22k_ft_in1k': _cfg(
|
||||||
url='https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_1k_224.pth',
|
url='https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_1k_224.pth',
|
||||||
|
hf_hub_id='timm/',
|
||||||
test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
||||||
'convnext_base.fb_in22k_ft_in1k': _cfg(
|
'convnext_base.fb_in22k_ft_in1k': _cfg(
|
||||||
url='https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_1k_224.pth',
|
url='https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_1k_224.pth',
|
||||||
|
hf_hub_id='timm/',
|
||||||
test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
||||||
'convnext_large.fb_in22k_ft_in1k': _cfg(
|
'convnext_large.fb_in22k_ft_in1k': _cfg(
|
||||||
url='https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_1k_224.pth',
|
url='https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_1k_224.pth',
|
||||||
|
hf_hub_id='timm/',
|
||||||
test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
||||||
'convnext_xlarge.fb_in22k_ft_in1k': _cfg(
|
'convnext_xlarge.fb_in22k_ft_in1k': _cfg(
|
||||||
url='https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_1k_224_ema.pth',
|
url='https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_1k_224_ema.pth',
|
||||||
|
hf_hub_id='timm/',
|
||||||
test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
||||||
|
|
||||||
'convnext_tiny.fb_in22k_ft_in1k_384': _cfg(
|
'convnext_tiny.fb_in22k_ft_in1k_384': _cfg(
|
||||||
url='https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_1k_384.pth',
|
url='https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_1k_384.pth',
|
||||||
|
hf_hub_id='timm/',
|
||||||
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
|
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
|
||||||
'convnext_small..fb_in22k_ft_in1k_384': _cfg(
|
'convnext_small.fb_in22k_ft_in1k_384': _cfg(
|
||||||
url='https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_1k_384.pth',
|
url='https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_1k_384.pth',
|
||||||
|
hf_hub_id='timm/',
|
||||||
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
|
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
|
||||||
'convnext_base.fb_in22k_ft_in1k_384': _cfg(
|
'convnext_base.fb_in22k_ft_in1k_384': _cfg(
|
||||||
url='https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_1k_384.pth',
|
url='https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_1k_384.pth',
|
||||||
|
hf_hub_id='timm/',
|
||||||
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
|
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
|
||||||
'convnext_large.fb_in22k_ft_in1k_384': _cfg(
|
'convnext_large.fb_in22k_ft_in1k_384': _cfg(
|
||||||
url='https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_1k_384.pth',
|
url='https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_1k_384.pth',
|
||||||
|
hf_hub_id='timm/',
|
||||||
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
|
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
|
||||||
'convnext_xlarge.fb_in22k_ft_in1k_384': _cfg(
|
'convnext_xlarge.fb_in22k_ft_in1k_384': _cfg(
|
||||||
url='https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_1k_384_ema.pth',
|
url='https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_1k_384_ema.pth',
|
||||||
|
hf_hub_id='timm/',
|
||||||
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
|
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
|
||||||
|
|
||||||
'convnext_tiny_in22k.fb_in22k': _cfg(
|
'convnext_tiny.fb_in22k': _cfg(
|
||||||
url="https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_224.pth", num_classes=21841),
|
url="https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_224.pth",
|
||||||
'convnext_small_in22k.fb_in22k': _cfg(
|
hf_hub_id='timm/',
|
||||||
url="https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_224.pth", num_classes=21841),
|
num_classes=21841),
|
||||||
'convnext_base_in22k.fb_in22k': _cfg(
|
'convnext_small.fb_in22k': _cfg(
|
||||||
url="https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth", num_classes=21841),
|
url="https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_224.pth",
|
||||||
'convnext_large_in22k.fb_in22k': _cfg(
|
hf_hub_id='timm/',
|
||||||
url="https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_224.pth", num_classes=21841),
|
num_classes=21841),
|
||||||
'convnext_xlarge_in22k.fb_in22k': _cfg(
|
'convnext_base.fb_in22k': _cfg(
|
||||||
url="https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_224.pth", num_classes=21841),
|
url="https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth",
|
||||||
|
hf_hub_id='timm/',
|
||||||
|
num_classes=21841),
|
||||||
|
'convnext_large.fb_in22k': _cfg(
|
||||||
|
url="https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_224.pth",
|
||||||
|
hf_hub_id='timm/',
|
||||||
|
num_classes=21841),
|
||||||
|
'convnext_xlarge.fb_in22k': _cfg(
|
||||||
|
url="https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_224.pth",
|
||||||
|
hf_hub_id='timm/',
|
||||||
|
num_classes=21841),
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
@ -576,3 +616,10 @@ def convnext_xlarge(pretrained=False, **kwargs):
|
|||||||
model_args = dict(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048], **kwargs)
|
model_args = dict(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048], **kwargs)
|
||||||
model = _create_convnext('convnext_xlarge', pretrained=pretrained, **model_args)
|
model = _create_convnext('convnext_xlarge', pretrained=pretrained, **model_args)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def convnext_xxlarge(pretrained=False, **kwargs):
|
||||||
|
model_args = dict(depths=[3, 4, 30, 3], dims=[384, 768, 1536, 3072], **kwargs)
|
||||||
|
model = _create_convnext('convnext_xxlarge', pretrained=pretrained, **model_args)
|
||||||
|
return model
|
||||||
|
File diff suppressed because it is too large
Load Diff
@ -2,6 +2,7 @@
|
|||||||
from timm.layers.activations import *
|
from timm.layers.activations import *
|
||||||
from timm.layers.adaptive_avgmax_pool import \
|
from timm.layers.adaptive_avgmax_pool import \
|
||||||
adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d
|
adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d
|
||||||
|
from timm.layers.attention_pool2d import AttentionPool2d, RotAttentionPool2d, RotaryEmbedding
|
||||||
from timm.layers.blur_pool import BlurPool2d
|
from timm.layers.blur_pool import BlurPool2d
|
||||||
from timm.layers.classifier import ClassifierHead, create_classifier
|
from timm.layers.classifier import ClassifierHead, create_classifier
|
||||||
from timm.layers.cond_conv2d import CondConv2d, get_condconv_initializer
|
from timm.layers.cond_conv2d import CondConv2d, get_condconv_initializer
|
||||||
|
@ -47,16 +47,15 @@ import torch
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||||
from timm.layers import Mlp, ConvMlp, DropPath, ClassifierHead, trunc_normal_tf_, LayerNorm
|
from timm.layers import Mlp, ConvMlp, DropPath, ClassifierHead, LayerNorm, SelectAdaptivePool2d
|
||||||
from timm.layers import SelectAdaptivePool2d, create_pool2d
|
from timm.layers import create_attn, get_act_layer, get_norm_layer, get_norm_act_layer, create_conv2d, create_pool2d
|
||||||
from timm.layers import create_attn, get_act_layer, get_norm_layer, get_norm_act_layer, create_conv2d
|
from timm.layers import trunc_normal_tf_, to_2tuple, extend_tuple, make_divisible, _assert
|
||||||
from timm.layers import to_2tuple, extend_tuple, make_divisible, _assert
|
from timm.layers import RelPosMlp, RelPosBias, RelPosBiasTf
|
||||||
from ._builder import build_model_with_cfg
|
from ._builder import build_model_with_cfg
|
||||||
from ._features_fx import register_notrace_function
|
from ._features_fx import register_notrace_function
|
||||||
from ._manipulate import named_apply, checkpoint_seq
|
from ._manipulate import named_apply, checkpoint_seq
|
||||||
from ._pretrained import generate_default_cfgs
|
from ._pretrained import generate_default_cfgs
|
||||||
from ._registry import register_model
|
from ._registry import register_model
|
||||||
from .vision_transformer_relpos import RelPosMlp, RelPosBias # FIXME move these to common location
|
|
||||||
|
|
||||||
__all__ = ['MaxxVitCfg', 'MaxxVitConvCfg', 'MaxxVitTransformerCfg', 'MaxxVit']
|
__all__ = ['MaxxVitCfg', 'MaxxVitConvCfg', 'MaxxVitTransformerCfg', 'MaxxVit']
|
||||||
|
|
||||||
@ -1076,93 +1075,6 @@ def cfg_window_size(cfg: MaxxVitTransformerCfg, img_size: Tuple[int, int]):
|
|||||||
return cfg
|
return cfg
|
||||||
|
|
||||||
|
|
||||||
def generate_lookup_tensor(
|
|
||||||
length: int,
|
|
||||||
max_relative_position: Optional[int] = None,
|
|
||||||
):
|
|
||||||
"""Generate a one_hot lookup tensor to reindex embeddings along one dimension.
|
|
||||||
Args:
|
|
||||||
length: the length to reindex to.
|
|
||||||
max_relative_position: the maximum relative position to consider.
|
|
||||||
Relative position embeddings for distances above this threshold
|
|
||||||
are zeroed out.
|
|
||||||
Returns:
|
|
||||||
a lookup Tensor of size [length, length, vocab_size] that satisfies
|
|
||||||
ret[n,m,v] = 1{m - n + max_relative_position = v}.
|
|
||||||
"""
|
|
||||||
if max_relative_position is None:
|
|
||||||
max_relative_position = length - 1
|
|
||||||
# Return the cached lookup tensor, otherwise compute it and cache it.
|
|
||||||
vocab_size = 2 * max_relative_position + 1
|
|
||||||
ret = torch.zeros(length, length, vocab_size)
|
|
||||||
for i in range(length):
|
|
||||||
for x in range(length):
|
|
||||||
v = x - i + max_relative_position
|
|
||||||
if abs(x - i) > max_relative_position:
|
|
||||||
continue
|
|
||||||
ret[i, x, v] = 1
|
|
||||||
return ret
|
|
||||||
|
|
||||||
|
|
||||||
def reindex_2d_einsum_lookup(
|
|
||||||
relative_position_tensor,
|
|
||||||
height: int,
|
|
||||||
width: int,
|
|
||||||
height_lookup: torch.Tensor,
|
|
||||||
width_lookup: torch.Tensor,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""Reindex 2d relative position bias with 2 independent einsum lookups.
|
|
||||||
Args:
|
|
||||||
relative_position_tensor: tensor of shape
|
|
||||||
[..., vocab_height, vocab_width, ...].
|
|
||||||
height: height to reindex to.
|
|
||||||
width: width to reindex to.
|
|
||||||
height_lookup: one-hot height lookup
|
|
||||||
width_lookup: one-hot width lookup
|
|
||||||
Returns:
|
|
||||||
reindexed_tensor: a Tensor of shape
|
|
||||||
[..., height * width, height * width, ...]
|
|
||||||
"""
|
|
||||||
reindexed_tensor = torch.einsum('nhw,ixh->nixw', relative_position_tensor, height_lookup)
|
|
||||||
reindexed_tensor = torch.einsum('nixw,jyw->nijxy', reindexed_tensor, width_lookup)
|
|
||||||
area = height * width
|
|
||||||
return reindexed_tensor.reshape(relative_position_tensor.shape[0], area, area)
|
|
||||||
|
|
||||||
|
|
||||||
class RelPosBiasTf(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, window_size, num_heads, prefix_tokens=0):
|
|
||||||
super().__init__()
|
|
||||||
assert prefix_tokens <= 1
|
|
||||||
self.window_size = window_size
|
|
||||||
self.window_area = window_size[0] * window_size[1]
|
|
||||||
self.num_heads = num_heads
|
|
||||||
|
|
||||||
vocab_height = 2 * window_size[0] - 1
|
|
||||||
vocab_width = 2 * window_size[1] - 1
|
|
||||||
self.bias_shape = (self.num_heads, vocab_height, vocab_width)
|
|
||||||
self.relative_position_bias_table = nn.Parameter(torch.zeros(self.bias_shape))
|
|
||||||
self.register_buffer('height_lookup', generate_lookup_tensor(window_size[0]), persistent=False)
|
|
||||||
self.register_buffer('width_lookup', generate_lookup_tensor(window_size[1]), persistent=False)
|
|
||||||
self.init_weights()
|
|
||||||
|
|
||||||
def init_weights(self):
|
|
||||||
nn.init.normal_(self.relative_position_bias_table, std=.02)
|
|
||||||
|
|
||||||
def get_bias(self) -> torch.Tensor:
|
|
||||||
# FIXME change to not use one-hot/einsum?
|
|
||||||
return reindex_2d_einsum_lookup(
|
|
||||||
self.relative_position_bias_table,
|
|
||||||
self.window_size[0],
|
|
||||||
self.window_size[1],
|
|
||||||
self.height_lookup,
|
|
||||||
self.width_lookup
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, attn, shared_rel_pos: Optional[torch.Tensor] = None):
|
|
||||||
return attn + self.get_bias()
|
|
||||||
|
|
||||||
|
|
||||||
class NormMlpHead(nn.Module):
|
class NormMlpHead(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -21,93 +21,12 @@ from ._efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficie
|
|||||||
round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT
|
round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT
|
||||||
from ._features import FeatureInfo, FeatureHooks
|
from ._features import FeatureInfo, FeatureHooks
|
||||||
from ._manipulate import checkpoint_seq
|
from ._manipulate import checkpoint_seq
|
||||||
|
from ._pretrained import generate_default_cfgs
|
||||||
from ._registry import register_model
|
from ._registry import register_model
|
||||||
|
|
||||||
__all__ = ['MobileNetV3', 'MobileNetV3Features']
|
__all__ = ['MobileNetV3', 'MobileNetV3Features']
|
||||||
|
|
||||||
|
|
||||||
def _cfg(url='', **kwargs):
|
|
||||||
return {
|
|
||||||
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
|
|
||||||
'crop_pct': 0.875, 'interpolation': 'bilinear',
|
|
||||||
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
|
||||||
'first_conv': 'conv_stem', 'classifier': 'classifier',
|
|
||||||
**kwargs
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
default_cfgs = {
|
|
||||||
'mobilenetv3_large_075': _cfg(url=''),
|
|
||||||
'mobilenetv3_large_100': _cfg(
|
|
||||||
interpolation='bicubic',
|
|
||||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_large_100_ra-f55367f5.pth'),
|
|
||||||
'mobilenetv3_large_100_miil': _cfg(
|
|
||||||
interpolation='bilinear', mean=(0., 0., 0.), std=(1., 1., 1.),
|
|
||||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/mobilenetv3_large_100_1k_miil_78_0-66471c13.pth'),
|
|
||||||
'mobilenetv3_large_100_miil_in21k': _cfg(
|
|
||||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/mobilenetv3_large_100_in21k_miil-d71cc17b.pth',
|
|
||||||
interpolation='bilinear', mean=(0., 0., 0.), std=(1., 1., 1.), num_classes=11221),
|
|
||||||
|
|
||||||
'mobilenetv3_small_050': _cfg(
|
|
||||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_small_050_lambc-4b7bbe87.pth',
|
|
||||||
interpolation='bicubic'),
|
|
||||||
'mobilenetv3_small_075': _cfg(
|
|
||||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_small_075_lambc-384766db.pth',
|
|
||||||
interpolation='bicubic'),
|
|
||||||
'mobilenetv3_small_100': _cfg(
|
|
||||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_small_100_lamb-266a294c.pth',
|
|
||||||
interpolation='bicubic'),
|
|
||||||
|
|
||||||
'mobilenetv3_rw': _cfg(
|
|
||||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_100-35495452.pth',
|
|
||||||
interpolation='bicubic'),
|
|
||||||
|
|
||||||
'tf_mobilenetv3_large_075': _cfg(
|
|
||||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_075-150ee8b0.pth',
|
|
||||||
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
|
|
||||||
'tf_mobilenetv3_large_100': _cfg(
|
|
||||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_100-427764d5.pth',
|
|
||||||
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
|
|
||||||
'tf_mobilenetv3_large_minimal_100': _cfg(
|
|
||||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_minimal_100-8596ae28.pth',
|
|
||||||
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
|
|
||||||
'tf_mobilenetv3_small_075': _cfg(
|
|
||||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_075-da427f52.pth',
|
|
||||||
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
|
|
||||||
'tf_mobilenetv3_small_100': _cfg(
|
|
||||||
url= 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_100-37f49e2b.pth',
|
|
||||||
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
|
|
||||||
'tf_mobilenetv3_small_minimal_100': _cfg(
|
|
||||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_minimal_100-922a7843.pth',
|
|
||||||
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
|
|
||||||
|
|
||||||
'fbnetv3_b': _cfg(
|
|
||||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/fbnetv3_b_224-ead5d2a1.pth',
|
|
||||||
test_input_size=(3, 256, 256), crop_pct=0.95),
|
|
||||||
'fbnetv3_d': _cfg(
|
|
||||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/fbnetv3_d_224-c98bce42.pth',
|
|
||||||
test_input_size=(3, 256, 256), crop_pct=0.95),
|
|
||||||
'fbnetv3_g': _cfg(
|
|
||||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/fbnetv3_g_240-0b1df83b.pth',
|
|
||||||
input_size=(3, 240, 240), test_input_size=(3, 288, 288), crop_pct=0.95, pool_size=(8, 8)),
|
|
||||||
|
|
||||||
"lcnet_035": _cfg(),
|
|
||||||
"lcnet_050": _cfg(
|
|
||||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/lcnet_050-f447553b.pth',
|
|
||||||
interpolation='bicubic',
|
|
||||||
),
|
|
||||||
"lcnet_075": _cfg(
|
|
||||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/lcnet_075-318cad2c.pth',
|
|
||||||
interpolation='bicubic',
|
|
||||||
),
|
|
||||||
"lcnet_100": _cfg(
|
|
||||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/lcnet_100-a929038c.pth',
|
|
||||||
interpolation='bicubic',
|
|
||||||
),
|
|
||||||
"lcnet_150": _cfg(),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class MobileNetV3(nn.Module):
|
class MobileNetV3(nn.Module):
|
||||||
""" MobiletNet-V3
|
""" MobiletNet-V3
|
||||||
|
|
||||||
@ -124,9 +43,24 @@ class MobileNetV3(nn.Module):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, block_args, num_classes=1000, in_chans=3, stem_size=16, fix_stem=False, num_features=1280,
|
self,
|
||||||
head_bias=True, pad_type='', act_layer=None, norm_layer=None, se_layer=None, se_from_exp=True,
|
block_args,
|
||||||
round_chs_fn=round_channels, drop_rate=0., drop_path_rate=0., global_pool='avg'):
|
num_classes=1000,
|
||||||
|
in_chans=3,
|
||||||
|
stem_size=16,
|
||||||
|
fix_stem=False,
|
||||||
|
num_features=1280,
|
||||||
|
head_bias=True,
|
||||||
|
pad_type='',
|
||||||
|
act_layer=None,
|
||||||
|
norm_layer=None,
|
||||||
|
se_layer=None,
|
||||||
|
se_from_exp=True,
|
||||||
|
round_chs_fn=round_channels,
|
||||||
|
drop_rate=0.,
|
||||||
|
drop_path_rate=0.,
|
||||||
|
global_pool='avg',
|
||||||
|
):
|
||||||
super(MobileNetV3, self).__init__()
|
super(MobileNetV3, self).__init__()
|
||||||
act_layer = act_layer or nn.ReLU
|
act_layer = act_layer or nn.ReLU
|
||||||
norm_layer = norm_layer or nn.BatchNorm2d
|
norm_layer = norm_layer or nn.BatchNorm2d
|
||||||
@ -145,8 +79,15 @@ class MobileNetV3(nn.Module):
|
|||||||
|
|
||||||
# Middle stages (IR/ER/DS Blocks)
|
# Middle stages (IR/ER/DS Blocks)
|
||||||
builder = EfficientNetBuilder(
|
builder = EfficientNetBuilder(
|
||||||
output_stride=32, pad_type=pad_type, round_chs_fn=round_chs_fn, se_from_exp=se_from_exp,
|
output_stride=32,
|
||||||
act_layer=act_layer, norm_layer=norm_layer, se_layer=se_layer, drop_path_rate=drop_path_rate)
|
pad_type=pad_type,
|
||||||
|
round_chs_fn=round_chs_fn,
|
||||||
|
se_from_exp=se_from_exp,
|
||||||
|
act_layer=act_layer,
|
||||||
|
norm_layer=norm_layer,
|
||||||
|
se_layer=se_layer,
|
||||||
|
drop_path_rate=drop_path_rate,
|
||||||
|
)
|
||||||
self.blocks = nn.Sequential(*builder(stem_size, block_args))
|
self.blocks = nn.Sequential(*builder(stem_size, block_args))
|
||||||
self.feature_info = builder.features
|
self.feature_info = builder.features
|
||||||
head_chs = builder.in_chs
|
head_chs = builder.in_chs
|
||||||
@ -225,9 +166,23 @@ class MobileNetV3Features(nn.Module):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, block_args, out_indices=(0, 1, 2, 3, 4), feature_location='bottleneck', in_chans=3,
|
self,
|
||||||
stem_size=16, fix_stem=False, output_stride=32, pad_type='', round_chs_fn=round_channels,
|
block_args,
|
||||||
se_from_exp=True, act_layer=None, norm_layer=None, se_layer=None, drop_rate=0., drop_path_rate=0.):
|
out_indices=(0, 1, 2, 3, 4),
|
||||||
|
feature_location='bottleneck',
|
||||||
|
in_chans=3,
|
||||||
|
stem_size=16,
|
||||||
|
fix_stem=False,
|
||||||
|
output_stride=32,
|
||||||
|
pad_type='',
|
||||||
|
round_chs_fn=round_channels,
|
||||||
|
se_from_exp=True,
|
||||||
|
act_layer=None,
|
||||||
|
norm_layer=None,
|
||||||
|
se_layer=None,
|
||||||
|
drop_rate=0.,
|
||||||
|
drop_path_rate=0.,
|
||||||
|
):
|
||||||
super(MobileNetV3Features, self).__init__()
|
super(MobileNetV3Features, self).__init__()
|
||||||
act_layer = act_layer or nn.ReLU
|
act_layer = act_layer or nn.ReLU
|
||||||
norm_layer = norm_layer or nn.BatchNorm2d
|
norm_layer = norm_layer or nn.BatchNorm2d
|
||||||
@ -243,9 +198,16 @@ class MobileNetV3Features(nn.Module):
|
|||||||
|
|
||||||
# Middle stages (IR/ER/DS Blocks)
|
# Middle stages (IR/ER/DS Blocks)
|
||||||
builder = EfficientNetBuilder(
|
builder = EfficientNetBuilder(
|
||||||
output_stride=output_stride, pad_type=pad_type, round_chs_fn=round_chs_fn, se_from_exp=se_from_exp,
|
output_stride=output_stride,
|
||||||
act_layer=act_layer, norm_layer=norm_layer, se_layer=se_layer,
|
pad_type=pad_type,
|
||||||
drop_path_rate=drop_path_rate, feature_location=feature_location)
|
round_chs_fn=round_chs_fn,
|
||||||
|
se_from_exp=se_from_exp,
|
||||||
|
act_layer=act_layer,
|
||||||
|
norm_layer=norm_layer,
|
||||||
|
se_layer=se_layer,
|
||||||
|
drop_path_rate=drop_path_rate,
|
||||||
|
feature_location=feature_location,
|
||||||
|
)
|
||||||
self.blocks = nn.Sequential(*builder(stem_size, block_args))
|
self.blocks = nn.Sequential(*builder(stem_size, block_args))
|
||||||
self.feature_info = FeatureInfo(builder.features, out_indices)
|
self.feature_info = FeatureInfo(builder.features, out_indices)
|
||||||
self._stage_out_idx = {v['stage']: i for i, v in enumerate(self.feature_info) if i in out_indices}
|
self._stage_out_idx = {v['stage']: i for i, v in enumerate(self.feature_info) if i in out_indices}
|
||||||
@ -286,7 +248,9 @@ def _create_mnv3(variant, pretrained=False, **kwargs):
|
|||||||
kwargs_filter = ('num_classes', 'num_features', 'head_conv', 'head_bias', 'global_pool')
|
kwargs_filter = ('num_classes', 'num_features', 'head_conv', 'head_bias', 'global_pool')
|
||||||
model_cls = MobileNetV3Features
|
model_cls = MobileNetV3Features
|
||||||
model = build_model_with_cfg(
|
model = build_model_with_cfg(
|
||||||
model_cls, variant, pretrained,
|
model_cls,
|
||||||
|
variant,
|
||||||
|
pretrained,
|
||||||
pretrained_strict=not features_only,
|
pretrained_strict=not features_only,
|
||||||
kwargs_filter=kwargs_filter,
|
kwargs_filter=kwargs_filter,
|
||||||
**kwargs)
|
**kwargs)
|
||||||
@ -567,6 +531,110 @@ def _gen_lcnet(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def _cfg(url='', **kwargs):
|
||||||
|
return {
|
||||||
|
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
|
||||||
|
'crop_pct': 0.875, 'interpolation': 'bilinear',
|
||||||
|
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
||||||
|
'first_conv': 'conv_stem', 'classifier': 'classifier',
|
||||||
|
**kwargs
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
default_cfgs = generate_default_cfgs({
|
||||||
|
'mobilenetv3_large_075.untrained': _cfg(url=''),
|
||||||
|
'mobilenetv3_large_100.ra_in1k': _cfg(
|
||||||
|
interpolation='bicubic',
|
||||||
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_large_100_ra-f55367f5.pth',
|
||||||
|
hf_hub_id='timm/'),
|
||||||
|
'mobilenetv3_large_100.miil_in21k_ft_in1k': _cfg(
|
||||||
|
interpolation='bilinear', mean=(0., 0., 0.), std=(1., 1., 1.),
|
||||||
|
origin_url='https://github.com/Alibaba-MIIL/ImageNet21K',
|
||||||
|
paper_ids='arXiv:2104.10972v4',
|
||||||
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/mobilenetv3_large_100_1k_miil_78_0-66471c13.pth',
|
||||||
|
hf_hub_id='timm/'),
|
||||||
|
'mobilenetv3_large_100.miil_in21k': _cfg(
|
||||||
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/mobilenetv3_large_100_in21k_miil-d71cc17b.pth',
|
||||||
|
hf_hub_id='timm/',
|
||||||
|
origin_url='https://github.com/Alibaba-MIIL/ImageNet21K',
|
||||||
|
paper_ids='arXiv:2104.10972v4',
|
||||||
|
interpolation='bilinear', mean=(0., 0., 0.), std=(1., 1., 1.), num_classes=11221),
|
||||||
|
|
||||||
|
'mobilenetv3_small_050.lamb_in1k': _cfg(
|
||||||
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_small_050_lambc-4b7bbe87.pth',
|
||||||
|
hf_hub_id='timm/',
|
||||||
|
interpolation='bicubic'),
|
||||||
|
'mobilenetv3_small_075.lamb_in1k': _cfg(
|
||||||
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_small_075_lambc-384766db.pth',
|
||||||
|
hf_hub_id='timm/',
|
||||||
|
interpolation='bicubic'),
|
||||||
|
'mobilenetv3_small_100.lamb_in1k': _cfg(
|
||||||
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_small_100_lamb-266a294c.pth',
|
||||||
|
hf_hub_id='timm/',
|
||||||
|
interpolation='bicubic'),
|
||||||
|
|
||||||
|
'mobilenetv3_rw.rmsp_in1k': _cfg(
|
||||||
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_100-35495452.pth',
|
||||||
|
interpolation='bicubic'),
|
||||||
|
|
||||||
|
'tf_mobilenetv3_large_075.in1k': _cfg(
|
||||||
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_075-150ee8b0.pth',
|
||||||
|
hf_hub_id='timm/',
|
||||||
|
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
|
||||||
|
'tf_mobilenetv3_large_100.in1k': _cfg(
|
||||||
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_100-427764d5.pth',
|
||||||
|
hf_hub_id='timm/',
|
||||||
|
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
|
||||||
|
'tf_mobilenetv3_large_minimal_100.in1k': _cfg(
|
||||||
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_minimal_100-8596ae28.pth',
|
||||||
|
hf_hub_id='timm/',
|
||||||
|
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
|
||||||
|
'tf_mobilenetv3_small_075.in1k': _cfg(
|
||||||
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_075-da427f52.pth',
|
||||||
|
hf_hub_id='timm/',
|
||||||
|
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
|
||||||
|
'tf_mobilenetv3_small_100.in1k': _cfg(
|
||||||
|
url= 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_100-37f49e2b.pth',
|
||||||
|
hf_hub_id='timm/',
|
||||||
|
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
|
||||||
|
'tf_mobilenetv3_small_minimal_100.in1k': _cfg(
|
||||||
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_minimal_100-922a7843.pth',
|
||||||
|
hf_hub_id='timm/',
|
||||||
|
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
|
||||||
|
|
||||||
|
'fbnetv3_b.ra2_in1k': _cfg(
|
||||||
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/fbnetv3_b_224-ead5d2a1.pth',
|
||||||
|
hf_hub_id='timm/',
|
||||||
|
test_input_size=(3, 256, 256), crop_pct=0.95),
|
||||||
|
'fbnetv3_d.ra2_in1k': _cfg(
|
||||||
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/fbnetv3_d_224-c98bce42.pth',
|
||||||
|
hf_hub_id='timm/',
|
||||||
|
test_input_size=(3, 256, 256), crop_pct=0.95),
|
||||||
|
'fbnetv3_g.ra2_in1k': _cfg(
|
||||||
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/fbnetv3_g_240-0b1df83b.pth',
|
||||||
|
hf_hub_id='timm/',
|
||||||
|
input_size=(3, 240, 240), test_input_size=(3, 288, 288), crop_pct=0.95, pool_size=(8, 8)),
|
||||||
|
|
||||||
|
"lcnet_035.untrained": _cfg(),
|
||||||
|
"lcnet_050.ra2_in1k": _cfg(
|
||||||
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/lcnet_050-f447553b.pth',
|
||||||
|
hf_hub_id='timm/',
|
||||||
|
interpolation='bicubic',
|
||||||
|
),
|
||||||
|
"lcnet_075.ra2_in1k": _cfg(
|
||||||
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/lcnet_075-318cad2c.pth',
|
||||||
|
hf_hub_id='timm/',
|
||||||
|
interpolation='bicubic',
|
||||||
|
),
|
||||||
|
"lcnet_100.ra2_in1k": _cfg(
|
||||||
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/lcnet_100-a929038c.pth',
|
||||||
|
hf_hub_id='timm/',
|
||||||
|
interpolation='bicubic',
|
||||||
|
),
|
||||||
|
"lcnet_150.untrained": _cfg(),
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def mobilenetv3_large_075(pretrained=False, **kwargs):
|
def mobilenetv3_large_075(pretrained=False, **kwargs):
|
||||||
""" MobileNet V3 """
|
""" MobileNet V3 """
|
||||||
@ -581,24 +649,6 @@ def mobilenetv3_large_100(pretrained=False, **kwargs):
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
|
||||||
def mobilenetv3_large_100_miil(pretrained=False, **kwargs):
|
|
||||||
""" MobileNet V3
|
|
||||||
Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K
|
|
||||||
"""
|
|
||||||
model = _gen_mobilenet_v3('mobilenetv3_large_100_miil', 1.0, pretrained=pretrained, **kwargs)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
|
||||||
def mobilenetv3_large_100_miil_in21k(pretrained=False, **kwargs):
|
|
||||||
""" MobileNet V3, 21k pretraining
|
|
||||||
Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K
|
|
||||||
"""
|
|
||||||
model = _gen_mobilenet_v3('mobilenetv3_large_100_miil_in21k', 1.0, pretrained=pretrained, **kwargs)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def mobilenetv3_small_050(pretrained=False, **kwargs):
|
def mobilenetv3_small_050(pretrained=False, **kwargs):
|
||||||
""" MobileNet V3 """
|
""" MobileNet V3 """
|
||||||
|
@ -8,12 +8,16 @@ A PyTorch implement of Vision Transformers as described in:
|
|||||||
`How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers`
|
`How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers`
|
||||||
- https://arxiv.org/abs/2106.10270
|
- https://arxiv.org/abs/2106.10270
|
||||||
|
|
||||||
The official jax code is released and available at https://github.com/google-research/vision_transformer
|
`FlexiViT: One Model for All Patch Sizes`
|
||||||
|
- https://arxiv.org/abs/2212.08013
|
||||||
|
|
||||||
|
The official jax code is released and available at
|
||||||
|
* https://github.com/google-research/vision_transformer
|
||||||
|
* https://github.com/google-research/big_vision
|
||||||
|
|
||||||
Acknowledgments:
|
Acknowledgments:
|
||||||
* The paper authors for releasing code and weights, thanks!
|
* The paper authors for releasing code and weights, thanks!
|
||||||
* I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch ... check it out
|
* I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch
|
||||||
for some einops/einsum fun
|
|
||||||
* Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT
|
* Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT
|
||||||
* Bert reference code checks against Huggingface Transformers and Tensorflow Bert
|
* Bert reference code checks against Huggingface Transformers and Tensorflow Bert
|
||||||
|
|
||||||
@ -23,7 +27,7 @@ import logging
|
|||||||
import math
|
import math
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Optional
|
from typing import Optional, List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -32,7 +36,8 @@ import torch.utils.checkpoint
|
|||||||
|
|
||||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD, \
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD, \
|
||||||
OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
|
OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
|
||||||
from timm.layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_
|
from timm.layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_, resample_patch_embed, \
|
||||||
|
resample_abs_pos_embed
|
||||||
from ._builder import build_model_with_cfg
|
from ._builder import build_model_with_cfg
|
||||||
from ._manipulate import named_apply, checkpoint_seq, adapt_input_conv
|
from ._manipulate import named_apply, checkpoint_seq, adapt_input_conv
|
||||||
from ._pretrained import generate_default_cfgs
|
from ._pretrained import generate_default_cfgs
|
||||||
@ -449,6 +454,39 @@ def get_init_weights_vit(mode='jax', head_bias: float = 0.):
|
|||||||
return init_weights_vit_timm
|
return init_weights_vit_timm
|
||||||
|
|
||||||
|
|
||||||
|
def resize_pos_embed(
|
||||||
|
posemb,
|
||||||
|
posemb_new,
|
||||||
|
num_prefix_tokens=1,
|
||||||
|
gs_new=(),
|
||||||
|
interpolation='bicubic',
|
||||||
|
antialias=False,
|
||||||
|
):
|
||||||
|
""" Rescale the grid of position embeddings when loading from state_dict.
|
||||||
|
|
||||||
|
*DEPRECATED* This function is being deprecated in favour of resample_abs_pos_embed
|
||||||
|
|
||||||
|
Adapted from:
|
||||||
|
https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
|
||||||
|
"""
|
||||||
|
ntok_new = posemb_new.shape[1]
|
||||||
|
if num_prefix_tokens:
|
||||||
|
posemb_prefix, posemb_grid = posemb[:, :num_prefix_tokens], posemb[0, num_prefix_tokens:]
|
||||||
|
ntok_new -= num_prefix_tokens
|
||||||
|
else:
|
||||||
|
posemb_prefix, posemb_grid = posemb[:, :0], posemb[0]
|
||||||
|
gs_old = int(math.sqrt(len(posemb_grid)))
|
||||||
|
if not len(gs_new): # backwards compatibility
|
||||||
|
gs_new = [int(math.sqrt(ntok_new))] * 2
|
||||||
|
assert len(gs_new) >= 2
|
||||||
|
_logger.info(f'Resized position embedding: {posemb.shape} ({[gs_old, gs_old]}) to {posemb_new.shape} ({gs_new}).')
|
||||||
|
posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
|
||||||
|
posemb_grid = F.interpolate(posemb_grid, size=gs_new, mode=interpolation, antialias=antialias, align_corners=False)
|
||||||
|
posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new[0] * gs_new[1], -1)
|
||||||
|
posemb = torch.cat([posemb_prefix, posemb_grid], dim=1)
|
||||||
|
return posemb
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ''):
|
def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ''):
|
||||||
""" Load weights from .npz checkpoints for official Google Brain Flax implementation
|
""" Load weights from .npz checkpoints for official Google Brain Flax implementation
|
||||||
@ -468,8 +506,15 @@ def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str =
|
|||||||
return torch.from_numpy(w)
|
return torch.from_numpy(w)
|
||||||
|
|
||||||
w = np.load(checkpoint_path)
|
w = np.load(checkpoint_path)
|
||||||
if not prefix and 'opt/target/embedding/kernel' in w:
|
interpolation = 'bilinear'
|
||||||
|
antialias = False
|
||||||
|
big_vision = False
|
||||||
|
if not prefix:
|
||||||
|
if 'opt/target/embedding/kernel' in w:
|
||||||
prefix = 'opt/target/'
|
prefix = 'opt/target/'
|
||||||
|
elif 'params/embedding/kernel' in w:
|
||||||
|
prefix = 'params/'
|
||||||
|
big_vision = True
|
||||||
|
|
||||||
if hasattr(model.patch_embed, 'backbone'):
|
if hasattr(model.patch_embed, 'backbone'):
|
||||||
# hybrid
|
# hybrid
|
||||||
@ -495,17 +540,33 @@ def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str =
|
|||||||
else:
|
else:
|
||||||
embed_conv_w = adapt_input_conv(
|
embed_conv_w = adapt_input_conv(
|
||||||
model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel']))
|
model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel']))
|
||||||
|
if embed_conv_w.shape[-2:] != model.patch_embed.proj.weight.shape[-2:]:
|
||||||
|
embed_conv_w = resample_patch_embed(
|
||||||
|
embed_conv_w,
|
||||||
|
model.patch_embed.proj.weight.shape[-2:],
|
||||||
|
interpolation=interpolation,
|
||||||
|
antialias=antialias,
|
||||||
|
verbose=True,
|
||||||
|
)
|
||||||
|
|
||||||
model.patch_embed.proj.weight.copy_(embed_conv_w)
|
model.patch_embed.proj.weight.copy_(embed_conv_w)
|
||||||
model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias']))
|
model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias']))
|
||||||
if model.cls_token is not None:
|
if model.cls_token is not None:
|
||||||
model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False))
|
model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False))
|
||||||
|
if big_vision:
|
||||||
|
pos_embed_w = _n2p(w[f'{prefix}pos_embedding'], t=False)
|
||||||
|
else:
|
||||||
pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False)
|
pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False)
|
||||||
if pos_embed_w.shape != model.pos_embed.shape:
|
if pos_embed_w.shape != model.pos_embed.shape:
|
||||||
pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights
|
old_shape = pos_embed_w.shape
|
||||||
|
num_prefix_tokens = 0 if getattr(model, 'no_embed_class', False) else getattr(model, 'num_prefix_tokens', 1)
|
||||||
|
pos_embed_w = resample_abs_pos_embed( # resize pos embedding when different size from pretrained weights
|
||||||
pos_embed_w,
|
pos_embed_w,
|
||||||
model.pos_embed,
|
new_size=model.patch_embed.grid_size,
|
||||||
getattr(model, 'num_prefix_tokens', 1),
|
num_prefix_tokens=num_prefix_tokens,
|
||||||
model.patch_embed.grid_size
|
interpolation=interpolation,
|
||||||
|
antialias=antialias,
|
||||||
|
verbose=True,
|
||||||
)
|
)
|
||||||
model.pos_embed.copy_(pos_embed_w)
|
model.pos_embed.copy_(pos_embed_w)
|
||||||
model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale']))
|
model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale']))
|
||||||
@ -517,9 +578,10 @@ def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str =
|
|||||||
# if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w:
|
# if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w:
|
||||||
# model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel']))
|
# model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel']))
|
||||||
# model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias']))
|
# model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias']))
|
||||||
|
mha_sub, b_sub, ln1_sub = (0, 0, 1) if big_vision else (1, 3, 2)
|
||||||
for i, block in enumerate(model.blocks.children()):
|
for i, block in enumerate(model.blocks.children()):
|
||||||
block_prefix = f'{prefix}Transformer/encoderblock_{i}/'
|
block_prefix = f'{prefix}Transformer/encoderblock_{i}/'
|
||||||
mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/'
|
mha_prefix = block_prefix + f'MultiHeadDotProductAttention_{mha_sub}/'
|
||||||
block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
|
block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
|
||||||
block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
|
block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
|
||||||
block.attn.qkv.weight.copy_(torch.cat([
|
block.attn.qkv.weight.copy_(torch.cat([
|
||||||
@ -529,32 +591,10 @@ def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str =
|
|||||||
block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
|
block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
|
||||||
block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
|
block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
|
||||||
for r in range(2):
|
for r in range(2):
|
||||||
getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel']))
|
getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/kernel']))
|
||||||
getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias']))
|
getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/bias']))
|
||||||
block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale']))
|
block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/scale']))
|
||||||
block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias']))
|
block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/bias']))
|
||||||
|
|
||||||
|
|
||||||
def resize_pos_embed(posemb, posemb_new, num_prefix_tokens=1, gs_new=()):
|
|
||||||
# Rescale the grid of position embeddings when loading from state_dict. Adapted from
|
|
||||||
# https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
|
|
||||||
_logger.info('Resized position embedding: %s to %s', posemb.shape, posemb_new.shape)
|
|
||||||
ntok_new = posemb_new.shape[1]
|
|
||||||
if num_prefix_tokens:
|
|
||||||
posemb_prefix, posemb_grid = posemb[:, :num_prefix_tokens], posemb[0, num_prefix_tokens:]
|
|
||||||
ntok_new -= num_prefix_tokens
|
|
||||||
else:
|
|
||||||
posemb_prefix, posemb_grid = posemb[:, :0], posemb[0]
|
|
||||||
gs_old = int(math.sqrt(len(posemb_grid)))
|
|
||||||
if not len(gs_new): # backwards compatibility
|
|
||||||
gs_new = [int(math.sqrt(ntok_new))] * 2
|
|
||||||
assert len(gs_new) >= 2
|
|
||||||
_logger.info('Position embedding grid-size from %s to %s', [gs_old, gs_old], gs_new)
|
|
||||||
posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
|
|
||||||
posemb_grid = F.interpolate(posemb_grid, size=gs_new, mode='bicubic', align_corners=False)
|
|
||||||
posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new[0] * gs_new[1], -1)
|
|
||||||
posemb = torch.cat([posemb_prefix, posemb_grid], dim=1)
|
|
||||||
return posemb
|
|
||||||
|
|
||||||
|
|
||||||
def _convert_openai_clip(state_dict, model):
|
def _convert_openai_clip(state_dict, model):
|
||||||
@ -591,7 +631,13 @@ def _convert_openai_clip(state_dict, model):
|
|||||||
return out_dict
|
return out_dict
|
||||||
|
|
||||||
|
|
||||||
def checkpoint_filter_fn(state_dict, model, adapt_layer_scale=False):
|
def checkpoint_filter_fn(
|
||||||
|
state_dict,
|
||||||
|
model,
|
||||||
|
adapt_layer_scale=False,
|
||||||
|
interpolation='bicubic',
|
||||||
|
antialias=True,
|
||||||
|
):
|
||||||
""" convert patch embedding weight from manual patchify + linear proj to conv"""
|
""" convert patch embedding weight from manual patchify + linear proj to conv"""
|
||||||
import re
|
import re
|
||||||
out_dict = {}
|
out_dict = {}
|
||||||
@ -603,17 +649,30 @@ def checkpoint_filter_fn(state_dict, model, adapt_layer_scale=False):
|
|||||||
return _convert_openai_clip(state_dict, model)
|
return _convert_openai_clip(state_dict, model)
|
||||||
|
|
||||||
for k, v in state_dict.items():
|
for k, v in state_dict.items():
|
||||||
if 'patch_embed.proj.weight' in k and len(v.shape) < 4:
|
if 'patch_embed.proj.weight' in k:
|
||||||
|
O, I, H, W = model.patch_embed.proj.weight.shape
|
||||||
|
if len(v.shape) < 4:
|
||||||
# For old models that I trained prior to conv based patchification
|
# For old models that I trained prior to conv based patchification
|
||||||
O, I, H, W = model.patch_embed.proj.weight.shape
|
O, I, H, W = model.patch_embed.proj.weight.shape
|
||||||
v = v.reshape(O, -1, H, W)
|
v = v.reshape(O, -1, H, W)
|
||||||
|
if v.shape[-1] != W or v.shape[-2] != H:
|
||||||
|
v = resample_patch_embed(
|
||||||
|
v,
|
||||||
|
(H, W),
|
||||||
|
interpolation=interpolation,
|
||||||
|
antialias=antialias,
|
||||||
|
verbose=True,
|
||||||
|
)
|
||||||
elif k == 'pos_embed' and v.shape[1] != model.pos_embed.shape[1]:
|
elif k == 'pos_embed' and v.shape[1] != model.pos_embed.shape[1]:
|
||||||
# To resize pos embedding when using model at different size from pretrained weights
|
# To resize pos embedding when using model at different size from pretrained weights
|
||||||
v = resize_pos_embed(
|
num_prefix_tokens = 0 if getattr(model, 'no_embed_class', False) else getattr(model, 'num_prefix_tokens', 1)
|
||||||
|
v = resample_abs_pos_embed(
|
||||||
v,
|
v,
|
||||||
model.pos_embed,
|
new_size=model.patch_embed.grid_size,
|
||||||
0 if getattr(model, 'no_embed_class') else getattr(model, 'num_prefix_tokens', 1),
|
num_prefix_tokens=num_prefix_tokens,
|
||||||
model.patch_embed.grid_size
|
interpolation=interpolation,
|
||||||
|
antialias=antialias,
|
||||||
|
verbose=True,
|
||||||
)
|
)
|
||||||
elif adapt_layer_scale and 'gamma_' in k:
|
elif adapt_layer_scale and 'gamma_' in k:
|
||||||
# remap layer-scale gamma into sub-module (deit3 models)
|
# remap layer-scale gamma into sub-module (deit3 models)
|
||||||
@ -641,67 +700,101 @@ default_cfgs = generate_default_cfgs({
|
|||||||
# How to train your ViT (augreg) weights, pretrained on 21k FT on in1k
|
# How to train your ViT (augreg) weights, pretrained on 21k FT on in1k
|
||||||
'vit_tiny_patch16_224.augreg_in21k_ft_in1k': _cfg(
|
'vit_tiny_patch16_224.augreg_in21k_ft_in1k': _cfg(
|
||||||
url='https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz',
|
url='https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz',
|
||||||
|
hf_hub_id='timm/',
|
||||||
custom_load=True),
|
custom_load=True),
|
||||||
'vit_tiny_patch16_384.augreg_in21k_ft_in1k': _cfg(
|
'vit_tiny_patch16_384.augreg_in21k_ft_in1k': _cfg(
|
||||||
url='https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
|
url='https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
|
||||||
|
hf_hub_id='timm/',
|
||||||
custom_load=True, input_size=(3, 384, 384), crop_pct=1.0),
|
custom_load=True, input_size=(3, 384, 384), crop_pct=1.0),
|
||||||
'vit_small_patch32_224.augreg_in21k_ft_in1k': _cfg(
|
'vit_small_patch32_224.augreg_in21k_ft_in1k': _cfg(
|
||||||
url='https://storage.googleapis.com/vit_models/augreg/S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz',
|
url='https://storage.googleapis.com/vit_models/augreg/S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz',
|
||||||
|
hf_hub_id='timm/',
|
||||||
custom_load=True),
|
custom_load=True),
|
||||||
'vit_small_patch32_384.augreg_in21k_ft_in1k': _cfg(
|
'vit_small_patch32_384.augreg_in21k_ft_in1k': _cfg(
|
||||||
url='https://storage.googleapis.com/vit_models/augreg/S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
|
url='https://storage.googleapis.com/vit_models/augreg/S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
|
||||||
|
hf_hub_id='timm/',
|
||||||
custom_load=True, input_size=(3, 384, 384), crop_pct=1.0),
|
custom_load=True, input_size=(3, 384, 384), crop_pct=1.0),
|
||||||
'vit_small_patch16_224.augreg_in21k_ft_in1k': _cfg(
|
'vit_small_patch16_224.augreg_in21k_ft_in1k': _cfg(
|
||||||
url='https://storage.googleapis.com/vit_models/augreg/S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz',
|
url='https://storage.googleapis.com/vit_models/augreg/S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz',
|
||||||
|
hf_hub_id='timm/',
|
||||||
custom_load=True),
|
custom_load=True),
|
||||||
'vit_small_patch16_384.augreg_in21k_ft_in1k': _cfg(
|
'vit_small_patch16_384.augreg_in21k_ft_in1k': _cfg(
|
||||||
url='https://storage.googleapis.com/vit_models/augreg/S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
|
url='https://storage.googleapis.com/vit_models/augreg/S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
|
||||||
|
hf_hub_id='timm/',
|
||||||
custom_load=True, input_size=(3, 384, 384), crop_pct=1.0),
|
custom_load=True, input_size=(3, 384, 384), crop_pct=1.0),
|
||||||
'vit_base_patch32_224.augreg_in21k_ft_in1k': _cfg(
|
'vit_base_patch32_224.augreg_in21k_ft_in1k': _cfg(
|
||||||
url='https://storage.googleapis.com/vit_models/augreg/B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz',
|
url='https://storage.googleapis.com/vit_models/augreg/B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz',
|
||||||
|
hf_hub_id='timm/',
|
||||||
custom_load=True),
|
custom_load=True),
|
||||||
'vit_base_patch32_384.augreg_in21k_ft_in1k': _cfg(
|
'vit_base_patch32_384.augreg_in21k_ft_in1k': _cfg(
|
||||||
url='https://storage.googleapis.com/vit_models/augreg/B_32-i21k-300ep-lr_0.001-aug_light1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
|
url='https://storage.googleapis.com/vit_models/augreg/B_32-i21k-300ep-lr_0.001-aug_light1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
|
||||||
|
hf_hub_id='timm/',
|
||||||
custom_load=True, input_size=(3, 384, 384), crop_pct=1.0),
|
custom_load=True, input_size=(3, 384, 384), crop_pct=1.0),
|
||||||
'vit_base_patch16_224.augreg_in21k_ft_in1k': _cfg(
|
'vit_base_patch16_224.augreg_in21k_ft_in1k': _cfg(
|
||||||
url='https://storage.googleapis.com/vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz',
|
url='https://storage.googleapis.com/vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz',
|
||||||
|
hf_hub_id='timm/',
|
||||||
custom_load=True),
|
custom_load=True),
|
||||||
'vit_base_patch16_384.augreg_in21k_ft_in1k': _cfg(
|
'vit_base_patch16_384.augreg_in21k_ft_in1k': _cfg(
|
||||||
url='https://storage.googleapis.com/vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz',
|
url='https://storage.googleapis.com/vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz',
|
||||||
|
hf_hub_id='timm/',
|
||||||
custom_load=True, input_size=(3, 384, 384), crop_pct=1.0),
|
custom_load=True, input_size=(3, 384, 384), crop_pct=1.0),
|
||||||
'vit_base_patch8_224.augreg_in21k_ft_in1k': _cfg(
|
'vit_base_patch8_224.augreg_in21k_ft_in1k': _cfg(
|
||||||
url='https://storage.googleapis.com/vit_models/augreg/B_8-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz',
|
url='https://storage.googleapis.com/vit_models/augreg/B_8-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz',
|
||||||
|
hf_hub_id='timm/',
|
||||||
custom_load=True),
|
custom_load=True),
|
||||||
'vit_large_patch16_224.augreg_in21k_ft_in1k': _cfg(
|
'vit_large_patch16_224.augreg_in21k_ft_in1k': _cfg(
|
||||||
url='https://storage.googleapis.com/vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz',
|
url='https://storage.googleapis.com/vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz',
|
||||||
|
hf_hub_id='timm/',
|
||||||
custom_load=True),
|
custom_load=True),
|
||||||
'vit_large_patch16_384.augreg_in21k_ft_in1k': _cfg(
|
'vit_large_patch16_384.augreg_in21k_ft_in1k': _cfg(
|
||||||
url='https://storage.googleapis.com/vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz',
|
url='https://storage.googleapis.com/vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz',
|
||||||
|
hf_hub_id='timm/',
|
||||||
custom_load=True, input_size=(3, 384, 384), crop_pct=1.0),
|
custom_load=True, input_size=(3, 384, 384), crop_pct=1.0),
|
||||||
|
|
||||||
# re-finetuned augreg 21k FT on in1k weights
|
# re-finetuned augreg 21k FT on in1k weights
|
||||||
'vit_base_patch16_224.augreg2_in21k_ft_in1k': _cfg(
|
'vit_base_patch16_224.augreg2_in21k_ft_in1k': _cfg(
|
||||||
file='b16_augreg-a-8.pth'),
|
hf_hub_id='timm/'),
|
||||||
'vit_base_patch16_384.augreg2_in21k_ft_in1k': _cfg(
|
'vit_base_patch16_384.augreg2_in21k_ft_in1k': _cfg(),
|
||||||
url=''),
|
|
||||||
'vit_base_patch8_224.augreg2_in21k_ft_in1k': _cfg(
|
'vit_base_patch8_224.augreg2_in21k_ft_in1k': _cfg(
|
||||||
url=''),
|
hf_hub_id='timm/'),
|
||||||
|
|
||||||
# patch models (weights from official Google JAX impl) pretrained on in21k FT on in1k
|
# patch models (weights from official Google JAX impl) pretrained on in21k FT on in1k
|
||||||
'vit_base_patch16_224.orig_in21k_ft_in1k': _cfg(
|
'vit_base_patch16_224.orig_in21k_ft_in1k': _cfg(
|
||||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth'),
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth',
|
||||||
|
hf_hub_id='timm/'),
|
||||||
'vit_base_patch16_384.orig_in21k_ft_in1k': _cfg(
|
'vit_base_patch16_384.orig_in21k_ft_in1k': _cfg(
|
||||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_384-83fb41ba.pth'),
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_384-83fb41ba.pth',
|
||||||
|
hf_hub_id='timm/',
|
||||||
|
input_size=(3, 384, 384), crop_pct=1.0),
|
||||||
'vit_large_patch32_384.orig_in21k_ft_in1k': _cfg(
|
'vit_large_patch32_384.orig_in21k_ft_in1k': _cfg(
|
||||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p32_384-9b920ba8.pth',
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p32_384-9b920ba8.pth',
|
||||||
|
hf_hub_id='timm/',
|
||||||
input_size=(3, 384, 384), crop_pct=1.0),
|
input_size=(3, 384, 384), crop_pct=1.0),
|
||||||
|
|
||||||
# How to train your ViT (augreg) weights trained on in1k
|
# How to train your ViT (augreg) weights trained on in1k only
|
||||||
|
'vit_small_patch16_224.augreg_in1k': _cfg(
|
||||||
|
url='https://storage.googleapis.com/vit_models/augreg/S_16-i1k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz',
|
||||||
|
hf_hub_id='timm/',
|
||||||
|
custom_load=True),
|
||||||
|
'vit_small_patch16_384.augreg_in1k': _cfg(
|
||||||
|
url='https://storage.googleapis.com/vit_models/augreg/S_16-i1k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz',
|
||||||
|
hf_hub_id='timm/',
|
||||||
|
custom_load=True, input_size=(3, 384, 384), crop_pct=1.0),
|
||||||
|
'vit_base_patch32_224.augreg_in1k': _cfg(
|
||||||
|
url='https://storage.googleapis.com/vit_models/augreg/B_32-i1k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz',
|
||||||
|
hf_hub_id='timm/',
|
||||||
|
custom_load=True),
|
||||||
|
'vit_base_patch32_384.augreg_in1k': _cfg(
|
||||||
|
url='https://storage.googleapis.com/vit_models/augreg/B_32-i1k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz',
|
||||||
|
hf_hub_id='timm/',
|
||||||
|
custom_load=True, input_size=(3, 384, 384), crop_pct=1.0),
|
||||||
'vit_base_patch16_224.augreg_in1k': _cfg(
|
'vit_base_patch16_224.augreg_in1k': _cfg(
|
||||||
url='https://storage.googleapis.com/vit_models/augreg/B_16-i1k-300ep-lr_0.001-aug_strong2-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz',
|
url='https://storage.googleapis.com/vit_models/augreg/B_16-i1k-300ep-lr_0.001-aug_strong2-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz',
|
||||||
|
hf_hub_id='timm/',
|
||||||
custom_load=True),
|
custom_load=True),
|
||||||
'vit_base_patch16_384.augreg_in1k': _cfg(
|
'vit_base_patch16_384.augreg_in1k': _cfg(
|
||||||
url='https://storage.googleapis.com/vit_models/augreg/B_16-i1k-300ep-lr_0.001-aug_strong2-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz',
|
url='https://storage.googleapis.com/vit_models/augreg/B_16-i1k-300ep-lr_0.001-aug_strong2-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz',
|
||||||
|
hf_hub_id='timm/',
|
||||||
custom_load=True, input_size=(3, 384, 384), crop_pct=1.0),
|
custom_load=True, input_size=(3, 384, 384), crop_pct=1.0),
|
||||||
|
|
||||||
'vit_large_patch14_224.untrained': _cfg(url=''),
|
'vit_large_patch14_224.untrained': _cfg(url=''),
|
||||||
@ -711,77 +804,94 @@ default_cfgs = generate_default_cfgs({
|
|||||||
|
|
||||||
|
|
||||||
# patch models, imagenet21k (weights from official Google JAX impl)
|
# patch models, imagenet21k (weights from official Google JAX impl)
|
||||||
'vit_large_patch32_224.v1_in21k': _cfg(
|
'vit_large_patch32_224.orig_in21k': _cfg(
|
||||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth',
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth',
|
||||||
|
hf_hub_id='timm/',
|
||||||
num_classes=21843),
|
num_classes=21843),
|
||||||
'vit_huge_patch14_224.v1_in21k': _cfg(
|
'vit_huge_patch14_224.orig_in21k': _cfg(
|
||||||
url='https://storage.googleapis.com/vit_models/imagenet21k/ViT-H_14.npz',
|
url='https://storage.googleapis.com/vit_models/imagenet21k/ViT-H_14.npz',
|
||||||
hf_hub_id='timm/vit_huge_patch14_224_in21k',
|
hf_hub_id='timm/',
|
||||||
custom_load=True, num_classes=21843),
|
custom_load=True, num_classes=21843),
|
||||||
|
|
||||||
# How to train your ViT (augreg) weights, pretrained on in21k
|
# How to train your ViT (augreg) weights, pretrained on in21k
|
||||||
'vit_tiny_patch16_224.augreg_in21k': _cfg(
|
'vit_tiny_patch16_224.augreg_in21k': _cfg(
|
||||||
url='https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz',
|
url='https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz',
|
||||||
|
hf_hub_id='timm/',
|
||||||
custom_load=True, num_classes=21843),
|
custom_load=True, num_classes=21843),
|
||||||
'vit_small_patch32_224.augreg_in21k': _cfg(
|
'vit_small_patch32_224.augreg_in21k': _cfg(
|
||||||
url='https://storage.googleapis.com/vit_models/augreg/S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz',
|
url='https://storage.googleapis.com/vit_models/augreg/S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz',
|
||||||
|
hf_hub_id='timm/',
|
||||||
custom_load=True, num_classes=21843),
|
custom_load=True, num_classes=21843),
|
||||||
'vit_small_patch16_224.augreg_in21k': _cfg(
|
'vit_small_patch16_224.augreg_in21k': _cfg(
|
||||||
url='https://storage.googleapis.com/vit_models/augreg/S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz',
|
url='https://storage.googleapis.com/vit_models/augreg/S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz',
|
||||||
|
hf_hub_id='timm/',
|
||||||
custom_load=True, num_classes=21843),
|
custom_load=True, num_classes=21843),
|
||||||
'vit_base_patch32_224.augreg_in21k': _cfg(
|
'vit_base_patch32_224.augreg_in21k': _cfg(
|
||||||
url='https://storage.googleapis.com/vit_models/augreg/B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0.npz',
|
url='https://storage.googleapis.com/vit_models/augreg/B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0.npz',
|
||||||
|
hf_hub_id='timm/',
|
||||||
custom_load=True, num_classes=21843),
|
custom_load=True, num_classes=21843),
|
||||||
'vit_base_patch16_224.augreg_in21k': _cfg(
|
'vit_base_patch16_224.augreg_in21k': _cfg(
|
||||||
url='https://storage.googleapis.com/vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz',
|
url='https://storage.googleapis.com/vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz',
|
||||||
|
hf_hub_id='timm/',
|
||||||
custom_load=True, num_classes=21843),
|
custom_load=True, num_classes=21843),
|
||||||
'vit_base_patch8_224.augreg_in21k': _cfg(
|
'vit_base_patch8_224.augreg_in21k': _cfg(
|
||||||
url='https://storage.googleapis.com/vit_models/augreg/B_8-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz',
|
url='https://storage.googleapis.com/vit_models/augreg/B_8-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz',
|
||||||
|
hf_hub_id='timm/',
|
||||||
custom_load=True, num_classes=21843),
|
custom_load=True, num_classes=21843),
|
||||||
'vit_large_patch16_224.augreg_in21k': _cfg(
|
'vit_large_patch16_224.augreg_in21k': _cfg(
|
||||||
url='https://storage.googleapis.com/vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1.npz',
|
url='https://storage.googleapis.com/vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1.npz',
|
||||||
|
hf_hub_id='timm/',
|
||||||
custom_load=True, num_classes=21843),
|
custom_load=True, num_classes=21843),
|
||||||
|
|
||||||
# SAM trained models (https://arxiv.org/abs/2106.01548)
|
# SAM trained models (https://arxiv.org/abs/2106.01548)
|
||||||
'vit_base_patch32_224.sam': _cfg(
|
'vit_base_patch32_224.sam': _cfg(
|
||||||
url='https://storage.googleapis.com/vit_models/sam/ViT-B_32.npz', custom_load=True),
|
url='https://storage.googleapis.com/vit_models/sam/ViT-B_32.npz', custom_load=True,
|
||||||
|
hf_hub_id='timm/'),
|
||||||
'vit_base_patch16_224.sam': _cfg(
|
'vit_base_patch16_224.sam': _cfg(
|
||||||
url='https://storage.googleapis.com/vit_models/sam/ViT-B_16.npz', custom_load=True),
|
url='https://storage.googleapis.com/vit_models/sam/ViT-B_16.npz', custom_load=True,
|
||||||
|
hf_hub_id='timm/'),
|
||||||
|
|
||||||
# DINO pretrained - https://arxiv.org/abs/2104.14294 (no classifier head, for fine-tune only)
|
# DINO pretrained - https://arxiv.org/abs/2104.14294 (no classifier head, for fine-tune only)
|
||||||
'vit_small_patch16_224.dino': _cfg(
|
'vit_small_patch16_224.dino': _cfg(
|
||||||
url='https://dl.fbaipublicfiles.com/dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth',
|
url='https://dl.fbaipublicfiles.com/dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth',
|
||||||
|
hf_hub_id='timm/',
|
||||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
|
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
|
||||||
'vit_small_patch8_224.dino': _cfg(
|
'vit_small_patch8_224.dino': _cfg(
|
||||||
url='https://dl.fbaipublicfiles.com/dino/dino_deitsmall8_pretrain/dino_deitsmall8_pretrain.pth',
|
url='https://dl.fbaipublicfiles.com/dino/dino_deitsmall8_pretrain/dino_deitsmall8_pretrain.pth',
|
||||||
|
hf_hub_id='timm/',
|
||||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
|
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
|
||||||
'vit_base_patch16_224.dino': _cfg(
|
'vit_base_patch16_224.dino': _cfg(
|
||||||
url='https://dl.fbaipublicfiles.com/dino/dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth',
|
url='https://dl.fbaipublicfiles.com/dino/dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth',
|
||||||
|
hf_hub_id='timm/',
|
||||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
|
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
|
||||||
'vit_base_patch8_224.dino': _cfg(
|
'vit_base_patch8_224.dino': _cfg(
|
||||||
url='https://dl.fbaipublicfiles.com/dino/dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth',
|
url='https://dl.fbaipublicfiles.com/dino/dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth',
|
||||||
|
hf_hub_id='timm/',
|
||||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
|
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
|
||||||
|
|
||||||
|
|
||||||
# ViT ImageNet-21K-P pretraining by MILL
|
# ViT ImageNet-21K-P pretraining by MILL
|
||||||
'vit_base_patch16_224_miil.in21k': _cfg(
|
'vit_base_patch16_224_miil.in21k': _cfg(
|
||||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/vit_base_patch16_224_in21k_miil-887286df.pth',
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/vit_base_patch16_224_in21k_miil-887286df.pth',
|
||||||
|
hf_hub_id='timm/',
|
||||||
mean=(0., 0., 0.), std=(1., 1., 1.), crop_pct=0.875, interpolation='bilinear', num_classes=11221),
|
mean=(0., 0., 0.), std=(1., 1., 1.), crop_pct=0.875, interpolation='bilinear', num_classes=11221),
|
||||||
'vit_base_patch16_224_miil.in21k_ft_in1k': _cfg(
|
'vit_base_patch16_224_miil.in21k_ft_in1k': _cfg(
|
||||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/vit_base_patch16_224_1k_miil_84_4-2deb18e3.pth',
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/vit_base_patch16_224_1k_miil_84_4-2deb18e3.pth',
|
||||||
|
hf_hub_id='timm/',
|
||||||
mean=(0., 0., 0.), std=(1., 1., 1.), crop_pct=0.875, interpolation='bilinear'),
|
mean=(0., 0., 0.), std=(1., 1., 1.), crop_pct=0.875, interpolation='bilinear'),
|
||||||
|
|
||||||
# custom timm variants
|
# custom timm variants
|
||||||
'vit_base_patch16_rpn_224.in1k': _cfg(
|
'vit_base_patch16_rpn_224.in1k': _cfg(
|
||||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_base_patch16_rpn_224-sw-3b07e89d.pth'),
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_base_patch16_rpn_224-sw-3b07e89d.pth',
|
||||||
|
hf_hub_id='timm/'),
|
||||||
'vit_medium_patch16_gap_240.in12k': _cfg(
|
'vit_medium_patch16_gap_240.in12k': _cfg(
|
||||||
hf_hub_id='timm/vit_medium_patch16_gap_240.in12k',
|
hf_hub_id='timm/',
|
||||||
input_size=(3, 240, 240), crop_pct=0.95, num_classes=11821),
|
input_size=(3, 240, 240), crop_pct=0.95, num_classes=11821),
|
||||||
'vit_medium_patch16_gap_256.in12k_ft_in1k': _cfg(
|
'vit_medium_patch16_gap_256.in12k_ft_in1k': _cfg(
|
||||||
hf_hub_id='timm/vit_medium_patch16_gap_256.in12k_ft_in1k',
|
hf_hub_id='timm/',
|
||||||
input_size=(3, 256, 256), crop_pct=0.95),
|
input_size=(3, 256, 256), crop_pct=0.95),
|
||||||
'vit_medium_patch16_gap_384.in12k_ft_in1k': _cfg(
|
'vit_medium_patch16_gap_384.in12k_ft_in1k': _cfg(
|
||||||
hf_hub_id='timm/vit_medium_patch16_gap_384.in12k_ft_in1k',
|
hf_hub_id='timm/',
|
||||||
input_size=(3, 384, 384), crop_pct=0.95, crop_mode='squash'),
|
input_size=(3, 384, 384), crop_pct=0.95, crop_mode='squash'),
|
||||||
'vit_base_patch16_gap_224': _cfg(),
|
'vit_base_patch16_gap_224': _cfg(),
|
||||||
|
|
||||||
@ -808,24 +918,24 @@ default_cfgs = generate_default_cfgs({
|
|||||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1024),
|
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1024),
|
||||||
|
|
||||||
'vit_base_patch32_clip_224.laion2b_ft_in1k': _cfg(
|
'vit_base_patch32_clip_224.laion2b_ft_in1k': _cfg(
|
||||||
hf_hub_id='timm/vit_base_patch32_clip_224.laion2b_ft_in1k',
|
hf_hub_id='timm/',
|
||||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD),
|
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD),
|
||||||
'vit_base_patch16_clip_224.laion2b_ft_in1k': _cfg(
|
'vit_base_patch16_clip_224.laion2b_ft_in1k': _cfg(
|
||||||
hf_hub_id='timm/vit_base_patch16_clip_224.laion2b_ft_in1k',
|
hf_hub_id='timm/',
|
||||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0),
|
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0),
|
||||||
'vit_base_patch16_clip_384.laion2b_ft_in1k': _cfg(
|
'vit_base_patch16_clip_384.laion2b_ft_in1k': _cfg(
|
||||||
hf_hub_id='timm/vit_base_patch16_clip_384.laion2b_ft_in1k',
|
hf_hub_id='timm/',
|
||||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
|
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
|
||||||
crop_pct=1.0, input_size=(3, 384, 384), crop_mode='squash'),
|
crop_pct=1.0, input_size=(3, 384, 384), crop_mode='squash'),
|
||||||
'vit_large_patch14_clip_224.laion2b_ft_in1k': _cfg(
|
'vit_large_patch14_clip_224.laion2b_ft_in1k': _cfg(
|
||||||
hf_hub_id='timm/vit_large_patch14_clip_224.laion2b_ft_in1k',
|
hf_hub_id='timm/',
|
||||||
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, crop_pct=1.0),
|
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, crop_pct=1.0),
|
||||||
'vit_large_patch14_clip_336.laion2b_ft_in1k': _cfg(
|
'vit_large_patch14_clip_336.laion2b_ft_in1k': _cfg(
|
||||||
hf_hub_id='timm/vit_large_patch14_clip_336.laion2b_ft_in1k',
|
hf_hub_id='timm/',
|
||||||
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD,
|
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD,
|
||||||
crop_pct=1.0, input_size=(3, 336, 336), crop_mode='squash'),
|
crop_pct=1.0, input_size=(3, 336, 336), crop_mode='squash'),
|
||||||
'vit_huge_patch14_clip_224.laion2b_ft_in1k': _cfg(
|
'vit_huge_patch14_clip_224.laion2b_ft_in1k': _cfg(
|
||||||
hf_hub_id='timm/vit_huge_patch14_clip_224.laion2b_ft_in1k',
|
hf_hub_id='timm/',
|
||||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0),
|
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0),
|
||||||
'vit_huge_patch14_clip_336.laion2b_ft_in1k': _cfg(
|
'vit_huge_patch14_clip_336.laion2b_ft_in1k': _cfg(
|
||||||
hf_hub_id='',
|
hf_hub_id='',
|
||||||
@ -833,33 +943,33 @@ default_cfgs = generate_default_cfgs({
|
|||||||
crop_pct=1.0, input_size=(3, 336, 336), crop_mode='squash'),
|
crop_pct=1.0, input_size=(3, 336, 336), crop_mode='squash'),
|
||||||
|
|
||||||
'vit_base_patch32_clip_224.laion2b_ft_in12k_in1k': _cfg(
|
'vit_base_patch32_clip_224.laion2b_ft_in12k_in1k': _cfg(
|
||||||
hf_hub_id='timm/vit_base_patch32_clip_224.laion2b_ft_in12k_in1k',
|
hf_hub_id='timm/',
|
||||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD),
|
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD),
|
||||||
'vit_base_patch32_clip_384.laion2b_ft_in12k_in1k': _cfg(
|
'vit_base_patch32_clip_384.laion2b_ft_in12k_in1k': _cfg(
|
||||||
hf_hub_id='timm/vit_base_patch32_clip_384.laion2b_ft_in12k_in1k',
|
hf_hub_id='timm/',
|
||||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, input_size=(3, 384, 384)),
|
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, input_size=(3, 384, 384)),
|
||||||
'vit_base_patch32_clip_448.laion2b_ft_in12k_in1k': _cfg(
|
'vit_base_patch32_clip_448.laion2b_ft_in12k_in1k': _cfg(
|
||||||
hf_hub_id='timm/vit_base_patch32_clip_448.laion2b_ft_in12k_in1k',
|
hf_hub_id='timm/',
|
||||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, input_size=(3, 448, 448)),
|
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, input_size=(3, 448, 448)),
|
||||||
'vit_base_patch16_clip_224.laion2b_ft_in12k_in1k': _cfg(
|
'vit_base_patch16_clip_224.laion2b_ft_in12k_in1k': _cfg(
|
||||||
hf_hub_id='timm/vit_base_patch16_clip_224.laion2b_ft_in12k_in1k',
|
hf_hub_id='timm/',
|
||||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=0.95),
|
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=0.95),
|
||||||
'vit_base_patch16_clip_384.laion2b_ft_in12k_in1k': _cfg(
|
'vit_base_patch16_clip_384.laion2b_ft_in12k_in1k': _cfg(
|
||||||
hf_hub_id='timm/vit_base_patch16_clip_384.laion2b_ft_in12k_in1k',
|
hf_hub_id='timm/',
|
||||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
|
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
|
||||||
crop_pct=1.0, input_size=(3, 384, 384), crop_mode='squash'),
|
crop_pct=1.0, input_size=(3, 384, 384), crop_mode='squash'),
|
||||||
'vit_large_patch14_clip_224.laion2b_ft_in12k_in1k': _cfg(
|
'vit_large_patch14_clip_224.laion2b_ft_in12k_in1k': _cfg(
|
||||||
hf_hub_id='timm/vit_large_patch14_clip_224.laion2b_ft_in12k_in1k',
|
hf_hub_id='timm/',
|
||||||
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, crop_pct=1.0),
|
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, crop_pct=1.0),
|
||||||
'vit_large_patch14_clip_336.laion2b_ft_in12k_in1k': _cfg(
|
'vit_large_patch14_clip_336.laion2b_ft_in12k_in1k': _cfg(
|
||||||
hf_hub_id='timm/vit_large_patch14_clip_336.laion2b_ft_in12k_in1k',
|
hf_hub_id='timm/',
|
||||||
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD,
|
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD,
|
||||||
crop_pct=1.0, input_size=(3, 336, 336), crop_mode='squash'),
|
crop_pct=1.0, input_size=(3, 336, 336), crop_mode='squash'),
|
||||||
'vit_huge_patch14_clip_224.laion2b_ft_in12k_in1k': _cfg(
|
'vit_huge_patch14_clip_224.laion2b_ft_in12k_in1k': _cfg(
|
||||||
hf_hub_id='timm/vit_huge_patch14_clip_224.laion2b_ft_in12k_in1k',
|
hf_hub_id='timm/',
|
||||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0),
|
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0),
|
||||||
'vit_huge_patch14_clip_336.laion2b_ft_in12k_in1k': _cfg(
|
'vit_huge_patch14_clip_336.laion2b_ft_in12k_in1k': _cfg(
|
||||||
hf_hub_id='timm/vit_huge_patch14_clip_336.laion2b_ft_in12k_in1k',
|
hf_hub_id='timm/',
|
||||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
|
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
|
||||||
crop_pct=1.0, input_size=(3, 336, 336), crop_mode='squash'),
|
crop_pct=1.0, input_size=(3, 336, 336), crop_mode='squash'),
|
||||||
|
|
||||||
@ -867,58 +977,58 @@ default_cfgs = generate_default_cfgs({
|
|||||||
#hf_hub_id='timm/vit_base_patch32_clip_224.laion2b_ft_in12k',
|
#hf_hub_id='timm/vit_base_patch32_clip_224.laion2b_ft_in12k',
|
||||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=11821),
|
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=11821),
|
||||||
'vit_base_patch16_clip_224.laion2b_ft_in12k': _cfg(
|
'vit_base_patch16_clip_224.laion2b_ft_in12k': _cfg(
|
||||||
hf_hub_id='timm/vit_base_patch16_clip_224.laion2b_ft_in12k',
|
hf_hub_id='timm/',
|
||||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=11821),
|
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=11821),
|
||||||
'vit_large_patch14_clip_224.laion2b_ft_in12k': _cfg(
|
'vit_large_patch14_clip_224.laion2b_ft_in12k': _cfg(
|
||||||
hf_hub_id='timm/vit_large_patch14_clip_224.laion2b_ft_in12k',
|
hf_hub_id='timm/',
|
||||||
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, crop_pct=1.0, num_classes=11821),
|
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, crop_pct=1.0, num_classes=11821),
|
||||||
'vit_huge_patch14_clip_224.laion2b_ft_in12k': _cfg(
|
'vit_huge_patch14_clip_224.laion2b_ft_in12k': _cfg(
|
||||||
hf_hub_id='timm/vit_huge_patch14_clip_224.laion2b_ft_in12k',
|
hf_hub_id='timm/',
|
||||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=11821),
|
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=11821),
|
||||||
|
|
||||||
'vit_base_patch32_clip_224.openai': _cfg(
|
'vit_base_patch32_clip_224.openai': _cfg(
|
||||||
hf_hub_id='timm/clip_vit_base_patch32_224.openai',
|
hf_hub_id='timm/',
|
||||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512),
|
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512),
|
||||||
'vit_base_patch16_clip_224.openai': _cfg(
|
'vit_base_patch16_clip_224.openai': _cfg(
|
||||||
hf_hub_id='timm/clip_vit_base_patch16_224.openai',
|
hf_hub_id='timm/',
|
||||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512),
|
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512),
|
||||||
'vit_large_patch14_clip_224.openai': _cfg(
|
'vit_large_patch14_clip_224.openai': _cfg(
|
||||||
hf_hub_id='timm/clip_vit_large_patch14_224.openai',
|
hf_hub_id='timm/',
|
||||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=768),
|
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=768),
|
||||||
|
|
||||||
'vit_base_patch32_clip_224.openai_ft_in1k': _cfg(
|
'vit_base_patch32_clip_224.openai_ft_in1k': _cfg(
|
||||||
hf_hub_id='timm/vit_base_patch32_clip_224.openai_ft_in1k',
|
hf_hub_id='timm/',
|
||||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD),
|
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD),
|
||||||
'vit_base_patch16_clip_224.openai_ft_in1k': _cfg(
|
'vit_base_patch16_clip_224.openai_ft_in1k': _cfg(
|
||||||
hf_hub_id='timm/vit_base_patch16_clip_224.openai_ft_in1k',
|
hf_hub_id='timm/',
|
||||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD),
|
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD),
|
||||||
'vit_base_patch16_clip_384.openai_ft_in1k': _cfg(
|
'vit_base_patch16_clip_384.openai_ft_in1k': _cfg(
|
||||||
hf_hub_id='timm/vit_base_patch16_clip_384.openai_ft_in1k',
|
hf_hub_id='timm/',
|
||||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
|
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
|
||||||
crop_pct=1.0, input_size=(3, 384, 384), crop_mode='squash'),
|
crop_pct=1.0, input_size=(3, 384, 384), crop_mode='squash'),
|
||||||
'vit_large_patch14_clip_224.openai_ft_in1k': _cfg(
|
'vit_large_patch14_clip_224.openai_ft_in1k': _cfg(
|
||||||
hf_hub_id='timm/vit_large_patch14_clip_224.openai_ft_in1k',
|
hf_hub_id='timm/',
|
||||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0),
|
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0),
|
||||||
|
|
||||||
'vit_base_patch32_clip_224.openai_ft_in12k_in1k': _cfg(
|
'vit_base_patch32_clip_224.openai_ft_in12k_in1k': _cfg(
|
||||||
#hf_hub_id='timm/vit_base_patch32_clip_224.openai_ft_in12k_in1k',
|
#hf_hub_id='timm/vit_base_patch32_clip_224.openai_ft_in12k_in1k',
|
||||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD),
|
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD),
|
||||||
'vit_base_patch32_clip_384.openai_ft_in12k_in1k': _cfg(
|
'vit_base_patch32_clip_384.openai_ft_in12k_in1k': _cfg(
|
||||||
hf_hub_id='timm/vit_base_patch32_clip_384.openai_ft_in12k_in1k',
|
hf_hub_id='timm/',
|
||||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
|
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
|
||||||
crop_pct=0.95, input_size=(3, 384, 384), crop_mode='squash'),
|
crop_pct=0.95, input_size=(3, 384, 384), crop_mode='squash'),
|
||||||
'vit_base_patch16_clip_224.openai_ft_in12k_in1k': _cfg(
|
'vit_base_patch16_clip_224.openai_ft_in12k_in1k': _cfg(
|
||||||
hf_hub_id='timm/vit_base_patch16_clip_224.openai_ft_in12k_in1k',
|
hf_hub_id='timm/',
|
||||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=0.95),
|
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=0.95),
|
||||||
'vit_base_patch16_clip_384.openai_ft_in12k_in1k': _cfg(
|
'vit_base_patch16_clip_384.openai_ft_in12k_in1k': _cfg(
|
||||||
hf_hub_id='timm/vit_base_patch16_clip_384.openai_ft_in12k_in1k',
|
hf_hub_id='timm/',
|
||||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
|
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
|
||||||
crop_pct=0.95, input_size=(3, 384, 384), crop_mode='squash'),
|
crop_pct=0.95, input_size=(3, 384, 384), crop_mode='squash'),
|
||||||
'vit_large_patch14_clip_224.openai_ft_in12k_in1k': _cfg(
|
'vit_large_patch14_clip_224.openai_ft_in12k_in1k': _cfg(
|
||||||
hf_hub_id='timm/vit_large_patch14_clip_224.openai_ft_in12k_in1k',
|
hf_hub_id='timm/',
|
||||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0),
|
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0),
|
||||||
'vit_large_patch14_clip_336.openai_ft_in12k_in1k': _cfg(
|
'vit_large_patch14_clip_336.openai_ft_in12k_in1k': _cfg(
|
||||||
hf_hub_id='timm/vit_large_patch14_clip_336.openai_ft_in12k_in1k',
|
hf_hub_id='timm/',
|
||||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
|
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
|
||||||
crop_pct=1.0, input_size=(3, 336, 336), crop_mode='squash'),
|
crop_pct=1.0, input_size=(3, 336, 336), crop_mode='squash'),
|
||||||
|
|
||||||
@ -926,10 +1036,10 @@ default_cfgs = generate_default_cfgs({
|
|||||||
#hf_hub_id='timm/vit_base_patch32_clip_224.openai_ft_in12k',
|
#hf_hub_id='timm/vit_base_patch32_clip_224.openai_ft_in12k',
|
||||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=11821),
|
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=11821),
|
||||||
'vit_base_patch16_clip_224.openai_ft_in12k': _cfg(
|
'vit_base_patch16_clip_224.openai_ft_in12k': _cfg(
|
||||||
hf_hub_id='timm/vit_base_patch16_clip_224.openai_ft_in12k',
|
hf_hub_id='timm/',
|
||||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=11821),
|
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=11821),
|
||||||
'vit_large_patch14_clip_224.openai_ft_in12k': _cfg(
|
'vit_large_patch14_clip_224.openai_ft_in12k': _cfg(
|
||||||
hf_hub_id='timm/vit_large_patch14_clip_224.openai_ft_in12k',
|
hf_hub_id='timm/',
|
||||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=11821),
|
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=11821),
|
||||||
|
|
||||||
# experimental (may be removed)
|
# experimental (may be removed)
|
||||||
@ -942,21 +1052,81 @@ default_cfgs = generate_default_cfgs({
|
|||||||
# EVA fine-tuned weights from MAE style MIM - EVA-CLIP target pretrain
|
# EVA fine-tuned weights from MAE style MIM - EVA-CLIP target pretrain
|
||||||
# https://github.com/baaivision/EVA/blob/7ecf2c0a370d97967e86d047d7af9188f78d2df3/eva/README.md#eva-l-learning-better-mim-representations-from-eva-clip
|
# https://github.com/baaivision/EVA/blob/7ecf2c0a370d97967e86d047d7af9188f78d2df3/eva/README.md#eva-l-learning-better-mim-representations-from-eva-clip
|
||||||
'eva_large_patch14_196.in22k_ft_in22k_in1k': _cfg(
|
'eva_large_patch14_196.in22k_ft_in22k_in1k': _cfg(
|
||||||
hf_hub_id='BAAI/EVA', hf_hub_filename='eva_l_psz14_196px_21k_to_1k_ft_88p6.pt',
|
# hf_hub_id='BAAI/EVA', hf_hub_filename='eva_l_psz14_196px_21k_to_1k_ft_88p6.pt',
|
||||||
|
hf_hub_id='timm/',
|
||||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
|
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
|
||||||
input_size=(3, 196, 196), crop_pct=1.0),
|
input_size=(3, 196, 196), crop_pct=1.0),
|
||||||
'eva_large_patch14_336.in22k_ft_in22k_in1k': _cfg(
|
'eva_large_patch14_336.in22k_ft_in22k_in1k': _cfg(
|
||||||
hf_hub_id='BAAI/EVA', hf_hub_filename='eva_l_psz14_336px_21k_to_1k_ft_89p2.pt',
|
# hf_hub_id='BAAI/EVA', hf_hub_filename='eva_l_psz14_336px_21k_to_1k_ft_89p2.pt',
|
||||||
|
hf_hub_id='timm/',
|
||||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
|
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
|
||||||
input_size=(3, 336, 336), crop_pct=1.0, crop_mode='squash'),
|
input_size=(3, 336, 336), crop_pct=1.0, crop_mode='squash'),
|
||||||
'eva_large_patch14_196.in22k_ft_in1k': _cfg(
|
'eva_large_patch14_196.in22k_ft_in1k': _cfg(
|
||||||
hf_hub_id='BAAI/EVA', hf_hub_filename='eva_l_psz14_196px_1k_ft_88p0.pt',
|
# hf_hub_id='BAAI/EVA', hf_hub_filename='eva_l_psz14_196px_1k_ft_88p0.pt',
|
||||||
|
hf_hub_id='timm/',
|
||||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
|
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
|
||||||
input_size=(3, 196, 196), crop_pct=1.0),
|
input_size=(3, 196, 196), crop_pct=1.0),
|
||||||
'eva_large_patch14_336.in22k_ft_in1k': _cfg(
|
'eva_large_patch14_336.in22k_ft_in1k': _cfg(
|
||||||
hf_hub_id='BAAI/EVA', hf_hub_filename='eva_l_psz14_336px_1k_ft_88p65.pt',
|
# hf_hub_id='BAAI/EVA', hf_hub_filename='eva_l_psz14_336px_1k_ft_88p65.pt',
|
||||||
|
hf_hub_id='timm/',
|
||||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
|
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
|
||||||
input_size=(3, 336, 336), crop_pct=1.0, crop_mode='squash'),
|
input_size=(3, 336, 336), crop_pct=1.0, crop_mode='squash'),
|
||||||
|
|
||||||
|
'flexivit_small.1200ep_in1k': _cfg(
|
||||||
|
url='https://storage.googleapis.com/big_vision/flexivit/flexivit_s_i1k.npz', custom_load=True,
|
||||||
|
hf_hub_id='timm/',
|
||||||
|
input_size=(3, 240, 240), crop_pct=0.95),
|
||||||
|
'flexivit_small.600ep_in1k': _cfg(
|
||||||
|
url='https://storage.googleapis.com/big_vision/flexivit/flexivit_s_i1k_600ep.npz', custom_load=True,
|
||||||
|
hf_hub_id='timm/',
|
||||||
|
input_size=(3, 240, 240), crop_pct=0.95),
|
||||||
|
'flexivit_small.300ep_in1k': _cfg(
|
||||||
|
url='https://storage.googleapis.com/big_vision/flexivit/flexivit_s_i1k_300ep.npz', custom_load=True,
|
||||||
|
hf_hub_id='timm/',
|
||||||
|
input_size=(3, 240, 240), crop_pct=0.95),
|
||||||
|
|
||||||
|
'flexivit_base.1200ep_in1k': _cfg(
|
||||||
|
url='https://storage.googleapis.com/big_vision/flexivit/flexivit_b_i1k.npz', custom_load=True,
|
||||||
|
hf_hub_id='timm/',
|
||||||
|
input_size=(3, 240, 240), crop_pct=0.95),
|
||||||
|
'flexivit_base.600ep_in1k': _cfg(
|
||||||
|
url='https://storage.googleapis.com/big_vision/flexivit/flexivit_b_i1k_600ep.npz', custom_load=True,
|
||||||
|
hf_hub_id='timm/',
|
||||||
|
input_size=(3, 240, 240), crop_pct=0.95),
|
||||||
|
'flexivit_base.300ep_in1k': _cfg(
|
||||||
|
url='https://storage.googleapis.com/big_vision/flexivit/flexivit_b_i1k_300ep.npz', custom_load=True,
|
||||||
|
hf_hub_id='timm/',
|
||||||
|
input_size=(3, 240, 240), crop_pct=0.95),
|
||||||
|
'flexivit_base.1000ep_in21k': _cfg(
|
||||||
|
url='https://storage.googleapis.com/big_vision/flexivit/flexivit_b_i21k_1000ep.npz', custom_load=True,
|
||||||
|
hf_hub_id='timm/',
|
||||||
|
input_size=(3, 240, 240), crop_pct=0.95, num_classes=21843),
|
||||||
|
'flexivit_base.300ep_in21k': _cfg(
|
||||||
|
url='https://storage.googleapis.com/big_vision/flexivit/flexivit_b_i21k_300ep.npz', custom_load=True,
|
||||||
|
hf_hub_id='timm/',
|
||||||
|
input_size=(3, 240, 240), crop_pct=0.95, num_classes=21843),
|
||||||
|
|
||||||
|
'flexivit_large.1200ep_in1k': _cfg(
|
||||||
|
url='https://storage.googleapis.com/big_vision/flexivit/flexivit_l_i1k.npz', custom_load=True,
|
||||||
|
hf_hub_id='timm/',
|
||||||
|
input_size=(3, 240, 240), crop_pct=0.95),
|
||||||
|
'flexivit_large.600ep_in1k': _cfg(
|
||||||
|
url='https://storage.googleapis.com/big_vision/flexivit/flexivit_l_i1k_600ep.npz', custom_load=True,
|
||||||
|
hf_hub_id='timm/',
|
||||||
|
input_size=(3, 240, 240), crop_pct=0.95),
|
||||||
|
'flexivit_large.300ep_in1k': _cfg(
|
||||||
|
url='https://storage.googleapis.com/big_vision/flexivit/flexivit_l_i1k_300ep.npz', custom_load=True,
|
||||||
|
hf_hub_id='timm/',
|
||||||
|
input_size=(3, 240, 240), crop_pct=0.95),
|
||||||
|
|
||||||
|
'flexivit_base.patch16_in21k': _cfg(
|
||||||
|
url='https://storage.googleapis.com/big_vision/flexivit/vit_b16_i21k_300ep.npz', custom_load=True,
|
||||||
|
hf_hub_id='timm/',
|
||||||
|
input_size=(3, 240, 240), crop_pct=0.95, num_classes=21843),
|
||||||
|
'flexivit_base.patch30_in21k': _cfg(
|
||||||
|
url='https://storage.googleapis.com/big_vision/flexivit/vit_b30_i21k_300ep.npz', custom_load=True,
|
||||||
|
hf_hub_id='timm/',
|
||||||
|
input_size=(3, 240, 240), crop_pct=0.95, num_classes=21843),
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
@ -964,9 +1134,16 @@ def _create_vision_transformer(variant, pretrained=False, **kwargs):
|
|||||||
if kwargs.get('features_only', None):
|
if kwargs.get('features_only', None):
|
||||||
raise RuntimeError('features_only not implemented for Vision Transformer models.')
|
raise RuntimeError('features_only not implemented for Vision Transformer models.')
|
||||||
|
|
||||||
|
if 'flexi' in variant:
|
||||||
|
# FIXME Google FlexiViT pretrained models have a strong preference for bilinear patch / embed
|
||||||
|
# interpolation, other pretrained models resize better w/ anti-aliased bicubic interpolation.
|
||||||
|
_filter_fn = partial(checkpoint_filter_fn, interpolation='bilinear', antialias=False)
|
||||||
|
else:
|
||||||
|
_filter_fn = checkpoint_filter_fn
|
||||||
|
|
||||||
return build_model_with_cfg(
|
return build_model_with_cfg(
|
||||||
VisionTransformer, variant, pretrained,
|
VisionTransformer, variant, pretrained,
|
||||||
pretrained_filter_fn=checkpoint_filter_fn,
|
pretrained_filter_fn=_filter_fn,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1396,3 +1573,30 @@ def eva_large_patch14_336(pretrained=False, **kwargs):
|
|||||||
patch_size=14, embed_dim=1024, depth=24, num_heads=16, global_pool='avg', **kwargs)
|
patch_size=14, embed_dim=1024, depth=24, num_heads=16, global_pool='avg', **kwargs)
|
||||||
model = _create_vision_transformer('eva_large_patch14_336', pretrained=pretrained, **model_kwargs)
|
model = _create_vision_transformer('eva_large_patch14_336', pretrained=pretrained, **model_kwargs)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def flexivit_small(pretrained=False, **kwargs):
|
||||||
|
""" FlexiViT-Small
|
||||||
|
"""
|
||||||
|
model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, no_embed_class=True, **kwargs)
|
||||||
|
model = _create_vision_transformer('flexivit_small', pretrained=pretrained, **model_kwargs)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def flexivit_base(pretrained=False, **kwargs):
|
||||||
|
""" FlexiViT-Base
|
||||||
|
"""
|
||||||
|
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, no_embed_class=True, **kwargs)
|
||||||
|
model = _create_vision_transformer('flexivit_base', pretrained=pretrained, **model_kwargs)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def flexivit_large(pretrained=False, **kwargs):
|
||||||
|
""" FlexiViT-Large
|
||||||
|
"""
|
||||||
|
model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, no_embed_class=True, **kwargs)
|
||||||
|
model = _create_vision_transformer('flexivit_large', pretrained=pretrained, **model_kwargs)
|
||||||
|
return model
|
||||||
|
@ -27,72 +27,6 @@ from .resnetv2 import ResNetV2, create_resnetv2_stem
|
|||||||
from .vision_transformer import _create_vision_transformer
|
from .vision_transformer import _create_vision_transformer
|
||||||
|
|
||||||
|
|
||||||
def _cfg(url='', **kwargs):
|
|
||||||
return {
|
|
||||||
'url': url,
|
|
||||||
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
|
|
||||||
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
|
|
||||||
'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
|
|
||||||
'first_conv': 'patch_embed.backbone.stem.conv', 'classifier': 'head',
|
|
||||||
**kwargs
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
default_cfgs = generate_default_cfgs({
|
|
||||||
# hybrid in-1k models (weights from official JAX impl where they exist)
|
|
||||||
'vit_tiny_r_s16_p8_224.augreg_in21k_ft_in1k': _cfg(
|
|
||||||
url='https://storage.googleapis.com/vit_models/augreg/R_Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz',
|
|
||||||
custom_load=True,
|
|
||||||
first_conv='patch_embed.backbone.conv'),
|
|
||||||
'vit_tiny_r_s16_p8_384.augreg_in21k_ft_in1k': _cfg(
|
|
||||||
url='https://storage.googleapis.com/vit_models/augreg/R_Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
|
|
||||||
first_conv='patch_embed.backbone.conv', input_size=(3, 384, 384), crop_pct=1.0, custom_load=True),
|
|
||||||
'vit_small_r26_s32_224.augreg_in21k_ft_in1k': _cfg(
|
|
||||||
url='https://storage.googleapis.com/vit_models/augreg/R26_S_32-i21k-300ep-lr_0.001-aug_light0-wd_0.03-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.03-res_224.npz',
|
|
||||||
custom_load=True,
|
|
||||||
),
|
|
||||||
'vit_small_r26_s32_384.augreg_in21k_ft_in1k': _cfg(
|
|
||||||
url='https://storage.googleapis.com/vit_models/augreg/R26_S_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
|
|
||||||
input_size=(3, 384, 384), crop_pct=1.0, custom_load=True),
|
|
||||||
'vit_base_r26_s32_224.untrained': _cfg(),
|
|
||||||
'vit_base_r50_s16_384.v1_in21k_ft_in1k': _cfg(
|
|
||||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_384-9fd3c705.pth',
|
|
||||||
input_size=(3, 384, 384), crop_pct=1.0),
|
|
||||||
'vit_large_r50_s32_224.augreg_in21k_ft_in1k': _cfg(
|
|
||||||
url='https://storage.googleapis.com/vit_models/augreg/R50_L_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz',
|
|
||||||
custom_load=True,
|
|
||||||
),
|
|
||||||
'vit_large_r50_s32_384.augreg_in21k_ft_in1k': _cfg(
|
|
||||||
url='https://storage.googleapis.com/vit_models/augreg/R50_L_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz',
|
|
||||||
input_size=(3, 384, 384), crop_pct=1.0, custom_load=True,
|
|
||||||
),
|
|
||||||
|
|
||||||
# hybrid in-21k models (weights from official Google JAX impl where they exist)
|
|
||||||
'vit_tiny_r_s16_p8_224.augreg_in21k': _cfg(
|
|
||||||
url='https://storage.googleapis.com/vit_models/augreg/R_Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz',
|
|
||||||
num_classes=21843, crop_pct=0.9, first_conv='patch_embed.backbone.conv', custom_load=True),
|
|
||||||
'vit_small_r26_s32_224.augreg_in21k': _cfg(
|
|
||||||
url='https://storage.googleapis.com/vit_models/augreg/R26_S_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.03-do_0.0-sd_0.0.npz',
|
|
||||||
num_classes=21843, crop_pct=0.9, custom_load=True),
|
|
||||||
'vit_base_r50_s16_224.v1_in21k': _cfg(
|
|
||||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_224_in21k-6f7c7740.pth',
|
|
||||||
num_classes=21843, crop_pct=0.9),
|
|
||||||
'vit_large_r50_s32_224.augreg_in21k': _cfg(
|
|
||||||
url='https://storage.googleapis.com/vit_models/augreg/R50_L_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.0-sd_0.0.npz',
|
|
||||||
num_classes=21843, crop_pct=0.9, custom_load=True),
|
|
||||||
|
|
||||||
# hybrid models (using timm resnet backbones)
|
|
||||||
'vit_small_resnet26d_224': _cfg(
|
|
||||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, first_conv='patch_embed.backbone.conv1.0'),
|
|
||||||
'vit_small_resnet50d_s16_224': _cfg(
|
|
||||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, first_conv='patch_embed.backbone.conv1.0'),
|
|
||||||
'vit_base_resnet26d_224': _cfg(
|
|
||||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, first_conv='patch_embed.backbone.conv1.0'),
|
|
||||||
'vit_base_resnet50d_224': _cfg(
|
|
||||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, first_conv='patch_embed.backbone.conv1.0'),
|
|
||||||
})
|
|
||||||
|
|
||||||
|
|
||||||
class HybridEmbed(nn.Module):
|
class HybridEmbed(nn.Module):
|
||||||
""" CNN Feature Map Embedding
|
""" CNN Feature Map Embedding
|
||||||
Extract feature map from CNN, flatten, project to embedding dim.
|
Extract feature map from CNN, flatten, project to embedding dim.
|
||||||
@ -166,6 +100,83 @@ def _resnetv2(layers=(3, 4, 9), **kwargs):
|
|||||||
return backbone
|
return backbone
|
||||||
|
|
||||||
|
|
||||||
|
def _cfg(url='', **kwargs):
|
||||||
|
return {
|
||||||
|
'url': url,
|
||||||
|
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
|
||||||
|
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
|
||||||
|
'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
|
||||||
|
'first_conv': 'patch_embed.backbone.stem.conv', 'classifier': 'head',
|
||||||
|
**kwargs
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
default_cfgs = generate_default_cfgs({
|
||||||
|
# hybrid in-1k models (weights from official JAX impl where they exist)
|
||||||
|
'vit_tiny_r_s16_p8_224.augreg_in21k_ft_in1k': _cfg(
|
||||||
|
url='https://storage.googleapis.com/vit_models/augreg/R_Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz',
|
||||||
|
hf_hub_id='timm/',
|
||||||
|
custom_load=True,
|
||||||
|
first_conv='patch_embed.backbone.conv'),
|
||||||
|
'vit_tiny_r_s16_p8_384.augreg_in21k_ft_in1k': _cfg(
|
||||||
|
url='https://storage.googleapis.com/vit_models/augreg/R_Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
|
||||||
|
hf_hub_id='timm/',
|
||||||
|
first_conv='patch_embed.backbone.conv', input_size=(3, 384, 384), crop_pct=1.0, custom_load=True),
|
||||||
|
'vit_small_r26_s32_224.augreg_in21k_ft_in1k': _cfg(
|
||||||
|
url='https://storage.googleapis.com/vit_models/augreg/R26_S_32-i21k-300ep-lr_0.001-aug_light0-wd_0.03-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.03-res_224.npz',
|
||||||
|
hf_hub_id='timm/',
|
||||||
|
custom_load=True,
|
||||||
|
),
|
||||||
|
'vit_small_r26_s32_384.augreg_in21k_ft_in1k': _cfg(
|
||||||
|
url='https://storage.googleapis.com/vit_models/augreg/R26_S_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
|
||||||
|
hf_hub_id='timm/',
|
||||||
|
input_size=(3, 384, 384), crop_pct=1.0, custom_load=True),
|
||||||
|
'vit_base_r26_s32_224.untrained': _cfg(),
|
||||||
|
'vit_base_r50_s16_384.orig_in21k_ft_in1k': _cfg(
|
||||||
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_384-9fd3c705.pth',
|
||||||
|
hf_hub_id='timm/',
|
||||||
|
input_size=(3, 384, 384), crop_pct=1.0),
|
||||||
|
'vit_large_r50_s32_224.augreg_in21k_ft_in1k': _cfg(
|
||||||
|
url='https://storage.googleapis.com/vit_models/augreg/R50_L_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz',
|
||||||
|
hf_hub_id='timm/',
|
||||||
|
custom_load=True,
|
||||||
|
),
|
||||||
|
'vit_large_r50_s32_384.augreg_in21k_ft_in1k': _cfg(
|
||||||
|
url='https://storage.googleapis.com/vit_models/augreg/R50_L_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz',
|
||||||
|
hf_hub_id='timm/',
|
||||||
|
input_size=(3, 384, 384), crop_pct=1.0, custom_load=True,
|
||||||
|
),
|
||||||
|
|
||||||
|
# hybrid in-21k models (weights from official Google JAX impl where they exist)
|
||||||
|
'vit_tiny_r_s16_p8_224.augreg_in21k': _cfg(
|
||||||
|
url='https://storage.googleapis.com/vit_models/augreg/R_Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz',
|
||||||
|
hf_hub_id='timm/',
|
||||||
|
num_classes=21843, crop_pct=0.9, first_conv='patch_embed.backbone.conv', custom_load=True),
|
||||||
|
'vit_small_r26_s32_224.augreg_in21k': _cfg(
|
||||||
|
url='https://storage.googleapis.com/vit_models/augreg/R26_S_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.03-do_0.0-sd_0.0.npz',
|
||||||
|
hf_hub_id='timm/',
|
||||||
|
num_classes=21843, crop_pct=0.9, custom_load=True),
|
||||||
|
'vit_base_r50_s16_224.orig_in21k': _cfg(
|
||||||
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_224_in21k-6f7c7740.pth',
|
||||||
|
hf_hub_id='timm/',
|
||||||
|
num_classes=21843, crop_pct=0.9),
|
||||||
|
'vit_large_r50_s32_224.augreg_in21k': _cfg(
|
||||||
|
url='https://storage.googleapis.com/vit_models/augreg/R50_L_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.0-sd_0.0.npz',
|
||||||
|
hf_hub_id='timm/',
|
||||||
|
num_classes=21843, crop_pct=0.9, custom_load=True),
|
||||||
|
|
||||||
|
# hybrid models (using timm resnet backbones)
|
||||||
|
'vit_small_resnet26d_224.untrained': _cfg(
|
||||||
|
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, first_conv='patch_embed.backbone.conv1.0'),
|
||||||
|
'vit_small_resnet50d_s16_224.untrained': _cfg(
|
||||||
|
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, first_conv='patch_embed.backbone.conv1.0'),
|
||||||
|
'vit_base_resnet26d_224.untrained': _cfg(
|
||||||
|
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, first_conv='patch_embed.backbone.conv1.0'),
|
||||||
|
'vit_base_resnet50d_224.untrained': _cfg(
|
||||||
|
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, first_conv='patch_embed.backbone.conv1.0'),
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def vit_tiny_r_s16_p8_224(pretrained=False, **kwargs):
|
def vit_tiny_r_s16_p8_224(pretrained=False, **kwargs):
|
||||||
""" R+ViT-Ti/S16 w/ 8x8 patch hybrid @ 224 x 224.
|
""" R+ViT-Ti/S16 w/ 8x8 patch hybrid @ 224 x 224.
|
||||||
|
@ -11,12 +11,12 @@ from typing import Optional, Tuple
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
|
||||||
from torch.utils.checkpoint import checkpoint
|
from torch.utils.checkpoint import checkpoint
|
||||||
|
|
||||||
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
||||||
from timm.layers import PatchEmbed, Mlp, DropPath, trunc_normal_
|
from timm.layers import PatchEmbed, Mlp, DropPath, RelPosMlp, RelPosBias
|
||||||
from ._builder import build_model_with_cfg
|
from ._builder import build_model_with_cfg
|
||||||
|
from ._pretrained import generate_default_cfgs
|
||||||
from ._registry import register_model
|
from ._registry import register_model
|
||||||
|
|
||||||
__all__ = ['VisionTransformerRelPos'] # model_registry will add each entrypoint fn to this
|
__all__ = ['VisionTransformerRelPos'] # model_registry will add each entrypoint fn to this
|
||||||
@ -24,216 +24,6 @@ __all__ = ['VisionTransformerRelPos'] # model_registry will add each entrypoint
|
|||||||
_logger = logging.getLogger(__name__)
|
_logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _cfg(url='', **kwargs):
|
|
||||||
return {
|
|
||||||
'url': url,
|
|
||||||
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
|
|
||||||
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
|
|
||||||
'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
|
|
||||||
'first_conv': 'patch_embed.proj', 'classifier': 'head',
|
|
||||||
**kwargs
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
default_cfgs = {
|
|
||||||
'vit_relpos_base_patch32_plus_rpn_256': _cfg(
|
|
||||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_replos_base_patch32_plus_rpn_256-sw-dd486f51.pth',
|
|
||||||
input_size=(3, 256, 256)),
|
|
||||||
'vit_relpos_base_patch16_plus_240': _cfg(url='', input_size=(3, 240, 240)),
|
|
||||||
|
|
||||||
'vit_relpos_small_patch16_224': _cfg(
|
|
||||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_relpos_small_patch16_224-sw-ec2778b4.pth'),
|
|
||||||
'vit_relpos_medium_patch16_224': _cfg(
|
|
||||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_relpos_medium_patch16_224-sw-11c174af.pth'),
|
|
||||||
'vit_relpos_base_patch16_224': _cfg(
|
|
||||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_relpos_base_patch16_224-sw-49049aed.pth'),
|
|
||||||
|
|
||||||
'vit_srelpos_small_patch16_224': _cfg(
|
|
||||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_srelpos_small_patch16_224-sw-6cdb8849.pth'),
|
|
||||||
'vit_srelpos_medium_patch16_224': _cfg(
|
|
||||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_srelpos_medium_patch16_224-sw-ad702b8c.pth'),
|
|
||||||
|
|
||||||
'vit_relpos_medium_patch16_cls_224': _cfg(
|
|
||||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_relpos_medium_patch16_cls_224-sw-cfe8e259.pth'),
|
|
||||||
'vit_relpos_base_patch16_cls_224': _cfg(
|
|
||||||
url=''),
|
|
||||||
'vit_relpos_base_patch16_clsgap_224': _cfg(
|
|
||||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_relpos_base_patch16_gapcls_224-sw-1a341d6c.pth'),
|
|
||||||
|
|
||||||
'vit_relpos_small_patch16_rpn_224': _cfg(url=''),
|
|
||||||
'vit_relpos_medium_patch16_rpn_224': _cfg(
|
|
||||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_relpos_medium_patch16_rpn_224-sw-5d2befd8.pth'),
|
|
||||||
'vit_relpos_base_patch16_rpn_224': _cfg(url=''),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def gen_relative_position_index(
|
|
||||||
q_size: Tuple[int, int],
|
|
||||||
k_size: Tuple[int, int] = None,
|
|
||||||
class_token: bool = False) -> torch.Tensor:
|
|
||||||
# Adapted with significant modifications from Swin / BeiT codebases
|
|
||||||
# get pair-wise relative position index for each token inside the window
|
|
||||||
q_coords = torch.stack(torch.meshgrid([torch.arange(q_size[0]), torch.arange(q_size[1])])).flatten(1) # 2, Wh, Ww
|
|
||||||
if k_size is None:
|
|
||||||
k_coords = q_coords
|
|
||||||
k_size = q_size
|
|
||||||
else:
|
|
||||||
# different q vs k sizes is a WIP
|
|
||||||
k_coords = torch.stack(torch.meshgrid([torch.arange(k_size[0]), torch.arange(k_size[1])])).flatten(1)
|
|
||||||
relative_coords = q_coords[:, :, None] - k_coords[:, None, :] # 2, Wh*Ww, Wh*Ww
|
|
||||||
relative_coords = relative_coords.permute(1, 2, 0) # Wh*Ww, Wh*Ww, 2
|
|
||||||
_, relative_position_index = torch.unique(relative_coords.view(-1, 2), return_inverse=True, dim=0)
|
|
||||||
|
|
||||||
if class_token:
|
|
||||||
# handle cls to token & token 2 cls & cls to cls as per beit for rel pos bias
|
|
||||||
# NOTE not intended or tested with MLP log-coords
|
|
||||||
max_size = (max(q_size[0], k_size[0]), max(q_size[1], k_size[1]))
|
|
||||||
num_relative_distance = (2 * max_size[0] - 1) * (2 * max_size[1] - 1) + 3
|
|
||||||
relative_position_index = F.pad(relative_position_index, [1, 0, 1, 0])
|
|
||||||
relative_position_index[0, 0:] = num_relative_distance - 3
|
|
||||||
relative_position_index[0:, 0] = num_relative_distance - 2
|
|
||||||
relative_position_index[0, 0] = num_relative_distance - 1
|
|
||||||
|
|
||||||
return relative_position_index.contiguous()
|
|
||||||
|
|
||||||
|
|
||||||
def gen_relative_log_coords(
|
|
||||||
win_size: Tuple[int, int],
|
|
||||||
pretrained_win_size: Tuple[int, int] = (0, 0),
|
|
||||||
mode='swin',
|
|
||||||
):
|
|
||||||
assert mode in ('swin', 'cr', 'rw')
|
|
||||||
# as per official swin-v2 impl, supporting timm specific 'cr' and 'rw' log coords as well
|
|
||||||
relative_coords_h = torch.arange(-(win_size[0] - 1), win_size[0], dtype=torch.float32)
|
|
||||||
relative_coords_w = torch.arange(-(win_size[1] - 1), win_size[1], dtype=torch.float32)
|
|
||||||
relative_coords_table = torch.stack(torch.meshgrid([relative_coords_h, relative_coords_w]))
|
|
||||||
relative_coords_table = relative_coords_table.permute(1, 2, 0).contiguous() # 2*Wh-1, 2*Ww-1, 2
|
|
||||||
if mode == 'swin':
|
|
||||||
if pretrained_win_size[0] > 0:
|
|
||||||
relative_coords_table[:, :, 0] /= (pretrained_win_size[0] - 1)
|
|
||||||
relative_coords_table[:, :, 1] /= (pretrained_win_size[1] - 1)
|
|
||||||
else:
|
|
||||||
relative_coords_table[:, :, 0] /= (win_size[0] - 1)
|
|
||||||
relative_coords_table[:, :, 1] /= (win_size[1] - 1)
|
|
||||||
relative_coords_table *= 8 # normalize to -8, 8
|
|
||||||
relative_coords_table = torch.sign(relative_coords_table) * torch.log2(
|
|
||||||
1.0 + relative_coords_table.abs()) / math.log2(8)
|
|
||||||
else:
|
|
||||||
if mode == 'rw':
|
|
||||||
# cr w/ window size normalization -> [-1,1] log coords
|
|
||||||
relative_coords_table[:, :, 0] /= (win_size[0] - 1)
|
|
||||||
relative_coords_table[:, :, 1] /= (win_size[1] - 1)
|
|
||||||
relative_coords_table *= 8 # scale to -8, 8
|
|
||||||
relative_coords_table = torch.sign(relative_coords_table) * torch.log2(
|
|
||||||
1.0 + relative_coords_table.abs())
|
|
||||||
relative_coords_table /= math.log2(9) # -> [-1, 1]
|
|
||||||
else:
|
|
||||||
# mode == 'cr'
|
|
||||||
relative_coords_table = torch.sign(relative_coords_table) * torch.log(
|
|
||||||
1.0 + relative_coords_table.abs())
|
|
||||||
|
|
||||||
return relative_coords_table
|
|
||||||
|
|
||||||
|
|
||||||
class RelPosMlp(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
window_size,
|
|
||||||
num_heads=8,
|
|
||||||
hidden_dim=128,
|
|
||||||
prefix_tokens=0,
|
|
||||||
mode='cr',
|
|
||||||
pretrained_window_size=(0, 0)
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.window_size = window_size
|
|
||||||
self.window_area = self.window_size[0] * self.window_size[1]
|
|
||||||
self.prefix_tokens = prefix_tokens
|
|
||||||
self.num_heads = num_heads
|
|
||||||
self.bias_shape = (self.window_area,) * 2 + (num_heads,)
|
|
||||||
if mode == 'swin':
|
|
||||||
self.bias_act = nn.Sigmoid()
|
|
||||||
self.bias_gain = 16
|
|
||||||
mlp_bias = (True, False)
|
|
||||||
elif mode == 'rw':
|
|
||||||
self.bias_act = nn.Tanh()
|
|
||||||
self.bias_gain = 4
|
|
||||||
mlp_bias = True
|
|
||||||
else:
|
|
||||||
self.bias_act = nn.Identity()
|
|
||||||
self.bias_gain = None
|
|
||||||
mlp_bias = True
|
|
||||||
|
|
||||||
self.mlp = Mlp(
|
|
||||||
2, # x, y
|
|
||||||
hidden_features=hidden_dim,
|
|
||||||
out_features=num_heads,
|
|
||||||
act_layer=nn.ReLU,
|
|
||||||
bias=mlp_bias,
|
|
||||||
drop=(0.125, 0.)
|
|
||||||
)
|
|
||||||
|
|
||||||
self.register_buffer(
|
|
||||||
"relative_position_index",
|
|
||||||
gen_relative_position_index(window_size),
|
|
||||||
persistent=False)
|
|
||||||
|
|
||||||
# get relative_coords_table
|
|
||||||
self.register_buffer(
|
|
||||||
"rel_coords_log",
|
|
||||||
gen_relative_log_coords(window_size, pretrained_window_size, mode=mode),
|
|
||||||
persistent=False)
|
|
||||||
|
|
||||||
def get_bias(self) -> torch.Tensor:
|
|
||||||
relative_position_bias = self.mlp(self.rel_coords_log)
|
|
||||||
if self.relative_position_index is not None:
|
|
||||||
relative_position_bias = relative_position_bias.view(-1, self.num_heads)[
|
|
||||||
self.relative_position_index.view(-1)] # Wh*Ww,Wh*Ww,nH
|
|
||||||
relative_position_bias = relative_position_bias.view(self.bias_shape)
|
|
||||||
relative_position_bias = relative_position_bias.permute(2, 0, 1)
|
|
||||||
relative_position_bias = self.bias_act(relative_position_bias)
|
|
||||||
if self.bias_gain is not None:
|
|
||||||
relative_position_bias = self.bias_gain * relative_position_bias
|
|
||||||
if self.prefix_tokens:
|
|
||||||
relative_position_bias = F.pad(relative_position_bias, [self.prefix_tokens, 0, self.prefix_tokens, 0])
|
|
||||||
return relative_position_bias.unsqueeze(0).contiguous()
|
|
||||||
|
|
||||||
def forward(self, attn, shared_rel_pos: Optional[torch.Tensor] = None):
|
|
||||||
return attn + self.get_bias()
|
|
||||||
|
|
||||||
|
|
||||||
class RelPosBias(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, window_size, num_heads, prefix_tokens=0):
|
|
||||||
super().__init__()
|
|
||||||
assert prefix_tokens <= 1
|
|
||||||
self.window_size = window_size
|
|
||||||
self.window_area = window_size[0] * window_size[1]
|
|
||||||
self.bias_shape = (self.window_area + prefix_tokens,) * 2 + (num_heads,)
|
|
||||||
|
|
||||||
num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 * prefix_tokens
|
|
||||||
self.relative_position_bias_table = nn.Parameter(torch.zeros(num_relative_distance, num_heads))
|
|
||||||
self.register_buffer(
|
|
||||||
"relative_position_index",
|
|
||||||
gen_relative_position_index(self.window_size, class_token=prefix_tokens > 0),
|
|
||||||
persistent=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.init_weights()
|
|
||||||
|
|
||||||
def init_weights(self):
|
|
||||||
trunc_normal_(self.relative_position_bias_table, std=.02)
|
|
||||||
|
|
||||||
def get_bias(self) -> torch.Tensor:
|
|
||||||
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)]
|
|
||||||
# win_h * win_w, win_h * win_w, num_heads
|
|
||||||
relative_position_bias = relative_position_bias.view(self.bias_shape).permute(2, 0, 1)
|
|
||||||
return relative_position_bias.unsqueeze(0).contiguous()
|
|
||||||
|
|
||||||
def forward(self, attn, shared_rel_pos: Optional[torch.Tensor] = None):
|
|
||||||
return attn + self.get_bias()
|
|
||||||
|
|
||||||
|
|
||||||
class RelPosAttention(nn.Module):
|
class RelPosAttention(nn.Module):
|
||||||
def __init__(self, dim, num_heads=8, qkv_bias=False, rel_pos_cls=None, attn_drop=0., proj_drop=0.):
|
def __init__(self, dim, num_heads=8, qkv_bias=False, rel_pos_cls=None, attn_drop=0., proj_drop=0.):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -513,6 +303,57 @@ def _create_vision_transformer_relpos(variant, pretrained=False, **kwargs):
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def _cfg(url='', **kwargs):
|
||||||
|
return {
|
||||||
|
'url': url,
|
||||||
|
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
|
||||||
|
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
|
||||||
|
'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
|
||||||
|
'first_conv': 'patch_embed.proj', 'classifier': 'head',
|
||||||
|
**kwargs
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
default_cfgs = generate_default_cfgs({
|
||||||
|
'vit_relpos_base_patch32_plus_rpn_256.sw_in1k': _cfg(
|
||||||
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_replos_base_patch32_plus_rpn_256-sw-dd486f51.pth',
|
||||||
|
hf_hub_id='timm/',
|
||||||
|
input_size=(3, 256, 256)),
|
||||||
|
'vit_relpos_base_patch16_plus_240.untrained': _cfg(url='', input_size=(3, 240, 240)),
|
||||||
|
|
||||||
|
'vit_relpos_small_patch16_224.sw_in1k': _cfg(
|
||||||
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_relpos_small_patch16_224-sw-ec2778b4.pth',
|
||||||
|
hf_hub_id='timm/'),
|
||||||
|
'vit_relpos_medium_patch16_224.sw_in1k': _cfg(
|
||||||
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_relpos_medium_patch16_224-sw-11c174af.pth',
|
||||||
|
hf_hub_id='timm/'),
|
||||||
|
'vit_relpos_base_patch16_224.sw_in1k': _cfg(
|
||||||
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_relpos_base_patch16_224-sw-49049aed.pth',
|
||||||
|
hf_hub_id='timm/'),
|
||||||
|
|
||||||
|
'vit_srelpos_small_patch16_224.sw_in1k': _cfg(
|
||||||
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_srelpos_small_patch16_224-sw-6cdb8849.pth',
|
||||||
|
hf_hub_id='timm/'),
|
||||||
|
'vit_srelpos_medium_patch16_224.sw_in1k': _cfg(
|
||||||
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_srelpos_medium_patch16_224-sw-ad702b8c.pth',
|
||||||
|
hf_hub_id='timm/'),
|
||||||
|
|
||||||
|
'vit_relpos_medium_patch16_cls_224.sw_in1k': _cfg(
|
||||||
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_relpos_medium_patch16_cls_224-sw-cfe8e259.pth',
|
||||||
|
hf_hub_id='timm/'),
|
||||||
|
'vit_relpos_base_patch16_cls_224.untrained': _cfg(),
|
||||||
|
'vit_relpos_base_patch16_clsgap_224.sw_in1k': _cfg(
|
||||||
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_relpos_base_patch16_gapcls_224-sw-1a341d6c.pth',
|
||||||
|
hf_hub_id='timm/'),
|
||||||
|
|
||||||
|
'vit_relpos_small_patch16_rpn_224.untrained': _cfg(),
|
||||||
|
'vit_relpos_medium_patch16_rpn_224.sw_in1k': _cfg(
|
||||||
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_relpos_medium_patch16_rpn_224-sw-5d2befd8.pth',
|
||||||
|
hf_hub_id='timm/'),
|
||||||
|
'vit_relpos_base_patch16_rpn_224.untrained': _cfg(),
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def vit_relpos_base_patch32_plus_rpn_256(pretrained=False, **kwargs):
|
def vit_relpos_base_patch32_plus_rpn_256(pretrained=False, **kwargs):
|
||||||
""" ViT-Base (ViT-B/32+) w/ relative log-coord position and residual post-norm, no class token
|
""" ViT-Base (ViT-B/32+) w/ relative log-coord position and residual post-norm, no class token
|
||||||
|
@ -1 +1 @@
|
|||||||
__version__ = '0.8.1dev0'
|
__version__ = '0.8.2dev0'
|
||||||
|
Loading…
x
Reference in New Issue
Block a user