mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Update metaformers.py
This commit is contained in:
parent
10bde717e5
commit
f938beb81b
@ -24,8 +24,11 @@ Adapted from https://github.com/sail-sg/metaformer, original copyright below
|
|||||||
|
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||||
from timm.layers import trunc_normal_, DropPath, SelectAdaptivePool2d, GroupNorm1
|
from timm.layers import trunc_normal_, DropPath, SelectAdaptivePool2d, GroupNorm1
|
||||||
from timm.layers.helpers import to_2tuple
|
from timm.layers.helpers import to_2tuple
|
||||||
@ -415,6 +418,61 @@ class MetaFormerBlock(nn.Module):
|
|||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
class MetaFormerStage(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_chs,
|
||||||
|
out_chs,
|
||||||
|
depth=2,
|
||||||
|
downsample_norm=partial(LayerNormGeneral, bias=False, eps=1e-6),
|
||||||
|
token_mixer=nn.Identity,
|
||||||
|
mlp=Mlp,
|
||||||
|
mlp_fn=nn.Linear,
|
||||||
|
mlp_act=StarReLU,
|
||||||
|
mlp_bias=False,
|
||||||
|
norm_layer=partial(LayerNormGeneral, eps=1e-6, bias=False),
|
||||||
|
dp_rates=[0.]*2,
|
||||||
|
layer_scale_init_value=None,
|
||||||
|
res_scale_init_value=None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.grad_checkpointing = False
|
||||||
|
|
||||||
|
self.downsample = nn.Identity() if in_chs == out_chs else Downsampling(
|
||||||
|
in_chs,
|
||||||
|
out_chs,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=2,
|
||||||
|
padding=1,
|
||||||
|
norm_layer=downsample_norm
|
||||||
|
)
|
||||||
|
|
||||||
|
self.blocks = nn.Sequential(*[MetaFormerBlock(
|
||||||
|
dim=out_chs,
|
||||||
|
token_mixer=token_mixer,
|
||||||
|
mlp=mlp,
|
||||||
|
mlp_fn=mlp_fn,
|
||||||
|
mlp_act=mlp_act,
|
||||||
|
mlp_bias=mlp_bias,
|
||||||
|
norm_layer=norm_layer,
|
||||||
|
drop_path=dp_rates[i],
|
||||||
|
layer_scale_init_value=layer_scale_init_value,
|
||||||
|
res_scale_init_value=res_scale_init_value
|
||||||
|
) for i in range(depth)])
|
||||||
|
|
||||||
|
@torch.jit.ignore
|
||||||
|
def set_grad_checkpointing(self, enable=True):
|
||||||
|
self.grad_checkpointing = enable
|
||||||
|
|
||||||
|
def forward(self, x: Tensor):
|
||||||
|
x = self.downsample(x)
|
||||||
|
if self.grad_checkpointing and not torch.jit.is_scripting():
|
||||||
|
x = checkpoint_seq(self.blocks, x)
|
||||||
|
else:
|
||||||
|
x = self.blocks(x)
|
||||||
|
return x
|
||||||
|
|
||||||
class MetaFormer(nn.Module):
|
class MetaFormer(nn.Module):
|
||||||
r""" MetaFormer
|
r""" MetaFormer
|
||||||
A PyTorch impl of : `MetaFormer Baselines for Vision` -
|
A PyTorch impl of : `MetaFormer Baselines for Vision` -
|
||||||
@ -447,7 +505,7 @@ class MetaFormer(nn.Module):
|
|||||||
token_mixers=nn.Identity,
|
token_mixers=nn.Identity,
|
||||||
mlps=Mlp,
|
mlps=Mlp,
|
||||||
mlp_fn=nn.Linear,
|
mlp_fn=nn.Linear,
|
||||||
mlp_act = StarReLU,
|
mlp_act=StarReLU,
|
||||||
mlp_bias=False,
|
mlp_bias=False,
|
||||||
norm_layers=partial(LayerNormGeneral, eps=1e-6, bias=False),
|
norm_layers=partial(LayerNormGeneral, eps=1e-6, bias=False),
|
||||||
drop_path_rate=0.,
|
drop_path_rate=0.,
|
||||||
@ -491,7 +549,7 @@ class MetaFormer(nn.Module):
|
|||||||
self.grad_checkpointing = False
|
self.grad_checkpointing = False
|
||||||
self.feature_info = []
|
self.feature_info = []
|
||||||
|
|
||||||
dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
|
dp_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
|
||||||
|
|
||||||
|
|
||||||
self.stem = Stem(
|
self.stem = Stem(
|
||||||
@ -502,7 +560,9 @@ class MetaFormer(nn.Module):
|
|||||||
|
|
||||||
stages = nn.ModuleList() # each stage consists of multiple metaformer blocks
|
stages = nn.ModuleList() # each stage consists of multiple metaformer blocks
|
||||||
cur = 0
|
cur = 0
|
||||||
|
last_dim = dims[0]
|
||||||
for i in range(self.num_stages):
|
for i in range(self.num_stages):
|
||||||
|
'''
|
||||||
stage = nn.Sequential(OrderedDict([
|
stage = nn.Sequential(OrderedDict([
|
||||||
('downsample', nn.Identity() if i == 0 else Downsampling(
|
('downsample', nn.Identity() if i == 0 else Downsampling(
|
||||||
dims[i-1],
|
dims[i-1],
|
||||||
@ -526,8 +586,27 @@ class MetaFormer(nn.Module):
|
|||||||
) for j in range(depths[i])])
|
) for j in range(depths[i])])
|
||||||
)])
|
)])
|
||||||
)
|
)
|
||||||
|
'''
|
||||||
|
|
||||||
|
stage = MetaFormerStage(
|
||||||
|
dim,
|
||||||
|
dims[i],
|
||||||
|
depth=depths[i],
|
||||||
|
downsample_norm=downsample_norm,
|
||||||
|
token_mixer=token_mixers[i],
|
||||||
|
mlp=mlps[i],
|
||||||
|
mlp_fn=mlp_fn,
|
||||||
|
mlp_act=mlp_act,
|
||||||
|
mlp_bias=mlp_bias,
|
||||||
|
norm_layer=norm_layers[i],
|
||||||
|
dp_rates=dp_rates[i],
|
||||||
|
layer_scale_init_value=layer_scale_init_values[i],
|
||||||
|
res_scale_init_value=res_scale_init_values[i],
|
||||||
|
)
|
||||||
|
|
||||||
stages.append(stage)
|
stages.append(stage)
|
||||||
cur += depths[i]
|
cur += depths[i]
|
||||||
|
last_dim = dims[i]
|
||||||
self.feature_info += [dict(num_chs=dims[i], reduction=2, module=f'stages.{i}')]
|
self.feature_info += [dict(num_chs=dims[i], reduction=2, module=f'stages.{i}')]
|
||||||
|
|
||||||
self.stages = nn.Sequential(*stages)
|
self.stages = nn.Sequential(*stages)
|
||||||
@ -596,7 +675,7 @@ class MetaFormer(nn.Module):
|
|||||||
x = self.norm_pre(x).permute(0, 3, 1, 2)
|
x = self.norm_pre(x).permute(0, 3, 1, 2)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x: Tensor):
|
||||||
x = self.forward_features(x)
|
x = self.forward_features(x)
|
||||||
x = self.forward_head(x)
|
x = self.forward_head(x)
|
||||||
return x
|
return x
|
||||||
|
Loading…
x
Reference in New Issue
Block a user