mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
For ConvNeXt, use timm internal LayerNorm for fast_norm in non conv_mlp mode
This commit is contained in:
parent
cac0a4570a
commit
837c68263b
@ -19,7 +19,7 @@ import torch.nn as nn
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from .helpers import named_apply, build_model_with_cfg, checkpoint_seq
|
||||
from .layers import trunc_normal_, SelectAdaptivePool2d, DropPath, ConvMlp, Mlp, LayerNorm2d,\
|
||||
from .layers import trunc_normal_, SelectAdaptivePool2d, DropPath, ConvMlp, Mlp, LayerNorm2d, LayerNorm, \
|
||||
create_conv2d, get_act_layer, make_divisible, to_ntuple
|
||||
from .registry import register_model
|
||||
|
||||
@ -161,7 +161,7 @@ class ConvNeXtBlock(nn.Module):
|
||||
out_chs = out_chs or in_chs
|
||||
act_layer = get_act_layer(act_layer)
|
||||
if not norm_layer:
|
||||
norm_layer = partial(LayerNorm2d, eps=1e-6) if conv_mlp else partial(nn.LayerNorm, eps=1e-6)
|
||||
norm_layer = LayerNorm2d if conv_mlp else LayerNorm
|
||||
mlp_layer = ConvMlp if conv_mlp else Mlp
|
||||
self.use_conv_mlp = conv_mlp
|
||||
|
||||
@ -291,8 +291,8 @@ class ConvNeXt(nn.Module):
|
||||
assert output_stride in (8, 16, 32)
|
||||
kernel_sizes = to_ntuple(4)(kernel_sizes)
|
||||
if norm_layer is None:
|
||||
norm_layer = partial(LayerNorm2d, eps=1e-6)
|
||||
norm_layer_cl = norm_layer if conv_mlp else partial(nn.LayerNorm, eps=1e-6)
|
||||
norm_layer = LayerNorm2d
|
||||
norm_layer_cl = norm_layer if conv_mlp else LayerNorm
|
||||
else:
|
||||
assert conv_mlp,\
|
||||
'If a norm_layer is specified, conv MLP must be used so all norm expect rank-4, channels-first input'
|
||||
|
Loading…
x
Reference in New Issue
Block a user