update norm eps for efficientvit large

pull/2034/head
方曦 2023-11-18 17:46:47 +08:00
parent 87ba43a9bc
commit c9d093a58e
1 changed files with 15 additions and 20 deletions

View File

@ -8,6 +8,7 @@ Adapted from official impl at https://github.com/mit-han-lab/efficientvit
__all__ = ['EfficientVit']
from typing import Optional
from functools import partial
import torch
import torch.nn as nn
@ -15,7 +16,7 @@ import torch.nn.functional as F
from torch.nn.modules.batchnorm import _BatchNorm
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import SelectAdaptivePool2d, create_conv2d
from timm.layers import SelectAdaptivePool2d, create_conv2d, GELUTanh
from ._builder import build_model_with_cfg
from ._features_fx import register_notrace_module
from ._manipulate import checkpoint_seq
@ -71,10 +72,7 @@ class ConvNormAct(nn.Module):
bias=bias,
)
self.norm = norm_layer(num_features=out_channels) if norm_layer else nn.Identity()
if act_layer is not None:
self.act = act_layer(inplace=True) if act_layer is not nn.GELU else act_layer(approximate="tanh")
else:
self.act = nn.Identity()
self.act = act_layer(inplace=True) if act_layer is not None else nn.Identity()
def forward(self, x):
x = self.dropout(x)
@ -641,14 +639,15 @@ class ClassifierHead(nn.Module):
norm_layer=nn.BatchNorm2d,
act_layer=nn.Hardswish,
global_pool='avg',
norm_eps=1e-5,
):
super(ClassifierHead, self).__init__()
self.in_conv = ConvNormAct(in_channels, widths[0], 1, norm_layer=norm_layer, act_layer=act_layer)
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool, flatten=True, input_fmt='NCHW')
self.classifier = nn.Sequential(
nn.Linear(widths[0], widths[1], bias=False),
nn.LayerNorm(widths[1]),
act_layer(inplace=True) if act_layer is not nn.GELU else act_layer(approximate="tanh"),
nn.LayerNorm(widths[1], eps=norm_eps),
act_layer(inplace=True) if act_layer is not None else nn.Identity(),
nn.Dropout(dropout, inplace=False),
nn.Linear(widths[1], n_classes, bias=True),
)
@ -784,17 +783,19 @@ class EfficientVitLarge(nn.Module):
depths=(),
head_dim=32,
norm_layer=nn.BatchNorm2d,
act_layer=nn.GELU,
act_layer=GELUTanh,
global_pool='avg',
head_widths=(),
drop_rate=0.0,
num_classes=1000,
eps=1e-7,
norm_eps=1e-7,
):
super(EfficientVitLarge, self).__init__()
self.grad_checkpointing = False
self.global_pool = global_pool
self.num_classes = num_classes
self.norm_eps = norm_eps
norm_layer = partial(norm_layer, eps=self.norm_eps)
# input stem
self.stem = Stem(in_chans, widths[0], depths[0], norm_layer, act_layer, block_type='large')
@ -830,20 +831,13 @@ class EfficientVitLarge(nn.Module):
dropout=self.head_dropout,
global_pool=self.global_pool,
act_layer=act_layer,
norm_eps=self.norm_eps,
)
else:
if self.global_pool == 'avg':
self.head = SelectAdaptivePool2d(pool_type=global_pool, flatten=True)
else:
self.head = nn.Identity()
self.set_norm_eps(eps)
@torch.jit.ignore
def set_norm_eps(self, eps):
for m in self.modules():
if isinstance(m, (nn.GroupNorm, nn.LayerNorm, _BatchNorm)):
if eps is not None:
m.eps = eps
@torch.jit.ignore
def group_matcher(self, coarse=False):
@ -875,6 +869,7 @@ class EfficientVitLarge(nn.Module):
n_classes=num_classes,
dropout=self.head_dropout,
global_pool=self.global_pool,
norm_eps=self.norm_eps
)
else:
if self.global_pool == 'avg':
@ -1056,19 +1051,19 @@ def efficientvit_l3(pretrained=False, **kwargs):
@register_model
def efficientvit_l0_sam(pretrained=False, **kwargs):
model_args = dict(
widths=(32, 64, 128, 256, 512), depths=(1, 1, 1, 4, 4), head_dim=32, num_classes=0) # only backbone for segment-anything-model weights
widths=(32, 64, 128, 256, 512), depths=(1, 1, 1, 4, 4), head_dim=32, num_classes=0, norm_eps=1e-6) # only backbone for segment-anything-model weights
return _create_efficientvit_large('efficientvit_l0_sam', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model
def efficientvit_l1_sam(pretrained=False, **kwargs):
model_args = dict(
widths=(32, 64, 128, 256, 512), depths=(1, 1, 1, 6, 6), head_dim=32, num_classes=0) # only backbone for segment-anything-model weights
widths=(32, 64, 128, 256, 512), depths=(1, 1, 1, 6, 6), head_dim=32, num_classes=0, norm_eps=1e-6) # only backbone for segment-anything-model weights
return _create_efficientvit_large('efficientvit_l1_sam', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model
def efficientvit_l2_sam(pretrained=False, **kwargs):
model_args = dict(
widths=(32, 64, 128, 256, 512), depths=(1, 2, 2, 8, 8), head_dim=32, num_classes=0) # only backbone for segment-anything-model weights
widths=(32, 64, 128, 256, 512), depths=(1, 2, 2, 8, 8), head_dim=32, num_classes=0, norm_eps=1e-6) # only backbone for segment-anything-model weights
return _create_efficientvit_large('efficientvit_l2_sam', pretrained=pretrained, **dict(model_args, **kwargs))