diff --git a/timm/models/convnext.py b/timm/models/convnext.py index ba63a453..15000b40 100644 --- a/timm/models/convnext.py +++ b/timm/models/convnext.py @@ -19,7 +19,7 @@ import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .helpers import named_apply, build_model_with_cfg, checkpoint_seq -from .layers import trunc_normal_, SelectAdaptivePool2d, DropPath, ConvMlp, Mlp, LayerNorm2d,\ +from .layers import trunc_normal_, SelectAdaptivePool2d, DropPath, ConvMlp, Mlp, LayerNorm2d, LayerNorm, \ create_conv2d, get_act_layer, make_divisible, to_ntuple from .registry import register_model @@ -161,7 +161,7 @@ class ConvNeXtBlock(nn.Module): out_chs = out_chs or in_chs act_layer = get_act_layer(act_layer) if not norm_layer: - norm_layer = partial(LayerNorm2d, eps=1e-6) if conv_mlp else partial(nn.LayerNorm, eps=1e-6) + norm_layer = LayerNorm2d if conv_mlp else LayerNorm mlp_layer = ConvMlp if conv_mlp else Mlp self.use_conv_mlp = conv_mlp @@ -291,8 +291,8 @@ class ConvNeXt(nn.Module): assert output_stride in (8, 16, 32) kernel_sizes = to_ntuple(4)(kernel_sizes) if norm_layer is None: - norm_layer = partial(LayerNorm2d, eps=1e-6) - norm_layer_cl = norm_layer if conv_mlp else partial(nn.LayerNorm, eps=1e-6) + norm_layer = LayerNorm2d + norm_layer_cl = norm_layer if conv_mlp else LayerNorm else: assert conv_mlp,\ 'If a norm_layer is specified, conv MLP must be used so all norm expect rank-4, channels-first input'