update norm eps for efficientvit large
parent
87ba43a9bc
commit
c9d093a58e
|
@ -8,6 +8,7 @@ Adapted from official impl at https://github.com/mit-han-lab/efficientvit
|
|||
|
||||
__all__ = ['EfficientVit']
|
||||
from typing import Optional
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
@ -15,7 +16,7 @@ import torch.nn.functional as F
|
|||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import SelectAdaptivePool2d, create_conv2d
|
||||
from timm.layers import SelectAdaptivePool2d, create_conv2d, GELUTanh
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._features_fx import register_notrace_module
|
||||
from ._manipulate import checkpoint_seq
|
||||
|
@ -71,10 +72,7 @@ class ConvNormAct(nn.Module):
|
|||
bias=bias,
|
||||
)
|
||||
self.norm = norm_layer(num_features=out_channels) if norm_layer else nn.Identity()
|
||||
if act_layer is not None:
|
||||
self.act = act_layer(inplace=True) if act_layer is not nn.GELU else act_layer(approximate="tanh")
|
||||
else:
|
||||
self.act = nn.Identity()
|
||||
self.act = act_layer(inplace=True) if act_layer is not None else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.dropout(x)
|
||||
|
@ -641,14 +639,15 @@ class ClassifierHead(nn.Module):
|
|||
norm_layer=nn.BatchNorm2d,
|
||||
act_layer=nn.Hardswish,
|
||||
global_pool='avg',
|
||||
norm_eps=1e-5,
|
||||
):
|
||||
super(ClassifierHead, self).__init__()
|
||||
self.in_conv = ConvNormAct(in_channels, widths[0], 1, norm_layer=norm_layer, act_layer=act_layer)
|
||||
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool, flatten=True, input_fmt='NCHW')
|
||||
self.classifier = nn.Sequential(
|
||||
nn.Linear(widths[0], widths[1], bias=False),
|
||||
nn.LayerNorm(widths[1]),
|
||||
act_layer(inplace=True) if act_layer is not nn.GELU else act_layer(approximate="tanh"),
|
||||
nn.LayerNorm(widths[1], eps=norm_eps),
|
||||
act_layer(inplace=True) if act_layer is not None else nn.Identity(),
|
||||
nn.Dropout(dropout, inplace=False),
|
||||
nn.Linear(widths[1], n_classes, bias=True),
|
||||
)
|
||||
|
@ -784,17 +783,19 @@ class EfficientVitLarge(nn.Module):
|
|||
depths=(),
|
||||
head_dim=32,
|
||||
norm_layer=nn.BatchNorm2d,
|
||||
act_layer=nn.GELU,
|
||||
act_layer=GELUTanh,
|
||||
global_pool='avg',
|
||||
head_widths=(),
|
||||
drop_rate=0.0,
|
||||
num_classes=1000,
|
||||
eps=1e-7,
|
||||
norm_eps=1e-7,
|
||||
):
|
||||
super(EfficientVitLarge, self).__init__()
|
||||
self.grad_checkpointing = False
|
||||
self.global_pool = global_pool
|
||||
self.num_classes = num_classes
|
||||
self.norm_eps = norm_eps
|
||||
norm_layer = partial(norm_layer, eps=self.norm_eps)
|
||||
|
||||
# input stem
|
||||
self.stem = Stem(in_chans, widths[0], depths[0], norm_layer, act_layer, block_type='large')
|
||||
|
@ -830,20 +831,13 @@ class EfficientVitLarge(nn.Module):
|
|||
dropout=self.head_dropout,
|
||||
global_pool=self.global_pool,
|
||||
act_layer=act_layer,
|
||||
norm_eps=self.norm_eps,
|
||||
)
|
||||
else:
|
||||
if self.global_pool == 'avg':
|
||||
self.head = SelectAdaptivePool2d(pool_type=global_pool, flatten=True)
|
||||
else:
|
||||
self.head = nn.Identity()
|
||||
self.set_norm_eps(eps)
|
||||
|
||||
@torch.jit.ignore
|
||||
def set_norm_eps(self, eps):
|
||||
for m in self.modules():
|
||||
if isinstance(m, (nn.GroupNorm, nn.LayerNorm, _BatchNorm)):
|
||||
if eps is not None:
|
||||
m.eps = eps
|
||||
|
||||
@torch.jit.ignore
|
||||
def group_matcher(self, coarse=False):
|
||||
|
@ -875,6 +869,7 @@ class EfficientVitLarge(nn.Module):
|
|||
n_classes=num_classes,
|
||||
dropout=self.head_dropout,
|
||||
global_pool=self.global_pool,
|
||||
norm_eps=self.norm_eps
|
||||
)
|
||||
else:
|
||||
if self.global_pool == 'avg':
|
||||
|
@ -1056,19 +1051,19 @@ def efficientvit_l3(pretrained=False, **kwargs):
|
|||
@register_model
|
||||
def efficientvit_l0_sam(pretrained=False, **kwargs):
|
||||
model_args = dict(
|
||||
widths=(32, 64, 128, 256, 512), depths=(1, 1, 1, 4, 4), head_dim=32, num_classes=0) # only backbone for segment-anything-model weights
|
||||
widths=(32, 64, 128, 256, 512), depths=(1, 1, 1, 4, 4), head_dim=32, num_classes=0, norm_eps=1e-6) # only backbone for segment-anything-model weights
|
||||
return _create_efficientvit_large('efficientvit_l0_sam', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
@register_model
|
||||
def efficientvit_l1_sam(pretrained=False, **kwargs):
|
||||
model_args = dict(
|
||||
widths=(32, 64, 128, 256, 512), depths=(1, 1, 1, 6, 6), head_dim=32, num_classes=0) # only backbone for segment-anything-model weights
|
||||
widths=(32, 64, 128, 256, 512), depths=(1, 1, 1, 6, 6), head_dim=32, num_classes=0, norm_eps=1e-6) # only backbone for segment-anything-model weights
|
||||
return _create_efficientvit_large('efficientvit_l1_sam', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
@register_model
|
||||
def efficientvit_l2_sam(pretrained=False, **kwargs):
|
||||
model_args = dict(
|
||||
widths=(32, 64, 128, 256, 512), depths=(1, 2, 2, 8, 8), head_dim=32, num_classes=0) # only backbone for segment-anything-model weights
|
||||
widths=(32, 64, 128, 256, 512), depths=(1, 2, 2, 8, 8), head_dim=32, num_classes=0, norm_eps=1e-6) # only backbone for segment-anything-model weights
|
||||
return _create_efficientvit_large('efficientvit_l2_sam', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
|
|
Loading…
Reference in New Issue