diff --git a/timm/layers/create_norm.py b/timm/layers/create_norm.py index 74e893d8..75262b5e 100644 --- a/timm/layers/create_norm.py +++ b/timm/layers/create_norm.py @@ -10,7 +10,7 @@ from typing import Type import torch.nn as nn -from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d, RmsNorm, RmsNorm2d +from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d, RmsNorm, RmsNorm2d, SimpleNorm, SimpleNorm2d from torchvision.ops.misc import FrozenBatchNorm2d _NORM_MAP = dict( @@ -23,6 +23,8 @@ _NORM_MAP = dict( layernorm2d=LayerNorm2d, rmsnorm=RmsNorm, rmsnorm2d=RmsNorm2d, + simplenorm=SimpleNorm, + simplenorm2d=SimpleNorm2d, frozenbatchnorm2d=FrozenBatchNorm2d, ) _NORM_TYPES = {m for n, m in _NORM_MAP.items()}