From c9d093a58ed6927923902ec831f1925634ec3c38 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=96=B9=E6=9B=A6?= Date: Sat, 18 Nov 2023 17:46:47 +0800 Subject: [PATCH] update norm eps for efficientvit large --- timm/models/efficientvit_mit.py | 35 ++++++++++++++------------------- 1 file changed, 15 insertions(+), 20 deletions(-) diff --git a/timm/models/efficientvit_mit.py b/timm/models/efficientvit_mit.py index c38a6f42..b464fc88 100644 --- a/timm/models/efficientvit_mit.py +++ b/timm/models/efficientvit_mit.py @@ -8,6 +8,7 @@ Adapted from official impl at https://github.com/mit-han-lab/efficientvit __all__ = ['EfficientVit'] from typing import Optional +from functools import partial import torch import torch.nn as nn @@ -15,7 +16,7 @@ import torch.nn.functional as F from torch.nn.modules.batchnorm import _BatchNorm from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.layers import SelectAdaptivePool2d, create_conv2d +from timm.layers import SelectAdaptivePool2d, create_conv2d, GELUTanh from ._builder import build_model_with_cfg from ._features_fx import register_notrace_module from ._manipulate import checkpoint_seq @@ -71,10 +72,7 @@ class ConvNormAct(nn.Module): bias=bias, ) self.norm = norm_layer(num_features=out_channels) if norm_layer else nn.Identity() - if act_layer is not None: - self.act = act_layer(inplace=True) if act_layer is not nn.GELU else act_layer(approximate="tanh") - else: - self.act = nn.Identity() + self.act = act_layer(inplace=True) if act_layer is not None else nn.Identity() def forward(self, x): x = self.dropout(x) @@ -641,14 +639,15 @@ class ClassifierHead(nn.Module): norm_layer=nn.BatchNorm2d, act_layer=nn.Hardswish, global_pool='avg', + norm_eps=1e-5, ): super(ClassifierHead, self).__init__() self.in_conv = ConvNormAct(in_channels, widths[0], 1, norm_layer=norm_layer, act_layer=act_layer) self.global_pool = SelectAdaptivePool2d(pool_type=global_pool, flatten=True, input_fmt='NCHW') self.classifier = nn.Sequential( nn.Linear(widths[0], widths[1], bias=False), - nn.LayerNorm(widths[1]), - act_layer(inplace=True) if act_layer is not nn.GELU else act_layer(approximate="tanh"), + nn.LayerNorm(widths[1], eps=norm_eps), + act_layer(inplace=True) if act_layer is not None else nn.Identity(), nn.Dropout(dropout, inplace=False), nn.Linear(widths[1], n_classes, bias=True), ) @@ -784,17 +783,19 @@ class EfficientVitLarge(nn.Module): depths=(), head_dim=32, norm_layer=nn.BatchNorm2d, - act_layer=nn.GELU, + act_layer=GELUTanh, global_pool='avg', head_widths=(), drop_rate=0.0, num_classes=1000, - eps=1e-7, + norm_eps=1e-7, ): super(EfficientVitLarge, self).__init__() self.grad_checkpointing = False self.global_pool = global_pool self.num_classes = num_classes + self.norm_eps = norm_eps + norm_layer = partial(norm_layer, eps=self.norm_eps) # input stem self.stem = Stem(in_chans, widths[0], depths[0], norm_layer, act_layer, block_type='large') @@ -830,20 +831,13 @@ class EfficientVitLarge(nn.Module): dropout=self.head_dropout, global_pool=self.global_pool, act_layer=act_layer, + norm_eps=self.norm_eps, ) else: if self.global_pool == 'avg': self.head = SelectAdaptivePool2d(pool_type=global_pool, flatten=True) else: self.head = nn.Identity() - self.set_norm_eps(eps) - - @torch.jit.ignore - def set_norm_eps(self, eps): - for m in self.modules(): - if isinstance(m, (nn.GroupNorm, nn.LayerNorm, _BatchNorm)): - if eps is not None: - m.eps = eps @torch.jit.ignore def group_matcher(self, coarse=False): @@ -875,6 +869,7 @@ class EfficientVitLarge(nn.Module): n_classes=num_classes, dropout=self.head_dropout, global_pool=self.global_pool, + norm_eps=self.norm_eps ) else: if self.global_pool == 'avg': @@ -1056,19 +1051,19 @@ def efficientvit_l3(pretrained=False, **kwargs): @register_model def efficientvit_l0_sam(pretrained=False, **kwargs): model_args = dict( - widths=(32, 64, 128, 256, 512), depths=(1, 1, 1, 4, 4), head_dim=32, num_classes=0) # only backbone for segment-anything-model weights + widths=(32, 64, 128, 256, 512), depths=(1, 1, 1, 4, 4), head_dim=32, num_classes=0, norm_eps=1e-6) # only backbone for segment-anything-model weights return _create_efficientvit_large('efficientvit_l0_sam', pretrained=pretrained, **dict(model_args, **kwargs)) @register_model def efficientvit_l1_sam(pretrained=False, **kwargs): model_args = dict( - widths=(32, 64, 128, 256, 512), depths=(1, 1, 1, 6, 6), head_dim=32, num_classes=0) # only backbone for segment-anything-model weights + widths=(32, 64, 128, 256, 512), depths=(1, 1, 1, 6, 6), head_dim=32, num_classes=0, norm_eps=1e-6) # only backbone for segment-anything-model weights return _create_efficientvit_large('efficientvit_l1_sam', pretrained=pretrained, **dict(model_args, **kwargs)) @register_model def efficientvit_l2_sam(pretrained=False, **kwargs): model_args = dict( - widths=(32, 64, 128, 256, 512), depths=(1, 2, 2, 8, 8), head_dim=32, num_classes=0) # only backbone for segment-anything-model weights + widths=(32, 64, 128, 256, 512), depths=(1, 2, 2, 8, 8), head_dim=32, num_classes=0, norm_eps=1e-6) # only backbone for segment-anything-model weights return _create_efficientvit_large('efficientvit_l2_sam', pretrained=pretrained, **dict(model_args, **kwargs))