Update metaformers.py
parent
10bde717e5
commit
f938beb81b
|
@ -24,8 +24,11 @@ Adapted from https://github.com/sail-sg/metaformer, original copyright below
|
|||
|
||||
from collections import OrderedDict
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import trunc_normal_, DropPath, SelectAdaptivePool2d, GroupNorm1
|
||||
from timm.layers.helpers import to_2tuple
|
||||
|
@ -415,6 +418,61 @@ class MetaFormerBlock(nn.Module):
|
|||
|
||||
return x
|
||||
|
||||
class MetaFormerStage(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_chs,
|
||||
out_chs,
|
||||
depth=2,
|
||||
downsample_norm=partial(LayerNormGeneral, bias=False, eps=1e-6),
|
||||
token_mixer=nn.Identity,
|
||||
mlp=Mlp,
|
||||
mlp_fn=nn.Linear,
|
||||
mlp_act=StarReLU,
|
||||
mlp_bias=False,
|
||||
norm_layer=partial(LayerNormGeneral, eps=1e-6, bias=False),
|
||||
dp_rates=[0.]*2,
|
||||
layer_scale_init_value=None,
|
||||
res_scale_init_value=None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.grad_checkpointing = False
|
||||
|
||||
self.downsample = nn.Identity() if in_chs == out_chs else Downsampling(
|
||||
in_chs,
|
||||
out_chs,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
norm_layer=downsample_norm
|
||||
)
|
||||
|
||||
self.blocks = nn.Sequential(*[MetaFormerBlock(
|
||||
dim=out_chs,
|
||||
token_mixer=token_mixer,
|
||||
mlp=mlp,
|
||||
mlp_fn=mlp_fn,
|
||||
mlp_act=mlp_act,
|
||||
mlp_bias=mlp_bias,
|
||||
norm_layer=norm_layer,
|
||||
drop_path=dp_rates[i],
|
||||
layer_scale_init_value=layer_scale_init_value,
|
||||
res_scale_init_value=res_scale_init_value
|
||||
) for i in range(depth)])
|
||||
|
||||
@torch.jit.ignore
|
||||
def set_grad_checkpointing(self, enable=True):
|
||||
self.grad_checkpointing = enable
|
||||
|
||||
def forward(self, x: Tensor):
|
||||
x = self.downsample(x)
|
||||
if self.grad_checkpointing and not torch.jit.is_scripting():
|
||||
x = checkpoint_seq(self.blocks, x)
|
||||
else:
|
||||
x = self.blocks(x)
|
||||
return x
|
||||
|
||||
class MetaFormer(nn.Module):
|
||||
r""" MetaFormer
|
||||
A PyTorch impl of : `MetaFormer Baselines for Vision` -
|
||||
|
@ -447,7 +505,7 @@ class MetaFormer(nn.Module):
|
|||
token_mixers=nn.Identity,
|
||||
mlps=Mlp,
|
||||
mlp_fn=nn.Linear,
|
||||
mlp_act = StarReLU,
|
||||
mlp_act=StarReLU,
|
||||
mlp_bias=False,
|
||||
norm_layers=partial(LayerNormGeneral, eps=1e-6, bias=False),
|
||||
drop_path_rate=0.,
|
||||
|
@ -491,7 +549,7 @@ class MetaFormer(nn.Module):
|
|||
self.grad_checkpointing = False
|
||||
self.feature_info = []
|
||||
|
||||
dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
|
||||
dp_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
|
||||
|
||||
|
||||
self.stem = Stem(
|
||||
|
@ -502,7 +560,9 @@ class MetaFormer(nn.Module):
|
|||
|
||||
stages = nn.ModuleList() # each stage consists of multiple metaformer blocks
|
||||
cur = 0
|
||||
last_dim = dims[0]
|
||||
for i in range(self.num_stages):
|
||||
'''
|
||||
stage = nn.Sequential(OrderedDict([
|
||||
('downsample', nn.Identity() if i == 0 else Downsampling(
|
||||
dims[i-1],
|
||||
|
@ -526,8 +586,27 @@ class MetaFormer(nn.Module):
|
|||
) for j in range(depths[i])])
|
||||
)])
|
||||
)
|
||||
'''
|
||||
|
||||
stage = MetaFormerStage(
|
||||
dim,
|
||||
dims[i],
|
||||
depth=depths[i],
|
||||
downsample_norm=downsample_norm,
|
||||
token_mixer=token_mixers[i],
|
||||
mlp=mlps[i],
|
||||
mlp_fn=mlp_fn,
|
||||
mlp_act=mlp_act,
|
||||
mlp_bias=mlp_bias,
|
||||
norm_layer=norm_layers[i],
|
||||
dp_rates=dp_rates[i],
|
||||
layer_scale_init_value=layer_scale_init_values[i],
|
||||
res_scale_init_value=res_scale_init_values[i],
|
||||
)
|
||||
|
||||
stages.append(stage)
|
||||
cur += depths[i]
|
||||
last_dim = dims[i]
|
||||
self.feature_info += [dict(num_chs=dims[i], reduction=2, module=f'stages.{i}')]
|
||||
|
||||
self.stages = nn.Sequential(*stages)
|
||||
|
@ -596,7 +675,7 @@ class MetaFormer(nn.Module):
|
|||
x = self.norm_pre(x).permute(0, 3, 1, 2)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x: Tensor):
|
||||
x = self.forward_features(x)
|
||||
x = self.forward_head(x)
|
||||
return x
|
||||
|
|
Loading…
Reference in New Issue