Merge pull request #1467 from rwightman/clip_laion2b
Adding support for fine-tune CLIP LAION-2B image tower weights for B/32, L/14, H/14, and g/14.pull/1476/head
commit
d199f6651d
|
@ -1,3 +1,4 @@
|
|||
torch>=1.4.0
|
||||
torchvision>=0.5.0
|
||||
torch>=1.7
|
||||
torchvision
|
||||
pyyaml
|
||||
huggingface_hub
|
||||
|
|
9
setup.py
9
setup.py
|
@ -25,13 +25,15 @@ setup(
|
|||
# 3 - Alpha
|
||||
# 4 - Beta
|
||||
# 5 - Production/Stable
|
||||
'Development Status :: 3 - Alpha',
|
||||
'Development Status :: 4 - Beta',
|
||||
'Intended Audience :: Education',
|
||||
'Intended Audience :: Science/Research',
|
||||
'License :: OSI Approved :: Apache Software License',
|
||||
'Programming Language :: Python :: 3.6',
|
||||
'Programming Language :: Python :: 3.7',
|
||||
'Programming Language :: Python :: 3.8',
|
||||
'Programming Language :: Python :: 3.9',
|
||||
'Programming Language :: Python :: 3.10',
|
||||
'Topic :: Scientific/Engineering',
|
||||
'Topic :: Scientific/Engineering :: Artificial Intelligence',
|
||||
'Topic :: Software Development',
|
||||
|
@ -40,9 +42,10 @@ setup(
|
|||
],
|
||||
|
||||
# Note that this is a string of words separated by whitespace, not a list.
|
||||
keywords='pytorch pretrained models efficientnet mobilenetv3 mnasnet',
|
||||
keywords='pytorch pretrained models efficientnet mobilenetv3 mnasnet resnet vision transformer vit',
|
||||
packages=find_packages(exclude=['convert', 'tests', 'results']),
|
||||
include_package_data=True,
|
||||
install_requires=['torch >= 1.4', 'torchvision'],
|
||||
install_requires=['torch >= 1.7', 'torchvision', 'pyyaml', 'huggingface_hub'],
|
||||
python_requires='>=3.6',
|
||||
)
|
||||
|
||||
|
|
|
@ -5,3 +5,5 @@ IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5)
|
|||
IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5)
|
||||
IMAGENET_DPN_MEAN = (124 / 255, 117 / 255, 104 / 255)
|
||||
IMAGENET_DPN_STD = tuple([1 / (.0167 * 255)] * 3)
|
||||
OPENAI_CLIP_MEAN = (0.48145466, 0.4578275, 0.40821073)
|
||||
OPENAI_CLIP_STD = (0.26862954, 0.26130258, 0.27577711)
|
||||
|
|
|
@ -138,6 +138,9 @@ def _resolve_pretrained_source(pretrained_cfg):
|
|||
# hf-hub available as alternate weight source in default_cfg
|
||||
load_from = 'hf-hub'
|
||||
pretrained_loc = hf_hub_id
|
||||
if load_from == 'hf-hub' and 'hf_hub_filename' in pretrained_cfg:
|
||||
# if a filename override is set, return tuple for location w/ (hub_id, filename)
|
||||
pretrained_loc = pretrained_loc, pretrained_cfg['hf_hub_filename']
|
||||
return load_from, pretrained_loc
|
||||
|
||||
|
||||
|
@ -246,7 +249,10 @@ def load_pretrained(
|
|||
pretrained_loc, map_location='cpu', progress=_DOWNLOAD_PROGRESS, check_hash=_CHECK_HASH)
|
||||
elif load_from == 'hf-hub':
|
||||
_logger.info(f'Loading pretrained weights from Hugging Face hub ({pretrained_loc})')
|
||||
state_dict = load_state_dict_from_hf(pretrained_loc)
|
||||
if isinstance(pretrained_loc, (list, tuple)):
|
||||
state_dict = load_state_dict_from_hf(*pretrained_loc)
|
||||
else:
|
||||
state_dict = load_state_dict_from_hf(pretrained_loc)
|
||||
else:
|
||||
_logger.warning("No pretrained weights exist or were found for this model. Using random initialization.")
|
||||
return
|
||||
|
|
|
@ -13,6 +13,7 @@ except ImportError:
|
|||
from torch.hub import _get_torch_home as get_dir
|
||||
|
||||
from timm import __version__
|
||||
|
||||
try:
|
||||
from huggingface_hub import HfApi, HfFolder, Repository, hf_hub_download, hf_hub_url
|
||||
hf_hub_download = partial(hf_hub_download, library_name="timm", library_version=__version__)
|
||||
|
@ -55,7 +56,7 @@ def download_cached_file(url, check_hash=True, progress=False):
|
|||
|
||||
def has_hf_hub(necessary=False):
|
||||
if not _has_hf_hub and necessary:
|
||||
# if no HF Hub module installed and it is necessary to continue, raise error
|
||||
# if no HF Hub module installed, and it is necessary to continue, raise error
|
||||
raise RuntimeError(
|
||||
'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.')
|
||||
return _has_hf_hub
|
||||
|
@ -78,7 +79,7 @@ def load_cfg_from_json(json_file: Union[str, os.PathLike]):
|
|||
|
||||
def _download_from_hf(model_id: str, filename: str):
|
||||
hf_model_id, hf_revision = hf_split(model_id)
|
||||
return hf_hub_download(hf_model_id, filename, revision=hf_revision, cache_dir=get_cache_dir('hf'))
|
||||
return hf_hub_download(hf_model_id, filename, revision=hf_revision)
|
||||
|
||||
|
||||
def load_model_config_from_hf(model_id: str):
|
||||
|
@ -91,9 +92,9 @@ def load_model_config_from_hf(model_id: str):
|
|||
return pretrained_cfg, model_name
|
||||
|
||||
|
||||
def load_state_dict_from_hf(model_id: str):
|
||||
def load_state_dict_from_hf(model_id: str, filename: str = 'pytorch_model.bin'):
|
||||
assert has_hf_hub(True)
|
||||
cached_file = _download_from_hf(model_id, 'pytorch_model.bin')
|
||||
cached_file = _download_from_hf(model_id, filename)
|
||||
state_dict = torch.load(cached_file, map_location='cpu')
|
||||
return state_dict
|
||||
|
||||
|
|
|
@ -15,7 +15,16 @@ from .trace_utils import _assert
|
|||
class PatchEmbed(nn.Module):
|
||||
""" 2D Image to Patch Embedding
|
||||
"""
|
||||
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True):
|
||||
def __init__(
|
||||
self,
|
||||
img_size=224,
|
||||
patch_size=16,
|
||||
in_chans=3,
|
||||
embed_dim=768,
|
||||
norm_layer=None,
|
||||
flatten=True,
|
||||
bias=True,
|
||||
):
|
||||
super().__init__()
|
||||
img_size = to_2tuple(img_size)
|
||||
patch_size = to_2tuple(patch_size)
|
||||
|
@ -25,7 +34,7 @@ class PatchEmbed(nn.Module):
|
|||
self.num_patches = self.grid_size[0] * self.grid_size[1]
|
||||
self.flatten = flatten
|
||||
|
||||
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
||||
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
|
||||
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
|
|
|
@ -30,7 +30,8 @@ import torch.nn as nn
|
|||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD,\
|
||||
OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
|
||||
from .helpers import build_model_with_cfg, resolve_pretrained_cfg, named_apply, adapt_input_conv, checkpoint_seq
|
||||
from .layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_
|
||||
from .registry import register_model
|
||||
|
@ -177,6 +178,24 @@ default_cfgs = {
|
|||
'vit_small_patch16_36x1_224': _cfg(url=''),
|
||||
'vit_small_patch16_18x2_224': _cfg(url=''),
|
||||
'vit_base_patch16_18x2_224': _cfg(url=''),
|
||||
|
||||
'vit_base_patch32_224_clip_laion2b': _cfg(
|
||||
hf_hub_id='laion/CLIP-ViT-B-32-laion2B-s34B-b79K',
|
||||
hf_hub_filename='open_clip_pytorch_model.bin',
|
||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512),
|
||||
'vit_large_patch14_224_clip_laion2b': _cfg(
|
||||
hf_hub_id='laion/CLIP-ViT-L-14-laion2B-s32B-b82K',
|
||||
hf_hub_filename='open_clip_pytorch_model.bin',
|
||||
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, num_classes=768),
|
||||
'vit_huge_patch14_224_clip_laion2b': _cfg(
|
||||
hf_hub_id='laion/CLIP-ViT-H-14-laion2B-s32B-b79K',
|
||||
hf_hub_filename='open_clip_pytorch_model.bin',
|
||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=1024),
|
||||
'vit_giant_patch14_224_clip_laion2b': _cfg(
|
||||
hf_hub_id='CLIP-ViT-g-14-laion2B-s12B-b42K',
|
||||
hf_hub_filename='open_clip_pytorch_model.bin',
|
||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=1024),
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
@ -221,8 +240,18 @@ class LayerScale(nn.Module):
|
|||
class Block(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., init_values=None,
|
||||
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
|
||||
self,
|
||||
dim,
|
||||
num_heads,
|
||||
mlp_ratio=4.,
|
||||
qkv_bias=False,
|
||||
drop=0.,
|
||||
attn_drop=0.,
|
||||
init_values=None,
|
||||
drop_path=0.,
|
||||
act_layer=nn.GELU,
|
||||
norm_layer=nn.LayerNorm
|
||||
):
|
||||
super().__init__()
|
||||
self.norm1 = norm_layer(dim)
|
||||
self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
|
||||
|
@ -244,8 +273,18 @@ class Block(nn.Module):
|
|||
class ResPostBlock(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., init_values=None,
|
||||
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
|
||||
self,
|
||||
dim,
|
||||
num_heads,
|
||||
mlp_ratio=4.,
|
||||
qkv_bias=False,
|
||||
drop=0.,
|
||||
attn_drop=0.,
|
||||
init_values=None,
|
||||
drop_path=0.,
|
||||
act_layer=nn.GELU,
|
||||
norm_layer=nn.LayerNorm
|
||||
):
|
||||
super().__init__()
|
||||
self.init_values = init_values
|
||||
|
||||
|
@ -274,8 +313,19 @@ class ResPostBlock(nn.Module):
|
|||
class ParallelBlock(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self, dim, num_heads, num_parallel=2, mlp_ratio=4., qkv_bias=False, init_values=None,
|
||||
drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
|
||||
self,
|
||||
dim,
|
||||
num_heads,
|
||||
num_parallel=2,
|
||||
mlp_ratio=4.,
|
||||
qkv_bias=False,
|
||||
init_values=None,
|
||||
drop=0.,
|
||||
attn_drop=0.,
|
||||
drop_path=0.,
|
||||
act_layer=nn.GELU,
|
||||
norm_layer=nn.LayerNorm
|
||||
):
|
||||
super().__init__()
|
||||
self.num_parallel = num_parallel
|
||||
self.attns = nn.ModuleList()
|
||||
|
@ -320,10 +370,31 @@ class VisionTransformer(nn.Module):
|
|||
"""
|
||||
|
||||
def __init__(
|
||||
self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, global_pool='token',
|
||||
embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, init_values=None,
|
||||
class_token=True, no_embed_class=False, fc_norm=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.,
|
||||
weight_init='', embed_layer=PatchEmbed, norm_layer=None, act_layer=None, block_fn=Block):
|
||||
self,
|
||||
img_size=224,
|
||||
patch_size=16,
|
||||
in_chans=3,
|
||||
num_classes=1000,
|
||||
global_pool='token',
|
||||
embed_dim=768,
|
||||
depth=12,
|
||||
num_heads=12,
|
||||
mlp_ratio=4.,
|
||||
qkv_bias=True,
|
||||
init_values=None,
|
||||
class_token=True,
|
||||
no_embed_class=False,
|
||||
pre_norm=False,
|
||||
fc_norm=None,
|
||||
drop_rate=0.,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.,
|
||||
weight_init='',
|
||||
embed_layer=PatchEmbed,
|
||||
norm_layer=None,
|
||||
act_layer=None,
|
||||
block_fn=Block,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
img_size (int, tuple): input image size
|
||||
|
@ -362,19 +433,34 @@ class VisionTransformer(nn.Module):
|
|||
self.grad_checkpointing = False
|
||||
|
||||
self.patch_embed = embed_layer(
|
||||
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
|
||||
img_size=img_size,
|
||||
patch_size=patch_size,
|
||||
in_chans=in_chans,
|
||||
embed_dim=embed_dim,
|
||||
bias=not pre_norm, # disable bias if pre-norm is used (e.g. CLIP)
|
||||
)
|
||||
num_patches = self.patch_embed.num_patches
|
||||
|
||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None
|
||||
embed_len = num_patches if no_embed_class else num_patches + self.num_prefix_tokens
|
||||
self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * .02)
|
||||
self.pos_drop = nn.Dropout(p=drop_rate)
|
||||
self.norm_pre = norm_layer(embed_dim) if pre_norm else nn.Identity()
|
||||
|
||||
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
||||
self.blocks = nn.Sequential(*[
|
||||
block_fn(
|
||||
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, init_values=init_values,
|
||||
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer)
|
||||
dim=embed_dim,
|
||||
num_heads=num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
init_values=init_values,
|
||||
drop=drop_rate,
|
||||
attn_drop=attn_drop_rate,
|
||||
drop_path=dpr[i],
|
||||
norm_layer=norm_layer,
|
||||
act_layer=act_layer
|
||||
)
|
||||
for i in range(depth)])
|
||||
self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity()
|
||||
|
||||
|
@ -445,6 +531,7 @@ class VisionTransformer(nn.Module):
|
|||
def forward_features(self, x):
|
||||
x = self.patch_embed(x)
|
||||
x = self._pos_embed(x)
|
||||
x = self.norm_pre(x)
|
||||
if self.grad_checkpointing and not torch.jit.is_scripting():
|
||||
x = checkpoint_seq(self.blocks, x)
|
||||
else:
|
||||
|
@ -623,6 +710,40 @@ def resize_pos_embed(posemb, posemb_new, num_prefix_tokens=1, gs_new=()):
|
|||
return posemb
|
||||
|
||||
|
||||
def _convert_openai_clip(state_dict, model):
|
||||
out_dict = {}
|
||||
swaps = [
|
||||
('visual.', ''), ('conv1', 'patch_embed.proj'), ('positional_embedding', 'pos_embed'),
|
||||
('transformer.resblocks.', 'blocks.'), ('ln_pre', 'norm_pre'), ('ln_post', 'norm'), ('ln_', 'norm'),
|
||||
('in_proj_', 'qkv.'), ('out_proj', 'proj'), ('mlp.c_fc', 'mlp.fc1'), ('mlp.c_proj', 'mlp.fc2'),
|
||||
]
|
||||
for k, v in state_dict.items():
|
||||
if not k.startswith('visual.'):
|
||||
continue
|
||||
for sp in swaps:
|
||||
k = k.replace(sp[0], sp[1])
|
||||
|
||||
if k == 'proj':
|
||||
k = 'head.weight'
|
||||
v = v.transpose(0, 1)
|
||||
out_dict['head.bias'] = torch.zeros(v.shape[0])
|
||||
elif k == 'class_embedding':
|
||||
k = 'cls_token'
|
||||
v = v.unsqueeze(0).unsqueeze(1)
|
||||
elif k == 'pos_embed':
|
||||
v = v.unsqueeze(0)
|
||||
if v.shape[1] != model.pos_embed.shape[1]:
|
||||
# To resize pos embedding when using model at different size from pretrained weights
|
||||
v = resize_pos_embed(
|
||||
v,
|
||||
model.pos_embed,
|
||||
0 if getattr(model, 'no_embed_class') else getattr(model, 'num_prefix_tokens', 1),
|
||||
model.patch_embed.grid_size
|
||||
)
|
||||
out_dict[k] = v
|
||||
return out_dict
|
||||
|
||||
|
||||
def checkpoint_filter_fn(state_dict, model, adapt_layer_scale=False):
|
||||
""" convert patch embedding weight from manual patchify + linear proj to conv"""
|
||||
import re
|
||||
|
@ -631,6 +752,9 @@ def checkpoint_filter_fn(state_dict, model, adapt_layer_scale=False):
|
|||
# For deit models
|
||||
state_dict = state_dict['model']
|
||||
|
||||
if 'visual.class_embedding' in state_dict:
|
||||
return _convert_openai_clip(state_dict, model)
|
||||
|
||||
for k, v in state_dict.items():
|
||||
if 'patch_embed.proj.weight' in k and len(v.shape) < 4:
|
||||
# For old models that I trained prior to conv based patchification
|
||||
|
@ -833,7 +957,7 @@ def vit_huge_patch14_224(pretrained=False, **kwargs):
|
|||
|
||||
@register_model
|
||||
def vit_giant_patch14_224(pretrained=False, **kwargs):
|
||||
""" ViT-Giant model (ViT-g/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560
|
||||
""" ViT-Giant (little-g) model (ViT-g/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560
|
||||
"""
|
||||
model_kwargs = dict(patch_size=14, embed_dim=1408, mlp_ratio=48/11, depth=40, num_heads=16, **kwargs)
|
||||
model = _create_vision_transformer('vit_giant_patch14_224', pretrained=pretrained, **model_kwargs)
|
||||
|
@ -842,7 +966,7 @@ def vit_giant_patch14_224(pretrained=False, **kwargs):
|
|||
|
||||
@register_model
|
||||
def vit_gigantic_patch14_224(pretrained=False, **kwargs):
|
||||
""" ViT-Gigantic model (ViT-G/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560
|
||||
""" ViT-Gigantic (big-G) model (ViT-G/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560
|
||||
"""
|
||||
model_kwargs = dict(patch_size=14, embed_dim=1664, mlp_ratio=64/13, depth=48, num_heads=16, **kwargs)
|
||||
model = _create_vision_transformer('vit_gigantic_patch14_224', pretrained=pretrained, **model_kwargs)
|
||||
|
@ -1085,3 +1209,44 @@ def vit_base_patch16_18x2_224(pretrained=False, **kwargs):
|
|||
patch_size=16, embed_dim=768, depth=18, num_heads=12, init_values=1e-5, block_fn=ParallelBlock, **kwargs)
|
||||
model = _create_vision_transformer('vit_base_patch16_18x2_224', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_base_patch32_224_clip_laion2b(pretrained=False, **kwargs):
|
||||
""" ViT-B/32
|
||||
Pretrained weights from CLIP image tower trained on LAION-2B image-text pairs.
|
||||
"""
|
||||
model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, pre_norm=True, **kwargs)
|
||||
model = _create_vision_transformer('vit_base_patch32_224_clip_laion2b', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_large_patch14_224_clip_laion2b(pretrained=False, **kwargs):
|
||||
""" ViT-Large model (ViT-L/14)
|
||||
Pretrained weights from CLIP image tower trained on LAION-2B image-text pairs.
|
||||
"""
|
||||
model_kwargs = dict(patch_size=14, embed_dim=1024, depth=24, num_heads=16, pre_norm=True, **kwargs)
|
||||
model = _create_vision_transformer('vit_large_patch14_224_clip_laion2b', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_huge_patch14_224_clip_laion2b(pretrained=False, **kwargs):
|
||||
""" ViT-Huge model (ViT-H/14) from original paper (https://arxiv.org/abs/2010.11929).
|
||||
Pretrained weights from CLIP image tower trained on LAION-2B image-text pairs.
|
||||
"""
|
||||
model_kwargs = dict(patch_size=14, embed_dim=1280, depth=32, num_heads=16, pre_norm=True, **kwargs)
|
||||
model = _create_vision_transformer('vit_huge_patch14_224_clip_laion2b', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_giant_patch14_224_clip_laion2b(pretrained=False, **kwargs):
|
||||
""" ViT-Giant (little-g) model (ViT-g/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560
|
||||
Pretrained weights from CLIP image tower trained on LAION-2B image-text pairs.
|
||||
"""
|
||||
model_kwargs = dict(
|
||||
patch_size=14, embed_dim=1408, mlp_ratio=48/11, depth=40, num_heads=16, pre_norm=True, **kwargs)
|
||||
model = _create_vision_transformer('vit_giant_patch14_224_clip_laion2b', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
|
|
@ -101,7 +101,16 @@ class HybridEmbed(nn.Module):
|
|||
""" CNN Feature Map Embedding
|
||||
Extract feature map from CNN, flatten, project to embedding dim.
|
||||
"""
|
||||
def __init__(self, backbone, img_size=224, patch_size=1, feature_size=None, in_chans=3, embed_dim=768):
|
||||
def __init__(
|
||||
self,
|
||||
backbone,
|
||||
img_size=224,
|
||||
patch_size=1,
|
||||
feature_size=None,
|
||||
in_chans=3,
|
||||
embed_dim=768,
|
||||
bias=True,
|
||||
):
|
||||
super().__init__()
|
||||
assert isinstance(backbone, nn.Module)
|
||||
img_size = to_2tuple(img_size)
|
||||
|
@ -130,7 +139,7 @@ class HybridEmbed(nn.Module):
|
|||
assert feature_size[0] % patch_size[0] == 0 and feature_size[1] % patch_size[1] == 0
|
||||
self.grid_size = (feature_size[0] // patch_size[0], feature_size[1] // patch_size[1])
|
||||
self.num_patches = self.grid_size[0] * self.grid_size[1]
|
||||
self.proj = nn.Conv2d(feature_dim, embed_dim, kernel_size=patch_size, stride=patch_size)
|
||||
self.proj = nn.Conv2d(feature_dim, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.backbone(x)
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
""" Optimizer Factory w/ Custom Weight Decay
|
||||
Hacked together by / Copyright 2021 Ross Wightman
|
||||
"""
|
||||
import json
|
||||
import logging
|
||||
from itertools import islice
|
||||
from typing import Optional, Callable, Tuple
|
||||
|
||||
|
@ -31,6 +31,8 @@ try:
|
|||
except ImportError:
|
||||
has_apex = False
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def param_groups_weight_decay(
|
||||
model: nn.Module,
|
||||
|
@ -92,6 +94,7 @@ def param_groups_layer_decay(
|
|||
no_weight_decay_list: Tuple[str] = (),
|
||||
layer_decay: float = .75,
|
||||
end_layer_decay: Optional[float] = None,
|
||||
verbose: bool = False,
|
||||
):
|
||||
"""
|
||||
Parameter groups for layer-wise lr decay & weight decay
|
||||
|
@ -142,8 +145,9 @@ def param_groups_layer_decay(
|
|||
param_group_names[group_name]["param_names"].append(name)
|
||||
param_groups[group_name]["params"].append(param)
|
||||
|
||||
# FIXME temporary output to debug new feature
|
||||
print("parameter groups: \n%s" % json.dumps(param_group_names, indent=2))
|
||||
if verbose:
|
||||
import json
|
||||
_logger.info("parameter groups: \n%s" % json.dumps(param_group_names, indent=2))
|
||||
|
||||
return list(param_groups.values())
|
||||
|
||||
|
|
Loading…
Reference in New Issue