Update metaformers.py
parent
ec202b4d16
commit
d90ed530dc
|
@ -469,14 +469,19 @@ class MetaFormerBlock(nn.Module):
|
|||
Implementation of one MetaFormer block.
|
||||
"""
|
||||
def __init__(self, dim,
|
||||
token_mixer=nn.Identity, mlp=Mlp,
|
||||
token_mixer=nn.Identity,
|
||||
mlp=Mlp,
|
||||
norm_layer=nn.LayerNorm,
|
||||
drop=0., drop_path=0.,
|
||||
layer_scale_init_value=None, res_scale_init_value=None
|
||||
layer_scale_init_value=None,
|
||||
res_scale_init_value=None,
|
||||
downsample = nn.Identity()
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
|
||||
|
||||
self.downsample = nn.Identity()
|
||||
|
||||
self.norm1 = norm_layer(dim)
|
||||
self.token_mixer = token_mixer(dim=dim, drop=drop)
|
||||
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
|
@ -494,6 +499,7 @@ class MetaFormerBlock(nn.Module):
|
|||
if res_scale_init_value else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.downsample(x)
|
||||
x = self.res_scale1(x) + \
|
||||
self.layer_scale1(
|
||||
self.drop_path1(
|
||||
|
@ -600,18 +606,18 @@ class MetaFormer(nn.Module):
|
|||
stages = nn.ModuleList() # each stage consists of multiple metaformer blocks
|
||||
cur = 0
|
||||
for i in range(num_stage):
|
||||
stage = nn.Sequential(
|
||||
downsample_layers[i],
|
||||
*[MetaFormerBlock(
|
||||
dim=dims[i],
|
||||
token_mixer=token_mixers[i],
|
||||
mlp=mlps[i],
|
||||
norm_layer=norm_layers[i],
|
||||
drop_path=dp_rates[cur + j],
|
||||
layer_scale_init_value=layer_scale_init_values[i],
|
||||
res_scale_init_value=res_scale_init_values[i],
|
||||
stage = nn.Sequential(*[MetaFormerBlock(
|
||||
dim=dims[i],
|
||||
token_mixer=token_mixers[i],
|
||||
mlp=mlps[i],
|
||||
norm_layer=norm_layers[i],
|
||||
drop_path=dp_rates[cur + j],
|
||||
layer_scale_init_value=layer_scale_init_values[i],
|
||||
res_scale_init_value=res_scale_init_values[i],
|
||||
downsample = downsample_layers[i]
|
||||
) for j in range(depths[i])]
|
||||
)
|
||||
|
||||
stages.append(stage)
|
||||
cur += depths[i]
|
||||
|
||||
|
@ -649,6 +655,17 @@ class MetaFormer(nn.Module):
|
|||
x = self.head(x)
|
||||
return x
|
||||
|
||||
def checkpoint_filter_fn(state_dict, model):
|
||||
|
||||
import re
|
||||
out_dict = {}
|
||||
for k, v in state_dict.items():
|
||||
|
||||
k = re.sub(r'downsample_layers.([0-9]+)', r'stages.\1.downsample', k)
|
||||
out_dict[k] = v
|
||||
return out_dict
|
||||
|
||||
|
||||
def _create_metaformer(variant, pretrained=False, **kwargs):
|
||||
default_out_indices = tuple(i for i, _ in enumerate(kwargs.get('depths', (2, 2, 6, 2))))
|
||||
out_indices = kwargs.pop('out_indices', default_out_indices)
|
||||
|
|
Loading…
Reference in New Issue