Update metaformers.py

pull/1647/head
Fredo Guan 2023-02-10 00:24:58 -08:00
parent 10bde717e5
commit f938beb81b
1 changed files with 82 additions and 3 deletions

View File

@ -24,8 +24,11 @@ Adapted from https://github.com/sail-sg/metaformer, original copyright below
from collections import OrderedDict
from functools import partial
import torch
import torch.nn as nn
from torch import Tensor
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import trunc_normal_, DropPath, SelectAdaptivePool2d, GroupNorm1
from timm.layers.helpers import to_2tuple
@ -415,6 +418,61 @@ class MetaFormerBlock(nn.Module):
return x
class MetaFormerStage(nn.Module):
def __init__(
self,
in_chs,
out_chs,
depth=2,
downsample_norm=partial(LayerNormGeneral, bias=False, eps=1e-6),
token_mixer=nn.Identity,
mlp=Mlp,
mlp_fn=nn.Linear,
mlp_act=StarReLU,
mlp_bias=False,
norm_layer=partial(LayerNormGeneral, eps=1e-6, bias=False),
dp_rates=[0.]*2,
layer_scale_init_value=None,
res_scale_init_value=None,
):
super().__init__()
self.grad_checkpointing = False
self.downsample = nn.Identity() if in_chs == out_chs else Downsampling(
in_chs,
out_chs,
kernel_size=3,
stride=2,
padding=1,
norm_layer=downsample_norm
)
self.blocks = nn.Sequential(*[MetaFormerBlock(
dim=out_chs,
token_mixer=token_mixer,
mlp=mlp,
mlp_fn=mlp_fn,
mlp_act=mlp_act,
mlp_bias=mlp_bias,
norm_layer=norm_layer,
drop_path=dp_rates[i],
layer_scale_init_value=layer_scale_init_value,
res_scale_init_value=res_scale_init_value
) for i in range(depth)])
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.grad_checkpointing = enable
def forward(self, x: Tensor):
x = self.downsample(x)
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint_seq(self.blocks, x)
else:
x = self.blocks(x)
return x
class MetaFormer(nn.Module):
r""" MetaFormer
A PyTorch impl of : `MetaFormer Baselines for Vision` -
@ -447,7 +505,7 @@ class MetaFormer(nn.Module):
token_mixers=nn.Identity,
mlps=Mlp,
mlp_fn=nn.Linear,
mlp_act = StarReLU,
mlp_act=StarReLU,
mlp_bias=False,
norm_layers=partial(LayerNormGeneral, eps=1e-6, bias=False),
drop_path_rate=0.,
@ -491,7 +549,7 @@ class MetaFormer(nn.Module):
self.grad_checkpointing = False
self.feature_info = []
dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
dp_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
self.stem = Stem(
@ -502,7 +560,9 @@ class MetaFormer(nn.Module):
stages = nn.ModuleList() # each stage consists of multiple metaformer blocks
cur = 0
last_dim = dims[0]
for i in range(self.num_stages):
'''
stage = nn.Sequential(OrderedDict([
('downsample', nn.Identity() if i == 0 else Downsampling(
dims[i-1],
@ -526,8 +586,27 @@ class MetaFormer(nn.Module):
) for j in range(depths[i])])
)])
)
'''
stage = MetaFormerStage(
dim,
dims[i],
depth=depths[i],
downsample_norm=downsample_norm,
token_mixer=token_mixers[i],
mlp=mlps[i],
mlp_fn=mlp_fn,
mlp_act=mlp_act,
mlp_bias=mlp_bias,
norm_layer=norm_layers[i],
dp_rates=dp_rates[i],
layer_scale_init_value=layer_scale_init_values[i],
res_scale_init_value=res_scale_init_values[i],
)
stages.append(stage)
cur += depths[i]
last_dim = dims[i]
self.feature_info += [dict(num_chs=dims[i], reduction=2, module=f'stages.{i}')]
self.stages = nn.Sequential(*stages)
@ -596,7 +675,7 @@ class MetaFormer(nn.Module):
x = self.norm_pre(x).permute(0, 3, 1, 2)
return x
def forward(self, x):
def forward(self, x: Tensor):
x = self.forward_features(x)
x = self.forward_head(x)
return x