Add SimpleNorm to create_norm factory

This commit is contained in:
Ross Wightman 2024-12-30 14:22:42 -08:00
parent 5809c2fe5e
commit a4146b79d1

View File

@ -10,7 +10,7 @@ from typing import Type
import torch.nn as nn
from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d, RmsNorm, RmsNorm2d
from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d, RmsNorm, RmsNorm2d, SimpleNorm, SimpleNorm2d
from torchvision.ops.misc import FrozenBatchNorm2d
_NORM_MAP = dict(
@ -23,6 +23,8 @@ _NORM_MAP = dict(
layernorm2d=LayerNorm2d,
rmsnorm=RmsNorm,
rmsnorm2d=RmsNorm2d,
simplenorm=SimpleNorm,
simplenorm2d=SimpleNorm2d,
frozenbatchnorm2d=FrozenBatchNorm2d,
)
_NORM_TYPES = {m for n, m in _NORM_MAP.items()}