Update metaformers.py

This commit is contained in:
Fredo Guan 2023-02-10 00:24:58 -08:00
parent 10bde717e5
commit f938beb81b

View File

@ -24,8 +24,11 @@ Adapted from https://github.com/sail-sg/metaformer, original copyright below
from collections import OrderedDict from collections import OrderedDict
from functools import partial from functools import partial
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch import Tensor
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import trunc_normal_, DropPath, SelectAdaptivePool2d, GroupNorm1 from timm.layers import trunc_normal_, DropPath, SelectAdaptivePool2d, GroupNorm1
from timm.layers.helpers import to_2tuple from timm.layers.helpers import to_2tuple
@ -415,6 +418,61 @@ class MetaFormerBlock(nn.Module):
return x return x
class MetaFormerStage(nn.Module):
def __init__(
self,
in_chs,
out_chs,
depth=2,
downsample_norm=partial(LayerNormGeneral, bias=False, eps=1e-6),
token_mixer=nn.Identity,
mlp=Mlp,
mlp_fn=nn.Linear,
mlp_act=StarReLU,
mlp_bias=False,
norm_layer=partial(LayerNormGeneral, eps=1e-6, bias=False),
dp_rates=[0.]*2,
layer_scale_init_value=None,
res_scale_init_value=None,
):
super().__init__()
self.grad_checkpointing = False
self.downsample = nn.Identity() if in_chs == out_chs else Downsampling(
in_chs,
out_chs,
kernel_size=3,
stride=2,
padding=1,
norm_layer=downsample_norm
)
self.blocks = nn.Sequential(*[MetaFormerBlock(
dim=out_chs,
token_mixer=token_mixer,
mlp=mlp,
mlp_fn=mlp_fn,
mlp_act=mlp_act,
mlp_bias=mlp_bias,
norm_layer=norm_layer,
drop_path=dp_rates[i],
layer_scale_init_value=layer_scale_init_value,
res_scale_init_value=res_scale_init_value
) for i in range(depth)])
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.grad_checkpointing = enable
def forward(self, x: Tensor):
x = self.downsample(x)
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint_seq(self.blocks, x)
else:
x = self.blocks(x)
return x
class MetaFormer(nn.Module): class MetaFormer(nn.Module):
r""" MetaFormer r""" MetaFormer
A PyTorch impl of : `MetaFormer Baselines for Vision` - A PyTorch impl of : `MetaFormer Baselines for Vision` -
@ -447,7 +505,7 @@ class MetaFormer(nn.Module):
token_mixers=nn.Identity, token_mixers=nn.Identity,
mlps=Mlp, mlps=Mlp,
mlp_fn=nn.Linear, mlp_fn=nn.Linear,
mlp_act = StarReLU, mlp_act=StarReLU,
mlp_bias=False, mlp_bias=False,
norm_layers=partial(LayerNormGeneral, eps=1e-6, bias=False), norm_layers=partial(LayerNormGeneral, eps=1e-6, bias=False),
drop_path_rate=0., drop_path_rate=0.,
@ -491,7 +549,7 @@ class MetaFormer(nn.Module):
self.grad_checkpointing = False self.grad_checkpointing = False
self.feature_info = [] self.feature_info = []
dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] dp_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
self.stem = Stem( self.stem = Stem(
@ -502,7 +560,9 @@ class MetaFormer(nn.Module):
stages = nn.ModuleList() # each stage consists of multiple metaformer blocks stages = nn.ModuleList() # each stage consists of multiple metaformer blocks
cur = 0 cur = 0
last_dim = dims[0]
for i in range(self.num_stages): for i in range(self.num_stages):
'''
stage = nn.Sequential(OrderedDict([ stage = nn.Sequential(OrderedDict([
('downsample', nn.Identity() if i == 0 else Downsampling( ('downsample', nn.Identity() if i == 0 else Downsampling(
dims[i-1], dims[i-1],
@ -526,8 +586,27 @@ class MetaFormer(nn.Module):
) for j in range(depths[i])]) ) for j in range(depths[i])])
)]) )])
) )
'''
stage = MetaFormerStage(
dim,
dims[i],
depth=depths[i],
downsample_norm=downsample_norm,
token_mixer=token_mixers[i],
mlp=mlps[i],
mlp_fn=mlp_fn,
mlp_act=mlp_act,
mlp_bias=mlp_bias,
norm_layer=norm_layers[i],
dp_rates=dp_rates[i],
layer_scale_init_value=layer_scale_init_values[i],
res_scale_init_value=res_scale_init_values[i],
)
stages.append(stage) stages.append(stage)
cur += depths[i] cur += depths[i]
last_dim = dims[i]
self.feature_info += [dict(num_chs=dims[i], reduction=2, module=f'stages.{i}')] self.feature_info += [dict(num_chs=dims[i], reduction=2, module=f'stages.{i}')]
self.stages = nn.Sequential(*stages) self.stages = nn.Sequential(*stages)
@ -596,7 +675,7 @@ class MetaFormer(nn.Module):
x = self.norm_pre(x).permute(0, 3, 1, 2) x = self.norm_pre(x).permute(0, 3, 1, 2)
return x return x
def forward(self, x): def forward(self, x: Tensor):
x = self.forward_features(x) x = self.forward_features(x)
x = self.forward_head(x) x = self.forward_head(x)
return x return x