Update metaformers.py
parent
0bde1c1218
commit
ec202b4d16
|
@ -20,12 +20,20 @@ Some implementations are modified from timm (https://github.com/rwightman/pytorc
|
|||
from functools import partial
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from timm.layers import trunc_normal_, DropPath
|
||||
from ._registry import register_model
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import trunc_normal_, DropPath
|
||||
from timm.layers.helpers import to_2tuple
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._features import FeatureInfo
|
||||
from ._features_fx import register_notrace_function
|
||||
from ._manipulate import checkpoint_seq
|
||||
from ._pretrained import generate_default_cfgs
|
||||
from ._registry import register_model
|
||||
|
||||
|
||||
|
||||
__all__ = ['MetaFormer']
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
return {
|
||||
'url': url,
|
||||
|
@ -187,6 +195,7 @@ default_cfgs = {
|
|||
num_classes=21841),
|
||||
}
|
||||
|
||||
cfgs_v2 = generate_default_cfgs(default_cfgs)
|
||||
|
||||
class Downsampling(nn.Module):
|
||||
"""
|
||||
|
@ -592,16 +601,17 @@ class MetaFormer(nn.Module):
|
|||
cur = 0
|
||||
for i in range(num_stage):
|
||||
stage = nn.Sequential(
|
||||
*[MetaFormerBlock(dim=dims[i],
|
||||
token_mixer=token_mixers[i],
|
||||
mlp=mlps[i],
|
||||
norm_layer=norm_layers[i],
|
||||
drop_path=dp_rates[cur + j],
|
||||
layer_scale_init_value=layer_scale_init_values[i],
|
||||
res_scale_init_value=res_scale_init_values[i],
|
||||
downsample_layers[i],
|
||||
*[MetaFormerBlock(
|
||||
dim=dims[i],
|
||||
token_mixer=token_mixers[i],
|
||||
mlp=mlps[i],
|
||||
norm_layer=norm_layers[i],
|
||||
drop_path=dp_rates[cur + j],
|
||||
layer_scale_init_value=layer_scale_init_values[i],
|
||||
res_scale_init_value=res_scale_init_values[i],
|
||||
) for j in range(depths[i])]
|
||||
)
|
||||
stages.append(downsample_layers[i])
|
||||
stages.append(stage)
|
||||
cur += depths[i]
|
||||
|
||||
|
@ -639,8 +649,20 @@ class MetaFormer(nn.Module):
|
|||
x = self.head(x)
|
||||
return x
|
||||
|
||||
|
||||
|
||||
def _create_metaformer(variant, pretrained=False, **kwargs):
|
||||
default_out_indices = tuple(i for i, _ in enumerate(kwargs.get('depths', (2, 2, 6, 2))))
|
||||
out_indices = kwargs.pop('out_indices', default_out_indices)
|
||||
|
||||
model = build_model_with_cfg(
|
||||
MetaFormer,
|
||||
variant,
|
||||
pretrained,
|
||||
pretrained_filter_fn=checkpoint_filter_fn,
|
||||
feature_cfg=dict(flatten_sequential=True, out_indices = out_indices),
|
||||
**kwargs)
|
||||
|
||||
return model
|
||||
'''
|
||||
@register_model
|
||||
def identityformer_s12(pretrained=False, **kwargs):
|
||||
model = MetaFormer(
|
||||
|
@ -656,6 +678,17 @@ def identityformer_s12(pretrained=False, **kwargs):
|
|||
model.load_state_dict(state_dict)
|
||||
return model
|
||||
|
||||
'''
|
||||
|
||||
@register_model
|
||||
def identityformer_s12(pretrained=False, **kwargs):
|
||||
model_kwargs = dict(
|
||||
depths=[2, 2, 6, 2],
|
||||
dims=[64, 128, 320, 512],
|
||||
token_mixers=nn.Identity,
|
||||
norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False),
|
||||
**kwargs)
|
||||
return _create_metaformer('identityformer_s12', pretrained=pretrained, **model_kwargs)
|
||||
|
||||
@register_model
|
||||
def identityformer_s24(pretrained=False, **kwargs):
|
||||
|
|
Loading…
Reference in New Issue